├── .gitignore
├── .readthedocs.yaml
├── .vscode
└── settings.json
├── LICENSE
├── README.md
├── docs
├── Makefile
├── images
│ ├── logo.png
│ ├── logo.svg
│ └── notebook_images
│ │ ├── inspectgpt2_card.png
│ │ └── trace_info_flow_card.png
├── make.bat
└── source
│ ├── conf.py
│ ├── index.rst
│ ├── tutorials
│ ├── common_hooks.rst
│ ├── getting_started.rst
│ ├── hooking.rst
│ ├── index.rst
│ ├── installation.rst
│ └── interfaces.rst
│ ├── unseal.hooks.rst
│ ├── unseal.interface.rst
│ ├── unseal.logit_lense.rst
│ └── unseal.transformers_util.rst
├── pyproject.toml
├── setup.cfg
├── setup.py
├── tests
├── hooks
│ ├── test_common_hooks.py
│ ├── test_commons.py
│ ├── test_rome_hooks.py
│ └── test_util.py
└── test_transformer_util.py
└── unseal
├── __init__.py
├── circuits
└── utils.py
├── hooks
├── __init__.py
├── common_hooks.py
├── commons.py
├── rome_hooks.py
└── util.py
├── logit_lense.py
├── transformers_util.py
└── visuals
├── __init__.py
├── streamlit_interfaces
├── __init__.py
├── all_layers_single_input.py
├── commons.py
├── compare_two_inputs.py
├── interface_setup.py
├── load_only.py
├── registered_models.json
├── single_layer_single_input.py
├── split_full_model_vis_into_layers.py
└── utils.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 |
132 | testing.py
133 |
134 | *visualizations/
135 | *gpt*json
136 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | # Required for RTD to use 3.9
9 | build:
10 | image: testing
11 |
12 | # Set python version
13 | python:
14 | version: 3.9
15 | install:
16 | - method: setuptools
17 | path: .
18 |
19 | # Build documentation in the docs/ directory with Sphinx
20 | sphinx:
21 | configuration: docs/source/conf.py
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "restructuredtext.confPath": "${workspaceFolder}/docs/source",
3 | "esbonio.server.enabled": false,
4 | "esbonio.sphinx.confDir": "${workspaceFolder}/docs/source"
5 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Tom Lieberum
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Unseal - Mechanistic Interpretability for Transformers
2 |
3 |
4 |
5 |
6 | ## Prerequisites
7 |
8 | Unseal requires python 3.6+.
9 |
10 |
11 | ## Installation
12 |
13 | For its visualizations interfaces, Unseal uses [this fork](https://github.com/TomFrederik/pysvelte) of the PySvelte library, which can be installed via pip:
14 |
15 | ```sh
16 | git clone git@github.com:TomFrederik/PySvelte.git
17 | cd PySvelte
18 | pip install -e .
19 | ```
20 |
21 | In order to run PySvelte, you will also need to install ``npm`` via your package manager.
22 | The hooking functionality of Unseal should still work without PySvelte, but we can't give any guarantees
23 |
24 | Install Unseal via pip
25 |
26 | ```sh
27 | pip install unseal
28 | ```
29 |
30 | ## Usage
31 |
32 | We refer to our documentation for tutorials and usage guides:
33 |
34 | [Documentation](https://unseal.readthedocs.io/en/latest/)
35 |
36 |
37 | ## Notebooks
38 |
39 | Here are some notebooks that also showcase Unseal's functionalities.
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TomFrederik/unseal/96a8251e9204b4b41213b9a4258058a180e38e82/docs/images/logo.png
--------------------------------------------------------------------------------
/docs/images/notebook_images/inspectgpt2_card.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TomFrederik/unseal/96a8251e9204b4b41213b9a4258058a180e38e82/docs/images/notebook_images/inspectgpt2_card.png
--------------------------------------------------------------------------------
/docs/images/notebook_images/trace_info_flow_card.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TomFrederik/unseal/96a8251e9204b4b41213b9a4258058a180e38e82/docs/images/notebook_images/trace_info_flow_card.png
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.https://www.sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | import os
14 | import sys
15 | sys.path.insert(0, os.path.abspath('../..'))
16 | from unseal import __version__
17 |
18 | # -- Project information -----------------------------------------------------
19 |
20 | project = 'Unseal'
21 | copyright = '2022, Tom Lieberum'
22 | author = 'Tom Lieberum'
23 |
24 |
25 | # The full version, including alpha/beta/rc tags
26 | release = __version__
27 |
28 | # -- General configuration ---------------------------------------------------
29 |
30 | # Add any Sphinx extension module names here, as strings. They can be
31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
32 | # ones.
33 | extensions = [
34 | 'sphinx.ext.coverage',
35 | 'sphinx.ext.autodoc',
36 | 'sphinx.ext.napoleon',
37 | ]
38 |
39 | # Add any paths that contain templates here, relative to this directory.
40 | templates_path = ['_templates']
41 |
42 | # List of patterns, relative to source directory, that match files and
43 | # directories to ignore when looking for source files.
44 | # This pattern also affects html_static_path and html_extra_path.
45 | exclude_patterns = []
46 |
47 |
48 | # -- Options for HTML output -------------------------------------------------
49 |
50 | # The theme to use for HTML and HTML Help pages. See the documentation for
51 | # a list of builtin themes.
52 | #
53 | html_theme = "sphinx_rtd_theme"
54 |
55 | # Add any paths that contain custom static files (such as style sheets) here,
56 | # relative to this directory. They are copied after the builtin static files,
57 | # so a file named "default.css" will overwrite the builtin "default.css".
58 | html_static_path = ['_static']
59 |
60 | # -- Extension configuration -------------------------------------------------
61 | html_theme_options = {
62 | 'canonical_url': 'docs/',
63 | 'analytics_id': 'UA-136588502-1', # Provided by Google in your dashboard
64 | # 'logo_only': False,
65 | # 'display_version': True,
66 | # 'prev_next_buttons_location': 'bottom',
67 | # 'style_external_links': False,
68 | # 'vcs_pageview_mode': '',
69 | # 'style_nav_header_background': 'white',
70 | # Toc options
71 | 'collapse_navigation': False,
72 | 'sticky_navigation': True,
73 | 'navigation_depth': 4,
74 | # 'includehidden': True,
75 | # 'titles_only': False
76 | }
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. Unseal documentation master file, created by
2 | sphinx-quickstart on Wed Feb 9 14:22:48 2022.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 | .. _main:
6 |
7 | .. image:: https://github.com/TomFrederik/unseal/blob/main/docs/images/logo.png?raw=true
8 | :width: 400px
9 |
10 | Welcome to Unseal's documentation!
11 | ==================================
12 |
13 | Unseal wants to help you get started in mechanistic interpretability research on transformers!
14 |
15 | Install Unseal by following the instructions on the :ref:`installation page `.
16 |
17 | If you are new to Unseal, you should start by reading the :ref:`getting started ` guide.
18 |
19 |
20 | .. toctree::
21 | :maxdepth: 1
22 | :caption: Tutorials and Guides
23 |
24 | tutorials/installation
25 | tutorials/getting_started
26 | tutorials/hooking
27 | tutorials/common_hooks
28 | tutorials/interfaces
29 |
30 | .. toctree::
31 | :maxdepth: 1
32 | :caption: API Reference
33 |
34 | unseal.hooks
35 | unseal.interface
36 | unseal.transformers_util
37 | unseal.logit_lense
38 |
39 |
40 |
41 | Indices and tables
42 | ==================
43 |
44 | * :ref:`genindex`
45 | * :ref:`modindex`
46 | * :ref:`search`
47 |
--------------------------------------------------------------------------------
/docs/source/tutorials/common_hooks.rst:
--------------------------------------------------------------------------------
1 | .. _common_hooks:
2 |
3 | ============================
4 | Common Hooks
5 | ============================
6 |
7 | The most common hooking functions are supported out of the box by Unseal.
8 |
9 | Some of the methods can be used directly as a function in a hook, others return such a function and some will return the hook itself.
10 | This will be indicated in the docstring.
11 |
12 | Saving Outputs
13 | ==============
14 |
15 | This method can be used directly in the construction of a hook.
16 |
17 | .. automethod:: unseal.hooks.common_hooks.save_output
18 |
19 |
20 | Replacing Activations
21 | =====================
22 |
23 | This method is a factory and returns a function that can be used in a hook to replace the activation of a layer.
24 |
25 | .. automethod:: unseal.hooks.common_hooks.replace_activation
26 |
27 |
28 | Saving Attention
29 | =====================
30 |
31 | .. automethod:: unseal.hooks.common_hooks.transformers_get_attention
32 |
33 |
34 | Creating an Attention Hook
35 | ===========================
36 |
37 | .. automethod:: unseal.hooks.common_hooks.create_attention_hook
38 |
39 |
40 | Creating a Logit Hook
41 | ===========================
42 |
43 | .. automethod:: unseal.hooks.common_hooks.create_logit_hook
44 |
45 |
46 | GPT ``_attn`` Wrapper
47 | =====================
48 |
49 | .. automethod:: unseal.hooks.common_hooks.gpt_attn_wrapper
50 |
51 |
52 |
--------------------------------------------------------------------------------
/docs/source/tutorials/getting_started.rst:
--------------------------------------------------------------------------------
1 | .. _getting_started:
2 |
3 | Getting started
4 | ===============
5 |
6 | If you just want to play around with models, then head to the :ref:`interface ` section.
7 |
8 | If you want to learn more about how Unseal works under the hood, check out our section on :ref:`hooking `.
--------------------------------------------------------------------------------
/docs/source/tutorials/hooking.rst:
--------------------------------------------------------------------------------
1 | .. _hooking:
2 |
3 |
4 | ===============
5 | Hooks
6 | ===============
7 |
8 | Hooks are at the heart of Unseal. In short, a hook is an access point to a model. It is defined by the point of the model at
9 | which it attaches and by the operation that it executes (usually either during the forward or backward pass).
10 |
11 | To read more about the original concept of a hook in PyTorch read `here `_.
12 |
13 | In Unseal, a hook is an object consisting of a ``layer_name`` (the point at which it attaches),
14 | a ``func`` (the function it executes), and a ``key`` (an identifying string unique to the hook)
15 |
16 | In order to simplify the hooking interface, Unseal wraps every model in the ``hooks.HookedModel`` class.
17 |
18 |
19 |
20 | hooks.HookedModel
21 | =======================
22 |
23 | You can access the top-level structure of a so-wrapped model by printing it (i.e. its ``__repr__`` property):
24 |
25 | .. code-block:: python
26 |
27 | import torch.nn as nn
28 | from unseal.hooks import HookedModel
29 |
30 | model = nn.Sequential(
31 | nn.Linear(8,64),
32 | nn.ReLU(),
33 | nn.Sequential(
34 | nn.Linear(64,256),
35 | nn.ReLU(),
36 | nn.Linear(256,1)
37 | )
38 | )
39 | model = HookedModel(model)
40 |
41 | print(model)
42 |
43 | # equivalent:
44 | # print(model.model)
45 |
46 | ''' Output:
47 | Sequential(
48 | (0): Linear(in_features=8, out_features=64, bias=True)
49 | (1): ReLU()
50 | (2): Sequential(
51 | (0): Linear(in_features=64, out_features=256, bias=True)
52 | (1): ReLU()
53 | (2): Linear(in_features=256, out_features=1, bias=True)
54 | )
55 | )
56 | '''
57 |
58 |
59 | A HookedModel also has special references to every layer which you can access via the ``layers`` attribute:
60 |
61 | .. code-block:: python
62 |
63 | print(model.layers)
64 | '''Output:
65 | OrderedDict([('0', Linear(in_features=8, out_features=64, bias=True)), ('1', ReLU()), ('2', Sequential(
66 | (0): Linear(in_features=64, out_features=256, bias=True)
67 | (1): ReLU()
68 | (2): Linear(in_features=256, out_features=1, bias=True)
69 | )), ('2->0', Linear(in_features=64, out_features=256, bias=True)), ('2->1', ReLU()), ('2->2', Linear(in_features=256, out_features=1, bias=True))])
70 | '''
71 |
72 |
73 | You can see that each layer has its own identifying string (e.g. ``'2->2'``). If you want to only display the layer names you can simply call ``model.layers.keys()``.
74 |
75 | Hooked forward passes
76 | ---------------------
77 |
78 | The most important feature of a HookedModel object is its modified ``forward`` method which allows a user to temporarily add a hook to the model, perform a forward pass
79 | and record the result in the context attribute of the HookedModel.
80 |
81 | For this, the forward method takes an additional ``hooks`` argument which is a ``list`` of ``Hooks`` which get registered. After the forward pass, the hooks are removed
82 | again (to ensure consistent behavior). Hooks have access to the ``save_ctx`` attribute of the HookedModel, so anything you want to access later goes there and can
83 | be recalled via ``model.save_ctx[your_hook_key]``. Beware that the context attribute does not get reset automatically, so running a lot of
84 | different hooks can fill up your memory.
85 |
86 |
87 | Writing your own hooks
88 | ======================
89 |
90 | As mentioned above, hooks are triples ``(layer_name, func, key)``. After choosing the attachment point (the ``layer_name``, an element from ``model.layers.keys()``),
91 | you need to implement the hooking function.
92 |
93 | Every hooking function needs to follow the signature ``save_ctx, input, output -> output``.
94 |
95 | ``save_ctx`` is a dictionary which is initialized empty by the HookedModule class
96 | during the forward pass. ``input`` and ``output`` are the input and output of the module respectively. If the hook is not modifying the output, the function does
97 | not need to return anything, as that is the default behavior.
98 |
99 | For example, let's implement a hook which saves the input and output of the first linear layer in the network we defined above:
100 |
101 |
102 | .. code-block:: python
103 |
104 | import torch
105 | from unseal import Hook
106 |
107 | # define the hooking function
108 | def save_input_output(save_ctx, input, output):
109 | # make sure to not clutter the gpu and not keep track of gradients.
110 | save_ctx['input'] = input.detach().cpu()
111 | save_ctx['output'] = output.detach().cpu()
112 |
113 | # create Hook object
114 | my_hook = Hook('0', func, 'save_input_output_0')
115 |
116 | # create random input tensor
117 | input_tensor = torch.rand((1,8))
118 |
119 | # forward pass with our new hook
120 | model.forward(input_tensor, hooks=[my_hook])
121 |
122 | # now we can access the model's context object
123 | print(model.save_ctx['save_input_output_0']['input'])
124 | print(model.save_ctx['save_input_output_0']['output'])
125 |
126 | '''Output:
127 | tensor([[0.5778, 0.0257, 0.4552, 0.4787, 0.9211, 0.0284, 0.8347, 0.9621]])
128 | tensor([[-0.6566, 1.0794, 0.1455, -0.0396, 0.0411, 0.2184, -0.3484, -0.1095,
129 | -0.2990, -0.1757, 0.1078, 0.2126, 0.4414, 0.1682, -0.2449, 0.0090,
130 | -0.0726, -0.0325, -0.5832, 0.1020, -0.2699, 0.0223, -0.8340, -0.4016,
131 | -0.2808, -0.5337, 0.1518, 1.1230, 1.1380, -0.1437, 0.2738, 0.4592,
132 | -0.7136, -0.3247, 0.2068, -0.5012, 0.4446, -0.4551, 0.2015, -0.3641,
133 | -0.1598, -0.7272, 0.0271, 0.2181, -0.3253, 0.2763, -0.5745, 0.4344,
134 | 0.0255, -0.2492, 0.1586, 0.2404, -0.2033, -0.6197, -0.1098, 0.3736,
135 | 0.1246, -0.4697, -0.7690, 0.0981, -0.0255, 0.2133, 0.3061, 0.1846]])
136 | '''
137 |
138 | To make things easier for you, Unseal comes with a few pre-implemented hooking functions, which
139 | we will explain in the next section.
--------------------------------------------------------------------------------
/docs/source/tutorials/index.rst:
--------------------------------------------------------------------------------
1 | .. _tutorials:
2 |
3 | Tutorials and Guides
4 | ====================
5 |
6 | .. toctree::
7 | :maxdepth: 1
8 | :caption: Tutorials and Guides
9 |
10 | tutorials/getting_started
11 | tutorials/hooking
12 |
--------------------------------------------------------------------------------
/docs/source/tutorials/installation.rst:
--------------------------------------------------------------------------------
1 | .. _installing_unseal:
2 |
3 | =====================
4 | Installing Unseal
5 | =====================
6 |
7 | Prerequisites
8 | -------------
9 |
10 | Unseal requires python 3.6+.
11 |
12 |
13 | Installation
14 | ------------
15 |
16 | For its interfaces, Unseal uses `this fork `_ of the PySvelte library, which can be installed via pip:
17 |
18 | .. code-block:: console
19 |
20 | git clone git@github.com:TomFrederik/PySvelte.git
21 | cd PySvelte
22 | pip install -e .
23 |
24 | In order to run PySvelte, you will also need to install ``npm`` via your package manager.
25 |
26 | Install Unseal via pip:
27 |
28 | .. code-block:: console
29 |
30 | pip install unseal
--------------------------------------------------------------------------------
/docs/source/tutorials/interfaces.rst:
--------------------------------------------------------------------------------
1 | .. _interfaces:
2 |
3 | ====================
4 | Interfaces in Unseal
5 | ====================
6 |
7 | .. contents:: Contents
8 |
9 | Unseal wants to provide simple and intuitive interfaces for exploring
10 | large language models.
11 |
12 | At its core it relies on a combination of Streamlit and PySvelte.
13 |
14 | Notebooks
15 | ==========
16 |
17 | Here we collect Google Colab Notebooks that demonstrate Unseal functionalities.
18 |
19 | .. image:: https://github.com/TomFrederik/unseal/blob/main/docs/images/notebook_images/inspectgpt2_card.png?raw=true
20 | :target: https://colab.research.google.com/drive/1Y1y2GnDT-Uzvyp8pUWWXt8lEfHWxje3b?usp=sharing
21 |
22 |
23 | .. image:: https://github.com/TomFrederik/unseal/blob/main/docs/images/notebook_images/trace_info_flow_card.png?raw=true
24 | :target: https://colab.research.google.com/drive/1ljCPvbr7VPEIlbZQvrUceLSDsdeo3oRH?usp=sharing
25 |
26 |
27 |
28 | Streamlit interfaces
29 | ====================
30 |
31 | Unseal comes with several native interfaces that are ready to use out of the box.
32 |
33 | All the pre-built interfaces are available in the ``unseal.interface.streamlit_interfaces`` package.
34 |
35 | To run any of the interfaces, you can navigate to the ``streamlit_interfaces`` directory and run
36 |
37 | .. code-block:: bash
38 |
39 | streamlit run
40 |
41 |
--------------------------------------------------------------------------------
/docs/source/unseal.hooks.rst:
--------------------------------------------------------------------------------
1 | unseal.hooks package
2 | =====================
3 |
4 | This package handles the nitty-gritty of hooking to a model.
5 |
6 |
7 | hooks module
8 | -------------------------------
9 |
10 | .. automodule:: unseal.hooks
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 |
16 | hooks.common_hooks module
17 | -------------------------------
18 |
19 | .. automodule:: unseal.hooks.common_hooks
20 | :members:
21 | :undoc-members:
22 | :show-inheritance:
23 |
24 |
25 | hooks.rome_hooks module
26 | -------------------------------
27 |
28 | .. automodule:: unseal.hooks.rome_hooks
29 | :members:
30 | :undoc-members:
31 | :show-inheritance:
32 |
33 |
34 | hooks.util module
35 | -------------------------------
36 |
37 | .. automodule:: unseal.hooks.util
38 | :members:
39 | :undoc-members:
40 | :show-inheritance:
41 |
42 | hooks.commons module
43 | -------------------------------
44 |
45 | .. automodule:: unseal.interface.commons
46 | :members:
47 | :undoc-members:
48 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/source/unseal.interface.rst:
--------------------------------------------------------------------------------
1 | unseal.visuals package
2 | ========================
3 |
4 | This package creates graphical outputs, e.g. for displaying in GUIs.
5 |
6 |
7 | visuals module
8 | -------------------------------
9 |
10 | .. automodule:: unseal.visuals
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 |
16 | visuals.commons module
17 | -------------------------------
18 |
19 | .. automodule:: unseal.visuals.commons
20 | :members:
21 | :undoc-members:
22 | :show-inheritance:
23 |
24 |
25 | visuals.utils module
26 | -------------------------------
27 |
28 | .. automodule:: unseal.visuals.utils
29 | :members:
30 | :undoc-members:
31 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/source/unseal.logit_lense.rst:
--------------------------------------------------------------------------------
1 | logit_lense module
2 | -------------------------------
3 |
4 | .. automodule:: unseal.logit_lense
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/source/unseal.transformers_util.rst:
--------------------------------------------------------------------------------
1 | transformers_util module
2 | -------------------------------
3 |
4 | .. automodule:: unseal.transformers_util
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=42"]
3 | build-backend = "setuptools.build_meta"
4 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | description-file=README.md
3 | license_files=LICENSE
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | import os, sys
3 | sys.path.insert(0, os.path.abspath("."))
4 | from unseal import __version__
5 |
6 | version = __version__
7 |
8 | with open("README.md", "r") as fh:
9 | long_description = fh.read()
10 |
11 | setup(
12 | name='Unseal',
13 | version=version,
14 | packages=find_packages(exclude=[]),
15 | python_requires='>=3.6.0',
16 | install_requires=[
17 | 'torch',
18 | 'einops>=0.3.2',
19 | 'numpy',
20 | 'transformers',
21 | 'tqdm',
22 | 'matplotlib',
23 | 'streamlit',
24 | ],
25 | # entry_points={
26 | # 'console_scripts': [
27 | # '"unseal compare" = unseal.commands.interfaces.compare_two_inputs:main',
28 | # ]
29 | # },
30 | description=(
31 | "Unseal: "
32 | "A collection of infrastructure and tools for research in "
33 | "transformer interpretability."
34 | ),
35 | long_description=long_description,
36 | long_description_content_type="text/markdown",
37 | author="The Unseal Team",
38 | author_email="tlieberum@outlook.de",
39 | url="https://github.com/TomFrederik/unseal/",
40 | license="Apache License 2.0",
41 | keywords=[
42 | "pytorch",
43 | "tensor",
44 | "machine learning",
45 | "neural networks",
46 | "interpretability",
47 | "transformers",
48 | ],
49 | )
50 |
--------------------------------------------------------------------------------
/tests/hooks/test_common_hooks.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TomFrederik/unseal/96a8251e9204b4b41213b9a4258058a180e38e82/tests/hooks/test_common_hooks.py
--------------------------------------------------------------------------------
/tests/hooks/test_commons.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Callable
3 |
4 | import pytest
5 | import torch
6 | from unseal.hooks import commons
7 |
8 | class TestHookedModel():
9 | def test_constructor(self):
10 | model = commons.HookedModel(torch.nn.Module())
11 | assert model is not None
12 | assert isinstance(model.model, torch.nn.Module)
13 | assert isinstance(model.save_ctx, dict)
14 | assert isinstance(model.layers, OrderedDict)
15 |
16 |
17 | with pytest.raises(TypeError):
18 | model = commons.HookedModel('not a module')
19 |
20 | def test_init_refs(self):
21 | model = commons.HookedModel(
22 | torch.nn.Sequential(
23 | torch.nn.Sequential(
24 | torch.nn.Linear(10,10)
25 | ),
26 | torch.nn.Linear(10,10)
27 | )
28 | )
29 | assert list(model.layers.keys()) == ['0', '0->0', '1']
30 |
31 | def test__hook_wrapper(self):
32 | model = commons.HookedModel(torch.nn.Module())
33 | def test_func(save_ctx, inp, output):
34 | return 1
35 |
36 | func = model._hook_wrapper(test_func, 'key')
37 | assert isinstance(func, Callable)
38 |
39 | def test_get_ctx_keys(self):
40 | model = commons.HookedModel(torch.nn.Module())
41 | assert model.get_ctx_keys() == []
42 |
43 | model.save_ctx['key'] = dict()
44 | assert model.get_ctx_keys() == ['key']
45 |
46 | def test_repr(self):
47 | model = commons.HookedModel(torch.nn.Module())
48 | assert model.__repr__() == model.model.__repr__()
49 |
50 | def test_device(self):
51 | model = commons.HookedModel(torch.nn.Linear(10, 10))
52 | assert model.device.type == 'cpu'
53 |
54 | if torch.cuda.is_available():
55 | model.to('cuda')
56 | assert model.device.type == 'cuda'
57 |
58 | def test_forward(self):
59 | model = commons.HookedModel(torch.nn.Sequential(torch.nn.Linear(10, 10)))
60 | def fn(save_ctx, input, output):
61 | save_ctx['key'] = 1
62 | hook = commons.Hook('0', fn, 'key')
63 | model.forward(torch.rand(10), [hook])
64 | assert model.save_ctx['key']['key'] == 1
65 |
66 | def test___call__(self):
67 | model = commons.HookedModel(torch.nn.Sequential(torch.nn.Linear(10, 10)))
68 | def fn(save_ctx, input, output):
69 | save_ctx['key'] = 2
70 | hook = commons.Hook('0', fn, 'key')
71 | model(torch.rand(10), [hook])
72 | assert model.save_ctx['key']['key'] == 2
73 |
74 |
75 | def test_hook():
76 | def correct_func(save_ctx, input, output):
77 | save_ctx['key'] = 1
78 | hook = commons.Hook('0', correct_func, 'key')
79 | assert hook.layer_name == '0'
80 | assert hook.func == correct_func
81 | assert hook.key == 'key'
82 |
83 | def false_func(save_ctx):
84 | save_ctx['key'] = 1
85 | with pytest.raises(TypeError):
86 | hook = commons.Hook('0', false_func, 'key')
--------------------------------------------------------------------------------
/tests/hooks/test_rome_hooks.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TomFrederik/unseal/96a8251e9204b4b41213b9a4258058a180e38e82/tests/hooks/test_rome_hooks.py
--------------------------------------------------------------------------------
/tests/hooks/test_util.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Iterable
3 |
4 | from unseal.hooks import util
5 | import numpy as np
6 | import pytest
7 | import torch
8 | def test_create_slice_from_str():
9 | arr = np.random.rand(10,10,10)
10 | valid_idx_strings = [
11 | '...,3:5,:',
12 | '4:5,:,:',
13 | '0,0,0',
14 | '-1,-1,-1'
15 | ]
16 | valid_result_arrs = [
17 | arr[...,3:5,:],
18 | arr[4:5,:,:],
19 | arr[0,0,0],
20 | arr[-1,-1,-1],
21 | ]
22 | for string, subarr in zip(valid_idx_strings, valid_result_arrs):
23 | assert np.all(arr[util.create_slice_from_str(string)] == subarr), f"{util.create_slice_from_str(string)}, {subarr}"
24 |
25 | invalid_idx_strings = [
26 | '',
27 | ]
28 | for idx in invalid_idx_strings:
29 | with pytest.raises(ValueError):
30 | subarry = arr[util.create_slice_from_str(idx)]
31 |
32 | def test_recursive_to_device():
33 | if not torch.cuda.is_available():
34 | logging.warning('CUDA not available, skipping recursive_to_device test.')
35 | return
36 |
37 | tensor = torch.rand(10,10,10)
38 | tensor = util.recursive_to_device(tensor, torch.device('cuda'))
39 | assert tensor.device.type == 'cuda'
40 |
41 | iterable = [tensor, [[tensor], tensor]]
42 |
43 | recursive_cuda_check(iterable)
44 |
45 | def recursive_cuda_check(iterable):
46 | if isinstance(iterable, torch.Tensor):
47 | assert iterable.device.type == 'cuda'
48 | elif isinstance(iterable, Iterable):
49 | for item in iterable:
50 | recursive_cuda_check(item)
--------------------------------------------------------------------------------
/tests/test_transformer_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unseal.transformers_util as tutil
3 | from unseal.hooks import HookedModel
4 |
5 | def test_load_model():
6 | model, tokenizer, config = tutil.load_from_pretrained('gpt2')
7 | assert model is not None
8 | assert tokenizer is not None
9 | assert config is not None
10 |
11 | def test_load_model_with_dir():
12 | model, tokenizer, config = tutil.load_from_pretrained('gpt-neo-125M', model_dir='EleutherAI')
13 | assert model is not None
14 | assert tokenizer is not None
15 | assert config is not None
16 |
17 | def test_load_model_eleuther_without_dir():
18 | model, tokenizer, config = tutil.load_from_pretrained('gpt-neo-125M')
19 | assert model is not None
20 | assert tokenizer is not None
21 | assert config is not None
22 |
23 | def test_load_model_with_low_mem():
24 | model, tokenizer, config = tutil.load_from_pretrained('gpt2', low_cpu_mem_usage=True)
25 | assert model is not None
26 | assert tokenizer is not None
27 | assert config is not None
28 |
29 | def test_get_num_layers_gpt2():
30 | model, *_ = tutil.load_from_pretrained('gpt2')
31 | model = HookedModel(model)
32 | assert tutil.get_num_layers(model, 'transformer->h') == 12
33 |
34 | def test_get_num_layers_transformer():
35 | model = torch.nn.Transformer(d_model=10, nhead=2, num_encoder_layers=0, num_decoder_layers=10)
36 | model = HookedModel(model)
37 | assert tutil.get_num_layers(model, 'decoder->layers')
38 |
--------------------------------------------------------------------------------
/unseal/__init__.py:
--------------------------------------------------------------------------------
1 | from . import *
2 |
3 | __version__ = '0.2.1'
--------------------------------------------------------------------------------
/unseal/circuits/utils.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import math
3 | from typing import Tuple, TypeVar
4 |
5 | import einops
6 | import torch
7 | from torch import Tensor
8 | from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
9 | from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoSelfAttention
10 | from transformers.models.gptj.modeling_gptj import GPTJAttention
11 |
12 | Attention = TypeVar('Attention', GPT2Attention, GPTNeoSelfAttention, GPTJAttention)
13 |
14 | def get_qkv_weights(attention_module: Attention) -> Tuple[Tensor, Tensor, Tensor]:
15 | if isinstance(attention_module, GPT2Attention):
16 | q, k, v = attention_module.c_proj.weight.chunk(3)
17 | print(f"{q.shape = }")
18 | print(f"{k.shape = }")
19 | print(f"{v.shape = }")
20 | else:
21 | q, k, v = attention_module.q_proj.weight, attention_module.k_proj.weight, attention_module.v_proj.weight
22 |
23 | q = einops.rearrange(q, '(num_heads head_dim) out_dim -> num_heads head_dim out_dim', head_dim=attention_module.head_dim)
24 | k = einops.rearrange(k, '(num_heads head_dim) out_dim -> num_heads head_dim out_dim', head_dim=attention_module.head_dim)
25 | v = einops.rearrange(v, '(num_heads head_dim) out_dim -> num_heads head_dim out_dim', head_dim=attention_module.head_dim)
26 |
27 | return q, k, v
28 |
29 | def get_o_weight(attention_module: Attention) -> Tensor:
30 | if not isinstance(attention_module, Attention):
31 | raise TypeError(f"{attention_module} is not an instance of Attention, has type {type(attention_module)}")
32 |
33 | if isinstance(attention_module, GPT2Attention):
34 | return einops.rearrange(attention_module.c_proj.weight, '(num_heads head_dim) out_dim -> num_heads head_dim out_dim', head_dim=attention_module.head_dim)
35 | else:
36 | return einops.rearrange(attention_module.out_proj.weight, '(num_heads head_dim) out_dim -> num_heads head_dim out_dim', head_dim=attention_module.head_dim)
37 |
38 | def composition(a: Tensor, b: Tensor):
39 | return (a.T @ b).norm(p='fro') / (a.norm(p='fro') * b.norm(p='fro'))
40 |
41 | def q_composition(qk: Tensor, ov: Tensor):
42 | return composition(qk, ov)
43 |
44 | def k_composition(qk: Tensor, ov: Tensor):
45 | return composition(qk.T, ov)
46 |
47 | def v_composition(ov_2: Tensor, ov_1: Tensor):
48 | return composition(ov_2.T, ov_1)
49 |
50 | def get_init_limits(weight: Tensor) -> float:
51 | # compute xavier uniform initialization limits
52 | return (-math.sqrt(6/(weight.shape[0] + weight.shape[1])), math.sqrt(6/(weight.shape[0] + weight.shape[1])))
53 |
54 | def approx_baseline(shape_1, shape_2, limits_1, limits_2, num_samples, device='cpu'):
55 | baseline = 0
56 | for i in range(num_samples):
57 | mat_1 = torch.distributions.Uniform(limits_1[0], limits_1[1]).rsample(shape_1).to(device)
58 | mat_2 = torch.distributions.Uniform(limits_2[0], limits_2[1]).rsample(shape_2).to(device)
59 |
60 | baseline += composition(mat_1, mat_2)
61 |
62 | return baseline / num_samples
63 |
64 | def compute_all_baselines(attention_module: Attention, num_samples):
65 | device = next(attention_module.parameters()).device
66 |
67 | q, k, v = get_qkv_weights(attention_module)
68 | o = get_o_weight(attention_module)
69 | # print(f"q: {q.shape}")
70 | # print(f"k: {k.shape}")
71 | # print(f"v: {v.shape}")
72 | # print(f"o: {o.shape}")
73 | qk_shape = (q.shape[1],) + (k.shape[1],)
74 | ov_shape = (o.shape[1],) + (v.shape[1],)
75 | # print(f"{qk_shape = }")
76 | # print(f"{ov_shape = }")
77 | qk_limits = get_init_limits(attention_module.qkv_proj.weight)
78 | ov_limits = get_init_limits(o)
79 |
80 | qk_baseline = approx_baseline(qk_shape, ov_shape, qk_limits, ov_limits, num_samples, device)
81 | v_baseline = approx_baseline(ov_shape, ov_shape, ov_limits, ov_limits, num_samples, device)
82 |
83 | return qk_baseline, v_baseline
84 |
85 | @torch.no_grad()
86 | def compute_all_compositions(attn_1: Attention, attn_2: Attention, num_samples: int = 1000, subtract_baseline: bool = False):
87 | if subtract_baseline:
88 | qk_baseline, v_baseline = compute_all_baselines(attn_2, num_samples)
89 | else:
90 | qk_baseline, v_baseline = 0, 0
91 | # print(f"{qk_baseline = }")
92 | # print(f"{v_baseline = }")
93 |
94 | q_1, k_1, v_1 = get_qkv_weights(attn_1)
95 | # print(f"v_1: {v_1.shape}")
96 | o_1 = get_o_weight(attn_1)
97 | # print(f"o_1: {o_1.shape}")
98 | ov_1 = torch.einsum('abc, acd -> abd', einops.rearrange(o_1, 'a c b -> a b c'), v_1)
99 | qk_1 = torch.einsum('abc, acd -> abd', einops.rearrange(q_1, 'a c b -> a b c'), k_1)
100 | # print(f"{qk_1.shape = }")
101 | # print(f"{ov_1.shape = }")
102 | q_2, k_2, v_2 = get_qkv_weights(attn_2)
103 | o_2 = get_o_weight(attn_2)
104 | ov_2 = torch.einsum('abc, acd -> abd', einops.rearrange(o_2, 'a c b -> a b c'), v_2)
105 | qk_2 = torch.einsum('abc, acd -> abd', einops.rearrange(q_2, 'a c b -> a b c'), k_2)
106 | # print(f"{qk_2.shape = }")
107 | # print(f"{ov_2.shape = }")
108 |
109 | q_comps = []
110 | k_comps = []
111 | v_comps = []
112 |
113 | for head_1, head_2 in itertools.product(range(attn_1.num_heads), range(attn_2.num_heads)):
114 | q_comps.append(q_composition(qk_2[head_2], ov_1[head_1]))
115 | k_comps.append(k_composition(qk_2[head_2], ov_1[head_1]))
116 | v_comps.append(v_composition(ov_2[head_2], ov_1[head_1]))
117 |
118 | q_comps = torch.stack(q_comps) - qk_baseline
119 | k_comps = torch.stack(k_comps) - qk_baseline
120 | v_comps = torch.stack(v_comps) - v_baseline
121 |
122 |
123 | return q_comps.clamp(min=0).cpu(), k_comps.clamp(min=0).cpu(), v_comps.clamp(min=0).cpu()
124 |
125 |
--------------------------------------------------------------------------------
/unseal/hooks/__init__.py:
--------------------------------------------------------------------------------
1 | from . import common_hooks
2 | from . import util
3 | from .commons import Hook, HookedModel
--------------------------------------------------------------------------------
/unseal/hooks/common_hooks.py:
--------------------------------------------------------------------------------
1 | # pre-implemented common hooks
2 | import logging
3 | import math
4 | from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
5 |
6 | import einops
7 | import torch
8 | import torch.nn.functional as F
9 | from tqdm import tqdm
10 |
11 | from . import util
12 | from .commons import Hook, HookedModel
13 |
14 |
15 | def save_output(cpu: bool = True, detach: bool = True) -> Callable:
16 | """Basic hooking function for saving the output of a module to the global context object
17 |
18 | :param cpu: Whether to save the output to cpu.
19 | :type cpu: bool
20 | :param detach: Whether to detach the output.
21 | :type detach: bool
22 | :return: Function that saves the output to the context object.
23 | :rtype: Callable
24 | """
25 | def inner(save_ctx, input, output):
26 | if detach:
27 | if isinstance(output, torch.Tensor):
28 | output = output.detach()
29 | else:
30 | logging.warning('Detaching tensor for iterables not implemented')
31 | # not implemented
32 | if cpu:
33 | if isinstance(output, torch.Tensor):
34 | output = output.cpu()
35 | elif isinstance(output, Iterable): # hope for the best
36 | output = util.recursive_to_device(output, 'cpu')
37 | else:
38 | raise TypeError(f"Unsupported type for output {type(output)}")
39 | save_ctx['output'] = output
40 |
41 | return inner
42 |
43 | def replace_activation(indices: str, replacement_tensor: torch.Tensor, tuple_index: int = None) -> Callable:
44 | """Creates a hook which replaces a module's activation (output) with a replacement tensor.
45 | If there is a dimension mismatch, the replacement tensor is copied along the leading dimensions of the output.
46 |
47 | Example: If the activation has shape ``(B, T, D)`` and replacement tensor has shape ``(D,)`` which you want to plug in
48 | at position t in the T dimension for every tensor in the batch, then indices should be ``:,t,:``.
49 |
50 | :param indices: Indices at which to insert the replacement tensor
51 | :type indices: str
52 | :param replacement_tensor: Tensor that is filled in.
53 | :type replacement_tensor: torch.Tensor
54 | :param tuple_index: Index of the tuple in the output of the module.
55 | :type tuple_index: int
56 | :return: Function that replaces part of a given tensor with replacement_tensor
57 | :rtype: Callable
58 | """
59 | slice_ = util.create_slice_from_str(indices)
60 | def func(save_ctx, input, output):
61 | if tuple_index is None:
62 | # add dummy dimensions if shape mismatch
63 | diff = len(output[slice_].shape) - len(replacement_tensor.shape)
64 | rep = replacement_tensor[(None,)*diff].to(input.device)
65 | # replace part of tensor
66 | output[slice_] = rep
67 | else:
68 | # add dummy dimensions if shape mismatch
69 | diff = len(output[tuple_index][slice_].shape) - len(replacement_tensor.shape)
70 | rep = replacement_tensor[(None,)*diff].to(input.device)
71 | # replace part of tensor
72 | output[tuple_index][slice_] = rep
73 | return output
74 |
75 | return func
76 |
77 | def transformers_get_attention(
78 | heads: Optional[Union[int, Iterable[int], str]] = None,
79 | output_idx: Optional[int] = None,
80 | ) -> Callable:
81 | """Creates a hooking function to get the attention patterns of a given layer.
82 |
83 | :param heads: The heads for which to save the attention, defaults to None
84 | :type heads: Optional[Union[int, Iterable[int], str]], optional
85 | :param output_idx: If the attention module returns a tuple, use this argument to index it, defaults to None
86 | :type output_idx: Optional[int], optional
87 | :return: func, hooking function that saves attention of the specified heads
88 | :rtype: Callable
89 | """
90 |
91 | # convert head string to slice
92 | if heads is None:
93 | heads = ":"
94 | if isinstance(heads, str):
95 | heads = util.create_slice_from_str(heads)
96 |
97 | def func(save_ctx, input, output):
98 | if output_idx is None:
99 | save_ctx['attn'] = output[:,heads,...].detach().cpu()
100 | else:
101 | save_ctx['attn'] = output[output_idx][:,heads,...].detach().cpu()
102 |
103 | return func
104 |
105 | def create_attention_hook(
106 | layer: int,
107 | key: str,
108 | output_idx: Optional[int] = None,
109 | attn_name: Optional[str] = 'attn',
110 | layer_key_prefix: Optional[str] = None,
111 | heads: Optional[Union[int, Iterable[int], str]] = None
112 | ) -> Hook:
113 | """Creates a hook which saves the attention patterns of a given layer.
114 |
115 | :param layer: The layer to hook.
116 | :type layer: int
117 | :param key: The key to use for saving the attention patterns.
118 | :type key: str
119 | :param output_idx: If the module output is a tuple, index it with this. GPT like models need this to be equal to 2, defaults to None
120 | :type output_idx: Optional[int], optional
121 | :param attn_name: The name of the attention module in the transformer, defaults to 'attn'
122 | :type attn_name: Optional[str], optional
123 | :param layer_key_prefix: The prefix in the model structure before the layer idx, e.g. 'transformer->h', defaults to None
124 | :type layer_key_prefix: Optional[str], optional
125 | :param heads: Which heads to save the attention pattern for. Can be int, tuple of ints or string like '1:3', defaults to None
126 | :type heads: Optional[Union[int, Iterable[int], str]], optional
127 | :return: Hook which saves the attention patterns
128 | :rtype: Hook
129 | """
130 | if layer_key_prefix is None:
131 | layer_key_prefix = ""
132 |
133 | func = transformers_get_attention(heads, output_idx)
134 | return Hook(f'{layer_key_prefix}{layer}->{attn_name}', func, key)
135 |
136 |
137 | def create_logit_hook(
138 | layer:int,
139 | model: HookedModel,
140 | unembedding_key: str,
141 | layer_key_prefix: Optional[str] = None,
142 | target: Optional[Union[int, List[int]]] = None,
143 | position: Optional[Union[int, List[int]]] = None,
144 | key: Optional[str] = None,
145 | split_heads: Optional[bool] = False,
146 | num_heads: Optional[int] = None,
147 | ) -> Hook:
148 | """Create a hook that saves the logits of a layer's output.
149 | Outputs are saved to save_ctx[key]['logits'].
150 |
151 | :param layer: The number of the layer
152 | :type layer: int
153 | :param model: The model.
154 | :type model: HookedModel
155 | :param unembedding_key: The key/name of the embedding matrix, e.g. 'lm_head' for causal LM models
156 | :type unembedding_key: str
157 | :param layer_key_prefix: The prefix of the key of the layer, e.g. 'transformer->h' for GPT like models
158 | :type layer_key_prefix: str
159 | :param target: The target token(s) to extract logits for. Defaults to all tokens.
160 | :type target: Union[int, List[int]]
161 | :param position: The position for which to extract logits for. Defaults to all positions.
162 | :type position: Union[int, List[int]]
163 | :param key: The key of the hook. Defaults to {layer}_logits.
164 | :type key: str
165 | :param split_heads: Whether to split the heads. Defaults to False.
166 | :type split_heads: bool
167 | :param num_heads: The number of heads to split. Defaults to None.
168 | :type num_heads: int
169 | :return: The hook.
170 | :rtype: Hook
171 | """
172 | if layer_key_prefix is None:
173 | layer_key_prefix = ""
174 |
175 | # generate slice
176 | if target is None:
177 | target = ":"
178 | else:
179 | if isinstance(target, int):
180 | target = str(target)
181 | else:
182 | target = "[" + ",".join(str(t) for t in target) + "]"
183 | if position is None:
184 | position = ":"
185 | else:
186 | if isinstance(position, int):
187 | position = str(position)
188 | else:
189 | position = "[" + ",".join(str(p) for p in position) + "]"
190 | position_slice = util.create_slice_from_str(f":,{position},:")
191 | target_slice = util.create_slice_from_str(f"{target},:")
192 |
193 | # load the relevant part of the vocab matrix
194 | vocab_matrix = model.layers[unembedding_key].weight[target_slice].T
195 |
196 | # split vocab matrix
197 | if split_heads:
198 | if num_heads is None:
199 | raise ValueError("num_heads must be specified if split_heads is True")
200 | else:
201 | vocab_matrix = einops.rearrange(vocab_matrix, '(num_heads head_dim) vocab_size -> num_heads head_dim vocab_size', num_heads=num_heads)
202 |
203 | def inner(save_ctx, input, output):
204 | if split_heads:
205 | einsum_in = einops.rearrange(output[0][position_slice], 'batch seq_len (num_heads head_dim) -> batch num_heads seq_len head_dim', num_heads=num_heads)
206 | einsum_out = torch.einsum('bcij,cjk->bcik', einsum_in, vocab_matrix)
207 | else:
208 | einsum_in = output[0][position_slice]
209 | einsum_out = torch.einsum('bij,jk->bik', einsum_in, vocab_matrix)
210 |
211 | save_ctx['logits'] = einsum_out.detach().cpu()
212 |
213 | # write key
214 | if key is None:
215 | key = str(layer) + '_logits'
216 |
217 | # create hook
218 | hook = Hook(f'{layer_key_prefix}{layer}', inner, key)
219 |
220 | return hook
221 |
222 | def gpt_attn_wrapper(
223 | func: Callable,
224 | save_ctx: Dict,
225 | c_proj: torch.Tensor,
226 | vocab_embedding: torch.Tensor,
227 | target_ids: torch.Tensor,
228 | batch_size: Optional[int] = None,
229 | ) -> Tuple[Callable, Callable]:
230 | """Wraps around the [AttentionBlock]._attn function to save the individual heads' logits.
231 | This is necessary because the individual heads' logits are not available on a module level and thus not accessible via a hook.
232 |
233 | :param func: original _attn function
234 | :type func: Callable
235 | :param save_ctx: context to which the logits will be saved
236 | :type save_ctx: Dict
237 | :param c_proj: projection matrix, this is W_O in Anthropic's terminology
238 | :type c_proj: torch.Tensor
239 | :param vocab_matrix: vocabulary/embedding matrix, this is W_V in Anthropic's terminology
240 | :type vocab_matrix: torch.Tensor
241 | :param target_ids: indices of the target tokens for which the logits are computed
242 | :type target_ids: torch.Tensor
243 | :param batch_size: batch size to reduce memory footprint, defaults to None
244 | :type batch_size: Optional[int]
245 | :return: inner, func, the wrapped function and the original function
246 | :rtype: Tuple[Callable, Callable]
247 | """
248 | # TODO Find a smarter/more efficient way of implementing this function
249 | # TODO clean up this function
250 | def inner(query, key, value, *args, **kwargs):
251 | nonlocal c_proj
252 | nonlocal target_ids
253 | nonlocal vocab_embedding
254 | nonlocal batch_size
255 | attn_output, attn_weights = func(query, key, value, *args, **kwargs)
256 | if batch_size is None:
257 | batch_size = attn_output.shape[0]
258 | with torch.no_grad():
259 | temp = attn_weights[...,None] * value[:,:,None]
260 | if len(c_proj.shape) == 2:
261 | c_proj = einops.rearrange(c_proj, '(head_dim num_heads) out_dim -> head_dim num_heads out_dim', num_heads=attn_output.shape[1])
262 | c_proj = einops.rearrange(c_proj, 'h n o -> n h o')
263 | temp = temp[0,:,:-1,:-1] # could this be done earlier?
264 | new_temp = []
265 | for head in tqdm(range(temp.shape[0])):
266 | new_temp.append([])
267 | for i in range(math.ceil(temp.shape[1] / batch_size)):
268 | out = temp[head, i*batch_size:(i+1)*batch_size] @ c_proj[head]
269 | out = out @ vocab_embedding # compute logits
270 | new_temp[-1].append(out)
271 |
272 | # center logits
273 | new_temp[-1] = torch.cat(new_temp[-1])
274 | new_temp[-1] -= torch.mean(new_temp[-1], dim=-1, keepdim=True)
275 | # select targets
276 | new_temp[-1] = new_temp[-1][...,torch.arange(len(target_ids)), :, target_ids]#.to('cpu')
277 |
278 | new_temp = torch.stack(new_temp, dim=0) # stack heads
279 | max_pos_value = torch.amax(new_temp).item()
280 | max_neg_value = torch.amax(-new_temp).item()
281 |
282 | save_ctx['logits'] = {
283 | 'pos': (new_temp/max_pos_value).clamp(min=0, max=1).detach().cpu(),
284 | 'neg': (-new_temp/max_neg_value).clamp(min=0, max=1).detach().cpu(),
285 | }
286 | return attn_output, attn_weights
287 | return inner, func
288 |
289 | #TODO update docs here
290 | def additive_output_noise(
291 | indices: str,
292 | mean: Optional[float] = 0,
293 | std: Optional[float] = 0.1
294 | ) -> Callable:
295 | slice_ = util.create_slice_from_str(indices)
296 | def func(save_ctx, input, output):
297 | noise = mean + std * torch.randn_like(output[slice_])
298 | output[slice_] += noise
299 | return output
300 | return func
301 |
302 | def hidden_patch_hook_fn(
303 | position: int,
304 | replacement_tensor: torch.Tensor,
305 | ) -> Callable:
306 | indices = "...," + str(position) + len(replacement_tensor.shape) * ",:"
307 | inner = replace_activation(indices, replacement_tensor)
308 | def func(save_ctx, input, output):
309 | output[0][...] = inner(save_ctx, input, output[0])
310 | return output
311 |
--------------------------------------------------------------------------------
/unseal/hooks/commons.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from dataclasses import dataclass
3 | from inspect import signature
4 | from typing import List, Callable
5 |
6 | import torch
7 |
8 | from . import util
9 |
10 | class Hook:
11 | layer_name: str
12 | func: Callable
13 | key: str
14 | def __init__(self, layer_name: str, func: Callable, key: str):
15 | # check that func takes three arguments
16 | if len(signature(func).parameters) != 3:
17 | raise TypeError(f'Hook function {func.__name__} should have three arguments, but has {len(signature(func).parameters)}.')
18 |
19 | self.layer_name = layer_name
20 | self.func = func
21 | self.key = key
22 |
23 | class HookedModel(torch.nn.Module):
24 | def __init__(self, model):
25 | """Wrapper around a module that allows forward passes with hooks and a context object.
26 |
27 | :param model: Model to be hooked
28 | :type model: nn.Module
29 | :raises TypeError: Incorrect model type
30 | """
31 | super().__init__()
32 |
33 | # check inputs
34 | if not isinstance(model, torch.nn.Module):
35 | raise TypeError(f"model should be type torch.nn.Module but is {type(model)}")
36 |
37 | self.model = model
38 |
39 | # initialize hooks
40 | self.init_refs()
41 |
42 | # init context for accessing hook output
43 | self.save_ctx = dict()
44 |
45 | def init_refs(self):
46 | """Creates references for every layer in a model."""
47 |
48 | self.layers = OrderedDict()
49 |
50 | # recursive naming function
51 | def name_layers(net, prefix=[]):
52 | if hasattr(net, "_modules"):
53 | for name, layer in net._modules.items():
54 | if layer is None:
55 | # e.g. GoogLeNet's aux1 and aux2 layers
56 | continue
57 | self.layers["->".join(prefix + [name])] = layer
58 | name_layers(layer, prefix=prefix + [name])
59 | else:
60 | raise ValueError('net has not _modules attribute! Check if your model is properly instantiated..')
61 |
62 | name_layers(self.model)
63 |
64 |
65 | def forward(
66 | self,
67 | input_ids: torch.Tensor,
68 | hooks: List[Hook],
69 | *args,
70 | **kwargs,
71 | ):
72 | """Wrapper around the default forward pass that temporarily registers hooks, executes the forward pass and then closes hooks again.
73 | """
74 | # register hooks
75 | registered_hooks = []
76 | for hook in hooks:
77 | layer = self.layers.get(hook.layer_name, None)
78 | if layer is None:
79 | raise ValueError(f'Layer {hook.layer_name} was not found during hook registration! Here is the whole model for reference:\n {self.__repr__}')
80 | self.save_ctx[hook.key] = dict() # create sub-context for each hook to write to
81 | registered_hooks.append(layer.register_forward_hook(self._hook_wrapper(hook.func, hook.key)))
82 |
83 | # forward
84 | output = self.model(input_ids, *args, **kwargs) #TODO generalize to non-HF models which would not have an input_ids kwarg
85 |
86 | # remove hooks
87 | for hook in registered_hooks:
88 | hook.remove()
89 |
90 | return output
91 |
92 | def _hook_wrapper(self, func, hook_key):
93 | """Wrapper to comply with PyTorch's hooking API while enabling saving to context.
94 |
95 | :param func: [description]
96 | :type func: [type]
97 | :param hook_key: [description]
98 | :type hook_key: [type]
99 | :return: [description]
100 | :rtype: [type]
101 | """
102 | return lambda model, input, output: func(save_ctx=self.save_ctx[hook_key], input=input[0], output=output)
103 |
104 | def get_ctx_keys(self):
105 | return list(self.save_ctx.keys())
106 |
107 | def __repr__(self):
108 | return self.model.__repr__()
109 |
110 | @property
111 | def device(self):
112 | return next(self.model.parameters()).device
--------------------------------------------------------------------------------
/unseal/hooks/rome_hooks.py:
--------------------------------------------------------------------------------
1 | #TODO for version 1.0.0: Remove this file
2 | import logging
3 |
4 | from .common_hooks import additive_output_noise, hidden_patch_hook_fn
5 |
6 | logging.warning("rome_hooks.py is deprecated and will be removed in version 1.0.0. Please use common_hooks.py instead.")
7 |
8 |
--------------------------------------------------------------------------------
/unseal/hooks/util.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Iterable, Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 |
8 | def create_slice_from_str(indices: str) -> slice:
9 | """Creates a slice object from a string representing the slice.
10 |
11 | :param indices: String representing the slice, e.g. ``...,3:5,:``
12 | :type indices: str
13 | :return: Slice object corresponding to the input indices.
14 | :rtype: slice
15 | """
16 | if len(indices) == 0:
17 | raise ValueError('Empty string is not a valid slice.')
18 | return eval(f'np.s_[{indices}]')
19 |
20 |
21 | def recursive_to_device(
22 | iterable: Union[Iterable, torch.Tensor],
23 | device: Union[str, torch.device],
24 | ) -> Iterable:
25 | """Recursively puts an Iterable of (Iterable of (...)) tensors on the given device
26 |
27 | :param iterable: Tensor or Iterable of tensors or iterables of ...
28 | :type iterable: Tensor or Iterable
29 | :param device: Device on which to put the object
30 | :type device: Union[str, torch.device]
31 | :raises TypeError: Unexpected tyes
32 | :return: Nested iterable with the tensors on the new device
33 | :rtype: Iterable
34 | """
35 | if isinstance(iterable, torch.Tensor):
36 | return iterable.to(device)
37 |
38 | new = []
39 | for i, item in enumerate(iterable):
40 | if isinstance(item, torch.Tensor):
41 | new.append(item.to(device))
42 | elif isinstance(item, Iterable):
43 | new.append(recursive_to_device(item, device))
44 | else:
45 | raise TypeError(f'Expected type tensor or Iterable but got {type(item)}.')
46 | if isinstance(iterable, Tuple):
47 | new = tuple(new)
48 | return new
--------------------------------------------------------------------------------
/unseal/logit_lense.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Optional, List
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from transformers import AutoTokenizer
7 |
8 | from .transformers_util import get_num_layers
9 | from .hooks.common_hooks import create_logit_hook
10 | from .hooks.commons import HookedModel
11 |
12 | def generate_logit_lense(
13 | model: HookedModel,
14 | tokenizer: AutoTokenizer,
15 | sentence: str,
16 | layers: Optional[List[int]] = None,
17 | ranks: Optional[bool] = False,
18 | kl_div: Optional[bool] = False,
19 | include_input: Optional[bool] = False,
20 | layer_key_prefix: Optional[str] = None,
21 | ):
22 | """Generates the necessary data to generate the plots from the logits `lense post
23 | `_.
24 |
25 | Returns None for ranks and kl_div if not specified.
26 |
27 | :param model: Model that is investigated.
28 | :type model: HookedModel
29 | :param tokenizer: Tokenizer of the model.
30 | :type tokenizer: AutoTokenizer
31 | :param sentence: Sentence to be analyzed.
32 | :type sentence: str
33 | :param layers: List of layers to be investigated.
34 | :type layers: Optional[List[int]]
35 | :param ranks: Whether to return ranks of the correct token throughout layers, defaults to False
36 | :type ranks: Optional[bool], optional
37 | :param kl_div: Whether to return the KL divergence between intermediate probabilities and final output probabilities, defaults to False
38 | :type kl_div: Optional[bool], optional
39 | :param include_input: Whether to include the immediate logits/ranks/kld after embedding the input, defaults to False
40 | :type include_input: Optional[bool], optional
41 | :param layer_key_prefix: Prefix for the layer keys, e.g. 'transformer->h' for GPT like models, defaults to None
42 | :type layer_key_prefix: Optional[str], optional
43 | :return: logits, ranks, kl_div
44 | :rtype: Tuple[torch.Tensor]
45 | """
46 |
47 | # TODO
48 | if include_input:
49 | logging.warning("include_input is not implemented yet")
50 |
51 | # prepare model input
52 | tokenized_sentence = tokenizer.encode(sentence, return_tensors='pt').to(model.device)
53 | targets = tokenizer.encode(sentence)[1:]
54 |
55 | # instantiate hooks
56 | num_layers = get_num_layers(model, layer_key_prefix=layer_key_prefix)
57 | if layers is None:
58 | layers = list(range(num_layers))
59 | logit_hooks = [create_logit_hook(layer, model, 'lm_head', layer_key_prefix) for layer in layers]
60 |
61 | # run model
62 | model.forward(tokenized_sentence, hooks=logit_hooks)
63 | logits = torch.stack([model.save_ctx[str(layer) + '_logits']['logits'] for layer in range(num_layers)], dim=0)
64 | logits = F.log_softmax(logits, dim=-1)
65 |
66 | # compute ranks and kld
67 | if ranks:
68 | inverted_ranks = torch.argsort(logits, dim=-1, descending=True)
69 | ranks = torch.argsort(inverted_ranks, dim=-1) + 1
70 | ranks = ranks[:, torch.arange(len(targets)), targets]
71 | else:
72 | ranks = None
73 |
74 | if kl_div: # Note: logits are already normalized internally by the logit_hook
75 | kl_div = F.kl_div(logits, logits[-1][None], reduction='none', log_target=True).sum(dim=-1)
76 | kl_div = kl_div[:, torch.arange(len(targets)), targets]
77 | else:
78 | kl_div = None
79 |
80 | logits = logits[:, torch.arange(len(targets)), targets]
81 |
82 | return logits, ranks, kl_div
83 |
84 |
85 |
--------------------------------------------------------------------------------
/unseal/transformers_util.py:
--------------------------------------------------------------------------------
1 | # utility functions for interacting with huggingface's transformers library
2 | import logging
3 | import os
4 | from typing import Optional, Tuple
5 |
6 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
7 | from transformers.file_utils import RepositoryNotFoundError
8 |
9 | from .hooks.commons import HookedModel
10 |
11 | def load_from_pretrained(
12 | model_name: str,
13 | model_dir: Optional[str] = None,
14 | load_model: Optional[bool] = True,
15 | load_tokenizer: Optional[bool] = True,
16 | load_config: Optional[bool] = True,
17 | low_cpu_mem_usage: Optional[bool] = False,
18 | ) -> Tuple[AutoModelForCausalLM, AutoTokenizer, AutoConfig]:
19 | """Load a pretrained model from huggingface's transformer library
20 |
21 | :param model_name: Name of the model, e.g. `gpt2` or `gpt-neo`.
22 | :type model_name: str
23 | :param model_dir: Directory in which to look for the model, e.g. `EleutherAI`, defaults to None
24 | :type model_dir: Optional[str], optional
25 | :param load_model: Whether to load the model itself, defaults to True
26 | :type load_model: Optional[bool], optional
27 | :param load_tokenizer: Whether to load the tokenizer, defaults to True
28 | :type load_tokenizer: Optional[bool], optional
29 | :param load_config: Whether to load the config file, defaults to True
30 | :type load_config: Optional[bool], optional
31 | :param low_cpu_mem_usage: Whether to use low-memory mode, experimental feature of HF, defaults to False
32 | :type low_cpu_mem_usage: bool, optional
33 | :return: model, tokenizer, config. Returns None values for those elements which were not loaded.
34 | :rtype: Tuple[AutoModelForCausalLM, AutoTokenizer, AutoConfig]
35 | """
36 | if model_dir is None:
37 | try:
38 | logging.info(f'Loading model {model_name}')
39 |
40 | model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=low_cpu_mem_usage) if load_model else None
41 | tokenizer = AutoTokenizer.from_pretrained(model_name) if load_tokenizer else None
42 | config = AutoConfig.from_pretrained(model_name) if load_config else None
43 |
44 | except (RepositoryNotFoundError, OSError) as error:
45 | logging.warning("Couldn't find model in default folder. Trying EleutherAI/...")
46 |
47 | model = AutoModelForCausalLM.from_pretrained(f'EleutherAI/{model_name}', low_cpu_mem_usage=low_cpu_mem_usage) if load_model else None
48 | tokenizer = AutoTokenizer.from_pretrained(f'EleutherAI/{model_name}') if load_tokenizer else None
49 | config = AutoConfig.from_pretrained(f'EleutherAI/{model_name}') if load_config else None
50 |
51 | else:
52 | model = AutoModelForCausalLM.from_pretrained(os.path.join(model_dir, model_name), low_cpu_mem_usage=low_cpu_mem_usage) if load_model else None
53 | tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_dir, model_name)) if load_tokenizer else None
54 | config = AutoConfig.from_pretrained(os.path.join(model_dir, model_name)) if load_config else None
55 |
56 | return model, tokenizer, config
57 |
58 | def get_num_layers(model: HookedModel, layer_key_prefix: Optional[str] = None) -> int:
59 | """Get the number of layers in a model
60 |
61 | :param model: The model to get the number of layers from
62 | :type model: HookedModel
63 | :param layer_key_prefix: The prefix to use for the layer keys, defaults to None
64 | :type layer_key_prefix: Optional[str], optional
65 | :return: The number of layers in the model
66 | :rtype: int
67 | """
68 | if layer_key_prefix is None:
69 | layer_key_prefix = ""
70 |
71 | return len(model.layers[f"{layer_key_prefix}"])
--------------------------------------------------------------------------------
/unseal/visuals/__init__.py:
--------------------------------------------------------------------------------
1 | from . import utils
--------------------------------------------------------------------------------
/unseal/visuals/streamlit_interfaces/__init__.py:
--------------------------------------------------------------------------------
1 | from .commons import *
--------------------------------------------------------------------------------
/unseal/visuals/streamlit_interfaces/all_layers_single_input.py:
--------------------------------------------------------------------------------
1 | import json
2 | import time
3 |
4 | import streamlit as st
5 |
6 | from unseal.visuals.streamlit_interfaces import utils
7 | from unseal.visuals.streamlit_interfaces import interface_setup as setup
8 | from unseal.visuals.streamlit_interfaces.commons import SESSION_STATE_VARIABLES
9 |
10 | # perform startup tasks
11 | setup.startup(SESSION_STATE_VARIABLES, './registered_models.json')
12 |
13 | # create sidebar
14 | with st.sidebar:
15 | setup.create_sidebar()
16 |
17 | sample = st.checkbox('Enable sampling', value=False, key='sample')
18 | if sample:
19 | setup.create_sample_sliders()
20 | setup.on_sampling_config_change()
21 |
22 | if "storage" not in st.session_state:
23 | st.session_state["storage"] = [""]
24 |
25 | # input 1
26 | placeholder1 = st.empty()
27 | placeholder1.text_area(label='Input 1', on_change=utils.on_text_change, key='input_text_1', value=st.session_state.storage[0], kwargs=dict(col_idx=0, text_key='input_text_1'))
28 | if sample:
29 | st.button(label="Sample", on_click=utils.sample_text, kwargs=dict(model=st.session_state.model, col_idx=0, key="input_text"), key="sample_text")
30 |
31 | # sometimes need to force a re-render
32 | st.button('Show Attention', on_click=utils.text_change, kwargs=dict(col_idx=0))
33 |
34 | f = json.encoder.JSONEncoder().encode(st.session_state.visualization)
35 | st.download_button(
36 | label='Download Visualization',
37 | data=f,
38 | file_name=f'{st.session_state.model_name}_{time.strftime("%Y%m%d_%H%M%S", time.localtime())}.json',
39 | mime='application/json',
40 | help='Download the visualizations as a json of html files.',
41 | key='download_button'
42 | )
43 |
44 | # show the html visualization
45 | if st.session_state.model is not None:
46 | cols = st.columns(1)
47 | for col_idx, col in enumerate(cols):
48 | if f"col_{col_idx}" in st.session_state.visualization:
49 | with col:
50 | for layer in range(st.session_state.num_layers):
51 | if f"layer_{layer}" in st.session_state.visualization[f"col_{col_idx}"]:
52 | with st.expander(f'Layer {layer}'):
53 | st.components.v1.html(st.session_state.visualization[f"col_{col_idx}"][f"layer_{layer}"], height=600)
54 | else:
55 | st.session_state.visualization[f"col_{col_idx}"] = dict()
--------------------------------------------------------------------------------
/unseal/visuals/streamlit_interfaces/commons.py:
--------------------------------------------------------------------------------
1 | # define some global variables
2 | HF_MODELS = [
3 | 'gpt2',
4 | 'gpt2-medium',
5 | 'gpt2-large',
6 | 'gpt2-xl',
7 | 'gpt-neo-125M',
8 | 'gpt-neo-1.3B',
9 | 'gpt-neo-2.7B',
10 | 'gpt-j-6b',
11 | ]
12 |
13 | SESSION_STATE_VARIABLES = [
14 | 'model',
15 | 'tokenizer',
16 | 'config',
17 | 'registered_models',
18 | 'registered_model_names',
19 | 'num_layers',
20 | ]
21 |
--------------------------------------------------------------------------------
/unseal/visuals/streamlit_interfaces/compare_two_inputs.py:
--------------------------------------------------------------------------------
1 | import time
2 | import json
3 | import streamlit as st
4 |
5 | from unseal.visuals.streamlit_interfaces import utils
6 | from unseal.visuals.streamlit_interfaces import interface_setup as setup
7 | from unseal.visuals.streamlit_interfaces.commons import SESSION_STATE_VARIABLES
8 |
9 | # perform startup tasks
10 | setup.startup(SESSION_STATE_VARIABLES, './registered_models.json')
11 |
12 | # create sidebar
13 | with st.sidebar:
14 | setup.create_sidebar()
15 |
16 | sample = st.checkbox('Enable sampling', value=False, key='sample')
17 | if sample:
18 | setup.create_sample_sliders()
19 | setup.on_sampling_config_change()
20 |
21 | if "storage" not in st.session_state:
22 | st.session_state["storage"] = ["", ""]
23 |
24 | # input 1
25 | placeholder1 = st.empty()
26 | placeholder1.text_area(label='Input 1', on_change=utils.on_text_change, key='input_text_1', value=st.session_state.storage[0], kwargs=dict(col_idx=0, text_key='input_text_1'))
27 | if sample:
28 | st.button(label="Sample", on_click=utils.sample_text, kwargs=dict(model=st.session_state.model, col_idx=0, key="input_text_1"), key="sample_text_1")
29 |
30 | # input 2
31 | placeholder2 = st.empty()
32 | placeholder2.text_area(label='Input 2', on_change=utils.on_text_change, key='input_text_2', value=st.session_state.storage[1], kwargs=dict(col_idx=1, text_key='input_text_2'))
33 | if sample:
34 | st.button(label="Sample", on_click=utils.sample_text, kwargs=dict(model=st.session_state.model, col_idx=1, key="input_text_2"), key="sample_text_2")
35 |
36 | # sometimes need to force a re-render
37 | st.button('Show Attention', on_click=utils.text_change, kwargs=dict(col_idx=[0,1]))
38 |
39 | # download button
40 | f = json.encoder.JSONEncoder().encode(st.session_state.visualization)
41 | st.download_button(
42 | label='Download Visualization',
43 | data=f,
44 | file_name=f'{st.session_state.model_name}_{time.strftime("%Y%m%d_%H%M%S", time.localtime())}.json',
45 | mime='application/json',
46 | help='Download the visualizations as a json of html files.',
47 | key='download_button'
48 | )
49 |
50 | # show the html visualization
51 | if st.session_state.model is not None:
52 | cols = st.columns(2)
53 | for col_idx, col in enumerate(cols):
54 | if f"col_{col_idx}" in st.session_state.visualization:
55 | with col:
56 | for layer in range(st.session_state.num_layers):
57 | if f"layer_{layer}" in st.session_state.visualization[f"col_{col_idx}"]:
58 | with st.expander(f'Layer {layer}'):
59 | st.components.v1.html(st.session_state.visualization[f"col_{col_idx}"][f"layer_{layer}"], height=600, scrolling=True)
60 | else:
61 | st.session_state.visualization[f"col_{col_idx}"] = dict()
62 |
--------------------------------------------------------------------------------
/unseal/visuals/streamlit_interfaces/interface_setup.py:
--------------------------------------------------------------------------------
1 | # Functions that are usually only called once during setup of the interface
2 | import importlib
3 | import json
4 | from typing import Optional, List, Tuple
5 |
6 | import streamlit as st
7 | import torch
8 |
9 | from .commons import HF_MODELS
10 | from ...transformers_util import get_num_layers, load_from_pretrained
11 | from ...hooks import HookedModel
12 |
13 | def create_model_config(model_names):
14 | with st.form('model_config'):
15 | st.write('## Model Config')
16 |
17 | if model_names is None:
18 | model_options = list()
19 | else:
20 | model_options = model_names
21 |
22 | st.selectbox(
23 | 'Model',
24 | options=model_options,
25 | key='model_name',
26 | index=0,
27 | )
28 |
29 | devices = ['cpu']
30 | if torch.cuda.is_available():
31 | devices += ['cuda']
32 | st.selectbox(
33 | 'Device',
34 | options=devices,
35 | index=0,
36 | key='device'
37 | )
38 |
39 | st.text_area(label='Prefix Prompt', key='prefix_prompt', value='')
40 |
41 | submitted = st.form_submit_button("Save model config")
42 | if submitted:
43 | st.session_state.model, st.session_state.tokenizer, st.session_state.config = on_config_submit(st.session_state.model_name)
44 | st.write('Model config saved!')
45 |
46 |
47 | def create_sidebar():
48 | st.checkbox('Show only local models', value=False, key='local_only')
49 |
50 | if not st.session_state.local_only:
51 | model_names = st.session_state.registered_model_names + HF_MODELS
52 | else:
53 | model_names = st.session_state.registered_model_names
54 |
55 | create_model_config(model_names)
56 |
57 | def create_sample_sliders():
58 | st.slider(label="Temperature", min_value=0., max_value=1.0, value=0., step=0.01, key='temperature', on_change=on_sampling_config_change)
59 | st.slider(label="Response length", min_value=1, max_value=1024, value=64, step=1, key='response_length', on_change=on_sampling_config_change)
60 | st.slider(label="Top P", min_value=0., max_value=1.0, value=1., step=0.01, key='top_p', on_change=on_sampling_config_change)
61 | st.slider(label="Repetition Penalty (1 = no penalty)", min_value=0.01, max_value=1.0, value=1., step=0.01, key='repetition_penalty', on_change=on_sampling_config_change)
62 | st.slider(label="Number of Beams", min_value=1, max_value=10, value=1, step=1, key='num_beams', on_change=on_sampling_config_change)
63 |
64 | def load_registered_models(model_file_path: str = './registered_models.json') -> None:
65 | try:
66 | with open(model_file_path, 'r') as f:
67 | st.session_state.registered_models = json.load(f)
68 | except FileNotFoundError:
69 | st.warning(f"Did not find a 'registered_models.json'. Only showing HF models")
70 | st.session_state.registered_models = dict()
71 | st.session_state.registered_model_names = list(st.session_state.registered_models.keys())
72 |
73 |
74 | def startup(variables: List[str], mode_file_path: Optional[str] = './registered_models.json') -> None:
75 | """Performs startup tasks for the app.
76 |
77 | :param variables: List of variable names that should be intialized.
78 | :type variables: List[str]
79 | :param model_file_path: Path to the file containing the registered models.
80 | :type model_file_path: Optional[str]
81 | """
82 |
83 | if 'startup_done' not in st.session_state:
84 | # set wide layout
85 | st.set_page_config(layout="wide")
86 |
87 | # initialize session state variables
88 | init_session_state(variables)
89 | st.session_state['visualization'] = dict()
90 | st.session_state['startup_done'] = True
91 |
92 | # load externally registered models
93 | load_registered_models(mode_file_path)
94 |
95 | def init_session_state(variables: List[str]) -> None:
96 | """Initialize session state variables to None.
97 |
98 | :param variables: List of variable names to initialize.
99 | :type variables: List[str]
100 | """
101 | for var in variables:
102 | if var not in st.session_state:
103 | st.session_state[var] = None
104 |
105 | def on_config_submit(model_name: str) -> Tuple:
106 | """Function that is called on submitting the config form.
107 |
108 | :param model_name: Name of the model that should be loaded
109 | :type model_name: str
110 | :return: Model, tokenizer, config
111 | :rtype: Tuple
112 | """
113 | # load model, hook it and put it on device and in eval mode
114 | model, tokenizer, config = load_model(model_name)
115 | model = HookedModel(model)
116 | model.to(st.session_state.device).eval()
117 |
118 | if model_name in st.session_state.registered_models:
119 | st.session_state.num_layers = get_num_layers(model, config['layer_key_prefix'])
120 | else:
121 | st.session_state.num_layers = get_num_layers(model, 'transformer->h')
122 |
123 | return model, tokenizer, config
124 |
125 | def on_sampling_config_change():
126 | st.session_state.sample_kwargs = dict(
127 | temperature=st.session_state.temperature,
128 | max_length=st.session_state.response_length,
129 | top_p=st.session_state.top_p,
130 | repetition_penalty=1/st.session_state.repetition_penalty,
131 | num_beams=st.session_state.num_beams
132 | )
133 |
134 | @st.experimental_singleton
135 | def load_model(model_name: str) -> Tuple:
136 | """Load the specified model with its tokenizer and config.
137 |
138 | :param model_name: Model name, e.g. 'gpt2-xl'
139 | :type model_name: str
140 | :return: Model, Tokenizer, Config
141 | :rtype: Tuple
142 | """
143 | if model_name in st.session_state.registered_model_names:
144 | # import model constructor
145 | constructor = st.session_state.registered_models[model_name]['constructor']
146 | constructor_module = importlib.import_module('.'.join(constructor.split('.')[:-1]))
147 | constructor_class = getattr(constructor_module, constructor.split('.')[-1])
148 |
149 | # load model from checkpoint --> make sure that your class has this method, it's default for pl.LightningModules
150 | checkpoint = st.session_state.registered_models[model_name]['checkpoint']
151 | model = constructor_class.load_from_checkpoint(checkpoint)
152 |
153 | # load tokenizer
154 | tokenizer = st.session_state.registered_models[model_name]['tokenizer'] # TODO how to deal with this?
155 | tokenizer_module = importlib.import_module('.'.join(tokenizer.split('.')[:-1]))
156 | tokenizer_class = getattr(tokenizer_module, tokenizer.split('.')[-1])
157 | tokenizer = tokenizer_class()
158 |
159 | # load config
160 | config = st.session_state.registered_models[model_name]['config']
161 |
162 | else: # attempt to load from huggingface
163 | model, tokenizer, config = load_from_pretrained(model_name)
164 |
165 | return model, tokenizer, config
166 |
--------------------------------------------------------------------------------
/unseal/visuals/streamlit_interfaces/load_only.py:
--------------------------------------------------------------------------------
1 | import json
2 | from io import StringIO
3 |
4 | import streamlit as st
5 |
6 | from unseal.visuals.streamlit_interfaces import interface_setup as setup
7 | from unseal.visuals import utils
8 |
9 | def on_file_upload():
10 | if st.session_state.uploaded_file is None:
11 | return
12 |
13 | data = json.loads(StringIO(st.session_state.uploaded_file.getvalue().decode('utf-8')).read())
14 |
15 | st.session_state.visualization = data
16 |
17 | for layer in range(len(data)):
18 | html_str = st.session_state.visualization[f'layer_{layer}']
19 | with st.expander(f'Layer {layer}'):
20 | st.components.v1.html(html_str, height=600, scrolling=True)
21 |
22 | # set page config to wide layout
23 | st.set_page_config(layout='wide')
24 |
25 | # create sidebar
26 | with st.sidebar:
27 | st.file_uploader(
28 | label='Upload Visualization',
29 | accept_multiple_files=False,
30 | on_change=on_file_upload,
31 | help='Upload the visualizations as an json of html files.',
32 | key='uploaded_file'
33 | )
34 |
35 |
36 | def load():
37 | with open("./visualizations/gpt2-xl.json", "r") as fp:
38 | data = json.load(fp)
39 |
40 | st.session_state.visualization = data
41 | for layer in range(len(data['col_0'])):
42 | html_str = st.session_state.visualization['col_0'][f'layer_{layer}']
43 | with st.expander(f'Layer {layer}'):
44 | st.components.v1.html(html_str, height=1000, width=800, scrolling=True)
45 |
46 | load()
47 |
48 |
--------------------------------------------------------------------------------
/unseal/visuals/streamlit_interfaces/registered_models.json:
--------------------------------------------------------------------------------
1 | {}
--------------------------------------------------------------------------------
/unseal/visuals/streamlit_interfaces/single_layer_single_input.py:
--------------------------------------------------------------------------------
1 | def layer_change():
2 | utils.text_change(0)
3 |
4 | import json
5 | import time
6 |
7 | import streamlit as st
8 |
9 | from unseal.visuals.streamlit_interfaces import utils
10 | from unseal.visuals.streamlit_interfaces import interface_setup as setup
11 | from unseal.visuals.streamlit_interfaces.commons import SESSION_STATE_VARIABLES
12 |
13 | # perform startup tasks
14 | setup.startup(SESSION_STATE_VARIABLES, './registered_models.json')
15 |
16 | # create sidebar
17 | with st.sidebar:
18 | setup.create_sidebar()
19 |
20 | sample = st.checkbox('Enable sampling', value=False, key='sample')
21 | if sample:
22 | setup.create_sample_sliders()
23 | setup.on_sampling_config_change()
24 |
25 | if "storage" not in st.session_state:
26 | st.session_state["storage"] = [""]
27 |
28 | # select layer
29 | if st.session_state.num_layers is None:
30 | options = list(['Select a model!'])
31 | else:
32 | options = list(range(st.session_state.num_layers))
33 | st.selectbox('Layer', options=options, key='layer', on_change=layer_change, index=0)
34 |
35 | # input 1
36 | placeholder = st.empty()
37 | placeholder.text_area(label='Input', on_change=utils.on_text_change, key='input_text', value=st.session_state.storage[0], kwargs=dict(col_idx=0, text_key='input_text'))
38 | if sample:
39 | st.button(label="Sample", on_click=utils.sample_text, kwargs=dict(model=st.session_state.model, col_idx=0, key="input_text"), key="sample_text")
40 |
41 | # sometimes need to force a re-render
42 | st.button('Show Attention', on_click=utils.text_change, kwargs=dict(col_idx=0))
43 |
44 | f = json.encoder.JSONEncoder().encode(st.session_state.visualization)
45 | st.download_button(
46 | label='Download Visualization',
47 | data=f,
48 | file_name=f'{st.session_state.model_name}_{time.strftime("%Y%m%d_%H%M%S", time.localtime())}.json',
49 | mime='application/json',
50 | help='Download the visualizations as a json of html files.',
51 | key='download_button'
52 | )
53 |
54 | # show the html visualization
55 | if st.session_state.model is not None:
56 | cols = st.columns(1)
57 | for col_idx, col in enumerate(cols):
58 | if f"col_{col_idx}" in st.session_state.visualization:
59 | with col:
60 | if f"layer_{st.session_state.layer}" in st.session_state.visualization[f"col_{col_idx}"]:
61 | with st.expander(f'Layer {st.session_state.layer}'):
62 | st.components.v1.html(st.session_state.visualization[f"col_{col_idx}"][f"layer_{st.session_state.layer}"], height=600)
63 | else:
64 | st.session_state.visualization[f"col_{col_idx}"] = dict()
--------------------------------------------------------------------------------
/unseal/visuals/streamlit_interfaces/split_full_model_vis_into_layers.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | import os
4 |
5 | def main(args):
6 | with open(args.file, 'r') as f:
7 | data = json.load(f)
8 |
9 | # get model_name from path
10 | filename = os.path.basename(args.file)
11 | model_name = filename.split('.')[0]
12 | print(f"model_name: {model_name}")
13 | for i in range(2):
14 | if f'col_{i}' in data:
15 | for j in range(len(data[f'col_{i}'])):
16 | os.makedirs(f'{args.target_folder}/col_{i}', exist_ok=True)
17 | with open(f'{args.target_folder}/col_{i}/layer_{j}.txt', 'w') as f:
18 | json.dump(data[f'col_{i}'][f'layer_{j}'], f)
19 | else:
20 | print(f"col_{i} not in data -> skipping")
21 |
22 | if __name__ == "__main__":
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('--file', type=str, required=True)
25 | parser.add_argument('--target_folder', type=str, required=True)
26 | args = parser.parse_args()
27 |
28 | main(args)
29 |
--------------------------------------------------------------------------------
/unseal/visuals/streamlit_interfaces/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Union, List
2 |
3 | import streamlit as st
4 | from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoPreTrainedModel
5 |
6 | from ..utils import compute_attn_logits
7 |
8 | def sample_text(model, col_idx, key):
9 | text = st.session_state[key]
10 | if st.session_state.prefix_prompt is not None and len(st.session_state.prefix_prompt) > 0:
11 | text = st.session_state.prefix_prompt + '\n' + text
12 | model_inputs = st.session_state.tokenizer.encode(text, return_tensors='pt').to(st.session_state.device)
13 | output = model.model.generate(model_inputs, **st.session_state.sample_kwargs, min_length=0, output_attentions=True)
14 | output_text = st.session_state.tokenizer.decode(output[0], skip_special_tokens=True)
15 | if st.session_state.prefix_prompt is not None and len(st.session_state.prefix_prompt) > 0:
16 | output_text = output_text.lstrip(st.session_state.prefix_prompt + '\n')
17 |
18 | st.session_state["storage"][col_idx] = output_text
19 | text_change(col_idx=col_idx)
20 |
21 |
22 | def on_text_change(col_idx: Union[int, List[int]], text_key):
23 | if isinstance(col_idx, list):
24 | for idx in col_idx:
25 | on_text_change(idx, text_key)
26 | else:
27 | st.session_state["storage"][col_idx] = st.session_state[text_key]
28 | text_change(col_idx)
29 |
30 | def get_attn_logits_args():
31 | # get args for compute_attn_logits
32 | if st.session_state.model_name in st.session_state.registered_model_names:
33 | attn_name = st.session_state.config['attn_name']
34 | output_idx = st.session_state.config['output_idx']
35 | layer_key_prefix = st.session_state.config['layer_key_prefix']
36 | out_proj_name = st.session_state.config['out_proj_name']
37 | attn_suffix = st.session_state.config['attn_suffix']
38 | unembedding_key = st.session_state.config['unembedding_key']
39 | elif isinstance(st.session_state.model.model, GPTNeoPreTrainedModel):
40 | attn_name = 'attn'
41 | output_idx = 2
42 | layer_key_prefix = 'transformer->h'
43 | out_proj_name = 'out_proj'
44 | attn_suffix = 'attention'
45 | unembedding_key = 'lm_head'
46 | else:
47 | attn_name = 'attn'
48 | output_idx = 2
49 | layer_key_prefix = 'transformer->h'
50 | out_proj_name = 'c_proj'
51 | attn_suffix = None
52 | unembedding_key = 'lm_head'
53 | return attn_name, output_idx, layer_key_prefix, out_proj_name, attn_suffix, unembedding_key
54 |
55 | def text_change(col_idx: Union[int, List[int]]):
56 | if isinstance(col_idx, list):
57 | for idx in col_idx:
58 | text_change(idx)
59 | return
60 |
61 | text = st.session_state["storage"][col_idx]
62 | if st.session_state.prefix_prompt is not None and len(st.session_state.prefix_prompt) > 0:
63 | text = st.session_state.prefix_prompt + '\n' + text
64 |
65 | if text is None or len(text) == 0:
66 | return
67 |
68 | attn_name, output_idx, layer_key_prefix, out_proj_name, attn_suffix, unembedding_key = get_attn_logits_args()
69 |
70 | if 'layer' in st.session_state:
71 | layer = st.session_state['layer']
72 | else:
73 | layer = None
74 |
75 | compute_attn_logits(
76 | st.session_state.model,
77 | st.session_state.model_name,
78 | st.session_state.tokenizer,
79 | st.session_state.num_layers,
80 | text,
81 | st.session_state.visualization[f'col_{col_idx}'],
82 | attn_name = attn_name,
83 | output_idx = output_idx,
84 | layer_key_prefix = layer_key_prefix,
85 | out_proj_name = out_proj_name,
86 | attn_suffix = attn_suffix,
87 | unembedding_key = unembedding_key,
88 | layer_id = layer,
89 | )
--------------------------------------------------------------------------------
/unseal/visuals/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import gc
3 | from typing import Optional, Callable, Dict
4 |
5 | import einops
6 | import pysvelte as ps
7 | import torch
8 | from transformers import AutoTokenizer
9 |
10 | from ..hooks import HookedModel
11 | from ..hooks.common_hooks import create_attention_hook, gpt_attn_wrapper
12 |
13 | def compute_attn_logits(
14 | model: HookedModel,
15 | model_name: str,
16 | tokenizer: AutoTokenizer,
17 | num_layers: int,
18 | text: str,
19 | html_storage: Dict,
20 | save_path: Optional[str] = None,
21 | attn_name: Optional[str] = 'attn',
22 | output_idx: Optional[int] = 2,
23 | layer_key_prefix: Optional[str] = None,
24 | out_proj_name: Optional[str] = 'out_proj',
25 | attn_suffix: Optional[str] = None,
26 | unembedding_key: Optional[str] = 'lm_head',
27 | layer_id: Optional[int] = None,
28 | batch_size: Optional[int] = None,
29 | ):
30 | # parse inputs
31 | if save_path is None:
32 | save_path = f"{model_name}.json"
33 | if layer_key_prefix is None:
34 | layer_key_prefix = ""
35 | else:
36 | layer_key_prefix += "->"
37 | if attn_suffix is None or attn_suffix == "":
38 | attn_suffix = ""
39 | if layer_id is None:
40 | layer_list = list(range(num_layers))
41 | else:
42 | layer_list = [layer_id]
43 |
44 | # tokenize without tokenization artifact -> needed for visualization
45 | tokenized_text = tokenizer.tokenize(text)
46 | tokenized_text = list(map(tokenizer.convert_tokens_to_string, map(lambda x: [x], tokenized_text)))
47 |
48 | # encode text
49 | model_input = tokenizer.encode(text, return_tensors='pt').to(model.device)
50 | target_ids = tokenizer.encode(text)[1:]
51 |
52 | # compute attention pattern
53 | attn_hooks = [create_attention_hook(layer, f'attn_layer_{layer}', output_idx, attn_name, layer_key_prefix) for layer in layer_list]
54 | model.forward(model_input, hooks=attn_hooks, output_attentions=True)
55 |
56 | # compute logits
57 | for layer in layer_list:
58 |
59 | # wrap the _attn function to create logit attribution
60 | model.save_ctx[f'logit_layer_{layer}'] = dict()
61 | old_fn = wrap_gpt_attn(model, layer, target_ids, unembedding_key, attn_name, attn_suffix, layer_key_prefix, out_proj_name, batch_size)
62 |
63 | # forward pass
64 | model.forward(model_input, hooks=[])
65 |
66 | # parse attentions for this layer
67 | attention = model.save_ctx[f"attn_layer_{layer}"]['attn'][0]
68 | attention = einops.rearrange(attention, 'h n1 n2 -> n1 n2 h')
69 |
70 | # parse logits
71 | if model_input.shape[1] > 1: # otherwise we don't have any logit attribution
72 | logits = model.save_ctx[f'logit_layer_{layer}']['logits']
73 | pos_logits = logits['pos']
74 | neg_logits = logits['neg']
75 | pos_logits = pad_logits(pos_logits)
76 | neg_logits = pad_logits(neg_logits)
77 |
78 | pos_logits = einops.rearrange(pos_logits, 'h n1 n2 -> n1 n2 h')
79 | neg_logits = einops.rearrange(neg_logits, 'h n1 n2 -> n1 n2 h')
80 | else:
81 | pos_logits = torch.zeros((attention.shape[0], attention.shape[1], attention.shape[2]))
82 | neg_logits = torch.zeros((attention.shape[0], attention.shape[1], attention.shape[2]))
83 |
84 | # compute and display the html object
85 | html_object = ps.AttentionLogits(tokens=tokenized_text, attention=attention, pos_logits=pos_logits, neg_logits=neg_logits, head_labels=[f'{layer}:{j}' for j in range(attention.shape[-1])])
86 | html_object = html_object.update_meta(suppress_title=True)
87 | html_str = html_object.html_page_str()
88 |
89 | # save html string
90 | html_storage[f'layer_{layer}'] = html_str
91 |
92 | # reset _attn function
93 | reset_attn_fn(model, layer, old_fn, attn_name, attn_suffix, layer_key_prefix)
94 |
95 | # save progress so far
96 | with open(save_path, "w") as f:
97 | json.dump(html_storage, f)
98 |
99 | # garbage collection
100 | gc.collect()
101 | if torch.cuda.is_available():
102 | torch.cuda.empty_cache()
103 |
104 | return html_storage
105 |
106 | def pad_logits(logits):
107 | logits = torch.cat([torch.zeros_like(logits[:,0,None]), logits], dim=1)
108 | logits = torch.cat([logits, torch.zeros_like(logits[:,:,0,None])], dim=2)
109 | return logits
110 |
111 | def wrap_gpt_attn(
112 | model: HookedModel,
113 | layer: int,
114 | target_ids: Callable,
115 | unembedding_key: str,
116 | attn_name: Optional[str] = 'attn',
117 | attn_suffix: Optional[str] = None,
118 | layer_key_prefix: Optional[str] = None,
119 | out_proj_name: Optional[str] = 'out_proj',
120 | batch_size: Optional[int] = None,
121 | ) -> Callable:
122 | # parse inputs
123 | if layer_key_prefix is None:
124 | layer_key_prefix = ""
125 | if attn_suffix is None:
126 | attn_suffix = ""
127 |
128 | attn_name = f"{layer_key_prefix}{layer}->{attn_name}{attn_suffix}"
129 | out_proj_name = attn_name + f"->{out_proj_name}"
130 |
131 | model.layers[attn_name]._attn, old_fn = gpt_attn_wrapper(
132 | model.layers[attn_name]._attn,
133 | model.save_ctx[f'logit_layer_{layer}'],
134 | model.layers[out_proj_name].weight,
135 | model.layers[unembedding_key].weight.T,
136 | target_ids,
137 | batch_size,
138 | )
139 |
140 | return old_fn
141 |
142 |
143 | def reset_attn_fn(
144 | model: HookedModel,
145 | layer: int,
146 | old_fn: Callable,
147 | attn_name: Optional[str] = 'attn',
148 | attn_suffix: Optional[str] = None,
149 | layer_key_prefix: Optional[str] = None,
150 | ) -> None:
151 | # parse inputs
152 | if layer_key_prefix is None:
153 | layer_key_prefix = ""
154 | if attn_suffix is None:
155 | attn_suffix = ""
156 |
157 | # reset _attn function to old_fn
158 | del model.layers[f"{layer_key_prefix}{layer}->{attn_name}{attn_suffix}"]._attn
159 | model.layers[f"{layer_key_prefix}{layer}->{attn_name}{attn_suffix}"]._attn = old_fn
--------------------------------------------------------------------------------