├── .devcontainer ├── Dockerfile └── devcontainer.json ├── .github └── workflows │ ├── checks.yml │ ├── release.yml │ └── storybook.yml ├── .gitignore ├── .vscode ├── cspell.json ├── extensions.json └── settings.json ├── LICENSE.txt ├── README.md ├── python ├── .pylintrc ├── Demonstration.ipynb ├── FetchActivationsDemo.ipynb ├── LICENSE.txt ├── README.md ├── circuitsvis │ ├── __init__.py │ ├── activations.py │ ├── attention.py │ ├── examples.py │ ├── logits.py │ ├── tests │ │ ├── __init__.py │ │ ├── snapshots │ │ │ ├── __init__.py │ │ │ ├── snap_test_activations.py │ │ │ ├── snap_test_attention.py │ │ │ ├── snap_test_hello.py │ │ │ ├── snap_test_render.py │ │ │ ├── snap_test_tokens.py │ │ │ ├── snap_test_topk_samples.py │ │ │ └── snap_test_topk_tokens.py │ │ ├── test_activations.py │ │ ├── test_attention.py │ │ ├── test_hello.py │ │ ├── test_init.py │ │ ├── test_tokens.py │ │ ├── test_topk_samples.py │ │ └── test_topk_tokens.py │ ├── tokens.py │ ├── topk_samples.py │ ├── topk_tokens.py │ └── utils │ │ ├── build_js.py │ │ ├── convert_props.py │ │ ├── render.py │ │ └── tests │ │ ├── __init__.py │ │ ├── snapshots │ │ ├── __init__.py │ │ └── snap_test_render.py │ │ ├── test_convert_props.py │ │ └── test_render.py ├── mypy.ini ├── poetry.lock ├── pyproject.toml └── setup.cfg └── react ├── .eslintignore ├── .eslintrc.js ├── .prettierrc ├── .storybook ├── main.js └── preview.js ├── README.md ├── esbuild.js ├── package.json ├── src ├── activations │ ├── TextNeuronActivations.stories.tsx │ ├── TextNeuronActivations.tsx │ ├── mocks │ │ └── textNeuronActivations.ts │ └── tests │ │ └── TestNeuronActivations.test.tsx ├── attention │ ├── AttentionHeads.stories.tsx │ ├── AttentionHeads.tsx │ ├── AttentionPattern.stories.tsx │ ├── AttentionPattern.tsx │ ├── AttentionPatterns.tsx │ ├── components │ │ ├── AttentionImage.tsx │ │ ├── AttentionTokens.tsx │ │ ├── tests │ │ │ └── useHoverLock.test.tsx │ │ └── useHoverLock.tsx │ ├── mocks │ │ └── attention.ts │ └── tests │ │ ├── AttentionPatterns.test.tsx │ │ └── __snapshots__ │ │ └── AttentionPatterns.test.tsx.snap ├── examples │ ├── Hello.stories.tsx │ └── Hello.tsx ├── index.ts ├── logits │ ├── TokenLogProbs.stories.tsx │ ├── TokenLogProbs.tsx │ └── mocks │ │ └── mockTokenLogProbs.ts ├── render-helper.ts ├── shared │ ├── NumberSelector.tsx │ ├── RangeSelector.tsx │ ├── SampleItems.tsx │ └── tests │ │ └── rangeStrArrConversion.test.ts ├── tokens │ ├── ColoredTokens.stories.tsx │ ├── ColoredTokens.tsx │ ├── ColoredTokensCustomTooltips.tsx │ ├── ColoredTokensMulti.stories.tsx │ ├── ColoredTokensMulti.tsx │ ├── mocks │ │ ├── coloredTokens.ts │ │ └── coloredTokensMulti.ts │ └── utils │ │ ├── Token.tsx │ │ └── TokenCustomTooltip.tsx ├── topk │ ├── TopkSamples.stories.tsx │ ├── TopkSamples.tsx │ ├── TopkTokens.stories.tsx │ ├── TopkTokens.tsx │ └── mocks │ │ ├── topkSamples.ts │ │ └── topkTokens.ts └── utils │ ├── arrayOps.ts │ ├── getTokenBackgroundColor.ts │ └── tests │ ├── arrayOps.test.ts │ └── getTokenBackgroundColor.test.ts ├── tsconfig.build.json ├── tsconfig.json └── yarn.lock /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | # See here for image contents: https://github.com/microsoft/vscode-dev-containers/tree/v0.245.2/containers/python-3/.devcontainer/base.Dockerfile 2 | 3 | # [Choice] Python version (use -bullseye variants on local arm64/Apple Silicon): 3, 3.10, 3.9, 3.8, 3.7, 3.6, 3-bullseye, 3.10-bullseye, 3.9-bullseye, 3.8-bullseye, 3.7-bullseye, 3.6-bullseye, 3-buster, 3.10-buster, 3.9-buster, 3.8-buster, 3.7-buster, 3.6-buster 4 | ARG VARIANT="3.10-bullseye" 5 | FROM mcr.microsoft.com/vscode/devcontainers/python:0-${VARIANT} 6 | 7 | # [Choice] Node.js version: none, lts/*, 16, 14, 12, 10 8 | ARG NODE_VERSION="none" 9 | RUN if [ "${NODE_VERSION}" != "none" ]; then su vscode -c "umask 0002 && . /usr/local/share/nvm/nvm.sh && nvm install ${NODE_VERSION} 2>&1"; fi 10 | 11 | # Install pip dependencies 12 | RUN pip3 --disable-pip-version-check --no-cache-dir install \ 13 | autopep8 \ 14 | jupyterlab 15 | 16 | # Poetry for Python package management 17 | USER vscode 18 | RUN curl -sSL https://install.python-poetry.org | python3 - 19 | ENV PATH "/home/vscode/.poetry/bin:$PATH" 20 | 21 | # [Optional] Uncomment this section to install additional OS packages. 22 | # RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ 23 | # && apt-get -y install --no-install-recommends 24 | 25 | # [Optional] Uncomment this line to install global node packages. 26 | # RUN su vscode -c "source /usr/local/share/nvm/nvm.sh && npm install -g 27 | # " 2>&1 28 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the README at: 2 | // https://github.com/microsoft/vscode-dev-containers/tree/v0.245.2/containers/python-3 3 | { 4 | "name": "Python 3", 5 | "build": { 6 | "dockerfile": "Dockerfile", 7 | "context": "..", 8 | "args": { 9 | // Update 'VARIANT' to pick a Python version: 3, 3.10, 3.9, 3.8, 3.7, 3.6 10 | // Append -bullseye or -buster to pin to an OS version. 11 | // Use -bullseye variants on local on arm64/Apple Silicon. 12 | "VARIANT": "3.10-bullseye", 13 | // Options 14 | "NODE_VERSION": "lts/*" 15 | } 16 | }, 17 | // Configure tool-specific properties. 18 | "customizations": { 19 | // Configure properties specific to VS Code. 20 | "vscode": { 21 | // Set *default* container specific settings.json values on container create. 22 | "settings": { 23 | "mypy.dmypyExecutable": "dmypy" 24 | }, 25 | // Add the IDs of extensions you want installed when the container is created. 26 | "extensions": [ 27 | "2gua.rainbow-brackets", 28 | "christian-kohler.npm-intellisense", 29 | "christian-kohler.path-intellisense", 30 | "davidanson.vscode-markdownlint", 31 | "dbaeumer.vscode-eslint", 32 | "donjayamanne.githistory", 33 | "donjayamanne.python-extension-pack", 34 | "eg2.vscode-npm-script", 35 | "esbenp.prettier-vscode", 36 | "github.copilot", 37 | "github.vscode-pull-request-github", 38 | "ionutvmi.path-autocomplete", 39 | "mikoz.autoflake-extension", 40 | "ms-python.isort", 41 | "ms-python.pylint", 42 | "ms-python.python", 43 | "ms-python.vscode-pylance", 44 | "ms-toolsai.jupyter-keymap", 45 | "ms-toolsai.jupyter-renderers", 46 | "ms-toolsai.jupyter", 47 | "ms-vsliveshare.vsliveshare-pack", 48 | "njpwerner.autodocstring", 49 | "redhat.vscode-yaml", 50 | "richie5um2.vscode-sort-json", 51 | "rvest.vs-code-prettier-eslint", 52 | "stkb.rewrap", 53 | "streetsidesoftware.code-spell-checker-british-english", 54 | "streetsidesoftware.code-spell-checker", 55 | "tushortz.python-extended-snippets", 56 | "yzhang.markdown-all-in-one", 57 | "matangover.mypy", 58 | "github.vscode-github-actions" 59 | ] 60 | } 61 | }, 62 | // Run commands after the container is created: 63 | "postCreateCommand": "cd python && poetry config virtualenvs.in-project true && poetry self add 'poethepoet[poetry_plugin]' && poetry install --with dev && cd ../react && yarn", 64 | // Comment out to connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root. 65 | "remoteUser": "vscode", 66 | "hostRequirements": { 67 | "memory": "6gb" 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /.github/workflows/checks.yml: -------------------------------------------------------------------------------- 1 | name: Checks 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | paths-ignore: 8 | - '.devcontainer/**' 9 | - '.github/**' 10 | - '.vscode/**' 11 | - '.gitignore' 12 | - 'README.md' 13 | pull_request: 14 | branches: 15 | - main 16 | paths-ignore: 17 | - '.devcontainer/**' 18 | - '.github/**' 19 | - '.vscode/**' 20 | - '.gitignore' 21 | - 'README.md' 22 | # Allow this workflow to be called from other workflows 23 | workflow_call: 24 | inputs: 25 | # Requires at least one input to be valid, but in practice we don't need any 26 | dummy: 27 | type: string 28 | required: false 29 | 30 | jobs: 31 | python-checks: 32 | name: Python Checks 33 | runs-on: ubuntu-latest 34 | strategy: 35 | matrix: 36 | python-version: 37 | - "3.8" 38 | - "3.9" 39 | - "3.10" 40 | - "3.11" 41 | env: 42 | working-directory: python 43 | defaults: 44 | run: 45 | working-directory: python 46 | steps: 47 | - uses: actions/checkout@v3 48 | - name: Install Poetry 49 | uses: snok/install-poetry@v1 50 | - name: Set up Python 51 | uses: actions/setup-python@v4 52 | with: 53 | python-version: ${{ matrix.python-version }} 54 | cache: 'poetry' 55 | - name: Install dependencies 56 | run: poetry install --with dev 57 | - name: Pytest 58 | run: poetry run pytest 59 | - name: Type check 60 | run: poetry run mypy circuitsvis 61 | - name: Build check 62 | run: poetry build 63 | 64 | react-checks: 65 | name: React checks 66 | runs-on: ubuntu-latest 67 | env: 68 | working-directory: react 69 | defaults: 70 | run: 71 | working-directory: react 72 | steps: 73 | - uses: actions/checkout@v3 74 | - name: Install dependencies 75 | run: yarn 76 | - name: Jest 77 | run: yarn test 78 | - name: Lint 79 | run: yarn lint 80 | - name: Check types 81 | run: yarn typeCheck 82 | - name: Build check 83 | run: yarn build 84 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: 6 | - published 7 | 8 | jobs: 9 | checks: 10 | name: Run checks workflow 11 | uses: TransformerLensOrg/CircuitsVis/.github/workflows/checks.yml@main 12 | 13 | semver-parser: 14 | name: Parse the semantic version from the release 15 | runs-on: ubuntu-latest 16 | steps: 17 | - name: Parse semver string 18 | id: semver_parser 19 | uses: booxmedialtd/ws-action-parse-semver@v1.4.7 20 | with: 21 | input_string: ${{ github.event.release.tag_name }} 22 | outputs: 23 | major: "${{ steps.semver_parser.outputs.major }}" 24 | minor: "${{ steps.semver_parser.outputs.minor }}" 25 | patch: "${{ steps.semver_parser.outputs.patch }}" 26 | semver: "${{ steps.semver_parser.outputs.fullversion }}" 27 | 28 | release-react: 29 | name: Release React Node JS package to NPMJS 30 | needs: 31 | - checks 32 | - semver-parser 33 | runs-on: ubuntu-latest 34 | permissions: 35 | actions: write 36 | contents: write 37 | env: 38 | working-directory: react 39 | defaults: 40 | run: 41 | working-directory: react 42 | steps: 43 | - uses: actions/checkout@v3 44 | - name: Install dependencies 45 | run: yarn 46 | - name: Build 47 | run: yarn build 48 | - name: Set the version 49 | run: yarn version --new-version ${{needs.semver-parser.outputs.semver}} --no-git-tag-version --no-commit-hooks 50 | - name: Publish to NPMJS 51 | id: publish 52 | uses: JS-DevTools/npm-publish@v1 53 | with: 54 | token: ${{ secrets.NPM_TOKEN }} 55 | package: ${{env.working-directory}}/package.json 56 | check-version: true 57 | 58 | release-python: 59 | name: Release Python package to PyPi 60 | needs: 61 | - checks 62 | - semver-parser 63 | - release-react 64 | runs-on: ubuntu-latest 65 | defaults: 66 | run: 67 | working-directory: python 68 | steps: 69 | - uses: actions/checkout@v3 70 | - name: Install Poetry 71 | uses: snok/install-poetry@v1 72 | with: 73 | version: 1.4.0 74 | - name: Poetry config 75 | run: poetry self add 'poethepoet[poetry_plugin]' 76 | - name: Set up Python 77 | uses: actions/setup-python@v4 78 | with: 79 | python-version: '3.9' 80 | cache: 'poetry' 81 | - name: Install dependencies 82 | run: poetry install --with dev 83 | - name: Set the version 84 | run: poetry version ${{needs.semver-parser.outputs.semver}} 85 | - name: Build 86 | run: poetry build 87 | - name: Publish 88 | run: poetry publish 89 | env: 90 | POETRY_PYPI_TOKEN_PYPI: ${{ secrets.POETRY_PYPI_TOKEN_PYPI }} 91 | 92 | publish-storybook: 93 | name: Publish docs to GitHub Pages 94 | environment: 95 | name: github-pages 96 | url: ${{ steps.deployment.outputs.page_url }} 97 | needs: 98 | - release-react 99 | runs-on: ubuntu-latest 100 | env: 101 | working-directory: react 102 | defaults: 103 | run: 104 | working-directory: react 105 | permissions: 106 | contents: read 107 | pages: write 108 | id-token: write 109 | steps: 110 | - name: Checkout 111 | uses: actions/checkout@v3 112 | - name: Setup Pages 113 | uses: actions/configure-pages@v2 114 | - uses: actions/checkout@v3 115 | - name: Install dependencies 116 | run: yarn 117 | - name: Build 118 | run: yarn build 119 | - name: Build storybook 120 | run: yarn build-storybook 121 | - name: Upload artifact 122 | uses: actions/upload-pages-artifact@v1 123 | with: 124 | # Upload entire repository 125 | path: 'react/storybook-static' 126 | - name: Deploy to GitHub Pages 127 | id: deployment 128 | uses: actions/deploy-pages@v1 129 | -------------------------------------------------------------------------------- /.github/workflows/storybook.yml: -------------------------------------------------------------------------------- 1 | name: Storybook Publish to Chromatic 2 | 3 | on: push 4 | 5 | jobs: 6 | # Note this is designed primarily for testing. We also publish the Storybook 7 | # in the release workflow to GitHub Pages (which is then used for docs). 8 | storybook: 9 | runs-on: ubuntu-latest 10 | defaults: 11 | run: 12 | working-directory: react 13 | steps: 14 | - uses: actions/checkout@v1 15 | - name: Install dependencies 16 | run: yarn 17 | - name: Publish Storybook to Chromatic 18 | uses: chromaui/action@v1 19 | with: 20 | projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }} 21 | token: ${{ secrets.GITHUB_TOKEN }} 22 | workingDir: react 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Distribution/packaging 2 | *.egg-info/ 3 | .DS_Store 4 | .eggs/ 5 | build/ 6 | .build/ 7 | dist/ 8 | 9 | # Python 10 | __pycache__ 11 | venv 12 | *-checkpoint.ipynb 13 | requirements.txt 14 | 15 | # Node 16 | node_modules/ 17 | package-lock.json 18 | yarn-error.log 19 | .parcel-cache/ 20 | storybook-static/ 21 | build-storybook.log 22 | -------------------------------------------------------------------------------- /.vscode/cspell.json: -------------------------------------------------------------------------------- 1 | { 2 | "language": "en,en-GB", 3 | "words": [ 4 | "circuitsvis", 5 | "Colab", 6 | "Colord", 7 | "crossorigin", 8 | "devcontainer", 9 | "Interpretability", 10 | "ndarray", 11 | "NPMJS" 12 | ] 13 | } 14 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "2gua.rainbow-brackets", 4 | "christian-kohler.npm-intellisense", 5 | "christian-kohler.path-intellisense", 6 | "davidanson.vscode-markdownlint", 7 | "dbaeumer.vscode-eslint", 8 | "donjayamanne.githistory", 9 | "donjayamanne.python-extension-pack", 10 | "eg2.vscode-npm-script", 11 | "esbenp.prettier-vscode", 12 | "github.copilot", 13 | "github.vscode-pull-request-github", 14 | "ionutvmi.path-autocomplete", 15 | "mikoz.autoflake-extension", 16 | "ms-python.isort", 17 | "ms-python.pylint", 18 | "ms-python.python", 19 | "ms-python.vscode-pylance", 20 | "ms-toolsai.jupyter-keymap", 21 | "ms-toolsai.jupyter-renderers", 22 | "ms-toolsai.jupyter", 23 | "ms-vsliveshare.vsliveshare-pack", 24 | "njpwerner.autodocstring", 25 | "redhat.vscode-yaml", 26 | "richie5um2.vscode-sort-json", 27 | "rvest.vs-code-prettier-eslint", 28 | "stkb.rewrap", 29 | "streetsidesoftware.code-spell-checker-british-english", 30 | "streetsidesoftware.code-spell-checker", 31 | "tushortz.python-extended-snippets", 32 | "yzhang.markdown-all-in-one", 33 | "matangover.mypy", 34 | "github.vscode-github-actions", 35 | "mikoz.black-py" 36 | ] 37 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "rewrap.autoWrap.enabled": true, 3 | "eslint.workingDirectories": [ 4 | { 5 | "mode": "auto" 6 | } 7 | ], 8 | "eslint.options": { 9 | "ignorePath": "../.gitignore" 10 | }, 11 | "editor.codeActionsOnSave": { 12 | "source.fixAll.eslint": "explicit", 13 | "source.organizeImports": "explicit" 14 | }, 15 | "eslint.validate": [ 16 | "javascript", 17 | "typescript" 18 | ], 19 | "rewrap.reformat": false, 20 | "editor.formatOnSave": true, 21 | "pylint.args": [ 22 | "--rcfile=./python/.pylintrc", 23 | "--generated-members=numpy.* ,torch.* ,cv2.* , cv.*" 24 | ], 25 | "python.testing.pytestArgs": [ 26 | "python/circuitsvis", 27 | ], 28 | "python.testing.pytestEnabled": true, 29 | "rewrap.wrappingColumn": 100, 30 | "pylint.importStrategy": "fromEnvironment", 31 | "notebook.formatOnCellExecution": true, 32 | "notebook.formatOnSave.enabled": true, 33 | "mypy.configFile": "python/mypy.ini", 34 | "mypy.targets": [ 35 | "python/circuitsvis", 36 | "python/circuitsvis/" 37 | ], 38 | "[python]": { 39 | "editor.defaultFormatter": "ms-python.black-formatter" 40 | }, 41 | } -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Alan Cooney 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CircuitsVis 2 | 3 | [![Release](https://github.com/alan-cooney/CircuitsVis/actions/workflows/release.yml/badge.svg)](https://github.com/alan-cooney/CircuitsVis/actions/workflows/release.yml) 4 | [![NPMJS](https://img.shields.io/npm/v/circuitsvis)](https://www.npmjs.com/package/circuitsvis) 5 | [![Pypi](https://img.shields.io/pypi/v/circuitsvis)](https://pypi.org/project/circuitsvis/) 6 | 7 | Mechanistic Interpretability visualizations, that work both in both Python (e.g. with 8 | [Jupyter Lab](https://jupyter.org/)) and JavaScript (e.g. [React](https://reactjs.org/) or plain HTML). 9 | 10 | View them all at [https://transformerlensorg.github.io/CircuitsVis](https://transformerlensorg.github.io/CircuitsVis) 11 | 12 | ## Use 13 | 14 | ### Install 15 | 16 | #### Python 17 | 18 | ```bash 19 | pip install circuitsvis 20 | ``` 21 | 22 | #### React 23 | 24 | ```bash 25 | yarn add circuitsvis 26 | ``` 27 | 28 | ### Add visualizations 29 | 30 | You can use any of the components from the [demo 31 | page](https://transformerlensorg.github.io/CircuitsVis). These show the source code for 32 | use with React, and for Python you can instead import the function with the same 33 | name. 34 | 35 | ```Python 36 | # Python Example 37 | from circuitsvis.tokens import colored_tokens 38 | colored_tokens(["My", "tokens"], [0.123, -0.226]) 39 | ``` 40 | 41 | ```TypeScript 42 | // React Example 43 | import ColoredTokens from "circuitsvis"; 44 | 45 | function Example() { 46 | 50 | } 51 | ``` 52 | 53 | ## Contribute 54 | 55 | ### Development requirements 56 | 57 | #### DevContainer 58 | 59 | For a one-click setup of your development environment, this project includes a 60 | DevContainer. It can be used locally with [VS 61 | Code](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) 62 | or with [GitHub Codespaces](https://github.com/features/codespaces). 63 | 64 | #### Manual setup 65 | 66 | To create new visualizations you need [Node](https://nodejs.org/en/) (including 67 | [yarn](https://classic.yarnpkg.com/lang/en/docs/install/#mac-stable)) and Python 68 | (with [Poetry](https://python-poetry.org/)). 69 | 70 | Once you have these, you need to install both the Node & Python packages (note 71 | that for Python we use the 72 | [Poetry](https://python-poetry.org/docs/#installation) package management 73 | system). 74 | 75 | ```bash 76 | cd react && yarn 77 | ``` 78 | 79 | ```bash 80 | cd python && poetry install --with dev 81 | ``` 82 | 83 | #### Jupyter install 84 | 85 | If you want Jupyter as well, run `poetry install --with jupyter` or, if this 86 | fails due to a PyTorch bug on M1 MacBooks, run `poetry run pip install jupyter`. 87 | 88 | ### Creating visualizations 89 | 90 | #### React 91 | 92 | You'll first want to create the visualisation in React. To do this, you can copy 93 | the example from `/react/src/examples/Hello.tsx`. To view changes whilst editing 94 | this (in [Storybook](https://classic.yarnpkg.com/lang/en/docs/install/#mac-stable)), 95 | run the following from the `/react/` directory: 96 | 97 | ```bash 98 | yarn storybook 99 | ``` 100 | 101 | #### Python 102 | 103 | This project uses [Poetry](https://python-poetry.org/docs/#installation) for 104 | package management. To install run: 105 | 106 | ```bash 107 | poetry install 108 | ``` 109 | 110 | Once you've created your visualization in React, you can then create a short 111 | function in the Python library to render it. You can see an example in 112 | `/python/circuitsvis/examples.py`. 113 | 114 | Note that **this example will render from the CDN**, unless development mode is 115 | specified. Your visualization will only be available on the CDN once it has been 116 | released to the latest production version of this library. 117 | 118 | #### Publishing a new release 119 | 120 | When a new GitHub release is created, the codebase will be automatically built 121 | and deployed to [PyPI](https://pypi.org/project/circuitsvis/). 122 | 123 | ### Citation 124 | 125 | Please cite this library as: 126 | 127 | ```BibTeX 128 | @misc{cooney2023circuitsvis, 129 | title = {CircuitsVis}, 130 | author = {Alan Cooney and Neel Nanda}, 131 | year = {2023}, 132 | howpublished = {\url{https://github.com/TransformerLensOrg/CircuitsVis}}, 133 | } 134 | ``` 135 | -------------------------------------------------------------------------------- /python/.pylintrc: -------------------------------------------------------------------------------- 1 | [MAIN] 2 | 3 | # Ignore test files 4 | ignore-patterns=(.)*_test\.py,test_(.)*\.py,(.)*snap_test_(.)*\.py 5 | 6 | disable=import-error, too-many-arguments, bare-except 7 | -------------------------------------------------------------------------------- /python/LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Alan Cooney 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /python/README.md: -------------------------------------------------------------------------------- 1 | # Circuits Vis 2 | 3 | Mechanistic Interpretability visualizations. 4 | 5 | ## Testing 6 | 7 | `pytest` -------------------------------------------------------------------------------- /python/circuitsvis/__init__.py: -------------------------------------------------------------------------------- 1 | """CircuitsVis""" 2 | from importlib_metadata import version 3 | import circuitsvis.activations 4 | import circuitsvis.attention 5 | import circuitsvis.examples 6 | import circuitsvis.tokens 7 | import circuitsvis.topk_samples 8 | import circuitsvis.topk_tokens 9 | import circuitsvis.logits 10 | 11 | __version__ = version("circuitsvis") 12 | 13 | __all__ = [ 14 | "activations", 15 | "attention", 16 | "examples", 17 | "tokens", 18 | "topk_samples", 19 | "topk_tokens", 20 | ] 21 | -------------------------------------------------------------------------------- /python/circuitsvis/activations.py: -------------------------------------------------------------------------------- 1 | """Activations visualizations""" 2 | from typing import List, Union, Optional 3 | 4 | import numpy as np 5 | import torch 6 | from circuitsvis.utils.render import RenderedHTML, render 7 | 8 | 9 | def text_neuron_activations( 10 | tokens: Union[List[str], List[List[str]]], 11 | activations: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]], 12 | first_dimension_name: Optional[str] = "Layer", 13 | second_dimension_name: Optional[str] = "Neuron", 14 | first_dimension_labels: Optional[List[str]] = None, 15 | second_dimension_labels: Optional[List[str]] = None, 16 | ) -> RenderedHTML: 17 | """Show activations (colored by intensity) for each token in a text or set 18 | of texts. 19 | 20 | Includes drop-downs for layer and neuron numbers. 21 | 22 | Args: 23 | tokens: List of tokens if single sample (e.g. `["A", "person"]`) or list of lists of tokens (e.g. `[[["A", "person"], ["is", "walking"]]]`) 24 | activations: Activations of the shape [tokens x layers x neurons] if 25 | single sample or list of [tokens x layers x neurons] if multiple samples 26 | 27 | Returns: 28 | Html: Text neuron activations visualization 29 | """ 30 | # Verify that activations and tokens have the right shape and convert to 31 | # nested lists 32 | if isinstance(activations, (np.ndarray, torch.Tensor)): 33 | assert ( 34 | activations.ndim == 3 35 | ), "activations must be of shape [tokens x layers x neurons]" 36 | activations_list = activations.tolist() 37 | elif isinstance(activations, list): 38 | activations_list = [] 39 | for act in activations: 40 | assert ( 41 | act.ndim == 3 42 | ), "activations must be of shape [tokens x layers x neurons]" 43 | activations_list.append(act.tolist()) 44 | else: 45 | raise TypeError( 46 | f"activations must be of type np.ndarray, torch.Tensor, or list, not {type(activations)}" 47 | ) 48 | 49 | return render( 50 | "TextNeuronActivations", 51 | tokens=tokens, 52 | activations=activations_list, 53 | firstDimensionName=first_dimension_name, 54 | secondDimensionName=second_dimension_name, 55 | firstDimensionLabels=first_dimension_labels, 56 | secondDimensionLabels=second_dimension_labels, 57 | ) 58 | -------------------------------------------------------------------------------- /python/circuitsvis/attention.py: -------------------------------------------------------------------------------- 1 | """Attention visualisations""" 2 | from typing import List, Optional, Union 3 | 4 | import numpy as np 5 | import torch 6 | from circuitsvis.utils.render import RenderedHTML, render 7 | 8 | 9 | def attention_heads( 10 | attention: Union[list, np.ndarray, torch.Tensor], 11 | tokens: List[str], 12 | attention_head_names: Optional[List[str]] = None, 13 | max_value: Optional[float] = None, 14 | min_value: Optional[float] = None, 15 | negative_color: Optional[str] = None, 16 | positive_color: Optional[str] = None, 17 | mask_upper_tri: Optional[bool] = None, 18 | ) -> RenderedHTML: 19 | """Attention Heads 20 | 21 | Attention patterns from destination to source tokens, for a group of heads. 22 | 23 | Displays a small heatmap for each attention head. When one is selected, it 24 | is then shown in full size. 25 | 26 | Args: 27 | attention: Attention head activations of the shape [dest_tokens x 28 | src_tokens] 29 | tokens: List of tokens (e.g. `["A", "person"]`). Must be the same length 30 | as the list of values. 31 | max_value: Maximum value. Used to determine how dark the token color is 32 | when positive (i.e. based on how close it is to the maximum value). 33 | min_value: Minimum value. Used to determine how dark the token color is 34 | when negative (i.e. based on how close it is to the minimum value). 35 | negative_color: Color for negative values. This can be any valid CSS 36 | color string. Be mindful of color blindness if not using the default 37 | here. 38 | positive_color: Color for positive values. This can be any valid CSS 39 | color string. Be mindful of color blindness if not using the default 40 | here. 41 | mask_upper_tri: Whether or not to mask the upper triangular portion of 42 | the attention patterns. Should be true for causal attention, false for 43 | bidirectional attention. 44 | 45 | Returns: 46 | Html: Attention pattern visualization 47 | """ 48 | kwargs = { 49 | "attention": attention, 50 | "attentionHeadNames": attention_head_names, 51 | "maxValue": max_value, 52 | "minValue": min_value, 53 | "negativeColor": negative_color, 54 | "positiveColor": positive_color, 55 | "tokens": tokens, 56 | "maskUpperTri": mask_upper_tri, 57 | } 58 | 59 | return render( 60 | "AttentionHeads", 61 | **kwargs 62 | ) 63 | 64 | 65 | def attention_patterns( 66 | tokens: List[str], 67 | attention: Union[list, np.ndarray, torch.Tensor], 68 | ) -> RenderedHTML: 69 | """Attention Patterns 70 | 71 | Visualization of attention head patterns. 72 | 73 | @deprecated Use `attention_heads` instead. 74 | 75 | Args: 76 | tokens: List of tokens (e.g. `["A", "person"]`) 77 | attention: Attention tensor of the shape [num_heads x dest_tokens x 78 | src_tokens] 79 | 80 | Returns: 81 | Html: Attention patterns visualization 82 | """ 83 | return render( 84 | "AttentionPatterns", 85 | tokens=tokens, 86 | attention=attention, 87 | ) 88 | 89 | 90 | def attention_pattern( 91 | tokens: List[str], 92 | attention: Union[list, np.ndarray, torch.Tensor], 93 | max_value: Optional[float] = None, 94 | min_value: Optional[float] = None, 95 | negative_color: Optional[str] = None, 96 | show_axis_labels: Optional[bool] = None, 97 | positive_color: Optional[str] = None, 98 | mask_upper_tri: Optional[bool] = None, 99 | ) -> RenderedHTML: 100 | """Attention Pattern 101 | 102 | Attention pattern from destination to source tokens. Displays a heatmap of 103 | attention values (hover to see the specific values). 104 | 105 | Args: 106 | tokens: List of tokens (e.g. `["A", "person"]`). Must be the same length 107 | as the list of values. 108 | attention: Attention head activations of the shape [dest_tokens x 109 | src_tokens] 110 | max_value: Maximum value. Used to determine how dark the token color is 111 | when positive (i.e. based on how close it is to the maximum value). 112 | min_value: Minimum value. Used to determine how dark the token color is 113 | when negative (i.e. based on how close it is to the minimum value). 114 | negative_color: Color for negative values. This can be any valid CSS 115 | color string. Be mindful of color blindness if not using the default 116 | here. 117 | show_axis_labels: Whether to show axis labels. 118 | positive_color: Color for positive values. This can be any valid CSS 119 | color string. Be mindful of color blindness if not using the default 120 | here. 121 | mask_upper_tri: Whether or not to mask the upper triangular portion of 122 | the attention patterns. Should be true for causal attention, false for 123 | bidirectional attention. 124 | 125 | Returns: 126 | Html: Attention pattern visualization 127 | """ 128 | kwargs = { 129 | "tokens": tokens, 130 | "attention": attention, 131 | "minValue": min_value, 132 | "maxValue": max_value, 133 | "negativeColor": negative_color, 134 | "positiveColor": positive_color, 135 | "showAxisLabels": show_axis_labels, 136 | "maskUpperTri": mask_upper_tri, 137 | } 138 | 139 | return render( 140 | "AttentionPattern", 141 | **kwargs 142 | ) 143 | -------------------------------------------------------------------------------- /python/circuitsvis/examples.py: -------------------------------------------------------------------------------- 1 | """Examples""" 2 | from circuitsvis.utils.render import RenderedHTML, render 3 | 4 | 5 | def hello( 6 | name: str, 7 | ) -> RenderedHTML: 8 | """Hello example 9 | 10 | Args: 11 | name: Name to say hello to 12 | 13 | Returns: 14 | Html: Hello example 15 | """ 16 | return render( 17 | "Hello", 18 | name=name, 19 | ) 20 | -------------------------------------------------------------------------------- /python/circuitsvis/logits.py: -------------------------------------------------------------------------------- 1 | """Log Prob visualization""" 2 | from typing import Callable, List, Union 3 | 4 | import numpy as np 5 | import torch 6 | from circuitsvis.utils.render import RenderedHTML, render 7 | 8 | ArrayRank1 = Union[List[float], np.ndarray, torch.Tensor] 9 | ArrayRank2 = Union[List[List[float]], np.ndarray, torch.Tensor] 10 | ArrayRank3 = Union[List[List[List[float]]], np.ndarray, torch.Tensor] 11 | IntArrayRank1 = Union[List[int], np.ndarray, torch.Tensor] 12 | 13 | 14 | def token_log_probs( 15 | token_indices: torch.Tensor, 16 | log_probs: torch.Tensor, 17 | to_string: Callable[[int], str], 18 | top_k: int = 10, 19 | ) -> RenderedHTML: 20 | """ 21 | Takes the log probs for a model on some text. Outputs the tokens coloured by 22 | the log prob, and on hover shows you the top K tokens that the model guessed 23 | for that position, and where the true token ranked in that. 24 | 25 | The intended use case is to help debug and explore a model's outputs. 26 | 27 | Args: 28 | token_indices: Tensor of token indices (ie integers) of shape [N,]. 29 | Assumed to begin with a Beginning of Sequence (BOS) token, which is not 30 | shown in the visualization. 31 | log_probs: Log Probabilities for predicting the next token. Tensor of 32 | shape [N, d_vocab]. 33 | to_string: A function mapping tokens (as integers) to their string value 34 | top_k: How many logits to show 35 | 36 | Returns: 37 | Html: Log prob visualization 38 | """ 39 | if len(token_indices.shape) == 2: 40 | # Remove batch dimension from token indices 41 | token_indices = token_indices.squeeze(0) 42 | 43 | if len(log_probs.shape) == 3: 44 | # Remove batch dimension from log probs 45 | log_probs = log_probs.squeeze(0) 46 | 47 | assert len( 48 | log_probs.shape) == 2, f"Log Probs shape must be 2D: {log_probs.shape}" 49 | assert len( 50 | token_indices.shape) == 1, f"Tokens shape must be 1D: {token_indices.shape}" 51 | assert token_indices.size(0) == log_probs.size( 52 | 0), f"Number of tokens and log prob vectors must be identical, {log_probs.shape}, {token_indices.shape}" 53 | 54 | # Drop the final dimension of log probs, since we don't know what the next 55 | # token is for the final position! 56 | log_probs = log_probs[:-1] 57 | 58 | prompt = [to_string(index.item()) for index in token_indices] 59 | 60 | # Sort log probs and values along the d_vocab dimension 61 | _sorted_log_prob_values, sorted_log_prob_indices = log_probs.sort( 62 | dim=-1, descending=True) 63 | 64 | # Get the top K log probs and indices for each position 65 | # Shapes are [N, K] 66 | top_k_log_probs, top_k_indices = log_probs.topk(top_k, dim=-1) 67 | 68 | # Get the token values (ie strings) for the top K tokens per position 69 | top_k_tokens = [[to_string(token) for token in current_top_k_tokens] 70 | for current_top_k_tokens in top_k_indices.tolist()] 71 | 72 | # Slightly cursed code to get the rank of the correct token at each position 73 | # .nonzero on a 2D array returns a [X, 2] array - X is the number of 74 | # non-zero elements, and each has the pair of indices corresponding to it. 75 | # We only want the index on the d_vocab direction, so we take 1 76 | # We don't care about predicting the BOS token, so we do token_indices[1:] 77 | correct_token_rank = (sorted_log_prob_indices == 78 | token_indices[1:, None]).nonzero()[:, 1] 79 | assert len(correct_token_rank) == (len(token_indices) - 80 | 1), "Some token indices were missing from sorted_log_prob_indices" 81 | 82 | # Gets the log probs for the correct next token. Weird indexing is necessary 83 | # to use gather. 84 | correct_token_log_prob = log_probs.gather( 85 | index=token_indices[1:, None], dim=-1).squeeze(1) 86 | 87 | return render( 88 | "TokenLogProbs", 89 | prompt=prompt, 90 | topKLogProbs=top_k_log_probs, 91 | topKTokens=top_k_tokens, 92 | correctTokenRank=correct_token_rank, 93 | correctTokenLogProb=correct_token_log_prob, 94 | ) 95 | -------------------------------------------------------------------------------- /python/circuitsvis/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TransformerLensOrg/CircuitsVis/f8c1c81c8d6ab1f7882587e2c9ca188bb95c0f58/python/circuitsvis/tests/__init__.py -------------------------------------------------------------------------------- /python/circuitsvis/tests/snapshots/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TransformerLensOrg/CircuitsVis/f8c1c81c8d6ab1f7882587e2c9ca188bb95c0f58/python/circuitsvis/tests/snapshots/__init__.py -------------------------------------------------------------------------------- /python/circuitsvis/tests/snapshots/snap_test_activations.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # snapshottest: v1 - https://goo.gl/zC4yUc 3 | from __future__ import unicode_literals 4 | 5 | from snapshottest import Snapshot 6 | 7 | 8 | snapshots = Snapshot() 9 | 10 | snapshots['TestTextNeuronActivations.test_multi_matches_snapshot 1'] = '''
11 | ''' 19 | 20 | snapshots['TestTextNeuronActivations.test_single_matches_snapshot 1'] = '''
21 | ''' 29 | -------------------------------------------------------------------------------- /python/circuitsvis/tests/snapshots/snap_test_attention.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # snapshottest: v1 - https://goo.gl/zC4yUc 3 | from __future__ import unicode_literals 4 | 5 | from snapshottest import Snapshot 6 | 7 | 8 | snapshots = Snapshot() 9 | 10 | snapshots['TestAttention.test_matches_snapshot 1'] = '''
11 | ''' 19 | -------------------------------------------------------------------------------- /python/circuitsvis/tests/snapshots/snap_test_hello.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # snapshottest: v1 - https://goo.gl/zC4yUc 3 | from __future__ import unicode_literals 4 | 5 | from snapshottest import Snapshot 6 | 7 | 8 | snapshots = Snapshot() 9 | 10 | snapshots['TestHello.test_matches_snapshot 1'] = '''
11 | ''' 19 | -------------------------------------------------------------------------------- /python/circuitsvis/tests/snapshots/snap_test_render.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # snapshottest: v1 - https://goo.gl/zC4yUc 3 | from __future__ import unicode_literals 4 | 5 | from snapshottest import Snapshot 6 | 7 | 8 | snapshots = Snapshot() 9 | 10 | snapshots['TestRenderProd.test_example_element 1'] = '''
11 | ''' 19 | -------------------------------------------------------------------------------- /python/circuitsvis/tests/snapshots/snap_test_tokens.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # snapshottest: v1 - https://goo.gl/zC4yUc 3 | from __future__ import unicode_literals 4 | 5 | from snapshottest import Snapshot 6 | 7 | 8 | snapshots = Snapshot() 9 | 10 | snapshots['TestTokens.test_matches_snapshot 1'] = '''
11 | ''' 19 | -------------------------------------------------------------------------------- /python/circuitsvis/tests/snapshots/snap_test_topk_samples.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # snapshottest: v1 - https://goo.gl/zC4yUc 3 | from __future__ import unicode_literals 4 | 5 | from snapshottest import Snapshot 6 | 7 | 8 | snapshots = Snapshot() 9 | 10 | snapshots['TestTopkSamples.test_matches_snapshot 1'] = '''
11 | ''' 19 | -------------------------------------------------------------------------------- /python/circuitsvis/tests/snapshots/snap_test_topk_tokens.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # snapshottest: v1 - https://goo.gl/zC4yUc 3 | from __future__ import unicode_literals 4 | 5 | from snapshottest import Snapshot 6 | 7 | 8 | snapshots = Snapshot() 9 | 10 | snapshots['TestTopkTokens.test_matches_snapshot 1'] = '''
11 | ''' 19 | -------------------------------------------------------------------------------- /python/circuitsvis/tests/test_activations.py: -------------------------------------------------------------------------------- 1 | from circuitsvis.activations import text_neuron_activations 2 | import circuitsvis.utils.render 3 | import numpy as np 4 | 5 | 6 | class TestTextNeuronActivations: 7 | def test_single_matches_snapshot(self, snapshot, monkeypatch): 8 | monkeypatch.setattr(circuitsvis.utils.render, "uuid4", lambda: "mock") 9 | 10 | res = text_neuron_activations( 11 | tokens=["a", "b"], 12 | activations=np.array( 13 | [[[0, 1, 0], [0, 1, 1]], [[0, 1, 1], [1, 1, 1]]] 14 | ), # [tokens (2) x layers (2) x neurons(3)] 15 | ) 16 | snapshot.assert_match(str(res)) 17 | 18 | def test_multi_matches_snapshot(self, snapshot, monkeypatch): 19 | monkeypatch.setattr(circuitsvis.utils.render, "uuid4", lambda: "mock") 20 | 21 | res = text_neuron_activations( 22 | tokens=[["a", "b"], ["c", "d", "e"]], 23 | activations=[ 24 | np.array( 25 | [[[0, 1, 0], [0, 1, 1]], [[0, 1, 1], [1, 1, 1]]] 26 | ), # [tokens (2) x layers (2) x neurons(3)] 27 | np.array( 28 | [ 29 | [[0, 1, 0], [0, 1, 1]], 30 | [[0, 1, 1], [1, 1, 1]], 31 | [[0, 1, 1], [1, 1, 1]], 32 | ] 33 | ), # [tokens (3) x layers (2) x neurons(3)] 34 | ], 35 | ) 36 | snapshot.assert_match(str(res)) 37 | -------------------------------------------------------------------------------- /python/circuitsvis/tests/test_attention.py: -------------------------------------------------------------------------------- 1 | import circuitsvis 2 | import circuitsvis.utils.render 3 | import numpy as np 4 | from circuitsvis.attention import attention_patterns 5 | 6 | 7 | class TestAttention: 8 | def test_matches_snapshot(self, snapshot, monkeypatch): 9 | monkeypatch.setattr(circuitsvis.utils.render, "uuid4", lambda: "mock") 10 | 11 | res = attention_patterns( 12 | tokens=["a", "b"], 13 | attention=np.array([[[0, 1], [0, 1]]]) 14 | ) 15 | snapshot.assert_match(str(res)) 16 | -------------------------------------------------------------------------------- /python/circuitsvis/tests/test_hello.py: -------------------------------------------------------------------------------- 1 | import circuitsvis 2 | import circuitsvis.utils.render 3 | from circuitsvis.examples import hello 4 | 5 | 6 | class TestHello: 7 | def test_matches_snapshot(self, snapshot, monkeypatch): 8 | monkeypatch.setattr(circuitsvis.utils.render, "uuid4", lambda: "mock") 9 | 10 | res = hello(name="Bob") 11 | snapshot.assert_match(str(res)) 12 | -------------------------------------------------------------------------------- /python/circuitsvis/tests/test_init.py: -------------------------------------------------------------------------------- 1 | import circuitsvis 2 | 3 | 4 | def test_version(): 5 | """This also checks that the module initializes without errors""" 6 | assert type(circuitsvis.__version__) == str 7 | 8 | 9 | def test_all(): 10 | """Test that __all__ contains only names that are actually exported.""" 11 | for name in circuitsvis.__all__: 12 | assert hasattr(circuitsvis, name) 13 | -------------------------------------------------------------------------------- /python/circuitsvis/tests/test_tokens.py: -------------------------------------------------------------------------------- 1 | import circuitsvis 2 | import circuitsvis.utils.render 3 | from circuitsvis.tokens import colored_tokens 4 | 5 | 6 | class TestTokens: 7 | def test_matches_snapshot(self, snapshot, monkeypatch): 8 | monkeypatch.setattr(circuitsvis.utils.render, "uuid4", lambda: "mock") 9 | 10 | res = colored_tokens(tokens=["a", "b"], values=[1, 2]) 11 | snapshot.assert_match(str(res)) 12 | -------------------------------------------------------------------------------- /python/circuitsvis/tests/test_topk_samples.py: -------------------------------------------------------------------------------- 1 | from circuitsvis.topk_samples import topk_samples 2 | import circuitsvis.utils.render 3 | 4 | 5 | class TestTopkSamples: 6 | def test_matches_snapshot(self, snapshot, monkeypatch): 7 | # Monkeypatch uuid4 to always return the same uuid 8 | monkeypatch.setattr(circuitsvis.utils.render, "uuid4", lambda: "mock") 9 | tokens = [ 10 | [ 11 | ["And", " here"], 12 | ["This", " is", " another"], 13 | ], 14 | [ 15 | ["Another", " example"], 16 | ["Weee", " is", " another"], 17 | ], 18 | ] # list of samples for the layer (n_neurons (2), samples (2), tokens (varied)) 19 | activations = [ 20 | [ 21 | [0.2, 1], 22 | [1, 0.0, 0], 23 | ], 24 | [ 25 | [0, 1], 26 | [0.5, 1, 1], 27 | ], 28 | ] # list of samples for the layer (n_neurons (2), samples (2), tokens (varied)) 29 | res = topk_samples(tokens=[tokens], activations=[activations]) 30 | snapshot.assert_match(str(res)) 31 | -------------------------------------------------------------------------------- /python/circuitsvis/tests/test_topk_tokens.py: -------------------------------------------------------------------------------- 1 | from circuitsvis.topk_tokens import topk_tokens 2 | import circuitsvis.utils.render 3 | import numpy as np 4 | 5 | 6 | class TestTopkTokens: 7 | def test_matches_snapshot(self, snapshot, monkeypatch): 8 | # Monkeypatch uuid4 to always return the same uuid 9 | monkeypatch.setattr(circuitsvis.utils.render, "uuid4", lambda: "mock") 10 | res = topk_tokens( 11 | tokens=[["a", "b", "c", "d", "e"], ["f", "g", "h"]], 12 | activations=[ 13 | np.arange(30).reshape(2, 5, 3), 14 | np.arange(12).reshape(2, 3, 2), 15 | ], # each of shape (n_layers, n_tokens, n_neurons) 16 | ) 17 | snapshot.assert_match(str(res)) 18 | -------------------------------------------------------------------------------- /python/circuitsvis/tokens.py: -------------------------------------------------------------------------------- 1 | """Tokens Visualizations""" 2 | from typing import List, Optional, Union 3 | 4 | import numpy as np 5 | import torch 6 | from circuitsvis.utils.render import RenderedHTML, render 7 | 8 | ArrayRank1 = Union[List[float], np.ndarray, torch.Tensor] 9 | ArrayRank2 = Union[List[List[float]], np.ndarray, torch.Tensor] 10 | ArrayRank3 = Union[List[List[List[float]]], np.ndarray, torch.Tensor] 11 | IntArrayRank1 = Union[List[int], np.ndarray, torch.Tensor] 12 | 13 | 14 | def colored_tokens( 15 | tokens: List[str], 16 | values: Union[List[float], np.ndarray, torch.Tensor], 17 | min_value: Optional[float] = None, 18 | max_value: Optional[float] = None, 19 | negative_color: Optional[str] = None, 20 | positive_color: Optional[str] = None, 21 | ) -> RenderedHTML: 22 | """Show tokens (colored by values) for each token in some text 23 | 24 | Args: 25 | tokens: List of tokens (e.g. `["A", "person"]`) 26 | values: Values of the same length as the tokens 27 | min_value: Minimum value to use for color scale 28 | max_value: Maximum value to use for color scale 29 | min_color: Color to use for minimum value 30 | max_color: Color to use for maximum value 31 | 32 | Returns: 33 | Html: Colored tokens visualization 34 | """ 35 | kwargs = { 36 | "tokens": tokens, 37 | "values": values, 38 | "minValue": min_value, 39 | "maxValue": max_value, 40 | "negativeColor": negative_color, 41 | "positiveColor": positive_color, 42 | } 43 | 44 | return render( 45 | "ColoredTokens", 46 | **kwargs 47 | ) 48 | 49 | 50 | def colored_tokens_multi( 51 | tokens: List[str], 52 | values: torch.Tensor, 53 | labels: Optional[List[str]] = None, 54 | ) -> RenderedHTML: 55 | """Shows a sequence of tokens colored by their value. 56 | 57 | Takes in a tensor of values of shape [S, K] (S tokens, K different types of 58 | value). 59 | 60 | The user can hover or click on a button for each of the K types to color the 61 | token with those values. 62 | 63 | The user can hover over a token to see a list of the K values for that 64 | token. 65 | 66 | Args: 67 | tokens: List of string tokens, one for each token in the prompt. Length [S] 68 | values: The tensor of values to color tokens by. Shape [S, K] 69 | labels: The names of the values. Length [K]. 70 | 71 | Returns: 72 | Html: Log prob visualization 73 | """ 74 | assert len(tokens) == values.size(0), \ 75 | f"Number of tokens ({len(tokens)}) must equal first dimension of values tensor, " + \ 76 | f"shape {values.shape}" 77 | if labels: 78 | assert len(labels) == values.size(1), \ 79 | f"Number of labels ({len(labels)}) must equal second dimension of values tensor, " + \ 80 | f"shape {values.shape}" 81 | 82 | return render( 83 | "ColoredTokensMulti", 84 | tokens=tokens, 85 | values=values, 86 | labels=labels, 87 | ) 88 | 89 | 90 | def visualize_model_performance( 91 | tokens: torch.Tensor, 92 | str_tokens: List[str], 93 | logits: torch.Tensor, 94 | ): 95 | """Visualizes model performance on some text 96 | 97 | Shows logits, log probs, and probabilities for predicting each token (from 98 | the previous tokens), colors the tokens according to one of logits, log 99 | probs and probabilities, according to user input. 100 | 101 | Allows the user to enter custom bounds for the values (eg, saturate color of 102 | probability at 0.01) 103 | """ 104 | if len(tokens.shape) == 2: 105 | assert tokens.shape[0] == 1, \ 106 | f"tokens must be rank 1, or rank 2 with a dummy batch dimension. Shape: {tokens.shape}" 107 | tokens = tokens[0] 108 | if len(logits.shape) == 3: 109 | assert logits.shape[0] == 1, \ 110 | f"logits must be rank 2, or rank 3 with a dummy batch dimension. Shape: {logits.shape}" 111 | logits = logits[0] 112 | assert len(str_tokens) == len(tokens), \ 113 | "Must have same number of tokens and str_tokens" 114 | assert len(tokens) == logits.shape[0], \ 115 | "Must have the same number of tokens and logit vectors" 116 | 117 | # We remove the final vector of logits, as it can't predict anything. 118 | logits = logits[:-1] 119 | log_probs = logits.log_softmax(dim=-1) 120 | probs = logits.softmax(dim=-1) 121 | values = torch.stack([ 122 | logits.gather(-1, tokens[1:, None])[:, 0], 123 | log_probs.gather(-1, tokens[1:, None])[:, 0], 124 | probs.gather(-1, tokens[1:, None])[:, 0], 125 | ], dim=1) 126 | labels = ["logits", "log_probs", "probs"] 127 | return colored_tokens_multi(str_tokens[1:], values, labels) 128 | -------------------------------------------------------------------------------- /python/circuitsvis/topk_samples.py: -------------------------------------------------------------------------------- 1 | """Activations visualizations""" 2 | from typing import List, Optional 3 | 4 | from circuitsvis.utils.render import RenderedHTML, render 5 | 6 | 7 | def topk_samples( 8 | tokens: List[List[List[List[str]]]], 9 | activations: List[List[List[List[float]]]], 10 | zeroth_dimension_name: Optional[str] = "Layer", 11 | first_dimension_name: Optional[str] = "Neuron", 12 | zeroth_dimension_labels: Optional[List[str]] = None, 13 | first_dimension_labels: Optional[List[str]] = None, 14 | ) -> RenderedHTML: 15 | """List of samples in descending order of max token activation value for the 16 | selected layer and neuron (or whatever other dimension names are specified). 17 | 18 | Args: 19 | tokens: List of tokens of shape [layers x neurons x samples x tokens] 20 | activations: Activations of shape [layers x neurons x samples x tokens] 21 | zeroth_dimension_name: Zeroth dimension to display (e.g. "Layer") 22 | first_dimension_name: First dimension to display (e.g. "Neuron") 23 | 24 | Returns: 25 | Html: TopkSamples visualization 26 | """ 27 | return render( 28 | "TopkSamples", 29 | tokens=tokens, 30 | activations=activations, 31 | zerothDimensionName=zeroth_dimension_name, 32 | firstDimensionName=first_dimension_name, 33 | zerothDimensionLabels=zeroth_dimension_labels, 34 | firstDimensionLabels=first_dimension_labels, 35 | ) 36 | -------------------------------------------------------------------------------- /python/circuitsvis/topk_tokens.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from circuitsvis.utils.render import RenderedHTML, render 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def topk_tokens( 8 | tokens: List[List[str]], 9 | activations: List[np.ndarray], # np.ndarray: [n_layers, n_tokens, n_neurons] 10 | max_k: int = 10, 11 | first_dimension_name: str = "Layer", 12 | third_dimension_name: str = "Neuron", 13 | sample_labels: Optional[List[str]] = None, 14 | first_dimension_labels: Optional[List[str]] = None, 15 | ) -> RenderedHTML: 16 | """Show a table of the topk and bottomk activations. 17 | 18 | The columns correspond to the given third_dimension_name. 19 | 20 | Includes drop-downs for all dimensions as well as options to choose the number of columns to show. 21 | 22 | Note that we can't set labels for the third dimension because the visualisation uses pagination on this dimension. 23 | 24 | Args: 25 | tokens: Nested list of tokens for each sample (e.g. `[["A", "person"], 26 | ["He" "ran"]]`) 27 | activations: List with the same length as tokens (indicating the number 28 | of samples) containing activations of the shape [n_layers x 29 | n_neurons x n_tokens] 30 | max_k: Maximum number of top and bottom activations to show. This 31 | prevents us sending too much data to the react component. 32 | first_dimension_name: Name of the first dimension (e.g. "Layer") 33 | third_dimension_name: Name of the third dimension (e.g. "Neuron") 34 | sample_labels: Optional list of labels for each sample 35 | first_dimension_labels: Optional list of labels for each value in the first dimension 36 | 37 | Returns: 38 | Html: Topk activations visualization 39 | """ 40 | 41 | assert len(tokens) == len(activations), "tokens and activations must be same length" 42 | assert all( 43 | act.ndim == 3 for act in activations 44 | ), "activations must be of the form [n_layers, n_tokens, n_neurons]" 45 | 46 | topk_vals = [] 47 | topk_idxs = [] 48 | bottomk_vals = [] 49 | bottomk_idxs = [] 50 | for sample_acts in activations: 51 | # get topk tokens for each object 52 | sample_acts_tensor = torch.from_numpy(sample_acts) 53 | # Set max_k to the min of the number of tokens and the max_k 54 | k = min(sample_acts_tensor.shape[1], max_k) 55 | 56 | sample_topk_vals, sample_topk_idxs = sample_acts_tensor.topk(k=k, dim=1) 57 | # also get bottom k vals 58 | sample_bottomk_vals, sample_bottomk_idxs = sample_acts_tensor.topk( 59 | k=k, dim=1, largest=False 60 | ) 61 | 62 | # reverse sort order of bottomk vals and idxs 63 | sample_bottomk_vals = sample_bottomk_vals.flip(dims=(1,)) 64 | sample_bottomk_idxs = sample_bottomk_idxs.flip(dims=(1,)) 65 | 66 | topk_vals.append(sample_topk_vals.tolist()) 67 | topk_idxs.append(sample_topk_idxs.tolist()) 68 | bottomk_vals.append(sample_bottomk_vals.tolist()) 69 | bottomk_idxs.append(sample_bottomk_idxs.tolist()) 70 | 71 | return render( 72 | "TopkTokens", 73 | tokens=tokens, 74 | topkVals=topk_vals, 75 | topkIdxs=topk_idxs, 76 | bottomkVals=bottomk_vals, 77 | bottomkIdxs=bottomk_idxs, 78 | firstDimensionName=first_dimension_name, 79 | thirdDimensionName=third_dimension_name, 80 | sampleLabels=sample_labels, 81 | firstDimensionLabels=first_dimension_labels, 82 | ) 83 | -------------------------------------------------------------------------------- /python/circuitsvis/utils/build_js.py: -------------------------------------------------------------------------------- 1 | """Build helpers for creating a distributable package""" 2 | from circuitsvis.utils.render import install_if_necessary, bundle_source 3 | 4 | 5 | def build() -> None: 6 | """Bundle the JavaScript/TypeScript source files for a Python package 7 | release""" 8 | # Install 9 | install_if_necessary() 10 | 11 | # Bundle 12 | bundle_source() 13 | -------------------------------------------------------------------------------- /python/circuitsvis/utils/convert_props.py: -------------------------------------------------------------------------------- 1 | """Convert Python to JavaScript safe props""" 2 | import json 3 | from typing import Union, Dict 4 | 5 | import numpy as np 6 | import torch 7 | 8 | PythonProperty = Union[ 9 | dict, 10 | list, 11 | None, 12 | np.ndarray, 13 | torch.Tensor, 14 | bool, 15 | float, 16 | int, 17 | str, 18 | ] 19 | 20 | JavaScriptProperty = Union[ 21 | bool, # Boolean 22 | dict, # Object 23 | float, # Number 24 | int, # Number 25 | list, # Array 26 | str, # String 27 | None, # Undefined 28 | ] 29 | 30 | 31 | def convert_prop_type(prop: PythonProperty) -> JavaScriptProperty: 32 | """Convert property to a JavaScript supported type 33 | 34 | For example, JavaScript doesn't support numpy arrays or torch tensors, so we 35 | convert them to lists (which JavaScript will recognize as an array). 36 | 37 | Args: 38 | prop: The property to convert 39 | 40 | Returns: 41 | Union[str, int, float, bool]: JavaScript safe property 42 | """ 43 | if isinstance(prop, torch.Tensor): 44 | return prop.tolist() 45 | if isinstance(prop, np.ndarray): 46 | return prop.tolist() 47 | 48 | return prop 49 | 50 | 51 | def convert_props(props: Dict[str, PythonProperty]) -> str: 52 | """Convert a set of properties to a JavaScript safe string 53 | 54 | Args: 55 | props: The properties to convert 56 | 57 | Returns: 58 | str: JavaScript safe properties 59 | """ 60 | props_with_values = {k: v for k, v in props.items() if v is not None} 61 | 62 | return json.dumps({k: convert_prop_type(v) for k, v in props_with_values.items()}) 63 | -------------------------------------------------------------------------------- /python/circuitsvis/utils/render.py: -------------------------------------------------------------------------------- 1 | """Helper functions to build visualizations using HTML/web frameworks.""" 2 | import shutil 3 | import subprocess 4 | import os 5 | from pathlib import Path 6 | from urllib import request 7 | from uuid import uuid4 8 | 9 | import circuitsvis 10 | from circuitsvis.utils.convert_props import PythonProperty, convert_props 11 | 12 | REACT_DIR = Path(__file__).parent.parent.parent.parent / "react" 13 | 14 | 15 | def is_in_dev_mode(dir_to_check: Path = REACT_DIR) -> bool: 16 | """Detect if we're in dev mode (running in the CircuitsVis repo) 17 | 18 | Returns: 19 | bool: True if we're in dev mode 20 | """ 21 | return dir_to_check.exists() 22 | 23 | 24 | def internet_on() -> bool: 25 | """Detect if we're online""" 26 | try: 27 | request.urlopen("http://google.com", timeout=1) 28 | return True 29 | except: 30 | pass 31 | 32 | return False 33 | 34 | 35 | class RenderedHTML: 36 | """Rendered HTML 37 | 38 | Enables rendering HTML in a variety of situations (e.g. Jupyter Lab) 39 | """ 40 | 41 | def __init__(self, local_src: str, cdn_src: str): 42 | self.local_src = local_src 43 | self.cdn_src = cdn_src 44 | 45 | def _repr_html_(self) -> str: 46 | """Jupyter/Colab HTML Representation 47 | 48 | When Jupyter sees this method, it renders the HTML. 49 | 50 | Returns: 51 | str: HTML for Jupyter/Colab 52 | """ 53 | # Use local source if we're in dev mode 54 | if is_in_dev_mode(): 55 | return self.local_src 56 | 57 | # Use local source if we're offline 58 | if not internet_on(): 59 | return self.local_src 60 | 61 | # Otherwise use the CDN 62 | return self.cdn_src 63 | 64 | def __html__(self) -> str: 65 | """Used by some tooling as an alternative to _repr_html_""" 66 | return self._repr_html_() 67 | 68 | def show_code(self) -> str: 69 | """Show the code as HTML source code 70 | 71 | This loads JavaScript from the CDN, so it will not work offline. 72 | 73 | Returns: 74 | str: HTML source code (with JavaScript from CDN) 75 | """ 76 | return self.cdn_src 77 | 78 | def __str__(self): 79 | """String type conversion handler 80 | 81 | Returns: 82 | str: HTML source code (with JavaScript from CDN) 83 | """ 84 | return self.cdn_src 85 | 86 | 87 | def install_if_necessary() -> None: 88 | """Install node modules if they're missing.""" 89 | node_modules = REACT_DIR / "node_modules" 90 | if not node_modules.exists(): 91 | subprocess.run( 92 | ["yarn"], 93 | cwd=REACT_DIR, 94 | capture_output=True, 95 | text=True, 96 | check=True 97 | ) 98 | 99 | 100 | def bundle_source(dev_mode: bool = True) -> None: 101 | """Bundle up the JavaScript/TypeScript source files 102 | 103 | Bundles the files together and then also copies them to the Python dist/ 104 | directory. This allows the Python package to also include these files when 105 | it is installed.""" 106 | # Build 107 | build_command = [ 108 | "yarn", 109 | "buildBrowser", 110 | ] 111 | 112 | if dev_mode: 113 | build_command.append("--dev") 114 | 115 | subprocess.run(build_command, 116 | cwd=REACT_DIR, 117 | capture_output=True, 118 | text=True, 119 | check=True 120 | ) 121 | 122 | # Copy files to python dist directory (overwriting any existing files) 123 | react_dist = REACT_DIR / "dist" 124 | python_dist = Path(__file__).parent.parent / "dist" 125 | if os.path.exists(python_dist): 126 | # Python 3.7 doesn't support the exist_ok argument, so we have to delete 127 | # the destination directory first 128 | shutil.rmtree(python_dist) 129 | shutil.copytree(react_dist, python_dist) 130 | 131 | 132 | def render_local(react_element_name: str, **kwargs) -> str: 133 | """Render (using local JavaScript files)""" 134 | # Create a random ID for the div (that we render into) 135 | # This is done to avoid name clashes on a page with many rendered 136 | # CircuitsVis elements. Note we shorten the UUID to be a reasonable length 137 | uuid = "circuits-vis-" + str(uuid4())[:13] 138 | 139 | # Stringify keyword args 140 | props = convert_props(kwargs) 141 | 142 | # Build if in dev mode 143 | if is_in_dev_mode(): 144 | install_if_necessary() 145 | bundle_source() 146 | 147 | # Load the JS 148 | filename = Path(__file__).parent.parent / "dist" / "cdn" / "iife.js" 149 | with open(filename, encoding="utf-8") as file: 150 | inline_js = file.read() 151 | # Remove any closing script tags (as this breaks inline code) 152 | inline_js = inline_js.replace("", "") 153 | 154 | html = f"""
155 | """ 164 | 165 | return html 166 | 167 | 168 | def render_cdn(react_element_name: str, **kwargs: PythonProperty) -> str: 169 | """Render (from the CDN) 170 | 171 | Args: 172 | react_element_name (str): Name of the React element to render 173 | 174 | Returns: 175 | RenderedHTML: HTML for the visualization 176 | """ 177 | # Create a random ID for the div (that we render into) 178 | # This is done to avoid name clashes on a page with many rendered 179 | # CircuitsVis elements. Note we shorten the UUID to be a reasonable length 180 | uuid = "circuits-vis-" + str(uuid4())[:13] 181 | 182 | # Stringify keyword args 183 | props = convert_props(kwargs) 184 | 185 | html = f"""
186 | """ 194 | 195 | return html 196 | 197 | 198 | def render( 199 | react_element_name: str, 200 | **kwargs: PythonProperty 201 | ) -> RenderedHTML: 202 | """Render a visualization to HTML 203 | 204 | This will show the visualization in Jupyter Lab/Colab by default, and show a 205 | string representation of the code otherwise (or if you wrap in `str()`). 206 | 207 | Args: 208 | react_element_name (str): Visualization element name from React codebase 209 | use this if directly developing this library). Defaults to False. 210 | 211 | Returns: 212 | Html: HTML for the visualization 213 | """ 214 | local_src = render_local(react_element_name, **kwargs) 215 | cdn_src = render_cdn(react_element_name, **kwargs) 216 | return RenderedHTML(local_src, cdn_src) 217 | -------------------------------------------------------------------------------- /python/circuitsvis/utils/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TransformerLensOrg/CircuitsVis/f8c1c81c8d6ab1f7882587e2c9ca188bb95c0f58/python/circuitsvis/utils/tests/__init__.py -------------------------------------------------------------------------------- /python/circuitsvis/utils/tests/snapshots/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TransformerLensOrg/CircuitsVis/f8c1c81c8d6ab1f7882587e2c9ca188bb95c0f58/python/circuitsvis/utils/tests/snapshots/__init__.py -------------------------------------------------------------------------------- /python/circuitsvis/utils/tests/snapshots/snap_test_render.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # snapshottest: v1 - https://goo.gl/zC4yUc 3 | from __future__ import unicode_literals 4 | 5 | from snapshottest import Snapshot 6 | 7 | 8 | snapshots = Snapshot() 9 | 10 | snapshots['TestRenderDev.test_example_element 1'] = '''
11 | ''' 19 | -------------------------------------------------------------------------------- /python/circuitsvis/utils/tests/test_convert_props.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from circuitsvis.utils.convert_props import convert_prop_type 3 | import numpy as np 4 | 5 | 6 | class TestConvertPropType: 7 | 8 | def test_dict(self): 9 | res = convert_prop_type({"a": 1}) 10 | assert res == {"a": 1} 11 | 12 | def test_list(self): 13 | res = convert_prop_type([1, 2, 3]) 14 | assert res == [1, 2, 3] 15 | 16 | def test_none(self): 17 | res = convert_prop_type(None) 18 | assert res == None 19 | 20 | def test_ndarray(self): 21 | res = convert_prop_type(np.array([1, 2, 3])) 22 | assert res == [1, 2, 3] 23 | 24 | def test_tensor(self): 25 | res = convert_prop_type(torch.Tensor([1, 2, 3])) 26 | assert res == [1, 2, 3] 27 | 28 | def test_bool(self): 29 | res = convert_prop_type(True) 30 | assert res is True 31 | 32 | def test_float(self): 33 | res = convert_prop_type(1.0) 34 | assert res == 1.0 35 | 36 | def test_int(self): 37 | res = convert_prop_type(1) 38 | assert res == 1 39 | 40 | def test_str(self): 41 | res = convert_prop_type("hello") 42 | assert res == "hello" 43 | -------------------------------------------------------------------------------- /python/circuitsvis/utils/tests/test_render.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from urllib import request 3 | 4 | import circuitsvis.utils.render 5 | from circuitsvis.utils.render import (RenderedHTML, bundle_source, 6 | install_if_necessary, internet_on, 7 | is_in_dev_mode, render, render_cdn, 8 | render_local) 9 | 10 | 11 | class TestIsInDevMode: 12 | def test_is_in_dev_mode(self): 13 | assert is_in_dev_mode() 14 | 15 | def test_is_not_in_dev_mode(self): 16 | does_not_exist_dir = Path(__file__) / "does_not_exist" 17 | assert not is_in_dev_mode(does_not_exist_dir) 18 | 19 | 20 | class TestInternetOn: 21 | def test_internet_on(self, monkeypatch): 22 | def mock_urlopen(url, timeout): 23 | return True 24 | monkeypatch.setattr(request, "urlopen", mock_urlopen) 25 | assert internet_on() 26 | 27 | def test_internet_off(self, monkeypatch): 28 | def mock_urlopen(url, timeout): 29 | raise Exception() 30 | monkeypatch.setattr(request, "urlopen", mock_urlopen) 31 | assert not internet_on() 32 | 33 | 34 | class TestRenderedHTML: 35 | def test_jupyter_renders(self): 36 | src = "

Hi

" 37 | html = RenderedHTML(src, src) 38 | 39 | # Check the _repr_html_ method is defined (as Jupyter Lab displays this) 40 | assert html._repr_html_() == src 41 | 42 | def test_show_code(self): 43 | src = "

Hi

" 44 | html = RenderedHTML(src, src) 45 | assert html.show_code() == src 46 | 47 | def test_string(self): 48 | src = "

Hi

" 49 | html = RenderedHTML(src, src) 50 | assert str(html) == src 51 | 52 | 53 | class TestInstallIfNecessary: 54 | def test_runs_without_errors(self): 55 | install_if_necessary() 56 | 57 | 58 | class TestBundleSource: 59 | def test_runs_without_errors(self): 60 | bundle_source() 61 | # Run twice, to check it doesn't fail if the directory already exists 62 | bundle_source() 63 | 64 | 65 | class TestRenderLocal: 66 | def runs_without_error(self): 67 | render_local("Hello", name="Bob") 68 | 69 | 70 | class TestRenderDev: 71 | def test_example_element(self, snapshot, monkeypatch): 72 | monkeypatch.setattr(circuitsvis.utils.render, "uuid4", lambda: "mock") 73 | monkeypatch.setattr(circuitsvis, "__version__", "1.0.0") 74 | 75 | res = render_cdn("Hello", name="Bob") 76 | snapshot.assert_match(str(res)) 77 | 78 | 79 | class TestRender: 80 | def test_stringified_render_is_from_cdn(self, monkeypatch): 81 | monkeypatch.setattr(circuitsvis.utils.render, "uuid4", lambda: "mock") 82 | monkeypatch.setattr(circuitsvis, "__version__", "1.0.0") 83 | 84 | prod = render_cdn("Hello", name="Bob") 85 | res = render("Hello", name="Bob") 86 | assert str(res) == str(prod) 87 | 88 | def test_jupyter_verson_is_from_local(self, monkeypatch): 89 | monkeypatch.setattr(circuitsvis.utils.render, "uuid4", lambda: "mock") 90 | monkeypatch.setattr(circuitsvis, "__version__", "1.0.0") 91 | 92 | dev = render_local("Hello", name="Bob") 93 | res = render("Hello", name="Bob") 94 | assert res._repr_html_() == dev 95 | -------------------------------------------------------------------------------- /python/mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | ignore_missing_imports = True 3 | -------------------------------------------------------------------------------- /python/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "circuitsvis" 3 | version = "0.0.0" # This is updated by the CI/CD pipeline 4 | description = "Mechanistic Interpretability Visualizations" 5 | authors = ["Alan Cooney <41682961+alan-cooney@users.noreply.github.com>"] 6 | license = "MIT" 7 | readme = "README.md" 8 | include = ["circuitsvis/dist/cdn/iife.js", "circuitsvis/dist/cdn/iife.js.map"] 9 | 10 | [tool.poetry.dependencies] 11 | importlib-metadata = ">=5.1.0" 12 | numpy = [{ version = ">=1.20,<1.25", python = ">=3.8,<3.9" }, 13 | { version = ">=1.24", python = ">=3.9,<3.12" }, 14 | { version = ">=1.26", python = ">=3.12,<3.13" }] 15 | python = ">=3.8" 16 | torch = ">=2.1.1" 17 | 18 | [tool.poetry.group.dev.dependencies] 19 | autopep8 = ">=2.0" 20 | mypy = ">=0.990" 21 | poethepoet = ">=0.16.5" 22 | pytest = ">=7.2" 23 | snapshottest = ">=0.6" 24 | twine = ">=4.0.1" 25 | 26 | [tool.poetry.group.jupyter.dependencies] 27 | jupyterlab = ">=3.5" 28 | 29 | [build-system] 30 | requires = ["poetry-core"] 31 | build-backend = "poetry.core.masonry.api" 32 | 33 | [tool.poe.poetry_hooks] 34 | pre_build = "bundle-js" 35 | 36 | [tool.poe.tasks.bundle-js] 37 | script = "circuitsvis.utils.build_js:build" 38 | help = "Bundle up the latest version of the react library as a single script file." 39 | -------------------------------------------------------------------------------- /python/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | long_description = file: README.md 3 | long_description_content_type = text/markdown 4 | -------------------------------------------------------------------------------- /react/.eslintignore: -------------------------------------------------------------------------------- 1 | !.*.js 2 | node_modules/ 3 | dist/ 4 | -------------------------------------------------------------------------------- /react/.eslintrc.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | env: { 3 | browser: true, 4 | es6: true, 5 | node: true, 6 | jest: true 7 | }, 8 | extends: [ 9 | "airbnb-base", 10 | "airbnb-typescript", 11 | "plugin:prettier/recommended", 12 | "plugin:react/recommended", 13 | "plugin:react-hooks/recommended", 14 | "plugin:testing-library/react", 15 | "plugin:jest-dom/recommended", 16 | "plugin:storybook/recommended" 17 | ], 18 | parser: "@typescript-eslint/parser", 19 | parserOptions: { 20 | project: ["./tsconfig.json"], 21 | ecmaVersion: 2018, 22 | ecmaFeatures: { 23 | jsx: true 24 | }, 25 | sourceType: "module" 26 | }, 27 | plugins: [ 28 | "@typescript-eslint", 29 | "jest-dom", 30 | "jsx-a11y", 31 | "react", 32 | "testing-library" 33 | ], 34 | rules: { 35 | "import/prefer-default-export": "off", 36 | "import/no-extraneous-dependencies": [ 37 | "error", 38 | { devDependencies: ["**/*.test.*", "**/*.stories.*"] } 39 | ], 40 | "no-plusplus": ["error", { allowForLoopAfterthoughts: true }], 41 | "react/jsx-filename-extension": "off", 42 | "react/require-default-props": "off", // Not needed with TS 43 | "react/react-in-jsx-scope": "off" // Esbuild injects this for us 44 | }, 45 | settings: { 46 | react: { 47 | version: "detect" 48 | } 49 | } 50 | }; 51 | -------------------------------------------------------------------------------- /react/.prettierrc: -------------------------------------------------------------------------------- 1 | { 2 | "trailingComma": "none" 3 | } 4 | -------------------------------------------------------------------------------- /react/.storybook/main.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | "stories": [ 3 | "../src/**/*.stories.mdx", 4 | "../src/**/*.stories.@(js|jsx|ts|tsx)" 5 | ], 6 | "addons": [ 7 | "@storybook/addon-links", 8 | "@storybook/addon-essentials", 9 | "@storybook/addon-interactions" 10 | ], 11 | "framework": "@storybook/react", 12 | core: { 13 | builder: 'webpack5', 14 | }, 15 | } 16 | -------------------------------------------------------------------------------- /react/.storybook/preview.js: -------------------------------------------------------------------------------- 1 | export const parameters = { 2 | actions: { argTypesRegex: "^on[A-Z].*" }, 3 | controls: { 4 | matchers: { 5 | color: /(background|color)$/i, 6 | date: /Date$/, 7 | }, 8 | }, 9 | viewMode: "docs", 10 | } 11 | -------------------------------------------------------------------------------- /react/README.md: -------------------------------------------------------------------------------- 1 | # CircuitsVis 2 | 3 | Mechanistic Interpretability visualizations in React. 4 | 5 | View all available components in Storybook at 6 | https://TransformerLensOrg.github.io/CircuitsVis . 7 | 8 | ## Use 9 | 10 | ### Within a React Project 11 | 12 | First install the package: 13 | 14 | ```shell 15 | yarn add circuitsvis 16 | ``` 17 | 18 | Then import and use the visualizations directly: 19 | 20 | ```tsx 21 | import { Hello } from "circuitsvis"; 22 | 23 | export function Demo() { 24 | return ; 25 | } 26 | ``` 27 | 28 | ### Standalone 29 | 30 | You can use this package directly from a CDN (e.g. unpkg) to render visualizations. 31 | 32 | #### Modern ES Modules Approach 33 | 34 | ```html 35 |
36 | 37 | 46 | ``` 47 | 48 | #### ES6 Approach (supports more legacy browsers) 49 | 50 | ```html 51 |
52 | 53 | 58 | 59 | 72 | ``` 73 | 74 | ### Within a Python project 75 | 76 | See https://github.com/TransformerLensOrg/CircuitsVis for details of how to use this 77 | library within a Python project. 78 | -------------------------------------------------------------------------------- /react/esbuild.js: -------------------------------------------------------------------------------- 1 | /* eslint-disable import/no-extraneous-dependencies */ 2 | const esbuild = require("esbuild"); 3 | const { externalGlobalPlugin } = require("esbuild-plugin-external-global"); 4 | const yargs = require("yargs"); 5 | 6 | /** 7 | * Command line arguments 8 | */ 9 | const commandLineArgs = yargs.option("dev", { 10 | alias: "D", 11 | describe: "Development mode (only builds IIFE version, for speed)", 12 | type: "boolean", 13 | default: false 14 | }).argv; 15 | 16 | async function buildAll() { 17 | /** 18 | * CDN IIFE Build (with React) 19 | */ 20 | await esbuild.build({ 21 | entryPoints: ["src/index.ts"], 22 | outfile: "dist/cdn/iife.js", 23 | bundle: true, 24 | target: "es6", 25 | minify: true, 26 | legalComments: "none", 27 | sourcemap: true, 28 | globalName: "CircuitsVis" // Components available as e.g. `CircuitsVis.Hello` 29 | }); 30 | 31 | if (!commandLineArgs.dev) { 32 | /** 33 | * CDN ESM Build (with React) 34 | */ 35 | await esbuild.build({ 36 | entryPoints: ["src/index.ts"], 37 | outfile: "dist/cdn/esm.js", 38 | bundle: true, 39 | target: "es2020", 40 | format: "esm", 41 | minify: true, 42 | platform: "browser", 43 | legalComments: "none", 44 | sourcemap: true 45 | }); 46 | 47 | /** 48 | * CDN IIFE Build (without React) 49 | * 50 | * This allows the user to import and run alongside their own browser import of 51 | * React (whichever version that may be). 52 | */ 53 | await esbuild.build({ 54 | entryPoints: ["src/index.ts"], 55 | outfile: "dist/cdn/without-react.iife.js", 56 | bundle: true, 57 | target: "es6", 58 | minify: true, 59 | legalComments: "none", 60 | sourcemap: true, 61 | globalName: "CircuitsVis", // Components available as e.g. `CircuitsVis.Hello` 62 | plugins: [ 63 | // Exclude React/ReactDom from the browser 64 | externalGlobalPlugin({ 65 | react: "window.React", 66 | "react-dom": "window.ReactDOM" 67 | }) 68 | ] 69 | }); 70 | 71 | /** 72 | * CommonJS Version 73 | */ 74 | await esbuild.build({ 75 | entryPoints: ["src/index.ts"], 76 | outdir: "dist/commonjs/", 77 | external: ["./node_modules/*"], 78 | sourcemap: true, 79 | bundle: true, 80 | target: ["node12"], 81 | platform: "node", 82 | format: "cjs" 83 | }); 84 | } 85 | } 86 | 87 | buildAll(); 88 | -------------------------------------------------------------------------------- /react/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "circuitsvis", 3 | "description": "Mechanistic Interpretability Visualizations", 4 | "version": "0.0.0", 5 | "main": "dist/commonjs/index.js", 6 | "module": "dist/module/index.js", 7 | "unpkg": "dist/cdn/iife.js", 8 | "jsdelivr": "dist/cdn/iife.js", 9 | "source": "src/index.ts", 10 | "types": "dist/module/index.d.ts", 11 | "sideEffects": false, 12 | "files": [ 13 | "dist" 14 | ], 15 | "exports": { 16 | "./package.json": "./package.json", 17 | ".": { 18 | "types": "./dist/index.d.ts", 19 | "import": "./dist/index.esm.mjs", 20 | "require": "./dist/index.cjs.js" 21 | } 22 | }, 23 | "scripts": { 24 | "test": "jest", 25 | "lint": "eslint src", 26 | "buildBrowser": "node ./esbuild.js", 27 | "buildNode": "tsc --project tsconfig.build.json", 28 | "build": "yarn buildBrowser && yarn buildNode", 29 | "typeCheck": "tsc --noEmit", 30 | "storybook": "start-storybook -p 6006", 31 | "build-storybook": "build-storybook" 32 | }, 33 | "repository": "https://github.com/TransformerLensOrg/CircuitsVis", 34 | "author": "Alan Cooney", 35 | "license": "MIT", 36 | "dependencies": { 37 | "@tensorflow/tfjs": "^4.1.0", 38 | "chart.js": "^4.0.1", 39 | "chartjs-chart-matrix": "^1.3.0", 40 | "colord": "^2.9.3", 41 | "react-chartjs-2": "^5.0.1", 42 | "react-grid-system": "^8.1.6", 43 | "react-popper-tooltip": "^4.4.2", 44 | "tinycolor2": "^1.4.2" 45 | }, 46 | "devDependencies": { 47 | "@storybook/addon-actions": "^6.5.14", 48 | "@storybook/addon-essentials": "^6.5.14", 49 | "@storybook/addon-interactions": "^6.5.14", 50 | "@storybook/addon-links": "^6.5.14", 51 | "@storybook/builder-webpack5": "^6.5.14", 52 | "@storybook/manager-webpack5": "^6.5.14", 53 | "@storybook/preset-typescript": "^3.0.0", 54 | "@storybook/react": "^6.5.14", 55 | "@storybook/testing-library": "^0.0.13", 56 | "@tensorflow/tfjs-node": "^4.1.0", 57 | "@testing-library/jest-dom": "5.16.5", 58 | "@testing-library/react": "^13.4.0", 59 | "@testing-library/user-event": "14.4.3", 60 | "@types/jest": "29.2.4", 61 | "@types/node": "^18.11.11", 62 | "@types/react": "^18.0.26", 63 | "@types/testing-library__jest-dom": "^5.14.5", 64 | "@types/tinycolor2": "^1.4.3", 65 | "@typescript-eslint/eslint-plugin": "5.45.1", 66 | "@typescript-eslint/parser": "^5.45.1", 67 | "chromatic": "^6.12.0", 68 | "esbuild": "^0.16.1", 69 | "esbuild-plugin-external-global": "^1.0.1", 70 | "eslint": "8.29.0", 71 | "eslint-config-airbnb-base": "15.0.0", 72 | "eslint-config-airbnb-typescript": "17.0.0", 73 | "eslint-config-prettier": "8.5.0", 74 | "eslint-import-resolver-typescript": "3.5.2", 75 | "eslint-plugin-import": "2.26.0", 76 | "eslint-plugin-jest": "27.1.6", 77 | "eslint-plugin-jest-dom": "4.0.3", 78 | "eslint-plugin-jsx-a11y": "6.6.1", 79 | "eslint-plugin-prettier": "4.2.1", 80 | "eslint-plugin-react": "7.31.11", 81 | "eslint-plugin-react-hooks": "4.6.0", 82 | "eslint-plugin-storybook": "^0.6.8", 83 | "eslint-plugin-testing-library": "5.9.1", 84 | "jest": "^29.3.1", 85 | "jest-canvas-mock": "^2.4.0", 86 | "jest-environment-jsdom": "29.3.1", 87 | "prettier": "^2.8.1", 88 | "react": "^18.2.0", 89 | "react-dom": "^18.2.0", 90 | "testing-library": "^0.0.2", 91 | "ts-jest": "^29.0.3", 92 | "typescript": "^4.9.3" 93 | }, 94 | "peerDependencies": { 95 | "react": "^16.8.0 || ^17 || ^18" 96 | }, 97 | "jest": { 98 | "testEnvironment": "jsdom", 99 | "transform": { 100 | "^.+\\.tsx?$": [ 101 | "ts-jest" 102 | ] 103 | }, 104 | "setupFiles": [ 105 | "jest-canvas-mock" 106 | ] 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /react/src/activations/TextNeuronActivations.stories.tsx: -------------------------------------------------------------------------------- 1 | import { ComponentStory, ComponentMeta } from "@storybook/react"; 2 | import React from "react"; 3 | import { 4 | mockActivations, 5 | mockTokens, 6 | neuronLabels 7 | } from "./mocks/textNeuronActivations"; 8 | import { TextNeuronActivations } from "./TextNeuronActivations"; 9 | 10 | export default { 11 | component: TextNeuronActivations 12 | } as ComponentMeta; 13 | 14 | const Template: ComponentStory = (args) => ( 15 | 16 | ); 17 | 18 | export const MultipleSamples = Template.bind({}); 19 | MultipleSamples.args = { 20 | tokens: mockTokens, 21 | activations: mockActivations, 22 | secondDimensionLabels: neuronLabels 23 | }; 24 | 25 | export const SingleSample = Template.bind({}); 26 | SingleSample.args = { 27 | tokens: mockTokens[0], 28 | activations: mockActivations[0], 29 | secondDimensionLabels: neuronLabels 30 | }; 31 | -------------------------------------------------------------------------------- /react/src/activations/TextNeuronActivations.tsx: -------------------------------------------------------------------------------- 1 | import { Rank, tensor, Tensor1D, Tensor3D } from "@tensorflow/tfjs"; 2 | import React, { useState, useEffect } from "react"; 3 | import { Container, Row, Col } from "react-grid-system"; 4 | import { SampleItems } from "../shared/SampleItems"; 5 | import { RangeSelector } from "../shared/RangeSelector"; 6 | import { NumberSelector } from "../shared/NumberSelector"; 7 | import { minMaxInNestedArray } from "../utils/arrayOps"; 8 | 9 | /** 10 | * Get the selected activations 11 | * 12 | * @param activations All activations [ tokens x layers x neurons ] 13 | * @param layerNumber 14 | * @param neuronNumber 15 | */ 16 | export function getSelectedActivations( 17 | activations: Tensor3D, 18 | layerNumber: number, 19 | neuronNumber: number 20 | ): number[] { 21 | const relevantActivations = activations 22 | .slice([0, layerNumber, neuronNumber], [-1, 1, 1]) 23 | .squeeze([1, 2]); 24 | return relevantActivations.arraySync(); 25 | } 26 | 27 | /** 28 | * Show activations (colored by intensity) for each token. 29 | * 30 | * Includes drop-downs for e.g. showing the activations for the selected layer 31 | * and neuron for the given samples. 32 | */ 33 | export function TextNeuronActivations({ 34 | tokens, 35 | activations, 36 | firstDimensionName = "Layer", 37 | secondDimensionName = "Neuron", 38 | firstDimensionLabels, 39 | secondDimensionLabels 40 | }: TextNeuronActivationsProps) { 41 | // If there is only one sample (i.e. if tokens is an array of strings), cast tokens and activations to an array with 42 | // a single element 43 | const tokensList: string[][] = 44 | typeof tokens[0] === "string" 45 | ? ([tokens] as string[][]) 46 | : (tokens as string[][]); 47 | const activationsList: number[][][][] = 48 | typeof activations[0][0][0] === "number" 49 | ? ([activations] as number[][][][]) 50 | : (activations as number[][][][]); 51 | 52 | // Obtain min and max activations for a consistent color scale across all samples 53 | const [minValue, maxValue] = minMaxInNestedArray(activationsList); 54 | 55 | // Convert the activations to a tensor 56 | const activationsTensors = activationsList.map((sampleActivations) => { 57 | return tensor(sampleActivations); 58 | }); 59 | 60 | // Get number of layers/neurons 61 | const numberOfLayers = activationsTensors[0].shape[1]; 62 | const numberOfNeurons = activationsTensors[0].shape[2]; 63 | const numberOfSamples = activationsTensors.length; 64 | 65 | const [samplesPerPage, setSamplesPerPage] = useState( 66 | Math.min(5, numberOfSamples) 67 | ); 68 | const [sampleNumbers, setSampleNumbers] = useState([ 69 | ...Array(samplesPerPage).keys() 70 | ]); 71 | const [layerNumber, setLayerNumber] = useState(0); 72 | const [neuronNumber, setNeuronNumber] = useState(0); 73 | 74 | useEffect(() => { 75 | // When the user changes the samplesPerPage, update the sampleNumbers 76 | setSampleNumbers([...Array(samplesPerPage).keys()]); 77 | }, [samplesPerPage]); 78 | 79 | // Get the relevant activations for the selected samples, layer, and neuron. 80 | const selectedActivations: number[][] = sampleNumbers.map((sampleNumber) => { 81 | return getSelectedActivations( 82 | activationsTensors[sampleNumber], 83 | layerNumber, 84 | neuronNumber 85 | ); 86 | }); 87 | 88 | const selectedTokens: string[][] = sampleNumbers.map((sampleNumber) => { 89 | return tokensList[sampleNumber]; 90 | }); 91 | 92 | const selectRowStyle = { 93 | paddingTop: 5, 94 | paddingBottom: 5 95 | }; 96 | 97 | return ( 98 | 99 | 100 | 101 | 102 | 103 | 106 | 113 | 114 | 115 | 116 | 117 | 120 | 127 | 128 | 129 | {/* Only show the sample selector if there is more than one sample */} 130 | {numberOfSamples > 1 && ( 131 | 132 | 133 | 136 | 143 | 144 | 145 | )} 146 | 147 | 148 | {/* Only show the sample per page selector if there is more than one sample */} 149 | {numberOfSamples > 1 && ( 150 | 151 | 152 | 158 | 165 | 166 | 167 | )} 168 | 169 | 170 | 171 | 172 | 178 | 179 | 180 | 181 | ); 182 | } 183 | 184 | export interface TextNeuronActivationsProps { 185 | /** 186 | * List of lists of tokens (if multiple samples) or a list of tokens (if 187 | * single sample) 188 | * 189 | * If multiple samples, each list must be the same length as the number of activations in the 190 | * corresponding activations list. 191 | */ 192 | tokens: string[][] | string[]; 193 | 194 | /** 195 | * Activations 196 | * 197 | * If multiple samples, will be a nested list of numbers, of the form [ sample x tokens x layers x neurons 198 | * ]. If a single sample, will be a list of numbers of the form [ tokens x layers x neurons ]. 199 | */ 200 | activations: number[][][][] | number[][][]; 201 | 202 | /** 203 | * Name of the first dimension 204 | */ 205 | firstDimensionName?: string; 206 | 207 | /** 208 | * Name of the second dimension 209 | */ 210 | secondDimensionName?: string; 211 | 212 | /** 213 | * Labels for the first dimension 214 | */ 215 | firstDimensionLabels?: string[]; 216 | 217 | /** 218 | * Labels for the second dimension 219 | */ 220 | secondDimensionLabels?: string[]; 221 | } 222 | -------------------------------------------------------------------------------- /react/src/activations/mocks/textNeuronActivations.ts: -------------------------------------------------------------------------------- 1 | const text: string = ` 2 | A goose (PL: geese) is a bird of any of several waterfowl species in the family Anatidae. This group comprises the genera Anser (the grey geese and white geese) and Branta (the black geese). Some other birds, mostly related to the shelducks, have "goose" as part of their names. More distantly related members of the family Anatidae are swans, most of which are larger than true geese, and ducks, which are smaller. 3 | 4 | The term "goose" may refer to either a male or female bird, but when paired with "gander", refers specifically to a female one (the latter referring to a male). Young birds before fledging are called goslings.[1] The collective noun for a group of geese on the ground is a gaggle; when in flight, they are called a skein, a team, or a wedge; when flying close together, they are called a plump.[2] 5 | Contents 6 | 7 | 1 Etymology 8 | 2 True geese and their relatives 9 | 3 Fossil record 10 | 4 Migratory patterns 11 | 4.1 Preparation 12 | 4.2 Navigation 13 | 4.3 Formation 14 | 5 Other birds called "geese" 15 | 6 In popular culture 16 | 6.1 "Gray Goose Laws" in Iceland 17 | 7 Gallery 18 | 8 See also 19 | 9 References 20 | 10 Further reading 21 | 11 External links 22 | 23 | Etymology 24 | 25 | The word "goose" is a direct descendant of,*ghans-. In Germanic languages, the root gave Old English gōs with the plural gēs and gandres (becoming Modern English goose, geese, gander, and gosling, respectively), Frisian goes, gies and guoske, New High German Gans, Gänse, and Ganter, and Old Norse gās. 26 | 27 | This term also gave Lithuanian: žąsìs, Irish: gé (goose, from Old Irish géiss), Hindi: कलहंस, Latin: anser, Spanish: ganso, Ancient Greek: χήν (khēn), Dutch: gans, Albanian: gatë swans), Finnish: hanhi, Avestan zāō, Polish: gęś, Romanian: gâscă / gânsac, Ukrainian: гуска / гусак (huska / husak), Russian: гусыня / гусь (gusyna / gus), Czech: husa, and Persian: غاز (ghāz).[1][3] 28 | True geese and their relatives 29 | Snow geese (Anser caerulescens) in Quebec, Canada 30 | Chinese geese (Anser cygnoides domesticus), the domesticated form of the swan goose (Anser cygnoides) 31 | Barnacle geese (Branta leucopsis) in Naantali, Finland 32 | 33 | The two living genera of true geese are: Anser, grey geese and white geese, such as the greylag goose and snow goose, and Branta, black geese, such as the Canada goose. 34 | 35 | Two genera of geese are only tentatively placed in the Anserinae; they may belong to the shelducks or form a subfamily on their own: Cereopsis, the Cape Barren goose, and Cnemiornis, the prehistoric New Zealand goose. Either these or, more probably, the goose-like coscoroba swan is the closest living relative of the true geese. 36 | 37 | Fossils of true geese are hard to assign to genus; all that can be said is that their fossil record, particularly in North America, is dense and comprehensively documents many different species of true geese that have been around since about 10 million years ago in the Miocene. The aptly named Anser atavus (meaning "progenitor goose") from some 12 million years ago had even more plesiomorphies in common with swans. In addition, some goose-like birds are known from subfossil remains found on the Hawaiian Islands. 38 | 39 | Geese are monogamous, living in permanent pairs throughout the year; however, unlike most other permanently monogamous animals, they are territorial only during the short nesting season. Paired geese are more dominant and feed more, two factors that result in more young.[4][5] 40 | 41 | Geese honk while in flight to encourage other members of the flock to maintain a 'v-formation' and to help communicate with one another.[6] 42 | Fossil record 43 | 44 | Geese fossils have been found ranging from 10 to 12 million years ago (Middle Miocene). Garganornis ballmanni from Late Miocene (~ 6-9 Ma) of Gargano region of central Italy, stood one and a half meters tall and weighed about 22 kilograms. The evidence suggests the bird was flightless, unlike modern geese.[7] 45 | Migratory patterns 46 | 47 | Geese like the Canada goose do not always migrate.[8] Some members of the species only move south enough to ensure a supply of food and water. When European settlers came to America, the birds were seen as easy prey and were almost wiped out of the population. The species was reintroduced across the northern U.S. range and their population has been growing ever since.[9] 48 | Preparation 49 | `; 50 | 51 | function chunkText(textArr: string[]): string[][] { 52 | const chunks: string[][] = []; 53 | let i = 0; 54 | // Split textArr into 12 chunks of 75 tokens 55 | const chunkSize = 75; 56 | while (i < textArr.length) { 57 | chunks.push(textArr.slice(i, i + chunkSize)); 58 | i += chunkSize; 59 | } 60 | return chunks; 61 | } 62 | 63 | export const mockTokens: string[][] = chunkText(text.split(/(?=\s)/)); 64 | 65 | const numLayers: number = 2; 66 | const numNeurons: number = 3; 67 | function createRandom3DActivationMatrix(shape: number[]) { 68 | return Array.from(Array(shape[0]), () => 69 | Array.from(Array(shape[1]), () => 70 | Array.from(Array(shape[2]), () => Math.random()) 71 | ) 72 | ); 73 | } 74 | export const mockActivations: number[][][][] = mockTokens.map((tokens) => { 75 | return createRandom3DActivationMatrix([tokens.length, numLayers, numNeurons]); 76 | }); 77 | 78 | export const neuronLabels: string[] = ["3", "9", "42"]; 79 | -------------------------------------------------------------------------------- /react/src/activations/tests/TestNeuronActivations.test.tsx: -------------------------------------------------------------------------------- 1 | import { tensor } from "@tensorflow/tfjs-node"; 2 | import { render, screen } from "@testing-library/react"; 3 | import React from "react"; 4 | import { 5 | getSelectedActivations, 6 | TextNeuronActivations 7 | } from "../TextNeuronActivations"; 8 | import { mockActivations, mockTokens } from "../mocks/textNeuronActivations"; 9 | 10 | describe("getSelectedActivations", () => { 11 | it("gets the correct activation for a specified layer and neuron", () => { 12 | const activations: number[][][] = [ 13 | // Token 0 14 | [ 15 | // Layer 0 16 | [ 17 | 0, // Neuron 0 18 | 1 // Neuron 1 19 | ], 20 | // Layer 1 21 | [ 22 | 10, // Neuron 0 23 | 20 // Neuron 1 24 | ] 25 | ] 26 | ]; 27 | 28 | const res = getSelectedActivations(tensor(activations), 1, 1); 29 | 30 | expect(res).toEqual([20]); 31 | }); 32 | }); 33 | 34 | describe("TextNeuronActivations", () => { 35 | it("renders", () => { 36 | render( 37 | 41 | ); 42 | 43 | // Check it renders 44 | screen.getByText("Layer:"); 45 | }); 46 | }); 47 | -------------------------------------------------------------------------------- /react/src/attention/AttentionHeads.stories.tsx: -------------------------------------------------------------------------------- 1 | import { ComponentStory, ComponentMeta } from "@storybook/react"; 2 | import React from "react"; 3 | 4 | import { AttentionHeads } from "./AttentionHeads"; 5 | import { mockAttention, mockTokens } from "./mocks/attention"; 6 | 7 | export default { 8 | component: AttentionHeads, 9 | argTypes: { 10 | negativeColor: { control: "color" }, 11 | positiveColor: { control: "color" } 12 | } 13 | } as ComponentMeta; 14 | 15 | const Template: ComponentStory = (args) => ( 16 | 17 | ); 18 | 19 | export const InductionHeadsLayer: ComponentStory = 20 | Template.bind({}); 21 | InductionHeadsLayer.args = { 22 | tokens: mockTokens, 23 | attention: mockAttention 24 | }; 25 | -------------------------------------------------------------------------------- /react/src/attention/AttentionHeads.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import { Col, Container, Row } from "react-grid-system"; 3 | import { AttentionPattern } from "./AttentionPattern"; 4 | import { useHoverLock, UseHoverLockState } from "./components/useHoverLock"; 5 | 6 | /** 7 | * Attention head color 8 | * 9 | * @param idx Head index 10 | * @param numberOfHeads Number of heads 11 | * @param alpha Opaqueness (0% = fully transparent, 100% = fully opaque) 12 | */ 13 | export function attentionHeadColor( 14 | idx: number, 15 | numberOfHeads: number, 16 | alpha: string = "100%" 17 | ): string { 18 | const hue = Math.round((idx / numberOfHeads) * 360); 19 | 20 | return `hsla(${hue}, 70%, 50%, ${alpha})`; 21 | } 22 | 23 | /** 24 | * Attention Heads Selector 25 | */ 26 | export function AttentionHeadsSelector({ 27 | attention, 28 | attentionHeadNames, 29 | focused, 30 | maxValue, 31 | minValue, 32 | negativeColor, 33 | onClick, 34 | onMouseEnter, 35 | onMouseLeave, 36 | positiveColor, 37 | maskUpperTri, 38 | tokens 39 | }: AttentionHeadsProps & { 40 | attentionHeadNames: string[]; 41 | } & UseHoverLockState) { 42 | return ( 43 | 44 | {attention.map((headAttention, idx) => { 45 | const isFocused = focused === idx; 46 | 47 | return ( 48 | 49 |
onClick(idx)} 52 | onMouseEnter={() => onMouseEnter(idx)} 53 | onMouseLeave={onMouseLeave} 54 | > 55 |
70 |

82 | {attentionHeadNames[idx]} 83 |

84 | 85 | 95 |
96 |
97 | 98 | ); 99 | })} 100 |
101 | ); 102 | } 103 | 104 | /** 105 | * Attention patterns from destination to source tokens, for a group of heads. 106 | * 107 | * Displays a small heatmap for each attention head. When one is selected, it is 108 | * then shown in full size. 109 | */ 110 | export function AttentionHeads({ 111 | attention, 112 | attentionHeadNames, 113 | maxValue, 114 | minValue, 115 | negativeColor, 116 | positiveColor, 117 | maskUpperTri = true, 118 | tokens 119 | }: AttentionHeadsProps) { 120 | // Attention head focussed state 121 | const { focused, onClick, onMouseEnter, onMouseLeave } = useHoverLock(0); 122 | 123 | const headNames = 124 | attentionHeadNames || attention.map((_, idx) => `Head ${idx}`); 125 | 126 | return ( 127 | 128 |

129 | Head Selector (hover to view, click to lock) 130 |

131 | 132 | 146 | 147 | 148 | 149 |

{headNames[focused]} Zoomed

150 |
151 |

163 | {headNames[focused]} 164 |

165 | 175 |
176 | 177 |
178 | 179 | 180 |
181 | ); 182 | } 183 | 184 | export interface AttentionHeadsProps { 185 | /** 186 | * Attention heads activations 187 | * 188 | * Of the shape [ heads x dest_pos x src_pos ] 189 | */ 190 | attention: number[][][]; 191 | 192 | /** 193 | * Names for each attention head 194 | * 195 | * Useful if e.g. you want to label the heads with the layer they are from. 196 | */ 197 | attentionHeadNames?: string[]; 198 | 199 | /** 200 | * Maximum value 201 | * 202 | * Used to determine how dark the token color is when positive (i.e. based on 203 | * how close it is to the maximum value). 204 | * 205 | * @default Math.max(...values) 206 | */ 207 | maxValue?: number; 208 | 209 | /** 210 | * Minimum value 211 | * 212 | * Used to determine how dark the token color is when negative (i.e. based on 213 | * how close it is to the minimum value). 214 | * 215 | * @default Math.min(...values) 216 | */ 217 | minValue?: number; 218 | 219 | /** 220 | * Negative color 221 | * 222 | * Color to use for negative values. This can be any valid CSS color string. 223 | * 224 | * Be mindful of color blindness if not using the default here. 225 | * 226 | * @default red 227 | * 228 | * @example rgb(255, 0, 0) 229 | * 230 | * @example #ff0000 231 | */ 232 | negativeColor?: string; 233 | 234 | /** 235 | * Positive color 236 | * 237 | * Color to use for positive values. This can be any valid CSS color string. 238 | * 239 | * Be mindful of color blindness if not using the default here. 240 | * 241 | * @default blue 242 | * 243 | * @example rgb(0, 0, 255) 244 | * 245 | * @example #0000ff 246 | */ 247 | positiveColor?: string; 248 | 249 | /** 250 | * Mask upper triangular 251 | * 252 | * Whether or not to mask the upper triangular portion of the attention patterns. 253 | * 254 | * Should be true for causal attention, false for bidirectional attention. 255 | * 256 | * @default true 257 | */ 258 | maskUpperTri?: boolean; 259 | 260 | /** 261 | * Show axis labels 262 | */ 263 | showAxisLabels?: boolean; 264 | 265 | /** 266 | * List of tokens 267 | * 268 | * Must be the same length as the list of values. 269 | */ 270 | tokens: string[]; 271 | } 272 | -------------------------------------------------------------------------------- /react/src/attention/AttentionPattern.stories.tsx: -------------------------------------------------------------------------------- 1 | import { ComponentStory, ComponentMeta } from "@storybook/react"; 2 | import React from "react"; 3 | 4 | import { AttentionPattern } from "./AttentionPattern"; 5 | import { 6 | mockAttention, 7 | mockShortAttention, 8 | mockShortPrompt, 9 | mockTokens 10 | } from "./mocks/attention"; 11 | 12 | export default { 13 | component: AttentionPattern, 14 | argTypes: { 15 | negativeColor: { control: "color" }, 16 | positiveColor: { control: "color" }, 17 | tokens: { control: { type: "object", raw: true } }, 18 | values: { control: { type: "object", raw: true } } 19 | } 20 | } as ComponentMeta; 21 | 22 | const Template: ComponentStory = (args) => ( 23 | 24 | ); 25 | 26 | export const InductionHead: ComponentStory = 27 | Template.bind({}); 28 | InductionHead.args = { 29 | tokens: mockTokens, 30 | attention: mockAttention[0] 31 | }; 32 | 33 | export const SmallDummyHead: ComponentStory = 34 | Template.bind({}); 35 | SmallDummyHead.args = { 36 | tokens: mockShortPrompt, 37 | attention: mockShortAttention 38 | }; 39 | -------------------------------------------------------------------------------- /react/src/attention/AttentionPattern.tsx: -------------------------------------------------------------------------------- 1 | import React, { useMemo } from "react"; 2 | import { MatrixController, MatrixElement } from "chartjs-chart-matrix"; 3 | import { 4 | Chart as ChartJS, 5 | CategoryScale, 6 | Tooltip, 7 | ScriptableContext, 8 | TooltipItem, 9 | ChartData, 10 | LinearScale 11 | } from "chart.js"; 12 | import { Chart, ChartProps } from "react-chartjs-2"; 13 | import { Col, Row } from "react-grid-system"; 14 | import { colord } from "colord"; 15 | import { getTokenBackgroundColor } from "../utils/getTokenBackgroundColor"; 16 | 17 | /** 18 | * Register ChartJS plugins 19 | */ 20 | ChartJS.register( 21 | CategoryScale, 22 | Tooltip, 23 | MatrixElement, 24 | MatrixController, 25 | LinearScale 26 | ); 27 | 28 | /** 29 | * Block data point 30 | * 31 | * Contains information about a single block on the chart. 32 | */ 33 | export interface Block { 34 | /** Source token with index suffix */ 35 | x: string; 36 | /** Destination token with index suffix */ 37 | y: string; 38 | /** Attention value */ 39 | v: number; 40 | /** Source token */ 41 | srcToken: string; 42 | /** Destination token */ 43 | destToken: string; 44 | /** Source index */ 45 | srcIdx: number; 46 | /** Destination index */ 47 | destIdx: number; 48 | } 49 | 50 | const DefaultUpperTriColor = "rgb(200,200,200)"; 51 | 52 | /** 53 | * Attention pattern from destination to source tokens. Displays a heatmap of 54 | * attention values (hover to see the specific values). 55 | */ 56 | export function AttentionPattern({ 57 | attention, 58 | maxValue = 1, 59 | minValue = -1, 60 | negativeColor, 61 | positiveColor, 62 | upperTriColor = DefaultUpperTriColor, 63 | showAxisLabels = true, 64 | zoomed = false, 65 | maskUpperTri = true, 66 | tokens 67 | }: AttentionPatternProps) { 68 | // Tokens must be unique (for the categories), so we add an index prefix 69 | const uniqueTokens = useMemo( 70 | () => tokens.map((token, idx) => `${token.replace(/\s/g, "")} (${idx})`), 71 | [tokens] 72 | ); 73 | 74 | // Memoize the chart data 75 | const chartData = useMemo(() => { 76 | return attention 77 | .map((src, destIdx) => 78 | src.map((value, srcIdx) => ({ 79 | srcIdx, 80 | destIdx, 81 | srcToken: tokens[srcIdx], 82 | destToken: tokens[destIdx], 83 | x: uniqueTokens[srcIdx], 84 | y: uniqueTokens[destIdx], 85 | v: value 86 | })) 87 | ) 88 | .flat(); 89 | }, [attention, tokens, uniqueTokens]); 90 | 91 | // Format the chart data 92 | const data: ChartData<"matrix", Block[], unknown> = { 93 | datasets: [ 94 | { 95 | // Data must be given in the form {x: xCategory, y: yCategory, v: value} 96 | data: chartData, 97 | // Set the background color for each block, based on the attention value 98 | backgroundColor(context: ScriptableContext<"matrix">) { 99 | const block = context.dataset.data[context.dataIndex] as any as Block; 100 | if (maskUpperTri && block.srcIdx > block.destIdx) { 101 | // Color the upper triangular part separately 102 | return colord(upperTriColor).toRgbString(); 103 | } 104 | const color = getTokenBackgroundColor( 105 | block.v, 106 | minValue, 107 | maxValue, 108 | negativeColor, 109 | positiveColor 110 | ); 111 | return color.toRgbString(); 112 | }, 113 | // Block size 114 | width: (ctx) => ctx.chart.chartArea.width / tokens.length, 115 | height: (ctx) => ctx.chart.chartArea.height / tokens.length 116 | } 117 | ] 118 | }; 119 | 120 | // Chart options 121 | const options: ChartProps<"matrix", Block[], unknown>["options"] = { 122 | animation: { 123 | duration: 0 // general animation time 124 | }, 125 | plugins: { 126 | // Tooltip (hover) options 127 | tooltip: { 128 | enabled: showAxisLabels, 129 | yAlign: "bottom", 130 | callbacks: { 131 | title: () => "", // Hide the title 132 | label({ raw }: TooltipItem<"matrix">) { 133 | const block = raw as Block; 134 | if (maskUpperTri && block.destIdx < block.srcIdx) { 135 | // Just show N/A for the upper triangular part 136 | return "N/A"; 137 | } 138 | return [ 139 | `(${block.destIdx}, ${block.srcIdx})`, 140 | `Src: ${block.srcToken}`, 141 | `Dest: ${block.destToken} `, 142 | `Val: ${block.v}` 143 | ]; 144 | } 145 | } 146 | } 147 | }, 148 | scales: { 149 | x: { 150 | title: { display: true, text: "Source Token", padding: 1 }, 151 | type: "category" as any, 152 | labels: uniqueTokens, 153 | offset: true, 154 | ticks: { display: true, minRotation: 45, maxRotation: 90 }, 155 | grid: { display: false }, 156 | display: showAxisLabels 157 | }, 158 | y: { 159 | title: { display: true, text: "Destination Token", padding: 1 }, 160 | type: "category" as any, 161 | offset: true, 162 | labels: [...uniqueTokens].reverse(), 163 | ticks: { display: true }, 164 | grid: { display: false }, 165 | display: showAxisLabels 166 | } 167 | } 168 | }; 169 | 170 | return ( 171 | 172 | 173 |
194 | 202 |
203 |
204 | 205 | ); 206 | } 207 | 208 | export interface AttentionPatternProps { 209 | /** 210 | * Attention head activations 211 | * 212 | * Of the shape [ dest_pos x src_pos ] 213 | */ 214 | attention: number[][]; 215 | 216 | /** 217 | * Maximum value 218 | * 219 | * Used to determine how dark the token color is when positive (i.e. based on 220 | * how close it is to the maximum value). 221 | * 222 | * @default Math.max(...values) 223 | */ 224 | maxValue?: number; 225 | 226 | /** 227 | * Minimum value 228 | * 229 | * Used to determine how dark the token color is when negative (i.e. based on 230 | * how close it is to the minimum value). 231 | * 232 | * @default Math.min(...values) 233 | */ 234 | minValue?: number; 235 | 236 | /** 237 | * Negative color 238 | * 239 | * Color to use for negative values. This can be any valid CSS color string. 240 | * 241 | * Be mindful of color blindness if not using the default here. 242 | * 243 | * @default red 244 | * 245 | * @example rgb(255, 0, 0) 246 | * 247 | * @example #ff0000 248 | */ 249 | negativeColor?: string; 250 | 251 | /** 252 | * Positive color 253 | * 254 | * Color to use for positive values. This can be any valid CSS color string. 255 | * 256 | * Be mindful of color blindness if not using the default here. 257 | * 258 | * @default blue 259 | * 260 | * @example rgb(0, 0, 255) 261 | * 262 | * @example #0000ff 263 | */ 264 | positiveColor?: string; 265 | 266 | /** 267 | * Mask upper triangular 268 | * 269 | * Whether or not to mask the upper triangular portion of the attention patterns. 270 | * 271 | * Should be true for causal attention, false for bidirectional attention. 272 | * 273 | * @default true 274 | */ 275 | maskUpperTri?: boolean; 276 | 277 | /** 278 | * Upper triangular color 279 | * 280 | * Color to use for the upper triangular part of the attention pattern to make visualization slightly nicer. 281 | * Only applied if maskUpperTri is set to true. 282 | * 283 | * @default rgb(200, 200, 200) 284 | * 285 | * @example rgb(200, 200, 200) 286 | * 287 | * @example #C8C8C8 288 | */ 289 | upperTriColor?: string; 290 | 291 | /** 292 | * Show axis labels 293 | */ 294 | showAxisLabels?: boolean; 295 | 296 | /** 297 | * Is this a zoomed in view? 298 | */ 299 | zoomed?: boolean; 300 | 301 | /** 302 | * List of tokens 303 | * 304 | * Must be the same length as the list of values. 305 | */ 306 | tokens: string[]; 307 | } 308 | -------------------------------------------------------------------------------- /react/src/attention/AttentionPatterns.tsx: -------------------------------------------------------------------------------- 1 | import React, { useMemo, useState } from "react"; 2 | import { einsum, Rank, tensor, Tensor3D, Tensor4D } from "@tensorflow/tfjs"; 3 | import tinycolor from "tinycolor2"; 4 | import { AttentionImage } from "./components/AttentionImage"; 5 | import { Tokens, TokensView } from "./components/AttentionTokens"; 6 | import { useHoverLock } from "./components/useHoverLock"; 7 | 8 | /** 9 | * Color the attention values by heads 10 | * 11 | * We want attention values to be colored by each head (i.e. becoming [heads x 12 | * dest_tokens x src_tokens x rgb_color_channel]). This way, when outputting an 13 | * image of just one attention head it will be colored (by the specific hue 14 | * assigned to that attention head) rather than grayscale. 15 | * 16 | * Importantly, when outputting an image that averages 17 | * several attention heads we can then also average over the colors (so that we 18 | * can see for each destination-source token pair which head is most important). 19 | * For example, if the specific pair is very red, it suggests that the red 20 | * attention head is most important for this destination-source token combination. 21 | * 22 | * @param attentionInput Attention input as [dest_tokens x source_tokens x 23 | * heads] array (this is the format provided by the Python interface). 24 | * 25 | * @returns Tensor of the shape [heads x dest_tokens x src_tokens x 26 | * rgb_color_channel] 27 | */ 28 | export function colorAttentionTensors(attentionInput: number[][][]): Tensor4D { 29 | // Create a TensorFlow tensor from the attention data 30 | const attentionTensor = tensor(attentionInput); // [heads x dest_tokens x source_tokens] 31 | 32 | const attention = attentionTensor.arraySync() as number[][][]; 33 | 34 | // Set the colors 35 | const colored = attention.map((head, headNumber) => 36 | head.map((destination) => 37 | destination.map((sourceAttention) => { 38 | // Color 39 | const attentionColor = tinycolor({ 40 | h: (headNumber / attention.length) * 360, // Hue (degrees 0-360) 41 | s: 0.8, // Saturation (slightly off 100% to make less glaring) 42 | l: 1 - 0.75 * sourceAttention // Luminance (shows amount of attention) 43 | }); 44 | 45 | // Return as a nested list in the format [red, green, blue] 46 | const { r, g, b } = attentionColor.toRgb(); 47 | return [r, g, b]; 48 | }) 49 | ) 50 | ); 51 | 52 | return tensor(colored); 53 | } 54 | 55 | /** 56 | * Attention Patterns 57 | * 58 | * @deprecated Use `AttentionHeads` instead. 59 | */ 60 | export function AttentionPatterns({ 61 | tokens, 62 | attention, 63 | headLabels 64 | }: { 65 | /** Array of tokens e.g. `["Hello", "my", "name", "is"...]` */ 66 | tokens: string[]; 67 | /** Attention input as [dest_tokens x source_tokens x heads] (JSON stringified) */ 68 | attention: number[][][]; 69 | /** Head labels */ 70 | headLabels?: string[]; 71 | }) { 72 | // State for the token view type 73 | const [tokensView, setTokensView] = useState( 74 | TokensView.DESTINATION_TO_SOURCE 75 | ); 76 | 77 | // Attention head focussed state 78 | const { 79 | focused: focusedHead, 80 | onClick: onClickHead, 81 | onMouseEnter: onMouseEnterHead, 82 | onMouseLeave: onMouseLeaveHead 83 | } = useHoverLock(); 84 | 85 | // State for which token is focussed 86 | const { 87 | focused: focussedToken, 88 | onClick: onClickToken, 89 | onMouseEnter: onMouseEnterToken, 90 | onMouseLeave: onMouseLeaveToken 91 | } = useHoverLock(); 92 | 93 | // Color the attention values (by head) 94 | const coloredAttention = useMemo( 95 | () => colorAttentionTensors(attention), 96 | [attention] 97 | ); 98 | const heads = coloredAttention.unstack(0); 99 | 100 | // Max attention color across all heads 101 | // This is helpful as we can see if, for example, only one or two colored 102 | // heads are focussing on a specific source token from a destination token. 103 | // To do this we re-arrange to put heads at the last dimension and then max 104 | // this (by color darkness, so min in terms of rgb values) 105 | const maxAttentionAcrossHeads = einsum("hdsc -> dsch", coloredAttention).min( 106 | 3 107 | ); 108 | 109 | // Get the focused head based on the state (selected/hovered) 110 | const focusedAttention = 111 | focusedHead === null ? maxAttentionAcrossHeads : heads[focusedHead]; 112 | 113 | return ( 114 |
115 |
116 |
117 |

Attention Patterns

118 | 119 |
120 | 121 |
122 |

123 | Head selector 124 | 125 | {" "} 126 | (hover to focus, click to lock) 127 | 128 |

129 |
130 | {heads.map((head, headNumber) => ( 131 |
onClickHead(headNumber)} 138 | onMouseEnter={() => onMouseEnterHead(headNumber)} 139 | onMouseLeave={onMouseLeaveHead} 140 | > 141 | 146 |
147 | {headLabels?.[headNumber] ?? `Head ${headNumber}`} 148 |
149 |
150 | ))} 151 |
152 |
153 |
154 | 155 |
156 |

157 | Tokens 158 | (click to focus) 159 |

160 | 171 |
172 | 182 |
183 |
184 |
185 | ); 186 | } 187 | -------------------------------------------------------------------------------- /react/src/attention/components/AttentionImage.tsx: -------------------------------------------------------------------------------- 1 | import React, { CSSProperties, useEffect, useRef } from "react"; 2 | import { browser, Tensor3D } from "@tensorflow/tfjs"; 3 | 4 | export interface AttentionImageProps { 5 | /** 6 | * Attention patterns (destination to source tokens), colored by attention head 7 | * 8 | * Should be [n_tokens x n_tokens x color_channels] 9 | */ 10 | coloredAttention: Tensor3D; 11 | 12 | style?: CSSProperties; 13 | 14 | /** Adds a box-shadow to the canvas when true */ 15 | isSelected?: boolean; 16 | } 17 | 18 | /** 19 | * Attention Image 20 | * 21 | * Shows the attention from destination tokens to source tokens, as a [n_tokens 22 | * x n_tokens] image. 23 | */ 24 | export function AttentionImage({ 25 | coloredAttention, 26 | style = {}, 27 | isSelected = false 28 | }: AttentionImageProps) { 29 | // Add a reference to the HTML Canvas element in the DOM, so we can update it 30 | const canvasRef = useRef(null); 31 | 32 | // Draw the attention pattern onto the HTML Canvas 33 | // Runs in `useEffect` as we need the canvas to be added to the DOM first, 34 | // before we can interact with it. 35 | useEffect(() => { 36 | const canvas = canvasRef.current; 37 | browser.toPixels(coloredAttention.toInt(), canvas as HTMLCanvasElement); 38 | }, [coloredAttention]); 39 | 40 | return ( 41 | 59 | ); 60 | } 61 | -------------------------------------------------------------------------------- /react/src/attention/components/AttentionTokens.tsx: -------------------------------------------------------------------------------- 1 | import { einsum, Rank, Tensor, Tensor3D, Tensor4D } from "@tensorflow/tfjs"; 2 | import tinycolor from "tinycolor2"; 3 | import React from "react"; 4 | 5 | export enum TokensView { 6 | DESTINATION_TO_SOURCE = "DESTINATION_TO_SOURCE", 7 | SOURCE_TO_DESTINATION = "SOURCE_TO_DESTINATION" 8 | } 9 | 10 | /** 11 | * Get the relevant attention values to average (for an individual token) 12 | * 13 | * Used to calculate the color of a specific token block (div). 14 | * 15 | * @param maxAttentionAcrossHeads [dest_tokens x src_tokens x rgb] 16 | * @param tokenIndex Current token index 17 | * @param tokensView 18 | * @param focusedToken Selected/focused token 19 | * 20 | * @returns Relevant tokens from which to average the color [dest_tokens x src_tokens x rgb] 21 | */ 22 | export function getTokensToAverage( 23 | maxAttentionAcrossHeads: Tensor3D, 24 | tokenIndex: number, 25 | tokensView: TokensView, 26 | focusedToken?: number 27 | ): Tensor3D { 28 | // Default: If no tokens are selected, we're going to average over all source 29 | // tokens available to look at (i.e. up to this current token) 30 | // Note: End values are inclusive 31 | let destinationStart: number = tokenIndex; 32 | let destinationEnd: number = tokenIndex; 33 | let sourceStart: number = 0; 34 | let sourceEnd: number = tokenIndex; 35 | 36 | // If a token is selected (and we're showing destination -> source attention), 37 | // show the attention from the selected destination token to this token. 38 | if ( 39 | typeof focusedToken === "number" && 40 | tokensView === TokensView.DESTINATION_TO_SOURCE 41 | ) { 42 | destinationStart = focusedToken; 43 | destinationEnd = focusedToken; 44 | sourceStart = tokenIndex; 45 | sourceEnd = tokenIndex; 46 | } 47 | 48 | // If a token is selected (but instead we're showing source -> destination), 49 | // show the attention from the selected source token to this token. 50 | else if ( 51 | typeof focusedToken === "number" && 52 | tokensView === TokensView.SOURCE_TO_DESTINATION 53 | ) { 54 | destinationStart = tokenIndex; 55 | destinationEnd = tokenIndex; 56 | sourceStart = focusedToken; 57 | sourceEnd = focusedToken; 58 | } 59 | 60 | return maxAttentionAcrossHeads.slice( 61 | [destinationStart, sourceStart], 62 | [destinationEnd + 1 - destinationStart, sourceEnd + 1 - sourceStart] 63 | ); 64 | } 65 | 66 | /** 67 | * Individual Token 68 | */ 69 | export function Token({ 70 | focusedToken, 71 | onClickToken, 72 | onMouseEnterToken, 73 | onMouseLeaveToken, 74 | maxAttentionAcrossHeads, 75 | text, 76 | tokenIndex, 77 | tokensView 78 | }: { 79 | focusedToken?: number; 80 | onClickToken: (e: number) => void; 81 | onMouseEnterToken: (e: number) => void; 82 | onMouseLeaveToken: () => void; 83 | maxAttentionAcrossHeads: Tensor3D; 84 | text: string; 85 | tokenIndex: number; 86 | tokensView: TokensView; 87 | }) { 88 | const isFocused = focusedToken !== null && focusedToken === tokenIndex; 89 | 90 | // Get the average of the colors of the source tokens that we can attend to. 91 | const relevantTokens = getTokensToAverage( 92 | maxAttentionAcrossHeads, 93 | tokenIndex, 94 | tokensView, 95 | focusedToken 96 | ); 97 | 98 | const averageColor = relevantTokens 99 | .mean>(0) 100 | .mean>(0); 101 | const [r, g, b] = averageColor.arraySync(); 102 | const backgroundColor = tinycolor({ r, g, b }); 103 | 104 | // Set the text color to always be visible (allowing for the background color) 105 | const textColor = backgroundColor.getBrightness() < 180 ? "white" : "black"; 106 | 107 | return ( 108 | 96 | )} 97 |
98 | ); 99 | } 100 | 101 | export function Tooltip({ 102 | title, 103 | labels, 104 | values, 105 | tokenIndex, 106 | currentValueIndex 107 | }: { 108 | title: string; 109 | labels: string[]; 110 | values: Tensor2D; 111 | tokenIndex: number; 112 | currentValueIndex: number; 113 | }) { 114 | const numValues = values.shape[1]; 115 | 116 | const valueRows = []; 117 | for (let i = 0; i < numValues; i++) { 118 | valueRows.push( 119 | 120 | {labels[i]} 121 | 127 | {values.bufferSync().get(tokenIndex, i).toFixed(PRECISION)} 128 | 129 | 130 | ); 131 | } 132 | 133 | return ( 134 | <> 135 |
138 | {title} 139 |
140 | 141 | {valueRows} 142 |
143 | 144 | ); 145 | } 146 | 147 | /** 148 | * Extension of ColoredTokens to allow K vectors of values across tokens. Each 149 | * vector has a positive and negative color associated. For the selected vector, 150 | * display tokens with a background representing how negative (close to 151 | * `negativeColor`) or positive (close to `positiveColor`) the token is. Zero is 152 | * always displayed as white. 153 | * 154 | * Hover over a token, to view all K of its values. 155 | */ 156 | export function ColoredTokensMulti({ 157 | tokens, 158 | values, 159 | labels, 160 | positiveBounds, 161 | negativeBounds 162 | }: ColoredTokensMultiProps) { 163 | const valuesTensor = tensor(values); 164 | 165 | const numValues = valuesTensor.shape[1]; 166 | 167 | // Define default positive and negative bounds if not provided 168 | // These are the max/min elements of the value tensor, capped at +-1e-7 (not 169 | // zero, to avoid a bug in our color calculation code) 170 | const positiveBoundsTensor: Tensor1D = positiveBounds 171 | ? tensor(positiveBounds) 172 | : valuesTensor.max(0).maximum(1e-7); 173 | const negativeBoundsTensor: Tensor1D = negativeBounds 174 | ? tensor(negativeBounds) 175 | : valuesTensor.min(0).minimum(-1e-7); 176 | 177 | // Define default labels if not provided 178 | const valueLabels = 179 | labels || Array.from(Array(numValues).keys()).map((_, i) => `${i}`); 180 | 181 | const [displayedValueIndex, setDisplayedValueIndex] = useState(0); 182 | 183 | // Positive and negative bounds state 184 | const defaultPositiveBound = Number( 185 | positiveBoundsTensor.arraySync()[displayedValueIndex].toFixed(PRECISION) 186 | ); 187 | const defaultNegativeBound = Number( 188 | negativeBoundsTensor.arraySync()[displayedValueIndex].toFixed(PRECISION) 189 | ); 190 | const [positiveBound, setOverridePositiveBound] = useState( 191 | Number(defaultPositiveBound) 192 | ); 193 | const [negativeBound, setOverrideNegativeBound] = useState( 194 | Number(defaultNegativeBound) 195 | ); 196 | 197 | const displayedValues = valuesTensor 198 | .slice([0, displayedValueIndex], [-1, 1]) 199 | .squeeze([1]); 200 | 201 | // Padding to ensure that the tooltip is visible - pretty janky, sorry! 202 | return ( 203 |
204 | 210 | 211 | 217 | 223 | 224 |
225 | 226 | ( 232 | 240 | ))} 241 | /> 242 |
243 | ); 244 | } 245 | 246 | export interface ColoredTokensMultiProps { 247 | /** 248 | * The prompt for the model, split into S tokens (as strings) 249 | */ 250 | tokens: string[]; 251 | /** 252 | * The tensor of values across the tokens. Shape [S, K] 253 | */ 254 | values: number[][]; 255 | /** 256 | * The labels for the K vectors 257 | */ 258 | labels?: string[]; 259 | /** 260 | * 261 | */ 262 | positiveBounds?: number[]; 263 | /** 264 | */ 265 | negativeBounds?: number[]; 266 | } 267 | -------------------------------------------------------------------------------- /react/src/tokens/mocks/coloredTokens.ts: -------------------------------------------------------------------------------- 1 | export const mockTokens = [ 2 | "class", 3 | " Reddit", 4 | ":", 5 | "\n ", 6 | " update", 7 | "_", 8 | "checked", 9 | " =", 10 | " False", 11 | "\n ", 12 | " _", 13 | "rat", 14 | "el", 15 | "imit", 16 | "_", 17 | "regex", 18 | " =", 19 | " re", 20 | ".", 21 | "compile", 22 | "(", 23 | "r", 24 | '"', 25 | "([", 26 | "0", 27 | "-", 28 | "9", 29 | "]{", 30 | "1", 31 | ",", 32 | "3", 33 | "})", 34 | " (", 35 | "mill", 36 | "iseconds", 37 | "?", 38 | "|", 39 | "seconds", 40 | "?", 41 | "|", 42 | "minutes", 43 | "?)", 44 | '")', 45 | "\n\n ", 46 | " @", 47 | "property", 48 | "\n ", 49 | " def", 50 | " _", 51 | "next", 52 | "_", 53 | "unique", 54 | "(", 55 | "self", 56 | ")", 57 | " ->", 58 | " int", 59 | ":", 60 | "\n ", 61 | " value", 62 | " =", 63 | " self", 64 | "._", 65 | "unique", 66 | "_", 67 | "counter", 68 | "\n ", 69 | " self", 70 | "._", 71 | "unique", 72 | "_", 73 | "counter", 74 | " +=", 75 | " 1", 76 | "\n ", 77 | " return", 78 | " value", 79 | "\n\n ", 80 | " @", 81 | "property", 82 | "\n ", 83 | " def", 84 | " read", 85 | "_", 86 | "only", 87 | "(", 88 | "self", 89 | ")", 90 | " ->", 91 | " bool", 92 | ":", 93 | "\n ", 94 | ' """', 95 | "Return", 96 | " ``", 97 | "True", 98 | "``", 99 | " when", 100 | " using", 101 | " the", 102 | " ``", 103 | "Read", 104 | "Only", 105 | "Author", 106 | "izer", 107 | "``", 108 | '."""', 109 | "\n ", 110 | " return", 111 | " self", 112 | "._", 113 | "core", 114 | " ==", 115 | " self", 116 | "._", 117 | "read", 118 | "_", 119 | "only", 120 | "_", 121 | "core", 122 | "\n\n ", 123 | " @", 124 | "read", 125 | "_", 126 | "only", 127 | ".", 128 | "set", 129 | "ter", 130 | "\n ", 131 | " def", 132 | " read", 133 | "_", 134 | "only", 135 | "(", 136 | "self", 137 | ",", 138 | " value", 139 | ":", 140 | " bool", 141 | ")", 142 | " ->", 143 | " None", 144 | ":", 145 | "\n ", 146 | ' """', 147 | "Set", 148 | " or", 149 | " un", 150 | "set", 151 | " the", 152 | " use", 153 | " of", 154 | " the", 155 | " Read", 156 | "Only", 157 | "Author", 158 | "izer", 159 | ".", 160 | "\n ", 161 | " :", 162 | "ra", 163 | "ises", 164 | ":", 165 | " :", 166 | "class", 167 | ":", 168 | "`.", 169 | "Client", 170 | "Exception", 171 | "`", 172 | " when", 173 | " attempting", 174 | " to", 175 | " un", 176 | "set", 177 | " ``", 178 | "read", 179 | "_", 180 | "only", 181 | "``", 182 | " and", 183 | "\n ", 184 | " only", 185 | " the", 186 | " ``", 187 | "Read", 188 | "Only", 189 | "Author", 190 | "izer", 191 | "``", 192 | " is", 193 | " available", 194 | ".", 195 | "\n ", 196 | ' """', 197 | "\n ", 198 | " if", 199 | " value", 200 | ":", 201 | "\n ", 202 | " self", 203 | "._", 204 | "core", 205 | " =", 206 | " self", 207 | "._", 208 | "read", 209 | "_", 210 | "only", 211 | "_", 212 | "core", 213 | "\n ", 214 | " elif", 215 | " self", 216 | "._", 217 | "authorized", 218 | "_", 219 | "core", 220 | " is", 221 | " None", 222 | ":", 223 | "\n ", 224 | " raise", 225 | " Client", 226 | "Exception", 227 | "(", 228 | "\n ", 229 | ' "', 230 | "read", 231 | "_", 232 | "only", 233 | " cannot", 234 | " be", 235 | " un", 236 | "set", 237 | " as", 238 | " only", 239 | " the", 240 | " Read", 241 | "Only", 242 | "Author", 243 | "izer", 244 | " is", 245 | " available", 246 | '."', 247 | "\n ", 248 | " )", 249 | "\n ", 250 | " else", 251 | ":", 252 | "\n ", 253 | " self", 254 | "._", 255 | "core", 256 | " =", 257 | " self", 258 | "._", 259 | "authorized", 260 | "_", 261 | "core", 262 | "\n\n ", 263 | " @", 264 | "property", 265 | "\n ", 266 | " def", 267 | " validate", 268 | "_", 269 | "on", 270 | "_", 271 | "submit", 272 | "(", 273 | "self", 274 | ")", 275 | " ->", 276 | " bool", 277 | ":", 278 | "\n ", 279 | ' """', 280 | "Get", 281 | " validate", 282 | "_", 283 | "on", 284 | "_", 285 | "submit", 286 | ".", 287 | "\n ", 288 | "..", 289 | " deprecated", 290 | "::", 291 | " 7", 292 | ".", 293 | "0", 294 | "\n ", 295 | " If", 296 | " property", 297 | " :", 298 | "attr", 299 | ":", 300 | "`.", 301 | "validate", 302 | "_", 303 | "on", 304 | "_", 305 | "submit", 306 | "`", 307 | " is", 308 | " set", 309 | " to", 310 | " ``", 311 | "False", 312 | "``", 313 | ",", 314 | " the", 315 | " behavior", 316 | " is", 317 | "\n ", 318 | " deprecated", 319 | " by", 320 | " Reddit", 321 | ".", 322 | " This", 323 | " attribute", 324 | " will", 325 | " be", 326 | " removed", 327 | " around", 328 | " May", 329 | "-", 330 | "June", 331 | " 2", 332 | "0", 333 | "2", 334 | "0", 335 | ".", 336 | "\n ", 337 | ' """', 338 | "\n ", 339 | " value", 340 | " =", 341 | " self", 342 | "._", 343 | "validate", 344 | "_", 345 | "on", 346 | "_", 347 | "submit", 348 | "\n ", 349 | " if", 350 | " value", 351 | " is", 352 | " False", 353 | ":", 354 | "\n ", 355 | " warn", 356 | "(", 357 | "\n ", 358 | ' "', 359 | "Red", 360 | "dit", 361 | " will", 362 | " check", 363 | " for", 364 | " validation", 365 | " on", 366 | " all", 367 | " posts", 368 | " around", 369 | " May", 370 | "-", 371 | "June", 372 | " 2", 373 | "0", 374 | "2", 375 | "0", 376 | ".", 377 | " It", 378 | '"', 379 | "\n ", 380 | ' "', 381 | " is", 382 | " recommended", 383 | " to", 384 | " check", 385 | " for", 386 | " validation", 387 | " by", 388 | " setting", 389 | '"', 390 | "\n ", 391 | ' "', 392 | " redd", 393 | "it", 394 | ".", 395 | "validate", 396 | "_", 397 | "on", 398 | "_", 399 | "submit", 400 | " to", 401 | " True", 402 | '.",', 403 | "\n ", 404 | " category", 405 | "=", 406 | "Dep", 407 | "rec", 408 | "ation", 409 | "Warning", 410 | ",", 411 | "\n ", 412 | " stack", 413 | "level", 414 | "=", 415 | "3", 416 | ",", 417 | "\n ", 418 | " )", 419 | "\n ", 420 | " return" 421 | ]; 422 | 423 | export const mockValues = [ 424 | 0.0000006219, 0.0145308562, -0.1404035836, 0.0001031339, -0.0181442499, 425 | 0.0000425652, 0.4676490724, 0.0379571021, -0.1879520416, -0.007731745, 426 | 0.0000087938, -0.0127301402, 0.2066464573, 0.0776088983, 0.0006695495, 427 | -0.2520848811, -0.0823620856, -0.0334108472, 0.245113492, -0.1432019472, 428 | -0.0066892505, -0.0300149024, -0.009421261, -0.0221996512, 0.5490510464, 429 | -0.1218788922, 0.0105798542, 0.0484184623, 0.0801258683, -0.0239099413, 430 | -0.0197054446, 0.0011228092, 0.0000681803, 0.2287244499, 0.0063738613, 431 | -0.0005307829, 0.0139577975, 0.2580736578, 0.6625726223, -0.0009637121, 432 | 0.066787228, 0.2879539132, 0.1427007169, -0.007757809, 0.1295484602, 433 | 0.0134078264, -0.0097765326, -0.0366173349, -0.0000446837, 0.0205183029, 434 | -0.0000258216, 0.1400976628, 0.0655861497, 0.0012099133, -0.0720179677, 435 | -0.0070694387, 0.1904485226, 0.0582311749, -0.003322866, -0.0540601015, 436 | 0.1502584517, 0.1261327267, -0.0011224833, -0.4380444586, 0.0069346358, 437 | 0.0386789814, -0.0245163366, -0.2321336865, -0.0314541124, -0.08843261, 438 | 0.9286111593, -0.1562622041, 0.1237872243, 0.1001729071, 0.024910897, 439 | 0.5526285768, 0.0023488998, -0.2850846946, 0.0118290782, 0.0058026314, 440 | 0.0047158003, 0.0006096047, -0.0790141821, 0.011721639, 0.1362534761, 441 | 0.0091130733, -0.2058984935, 0.0454577208, 0.1273727119, 0.1381959319, 442 | 0.0429062247, -0.0642386526, 0.1462859362, -0.0010326442, 0.2920421362, 443 | 0.002445817, -0.0018207836, 0.004715913, -0.0230856687, 0.0407866389, 444 | 0.0010117516, 0.5489981771, -0.0001264173, 0.0007876776, 0.2675945163, 445 | 0.4242863953, -0.0210899711, 0.2453442812, 0.0855717659, -0.0085759163, 446 | -0.000562535, 0.000116363, 0.0085016415, 0.0134009123, -0.0005975422, 447 | 0.2166460752, 0.2984373569, -0.2958423793, -0.0008974103, 0.2694529295, 448 | -0.1098018885, 0.0134821096, 0.5570789576, 0.117154181, 0.3000327051, 449 | 0.4688822627, 0.1061223745, -0.0078112483, -0.0078518391, 0.8520200253, 450 | 0.1286870837, 0.3439406753, 0.1901453137, 0.0072892308, 0.0131719112, 451 | 0.2396282405, 0.010022047, -0.3157484233, 0.0698616207, -0.0080025792, 452 | -0.1726672649, 0.1600424647, 0.0671297312, 0.2149727941, 0.0894118994, 453 | 0.0005835033, -0.0000675181, 0.0946147144, 0.0470751151, -0.0010825, 454 | 0.131246388, 0.045971334, 0.0011529038, 0.6017510891, 0.7375714779, 455 | 0.9098120332, -0.0137378313, 0.0703185946, 0.0590063371, -0.0246331077, 456 | -0.0027519464, -0.0677859783, 0.0881179646, -0.1653440893, -0.0123041077, 457 | -0.286318779, -0.000936829, 0.1067639887, 0.2747517526, 0.0903260782, 458 | 0.0810994208, 0.013060689, 0.0007890144, 0.9062747359, -0.009627467, 459 | 0.0792724416, 0.0516216159, 0.0196556449, 0.1200390458, -0.0026325649, 460 | 0.0102911554, 0.0018374863, 0.1356917173, -0.0386384428, 0.7473417521, 461 | 0.1174088717, 0.942612648, 0.8944275379, 0.5595457554, -0.0744787753, 462 | -0.0371254981, 0.4966602921, 0.4688055515, 0.1482989192, 0.286432147, 463 | 0.1191697195, 0.0031747962, 0.5068869591, 0.0697489977, 0.2824061513, 464 | -0.0376368165, 0.0198621545, 0.1926025301, 0.413629055, -0.0337253809, 465 | 0.4140422344, 0.2303587794, 0.0281034708, -0.0063169599, 0.6945232749, 466 | 0.1411154866, 0.0252213627, -0.4500791132, 0.0029962659, -0.0000003174, 467 | -0.2665536106, 0.0001449844, 0.0666593164, 0.125900805, -0.2045109272, 468 | -0.0008355975, -0.0386986323, -0.0000711503, 0.5098578334, 0.1436619163, 469 | -0.1012291908, 0.0344754159, 0.0561907329, 0.4081306159, 0.0034198165, 470 | 0.0003852564, -0.3534452617, 0.0014236877, 0.9586799145, 0.0012088121, 471 | 0.0009353122, 0.0330252498, 0.0019751845, 0.2715310454, 0.8898829818, 472 | 0.836817503, 0.0882308781, -0.0037925355, 0.313105762, -0.2465211898, 473 | 0.6578400135, 0.0435889363, -0.1754517257, -0.0080716014, 0.0057964921, 474 | 0.0600164458, 0.0766673684, 0.7506139874, 0.6541926265, 0.6603399515, 475 | 0.0395969152, 0.0000002805, 0.203969419, 0.5281766057, 0.435046792, 476 | 0.1880384684, -0.79277879, 0.0236440301, 0.0015734434, 0.0020925375, 477 | -0.0624228716, -0.0024808338, -0.0582574606, 0.0018162932, 0.1995362043, 478 | 0.0034110546, -0.1866874397, 0.1613387465, 0.4188560545, 0.0775891542, 479 | 0.1022245288, 0.1870539784, 0.0052574091, 0.0025829612, 0.1834496558, 480 | 0.73578161, 0.2195843458, 0.8987259865, 0.0121229813, -0.0303361509, 481 | 0.0002931431, -0.0195899308, -0.2794342637, 0.0003327613, 0.0898362398, 482 | 0.226510331, -0.108599633, -0.0014194585, -0.0000567417, 0.0033046999, 483 | 0.1063713953, -0.0404051207, 0.0553080775, -0.0005358383, -0.1770697832, 484 | 0.262383461, 0.175275743, 0.6111828685, 0.1434824914, 0.3208144009, 485 | -0.012219294, -0.1834220886, 0.0149550941, 0.1380700469, -0.0015727878, 486 | 0.2584753931, -0.0277175345, 0.0003288229, 0.0949437171, 0.0207306352, 487 | -0.0049975035, 0.0035048975, 0.0000889587, 0.1136212647, -0.0034325775, 488 | 0.0000552796, 0.1600628495, 0.0297037363, -0.1841051579, -0.0000346369, 489 | 0.0025556006, 0.0020889556, 0.0079352856, -0.1264779866, 0.319478184, 490 | 0.0008403971, 0.1294415742, 0.0411299765, 0.1301106811, 0.090507865, 491 | 0.059785068, 0.0036087139, 0.3591989875, 0.3647009134, 0.071782589, 492 | -0.0264552459, -0.0083175302, 0.3546337485, 0.0599297881, 0.3915102482, 493 | 0.0252607167, 0.1225005835, 0.471676141, -0.4233388901, 0.0037213129, 494 | -0.0658706427, 0.0143287778, -0.0000618782, 0.0446123667, 0.4080998302, 495 | 0.5767087936, -0.001053106, 0.0884372517, -0.0042287335, 0.0034043523, 496 | 0.0120782629, 0.0006982221, -0.0016106404, -0.0076557342, 0.0115865655, 497 | -0.0003337541, 0.0219879113, 0.8558861017, 0.5967720747, 0.4480031729, 498 | 0.4373111129, 0.8445559144, 0.9268567562, 0.3059120774, 0.0063580293, 499 | 0.0000039492, 0.2881910205, 0.3730633855, 0.0017905204, 0.0016960911, 500 | -0.1373336017, -0.0334244072, 0.0517534427, 0.1390397102, -0.0196680874, 501 | -0.001462996, 0.0013619808, 0.2718190253, 0.8149151802, 0.0000475792, 502 | -0.2028101683, -0.0254933126, -0.0419868752, 0.1284008026, 0.3714885116, 503 | 0.3263232112, 0.3687552214, 0.0001573412, 0.0150237354, 0.0031800342, 504 | -0.0472598076, -0.0021545473, 0.1496221721, -0.0023701997, -0.1677999645, 505 | -0.0129204392, 0.0103296041, 0.094870463, -0.2831195295, 0.008339718, 506 | 0.2930948734, 0.1508107185, 0.0295093879, 0.0405177772, 0.0641460866, 507 | 0.2621290684, 0.1435324848, 0.0173472911, 0.2825519443 508 | ]; 509 | -------------------------------------------------------------------------------- /react/src/tokens/mocks/coloredTokensMulti.ts: -------------------------------------------------------------------------------- 1 | export const mockTokens: string[] = [ 2 | "<|BOS|>", 3 | "We", 4 | " all", 5 | " live", 6 | " in", 7 | " a", 8 | " yellow", 9 | " submarine" 10 | ]; 11 | 12 | export const mockValues: number[][] = [ 13 | [ 14 | 0.00019838476146105677, 6.558105087606236e-5, 7.382185867754743e-5, 15 | 0.0001276412804145366, 0.0006976894219405949 16 | ], 17 | [ 18 | -2.8924760044901632e-5, -0.00014778069453313947, 0.00014617902343161404, 19 | -8.635249832877889e-5, 0.0004980250378139317 20 | ], 21 | [ 22 | 0.0005748102557845414, -0.00015581399202346802, 7.338493742281571e-5, 23 | -0.0001935002946993336, -0.00019373299437575042 24 | ], 25 | [ 26 | 0.0001386010553687811, 2.15080854104599e-5, 0.0016489957924932241, 27 | 9.453385428059846e-5, 0.0019853608682751656 28 | ], 29 | [ 30 | 0.0005312012508511543, 0.00032850843854248524, -8.134488598443568e-5, 31 | -0.00018524321785662323, -8.109412738122046e-5 32 | ], 33 | [ 34 | -0.00019468815298750997, -5.316149690770544e-5, -0.00018978440493810922, 35 | -0.0001922725496115163, 0.00037922835326753557 36 | ], 37 | [ 38 | -0.0001361396862193942, 6.918911094544455e-5, -0.00017853885947261006, 39 | -8.563703886466101e-5, 0.00041159408283419907 40 | ], 41 | [ 42 | 9.39336241572164e-5, 0.0001966943236766383, -0.00010445583757245913, 43 | -7.004357757978141e-5, 0.0007441306370310485 44 | ] 45 | ]; 46 | 47 | export const mockLabels: string[] = ["L0N0", "L0N1", "L0N2", "L0N3", "L0N4"]; 48 | -------------------------------------------------------------------------------- /react/src/tokens/utils/Token.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import { colord, AnyColor } from "colord"; 3 | import { usePopperTooltip } from "react-popper-tooltip"; 4 | import { getTokenBackgroundColor } from "../../utils/getTokenBackgroundColor"; 5 | 6 | export function formatTokenText(token: string) { 7 | // Handle special tokens (e.g. spaces/line breaks) 8 | const tokenReplaceSpaces = token.replace(/\s/g, " "); 9 | const tokenReplaceLineBreaks = tokenReplaceSpaces.replace(/\n/g, "¶"); 10 | return tokenReplaceLineBreaks; 11 | } 12 | 13 | /** 14 | * Token (shown as an inline block) 15 | */ 16 | export function Token({ 17 | token, 18 | value, 19 | min, 20 | max, 21 | negativeColor, 22 | positiveColor 23 | }: { 24 | token: string; 25 | value: number; 26 | min: number; 27 | max: number; 28 | negativeColor?: AnyColor; 29 | positiveColor?: AnyColor; 30 | }) { 31 | // Hover state 32 | const { getTooltipProps, setTooltipRef, setTriggerRef, visible } = 33 | usePopperTooltip({ 34 | followCursor: true 35 | }); 36 | 37 | // Get the background color 38 | const backgroundColor = getTokenBackgroundColor( 39 | value, 40 | min, 41 | max, 42 | negativeColor, 43 | positiveColor 44 | ).toRgbString(); 45 | 46 | // Get the text color 47 | const textColor = 48 | colord(backgroundColor).brightness() < 0.6 ? "white" : "black"; 49 | 50 | // Format the span (CSS style) 51 | const spanStyle: React.CSSProperties = { 52 | display: "inline-block", 53 | backgroundColor, 54 | color: textColor, 55 | lineHeight: "1em", 56 | padding: "3px 0", 57 | marginLeft: -1, 58 | marginBottom: 1, 59 | borderWidth: 1, 60 | borderStyle: "solid", 61 | borderColor: "#eee" 62 | }; 63 | 64 | // Handle special tokens (e.g. spaces/line breaks) 65 | const tokenReplaceLineBreaks = formatTokenText(token); 66 | const lineBreakElements = token.match(/\n/g)!; 67 | 68 | return ( 69 | <> 70 | 71 | 75 | {lineBreakElements?.map((_break, idx) => ( 76 |
77 | ))} 78 |
79 | 80 | {visible && ( 81 |
96 | {token} 97 |
98 | {value} 99 |
100 | )} 101 | 102 | ); 103 | } 104 | -------------------------------------------------------------------------------- /react/src/tokens/utils/TokenCustomTooltip.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import { colord, AnyColor } from "colord"; 3 | import { usePopperTooltip } from "react-popper-tooltip"; 4 | import { getTokenBackgroundColor } from "../../utils/getTokenBackgroundColor"; 5 | import { formatTokenText } from "./Token"; 6 | 7 | /** 8 | * Token (shown as an inline block) 9 | */ 10 | export function TokenCustomTooltip({ 11 | token, 12 | value, 13 | min, 14 | max, 15 | negativeColor, 16 | positiveColor, 17 | tooltip = <>{"Intentionally Left Blank"} 18 | }: { 19 | token: string; 20 | value: number; 21 | min: number; 22 | max: number; 23 | negativeColor?: AnyColor; 24 | positiveColor?: AnyColor; 25 | tooltip?: React.ReactNode; 26 | }) { 27 | // Hover state 28 | const { getTooltipProps, setTooltipRef, setTriggerRef, visible } = 29 | usePopperTooltip({ 30 | followCursor: true 31 | }); 32 | 33 | // Get the background color 34 | const backgroundColor = getTokenBackgroundColor( 35 | value, 36 | min, 37 | max, 38 | negativeColor, 39 | positiveColor 40 | ).toRgbString(); 41 | 42 | // Get the text color 43 | const textColor = 44 | colord(backgroundColor).brightness() < 0.6 ? "white" : "black"; 45 | 46 | // Format the span (CSS style) 47 | const spanStyle: React.CSSProperties = { 48 | display: "inline-block", 49 | backgroundColor, 50 | color: textColor, 51 | lineHeight: "1em", 52 | padding: "3px 0", 53 | marginLeft: -1, 54 | marginBottom: 1, 55 | borderWidth: 1, 56 | borderStyle: "solid", 57 | borderColor: "#eee" 58 | }; 59 | 60 | // Handle special tokens (e.g. spaces/line breaks) 61 | const tokenReplaceLineBreaks = formatTokenText(token); 62 | const lineBreakElements = token.match(/\n/g)!; 63 | 64 | return ( 65 | <> 66 | 67 | 71 | {lineBreakElements?.map((_break, idx) => ( 72 |
73 | ))} 74 |
75 | 76 | {visible && ( 77 |
91 | {tooltip} 92 |
93 | )} 94 | 95 | ); 96 | } 97 | -------------------------------------------------------------------------------- /react/src/topk/TopkSamples.stories.tsx: -------------------------------------------------------------------------------- 1 | import { ComponentStory, ComponentMeta } from "@storybook/react"; 2 | import React from "react"; 3 | import { mockActivations, mockTokens, neuronLabels } from "./mocks/topkSamples"; 4 | import { TopkSamples } from "./TopkSamples"; 5 | 6 | export default { 7 | component: TopkSamples 8 | } as ComponentMeta; 9 | 10 | const Template: ComponentStory = (args) => ( 11 | 12 | ); 13 | 14 | export const ExampleSamples: ComponentStory = Template.bind( 15 | {} 16 | ); 17 | ExampleSamples.args = { 18 | tokens: mockTokens, 19 | activations: mockActivations, 20 | firstDimensionLabels: neuronLabels 21 | }; 22 | -------------------------------------------------------------------------------- /react/src/topk/TopkSamples.tsx: -------------------------------------------------------------------------------- 1 | import React, { useState, useEffect } from "react"; 2 | import { Container, Row, Col } from "react-grid-system"; 3 | import { SampleItems } from "../shared/SampleItems"; 4 | import { RangeSelector } from "../shared/RangeSelector"; 5 | import { NumberSelector } from "../shared/NumberSelector"; 6 | import { minMaxInNestedArray } from "../utils/arrayOps"; 7 | 8 | /** 9 | * List of samples in descending order of max token activation value for the 10 | * selected layer and neuron (or whatever other dimension names are specified). 11 | */ 12 | export function TopkSamples({ 13 | tokens, 14 | activations, 15 | zerothDimensionName = "Layer", 16 | firstDimensionName = "Neuron", 17 | zerothDimensionLabels, 18 | firstDimensionLabels 19 | }: TopkSamplesProps) { 20 | const numberOfLayers = activations.length; 21 | const numberOfNeurons = activations[0].length; 22 | const numberOfSamples = activations[0][0].length; 23 | 24 | const [samplesPerPage, setSamplesPerPage] = useState( 25 | Math.min(5, numberOfSamples) 26 | ); 27 | const [sampleNumbers, setSampleNumbers] = useState([ 28 | ...Array(samplesPerPage).keys() 29 | ]); 30 | const [layerNumber, setLayerNumber] = useState(0); 31 | const [neuronNumber, setNeuronNumber] = useState(0); 32 | 33 | useEffect(() => { 34 | // When the user changes the samplesPerPage, update the sampleNumbers 35 | setSampleNumbers([...Array(samplesPerPage).keys()]); 36 | }, [samplesPerPage]); 37 | 38 | // Get the relevant activations for the selected layer and neuron. 39 | const selectedActivations: number[][] = sampleNumbers.map((sampleNumber) => { 40 | return activations[layerNumber][neuronNumber][sampleNumber]; 41 | }); 42 | const selectedTokens: string[][] = sampleNumbers.map((sampleNumber) => { 43 | return tokens[layerNumber][neuronNumber][sampleNumber]; 44 | }); 45 | 46 | // For a consistent color scale across all samples in this layer and neuron 47 | const [minValue, maxValue] = minMaxInNestedArray( 48 | activations[layerNumber][neuronNumber] 49 | ); 50 | 51 | const selectRowStyle = { 52 | paddingTop: 5, 53 | paddingBottom: 5 54 | }; 55 | 56 | return ( 57 | 58 | 59 | 60 | 61 | 62 | 65 | 72 | 73 | 74 | 75 | 76 | 79 | 86 | 87 | 88 | {/* Only show the sample selector if there is more than one sample */} 89 | {numberOfSamples > 1 && ( 90 | 91 | 92 | 95 | 102 | 103 | 104 | )} 105 | 106 | 107 | {/* Only show the sample per page selector if there is more than one sample */} 108 | {numberOfSamples > 1 && ( 109 | 110 | 111 | 117 | 124 | 125 | 126 | )} 127 | 128 | 129 | 130 | 131 | 137 | 138 | 139 | 140 | ); 141 | } 142 | 143 | export interface TopkSamplesProps { 144 | /** 145 | * Nested list of tokens of shape [layers x neurons x samples x tokens] 146 | * 147 | * The inner most dimension must be the same size as the inner most dimension of activations. 148 | * 149 | * For example, the first and second dimensisons (1-indexed) may correspond to 150 | * layers and neurons. 151 | */ 152 | tokens: string[][][][]; 153 | 154 | /** 155 | * Activations for the tokens with shape [layers x neurons x samples x tokens] 156 | * 157 | */ 158 | activations: number[][][][]; 159 | 160 | /** 161 | * Name of the zeroth dimension 162 | */ 163 | zerothDimensionName?: string; 164 | 165 | /** 166 | * Name of the first dimension 167 | */ 168 | firstDimensionName?: string; 169 | 170 | /** 171 | * Labels for the zeroth dimension 172 | */ 173 | zerothDimensionLabels?: string[]; 174 | 175 | /** 176 | * Labels for the first dimension 177 | */ 178 | firstDimensionLabels?: string[]; 179 | } 180 | -------------------------------------------------------------------------------- /react/src/topk/TopkTokens.stories.tsx: -------------------------------------------------------------------------------- 1 | import { ComponentStory, ComponentMeta } from "@storybook/react"; 2 | import React from "react"; 3 | import { 4 | mockTokens, 5 | topkVals, 6 | topkIdxs, 7 | bottomkVals, 8 | bottomkIdxs, 9 | objType, 10 | layerLabels 11 | } from "./mocks/topkTokens"; 12 | import { TopkTokens } from "./TopkTokens"; 13 | 14 | export default { 15 | component: TopkTokens 16 | } as ComponentMeta; 17 | 18 | const Template: ComponentStory = (args) => ( 19 | 20 | ); 21 | 22 | export const ExampleTokens: ComponentStory = Template.bind( 23 | {} 24 | ); 25 | ExampleTokens.args = { 26 | tokens: mockTokens, 27 | topkVals, 28 | topkIdxs, 29 | bottomkVals, 30 | bottomkIdxs, 31 | thirdDimensionName: objType, 32 | firstDimensionLabels: layerLabels 33 | }; 34 | -------------------------------------------------------------------------------- /react/src/topk/mocks/topkSamples.ts: -------------------------------------------------------------------------------- 1 | const text: string = ` 2 | A goose (PL: geese) is a bird of any of several waterfowl species in the family Anatidae. This group comprises the genera Anser (the grey geese and white geese) and Branta (the black geese). Some other birds, mostly related to the shelducks, have "goose" as part of their names. More distantly related members of the family Anatidae are swans, most of which are larger than true geese, and ducks, which are smaller. 3 | 4 | The term "goose" may refer to either a male or female bird, but when paired with "gander", refers specifically to a female one (the latter referring to a male). Young birds before fledging are called goslings.[1] The collective noun for a group of geese on the ground is a gaggle; when in flight, they are called a skein, a team, or a wedge; when flying close together, they are called a plump.[2] 5 | Contents 6 | 7 | 1 Etymology 8 | 2 True geese and their relatives 9 | 3 Fossil record 10 | 4 Migratory patterns 11 | 4.1 Preparation 12 | 4.2 Navigation 13 | 4.3 Formation 14 | 5 Other birds called "geese" 15 | 6 In popular culture 16 | 6.1 "Gray Goose Laws" in Iceland 17 | 7 Gallery 18 | 8 See also 19 | 9 References 20 | 10 Further reading 21 | 11 External links 22 | 23 | Etymology 24 | 25 | The word "goose" is a direct descendant of,*ghans-. In Germanic languages, the root gave Old English gōs with the plural gēs and gandres (becoming Modern English goose, geese, gander, and gosling, respectively), Frisian goes, gies and guoske, New High German Gans, Gänse, and Ganter, and Old Norse gās. 26 | 27 | This term also gave Lithuanian: žąsìs, Irish: gé (goose, from Old Irish géiss), Hindi: कलहंस, Latin: anser, Spanish: ganso, Ancient Greek: χήν (khēn), Dutch: gans, Albanian: gatë swans), Finnish: hanhi, Avestan zāō, Polish: gęś, Romanian: gâscă / gânsac, Ukrainian: гуска / гусак (huska / husak), Russian: гусыня / гусь (gusyna / gus), Czech: husa, and Persian: غاز (ghāz).[1][3] 28 | True geese and their relatives 29 | Snow geese (Anser caerulescens) in Quebec, Canada 30 | Chinese geese (Anser cygnoides domesticus), the domesticated form of the swan goose (Anser cygnoides) 31 | Barnacle geese (Branta leucopsis) in Naantali, Finland 32 | 33 | The two living genera of true geese are: Anser, grey geese and white geese, such as the greylag goose and snow goose, and Branta, black geese, such as the Canada goose. 34 | 35 | Two genera of geese are only tentatively placed in the Anserinae; they may belong to the shelducks or form a subfamily on their own: Cereopsis, the Cape Barren goose, and Cnemiornis, the prehistoric New Zealand goose. Either these or, more probably, the goose-like coscoroba swan is the closest living relative of the true geese. 36 | 37 | Fossils of true geese are hard to assign to genus; all that can be said is that their fossil record, particularly in North America, is dense and comprehensively documents many different species of true geese that have been around since about 10 million years ago in the Miocene. The aptly named Anser atavus (meaning "progenitor goose") from some 12 million years ago had even more plesiomorphies in common with swans. In addition, some goose-like birds are known from subfossil remains found on the Hawaiian Islands. 38 | 39 | Geese are monogamous, living in permanent pairs throughout the year; however, unlike most other permanently monogamous animals, they are territorial only during the short nesting season. Paired geese are more dominant and feed more, two factors that result in more young.[4][5] 40 | 41 | Geese honk while in flight to encourage other members of the flock to maintain a 'v-formation' and to help communicate with one another.[6] 42 | Fossil record 43 | 44 | Geese fossils have been found ranging from 10 to 12 million years ago (Middle Miocene). Garganornis ballmanni from Late Miocene (~ 6-9 Ma) of Gargano region of central Italy, stood one and a half meters tall and weighed about 22 kilograms. The evidence suggests the bird was flightless, unlike modern geese.[7] 45 | Migratory patterns 46 | 47 | Geese like the Canada goose do not always migrate.[8] Some members of the species only move south enough to ensure a supply of food and water. When European settlers came to America, the birds were seen as easy prey and were almost wiped out of the population. The species was reintroduced across the northern U.S. range and their population has been growing ever since.[9] 48 | Preparation 49 | `; 50 | 51 | function chunkText(textArr: string[]): string[][] { 52 | const chunks: string[][] = []; 53 | let i = 0; 54 | // Split textArr into 12 chunks of 75 tokens 55 | const chunkSize = 75; 56 | while (i < textArr.length) { 57 | chunks.push(textArr.slice(i, i + chunkSize)); 58 | i += chunkSize; 59 | } 60 | return chunks; 61 | } 62 | 63 | // This creates 12 samples 64 | const mockTokensFlat: string[][] = chunkText(text.split(/(?=\s)/)); 65 | const numLayers: number = 2; 66 | const numNeurons: number = 2; 67 | const numK: number = 3; 68 | 69 | // Convert mockTokensFlat to a nested array of size (n_layers, n_neurons, numK, 70 | // numTokens) where numTokens varies by sample 71 | let sampleIdx: number = 0; 72 | export const mockTokens: string[][][][] = []; 73 | export const mockActivations: number[][][][] = []; 74 | for (let l = 0; l < numLayers; l += 1) { 75 | const layerTokens: string[][][] = []; 76 | const layerActivations: number[][][] = []; 77 | for (let n = 0; n < numNeurons; n += 1) { 78 | const neuronTokens: string[][] = []; 79 | const neuronActivations: number[][] = []; 80 | for (let k = 0; k < numK; k += 1) { 81 | neuronTokens.push(mockTokensFlat[sampleIdx]); 82 | neuronActivations.push( 83 | Array.from(Array(mockTokensFlat[sampleIdx].length), () => Math.random()) 84 | ); 85 | sampleIdx += 1; 86 | } 87 | layerTokens.push(neuronTokens); 88 | layerActivations.push(neuronActivations); 89 | } 90 | mockTokens.push(layerTokens); 91 | mockActivations.push(layerActivations); 92 | } 93 | 94 | export const neuronLabels: string[] = ["3", "42"]; 95 | -------------------------------------------------------------------------------- /react/src/topk/mocks/topkTokens.ts: -------------------------------------------------------------------------------- 1 | import { 2 | Rank, 3 | tensor, 4 | Tensor3D, 5 | reverse, 6 | topk as tfTopk 7 | } from "@tensorflow/tfjs"; 8 | 9 | const text: string = ` 10 | A goose (PL: geese) is a bird of any of several waterfowl species in the family Anatidae. This group comprises the genera Anser (the grey geese and white geese) and Branta (the black geese). Some other birds, mostly related to the shelducks, have "goose" as part of their names. More distantly related members of the family Anatidae are swans, most of which are larger than true geese, and ducks, which are smaller. 11 | 12 | The term "goose" may refer to either a male or female bird, but when paired with "gander", refers specifically to a female one (the latter referring to a male). Young birds before fledging are called goslings.[1] The collective noun for a group of geese on the ground is a gaggle; when in flight, they are called a skein, a team, or a wedge; when flying close together, they are called a plump.[2] 13 | Contents 14 | 15 | 1 Etymology 16 | 2 True geese and their relatives 17 | 3 Fossil record 18 | 4 Migratory patterns 19 | 4.1 Preparation 20 | 4.2 Navigation 21 | 4.3 Formation 22 | 5 Other birds called "geese" 23 | 6 In popular culture 24 | 6.1 "Gray Goose Laws" in Iceland 25 | 7 Gallery 26 | 8 See also 27 | 9 References 28 | 10 Further reading 29 | 11 External links 30 | 31 | Etymology 32 | 33 | The word "goose" is a direct descendant of,*ghans-. In Germanic languages, the root gave Old English gōs with the plural gēs and gandres (becoming Modern English goose, geese, gander, and gosling, respectively), Frisian goes, gies and guoske, New High German Gans, Gänse, and Ganter, and Old Norse gās. 34 | 35 | This term also gave Lithuanian: žąsìs, Irish: gé (goose, from Old Irish géiss), Hindi: कलहंस, Latin: anser, Spanish: ganso, Ancient Greek: χήν (khēn), Dutch: gans, Albanian: gatë swans), Finnish: hanhi, Avestan zāō, Polish: gęś, Romanian: gâscă / gânsac, Ukrainian: гуска / гусак (huska / husak), Russian: гусыня / гусь (gusyna / gus), Czech: husa, and Persian: غاز (ghāz).[1][3] 36 | True geese and their relatives 37 | Snow geese (Anser caerulescens) in Quebec, Canada 38 | Chinese geese (Anser cygnoides domesticus), the domesticated form of the swan goose (Anser cygnoides) 39 | Barnacle geese (Branta leucopsis) in Naantali, Finland 40 | 41 | The two living genera of true geese are: Anser, grey geese and white geese, such as the greylag goose and snow goose, and Branta, black geese, such as the Canada goose. 42 | 43 | Two genera of geese are only tentatively placed in the Anserinae; they may belong to the shelducks or form a subfamily on their own: Cereopsis, the Cape Barren goose, and Cnemiornis, the prehistoric New Zealand goose. Either these or, more probably, the goose-like coscoroba swan is the closest living relative of the true geese. 44 | 45 | Fossils of true geese are hard to assign to genus; all that can be said is that their fossil record, particularly in North America, is dense and comprehensively documents many different species of true geese that have been around since about 10 million years ago in the Miocene. The aptly named Anser atavus (meaning "progenitor goose") from some 12 million years ago had even more plesiomorphies in common with swans. In addition, some goose-like birds are known from subfossil remains found on the Hawaiian Islands. 46 | 47 | Geese are monogamous, living in permanent pairs throughout the year; however, unlike most other permanently monogamous animals, they are territorial only during the short nesting season. Paired geese are more dominant and feed more, two factors that result in more young.[4][5] 48 | 49 | Geese honk while in flight to encourage other members of the flock to maintain a 'v-formation' and to help communicate with one another.[6] 50 | Fossil record 51 | 52 | Geese fossils have been found ranging from 10 to 12 million years ago (Middle Miocene). Garganornis ballmanni from Late Miocene (~ 6-9 Ma) of Gargano region of central Italy, stood one and a half meters tall and weighed about 22 kilograms. The evidence suggests the bird was flightless, unlike modern geese.[7] 53 | Migratory patterns 54 | 55 | Geese like the Canada goose do not always migrate.[8] Some members of the species only move south enough to ensure a supply of food and water. When European settlers came to America, the birds were seen as easy prey and were almost wiped out of the population. The species was reintroduced across the northern U.S. range and their population has been growing ever since.[9] 56 | Preparation 57 | `; 58 | 59 | function chunkText(textArr: string[]): string[][] { 60 | const chunks: string[][] = []; 61 | let i = 0; 62 | // Split textArr into 12 chunks of 75 tokens 63 | const chunkSize = 75; 64 | while (i < textArr.length) { 65 | chunks.push(textArr.slice(i, i + chunkSize)); 66 | i += chunkSize; 67 | } 68 | return chunks; 69 | } 70 | 71 | function createRandom3DActivationMatrix(shape: number[]): number[][][] { 72 | return Array.from(Array(shape[0]), () => 73 | Array.from(Array(shape[1]), () => 74 | Array.from(Array(shape[2]), () => Math.random()) 75 | ) 76 | ); 77 | } 78 | 79 | const numLayers: number = 2; 80 | const numSVDDirs: number = 30; 81 | const k: number = 5; 82 | 83 | // [samples x tokens] 84 | export const mockTokens: string[][] = chunkText(text.split(/(?=\s)/)); 85 | 86 | // [samples x layers x neurons x tokens] 87 | const mockActivations: Tensor3D[] = mockTokens.map((tokens) => { 88 | return tensor( 89 | createRandom3DActivationMatrix([numLayers, numSVDDirs, tokens.length]) 90 | ); 91 | }); 92 | 93 | // All have shape [samples x layers x k x neurons] 94 | export const topkVals: number[][][][] = []; 95 | export const topkIdxs: number[][][][] = []; 96 | export const bottomkVals: number[][][][] = []; 97 | export const bottomkIdxs: number[][][][] = []; 98 | 99 | for (let sampleNum = 0; sampleNum < mockActivations.length; sampleNum += 1) { 100 | const { values: sampleTopkValsRaw, indices: sampleTopkIdxsRaw } = tfTopk( 101 | mockActivations[sampleNum], 102 | k, 103 | true 104 | ); // [layers x neurons x k] 105 | const { values: sampleBottomkValsRaw, indices: sampleBottomkIdxsRaw } = 106 | tfTopk(mockActivations[sampleNum].mul(-1), k, true); // [layers x neurons x k] 107 | // Append the sample TopkValsRaw to the topkVals array 108 | topkVals.push( 109 | sampleTopkValsRaw.transpose([0, 2, 1]).arraySync() as number[][][] 110 | ); // [layers x k x neurons] 111 | topkIdxs.push( 112 | sampleTopkIdxsRaw.transpose([0, 2, 1]).arraySync() as number[][][] 113 | ); // [layers x k x neurons] 114 | bottomkVals.push( 115 | reverse(sampleBottomkValsRaw.mul(-1), -1) 116 | .transpose([0, 2, 1]) 117 | .arraySync() as number[][][] 118 | ); // [layers x k x neurons] 119 | bottomkIdxs.push( 120 | reverse(sampleBottomkIdxsRaw, -1) 121 | .transpose([0, 2, 1]) 122 | .arraySync() as number[][][] 123 | ); // [layers x k x neurons] 124 | } 125 | 126 | export const objType: string = "SVD Direction"; 127 | 128 | export const layerLabels: string[] = ["10", "12"]; 129 | -------------------------------------------------------------------------------- /react/src/utils/arrayOps.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Slices a 2D array. 3 | * 4 | * @param arr - The 2D array to slice 5 | * @param dims - The dimensions to slice, as 2D array of pairs of start and end indices 6 | * @returns The sliced array 7 | */ 8 | export function arraySlice2D( 9 | arr: number[][], 10 | dims: [number, number][] 11 | ): number[][] { 12 | // Recursively apply the slicing operation on each dimension 13 | return arr 14 | .slice(dims[0][0], dims[0][1]) 15 | .map((row) => row.slice(dims[1][0], dims[1][1])); 16 | } 17 | 18 | type NestedArrayOfNumbers = number[] | NestedArrayOfNumbers[]; 19 | /** 20 | * Finds the minimum and maximum values in a nested array of arbitrary depth. 21 | * 22 | * @param {any[]} arr - The input array. 23 | * @returns {[number, number]} A tuple containing the minimum and maximum values in the array. 24 | * 25 | * @example 26 | * minMaxInNestedArray([1, 2, 3, [4, 5, [6, 7]], 8]); 27 | * // returns [1, 8] 28 | * 29 | * @example 30 | * minMaxInNestedArray([[[[1]]]], 2, 3, [[4]]); 31 | * // returns [1, 4] 32 | */ 33 | export function minMaxInNestedArray( 34 | arr: NestedArrayOfNumbers 35 | ): [number, number] { 36 | if (arr.length === 0) { 37 | return [0, 1]; 38 | } 39 | let min = Number.MAX_VALUE; 40 | let max = Number.MIN_VALUE; 41 | for (let i = 0; i < arr.length; i += 1) { 42 | if (Array.isArray(arr[i])) { 43 | const [subMin, subMax] = minMaxInNestedArray( 44 | arr[i] as NestedArrayOfNumbers 45 | ); 46 | min = Math.min(min, subMin); 47 | max = Math.max(max, subMax); 48 | } else { 49 | min = Math.min(min, arr[i] as number); 50 | max = Math.max(max, arr[i] as number); 51 | } 52 | } 53 | return [min, max]; 54 | } 55 | -------------------------------------------------------------------------------- /react/src/utils/getTokenBackgroundColor.ts: -------------------------------------------------------------------------------- 1 | import { colord, extend, AnyColor, Colord } from "colord"; 2 | import mixPlugin from "colord/plugins/mix"; 3 | import namesPlugin from "colord/plugins/names"; 4 | 5 | extend([mixPlugin, namesPlugin]); 6 | 7 | /** 8 | * Get the token background color 9 | * 10 | * Defaults to color blind friendly colors (https://davidmathlogic.com/colorblind/#%23D81B60-%231E88E5-%23FFC107-%23004D40) 11 | */ 12 | export function getTokenBackgroundColor( 13 | value: number, 14 | min: number, 15 | max: number, 16 | negativeColor: AnyColor = "red", 17 | positiveColor: AnyColor = "blue" 18 | ): Colord { 19 | // original_color.mix("white", x) interpolates between original_color and 20 | // white, with x being the ratio of white. So x=0 is original_color, x=1 is 21 | // white. Clamp at 0 to avoid negative values. 22 | if (value >= 0) { 23 | return colord(positiveColor).mix( 24 | colord("white"), 25 | Math.min(Math.max(1 - value / max, 0), 1) 26 | ); 27 | } 28 | 29 | // value and min are assumed to be negative. We negate them to be consistent with the positive case. 30 | return colord(negativeColor).mix( 31 | colord("white"), 32 | Math.min(Math.max(1 - -value / -min, 0), 1) 33 | ); 34 | } 35 | -------------------------------------------------------------------------------- /react/src/utils/tests/arrayOps.test.ts: -------------------------------------------------------------------------------- 1 | import { arraySlice2D, minMaxInNestedArray } from "../arrayOps"; 2 | 3 | describe("arraySlice", () => { 4 | it("should slice a 2D array", () => { 5 | const arr: number[][] = [ 6 | [1, 2, 3], 7 | [4, 5, 6], 8 | [7, 8, 9] 9 | ]; 10 | const dims: [number, number][] = [ 11 | [1, 2], 12 | [1, 3] 13 | ]; 14 | const sliced = arraySlice2D(arr, dims); 15 | expect(sliced).toEqual([[5, 6]]); 16 | }); 17 | it("should slice a 2D array at the edges", () => { 18 | const arr: number[][] = [ 19 | [1, 2, 3], 20 | [4, 5, 6], 21 | [7, 8, 9] 22 | ]; 23 | const dims: [number, number][] = [ 24 | [0, 3], 25 | [0, 3] 26 | ]; 27 | const sliced = arraySlice2D(arr, dims); 28 | expect(sliced).toEqual(arr); 29 | }); 30 | }); 31 | 32 | describe("minMaxInNestedArray", () => { 33 | it("should return minimum and maximum values for a flat array", () => { 34 | const arr: number[] = [1, 2, 3, 4, 5]; 35 | const [min, max] = minMaxInNestedArray(arr); 36 | expect(min).toEqual(1); 37 | expect(max).toEqual(5); 38 | }); 39 | 40 | it("should return minimum and maximum values for a nested array", () => { 41 | const arr: any[] = [1, 2, 3, [4, 5, [6, 7]], 8]; 42 | const [min, max] = minMaxInNestedArray(arr); 43 | expect(min).toEqual(1); 44 | expect(max).toEqual(8); 45 | }); 46 | 47 | it("should return minimum and maximum values for a deeply nested array", () => { 48 | const arr: any[] = [[[[1]]], 2, 3, [[4]]]; 49 | const [min, max] = minMaxInNestedArray(arr); 50 | expect(min).toEqual(1); 51 | expect(max).toEqual(4); 52 | }); 53 | 54 | it("should return 0 and 1 for an empty array", () => { 55 | const arr: any[] = []; 56 | const [min, max] = minMaxInNestedArray(arr); 57 | expect(min).toEqual(0); 58 | expect(max).toEqual(1); 59 | }); 60 | }); 61 | -------------------------------------------------------------------------------- /react/src/utils/tests/getTokenBackgroundColor.test.ts: -------------------------------------------------------------------------------- 1 | import { colord } from "colord"; 2 | import { getTokenBackgroundColor } from "../getTokenBackgroundColor"; 3 | 4 | describe("getBackgroundColor", () => { 5 | it("sets a positive color to blue", () => { 6 | const res = getTokenBackgroundColor(1, -1, 1); 7 | const hsl = res.toHsl(); 8 | const greenHue = colord("blue").toHsv().h; 9 | expect(hsl.h).toBeCloseTo(greenHue); 10 | }); 11 | 12 | it("sets a negative color to red", () => { 13 | const res = getTokenBackgroundColor(-1, -1, 1); 14 | const hsl = res.toHsl(); 15 | const blueHue = colord("red").toHsv().h; 16 | expect(hsl.h).toBeCloseTo(blueHue); 17 | }); 18 | 19 | it("sets 0 to white", () => { 20 | // Should be 80% red 21 | const res = getTokenBackgroundColor(0, -1, 1); 22 | const hsl = res.toHsl(); 23 | expect(hsl.l).toBeCloseTo(100); 24 | }); 25 | 26 | it("Check that 0 returns a brightness of 1", () => { 27 | // If the brightness is <0.6, the text will be white and invisible 28 | const res = getTokenBackgroundColor(0, 0, 1); 29 | expect(res.brightness()).toBeCloseTo(1); 30 | }); 31 | }); 32 | -------------------------------------------------------------------------------- /react/tsconfig.build.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "./tsconfig.json", 3 | "exclude": [ 4 | "src/**/*.spec.ts", 5 | "src/**/*.test.ts", 6 | "src/**/*.test.tsx", 7 | "src/**/*.stories.ts", 8 | "src/**/*.stories.tsx", 9 | "esbuild.js" 10 | ] 11 | } 12 | -------------------------------------------------------------------------------- /react/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "sourceMap": true, 4 | "module": "es2015", 5 | "target": "es2018", 6 | "moduleResolution": "node", 7 | "outDir": "./dist/module", 8 | "jsx": "react", 9 | "skipLibCheck": true, 10 | "declaration": true, 11 | "declarationMap": true, 12 | "esModuleInterop": true, 13 | "lib": [ 14 | "dom", 15 | "dom.iterable", 16 | "esnext" 17 | ], 18 | "strict": true, 19 | "noImplicitReturns": true, 20 | "noFallthroughCasesInSwitch": true, 21 | }, 22 | "include": [ 23 | "src", 24 | "esbuild.js" 25 | ], 26 | "exclude": [ 27 | "node_modules", 28 | ], 29 | } 30 | --------------------------------------------------------------------------------