├── .cursor └── rules │ ├── sae_evaluation.mdc │ ├── sae_training.mdc │ └── sparse_autoencoder_overview.mdc ├── .cursorrules ├── .github └── workflows │ ├── checks.yml │ └── publish.yml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── README.md ├── examples ├── generate_pythia_acts.py ├── loading_llamascope_saes.ipynb ├── replicate_llamascope_sae.sh ├── train_pythia.py └── train_pythia_pre_generated_acts.py ├── pyproject.toml ├── server ├── .env.example ├── __init__.py └── app.py ├── src └── lm_saes │ ├── __init__.py │ ├── abstract_sae.py │ ├── activation │ ├── __init__.py │ ├── factory.py │ ├── processors │ │ ├── __init__.py │ │ ├── activation.py │ │ ├── cached_activation.py │ │ ├── core.py │ │ ├── huggingface.py │ │ └── token.py │ └── writer.py │ ├── analysis │ ├── __init__.py │ ├── direct_logit_attributor.py │ ├── feature_analyzer.py │ └── feature_interpreter.py │ ├── backend │ ├── __init__.py │ └── language_model.py │ ├── circuit │ ├── __init__.py │ └── context.py │ ├── clt.py │ ├── config.py │ ├── crosscoder.py │ ├── database.py │ ├── entrypoint.py │ ├── evaluator.py │ ├── initializer.py │ ├── kernels.py │ ├── optim.py │ ├── resource_loaders.py │ ├── runners │ ├── __init__.py │ ├── analyze.py │ ├── autointerp.py │ ├── eval.py │ ├── generate.py │ ├── train.py │ └── utils.py │ ├── sae.py │ ├── trainer.py │ └── utils │ ├── __init__.py │ ├── bytes.py │ ├── concurrent.py │ ├── config.py │ ├── discrete.py │ ├── distributed.py │ ├── hooks.py │ ├── huggingface.py │ ├── logging.py │ ├── math.py │ ├── misc.py │ ├── tensor_dict.py │ └── timer.py ├── tests ├── __init__.py ├── integration │ ├── test_activation_factory.py │ └── test_train_sae.py └── unit │ ├── test_activation_processors.py │ ├── test_activation_processors_distributed.py │ ├── test_activation_writer.py │ ├── test_clt.py │ ├── test_clt_distributed.py │ ├── test_concurrent.py │ ├── test_crosscoder.py │ ├── test_database.py │ ├── test_discrete_mapper.py │ ├── test_evaluator.py │ ├── test_example.py │ ├── test_feature_analyzer.py │ ├── test_feature_interpreter.py │ ├── test_hf_backend.py │ ├── test_initializer.py │ ├── test_misc.py │ ├── test_mixcoder.py │ ├── test_sae.py │ └── test_util_distributed.py └── ui ├── .env.example ├── .eslintrc.cjs ├── .gitignore ├── .prettierrc.cjs ├── README.md ├── bun.lockb ├── components.json ├── index.html ├── package.json ├── postcss.config.js ├── public ├── openmoss.ico └── vite.svg ├── src ├── components │ ├── app │ │ ├── feature-preview.tsx │ │ ├── navbar.tsx │ │ ├── sample.tsx │ │ ├── section-navigator.tsx │ │ └── token.tsx │ ├── attn-head │ │ └── attn-head-card.tsx │ ├── dictionary │ │ ├── dictionary-card.tsx │ │ └── sample.tsx │ ├── feature │ │ ├── feature-card.tsx │ │ ├── interpret.tsx │ │ └── sample.tsx │ ├── model │ │ ├── circuit.tsx │ │ └── model-card.tsx │ └── ui │ │ ├── accordion.tsx │ │ ├── badge.tsx │ │ ├── button.tsx │ │ ├── card.tsx │ │ ├── combobox.tsx │ │ ├── command.tsx │ │ ├── context-menu.tsx │ │ ├── data-table.tsx │ │ ├── dialog.tsx │ │ ├── dropdown-menu.tsx │ │ ├── hover-card.tsx │ │ ├── input.tsx │ │ ├── label.tsx │ │ ├── multiple-selector.tsx │ │ ├── pagination.tsx │ │ ├── popover.tsx │ │ ├── select.tsx │ │ ├── separator.tsx │ │ ├── switch.tsx │ │ ├── table.tsx │ │ ├── tabs.tsx │ │ ├── textarea.tsx │ │ ├── toggle.tsx │ │ └── tooltip.tsx ├── globals.css ├── lib │ └── utils.ts ├── main.tsx ├── routes │ ├── attn-heads │ │ └── page.tsx │ ├── bookmarks │ │ └── page.tsx │ ├── dictionaries │ │ └── page.tsx │ ├── features │ │ └── page.tsx │ ├── models │ │ └── page.tsx │ └── page.tsx ├── tanstack.d.ts ├── types │ ├── attn-head.ts │ ├── dictionary.ts │ ├── feature.ts │ └── model.ts ├── utils │ ├── array.ts │ ├── style.ts │ └── token.ts └── vite-env.d.ts ├── tailwind.config.js ├── tsconfig.json ├── tsconfig.node.json └── vite.config.ts /.cursor/rules/sae_evaluation.mdc: -------------------------------------------------------------------------------- 1 | --- 2 | description: Evaluation overviews for Sparse Autoencoders 3 | globs: 4 | alwaysApply: false 5 | --- 6 | # SAE Evaluation and Analysis 7 | 8 | Evaluating Sparse Autoencoders involves multiple metrics and analysis techniques to assess both quantitative performance and qualitative interpretability. 9 | 10 | ## Evaluation Metrics 11 | 12 | 1. **Sparsity-Fidelity Tradeoff** 13 | - L0 Sparsity: Average number of active features per sample 14 | - Explained Variance: Proportion of variance in the original activations explained by the SAE 15 | - Delta LM Loss: Impact on language model loss when using reconstructed activations 16 | 17 | 2. **Feature Characteristics** 18 | - Activation Frequency: How often each feature activates across the dataset 19 | - Monosemanticity: Whether each feature represents a single, interpretable concept 20 | - Out-of-Distribution Generalization: Performance on contexts different from training data 21 | 22 | The evaluation process is implemented in [src/lm_saes/evaluator.py](mdc:src/lm_saes/evaluator.py). 23 | 24 | ## Analysis Tools 25 | 26 | This codebase provides tools for analyzing SAEs through: 27 | 28 | 1. **Feature Visualization**: Examining which tokens/inputs maximally activate each feature 29 | 2. **Feature Geometry**: Analyzing the relationships between features in the latent space 30 | 3. **Circuit Analysis**: Understanding how features interact in the context of the larger model 31 | 32 | Analysis configurations can be specified through TOML files as shown in [examples/configuration/analyze.toml](mdc:examples/configuration/analyze.toml). 33 | 34 | ## Visualization 35 | 36 | Results from analysis can be visualized through a web interface: 37 | 38 | 1. A FastAPI backend served with `uvicorn server.app:app` 39 | 2. A frontend for interactive exploration of features and their properties 40 | 41 | The visualization makes it easier to identify patterns and interpret the meaning of learned features. 42 | -------------------------------------------------------------------------------- /.cursor/rules/sae_training.mdc: -------------------------------------------------------------------------------- 1 | --- 2 | description: Training overview for Sparse Autoencoders 3 | globs: 4 | alwaysApply: false 5 | --- 6 | # SAE Training Process 7 | 8 | The training process for Sparse Autoencoders in this codebase focuses on reconstructing activations from language models while enforcing sparsity. 9 | 10 | ## Training Pipeline 11 | 12 | 1. **Activation Collection**: Extract activations from a pre-trained language model (e.g., Llama-3.1-8B) 13 | 2. **Initialization**: Initialize SAE parameters, often with decoder normalization 14 | 3. **Training**: Minimize reconstruction loss while enforcing sparsity 15 | 4. **Post-processing**: Transform SAEs to have desired properties (e.g., unit decoder norm) 16 | 17 | The training process is implemented in [src/lm_saes/trainer.py](mdc:src/lm_saes/trainer.py). 18 | 19 | ## Key Components 20 | 21 | - **Activation Buffer**: Stores model activations generated on-the-fly during training 22 | - **Mixed Parallelism**: Combines data parallelism for activation generation with tensor parallelism for SAE training 23 | - **K-Annealing**: Schedule that gradually reduces the number of active features during training 24 | 25 | ## Loss Functions 26 | 27 | The loss function typically combines: 28 | 1. **Reconstruction Loss**: MSE between original activations and reconstructed activations 29 | 2. **Sparsity Loss**: L1 regularization or other mechanisms to promote sparsity 30 | 31 | ## Training Configurations 32 | 33 | Training configurations can be specified through TOML files as shown in [examples/configuration/train.toml](mdc:examples/configuration/train.toml). 34 | 35 | ## Resource Considerations 36 | 37 | - Training SAEs requires significant computational resources 38 | - The codebase includes optimizations for disk I/O and memory usage 39 | - Online activation generation eliminates the need for vast storage resources, making it more suitable for academic research 40 | -------------------------------------------------------------------------------- /.cursor/rules/sparse_autoencoder_overview.mdc: -------------------------------------------------------------------------------- 1 | --- 2 | description: 3 | globs: 4 | alwaysApply: true 5 | --- 6 | # Sparse Autoencoder Overview 7 | 8 | Sparse Autoencoders (SAEs) are neural network models used to extract interpretable features from language models. They help address the superposition problem in neural networks by learning sparse, interpretable representations of activations. 9 | 10 | ## Key Concepts 11 | 12 | - **Superposition**: When a neural network represents multiple features in a single neuron, making interpretation difficult. 13 | - **Monosemanticity**: The desirable property where each feature represents exactly one concept. 14 | - **Sparsity**: The property where only a small subset of features activate for a given input, improving interpretability. 15 | 16 | ## Architecture 17 | 18 | A typical SAE consists of: 19 | - An **encoder** that maps model activations to a higher-dimensional latent space 20 | - A **decoder** that reconstructs the original activations from the latent space 21 | - A **sparsity mechanism** (often L1 regularization or TopK activation) that enforces sparse activations 22 | 23 | The main implementation can be found in [src/lm_saes/sae.py](mdc:src/lm_saes/sae.py) with the abstract class defined in [src/lm_saes/abstract_sae.py](mdc:src/lm_saes/abstract_sae.py). 24 | 25 | ## Types of SAEs 26 | 27 | 1. **Vanilla SAEs**: Use ReLU activation + L1 regularization 28 | 2. **TopK SAEs**: Only retain the top K activations per sample, zeroing out the rest 29 | 3. **JumpReLU SAEs**: A variant with thresholded activation functions 30 | 31 | ## Use Cases 32 | 33 | - Mechanistic interpretability of language models 34 | - Discovering features and circuits in neural networks 35 | - Addressing model hallucination 36 | - Mitigating safety-relevant behaviors 37 | - Creating a more interpretable latent space 38 | -------------------------------------------------------------------------------- /.cursorrules: -------------------------------------------------------------------------------- 1 | You are an expert in developing deep learning models in PyTorch. 2 | 3 | Key principles: 4 | 5 | - Add precise type hints and docstrings to all functions and classes. Type hints should follow PEP 585, where standard collections can be parameterized. In other words, use `list[int]` instead of `List[int]`. 6 | - When writing tests, use `pytest` and `pytest-mock` to write tests. Use `mocker` for mocking. 7 | - Current project use `uv` for dependency management. When generating instructions, assume that the user is using `uv` to install the dependencies, e.g. `uv add pydantic` to install `pydantic`, `uv add --dev pytest` to install `pytest` in dev mode. 8 | - When writing docstrings of `dataclass` or pydantic models, write field descriptions right after the field. 9 | - Do not automatically run the code, since the code is designed to be run on remote GPU servers. 10 | -------------------------------------------------------------------------------- /.github/workflows/checks.yml: -------------------------------------------------------------------------------- 1 | name: Checks 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - dev 8 | paths: 9 | - "**" # Include all files by default 10 | - "!.devcontainer/**" 11 | - "!.vscode/**" 12 | - "!.git*" 13 | - "!*.md" 14 | - "!.github/**" 15 | - ".github/workflows/checks.yml" # Still include current workflow 16 | pull_request: 17 | branches: 18 | - main 19 | - dev 20 | paths: 21 | - "**" 22 | - "!.devcontainer/**" 23 | - "!.vscode/**" 24 | - "!.git*" 25 | - "!*.md" 26 | - "!.github/**" 27 | - ".github/workflows/checks.yml" 28 | # Allow this workflow to be called from other workflows 29 | workflow_call: 30 | inputs: 31 | # Requires at least one input to be valid, but in practice we don't need any 32 | dummy: 33 | type: string 34 | required: false 35 | 36 | permissions: 37 | actions: write 38 | contents: write 39 | 40 | jobs: 41 | type-checks: 42 | name: Type Checks 43 | runs-on: ubuntu-latest 44 | steps: 45 | - name: Checkout repository 46 | uses: actions/checkout@v4 47 | with: 48 | submodules: "true" 49 | - name: Install uv 50 | uses: astral-sh/setup-uv@v5 51 | with: 52 | enable-cache: true 53 | cache-dependency-glob: "uv.lock" 54 | 55 | - name: Install the project 56 | run: uv sync --extra default --dev 57 | 58 | - name: Type check 59 | run: uv run basedpyright . 60 | 61 | # - name: Unit tests 62 | # run: uv run pytest tests 63 | 64 | ruff: 65 | runs-on: ubuntu-latest 66 | steps: 67 | - uses: actions/checkout@v4 68 | - uses: astral-sh/ruff-action@v1 69 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: "Publish" 2 | 3 | on: 4 | release: 5 | types: ["published"] 6 | 7 | jobs: 8 | run: 9 | name: "Build and publish release" 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v4 14 | 15 | - name: Install uv 16 | uses: astral-sh/setup-uv@v3 17 | with: 18 | enable-cache: true 19 | cache-dependency-glob: uv.lock 20 | 21 | - name: Set up Python 22 | run: uv python install 3.12 # Or whatever version I want to use. 23 | 24 | - name: Build 25 | run: uv build 26 | 27 | - name: Publish 28 | run: uv publish -t ${{ secrets.PYPI_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | .pdm-python 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | 166 | ### Python Patch ### 167 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 168 | poetry.toml 169 | 170 | # ruff 171 | .ruff_cache/ 172 | 173 | # LSP config files 174 | pyrightconfig.json 175 | 176 | # End of https://www.toptal.com/developers/gitignore/api/python 177 | 178 | # UI 179 | !ui/**/* 180 | 181 | # Custom 182 | /sync.sh 183 | /connect.sh 184 | /activations 185 | /data 186 | /legacy 187 | /run.py 188 | /checkpoints 189 | /wandb 190 | /exp 191 | /analysis-results 192 | /analysis 193 | /results 194 | .vscode/settings.json 195 | 196 | # uv 197 | # We ignore uv.lock as conflicting optional dependencies generates non-compatible uv.lock files. 198 | uv.lock -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "TransformerLens"] 2 | path = TransformerLens 3 | url = git@github.com:OpenMOSS/TransformerLens.git -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.11.7 5 | hooks: 6 | # Run the linter. 7 | - id: ruff 8 | args: [--fix] 9 | # Run the formatter. 10 | - id: ruff-format 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Language-Model-SAEs 2 | 3 | > [!IMPORTANT] 4 | > Currently the examples are outdated and some parallelism strategies are not working due to lack of bandwidth. We are working on better organizing recent updates and will make everything work ASAP. 5 | 6 | ## News 7 | 8 | - 2025.9.23 We leverage **Crosscoder** to track feature evolution across pre-training snapshots. Link: [Evolution of Concepts in Language Model Pre-Training](https://www.arxiv.org/abs/2509.17196). 9 | 10 | - 2025.8.23 We identify a prevalent low-rank structure in attention outputs as the key cause of dead features, and propose **Active Subspace Initialization** to improve sparse dictionary learning on these low-rank activations. Link: [Attention Layers Add Into Low-Dimensional Residual Subspaces](https://arxiv.org/abs/2508.16929). 11 | 12 | - 2025.4.29 We introduce **Low-Rank Sparse Attention (Lorsa)** to attack attention superposition, extracting tens of thousands of true attention units from LLM attention layers. Link: [Towards Understanding the Nature of Attention with Low-Rank Sparse Decomposition](https://arxiv.org/abs/2504.20938). 13 | 14 | - 2024.10.29 We introduce **Llama Scope**, our first contribution to the open-source Sparse Autoencoder ecosystem. Stay tuned! Link: [Llama Scope: Extracting Millions of Features from Llama-3.1-8B with Sparse Autoencoders](http://arxiv.org/abs/2410.20526). 15 | 16 | - 2024.10.9 Transformers and Mambas are mechanistically similar in both feature and circuit level. Can we follow this line and find **universal motifs and fundamental differences between language model architectures**? Link: [Towards Universality: Studying Mechanistic Similarity Across Language Model Architectures](https://arxiv.org/pdf/2410.06672). 17 | 18 | - 2024.5.22 We propose hierarchical tracing, a promising method to **scale up sparse feature circuit analysis** to industrial size language models! Link: [Automatically Identifying Local and Global Circuits with Linear Computation Graphs](https://arxiv.org/pdf/2405.13868). 19 | 20 | - 2024.2.19 Our first attempt on SAE-based circuit analysis for Othello-GPT leads us to **an example of Attention Superposition in the wild**! Link: [Dictionary learning improves patch-free circuit discovery in mechanistic interpretability: A case study on othello-gpt](https://arxiv.org/pdf/2402.12201). 21 | 22 | ## Installation 23 | 24 | We use [uv](https://docs.astral.sh/uv/) to manage the dependencies, which is an alternative to [poetry](https://python-poetry.org/) or [pdm](https://pdm-project.org/). To install the required packages, just install [uv](https://docs.astral.sh/uv/getting-started/installation/), and run the following command: 25 | 26 | ```bash 27 | uv sync --extra default 28 | ``` 29 | 30 | This will install all the required packages for the codebase in `.venv` directory. For Ascend NPU support, run 31 | 32 | ```bash 33 | uv sync --extra npu 34 | ``` 35 | 36 | A forked version of `TransformerLens` is also included in the dependencies to provide the necessary tools for analyzing features. 37 | 38 | If you want to use the visualization tools, you also need to install the required packages for the frontend, which uses [bun](https://bun.sh/) for dependency management. Follow the instructions on the website to install it, and then run the following command: 39 | 40 | ```bash 41 | cd ui 42 | bun install 43 | ``` 44 | 45 | `bun` is not well-supported on Windows, so you may need to use WSL or other Linux-based solutions to run the frontend, or consider using a different package manager, such as `pnpm` or `yarn`. 46 | 47 | ## Launch an Experiment 48 | 49 | The guidelines and examples for launching experiments are generally outdated. At this moment, you may explore `src/lm_saes/runners` folder for the interface for generating activations and training & analyzing SAE variants. For analyzing SAEs, a MongoDB instance is required. More instructions will be provided in near future. 50 | 51 | ## Visualizing the Learned Dictionary 52 | 53 | The analysis results will be saved using MongoDB, and you can use the provided visualization tools to visualize the learned dictionary. First, start the FastAPI server by running the following command: 54 | 55 | ```bash 56 | uvicorn server.app:app --port 24577 --env-file server/.env 57 | ``` 58 | 59 | Then, copy the `ui/.env.example` file to `ui/.env` and modify the `VITE_BACKEND_URL` to fit your server settings (by default, it's `http://localhost:24577`), and start the frontend by running the following command: 60 | 61 | ```bash 62 | cd ui 63 | bun dev --port 24576 64 | ``` 65 | 66 | That's it! You can now go to `http://localhost:24576` to visualize the learned dictionary and its features. 67 | 68 | ## Development 69 | 70 | We highly welcome contributions to this project. If you have any questions or suggestions, feel free to open an issue or a pull request. We are looking forward to hearing from you! 71 | 72 | TODO: Add development guidelines 73 | 74 | ## Acknowledgement 75 | 76 | The design of the pipeline (including the configuration and some training details) is highly inspired by the [mats_sae_training 77 | ](https://github.com/jbloomAus/mats_sae_training) project (now known as [SAELens](https://github.com/jbloomAus/SAELens)) and heavily relies on the [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens) library. We thank the authors for their great work. 78 | 79 | ## Citation 80 | 81 | Please cite this library as: 82 | 83 | ``` 84 | @misc{Ge2024OpenMossSAEs, 85 | title = {OpenMoss Language Model Sparse Autoencoders}, 86 | author = {Xuyang Ge, Fukang Zhu, Junxuan Wang, Wentao Shu, Lingjie Chen, Zhengfu He}, 87 | url = {https://github.com/OpenMOSS/Language-Model-SAEs}, 88 | year = {2024} 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /examples/generate_pythia_acts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from lm_saes import ( 4 | ActivationFactoryTarget, 5 | DatasetConfig, 6 | GenerateActivationsSettings, 7 | LanguageModelConfig, 8 | generate_activations, 9 | ) 10 | 11 | if __name__ == "__main__": 12 | settings = GenerateActivationsSettings( 13 | model=LanguageModelConfig( 14 | model_name="EleutherAI/pythia-160m", 15 | device="cuda", 16 | dtype="torch.float32", 17 | ), 18 | model_name="pythia-160m", 19 | dataset=DatasetConfig( 20 | dataset_name_or_path="Skylion007/openwebtext", 21 | is_dataset_on_disk=True, 22 | ), 23 | dataset_name="openwebtext", 24 | hook_points=[f"blocks.{layer}.ln1.hook_normalized" for layer in range(12)], 25 | output_dir="activations", 26 | total_tokens=800_000_000, 27 | context_size=1024, 28 | n_samples_per_chunk=1, 29 | model_batch_size=32, 30 | target=ActivationFactoryTarget.BATCHED_ACTIVATIONS_1D, 31 | batch_size=2048 * 16, 32 | buffer_size=2048 * 200, 33 | ) 34 | generate_activations(settings) 35 | torch.distributed.destroy_process_group() 36 | -------------------------------------------------------------------------------- /examples/replicate_llamascope_sae.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | 4 | 5 | exp_factor=$1 6 | tc_in_abbr=$2 7 | layer=$3 8 | 9 | tc_out_abbr=$tc_in_abbr 10 | 11 | if [ "$exp_factor" -eq 128 ]; then 12 | tp_size=2 13 | else 14 | tp_size=1 15 | fi 16 | 17 | k=50 18 | 19 | 20 | if [ "$exp_factor" -eq 8 ]; then 21 | total_training_tokens=800000000 22 | lr=8e-4 23 | elif [ "$exp_factor" -eq 32 ]; then 24 | total_training_tokens=1600000000 25 | lr=8e-4 26 | elif [ "$exp_factor" -eq 128 ]; then 27 | total_training_tokens=3200000000 28 | lr=2e-4 29 | fi 30 | 31 | if [ "$tc_in_abbr" = "TC" ]; then 32 | tc_out_abbr="M" 33 | fi 34 | 35 | WANDB_CONSOLE=off WANDB_MODE=offline torchrun --nproc_per_node=$tp_size --master_port=10110 ./examples/programmatic/train_llama_scope.py --total_training_tokens $total_training_tokens --layer $layer --lr $lr --clip_grad_norm 0.001 --exp_factor $exp_factor --batch_size 2048 --tp_size $tp_size --buffer_size 500000 --log_to_wandb false --store_batch_size 32 --k $k --tc_in_abbr $tc_in_abbr --tc_out_abbr $tc_out_abbr 36 | -------------------------------------------------------------------------------- /examples/train_pythia.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from lm_saes import ( 4 | ActivationFactoryConfig, 5 | ActivationFactoryDatasetSource, 6 | ActivationFactoryTarget, 7 | DatasetConfig, 8 | InitializerConfig, 9 | LanguageModelConfig, 10 | SAEConfig, 11 | TrainerConfig, 12 | TrainSAESettings, 13 | WandbConfig, 14 | train_sae, 15 | ) 16 | 17 | if __name__ == "__main__": 18 | settings = TrainSAESettings( 19 | sae=SAEConfig( 20 | hook_point_in="blocks.3.ln1.hook_normalized", 21 | hook_point_out="blocks.3.ln1.hook_normalized", 22 | d_model=768, 23 | expansion_factor=8, 24 | act_fn="topk", 25 | norm_activation="token-wise", 26 | sparsity_include_decoder_norm=True, 27 | top_k=50, 28 | dtype=torch.float32, 29 | device="cuda", 30 | ), 31 | initializer=InitializerConfig( 32 | init_search=True, 33 | state="training", 34 | ), 35 | trainer=TrainerConfig( 36 | lp=1, 37 | initial_k=768 / 2, 38 | lr=4e-4, 39 | lr_scheduler_name="constantwithwarmup", 40 | total_training_tokens=600_000_000, 41 | log_frequency=1000, 42 | eval_frequency=1000000, 43 | n_checkpoints=5, 44 | check_point_save_mode="linear", 45 | exp_result_path="results", 46 | ), 47 | wandb=WandbConfig( 48 | wandb_project="pythia-160m-test", 49 | exp_name="pythia-160m-test", 50 | ), 51 | activation_factory=ActivationFactoryConfig( 52 | sources=[ 53 | ActivationFactoryDatasetSource( 54 | name="openwebtext", 55 | ) 56 | ], 57 | target=ActivationFactoryTarget.BATCHED_ACTIVATIONS_1D, 58 | hook_points=["blocks.3.ln1.hook_normalized"], 59 | batch_size=2048, 60 | buffer_size=None, 61 | ignore_token_ids=[], 62 | ), 63 | sae_name="pythia-160m-test-L3", 64 | sae_series="pythia-160m-test", 65 | model=LanguageModelConfig( 66 | model_name="EleutherAI/pythia-160m", 67 | device="cuda", 68 | dtype="torch.float32", 69 | ), 70 | model_name="pythia-160m", 71 | datasets={ 72 | "openwebtext": DatasetConfig( 73 | dataset_name_or_path="Skylion007/openwebtext", 74 | ) 75 | }, 76 | ) 77 | train_sae(settings) 78 | -------------------------------------------------------------------------------- /examples/train_pythia_pre_generated_acts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from lm_saes import ( 4 | ActivationFactoryActivationsSource, 5 | ActivationFactoryConfig, 6 | ActivationFactoryTarget, 7 | InitializerConfig, 8 | SAEConfig, 9 | TrainerConfig, 10 | TrainSAESettings, 11 | WandbConfig, 12 | train_sae, 13 | ) 14 | 15 | if __name__ == "__main__": 16 | settings = TrainSAESettings( 17 | sae=SAEConfig( 18 | hook_point_in="blocks.3.ln1.hook_normalized", 19 | hook_point_out="blocks.3.ln1.hook_normalized", 20 | d_model=768, 21 | expansion_factor=8, 22 | act_fn="topk", 23 | norm_activation="token-wise", 24 | sparsity_include_decoder_norm=True, 25 | top_k=50, 26 | dtype=torch.float32, 27 | device="cuda", 28 | ), 29 | initializer=InitializerConfig( 30 | init_search=True, 31 | state="training", 32 | ), 33 | trainer=TrainerConfig( 34 | lp=1, 35 | initial_k=768 / 2, 36 | lr=4e-4, 37 | lr_scheduler_name="constantwithwarmup", 38 | total_training_tokens=600_000_000, 39 | log_frequency=1000, 40 | eval_frequency=1000000, 41 | n_checkpoints=5, 42 | check_point_save_mode="linear", 43 | exp_result_path="results", 44 | ), 45 | wandb=WandbConfig( 46 | wandb_project="pythia-160m-test", 47 | exp_name="pythia-160m-test", 48 | ), 49 | activation_factory=ActivationFactoryConfig( 50 | sources=[ 51 | ActivationFactoryActivationsSource( 52 | name="openwebtext", 53 | path="activations", 54 | sample_weights=1.0, 55 | device="cuda", 56 | ) 57 | ], 58 | target=ActivationFactoryTarget.BATCHED_ACTIVATIONS_1D, 59 | hook_points=["blocks.3.ln1.hook_normalized"], 60 | batch_size=2048, 61 | buffer_size=None, 62 | ignore_token_ids=[], 63 | ), 64 | sae_name="pythia-160m-test-L3", 65 | sae_series="pythia-160m-test", 66 | ) 67 | train_sae(settings) 68 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "lm-saes" 3 | version = "0.1.0" 4 | description = "For OpenMOSS Mechanistic Interpretability Team's Sparse Autoencoder (SAE) research. Open-sourced and constantly updated." 5 | dependencies = [ 6 | "datasets>=3.0.2", 7 | "transformers>=4.46.0", 8 | "einops>=0.8.0", 9 | "fastapi>=0.115.4", 10 | "matplotlib>=3.9.2", 11 | "numpy<2.0.0", 12 | "pandas>=2.2.3", 13 | "pymongo>=4.10.1", 14 | "tensorboardX>=2.6.2.2", 15 | # "torch>=2.5.0", 16 | # "torchvision>=0.20.1", 17 | "transformer-lens", 18 | "uvicorn>=0.32.0", 19 | "wandb>=0.18.5", 20 | "msgpack>=1.1.0", 21 | "plotly>=5.24.1", 22 | "openai>=1.52.2", 23 | "tiktoken>=0.8.0", 24 | "python-dotenv>=1.0.1", 25 | "jaxtyping>=0.2.34", 26 | "safetensors>=0.4.5", 27 | "pydantic>=2.10.6", 28 | "argparse>=1.4.0", 29 | "pyyaml>=6.0.2", 30 | "tomlkit>=0.13.2", 31 | "pydantic-settings>=2.7.1", 32 | "typing-extensions>=4.13.2", 33 | "more-itertools>=10.7.0", 34 | "json-repair>=0.44.1", 35 | ] 36 | requires-python = "==3.11.*" 37 | readme = "README.md" 38 | 39 | [[project.authors]] 40 | name = "Xuyang Ge" 41 | email = "xyge20@fudan.edu.cn" 42 | 43 | [[project.authors]] 44 | name = "Zhengfu He" 45 | email = "zfhe19@fudan.edu.cn" 46 | 47 | [[project.authors]] 48 | name = "Wentao Shu" 49 | email = "wtshu20@fudan.edu.cn" 50 | 51 | [[project.authors]] 52 | name = "Fukang Zhu" 53 | email = "fkzhu21@m.fudan.edu.cn" 54 | 55 | [[project.authors]] 56 | name = "Lingjie Chen" 57 | email = "ljchen21@m.fudan.edu.cn" 58 | 59 | [[project.authors]] 60 | name = "Junxuan Wang" 61 | email = "junxuanwang21@m.fudan.edu.cn" 62 | 63 | [project.license] 64 | text = "MIT" 65 | 66 | [project.scripts] 67 | lm-saes = "lm_saes.entrypoint:entrypoint" 68 | 69 | [dependency-groups] 70 | dev = [ 71 | "transformer-lens", 72 | "jupyter>=1.1.1", 73 | "ipywidgets>=8.1.5", 74 | "pytest>=8.3.3", 75 | "ipykernel>=6.29.5", 76 | "nbformat>=5.10.4", 77 | "kaleido==0.2.1", 78 | "pre-commit>=4.0.1", 79 | "ruff>=0.7.1", 80 | "basedpyright>=1.21.0", 81 | "scikit-learn>=1.6.0", 82 | "plotly>=5.24.1", 83 | "pandas>=2.2.3", 84 | "pytest-mock>=3.14.0", 85 | "typeguard>=4.4.1", 86 | "pyfakefs>=5.7.3", 87 | "mongomock>=4.3.0", 88 | "qwen-vl-utils>=0.0.10", 89 | "tabulate>=0.9.0", 90 | "gradio>=5.34.0", 91 | ] 92 | flash-attn = [ 93 | "flash-attn>=2.7.4.post1; (sys_platform == 'win32' or sys_platform == 'linux')", 94 | ] 95 | 96 | [project.optional-dependencies] 97 | # NPU variant of PyTorch (for Huawei Ascend hardware) 98 | default = ["torch==2.7.1", "torchvision>=0.22.1"] 99 | 100 | npu = [ 101 | "torch==2.6.0", 102 | "torchvision>=0.20.1", 103 | "torch-npu==2.6.0rc1", 104 | ] 105 | 106 | triton = ["triton"] 107 | 108 | [tool.uv.sources] 109 | torch = [{ index = "torch-cpu", extra = "npu" }] 110 | torchvision = [{ index = "torch-cpu", extra = "npu" }] 111 | 112 | [[tool.uv.index]] 113 | name = "torch-cpu" 114 | url = "https://download.pytorch.org/whl/cpu" 115 | explicit = true 116 | 117 | [tool.ruff] 118 | exclude = [ 119 | ".bzr", 120 | ".direnv", 121 | ".eggs", 122 | ".git", 123 | ".git-rewrite", 124 | ".hg", 125 | ".ipynb_checkpoints", 126 | ".mypy_cache", 127 | ".nox", 128 | ".pants.d", 129 | ".pyenv", 130 | ".pytest_cache", 131 | ".pytype", 132 | ".ruff_cache", 133 | ".svn", 134 | ".tox", 135 | ".venv", 136 | ".vscode", 137 | "__pypackages__", 138 | "_build", 139 | "buck-out", 140 | "build", 141 | "dist", 142 | "node_modules", 143 | "site-packages", 144 | "venv", 145 | "TransformerLens", 146 | "ui", 147 | ] 148 | line-length = 120 149 | indent-width = 4 150 | target-version = "py311" 151 | 152 | [tool.ruff.lint] 153 | select = ["E4", "E7", "E9", "F", "I"] 154 | ignore = ["E741", "F722"] 155 | fixable = ["ALL"] 156 | unfixable = [] 157 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 158 | 159 | [tool.ruff.format] 160 | quote-style = "double" 161 | indent-style = "space" 162 | skip-magic-trailing-comma = false 163 | line-ending = "auto" 164 | docstring-code-format = false 165 | docstring-code-line-length = "dynamic" 166 | 167 | [tool.pyright] 168 | ignore = [".venv/", "examples", "TransformerLens", "tests", "exp"] 169 | typeCheckingMode = "standard" 170 | reportRedeclaration = false 171 | reportPrivateImportUsage = false 172 | 173 | [tool.uv] 174 | package = true 175 | no-build-isolation-package = ["flash-attn"] 176 | conflicts = [[{ extra = "default" }, { extra = "npu" }], [{ extra = "triton" }, { extra = "npu" }],] 177 | 178 | [[tool.uv.dependency-metadata]] 179 | name = "flash-attn" 180 | version = "2.7.4.post1" 181 | requires-dist = ["torch", "einops"] 182 | 183 | [tool.uv.sources.transformer-lens] 184 | path = "./TransformerLens" 185 | editable = true -------------------------------------------------------------------------------- /server/.env.example: -------------------------------------------------------------------------------- 1 | MONGO_URI= # Must fill in 2 | SAE_SERIES= # Must fill in -------------------------------------------------------------------------------- /server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMOSS/Language-Model-SAEs/dc8c9893108da636d156e507779ebb0aca6b9dca/server/__init__.py -------------------------------------------------------------------------------- /src/lm_saes/__init__.py: -------------------------------------------------------------------------------- 1 | from .activation import ActivationFactory, ActivationWriter 2 | from .analysis import FeatureAnalyzer 3 | from .clt import CrossLayerTranscoder 4 | from .config import ( 5 | ActivationFactoryActivationsSource, 6 | ActivationFactoryConfig, 7 | ActivationFactoryDatasetSource, 8 | ActivationFactoryTarget, 9 | ActivationWriterConfig, 10 | BufferShuffleConfig, 11 | CLTConfig, 12 | CrossCoderConfig, 13 | DatasetConfig, 14 | DirectLogitAttributorConfig, 15 | FeatureAnalyzerConfig, 16 | InitializerConfig, 17 | LanguageModelConfig, 18 | LLaDAConfig, 19 | MongoDBConfig, 20 | SAEConfig, 21 | TrainerConfig, 22 | WandbConfig, 23 | ) 24 | from .crosscoder import CrossCoder 25 | from .database import MongoClient 26 | from .evaluator import EvalConfig, Evaluator 27 | from .resource_loaders import load_dataset, load_model 28 | from .runners import ( 29 | AnalyzeCrossCoderSettings, 30 | AnalyzeSAESettings, 31 | AutoInterpSettings, 32 | CheckActivationConsistencySettings, 33 | DirectLogitAttributeSettings, 34 | EvaluateCrossCoderSettings, 35 | EvaluateSAESettings, 36 | GenerateActivationsSettings, 37 | SweepingItem, 38 | SweepSAESettings, 39 | TrainCLTSettings, 40 | TrainCrossCoderSettings, 41 | TrainSAESettings, 42 | analyze_crosscoder, 43 | analyze_sae, 44 | auto_interp, 45 | check_activation_consistency, 46 | direct_logit_attribute, 47 | evaluate_crosscoder, 48 | evaluate_sae, 49 | generate_activations, 50 | sweep_sae, 51 | train_clt, 52 | train_crosscoder, 53 | train_sae, 54 | ) 55 | from .sae import SparseAutoEncoder 56 | 57 | __all__ = [ 58 | "ActivationFactory", 59 | "ActivationWriter", 60 | "CLTConfig", 61 | "CrossLayerTranscoder", 62 | "CrossCoderConfig", 63 | "CrossCoder", 64 | "SparseAutoEncoder", 65 | "LanguageModelConfig", 66 | "DatasetConfig", 67 | "ActivationFactoryActivationsSource", 68 | "ActivationFactoryDatasetSource", 69 | "ActivationFactoryConfig", 70 | "ActivationWriterConfig", 71 | "BufferShuffleConfig", 72 | "ActivationFactoryTarget", 73 | "load_dataset", 74 | "load_model", 75 | "FeatureAnalyzer", 76 | "EvaluateCrossCoderSettings", 77 | "evaluate_crosscoder", 78 | "EvaluateSAESettings", 79 | "Evaluator", 80 | "EvalConfig", 81 | "evaluate_sae", 82 | "GenerateActivationsSettings", 83 | "generate_activations", 84 | "CheckActivationConsistencySettings", 85 | "check_activation_consistency", 86 | "InitializerConfig", 87 | "SAEConfig", 88 | "TrainerConfig", 89 | "WandbConfig", 90 | "train_sae", 91 | "TrainSAESettings", 92 | "TrainCLTSettings", 93 | "train_clt", 94 | "AnalyzeSAESettings", 95 | "analyze_sae", 96 | "FeatureAnalyzerConfig", 97 | "MongoDBConfig", 98 | "MongoClient", 99 | "AnalyzeCrossCoderSettings", 100 | "analyze_crosscoder", 101 | "AutoInterpSettings", 102 | "SweepingItem", 103 | "SweepSAESettings", 104 | "TrainCrossCoderSettings", 105 | "auto_interp", 106 | "sweep_sae", 107 | "LLaDAConfig", 108 | "train_crosscoder", 109 | "DirectLogitAttributeSettings", 110 | "direct_logit_attribute", 111 | "DirectLogitAttributorConfig", 112 | ] 113 | -------------------------------------------------------------------------------- /src/lm_saes/activation/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import ActivationFactory 2 | from .processors import ( 3 | ActivationBatchler, 4 | BaseActivationProcessor, 5 | HuggingFaceDatasetLoader, 6 | PadAndTruncateTokensProcessor, 7 | RawDatasetTokenProcessor, 8 | ) 9 | from .writer import ActivationWriter 10 | 11 | __all__ = [ 12 | "ActivationFactory", 13 | "BaseActivationProcessor", 14 | "ActivationBatchler", 15 | "HuggingFaceDatasetLoader", 16 | "PadAndTruncateTokensProcessor", 17 | "RawDatasetTokenProcessor", 18 | "ActivationWriter", 19 | ] 20 | -------------------------------------------------------------------------------- /src/lm_saes/activation/processors/__init__.py: -------------------------------------------------------------------------------- 1 | from .activation import ActivationBatchler 2 | from .cached_activation import CachedActivationLoader 3 | from .core import BaseActivationProcessor 4 | from .huggingface import HuggingFaceDatasetLoader 5 | from .token import PadAndTruncateTokensProcessor, RawDatasetTokenProcessor 6 | 7 | __all__ = [ 8 | "BaseActivationProcessor", 9 | "ActivationBatchler", 10 | "HuggingFaceDatasetLoader", 11 | "PadAndTruncateTokensProcessor", 12 | "RawDatasetTokenProcessor", 13 | "CachedActivationLoader", 14 | ] 15 | -------------------------------------------------------------------------------- /src/lm_saes/activation/processors/core.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, Iterable, TypeVar 3 | 4 | # Define type variables for input and output iterable containers 5 | InputT = TypeVar("InputT", bound=Iterable | None) 6 | OutputT = TypeVar("OutputT", bound=Iterable) 7 | 8 | 9 | class BaseActivationProcessor(Generic[InputT, OutputT], ABC): 10 | """Base class for activation processors. 11 | 12 | An activation processor transforms a stream of input data into a stream of output data. 13 | Common use cases include batching, filtering, and transforming activation data from language models. 14 | 15 | The processor can be used either by calling its `process()` method directly or by using the instance 16 | as a callable via `__call__()`. 17 | 18 | TypeVars: 19 | InputT: The type of input iterable container (e.g. List, Generator, Dataset) 20 | OutputT: The type of output iterable container (e.g. DataLoader, Generator) 21 | """ 22 | 23 | @abstractmethod 24 | def process(self, data: InputT, **kwargs) -> OutputT: 25 | """Process the input data stream and return transformed output stream. 26 | 27 | Args: 28 | data: Input data stream to process 29 | **kwargs: Additional keyword arguments for processing 30 | 31 | Returns: 32 | Processed output data stream 33 | 34 | Raises: 35 | NotImplementedError: This is an abstract method that must be implemented by subclasses 36 | """ 37 | raise NotImplementedError 38 | 39 | def __call__(self, data: InputT, **kwargs) -> OutputT: 40 | """Process data by calling the processor instance directly. 41 | 42 | This is a convenience wrapper around the process() method. 43 | 44 | Args: 45 | data: Input data stream to process 46 | **kwargs: Additional keyword arguments for processing 47 | Returns: 48 | Processed output data stream 49 | """ 50 | return self.process(data, **kwargs) 51 | -------------------------------------------------------------------------------- /src/lm_saes/activation/processors/huggingface.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Any, Iterable, Optional, cast 3 | 4 | import torch 5 | from datasets import Dataset 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | 9 | from lm_saes.activation.processors.core import BaseActivationProcessor 10 | 11 | 12 | def identity_collate_fn(x): 13 | """Identity collate function that returns the input as is. 14 | 15 | This should be defined in the global scope so it can be pickled when multiprocessing with `num_workers > 0`. 16 | """ 17 | return x 18 | 19 | 20 | class HuggingFaceDatasetLoader(BaseActivationProcessor[Dataset, Iterable[dict[str, Any]]]): 21 | """A processor that directly loads raw dataset from HuggingFace datasets. 22 | 23 | This processor takes a HuggingFace dataset and converts it into an iterable of raw data records. 24 | It can optionally add context index information to each data record and show a progress bar. 25 | 26 | Args: 27 | batch_size (int): Number of samples per batch. The batch is only used for loading the dataset. 28 | The returned data is always a flattened list of raw data records. 29 | num_workers (int, optional): Number of workers to use for loading the dataset. 30 | Defaults to 0. Use a larger `num_workers` if you have a large dataset and want to speed up loading. 31 | with_info (bool, optional): Whether to include context index information with each batch. 32 | If True, returns tuples of (batch, info) where info contains context indices. This context 33 | index is used to identify the original sample in the dataset. 34 | Defaults to False. 35 | show_progress (bool, optional): Whether to display a progress bar while loading. 36 | Defaults to True. 37 | 38 | Returns: 39 | Iterable: An iterator over batches from the dataset. If with_info=True, yields tuples of 40 | (batch, info) where info contains context indices for each sample in the batch. 41 | """ 42 | 43 | def __init__(self, batch_size: int, num_workers: int = 0, with_info: bool = False, show_progress: bool = True): 44 | self.batch_size = batch_size 45 | self.num_workers = num_workers 46 | self.with_info = with_info 47 | self.show_progress = show_progress 48 | 49 | def process( 50 | self, 51 | data: Dataset, 52 | *, 53 | dataset_name: str | None = None, 54 | metadata: Optional[dict[str, Any]] = None, 55 | **kwargs, 56 | ) -> Iterable[dict[str, Any]]: 57 | """Process the input dataset into batches. 58 | 59 | Args: 60 | data (Dataset): Input HuggingFace dataset to process 61 | dataset_name (str, optional): Name of the dataset. If provided, it will be added to the info field. 62 | Defaults to None. 63 | metadata (dict[str, Any], optional): Metadata to add to each batch. Defaults to None. 64 | **kwargs: Additional keyword arguments for processing. Not used by this processor. 65 | 66 | Returns: 67 | Iterable: Iterator over batches, optionally with context info if with_info=True 68 | """ 69 | 70 | dataloader = cast( 71 | Iterable[list[dict[str, Any]]], 72 | DataLoader( 73 | cast(torch.utils.data.Dataset, data), 74 | batch_size=self.batch_size, 75 | shuffle=False, 76 | pin_memory=True, 77 | collate_fn=identity_collate_fn, 78 | num_workers=self.num_workers, 79 | ), 80 | ) 81 | 82 | if self.show_progress: 83 | dataloader = tqdm(dataloader, desc="Loading dataset") 84 | 85 | flattened = itertools.chain.from_iterable(dataloader) 86 | 87 | if self.with_info: 88 | flattened = map( 89 | lambda x: x[1] 90 | | { 91 | "meta": { 92 | "context_idx": x[0], 93 | **({"dataset_name": dataset_name} if dataset_name else {}), 94 | **(metadata if metadata else {}), 95 | } 96 | }, 97 | enumerate(flattened), 98 | ) 99 | 100 | return flattened 101 | -------------------------------------------------------------------------------- /src/lm_saes/activation/processors/token.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterable, Optional 2 | 3 | import torch 4 | 5 | from lm_saes.activation.processors.core import BaseActivationProcessor 6 | from lm_saes.backend.language_model import LanguageModel 7 | from lm_saes.utils.misc import pad_and_truncate_tokens 8 | 9 | 10 | class RawDatasetTokenProcessor(BaseActivationProcessor[Iterable[dict[str, Any]], Iterable[dict[str, Any]]]): 11 | """Processor for converting raw token datasets into model-ready token format. 12 | 13 | This processor takes an iterable of dictionaries containing raw data (e.g. text and images) and converts 14 | them into a tokens. The output is a dictionary with a "tokens" key, which contains the (non-padded and non-truncated) 15 | tokens. The "meta" key is preserved if it exists in the input. 16 | 17 | Args: 18 | prepend_bos: Whether to prepend beginning-of-sequence token. If None, uses model default. 19 | """ 20 | 21 | def __init__(self, prepend_bos: bool | None = None): 22 | self.prepend_bos = prepend_bos 23 | 24 | def process(self, data: Iterable[dict[str, Any]], *, model: LanguageModel, **kwargs) -> Iterable[dict[str, Any]]: 25 | """Process raw data into tokens. 26 | 27 | Args: 28 | data: Iterable of dictionaries containing raw data (e.g. text and images) 29 | model: model to use for producing tokens 30 | **kwargs: Additional keyword arguments. Not used by this processor. 31 | 32 | Yields: 33 | dict: Processed token data with optional info field 34 | """ 35 | for d in data: 36 | # TODO: support TransformerLens backend 37 | tokens = model.to_tokens(d, prepend_bos=self.prepend_bos) # type: ignore 38 | ret = {"tokens": tokens[0]} 39 | if "meta" in d: 40 | ret = ret | {"meta": d["meta"]} 41 | yield ret 42 | 43 | 44 | class PadAndTruncateTokensProcessor(BaseActivationProcessor[Iterable[dict[str, Any]], Iterable[dict[str, Any]]]): 45 | """Processor for padding and truncating tokens to a desired sequence length. 46 | 47 | This processor takes an iterable of dictionaries containing tokens and pads them to a desired sequence length. 48 | The output is a dictionary with a "tokens" key, which contains the padded tokens. 49 | 50 | Args: 51 | seq_len (int): The desired sequence length to pad/truncate to 52 | pad_token_id (int, optional): The token ID to use for padding. Defaults to 0. 53 | """ 54 | 55 | def __init__(self, seq_len: int): 56 | self.seq_len = seq_len 57 | 58 | def process( 59 | self, 60 | data: Iterable[dict[str, Any]], 61 | *, 62 | pad_token_id: Optional[int] = None, 63 | model: Optional[LanguageModel] = None, 64 | **kwargs, 65 | ) -> Iterable[dict[str, Any]]: 66 | """Process tokens by padding or truncating to desired sequence length. 67 | 68 | Args: 69 | data (Iterable[dict[str, Any]]): Input data containing tokens to process 70 | pad_token_id (int, optional): The token ID to use for padding. Defaults to None. 71 | If not specified, the pad_token_id will be inferred from the model's tokenizer. 72 | If neither is provided, the pad_token_id will be 0. 73 | model (LanguageModel, optional): The model to use for padding. Defaults to None. 74 | If provided, the pad_token_id will be inferred from the model's tokenizer. 75 | **kwargs: Additional keyword arguments. Not used by this processor. 76 | 77 | Yields: 78 | dict[str, Any]: Dictionary containing processed tokens padded/truncated to seq_len, 79 | and original info field if present 80 | """ 81 | 82 | # Infer pad_token_id if not provided 83 | if pad_token_id is None: 84 | if model is not None: 85 | pad_token_id = model.pad_token_id 86 | else: 87 | pad_token_id = 0 88 | 89 | for d in data: 90 | assert "tokens" in d and isinstance(d["tokens"], torch.Tensor) 91 | assert pad_token_id is not None, "pad_token_id must be provided" 92 | tokens = pad_and_truncate_tokens(d["tokens"], seq_len=self.seq_len, pad_token_id=pad_token_id) 93 | ret = {"tokens": tokens} 94 | if "meta" in d: 95 | ret = ret | {"meta": d["meta"]} 96 | yield ret 97 | -------------------------------------------------------------------------------- /src/lm_saes/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | from .feature_analyzer import FeatureAnalyzer 2 | from .feature_interpreter import ( 3 | AutoInterpConfig, 4 | ExplainerType, 5 | FeatureInterpreter, 6 | ScorerType, 7 | TokenizedSample, 8 | ) 9 | 10 | __all__ = [ 11 | "FeatureAnalyzer", 12 | "FeatureInterpreter", 13 | "AutoInterpConfig", 14 | "TokenizedSample", 15 | "ExplainerType", 16 | "ScorerType", 17 | ] 18 | -------------------------------------------------------------------------------- /src/lm_saes/analysis/direct_logit_attributor.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | from transformer_lens import HookedTransformer 4 | 5 | from lm_saes.abstract_sae import AbstractSparseAutoEncoder 6 | from lm_saes.backend import LanguageModel 7 | from lm_saes.backend.language_model import TransformerLensLanguageModel 8 | from lm_saes.config import DirectLogitAttributorConfig 9 | from lm_saes.crosscoder import CrossCoder 10 | from lm_saes.sae import SparseAutoEncoder 11 | 12 | 13 | class DirectLogitAttributor: 14 | def __init__(self, cfg: DirectLogitAttributorConfig): 15 | self.cfg = cfg 16 | 17 | @torch.no_grad() 18 | def direct_logit_attribute(self, sae: AbstractSparseAutoEncoder, model: LanguageModel): 19 | assert isinstance(model, TransformerLensLanguageModel), ( 20 | "DirectLogitAttributor only supports TransformerLensLanguageModel as the model backend" 21 | ) 22 | model: HookedTransformer | None = model.model 23 | assert model is not None, "Model ckpt must be loaded for direct logit attribution" 24 | 25 | if isinstance(sae, CrossCoder): 26 | residual = sae.W_D[-1] 27 | elif isinstance(sae, SparseAutoEncoder): 28 | residual = sae.W_D 29 | else: 30 | raise ValueError(f"Unsupported SAE type: {type(sae)}") 31 | 32 | residual = einops.rearrange(residual, "batch d_model -> batch 1 d_model") # Add a context dimension 33 | 34 | if model.cfg.normalization_type is not None: 35 | residual = model.ln_final(residual) # [batch, pos, d_model] 36 | logits = model.unembed(residual) # [batch, pos, d_vocab] 37 | logits = einops.rearrange(logits, "batch 1 d_vocab -> batch d_vocab") # Remove the context dimension 38 | 39 | # Select the top k tokens 40 | top_k_logits, top_k_indices = torch.topk(logits, self.cfg.top_k, dim=-1) 41 | top_k_tokens = [model.to_str_tokens(top_k_indices[i]) for i in range(sae.cfg.d_sae)] 42 | 43 | assert top_k_logits.shape == top_k_indices.shape == (sae.cfg.d_sae, self.cfg.top_k), ( 44 | f"Top k logits and indices should have shape (d_sae, top_k), but got {top_k_logits.shape} and {top_k_indices.shape}" 45 | ) 46 | assert (len(top_k_tokens), len(top_k_tokens[0])) == (sae.cfg.d_sae, self.cfg.top_k), ( 47 | f"Top k tokens should have shape (d_sae, top_k), but got {len(top_k_tokens)} and {len(top_k_tokens[0])}" 48 | ) 49 | 50 | # Select the bottom k tokens 51 | bottom_k_logits, bottom_k_indices = torch.topk(logits, self.cfg.top_k, dim=-1, largest=False) 52 | bottom_k_tokens = [model.to_str_tokens(bottom_k_indices[i]) for i in range(sae.cfg.d_sae)] 53 | 54 | assert bottom_k_logits.shape == bottom_k_indices.shape == (sae.cfg.d_sae, self.cfg.top_k), ( 55 | f"Bottom k logits and indices should have shape (d_sae, top_k), but got {bottom_k_logits.shape} and {bottom_k_indices.shape}" 56 | ) 57 | assert (len(bottom_k_tokens), len(bottom_k_tokens[0])) == (sae.cfg.d_sae, self.cfg.top_k), ( 58 | f"Bottom k tokens should have shape (d_sae, top_k), but got {len(bottom_k_tokens)} and {len(bottom_k_tokens[0])}" 59 | ) 60 | 61 | result = [ 62 | { 63 | "top_positive": [ 64 | {"token": token, "logit": logit} for token, logit in zip(top_k_tokens[i], top_k_logits[i].tolist()) 65 | ], 66 | "top_negative": [ 67 | {"token": token, "logit": logit} 68 | for token, logit in zip(bottom_k_tokens[i], bottom_k_logits[i].tolist()) 69 | ], 70 | } 71 | for i in range(sae.cfg.d_sae) 72 | ] 73 | return result 74 | -------------------------------------------------------------------------------- /src/lm_saes/backend/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model import LanguageModel 2 | 3 | __all__ = ["LanguageModel"] 4 | -------------------------------------------------------------------------------- /src/lm_saes/circuit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMOSS/Language-Model-SAEs/dc8c9893108da636d156e507779ebb0aca6b9dca/src/lm_saes/circuit/__init__.py -------------------------------------------------------------------------------- /src/lm_saes/circuit/context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from typing import Callable, Tuple, Union 3 | 4 | import torch 5 | from transformer_lens.hook_points import HookedRootModule, HookPoint 6 | 7 | from ..sae import SparseAutoEncoder 8 | 9 | 10 | @contextmanager 11 | def apply_sae(model: HookedRootModule, saes: list[SparseAutoEncoder]): 12 | """ 13 | Apply the sparse autoencoders to the model. 14 | """ 15 | fwd_hooks: list[Tuple[Union[str, Callable], Callable]] = [] 16 | 17 | def get_fwd_hooks(sae: SparseAutoEncoder) -> list[Tuple[Union[str, Callable], Callable]]: 18 | if sae.cfg.hook_point_in == sae.cfg.hook_point_out: 19 | 20 | def hook(tensor: torch.Tensor, hook: HookPoint): 21 | reconstructed = sae.forward(tensor) 22 | return reconstructed + (tensor - reconstructed).detach() 23 | 24 | return [(sae.cfg.hook_point_in, hook)] 25 | else: 26 | x = None 27 | 28 | def hook_in(tensor: torch.Tensor, hook: HookPoint): 29 | nonlocal x 30 | x = tensor 31 | return tensor 32 | 33 | def hook_out(tensor: torch.Tensor, hook: HookPoint): 34 | nonlocal x 35 | assert x is not None, "hook_in must be called before hook_out." 36 | reconstructed = sae.forward(x) 37 | x = None 38 | return reconstructed + (tensor - reconstructed).detach() 39 | 40 | return [(sae.cfg.hook_point_in, hook_in), (sae.cfg.hook_point_out, hook_out)] 41 | 42 | for sae in saes: 43 | hooks = get_fwd_hooks(sae) 44 | fwd_hooks.extend(hooks) 45 | with model.mount_hooked_modules([(sae.cfg.hook_point_out, "sae", sae) for sae in saes]): 46 | with model.hooks(fwd_hooks): 47 | yield model 48 | 49 | 50 | @contextmanager 51 | def detach_at( 52 | model: HookedRootModule, 53 | hook_points: list[str], 54 | ): 55 | """ 56 | Detach the gradients on the given hook points. 57 | """ 58 | 59 | def generate_hook(): 60 | hook_pre = HookPoint() 61 | hook_post = HookPoint() 62 | 63 | def hook(tensor: torch.Tensor, hook: HookPoint): 64 | return hook_post(hook_pre(tensor).detach().requires_grad_()) 65 | 66 | return hook_pre, hook_post, hook 67 | 68 | hooks = {hook_point: generate_hook() for hook_point in hook_points} 69 | fwd_hooks = [(hook_point, hook) for hook_point, (_, _, hook) in hooks.items()] 70 | with model.mount_hooked_modules( 71 | [(hook_point, "pre", hook) for hook_point, (hook, _, _) in hooks.items()] 72 | + [(hook_point, "post", hook) for hook_point, (_, hook, _) in hooks.items()] 73 | ): 74 | with model.hooks(fwd_hooks): 75 | yield model 76 | -------------------------------------------------------------------------------- /src/lm_saes/resource_loaders.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal, Optional, cast 2 | 3 | import datasets 4 | import torch 5 | from torch.distributed.device_mesh import DeviceMesh 6 | 7 | from lm_saes.backend.language_model import ( 8 | LanguageModel, 9 | LLaDALanguageModel, 10 | QwenLanguageModel, 11 | QwenVLLanguageModel, 12 | TransformerLensLanguageModel, 13 | ) 14 | from lm_saes.config import DatasetConfig, LanguageModelConfig, LLaDAConfig 15 | 16 | 17 | def dataset_transform(data): 18 | if "image" in data: 19 | # Rename image to images 20 | data["images"] = data["image"] 21 | del data["image"] 22 | 23 | data["images"] = [[torch.tensor(image) for image in images] for images in data["images"]] 24 | return data 25 | 26 | 27 | def load_dataset_shard( 28 | cfg: DatasetConfig, 29 | shard_idx: int, 30 | n_shards: int, 31 | ) -> datasets.Dataset: 32 | if not cfg.is_dataset_on_disk: 33 | dataset = datasets.load_dataset(cfg.dataset_name_or_path, split="train", cache_dir=cfg.cache_dir) 34 | else: 35 | dataset = datasets.load_from_disk(cfg.dataset_name_or_path) 36 | dataset = cast(datasets.Dataset, dataset) 37 | dataset = dataset.shard(num_shards=n_shards, index=shard_idx, contiguous=True) 38 | dataset.set_transform(dataset_transform) 39 | return dataset 40 | 41 | 42 | def load_dataset( 43 | cfg: DatasetConfig, 44 | device_mesh: Optional[DeviceMesh] = None, 45 | n_shards: Optional[int] = None, 46 | start_shard: int = 0, 47 | ) -> tuple[datasets.Dataset, Optional[dict[str, Any]]]: 48 | if not cfg.is_dataset_on_disk: 49 | dataset = datasets.load_dataset( 50 | cfg.dataset_name_or_path, split="train", cache_dir=cfg.cache_dir, trust_remote_code=True 51 | ) 52 | else: 53 | dataset = datasets.load_from_disk(cfg.dataset_name_or_path) 54 | dataset = cast(datasets.Dataset, dataset) 55 | if device_mesh is not None: 56 | shard = dataset.shard( 57 | num_shards=n_shards or device_mesh.get_group("data").size(), 58 | index=start_shard + device_mesh.get_group("data").rank(), 59 | contiguous=True, 60 | ) 61 | shard_metadata = { 62 | "shard_idx": start_shard + device_mesh.get_group("data").rank(), 63 | "n_shards": n_shards or device_mesh.get_group("data").size(), 64 | } 65 | else: 66 | shard = dataset 67 | shard_metadata = None 68 | shard.set_transform(dataset_transform) 69 | return shard, shard_metadata 70 | 71 | 72 | def infer_model_backend(model_name: str) -> Literal["huggingface", "transformer_lens"]: 73 | if model_name.startswith("Qwen/Qwen2.5-VL"): 74 | return "huggingface" 75 | elif model_name.startswith("Qwen/Qwen2.5"): 76 | return "huggingface" 77 | elif model_name.startswith("GSAI-ML/LLaDA"): 78 | return "huggingface" 79 | else: 80 | return "transformer_lens" 81 | 82 | 83 | def load_model(cfg: LanguageModelConfig) -> LanguageModel: 84 | backend = infer_model_backend(cfg.model_name) if cfg.backend == "auto" else cfg.backend 85 | if backend == "huggingface": 86 | if cfg.model_name.startswith("Qwen/Qwen2.5-VL"): 87 | return QwenVLLanguageModel(cfg) 88 | elif cfg.model_name.startswith("Qwen/Qwen2.5"): 89 | return QwenLanguageModel(cfg) 90 | else: 91 | raise NotImplementedError(f"Model {cfg.model_name} not supported in HuggingFace backend.") 92 | elif backend == "transformer_lens": 93 | if cfg.model_name.startswith("GSAI-ML/LLaDA"): 94 | assert isinstance(cfg, LLaDAConfig) 95 | return LLaDALanguageModel(cfg) 96 | else: 97 | return TransformerLensLanguageModel(cfg) 98 | else: 99 | raise NotImplementedError(f"Backend {backend} not supported.") 100 | -------------------------------------------------------------------------------- /src/lm_saes/runners/__init__.py: -------------------------------------------------------------------------------- 1 | """Runner module for executing various operations on language models and SAEs.""" 2 | 3 | from .analyze import ( 4 | AnalyzeCrossCoderSettings, 5 | AnalyzeSAESettings, 6 | DirectLogitAttributeSettings, 7 | analyze_crosscoder, 8 | analyze_sae, 9 | direct_logit_attribute, 10 | ) 11 | from .autointerp import AutoInterpSettings, auto_interp 12 | from .eval import ( 13 | EvaluateCrossCoderSettings, 14 | EvaluateSAESettings, 15 | evaluate_crosscoder, 16 | evaluate_sae, 17 | ) 18 | from .generate import ( 19 | CheckActivationConsistencySettings, 20 | GenerateActivationsSettings, 21 | check_activation_consistency, 22 | generate_activations, 23 | ) 24 | from .train import ( 25 | SweepingItem, 26 | SweepSAESettings, 27 | TrainCLTSettings, 28 | TrainCrossCoderSettings, 29 | TrainSAESettings, 30 | sweep_sae, 31 | train_clt, 32 | train_crosscoder, 33 | train_sae, 34 | ) 35 | from .utils import load_config 36 | 37 | __all__ = [ 38 | "DirectLogitAttributeSettings", 39 | "direct_logit_attribute", 40 | "GenerateActivationsSettings", 41 | "generate_activations", 42 | "CheckActivationConsistencySettings", 43 | "check_activation_consistency", 44 | "TrainSAESettings", 45 | "train_sae", 46 | "TrainCrossCoderSettings", 47 | "train_crosscoder", 48 | "TrainCLTSettings", 49 | "train_clt", 50 | "SweepSAESettings", 51 | "SweepingItem", 52 | "sweep_sae", 53 | "AnalyzeSAESettings", 54 | "analyze_sae", 55 | "AnalyzeCrossCoderSettings", 56 | "analyze_crosscoder", 57 | "AutoInterpSettings", 58 | "auto_interp", 59 | "load_config", 60 | "EvaluateCrossCoderSettings", 61 | "evaluate_crosscoder", 62 | "EvaluateSAESettings", 63 | "evaluate_sae", 64 | ] 65 | -------------------------------------------------------------------------------- /src/lm_saes/runners/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for runners.""" 2 | 3 | from typing import Literal, Optional, TypeVar, overload 4 | 5 | from lm_saes.database import MongoClient 6 | from lm_saes.utils.logging import get_logger 7 | 8 | logger = get_logger("runners.utils") 9 | 10 | T = TypeVar("T") 11 | 12 | 13 | @overload 14 | def load_config( 15 | config: Optional[T], 16 | name: Optional[str], 17 | mongo_client: Optional[MongoClient], 18 | config_type: str, 19 | required: Literal[True] = True, 20 | ) -> T: ... 21 | 22 | 23 | @overload 24 | def load_config( 25 | config: Optional[T], 26 | name: Optional[str], 27 | mongo_client: Optional[MongoClient], 28 | config_type: str, 29 | required: Literal[False] = False, 30 | ) -> Optional[T]: ... 31 | 32 | 33 | def load_config( 34 | config: Optional[T], 35 | name: Optional[str], 36 | mongo_client: Optional[MongoClient], 37 | config_type: str, 38 | required: bool = True, 39 | ) -> Optional[T]: 40 | """Load configuration from settings or database. 41 | 42 | Args: 43 | config: Configuration provided directly in settings 44 | name: Name of the config to load from database 45 | mongo_client: Optional MongoDB client for database operations 46 | config_type: String identifier for error messages ('model' or 'dataset') 47 | required: Whether the config must be present 48 | 49 | Returns: 50 | Loaded configuration or None if not required and not found 51 | 52 | Raises: 53 | AssertionError: If config is required but not found 54 | """ 55 | if mongo_client is not None and name is not None: 56 | if config is None: 57 | config = getattr(mongo_client, f"get_{config_type}_cfg")(name) 58 | logger.info(f"Loaded {config_type} config from database: {name}") 59 | else: 60 | getattr(mongo_client, f"add_{config_type}")(name, config) 61 | logger.info(f"Added {config_type} config to database: {name}") 62 | 63 | if required: 64 | assert config is not None, f"{config_type} config not provided and not found in database" 65 | return config 66 | -------------------------------------------------------------------------------- /src/lm_saes/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMOSS/Language-Model-SAEs/dc8c9893108da636d156e507779ebb0aca6b9dca/src/lm_saes/utils/__init__.py -------------------------------------------------------------------------------- /src/lm_saes/utils/bytes.py: -------------------------------------------------------------------------------- 1 | import io 2 | from functools import lru_cache 3 | 4 | import numpy as np 5 | 6 | 7 | def np_to_bytes(arr): 8 | with io.BytesIO() as buffer: 9 | np.save(buffer, arr) 10 | return buffer.getvalue() 11 | 12 | 13 | def bytes_to_np(b): 14 | with io.BytesIO(b) as buffer: 15 | return np.load(buffer) 16 | 17 | 18 | @lru_cache() 19 | def bytes_to_unicode(): 20 | """ 21 | Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control 22 | characters the bpe code barfs on. 23 | 24 | The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab 25 | if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for 26 | decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup 27 | tables between utf-8 bytes and unicode strings. 28 | """ 29 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) 30 | cs = bs[:] 31 | n = 0 32 | for b in range(2**8): 33 | if b not in bs: 34 | bs.append(b) 35 | cs.append(2**8 + n) 36 | n += 1 37 | cs = [chr(n) for n in cs] 38 | return dict(zip(bs, cs)) 39 | -------------------------------------------------------------------------------- /src/lm_saes/utils/concurrent.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import Future, ThreadPoolExecutor 2 | from queue import Queue 3 | from typing import Iterable, Optional, TypeVar 4 | 5 | from tqdm import tqdm 6 | 7 | T = TypeVar("T") 8 | 9 | 10 | class BackgroundGenerator(Iterable[T]): 11 | """Compute elements of a generator in a background thread pool. 12 | 13 | This class is optimized for scenarios where either the generator or the main thread 14 | performs GIL-releasing work (e.g., File I/O, network requests). Best used when the 15 | generator has no side effects on the program state. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | generator: Iterable[T], 21 | max_prefetch: int = 1, 22 | executor: Optional[ThreadPoolExecutor] = None, 23 | name: Optional[str] = None, 24 | ) -> None: 25 | """Initialize the background generator. 26 | 27 | Args: 28 | generator: The generator to compute elements from in a background thread. 29 | max_prefetch: Maximum number of elements to precompute ahead. If <= 0, 30 | the queue size is infinite. When max_prefetch elements have been 31 | computed, the background thread will wait for consumption. 32 | executor: Optional ThreadPoolExecutor to use. If None, creates a single-thread executor. 33 | """ 34 | self.queue = Queue(max_prefetch) 35 | self.generator = generator 36 | self._executor = executor or ThreadPoolExecutor(max_workers=1) 37 | self._owned_executor = executor is None 38 | self._future: Optional[Future] = None 39 | self.continue_iteration = True 40 | self._started = False 41 | self.pbar = tqdm(total=max_prefetch, desc=f"Background Processing {name}", smoothing=0.001, miniters=1) 42 | self._start() 43 | 44 | def _process_generator(self) -> None: 45 | """Process the generator items in the background thread.""" 46 | try: 47 | for item in self.generator: 48 | if not self.continue_iteration: 49 | break 50 | self.queue.put((True, item)) 51 | self.pbar.update(1) 52 | except Exception as e: 53 | self.queue.put((False, e)) 54 | self.pbar.update(1) 55 | finally: 56 | self.queue.put((False, StopIteration)) 57 | self.pbar.update(1) 58 | 59 | def _start(self) -> None: 60 | """Start the background processing.""" 61 | self._future = self._executor.submit(self._process_generator) 62 | 63 | def __next__(self) -> T: 64 | """Get the next item from the generator. 65 | 66 | Returns: 67 | The next item from the generator. 68 | 69 | Raises: 70 | StopIteration: When the generator is exhausted. 71 | Exception: Any exception raised by the generator. 72 | """ 73 | if self.continue_iteration: 74 | success, next_item = self.queue.get() 75 | self.pbar.update(-1) 76 | if success: 77 | return next_item 78 | else: 79 | self.continue_iteration = False 80 | raise next_item 81 | else: 82 | raise StopIteration 83 | 84 | def __iter__(self) -> "BackgroundGenerator[T]": 85 | """Return self as iterator. 86 | 87 | Returns: 88 | Self as iterator. 89 | """ 90 | return self 91 | 92 | def close(self) -> None: 93 | """Close the generator and clean up resources.""" 94 | self.continue_iteration = False 95 | if self._future is not None: 96 | self._future.cancel() 97 | # Clear the queue to unblock any waiting threads 98 | while not self.queue.empty(): 99 | try: 100 | self.queue.get_nowait() 101 | self.pbar.update(-1) 102 | except Exception: 103 | pass 104 | if self._owned_executor: 105 | self._executor.shutdown(wait=False, cancel_futures=True) 106 | self.pbar.close() 107 | 108 | def __del__(self) -> None: 109 | """Clean up resources when the object is deleted.""" 110 | self.close() 111 | -------------------------------------------------------------------------------- /src/lm_saes/utils/discrete.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, cast 2 | 3 | import torch 4 | 5 | 6 | class DiscreteMapper: 7 | def __init__(self) -> None: 8 | """Initialize a new PythonDiscreteMapper with empty mappings.""" 9 | self.value_to_int: dict[str, int] = {} 10 | self.int_to_value: list[str] = [] 11 | self.counter: int = 0 12 | 13 | def _dist_encode(self, values: list[str], group: torch.distributed.ProcessGroup) -> list[int]: 14 | local_rank = torch.distributed.get_rank(group) 15 | not_seen_values = [value for value in values if value not in self.value_to_int] 16 | not_seen_values_list = [None] * group.size() 17 | torch.distributed.all_gather_object(not_seen_values_list, not_seen_values, group=group) 18 | not_seen_values_list = [ 19 | sublist for sublist in not_seen_values_list if (sublist is not None and len(sublist) > 0) 20 | ] 21 | flattened_not_seen_values = [item for sublist in not_seen_values_list for item in cast(list[str], sublist)] 22 | 23 | # get unique values 24 | unique_values = list(set(flattened_not_seen_values)) 25 | if len(unique_values) == 0: 26 | return [self.value_to_int[value] for value in values] 27 | update_hash = {} 28 | counter = self.counter 29 | broadcast_list: list[Any] = [None] * (len(unique_values) + 1) if local_rank != 0 else [] 30 | if local_rank == 0: 31 | for value in unique_values: 32 | assert value not in self.value_to_int 33 | update_hash[value] = counter 34 | broadcast_list.append(value) 35 | counter += 1 36 | broadcast_list.append(update_hash) 37 | torch.distributed.broadcast_object_list(broadcast_list, src=0, group=group) 38 | self.value_to_int |= broadcast_list[-1] 39 | self.int_to_value.extend(broadcast_list[:-1]) 40 | self.counter += len(broadcast_list[:-1]) 41 | 42 | # check if all hash_table is the same 43 | check_list = [None] * group.size() 44 | torch.distributed.all_gather_object(check_list, self.value_to_int, group=group) 45 | assert all(check_list[i] == check_list[0] for i in range(1, group.size())), ( 46 | "value_to_int is not consistent across processes" 47 | ) # TODO: Remove this check for speed up 48 | return [self.value_to_int[value] for value in values] 49 | 50 | def encode(self, values: list[str], group: Optional[torch.distributed.ProcessGroup] = None) -> list[int]: 51 | """Encode a list of strings to their corresponding integer indices. 52 | 53 | Args: 54 | values: List of strings to encode 55 | 56 | Returns: 57 | List of integer indices 58 | """ 59 | if group is not None: 60 | return self._dist_encode(values, group) 61 | result = [] 62 | for value in values: 63 | if value not in self.value_to_int: 64 | self.value_to_int[value] = self.counter 65 | self.int_to_value.append(value) 66 | self.counter += 1 67 | 68 | result.append(self.value_to_int[value]) 69 | return result 70 | 71 | def decode(self, integers: list[int]) -> list[str]: 72 | """Decode a list of integers back to their corresponding strings. 73 | 74 | Args: 75 | integers: List of integer indices to decode 76 | 77 | Returns: 78 | List of decoded strings 79 | 80 | Raises: 81 | IndexError: If any integer is out of range 82 | """ 83 | return [self.int_to_value[i] for i in integers] 84 | 85 | def get_mapping(self) -> dict[str, int]: 86 | """Get the current mapping from strings to integers. 87 | 88 | Returns: 89 | Dictionary mapping strings to their integer indices 90 | """ 91 | return self.value_to_int.copy() 92 | 93 | 94 | class KeyedDiscreteMapper: 95 | def __init__(self) -> None: 96 | """Initialize a new PythonKeyedDiscreteMapper with empty mappers.""" 97 | self.mappers: dict[str, DiscreteMapper] = {} 98 | 99 | def encode(self, key: str, values: list[str], group: Optional[torch.distributed.ProcessGroup] = None) -> list[int]: 100 | """Encode a list of strings using the mapper associated with the given key. 101 | 102 | Args: 103 | key: The key identifying which mapper to use 104 | values: List of strings to encode 105 | 106 | Returns: 107 | List of integer indices 108 | """ 109 | if key not in self.mappers: 110 | self.mappers[key] = DiscreteMapper() 111 | return self.mappers[key].encode(values, group=group) 112 | 113 | def decode(self, key: str, integers: list[int]) -> list[str]: 114 | """Decode a list of integers using the mapper associated with the given key. 115 | 116 | Args: 117 | key: The key identifying which mapper to use 118 | integers: List of integer indices to decode 119 | 120 | Returns: 121 | List of decoded strings 122 | 123 | Raises: 124 | KeyError: If the key doesn't exist 125 | IndexError: If any integer is out of range 126 | """ 127 | if key not in self.mappers: 128 | raise KeyError("Key not found") 129 | return self.mappers[key].decode(integers) 130 | 131 | def keys(self) -> list[str]: 132 | """Get all keys currently in use. 133 | 134 | Returns: 135 | List of keys 136 | """ 137 | return list(self.mappers.keys()) 138 | -------------------------------------------------------------------------------- /src/lm_saes/utils/hooks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformer_lens.hook_points import HookPoint 3 | 4 | 5 | def compose_hooks(*hooks): 6 | """ 7 | Compose multiple hooks into a single hook by executing them in order. 8 | """ 9 | 10 | def composed_hook(tensor: torch.Tensor, hook: HookPoint): 11 | for hook_fn in hooks: 12 | tensor = hook_fn(tensor, hook) 13 | return tensor 14 | 15 | return composed_hook 16 | 17 | 18 | def retain_grad_hook(tensor: torch.Tensor, hook: HookPoint): 19 | """ 20 | Retain the gradient of the tensor at the given hook point. 21 | """ 22 | tensor.retain_grad() 23 | return tensor 24 | 25 | 26 | def detach_hook(tensor: torch.Tensor, hook: HookPoint): 27 | """ 28 | Detach the tensor at the given hook point. 29 | """ 30 | return tensor.detach().requires_grad_(True) 31 | -------------------------------------------------------------------------------- /src/lm_saes/utils/huggingface.py: -------------------------------------------------------------------------------- 1 | # Some helper function to interact with huggingface API 2 | 3 | import os 4 | import re 5 | import shutil 6 | 7 | from huggingface_hub import create_repo, snapshot_download, upload_folder 8 | 9 | from .misc import print_once 10 | 11 | 12 | def upload_pretrained_sae_to_hf(sae_path: str, repo_id: str, private: bool = False): 13 | """Upload a pretrained SAE model to huggingface model hub 14 | 15 | Args: 16 | sae_path (str): path to the local SAE model 17 | """ 18 | 19 | from lm_saes.config import LanguageModelConfig 20 | from lm_saes.sae import SparseAutoEncoder 21 | 22 | # Load the model 23 | sae = SparseAutoEncoder.from_pretrained(sae_path) 24 | lm_config = LanguageModelConfig.from_pretrained_sae(sae_path) 25 | 26 | # Create local temporary directory for uploading 27 | folder_name = ( 28 | sae.cfg.hook_point_in 29 | if sae.cfg.hook_point_in == sae.cfg.hook_point_out 30 | else f"{sae.cfg.hook_point_in}-{sae.cfg.hook_point_out}" 31 | ) 32 | os.makedirs(f"{sae_path}/{folder_name}", exist_ok=True) 33 | 34 | try: 35 | # Save the model 36 | create_repo(repo_id=repo_id, private=private, exist_ok=True) 37 | 38 | sae.save_pretrained(f"{sae_path}/{folder_name}") 39 | sae.cfg.save_hyperparameters(f"{sae_path}/{folder_name}") 40 | lm_config.save_lm_config(f"{sae_path}/{folder_name}") 41 | 42 | # Upload the model 43 | upload_folder( 44 | folder_path=f"{sae_path}/{folder_name}", 45 | repo_id=repo_id, 46 | path_in_repo=folder_name, 47 | commit_message=f"Upload pretrained SAE model. Hook point: {folder_name}. Language Model Name: {lm_config.model_name}", 48 | ) 49 | 50 | finally: 51 | # Remove the temporary directory 52 | shutil.rmtree(f"{sae_path}/{folder_name}") 53 | 54 | 55 | def download_pretrained_sae_from_hf(repo_id: str, hook_point: str): 56 | """Download a pretrained SAE model from huggingface model hub 57 | 58 | Args: 59 | repo_id (str): id of the repo 60 | hook_point (str): hook point 61 | """ 62 | 63 | snapshot_path = snapshot_download(repo_id=repo_id, allow_patterns=[f"{hook_point}/*"]) 64 | return os.path.join(snapshot_path, hook_point) 65 | 66 | 67 | def _parse_repo_id(pretrained_name_or_path): 68 | pattern = r"L(\d{1,2})([RAMTC])-(8|32)x" 69 | 70 | def replace_match(match): 71 | sublayer = match.group(2) 72 | exp_factor = match.group(3) 73 | 74 | return f"LX{sublayer}-{exp_factor}x" 75 | 76 | output_string = re.sub(pattern, replace_match, pretrained_name_or_path) 77 | 78 | return output_string 79 | 80 | 81 | def parse_pretrained_name_or_path(pretrained_name_or_path: str): 82 | if os.path.exists(pretrained_name_or_path): 83 | return pretrained_name_or_path 84 | else: 85 | print_once(f"Local path `{pretrained_name_or_path}` not found. Downloading from huggingface model hub.") 86 | if pretrained_name_or_path.startswith("fnlp"): 87 | print_once("Downloading Llama Scope SAEs.") 88 | repo_id = _parse_repo_id(pretrained_name_or_path) 89 | hook_point = pretrained_name_or_path.split("/")[1] 90 | else: 91 | repo_id = "/".join(pretrained_name_or_path.split("/")[:2]) 92 | hook_point = "/".join(pretrained_name_or_path.split("/")[2:]) 93 | return download_pretrained_sae_from_hf(repo_id, hook_point) 94 | -------------------------------------------------------------------------------- /src/lm_saes/utils/math.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def compute_geometric_median(x: torch.Tensor, max_iter=1000) -> torch.Tensor: 5 | """ 6 | Compute the geometric median of a point cloud x. 7 | The geometric median is the point that minimizes the sum of distances to the other points. 8 | This function uses Weiszfeld's algorithm to compute the geometric median. 9 | 10 | Args: 11 | x: Input point cloud. Shape (n_points, n_dims) 12 | max_iter: Maximum number of iterations 13 | 14 | Returns: 15 | The geometric median of the point cloud. Shape (n_dims,) 16 | """ 17 | 18 | # Initialize the geometric median as the mean of the points 19 | y = x.mean(dim=0) 20 | 21 | for _ in range(max_iter): 22 | # Compute the weights 23 | w = 1 / (x - y.unsqueeze(0)).norm(dim=-1) 24 | 25 | # Update the geometric median 26 | y = (w.unsqueeze(-1) * x).sum(dim=0) / w.sum() 27 | 28 | return y 29 | 30 | 31 | def norm_ratio(a, b): 32 | a_norm = torch.norm(a, 2, dim=0).mean() 33 | b_norm = torch.norm(b, 2, dim=0).mean() 34 | return a_norm / b_norm 35 | -------------------------------------------------------------------------------- /src/lm_saes/utils/tensor_dict.py: -------------------------------------------------------------------------------- 1 | from typing import Any, TypeVar, overload 2 | 3 | import torch 4 | 5 | T = TypeVar("T") 6 | 7 | 8 | def sort_dict_of_tensor( 9 | tensor_dict: dict[str, torch.Tensor], 10 | sort_key: str, 11 | sort_dim: int = 0, 12 | descending: bool = True, 13 | ): 14 | """ 15 | Sort a dictionary of tensors by the values of a tensor. 16 | 17 | Args: 18 | tensor_dict: Dictionary of tensors 19 | sort_key: Key of the tensor to sort by 20 | sort_dim: Dimension to sort along 21 | descending: Whether to sort in descending order 22 | 23 | Returns: 24 | A dictionary of tensors sorted by the values of the specified tensor 25 | """ 26 | sorted_idx = tensor_dict[sort_key].argsort(dim=sort_dim, descending=descending) 27 | return { 28 | k: v.gather(sort_dim, sorted_idx.unsqueeze(-1).expand_as(v.reshape(*sorted_idx.shape, -1)).reshape_as(v)) 29 | for k, v in tensor_dict.items() 30 | } 31 | 32 | 33 | def concat_dict_of_tensor(*dicts: dict[str, torch.Tensor], dim: int = 0) -> dict[str, torch.Tensor]: 34 | """ 35 | Concatenate a dictionary of tensors along a specified dimension. 36 | 37 | Args: 38 | *dicts: Dictionaries of tensors 39 | dim: Dimension to concatenate along 40 | 41 | Returns: 42 | A dictionary of tensors concatenated along the specified dimension 43 | """ 44 | return {k: torch.cat([d[k] for d in dicts], dim=dim) for k in dicts[0].keys()} 45 | 46 | 47 | @overload 48 | def move_dict_of_tensor_to_device( 49 | tensor_dict: dict[str, torch.Tensor], device: torch.device | str 50 | ) -> dict[str, torch.Tensor]: ... 51 | 52 | 53 | @overload 54 | def move_dict_of_tensor_to_device(tensor_dict: dict[str, Any], device: torch.device | str) -> dict[str, Any]: ... 55 | 56 | 57 | def move_dict_of_tensor_to_device(tensor_dict: dict[str, Any], device: torch.device | str) -> dict[str, Any]: 58 | """ 59 | Move tensors in a dictionary to specified device, leaving non-tensor values unchanged. 60 | 61 | Args: 62 | tensor_dict: Dictionary containing tensors and possibly other types 63 | device: Target device to move tensors to 64 | 65 | Returns: 66 | Dictionary with tensors moved to specified device 67 | """ 68 | return {k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v for k, v in tensor_dict.items()} 69 | 70 | 71 | def batch_size(tensor_dict: dict[str, torch.Tensor]) -> int: 72 | """ 73 | Get the batch size of a dictionary of tensors. 74 | """ 75 | return len(tensor_dict[list(tensor_dict.keys())[0]]) 76 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMOSS/Language-Model-SAEs/dc8c9893108da636d156e507779ebb0aca6b9dca/tests/__init__.py -------------------------------------------------------------------------------- /tests/integration/test_train_sae.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import torch 5 | from pytest_mock import MockerFixture 6 | from torch.distributed.device_mesh import init_device_mesh 7 | 8 | from lm_saes.config import InitializerConfig, MixCoderConfig, SAEConfig, TrainerConfig 9 | from lm_saes.initializer import Initializer 10 | from lm_saes.trainer import Trainer 11 | 12 | pytest.skip("This test needs fixing", allow_module_level=True) 13 | 14 | 15 | @pytest.fixture 16 | def sae_config() -> SAEConfig: 17 | return SAEConfig( 18 | hook_point_in="in", 19 | hook_point_out="out", 20 | d_model=2, 21 | expansion_factor=2, 22 | device="cpu", 23 | dtype=torch.bfloat16, # the precision of bfloat16 is not enough for the tests 24 | act_fn="topk", 25 | norm_activation="dataset-wise", 26 | sparsity_include_decoder_norm=True, 27 | top_k=2, 28 | ) 29 | 30 | 31 | @pytest.fixture 32 | def mixcoder_config() -> MixCoderConfig: 33 | return MixCoderConfig( 34 | hook_point_in="in", 35 | hook_point_out="out", 36 | d_model=2, 37 | expansion_factor=2, 38 | device="cpu", 39 | dtype=torch.bfloat16, # the precision of bfloat16 is not enough for the tests 40 | act_fn="topk", 41 | norm_activation="dataset-wise", 42 | top_k=2, 43 | modalities={"image": 2, "text": 2, "shared": 2}, 44 | ) 45 | 46 | 47 | @pytest.fixture 48 | def initializer_config() -> InitializerConfig: 49 | return InitializerConfig( 50 | state="training", 51 | init_search=True, 52 | l1_coefficient=0.00008, 53 | ) 54 | 55 | 56 | @pytest.fixture 57 | def trainer_config(tmp_path) -> TrainerConfig: 58 | # Remove tmp path 59 | os.rmdir(tmp_path) 60 | return TrainerConfig( 61 | initial_k=3, 62 | total_training_tokens=400, 63 | log_frequency=10, 64 | eval_frequency=10, 65 | n_checkpoints=0, 66 | exp_result_path=str(tmp_path), 67 | ) 68 | 69 | 70 | def test_train_sae( 71 | sae_config: SAEConfig, 72 | initializer_config: InitializerConfig, 73 | trainer_config: TrainerConfig, 74 | mocker: MockerFixture, 75 | tmp_path, 76 | ) -> None: 77 | wandb_runner = mocker.Mock() 78 | wandb_runner.log = lambda *args, **kwargs: None 79 | device_mesh = ( 80 | init_device_mesh( 81 | device_type="cuda", 82 | mesh_shape=(int(os.environ.get("WORLD_SIZE", 1)), 1), 83 | mesh_dim_names=("data", "model"), 84 | ) 85 | if os.environ.get("WORLD_SIZE") is not None 86 | else None 87 | ) 88 | activation_stream = [ 89 | { 90 | "in": torch.randn(4, 2, dtype=sae_config.dtype, device=sae_config.device), 91 | "out": torch.randn(4, 2, dtype=sae_config.dtype, device=sae_config.device), 92 | "tokens": torch.tensor([2, 3, 4, 5], dtype=torch.long, device=sae_config.device), 93 | } 94 | for _ in range(200) 95 | ] 96 | initializer = Initializer(initializer_config) 97 | sae = initializer.initialize_sae_from_config( 98 | sae_config, 99 | device_mesh=device_mesh, 100 | activation_stream=activation_stream, 101 | ) 102 | trainer = Trainer(trainer_config) 103 | trainer.fit( 104 | sae=sae, 105 | activation_stream=activation_stream, 106 | eval_fn=lambda x: None, 107 | wandb_logger=wandb_runner, 108 | ) 109 | 110 | 111 | def test_train_mixcoder( 112 | mixcoder_config: MixCoderConfig, 113 | initializer_config: InitializerConfig, 114 | trainer_config: TrainerConfig, 115 | mocker: MockerFixture, 116 | tmp_path, 117 | ) -> None: 118 | wandb_runner = mocker.Mock() 119 | wandb_runner.log = lambda *args, **kwargs: None 120 | device_mesh = ( 121 | init_device_mesh( 122 | device_type="cuda", 123 | mesh_shape=(int(os.environ.get("WORLD_SIZE", 1)), 1), 124 | mesh_dim_names=("data", "model"), 125 | ) 126 | if os.environ.get("WORLD_SIZE") is not None 127 | else None 128 | ) 129 | activation_stream = [ 130 | { 131 | "in": torch.randn(4, 2, dtype=mixcoder_config.dtype, device=mixcoder_config.device), 132 | "out": torch.randn(4, 2, dtype=mixcoder_config.dtype, device=mixcoder_config.device), 133 | "tokens": torch.tensor([2, 3, 4, 5], dtype=torch.long, device=mixcoder_config.device), 134 | } 135 | for _ in range(200) 136 | ] 137 | initializer = Initializer(initializer_config) 138 | tokenizer = mocker.Mock() 139 | tokenizer.get_vocab.return_value = { 140 | "IMGIMG1": 1, 141 | "IMGIMG2": 2, 142 | "IMGIMG3": 3, 143 | "IMGIMG4": 4, 144 | "TEXT1": 5, 145 | "TEXT2": 6, 146 | "TEXT3": 7, 147 | "TEXT4": 8, 148 | } 149 | model_name = "facebook/chameleon-7b" 150 | 151 | mixcoder_settings = {"tokenizer": tokenizer, "model_name": model_name} 152 | mixcoder = initializer.initialize_sae_from_config( 153 | mixcoder_config, 154 | device_mesh=device_mesh, 155 | activation_stream=activation_stream, 156 | mixcoder_settings=mixcoder_settings, 157 | ) 158 | trainer = Trainer(trainer_config) 159 | trainer.fit( 160 | sae=mixcoder, 161 | activation_stream=activation_stream, 162 | eval_fn=lambda x: None, 163 | wandb_logger=wandb_runner, 164 | ) 165 | -------------------------------------------------------------------------------- /tests/unit/test_activation_processors_distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.distributed as dist 5 | from torch.distributed.device_mesh import init_device_mesh 6 | from torch.distributed.tensor import DTensor 7 | 8 | from lm_saes.activation.processors.activation import ActivationBuffer 9 | from lm_saes.utils.distributed import DimMap 10 | 11 | 12 | def setup_distributed(): 13 | """Initialize distributed training with torchrun.""" 14 | # torchrun sets these environment variables automatically 15 | rank = int(os.environ.get("RANK", 0)) 16 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 17 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 18 | 19 | # Initialize the process group 20 | dist.init_process_group("nccl" if torch.cuda.is_available() else "gloo") 21 | 22 | # Set device 23 | if torch.cuda.is_available(): 24 | torch.cuda.set_device(local_rank) 25 | device = f"cuda:{local_rank}" 26 | else: 27 | device = "cpu" 28 | 29 | return device, rank, world_size 30 | 31 | 32 | def cleanup_distributed(): 33 | """Clean up distributed training.""" 34 | dist.destroy_process_group() 35 | 36 | 37 | def test_activation_buffer_distributed(): 38 | device, rank, world_size = setup_distributed() 39 | print(f"Running test on device {device}, rank {rank}, world size {world_size}") 40 | assert world_size == 2, "This test requires 2 processes" 41 | 42 | # Initialize device mesh 43 | device_mesh = init_device_mesh( 44 | device_type="cuda" if torch.cuda.is_available() else "cpu", mesh_shape=[world_size], mesh_dim_names=["data"] 45 | ) 46 | 47 | # Initialize activation buffer 48 | buffer = ActivationBuffer(device_mesh=device_mesh) 49 | 50 | a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15], [16, 17, 18]], device=device) 51 | a = DimMap({"data": 0}).distribute(a, device_mesh) 52 | buffer = buffer.cat({"a": a}) 53 | batch, buffer = buffer.yield_batch(2) 54 | 55 | if rank == 0: 56 | assert isinstance(batch["a"], DTensor) and torch.allclose( 57 | batch["a"].to_local(), torch.tensor([[1, 2, 3]], device=device) 58 | ) 59 | else: 60 | assert isinstance(batch["a"], DTensor) and torch.allclose( 61 | batch["a"].to_local(), torch.tensor([[10, 11, 12]], device=device) 62 | ) 63 | 64 | a = torch.tensor([[1, 2, 3], [4, 5, 6]], device=device) 65 | a = DimMap({"data": 0}).distribute(a, device_mesh) 66 | buffer = buffer.cat({"a": a}) 67 | 68 | remaining_buffer = buffer.consume() 69 | if rank == 0: 70 | assert isinstance(remaining_buffer["a"], DTensor) and torch.allclose( 71 | remaining_buffer["a"].to_local(), torch.tensor([[4, 5, 6], [7, 8, 9], [1, 2, 3]], device=device) 72 | ) 73 | else: 74 | assert isinstance(remaining_buffer["a"], DTensor) and torch.allclose( 75 | remaining_buffer["a"].to_local(), torch.tensor([[13, 14, 15], [16, 17, 18], [4, 5, 6]], device=device) 76 | ) 77 | 78 | cleanup_distributed() 79 | 80 | 81 | if __name__ == "__main__": 82 | test_activation_buffer_distributed() 83 | -------------------------------------------------------------------------------- /tests/unit/test_concurrent.py: -------------------------------------------------------------------------------- 1 | import time 2 | from concurrent.futures import ThreadPoolExecutor 3 | 4 | import pytest 5 | 6 | from lm_saes.utils.concurrent import BackgroundGenerator 7 | 8 | 9 | def test_basic_iteration(): 10 | """Test basic iteration through a simple list.""" 11 | data = [1, 2, 3, 4, 5] 12 | bg = BackgroundGenerator(data) 13 | assert list(bg) == data 14 | 15 | 16 | def test_empty_generator(): 17 | """Test behavior with an empty generator.""" 18 | bg = BackgroundGenerator([]) 19 | with pytest.raises(StopIteration): 20 | next(bg) 21 | 22 | 23 | def test_max_prefetch(): 24 | """Test that max_prefetch limits the queue size.""" 25 | 26 | def slow_generator(): 27 | for i in range(3): 28 | time.sleep(0.1) # Simulate slow generation 29 | yield i 30 | 31 | bg = BackgroundGenerator(slow_generator(), max_prefetch=1) 32 | assert bg.queue.maxsize == 1 33 | 34 | # Consume all items 35 | assert list(bg) == [0, 1, 2] 36 | 37 | 38 | def test_exception_propagation(): 39 | """Test that exceptions from the generator are properly propagated.""" 40 | 41 | def failing_generator(): 42 | yield 1 43 | raise ValueError("Test error") 44 | yield 2 # This will never be reached 45 | 46 | bg = BackgroundGenerator(failing_generator()) 47 | 48 | # First item should be retrieved normally 49 | assert next(bg) == 1 50 | 51 | # Next call should raise the ValueError 52 | with pytest.raises(ValueError, match="Test error"): 53 | next(bg) 54 | 55 | 56 | def test_custom_executor(): 57 | """Test using a custom executor.""" 58 | data = [1, 2, 3] 59 | with ThreadPoolExecutor(max_workers=2) as executor: 60 | bg = BackgroundGenerator(data, executor=executor) 61 | assert list(bg) == data 62 | 63 | 64 | def test_generator_cleanup(mocker): 65 | """Test proper cleanup of resources.""" 66 | # Mock ThreadPoolExecutor 67 | mock_executor = mocker.Mock(spec=ThreadPoolExecutor) 68 | mock_executor.submit.return_value = mocker.Mock() 69 | 70 | data = [1, 2, 3] 71 | bg = BackgroundGenerator(data, executor=mock_executor) 72 | 73 | # Simulate deletion 74 | bg.__del__() 75 | 76 | # Verify that shutdown wasn't called since it's not an owned executor 77 | mock_executor.shutdown.assert_not_called() 78 | 79 | 80 | def test_large_dataset(): 81 | """Test handling of a larger dataset.""" 82 | large_data = range(1000) 83 | bg = BackgroundGenerator(large_data, max_prefetch=10) 84 | assert list(bg) == list(large_data) 85 | 86 | 87 | def test_multiple_iterations(): 88 | """Test that the generator can only be consumed once.""" 89 | data = [1, 2, 3] 90 | bg = BackgroundGenerator(data) 91 | 92 | # First iteration should work 93 | assert list(bg) == data 94 | 95 | # Second iteration should yield no items 96 | assert list(bg) == [] 97 | -------------------------------------------------------------------------------- /tests/unit/test_database.py: -------------------------------------------------------------------------------- 1 | import mongomock 2 | import pytest 3 | 4 | from lm_saes.config import MongoDBConfig, SAEConfig 5 | from lm_saes.database import MongoClient, SAERecord 6 | 7 | 8 | @pytest.fixture 9 | def mongo_client() -> MongoClient: 10 | """ 11 | Creates a MongoClient instance with an in-memory MongoDB. 12 | 13 | Returns: 14 | MongoClient: A configured MongoClient instance for testing 15 | """ 16 | with mongomock.patch(servers=(("fake", 27017),)): 17 | client = MongoClient(MongoDBConfig(mongo_uri="mongodb://fake", mongo_db="test_db")) 18 | yield client 19 | # Clear all collections after each test 20 | client.db.drop_collection("sae") 21 | client.db.drop_collection("analysis") 22 | client.db.drop_collection("feature") 23 | 24 | 25 | def test_create_and_get_sae(mongo_client: MongoClient) -> None: 26 | """Test creating and retrieving an SAE record.""" 27 | # Arrange 28 | name = "test_sae" 29 | series = "test_series" 30 | path = "test_path" 31 | cfg = SAEConfig(hook_point_in="test_hook_point_in", d_sae=10, d_model=10, expansion_factor=1) 32 | 33 | # Act 34 | mongo_client.create_sae(name, series, path, cfg) 35 | result = mongo_client.get_sae(name, series) 36 | 37 | # Assert 38 | assert isinstance(result, SAERecord) 39 | assert result.name == name 40 | assert result.series == series 41 | assert result.path == path 42 | assert result.cfg.d_sae == cfg.d_sae 43 | 44 | 45 | def test_list_saes(mongo_client: MongoClient) -> None: 46 | """Test listing SAE records.""" 47 | # Arrange 48 | cfg = SAEConfig(hook_point_in="test_hook_point_in", d_sae=10, d_model=10, expansion_factor=1) 49 | print(mongo_client.sae_collection.find_one()) 50 | mongo_client.create_sae("sae1", "series1", "test_path", cfg) 51 | mongo_client.create_sae("sae2", "series1", "test_path", cfg) 52 | mongo_client.create_sae("sae3", "series2", "test_path", cfg) 53 | 54 | # Act & Assert 55 | assert set(mongo_client.list_saes()) == {"sae1", "sae2", "sae3"} 56 | assert set(mongo_client.list_saes("series1")) == {"sae1", "sae2"} 57 | assert set(mongo_client.list_saes("series2")) == {"sae3"} 58 | 59 | 60 | def test_remove_sae(mongo_client: MongoClient) -> None: 61 | """Test removing an SAE record.""" 62 | # Arrange 63 | name = "test_sae" 64 | series = "test_series" 65 | path = "test_path" 66 | cfg = SAEConfig(hook_point_in="test_hook_point_in", d_sae=10, d_model=10, expansion_factor=1) 67 | mongo_client.create_sae(name, series, path, cfg) 68 | 69 | # Act 70 | mongo_client.remove_sae(name, series) 71 | 72 | # Assert 73 | assert mongo_client.get_sae(name, series) is None 74 | assert mongo_client.list_saes(series) == [] 75 | 76 | 77 | def test_create_and_get_analysis(mongo_client: MongoClient) -> None: 78 | """Test creating and retrieving analysis records.""" 79 | # Arrange 80 | cfg = SAEConfig(hook_point_in="test_hook_point_in", d_sae=10, d_model=10, expansion_factor=1) 81 | mongo_client.create_sae("test_sae", "test_series", "test_path", cfg) 82 | 83 | # Act 84 | mongo_client.create_analysis("test_analysis", "test_sae", "test_series") 85 | analyses = mongo_client.list_analyses("test_sae", "test_series") 86 | 87 | # Assert 88 | assert analyses == ["test_analysis"] 89 | 90 | 91 | def test_get_nonexistent_sae(mongo_client: MongoClient) -> None: 92 | """Test retrieving a non-existent SAE record.""" 93 | assert mongo_client.get_sae("nonexistent", "series") is None 94 | 95 | 96 | def test_get_feature(mongo_client: MongoClient) -> None: 97 | """Test retrieving feature records.""" 98 | # Arrange 99 | cfg = SAEConfig(hook_point_in="test_hook_point_in", d_sae=2, d_model=2, expansion_factor=1) 100 | mongo_client.create_sae("test_sae", "test_series", "test_path", cfg) 101 | 102 | # Act 103 | feature = mongo_client.get_feature("test_sae", "test_series", 0) 104 | 105 | # Assert 106 | assert feature is not None 107 | assert feature.sae_name == "test_sae" 108 | assert feature.sae_series == "test_series" 109 | assert feature.index == 0 110 | -------------------------------------------------------------------------------- /tests/unit/test_discrete_mapper.py: -------------------------------------------------------------------------------- 1 | from lm_saes.utils.discrete import DiscreteMapper, KeyedDiscreteMapper 2 | 3 | 4 | def test_discrete_mapper(): 5 | mapper = DiscreteMapper() 6 | assert mapper.encode(["a", "b", "a", "c"]) == [0, 1, 0, 2] 7 | assert mapper.decode([0, 1, 0, 2]) == ["a", "b", "a", "c"] 8 | 9 | assert mapper.get_mapping() == {"a": 0, "b": 1, "c": 2} 10 | 11 | assert mapper.encode(["a", "c", "d"]) == [0, 2, 3] 12 | assert mapper.decode([0, 2, 3]) == ["a", "c", "d"] 13 | assert mapper.get_mapping() == {"a": 0, "b": 1, "c": 2, "d": 3} 14 | 15 | assert mapper.encode(["a", "c", "d", "a"]) == [0, 2, 3, 0] 16 | assert mapper.decode([0, 2, 3, 0]) == ["a", "c", "d", "a"] 17 | assert mapper.get_mapping() == {"a": 0, "b": 1, "c": 2, "d": 3} 18 | 19 | 20 | def test_keyed_discrete_mapper(): 21 | mapper = KeyedDiscreteMapper() 22 | assert mapper.encode("foo", ["a", "b", "a", "c"]) == [0, 1, 0, 2] 23 | assert mapper.decode("foo", [0, 1, 0, 2]) == ["a", "b", "a", "c"] 24 | assert mapper.keys() == ["foo"] 25 | 26 | assert mapper.encode("bar", ["a", "c", "d"]) == [0, 1, 2] 27 | assert mapper.decode("bar", [0, 1, 2]) == ["a", "c", "d"] 28 | assert sorted(mapper.keys()) == ["bar", "foo"] 29 | 30 | 31 | if __name__ == "__main__": 32 | import timeit 33 | 34 | print(timeit.timeit("test_discrete_mapper()", globals=globals(), number=1000)) 35 | print(timeit.timeit("test_keyed_discrete_mapper()", globals=globals(), number=1000)) 36 | -------------------------------------------------------------------------------- /tests/unit/test_example.py: -------------------------------------------------------------------------------- 1 | def func(x): 2 | return x + 1 3 | 4 | 5 | def test_answer(): 6 | assert func(4) == 5 7 | -------------------------------------------------------------------------------- /tests/unit/test_hf_backend.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytest 4 | import torch 5 | 6 | from lm_saes.backend.language_model import QwenLanguageModel, QwenVLLanguageModel 7 | from lm_saes.config import LanguageModelConfig 8 | 9 | pytest.skip("This test needs to switch to a smaller model", allow_module_level=True) 10 | # if not torch.cuda.is_available(): 11 | # pytest.skip("CUDA is not available", allow_module_level=True) 12 | 13 | 14 | class TestQwenVLLanguageModel: 15 | @pytest.fixture 16 | def language_model_config(self) -> LanguageModelConfig: 17 | return LanguageModelConfig( 18 | model_name="Qwen/Qwen2.5-VL-7B-Instruct", 19 | device="cuda", 20 | dtype=torch.bfloat16, 21 | use_flash_attn=True, 22 | d_model=3584, 23 | local_files_only=False, 24 | ) 25 | 26 | @pytest.fixture 27 | def multi_image_input(self) -> dict[str, Any]: 28 | return { 29 | "text": [f"{i}!" for i in range(10)], 30 | "images": [torch.randint(0, 256, (3, 3, 28 + 28 * i, 28)) for i in range(10)], 31 | } 32 | 33 | @pytest.fixture 34 | def single_image_input(self) -> dict[str, Any]: 35 | return { 36 | "text": ["test!" for _ in range(10)], 37 | "images": [torch.randint(0, 256, (1, 3, 28 + 28 * i, 28)) for i in range(10)], 38 | } 39 | 40 | def test_to_activations(self, language_model_config: LanguageModelConfig, multi_image_input: dict[str, Any]): 41 | hf_language_model = QwenVLLanguageModel(language_model_config) 42 | hf_language_model.model.eval() 43 | hook_points = [f"blocks.{i}.hook_resid_post" for i in range(28)] 44 | activations = hf_language_model.to_activations(multi_image_input, hook_points) 45 | assert len(activations) == 28 46 | 47 | model = hf_language_model.model 48 | processor = hf_language_model.processor 49 | for i in range(len(multi_image_input["text"])): 50 | multi_image_input["text"][i] = multi_image_input["text"][i].replace( 51 | "", "<|vision_start|><|image_pad|><|vision_end|>" 52 | ) 53 | inputs = processor( 54 | text=multi_image_input["text"], 55 | images=list(multi_image_input["images"]), 56 | return_tensors="pt", 57 | padding=True, 58 | ).to(language_model_config.device) 59 | outputs = model(**inputs, output_hidden_states=True) 60 | activations_from_model = {hook_points[i]: outputs.hidden_states[i + 1] for i in range(28)} 61 | for key in activations: 62 | assert torch.allclose(input=activations[key], other=activations_from_model[key], atol=1e-3) 63 | 64 | def test_trace(self, language_model_config: LanguageModelConfig, single_image_input: dict[str, Any]): 65 | hf_language_model = QwenVLLanguageModel(language_model_config) 66 | hf_language_model.model.eval() 67 | trace = hf_language_model.trace(single_image_input) 68 | assert trace[2][6] == (21, 14, 42, 28) 69 | 70 | 71 | class TestQwenLanguageModel: 72 | @pytest.fixture 73 | def language_model_config(self) -> LanguageModelConfig: 74 | return LanguageModelConfig( 75 | model_name="Qwen/Qwen2.5-7B", 76 | device="cuda", 77 | dtype=torch.bfloat16, 78 | d_model=3584, 79 | ) 80 | 81 | @pytest.fixture 82 | def raw_input(self) -> dict[str, Any]: 83 | return { 84 | "text": ["Hello, world!" * i for i in range(10)], 85 | } 86 | 87 | def test_to_activations(self, language_model_config: LanguageModelConfig, raw_input: dict[str, Any]): 88 | hook_points = [f"blocks.{i}.hook_resid_post" for i in range(28)] 89 | hf_language_model = QwenLanguageModel(language_model_config) 90 | hf_language_model.model.eval() 91 | activations = hf_language_model.to_activations(raw_input, hook_points) 92 | assert len(activations) == 28 93 | 94 | def test_trace(self, language_model_config: LanguageModelConfig, raw_input: dict[str, Any]): 95 | hf_language_model = QwenLanguageModel(language_model_config) 96 | hf_language_model.model.eval() 97 | trace = hf_language_model.trace(raw_input) 98 | print(trace) 99 | -------------------------------------------------------------------------------- /tests/unit/test_misc.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Iterator 3 | 4 | import pytest 5 | import torch 6 | 7 | from lm_saes.utils.misc import calculate_activation_norm 8 | 9 | 10 | class TestCalculateActivationNorm: 11 | @pytest.fixture 12 | def mock_activation_stream(self) -> Iterator[dict[str, torch.Tensor]]: 13 | """Creates a mock activation stream with known values for testing.""" 14 | 15 | def stream_generator(): 16 | # Create 10 batches of activations 17 | for _ in range(10): 18 | yield { 19 | "layer1": torch.ones(4, 16), # norm will be sqrt(16) 20 | "layer2": torch.ones(4, 16) * 2, # norm will be sqrt(16) * 2 21 | } 22 | 23 | return stream_generator() 24 | 25 | def test_basic_functionality(self, mock_activation_stream): 26 | """Test basic functionality with default batch_num.""" 27 | result = calculate_activation_norm(mock_activation_stream, ["layer1", "layer2"]) 28 | 29 | assert isinstance(result, dict) 30 | assert "layer1" in result 31 | assert "layer2" in result 32 | 33 | # Expected norms: 34 | # layer1: sqrt(16) ≈ 4 35 | # layer2: sqrt(16) * 2 ≈ 8.0 36 | assert pytest.approx(result["layer1"], rel=1e-4) == 4.0 37 | assert pytest.approx(result["layer2"], rel=1e-4) == 8.0 38 | 39 | def test_custom_batch_num(self, mock_activation_stream): 40 | """Test with custom batch_num parameter.""" 41 | result = calculate_activation_norm(mock_activation_stream, batch_num=3, hook_points=["layer1", "layer2"]) 42 | 43 | # Should still give same results as we're averaging 44 | assert pytest.approx(result["layer1"], rel=1e-4) == 4.0 45 | assert pytest.approx(result["layer2"], rel=1e-4) == 8.0 46 | 47 | def test_empty_stream(self): 48 | """Test behavior with empty activation stream.""" 49 | empty_stream = iter([]) 50 | with pytest.warns(UserWarning): 51 | calculate_activation_norm(empty_stream, hook_points=[""]) 52 | 53 | def test_single_batch(self): 54 | """Test with a single batch of activations.""" 55 | 56 | def single_batch_stream(): 57 | yield {"single": torch.ones(2, 4)} # norm will be 2.0 58 | 59 | result = calculate_activation_norm(single_batch_stream(), hook_points=["single", "single"], batch_num=1) 60 | assert pytest.approx(result["single"], rel=1e-4) == 2.0 61 | 62 | def test_zero_tensors(self): 63 | """Test with zero tensors.""" 64 | 65 | def zero_stream(): 66 | for _ in range(10): 67 | yield {"zeros": torch.zeros(2, 4)} 68 | 69 | result = calculate_activation_norm(zero_stream(), hook_points=["zeros"]) 70 | assert result["zeros"] == 0.0 71 | 72 | def test_mixed_values(self): 73 | """Test with mixed positive/negative values.""" 74 | 75 | def mixed_stream(): 76 | for i in range(10): 77 | yield {"mixed": torch.tensor([[1.0, -2.0], [3.0, -4.0], [3.0, -2.0], [9.0, -4.0]]) * (i + 1)} 78 | 79 | result = calculate_activation_norm(mixed_stream(), hook_points=["mixed"], batch_num=10) 80 | assert ( 81 | pytest.approx(result["mixed"], rel=1e-4) == ((math.sqrt(5) + 5 + math.sqrt(13) + math.sqrt(97)) / 4) * 5.5 82 | ) 83 | -------------------------------------------------------------------------------- /tests/unit/test_mixcoder.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from lm_saes.config import MixCoderConfig 5 | from lm_saes.mixcoder import MixCoder 6 | 7 | pytest.skip("This test needs fixing", allow_module_level=True) 8 | 9 | 10 | @pytest.fixture 11 | def config(): 12 | return MixCoderConfig( 13 | d_model=3, 14 | modalities={"text": 2, "image": 3, "shared": 4}, 15 | device="cpu", 16 | dtype=torch.float32, 17 | use_glu_encoder=True, 18 | use_decoder_bias=True, 19 | hook_point_in="hook_point_in", 20 | hook_point_out="hook_point_out", 21 | expansion_factor=1.0, 22 | top_k=2, 23 | act_fn="topk", 24 | ) 25 | 26 | 27 | @pytest.fixture 28 | def modality_indices(): 29 | return { 30 | "text": torch.tensor([1, 2, 3, 4]), 31 | "image": torch.tensor([5, 6, 7, 8]), 32 | } 33 | 34 | 35 | @pytest.fixture 36 | def mixcoder(config, modality_indices): 37 | model = MixCoder(config) 38 | model.init_parameters(modality_indices=modality_indices) 39 | model.decoder["text"].bias.data = torch.rand_like(model.decoder["text"].bias.data) 40 | model.decoder["image"].bias.data = torch.rand_like(model.decoder["image"].bias.data) 41 | model.decoder["shared"].bias.data = torch.rand_like(model.decoder["shared"].bias.data) 42 | model.encoder["text"].bias.data = torch.rand_like(model.encoder["text"].bias.data) 43 | model.encoder["image"].bias.data = torch.rand_like(model.encoder["image"].bias.data) 44 | model.encoder["shared"].bias.data = torch.rand_like(model.encoder["shared"].bias.data) 45 | return model 46 | 47 | 48 | def test_init_parameters(mixcoder, config): 49 | assert mixcoder.modality_index == {"text": (0, 2), "image": (2, 5), "shared": (5, 9)} 50 | assert torch.allclose(mixcoder.modality_indices["text"], torch.tensor([1, 2, 3, 4])) 51 | assert torch.allclose(mixcoder.modality_indices["image"], torch.tensor([5, 6, 7, 8])) 52 | 53 | 54 | def test_encode_decode(mixcoder, config): 55 | """Test the encoding and decoding process.""" 56 | mixcoder.set_dataset_average_activation_norm({"hook_point_in": 1.0, "hook_point_out": 1.0}) 57 | batch_size = 8 58 | x = torch.randn(batch_size, config.d_model) # batch, d_model 59 | tokens = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) 60 | x_text = torch.cat([x[:4, :], torch.zeros(4, config.d_model)], dim=0) 61 | x_image = torch.cat([torch.zeros(4, config.d_model), x[4:, :]], dim=0) 62 | tokens_text = torch.tensor([1, 2, 3, 4, 0, 0, 0, 0]) 63 | tokens_image = torch.tensor([0, 0, 0, 0, 5, 6, 7, 8]) 64 | # Test encode 65 | feature_acts = mixcoder.encode(x, tokens=tokens) 66 | assert feature_acts.shape == (batch_size, config.d_sae) # batch, d_sae 67 | 68 | feature_acts_text = mixcoder.encode(x_text, tokens=tokens_text) 69 | assert feature_acts_text.shape == (batch_size, config.d_sae) 70 | feature_acts_image = mixcoder.encode(x_image, tokens=tokens_image) 71 | assert feature_acts_image.shape == (batch_size, config.d_sae) 72 | modality_index = mixcoder.get_modality_index() 73 | 74 | assert torch.allclose( 75 | feature_acts_text[:4, slice(*modality_index["text"])], 76 | feature_acts[:4, slice(*modality_index["text"])], 77 | ) 78 | assert torch.allclose( 79 | feature_acts_image[4:, slice(*modality_index["image"])], 80 | feature_acts[4:, slice(*modality_index["image"])], 81 | ) 82 | 83 | assert torch.allclose( 84 | torch.cat( 85 | [ 86 | feature_acts_text[:4, slice(*modality_index["shared"])], 87 | feature_acts_image[4:, slice(*modality_index["shared"])], 88 | ], 89 | dim=0, 90 | ), 91 | feature_acts[:, slice(*modality_index["shared"])], 92 | ) 93 | print(feature_acts) 94 | 95 | # Test decode 96 | reconstructed = mixcoder.decode(feature_acts, tokens=tokens) 97 | assert reconstructed.shape == (batch_size, config.d_model) 98 | 99 | reconstructed_text = mixcoder.decode(feature_acts_text, tokens=tokens_text) 100 | assert reconstructed_text.shape == (batch_size, config.d_model) 101 | 102 | reconstructed_image = mixcoder.decode(feature_acts_image, tokens=tokens_image) 103 | assert reconstructed_image.shape == (batch_size, config.d_model) 104 | 105 | assert torch.allclose(reconstructed_text[:4, :], reconstructed[:4, :]) 106 | assert torch.allclose(reconstructed_image[4:, :], reconstructed[4:, :]) 107 | 108 | 109 | def test_get_modality_activation_mask(mixcoder, config): 110 | """Test the _get_modality_activation method.""" 111 | tokens = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) 112 | 113 | # Test text modality 114 | text_activation_mask = mixcoder.get_modality_token_mask(tokens, "text") 115 | assert torch.all(text_activation_mask[:4] == 1) # First 4 positions should be 1 116 | assert torch.all(text_activation_mask[4:] == 0) # Last 4 positions should be 0 117 | 118 | # Test image modality 119 | image_activation_mask = mixcoder.get_modality_token_mask(tokens, "image") 120 | assert torch.all(image_activation_mask[:4] == 0) # First 4 positions should be 0 121 | assert torch.all(image_activation_mask[4:] == 1) # Last 4 positions should be 1 122 | -------------------------------------------------------------------------------- /tests/unit/test_util_distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.distributed as dist 5 | from torch.distributed.device_mesh import init_device_mesh 6 | 7 | from lm_saes.utils.misc import get_mesh_dim_size, get_mesh_rank 8 | 9 | 10 | def setup_distributed(): 11 | """Initialize distributed training with torchrun.""" 12 | # torchrun sets these environment variables automatically 13 | rank = int(os.environ.get("RANK", 0)) 14 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 15 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 16 | 17 | # Initialize the process group 18 | dist.init_process_group("nccl" if torch.cuda.is_available() else "gloo") 19 | 20 | # Set device 21 | if torch.cuda.is_available(): 22 | torch.cuda.set_device(local_rank) 23 | device = f"cuda:{local_rank}" 24 | else: 25 | device = "cpu" 26 | 27 | return device, rank, world_size 28 | 29 | 30 | def cleanup_distributed(): 31 | """Clean up distributed training.""" 32 | dist.destroy_process_group() 33 | 34 | 35 | def test_get_mesh_dim_size_none(): 36 | """Test get_mesh_dim_size with None device mesh.""" 37 | result = get_mesh_dim_size(None, "data") 38 | assert result == 1, f"Expected 1 for None device mesh, got {result}" 39 | 40 | 41 | def test_get_mesh_rank_none(): 42 | """Test get_mesh_rank with None device mesh.""" 43 | result = get_mesh_rank(None) 44 | assert result == 0, f"Expected 0 for None device mesh, got {result}" 45 | 46 | 47 | def test_get_mesh_dim_size_and_rank_2d_distributed(): 48 | """Test get_mesh_dim_size and get_mesh_rank with 2D device mesh.""" 49 | device, rank, world_size = setup_distributed() 50 | print(f"Running 2D test on device {device}, rank {rank}, world size {world_size}") 51 | assert world_size == 8, "This test requires 8 processes" 52 | 53 | # Test with 2D mesh (2x1) 54 | device_mesh = init_device_mesh( 55 | device_type="cuda" if torch.cuda.is_available() else "cpu", mesh_shape=[2, 2, 2], mesh_dim_names=["a", "b", "c"] 56 | ) 57 | 58 | # Test get_mesh_dim_size for both dimensions 59 | a_dim_size = get_mesh_dim_size(device_mesh, "a") 60 | b_dim_size = get_mesh_dim_size(device_mesh, "b") 61 | c_dim_size = get_mesh_dim_size(device_mesh, "c") 62 | 63 | assert a_dim_size == 2, f"Expected a dimension size 2, got {a_dim_size}" 64 | assert b_dim_size == 2, f"Expected b dimension size 2, got {b_dim_size}" 65 | assert c_dim_size == 2, f"Expected c dimension size 2, got {c_dim_size}" 66 | 67 | # Test get_mesh_rank 68 | mesh_rank = get_mesh_rank(device_mesh) 69 | assert mesh_rank == rank, f"Expected mesh rank {rank}, got {mesh_rank}" 70 | 71 | sub_mesh = device_mesh["b", "c"] 72 | sub_mesh_rank = get_mesh_rank(sub_mesh) 73 | assert sub_mesh_rank == rank % 4, f"Expected sub mesh rank {rank % 4}, got {sub_mesh_rank}" 74 | b_dim_size = get_mesh_dim_size(sub_mesh, "b") 75 | assert b_dim_size == 2, f"Expected b dimension size 2, got {b_dim_size}" 76 | c_dim_size = get_mesh_dim_size(sub_mesh, "c") 77 | assert c_dim_size == 2, f"Expected c dimension size 2, got {c_dim_size}" 78 | 79 | sub_mesh = device_mesh["a", "c"] 80 | sub_mesh_rank = get_mesh_rank(sub_mesh) 81 | if rank in [0, 1, 4, 5]: 82 | assert sub_mesh_rank == [0, 1, 4, 5].index(rank), ( 83 | f"Expected sub mesh rank {[0, 1, 4, 5].index(rank)}, got {sub_mesh_rank}" 84 | ) 85 | else: 86 | assert sub_mesh_rank == [2, 3, 6, 7].index(rank), ( 87 | f"Expected sub mesh rank {[2, 3, 6, 7].index(rank)}, got {sub_mesh_rank}" 88 | ) 89 | a_dim_size = get_mesh_dim_size(sub_mesh, "a") 90 | assert a_dim_size == 2, f"Expected a dimension size 2, got {a_dim_size}" 91 | c_dim_size = get_mesh_dim_size(sub_mesh, "c") 92 | assert c_dim_size == 2, f"Expected c dimension size 2, got {c_dim_size}" 93 | 94 | cleanup_distributed() 95 | 96 | 97 | if __name__ == "__main__": 98 | # Test non-distributed functions first 99 | test_get_mesh_dim_size_none() 100 | test_get_mesh_rank_none() 101 | 102 | # Test distributed functions 103 | test_get_mesh_dim_size_and_rank_2d_distributed() 104 | -------------------------------------------------------------------------------- /ui/.env.example: -------------------------------------------------------------------------------- 1 | VITE_BACKEND_URL=http://localhost:24577 -------------------------------------------------------------------------------- /ui/.eslintrc.cjs: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | root: true, 3 | env: { browser: true, es2020: true }, 4 | extends: ["eslint:recommended", "plugin:@typescript-eslint/recommended", "plugin:react-hooks/recommended"], 5 | ignorePatterns: ["dist", ".eslintrc.cjs"], 6 | parser: "@typescript-eslint/parser", 7 | plugins: ["react-refresh"], 8 | rules: { 9 | "react-refresh/only-export-components": ["warn", { allowConstantExport: true }], 10 | "@typescript-eslint/no-unused-vars": [ 11 | "error", 12 | { 13 | args: "all", 14 | argsIgnorePattern: "^_", 15 | caughtErrors: "all", 16 | caughtErrorsIgnorePattern: "^_", 17 | destructuredArrayIgnorePattern: "^_", 18 | varsIgnorePattern: "^_", 19 | ignoreRestSiblings: true, 20 | }, 21 | ], 22 | }, 23 | }; 24 | -------------------------------------------------------------------------------- /ui/.gitignore: -------------------------------------------------------------------------------- 1 | # Logs 2 | logs 3 | *.log 4 | npm-debug.log* 5 | yarn-debug.log* 6 | yarn-error.log* 7 | pnpm-debug.log* 8 | lerna-debug.log* 9 | 10 | node_modules 11 | dist 12 | dist-ssr 13 | *.local 14 | 15 | # Editor directories and files 16 | .vscode/* 17 | !.vscode/extensions.json 18 | .idea 19 | .DS_Store 20 | *.suo 21 | *.ntvs* 22 | *.njsproj 23 | *.sln 24 | *.sw? 25 | 26 | # Environment variables 27 | .env 28 | .env.local 29 | .env.development 30 | .env.test 31 | .env.production -------------------------------------------------------------------------------- /ui/.prettierrc.cjs: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | printWidth: 120, // max 120 chars in line, code is easy to read 3 | useTabs: false, // use spaces instead of tabs 4 | tabWidth: 2, // "visual width" of of the "tab" 5 | trailingComma: "es5", // add trailing commas in objects, arrays, etc. 6 | semi: true, // add ; when needed 7 | singleQuote: false, // '' for stings instead of "" 8 | bracketSpacing: true, // import { some } ... instead of import {some} ... 9 | arrowParens: "always", // braces even for single param in arrow functions (a) => { } 10 | jsxSingleQuote: false, // "" for react props, like in html 11 | bracketSameLine: false, // pretty JSX 12 | endOfLine: "lf", // 'lf' for linux, 'crlf' for windows, we need to use 'lf' for git 13 | }; 14 | -------------------------------------------------------------------------------- /ui/README.md: -------------------------------------------------------------------------------- 1 | # React + TypeScript + Vite 2 | 3 | This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules. 4 | 5 | Currently, two official plugins are available: 6 | 7 | - [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react/README.md) uses [Babel](https://babeljs.io/) for Fast Refresh 8 | - [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh 9 | 10 | ## Expanding the ESLint configuration 11 | 12 | If you are developing a production application, we recommend updating the configuration to enable type aware lint rules: 13 | 14 | - Configure the top-level `parserOptions` property like this: 15 | 16 | ```js 17 | export default { 18 | // other rules... 19 | parserOptions: { 20 | ecmaVersion: "latest", 21 | sourceType: "module", 22 | project: ["./tsconfig.json", "./tsconfig.node.json"], 23 | tsconfigRootDir: __dirname, 24 | }, 25 | }; 26 | ``` 27 | 28 | - Replace `plugin:@typescript-eslint/recommended` to `plugin:@typescript-eslint/recommended-type-checked` or `plugin:@typescript-eslint/strict-type-checked` 29 | - Optionally add `plugin:@typescript-eslint/stylistic-type-checked` 30 | - Install [eslint-plugin-react](https://github.com/jsx-eslint/eslint-plugin-react) and add `plugin:react/recommended` & `plugin:react/jsx-runtime` to the `extends` list 31 | -------------------------------------------------------------------------------- /ui/bun.lockb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMOSS/Language-Model-SAEs/dc8c9893108da636d156e507779ebb0aca6b9dca/ui/bun.lockb -------------------------------------------------------------------------------- /ui/components.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://ui.shadcn.com/schema.json", 3 | "style": "default", 4 | "rsc": false, 5 | "tsx": true, 6 | "tailwind": { 7 | "config": "tailwind.config.js", 8 | "css": "src/globals.css", 9 | "baseColor": "slate", 10 | "cssVariables": true, 11 | "prefix": "" 12 | }, 13 | "aliases": { 14 | "components": "@/components", 15 | "utils": "@/lib/utils" 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /ui/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Sparse Auto Encoder Visualizer 8 | 9 | 10 |
11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /ui/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ui", 3 | "private": true, 4 | "version": "0.0.0", 5 | "type": "module", 6 | "scripts": { 7 | "dev": "vite", 8 | "build": "tsc && vite build", 9 | "lint": "eslint . --ext ts,tsx --report-unused-disable-directives --max-warnings 0", 10 | "preview": "vite preview" 11 | }, 12 | "dependencies": { 13 | "@msgpack/msgpack": "^3.0.0-beta2", 14 | "@radix-ui/react-accordion": "^1.1.2", 15 | "@radix-ui/react-context-menu": "^2.2.1", 16 | "@radix-ui/react-dialog": "^1.1.1", 17 | "@radix-ui/react-dropdown-menu": "^2.0.6", 18 | "@radix-ui/react-hover-card": "^1.0.7", 19 | "@radix-ui/react-label": "^2.0.2", 20 | "@radix-ui/react-popover": "^1.1.1", 21 | "@radix-ui/react-select": "^2.0.0", 22 | "@radix-ui/react-separator": "^1.0.3", 23 | "@radix-ui/react-slot": "^1.0.2", 24 | "@radix-ui/react-switch": "^1.1.0", 25 | "@radix-ui/react-tabs": "^1.0.4", 26 | "@radix-ui/react-toggle": "^1.0.3", 27 | "@radix-ui/react-tooltip": "^1.1.2", 28 | "@tanstack/react-table": "^8.15.3", 29 | "@xyflow/react": "^12.0.4", 30 | "camelcase-keys": "^9.1.3", 31 | "class-variance-authority": "^0.7.0", 32 | "clsx": "^2.1.0", 33 | "cmdk": "1.0.0", 34 | "dagre": "^0.8.5", 35 | "lucide-react": "^0.358.0", 36 | "plotly.js": "^2.30.1", 37 | "react": "^18.2.0", 38 | "react-dom": "^18.2.0", 39 | "react-plotly.js": "^2.6.0", 40 | "react-router-dom": "^6.22.3", 41 | "react-use": "^17.5.0", 42 | "recharts": "^2.12.7", 43 | "snakecase-keys": "^8.0.1", 44 | "tailwind-merge": "^2.2.1", 45 | "tailwindcss-animate": "^1.0.7", 46 | "zod": "^3.22.4", 47 | "zustand": "^4.5.2" 48 | }, 49 | "devDependencies": { 50 | "@types/dagre": "^0.7.52", 51 | "@types/node": "^20.11.28", 52 | "@types/react": "^18.2.64", 53 | "@types/react-dom": "^18.2.21", 54 | "@types/react-plotly.js": "^2.6.3", 55 | "@typescript-eslint/eslint-plugin": "^7.1.1", 56 | "@typescript-eslint/parser": "^7.1.1", 57 | "@vitejs/plugin-react": "^4.2.1", 58 | "autoprefixer": "^10.4.18", 59 | "eslint": "^8.57.0", 60 | "eslint-plugin-react-hooks": "^4.6.0", 61 | "eslint-plugin-react-refresh": "^0.4.5", 62 | "postcss": "^8.4.35", 63 | "tailwindcss": "^3.4.1", 64 | "typescript": "^5.2.2", 65 | "vite": "^5.1.6" 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /ui/postcss.config.js: -------------------------------------------------------------------------------- 1 | export default { 2 | plugins: { 3 | tailwindcss: {}, 4 | autoprefixer: {}, 5 | }, 6 | }; 7 | -------------------------------------------------------------------------------- /ui/public/openmoss.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMOSS/Language-Model-SAEs/dc8c9893108da636d156e507779ebb0aca6b9dca/ui/public/openmoss.ico -------------------------------------------------------------------------------- /ui/public/vite.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ui/src/components/app/feature-preview.tsx: -------------------------------------------------------------------------------- 1 | import { HoverCard, HoverCardContent, HoverCardTrigger } from "../ui/hover-card"; 2 | import { useAsync } from "react-use"; 3 | import { Feature, FeatureSchema } from "@/types/feature"; 4 | import { decode } from "@msgpack/msgpack"; 5 | import camelcaseKeys from "camelcase-keys"; 6 | import { create } from "zustand"; 7 | import { Link } from "react-router-dom"; 8 | 9 | const useFeaturePreviewStore = create<{ 10 | features: Record; 11 | addFeature: (feature: Feature) => void; 12 | }>((set) => ({ 13 | features: {}, 14 | addFeature: (feature: Feature) => 15 | set((state) => ({ 16 | features: { 17 | ...state.features, 18 | [`${feature.dictionaryName}---${feature.featureIndex}`]: feature, 19 | }, 20 | })), 21 | })); 22 | 23 | export const FeaturePreview = ({ dictionaryName, featureIndex }: { dictionaryName: string; featureIndex: number }) => { 24 | const featureInStore: Feature | null = useFeaturePreviewStore( 25 | (state) => state.features[`${dictionaryName}---${featureIndex}`] || null 26 | ); 27 | const addFeature = useFeaturePreviewStore((state) => state.addFeature); 28 | 29 | const state = useAsync(async () => { 30 | if (featureInStore) { 31 | return featureInStore; 32 | } 33 | const feature = await fetch( 34 | `${import.meta.env.VITE_BACKEND_URL}/dictionaries/${dictionaryName}/features/${featureIndex}`, 35 | { 36 | method: "GET", 37 | headers: { 38 | Accept: "application/x-msgpack", 39 | }, 40 | } 41 | ) 42 | .then(async (res) => { 43 | if (!res.ok) { 44 | throw new Error(await res.text()); 45 | } 46 | return res; 47 | }) 48 | .then(async (res) => await res.arrayBuffer()) 49 | // eslint-disable-next-line @typescript-eslint/no-explicit-any 50 | .then((res) => decode(new Uint8Array(res)) as any) 51 | .then((res) => 52 | camelcaseKeys(res, { 53 | deep: true, 54 | stopPaths: ["sample_groups.samples.context"], 55 | }) 56 | ) 57 | .then((res) => FeatureSchema.parse(res)); 58 | addFeature(feature); 59 | return feature; 60 | }); 61 | 62 | return ( 63 |
64 | {state.loading &&

Loading...

} 65 | {state.error &&

Error: {state.error.message}

} 66 | {state.value && ( 67 |
68 |
Feature:
69 |
#{featureIndex}
70 |
Interpretation:
71 |
{state.value.interpretation?.text || "N/A"}
72 |
73 | )} 74 |
75 | ); 76 | }; 77 | 78 | export const FeatureLinkWithPreview = ({ 79 | dictionaryName, 80 | featureIndex, 81 | }: { 82 | dictionaryName: string; 83 | featureIndex: number; 84 | }) => { 85 | return ( 86 | 87 | 88 | 89 | #{featureIndex} 90 | 91 | 92 | 93 | 94 | 95 | 96 | ); 97 | }; 98 | -------------------------------------------------------------------------------- /ui/src/components/app/navbar.tsx: -------------------------------------------------------------------------------- 1 | import { cn } from "@/lib/utils"; 2 | import { Link, useLocation } from "react-router-dom"; 3 | 4 | export const AppNavbar = () => { 5 | const location = useLocation(); 6 | 7 | return ( 8 | 52 | ); 53 | }; 54 | -------------------------------------------------------------------------------- /ui/src/components/app/sample.tsx: -------------------------------------------------------------------------------- 1 | import { useState } from "react"; 2 | import { TokenGroup } from "./token"; 3 | import { cn } from "@/lib/utils"; 4 | 5 | export type SampleProps = { 6 | sampleName?: string; 7 | tokenGroups: T[][]; 8 | tokenGroupClassName?: (tokenGroup: T[], i: number) => string; 9 | tokenGroupProps?: (tokenGroup: T[], i: number) => React.HTMLProps; 10 | tokenInfoContent?: (tokenGroup: T[], i: number) => (token: T, i: number) => React.ReactNode; 11 | tokenGroupInfoContent?: (tokenGroup: T[], i: number) => React.ReactNode; 12 | customTokenGroup?: (tokens: T[], i: number) => React.ReactNode; 13 | foldedStart?: number; 14 | }; 15 | 16 | export const Sample = ({ 17 | sampleName, 18 | tokenGroups, 19 | tokenGroupClassName, 20 | tokenGroupProps, 21 | tokenInfoContent, 22 | tokenGroupInfoContent, 23 | customTokenGroup, 24 | foldedStart, 25 | }: SampleProps) => { 26 | const [folded, setFolded] = useState(true); 27 | 28 | return ( 29 |
30 |
setFolded(!folded) : undefined} 33 | > 34 | {sampleName && {sampleName}: } 35 | {folded && !!foldedStart && ...} 36 | {tokenGroups 37 | .slice((folded && foldedStart) || 0) 38 | .map((tokens, i) => 39 | customTokenGroup ? ( 40 | customTokenGroup(tokens, i) 41 | ) : ( 42 | 50 | ) 51 | )} 52 |
53 |
54 | ); 55 | }; 56 | -------------------------------------------------------------------------------- /ui/src/components/app/section-navigator.tsx: -------------------------------------------------------------------------------- 1 | import { cn } from "@/lib/utils"; 2 | import { Card, CardContent, CardHeader, CardTitle } from "../ui/card"; 3 | import { useEffect, useState } from "react"; 4 | 5 | export const SectionNavigator = ({ sections }: { sections: { title: string; id: string }[] }) => { 6 | const [activeSection, setActiveSection] = useState<{ title: string; id: string } | null>(null); 7 | 8 | const handleScroll = () => { 9 | // Use reduce instead of find for obtaining the last section that is in view 10 | const currentSection = sections.reduce((result: { title: string; id: string } | null, section) => { 11 | const secElement = document.getElementById(section.id); 12 | if (!secElement) return result; 13 | const rect = secElement.getBoundingClientRect(); 14 | if (rect.top <= window.innerHeight / 2) { 15 | return section; 16 | } 17 | return result; 18 | }, null); 19 | 20 | setActiveSection(currentSection); 21 | }; 22 | 23 | useEffect(() => { 24 | window.addEventListener("scroll", handleScroll); 25 | 26 | // Run the handler to set the initial active section 27 | handleScroll(); 28 | 29 | return () => { 30 | window.removeEventListener("scroll", handleScroll); 31 | }; 32 | }); 33 | 34 | return ( 35 | 36 | 37 | 38 | CONTENTS 39 | 40 | 41 | 42 |
43 | 56 |
57 |
58 |
59 | ); 60 | }; 61 | -------------------------------------------------------------------------------- /ui/src/components/app/token.tsx: -------------------------------------------------------------------------------- 1 | import { cn } from "@/lib/utils"; 2 | import { mergeUint8Arrays } from "@/utils/array"; 3 | import { HoverCard, HoverCardContent, HoverCardTrigger } from "../ui/hover-card"; 4 | import { Fragment } from "react/jsx-runtime"; 5 | import { Separator } from "../ui/separator"; 6 | 7 | export type PlainTokenGroupProps = { 8 | tokenGroup: T[]; 9 | tokenGroupClassName?: string; 10 | tokenGroupProps?: React.HTMLProps; 11 | }; 12 | 13 | export const PlainTokenGroup = ({ 14 | tokenGroup, 15 | tokenGroupClassName, 16 | tokenGroupProps, 17 | }: PlainTokenGroupProps) => { 18 | const decoder = new TextDecoder("utf-8", { fatal: true }); 19 | 20 | return ( 21 | 28 | {decoder 29 | .decode(mergeUint8Arrays(tokenGroup.map((t) => t.token))) 30 | .replace("\n", "⏎") 31 | .replace("\t", "⇥") 32 | .replace("\r", "↵")} 33 | 34 | ); 35 | }; 36 | 37 | export type TokenGroupProps = PlainTokenGroupProps & { 38 | tokenInfoContent?: (token: T, i: number) => React.ReactNode; 39 | tokenGroupInfoContent?: React.ReactNode; 40 | }; 41 | 42 | export const TokenGroup = ({ 43 | tokenGroup, 44 | tokenGroupClassName, 45 | tokenGroupProps, 46 | tokenInfoContent, 47 | tokenGroupInfoContent, 48 | }: TokenGroupProps) => { 49 | if (!tokenInfoContent && !tokenGroupInfoContent) { 50 | return ( 51 | 56 | ); 57 | } 58 | 59 | return ( 60 | 61 | 62 | 67 | 68 | 69 | {tokenGroupInfoContent ? ( 70 | {tokenGroupInfoContent} 71 | ) : ( 72 | tokenGroup.map((token, i) => ( 73 | 74 | {tokenInfoContent?.(token, i)} 75 | {i < tokenGroup.length - 1 && } 76 | 77 | )) 78 | )} 79 | 80 | 81 | ); 82 | }; 83 | -------------------------------------------------------------------------------- /ui/src/components/attn-head/attn-head-card.tsx: -------------------------------------------------------------------------------- 1 | import { AttentionHead } from "@/types/attn-head"; 2 | import { Card, CardHeader, CardTitle, CardContent } from "../ui/card"; 3 | import { Table, TableBody, TableCell, TableHead, TableHeader, TableRow } from "../ui/table"; 4 | import { FeatureLinkWithPreview } from "../app/feature-preview"; 5 | 6 | export const AttentionHeadCard = ({ attnHead }: { attnHead: AttentionHead }) => { 7 | return ( 8 | 9 | 10 | 11 | Attention Head {attnHead.layer}.{attnHead.head} 12 | 13 | 14 | 15 |
16 | {attnHead.attnScores.map((attnScoreGroup, idx) => ( 17 |
18 |

19 | {attnScoreGroup.dictionary1Name} {" -> "} 20 | {attnScoreGroup.dictionary2Name} 21 |

22 | 23 | 24 | 25 | Feature 26 | Feature Attended 27 | Attention Score 28 | 29 | 30 | 31 | {attnScoreGroup.topAttnScores.map((attnScore, idx) => ( 32 | 33 | 34 | 38 | 39 | 40 | 44 | 45 | {attnScore.attnScore.toFixed(3)} 46 | 47 | ))} 48 | 49 |
50 |
51 | ))} 52 |
53 |
54 |
55 | ); 56 | }; 57 | -------------------------------------------------------------------------------- /ui/src/components/dictionary/dictionary-card.tsx: -------------------------------------------------------------------------------- 1 | import { useState } from "react"; 2 | import { Button } from "../ui/button"; 3 | import { Card, CardContent, CardHeader, CardTitle } from "../ui/card"; 4 | import { Dictionary, DictionarySampleCompact, DictionarySampleCompactSchema } from "@/types/dictionary"; 5 | import Plot from "react-plotly.js"; 6 | import { useAsyncFn } from "react-use"; 7 | import { decode } from "@msgpack/msgpack"; 8 | import camelcaseKeys from "camelcase-keys"; 9 | import { Textarea } from "../ui/textarea"; 10 | import { DictionarySample } from "./sample"; 11 | 12 | const DictionaryCustomInputArea = ({ dictionary }: { dictionary: Dictionary }) => { 13 | const [customInput, setCustomInput] = useState(""); 14 | const [samples, setSamples] = useState([]); 15 | const [state, submit] = useAsyncFn(async () => { 16 | if (!customInput) { 17 | alert("Please enter your input."); 18 | return; 19 | } 20 | const sample = await fetch( 21 | `${import.meta.env.VITE_BACKEND_URL}/dictionaries/${ 22 | dictionary.dictionaryName 23 | }/custom?input_text=${encodeURIComponent(customInput)}`, 24 | { 25 | method: "POST", 26 | headers: { 27 | Accept: "application/x-msgpack", 28 | }, 29 | } 30 | ) 31 | .then(async (res) => { 32 | if (!res.ok) { 33 | throw new Error(await res.text()); 34 | } 35 | return res; 36 | }) 37 | .then(async (res) => await res.arrayBuffer()) 38 | // eslint-disable-next-line @typescript-eslint/no-explicit-any 39 | .then((res) => decode(new Uint8Array(res)) as any) 40 | .then((res) => 41 | camelcaseKeys(res, { 42 | deep: true, 43 | stopPaths: ["context"], 44 | }) 45 | ) 46 | .then((res) => DictionarySampleCompactSchema.parse(res)); 47 | setSamples((prev) => [...prev, sample]); 48 | }, [customInput]); 49 | 50 | return ( 51 |
52 |

Custom Input

53 |