├── .flake8 ├── pyproject.toml ├── .dockerignore ├── llm_transparency_tool ├── components │ ├── frontend │ │ ├── src │ │ │ ├── react-app-env.d.ts │ │ │ ├── common.tsx │ │ │ ├── index.tsx │ │ │ ├── LlmViewer.css │ │ │ ├── Selector.tsx │ │ │ └── ContributionGraph.tsx │ │ ├── .prettierrc │ │ ├── .env │ │ ├── tsconfig.json │ │ ├── public │ │ │ └── index.html │ │ └── package.json │ └── __init__.py ├── __init__.py ├── models │ ├── __init__.py │ ├── test_tlens_model.py │ ├── transparent_llm.py │ └── tlens_model.py ├── routes │ ├── __init__.py │ ├── graph_node.py │ ├── test_contributions.py │ ├── graph.py │ └── contributions.py └── server │ ├── graph_selection.py │ ├── monitor.py │ ├── styles.py │ ├── utils.py │ └── app.py ├── .gitignore ├── sample_input.txt ├── config ├── docker_hosting.json ├── docker_local.json └── local.json ├── setup.py ├── env.yaml ├── CONTRIBUTING.md ├── Dockerfile ├── CODE_OF_CONDUCT.md ├── README.md └── LICENSE /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | **/.git 2 | **/node_modules 3 | **/.mypy_cache 4 | -------------------------------------------------------------------------------- /llm_transparency_tool/components/frontend/src/react-app-env.d.ts: -------------------------------------------------------------------------------- 1 | /// 2 | -------------------------------------------------------------------------------- /llm_transparency_tool/components/frontend/.prettierrc: -------------------------------------------------------------------------------- 1 | { 2 | "endOfLine": "lf", 3 | "semi": false, 4 | "trailingComma": "es5" 5 | } 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/frontend/node_modules* 2 | **/frontend/build/ 3 | **/frontend/.yarn* 4 | .vscode/ 5 | .mypy_cache/ 6 | __pycache__/ 7 | .DS_Store 8 | -------------------------------------------------------------------------------- /sample_input.txt: -------------------------------------------------------------------------------- 1 | The war lasted from the year 1732 to the year 17 2 | 5 + 4 = 9, 2 + 3 = 3 | When Mary and John went to the store, John gave a drink to 4 | -------------------------------------------------------------------------------- /llm_transparency_tool/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /llm_transparency_tool/components/frontend/.env: -------------------------------------------------------------------------------- 1 | # Run the component's dev server on :3001 2 | # (The Streamlit dev server already runs on :3000) 3 | PORT=3001 4 | 5 | # Don't automatically open the web browser on `npm run start`. 6 | BROWSER=none 7 | -------------------------------------------------------------------------------- /llm_transparency_tool/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /llm_transparency_tool/routes/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /config/docker_hosting.json: -------------------------------------------------------------------------------- 1 | { 2 | "allow_loading_dataset_files": false, 3 | "max_user_string_length": 100, 4 | "preloaded_dataset_filename": "sample_input.txt", 5 | "debug": false, 6 | "demo_mode": true, 7 | "models": { 8 | "facebook/opt-125m": null, 9 | "gpt2": null, 10 | "distilgpt2": null 11 | }, 12 | "default_model": "gpt2" 13 | } 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import setup 8 | 9 | setup( 10 | name="llm_transparency_tool", 11 | version="0.1", 12 | packages=["llm_transparency_tool"], 13 | ) 14 | -------------------------------------------------------------------------------- /llm_transparency_tool/components/frontend/src/common.tsx: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | export interface Point { 10 | x: number 11 | y: number 12 | } 13 | 14 | export interface Label { 15 | text: string 16 | pos: Point 17 | } 18 | -------------------------------------------------------------------------------- /llm_transparency_tool/components/frontend/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "es5", 4 | "lib": ["dom", "dom.iterable", "esnext"], 5 | "allowJs": true, 6 | "skipLibCheck": true, 7 | "esModuleInterop": true, 8 | "allowSyntheticDefaultImports": true, 9 | "strict": true, 10 | "forceConsistentCasingInFileNames": true, 11 | "module": "esnext", 12 | "moduleResolution": "node", 13 | "resolveJsonModule": true, 14 | "isolatedModules": true, 15 | "noEmit": true, 16 | "jsx": "react" 17 | }, 18 | "include": ["src"] 19 | } 20 | -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: llmtt 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | dependencies: 7 | - python=3.12 8 | - pytorch 9 | - pytorch-cuda=11.8 10 | - nodejs 11 | - yarn 12 | - pip 13 | - pip: 14 | - datasets 15 | - einops 16 | - fancy_einsum 17 | - jaxtyping==0.2.25 18 | - networkx 19 | - plotly 20 | - pyinstrument 21 | - setuptools 22 | - streamlit 23 | - streamlit_extras 24 | - tokenizers 25 | - transformer_lens 26 | - transformers 27 | - pytest # fixes wrong dependencies of transformer_lens 28 | -------------------------------------------------------------------------------- /llm_transparency_tool/components/frontend/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Contribution Graph for Streamlit 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 |
14 | 15 | -------------------------------------------------------------------------------- /config/docker_local.json: -------------------------------------------------------------------------------- 1 | { 2 | "allow_loading_dataset_files": true, 3 | "preloaded_dataset_filename": "sample_input.txt", 4 | "debug": true, 5 | "models": { 6 | "": null, 7 | "facebook/opt-125m": null, 8 | "facebook/opt-1.3b": null, 9 | "facebook/opt-2.7b": null, 10 | "facebook/opt-6.7b": null, 11 | "facebook/opt-13b": null, 12 | "facebook/opt-30b": null, 13 | "meta-llama/Llama-2-7b-hf": null, 14 | "meta-llama/Llama-2-7b-chat-hf": null, 15 | "meta-llama/Llama-2-13b-hf": null, 16 | "meta-llama/Llama-2-13b-chat-hf": null, 17 | "gpt2": null, 18 | "gpt2-medium": null, 19 | "gpt2-large": null, 20 | "gpt2-xl": null, 21 | "distilgpt2": null 22 | }, 23 | "default_model": "distilgpt2", 24 | "demo_mode": false 25 | } 26 | -------------------------------------------------------------------------------- /llm_transparency_tool/components/frontend/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "contribution_graph", 3 | "version": "0.1.0", 4 | "private": true, 5 | "dependencies": { 6 | "@types/d3": "^7.4.0", 7 | "d3": "^7.8.5", 8 | "react": "^18.2.0", 9 | "react-dom": "^18.2.0", 10 | "streamlit-component-lib": "^2.0.0" 11 | }, 12 | "scripts": { 13 | "start": "react-scripts start", 14 | "build": "react-scripts build", 15 | "test": "react-scripts test", 16 | "eject": "react-scripts eject" 17 | }, 18 | "browserslist": { 19 | "production": [ 20 | ">0.2%", 21 | "not dead", 22 | "not op_mini all" 23 | ], 24 | "development": [ 25 | "last 1 chrome version", 26 | "last 1 firefox version", 27 | "last 1 safari version" 28 | ] 29 | }, 30 | "homepage": ".", 31 | "devDependencies": { 32 | "@types/node": "^20.11.17", 33 | "@types/react": "^18.2.55", 34 | "@types/react-dom": "^18.2.19", 35 | "eslint-config-react-app": "^7.0.1", 36 | "react-scripts": "^5.0.1", 37 | "typescript": "^5.3.3" 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /llm_transparency_tool/components/frontend/src/index.tsx: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | import React from "react" 10 | import ReactDOM from "react-dom" 11 | 12 | import { 13 | ComponentProps, 14 | withStreamlitConnection, 15 | } from "streamlit-component-lib" 16 | 17 | 18 | import ContributionGraph from "./ContributionGraph" 19 | import Selector from "./Selector" 20 | 21 | const LlmViewerComponent = (props: ComponentProps) => { 22 | switch (props.args['component']) { 23 | case 'graph': 24 | return 25 | case 'selector': 26 | return 27 | default: 28 | return <> 29 | } 30 | }; 31 | 32 | const StreamlitLlmViewerComponent = withStreamlitConnection(LlmViewerComponent) 33 | 34 | ReactDOM.render( 35 | 36 | 37 | , 38 | document.getElementById("root") 39 | ) 40 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to llm-transparency-tool 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to llm-transparency-tool, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04 8 | 9 | RUN apt-get update && apt-get install -y \ 10 | wget \ 11 | git \ 12 | && apt-get clean \ 13 | && rm -rf /var/lib/apt/lists/* 14 | 15 | RUN useradd -m -u 1000 user 16 | USER user 17 | 18 | ENV HOME=/home/user 19 | 20 | RUN wget -P /tmp \ 21 | "https://github.com/conda-forge/miniforge/releases/download/23.11.0-0/Mambaforge-23.11.0-0-Linux-x86_64.sh" \ 22 | && bash /tmp/Mambaforge-23.11.0-0-Linux-x86_64.sh -b -p $HOME/mambaforge3 \ 23 | && rm /tmp/Mambaforge-23.11.0-0-Linux-x86_64.sh 24 | ENV PATH $HOME/mambaforge3/bin:$PATH 25 | 26 | WORKDIR $HOME 27 | 28 | ENV REPO=$HOME/llm-transparency-tool 29 | COPY --chown=user . $REPO 30 | 31 | WORKDIR $REPO 32 | 33 | RUN mamba env create --name llmtt -f env.yaml -y 34 | ENV PATH $HOME/mambaforge3/envs/llmtt/bin:$PATH 35 | RUN pip install -e . 36 | 37 | RUN cd llm_transparency_tool/components/frontend \ 38 | && yarn install \ 39 | && yarn build 40 | 41 | EXPOSE 7860 42 | CMD ["streamlit", "run", "llm_transparency_tool/server/app.py", "--server.port=7860", "--server.address=0.0.0.0", "--theme.font=Inconsolata", "--", "config/docker_hosting.json"] 43 | -------------------------------------------------------------------------------- /llm_transparency_tool/components/frontend/src/LlmViewer.css: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | .graph-container { 10 | display: flex; 11 | justify-content: center; 12 | align-items: center; 13 | height: 100vh; 14 | } 15 | 16 | .svg { 17 | border: 1px solid #ccc; 18 | } 19 | 20 | .layer-highlight { 21 | fill: #f0f5f0; 22 | } 23 | 24 | .selectable-item { 25 | stroke: black; 26 | cursor: pointer; 27 | } 28 | 29 | .selection, 30 | .selection:hover { 31 | fill: orange; 32 | } 33 | 34 | .active-residual-node { 35 | fill: yellowgreen; 36 | } 37 | 38 | .active-residual-node:hover { 39 | fill: olivedrab; 40 | } 41 | 42 | .active-ffn-node { 43 | fill: orchid; 44 | } 45 | 46 | .active-ffn-node:hover { 47 | fill: purple; 48 | } 49 | 50 | .inactive-node { 51 | fill: lightgray; 52 | stroke-width: 0.5px; 53 | } 54 | 55 | .inactive-node:hover { 56 | fill: gray; 57 | } 58 | 59 | .selectable-edge { 60 | cursor: pointer; 61 | } 62 | 63 | .token-selector { 64 | fill: lightblue; 65 | } 66 | 67 | .token-selector:hover { 68 | fill: cornflowerblue; 69 | } 70 | 71 | .selector-item { 72 | fill: lightblue; 73 | } 74 | 75 | .selector-item:hover { 76 | fill: cornflowerblue; 77 | } -------------------------------------------------------------------------------- /config/local.json: -------------------------------------------------------------------------------- 1 | { 2 | "allow_loading_dataset_files": true, 3 | "preloaded_dataset_filename": "sample_input.txt", 4 | "debug": true, 5 | "models": { 6 | "": null, 7 | 8 | "meta-llama/Meta-Llama-3-8B": null, 9 | "meta-llama/Meta-Llama-3-70B": null, 10 | "meta-llama/Meta-Llama-3-8B-Instruct": null, 11 | "meta-llama/Meta-Llama-3-70B-Instruct": null, 12 | "mistral/mixtral-instruct": null, 13 | 14 | "gpt2": null, 15 | "distilgpt2": null, 16 | "facebook/opt-125m": null, 17 | "facebook/opt-1.3b": null, 18 | "EleutherAI/gpt-neo-125M": null, 19 | "Qwen/Qwen-1_8B": null, 20 | "Qwen/Qwen1.5-0.5B": null, 21 | "Qwen/Qwen1.5-0.5B-Chat": null, 22 | "Qwen/Qwen1.5-1.8B": null, 23 | "Qwen/Qwen1.5-1.8B-Chat": null, 24 | "microsoft/phi-1": null, 25 | "microsoft/phi-1_5": null, 26 | "microsoft/phi-2": null, 27 | 28 | "meta-llama/Llama-2-7b-hf": null, 29 | "meta-llama/Llama-2-7b-chat-hf": null, 30 | 31 | "meta-llama/Llama-2-13b-hf": null, 32 | "meta-llama/Llama-2-13b-chat-hf": null, 33 | 34 | "meta-llama/Llama-2-70b-hf": null, 35 | "meta-llama/Llama-2-70b-chat-hf": null, 36 | 37 | 38 | "gpt2-medium": null, 39 | "gpt2-large": null, 40 | "gpt2-xl": null, 41 | 42 | "mistralai/Mistral-7B-v0.1": null, 43 | "mistralai/Mistral-7B-Instruct-v0.1": null, 44 | "mistralai/Mistral-7B-Instruct-v0.2": null, 45 | 46 | "google/gemma-7b": null, 47 | "google/gemma-2b": null, 48 | 49 | "facebook/opt-2.7b": null, 50 | "facebook/opt-6.7b": null, 51 | "facebook/opt-13b": null, 52 | "facebook/opt-30b": null 53 | }, 54 | "default_model": "", 55 | "demo_mode": false 56 | } 57 | -------------------------------------------------------------------------------- /llm_transparency_tool/server/graph_selection.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass 8 | from typing import Any, Dict, Optional 9 | 10 | from llm_transparency_tool.routes.graph_node import GraphNode, NodeType 11 | 12 | 13 | class UiGraphNode(GraphNode): 14 | @staticmethod 15 | def from_json(json: Dict[str, Any]) -> Optional["UiGraphNode"]: 16 | try: 17 | layer = json["cell"]["layer"] 18 | token = json["cell"]["token"] 19 | type = NodeType(json["item"]) 20 | return UiGraphNode(layer, token, type) 21 | except (TypeError, KeyError): 22 | return None 23 | 24 | 25 | @dataclass 26 | class UiGraphEdge: 27 | source: UiGraphNode 28 | target: UiGraphNode 29 | weight: float 30 | 31 | @staticmethod 32 | def from_json(json: Dict[str, Any]) -> Optional["UiGraphEdge"]: 33 | try: 34 | source = UiGraphNode.from_json(json["from"]) 35 | target = UiGraphNode.from_json(json["to"]) 36 | if source is None or target is None: 37 | return None 38 | weight = float(json["weight"]) 39 | return UiGraphEdge(source, target, weight) 40 | except (TypeError, KeyError): 41 | return None 42 | 43 | 44 | @dataclass 45 | class GraphSelection: 46 | node: Optional[UiGraphNode] 47 | edge: Optional[UiGraphEdge] 48 | 49 | @staticmethod 50 | def from_json(json: Dict[str, Any]) -> Optional["GraphSelection"]: 51 | try: 52 | node = UiGraphNode.from_json(json["node"]) 53 | edge = UiGraphEdge.from_json(json["edge"]) 54 | return GraphSelection(node, edge) 55 | except (TypeError, KeyError): 56 | return None 57 | -------------------------------------------------------------------------------- /llm_transparency_tool/routes/graph_node.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass 8 | from enum import Enum 9 | from typing import List, Optional 10 | 11 | 12 | class NodeType(Enum): 13 | AFTER_ATTN = "after_attn" 14 | AFTER_FFN = "after_ffn" 15 | FFN = "ffn" 16 | ORIGINAL = "original" # The original tokens 17 | 18 | 19 | def _format_block_hierachy_string(blocks: List[str]) -> str: 20 | return " ▸ ".join(blocks) 21 | 22 | 23 | @dataclass 24 | class GraphNode: 25 | layer: int 26 | token: int 27 | type: NodeType 28 | 29 | def is_in_residual_stream(self) -> bool: 30 | return self.type in [NodeType.AFTER_ATTN, NodeType.AFTER_FFN] 31 | 32 | def get_residual_predecessor(self) -> Optional["GraphNode"]: 33 | """ 34 | Get another graph node which points to the state of the residual stream before 35 | this node. 36 | 37 | Retun None if current representation is the first one in the residual stream. 38 | """ 39 | scheme = { 40 | NodeType.AFTER_ATTN: GraphNode( 41 | layer=max(self.layer - 1, 0), 42 | token=self.token, 43 | type=NodeType.AFTER_FFN if self.layer > 0 else NodeType.ORIGINAL, 44 | ), 45 | NodeType.AFTER_FFN: GraphNode( 46 | layer=self.layer, 47 | token=self.token, 48 | type=NodeType.AFTER_ATTN, 49 | ), 50 | NodeType.FFN: GraphNode( 51 | layer=self.layer, 52 | token=self.token, 53 | type=NodeType.AFTER_ATTN, 54 | ), 55 | NodeType.ORIGINAL: None, 56 | } 57 | node = scheme[self.type] 58 | if node.layer < 0: 59 | return None 60 | return node 61 | 62 | def get_name(self) -> str: 63 | return _format_block_hierachy_string( 64 | [f"L{self.layer}", f"T{self.token}", str(self.type.value)] 65 | ) 66 | 67 | def get_predecessor_block_name(self) -> str: 68 | """ 69 | Return the name of the block standing between current node and its predecessor 70 | in the residual stream. 71 | """ 72 | scheme = { 73 | NodeType.AFTER_ATTN: [f"L{self.layer}", "attn"], 74 | NodeType.AFTER_FFN: [f"L{self.layer}", "ffn"], 75 | NodeType.FFN: [f"L{self.layer}", "ffn"], 76 | NodeType.ORIGINAL: ["Nothing"], 77 | } 78 | return _format_block_hierachy_string(scheme[self.type]) 79 | 80 | def get_head_name(self, head: Optional[int]) -> str: 81 | path = [f"L{self.layer}", "attn"] 82 | if head is not None: 83 | path.append(f"H{head}") 84 | return _format_block_hierachy_string(path) 85 | 86 | def get_neuron_name(self, neuron: Optional[int]) -> str: 87 | path = [f"L{self.layer}", "ffn"] 88 | if neuron is not None: 89 | path.append(f"N{neuron}") 90 | return _format_block_hierachy_string(path) 91 | -------------------------------------------------------------------------------- /llm_transparency_tool/server/monitor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import streamlit as st 9 | from pyinstrument import Profiler 10 | from typing import Dict 11 | import pandas as pd 12 | 13 | 14 | @st.cache_resource(max_entries=1, show_spinner=False) 15 | def init_gpu_memory(): 16 | """ 17 | When CUDA is initialized, it occupies some memory on the GPU thus this overhead 18 | can sometimes make it difficult to understand how much memory is actually used by 19 | the model. 20 | 21 | This function is used to initialize CUDA and measure the overhead. 22 | """ 23 | if not torch.cuda.is_available(): 24 | return {} 25 | 26 | # lets init torch gpu for a moment 27 | gpu_memory_overhead = {} 28 | for i in range(torch.cuda.device_count()): 29 | torch.ones(1).cuda(i) 30 | free, total = torch.cuda.mem_get_info(i) 31 | occupied = total - free 32 | gpu_memory_overhead[i] = occupied 33 | 34 | return gpu_memory_overhead 35 | 36 | 37 | class SystemMonitor: 38 | """ 39 | This class is used to monitor the system resources such as GPU memory and CPU 40 | usage. It uses the pyinstrument library to profile the code and measure the 41 | execution time of different parts of the code. 42 | """ 43 | 44 | def __init__( 45 | self, 46 | enabled: bool = False, 47 | ): 48 | self.enabled = enabled 49 | self.profiler = Profiler() 50 | self.overhead: Dict[int, int] 51 | 52 | def __enter__(self): 53 | if not self.enabled: 54 | return 55 | 56 | self.overhead = init_gpu_memory() 57 | 58 | self.profiler.__enter__() 59 | 60 | def __exit__(self, exc_type, exc_value, traceback): 61 | if not self.enabled: 62 | return 63 | 64 | self.profiler.__exit__(exc_type, exc_value, traceback) 65 | 66 | self.report_gpu_usage() 67 | self.report_profiler() 68 | 69 | with st.expander("Session state"): 70 | st.write(st.session_state) 71 | 72 | return None 73 | 74 | def report_gpu_usage(self): 75 | 76 | if not torch.cuda.is_available(): 77 | return 78 | 79 | data = [] 80 | 81 | for i in range(torch.cuda.device_count()): 82 | free, total = torch.cuda.mem_get_info(i) 83 | occupied = total - free 84 | data.append({ 85 | 'overhead': self.overhead[i], 86 | 'occupied': occupied - self.overhead[i], 87 | 'free': free, 88 | }) 89 | df = pd.DataFrame(data, columns=["overhead", "occupied", "free"]) 90 | 91 | with st.sidebar.expander("System"): 92 | st.write("GPU memory on server") 93 | df /= 1024 ** 3 # Convert to GB 94 | st.bar_chart(df, width=200, height=200, color=["#fefefe", "#84c9ff", "#fe2b2b"]) 95 | 96 | def report_profiler(self): 97 | html_code = self.profiler.output_html() 98 | with st.expander("Profiler", expanded=False): 99 | st.components.v1.html(html_code, height=1000, scrolling=True) 100 | -------------------------------------------------------------------------------- /llm_transparency_tool/server/styles.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass 8 | 9 | import matplotlib 10 | 11 | # Unofficial way do make the padding a bit smaller. 12 | margins_css = """ 13 | 24 | """ 25 | 26 | 27 | @dataclass 28 | class RenderSettings: 29 | column_proportions = [50, 30] 30 | 31 | # We don't know the actual height. This will be used in order to compute the table 32 | # viewport height when needed. 33 | table_cell_height = 36 34 | 35 | n_top_tokens = 30 36 | n_promoted_tokens = 15 37 | n_suppressed_tokens = 15 38 | 39 | n_top_neurons = 20 40 | 41 | attention_color_map = "Blues" 42 | 43 | no_model_alt_text = "" 44 | 45 | 46 | def string_to_display(s: str) -> str: 47 | return s.replace(" ", "·") 48 | 49 | 50 | def logits_color_map(positive_and_negative: bool) -> matplotlib.colors.Colormap: 51 | background_colors = { 52 | "red": [ 53 | [0.0, 0.40, 0.40], 54 | [0.1, 0.69, 0.69], 55 | [0.2, 0.83, 0.83], 56 | [0.3, 0.95, 0.95], 57 | [0.4, 0.99, 0.99], 58 | [0.5, 1.0, 1.0], 59 | [0.6, 0.90, 0.90], 60 | [0.7, 0.72, 0.72], 61 | [0.8, 0.49, 0.49], 62 | [0.9, 0.30, 0.30], 63 | [1.0, 0.15, 0.15], 64 | ], 65 | "green": [ 66 | [0.0, 0.0, 0.0], 67 | [0.1, 0.09, 0.09], 68 | [0.2, 0.37, 0.37], 69 | [0.3, 0.64, 0.64], 70 | [0.4, 0.85, 0.85], 71 | [0.5, 1.0, 1.0], 72 | [0.6, 0.96, 0.96], 73 | [0.7, 0.88, 0.88], 74 | [0.8, 0.73, 0.73], 75 | [0.9, 0.57, 0.57], 76 | [1.0, 0.39, 0.39], 77 | ], 78 | "blue": [ 79 | [0.0, 0.12, 0.12], 80 | [0.1, 0.16, 0.16], 81 | [0.2, 0.30, 0.30], 82 | [0.3, 0.50, 0.50], 83 | [0.4, 0.78, 0.78], 84 | [0.5, 1.0, 1.0], 85 | [0.6, 0.81, 0.81], 86 | [0.7, 0.52, 0.52], 87 | [0.8, 0.25, 0.25], 88 | [0.9, 0.12, 0.12], 89 | [1.0, 0.09, 0.09], 90 | ], 91 | } 92 | 93 | if not positive_and_negative: 94 | # Stretch the top part to the whole range 95 | new_colors = {} 96 | for channel, colors in background_colors.items(): 97 | new_colors[channel] = [ 98 | [(value - 0.5) * 2, color, color] 99 | for value, color, _ in colors 100 | if value >= 0.5 101 | ] 102 | background_colors = new_colors 103 | 104 | return matplotlib.colors.LinearSegmentedColormap( 105 | f"RdYG-{positive_and_negative}", 106 | background_colors, 107 | ) 108 | -------------------------------------------------------------------------------- /llm_transparency_tool/components/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from typing import List, Optional 9 | 10 | import networkx as nx 11 | import streamlit.components.v1 as components 12 | 13 | from llm_transparency_tool.models.transparent_llm import ModelInfo 14 | from llm_transparency_tool.server.graph_selection import GraphSelection, UiGraphNode 15 | 16 | _RELEASE = True 17 | 18 | if _RELEASE: 19 | parent_dir = os.path.dirname(os.path.abspath(__file__)) 20 | config = { 21 | "path": os.path.join(parent_dir, "frontend/build"), 22 | } 23 | else: 24 | config = { 25 | "url": "http://localhost:3001", 26 | } 27 | 28 | _component_func = components.declare_component("contribution_graph", **config) 29 | 30 | 31 | def is_node_valid(node: UiGraphNode, n_layers: int, n_tokens: int): 32 | return node.layer < n_layers and node.token < n_tokens 33 | 34 | 35 | def is_selection_valid(s: GraphSelection, n_layers: int, n_tokens: int): 36 | if not s: 37 | return True 38 | if s.node: 39 | if not is_node_valid(s.node, n_layers, n_tokens): 40 | return False 41 | if s.edge: 42 | for node in [s.edge.source, s.edge.target]: 43 | if not is_node_valid(node, n_layers, n_tokens): 44 | return False 45 | return True 46 | 47 | 48 | def contribution_graph( 49 | model_info: ModelInfo, 50 | tokens: List[str], 51 | graphs: List[nx.Graph], 52 | key: str, 53 | ) -> Optional[GraphSelection]: 54 | """Create a new instance of contribution graph. 55 | 56 | Returns selected graph node or None if nothing was selected. 57 | """ 58 | assert len(tokens) == len(graphs) 59 | 60 | result = _component_func( 61 | component="graph", 62 | model_info=model_info.__dict__, 63 | tokens=tokens, 64 | edges_per_token=[nx.node_link_data(g)["links"] for g in graphs], 65 | default=None, 66 | key=key, 67 | ) 68 | 69 | selection = GraphSelection.from_json(result) 70 | 71 | n_tokens = len(tokens) 72 | n_layers = model_info.n_layers 73 | # We need this extra protection because even though the component has to check for 74 | # the validity of the selection, sometimes it allows invalid output. It's some 75 | # unexpected effect that has something to do with React and how the output value is 76 | # set for the component. 77 | if not is_selection_valid(selection, n_layers, n_tokens): 78 | selection = None 79 | 80 | return selection 81 | 82 | 83 | def selector( 84 | items: List[str], 85 | indices: List[int], 86 | temperatures: Optional[List[float]], 87 | preselected_index: Optional[int], 88 | key: str, 89 | ) -> Optional[int]: 90 | """Create a new instance of selector. 91 | 92 | Returns selected item index. 93 | """ 94 | n = len(items) 95 | assert n == len(indices) 96 | items = [{"index": i, "text": s} for s, i in zip(items, indices)] 97 | 98 | if temperatures is not None: 99 | assert n == len(temperatures) 100 | for i, t in enumerate(temperatures): 101 | items[i]["temperature"] = t 102 | 103 | result = _component_func( 104 | component="selector", 105 | items=items, 106 | preselected_index=preselected_index, 107 | default=None, 108 | key=key, 109 | ) 110 | 111 | return None if result is None else int(result) 112 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /llm_transparency_tool/server/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import uuid 8 | from typing import List, Optional, Tuple 9 | 10 | import networkx as nx 11 | import streamlit as st 12 | import torch 13 | import transformers 14 | 15 | import llm_transparency_tool.routes.graph 16 | from llm_transparency_tool.models.tlens_model import TransformerLensTransparentLlm 17 | from llm_transparency_tool.models.transparent_llm import TransparentLlm 18 | 19 | GPU = "gpu" 20 | CPU = "cpu" 21 | 22 | # This variable is for expressing the idea that batch_id = 0, but make it more 23 | # readable than just 0. 24 | B0 = 0 25 | 26 | 27 | def possible_devices() -> List[str]: 28 | devices = [] 29 | if torch.cuda.is_available(): 30 | devices.append("gpu") 31 | devices.append("cpu") 32 | return devices 33 | 34 | 35 | def load_dataset(filename) -> List[str]: 36 | with open(filename) as f: 37 | dataset = [s.strip("\n") for s in f.readlines()] 38 | print(f"Loaded {len(dataset)} sentences from {filename}") 39 | return dataset 40 | 41 | 42 | @st.cache_resource( 43 | hash_funcs={ 44 | TransformerLensTransparentLlm: id 45 | } 46 | ) 47 | def load_model( 48 | model_name: str, 49 | _device: str, 50 | _model_path: Optional[str] = None, 51 | _dtype: torch.dtype = torch.float32, 52 | supported_model_name: Optional[str] = None, 53 | ) -> TransparentLlm: 54 | """ 55 | Returns the loaded model along with its key. The key is just a unique string which 56 | can be used later to identify if the model has changed. 57 | """ 58 | assert _device in possible_devices() 59 | 60 | causal_lm = None 61 | tokenizer = None 62 | 63 | tl_lm = TransformerLensTransparentLlm( 64 | model_name=model_name, 65 | hf_model=causal_lm, 66 | tokenizer=tokenizer, 67 | device=_device, 68 | dtype=_dtype, 69 | supported_model_name=supported_model_name, 70 | ) 71 | 72 | return tl_lm 73 | 74 | 75 | def run_model(model: TransparentLlm, sentence: str) -> None: 76 | print(f"Running inference for '{sentence}'") 77 | model.run([sentence]) 78 | 79 | 80 | def load_model_with_session_caching( 81 | **kwargs, 82 | ) -> Tuple[TransparentLlm, str]: 83 | return load_model(**kwargs) 84 | 85 | def run_model_with_session_caching( 86 | _model: TransparentLlm, 87 | model_key: str, 88 | sentence: str, 89 | ): 90 | LAST_RUN_MODEL_KEY = "last_run_model_key" 91 | LAST_RUN_SENTENCE = "last_run_sentence" 92 | state = st.session_state 93 | 94 | if ( 95 | state.get(LAST_RUN_MODEL_KEY, None) == model_key 96 | and state.get(LAST_RUN_SENTENCE, None) == sentence 97 | ): 98 | return 99 | 100 | run_model(_model, sentence) 101 | state[LAST_RUN_MODEL_KEY] = model_key 102 | state[LAST_RUN_SENTENCE] = sentence 103 | 104 | 105 | @st.cache_resource( 106 | hash_funcs={ 107 | TransformerLensTransparentLlm: id 108 | } 109 | ) 110 | def get_contribution_graph( 111 | model: TransparentLlm, # TODO bug here 112 | model_key: str, 113 | tokens: List[str], 114 | threshold: float, 115 | ) -> nx.Graph: 116 | """ 117 | The `model_key` and `tokens` are used only for caching. The model itself is not 118 | hashed, hence the `_` in the beginning. 119 | """ 120 | return llm_transparency_tool.routes.graph.build_full_graph( 121 | model, 122 | B0, 123 | threshold, 124 | ) 125 | 126 | 127 | def st_placeholder( 128 | text: str, 129 | container=st, 130 | border: bool = True, 131 | height: Optional[int] = 500, 132 | ): 133 | empty = container.empty() 134 | empty.container(border=border, height=height).write(f'{text}', unsafe_allow_html=True) 135 | return empty 136 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | LLM Transparency Tool 3 |

4 | 5 | screenshot 6 | 7 | 8 | ## Key functionality 9 | 10 | * Choose your model, choose or add your prompt, run the inference. 11 | * Browse contribution graph. 12 | * Select the token to build the graph from. 13 | * Tune the contribution threshold. 14 | * Select representation of any token after any block. 15 | * For the representation, see its projection to the output vocabulary, see which tokens 16 | were promoted/suppressed but the previous block. 17 | * The following things are clickable: 18 | * Edges. That shows more info about the contributing attention head. 19 | * Heads when an edge is selected. You can see what this head is promoting/suppressing. 20 | * FFN blocks (little squares on the graph). 21 | * Neurons when an FFN block is selected. 22 | 23 | 24 | ## Installation 25 | 26 | ### Dockerized running 27 | ```bash 28 | # From the repository root directory 29 | docker build -t llm_transparency_tool . 30 | docker run --rm -p 7860:7860 llm_transparency_tool 31 | ``` 32 | 33 | ### Local Installation 34 | 35 | 36 | ```bash 37 | # download 38 | git clone git@github.com:facebookresearch/llm-transparency-tool.git 39 | cd llm-transparency-tool 40 | 41 | # install the necessary packages 42 | conda env create --name llmtt -f env.yaml 43 | # install the `llm_transparency_tool` package 44 | pip install -e . 45 | 46 | # now, we need to build the frontend 47 | # don't worry, even `yarn` comes preinstalled by `env.yaml` 48 | cd llm_transparency_tool/components/frontend 49 | yarn install 50 | yarn build 51 | ``` 52 | 53 | ### Launch 54 | 55 | ```bash 56 | streamlit run llm_transparency_tool/server/app.py -- config/local.json 57 | ``` 58 | 59 | 60 | ## Adding support for your LLM 61 | 62 | Initially, the tool allows you to select from just a handful of models. Here are the 63 | options you can try for using your model in the tool, from least to most 64 | effort. 65 | 66 | 67 | ### The model is already supported by TransformerLens 68 | 69 | Full list of models is [here](https://github.com/neelnanda-io/TransformerLens/blob/0825c5eb4196e7ad72d28bcf4e615306b3897490/transformer_lens/loading_from_pretrained.py#L18). 70 | In this case, the model can be added to the configuration json file. 71 | 72 | 73 | ### Tuned version of a model supported by TransformerLens 74 | 75 | Add the official name of the model to the config along with the location to read the 76 | weights from. 77 | 78 | 79 | ### The model is not supported by TransformerLens 80 | 81 | In this case the UI wouldn't know how to create proper hooks for the model. You'd need 82 | to implement your version of [TransparentLlm](./llm_transparency_tool/models/transparent_llm.py#L28) class and alter the 83 | Streamlit app to use your implementation. 84 | 85 | ## Citation 86 | If you use the LLM Transparency Tool for your research, please consider citing: 87 | 88 | ```bibtex 89 | @article{tufanov2024lm, 90 | title={LM Transparency Tool: Interactive Tool for Analyzing Transformer Language Models}, 91 | author={Igor Tufanov and Karen Hambardzumyan and Javier Ferrando and Elena Voita}, 92 | year={2024}, 93 | journal={Arxiv}, 94 | url={https://arxiv.org/abs/2404.07004} 95 | } 96 | 97 | @article{ferrando2024information, 98 | title={Information Flow Routes: Automatically Interpreting Language Models at Scale}, 99 | author={Javier Ferrando and Elena Voita}, 100 | year={2024}, 101 | journal={Arxiv}, 102 | url={https://arxiv.org/abs/2403.00824} 103 | } 104 | ```` 105 | 106 | ## License 107 | 108 | This code is made available under a [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license, as found in the LICENSE file. 109 | However you may have other legal obligations that govern your use of other content, such as the terms of service for third-party models. 110 | -------------------------------------------------------------------------------- /llm_transparency_tool/components/frontend/src/Selector.tsx: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | import { 10 | ComponentProps, 11 | Streamlit, 12 | withStreamlitConnection, 13 | } from "streamlit-component-lib" 14 | import React, { useEffect, useMemo, useRef, useState } from 'react'; 15 | import * as d3 from 'd3'; 16 | 17 | import { 18 | Point, 19 | } from './common'; 20 | import './LlmViewer.css'; 21 | 22 | export const renderParams = { 23 | verticalGap: 24, 24 | horizontalGap: 24, 25 | itemSize: 8, 26 | } 27 | 28 | interface Item { 29 | index: number 30 | text: string 31 | temperature: number 32 | } 33 | 34 | const Selector = ({ args }: ComponentProps) => { 35 | const items: Item[] = args["items"] 36 | const preselected_index: number | null = args["preselected_index"] 37 | const n = items.length 38 | 39 | const [selection, setSelection] = useState(null) 40 | 41 | // Ensure the preselected element has effect only when it's a new data. 42 | var args_json = JSON.stringify(args) 43 | useEffect(() => { 44 | setSelection(preselected_index) 45 | Streamlit.setComponentValue(preselected_index) 46 | }, [args_json, preselected_index]); 47 | 48 | const handleItemClick = (index: number) => { 49 | setSelection(index) 50 | Streamlit.setComponentValue(index) 51 | } 52 | 53 | const [xScale, yScale] = useMemo(() => { 54 | const x = d3.scaleLinear() 55 | .domain([0, 1]) 56 | .range([0, renderParams.horizontalGap]) 57 | const y = d3.scaleLinear() 58 | .domain([0, n - 1]) 59 | .range([0, renderParams.verticalGap * (n - 1)]) 60 | return [x, y] 61 | }, [n]) 62 | 63 | const itemCoords: Point[] = useMemo(() => { 64 | return Array.from(Array(n).keys()).map(i => ({ 65 | x: xScale(0.5), 66 | y: yScale(i + 0.5), 67 | })) 68 | }, [n, xScale, yScale]) 69 | 70 | var hasTemperature = false 71 | if (n > 0) { 72 | var t = items[0].temperature 73 | hasTemperature = (t !== null && t !== undefined) 74 | } 75 | const colorScale = useMemo(() => { 76 | var min_t = 0.0 77 | var max_t = 1.0 78 | if (hasTemperature) { 79 | min_t = items[0].temperature 80 | max_t = items[0].temperature 81 | for (var i = 0; i < n; i++) { 82 | const t = items[i].temperature 83 | min_t = Math.min(min_t, t) 84 | max_t = Math.max(max_t, t) 85 | } 86 | } 87 | const norm = d3.scaleLinear([min_t, max_t], [0.0, 1.0]) 88 | const colorScale = d3.scaleSequential(d3.interpolateYlGn); 89 | return d3.scaleSequential(value => colorScale(norm(value))) 90 | }, [items, hasTemperature, n]) 91 | 92 | const totalW = 100 93 | const totalH = yScale(n) 94 | useEffect(() => { 95 | Streamlit.setFrameHeight(totalH) 96 | }, [totalH]) 97 | 98 | const svgRef = useRef(null); 99 | 100 | useEffect(() => { 101 | const svg = d3.select(svgRef.current) 102 | svg.selectAll('*').remove() 103 | 104 | const getItemClass = (index: number) => { 105 | var style = 'selectable-item ' 106 | style += index === selection ? 'selection' : 'selector-item' 107 | return style 108 | } 109 | 110 | const getItemColor = (item: Item) => { 111 | var t = item.temperature ?? 0.0 112 | return item.index === selection ? 'orange' : colorScale(t) 113 | } 114 | 115 | var icons = svg 116 | .selectAll('items') 117 | .data(Array.from(Array(n).keys())) 118 | .enter() 119 | .append('circle') 120 | .attr('cx', (i) => itemCoords[i].x) 121 | .attr('cy', (i) => itemCoords[i].y) 122 | .attr('r', renderParams.itemSize / 2) 123 | .on('click', (event: PointerEvent, i) => { 124 | handleItemClick(items[i].index) 125 | }) 126 | .attr('class', (i) => getItemClass(items[i].index)) 127 | if (hasTemperature) { 128 | icons.style('fill', (i) => getItemColor(items[i])) 129 | } 130 | 131 | svg 132 | .selectAll('labels') 133 | .data(Array.from(Array(n).keys())) 134 | .enter() 135 | .append('text') 136 | .attr('x', (i) => itemCoords[i].x + renderParams.horizontalGap / 2) 137 | .attr('y', (i) => itemCoords[i].y) 138 | .attr('text-anchor', 'left') 139 | .attr('alignment-baseline', 'middle') 140 | .text((i) => items[i].text) 141 | 142 | }, [ 143 | items, 144 | n, 145 | itemCoords, 146 | selection, 147 | colorScale, 148 | hasTemperature, 149 | ]) 150 | 151 | return 152 | } 153 | 154 | export default withStreamlitConnection(Selector) 155 | -------------------------------------------------------------------------------- /llm_transparency_tool/routes/test_contributions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import unittest 8 | from typing import Any, List 9 | 10 | import torch 11 | 12 | import llm_transparency_tool.routes.contributions as contributions 13 | 14 | 15 | class TestContributions(unittest.TestCase): 16 | def setUp(self): 17 | torch.manual_seed(123) 18 | 19 | self.eps = 1e-4 20 | 21 | # It may be useful to run the test on GPU in case there are any issues with 22 | # creating temporary tensors on another device. But turn this off by default. 23 | self.test_on_gpu = False 24 | 25 | self.device = "cuda" if self.test_on_gpu else "cpu" 26 | 27 | self.batch = 4 28 | self.tokens = 5 29 | self.heads = 6 30 | self.d_model = 10 31 | 32 | self.decomposed_attn = torch.rand( 33 | self.batch, 34 | self.tokens, 35 | self.tokens, 36 | self.heads, 37 | self.d_model, 38 | device=self.device, 39 | ) 40 | self.mlp_out = torch.rand( 41 | self.batch, self.tokens, self.d_model, device=self.device 42 | ) 43 | self.resid_pre = torch.rand( 44 | self.batch, self.tokens, self.d_model, device=self.device 45 | ) 46 | self.resid_mid = torch.rand( 47 | self.batch, self.tokens, self.d_model, device=self.device 48 | ) 49 | self.resid_post = torch.rand( 50 | self.batch, self.tokens, self.d_model, device=self.device 51 | ) 52 | 53 | def _assert_tensor_eq(self, t: torch.Tensor, expected: List[Any]): 54 | self.assertTrue( 55 | torch.isclose(t, torch.Tensor(expected), atol=self.eps).all(), 56 | t, 57 | ) 58 | 59 | def test_mlp_contributions(self): 60 | mlp_out = torch.tensor([[[1.0, 1.0]]]) 61 | resid_mid = torch.tensor([[[0.0, 0.0]]]) 62 | resid_post = torch.tensor([[[1.0, 1.0]]]) 63 | 64 | c_mlp, c_residual = contributions.get_mlp_contributions( 65 | resid_mid, resid_post, mlp_out 66 | ) 67 | self.assertAlmostEqual(c_mlp.item(), 1.0, delta=self.eps) 68 | self.assertAlmostEqual(c_residual.item(), 0.0, delta=self.eps) 69 | 70 | def test_decomposed_attn_contributions(self): 71 | resid_pre = torch.tensor([[[2.0, 1.0]]]) 72 | resid_mid = torch.tensor([[[2.0, 2.0]]]) 73 | decomposed_attn = torch.tensor( 74 | [ 75 | [ 76 | [ 77 | [ 78 | [1.0, 1.0], 79 | [-1.0, 0.0], 80 | ] 81 | ] 82 | ] 83 | ] 84 | ) 85 | 86 | c_attn, c_residual = contributions.get_attention_contributions( 87 | resid_pre, resid_mid, decomposed_attn, distance_norm=2 88 | ) 89 | self._assert_tensor_eq(c_attn, [[[[0.43613, 0]]]]) 90 | self.assertAlmostEqual(c_residual.item(), 0.56387, delta=self.eps) 91 | 92 | def test_decomposed_mlp_contributions(self): 93 | pre = torch.tensor([10.0, 10.0]) 94 | post = torch.tensor([-10.0, 10.0]) 95 | neuron_impacts = torch.tensor( 96 | [ 97 | [0.0, 1.0], 98 | [1.0, 0.0], 99 | [-21.0, -1.0], 100 | ] 101 | ) 102 | c_mlp, c_residual = contributions.get_decomposed_mlp_contributions( 103 | pre, post, neuron_impacts, distance_norm=2 104 | ) 105 | # A bit counter-intuitive, but the only vector pointing from 0 towards the 106 | # output is the first one. 107 | self._assert_tensor_eq(c_mlp, [1, 0, 0]) 108 | self.assertAlmostEqual(c_residual, 0, delta=self.eps) 109 | 110 | def test_decomposed_mlp_contributions_single_direction(self): 111 | pre = torch.tensor([1.0, 1.0]) 112 | post = torch.tensor([4.0, 4.0]) 113 | neuron_impacts = torch.tensor( 114 | [ 115 | [1.0, 1.0], 116 | [2.0, 2.0], 117 | ] 118 | ) 119 | c_mlp, c_residual = contributions.get_decomposed_mlp_contributions( 120 | pre, post, neuron_impacts, distance_norm=2 121 | ) 122 | self._assert_tensor_eq(c_mlp, [0.25, 0.5]) 123 | self.assertAlmostEqual(c_residual, 0.25, delta=self.eps) 124 | 125 | def test_attention_contributions_shape(self): 126 | c_attn, c_residual = contributions.get_attention_contributions( 127 | self.resid_pre, self.resid_mid, self.decomposed_attn 128 | ) 129 | self.assertEqual( 130 | list(c_attn.shape), [self.batch, self.tokens, self.tokens, self.heads] 131 | ) 132 | self.assertEqual(list(c_residual.shape), [self.batch, self.tokens]) 133 | 134 | def test_mlp_contributions_shape(self): 135 | c_mlp, c_residual = contributions.get_mlp_contributions( 136 | self.resid_mid, self.resid_post, self.mlp_out 137 | ) 138 | self.assertEqual(list(c_mlp.shape), [self.batch, self.tokens]) 139 | self.assertEqual(list(c_residual.shape), [self.batch, self.tokens]) 140 | 141 | def test_renormalizing_threshold(self): 142 | c_blocks = torch.Tensor([[0.05, 0.15], [0.05, 0.05]]) 143 | c_residual = torch.Tensor([0.8, 0.9]) 144 | norm_blocks, norm_residual = contributions.apply_threshold_and_renormalize( 145 | 0.1, c_blocks, c_residual 146 | ) 147 | self._assert_tensor_eq(norm_blocks, [[0.0, 0.157894], [0.0, 0.0]]) 148 | self._assert_tensor_eq(norm_residual, [0.842105, 1.0]) 149 | -------------------------------------------------------------------------------- /llm_transparency_tool/models/test_tlens_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import unittest 8 | 9 | import torch 10 | 11 | from llm_transparency_tool.models.tlens_model import TransformerLensTransparentLlm 12 | from llm_transparency_tool.models.transparent_llm import ModelInfo 13 | 14 | 15 | class TransparentLlmTestCase(unittest.TestCase): 16 | @classmethod 17 | def setUpClass(cls): 18 | # Picking the smallest model possible so that the test runs faster. It's ok to 19 | # change this model, but you'll need to update tokenization specifics in some 20 | # tests. 21 | cls._llm = TransformerLensTransparentLlm( 22 | model_name="facebook/opt-125m", 23 | device="cpu", 24 | ) 25 | 26 | def setUp(self): 27 | self._llm.run(["test", "test 1"]) 28 | self._eps = 1e-5 29 | 30 | def test_model_info(self): 31 | info = self._llm.model_info() 32 | self.assertEqual( 33 | info, 34 | ModelInfo( 35 | name="facebook/opt-125m", 36 | n_params_estimate=84934656, 37 | n_layers=12, 38 | n_heads=12, 39 | d_model=768, 40 | d_vocab=50272, 41 | ), 42 | ) 43 | 44 | def test_tokens(self): 45 | tokens = self._llm.tokens() 46 | 47 | pad = 1 48 | bos = 2 49 | test = 21959 50 | one = 112 51 | 52 | self.assertEqual(tokens.tolist(), [[bos, test, pad], [bos, test, one]]) 53 | 54 | def test_tokens_to_strings(self): 55 | s = self._llm.tokens_to_strings(torch.Tensor([2, 21959, 112]).to(torch.int)) 56 | self.assertEqual(s, ["", "test", " 1"]) 57 | 58 | def test_manage_state(self): 59 | # One llm.run was called at the setup. Call one more and make sure the object 60 | # returns values for the new state. 61 | self._llm.run(["one", "two", "three", "four"]) 62 | self.assertEqual(self._llm.tokens().shape[0], 4) 63 | 64 | def test_residual_in_and_out(self): 65 | """ 66 | Test that residual_in is a residual_out for the previous layer. 67 | """ 68 | for layer in range(1, 12): 69 | prev_residual_out = self._llm.residual_out(layer - 1) 70 | residual_in = self._llm.residual_in(layer) 71 | diff = torch.max(torch.abs(residual_in - prev_residual_out)).item() 72 | self.assertLess(diff, self._eps, f"layer {layer}") 73 | 74 | def test_residual_plus_block(self): 75 | """ 76 | Make sure that new residual = old residual + block output. Here, block is an ffn 77 | or attention. It's not that obvious because it could be that layer norm is 78 | applied after the block output, but before saving the result to residual. 79 | Luckily, this is not the case in TransformerLens, and we're relying on that. 80 | """ 81 | layer = 3 82 | batch = 0 83 | pos = 0 84 | 85 | residual_in = self._llm.residual_in(layer)[batch][pos] 86 | residual_mid = self._llm.residual_after_attn(layer)[batch][pos] 87 | residual_out = self._llm.residual_out(layer)[batch][pos] 88 | ffn_out = self._llm.ffn_out(layer)[batch][pos] 89 | attn_out = self._llm.attention_output(batch, layer, pos) 90 | 91 | a = residual_mid 92 | b = residual_in + attn_out 93 | diff = torch.max(torch.abs(a - b)).item() 94 | self.assertLess(diff, self._eps, "attn") 95 | 96 | a = residual_out 97 | b = residual_mid + ffn_out 98 | diff = torch.max(torch.abs(a - b)).item() 99 | self.assertLess(diff, self._eps, "ffn") 100 | 101 | def test_tensor_shapes(self): 102 | # Not much we can do about the tensors, but at least check their shapes and 103 | # that they don't contain NaNs. 104 | vocab_size = 50272 105 | n_batch = 2 106 | n_tokens = 3 107 | d_model = 768 108 | d_hidden = d_model * 4 109 | n_heads = 12 110 | layer = 5 111 | 112 | device = self._llm.residual_in(0).device 113 | 114 | for name, tensor, expected_shape in [ 115 | ("r_in", self._llm.residual_in(layer), [n_batch, n_tokens, d_model]), 116 | ( 117 | "r_mid", 118 | self._llm.residual_after_attn(layer), 119 | [n_batch, n_tokens, d_model], 120 | ), 121 | ("r_out", self._llm.residual_out(layer), [n_batch, n_tokens, d_model]), 122 | ("logits", self._llm.logits(), [n_batch, n_tokens, vocab_size]), 123 | ("ffn_out", self._llm.ffn_out(layer), [n_batch, n_tokens, d_model]), 124 | ( 125 | "decomposed_ffn_out", 126 | self._llm.decomposed_ffn_out(0, 0, 0), 127 | [d_hidden, d_model], 128 | ), 129 | ("neuron_activations", self._llm.neuron_activations(0, 0, 0), [d_hidden]), 130 | ("neuron_output", self._llm.neuron_output(0, 0), [d_model]), 131 | ( 132 | "attention_matrix", 133 | self._llm.attention_matrix(0, 0, 0), 134 | [n_tokens, n_tokens], 135 | ), 136 | ( 137 | "attention_output_per_head", 138 | self._llm.attention_output_per_head(0, 0, 0, 0), 139 | [d_model], 140 | ), 141 | ( 142 | "attention_output", 143 | self._llm.attention_output(0, 0, 0), 144 | [d_model], 145 | ), 146 | ( 147 | "decomposed_attn", 148 | self._llm.decomposed_attn(0, layer), 149 | [n_tokens, n_tokens, n_heads, d_model], 150 | ), 151 | ( 152 | "unembed", 153 | self._llm.unembed(torch.zeros([d_model]).to(device), normalize=True), 154 | [vocab_size], 155 | ), 156 | ]: 157 | self.assertEqual(list(tensor.shape), expected_shape, name) 158 | self.assertFalse(torch.any(tensor.isnan()), name) 159 | 160 | 161 | if __name__ == "__main__": 162 | unittest.main() 163 | -------------------------------------------------------------------------------- /llm_transparency_tool/models/transparent_llm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from abc import ABC, abstractmethod 8 | from dataclasses import dataclass 9 | from typing import List 10 | 11 | import torch 12 | from jaxtyping import Float, Int 13 | 14 | 15 | @dataclass 16 | class ModelInfo: 17 | name: str 18 | 19 | # Not the actual number of parameters, but rather the order of magnitude 20 | n_params_estimate: int 21 | 22 | n_layers: int 23 | n_heads: int 24 | d_model: int 25 | d_vocab: int 26 | 27 | 28 | class TransparentLlm(ABC): 29 | """ 30 | An abstract stateful interface for a language model. The model is supposed to be 31 | loaded at the class initialization. 32 | 33 | The internal state is the resulting tensors from the last call of the `run` method. 34 | Most of the methods could return values based on the state, but some may do cheap 35 | computations based on them. 36 | """ 37 | 38 | @abstractmethod 39 | def model_info(self) -> ModelInfo: 40 | """ 41 | Gives general info about the model. This method must be available before any 42 | calls of the `run`. 43 | """ 44 | pass 45 | 46 | @abstractmethod 47 | def run(self, sentences: List[str]) -> None: 48 | """ 49 | Run the inference on the given sentences in a single batch and store all 50 | necessary info in the internal state. 51 | """ 52 | pass 53 | 54 | @abstractmethod 55 | def batch_size(self) -> int: 56 | """ 57 | The size of the batch that was used for the last call of `run`. 58 | """ 59 | pass 60 | 61 | @abstractmethod 62 | def tokens(self) -> Int[torch.Tensor, "batch pos"]: 63 | pass 64 | 65 | @abstractmethod 66 | def tokens_to_strings(self, tokens: Int[torch.Tensor, "pos"]) -> List[str]: 67 | pass 68 | 69 | @abstractmethod 70 | def logits(self) -> Float[torch.Tensor, "batch pos d_vocab"]: 71 | pass 72 | 73 | @abstractmethod 74 | def unembed( 75 | self, 76 | t: Float[torch.Tensor, "d_model"], 77 | normalize: bool, 78 | ) -> Float[torch.Tensor, "vocab"]: 79 | """ 80 | Project the given vector (for example, the state of the residual stream for a 81 | layer and token) into the output vocabulary. 82 | 83 | normalize: whether to apply the final normalization before the unembedding. 84 | Setting it to True and applying to output of the last layer gives the output of 85 | the model. 86 | """ 87 | pass 88 | 89 | # ================= Methods related to the residual stream ================= 90 | 91 | @abstractmethod 92 | def residual_in(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]: 93 | """ 94 | The state of the residual stream before entering the layer. For example, when 95 | layer == 0 these must the embedded tokens (including positional embedding). 96 | """ 97 | pass 98 | 99 | @abstractmethod 100 | def residual_after_attn( 101 | self, layer: int 102 | ) -> Float[torch.Tensor, "batch pos d_model"]: 103 | """ 104 | The state of the residual stream after attention, but before the FFN in the 105 | given layer. 106 | """ 107 | pass 108 | 109 | @abstractmethod 110 | def residual_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]: 111 | """ 112 | The state of the residual stream after the given layer. This is equivalent to the 113 | next layer's input. 114 | """ 115 | pass 116 | 117 | # ================ Methods related to the feed-forward layer =============== 118 | 119 | @abstractmethod 120 | def ffn_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]: 121 | """ 122 | The output of the FFN layer, before it gets merged into the residual stream. 123 | """ 124 | pass 125 | 126 | @abstractmethod 127 | def decomposed_ffn_out( 128 | self, 129 | batch_i: int, 130 | layer: int, 131 | pos: int, 132 | ) -> Float[torch.Tensor, "hidden d_model"]: 133 | """ 134 | A collection of vectors added to the residual stream by each neuron. It should 135 | be the same as neuron activations multiplied by neuron outputs. 136 | """ 137 | pass 138 | 139 | @abstractmethod 140 | def neuron_activations( 141 | self, 142 | batch_i: int, 143 | layer: int, 144 | pos: int, 145 | ) -> Float[torch.Tensor, "d_ffn"]: 146 | """ 147 | The content of the hidden layer right after the activation function was applied. 148 | """ 149 | pass 150 | 151 | @abstractmethod 152 | def neuron_output( 153 | self, 154 | layer: int, 155 | neuron: int, 156 | ) -> Float[torch.Tensor, "d_model"]: 157 | """ 158 | Return the value that the given neuron adds to the residual stream. It's a raw 159 | vector from the model parameters, no activation involved. 160 | """ 161 | pass 162 | 163 | # ==================== Methods related to the attention ==================== 164 | 165 | @abstractmethod 166 | def attention_matrix( 167 | self, batch_i, layer: int, head: int 168 | ) -> Float[torch.Tensor, "query_pos key_pos"]: 169 | """ 170 | Return a lower-diagonal attention matrix. 171 | """ 172 | pass 173 | 174 | @abstractmethod 175 | def attention_output( 176 | self, 177 | batch_i: int, 178 | layer: int, 179 | pos: int, 180 | head: int, 181 | ) -> Float[torch.Tensor, "d_model"]: 182 | """ 183 | Return what the given head at the given layer and pos added to the residual 184 | stream. 185 | """ 186 | pass 187 | 188 | @abstractmethod 189 | def decomposed_attn( 190 | self, batch_i: int, layer: int 191 | ) -> Float[torch.Tensor, "source target head d_model"]: 192 | """ 193 | Here 194 | - source: index of token from the previous layer 195 | - target: index of token on the current layer 196 | The decomposed attention tells what vector from source representation was used 197 | in order to contribute to the taget representation. 198 | """ 199 | pass 200 | -------------------------------------------------------------------------------- /llm_transparency_tool/routes/graph.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Optional 8 | 9 | import networkx as nx 10 | import torch 11 | 12 | import llm_transparency_tool.routes.contributions as contributions 13 | from llm_transparency_tool.models.transparent_llm import TransparentLlm 14 | 15 | 16 | class GraphBuilder: 17 | """ 18 | Constructs the contributions graph with edges given one by one. The resulting graph 19 | is a networkx graph that can be accessed via the `graph` field. It contains the 20 | following types of nodes: 21 | 22 | - X0_: the original token. 23 | - A_: the residual stream after attention at the given layer for the 24 | given token. 25 | - M_: the ffn block. 26 | - I_: the residual stream after the ffn block. 27 | """ 28 | 29 | def __init__(self, n_layers: int, n_tokens: int): 30 | self._n_layers = n_layers 31 | self._n_tokens = n_tokens 32 | 33 | self.graph = nx.DiGraph() 34 | for layer in range(n_layers): 35 | for token in range(n_tokens): 36 | self.graph.add_node(f"A{layer}_{token}") 37 | self.graph.add_node(f"I{layer}_{token}") 38 | self.graph.add_node(f"M{layer}_{token}") 39 | for token in range(n_tokens): 40 | self.graph.add_node(f"X0_{token}") 41 | 42 | def get_output_node(self, token: int): 43 | return f"I{self._n_layers - 1}_{token}" 44 | 45 | def _add_edge(self, u: str, v: str, weight: float): 46 | # TODO(igortufanov): Here we sum up weights for multi-edges. It happens with 47 | # attention from the current token and the residual edge. Ideally these need to 48 | # be 2 separate edges, but then we need to do a MultiGraph. Multigraph is fine, 49 | # but when we try to traverse it, we face some NetworkX issue with EDGE_OK 50 | # receiving 3 arguments instead of 2. 51 | if self.graph.has_edge(u, v): 52 | self.graph[u][v]["weight"] += weight 53 | else: 54 | self.graph.add_edge(u, v, weight=weight) 55 | 56 | def add_attention_edge(self, layer: int, token_from: int, token_to: int, w: float): 57 | self._add_edge( 58 | f"I{layer-1}_{token_from}" if layer > 0 else f"X0_{token_from}", 59 | f"A{layer}_{token_to}", 60 | w, 61 | ) 62 | 63 | def add_residual_to_attn(self, layer: int, token: int, w: float): 64 | self._add_edge( 65 | f"I{layer-1}_{token}" if layer > 0 else f"X0_{token}", 66 | f"A{layer}_{token}", 67 | w, 68 | ) 69 | 70 | def add_ffn_edge(self, layer: int, token: int, w: float): 71 | self._add_edge(f"A{layer}_{token}", f"M{layer}_{token}", w) 72 | self._add_edge(f"M{layer}_{token}", f"I{layer}_{token}", w) 73 | 74 | def add_residual_to_ffn(self, layer: int, token: int, w: float): 75 | self._add_edge(f"A{layer}_{token}", f"I{layer}_{token}", w) 76 | 77 | 78 | @torch.no_grad() 79 | def build_full_graph( 80 | model: TransparentLlm, 81 | batch_i: int = 0, 82 | renormalizing_threshold: Optional[float] = None, 83 | ) -> nx.Graph: 84 | """ 85 | Build the contribution graph for all blocks of the model and all tokens. 86 | 87 | model: The transparent llm which already did the inference. 88 | batch_i: Which sentence to use from the batch that was given to the model. 89 | renormalizing_threshold: If specified, will apply renormalizing thresholding to the 90 | contributions. All contributions below the threshold will be erazed and the rest 91 | will be renormalized. 92 | """ 93 | n_layers = model.model_info().n_layers 94 | n_tokens = model.tokens()[batch_i].shape[0] 95 | 96 | builder = GraphBuilder(n_layers, n_tokens) 97 | 98 | for layer in range(n_layers): 99 | c_attn, c_resid_attn = contributions.get_attention_contributions( 100 | resid_pre=model.residual_in(layer)[batch_i].unsqueeze(0), 101 | resid_mid=model.residual_after_attn(layer)[batch_i].unsqueeze(0), 102 | decomposed_attn=model.decomposed_attn(batch_i, layer).unsqueeze(0), 103 | ) 104 | if renormalizing_threshold is not None: 105 | c_attn, c_resid_attn = contributions.apply_threshold_and_renormalize( 106 | renormalizing_threshold, c_attn, c_resid_attn 107 | ) 108 | for token_from in range(n_tokens): 109 | for token_to in range(n_tokens): 110 | # Sum attention contributions over heads. 111 | c = c_attn[batch_i, token_to, token_from].sum().item() 112 | builder.add_attention_edge(layer, token_from, token_to, c) 113 | for token in range(n_tokens): 114 | builder.add_residual_to_attn( 115 | layer, token, c_resid_attn[batch_i, token].item() 116 | ) 117 | 118 | c_ffn, c_resid_ffn = contributions.get_mlp_contributions( 119 | resid_mid=model.residual_after_attn(layer)[batch_i].unsqueeze(0), 120 | resid_post=model.residual_out(layer)[batch_i].unsqueeze(0), 121 | mlp_out=model.ffn_out(layer)[batch_i].unsqueeze(0), 122 | ) 123 | if renormalizing_threshold is not None: 124 | c_ffn, c_resid_ffn = contributions.apply_threshold_and_renormalize( 125 | renormalizing_threshold, c_ffn, c_resid_ffn 126 | ) 127 | for token in range(n_tokens): 128 | builder.add_ffn_edge(layer, token, c_ffn[batch_i, token].item()) 129 | builder.add_residual_to_ffn( 130 | layer, token, c_resid_ffn[batch_i, token].item() 131 | ) 132 | 133 | return builder.graph 134 | 135 | 136 | def build_paths_to_predictions( 137 | graph: nx.Graph, 138 | n_layers: int, 139 | n_tokens: int, 140 | starting_tokens: List[int], 141 | threshold: float, 142 | ) -> List[nx.Graph]: 143 | """ 144 | Given the full graph, this function returns only the trees leading to the specified 145 | tokens. Edges with weight below `threshold` will be ignored. 146 | """ 147 | builder = GraphBuilder(n_layers, n_tokens) 148 | 149 | rgraph = graph.reverse() 150 | search_graph = nx.subgraph_view( 151 | rgraph, filter_edge=lambda u, v: rgraph[u][v]["weight"] > threshold 152 | ) 153 | 154 | result = [] 155 | for start in starting_tokens: 156 | assert start < n_tokens 157 | assert start >= 0 158 | edges = nx.edge_dfs(search_graph, source=builder.get_output_node(start)) 159 | tree = search_graph.edge_subgraph(edges) 160 | # Reverse the edges because the dfs was going from upper layer downwards. 161 | result.append(tree.reverse()) 162 | 163 | return result 164 | -------------------------------------------------------------------------------- /llm_transparency_tool/routes/contributions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Tuple 8 | 9 | import einops 10 | import torch 11 | from jaxtyping import Float 12 | from typeguard import typechecked 13 | 14 | 15 | @torch.no_grad() 16 | @typechecked 17 | def get_contributions( 18 | parts: torch.Tensor, 19 | whole: torch.Tensor, 20 | distance_norm: int = 1, 21 | ) -> torch.Tensor: 22 | """ 23 | Compute contributions of the `parts` vectors into the `whole` vector. 24 | 25 | Shapes of the tensors are as follows: 26 | parts: p_1 ... p_k, v_1 ... v_n, d 27 | whole: v_1 ... v_n, d 28 | result: p_1 ... p_k, v_1 ... v_n 29 | 30 | Here 31 | * `p_1 ... p_k`: dimensions for enumerating the parts 32 | * `v_1 ... v_n`: dimensions listing the independent cases (batching), 33 | * `d` is the dimension to compute the distances on. 34 | 35 | The resulting contributions will be normalized so that 36 | for each v_: sum(over p_ of result(p_, v_)) = 1. 37 | """ 38 | EPS = 1e-5 39 | 40 | k = len(parts.shape) - len(whole.shape) 41 | assert k >= 0 42 | assert parts.shape[k:] == whole.shape 43 | bc_whole = whole.expand(parts.shape) # new dims p_1 ... p_k are added to the front 44 | 45 | distance = torch.nn.functional.pairwise_distance(parts, bc_whole, p=distance_norm) 46 | 47 | whole_norm = torch.norm(whole, p=distance_norm, dim=-1) 48 | distance = (whole_norm - distance).clip(min=EPS) 49 | 50 | sum = distance.sum(dim=tuple(range(k)), keepdim=True) 51 | 52 | return distance / sum 53 | 54 | 55 | @torch.no_grad() 56 | @typechecked 57 | def get_contributions_with_one_off_part( 58 | parts: torch.Tensor, 59 | one_off: torch.Tensor, 60 | whole: torch.Tensor, 61 | distance_norm: int = 1, 62 | ) -> Tuple[torch.Tensor, torch.Tensor]: 63 | """ 64 | Same as computing the contributions, but there is one additional part. That's useful 65 | because we always have the residual stream as one of the parts. 66 | 67 | See `get_contributions` documentation about `parts` and `whole` dimensions. The 68 | `one_off` should have the same dimensions as `whole`. 69 | 70 | Returns a pair consisting of 71 | 1. contributions tensor for the `parts` 72 | 2. contributions tensor for the `one_off` vector 73 | """ 74 | assert one_off.shape == whole.shape 75 | 76 | k = len(parts.shape) - len(whole.shape) 77 | assert k >= 0 78 | 79 | # Flatten the p_ dimensions, get contributions for the list, unflatten. 80 | flat = parts.flatten(start_dim=0, end_dim=k - 1) 81 | flat = torch.cat([flat, one_off.unsqueeze(0)]) 82 | contributions = get_contributions(flat, whole, distance_norm) 83 | parts_contributions, one_off_contributions = torch.split( 84 | contributions, flat.shape[0] - 1 85 | ) 86 | return ( 87 | parts_contributions.unflatten(0, parts.shape[0:k]), 88 | one_off_contributions[0], 89 | ) 90 | 91 | 92 | @torch.no_grad() 93 | @typechecked 94 | def get_attention_contributions( 95 | resid_pre: Float[torch.Tensor, "batch pos d_model"], 96 | resid_mid: Float[torch.Tensor, "batch pos d_model"], 97 | decomposed_attn: Float[torch.Tensor, "batch pos key_pos head d_model"], 98 | distance_norm: int = 1, 99 | ) -> Tuple[ 100 | Float[torch.Tensor, "batch pos key_pos head"], 101 | Float[torch.Tensor, "batch pos"], 102 | ]: 103 | """ 104 | Returns a pair of 105 | - a tensor of contributions of each token via each head 106 | - the contribution of the residual stream. 107 | """ 108 | 109 | # part dimensions | batch dimensions | vector dimension 110 | # ----------------+------------------+----------------- 111 | # key_pos, head | batch, pos | d_model 112 | parts = einops.rearrange( 113 | decomposed_attn, 114 | "batch pos key_pos head d_model -> key_pos head batch pos d_model", 115 | ) 116 | attn_contribution, residual_contribution = get_contributions_with_one_off_part( 117 | parts, resid_pre, resid_mid, distance_norm 118 | ) 119 | return ( 120 | einops.rearrange( 121 | attn_contribution, "key_pos head batch pos -> batch pos key_pos head" 122 | ), 123 | residual_contribution, 124 | ) 125 | 126 | 127 | @torch.no_grad() 128 | @typechecked 129 | def get_mlp_contributions( 130 | resid_mid: Float[torch.Tensor, "batch pos d_model"], 131 | resid_post: Float[torch.Tensor, "batch pos d_model"], 132 | mlp_out: Float[torch.Tensor, "batch pos d_model"], 133 | distance_norm: int = 1, 134 | ) -> Tuple[Float[torch.Tensor, "batch pos"], Float[torch.Tensor, "batch pos"]]: 135 | """ 136 | Returns a pair of (mlp, residual) contributions for each sentence and token. 137 | """ 138 | 139 | contributions = get_contributions( 140 | torch.stack((mlp_out, resid_mid)), resid_post, distance_norm 141 | ) 142 | return contributions[0], contributions[1] 143 | 144 | 145 | @torch.no_grad() 146 | @typechecked 147 | def get_decomposed_mlp_contributions( 148 | resid_mid: Float[torch.Tensor, "d_model"], 149 | resid_post: Float[torch.Tensor, "d_model"], 150 | decomposed_mlp_out: Float[torch.Tensor, "hidden d_model"], 151 | distance_norm: int = 1, 152 | ) -> Tuple[Float[torch.Tensor, "hidden"], float]: 153 | """ 154 | Similar to `get_mlp_contributions`, but it takes the MLP output for each neuron of 155 | the hidden layer and thus computes a contribution per neuron. 156 | 157 | Doesn't contain batch and token dimensions for sake of saving memory. But we may 158 | consider adding them. 159 | """ 160 | 161 | neuron_contributions, residual_contribution = get_contributions_with_one_off_part( 162 | decomposed_mlp_out, resid_mid, resid_post, distance_norm 163 | ) 164 | return neuron_contributions, residual_contribution.item() 165 | 166 | 167 | @torch.no_grad() 168 | def apply_threshold_and_renormalize( 169 | threshold: float, 170 | c_blocks: torch.Tensor, 171 | c_residual: torch.Tensor, 172 | ) -> Tuple[torch.Tensor, torch.Tensor]: 173 | """ 174 | Thresholding mechanism used in the original graphs paper. After the threshold is 175 | applied, the remaining contributions are renormalized on order to sum up to 1 for 176 | each representation. 177 | 178 | threshold: The threshold. 179 | c_residual: Contribution of the residual stream for each representation. This tensor 180 | should contain 1 element per representation, i.e., its dimensions are all batch 181 | dimensions. 182 | c_blocks: Contributions of the blocks. Could be 1 block per representation, like 183 | ffn, or heads*tokens blocks in case of attention. The shape of `c_residual` 184 | must be a prefix if the shape of this tensor. The remaining dimensions are for 185 | listing the blocks. 186 | """ 187 | 188 | block_dims = len(c_blocks.shape) 189 | resid_dims = len(c_residual.shape) 190 | bound_dims = block_dims - resid_dims 191 | assert bound_dims >= 0 192 | assert c_blocks.shape[0:resid_dims] == c_residual.shape 193 | 194 | c_blocks = c_blocks * (c_blocks > threshold) 195 | c_residual = c_residual * (c_residual > threshold) 196 | 197 | denom = c_residual + c_blocks.sum(dim=tuple(range(resid_dims, block_dims))) 198 | return ( 199 | c_blocks / denom.reshape(denom.shape + (1,) * bound_dims), 200 | c_residual / denom, 201 | ) 202 | -------------------------------------------------------------------------------- /llm_transparency_tool/models/tlens_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass 8 | from typing import List, Optional 9 | 10 | import torch 11 | import transformer_lens 12 | import transformers 13 | from fancy_einsum import einsum 14 | from jaxtyping import Float, Int 15 | from typeguard import typechecked 16 | import streamlit as st 17 | 18 | from llm_transparency_tool.models.transparent_llm import ModelInfo, TransparentLlm 19 | from transformer_lens.loading_from_pretrained import MODEL_ALIASES, get_official_model_name 20 | 21 | 22 | @dataclass 23 | class _RunInfo: 24 | tokens: Int[torch.Tensor, "batch pos"] 25 | logits: Float[torch.Tensor, "batch pos d_vocab"] 26 | cache: transformer_lens.ActivationCache 27 | 28 | 29 | @st.cache_resource( 30 | max_entries=1, 31 | show_spinner=True, 32 | hash_funcs={ 33 | transformers.PreTrainedModel: id, 34 | transformers.PreTrainedTokenizer: id 35 | } 36 | ) 37 | def load_hooked_transformer( 38 | model_name: str, 39 | hf_model: Optional[transformers.PreTrainedModel] = None, 40 | tlens_device: str = "cuda", 41 | dtype: torch.dtype = torch.float32, 42 | supported_model_name: Optional[str] = None, 43 | ): 44 | if supported_model_name is None: 45 | supported_model_name = model_name 46 | supported_model_name = get_official_model_name(supported_model_name) 47 | if model_name not in MODEL_ALIASES: 48 | MODEL_ALIASES[supported_model_name] = [] 49 | if model_name not in MODEL_ALIASES[supported_model_name]: 50 | MODEL_ALIASES[supported_model_name].append(model_name) 51 | tlens_model = transformer_lens.HookedTransformer.from_pretrained( 52 | model_name, 53 | hf_model=hf_model, 54 | fold_ln=False, # Keep layer norm where it is. 55 | center_writing_weights=False, 56 | center_unembed=False, 57 | device=tlens_device, 58 | dtype=dtype, 59 | ) 60 | tlens_model.eval() 61 | return tlens_model 62 | 63 | 64 | # TODO(igortufanov): If we want to scale the app to multiple users, we need more careful 65 | # thread-safe implementation. The simplest option could be to wrap the existing methods 66 | # in mutexes. 67 | class TransformerLensTransparentLlm(TransparentLlm): 68 | """ 69 | Implementation of Transparent LLM based on transformer lens. 70 | 71 | Args: 72 | - model_name: The official name of the model from HuggingFace. Even if the model was 73 | patched or loaded locally, the name should still be official because that's how 74 | transformer_lens treats the model. 75 | - hf_model: The language model as a HuggingFace class. 76 | - tokenizer, 77 | - device: "gpu" or "cpu" 78 | """ 79 | 80 | def __init__( 81 | self, 82 | model_name: str, 83 | hf_model: Optional[transformers.PreTrainedModel] = None, 84 | tokenizer: Optional[transformers.PreTrainedTokenizer] = None, 85 | device: str = "gpu", 86 | dtype: torch.dtype = torch.float32, 87 | supported_model_name: str = None, 88 | ): 89 | if device == "gpu": 90 | self.device = "cuda" 91 | if not torch.cuda.is_available(): 92 | RuntimeError("Asked to run on gpu, but torch couldn't find cuda") 93 | elif device == "cpu": 94 | self.device = "cpu" 95 | else: 96 | raise RuntimeError(f"Specified device {device} is not a valid option") 97 | 98 | self.dtype = dtype 99 | self.hf_tokenizer = tokenizer 100 | self.hf_model = hf_model 101 | 102 | # self._model = tlens_model 103 | self._model_name = model_name 104 | self._supported_model_name = supported_model_name 105 | self._prepend_bos = True 106 | self._last_run = None 107 | self._run_exception = RuntimeError( 108 | "Tried to use the model output before calling the `run` method" 109 | ) 110 | 111 | def copy(self): 112 | import copy 113 | return copy.copy(self) 114 | 115 | @property 116 | def _model(self): 117 | tlens_model = load_hooked_transformer( 118 | self._model_name, 119 | hf_model=self.hf_model, 120 | tlens_device=self.device, 121 | dtype=self.dtype, 122 | supported_model_name=self._supported_model_name, 123 | ) 124 | 125 | if self.hf_tokenizer is not None: 126 | tlens_model.set_tokenizer(self.hf_tokenizer, default_padding_side="left") 127 | 128 | tlens_model.set_use_attn_result(True) 129 | tlens_model.set_use_attn_in(False) 130 | tlens_model.set_use_split_qkv_input(False) 131 | 132 | return tlens_model 133 | 134 | def model_info(self) -> ModelInfo: 135 | cfg = self._model.cfg 136 | return ModelInfo( 137 | name=self._model_name, 138 | n_params_estimate=cfg.n_params, 139 | n_layers=cfg.n_layers, 140 | n_heads=cfg.n_heads, 141 | d_model=cfg.d_model, 142 | d_vocab=cfg.d_vocab, 143 | ) 144 | 145 | @torch.no_grad() 146 | def run(self, sentences: List[str]) -> None: 147 | tokens = self._model.to_tokens(sentences, prepend_bos=self._prepend_bos) 148 | logits, cache = self._model.run_with_cache(tokens) 149 | 150 | self._last_run = _RunInfo( 151 | tokens=tokens, 152 | logits=logits, 153 | cache=cache, 154 | ) 155 | 156 | def batch_size(self) -> int: 157 | if not self._last_run: 158 | raise self._run_exception 159 | return self._last_run.logits.shape[0] 160 | 161 | @typechecked 162 | def tokens(self) -> Int[torch.Tensor, "batch pos"]: 163 | if not self._last_run: 164 | raise self._run_exception 165 | return self._last_run.tokens 166 | 167 | @typechecked 168 | def tokens_to_strings(self, tokens: Int[torch.Tensor, "pos"]) -> List[str]: 169 | return self._model.to_str_tokens(tokens) 170 | 171 | @typechecked 172 | def logits(self) -> Float[torch.Tensor, "batch pos d_vocab"]: 173 | if not self._last_run: 174 | raise self._run_exception 175 | return self._last_run.logits 176 | 177 | @torch.no_grad() 178 | @typechecked 179 | def unembed( 180 | self, 181 | t: Float[torch.Tensor, "d_model"], 182 | normalize: bool, 183 | ) -> Float[torch.Tensor, "vocab"]: 184 | # t: [d_model] -> [batch, pos, d_model] 185 | tdim = t.unsqueeze(0).unsqueeze(0) 186 | if normalize: 187 | normalized = self._model.ln_final(tdim) 188 | result = self._model.unembed(normalized) 189 | else: 190 | result = self._model.unembed(tdim.to(self.dtype)) 191 | return result[0][0] 192 | 193 | def _get_block(self, layer: int, block_name: str) -> torch.Tensor: 194 | if not self._last_run: 195 | raise self._run_exception 196 | return self._last_run.cache[f"blocks.{layer}.{block_name}"] 197 | 198 | # ================= Methods related to the residual stream ================= 199 | 200 | @typechecked 201 | def residual_in(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]: 202 | if not self._last_run: 203 | raise self._run_exception 204 | return self._get_block(layer, "hook_resid_pre") 205 | 206 | @typechecked 207 | def residual_after_attn( 208 | self, layer: int 209 | ) -> Float[torch.Tensor, "batch pos d_model"]: 210 | if not self._last_run: 211 | raise self._run_exception 212 | return self._get_block(layer, "hook_resid_mid") 213 | 214 | @typechecked 215 | def residual_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]: 216 | if not self._last_run: 217 | raise self._run_exception 218 | return self._get_block(layer, "hook_resid_post") 219 | 220 | # ================ Methods related to the feed-forward layer =============== 221 | 222 | @typechecked 223 | def ffn_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]: 224 | if not self._last_run: 225 | raise self._run_exception 226 | return self._get_block(layer, "hook_mlp_out") 227 | 228 | @torch.no_grad() 229 | @typechecked 230 | def decomposed_ffn_out( 231 | self, 232 | batch_i: int, 233 | layer: int, 234 | pos: int, 235 | ) -> Float[torch.Tensor, "hidden d_model"]: 236 | # Take activations right before they're multiplied by W_out, i.e. non-linearity 237 | # and layer norm are already applied. 238 | processed_activations = self._get_block(layer, "mlp.hook_post")[batch_i][pos] 239 | return torch.mul(processed_activations.unsqueeze(-1), self._model.blocks[layer].mlp.W_out) 240 | 241 | @typechecked 242 | def neuron_activations( 243 | self, 244 | batch_i: int, 245 | layer: int, 246 | pos: int, 247 | ) -> Float[torch.Tensor, "hidden"]: 248 | return self._get_block(layer, "mlp.hook_pre")[batch_i][pos] 249 | 250 | @typechecked 251 | def neuron_output( 252 | self, 253 | layer: int, 254 | neuron: int, 255 | ) -> Float[torch.Tensor, "d_model"]: 256 | return self._model.blocks[layer].mlp.W_out[neuron] 257 | 258 | # ==================== Methods related to the attention ==================== 259 | 260 | @typechecked 261 | def attention_matrix( 262 | self, batch_i: int, layer: int, head: int 263 | ) -> Float[torch.Tensor, "query_pos key_pos"]: 264 | return self._get_block(layer, "attn.hook_pattern")[batch_i][head] 265 | 266 | @typechecked 267 | def attention_output_per_head( 268 | self, 269 | batch_i: int, 270 | layer: int, 271 | pos: int, 272 | head: int, 273 | ) -> Float[torch.Tensor, "d_model"]: 274 | return self._get_block(layer, "attn.hook_result")[batch_i][pos][head] 275 | 276 | @typechecked 277 | def attention_output( 278 | self, 279 | batch_i: int, 280 | layer: int, 281 | pos: int, 282 | ) -> Float[torch.Tensor, "d_model"]: 283 | return self._get_block(layer, "hook_attn_out")[batch_i][pos] 284 | 285 | @torch.no_grad() 286 | @typechecked 287 | def decomposed_attn( 288 | self, batch_i: int, layer: int 289 | ) -> Float[torch.Tensor, "pos key_pos head d_model"]: 290 | if not self._last_run: 291 | raise self._run_exception 292 | hook_v = self._get_block(layer, "attn.hook_v")[batch_i] 293 | b_v = self._model.blocks[layer].attn.b_V 294 | 295 | # support for gqa 296 | num_head_groups = b_v.shape[-2] // hook_v.shape[-2] 297 | hook_v = hook_v.repeat_interleave(num_head_groups, dim=-2) 298 | 299 | v = hook_v + b_v 300 | pattern = self._get_block(layer, "attn.hook_pattern")[batch_i].to(v.dtype) 301 | z = einsum( 302 | "key_pos head d_head, " 303 | "head query_pos key_pos -> " 304 | "query_pos key_pos head d_head", 305 | v, 306 | pattern, 307 | ) 308 | decomposed_attn = einsum( 309 | "pos key_pos head d_head, " 310 | "head d_head d_model -> " 311 | "pos key_pos head d_model", 312 | z, 313 | self._model.blocks[layer].attn.W_O, 314 | ) 315 | return decomposed_attn 316 | -------------------------------------------------------------------------------- /llm_transparency_tool/components/frontend/src/ContributionGraph.tsx: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | import { 10 | ComponentProps, 11 | Streamlit, 12 | withStreamlitConnection, 13 | } from 'streamlit-component-lib' 14 | import React, { useEffect, useMemo, useRef, useState } from 'react'; 15 | import * as d3 from 'd3'; 16 | 17 | import { 18 | Label, 19 | Point, 20 | } from './common'; 21 | import './LlmViewer.css'; 22 | 23 | export const renderParams = { 24 | cellH: 32, 25 | cellW: 32, 26 | attnSize: 8, 27 | afterFfnSize: 8, 28 | ffnSize: 6, 29 | tokenSelectorSize: 16, 30 | layerCornerRadius: 6, 31 | } 32 | 33 | interface Cell { 34 | layer: number 35 | token: number 36 | } 37 | 38 | enum CellItem { 39 | AfterAttn = 'after_attn', 40 | AfterFfn = 'after_ffn', 41 | Ffn = 'ffn', 42 | Original = 'original', // They will only be at level = 0 43 | } 44 | 45 | interface Node { 46 | cell: Cell | null 47 | item: CellItem | null 48 | } 49 | 50 | interface NodeProps { 51 | node: Node 52 | pos: Point 53 | isActive: boolean 54 | } 55 | 56 | interface EdgeRaw { 57 | weight: number 58 | source: string 59 | target: string 60 | } 61 | 62 | interface Edge { 63 | weight: number 64 | from: Node 65 | to: Node 66 | fromPos: Point 67 | toPos: Point 68 | isSelectable: boolean 69 | isFfn: boolean 70 | } 71 | 72 | interface Selection { 73 | node: Node | null 74 | edge: Edge | null 75 | } 76 | 77 | function tokenPointerPolygon(origin: Point) { 78 | const r = renderParams.tokenSelectorSize / 2 79 | const dy = r / 2 80 | const dx = r * Math.sqrt(3.0) / 2 81 | // Draw an arrow looking down 82 | return [ 83 | [origin.x, origin.y + r], 84 | [origin.x + dx, origin.y - dy], 85 | [origin.x - dx, origin.y - dy], 86 | ].toString() 87 | } 88 | 89 | function isSameCell(cell1: Cell | null, cell2: Cell | null) { 90 | if (cell1 == null || cell2 == null) { 91 | return false 92 | } 93 | return cell1.layer === cell2.layer && cell1.token === cell2.token 94 | } 95 | 96 | function isSameNode(node1: Node | null, node2: Node | null) { 97 | if (node1 === null || node2 === null) { 98 | return false 99 | } 100 | return isSameCell(node1.cell, node2.cell) 101 | && node1.item === node2.item; 102 | } 103 | 104 | function isSameEdge(edge1: Edge | null, edge2: Edge | null) { 105 | if (edge1 === null || edge2 === null) { 106 | return false 107 | } 108 | return isSameNode(edge1.from, edge2.from) && isSameNode(edge1.to, edge2.to); 109 | } 110 | 111 | function nodeFromString(name: string) { 112 | const match = name.match(/([AIMX])(\d+)_(\d+)/) 113 | if (match == null) { 114 | return { 115 | cell: null, 116 | item: null, 117 | } 118 | } 119 | const [, type, layerStr, tokenStr] = match 120 | const layer = +layerStr 121 | const token = +tokenStr 122 | 123 | const typeToCellItem = new Map([ 124 | ['A', CellItem.AfterAttn], 125 | ['I', CellItem.AfterFfn], 126 | ['M', CellItem.Ffn], 127 | ['X', CellItem.Original], 128 | ]) 129 | return { 130 | cell: { 131 | layer: layer, 132 | token: token, 133 | }, 134 | item: typeToCellItem.get(type) ?? null, 135 | } 136 | } 137 | 138 | function isValidNode(node: Node, nLayers: number, nTokens: number) { 139 | if (node.cell === null) { 140 | return true 141 | } 142 | return node.cell.layer < nLayers && node.cell.token < nTokens 143 | } 144 | 145 | function isValidSelection(selection: Selection, nLayers: number, nTokens: number) { 146 | if (selection.node !== null) { 147 | return isValidNode(selection.node, nLayers, nTokens) 148 | } 149 | if (selection.edge !== null) { 150 | return isValidNode(selection.edge.from, nLayers, nTokens) && 151 | isValidNode(selection.edge.to, nLayers, nTokens) 152 | } 153 | return true 154 | } 155 | 156 | const ContributionGraph = ({ args }: ComponentProps) => { 157 | const modelInfo = args['model_info'] 158 | const tokens = args['tokens'] 159 | const edgesRaw: EdgeRaw[][] = args['edges_per_token'] 160 | 161 | const nLayers = modelInfo === null ? 0 : modelInfo.n_layers 162 | const nTokens = tokens === null ? 0 : tokens.length 163 | 164 | const [selection, setSelection] = useState({ 165 | node: null, 166 | edge: null, 167 | }) 168 | var curSelection = selection 169 | if (!isValidSelection(selection, nLayers, nTokens)) { 170 | curSelection = { 171 | node: null, 172 | edge: null, 173 | } 174 | setSelection(curSelection) 175 | Streamlit.setComponentValue(curSelection) 176 | } 177 | 178 | const [startToken, setStartToken] = useState(nTokens - 1) 179 | // We have startToken state var, but it won't be updated till next render, so use 180 | // this var in the current render. 181 | var curStartToken = startToken 182 | if (startToken >= nTokens) { 183 | curStartToken = nTokens - 1 184 | setStartToken(curStartToken) 185 | } 186 | 187 | const handleRepresentationClick = (node: Node) => { 188 | const newSelection: Selection = { 189 | node: node, 190 | edge: null, 191 | } 192 | setSelection(newSelection) 193 | Streamlit.setComponentValue(newSelection) 194 | } 195 | 196 | const handleEdgeClick = (edge: Edge) => { 197 | if (!edge.isSelectable) { 198 | return 199 | } 200 | const newSelection: Selection = { 201 | node: edge.to, 202 | edge: edge, 203 | } 204 | setSelection(newSelection) 205 | Streamlit.setComponentValue(newSelection) 206 | } 207 | 208 | const handleTokenClick = (t: number) => { 209 | setStartToken(t) 210 | } 211 | 212 | const [xScale, yScale] = useMemo(() => { 213 | const x = d3.scaleLinear() 214 | .domain([-2, nTokens - 1]) 215 | .range([0, renderParams.cellW * (nTokens + 2)]) 216 | const y = d3.scaleLinear() 217 | .domain([-1, nLayers]) 218 | .range([renderParams.cellH * (nLayers + 2), 0]) 219 | return [x, y] 220 | }, [nLayers, nTokens]) 221 | 222 | const cells = useMemo(() => { 223 | let result: Cell[] = [] 224 | for (let l = 0; l < nLayers; l++) { 225 | for (let t = 0; t < nTokens; t++) { 226 | result.push({ 227 | layer: l, 228 | token: t, 229 | }) 230 | } 231 | } 232 | return result 233 | }, [nLayers, nTokens]) 234 | 235 | const nodeCoords = useMemo(() => { 236 | let result = new Map() 237 | const w = renderParams.cellW 238 | const h = renderParams.cellH 239 | for (var cell of cells) { 240 | const cx = xScale(cell.token + 0.5) 241 | const cy = yScale(cell.layer - 0.5) 242 | result.set( 243 | JSON.stringify({ cell: cell, item: CellItem.AfterAttn }), 244 | { x: cx, y: cy + h / 4 }, 245 | ) 246 | result.set( 247 | JSON.stringify({ cell: cell, item: CellItem.AfterFfn }), 248 | { x: cx, y: cy - h / 4 }, 249 | ) 250 | result.set( 251 | JSON.stringify({ cell: cell, item: CellItem.Ffn }), 252 | { x: cx + 5 * w / 16, y: cy }, 253 | ) 254 | } 255 | for (let t = 0; t < nTokens; t++) { 256 | cell = { 257 | layer: 0, 258 | token: t, 259 | } 260 | const cx = xScale(cell.token + 0.5) 261 | const cy = yScale(cell.layer - 1.0) 262 | result.set( 263 | JSON.stringify({ cell: cell, item: CellItem.Original }), 264 | { x: cx, y: cy + h / 4 }, 265 | ) 266 | } 267 | return result 268 | }, [cells, nTokens, xScale, yScale]) 269 | 270 | const edges: Edge[][] = useMemo(() => { 271 | let result = [] 272 | for (var edgeList of edgesRaw) { 273 | let edgesPerStartToken = [] 274 | for (var edge of edgeList) { 275 | const u = nodeFromString(edge.source) 276 | const v = nodeFromString(edge.target) 277 | var isSelectable = ( 278 | u.cell !== null && v.cell !== null && v.item === CellItem.AfterAttn 279 | ) 280 | var isFfn = ( 281 | u.cell !== null && v.cell !== null && ( 282 | u.item === CellItem.Ffn || v.item === CellItem.Ffn 283 | ) 284 | ) 285 | edgesPerStartToken.push({ 286 | weight: edge.weight, 287 | from: u, 288 | to: v, 289 | fromPos: nodeCoords.get(JSON.stringify(u)) ?? { 'x': 0, 'y': 0 }, 290 | toPos: nodeCoords.get(JSON.stringify(v)) ?? { 'x': 0, 'y': 0 }, 291 | isSelectable: isSelectable, 292 | isFfn: isFfn, 293 | }) 294 | } 295 | result.push(edgesPerStartToken) 296 | } 297 | return result 298 | }, [edgesRaw, nodeCoords]) 299 | 300 | const activeNodes = useMemo(() => { 301 | let result = new Set() 302 | for (var edge of edges[curStartToken]) { 303 | const u = JSON.stringify(edge.from) 304 | const v = JSON.stringify(edge.to) 305 | result.add(u) 306 | result.add(v) 307 | } 308 | return result 309 | }, [edges, curStartToken]) 310 | 311 | const nodeProps = useMemo(() => { 312 | let result: Array = [] 313 | nodeCoords.forEach((p: Point, node: string) => { 314 | result.push({ 315 | node: JSON.parse(node), 316 | pos: p, 317 | isActive: activeNodes.has(node), 318 | }) 319 | }) 320 | return result 321 | }, [nodeCoords, activeNodes]) 322 | 323 | const tokenLabels: Label[] = useMemo(() => { 324 | if (!tokens) { 325 | return [] 326 | } 327 | return tokens.map((s: string, i: number) => ({ 328 | text: s.replace(/ /g, '·'), 329 | pos: { 330 | x: xScale(i + 0.5), 331 | y: yScale(-1.5), 332 | }, 333 | })) 334 | }, [tokens, xScale, yScale]) 335 | 336 | const layerLabels: Label[] = useMemo(() => { 337 | return Array.from(Array(nLayers).keys()).map(i => ({ 338 | text: 'L' + i, 339 | pos: { 340 | x: xScale(-0.25), 341 | y: yScale(i - 0.5), 342 | }, 343 | })) 344 | }, [nLayers, xScale, yScale]) 345 | 346 | const tokenSelectors: Array<[number, Point]> = useMemo(() => { 347 | return Array.from(Array(nTokens).keys()).map(i => ([ 348 | i, 349 | { 350 | x: xScale(i + 0.5), 351 | y: yScale(nLayers - 0.5), 352 | } 353 | ])) 354 | }, [nTokens, nLayers, xScale, yScale]) 355 | 356 | const totalW = xScale(nTokens + 2) 357 | const totalH = yScale(-4) 358 | useEffect(() => { 359 | Streamlit.setFrameHeight(totalH) 360 | }, [totalH]) 361 | 362 | const colorScale = d3.scaleLinear( 363 | [0.0, 0.5, 1.0], 364 | ['#9eba66', 'darkolivegreen', 'darkolivegreen'] 365 | ) 366 | const ffnEdgeColorScale = d3.scaleLinear( 367 | [0.0, 0.5, 1.0], 368 | ['orchid', 'purple', 'purple'] 369 | ) 370 | const edgeWidthScale = d3.scaleLinear([0.0, 0.5, 1.0], [2.0, 3.0, 3.0]) 371 | 372 | const svgRef = useRef(null); 373 | 374 | useEffect(() => { 375 | const getNodeStyle = (p: NodeProps, type: string) => { 376 | if (isSameNode(p.node, curSelection.node)) { 377 | return 'selectable-item selection' 378 | } 379 | if (p.isActive) { 380 | return 'selectable-item active-' + type + '-node' 381 | } 382 | return 'selectable-item inactive-node' 383 | } 384 | 385 | const svg = d3.select(svgRef.current) 386 | svg.selectAll('*').remove() 387 | 388 | svg 389 | .selectAll('layers') 390 | .data(Array.from(Array(nLayers).keys()).filter((x) => x % 2 === 1)) 391 | .enter() 392 | .append('rect') 393 | .attr('class', 'layer-highlight') 394 | .attr('x', xScale(-1.0)) 395 | .attr('y', (layer) => yScale(layer)) 396 | .attr('width', xScale(nTokens + 0.25) - xScale(-1.0)) 397 | .attr('height', (layer) => yScale(layer) - yScale(layer + 1)) 398 | .attr('rx', renderParams.layerCornerRadius) 399 | 400 | svg 401 | .selectAll('edges') 402 | .data(edges[curStartToken]) 403 | .enter() 404 | .append('line') 405 | .style('stroke', (edge: Edge) => { 406 | if (isSameEdge(edge, curSelection.edge)) { 407 | return 'orange' 408 | } 409 | if (edge.isFfn) { 410 | return ffnEdgeColorScale(edge.weight) 411 | } 412 | return colorScale(edge.weight) 413 | }) 414 | .attr('class', (edge: Edge) => edge.isSelectable ? 'selectable-edge' : '') 415 | .style('stroke-width', (edge: Edge) => edgeWidthScale(edge.weight)) 416 | .attr('x1', (edge: Edge) => edge.fromPos.x) 417 | .attr('y1', (edge: Edge) => edge.fromPos.y) 418 | .attr('x2', (edge: Edge) => edge.toPos.x) 419 | .attr('y2', (edge: Edge) => edge.toPos.y) 420 | .on('click', (event: PointerEvent, edge) => { 421 | handleEdgeClick(edge) 422 | }) 423 | 424 | svg 425 | .selectAll('residual') 426 | .data(nodeProps) 427 | .enter() 428 | .filter((p) => { 429 | return p.node.item === CellItem.AfterAttn 430 | || p.node.item === CellItem.AfterFfn 431 | }) 432 | .append('circle') 433 | .attr('class', (p) => getNodeStyle(p, 'residual')) 434 | .attr('cx', (p) => p.pos.x) 435 | .attr('cy', (p) => p.pos.y) 436 | .attr('r', renderParams.attnSize / 2) 437 | .on('click', (event: PointerEvent, p) => { 438 | handleRepresentationClick(p.node) 439 | }) 440 | 441 | svg 442 | .selectAll('ffn') 443 | .data(nodeProps) 444 | .enter() 445 | .filter((p) => p.node.item === CellItem.Ffn && p.isActive) 446 | .append('rect') 447 | .attr('class', (p) => getNodeStyle(p, 'ffn')) 448 | .attr('x', (p) => p.pos.x - renderParams.ffnSize / 2) 449 | .attr('y', (p) => p.pos.y - renderParams.ffnSize / 2) 450 | .attr('width', renderParams.ffnSize) 451 | .attr('height', renderParams.ffnSize) 452 | .on('click', (event: PointerEvent, p) => { 453 | handleRepresentationClick(p.node) 454 | }) 455 | 456 | svg 457 | .selectAll('token_labels') 458 | .data(tokenLabels) 459 | .enter() 460 | .append('text') 461 | .attr('x', (label: Label) => label.pos.x) 462 | .attr('y', (label: Label) => label.pos.y) 463 | .attr('text-anchor', 'end') 464 | .attr('dominant-baseline', 'middle') 465 | .attr('alignment-baseline', 'top') 466 | .attr('transform', (label: Label) => 467 | 'rotate(-40, ' + label.pos.x + ', ' + label.pos.y + ')') 468 | .text((label: Label) => label.text) 469 | 470 | svg 471 | .selectAll('layer_labels') 472 | .data(layerLabels) 473 | .enter() 474 | .append('text') 475 | .attr('x', (label: Label) => label.pos.x) 476 | .attr('y', (label: Label) => label.pos.y) 477 | .attr('text-anchor', 'middle') 478 | .attr('alignment-baseline', 'middle') 479 | .text((label: Label) => label.text) 480 | 481 | svg 482 | .selectAll('token_selectors') 483 | .data(tokenSelectors) 484 | .enter() 485 | .append('polygon') 486 | .attr('class', ([i,]) => ( 487 | curStartToken === i 488 | ? 'selectable-item selection' 489 | : 'selectable-item token-selector' 490 | )) 491 | .attr('points', ([, p]) => tokenPointerPolygon(p)) 492 | .attr('r', renderParams.tokenSelectorSize / 2) 493 | .on('click', (event: PointerEvent, [i,]) => { 494 | handleTokenClick(i) 495 | }) 496 | }, [ 497 | cells, 498 | edges, 499 | nodeProps, 500 | tokenLabels, 501 | layerLabels, 502 | tokenSelectors, 503 | curStartToken, 504 | curSelection, 505 | colorScale, 506 | ffnEdgeColorScale, 507 | edgeWidthScale, 508 | nLayers, 509 | nTokens, 510 | xScale, 511 | yScale 512 | ]) 513 | 514 | return 515 | } 516 | 517 | export default withStreamlitConnection(ContributionGraph) 518 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /llm_transparency_tool/server/app.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | from dataclasses import dataclass, field 9 | from typing import Dict, List, Optional, Tuple 10 | 11 | import networkx as nx 12 | import pandas as pd 13 | import plotly.express 14 | import plotly.graph_objects as go 15 | import streamlit as st 16 | import streamlit_extras.row as st_row 17 | import torch 18 | from jaxtyping import Float 19 | from torch.amp import autocast 20 | from transformers import HfArgumentParser 21 | 22 | import llm_transparency_tool.components 23 | from llm_transparency_tool.models.tlens_model import TransformerLensTransparentLlm 24 | import llm_transparency_tool.routes.contributions as contributions 25 | import llm_transparency_tool.routes.graph 26 | from llm_transparency_tool.models.transparent_llm import TransparentLlm 27 | from llm_transparency_tool.routes.graph_node import NodeType 28 | from llm_transparency_tool.server.graph_selection import ( 29 | GraphSelection, 30 | UiGraphEdge, 31 | UiGraphNode, 32 | ) 33 | from llm_transparency_tool.server.styles import ( 34 | RenderSettings, 35 | logits_color_map, 36 | margins_css, 37 | string_to_display, 38 | ) 39 | from llm_transparency_tool.server.utils import ( 40 | B0, 41 | get_contribution_graph, 42 | load_dataset, 43 | load_model, 44 | possible_devices, 45 | run_model_with_session_caching, 46 | st_placeholder, 47 | ) 48 | from llm_transparency_tool.server.monitor import SystemMonitor 49 | 50 | from networkx.classes.digraph import DiGraph 51 | 52 | 53 | @st.cache_resource( 54 | hash_funcs={ 55 | nx.Graph: id, 56 | DiGraph: id 57 | } 58 | ) 59 | def cached_build_paths_to_predictions( 60 | graph: nx.Graph, 61 | n_layers: int, 62 | n_tokens: int, 63 | starting_tokens: List[int], 64 | threshold: float, 65 | ): 66 | return llm_transparency_tool.routes.graph.build_paths_to_predictions( 67 | graph, n_layers, n_tokens, starting_tokens, threshold 68 | ) 69 | 70 | @st.cache_resource( 71 | hash_funcs={ 72 | TransformerLensTransparentLlm: id 73 | } 74 | ) 75 | def cached_run_inference_and_populate_state( 76 | stateless_model, 77 | sentences, 78 | ): 79 | stateful_model = stateless_model.copy() 80 | stateful_model.run(sentences) 81 | return stateful_model 82 | 83 | 84 | @dataclass 85 | class LlmViewerConfig: 86 | debug: bool = field( 87 | default=False, 88 | metadata={"help": "Show debugging information, like the time profile."}, 89 | ) 90 | 91 | preloaded_dataset_filename: Optional[str] = field( 92 | default=None, 93 | metadata={"help": "The name of the text file to load the lines from."}, 94 | ) 95 | 96 | demo_mode: bool = field( 97 | default=False, 98 | metadata={"help": "Whether the app should be in the demo mode."}, 99 | ) 100 | 101 | allow_loading_dataset_files: bool = field( 102 | default=True, 103 | metadata={"help": "Whether the app should be able to load the dataset files " "on the server side."}, 104 | ) 105 | 106 | max_user_string_length: Optional[int] = field( 107 | default=None, 108 | metadata={ 109 | "help": "Limit for the length of user-provided sentences (in characters), " "or None if there is no limit." 110 | }, 111 | ) 112 | 113 | models: Dict[str, str] = field( 114 | default_factory=dict, 115 | metadata={ 116 | "help": "Locations of models which are stored locally. Dictionary: official " 117 | "HuggingFace name -> path to dir. If None is specified, the model will be" 118 | "downloaded from HuggingFace." 119 | }, 120 | ) 121 | 122 | default_model: str = field( 123 | default="", 124 | metadata={"help": "The model to load once the UI is started."}, 125 | ) 126 | 127 | 128 | class App: 129 | _stateful_model: TransparentLlm = None 130 | render_settings = RenderSettings() 131 | _graph: Optional[nx.Graph] = None 132 | _contribution_threshold: float = 0.0 133 | _renormalize_after_threshold: bool = False 134 | _normalize_before_unembedding: bool = True 135 | 136 | @property 137 | def stateful_model(self) -> TransparentLlm: 138 | return self._stateful_model 139 | 140 | def __init__(self, config: LlmViewerConfig): 141 | self._config = config 142 | st.set_page_config(layout="wide") 143 | st.markdown(margins_css, unsafe_allow_html=True) 144 | 145 | def _get_representation(self, node: Optional[UiGraphNode]) -> Optional[Float[torch.Tensor, "d_model"]]: 146 | if node is None: 147 | return None 148 | fn = { 149 | NodeType.AFTER_ATTN: self.stateful_model.residual_after_attn, 150 | NodeType.AFTER_FFN: self.stateful_model.residual_out, 151 | NodeType.FFN: None, 152 | NodeType.ORIGINAL: self.stateful_model.residual_in, 153 | } 154 | return fn[node.type](node.layer)[B0][node.token] 155 | 156 | def draw_model_info(self): 157 | info = self.stateful_model.model_info().__dict__ 158 | df = pd.DataFrame( 159 | data=[str(x) for x in info.values()], 160 | index=info.keys(), 161 | columns=["Model parameter"], 162 | ) 163 | st.dataframe(df, use_container_width=False) 164 | 165 | def draw_dataset_selection(self) -> int: 166 | def update_dataset(filename: Optional[str]): 167 | dataset = load_dataset(filename) if filename is not None else [] 168 | st.session_state["dataset"] = dataset 169 | st.session_state["dataset_file"] = filename 170 | 171 | if "dataset" not in st.session_state: 172 | update_dataset(self._config.preloaded_dataset_filename) 173 | 174 | 175 | if not self._config.demo_mode: 176 | with st.sidebar.expander("Dataset", expanded=False): 177 | if self._config.allow_loading_dataset_files: 178 | row_f = st_row.row([2, 1], vertical_align="bottom") 179 | filename = row_f.text_input("Dataset", value=st.session_state.dataset_file or "", label_visibility="collapsed") 180 | if row_f.button("Load"): 181 | update_dataset(filename) 182 | row_s = st_row.row([2, 1], vertical_align="bottom") 183 | new_sentence = row_s.text_area("New sentence", label_visibility="collapsed") 184 | new_sentence_added = False 185 | 186 | if row_s.button("Add"): 187 | max_len = self._config.max_user_string_length 188 | n = len(new_sentence) 189 | if max_len is None or n <= max_len: 190 | st.session_state.dataset.append(new_sentence) 191 | new_sentence_added = True 192 | st.session_state.sentence_selector = new_sentence 193 | else: 194 | st.warning(f"Sentence length {n} is larger than " f"the configured limit of {max_len}") 195 | 196 | sentences = st.session_state.dataset 197 | selection = st.selectbox( 198 | "Sentence", 199 | sentences, 200 | index=len(sentences) - 1, 201 | key="sentence_selector", 202 | ) 203 | return selection 204 | 205 | def _unembed( 206 | self, 207 | representation: torch.Tensor, 208 | ) -> torch.Tensor: 209 | return self.stateful_model.unembed(representation, normalize=self._normalize_before_unembedding) 210 | 211 | def draw_graph(self, contribution_threshold: float) -> Optional[GraphSelection]: 212 | tokens = self.stateful_model.tokens()[B0] 213 | n_tokens = tokens.shape[0] 214 | model_info = self.stateful_model.model_info() 215 | 216 | graphs = cached_build_paths_to_predictions( 217 | self._graph, 218 | model_info.n_layers, 219 | n_tokens, 220 | range(n_tokens), 221 | contribution_threshold, 222 | ) 223 | 224 | return llm_transparency_tool.components.contribution_graph( 225 | model_info, 226 | self.stateful_model.tokens_to_strings(tokens), 227 | graphs, 228 | key=f"graph_{hash(self.sentence)}", 229 | ) 230 | 231 | def draw_token_matrix( 232 | self, 233 | values: Float[torch.Tensor, "t t"], 234 | tokens: List[str], 235 | value_name: str, 236 | title: str, 237 | ): 238 | assert values.shape[0] == len(tokens) 239 | labels = { 240 | "x": "src", 241 | "y": "tgt", 242 | "color": value_name, 243 | } 244 | 245 | captions = [f"({i}){t}" for i, t in enumerate(tokens)] 246 | 247 | fig = plotly.express.imshow( 248 | values.cpu(), 249 | title=f'{title}', 250 | labels=labels, 251 | x=captions, 252 | y=captions, 253 | color_continuous_scale=self.render_settings.attention_color_map, 254 | aspect="equal", 255 | ) 256 | fig.update_layout( 257 | autosize=True, 258 | margin=go.layout.Margin( 259 | l=50, # left margin 260 | r=0, # right margin 261 | b=100, # bottom margin 262 | t=100, # top margin 263 | # pad=10 # padding 264 | ) 265 | ) 266 | fig.update_xaxes(tickmode="linear") 267 | fig.update_yaxes(tickmode="linear") 268 | fig.update_coloraxes(showscale=False) 269 | 270 | st.plotly_chart(fig, use_container_width=True, theme=None) 271 | 272 | def draw_attn_info(self, edge: UiGraphEdge, container_attention_map) -> Optional[int]: 273 | """ 274 | Returns: the index of the selected head. 275 | """ 276 | 277 | n_heads = self.stateful_model.model_info().n_heads 278 | 279 | layer = edge.target.layer 280 | 281 | head_contrib, _ = contributions.get_attention_contributions( 282 | resid_pre=self.stateful_model.residual_in(layer)[B0].unsqueeze(0), 283 | resid_mid=self.stateful_model.residual_after_attn(layer)[B0].unsqueeze(0), 284 | decomposed_attn=self.stateful_model.decomposed_attn(B0, layer).unsqueeze(0), 285 | ) 286 | 287 | # [batch pos key_pos head] -> [head] 288 | flat_contrib = head_contrib[0, edge.target.token, edge.source.token, :] 289 | assert flat_contrib.shape[0] == n_heads, f"{flat_contrib.shape} vs {n_heads}" 290 | 291 | selected_head = llm_transparency_tool.components.selector( 292 | items=[f"H{h}" if h >= 0 else "All" for h in range(-1, n_heads)], 293 | indices=range(-1, n_heads), 294 | temperatures=[sum(flat_contrib).item()] + flat_contrib.tolist(), 295 | preselected_index=flat_contrib.argmax().item(), 296 | key=f"head_selector_layer_{layer}" #_from_tok_{edge.source.token}_to_tok_{edge.target.token}", 297 | ) 298 | print(f"head_selector_layer_{layer}_from_tok_{edge.source.token}_to_tok_{edge.target.token}") 299 | if selected_head == -1 or selected_head is None: 300 | # selected_head = None 301 | selected_head = flat_contrib.argmax().item() 302 | print('****\n' * 3 + f"selected_head: {selected_head}" + '\n****\n' * 3) 303 | 304 | # Draw attention matrix and contributions for the selected head. 305 | if selected_head is not None: 306 | tokens = [ 307 | string_to_display(s) for s in self.stateful_model.tokens_to_strings(self.stateful_model.tokens()[B0]) 308 | ] 309 | 310 | with container_attention_map: 311 | attn_container, contrib_container = st.columns([1, 1]) 312 | with attn_container: 313 | attn = self.stateful_model.attention_matrix(B0, layer, selected_head) 314 | self.draw_token_matrix( 315 | attn, 316 | tokens, 317 | "attention", 318 | f"Attention map L{layer} H{selected_head}", 319 | ) 320 | with contrib_container: 321 | contrib = head_contrib[B0, :, :, selected_head] 322 | self.draw_token_matrix( 323 | contrib, 324 | tokens, 325 | "contribution", 326 | f"Contribution map L{layer} H{selected_head}", 327 | ) 328 | 329 | return selected_head 330 | 331 | def draw_ffn_info(self, node: UiGraphNode) -> Optional[int]: 332 | """ 333 | Returns: the index of the selected neuron. 334 | """ 335 | 336 | resid_mid = self.stateful_model.residual_after_attn(node.layer)[B0][node.token] 337 | resid_post = self.stateful_model.residual_out(node.layer)[B0][node.token] 338 | decomposed_ffn = self.stateful_model.decomposed_ffn_out(B0, node.layer, node.token) 339 | c_ffn, _ = contributions.get_decomposed_mlp_contributions(resid_mid, resid_post, decomposed_ffn) 340 | 341 | top_values, top_i = c_ffn.sort(descending=True) 342 | n = min(self.render_settings.n_top_neurons, c_ffn.shape[0]) 343 | top_neurons = top_i[0:n].tolist() 344 | 345 | selected_neuron = llm_transparency_tool.components.selector( 346 | items=[f"{top_neurons[i]}" if i >= 0 else "All" for i in range(-1, n)], 347 | indices=range(-1, n), 348 | temperatures=[0.0] + top_values[0:n].tolist(), 349 | preselected_index=-1, 350 | key="neuron_selector", 351 | ) 352 | if selected_neuron is None: 353 | selected_neuron = -1 354 | selected_neuron = None if selected_neuron == -1 else top_neurons[selected_neuron] 355 | 356 | return selected_neuron 357 | 358 | def _draw_token_table( 359 | self, 360 | n_top: int, 361 | n_bottom: int, 362 | representation: torch.Tensor, 363 | predecessor: Optional[torch.Tensor] = None, 364 | ): 365 | n_total = n_top + n_bottom 366 | 367 | logits = self._unembed(representation) 368 | n_vocab = logits.shape[0] 369 | scores, indices = torch.topk(logits, n_top, largest=True) 370 | positions = list(range(n_top)) 371 | 372 | if n_bottom > 0: 373 | low_scores, low_indices = torch.topk(logits, n_bottom, largest=False) 374 | indices = torch.cat((indices, low_indices.flip(0))) 375 | scores = torch.cat((scores, low_scores.flip(0))) 376 | positions += range(n_vocab - n_bottom, n_vocab) 377 | 378 | tokens = [string_to_display(w) for w in self.stateful_model.tokens_to_strings(indices)] 379 | 380 | if predecessor is not None: 381 | pre_logits = self._unembed(predecessor) 382 | _, sorted_pre_indices = pre_logits.sort(descending=True) 383 | pre_indices_dict = {index: pos for pos, index in enumerate(sorted_pre_indices.tolist())} 384 | old_positions = [pre_indices_dict[i] for i in indices.tolist()] 385 | 386 | def pos_gain_string(pos, old_pos): 387 | if pos == old_pos: 388 | return "" 389 | sign = "↓" if pos > old_pos else "↑" 390 | return f"({sign}{abs(pos - old_pos)})" 391 | 392 | position_strings = [f"{i} {pos_gain_string(i, old_i)}" for (i, old_i) in zip(positions, old_positions)] 393 | else: 394 | position_strings = [str(pos) for pos in positions] 395 | 396 | def pos_gain_color(s): 397 | color = "black" 398 | if isinstance(s, str): 399 | if "↓" in s: 400 | color = "red" 401 | if "↑" in s: 402 | color = "green" 403 | return f"color: {color}" 404 | 405 | top_df = pd.DataFrame( 406 | data=zip(position_strings, tokens, scores.tolist()), 407 | columns=["Pos", "Token", "Score"], 408 | ) 409 | 410 | st.dataframe( 411 | top_df.style.applymap(pos_gain_color) 412 | .background_gradient( 413 | axis=0, 414 | cmap=logits_color_map(positive_and_negative=n_bottom > 0), 415 | ) 416 | .format(precision=3), 417 | hide_index=True, 418 | height=self.render_settings.table_cell_height * (n_total + 1), 419 | use_container_width=True, 420 | ) 421 | 422 | def draw_token_dynamics(self, representation: torch.Tensor, block_name: str) -> None: 423 | st.caption(block_name) 424 | self._draw_token_table( 425 | self.render_settings.n_promoted_tokens, 426 | self.render_settings.n_suppressed_tokens, 427 | representation, 428 | None, 429 | ) 430 | 431 | def draw_top_tokens( 432 | self, 433 | node: UiGraphNode, 434 | container_top_tokens, 435 | container_token_dynamics, 436 | ) -> None: 437 | pre_node = node.get_residual_predecessor() 438 | if pre_node is None: 439 | return 440 | 441 | representation = self._get_representation(node) 442 | predecessor = self._get_representation(pre_node) 443 | 444 | with container_top_tokens: 445 | st.caption(node.get_name()) 446 | self._draw_token_table( 447 | self.render_settings.n_top_tokens, 448 | 0, 449 | representation, 450 | predecessor, 451 | ) 452 | if container_token_dynamics is not None: 453 | with container_token_dynamics: 454 | self.draw_token_dynamics(representation - predecessor, node.get_predecessor_block_name()) 455 | 456 | def draw_attention_dynamics(self, node: UiGraphNode, head: Optional[int]): 457 | block_name = node.get_head_name(head) 458 | block_output = ( 459 | self.stateful_model.attention_output_per_head(B0, node.layer, node.token, head) 460 | if head is not None 461 | else self.stateful_model.attention_output(B0, node.layer, node.token) 462 | ) 463 | self.draw_token_dynamics(block_output, block_name) 464 | 465 | def draw_ffn_dynamics(self, node: UiGraphNode, neuron: Optional[int]): 466 | block_name = node.get_neuron_name(neuron) 467 | block_output = ( 468 | self.stateful_model.neuron_output(node.layer, neuron) 469 | if neuron is not None 470 | else self.stateful_model.ffn_out(node.layer)[B0][node.token] 471 | ) 472 | self.draw_token_dynamics(block_output, block_name) 473 | 474 | def draw_precision_controls(self, device: str) -> Tuple[torch.dtype, bool]: 475 | """ 476 | Draw fp16/fp32 switch and AMP control. 477 | 478 | return: The selected precision and whether AMP should be enabled. 479 | """ 480 | 481 | if device == "cpu": 482 | dtype = torch.float32 483 | else: 484 | dtype = st.selectbox( 485 | "Precision", 486 | [torch.float16, torch.bfloat16, torch.float32], 487 | index=0, 488 | ) 489 | 490 | amp_enabled = dtype != torch.float32 491 | 492 | return dtype, amp_enabled 493 | 494 | def draw_controls(self): 495 | # model_container, data_container = st.columns([1, 1]) 496 | with st.sidebar.expander("Model", expanded=True): 497 | list_of_devices = possible_devices() 498 | if len(list_of_devices) > 1: 499 | self.device = st.selectbox( 500 | "Device", 501 | possible_devices(), 502 | index=0, 503 | ) 504 | else: 505 | self.device = list_of_devices[0] 506 | 507 | self.dtype, self.amp_enabled = self.draw_precision_controls(self.device) 508 | 509 | model_list = list(self._config.models) 510 | default_choice = model_list.index(self._config.default_model) 511 | 512 | self.supported_model_name = st.selectbox( 513 | "Model name", 514 | model_list, 515 | index=default_choice, 516 | ) 517 | self.model_name = st.text_input("Custom model name", value=self.supported_model_name) 518 | 519 | if self.model_name: 520 | self._stateful_model = load_model( 521 | model_name=self.model_name, 522 | _model_path=self._config.models[self.model_name], 523 | _device=self.device, 524 | _dtype=self.dtype, 525 | supported_model_name=None if not self.supported_model_name else self.supported_model_name, 526 | ) 527 | self.model_key = self.model_name # TODO maybe something else? 528 | self.draw_model_info() 529 | 530 | self.sentence = self.draw_dataset_selection() 531 | 532 | with st.sidebar.expander("Graph", expanded=True): 533 | self._contribution_threshold = st.slider( 534 | min_value=0.01, 535 | max_value=0.1, 536 | step=0.01, 537 | value=0.04, 538 | format=r"%.3f", 539 | label="Contribution threshold", 540 | ) 541 | self._renormalize_after_threshold = st.checkbox("Renormalize after threshold", value=True) 542 | self._normalize_before_unembedding = st.checkbox("Normalize before unembedding", value=True) 543 | 544 | def run_inference(self): 545 | 546 | with autocast(enabled=self.amp_enabled, device_type="cuda", dtype=self.dtype): 547 | self._stateful_model = cached_run_inference_and_populate_state(self.stateful_model, [self.sentence]) 548 | 549 | with autocast(enabled=self.amp_enabled, device_type="cuda", dtype=self.dtype): 550 | self._graph = get_contribution_graph( 551 | self.stateful_model, 552 | self.model_key, 553 | self.stateful_model.tokens()[B0].tolist(), 554 | (self._contribution_threshold if self._renormalize_after_threshold else 0.0), 555 | ) 556 | 557 | def draw_graph_and_selection( 558 | self, 559 | ) -> None: 560 | ( 561 | container_graph, 562 | container_tokens, 563 | ) = st.columns(self.render_settings.column_proportions) 564 | 565 | container_graph_left, container_graph_right = container_graph.columns([5, 1]) 566 | 567 | container_graph_left.write('##### Graph') 568 | heads_placeholder = container_graph_right.empty() 569 | heads_placeholder.write('##### Blocks') 570 | container_graph_right_used = False 571 | 572 | container_top_tokens, container_token_dynamics = container_tokens.columns([1, 1]) 573 | container_top_tokens.write('##### Top Tokens') 574 | container_top_tokens_used = False 575 | container_token_dynamics.write('##### Promoted Tokens') 576 | container_token_dynamics_used = False 577 | 578 | try: 579 | 580 | if self.sentence is None: 581 | return 582 | 583 | with container_graph_left: 584 | selection = self.draw_graph(self._contribution_threshold if not self._renormalize_after_threshold else 0.0) 585 | 586 | if selection is None: 587 | return 588 | 589 | node = selection.node 590 | edge = selection.edge 591 | 592 | if edge is not None and edge.target.type == NodeType.AFTER_ATTN: 593 | with container_graph_right: 594 | container_graph_right_used = True 595 | heads_placeholder.write('##### Heads') 596 | head = self.draw_attn_info(edge, container_graph) 597 | with container_token_dynamics: 598 | self.draw_attention_dynamics(edge.target, head) 599 | container_token_dynamics_used = True 600 | elif node is not None and node.type == NodeType.FFN: 601 | with container_graph_right: 602 | container_graph_right_used = True 603 | heads_placeholder.write('##### Neurons') 604 | neuron = self.draw_ffn_info(node) 605 | with container_token_dynamics: 606 | self.draw_ffn_dynamics(node, neuron) 607 | container_token_dynamics_used = True 608 | 609 | if node is not None and node.is_in_residual_stream(): 610 | self.draw_top_tokens( 611 | node, 612 | container_top_tokens, 613 | container_token_dynamics if not container_token_dynamics_used else None, 614 | ) 615 | container_top_tokens_used = True 616 | container_token_dynamics_used = True 617 | finally: 618 | if not container_graph_right_used: 619 | st_placeholder('Click on an edge to see head contributions. \n\n' 620 | 'Or click on FFN to see individual neuron contributions.', container_graph_right, height=1100) 621 | if not container_top_tokens_used: 622 | st_placeholder('Select a node from residual stream to see its top tokens.', container_top_tokens, height=1100) 623 | if not container_token_dynamics_used: 624 | st_placeholder('Select a node to see its promoted tokens.', container_token_dynamics, height=1100) 625 | 626 | 627 | def run(self): 628 | 629 | if self._config.demo_mode: 630 | with st.sidebar.expander("About", expanded=True): 631 | st.caption(""" 632 | The app is deployed in Demo Mode, thus only predefined models and inputs are available.\n 633 | You can still install the app locally and use your own models and inputs.\n 634 | See https://github.com/facebookresearch/llm-transparency-tool for more information. 635 | """) 636 | 637 | self.draw_controls() 638 | 639 | if not self.model_name: 640 | st.warning("No model selected") 641 | st.stop() 642 | 643 | if self.sentence is None: 644 | st.warning("No sentence selected") 645 | else: 646 | with torch.inference_mode(): 647 | self.run_inference() 648 | 649 | self.draw_graph_and_selection() 650 | 651 | 652 | if __name__ == "__main__": 653 | top_parser = argparse.ArgumentParser() 654 | top_parser.add_argument("config_file") 655 | args = top_parser.parse_args() 656 | 657 | parser = HfArgumentParser([LlmViewerConfig]) 658 | config = parser.parse_json_file(args.config_file)[0] 659 | 660 | with SystemMonitor(config.debug) as prof: 661 | app = App(config) 662 | app.run() 663 | --------------------------------------------------------------------------------