├── .gitignore ├── .readthedocs.yaml ├── .vscode └── settings.json ├── LICENSE ├── README.md ├── docs ├── Makefile ├── images │ ├── logo.png │ ├── logo.svg │ └── notebook_images │ │ ├── inspectgpt2_card.png │ │ └── trace_info_flow_card.png ├── make.bat └── source │ ├── conf.py │ ├── index.rst │ ├── tutorials │ ├── common_hooks.rst │ ├── getting_started.rst │ ├── hooking.rst │ ├── index.rst │ ├── installation.rst │ └── interfaces.rst │ ├── unseal.hooks.rst │ ├── unseal.interface.rst │ ├── unseal.logit_lense.rst │ └── unseal.transformers_util.rst ├── pyproject.toml ├── setup.cfg ├── setup.py ├── tests ├── hooks │ ├── test_common_hooks.py │ ├── test_commons.py │ ├── test_rome_hooks.py │ └── test_util.py └── test_transformer_util.py └── unseal ├── __init__.py ├── circuits └── utils.py ├── hooks ├── __init__.py ├── common_hooks.py ├── commons.py ├── rome_hooks.py └── util.py ├── logit_lense.py ├── transformers_util.py └── visuals ├── __init__.py ├── streamlit_interfaces ├── __init__.py ├── all_layers_single_input.py ├── commons.py ├── compare_two_inputs.py ├── interface_setup.py ├── load_only.py ├── registered_models.json ├── single_layer_single_input.py ├── split_full_model_vis_into_layers.py └── utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | testing.py 133 | 134 | *visualizations/ 135 | *gpt*json 136 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Required for RTD to use 3.9 9 | build: 10 | image: testing 11 | 12 | # Set python version 13 | python: 14 | version: 3.9 15 | install: 16 | - method: setuptools 17 | path: . 18 | 19 | # Build documentation in the docs/ directory with Sphinx 20 | sphinx: 21 | configuration: docs/source/conf.py -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "restructuredtext.confPath": "${workspaceFolder}/docs/source", 3 | "esbonio.server.enabled": false, 4 | "esbonio.sphinx.confDir": "${workspaceFolder}/docs/source" 5 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Tom Lieberum 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unseal - Mechanistic Interpretability for Transformers 2 | 3 | 4 | 5 | 6 | ## Prerequisites 7 | 8 | Unseal requires python 3.6+. 9 | 10 | 11 | ## Installation 12 | 13 | For its visualizations interfaces, Unseal uses [this fork](https://github.com/TomFrederik/pysvelte) of the PySvelte library, which can be installed via pip: 14 | 15 | ```sh 16 | git clone git@github.com:TomFrederik/PySvelte.git 17 | cd PySvelte 18 | pip install -e . 19 | ``` 20 | 21 | In order to run PySvelte, you will also need to install ``npm`` via your package manager. 22 | The hooking functionality of Unseal should still work without PySvelte, but we can't give any guarantees 23 | 24 | Install Unseal via pip 25 | 26 | ```sh 27 | pip install unseal 28 | ``` 29 | 30 | ## Usage 31 | 32 | We refer to our documentation for tutorials and usage guides: 33 | 34 | [Documentation](https://unseal.readthedocs.io/en/latest/) 35 | 36 | 37 | ## Notebooks 38 | 39 | Here are some notebooks that also showcase Unseal's functionalities. 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TomFrederik/unseal/96a8251e9204b4b41213b9a4258058a180e38e82/docs/images/logo.png -------------------------------------------------------------------------------- /docs/images/notebook_images/inspectgpt2_card.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TomFrederik/unseal/96a8251e9204b4b41213b9a4258058a180e38e82/docs/images/notebook_images/inspectgpt2_card.png -------------------------------------------------------------------------------- /docs/images/notebook_images/trace_info_flow_card.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TomFrederik/unseal/96a8251e9204b4b41213b9a4258058a180e38e82/docs/images/notebook_images/trace_info_flow_card.png -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('../..')) 16 | from unseal import __version__ 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'Unseal' 21 | copyright = '2022, Tom Lieberum' 22 | author = 'Tom Lieberum' 23 | 24 | 25 | # The full version, including alpha/beta/rc tags 26 | release = __version__ 27 | 28 | # -- General configuration --------------------------------------------------- 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = [ 34 | 'sphinx.ext.coverage', 35 | 'sphinx.ext.autodoc', 36 | 'sphinx.ext.napoleon', 37 | ] 38 | 39 | # Add any paths that contain templates here, relative to this directory. 40 | templates_path = ['_templates'] 41 | 42 | # List of patterns, relative to source directory, that match files and 43 | # directories to ignore when looking for source files. 44 | # This pattern also affects html_static_path and html_extra_path. 45 | exclude_patterns = [] 46 | 47 | 48 | # -- Options for HTML output ------------------------------------------------- 49 | 50 | # The theme to use for HTML and HTML Help pages. See the documentation for 51 | # a list of builtin themes. 52 | # 53 | html_theme = "sphinx_rtd_theme" 54 | 55 | # Add any paths that contain custom static files (such as style sheets) here, 56 | # relative to this directory. They are copied after the builtin static files, 57 | # so a file named "default.css" will overwrite the builtin "default.css". 58 | html_static_path = ['_static'] 59 | 60 | # -- Extension configuration ------------------------------------------------- 61 | html_theme_options = { 62 | 'canonical_url': 'docs/', 63 | 'analytics_id': 'UA-136588502-1', # Provided by Google in your dashboard 64 | # 'logo_only': False, 65 | # 'display_version': True, 66 | # 'prev_next_buttons_location': 'bottom', 67 | # 'style_external_links': False, 68 | # 'vcs_pageview_mode': '', 69 | # 'style_nav_header_background': 'white', 70 | # Toc options 71 | 'collapse_navigation': False, 72 | 'sticky_navigation': True, 73 | 'navigation_depth': 4, 74 | # 'includehidden': True, 75 | # 'titles_only': False 76 | } -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. Unseal documentation master file, created by 2 | sphinx-quickstart on Wed Feb 9 14:22:48 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | .. _main: 6 | 7 | .. image:: https://github.com/TomFrederik/unseal/blob/main/docs/images/logo.png?raw=true 8 | :width: 400px 9 | 10 | Welcome to Unseal's documentation! 11 | ================================== 12 | 13 | Unseal wants to help you get started in mechanistic interpretability research on transformers! 14 | 15 | Install Unseal by following the instructions on the :ref:`installation page `. 16 | 17 | If you are new to Unseal, you should start by reading the :ref:`getting started ` guide. 18 | 19 | 20 | .. toctree:: 21 | :maxdepth: 1 22 | :caption: Tutorials and Guides 23 | 24 | tutorials/installation 25 | tutorials/getting_started 26 | tutorials/hooking 27 | tutorials/common_hooks 28 | tutorials/interfaces 29 | 30 | .. toctree:: 31 | :maxdepth: 1 32 | :caption: API Reference 33 | 34 | unseal.hooks 35 | unseal.interface 36 | unseal.transformers_util 37 | unseal.logit_lense 38 | 39 | 40 | 41 | Indices and tables 42 | ================== 43 | 44 | * :ref:`genindex` 45 | * :ref:`modindex` 46 | * :ref:`search` 47 | -------------------------------------------------------------------------------- /docs/source/tutorials/common_hooks.rst: -------------------------------------------------------------------------------- 1 | .. _common_hooks: 2 | 3 | ============================ 4 | Common Hooks 5 | ============================ 6 | 7 | The most common hooking functions are supported out of the box by Unseal. 8 | 9 | Some of the methods can be used directly as a function in a hook, others return such a function and some will return the hook itself. 10 | This will be indicated in the docstring. 11 | 12 | Saving Outputs 13 | ============== 14 | 15 | This method can be used directly in the construction of a hook. 16 | 17 | .. automethod:: unseal.hooks.common_hooks.save_output 18 | 19 | 20 | Replacing Activations 21 | ===================== 22 | 23 | This method is a factory and returns a function that can be used in a hook to replace the activation of a layer. 24 | 25 | .. automethod:: unseal.hooks.common_hooks.replace_activation 26 | 27 | 28 | Saving Attention 29 | ===================== 30 | 31 | .. automethod:: unseal.hooks.common_hooks.transformers_get_attention 32 | 33 | 34 | Creating an Attention Hook 35 | =========================== 36 | 37 | .. automethod:: unseal.hooks.common_hooks.create_attention_hook 38 | 39 | 40 | Creating a Logit Hook 41 | =========================== 42 | 43 | .. automethod:: unseal.hooks.common_hooks.create_logit_hook 44 | 45 | 46 | GPT ``_attn`` Wrapper 47 | ===================== 48 | 49 | .. automethod:: unseal.hooks.common_hooks.gpt_attn_wrapper 50 | 51 | 52 | -------------------------------------------------------------------------------- /docs/source/tutorials/getting_started.rst: -------------------------------------------------------------------------------- 1 | .. _getting_started: 2 | 3 | Getting started 4 | =============== 5 | 6 | If you just want to play around with models, then head to the :ref:`interface ` section. 7 | 8 | If you want to learn more about how Unseal works under the hood, check out our section on :ref:`hooking `. -------------------------------------------------------------------------------- /docs/source/tutorials/hooking.rst: -------------------------------------------------------------------------------- 1 | .. _hooking: 2 | 3 | 4 | =============== 5 | Hooks 6 | =============== 7 | 8 | Hooks are at the heart of Unseal. In short, a hook is an access point to a model. It is defined by the point of the model at 9 | which it attaches and by the operation that it executes (usually either during the forward or backward pass). 10 | 11 | To read more about the original concept of a hook in PyTorch read `here `_. 12 | 13 | In Unseal, a hook is an object consisting of a ``layer_name`` (the point at which it attaches), 14 | a ``func`` (the function it executes), and a ``key`` (an identifying string unique to the hook) 15 | 16 | In order to simplify the hooking interface, Unseal wraps every model in the ``hooks.HookedModel`` class. 17 | 18 | 19 | 20 | hooks.HookedModel 21 | ======================= 22 | 23 | You can access the top-level structure of a so-wrapped model by printing it (i.e. its ``__repr__`` property): 24 | 25 | .. code-block:: python 26 | 27 | import torch.nn as nn 28 | from unseal.hooks import HookedModel 29 | 30 | model = nn.Sequential( 31 | nn.Linear(8,64), 32 | nn.ReLU(), 33 | nn.Sequential( 34 | nn.Linear(64,256), 35 | nn.ReLU(), 36 | nn.Linear(256,1) 37 | ) 38 | ) 39 | model = HookedModel(model) 40 | 41 | print(model) 42 | 43 | # equivalent: 44 | # print(model.model) 45 | 46 | ''' Output: 47 | Sequential( 48 | (0): Linear(in_features=8, out_features=64, bias=True) 49 | (1): ReLU() 50 | (2): Sequential( 51 | (0): Linear(in_features=64, out_features=256, bias=True) 52 | (1): ReLU() 53 | (2): Linear(in_features=256, out_features=1, bias=True) 54 | ) 55 | ) 56 | ''' 57 | 58 | 59 | A HookedModel also has special references to every layer which you can access via the ``layers`` attribute: 60 | 61 | .. code-block:: python 62 | 63 | print(model.layers) 64 | '''Output: 65 | OrderedDict([('0', Linear(in_features=8, out_features=64, bias=True)), ('1', ReLU()), ('2', Sequential( 66 | (0): Linear(in_features=64, out_features=256, bias=True) 67 | (1): ReLU() 68 | (2): Linear(in_features=256, out_features=1, bias=True) 69 | )), ('2->0', Linear(in_features=64, out_features=256, bias=True)), ('2->1', ReLU()), ('2->2', Linear(in_features=256, out_features=1, bias=True))]) 70 | ''' 71 | 72 | 73 | You can see that each layer has its own identifying string (e.g. ``'2->2'``). If you want to only display the layer names you can simply call ``model.layers.keys()``. 74 | 75 | Hooked forward passes 76 | --------------------- 77 | 78 | The most important feature of a HookedModel object is its modified ``forward`` method which allows a user to temporarily add a hook to the model, perform a forward pass 79 | and record the result in the context attribute of the HookedModel. 80 | 81 | For this, the forward method takes an additional ``hooks`` argument which is a ``list`` of ``Hooks`` which get registered. After the forward pass, the hooks are removed 82 | again (to ensure consistent behavior). Hooks have access to the ``save_ctx`` attribute of the HookedModel, so anything you want to access later goes there and can 83 | be recalled via ``model.save_ctx[your_hook_key]``. Beware that the context attribute does not get reset automatically, so running a lot of 84 | different hooks can fill up your memory. 85 | 86 | 87 | Writing your own hooks 88 | ====================== 89 | 90 | As mentioned above, hooks are triples ``(layer_name, func, key)``. After choosing the attachment point (the ``layer_name``, an element from ``model.layers.keys()``), 91 | you need to implement the hooking function. 92 | 93 | Every hooking function needs to follow the signature ``save_ctx, input, output -> output``. 94 | 95 | ``save_ctx`` is a dictionary which is initialized empty by the HookedModule class 96 | during the forward pass. ``input`` and ``output`` are the input and output of the module respectively. If the hook is not modifying the output, the function does 97 | not need to return anything, as that is the default behavior. 98 | 99 | For example, let's implement a hook which saves the input and output of the first linear layer in the network we defined above: 100 | 101 | 102 | .. code-block:: python 103 | 104 | import torch 105 | from unseal import Hook 106 | 107 | # define the hooking function 108 | def save_input_output(save_ctx, input, output): 109 | # make sure to not clutter the gpu and not keep track of gradients. 110 | save_ctx['input'] = input.detach().cpu() 111 | save_ctx['output'] = output.detach().cpu() 112 | 113 | # create Hook object 114 | my_hook = Hook('0', func, 'save_input_output_0') 115 | 116 | # create random input tensor 117 | input_tensor = torch.rand((1,8)) 118 | 119 | # forward pass with our new hook 120 | model.forward(input_tensor, hooks=[my_hook]) 121 | 122 | # now we can access the model's context object 123 | print(model.save_ctx['save_input_output_0']['input']) 124 | print(model.save_ctx['save_input_output_0']['output']) 125 | 126 | '''Output: 127 | tensor([[0.5778, 0.0257, 0.4552, 0.4787, 0.9211, 0.0284, 0.8347, 0.9621]]) 128 | tensor([[-0.6566, 1.0794, 0.1455, -0.0396, 0.0411, 0.2184, -0.3484, -0.1095, 129 | -0.2990, -0.1757, 0.1078, 0.2126, 0.4414, 0.1682, -0.2449, 0.0090, 130 | -0.0726, -0.0325, -0.5832, 0.1020, -0.2699, 0.0223, -0.8340, -0.4016, 131 | -0.2808, -0.5337, 0.1518, 1.1230, 1.1380, -0.1437, 0.2738, 0.4592, 132 | -0.7136, -0.3247, 0.2068, -0.5012, 0.4446, -0.4551, 0.2015, -0.3641, 133 | -0.1598, -0.7272, 0.0271, 0.2181, -0.3253, 0.2763, -0.5745, 0.4344, 134 | 0.0255, -0.2492, 0.1586, 0.2404, -0.2033, -0.6197, -0.1098, 0.3736, 135 | 0.1246, -0.4697, -0.7690, 0.0981, -0.0255, 0.2133, 0.3061, 0.1846]]) 136 | ''' 137 | 138 | To make things easier for you, Unseal comes with a few pre-implemented hooking functions, which 139 | we will explain in the next section. -------------------------------------------------------------------------------- /docs/source/tutorials/index.rst: -------------------------------------------------------------------------------- 1 | .. _tutorials: 2 | 3 | Tutorials and Guides 4 | ==================== 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | :caption: Tutorials and Guides 9 | 10 | tutorials/getting_started 11 | tutorials/hooking 12 | -------------------------------------------------------------------------------- /docs/source/tutorials/installation.rst: -------------------------------------------------------------------------------- 1 | .. _installing_unseal: 2 | 3 | ===================== 4 | Installing Unseal 5 | ===================== 6 | 7 | Prerequisites 8 | ------------- 9 | 10 | Unseal requires python 3.6+. 11 | 12 | 13 | Installation 14 | ------------ 15 | 16 | For its interfaces, Unseal uses `this fork `_ of the PySvelte library, which can be installed via pip: 17 | 18 | .. code-block:: console 19 | 20 | git clone git@github.com:TomFrederik/PySvelte.git 21 | cd PySvelte 22 | pip install -e . 23 | 24 | In order to run PySvelte, you will also need to install ``npm`` via your package manager. 25 | 26 | Install Unseal via pip: 27 | 28 | .. code-block:: console 29 | 30 | pip install unseal -------------------------------------------------------------------------------- /docs/source/tutorials/interfaces.rst: -------------------------------------------------------------------------------- 1 | .. _interfaces: 2 | 3 | ==================== 4 | Interfaces in Unseal 5 | ==================== 6 | 7 | .. contents:: Contents 8 | 9 | Unseal wants to provide simple and intuitive interfaces for exploring 10 | large language models. 11 | 12 | At its core it relies on a combination of Streamlit and PySvelte. 13 | 14 | Notebooks 15 | ========== 16 | 17 | Here we collect Google Colab Notebooks that demonstrate Unseal functionalities. 18 | 19 | .. image:: https://github.com/TomFrederik/unseal/blob/main/docs/images/notebook_images/inspectgpt2_card.png?raw=true 20 | :target: https://colab.research.google.com/drive/1Y1y2GnDT-Uzvyp8pUWWXt8lEfHWxje3b?usp=sharing 21 | 22 | 23 | .. image:: https://github.com/TomFrederik/unseal/blob/main/docs/images/notebook_images/trace_info_flow_card.png?raw=true 24 | :target: https://colab.research.google.com/drive/1ljCPvbr7VPEIlbZQvrUceLSDsdeo3oRH?usp=sharing 25 | 26 | 27 | 28 | Streamlit interfaces 29 | ==================== 30 | 31 | Unseal comes with several native interfaces that are ready to use out of the box. 32 | 33 | All the pre-built interfaces are available in the ``unseal.interface.streamlit_interfaces`` package. 34 | 35 | To run any of the interfaces, you can navigate to the ``streamlit_interfaces`` directory and run 36 | 37 | .. code-block:: bash 38 | 39 | streamlit run 40 | 41 | -------------------------------------------------------------------------------- /docs/source/unseal.hooks.rst: -------------------------------------------------------------------------------- 1 | unseal.hooks package 2 | ===================== 3 | 4 | This package handles the nitty-gritty of hooking to a model. 5 | 6 | 7 | hooks module 8 | ------------------------------- 9 | 10 | .. automodule:: unseal.hooks 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | 16 | hooks.common_hooks module 17 | ------------------------------- 18 | 19 | .. automodule:: unseal.hooks.common_hooks 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | 24 | 25 | hooks.rome_hooks module 26 | ------------------------------- 27 | 28 | .. automodule:: unseal.hooks.rome_hooks 29 | :members: 30 | :undoc-members: 31 | :show-inheritance: 32 | 33 | 34 | hooks.util module 35 | ------------------------------- 36 | 37 | .. automodule:: unseal.hooks.util 38 | :members: 39 | :undoc-members: 40 | :show-inheritance: 41 | 42 | hooks.commons module 43 | ------------------------------- 44 | 45 | .. automodule:: unseal.interface.commons 46 | :members: 47 | :undoc-members: 48 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/unseal.interface.rst: -------------------------------------------------------------------------------- 1 | unseal.visuals package 2 | ======================== 3 | 4 | This package creates graphical outputs, e.g. for displaying in GUIs. 5 | 6 | 7 | visuals module 8 | ------------------------------- 9 | 10 | .. automodule:: unseal.visuals 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | 16 | visuals.commons module 17 | ------------------------------- 18 | 19 | .. automodule:: unseal.visuals.commons 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | 24 | 25 | visuals.utils module 26 | ------------------------------- 27 | 28 | .. automodule:: unseal.visuals.utils 29 | :members: 30 | :undoc-members: 31 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/unseal.logit_lense.rst: -------------------------------------------------------------------------------- 1 | logit_lense module 2 | ------------------------------- 3 | 4 | .. automodule:: unseal.logit_lense 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/unseal.transformers_util.rst: -------------------------------------------------------------------------------- 1 | transformers_util module 2 | ------------------------------- 3 | 4 | .. automodule:: unseal.transformers_util 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file=README.md 3 | license_files=LICENSE -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import os, sys 3 | sys.path.insert(0, os.path.abspath(".")) 4 | from unseal import __version__ 5 | 6 | version = __version__ 7 | 8 | with open("README.md", "r") as fh: 9 | long_description = fh.read() 10 | 11 | setup( 12 | name='Unseal', 13 | version=version, 14 | packages=find_packages(exclude=[]), 15 | python_requires='>=3.6.0', 16 | install_requires=[ 17 | 'torch', 18 | 'einops>=0.3.2', 19 | 'numpy', 20 | 'transformers', 21 | 'tqdm', 22 | 'matplotlib', 23 | 'streamlit', 24 | ], 25 | # entry_points={ 26 | # 'console_scripts': [ 27 | # '"unseal compare" = unseal.commands.interfaces.compare_two_inputs:main', 28 | # ] 29 | # }, 30 | description=( 31 | "Unseal: " 32 | "A collection of infrastructure and tools for research in " 33 | "transformer interpretability." 34 | ), 35 | long_description=long_description, 36 | long_description_content_type="text/markdown", 37 | author="The Unseal Team", 38 | author_email="tlieberum@outlook.de", 39 | url="https://github.com/TomFrederik/unseal/", 40 | license="Apache License 2.0", 41 | keywords=[ 42 | "pytorch", 43 | "tensor", 44 | "machine learning", 45 | "neural networks", 46 | "interpretability", 47 | "transformers", 48 | ], 49 | ) 50 | -------------------------------------------------------------------------------- /tests/hooks/test_common_hooks.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TomFrederik/unseal/96a8251e9204b4b41213b9a4258058a180e38e82/tests/hooks/test_common_hooks.py -------------------------------------------------------------------------------- /tests/hooks/test_commons.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Callable 3 | 4 | import pytest 5 | import torch 6 | from unseal.hooks import commons 7 | 8 | class TestHookedModel(): 9 | def test_constructor(self): 10 | model = commons.HookedModel(torch.nn.Module()) 11 | assert model is not None 12 | assert isinstance(model.model, torch.nn.Module) 13 | assert isinstance(model.save_ctx, dict) 14 | assert isinstance(model.layers, OrderedDict) 15 | 16 | 17 | with pytest.raises(TypeError): 18 | model = commons.HookedModel('not a module') 19 | 20 | def test_init_refs(self): 21 | model = commons.HookedModel( 22 | torch.nn.Sequential( 23 | torch.nn.Sequential( 24 | torch.nn.Linear(10,10) 25 | ), 26 | torch.nn.Linear(10,10) 27 | ) 28 | ) 29 | assert list(model.layers.keys()) == ['0', '0->0', '1'] 30 | 31 | def test__hook_wrapper(self): 32 | model = commons.HookedModel(torch.nn.Module()) 33 | def test_func(save_ctx, inp, output): 34 | return 1 35 | 36 | func = model._hook_wrapper(test_func, 'key') 37 | assert isinstance(func, Callable) 38 | 39 | def test_get_ctx_keys(self): 40 | model = commons.HookedModel(torch.nn.Module()) 41 | assert model.get_ctx_keys() == [] 42 | 43 | model.save_ctx['key'] = dict() 44 | assert model.get_ctx_keys() == ['key'] 45 | 46 | def test_repr(self): 47 | model = commons.HookedModel(torch.nn.Module()) 48 | assert model.__repr__() == model.model.__repr__() 49 | 50 | def test_device(self): 51 | model = commons.HookedModel(torch.nn.Linear(10, 10)) 52 | assert model.device.type == 'cpu' 53 | 54 | if torch.cuda.is_available(): 55 | model.to('cuda') 56 | assert model.device.type == 'cuda' 57 | 58 | def test_forward(self): 59 | model = commons.HookedModel(torch.nn.Sequential(torch.nn.Linear(10, 10))) 60 | def fn(save_ctx, input, output): 61 | save_ctx['key'] = 1 62 | hook = commons.Hook('0', fn, 'key') 63 | model.forward(torch.rand(10), [hook]) 64 | assert model.save_ctx['key']['key'] == 1 65 | 66 | def test___call__(self): 67 | model = commons.HookedModel(torch.nn.Sequential(torch.nn.Linear(10, 10))) 68 | def fn(save_ctx, input, output): 69 | save_ctx['key'] = 2 70 | hook = commons.Hook('0', fn, 'key') 71 | model(torch.rand(10), [hook]) 72 | assert model.save_ctx['key']['key'] == 2 73 | 74 | 75 | def test_hook(): 76 | def correct_func(save_ctx, input, output): 77 | save_ctx['key'] = 1 78 | hook = commons.Hook('0', correct_func, 'key') 79 | assert hook.layer_name == '0' 80 | assert hook.func == correct_func 81 | assert hook.key == 'key' 82 | 83 | def false_func(save_ctx): 84 | save_ctx['key'] = 1 85 | with pytest.raises(TypeError): 86 | hook = commons.Hook('0', false_func, 'key') -------------------------------------------------------------------------------- /tests/hooks/test_rome_hooks.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TomFrederik/unseal/96a8251e9204b4b41213b9a4258058a180e38e82/tests/hooks/test_rome_hooks.py -------------------------------------------------------------------------------- /tests/hooks/test_util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Iterable 3 | 4 | from unseal.hooks import util 5 | import numpy as np 6 | import pytest 7 | import torch 8 | def test_create_slice_from_str(): 9 | arr = np.random.rand(10,10,10) 10 | valid_idx_strings = [ 11 | '...,3:5,:', 12 | '4:5,:,:', 13 | '0,0,0', 14 | '-1,-1,-1' 15 | ] 16 | valid_result_arrs = [ 17 | arr[...,3:5,:], 18 | arr[4:5,:,:], 19 | arr[0,0,0], 20 | arr[-1,-1,-1], 21 | ] 22 | for string, subarr in zip(valid_idx_strings, valid_result_arrs): 23 | assert np.all(arr[util.create_slice_from_str(string)] == subarr), f"{util.create_slice_from_str(string)}, {subarr}" 24 | 25 | invalid_idx_strings = [ 26 | '', 27 | ] 28 | for idx in invalid_idx_strings: 29 | with pytest.raises(ValueError): 30 | subarry = arr[util.create_slice_from_str(idx)] 31 | 32 | def test_recursive_to_device(): 33 | if not torch.cuda.is_available(): 34 | logging.warning('CUDA not available, skipping recursive_to_device test.') 35 | return 36 | 37 | tensor = torch.rand(10,10,10) 38 | tensor = util.recursive_to_device(tensor, torch.device('cuda')) 39 | assert tensor.device.type == 'cuda' 40 | 41 | iterable = [tensor, [[tensor], tensor]] 42 | 43 | recursive_cuda_check(iterable) 44 | 45 | def recursive_cuda_check(iterable): 46 | if isinstance(iterable, torch.Tensor): 47 | assert iterable.device.type == 'cuda' 48 | elif isinstance(iterable, Iterable): 49 | for item in iterable: 50 | recursive_cuda_check(item) -------------------------------------------------------------------------------- /tests/test_transformer_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unseal.transformers_util as tutil 3 | from unseal.hooks import HookedModel 4 | 5 | def test_load_model(): 6 | model, tokenizer, config = tutil.load_from_pretrained('gpt2') 7 | assert model is not None 8 | assert tokenizer is not None 9 | assert config is not None 10 | 11 | def test_load_model_with_dir(): 12 | model, tokenizer, config = tutil.load_from_pretrained('gpt-neo-125M', model_dir='EleutherAI') 13 | assert model is not None 14 | assert tokenizer is not None 15 | assert config is not None 16 | 17 | def test_load_model_eleuther_without_dir(): 18 | model, tokenizer, config = tutil.load_from_pretrained('gpt-neo-125M') 19 | assert model is not None 20 | assert tokenizer is not None 21 | assert config is not None 22 | 23 | def test_load_model_with_low_mem(): 24 | model, tokenizer, config = tutil.load_from_pretrained('gpt2', low_cpu_mem_usage=True) 25 | assert model is not None 26 | assert tokenizer is not None 27 | assert config is not None 28 | 29 | def test_get_num_layers_gpt2(): 30 | model, *_ = tutil.load_from_pretrained('gpt2') 31 | model = HookedModel(model) 32 | assert tutil.get_num_layers(model, 'transformer->h') == 12 33 | 34 | def test_get_num_layers_transformer(): 35 | model = torch.nn.Transformer(d_model=10, nhead=2, num_encoder_layers=0, num_decoder_layers=10) 36 | model = HookedModel(model) 37 | assert tutil.get_num_layers(model, 'decoder->layers') 38 | -------------------------------------------------------------------------------- /unseal/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | __version__ = '0.2.1' -------------------------------------------------------------------------------- /unseal/circuits/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import math 3 | from typing import Tuple, TypeVar 4 | 5 | import einops 6 | import torch 7 | from torch import Tensor 8 | from transformers.models.gpt2.modeling_gpt2 import GPT2Attention 9 | from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoSelfAttention 10 | from transformers.models.gptj.modeling_gptj import GPTJAttention 11 | 12 | Attention = TypeVar('Attention', GPT2Attention, GPTNeoSelfAttention, GPTJAttention) 13 | 14 | def get_qkv_weights(attention_module: Attention) -> Tuple[Tensor, Tensor, Tensor]: 15 | if isinstance(attention_module, GPT2Attention): 16 | q, k, v = attention_module.c_proj.weight.chunk(3) 17 | print(f"{q.shape = }") 18 | print(f"{k.shape = }") 19 | print(f"{v.shape = }") 20 | else: 21 | q, k, v = attention_module.q_proj.weight, attention_module.k_proj.weight, attention_module.v_proj.weight 22 | 23 | q = einops.rearrange(q, '(num_heads head_dim) out_dim -> num_heads head_dim out_dim', head_dim=attention_module.head_dim) 24 | k = einops.rearrange(k, '(num_heads head_dim) out_dim -> num_heads head_dim out_dim', head_dim=attention_module.head_dim) 25 | v = einops.rearrange(v, '(num_heads head_dim) out_dim -> num_heads head_dim out_dim', head_dim=attention_module.head_dim) 26 | 27 | return q, k, v 28 | 29 | def get_o_weight(attention_module: Attention) -> Tensor: 30 | if not isinstance(attention_module, Attention): 31 | raise TypeError(f"{attention_module} is not an instance of Attention, has type {type(attention_module)}") 32 | 33 | if isinstance(attention_module, GPT2Attention): 34 | return einops.rearrange(attention_module.c_proj.weight, '(num_heads head_dim) out_dim -> num_heads head_dim out_dim', head_dim=attention_module.head_dim) 35 | else: 36 | return einops.rearrange(attention_module.out_proj.weight, '(num_heads head_dim) out_dim -> num_heads head_dim out_dim', head_dim=attention_module.head_dim) 37 | 38 | def composition(a: Tensor, b: Tensor): 39 | return (a.T @ b).norm(p='fro') / (a.norm(p='fro') * b.norm(p='fro')) 40 | 41 | def q_composition(qk: Tensor, ov: Tensor): 42 | return composition(qk, ov) 43 | 44 | def k_composition(qk: Tensor, ov: Tensor): 45 | return composition(qk.T, ov) 46 | 47 | def v_composition(ov_2: Tensor, ov_1: Tensor): 48 | return composition(ov_2.T, ov_1) 49 | 50 | def get_init_limits(weight: Tensor) -> float: 51 | # compute xavier uniform initialization limits 52 | return (-math.sqrt(6/(weight.shape[0] + weight.shape[1])), math.sqrt(6/(weight.shape[0] + weight.shape[1]))) 53 | 54 | def approx_baseline(shape_1, shape_2, limits_1, limits_2, num_samples, device='cpu'): 55 | baseline = 0 56 | for i in range(num_samples): 57 | mat_1 = torch.distributions.Uniform(limits_1[0], limits_1[1]).rsample(shape_1).to(device) 58 | mat_2 = torch.distributions.Uniform(limits_2[0], limits_2[1]).rsample(shape_2).to(device) 59 | 60 | baseline += composition(mat_1, mat_2) 61 | 62 | return baseline / num_samples 63 | 64 | def compute_all_baselines(attention_module: Attention, num_samples): 65 | device = next(attention_module.parameters()).device 66 | 67 | q, k, v = get_qkv_weights(attention_module) 68 | o = get_o_weight(attention_module) 69 | # print(f"q: {q.shape}") 70 | # print(f"k: {k.shape}") 71 | # print(f"v: {v.shape}") 72 | # print(f"o: {o.shape}") 73 | qk_shape = (q.shape[1],) + (k.shape[1],) 74 | ov_shape = (o.shape[1],) + (v.shape[1],) 75 | # print(f"{qk_shape = }") 76 | # print(f"{ov_shape = }") 77 | qk_limits = get_init_limits(attention_module.qkv_proj.weight) 78 | ov_limits = get_init_limits(o) 79 | 80 | qk_baseline = approx_baseline(qk_shape, ov_shape, qk_limits, ov_limits, num_samples, device) 81 | v_baseline = approx_baseline(ov_shape, ov_shape, ov_limits, ov_limits, num_samples, device) 82 | 83 | return qk_baseline, v_baseline 84 | 85 | @torch.no_grad() 86 | def compute_all_compositions(attn_1: Attention, attn_2: Attention, num_samples: int = 1000, subtract_baseline: bool = False): 87 | if subtract_baseline: 88 | qk_baseline, v_baseline = compute_all_baselines(attn_2, num_samples) 89 | else: 90 | qk_baseline, v_baseline = 0, 0 91 | # print(f"{qk_baseline = }") 92 | # print(f"{v_baseline = }") 93 | 94 | q_1, k_1, v_1 = get_qkv_weights(attn_1) 95 | # print(f"v_1: {v_1.shape}") 96 | o_1 = get_o_weight(attn_1) 97 | # print(f"o_1: {o_1.shape}") 98 | ov_1 = torch.einsum('abc, acd -> abd', einops.rearrange(o_1, 'a c b -> a b c'), v_1) 99 | qk_1 = torch.einsum('abc, acd -> abd', einops.rearrange(q_1, 'a c b -> a b c'), k_1) 100 | # print(f"{qk_1.shape = }") 101 | # print(f"{ov_1.shape = }") 102 | q_2, k_2, v_2 = get_qkv_weights(attn_2) 103 | o_2 = get_o_weight(attn_2) 104 | ov_2 = torch.einsum('abc, acd -> abd', einops.rearrange(o_2, 'a c b -> a b c'), v_2) 105 | qk_2 = torch.einsum('abc, acd -> abd', einops.rearrange(q_2, 'a c b -> a b c'), k_2) 106 | # print(f"{qk_2.shape = }") 107 | # print(f"{ov_2.shape = }") 108 | 109 | q_comps = [] 110 | k_comps = [] 111 | v_comps = [] 112 | 113 | for head_1, head_2 in itertools.product(range(attn_1.num_heads), range(attn_2.num_heads)): 114 | q_comps.append(q_composition(qk_2[head_2], ov_1[head_1])) 115 | k_comps.append(k_composition(qk_2[head_2], ov_1[head_1])) 116 | v_comps.append(v_composition(ov_2[head_2], ov_1[head_1])) 117 | 118 | q_comps = torch.stack(q_comps) - qk_baseline 119 | k_comps = torch.stack(k_comps) - qk_baseline 120 | v_comps = torch.stack(v_comps) - v_baseline 121 | 122 | 123 | return q_comps.clamp(min=0).cpu(), k_comps.clamp(min=0).cpu(), v_comps.clamp(min=0).cpu() 124 | 125 | -------------------------------------------------------------------------------- /unseal/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | from . import common_hooks 2 | from . import util 3 | from .commons import Hook, HookedModel -------------------------------------------------------------------------------- /unseal/hooks/common_hooks.py: -------------------------------------------------------------------------------- 1 | # pre-implemented common hooks 2 | import logging 3 | import math 4 | from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union 5 | 6 | import einops 7 | import torch 8 | import torch.nn.functional as F 9 | from tqdm import tqdm 10 | 11 | from . import util 12 | from .commons import Hook, HookedModel 13 | 14 | 15 | def save_output(cpu: bool = True, detach: bool = True) -> Callable: 16 | """Basic hooking function for saving the output of a module to the global context object 17 | 18 | :param cpu: Whether to save the output to cpu. 19 | :type cpu: bool 20 | :param detach: Whether to detach the output. 21 | :type detach: bool 22 | :return: Function that saves the output to the context object. 23 | :rtype: Callable 24 | """ 25 | def inner(save_ctx, input, output): 26 | if detach: 27 | if isinstance(output, torch.Tensor): 28 | output = output.detach() 29 | else: 30 | logging.warning('Detaching tensor for iterables not implemented') 31 | # not implemented 32 | if cpu: 33 | if isinstance(output, torch.Tensor): 34 | output = output.cpu() 35 | elif isinstance(output, Iterable): # hope for the best 36 | output = util.recursive_to_device(output, 'cpu') 37 | else: 38 | raise TypeError(f"Unsupported type for output {type(output)}") 39 | save_ctx['output'] = output 40 | 41 | return inner 42 | 43 | def replace_activation(indices: str, replacement_tensor: torch.Tensor, tuple_index: int = None) -> Callable: 44 | """Creates a hook which replaces a module's activation (output) with a replacement tensor. 45 | If there is a dimension mismatch, the replacement tensor is copied along the leading dimensions of the output. 46 | 47 | Example: If the activation has shape ``(B, T, D)`` and replacement tensor has shape ``(D,)`` which you want to plug in 48 | at position t in the T dimension for every tensor in the batch, then indices should be ``:,t,:``. 49 | 50 | :param indices: Indices at which to insert the replacement tensor 51 | :type indices: str 52 | :param replacement_tensor: Tensor that is filled in. 53 | :type replacement_tensor: torch.Tensor 54 | :param tuple_index: Index of the tuple in the output of the module. 55 | :type tuple_index: int 56 | :return: Function that replaces part of a given tensor with replacement_tensor 57 | :rtype: Callable 58 | """ 59 | slice_ = util.create_slice_from_str(indices) 60 | def func(save_ctx, input, output): 61 | if tuple_index is None: 62 | # add dummy dimensions if shape mismatch 63 | diff = len(output[slice_].shape) - len(replacement_tensor.shape) 64 | rep = replacement_tensor[(None,)*diff].to(input.device) 65 | # replace part of tensor 66 | output[slice_] = rep 67 | else: 68 | # add dummy dimensions if shape mismatch 69 | diff = len(output[tuple_index][slice_].shape) - len(replacement_tensor.shape) 70 | rep = replacement_tensor[(None,)*diff].to(input.device) 71 | # replace part of tensor 72 | output[tuple_index][slice_] = rep 73 | return output 74 | 75 | return func 76 | 77 | def transformers_get_attention( 78 | heads: Optional[Union[int, Iterable[int], str]] = None, 79 | output_idx: Optional[int] = None, 80 | ) -> Callable: 81 | """Creates a hooking function to get the attention patterns of a given layer. 82 | 83 | :param heads: The heads for which to save the attention, defaults to None 84 | :type heads: Optional[Union[int, Iterable[int], str]], optional 85 | :param output_idx: If the attention module returns a tuple, use this argument to index it, defaults to None 86 | :type output_idx: Optional[int], optional 87 | :return: func, hooking function that saves attention of the specified heads 88 | :rtype: Callable 89 | """ 90 | 91 | # convert head string to slice 92 | if heads is None: 93 | heads = ":" 94 | if isinstance(heads, str): 95 | heads = util.create_slice_from_str(heads) 96 | 97 | def func(save_ctx, input, output): 98 | if output_idx is None: 99 | save_ctx['attn'] = output[:,heads,...].detach().cpu() 100 | else: 101 | save_ctx['attn'] = output[output_idx][:,heads,...].detach().cpu() 102 | 103 | return func 104 | 105 | def create_attention_hook( 106 | layer: int, 107 | key: str, 108 | output_idx: Optional[int] = None, 109 | attn_name: Optional[str] = 'attn', 110 | layer_key_prefix: Optional[str] = None, 111 | heads: Optional[Union[int, Iterable[int], str]] = None 112 | ) -> Hook: 113 | """Creates a hook which saves the attention patterns of a given layer. 114 | 115 | :param layer: The layer to hook. 116 | :type layer: int 117 | :param key: The key to use for saving the attention patterns. 118 | :type key: str 119 | :param output_idx: If the module output is a tuple, index it with this. GPT like models need this to be equal to 2, defaults to None 120 | :type output_idx: Optional[int], optional 121 | :param attn_name: The name of the attention module in the transformer, defaults to 'attn' 122 | :type attn_name: Optional[str], optional 123 | :param layer_key_prefix: The prefix in the model structure before the layer idx, e.g. 'transformer->h', defaults to None 124 | :type layer_key_prefix: Optional[str], optional 125 | :param heads: Which heads to save the attention pattern for. Can be int, tuple of ints or string like '1:3', defaults to None 126 | :type heads: Optional[Union[int, Iterable[int], str]], optional 127 | :return: Hook which saves the attention patterns 128 | :rtype: Hook 129 | """ 130 | if layer_key_prefix is None: 131 | layer_key_prefix = "" 132 | 133 | func = transformers_get_attention(heads, output_idx) 134 | return Hook(f'{layer_key_prefix}{layer}->{attn_name}', func, key) 135 | 136 | 137 | def create_logit_hook( 138 | layer:int, 139 | model: HookedModel, 140 | unembedding_key: str, 141 | layer_key_prefix: Optional[str] = None, 142 | target: Optional[Union[int, List[int]]] = None, 143 | position: Optional[Union[int, List[int]]] = None, 144 | key: Optional[str] = None, 145 | split_heads: Optional[bool] = False, 146 | num_heads: Optional[int] = None, 147 | ) -> Hook: 148 | """Create a hook that saves the logits of a layer's output. 149 | Outputs are saved to save_ctx[key]['logits']. 150 | 151 | :param layer: The number of the layer 152 | :type layer: int 153 | :param model: The model. 154 | :type model: HookedModel 155 | :param unembedding_key: The key/name of the embedding matrix, e.g. 'lm_head' for causal LM models 156 | :type unembedding_key: str 157 | :param layer_key_prefix: The prefix of the key of the layer, e.g. 'transformer->h' for GPT like models 158 | :type layer_key_prefix: str 159 | :param target: The target token(s) to extract logits for. Defaults to all tokens. 160 | :type target: Union[int, List[int]] 161 | :param position: The position for which to extract logits for. Defaults to all positions. 162 | :type position: Union[int, List[int]] 163 | :param key: The key of the hook. Defaults to {layer}_logits. 164 | :type key: str 165 | :param split_heads: Whether to split the heads. Defaults to False. 166 | :type split_heads: bool 167 | :param num_heads: The number of heads to split. Defaults to None. 168 | :type num_heads: int 169 | :return: The hook. 170 | :rtype: Hook 171 | """ 172 | if layer_key_prefix is None: 173 | layer_key_prefix = "" 174 | 175 | # generate slice 176 | if target is None: 177 | target = ":" 178 | else: 179 | if isinstance(target, int): 180 | target = str(target) 181 | else: 182 | target = "[" + ",".join(str(t) for t in target) + "]" 183 | if position is None: 184 | position = ":" 185 | else: 186 | if isinstance(position, int): 187 | position = str(position) 188 | else: 189 | position = "[" + ",".join(str(p) for p in position) + "]" 190 | position_slice = util.create_slice_from_str(f":,{position},:") 191 | target_slice = util.create_slice_from_str(f"{target},:") 192 | 193 | # load the relevant part of the vocab matrix 194 | vocab_matrix = model.layers[unembedding_key].weight[target_slice].T 195 | 196 | # split vocab matrix 197 | if split_heads: 198 | if num_heads is None: 199 | raise ValueError("num_heads must be specified if split_heads is True") 200 | else: 201 | vocab_matrix = einops.rearrange(vocab_matrix, '(num_heads head_dim) vocab_size -> num_heads head_dim vocab_size', num_heads=num_heads) 202 | 203 | def inner(save_ctx, input, output): 204 | if split_heads: 205 | einsum_in = einops.rearrange(output[0][position_slice], 'batch seq_len (num_heads head_dim) -> batch num_heads seq_len head_dim', num_heads=num_heads) 206 | einsum_out = torch.einsum('bcij,cjk->bcik', einsum_in, vocab_matrix) 207 | else: 208 | einsum_in = output[0][position_slice] 209 | einsum_out = torch.einsum('bij,jk->bik', einsum_in, vocab_matrix) 210 | 211 | save_ctx['logits'] = einsum_out.detach().cpu() 212 | 213 | # write key 214 | if key is None: 215 | key = str(layer) + '_logits' 216 | 217 | # create hook 218 | hook = Hook(f'{layer_key_prefix}{layer}', inner, key) 219 | 220 | return hook 221 | 222 | def gpt_attn_wrapper( 223 | func: Callable, 224 | save_ctx: Dict, 225 | c_proj: torch.Tensor, 226 | vocab_embedding: torch.Tensor, 227 | target_ids: torch.Tensor, 228 | batch_size: Optional[int] = None, 229 | ) -> Tuple[Callable, Callable]: 230 | """Wraps around the [AttentionBlock]._attn function to save the individual heads' logits. 231 | This is necessary because the individual heads' logits are not available on a module level and thus not accessible via a hook. 232 | 233 | :param func: original _attn function 234 | :type func: Callable 235 | :param save_ctx: context to which the logits will be saved 236 | :type save_ctx: Dict 237 | :param c_proj: projection matrix, this is W_O in Anthropic's terminology 238 | :type c_proj: torch.Tensor 239 | :param vocab_matrix: vocabulary/embedding matrix, this is W_V in Anthropic's terminology 240 | :type vocab_matrix: torch.Tensor 241 | :param target_ids: indices of the target tokens for which the logits are computed 242 | :type target_ids: torch.Tensor 243 | :param batch_size: batch size to reduce memory footprint, defaults to None 244 | :type batch_size: Optional[int] 245 | :return: inner, func, the wrapped function and the original function 246 | :rtype: Tuple[Callable, Callable] 247 | """ 248 | # TODO Find a smarter/more efficient way of implementing this function 249 | # TODO clean up this function 250 | def inner(query, key, value, *args, **kwargs): 251 | nonlocal c_proj 252 | nonlocal target_ids 253 | nonlocal vocab_embedding 254 | nonlocal batch_size 255 | attn_output, attn_weights = func(query, key, value, *args, **kwargs) 256 | if batch_size is None: 257 | batch_size = attn_output.shape[0] 258 | with torch.no_grad(): 259 | temp = attn_weights[...,None] * value[:,:,None] 260 | if len(c_proj.shape) == 2: 261 | c_proj = einops.rearrange(c_proj, '(head_dim num_heads) out_dim -> head_dim num_heads out_dim', num_heads=attn_output.shape[1]) 262 | c_proj = einops.rearrange(c_proj, 'h n o -> n h o') 263 | temp = temp[0,:,:-1,:-1] # could this be done earlier? 264 | new_temp = [] 265 | for head in tqdm(range(temp.shape[0])): 266 | new_temp.append([]) 267 | for i in range(math.ceil(temp.shape[1] / batch_size)): 268 | out = temp[head, i*batch_size:(i+1)*batch_size] @ c_proj[head] 269 | out = out @ vocab_embedding # compute logits 270 | new_temp[-1].append(out) 271 | 272 | # center logits 273 | new_temp[-1] = torch.cat(new_temp[-1]) 274 | new_temp[-1] -= torch.mean(new_temp[-1], dim=-1, keepdim=True) 275 | # select targets 276 | new_temp[-1] = new_temp[-1][...,torch.arange(len(target_ids)), :, target_ids]#.to('cpu') 277 | 278 | new_temp = torch.stack(new_temp, dim=0) # stack heads 279 | max_pos_value = torch.amax(new_temp).item() 280 | max_neg_value = torch.amax(-new_temp).item() 281 | 282 | save_ctx['logits'] = { 283 | 'pos': (new_temp/max_pos_value).clamp(min=0, max=1).detach().cpu(), 284 | 'neg': (-new_temp/max_neg_value).clamp(min=0, max=1).detach().cpu(), 285 | } 286 | return attn_output, attn_weights 287 | return inner, func 288 | 289 | #TODO update docs here 290 | def additive_output_noise( 291 | indices: str, 292 | mean: Optional[float] = 0, 293 | std: Optional[float] = 0.1 294 | ) -> Callable: 295 | slice_ = util.create_slice_from_str(indices) 296 | def func(save_ctx, input, output): 297 | noise = mean + std * torch.randn_like(output[slice_]) 298 | output[slice_] += noise 299 | return output 300 | return func 301 | 302 | def hidden_patch_hook_fn( 303 | position: int, 304 | replacement_tensor: torch.Tensor, 305 | ) -> Callable: 306 | indices = "...," + str(position) + len(replacement_tensor.shape) * ",:" 307 | inner = replace_activation(indices, replacement_tensor) 308 | def func(save_ctx, input, output): 309 | output[0][...] = inner(save_ctx, input, output[0]) 310 | return output 311 | -------------------------------------------------------------------------------- /unseal/hooks/commons.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from dataclasses import dataclass 3 | from inspect import signature 4 | from typing import List, Callable 5 | 6 | import torch 7 | 8 | from . import util 9 | 10 | class Hook: 11 | layer_name: str 12 | func: Callable 13 | key: str 14 | def __init__(self, layer_name: str, func: Callable, key: str): 15 | # check that func takes three arguments 16 | if len(signature(func).parameters) != 3: 17 | raise TypeError(f'Hook function {func.__name__} should have three arguments, but has {len(signature(func).parameters)}.') 18 | 19 | self.layer_name = layer_name 20 | self.func = func 21 | self.key = key 22 | 23 | class HookedModel(torch.nn.Module): 24 | def __init__(self, model): 25 | """Wrapper around a module that allows forward passes with hooks and a context object. 26 | 27 | :param model: Model to be hooked 28 | :type model: nn.Module 29 | :raises TypeError: Incorrect model type 30 | """ 31 | super().__init__() 32 | 33 | # check inputs 34 | if not isinstance(model, torch.nn.Module): 35 | raise TypeError(f"model should be type torch.nn.Module but is {type(model)}") 36 | 37 | self.model = model 38 | 39 | # initialize hooks 40 | self.init_refs() 41 | 42 | # init context for accessing hook output 43 | self.save_ctx = dict() 44 | 45 | def init_refs(self): 46 | """Creates references for every layer in a model.""" 47 | 48 | self.layers = OrderedDict() 49 | 50 | # recursive naming function 51 | def name_layers(net, prefix=[]): 52 | if hasattr(net, "_modules"): 53 | for name, layer in net._modules.items(): 54 | if layer is None: 55 | # e.g. GoogLeNet's aux1 and aux2 layers 56 | continue 57 | self.layers["->".join(prefix + [name])] = layer 58 | name_layers(layer, prefix=prefix + [name]) 59 | else: 60 | raise ValueError('net has not _modules attribute! Check if your model is properly instantiated..') 61 | 62 | name_layers(self.model) 63 | 64 | 65 | def forward( 66 | self, 67 | input_ids: torch.Tensor, 68 | hooks: List[Hook], 69 | *args, 70 | **kwargs, 71 | ): 72 | """Wrapper around the default forward pass that temporarily registers hooks, executes the forward pass and then closes hooks again. 73 | """ 74 | # register hooks 75 | registered_hooks = [] 76 | for hook in hooks: 77 | layer = self.layers.get(hook.layer_name, None) 78 | if layer is None: 79 | raise ValueError(f'Layer {hook.layer_name} was not found during hook registration! Here is the whole model for reference:\n {self.__repr__}') 80 | self.save_ctx[hook.key] = dict() # create sub-context for each hook to write to 81 | registered_hooks.append(layer.register_forward_hook(self._hook_wrapper(hook.func, hook.key))) 82 | 83 | # forward 84 | output = self.model(input_ids, *args, **kwargs) #TODO generalize to non-HF models which would not have an input_ids kwarg 85 | 86 | # remove hooks 87 | for hook in registered_hooks: 88 | hook.remove() 89 | 90 | return output 91 | 92 | def _hook_wrapper(self, func, hook_key): 93 | """Wrapper to comply with PyTorch's hooking API while enabling saving to context. 94 | 95 | :param func: [description] 96 | :type func: [type] 97 | :param hook_key: [description] 98 | :type hook_key: [type] 99 | :return: [description] 100 | :rtype: [type] 101 | """ 102 | return lambda model, input, output: func(save_ctx=self.save_ctx[hook_key], input=input[0], output=output) 103 | 104 | def get_ctx_keys(self): 105 | return list(self.save_ctx.keys()) 106 | 107 | def __repr__(self): 108 | return self.model.__repr__() 109 | 110 | @property 111 | def device(self): 112 | return next(self.model.parameters()).device -------------------------------------------------------------------------------- /unseal/hooks/rome_hooks.py: -------------------------------------------------------------------------------- 1 | #TODO for version 1.0.0: Remove this file 2 | import logging 3 | 4 | from .common_hooks import additive_output_noise, hidden_patch_hook_fn 5 | 6 | logging.warning("rome_hooks.py is deprecated and will be removed in version 1.0.0. Please use common_hooks.py instead.") 7 | 8 | -------------------------------------------------------------------------------- /unseal/hooks/util.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Iterable, Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | def create_slice_from_str(indices: str) -> slice: 9 | """Creates a slice object from a string representing the slice. 10 | 11 | :param indices: String representing the slice, e.g. ``...,3:5,:`` 12 | :type indices: str 13 | :return: Slice object corresponding to the input indices. 14 | :rtype: slice 15 | """ 16 | if len(indices) == 0: 17 | raise ValueError('Empty string is not a valid slice.') 18 | return eval(f'np.s_[{indices}]') 19 | 20 | 21 | def recursive_to_device( 22 | iterable: Union[Iterable, torch.Tensor], 23 | device: Union[str, torch.device], 24 | ) -> Iterable: 25 | """Recursively puts an Iterable of (Iterable of (...)) tensors on the given device 26 | 27 | :param iterable: Tensor or Iterable of tensors or iterables of ... 28 | :type iterable: Tensor or Iterable 29 | :param device: Device on which to put the object 30 | :type device: Union[str, torch.device] 31 | :raises TypeError: Unexpected tyes 32 | :return: Nested iterable with the tensors on the new device 33 | :rtype: Iterable 34 | """ 35 | if isinstance(iterable, torch.Tensor): 36 | return iterable.to(device) 37 | 38 | new = [] 39 | for i, item in enumerate(iterable): 40 | if isinstance(item, torch.Tensor): 41 | new.append(item.to(device)) 42 | elif isinstance(item, Iterable): 43 | new.append(recursive_to_device(item, device)) 44 | else: 45 | raise TypeError(f'Expected type tensor or Iterable but got {type(item)}.') 46 | if isinstance(iterable, Tuple): 47 | new = tuple(new) 48 | return new -------------------------------------------------------------------------------- /unseal/logit_lense.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional, List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from transformers import AutoTokenizer 7 | 8 | from .transformers_util import get_num_layers 9 | from .hooks.common_hooks import create_logit_hook 10 | from .hooks.commons import HookedModel 11 | 12 | def generate_logit_lense( 13 | model: HookedModel, 14 | tokenizer: AutoTokenizer, 15 | sentence: str, 16 | layers: Optional[List[int]] = None, 17 | ranks: Optional[bool] = False, 18 | kl_div: Optional[bool] = False, 19 | include_input: Optional[bool] = False, 20 | layer_key_prefix: Optional[str] = None, 21 | ): 22 | """Generates the necessary data to generate the plots from the logits `lense post 23 | `_. 24 | 25 | Returns None for ranks and kl_div if not specified. 26 | 27 | :param model: Model that is investigated. 28 | :type model: HookedModel 29 | :param tokenizer: Tokenizer of the model. 30 | :type tokenizer: AutoTokenizer 31 | :param sentence: Sentence to be analyzed. 32 | :type sentence: str 33 | :param layers: List of layers to be investigated. 34 | :type layers: Optional[List[int]] 35 | :param ranks: Whether to return ranks of the correct token throughout layers, defaults to False 36 | :type ranks: Optional[bool], optional 37 | :param kl_div: Whether to return the KL divergence between intermediate probabilities and final output probabilities, defaults to False 38 | :type kl_div: Optional[bool], optional 39 | :param include_input: Whether to include the immediate logits/ranks/kld after embedding the input, defaults to False 40 | :type include_input: Optional[bool], optional 41 | :param layer_key_prefix: Prefix for the layer keys, e.g. 'transformer->h' for GPT like models, defaults to None 42 | :type layer_key_prefix: Optional[str], optional 43 | :return: logits, ranks, kl_div 44 | :rtype: Tuple[torch.Tensor] 45 | """ 46 | 47 | # TODO 48 | if include_input: 49 | logging.warning("include_input is not implemented yet") 50 | 51 | # prepare model input 52 | tokenized_sentence = tokenizer.encode(sentence, return_tensors='pt').to(model.device) 53 | targets = tokenizer.encode(sentence)[1:] 54 | 55 | # instantiate hooks 56 | num_layers = get_num_layers(model, layer_key_prefix=layer_key_prefix) 57 | if layers is None: 58 | layers = list(range(num_layers)) 59 | logit_hooks = [create_logit_hook(layer, model, 'lm_head', layer_key_prefix) for layer in layers] 60 | 61 | # run model 62 | model.forward(tokenized_sentence, hooks=logit_hooks) 63 | logits = torch.stack([model.save_ctx[str(layer) + '_logits']['logits'] for layer in range(num_layers)], dim=0) 64 | logits = F.log_softmax(logits, dim=-1) 65 | 66 | # compute ranks and kld 67 | if ranks: 68 | inverted_ranks = torch.argsort(logits, dim=-1, descending=True) 69 | ranks = torch.argsort(inverted_ranks, dim=-1) + 1 70 | ranks = ranks[:, torch.arange(len(targets)), targets] 71 | else: 72 | ranks = None 73 | 74 | if kl_div: # Note: logits are already normalized internally by the logit_hook 75 | kl_div = F.kl_div(logits, logits[-1][None], reduction='none', log_target=True).sum(dim=-1) 76 | kl_div = kl_div[:, torch.arange(len(targets)), targets] 77 | else: 78 | kl_div = None 79 | 80 | logits = logits[:, torch.arange(len(targets)), targets] 81 | 82 | return logits, ranks, kl_div 83 | 84 | 85 | -------------------------------------------------------------------------------- /unseal/transformers_util.py: -------------------------------------------------------------------------------- 1 | # utility functions for interacting with huggingface's transformers library 2 | import logging 3 | import os 4 | from typing import Optional, Tuple 5 | 6 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM 7 | from transformers.file_utils import RepositoryNotFoundError 8 | 9 | from .hooks.commons import HookedModel 10 | 11 | def load_from_pretrained( 12 | model_name: str, 13 | model_dir: Optional[str] = None, 14 | load_model: Optional[bool] = True, 15 | load_tokenizer: Optional[bool] = True, 16 | load_config: Optional[bool] = True, 17 | low_cpu_mem_usage: Optional[bool] = False, 18 | ) -> Tuple[AutoModelForCausalLM, AutoTokenizer, AutoConfig]: 19 | """Load a pretrained model from huggingface's transformer library 20 | 21 | :param model_name: Name of the model, e.g. `gpt2` or `gpt-neo`. 22 | :type model_name: str 23 | :param model_dir: Directory in which to look for the model, e.g. `EleutherAI`, defaults to None 24 | :type model_dir: Optional[str], optional 25 | :param load_model: Whether to load the model itself, defaults to True 26 | :type load_model: Optional[bool], optional 27 | :param load_tokenizer: Whether to load the tokenizer, defaults to True 28 | :type load_tokenizer: Optional[bool], optional 29 | :param load_config: Whether to load the config file, defaults to True 30 | :type load_config: Optional[bool], optional 31 | :param low_cpu_mem_usage: Whether to use low-memory mode, experimental feature of HF, defaults to False 32 | :type low_cpu_mem_usage: bool, optional 33 | :return: model, tokenizer, config. Returns None values for those elements which were not loaded. 34 | :rtype: Tuple[AutoModelForCausalLM, AutoTokenizer, AutoConfig] 35 | """ 36 | if model_dir is None: 37 | try: 38 | logging.info(f'Loading model {model_name}') 39 | 40 | model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=low_cpu_mem_usage) if load_model else None 41 | tokenizer = AutoTokenizer.from_pretrained(model_name) if load_tokenizer else None 42 | config = AutoConfig.from_pretrained(model_name) if load_config else None 43 | 44 | except (RepositoryNotFoundError, OSError) as error: 45 | logging.warning("Couldn't find model in default folder. Trying EleutherAI/...") 46 | 47 | model = AutoModelForCausalLM.from_pretrained(f'EleutherAI/{model_name}', low_cpu_mem_usage=low_cpu_mem_usage) if load_model else None 48 | tokenizer = AutoTokenizer.from_pretrained(f'EleutherAI/{model_name}') if load_tokenizer else None 49 | config = AutoConfig.from_pretrained(f'EleutherAI/{model_name}') if load_config else None 50 | 51 | else: 52 | model = AutoModelForCausalLM.from_pretrained(os.path.join(model_dir, model_name), low_cpu_mem_usage=low_cpu_mem_usage) if load_model else None 53 | tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_dir, model_name)) if load_tokenizer else None 54 | config = AutoConfig.from_pretrained(os.path.join(model_dir, model_name)) if load_config else None 55 | 56 | return model, tokenizer, config 57 | 58 | def get_num_layers(model: HookedModel, layer_key_prefix: Optional[str] = None) -> int: 59 | """Get the number of layers in a model 60 | 61 | :param model: The model to get the number of layers from 62 | :type model: HookedModel 63 | :param layer_key_prefix: The prefix to use for the layer keys, defaults to None 64 | :type layer_key_prefix: Optional[str], optional 65 | :return: The number of layers in the model 66 | :rtype: int 67 | """ 68 | if layer_key_prefix is None: 69 | layer_key_prefix = "" 70 | 71 | return len(model.layers[f"{layer_key_prefix}"]) -------------------------------------------------------------------------------- /unseal/visuals/__init__.py: -------------------------------------------------------------------------------- 1 | from . import utils -------------------------------------------------------------------------------- /unseal/visuals/streamlit_interfaces/__init__.py: -------------------------------------------------------------------------------- 1 | from .commons import * -------------------------------------------------------------------------------- /unseal/visuals/streamlit_interfaces/all_layers_single_input.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | 4 | import streamlit as st 5 | 6 | from unseal.visuals.streamlit_interfaces import utils 7 | from unseal.visuals.streamlit_interfaces import interface_setup as setup 8 | from unseal.visuals.streamlit_interfaces.commons import SESSION_STATE_VARIABLES 9 | 10 | # perform startup tasks 11 | setup.startup(SESSION_STATE_VARIABLES, './registered_models.json') 12 | 13 | # create sidebar 14 | with st.sidebar: 15 | setup.create_sidebar() 16 | 17 | sample = st.checkbox('Enable sampling', value=False, key='sample') 18 | if sample: 19 | setup.create_sample_sliders() 20 | setup.on_sampling_config_change() 21 | 22 | if "storage" not in st.session_state: 23 | st.session_state["storage"] = [""] 24 | 25 | # input 1 26 | placeholder1 = st.empty() 27 | placeholder1.text_area(label='Input 1', on_change=utils.on_text_change, key='input_text_1', value=st.session_state.storage[0], kwargs=dict(col_idx=0, text_key='input_text_1')) 28 | if sample: 29 | st.button(label="Sample", on_click=utils.sample_text, kwargs=dict(model=st.session_state.model, col_idx=0, key="input_text"), key="sample_text") 30 | 31 | # sometimes need to force a re-render 32 | st.button('Show Attention', on_click=utils.text_change, kwargs=dict(col_idx=0)) 33 | 34 | f = json.encoder.JSONEncoder().encode(st.session_state.visualization) 35 | st.download_button( 36 | label='Download Visualization', 37 | data=f, 38 | file_name=f'{st.session_state.model_name}_{time.strftime("%Y%m%d_%H%M%S", time.localtime())}.json', 39 | mime='application/json', 40 | help='Download the visualizations as a json of html files.', 41 | key='download_button' 42 | ) 43 | 44 | # show the html visualization 45 | if st.session_state.model is not None: 46 | cols = st.columns(1) 47 | for col_idx, col in enumerate(cols): 48 | if f"col_{col_idx}" in st.session_state.visualization: 49 | with col: 50 | for layer in range(st.session_state.num_layers): 51 | if f"layer_{layer}" in st.session_state.visualization[f"col_{col_idx}"]: 52 | with st.expander(f'Layer {layer}'): 53 | st.components.v1.html(st.session_state.visualization[f"col_{col_idx}"][f"layer_{layer}"], height=600) 54 | else: 55 | st.session_state.visualization[f"col_{col_idx}"] = dict() -------------------------------------------------------------------------------- /unseal/visuals/streamlit_interfaces/commons.py: -------------------------------------------------------------------------------- 1 | # define some global variables 2 | HF_MODELS = [ 3 | 'gpt2', 4 | 'gpt2-medium', 5 | 'gpt2-large', 6 | 'gpt2-xl', 7 | 'gpt-neo-125M', 8 | 'gpt-neo-1.3B', 9 | 'gpt-neo-2.7B', 10 | 'gpt-j-6b', 11 | ] 12 | 13 | SESSION_STATE_VARIABLES = [ 14 | 'model', 15 | 'tokenizer', 16 | 'config', 17 | 'registered_models', 18 | 'registered_model_names', 19 | 'num_layers', 20 | ] 21 | -------------------------------------------------------------------------------- /unseal/visuals/streamlit_interfaces/compare_two_inputs.py: -------------------------------------------------------------------------------- 1 | import time 2 | import json 3 | import streamlit as st 4 | 5 | from unseal.visuals.streamlit_interfaces import utils 6 | from unseal.visuals.streamlit_interfaces import interface_setup as setup 7 | from unseal.visuals.streamlit_interfaces.commons import SESSION_STATE_VARIABLES 8 | 9 | # perform startup tasks 10 | setup.startup(SESSION_STATE_VARIABLES, './registered_models.json') 11 | 12 | # create sidebar 13 | with st.sidebar: 14 | setup.create_sidebar() 15 | 16 | sample = st.checkbox('Enable sampling', value=False, key='sample') 17 | if sample: 18 | setup.create_sample_sliders() 19 | setup.on_sampling_config_change() 20 | 21 | if "storage" not in st.session_state: 22 | st.session_state["storage"] = ["", ""] 23 | 24 | # input 1 25 | placeholder1 = st.empty() 26 | placeholder1.text_area(label='Input 1', on_change=utils.on_text_change, key='input_text_1', value=st.session_state.storage[0], kwargs=dict(col_idx=0, text_key='input_text_1')) 27 | if sample: 28 | st.button(label="Sample", on_click=utils.sample_text, kwargs=dict(model=st.session_state.model, col_idx=0, key="input_text_1"), key="sample_text_1") 29 | 30 | # input 2 31 | placeholder2 = st.empty() 32 | placeholder2.text_area(label='Input 2', on_change=utils.on_text_change, key='input_text_2', value=st.session_state.storage[1], kwargs=dict(col_idx=1, text_key='input_text_2')) 33 | if sample: 34 | st.button(label="Sample", on_click=utils.sample_text, kwargs=dict(model=st.session_state.model, col_idx=1, key="input_text_2"), key="sample_text_2") 35 | 36 | # sometimes need to force a re-render 37 | st.button('Show Attention', on_click=utils.text_change, kwargs=dict(col_idx=[0,1])) 38 | 39 | # download button 40 | f = json.encoder.JSONEncoder().encode(st.session_state.visualization) 41 | st.download_button( 42 | label='Download Visualization', 43 | data=f, 44 | file_name=f'{st.session_state.model_name}_{time.strftime("%Y%m%d_%H%M%S", time.localtime())}.json', 45 | mime='application/json', 46 | help='Download the visualizations as a json of html files.', 47 | key='download_button' 48 | ) 49 | 50 | # show the html visualization 51 | if st.session_state.model is not None: 52 | cols = st.columns(2) 53 | for col_idx, col in enumerate(cols): 54 | if f"col_{col_idx}" in st.session_state.visualization: 55 | with col: 56 | for layer in range(st.session_state.num_layers): 57 | if f"layer_{layer}" in st.session_state.visualization[f"col_{col_idx}"]: 58 | with st.expander(f'Layer {layer}'): 59 | st.components.v1.html(st.session_state.visualization[f"col_{col_idx}"][f"layer_{layer}"], height=600, scrolling=True) 60 | else: 61 | st.session_state.visualization[f"col_{col_idx}"] = dict() 62 | -------------------------------------------------------------------------------- /unseal/visuals/streamlit_interfaces/interface_setup.py: -------------------------------------------------------------------------------- 1 | # Functions that are usually only called once during setup of the interface 2 | import importlib 3 | import json 4 | from typing import Optional, List, Tuple 5 | 6 | import streamlit as st 7 | import torch 8 | 9 | from .commons import HF_MODELS 10 | from ...transformers_util import get_num_layers, load_from_pretrained 11 | from ...hooks import HookedModel 12 | 13 | def create_model_config(model_names): 14 | with st.form('model_config'): 15 | st.write('## Model Config') 16 | 17 | if model_names is None: 18 | model_options = list() 19 | else: 20 | model_options = model_names 21 | 22 | st.selectbox( 23 | 'Model', 24 | options=model_options, 25 | key='model_name', 26 | index=0, 27 | ) 28 | 29 | devices = ['cpu'] 30 | if torch.cuda.is_available(): 31 | devices += ['cuda'] 32 | st.selectbox( 33 | 'Device', 34 | options=devices, 35 | index=0, 36 | key='device' 37 | ) 38 | 39 | st.text_area(label='Prefix Prompt', key='prefix_prompt', value='') 40 | 41 | submitted = st.form_submit_button("Save model config") 42 | if submitted: 43 | st.session_state.model, st.session_state.tokenizer, st.session_state.config = on_config_submit(st.session_state.model_name) 44 | st.write('Model config saved!') 45 | 46 | 47 | def create_sidebar(): 48 | st.checkbox('Show only local models', value=False, key='local_only') 49 | 50 | if not st.session_state.local_only: 51 | model_names = st.session_state.registered_model_names + HF_MODELS 52 | else: 53 | model_names = st.session_state.registered_model_names 54 | 55 | create_model_config(model_names) 56 | 57 | def create_sample_sliders(): 58 | st.slider(label="Temperature", min_value=0., max_value=1.0, value=0., step=0.01, key='temperature', on_change=on_sampling_config_change) 59 | st.slider(label="Response length", min_value=1, max_value=1024, value=64, step=1, key='response_length', on_change=on_sampling_config_change) 60 | st.slider(label="Top P", min_value=0., max_value=1.0, value=1., step=0.01, key='top_p', on_change=on_sampling_config_change) 61 | st.slider(label="Repetition Penalty (1 = no penalty)", min_value=0.01, max_value=1.0, value=1., step=0.01, key='repetition_penalty', on_change=on_sampling_config_change) 62 | st.slider(label="Number of Beams", min_value=1, max_value=10, value=1, step=1, key='num_beams', on_change=on_sampling_config_change) 63 | 64 | def load_registered_models(model_file_path: str = './registered_models.json') -> None: 65 | try: 66 | with open(model_file_path, 'r') as f: 67 | st.session_state.registered_models = json.load(f) 68 | except FileNotFoundError: 69 | st.warning(f"Did not find a 'registered_models.json'. Only showing HF models") 70 | st.session_state.registered_models = dict() 71 | st.session_state.registered_model_names = list(st.session_state.registered_models.keys()) 72 | 73 | 74 | def startup(variables: List[str], mode_file_path: Optional[str] = './registered_models.json') -> None: 75 | """Performs startup tasks for the app. 76 | 77 | :param variables: List of variable names that should be intialized. 78 | :type variables: List[str] 79 | :param model_file_path: Path to the file containing the registered models. 80 | :type model_file_path: Optional[str] 81 | """ 82 | 83 | if 'startup_done' not in st.session_state: 84 | # set wide layout 85 | st.set_page_config(layout="wide") 86 | 87 | # initialize session state variables 88 | init_session_state(variables) 89 | st.session_state['visualization'] = dict() 90 | st.session_state['startup_done'] = True 91 | 92 | # load externally registered models 93 | load_registered_models(mode_file_path) 94 | 95 | def init_session_state(variables: List[str]) -> None: 96 | """Initialize session state variables to None. 97 | 98 | :param variables: List of variable names to initialize. 99 | :type variables: List[str] 100 | """ 101 | for var in variables: 102 | if var not in st.session_state: 103 | st.session_state[var] = None 104 | 105 | def on_config_submit(model_name: str) -> Tuple: 106 | """Function that is called on submitting the config form. 107 | 108 | :param model_name: Name of the model that should be loaded 109 | :type model_name: str 110 | :return: Model, tokenizer, config 111 | :rtype: Tuple 112 | """ 113 | # load model, hook it and put it on device and in eval mode 114 | model, tokenizer, config = load_model(model_name) 115 | model = HookedModel(model) 116 | model.to(st.session_state.device).eval() 117 | 118 | if model_name in st.session_state.registered_models: 119 | st.session_state.num_layers = get_num_layers(model, config['layer_key_prefix']) 120 | else: 121 | st.session_state.num_layers = get_num_layers(model, 'transformer->h') 122 | 123 | return model, tokenizer, config 124 | 125 | def on_sampling_config_change(): 126 | st.session_state.sample_kwargs = dict( 127 | temperature=st.session_state.temperature, 128 | max_length=st.session_state.response_length, 129 | top_p=st.session_state.top_p, 130 | repetition_penalty=1/st.session_state.repetition_penalty, 131 | num_beams=st.session_state.num_beams 132 | ) 133 | 134 | @st.experimental_singleton 135 | def load_model(model_name: str) -> Tuple: 136 | """Load the specified model with its tokenizer and config. 137 | 138 | :param model_name: Model name, e.g. 'gpt2-xl' 139 | :type model_name: str 140 | :return: Model, Tokenizer, Config 141 | :rtype: Tuple 142 | """ 143 | if model_name in st.session_state.registered_model_names: 144 | # import model constructor 145 | constructor = st.session_state.registered_models[model_name]['constructor'] 146 | constructor_module = importlib.import_module('.'.join(constructor.split('.')[:-1])) 147 | constructor_class = getattr(constructor_module, constructor.split('.')[-1]) 148 | 149 | # load model from checkpoint --> make sure that your class has this method, it's default for pl.LightningModules 150 | checkpoint = st.session_state.registered_models[model_name]['checkpoint'] 151 | model = constructor_class.load_from_checkpoint(checkpoint) 152 | 153 | # load tokenizer 154 | tokenizer = st.session_state.registered_models[model_name]['tokenizer'] # TODO how to deal with this? 155 | tokenizer_module = importlib.import_module('.'.join(tokenizer.split('.')[:-1])) 156 | tokenizer_class = getattr(tokenizer_module, tokenizer.split('.')[-1]) 157 | tokenizer = tokenizer_class() 158 | 159 | # load config 160 | config = st.session_state.registered_models[model_name]['config'] 161 | 162 | else: # attempt to load from huggingface 163 | model, tokenizer, config = load_from_pretrained(model_name) 164 | 165 | return model, tokenizer, config 166 | -------------------------------------------------------------------------------- /unseal/visuals/streamlit_interfaces/load_only.py: -------------------------------------------------------------------------------- 1 | import json 2 | from io import StringIO 3 | 4 | import streamlit as st 5 | 6 | from unseal.visuals.streamlit_interfaces import interface_setup as setup 7 | from unseal.visuals import utils 8 | 9 | def on_file_upload(): 10 | if st.session_state.uploaded_file is None: 11 | return 12 | 13 | data = json.loads(StringIO(st.session_state.uploaded_file.getvalue().decode('utf-8')).read()) 14 | 15 | st.session_state.visualization = data 16 | 17 | for layer in range(len(data)): 18 | html_str = st.session_state.visualization[f'layer_{layer}'] 19 | with st.expander(f'Layer {layer}'): 20 | st.components.v1.html(html_str, height=600, scrolling=True) 21 | 22 | # set page config to wide layout 23 | st.set_page_config(layout='wide') 24 | 25 | # create sidebar 26 | with st.sidebar: 27 | st.file_uploader( 28 | label='Upload Visualization', 29 | accept_multiple_files=False, 30 | on_change=on_file_upload, 31 | help='Upload the visualizations as an json of html files.', 32 | key='uploaded_file' 33 | ) 34 | 35 | 36 | def load(): 37 | with open("./visualizations/gpt2-xl.json", "r") as fp: 38 | data = json.load(fp) 39 | 40 | st.session_state.visualization = data 41 | for layer in range(len(data['col_0'])): 42 | html_str = st.session_state.visualization['col_0'][f'layer_{layer}'] 43 | with st.expander(f'Layer {layer}'): 44 | st.components.v1.html(html_str, height=1000, width=800, scrolling=True) 45 | 46 | load() 47 | 48 | -------------------------------------------------------------------------------- /unseal/visuals/streamlit_interfaces/registered_models.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /unseal/visuals/streamlit_interfaces/single_layer_single_input.py: -------------------------------------------------------------------------------- 1 | def layer_change(): 2 | utils.text_change(0) 3 | 4 | import json 5 | import time 6 | 7 | import streamlit as st 8 | 9 | from unseal.visuals.streamlit_interfaces import utils 10 | from unseal.visuals.streamlit_interfaces import interface_setup as setup 11 | from unseal.visuals.streamlit_interfaces.commons import SESSION_STATE_VARIABLES 12 | 13 | # perform startup tasks 14 | setup.startup(SESSION_STATE_VARIABLES, './registered_models.json') 15 | 16 | # create sidebar 17 | with st.sidebar: 18 | setup.create_sidebar() 19 | 20 | sample = st.checkbox('Enable sampling', value=False, key='sample') 21 | if sample: 22 | setup.create_sample_sliders() 23 | setup.on_sampling_config_change() 24 | 25 | if "storage" not in st.session_state: 26 | st.session_state["storage"] = [""] 27 | 28 | # select layer 29 | if st.session_state.num_layers is None: 30 | options = list(['Select a model!']) 31 | else: 32 | options = list(range(st.session_state.num_layers)) 33 | st.selectbox('Layer', options=options, key='layer', on_change=layer_change, index=0) 34 | 35 | # input 1 36 | placeholder = st.empty() 37 | placeholder.text_area(label='Input', on_change=utils.on_text_change, key='input_text', value=st.session_state.storage[0], kwargs=dict(col_idx=0, text_key='input_text')) 38 | if sample: 39 | st.button(label="Sample", on_click=utils.sample_text, kwargs=dict(model=st.session_state.model, col_idx=0, key="input_text"), key="sample_text") 40 | 41 | # sometimes need to force a re-render 42 | st.button('Show Attention', on_click=utils.text_change, kwargs=dict(col_idx=0)) 43 | 44 | f = json.encoder.JSONEncoder().encode(st.session_state.visualization) 45 | st.download_button( 46 | label='Download Visualization', 47 | data=f, 48 | file_name=f'{st.session_state.model_name}_{time.strftime("%Y%m%d_%H%M%S", time.localtime())}.json', 49 | mime='application/json', 50 | help='Download the visualizations as a json of html files.', 51 | key='download_button' 52 | ) 53 | 54 | # show the html visualization 55 | if st.session_state.model is not None: 56 | cols = st.columns(1) 57 | for col_idx, col in enumerate(cols): 58 | if f"col_{col_idx}" in st.session_state.visualization: 59 | with col: 60 | if f"layer_{st.session_state.layer}" in st.session_state.visualization[f"col_{col_idx}"]: 61 | with st.expander(f'Layer {st.session_state.layer}'): 62 | st.components.v1.html(st.session_state.visualization[f"col_{col_idx}"][f"layer_{st.session_state.layer}"], height=600) 63 | else: 64 | st.session_state.visualization[f"col_{col_idx}"] = dict() -------------------------------------------------------------------------------- /unseal/visuals/streamlit_interfaces/split_full_model_vis_into_layers.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | 5 | def main(args): 6 | with open(args.file, 'r') as f: 7 | data = json.load(f) 8 | 9 | # get model_name from path 10 | filename = os.path.basename(args.file) 11 | model_name = filename.split('.')[0] 12 | print(f"model_name: {model_name}") 13 | for i in range(2): 14 | if f'col_{i}' in data: 15 | for j in range(len(data[f'col_{i}'])): 16 | os.makedirs(f'{args.target_folder}/col_{i}', exist_ok=True) 17 | with open(f'{args.target_folder}/col_{i}/layer_{j}.txt', 'w') as f: 18 | json.dump(data[f'col_{i}'][f'layer_{j}'], f) 19 | else: 20 | print(f"col_{i} not in data -> skipping") 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--file', type=str, required=True) 25 | parser.add_argument('--target_folder', type=str, required=True) 26 | args = parser.parse_args() 27 | 28 | main(args) 29 | -------------------------------------------------------------------------------- /unseal/visuals/streamlit_interfaces/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | 3 | import streamlit as st 4 | from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoPreTrainedModel 5 | 6 | from ..utils import compute_attn_logits 7 | 8 | def sample_text(model, col_idx, key): 9 | text = st.session_state[key] 10 | if st.session_state.prefix_prompt is not None and len(st.session_state.prefix_prompt) > 0: 11 | text = st.session_state.prefix_prompt + '\n' + text 12 | model_inputs = st.session_state.tokenizer.encode(text, return_tensors='pt').to(st.session_state.device) 13 | output = model.model.generate(model_inputs, **st.session_state.sample_kwargs, min_length=0, output_attentions=True) 14 | output_text = st.session_state.tokenizer.decode(output[0], skip_special_tokens=True) 15 | if st.session_state.prefix_prompt is not None and len(st.session_state.prefix_prompt) > 0: 16 | output_text = output_text.lstrip(st.session_state.prefix_prompt + '\n') 17 | 18 | st.session_state["storage"][col_idx] = output_text 19 | text_change(col_idx=col_idx) 20 | 21 | 22 | def on_text_change(col_idx: Union[int, List[int]], text_key): 23 | if isinstance(col_idx, list): 24 | for idx in col_idx: 25 | on_text_change(idx, text_key) 26 | else: 27 | st.session_state["storage"][col_idx] = st.session_state[text_key] 28 | text_change(col_idx) 29 | 30 | def get_attn_logits_args(): 31 | # get args for compute_attn_logits 32 | if st.session_state.model_name in st.session_state.registered_model_names: 33 | attn_name = st.session_state.config['attn_name'] 34 | output_idx = st.session_state.config['output_idx'] 35 | layer_key_prefix = st.session_state.config['layer_key_prefix'] 36 | out_proj_name = st.session_state.config['out_proj_name'] 37 | attn_suffix = st.session_state.config['attn_suffix'] 38 | unembedding_key = st.session_state.config['unembedding_key'] 39 | elif isinstance(st.session_state.model.model, GPTNeoPreTrainedModel): 40 | attn_name = 'attn' 41 | output_idx = 2 42 | layer_key_prefix = 'transformer->h' 43 | out_proj_name = 'out_proj' 44 | attn_suffix = 'attention' 45 | unembedding_key = 'lm_head' 46 | else: 47 | attn_name = 'attn' 48 | output_idx = 2 49 | layer_key_prefix = 'transformer->h' 50 | out_proj_name = 'c_proj' 51 | attn_suffix = None 52 | unembedding_key = 'lm_head' 53 | return attn_name, output_idx, layer_key_prefix, out_proj_name, attn_suffix, unembedding_key 54 | 55 | def text_change(col_idx: Union[int, List[int]]): 56 | if isinstance(col_idx, list): 57 | for idx in col_idx: 58 | text_change(idx) 59 | return 60 | 61 | text = st.session_state["storage"][col_idx] 62 | if st.session_state.prefix_prompt is not None and len(st.session_state.prefix_prompt) > 0: 63 | text = st.session_state.prefix_prompt + '\n' + text 64 | 65 | if text is None or len(text) == 0: 66 | return 67 | 68 | attn_name, output_idx, layer_key_prefix, out_proj_name, attn_suffix, unembedding_key = get_attn_logits_args() 69 | 70 | if 'layer' in st.session_state: 71 | layer = st.session_state['layer'] 72 | else: 73 | layer = None 74 | 75 | compute_attn_logits( 76 | st.session_state.model, 77 | st.session_state.model_name, 78 | st.session_state.tokenizer, 79 | st.session_state.num_layers, 80 | text, 81 | st.session_state.visualization[f'col_{col_idx}'], 82 | attn_name = attn_name, 83 | output_idx = output_idx, 84 | layer_key_prefix = layer_key_prefix, 85 | out_proj_name = out_proj_name, 86 | attn_suffix = attn_suffix, 87 | unembedding_key = unembedding_key, 88 | layer_id = layer, 89 | ) -------------------------------------------------------------------------------- /unseal/visuals/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import gc 3 | from typing import Optional, Callable, Dict 4 | 5 | import einops 6 | import pysvelte as ps 7 | import torch 8 | from transformers import AutoTokenizer 9 | 10 | from ..hooks import HookedModel 11 | from ..hooks.common_hooks import create_attention_hook, gpt_attn_wrapper 12 | 13 | def compute_attn_logits( 14 | model: HookedModel, 15 | model_name: str, 16 | tokenizer: AutoTokenizer, 17 | num_layers: int, 18 | text: str, 19 | html_storage: Dict, 20 | save_path: Optional[str] = None, 21 | attn_name: Optional[str] = 'attn', 22 | output_idx: Optional[int] = 2, 23 | layer_key_prefix: Optional[str] = None, 24 | out_proj_name: Optional[str] = 'out_proj', 25 | attn_suffix: Optional[str] = None, 26 | unembedding_key: Optional[str] = 'lm_head', 27 | layer_id: Optional[int] = None, 28 | batch_size: Optional[int] = None, 29 | ): 30 | # parse inputs 31 | if save_path is None: 32 | save_path = f"{model_name}.json" 33 | if layer_key_prefix is None: 34 | layer_key_prefix = "" 35 | else: 36 | layer_key_prefix += "->" 37 | if attn_suffix is None or attn_suffix == "": 38 | attn_suffix = "" 39 | if layer_id is None: 40 | layer_list = list(range(num_layers)) 41 | else: 42 | layer_list = [layer_id] 43 | 44 | # tokenize without tokenization artifact -> needed for visualization 45 | tokenized_text = tokenizer.tokenize(text) 46 | tokenized_text = list(map(tokenizer.convert_tokens_to_string, map(lambda x: [x], tokenized_text))) 47 | 48 | # encode text 49 | model_input = tokenizer.encode(text, return_tensors='pt').to(model.device) 50 | target_ids = tokenizer.encode(text)[1:] 51 | 52 | # compute attention pattern 53 | attn_hooks = [create_attention_hook(layer, f'attn_layer_{layer}', output_idx, attn_name, layer_key_prefix) for layer in layer_list] 54 | model.forward(model_input, hooks=attn_hooks, output_attentions=True) 55 | 56 | # compute logits 57 | for layer in layer_list: 58 | 59 | # wrap the _attn function to create logit attribution 60 | model.save_ctx[f'logit_layer_{layer}'] = dict() 61 | old_fn = wrap_gpt_attn(model, layer, target_ids, unembedding_key, attn_name, attn_suffix, layer_key_prefix, out_proj_name, batch_size) 62 | 63 | # forward pass 64 | model.forward(model_input, hooks=[]) 65 | 66 | # parse attentions for this layer 67 | attention = model.save_ctx[f"attn_layer_{layer}"]['attn'][0] 68 | attention = einops.rearrange(attention, 'h n1 n2 -> n1 n2 h') 69 | 70 | # parse logits 71 | if model_input.shape[1] > 1: # otherwise we don't have any logit attribution 72 | logits = model.save_ctx[f'logit_layer_{layer}']['logits'] 73 | pos_logits = logits['pos'] 74 | neg_logits = logits['neg'] 75 | pos_logits = pad_logits(pos_logits) 76 | neg_logits = pad_logits(neg_logits) 77 | 78 | pos_logits = einops.rearrange(pos_logits, 'h n1 n2 -> n1 n2 h') 79 | neg_logits = einops.rearrange(neg_logits, 'h n1 n2 -> n1 n2 h') 80 | else: 81 | pos_logits = torch.zeros((attention.shape[0], attention.shape[1], attention.shape[2])) 82 | neg_logits = torch.zeros((attention.shape[0], attention.shape[1], attention.shape[2])) 83 | 84 | # compute and display the html object 85 | html_object = ps.AttentionLogits(tokens=tokenized_text, attention=attention, pos_logits=pos_logits, neg_logits=neg_logits, head_labels=[f'{layer}:{j}' for j in range(attention.shape[-1])]) 86 | html_object = html_object.update_meta(suppress_title=True) 87 | html_str = html_object.html_page_str() 88 | 89 | # save html string 90 | html_storage[f'layer_{layer}'] = html_str 91 | 92 | # reset _attn function 93 | reset_attn_fn(model, layer, old_fn, attn_name, attn_suffix, layer_key_prefix) 94 | 95 | # save progress so far 96 | with open(save_path, "w") as f: 97 | json.dump(html_storage, f) 98 | 99 | # garbage collection 100 | gc.collect() 101 | if torch.cuda.is_available(): 102 | torch.cuda.empty_cache() 103 | 104 | return html_storage 105 | 106 | def pad_logits(logits): 107 | logits = torch.cat([torch.zeros_like(logits[:,0,None]), logits], dim=1) 108 | logits = torch.cat([logits, torch.zeros_like(logits[:,:,0,None])], dim=2) 109 | return logits 110 | 111 | def wrap_gpt_attn( 112 | model: HookedModel, 113 | layer: int, 114 | target_ids: Callable, 115 | unembedding_key: str, 116 | attn_name: Optional[str] = 'attn', 117 | attn_suffix: Optional[str] = None, 118 | layer_key_prefix: Optional[str] = None, 119 | out_proj_name: Optional[str] = 'out_proj', 120 | batch_size: Optional[int] = None, 121 | ) -> Callable: 122 | # parse inputs 123 | if layer_key_prefix is None: 124 | layer_key_prefix = "" 125 | if attn_suffix is None: 126 | attn_suffix = "" 127 | 128 | attn_name = f"{layer_key_prefix}{layer}->{attn_name}{attn_suffix}" 129 | out_proj_name = attn_name + f"->{out_proj_name}" 130 | 131 | model.layers[attn_name]._attn, old_fn = gpt_attn_wrapper( 132 | model.layers[attn_name]._attn, 133 | model.save_ctx[f'logit_layer_{layer}'], 134 | model.layers[out_proj_name].weight, 135 | model.layers[unembedding_key].weight.T, 136 | target_ids, 137 | batch_size, 138 | ) 139 | 140 | return old_fn 141 | 142 | 143 | def reset_attn_fn( 144 | model: HookedModel, 145 | layer: int, 146 | old_fn: Callable, 147 | attn_name: Optional[str] = 'attn', 148 | attn_suffix: Optional[str] = None, 149 | layer_key_prefix: Optional[str] = None, 150 | ) -> None: 151 | # parse inputs 152 | if layer_key_prefix is None: 153 | layer_key_prefix = "" 154 | if attn_suffix is None: 155 | attn_suffix = "" 156 | 157 | # reset _attn function to old_fn 158 | del model.layers[f"{layer_key_prefix}{layer}->{attn_name}{attn_suffix}"]._attn 159 | model.layers[f"{layer_key_prefix}{layer}->{attn_name}{attn_suffix}"]._attn = old_fn --------------------------------------------------------------------------------