├── .flake8
├── pyproject.toml
├── .dockerignore
├── llm_transparency_tool
├── components
│ ├── frontend
│ │ ├── src
│ │ │ ├── react-app-env.d.ts
│ │ │ ├── common.tsx
│ │ │ ├── index.tsx
│ │ │ ├── LlmViewer.css
│ │ │ ├── Selector.tsx
│ │ │ └── ContributionGraph.tsx
│ │ ├── .prettierrc
│ │ ├── .env
│ │ ├── tsconfig.json
│ │ ├── public
│ │ │ └── index.html
│ │ └── package.json
│ └── __init__.py
├── __init__.py
├── models
│ ├── __init__.py
│ ├── test_tlens_model.py
│ ├── transparent_llm.py
│ └── tlens_model.py
├── routes
│ ├── __init__.py
│ ├── graph_node.py
│ ├── test_contributions.py
│ ├── graph.py
│ └── contributions.py
└── server
│ ├── graph_selection.py
│ ├── monitor.py
│ ├── styles.py
│ ├── utils.py
│ └── app.py
├── .gitignore
├── sample_input.txt
├── config
├── docker_hosting.json
├── docker_local.json
└── local.json
├── setup.py
├── env.yaml
├── CONTRIBUTING.md
├── Dockerfile
├── CODE_OF_CONDUCT.md
├── README.md
└── LICENSE
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 120
3 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 120
3 |
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | **/.git
2 | **/node_modules
3 | **/.mypy_cache
4 |
--------------------------------------------------------------------------------
/llm_transparency_tool/components/frontend/src/react-app-env.d.ts:
--------------------------------------------------------------------------------
1 | ///
2 |
--------------------------------------------------------------------------------
/llm_transparency_tool/components/frontend/.prettierrc:
--------------------------------------------------------------------------------
1 | {
2 | "endOfLine": "lf",
3 | "semi": false,
4 | "trailingComma": "es5"
5 | }
6 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **/frontend/node_modules*
2 | **/frontend/build/
3 | **/frontend/.yarn*
4 | .vscode/
5 | .mypy_cache/
6 | __pycache__/
7 | .DS_Store
8 |
--------------------------------------------------------------------------------
/sample_input.txt:
--------------------------------------------------------------------------------
1 | The war lasted from the year 1732 to the year 17
2 | 5 + 4 = 9, 2 + 3 =
3 | When Mary and John went to the store, John gave a drink to
4 |
--------------------------------------------------------------------------------
/llm_transparency_tool/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/llm_transparency_tool/components/frontend/.env:
--------------------------------------------------------------------------------
1 | # Run the component's dev server on :3001
2 | # (The Streamlit dev server already runs on :3000)
3 | PORT=3001
4 |
5 | # Don't automatically open the web browser on `npm run start`.
6 | BROWSER=none
7 |
--------------------------------------------------------------------------------
/llm_transparency_tool/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/llm_transparency_tool/routes/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/config/docker_hosting.json:
--------------------------------------------------------------------------------
1 | {
2 | "allow_loading_dataset_files": false,
3 | "max_user_string_length": 100,
4 | "preloaded_dataset_filename": "sample_input.txt",
5 | "debug": false,
6 | "demo_mode": true,
7 | "models": {
8 | "facebook/opt-125m": null,
9 | "gpt2": null,
10 | "distilgpt2": null
11 | },
12 | "default_model": "gpt2"
13 | }
14 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from setuptools import setup
8 |
9 | setup(
10 | name="llm_transparency_tool",
11 | version="0.1",
12 | packages=["llm_transparency_tool"],
13 | )
14 |
--------------------------------------------------------------------------------
/llm_transparency_tool/components/frontend/src/common.tsx:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | export interface Point {
10 | x: number
11 | y: number
12 | }
13 |
14 | export interface Label {
15 | text: string
16 | pos: Point
17 | }
18 |
--------------------------------------------------------------------------------
/llm_transparency_tool/components/frontend/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "target": "es5",
4 | "lib": ["dom", "dom.iterable", "esnext"],
5 | "allowJs": true,
6 | "skipLibCheck": true,
7 | "esModuleInterop": true,
8 | "allowSyntheticDefaultImports": true,
9 | "strict": true,
10 | "forceConsistentCasingInFileNames": true,
11 | "module": "esnext",
12 | "moduleResolution": "node",
13 | "resolveJsonModule": true,
14 | "isolatedModules": true,
15 | "noEmit": true,
16 | "jsx": "react"
17 | },
18 | "include": ["src"]
19 | }
20 |
--------------------------------------------------------------------------------
/env.yaml:
--------------------------------------------------------------------------------
1 | name: llmtt
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | dependencies:
7 | - python=3.12
8 | - pytorch
9 | - pytorch-cuda=11.8
10 | - nodejs
11 | - yarn
12 | - pip
13 | - pip:
14 | - datasets
15 | - einops
16 | - fancy_einsum
17 | - jaxtyping==0.2.25
18 | - networkx
19 | - plotly
20 | - pyinstrument
21 | - setuptools
22 | - streamlit
23 | - streamlit_extras
24 | - tokenizers
25 | - transformer_lens
26 | - transformers
27 | - pytest # fixes wrong dependencies of transformer_lens
28 |
--------------------------------------------------------------------------------
/llm_transparency_tool/components/frontend/public/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Contribution Graph for Streamlit
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/config/docker_local.json:
--------------------------------------------------------------------------------
1 | {
2 | "allow_loading_dataset_files": true,
3 | "preloaded_dataset_filename": "sample_input.txt",
4 | "debug": true,
5 | "models": {
6 | "": null,
7 | "facebook/opt-125m": null,
8 | "facebook/opt-1.3b": null,
9 | "facebook/opt-2.7b": null,
10 | "facebook/opt-6.7b": null,
11 | "facebook/opt-13b": null,
12 | "facebook/opt-30b": null,
13 | "meta-llama/Llama-2-7b-hf": null,
14 | "meta-llama/Llama-2-7b-chat-hf": null,
15 | "meta-llama/Llama-2-13b-hf": null,
16 | "meta-llama/Llama-2-13b-chat-hf": null,
17 | "gpt2": null,
18 | "gpt2-medium": null,
19 | "gpt2-large": null,
20 | "gpt2-xl": null,
21 | "distilgpt2": null
22 | },
23 | "default_model": "distilgpt2",
24 | "demo_mode": false
25 | }
26 |
--------------------------------------------------------------------------------
/llm_transparency_tool/components/frontend/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "contribution_graph",
3 | "version": "0.1.0",
4 | "private": true,
5 | "dependencies": {
6 | "@types/d3": "^7.4.0",
7 | "d3": "^7.8.5",
8 | "react": "^18.2.0",
9 | "react-dom": "^18.2.0",
10 | "streamlit-component-lib": "^2.0.0"
11 | },
12 | "scripts": {
13 | "start": "react-scripts start",
14 | "build": "react-scripts build",
15 | "test": "react-scripts test",
16 | "eject": "react-scripts eject"
17 | },
18 | "browserslist": {
19 | "production": [
20 | ">0.2%",
21 | "not dead",
22 | "not op_mini all"
23 | ],
24 | "development": [
25 | "last 1 chrome version",
26 | "last 1 firefox version",
27 | "last 1 safari version"
28 | ]
29 | },
30 | "homepage": ".",
31 | "devDependencies": {
32 | "@types/node": "^20.11.17",
33 | "@types/react": "^18.2.55",
34 | "@types/react-dom": "^18.2.19",
35 | "eslint-config-react-app": "^7.0.1",
36 | "react-scripts": "^5.0.1",
37 | "typescript": "^5.3.3"
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/llm_transparency_tool/components/frontend/src/index.tsx:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | import React from "react"
10 | import ReactDOM from "react-dom"
11 |
12 | import {
13 | ComponentProps,
14 | withStreamlitConnection,
15 | } from "streamlit-component-lib"
16 |
17 |
18 | import ContributionGraph from "./ContributionGraph"
19 | import Selector from "./Selector"
20 |
21 | const LlmViewerComponent = (props: ComponentProps) => {
22 | switch (props.args['component']) {
23 | case 'graph':
24 | return
25 | case 'selector':
26 | return
27 | default:
28 | return <>>
29 | }
30 | };
31 |
32 | const StreamlitLlmViewerComponent = withStreamlitConnection(LlmViewerComponent)
33 |
34 | ReactDOM.render(
35 |
36 |
37 | ,
38 | document.getElementById("root")
39 | )
40 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to llm-transparency-tool
2 | We want to make contributing to this project as easy and transparent as
3 | possible.
4 |
5 | ## Pull Requests
6 | We actively welcome your pull requests.
7 |
8 | 1. Fork the repo and create your branch from `main`.
9 | 2. If you've added code that should be tested, add tests.
10 | 3. If you've changed APIs, update the documentation.
11 | 4. Ensure the test suite passes.
12 | 5. Make sure your code lints.
13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14 |
15 | ## Contributor License Agreement ("CLA")
16 | In order to accept your pull request, we need you to submit a CLA. You only need
17 | to do this once to work on any of Facebook's open source projects.
18 |
19 | Complete your CLA here:
20 |
21 | ## Issues
22 | We use GitHub issues to track public bugs. Please ensure your description is
23 | clear and has sufficient instructions to be able to reproduce the issue.
24 |
25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
26 | disclosure of security bugs. In those cases, please go through the process
27 | outlined on that page and do not file a public issue.
28 |
29 | ## License
30 | By contributing to llm-transparency-tool, you agree that your contributions will be licensed
31 | under the LICENSE file in the root directory of this source tree.
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
8 |
9 | RUN apt-get update && apt-get install -y \
10 | wget \
11 | git \
12 | && apt-get clean \
13 | && rm -rf /var/lib/apt/lists/*
14 |
15 | RUN useradd -m -u 1000 user
16 | USER user
17 |
18 | ENV HOME=/home/user
19 |
20 | RUN wget -P /tmp \
21 | "https://github.com/conda-forge/miniforge/releases/download/23.11.0-0/Mambaforge-23.11.0-0-Linux-x86_64.sh" \
22 | && bash /tmp/Mambaforge-23.11.0-0-Linux-x86_64.sh -b -p $HOME/mambaforge3 \
23 | && rm /tmp/Mambaforge-23.11.0-0-Linux-x86_64.sh
24 | ENV PATH $HOME/mambaforge3/bin:$PATH
25 |
26 | WORKDIR $HOME
27 |
28 | ENV REPO=$HOME/llm-transparency-tool
29 | COPY --chown=user . $REPO
30 |
31 | WORKDIR $REPO
32 |
33 | RUN mamba env create --name llmtt -f env.yaml -y
34 | ENV PATH $HOME/mambaforge3/envs/llmtt/bin:$PATH
35 | RUN pip install -e .
36 |
37 | RUN cd llm_transparency_tool/components/frontend \
38 | && yarn install \
39 | && yarn build
40 |
41 | EXPOSE 7860
42 | CMD ["streamlit", "run", "llm_transparency_tool/server/app.py", "--server.port=7860", "--server.address=0.0.0.0", "--theme.font=Inconsolata", "--", "config/docker_hosting.json"]
43 |
--------------------------------------------------------------------------------
/llm_transparency_tool/components/frontend/src/LlmViewer.css:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | .graph-container {
10 | display: flex;
11 | justify-content: center;
12 | align-items: center;
13 | height: 100vh;
14 | }
15 |
16 | .svg {
17 | border: 1px solid #ccc;
18 | }
19 |
20 | .layer-highlight {
21 | fill: #f0f5f0;
22 | }
23 |
24 | .selectable-item {
25 | stroke: black;
26 | cursor: pointer;
27 | }
28 |
29 | .selection,
30 | .selection:hover {
31 | fill: orange;
32 | }
33 |
34 | .active-residual-node {
35 | fill: yellowgreen;
36 | }
37 |
38 | .active-residual-node:hover {
39 | fill: olivedrab;
40 | }
41 |
42 | .active-ffn-node {
43 | fill: orchid;
44 | }
45 |
46 | .active-ffn-node:hover {
47 | fill: purple;
48 | }
49 |
50 | .inactive-node {
51 | fill: lightgray;
52 | stroke-width: 0.5px;
53 | }
54 |
55 | .inactive-node:hover {
56 | fill: gray;
57 | }
58 |
59 | .selectable-edge {
60 | cursor: pointer;
61 | }
62 |
63 | .token-selector {
64 | fill: lightblue;
65 | }
66 |
67 | .token-selector:hover {
68 | fill: cornflowerblue;
69 | }
70 |
71 | .selector-item {
72 | fill: lightblue;
73 | }
74 |
75 | .selector-item:hover {
76 | fill: cornflowerblue;
77 | }
--------------------------------------------------------------------------------
/config/local.json:
--------------------------------------------------------------------------------
1 | {
2 | "allow_loading_dataset_files": true,
3 | "preloaded_dataset_filename": "sample_input.txt",
4 | "debug": true,
5 | "models": {
6 | "": null,
7 |
8 | "meta-llama/Meta-Llama-3-8B": null,
9 | "meta-llama/Meta-Llama-3-70B": null,
10 | "meta-llama/Meta-Llama-3-8B-Instruct": null,
11 | "meta-llama/Meta-Llama-3-70B-Instruct": null,
12 | "mistral/mixtral-instruct": null,
13 |
14 | "gpt2": null,
15 | "distilgpt2": null,
16 | "facebook/opt-125m": null,
17 | "facebook/opt-1.3b": null,
18 | "EleutherAI/gpt-neo-125M": null,
19 | "Qwen/Qwen-1_8B": null,
20 | "Qwen/Qwen1.5-0.5B": null,
21 | "Qwen/Qwen1.5-0.5B-Chat": null,
22 | "Qwen/Qwen1.5-1.8B": null,
23 | "Qwen/Qwen1.5-1.8B-Chat": null,
24 | "microsoft/phi-1": null,
25 | "microsoft/phi-1_5": null,
26 | "microsoft/phi-2": null,
27 |
28 | "meta-llama/Llama-2-7b-hf": null,
29 | "meta-llama/Llama-2-7b-chat-hf": null,
30 |
31 | "meta-llama/Llama-2-13b-hf": null,
32 | "meta-llama/Llama-2-13b-chat-hf": null,
33 |
34 | "meta-llama/Llama-2-70b-hf": null,
35 | "meta-llama/Llama-2-70b-chat-hf": null,
36 |
37 |
38 | "gpt2-medium": null,
39 | "gpt2-large": null,
40 | "gpt2-xl": null,
41 |
42 | "mistralai/Mistral-7B-v0.1": null,
43 | "mistralai/Mistral-7B-Instruct-v0.1": null,
44 | "mistralai/Mistral-7B-Instruct-v0.2": null,
45 |
46 | "google/gemma-7b": null,
47 | "google/gemma-2b": null,
48 |
49 | "facebook/opt-2.7b": null,
50 | "facebook/opt-6.7b": null,
51 | "facebook/opt-13b": null,
52 | "facebook/opt-30b": null
53 | },
54 | "default_model": "",
55 | "demo_mode": false
56 | }
57 |
--------------------------------------------------------------------------------
/llm_transparency_tool/server/graph_selection.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from dataclasses import dataclass
8 | from typing import Any, Dict, Optional
9 |
10 | from llm_transparency_tool.routes.graph_node import GraphNode, NodeType
11 |
12 |
13 | class UiGraphNode(GraphNode):
14 | @staticmethod
15 | def from_json(json: Dict[str, Any]) -> Optional["UiGraphNode"]:
16 | try:
17 | layer = json["cell"]["layer"]
18 | token = json["cell"]["token"]
19 | type = NodeType(json["item"])
20 | return UiGraphNode(layer, token, type)
21 | except (TypeError, KeyError):
22 | return None
23 |
24 |
25 | @dataclass
26 | class UiGraphEdge:
27 | source: UiGraphNode
28 | target: UiGraphNode
29 | weight: float
30 |
31 | @staticmethod
32 | def from_json(json: Dict[str, Any]) -> Optional["UiGraphEdge"]:
33 | try:
34 | source = UiGraphNode.from_json(json["from"])
35 | target = UiGraphNode.from_json(json["to"])
36 | if source is None or target is None:
37 | return None
38 | weight = float(json["weight"])
39 | return UiGraphEdge(source, target, weight)
40 | except (TypeError, KeyError):
41 | return None
42 |
43 |
44 | @dataclass
45 | class GraphSelection:
46 | node: Optional[UiGraphNode]
47 | edge: Optional[UiGraphEdge]
48 |
49 | @staticmethod
50 | def from_json(json: Dict[str, Any]) -> Optional["GraphSelection"]:
51 | try:
52 | node = UiGraphNode.from_json(json["node"])
53 | edge = UiGraphEdge.from_json(json["edge"])
54 | return GraphSelection(node, edge)
55 | except (TypeError, KeyError):
56 | return None
57 |
--------------------------------------------------------------------------------
/llm_transparency_tool/routes/graph_node.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from dataclasses import dataclass
8 | from enum import Enum
9 | from typing import List, Optional
10 |
11 |
12 | class NodeType(Enum):
13 | AFTER_ATTN = "after_attn"
14 | AFTER_FFN = "after_ffn"
15 | FFN = "ffn"
16 | ORIGINAL = "original" # The original tokens
17 |
18 |
19 | def _format_block_hierachy_string(blocks: List[str]) -> str:
20 | return " ▸ ".join(blocks)
21 |
22 |
23 | @dataclass
24 | class GraphNode:
25 | layer: int
26 | token: int
27 | type: NodeType
28 |
29 | def is_in_residual_stream(self) -> bool:
30 | return self.type in [NodeType.AFTER_ATTN, NodeType.AFTER_FFN]
31 |
32 | def get_residual_predecessor(self) -> Optional["GraphNode"]:
33 | """
34 | Get another graph node which points to the state of the residual stream before
35 | this node.
36 |
37 | Retun None if current representation is the first one in the residual stream.
38 | """
39 | scheme = {
40 | NodeType.AFTER_ATTN: GraphNode(
41 | layer=max(self.layer - 1, 0),
42 | token=self.token,
43 | type=NodeType.AFTER_FFN if self.layer > 0 else NodeType.ORIGINAL,
44 | ),
45 | NodeType.AFTER_FFN: GraphNode(
46 | layer=self.layer,
47 | token=self.token,
48 | type=NodeType.AFTER_ATTN,
49 | ),
50 | NodeType.FFN: GraphNode(
51 | layer=self.layer,
52 | token=self.token,
53 | type=NodeType.AFTER_ATTN,
54 | ),
55 | NodeType.ORIGINAL: None,
56 | }
57 | node = scheme[self.type]
58 | if node.layer < 0:
59 | return None
60 | return node
61 |
62 | def get_name(self) -> str:
63 | return _format_block_hierachy_string(
64 | [f"L{self.layer}", f"T{self.token}", str(self.type.value)]
65 | )
66 |
67 | def get_predecessor_block_name(self) -> str:
68 | """
69 | Return the name of the block standing between current node and its predecessor
70 | in the residual stream.
71 | """
72 | scheme = {
73 | NodeType.AFTER_ATTN: [f"L{self.layer}", "attn"],
74 | NodeType.AFTER_FFN: [f"L{self.layer}", "ffn"],
75 | NodeType.FFN: [f"L{self.layer}", "ffn"],
76 | NodeType.ORIGINAL: ["Nothing"],
77 | }
78 | return _format_block_hierachy_string(scheme[self.type])
79 |
80 | def get_head_name(self, head: Optional[int]) -> str:
81 | path = [f"L{self.layer}", "attn"]
82 | if head is not None:
83 | path.append(f"H{head}")
84 | return _format_block_hierachy_string(path)
85 |
86 | def get_neuron_name(self, neuron: Optional[int]) -> str:
87 | path = [f"L{self.layer}", "ffn"]
88 | if neuron is not None:
89 | path.append(f"N{neuron}")
90 | return _format_block_hierachy_string(path)
91 |
--------------------------------------------------------------------------------
/llm_transparency_tool/server/monitor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import streamlit as st
9 | from pyinstrument import Profiler
10 | from typing import Dict
11 | import pandas as pd
12 |
13 |
14 | @st.cache_resource(max_entries=1, show_spinner=False)
15 | def init_gpu_memory():
16 | """
17 | When CUDA is initialized, it occupies some memory on the GPU thus this overhead
18 | can sometimes make it difficult to understand how much memory is actually used by
19 | the model.
20 |
21 | This function is used to initialize CUDA and measure the overhead.
22 | """
23 | if not torch.cuda.is_available():
24 | return {}
25 |
26 | # lets init torch gpu for a moment
27 | gpu_memory_overhead = {}
28 | for i in range(torch.cuda.device_count()):
29 | torch.ones(1).cuda(i)
30 | free, total = torch.cuda.mem_get_info(i)
31 | occupied = total - free
32 | gpu_memory_overhead[i] = occupied
33 |
34 | return gpu_memory_overhead
35 |
36 |
37 | class SystemMonitor:
38 | """
39 | This class is used to monitor the system resources such as GPU memory and CPU
40 | usage. It uses the pyinstrument library to profile the code and measure the
41 | execution time of different parts of the code.
42 | """
43 |
44 | def __init__(
45 | self,
46 | enabled: bool = False,
47 | ):
48 | self.enabled = enabled
49 | self.profiler = Profiler()
50 | self.overhead: Dict[int, int]
51 |
52 | def __enter__(self):
53 | if not self.enabled:
54 | return
55 |
56 | self.overhead = init_gpu_memory()
57 |
58 | self.profiler.__enter__()
59 |
60 | def __exit__(self, exc_type, exc_value, traceback):
61 | if not self.enabled:
62 | return
63 |
64 | self.profiler.__exit__(exc_type, exc_value, traceback)
65 |
66 | self.report_gpu_usage()
67 | self.report_profiler()
68 |
69 | with st.expander("Session state"):
70 | st.write(st.session_state)
71 |
72 | return None
73 |
74 | def report_gpu_usage(self):
75 |
76 | if not torch.cuda.is_available():
77 | return
78 |
79 | data = []
80 |
81 | for i in range(torch.cuda.device_count()):
82 | free, total = torch.cuda.mem_get_info(i)
83 | occupied = total - free
84 | data.append({
85 | 'overhead': self.overhead[i],
86 | 'occupied': occupied - self.overhead[i],
87 | 'free': free,
88 | })
89 | df = pd.DataFrame(data, columns=["overhead", "occupied", "free"])
90 |
91 | with st.sidebar.expander("System"):
92 | st.write("GPU memory on server")
93 | df /= 1024 ** 3 # Convert to GB
94 | st.bar_chart(df, width=200, height=200, color=["#fefefe", "#84c9ff", "#fe2b2b"])
95 |
96 | def report_profiler(self):
97 | html_code = self.profiler.output_html()
98 | with st.expander("Profiler", expanded=False):
99 | st.components.v1.html(html_code, height=1000, scrolling=True)
100 |
--------------------------------------------------------------------------------
/llm_transparency_tool/server/styles.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from dataclasses import dataclass
8 |
9 | import matplotlib
10 |
11 | # Unofficial way do make the padding a bit smaller.
12 | margins_css = """
13 |
24 | """
25 |
26 |
27 | @dataclass
28 | class RenderSettings:
29 | column_proportions = [50, 30]
30 |
31 | # We don't know the actual height. This will be used in order to compute the table
32 | # viewport height when needed.
33 | table_cell_height = 36
34 |
35 | n_top_tokens = 30
36 | n_promoted_tokens = 15
37 | n_suppressed_tokens = 15
38 |
39 | n_top_neurons = 20
40 |
41 | attention_color_map = "Blues"
42 |
43 | no_model_alt_text = ""
44 |
45 |
46 | def string_to_display(s: str) -> str:
47 | return s.replace(" ", "·")
48 |
49 |
50 | def logits_color_map(positive_and_negative: bool) -> matplotlib.colors.Colormap:
51 | background_colors = {
52 | "red": [
53 | [0.0, 0.40, 0.40],
54 | [0.1, 0.69, 0.69],
55 | [0.2, 0.83, 0.83],
56 | [0.3, 0.95, 0.95],
57 | [0.4, 0.99, 0.99],
58 | [0.5, 1.0, 1.0],
59 | [0.6, 0.90, 0.90],
60 | [0.7, 0.72, 0.72],
61 | [0.8, 0.49, 0.49],
62 | [0.9, 0.30, 0.30],
63 | [1.0, 0.15, 0.15],
64 | ],
65 | "green": [
66 | [0.0, 0.0, 0.0],
67 | [0.1, 0.09, 0.09],
68 | [0.2, 0.37, 0.37],
69 | [0.3, 0.64, 0.64],
70 | [0.4, 0.85, 0.85],
71 | [0.5, 1.0, 1.0],
72 | [0.6, 0.96, 0.96],
73 | [0.7, 0.88, 0.88],
74 | [0.8, 0.73, 0.73],
75 | [0.9, 0.57, 0.57],
76 | [1.0, 0.39, 0.39],
77 | ],
78 | "blue": [
79 | [0.0, 0.12, 0.12],
80 | [0.1, 0.16, 0.16],
81 | [0.2, 0.30, 0.30],
82 | [0.3, 0.50, 0.50],
83 | [0.4, 0.78, 0.78],
84 | [0.5, 1.0, 1.0],
85 | [0.6, 0.81, 0.81],
86 | [0.7, 0.52, 0.52],
87 | [0.8, 0.25, 0.25],
88 | [0.9, 0.12, 0.12],
89 | [1.0, 0.09, 0.09],
90 | ],
91 | }
92 |
93 | if not positive_and_negative:
94 | # Stretch the top part to the whole range
95 | new_colors = {}
96 | for channel, colors in background_colors.items():
97 | new_colors[channel] = [
98 | [(value - 0.5) * 2, color, color]
99 | for value, color, _ in colors
100 | if value >= 0.5
101 | ]
102 | background_colors = new_colors
103 |
104 | return matplotlib.colors.LinearSegmentedColormap(
105 | f"RdYG-{positive_and_negative}",
106 | background_colors,
107 | )
108 |
--------------------------------------------------------------------------------
/llm_transparency_tool/components/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | from typing import List, Optional
9 |
10 | import networkx as nx
11 | import streamlit.components.v1 as components
12 |
13 | from llm_transparency_tool.models.transparent_llm import ModelInfo
14 | from llm_transparency_tool.server.graph_selection import GraphSelection, UiGraphNode
15 |
16 | _RELEASE = True
17 |
18 | if _RELEASE:
19 | parent_dir = os.path.dirname(os.path.abspath(__file__))
20 | config = {
21 | "path": os.path.join(parent_dir, "frontend/build"),
22 | }
23 | else:
24 | config = {
25 | "url": "http://localhost:3001",
26 | }
27 |
28 | _component_func = components.declare_component("contribution_graph", **config)
29 |
30 |
31 | def is_node_valid(node: UiGraphNode, n_layers: int, n_tokens: int):
32 | return node.layer < n_layers and node.token < n_tokens
33 |
34 |
35 | def is_selection_valid(s: GraphSelection, n_layers: int, n_tokens: int):
36 | if not s:
37 | return True
38 | if s.node:
39 | if not is_node_valid(s.node, n_layers, n_tokens):
40 | return False
41 | if s.edge:
42 | for node in [s.edge.source, s.edge.target]:
43 | if not is_node_valid(node, n_layers, n_tokens):
44 | return False
45 | return True
46 |
47 |
48 | def contribution_graph(
49 | model_info: ModelInfo,
50 | tokens: List[str],
51 | graphs: List[nx.Graph],
52 | key: str,
53 | ) -> Optional[GraphSelection]:
54 | """Create a new instance of contribution graph.
55 |
56 | Returns selected graph node or None if nothing was selected.
57 | """
58 | assert len(tokens) == len(graphs)
59 |
60 | result = _component_func(
61 | component="graph",
62 | model_info=model_info.__dict__,
63 | tokens=tokens,
64 | edges_per_token=[nx.node_link_data(g)["links"] for g in graphs],
65 | default=None,
66 | key=key,
67 | )
68 |
69 | selection = GraphSelection.from_json(result)
70 |
71 | n_tokens = len(tokens)
72 | n_layers = model_info.n_layers
73 | # We need this extra protection because even though the component has to check for
74 | # the validity of the selection, sometimes it allows invalid output. It's some
75 | # unexpected effect that has something to do with React and how the output value is
76 | # set for the component.
77 | if not is_selection_valid(selection, n_layers, n_tokens):
78 | selection = None
79 |
80 | return selection
81 |
82 |
83 | def selector(
84 | items: List[str],
85 | indices: List[int],
86 | temperatures: Optional[List[float]],
87 | preselected_index: Optional[int],
88 | key: str,
89 | ) -> Optional[int]:
90 | """Create a new instance of selector.
91 |
92 | Returns selected item index.
93 | """
94 | n = len(items)
95 | assert n == len(indices)
96 | items = [{"index": i, "text": s} for s, i in zip(items, indices)]
97 |
98 | if temperatures is not None:
99 | assert n == len(temperatures)
100 | for i, t in enumerate(temperatures):
101 | items[i]["temperature"] = t
102 |
103 | result = _component_func(
104 | component="selector",
105 | items=items,
106 | preselected_index=preselected_index,
107 | default=None,
108 | key=key,
109 | )
110 |
111 | return None if result is None else int(result)
112 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to make participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | This Code of Conduct also applies outside the project spaces when there is a
56 | reasonable belief that an individual's behavior may have a negative impact on
57 | the project or its community.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported by contacting the project team at . All
63 | complaints will be reviewed and investigated and will result in a response that
64 | is deemed necessary and appropriate to the circumstances. The project team is
65 | obligated to maintain confidentiality with regard to the reporter of an incident.
66 | Further details of specific enforcement policies may be posted separately.
67 |
68 | Project maintainers who do not follow or enforce the Code of Conduct in good
69 | faith may face temporary or permanent repercussions as determined by other
70 | members of the project's leadership.
71 |
72 | ## Attribution
73 |
74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76 |
77 | [homepage]: https://www.contributor-covenant.org
78 |
79 | For answers to common questions about this code of conduct, see
80 | https://www.contributor-covenant.org/faq
81 |
--------------------------------------------------------------------------------
/llm_transparency_tool/server/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import uuid
8 | from typing import List, Optional, Tuple
9 |
10 | import networkx as nx
11 | import streamlit as st
12 | import torch
13 | import transformers
14 |
15 | import llm_transparency_tool.routes.graph
16 | from llm_transparency_tool.models.tlens_model import TransformerLensTransparentLlm
17 | from llm_transparency_tool.models.transparent_llm import TransparentLlm
18 |
19 | GPU = "gpu"
20 | CPU = "cpu"
21 |
22 | # This variable is for expressing the idea that batch_id = 0, but make it more
23 | # readable than just 0.
24 | B0 = 0
25 |
26 |
27 | def possible_devices() -> List[str]:
28 | devices = []
29 | if torch.cuda.is_available():
30 | devices.append("gpu")
31 | devices.append("cpu")
32 | return devices
33 |
34 |
35 | def load_dataset(filename) -> List[str]:
36 | with open(filename) as f:
37 | dataset = [s.strip("\n") for s in f.readlines()]
38 | print(f"Loaded {len(dataset)} sentences from {filename}")
39 | return dataset
40 |
41 |
42 | @st.cache_resource(
43 | hash_funcs={
44 | TransformerLensTransparentLlm: id
45 | }
46 | )
47 | def load_model(
48 | model_name: str,
49 | _device: str,
50 | _model_path: Optional[str] = None,
51 | _dtype: torch.dtype = torch.float32,
52 | supported_model_name: Optional[str] = None,
53 | ) -> TransparentLlm:
54 | """
55 | Returns the loaded model along with its key. The key is just a unique string which
56 | can be used later to identify if the model has changed.
57 | """
58 | assert _device in possible_devices()
59 |
60 | causal_lm = None
61 | tokenizer = None
62 |
63 | tl_lm = TransformerLensTransparentLlm(
64 | model_name=model_name,
65 | hf_model=causal_lm,
66 | tokenizer=tokenizer,
67 | device=_device,
68 | dtype=_dtype,
69 | supported_model_name=supported_model_name,
70 | )
71 |
72 | return tl_lm
73 |
74 |
75 | def run_model(model: TransparentLlm, sentence: str) -> None:
76 | print(f"Running inference for '{sentence}'")
77 | model.run([sentence])
78 |
79 |
80 | def load_model_with_session_caching(
81 | **kwargs,
82 | ) -> Tuple[TransparentLlm, str]:
83 | return load_model(**kwargs)
84 |
85 | def run_model_with_session_caching(
86 | _model: TransparentLlm,
87 | model_key: str,
88 | sentence: str,
89 | ):
90 | LAST_RUN_MODEL_KEY = "last_run_model_key"
91 | LAST_RUN_SENTENCE = "last_run_sentence"
92 | state = st.session_state
93 |
94 | if (
95 | state.get(LAST_RUN_MODEL_KEY, None) == model_key
96 | and state.get(LAST_RUN_SENTENCE, None) == sentence
97 | ):
98 | return
99 |
100 | run_model(_model, sentence)
101 | state[LAST_RUN_MODEL_KEY] = model_key
102 | state[LAST_RUN_SENTENCE] = sentence
103 |
104 |
105 | @st.cache_resource(
106 | hash_funcs={
107 | TransformerLensTransparentLlm: id
108 | }
109 | )
110 | def get_contribution_graph(
111 | model: TransparentLlm, # TODO bug here
112 | model_key: str,
113 | tokens: List[str],
114 | threshold: float,
115 | ) -> nx.Graph:
116 | """
117 | The `model_key` and `tokens` are used only for caching. The model itself is not
118 | hashed, hence the `_` in the beginning.
119 | """
120 | return llm_transparency_tool.routes.graph.build_full_graph(
121 | model,
122 | B0,
123 | threshold,
124 | )
125 |
126 |
127 | def st_placeholder(
128 | text: str,
129 | container=st,
130 | border: bool = True,
131 | height: Optional[int] = 500,
132 | ):
133 | empty = container.empty()
134 | empty.container(border=border, height=height).write(f'{text}', unsafe_allow_html=True)
135 | return empty
136 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 | ## Key functionality
9 |
10 | * Choose your model, choose or add your prompt, run the inference.
11 | * Browse contribution graph.
12 | * Select the token to build the graph from.
13 | * Tune the contribution threshold.
14 | * Select representation of any token after any block.
15 | * For the representation, see its projection to the output vocabulary, see which tokens
16 | were promoted/suppressed but the previous block.
17 | * The following things are clickable:
18 | * Edges. That shows more info about the contributing attention head.
19 | * Heads when an edge is selected. You can see what this head is promoting/suppressing.
20 | * FFN blocks (little squares on the graph).
21 | * Neurons when an FFN block is selected.
22 |
23 |
24 | ## Installation
25 |
26 | ### Dockerized running
27 | ```bash
28 | # From the repository root directory
29 | docker build -t llm_transparency_tool .
30 | docker run --rm -p 7860:7860 llm_transparency_tool
31 | ```
32 |
33 | ### Local Installation
34 |
35 |
36 | ```bash
37 | # download
38 | git clone git@github.com:facebookresearch/llm-transparency-tool.git
39 | cd llm-transparency-tool
40 |
41 | # install the necessary packages
42 | conda env create --name llmtt -f env.yaml
43 | # install the `llm_transparency_tool` package
44 | pip install -e .
45 |
46 | # now, we need to build the frontend
47 | # don't worry, even `yarn` comes preinstalled by `env.yaml`
48 | cd llm_transparency_tool/components/frontend
49 | yarn install
50 | yarn build
51 | ```
52 |
53 | ### Launch
54 |
55 | ```bash
56 | streamlit run llm_transparency_tool/server/app.py -- config/local.json
57 | ```
58 |
59 |
60 | ## Adding support for your LLM
61 |
62 | Initially, the tool allows you to select from just a handful of models. Here are the
63 | options you can try for using your model in the tool, from least to most
64 | effort.
65 |
66 |
67 | ### The model is already supported by TransformerLens
68 |
69 | Full list of models is [here](https://github.com/neelnanda-io/TransformerLens/blob/0825c5eb4196e7ad72d28bcf4e615306b3897490/transformer_lens/loading_from_pretrained.py#L18).
70 | In this case, the model can be added to the configuration json file.
71 |
72 |
73 | ### Tuned version of a model supported by TransformerLens
74 |
75 | Add the official name of the model to the config along with the location to read the
76 | weights from.
77 |
78 |
79 | ### The model is not supported by TransformerLens
80 |
81 | In this case the UI wouldn't know how to create proper hooks for the model. You'd need
82 | to implement your version of [TransparentLlm](./llm_transparency_tool/models/transparent_llm.py#L28) class and alter the
83 | Streamlit app to use your implementation.
84 |
85 | ## Citation
86 | If you use the LLM Transparency Tool for your research, please consider citing:
87 |
88 | ```bibtex
89 | @article{tufanov2024lm,
90 | title={LM Transparency Tool: Interactive Tool for Analyzing Transformer Language Models},
91 | author={Igor Tufanov and Karen Hambardzumyan and Javier Ferrando and Elena Voita},
92 | year={2024},
93 | journal={Arxiv},
94 | url={https://arxiv.org/abs/2404.07004}
95 | }
96 |
97 | @article{ferrando2024information,
98 | title={Information Flow Routes: Automatically Interpreting Language Models at Scale},
99 | author={Javier Ferrando and Elena Voita},
100 | year={2024},
101 | journal={Arxiv},
102 | url={https://arxiv.org/abs/2403.00824}
103 | }
104 | ````
105 |
106 | ## License
107 |
108 | This code is made available under a [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license, as found in the LICENSE file.
109 | However you may have other legal obligations that govern your use of other content, such as the terms of service for third-party models.
110 |
--------------------------------------------------------------------------------
/llm_transparency_tool/components/frontend/src/Selector.tsx:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | import {
10 | ComponentProps,
11 | Streamlit,
12 | withStreamlitConnection,
13 | } from "streamlit-component-lib"
14 | import React, { useEffect, useMemo, useRef, useState } from 'react';
15 | import * as d3 from 'd3';
16 |
17 | import {
18 | Point,
19 | } from './common';
20 | import './LlmViewer.css';
21 |
22 | export const renderParams = {
23 | verticalGap: 24,
24 | horizontalGap: 24,
25 | itemSize: 8,
26 | }
27 |
28 | interface Item {
29 | index: number
30 | text: string
31 | temperature: number
32 | }
33 |
34 | const Selector = ({ args }: ComponentProps) => {
35 | const items: Item[] = args["items"]
36 | const preselected_index: number | null = args["preselected_index"]
37 | const n = items.length
38 |
39 | const [selection, setSelection] = useState(null)
40 |
41 | // Ensure the preselected element has effect only when it's a new data.
42 | var args_json = JSON.stringify(args)
43 | useEffect(() => {
44 | setSelection(preselected_index)
45 | Streamlit.setComponentValue(preselected_index)
46 | }, [args_json, preselected_index]);
47 |
48 | const handleItemClick = (index: number) => {
49 | setSelection(index)
50 | Streamlit.setComponentValue(index)
51 | }
52 |
53 | const [xScale, yScale] = useMemo(() => {
54 | const x = d3.scaleLinear()
55 | .domain([0, 1])
56 | .range([0, renderParams.horizontalGap])
57 | const y = d3.scaleLinear()
58 | .domain([0, n - 1])
59 | .range([0, renderParams.verticalGap * (n - 1)])
60 | return [x, y]
61 | }, [n])
62 |
63 | const itemCoords: Point[] = useMemo(() => {
64 | return Array.from(Array(n).keys()).map(i => ({
65 | x: xScale(0.5),
66 | y: yScale(i + 0.5),
67 | }))
68 | }, [n, xScale, yScale])
69 |
70 | var hasTemperature = false
71 | if (n > 0) {
72 | var t = items[0].temperature
73 | hasTemperature = (t !== null && t !== undefined)
74 | }
75 | const colorScale = useMemo(() => {
76 | var min_t = 0.0
77 | var max_t = 1.0
78 | if (hasTemperature) {
79 | min_t = items[0].temperature
80 | max_t = items[0].temperature
81 | for (var i = 0; i < n; i++) {
82 | const t = items[i].temperature
83 | min_t = Math.min(min_t, t)
84 | max_t = Math.max(max_t, t)
85 | }
86 | }
87 | const norm = d3.scaleLinear([min_t, max_t], [0.0, 1.0])
88 | const colorScale = d3.scaleSequential(d3.interpolateYlGn);
89 | return d3.scaleSequential(value => colorScale(norm(value)))
90 | }, [items, hasTemperature, n])
91 |
92 | const totalW = 100
93 | const totalH = yScale(n)
94 | useEffect(() => {
95 | Streamlit.setFrameHeight(totalH)
96 | }, [totalH])
97 |
98 | const svgRef = useRef(null);
99 |
100 | useEffect(() => {
101 | const svg = d3.select(svgRef.current)
102 | svg.selectAll('*').remove()
103 |
104 | const getItemClass = (index: number) => {
105 | var style = 'selectable-item '
106 | style += index === selection ? 'selection' : 'selector-item'
107 | return style
108 | }
109 |
110 | const getItemColor = (item: Item) => {
111 | var t = item.temperature ?? 0.0
112 | return item.index === selection ? 'orange' : colorScale(t)
113 | }
114 |
115 | var icons = svg
116 | .selectAll('items')
117 | .data(Array.from(Array(n).keys()))
118 | .enter()
119 | .append('circle')
120 | .attr('cx', (i) => itemCoords[i].x)
121 | .attr('cy', (i) => itemCoords[i].y)
122 | .attr('r', renderParams.itemSize / 2)
123 | .on('click', (event: PointerEvent, i) => {
124 | handleItemClick(items[i].index)
125 | })
126 | .attr('class', (i) => getItemClass(items[i].index))
127 | if (hasTemperature) {
128 | icons.style('fill', (i) => getItemColor(items[i]))
129 | }
130 |
131 | svg
132 | .selectAll('labels')
133 | .data(Array.from(Array(n).keys()))
134 | .enter()
135 | .append('text')
136 | .attr('x', (i) => itemCoords[i].x + renderParams.horizontalGap / 2)
137 | .attr('y', (i) => itemCoords[i].y)
138 | .attr('text-anchor', 'left')
139 | .attr('alignment-baseline', 'middle')
140 | .text((i) => items[i].text)
141 |
142 | }, [
143 | items,
144 | n,
145 | itemCoords,
146 | selection,
147 | colorScale,
148 | hasTemperature,
149 | ])
150 |
151 | return
152 | }
153 |
154 | export default withStreamlitConnection(Selector)
155 |
--------------------------------------------------------------------------------
/llm_transparency_tool/routes/test_contributions.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import unittest
8 | from typing import Any, List
9 |
10 | import torch
11 |
12 | import llm_transparency_tool.routes.contributions as contributions
13 |
14 |
15 | class TestContributions(unittest.TestCase):
16 | def setUp(self):
17 | torch.manual_seed(123)
18 |
19 | self.eps = 1e-4
20 |
21 | # It may be useful to run the test on GPU in case there are any issues with
22 | # creating temporary tensors on another device. But turn this off by default.
23 | self.test_on_gpu = False
24 |
25 | self.device = "cuda" if self.test_on_gpu else "cpu"
26 |
27 | self.batch = 4
28 | self.tokens = 5
29 | self.heads = 6
30 | self.d_model = 10
31 |
32 | self.decomposed_attn = torch.rand(
33 | self.batch,
34 | self.tokens,
35 | self.tokens,
36 | self.heads,
37 | self.d_model,
38 | device=self.device,
39 | )
40 | self.mlp_out = torch.rand(
41 | self.batch, self.tokens, self.d_model, device=self.device
42 | )
43 | self.resid_pre = torch.rand(
44 | self.batch, self.tokens, self.d_model, device=self.device
45 | )
46 | self.resid_mid = torch.rand(
47 | self.batch, self.tokens, self.d_model, device=self.device
48 | )
49 | self.resid_post = torch.rand(
50 | self.batch, self.tokens, self.d_model, device=self.device
51 | )
52 |
53 | def _assert_tensor_eq(self, t: torch.Tensor, expected: List[Any]):
54 | self.assertTrue(
55 | torch.isclose(t, torch.Tensor(expected), atol=self.eps).all(),
56 | t,
57 | )
58 |
59 | def test_mlp_contributions(self):
60 | mlp_out = torch.tensor([[[1.0, 1.0]]])
61 | resid_mid = torch.tensor([[[0.0, 0.0]]])
62 | resid_post = torch.tensor([[[1.0, 1.0]]])
63 |
64 | c_mlp, c_residual = contributions.get_mlp_contributions(
65 | resid_mid, resid_post, mlp_out
66 | )
67 | self.assertAlmostEqual(c_mlp.item(), 1.0, delta=self.eps)
68 | self.assertAlmostEqual(c_residual.item(), 0.0, delta=self.eps)
69 |
70 | def test_decomposed_attn_contributions(self):
71 | resid_pre = torch.tensor([[[2.0, 1.0]]])
72 | resid_mid = torch.tensor([[[2.0, 2.0]]])
73 | decomposed_attn = torch.tensor(
74 | [
75 | [
76 | [
77 | [
78 | [1.0, 1.0],
79 | [-1.0, 0.0],
80 | ]
81 | ]
82 | ]
83 | ]
84 | )
85 |
86 | c_attn, c_residual = contributions.get_attention_contributions(
87 | resid_pre, resid_mid, decomposed_attn, distance_norm=2
88 | )
89 | self._assert_tensor_eq(c_attn, [[[[0.43613, 0]]]])
90 | self.assertAlmostEqual(c_residual.item(), 0.56387, delta=self.eps)
91 |
92 | def test_decomposed_mlp_contributions(self):
93 | pre = torch.tensor([10.0, 10.0])
94 | post = torch.tensor([-10.0, 10.0])
95 | neuron_impacts = torch.tensor(
96 | [
97 | [0.0, 1.0],
98 | [1.0, 0.0],
99 | [-21.0, -1.0],
100 | ]
101 | )
102 | c_mlp, c_residual = contributions.get_decomposed_mlp_contributions(
103 | pre, post, neuron_impacts, distance_norm=2
104 | )
105 | # A bit counter-intuitive, but the only vector pointing from 0 towards the
106 | # output is the first one.
107 | self._assert_tensor_eq(c_mlp, [1, 0, 0])
108 | self.assertAlmostEqual(c_residual, 0, delta=self.eps)
109 |
110 | def test_decomposed_mlp_contributions_single_direction(self):
111 | pre = torch.tensor([1.0, 1.0])
112 | post = torch.tensor([4.0, 4.0])
113 | neuron_impacts = torch.tensor(
114 | [
115 | [1.0, 1.0],
116 | [2.0, 2.0],
117 | ]
118 | )
119 | c_mlp, c_residual = contributions.get_decomposed_mlp_contributions(
120 | pre, post, neuron_impacts, distance_norm=2
121 | )
122 | self._assert_tensor_eq(c_mlp, [0.25, 0.5])
123 | self.assertAlmostEqual(c_residual, 0.25, delta=self.eps)
124 |
125 | def test_attention_contributions_shape(self):
126 | c_attn, c_residual = contributions.get_attention_contributions(
127 | self.resid_pre, self.resid_mid, self.decomposed_attn
128 | )
129 | self.assertEqual(
130 | list(c_attn.shape), [self.batch, self.tokens, self.tokens, self.heads]
131 | )
132 | self.assertEqual(list(c_residual.shape), [self.batch, self.tokens])
133 |
134 | def test_mlp_contributions_shape(self):
135 | c_mlp, c_residual = contributions.get_mlp_contributions(
136 | self.resid_mid, self.resid_post, self.mlp_out
137 | )
138 | self.assertEqual(list(c_mlp.shape), [self.batch, self.tokens])
139 | self.assertEqual(list(c_residual.shape), [self.batch, self.tokens])
140 |
141 | def test_renormalizing_threshold(self):
142 | c_blocks = torch.Tensor([[0.05, 0.15], [0.05, 0.05]])
143 | c_residual = torch.Tensor([0.8, 0.9])
144 | norm_blocks, norm_residual = contributions.apply_threshold_and_renormalize(
145 | 0.1, c_blocks, c_residual
146 | )
147 | self._assert_tensor_eq(norm_blocks, [[0.0, 0.157894], [0.0, 0.0]])
148 | self._assert_tensor_eq(norm_residual, [0.842105, 1.0])
149 |
--------------------------------------------------------------------------------
/llm_transparency_tool/models/test_tlens_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import unittest
8 |
9 | import torch
10 |
11 | from llm_transparency_tool.models.tlens_model import TransformerLensTransparentLlm
12 | from llm_transparency_tool.models.transparent_llm import ModelInfo
13 |
14 |
15 | class TransparentLlmTestCase(unittest.TestCase):
16 | @classmethod
17 | def setUpClass(cls):
18 | # Picking the smallest model possible so that the test runs faster. It's ok to
19 | # change this model, but you'll need to update tokenization specifics in some
20 | # tests.
21 | cls._llm = TransformerLensTransparentLlm(
22 | model_name="facebook/opt-125m",
23 | device="cpu",
24 | )
25 |
26 | def setUp(self):
27 | self._llm.run(["test", "test 1"])
28 | self._eps = 1e-5
29 |
30 | def test_model_info(self):
31 | info = self._llm.model_info()
32 | self.assertEqual(
33 | info,
34 | ModelInfo(
35 | name="facebook/opt-125m",
36 | n_params_estimate=84934656,
37 | n_layers=12,
38 | n_heads=12,
39 | d_model=768,
40 | d_vocab=50272,
41 | ),
42 | )
43 |
44 | def test_tokens(self):
45 | tokens = self._llm.tokens()
46 |
47 | pad = 1
48 | bos = 2
49 | test = 21959
50 | one = 112
51 |
52 | self.assertEqual(tokens.tolist(), [[bos, test, pad], [bos, test, one]])
53 |
54 | def test_tokens_to_strings(self):
55 | s = self._llm.tokens_to_strings(torch.Tensor([2, 21959, 112]).to(torch.int))
56 | self.assertEqual(s, ["", "test", " 1"])
57 |
58 | def test_manage_state(self):
59 | # One llm.run was called at the setup. Call one more and make sure the object
60 | # returns values for the new state.
61 | self._llm.run(["one", "two", "three", "four"])
62 | self.assertEqual(self._llm.tokens().shape[0], 4)
63 |
64 | def test_residual_in_and_out(self):
65 | """
66 | Test that residual_in is a residual_out for the previous layer.
67 | """
68 | for layer in range(1, 12):
69 | prev_residual_out = self._llm.residual_out(layer - 1)
70 | residual_in = self._llm.residual_in(layer)
71 | diff = torch.max(torch.abs(residual_in - prev_residual_out)).item()
72 | self.assertLess(diff, self._eps, f"layer {layer}")
73 |
74 | def test_residual_plus_block(self):
75 | """
76 | Make sure that new residual = old residual + block output. Here, block is an ffn
77 | or attention. It's not that obvious because it could be that layer norm is
78 | applied after the block output, but before saving the result to residual.
79 | Luckily, this is not the case in TransformerLens, and we're relying on that.
80 | """
81 | layer = 3
82 | batch = 0
83 | pos = 0
84 |
85 | residual_in = self._llm.residual_in(layer)[batch][pos]
86 | residual_mid = self._llm.residual_after_attn(layer)[batch][pos]
87 | residual_out = self._llm.residual_out(layer)[batch][pos]
88 | ffn_out = self._llm.ffn_out(layer)[batch][pos]
89 | attn_out = self._llm.attention_output(batch, layer, pos)
90 |
91 | a = residual_mid
92 | b = residual_in + attn_out
93 | diff = torch.max(torch.abs(a - b)).item()
94 | self.assertLess(diff, self._eps, "attn")
95 |
96 | a = residual_out
97 | b = residual_mid + ffn_out
98 | diff = torch.max(torch.abs(a - b)).item()
99 | self.assertLess(diff, self._eps, "ffn")
100 |
101 | def test_tensor_shapes(self):
102 | # Not much we can do about the tensors, but at least check their shapes and
103 | # that they don't contain NaNs.
104 | vocab_size = 50272
105 | n_batch = 2
106 | n_tokens = 3
107 | d_model = 768
108 | d_hidden = d_model * 4
109 | n_heads = 12
110 | layer = 5
111 |
112 | device = self._llm.residual_in(0).device
113 |
114 | for name, tensor, expected_shape in [
115 | ("r_in", self._llm.residual_in(layer), [n_batch, n_tokens, d_model]),
116 | (
117 | "r_mid",
118 | self._llm.residual_after_attn(layer),
119 | [n_batch, n_tokens, d_model],
120 | ),
121 | ("r_out", self._llm.residual_out(layer), [n_batch, n_tokens, d_model]),
122 | ("logits", self._llm.logits(), [n_batch, n_tokens, vocab_size]),
123 | ("ffn_out", self._llm.ffn_out(layer), [n_batch, n_tokens, d_model]),
124 | (
125 | "decomposed_ffn_out",
126 | self._llm.decomposed_ffn_out(0, 0, 0),
127 | [d_hidden, d_model],
128 | ),
129 | ("neuron_activations", self._llm.neuron_activations(0, 0, 0), [d_hidden]),
130 | ("neuron_output", self._llm.neuron_output(0, 0), [d_model]),
131 | (
132 | "attention_matrix",
133 | self._llm.attention_matrix(0, 0, 0),
134 | [n_tokens, n_tokens],
135 | ),
136 | (
137 | "attention_output_per_head",
138 | self._llm.attention_output_per_head(0, 0, 0, 0),
139 | [d_model],
140 | ),
141 | (
142 | "attention_output",
143 | self._llm.attention_output(0, 0, 0),
144 | [d_model],
145 | ),
146 | (
147 | "decomposed_attn",
148 | self._llm.decomposed_attn(0, layer),
149 | [n_tokens, n_tokens, n_heads, d_model],
150 | ),
151 | (
152 | "unembed",
153 | self._llm.unembed(torch.zeros([d_model]).to(device), normalize=True),
154 | [vocab_size],
155 | ),
156 | ]:
157 | self.assertEqual(list(tensor.shape), expected_shape, name)
158 | self.assertFalse(torch.any(tensor.isnan()), name)
159 |
160 |
161 | if __name__ == "__main__":
162 | unittest.main()
163 |
--------------------------------------------------------------------------------
/llm_transparency_tool/models/transparent_llm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from abc import ABC, abstractmethod
8 | from dataclasses import dataclass
9 | from typing import List
10 |
11 | import torch
12 | from jaxtyping import Float, Int
13 |
14 |
15 | @dataclass
16 | class ModelInfo:
17 | name: str
18 |
19 | # Not the actual number of parameters, but rather the order of magnitude
20 | n_params_estimate: int
21 |
22 | n_layers: int
23 | n_heads: int
24 | d_model: int
25 | d_vocab: int
26 |
27 |
28 | class TransparentLlm(ABC):
29 | """
30 | An abstract stateful interface for a language model. The model is supposed to be
31 | loaded at the class initialization.
32 |
33 | The internal state is the resulting tensors from the last call of the `run` method.
34 | Most of the methods could return values based on the state, but some may do cheap
35 | computations based on them.
36 | """
37 |
38 | @abstractmethod
39 | def model_info(self) -> ModelInfo:
40 | """
41 | Gives general info about the model. This method must be available before any
42 | calls of the `run`.
43 | """
44 | pass
45 |
46 | @abstractmethod
47 | def run(self, sentences: List[str]) -> None:
48 | """
49 | Run the inference on the given sentences in a single batch and store all
50 | necessary info in the internal state.
51 | """
52 | pass
53 |
54 | @abstractmethod
55 | def batch_size(self) -> int:
56 | """
57 | The size of the batch that was used for the last call of `run`.
58 | """
59 | pass
60 |
61 | @abstractmethod
62 | def tokens(self) -> Int[torch.Tensor, "batch pos"]:
63 | pass
64 |
65 | @abstractmethod
66 | def tokens_to_strings(self, tokens: Int[torch.Tensor, "pos"]) -> List[str]:
67 | pass
68 |
69 | @abstractmethod
70 | def logits(self) -> Float[torch.Tensor, "batch pos d_vocab"]:
71 | pass
72 |
73 | @abstractmethod
74 | def unembed(
75 | self,
76 | t: Float[torch.Tensor, "d_model"],
77 | normalize: bool,
78 | ) -> Float[torch.Tensor, "vocab"]:
79 | """
80 | Project the given vector (for example, the state of the residual stream for a
81 | layer and token) into the output vocabulary.
82 |
83 | normalize: whether to apply the final normalization before the unembedding.
84 | Setting it to True and applying to output of the last layer gives the output of
85 | the model.
86 | """
87 | pass
88 |
89 | # ================= Methods related to the residual stream =================
90 |
91 | @abstractmethod
92 | def residual_in(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
93 | """
94 | The state of the residual stream before entering the layer. For example, when
95 | layer == 0 these must the embedded tokens (including positional embedding).
96 | """
97 | pass
98 |
99 | @abstractmethod
100 | def residual_after_attn(
101 | self, layer: int
102 | ) -> Float[torch.Tensor, "batch pos d_model"]:
103 | """
104 | The state of the residual stream after attention, but before the FFN in the
105 | given layer.
106 | """
107 | pass
108 |
109 | @abstractmethod
110 | def residual_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
111 | """
112 | The state of the residual stream after the given layer. This is equivalent to the
113 | next layer's input.
114 | """
115 | pass
116 |
117 | # ================ Methods related to the feed-forward layer ===============
118 |
119 | @abstractmethod
120 | def ffn_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
121 | """
122 | The output of the FFN layer, before it gets merged into the residual stream.
123 | """
124 | pass
125 |
126 | @abstractmethod
127 | def decomposed_ffn_out(
128 | self,
129 | batch_i: int,
130 | layer: int,
131 | pos: int,
132 | ) -> Float[torch.Tensor, "hidden d_model"]:
133 | """
134 | A collection of vectors added to the residual stream by each neuron. It should
135 | be the same as neuron activations multiplied by neuron outputs.
136 | """
137 | pass
138 |
139 | @abstractmethod
140 | def neuron_activations(
141 | self,
142 | batch_i: int,
143 | layer: int,
144 | pos: int,
145 | ) -> Float[torch.Tensor, "d_ffn"]:
146 | """
147 | The content of the hidden layer right after the activation function was applied.
148 | """
149 | pass
150 |
151 | @abstractmethod
152 | def neuron_output(
153 | self,
154 | layer: int,
155 | neuron: int,
156 | ) -> Float[torch.Tensor, "d_model"]:
157 | """
158 | Return the value that the given neuron adds to the residual stream. It's a raw
159 | vector from the model parameters, no activation involved.
160 | """
161 | pass
162 |
163 | # ==================== Methods related to the attention ====================
164 |
165 | @abstractmethod
166 | def attention_matrix(
167 | self, batch_i, layer: int, head: int
168 | ) -> Float[torch.Tensor, "query_pos key_pos"]:
169 | """
170 | Return a lower-diagonal attention matrix.
171 | """
172 | pass
173 |
174 | @abstractmethod
175 | def attention_output(
176 | self,
177 | batch_i: int,
178 | layer: int,
179 | pos: int,
180 | head: int,
181 | ) -> Float[torch.Tensor, "d_model"]:
182 | """
183 | Return what the given head at the given layer and pos added to the residual
184 | stream.
185 | """
186 | pass
187 |
188 | @abstractmethod
189 | def decomposed_attn(
190 | self, batch_i: int, layer: int
191 | ) -> Float[torch.Tensor, "source target head d_model"]:
192 | """
193 | Here
194 | - source: index of token from the previous layer
195 | - target: index of token on the current layer
196 | The decomposed attention tells what vector from source representation was used
197 | in order to contribute to the taget representation.
198 | """
199 | pass
200 |
--------------------------------------------------------------------------------
/llm_transparency_tool/routes/graph.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import List, Optional
8 |
9 | import networkx as nx
10 | import torch
11 |
12 | import llm_transparency_tool.routes.contributions as contributions
13 | from llm_transparency_tool.models.transparent_llm import TransparentLlm
14 |
15 |
16 | class GraphBuilder:
17 | """
18 | Constructs the contributions graph with edges given one by one. The resulting graph
19 | is a networkx graph that can be accessed via the `graph` field. It contains the
20 | following types of nodes:
21 |
22 | - X0_: the original token.
23 | - A_: the residual stream after attention at the given layer for the
24 | given token.
25 | - M_: the ffn block.
26 | - I_: the residual stream after the ffn block.
27 | """
28 |
29 | def __init__(self, n_layers: int, n_tokens: int):
30 | self._n_layers = n_layers
31 | self._n_tokens = n_tokens
32 |
33 | self.graph = nx.DiGraph()
34 | for layer in range(n_layers):
35 | for token in range(n_tokens):
36 | self.graph.add_node(f"A{layer}_{token}")
37 | self.graph.add_node(f"I{layer}_{token}")
38 | self.graph.add_node(f"M{layer}_{token}")
39 | for token in range(n_tokens):
40 | self.graph.add_node(f"X0_{token}")
41 |
42 | def get_output_node(self, token: int):
43 | return f"I{self._n_layers - 1}_{token}"
44 |
45 | def _add_edge(self, u: str, v: str, weight: float):
46 | # TODO(igortufanov): Here we sum up weights for multi-edges. It happens with
47 | # attention from the current token and the residual edge. Ideally these need to
48 | # be 2 separate edges, but then we need to do a MultiGraph. Multigraph is fine,
49 | # but when we try to traverse it, we face some NetworkX issue with EDGE_OK
50 | # receiving 3 arguments instead of 2.
51 | if self.graph.has_edge(u, v):
52 | self.graph[u][v]["weight"] += weight
53 | else:
54 | self.graph.add_edge(u, v, weight=weight)
55 |
56 | def add_attention_edge(self, layer: int, token_from: int, token_to: int, w: float):
57 | self._add_edge(
58 | f"I{layer-1}_{token_from}" if layer > 0 else f"X0_{token_from}",
59 | f"A{layer}_{token_to}",
60 | w,
61 | )
62 |
63 | def add_residual_to_attn(self, layer: int, token: int, w: float):
64 | self._add_edge(
65 | f"I{layer-1}_{token}" if layer > 0 else f"X0_{token}",
66 | f"A{layer}_{token}",
67 | w,
68 | )
69 |
70 | def add_ffn_edge(self, layer: int, token: int, w: float):
71 | self._add_edge(f"A{layer}_{token}", f"M{layer}_{token}", w)
72 | self._add_edge(f"M{layer}_{token}", f"I{layer}_{token}", w)
73 |
74 | def add_residual_to_ffn(self, layer: int, token: int, w: float):
75 | self._add_edge(f"A{layer}_{token}", f"I{layer}_{token}", w)
76 |
77 |
78 | @torch.no_grad()
79 | def build_full_graph(
80 | model: TransparentLlm,
81 | batch_i: int = 0,
82 | renormalizing_threshold: Optional[float] = None,
83 | ) -> nx.Graph:
84 | """
85 | Build the contribution graph for all blocks of the model and all tokens.
86 |
87 | model: The transparent llm which already did the inference.
88 | batch_i: Which sentence to use from the batch that was given to the model.
89 | renormalizing_threshold: If specified, will apply renormalizing thresholding to the
90 | contributions. All contributions below the threshold will be erazed and the rest
91 | will be renormalized.
92 | """
93 | n_layers = model.model_info().n_layers
94 | n_tokens = model.tokens()[batch_i].shape[0]
95 |
96 | builder = GraphBuilder(n_layers, n_tokens)
97 |
98 | for layer in range(n_layers):
99 | c_attn, c_resid_attn = contributions.get_attention_contributions(
100 | resid_pre=model.residual_in(layer)[batch_i].unsqueeze(0),
101 | resid_mid=model.residual_after_attn(layer)[batch_i].unsqueeze(0),
102 | decomposed_attn=model.decomposed_attn(batch_i, layer).unsqueeze(0),
103 | )
104 | if renormalizing_threshold is not None:
105 | c_attn, c_resid_attn = contributions.apply_threshold_and_renormalize(
106 | renormalizing_threshold, c_attn, c_resid_attn
107 | )
108 | for token_from in range(n_tokens):
109 | for token_to in range(n_tokens):
110 | # Sum attention contributions over heads.
111 | c = c_attn[batch_i, token_to, token_from].sum().item()
112 | builder.add_attention_edge(layer, token_from, token_to, c)
113 | for token in range(n_tokens):
114 | builder.add_residual_to_attn(
115 | layer, token, c_resid_attn[batch_i, token].item()
116 | )
117 |
118 | c_ffn, c_resid_ffn = contributions.get_mlp_contributions(
119 | resid_mid=model.residual_after_attn(layer)[batch_i].unsqueeze(0),
120 | resid_post=model.residual_out(layer)[batch_i].unsqueeze(0),
121 | mlp_out=model.ffn_out(layer)[batch_i].unsqueeze(0),
122 | )
123 | if renormalizing_threshold is not None:
124 | c_ffn, c_resid_ffn = contributions.apply_threshold_and_renormalize(
125 | renormalizing_threshold, c_ffn, c_resid_ffn
126 | )
127 | for token in range(n_tokens):
128 | builder.add_ffn_edge(layer, token, c_ffn[batch_i, token].item())
129 | builder.add_residual_to_ffn(
130 | layer, token, c_resid_ffn[batch_i, token].item()
131 | )
132 |
133 | return builder.graph
134 |
135 |
136 | def build_paths_to_predictions(
137 | graph: nx.Graph,
138 | n_layers: int,
139 | n_tokens: int,
140 | starting_tokens: List[int],
141 | threshold: float,
142 | ) -> List[nx.Graph]:
143 | """
144 | Given the full graph, this function returns only the trees leading to the specified
145 | tokens. Edges with weight below `threshold` will be ignored.
146 | """
147 | builder = GraphBuilder(n_layers, n_tokens)
148 |
149 | rgraph = graph.reverse()
150 | search_graph = nx.subgraph_view(
151 | rgraph, filter_edge=lambda u, v: rgraph[u][v]["weight"] > threshold
152 | )
153 |
154 | result = []
155 | for start in starting_tokens:
156 | assert start < n_tokens
157 | assert start >= 0
158 | edges = nx.edge_dfs(search_graph, source=builder.get_output_node(start))
159 | tree = search_graph.edge_subgraph(edges)
160 | # Reverse the edges because the dfs was going from upper layer downwards.
161 | result.append(tree.reverse())
162 |
163 | return result
164 |
--------------------------------------------------------------------------------
/llm_transparency_tool/routes/contributions.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Tuple
8 |
9 | import einops
10 | import torch
11 | from jaxtyping import Float
12 | from typeguard import typechecked
13 |
14 |
15 | @torch.no_grad()
16 | @typechecked
17 | def get_contributions(
18 | parts: torch.Tensor,
19 | whole: torch.Tensor,
20 | distance_norm: int = 1,
21 | ) -> torch.Tensor:
22 | """
23 | Compute contributions of the `parts` vectors into the `whole` vector.
24 |
25 | Shapes of the tensors are as follows:
26 | parts: p_1 ... p_k, v_1 ... v_n, d
27 | whole: v_1 ... v_n, d
28 | result: p_1 ... p_k, v_1 ... v_n
29 |
30 | Here
31 | * `p_1 ... p_k`: dimensions for enumerating the parts
32 | * `v_1 ... v_n`: dimensions listing the independent cases (batching),
33 | * `d` is the dimension to compute the distances on.
34 |
35 | The resulting contributions will be normalized so that
36 | for each v_: sum(over p_ of result(p_, v_)) = 1.
37 | """
38 | EPS = 1e-5
39 |
40 | k = len(parts.shape) - len(whole.shape)
41 | assert k >= 0
42 | assert parts.shape[k:] == whole.shape
43 | bc_whole = whole.expand(parts.shape) # new dims p_1 ... p_k are added to the front
44 |
45 | distance = torch.nn.functional.pairwise_distance(parts, bc_whole, p=distance_norm)
46 |
47 | whole_norm = torch.norm(whole, p=distance_norm, dim=-1)
48 | distance = (whole_norm - distance).clip(min=EPS)
49 |
50 | sum = distance.sum(dim=tuple(range(k)), keepdim=True)
51 |
52 | return distance / sum
53 |
54 |
55 | @torch.no_grad()
56 | @typechecked
57 | def get_contributions_with_one_off_part(
58 | parts: torch.Tensor,
59 | one_off: torch.Tensor,
60 | whole: torch.Tensor,
61 | distance_norm: int = 1,
62 | ) -> Tuple[torch.Tensor, torch.Tensor]:
63 | """
64 | Same as computing the contributions, but there is one additional part. That's useful
65 | because we always have the residual stream as one of the parts.
66 |
67 | See `get_contributions` documentation about `parts` and `whole` dimensions. The
68 | `one_off` should have the same dimensions as `whole`.
69 |
70 | Returns a pair consisting of
71 | 1. contributions tensor for the `parts`
72 | 2. contributions tensor for the `one_off` vector
73 | """
74 | assert one_off.shape == whole.shape
75 |
76 | k = len(parts.shape) - len(whole.shape)
77 | assert k >= 0
78 |
79 | # Flatten the p_ dimensions, get contributions for the list, unflatten.
80 | flat = parts.flatten(start_dim=0, end_dim=k - 1)
81 | flat = torch.cat([flat, one_off.unsqueeze(0)])
82 | contributions = get_contributions(flat, whole, distance_norm)
83 | parts_contributions, one_off_contributions = torch.split(
84 | contributions, flat.shape[0] - 1
85 | )
86 | return (
87 | parts_contributions.unflatten(0, parts.shape[0:k]),
88 | one_off_contributions[0],
89 | )
90 |
91 |
92 | @torch.no_grad()
93 | @typechecked
94 | def get_attention_contributions(
95 | resid_pre: Float[torch.Tensor, "batch pos d_model"],
96 | resid_mid: Float[torch.Tensor, "batch pos d_model"],
97 | decomposed_attn: Float[torch.Tensor, "batch pos key_pos head d_model"],
98 | distance_norm: int = 1,
99 | ) -> Tuple[
100 | Float[torch.Tensor, "batch pos key_pos head"],
101 | Float[torch.Tensor, "batch pos"],
102 | ]:
103 | """
104 | Returns a pair of
105 | - a tensor of contributions of each token via each head
106 | - the contribution of the residual stream.
107 | """
108 |
109 | # part dimensions | batch dimensions | vector dimension
110 | # ----------------+------------------+-----------------
111 | # key_pos, head | batch, pos | d_model
112 | parts = einops.rearrange(
113 | decomposed_attn,
114 | "batch pos key_pos head d_model -> key_pos head batch pos d_model",
115 | )
116 | attn_contribution, residual_contribution = get_contributions_with_one_off_part(
117 | parts, resid_pre, resid_mid, distance_norm
118 | )
119 | return (
120 | einops.rearrange(
121 | attn_contribution, "key_pos head batch pos -> batch pos key_pos head"
122 | ),
123 | residual_contribution,
124 | )
125 |
126 |
127 | @torch.no_grad()
128 | @typechecked
129 | def get_mlp_contributions(
130 | resid_mid: Float[torch.Tensor, "batch pos d_model"],
131 | resid_post: Float[torch.Tensor, "batch pos d_model"],
132 | mlp_out: Float[torch.Tensor, "batch pos d_model"],
133 | distance_norm: int = 1,
134 | ) -> Tuple[Float[torch.Tensor, "batch pos"], Float[torch.Tensor, "batch pos"]]:
135 | """
136 | Returns a pair of (mlp, residual) contributions for each sentence and token.
137 | """
138 |
139 | contributions = get_contributions(
140 | torch.stack((mlp_out, resid_mid)), resid_post, distance_norm
141 | )
142 | return contributions[0], contributions[1]
143 |
144 |
145 | @torch.no_grad()
146 | @typechecked
147 | def get_decomposed_mlp_contributions(
148 | resid_mid: Float[torch.Tensor, "d_model"],
149 | resid_post: Float[torch.Tensor, "d_model"],
150 | decomposed_mlp_out: Float[torch.Tensor, "hidden d_model"],
151 | distance_norm: int = 1,
152 | ) -> Tuple[Float[torch.Tensor, "hidden"], float]:
153 | """
154 | Similar to `get_mlp_contributions`, but it takes the MLP output for each neuron of
155 | the hidden layer and thus computes a contribution per neuron.
156 |
157 | Doesn't contain batch and token dimensions for sake of saving memory. But we may
158 | consider adding them.
159 | """
160 |
161 | neuron_contributions, residual_contribution = get_contributions_with_one_off_part(
162 | decomposed_mlp_out, resid_mid, resid_post, distance_norm
163 | )
164 | return neuron_contributions, residual_contribution.item()
165 |
166 |
167 | @torch.no_grad()
168 | def apply_threshold_and_renormalize(
169 | threshold: float,
170 | c_blocks: torch.Tensor,
171 | c_residual: torch.Tensor,
172 | ) -> Tuple[torch.Tensor, torch.Tensor]:
173 | """
174 | Thresholding mechanism used in the original graphs paper. After the threshold is
175 | applied, the remaining contributions are renormalized on order to sum up to 1 for
176 | each representation.
177 |
178 | threshold: The threshold.
179 | c_residual: Contribution of the residual stream for each representation. This tensor
180 | should contain 1 element per representation, i.e., its dimensions are all batch
181 | dimensions.
182 | c_blocks: Contributions of the blocks. Could be 1 block per representation, like
183 | ffn, or heads*tokens blocks in case of attention. The shape of `c_residual`
184 | must be a prefix if the shape of this tensor. The remaining dimensions are for
185 | listing the blocks.
186 | """
187 |
188 | block_dims = len(c_blocks.shape)
189 | resid_dims = len(c_residual.shape)
190 | bound_dims = block_dims - resid_dims
191 | assert bound_dims >= 0
192 | assert c_blocks.shape[0:resid_dims] == c_residual.shape
193 |
194 | c_blocks = c_blocks * (c_blocks > threshold)
195 | c_residual = c_residual * (c_residual > threshold)
196 |
197 | denom = c_residual + c_blocks.sum(dim=tuple(range(resid_dims, block_dims)))
198 | return (
199 | c_blocks / denom.reshape(denom.shape + (1,) * bound_dims),
200 | c_residual / denom,
201 | )
202 |
--------------------------------------------------------------------------------
/llm_transparency_tool/models/tlens_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from dataclasses import dataclass
8 | from typing import List, Optional
9 |
10 | import torch
11 | import transformer_lens
12 | import transformers
13 | from fancy_einsum import einsum
14 | from jaxtyping import Float, Int
15 | from typeguard import typechecked
16 | import streamlit as st
17 |
18 | from llm_transparency_tool.models.transparent_llm import ModelInfo, TransparentLlm
19 | from transformer_lens.loading_from_pretrained import MODEL_ALIASES, get_official_model_name
20 |
21 |
22 | @dataclass
23 | class _RunInfo:
24 | tokens: Int[torch.Tensor, "batch pos"]
25 | logits: Float[torch.Tensor, "batch pos d_vocab"]
26 | cache: transformer_lens.ActivationCache
27 |
28 |
29 | @st.cache_resource(
30 | max_entries=1,
31 | show_spinner=True,
32 | hash_funcs={
33 | transformers.PreTrainedModel: id,
34 | transformers.PreTrainedTokenizer: id
35 | }
36 | )
37 | def load_hooked_transformer(
38 | model_name: str,
39 | hf_model: Optional[transformers.PreTrainedModel] = None,
40 | tlens_device: str = "cuda",
41 | dtype: torch.dtype = torch.float32,
42 | supported_model_name: Optional[str] = None,
43 | ):
44 | if supported_model_name is None:
45 | supported_model_name = model_name
46 | supported_model_name = get_official_model_name(supported_model_name)
47 | if model_name not in MODEL_ALIASES:
48 | MODEL_ALIASES[supported_model_name] = []
49 | if model_name not in MODEL_ALIASES[supported_model_name]:
50 | MODEL_ALIASES[supported_model_name].append(model_name)
51 | tlens_model = transformer_lens.HookedTransformer.from_pretrained(
52 | model_name,
53 | hf_model=hf_model,
54 | fold_ln=False, # Keep layer norm where it is.
55 | center_writing_weights=False,
56 | center_unembed=False,
57 | device=tlens_device,
58 | dtype=dtype,
59 | )
60 | tlens_model.eval()
61 | return tlens_model
62 |
63 |
64 | # TODO(igortufanov): If we want to scale the app to multiple users, we need more careful
65 | # thread-safe implementation. The simplest option could be to wrap the existing methods
66 | # in mutexes.
67 | class TransformerLensTransparentLlm(TransparentLlm):
68 | """
69 | Implementation of Transparent LLM based on transformer lens.
70 |
71 | Args:
72 | - model_name: The official name of the model from HuggingFace. Even if the model was
73 | patched or loaded locally, the name should still be official because that's how
74 | transformer_lens treats the model.
75 | - hf_model: The language model as a HuggingFace class.
76 | - tokenizer,
77 | - device: "gpu" or "cpu"
78 | """
79 |
80 | def __init__(
81 | self,
82 | model_name: str,
83 | hf_model: Optional[transformers.PreTrainedModel] = None,
84 | tokenizer: Optional[transformers.PreTrainedTokenizer] = None,
85 | device: str = "gpu",
86 | dtype: torch.dtype = torch.float32,
87 | supported_model_name: str = None,
88 | ):
89 | if device == "gpu":
90 | self.device = "cuda"
91 | if not torch.cuda.is_available():
92 | RuntimeError("Asked to run on gpu, but torch couldn't find cuda")
93 | elif device == "cpu":
94 | self.device = "cpu"
95 | else:
96 | raise RuntimeError(f"Specified device {device} is not a valid option")
97 |
98 | self.dtype = dtype
99 | self.hf_tokenizer = tokenizer
100 | self.hf_model = hf_model
101 |
102 | # self._model = tlens_model
103 | self._model_name = model_name
104 | self._supported_model_name = supported_model_name
105 | self._prepend_bos = True
106 | self._last_run = None
107 | self._run_exception = RuntimeError(
108 | "Tried to use the model output before calling the `run` method"
109 | )
110 |
111 | def copy(self):
112 | import copy
113 | return copy.copy(self)
114 |
115 | @property
116 | def _model(self):
117 | tlens_model = load_hooked_transformer(
118 | self._model_name,
119 | hf_model=self.hf_model,
120 | tlens_device=self.device,
121 | dtype=self.dtype,
122 | supported_model_name=self._supported_model_name,
123 | )
124 |
125 | if self.hf_tokenizer is not None:
126 | tlens_model.set_tokenizer(self.hf_tokenizer, default_padding_side="left")
127 |
128 | tlens_model.set_use_attn_result(True)
129 | tlens_model.set_use_attn_in(False)
130 | tlens_model.set_use_split_qkv_input(False)
131 |
132 | return tlens_model
133 |
134 | def model_info(self) -> ModelInfo:
135 | cfg = self._model.cfg
136 | return ModelInfo(
137 | name=self._model_name,
138 | n_params_estimate=cfg.n_params,
139 | n_layers=cfg.n_layers,
140 | n_heads=cfg.n_heads,
141 | d_model=cfg.d_model,
142 | d_vocab=cfg.d_vocab,
143 | )
144 |
145 | @torch.no_grad()
146 | def run(self, sentences: List[str]) -> None:
147 | tokens = self._model.to_tokens(sentences, prepend_bos=self._prepend_bos)
148 | logits, cache = self._model.run_with_cache(tokens)
149 |
150 | self._last_run = _RunInfo(
151 | tokens=tokens,
152 | logits=logits,
153 | cache=cache,
154 | )
155 |
156 | def batch_size(self) -> int:
157 | if not self._last_run:
158 | raise self._run_exception
159 | return self._last_run.logits.shape[0]
160 |
161 | @typechecked
162 | def tokens(self) -> Int[torch.Tensor, "batch pos"]:
163 | if not self._last_run:
164 | raise self._run_exception
165 | return self._last_run.tokens
166 |
167 | @typechecked
168 | def tokens_to_strings(self, tokens: Int[torch.Tensor, "pos"]) -> List[str]:
169 | return self._model.to_str_tokens(tokens)
170 |
171 | @typechecked
172 | def logits(self) -> Float[torch.Tensor, "batch pos d_vocab"]:
173 | if not self._last_run:
174 | raise self._run_exception
175 | return self._last_run.logits
176 |
177 | @torch.no_grad()
178 | @typechecked
179 | def unembed(
180 | self,
181 | t: Float[torch.Tensor, "d_model"],
182 | normalize: bool,
183 | ) -> Float[torch.Tensor, "vocab"]:
184 | # t: [d_model] -> [batch, pos, d_model]
185 | tdim = t.unsqueeze(0).unsqueeze(0)
186 | if normalize:
187 | normalized = self._model.ln_final(tdim)
188 | result = self._model.unembed(normalized)
189 | else:
190 | result = self._model.unembed(tdim.to(self.dtype))
191 | return result[0][0]
192 |
193 | def _get_block(self, layer: int, block_name: str) -> torch.Tensor:
194 | if not self._last_run:
195 | raise self._run_exception
196 | return self._last_run.cache[f"blocks.{layer}.{block_name}"]
197 |
198 | # ================= Methods related to the residual stream =================
199 |
200 | @typechecked
201 | def residual_in(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
202 | if not self._last_run:
203 | raise self._run_exception
204 | return self._get_block(layer, "hook_resid_pre")
205 |
206 | @typechecked
207 | def residual_after_attn(
208 | self, layer: int
209 | ) -> Float[torch.Tensor, "batch pos d_model"]:
210 | if not self._last_run:
211 | raise self._run_exception
212 | return self._get_block(layer, "hook_resid_mid")
213 |
214 | @typechecked
215 | def residual_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
216 | if not self._last_run:
217 | raise self._run_exception
218 | return self._get_block(layer, "hook_resid_post")
219 |
220 | # ================ Methods related to the feed-forward layer ===============
221 |
222 | @typechecked
223 | def ffn_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
224 | if not self._last_run:
225 | raise self._run_exception
226 | return self._get_block(layer, "hook_mlp_out")
227 |
228 | @torch.no_grad()
229 | @typechecked
230 | def decomposed_ffn_out(
231 | self,
232 | batch_i: int,
233 | layer: int,
234 | pos: int,
235 | ) -> Float[torch.Tensor, "hidden d_model"]:
236 | # Take activations right before they're multiplied by W_out, i.e. non-linearity
237 | # and layer norm are already applied.
238 | processed_activations = self._get_block(layer, "mlp.hook_post")[batch_i][pos]
239 | return torch.mul(processed_activations.unsqueeze(-1), self._model.blocks[layer].mlp.W_out)
240 |
241 | @typechecked
242 | def neuron_activations(
243 | self,
244 | batch_i: int,
245 | layer: int,
246 | pos: int,
247 | ) -> Float[torch.Tensor, "hidden"]:
248 | return self._get_block(layer, "mlp.hook_pre")[batch_i][pos]
249 |
250 | @typechecked
251 | def neuron_output(
252 | self,
253 | layer: int,
254 | neuron: int,
255 | ) -> Float[torch.Tensor, "d_model"]:
256 | return self._model.blocks[layer].mlp.W_out[neuron]
257 |
258 | # ==================== Methods related to the attention ====================
259 |
260 | @typechecked
261 | def attention_matrix(
262 | self, batch_i: int, layer: int, head: int
263 | ) -> Float[torch.Tensor, "query_pos key_pos"]:
264 | return self._get_block(layer, "attn.hook_pattern")[batch_i][head]
265 |
266 | @typechecked
267 | def attention_output_per_head(
268 | self,
269 | batch_i: int,
270 | layer: int,
271 | pos: int,
272 | head: int,
273 | ) -> Float[torch.Tensor, "d_model"]:
274 | return self._get_block(layer, "attn.hook_result")[batch_i][pos][head]
275 |
276 | @typechecked
277 | def attention_output(
278 | self,
279 | batch_i: int,
280 | layer: int,
281 | pos: int,
282 | ) -> Float[torch.Tensor, "d_model"]:
283 | return self._get_block(layer, "hook_attn_out")[batch_i][pos]
284 |
285 | @torch.no_grad()
286 | @typechecked
287 | def decomposed_attn(
288 | self, batch_i: int, layer: int
289 | ) -> Float[torch.Tensor, "pos key_pos head d_model"]:
290 | if not self._last_run:
291 | raise self._run_exception
292 | hook_v = self._get_block(layer, "attn.hook_v")[batch_i]
293 | b_v = self._model.blocks[layer].attn.b_V
294 |
295 | # support for gqa
296 | num_head_groups = b_v.shape[-2] // hook_v.shape[-2]
297 | hook_v = hook_v.repeat_interleave(num_head_groups, dim=-2)
298 |
299 | v = hook_v + b_v
300 | pattern = self._get_block(layer, "attn.hook_pattern")[batch_i].to(v.dtype)
301 | z = einsum(
302 | "key_pos head d_head, "
303 | "head query_pos key_pos -> "
304 | "query_pos key_pos head d_head",
305 | v,
306 | pattern,
307 | )
308 | decomposed_attn = einsum(
309 | "pos key_pos head d_head, "
310 | "head d_head d_model -> "
311 | "pos key_pos head d_model",
312 | z,
313 | self._model.blocks[layer].attn.W_O,
314 | )
315 | return decomposed_attn
316 |
--------------------------------------------------------------------------------
/llm_transparency_tool/components/frontend/src/ContributionGraph.tsx:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Meta Platforms, Inc. and affiliates.
3 | * All rights reserved.
4 | *
5 | * This source code is licensed under the license found in the
6 | * LICENSE file in the root directory of this source tree.
7 | */
8 |
9 | import {
10 | ComponentProps,
11 | Streamlit,
12 | withStreamlitConnection,
13 | } from 'streamlit-component-lib'
14 | import React, { useEffect, useMemo, useRef, useState } from 'react';
15 | import * as d3 from 'd3';
16 |
17 | import {
18 | Label,
19 | Point,
20 | } from './common';
21 | import './LlmViewer.css';
22 |
23 | export const renderParams = {
24 | cellH: 32,
25 | cellW: 32,
26 | attnSize: 8,
27 | afterFfnSize: 8,
28 | ffnSize: 6,
29 | tokenSelectorSize: 16,
30 | layerCornerRadius: 6,
31 | }
32 |
33 | interface Cell {
34 | layer: number
35 | token: number
36 | }
37 |
38 | enum CellItem {
39 | AfterAttn = 'after_attn',
40 | AfterFfn = 'after_ffn',
41 | Ffn = 'ffn',
42 | Original = 'original', // They will only be at level = 0
43 | }
44 |
45 | interface Node {
46 | cell: Cell | null
47 | item: CellItem | null
48 | }
49 |
50 | interface NodeProps {
51 | node: Node
52 | pos: Point
53 | isActive: boolean
54 | }
55 |
56 | interface EdgeRaw {
57 | weight: number
58 | source: string
59 | target: string
60 | }
61 |
62 | interface Edge {
63 | weight: number
64 | from: Node
65 | to: Node
66 | fromPos: Point
67 | toPos: Point
68 | isSelectable: boolean
69 | isFfn: boolean
70 | }
71 |
72 | interface Selection {
73 | node: Node | null
74 | edge: Edge | null
75 | }
76 |
77 | function tokenPointerPolygon(origin: Point) {
78 | const r = renderParams.tokenSelectorSize / 2
79 | const dy = r / 2
80 | const dx = r * Math.sqrt(3.0) / 2
81 | // Draw an arrow looking down
82 | return [
83 | [origin.x, origin.y + r],
84 | [origin.x + dx, origin.y - dy],
85 | [origin.x - dx, origin.y - dy],
86 | ].toString()
87 | }
88 |
89 | function isSameCell(cell1: Cell | null, cell2: Cell | null) {
90 | if (cell1 == null || cell2 == null) {
91 | return false
92 | }
93 | return cell1.layer === cell2.layer && cell1.token === cell2.token
94 | }
95 |
96 | function isSameNode(node1: Node | null, node2: Node | null) {
97 | if (node1 === null || node2 === null) {
98 | return false
99 | }
100 | return isSameCell(node1.cell, node2.cell)
101 | && node1.item === node2.item;
102 | }
103 |
104 | function isSameEdge(edge1: Edge | null, edge2: Edge | null) {
105 | if (edge1 === null || edge2 === null) {
106 | return false
107 | }
108 | return isSameNode(edge1.from, edge2.from) && isSameNode(edge1.to, edge2.to);
109 | }
110 |
111 | function nodeFromString(name: string) {
112 | const match = name.match(/([AIMX])(\d+)_(\d+)/)
113 | if (match == null) {
114 | return {
115 | cell: null,
116 | item: null,
117 | }
118 | }
119 | const [, type, layerStr, tokenStr] = match
120 | const layer = +layerStr
121 | const token = +tokenStr
122 |
123 | const typeToCellItem = new Map([
124 | ['A', CellItem.AfterAttn],
125 | ['I', CellItem.AfterFfn],
126 | ['M', CellItem.Ffn],
127 | ['X', CellItem.Original],
128 | ])
129 | return {
130 | cell: {
131 | layer: layer,
132 | token: token,
133 | },
134 | item: typeToCellItem.get(type) ?? null,
135 | }
136 | }
137 |
138 | function isValidNode(node: Node, nLayers: number, nTokens: number) {
139 | if (node.cell === null) {
140 | return true
141 | }
142 | return node.cell.layer < nLayers && node.cell.token < nTokens
143 | }
144 |
145 | function isValidSelection(selection: Selection, nLayers: number, nTokens: number) {
146 | if (selection.node !== null) {
147 | return isValidNode(selection.node, nLayers, nTokens)
148 | }
149 | if (selection.edge !== null) {
150 | return isValidNode(selection.edge.from, nLayers, nTokens) &&
151 | isValidNode(selection.edge.to, nLayers, nTokens)
152 | }
153 | return true
154 | }
155 |
156 | const ContributionGraph = ({ args }: ComponentProps) => {
157 | const modelInfo = args['model_info']
158 | const tokens = args['tokens']
159 | const edgesRaw: EdgeRaw[][] = args['edges_per_token']
160 |
161 | const nLayers = modelInfo === null ? 0 : modelInfo.n_layers
162 | const nTokens = tokens === null ? 0 : tokens.length
163 |
164 | const [selection, setSelection] = useState({
165 | node: null,
166 | edge: null,
167 | })
168 | var curSelection = selection
169 | if (!isValidSelection(selection, nLayers, nTokens)) {
170 | curSelection = {
171 | node: null,
172 | edge: null,
173 | }
174 | setSelection(curSelection)
175 | Streamlit.setComponentValue(curSelection)
176 | }
177 |
178 | const [startToken, setStartToken] = useState(nTokens - 1)
179 | // We have startToken state var, but it won't be updated till next render, so use
180 | // this var in the current render.
181 | var curStartToken = startToken
182 | if (startToken >= nTokens) {
183 | curStartToken = nTokens - 1
184 | setStartToken(curStartToken)
185 | }
186 |
187 | const handleRepresentationClick = (node: Node) => {
188 | const newSelection: Selection = {
189 | node: node,
190 | edge: null,
191 | }
192 | setSelection(newSelection)
193 | Streamlit.setComponentValue(newSelection)
194 | }
195 |
196 | const handleEdgeClick = (edge: Edge) => {
197 | if (!edge.isSelectable) {
198 | return
199 | }
200 | const newSelection: Selection = {
201 | node: edge.to,
202 | edge: edge,
203 | }
204 | setSelection(newSelection)
205 | Streamlit.setComponentValue(newSelection)
206 | }
207 |
208 | const handleTokenClick = (t: number) => {
209 | setStartToken(t)
210 | }
211 |
212 | const [xScale, yScale] = useMemo(() => {
213 | const x = d3.scaleLinear()
214 | .domain([-2, nTokens - 1])
215 | .range([0, renderParams.cellW * (nTokens + 2)])
216 | const y = d3.scaleLinear()
217 | .domain([-1, nLayers])
218 | .range([renderParams.cellH * (nLayers + 2), 0])
219 | return [x, y]
220 | }, [nLayers, nTokens])
221 |
222 | const cells = useMemo(() => {
223 | let result: Cell[] = []
224 | for (let l = 0; l < nLayers; l++) {
225 | for (let t = 0; t < nTokens; t++) {
226 | result.push({
227 | layer: l,
228 | token: t,
229 | })
230 | }
231 | }
232 | return result
233 | }, [nLayers, nTokens])
234 |
235 | const nodeCoords = useMemo(() => {
236 | let result = new Map()
237 | const w = renderParams.cellW
238 | const h = renderParams.cellH
239 | for (var cell of cells) {
240 | const cx = xScale(cell.token + 0.5)
241 | const cy = yScale(cell.layer - 0.5)
242 | result.set(
243 | JSON.stringify({ cell: cell, item: CellItem.AfterAttn }),
244 | { x: cx, y: cy + h / 4 },
245 | )
246 | result.set(
247 | JSON.stringify({ cell: cell, item: CellItem.AfterFfn }),
248 | { x: cx, y: cy - h / 4 },
249 | )
250 | result.set(
251 | JSON.stringify({ cell: cell, item: CellItem.Ffn }),
252 | { x: cx + 5 * w / 16, y: cy },
253 | )
254 | }
255 | for (let t = 0; t < nTokens; t++) {
256 | cell = {
257 | layer: 0,
258 | token: t,
259 | }
260 | const cx = xScale(cell.token + 0.5)
261 | const cy = yScale(cell.layer - 1.0)
262 | result.set(
263 | JSON.stringify({ cell: cell, item: CellItem.Original }),
264 | { x: cx, y: cy + h / 4 },
265 | )
266 | }
267 | return result
268 | }, [cells, nTokens, xScale, yScale])
269 |
270 | const edges: Edge[][] = useMemo(() => {
271 | let result = []
272 | for (var edgeList of edgesRaw) {
273 | let edgesPerStartToken = []
274 | for (var edge of edgeList) {
275 | const u = nodeFromString(edge.source)
276 | const v = nodeFromString(edge.target)
277 | var isSelectable = (
278 | u.cell !== null && v.cell !== null && v.item === CellItem.AfterAttn
279 | )
280 | var isFfn = (
281 | u.cell !== null && v.cell !== null && (
282 | u.item === CellItem.Ffn || v.item === CellItem.Ffn
283 | )
284 | )
285 | edgesPerStartToken.push({
286 | weight: edge.weight,
287 | from: u,
288 | to: v,
289 | fromPos: nodeCoords.get(JSON.stringify(u)) ?? { 'x': 0, 'y': 0 },
290 | toPos: nodeCoords.get(JSON.stringify(v)) ?? { 'x': 0, 'y': 0 },
291 | isSelectable: isSelectable,
292 | isFfn: isFfn,
293 | })
294 | }
295 | result.push(edgesPerStartToken)
296 | }
297 | return result
298 | }, [edgesRaw, nodeCoords])
299 |
300 | const activeNodes = useMemo(() => {
301 | let result = new Set()
302 | for (var edge of edges[curStartToken]) {
303 | const u = JSON.stringify(edge.from)
304 | const v = JSON.stringify(edge.to)
305 | result.add(u)
306 | result.add(v)
307 | }
308 | return result
309 | }, [edges, curStartToken])
310 |
311 | const nodeProps = useMemo(() => {
312 | let result: Array = []
313 | nodeCoords.forEach((p: Point, node: string) => {
314 | result.push({
315 | node: JSON.parse(node),
316 | pos: p,
317 | isActive: activeNodes.has(node),
318 | })
319 | })
320 | return result
321 | }, [nodeCoords, activeNodes])
322 |
323 | const tokenLabels: Label[] = useMemo(() => {
324 | if (!tokens) {
325 | return []
326 | }
327 | return tokens.map((s: string, i: number) => ({
328 | text: s.replace(/ /g, '·'),
329 | pos: {
330 | x: xScale(i + 0.5),
331 | y: yScale(-1.5),
332 | },
333 | }))
334 | }, [tokens, xScale, yScale])
335 |
336 | const layerLabels: Label[] = useMemo(() => {
337 | return Array.from(Array(nLayers).keys()).map(i => ({
338 | text: 'L' + i,
339 | pos: {
340 | x: xScale(-0.25),
341 | y: yScale(i - 0.5),
342 | },
343 | }))
344 | }, [nLayers, xScale, yScale])
345 |
346 | const tokenSelectors: Array<[number, Point]> = useMemo(() => {
347 | return Array.from(Array(nTokens).keys()).map(i => ([
348 | i,
349 | {
350 | x: xScale(i + 0.5),
351 | y: yScale(nLayers - 0.5),
352 | }
353 | ]))
354 | }, [nTokens, nLayers, xScale, yScale])
355 |
356 | const totalW = xScale(nTokens + 2)
357 | const totalH = yScale(-4)
358 | useEffect(() => {
359 | Streamlit.setFrameHeight(totalH)
360 | }, [totalH])
361 |
362 | const colorScale = d3.scaleLinear(
363 | [0.0, 0.5, 1.0],
364 | ['#9eba66', 'darkolivegreen', 'darkolivegreen']
365 | )
366 | const ffnEdgeColorScale = d3.scaleLinear(
367 | [0.0, 0.5, 1.0],
368 | ['orchid', 'purple', 'purple']
369 | )
370 | const edgeWidthScale = d3.scaleLinear([0.0, 0.5, 1.0], [2.0, 3.0, 3.0])
371 |
372 | const svgRef = useRef(null);
373 |
374 | useEffect(() => {
375 | const getNodeStyle = (p: NodeProps, type: string) => {
376 | if (isSameNode(p.node, curSelection.node)) {
377 | return 'selectable-item selection'
378 | }
379 | if (p.isActive) {
380 | return 'selectable-item active-' + type + '-node'
381 | }
382 | return 'selectable-item inactive-node'
383 | }
384 |
385 | const svg = d3.select(svgRef.current)
386 | svg.selectAll('*').remove()
387 |
388 | svg
389 | .selectAll('layers')
390 | .data(Array.from(Array(nLayers).keys()).filter((x) => x % 2 === 1))
391 | .enter()
392 | .append('rect')
393 | .attr('class', 'layer-highlight')
394 | .attr('x', xScale(-1.0))
395 | .attr('y', (layer) => yScale(layer))
396 | .attr('width', xScale(nTokens + 0.25) - xScale(-1.0))
397 | .attr('height', (layer) => yScale(layer) - yScale(layer + 1))
398 | .attr('rx', renderParams.layerCornerRadius)
399 |
400 | svg
401 | .selectAll('edges')
402 | .data(edges[curStartToken])
403 | .enter()
404 | .append('line')
405 | .style('stroke', (edge: Edge) => {
406 | if (isSameEdge(edge, curSelection.edge)) {
407 | return 'orange'
408 | }
409 | if (edge.isFfn) {
410 | return ffnEdgeColorScale(edge.weight)
411 | }
412 | return colorScale(edge.weight)
413 | })
414 | .attr('class', (edge: Edge) => edge.isSelectable ? 'selectable-edge' : '')
415 | .style('stroke-width', (edge: Edge) => edgeWidthScale(edge.weight))
416 | .attr('x1', (edge: Edge) => edge.fromPos.x)
417 | .attr('y1', (edge: Edge) => edge.fromPos.y)
418 | .attr('x2', (edge: Edge) => edge.toPos.x)
419 | .attr('y2', (edge: Edge) => edge.toPos.y)
420 | .on('click', (event: PointerEvent, edge) => {
421 | handleEdgeClick(edge)
422 | })
423 |
424 | svg
425 | .selectAll('residual')
426 | .data(nodeProps)
427 | .enter()
428 | .filter((p) => {
429 | return p.node.item === CellItem.AfterAttn
430 | || p.node.item === CellItem.AfterFfn
431 | })
432 | .append('circle')
433 | .attr('class', (p) => getNodeStyle(p, 'residual'))
434 | .attr('cx', (p) => p.pos.x)
435 | .attr('cy', (p) => p.pos.y)
436 | .attr('r', renderParams.attnSize / 2)
437 | .on('click', (event: PointerEvent, p) => {
438 | handleRepresentationClick(p.node)
439 | })
440 |
441 | svg
442 | .selectAll('ffn')
443 | .data(nodeProps)
444 | .enter()
445 | .filter((p) => p.node.item === CellItem.Ffn && p.isActive)
446 | .append('rect')
447 | .attr('class', (p) => getNodeStyle(p, 'ffn'))
448 | .attr('x', (p) => p.pos.x - renderParams.ffnSize / 2)
449 | .attr('y', (p) => p.pos.y - renderParams.ffnSize / 2)
450 | .attr('width', renderParams.ffnSize)
451 | .attr('height', renderParams.ffnSize)
452 | .on('click', (event: PointerEvent, p) => {
453 | handleRepresentationClick(p.node)
454 | })
455 |
456 | svg
457 | .selectAll('token_labels')
458 | .data(tokenLabels)
459 | .enter()
460 | .append('text')
461 | .attr('x', (label: Label) => label.pos.x)
462 | .attr('y', (label: Label) => label.pos.y)
463 | .attr('text-anchor', 'end')
464 | .attr('dominant-baseline', 'middle')
465 | .attr('alignment-baseline', 'top')
466 | .attr('transform', (label: Label) =>
467 | 'rotate(-40, ' + label.pos.x + ', ' + label.pos.y + ')')
468 | .text((label: Label) => label.text)
469 |
470 | svg
471 | .selectAll('layer_labels')
472 | .data(layerLabels)
473 | .enter()
474 | .append('text')
475 | .attr('x', (label: Label) => label.pos.x)
476 | .attr('y', (label: Label) => label.pos.y)
477 | .attr('text-anchor', 'middle')
478 | .attr('alignment-baseline', 'middle')
479 | .text((label: Label) => label.text)
480 |
481 | svg
482 | .selectAll('token_selectors')
483 | .data(tokenSelectors)
484 | .enter()
485 | .append('polygon')
486 | .attr('class', ([i,]) => (
487 | curStartToken === i
488 | ? 'selectable-item selection'
489 | : 'selectable-item token-selector'
490 | ))
491 | .attr('points', ([, p]) => tokenPointerPolygon(p))
492 | .attr('r', renderParams.tokenSelectorSize / 2)
493 | .on('click', (event: PointerEvent, [i,]) => {
494 | handleTokenClick(i)
495 | })
496 | }, [
497 | cells,
498 | edges,
499 | nodeProps,
500 | tokenLabels,
501 | layerLabels,
502 | tokenSelectors,
503 | curStartToken,
504 | curSelection,
505 | colorScale,
506 | ffnEdgeColorScale,
507 | edgeWidthScale,
508 | nLayers,
509 | nTokens,
510 | xScale,
511 | yScale
512 | ])
513 |
514 | return
515 | }
516 |
517 | export default withStreamlitConnection(ContributionGraph)
518 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More_considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial 4.0 International Public
58 | License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial 4.0 International Public License ("Public
63 | License"). To the extent this Public License may be interpreted as a
64 | contract, You are granted the Licensed Rights in consideration of Your
65 | acceptance of these terms and conditions, and the Licensor grants You
66 | such rights in consideration of benefits the Licensor receives from
67 | making the Licensed Material available under these terms and
68 | conditions.
69 |
70 | Section 1 -- Definitions.
71 |
72 | a. Adapted Material means material subject to Copyright and Similar
73 | Rights that is derived from or based upon the Licensed Material
74 | and in which the Licensed Material is translated, altered,
75 | arranged, transformed, or otherwise modified in a manner requiring
76 | permission under the Copyright and Similar Rights held by the
77 | Licensor. For purposes of this Public License, where the Licensed
78 | Material is a musical work, performance, or sound recording,
79 | Adapted Material is always produced where the Licensed Material is
80 | synched in timed relation with a moving image.
81 |
82 | b. Adapter's License means the license You apply to Your Copyright
83 | and Similar Rights in Your contributions to Adapted Material in
84 | accordance with the terms and conditions of this Public License.
85 |
86 | c. Copyright and Similar Rights means copyright and/or similar rights
87 | closely related to copyright including, without limitation,
88 | performance, broadcast, sound recording, and Sui Generis Database
89 | Rights, without regard to how the rights are labeled or
90 | categorized. For purposes of this Public License, the rights
91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
92 | Rights.
93 | d. Effective Technological Measures means those measures that, in the
94 | absence of proper authority, may not be circumvented under laws
95 | fulfilling obligations under Article 11 of the WIPO Copyright
96 | Treaty adopted on December 20, 1996, and/or similar international
97 | agreements.
98 |
99 | e. Exceptions and Limitations means fair use, fair dealing, and/or
100 | any other exception or limitation to Copyright and Similar Rights
101 | that applies to Your use of the Licensed Material.
102 |
103 | f. Licensed Material means the artistic or literary work, database,
104 | or other material to which the Licensor applied this Public
105 | License.
106 |
107 | g. Licensed Rights means the rights granted to You subject to the
108 | terms and conditions of this Public License, which are limited to
109 | all Copyright and Similar Rights that apply to Your use of the
110 | Licensed Material and that the Licensor has authority to license.
111 |
112 | h. Licensor means the individual(s) or entity(ies) granting rights
113 | under this Public License.
114 |
115 | i. NonCommercial means not primarily intended for or directed towards
116 | commercial advantage or monetary compensation. For purposes of
117 | this Public License, the exchange of the Licensed Material for
118 | other material subject to Copyright and Similar Rights by digital
119 | file-sharing or similar means is NonCommercial provided there is
120 | no payment of monetary compensation in connection with the
121 | exchange.
122 |
123 | j. Share means to provide material to the public by any means or
124 | process that requires permission under the Licensed Rights, such
125 | as reproduction, public display, public performance, distribution,
126 | dissemination, communication, or importation, and to make material
127 | available to the public including in ways that members of the
128 | public may access the material from a place and at a time
129 | individually chosen by them.
130 |
131 | k. Sui Generis Database Rights means rights other than copyright
132 | resulting from Directive 96/9/EC of the European Parliament and of
133 | the Council of 11 March 1996 on the legal protection of databases,
134 | as amended and/or succeeded, as well as other essentially
135 | equivalent rights anywhere in the world.
136 |
137 | l. You means the individual or entity exercising the Licensed Rights
138 | under this Public License. Your has a corresponding meaning.
139 |
140 | Section 2 -- Scope.
141 |
142 | a. License grant.
143 |
144 | 1. Subject to the terms and conditions of this Public License,
145 | the Licensor hereby grants You a worldwide, royalty-free,
146 | non-sublicensable, non-exclusive, irrevocable license to
147 | exercise the Licensed Rights in the Licensed Material to:
148 |
149 | a. reproduce and Share the Licensed Material, in whole or
150 | in part, for NonCommercial purposes only; and
151 |
152 | b. produce, reproduce, and Share Adapted Material for
153 | NonCommercial purposes only.
154 |
155 | 2. Exceptions and Limitations. For the avoidance of doubt, where
156 | Exceptions and Limitations apply to Your use, this Public
157 | License does not apply, and You do not need to comply with
158 | its terms and conditions.
159 |
160 | 3. Term. The term of this Public License is specified in Section
161 | 6(a).
162 |
163 | 4. Media and formats; technical modifications allowed. The
164 | Licensor authorizes You to exercise the Licensed Rights in
165 | all media and formats whether now known or hereafter created,
166 | and to make technical modifications necessary to do so. The
167 | Licensor waives and/or agrees not to assert any right or
168 | authority to forbid You from making technical modifications
169 | necessary to exercise the Licensed Rights, including
170 | technical modifications necessary to circumvent Effective
171 | Technological Measures. For purposes of this Public License,
172 | simply making modifications authorized by this Section 2(a)
173 | (4) never produces Adapted Material.
174 |
175 | 5. Downstream recipients.
176 |
177 | a. Offer from the Licensor -- Licensed Material. Every
178 | recipient of the Licensed Material automatically
179 | receives an offer from the Licensor to exercise the
180 | Licensed Rights under the terms and conditions of this
181 | Public License.
182 |
183 | b. No downstream restrictions. You may not offer or impose
184 | any additional or different terms or conditions on, or
185 | apply any Effective Technological Measures to, the
186 | Licensed Material if doing so restricts exercise of the
187 | Licensed Rights by any recipient of the Licensed
188 | Material.
189 |
190 | 6. No endorsement. Nothing in this Public License constitutes or
191 | may be construed as permission to assert or imply that You
192 | are, or that Your use of the Licensed Material is, connected
193 | with, or sponsored, endorsed, or granted official status by,
194 | the Licensor or others designated to receive attribution as
195 | provided in Section 3(a)(1)(A)(i).
196 |
197 | b. Other rights.
198 |
199 | 1. Moral rights, such as the right of integrity, are not
200 | licensed under this Public License, nor are publicity,
201 | privacy, and/or other similar personality rights; however, to
202 | the extent possible, the Licensor waives and/or agrees not to
203 | assert any such rights held by the Licensor to the limited
204 | extent necessary to allow You to exercise the Licensed
205 | Rights, but not otherwise.
206 |
207 | 2. Patent and trademark rights are not licensed under this
208 | Public License.
209 |
210 | 3. To the extent possible, the Licensor waives any right to
211 | collect royalties from You for the exercise of the Licensed
212 | Rights, whether directly or through a collecting society
213 | under any voluntary or waivable statutory or compulsory
214 | licensing scheme. In all other cases the Licensor expressly
215 | reserves any right to collect such royalties, including when
216 | the Licensed Material is used other than for NonCommercial
217 | purposes.
218 |
219 | Section 3 -- License Conditions.
220 |
221 | Your exercise of the Licensed Rights is expressly made subject to the
222 | following conditions.
223 |
224 | a. Attribution.
225 |
226 | 1. If You Share the Licensed Material (including in modified
227 | form), You must:
228 |
229 | a. retain the following if it is supplied by the Licensor
230 | with the Licensed Material:
231 |
232 | i. identification of the creator(s) of the Licensed
233 | Material and any others designated to receive
234 | attribution, in any reasonable manner requested by
235 | the Licensor (including by pseudonym if
236 | designated);
237 |
238 | ii. a copyright notice;
239 |
240 | iii. a notice that refers to this Public License;
241 |
242 | iv. a notice that refers to the disclaimer of
243 | warranties;
244 |
245 | v. a URI or hyperlink to the Licensed Material to the
246 | extent reasonably practicable;
247 |
248 | b. indicate if You modified the Licensed Material and
249 | retain an indication of any previous modifications; and
250 |
251 | c. indicate the Licensed Material is licensed under this
252 | Public License, and include the text of, or the URI or
253 | hyperlink to, this Public License.
254 |
255 | 2. You may satisfy the conditions in Section 3(a)(1) in any
256 | reasonable manner based on the medium, means, and context in
257 | which You Share the Licensed Material. For example, it may be
258 | reasonable to satisfy the conditions by providing a URI or
259 | hyperlink to a resource that includes the required
260 | information.
261 |
262 | 3. If requested by the Licensor, You must remove any of the
263 | information required by Section 3(a)(1)(A) to the extent
264 | reasonably practicable.
265 |
266 | 4. If You Share Adapted Material You produce, the Adapter's
267 | License You apply must not prevent recipients of the Adapted
268 | Material from complying with this Public License.
269 |
270 | Section 4 -- Sui Generis Database Rights.
271 |
272 | Where the Licensed Rights include Sui Generis Database Rights that
273 | apply to Your use of the Licensed Material:
274 |
275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
276 | to extract, reuse, reproduce, and Share all or a substantial
277 | portion of the contents of the database for NonCommercial purposes
278 | only;
279 |
280 | b. if You include all or a substantial portion of the database
281 | contents in a database in which You have Sui Generis Database
282 | Rights, then the database in which You have Sui Generis Database
283 | Rights (but not its individual contents) is Adapted Material; and
284 |
285 | c. You must comply with the conditions in Section 3(a) if You Share
286 | all or a substantial portion of the contents of the database.
287 |
288 | For the avoidance of doubt, this Section 4 supplements and does not
289 | replace Your obligations under this Public License where the Licensed
290 | Rights include other Copyright and Similar Rights.
291 |
292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
293 |
294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
304 |
305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
314 |
315 | c. The disclaimer of warranties and limitation of liability provided
316 | above shall be interpreted in a manner that, to the extent
317 | possible, most closely approximates an absolute disclaimer and
318 | waiver of all liability.
319 |
320 | Section 6 -- Term and Termination.
321 |
322 | a. This Public License applies for the term of the Copyright and
323 | Similar Rights licensed here. However, if You fail to comply with
324 | this Public License, then Your rights under this Public License
325 | terminate automatically.
326 |
327 | b. Where Your right to use the Licensed Material has terminated under
328 | Section 6(a), it reinstates:
329 |
330 | 1. automatically as of the date the violation is cured, provided
331 | it is cured within 30 days of Your discovery of the
332 | violation; or
333 |
334 | 2. upon express reinstatement by the Licensor.
335 |
336 | For the avoidance of doubt, this Section 6(b) does not affect any
337 | right the Licensor may have to seek remedies for Your violations
338 | of this Public License.
339 |
340 | c. For the avoidance of doubt, the Licensor may also offer the
341 | Licensed Material under separate terms or conditions or stop
342 | distributing the Licensed Material at any time; however, doing so
343 | will not terminate this Public License.
344 |
345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
346 | License.
347 |
348 | Section 7 -- Other Terms and Conditions.
349 |
350 | a. The Licensor shall not be bound by any additional or different
351 | terms or conditions communicated by You unless expressly agreed.
352 |
353 | b. Any arrangements, understandings, or agreements regarding the
354 | Licensed Material not stated herein are separate from and
355 | independent of the terms and conditions of this Public License.
356 |
357 | Section 8 -- Interpretation.
358 |
359 | a. For the avoidance of doubt, this Public License does not, and
360 | shall not be interpreted to, reduce, limit, restrict, or impose
361 | conditions on any use of the Licensed Material that could lawfully
362 | be made without permission under this Public License.
363 |
364 | b. To the extent possible, if any provision of this Public License is
365 | deemed unenforceable, it shall be automatically reformed to the
366 | minimum extent necessary to make it enforceable. If the provision
367 | cannot be reformed, it shall be severed from this Public License
368 | without affecting the enforceability of the remaining terms and
369 | conditions.
370 |
371 | c. No term or condition of this Public License will be waived and no
372 | failure to comply consented to unless expressly agreed to by the
373 | Licensor.
374 |
375 | d. Nothing in this Public License constitutes or may be interpreted
376 | as a limitation upon, or waiver of, any privileges and immunities
377 | that apply to the Licensor or You, including from the legal
378 | processes of any jurisdiction or authority.
379 |
380 | =======================================================================
381 |
382 | Creative Commons is not a party to its public
383 | licenses. Notwithstanding, Creative Commons may elect to apply one of
384 | its public licenses to material it publishes and in those instances
385 | will be considered the “Licensor.” The text of the Creative Commons
386 | public licenses is dedicated to the public domain under the CC0 Public
387 | Domain Dedication. Except for the limited purpose of indicating that
388 | material is shared under a Creative Commons public license or as
389 | otherwise permitted by the Creative Commons policies published at
390 | creativecommons.org/policies, Creative Commons does not authorize the
391 | use of the trademark "Creative Commons" or any other trademark or logo
392 | of Creative Commons without its prior written consent including,
393 | without limitation, in connection with any unauthorized modifications
394 | to any of its public licenses or any other arrangements,
395 | understandings, or agreements concerning use of licensed material. For
396 | the avoidance of doubt, this paragraph does not form part of the
397 | public licenses.
398 |
399 | Creative Commons may be contacted at creativecommons.org.
--------------------------------------------------------------------------------
/llm_transparency_tool/server/app.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | from dataclasses import dataclass, field
9 | from typing import Dict, List, Optional, Tuple
10 |
11 | import networkx as nx
12 | import pandas as pd
13 | import plotly.express
14 | import plotly.graph_objects as go
15 | import streamlit as st
16 | import streamlit_extras.row as st_row
17 | import torch
18 | from jaxtyping import Float
19 | from torch.amp import autocast
20 | from transformers import HfArgumentParser
21 |
22 | import llm_transparency_tool.components
23 | from llm_transparency_tool.models.tlens_model import TransformerLensTransparentLlm
24 | import llm_transparency_tool.routes.contributions as contributions
25 | import llm_transparency_tool.routes.graph
26 | from llm_transparency_tool.models.transparent_llm import TransparentLlm
27 | from llm_transparency_tool.routes.graph_node import NodeType
28 | from llm_transparency_tool.server.graph_selection import (
29 | GraphSelection,
30 | UiGraphEdge,
31 | UiGraphNode,
32 | )
33 | from llm_transparency_tool.server.styles import (
34 | RenderSettings,
35 | logits_color_map,
36 | margins_css,
37 | string_to_display,
38 | )
39 | from llm_transparency_tool.server.utils import (
40 | B0,
41 | get_contribution_graph,
42 | load_dataset,
43 | load_model,
44 | possible_devices,
45 | run_model_with_session_caching,
46 | st_placeholder,
47 | )
48 | from llm_transparency_tool.server.monitor import SystemMonitor
49 |
50 | from networkx.classes.digraph import DiGraph
51 |
52 |
53 | @st.cache_resource(
54 | hash_funcs={
55 | nx.Graph: id,
56 | DiGraph: id
57 | }
58 | )
59 | def cached_build_paths_to_predictions(
60 | graph: nx.Graph,
61 | n_layers: int,
62 | n_tokens: int,
63 | starting_tokens: List[int],
64 | threshold: float,
65 | ):
66 | return llm_transparency_tool.routes.graph.build_paths_to_predictions(
67 | graph, n_layers, n_tokens, starting_tokens, threshold
68 | )
69 |
70 | @st.cache_resource(
71 | hash_funcs={
72 | TransformerLensTransparentLlm: id
73 | }
74 | )
75 | def cached_run_inference_and_populate_state(
76 | stateless_model,
77 | sentences,
78 | ):
79 | stateful_model = stateless_model.copy()
80 | stateful_model.run(sentences)
81 | return stateful_model
82 |
83 |
84 | @dataclass
85 | class LlmViewerConfig:
86 | debug: bool = field(
87 | default=False,
88 | metadata={"help": "Show debugging information, like the time profile."},
89 | )
90 |
91 | preloaded_dataset_filename: Optional[str] = field(
92 | default=None,
93 | metadata={"help": "The name of the text file to load the lines from."},
94 | )
95 |
96 | demo_mode: bool = field(
97 | default=False,
98 | metadata={"help": "Whether the app should be in the demo mode."},
99 | )
100 |
101 | allow_loading_dataset_files: bool = field(
102 | default=True,
103 | metadata={"help": "Whether the app should be able to load the dataset files " "on the server side."},
104 | )
105 |
106 | max_user_string_length: Optional[int] = field(
107 | default=None,
108 | metadata={
109 | "help": "Limit for the length of user-provided sentences (in characters), " "or None if there is no limit."
110 | },
111 | )
112 |
113 | models: Dict[str, str] = field(
114 | default_factory=dict,
115 | metadata={
116 | "help": "Locations of models which are stored locally. Dictionary: official "
117 | "HuggingFace name -> path to dir. If None is specified, the model will be"
118 | "downloaded from HuggingFace."
119 | },
120 | )
121 |
122 | default_model: str = field(
123 | default="",
124 | metadata={"help": "The model to load once the UI is started."},
125 | )
126 |
127 |
128 | class App:
129 | _stateful_model: TransparentLlm = None
130 | render_settings = RenderSettings()
131 | _graph: Optional[nx.Graph] = None
132 | _contribution_threshold: float = 0.0
133 | _renormalize_after_threshold: bool = False
134 | _normalize_before_unembedding: bool = True
135 |
136 | @property
137 | def stateful_model(self) -> TransparentLlm:
138 | return self._stateful_model
139 |
140 | def __init__(self, config: LlmViewerConfig):
141 | self._config = config
142 | st.set_page_config(layout="wide")
143 | st.markdown(margins_css, unsafe_allow_html=True)
144 |
145 | def _get_representation(self, node: Optional[UiGraphNode]) -> Optional[Float[torch.Tensor, "d_model"]]:
146 | if node is None:
147 | return None
148 | fn = {
149 | NodeType.AFTER_ATTN: self.stateful_model.residual_after_attn,
150 | NodeType.AFTER_FFN: self.stateful_model.residual_out,
151 | NodeType.FFN: None,
152 | NodeType.ORIGINAL: self.stateful_model.residual_in,
153 | }
154 | return fn[node.type](node.layer)[B0][node.token]
155 |
156 | def draw_model_info(self):
157 | info = self.stateful_model.model_info().__dict__
158 | df = pd.DataFrame(
159 | data=[str(x) for x in info.values()],
160 | index=info.keys(),
161 | columns=["Model parameter"],
162 | )
163 | st.dataframe(df, use_container_width=False)
164 |
165 | def draw_dataset_selection(self) -> int:
166 | def update_dataset(filename: Optional[str]):
167 | dataset = load_dataset(filename) if filename is not None else []
168 | st.session_state["dataset"] = dataset
169 | st.session_state["dataset_file"] = filename
170 |
171 | if "dataset" not in st.session_state:
172 | update_dataset(self._config.preloaded_dataset_filename)
173 |
174 |
175 | if not self._config.demo_mode:
176 | with st.sidebar.expander("Dataset", expanded=False):
177 | if self._config.allow_loading_dataset_files:
178 | row_f = st_row.row([2, 1], vertical_align="bottom")
179 | filename = row_f.text_input("Dataset", value=st.session_state.dataset_file or "", label_visibility="collapsed")
180 | if row_f.button("Load"):
181 | update_dataset(filename)
182 | row_s = st_row.row([2, 1], vertical_align="bottom")
183 | new_sentence = row_s.text_area("New sentence", label_visibility="collapsed")
184 | new_sentence_added = False
185 |
186 | if row_s.button("Add"):
187 | max_len = self._config.max_user_string_length
188 | n = len(new_sentence)
189 | if max_len is None or n <= max_len:
190 | st.session_state.dataset.append(new_sentence)
191 | new_sentence_added = True
192 | st.session_state.sentence_selector = new_sentence
193 | else:
194 | st.warning(f"Sentence length {n} is larger than " f"the configured limit of {max_len}")
195 |
196 | sentences = st.session_state.dataset
197 | selection = st.selectbox(
198 | "Sentence",
199 | sentences,
200 | index=len(sentences) - 1,
201 | key="sentence_selector",
202 | )
203 | return selection
204 |
205 | def _unembed(
206 | self,
207 | representation: torch.Tensor,
208 | ) -> torch.Tensor:
209 | return self.stateful_model.unembed(representation, normalize=self._normalize_before_unembedding)
210 |
211 | def draw_graph(self, contribution_threshold: float) -> Optional[GraphSelection]:
212 | tokens = self.stateful_model.tokens()[B0]
213 | n_tokens = tokens.shape[0]
214 | model_info = self.stateful_model.model_info()
215 |
216 | graphs = cached_build_paths_to_predictions(
217 | self._graph,
218 | model_info.n_layers,
219 | n_tokens,
220 | range(n_tokens),
221 | contribution_threshold,
222 | )
223 |
224 | return llm_transparency_tool.components.contribution_graph(
225 | model_info,
226 | self.stateful_model.tokens_to_strings(tokens),
227 | graphs,
228 | key=f"graph_{hash(self.sentence)}",
229 | )
230 |
231 | def draw_token_matrix(
232 | self,
233 | values: Float[torch.Tensor, "t t"],
234 | tokens: List[str],
235 | value_name: str,
236 | title: str,
237 | ):
238 | assert values.shape[0] == len(tokens)
239 | labels = {
240 | "x": "src",
241 | "y": "tgt",
242 | "color": value_name,
243 | }
244 |
245 | captions = [f"({i}){t}" for i, t in enumerate(tokens)]
246 |
247 | fig = plotly.express.imshow(
248 | values.cpu(),
249 | title=f'{title}',
250 | labels=labels,
251 | x=captions,
252 | y=captions,
253 | color_continuous_scale=self.render_settings.attention_color_map,
254 | aspect="equal",
255 | )
256 | fig.update_layout(
257 | autosize=True,
258 | margin=go.layout.Margin(
259 | l=50, # left margin
260 | r=0, # right margin
261 | b=100, # bottom margin
262 | t=100, # top margin
263 | # pad=10 # padding
264 | )
265 | )
266 | fig.update_xaxes(tickmode="linear")
267 | fig.update_yaxes(tickmode="linear")
268 | fig.update_coloraxes(showscale=False)
269 |
270 | st.plotly_chart(fig, use_container_width=True, theme=None)
271 |
272 | def draw_attn_info(self, edge: UiGraphEdge, container_attention_map) -> Optional[int]:
273 | """
274 | Returns: the index of the selected head.
275 | """
276 |
277 | n_heads = self.stateful_model.model_info().n_heads
278 |
279 | layer = edge.target.layer
280 |
281 | head_contrib, _ = contributions.get_attention_contributions(
282 | resid_pre=self.stateful_model.residual_in(layer)[B0].unsqueeze(0),
283 | resid_mid=self.stateful_model.residual_after_attn(layer)[B0].unsqueeze(0),
284 | decomposed_attn=self.stateful_model.decomposed_attn(B0, layer).unsqueeze(0),
285 | )
286 |
287 | # [batch pos key_pos head] -> [head]
288 | flat_contrib = head_contrib[0, edge.target.token, edge.source.token, :]
289 | assert flat_contrib.shape[0] == n_heads, f"{flat_contrib.shape} vs {n_heads}"
290 |
291 | selected_head = llm_transparency_tool.components.selector(
292 | items=[f"H{h}" if h >= 0 else "All" for h in range(-1, n_heads)],
293 | indices=range(-1, n_heads),
294 | temperatures=[sum(flat_contrib).item()] + flat_contrib.tolist(),
295 | preselected_index=flat_contrib.argmax().item(),
296 | key=f"head_selector_layer_{layer}" #_from_tok_{edge.source.token}_to_tok_{edge.target.token}",
297 | )
298 | print(f"head_selector_layer_{layer}_from_tok_{edge.source.token}_to_tok_{edge.target.token}")
299 | if selected_head == -1 or selected_head is None:
300 | # selected_head = None
301 | selected_head = flat_contrib.argmax().item()
302 | print('****\n' * 3 + f"selected_head: {selected_head}" + '\n****\n' * 3)
303 |
304 | # Draw attention matrix and contributions for the selected head.
305 | if selected_head is not None:
306 | tokens = [
307 | string_to_display(s) for s in self.stateful_model.tokens_to_strings(self.stateful_model.tokens()[B0])
308 | ]
309 |
310 | with container_attention_map:
311 | attn_container, contrib_container = st.columns([1, 1])
312 | with attn_container:
313 | attn = self.stateful_model.attention_matrix(B0, layer, selected_head)
314 | self.draw_token_matrix(
315 | attn,
316 | tokens,
317 | "attention",
318 | f"Attention map L{layer} H{selected_head}",
319 | )
320 | with contrib_container:
321 | contrib = head_contrib[B0, :, :, selected_head]
322 | self.draw_token_matrix(
323 | contrib,
324 | tokens,
325 | "contribution",
326 | f"Contribution map L{layer} H{selected_head}",
327 | )
328 |
329 | return selected_head
330 |
331 | def draw_ffn_info(self, node: UiGraphNode) -> Optional[int]:
332 | """
333 | Returns: the index of the selected neuron.
334 | """
335 |
336 | resid_mid = self.stateful_model.residual_after_attn(node.layer)[B0][node.token]
337 | resid_post = self.stateful_model.residual_out(node.layer)[B0][node.token]
338 | decomposed_ffn = self.stateful_model.decomposed_ffn_out(B0, node.layer, node.token)
339 | c_ffn, _ = contributions.get_decomposed_mlp_contributions(resid_mid, resid_post, decomposed_ffn)
340 |
341 | top_values, top_i = c_ffn.sort(descending=True)
342 | n = min(self.render_settings.n_top_neurons, c_ffn.shape[0])
343 | top_neurons = top_i[0:n].tolist()
344 |
345 | selected_neuron = llm_transparency_tool.components.selector(
346 | items=[f"{top_neurons[i]}" if i >= 0 else "All" for i in range(-1, n)],
347 | indices=range(-1, n),
348 | temperatures=[0.0] + top_values[0:n].tolist(),
349 | preselected_index=-1,
350 | key="neuron_selector",
351 | )
352 | if selected_neuron is None:
353 | selected_neuron = -1
354 | selected_neuron = None if selected_neuron == -1 else top_neurons[selected_neuron]
355 |
356 | return selected_neuron
357 |
358 | def _draw_token_table(
359 | self,
360 | n_top: int,
361 | n_bottom: int,
362 | representation: torch.Tensor,
363 | predecessor: Optional[torch.Tensor] = None,
364 | ):
365 | n_total = n_top + n_bottom
366 |
367 | logits = self._unembed(representation)
368 | n_vocab = logits.shape[0]
369 | scores, indices = torch.topk(logits, n_top, largest=True)
370 | positions = list(range(n_top))
371 |
372 | if n_bottom > 0:
373 | low_scores, low_indices = torch.topk(logits, n_bottom, largest=False)
374 | indices = torch.cat((indices, low_indices.flip(0)))
375 | scores = torch.cat((scores, low_scores.flip(0)))
376 | positions += range(n_vocab - n_bottom, n_vocab)
377 |
378 | tokens = [string_to_display(w) for w in self.stateful_model.tokens_to_strings(indices)]
379 |
380 | if predecessor is not None:
381 | pre_logits = self._unembed(predecessor)
382 | _, sorted_pre_indices = pre_logits.sort(descending=True)
383 | pre_indices_dict = {index: pos for pos, index in enumerate(sorted_pre_indices.tolist())}
384 | old_positions = [pre_indices_dict[i] for i in indices.tolist()]
385 |
386 | def pos_gain_string(pos, old_pos):
387 | if pos == old_pos:
388 | return ""
389 | sign = "↓" if pos > old_pos else "↑"
390 | return f"({sign}{abs(pos - old_pos)})"
391 |
392 | position_strings = [f"{i} {pos_gain_string(i, old_i)}" for (i, old_i) in zip(positions, old_positions)]
393 | else:
394 | position_strings = [str(pos) for pos in positions]
395 |
396 | def pos_gain_color(s):
397 | color = "black"
398 | if isinstance(s, str):
399 | if "↓" in s:
400 | color = "red"
401 | if "↑" in s:
402 | color = "green"
403 | return f"color: {color}"
404 |
405 | top_df = pd.DataFrame(
406 | data=zip(position_strings, tokens, scores.tolist()),
407 | columns=["Pos", "Token", "Score"],
408 | )
409 |
410 | st.dataframe(
411 | top_df.style.applymap(pos_gain_color)
412 | .background_gradient(
413 | axis=0,
414 | cmap=logits_color_map(positive_and_negative=n_bottom > 0),
415 | )
416 | .format(precision=3),
417 | hide_index=True,
418 | height=self.render_settings.table_cell_height * (n_total + 1),
419 | use_container_width=True,
420 | )
421 |
422 | def draw_token_dynamics(self, representation: torch.Tensor, block_name: str) -> None:
423 | st.caption(block_name)
424 | self._draw_token_table(
425 | self.render_settings.n_promoted_tokens,
426 | self.render_settings.n_suppressed_tokens,
427 | representation,
428 | None,
429 | )
430 |
431 | def draw_top_tokens(
432 | self,
433 | node: UiGraphNode,
434 | container_top_tokens,
435 | container_token_dynamics,
436 | ) -> None:
437 | pre_node = node.get_residual_predecessor()
438 | if pre_node is None:
439 | return
440 |
441 | representation = self._get_representation(node)
442 | predecessor = self._get_representation(pre_node)
443 |
444 | with container_top_tokens:
445 | st.caption(node.get_name())
446 | self._draw_token_table(
447 | self.render_settings.n_top_tokens,
448 | 0,
449 | representation,
450 | predecessor,
451 | )
452 | if container_token_dynamics is not None:
453 | with container_token_dynamics:
454 | self.draw_token_dynamics(representation - predecessor, node.get_predecessor_block_name())
455 |
456 | def draw_attention_dynamics(self, node: UiGraphNode, head: Optional[int]):
457 | block_name = node.get_head_name(head)
458 | block_output = (
459 | self.stateful_model.attention_output_per_head(B0, node.layer, node.token, head)
460 | if head is not None
461 | else self.stateful_model.attention_output(B0, node.layer, node.token)
462 | )
463 | self.draw_token_dynamics(block_output, block_name)
464 |
465 | def draw_ffn_dynamics(self, node: UiGraphNode, neuron: Optional[int]):
466 | block_name = node.get_neuron_name(neuron)
467 | block_output = (
468 | self.stateful_model.neuron_output(node.layer, neuron)
469 | if neuron is not None
470 | else self.stateful_model.ffn_out(node.layer)[B0][node.token]
471 | )
472 | self.draw_token_dynamics(block_output, block_name)
473 |
474 | def draw_precision_controls(self, device: str) -> Tuple[torch.dtype, bool]:
475 | """
476 | Draw fp16/fp32 switch and AMP control.
477 |
478 | return: The selected precision and whether AMP should be enabled.
479 | """
480 |
481 | if device == "cpu":
482 | dtype = torch.float32
483 | else:
484 | dtype = st.selectbox(
485 | "Precision",
486 | [torch.float16, torch.bfloat16, torch.float32],
487 | index=0,
488 | )
489 |
490 | amp_enabled = dtype != torch.float32
491 |
492 | return dtype, amp_enabled
493 |
494 | def draw_controls(self):
495 | # model_container, data_container = st.columns([1, 1])
496 | with st.sidebar.expander("Model", expanded=True):
497 | list_of_devices = possible_devices()
498 | if len(list_of_devices) > 1:
499 | self.device = st.selectbox(
500 | "Device",
501 | possible_devices(),
502 | index=0,
503 | )
504 | else:
505 | self.device = list_of_devices[0]
506 |
507 | self.dtype, self.amp_enabled = self.draw_precision_controls(self.device)
508 |
509 | model_list = list(self._config.models)
510 | default_choice = model_list.index(self._config.default_model)
511 |
512 | self.supported_model_name = st.selectbox(
513 | "Model name",
514 | model_list,
515 | index=default_choice,
516 | )
517 | self.model_name = st.text_input("Custom model name", value=self.supported_model_name)
518 |
519 | if self.model_name:
520 | self._stateful_model = load_model(
521 | model_name=self.model_name,
522 | _model_path=self._config.models[self.model_name],
523 | _device=self.device,
524 | _dtype=self.dtype,
525 | supported_model_name=None if not self.supported_model_name else self.supported_model_name,
526 | )
527 | self.model_key = self.model_name # TODO maybe something else?
528 | self.draw_model_info()
529 |
530 | self.sentence = self.draw_dataset_selection()
531 |
532 | with st.sidebar.expander("Graph", expanded=True):
533 | self._contribution_threshold = st.slider(
534 | min_value=0.01,
535 | max_value=0.1,
536 | step=0.01,
537 | value=0.04,
538 | format=r"%.3f",
539 | label="Contribution threshold",
540 | )
541 | self._renormalize_after_threshold = st.checkbox("Renormalize after threshold", value=True)
542 | self._normalize_before_unembedding = st.checkbox("Normalize before unembedding", value=True)
543 |
544 | def run_inference(self):
545 |
546 | with autocast(enabled=self.amp_enabled, device_type="cuda", dtype=self.dtype):
547 | self._stateful_model = cached_run_inference_and_populate_state(self.stateful_model, [self.sentence])
548 |
549 | with autocast(enabled=self.amp_enabled, device_type="cuda", dtype=self.dtype):
550 | self._graph = get_contribution_graph(
551 | self.stateful_model,
552 | self.model_key,
553 | self.stateful_model.tokens()[B0].tolist(),
554 | (self._contribution_threshold if self._renormalize_after_threshold else 0.0),
555 | )
556 |
557 | def draw_graph_and_selection(
558 | self,
559 | ) -> None:
560 | (
561 | container_graph,
562 | container_tokens,
563 | ) = st.columns(self.render_settings.column_proportions)
564 |
565 | container_graph_left, container_graph_right = container_graph.columns([5, 1])
566 |
567 | container_graph_left.write('##### Graph')
568 | heads_placeholder = container_graph_right.empty()
569 | heads_placeholder.write('##### Blocks')
570 | container_graph_right_used = False
571 |
572 | container_top_tokens, container_token_dynamics = container_tokens.columns([1, 1])
573 | container_top_tokens.write('##### Top Tokens')
574 | container_top_tokens_used = False
575 | container_token_dynamics.write('##### Promoted Tokens')
576 | container_token_dynamics_used = False
577 |
578 | try:
579 |
580 | if self.sentence is None:
581 | return
582 |
583 | with container_graph_left:
584 | selection = self.draw_graph(self._contribution_threshold if not self._renormalize_after_threshold else 0.0)
585 |
586 | if selection is None:
587 | return
588 |
589 | node = selection.node
590 | edge = selection.edge
591 |
592 | if edge is not None and edge.target.type == NodeType.AFTER_ATTN:
593 | with container_graph_right:
594 | container_graph_right_used = True
595 | heads_placeholder.write('##### Heads')
596 | head = self.draw_attn_info(edge, container_graph)
597 | with container_token_dynamics:
598 | self.draw_attention_dynamics(edge.target, head)
599 | container_token_dynamics_used = True
600 | elif node is not None and node.type == NodeType.FFN:
601 | with container_graph_right:
602 | container_graph_right_used = True
603 | heads_placeholder.write('##### Neurons')
604 | neuron = self.draw_ffn_info(node)
605 | with container_token_dynamics:
606 | self.draw_ffn_dynamics(node, neuron)
607 | container_token_dynamics_used = True
608 |
609 | if node is not None and node.is_in_residual_stream():
610 | self.draw_top_tokens(
611 | node,
612 | container_top_tokens,
613 | container_token_dynamics if not container_token_dynamics_used else None,
614 | )
615 | container_top_tokens_used = True
616 | container_token_dynamics_used = True
617 | finally:
618 | if not container_graph_right_used:
619 | st_placeholder('Click on an edge to see head contributions. \n\n'
620 | 'Or click on FFN to see individual neuron contributions.', container_graph_right, height=1100)
621 | if not container_top_tokens_used:
622 | st_placeholder('Select a node from residual stream to see its top tokens.', container_top_tokens, height=1100)
623 | if not container_token_dynamics_used:
624 | st_placeholder('Select a node to see its promoted tokens.', container_token_dynamics, height=1100)
625 |
626 |
627 | def run(self):
628 |
629 | if self._config.demo_mode:
630 | with st.sidebar.expander("About", expanded=True):
631 | st.caption("""
632 | The app is deployed in Demo Mode, thus only predefined models and inputs are available.\n
633 | You can still install the app locally and use your own models and inputs.\n
634 | See https://github.com/facebookresearch/llm-transparency-tool for more information.
635 | """)
636 |
637 | self.draw_controls()
638 |
639 | if not self.model_name:
640 | st.warning("No model selected")
641 | st.stop()
642 |
643 | if self.sentence is None:
644 | st.warning("No sentence selected")
645 | else:
646 | with torch.inference_mode():
647 | self.run_inference()
648 |
649 | self.draw_graph_and_selection()
650 |
651 |
652 | if __name__ == "__main__":
653 | top_parser = argparse.ArgumentParser()
654 | top_parser.add_argument("config_file")
655 | args = top_parser.parse_args()
656 |
657 | parser = HfArgumentParser([LlmViewerConfig])
658 | config = parser.parse_json_file(args.config_file)[0]
659 |
660 | with SystemMonitor(config.debug) as prof:
661 | app = App(config)
662 | app.run()
663 |
--------------------------------------------------------------------------------