├── .gitignore ├── .vscode ├── extensions.json └── settings.json ├── LICENSE ├── README.md ├── app ├── .eslintrc.json ├── .gitignore ├── .prettierrc ├── app │ ├── components │ │ ├── ColorSpan.tsx │ │ ├── Controls.tsx │ │ ├── LatentExamples.tsx │ │ ├── LogitsTable.tsx │ │ └── prompt │ │ │ ├── PromptActivations.tsx │ │ │ ├── PromptLatentHeatmaps.tsx │ │ │ ├── PromptLayerHistograms.tsx │ │ │ ├── PromptLogitsInput.tsx │ │ │ ├── PromptLogitsRecon.tsx │ │ │ └── PromptLogitsSteer.tsx │ ├── favicon.ico │ ├── globals.css │ ├── layout.tsx │ ├── page.tsx │ ├── prompt │ │ └── [prompt] │ │ │ ├── page.client.tsx │ │ │ └── page.tsx │ └── use-select.tsx ├── components.json ├── components │ └── ui │ │ ├── button.tsx │ │ ├── card.tsx │ │ ├── chart.tsx │ │ ├── form.tsx │ │ ├── input.tsx │ │ ├── label.tsx │ │ ├── table.tsx │ │ ├── tabs.tsx │ │ └── textarea.tsx ├── lib │ ├── api.ts │ ├── format.ts │ ├── models.ts │ └── utils.ts ├── next.config.js ├── package-lock.json ├── package.json ├── postcss.config.js ├── tailwind.config.js └── tsconfig.json ├── citation.bib ├── figures.py ├── figures ├── __init__.py ├── embed_sim.py ├── entropy.py ├── heatmap.py ├── heatmap_aggregate.py ├── heatmap_prompt.py ├── layer_hist.py ├── layer_sim.py ├── layer_std.py ├── mmcs.py ├── num_layers.py ├── resid_sim.py ├── scatter_freq.py ├── test.py └── wdec_sim.py ├── layer_dists.py ├── layer_tests.py ├── mlsae ├── __init__.py ├── analysis │ ├── __init__.py │ ├── dists.py │ ├── examples.py │ └── variances.py ├── api │ ├── __init__.py │ ├── __main__.py │ ├── analyser.py │ └── models.py ├── metrics │ ├── __init__.py │ ├── auxiliary_loss.py │ ├── dead_latents.py │ ├── layerwise.py │ ├── layerwise_fvu.py │ ├── layerwise_l0_norm.py │ ├── layerwise_l1_norm.py │ ├── layerwise_logit_kl_div.py │ ├── layerwise_logit_mse.py │ ├── layerwise_loss_delta.py │ ├── layerwise_mse.py │ ├── mse_loss.py │ └── tests │ │ ├── __init__.py │ │ ├── test_dead_latents.py │ │ ├── test_layerwise_fvu.py │ │ ├── test_layerwise_l0_norm.py │ │ ├── test_layerwise_l1_norm.py │ │ ├── test_layerwise_mse.py │ │ └── test_loss_mse.py ├── model │ ├── __init__.py │ ├── autoencoders │ │ ├── __init__.py │ │ ├── standard.py │ │ ├── tests │ │ │ └── test_autoencoders.py │ │ ├── topk.py │ │ └── utils.py │ ├── data.py │ ├── decoder.py │ ├── geom_median.py │ ├── kernels.py │ ├── lightning.py │ ├── transformers │ │ ├── __init__.py │ │ ├── gemma2.py │ │ ├── gpt2.py │ │ ├── llama.py │ │ ├── models │ │ │ ├── gemma2 │ │ │ │ └── modeling_gemma2.py │ │ │ ├── gpt2 │ │ │ │ └── modeling_gpt2.py │ │ │ └── llama │ │ │ │ └── modeling_llama.py │ │ ├── pythia.py │ │ └── tests │ │ │ ├── test_gpt2.py │ │ │ └── test_llama.py │ └── types.py ├── model_card.py ├── trainer │ ├── __init__.py │ ├── config.py │ ├── test.py │ └── train.py └── utils.py ├── pyproject.toml ├── requirements.txt ├── test.py ├── tests.py ├── train.py ├── upload.py └── uv.lock /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | # .DS_Store 165 | # .ruff_cache 166 | 167 | # .env 168 | # .venv 169 | 170 | # Slurm 171 | *.o 172 | *.out 173 | 174 | # PyTorch Lightning 175 | lightning_logs 176 | 177 | # Weights & Biases 178 | wandb_logs 179 | 180 | data 181 | models 182 | !mlsae/**/models 183 | out 184 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": ["charliermarsh.ruff", "ms-python.python"] 3 | } 4 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "ruff.configuration": "pyproject.toml", 3 | "ruff.nativeServer": true 4 | } 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2024] [Tim Lawson] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Layer Sparse Autoencoders (MLSAE) 2 | 3 | > [!NOTE] 4 | > This repository accompanies the preprint Residual Stream Analysis with 5 | > Multi-Layer SAEs (). 6 | > See [References](#references) for related work. 7 | 8 | ## Pretrained MLSAEs 9 | 10 | We define two types of model: plain PyTorch 11 | [MLSAE](./mlsae/model/autoencoder.py) modules, which are relatively small; and 12 | PyTorch Lightning [MLSAETransformer](./mlsae/model/lightning.py) modules, which 13 | include the underlying transformer. HuggingFace collections for both are here: 14 | 15 | - [Multi-Layer Sparse Autoencoders](https://huggingface.co/collections/tim-lawson/multi-layer-sparse-autoencoders-66c2fe8896583c59b02ceb72) 16 | - [Multi-Layer Sparse Autoencoders with Transformers](https://huggingface.co/collections/tim-lawson/multi-layer-sparse-autoencoders-with-transformers-66c441c87d1b24912175ce08) 17 | 18 | We assume that pretrained MLSAEs have repo_ids with 19 | [this naming convention](./mlsae/utils.py): 20 | 21 | - `tim-lawson/mlsae-pythia-70m-deduped-x{expansion_factor}-k{k}` 22 | - `tim-lawson/mlsae-pythia-70m-deduped-x{expansion_factor}-k{k}-tfm` 23 | 24 | The Weights & Biases project for the paper is 25 | [here](https://wandb.ai/timlawson-/mlsae). 26 | 27 | ## Installation 28 | 29 | Install Python dependencies with Poetry: 30 | 31 | ```bash 32 | poetry env use 3.12 33 | poetry install 34 | ``` 35 | 36 | Install Python dependencies with pip: 37 | 38 | ```bash 39 | python -m venv .venv 40 | source .venv/bin/activate 41 | pip install -r requirements.txt 42 | ``` 43 | 44 | Install Node.js dependencies: 45 | 46 | ```bash 47 | cd app 48 | npm install 49 | ``` 50 | 51 | ## Training 52 | 53 | Train a single MLSAE: 54 | 55 | ```bash 56 | python train.py --help 57 | python train.py --model_name EleutherAI/pythia-70m-deduped --expansion_factor 64 -k 32 58 | ``` 59 | 60 | ## Analysis 61 | 62 | Test a single pretrained MLSAE: 63 | 64 | > [!WARNING] 65 | > We assume that the test split of `monology/pile-uncopyrighted` is already downloaded 66 | > and stored in `data/test.jsonl.zst`. 67 | 68 | ```bash 69 | python test.py --help 70 | python test.py --model_name EleutherAI/pythia-70m-deduped --expansion_factor 64 -k 32 71 | ``` 72 | 73 | Compute the distributions of latent activations over layers for a single 74 | pretrained MLSAE 75 | ([HuggingFace datasets](https://huggingface.co/collections/tim-lawson/mlsae-latent-distributions-over-layers-66d6a0ec9fcb6b494fb1808e)): 76 | 77 | ```bash 78 | python -m mlsae.analysis.dists --help 79 | python -m mlsae.analysis.dists --repo_id tim-lawson/mlsae-pythia-70m-deduped-x64-k32-tfm --max_tokens 100_000_000 80 | ``` 81 | 82 | Compute the maximally activating examples for each combination of latent and 83 | layer for a single pretrained MLSAE 84 | ([HuggingFace datasets](https://huggingface.co/collections/tim-lawson/mlsae-maximally-activating-examples-66dbcc999a962ae594f631b6)): 85 | 86 | ```bash 87 | python -m mlsae.analysis.examples --help 88 | python -m mlsae.analysis.examples --repo_id tim-lawson/mlsae-pythia-70m-deduped-x64-k32-tfm --max_tokens 1_000_000 89 | ``` 90 | 91 | ## Interactive visualizations 92 | 93 | Run the interactive web application for a single pretrained MLSAE: 94 | 95 | ```bash 96 | python -m mlsae.api --help 97 | python -m mlsae.api --repo_id tim-lawson/mlsae-pythia-70m-deduped-x64-k32-tfm 98 | 99 | cd app 100 | npm run dev 101 | ``` 102 | 103 | Navigate to , enter a prompt, and click 'Submit'. 104 | 105 | Alternatively, navigate to . 106 | 107 | ## Figures 108 | 109 | Compute the mean cosine similarities between residual stream activation vectors 110 | at adjacent layers of a single pretrained transformer: 111 | 112 | ```bash 113 | python figures/resid_cos_sim.py --help 114 | python figures/resid_cos_sim.py --model_name EleutherAI/pythia-70m-deduped 115 | ``` 116 | 117 | Save heatmaps of the distributions of latent activations over layers for 118 | multiple pretrained MLSAEs: 119 | 120 | ```bash 121 | python figures/dists_heatmaps.py --help 122 | python figures/dists_heatmaps.py --expansion_factor 32 64 128 -k 16 32 64 123 | ``` 124 | 125 | Save a CSV of the mean standard deviations of the distributions of latent 126 | activations over layers for multiple pretrained MLSAEs: 127 | 128 | ```bash 129 | python figures/dists_layer_std.py --help 130 | python figures/dists_layer_std.py --expansion_factor 32 64 128 -k 16 32 64 131 | ``` 132 | 133 | Save heatmaps of the maximum latent activations for a given prompt and multiple 134 | pretrained MLSAEs: 135 | 136 | ```bash 137 | python figures/prompt_heatmaps.py --help 138 | python figures/prompt_heatmaps.py --expansion_factor 32 64 128 -k 16 32 64 139 | ``` 140 | 141 | Save a CSV of the Mean Max Cosine Similarity (MMCS) for multiple pretrained 142 | MLSAEs: 143 | 144 | ```bash 145 | python figures/mmcs.py --help 146 | python figures/mmcs.py --expansion_factor 32 64 128 -k 16 32 64 147 | ``` 148 | 149 | ## References 150 | 151 | ### Code 152 | 153 | - 154 | - 155 | - 156 | - 157 | 158 | ### Papers 159 | 160 | - Gao et al. [2024] 161 | - Bricken et al. [2023] 162 | 163 | -------------------------------------------------------------------------------- /app/.eslintrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": ["next/core-web-vitals", "prettier"], 3 | "rules": { 4 | "sort-imports": [ 5 | "error", 6 | { 7 | "ignoreDeclarationSort": true 8 | } 9 | ], 10 | "import/consistent-type-specifier-style": "error", 11 | "import/no-duplicates": ["error", { "prefer-inline": true }], 12 | "import/order": [ 13 | "warn", 14 | { 15 | "alphabetize": { 16 | "caseInsensitive": false, 17 | "order": "asc", 18 | "orderImportKind": "asc" 19 | }, 20 | "groups": [ 21 | "builtin", 22 | "type", 23 | "external", 24 | "internal", 25 | "parent", 26 | "sibling", 27 | "object", 28 | "index" 29 | ], 30 | "pathGroups": [ 31 | { 32 | "pattern": "~/**", 33 | "group": "external", 34 | "position": "after" 35 | } 36 | ], 37 | "pathGroupsExcludedImportTypes": ["builtin"] 38 | } 39 | ] 40 | }, 41 | "settings": { 42 | "import/resolver": { 43 | "typescript": {} 44 | } 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /app/.gitignore: -------------------------------------------------------------------------------- 1 | # dependencies 2 | /node_modules 3 | /.pnp 4 | .pnp.js 5 | 6 | # testing 7 | /coverage 8 | 9 | # next.js 10 | /.next/ 11 | /out/ 12 | 13 | # production 14 | /build 15 | 16 | # misc 17 | .DS_Store 18 | *.pem 19 | 20 | # debug 21 | npm-debug.log* 22 | yarn-debug.log* 23 | yarn-error.log* 24 | 25 | # local env files 26 | .env*.local 27 | 28 | # vercel 29 | .vercel 30 | 31 | # typescript 32 | *.tsbuildinfo 33 | next-env.d.ts 34 | -------------------------------------------------------------------------------- /app/.prettierrc: -------------------------------------------------------------------------------- 1 | { 2 | "tabWidth": 2, 3 | "useTabs": false 4 | } 5 | -------------------------------------------------------------------------------- /app/app/components/ColorSpan.tsx: -------------------------------------------------------------------------------- 1 | import { cn } from "~/lib/utils"; 2 | 3 | export default function ColorSpan({ 4 | children, 5 | opacity, 6 | className, 7 | color = "bg-orange-400 dark:bg-orange-800", 8 | style = {}, 9 | ...props 10 | }: { 11 | children?: React.ReactNode; 12 | opacity: number; 13 | className?: string | boolean | undefined; 14 | color?: string; 15 | style?: React.CSSProperties; 16 | } & React.HTMLAttributes) { 17 | return ( 18 | 19 |
23 | {children} 24 | 25 | ); 26 | } 27 | -------------------------------------------------------------------------------- /app/app/components/Controls.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import React from "react"; 4 | import { Select } from "~/app/use-select"; 5 | import { Input } from "~/components/ui/input"; 6 | import { Label } from "~/components/ui/label"; 7 | 8 | export default function Controls({ 9 | nLayers, 10 | nLatents, 11 | nPositions, 12 | stateLayer, 13 | stateLatent, 14 | statePosition, 15 | threshold, 16 | onChangeThreshold, 17 | factor, 18 | onChangeFactor, 19 | }: { 20 | nLayers: number; 21 | nLatents: number; 22 | nPositions: number; 23 | stateLayer: Select; 24 | stateLatent: Select; 25 | statePosition: Select; 26 | threshold: number; 27 | onChangeThreshold: (value: number) => void; 28 | factor: number; 29 | onChangeFactor: (value: number) => void; 30 | }) { 31 | const onChangeLayer = (event: React.ChangeEvent) => { 32 | stateLayer.onClick(event.target.valueAsNumber); 33 | }; 34 | 35 | const onChangeLatent = (event: React.ChangeEvent) => { 36 | stateLatent.onClick(event.target.valueAsNumber); 37 | }; 38 | 39 | const onChangePosition = (event: React.ChangeEvent) => { 40 | statePosition.onClick(event.target.valueAsNumber); 41 | }; 42 | 43 | return ( 44 |
45 |
46 | 47 | 57 |
58 |
59 | 60 | 70 |
71 |
72 | 73 | 83 |
84 |
85 | 88 | onChangeThreshold(event.target.valueAsNumber)} 93 | min={0.0} 94 | max={0.95} 95 | step={0.05} 96 | className="h-8" 97 | /> 98 |
99 |
100 | 103 | onChangeFactor(event.target.valueAsNumber)} 108 | min={-10} 109 | max={10} 110 | step={0.5} 111 | className="h-8" 112 | /> 113 |
114 |
115 | ); 116 | } 117 | -------------------------------------------------------------------------------- /app/app/components/LatentExamples.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import React from "react"; 4 | import useSWR from "swr"; 5 | import ColorSpan from "~/app/components/ColorSpan"; 6 | import { Select } from "~/app/use-select"; 7 | import { Card, CardContent } from "~/components/ui/card"; 8 | import { 9 | Table, 10 | TableBody, 11 | TableCell, 12 | TableHead, 13 | TableHeader, 14 | TableRow, 15 | } from "~/components/ui/table"; 16 | import { getExamples } from "~/lib/api"; 17 | import { escapeWhitespace } from "~/lib/format"; 18 | 19 | export default function LatentExamplesComponent({ 20 | className, 21 | stateLatent, 22 | stateLayer, 23 | }: { 24 | className?: string; 25 | stateLatent: Select; 26 | stateLayer: Select; 27 | }) { 28 | const { data: examples } = useSWR( 29 | ["examples", stateLatent.clicked, stateLayer.clicked], 30 | ([_key, latent, layer]) => getExamples(latent, layer), 31 | { 32 | keepPreviousData: true, 33 | }, 34 | ); 35 | 36 | if (examples === undefined || examples.length === 0) { 37 | return null; 38 | } 39 | 40 | return ( 41 | 42 | 43 | 44 | 45 | 46 | Example 47 | Activation 48 | 49 | 50 | 51 | {examples.map((example, index) => { 52 | return ( 53 | 54 | 55 | {example.tokens.map((token, position) => { 56 | const value = example.acts[position]; 57 | return ( 58 | 59 | {escapeWhitespace(token)} 60 | 61 | ); 62 | })} 63 | 64 | 65 | {example.act.toFixed(3)} 66 | 67 | 68 | ); 69 | })} 70 | 71 |
72 |
73 |
74 | ); 75 | } 76 | -------------------------------------------------------------------------------- /app/app/components/LogitsTable.tsx: -------------------------------------------------------------------------------- 1 | import ColorSpan from "~/app/components/ColorSpan"; 2 | import { 3 | Table, 4 | TableBody, 5 | TableCaption, 6 | TableCell, 7 | TableRow, 8 | } from "~/components/ui/table"; 9 | import { escapeWhitespace, toSigned, toUnitInterval } from "~/lib/format"; 10 | import { LogitType } from "~/lib/models"; 11 | 12 | export default function LogitsTable({ 13 | caption, 14 | data, 15 | color, 16 | }: { 17 | caption: string; 18 | data: LogitType[] | undefined; 19 | color: string; 20 | }) { 21 | const hasProbability = data?.[0].prob !== null; 22 | return ( 23 | 24 | 25 | {caption} 26 | 27 | 28 | {data?.map((logit) => ( 29 | 30 | 31 | 35 | {escapeWhitespace(logit.token)} 36 | 37 | 38 | {!hasProbability && ( 39 | 40 | {toSigned(logit.logit)} 41 | 42 | )} 43 | {hasProbability && ( 44 | 45 | {logit.prob?.toFixed(3)} 46 | 47 | )} 48 | 49 | ))} 50 | 51 |
52 | ); 53 | } 54 | -------------------------------------------------------------------------------- /app/app/components/prompt/PromptActivations.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import ColorSpan from "~/app/components/ColorSpan"; 4 | import { Select } from "~/app/use-select"; 5 | import { Card, CardContent } from "~/components/ui/card"; 6 | import { escapeWhitespace } from "~/lib/format"; 7 | import { LatentActivationsType, TokenType } from "~/lib/models"; 8 | import { cn } from "~/lib/utils"; 9 | 10 | export default function PromptActivationsComponent({ 11 | className, 12 | tokens, 13 | latentActivations, 14 | stateLatent, 15 | stateLayer, 16 | statePosition, 17 | }: { 18 | className?: string; 19 | tokens: TokenType[]; 20 | latentActivations: LatentActivationsType; 21 | stateLatent: Select; 22 | stateLayer: Select; 23 | statePosition: Select; 24 | }) { 25 | return ( 26 | 27 | 28 |
29 | {tokens.map((token, index) => { 30 | const absolute = 31 | latentActivations.values[stateLayer.active][token.pos][ 32 | stateLatent.active 33 | ]; 34 | 35 | const relative = 36 | absolute / latentActivations.max[stateLayer.active][token.pos]; 37 | 38 | return ( 39 | statePosition.onClick(index)} 47 | onMouseEnter={() => statePosition.onMouseEnter(index)} 48 | onMouseLeave={statePosition.onMouseLeave} 49 | > 50 | {escapeWhitespace(token.token)} 51 | 52 | ); 53 | })} 54 |
55 |
56 |
57 | ); 58 | } 59 | -------------------------------------------------------------------------------- /app/app/components/prompt/PromptLatentHeatmaps.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import React from "react"; 4 | import ColorSpan from "~/app/components/ColorSpan"; 5 | import { Select } from "~/app/use-select"; 6 | import { Card, CardContent } from "~/components/ui/card"; 7 | import { 8 | Table, 9 | TableBody, 10 | TableCell, 11 | TableHead, 12 | TableHeader, 13 | TableRow, 14 | } from "~/components/ui/table"; 15 | import { LatentActivationsType } from "~/lib/models"; 16 | import { cn } from "~/lib/utils"; 17 | 18 | export default function LatentHeatmapComponent({ 19 | latentActivations, 20 | threshold, 21 | stateLatent, 22 | stateLayer, 23 | statePosition, 24 | perToken = true, 25 | }: { 26 | latentActivations: LatentActivationsType; 27 | threshold: number; 28 | stateLatent: Select; 29 | stateLayer: Select; 30 | statePosition: Select; 31 | perToken?: boolean; 32 | }) { 33 | const nLayers = latentActivations.max.length; 34 | const layers = Array.from({ length: nLayers }).map((_, layer) => layer); 35 | 36 | const latentHeatmaps = React.useMemo(() => { 37 | return getLatentHeatmaps( 38 | latentActivations, 39 | threshold, 40 | statePosition.clicked, 41 | perToken, 42 | ); 43 | }, [latentActivations, threshold, statePosition.clicked, perToken]); 44 | 45 | return ( 46 | 47 | 48 | 49 | 50 | 51 | Latent 52 | Mean Layer 53 | {layers.map((layer) => ( 54 | 55 | Layer {layer} 56 | 57 | ))} 58 | 59 | 60 | 61 | {latentHeatmaps.map((latentHeatmap) => { 62 | return ( 63 | stateLatent.onClick(latentHeatmap.latent)} 67 | onMouseEnter={() => 68 | stateLatent.onMouseEnter(latentHeatmap.latent) 69 | } 70 | onMouseLeave={stateLatent.onMouseLeave} 71 | > 72 | 73 | {latentHeatmap.latent} 74 | 75 | 76 | {latentHeatmap.layer_mean.toFixed(2)} 77 | 78 | {layers.map((layer) => { 79 | const latentHeatmapLayer = latentHeatmap.layers.find( 80 | (heatmapLayer) => heatmapLayer.layer === layer, 81 | ); 82 | const absolute = latentHeatmapLayer?.absolute ?? 0; 83 | const relative = latentHeatmapLayer?.relative ?? 0; 84 | const string = absolute.toFixed(3); 85 | return ( 86 | 87 | stateLayer.onClick(layer)} 96 | onMouseEnter={() => stateLayer.onMouseEnter(layer)} 97 | onMouseLeave={stateLayer.onMouseLeave} 98 | title={string} 99 | > 100 | {absolute > 0 ? ( 101 | {string} 102 | ) : null} 103 | 104 | 105 | ); 106 | })} 107 | 108 | ); 109 | })} 110 | 111 |
112 |
113 |
114 | ); 115 | } 116 | 117 | interface LatentHeatmap { 118 | latent: number; 119 | layers: { 120 | layer: number; 121 | absolute: number; 122 | relative: number; 123 | }[]; 124 | layer_mean: number; 125 | } 126 | 127 | function getLatentHeatmaps( 128 | latentActivations: LatentActivationsType, 129 | threshold: number, 130 | position: number, 131 | perToken = true, 132 | ) { 133 | const latentMap: Record = {}; 134 | const nLayers = latentActivations.max.length; 135 | const nLatents = latentActivations.values[0][0].length; 136 | 137 | for (let layer = 0; layer < nLayers; layer++) { 138 | for (let latent = 0; latent < nLatents; latent++) { 139 | let absolute: number; 140 | let relative: number; 141 | if (perToken) { 142 | absolute = latentActivations.values[layer][position][latent]; 143 | relative = absolute / latentActivations.max[layer][position]; 144 | } else { 145 | absolute = latentActivations.values[layer].reduce((total, values) => { 146 | return total + values[latent]; 147 | }, 0); 148 | relative = 149 | absolute / 150 | latentActivations.max[layer].reduce((total, value) => { 151 | return total + value; 152 | }, 0); 153 | } 154 | if (relative > threshold) { 155 | if (latentMap[latent] === undefined) { 156 | latentMap[latent] = []; 157 | } 158 | latentMap[latent].push({ 159 | layer, 160 | absolute, 161 | relative, 162 | }); 163 | } 164 | } 165 | } 166 | 167 | const latentActivationTotal: Record = {}; 168 | const latentActivationLayerTotal: Record = {}; 169 | for (const key of Object.keys(latentMap)) { 170 | const latent = Number(key); 171 | for (let layer = 0; layer < nLayers; layer++) { 172 | const activation = latentActivations.values[layer][position][latent]; 173 | if (activation > 0) { 174 | if (latentActivationTotal[latent] === undefined) { 175 | latentActivationTotal[latent] = 0; 176 | } 177 | latentActivationTotal[latent] += activation; 178 | 179 | if (latentActivationLayerTotal[latent] === undefined) { 180 | latentActivationLayerTotal[latent] = 0; 181 | } 182 | latentActivationLayerTotal[latent] += activation * layer; 183 | } 184 | } 185 | } 186 | 187 | let latentHeatmaps: LatentHeatmap[] = Object.entries(latentMap).map( 188 | ([key, layers]) => { 189 | const latent = Number(key); 190 | return { 191 | latent, 192 | layers, 193 | layer_mean: 194 | latentActivationLayerTotal[latent] / latentActivationTotal[latent], 195 | }; 196 | }, 197 | ); 198 | 199 | latentHeatmaps.sort((a, b) => { 200 | const latentCenter = a.layer_mean - b.layer_mean; 201 | const activationTotal = 202 | b.layers.reduce((total, { absolute }) => total + absolute, 0) - 203 | a.layers.reduce((total, { absolute }) => total + absolute, 0); 204 | return latentCenter !== 0 ? latentCenter : activationTotal; 205 | }); 206 | 207 | return latentHeatmaps; 208 | } 209 | -------------------------------------------------------------------------------- /app/app/components/prompt/PromptLayerHistograms.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import { scaleSymlog } from "d3-scale"; 4 | import React from "react"; 5 | import { Area, AreaChart, XAxis, YAxis } from "recharts"; 6 | import useSWR from "swr"; 7 | import colors from "tailwindcss/colors"; 8 | import { useDarkMode } from "usehooks-ts"; 9 | import { Select } from "~/app/use-select"; 10 | import { Card, CardContent } from "~/components/ui/card"; 11 | import { ChartConfig, ChartContainer } from "~/components/ui/chart"; 12 | import { 13 | Table, 14 | TableBody, 15 | TableCell, 16 | TableHead, 17 | TableHeader, 18 | TableRow, 19 | } from "~/components/ui/table"; 20 | import { getPromptLayerHistograms } from "~/lib/api"; 21 | import { toSigned } from "~/lib/format"; 22 | import { LayerHistogramsType } from "~/lib/models"; 23 | import { cn } from "~/lib/utils"; 24 | 25 | const symlog = scaleSymlog(); 26 | const chartConfig = { 27 | count: {}, 28 | } satisfies ChartConfig; 29 | const chartMargin = { top: 0, left: 0, right: 0, bottom: 0 } as const; 30 | 31 | function chartData(histograms: LayerHistogramsType, layer: number) { 32 | return histograms.edges 33 | .map((edge, index) => ({ 34 | edge, 35 | value: histograms.values[layer][index], 36 | })) 37 | .slice(0, -1); 38 | } 39 | 40 | export default function LayerHistogramsComponent({ 41 | prompt, 42 | stateLayer, 43 | }: { 44 | prompt: string; 45 | stateLayer: Select; 46 | }) { 47 | const { data: histograms } = useSWR( 48 | ["prompt/layer-histograms", prompt], 49 | ([_key, prompt]) => getPromptLayerHistograms(prompt), 50 | { 51 | keepPreviousData: true, 52 | }, 53 | ); 54 | 55 | const { isDarkMode } = useDarkMode(); 56 | 57 | if (!histograms) { 58 | return null; 59 | } 60 | 61 | const nLayers = histograms.values.length; 62 | const layers = Array.from({ length: nLayers }, (_, i) => i); 63 | 64 | return ( 65 | 66 | 67 | 68 | 69 | 70 | Layer 71 | Histogram 72 | 73 | 74 |
75 | 76 | {layers.map((layer) => { 77 | const isActive = layer == stateLayer.active; 78 | const color = isDarkMode 79 | ? isActive 80 | ? colors.orange[800] 81 | : colors.slate[100] 82 | : isActive 83 | ? colors.orange[400] 84 | : colors.slate[900]; 85 | return ( 86 | stateLayer.onClick(layer)} 89 | onMouseEnter={() => stateLayer.onMouseEnter(layer)} 90 | onMouseLeave={stateLayer.onMouseLeave} 91 | > 92 | {layer} 93 | 94 | 101 | 106 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | ); 123 | })} 124 | 125 |
126 |
127 | ); 128 | } 129 | -------------------------------------------------------------------------------- /app/app/components/prompt/PromptLogitsInput.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import LogitsTable from "~/app/components/LogitsTable"; 4 | import { Select } from "~/app/use-select"; 5 | import { Card, CardContent, CardHeader, CardTitle } from "~/components/ui/card"; 6 | import { MaxLogitsType } from "~/lib/models"; 7 | 8 | export default function PromptLogitsInput({ 9 | values, 10 | statePosition, 11 | className, 12 | }: { 13 | values: MaxLogitsType; 14 | statePosition: Select; 15 | className?: string; 16 | }) { 17 | return ( 18 | 19 | 20 | Input activations 21 | 22 | 23 | 28 | 29 | 30 | ); 31 | } 32 | -------------------------------------------------------------------------------- /app/app/components/prompt/PromptLogitsRecon.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import useSWR from "swr"; 4 | import LogitsTable from "~/app/components/LogitsTable"; 5 | import { Select } from "~/app/use-select"; 6 | import { Card, CardContent, CardHeader, CardTitle } from "~/components/ui/card"; 7 | import { getPromptLogitsRecon as getPromptLogitsRecon } from "~/lib/api"; 8 | 9 | export default function PromptLogitsRecon({ 10 | prompt, 11 | stateLayer, 12 | statePosition, 13 | className, 14 | }: { 15 | prompt: string; 16 | stateLayer: Select; 17 | statePosition: Select; 18 | className?: string; 19 | }) { 20 | const { data } = useSWR( 21 | ["prompt/logits-recon", prompt, stateLayer.clicked], 22 | ([_key, prompt, layer]) => getPromptLogitsRecon(prompt, layer), 23 | { 24 | keepPreviousData: false, 25 | }, 26 | ); 27 | 28 | const [values, changes] = data ?? []; 29 | return ( 30 | 31 | 32 | Reconstruction at layer {stateLayer.clicked} 33 | 34 | 35 | 40 | 45 | 50 | 51 | 52 | ); 53 | } 54 | -------------------------------------------------------------------------------- /app/app/components/prompt/PromptLogitsSteer.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import useSWR from "swr"; 4 | import LogitsTable from "~/app/components/LogitsTable"; 5 | import { Select } from "~/app/use-select"; 6 | import { Card, CardContent, CardHeader, CardTitle } from "~/components/ui/card"; 7 | import { getPromptLogitsSteer } from "~/lib/api"; 8 | 9 | export default function PromptLogitsSteered({ 10 | prompt, 11 | stateLatent, 12 | stateLayer, 13 | statePosition, 14 | factor, 15 | className, 16 | }: { 17 | prompt: string; 18 | stateLatent: Select; 19 | stateLayer: Select; 20 | statePosition: Select; 21 | factor: number; 22 | className?: string; 23 | }) { 24 | const { data } = useSWR( 25 | [ 26 | "prompt/logits-steer", 27 | prompt, 28 | stateLatent.clicked, 29 | stateLayer.clicked, 30 | factor, 31 | ], 32 | ([_key, prompt, latent, layer]) => 33 | getPromptLogitsSteer(prompt, latent, layer, factor), 34 | { 35 | keepPreviousData: false, 36 | }, 37 | ); 38 | 39 | const [values, changes] = data ?? []; 40 | return ( 41 | 42 | 43 | 44 | Steered by latent {stateLatent.clicked} at layer {stateLayer.clicked} 45 | 46 | 47 | 48 | 53 | 58 | 63 | 64 | 65 | ); 66 | } 67 | -------------------------------------------------------------------------------- /app/app/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-lawson/mlsae/03ad37a0a1b4541d763859cb0c7c9ccb7ce67867/app/app/favicon.ico -------------------------------------------------------------------------------- /app/app/globals.css: -------------------------------------------------------------------------------- 1 | @tailwind base; 2 | @tailwind components; 3 | @tailwind utilities; 4 | -------------------------------------------------------------------------------- /app/app/layout.tsx: -------------------------------------------------------------------------------- 1 | import { Inter } from "next/font/google"; 2 | import "./globals.css"; 3 | import { cn } from "~/lib/utils"; 4 | 5 | const inter = Inter({ 6 | subsets: ["latin", "latin-ext"], 7 | }); 8 | 9 | export default async function Layout({ 10 | children, 11 | }: { 12 | children: React.ReactNode; 13 | }) { 14 | return ( 15 | 16 | 22 |
28 | Residual Stream Analysis with Multi-Layer SAEs 29 |
30 |
{children}
31 | 32 | 33 | ); 34 | } 35 | -------------------------------------------------------------------------------- /app/app/page.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import { zodResolver } from "@hookform/resolvers/zod"; 4 | import { useRouter } from "next/navigation"; 5 | import { useForm } from "react-hook-form"; 6 | import { z } from "zod"; 7 | import { Button } from "~/components/ui/button"; 8 | import { 9 | Form, 10 | FormControl, 11 | FormField, 12 | FormItem, 13 | FormLabel, 14 | FormMessage, 15 | } from "~/components/ui/form"; 16 | import { Textarea } from "~/components/ui/textarea"; 17 | 18 | const FormSchema = z.object({ 19 | prompt: z.string(), 20 | }); 21 | 22 | export default function InputForm() { 23 | const form = useForm>({ 24 | resolver: zodResolver(FormSchema), 25 | defaultValues: { 26 | prompt: "", 27 | }, 28 | }); 29 | 30 | const router = useRouter(); 31 | 32 | function onSubmit(data: z.infer) { 33 | router.push(`/prompt/${encodeURIComponent(data.prompt)}`); 34 | } 35 | 36 | return ( 37 |
38 | 42 | ( 46 | 47 | Prompt 48 | 49 |