├── .devcontainer ├── Dockerfile └── devcontainer.json ├── .gitattributes ├── .gitconfig ├── .github ├── ISSUE_TEMPLATE │ ├── bug.md │ ├── compatibility.md │ ├── proposal.md │ └── question.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── checks.yml │ └── release.yml ├── .gitignore ├── .vscode ├── cspell.json ├── extensions.json └── settings.json ├── LICENSE ├── Main_Demo.ipynb ├── README.md ├── assets ├── rm_transformer_lens_logo.png └── transformer_lens_logo.png ├── debugging ├── comparing-to-huggingface.ipynb └── hf-tl-logit-comparator.ipynb ├── demos ├── ARENA_Content.ipynb ├── Activation_Patching_in_TL_Demo.ipynb ├── Attribution_Patching_Demo.ipynb ├── BERT.ipynb ├── Colab_Compatibility.ipynb ├── Config_Overhaul.ipynb ├── Exploratory_Analysis_Demo.ipynb ├── Grokking_Demo.ipynb ├── Head_Detector_Demo.ipynb ├── Interactive_Neuroscope.ipynb ├── LLaMA.ipynb ├── LLaMA2_GPU_Quantized.ipynb ├── LLaVA.ipynb ├── Main_Demo.ipynb ├── No_Position_Experiment.ipynb ├── Othello_GPT.ipynb ├── Patchscopes_Generation_Demo.ipynb ├── Qwen.ipynb ├── SVD_Interpreter_Demo.ipynb ├── Santa_Coder.ipynb ├── T5.ipynb ├── Tracr_to_Transformer_Lens_Demo.ipynb ├── conftest.py ├── doc_sanitize.cfg └── stable_lm.ipynb ├── docs ├── Makefile ├── README.md ├── make.bat ├── make_docs.py └── source │ ├── _static │ ├── TransformerLens_Diagram.svg │ └── transformer_lens_logo.png │ ├── apidoc_templates │ ├── module.rst_t │ ├── package.rst_t │ └── toc.rst_t │ ├── conf.py │ ├── content │ ├── citation.md │ ├── contributing.md │ ├── gallery.md │ ├── getting_started.md │ ├── getting_started_mech_interp.md │ ├── news │ │ └── release-2.0.md │ ├── special_cases.md │ └── tutorials.md │ ├── favicon.ico │ └── index.md ├── easy_transformer └── __init__.py ├── further_comments.md ├── makefile ├── poetry.lock ├── pyproject.toml ├── tests ├── acceptance │ ├── test_activation_cache.py │ ├── test_evals.py │ ├── test_hook_tokens.py │ ├── test_hooked_encoder.py │ ├── test_hooked_encoder_decoder.py │ ├── test_hooked_transformer.py │ ├── test_multi_gpu.py │ └── test_tokenizer_special_tokens.py ├── integration │ ├── test_attention_mask.py │ ├── test_cache_hook_names.py │ ├── test_cache_pos_slice.py │ ├── test_create_hooked_encoder.py │ ├── test_cross_entropy_loss.py │ ├── test_d_vocab.py │ ├── test_grouped_query_attention.py │ ├── test_head_detector.py │ ├── test_hooks.py │ ├── test_kv_cache.py │ ├── test_left_padding.py │ ├── test_loading_from_pretrained.py │ ├── test_match_huggingface.py │ ├── test_only_tokenizer.py │ ├── test_prepend_bos.py │ ├── test_start_at_layer.py │ ├── test_stop_at_layer.py │ ├── test_tokenization_methods.py │ └── test_utils_tokens.py ├── manual_checks │ ├── manual_checks_type_annotations.py │ └── manual_checks_typing.py └── unit │ ├── components │ ├── mlps │ │ ├── test_can_be_used_as_mlp.py │ │ ├── test_gated_mlp.py │ │ ├── test_mlp.py │ │ └── test_moe.py │ ├── test_abstract_attention.py │ └── test_attention.py │ ├── factored_matrix │ ├── test_constructor.py │ ├── test_get_item.py │ ├── test_multiply_by_factored_matrix.py │ ├── test_multiply_by_matrix.py │ ├── test_multiply_by_scalar.py │ ├── test_multiply_by_vector.py │ └── test_properties.py │ ├── factories │ ├── test_activation_function_factory.py │ └── test_mlp_factory.py │ ├── pretrained_weight_conversions │ └── test_neo.py │ ├── test_hook_points.py │ ├── test_hooked_root_module.py │ ├── test_hooked_transformer_config.py │ ├── test_loading_from_pretrained_utilities.py │ ├── test_make_docs.py │ ├── test_next_sentence_prediction.py │ ├── test_split_qkv.py │ ├── test_svd_interpreter.py │ ├── test_use_attn_result.py │ ├── test_utils.py │ └── utilities │ └── test_devices.py └── transformer_lens ├── ActivationCache.py ├── BertNextSentencePrediction.py ├── FactoredMatrix.py ├── HookedEncoder.py ├── HookedEncoderDecoder.py ├── HookedTransformer.py ├── HookedTransformerConfig.py ├── SVDInterpreter.py ├── __init__.py ├── components ├── __init__.py ├── abstract_attention.py ├── attention.py ├── bert_block.py ├── bert_embed.py ├── bert_mlm_head.py ├── bert_nsp_head.py ├── bert_pooler.py ├── embed.py ├── grouped_query_attention.py ├── layer_norm.py ├── layer_norm_pre.py ├── mlps │ ├── can_be_used_as_mlp.py │ ├── gated_mlp.py │ ├── gated_mlp_4bit.py │ ├── mlp.py │ └── moe.py ├── pos_embed.py ├── rms_norm.py ├── rms_norm_pre.py ├── t5_attention.py ├── t5_block.py ├── token_typed_embed.py ├── transformer_block.py └── unembed.py ├── evals.py ├── factories ├── activation_function_factory.py └── mlp_factory.py ├── head_detector.py ├── hook_points.py ├── loading_from_pretrained.py ├── past_key_value_caching.py ├── patching.py ├── pretrained ├── __init__.py └── weight_conversions │ ├── __init__.py │ ├── bert.py │ ├── bloom.py │ ├── coder.py │ ├── gemma.py │ ├── gpt2.py │ ├── gptj.py │ ├── llama.py │ ├── mingpt.py │ ├── mistral.py │ ├── mixtral.py │ ├── nanogpt.py │ ├── neel_solu_old.py │ ├── neo.py │ ├── neox.py │ ├── opt.py │ ├── phi.py │ ├── phi3.py │ ├── qwen.py │ ├── qwen2.py │ └── t5.py ├── train.py ├── utilities ├── __init__.py ├── activation_functions.py ├── addmm.py ├── attention.py └── devices.py └── utils.py /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | # If .venv is already setup with python3.8, it will use python3.8. To use 3.11 remove it first. 2 | 3 | # Use Nvidia Ubuntu 20 base (includes CUDA if a supported GPU is present) 4 | # https://hub.docker.com/r/nvidia/cuda 5 | FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04@sha256:55211df43bf393d3393559d5ab53283d4ebc3943d802b04546a24f3345825bd9 6 | 7 | ARG USERNAME 8 | ARG USER_UID=1000 9 | ARG USER_GID=$USER_UID 10 | 11 | # Create the user 12 | # https://code.visualstudio.com/remote/advancedcontainers/add-nonroot-user 13 | RUN groupadd --gid $USER_GID $USERNAME \ 14 | && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \ 15 | && usermod -a -G video user \ 16 | && apt-get update \ 17 | && apt-get install -y sudo \ 18 | && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \ 19 | && chmod 0440 /etc/sudoers.d/$USERNAME 20 | 21 | # Install dependencies 22 | RUN apt-get update && \ 23 | DEBIAN_FRONTEND=noninteractive apt-get -qq -y install \ 24 | software-properties-common && \ 25 | add-apt-repository -y ppa:deadsnakes/ppa && \ 26 | apt-get update && \ 27 | DEBIAN_FRONTEND=noninteractive apt-get -qq -y install \ 28 | build-essential \ 29 | python3.11 \ 30 | python3.11-dev \ 31 | python3.11-distutils \ 32 | python3.11-venv \ 33 | curl \ 34 | git && \ 35 | # Update python3 default to point to python3.11 36 | update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 1 && \ 37 | update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 2 && \ 38 | update-alternatives --set python3 /usr/bin/python3.11 39 | 40 | # User the new user 41 | USER $USERNAME 42 | 43 | # Install poetry 44 | RUN curl -sSL https://install.python-poetry.org | python3.11 - 45 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the README at: 2 | // https://github.com/microsoft/vscode-dev-containers/tree/v0.238.1/containers/python-3 3 | { 4 | "name": "Python 3", 5 | "build": { 6 | "dockerfile": "Dockerfile", 7 | "args": { 8 | "USERNAME": "user" 9 | } 10 | }, 11 | // Configure tool-specific properties. 12 | "customizations": { 13 | // Configure properties specific to VS Code. 14 | "vscode": { 15 | // Set *default* container specific settings.json values on container create. 16 | "settings": {}, 17 | // Add the IDs of extensions you want installed when the container is created. 18 | "extensions": [] 19 | } 20 | }, 21 | "containerUser": "user", 22 | // Install any dependencies 23 | "postCreateCommand": "poetry config virtualenvs.in-project true && poetry install --with dev" 24 | } -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb merge=nbdev-merge 2 | -------------------------------------------------------------------------------- /.gitconfig: -------------------------------------------------------------------------------- 1 | # Generated by nbdev_install_hooks 2 | # 3 | # If you need to disable this instrumentation do: 4 | # git config --local --unset include.path 5 | # 6 | # To restore: 7 | # git config --local include.path ../.gitconfig 8 | # 9 | [merge "nbdev-merge"] 10 | name = resolve conflicts with nbdev_fix 11 | driver = nbdev_merge %O %A %B %P 12 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: Submit a bug report 4 | title: "[Bug Report] Bug title" 5 | 6 | --- 7 | 8 | If you are submitting a bug report, please fill in the following details and use the tag [bug]. 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **Code example** 14 | Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful. 15 | 16 | **System Info** 17 | Describe the characteristic of your environment: 18 | * Describe how `transformer_lens` was installed (pip, docker, source, ...) 19 | * What OS are you using? (Linux, MacOS, Windows) 20 | * Python version (We support 3.7--3.10 currently) 21 | 22 | **Additional context** 23 | Add any other context about the problem here. 24 | 25 | ### Checklist 26 | 27 | - [ ] I have checked that there is no similar [issue](https://github.com/TransformerLensOrg/TransformerLens/issues) in the repo (**required**) 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/compatibility.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Compatibility Report 3 | about: Submit a compatibility report 4 | title: "[Compatibility Report] Model ID" 5 | 6 | --- 7 | 8 | 15 | 16 | ## Model 17 | 18 | REPLACE_WITH_MODEL_ID 19 | 20 | - [ ] This model was incompatible when it was introduced to TransformerLens 21 | 22 | 25 | 26 | The model seems to have worked as of REPLACE_WITH_LAST_COMPATIBLE_VERSION_NUMBER. It first started 27 | showing signs of incompatibility in REPLACE_WITH_FIRST_INCOMPATIBLE_VERSION_NUMBER. 28 | 29 | ### Example of some generations in transformers 30 | 31 | 32 | ### Code used to load the model in TransformerLens 33 | 34 | 35 | ### Example of some generations in TransformerLens 36 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/proposal.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Proposal 3 | about: Propose changes that are not fixes bugs 4 | title: "[Proposal] Proposal title" 5 | --- 6 | 7 | ### Proposal 8 | 9 | A clear and concise description of the proposal. 10 | 11 | ### Motivation 12 | 13 | Please outline the motivation for the proposal. 14 | Is your feature request related to a problem? e.g.,"I'm always frustrated when [...]". 15 | If this is related to another GitHub issue, please link here too. 16 | 17 | ### Pitch 18 | 19 | A clear and concise description of what you want to happen. It is useful if you relate proposals to existing features or lackthereof as well as relevant mechanistic interpretability techniques or model architectures. 20 | 21 | ### Alternatives 22 | 23 | A clear and concise description of any alternative solutions or features you've considered, if any. 24 | 25 | ### Additional context 26 | 27 | Add any other context or screenshots about the feature request here. 28 | 29 | ### Checklist 30 | 31 | - [ ] I have checked that there is no similar [issue](https://github.com/TransformerLensOrg/Transformerlens/issues) in the repo (**required**) 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Question 3 | about: Ask a question 4 | title: "[Question] Question title" 5 | --- 6 | 7 | 8 | ### Question 9 | 10 | If you're a beginner and have basic questions, you can ask them on various online forums such as: 11 | - The [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-1qosyh8g3-9bF3gamhLNJiqCL_QqLFrA) 12 | - The [Eleuther AI discord](https://discord.gg/zBGx3azzUn) 13 | - The [Mechanistic Interpretability Discord](https://discord.gg/wcuV4xnJ) 14 | 15 | Advanced/nontrivial questions, especially in areas where documentation is lacking, are very much welcome here. -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 8 | # Description 9 | 10 | Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. 11 | 12 | Fixes # (issue) 13 | 14 | ## Type of change 15 | 16 | Please delete options that are not relevant. 17 | 18 | - [ ] Bug fix (non-breaking change which fixes an issue) 19 | - [ ] New feature (non-breaking change which adds functionality) 20 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 21 | - [ ] This change requires a documentation update 22 | 23 | ### Screenshots 24 | Please attach before and after screenshots of the change if applicable. 25 | 26 | 36 | 37 | # Checklist: 38 | 39 | - [ ] I have commented my code, particularly in hard-to-understand areas 40 | - [ ] I have made corresponding changes to the documentation 41 | - [ ] My changes generate no new warnings 42 | - [ ] I have added tests that prove my fix is effective or that my feature works 43 | - [ ] New and existing unit tests pass locally with my changes 44 | - [ ] I have not rewritten tests relating to key interfaces which would affect backward compatibility 45 | 46 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: 6 | - published 7 | 8 | jobs: 9 | checks: 10 | name: Run checks workflow 11 | uses: TransformerLensOrg/TransformerLens/.github/workflows/checks.yml@main 12 | secrets: 13 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 14 | 15 | semver-parser: 16 | name: Parse the semantic version from the release 17 | runs-on: ubuntu-latest 18 | steps: 19 | - name: Parse semver string 20 | id: semver_parser 21 | uses: booxmedialtd/ws-action-parse-semver@v1.4.7 22 | with: 23 | input_string: ${{ github.event.release.tag_name }} 24 | outputs: 25 | major: "${{ steps.semver_parser.outputs.major }}" 26 | minor: "${{ steps.semver_parser.outputs.minor }}" 27 | patch: "${{ steps.semver_parser.outputs.patch }}" 28 | semver: "${{ steps.semver_parser.outputs.fullversion }}" 29 | 30 | release-python: 31 | name: Release Python package to PyPi 32 | needs: 33 | - checks 34 | - semver-parser 35 | runs-on: ubuntu-latest 36 | steps: 37 | - uses: actions/checkout@v3 38 | - name: Install Poetry 39 | uses: snok/install-poetry@v1 40 | - name: Set up Python 41 | uses: actions/setup-python@v4 42 | with: 43 | python-version: '3.11' 44 | cache: 'poetry' 45 | - name: Poetry config 46 | run: poetry self add 'poethepoet[poetry_plugin]' 47 | - name: Install dependencies 48 | run: poetry install --with dev 49 | - name: Set the version 50 | run: poetry version ${{needs.semver-parser.outputs.semver}} 51 | - name: Build 52 | run: poetry build 53 | - name: Publish 54 | run: poetry publish 55 | env: 56 | POETRY_PYPI_TOKEN_PYPI: ${{ secrets.PYPI_TOKEN_PYPI }} 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | Testing_Notebook.ipynb 3 | transformer_lens/scratch.py 4 | .idea/ 5 | wandb* 6 | transformer_lens\.egg* 7 | MANIFEST.in 8 | settings.ini 9 | _proc 10 | core.py 11 | nbs 12 | _modidx.py 13 | .ipynb_checkpoints 14 | env 15 | dist/ 16 | docs/build 17 | .coverage* 18 | .Ds_Store 19 | .pylintrc 20 | docs/source/generated 21 | **.orig 22 | .venv 23 | -------------------------------------------------------------------------------- /.vscode/cspell.json: -------------------------------------------------------------------------------- 1 | { 2 | "language": "en,en-GB", 3 | "words": [ 4 | "accum", 5 | "adrià", 6 | "aengus", 7 | "allclose", 8 | "alonso", 9 | "arange", 10 | "argmax", 11 | "argmaxy", 12 | "autodiff", 13 | "autoregressive", 14 | "barez", 15 | "beartype", 16 | "belrose", 17 | "bertsimas", 18 | "biderman", 19 | "bilal", 20 | "bincount", 21 | "caxis", 22 | "checkpointed", 23 | "chughtai", 24 | "circuitsvis", 25 | "Codeparrot", 26 | "codespaces", 27 | "colab", 28 | "collectstart", 29 | "colour", 30 | "conmy", 31 | "cooney", 32 | "crfm", 33 | "cumsum", 34 | "datapoint", 35 | "dictmodel", 36 | "dimitris", 37 | "disconfirm", 38 | "dmitrii", 39 | "docstrings", 40 | "doctest", 41 | "doctree", 42 | "dtype", 43 | "dtypes", 44 | "einops", 45 | "elhage", 46 | "endoftext", 47 | "eqnarray", 48 | "esben", 49 | "evals", 50 | "explictly", 51 | "fazl", 52 | "firstpage", 53 | "fspath", 54 | "furo", 55 | "garriga", 56 | "gelu", 57 | "githubpages", 58 | "gptj", 59 | "halawi", 60 | "heimersheim", 61 | "howpublished", 62 | "huggingface", 63 | "icml", 64 | "idxs", 65 | "imshow", 66 | "interp", 67 | "interpretability", 68 | "ioannis", 69 | "ipynb", 70 | "isin", 71 | "isort", 72 | "janiak", 73 | "Janky", 74 | "jaxtyping", 75 | "jett", 76 | "kaiming", 77 | "keepdim", 78 | "kissane", 79 | "konstas", 80 | "kran", 81 | "lastpage", 82 | "layernorm", 83 | "ldim", 84 | "lieberum", 85 | "logits", 86 | "logsumexp", 87 | "mavor", 88 | "maxdepth", 89 | "mingpt", 90 | "nanda", 91 | "ndarray", 92 | "ndim", 93 | "neel", 94 | "neox", 95 | "nitpicky", 96 | "occurences", 97 | "olah", 98 | "openwebtext", 99 | "overcomplete", 100 | "Overriden", 101 | "pagename", 102 | "pauly", 103 | "pretrained", 104 | "probs", 105 | "producting", 106 | "pycln", 107 | "pypi", 108 | "pytest", 109 | "randn", 110 | "rdim", 111 | "relu", 112 | "resid", 113 | "rprint", 114 | "rtml", 115 | "rtol", 116 | "shortformer", 117 | "softmax", 118 | "softmaxing", 119 | "solu", 120 | "stas", 121 | "templatedir", 122 | "templatename", 123 | "toctree", 124 | "topk", 125 | "tqdm", 126 | "transformerlens", 127 | "tril", 128 | "triu", 129 | "troitskii", 130 | "unembed", 131 | "unembedded", 132 | "unembedding", 133 | "unigram", 134 | "unsqueeze", 135 | "virtualenvs", 136 | "visualisation", 137 | "xaxis", 138 | "yaxis" 139 | ] 140 | } 141 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "christian-kohler.path-intellisense", 4 | "davidanson.vscode-markdownlint", 5 | "donjayamanne.githistory", 6 | "donjayamanne.python-extension-pack", 7 | "github.copilot", 8 | "github.vscode-pull-request-github", 9 | "ionutvmi.path-autocomplete", 10 | "mikoz.autoflake-extension", 11 | "ms-python.isort", 12 | "ms-python.pylint", 13 | "ms-python.python", 14 | "ms-python.vscode-pylance", 15 | "ms-toolsai.jupyter-keymap", 16 | "ms-toolsai.jupyter-renderers", 17 | "ms-toolsai.jupyter", 18 | "richie5um2.vscode-sort-json", 19 | "stkb.rewrap", 20 | "streetsidesoftware.code-spell-checker-british-english", 21 | "streetsidesoftware.code-spell-checker", 22 | "tamasfe.even-better-toml", 23 | "yzhang.markdown-all-in-one" 24 | ] 25 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "[python]": { 3 | "editor.defaultFormatter": "ms-python.black-formatter" 4 | }, 5 | "[toml]": { 6 | "editor.defaultFormatter": "tamasfe.even-better-toml" 7 | }, 8 | "editor.codeActionsOnSave": { 9 | "source.organizeImports": "explicit" 10 | }, 11 | "editor.formatOnSave": true, 12 | "editor.rulers": [100], 13 | "evenBetterToml.formatter.allowedBlankLines": 1, 14 | "evenBetterToml.formatter.arrayAutoCollapse": true, 15 | "evenBetterToml.formatter.arrayAutoExpand": true, 16 | "evenBetterToml.formatter.arrayTrailingComma": true, 17 | "evenBetterToml.formatter.columnWidth": 100, 18 | "evenBetterToml.formatter.compactArrays": true, 19 | "evenBetterToml.formatter.compactEntries": true, 20 | "evenBetterToml.formatter.compactInlineTables": true, 21 | "evenBetterToml.formatter.indentEntries": true, 22 | "evenBetterToml.formatter.indentString": " ", 23 | "evenBetterToml.formatter.indentTables": true, 24 | "evenBetterToml.formatter.inlineTableExpand": true, 25 | "evenBetterToml.formatter.reorderArrays": true, 26 | "evenBetterToml.formatter.reorderKeys": true, 27 | "evenBetterToml.formatter.trailingNewline": true, 28 | "evenBetterToml.schema.enabled": true, 29 | "evenBetterToml.schema.links": true, 30 | "evenBetterToml.syntax.semanticTokens": false, 31 | "mypy-type-checker.importStrategy": "fromEnvironment", 32 | "notebook.formatOnCellExecution": true, 33 | "notebook.formatOnSave.enabled": true, 34 | "pylint.importStrategy": "fromEnvironment", 35 | "python.testing.pytestArgs": [ 36 | "transformer_lens", 37 | ], 38 | "python.testing.pytestEnabled": true, 39 | "rewrap.autoWrap.enabled": true, 40 | "rewrap.reformat": true, 41 | "rewrap.wrappingColumn": 100, 42 | "mypy.runUsingActiveInterpreter": true, 43 | "editor.defaultFormatter": "ms-python.black-formatter", 44 | "black-formatter.args": [ 45 | "-l 100" 46 | ], 47 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 TransformerLensOrg 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Main_Demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Notice\n", 9 | "\n", 10 | "All demos have been moved to the `/demos` directory in the root of the project [View on GitHub](https://github.com/TransformerLensOrg/TransformerLens/tree/main/demos)\n" 11 | ] 12 | } 13 | ], 14 | "metadata": { 15 | "kernelspec": { 16 | "display_name": ".venv", 17 | "language": "python", 18 | "name": "python3" 19 | }, 20 | "language_info": { 21 | "name": "python", 22 | "version": "3.8.10" 23 | }, 24 | "orig_nbformat": 4, 25 | "vscode": { 26 | "interpreter": { 27 | "hash": "b8da8bee8e62267e58dea1070cf9758944e5c526027e75fbcceb99a4c665691a" 28 | } 29 | } 30 | }, 31 | "nbformat": 4, 32 | "nbformat_minor": 2 33 | } 34 | -------------------------------------------------------------------------------- /assets/rm_transformer_lens_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TransformerLensOrg/TransformerLens/b5a16f849649a237cc02cc2c272ae4dc2085abe4/assets/rm_transformer_lens_logo.png -------------------------------------------------------------------------------- /assets/transformer_lens_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TransformerLensOrg/TransformerLens/b5a16f849649a237cc02cc2c272ae4dc2085abe4/assets/transformer_lens_logo.png -------------------------------------------------------------------------------- /demos/conftest.py: -------------------------------------------------------------------------------- 1 | def pytest_collectstart(collector): 2 | """Ignore several mimetypes when comparing notebooks.""" 3 | if collector.fspath and collector.fspath.ext == ".ipynb": 4 | collector.skip_compare += ( 5 | "text/html", 6 | "application/javascript", 7 | "application/vnd.plotly.v1+json", # Plotly 8 | ) 9 | -------------------------------------------------------------------------------- /demos/doc_sanitize.cfg: -------------------------------------------------------------------------------- 1 | [regex1] 2 | regex: \d{1,2}/\d{1,2}/\d{2,4} 3 | replace: DATE-STAMP 4 | 5 | [regex2] 6 | regex: \d{2}:\d{2}:\d{2} 7 | replace: TIME-STAMP 8 | 9 | [regex3] 10 | regex: 0[xX][0-9a-fA-F]+ 11 | replace: HEX-CODE -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Transformer-Lens Docs 3 | 4 | 5 | This repo contains the [website](https://TransformerLensOrg.github.io/TransformerLens/) for [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens). This site is currently in Beta and we are in the process of adding/editing information. 6 | 7 | The documentation uses Sphinx. However, the documentation is written in regular md, NOT rst. 8 | 9 | ## Build the Documentation 10 | 11 | First install the docs packages: 12 | 13 | ```bash 14 | poetry install --with docs 15 | ``` 16 | 17 | Then for hot-reloading, run this (note the model properties table won't hot reload, but everything 18 | else will): 19 | 20 | ```bash 21 | poetry run docs-hot-reload 22 | ``` 23 | 24 | Alternatively to build once, run: 25 | 26 | ```bash 27 | poetry run build-docs 28 | ``` 29 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/_static/transformer_lens_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TransformerLensOrg/TransformerLens/b5a16f849649a237cc02cc2c272ae4dc2085abe4/docs/source/_static/transformer_lens_logo.png -------------------------------------------------------------------------------- /docs/source/apidoc_templates/module.rst_t: -------------------------------------------------------------------------------- 1 | {%- if show_headings %} 2 | {{- [basename] | join(' ') | e | heading }} 3 | 4 | {% endif -%} 5 | .. automodule:: {{ qualname }} 6 | {%- for option in automodule_options %} 7 | :{{ option }}: 8 | {%- endfor %} -------------------------------------------------------------------------------- /docs/source/apidoc_templates/package.rst_t: -------------------------------------------------------------------------------- 1 | {%- macro automodule(modname, options) -%} 2 | .. automodule:: {{ modname }} 3 | {%- for option in options %} 4 | :{{ option }}: 5 | {%- endfor %} 6 | {%- endmacro %} 7 | 8 | {%- macro toctree(docnames) -%} 9 | .. toctree:: 10 | :maxdepth: 1 11 | {% for docname in docnames %} 12 | {{ docname }} 13 | {%- endfor %} 14 | {%- endmacro %} 15 | 16 | {%- if is_namespace %} 17 | {{- [pkgname] | join(" ") | e | heading }} 18 | {% else %} 19 | {{- [pkgname] | join(" ") | e | heading }} 20 | {% endif %} 21 | 22 | {%- if is_namespace %} 23 | .. py:module:: {{ pkgname }} 24 | {% endif %} 25 | 26 | {%- if modulefirst and not is_namespace %} 27 | {{ automodule(pkgname, automodule_options) }} 28 | {% endif %} 29 | 30 | {%- if submodules %} 31 | Submodules 32 | ---------- 33 | {% if separatemodules %} 34 | {{ toctree(submodules) }} 35 | {% else %} 36 | {%- for submodule in submodules %} 37 | {% if show_headings %} 38 | {{- [submodule, "module"] | join(" ") | e | heading(2) }} 39 | {% endif %} 40 | {{ automodule(submodule, automodule_options) }} 41 | {% endfor %} 42 | {%- endif %} 43 | {%- endif %} 44 | 45 | {%- if subpackages %} 46 | Subpackages 47 | ----------- 48 | 49 | {{ toctree(subpackages) }} 50 | {% endif %} 51 | -------------------------------------------------------------------------------- /docs/source/apidoc_templates/toc.rst_t: -------------------------------------------------------------------------------- 1 | Transformer Lens API 2 | -------------------- 3 | 4 | If browsing the docs for the first time, we recommend initially looking at :doc:`HookedTransformer 5 | ` and :doc:`ActivationCache 6 | ` modules. 7 | 8 | Contents 9 | ^^^^^^^^ 10 | 11 | .. toctree:: 12 | :maxdepth: 3 13 | {% for docname in docnames %} 14 | {{ docname }} 15 | {%- endfor %} -------------------------------------------------------------------------------- /docs/source/content/citation.md: -------------------------------------------------------------------------------- 1 | 2 | # Citation 3 | 4 | Please cite this library as: 5 | 6 | ```BibTeX 7 | @misc{nanda2022transformerlens, 8 | title = {TransformerLens}, 9 | author = {Neel Nanda and Joseph Bloom}, 10 | year = {2022}, 11 | howpublished = {\url{https://github.com/TransformerLensOrg/TransformerLens}}, 12 | } 13 | ``` 14 | -------------------------------------------------------------------------------- /docs/source/content/gallery.md: -------------------------------------------------------------------------------- 1 | # Gallery 2 | 3 | Research done involving TransformerLens: 4 | 5 | - [Progress Measures for Grokking via Mechanistic 6 | Interpretability](https://arxiv.org/abs/2301.05217) (ICLR Spotlight, 2023) by Neel Nanda, Lawrence 7 | Chan, Tom Lieberum, Jess Smith, Jacob Steinhardt 8 | - [Finding Neurons in a Haystack: Case Studies with Sparse 9 | Probing](https://arxiv.org/abs/2305.01610) by Wes Gurnee, Neel Nanda, Matthew Pauly, Katherine 10 | Harvey, Dmitrii Troitskii, Dimitris Bertsimas 11 | - [Towards Automated Circuit Discovery for Mechanistic 12 | Interpretability](https://arxiv.org/abs/2304.14997) by Arthur Conmy, Augustine N. Mavor-Parker, 13 | Aengus Lynch, Stefan Heimersheim, Adrià Garriga-Alonso 14 | - [Actually, Othello-GPT Has A Linear Emergent World Representation](https://neelnanda.io/othello) 15 | by Neel Nanda 16 | - [A circuit for Python docstrings in a 4-layer attention-only 17 | transformer](https://www.alignmentforum.org/posts/u6KXXmKFbXfWzoAXn/a-circuit-for-python-docstrings-in-a-4-layer-attention-only) 18 | by Stefan Heimersheim and Jett Janiak 19 | - [A Toy Model of Universality](https://arxiv.org/abs/2302.03025) (ICML, 2023) by Bilal Chughtai, 20 | Lawrence Chan, Neel Nanda 21 | - [N2G: A Scalable Approach for Quantifying Interpretable Neuron Representations in Large Language 22 | Models](https://openreview.net/forum?id=ZB6bK6MTYq) (2023, ICLR Workshop RTML) by Alex Foote, Neel 23 | Nanda, Esben Kran, Ioannis Konstas, Fazl Barez 24 | - [Eliciting Latent Predictions from Transformers with the Tuned 25 | Lens](https://arxiv.org/abs/2303.08112) by Nora Belrose, Zach Furman, Logan Smith, Danny Halawi, 26 | Igor Ostrovsky, Lev McKinney, Stella Biderman, Jacob Steinhardt 27 | 28 | User contributed examples of the library being used in action: 29 | 30 | - [Induction Heads Phase Change 31 | Replication](https://colab.research.google.com/github/ckkissane/induction-heads-transformer-lens/blob/main/Induction_Heads_Phase_Change.ipynb): 32 | A partial replication of [In-Context Learning and Induction 33 | Heads](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html) 34 | from Connor Kissane 35 | - [Decision Transformer 36 | Interpretability](https://github.com/jbloomAus/DecisionTransformerInterpretability): A set of 37 | scripts for training decision transformers which uses transformer lens to view intermediate 38 | activations, perform attribution and ablations. A write up of the initial work can be found 39 | [here](https://www.lesswrong.com/posts/bBuBDJBYHt39Q5zZy/decision-transformer-interpretability). -------------------------------------------------------------------------------- /docs/source/content/getting_started.md: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | 3 | **Start with the [main demo](https://neelnanda.io/transformer-lens-demo) to learn how the library works, and the basic features**. 4 | 5 | To see what using it for exploratory analysis in practice looks like, check out [my notebook analysing Indirect Objection Identification](https://neelnanda.io/exploratory-analysis-demo) or [my recording of myself doing research](https://www.youtube.com/watch?v=yo4QvDn-vsU)! 6 | 7 | Mechanistic interpretability is a very young and small field, and there are a *lot* of open problems - if you would like to help, please try working on one! **Check out my [list of concrete open problems](https://docs.google.com/document/d/1WONBzNqfKIxERejrrPlQMyKqg7jSFW92x5UMXNrMdPo/edit) to figure out where to start.**. It begins with advice on skilling up, and key resources to check out. 8 | 9 | If you're new to transformers, check out my [what is a transformer tutorial](https://neelnanda.io/transformer-tutorial) and [tutorial on coding GPT-2 from scratch](https://neelnanda.io/transformer-tutorial-2) (with [an accompanying template](https://neelnanda.io/transformer-template) to write one yourself! 10 | 11 | ## Advice for Reading the Code 12 | 13 | One significant design decision made was to have a single transformer implementation that could support a range of subtly different GPT-style models. This has the upside of interpretability code just working for arbitrary models when you change the model name in `HookedTransformer.from_pretrained`! But it has the significant downside that the code implementing the model (in `HookedTransformer.py` and `components.py`) can be difficult to read. I recommend starting with my [Clean Transformer Demo](https://neelnanda.io/transformer-solution), which is a clean, minimal implementation of GPT-2 with the same internal architecture and activation names as HookedTransformer, but is significantly clearer and better documented. 14 | 15 | ## Installation 16 | 17 | `pip install git+https://github.com/TransformerLensOrg/TransformerLens` 18 | 19 | Import the library with `import transformer_lens` 20 | 21 | (Note: This library used to be known as EasyTransformer, and some breaking changes have been made since the rename. If you need to use the old version with some legacy code, run `pip install git+https://github.com/TransformerLensOrg/TransformerLens@v1`.) 22 | 23 | ## Huggingface Gated Access 24 | 25 | Some of the models available in TransformerLens require gated access to be used. Luckily TransformerLens provides a way to access those models via the configuration of an environmental variable. Simply configure your access token found [here](https://huggingface.co/settings/tokens) as `HF_TOKEN` in your environment. 26 | 27 | You will need to make sure you accept the agreements for any gated models, but once you do, the models will work with TransformerLens without issue. If you attempt to ues one of these models before you have accepted any related agreements, the console output will be very helpful and point you to the URL where you need to accept an agreement. As of 23/4/24, the current list of gated models supported by TransformerLens is as follows. 28 | 29 | * https://huggingface.co/mistralai/Mixtral-8x7B-v0.1 30 | * https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 31 | * https://huggingface.co/mistralai/Mistral-7B-v0.1 32 | -------------------------------------------------------------------------------- /docs/source/content/getting_started_mech_interp.md: -------------------------------------------------------------------------------- 1 | # Getting Started in Mechanistic Interpretability 2 | 3 | Mechanistic interpretability is a very young and small field, and there are a _lot_ of open 4 | problems. This means there's both a lot of low-hanging fruit, and that the bar for entry is low - if 5 | you would like to help, please try working on one! The standard answer to "why has no one done this 6 | yet" is just that there aren't enough people! Key resources: 7 | 8 | - [A Guide to Getting Started in Mechanistic Interpretability](https://neelnanda.io/getting-started) 9 | - [ARENA Mechanistic Interpretability Tutorials](https://arena-ch1-transformers.streamlit.app/) from 10 | Callum McDougall. A comprehensive practical introduction to mech interp, written in 11 | TransformerLens - full of snippets to copy and they come with exercises and solutions! Notable 12 | tutorials: 13 | - [Coding GPT-2 from 14 | scratch](https://arena-ch1-transformers.streamlit.app/[1.1]_Transformer_from_Scratch), with 15 | accompanying video tutorial from me ([1](https://neelnanda.io/transformer-tutorial) 16 | [2](https://neelnanda.io/transformer-tutorial-2)) - a good introduction to transformers 17 | - [Introduction to Mech Interp and 18 | TransformerLens](https://arena-ch1-transformers.streamlit.app/[1.2]_Intro_to_Mech_Interp): An 19 | introduction to TransformerLens and mech interp via studying induction heads. Covers the 20 | foundational concepts of the library 21 | - [Indirect Object 22 | Identification](https://arena-ch1-transformers.streamlit.app/[1.3]_Indirect_Object_Identification): 23 | a replication of interpretability in the wild, that covers standard techniques in mech interp 24 | such as [direct logit 25 | attribution](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=disz2gTx-jooAcR0a5r8e7LZ), 26 | [activation patching and path 27 | patching](https://www.lesswrong.com/posts/xh85KbTFhbCz7taD4/how-to-think-about-activation-patching) 28 | - [Mech Interp Paper Reading List](https://neelnanda.io/paper-list) 29 | - [200 Concrete Open Problems in Mechanistic 30 | Interpretability](https://neelnanda.io/concrete-open-problems) 31 | - [A Comprehensive Mechanistic Interpretability Explainer](https://neelnanda.io/glossary): To look 32 | up all the jargon and unfamiliar terms you're going to come across! 33 | - [Neel Nanda's Youtube channel](https://www.youtube.com/channel/UCBMJ0D-omcRay8dh4QT0doQ): A range 34 | of mech interp video content, including [paper 35 | walkthroughs](https://www.youtube.com/watch?v=KV5gbOmHbjU&list=PL7m7hLIqA0hpsJYYhlt1WbHHgdfRLM2eY&index=1), 36 | and [walkthroughs of doing 37 | research](https://www.youtube.com/watch?v=yo4QvDn-vsU&list=PL7m7hLIqA0hr4dVOgjNwP2zjQGVHKeB7T) -------------------------------------------------------------------------------- /docs/source/content/special_cases.md: -------------------------------------------------------------------------------- 1 | # Special Cases 2 | 3 | ## Mixture of Experts error rates 4 | Due to the Top-K gating performed in the hidden layer of Mixture of Experts models, small errors can be amplified 5 | greatly in cases where a different expert is selected, which leads to a higher than normal variance in the error rate 6 | of the final logits. In testing done on Mixtral running in half precision, the standard deviation of the absolute error 7 | rate of the logits compared to those from the default model was found to be around 2e-3. 8 | 9 | There are two main ways to mitigate this: 10 | 1. Disable preprocessing options by using `HookedTransformer.from_pretrained_no_processing` instead of `HookedTransformer.from_pretrained` 11 | 2. Increase the precision of the data type used in the model 12 | -------------------------------------------------------------------------------- /docs/source/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TransformerLensOrg/TransformerLens/b5a16f849649a237cc02cc2c272ae4dc2085abe4/docs/source/favicon.ico -------------------------------------------------------------------------------- /docs/source/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | hide-toc: true 3 | firstpage: 4 | lastpage: 5 | --- 6 | 7 | # TransformerLens 8 | 9 | (Formerly known as EasyTransformer) [![Pypi](https://img.shields.io/pypi/v/transformer-lens)](https://pypi.org/project/transformer-lens/) 10 | 11 | ## A Library for Mechanistic Interpretability of Generative Language Models 12 | 13 | This is a library for doing [mechanistic interpretability](https://distill.pub/2020/circuits/zoom-in/) of GPT-2 Style language models. The goal of mechanistic interpretability is to take a trained model and reverse engineer the algorithms the model learned during training from its weights. It is a fact about the world today that we have computer programs that can essentially speak English at a human level (GPT-3, PaLM, etc), yet we have no idea how they work nor how to write one ourselves. This offends me greatly, and I would like to solve this! 14 | 15 | TransformerLens lets you load in an open source language model, like GPT-2, and exposes the internal activations of the model to you. You can cache any internal activation in the model, and add in functions to edit, remove or replace these activations as the model runs. The core design principle I've followed is to enable exploratory analysis. One of the most fun parts of mechanistic interpretability compared to normal ML is the extremely short feedback loops! The point of this library is to keep the gap between having an experiment idea and seeing the results as small as possible, to make it easy for **research to feel like play** and to enter a flow state. Part of what I aimed for is to make *my* experience of doing research easier and more fun, hopefully this transfers to you! 16 | 17 | I used to work for the [Anthropic interpretability team](https://transformer-circuits.pub/), and I wrote this library because after I left and tried doing independent research, I got extremely frustrated by the state of open source tooling. There's a lot of excellent infrastructure like HuggingFace and DeepSpeed to *use* or *train* models, but very little to dig into their internals and reverse engineer how they work. **This library tries to solve that**, and to make it easy to get into the field even if you don't work at an industry org with real infrastructure! One of the great things about mechanistic interpretability is that you don't need large models or tons of compute. There are lots of important open problems that can be solved with a small model in a Colab notebook! 18 | 19 | The core features were heavily inspired by the interface to [Anthropic's excellent Garcon tool](https://transformer-circuits.pub/2021/garcon/index.html). Credit to Nelson Elhage and Chris Olah for building Garcon and showing me the value of good infrastructure for enabling exploratory research! 20 | 21 | A great place to start is to take a look at a helpful diagram of [all weight matrices and activation tensors with TransformerLens notation](_static/TransformerLens_Diagram.svg) courtesy of [Austin Kozlowski](https://github.com/akozlo). Another helpful tool to help you get going as quickly as possible is our [Colab Compatability Demo](https://github.com/TransformerLensOrg/TransformerLens/tree/main/demos/Colab_Compatibility.ipynb), which will give you a good idea of what you can do in various Colab environments. 22 | 23 | ```{toctree} 24 | :hidden: 25 | :caption: Introduction 26 | 27 | content/getting_started 28 | content/getting_started_mech_interp 29 | content/gallery 30 | ``` 31 | 32 | ```{toctree} 33 | :hidden: 34 | :caption: Documentation 35 | 36 | generated/code/modules 37 | generated/model_properties_table.md 38 | ``` 39 | 40 | ```{toctree} 41 | :hidden: 42 | :caption: Resources 43 | 44 | content/tutorials 45 | content/citation 46 | content/contributing 47 | generated/demos/Main_Demo 48 | generated/demos/Exploratory_Analysis_Demo 49 | content/special_cases 50 | ``` 51 | 52 | ```{toctree} 53 | :hidden: 54 | :caption: News 55 | 56 | content/news/release-2.0 57 | ``` 58 | 59 | ```{toctree} 60 | :hidden: 61 | :caption: Development 62 | 63 | content/contributing 64 | Code Coverage 65 | Github 66 | ``` 67 | -------------------------------------------------------------------------------- /easy_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.warning("DEPRECATED: Library has been renamed, import transformer_lens instead") 4 | from transformer_lens import * 5 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | format: 2 | poetry run pycln --all . --exclude "__init__.py" 3 | poetry run isort format . 4 | poetry run black . 5 | 6 | check-format: 7 | poetry run pycln --check --all . --exclude "__init__.py" 8 | poetry run isort --check-only . 9 | poetry run black --check . 10 | 11 | unit-test: 12 | poetry run pytest tests/unit 13 | 14 | integration-test: 15 | poetry run pytest tests/integration 16 | 17 | acceptance-test: 18 | poetry run pytest tests/acceptance 19 | 20 | coverage-report-test: 21 | poetry run pytest --cov=transformer_lens/ --cov-report=html --cov-branch tests/unit tests/integration tests/acceptance 22 | 23 | docstring-test: 24 | poetry run pytest transformer_lens/ 25 | 26 | notebook-test: 27 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/BERT.ipynb 28 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Exploratory_Analysis_Demo.ipynb 29 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Main_Demo.ipynb 30 | 31 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Head_Detector_Demo.ipynb 32 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Interactive_Neuroscope.ipynb 33 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/LLaMA.ipynb 34 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/No_Position_Experiment.ipynb 35 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Othello_GPT.ipynb 36 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Qwen.ipynb 37 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Santa_Coder.ipynb 38 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Stable_Lm.ipynb 39 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/SVD_Interpreter_Demo.ipynb 40 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Tracr_to_Transformer_Lens_Demo.ipynb 41 | 42 | # Contains failing cells 43 | 44 | # Causes CI to hang 45 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Activation_Patching_in_TL_Demo.ipynb 46 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Attribution_Patching_Demo.ipynb 47 | poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Grokking_Demo.ipynb 48 | 49 | test: 50 | make unit-test 51 | make acceptance-test 52 | make docstring-test 53 | make notebook-test 54 | 55 | docs-hot-reload: 56 | poetry run docs-hot-reload 57 | 58 | build-docs: 59 | poetry run build-docs 60 | -------------------------------------------------------------------------------- /tests/acceptance/test_evals.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from transformer_lens.evals import IOIDataset, ioi_eval 4 | from transformer_lens.HookedTransformer import HookedTransformer 5 | 6 | 7 | @pytest.fixture(scope="module") 8 | def model(): 9 | return HookedTransformer.from_pretrained("gpt2-small") 10 | 11 | 12 | def test_basic_ioi_eval(model): 13 | """ 14 | Test IOI evaluation with default dataset and settings. 15 | """ 16 | results = ioi_eval(model, num_samples=100) 17 | assert results["Accuracy"] >= 0.99 18 | 19 | 20 | def test_symmetric_samples(model): 21 | """ 22 | Test IOI evaluation with symmetric=True so prompts are in symmetric pairs. 23 | """ 24 | ds = IOIDataset(tokenizer=model.tokenizer, num_samples=100, symmetric=True) 25 | results = ioi_eval(model, dataset=ds) 26 | assert results["Logit Difference"] > 2.0 27 | assert results["Accuracy"] > 0.9 28 | 29 | 30 | def test_custom_dataset_ioi_eval(model): 31 | """ 32 | Test IOI eval with custom dataset using different templates, names, and objects. 33 | """ 34 | ds = IOIDataset( 35 | tokenizer=model.tokenizer, 36 | num_samples=100, 37 | templates=["[A] met with [B]. [B] gave the [OBJECT] to [A]"], 38 | names=["Alice", "Bob", "Charlie"], 39 | nouns={"OBJECT": ["ball", "book"]}, 40 | ) 41 | results = ioi_eval(model, dataset=ds) 42 | assert results["Logit Difference"] > 2.0 43 | assert results["Accuracy"] >= 0.99 44 | 45 | 46 | def test_multitoken_names_ioi_eval(model): 47 | """ 48 | Test the IOI evaluation with multi-token names in the dataset. 49 | """ 50 | ds = IOIDataset( 51 | tokenizer=model.tokenizer, 52 | num_samples=100, 53 | names=["John Smith", "John Doe"], 54 | ) 55 | results = ioi_eval(model, dataset=ds) 56 | assert results["Logit Difference"] > 2.0 57 | assert results["Accuracy"] >= 0.99 58 | 59 | 60 | def test_inverted_template(model): 61 | """ 62 | Test IOI eval with an unnatural template (BAAA). 63 | This should result in a negative logit difference and very low accuracy. 64 | """ 65 | ds = IOIDataset( 66 | tokenizer=model.tokenizer, 67 | num_samples=100, 68 | templates=["[B] met with [A]. [A] said hello to [A]"], 69 | ) 70 | results = ioi_eval(model, dataset=ds) 71 | assert results["Logit Difference"] < -2.0 72 | assert results["Accuracy"] <= 0.01 73 | -------------------------------------------------------------------------------- /tests/acceptance/test_hook_tokens.py: -------------------------------------------------------------------------------- 1 | # %% 2 | 3 | import functools 4 | 5 | import torch as t 6 | from jaxtyping import Int 7 | 8 | from transformer_lens import HookedTransformer, HookedTransformerConfig 9 | from transformer_lens.hook_points import HookPoint 10 | 11 | 12 | def test_patch_tokens(): 13 | # Define small transformer 14 | cfg = HookedTransformerConfig( 15 | n_layers=1, 16 | d_mlp=10, 17 | d_model=10, 18 | d_head=5, 19 | n_heads=2, 20 | n_ctx=20, 21 | act_fn="relu", 22 | tokenizer_name="gpt2", 23 | use_hook_tokens=True, 24 | ) 25 | model = HookedTransformer(cfg=cfg) 26 | 27 | # Define short prompt, and a token to replace the first token with (note this is index 1, because BOS) 28 | prompt = "Hello World!" 29 | modified_prompt = "Hi World!" 30 | new_first_token = model.to_single_token("Hi") 31 | 32 | # Define hook function to alter the first token 33 | def hook_fn(tokens: Int[t.Tensor, "batch seq"], hook: HookPoint, new_first_token: int): 34 | assert ( 35 | tokens[0, 0].item() != new_first_token 36 | ) # Need new_first_token to be different from original 37 | tokens[0, 0] = new_first_token 38 | return tokens 39 | 40 | # Run with hooks 41 | out_from_hook = model.run_with_hooks( 42 | prompt, 43 | prepend_bos=False, 44 | fwd_hooks=[("hook_tokens", functools.partial(hook_fn, new_first_token=new_first_token))], 45 | ) 46 | 47 | out_direct = model(modified_prompt, prepend_bos=False) 48 | 49 | t.testing.assert_close(out_from_hook, out_direct) 50 | -------------------------------------------------------------------------------- /tests/acceptance/test_tokenizer_special_tokens.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | 3 | import transformer_lens.loading_from_pretrained as loading 4 | from transformer_lens import HookedTransformer, HookedTransformerConfig 5 | 6 | # Get's tedious typing these out everytime I want to sweep over all the distinct small models 7 | MODEL_TESTING_LIST = [ 8 | "solu-1l", 9 | "gpt2-small", 10 | "gpt-neo-125M", 11 | "opt-125m", 12 | "opt-30b", 13 | "stanford-gpt2-small-a", 14 | "pythia-70m", 15 | ] 16 | 17 | 18 | def test_d_vocab_from_tokenizer(): 19 | cfg = HookedTransformerConfig( 20 | n_layers=1, d_mlp=10, d_model=10, d_head=5, n_heads=2, n_ctx=20, act_fn="relu" 21 | ) 22 | test_string = "a fish." 23 | # Test tokenizers for different models 24 | for model_name in MODEL_TESTING_LIST: 25 | if model_name == "solu-1l": 26 | tokenizer_name = "NeelNanda/gpt-neox-tokenizer-digits" 27 | else: 28 | tokenizer_name = loading.get_official_model_name(model_name) 29 | 30 | model = HookedTransformer(cfg=cfg, tokenizer=AutoTokenizer.from_pretrained(tokenizer_name)) 31 | 32 | tokens_with_bos = model.to_tokens(test_string) 33 | tokens_without_bos = model.to_tokens(test_string, prepend_bos=False) 34 | 35 | # Check that the lengths are different by one 36 | assert ( 37 | tokens_with_bos.shape[-1] == tokens_without_bos.shape[-1] + 1 38 | ), "BOS Token not added when expected" 39 | # Check that we don't have BOS when we disable the flag 40 | assert ( 41 | tokens_without_bos.squeeze()[0] != model.tokenizer.bos_token_id 42 | ), "BOS token is present when it shouldn't be" 43 | -------------------------------------------------------------------------------- /tests/integration/test_attention_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformer_lens import utils 4 | from transformer_lens.HookedTransformer import HookedTransformer 5 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 6 | 7 | 8 | def test_attention_mask(): 9 | # Verify the attention mask attends properly, including for low attention scores 10 | cfg = HookedTransformerConfig( 11 | d_head=1, 12 | d_model=12, 13 | d_vocab=2, 14 | n_ctx=5, 15 | n_layers=1, 16 | attn_only=True, 17 | attention_dir="causal", 18 | ) 19 | model = HookedTransformer(cfg) 20 | input_length = 5 21 | input = torch.ones((1, input_length), dtype=torch.int64) 22 | layer = 0 23 | low_attn_score = 1e-6 24 | ones_input_matrix = torch.ones((input_length, input_length)) 25 | masked = torch.triu(ones_input_matrix, diagonal=1).bool() 26 | 27 | def attn_scores_hook(attn_scores, hook): 28 | assert torch.all( 29 | attn_scores[:, :, masked] == float("-inf") 30 | ), "Attention scores excluded by the mask are not being set to -inf" 31 | 32 | # Set low attention scores that are attended to by the mask 33 | attn_scores[:, :, ~masked] = low_attn_score 34 | 35 | return attn_scores 36 | 37 | def attn_hook(attn, hook): 38 | assert torch.all(attn[:, :, masked] == 0), "Attention pattern attends outside the mask" 39 | 40 | return attn 41 | 42 | fwd_hooks = [ 43 | (utils.get_act_name("attn_scores", layer), attn_scores_hook), 44 | (utils.get_act_name("attn", layer), attn_hook), 45 | ] 46 | 47 | model.run_with_hooks(input, fwd_hooks=fwd_hooks) 48 | 49 | 50 | def test_masked_tokens(): 51 | """Test that masking tokens works as expected.""" 52 | MODEL = "solu-1l" 53 | prompts = [ 54 | "Hello world!", 55 | "The quick brown fox jumps over the lazy dog.", 56 | ] 57 | model = HookedTransformer.from_pretrained(MODEL) 58 | tokens = model.to_tokens(prompts) 59 | 60 | # Part 1: If the mask is all ones, the output should be the same as if there was no mask. 61 | full_mask = torch.ones_like(tokens) 62 | no_mask_out = model(tokens) 63 | full_mask_out = model(tokens, attention_mask=full_mask) 64 | assert torch.allclose(no_mask_out, full_mask_out), "Full mask should be equivalent to no mask" 65 | 66 | # Part 2: If the mask has a column of zeros, the output should be the same as if that token 67 | # position was removed from the input. 68 | remove_tok_idx = 2 69 | edited_tokens = torch.cat([tokens[:, :remove_tok_idx], tokens[:, remove_tok_idx + 1 :]], dim=1) 70 | edited_mask = full_mask.clone() 71 | edited_mask[:, remove_tok_idx] = 0 72 | edited_no_mask_out = model(edited_tokens) 73 | edited_mask_out = model(tokens, attention_mask=edited_mask) 74 | edited_mask_out = torch.cat( 75 | [edited_mask_out[:, :remove_tok_idx], edited_mask_out[:, remove_tok_idx + 1 :]], dim=1 76 | ) 77 | assert torch.allclose( 78 | edited_no_mask_out, edited_mask_out, atol=1e-4 79 | ), "Edited mask should be equivalent to no mask" 80 | -------------------------------------------------------------------------------- /tests/integration/test_cache_hook_names.py: -------------------------------------------------------------------------------- 1 | from transformer_lens import HookedTransformer 2 | 3 | MODEL = "solu-1l" 4 | 5 | prompt = "Hello World!" 6 | model = HookedTransformer.from_pretrained(MODEL) 7 | 8 | act_names_in_cache = [ 9 | "hook_embed", 10 | "hook_pos_embed", 11 | "blocks.0.hook_resid_pre", 12 | "blocks.0.ln1.hook_scale", 13 | "blocks.0.ln1.hook_normalized", 14 | "blocks.0.attn.hook_q", 15 | "blocks.0.attn.hook_k", 16 | "blocks.0.attn.hook_v", 17 | "blocks.0.attn.hook_attn_scores", 18 | "blocks.0.attn.hook_pattern", 19 | "blocks.0.attn.hook_z", 20 | "blocks.0.hook_attn_out", 21 | "blocks.0.hook_resid_mid", 22 | "blocks.0.ln2.hook_scale", 23 | "blocks.0.ln2.hook_normalized", 24 | "blocks.0.mlp.hook_pre", 25 | "blocks.0.mlp.hook_mid", 26 | "blocks.0.mlp.ln.hook_scale", 27 | "blocks.0.mlp.ln.hook_normalized", 28 | "blocks.0.mlp.hook_post", 29 | "blocks.0.hook_mlp_out", 30 | "blocks.0.hook_resid_post", 31 | "ln_final.hook_scale", 32 | "ln_final.hook_normalized", 33 | ] 34 | 35 | 36 | def test_cache_hook_names(): 37 | _, cache = model.run_with_cache(prompt) 38 | assert list(cache.keys()) == act_names_in_cache 39 | -------------------------------------------------------------------------------- /tests/integration/test_create_hooked_encoder.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from transformers import AutoTokenizer, BertTokenizerFast 3 | 4 | from transformer_lens import HookedEncoder, HookedTransformerConfig 5 | 6 | 7 | @pytest.fixture 8 | def cfg(): 9 | return HookedTransformerConfig(d_head=4, d_model=12, n_ctx=5, n_layers=3, act_fn="gelu") 10 | 11 | 12 | def test_pass_tokenizer(cfg): 13 | tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") 14 | model = HookedEncoder(cfg, tokenizer=tokenizer) 15 | assert model.tokenizer == tokenizer 16 | 17 | 18 | def test_load_tokenizer_from_config(cfg): 19 | cfg.tokenizer_name = "bert-base-cased" 20 | model = HookedEncoder(cfg) 21 | assert isinstance(model.tokenizer, BertTokenizerFast) 22 | 23 | 24 | def test_load_without_tokenizer(cfg): 25 | cfg.d_vocab = 22 26 | model = HookedEncoder(cfg) 27 | assert model.tokenizer is None 28 | 29 | 30 | def test_cannot_load_without_tokenizer_or_d_vocab(cfg): 31 | with pytest.raises(AssertionError) as e: 32 | HookedEncoder(cfg) 33 | assert "Must provide a tokenizer if d_vocab is not provided" in str(e.value) 34 | -------------------------------------------------------------------------------- /tests/integration/test_cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformer_lens.HookedTransformer import HookedTransformer 4 | 5 | 6 | def test_cross_entropy_attention_mask(): 7 | """Check that adding a bunch of masked tokens to the input does not change the loss.""" 8 | MODEL = "solu-1l" 9 | model = HookedTransformer.from_pretrained(MODEL) 10 | 11 | # Step 1: Get the default loss on a prompt 12 | prompt = ["The quick brown fox jumps over the lazy dog."] 13 | default_tokens = model.to_tokens(prompt) 14 | default_attention_mask = torch.ones_like(default_tokens) 15 | default_loss = model(default_tokens, return_type="loss") 16 | ones_mask_loss = model( 17 | default_tokens, attention_mask=default_attention_mask, return_type="loss" 18 | ) 19 | assert torch.allclose(default_loss, ones_mask_loss, atol=1e-6) 20 | 21 | # Step 2: Get the loss when we add some extra tokens to the input and set their attention mask 22 | # to zero 23 | extra_prompt = ["Lorem ipsum dolor sit amet, consectetur adipiscing elit."] 24 | extra_tokens = model.to_tokens(extra_prompt) 25 | extra_zeros_attention_mask = torch.zeros_like(extra_tokens) 26 | 27 | combined_tokens = torch.cat([default_tokens, extra_tokens], dim=1) 28 | combined_attention_mask = torch.cat([default_attention_mask, extra_zeros_attention_mask], dim=1) 29 | combined_masked_loss = model( 30 | combined_tokens, attention_mask=combined_attention_mask, return_type="loss" 31 | ) 32 | assert torch.allclose(default_loss, combined_masked_loss) 33 | -------------------------------------------------------------------------------- /tests/integration/test_d_vocab.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | 3 | from transformer_lens import HookedTransformer, HookedTransformerConfig 4 | 5 | 6 | def test_d_vocab_from_tokenizer(): 7 | cfg = HookedTransformerConfig( 8 | n_layers=1, d_mlp=10, d_model=10, d_head=5, n_heads=2, n_ctx=20, act_fn="relu" 9 | ) 10 | model = HookedTransformer(cfg=cfg, tokenizer=AutoTokenizer.from_pretrained("gpt2")) 11 | assert model.cfg.d_vocab == 50257 12 | assert model.cfg.d_vocab_out == 50257 13 | 14 | 15 | def test_d_vocab_from_tokenizer_name(): 16 | cfg = HookedTransformerConfig( 17 | n_layers=1, 18 | d_mlp=10, 19 | d_model=10, 20 | d_head=5, 21 | n_heads=2, 22 | n_ctx=20, 23 | act_fn="relu", 24 | tokenizer_name="gpt2", 25 | ) 26 | model = HookedTransformer(cfg=cfg) 27 | assert model.cfg.d_vocab == 50257 28 | assert model.cfg.d_vocab_out == 50257 29 | 30 | 31 | def test_d_vocab_out_set(): 32 | cfg = HookedTransformerConfig( 33 | n_layers=1, 34 | d_mlp=10, 35 | d_model=10, 36 | d_head=5, 37 | n_heads=2, 38 | n_ctx=20, 39 | act_fn="relu", 40 | d_vocab=100, 41 | d_vocab_out=90, 42 | ) 43 | model = HookedTransformer(cfg=cfg) 44 | assert model.cfg.d_vocab == 100 45 | assert model.cfg.d_vocab_out == 90 46 | 47 | 48 | def test_d_vocab_out_set_d_vocab_infer(): 49 | cfg = HookedTransformerConfig( 50 | n_layers=1, 51 | d_mlp=10, 52 | d_model=10, 53 | d_head=5, 54 | n_heads=2, 55 | n_ctx=20, 56 | act_fn="relu", 57 | d_vocab_out=90, 58 | tokenizer_name="gpt2", 59 | ) 60 | model = HookedTransformer(cfg=cfg) 61 | assert model.cfg.d_vocab == 50257 62 | assert model.cfg.d_vocab_out == 90 63 | -------------------------------------------------------------------------------- /tests/integration/test_loading_from_pretrained.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests that verify than an arbitrary component (e.g. Embed) can be initialized using dict and object versions of HookedTransformerConfig and HookedEncoderConfig. 3 | """ 4 | 5 | from transformer_lens import loading_from_pretrained as loading 6 | 7 | 8 | def test_get_basic_config(): 9 | cfg = loading.get_basic_config("gpt2-small") 10 | assert cfg.d_model 11 | assert cfg.layer_norm_eps 12 | assert cfg.d_vocab 13 | assert cfg.init_range 14 | assert cfg.n_ctx 15 | assert cfg.d_head 16 | assert cfg.d_mlp 17 | assert cfg.n_heads 18 | assert cfg.n_layers 19 | -------------------------------------------------------------------------------- /tests/integration/test_match_huggingface.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import pytest 4 | import torch 5 | from transformers import AutoModelForCausalLM 6 | 7 | from transformer_lens import HookedTransformer 8 | 9 | 10 | class TestMatchHuggingFace: 11 | # fixtures 12 | @pytest.fixture(scope="class", params=["gpt2"]) 13 | def model_name(self, request): 14 | return request.param 15 | 16 | # tests 17 | def test_compare_huggingface_mlp_match_local_implementation(self, model_name): 18 | tl_model = HookedTransformer.from_pretrained_no_processing(model_name, device="cpu") 19 | hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu") 20 | tensor_shape = (3, 5, tl_model.cfg.d_model) 21 | test_tensor = torch.randn(tensor_shape) 22 | 23 | for layer_n in range(len(tl_model.blocks)): 24 | tl_out = tl_model.blocks[layer_n].mlp(test_tensor) 25 | hf_out = hf_model.transformer.h[layer_n].mlp(test_tensor) 26 | 27 | assert torch.sum(tl_out == hf_out) == math.prod(tensor_shape) 28 | 29 | def test_compare_huggingface_attention_match_local_implementation(self, model_name): 30 | tl_model = HookedTransformer.from_pretrained_no_processing(model_name, device="cpu") 31 | hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu") 32 | batch, pos, d_model = 3, 5, tl_model.cfg.d_model 33 | input = torch.randn(batch, pos, d_model) 34 | 35 | for layer_n in range(len(tl_model.blocks)): 36 | tl_out = tl_model.blocks[layer_n].attn( 37 | query_input=input, 38 | key_input=input, 39 | value_input=input, 40 | past_kv_cache_entry=None, 41 | attention_mask=None, 42 | ) 43 | hf_out, _, _ = hf_model.transformer.h[layer_n].attn( 44 | hidden_states=input, output_attentions=True 45 | ) 46 | 47 | assert torch.sum(tl_out == hf_out) == math.prod(tl_out.shape) 48 | -------------------------------------------------------------------------------- /tests/integration/test_utils_tokens.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | 7 | import transformer_lens.utils as utils 8 | from transformer_lens import HookedTransformer 9 | 10 | MODEL = "solu-1l" 11 | 12 | model = HookedTransformer.from_pretrained(MODEL) 13 | 14 | 15 | @pytest.fixture 16 | def nested_list_1(): 17 | return [1] 18 | 19 | 20 | @pytest.fixture 21 | def nested_list_1x1(): 22 | return [[6]] 23 | 24 | 25 | @pytest.fixture 26 | def nested_list_1x3(): 27 | return [[1, 2, 3]] 28 | 29 | 30 | def test_to_str_tokens(nested_list_1, nested_list_1x1, nested_list_1x3): 31 | tensor_1_to_str_tokens = model.to_str_tokens(torch.tensor(nested_list_1)) 32 | assert isinstance(tensor_1_to_str_tokens, list) 33 | assert len(tensor_1_to_str_tokens) == 1 34 | assert isinstance(tensor_1_to_str_tokens[0], str) 35 | 36 | tensor_1x1_to_str_tokens = model.to_str_tokens(torch.tensor(nested_list_1x1)) 37 | assert isinstance(tensor_1x1_to_str_tokens, list) 38 | assert len(tensor_1x1_to_str_tokens) == 1 39 | assert isinstance(tensor_1x1_to_str_tokens[0], str) 40 | 41 | ndarray_1_to_str_tokens = model.to_str_tokens(np.array(nested_list_1)) 42 | assert isinstance(ndarray_1_to_str_tokens, list) 43 | assert len(ndarray_1_to_str_tokens) == 1 44 | assert isinstance(ndarray_1_to_str_tokens[0], str) 45 | 46 | ndarray_1x1_to_str_tokens = model.to_str_tokens(np.array(nested_list_1x1)) 47 | assert isinstance(ndarray_1x1_to_str_tokens, list) 48 | assert len(ndarray_1x1_to_str_tokens) == 1 49 | assert isinstance(ndarray_1x1_to_str_tokens[0], str) 50 | 51 | single_int_to_single_str_token = model.to_single_str_token(3) 52 | assert isinstance(single_int_to_single_str_token, str) 53 | 54 | squeezable_tensor_to_str_tokens = model.to_str_tokens(torch.tensor(nested_list_1x3)) 55 | assert isinstance(squeezable_tensor_to_str_tokens, list) 56 | assert len(squeezable_tensor_to_str_tokens) == 3 57 | assert isinstance(squeezable_tensor_to_str_tokens[0], str) 58 | assert isinstance(squeezable_tensor_to_str_tokens[1], str) 59 | assert isinstance(squeezable_tensor_to_str_tokens[2], str) 60 | 61 | squeezable_ndarray_to_str_tokens = model.to_str_tokens(np.array(nested_list_1x3)) 62 | assert isinstance(squeezable_ndarray_to_str_tokens, list) 63 | assert len(squeezable_ndarray_to_str_tokens) == 3 64 | assert isinstance(squeezable_ndarray_to_str_tokens[0], str) 65 | assert isinstance(squeezable_ndarray_to_str_tokens[1], str) 66 | assert isinstance(squeezable_ndarray_to_str_tokens[2], str) 67 | 68 | 69 | @pytest.mark.parametrize( 70 | "prepend_space_to_answer, tokenized_prompt, tokenized_answer", 71 | [ 72 | ( 73 | True, 74 | [ 75 | "<|BOS|>", 76 | "The", 77 | " circumference", 78 | " is", 79 | " the", 80 | " perimeter", 81 | " of", 82 | " the", 83 | " circ", 84 | ], 85 | [" le", "."], 86 | ), 87 | ( 88 | False, 89 | [ 90 | "<|BOS|>", 91 | "The", 92 | " circumference", 93 | " is", 94 | " the", 95 | " perimeter", 96 | " of", 97 | " the", 98 | " circ", 99 | ], 100 | ["le", "."], 101 | ), 102 | ], 103 | ) 104 | @mock.patch("builtins.print") 105 | def test_test_prompt( 106 | mocked_print, 107 | prepend_space_to_answer, 108 | tokenized_prompt, 109 | tokenized_answer, 110 | ): 111 | """ 112 | Tests that utils.test_prompt produces the correct tokenization. In particular, when prepend_space_to_answer = False, the last token of the prompt 113 | and the first answer token should not be turned into one token (e.g. 'circ' and 'le' don't become 'circle'). See https://github.com/TransformerLensOrg/TransformerLens/issues/271 114 | for a more detailed explanation. 115 | """ 116 | utils.test_prompt( 117 | "The circumference is the perimeter of the circ", 118 | "le.", 119 | model, 120 | prepend_space_to_answer=prepend_space_to_answer, 121 | ) 122 | 123 | printed_tokenized_prompt = mock.call("Tokenized prompt:", tokenized_prompt) 124 | printed_tokenized_answer = mock.call("Tokenized answer:", tokenized_answer) 125 | 126 | assert mocked_print.mock_calls[0] == printed_tokenized_prompt 127 | assert mocked_print.mock_calls[1] == printed_tokenized_answer 128 | -------------------------------------------------------------------------------- /tests/manual_checks/manual_checks_type_annotations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from jaxtyping import Float 3 | 4 | from transformer_lens import HookedTransformer 5 | 6 | MODEL = "gpt2" 7 | model = HookedTransformer.from_pretrained(MODEL) 8 | 9 | prompt = "Hello World!" 10 | tokens = model.to_tokens(prompt, prepend_bos=False) 11 | logits_tokens = model(tokens) 12 | logits_text: Float[torch.Tensor, "1 n_tokens d_vocab"] = model(prompt, prepend_bos=False) 13 | 14 | # n.b. that i used this file to see if my type annotations were working- they were! i occasionally 15 | # changed one of the sizes and saw that the type checker caught it. 16 | -------------------------------------------------------------------------------- /tests/manual_checks/manual_checks_typing.py: -------------------------------------------------------------------------------- 1 | # Adapted from [HookedTransformer_Demo.ipynb]. Useful for testing that all the typing mechanisms work 2 | # out. 3 | 4 | # %% 5 | 6 | import torch as t 7 | from jaxtyping import Float 8 | 9 | from transformer_lens import HookedTransformer, utils 10 | 11 | DEVICE = utils.get_device() 12 | MODEL = "gpt2" 13 | 14 | # %% 15 | model = HookedTransformer.from_pretrained(MODEL) 16 | model.to(DEVICE) 17 | 18 | # %% 19 | 20 | prompt = "Hello World!" 21 | tokens = model.to_tokens(prompt, prepend_bos=False) 22 | logits_tokens = model(tokens) 23 | logits_text: Float[t.Tensor, "1 n_tokens d_vocab"] = model(prompt, prepend_bos=False) 24 | 25 | # %% 26 | 27 | logits_text.shape 28 | # %% 29 | -------------------------------------------------------------------------------- /tests/unit/components/mlps/test_can_be_used_as_mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import pytest 4 | import torch 5 | 6 | from transformer_lens.components import LayerNorm, LayerNormPre 7 | from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP 8 | from transformer_lens.hook_points import HookPoint 9 | from transformer_lens.utils import solu 10 | 11 | 12 | @pytest.fixture 13 | def cfg() -> Dict[str, Any]: 14 | return { 15 | "n_layers": 12, 16 | "n_ctx": 1024, 17 | "d_head": 64, 18 | "d_model": 128, 19 | "d_mlp": 256, 20 | "dtype": torch.float32, 21 | "act_fn": "solu_ln", 22 | "normalization_type": "LN", 23 | "load_in_4bit": False, 24 | } 25 | 26 | 27 | def test_initialization(cfg: Dict[str, Any]): 28 | CanBeUsedAsMLP(cfg) 29 | 30 | 31 | def test_initialization_fails_without_d_mlp(cfg: Dict[str, Any]): 32 | cfg["d_mlp"] = None 33 | pytest.raises(ValueError) 34 | CanBeUsedAsMLP(cfg) 35 | 36 | 37 | def test_select_activation_function_selects_function(): 38 | cfg = { 39 | "n_layers": 12, 40 | "n_ctx": 1024, 41 | "d_head": 64, 42 | "d_model": 128, 43 | "d_mlp": 256, 44 | "dtype": torch.float32, 45 | "act_fn": "silu", 46 | "normalization_type": "LN", 47 | "load_in_4bit": False, 48 | } 49 | 50 | model = CanBeUsedAsMLP(cfg) 51 | model.select_activation_function() 52 | assert model.act_fn is not None 53 | 54 | 55 | def test_select_activation_function_with_layer_norm(): 56 | cfg = { 57 | "n_layers": 12, 58 | "n_ctx": 1024, 59 | "d_head": 64, 60 | "d_model": 128, 61 | "d_mlp": 256, 62 | "dtype": torch.float32, 63 | "act_fn": "solu_ln", 64 | "normalization_type": "LN", 65 | "load_in_4bit": False, 66 | } 67 | 68 | model = CanBeUsedAsMLP(cfg) 69 | model.select_activation_function() 70 | assert model.act_fn == solu 71 | assert isinstance(model.hook_mid, HookPoint) 72 | assert isinstance(model.ln, LayerNorm) 73 | 74 | 75 | def test_select_activation_function_with_layer_norm_pre(): 76 | cfg = { 77 | "n_layers": 12, 78 | "n_ctx": 1024, 79 | "d_head": 64, 80 | "d_model": 128, 81 | "d_mlp": 256, 82 | "dtype": torch.float32, 83 | "act_fn": "solu_ln", 84 | "normalization_type": "LNPre", 85 | "load_in_4bit": False, 86 | } 87 | 88 | model = CanBeUsedAsMLP(cfg) 89 | model.select_activation_function() 90 | assert model.act_fn == solu 91 | assert isinstance(model.hook_mid, HookPoint) 92 | assert isinstance(model.ln, LayerNormPre) 93 | -------------------------------------------------------------------------------- /tests/unit/components/mlps/test_gated_mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import pytest 4 | import torch 5 | import torch.nn as nn 6 | 7 | from transformer_lens.components import GatedMLP, LayerNorm 8 | from transformer_lens.utils import solu 9 | 10 | 11 | @pytest.fixture 12 | def cfg() -> Dict[str, Any]: 13 | return { 14 | "n_layers": 12, 15 | "n_ctx": 1024, 16 | "d_head": 64, 17 | "d_model": 128, 18 | "d_mlp": 256, 19 | "dtype": torch.float32, 20 | "act_fn": "solu_ln", 21 | "normalization_type": "LN", 22 | "load_in_4bit": False, 23 | } 24 | 25 | 26 | def test_initialization(cfg: Dict[str, Any]): 27 | model = GatedMLP(cfg) 28 | assert isinstance(model.W_in, nn.Parameter) 29 | assert isinstance(model.W_gate, nn.Parameter) 30 | assert isinstance(model.W_out, nn.Parameter) 31 | assert isinstance(model.b_in, nn.Parameter) 32 | assert isinstance(model.b_out, nn.Parameter) 33 | assert model.act_fn == solu 34 | assert isinstance(model.ln, LayerNorm) 35 | 36 | 37 | def test_forward(cfg: Dict[str, Any]): 38 | model = GatedMLP(cfg) 39 | x = torch.randn(2, 10, cfg["d_model"]) 40 | output = model(x) 41 | assert output.shape == (2, 10, cfg["d_model"]) 42 | -------------------------------------------------------------------------------- /tests/unit/components/mlps/test_mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import pytest 4 | import torch 5 | 6 | from transformer_lens.components import LayerNorm 7 | from transformer_lens.components.mlps.mlp import MLP 8 | from transformer_lens.hook_points import HookPoint 9 | 10 | 11 | @pytest.fixture 12 | def cfg() -> Dict[str, Any]: 13 | return { 14 | "n_layers": 12, 15 | "n_ctx": 1024, 16 | "d_head": 64, 17 | "d_model": 128, 18 | "d_mlp": 256, 19 | "dtype": torch.float32, 20 | "act_fn": "solu_ln", 21 | "normalization_type": "LN", 22 | "load_in_4bit": False, 23 | } 24 | 25 | 26 | def test_initialization(cfg: Dict[str, Any]): 27 | MLP(cfg) 28 | 29 | 30 | def test_forward_without_layer_norm(cfg: Dict[str, Any]): 31 | cfg["act_fn"] = "solu" 32 | 33 | model = MLP(cfg) 34 | 35 | input = torch.full((1, 1, 128), 0.085) 36 | 37 | result = model(input) 38 | 39 | assert result.shape == (1, 1, 128) 40 | 41 | 42 | def test_forward_with_layer_norm(cfg: Dict[str, Any]): 43 | model = MLP(cfg) 44 | assert isinstance(model.hook_mid, HookPoint) 45 | assert isinstance(model.ln, LayerNorm) 46 | 47 | input = torch.full((1, 1, 128), 0.85) 48 | result = model(input) 49 | assert result.shape == (1, 1, 128) 50 | -------------------------------------------------------------------------------- /tests/unit/components/mlps/test_moe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformer_lens.components import MoE 4 | 5 | 6 | def test_forward(): 7 | cfg = { 8 | "d_model": 32, 9 | "d_mlp": 14336, 10 | "d_head": 4, 11 | "num_experts": 32, 12 | "n_layers": 16, 13 | "n_ctx": 2048, 14 | "experts_per_token": 4, 15 | "gated_mlp": True, 16 | "act_fn": "silu", 17 | } 18 | moe = MoE(cfg) 19 | 20 | x = torch.rand((1, 4, 32)) 21 | moe(x) 22 | -------------------------------------------------------------------------------- /tests/unit/components/test_abstract_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformer_lens.components import AbstractAttention 4 | 5 | 6 | def test_create_alibi_slope(): 7 | n_ctx = 100 8 | 9 | # Expected result computed non-vectorized way 10 | expected = torch.zeros((n_ctx, n_ctx)) 11 | for row in range(n_ctx): 12 | for col in range(n_ctx): 13 | expected[row, col] = float(min(col - row, 0)) 14 | 15 | # Check against the method's vectorized version 16 | result = AbstractAttention.create_alibi_slope(n_ctx) 17 | assert torch.allclose(expected, result) 18 | 19 | 20 | def test_create_alibi_bias(): 21 | n_heads = 2 22 | n_ctx = 4 23 | 24 | result = AbstractAttention.create_alibi_bias(n_heads, n_ctx, torch.device("cpu")) 25 | 26 | for matrix in result: 27 | n_row, n_col = matrix.size() 28 | slope = -matrix[1, 0] 29 | # Check if upper triangle is all zeros 30 | assert torch.equal(torch.triu(matrix), torch.zeros_like(matrix)) 31 | 32 | ref_lower_triangle = torch.zeros_like(matrix) 33 | for i in range(1, n_row): 34 | for j in range(i): 35 | ref_lower_triangle[i, j] = -slope * (i - j) 36 | 37 | # Check if the lower triangle is decreasing by a constant slope (towards the bottom left corner). 38 | assert torch.equal( 39 | torch.tril(matrix, diagonal=-1), torch.tril(ref_lower_triangle, diagonal=-1) 40 | ) 41 | -------------------------------------------------------------------------------- /tests/unit/components/test_attention.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import pytest 3 | import torch 4 | import torch.nn as nn 5 | from transformers.utils import is_bitsandbytes_available 6 | 7 | from transformer_lens.components import Attention 8 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 9 | from transformer_lens.utilities.attention import complex_attn_linear 10 | 11 | if is_bitsandbytes_available(): 12 | from bitsandbytes.nn.modules import Params4bit 13 | 14 | 15 | def test_attention_hooked_transformer_config(): 16 | cfg = HookedTransformerConfig( 17 | n_layers=12, 18 | d_model=512, 19 | n_ctx=1024, 20 | d_head=64, 21 | n_heads=8, 22 | load_in_4bit=False, 23 | dtype=torch.float32, 24 | act_fn="relu", 25 | ) 26 | attn = Attention(cfg) 27 | assert attn.cfg == cfg 28 | assert attn.cfg.n_layers == 12 29 | assert attn.cfg.d_model == 512 30 | assert attn.cfg.n_ctx == 1024 31 | assert attn.cfg.d_head == 64 32 | assert attn.cfg.n_heads == 8 33 | assert attn.cfg.load_in_4bit == False 34 | assert attn.cfg.dtype == torch.float32 35 | assert attn.cfg.act_fn == "relu" 36 | 37 | assert isinstance(attn.W_K, nn.Parameter) 38 | assert isinstance(attn.W_V, nn.Parameter) 39 | assert attn.W_K.shape == (cfg.n_heads, cfg.d_model, cfg.d_head) 40 | assert attn.W_V.shape == (cfg.n_heads, cfg.d_model, cfg.d_head) 41 | 42 | assert attn.b_K.shape == (cfg.n_heads, cfg.d_head) 43 | assert attn.b_V.shape == (cfg.n_heads, cfg.d_head) 44 | assert torch.all(attn.b_K == 0) 45 | assert torch.all(attn.b_V == 0) 46 | 47 | 48 | @pytest.mark.skipif(not is_bitsandbytes_available(), reason="bitsandbytes is not available") 49 | def test_attention_load_in_4bit(): 50 | cfg = HookedTransformerConfig( 51 | n_layers=12, 52 | d_model=512, 53 | n_ctx=1024, 54 | d_head=64, 55 | n_heads=8, 56 | load_in_4bit=True, 57 | dtype=torch.float32, 58 | act_fn="relu", 59 | ) 60 | attn = Attention(cfg) 61 | assert attn.cfg == cfg 62 | assert attn.cfg.n_layers == 12 63 | assert attn.cfg.d_model == 512 64 | assert attn.cfg.n_ctx == 1024 65 | assert attn.cfg.d_head == 64 66 | assert attn.cfg.n_heads == 8 67 | assert attn.cfg.load_in_4bit == False 68 | assert attn.cfg.dtype == torch.float32 69 | assert attn.cfg.act_fn == "relu" 70 | 71 | assert isinstance(attn.W_K, Params4bit) 72 | assert isinstance(attn.W_V, Params4bit) 73 | nq = int((cfg.d_model * cfg.d_model) / 2) 74 | assert attn.W_K.data.shape == (nq, 1) 75 | assert attn.W_V.data.shape == (nq, 1) 76 | 77 | assert attn.b_K.shape == (cfg.n_heads, cfg.d_head) 78 | assert attn.b_V.shape == (cfg.n_heads, cfg.d_head) 79 | assert torch.all(attn.b_K == 0) 80 | assert torch.all(attn.b_V == 0) 81 | 82 | 83 | def test_attention_config_dict(): 84 | cfg = { 85 | "n_layers": 12, 86 | "d_model": 512, 87 | "n_ctx": 1024, 88 | "d_head": 64, 89 | "n_heads": 8, 90 | "load_in_4bit": False, 91 | "dtype": torch.float32, 92 | "act_fn": "relu", 93 | } 94 | attn = Attention(cfg) 95 | assert attn.cfg.n_layers == 12 96 | assert attn.cfg.d_model == 512 97 | assert attn.cfg.n_ctx == 1024 98 | assert attn.cfg.d_head == 64 99 | assert attn.cfg.n_heads == 8 100 | assert attn.cfg.load_in_4bit == False 101 | assert attn.cfg.dtype == torch.float32 102 | assert attn.cfg.act_fn == "relu" 103 | 104 | 105 | def test_remove_einsum_from_complex_attn_linear(): 106 | batch = 64 107 | pos = 128 108 | head_index = 8 109 | d_model = 512 110 | d_head = 64 111 | input = torch.randn(batch, pos, head_index, d_model) 112 | w = torch.randn(head_index, d_model, d_head) 113 | b = torch.randn(head_index, d_head) 114 | result_new = complex_attn_linear(input, w, b) 115 | 116 | # Check if new implementation without einsum produces correct shape 117 | assert result_new.shape == (batch, pos, head_index, d_head) 118 | 119 | # Old implementation used einsum 120 | result_old = ( 121 | einops.einsum( 122 | input, 123 | w, 124 | "batch pos head_index d_model, head_index d_model d_head -> batch pos head_index d_head", 125 | ) 126 | + b 127 | ) 128 | 129 | # Check if the results are the same 130 | assert torch.allclose(result_new, result_old, atol=1e-4) 131 | -------------------------------------------------------------------------------- /tests/unit/factored_matrix/test_constructor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from transformer_lens import FactoredMatrix 5 | 6 | 7 | def test_factored_matrix(): 8 | A = torch.randn(5, 3) 9 | B = torch.randn(3, 7) 10 | f = FactoredMatrix(A, B) 11 | 12 | assert torch.equal(f.A, A) 13 | assert torch.equal(f.B, B) 14 | 15 | assert (f.ldim, f.mdim, f.rdim) == (5, 3, 7) 16 | assert not f.has_leading_dims 17 | assert f.shape == (5, 7) 18 | 19 | 20 | def test_factored_matrix_b_leading_dims(): 21 | A = torch.ones((5, 3)) 22 | B = torch.ones((2, 4, 3, 7)) 23 | f = FactoredMatrix(A, B) 24 | 25 | assert f.A.shape == (2, 4, 5, 3) 26 | assert torch.equal(f.B, B) 27 | 28 | assert (f.ldim, f.mdim, f.rdim) == (5, 3, 7) 29 | assert f.has_leading_dims 30 | assert f.shape == (2, 4, 5, 7) 31 | 32 | 33 | def test_factored_matrix_a_b_leading_dims(): 34 | A = torch.ones((4, 5, 3)) 35 | B = torch.ones((2, 4, 3, 7)) 36 | f = FactoredMatrix(A, B) 37 | 38 | assert f.A.shape == (2, 4, 5, 3) 39 | assert torch.equal(f.B, B) 40 | 41 | assert (f.ldim, f.mdim, f.rdim) == (5, 3, 7) 42 | assert f.has_leading_dims 43 | assert f.shape == (2, 4, 5, 7) 44 | 45 | 46 | def test_factored_matrix_broadcast_mismatch(): 47 | A = torch.ones((9, 5, 3)) 48 | B = torch.ones((2, 4, 3, 7)) 49 | 50 | with pytest.raises(RuntimeError) as e: 51 | FactoredMatrix(A, B) 52 | 53 | assert "Shape mismatch" in str(e.value) 54 | 55 | 56 | @pytest.mark.skip( 57 | """ 58 | AssertionError will not be reached due to jaxtyping argument consistency 59 | checks, which are enabled at test time but not run time. 60 | 61 | See https://github.com/TransformerLensOrg/TransformerLens/issues/190 62 | """ 63 | ) 64 | def test_factored_matrix_inner_mismatch(): 65 | A = torch.ones((2, 3, 4)) 66 | B = torch.ones((2, 3, 5)) 67 | with pytest.raises(AssertionError) as e: 68 | FactoredMatrix(A, B) 69 | 70 | assert "inner dimension" in str(e.value) 71 | -------------------------------------------------------------------------------- /tests/unit/factored_matrix/test_get_item.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.testing import assert_close 4 | 5 | from transformer_lens import FactoredMatrix 6 | 7 | 8 | @pytest.fixture 9 | def sample_factored_matrix(): 10 | A = torch.rand(2, 2, 2, 2, 2) 11 | B = torch.rand(2, 2, 2, 2, 2) 12 | return FactoredMatrix(A, B) 13 | 14 | 15 | def test_getitem_int(sample_factored_matrix): 16 | result = sample_factored_matrix[0] 17 | assert_close(result.A, sample_factored_matrix.A[0]) 18 | assert_close(result.B, sample_factored_matrix.B[0]) 19 | 20 | 21 | def test_getitem_tuple(sample_factored_matrix): 22 | result = sample_factored_matrix[(0, 1)] 23 | assert_close(result.A, sample_factored_matrix.A[0, 1]) 24 | assert_close(result.B, sample_factored_matrix.B[0, 1]) 25 | 26 | 27 | def test_getitem_slice(sample_factored_matrix): 28 | result = sample_factored_matrix[:, 1] 29 | assert_close(result.A, sample_factored_matrix.A[:, 1]) 30 | assert_close(result.B, sample_factored_matrix.B[:, 1]) 31 | 32 | 33 | def test_getitem_error(sample_factored_matrix): 34 | with pytest.raises(IndexError): 35 | _ = sample_factored_matrix[(0, 1, 2)] 36 | 37 | 38 | def test_getitem_multiple_slices(sample_factored_matrix): 39 | result = sample_factored_matrix[:, :, 1] 40 | assert_close(result.A, sample_factored_matrix.A[:, :, 1]) 41 | assert_close(result.B, sample_factored_matrix.B[:, :, 1]) 42 | 43 | 44 | def test_index_dimension_get_line(sample_factored_matrix): 45 | result = sample_factored_matrix[0, 0, 0, 1] 46 | assert_close(result.AB.squeeze(), sample_factored_matrix.AB[0, 0, 0, 1]) 47 | 48 | 49 | def test_index_dimension_get_element(sample_factored_matrix): 50 | result = sample_factored_matrix[0, 0, 0, 0, 1] 51 | assert_close(result.AB.squeeze(), sample_factored_matrix.AB[0, 0, 0, 0, 1]) 52 | 53 | 54 | def test_index_dimension_too_big(sample_factored_matrix): 55 | with pytest.raises(Exception): 56 | _ = sample_factored_matrix[1, 1, 1, 1, 1, 1] 57 | 58 | 59 | def test_getitem_sequences(sample_factored_matrix): 60 | A_idx = [0, 1] 61 | B_idx = [0] 62 | result = sample_factored_matrix[:, :, :, A_idx, B_idx] 63 | assert_close(result.A, sample_factored_matrix.A[:, :, :, A_idx, :]) 64 | assert_close(result.B, sample_factored_matrix.B[:, :, :, :, B_idx]) 65 | 66 | 67 | def test_getitem_sequences_and_ints(sample_factored_matrix): 68 | A_idx = [0, 1] 69 | B_idx = 0 70 | result = sample_factored_matrix[:, :, :, A_idx, B_idx] 71 | assert_close(result.A, sample_factored_matrix.A[:, :, :, A_idx, :]) 72 | # we squeeze result.B, because indexing by ints is designed not to delete dimensions 73 | assert_close(result.B.squeeze(-1), sample_factored_matrix.B[:, :, :, :, B_idx]) 74 | 75 | 76 | def test_getitem_tensors(sample_factored_matrix): 77 | A_idx = torch.tensor([0, 1]) 78 | B_idx = torch.tensor([0]) 79 | result = sample_factored_matrix[:, :, :, A_idx, B_idx] 80 | assert_close(result.A, sample_factored_matrix.A[:, :, :, A_idx, :]) 81 | assert_close(result.B, sample_factored_matrix.B[:, :, :, :, B_idx]) 82 | -------------------------------------------------------------------------------- /tests/unit/factored_matrix/test_multiply_by_scalar.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import pytest 4 | import torch 5 | from torch.testing import assert_close 6 | 7 | from transformer_lens import FactoredMatrix 8 | 9 | 10 | # This test function is parametrized with different types of scalars, including non-scalar tensors and arrays, to check that the correct errors are raised. 11 | # Considers cases with and without leading dimensions as well as left and right multiplication. 12 | @pytest.mark.parametrize( 13 | "scalar, error_expected", 14 | [ 15 | # Test cases with different types of scalar values. 16 | (torch.rand(1), None), # 1-element Tensor. No error expected. 17 | (random.random(), None), # float. No error expected. 18 | (random.randint(-100, 100), None), # int. No error expected. 19 | # Test cases with non-scalar values that are expected to raise errors. 20 | ( 21 | torch.rand(2, 2), 22 | AssertionError, 23 | ), # Non-scalar Tensor. AssertionError expected. 24 | (torch.rand(2), AssertionError), # Non-scalar Tensor. AssertionError expected. 25 | ], 26 | ) 27 | @pytest.mark.parametrize("leading_dim", [False, True]) 28 | @pytest.mark.parametrize("multiply_from_left", [False, True]) 29 | def test_multiply(scalar, leading_dim, multiply_from_left, error_expected): 30 | # Prepare a FactoredMatrix, with or without leading dimensions 31 | if leading_dim: 32 | a = torch.rand(6, 2, 3) 33 | b = torch.rand(6, 3, 4) 34 | else: 35 | a = torch.rand(2, 3) 36 | b = torch.rand(3, 4) 37 | 38 | fm = FactoredMatrix(a, b) 39 | 40 | if error_expected: 41 | # If an error is expected, check that the correct exception is raised. 42 | with pytest.raises(error_expected): 43 | if multiply_from_left: 44 | _ = fm * scalar 45 | else: 46 | _ = scalar * fm 47 | else: 48 | # If no error is expected, check that the multiplication results in the correct value. 49 | # Use FactoredMatrix.AB to calculate the product of the two factor matrices before comparing with the expected value. 50 | if multiply_from_left: 51 | assert_close((fm * scalar).AB, (a @ b) * scalar) 52 | else: 53 | assert_close((scalar * fm).AB, scalar * (a @ b)) 54 | # This next test is implementation dependant and can be broken and removed at any time! 55 | # It checks that the multiplication is performed on the A factor matrix. 56 | if multiply_from_left: 57 | assert_close((fm * scalar).A, a * scalar) 58 | else: 59 | assert_close((scalar * fm).A, scalar * a) 60 | -------------------------------------------------------------------------------- /tests/unit/factored_matrix/test_multiply_by_vector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.testing import assert_close 3 | 4 | from transformer_lens import FactoredMatrix 5 | 6 | 7 | def test_left_matmul_by_vector_left(): 8 | a = torch.rand(2, 3) 9 | b = torch.rand(3, 4) 10 | 11 | fm = FactoredMatrix(a, b) 12 | vector = torch.rand(4) 13 | 14 | assert_close(fm @ vector, (a @ b) @ vector) 15 | 16 | 17 | def test_left_matmul_by_vector_leading_dim(): 18 | a = torch.rand(6, 2, 3) 19 | b = torch.rand(6, 3, 4) 20 | 21 | fm = FactoredMatrix(a, b) 22 | vector = torch.rand(4) 23 | 24 | assert_close(fm @ vector, (a @ b) @ vector) 25 | 26 | 27 | def test_right_matmul_by_vector(): 28 | a = torch.rand(2, 3) 29 | b = torch.rand(3, 4) 30 | 31 | fm = FactoredMatrix(a, b) 32 | vector = torch.rand(2) 33 | 34 | assert_close(vector @ fm, vector @ (a @ b)) 35 | 36 | 37 | def test_right_matmul_by_vector_leading_dim(): 38 | a = torch.rand(6, 2, 3) 39 | b = torch.rand(6, 3, 4) 40 | 41 | fm = FactoredMatrix(a, b) 42 | vector = torch.rand(2) 43 | 44 | assert_close(vector @ fm, vector @ (a @ b)) 45 | -------------------------------------------------------------------------------- /tests/unit/factories/test_activation_function_factory.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from transformer_lens.factories.activation_function_factory import ( 5 | ActivationFunctionFactory, 6 | ) 7 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 8 | from transformer_lens.utilities.activation_functions import SUPPORTED_ACTIVATIONS 9 | 10 | 11 | @pytest.mark.parametrize("act_function", SUPPORTED_ACTIVATIONS.keys()) 12 | def test_pick_activation_function_runs(act_function): 13 | config = HookedTransformerConfig.unwrap( 14 | {"n_layers": 12, "n_ctx": 1024, "d_head": 64, "d_model": 128, "act_fn": act_function} 15 | ) 16 | function = ActivationFunctionFactory.pick_activation_function(config) 17 | assert function is not None 18 | dummy_data = torch.zeros((1, 4, 32)) 19 | result = function(dummy_data) 20 | assert isinstance(result, torch.Tensor) 21 | -------------------------------------------------------------------------------- /tests/unit/factories/test_mlp_factory.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from transformers.utils import is_bitsandbytes_available 3 | 4 | from transformer_lens.components.mlps.gated_mlp import GatedMLP 5 | from transformer_lens.components.mlps.gated_mlp_4bit import GatedMLP4Bit 6 | from transformer_lens.components.mlps.mlp import MLP 7 | from transformer_lens.factories.mlp_factory import MLPFactory 8 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 9 | 10 | 11 | def test_create_mlp_basic(): 12 | config = HookedTransformerConfig.unwrap( 13 | { 14 | "n_layers": 12, 15 | "n_ctx": 1024, 16 | "d_head": 64, 17 | "d_model": 128, 18 | "act_fn": "solu", 19 | } 20 | ) 21 | mlp = MLPFactory.create_mlp(config) 22 | assert isinstance(mlp, MLP) 23 | 24 | 25 | def test_create_mlp_gated(): 26 | config = HookedTransformerConfig.unwrap( 27 | { 28 | "n_layers": 12, 29 | "n_ctx": 1024, 30 | "d_head": 64, 31 | "d_model": 128, 32 | "act_fn": "solu", 33 | "gated_mlp": True, 34 | } 35 | ) 36 | mlp = MLPFactory.create_mlp(config) 37 | assert isinstance(mlp, GatedMLP) 38 | 39 | 40 | @pytest.mark.skipif( 41 | not is_bitsandbytes_available(), 42 | reason="4 bit not available on current architecture", 43 | ) 44 | def test_create_mlp_gated_4bit(): 45 | config = HookedTransformerConfig.unwrap( 46 | { 47 | "n_layers": 12, 48 | "n_ctx": 1024, 49 | "d_head": 64, 50 | "d_model": 128, 51 | "act_fn": "solu", 52 | "gated_mlp": True, 53 | "load_in_4bit": True, 54 | } 55 | ) 56 | mlp = MLPFactory.create_mlp(config) 57 | assert isinstance(mlp, GatedMLP4Bit) 58 | 59 | 60 | def test_create_moe(): 61 | if is_bitsandbytes_available(): 62 | config = HookedTransformerConfig.unwrap( 63 | { 64 | "n_layers": 12, 65 | "n_ctx": 1024, 66 | "d_head": 64, 67 | "d_model": 128, 68 | "act_fn": "solu", 69 | "gated_mlp": True, 70 | "num_experts": 32, 71 | } 72 | ) 73 | mlp = MLPFactory.create_mlp(config) 74 | assert isinstance(mlp, GatedMLP4Bit) 75 | -------------------------------------------------------------------------------- /tests/unit/pretrained_weight_conversions/test_neo.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | 3 | import torch 4 | 5 | from transformer_lens import HookedTransformer 6 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 7 | from transformer_lens.pretrained.weight_conversions.neo import convert_neo_weights 8 | 9 | 10 | def get_default_config(): 11 | return HookedTransformerConfig( 12 | d_model=128, d_head=8, n_heads=16, n_ctx=128, n_layers=1, d_vocab=50257, attn_only=True 13 | ) 14 | 15 | 16 | def test_convert_neo_weights_exposed(): 17 | cfg = get_default_config() 18 | 19 | class MockNeo: 20 | def __init__(self): 21 | self.transformer = HookedTransformer(cfg) 22 | self.transformer.wte = torch.nn.Embedding(cfg.d_vocab, cfg.d_model) 23 | self.transformer.wpe = torch.nn.Embedding(cfg.n_ctx, cfg.d_model) 24 | self.transformer.final_norm = torch.nn.LayerNorm(cfg.d_model) 25 | self.transformer.h = [mock.Mock() for _ in range(cfg.n_layers)] 26 | self.lm_head = torch.nn.Linear(cfg.d_model, cfg.d_vocab) 27 | 28 | for layer in self.transformer.h: 29 | layer.ln_1 = torch.nn.LayerNorm(cfg.d_model) 30 | layer.ln_2 = torch.nn.LayerNorm(cfg.d_model) 31 | layer.attn = mock.Mock() 32 | layer.attn.attention = mock.Mock() 33 | layer.attn.attention.q_proj = torch.nn.Linear(cfg.d_model, cfg.d_model) 34 | layer.attn.attention.k_proj = torch.nn.Linear(cfg.d_model, cfg.d_model) 35 | layer.attn.attention.v_proj = torch.nn.Linear(cfg.d_model, cfg.d_model) 36 | layer.attn.attention.out_proj = torch.nn.Linear(cfg.d_model, cfg.d_model) 37 | layer.mlp = mock.Mock() 38 | layer.mlp.c_fc = torch.nn.Linear(cfg.d_model, cfg.d_model) 39 | layer.mlp.c_proj = torch.nn.Linear(cfg.d_model, cfg.d_model) 40 | 41 | self.transformer.ln_f = torch.nn.LayerNorm(cfg.d_model) 42 | 43 | neo = MockNeo() 44 | 45 | try: 46 | convert_neo_weights(neo, cfg) 47 | function_works = True 48 | except Exception as e: 49 | function_works = False 50 | print(f"The convert_neo_weights function raised an error: {e}") 51 | 52 | assert function_works 53 | -------------------------------------------------------------------------------- /tests/unit/test_hook_points.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | 3 | from transformer_lens.hook_points import HookPoint 4 | 5 | 6 | def setup_hook_point_and_hook(): 7 | hook_point = HookPoint() 8 | 9 | def hook(activation, hook): 10 | return activation 11 | 12 | return hook_point, hook 13 | 14 | 15 | @mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) 16 | def test_add_hook_forward(mock_handle): 17 | mock_handle.return_value.id = 0 18 | hook_point, hook = setup_hook_point_and_hook() 19 | hook_point.add_hook(hook, dir="fwd") 20 | assert len(hook_point.fwd_hooks) == 1 21 | 22 | 23 | @mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) 24 | def test_add_hook_backward(mock_handle): 25 | mock_handle.return_value.id = 0 26 | hook_point, hook = setup_hook_point_and_hook() 27 | hook_point.add_hook(hook, dir="bwd") 28 | assert len(hook_point.bwd_hooks) == 1 29 | 30 | 31 | @mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) 32 | def test_add_hook_permanent(mock_handle): 33 | mock_handle.return_value.id = 0 34 | hook_point, hook = setup_hook_point_and_hook() 35 | hook_point.add_hook(hook, dir="fwd", is_permanent=True) 36 | assert hook_point.fwd_hooks[0].is_permanent 37 | 38 | 39 | @mock.patch("torch.utils.hooks.RemovableHandle", autospec=True) 40 | def test_add_hook_with_level(mock_handle): 41 | mock_handle.return_value.id = 0 42 | hook_point, hook = setup_hook_point_and_hook() 43 | hook_point.add_hook(hook, dir="fwd", level=5) 44 | assert hook_point.fwd_hooks[0].context_level == 5 45 | 46 | 47 | @mock.patch("torch.utils.hooks.RemovableHandle") 48 | def test_add_hook_prepend(mock_handle): 49 | mock_handle.id = 0 50 | mock_handle.next_id = 1 51 | 52 | hook_point, _ = setup_hook_point_and_hook() 53 | 54 | def hook1(activation, hook): 55 | return activation 56 | 57 | def hook2(activation, hook): 58 | return activation 59 | 60 | hook_point.add_hook(hook1, dir="fwd") 61 | hook_point.add_hook(hook2, dir="fwd", prepend=True) 62 | 63 | assert len(hook_point.fwd_hooks) == 2 64 | assert hook_point.fwd_hooks[0].hook.id == 2 65 | assert hook_point.fwd_hooks[1].hook.id == 1 66 | -------------------------------------------------------------------------------- /tests/unit/test_hooked_root_module.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock 2 | 3 | from transformer_lens.hook_points import HookedRootModule 4 | 5 | MODEL_NAME = "solu-2l" 6 | 7 | 8 | def test_enable_hook_with_name(): 9 | model = HookedRootModule() 10 | model.mod_dict = {"linear": Mock()} 11 | model.context_level = 5 12 | 13 | hook = lambda x: False 14 | dir = "fwd" 15 | 16 | model._enable_hook_with_name("linear", hook=hook, dir=dir) 17 | 18 | model.mod_dict["linear"].add_hook.assert_called_with(hook, dir="fwd", level=5) 19 | 20 | 21 | def test_enable_hooks_for_points(): 22 | model = HookedRootModule() 23 | model.mod_dict = {} 24 | model.context_level = 5 25 | 26 | hook_points = { 27 | "linear": Mock(), 28 | "attn": Mock(), 29 | } 30 | 31 | enabled = lambda x: x == "attn" 32 | 33 | hook = lambda x: False 34 | dir = "bwd" 35 | 36 | print(hook_points.items()) 37 | model._enable_hooks_for_points( 38 | hook_points=hook_points.items(), enabled=enabled, hook=hook, dir=dir 39 | ) 40 | 41 | hook_points["attn"].add_hook.assert_called_with(hook, dir="bwd", level=5) 42 | hook_points["linear"].add_hook.assert_not_called() 43 | 44 | 45 | def test_enable_hook_with_string_param(): 46 | model = HookedRootModule() 47 | model.mod_dict = {"linear": Mock()} 48 | model.context_level = 5 49 | 50 | hook = lambda x: False 51 | dir = "fwd" 52 | 53 | model._enable_hook("linear", hook=hook, dir=dir) 54 | 55 | model.mod_dict["linear"].add_hook.assert_called_with(hook, dir="fwd", level=5) 56 | 57 | 58 | def test_enable_hook_with_callable_param(): 59 | model = HookedRootModule() 60 | model.mod_dict = {"linear": Mock()} 61 | model.hook_dict = { 62 | "linear": Mock(), 63 | "attn": Mock(), 64 | } 65 | model.context_level = 5 66 | 67 | enabled = lambda x: x == "attn" 68 | 69 | hook = lambda x: False 70 | dir = "fwd" 71 | 72 | model._enable_hook(enabled, hook=hook, dir=dir) 73 | 74 | model.mod_dict["linear"].add_hook.assert_not_called() 75 | model.hook_dict["attn"].add_hook.assert_called_with(hook, dir="fwd", level=5) 76 | model.hook_dict["linear"].add_hook.assert_not_called() 77 | -------------------------------------------------------------------------------- /tests/unit/test_hooked_transformer_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests that config passed around TransformerLens can be unwrapped into an actual configuration object 3 | """ 4 | 5 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 6 | 7 | 8 | def test_hooked_transformer_config_object(): 9 | hooked_transformer_config = HookedTransformerConfig( 10 | n_layers=2, d_vocab=100, d_model=6, n_ctx=5, d_head=2, attn_only=True 11 | ) 12 | result = HookedTransformerConfig.unwrap(hooked_transformer_config) 13 | # Assert that the same object was returned 14 | assert result is hooked_transformer_config 15 | 16 | 17 | def test_hooked_transformer_config_dict(): 18 | hooked_transformer_config_dict = { 19 | "n_layers": 2, 20 | "d_vocab": 100, 21 | "d_model": 6, 22 | "n_ctx": 5, 23 | "d_head": 2, 24 | "attn_only": True, 25 | } 26 | result = HookedTransformerConfig.unwrap(hooked_transformer_config_dict) 27 | # Assert that the new returned value has been transformed into a config object 28 | assert isinstance(result, HookedTransformerConfig) 29 | 30 | 31 | def test_is_layer_norm_activation_passes(): 32 | hooked_transformer_config_dict = { 33 | "n_layers": 2, 34 | "d_vocab": 100, 35 | "d_model": 6, 36 | "n_ctx": 5, 37 | "d_head": 2, 38 | "attn_only": True, 39 | "act_fn": "solu_ln", 40 | } 41 | config = HookedTransformerConfig.unwrap(hooked_transformer_config_dict) 42 | assert config.is_layer_norm_activation() 43 | 44 | 45 | def test_is_layer_norm_activation_fails(): 46 | hooked_transformer_config_dict = { 47 | "n_layers": 2, 48 | "d_vocab": 100, 49 | "d_model": 6, 50 | "n_ctx": 5, 51 | "d_head": 2, 52 | "attn_only": True, 53 | "act_fn": "relu", 54 | } 55 | config = HookedTransformerConfig.unwrap(hooked_transformer_config_dict) 56 | assert not config.is_layer_norm_activation() 57 | -------------------------------------------------------------------------------- /tests/unit/test_loading_from_pretrained_utilities.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | 3 | import pytest 4 | 5 | from transformer_lens import HookedTransformer 6 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 7 | from transformer_lens.loading_from_pretrained import fill_missing_keys 8 | 9 | 10 | def get_default_config(): 11 | return HookedTransformerConfig( 12 | d_model=128, d_head=8, n_heads=16, n_ctx=128, n_layers=1, d_vocab=50257, attn_only=True 13 | ) 14 | 15 | 16 | # Successes 17 | 18 | 19 | @mock.patch("logging.warning") 20 | def test_fill_missing_keys(mock_warning): 21 | cfg = get_default_config() 22 | model = HookedTransformer(cfg) 23 | default_state_dict = model.state_dict() 24 | 25 | incomplete_state_dict = {k: v for k, v in default_state_dict.items() if "W_" not in k} 26 | 27 | filled_state_dict = fill_missing_keys(model, incomplete_state_dict) 28 | 29 | assert set(filled_state_dict.keys()) == set(default_state_dict.keys()) 30 | 31 | # Check that warnings were issued for missing weight matrices 32 | for key in default_state_dict: 33 | if "W_" in key and key not in incomplete_state_dict: 34 | mock_warning.assert_any_call( 35 | f"Missing key for a weight matrix in pretrained, filled in with an empty tensor: {key}" 36 | ) 37 | 38 | 39 | def test_fill_missing_keys_with_hf_model_keys(): 40 | cfg = get_default_config() 41 | model = HookedTransformer(cfg) 42 | default_state_dict = model.state_dict() 43 | 44 | incomplete_state_dict = {k: v for k, v in default_state_dict.items() if "hf_model" not in k} 45 | 46 | filled_state_dict = fill_missing_keys(model, incomplete_state_dict) 47 | 48 | expected_keys = set(default_state_dict.keys()) - { 49 | k for k in default_state_dict.keys() if "hf_model" in k 50 | } 51 | assert set(filled_state_dict.keys()) == expected_keys 52 | 53 | 54 | def test_fill_missing_keys_no_missing_keys(): 55 | cfg = get_default_config() 56 | model = HookedTransformer(cfg) 57 | default_state_dict = model.state_dict() 58 | 59 | filled_state_dict = fill_missing_keys(model, default_state_dict) 60 | 61 | assert filled_state_dict == default_state_dict 62 | 63 | 64 | # Failures 65 | 66 | 67 | def test_fill_missing_keys_raises_error_on_invalid_model(): 68 | invalid_model = None 69 | default_state_dict = {} 70 | 71 | with pytest.raises(AttributeError): 72 | fill_missing_keys(invalid_model, default_state_dict) 73 | -------------------------------------------------------------------------------- /tests/unit/test_make_docs.py: -------------------------------------------------------------------------------- 1 | """Make Docs Tests.""" 2 | 3 | import pytest 4 | 5 | from docs.make_docs import get_config, get_property 6 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 7 | 8 | 9 | def test_get_config(): 10 | """Test get config with attn-only-1l model.""" 11 | config: HookedTransformerConfig = get_config("attn-only-1l") 12 | assert config.attn_only is True 13 | 14 | 15 | def test_get_property(): 16 | """Test get property with attn-only-1l model.""" 17 | act_fn = get_property("act_fn", "attn-only-1l") 18 | assert act_fn == "attn_only" 19 | 20 | n_params = get_property("n_params", "attn-only-1l") 21 | assert n_params == "1.0M" 22 | 23 | n_layers = get_property("n_layers", "attn-only-1l") 24 | assert n_layers == 1 25 | 26 | d_model = get_property("d_model", "attn-only-1l") 27 | assert d_model == 512 28 | 29 | n_heads = get_property("n_heads", "attn-only-1l") 30 | assert n_heads == 8 31 | 32 | n_ctx = get_property("n_ctx", "attn-only-1l") 33 | assert n_ctx == 1024 34 | 35 | d_vocab = get_property("d_vocab", "attn-only-1l") 36 | assert d_vocab == 48262 37 | 38 | d_head = get_property("d_head", "attn-only-1l") 39 | assert d_head == 64 40 | 41 | d_mlp = get_property("d_mlp", "attn-only-1l") 42 | assert d_mlp == 2048 43 | 44 | n_key_value_heads = get_property("n_key_value_heads", "attn-only-1l") 45 | assert n_key_value_heads is None 46 | 47 | # Test an unknown property 48 | with pytest.raises(KeyError): 49 | get_property("unknown_property", "attn-only-1l") 50 | -------------------------------------------------------------------------------- /tests/unit/test_split_qkv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformer_lens import HookedTransformer 4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 5 | 6 | 7 | def test_split_qkv_normal_attn_correct(): 8 | """Verifies that the split_qkv_input flag does not change the output for models with normal attention.""" 9 | d_model = 128 10 | d_head = 8 11 | n_heads = 16 12 | n_ctx = 128 13 | n_layers = 1 14 | d_vocab = 10 15 | 16 | cfg = HookedTransformerConfig( 17 | d_model=d_model, 18 | d_head=d_head, 19 | n_heads=n_heads, 20 | n_ctx=n_ctx, 21 | n_layers=n_layers, 22 | attn_only=True, 23 | d_vocab=d_vocab, 24 | ) 25 | 26 | model = HookedTransformer(cfg) 27 | assert model.cfg.use_split_qkv_input is False 28 | 29 | x = torch.arange(1, 9).unsqueeze(0) 30 | normal_output = model(x) 31 | 32 | model.set_use_split_qkv_input(True) 33 | assert model.cfg.use_split_qkv_input is True 34 | 35 | split_output = model(x) 36 | 37 | assert torch.allclose(normal_output, split_output, atol=1e-6) 38 | 39 | 40 | def test_split_qkv_grouped_query_attn_correct(): 41 | """Verifies that the split_qkv_input flag does not change the output for models with grouped query attention.""" 42 | 43 | d_model = 128 44 | d_head = 8 45 | n_heads = 16 46 | n_ctx = 128 47 | n_key_value_heads = 2 48 | n_layers = 1 49 | d_vocab = 10 50 | 51 | cfg = HookedTransformerConfig( 52 | d_model=d_model, 53 | d_head=d_head, 54 | n_heads=n_heads, 55 | n_ctx=n_ctx, 56 | n_key_value_heads=n_key_value_heads, 57 | n_layers=n_layers, 58 | attn_only=True, 59 | d_vocab=d_vocab, 60 | ) 61 | 62 | model = HookedTransformer(cfg) 63 | assert model.cfg.use_split_qkv_input is False 64 | 65 | x = torch.arange(1, 9).unsqueeze(0) 66 | normal_output = model(x) 67 | 68 | model.set_use_split_qkv_input(True) 69 | assert model.cfg.use_split_qkv_input is True 70 | 71 | split_output = model(x) 72 | 73 | assert torch.allclose(normal_output, split_output, atol=1e-6) 74 | -------------------------------------------------------------------------------- /tests/unit/test_use_attn_result.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformer_lens import HookedTransformer 4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 5 | 6 | 7 | def test_atten_result_normal_attn_correct(): 8 | """Verifies that the attn_result flag does not change the output for models with normal attention.""" 9 | d_model = 128 10 | d_head = 8 11 | n_heads = 16 12 | n_ctx = 128 13 | n_layers = 1 14 | d_vocab = 10 15 | 16 | cfg = HookedTransformerConfig( 17 | d_model=d_model, 18 | d_head=d_head, 19 | n_heads=n_heads, 20 | n_ctx=n_ctx, 21 | n_layers=n_layers, 22 | attn_only=True, 23 | d_vocab=d_vocab, 24 | ) 25 | 26 | model = HookedTransformer(cfg) 27 | assert model.cfg.use_split_qkv_input is False 28 | 29 | x = torch.arange(1, 9).unsqueeze(0) 30 | normal_output = model(x) 31 | 32 | model.set_use_attn_result(True) 33 | assert model.cfg.use_attn_result is True 34 | 35 | split_output = model(x) 36 | 37 | assert torch.allclose(normal_output, split_output, atol=1e-6) 38 | 39 | 40 | def test_atten_result_grouped_query_attn_correct(): 41 | """Verifies that the atten_result flag does not change the output for models with grouped query attention.""" 42 | 43 | d_model = 128 44 | d_head = 8 45 | n_heads = 16 46 | n_ctx = 128 47 | n_key_value_heads = 2 48 | n_layers = 1 49 | d_vocab = 10 50 | 51 | cfg = HookedTransformerConfig( 52 | d_model=d_model, 53 | d_head=d_head, 54 | n_heads=n_heads, 55 | n_ctx=n_ctx, 56 | n_key_value_heads=n_key_value_heads, 57 | n_layers=n_layers, 58 | attn_only=True, 59 | d_vocab=d_vocab, 60 | ) 61 | 62 | model = HookedTransformer(cfg) 63 | assert model.cfg.use_split_qkv_input is False 64 | 65 | x = torch.arange(1, 9).unsqueeze(0) 66 | normal_output = model(x) 67 | 68 | model.set_use_attn_result(True) 69 | assert model.cfg.use_attn_result is True 70 | 71 | split_output = model(x) 72 | 73 | assert torch.allclose(normal_output, split_output, atol=1e-6) 74 | -------------------------------------------------------------------------------- /tests/unit/utilities/test_devices.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock 2 | 3 | import torch 4 | 5 | from transformer_lens.utilities.devices import ( 6 | calculate_available_device_cuda_memory, 7 | determine_available_memory_for_available_devices, 8 | sort_devices_based_on_available_memory, 9 | ) 10 | 11 | 12 | def mock_available_devices(memory_stats: list[tuple[int, int]]): 13 | torch.cuda.device_count = Mock(return_value=len(memory_stats)) 14 | 15 | def device_props_return(*args, **kwargs): 16 | total_memory = memory_stats[args[0]][0] 17 | device_props = Mock() 18 | device_props.total_memory = total_memory 19 | return device_props 20 | 21 | def memory_allocated_return(*args, **kwargs): 22 | return memory_stats[args[0]][1] 23 | 24 | torch.cuda.get_device_properties = Mock(side_effect=device_props_return) 25 | torch.cuda.memory_allocated = Mock(side_effect=memory_allocated_return) 26 | 27 | 28 | def test_calculate_available_device_cuda_memory(): 29 | mock_available_devices([(80, 40)]) 30 | 31 | result = calculate_available_device_cuda_memory(0) 32 | assert result == 40 33 | 34 | 35 | def test_determine_available_memory_for_available_devices(): 36 | mock_available_devices( 37 | [ 38 | (80, 60), 39 | (80, 15), 40 | (80, 40), 41 | ] 42 | ) 43 | 44 | result = determine_available_memory_for_available_devices(3) 45 | 46 | assert result == [ 47 | (0, 20), 48 | (1, 65), 49 | (2, 40), 50 | ] 51 | 52 | 53 | def test_sort_devices_based_on_available_memory(): 54 | devices = [ 55 | (0, 20), 56 | (1, 65), 57 | (2, 40), 58 | ] 59 | 60 | result = sort_devices_based_on_available_memory(devices) 61 | 62 | assert result == [ 63 | (1, 65), 64 | (2, 40), 65 | (0, 20), 66 | ] 67 | -------------------------------------------------------------------------------- /transformer_lens/__init__.py: -------------------------------------------------------------------------------- 1 | from . import hook_points 2 | from . import utils 3 | from . import evals 4 | from .past_key_value_caching import ( 5 | HookedTransformerKeyValueCache, 6 | HookedTransformerKeyValueCacheEntry, 7 | ) 8 | from . import components 9 | from . import factories 10 | from .HookedTransformerConfig import HookedTransformerConfig 11 | from .FactoredMatrix import FactoredMatrix 12 | from .ActivationCache import ActivationCache 13 | from .HookedTransformer import HookedTransformer 14 | from .SVDInterpreter import SVDInterpreter 15 | from .HookedEncoder import HookedEncoder 16 | from .HookedEncoderDecoder import HookedEncoderDecoder 17 | from .BertNextSentencePrediction import BertNextSentencePrediction 18 | from . import head_detector 19 | from . import loading_from_pretrained as loading 20 | from . import patching 21 | from . import train 22 | 23 | from .past_key_value_caching import ( 24 | HookedTransformerKeyValueCache as EasyTransformerKeyValueCache, 25 | HookedTransformerKeyValueCacheEntry as EasyTransformerKeyValueCacheEntry, 26 | ) 27 | from .HookedTransformer import HookedTransformer as EasyTransformer 28 | from .HookedTransformerConfig import HookedTransformerConfig as EasyTransformerConfig 29 | -------------------------------------------------------------------------------- /transformer_lens/components/__init__.py: -------------------------------------------------------------------------------- 1 | """Hooked Transformer Components. 2 | 3 | This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) 4 | needed to create many different types of generative language models. They are used by 5 | :class:`transformer_lens.HookedTransformer`. 6 | """ 7 | 8 | # Independent classes 9 | from .abstract_attention import AbstractAttention 10 | from .layer_norm import LayerNorm 11 | from .layer_norm_pre import LayerNormPre 12 | from .pos_embed import PosEmbed 13 | from .rms_norm import RMSNorm 14 | from .rms_norm_pre import RMSNormPre 15 | from .token_typed_embed import TokenTypeEmbed 16 | from .unembed import Unembed 17 | 18 | # Only dependent on independent modules 19 | from .attention import Attention 20 | from .bert_mlm_head import BertMLMHead 21 | from .bert_nsp_head import BertNSPHead 22 | from .bert_pooler import BertPooler 23 | from .embed import Embed 24 | from .grouped_query_attention import GroupedQueryAttention 25 | from .mlps.gated_mlp import GatedMLP 26 | from .mlps.mlp import MLP 27 | 28 | # Interdependent modules 29 | from .bert_block import BertBlock 30 | from .bert_embed import BertEmbed 31 | from .mlps.moe import MoE 32 | from .transformer_block import TransformerBlock 33 | from .t5_attention import T5Attention 34 | from .t5_block import T5Block 35 | -------------------------------------------------------------------------------- /transformer_lens/components/attention.py: -------------------------------------------------------------------------------- 1 | """Hooked Transformer Attention Component. 2 | 3 | This module contains all the component :class:`Attention`. 4 | """ 5 | from typing import Dict, Optional, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | from transformers.utils import is_bitsandbytes_available 10 | 11 | from transformer_lens.components import AbstractAttention 12 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 13 | 14 | if is_bitsandbytes_available(): 15 | from bitsandbytes.nn.modules import Params4bit 16 | 17 | 18 | # Attention 19 | class Attention(AbstractAttention): 20 | def __init__( 21 | self, 22 | cfg: Union[Dict, HookedTransformerConfig], 23 | attn_type: str = "global", 24 | layer_id: Optional[int] = None, 25 | ): 26 | """Attention Block - params have shape [head_index, d_model, d_head] (or [head_index, d_head, d_model] for W_O) and multiply on the right. attn_scores refers to query key dot product immediately before attention softmax 27 | 28 | Convention: All attention pattern-style matrices have shape [batch, head_index, query_pos, key_pos] 29 | 30 | Args: 31 | cfg (Union[Dict, HookedTransformerConfig]): Config 32 | attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global". 33 | layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None. 34 | """ 35 | super().__init__(cfg, attn_type, layer_id) 36 | self.cfg = HookedTransformerConfig.unwrap(cfg) 37 | 38 | if self.cfg.load_in_4bit: 39 | # 4-bit quantization convention 40 | nq = int((self.cfg.d_model * self.cfg.d_model) / 2) 41 | self.W_K = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 42 | self.W_V = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 43 | else: 44 | self.W_K = nn.Parameter( 45 | torch.empty( 46 | self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=self.cfg.dtype 47 | ) 48 | ) 49 | self.W_V = nn.Parameter( 50 | torch.empty( 51 | self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=self.cfg.dtype 52 | ) 53 | ) 54 | self.b_K = nn.Parameter( 55 | torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) 56 | ) 57 | self.b_V = nn.Parameter( 58 | torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) 59 | ) 60 | -------------------------------------------------------------------------------- /transformer_lens/components/bert_block.py: -------------------------------------------------------------------------------- 1 | """Hooked Transformer Bert Block Component. 2 | 3 | This module contains all the component :class:`BertBlock`. 4 | """ 5 | from typing import Optional 6 | 7 | import torch 8 | import torch.nn as nn 9 | from jaxtyping import Float 10 | 11 | from transformer_lens.components import Attention, LayerNorm 12 | from transformer_lens.factories.mlp_factory import MLPFactory 13 | from transformer_lens.hook_points import HookPoint 14 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 15 | from transformer_lens.utils import repeat_along_head_dimension 16 | 17 | 18 | class BertBlock(nn.Module): 19 | """ 20 | BERT Block. Similar to the TransformerBlock, except that the LayerNorms are applied after the attention and MLP, rather than before. 21 | """ 22 | 23 | def __init__(self, cfg: HookedTransformerConfig): 24 | super().__init__() 25 | self.cfg = cfg 26 | 27 | self.attn = Attention(cfg) 28 | self.ln1 = LayerNorm(cfg) 29 | self.mlp = MLPFactory.create_mlp(self.cfg) 30 | self.ln2 = LayerNorm(cfg) 31 | 32 | self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model] 33 | self.hook_k_input = HookPoint() # [batch, pos, n_heads, d_model] 34 | self.hook_v_input = HookPoint() # [batch, pos, n_heads, d_model] 35 | 36 | self.hook_attn_out = HookPoint() # [batch, pos, d_model] 37 | self.hook_mlp_in = HookPoint() # [batch, pos, d_model] 38 | self.hook_mlp_out = HookPoint() # [batch, pos, d_model] 39 | self.hook_resid_pre = HookPoint() # [batch, pos, d_model] 40 | self.hook_resid_mid = HookPoint() # [batch, pos, d_model] 41 | self.hook_resid_post = HookPoint() # [batch, pos, d_model] 42 | self.hook_normalized_resid_post = HookPoint() # [batch, pos, d_model] 43 | 44 | def forward( 45 | self, 46 | resid_pre: Float[torch.Tensor, "batch pos d_model"], 47 | additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None, 48 | ) -> Float[torch.Tensor, "batch pos d_model"]: 49 | resid_pre = self.hook_resid_pre(resid_pre) 50 | 51 | query_input = resid_pre 52 | key_input = resid_pre 53 | value_input = resid_pre 54 | 55 | if self.cfg.use_split_qkv_input: 56 | n_heads = self.cfg.n_heads 57 | query_input = self.hook_q_input(repeat_along_head_dimension(query_input, n_heads)) 58 | key_input = self.hook_k_input(repeat_along_head_dimension(key_input, n_heads)) 59 | value_input = self.hook_v_input(repeat_along_head_dimension(value_input, n_heads)) 60 | 61 | attn_out = self.hook_attn_out( 62 | self.attn( 63 | query_input, 64 | key_input, 65 | value_input, 66 | additive_attention_mask=additive_attention_mask, 67 | ) 68 | ) 69 | resid_mid = self.hook_resid_mid(resid_pre + attn_out) 70 | 71 | mlp_in = resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone()) 72 | normalized_resid_mid = self.ln1(mlp_in) 73 | mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_mid)) 74 | resid_post = self.hook_resid_post(normalized_resid_mid + mlp_out) 75 | normalized_resid_post = self.hook_normalized_resid_post(self.ln2(resid_post)) 76 | 77 | return normalized_resid_post 78 | -------------------------------------------------------------------------------- /transformer_lens/components/bert_embed.py: -------------------------------------------------------------------------------- 1 | """Hooked Transformer Bert Embed Component. 2 | 3 | This module contains all the component :class:`BertEmbed`. 4 | """ 5 | from typing import Dict, Optional, Union 6 | 7 | import einops 8 | import torch 9 | import torch.nn as nn 10 | from jaxtyping import Float, Int 11 | 12 | from transformer_lens.components import Embed, LayerNorm, PosEmbed, TokenTypeEmbed 13 | from transformer_lens.hook_points import HookPoint 14 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 15 | 16 | 17 | class BertEmbed(nn.Module): 18 | """ 19 | Custom embedding layer for a BERT-like model. This module computes the sum of the token, positional and token-type embeddings and takes the layer norm of the result. 20 | """ 21 | 22 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 23 | super().__init__() 24 | self.cfg = HookedTransformerConfig.unwrap(cfg) 25 | self.embed = Embed(self.cfg) 26 | self.pos_embed = PosEmbed(self.cfg) 27 | self.token_type_embed = TokenTypeEmbed(self.cfg) 28 | self.ln = LayerNorm(self.cfg) 29 | 30 | self.hook_embed = HookPoint() 31 | self.hook_pos_embed = HookPoint() 32 | self.hook_token_type_embed = HookPoint() 33 | 34 | def forward( 35 | self, 36 | input_ids: Int[torch.Tensor, "batch pos"], 37 | token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 38 | ) -> Float[torch.Tensor, "batch pos d_model"]: 39 | base_index_id = torch.arange(input_ids.shape[1], device=input_ids.device) 40 | index_ids = einops.repeat(base_index_id, "pos -> batch pos", batch=input_ids.shape[0]) 41 | if token_type_ids is None: 42 | token_type_ids = torch.zeros_like(input_ids) 43 | 44 | word_embeddings_out = self.hook_embed(self.embed(input_ids)) 45 | position_embeddings_out = self.hook_pos_embed(self.pos_embed(index_ids)) 46 | token_type_embeddings_out = self.hook_token_type_embed( 47 | self.token_type_embed(token_type_ids) 48 | ) 49 | 50 | embeddings_out = word_embeddings_out + position_embeddings_out + token_type_embeddings_out 51 | layer_norm_out = self.ln(embeddings_out) 52 | return layer_norm_out 53 | -------------------------------------------------------------------------------- /transformer_lens/components/bert_mlm_head.py: -------------------------------------------------------------------------------- 1 | """Hooked Encoder Bert MLM Head Component. 2 | 3 | This module contains all the component :class:`BertMLMHead`. 4 | """ 5 | from typing import Dict, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | from jaxtyping import Float 10 | 11 | from transformer_lens.components import LayerNorm 12 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 13 | 14 | 15 | class BertMLMHead(nn.Module): 16 | """ 17 | Transforms BERT embeddings into logits. The purpose of this module is to predict masked tokens in a sentence. 18 | """ 19 | 20 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 21 | super().__init__() 22 | self.cfg = HookedTransformerConfig.unwrap(cfg) 23 | self.W = nn.Parameter(torch.empty(self.cfg.d_model, self.cfg.d_model, dtype=self.cfg.dtype)) 24 | self.b = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) 25 | self.act_fn = nn.GELU() 26 | self.ln = LayerNorm(self.cfg) 27 | 28 | def forward( 29 | self, resid: Float[torch.Tensor, "batch pos d_model"] 30 | ) -> Float[torch.Tensor, "batch pos d_model"]: 31 | resid = torch.matmul(resid, self.W) + self.b 32 | resid = self.act_fn(resid) 33 | resid = self.ln(resid) 34 | return resid 35 | -------------------------------------------------------------------------------- /transformer_lens/components/bert_nsp_head.py: -------------------------------------------------------------------------------- 1 | """Hooked Encoder Bert NSP Head Component. 2 | 3 | This module contains all the component :class:`BertNSPHead`. 4 | """ 5 | from typing import Dict, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | from jaxtyping import Float 10 | 11 | from transformer_lens.hook_points import HookPoint 12 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 13 | 14 | 15 | class BertNSPHead(nn.Module): 16 | """ 17 | Transforms BERT embeddings into logits. The purpose of this module is to predict whether or not sentence B follows sentence A. 18 | """ 19 | 20 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 21 | super().__init__() 22 | self.cfg = HookedTransformerConfig.unwrap(cfg) 23 | self.W = nn.Parameter(torch.empty(self.cfg.d_model, 2, dtype=self.cfg.dtype)) 24 | self.b = nn.Parameter(torch.zeros(2, dtype=self.cfg.dtype)) 25 | self.hook_nsp_out = HookPoint() 26 | 27 | def forward( 28 | self, resid: Float[torch.Tensor, "batch d_model"] 29 | ) -> Float[torch.Tensor, "batch 2"]: 30 | nsp_logits = torch.matmul(resid, self.W) + self.b 31 | return self.hook_nsp_out(nsp_logits) 32 | -------------------------------------------------------------------------------- /transformer_lens/components/bert_pooler.py: -------------------------------------------------------------------------------- 1 | """Hooked Encoder Bert Pooler Component. 2 | 3 | This module contains all the component :class:`BertPooler`. 4 | """ 5 | from typing import Dict, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | from jaxtyping import Float 10 | 11 | from transformer_lens.hook_points import HookPoint 12 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 13 | 14 | 15 | class BertPooler(nn.Module): 16 | """ 17 | Transforms the [CLS] token representation into a fixed-size sequence embedding. 18 | The purpose of this module is to convert variable-length sequence inputs into a single vector representation suitable for downstream tasks. 19 | (e.g. Next Sentence Prediction) 20 | """ 21 | 22 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 23 | super().__init__() 24 | self.cfg = HookedTransformerConfig.unwrap(cfg) 25 | self.W = nn.Parameter(torch.empty(self.cfg.d_model, self.cfg.d_model, dtype=self.cfg.dtype)) 26 | self.b = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) 27 | self.activation = nn.Tanh() 28 | self.hook_pooler_out = HookPoint() 29 | 30 | def forward( 31 | self, resid: Float[torch.Tensor, "batch pos d_model"] 32 | ) -> Float[torch.Tensor, "batch d_model"]: 33 | first_token_tensor = resid[:, 0] 34 | pooled_output = torch.matmul(first_token_tensor, self.W) + self.b 35 | pooled_output = self.hook_pooler_out(self.activation(pooled_output)) 36 | return pooled_output 37 | -------------------------------------------------------------------------------- /transformer_lens/components/embed.py: -------------------------------------------------------------------------------- 1 | """Hooked Transformer Embed Component. 2 | 3 | This module contains all the component :class:`Embed`. 4 | """ 5 | from typing import Dict, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | from jaxtyping import Float, Int 10 | 11 | from transformer_lens.components import LayerNorm 12 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 13 | 14 | 15 | # Embed & Unembed 16 | class Embed(nn.Module): 17 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 18 | super().__init__() 19 | self.cfg = HookedTransformerConfig.unwrap(cfg) 20 | self.W_E: Float[torch.Tensor, "d_vocab d_model"] = nn.Parameter( 21 | torch.empty(self.cfg.d_vocab, self.cfg.d_model, dtype=self.cfg.dtype) 22 | ) 23 | # Some models (e.g. Bloom) need post embedding layer norm 24 | if self.cfg.post_embedding_ln: 25 | self.ln = LayerNorm(self.cfg) 26 | 27 | def forward( 28 | self, tokens: Int[torch.Tensor, "batch pos"] 29 | ) -> Float[torch.Tensor, "batch pos d_model"]: 30 | # If A has shape [a, b] and B has shape [c, d], then A[:, B] has shape [a, c, d] 31 | # B acts as a tensor of indices into the second dimension (so >=0 and Union[ 45 | Float[torch.Tensor, "batch pos d_model"], 46 | Float[torch.Tensor, "batch pos head_index d_model"], 47 | ]: 48 | if self.cfg.dtype not in [torch.float32, torch.float64]: 49 | x = x.to(torch.float32) 50 | 51 | x = x - x.mean(-1, keepdim=True) # [batch, pos, length] 52 | scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale( 53 | (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt() 54 | ) 55 | x = x / scale # [batch, pos, length] 56 | return self.hook_normalized(x * self.w + self.b).to(self.cfg.dtype) 57 | -------------------------------------------------------------------------------- /transformer_lens/components/layer_norm_pre.py: -------------------------------------------------------------------------------- 1 | """Hooked Transformer Layer Norm Pre Component. 2 | 3 | This module contains all the component :class:`LayerNormPre`. 4 | """ 5 | from typing import Dict, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | from jaxtyping import Float 10 | 11 | from transformer_lens.hook_points import HookPoint 12 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 13 | 14 | 15 | # LayerNormPre 16 | # I fold the LayerNorm weights and biases into later weights and biases. 17 | # This is just the 'center and normalise' part of LayerNorm 18 | # Centering is equivalent to just deleting one direction of residual space, 19 | # and is equivalent to centering the weight matrices of everything writing to the residual stream 20 | # Normalising is a funkier non-linear operation, that projects the residual stream onto the unit hypersphere 21 | class LayerNormPre(nn.Module): 22 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 23 | """LayerNormPre - the 'center and normalise' part of LayerNorm. Length is 24 | normally d_model, but is d_mlp for softmax. Not needed as a parameter. This 25 | should only be used in inference mode after folding in LayerNorm weights""" 26 | super().__init__() 27 | self.cfg = HookedTransformerConfig.unwrap(cfg) 28 | self.eps = self.cfg.eps 29 | 30 | # Adds a hook point for the normalisation scale factor 31 | self.hook_scale = HookPoint() # [batch, pos] 32 | # Hook Normalized captures LN output - here it's a vector with std 1 and mean 0 33 | self.hook_normalized = HookPoint() # [batch, pos, length] 34 | 35 | def forward( 36 | self, 37 | x: Union[ 38 | Float[torch.Tensor, "batch pos d_model"], 39 | Float[torch.Tensor, "batch pos head_index d_model"], 40 | ], 41 | ) -> Union[ 42 | Float[torch.Tensor, "batch pos d_model"], 43 | Float[torch.Tensor, "batch pos head_index d_model"], 44 | ]: 45 | if self.cfg.dtype not in [torch.float32, torch.float64]: 46 | x = x.to(torch.float32) 47 | 48 | x = x - x.mean(-1, keepdim=True) # [batch, pos, length] 49 | scale: Union[ 50 | Float[torch.Tensor, "batch pos 1"], 51 | Float[torch.Tensor, "batch pos head_index 1"], 52 | ] = self.hook_scale((x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()) 53 | return self.hook_normalized(x / scale).to(self.cfg.dtype) 54 | -------------------------------------------------------------------------------- /transformer_lens/components/mlps/can_be_used_as_mlp.py: -------------------------------------------------------------------------------- 1 | """Can Be Used as MLP component. 2 | 3 | This module serves as the base for everything within TransformerLens that can be used like an MLP. 4 | This does not necessarily mean that every component extending this class will be an MLP, but 5 | everything extending this class can be used interchangeably for an MLP. 6 | """ 7 | from typing import Dict, Optional, Union 8 | 9 | import torch 10 | import torch.nn as nn 11 | from jaxtyping import Float 12 | 13 | from transformer_lens.components import LayerNorm, LayerNormPre 14 | from transformer_lens.factories.activation_function_factory import ( 15 | ActivationFunctionFactory, 16 | ) 17 | from transformer_lens.hook_points import HookPoint 18 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 19 | from transformer_lens.utilities.activation_functions import ActivationFunction 20 | 21 | 22 | class CanBeUsedAsMLP(nn.Module): 23 | # The actual activation function 24 | act_fn: ActivationFunction 25 | 26 | # The full config object for the model 27 | cfg: HookedTransformerConfig 28 | 29 | # The d mlp value pulled out of the config to make sure it always has a value 30 | d_mlp: int 31 | 32 | # The middle hook point will be None unless it specifically should be used 33 | hook_mid: Optional[HookPoint] # [batch, pos, d_mlp] 34 | 35 | # The layer norm component if the activation function is a layer norm 36 | ln: Optional[nn.Module] 37 | 38 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 39 | """The base init for all MLP like components 40 | 41 | Args: 42 | config (Union[Dict, HookedTransformerConfig]): The config for this instance 43 | 44 | Raises: 45 | ValueError: If there is a misconfiguration 46 | """ 47 | super().__init__() 48 | self.cfg = HookedTransformerConfig.unwrap(cfg) 49 | if self.cfg.d_mlp is None: 50 | raise ValueError("d_mlp must be set to use an MLP") 51 | 52 | self.d_mlp = self.cfg.d_mlp 53 | 54 | def forward( 55 | self, x: Float[torch.Tensor, "batch pos d_model"] 56 | ) -> Float[torch.Tensor, "batch pos d_model"]: 57 | """The format for all forward functions for any MLP""" 58 | return x 59 | 60 | def select_activation_function(self) -> None: 61 | """This function should be called by all components in their init to get everything needed 62 | for activation functions setup. 63 | 64 | Raises: 65 | ValueError: If the configure activation function is not supported. 66 | """ 67 | 68 | self.act_fn = ActivationFunctionFactory.pick_activation_function(self.cfg) 69 | 70 | if self.cfg.is_layer_norm_activation(): 71 | self.hook_mid = HookPoint() 72 | if self.cfg.normalization_type == "LN": 73 | self.ln = LayerNorm(self.cfg, self.d_mlp) 74 | else: 75 | self.ln = LayerNormPre(self.cfg) 76 | -------------------------------------------------------------------------------- /transformer_lens/components/mlps/gated_mlp.py: -------------------------------------------------------------------------------- 1 | """Hooked Transformer Gated MLP Component. 2 | 3 | This module contains all the component :class:`GatedMLP`. 4 | """ 5 | from typing import Dict, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | from jaxtyping import Float 10 | from transformers.utils import is_bitsandbytes_available 11 | 12 | from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP 13 | from transformer_lens.hook_points import HookPoint 14 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 15 | from transformer_lens.utilities.addmm import batch_addmm 16 | 17 | if is_bitsandbytes_available(): 18 | pass 19 | 20 | 21 | class GatedMLP(CanBeUsedAsMLP): 22 | """ 23 | The equation of a gated MLP: 24 | pre = x @ W_gate 25 | pre_linear = x @ W_in 26 | post = Gelu(pre) * (pre_linear) + b_in 27 | mlp_out = post @ W_out + b_out 28 | 29 | In one equation, mlp_out = (Gelu(x @ W_gate) * (x @ W_in) + b_in) @ W_out + b_out 30 | """ 31 | 32 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 33 | super().__init__(cfg) 34 | self.select_activation_function() 35 | self.W_in = nn.Parameter(torch.empty(self.cfg.d_model, self.d_mlp, dtype=self.cfg.dtype)) 36 | self.W_out = nn.Parameter(torch.empty(self.d_mlp, self.cfg.d_model, dtype=self.cfg.dtype)) 37 | self.W_gate = nn.Parameter(torch.empty(self.cfg.d_model, self.d_mlp, dtype=self.cfg.dtype)) 38 | 39 | self.b_in = nn.Parameter(torch.zeros(self.d_mlp, dtype=self.cfg.dtype)) 40 | self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) 41 | 42 | # hook on gate output but before act_fn 43 | self.hook_pre = HookPoint() # [batch, pos, d_mlp] 44 | # hook on the linear component of the input 45 | self.hook_pre_linear = HookPoint() # [batch, pos, d_mlp] 46 | # hook on act_fn(gate_output) * W_in(x) + b_in 47 | self.hook_post = HookPoint() # [batch, pos, d_mlp] 48 | 49 | def forward( 50 | self, x: Float[torch.Tensor, "batch pos d_model"] 51 | ) -> Float[torch.Tensor, "batch pos d_model"]: 52 | # Technically, all these einsums could be done with a single matmul, but this is more readable. 53 | if self.W_gate.device != x.device: 54 | x = x.to(self.W_gate.device) 55 | pre_act = self.hook_pre( 56 | torch.matmul(x, self.W_gate) # batch pos d_model, d_model d_mlp -> batch pos d_mlp 57 | ) # [batch, pos, d_mlp] 58 | 59 | if ( 60 | self.cfg.is_layer_norm_activation() 61 | and self.hook_mid is not None 62 | and self.ln is not None 63 | ): 64 | mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp] 65 | post_act = self.hook_post(self.ln(mid_act)) 66 | else: 67 | pre_linear = self.hook_pre_linear( 68 | torch.matmul(x, self.W_in) # batch pos d_model, d_model d_mlp -> batch pos d_mlp 69 | ) 70 | 71 | post_act = self.hook_post( 72 | (self.act_fn(pre_act) * pre_linear) + self.b_in 73 | ) # [batch, pos, d_mlp] 74 | 75 | return batch_addmm(self.b_out, self.W_out, post_act) 76 | -------------------------------------------------------------------------------- /transformer_lens/components/mlps/gated_mlp_4bit.py: -------------------------------------------------------------------------------- 1 | """Hooked Transformer Gated MLP Component. 2 | 3 | This module contains all the component :class:`GatedMLP`. 4 | """ 5 | from typing import Dict, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | from jaxtyping import Float 10 | from transformers.utils import is_bitsandbytes_available 11 | 12 | from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP 13 | from transformer_lens.hook_points import HookPoint 14 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 15 | 16 | if is_bitsandbytes_available(): 17 | import bitsandbytes as bnb 18 | from bitsandbytes.nn.modules import Params4bit 19 | 20 | 21 | class GatedMLP4Bit(CanBeUsedAsMLP): 22 | """ 23 | The equation of a gated MLP: 24 | pre = x @ W_gate 25 | pre_linear = x @ W_in 26 | post = Gelu(pre) * (pre_linear) + b_in 27 | mlp_out = post @ W_out + b_out 28 | 29 | In one equation, mlp_out = (Gelu(x @ W_gate) * (x @ W_in) + b_in) @ W_out + b_out 30 | """ 31 | 32 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 33 | super().__init__(cfg) 34 | self.select_activation_function() 35 | 36 | nq = int((self.cfg.d_model * self.d_mlp) / 2) 37 | self.W_in = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 38 | self.W_gate = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 39 | self.W_out = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 40 | 41 | self.b_in = nn.Parameter(torch.zeros(self.d_mlp, dtype=self.cfg.dtype)) 42 | self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) 43 | 44 | # hook on gate output but before act_fn 45 | self.hook_pre = HookPoint() # [batch, pos, d_mlp] 46 | # hook on the linear component of the input 47 | self.hook_pre_linear = HookPoint() # [batch, pos, d_mlp] 48 | # hook on act_fn(gate_output) * W_in(x) + b_in 49 | self.hook_post = HookPoint() # [batch, pos, d_mlp] 50 | 51 | def forward( 52 | self, x: Float[torch.Tensor, "batch pos d_model"] 53 | ) -> Float[torch.Tensor, "batch pos d_model"]: 54 | # Technically, all these einsums could be done with a single matmul, but this is more readable. 55 | pre_act = self.hook_pre( 56 | bnb.matmul_4bit(x, self.W_gate.t(), bias=None, quant_state=self.W_gate.quant_state) 57 | ) 58 | 59 | if ( 60 | self.cfg.is_layer_norm_activation() 61 | and self.hook_mid is not None 62 | and self.ln is not None 63 | ): 64 | mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp] 65 | post_act = self.hook_post(self.ln(mid_act)) 66 | else: 67 | pre_linear = self.hook_pre_linear( 68 | bnb.matmul_4bit(x, self.W_in.t(), bias=None, quant_state=self.W_in.quant_state) 69 | ) 70 | 71 | post_act = self.hook_post( 72 | (self.act_fn(pre_act) * pre_linear) + self.b_in 73 | ) # [batch, pos, d_mlp] 74 | 75 | return bnb.matmul_4bit( 76 | post_act, self.W_out.t(), bias=None, quant_state=self.W_out.quant_state 77 | ) 78 | -------------------------------------------------------------------------------- /transformer_lens/components/mlps/mlp.py: -------------------------------------------------------------------------------- 1 | """Hooked Transformer MLP Component. 2 | 3 | This module contains all the component :class:`MLP`. 4 | """ 5 | 6 | from typing import Dict, Union 7 | 8 | import torch 9 | import torch.nn as nn 10 | from jaxtyping import Float 11 | 12 | from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP 13 | from transformer_lens.hook_points import HookPoint 14 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 15 | from transformer_lens.utilities.addmm import batch_addmm 16 | 17 | 18 | class MLP(CanBeUsedAsMLP): 19 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 20 | super().__init__(cfg) 21 | self.select_activation_function() 22 | 23 | self.W_in = nn.Parameter(torch.empty(self.cfg.d_model, self.d_mlp, dtype=self.cfg.dtype)) 24 | self.b_in = nn.Parameter(torch.zeros(self.d_mlp, dtype=self.cfg.dtype)) 25 | 26 | self.W_out = nn.Parameter(torch.empty(self.d_mlp, self.cfg.d_model, dtype=self.cfg.dtype)) 27 | self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) 28 | 29 | self.hook_pre = HookPoint() # [batch, pos, d_mlp] 30 | self.hook_post = HookPoint() # [batch, pos, d_mlp] 31 | 32 | def forward( 33 | self, x: Float[torch.Tensor, "batch pos d_model"] 34 | ) -> Float[torch.Tensor, "batch pos d_model"]: 35 | # This is equivalent to (roughly) W_in @ x + b_in. It's important to 36 | # use a fused addmm to ensure it matches the Huggingface implementation 37 | # exactly. 38 | pre_act = self.hook_pre(batch_addmm(self.b_in, self.W_in, x)) # [batch, pos, d_mlp] 39 | 40 | if ( 41 | self.cfg.is_layer_norm_activation() 42 | and self.hook_mid is not None 43 | and self.ln is not None 44 | ): 45 | mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp] 46 | post_act = self.hook_post(self.ln(mid_act)) 47 | else: 48 | post_act = self.hook_post(self.act_fn(pre_act)) # [batch, pos, d_mlp] 49 | return batch_addmm(self.b_out, self.W_out, post_act) 50 | -------------------------------------------------------------------------------- /transformer_lens/components/pos_embed.py: -------------------------------------------------------------------------------- 1 | """Hooked Transformer POS Embed Component. 2 | 3 | This module contains all the component :class:`PosEmbed`. 4 | """ 5 | from typing import Dict, Optional, Union 6 | 7 | import einops 8 | import torch 9 | import torch.nn as nn 10 | from jaxtyping import Float, Int 11 | 12 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 13 | from transformer_lens.utils import get_offset_position_ids 14 | 15 | 16 | # Positional Embeddings 17 | class PosEmbed(nn.Module): 18 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 19 | super().__init__() 20 | self.cfg = HookedTransformerConfig.unwrap(cfg) 21 | self.W_pos = nn.Parameter( 22 | torch.empty(self.cfg.n_ctx, self.cfg.d_model, dtype=self.cfg.dtype) 23 | ) 24 | 25 | def forward( 26 | self, 27 | tokens: Int[torch.Tensor, "batch pos"], 28 | past_kv_pos_offset: int = 0, 29 | attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, 30 | ) -> Float[torch.Tensor, "batch new_pos d_model"]: 31 | """ 32 | Forward pass for positional embeddings. 33 | 34 | Args: 35 | tokens (Int[torch.Tensor, "batch pos"]): Input tokens. 36 | past_kv_pos_offset (int, optional): The length of tokens in the past_kv_cache. Defaults to 0. 37 | attention_mask (Int[torch.Tensor, "batch pos"], optional): The attention mask for padded tokens. 38 | Defaults to None. 39 | 40 | Returns: 41 | Float[torch.Tensor, "batch pos d_model"]: Absolute position embeddings. 42 | """ 43 | tokens_length = tokens.size(-1) 44 | 45 | if attention_mask is None: 46 | pos_embed = self.W_pos[ 47 | past_kv_pos_offset : tokens_length + past_kv_pos_offset, : 48 | ] # [pos, d_model] 49 | batch_pos_embed = einops.repeat( 50 | pos_embed, "pos d_model -> batch pos d_model", batch=tokens.size(0) 51 | ) 52 | 53 | else: 54 | # Separated from the no padding case for computational efficiency 55 | # (this code is a bit slower than the code above) 56 | 57 | offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask) 58 | pos_embed = self.W_pos[offset_position_ids] # [batch, pos, d_model] 59 | 60 | # Set the position embeddings to 0 for pad tokens (this is an arbitrary choice) 61 | padding_mask = ~attention_mask.bool() # [batch, tokens_length] 62 | offset_padding_mask = padding_mask[ 63 | :, past_kv_pos_offset : tokens_length + past_kv_pos_offset 64 | ].unsqueeze( 65 | -1 66 | ) # [batch, pos, 1] 67 | batch_pos_embed = torch.where(offset_padding_mask, 0, pos_embed) 68 | 69 | return batch_pos_embed.clone() 70 | -------------------------------------------------------------------------------- /transformer_lens/components/rms_norm.py: -------------------------------------------------------------------------------- 1 | """Hooked Transformer RMS Norm Component. 2 | 3 | This module contains all the component :class:`RMSNorm`. 4 | """ 5 | from typing import Dict, Optional, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | from jaxtyping import Float 10 | 11 | from transformer_lens.hook_points import HookPoint 12 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 13 | 14 | 15 | class RMSNorm(nn.Module): 16 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None): 17 | """ 18 | RMSNorm - LayerNorm without the centering and bias (RMS = Root Mean Square) 19 | 20 | length (Optional[int]): If the dimension of the RMSNorm. If not provided, assumed to be d_model 21 | """ 22 | super().__init__() 23 | self.cfg = HookedTransformerConfig.unwrap(cfg) 24 | self.eps = self.cfg.eps 25 | if length is None: 26 | self.length = self.cfg.d_model 27 | else: 28 | self.length = length 29 | 30 | self.w = nn.Parameter(torch.ones(self.length, dtype=self.cfg.dtype)) 31 | 32 | # Adds a hook point for the normalisation scale factor 33 | self.hook_scale = HookPoint() # [batch, pos, 1] 34 | self.hook_normalized = HookPoint() # [batch, pos, length] 35 | 36 | def forward( 37 | self, x: Float[torch.Tensor, "batch pos length"] 38 | ) -> Float[torch.Tensor, "batch pos length"]: 39 | if self.cfg.dtype not in [torch.float32, torch.float64]: 40 | x = x.to(torch.float32) 41 | scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale( 42 | (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt() 43 | ) 44 | x = self.hook_normalized(x / scale).to(self.cfg.dtype) # [batch, pos, length] 45 | 46 | if x.device != self.w.device: 47 | self.to(x.device) 48 | 49 | return x * self.w 50 | -------------------------------------------------------------------------------- /transformer_lens/components/rms_norm_pre.py: -------------------------------------------------------------------------------- 1 | """Hooked Transformer RMS Norm Pre Component. 2 | 3 | This module contains all the component :class:`RMSNormPre`. 4 | """ 5 | from typing import Dict, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | from jaxtyping import Float 10 | 11 | from transformer_lens.hook_points import HookPoint 12 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 13 | 14 | 15 | class RMSNormPre(nn.Module): 16 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 17 | """RMSNormPre - LayerNormPre without the centering and bias (RMS = Root Mean Square)""" 18 | super().__init__() 19 | self.cfg = HookedTransformerConfig.unwrap(cfg) 20 | self.eps = self.cfg.eps 21 | 22 | # Adds a hook point for the normalisation scale factor 23 | self.hook_scale = HookPoint() # [batch, pos] 24 | self.hook_normalized = HookPoint() # [batch, pos, length] 25 | 26 | def forward( 27 | self, x: Float[torch.Tensor, "batch pos length"] 28 | ) -> Float[torch.Tensor, "batch pos length"]: 29 | if self.cfg.dtype not in [torch.float32, torch.float64]: 30 | x = x.to(torch.float32) 31 | 32 | scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale( 33 | (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt() 34 | ) 35 | return self.hook_normalized(x / scale).to(self.cfg.dtype) # [batch, pos, length] 36 | -------------------------------------------------------------------------------- /transformer_lens/components/token_typed_embed.py: -------------------------------------------------------------------------------- 1 | """Hooked Transformer Token Typed Embed Component. 2 | 3 | This module contains all the component :class:`TokenTypeEmbed`. 4 | """ 5 | from typing import Dict, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | from jaxtyping import Int 10 | 11 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 12 | 13 | 14 | class TokenTypeEmbed(nn.Module): 15 | """ 16 | The token-type embed is a binary ids indicating whether a token belongs to sequence A or B. For example, for two sentences: "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, `1` from Sentence B. If not provided, BERT assumes a single sequence input. Typically, shape is (batch_size, sequence_length). 17 | 18 | See the BERT paper for more information: https://arxiv.org/pdf/1810.04805.pdf 19 | """ 20 | 21 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 22 | super().__init__() 23 | self.cfg = HookedTransformerConfig.unwrap(cfg) 24 | self.W_token_type = nn.Parameter(torch.empty(2, self.cfg.d_model, dtype=self.cfg.dtype)) 25 | 26 | def forward(self, token_type_ids: Int[torch.Tensor, "batch pos"]): 27 | return self.W_token_type[token_type_ids, :] 28 | -------------------------------------------------------------------------------- /transformer_lens/components/unembed.py: -------------------------------------------------------------------------------- 1 | """Hooked Transformer Unembed Component. 2 | 3 | This module contains all the component :class:`Unembed`. 4 | """ 5 | 6 | from typing import Dict, Union 7 | 8 | import torch 9 | import torch.nn as nn 10 | from jaxtyping import Float 11 | 12 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 13 | from transformer_lens.utilities.addmm import batch_addmm 14 | 15 | 16 | class Unembed(nn.Module): 17 | def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 18 | super().__init__() 19 | self.cfg = HookedTransformerConfig.unwrap(cfg) 20 | # Note that there's a separate variable for d_vocab_out and d_vocab (the input vocab size). For language tasks these are always the same, but for algorithmic tasks we may want them to be different. 21 | self.W_U: Float[torch.Tensor, "d_model d_vocab_out"] = nn.Parameter( 22 | torch.empty(self.cfg.d_model, self.cfg.d_vocab_out, dtype=self.cfg.dtype) 23 | ) 24 | self.b_U: Float[torch.Tensor, "d_vocab_out"] = nn.Parameter( 25 | torch.zeros(self.cfg.d_vocab_out, dtype=self.cfg.dtype) 26 | ) 27 | 28 | def forward( 29 | self, residual: Float[torch.Tensor, "batch pos d_model"] 30 | ) -> Float[torch.Tensor, "batch pos d_vocab_out"]: 31 | return batch_addmm(self.b_U, self.W_U, residual) 32 | -------------------------------------------------------------------------------- /transformer_lens/factories/activation_function_factory.py: -------------------------------------------------------------------------------- 1 | """Activation Function Factory 2 | 3 | Centralized location for selection supported activation functions throughout TransformerLens 4 | """ 5 | 6 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 7 | from transformer_lens.utilities.activation_functions import ( 8 | SUPPORTED_ACTIVATIONS, 9 | ActivationFunction, 10 | ) 11 | 12 | 13 | class ActivationFunctionFactory: 14 | @staticmethod 15 | def pick_activation_function(cfg: HookedTransformerConfig) -> ActivationFunction: 16 | """Use this to select what activation function is needed based on configuration. 17 | 18 | Args: 19 | cfg (HookedTransformerConfig): The already created hooked transformer config 20 | 21 | Raises: 22 | ValueError: If there is a problem with the requested activation function. 23 | 24 | Returns: 25 | ActivationFunction: The activation function based on the dictionary of supported activations. 26 | """ 27 | act_fn = cfg.act_fn 28 | 29 | if act_fn is None: 30 | raise ValueError("act_fn not set when trying to select Activation Function") 31 | 32 | activation_function = SUPPORTED_ACTIVATIONS.get(act_fn) 33 | 34 | if activation_function is None: 35 | raise ValueError(f"Invalid activation function name: {act_fn}") 36 | 37 | return activation_function 38 | -------------------------------------------------------------------------------- /transformer_lens/factories/mlp_factory.py: -------------------------------------------------------------------------------- 1 | """MLP Factory 2 | 3 | Centralized location for creating any MLP needed within TransformerLens 4 | """ 5 | from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP 6 | from transformer_lens.components.mlps.gated_mlp import GatedMLP 7 | from transformer_lens.components.mlps.gated_mlp_4bit import GatedMLP4Bit 8 | from transformer_lens.components.mlps.mlp import MLP 9 | from transformer_lens.components.mlps.moe import MoE 10 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 11 | 12 | 13 | class MLPFactory: 14 | @staticmethod 15 | def create_mlp(cfg: HookedTransformerConfig) -> CanBeUsedAsMLP: 16 | if cfg.num_experts: 17 | return MoE(cfg) 18 | elif cfg.gated_mlp: 19 | return GatedMLP(cfg) if not cfg.load_in_4bit else GatedMLP4Bit(cfg) 20 | else: 21 | return MLP(cfg) 22 | -------------------------------------------------------------------------------- /transformer_lens/pretrained/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TransformerLensOrg/TransformerLens/b5a16f849649a237cc02cc2c272ae4dc2085abe4/transformer_lens/pretrained/__init__.py -------------------------------------------------------------------------------- /transformer_lens/pretrained/weight_conversions/__init__.py: -------------------------------------------------------------------------------- 1 | from .neo import convert_neo_weights 2 | from .gpt2 import convert_gpt2_weights 3 | from .opt import convert_opt_weights 4 | from .gptj import convert_gptj_weights 5 | from .neox import convert_neox_weights 6 | from .llama import convert_llama_weights 7 | from .bert import convert_bert_weights 8 | from .mistral import convert_mistral_weights 9 | from .mixtral import convert_mixtral_weights 10 | from .bloom import convert_bloom_weights 11 | from .coder import convert_coder_weights 12 | from .qwen import convert_qwen_weights 13 | from .qwen2 import convert_qwen2_weights 14 | from .phi import convert_phi_weights 15 | from .phi3 import convert_phi3_weights 16 | from .gemma import convert_gemma_weights 17 | from .mingpt import convert_mingpt_weights 18 | from .nanogpt import convert_nanogpt_weights 19 | from .t5 import convert_t5_weights 20 | from .neel_solu_old import convert_neel_solu_old_weights 21 | -------------------------------------------------------------------------------- /transformer_lens/pretrained/weight_conversions/bert.py: -------------------------------------------------------------------------------- 1 | import einops 2 | 3 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 4 | 5 | 6 | def convert_bert_weights(bert, cfg: HookedTransformerConfig): 7 | embeddings = bert.bert.embeddings 8 | state_dict = { 9 | "embed.embed.W_E": embeddings.word_embeddings.weight, 10 | "embed.pos_embed.W_pos": embeddings.position_embeddings.weight, 11 | "embed.token_type_embed.W_token_type": embeddings.token_type_embeddings.weight, 12 | "embed.ln.w": embeddings.LayerNorm.weight, 13 | "embed.ln.b": embeddings.LayerNorm.bias, 14 | } 15 | 16 | for l in range(cfg.n_layers): 17 | block = bert.bert.encoder.layer[l] 18 | state_dict[f"blocks.{l}.attn.W_Q"] = einops.rearrange( 19 | block.attention.self.query.weight, "(i h) m -> i m h", i=cfg.n_heads 20 | ) 21 | state_dict[f"blocks.{l}.attn.b_Q"] = einops.rearrange( 22 | block.attention.self.query.bias, "(i h) -> i h", i=cfg.n_heads 23 | ) 24 | state_dict[f"blocks.{l}.attn.W_K"] = einops.rearrange( 25 | block.attention.self.key.weight, "(i h) m -> i m h", i=cfg.n_heads 26 | ) 27 | state_dict[f"blocks.{l}.attn.b_K"] = einops.rearrange( 28 | block.attention.self.key.bias, "(i h) -> i h", i=cfg.n_heads 29 | ) 30 | state_dict[f"blocks.{l}.attn.W_V"] = einops.rearrange( 31 | block.attention.self.value.weight, "(i h) m -> i m h", i=cfg.n_heads 32 | ) 33 | state_dict[f"blocks.{l}.attn.b_V"] = einops.rearrange( 34 | block.attention.self.value.bias, "(i h) -> i h", i=cfg.n_heads 35 | ) 36 | state_dict[f"blocks.{l}.attn.W_O"] = einops.rearrange( 37 | block.attention.output.dense.weight, 38 | "m (i h) -> i h m", 39 | i=cfg.n_heads, 40 | ) 41 | state_dict[f"blocks.{l}.attn.b_O"] = block.attention.output.dense.bias 42 | state_dict[f"blocks.{l}.ln1.w"] = block.attention.output.LayerNorm.weight 43 | state_dict[f"blocks.{l}.ln1.b"] = block.attention.output.LayerNorm.bias 44 | state_dict[f"blocks.{l}.mlp.W_in"] = einops.rearrange( 45 | block.intermediate.dense.weight, "mlp model -> model mlp" 46 | ) 47 | state_dict[f"blocks.{l}.mlp.b_in"] = block.intermediate.dense.bias 48 | state_dict[f"blocks.{l}.mlp.W_out"] = einops.rearrange( 49 | block.output.dense.weight, "model mlp -> mlp model" 50 | ) 51 | state_dict[f"blocks.{l}.mlp.b_out"] = block.output.dense.bias 52 | state_dict[f"blocks.{l}.ln2.w"] = block.output.LayerNorm.weight 53 | state_dict[f"blocks.{l}.ln2.b"] = block.output.LayerNorm.bias 54 | 55 | pooler = bert.bert.pooler 56 | state_dict["pooler.W"] = pooler.dense.weight.T 57 | state_dict["pooler.b"] = pooler.dense.bias 58 | 59 | mlm_head = bert.cls.predictions 60 | state_dict["mlm_head.W"] = mlm_head.transform.dense.weight.T 61 | state_dict["mlm_head.b"] = mlm_head.transform.dense.bias 62 | state_dict["mlm_head.ln.w"] = mlm_head.transform.LayerNorm.weight 63 | state_dict["mlm_head.ln.b"] = mlm_head.transform.LayerNorm.bias 64 | 65 | # The NSP head does not have an unembedding 66 | # so we are only using weights from the MLM head 67 | # Note: BERT uses tied embeddings 68 | state_dict["unembed.W_U"] = mlm_head.decoder.weight.T 69 | state_dict["unembed.b_U"] = mlm_head.decoder.bias 70 | 71 | nsp_head = bert.cls.seq_relationship 72 | state_dict["nsp_head.W"] = nsp_head.weight.T 73 | state_dict["nsp_head.b"] = nsp_head.bias 74 | 75 | return state_dict 76 | -------------------------------------------------------------------------------- /transformer_lens/pretrained/weight_conversions/bloom.py: -------------------------------------------------------------------------------- 1 | import einops 2 | 3 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 4 | 5 | 6 | def convert_bloom_weights(bloom, cfg: HookedTransformerConfig): 7 | state_dict = {} 8 | 9 | state_dict["embed.W_E"] = bloom.transformer.word_embeddings.weight 10 | 11 | # Bloom uses post embedding layer norm 12 | state_dict["embed.ln.w"] = bloom.transformer.word_embeddings_layernorm.weight 13 | state_dict["embed.ln.b"] = bloom.transformer.word_embeddings_layernorm.bias 14 | 15 | for l in range(cfg.n_layers): 16 | state_dict[f"blocks.{l}.ln1.w"] = bloom.transformer.h[l].input_layernorm.weight 17 | state_dict[f"blocks.{l}.ln1.b"] = bloom.transformer.h[l].input_layernorm.bias 18 | 19 | W = bloom.transformer.h[l].self_attention.query_key_value.weight 20 | 21 | W_split = W.T.reshape(cfg.d_model, cfg.n_heads, 3, cfg.d_head) 22 | 23 | W_Q, W_K, W_V = W_split[..., 0, :], W_split[..., 1, :], W_split[..., 2, :] 24 | W_Q = einops.rearrange(W_Q, "m n h ->n m h", n=cfg.n_heads) 25 | W_K = einops.rearrange(W_K, "m n h ->n m h", n=cfg.n_heads) 26 | W_V = einops.rearrange(W_V, "m n h ->n m h", n=cfg.n_heads) 27 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 28 | state_dict[f"blocks.{l}.attn.W_K"] = W_K 29 | state_dict[f"blocks.{l}.attn.W_V"] = W_V 30 | 31 | qkv_bias = bloom.transformer.h[l].self_attention.query_key_value.bias 32 | qkv_bias = qkv_bias.reshape(cfg.n_heads, 3, cfg.d_head) 33 | 34 | state_dict[f"blocks.{l}.attn.b_Q"] = qkv_bias[:, 0, :] 35 | state_dict[f"blocks.{l}.attn.b_K"] = qkv_bias[:, 1, :] 36 | state_dict[f"blocks.{l}.attn.b_V"] = qkv_bias[:, 2, :] 37 | 38 | W_O = bloom.transformer.h[l].self_attention.dense.weight.T # [1024, 1024] 39 | W_O = einops.rearrange(W_O, "(n h) m->n h m", n=cfg.n_heads) # [n_heads, d_head, d_model] 40 | state_dict[f"blocks.{l}.attn.W_O"] = W_O 41 | state_dict[f"blocks.{l}.attn.b_O"] = bloom.transformer.h[l].self_attention.dense.bias 42 | 43 | state_dict[f"blocks.{l}.ln2.w"] = bloom.transformer.h[l].post_attention_layernorm.weight 44 | state_dict[f"blocks.{l}.ln2.b"] = bloom.transformer.h[l].post_attention_layernorm.bias 45 | 46 | W_in = bloom.transformer.h[l].mlp.dense_h_to_4h.weight.T 47 | state_dict[f"blocks.{l}.mlp.W_in"] = W_in 48 | state_dict[f"blocks.{l}.mlp.b_in"] = bloom.transformer.h[l].mlp.dense_h_to_4h.bias 49 | 50 | W_out = bloom.transformer.h[l].mlp.dense_4h_to_h.weight.T 51 | state_dict[f"blocks.{l}.mlp.W_out"] = W_out 52 | state_dict[f"blocks.{l}.mlp.b_out"] = bloom.transformer.h[l].mlp.dense_4h_to_h.bias 53 | state_dict["unembed.W_U"] = bloom.lm_head.weight.T 54 | 55 | state_dict["ln_final.w"] = bloom.transformer.ln_f.weight 56 | state_dict["ln_final.b"] = bloom.transformer.ln_f.bias 57 | return state_dict 58 | -------------------------------------------------------------------------------- /transformer_lens/pretrained/weight_conversions/coder.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | 4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 5 | 6 | 7 | def convert_coder_weights(model, cfg: HookedTransformerConfig): 8 | state_dict = {} 9 | 10 | state_dict["embed.W_E"] = model.transformer.wte.weight 11 | state_dict["pos_embed.W_pos"] = model.transformer.wpe.weight 12 | 13 | for l in range(cfg.n_layers): 14 | state_dict[f"blocks.{l}.ln1.w"] = model.transformer.h[l].ln_1.weight 15 | state_dict[f"blocks.{l}.ln1.b"] = model.transformer.h[l].ln_1.bias 16 | 17 | # In GPT-2, q,k,v are produced by one big linear map, whose output is 18 | # concat([q, k, v]) 19 | W_KV = model.transformer.h[l].attn.kv_attn.weight # [d_model, 2 * d_head] 20 | W_K, W_V = torch.tensor_split(W_KV, 2, dim=1) 21 | W_Q = model.transformer.h[l].attn.q_attn.weight # [d_model, d_model] 22 | W_Q = einops.rearrange(W_Q, "m (i h)->i m h", i=cfg.n_heads) 23 | W_K = einops.repeat(W_K, "m h -> i m h", i=cfg.n_heads) 24 | W_V = einops.repeat(W_V, "m h -> i m h", i=cfg.n_heads) 25 | 26 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 27 | state_dict[f"blocks.{l}.attn.W_K"] = W_K 28 | state_dict[f"blocks.{l}.attn.W_V"] = W_V 29 | 30 | b_Q = einops.rearrange( 31 | model.transformer.h[l].attn.q_attn.bias, 32 | "(index head)-> index head", 33 | index=cfg.n_heads, 34 | head=cfg.d_head, 35 | ) 36 | b_KV = model.transformer.h[l].attn.kv_attn.bias # [2 * d_head] 37 | b_K, b_V = torch.tensor_split(b_KV, 2, dim=0) 38 | b_K = einops.repeat(b_K, "head -> index head", index=cfg.n_heads) 39 | b_V = einops.repeat(b_V, "head -> index head", index=cfg.n_heads) 40 | state_dict[f"blocks.{l}.attn.b_Q"] = b_Q 41 | state_dict[f"blocks.{l}.attn.b_K"] = b_K 42 | state_dict[f"blocks.{l}.attn.b_V"] = b_V 43 | 44 | W_O = model.transformer.h[l].attn.c_proj.weight 45 | W_O = einops.rearrange(W_O, "(i h) m->i h m", i=cfg.n_heads) 46 | state_dict[f"blocks.{l}.attn.W_O"] = W_O 47 | state_dict[f"blocks.{l}.attn.b_O"] = model.transformer.h[l].attn.c_proj.bias 48 | 49 | state_dict[f"blocks.{l}.ln2.w"] = model.transformer.h[l].ln_2.weight 50 | state_dict[f"blocks.{l}.ln2.b"] = model.transformer.h[l].ln_2.bias 51 | 52 | W_in = model.transformer.h[l].mlp.c_fc.weight 53 | state_dict[f"blocks.{l}.mlp.W_in"] = W_in 54 | state_dict[f"blocks.{l}.mlp.b_in"] = model.transformer.h[l].mlp.c_fc.bias 55 | 56 | W_out = model.transformer.h[l].mlp.c_proj.weight 57 | state_dict[f"blocks.{l}.mlp.W_out"] = W_out 58 | state_dict[f"blocks.{l}.mlp.b_out"] = model.transformer.h[l].mlp.c_proj.bias 59 | state_dict["unembed.W_U"] = model.lm_head.weight.T 60 | 61 | state_dict["ln_final.w"] = model.transformer.ln_f.weight 62 | state_dict["ln_final.b"] = model.transformer.ln_f.bias 63 | return state_dict 64 | -------------------------------------------------------------------------------- /transformer_lens/pretrained/weight_conversions/gpt2.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | 4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 5 | 6 | 7 | def convert_gpt2_weights(gpt2, cfg: HookedTransformerConfig): 8 | state_dict = {} 9 | 10 | state_dict["embed.W_E"] = gpt2.transformer.wte.weight 11 | state_dict["pos_embed.W_pos"] = gpt2.transformer.wpe.weight 12 | 13 | for l in range(cfg.n_layers): 14 | state_dict[f"blocks.{l}.ln1.w"] = gpt2.transformer.h[l].ln_1.weight 15 | state_dict[f"blocks.{l}.ln1.b"] = gpt2.transformer.h[l].ln_1.bias 16 | 17 | # In GPT-2, q,k,v are produced by one big linear map, whose output is 18 | # concat([q, k, v]) 19 | W = gpt2.transformer.h[l].attn.c_attn.weight 20 | W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=1) 21 | W_Q = einops.rearrange(W_Q, "m (i h)->i m h", i=cfg.n_heads) 22 | W_K = einops.rearrange(W_K, "m (i h)->i m h", i=cfg.n_heads) 23 | W_V = einops.rearrange(W_V, "m (i h)->i m h", i=cfg.n_heads) 24 | 25 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 26 | state_dict[f"blocks.{l}.attn.W_K"] = W_K 27 | state_dict[f"blocks.{l}.attn.W_V"] = W_V 28 | 29 | qkv_bias = gpt2.transformer.h[l].attn.c_attn.bias 30 | qkv_bias = einops.rearrange( 31 | qkv_bias, 32 | "(qkv index head)->qkv index head", 33 | qkv=3, 34 | index=cfg.n_heads, 35 | head=cfg.d_head, 36 | ) 37 | state_dict[f"blocks.{l}.attn.b_Q"] = qkv_bias[0] 38 | state_dict[f"blocks.{l}.attn.b_K"] = qkv_bias[1] 39 | state_dict[f"blocks.{l}.attn.b_V"] = qkv_bias[2] 40 | 41 | W_O = gpt2.transformer.h[l].attn.c_proj.weight 42 | W_O = einops.rearrange(W_O, "(i h) m->i h m", i=cfg.n_heads) 43 | state_dict[f"blocks.{l}.attn.W_O"] = W_O 44 | state_dict[f"blocks.{l}.attn.b_O"] = gpt2.transformer.h[l].attn.c_proj.bias 45 | 46 | state_dict[f"blocks.{l}.ln2.w"] = gpt2.transformer.h[l].ln_2.weight 47 | state_dict[f"blocks.{l}.ln2.b"] = gpt2.transformer.h[l].ln_2.bias 48 | 49 | W_in = gpt2.transformer.h[l].mlp.c_fc.weight 50 | state_dict[f"blocks.{l}.mlp.W_in"] = W_in 51 | state_dict[f"blocks.{l}.mlp.b_in"] = gpt2.transformer.h[l].mlp.c_fc.bias 52 | 53 | W_out = gpt2.transformer.h[l].mlp.c_proj.weight 54 | state_dict[f"blocks.{l}.mlp.W_out"] = W_out 55 | state_dict[f"blocks.{l}.mlp.b_out"] = gpt2.transformer.h[l].mlp.c_proj.bias 56 | state_dict["unembed.W_U"] = gpt2.lm_head.weight.T 57 | 58 | state_dict["ln_final.w"] = gpt2.transformer.ln_f.weight 59 | state_dict["ln_final.b"] = gpt2.transformer.ln_f.bias 60 | return state_dict 61 | -------------------------------------------------------------------------------- /transformer_lens/pretrained/weight_conversions/gptj.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | 4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 5 | 6 | 7 | def convert_gptj_weights(gptj, cfg: HookedTransformerConfig): 8 | state_dict = {} 9 | 10 | state_dict["embed.W_E"] = gptj.transformer.wte.weight 11 | 12 | for l in range(cfg.n_layers): 13 | state_dict[f"blocks.{l}.ln1.w"] = gptj.transformer.h[l].ln_1.weight 14 | state_dict[f"blocks.{l}.ln1.b"] = gptj.transformer.h[l].ln_1.bias 15 | 16 | W_Q = gptj.transformer.h[l].attn.q_proj.weight 17 | W_K = gptj.transformer.h[l].attn.k_proj.weight 18 | W_V = gptj.transformer.h[l].attn.v_proj.weight 19 | W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads) 20 | W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads) 21 | W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads) 22 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 23 | state_dict[f"blocks.{l}.attn.W_K"] = W_K 24 | state_dict[f"blocks.{l}.attn.W_V"] = W_V 25 | 26 | state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) 27 | state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) 28 | state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) 29 | 30 | W_O = gptj.transformer.h[l].attn.out_proj.weight 31 | W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) 32 | state_dict[f"blocks.{l}.attn.W_O"] = W_O 33 | state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 34 | 35 | # Layer Norm 1 and 2 are tied. 36 | state_dict[f"blocks.{l}.ln2.w"] = state_dict[f"blocks.{l}.ln1.w"] 37 | state_dict[f"blocks.{l}.ln2.b"] = state_dict[f"blocks.{l}.ln1.b"] 38 | 39 | state_dict[f"blocks.{l}.mlp.W_in"] = gptj.transformer.h[l].mlp.fc_in.weight.T 40 | state_dict[f"blocks.{l}.mlp.b_in"] = gptj.transformer.h[l].mlp.fc_in.bias 41 | 42 | state_dict[f"blocks.{l}.mlp.W_out"] = gptj.transformer.h[l].mlp.fc_out.weight.T 43 | state_dict[f"blocks.{l}.mlp.b_out"] = gptj.transformer.h[l].mlp.fc_out.bias 44 | state_dict["ln_final.w"] = gptj.transformer.ln_f.weight 45 | state_dict["ln_final.b"] = gptj.transformer.ln_f.bias 46 | 47 | state_dict["unembed.W_U"] = gptj.lm_head.weight.T 48 | # Contains a bias, for some reason? 49 | state_dict["unembed.b_U"] = gptj.lm_head.bias 50 | return state_dict 51 | -------------------------------------------------------------------------------- /transformer_lens/pretrained/weight_conversions/llama.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | 3 | import einops 4 | import torch 5 | 6 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 7 | 8 | 9 | def convert_llama_weights(llama, cfg: HookedTransformerConfig): 10 | state_dict = {} 11 | 12 | state_dict["embed.W_E"] = llama.model.embed_tokens.weight 13 | 14 | # Some models with the Llama architecture use Grouped Query Attention, and so for these we need to modify 15 | # the state dict keys for the K/V attention weight/biases, prepending "_" to the key names. 16 | using_gqa = cfg.n_key_value_heads is not None 17 | gqa_uscore = "_" if using_gqa else "" 18 | # need a cast since MyPy isn't smart enough to realize that using_gqa implies n_key_value_heads is not None 19 | n_kv_heads = cast(int, cfg.n_key_value_heads if using_gqa else cfg.n_heads) 20 | 21 | # llama has no biases anywhere and deals with everything else roughly like 22 | # GPTNeoX with different names 23 | 24 | assert cfg.d_mlp is not None # keep mypy happy 25 | 26 | for l in range(cfg.n_layers): 27 | state_dict[f"blocks.{l}.ln1.w"] = llama.model.layers[l].input_layernorm.weight 28 | 29 | W_Q = llama.model.layers[l].self_attn.q_proj.weight 30 | W_K = llama.model.layers[l].self_attn.k_proj.weight 31 | W_V = llama.model.layers[l].self_attn.v_proj.weight 32 | 33 | # in case of quantization, 34 | # parameters should stay as bitsandbytes.nn.modules.Params4bit 35 | if not cfg.load_in_4bit: 36 | W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) 37 | W_K = einops.rearrange(W_K, "(n h) m->n m h", n=n_kv_heads) 38 | W_V = einops.rearrange(W_V, "(n h) m->n m h", n=n_kv_heads) 39 | 40 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 41 | state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K 42 | state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V 43 | 44 | state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( 45 | cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device 46 | ) 47 | state_dict[f"blocks.{l}.attn.{gqa_uscore}b_K"] = torch.zeros( 48 | n_kv_heads, 49 | cfg.d_head, 50 | dtype=cfg.dtype, 51 | device=cfg.device, 52 | ) 53 | state_dict[f"blocks.{l}.attn.{gqa_uscore}b_V"] = torch.zeros( 54 | n_kv_heads, 55 | cfg.d_head, 56 | dtype=cfg.dtype, 57 | device=cfg.device, 58 | ) 59 | 60 | W_O = llama.model.layers[l].self_attn.o_proj.weight 61 | 62 | if not cfg.load_in_4bit: 63 | W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) 64 | 65 | state_dict[f"blocks.{l}.attn.W_O"] = W_O.to(device=cfg.device) 66 | 67 | state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros( 68 | cfg.d_model, dtype=cfg.dtype, device=cfg.device 69 | ) 70 | 71 | state_dict[f"blocks.{l}.ln2.w"] = llama.model.layers[l].post_attention_layernorm.weight 72 | 73 | # in case of quantization, 74 | # parameters should stay as bitsandbytes.nn.modules.Params4bit 75 | if not cfg.load_in_4bit: 76 | state_dict[f"blocks.{l}.mlp.W_in"] = llama.model.layers[l].mlp.up_proj.weight.T 77 | state_dict[f"blocks.{l}.mlp.W_gate"] = llama.model.layers[l].mlp.gate_proj.weight.T 78 | state_dict[f"blocks.{l}.mlp.W_out"] = llama.model.layers[l].mlp.down_proj.weight.T 79 | else: 80 | state_dict[f"blocks.{l}.mlp.W_in"] = llama.model.layers[l].mlp.up_proj.weight 81 | state_dict[f"blocks.{l}.mlp.W_gate"] = llama.model.layers[l].mlp.gate_proj.weight 82 | state_dict[f"blocks.{l}.mlp.W_out"] = llama.model.layers[l].mlp.down_proj.weight 83 | 84 | state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros( 85 | cfg.d_mlp, dtype=cfg.dtype, device=cfg.device 86 | ) 87 | state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros( 88 | cfg.d_model, dtype=cfg.dtype, device=cfg.device 89 | ) 90 | 91 | state_dict["ln_final.w"] = llama.model.norm.weight 92 | 93 | state_dict["unembed.W_U"] = llama.lm_head.weight.T 94 | state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype, device=cfg.device) 95 | 96 | return state_dict 97 | -------------------------------------------------------------------------------- /transformer_lens/pretrained/weight_conversions/mingpt.py: -------------------------------------------------------------------------------- 1 | import einops 2 | 3 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 4 | 5 | 6 | def convert_mingpt_weights(old_state_dict, cfg: HookedTransformerConfig): 7 | # mingpt (https://github.com/karpathy/minGPT) is mostly similar to GPT-2, 8 | # but doesn't concat the QKV matrices. 9 | state_dict = {} 10 | 11 | state_dict["embed.W_E"] = old_state_dict["tok_emb.weight"] 12 | state_dict["pos_embed.W_pos"] = old_state_dict["pos_emb"].squeeze() 13 | 14 | for l in range(cfg.n_layers): 15 | state_dict[f"blocks.{l}.ln1.w"] = old_state_dict[f"blocks.{l}.ln1.weight"] 16 | state_dict[f"blocks.{l}.ln1.b"] = old_state_dict[f"blocks.{l}.ln1.bias"] 17 | 18 | W_Q = old_state_dict[f"blocks.{l}.attn.query.weight"] 19 | W_K = old_state_dict[f"blocks.{l}.attn.key.weight"] 20 | W_V = old_state_dict[f"blocks.{l}.attn.value.weight"] 21 | W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads) 22 | W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads) 23 | W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads) 24 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 25 | state_dict[f"blocks.{l}.attn.W_K"] = W_K 26 | state_dict[f"blocks.{l}.attn.W_V"] = W_V 27 | 28 | q_bias = einops.rearrange( 29 | old_state_dict[f"blocks.{l}.attn.query.bias"], "(i h)->i h", i=cfg.n_heads 30 | ) 31 | k_bias = einops.rearrange( 32 | old_state_dict[f"blocks.{l}.attn.key.bias"], "(i h)->i h", i=cfg.n_heads 33 | ) 34 | v_bias = einops.rearrange( 35 | old_state_dict[f"blocks.{l}.attn.value.bias"], "(i h)->i h", i=cfg.n_heads 36 | ) 37 | 38 | state_dict[f"blocks.{l}.attn.b_Q"] = q_bias 39 | state_dict[f"blocks.{l}.attn.b_K"] = k_bias 40 | state_dict[f"blocks.{l}.attn.b_V"] = v_bias 41 | 42 | W_O = old_state_dict[f"blocks.{l}.attn.proj.weight"] 43 | W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) 44 | state_dict[f"blocks.{l}.attn.W_O"] = W_O 45 | state_dict[f"blocks.{l}.attn.b_O"] = old_state_dict[f"blocks.{l}.attn.proj.bias"] 46 | 47 | state_dict[f"blocks.{l}.ln2.w"] = old_state_dict[f"blocks.{l}.ln2.weight"] 48 | state_dict[f"blocks.{l}.ln2.b"] = old_state_dict[f"blocks.{l}.ln2.bias"] 49 | 50 | W_in = old_state_dict[f"blocks.{l}.mlp.0.weight"] 51 | state_dict[f"blocks.{l}.mlp.W_in"] = W_in.T 52 | state_dict[f"blocks.{l}.mlp.b_in"] = old_state_dict[f"blocks.{l}.mlp.0.bias"] 53 | 54 | W_out = old_state_dict[f"blocks.{l}.mlp.2.weight"] 55 | state_dict[f"blocks.{l}.mlp.W_out"] = W_out.T 56 | state_dict[f"blocks.{l}.mlp.b_out"] = old_state_dict[f"blocks.{l}.mlp.2.bias"] 57 | 58 | state_dict["unembed.W_U"] = old_state_dict["head.weight"].T 59 | 60 | state_dict["ln_final.w"] = old_state_dict["ln_f.weight"] 61 | state_dict["ln_final.b"] = old_state_dict["ln_f.bias"] 62 | 63 | return state_dict 64 | -------------------------------------------------------------------------------- /transformer_lens/pretrained/weight_conversions/mistral.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | 4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 5 | 6 | 7 | def convert_mistral_weights(mistral, cfg: HookedTransformerConfig): 8 | state_dict = {} 9 | 10 | state_dict["embed.W_E"] = mistral.model.embed_tokens.weight 11 | 12 | assert cfg.n_key_value_heads is not None # keep mypy happy 13 | assert cfg.d_mlp is not None # keep mypy happy 14 | 15 | # Mistral has no biases anywhere 16 | for l in range(cfg.n_layers): 17 | state_dict[f"blocks.{l}.ln1.w"] = mistral.model.layers[l].input_layernorm.weight 18 | 19 | W_Q = mistral.model.layers[l].self_attn.q_proj.weight 20 | W_K = mistral.model.layers[l].self_attn.k_proj.weight 21 | W_V = mistral.model.layers[l].self_attn.v_proj.weight 22 | W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) 23 | W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads) 24 | W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads) 25 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 26 | state_dict[f"blocks.{l}.attn._W_K"] = W_K 27 | state_dict[f"blocks.{l}.attn._W_V"] = W_V 28 | 29 | state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) 30 | state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros( 31 | cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype 32 | ) 33 | state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros( 34 | cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype 35 | ) 36 | 37 | W_O = mistral.model.layers[l].self_attn.o_proj.weight 38 | W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) 39 | state_dict[f"blocks.{l}.attn.W_O"] = W_O 40 | 41 | state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 42 | 43 | state_dict[f"blocks.{l}.ln2.w"] = mistral.model.layers[l].post_attention_layernorm.weight 44 | 45 | state_dict[f"blocks.{l}.mlp.W_in"] = mistral.model.layers[l].mlp.up_proj.weight.T 46 | state_dict[f"blocks.{l}.mlp.W_gate"] = mistral.model.layers[l].mlp.gate_proj.weight.T 47 | state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) 48 | 49 | state_dict[f"blocks.{l}.mlp.W_out"] = mistral.model.layers[l].mlp.down_proj.weight.T 50 | state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 51 | 52 | state_dict["ln_final.w"] = mistral.model.norm.weight 53 | 54 | state_dict["unembed.W_U"] = mistral.lm_head.weight.T 55 | state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 56 | 57 | return state_dict 58 | -------------------------------------------------------------------------------- /transformer_lens/pretrained/weight_conversions/mixtral.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | 4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 5 | 6 | 7 | def convert_mixtral_weights(mixtral, cfg: HookedTransformerConfig): 8 | # The same as Mistral, but with the MLP replaced with MoE 9 | # As with Mistral, Mixtral has no biases 10 | 11 | state_dict = {} 12 | 13 | assert cfg.n_key_value_heads is not None # keep mypy happy 14 | assert cfg.d_mlp is not None 15 | assert cfg.num_experts is not None 16 | 17 | state_dict["embed.W_E"] = mixtral.model.embed_tokens.weight 18 | 19 | for l in range(cfg.n_layers): 20 | state_dict[f"blocks.{l}.ln1.w"] = mixtral.model.layers[l].input_layernorm.weight 21 | 22 | W_Q = mixtral.model.layers[l].self_attn.q_proj.weight 23 | W_K = mixtral.model.layers[l].self_attn.k_proj.weight 24 | W_V = mixtral.model.layers[l].self_attn.v_proj.weight 25 | W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) 26 | W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads) 27 | W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads) 28 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 29 | state_dict[f"blocks.{l}.attn._W_K"] = W_K 30 | state_dict[f"blocks.{l}.attn._W_V"] = W_V 31 | 32 | state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) 33 | state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros( 34 | cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype 35 | ) 36 | state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros( 37 | cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype 38 | ) 39 | 40 | W_O = mixtral.model.layers[l].self_attn.o_proj.weight 41 | W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) 42 | state_dict[f"blocks.{l}.attn.W_O"] = W_O 43 | 44 | state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 45 | 46 | state_dict[f"blocks.{l}.ln2.w"] = mixtral.model.layers[l].post_attention_layernorm.weight 47 | 48 | state_dict[f"blocks.{l}.mlp.W_gate.weight"] = mixtral.model.layers[ 49 | l 50 | ].block_sparse_moe.gate.weight 51 | 52 | # The mapping here from wn to W_{in/out/gate} is a bit confusing: 53 | # w1 -> W_gate 54 | # w2 -> W_out 55 | # w3 -> W_in 56 | # See https://github.com/mistralai/mistral-inference/blob/8598cf582091a596671be31990448e0620017851/mistral/model.py#L128 for reference 57 | for e in range(cfg.num_experts): 58 | state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.weight"] = ( 59 | mixtral.model.layers[l].block_sparse_moe.experts[e].w3.weight 60 | ) 61 | state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.weight"] = ( 62 | mixtral.model.layers[l].block_sparse_moe.experts[e].w1.weight 63 | ) 64 | state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.weight"] = ( 65 | mixtral.model.layers[l].block_sparse_moe.experts[e].w2.weight 66 | ) 67 | 68 | state_dict["ln_final.w"] = mixtral.model.norm.weight.data 69 | 70 | state_dict["unembed.W_U"] = mixtral.lm_head.weight.T 71 | state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 72 | 73 | return state_dict 74 | -------------------------------------------------------------------------------- /transformer_lens/pretrained/weight_conversions/nanogpt.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | 4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 5 | 6 | 7 | def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig): 8 | """For https://github.com/karpathy/nanoGPT 9 | There are two complications with converting nanogpt models: 10 | The first is that some state dicts have an unwanted prefix on keys that needs to be removed. 11 | The second is that the models can be saved with or without bias. By default, there 12 | is no bias. This function can handle both cases.""" 13 | # Nanogpt models saved after torch.compile() have this unwanted prefix 14 | # This is a simple way to remove it 15 | unwanted_prefix = "_orig_mod." 16 | for k, v in list(old_state_dict.items()): 17 | if k.startswith(unwanted_prefix): 18 | old_state_dict[k[len(unwanted_prefix) :]] = old_state_dict.pop(k) 19 | 20 | new_state_dict = {} 21 | new_state_dict["pos_embed.W_pos"] = old_state_dict["transformer.wpe.weight"] 22 | new_state_dict["embed.W_E"] = old_state_dict["transformer.wte.weight"] 23 | 24 | new_state_dict["ln_final.w"] = old_state_dict["transformer.ln_f.weight"] 25 | new_state_dict["ln_final.b"] = torch.zeros_like(old_state_dict["transformer.ln_f.weight"]) 26 | new_state_dict["unembed.W_U"] = old_state_dict["lm_head.weight"].T 27 | 28 | bias = False 29 | if "transformer.ln_f.bias" in old_state_dict: 30 | bias = True 31 | new_state_dict["ln_final.b"] = old_state_dict["transformer.ln_f.bias"] 32 | 33 | for layer in range(cfg.n_layers): 34 | layer_key = f"transformer.h.{layer}" 35 | 36 | new_state_dict[f"blocks.{layer}.ln1.w"] = old_state_dict[f"{layer_key}.ln_1.weight"] 37 | # A bias of zeros is required for folding layer norm 38 | new_state_dict[f"blocks.{layer}.ln1.b"] = torch.zeros_like( 39 | old_state_dict[f"{layer_key}.ln_1.weight"] 40 | ) 41 | new_state_dict[f"blocks.{layer}.ln2.w"] = old_state_dict[f"{layer_key}.ln_2.weight"] 42 | new_state_dict[f"blocks.{layer}.ln2.b"] = torch.zeros_like( 43 | old_state_dict[f"{layer_key}.ln_2.weight"] 44 | ) 45 | 46 | W = old_state_dict[f"{layer_key}.attn.c_attn.weight"] 47 | W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=0) 48 | W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads) 49 | W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads) 50 | W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads) 51 | new_state_dict[f"blocks.{layer}.attn.W_Q"] = W_Q 52 | new_state_dict[f"blocks.{layer}.attn.W_K"] = W_K 53 | new_state_dict[f"blocks.{layer}.attn.W_V"] = W_V 54 | 55 | W_O = old_state_dict[f"{layer_key}.attn.c_proj.weight"] 56 | W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) 57 | new_state_dict[f"blocks.{layer}.attn.W_O"] = W_O 58 | 59 | new_state_dict[f"blocks.{layer}.mlp.W_in"] = old_state_dict[ 60 | f"{layer_key}.mlp.c_fc.weight" 61 | ].T 62 | new_state_dict[f"blocks.{layer}.mlp.W_out"] = old_state_dict[ 63 | f"{layer_key}.mlp.c_proj.weight" 64 | ].T 65 | 66 | if bias: 67 | new_state_dict[f"blocks.{layer}.ln1.b"] = old_state_dict[f"{layer_key}.ln_1.bias"] 68 | new_state_dict[f"blocks.{layer}.ln2.b"] = old_state_dict[f"{layer_key}.ln_2.bias"] 69 | new_state_dict[f"blocks.{layer}.mlp.b_in"] = old_state_dict[ 70 | f"{layer_key}.mlp.c_fc.bias" 71 | ] 72 | new_state_dict[f"blocks.{layer}.mlp.b_out"] = old_state_dict[ 73 | f"{layer_key}.mlp.c_proj.bias" 74 | ] 75 | 76 | B = old_state_dict[f"{layer_key}.attn.c_attn.bias"] 77 | B_Q, B_K, B_V = torch.tensor_split(B, 3, dim=0) 78 | B_Q = einops.rearrange(B_Q, "(i h)->i h", i=cfg.n_heads) 79 | B_K = einops.rearrange(B_K, "(i h)->i h", i=cfg.n_heads) 80 | B_V = einops.rearrange(B_V, "(i h)->i h", i=cfg.n_heads) 81 | new_state_dict[f"blocks.{layer}.attn.b_Q"] = B_Q 82 | new_state_dict[f"blocks.{layer}.attn.b_K"] = B_K 83 | new_state_dict[f"blocks.{layer}.attn.b_V"] = B_V 84 | new_state_dict[f"blocks.{layer}.attn.b_O"] = old_state_dict[ 85 | f"{layer_key}.attn.c_proj.bias" 86 | ] 87 | 88 | return new_state_dict 89 | -------------------------------------------------------------------------------- /transformer_lens/pretrained/weight_conversions/neel_solu_old.py: -------------------------------------------------------------------------------- 1 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 2 | 3 | 4 | def convert_neel_solu_old_weights(state_dict: dict, cfg: HookedTransformerConfig): 5 | """ 6 | Converts the weights of my old SoLU models to the HookedTransformer format. 7 | Takes as input a state dict, *not* a model object. 8 | 9 | There are a bunch of dumb bugs in the original code, sorry! 10 | 11 | Models 1L, 2L, 4L and 6L have left facing weights (ie, weights have shape 12 | [dim_out, dim_in]) while HookedTransformer does right facing (ie [dim_in, 13 | dim_out]). 14 | 15 | 8L has *just* a left facing W_pos, the rest right facing. 16 | 17 | And some models were trained with 18 | """ 19 | # Early models have left facing W_pos 20 | reverse_pos = cfg.n_layers <= 8 21 | 22 | # Models prior to 8L have left facing everything (8L has JUST left facing W_pos - sorry! Stupid bug) 23 | reverse_weights = cfg.n_layers <= 6 24 | 25 | new_state_dict = {} 26 | for k, v in state_dict.items(): 27 | k = k.replace("norm", "ln") 28 | if k.startswith("ln."): 29 | k = k.replace("ln.", "ln_final.") 30 | new_state_dict[k] = v 31 | 32 | if reverse_pos: 33 | new_state_dict["pos_embed.W_pos"] = new_state_dict["pos_embed.W_pos"].T 34 | if reverse_weights: 35 | for k, v in new_state_dict.items(): 36 | if "W_" in k and "W_pos" not in k: 37 | new_state_dict[k] = v.transpose(-2, -1) 38 | return new_state_dict 39 | -------------------------------------------------------------------------------- /transformer_lens/pretrained/weight_conversions/neo.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | 4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 5 | 6 | 7 | def convert_neo_weights(neo, cfg: HookedTransformerConfig): 8 | state_dict = {} 9 | 10 | state_dict["embed.W_E"] = neo.transformer.wte.weight 11 | state_dict["pos_embed.W_pos"] = neo.transformer.wpe.weight 12 | 13 | for l in range(cfg.n_layers): 14 | state_dict[f"blocks.{l}.ln1.w"] = neo.transformer.h[l].ln_1.weight 15 | state_dict[f"blocks.{l}.ln1.b"] = neo.transformer.h[l].ln_1.bias 16 | 17 | W_Q = neo.transformer.h[l].attn.attention.q_proj.weight 18 | W_K = neo.transformer.h[l].attn.attention.k_proj.weight 19 | W_V = neo.transformer.h[l].attn.attention.v_proj.weight 20 | W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads) 21 | W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads) 22 | W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads) 23 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 24 | state_dict[f"blocks.{l}.attn.W_K"] = W_K 25 | state_dict[f"blocks.{l}.attn.W_V"] = W_V 26 | 27 | state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) 28 | state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) 29 | state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) 30 | 31 | W_O = neo.transformer.h[l].attn.attention.out_proj.weight 32 | W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) 33 | state_dict[f"blocks.{l}.attn.W_O"] = W_O 34 | state_dict[f"blocks.{l}.attn.b_O"] = neo.transformer.h[l].attn.attention.out_proj.bias 35 | 36 | state_dict[f"blocks.{l}.ln2.w"] = neo.transformer.h[l].ln_2.weight 37 | state_dict[f"blocks.{l}.ln2.b"] = neo.transformer.h[l].ln_2.bias 38 | 39 | state_dict[f"blocks.{l}.mlp.W_in"] = neo.transformer.h[l].mlp.c_fc.weight.T 40 | state_dict[f"blocks.{l}.mlp.b_in"] = neo.transformer.h[l].mlp.c_fc.bias 41 | 42 | state_dict[f"blocks.{l}.mlp.W_out"] = neo.transformer.h[l].mlp.c_proj.weight.T 43 | state_dict[f"blocks.{l}.mlp.b_out"] = neo.transformer.h[l].mlp.c_proj.bias 44 | state_dict["ln_final.w"] = neo.transformer.ln_f.weight 45 | state_dict["ln_final.b"] = neo.transformer.ln_f.bias 46 | 47 | state_dict["unembed.W_U"] = neo.lm_head.weight.T 48 | state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 49 | return state_dict 50 | -------------------------------------------------------------------------------- /transformer_lens/pretrained/weight_conversions/neox.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | 4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 5 | 6 | 7 | def convert_neox_weights(neox, cfg: HookedTransformerConfig): 8 | state_dict = {} 9 | 10 | state_dict["embed.W_E"] = neox.gpt_neox.embed_in.weight 11 | 12 | for l in range(cfg.n_layers): 13 | state_dict[f"blocks.{l}.ln1.w"] = neox.gpt_neox.layers[l].input_layernorm.weight 14 | state_dict[f"blocks.{l}.ln1.b"] = neox.gpt_neox.layers[l].input_layernorm.bias 15 | 16 | # For some inexplicable reason, NeoX both uses the concatenated QKV 17 | # matmul of GPT-2 (afaict this has a neglible performance impact) AND 18 | # has the flattened axis in the DIFFERENT order of (head_index qkv 19 | # d_head) - this took me an hour to debug... 20 | W = neox.gpt_neox.layers[l].attention.query_key_value.weight 21 | W = einops.rearrange(W, "(i qkv h) m->qkv i m h", i=cfg.n_heads, qkv=3) 22 | 23 | # Fold in layer norm weights 24 | state_dict[f"blocks.{l}.attn.W_Q"] = W[0] 25 | state_dict[f"blocks.{l}.attn.W_K"] = W[1] 26 | state_dict[f"blocks.{l}.attn.W_V"] = W[2] 27 | 28 | qkv_bias = neox.gpt_neox.layers[l].attention.query_key_value.bias 29 | qkv_bias = einops.rearrange( 30 | qkv_bias, 31 | "(index qkv head)->qkv index head", 32 | qkv=3, 33 | index=cfg.n_heads, 34 | head=cfg.d_head, 35 | ) 36 | # Fold in layer norm biases 37 | state_dict[f"blocks.{l}.attn.b_Q"] = qkv_bias[0] 38 | state_dict[f"blocks.{l}.attn.b_K"] = qkv_bias[1] 39 | state_dict[f"blocks.{l}.attn.b_V"] = qkv_bias[2] 40 | 41 | W_O = neox.gpt_neox.layers[l].attention.dense.weight 42 | W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) 43 | state_dict[f"blocks.{l}.attn.W_O"] = W_O 44 | state_dict[f"blocks.{l}.attn.b_O"] = neox.gpt_neox.layers[l].attention.dense.bias 45 | 46 | state_dict[f"blocks.{l}.ln2.w"] = neox.gpt_neox.layers[l].post_attention_layernorm.weight 47 | state_dict[f"blocks.{l}.ln2.b"] = neox.gpt_neox.layers[l].post_attention_layernorm.bias 48 | 49 | state_dict[f"blocks.{l}.mlp.W_in"] = neox.gpt_neox.layers[l].mlp.dense_h_to_4h.weight.T 50 | state_dict[f"blocks.{l}.mlp.b_in"] = neox.gpt_neox.layers[l].mlp.dense_h_to_4h.bias 51 | 52 | state_dict[f"blocks.{l}.mlp.W_out"] = neox.gpt_neox.layers[l].mlp.dense_4h_to_h.weight.T 53 | state_dict[f"blocks.{l}.mlp.b_out"] = neox.gpt_neox.layers[l].mlp.dense_4h_to_h.bias 54 | state_dict["ln_final.w"] = neox.gpt_neox.final_layer_norm.weight 55 | state_dict["ln_final.b"] = neox.gpt_neox.final_layer_norm.bias 56 | 57 | state_dict["unembed.W_U"] = neox.embed_out.weight.T 58 | state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 59 | return state_dict 60 | -------------------------------------------------------------------------------- /transformer_lens/pretrained/weight_conversions/opt.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | 4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 5 | 6 | 7 | def convert_opt_weights(opt, cfg: HookedTransformerConfig): 8 | state_dict = {} 9 | 10 | state_dict["embed.W_E"] = opt.model.decoder.embed_tokens.weight 11 | state_dict["pos_embed.W_pos"] = opt.model.decoder.embed_positions.weight[2:, :] 12 | 13 | for l in range(cfg.n_layers): 14 | state_dict[f"blocks.{l}.ln1.w"] = opt.model.decoder.layers[l].self_attn_layer_norm.weight 15 | state_dict[f"blocks.{l}.ln1.b"] = opt.model.decoder.layers[l].self_attn_layer_norm.bias 16 | 17 | W_Q = opt.model.decoder.layers[l].self_attn.q_proj.weight 18 | W_K = opt.model.decoder.layers[l].self_attn.k_proj.weight 19 | W_V = opt.model.decoder.layers[l].self_attn.v_proj.weight 20 | W_Q = einops.rearrange( 21 | W_Q, 22 | "(index d_head) d_model->index d_model d_head", 23 | index=cfg.n_heads, 24 | ) 25 | W_K = einops.rearrange( 26 | W_K, 27 | "(index d_head) d_model->index d_model d_head", 28 | index=cfg.n_heads, 29 | ) 30 | W_V = einops.rearrange( 31 | W_V, 32 | "(index d_head) d_model->index d_model d_head", 33 | index=cfg.n_heads, 34 | ) 35 | 36 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 37 | state_dict[f"blocks.{l}.attn.W_K"] = W_K 38 | state_dict[f"blocks.{l}.attn.W_V"] = W_V 39 | 40 | q_bias = einops.rearrange( 41 | opt.model.decoder.layers[l].self_attn.q_proj.bias, 42 | "(head_index d_head)->head_index d_head", 43 | head_index=cfg.n_heads, 44 | d_head=cfg.d_head, 45 | ) 46 | k_bias = einops.rearrange( 47 | opt.model.decoder.layers[l].self_attn.k_proj.bias, 48 | "(head_index d_head)->head_index d_head", 49 | head_index=cfg.n_heads, 50 | d_head=cfg.d_head, 51 | ) 52 | v_bias = einops.rearrange( 53 | opt.model.decoder.layers[l].self_attn.v_proj.bias, 54 | "(head_index d_head)->head_index d_head", 55 | head_index=cfg.n_heads, 56 | d_head=cfg.d_head, 57 | ) 58 | 59 | state_dict[f"blocks.{l}.attn.b_Q"] = q_bias 60 | state_dict[f"blocks.{l}.attn.b_K"] = k_bias 61 | state_dict[f"blocks.{l}.attn.b_V"] = v_bias 62 | 63 | W_O = opt.model.decoder.layers[l].self_attn.out_proj.weight 64 | W_O = einops.rearrange( 65 | W_O, 66 | "d_model (index d_head)->index d_head d_model", 67 | index=cfg.n_heads, 68 | ) 69 | state_dict[f"blocks.{l}.attn.W_O"] = W_O 70 | state_dict[f"blocks.{l}.attn.b_O"] = opt.model.decoder.layers[l].self_attn.out_proj.bias 71 | 72 | state_dict[f"blocks.{l}.ln2.w"] = opt.model.decoder.layers[l].final_layer_norm.weight 73 | state_dict[f"blocks.{l}.ln2.b"] = opt.model.decoder.layers[l].final_layer_norm.bias 74 | 75 | state_dict[f"blocks.{l}.mlp.W_in"] = opt.model.decoder.layers[l].fc1.weight.T 76 | state_dict[f"blocks.{l}.mlp.W_out"] = opt.model.decoder.layers[l].fc2.weight.T 77 | 78 | state_dict[f"blocks.{l}.mlp.b_in"] = opt.model.decoder.layers[l].fc1.bias 79 | state_dict[f"blocks.{l}.mlp.b_out"] = opt.model.decoder.layers[l].fc2.bias 80 | state_dict["ln_final.w"] = opt.model.decoder.final_layer_norm.weight 81 | state_dict["ln_final.b"] = opt.model.decoder.final_layer_norm.bias 82 | state_dict["unembed.W_U"] = opt.lm_head.weight.T 83 | state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 84 | return state_dict 85 | -------------------------------------------------------------------------------- /transformer_lens/pretrained/weight_conversions/phi.py: -------------------------------------------------------------------------------- 1 | import einops 2 | 3 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 4 | 5 | 6 | def convert_phi_weights(phi, cfg: HookedTransformerConfig): 7 | state_dict = {} 8 | 9 | state_dict["embed.W_E"] = phi.model.embed_tokens.weight 10 | 11 | for l in range(cfg.n_layers): 12 | state_dict[f"blocks.{l}.ln1.w"] = phi.model.layers[l].input_layernorm.weight 13 | state_dict[f"blocks.{l}.ln1.b"] = phi.model.layers[l].input_layernorm.bias 14 | 15 | W_Q = phi.model.layers[l].self_attn.q_proj.weight 16 | W_K = phi.model.layers[l].self_attn.k_proj.weight 17 | W_V = phi.model.layers[l].self_attn.v_proj.weight 18 | W_Q = einops.rearrange( 19 | W_Q, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads 20 | ) 21 | W_K = einops.rearrange( 22 | W_K, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads 23 | ) 24 | W_V = einops.rearrange( 25 | W_V, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads 26 | ) 27 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 28 | state_dict[f"blocks.{l}.attn.W_K"] = W_K 29 | state_dict[f"blocks.{l}.attn.W_V"] = W_V 30 | 31 | b_Q = phi.model.layers[l].self_attn.q_proj.bias 32 | b_K = phi.model.layers[l].self_attn.k_proj.bias 33 | b_V = phi.model.layers[l].self_attn.v_proj.bias 34 | b_Q = einops.rearrange(b_Q, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads) 35 | b_K = einops.rearrange(b_K, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads) 36 | b_V = einops.rearrange(b_V, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads) 37 | state_dict[f"blocks.{l}.attn.b_Q"] = b_Q 38 | state_dict[f"blocks.{l}.attn.b_K"] = b_K 39 | state_dict[f"blocks.{l}.attn.b_V"] = b_V 40 | 41 | W_O = phi.model.layers[l].self_attn.dense.weight 42 | W_O = einops.rearrange( 43 | W_O, "d_model (n_head d_head) -> n_head d_head d_model", n_head=cfg.n_heads 44 | ) 45 | 46 | state_dict[f"blocks.{l}.attn.W_O"] = W_O 47 | state_dict[f"blocks.{l}.attn.b_O"] = phi.model.layers[l].self_attn.dense.bias 48 | 49 | # Layer Norm 1 and 2 are tied. 50 | state_dict[f"blocks.{l}.ln2.w"] = state_dict[f"blocks.{l}.ln1.w"] 51 | state_dict[f"blocks.{l}.ln2.b"] = state_dict[f"blocks.{l}.ln1.b"] 52 | 53 | state_dict[f"blocks.{l}.mlp.W_in"] = phi.model.layers[l].mlp.fc1.weight.T 54 | state_dict[f"blocks.{l}.mlp.b_in"] = phi.model.layers[l].mlp.fc1.bias 55 | state_dict[f"blocks.{l}.mlp.W_out"] = phi.model.layers[l].mlp.fc2.weight.T 56 | state_dict[f"blocks.{l}.mlp.b_out"] = phi.model.layers[l].mlp.fc2.bias 57 | 58 | state_dict["ln_final.w"] = phi.model.final_layernorm.weight 59 | state_dict["ln_final.b"] = phi.model.final_layernorm.bias 60 | 61 | state_dict["unembed.W_U"] = phi.lm_head.weight.T 62 | state_dict["unembed.b_U"] = phi.lm_head.bias 63 | 64 | return state_dict 65 | -------------------------------------------------------------------------------- /transformer_lens/pretrained/weight_conversions/phi3.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | 3 | import einops 4 | import torch 5 | 6 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 7 | 8 | 9 | def convert_phi3_weights(phi, cfg: HookedTransformerConfig): 10 | state_dict = {} 11 | state_dict["embed.W_E"] = phi.model.embed_tokens.weight 12 | 13 | # Some models with this architecture use Grouped Query Attention, and so for these we need to modify 14 | # the state dict keys for the K/V attention weight/biases, prepending "_" to the key names. 15 | using_gqa = cfg.n_key_value_heads is not None 16 | gqa_uscore = "_" if using_gqa else "" 17 | # need a cast since MyPy isn't smart enough to realize that using_gqa implies n_key_value_heads is not None 18 | n_kv_heads = cast(int, cfg.n_key_value_heads if using_gqa else cfg.n_heads) 19 | 20 | for l in range(cfg.n_layers): 21 | state_dict[f"blocks.{l}.ln1.w"] = phi.model.layers[l].input_layernorm.weight 22 | state_dict[f"blocks.{l}.ln1.b"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 23 | 24 | W = phi.model.layers[l].self_attn.qkv_proj.weight 25 | q_dim = cfg.n_heads * cfg.d_head 26 | kv_dim = n_kv_heads * cfg.d_head 27 | W_Q, W_K, W_V = W.split([q_dim, kv_dim, kv_dim], dim=0) 28 | 29 | W_Q = einops.rearrange( 30 | W_Q, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads 31 | ) 32 | W_K = einops.rearrange( 33 | W_K, "(n_kv_head d_head) d_model -> n_kv_head d_model d_head", n_kv_head=n_kv_heads 34 | ) 35 | W_V = einops.rearrange( 36 | W_V, "(n_kv_head d_head) d_model -> n_kv_head d_model d_head", n_kv_head=n_kv_heads 37 | ) 38 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 39 | state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K 40 | state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V 41 | 42 | state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( 43 | cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device 44 | ) 45 | state_dict[f"blocks.{l}.attn.{gqa_uscore}b_K"] = torch.zeros( 46 | n_kv_heads, 47 | cfg.d_head, 48 | dtype=cfg.dtype, 49 | ) 50 | state_dict[f"blocks.{l}.attn.{gqa_uscore}b_V"] = torch.zeros( 51 | n_kv_heads, 52 | cfg.d_head, 53 | dtype=cfg.dtype, 54 | ) 55 | 56 | W_O = phi.model.layers[l].self_attn.o_proj.weight 57 | W_O = einops.rearrange( 58 | W_O, "d_model (n_head d_head) -> n_head d_head d_model", n_head=cfg.n_heads 59 | ) 60 | 61 | state_dict[f"blocks.{l}.attn.W_O"] = W_O 62 | state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 63 | 64 | state_dict[f"blocks.{l}.ln2.w"] = phi.model.layers[l].post_attention_layernorm.weight 65 | state_dict[f"blocks.{l}.ln2.b"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 66 | 67 | W = phi.model.layers[l].mlp.gate_up_proj.weight.T 68 | W_gate, W_in = torch.tensor_split(W, 2, dim=1) 69 | state_dict[f"blocks.{l}.mlp.W_in"] = W_in 70 | state_dict[f"blocks.{l}.mlp.W_gate"] = W_gate 71 | state_dict[f"blocks.{l}.mlp.W_out"] = phi.model.layers[l].mlp.down_proj.weight.T 72 | 73 | state_dict["ln_final.w"] = phi.model.norm.weight 74 | 75 | state_dict["unembed.W_U"] = phi.lm_head.weight.T 76 | state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 77 | 78 | return state_dict 79 | -------------------------------------------------------------------------------- /transformer_lens/pretrained/weight_conversions/qwen.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | 4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 5 | 6 | 7 | def convert_qwen_weights(qwen, cfg: HookedTransformerConfig): 8 | state_dict = {} 9 | model = qwen.transformer 10 | state_dict["embed.W_E"] = model.wte.weight 11 | 12 | assert cfg.d_mlp is not None # keep mypy happy 13 | 14 | for l in range(cfg.n_layers): 15 | state_dict[f"blocks.{l}.ln1.w"] = model.h[l].ln_1.weight 16 | 17 | W_Q, W_K, W_V = model.h[l].attn.c_attn.weight.split(split_size=cfg.d_model, dim=0) 18 | W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) 19 | W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_heads) 20 | W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_heads) 21 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 22 | state_dict[f"blocks.{l}.attn.W_K"] = W_K 23 | state_dict[f"blocks.{l}.attn.W_V"] = W_V 24 | 25 | b_Q, b_K, b_V = model.h[l].attn.c_attn.bias.split(split_size=cfg.d_model, dim=0) 26 | b_Q = einops.rearrange( 27 | b_Q, 28 | "(n_head d_head) -> n_head d_head", 29 | n_head=cfg.n_heads, 30 | ) 31 | b_K = einops.rearrange( 32 | b_K, 33 | "(n_head d_head) -> n_head d_head", 34 | n_head=cfg.n_heads, 35 | ) 36 | b_V = einops.rearrange( 37 | b_V, 38 | "(n_head d_head) -> n_head d_head", 39 | n_head=cfg.n_heads, 40 | ) 41 | state_dict[f"blocks.{l}.attn.b_Q"] = b_Q 42 | state_dict[f"blocks.{l}.attn.b_K"] = b_K 43 | state_dict[f"blocks.{l}.attn.b_V"] = b_V 44 | 45 | W_O = model.h[l].attn.c_proj.weight 46 | W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) 47 | state_dict[f"blocks.{l}.attn.W_O"] = W_O 48 | 49 | state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 50 | 51 | state_dict[f"blocks.{l}.ln2.w"] = model.h[l].ln_2.weight 52 | 53 | state_dict[f"blocks.{l}.mlp.W_in"] = model.h[l].mlp.w1.weight.T 54 | state_dict[f"blocks.{l}.mlp.W_gate"] = model.h[l].mlp.w2.weight.T 55 | state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) 56 | 57 | state_dict[f"blocks.{l}.mlp.W_out"] = model.h[l].mlp.c_proj.weight.T 58 | state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 59 | 60 | state_dict["ln_final.w"] = model.ln_f.weight 61 | 62 | state_dict["unembed.W_U"] = qwen.lm_head.weight.T 63 | state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 64 | 65 | return state_dict 66 | -------------------------------------------------------------------------------- /transformer_lens/pretrained/weight_conversions/qwen2.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | 4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 5 | 6 | 7 | def convert_qwen2_weights(qwen, cfg: HookedTransformerConfig): 8 | # Note that this method is also applied for Qwen1.5 models, since they 9 | # have architecture type Qwen2ForCausalLM. 10 | 11 | state_dict = {} 12 | 13 | state_dict["embed.W_E"] = qwen.model.embed_tokens.weight 14 | 15 | assert cfg.d_mlp is not None # keep mypy happy 16 | 17 | for l in range(cfg.n_layers): 18 | state_dict[f"blocks.{l}.ln1.w"] = qwen.model.layers[l].input_layernorm.weight 19 | 20 | W_Q = qwen.model.layers[l].self_attn.q_proj.weight 21 | W_K = qwen.model.layers[l].self_attn.k_proj.weight 22 | W_V = qwen.model.layers[l].self_attn.v_proj.weight 23 | W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) 24 | W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads) 25 | W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads) 26 | 27 | state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 28 | state_dict[f"blocks.{l}.attn._W_K"] = W_K 29 | state_dict[f"blocks.{l}.attn._W_V"] = W_V 30 | 31 | b_Q = qwen.model.layers[l].self_attn.q_proj.bias 32 | b_Q = einops.rearrange( 33 | b_Q, 34 | "(n_head d_head) -> n_head d_head", 35 | n_head=cfg.n_heads, 36 | ) 37 | 38 | b_K = qwen.model.layers[l].self_attn.k_proj.bias 39 | b_K = einops.rearrange( 40 | b_K, 41 | "(n_head d_head) -> n_head d_head", 42 | n_head=cfg.n_key_value_heads, 43 | ) 44 | 45 | b_V = qwen.model.layers[l].self_attn.v_proj.bias 46 | b_V = einops.rearrange( 47 | b_V, 48 | "(n_head d_head) -> n_head d_head", 49 | n_head=cfg.n_key_value_heads, 50 | ) 51 | 52 | state_dict[f"blocks.{l}.attn.b_Q"] = b_Q 53 | state_dict[f"blocks.{l}.attn._b_K"] = b_K 54 | state_dict[f"blocks.{l}.attn._b_V"] = b_V 55 | 56 | W_O = qwen.model.layers[l].self_attn.o_proj.weight 57 | W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) 58 | state_dict[f"blocks.{l}.attn.W_O"] = W_O 59 | 60 | state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 61 | 62 | state_dict[f"blocks.{l}.ln2.w"] = qwen.model.layers[l].post_attention_layernorm.weight 63 | 64 | state_dict[f"blocks.{l}.mlp.W_in"] = qwen.model.layers[l].mlp.up_proj.weight.T 65 | state_dict[f"blocks.{l}.mlp.W_gate"] = qwen.model.layers[l].mlp.gate_proj.weight.T 66 | state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) 67 | 68 | state_dict[f"blocks.{l}.mlp.W_out"] = qwen.model.layers[l].mlp.down_proj.weight.T 69 | state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 70 | 71 | state_dict["ln_final.w"] = qwen.model.norm.weight 72 | 73 | state_dict["unembed.W_U"] = qwen.lm_head.weight.T 74 | state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 75 | 76 | return state_dict 77 | -------------------------------------------------------------------------------- /transformer_lens/pretrained/weight_conversions/t5.py: -------------------------------------------------------------------------------- 1 | import einops 2 | 3 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 4 | 5 | 6 | def convert_t5_weights(t5, cfg: HookedTransformerConfig): 7 | state_dict = { 8 | "embed.W_E": t5.encoder.embed_tokens.weight, 9 | "unembed.W_U": t5.encoder.embed_tokens.weight.T, 10 | "encoder.0.attn.rel_pos_bias.weight": t5.encoder.block[0] 11 | .layer[0] 12 | .SelfAttention.relative_attention_bias.weight, 13 | } 14 | 15 | for l in range(cfg.n_layers): 16 | block = t5.encoder.block[l] 17 | state_dict[f"encoder.{l}.attn.W_Q"] = einops.rearrange( 18 | block.layer[0].SelfAttention.q.weight, "(i h) m -> i m h", i=cfg.n_heads 19 | ) 20 | state_dict[f"encoder.{l}.attn.W_K"] = einops.rearrange( 21 | block.layer[0].SelfAttention.k.weight, "(i h) m -> i m h", i=cfg.n_heads 22 | ) 23 | 24 | state_dict[f"encoder.{l}.attn.W_V"] = einops.rearrange( 25 | block.layer[0].SelfAttention.v.weight, "(i h) m -> i m h", i=cfg.n_heads 26 | ) 27 | 28 | state_dict[f"encoder.{l}.attn.W_O"] = einops.rearrange( 29 | block.layer[0].SelfAttention.o.weight, 30 | "m (i h) -> i h m", 31 | i=cfg.n_heads, 32 | ) 33 | state_dict[f"encoder.{l}.ln1.w"] = block.layer[0].layer_norm.weight 34 | 35 | # fixme DenseReluDense may be T5DenseGatedActDense instead 36 | state_dict[f"encoder.{l}.mlp.W_in"] = einops.rearrange( 37 | block.layer[1].DenseReluDense.wi.weight, "mlp model -> model mlp" 38 | ) 39 | 40 | state_dict[f"encoder.{l}.mlp.W_out"] = einops.rearrange( 41 | block.layer[1].DenseReluDense.wo.weight, "model mlp -> mlp model" 42 | ) 43 | state_dict[f"encoder.{l}.ln2.w"] = block.layer[1].layer_norm.weight 44 | 45 | state_dict["encoder_final_ln.w"] = t5.encoder.final_layer_norm.weight 46 | 47 | state_dict["decoder.0.attn.rel_pos_bias.weight"] = ( 48 | t5.decoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight 49 | ) 50 | 51 | for l in range(cfg.n_layers): 52 | block = t5.decoder.block[l] 53 | state_dict[f"decoder.{l}.attn.W_Q"] = einops.rearrange( 54 | block.layer[0].SelfAttention.q.weight, "(i h) m -> i m h", i=cfg.n_heads 55 | ) 56 | 57 | state_dict[f"decoder.{l}.attn.W_K"] = einops.rearrange( 58 | block.layer[0].SelfAttention.k.weight, "(i h) m -> i m h", i=cfg.n_heads 59 | ) 60 | state_dict[f"decoder.{l}.attn.W_V"] = einops.rearrange( 61 | block.layer[0].SelfAttention.v.weight, "(i h) m -> i m h", i=cfg.n_heads 62 | ) 63 | 64 | state_dict[f"decoder.{l}.attn.W_O"] = einops.rearrange( 65 | block.layer[0].SelfAttention.o.weight, 66 | "m (i h) -> i h m", 67 | i=cfg.n_heads, 68 | ) 69 | 70 | state_dict[f"decoder.{l}.ln1.w"] = block.layer[0].layer_norm.weight 71 | 72 | state_dict[f"decoder.{l}.cross_attn.W_Q"] = einops.rearrange( 73 | block.layer[1].EncDecAttention.q.weight, "(i h) m -> i m h", i=cfg.n_heads 74 | ) 75 | 76 | state_dict[f"decoder.{l}.cross_attn.W_K"] = einops.rearrange( 77 | block.layer[1].EncDecAttention.k.weight, "(i h) m -> i m h", i=cfg.n_heads 78 | ) 79 | 80 | state_dict[f"decoder.{l}.cross_attn.W_V"] = einops.rearrange( 81 | block.layer[1].EncDecAttention.v.weight, "(i h) m -> i m h", i=cfg.n_heads 82 | ) 83 | state_dict[f"decoder.{l}.cross_attn.W_O"] = einops.rearrange( 84 | block.layer[1].EncDecAttention.o.weight, 85 | "m (i h) -> i h m", 86 | i=cfg.n_heads, 87 | ) 88 | state_dict[f"decoder.{l}.ln2.w"] = block.layer[1].layer_norm.weight 89 | 90 | # fixme DenseReluDense may be T5DenseGatedActDense instead 91 | state_dict[f"decoder.{l}.mlp.W_in"] = einops.rearrange( 92 | block.layer[2].DenseReluDense.wi.weight, "mlp model -> model mlp" 93 | ) 94 | state_dict[f"decoder.{l}.mlp.W_out"] = einops.rearrange( 95 | block.layer[2].DenseReluDense.wo.weight, "model mlp -> mlp model" 96 | ) 97 | state_dict[f"decoder.{l}.ln3.w"] = block.layer[2].layer_norm.weight 98 | 99 | state_dict["decoder_final_ln.w"] = t5.decoder.final_layer_norm.weight 100 | 101 | return state_dict 102 | -------------------------------------------------------------------------------- /transformer_lens/utilities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TransformerLensOrg/TransformerLens/b5a16f849649a237cc02cc2c272ae4dc2085abe4/transformer_lens/utilities/__init__.py -------------------------------------------------------------------------------- /transformer_lens/utilities/activation_functions.py: -------------------------------------------------------------------------------- 1 | """Activation Functions. 2 | 3 | Utilities for interacting with all supported activation functions. 4 | """ 5 | from typing import Callable, Dict 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from transformer_lens.utils import gelu_fast, gelu_new, solu 11 | 12 | # Convenient type for the format of each activation function 13 | ActivationFunction = Callable[..., torch.Tensor] 14 | 15 | # All currently supported activation functions. To add a new function, simply 16 | # put the name of the function as the key, and the value as the actual callable. 17 | SUPPORTED_ACTIVATIONS: Dict[str, ActivationFunction] = { 18 | "solu": solu, 19 | "solu_ln": solu, 20 | "gelu_new": gelu_new, 21 | "gelu_fast": gelu_fast, 22 | "silu": F.silu, 23 | "relu": F.relu, 24 | "gelu": F.gelu, 25 | "gelu_pytorch_tanh": lambda tensor: F.gelu(tensor, approximate="tanh"), 26 | } 27 | -------------------------------------------------------------------------------- /transformer_lens/utilities/addmm.py: -------------------------------------------------------------------------------- 1 | """Addmm 2 | 3 | Implementations of Addmm functions matching Huggingface implementations. 4 | """ 5 | import torch 6 | from jaxtyping import Float 7 | 8 | 9 | def vanilla_addmm( 10 | input: Float[torch.Tensor, "... #o"], # Must be broadcastable to "m o" 11 | mat1: Float[torch.Tensor, "m n"], 12 | mat2: Float[torch.Tensor, "n o"], 13 | ) -> Float[torch.Tensor, "m o"]: 14 | """Typechecked version of torch.addmm. 15 | 16 | Note that both mat1 and mat2 *must* be 2d matrices. 17 | """ 18 | return torch.addmm(input, mat1, mat2) 19 | 20 | 21 | def batch_addmm( 22 | bias: Float[torch.Tensor, "... #d_out"], # Must be broadcastable to "... d_out" 23 | weight: Float[torch.Tensor, "d_in d_out"], 24 | x: Float[torch.Tensor, "... d_in"], 25 | ) -> Float[torch.Tensor, "... d_out"]: 26 | """Fused add-multiply with support for batch dimensions. 27 | 28 | Must match the Huggingface Conv1D implementation exactly. 29 | https://github.com/huggingface/transformers/blob/9ba9369a2557e53a01378199a9839ec6e82d8bc7/src/transformers/pytorch_utils.py#L102-L106 30 | """ 31 | n_output_features = weight.shape[-1] 32 | size_out = x.size()[:-1] + (n_output_features,) 33 | x = vanilla_addmm(bias, x.view(-1, x.size(-1)), weight) 34 | x = x.view(size_out) 35 | return x 36 | -------------------------------------------------------------------------------- /transformer_lens/utilities/attention.py: -------------------------------------------------------------------------------- 1 | """Attention. 2 | 3 | Utilities for attention components. 4 | """ 5 | 6 | import einops 7 | import torch 8 | import torch.nn.functional as F 9 | from jaxtyping import Float 10 | 11 | 12 | def simple_attn_linear( 13 | input: Float[torch.Tensor, "batch pos d_model"], 14 | w: Float[torch.Tensor, "head_index d_model d_head"], 15 | b: Float[torch.Tensor, "head_index d_head"], 16 | ) -> Float[torch.Tensor, "batch pos head_index d_head"]: 17 | """Linear layer for attention calculation.""" 18 | 19 | if input.device != w.device: 20 | w = w.to(input.device) 21 | if input.device != b.device: 22 | b = b.to(input.device) 23 | 24 | w = einops.rearrange(w, "head_index d_model d_head -> (head_index d_head) d_model") 25 | b_ = einops.rearrange(b, "head_index d_head -> (head_index d_head)") 26 | 27 | return F.linear(input, w, b_).reshape(input.shape[0], input.shape[1], b.shape[0], b.shape[1]) 28 | 29 | 30 | def complex_attn_linear( 31 | input: Float[torch.Tensor, "batch pos head_index d_model"], 32 | w: Float[torch.Tensor, "head_index d_model d_head"], 33 | b: Float[torch.Tensor, "head_index d_head"], 34 | ) -> Float[torch.Tensor, "batch pos head_index d_head"]: 35 | """Linear layer for attention calculation. 36 | 37 | This is almost the same as simple_attn_linear, but the input tensor has an extra head_index dimension, used when calculating the input of each attention head separately. 38 | """ 39 | 40 | # Add singleton dimensions for broadcasting 41 | input = einops.rearrange( 42 | input, "batch pos head_index d_model -> batch pos head_index d_model 1" 43 | ) 44 | w = einops.rearrange(w, "head_index d_model d_head -> 1 1 head_index d_model d_head") 45 | 46 | # Element-wise multiplication and sum over the d_model dimension 47 | result = input * w 48 | result = result.sum(dim=-2) 49 | return result + b 50 | --------------------------------------------------------------------------------