├── .devcontainer
├── Dockerfile
└── devcontainer.json
├── .gitattributes
├── .gitconfig
├── .github
├── ISSUE_TEMPLATE
│ ├── bug.md
│ ├── compatibility.md
│ ├── proposal.md
│ └── question.md
├── PULL_REQUEST_TEMPLATE.md
└── workflows
│ ├── checks.yml
│ └── release.yml
├── .gitignore
├── .vscode
├── cspell.json
├── extensions.json
└── settings.json
├── LICENSE
├── Main_Demo.ipynb
├── README.md
├── assets
├── rm_transformer_lens_logo.png
└── transformer_lens_logo.png
├── debugging
├── comparing-to-huggingface.ipynb
└── hf-tl-logit-comparator.ipynb
├── demos
├── ARENA_Content.ipynb
├── Activation_Patching_in_TL_Demo.ipynb
├── Attribution_Patching_Demo.ipynb
├── BERT.ipynb
├── Colab_Compatibility.ipynb
├── Config_Overhaul.ipynb
├── Exploratory_Analysis_Demo.ipynb
├── Grokking_Demo.ipynb
├── Head_Detector_Demo.ipynb
├── Interactive_Neuroscope.ipynb
├── LLaMA.ipynb
├── LLaMA2_GPU_Quantized.ipynb
├── LLaVA.ipynb
├── Main_Demo.ipynb
├── No_Position_Experiment.ipynb
├── Othello_GPT.ipynb
├── Patchscopes_Generation_Demo.ipynb
├── Qwen.ipynb
├── SVD_Interpreter_Demo.ipynb
├── Santa_Coder.ipynb
├── T5.ipynb
├── Tracr_to_Transformer_Lens_Demo.ipynb
├── conftest.py
├── doc_sanitize.cfg
└── stable_lm.ipynb
├── docs
├── Makefile
├── README.md
├── make.bat
├── make_docs.py
└── source
│ ├── _static
│ ├── TransformerLens_Diagram.svg
│ └── transformer_lens_logo.png
│ ├── apidoc_templates
│ ├── module.rst_t
│ ├── package.rst_t
│ └── toc.rst_t
│ ├── conf.py
│ ├── content
│ ├── citation.md
│ ├── contributing.md
│ ├── gallery.md
│ ├── getting_started.md
│ ├── getting_started_mech_interp.md
│ ├── news
│ │ └── release-2.0.md
│ ├── special_cases.md
│ └── tutorials.md
│ ├── favicon.ico
│ └── index.md
├── easy_transformer
└── __init__.py
├── further_comments.md
├── makefile
├── poetry.lock
├── pyproject.toml
├── tests
├── acceptance
│ ├── test_activation_cache.py
│ ├── test_evals.py
│ ├── test_hook_tokens.py
│ ├── test_hooked_encoder.py
│ ├── test_hooked_encoder_decoder.py
│ ├── test_hooked_transformer.py
│ ├── test_multi_gpu.py
│ └── test_tokenizer_special_tokens.py
├── integration
│ ├── test_attention_mask.py
│ ├── test_cache_hook_names.py
│ ├── test_cache_pos_slice.py
│ ├── test_create_hooked_encoder.py
│ ├── test_cross_entropy_loss.py
│ ├── test_d_vocab.py
│ ├── test_grouped_query_attention.py
│ ├── test_head_detector.py
│ ├── test_hooks.py
│ ├── test_kv_cache.py
│ ├── test_left_padding.py
│ ├── test_loading_from_pretrained.py
│ ├── test_match_huggingface.py
│ ├── test_only_tokenizer.py
│ ├── test_prepend_bos.py
│ ├── test_start_at_layer.py
│ ├── test_stop_at_layer.py
│ ├── test_tokenization_methods.py
│ └── test_utils_tokens.py
├── manual_checks
│ ├── manual_checks_type_annotations.py
│ └── manual_checks_typing.py
└── unit
│ ├── components
│ ├── mlps
│ │ ├── test_can_be_used_as_mlp.py
│ │ ├── test_gated_mlp.py
│ │ ├── test_mlp.py
│ │ └── test_moe.py
│ ├── test_abstract_attention.py
│ └── test_attention.py
│ ├── factored_matrix
│ ├── test_constructor.py
│ ├── test_get_item.py
│ ├── test_multiply_by_factored_matrix.py
│ ├── test_multiply_by_matrix.py
│ ├── test_multiply_by_scalar.py
│ ├── test_multiply_by_vector.py
│ └── test_properties.py
│ ├── factories
│ ├── test_activation_function_factory.py
│ └── test_mlp_factory.py
│ ├── pretrained_weight_conversions
│ └── test_neo.py
│ ├── test_hook_points.py
│ ├── test_hooked_root_module.py
│ ├── test_hooked_transformer_config.py
│ ├── test_loading_from_pretrained_utilities.py
│ ├── test_make_docs.py
│ ├── test_next_sentence_prediction.py
│ ├── test_split_qkv.py
│ ├── test_svd_interpreter.py
│ ├── test_use_attn_result.py
│ ├── test_utils.py
│ └── utilities
│ └── test_devices.py
└── transformer_lens
├── ActivationCache.py
├── BertNextSentencePrediction.py
├── FactoredMatrix.py
├── HookedEncoder.py
├── HookedEncoderDecoder.py
├── HookedTransformer.py
├── HookedTransformerConfig.py
├── SVDInterpreter.py
├── __init__.py
├── components
├── __init__.py
├── abstract_attention.py
├── attention.py
├── bert_block.py
├── bert_embed.py
├── bert_mlm_head.py
├── bert_nsp_head.py
├── bert_pooler.py
├── embed.py
├── grouped_query_attention.py
├── layer_norm.py
├── layer_norm_pre.py
├── mlps
│ ├── can_be_used_as_mlp.py
│ ├── gated_mlp.py
│ ├── gated_mlp_4bit.py
│ ├── mlp.py
│ └── moe.py
├── pos_embed.py
├── rms_norm.py
├── rms_norm_pre.py
├── t5_attention.py
├── t5_block.py
├── token_typed_embed.py
├── transformer_block.py
└── unembed.py
├── evals.py
├── factories
├── activation_function_factory.py
└── mlp_factory.py
├── head_detector.py
├── hook_points.py
├── loading_from_pretrained.py
├── past_key_value_caching.py
├── patching.py
├── pretrained
├── __init__.py
└── weight_conversions
│ ├── __init__.py
│ ├── bert.py
│ ├── bloom.py
│ ├── coder.py
│ ├── gemma.py
│ ├── gpt2.py
│ ├── gptj.py
│ ├── llama.py
│ ├── mingpt.py
│ ├── mistral.py
│ ├── mixtral.py
│ ├── nanogpt.py
│ ├── neel_solu_old.py
│ ├── neo.py
│ ├── neox.py
│ ├── opt.py
│ ├── phi.py
│ ├── phi3.py
│ ├── qwen.py
│ ├── qwen2.py
│ └── t5.py
├── train.py
├── utilities
├── __init__.py
├── activation_functions.py
├── addmm.py
├── attention.py
└── devices.py
└── utils.py
/.devcontainer/Dockerfile:
--------------------------------------------------------------------------------
1 | # If .venv is already setup with python3.8, it will use python3.8. To use 3.11 remove it first.
2 |
3 | # Use Nvidia Ubuntu 20 base (includes CUDA if a supported GPU is present)
4 | # https://hub.docker.com/r/nvidia/cuda
5 | FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04@sha256:55211df43bf393d3393559d5ab53283d4ebc3943d802b04546a24f3345825bd9
6 |
7 | ARG USERNAME
8 | ARG USER_UID=1000
9 | ARG USER_GID=$USER_UID
10 |
11 | # Create the user
12 | # https://code.visualstudio.com/remote/advancedcontainers/add-nonroot-user
13 | RUN groupadd --gid $USER_GID $USERNAME \
14 | && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \
15 | && usermod -a -G video user \
16 | && apt-get update \
17 | && apt-get install -y sudo \
18 | && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \
19 | && chmod 0440 /etc/sudoers.d/$USERNAME
20 |
21 | # Install dependencies
22 | RUN apt-get update && \
23 | DEBIAN_FRONTEND=noninteractive apt-get -qq -y install \
24 | software-properties-common && \
25 | add-apt-repository -y ppa:deadsnakes/ppa && \
26 | apt-get update && \
27 | DEBIAN_FRONTEND=noninteractive apt-get -qq -y install \
28 | build-essential \
29 | python3.11 \
30 | python3.11-dev \
31 | python3.11-distutils \
32 | python3.11-venv \
33 | curl \
34 | git && \
35 | # Update python3 default to point to python3.11
36 | update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 1 && \
37 | update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 2 && \
38 | update-alternatives --set python3 /usr/bin/python3.11
39 |
40 | # User the new user
41 | USER $USERNAME
42 |
43 | # Install poetry
44 | RUN curl -sSL https://install.python-poetry.org | python3.11 -
45 |
--------------------------------------------------------------------------------
/.devcontainer/devcontainer.json:
--------------------------------------------------------------------------------
1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the README at:
2 | // https://github.com/microsoft/vscode-dev-containers/tree/v0.238.1/containers/python-3
3 | {
4 | "name": "Python 3",
5 | "build": {
6 | "dockerfile": "Dockerfile",
7 | "args": {
8 | "USERNAME": "user"
9 | }
10 | },
11 | // Configure tool-specific properties.
12 | "customizations": {
13 | // Configure properties specific to VS Code.
14 | "vscode": {
15 | // Set *default* container specific settings.json values on container create.
16 | "settings": {},
17 | // Add the IDs of extensions you want installed when the container is created.
18 | "extensions": []
19 | }
20 | },
21 | "containerUser": "user",
22 | // Install any dependencies
23 | "postCreateCommand": "poetry config virtualenvs.in-project true && poetry install --with dev"
24 | }
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ipynb merge=nbdev-merge
2 |
--------------------------------------------------------------------------------
/.gitconfig:
--------------------------------------------------------------------------------
1 | # Generated by nbdev_install_hooks
2 | #
3 | # If you need to disable this instrumentation do:
4 | # git config --local --unset include.path
5 | #
6 | # To restore:
7 | # git config --local include.path ../.gitconfig
8 | #
9 | [merge "nbdev-merge"]
10 | name = resolve conflicts with nbdev_fix
11 | driver = nbdev_merge %O %A %B %P
12 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug Report
3 | about: Submit a bug report
4 | title: "[Bug Report] Bug title"
5 |
6 | ---
7 |
8 | If you are submitting a bug report, please fill in the following details and use the tag [bug].
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **Code example**
14 | Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
15 |
16 | **System Info**
17 | Describe the characteristic of your environment:
18 | * Describe how `transformer_lens` was installed (pip, docker, source, ...)
19 | * What OS are you using? (Linux, MacOS, Windows)
20 | * Python version (We support 3.7--3.10 currently)
21 |
22 | **Additional context**
23 | Add any other context about the problem here.
24 |
25 | ### Checklist
26 |
27 | - [ ] I have checked that there is no similar [issue](https://github.com/TransformerLensOrg/TransformerLens/issues) in the repo (**required**)
28 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/compatibility.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Compatibility Report
3 | about: Submit a compatibility report
4 | title: "[Compatibility Report] Model ID"
5 |
6 | ---
7 |
8 |
15 |
16 | ## Model
17 |
18 | REPLACE_WITH_MODEL_ID
19 |
20 | - [ ] This model was incompatible when it was introduced to TransformerLens
21 |
22 |
25 |
26 | The model seems to have worked as of REPLACE_WITH_LAST_COMPATIBLE_VERSION_NUMBER. It first started
27 | showing signs of incompatibility in REPLACE_WITH_FIRST_INCOMPATIBLE_VERSION_NUMBER.
28 |
29 | ### Example of some generations in transformers
30 |
31 |
32 | ### Code used to load the model in TransformerLens
33 |
34 |
35 | ### Example of some generations in TransformerLens
36 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/proposal.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Proposal
3 | about: Propose changes that are not fixes bugs
4 | title: "[Proposal] Proposal title"
5 | ---
6 |
7 | ### Proposal
8 |
9 | A clear and concise description of the proposal.
10 |
11 | ### Motivation
12 |
13 | Please outline the motivation for the proposal.
14 | Is your feature request related to a problem? e.g.,"I'm always frustrated when [...]".
15 | If this is related to another GitHub issue, please link here too.
16 |
17 | ### Pitch
18 |
19 | A clear and concise description of what you want to happen. It is useful if you relate proposals to existing features or lackthereof as well as relevant mechanistic interpretability techniques or model architectures.
20 |
21 | ### Alternatives
22 |
23 | A clear and concise description of any alternative solutions or features you've considered, if any.
24 |
25 | ### Additional context
26 |
27 | Add any other context or screenshots about the feature request here.
28 |
29 | ### Checklist
30 |
31 | - [ ] I have checked that there is no similar [issue](https://github.com/TransformerLensOrg/Transformerlens/issues) in the repo (**required**)
32 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/question.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Question
3 | about: Ask a question
4 | title: "[Question] Question title"
5 | ---
6 |
7 |
8 | ### Question
9 |
10 | If you're a beginner and have basic questions, you can ask them on various online forums such as:
11 | - The [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-1qosyh8g3-9bF3gamhLNJiqCL_QqLFrA)
12 | - The [Eleuther AI discord](https://discord.gg/zBGx3azzUn)
13 | - The [Mechanistic Interpretability Discord](https://discord.gg/wcuV4xnJ)
14 |
15 | Advanced/nontrivial questions, especially in areas where documentation is lacking, are very much welcome here.
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 |
8 | # Description
9 |
10 | Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
11 |
12 | Fixes # (issue)
13 |
14 | ## Type of change
15 |
16 | Please delete options that are not relevant.
17 |
18 | - [ ] Bug fix (non-breaking change which fixes an issue)
19 | - [ ] New feature (non-breaking change which adds functionality)
20 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
21 | - [ ] This change requires a documentation update
22 |
23 | ### Screenshots
24 | Please attach before and after screenshots of the change if applicable.
25 |
26 |
36 |
37 | # Checklist:
38 |
39 | - [ ] I have commented my code, particularly in hard-to-understand areas
40 | - [ ] I have made corresponding changes to the documentation
41 | - [ ] My changes generate no new warnings
42 | - [ ] I have added tests that prove my fix is effective or that my feature works
43 | - [ ] New and existing unit tests pass locally with my changes
44 | - [ ] I have not rewritten tests relating to key interfaces which would affect backward compatibility
45 |
46 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on:
4 | release:
5 | types:
6 | - published
7 |
8 | jobs:
9 | checks:
10 | name: Run checks workflow
11 | uses: TransformerLensOrg/TransformerLens/.github/workflows/checks.yml@main
12 | secrets:
13 | HF_TOKEN: ${{ secrets.HF_TOKEN }}
14 |
15 | semver-parser:
16 | name: Parse the semantic version from the release
17 | runs-on: ubuntu-latest
18 | steps:
19 | - name: Parse semver string
20 | id: semver_parser
21 | uses: booxmedialtd/ws-action-parse-semver@v1.4.7
22 | with:
23 | input_string: ${{ github.event.release.tag_name }}
24 | outputs:
25 | major: "${{ steps.semver_parser.outputs.major }}"
26 | minor: "${{ steps.semver_parser.outputs.minor }}"
27 | patch: "${{ steps.semver_parser.outputs.patch }}"
28 | semver: "${{ steps.semver_parser.outputs.fullversion }}"
29 |
30 | release-python:
31 | name: Release Python package to PyPi
32 | needs:
33 | - checks
34 | - semver-parser
35 | runs-on: ubuntu-latest
36 | steps:
37 | - uses: actions/checkout@v3
38 | - name: Install Poetry
39 | uses: snok/install-poetry@v1
40 | - name: Set up Python
41 | uses: actions/setup-python@v4
42 | with:
43 | python-version: '3.11'
44 | cache: 'poetry'
45 | - name: Poetry config
46 | run: poetry self add 'poethepoet[poetry_plugin]'
47 | - name: Install dependencies
48 | run: poetry install --with dev
49 | - name: Set the version
50 | run: poetry version ${{needs.semver-parser.outputs.semver}}
51 | - name: Build
52 | run: poetry build
53 | - name: Publish
54 | run: poetry publish
55 | env:
56 | POETRY_PYPI_TOKEN_PYPI: ${{ secrets.PYPI_TOKEN_PYPI }}
57 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | Testing_Notebook.ipynb
3 | transformer_lens/scratch.py
4 | .idea/
5 | wandb*
6 | transformer_lens\.egg*
7 | MANIFEST.in
8 | settings.ini
9 | _proc
10 | core.py
11 | nbs
12 | _modidx.py
13 | .ipynb_checkpoints
14 | env
15 | dist/
16 | docs/build
17 | .coverage*
18 | .Ds_Store
19 | .pylintrc
20 | docs/source/generated
21 | **.orig
22 | .venv
23 |
--------------------------------------------------------------------------------
/.vscode/cspell.json:
--------------------------------------------------------------------------------
1 | {
2 | "language": "en,en-GB",
3 | "words": [
4 | "accum",
5 | "adrià",
6 | "aengus",
7 | "allclose",
8 | "alonso",
9 | "arange",
10 | "argmax",
11 | "argmaxy",
12 | "autodiff",
13 | "autoregressive",
14 | "barez",
15 | "beartype",
16 | "belrose",
17 | "bertsimas",
18 | "biderman",
19 | "bilal",
20 | "bincount",
21 | "caxis",
22 | "checkpointed",
23 | "chughtai",
24 | "circuitsvis",
25 | "Codeparrot",
26 | "codespaces",
27 | "colab",
28 | "collectstart",
29 | "colour",
30 | "conmy",
31 | "cooney",
32 | "crfm",
33 | "cumsum",
34 | "datapoint",
35 | "dictmodel",
36 | "dimitris",
37 | "disconfirm",
38 | "dmitrii",
39 | "docstrings",
40 | "doctest",
41 | "doctree",
42 | "dtype",
43 | "dtypes",
44 | "einops",
45 | "elhage",
46 | "endoftext",
47 | "eqnarray",
48 | "esben",
49 | "evals",
50 | "explictly",
51 | "fazl",
52 | "firstpage",
53 | "fspath",
54 | "furo",
55 | "garriga",
56 | "gelu",
57 | "githubpages",
58 | "gptj",
59 | "halawi",
60 | "heimersheim",
61 | "howpublished",
62 | "huggingface",
63 | "icml",
64 | "idxs",
65 | "imshow",
66 | "interp",
67 | "interpretability",
68 | "ioannis",
69 | "ipynb",
70 | "isin",
71 | "isort",
72 | "janiak",
73 | "Janky",
74 | "jaxtyping",
75 | "jett",
76 | "kaiming",
77 | "keepdim",
78 | "kissane",
79 | "konstas",
80 | "kran",
81 | "lastpage",
82 | "layernorm",
83 | "ldim",
84 | "lieberum",
85 | "logits",
86 | "logsumexp",
87 | "mavor",
88 | "maxdepth",
89 | "mingpt",
90 | "nanda",
91 | "ndarray",
92 | "ndim",
93 | "neel",
94 | "neox",
95 | "nitpicky",
96 | "occurences",
97 | "olah",
98 | "openwebtext",
99 | "overcomplete",
100 | "Overriden",
101 | "pagename",
102 | "pauly",
103 | "pretrained",
104 | "probs",
105 | "producting",
106 | "pycln",
107 | "pypi",
108 | "pytest",
109 | "randn",
110 | "rdim",
111 | "relu",
112 | "resid",
113 | "rprint",
114 | "rtml",
115 | "rtol",
116 | "shortformer",
117 | "softmax",
118 | "softmaxing",
119 | "solu",
120 | "stas",
121 | "templatedir",
122 | "templatename",
123 | "toctree",
124 | "topk",
125 | "tqdm",
126 | "transformerlens",
127 | "tril",
128 | "triu",
129 | "troitskii",
130 | "unembed",
131 | "unembedded",
132 | "unembedding",
133 | "unigram",
134 | "unsqueeze",
135 | "virtualenvs",
136 | "visualisation",
137 | "xaxis",
138 | "yaxis"
139 | ]
140 | }
141 |
--------------------------------------------------------------------------------
/.vscode/extensions.json:
--------------------------------------------------------------------------------
1 | {
2 | "recommendations": [
3 | "christian-kohler.path-intellisense",
4 | "davidanson.vscode-markdownlint",
5 | "donjayamanne.githistory",
6 | "donjayamanne.python-extension-pack",
7 | "github.copilot",
8 | "github.vscode-pull-request-github",
9 | "ionutvmi.path-autocomplete",
10 | "mikoz.autoflake-extension",
11 | "ms-python.isort",
12 | "ms-python.pylint",
13 | "ms-python.python",
14 | "ms-python.vscode-pylance",
15 | "ms-toolsai.jupyter-keymap",
16 | "ms-toolsai.jupyter-renderers",
17 | "ms-toolsai.jupyter",
18 | "richie5um2.vscode-sort-json",
19 | "stkb.rewrap",
20 | "streetsidesoftware.code-spell-checker-british-english",
21 | "streetsidesoftware.code-spell-checker",
22 | "tamasfe.even-better-toml",
23 | "yzhang.markdown-all-in-one"
24 | ]
25 | }
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "[python]": {
3 | "editor.defaultFormatter": "ms-python.black-formatter"
4 | },
5 | "[toml]": {
6 | "editor.defaultFormatter": "tamasfe.even-better-toml"
7 | },
8 | "editor.codeActionsOnSave": {
9 | "source.organizeImports": "explicit"
10 | },
11 | "editor.formatOnSave": true,
12 | "editor.rulers": [100],
13 | "evenBetterToml.formatter.allowedBlankLines": 1,
14 | "evenBetterToml.formatter.arrayAutoCollapse": true,
15 | "evenBetterToml.formatter.arrayAutoExpand": true,
16 | "evenBetterToml.formatter.arrayTrailingComma": true,
17 | "evenBetterToml.formatter.columnWidth": 100,
18 | "evenBetterToml.formatter.compactArrays": true,
19 | "evenBetterToml.formatter.compactEntries": true,
20 | "evenBetterToml.formatter.compactInlineTables": true,
21 | "evenBetterToml.formatter.indentEntries": true,
22 | "evenBetterToml.formatter.indentString": " ",
23 | "evenBetterToml.formatter.indentTables": true,
24 | "evenBetterToml.formatter.inlineTableExpand": true,
25 | "evenBetterToml.formatter.reorderArrays": true,
26 | "evenBetterToml.formatter.reorderKeys": true,
27 | "evenBetterToml.formatter.trailingNewline": true,
28 | "evenBetterToml.schema.enabled": true,
29 | "evenBetterToml.schema.links": true,
30 | "evenBetterToml.syntax.semanticTokens": false,
31 | "mypy-type-checker.importStrategy": "fromEnvironment",
32 | "notebook.formatOnCellExecution": true,
33 | "notebook.formatOnSave.enabled": true,
34 | "pylint.importStrategy": "fromEnvironment",
35 | "python.testing.pytestArgs": [
36 | "transformer_lens",
37 | ],
38 | "python.testing.pytestEnabled": true,
39 | "rewrap.autoWrap.enabled": true,
40 | "rewrap.reformat": true,
41 | "rewrap.wrappingColumn": 100,
42 | "mypy.runUsingActiveInterpreter": true,
43 | "editor.defaultFormatter": "ms-python.black-formatter",
44 | "black-formatter.args": [
45 | "-l 100"
46 | ],
47 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 TransformerLensOrg
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Main_Demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "# Notice\n",
9 | "\n",
10 | "All demos have been moved to the `/demos` directory in the root of the project [View on GitHub](https://github.com/TransformerLensOrg/TransformerLens/tree/main/demos)\n"
11 | ]
12 | }
13 | ],
14 | "metadata": {
15 | "kernelspec": {
16 | "display_name": ".venv",
17 | "language": "python",
18 | "name": "python3"
19 | },
20 | "language_info": {
21 | "name": "python",
22 | "version": "3.8.10"
23 | },
24 | "orig_nbformat": 4,
25 | "vscode": {
26 | "interpreter": {
27 | "hash": "b8da8bee8e62267e58dea1070cf9758944e5c526027e75fbcceb99a4c665691a"
28 | }
29 | }
30 | },
31 | "nbformat": 4,
32 | "nbformat_minor": 2
33 | }
34 |
--------------------------------------------------------------------------------
/assets/rm_transformer_lens_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TransformerLensOrg/TransformerLens/b5a16f849649a237cc02cc2c272ae4dc2085abe4/assets/rm_transformer_lens_logo.png
--------------------------------------------------------------------------------
/assets/transformer_lens_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TransformerLensOrg/TransformerLens/b5a16f849649a237cc02cc2c272ae4dc2085abe4/assets/transformer_lens_logo.png
--------------------------------------------------------------------------------
/demos/conftest.py:
--------------------------------------------------------------------------------
1 | def pytest_collectstart(collector):
2 | """Ignore several mimetypes when comparing notebooks."""
3 | if collector.fspath and collector.fspath.ext == ".ipynb":
4 | collector.skip_compare += (
5 | "text/html",
6 | "application/javascript",
7 | "application/vnd.plotly.v1+json", # Plotly
8 | )
9 |
--------------------------------------------------------------------------------
/demos/doc_sanitize.cfg:
--------------------------------------------------------------------------------
1 | [regex1]
2 | regex: \d{1,2}/\d{1,2}/\d{2,4}
3 | replace: DATE-STAMP
4 |
5 | [regex2]
6 | regex: \d{2}:\d{2}:\d{2}
7 | replace: TIME-STAMP
8 |
9 | [regex3]
10 | regex: 0[xX][0-9a-fA-F]+
11 | replace: HEX-CODE
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Transformer-Lens Docs
3 |
4 |
5 | This repo contains the [website](https://TransformerLensOrg.github.io/TransformerLens/) for [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens). This site is currently in Beta and we are in the process of adding/editing information.
6 |
7 | The documentation uses Sphinx. However, the documentation is written in regular md, NOT rst.
8 |
9 | ## Build the Documentation
10 |
11 | First install the docs packages:
12 |
13 | ```bash
14 | poetry install --with docs
15 | ```
16 |
17 | Then for hot-reloading, run this (note the model properties table won't hot reload, but everything
18 | else will):
19 |
20 | ```bash
21 | poetry run docs-hot-reload
22 | ```
23 |
24 | Alternatively to build once, run:
25 |
26 | ```bash
27 | poetry run build-docs
28 | ```
29 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/source/_static/transformer_lens_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TransformerLensOrg/TransformerLens/b5a16f849649a237cc02cc2c272ae4dc2085abe4/docs/source/_static/transformer_lens_logo.png
--------------------------------------------------------------------------------
/docs/source/apidoc_templates/module.rst_t:
--------------------------------------------------------------------------------
1 | {%- if show_headings %}
2 | {{- [basename] | join(' ') | e | heading }}
3 |
4 | {% endif -%}
5 | .. automodule:: {{ qualname }}
6 | {%- for option in automodule_options %}
7 | :{{ option }}:
8 | {%- endfor %}
--------------------------------------------------------------------------------
/docs/source/apidoc_templates/package.rst_t:
--------------------------------------------------------------------------------
1 | {%- macro automodule(modname, options) -%}
2 | .. automodule:: {{ modname }}
3 | {%- for option in options %}
4 | :{{ option }}:
5 | {%- endfor %}
6 | {%- endmacro %}
7 |
8 | {%- macro toctree(docnames) -%}
9 | .. toctree::
10 | :maxdepth: 1
11 | {% for docname in docnames %}
12 | {{ docname }}
13 | {%- endfor %}
14 | {%- endmacro %}
15 |
16 | {%- if is_namespace %}
17 | {{- [pkgname] | join(" ") | e | heading }}
18 | {% else %}
19 | {{- [pkgname] | join(" ") | e | heading }}
20 | {% endif %}
21 |
22 | {%- if is_namespace %}
23 | .. py:module:: {{ pkgname }}
24 | {% endif %}
25 |
26 | {%- if modulefirst and not is_namespace %}
27 | {{ automodule(pkgname, automodule_options) }}
28 | {% endif %}
29 |
30 | {%- if submodules %}
31 | Submodules
32 | ----------
33 | {% if separatemodules %}
34 | {{ toctree(submodules) }}
35 | {% else %}
36 | {%- for submodule in submodules %}
37 | {% if show_headings %}
38 | {{- [submodule, "module"] | join(" ") | e | heading(2) }}
39 | {% endif %}
40 | {{ automodule(submodule, automodule_options) }}
41 | {% endfor %}
42 | {%- endif %}
43 | {%- endif %}
44 |
45 | {%- if subpackages %}
46 | Subpackages
47 | -----------
48 |
49 | {{ toctree(subpackages) }}
50 | {% endif %}
51 |
--------------------------------------------------------------------------------
/docs/source/apidoc_templates/toc.rst_t:
--------------------------------------------------------------------------------
1 | Transformer Lens API
2 | --------------------
3 |
4 | If browsing the docs for the first time, we recommend initially looking at :doc:`HookedTransformer
5 | ` and :doc:`ActivationCache
6 | ` modules.
7 |
8 | Contents
9 | ^^^^^^^^
10 |
11 | .. toctree::
12 | :maxdepth: 3
13 | {% for docname in docnames %}
14 | {{ docname }}
15 | {%- endfor %}
--------------------------------------------------------------------------------
/docs/source/content/citation.md:
--------------------------------------------------------------------------------
1 |
2 | # Citation
3 |
4 | Please cite this library as:
5 |
6 | ```BibTeX
7 | @misc{nanda2022transformerlens,
8 | title = {TransformerLens},
9 | author = {Neel Nanda and Joseph Bloom},
10 | year = {2022},
11 | howpublished = {\url{https://github.com/TransformerLensOrg/TransformerLens}},
12 | }
13 | ```
14 |
--------------------------------------------------------------------------------
/docs/source/content/gallery.md:
--------------------------------------------------------------------------------
1 | # Gallery
2 |
3 | Research done involving TransformerLens:
4 |
5 | - [Progress Measures for Grokking via Mechanistic
6 | Interpretability](https://arxiv.org/abs/2301.05217) (ICLR Spotlight, 2023) by Neel Nanda, Lawrence
7 | Chan, Tom Lieberum, Jess Smith, Jacob Steinhardt
8 | - [Finding Neurons in a Haystack: Case Studies with Sparse
9 | Probing](https://arxiv.org/abs/2305.01610) by Wes Gurnee, Neel Nanda, Matthew Pauly, Katherine
10 | Harvey, Dmitrii Troitskii, Dimitris Bertsimas
11 | - [Towards Automated Circuit Discovery for Mechanistic
12 | Interpretability](https://arxiv.org/abs/2304.14997) by Arthur Conmy, Augustine N. Mavor-Parker,
13 | Aengus Lynch, Stefan Heimersheim, Adrià Garriga-Alonso
14 | - [Actually, Othello-GPT Has A Linear Emergent World Representation](https://neelnanda.io/othello)
15 | by Neel Nanda
16 | - [A circuit for Python docstrings in a 4-layer attention-only
17 | transformer](https://www.alignmentforum.org/posts/u6KXXmKFbXfWzoAXn/a-circuit-for-python-docstrings-in-a-4-layer-attention-only)
18 | by Stefan Heimersheim and Jett Janiak
19 | - [A Toy Model of Universality](https://arxiv.org/abs/2302.03025) (ICML, 2023) by Bilal Chughtai,
20 | Lawrence Chan, Neel Nanda
21 | - [N2G: A Scalable Approach for Quantifying Interpretable Neuron Representations in Large Language
22 | Models](https://openreview.net/forum?id=ZB6bK6MTYq) (2023, ICLR Workshop RTML) by Alex Foote, Neel
23 | Nanda, Esben Kran, Ioannis Konstas, Fazl Barez
24 | - [Eliciting Latent Predictions from Transformers with the Tuned
25 | Lens](https://arxiv.org/abs/2303.08112) by Nora Belrose, Zach Furman, Logan Smith, Danny Halawi,
26 | Igor Ostrovsky, Lev McKinney, Stella Biderman, Jacob Steinhardt
27 |
28 | User contributed examples of the library being used in action:
29 |
30 | - [Induction Heads Phase Change
31 | Replication](https://colab.research.google.com/github/ckkissane/induction-heads-transformer-lens/blob/main/Induction_Heads_Phase_Change.ipynb):
32 | A partial replication of [In-Context Learning and Induction
33 | Heads](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html)
34 | from Connor Kissane
35 | - [Decision Transformer
36 | Interpretability](https://github.com/jbloomAus/DecisionTransformerInterpretability): A set of
37 | scripts for training decision transformers which uses transformer lens to view intermediate
38 | activations, perform attribution and ablations. A write up of the initial work can be found
39 | [here](https://www.lesswrong.com/posts/bBuBDJBYHt39Q5zZy/decision-transformer-interpretability).
--------------------------------------------------------------------------------
/docs/source/content/getting_started.md:
--------------------------------------------------------------------------------
1 | # Getting Started
2 |
3 | **Start with the [main demo](https://neelnanda.io/transformer-lens-demo) to learn how the library works, and the basic features**.
4 |
5 | To see what using it for exploratory analysis in practice looks like, check out [my notebook analysing Indirect Objection Identification](https://neelnanda.io/exploratory-analysis-demo) or [my recording of myself doing research](https://www.youtube.com/watch?v=yo4QvDn-vsU)!
6 |
7 | Mechanistic interpretability is a very young and small field, and there are a *lot* of open problems - if you would like to help, please try working on one! **Check out my [list of concrete open problems](https://docs.google.com/document/d/1WONBzNqfKIxERejrrPlQMyKqg7jSFW92x5UMXNrMdPo/edit) to figure out where to start.**. It begins with advice on skilling up, and key resources to check out.
8 |
9 | If you're new to transformers, check out my [what is a transformer tutorial](https://neelnanda.io/transformer-tutorial) and [tutorial on coding GPT-2 from scratch](https://neelnanda.io/transformer-tutorial-2) (with [an accompanying template](https://neelnanda.io/transformer-template) to write one yourself!
10 |
11 | ## Advice for Reading the Code
12 |
13 | One significant design decision made was to have a single transformer implementation that could support a range of subtly different GPT-style models. This has the upside of interpretability code just working for arbitrary models when you change the model name in `HookedTransformer.from_pretrained`! But it has the significant downside that the code implementing the model (in `HookedTransformer.py` and `components.py`) can be difficult to read. I recommend starting with my [Clean Transformer Demo](https://neelnanda.io/transformer-solution), which is a clean, minimal implementation of GPT-2 with the same internal architecture and activation names as HookedTransformer, but is significantly clearer and better documented.
14 |
15 | ## Installation
16 |
17 | `pip install git+https://github.com/TransformerLensOrg/TransformerLens`
18 |
19 | Import the library with `import transformer_lens`
20 |
21 | (Note: This library used to be known as EasyTransformer, and some breaking changes have been made since the rename. If you need to use the old version with some legacy code, run `pip install git+https://github.com/TransformerLensOrg/TransformerLens@v1`.)
22 |
23 | ## Huggingface Gated Access
24 |
25 | Some of the models available in TransformerLens require gated access to be used. Luckily TransformerLens provides a way to access those models via the configuration of an environmental variable. Simply configure your access token found [here](https://huggingface.co/settings/tokens) as `HF_TOKEN` in your environment.
26 |
27 | You will need to make sure you accept the agreements for any gated models, but once you do, the models will work with TransformerLens without issue. If you attempt to ues one of these models before you have accepted any related agreements, the console output will be very helpful and point you to the URL where you need to accept an agreement. As of 23/4/24, the current list of gated models supported by TransformerLens is as follows.
28 |
29 | * https://huggingface.co/mistralai/Mixtral-8x7B-v0.1
30 | * https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
31 | * https://huggingface.co/mistralai/Mistral-7B-v0.1
32 |
--------------------------------------------------------------------------------
/docs/source/content/getting_started_mech_interp.md:
--------------------------------------------------------------------------------
1 | # Getting Started in Mechanistic Interpretability
2 |
3 | Mechanistic interpretability is a very young and small field, and there are a _lot_ of open
4 | problems. This means there's both a lot of low-hanging fruit, and that the bar for entry is low - if
5 | you would like to help, please try working on one! The standard answer to "why has no one done this
6 | yet" is just that there aren't enough people! Key resources:
7 |
8 | - [A Guide to Getting Started in Mechanistic Interpretability](https://neelnanda.io/getting-started)
9 | - [ARENA Mechanistic Interpretability Tutorials](https://arena-ch1-transformers.streamlit.app/) from
10 | Callum McDougall. A comprehensive practical introduction to mech interp, written in
11 | TransformerLens - full of snippets to copy and they come with exercises and solutions! Notable
12 | tutorials:
13 | - [Coding GPT-2 from
14 | scratch](https://arena-ch1-transformers.streamlit.app/[1.1]_Transformer_from_Scratch), with
15 | accompanying video tutorial from me ([1](https://neelnanda.io/transformer-tutorial)
16 | [2](https://neelnanda.io/transformer-tutorial-2)) - a good introduction to transformers
17 | - [Introduction to Mech Interp and
18 | TransformerLens](https://arena-ch1-transformers.streamlit.app/[1.2]_Intro_to_Mech_Interp): An
19 | introduction to TransformerLens and mech interp via studying induction heads. Covers the
20 | foundational concepts of the library
21 | - [Indirect Object
22 | Identification](https://arena-ch1-transformers.streamlit.app/[1.3]_Indirect_Object_Identification):
23 | a replication of interpretability in the wild, that covers standard techniques in mech interp
24 | such as [direct logit
25 | attribution](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=disz2gTx-jooAcR0a5r8e7LZ),
26 | [activation patching and path
27 | patching](https://www.lesswrong.com/posts/xh85KbTFhbCz7taD4/how-to-think-about-activation-patching)
28 | - [Mech Interp Paper Reading List](https://neelnanda.io/paper-list)
29 | - [200 Concrete Open Problems in Mechanistic
30 | Interpretability](https://neelnanda.io/concrete-open-problems)
31 | - [A Comprehensive Mechanistic Interpretability Explainer](https://neelnanda.io/glossary): To look
32 | up all the jargon and unfamiliar terms you're going to come across!
33 | - [Neel Nanda's Youtube channel](https://www.youtube.com/channel/UCBMJ0D-omcRay8dh4QT0doQ): A range
34 | of mech interp video content, including [paper
35 | walkthroughs](https://www.youtube.com/watch?v=KV5gbOmHbjU&list=PL7m7hLIqA0hpsJYYhlt1WbHHgdfRLM2eY&index=1),
36 | and [walkthroughs of doing
37 | research](https://www.youtube.com/watch?v=yo4QvDn-vsU&list=PL7m7hLIqA0hr4dVOgjNwP2zjQGVHKeB7T)
--------------------------------------------------------------------------------
/docs/source/content/special_cases.md:
--------------------------------------------------------------------------------
1 | # Special Cases
2 |
3 | ## Mixture of Experts error rates
4 | Due to the Top-K gating performed in the hidden layer of Mixture of Experts models, small errors can be amplified
5 | greatly in cases where a different expert is selected, which leads to a higher than normal variance in the error rate
6 | of the final logits. In testing done on Mixtral running in half precision, the standard deviation of the absolute error
7 | rate of the logits compared to those from the default model was found to be around 2e-3.
8 |
9 | There are two main ways to mitigate this:
10 | 1. Disable preprocessing options by using `HookedTransformer.from_pretrained_no_processing` instead of `HookedTransformer.from_pretrained`
11 | 2. Increase the precision of the data type used in the model
12 |
--------------------------------------------------------------------------------
/docs/source/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TransformerLensOrg/TransformerLens/b5a16f849649a237cc02cc2c272ae4dc2085abe4/docs/source/favicon.ico
--------------------------------------------------------------------------------
/docs/source/index.md:
--------------------------------------------------------------------------------
1 | ---
2 | hide-toc: true
3 | firstpage:
4 | lastpage:
5 | ---
6 |
7 | # TransformerLens
8 |
9 | (Formerly known as EasyTransformer) [](https://pypi.org/project/transformer-lens/)
10 |
11 | ## A Library for Mechanistic Interpretability of Generative Language Models
12 |
13 | This is a library for doing [mechanistic interpretability](https://distill.pub/2020/circuits/zoom-in/) of GPT-2 Style language models. The goal of mechanistic interpretability is to take a trained model and reverse engineer the algorithms the model learned during training from its weights. It is a fact about the world today that we have computer programs that can essentially speak English at a human level (GPT-3, PaLM, etc), yet we have no idea how they work nor how to write one ourselves. This offends me greatly, and I would like to solve this!
14 |
15 | TransformerLens lets you load in an open source language model, like GPT-2, and exposes the internal activations of the model to you. You can cache any internal activation in the model, and add in functions to edit, remove or replace these activations as the model runs. The core design principle I've followed is to enable exploratory analysis. One of the most fun parts of mechanistic interpretability compared to normal ML is the extremely short feedback loops! The point of this library is to keep the gap between having an experiment idea and seeing the results as small as possible, to make it easy for **research to feel like play** and to enter a flow state. Part of what I aimed for is to make *my* experience of doing research easier and more fun, hopefully this transfers to you!
16 |
17 | I used to work for the [Anthropic interpretability team](https://transformer-circuits.pub/), and I wrote this library because after I left and tried doing independent research, I got extremely frustrated by the state of open source tooling. There's a lot of excellent infrastructure like HuggingFace and DeepSpeed to *use* or *train* models, but very little to dig into their internals and reverse engineer how they work. **This library tries to solve that**, and to make it easy to get into the field even if you don't work at an industry org with real infrastructure! One of the great things about mechanistic interpretability is that you don't need large models or tons of compute. There are lots of important open problems that can be solved with a small model in a Colab notebook!
18 |
19 | The core features were heavily inspired by the interface to [Anthropic's excellent Garcon tool](https://transformer-circuits.pub/2021/garcon/index.html). Credit to Nelson Elhage and Chris Olah for building Garcon and showing me the value of good infrastructure for enabling exploratory research!
20 |
21 | A great place to start is to take a look at a helpful diagram of [all weight matrices and activation tensors with TransformerLens notation](_static/TransformerLens_Diagram.svg) courtesy of [Austin Kozlowski](https://github.com/akozlo). Another helpful tool to help you get going as quickly as possible is our [Colab Compatability Demo](https://github.com/TransformerLensOrg/TransformerLens/tree/main/demos/Colab_Compatibility.ipynb), which will give you a good idea of what you can do in various Colab environments.
22 |
23 | ```{toctree}
24 | :hidden:
25 | :caption: Introduction
26 |
27 | content/getting_started
28 | content/getting_started_mech_interp
29 | content/gallery
30 | ```
31 |
32 | ```{toctree}
33 | :hidden:
34 | :caption: Documentation
35 |
36 | generated/code/modules
37 | generated/model_properties_table.md
38 | ```
39 |
40 | ```{toctree}
41 | :hidden:
42 | :caption: Resources
43 |
44 | content/tutorials
45 | content/citation
46 | content/contributing
47 | generated/demos/Main_Demo
48 | generated/demos/Exploratory_Analysis_Demo
49 | content/special_cases
50 | ```
51 |
52 | ```{toctree}
53 | :hidden:
54 | :caption: News
55 |
56 | content/news/release-2.0
57 | ```
58 |
59 | ```{toctree}
60 | :hidden:
61 | :caption: Development
62 |
63 | content/contributing
64 | Code Coverage
65 | Github
66 | ```
67 |
--------------------------------------------------------------------------------
/easy_transformer/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | logging.warning("DEPRECATED: Library has been renamed, import transformer_lens instead")
4 | from transformer_lens import *
5 |
--------------------------------------------------------------------------------
/makefile:
--------------------------------------------------------------------------------
1 | format:
2 | poetry run pycln --all . --exclude "__init__.py"
3 | poetry run isort format .
4 | poetry run black .
5 |
6 | check-format:
7 | poetry run pycln --check --all . --exclude "__init__.py"
8 | poetry run isort --check-only .
9 | poetry run black --check .
10 |
11 | unit-test:
12 | poetry run pytest tests/unit
13 |
14 | integration-test:
15 | poetry run pytest tests/integration
16 |
17 | acceptance-test:
18 | poetry run pytest tests/acceptance
19 |
20 | coverage-report-test:
21 | poetry run pytest --cov=transformer_lens/ --cov-report=html --cov-branch tests/unit tests/integration tests/acceptance
22 |
23 | docstring-test:
24 | poetry run pytest transformer_lens/
25 |
26 | notebook-test:
27 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/BERT.ipynb
28 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Exploratory_Analysis_Demo.ipynb
29 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Main_Demo.ipynb
30 |
31 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Head_Detector_Demo.ipynb
32 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Interactive_Neuroscope.ipynb
33 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/LLaMA.ipynb
34 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/No_Position_Experiment.ipynb
35 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Othello_GPT.ipynb
36 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Qwen.ipynb
37 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Santa_Coder.ipynb
38 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Stable_Lm.ipynb
39 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/SVD_Interpreter_Demo.ipynb
40 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Tracr_to_Transformer_Lens_Demo.ipynb
41 |
42 | # Contains failing cells
43 |
44 | # Causes CI to hang
45 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Activation_Patching_in_TL_Demo.ipynb
46 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Attribution_Patching_Demo.ipynb
47 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Grokking_Demo.ipynb
48 |
49 | test:
50 | make unit-test
51 | make acceptance-test
52 | make docstring-test
53 | make notebook-test
54 |
55 | docs-hot-reload:
56 | poetry run docs-hot-reload
57 |
58 | build-docs:
59 | poetry run build-docs
60 |
--------------------------------------------------------------------------------
/tests/acceptance/test_evals.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from transformer_lens.evals import IOIDataset, ioi_eval
4 | from transformer_lens.HookedTransformer import HookedTransformer
5 |
6 |
7 | @pytest.fixture(scope="module")
8 | def model():
9 | return HookedTransformer.from_pretrained("gpt2-small")
10 |
11 |
12 | def test_basic_ioi_eval(model):
13 | """
14 | Test IOI evaluation with default dataset and settings.
15 | """
16 | results = ioi_eval(model, num_samples=100)
17 | assert results["Accuracy"] >= 0.99
18 |
19 |
20 | def test_symmetric_samples(model):
21 | """
22 | Test IOI evaluation with symmetric=True so prompts are in symmetric pairs.
23 | """
24 | ds = IOIDataset(tokenizer=model.tokenizer, num_samples=100, symmetric=True)
25 | results = ioi_eval(model, dataset=ds)
26 | assert results["Logit Difference"] > 2.0
27 | assert results["Accuracy"] > 0.9
28 |
29 |
30 | def test_custom_dataset_ioi_eval(model):
31 | """
32 | Test IOI eval with custom dataset using different templates, names, and objects.
33 | """
34 | ds = IOIDataset(
35 | tokenizer=model.tokenizer,
36 | num_samples=100,
37 | templates=["[A] met with [B]. [B] gave the [OBJECT] to [A]"],
38 | names=["Alice", "Bob", "Charlie"],
39 | nouns={"OBJECT": ["ball", "book"]},
40 | )
41 | results = ioi_eval(model, dataset=ds)
42 | assert results["Logit Difference"] > 2.0
43 | assert results["Accuracy"] >= 0.99
44 |
45 |
46 | def test_multitoken_names_ioi_eval(model):
47 | """
48 | Test the IOI evaluation with multi-token names in the dataset.
49 | """
50 | ds = IOIDataset(
51 | tokenizer=model.tokenizer,
52 | num_samples=100,
53 | names=["John Smith", "John Doe"],
54 | )
55 | results = ioi_eval(model, dataset=ds)
56 | assert results["Logit Difference"] > 2.0
57 | assert results["Accuracy"] >= 0.99
58 |
59 |
60 | def test_inverted_template(model):
61 | """
62 | Test IOI eval with an unnatural template (BAAA).
63 | This should result in a negative logit difference and very low accuracy.
64 | """
65 | ds = IOIDataset(
66 | tokenizer=model.tokenizer,
67 | num_samples=100,
68 | templates=["[B] met with [A]. [A] said hello to [A]"],
69 | )
70 | results = ioi_eval(model, dataset=ds)
71 | assert results["Logit Difference"] < -2.0
72 | assert results["Accuracy"] <= 0.01
73 |
--------------------------------------------------------------------------------
/tests/acceptance/test_hook_tokens.py:
--------------------------------------------------------------------------------
1 | # %%
2 |
3 | import functools
4 |
5 | import torch as t
6 | from jaxtyping import Int
7 |
8 | from transformer_lens import HookedTransformer, HookedTransformerConfig
9 | from transformer_lens.hook_points import HookPoint
10 |
11 |
12 | def test_patch_tokens():
13 | # Define small transformer
14 | cfg = HookedTransformerConfig(
15 | n_layers=1,
16 | d_mlp=10,
17 | d_model=10,
18 | d_head=5,
19 | n_heads=2,
20 | n_ctx=20,
21 | act_fn="relu",
22 | tokenizer_name="gpt2",
23 | use_hook_tokens=True,
24 | )
25 | model = HookedTransformer(cfg=cfg)
26 |
27 | # Define short prompt, and a token to replace the first token with (note this is index 1, because BOS)
28 | prompt = "Hello World!"
29 | modified_prompt = "Hi World!"
30 | new_first_token = model.to_single_token("Hi")
31 |
32 | # Define hook function to alter the first token
33 | def hook_fn(tokens: Int[t.Tensor, "batch seq"], hook: HookPoint, new_first_token: int):
34 | assert (
35 | tokens[0, 0].item() != new_first_token
36 | ) # Need new_first_token to be different from original
37 | tokens[0, 0] = new_first_token
38 | return tokens
39 |
40 | # Run with hooks
41 | out_from_hook = model.run_with_hooks(
42 | prompt,
43 | prepend_bos=False,
44 | fwd_hooks=[("hook_tokens", functools.partial(hook_fn, new_first_token=new_first_token))],
45 | )
46 |
47 | out_direct = model(modified_prompt, prepend_bos=False)
48 |
49 | t.testing.assert_close(out_from_hook, out_direct)
50 |
--------------------------------------------------------------------------------
/tests/acceptance/test_tokenizer_special_tokens.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer
2 |
3 | import transformer_lens.loading_from_pretrained as loading
4 | from transformer_lens import HookedTransformer, HookedTransformerConfig
5 |
6 | # Get's tedious typing these out everytime I want to sweep over all the distinct small models
7 | MODEL_TESTING_LIST = [
8 | "solu-1l",
9 | "gpt2-small",
10 | "gpt-neo-125M",
11 | "opt-125m",
12 | "opt-30b",
13 | "stanford-gpt2-small-a",
14 | "pythia-70m",
15 | ]
16 |
17 |
18 | def test_d_vocab_from_tokenizer():
19 | cfg = HookedTransformerConfig(
20 | n_layers=1, d_mlp=10, d_model=10, d_head=5, n_heads=2, n_ctx=20, act_fn="relu"
21 | )
22 | test_string = "a fish."
23 | # Test tokenizers for different models
24 | for model_name in MODEL_TESTING_LIST:
25 | if model_name == "solu-1l":
26 | tokenizer_name = "NeelNanda/gpt-neox-tokenizer-digits"
27 | else:
28 | tokenizer_name = loading.get_official_model_name(model_name)
29 |
30 | model = HookedTransformer(cfg=cfg, tokenizer=AutoTokenizer.from_pretrained(tokenizer_name))
31 |
32 | tokens_with_bos = model.to_tokens(test_string)
33 | tokens_without_bos = model.to_tokens(test_string, prepend_bos=False)
34 |
35 | # Check that the lengths are different by one
36 | assert (
37 | tokens_with_bos.shape[-1] == tokens_without_bos.shape[-1] + 1
38 | ), "BOS Token not added when expected"
39 | # Check that we don't have BOS when we disable the flag
40 | assert (
41 | tokens_without_bos.squeeze()[0] != model.tokenizer.bos_token_id
42 | ), "BOS token is present when it shouldn't be"
43 |
--------------------------------------------------------------------------------
/tests/integration/test_attention_mask.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from transformer_lens import utils
4 | from transformer_lens.HookedTransformer import HookedTransformer
5 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
6 |
7 |
8 | def test_attention_mask():
9 | # Verify the attention mask attends properly, including for low attention scores
10 | cfg = HookedTransformerConfig(
11 | d_head=1,
12 | d_model=12,
13 | d_vocab=2,
14 | n_ctx=5,
15 | n_layers=1,
16 | attn_only=True,
17 | attention_dir="causal",
18 | )
19 | model = HookedTransformer(cfg)
20 | input_length = 5
21 | input = torch.ones((1, input_length), dtype=torch.int64)
22 | layer = 0
23 | low_attn_score = 1e-6
24 | ones_input_matrix = torch.ones((input_length, input_length))
25 | masked = torch.triu(ones_input_matrix, diagonal=1).bool()
26 |
27 | def attn_scores_hook(attn_scores, hook):
28 | assert torch.all(
29 | attn_scores[:, :, masked] == float("-inf")
30 | ), "Attention scores excluded by the mask are not being set to -inf"
31 |
32 | # Set low attention scores that are attended to by the mask
33 | attn_scores[:, :, ~masked] = low_attn_score
34 |
35 | return attn_scores
36 |
37 | def attn_hook(attn, hook):
38 | assert torch.all(attn[:, :, masked] == 0), "Attention pattern attends outside the mask"
39 |
40 | return attn
41 |
42 | fwd_hooks = [
43 | (utils.get_act_name("attn_scores", layer), attn_scores_hook),
44 | (utils.get_act_name("attn", layer), attn_hook),
45 | ]
46 |
47 | model.run_with_hooks(input, fwd_hooks=fwd_hooks)
48 |
49 |
50 | def test_masked_tokens():
51 | """Test that masking tokens works as expected."""
52 | MODEL = "solu-1l"
53 | prompts = [
54 | "Hello world!",
55 | "The quick brown fox jumps over the lazy dog.",
56 | ]
57 | model = HookedTransformer.from_pretrained(MODEL)
58 | tokens = model.to_tokens(prompts)
59 |
60 | # Part 1: If the mask is all ones, the output should be the same as if there was no mask.
61 | full_mask = torch.ones_like(tokens)
62 | no_mask_out = model(tokens)
63 | full_mask_out = model(tokens, attention_mask=full_mask)
64 | assert torch.allclose(no_mask_out, full_mask_out), "Full mask should be equivalent to no mask"
65 |
66 | # Part 2: If the mask has a column of zeros, the output should be the same as if that token
67 | # position was removed from the input.
68 | remove_tok_idx = 2
69 | edited_tokens = torch.cat([tokens[:, :remove_tok_idx], tokens[:, remove_tok_idx + 1 :]], dim=1)
70 | edited_mask = full_mask.clone()
71 | edited_mask[:, remove_tok_idx] = 0
72 | edited_no_mask_out = model(edited_tokens)
73 | edited_mask_out = model(tokens, attention_mask=edited_mask)
74 | edited_mask_out = torch.cat(
75 | [edited_mask_out[:, :remove_tok_idx], edited_mask_out[:, remove_tok_idx + 1 :]], dim=1
76 | )
77 | assert torch.allclose(
78 | edited_no_mask_out, edited_mask_out, atol=1e-4
79 | ), "Edited mask should be equivalent to no mask"
80 |
--------------------------------------------------------------------------------
/tests/integration/test_cache_hook_names.py:
--------------------------------------------------------------------------------
1 | from transformer_lens import HookedTransformer
2 |
3 | MODEL = "solu-1l"
4 |
5 | prompt = "Hello World!"
6 | model = HookedTransformer.from_pretrained(MODEL)
7 |
8 | act_names_in_cache = [
9 | "hook_embed",
10 | "hook_pos_embed",
11 | "blocks.0.hook_resid_pre",
12 | "blocks.0.ln1.hook_scale",
13 | "blocks.0.ln1.hook_normalized",
14 | "blocks.0.attn.hook_q",
15 | "blocks.0.attn.hook_k",
16 | "blocks.0.attn.hook_v",
17 | "blocks.0.attn.hook_attn_scores",
18 | "blocks.0.attn.hook_pattern",
19 | "blocks.0.attn.hook_z",
20 | "blocks.0.hook_attn_out",
21 | "blocks.0.hook_resid_mid",
22 | "blocks.0.ln2.hook_scale",
23 | "blocks.0.ln2.hook_normalized",
24 | "blocks.0.mlp.hook_pre",
25 | "blocks.0.mlp.hook_mid",
26 | "blocks.0.mlp.ln.hook_scale",
27 | "blocks.0.mlp.ln.hook_normalized",
28 | "blocks.0.mlp.hook_post",
29 | "blocks.0.hook_mlp_out",
30 | "blocks.0.hook_resid_post",
31 | "ln_final.hook_scale",
32 | "ln_final.hook_normalized",
33 | ]
34 |
35 |
36 | def test_cache_hook_names():
37 | _, cache = model.run_with_cache(prompt)
38 | assert list(cache.keys()) == act_names_in_cache
39 |
--------------------------------------------------------------------------------
/tests/integration/test_create_hooked_encoder.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from transformers import AutoTokenizer, BertTokenizerFast
3 |
4 | from transformer_lens import HookedEncoder, HookedTransformerConfig
5 |
6 |
7 | @pytest.fixture
8 | def cfg():
9 | return HookedTransformerConfig(d_head=4, d_model=12, n_ctx=5, n_layers=3, act_fn="gelu")
10 |
11 |
12 | def test_pass_tokenizer(cfg):
13 | tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
14 | model = HookedEncoder(cfg, tokenizer=tokenizer)
15 | assert model.tokenizer == tokenizer
16 |
17 |
18 | def test_load_tokenizer_from_config(cfg):
19 | cfg.tokenizer_name = "bert-base-cased"
20 | model = HookedEncoder(cfg)
21 | assert isinstance(model.tokenizer, BertTokenizerFast)
22 |
23 |
24 | def test_load_without_tokenizer(cfg):
25 | cfg.d_vocab = 22
26 | model = HookedEncoder(cfg)
27 | assert model.tokenizer is None
28 |
29 |
30 | def test_cannot_load_without_tokenizer_or_d_vocab(cfg):
31 | with pytest.raises(AssertionError) as e:
32 | HookedEncoder(cfg)
33 | assert "Must provide a tokenizer if d_vocab is not provided" in str(e.value)
34 |
--------------------------------------------------------------------------------
/tests/integration/test_cross_entropy_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from transformer_lens.HookedTransformer import HookedTransformer
4 |
5 |
6 | def test_cross_entropy_attention_mask():
7 | """Check that adding a bunch of masked tokens to the input does not change the loss."""
8 | MODEL = "solu-1l"
9 | model = HookedTransformer.from_pretrained(MODEL)
10 |
11 | # Step 1: Get the default loss on a prompt
12 | prompt = ["The quick brown fox jumps over the lazy dog."]
13 | default_tokens = model.to_tokens(prompt)
14 | default_attention_mask = torch.ones_like(default_tokens)
15 | default_loss = model(default_tokens, return_type="loss")
16 | ones_mask_loss = model(
17 | default_tokens, attention_mask=default_attention_mask, return_type="loss"
18 | )
19 | assert torch.allclose(default_loss, ones_mask_loss, atol=1e-6)
20 |
21 | # Step 2: Get the loss when we add some extra tokens to the input and set their attention mask
22 | # to zero
23 | extra_prompt = ["Lorem ipsum dolor sit amet, consectetur adipiscing elit."]
24 | extra_tokens = model.to_tokens(extra_prompt)
25 | extra_zeros_attention_mask = torch.zeros_like(extra_tokens)
26 |
27 | combined_tokens = torch.cat([default_tokens, extra_tokens], dim=1)
28 | combined_attention_mask = torch.cat([default_attention_mask, extra_zeros_attention_mask], dim=1)
29 | combined_masked_loss = model(
30 | combined_tokens, attention_mask=combined_attention_mask, return_type="loss"
31 | )
32 | assert torch.allclose(default_loss, combined_masked_loss)
33 |
--------------------------------------------------------------------------------
/tests/integration/test_d_vocab.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer
2 |
3 | from transformer_lens import HookedTransformer, HookedTransformerConfig
4 |
5 |
6 | def test_d_vocab_from_tokenizer():
7 | cfg = HookedTransformerConfig(
8 | n_layers=1, d_mlp=10, d_model=10, d_head=5, n_heads=2, n_ctx=20, act_fn="relu"
9 | )
10 | model = HookedTransformer(cfg=cfg, tokenizer=AutoTokenizer.from_pretrained("gpt2"))
11 | assert model.cfg.d_vocab == 50257
12 | assert model.cfg.d_vocab_out == 50257
13 |
14 |
15 | def test_d_vocab_from_tokenizer_name():
16 | cfg = HookedTransformerConfig(
17 | n_layers=1,
18 | d_mlp=10,
19 | d_model=10,
20 | d_head=5,
21 | n_heads=2,
22 | n_ctx=20,
23 | act_fn="relu",
24 | tokenizer_name="gpt2",
25 | )
26 | model = HookedTransformer(cfg=cfg)
27 | assert model.cfg.d_vocab == 50257
28 | assert model.cfg.d_vocab_out == 50257
29 |
30 |
31 | def test_d_vocab_out_set():
32 | cfg = HookedTransformerConfig(
33 | n_layers=1,
34 | d_mlp=10,
35 | d_model=10,
36 | d_head=5,
37 | n_heads=2,
38 | n_ctx=20,
39 | act_fn="relu",
40 | d_vocab=100,
41 | d_vocab_out=90,
42 | )
43 | model = HookedTransformer(cfg=cfg)
44 | assert model.cfg.d_vocab == 100
45 | assert model.cfg.d_vocab_out == 90
46 |
47 |
48 | def test_d_vocab_out_set_d_vocab_infer():
49 | cfg = HookedTransformerConfig(
50 | n_layers=1,
51 | d_mlp=10,
52 | d_model=10,
53 | d_head=5,
54 | n_heads=2,
55 | n_ctx=20,
56 | act_fn="relu",
57 | d_vocab_out=90,
58 | tokenizer_name="gpt2",
59 | )
60 | model = HookedTransformer(cfg=cfg)
61 | assert model.cfg.d_vocab == 50257
62 | assert model.cfg.d_vocab_out == 90
63 |
--------------------------------------------------------------------------------
/tests/integration/test_loading_from_pretrained.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests that verify than an arbitrary component (e.g. Embed) can be initialized using dict and object versions of HookedTransformerConfig and HookedEncoderConfig.
3 | """
4 |
5 | from transformer_lens import loading_from_pretrained as loading
6 |
7 |
8 | def test_get_basic_config():
9 | cfg = loading.get_basic_config("gpt2-small")
10 | assert cfg.d_model
11 | assert cfg.layer_norm_eps
12 | assert cfg.d_vocab
13 | assert cfg.init_range
14 | assert cfg.n_ctx
15 | assert cfg.d_head
16 | assert cfg.d_mlp
17 | assert cfg.n_heads
18 | assert cfg.n_layers
19 |
--------------------------------------------------------------------------------
/tests/integration/test_match_huggingface.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import pytest
4 | import torch
5 | from transformers import AutoModelForCausalLM
6 |
7 | from transformer_lens import HookedTransformer
8 |
9 |
10 | class TestMatchHuggingFace:
11 | # fixtures
12 | @pytest.fixture(scope="class", params=["gpt2"])
13 | def model_name(self, request):
14 | return request.param
15 |
16 | # tests
17 | def test_compare_huggingface_mlp_match_local_implementation(self, model_name):
18 | tl_model = HookedTransformer.from_pretrained_no_processing(model_name, device="cpu")
19 | hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu")
20 | tensor_shape = (3, 5, tl_model.cfg.d_model)
21 | test_tensor = torch.randn(tensor_shape)
22 |
23 | for layer_n in range(len(tl_model.blocks)):
24 | tl_out = tl_model.blocks[layer_n].mlp(test_tensor)
25 | hf_out = hf_model.transformer.h[layer_n].mlp(test_tensor)
26 |
27 | assert torch.sum(tl_out == hf_out) == math.prod(tensor_shape)
28 |
29 | def test_compare_huggingface_attention_match_local_implementation(self, model_name):
30 | tl_model = HookedTransformer.from_pretrained_no_processing(model_name, device="cpu")
31 | hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu")
32 | batch, pos, d_model = 3, 5, tl_model.cfg.d_model
33 | input = torch.randn(batch, pos, d_model)
34 |
35 | for layer_n in range(len(tl_model.blocks)):
36 | tl_out = tl_model.blocks[layer_n].attn(
37 | query_input=input,
38 | key_input=input,
39 | value_input=input,
40 | past_kv_cache_entry=None,
41 | attention_mask=None,
42 | )
43 | hf_out, _, _ = hf_model.transformer.h[layer_n].attn(
44 | hidden_states=input, output_attentions=True
45 | )
46 |
47 | assert torch.sum(tl_out == hf_out) == math.prod(tl_out.shape)
48 |
--------------------------------------------------------------------------------
/tests/integration/test_utils_tokens.py:
--------------------------------------------------------------------------------
1 | from unittest import mock
2 |
3 | import numpy as np
4 | import pytest
5 | import torch
6 |
7 | import transformer_lens.utils as utils
8 | from transformer_lens import HookedTransformer
9 |
10 | MODEL = "solu-1l"
11 |
12 | model = HookedTransformer.from_pretrained(MODEL)
13 |
14 |
15 | @pytest.fixture
16 | def nested_list_1():
17 | return [1]
18 |
19 |
20 | @pytest.fixture
21 | def nested_list_1x1():
22 | return [[6]]
23 |
24 |
25 | @pytest.fixture
26 | def nested_list_1x3():
27 | return [[1, 2, 3]]
28 |
29 |
30 | def test_to_str_tokens(nested_list_1, nested_list_1x1, nested_list_1x3):
31 | tensor_1_to_str_tokens = model.to_str_tokens(torch.tensor(nested_list_1))
32 | assert isinstance(tensor_1_to_str_tokens, list)
33 | assert len(tensor_1_to_str_tokens) == 1
34 | assert isinstance(tensor_1_to_str_tokens[0], str)
35 |
36 | tensor_1x1_to_str_tokens = model.to_str_tokens(torch.tensor(nested_list_1x1))
37 | assert isinstance(tensor_1x1_to_str_tokens, list)
38 | assert len(tensor_1x1_to_str_tokens) == 1
39 | assert isinstance(tensor_1x1_to_str_tokens[0], str)
40 |
41 | ndarray_1_to_str_tokens = model.to_str_tokens(np.array(nested_list_1))
42 | assert isinstance(ndarray_1_to_str_tokens, list)
43 | assert len(ndarray_1_to_str_tokens) == 1
44 | assert isinstance(ndarray_1_to_str_tokens[0], str)
45 |
46 | ndarray_1x1_to_str_tokens = model.to_str_tokens(np.array(nested_list_1x1))
47 | assert isinstance(ndarray_1x1_to_str_tokens, list)
48 | assert len(ndarray_1x1_to_str_tokens) == 1
49 | assert isinstance(ndarray_1x1_to_str_tokens[0], str)
50 |
51 | single_int_to_single_str_token = model.to_single_str_token(3)
52 | assert isinstance(single_int_to_single_str_token, str)
53 |
54 | squeezable_tensor_to_str_tokens = model.to_str_tokens(torch.tensor(nested_list_1x3))
55 | assert isinstance(squeezable_tensor_to_str_tokens, list)
56 | assert len(squeezable_tensor_to_str_tokens) == 3
57 | assert isinstance(squeezable_tensor_to_str_tokens[0], str)
58 | assert isinstance(squeezable_tensor_to_str_tokens[1], str)
59 | assert isinstance(squeezable_tensor_to_str_tokens[2], str)
60 |
61 | squeezable_ndarray_to_str_tokens = model.to_str_tokens(np.array(nested_list_1x3))
62 | assert isinstance(squeezable_ndarray_to_str_tokens, list)
63 | assert len(squeezable_ndarray_to_str_tokens) == 3
64 | assert isinstance(squeezable_ndarray_to_str_tokens[0], str)
65 | assert isinstance(squeezable_ndarray_to_str_tokens[1], str)
66 | assert isinstance(squeezable_ndarray_to_str_tokens[2], str)
67 |
68 |
69 | @pytest.mark.parametrize(
70 | "prepend_space_to_answer, tokenized_prompt, tokenized_answer",
71 | [
72 | (
73 | True,
74 | [
75 | "<|BOS|>",
76 | "The",
77 | " circumference",
78 | " is",
79 | " the",
80 | " perimeter",
81 | " of",
82 | " the",
83 | " circ",
84 | ],
85 | [" le", "."],
86 | ),
87 | (
88 | False,
89 | [
90 | "<|BOS|>",
91 | "The",
92 | " circumference",
93 | " is",
94 | " the",
95 | " perimeter",
96 | " of",
97 | " the",
98 | " circ",
99 | ],
100 | ["le", "."],
101 | ),
102 | ],
103 | )
104 | @mock.patch("builtins.print")
105 | def test_test_prompt(
106 | mocked_print,
107 | prepend_space_to_answer,
108 | tokenized_prompt,
109 | tokenized_answer,
110 | ):
111 | """
112 | Tests that utils.test_prompt produces the correct tokenization. In particular, when prepend_space_to_answer = False, the last token of the prompt
113 | and the first answer token should not be turned into one token (e.g. 'circ' and 'le' don't become 'circle'). See https://github.com/TransformerLensOrg/TransformerLens/issues/271
114 | for a more detailed explanation.
115 | """
116 | utils.test_prompt(
117 | "The circumference is the perimeter of the circ",
118 | "le.",
119 | model,
120 | prepend_space_to_answer=prepend_space_to_answer,
121 | )
122 |
123 | printed_tokenized_prompt = mock.call("Tokenized prompt:", tokenized_prompt)
124 | printed_tokenized_answer = mock.call("Tokenized answer:", tokenized_answer)
125 |
126 | assert mocked_print.mock_calls[0] == printed_tokenized_prompt
127 | assert mocked_print.mock_calls[1] == printed_tokenized_answer
128 |
--------------------------------------------------------------------------------
/tests/manual_checks/manual_checks_type_annotations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from jaxtyping import Float
3 |
4 | from transformer_lens import HookedTransformer
5 |
6 | MODEL = "gpt2"
7 | model = HookedTransformer.from_pretrained(MODEL)
8 |
9 | prompt = "Hello World!"
10 | tokens = model.to_tokens(prompt, prepend_bos=False)
11 | logits_tokens = model(tokens)
12 | logits_text: Float[torch.Tensor, "1 n_tokens d_vocab"] = model(prompt, prepend_bos=False)
13 |
14 | # n.b. that i used this file to see if my type annotations were working- they were! i occasionally
15 | # changed one of the sizes and saw that the type checker caught it.
16 |
--------------------------------------------------------------------------------
/tests/manual_checks/manual_checks_typing.py:
--------------------------------------------------------------------------------
1 | # Adapted from [HookedTransformer_Demo.ipynb]. Useful for testing that all the typing mechanisms work
2 | # out.
3 |
4 | # %%
5 |
6 | import torch as t
7 | from jaxtyping import Float
8 |
9 | from transformer_lens import HookedTransformer, utils
10 |
11 | DEVICE = utils.get_device()
12 | MODEL = "gpt2"
13 |
14 | # %%
15 | model = HookedTransformer.from_pretrained(MODEL)
16 | model.to(DEVICE)
17 |
18 | # %%
19 |
20 | prompt = "Hello World!"
21 | tokens = model.to_tokens(prompt, prepend_bos=False)
22 | logits_tokens = model(tokens)
23 | logits_text: Float[t.Tensor, "1 n_tokens d_vocab"] = model(prompt, prepend_bos=False)
24 |
25 | # %%
26 |
27 | logits_text.shape
28 | # %%
29 |
--------------------------------------------------------------------------------
/tests/unit/components/mlps/test_can_be_used_as_mlp.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 |
3 | import pytest
4 | import torch
5 |
6 | from transformer_lens.components import LayerNorm, LayerNormPre
7 | from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
8 | from transformer_lens.hook_points import HookPoint
9 | from transformer_lens.utils import solu
10 |
11 |
12 | @pytest.fixture
13 | def cfg() -> Dict[str, Any]:
14 | return {
15 | "n_layers": 12,
16 | "n_ctx": 1024,
17 | "d_head": 64,
18 | "d_model": 128,
19 | "d_mlp": 256,
20 | "dtype": torch.float32,
21 | "act_fn": "solu_ln",
22 | "normalization_type": "LN",
23 | "load_in_4bit": False,
24 | }
25 |
26 |
27 | def test_initialization(cfg: Dict[str, Any]):
28 | CanBeUsedAsMLP(cfg)
29 |
30 |
31 | def test_initialization_fails_without_d_mlp(cfg: Dict[str, Any]):
32 | cfg["d_mlp"] = None
33 | pytest.raises(ValueError)
34 | CanBeUsedAsMLP(cfg)
35 |
36 |
37 | def test_select_activation_function_selects_function():
38 | cfg = {
39 | "n_layers": 12,
40 | "n_ctx": 1024,
41 | "d_head": 64,
42 | "d_model": 128,
43 | "d_mlp": 256,
44 | "dtype": torch.float32,
45 | "act_fn": "silu",
46 | "normalization_type": "LN",
47 | "load_in_4bit": False,
48 | }
49 |
50 | model = CanBeUsedAsMLP(cfg)
51 | model.select_activation_function()
52 | assert model.act_fn is not None
53 |
54 |
55 | def test_select_activation_function_with_layer_norm():
56 | cfg = {
57 | "n_layers": 12,
58 | "n_ctx": 1024,
59 | "d_head": 64,
60 | "d_model": 128,
61 | "d_mlp": 256,
62 | "dtype": torch.float32,
63 | "act_fn": "solu_ln",
64 | "normalization_type": "LN",
65 | "load_in_4bit": False,
66 | }
67 |
68 | model = CanBeUsedAsMLP(cfg)
69 | model.select_activation_function()
70 | assert model.act_fn == solu
71 | assert isinstance(model.hook_mid, HookPoint)
72 | assert isinstance(model.ln, LayerNorm)
73 |
74 |
75 | def test_select_activation_function_with_layer_norm_pre():
76 | cfg = {
77 | "n_layers": 12,
78 | "n_ctx": 1024,
79 | "d_head": 64,
80 | "d_model": 128,
81 | "d_mlp": 256,
82 | "dtype": torch.float32,
83 | "act_fn": "solu_ln",
84 | "normalization_type": "LNPre",
85 | "load_in_4bit": False,
86 | }
87 |
88 | model = CanBeUsedAsMLP(cfg)
89 | model.select_activation_function()
90 | assert model.act_fn == solu
91 | assert isinstance(model.hook_mid, HookPoint)
92 | assert isinstance(model.ln, LayerNormPre)
93 |
--------------------------------------------------------------------------------
/tests/unit/components/mlps/test_gated_mlp.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 |
3 | import pytest
4 | import torch
5 | import torch.nn as nn
6 |
7 | from transformer_lens.components import GatedMLP, LayerNorm
8 | from transformer_lens.utils import solu
9 |
10 |
11 | @pytest.fixture
12 | def cfg() -> Dict[str, Any]:
13 | return {
14 | "n_layers": 12,
15 | "n_ctx": 1024,
16 | "d_head": 64,
17 | "d_model": 128,
18 | "d_mlp": 256,
19 | "dtype": torch.float32,
20 | "act_fn": "solu_ln",
21 | "normalization_type": "LN",
22 | "load_in_4bit": False,
23 | }
24 |
25 |
26 | def test_initialization(cfg: Dict[str, Any]):
27 | model = GatedMLP(cfg)
28 | assert isinstance(model.W_in, nn.Parameter)
29 | assert isinstance(model.W_gate, nn.Parameter)
30 | assert isinstance(model.W_out, nn.Parameter)
31 | assert isinstance(model.b_in, nn.Parameter)
32 | assert isinstance(model.b_out, nn.Parameter)
33 | assert model.act_fn == solu
34 | assert isinstance(model.ln, LayerNorm)
35 |
36 |
37 | def test_forward(cfg: Dict[str, Any]):
38 | model = GatedMLP(cfg)
39 | x = torch.randn(2, 10, cfg["d_model"])
40 | output = model(x)
41 | assert output.shape == (2, 10, cfg["d_model"])
42 |
--------------------------------------------------------------------------------
/tests/unit/components/mlps/test_mlp.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 |
3 | import pytest
4 | import torch
5 |
6 | from transformer_lens.components import LayerNorm
7 | from transformer_lens.components.mlps.mlp import MLP
8 | from transformer_lens.hook_points import HookPoint
9 |
10 |
11 | @pytest.fixture
12 | def cfg() -> Dict[str, Any]:
13 | return {
14 | "n_layers": 12,
15 | "n_ctx": 1024,
16 | "d_head": 64,
17 | "d_model": 128,
18 | "d_mlp": 256,
19 | "dtype": torch.float32,
20 | "act_fn": "solu_ln",
21 | "normalization_type": "LN",
22 | "load_in_4bit": False,
23 | }
24 |
25 |
26 | def test_initialization(cfg: Dict[str, Any]):
27 | MLP(cfg)
28 |
29 |
30 | def test_forward_without_layer_norm(cfg: Dict[str, Any]):
31 | cfg["act_fn"] = "solu"
32 |
33 | model = MLP(cfg)
34 |
35 | input = torch.full((1, 1, 128), 0.085)
36 |
37 | result = model(input)
38 |
39 | assert result.shape == (1, 1, 128)
40 |
41 |
42 | def test_forward_with_layer_norm(cfg: Dict[str, Any]):
43 | model = MLP(cfg)
44 | assert isinstance(model.hook_mid, HookPoint)
45 | assert isinstance(model.ln, LayerNorm)
46 |
47 | input = torch.full((1, 1, 128), 0.85)
48 | result = model(input)
49 | assert result.shape == (1, 1, 128)
50 |
--------------------------------------------------------------------------------
/tests/unit/components/mlps/test_moe.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from transformer_lens.components import MoE
4 |
5 |
6 | def test_forward():
7 | cfg = {
8 | "d_model": 32,
9 | "d_mlp": 14336,
10 | "d_head": 4,
11 | "num_experts": 32,
12 | "n_layers": 16,
13 | "n_ctx": 2048,
14 | "experts_per_token": 4,
15 | "gated_mlp": True,
16 | "act_fn": "silu",
17 | }
18 | moe = MoE(cfg)
19 |
20 | x = torch.rand((1, 4, 32))
21 | moe(x)
22 |
--------------------------------------------------------------------------------
/tests/unit/components/test_abstract_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from transformer_lens.components import AbstractAttention
4 |
5 |
6 | def test_create_alibi_slope():
7 | n_ctx = 100
8 |
9 | # Expected result computed non-vectorized way
10 | expected = torch.zeros((n_ctx, n_ctx))
11 | for row in range(n_ctx):
12 | for col in range(n_ctx):
13 | expected[row, col] = float(min(col - row, 0))
14 |
15 | # Check against the method's vectorized version
16 | result = AbstractAttention.create_alibi_slope(n_ctx)
17 | assert torch.allclose(expected, result)
18 |
19 |
20 | def test_create_alibi_bias():
21 | n_heads = 2
22 | n_ctx = 4
23 |
24 | result = AbstractAttention.create_alibi_bias(n_heads, n_ctx, torch.device("cpu"))
25 |
26 | for matrix in result:
27 | n_row, n_col = matrix.size()
28 | slope = -matrix[1, 0]
29 | # Check if upper triangle is all zeros
30 | assert torch.equal(torch.triu(matrix), torch.zeros_like(matrix))
31 |
32 | ref_lower_triangle = torch.zeros_like(matrix)
33 | for i in range(1, n_row):
34 | for j in range(i):
35 | ref_lower_triangle[i, j] = -slope * (i - j)
36 |
37 | # Check if the lower triangle is decreasing by a constant slope (towards the bottom left corner).
38 | assert torch.equal(
39 | torch.tril(matrix, diagonal=-1), torch.tril(ref_lower_triangle, diagonal=-1)
40 | )
41 |
--------------------------------------------------------------------------------
/tests/unit/components/test_attention.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import pytest
3 | import torch
4 | import torch.nn as nn
5 | from transformers.utils import is_bitsandbytes_available
6 |
7 | from transformer_lens.components import Attention
8 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
9 | from transformer_lens.utilities.attention import complex_attn_linear
10 |
11 | if is_bitsandbytes_available():
12 | from bitsandbytes.nn.modules import Params4bit
13 |
14 |
15 | def test_attention_hooked_transformer_config():
16 | cfg = HookedTransformerConfig(
17 | n_layers=12,
18 | d_model=512,
19 | n_ctx=1024,
20 | d_head=64,
21 | n_heads=8,
22 | load_in_4bit=False,
23 | dtype=torch.float32,
24 | act_fn="relu",
25 | )
26 | attn = Attention(cfg)
27 | assert attn.cfg == cfg
28 | assert attn.cfg.n_layers == 12
29 | assert attn.cfg.d_model == 512
30 | assert attn.cfg.n_ctx == 1024
31 | assert attn.cfg.d_head == 64
32 | assert attn.cfg.n_heads == 8
33 | assert attn.cfg.load_in_4bit == False
34 | assert attn.cfg.dtype == torch.float32
35 | assert attn.cfg.act_fn == "relu"
36 |
37 | assert isinstance(attn.W_K, nn.Parameter)
38 | assert isinstance(attn.W_V, nn.Parameter)
39 | assert attn.W_K.shape == (cfg.n_heads, cfg.d_model, cfg.d_head)
40 | assert attn.W_V.shape == (cfg.n_heads, cfg.d_model, cfg.d_head)
41 |
42 | assert attn.b_K.shape == (cfg.n_heads, cfg.d_head)
43 | assert attn.b_V.shape == (cfg.n_heads, cfg.d_head)
44 | assert torch.all(attn.b_K == 0)
45 | assert torch.all(attn.b_V == 0)
46 |
47 |
48 | @pytest.mark.skipif(not is_bitsandbytes_available(), reason="bitsandbytes is not available")
49 | def test_attention_load_in_4bit():
50 | cfg = HookedTransformerConfig(
51 | n_layers=12,
52 | d_model=512,
53 | n_ctx=1024,
54 | d_head=64,
55 | n_heads=8,
56 | load_in_4bit=True,
57 | dtype=torch.float32,
58 | act_fn="relu",
59 | )
60 | attn = Attention(cfg)
61 | assert attn.cfg == cfg
62 | assert attn.cfg.n_layers == 12
63 | assert attn.cfg.d_model == 512
64 | assert attn.cfg.n_ctx == 1024
65 | assert attn.cfg.d_head == 64
66 | assert attn.cfg.n_heads == 8
67 | assert attn.cfg.load_in_4bit == False
68 | assert attn.cfg.dtype == torch.float32
69 | assert attn.cfg.act_fn == "relu"
70 |
71 | assert isinstance(attn.W_K, Params4bit)
72 | assert isinstance(attn.W_V, Params4bit)
73 | nq = int((cfg.d_model * cfg.d_model) / 2)
74 | assert attn.W_K.data.shape == (nq, 1)
75 | assert attn.W_V.data.shape == (nq, 1)
76 |
77 | assert attn.b_K.shape == (cfg.n_heads, cfg.d_head)
78 | assert attn.b_V.shape == (cfg.n_heads, cfg.d_head)
79 | assert torch.all(attn.b_K == 0)
80 | assert torch.all(attn.b_V == 0)
81 |
82 |
83 | def test_attention_config_dict():
84 | cfg = {
85 | "n_layers": 12,
86 | "d_model": 512,
87 | "n_ctx": 1024,
88 | "d_head": 64,
89 | "n_heads": 8,
90 | "load_in_4bit": False,
91 | "dtype": torch.float32,
92 | "act_fn": "relu",
93 | }
94 | attn = Attention(cfg)
95 | assert attn.cfg.n_layers == 12
96 | assert attn.cfg.d_model == 512
97 | assert attn.cfg.n_ctx == 1024
98 | assert attn.cfg.d_head == 64
99 | assert attn.cfg.n_heads == 8
100 | assert attn.cfg.load_in_4bit == False
101 | assert attn.cfg.dtype == torch.float32
102 | assert attn.cfg.act_fn == "relu"
103 |
104 |
105 | def test_remove_einsum_from_complex_attn_linear():
106 | batch = 64
107 | pos = 128
108 | head_index = 8
109 | d_model = 512
110 | d_head = 64
111 | input = torch.randn(batch, pos, head_index, d_model)
112 | w = torch.randn(head_index, d_model, d_head)
113 | b = torch.randn(head_index, d_head)
114 | result_new = complex_attn_linear(input, w, b)
115 |
116 | # Check if new implementation without einsum produces correct shape
117 | assert result_new.shape == (batch, pos, head_index, d_head)
118 |
119 | # Old implementation used einsum
120 | result_old = (
121 | einops.einsum(
122 | input,
123 | w,
124 | "batch pos head_index d_model, head_index d_model d_head -> batch pos head_index d_head",
125 | )
126 | + b
127 | )
128 |
129 | # Check if the results are the same
130 | assert torch.allclose(result_new, result_old, atol=1e-4)
131 |
--------------------------------------------------------------------------------
/tests/unit/factored_matrix/test_constructor.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from transformer_lens import FactoredMatrix
5 |
6 |
7 | def test_factored_matrix():
8 | A = torch.randn(5, 3)
9 | B = torch.randn(3, 7)
10 | f = FactoredMatrix(A, B)
11 |
12 | assert torch.equal(f.A, A)
13 | assert torch.equal(f.B, B)
14 |
15 | assert (f.ldim, f.mdim, f.rdim) == (5, 3, 7)
16 | assert not f.has_leading_dims
17 | assert f.shape == (5, 7)
18 |
19 |
20 | def test_factored_matrix_b_leading_dims():
21 | A = torch.ones((5, 3))
22 | B = torch.ones((2, 4, 3, 7))
23 | f = FactoredMatrix(A, B)
24 |
25 | assert f.A.shape == (2, 4, 5, 3)
26 | assert torch.equal(f.B, B)
27 |
28 | assert (f.ldim, f.mdim, f.rdim) == (5, 3, 7)
29 | assert f.has_leading_dims
30 | assert f.shape == (2, 4, 5, 7)
31 |
32 |
33 | def test_factored_matrix_a_b_leading_dims():
34 | A = torch.ones((4, 5, 3))
35 | B = torch.ones((2, 4, 3, 7))
36 | f = FactoredMatrix(A, B)
37 |
38 | assert f.A.shape == (2, 4, 5, 3)
39 | assert torch.equal(f.B, B)
40 |
41 | assert (f.ldim, f.mdim, f.rdim) == (5, 3, 7)
42 | assert f.has_leading_dims
43 | assert f.shape == (2, 4, 5, 7)
44 |
45 |
46 | def test_factored_matrix_broadcast_mismatch():
47 | A = torch.ones((9, 5, 3))
48 | B = torch.ones((2, 4, 3, 7))
49 |
50 | with pytest.raises(RuntimeError) as e:
51 | FactoredMatrix(A, B)
52 |
53 | assert "Shape mismatch" in str(e.value)
54 |
55 |
56 | @pytest.mark.skip(
57 | """
58 | AssertionError will not be reached due to jaxtyping argument consistency
59 | checks, which are enabled at test time but not run time.
60 |
61 | See https://github.com/TransformerLensOrg/TransformerLens/issues/190
62 | """
63 | )
64 | def test_factored_matrix_inner_mismatch():
65 | A = torch.ones((2, 3, 4))
66 | B = torch.ones((2, 3, 5))
67 | with pytest.raises(AssertionError) as e:
68 | FactoredMatrix(A, B)
69 |
70 | assert "inner dimension" in str(e.value)
71 |
--------------------------------------------------------------------------------
/tests/unit/factored_matrix/test_get_item.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torch.testing import assert_close
4 |
5 | from transformer_lens import FactoredMatrix
6 |
7 |
8 | @pytest.fixture
9 | def sample_factored_matrix():
10 | A = torch.rand(2, 2, 2, 2, 2)
11 | B = torch.rand(2, 2, 2, 2, 2)
12 | return FactoredMatrix(A, B)
13 |
14 |
15 | def test_getitem_int(sample_factored_matrix):
16 | result = sample_factored_matrix[0]
17 | assert_close(result.A, sample_factored_matrix.A[0])
18 | assert_close(result.B, sample_factored_matrix.B[0])
19 |
20 |
21 | def test_getitem_tuple(sample_factored_matrix):
22 | result = sample_factored_matrix[(0, 1)]
23 | assert_close(result.A, sample_factored_matrix.A[0, 1])
24 | assert_close(result.B, sample_factored_matrix.B[0, 1])
25 |
26 |
27 | def test_getitem_slice(sample_factored_matrix):
28 | result = sample_factored_matrix[:, 1]
29 | assert_close(result.A, sample_factored_matrix.A[:, 1])
30 | assert_close(result.B, sample_factored_matrix.B[:, 1])
31 |
32 |
33 | def test_getitem_error(sample_factored_matrix):
34 | with pytest.raises(IndexError):
35 | _ = sample_factored_matrix[(0, 1, 2)]
36 |
37 |
38 | def test_getitem_multiple_slices(sample_factored_matrix):
39 | result = sample_factored_matrix[:, :, 1]
40 | assert_close(result.A, sample_factored_matrix.A[:, :, 1])
41 | assert_close(result.B, sample_factored_matrix.B[:, :, 1])
42 |
43 |
44 | def test_index_dimension_get_line(sample_factored_matrix):
45 | result = sample_factored_matrix[0, 0, 0, 1]
46 | assert_close(result.AB.squeeze(), sample_factored_matrix.AB[0, 0, 0, 1])
47 |
48 |
49 | def test_index_dimension_get_element(sample_factored_matrix):
50 | result = sample_factored_matrix[0, 0, 0, 0, 1]
51 | assert_close(result.AB.squeeze(), sample_factored_matrix.AB[0, 0, 0, 0, 1])
52 |
53 |
54 | def test_index_dimension_too_big(sample_factored_matrix):
55 | with pytest.raises(Exception):
56 | _ = sample_factored_matrix[1, 1, 1, 1, 1, 1]
57 |
58 |
59 | def test_getitem_sequences(sample_factored_matrix):
60 | A_idx = [0, 1]
61 | B_idx = [0]
62 | result = sample_factored_matrix[:, :, :, A_idx, B_idx]
63 | assert_close(result.A, sample_factored_matrix.A[:, :, :, A_idx, :])
64 | assert_close(result.B, sample_factored_matrix.B[:, :, :, :, B_idx])
65 |
66 |
67 | def test_getitem_sequences_and_ints(sample_factored_matrix):
68 | A_idx = [0, 1]
69 | B_idx = 0
70 | result = sample_factored_matrix[:, :, :, A_idx, B_idx]
71 | assert_close(result.A, sample_factored_matrix.A[:, :, :, A_idx, :])
72 | # we squeeze result.B, because indexing by ints is designed not to delete dimensions
73 | assert_close(result.B.squeeze(-1), sample_factored_matrix.B[:, :, :, :, B_idx])
74 |
75 |
76 | def test_getitem_tensors(sample_factored_matrix):
77 | A_idx = torch.tensor([0, 1])
78 | B_idx = torch.tensor([0])
79 | result = sample_factored_matrix[:, :, :, A_idx, B_idx]
80 | assert_close(result.A, sample_factored_matrix.A[:, :, :, A_idx, :])
81 | assert_close(result.B, sample_factored_matrix.B[:, :, :, :, B_idx])
82 |
--------------------------------------------------------------------------------
/tests/unit/factored_matrix/test_multiply_by_scalar.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import pytest
4 | import torch
5 | from torch.testing import assert_close
6 |
7 | from transformer_lens import FactoredMatrix
8 |
9 |
10 | # This test function is parametrized with different types of scalars, including non-scalar tensors and arrays, to check that the correct errors are raised.
11 | # Considers cases with and without leading dimensions as well as left and right multiplication.
12 | @pytest.mark.parametrize(
13 | "scalar, error_expected",
14 | [
15 | # Test cases with different types of scalar values.
16 | (torch.rand(1), None), # 1-element Tensor. No error expected.
17 | (random.random(), None), # float. No error expected.
18 | (random.randint(-100, 100), None), # int. No error expected.
19 | # Test cases with non-scalar values that are expected to raise errors.
20 | (
21 | torch.rand(2, 2),
22 | AssertionError,
23 | ), # Non-scalar Tensor. AssertionError expected.
24 | (torch.rand(2), AssertionError), # Non-scalar Tensor. AssertionError expected.
25 | ],
26 | )
27 | @pytest.mark.parametrize("leading_dim", [False, True])
28 | @pytest.mark.parametrize("multiply_from_left", [False, True])
29 | def test_multiply(scalar, leading_dim, multiply_from_left, error_expected):
30 | # Prepare a FactoredMatrix, with or without leading dimensions
31 | if leading_dim:
32 | a = torch.rand(6, 2, 3)
33 | b = torch.rand(6, 3, 4)
34 | else:
35 | a = torch.rand(2, 3)
36 | b = torch.rand(3, 4)
37 |
38 | fm = FactoredMatrix(a, b)
39 |
40 | if error_expected:
41 | # If an error is expected, check that the correct exception is raised.
42 | with pytest.raises(error_expected):
43 | if multiply_from_left:
44 | _ = fm * scalar
45 | else:
46 | _ = scalar * fm
47 | else:
48 | # If no error is expected, check that the multiplication results in the correct value.
49 | # Use FactoredMatrix.AB to calculate the product of the two factor matrices before comparing with the expected value.
50 | if multiply_from_left:
51 | assert_close((fm * scalar).AB, (a @ b) * scalar)
52 | else:
53 | assert_close((scalar * fm).AB, scalar * (a @ b))
54 | # This next test is implementation dependant and can be broken and removed at any time!
55 | # It checks that the multiplication is performed on the A factor matrix.
56 | if multiply_from_left:
57 | assert_close((fm * scalar).A, a * scalar)
58 | else:
59 | assert_close((scalar * fm).A, scalar * a)
60 |
--------------------------------------------------------------------------------
/tests/unit/factored_matrix/test_multiply_by_vector.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.testing import assert_close
3 |
4 | from transformer_lens import FactoredMatrix
5 |
6 |
7 | def test_left_matmul_by_vector_left():
8 | a = torch.rand(2, 3)
9 | b = torch.rand(3, 4)
10 |
11 | fm = FactoredMatrix(a, b)
12 | vector = torch.rand(4)
13 |
14 | assert_close(fm @ vector, (a @ b) @ vector)
15 |
16 |
17 | def test_left_matmul_by_vector_leading_dim():
18 | a = torch.rand(6, 2, 3)
19 | b = torch.rand(6, 3, 4)
20 |
21 | fm = FactoredMatrix(a, b)
22 | vector = torch.rand(4)
23 |
24 | assert_close(fm @ vector, (a @ b) @ vector)
25 |
26 |
27 | def test_right_matmul_by_vector():
28 | a = torch.rand(2, 3)
29 | b = torch.rand(3, 4)
30 |
31 | fm = FactoredMatrix(a, b)
32 | vector = torch.rand(2)
33 |
34 | assert_close(vector @ fm, vector @ (a @ b))
35 |
36 |
37 | def test_right_matmul_by_vector_leading_dim():
38 | a = torch.rand(6, 2, 3)
39 | b = torch.rand(6, 3, 4)
40 |
41 | fm = FactoredMatrix(a, b)
42 | vector = torch.rand(2)
43 |
44 | assert_close(vector @ fm, vector @ (a @ b))
45 |
--------------------------------------------------------------------------------
/tests/unit/factories/test_activation_function_factory.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from transformer_lens.factories.activation_function_factory import (
5 | ActivationFunctionFactory,
6 | )
7 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
8 | from transformer_lens.utilities.activation_functions import SUPPORTED_ACTIVATIONS
9 |
10 |
11 | @pytest.mark.parametrize("act_function", SUPPORTED_ACTIVATIONS.keys())
12 | def test_pick_activation_function_runs(act_function):
13 | config = HookedTransformerConfig.unwrap(
14 | {"n_layers": 12, "n_ctx": 1024, "d_head": 64, "d_model": 128, "act_fn": act_function}
15 | )
16 | function = ActivationFunctionFactory.pick_activation_function(config)
17 | assert function is not None
18 | dummy_data = torch.zeros((1, 4, 32))
19 | result = function(dummy_data)
20 | assert isinstance(result, torch.Tensor)
21 |
--------------------------------------------------------------------------------
/tests/unit/factories/test_mlp_factory.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from transformers.utils import is_bitsandbytes_available
3 |
4 | from transformer_lens.components.mlps.gated_mlp import GatedMLP
5 | from transformer_lens.components.mlps.gated_mlp_4bit import GatedMLP4Bit
6 | from transformer_lens.components.mlps.mlp import MLP
7 | from transformer_lens.factories.mlp_factory import MLPFactory
8 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
9 |
10 |
11 | def test_create_mlp_basic():
12 | config = HookedTransformerConfig.unwrap(
13 | {
14 | "n_layers": 12,
15 | "n_ctx": 1024,
16 | "d_head": 64,
17 | "d_model": 128,
18 | "act_fn": "solu",
19 | }
20 | )
21 | mlp = MLPFactory.create_mlp(config)
22 | assert isinstance(mlp, MLP)
23 |
24 |
25 | def test_create_mlp_gated():
26 | config = HookedTransformerConfig.unwrap(
27 | {
28 | "n_layers": 12,
29 | "n_ctx": 1024,
30 | "d_head": 64,
31 | "d_model": 128,
32 | "act_fn": "solu",
33 | "gated_mlp": True,
34 | }
35 | )
36 | mlp = MLPFactory.create_mlp(config)
37 | assert isinstance(mlp, GatedMLP)
38 |
39 |
40 | @pytest.mark.skipif(
41 | not is_bitsandbytes_available(),
42 | reason="4 bit not available on current architecture",
43 | )
44 | def test_create_mlp_gated_4bit():
45 | config = HookedTransformerConfig.unwrap(
46 | {
47 | "n_layers": 12,
48 | "n_ctx": 1024,
49 | "d_head": 64,
50 | "d_model": 128,
51 | "act_fn": "solu",
52 | "gated_mlp": True,
53 | "load_in_4bit": True,
54 | }
55 | )
56 | mlp = MLPFactory.create_mlp(config)
57 | assert isinstance(mlp, GatedMLP4Bit)
58 |
59 |
60 | def test_create_moe():
61 | if is_bitsandbytes_available():
62 | config = HookedTransformerConfig.unwrap(
63 | {
64 | "n_layers": 12,
65 | "n_ctx": 1024,
66 | "d_head": 64,
67 | "d_model": 128,
68 | "act_fn": "solu",
69 | "gated_mlp": True,
70 | "num_experts": 32,
71 | }
72 | )
73 | mlp = MLPFactory.create_mlp(config)
74 | assert isinstance(mlp, GatedMLP4Bit)
75 |
--------------------------------------------------------------------------------
/tests/unit/pretrained_weight_conversions/test_neo.py:
--------------------------------------------------------------------------------
1 | from unittest import mock
2 |
3 | import torch
4 |
5 | from transformer_lens import HookedTransformer
6 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
7 | from transformer_lens.pretrained.weight_conversions.neo import convert_neo_weights
8 |
9 |
10 | def get_default_config():
11 | return HookedTransformerConfig(
12 | d_model=128, d_head=8, n_heads=16, n_ctx=128, n_layers=1, d_vocab=50257, attn_only=True
13 | )
14 |
15 |
16 | def test_convert_neo_weights_exposed():
17 | cfg = get_default_config()
18 |
19 | class MockNeo:
20 | def __init__(self):
21 | self.transformer = HookedTransformer(cfg)
22 | self.transformer.wte = torch.nn.Embedding(cfg.d_vocab, cfg.d_model)
23 | self.transformer.wpe = torch.nn.Embedding(cfg.n_ctx, cfg.d_model)
24 | self.transformer.final_norm = torch.nn.LayerNorm(cfg.d_model)
25 | self.transformer.h = [mock.Mock() for _ in range(cfg.n_layers)]
26 | self.lm_head = torch.nn.Linear(cfg.d_model, cfg.d_vocab)
27 |
28 | for layer in self.transformer.h:
29 | layer.ln_1 = torch.nn.LayerNorm(cfg.d_model)
30 | layer.ln_2 = torch.nn.LayerNorm(cfg.d_model)
31 | layer.attn = mock.Mock()
32 | layer.attn.attention = mock.Mock()
33 | layer.attn.attention.q_proj = torch.nn.Linear(cfg.d_model, cfg.d_model)
34 | layer.attn.attention.k_proj = torch.nn.Linear(cfg.d_model, cfg.d_model)
35 | layer.attn.attention.v_proj = torch.nn.Linear(cfg.d_model, cfg.d_model)
36 | layer.attn.attention.out_proj = torch.nn.Linear(cfg.d_model, cfg.d_model)
37 | layer.mlp = mock.Mock()
38 | layer.mlp.c_fc = torch.nn.Linear(cfg.d_model, cfg.d_model)
39 | layer.mlp.c_proj = torch.nn.Linear(cfg.d_model, cfg.d_model)
40 |
41 | self.transformer.ln_f = torch.nn.LayerNorm(cfg.d_model)
42 |
43 | neo = MockNeo()
44 |
45 | try:
46 | convert_neo_weights(neo, cfg)
47 | function_works = True
48 | except Exception as e:
49 | function_works = False
50 | print(f"The convert_neo_weights function raised an error: {e}")
51 |
52 | assert function_works
53 |
--------------------------------------------------------------------------------
/tests/unit/test_hook_points.py:
--------------------------------------------------------------------------------
1 | from unittest import mock
2 |
3 | from transformer_lens.hook_points import HookPoint
4 |
5 |
6 | def setup_hook_point_and_hook():
7 | hook_point = HookPoint()
8 |
9 | def hook(activation, hook):
10 | return activation
11 |
12 | return hook_point, hook
13 |
14 |
15 | @mock.patch("torch.utils.hooks.RemovableHandle", autospec=True)
16 | def test_add_hook_forward(mock_handle):
17 | mock_handle.return_value.id = 0
18 | hook_point, hook = setup_hook_point_and_hook()
19 | hook_point.add_hook(hook, dir="fwd")
20 | assert len(hook_point.fwd_hooks) == 1
21 |
22 |
23 | @mock.patch("torch.utils.hooks.RemovableHandle", autospec=True)
24 | def test_add_hook_backward(mock_handle):
25 | mock_handle.return_value.id = 0
26 | hook_point, hook = setup_hook_point_and_hook()
27 | hook_point.add_hook(hook, dir="bwd")
28 | assert len(hook_point.bwd_hooks) == 1
29 |
30 |
31 | @mock.patch("torch.utils.hooks.RemovableHandle", autospec=True)
32 | def test_add_hook_permanent(mock_handle):
33 | mock_handle.return_value.id = 0
34 | hook_point, hook = setup_hook_point_and_hook()
35 | hook_point.add_hook(hook, dir="fwd", is_permanent=True)
36 | assert hook_point.fwd_hooks[0].is_permanent
37 |
38 |
39 | @mock.patch("torch.utils.hooks.RemovableHandle", autospec=True)
40 | def test_add_hook_with_level(mock_handle):
41 | mock_handle.return_value.id = 0
42 | hook_point, hook = setup_hook_point_and_hook()
43 | hook_point.add_hook(hook, dir="fwd", level=5)
44 | assert hook_point.fwd_hooks[0].context_level == 5
45 |
46 |
47 | @mock.patch("torch.utils.hooks.RemovableHandle")
48 | def test_add_hook_prepend(mock_handle):
49 | mock_handle.id = 0
50 | mock_handle.next_id = 1
51 |
52 | hook_point, _ = setup_hook_point_and_hook()
53 |
54 | def hook1(activation, hook):
55 | return activation
56 |
57 | def hook2(activation, hook):
58 | return activation
59 |
60 | hook_point.add_hook(hook1, dir="fwd")
61 | hook_point.add_hook(hook2, dir="fwd", prepend=True)
62 |
63 | assert len(hook_point.fwd_hooks) == 2
64 | assert hook_point.fwd_hooks[0].hook.id == 2
65 | assert hook_point.fwd_hooks[1].hook.id == 1
66 |
--------------------------------------------------------------------------------
/tests/unit/test_hooked_root_module.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import Mock
2 |
3 | from transformer_lens.hook_points import HookedRootModule
4 |
5 | MODEL_NAME = "solu-2l"
6 |
7 |
8 | def test_enable_hook_with_name():
9 | model = HookedRootModule()
10 | model.mod_dict = {"linear": Mock()}
11 | model.context_level = 5
12 |
13 | hook = lambda x: False
14 | dir = "fwd"
15 |
16 | model._enable_hook_with_name("linear", hook=hook, dir=dir)
17 |
18 | model.mod_dict["linear"].add_hook.assert_called_with(hook, dir="fwd", level=5)
19 |
20 |
21 | def test_enable_hooks_for_points():
22 | model = HookedRootModule()
23 | model.mod_dict = {}
24 | model.context_level = 5
25 |
26 | hook_points = {
27 | "linear": Mock(),
28 | "attn": Mock(),
29 | }
30 |
31 | enabled = lambda x: x == "attn"
32 |
33 | hook = lambda x: False
34 | dir = "bwd"
35 |
36 | print(hook_points.items())
37 | model._enable_hooks_for_points(
38 | hook_points=hook_points.items(), enabled=enabled, hook=hook, dir=dir
39 | )
40 |
41 | hook_points["attn"].add_hook.assert_called_with(hook, dir="bwd", level=5)
42 | hook_points["linear"].add_hook.assert_not_called()
43 |
44 |
45 | def test_enable_hook_with_string_param():
46 | model = HookedRootModule()
47 | model.mod_dict = {"linear": Mock()}
48 | model.context_level = 5
49 |
50 | hook = lambda x: False
51 | dir = "fwd"
52 |
53 | model._enable_hook("linear", hook=hook, dir=dir)
54 |
55 | model.mod_dict["linear"].add_hook.assert_called_with(hook, dir="fwd", level=5)
56 |
57 |
58 | def test_enable_hook_with_callable_param():
59 | model = HookedRootModule()
60 | model.mod_dict = {"linear": Mock()}
61 | model.hook_dict = {
62 | "linear": Mock(),
63 | "attn": Mock(),
64 | }
65 | model.context_level = 5
66 |
67 | enabled = lambda x: x == "attn"
68 |
69 | hook = lambda x: False
70 | dir = "fwd"
71 |
72 | model._enable_hook(enabled, hook=hook, dir=dir)
73 |
74 | model.mod_dict["linear"].add_hook.assert_not_called()
75 | model.hook_dict["attn"].add_hook.assert_called_with(hook, dir="fwd", level=5)
76 | model.hook_dict["linear"].add_hook.assert_not_called()
77 |
--------------------------------------------------------------------------------
/tests/unit/test_hooked_transformer_config.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests that config passed around TransformerLens can be unwrapped into an actual configuration object
3 | """
4 |
5 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
6 |
7 |
8 | def test_hooked_transformer_config_object():
9 | hooked_transformer_config = HookedTransformerConfig(
10 | n_layers=2, d_vocab=100, d_model=6, n_ctx=5, d_head=2, attn_only=True
11 | )
12 | result = HookedTransformerConfig.unwrap(hooked_transformer_config)
13 | # Assert that the same object was returned
14 | assert result is hooked_transformer_config
15 |
16 |
17 | def test_hooked_transformer_config_dict():
18 | hooked_transformer_config_dict = {
19 | "n_layers": 2,
20 | "d_vocab": 100,
21 | "d_model": 6,
22 | "n_ctx": 5,
23 | "d_head": 2,
24 | "attn_only": True,
25 | }
26 | result = HookedTransformerConfig.unwrap(hooked_transformer_config_dict)
27 | # Assert that the new returned value has been transformed into a config object
28 | assert isinstance(result, HookedTransformerConfig)
29 |
30 |
31 | def test_is_layer_norm_activation_passes():
32 | hooked_transformer_config_dict = {
33 | "n_layers": 2,
34 | "d_vocab": 100,
35 | "d_model": 6,
36 | "n_ctx": 5,
37 | "d_head": 2,
38 | "attn_only": True,
39 | "act_fn": "solu_ln",
40 | }
41 | config = HookedTransformerConfig.unwrap(hooked_transformer_config_dict)
42 | assert config.is_layer_norm_activation()
43 |
44 |
45 | def test_is_layer_norm_activation_fails():
46 | hooked_transformer_config_dict = {
47 | "n_layers": 2,
48 | "d_vocab": 100,
49 | "d_model": 6,
50 | "n_ctx": 5,
51 | "d_head": 2,
52 | "attn_only": True,
53 | "act_fn": "relu",
54 | }
55 | config = HookedTransformerConfig.unwrap(hooked_transformer_config_dict)
56 | assert not config.is_layer_norm_activation()
57 |
--------------------------------------------------------------------------------
/tests/unit/test_loading_from_pretrained_utilities.py:
--------------------------------------------------------------------------------
1 | from unittest import mock
2 |
3 | import pytest
4 |
5 | from transformer_lens import HookedTransformer
6 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
7 | from transformer_lens.loading_from_pretrained import fill_missing_keys
8 |
9 |
10 | def get_default_config():
11 | return HookedTransformerConfig(
12 | d_model=128, d_head=8, n_heads=16, n_ctx=128, n_layers=1, d_vocab=50257, attn_only=True
13 | )
14 |
15 |
16 | # Successes
17 |
18 |
19 | @mock.patch("logging.warning")
20 | def test_fill_missing_keys(mock_warning):
21 | cfg = get_default_config()
22 | model = HookedTransformer(cfg)
23 | default_state_dict = model.state_dict()
24 |
25 | incomplete_state_dict = {k: v for k, v in default_state_dict.items() if "W_" not in k}
26 |
27 | filled_state_dict = fill_missing_keys(model, incomplete_state_dict)
28 |
29 | assert set(filled_state_dict.keys()) == set(default_state_dict.keys())
30 |
31 | # Check that warnings were issued for missing weight matrices
32 | for key in default_state_dict:
33 | if "W_" in key and key not in incomplete_state_dict:
34 | mock_warning.assert_any_call(
35 | f"Missing key for a weight matrix in pretrained, filled in with an empty tensor: {key}"
36 | )
37 |
38 |
39 | def test_fill_missing_keys_with_hf_model_keys():
40 | cfg = get_default_config()
41 | model = HookedTransformer(cfg)
42 | default_state_dict = model.state_dict()
43 |
44 | incomplete_state_dict = {k: v for k, v in default_state_dict.items() if "hf_model" not in k}
45 |
46 | filled_state_dict = fill_missing_keys(model, incomplete_state_dict)
47 |
48 | expected_keys = set(default_state_dict.keys()) - {
49 | k for k in default_state_dict.keys() if "hf_model" in k
50 | }
51 | assert set(filled_state_dict.keys()) == expected_keys
52 |
53 |
54 | def test_fill_missing_keys_no_missing_keys():
55 | cfg = get_default_config()
56 | model = HookedTransformer(cfg)
57 | default_state_dict = model.state_dict()
58 |
59 | filled_state_dict = fill_missing_keys(model, default_state_dict)
60 |
61 | assert filled_state_dict == default_state_dict
62 |
63 |
64 | # Failures
65 |
66 |
67 | def test_fill_missing_keys_raises_error_on_invalid_model():
68 | invalid_model = None
69 | default_state_dict = {}
70 |
71 | with pytest.raises(AttributeError):
72 | fill_missing_keys(invalid_model, default_state_dict)
73 |
--------------------------------------------------------------------------------
/tests/unit/test_make_docs.py:
--------------------------------------------------------------------------------
1 | """Make Docs Tests."""
2 |
3 | import pytest
4 |
5 | from docs.make_docs import get_config, get_property
6 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
7 |
8 |
9 | def test_get_config():
10 | """Test get config with attn-only-1l model."""
11 | config: HookedTransformerConfig = get_config("attn-only-1l")
12 | assert config.attn_only is True
13 |
14 |
15 | def test_get_property():
16 | """Test get property with attn-only-1l model."""
17 | act_fn = get_property("act_fn", "attn-only-1l")
18 | assert act_fn == "attn_only"
19 |
20 | n_params = get_property("n_params", "attn-only-1l")
21 | assert n_params == "1.0M"
22 |
23 | n_layers = get_property("n_layers", "attn-only-1l")
24 | assert n_layers == 1
25 |
26 | d_model = get_property("d_model", "attn-only-1l")
27 | assert d_model == 512
28 |
29 | n_heads = get_property("n_heads", "attn-only-1l")
30 | assert n_heads == 8
31 |
32 | n_ctx = get_property("n_ctx", "attn-only-1l")
33 | assert n_ctx == 1024
34 |
35 | d_vocab = get_property("d_vocab", "attn-only-1l")
36 | assert d_vocab == 48262
37 |
38 | d_head = get_property("d_head", "attn-only-1l")
39 | assert d_head == 64
40 |
41 | d_mlp = get_property("d_mlp", "attn-only-1l")
42 | assert d_mlp == 2048
43 |
44 | n_key_value_heads = get_property("n_key_value_heads", "attn-only-1l")
45 | assert n_key_value_heads is None
46 |
47 | # Test an unknown property
48 | with pytest.raises(KeyError):
49 | get_property("unknown_property", "attn-only-1l")
50 |
--------------------------------------------------------------------------------
/tests/unit/test_split_qkv.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from transformer_lens import HookedTransformer
4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
5 |
6 |
7 | def test_split_qkv_normal_attn_correct():
8 | """Verifies that the split_qkv_input flag does not change the output for models with normal attention."""
9 | d_model = 128
10 | d_head = 8
11 | n_heads = 16
12 | n_ctx = 128
13 | n_layers = 1
14 | d_vocab = 10
15 |
16 | cfg = HookedTransformerConfig(
17 | d_model=d_model,
18 | d_head=d_head,
19 | n_heads=n_heads,
20 | n_ctx=n_ctx,
21 | n_layers=n_layers,
22 | attn_only=True,
23 | d_vocab=d_vocab,
24 | )
25 |
26 | model = HookedTransformer(cfg)
27 | assert model.cfg.use_split_qkv_input is False
28 |
29 | x = torch.arange(1, 9).unsqueeze(0)
30 | normal_output = model(x)
31 |
32 | model.set_use_split_qkv_input(True)
33 | assert model.cfg.use_split_qkv_input is True
34 |
35 | split_output = model(x)
36 |
37 | assert torch.allclose(normal_output, split_output, atol=1e-6)
38 |
39 |
40 | def test_split_qkv_grouped_query_attn_correct():
41 | """Verifies that the split_qkv_input flag does not change the output for models with grouped query attention."""
42 |
43 | d_model = 128
44 | d_head = 8
45 | n_heads = 16
46 | n_ctx = 128
47 | n_key_value_heads = 2
48 | n_layers = 1
49 | d_vocab = 10
50 |
51 | cfg = HookedTransformerConfig(
52 | d_model=d_model,
53 | d_head=d_head,
54 | n_heads=n_heads,
55 | n_ctx=n_ctx,
56 | n_key_value_heads=n_key_value_heads,
57 | n_layers=n_layers,
58 | attn_only=True,
59 | d_vocab=d_vocab,
60 | )
61 |
62 | model = HookedTransformer(cfg)
63 | assert model.cfg.use_split_qkv_input is False
64 |
65 | x = torch.arange(1, 9).unsqueeze(0)
66 | normal_output = model(x)
67 |
68 | model.set_use_split_qkv_input(True)
69 | assert model.cfg.use_split_qkv_input is True
70 |
71 | split_output = model(x)
72 |
73 | assert torch.allclose(normal_output, split_output, atol=1e-6)
74 |
--------------------------------------------------------------------------------
/tests/unit/test_use_attn_result.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from transformer_lens import HookedTransformer
4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
5 |
6 |
7 | def test_atten_result_normal_attn_correct():
8 | """Verifies that the attn_result flag does not change the output for models with normal attention."""
9 | d_model = 128
10 | d_head = 8
11 | n_heads = 16
12 | n_ctx = 128
13 | n_layers = 1
14 | d_vocab = 10
15 |
16 | cfg = HookedTransformerConfig(
17 | d_model=d_model,
18 | d_head=d_head,
19 | n_heads=n_heads,
20 | n_ctx=n_ctx,
21 | n_layers=n_layers,
22 | attn_only=True,
23 | d_vocab=d_vocab,
24 | )
25 |
26 | model = HookedTransformer(cfg)
27 | assert model.cfg.use_split_qkv_input is False
28 |
29 | x = torch.arange(1, 9).unsqueeze(0)
30 | normal_output = model(x)
31 |
32 | model.set_use_attn_result(True)
33 | assert model.cfg.use_attn_result is True
34 |
35 | split_output = model(x)
36 |
37 | assert torch.allclose(normal_output, split_output, atol=1e-6)
38 |
39 |
40 | def test_atten_result_grouped_query_attn_correct():
41 | """Verifies that the atten_result flag does not change the output for models with grouped query attention."""
42 |
43 | d_model = 128
44 | d_head = 8
45 | n_heads = 16
46 | n_ctx = 128
47 | n_key_value_heads = 2
48 | n_layers = 1
49 | d_vocab = 10
50 |
51 | cfg = HookedTransformerConfig(
52 | d_model=d_model,
53 | d_head=d_head,
54 | n_heads=n_heads,
55 | n_ctx=n_ctx,
56 | n_key_value_heads=n_key_value_heads,
57 | n_layers=n_layers,
58 | attn_only=True,
59 | d_vocab=d_vocab,
60 | )
61 |
62 | model = HookedTransformer(cfg)
63 | assert model.cfg.use_split_qkv_input is False
64 |
65 | x = torch.arange(1, 9).unsqueeze(0)
66 | normal_output = model(x)
67 |
68 | model.set_use_attn_result(True)
69 | assert model.cfg.use_attn_result is True
70 |
71 | split_output = model(x)
72 |
73 | assert torch.allclose(normal_output, split_output, atol=1e-6)
74 |
--------------------------------------------------------------------------------
/tests/unit/utilities/test_devices.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import Mock
2 |
3 | import torch
4 |
5 | from transformer_lens.utilities.devices import (
6 | calculate_available_device_cuda_memory,
7 | determine_available_memory_for_available_devices,
8 | sort_devices_based_on_available_memory,
9 | )
10 |
11 |
12 | def mock_available_devices(memory_stats: list[tuple[int, int]]):
13 | torch.cuda.device_count = Mock(return_value=len(memory_stats))
14 |
15 | def device_props_return(*args, **kwargs):
16 | total_memory = memory_stats[args[0]][0]
17 | device_props = Mock()
18 | device_props.total_memory = total_memory
19 | return device_props
20 |
21 | def memory_allocated_return(*args, **kwargs):
22 | return memory_stats[args[0]][1]
23 |
24 | torch.cuda.get_device_properties = Mock(side_effect=device_props_return)
25 | torch.cuda.memory_allocated = Mock(side_effect=memory_allocated_return)
26 |
27 |
28 | def test_calculate_available_device_cuda_memory():
29 | mock_available_devices([(80, 40)])
30 |
31 | result = calculate_available_device_cuda_memory(0)
32 | assert result == 40
33 |
34 |
35 | def test_determine_available_memory_for_available_devices():
36 | mock_available_devices(
37 | [
38 | (80, 60),
39 | (80, 15),
40 | (80, 40),
41 | ]
42 | )
43 |
44 | result = determine_available_memory_for_available_devices(3)
45 |
46 | assert result == [
47 | (0, 20),
48 | (1, 65),
49 | (2, 40),
50 | ]
51 |
52 |
53 | def test_sort_devices_based_on_available_memory():
54 | devices = [
55 | (0, 20),
56 | (1, 65),
57 | (2, 40),
58 | ]
59 |
60 | result = sort_devices_based_on_available_memory(devices)
61 |
62 | assert result == [
63 | (1, 65),
64 | (2, 40),
65 | (0, 20),
66 | ]
67 |
--------------------------------------------------------------------------------
/transformer_lens/__init__.py:
--------------------------------------------------------------------------------
1 | from . import hook_points
2 | from . import utils
3 | from . import evals
4 | from .past_key_value_caching import (
5 | HookedTransformerKeyValueCache,
6 | HookedTransformerKeyValueCacheEntry,
7 | )
8 | from . import components
9 | from . import factories
10 | from .HookedTransformerConfig import HookedTransformerConfig
11 | from .FactoredMatrix import FactoredMatrix
12 | from .ActivationCache import ActivationCache
13 | from .HookedTransformer import HookedTransformer
14 | from .SVDInterpreter import SVDInterpreter
15 | from .HookedEncoder import HookedEncoder
16 | from .HookedEncoderDecoder import HookedEncoderDecoder
17 | from .BertNextSentencePrediction import BertNextSentencePrediction
18 | from . import head_detector
19 | from . import loading_from_pretrained as loading
20 | from . import patching
21 | from . import train
22 |
23 | from .past_key_value_caching import (
24 | HookedTransformerKeyValueCache as EasyTransformerKeyValueCache,
25 | HookedTransformerKeyValueCacheEntry as EasyTransformerKeyValueCacheEntry,
26 | )
27 | from .HookedTransformer import HookedTransformer as EasyTransformer
28 | from .HookedTransformerConfig import HookedTransformerConfig as EasyTransformerConfig
29 |
--------------------------------------------------------------------------------
/transformer_lens/components/__init__.py:
--------------------------------------------------------------------------------
1 | """Hooked Transformer Components.
2 |
3 | This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`)
4 | needed to create many different types of generative language models. They are used by
5 | :class:`transformer_lens.HookedTransformer`.
6 | """
7 |
8 | # Independent classes
9 | from .abstract_attention import AbstractAttention
10 | from .layer_norm import LayerNorm
11 | from .layer_norm_pre import LayerNormPre
12 | from .pos_embed import PosEmbed
13 | from .rms_norm import RMSNorm
14 | from .rms_norm_pre import RMSNormPre
15 | from .token_typed_embed import TokenTypeEmbed
16 | from .unembed import Unembed
17 |
18 | # Only dependent on independent modules
19 | from .attention import Attention
20 | from .bert_mlm_head import BertMLMHead
21 | from .bert_nsp_head import BertNSPHead
22 | from .bert_pooler import BertPooler
23 | from .embed import Embed
24 | from .grouped_query_attention import GroupedQueryAttention
25 | from .mlps.gated_mlp import GatedMLP
26 | from .mlps.mlp import MLP
27 |
28 | # Interdependent modules
29 | from .bert_block import BertBlock
30 | from .bert_embed import BertEmbed
31 | from .mlps.moe import MoE
32 | from .transformer_block import TransformerBlock
33 | from .t5_attention import T5Attention
34 | from .t5_block import T5Block
35 |
--------------------------------------------------------------------------------
/transformer_lens/components/attention.py:
--------------------------------------------------------------------------------
1 | """Hooked Transformer Attention Component.
2 |
3 | This module contains all the component :class:`Attention`.
4 | """
5 | from typing import Dict, Optional, Union
6 |
7 | import torch
8 | import torch.nn as nn
9 | from transformers.utils import is_bitsandbytes_available
10 |
11 | from transformer_lens.components import AbstractAttention
12 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
13 |
14 | if is_bitsandbytes_available():
15 | from bitsandbytes.nn.modules import Params4bit
16 |
17 |
18 | # Attention
19 | class Attention(AbstractAttention):
20 | def __init__(
21 | self,
22 | cfg: Union[Dict, HookedTransformerConfig],
23 | attn_type: str = "global",
24 | layer_id: Optional[int] = None,
25 | ):
26 | """Attention Block - params have shape [head_index, d_model, d_head] (or [head_index, d_head, d_model] for W_O) and multiply on the right. attn_scores refers to query key dot product immediately before attention softmax
27 |
28 | Convention: All attention pattern-style matrices have shape [batch, head_index, query_pos, key_pos]
29 |
30 | Args:
31 | cfg (Union[Dict, HookedTransformerConfig]): Config
32 | attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global".
33 | layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None.
34 | """
35 | super().__init__(cfg, attn_type, layer_id)
36 | self.cfg = HookedTransformerConfig.unwrap(cfg)
37 |
38 | if self.cfg.load_in_4bit:
39 | # 4-bit quantization convention
40 | nq = int((self.cfg.d_model * self.cfg.d_model) / 2)
41 | self.W_K = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
42 | self.W_V = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
43 | else:
44 | self.W_K = nn.Parameter(
45 | torch.empty(
46 | self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=self.cfg.dtype
47 | )
48 | )
49 | self.W_V = nn.Parameter(
50 | torch.empty(
51 | self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=self.cfg.dtype
52 | )
53 | )
54 | self.b_K = nn.Parameter(
55 | torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
56 | )
57 | self.b_V = nn.Parameter(
58 | torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
59 | )
60 |
--------------------------------------------------------------------------------
/transformer_lens/components/bert_block.py:
--------------------------------------------------------------------------------
1 | """Hooked Transformer Bert Block Component.
2 |
3 | This module contains all the component :class:`BertBlock`.
4 | """
5 | from typing import Optional
6 |
7 | import torch
8 | import torch.nn as nn
9 | from jaxtyping import Float
10 |
11 | from transformer_lens.components import Attention, LayerNorm
12 | from transformer_lens.factories.mlp_factory import MLPFactory
13 | from transformer_lens.hook_points import HookPoint
14 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
15 | from transformer_lens.utils import repeat_along_head_dimension
16 |
17 |
18 | class BertBlock(nn.Module):
19 | """
20 | BERT Block. Similar to the TransformerBlock, except that the LayerNorms are applied after the attention and MLP, rather than before.
21 | """
22 |
23 | def __init__(self, cfg: HookedTransformerConfig):
24 | super().__init__()
25 | self.cfg = cfg
26 |
27 | self.attn = Attention(cfg)
28 | self.ln1 = LayerNorm(cfg)
29 | self.mlp = MLPFactory.create_mlp(self.cfg)
30 | self.ln2 = LayerNorm(cfg)
31 |
32 | self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model]
33 | self.hook_k_input = HookPoint() # [batch, pos, n_heads, d_model]
34 | self.hook_v_input = HookPoint() # [batch, pos, n_heads, d_model]
35 |
36 | self.hook_attn_out = HookPoint() # [batch, pos, d_model]
37 | self.hook_mlp_in = HookPoint() # [batch, pos, d_model]
38 | self.hook_mlp_out = HookPoint() # [batch, pos, d_model]
39 | self.hook_resid_pre = HookPoint() # [batch, pos, d_model]
40 | self.hook_resid_mid = HookPoint() # [batch, pos, d_model]
41 | self.hook_resid_post = HookPoint() # [batch, pos, d_model]
42 | self.hook_normalized_resid_post = HookPoint() # [batch, pos, d_model]
43 |
44 | def forward(
45 | self,
46 | resid_pre: Float[torch.Tensor, "batch pos d_model"],
47 | additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None,
48 | ) -> Float[torch.Tensor, "batch pos d_model"]:
49 | resid_pre = self.hook_resid_pre(resid_pre)
50 |
51 | query_input = resid_pre
52 | key_input = resid_pre
53 | value_input = resid_pre
54 |
55 | if self.cfg.use_split_qkv_input:
56 | n_heads = self.cfg.n_heads
57 | query_input = self.hook_q_input(repeat_along_head_dimension(query_input, n_heads))
58 | key_input = self.hook_k_input(repeat_along_head_dimension(key_input, n_heads))
59 | value_input = self.hook_v_input(repeat_along_head_dimension(value_input, n_heads))
60 |
61 | attn_out = self.hook_attn_out(
62 | self.attn(
63 | query_input,
64 | key_input,
65 | value_input,
66 | additive_attention_mask=additive_attention_mask,
67 | )
68 | )
69 | resid_mid = self.hook_resid_mid(resid_pre + attn_out)
70 |
71 | mlp_in = resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone())
72 | normalized_resid_mid = self.ln1(mlp_in)
73 | mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_mid))
74 | resid_post = self.hook_resid_post(normalized_resid_mid + mlp_out)
75 | normalized_resid_post = self.hook_normalized_resid_post(self.ln2(resid_post))
76 |
77 | return normalized_resid_post
78 |
--------------------------------------------------------------------------------
/transformer_lens/components/bert_embed.py:
--------------------------------------------------------------------------------
1 | """Hooked Transformer Bert Embed Component.
2 |
3 | This module contains all the component :class:`BertEmbed`.
4 | """
5 | from typing import Dict, Optional, Union
6 |
7 | import einops
8 | import torch
9 | import torch.nn as nn
10 | from jaxtyping import Float, Int
11 |
12 | from transformer_lens.components import Embed, LayerNorm, PosEmbed, TokenTypeEmbed
13 | from transformer_lens.hook_points import HookPoint
14 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
15 |
16 |
17 | class BertEmbed(nn.Module):
18 | """
19 | Custom embedding layer for a BERT-like model. This module computes the sum of the token, positional and token-type embeddings and takes the layer norm of the result.
20 | """
21 |
22 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
23 | super().__init__()
24 | self.cfg = HookedTransformerConfig.unwrap(cfg)
25 | self.embed = Embed(self.cfg)
26 | self.pos_embed = PosEmbed(self.cfg)
27 | self.token_type_embed = TokenTypeEmbed(self.cfg)
28 | self.ln = LayerNorm(self.cfg)
29 |
30 | self.hook_embed = HookPoint()
31 | self.hook_pos_embed = HookPoint()
32 | self.hook_token_type_embed = HookPoint()
33 |
34 | def forward(
35 | self,
36 | input_ids: Int[torch.Tensor, "batch pos"],
37 | token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
38 | ) -> Float[torch.Tensor, "batch pos d_model"]:
39 | base_index_id = torch.arange(input_ids.shape[1], device=input_ids.device)
40 | index_ids = einops.repeat(base_index_id, "pos -> batch pos", batch=input_ids.shape[0])
41 | if token_type_ids is None:
42 | token_type_ids = torch.zeros_like(input_ids)
43 |
44 | word_embeddings_out = self.hook_embed(self.embed(input_ids))
45 | position_embeddings_out = self.hook_pos_embed(self.pos_embed(index_ids))
46 | token_type_embeddings_out = self.hook_token_type_embed(
47 | self.token_type_embed(token_type_ids)
48 | )
49 |
50 | embeddings_out = word_embeddings_out + position_embeddings_out + token_type_embeddings_out
51 | layer_norm_out = self.ln(embeddings_out)
52 | return layer_norm_out
53 |
--------------------------------------------------------------------------------
/transformer_lens/components/bert_mlm_head.py:
--------------------------------------------------------------------------------
1 | """Hooked Encoder Bert MLM Head Component.
2 |
3 | This module contains all the component :class:`BertMLMHead`.
4 | """
5 | from typing import Dict, Union
6 |
7 | import torch
8 | import torch.nn as nn
9 | from jaxtyping import Float
10 |
11 | from transformer_lens.components import LayerNorm
12 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
13 |
14 |
15 | class BertMLMHead(nn.Module):
16 | """
17 | Transforms BERT embeddings into logits. The purpose of this module is to predict masked tokens in a sentence.
18 | """
19 |
20 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
21 | super().__init__()
22 | self.cfg = HookedTransformerConfig.unwrap(cfg)
23 | self.W = nn.Parameter(torch.empty(self.cfg.d_model, self.cfg.d_model, dtype=self.cfg.dtype))
24 | self.b = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))
25 | self.act_fn = nn.GELU()
26 | self.ln = LayerNorm(self.cfg)
27 |
28 | def forward(
29 | self, resid: Float[torch.Tensor, "batch pos d_model"]
30 | ) -> Float[torch.Tensor, "batch pos d_model"]:
31 | resid = torch.matmul(resid, self.W) + self.b
32 | resid = self.act_fn(resid)
33 | resid = self.ln(resid)
34 | return resid
35 |
--------------------------------------------------------------------------------
/transformer_lens/components/bert_nsp_head.py:
--------------------------------------------------------------------------------
1 | """Hooked Encoder Bert NSP Head Component.
2 |
3 | This module contains all the component :class:`BertNSPHead`.
4 | """
5 | from typing import Dict, Union
6 |
7 | import torch
8 | import torch.nn as nn
9 | from jaxtyping import Float
10 |
11 | from transformer_lens.hook_points import HookPoint
12 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
13 |
14 |
15 | class BertNSPHead(nn.Module):
16 | """
17 | Transforms BERT embeddings into logits. The purpose of this module is to predict whether or not sentence B follows sentence A.
18 | """
19 |
20 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
21 | super().__init__()
22 | self.cfg = HookedTransformerConfig.unwrap(cfg)
23 | self.W = nn.Parameter(torch.empty(self.cfg.d_model, 2, dtype=self.cfg.dtype))
24 | self.b = nn.Parameter(torch.zeros(2, dtype=self.cfg.dtype))
25 | self.hook_nsp_out = HookPoint()
26 |
27 | def forward(
28 | self, resid: Float[torch.Tensor, "batch d_model"]
29 | ) -> Float[torch.Tensor, "batch 2"]:
30 | nsp_logits = torch.matmul(resid, self.W) + self.b
31 | return self.hook_nsp_out(nsp_logits)
32 |
--------------------------------------------------------------------------------
/transformer_lens/components/bert_pooler.py:
--------------------------------------------------------------------------------
1 | """Hooked Encoder Bert Pooler Component.
2 |
3 | This module contains all the component :class:`BertPooler`.
4 | """
5 | from typing import Dict, Union
6 |
7 | import torch
8 | import torch.nn as nn
9 | from jaxtyping import Float
10 |
11 | from transformer_lens.hook_points import HookPoint
12 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
13 |
14 |
15 | class BertPooler(nn.Module):
16 | """
17 | Transforms the [CLS] token representation into a fixed-size sequence embedding.
18 | The purpose of this module is to convert variable-length sequence inputs into a single vector representation suitable for downstream tasks.
19 | (e.g. Next Sentence Prediction)
20 | """
21 |
22 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
23 | super().__init__()
24 | self.cfg = HookedTransformerConfig.unwrap(cfg)
25 | self.W = nn.Parameter(torch.empty(self.cfg.d_model, self.cfg.d_model, dtype=self.cfg.dtype))
26 | self.b = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))
27 | self.activation = nn.Tanh()
28 | self.hook_pooler_out = HookPoint()
29 |
30 | def forward(
31 | self, resid: Float[torch.Tensor, "batch pos d_model"]
32 | ) -> Float[torch.Tensor, "batch d_model"]:
33 | first_token_tensor = resid[:, 0]
34 | pooled_output = torch.matmul(first_token_tensor, self.W) + self.b
35 | pooled_output = self.hook_pooler_out(self.activation(pooled_output))
36 | return pooled_output
37 |
--------------------------------------------------------------------------------
/transformer_lens/components/embed.py:
--------------------------------------------------------------------------------
1 | """Hooked Transformer Embed Component.
2 |
3 | This module contains all the component :class:`Embed`.
4 | """
5 | from typing import Dict, Union
6 |
7 | import torch
8 | import torch.nn as nn
9 | from jaxtyping import Float, Int
10 |
11 | from transformer_lens.components import LayerNorm
12 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
13 |
14 |
15 | # Embed & Unembed
16 | class Embed(nn.Module):
17 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
18 | super().__init__()
19 | self.cfg = HookedTransformerConfig.unwrap(cfg)
20 | self.W_E: Float[torch.Tensor, "d_vocab d_model"] = nn.Parameter(
21 | torch.empty(self.cfg.d_vocab, self.cfg.d_model, dtype=self.cfg.dtype)
22 | )
23 | # Some models (e.g. Bloom) need post embedding layer norm
24 | if self.cfg.post_embedding_ln:
25 | self.ln = LayerNorm(self.cfg)
26 |
27 | def forward(
28 | self, tokens: Int[torch.Tensor, "batch pos"]
29 | ) -> Float[torch.Tensor, "batch pos d_model"]:
30 | # If A has shape [a, b] and B has shape [c, d], then A[:, B] has shape [a, c, d]
31 | # B acts as a tensor of indices into the second dimension (so >=0 and Union[
45 | Float[torch.Tensor, "batch pos d_model"],
46 | Float[torch.Tensor, "batch pos head_index d_model"],
47 | ]:
48 | if self.cfg.dtype not in [torch.float32, torch.float64]:
49 | x = x.to(torch.float32)
50 |
51 | x = x - x.mean(-1, keepdim=True) # [batch, pos, length]
52 | scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
53 | (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
54 | )
55 | x = x / scale # [batch, pos, length]
56 | return self.hook_normalized(x * self.w + self.b).to(self.cfg.dtype)
57 |
--------------------------------------------------------------------------------
/transformer_lens/components/layer_norm_pre.py:
--------------------------------------------------------------------------------
1 | """Hooked Transformer Layer Norm Pre Component.
2 |
3 | This module contains all the component :class:`LayerNormPre`.
4 | """
5 | from typing import Dict, Union
6 |
7 | import torch
8 | import torch.nn as nn
9 | from jaxtyping import Float
10 |
11 | from transformer_lens.hook_points import HookPoint
12 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
13 |
14 |
15 | # LayerNormPre
16 | # I fold the LayerNorm weights and biases into later weights and biases.
17 | # This is just the 'center and normalise' part of LayerNorm
18 | # Centering is equivalent to just deleting one direction of residual space,
19 | # and is equivalent to centering the weight matrices of everything writing to the residual stream
20 | # Normalising is a funkier non-linear operation, that projects the residual stream onto the unit hypersphere
21 | class LayerNormPre(nn.Module):
22 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
23 | """LayerNormPre - the 'center and normalise' part of LayerNorm. Length is
24 | normally d_model, but is d_mlp for softmax. Not needed as a parameter. This
25 | should only be used in inference mode after folding in LayerNorm weights"""
26 | super().__init__()
27 | self.cfg = HookedTransformerConfig.unwrap(cfg)
28 | self.eps = self.cfg.eps
29 |
30 | # Adds a hook point for the normalisation scale factor
31 | self.hook_scale = HookPoint() # [batch, pos]
32 | # Hook Normalized captures LN output - here it's a vector with std 1 and mean 0
33 | self.hook_normalized = HookPoint() # [batch, pos, length]
34 |
35 | def forward(
36 | self,
37 | x: Union[
38 | Float[torch.Tensor, "batch pos d_model"],
39 | Float[torch.Tensor, "batch pos head_index d_model"],
40 | ],
41 | ) -> Union[
42 | Float[torch.Tensor, "batch pos d_model"],
43 | Float[torch.Tensor, "batch pos head_index d_model"],
44 | ]:
45 | if self.cfg.dtype not in [torch.float32, torch.float64]:
46 | x = x.to(torch.float32)
47 |
48 | x = x - x.mean(-1, keepdim=True) # [batch, pos, length]
49 | scale: Union[
50 | Float[torch.Tensor, "batch pos 1"],
51 | Float[torch.Tensor, "batch pos head_index 1"],
52 | ] = self.hook_scale((x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt())
53 | return self.hook_normalized(x / scale).to(self.cfg.dtype)
54 |
--------------------------------------------------------------------------------
/transformer_lens/components/mlps/can_be_used_as_mlp.py:
--------------------------------------------------------------------------------
1 | """Can Be Used as MLP component.
2 |
3 | This module serves as the base for everything within TransformerLens that can be used like an MLP.
4 | This does not necessarily mean that every component extending this class will be an MLP, but
5 | everything extending this class can be used interchangeably for an MLP.
6 | """
7 | from typing import Dict, Optional, Union
8 |
9 | import torch
10 | import torch.nn as nn
11 | from jaxtyping import Float
12 |
13 | from transformer_lens.components import LayerNorm, LayerNormPre
14 | from transformer_lens.factories.activation_function_factory import (
15 | ActivationFunctionFactory,
16 | )
17 | from transformer_lens.hook_points import HookPoint
18 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
19 | from transformer_lens.utilities.activation_functions import ActivationFunction
20 |
21 |
22 | class CanBeUsedAsMLP(nn.Module):
23 | # The actual activation function
24 | act_fn: ActivationFunction
25 |
26 | # The full config object for the model
27 | cfg: HookedTransformerConfig
28 |
29 | # The d mlp value pulled out of the config to make sure it always has a value
30 | d_mlp: int
31 |
32 | # The middle hook point will be None unless it specifically should be used
33 | hook_mid: Optional[HookPoint] # [batch, pos, d_mlp]
34 |
35 | # The layer norm component if the activation function is a layer norm
36 | ln: Optional[nn.Module]
37 |
38 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
39 | """The base init for all MLP like components
40 |
41 | Args:
42 | config (Union[Dict, HookedTransformerConfig]): The config for this instance
43 |
44 | Raises:
45 | ValueError: If there is a misconfiguration
46 | """
47 | super().__init__()
48 | self.cfg = HookedTransformerConfig.unwrap(cfg)
49 | if self.cfg.d_mlp is None:
50 | raise ValueError("d_mlp must be set to use an MLP")
51 |
52 | self.d_mlp = self.cfg.d_mlp
53 |
54 | def forward(
55 | self, x: Float[torch.Tensor, "batch pos d_model"]
56 | ) -> Float[torch.Tensor, "batch pos d_model"]:
57 | """The format for all forward functions for any MLP"""
58 | return x
59 |
60 | def select_activation_function(self) -> None:
61 | """This function should be called by all components in their init to get everything needed
62 | for activation functions setup.
63 |
64 | Raises:
65 | ValueError: If the configure activation function is not supported.
66 | """
67 |
68 | self.act_fn = ActivationFunctionFactory.pick_activation_function(self.cfg)
69 |
70 | if self.cfg.is_layer_norm_activation():
71 | self.hook_mid = HookPoint()
72 | if self.cfg.normalization_type == "LN":
73 | self.ln = LayerNorm(self.cfg, self.d_mlp)
74 | else:
75 | self.ln = LayerNormPre(self.cfg)
76 |
--------------------------------------------------------------------------------
/transformer_lens/components/mlps/gated_mlp.py:
--------------------------------------------------------------------------------
1 | """Hooked Transformer Gated MLP Component.
2 |
3 | This module contains all the component :class:`GatedMLP`.
4 | """
5 | from typing import Dict, Union
6 |
7 | import torch
8 | import torch.nn as nn
9 | from jaxtyping import Float
10 | from transformers.utils import is_bitsandbytes_available
11 |
12 | from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
13 | from transformer_lens.hook_points import HookPoint
14 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
15 | from transformer_lens.utilities.addmm import batch_addmm
16 |
17 | if is_bitsandbytes_available():
18 | pass
19 |
20 |
21 | class GatedMLP(CanBeUsedAsMLP):
22 | """
23 | The equation of a gated MLP:
24 | pre = x @ W_gate
25 | pre_linear = x @ W_in
26 | post = Gelu(pre) * (pre_linear) + b_in
27 | mlp_out = post @ W_out + b_out
28 |
29 | In one equation, mlp_out = (Gelu(x @ W_gate) * (x @ W_in) + b_in) @ W_out + b_out
30 | """
31 |
32 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
33 | super().__init__(cfg)
34 | self.select_activation_function()
35 | self.W_in = nn.Parameter(torch.empty(self.cfg.d_model, self.d_mlp, dtype=self.cfg.dtype))
36 | self.W_out = nn.Parameter(torch.empty(self.d_mlp, self.cfg.d_model, dtype=self.cfg.dtype))
37 | self.W_gate = nn.Parameter(torch.empty(self.cfg.d_model, self.d_mlp, dtype=self.cfg.dtype))
38 |
39 | self.b_in = nn.Parameter(torch.zeros(self.d_mlp, dtype=self.cfg.dtype))
40 | self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))
41 |
42 | # hook on gate output but before act_fn
43 | self.hook_pre = HookPoint() # [batch, pos, d_mlp]
44 | # hook on the linear component of the input
45 | self.hook_pre_linear = HookPoint() # [batch, pos, d_mlp]
46 | # hook on act_fn(gate_output) * W_in(x) + b_in
47 | self.hook_post = HookPoint() # [batch, pos, d_mlp]
48 |
49 | def forward(
50 | self, x: Float[torch.Tensor, "batch pos d_model"]
51 | ) -> Float[torch.Tensor, "batch pos d_model"]:
52 | # Technically, all these einsums could be done with a single matmul, but this is more readable.
53 | if self.W_gate.device != x.device:
54 | x = x.to(self.W_gate.device)
55 | pre_act = self.hook_pre(
56 | torch.matmul(x, self.W_gate) # batch pos d_model, d_model d_mlp -> batch pos d_mlp
57 | ) # [batch, pos, d_mlp]
58 |
59 | if (
60 | self.cfg.is_layer_norm_activation()
61 | and self.hook_mid is not None
62 | and self.ln is not None
63 | ):
64 | mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp]
65 | post_act = self.hook_post(self.ln(mid_act))
66 | else:
67 | pre_linear = self.hook_pre_linear(
68 | torch.matmul(x, self.W_in) # batch pos d_model, d_model d_mlp -> batch pos d_mlp
69 | )
70 |
71 | post_act = self.hook_post(
72 | (self.act_fn(pre_act) * pre_linear) + self.b_in
73 | ) # [batch, pos, d_mlp]
74 |
75 | return batch_addmm(self.b_out, self.W_out, post_act)
76 |
--------------------------------------------------------------------------------
/transformer_lens/components/mlps/gated_mlp_4bit.py:
--------------------------------------------------------------------------------
1 | """Hooked Transformer Gated MLP Component.
2 |
3 | This module contains all the component :class:`GatedMLP`.
4 | """
5 | from typing import Dict, Union
6 |
7 | import torch
8 | import torch.nn as nn
9 | from jaxtyping import Float
10 | from transformers.utils import is_bitsandbytes_available
11 |
12 | from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
13 | from transformer_lens.hook_points import HookPoint
14 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
15 |
16 | if is_bitsandbytes_available():
17 | import bitsandbytes as bnb
18 | from bitsandbytes.nn.modules import Params4bit
19 |
20 |
21 | class GatedMLP4Bit(CanBeUsedAsMLP):
22 | """
23 | The equation of a gated MLP:
24 | pre = x @ W_gate
25 | pre_linear = x @ W_in
26 | post = Gelu(pre) * (pre_linear) + b_in
27 | mlp_out = post @ W_out + b_out
28 |
29 | In one equation, mlp_out = (Gelu(x @ W_gate) * (x @ W_in) + b_in) @ W_out + b_out
30 | """
31 |
32 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
33 | super().__init__(cfg)
34 | self.select_activation_function()
35 |
36 | nq = int((self.cfg.d_model * self.d_mlp) / 2)
37 | self.W_in = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
38 | self.W_gate = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
39 | self.W_out = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
40 |
41 | self.b_in = nn.Parameter(torch.zeros(self.d_mlp, dtype=self.cfg.dtype))
42 | self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))
43 |
44 | # hook on gate output but before act_fn
45 | self.hook_pre = HookPoint() # [batch, pos, d_mlp]
46 | # hook on the linear component of the input
47 | self.hook_pre_linear = HookPoint() # [batch, pos, d_mlp]
48 | # hook on act_fn(gate_output) * W_in(x) + b_in
49 | self.hook_post = HookPoint() # [batch, pos, d_mlp]
50 |
51 | def forward(
52 | self, x: Float[torch.Tensor, "batch pos d_model"]
53 | ) -> Float[torch.Tensor, "batch pos d_model"]:
54 | # Technically, all these einsums could be done with a single matmul, but this is more readable.
55 | pre_act = self.hook_pre(
56 | bnb.matmul_4bit(x, self.W_gate.t(), bias=None, quant_state=self.W_gate.quant_state)
57 | )
58 |
59 | if (
60 | self.cfg.is_layer_norm_activation()
61 | and self.hook_mid is not None
62 | and self.ln is not None
63 | ):
64 | mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp]
65 | post_act = self.hook_post(self.ln(mid_act))
66 | else:
67 | pre_linear = self.hook_pre_linear(
68 | bnb.matmul_4bit(x, self.W_in.t(), bias=None, quant_state=self.W_in.quant_state)
69 | )
70 |
71 | post_act = self.hook_post(
72 | (self.act_fn(pre_act) * pre_linear) + self.b_in
73 | ) # [batch, pos, d_mlp]
74 |
75 | return bnb.matmul_4bit(
76 | post_act, self.W_out.t(), bias=None, quant_state=self.W_out.quant_state
77 | )
78 |
--------------------------------------------------------------------------------
/transformer_lens/components/mlps/mlp.py:
--------------------------------------------------------------------------------
1 | """Hooked Transformer MLP Component.
2 |
3 | This module contains all the component :class:`MLP`.
4 | """
5 |
6 | from typing import Dict, Union
7 |
8 | import torch
9 | import torch.nn as nn
10 | from jaxtyping import Float
11 |
12 | from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
13 | from transformer_lens.hook_points import HookPoint
14 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
15 | from transformer_lens.utilities.addmm import batch_addmm
16 |
17 |
18 | class MLP(CanBeUsedAsMLP):
19 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
20 | super().__init__(cfg)
21 | self.select_activation_function()
22 |
23 | self.W_in = nn.Parameter(torch.empty(self.cfg.d_model, self.d_mlp, dtype=self.cfg.dtype))
24 | self.b_in = nn.Parameter(torch.zeros(self.d_mlp, dtype=self.cfg.dtype))
25 |
26 | self.W_out = nn.Parameter(torch.empty(self.d_mlp, self.cfg.d_model, dtype=self.cfg.dtype))
27 | self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))
28 |
29 | self.hook_pre = HookPoint() # [batch, pos, d_mlp]
30 | self.hook_post = HookPoint() # [batch, pos, d_mlp]
31 |
32 | def forward(
33 | self, x: Float[torch.Tensor, "batch pos d_model"]
34 | ) -> Float[torch.Tensor, "batch pos d_model"]:
35 | # This is equivalent to (roughly) W_in @ x + b_in. It's important to
36 | # use a fused addmm to ensure it matches the Huggingface implementation
37 | # exactly.
38 | pre_act = self.hook_pre(batch_addmm(self.b_in, self.W_in, x)) # [batch, pos, d_mlp]
39 |
40 | if (
41 | self.cfg.is_layer_norm_activation()
42 | and self.hook_mid is not None
43 | and self.ln is not None
44 | ):
45 | mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp]
46 | post_act = self.hook_post(self.ln(mid_act))
47 | else:
48 | post_act = self.hook_post(self.act_fn(pre_act)) # [batch, pos, d_mlp]
49 | return batch_addmm(self.b_out, self.W_out, post_act)
50 |
--------------------------------------------------------------------------------
/transformer_lens/components/pos_embed.py:
--------------------------------------------------------------------------------
1 | """Hooked Transformer POS Embed Component.
2 |
3 | This module contains all the component :class:`PosEmbed`.
4 | """
5 | from typing import Dict, Optional, Union
6 |
7 | import einops
8 | import torch
9 | import torch.nn as nn
10 | from jaxtyping import Float, Int
11 |
12 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
13 | from transformer_lens.utils import get_offset_position_ids
14 |
15 |
16 | # Positional Embeddings
17 | class PosEmbed(nn.Module):
18 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
19 | super().__init__()
20 | self.cfg = HookedTransformerConfig.unwrap(cfg)
21 | self.W_pos = nn.Parameter(
22 | torch.empty(self.cfg.n_ctx, self.cfg.d_model, dtype=self.cfg.dtype)
23 | )
24 |
25 | def forward(
26 | self,
27 | tokens: Int[torch.Tensor, "batch pos"],
28 | past_kv_pos_offset: int = 0,
29 | attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
30 | ) -> Float[torch.Tensor, "batch new_pos d_model"]:
31 | """
32 | Forward pass for positional embeddings.
33 |
34 | Args:
35 | tokens (Int[torch.Tensor, "batch pos"]): Input tokens.
36 | past_kv_pos_offset (int, optional): The length of tokens in the past_kv_cache. Defaults to 0.
37 | attention_mask (Int[torch.Tensor, "batch pos"], optional): The attention mask for padded tokens.
38 | Defaults to None.
39 |
40 | Returns:
41 | Float[torch.Tensor, "batch pos d_model"]: Absolute position embeddings.
42 | """
43 | tokens_length = tokens.size(-1)
44 |
45 | if attention_mask is None:
46 | pos_embed = self.W_pos[
47 | past_kv_pos_offset : tokens_length + past_kv_pos_offset, :
48 | ] # [pos, d_model]
49 | batch_pos_embed = einops.repeat(
50 | pos_embed, "pos d_model -> batch pos d_model", batch=tokens.size(0)
51 | )
52 |
53 | else:
54 | # Separated from the no padding case for computational efficiency
55 | # (this code is a bit slower than the code above)
56 |
57 | offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask)
58 | pos_embed = self.W_pos[offset_position_ids] # [batch, pos, d_model]
59 |
60 | # Set the position embeddings to 0 for pad tokens (this is an arbitrary choice)
61 | padding_mask = ~attention_mask.bool() # [batch, tokens_length]
62 | offset_padding_mask = padding_mask[
63 | :, past_kv_pos_offset : tokens_length + past_kv_pos_offset
64 | ].unsqueeze(
65 | -1
66 | ) # [batch, pos, 1]
67 | batch_pos_embed = torch.where(offset_padding_mask, 0, pos_embed)
68 |
69 | return batch_pos_embed.clone()
70 |
--------------------------------------------------------------------------------
/transformer_lens/components/rms_norm.py:
--------------------------------------------------------------------------------
1 | """Hooked Transformer RMS Norm Component.
2 |
3 | This module contains all the component :class:`RMSNorm`.
4 | """
5 | from typing import Dict, Optional, Union
6 |
7 | import torch
8 | import torch.nn as nn
9 | from jaxtyping import Float
10 |
11 | from transformer_lens.hook_points import HookPoint
12 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
13 |
14 |
15 | class RMSNorm(nn.Module):
16 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None):
17 | """
18 | RMSNorm - LayerNorm without the centering and bias (RMS = Root Mean Square)
19 |
20 | length (Optional[int]): If the dimension of the RMSNorm. If not provided, assumed to be d_model
21 | """
22 | super().__init__()
23 | self.cfg = HookedTransformerConfig.unwrap(cfg)
24 | self.eps = self.cfg.eps
25 | if length is None:
26 | self.length = self.cfg.d_model
27 | else:
28 | self.length = length
29 |
30 | self.w = nn.Parameter(torch.ones(self.length, dtype=self.cfg.dtype))
31 |
32 | # Adds a hook point for the normalisation scale factor
33 | self.hook_scale = HookPoint() # [batch, pos, 1]
34 | self.hook_normalized = HookPoint() # [batch, pos, length]
35 |
36 | def forward(
37 | self, x: Float[torch.Tensor, "batch pos length"]
38 | ) -> Float[torch.Tensor, "batch pos length"]:
39 | if self.cfg.dtype not in [torch.float32, torch.float64]:
40 | x = x.to(torch.float32)
41 | scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
42 | (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
43 | )
44 | x = self.hook_normalized(x / scale).to(self.cfg.dtype) # [batch, pos, length]
45 |
46 | if x.device != self.w.device:
47 | self.to(x.device)
48 |
49 | return x * self.w
50 |
--------------------------------------------------------------------------------
/transformer_lens/components/rms_norm_pre.py:
--------------------------------------------------------------------------------
1 | """Hooked Transformer RMS Norm Pre Component.
2 |
3 | This module contains all the component :class:`RMSNormPre`.
4 | """
5 | from typing import Dict, Union
6 |
7 | import torch
8 | import torch.nn as nn
9 | from jaxtyping import Float
10 |
11 | from transformer_lens.hook_points import HookPoint
12 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
13 |
14 |
15 | class RMSNormPre(nn.Module):
16 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
17 | """RMSNormPre - LayerNormPre without the centering and bias (RMS = Root Mean Square)"""
18 | super().__init__()
19 | self.cfg = HookedTransformerConfig.unwrap(cfg)
20 | self.eps = self.cfg.eps
21 |
22 | # Adds a hook point for the normalisation scale factor
23 | self.hook_scale = HookPoint() # [batch, pos]
24 | self.hook_normalized = HookPoint() # [batch, pos, length]
25 |
26 | def forward(
27 | self, x: Float[torch.Tensor, "batch pos length"]
28 | ) -> Float[torch.Tensor, "batch pos length"]:
29 | if self.cfg.dtype not in [torch.float32, torch.float64]:
30 | x = x.to(torch.float32)
31 |
32 | scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
33 | (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
34 | )
35 | return self.hook_normalized(x / scale).to(self.cfg.dtype) # [batch, pos, length]
36 |
--------------------------------------------------------------------------------
/transformer_lens/components/token_typed_embed.py:
--------------------------------------------------------------------------------
1 | """Hooked Transformer Token Typed Embed Component.
2 |
3 | This module contains all the component :class:`TokenTypeEmbed`.
4 | """
5 | from typing import Dict, Union
6 |
7 | import torch
8 | import torch.nn as nn
9 | from jaxtyping import Int
10 |
11 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
12 |
13 |
14 | class TokenTypeEmbed(nn.Module):
15 | """
16 | The token-type embed is a binary ids indicating whether a token belongs to sequence A or B. For example, for two sentences: "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, `1` from Sentence B. If not provided, BERT assumes a single sequence input. Typically, shape is (batch_size, sequence_length).
17 |
18 | See the BERT paper for more information: https://arxiv.org/pdf/1810.04805.pdf
19 | """
20 |
21 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
22 | super().__init__()
23 | self.cfg = HookedTransformerConfig.unwrap(cfg)
24 | self.W_token_type = nn.Parameter(torch.empty(2, self.cfg.d_model, dtype=self.cfg.dtype))
25 |
26 | def forward(self, token_type_ids: Int[torch.Tensor, "batch pos"]):
27 | return self.W_token_type[token_type_ids, :]
28 |
--------------------------------------------------------------------------------
/transformer_lens/components/unembed.py:
--------------------------------------------------------------------------------
1 | """Hooked Transformer Unembed Component.
2 |
3 | This module contains all the component :class:`Unembed`.
4 | """
5 |
6 | from typing import Dict, Union
7 |
8 | import torch
9 | import torch.nn as nn
10 | from jaxtyping import Float
11 |
12 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
13 | from transformer_lens.utilities.addmm import batch_addmm
14 |
15 |
16 | class Unembed(nn.Module):
17 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
18 | super().__init__()
19 | self.cfg = HookedTransformerConfig.unwrap(cfg)
20 | # Note that there's a separate variable for d_vocab_out and d_vocab (the input vocab size). For language tasks these are always the same, but for algorithmic tasks we may want them to be different.
21 | self.W_U: Float[torch.Tensor, "d_model d_vocab_out"] = nn.Parameter(
22 | torch.empty(self.cfg.d_model, self.cfg.d_vocab_out, dtype=self.cfg.dtype)
23 | )
24 | self.b_U: Float[torch.Tensor, "d_vocab_out"] = nn.Parameter(
25 | torch.zeros(self.cfg.d_vocab_out, dtype=self.cfg.dtype)
26 | )
27 |
28 | def forward(
29 | self, residual: Float[torch.Tensor, "batch pos d_model"]
30 | ) -> Float[torch.Tensor, "batch pos d_vocab_out"]:
31 | return batch_addmm(self.b_U, self.W_U, residual)
32 |
--------------------------------------------------------------------------------
/transformer_lens/factories/activation_function_factory.py:
--------------------------------------------------------------------------------
1 | """Activation Function Factory
2 |
3 | Centralized location for selection supported activation functions throughout TransformerLens
4 | """
5 |
6 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
7 | from transformer_lens.utilities.activation_functions import (
8 | SUPPORTED_ACTIVATIONS,
9 | ActivationFunction,
10 | )
11 |
12 |
13 | class ActivationFunctionFactory:
14 | @staticmethod
15 | def pick_activation_function(cfg: HookedTransformerConfig) -> ActivationFunction:
16 | """Use this to select what activation function is needed based on configuration.
17 |
18 | Args:
19 | cfg (HookedTransformerConfig): The already created hooked transformer config
20 |
21 | Raises:
22 | ValueError: If there is a problem with the requested activation function.
23 |
24 | Returns:
25 | ActivationFunction: The activation function based on the dictionary of supported activations.
26 | """
27 | act_fn = cfg.act_fn
28 |
29 | if act_fn is None:
30 | raise ValueError("act_fn not set when trying to select Activation Function")
31 |
32 | activation_function = SUPPORTED_ACTIVATIONS.get(act_fn)
33 |
34 | if activation_function is None:
35 | raise ValueError(f"Invalid activation function name: {act_fn}")
36 |
37 | return activation_function
38 |
--------------------------------------------------------------------------------
/transformer_lens/factories/mlp_factory.py:
--------------------------------------------------------------------------------
1 | """MLP Factory
2 |
3 | Centralized location for creating any MLP needed within TransformerLens
4 | """
5 | from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
6 | from transformer_lens.components.mlps.gated_mlp import GatedMLP
7 | from transformer_lens.components.mlps.gated_mlp_4bit import GatedMLP4Bit
8 | from transformer_lens.components.mlps.mlp import MLP
9 | from transformer_lens.components.mlps.moe import MoE
10 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
11 |
12 |
13 | class MLPFactory:
14 | @staticmethod
15 | def create_mlp(cfg: HookedTransformerConfig) -> CanBeUsedAsMLP:
16 | if cfg.num_experts:
17 | return MoE(cfg)
18 | elif cfg.gated_mlp:
19 | return GatedMLP(cfg) if not cfg.load_in_4bit else GatedMLP4Bit(cfg)
20 | else:
21 | return MLP(cfg)
22 |
--------------------------------------------------------------------------------
/transformer_lens/pretrained/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TransformerLensOrg/TransformerLens/b5a16f849649a237cc02cc2c272ae4dc2085abe4/transformer_lens/pretrained/__init__.py
--------------------------------------------------------------------------------
/transformer_lens/pretrained/weight_conversions/__init__.py:
--------------------------------------------------------------------------------
1 | from .neo import convert_neo_weights
2 | from .gpt2 import convert_gpt2_weights
3 | from .opt import convert_opt_weights
4 | from .gptj import convert_gptj_weights
5 | from .neox import convert_neox_weights
6 | from .llama import convert_llama_weights
7 | from .bert import convert_bert_weights
8 | from .mistral import convert_mistral_weights
9 | from .mixtral import convert_mixtral_weights
10 | from .bloom import convert_bloom_weights
11 | from .coder import convert_coder_weights
12 | from .qwen import convert_qwen_weights
13 | from .qwen2 import convert_qwen2_weights
14 | from .phi import convert_phi_weights
15 | from .phi3 import convert_phi3_weights
16 | from .gemma import convert_gemma_weights
17 | from .mingpt import convert_mingpt_weights
18 | from .nanogpt import convert_nanogpt_weights
19 | from .t5 import convert_t5_weights
20 | from .neel_solu_old import convert_neel_solu_old_weights
21 |
--------------------------------------------------------------------------------
/transformer_lens/pretrained/weight_conversions/bert.py:
--------------------------------------------------------------------------------
1 | import einops
2 |
3 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
4 |
5 |
6 | def convert_bert_weights(bert, cfg: HookedTransformerConfig):
7 | embeddings = bert.bert.embeddings
8 | state_dict = {
9 | "embed.embed.W_E": embeddings.word_embeddings.weight,
10 | "embed.pos_embed.W_pos": embeddings.position_embeddings.weight,
11 | "embed.token_type_embed.W_token_type": embeddings.token_type_embeddings.weight,
12 | "embed.ln.w": embeddings.LayerNorm.weight,
13 | "embed.ln.b": embeddings.LayerNorm.bias,
14 | }
15 |
16 | for l in range(cfg.n_layers):
17 | block = bert.bert.encoder.layer[l]
18 | state_dict[f"blocks.{l}.attn.W_Q"] = einops.rearrange(
19 | block.attention.self.query.weight, "(i h) m -> i m h", i=cfg.n_heads
20 | )
21 | state_dict[f"blocks.{l}.attn.b_Q"] = einops.rearrange(
22 | block.attention.self.query.bias, "(i h) -> i h", i=cfg.n_heads
23 | )
24 | state_dict[f"blocks.{l}.attn.W_K"] = einops.rearrange(
25 | block.attention.self.key.weight, "(i h) m -> i m h", i=cfg.n_heads
26 | )
27 | state_dict[f"blocks.{l}.attn.b_K"] = einops.rearrange(
28 | block.attention.self.key.bias, "(i h) -> i h", i=cfg.n_heads
29 | )
30 | state_dict[f"blocks.{l}.attn.W_V"] = einops.rearrange(
31 | block.attention.self.value.weight, "(i h) m -> i m h", i=cfg.n_heads
32 | )
33 | state_dict[f"blocks.{l}.attn.b_V"] = einops.rearrange(
34 | block.attention.self.value.bias, "(i h) -> i h", i=cfg.n_heads
35 | )
36 | state_dict[f"blocks.{l}.attn.W_O"] = einops.rearrange(
37 | block.attention.output.dense.weight,
38 | "m (i h) -> i h m",
39 | i=cfg.n_heads,
40 | )
41 | state_dict[f"blocks.{l}.attn.b_O"] = block.attention.output.dense.bias
42 | state_dict[f"blocks.{l}.ln1.w"] = block.attention.output.LayerNorm.weight
43 | state_dict[f"blocks.{l}.ln1.b"] = block.attention.output.LayerNorm.bias
44 | state_dict[f"blocks.{l}.mlp.W_in"] = einops.rearrange(
45 | block.intermediate.dense.weight, "mlp model -> model mlp"
46 | )
47 | state_dict[f"blocks.{l}.mlp.b_in"] = block.intermediate.dense.bias
48 | state_dict[f"blocks.{l}.mlp.W_out"] = einops.rearrange(
49 | block.output.dense.weight, "model mlp -> mlp model"
50 | )
51 | state_dict[f"blocks.{l}.mlp.b_out"] = block.output.dense.bias
52 | state_dict[f"blocks.{l}.ln2.w"] = block.output.LayerNorm.weight
53 | state_dict[f"blocks.{l}.ln2.b"] = block.output.LayerNorm.bias
54 |
55 | pooler = bert.bert.pooler
56 | state_dict["pooler.W"] = pooler.dense.weight.T
57 | state_dict["pooler.b"] = pooler.dense.bias
58 |
59 | mlm_head = bert.cls.predictions
60 | state_dict["mlm_head.W"] = mlm_head.transform.dense.weight.T
61 | state_dict["mlm_head.b"] = mlm_head.transform.dense.bias
62 | state_dict["mlm_head.ln.w"] = mlm_head.transform.LayerNorm.weight
63 | state_dict["mlm_head.ln.b"] = mlm_head.transform.LayerNorm.bias
64 |
65 | # The NSP head does not have an unembedding
66 | # so we are only using weights from the MLM head
67 | # Note: BERT uses tied embeddings
68 | state_dict["unembed.W_U"] = mlm_head.decoder.weight.T
69 | state_dict["unembed.b_U"] = mlm_head.decoder.bias
70 |
71 | nsp_head = bert.cls.seq_relationship
72 | state_dict["nsp_head.W"] = nsp_head.weight.T
73 | state_dict["nsp_head.b"] = nsp_head.bias
74 |
75 | return state_dict
76 |
--------------------------------------------------------------------------------
/transformer_lens/pretrained/weight_conversions/bloom.py:
--------------------------------------------------------------------------------
1 | import einops
2 |
3 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
4 |
5 |
6 | def convert_bloom_weights(bloom, cfg: HookedTransformerConfig):
7 | state_dict = {}
8 |
9 | state_dict["embed.W_E"] = bloom.transformer.word_embeddings.weight
10 |
11 | # Bloom uses post embedding layer norm
12 | state_dict["embed.ln.w"] = bloom.transformer.word_embeddings_layernorm.weight
13 | state_dict["embed.ln.b"] = bloom.transformer.word_embeddings_layernorm.bias
14 |
15 | for l in range(cfg.n_layers):
16 | state_dict[f"blocks.{l}.ln1.w"] = bloom.transformer.h[l].input_layernorm.weight
17 | state_dict[f"blocks.{l}.ln1.b"] = bloom.transformer.h[l].input_layernorm.bias
18 |
19 | W = bloom.transformer.h[l].self_attention.query_key_value.weight
20 |
21 | W_split = W.T.reshape(cfg.d_model, cfg.n_heads, 3, cfg.d_head)
22 |
23 | W_Q, W_K, W_V = W_split[..., 0, :], W_split[..., 1, :], W_split[..., 2, :]
24 | W_Q = einops.rearrange(W_Q, "m n h ->n m h", n=cfg.n_heads)
25 | W_K = einops.rearrange(W_K, "m n h ->n m h", n=cfg.n_heads)
26 | W_V = einops.rearrange(W_V, "m n h ->n m h", n=cfg.n_heads)
27 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
28 | state_dict[f"blocks.{l}.attn.W_K"] = W_K
29 | state_dict[f"blocks.{l}.attn.W_V"] = W_V
30 |
31 | qkv_bias = bloom.transformer.h[l].self_attention.query_key_value.bias
32 | qkv_bias = qkv_bias.reshape(cfg.n_heads, 3, cfg.d_head)
33 |
34 | state_dict[f"blocks.{l}.attn.b_Q"] = qkv_bias[:, 0, :]
35 | state_dict[f"blocks.{l}.attn.b_K"] = qkv_bias[:, 1, :]
36 | state_dict[f"blocks.{l}.attn.b_V"] = qkv_bias[:, 2, :]
37 |
38 | W_O = bloom.transformer.h[l].self_attention.dense.weight.T # [1024, 1024]
39 | W_O = einops.rearrange(W_O, "(n h) m->n h m", n=cfg.n_heads) # [n_heads, d_head, d_model]
40 | state_dict[f"blocks.{l}.attn.W_O"] = W_O
41 | state_dict[f"blocks.{l}.attn.b_O"] = bloom.transformer.h[l].self_attention.dense.bias
42 |
43 | state_dict[f"blocks.{l}.ln2.w"] = bloom.transformer.h[l].post_attention_layernorm.weight
44 | state_dict[f"blocks.{l}.ln2.b"] = bloom.transformer.h[l].post_attention_layernorm.bias
45 |
46 | W_in = bloom.transformer.h[l].mlp.dense_h_to_4h.weight.T
47 | state_dict[f"blocks.{l}.mlp.W_in"] = W_in
48 | state_dict[f"blocks.{l}.mlp.b_in"] = bloom.transformer.h[l].mlp.dense_h_to_4h.bias
49 |
50 | W_out = bloom.transformer.h[l].mlp.dense_4h_to_h.weight.T
51 | state_dict[f"blocks.{l}.mlp.W_out"] = W_out
52 | state_dict[f"blocks.{l}.mlp.b_out"] = bloom.transformer.h[l].mlp.dense_4h_to_h.bias
53 | state_dict["unembed.W_U"] = bloom.lm_head.weight.T
54 |
55 | state_dict["ln_final.w"] = bloom.transformer.ln_f.weight
56 | state_dict["ln_final.b"] = bloom.transformer.ln_f.bias
57 | return state_dict
58 |
--------------------------------------------------------------------------------
/transformer_lens/pretrained/weight_conversions/coder.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import torch
3 |
4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
5 |
6 |
7 | def convert_coder_weights(model, cfg: HookedTransformerConfig):
8 | state_dict = {}
9 |
10 | state_dict["embed.W_E"] = model.transformer.wte.weight
11 | state_dict["pos_embed.W_pos"] = model.transformer.wpe.weight
12 |
13 | for l in range(cfg.n_layers):
14 | state_dict[f"blocks.{l}.ln1.w"] = model.transformer.h[l].ln_1.weight
15 | state_dict[f"blocks.{l}.ln1.b"] = model.transformer.h[l].ln_1.bias
16 |
17 | # In GPT-2, q,k,v are produced by one big linear map, whose output is
18 | # concat([q, k, v])
19 | W_KV = model.transformer.h[l].attn.kv_attn.weight # [d_model, 2 * d_head]
20 | W_K, W_V = torch.tensor_split(W_KV, 2, dim=1)
21 | W_Q = model.transformer.h[l].attn.q_attn.weight # [d_model, d_model]
22 | W_Q = einops.rearrange(W_Q, "m (i h)->i m h", i=cfg.n_heads)
23 | W_K = einops.repeat(W_K, "m h -> i m h", i=cfg.n_heads)
24 | W_V = einops.repeat(W_V, "m h -> i m h", i=cfg.n_heads)
25 |
26 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
27 | state_dict[f"blocks.{l}.attn.W_K"] = W_K
28 | state_dict[f"blocks.{l}.attn.W_V"] = W_V
29 |
30 | b_Q = einops.rearrange(
31 | model.transformer.h[l].attn.q_attn.bias,
32 | "(index head)-> index head",
33 | index=cfg.n_heads,
34 | head=cfg.d_head,
35 | )
36 | b_KV = model.transformer.h[l].attn.kv_attn.bias # [2 * d_head]
37 | b_K, b_V = torch.tensor_split(b_KV, 2, dim=0)
38 | b_K = einops.repeat(b_K, "head -> index head", index=cfg.n_heads)
39 | b_V = einops.repeat(b_V, "head -> index head", index=cfg.n_heads)
40 | state_dict[f"blocks.{l}.attn.b_Q"] = b_Q
41 | state_dict[f"blocks.{l}.attn.b_K"] = b_K
42 | state_dict[f"blocks.{l}.attn.b_V"] = b_V
43 |
44 | W_O = model.transformer.h[l].attn.c_proj.weight
45 | W_O = einops.rearrange(W_O, "(i h) m->i h m", i=cfg.n_heads)
46 | state_dict[f"blocks.{l}.attn.W_O"] = W_O
47 | state_dict[f"blocks.{l}.attn.b_O"] = model.transformer.h[l].attn.c_proj.bias
48 |
49 | state_dict[f"blocks.{l}.ln2.w"] = model.transformer.h[l].ln_2.weight
50 | state_dict[f"blocks.{l}.ln2.b"] = model.transformer.h[l].ln_2.bias
51 |
52 | W_in = model.transformer.h[l].mlp.c_fc.weight
53 | state_dict[f"blocks.{l}.mlp.W_in"] = W_in
54 | state_dict[f"blocks.{l}.mlp.b_in"] = model.transformer.h[l].mlp.c_fc.bias
55 |
56 | W_out = model.transformer.h[l].mlp.c_proj.weight
57 | state_dict[f"blocks.{l}.mlp.W_out"] = W_out
58 | state_dict[f"blocks.{l}.mlp.b_out"] = model.transformer.h[l].mlp.c_proj.bias
59 | state_dict["unembed.W_U"] = model.lm_head.weight.T
60 |
61 | state_dict["ln_final.w"] = model.transformer.ln_f.weight
62 | state_dict["ln_final.b"] = model.transformer.ln_f.bias
63 | return state_dict
64 |
--------------------------------------------------------------------------------
/transformer_lens/pretrained/weight_conversions/gpt2.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import torch
3 |
4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
5 |
6 |
7 | def convert_gpt2_weights(gpt2, cfg: HookedTransformerConfig):
8 | state_dict = {}
9 |
10 | state_dict["embed.W_E"] = gpt2.transformer.wte.weight
11 | state_dict["pos_embed.W_pos"] = gpt2.transformer.wpe.weight
12 |
13 | for l in range(cfg.n_layers):
14 | state_dict[f"blocks.{l}.ln1.w"] = gpt2.transformer.h[l].ln_1.weight
15 | state_dict[f"blocks.{l}.ln1.b"] = gpt2.transformer.h[l].ln_1.bias
16 |
17 | # In GPT-2, q,k,v are produced by one big linear map, whose output is
18 | # concat([q, k, v])
19 | W = gpt2.transformer.h[l].attn.c_attn.weight
20 | W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=1)
21 | W_Q = einops.rearrange(W_Q, "m (i h)->i m h", i=cfg.n_heads)
22 | W_K = einops.rearrange(W_K, "m (i h)->i m h", i=cfg.n_heads)
23 | W_V = einops.rearrange(W_V, "m (i h)->i m h", i=cfg.n_heads)
24 |
25 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
26 | state_dict[f"blocks.{l}.attn.W_K"] = W_K
27 | state_dict[f"blocks.{l}.attn.W_V"] = W_V
28 |
29 | qkv_bias = gpt2.transformer.h[l].attn.c_attn.bias
30 | qkv_bias = einops.rearrange(
31 | qkv_bias,
32 | "(qkv index head)->qkv index head",
33 | qkv=3,
34 | index=cfg.n_heads,
35 | head=cfg.d_head,
36 | )
37 | state_dict[f"blocks.{l}.attn.b_Q"] = qkv_bias[0]
38 | state_dict[f"blocks.{l}.attn.b_K"] = qkv_bias[1]
39 | state_dict[f"blocks.{l}.attn.b_V"] = qkv_bias[2]
40 |
41 | W_O = gpt2.transformer.h[l].attn.c_proj.weight
42 | W_O = einops.rearrange(W_O, "(i h) m->i h m", i=cfg.n_heads)
43 | state_dict[f"blocks.{l}.attn.W_O"] = W_O
44 | state_dict[f"blocks.{l}.attn.b_O"] = gpt2.transformer.h[l].attn.c_proj.bias
45 |
46 | state_dict[f"blocks.{l}.ln2.w"] = gpt2.transformer.h[l].ln_2.weight
47 | state_dict[f"blocks.{l}.ln2.b"] = gpt2.transformer.h[l].ln_2.bias
48 |
49 | W_in = gpt2.transformer.h[l].mlp.c_fc.weight
50 | state_dict[f"blocks.{l}.mlp.W_in"] = W_in
51 | state_dict[f"blocks.{l}.mlp.b_in"] = gpt2.transformer.h[l].mlp.c_fc.bias
52 |
53 | W_out = gpt2.transformer.h[l].mlp.c_proj.weight
54 | state_dict[f"blocks.{l}.mlp.W_out"] = W_out
55 | state_dict[f"blocks.{l}.mlp.b_out"] = gpt2.transformer.h[l].mlp.c_proj.bias
56 | state_dict["unembed.W_U"] = gpt2.lm_head.weight.T
57 |
58 | state_dict["ln_final.w"] = gpt2.transformer.ln_f.weight
59 | state_dict["ln_final.b"] = gpt2.transformer.ln_f.bias
60 | return state_dict
61 |
--------------------------------------------------------------------------------
/transformer_lens/pretrained/weight_conversions/gptj.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import torch
3 |
4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
5 |
6 |
7 | def convert_gptj_weights(gptj, cfg: HookedTransformerConfig):
8 | state_dict = {}
9 |
10 | state_dict["embed.W_E"] = gptj.transformer.wte.weight
11 |
12 | for l in range(cfg.n_layers):
13 | state_dict[f"blocks.{l}.ln1.w"] = gptj.transformer.h[l].ln_1.weight
14 | state_dict[f"blocks.{l}.ln1.b"] = gptj.transformer.h[l].ln_1.bias
15 |
16 | W_Q = gptj.transformer.h[l].attn.q_proj.weight
17 | W_K = gptj.transformer.h[l].attn.k_proj.weight
18 | W_V = gptj.transformer.h[l].attn.v_proj.weight
19 | W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads)
20 | W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads)
21 | W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads)
22 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
23 | state_dict[f"blocks.{l}.attn.W_K"] = W_K
24 | state_dict[f"blocks.{l}.attn.W_V"] = W_V
25 |
26 | state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
27 | state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
28 | state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
29 |
30 | W_O = gptj.transformer.h[l].attn.out_proj.weight
31 | W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads)
32 | state_dict[f"blocks.{l}.attn.W_O"] = W_O
33 | state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
34 |
35 | # Layer Norm 1 and 2 are tied.
36 | state_dict[f"blocks.{l}.ln2.w"] = state_dict[f"blocks.{l}.ln1.w"]
37 | state_dict[f"blocks.{l}.ln2.b"] = state_dict[f"blocks.{l}.ln1.b"]
38 |
39 | state_dict[f"blocks.{l}.mlp.W_in"] = gptj.transformer.h[l].mlp.fc_in.weight.T
40 | state_dict[f"blocks.{l}.mlp.b_in"] = gptj.transformer.h[l].mlp.fc_in.bias
41 |
42 | state_dict[f"blocks.{l}.mlp.W_out"] = gptj.transformer.h[l].mlp.fc_out.weight.T
43 | state_dict[f"blocks.{l}.mlp.b_out"] = gptj.transformer.h[l].mlp.fc_out.bias
44 | state_dict["ln_final.w"] = gptj.transformer.ln_f.weight
45 | state_dict["ln_final.b"] = gptj.transformer.ln_f.bias
46 |
47 | state_dict["unembed.W_U"] = gptj.lm_head.weight.T
48 | # Contains a bias, for some reason?
49 | state_dict["unembed.b_U"] = gptj.lm_head.bias
50 | return state_dict
51 |
--------------------------------------------------------------------------------
/transformer_lens/pretrained/weight_conversions/llama.py:
--------------------------------------------------------------------------------
1 | from typing import cast
2 |
3 | import einops
4 | import torch
5 |
6 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
7 |
8 |
9 | def convert_llama_weights(llama, cfg: HookedTransformerConfig):
10 | state_dict = {}
11 |
12 | state_dict["embed.W_E"] = llama.model.embed_tokens.weight
13 |
14 | # Some models with the Llama architecture use Grouped Query Attention, and so for these we need to modify
15 | # the state dict keys for the K/V attention weight/biases, prepending "_" to the key names.
16 | using_gqa = cfg.n_key_value_heads is not None
17 | gqa_uscore = "_" if using_gqa else ""
18 | # need a cast since MyPy isn't smart enough to realize that using_gqa implies n_key_value_heads is not None
19 | n_kv_heads = cast(int, cfg.n_key_value_heads if using_gqa else cfg.n_heads)
20 |
21 | # llama has no biases anywhere and deals with everything else roughly like
22 | # GPTNeoX with different names
23 |
24 | assert cfg.d_mlp is not None # keep mypy happy
25 |
26 | for l in range(cfg.n_layers):
27 | state_dict[f"blocks.{l}.ln1.w"] = llama.model.layers[l].input_layernorm.weight
28 |
29 | W_Q = llama.model.layers[l].self_attn.q_proj.weight
30 | W_K = llama.model.layers[l].self_attn.k_proj.weight
31 | W_V = llama.model.layers[l].self_attn.v_proj.weight
32 |
33 | # in case of quantization,
34 | # parameters should stay as bitsandbytes.nn.modules.Params4bit
35 | if not cfg.load_in_4bit:
36 | W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
37 | W_K = einops.rearrange(W_K, "(n h) m->n m h", n=n_kv_heads)
38 | W_V = einops.rearrange(W_V, "(n h) m->n m h", n=n_kv_heads)
39 |
40 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
41 | state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K
42 | state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V
43 |
44 | state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(
45 | cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device
46 | )
47 | state_dict[f"blocks.{l}.attn.{gqa_uscore}b_K"] = torch.zeros(
48 | n_kv_heads,
49 | cfg.d_head,
50 | dtype=cfg.dtype,
51 | device=cfg.device,
52 | )
53 | state_dict[f"blocks.{l}.attn.{gqa_uscore}b_V"] = torch.zeros(
54 | n_kv_heads,
55 | cfg.d_head,
56 | dtype=cfg.dtype,
57 | device=cfg.device,
58 | )
59 |
60 | W_O = llama.model.layers[l].self_attn.o_proj.weight
61 |
62 | if not cfg.load_in_4bit:
63 | W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
64 |
65 | state_dict[f"blocks.{l}.attn.W_O"] = W_O.to(device=cfg.device)
66 |
67 | state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(
68 | cfg.d_model, dtype=cfg.dtype, device=cfg.device
69 | )
70 |
71 | state_dict[f"blocks.{l}.ln2.w"] = llama.model.layers[l].post_attention_layernorm.weight
72 |
73 | # in case of quantization,
74 | # parameters should stay as bitsandbytes.nn.modules.Params4bit
75 | if not cfg.load_in_4bit:
76 | state_dict[f"blocks.{l}.mlp.W_in"] = llama.model.layers[l].mlp.up_proj.weight.T
77 | state_dict[f"blocks.{l}.mlp.W_gate"] = llama.model.layers[l].mlp.gate_proj.weight.T
78 | state_dict[f"blocks.{l}.mlp.W_out"] = llama.model.layers[l].mlp.down_proj.weight.T
79 | else:
80 | state_dict[f"blocks.{l}.mlp.W_in"] = llama.model.layers[l].mlp.up_proj.weight
81 | state_dict[f"blocks.{l}.mlp.W_gate"] = llama.model.layers[l].mlp.gate_proj.weight
82 | state_dict[f"blocks.{l}.mlp.W_out"] = llama.model.layers[l].mlp.down_proj.weight
83 |
84 | state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(
85 | cfg.d_mlp, dtype=cfg.dtype, device=cfg.device
86 | )
87 | state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(
88 | cfg.d_model, dtype=cfg.dtype, device=cfg.device
89 | )
90 |
91 | state_dict["ln_final.w"] = llama.model.norm.weight
92 |
93 | state_dict["unembed.W_U"] = llama.lm_head.weight.T
94 | state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype, device=cfg.device)
95 |
96 | return state_dict
97 |
--------------------------------------------------------------------------------
/transformer_lens/pretrained/weight_conversions/mingpt.py:
--------------------------------------------------------------------------------
1 | import einops
2 |
3 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
4 |
5 |
6 | def convert_mingpt_weights(old_state_dict, cfg: HookedTransformerConfig):
7 | # mingpt (https://github.com/karpathy/minGPT) is mostly similar to GPT-2,
8 | # but doesn't concat the QKV matrices.
9 | state_dict = {}
10 |
11 | state_dict["embed.W_E"] = old_state_dict["tok_emb.weight"]
12 | state_dict["pos_embed.W_pos"] = old_state_dict["pos_emb"].squeeze()
13 |
14 | for l in range(cfg.n_layers):
15 | state_dict[f"blocks.{l}.ln1.w"] = old_state_dict[f"blocks.{l}.ln1.weight"]
16 | state_dict[f"blocks.{l}.ln1.b"] = old_state_dict[f"blocks.{l}.ln1.bias"]
17 |
18 | W_Q = old_state_dict[f"blocks.{l}.attn.query.weight"]
19 | W_K = old_state_dict[f"blocks.{l}.attn.key.weight"]
20 | W_V = old_state_dict[f"blocks.{l}.attn.value.weight"]
21 | W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads)
22 | W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads)
23 | W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads)
24 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
25 | state_dict[f"blocks.{l}.attn.W_K"] = W_K
26 | state_dict[f"blocks.{l}.attn.W_V"] = W_V
27 |
28 | q_bias = einops.rearrange(
29 | old_state_dict[f"blocks.{l}.attn.query.bias"], "(i h)->i h", i=cfg.n_heads
30 | )
31 | k_bias = einops.rearrange(
32 | old_state_dict[f"blocks.{l}.attn.key.bias"], "(i h)->i h", i=cfg.n_heads
33 | )
34 | v_bias = einops.rearrange(
35 | old_state_dict[f"blocks.{l}.attn.value.bias"], "(i h)->i h", i=cfg.n_heads
36 | )
37 |
38 | state_dict[f"blocks.{l}.attn.b_Q"] = q_bias
39 | state_dict[f"blocks.{l}.attn.b_K"] = k_bias
40 | state_dict[f"blocks.{l}.attn.b_V"] = v_bias
41 |
42 | W_O = old_state_dict[f"blocks.{l}.attn.proj.weight"]
43 | W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads)
44 | state_dict[f"blocks.{l}.attn.W_O"] = W_O
45 | state_dict[f"blocks.{l}.attn.b_O"] = old_state_dict[f"blocks.{l}.attn.proj.bias"]
46 |
47 | state_dict[f"blocks.{l}.ln2.w"] = old_state_dict[f"blocks.{l}.ln2.weight"]
48 | state_dict[f"blocks.{l}.ln2.b"] = old_state_dict[f"blocks.{l}.ln2.bias"]
49 |
50 | W_in = old_state_dict[f"blocks.{l}.mlp.0.weight"]
51 | state_dict[f"blocks.{l}.mlp.W_in"] = W_in.T
52 | state_dict[f"blocks.{l}.mlp.b_in"] = old_state_dict[f"blocks.{l}.mlp.0.bias"]
53 |
54 | W_out = old_state_dict[f"blocks.{l}.mlp.2.weight"]
55 | state_dict[f"blocks.{l}.mlp.W_out"] = W_out.T
56 | state_dict[f"blocks.{l}.mlp.b_out"] = old_state_dict[f"blocks.{l}.mlp.2.bias"]
57 |
58 | state_dict["unembed.W_U"] = old_state_dict["head.weight"].T
59 |
60 | state_dict["ln_final.w"] = old_state_dict["ln_f.weight"]
61 | state_dict["ln_final.b"] = old_state_dict["ln_f.bias"]
62 |
63 | return state_dict
64 |
--------------------------------------------------------------------------------
/transformer_lens/pretrained/weight_conversions/mistral.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import torch
3 |
4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
5 |
6 |
7 | def convert_mistral_weights(mistral, cfg: HookedTransformerConfig):
8 | state_dict = {}
9 |
10 | state_dict["embed.W_E"] = mistral.model.embed_tokens.weight
11 |
12 | assert cfg.n_key_value_heads is not None # keep mypy happy
13 | assert cfg.d_mlp is not None # keep mypy happy
14 |
15 | # Mistral has no biases anywhere
16 | for l in range(cfg.n_layers):
17 | state_dict[f"blocks.{l}.ln1.w"] = mistral.model.layers[l].input_layernorm.weight
18 |
19 | W_Q = mistral.model.layers[l].self_attn.q_proj.weight
20 | W_K = mistral.model.layers[l].self_attn.k_proj.weight
21 | W_V = mistral.model.layers[l].self_attn.v_proj.weight
22 | W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
23 | W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads)
24 | W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads)
25 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
26 | state_dict[f"blocks.{l}.attn._W_K"] = W_K
27 | state_dict[f"blocks.{l}.attn._W_V"] = W_V
28 |
29 | state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
30 | state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros(
31 | cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
32 | )
33 | state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros(
34 | cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
35 | )
36 |
37 | W_O = mistral.model.layers[l].self_attn.o_proj.weight
38 | W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
39 | state_dict[f"blocks.{l}.attn.W_O"] = W_O
40 |
41 | state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
42 |
43 | state_dict[f"blocks.{l}.ln2.w"] = mistral.model.layers[l].post_attention_layernorm.weight
44 |
45 | state_dict[f"blocks.{l}.mlp.W_in"] = mistral.model.layers[l].mlp.up_proj.weight.T
46 | state_dict[f"blocks.{l}.mlp.W_gate"] = mistral.model.layers[l].mlp.gate_proj.weight.T
47 | state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype)
48 |
49 | state_dict[f"blocks.{l}.mlp.W_out"] = mistral.model.layers[l].mlp.down_proj.weight.T
50 | state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
51 |
52 | state_dict["ln_final.w"] = mistral.model.norm.weight
53 |
54 | state_dict["unembed.W_U"] = mistral.lm_head.weight.T
55 | state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
56 |
57 | return state_dict
58 |
--------------------------------------------------------------------------------
/transformer_lens/pretrained/weight_conversions/mixtral.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import torch
3 |
4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
5 |
6 |
7 | def convert_mixtral_weights(mixtral, cfg: HookedTransformerConfig):
8 | # The same as Mistral, but with the MLP replaced with MoE
9 | # As with Mistral, Mixtral has no biases
10 |
11 | state_dict = {}
12 |
13 | assert cfg.n_key_value_heads is not None # keep mypy happy
14 | assert cfg.d_mlp is not None
15 | assert cfg.num_experts is not None
16 |
17 | state_dict["embed.W_E"] = mixtral.model.embed_tokens.weight
18 |
19 | for l in range(cfg.n_layers):
20 | state_dict[f"blocks.{l}.ln1.w"] = mixtral.model.layers[l].input_layernorm.weight
21 |
22 | W_Q = mixtral.model.layers[l].self_attn.q_proj.weight
23 | W_K = mixtral.model.layers[l].self_attn.k_proj.weight
24 | W_V = mixtral.model.layers[l].self_attn.v_proj.weight
25 | W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
26 | W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads)
27 | W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads)
28 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
29 | state_dict[f"blocks.{l}.attn._W_K"] = W_K
30 | state_dict[f"blocks.{l}.attn._W_V"] = W_V
31 |
32 | state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
33 | state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros(
34 | cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
35 | )
36 | state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros(
37 | cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
38 | )
39 |
40 | W_O = mixtral.model.layers[l].self_attn.o_proj.weight
41 | W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
42 | state_dict[f"blocks.{l}.attn.W_O"] = W_O
43 |
44 | state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
45 |
46 | state_dict[f"blocks.{l}.ln2.w"] = mixtral.model.layers[l].post_attention_layernorm.weight
47 |
48 | state_dict[f"blocks.{l}.mlp.W_gate.weight"] = mixtral.model.layers[
49 | l
50 | ].block_sparse_moe.gate.weight
51 |
52 | # The mapping here from wn to W_{in/out/gate} is a bit confusing:
53 | # w1 -> W_gate
54 | # w2 -> W_out
55 | # w3 -> W_in
56 | # See https://github.com/mistralai/mistral-inference/blob/8598cf582091a596671be31990448e0620017851/mistral/model.py#L128 for reference
57 | for e in range(cfg.num_experts):
58 | state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.weight"] = (
59 | mixtral.model.layers[l].block_sparse_moe.experts[e].w3.weight
60 | )
61 | state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.weight"] = (
62 | mixtral.model.layers[l].block_sparse_moe.experts[e].w1.weight
63 | )
64 | state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.weight"] = (
65 | mixtral.model.layers[l].block_sparse_moe.experts[e].w2.weight
66 | )
67 |
68 | state_dict["ln_final.w"] = mixtral.model.norm.weight.data
69 |
70 | state_dict["unembed.W_U"] = mixtral.lm_head.weight.T
71 | state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
72 |
73 | return state_dict
74 |
--------------------------------------------------------------------------------
/transformer_lens/pretrained/weight_conversions/nanogpt.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import torch
3 |
4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
5 |
6 |
7 | def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig):
8 | """For https://github.com/karpathy/nanoGPT
9 | There are two complications with converting nanogpt models:
10 | The first is that some state dicts have an unwanted prefix on keys that needs to be removed.
11 | The second is that the models can be saved with or without bias. By default, there
12 | is no bias. This function can handle both cases."""
13 | # Nanogpt models saved after torch.compile() have this unwanted prefix
14 | # This is a simple way to remove it
15 | unwanted_prefix = "_orig_mod."
16 | for k, v in list(old_state_dict.items()):
17 | if k.startswith(unwanted_prefix):
18 | old_state_dict[k[len(unwanted_prefix) :]] = old_state_dict.pop(k)
19 |
20 | new_state_dict = {}
21 | new_state_dict["pos_embed.W_pos"] = old_state_dict["transformer.wpe.weight"]
22 | new_state_dict["embed.W_E"] = old_state_dict["transformer.wte.weight"]
23 |
24 | new_state_dict["ln_final.w"] = old_state_dict["transformer.ln_f.weight"]
25 | new_state_dict["ln_final.b"] = torch.zeros_like(old_state_dict["transformer.ln_f.weight"])
26 | new_state_dict["unembed.W_U"] = old_state_dict["lm_head.weight"].T
27 |
28 | bias = False
29 | if "transformer.ln_f.bias" in old_state_dict:
30 | bias = True
31 | new_state_dict["ln_final.b"] = old_state_dict["transformer.ln_f.bias"]
32 |
33 | for layer in range(cfg.n_layers):
34 | layer_key = f"transformer.h.{layer}"
35 |
36 | new_state_dict[f"blocks.{layer}.ln1.w"] = old_state_dict[f"{layer_key}.ln_1.weight"]
37 | # A bias of zeros is required for folding layer norm
38 | new_state_dict[f"blocks.{layer}.ln1.b"] = torch.zeros_like(
39 | old_state_dict[f"{layer_key}.ln_1.weight"]
40 | )
41 | new_state_dict[f"blocks.{layer}.ln2.w"] = old_state_dict[f"{layer_key}.ln_2.weight"]
42 | new_state_dict[f"blocks.{layer}.ln2.b"] = torch.zeros_like(
43 | old_state_dict[f"{layer_key}.ln_2.weight"]
44 | )
45 |
46 | W = old_state_dict[f"{layer_key}.attn.c_attn.weight"]
47 | W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=0)
48 | W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads)
49 | W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads)
50 | W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads)
51 | new_state_dict[f"blocks.{layer}.attn.W_Q"] = W_Q
52 | new_state_dict[f"blocks.{layer}.attn.W_K"] = W_K
53 | new_state_dict[f"blocks.{layer}.attn.W_V"] = W_V
54 |
55 | W_O = old_state_dict[f"{layer_key}.attn.c_proj.weight"]
56 | W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads)
57 | new_state_dict[f"blocks.{layer}.attn.W_O"] = W_O
58 |
59 | new_state_dict[f"blocks.{layer}.mlp.W_in"] = old_state_dict[
60 | f"{layer_key}.mlp.c_fc.weight"
61 | ].T
62 | new_state_dict[f"blocks.{layer}.mlp.W_out"] = old_state_dict[
63 | f"{layer_key}.mlp.c_proj.weight"
64 | ].T
65 |
66 | if bias:
67 | new_state_dict[f"blocks.{layer}.ln1.b"] = old_state_dict[f"{layer_key}.ln_1.bias"]
68 | new_state_dict[f"blocks.{layer}.ln2.b"] = old_state_dict[f"{layer_key}.ln_2.bias"]
69 | new_state_dict[f"blocks.{layer}.mlp.b_in"] = old_state_dict[
70 | f"{layer_key}.mlp.c_fc.bias"
71 | ]
72 | new_state_dict[f"blocks.{layer}.mlp.b_out"] = old_state_dict[
73 | f"{layer_key}.mlp.c_proj.bias"
74 | ]
75 |
76 | B = old_state_dict[f"{layer_key}.attn.c_attn.bias"]
77 | B_Q, B_K, B_V = torch.tensor_split(B, 3, dim=0)
78 | B_Q = einops.rearrange(B_Q, "(i h)->i h", i=cfg.n_heads)
79 | B_K = einops.rearrange(B_K, "(i h)->i h", i=cfg.n_heads)
80 | B_V = einops.rearrange(B_V, "(i h)->i h", i=cfg.n_heads)
81 | new_state_dict[f"blocks.{layer}.attn.b_Q"] = B_Q
82 | new_state_dict[f"blocks.{layer}.attn.b_K"] = B_K
83 | new_state_dict[f"blocks.{layer}.attn.b_V"] = B_V
84 | new_state_dict[f"blocks.{layer}.attn.b_O"] = old_state_dict[
85 | f"{layer_key}.attn.c_proj.bias"
86 | ]
87 |
88 | return new_state_dict
89 |
--------------------------------------------------------------------------------
/transformer_lens/pretrained/weight_conversions/neel_solu_old.py:
--------------------------------------------------------------------------------
1 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
2 |
3 |
4 | def convert_neel_solu_old_weights(state_dict: dict, cfg: HookedTransformerConfig):
5 | """
6 | Converts the weights of my old SoLU models to the HookedTransformer format.
7 | Takes as input a state dict, *not* a model object.
8 |
9 | There are a bunch of dumb bugs in the original code, sorry!
10 |
11 | Models 1L, 2L, 4L and 6L have left facing weights (ie, weights have shape
12 | [dim_out, dim_in]) while HookedTransformer does right facing (ie [dim_in,
13 | dim_out]).
14 |
15 | 8L has *just* a left facing W_pos, the rest right facing.
16 |
17 | And some models were trained with
18 | """
19 | # Early models have left facing W_pos
20 | reverse_pos = cfg.n_layers <= 8
21 |
22 | # Models prior to 8L have left facing everything (8L has JUST left facing W_pos - sorry! Stupid bug)
23 | reverse_weights = cfg.n_layers <= 6
24 |
25 | new_state_dict = {}
26 | for k, v in state_dict.items():
27 | k = k.replace("norm", "ln")
28 | if k.startswith("ln."):
29 | k = k.replace("ln.", "ln_final.")
30 | new_state_dict[k] = v
31 |
32 | if reverse_pos:
33 | new_state_dict["pos_embed.W_pos"] = new_state_dict["pos_embed.W_pos"].T
34 | if reverse_weights:
35 | for k, v in new_state_dict.items():
36 | if "W_" in k and "W_pos" not in k:
37 | new_state_dict[k] = v.transpose(-2, -1)
38 | return new_state_dict
39 |
--------------------------------------------------------------------------------
/transformer_lens/pretrained/weight_conversions/neo.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import torch
3 |
4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
5 |
6 |
7 | def convert_neo_weights(neo, cfg: HookedTransformerConfig):
8 | state_dict = {}
9 |
10 | state_dict["embed.W_E"] = neo.transformer.wte.weight
11 | state_dict["pos_embed.W_pos"] = neo.transformer.wpe.weight
12 |
13 | for l in range(cfg.n_layers):
14 | state_dict[f"blocks.{l}.ln1.w"] = neo.transformer.h[l].ln_1.weight
15 | state_dict[f"blocks.{l}.ln1.b"] = neo.transformer.h[l].ln_1.bias
16 |
17 | W_Q = neo.transformer.h[l].attn.attention.q_proj.weight
18 | W_K = neo.transformer.h[l].attn.attention.k_proj.weight
19 | W_V = neo.transformer.h[l].attn.attention.v_proj.weight
20 | W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads)
21 | W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads)
22 | W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads)
23 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
24 | state_dict[f"blocks.{l}.attn.W_K"] = W_K
25 | state_dict[f"blocks.{l}.attn.W_V"] = W_V
26 |
27 | state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
28 | state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
29 | state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
30 |
31 | W_O = neo.transformer.h[l].attn.attention.out_proj.weight
32 | W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads)
33 | state_dict[f"blocks.{l}.attn.W_O"] = W_O
34 | state_dict[f"blocks.{l}.attn.b_O"] = neo.transformer.h[l].attn.attention.out_proj.bias
35 |
36 | state_dict[f"blocks.{l}.ln2.w"] = neo.transformer.h[l].ln_2.weight
37 | state_dict[f"blocks.{l}.ln2.b"] = neo.transformer.h[l].ln_2.bias
38 |
39 | state_dict[f"blocks.{l}.mlp.W_in"] = neo.transformer.h[l].mlp.c_fc.weight.T
40 | state_dict[f"blocks.{l}.mlp.b_in"] = neo.transformer.h[l].mlp.c_fc.bias
41 |
42 | state_dict[f"blocks.{l}.mlp.W_out"] = neo.transformer.h[l].mlp.c_proj.weight.T
43 | state_dict[f"blocks.{l}.mlp.b_out"] = neo.transformer.h[l].mlp.c_proj.bias
44 | state_dict["ln_final.w"] = neo.transformer.ln_f.weight
45 | state_dict["ln_final.b"] = neo.transformer.ln_f.bias
46 |
47 | state_dict["unembed.W_U"] = neo.lm_head.weight.T
48 | state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
49 | return state_dict
50 |
--------------------------------------------------------------------------------
/transformer_lens/pretrained/weight_conversions/neox.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import torch
3 |
4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
5 |
6 |
7 | def convert_neox_weights(neox, cfg: HookedTransformerConfig):
8 | state_dict = {}
9 |
10 | state_dict["embed.W_E"] = neox.gpt_neox.embed_in.weight
11 |
12 | for l in range(cfg.n_layers):
13 | state_dict[f"blocks.{l}.ln1.w"] = neox.gpt_neox.layers[l].input_layernorm.weight
14 | state_dict[f"blocks.{l}.ln1.b"] = neox.gpt_neox.layers[l].input_layernorm.bias
15 |
16 | # For some inexplicable reason, NeoX both uses the concatenated QKV
17 | # matmul of GPT-2 (afaict this has a neglible performance impact) AND
18 | # has the flattened axis in the DIFFERENT order of (head_index qkv
19 | # d_head) - this took me an hour to debug...
20 | W = neox.gpt_neox.layers[l].attention.query_key_value.weight
21 | W = einops.rearrange(W, "(i qkv h) m->qkv i m h", i=cfg.n_heads, qkv=3)
22 |
23 | # Fold in layer norm weights
24 | state_dict[f"blocks.{l}.attn.W_Q"] = W[0]
25 | state_dict[f"blocks.{l}.attn.W_K"] = W[1]
26 | state_dict[f"blocks.{l}.attn.W_V"] = W[2]
27 |
28 | qkv_bias = neox.gpt_neox.layers[l].attention.query_key_value.bias
29 | qkv_bias = einops.rearrange(
30 | qkv_bias,
31 | "(index qkv head)->qkv index head",
32 | qkv=3,
33 | index=cfg.n_heads,
34 | head=cfg.d_head,
35 | )
36 | # Fold in layer norm biases
37 | state_dict[f"blocks.{l}.attn.b_Q"] = qkv_bias[0]
38 | state_dict[f"blocks.{l}.attn.b_K"] = qkv_bias[1]
39 | state_dict[f"blocks.{l}.attn.b_V"] = qkv_bias[2]
40 |
41 | W_O = neox.gpt_neox.layers[l].attention.dense.weight
42 | W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads)
43 | state_dict[f"blocks.{l}.attn.W_O"] = W_O
44 | state_dict[f"blocks.{l}.attn.b_O"] = neox.gpt_neox.layers[l].attention.dense.bias
45 |
46 | state_dict[f"blocks.{l}.ln2.w"] = neox.gpt_neox.layers[l].post_attention_layernorm.weight
47 | state_dict[f"blocks.{l}.ln2.b"] = neox.gpt_neox.layers[l].post_attention_layernorm.bias
48 |
49 | state_dict[f"blocks.{l}.mlp.W_in"] = neox.gpt_neox.layers[l].mlp.dense_h_to_4h.weight.T
50 | state_dict[f"blocks.{l}.mlp.b_in"] = neox.gpt_neox.layers[l].mlp.dense_h_to_4h.bias
51 |
52 | state_dict[f"blocks.{l}.mlp.W_out"] = neox.gpt_neox.layers[l].mlp.dense_4h_to_h.weight.T
53 | state_dict[f"blocks.{l}.mlp.b_out"] = neox.gpt_neox.layers[l].mlp.dense_4h_to_h.bias
54 | state_dict["ln_final.w"] = neox.gpt_neox.final_layer_norm.weight
55 | state_dict["ln_final.b"] = neox.gpt_neox.final_layer_norm.bias
56 |
57 | state_dict["unembed.W_U"] = neox.embed_out.weight.T
58 | state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
59 | return state_dict
60 |
--------------------------------------------------------------------------------
/transformer_lens/pretrained/weight_conversions/opt.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import torch
3 |
4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
5 |
6 |
7 | def convert_opt_weights(opt, cfg: HookedTransformerConfig):
8 | state_dict = {}
9 |
10 | state_dict["embed.W_E"] = opt.model.decoder.embed_tokens.weight
11 | state_dict["pos_embed.W_pos"] = opt.model.decoder.embed_positions.weight[2:, :]
12 |
13 | for l in range(cfg.n_layers):
14 | state_dict[f"blocks.{l}.ln1.w"] = opt.model.decoder.layers[l].self_attn_layer_norm.weight
15 | state_dict[f"blocks.{l}.ln1.b"] = opt.model.decoder.layers[l].self_attn_layer_norm.bias
16 |
17 | W_Q = opt.model.decoder.layers[l].self_attn.q_proj.weight
18 | W_K = opt.model.decoder.layers[l].self_attn.k_proj.weight
19 | W_V = opt.model.decoder.layers[l].self_attn.v_proj.weight
20 | W_Q = einops.rearrange(
21 | W_Q,
22 | "(index d_head) d_model->index d_model d_head",
23 | index=cfg.n_heads,
24 | )
25 | W_K = einops.rearrange(
26 | W_K,
27 | "(index d_head) d_model->index d_model d_head",
28 | index=cfg.n_heads,
29 | )
30 | W_V = einops.rearrange(
31 | W_V,
32 | "(index d_head) d_model->index d_model d_head",
33 | index=cfg.n_heads,
34 | )
35 |
36 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
37 | state_dict[f"blocks.{l}.attn.W_K"] = W_K
38 | state_dict[f"blocks.{l}.attn.W_V"] = W_V
39 |
40 | q_bias = einops.rearrange(
41 | opt.model.decoder.layers[l].self_attn.q_proj.bias,
42 | "(head_index d_head)->head_index d_head",
43 | head_index=cfg.n_heads,
44 | d_head=cfg.d_head,
45 | )
46 | k_bias = einops.rearrange(
47 | opt.model.decoder.layers[l].self_attn.k_proj.bias,
48 | "(head_index d_head)->head_index d_head",
49 | head_index=cfg.n_heads,
50 | d_head=cfg.d_head,
51 | )
52 | v_bias = einops.rearrange(
53 | opt.model.decoder.layers[l].self_attn.v_proj.bias,
54 | "(head_index d_head)->head_index d_head",
55 | head_index=cfg.n_heads,
56 | d_head=cfg.d_head,
57 | )
58 |
59 | state_dict[f"blocks.{l}.attn.b_Q"] = q_bias
60 | state_dict[f"blocks.{l}.attn.b_K"] = k_bias
61 | state_dict[f"blocks.{l}.attn.b_V"] = v_bias
62 |
63 | W_O = opt.model.decoder.layers[l].self_attn.out_proj.weight
64 | W_O = einops.rearrange(
65 | W_O,
66 | "d_model (index d_head)->index d_head d_model",
67 | index=cfg.n_heads,
68 | )
69 | state_dict[f"blocks.{l}.attn.W_O"] = W_O
70 | state_dict[f"blocks.{l}.attn.b_O"] = opt.model.decoder.layers[l].self_attn.out_proj.bias
71 |
72 | state_dict[f"blocks.{l}.ln2.w"] = opt.model.decoder.layers[l].final_layer_norm.weight
73 | state_dict[f"blocks.{l}.ln2.b"] = opt.model.decoder.layers[l].final_layer_norm.bias
74 |
75 | state_dict[f"blocks.{l}.mlp.W_in"] = opt.model.decoder.layers[l].fc1.weight.T
76 | state_dict[f"blocks.{l}.mlp.W_out"] = opt.model.decoder.layers[l].fc2.weight.T
77 |
78 | state_dict[f"blocks.{l}.mlp.b_in"] = opt.model.decoder.layers[l].fc1.bias
79 | state_dict[f"blocks.{l}.mlp.b_out"] = opt.model.decoder.layers[l].fc2.bias
80 | state_dict["ln_final.w"] = opt.model.decoder.final_layer_norm.weight
81 | state_dict["ln_final.b"] = opt.model.decoder.final_layer_norm.bias
82 | state_dict["unembed.W_U"] = opt.lm_head.weight.T
83 | state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
84 | return state_dict
85 |
--------------------------------------------------------------------------------
/transformer_lens/pretrained/weight_conversions/phi.py:
--------------------------------------------------------------------------------
1 | import einops
2 |
3 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
4 |
5 |
6 | def convert_phi_weights(phi, cfg: HookedTransformerConfig):
7 | state_dict = {}
8 |
9 | state_dict["embed.W_E"] = phi.model.embed_tokens.weight
10 |
11 | for l in range(cfg.n_layers):
12 | state_dict[f"blocks.{l}.ln1.w"] = phi.model.layers[l].input_layernorm.weight
13 | state_dict[f"blocks.{l}.ln1.b"] = phi.model.layers[l].input_layernorm.bias
14 |
15 | W_Q = phi.model.layers[l].self_attn.q_proj.weight
16 | W_K = phi.model.layers[l].self_attn.k_proj.weight
17 | W_V = phi.model.layers[l].self_attn.v_proj.weight
18 | W_Q = einops.rearrange(
19 | W_Q, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads
20 | )
21 | W_K = einops.rearrange(
22 | W_K, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads
23 | )
24 | W_V = einops.rearrange(
25 | W_V, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads
26 | )
27 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
28 | state_dict[f"blocks.{l}.attn.W_K"] = W_K
29 | state_dict[f"blocks.{l}.attn.W_V"] = W_V
30 |
31 | b_Q = phi.model.layers[l].self_attn.q_proj.bias
32 | b_K = phi.model.layers[l].self_attn.k_proj.bias
33 | b_V = phi.model.layers[l].self_attn.v_proj.bias
34 | b_Q = einops.rearrange(b_Q, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads)
35 | b_K = einops.rearrange(b_K, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads)
36 | b_V = einops.rearrange(b_V, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads)
37 | state_dict[f"blocks.{l}.attn.b_Q"] = b_Q
38 | state_dict[f"blocks.{l}.attn.b_K"] = b_K
39 | state_dict[f"blocks.{l}.attn.b_V"] = b_V
40 |
41 | W_O = phi.model.layers[l].self_attn.dense.weight
42 | W_O = einops.rearrange(
43 | W_O, "d_model (n_head d_head) -> n_head d_head d_model", n_head=cfg.n_heads
44 | )
45 |
46 | state_dict[f"blocks.{l}.attn.W_O"] = W_O
47 | state_dict[f"blocks.{l}.attn.b_O"] = phi.model.layers[l].self_attn.dense.bias
48 |
49 | # Layer Norm 1 and 2 are tied.
50 | state_dict[f"blocks.{l}.ln2.w"] = state_dict[f"blocks.{l}.ln1.w"]
51 | state_dict[f"blocks.{l}.ln2.b"] = state_dict[f"blocks.{l}.ln1.b"]
52 |
53 | state_dict[f"blocks.{l}.mlp.W_in"] = phi.model.layers[l].mlp.fc1.weight.T
54 | state_dict[f"blocks.{l}.mlp.b_in"] = phi.model.layers[l].mlp.fc1.bias
55 | state_dict[f"blocks.{l}.mlp.W_out"] = phi.model.layers[l].mlp.fc2.weight.T
56 | state_dict[f"blocks.{l}.mlp.b_out"] = phi.model.layers[l].mlp.fc2.bias
57 |
58 | state_dict["ln_final.w"] = phi.model.final_layernorm.weight
59 | state_dict["ln_final.b"] = phi.model.final_layernorm.bias
60 |
61 | state_dict["unembed.W_U"] = phi.lm_head.weight.T
62 | state_dict["unembed.b_U"] = phi.lm_head.bias
63 |
64 | return state_dict
65 |
--------------------------------------------------------------------------------
/transformer_lens/pretrained/weight_conversions/phi3.py:
--------------------------------------------------------------------------------
1 | from typing import cast
2 |
3 | import einops
4 | import torch
5 |
6 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
7 |
8 |
9 | def convert_phi3_weights(phi, cfg: HookedTransformerConfig):
10 | state_dict = {}
11 | state_dict["embed.W_E"] = phi.model.embed_tokens.weight
12 |
13 | # Some models with this architecture use Grouped Query Attention, and so for these we need to modify
14 | # the state dict keys for the K/V attention weight/biases, prepending "_" to the key names.
15 | using_gqa = cfg.n_key_value_heads is not None
16 | gqa_uscore = "_" if using_gqa else ""
17 | # need a cast since MyPy isn't smart enough to realize that using_gqa implies n_key_value_heads is not None
18 | n_kv_heads = cast(int, cfg.n_key_value_heads if using_gqa else cfg.n_heads)
19 |
20 | for l in range(cfg.n_layers):
21 | state_dict[f"blocks.{l}.ln1.w"] = phi.model.layers[l].input_layernorm.weight
22 | state_dict[f"blocks.{l}.ln1.b"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
23 |
24 | W = phi.model.layers[l].self_attn.qkv_proj.weight
25 | q_dim = cfg.n_heads * cfg.d_head
26 | kv_dim = n_kv_heads * cfg.d_head
27 | W_Q, W_K, W_V = W.split([q_dim, kv_dim, kv_dim], dim=0)
28 |
29 | W_Q = einops.rearrange(
30 | W_Q, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads
31 | )
32 | W_K = einops.rearrange(
33 | W_K, "(n_kv_head d_head) d_model -> n_kv_head d_model d_head", n_kv_head=n_kv_heads
34 | )
35 | W_V = einops.rearrange(
36 | W_V, "(n_kv_head d_head) d_model -> n_kv_head d_model d_head", n_kv_head=n_kv_heads
37 | )
38 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
39 | state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K
40 | state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V
41 |
42 | state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(
43 | cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device
44 | )
45 | state_dict[f"blocks.{l}.attn.{gqa_uscore}b_K"] = torch.zeros(
46 | n_kv_heads,
47 | cfg.d_head,
48 | dtype=cfg.dtype,
49 | )
50 | state_dict[f"blocks.{l}.attn.{gqa_uscore}b_V"] = torch.zeros(
51 | n_kv_heads,
52 | cfg.d_head,
53 | dtype=cfg.dtype,
54 | )
55 |
56 | W_O = phi.model.layers[l].self_attn.o_proj.weight
57 | W_O = einops.rearrange(
58 | W_O, "d_model (n_head d_head) -> n_head d_head d_model", n_head=cfg.n_heads
59 | )
60 |
61 | state_dict[f"blocks.{l}.attn.W_O"] = W_O
62 | state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
63 |
64 | state_dict[f"blocks.{l}.ln2.w"] = phi.model.layers[l].post_attention_layernorm.weight
65 | state_dict[f"blocks.{l}.ln2.b"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
66 |
67 | W = phi.model.layers[l].mlp.gate_up_proj.weight.T
68 | W_gate, W_in = torch.tensor_split(W, 2, dim=1)
69 | state_dict[f"blocks.{l}.mlp.W_in"] = W_in
70 | state_dict[f"blocks.{l}.mlp.W_gate"] = W_gate
71 | state_dict[f"blocks.{l}.mlp.W_out"] = phi.model.layers[l].mlp.down_proj.weight.T
72 |
73 | state_dict["ln_final.w"] = phi.model.norm.weight
74 |
75 | state_dict["unembed.W_U"] = phi.lm_head.weight.T
76 | state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
77 |
78 | return state_dict
79 |
--------------------------------------------------------------------------------
/transformer_lens/pretrained/weight_conversions/qwen.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import torch
3 |
4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
5 |
6 |
7 | def convert_qwen_weights(qwen, cfg: HookedTransformerConfig):
8 | state_dict = {}
9 | model = qwen.transformer
10 | state_dict["embed.W_E"] = model.wte.weight
11 |
12 | assert cfg.d_mlp is not None # keep mypy happy
13 |
14 | for l in range(cfg.n_layers):
15 | state_dict[f"blocks.{l}.ln1.w"] = model.h[l].ln_1.weight
16 |
17 | W_Q, W_K, W_V = model.h[l].attn.c_attn.weight.split(split_size=cfg.d_model, dim=0)
18 | W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
19 | W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_heads)
20 | W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_heads)
21 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
22 | state_dict[f"blocks.{l}.attn.W_K"] = W_K
23 | state_dict[f"blocks.{l}.attn.W_V"] = W_V
24 |
25 | b_Q, b_K, b_V = model.h[l].attn.c_attn.bias.split(split_size=cfg.d_model, dim=0)
26 | b_Q = einops.rearrange(
27 | b_Q,
28 | "(n_head d_head) -> n_head d_head",
29 | n_head=cfg.n_heads,
30 | )
31 | b_K = einops.rearrange(
32 | b_K,
33 | "(n_head d_head) -> n_head d_head",
34 | n_head=cfg.n_heads,
35 | )
36 | b_V = einops.rearrange(
37 | b_V,
38 | "(n_head d_head) -> n_head d_head",
39 | n_head=cfg.n_heads,
40 | )
41 | state_dict[f"blocks.{l}.attn.b_Q"] = b_Q
42 | state_dict[f"blocks.{l}.attn.b_K"] = b_K
43 | state_dict[f"blocks.{l}.attn.b_V"] = b_V
44 |
45 | W_O = model.h[l].attn.c_proj.weight
46 | W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
47 | state_dict[f"blocks.{l}.attn.W_O"] = W_O
48 |
49 | state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
50 |
51 | state_dict[f"blocks.{l}.ln2.w"] = model.h[l].ln_2.weight
52 |
53 | state_dict[f"blocks.{l}.mlp.W_in"] = model.h[l].mlp.w1.weight.T
54 | state_dict[f"blocks.{l}.mlp.W_gate"] = model.h[l].mlp.w2.weight.T
55 | state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype)
56 |
57 | state_dict[f"blocks.{l}.mlp.W_out"] = model.h[l].mlp.c_proj.weight.T
58 | state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
59 |
60 | state_dict["ln_final.w"] = model.ln_f.weight
61 |
62 | state_dict["unembed.W_U"] = qwen.lm_head.weight.T
63 | state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
64 |
65 | return state_dict
66 |
--------------------------------------------------------------------------------
/transformer_lens/pretrained/weight_conversions/qwen2.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import torch
3 |
4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
5 |
6 |
7 | def convert_qwen2_weights(qwen, cfg: HookedTransformerConfig):
8 | # Note that this method is also applied for Qwen1.5 models, since they
9 | # have architecture type Qwen2ForCausalLM.
10 |
11 | state_dict = {}
12 |
13 | state_dict["embed.W_E"] = qwen.model.embed_tokens.weight
14 |
15 | assert cfg.d_mlp is not None # keep mypy happy
16 |
17 | for l in range(cfg.n_layers):
18 | state_dict[f"blocks.{l}.ln1.w"] = qwen.model.layers[l].input_layernorm.weight
19 |
20 | W_Q = qwen.model.layers[l].self_attn.q_proj.weight
21 | W_K = qwen.model.layers[l].self_attn.k_proj.weight
22 | W_V = qwen.model.layers[l].self_attn.v_proj.weight
23 | W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
24 | W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads)
25 | W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads)
26 |
27 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
28 | state_dict[f"blocks.{l}.attn._W_K"] = W_K
29 | state_dict[f"blocks.{l}.attn._W_V"] = W_V
30 |
31 | b_Q = qwen.model.layers[l].self_attn.q_proj.bias
32 | b_Q = einops.rearrange(
33 | b_Q,
34 | "(n_head d_head) -> n_head d_head",
35 | n_head=cfg.n_heads,
36 | )
37 |
38 | b_K = qwen.model.layers[l].self_attn.k_proj.bias
39 | b_K = einops.rearrange(
40 | b_K,
41 | "(n_head d_head) -> n_head d_head",
42 | n_head=cfg.n_key_value_heads,
43 | )
44 |
45 | b_V = qwen.model.layers[l].self_attn.v_proj.bias
46 | b_V = einops.rearrange(
47 | b_V,
48 | "(n_head d_head) -> n_head d_head",
49 | n_head=cfg.n_key_value_heads,
50 | )
51 |
52 | state_dict[f"blocks.{l}.attn.b_Q"] = b_Q
53 | state_dict[f"blocks.{l}.attn._b_K"] = b_K
54 | state_dict[f"blocks.{l}.attn._b_V"] = b_V
55 |
56 | W_O = qwen.model.layers[l].self_attn.o_proj.weight
57 | W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
58 | state_dict[f"blocks.{l}.attn.W_O"] = W_O
59 |
60 | state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
61 |
62 | state_dict[f"blocks.{l}.ln2.w"] = qwen.model.layers[l].post_attention_layernorm.weight
63 |
64 | state_dict[f"blocks.{l}.mlp.W_in"] = qwen.model.layers[l].mlp.up_proj.weight.T
65 | state_dict[f"blocks.{l}.mlp.W_gate"] = qwen.model.layers[l].mlp.gate_proj.weight.T
66 | state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype)
67 |
68 | state_dict[f"blocks.{l}.mlp.W_out"] = qwen.model.layers[l].mlp.down_proj.weight.T
69 | state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
70 |
71 | state_dict["ln_final.w"] = qwen.model.norm.weight
72 |
73 | state_dict["unembed.W_U"] = qwen.lm_head.weight.T
74 | state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
75 |
76 | return state_dict
77 |
--------------------------------------------------------------------------------
/transformer_lens/pretrained/weight_conversions/t5.py:
--------------------------------------------------------------------------------
1 | import einops
2 |
3 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
4 |
5 |
6 | def convert_t5_weights(t5, cfg: HookedTransformerConfig):
7 | state_dict = {
8 | "embed.W_E": t5.encoder.embed_tokens.weight,
9 | "unembed.W_U": t5.encoder.embed_tokens.weight.T,
10 | "encoder.0.attn.rel_pos_bias.weight": t5.encoder.block[0]
11 | .layer[0]
12 | .SelfAttention.relative_attention_bias.weight,
13 | }
14 |
15 | for l in range(cfg.n_layers):
16 | block = t5.encoder.block[l]
17 | state_dict[f"encoder.{l}.attn.W_Q"] = einops.rearrange(
18 | block.layer[0].SelfAttention.q.weight, "(i h) m -> i m h", i=cfg.n_heads
19 | )
20 | state_dict[f"encoder.{l}.attn.W_K"] = einops.rearrange(
21 | block.layer[0].SelfAttention.k.weight, "(i h) m -> i m h", i=cfg.n_heads
22 | )
23 |
24 | state_dict[f"encoder.{l}.attn.W_V"] = einops.rearrange(
25 | block.layer[0].SelfAttention.v.weight, "(i h) m -> i m h", i=cfg.n_heads
26 | )
27 |
28 | state_dict[f"encoder.{l}.attn.W_O"] = einops.rearrange(
29 | block.layer[0].SelfAttention.o.weight,
30 | "m (i h) -> i h m",
31 | i=cfg.n_heads,
32 | )
33 | state_dict[f"encoder.{l}.ln1.w"] = block.layer[0].layer_norm.weight
34 |
35 | # fixme DenseReluDense may be T5DenseGatedActDense instead
36 | state_dict[f"encoder.{l}.mlp.W_in"] = einops.rearrange(
37 | block.layer[1].DenseReluDense.wi.weight, "mlp model -> model mlp"
38 | )
39 |
40 | state_dict[f"encoder.{l}.mlp.W_out"] = einops.rearrange(
41 | block.layer[1].DenseReluDense.wo.weight, "model mlp -> mlp model"
42 | )
43 | state_dict[f"encoder.{l}.ln2.w"] = block.layer[1].layer_norm.weight
44 |
45 | state_dict["encoder_final_ln.w"] = t5.encoder.final_layer_norm.weight
46 |
47 | state_dict["decoder.0.attn.rel_pos_bias.weight"] = (
48 | t5.decoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight
49 | )
50 |
51 | for l in range(cfg.n_layers):
52 | block = t5.decoder.block[l]
53 | state_dict[f"decoder.{l}.attn.W_Q"] = einops.rearrange(
54 | block.layer[0].SelfAttention.q.weight, "(i h) m -> i m h", i=cfg.n_heads
55 | )
56 |
57 | state_dict[f"decoder.{l}.attn.W_K"] = einops.rearrange(
58 | block.layer[0].SelfAttention.k.weight, "(i h) m -> i m h", i=cfg.n_heads
59 | )
60 | state_dict[f"decoder.{l}.attn.W_V"] = einops.rearrange(
61 | block.layer[0].SelfAttention.v.weight, "(i h) m -> i m h", i=cfg.n_heads
62 | )
63 |
64 | state_dict[f"decoder.{l}.attn.W_O"] = einops.rearrange(
65 | block.layer[0].SelfAttention.o.weight,
66 | "m (i h) -> i h m",
67 | i=cfg.n_heads,
68 | )
69 |
70 | state_dict[f"decoder.{l}.ln1.w"] = block.layer[0].layer_norm.weight
71 |
72 | state_dict[f"decoder.{l}.cross_attn.W_Q"] = einops.rearrange(
73 | block.layer[1].EncDecAttention.q.weight, "(i h) m -> i m h", i=cfg.n_heads
74 | )
75 |
76 | state_dict[f"decoder.{l}.cross_attn.W_K"] = einops.rearrange(
77 | block.layer[1].EncDecAttention.k.weight, "(i h) m -> i m h", i=cfg.n_heads
78 | )
79 |
80 | state_dict[f"decoder.{l}.cross_attn.W_V"] = einops.rearrange(
81 | block.layer[1].EncDecAttention.v.weight, "(i h) m -> i m h", i=cfg.n_heads
82 | )
83 | state_dict[f"decoder.{l}.cross_attn.W_O"] = einops.rearrange(
84 | block.layer[1].EncDecAttention.o.weight,
85 | "m (i h) -> i h m",
86 | i=cfg.n_heads,
87 | )
88 | state_dict[f"decoder.{l}.ln2.w"] = block.layer[1].layer_norm.weight
89 |
90 | # fixme DenseReluDense may be T5DenseGatedActDense instead
91 | state_dict[f"decoder.{l}.mlp.W_in"] = einops.rearrange(
92 | block.layer[2].DenseReluDense.wi.weight, "mlp model -> model mlp"
93 | )
94 | state_dict[f"decoder.{l}.mlp.W_out"] = einops.rearrange(
95 | block.layer[2].DenseReluDense.wo.weight, "model mlp -> mlp model"
96 | )
97 | state_dict[f"decoder.{l}.ln3.w"] = block.layer[2].layer_norm.weight
98 |
99 | state_dict["decoder_final_ln.w"] = t5.decoder.final_layer_norm.weight
100 |
101 | return state_dict
102 |
--------------------------------------------------------------------------------
/transformer_lens/utilities/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TransformerLensOrg/TransformerLens/b5a16f849649a237cc02cc2c272ae4dc2085abe4/transformer_lens/utilities/__init__.py
--------------------------------------------------------------------------------
/transformer_lens/utilities/activation_functions.py:
--------------------------------------------------------------------------------
1 | """Activation Functions.
2 |
3 | Utilities for interacting with all supported activation functions.
4 | """
5 | from typing import Callable, Dict
6 |
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | from transformer_lens.utils import gelu_fast, gelu_new, solu
11 |
12 | # Convenient type for the format of each activation function
13 | ActivationFunction = Callable[..., torch.Tensor]
14 |
15 | # All currently supported activation functions. To add a new function, simply
16 | # put the name of the function as the key, and the value as the actual callable.
17 | SUPPORTED_ACTIVATIONS: Dict[str, ActivationFunction] = {
18 | "solu": solu,
19 | "solu_ln": solu,
20 | "gelu_new": gelu_new,
21 | "gelu_fast": gelu_fast,
22 | "silu": F.silu,
23 | "relu": F.relu,
24 | "gelu": F.gelu,
25 | "gelu_pytorch_tanh": lambda tensor: F.gelu(tensor, approximate="tanh"),
26 | }
27 |
--------------------------------------------------------------------------------
/transformer_lens/utilities/addmm.py:
--------------------------------------------------------------------------------
1 | """Addmm
2 |
3 | Implementations of Addmm functions matching Huggingface implementations.
4 | """
5 | import torch
6 | from jaxtyping import Float
7 |
8 |
9 | def vanilla_addmm(
10 | input: Float[torch.Tensor, "... #o"], # Must be broadcastable to "m o"
11 | mat1: Float[torch.Tensor, "m n"],
12 | mat2: Float[torch.Tensor, "n o"],
13 | ) -> Float[torch.Tensor, "m o"]:
14 | """Typechecked version of torch.addmm.
15 |
16 | Note that both mat1 and mat2 *must* be 2d matrices.
17 | """
18 | return torch.addmm(input, mat1, mat2)
19 |
20 |
21 | def batch_addmm(
22 | bias: Float[torch.Tensor, "... #d_out"], # Must be broadcastable to "... d_out"
23 | weight: Float[torch.Tensor, "d_in d_out"],
24 | x: Float[torch.Tensor, "... d_in"],
25 | ) -> Float[torch.Tensor, "... d_out"]:
26 | """Fused add-multiply with support for batch dimensions.
27 |
28 | Must match the Huggingface Conv1D implementation exactly.
29 | https://github.com/huggingface/transformers/blob/9ba9369a2557e53a01378199a9839ec6e82d8bc7/src/transformers/pytorch_utils.py#L102-L106
30 | """
31 | n_output_features = weight.shape[-1]
32 | size_out = x.size()[:-1] + (n_output_features,)
33 | x = vanilla_addmm(bias, x.view(-1, x.size(-1)), weight)
34 | x = x.view(size_out)
35 | return x
36 |
--------------------------------------------------------------------------------
/transformer_lens/utilities/attention.py:
--------------------------------------------------------------------------------
1 | """Attention.
2 |
3 | Utilities for attention components.
4 | """
5 |
6 | import einops
7 | import torch
8 | import torch.nn.functional as F
9 | from jaxtyping import Float
10 |
11 |
12 | def simple_attn_linear(
13 | input: Float[torch.Tensor, "batch pos d_model"],
14 | w: Float[torch.Tensor, "head_index d_model d_head"],
15 | b: Float[torch.Tensor, "head_index d_head"],
16 | ) -> Float[torch.Tensor, "batch pos head_index d_head"]:
17 | """Linear layer for attention calculation."""
18 |
19 | if input.device != w.device:
20 | w = w.to(input.device)
21 | if input.device != b.device:
22 | b = b.to(input.device)
23 |
24 | w = einops.rearrange(w, "head_index d_model d_head -> (head_index d_head) d_model")
25 | b_ = einops.rearrange(b, "head_index d_head -> (head_index d_head)")
26 |
27 | return F.linear(input, w, b_).reshape(input.shape[0], input.shape[1], b.shape[0], b.shape[1])
28 |
29 |
30 | def complex_attn_linear(
31 | input: Float[torch.Tensor, "batch pos head_index d_model"],
32 | w: Float[torch.Tensor, "head_index d_model d_head"],
33 | b: Float[torch.Tensor, "head_index d_head"],
34 | ) -> Float[torch.Tensor, "batch pos head_index d_head"]:
35 | """Linear layer for attention calculation.
36 |
37 | This is almost the same as simple_attn_linear, but the input tensor has an extra head_index dimension, used when calculating the input of each attention head separately.
38 | """
39 |
40 | # Add singleton dimensions for broadcasting
41 | input = einops.rearrange(
42 | input, "batch pos head_index d_model -> batch pos head_index d_model 1"
43 | )
44 | w = einops.rearrange(w, "head_index d_model d_head -> 1 1 head_index d_model d_head")
45 |
46 | # Element-wise multiplication and sum over the d_model dimension
47 | result = input * w
48 | result = result.sum(dim=-2)
49 | return result + b
50 |
--------------------------------------------------------------------------------