├── sae-viewer
├── public
│ ├── robots.txt
│ └── favicon.ico
├── .gitignore
├── tailwind.config.js
├── src
│ ├── index.tsx
│ ├── index.css
│ ├── utils.ts
│ ├── App.tsx
│ ├── feed.tsx
│ ├── components
│ │ ├── histogram.tsx
│ │ ├── tooltip.tsx
│ │ ├── tokenHeatmap.tsx
│ │ ├── tokenAblationmap.tsx
│ │ ├── featureSelect.tsx
│ │ └── featureInfo.tsx
│ ├── interpAPI.ts
│ ├── index.html
│ ├── autoencoder_registry.tsx
│ ├── types.ts
│ ├── App.css
│ └── welcome.tsx
├── README.md
├── tsconfig.json
└── package.json
├── sparse_autoencoder
├── __init__.py
├── loss.py
├── paths.py
├── model.py
├── explanations.py
├── kernels.py
└── train.py
├── .pre-commit-config.yaml
├── SECURITY.md
├── pyproject.toml
├── LICENSE
├── .gitignore
└── README.md
/sae-viewer/public/robots.txt:
--------------------------------------------------------------------------------
1 | # https://www.robotstxt.org/robotstxt.html
2 | User-agent: *
3 | Disallow:
4 |
--------------------------------------------------------------------------------
/sae-viewer/public/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/sparse_autoencoder/HEAD/sae-viewer/public/favicon.ico
--------------------------------------------------------------------------------
/sae-viewer/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | dist/
3 | node_modules/
4 | __pycache__/
5 | .parcel-cache/
6 | .cache/
7 | blog_json/
8 |
--------------------------------------------------------------------------------
/sparse_autoencoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import Autoencoder
2 | from . import paths
3 |
4 | __all__ = ["Autoencoder"]
5 |
--------------------------------------------------------------------------------
/sae-viewer/tailwind.config.js:
--------------------------------------------------------------------------------
1 | /** @type {import('tailwindcss').Config} */
2 | module.exports = {
3 | content: ["./src/**/*.{html,js,jsx}"],
4 | theme: {
5 | extend: {},
6 | },
7 | plugins: [],
8 | }
9 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: local
3 | hooks:
4 | - id: trufflehog
5 | name: TruffleHog
6 | description: Detect secrets in your data.
7 | entry: bash -c 'trufflehog git file://. --since-commit HEAD --fail --no-update'
8 | language: system
9 | stages: ["commit", "push"]
10 |
--------------------------------------------------------------------------------
/sae-viewer/src/index.tsx:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import ReactDOM from 'react-dom/client';
3 | import './index.css';
4 | import App from './App';
5 |
6 | const root = ReactDOM.createRoot(document.getElementById('root'));
7 | root.render(
8 |
9 |
10 |
11 | );
12 |
--------------------------------------------------------------------------------
/sae-viewer/README.md:
--------------------------------------------------------------------------------
1 | # SAE viewer
2 |
3 | The easiest way to view activation patterns is through the
4 | [public website](https://openaipublic.blob.core.windows.net/sparse-autoencoder/sae-viewer/index.html).
5 | This directory contains the implementation of that website.
6 |
7 | ## Local development
8 |
9 | Install:
10 |
11 | ```npm install```
12 |
13 | Run:
14 |
15 | ```npm start```
16 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | # Security Policy
2 | For a more in-depth look at our security policy, please check out our
3 | [Coordinated Vulnerability Disclosure Policy](https://openai.com/security/disclosure/#:~:text=Disclosure%20Policy,-Security%20is%20essential&text=OpenAI%27s%20coordinated%20vulnerability%20disclosure%20policy,expect%20from%20us%20in%20return.).
4 |
5 | Our PGP key can located [at this address.](https://cdn.openai.com/security.txt)
6 |
--------------------------------------------------------------------------------
/sae-viewer/src/index.css:
--------------------------------------------------------------------------------
1 | body {
2 | margin: 0;
3 | font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen',
4 | 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue',
5 | sans-serif;
6 | -webkit-font-smoothing: antialiased;
7 | -moz-osx-font-smoothing: grayscale;
8 | }
9 |
10 | code {
11 | font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New',
12 | monospace;
13 | }
14 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "sparse_autoencoder"
3 | description="Sparse autoencoder for GPT2"
4 | version = "0.1"
5 | authors = [{name = "OpenAI"}]
6 | dependencies = [
7 | "blobfile == 2.0.2",
8 | "torch == 2.1.0",
9 | "transformer_lens == 1.9.1",
10 | ]
11 | readme = "README.md"
12 |
13 | [build-system]
14 | requires = ["setuptools>=64.0"]
15 | build-backend = "setuptools.build_meta"
16 |
17 | [tool.setuptools.packages.find]
18 | include = ["sparse_autoencoder*"]
19 |
--------------------------------------------------------------------------------
/sae-viewer/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "target": "es2021",
4 | "module": "commonjs",
5 | "lib": ["dom", "dom.iterable", "esnext"],
6 | "allowJs": true,
7 | "skipLibCheck": true,
8 | "esModuleInterop": true,
9 | "allowSyntheticDefaultImports": true,
10 | "strict": true,
11 | "forceConsistentCasingInFileNames": true,
12 | "moduleResolution": "node",
13 | "resolveJsonModule": true,
14 | "isolatedModules": true,
15 | "noEmit": true,
16 | "jsx": "react-jsx"
17 | },
18 | "include": ["src"]
19 | }
20 |
--------------------------------------------------------------------------------
/sae-viewer/src/utils.ts:
--------------------------------------------------------------------------------
1 | export const memoizeAsync = (fnname: string, fn: any) => {
2 | return async (...args: any) => {
3 | const key = `memoized:${fnname}:${args.map((x: any) => JSON.stringify(x)).join("-")}`
4 | const val = localStorage.getItem(key);
5 | if (val === null) {
6 | const value = await fn(...args)
7 | localStorage.setItem(key, JSON.stringify(value))
8 | console.log(`memoized ${fnname}(${args.map((x: any) => JSON.stringify(x)).join(", ")})`, value)
9 | return value
10 | } else {
11 | // console.log(`parsing`, val)
12 | return JSON.parse(val)
13 | }
14 | }
15 | }
16 |
17 |
18 | export const getQueryParams = () => {
19 | const urlParams = new URLSearchParams(window.location.search)
20 | const params: {[key: string]: any} = {}
21 | for (const [key, value] of urlParams.entries()) {
22 | params[key] = value
23 | }
24 | return params
25 | }
26 |
--------------------------------------------------------------------------------
/sae-viewer/src/App.tsx:
--------------------------------------------------------------------------------
1 | import "./App.css"
2 | import Feed from "./feed"
3 | import React from "react"
4 | import { Routes, Route, HashRouter } from "react-router-dom"
5 | import { AUTOENCODER_FAMILIES } from "./autoencoder_registry"
6 | import Welcome from "./welcome"
7 |
8 | function App() {
9 | return (
10 |
11 |
12 |
13 | } />
14 | } />
15 | {
16 | Object.values(AUTOENCODER_FAMILIES).map((family) => {
17 | let extra = '';
18 | family.selectors.forEach((selector) => {
19 | extra += `/${selector.key}/:${selector.key}`;
20 | })
21 | return } />
22 | })
23 | }
24 |
25 |
26 |
27 | )
28 | }
29 |
30 | export default App
31 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 OpenAI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/sae-viewer/src/feed.tsx:
--------------------------------------------------------------------------------
1 | import FeatureInfo from "./components/featureInfo"
2 | import React, { useEffect } from "react"
3 | import Welcome from "./welcome"
4 | import { Feature } from "./types"
5 | import FeatureSelect from "./components/featureSelect"
6 | import { useState } from "react"
7 | import { useParams, useNavigate, Link } from "react-router-dom"
8 |
9 | import { pathForFeature, DEFAULT_AUTOENCODER, AUTOENCODER_FAMILIES } from "./autoencoder_registry"
10 |
11 | export default function Feed() {
12 | const params = useParams();
13 | const navigate = useNavigate();
14 | let family = AUTOENCODER_FAMILIES[params.family || DEFAULT_AUTOENCODER.family];
15 | let feature: Feature = {
16 | // "layer": parseInt(params.layer),
17 | "atom": parseInt(params.atom),
18 | "autoencoder": family.get_ae(params),
19 | };
20 | console.log('feature', JSON.stringify(feature, null, 2))
21 |
22 | return (
23 |
24 |
25 |
26 | SAE viewer
27 |
28 | navigate(pathForFeature(f, {replace: true}))}
31 | onFeatureSubmit={(f: Feature) => navigate(pathForFeature(f, {replace: true}))}
32 | />
33 |
34 |
35 |
36 |
37 |
38 |
39 | )
40 | }
41 |
--------------------------------------------------------------------------------
/sae-viewer/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "sae-viewer",
3 | "version": "0.1.67",
4 | "homepage": "https://openaipublic.blob.core.windows.net/sparse-autoencoder/sae-viewer/index.html",
5 | "dependencies": {
6 | "@headlessui/react": "^1.7.8",
7 | "@headlessui/tailwindcss": "^0.1.2",
8 | "@types/d3-scale": "^4.0.3",
9 | "@types/lodash": "^4.14.194",
10 | "@types/react": "^18.0.37",
11 | "@types/react-dom": "^18.0.11",
12 | "d3-scale": "^4.0.2",
13 | "lodash": "^4.17.21",
14 | "plotly.js": "^2.31.0",
15 | "react": "^18.2.0",
16 | "react-dom": "^18.2.0",
17 | "react-plotly.js": "^2.6.0",
18 | "react-router-dom": "^6.10.0",
19 | "web-vitals": "^3.0.3"
20 | },
21 | "scripts": {
22 | "start": "rm -rf dist && rm -rf .parcel-cache && parcel src/index.html",
23 | "build": "parcel build src/index.html",
24 | "serve": "parcel serve src/index.html",
25 | "typecheck": "tsc -p ."
26 | },
27 | "eslintConfig": {
28 | "extends": [
29 | "react-app"
30 | ]
31 | },
32 | "alias": {
33 | "preact/jsx-dev-runtime": "preact/jsx-runtime"
34 | },
35 | "devDependencies": {
36 | "@observablehq/plot": "^0.6.5",
37 | "@parcel/transformer-typescript-tsc": "^2.8.3",
38 | "@parcel/validator-typescript": "^2.8.3",
39 | "buffer": "^5.7.1",
40 | "nodemon": "^2.0.22",
41 | "parcel": "^2.8.3",
42 | "preact": "^10.13.2",
43 | "process": "^0.11.10",
44 | "react-refresh": "0.10.0",
45 | "tailwindcss": "^3.2.4",
46 | "typescript": "^5.0.4"
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/sae-viewer/src/components/histogram.tsx:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import Plot from 'react-plotly.js';
3 |
4 | // TODO get from data
5 | const BIN_WIDTH = 0.2;
6 | // # bins_fn = lambda lats: (lats / BIN_WIDTH).ceil().int()
7 | // bin_fn = lambda val: math.ceil(val / BIN_WIDTH)
8 | // bin_id_to_lower_bound = lambda xs: xs * BIN_WIDTH
9 |
10 | const HistogramDisplay = ({ data }) => {
11 | // min_bin = min(hist.keys())
12 | // max_bin = max(hist.keys())
13 | // ys = [hist.get(x, 0) for x in np.arange(min_bin, max_bin + 1)]
14 | // xs = np.arange(min_bin, max_bin + 2)
15 | // xs = bin_id_to_lower_bound(np.array(xs))
16 | const min_bin = Math.min(...Object.keys(data).map(Number));
17 | const max_bin = Math.max(...Object.keys(data).map(Number));
18 | const ys = Array.from({length: max_bin - min_bin + 1}, (_, i) => data[min_bin + i] || 0);
19 | let xs = Array.from({length: max_bin - min_bin + 2}, (_, i) => min_bin + i);
20 | xs = xs.map(x => x * BIN_WIDTH);
21 |
22 | const trace = {
23 | line: {shape: 'hvh'},
24 | mode: 'lines',
25 | type: 'scatter',
26 | x: xs,
27 | y: ys,
28 | fill: 'tozeroy',
29 | };
30 | const layout = {
31 | legend: {
32 | y: 0.5,
33 | font: {size: 16},
34 | traceorder: 'reversed',
35 | },
36 | yaxis: {
37 | type: 'log',
38 | autorange: true
39 | },
40 | margin: { l: 30, r: 0, b: 20, t: 0 },
41 | autosize: true,
42 | };
43 |
44 | return (
45 |
50 | )
51 | }
52 |
53 | export default HistogramDisplay;
54 |
--------------------------------------------------------------------------------
/sae-viewer/src/components/tooltip.tsx:
--------------------------------------------------------------------------------
1 | import React from "react"
2 |
3 | type Props = {
4 | content: React.ReactNode;
5 | tooltip: React.ReactNode;
6 | tooltipStyle?: React.CSSProperties;
7 | };
8 |
9 | function useOutsideClickAlerter(ref, fn) {
10 | React.useEffect(() => {
11 | /**
12 | * Alert if clicked on outside of element
13 | */
14 | function handleClickOutside(event) {
15 | if (ref.current && !ref.current.contains(event.target)) { fn() }
16 | }
17 | // Bind the event listener
18 | document.addEventListener("mousedown", handleClickOutside);
19 | return () => {
20 | // Unbind the event listener on clean up
21 | document.removeEventListener("mousedown", handleClickOutside);
22 | };
23 | }, [ref]);
24 | }
25 |
26 | const Tooltip: React.FunctionComponent = (props: Props) => {
27 | const [state, set_state] = React.useState({ force_open: false });
28 | const [hover, set_hover] = React.useState(false);
29 | const wrapperRef = React.useRef(null);
30 | useOutsideClickAlerter(wrapperRef, () => set_state({ force_open: false }));
31 | return (
32 |
33 | set_state({ force_open: !state.force_open })}
34 | onMouseOver={() => set_hover(true)}
35 | onMouseOut={() => setTimeout(() => set_hover(false), 100)}
36 | >
37 | {props.content}
38 |
39 |
42 | {props.tooltip}
43 |
44 |
45 | );
46 | };
47 |
48 |
49 | export default Tooltip;
50 |
--------------------------------------------------------------------------------
/sparse_autoencoder/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def autoencoder_loss(
5 | reconstruction: torch.Tensor,
6 | original_input: torch.Tensor,
7 | latent_activations: torch.Tensor,
8 | l1_weight: float,
9 | ) -> torch.Tensor:
10 | """
11 | :param reconstruction: output of Autoencoder.decode (shape: [batch, n_inputs])
12 | :param original_input: input of Autoencoder.encode (shape: [batch, n_inputs])
13 | :param latent_activations: output of Autoencoder.encode (shape: [batch, n_latents])
14 | :param l1_weight: weight of L1 loss
15 | :return: loss (shape: [1])
16 | """
17 | return (
18 | normalized_mean_squared_error(reconstruction, original_input)
19 | + normalized_L1_loss(latent_activations, original_input) * l1_weight
20 | )
21 |
22 |
23 | def normalized_mean_squared_error(
24 | reconstruction: torch.Tensor,
25 | original_input: torch.Tensor,
26 | ) -> torch.Tensor:
27 | """
28 | :param reconstruction: output of Autoencoder.decode (shape: [batch, n_inputs])
29 | :param original_input: input of Autoencoder.encode (shape: [batch, n_inputs])
30 | :return: normalized mean squared error (shape: [1])
31 | """
32 | return (
33 | ((reconstruction - original_input) ** 2).mean(dim=1) / (original_input**2).mean(dim=1)
34 | ).mean()
35 |
36 |
37 | def normalized_L1_loss(
38 | latent_activations: torch.Tensor,
39 | original_input: torch.Tensor,
40 | ) -> torch.Tensor:
41 | """
42 | :param latent_activations: output of Autoencoder.encode (shape: [batch, n_latents])
43 | :param original_input: input of Autoencoder.encode (shape: [batch, n_inputs])
44 | :return: normalized L1 loss (shape: [1])
45 | """
46 | return (latent_activations.abs().sum(dim=1) / original_input.norm(dim=1)).mean()
47 |
--------------------------------------------------------------------------------
/sae-viewer/src/components/tokenHeatmap.tsx:
--------------------------------------------------------------------------------
1 | import React from "react"
2 | import { interpolateColor, Color, getInterpolatedColor, DEFAULT_COLORS, DEFAULT_BOUNDARIES, SequenceInfo } from '../types'
3 | import Tooltip from './tooltip'
4 |
5 | type Props = {
6 | info: SequenceInfo,
7 | colors?: Color[],
8 | boundaries?: number[]
9 | renderNewlines?: boolean,
10 | }
11 |
12 | function zip_sequence(sequence: SequenceInfo) {
13 | return sequence.tokens.map((token, idx) => ({
14 | token,
15 | highlight: idx === sequence.idx,
16 | activation: sequence.acts[idx],
17 | normalized_activation: sequence.normalized_acts ? sequence.normalized_acts[idx] : undefined
18 | }));
19 | }
20 |
21 | export default function TokenHeatmap({ info, colors = DEFAULT_COLORS, boundaries = DEFAULT_BOUNDARIES, renderNewlines }: Props) {
22 | //
23 | const zipped = zip_sequence(info)
24 | return (
25 |
26 | {zipped.map(({ token, activation, normalized_activation, highlight }, i) => {
27 | const color = getInterpolatedColor(colors, boundaries, normalized_activation || activation);
28 | if (!renderNewlines) {
29 | token = token.replace(/\n/g, '↵')
30 | }
31 | return
40 | {token}
41 |
42 | }
43 | tooltip={Activation: {activation.toFixed(2)}
}
44 | key={i}
45 | />
46 | })}
47 |
48 | )
49 | }
50 |
--------------------------------------------------------------------------------
/sae-viewer/src/interpAPI.ts:
--------------------------------------------------------------------------------
1 | import {Feature, FeatureInfo} from './types';
2 | import {memoizeAsync} from "./utils"
3 |
4 | export const load_file_no_cache = async(path: string) => {
5 | const data = {
6 | path: path
7 | }
8 | const url = new URL("/load_az", window.location.href)
9 | url.port = '8000';
10 | return await (
11 | await fetch(url, {
12 | method: "POST", // or 'PUT'
13 | headers: {
14 | "Content-Type": "application/json",
15 | },
16 | body: JSON.stringify(data),
17 | })
18 | ).json()
19 |
20 | }
21 |
22 | export const load_file_az = async(path: string) => {
23 | const res = (
24 | await fetch(path, {
25 | method: "GET",
26 | mode: "cors",
27 | headers: {
28 | "Content-Type": "application/json",
29 | },
30 | })
31 | )
32 | if (!res.ok) {
33 | console.error(`HTTP error: ${res.status} - ${res.statusText}`);
34 | return;
35 | }
36 | return await res.json()
37 | }
38 |
39 |
40 | // export const load_file = memoizeAsync('load_file', load_file_no_cache)
41 | // export const load_file = window.location.host.indexOf('localhost:') === -1 ? load_file_az : load_file_no_cache;
42 | export const load_file = load_file_no_cache;
43 |
44 |
45 | export async function get_feature_info(feature: Feature, ablated?: boolean): Promise
{
46 | let load_fn = load_file_az;
47 | let prefix = "https://openaipublic.blob.core.windows.net/sparse-autoencoder/viewer"
48 | if (window.location.host.indexOf('localhost:') !== -1) {
49 | load_fn = load_file;
50 | prefix = "az://openaipublic/sparse-autoencoder/viewer"
51 | // prefix = az://oaialignment/interp/autoencoder-vis/ae
52 | }
53 |
54 | const ae = feature.autoencoder;
55 | const result = await load_fn(`${prefix}/${ae.subject}/${ae.path}/atoms/${feature.atom}${ablated ? '-ablated': ''}.json`)
56 | // console.log('result', result)
57 | return result
58 | }
59 |
--------------------------------------------------------------------------------
/sparse_autoencoder/paths.py:
--------------------------------------------------------------------------------
1 |
2 | def v1(location, layer_index):
3 | """
4 | Details:
5 | - Number of autoencoder latents: 32768
6 | - Number of training tokens: ~64M
7 | - Activation function: ReLU
8 | - L1 regularization strength: 0.01
9 | - Layer normed inputs: false
10 | - NeuronRecord files:
11 | `az://openaipublic/sparse-autoencoder/gpt2-small/{location}/collated_activations/{layer_index}/{latent_index}.json`
12 | """
13 | assert location in ["mlp_post_act", "resid_delta_mlp"]
14 | assert layer_index in range(12)
15 | return f"az://openaipublic/sparse-autoencoder/gpt2-small/{location}/autoencoders/{layer_index}.pt"
16 |
17 | def v4(location, layer_index):
18 | """
19 | Details:
20 | same as v1
21 | """
22 | assert location in ["mlp_post_act", "resid_delta_mlp"]
23 | assert layer_index in range(12)
24 | return f"az://openaipublic/sparse-autoencoder/gpt2-small/{location}_v4/autoencoders/{layer_index}.pt"
25 |
26 | def v5_32k(location, layer_index):
27 | """
28 | Details:
29 | - Number of autoencoder latents: 2**15 = 32768
30 | - Number of training tokens: TODO
31 | - Activation function: TopK(32)
32 | - L1 regularization strength: n/a
33 | - Layer normed inputs: true
34 | """
35 | assert location in ["resid_delta_attn", "resid_delta_mlp", "resid_post_attn", "resid_post_mlp"]
36 | assert layer_index in range(12)
37 | # note: it's actually 2**15 and 2**17 ~= 131k
38 | return f"az://openaipublic/sparse-autoencoder/gpt2-small/{location}_v5_32k/autoencoders/{layer_index}.pt"
39 |
40 | def v5_128k(location, layer_index):
41 | """
42 | Details:
43 | - Number of autoencoder latents: 2**17 = 131072
44 | - Number of training tokens: TODO
45 | - Activation function: TopK(32)
46 | - L1 regularization strength: n/a
47 | - Layer normed inputs: true
48 | """
49 | assert location in ["resid_delta_attn", "resid_delta_mlp", "resid_post_attn", "resid_post_mlp"]
50 | assert layer_index in range(12)
51 | # note: it's actually 2**15 and 2**17 ~= 131k
52 | return f"az://openaipublic/sparse-autoencoder/gpt2-small/{location}_v5_128k/autoencoders/{layer_index}.pt"
53 |
54 | # NOTE: we have larger autoencoders (up to 8M, with varying n and k) trained on layer 8 resid_post_mlp
55 | # we may release them in the future
56 |
--------------------------------------------------------------------------------
/sae-viewer/src/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
11 |
15 |
24 |
25 |
26 | SAE viewer
27 |
28 |
29 |
30 |
31 |
42 |
43 |
44 | You need to enable JavaScript to run this app.
45 |
46 |
56 |
57 |
58 |
59 |
60 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Sparse autoencoders
2 |
3 | This repository hosts:
4 | - sparse autoencoders trained on the GPT2-small model's activations.
5 | - a visualizer for the autoencoders' features
6 |
7 | ### Install
8 |
9 | ```sh
10 | pip install git+https://github.com/openai/sparse_autoencoder.git
11 | ```
12 |
13 | ### Code structure
14 |
15 | See [sae-viewer](./sae-viewer/README.md) to see the visualizer code, hosted publicly [here](https://openaipublic.blob.core.windows.net/sparse-autoencoder/sae-viewer/index.html).
16 |
17 | See [model.py](./sparse_autoencoder/model.py) for details on the autoencoder model architecture.
18 | See [train.py](./sparse_autoencoder/train.py) for autoencoder training code.
19 | See [paths.py](./sparse_autoencoder/paths.py) for more details on the available autoencoders.
20 |
21 | ### Example usage
22 |
23 | ```py
24 | import torch
25 | import blobfile as bf
26 | import transformer_lens
27 | import sparse_autoencoder
28 |
29 | # Extract neuron activations with transformer_lens
30 | model = transformer_lens.HookedTransformer.from_pretrained("gpt2", center_writing_weights=False)
31 | device = next(model.parameters()).device
32 |
33 | prompt = "This is an example of a prompt that"
34 | tokens = model.to_tokens(prompt) # (1, n_tokens)
35 | with torch.no_grad():
36 | logits, activation_cache = model.run_with_cache(tokens, remove_batch_dim=True)
37 |
38 | layer_index = 6
39 | location = "resid_post_mlp"
40 |
41 | transformer_lens_loc = {
42 | "mlp_post_act": f"blocks.{layer_index}.mlp.hook_post",
43 | "resid_delta_attn": f"blocks.{layer_index}.hook_attn_out",
44 | "resid_post_attn": f"blocks.{layer_index}.hook_resid_mid",
45 | "resid_delta_mlp": f"blocks.{layer_index}.hook_mlp_out",
46 | "resid_post_mlp": f"blocks.{layer_index}.hook_resid_post",
47 | }[location]
48 |
49 | with bf.BlobFile(sparse_autoencoder.paths.v5_32k(location, layer_index), mode="rb") as f:
50 | state_dict = torch.load(f)
51 | autoencoder = sparse_autoencoder.Autoencoder.from_state_dict(state_dict)
52 | autoencoder.to(device)
53 |
54 | input_tensor = activation_cache[transformer_lens_loc]
55 |
56 | input_tensor_ln = input_tensor
57 |
58 | with torch.no_grad():
59 | latent_activations, info = autoencoder.encode(input_tensor_ln)
60 | reconstructed_activations = autoencoder.decode(latent_activations, info)
61 |
62 | normalized_mse = (reconstructed_activations - input_tensor).pow(2).sum(dim=1) / (input_tensor).pow(2).sum(dim=1)
63 | print(location, normalized_mse)
64 | ```
65 |
--------------------------------------------------------------------------------
/sae-viewer/src/autoencoder_registry.tsx:
--------------------------------------------------------------------------------
1 | import { Autoencoder, Feature } from './types';
2 | import { Route } from "react-router-dom"
3 |
4 |
5 | export const GPT2_LAYER_FAMILY_32k = {
6 | subject: 'gpt2-small',
7 | name: 'v5_32k',
8 | label: 'n=32768, k=32, all locations',
9 | selectors: [
10 | {key: 'layer', label: 'Layer', values: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11']},
11 | {key: 'location', label: 'Location', values: ['resid_post_attn', 'resid_post_mlp']},
12 | ],
13 | default_H: (H: {[key: string]: string}) => ({
14 | layer: H.layer || '8', location: H.location || 'resid_post_mlp'
15 | }),
16 | get_ae: (H: {[key: string]: string}) => ({
17 | subject: 'gpt2-small',
18 | family: 'v5_32k',
19 | H: {layer: H.layer, location: H.location},
20 | path: `v5_32k/layer_${H.layer}/${H.location}`,
21 | num_features: 32768,
22 | }),
23 | }
24 |
25 | export const GPT2_LAYER_FAMILY_128k = {
26 | subject: 'gpt2-small',
27 | name: 'v5_128k',
28 | label: 'n=131072, k=32, all locations',
29 | selectors: [
30 | {key: 'layer', label: 'Layer', values: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11']},
31 | {key: 'location', label: 'Location', values: ['resid_post_attn', 'resid_post_mlp']},
32 | ],
33 | default_H: (H: {[key: string]: string}) => ({
34 | layer: H.layer || '8', location: H.location || 'resid_post_mlp'
35 | }),
36 | get_ae: (H: {[key: string]: string}) => ({
37 | subject: 'gpt2-small',
38 | family: 'v5_128k',
39 | H: {layer: H.layer, location: H.location},
40 | path: `v5_128k/layer_${H.layer}/${H.location}`,
41 | num_features: 131072,
42 | no_effects: true,
43 | }),
44 | }
45 |
46 | export const GPT2_NK_SWEEP = {
47 | subject: 'gpt2-small',
48 | name: 'v5_l8_postmlp',
49 | label: 'layer 8, resid post MLP, n/k sweep',
50 | selectors: [
51 | {key: 'num_features', label: 'Total features', values: ['2048', '8192', '32768', '131072', '524288', '2097152']},
52 | {key: 'num_active_features', label: 'Active features', values: ['8', '16', '32', '64', '128', '256', '512']},
53 | ],
54 | default_H: (H: {[key: string]: string}) => ({
55 | num_features: H.num_features || '2097152', num_active_features: H.num_active_features || '16'
56 | }),
57 | get_ae: (H: {[key: string]: string}) => ({
58 | subject: 'gpt2-small',
59 | family: 'v5_l8_postmlp',
60 | H: {num_features: H.num_features, num_active_features: H.num_active_features},
61 | path: `v5_l8_postmlp/n${H.num_features}/k${H.num_active_features}`,
62 | num_features: H.num_features,
63 | }),
64 | }
65 |
66 | export const GPT4_16m = {
67 | subject: 'gpt4',
68 | name: 'v5_latelayer_postmlp',
69 | label: 'n=16M',
70 | warning: 'Only 65536 features available. Activations shown on The Pile (uncopyrighted) instead of our internal training dataset.',
71 | selectors: [],
72 | default_H: (H: {[key: string]: string}) => ({}),
73 | get_ae: (H: {[key: string]: string}) => ({
74 | subject: 'gpt4',
75 | family: 'v5_latelayer_postmlp',
76 | H: {},
77 | path: `v5_latelayer_postmlp/n16777216/k256`,
78 | num_features: 65536,
79 | no_effects: true,
80 | }),
81 | }
82 |
83 | // export const DEFAULT_AUTOENCODER = GPT2_NK_SWEEP.get_ae(
84 | // GPT2_NK_SWEEP.default_H({})
85 | // );
86 | export const DEFAULT_AUTOENCODER = GPT4_16m.get_ae(
87 | GPT4_16m.default_H({})
88 | );
89 |
90 | export const AUTOENCODER_FAMILIES = Object.fromEntries(
91 | [
92 | GPT2_NK_SWEEP,
93 | GPT2_LAYER_FAMILY_32k,
94 | GPT2_LAYER_FAMILY_128k,
95 | GPT4_16m,
96 | ].map((family) => [family.name, family])
97 | );
98 |
99 | export const SUBJECT_MODELS = ['gpt2-small', 'gpt4'];
100 |
101 | export function pathForFeature(feature: Feature) {
102 | let res = `/model/${feature.autoencoder.subject}/family/${feature.autoencoder.family}`;
103 | // for (const [key, value] of Object.entries(feature.autoencoder.H)) {
104 | // res += `/${key}/${value}`;
105 | // }
106 | for (const selector of AUTOENCODER_FAMILIES[feature.autoencoder.family].selectors) {
107 | res += `/${selector.key}/${feature.autoencoder.H[selector.key]}`;
108 | }
109 | res += `/feature/${feature.atom}`;
110 | console.log('res', res)
111 | return res
112 | }
113 |
114 |
--------------------------------------------------------------------------------
/sae-viewer/src/types.ts:
--------------------------------------------------------------------------------
1 | import { scaleLinear } from "d3-scale"
2 | import { min, max, flatten } from "lodash"
3 |
4 | export type Autoencoder = {
5 | subject: string,
6 | num_features: number,
7 | family: string,
8 | H: {[key: string]: any},
9 | path: string,
10 | };
11 |
12 |
13 | export type Feature = {
14 | autoencoder: Autoencoder;
15 | atom: number;
16 | }
17 |
18 | export type TokenAndActivation = {
19 | token: string,
20 | activation: number
21 | normalized_activation?: number
22 | }
23 | export type TokenSequence = TokenAndActivation[]
24 |
25 | export type SequenceInfo = {
26 | density: number,
27 | doc_id: number,
28 | idx: number, // which act this document was selected for
29 | acts: number[],
30 | act: number,
31 | tokens: string[],
32 | token_ints: number[],
33 | normalized_acts?: number[],
34 | ablate_loss_diff?: number[],
35 | kl?: number[],
36 | top_downvote_tokens_logits?: string[][],
37 | top_downvotes_logits?: number[][],
38 | top_upvote_tokens_logits?: string[][],
39 | top_upvotes_logits?: number[][],
40 | top_downvote_tokens_probs?: string[][],
41 | top_downvotes_probs?: number[][],
42 | top_upvote_tokens_probs?: string[][],
43 | top_upvotes_probs?: number[][],
44 | }
45 |
46 | export function zip_sequence(sequence: SequenceInfo) {
47 | return sequence.tokens.map((token, idx) => ({
48 | token,
49 | highlight: idx === sequence.idx,
50 | activation: sequence.acts[idx],
51 | normalized_activation: sequence.normalized_acts ? sequence.normalized_acts[idx] : undefined
52 | }));
53 | }
54 |
55 | export type FeatureInfo = {
56 | density: number,
57 | mean_act: number,
58 | mean_act_squared: number,
59 | hist: {[key: number]: number},
60 | random: SequenceInfo[],
61 | top: SequenceInfo[],
62 | }
63 |
64 | export const normalizeSequences = (...sequences: SequenceInfo[][]) => {
65 | // console.log('sequences', sequences)
66 | let flattened: SequenceInfo[] = flatten(sequences)
67 | const maxActivation = Math.max(0, ...flattened.map((s) => Math.max(...s.acts)));
68 | const scaler = scaleLinear()
69 | // Even though we're only displaying positive activations, we still need to scale in a way that
70 | // accounts for the existence of negative activations, since our color scale includes them.
71 | .domain([0, maxActivation])
72 | .range([0, 1])
73 |
74 | sequences.map((seqs) => seqs.map((s) => {
75 | s.normalized_acts = s.acts.map((activation) => scaler(activation));
76 | }))
77 | }
78 |
79 | export const normalizeTokenActs = (...sequences: TokenSequence[][]) => {
80 | // console.log('sequences', sequences)
81 | let flattened: TokenAndActivation[] = flatten(flatten(sequences))
82 | // Replace all activations less than 0 in data.tokens with 0. This matches the format in the
83 | // top + random activation records displayed in the main grid.
84 | flattened = flattened.map(({token, activation}) => {
85 | return {
86 | token,
87 | activation: Math.max(activation, 0)
88 | }
89 | })
90 | const maxActivation = max(flattened.map((ta) => ta.activation)) || 0;
91 | const scaler = scaleLinear()
92 | // Even though we're only displaying positive activations, we still need to scale in a way that
93 | // accounts for the existence of negative activations, since our color scale includes them.
94 | .domain([0, maxActivation])
95 | .range([0, 1])
96 |
97 | return sequences.map((seq) => seq.map((tas) => tas.map(({ token, activation }) => ({
98 | token,
99 | activation,
100 | normalized_activation: scaler(activation),
101 | }))))
102 | }
103 |
104 | export type Color = {r: number, g: number, b: number};
105 | export function interpolateColor(color_l: Color, color_r: Color, value: number) {
106 | const color = {
107 | r: Math.round(color_l.r + (color_r.r - color_l.r) * value),
108 | g: Math.round(color_l.g + (color_r.g - color_l.g) * value),
109 | b: Math.round(color_l.b + (color_r.b - color_l.b) * value),
110 | }
111 | return color
112 | }
113 |
114 | export function getInterpolatedColor(colors: Color[], boundaries: number[], value: number) {
115 | const index = boundaries.findIndex((boundary) => boundary >= value)
116 | const colorIndex = Math.max(0, index - 1)
117 | const color_left = colors[colorIndex]
118 | const color_right = colors[colorIndex + 1]
119 | const boundary_left = boundaries[colorIndex]
120 | const boundary_right = boundaries[colorIndex + 1]
121 | const ratio = (value - boundary_left) / (boundary_right - boundary_left)
122 | const color = interpolateColor(color_left, color_right, ratio)
123 | return color
124 | }
125 |
126 | export const DEFAULT_COLORS = [
127 | { r: 255, g: 0, b: 0 },
128 | { r: 255, g: 255, b: 255 },
129 | { r: 0, g: 255, b: 0 },
130 | ]
131 | export const DEFAULT_BOUNDARIES = [
132 | -1, 0, 1
133 | ]
134 |
135 |
--------------------------------------------------------------------------------
/sae-viewer/src/components/tokenAblationmap.tsx:
--------------------------------------------------------------------------------
1 | import React from "react"
2 | import { interpolateColor, Color, getInterpolatedColor, DEFAULT_COLORS, SequenceInfo } from '../types'
3 | import Tooltip from './tooltip'
4 | import { scaleLinear } from "d3-scale"
5 |
6 | type Props = {
7 | info: SequenceInfo,
8 | colors?: Color[],
9 | boundaries?: number[],
10 | renderNewlines?: boolean,
11 | }
12 |
13 | export const normalizeToUnitInterval = (arr: number[]) => {
14 | const max = Math.max(...arr);
15 | const min = Math.min(...arr);
16 | const max_abs = Math.max(Math.abs(max), Math.abs(min));
17 | const rescale = scaleLinear()
18 | // Even though we're only displaying positive activations, we still need to scale in a way that
19 | // accounts for the existence of negative activations, since our color scale includes them.
20 | .domain([-max_abs, max_abs])
21 | .range([-1, 1])
22 |
23 | return arr.map((x) => rescale(x));
24 | }
25 |
26 |
27 | export default function TokenAblationmap({ info, colors = DEFAULT_COLORS, renderNewlines }: Props) {
28 | //
29 | if (!info.ablate_loss_diff) {
30 | return <> >;
31 | }
32 | const lossDiffsNorm = normalizeToUnitInterval(info.ablate_loss_diff.map((x) => (-x)));
33 | return (
34 |
35 | {info.tokens.map((token, idx) => {
36 | const highlight = idx === info.idx;
37 | const loss_diff = (idx === 0) ? 0: info.ablate_loss_diff[idx-1];
38 | const kl = (idx === 0) ? 0: info.kl[idx-1];
39 | const activation = info.acts[idx];
40 | const top_downvotes = (idx === 0) ? [] : info.top_downvotes_logits[idx-1];
41 | const top_downvote_tokens = (idx === 0) ? [] : info.top_downvote_tokens_logits[idx-1];
42 | const top_upvotes = (idx === 0) ? [] : info.top_upvotes_logits[idx-1];
43 | const top_upvote_tokens = (idx === 0) ? [] : info.top_upvote_tokens_logits[idx-1];
44 | // const top_downvotes_weighted = (idx === 0) ? [] : info.top_downvotes_weighted[idx-1];
45 | // const top_downvote_tokens_weighted = (idx === 0) ? [] : info.top_downvote_tokens_weighted[idx-1];
46 | // const top_upvotes_weighted = (idx === 0) ? [] : info.top_upvotes_weighted[idx-1];
47 | // const top_upvote_tokens_weighted = (idx === 0) ? [] : info.top_upvote_tokens_weighted[idx-1];
48 | const top_downvotes_probs = (idx === 0) ? [] : info.top_downvotes_probs[idx-1];
49 | const top_downvote_tokens_probs = (idx === 0) ? [] : info.top_downvote_tokens_probs[idx-1];
50 | const top_upvotes_probs = (idx === 0) ? [] : info.top_upvotes_probs[idx-1];
51 | const top_upvote_tokens_probs = (idx === 0) ? [] : info.top_upvote_tokens_probs[idx-1];
52 | const color = getInterpolatedColor(colors, [-1, 0, 1], (idx === 0) ? 0 : lossDiffsNorm[idx-1]);
53 | if (!renderNewlines) {
54 | token = token.replace(/\n/g, '↵')
55 | }
56 | return
65 | {token}
66 |
67 | }
68 | tooltip={(idx <= info.idx) ? (prediction prior to ablated token)
:
69 | Loss diff: {loss_diff.toExponential(2)}
70 | KL(clean || ablated): {kl.toExponential(2)}
71 | Logit diffs:
72 |
73 |
74 |
75 | {
76 | ['', /*' (weighted)',*/ ' (probs)'].map((suffix, i) => {
77 | return
78 | Upvoted {suffix}
79 | Downvoted {suffix}
80 |
81 | })
82 | }
83 |
84 |
85 |
86 | {
87 | top_upvotes.map((upvote, j) => {
88 | const downvote = top_downvotes[j];
89 | return
90 | {upvote.toExponential(1)}
91 | {top_upvote_tokens[j]}
92 | {downvote.toExponential(1)}
93 | {top_downvote_tokens[j]}
94 | {top_upvotes_probs[j].toExponential(1)}
95 | {top_upvote_tokens_probs[j]}
96 | {top_downvotes_probs[j].toExponential(1)}
97 | {top_downvote_tokens_probs[j]}
98 |
99 | })
100 | }
101 |
102 |
103 |
}
104 | key={idx}
105 | />
106 | })}
107 |
108 | )
109 |
110 | }
111 |
--------------------------------------------------------------------------------
/sae-viewer/src/components/featureSelect.tsx:
--------------------------------------------------------------------------------
1 | import FeatureInfo from "./components/featureInfo"
2 | import React, { useEffect, useState, FormEvent } from "react"
3 | import Welcome from "./welcome"
4 | import { Feature } from "./types"
5 | import { useState } from "react"
6 |
7 | import { pathForFeature, DEFAULT_AUTOENCODER, SUBJECT_MODELS, AUTOENCODER_FAMILIES } from "../autoencoder_registry"
8 |
9 | type FeatureSelectProps = {
10 | init_feature: Feature,
11 | onFeatureChange?: (feature: Feature) => void,
12 | onFeatureSubmit: (feature: Feature) => void,
13 | }
14 |
15 | export default function FeatureSelect({init_feature, onFeatureChange, onFeatureSubmit, show_go}: FeatureSelectProps) {
16 | let [feature, setFeature] = useState(init_feature);
17 | let family = AUTOENCODER_FAMILIES[feature.autoencoder.family];
18 | let [warningAcknowledged, setWarningAcknowledged] = useState
(localStorage.getItem('warningAcknowledged') === 'true');
19 | let changeFeature = (feature: Feature) => {
20 | onFeatureChange && onFeatureChange(feature);
21 | setFeature(feature);
22 | }
23 | // console.log('features', feature.autoencoder.num_features)
24 | const feelingLuckySubmit = () => {
25 | const atom = Math.floor(Math.random() * feature.autoencoder.num_features);
26 | const random_feature: Feature = {atom, autoencoder: feature.autoencoder}
27 | setFeature(random_feature);
28 | onFeatureSubmit(random_feature);
29 | return false
30 | }
31 | const acknowledgeWarning = () => {
32 | localStorage.setItem('warningAcknowledged', 'true');
33 | setWarningAcknowledged(true);
34 | }
35 |
36 | if (!warningAcknowledged) {
37 | return (
41 | Note: by clicking this button you acknowledge that the content of the documents are taken randomly from the internet, and may contain offensive or inappropriate content.
42 | )
43 | }
44 |
45 | return (
46 | <>
47 |
48 |
49 | Subject model {" "}
50 | {
51 | let family = Object.values(AUTOENCODER_FAMILIES).find((family) => (family.subject === e.target.value));
52 | changeFeature({
53 | atom: 0, autoencoder: family.get_ae(family.default_H(feature.autoencoder.H))
54 | })
55 | }}>
56 | {Object.values(SUBJECT_MODELS).map((subject_model) => (
57 | {subject_model}
58 | ))}
59 |
60 |
61 | {" "}
62 |
63 |
64 | Autoencoder {" "}
65 | family {" "}
66 | {
67 | let family = AUTOENCODER_FAMILIES[e.target.value];
68 | changeFeature({
69 | atom: 0, autoencoder: family.get_ae(family.default_H(feature.autoencoder.H))
70 | })
71 | }}>
72 | {Object.values(AUTOENCODER_FAMILIES).filter((family) => (family.subject === feature.autoencoder.subject)).map((family) => (
73 | {AUTOENCODER_FAMILIES[family.name].label}
74 | )) }
75 |
76 |
77 | {
78 | AUTOENCODER_FAMILIES[feature.autoencoder.family].selectors.map((selector) => (
79 |
80 | {" "}
81 | {selector.label || selector.key}
82 | {
84 | let family = AUTOENCODER_FAMILIES[feature.autoencoder.family];
85 | changeFeature({
86 | atom: 0, autoencoder: family.get_ae({...feature.autoencoder.H, [selector.key]: e.target.value})
87 | })
88 | }}
89 | >
90 | {selector.values.map((value) => (
91 | {value}
92 | ))}
93 |
94 |
95 | ))
96 | }
97 | {
98 | family.warning ? Note: {family.warning} : null
99 | }
100 |
101 |
102 | Feature
103 | (!isNaN(parseInt(e.target.value))) && changeFeature({...feature, atom: parseInt(e.target.value)})}
111 | className="border border-gray-300 rounded-md p-2"
112 | />
113 | {
114 | show_go &&
115 | {
118 | e.preventDefault()
119 | onFeatureSubmit(feature)
120 | return false
121 | }}
122 | className="border border-gray-300 rounded-md p-2"
123 | style={{ width: 200 }}
124 | disabled={!warningAcknowledged}
125 | >
126 | Go to feature {feature.atom}
127 |
128 | }
129 |
130 |
135 | I'm feeling lucky
136 |
137 |
138 |
139 |
140 |
141 |
142 | >
143 | )
144 | }
145 |
--------------------------------------------------------------------------------
/sae-viewer/src/components/featureInfo.tsx:
--------------------------------------------------------------------------------
1 | import React, { useEffect, useState, useRef } from "react"
2 | import { normalizeSequences, SequenceInfo, Feature, FeatureInfo } from "../types"
3 | import TokenHeatmap from "./tokenHeatmap";
4 | import TokenAblationmap from "./tokenAblationmap";
5 | import Histogram from "./histogram"
6 | import Tooltip from "./tooltip"
7 |
8 | import {get_feature_info} from "../interpAPI"
9 |
10 | export default ({ feature }: {feature: Feature}) => {
11 | const [data, setData] = useState(null as FeatureInfo | null)
12 | const [showingMore, setShowingMore] = useState({})
13 | const [renderNewlines, setRenderNewlines] = useState(false)
14 | const [isLoading, setIsLoading] = useState(true)
15 | const [got_error, setError] = useState(null)
16 | const currentFeatureRef = useRef(feature);
17 |
18 |
19 | useEffect(() => {
20 | async function fetchData() {
21 | setIsLoading(true)
22 | try {
23 | currentFeatureRef.current = feature; // Update current feature in ref on each effect run
24 | const result = await get_feature_info(feature)
25 | if (currentFeatureRef.current !== feature) {
26 | return;
27 | }
28 | normalizeSequences(result.top, result.random)
29 | result.top.sort((a, b) => b.act - a.act);
30 | setData(result)
31 | setIsLoading(false)
32 | setError(null);
33 | } catch (e) {
34 | setError(e);
35 | }
36 | try {
37 | const result = await get_feature_info(feature, true)
38 | if (currentFeatureRef.current !== feature) {
39 | return;
40 | }
41 | normalizeSequences(result.top, result.random)
42 | result.top.sort((a, b) => b.act - a.act);
43 | setData(result)
44 | setIsLoading(false)
45 | setError(null);
46 | } catch (e) {
47 | setError('Note: ablation effects data not available for this model');
48 | }
49 | }
50 | fetchData()
51 | }, [feature])
52 |
53 | if (isLoading) {
54 | return (
55 |
56 |
57 |
loading top dataset examples
58 | {
59 | got_error ?
Error loading data: {got_error} : null
60 | }
61 |
62 | )
63 | }
64 | if (!data) {
65 | throw new Error('no data. this should not happen.')
66 | }
67 |
68 | const all_sequences = []
69 | all_sequences.push({
70 | // label: '[0, 1] (Random)',
71 | label: 'Random positive activations',
72 | sequences: data.random,
73 | default_show: 5,
74 | })
75 | all_sequences.push({
76 | // label: '[0.999, 1] (Top quantile, sorted. 50 of 50000)',
77 | label: 'Top activations',
78 | sequences: data.top,
79 | default_show: 5,
80 | })
81 |
82 | // const activations = data.top_activations;
83 | return (
84 |
85 | {
86 | got_error ?
{got_error} : null
87 | }
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 | 0]'}/>
96 |
97 | {data.density.toExponential(2)}
98 |
99 |
100 |
101 |
102 | {data.mean_act ? data.mean_act.toExponential(2) : 'data not available'}
103 |
104 |
105 | E[a2 ]>}/>
106 | {data.mean_act_squared ? data.mean_act_squared.toExponential(2): 'data not available'}
107 |
108 |
109 | E[a3 ]/(E[a2 ])1.5 >}/>
110 | {data.skew ? data.skew.toExponential(2) : 'data not available'}
111 |
112 |
113 | E[a4 ]/(E[a2 ])2 >}/>
114 | {data.kurtosis ? data.kurtosis.toExponential(2) : 'data not available'}
115 |
116 |
117 |
118 |
119 | {
120 | all_sequences.map(({label, sequences, default_show}, idx) => {
121 | // console.log('sequences', sequences)
122 | const n_show = showingMore[label] ? sequences.length : default_show;
123 | return (
124 |
125 |
126 | {label}
127 | setShowingMore({...showingMore, [label]: !showingMore[label]})}>
129 | {showingMore[label] ? 'show less' : 'show more'}
130 |
131 | setRenderNewlines(!renderNewlines)}>
133 | {renderNewlines ? 'collapse newlines' : 'show newlines'}
134 |
135 |
136 |
137 |
138 |
139 | Doc ID Token Activation Activations
140 | {sequences.length && sequences[0].ablate_loss_diff && Effects }
141 |
142 |
143 |
144 | {sequences.slice(0, n_show).map((sequence, i) => (
145 |
146 | {sequence.doc_id} {sequence.idx} {sequence.act.toFixed(2)}
147 |
148 |
149 |
150 | {
151 | sequence.ablate_loss_diff &&
152 |
153 |
154 |
155 | }
156 |
157 | ))}
158 |
159 |
160 |
161 | )
162 | })
163 | }
164 |
165 | )
166 | }
167 |
--------------------------------------------------------------------------------
/sae-viewer/src/App.css:
--------------------------------------------------------------------------------
1 | @tailwind base;
2 | @tailwind components;
3 | @tailwind utilities;
4 |
5 | select {
6 | margin: 5px;
7 | }
8 |
9 | :root {
10 | --secondary-color: #0d978b;
11 | --accent-color: #efefef;
12 | }
13 |
14 | table.activations-table {
15 | border: 1px solid gray;
16 | border-radius: 3px;
17 | border-spacing: 0;
18 | }
19 | table.activations-table td, table.activations-table th {
20 | border-bottom: 1px solid gray;
21 | border-right: 1px solid gray;
22 | border-left: 1px solid gray;
23 | }
24 | table.activations-table tr:last-child > td {
25 | border-bottom: none;
26 | }
27 |
28 | .full-width{
29 | width: 100vw;
30 | position: relative;
31 | margin-left: -50vw;
32 | left: 50%;
33 | }
34 |
35 | .App {
36 | text-align: center;
37 | }
38 |
39 | .center {
40 | text-align: center;
41 | }
42 |
43 | .App-logo {
44 | height: 40vmin;
45 | pointer-events: none;
46 | }
47 |
48 | @media (prefers-reduced-motion: no-preference) {
49 | .App-logo {
50 | animation: App-logo-spin infinite 20s linear;
51 | }
52 | }
53 |
54 | .App h1 {
55 | font-size: 1.75rem;
56 | }
57 |
58 | .App-article {
59 | background-color: #282c34;
60 | min-height: 100vh;
61 | display: flex;
62 | flex-direction: column;
63 | align-items: center;
64 | justify-content: center;
65 | font-size: calc(10px + 2vmin);
66 | color: white;
67 | }
68 |
69 | .App-link {
70 | color: #61dafb;
71 | }
72 |
73 | @keyframes App-logo-spin {
74 | from {
75 | transform: rotate(0deg);
76 | }
77 | to {
78 | transform: rotate(360deg);
79 | }
80 | }
81 |
82 |
83 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
84 | /* Structure
85 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
86 |
87 | body {
88 | margin: 0;
89 | padding: 0 1em;
90 | font-size: 12pt;
91 | }
92 |
93 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
94 | /* Typography
95 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
96 |
97 | h1 {
98 | font-size: 24pt;
99 | font-weight: 500;
100 | padding: 1em 0 0;
101 | display: block;
102 | color: #000;
103 | }
104 | h3 { padding: 0 0; }
105 | h2 { padding: 1em 0 0.5em 0; }
106 | h4, h5 {
107 | text-transform: uppercase;
108 | margin: 1em 0;
109 | justify-tracks: space-between;
110 | font-family: var(--sans-serif);
111 | font-size: 12pt;
112 | font-weight: 600;
113 | }
114 | h2, h3 { font-weight: 500; font-style: italic; }
115 | subtitle {
116 | color: #555;
117 | font-size: 18pt;
118 | font-style: italic;
119 | padding: 0;
120 | display: block;
121 | margin-bottom: 1em
122 | }
123 |
124 | a {
125 | transition: all .05s ease-in-out;
126 | color: #5c60c3 !important;
127 | font-style: normal;
128 | }
129 | a:hover { color: var(--accent-color)!important; }
130 | code, pre { color: var(--inline-code-color);
131 | background-color: #eee; border-radius: 3px; }
132 | pre { padding: 1em; margin: 2em 0; }
133 | code { padding: 0.3em; }
134 | .text-secondary, h3, h5 { color: var(--secondary-color); }
135 | .text-primary, h2,h4 { color: var(--primary-color); }
136 |
137 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
138 | /* Images
139 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
140 |
141 | img#logo {
142 | width: 50%;
143 | margin: 3em 0 0
144 | }
145 |
146 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
147 | /* Alerts */
148 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
149 |
150 | .alert {
151 | font-weight: 600;
152 | font-style: italic;
153 | display: block;
154 | background-color: #fff7f7;
155 | padding: 1em;
156 | margin: 0;
157 | border-radius: 5px;
158 | color: #f25555
159 | }
160 | .alert.cool {
161 | background-color: #f3f0fc;
162 | color: #7155cf;
163 | }
164 | .flash-alert {
165 | display: inline-block;
166 | transition: ease-in-out 1s;
167 | font-size: 14pt;
168 | margin: 1em 0;
169 | padding-top: 0.5em;
170 | }
171 | .flash-alert.success {
172 | color: #000;
173 | }
174 | .flash-alert.failure {
175 | color: red;
176 | }
177 | .flash-alert.hidden {
178 | display: none;
179 | }
180 |
181 |
182 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
183 | /* Sidenotes & Superscripts */
184 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
185 |
186 | body { counter-reset: count; }
187 | p { whitespace: nowrap; }
188 |
189 | /* Different behavior if the screen is too
190 | narrow to show a sidenote on the side. */
191 |
192 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
193 | /* Buttons */
194 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
195 |
196 | @media print {
197 | a.btn, button {
198 | display: none!important
199 | }
200 | }
201 |
202 | @media screen {
203 | a.btn, button {
204 | border-radius: 3px;
205 | color: #000 !important;
206 | text-decoration: none !important;
207 | font-size: 11pt;
208 | border: 1px solid #000;
209 | padding: 0.5em 1em;
210 | font-family: -apple-system,
211 | BlinkMacSystemFont,
212 | "avenir next",
213 | avenir,
214 | helvetica,
215 | "helvetica neue",
216 | ubuntu,
217 | roboto,
218 | noto,
219 | "segoe ui",
220 | arial,
221 | sans-serif !important;
222 | background: #fff;
223 | font-weight: 500;
224 | transition: all .05s ease-in-out,box-shadow-color .025s ease-in-out;
225 | display: inline-block;
226 | }
227 |
228 | a.btn:hover, button:hover {
229 | cursor: pointer;
230 | }
231 | a.btn:active, button.active, button:active {
232 | border: 1px solid;
233 | }
234 | a.btn.small,button.small {
235 | border: 1px solid #000;
236 | padding: .6em 1em;
237 | font-weight: 500
238 | }
239 | a.btn.small:hover,button.small:hover {
240 | }
241 | a.btn.small:active,button.small:active {
242 | }
243 | }
244 |
245 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
246 | /* Blockquotes & Epigraphs
247 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
248 |
249 | blockquote {
250 | margin: 1em;
251 | }
252 | div>blockquote>p {
253 | font-size: 13pt;
254 | color: #555;
255 | font-style: normal!important;
256 | margin: 0;
257 | padding: 1em 0 1.5em
258 | }
259 | blockquote > blockquote {
260 | padding: 0.5em 2em 1em 1.5em !important;
261 | }
262 |
263 | blockquote > blockquote,
264 | blockquote > blockquote > p {
265 | font-size: 14pt;
266 | padding: 0;
267 | margin: 0;
268 | text-align: center;
269 | font-style: italic;
270 | color: var(--epigraph-color);
271 | }
272 | blockquote footer {
273 | font-size: 12pt;
274 | text-align: inherit;
275 | display: block;
276 | font-style: normal;
277 | margin: 1em;
278 | color: #aaa;
279 | }
280 |
--------------------------------------------------------------------------------
/sparse_autoencoder/model.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Any
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | def LN(x: torch.Tensor, eps: float = 1e-5) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
9 | mu = x.mean(dim=-1, keepdim=True)
10 | x = x - mu
11 | std = x.std(dim=-1, keepdim=True)
12 | x = x / (std + eps)
13 | return x, mu, std
14 |
15 |
16 | class Autoencoder(nn.Module):
17 | """Sparse autoencoder
18 |
19 | Implements:
20 | latents = activation(encoder(x - pre_bias) + latent_bias)
21 | recons = decoder(latents) + pre_bias
22 | """
23 |
24 | def __init__(
25 | self, n_latents: int, n_inputs: int, activation: Callable = nn.ReLU(), tied: bool = False,
26 | normalize: bool = False
27 | ) -> None:
28 | """
29 | :param n_latents: dimension of the autoencoder latent
30 | :param n_inputs: dimensionality of the original data (e.g residual stream, number of MLP hidden units)
31 | :param activation: activation function
32 | :param tied: whether to tie the encoder and decoder weights
33 | """
34 | super().__init__()
35 |
36 | self.pre_bias = nn.Parameter(torch.zeros(n_inputs))
37 | self.encoder: nn.Module = nn.Linear(n_inputs, n_latents, bias=False)
38 | self.latent_bias = nn.Parameter(torch.zeros(n_latents))
39 | self.activation = activation
40 | if tied:
41 | self.decoder: nn.Linear | TiedTranspose = TiedTranspose(self.encoder)
42 | else:
43 | self.decoder = nn.Linear(n_latents, n_inputs, bias=False)
44 | self.normalize = normalize
45 |
46 | self.stats_last_nonzero: torch.Tensor
47 | self.latents_activation_frequency: torch.Tensor
48 | self.latents_mean_square: torch.Tensor
49 | self.register_buffer("stats_last_nonzero", torch.zeros(n_latents, dtype=torch.long))
50 | self.register_buffer(
51 | "latents_activation_frequency", torch.ones(n_latents, dtype=torch.float)
52 | )
53 | self.register_buffer("latents_mean_square", torch.zeros(n_latents, dtype=torch.float))
54 |
55 | def encode_pre_act(self, x: torch.Tensor, latent_slice: slice = slice(None)) -> torch.Tensor:
56 | """
57 | :param x: input data (shape: [batch, n_inputs])
58 | :param latent_slice: slice of latents to compute
59 | Example: latent_slice = slice(0, 10) to compute only the first 10 latents.
60 | :return: autoencoder latents before activation (shape: [batch, n_latents])
61 | """
62 | x = x - self.pre_bias
63 | latents_pre_act = F.linear(
64 | x, self.encoder.weight[latent_slice], self.latent_bias[latent_slice]
65 | )
66 | return latents_pre_act
67 |
68 | def preprocess(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:
69 | if not self.normalize:
70 | return x, dict()
71 | x, mu, std = LN(x)
72 | return x, dict(mu=mu, std=std)
73 |
74 | def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:
75 | """
76 | :param x: input data (shape: [batch, n_inputs])
77 | :return: autoencoder latents (shape: [batch, n_latents])
78 | """
79 | x, info = self.preprocess(x)
80 | return self.activation(self.encode_pre_act(x)), info
81 |
82 | def decode(self, latents: torch.Tensor, info: dict[str, Any] | None = None) -> torch.Tensor:
83 | """
84 | :param latents: autoencoder latents (shape: [batch, n_latents])
85 | :return: reconstructed data (shape: [batch, n_inputs])
86 | """
87 | ret = self.decoder(latents) + self.pre_bias
88 | if self.normalize:
89 | assert info is not None
90 | ret = ret * info["std"] + info["mu"]
91 | return ret
92 |
93 | def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
94 | """
95 | :param x: input data (shape: [batch, n_inputs])
96 | :return: autoencoder latents pre activation (shape: [batch, n_latents])
97 | autoencoder latents (shape: [batch, n_latents])
98 | reconstructed data (shape: [batch, n_inputs])
99 | """
100 | x, info = self.preprocess(x)
101 | latents_pre_act = self.encode_pre_act(x)
102 | latents = self.activation(latents_pre_act)
103 | recons = self.decode(latents, info)
104 |
105 | # set all indices of self.stats_last_nonzero where (latents != 0) to 0
106 | self.stats_last_nonzero *= (latents == 0).all(dim=0).long()
107 | self.stats_last_nonzero += 1
108 |
109 | return latents_pre_act, latents, recons
110 |
111 | @classmethod
112 | def from_state_dict(
113 | cls, state_dict: dict[str, torch.Tensor], strict: bool = True
114 | ) -> "Autoencoder":
115 | n_latents, d_model = state_dict["encoder.weight"].shape
116 |
117 | # Retrieve activation
118 | activation_class_name = state_dict.pop("activation", "ReLU")
119 | activation_class = ACTIVATIONS_CLASSES.get(activation_class_name, nn.ReLU)
120 | normalize = activation_class_name == "TopK" # NOTE: hacky way to determine if normalization is enabled
121 | activation_state_dict = state_dict.pop("activation_state_dict", {})
122 | if hasattr(activation_class, "from_state_dict"):
123 | activation = activation_class.from_state_dict(
124 | activation_state_dict, strict=strict
125 | )
126 | else:
127 | activation = activation_class()
128 | if hasattr(activation, "load_state_dict"):
129 | activation.load_state_dict(activation_state_dict, strict=strict)
130 |
131 | autoencoder = cls(n_latents, d_model, activation=activation, normalize=normalize)
132 | # Load remaining state dict
133 | autoencoder.load_state_dict(state_dict, strict=strict)
134 | return autoencoder
135 |
136 | def state_dict(self, destination=None, prefix="", keep_vars=False):
137 | sd = super().state_dict(destination, prefix, keep_vars)
138 | sd[prefix + "activation"] = self.activation.__class__.__name__
139 | if hasattr(self.activation, "state_dict"):
140 | sd[prefix + "activation_state_dict"] = self.activation.state_dict()
141 | return sd
142 |
143 |
144 | class TiedTranspose(nn.Module):
145 | def __init__(self, linear: nn.Linear):
146 | super().__init__()
147 | self.linear = linear
148 |
149 | def forward(self, x: torch.Tensor) -> torch.Tensor:
150 | assert self.linear.bias is None
151 | return F.linear(x, self.linear.weight.t(), None)
152 |
153 | @property
154 | def weight(self) -> torch.Tensor:
155 | return self.linear.weight.t()
156 |
157 | @property
158 | def bias(self) -> torch.Tensor:
159 | return self.linear.bias
160 |
161 |
162 | class TopK(nn.Module):
163 | def __init__(self, k: int, postact_fn: Callable = nn.ReLU()) -> None:
164 | super().__init__()
165 | self.k = k
166 | self.postact_fn = postact_fn
167 |
168 | def forward(self, x: torch.Tensor) -> torch.Tensor:
169 | topk = torch.topk(x, k=self.k, dim=-1)
170 | values = self.postact_fn(topk.values)
171 | # make all other values 0
172 | result = torch.zeros_like(x)
173 | result.scatter_(-1, topk.indices, values)
174 | return result
175 |
176 | def state_dict(self, destination=None, prefix="", keep_vars=False):
177 | state_dict = super().state_dict(destination, prefix, keep_vars)
178 | state_dict.update({prefix + "k": self.k, prefix + "postact_fn": self.postact_fn.__class__.__name__})
179 | return state_dict
180 |
181 | @classmethod
182 | def from_state_dict(cls, state_dict: dict[str, torch.Tensor], strict: bool = True) -> "TopK":
183 | k = state_dict["k"]
184 | postact_fn = ACTIVATIONS_CLASSES[state_dict["postact_fn"]]()
185 | return cls(k=k, postact_fn=postact_fn)
186 |
187 |
188 | ACTIVATIONS_CLASSES = {
189 | "ReLU": nn.ReLU,
190 | "Identity": nn.Identity,
191 | "TopK": TopK,
192 | }
193 |
--------------------------------------------------------------------------------
/sae-viewer/src/welcome.tsx:
--------------------------------------------------------------------------------
1 | import React from "react"
2 | import { useState, FormEvent } from "react"
3 | import { useNavigate } from "react-router-dom"
4 | import { Feature } from "./types"
5 | import FeatureSelect from "./components/featureSelect"
6 | import { pathForFeature, DEFAULT_AUTOENCODER, AUTOENCODER_FAMILIES } from "./autoencoder_registry"
7 |
8 | export default function Welcome() {
9 | const navigate = useNavigate()
10 |
11 | const GPT4_ATOMS_PER_SHARD = 1024;
12 | const displayFeatures = [
13 | /**************
14 | /* well explained + interesting
15 | ***************/
16 | {heading: 'GPT-4', heading_type: 'h4', feature: null, label: ''},
17 | {feature: {atom: 62 * GPT4_ATOMS_PER_SHARD + 53, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})},
18 | label: "humans have flaws", description: "descriptions of how humans are flawed"},
19 | {feature: {atom: 25 * GPT4_ATOMS_PER_SHARD + 8, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})},
20 | label: "police reports, especially child safety", description: "safety incidents especially related to children"},
21 | {feature: {atom: 9 * GPT4_ATOMS_PER_SHARD + 44, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})},
22 | label: "price changes", description: "ends of phrases describing commodity/equity price changes"},
23 | {feature: {atom: 17 * GPT4_ATOMS_PER_SHARD + 33, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})},
24 | label: "ratification (multilingual)", description: "ratification (multilingual)"},
25 | {feature: {atom: 3 * GPT4_ATOMS_PER_SHARD + 421, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})},
26 | label: "would [...]", description: "conditionals (things that would be true)"},
27 | {feature: {atom: 63 * GPT4_ATOMS_PER_SHARD + 8, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})},
28 | label: "identification documents (multilingual)", description: "identification documents (multilingual)"},
29 | {feature: {atom: 0 * GPT4_ATOMS_PER_SHARD + 14, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})},
30 | label: "lightly incremented timestamps", description: "timestamps being lightly incremented with recurring formats"},
31 | {heading: 'Technical knowledge', heading_type: 'h3', feature: null, label: ''},
32 | {feature: {atom: 40 * GPT4_ATOMS_PER_SHARD + 42, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})},
33 | label: "machine learning training logs", description: "machine learning training logs"},
34 | {feature: {atom: 12 * GPT4_ATOMS_PER_SHARD + 47, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})},
35 | label: "onclick/onchange = function(this)", description: "onclick/onchange = function(this)"},
36 | {feature: {atom: 54 * GPT4_ATOMS_PER_SHARD + 23, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})},
37 | label: "edges (graph theory) and related concepts", description: "edges (graph theory) and related concepts"},
38 | {feature: {atom: 56 * GPT4_ATOMS_PER_SHARD + 12, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})},
39 | label: "algebraic rings", description: "algebraic rings"},
40 | {feature: {atom: 28 * GPT4_ATOMS_PER_SHARD + 47, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})},
41 | label: "adenosine/dopamine receptors", description: "adenosine/dopamine receptors"},
42 | {feature: {atom: 2 * GPT4_ATOMS_PER_SHARD + 601, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})},
43 | label: "blockchain vibes", description: "blockchain vibes"},
44 |
45 |
46 | {heading: 'GPT-2 small', heading_type: 'h4', feature: null, label: ''},
47 | {feature: {atom: 488432, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({
48 | num_features: '2097152', num_active_features: '8'
49 | })}, label: "rhetorical questions", description: "rhetorical questions"},
50 | {feature: {atom: 2088200, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({
51 | num_features: '2097152', num_active_features: '8'
52 | })}, label: "counting human casualties", description: "counting human casualties"},
53 | {feature: {atom: 1621560, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({
54 | num_features: '2097152', num_active_features: '8'
55 | })}, label: "X and Y phrases", description: "X and -> Y"},
56 | {feature: {atom: 733, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({
57 | num_features: '32768', num_active_features: '8'
58 | })}, label: "Patrick/Patty surname predictor", description: "Predicts surnames after Patrick"},
59 | {feature: {atom: 64464, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({
60 | num_features: '131072', num_active_features: '32'
61 | })}, label: "things that are unknown", description: "things that are unknown"},
62 | {feature: {atom: 56907, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({ // similar to 33248
63 | num_features: '131072', num_active_features: '32'
64 | })}, label: "words in quotes", description: "predicts words in quotes"},
65 | {feature: {atom: 1605835, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({
66 | num_features: '2097152', num_active_features: '8'
67 | })}, label: "these/those responsible things", description: "these/those, in a phrase where something is responsible for something"},
68 | {feature: {atom: 8040, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({
69 | num_features: '8192', num_active_features: '32'
70 | })}, label: "2018 natural disasters", description: "2018 natural disasters"},
71 | {feature: {atom: 21464, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({
72 | num_features: '131072', num_active_features: '32'
73 | })}, label: "addition in code", description: "addition in code"},
74 | {feature: {atom: 66232, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({
75 | num_features: '131072', num_active_features: '32'
76 | })}, label: "function application", description: "function application"},
77 | {feature: {atom: 64464, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({
78 | num_features: '131072', num_active_features: '32'
79 | })}, label: "unclear/hidden things", description: "unclear/hidden things (top only)"},
80 | {feature: {atom: 10423, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({
81 | num_features: '131072', num_active_features: '32'
82 | })}, label: "what the ...", description: "[who/what/when/where/why/how] the"},
83 | {heading: 'Safety relevant features (found via attribution methods)', heading_type: 'h3', feature: null, label: ''},
84 | {feature: {atom: 64840, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({
85 | num_features: '131072', num_active_features: '32'
86 | })}, label: "profanity (1)", description: "activates in order to output profanity"},
87 | {feature: {atom: 104813, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({
88 | num_features: '131072', num_active_features: '32'
89 | })}, label: "profanity (2)", description: "activates on profanity"},
90 | {feature: {atom: 101090, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({
91 | num_features: '131072', num_active_features: '32'
92 | })}, label: "profanity (3)", description: "activates on 'fucking' (profane, not sexual contexts)"},
93 | {feature: {atom: 72185, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({
94 | num_features: '131072', num_active_features: '32'
95 | })}, label: "erotic content", description: "erotic content"},
96 | {feature: {atom: 69134, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({
97 | num_features: '131072', num_active_features: '32'
98 | })}, label: "[content warning] sexual abuse", description: "sexual abuse"},
99 | // {feature: {atom: 2, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({
100 | // num_features: '2097152', num_active_features: '8'
101 | // })}, label: "things being brought", description: "bring * -> together/back"},
102 | ]
103 |
104 | let [feature, setFeature] = useState({
105 | atom: 0, autoencoder: DEFAULT_AUTOENCODER
106 | })
107 | const handleClick = (click_feature: Feature) => {
108 | navigate(pathForFeature(click_feature))
109 | }
110 |
111 | return (
112 |
113 |
Welcome! This is a viewer for sparse autoencoders features trained in this paper
114 |
Pick a feature:
115 |
setFeature(f)}
118 | onFeatureSubmit={(f: Feature) => navigate(pathForFeature(f))}
119 | show_go={true}
120 | />
121 |
122 |
123 |
Interesting features:
124 |
125 |
128 | {displayFeatures.map(({ heading, heading_type, feature, label, description }, j) => (
129 | heading ?
130 | {React.createElement(heading_type, {}, heading)}
131 |
:
handleClick(feature)}
134 | style={{ width: 200 }}
135 | className="text-blue-500 hover:text-blue-700"
136 | title={description}
137 | >
138 | {label}
139 |
140 | ))}
141 |
142 |
143 |
144 |
145 | )
146 | }
147 |
--------------------------------------------------------------------------------
/sparse_autoencoder/explanations.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from typing import Any, Callable
3 |
4 | import blobfile as bf
5 |
6 |
7 | class Explanation(ABC):
8 | def predict(self, tokens: list[str]) -> list[float]:
9 | raise NotImplementedError
10 |
11 | def dump(self) -> bytes:
12 | raise NotImplementedError
13 |
14 | @classmethod
15 | def load(cls, serialized: Any) -> "Explanation":
16 | raise NotImplementedError
17 |
18 | def dumpf(self, filename: str):
19 | d = self.dump()
20 | assert isinstance(d, bytes)
21 | with bf.BlobFile(filename, "wb") as f:
22 | f.write(d)
23 |
24 | @classmethod
25 | def loadf(cls, filename: str):
26 | with bf.BlobFile(filename, "rb") as f:
27 | return cls.load(f.read())
28 |
29 |
30 | _ANY_TOKEN = "token(*)"
31 | _START_TOKEN = "token(^)"
32 | _SALIENCY_KEY = ""
33 |
34 |
35 | class NtgExplanation(Explanation):
36 | def __init__(self, trie: dict):
37 | self.trie = trie
38 |
39 | def todict(self) -> dict:
40 | return {
41 | "trie": self.trie,
42 | }
43 |
44 | @classmethod
45 | def load(cls, serialized: dict) -> "Explanation":
46 | assert isinstance(serialized, dict)
47 | return cls(serialized["trie"])
48 |
49 | def predict(self, tokens: list[str]) -> list[float]:
50 | predicted_acts = []
51 | # for each token, traverse the trie beginning from that token and proceeding in reverse order until we match
52 | # a pattern or are no longer able to traverse.
53 | for i in range(len(tokens)):
54 | curr = self.trie
55 | for j in range(i, -1, -1):
56 | if tokens[j] not in curr and _ANY_TOKEN not in curr:
57 | predicted_acts.append(0)
58 | break
59 | if tokens[j] in curr:
60 | curr = curr[tokens[j]]
61 | else:
62 | curr = curr[_ANY_TOKEN]
63 | if _SALIENCY_KEY in curr:
64 | predicted_acts.append(curr[_SALIENCY_KEY])
65 | break
66 | # if we"ve reached the end of the sequence and haven't found a saliency value, append 0.
67 | elif j == 0:
68 | if _START_TOKEN in curr:
69 | curr = curr[_START_TOKEN]
70 | assert _SALIENCY_KEY in curr
71 | predicted_acts.append(curr[_SALIENCY_KEY])
72 | break
73 | predicted_acts.append(0)
74 | # We should have appended a value for each token in the sequence.
75 | assert len(predicted_acts) == len(tokens)
76 | return predicted_acts
77 |
78 | # TODO make this more efficient
79 | def predict_many(self, tokens_batch: list[list[str]]) -> list[list[float]]:
80 | return [self.predict(t) for t in tokens_batch]
81 |
82 |
83 | def batched(iterable, bs):
84 | batch = []
85 | it = iter(iterable)
86 | while True:
87 | batch = []
88 | try:
89 | for _ in range(bs):
90 | batch.append(next(it))
91 | yield batch
92 | except StopIteration:
93 | if len(batch) > 0:
94 | yield batch
95 | return
96 |
97 |
98 | def apply_batched(fn, iterable, bs):
99 | for batch in batched(iterable, bs):
100 | ret = fn(batch)
101 | assert len(ret) == len(batch)
102 | yield from ret
103 |
104 |
105 | def batch_parallelize(algos, fn, batch_size):
106 | """
107 | Algorithms are coroutines that yield items to be processed in parallel.
108 | We concurrently run the algorithm on all items in the batch.
109 | """
110 | inputs = []
111 | for i, algo in enumerate(algos):
112 | inputs.append((i, next(algo)))
113 | results = [None] * len(algos)
114 | while len(inputs) > 0:
115 | ret = list(apply_batched(fn, [x[1] for x in inputs], batch_size))
116 | assert len(ret) == len(inputs)
117 | inds = [x[0] for x in inputs]
118 | inputs = []
119 | for i, r in zip(inds, ret):
120 | try:
121 | next_input = algos[i].send(r)
122 | inputs.append((i, next_input))
123 | except StopIteration as e:
124 | results[i] = e.value
125 | return results
126 |
127 |
128 | def create_n2g_explanation(
129 | model_fn: Callable, train_set: list[dict], batch_size: int = 16,
130 | padding_token=4808 # " _" for GPT-2
131 | ) -> NtgExplanation:
132 | truncated = []
133 | # for each one find the index of the selected activation in the doc. truncate the sequences after this point.
134 | for doc in train_set:
135 | # get index of selected activation. for docs stored in 'top', this is the max activation.
136 | # for docs stored in 'random', it is a random positive activation (we sample activations, not docs
137 | # to populate 'random', so docs with more positive activations are more likely to be included).
138 | max_idx = doc["idx"]
139 | truncated.append(
140 | {
141 | "act": doc["act"],
142 | "acts": doc["acts"][: max_idx + 1],
143 | "tokens": doc["tokens"][: max_idx + 1],
144 | "token_ints": doc["token_ints"][: max_idx + 1],
145 | }
146 | )
147 |
148 | def get_minimal_subsequence(doc):
149 | for i in range(len(doc["token_ints"]) - 1, -1, -1):
150 | atom_acts = yield doc["token_ints"][i:]
151 | assert (
152 | len(atom_acts) == len(doc["token_ints"]) - i
153 | ), f"{len(atom_acts)} != {len(doc['token_ints']) - i}"
154 | if atom_acts[-1] / doc["act"] >= 0.5:
155 | return {
156 | "tokens": doc["tokens"][i:],
157 | "token_ints": doc["token_ints"][i:],
158 | "subsequence_act": atom_acts[-1],
159 | "orig_act": doc["act"],
160 | }
161 | print("Warning: no minimal subsequence found")
162 | # raise ValueError("No minimal subsequence found")
163 | return {
164 | "tokens": doc["tokens"],
165 | "token_ints": doc["token_ints"],
166 | "subsequence_act": doc["act"],
167 | "orig_act": doc["act"],
168 | }
169 |
170 | minimal_subsequences = batch_parallelize(
171 | [get_minimal_subsequence(doc) for doc in truncated], model_fn, batch_size
172 | )
173 |
174 | start_padded = apply_batched(
175 | model_fn,
176 | [[padding_token] + doc["token_ints"] for doc in minimal_subsequences],
177 | batch_size,
178 | )
179 | for min_seq, pad_atom_acts in zip(minimal_subsequences, start_padded):
180 | min_seq["can_pad_start"] = pad_atom_acts[-1] / min_seq["orig_act"] >= 0.5
181 |
182 | # for m in minimal_subsequences:
183 | # print("\t" + "".join(m["tokens"]))
184 |
185 | # for each token in a minimal subsequence, replace it with a padding token and compute the saliency value (1 - (orig act / new act))
186 | for doc in minimal_subsequences:
187 | all_seqs = []
188 | for i in range(len(doc["token_ints"])):
189 | tokens = doc["token_ints"][:i] + [padding_token] + doc["token_ints"][i + 1 :]
190 | assert len(tokens) == len(doc["token_ints"])
191 | all_seqs.append(tokens)
192 | saliency_vals = []
193 | all_atom_acts = apply_batched(model_fn, all_seqs, batch_size)
194 | for atom_acts, tokens in zip(all_atom_acts, all_seqs):
195 | assert len(atom_acts) == len(tokens)
196 | saliency_vals.append(1 - (atom_acts[-1] / doc["subsequence_act"]))
197 | doc["saliency_vals"] = saliency_vals
198 |
199 | trie = {}
200 | for doc in minimal_subsequences:
201 | curr = trie
202 | for i, (token, saliency) in enumerate(zip(doc["tokens"][::-1], doc["saliency_vals"][::-1])):
203 | if saliency < 0.5:
204 | token = _ANY_TOKEN
205 | if token not in curr:
206 | curr[token] = {}
207 | curr = curr[token]
208 | if i == len(doc["tokens"]) - 1:
209 | if not doc["can_pad_start"]:
210 | curr[_START_TOKEN] = {}
211 | curr = curr[_START_TOKEN]
212 | curr[_SALIENCY_KEY] = doc["subsequence_act"]
213 |
214 | return NtgExplanation(trie)
215 |
216 |
217 | if __name__ == "__main__":
218 |
219 | def first_position_fn(tokens: list[list[float]]) -> list[list[float]]:
220 | return [[1.0] + [0.0] * (len(toks) - 1) for toks in tokens]
221 |
222 | expl = create_n2g_explanation(
223 | first_position_fn,
224 | [
225 | {
226 | "idx": 0,
227 | "act": 1.0,
228 | "acts": [1.0, 0.0, 0.0],
229 | "tokens": ["a", "b", "c"],
230 | "token_ints": [0, 1, 2],
231 | },
232 | {
233 | "idx": 0,
234 | "act": 1.0,
235 | "acts": [1.0, 0.0, 0.0],
236 | "tokens": ["b", "c", "d"],
237 | "token_ints": [1, 2, 3],
238 | },
239 | ],
240 | )
241 | print(expl.trie)
242 | assert expl.predict(["a", "b", "c"]) == [1.0, 0.0, 0.0]
243 | assert expl.predict(["c", "b", "a"]) == [1.0, 0.0, 0.0]
244 |
245 | def c_fn(tokens: list[list[float]]) -> list[list[float]]:
246 | return [[1.0 if tok == 2 else 0.0 for tok in toks] for toks in tokens]
247 |
248 | expl = create_n2g_explanation(
249 | c_fn,
250 | [
251 | {
252 | "idx": 2,
253 | "act": 1.0,
254 | "acts": [0.0, 0.0, 1.0],
255 | "tokens": ["a", "b", "c"],
256 | "token_ints": [0, 1, 2],
257 | },
258 | {
259 | "idx": 1,
260 | "act": 1.0,
261 | "acts": [0.0, 1.0, 0.0],
262 | "tokens": ["b", "c", "d"],
263 | "token_ints": [1, 2, 3],
264 | },
265 | ],
266 | )
267 |
268 | print(expl.trie)
269 | assert expl.predict(["b", "c"]) == [0.0, 1.0]
270 | assert expl.predict(["a", "a", "a"]) == [0.0, 0.0, 0.0]
271 | assert expl.predict(["c", "b", "c"]) == [1.0, 0.0, 1.0]
272 |
273 | def a_star_c_fn(tokens: list[list[float]]) -> list[list[float]]:
274 | return [
275 | [1.0 if (tok == 2 and 0 in toks[:i]) else 0.0 for i, tok in enumerate(toks)]
276 | for toks in tokens
277 | ]
278 |
279 | expl = create_n2g_explanation(
280 | a_star_c_fn,
281 | [
282 | {
283 | "idx": 2,
284 | "act": 1.0,
285 | "acts": [0.0, 0.0, 1.0],
286 | "tokens": ["a", "b", "c"],
287 | "token_ints": [0, 1, 2],
288 | },
289 | {
290 | "idx": 2,
291 | "act": 1.0,
292 | "acts": [0.0, 0.0, 1.0],
293 | "tokens": ["b", "a", "c"],
294 | "token_ints": [1, 0, 2],
295 | },
296 | {
297 | "idx": 1,
298 | "act": 1.0,
299 | "acts": [0.0, 1.0, 0.0],
300 | "tokens": ["a", "c", "d"],
301 | "token_ints": [0, 2, 3],
302 | },
303 | ],
304 | )
305 |
306 | print(expl.trie)
307 | assert expl.predict(["b", "c"]) == [0.0, 0.0]
308 | assert expl.predict(["a", "c"]) == [0.0, 1.0]
309 | assert expl.predict(["a", "e", "c", "a", "c"]) == [0.0, 0.0, 1.0, 0.0, 1.0]
310 | # NOTE: should be 0, 0, 0, 1 but we're not smart enough to handle this yet
311 | assert expl.predict(["a", "b", "b", "c"]) == [0.0, 0.0, 0.0, 0.0]
312 |
313 | def zero_fn(tokens: list[list[float]]) -> list[list[float]]:
314 | return [[0.0 for tok in toks] for toks in tokens]
315 |
316 | expl = create_n2g_explanation(zero_fn, [])
317 |
318 | print(expl.trie)
319 | assert expl.predict(["b", "c"]) == [0.0, 0.0]
320 | assert expl.predict(["a", "c"]) == [0.0, 0.0]
321 | assert expl.predict(["a", "e", "c", "a", "c"]) == [0.0, 0.0, 0.0, 0.0, 0.0]
322 |
--------------------------------------------------------------------------------
/sparse_autoencoder/kernels.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import triton
4 | import triton.language as tl
5 |
6 |
7 | ## kernels
8 |
9 |
10 | def triton_sparse_transpose_dense_matmul(
11 | sparse_indices: torch.Tensor,
12 | sparse_values: torch.Tensor,
13 | dense: torch.Tensor,
14 | N: int,
15 | BLOCK_SIZE_AK=128,
16 | ) -> torch.Tensor:
17 | """
18 | calculates sparse.T @ dense (i.e reducing along the collated dimension of sparse)
19 | dense must be contiguous along dim 0 (in other words, dense.T is contiguous)
20 |
21 | sparse_indices is shape (A, k)
22 | sparse_values is shape (A, k)
23 | dense is shape (A, B)
24 |
25 | output is shape (N, B)
26 | """
27 |
28 | assert sparse_indices.shape == sparse_values.shape
29 | assert sparse_indices.is_contiguous()
30 | assert sparse_values.is_contiguous()
31 | assert dense.is_contiguous() # contiguous along B
32 |
33 | K = sparse_indices.shape[1]
34 | A = dense.shape[0]
35 | B = dense.shape[1]
36 | assert sparse_indices.shape[0] == A
37 |
38 | # COO-format and sorted
39 | sorted_indices = sparse_indices.view(-1).sort()
40 | coo_indices = torch.stack(
41 | [
42 | torch.arange(A, device=sparse_indices.device).repeat_interleave(K)[
43 | sorted_indices.indices
44 | ],
45 | sorted_indices.values,
46 | ]
47 | ) # shape (2, A * K)
48 | coo_values = sparse_values.view(-1)[sorted_indices.indices] # shape (A * K,)
49 | return triton_coo_sparse_dense_matmul(coo_indices, coo_values, dense, N, BLOCK_SIZE_AK)
50 |
51 |
52 | def triton_coo_sparse_dense_matmul(
53 | coo_indices: torch.Tensor,
54 | coo_values: torch.Tensor,
55 | dense: torch.Tensor,
56 | N: int,
57 | BLOCK_SIZE_AK=128,
58 | ) -> torch.Tensor:
59 | AK = coo_indices.shape[1]
60 | B = dense.shape[1]
61 |
62 | out = torch.zeros(N, B, device=dense.device, dtype=coo_values.dtype)
63 |
64 | grid = lambda META: (
65 | triton.cdiv(AK, META["BLOCK_SIZE_AK"]),
66 | 1,
67 | )
68 | triton_sparse_transpose_dense_matmul_kernel[grid](
69 | coo_indices,
70 | coo_values,
71 | dense,
72 | out,
73 | stride_da=dense.stride(0),
74 | stride_db=dense.stride(1),
75 | B=B,
76 | N=N,
77 | AK=AK,
78 | BLOCK_SIZE_AK=BLOCK_SIZE_AK,
79 | BLOCK_SIZE_B=triton.next_power_of_2(B),
80 | )
81 | return out
82 |
83 |
84 | @triton.jit
85 | def triton_sparse_transpose_dense_matmul_kernel(
86 | coo_indices_ptr,
87 | coo_values_ptr,
88 | dense_ptr,
89 | out_ptr,
90 | stride_da,
91 | stride_db,
92 | B,
93 | N,
94 | AK,
95 | BLOCK_SIZE_AK: tl.constexpr,
96 | BLOCK_SIZE_B: tl.constexpr,
97 | ):
98 | """
99 | coo_indices is shape (2, AK)
100 | coo_values is shape (AK,)
101 | dense is shape (A, B), contiguous along B
102 | out is shape (N, B)
103 | """
104 |
105 | pid_ak = tl.program_id(0)
106 | pid_b = tl.program_id(1)
107 |
108 | coo_offsets = tl.arange(0, BLOCK_SIZE_AK)
109 | b_offsets = tl.arange(0, BLOCK_SIZE_B)
110 |
111 | A_coords = tl.load(
112 | coo_indices_ptr + pid_ak * BLOCK_SIZE_AK + coo_offsets,
113 | mask=pid_ak * BLOCK_SIZE_AK + coo_offsets < AK,
114 | )
115 | K_coords = tl.load(
116 | coo_indices_ptr + pid_ak * BLOCK_SIZE_AK + coo_offsets + AK,
117 | mask=pid_ak * BLOCK_SIZE_AK + coo_offsets < AK,
118 | )
119 | values = tl.load(
120 | coo_values_ptr + pid_ak * BLOCK_SIZE_AK + coo_offsets,
121 | mask=pid_ak * BLOCK_SIZE_AK + coo_offsets < AK,
122 | )
123 |
124 | last_k = tl.min(K_coords)
125 | accum = tl.zeros((BLOCK_SIZE_B,), dtype=tl.float32)
126 |
127 | for ind in range(BLOCK_SIZE_AK):
128 | if ind + pid_ak * BLOCK_SIZE_AK < AK:
129 | # workaround to do A_coords[ind]
130 | a = tl.sum(
131 | tl.where(
132 | tl.arange(0, BLOCK_SIZE_AK) == ind,
133 | A_coords,
134 | tl.zeros((BLOCK_SIZE_AK,), dtype=tl.int64),
135 | )
136 | )
137 |
138 | k = tl.sum(
139 | tl.where(
140 | tl.arange(0, BLOCK_SIZE_AK) == ind,
141 | K_coords,
142 | tl.zeros((BLOCK_SIZE_AK,), dtype=tl.int64),
143 | )
144 | )
145 |
146 | v = tl.sum(
147 | tl.where(
148 | tl.arange(0, BLOCK_SIZE_AK) == ind,
149 | values,
150 | tl.zeros((BLOCK_SIZE_AK,), dtype=tl.float32),
151 | )
152 | )
153 |
154 | tl.device_assert(k < N)
155 |
156 | if k != last_k:
157 | tl.atomic_add(
158 | out_ptr + last_k * B + BLOCK_SIZE_B * pid_b + b_offsets,
159 | accum,
160 | mask=BLOCK_SIZE_B * pid_b + b_offsets < B,
161 | )
162 | accum *= 0
163 | last_k = k
164 |
165 | if v != 0:
166 | accum += v * tl.load(dense_ptr + a * stride_da + b_offsets, mask=b_offsets < B)
167 |
168 | tl.atomic_add(
169 | out_ptr + last_k * B + BLOCK_SIZE_B * pid_b + b_offsets,
170 | accum,
171 | mask=BLOCK_SIZE_B * pid_b + b_offsets < B,
172 | )
173 |
174 |
175 | def triton_sparse_dense_matmul(
176 | sparse_indices: torch.Tensor,
177 | sparse_values: torch.Tensor,
178 | dense: torch.Tensor,
179 | ) -> torch.Tensor:
180 | """
181 | calculates sparse @ dense (i.e reducing along the uncollated dimension of sparse)
182 | dense must be contiguous along dim 0 (in other words, dense.T is contiguous)
183 |
184 | sparse_indices is shape (A, k)
185 | sparse_values is shape (A, k)
186 | dense is shape (N, B)
187 |
188 | output is shape (A, B)
189 | """
190 | N = dense.shape[0]
191 | assert sparse_indices.shape == sparse_values.shape
192 | assert sparse_indices.is_contiguous()
193 | assert sparse_values.is_contiguous()
194 | assert dense.is_contiguous() # contiguous along B
195 |
196 | A = sparse_indices.shape[0]
197 | K = sparse_indices.shape[1]
198 | B = dense.shape[1]
199 |
200 | out = torch.zeros(A, B, device=dense.device, dtype=sparse_values.dtype)
201 |
202 | triton_sparse_dense_matmul_kernel[(A,)](
203 | sparse_indices,
204 | sparse_values,
205 | dense,
206 | out,
207 | stride_dn=dense.stride(0),
208 | stride_db=dense.stride(1),
209 | A=A,
210 | B=B,
211 | N=N,
212 | K=K,
213 | BLOCK_SIZE_K=triton.next_power_of_2(K),
214 | BLOCK_SIZE_B=triton.next_power_of_2(B),
215 | )
216 | return out
217 |
218 |
219 | @triton.jit
220 | def triton_sparse_dense_matmul_kernel(
221 | sparse_indices_ptr,
222 | sparse_values_ptr,
223 | dense_ptr,
224 | out_ptr,
225 | stride_dn,
226 | stride_db,
227 | A,
228 | B,
229 | N,
230 | K,
231 | BLOCK_SIZE_K: tl.constexpr,
232 | BLOCK_SIZE_B: tl.constexpr,
233 | ):
234 | """
235 | sparse_indices is shape (A, K)
236 | sparse_values is shape (A, K)
237 | dense is shape (N, B), contiguous along B
238 | out is shape (A, B)
239 | """
240 |
241 | pid = tl.program_id(0)
242 |
243 | offsets_k = tl.arange(0, BLOCK_SIZE_K)
244 | sparse_indices = tl.load(
245 | sparse_indices_ptr + pid * K + offsets_k, mask=offsets_k < K
246 | ) # shape (K,)
247 | sparse_values = tl.load(
248 | sparse_values_ptr + pid * K + offsets_k, mask=offsets_k < K
249 | ) # shape (K,)
250 |
251 | accum = tl.zeros((BLOCK_SIZE_B,), dtype=tl.float32)
252 |
253 | offsets_b = tl.arange(0, BLOCK_SIZE_B)
254 |
255 | for k in range(K):
256 | # workaround to do sparse_indices[k]
257 | i = tl.sum(
258 | tl.where(
259 | tl.arange(0, BLOCK_SIZE_K) == k,
260 | sparse_indices,
261 | tl.zeros((BLOCK_SIZE_K,), dtype=tl.int64),
262 | )
263 | )
264 | # workaround to do sparse_values[k]
265 | v = tl.sum(
266 | tl.where(
267 | tl.arange(0, BLOCK_SIZE_K) == k,
268 | sparse_values,
269 | tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32),
270 | )
271 | )
272 |
273 | tl.device_assert(i < N)
274 | if v != 0:
275 | accum += v * tl.load(
276 | dense_ptr + i * stride_dn + offsets_b * stride_db, mask=offsets_b < B
277 | )
278 |
279 | tl.store(out_ptr + pid * B + offsets_b, accum.to(sparse_values.dtype), mask=offsets_b < B)
280 |
281 |
282 | def triton_dense_dense_sparseout_matmul(
283 | dense1: torch.Tensor,
284 | dense2: torch.Tensor,
285 | at_indices: torch.Tensor,
286 | ) -> torch.Tensor:
287 | """
288 | dense1: shape (A, B)
289 | dense2: shape (B, N)
290 | at_indices: shape (A, K)
291 | out values: shape (A, K)
292 | calculates dense1 @ dense2 only for the indices in at_indices
293 |
294 | equivalent to (dense1 @ dense2).gather(1, at_indices)
295 | """
296 | A, B = dense1.shape
297 | N = dense2.shape[1]
298 | assert dense2.shape[0] == B
299 | assert at_indices.shape[0] == A
300 | K = at_indices.shape[1]
301 | assert at_indices.is_contiguous()
302 |
303 | assert dense1.stride(1) == 1, "dense1 must be contiguous along B"
304 | assert dense2.stride(0) == 1, "dense2 must be contiguous along B"
305 |
306 | if K > 512:
307 | # print("WARN - using naive matmul for large K")
308 | # naive is more efficient for large K
309 | return (dense1 @ dense2).gather(1, at_indices)
310 |
311 | out = torch.zeros(A, K, device=dense1.device, dtype=dense1.dtype)
312 |
313 | # grid = lambda META: (triton.cdiv(A, META['BLOCK_SIZE_A']),)
314 |
315 | triton_dense_dense_sparseout_matmul_kernel[(A,)](
316 | dense1,
317 | dense2,
318 | at_indices,
319 | out,
320 | stride_d1a=dense1.stride(0),
321 | stride_d1b=dense1.stride(1),
322 | stride_d2b=dense2.stride(0),
323 | stride_d2n=dense2.stride(1),
324 | A=A,
325 | B=B,
326 | N=N,
327 | K=K,
328 | BLOCK_SIZE_B=triton.next_power_of_2(B),
329 | BLOCK_SIZE_N=triton.next_power_of_2(N),
330 | BLOCK_SIZE_K=triton.next_power_of_2(K),
331 | )
332 |
333 | return out
334 |
335 |
336 | @triton.jit
337 | def triton_dense_dense_sparseout_matmul_kernel(
338 | dense1_ptr,
339 | dense2_ptr,
340 | at_indices_ptr,
341 | out_ptr,
342 | stride_d1a,
343 | stride_d1b,
344 | stride_d2b,
345 | stride_d2n,
346 | A,
347 | B,
348 | N,
349 | K,
350 | BLOCK_SIZE_B: tl.constexpr,
351 | BLOCK_SIZE_N: tl.constexpr,
352 | BLOCK_SIZE_K: tl.constexpr,
353 | ):
354 | """
355 | dense1: shape (A, B)
356 | dense2: shape (B, N)
357 | at_indices: shape (A, K)
358 | out values: shape (A, K)
359 | """
360 |
361 | pid = tl.program_id(0)
362 |
363 | offsets_k = tl.arange(0, BLOCK_SIZE_K)
364 | at_indices = tl.load(at_indices_ptr + pid * K + offsets_k, mask=offsets_k < K) # shape (K,)
365 |
366 | offsets_b = tl.arange(0, BLOCK_SIZE_B)
367 | dense1 = tl.load(
368 | dense1_ptr + pid * stride_d1a + offsets_b * stride_d1b, mask=offsets_b < B
369 | ) # shape (B,)
370 |
371 | accum = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32)
372 |
373 | for k in range(K):
374 | # workaround to do at_indices[b]
375 | i = tl.sum(
376 | tl.where(
377 | tl.arange(0, BLOCK_SIZE_K) == k,
378 | at_indices,
379 | tl.zeros((BLOCK_SIZE_K,), dtype=tl.int64),
380 | )
381 | )
382 | tl.device_assert(i < N)
383 |
384 | dense2col = tl.load(
385 | dense2_ptr + offsets_b * stride_d2b + i * stride_d2n, mask=offsets_b < B
386 | ) # shape (B,)
387 | accum += tl.where(
388 | tl.arange(0, BLOCK_SIZE_K) == k,
389 | tl.sum(dense1 * dense2col),
390 | tl.zeros((BLOCK_SIZE_K,), dtype=tl.int64),
391 | )
392 |
393 | tl.store(out_ptr + pid * K + offsets_k, accum, mask=offsets_k < K)
394 |
395 |
396 | class TritonDecoderAutograd(torch.autograd.Function):
397 | @staticmethod
398 | def forward(ctx, sparse_indices, sparse_values, decoder_weight):
399 | ctx.save_for_backward(sparse_indices, sparse_values, decoder_weight)
400 | return triton_sparse_dense_matmul(sparse_indices, sparse_values, decoder_weight.T)
401 |
402 | @staticmethod
403 | def backward(ctx, grad_output):
404 | sparse_indices, sparse_values, decoder_weight = ctx.saved_tensors
405 |
406 | assert grad_output.is_contiguous(), "grad_output must be contiguous; this is probably because the subsequent op was a .sum() or something like that, which returns a non contiguous gradient"
407 |
408 | decoder_grad = triton_sparse_transpose_dense_matmul(
409 | sparse_indices, sparse_values, grad_output, N=decoder_weight.shape[1]
410 | ).T
411 |
412 | return (
413 | None,
414 | triton_dense_dense_sparseout_matmul(grad_output, decoder_weight, sparse_indices),
415 | # decoder is contiguous when transposed so this is a matching layout
416 | decoder_grad,
417 | None,
418 | )
419 |
420 |
421 | def triton_add_mul_(
422 | x: torch.Tensor,
423 | a: torch.Tensor,
424 | b: torch.Tensor,
425 | c: float,
426 | ):
427 | """
428 | does
429 | x += a * b * c
430 |
431 | x : [m, n]
432 | a : [m, n]
433 | b : [m, n]
434 | c : float
435 | """
436 |
437 | if len(a.shape) == 1:
438 | a = a[None, :].broadcast_to(x.shape)
439 |
440 | if len(b.shape) == 1:
441 | b = b[None, :].broadcast_to(x.shape)
442 |
443 | assert x.shape == a.shape == b.shape
444 |
445 | BLOCK_SIZE_M = 64
446 | BLOCK_SIZE_N = 64
447 | grid = lambda META: (
448 | triton.cdiv(x.shape[0], META["BLOCK_SIZE_M"]),
449 | triton.cdiv(x.shape[1], META["BLOCK_SIZE_N"]),
450 | )
451 | triton_add_mul_kernel[grid](
452 | x,
453 | a,
454 | b,
455 | c,
456 | x.stride(0),
457 | x.stride(1),
458 | a.stride(0),
459 | a.stride(1),
460 | b.stride(0),
461 | b.stride(1),
462 | BLOCK_SIZE_M,
463 | BLOCK_SIZE_N,
464 | x.shape[0],
465 | x.shape[1],
466 | )
467 |
468 |
469 | @triton.jit
470 | def triton_add_mul_kernel(
471 | x_ptr,
472 | a_ptr,
473 | b_ptr,
474 | c,
475 | stride_x0,
476 | stride_x1,
477 | stride_a0,
478 | stride_a1,
479 | stride_b0,
480 | stride_b1,
481 | BLOCK_SIZE_M: tl.constexpr,
482 | BLOCK_SIZE_N: tl.constexpr,
483 | M: tl.constexpr,
484 | N: tl.constexpr,
485 | ):
486 | pid_m = tl.program_id(0)
487 | pid_n = tl.program_id(1)
488 |
489 | offsets_m = tl.arange(0, BLOCK_SIZE_M) + pid_m * BLOCK_SIZE_M
490 | offsets_n = tl.arange(0, BLOCK_SIZE_N) + pid_n * BLOCK_SIZE_N
491 |
492 | x = tl.load(
493 | x_ptr + offsets_m[:, None] * stride_x0 + offsets_n[None, :] * stride_x1,
494 | mask=(offsets_m[:, None] < M) & (offsets_n[None, :] < N),
495 | )
496 | a = tl.load(
497 | a_ptr + offsets_m[:, None] * stride_a0 + offsets_n[None, :] * stride_a1,
498 | mask=(offsets_m[:, None] < M) & (offsets_n[None, :] < N),
499 | )
500 | b = tl.load(
501 | b_ptr + offsets_m[:, None] * stride_b0 + offsets_n[None, :] * stride_b1,
502 | mask=(offsets_m[:, None] < M) & (offsets_n[None, :] < N),
503 | )
504 |
505 | x_dtype = x.dtype
506 | x = (x.to(tl.float32) + a.to(tl.float32) * b.to(tl.float32) * c).to(x_dtype)
507 |
508 | tl.store(
509 | x_ptr + offsets_m[:, None] * stride_x0 + offsets_n[None, :] * stride_x1,
510 | x,
511 | mask=(offsets_m[:, None] < M) & (offsets_n[None, :] < N),
512 | )
513 |
514 |
515 |
516 | def triton_sum_dim0_in_fp32(xs):
517 | a, b = xs.shape
518 |
519 | assert xs.is_contiguous()
520 | assert xs.dtype == torch.float16
521 |
522 | BLOCK_SIZE_A = min(triton.next_power_of_2(a), 512)
523 | BLOCK_SIZE_B = 64 # cache line is 128 bytes
524 |
525 | out = torch.zeros(b, dtype=torch.float32, device=xs.device)
526 |
527 | grid = lambda META: (triton.cdiv(b, META["BLOCK_SIZE_B"]),)
528 |
529 | triton_sum_dim0_in_fp32_kernel[grid](
530 | xs,
531 | out,
532 | stride_a=xs.stride(0),
533 | a=a,
534 | b=b,
535 | BLOCK_SIZE_A=BLOCK_SIZE_A,
536 | BLOCK_SIZE_B=BLOCK_SIZE_B,
537 | )
538 |
539 | return out
540 |
541 |
542 | @triton.jit
543 | def triton_sum_dim0_in_fp32_kernel(
544 | xs_ptr,
545 | out_ptr,
546 | stride_a,
547 | a,
548 | b,
549 | BLOCK_SIZE_A: tl.constexpr,
550 | BLOCK_SIZE_B: tl.constexpr,
551 | ):
552 | # each program handles 64 columns of xs
553 | pid = tl.program_id(0)
554 | offsets_b = tl.arange(0, BLOCK_SIZE_B) + pid * BLOCK_SIZE_B
555 |
556 | all_out = tl.zeros((BLOCK_SIZE_B,), dtype=tl.float32)
557 |
558 | for i in range(0, a, BLOCK_SIZE_A):
559 | offsets_a = tl.arange(0, BLOCK_SIZE_A) + i
560 | xs = tl.load(
561 | xs_ptr + offsets_a[:, None] * stride_a + offsets_b[None, :],
562 | mask=(offsets_a < a)[:, None] & (offsets_b < b)[None, :],
563 | other=0,
564 | )
565 | xs = xs.to(tl.float32)
566 | out = tl.sum(xs, axis=0)
567 | all_out += out
568 |
569 | tl.store(out_ptr + offsets_b, all_out, mask=offsets_b < b)
570 |
571 |
572 | def mse(
573 | output,
574 | target,
575 | ): # fusing fp32 cast and MSE to save memory
576 | assert output.shape == target.shape
577 | assert len(output.shape) == 2
578 | assert output.stride(1) == 1
579 | assert target.stride(1) == 1
580 |
581 | a, b = output.shape
582 |
583 | BLOCK_SIZE_B = triton.next_power_of_2(b)
584 |
585 | class _MSE(torch.autograd.Function):
586 | @staticmethod
587 | def forward(ctx, output, target):
588 | ctx.save_for_backward(output, target)
589 | out = torch.zeros(a, dtype=torch.float32, device=output.device)
590 |
591 | triton_mse_loss_fp16_kernel[(a,)](
592 | output,
593 | target,
594 | out,
595 | stride_a_output=output.stride(0),
596 | stride_a_target=target.stride(0),
597 | a=a,
598 | b=b,
599 | BLOCK_SIZE_B=BLOCK_SIZE_B,
600 | )
601 |
602 | return out
603 |
604 | @staticmethod
605 | def backward(ctx, grad_output):
606 | output, target = ctx.saved_tensors
607 | res = (output - target).float()
608 | res *= grad_output[:, None] * 2 / b
609 | return res, None
610 |
611 | return _MSE.apply(output, target).mean()
612 |
613 |
614 | def normalized_mse(recon: torch.Tensor, xs: torch.Tensor) -> torch.Tensor:
615 | # only used for auxk
616 | xs_mu = (
617 | triton_sum_dim0_in_fp32(xs) / xs.shape[0]
618 | if xs.dtype == torch.float16
619 | else xs.mean(dim=0)
620 | )
621 |
622 | loss = mse(recon, xs) / mse(
623 | xs_mu[None, :].broadcast_to(xs.shape), xs
624 | )
625 |
626 | return loss
627 |
628 |
629 | @triton.jit
630 | def triton_mse_loss_fp16_kernel(
631 | output_ptr,
632 | target_ptr,
633 | out_ptr,
634 | stride_a_output,
635 | stride_a_target,
636 | a,
637 | b,
638 | BLOCK_SIZE_B: tl.constexpr,
639 | ):
640 | pid = tl.program_id(0)
641 | offsets_b = tl.arange(0, BLOCK_SIZE_B)
642 |
643 | output = tl.load(
644 | output_ptr + pid * stride_a_output + offsets_b,
645 | mask=offsets_b < b,
646 | )
647 | target = tl.load(
648 | target_ptr + pid * stride_a_target + offsets_b,
649 | mask=offsets_b < b,
650 | )
651 |
652 | output = output.to(tl.float32)
653 | target = target.to(tl.float32)
654 |
655 | mse = tl.sum((output - target) * (output - target)) / b
656 |
657 | tl.store(out_ptr + pid, mse)
658 |
659 |
660 | def triton_add_mul_(
661 | x: torch.Tensor,
662 | a: torch.Tensor,
663 | b: torch.Tensor,
664 | c: float,
665 | ):
666 | """
667 | does
668 | x += a * b * c
669 |
670 | x : [m, n]
671 | a : [m, n]
672 | b : [m, n]
673 | c : float
674 | """
675 |
676 | if len(a.shape) == 1:
677 | a = a[None, :].broadcast_to(x.shape)
678 |
679 | if len(b.shape) == 1:
680 | b = b[None, :].broadcast_to(x.shape)
681 |
682 | assert x.shape == a.shape == b.shape
683 |
684 | BLOCK_SIZE_M = 64
685 | BLOCK_SIZE_N = 64
686 | grid = lambda META: (
687 | triton.cdiv(x.shape[0], META["BLOCK_SIZE_M"]),
688 | triton.cdiv(x.shape[1], META["BLOCK_SIZE_N"]),
689 | )
690 | triton_add_mul_kernel[grid](
691 | x,
692 | a,
693 | b,
694 | c,
695 | x.stride(0),
696 | x.stride(1),
697 | a.stride(0),
698 | a.stride(1),
699 | b.stride(0),
700 | b.stride(1),
701 | BLOCK_SIZE_M,
702 | BLOCK_SIZE_N,
703 | x.shape[0],
704 | x.shape[1],
705 | )
706 |
--------------------------------------------------------------------------------
/sparse_autoencoder/train.py:
--------------------------------------------------------------------------------
1 | # bare bones training script using sparse kernels and sharding/data parallel.
2 | # the main purpose of this code is to provide a reference implementation to compare
3 | # against when implementing our training methodology into other codebases, and to
4 | # demonstrate how sharding/DP can be implemented for autoencoders. some limitations:
5 | # - many basic features (e.g checkpointing, data loading, validation) are not implemented,
6 | # - the codebase is not designed to be extensible or easily hackable.
7 | # - this code is not guaranteed to run efficiently out of the box / in
8 | # combination with other changes, so you should profile it and make changes as needed.
9 | #
10 | # example launch command:
11 | # torchrun --nproc-per-node 8 train.py
12 |
13 |
14 | import os
15 | from dataclasses import dataclass
16 | from typing import Callable, Iterable, Iterator
17 |
18 | import torch
19 | import torch.distributed as dist
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 | import triton
23 | import triton.language as tl
24 | from sparse_autoencoder.kernels import *
25 | from torch.distributed import ReduceOp
26 |
27 | RANK = int(os.environ.get("RANK", "0"))
28 |
29 |
30 | ## parallelism
31 |
32 |
33 | @dataclass
34 | class Comm:
35 | group: torch.distributed.ProcessGroup
36 |
37 | def all_reduce(self, x, op=ReduceOp.SUM, async_op=False):
38 | return dist.all_reduce(x, op=op, group=self.group, async_op=async_op)
39 |
40 | def all_gather(self, x_list, x, async_op=False):
41 | return dist.all_gather(list(x_list), x, group=self.group, async_op=async_op)
42 |
43 | def broadcast(self, x, src, async_op=False):
44 | return dist.broadcast(x, src, group=self.group, async_op=async_op)
45 |
46 | def barrier(self):
47 | return dist.barrier(group=self.group)
48 |
49 | def size(self):
50 | return self.group.size()
51 |
52 |
53 | @dataclass
54 | class ShardingComms:
55 | n_replicas: int
56 | n_op_shards: int
57 | dp_rank: int
58 | sh_rank: int
59 | dp_comm: Comm | None
60 | sh_comm: Comm | None
61 | _rank: int
62 |
63 | def sh_allreduce_forward(self, x: torch.Tensor) -> torch.Tensor:
64 | if self.sh_comm is None:
65 | return x
66 |
67 | class AllreduceForward(torch.autograd.Function):
68 | @staticmethod
69 | def forward(ctx, input):
70 | assert self.sh_comm is not None
71 | self.sh_comm.all_reduce(input, async_op=True)
72 | return input
73 |
74 | @staticmethod
75 | def backward(ctx, grad_output):
76 | return grad_output
77 |
78 | return AllreduceForward.apply(x) # type: ignore
79 |
80 | def sh_allreduce_backward(self, x: torch.Tensor) -> torch.Tensor:
81 | if self.sh_comm is None:
82 | return x
83 |
84 | class AllreduceBackward(torch.autograd.Function):
85 | @staticmethod
86 | def forward(ctx, input):
87 | return input
88 |
89 | @staticmethod
90 | def backward(ctx, grad_output):
91 | grad_output = grad_output.clone()
92 | assert self.sh_comm is not None
93 | self.sh_comm.all_reduce(grad_output, async_op=True)
94 | return grad_output
95 |
96 | return AllreduceBackward.apply(x) # type: ignore
97 |
98 | def init_broadcast_(self, autoencoder):
99 | if self.dp_comm is not None:
100 | for p in autoencoder.parameters():
101 | self.dp_comm.broadcast(
102 | maybe_transpose(p.data),
103 | replica_shard_to_rank(
104 | replica_idx=0,
105 | shard_idx=self.sh_rank,
106 | n_op_shards=self.n_op_shards,
107 | ),
108 | )
109 |
110 | if self.sh_comm is not None:
111 | # pre_bias is the same across all shards
112 | self.sh_comm.broadcast(
113 | autoencoder.pre_bias.data,
114 | replica_shard_to_rank(
115 | replica_idx=self.dp_rank,
116 | shard_idx=0,
117 | n_op_shards=self.n_op_shards,
118 | ),
119 | )
120 |
121 | def dp_allreduce_(self, autoencoder) -> None:
122 | if self.dp_comm is None:
123 | return
124 |
125 | for param in autoencoder.parameters():
126 | if param.grad is not None:
127 | self.dp_comm.all_reduce(maybe_transpose(param.grad), op=ReduceOp.AVG, async_op=True)
128 |
129 | # make sure statistics for dead neurons are correct
130 | self.dp_comm.all_reduce( # type: ignore
131 | autoencoder.stats_last_nonzero, op=ReduceOp.MIN, async_op=True
132 | )
133 |
134 | def sh_allreduce_scale(self, scaler):
135 | if self.sh_comm is None:
136 | return
137 |
138 | if hasattr(scaler, "_scale") and scaler._scale is not None:
139 | self.sh_comm.all_reduce(scaler._scale, op=ReduceOp.MIN, async_op=True)
140 | self.sh_comm.all_reduce(scaler._growth_tracker, op=ReduceOp.MIN, async_op=True)
141 |
142 | def _sh_comm_op(self, x, op):
143 | if isinstance(x, (float, int)):
144 | x = torch.tensor(x, device="cuda")
145 |
146 | if not x.is_cuda:
147 | x = x.cuda()
148 |
149 | if self.sh_comm is None:
150 | return x
151 |
152 | out = x.clone()
153 | self.sh_comm.all_reduce(x, op=op, async_op=True)
154 | return out
155 |
156 | def sh_sum(self, x: torch.Tensor) -> torch.Tensor:
157 | return self._sh_comm_op(x, ReduceOp.SUM)
158 |
159 | def all_broadcast(self, x: torch.Tensor) -> torch.Tensor:
160 | if self.dp_comm is not None:
161 | self.dp_comm.broadcast(
162 | x,
163 | replica_shard_to_rank(
164 | replica_idx=0,
165 | shard_idx=self.sh_rank,
166 | n_op_shards=self.n_op_shards,
167 | ),
168 | )
169 |
170 | if self.sh_comm is not None:
171 | self.sh_comm.broadcast(
172 | x,
173 | replica_shard_to_rank(
174 | replica_idx=self.dp_rank,
175 | shard_idx=0,
176 | n_op_shards=self.n_op_shards,
177 | ),
178 | )
179 |
180 | return x
181 |
182 |
183 | def make_torch_comms(n_op_shards=4, n_replicas=2):
184 | if "RANK" not in os.environ:
185 | assert n_op_shards == 1
186 | assert n_replicas == 1
187 | return TRIVIAL_COMMS
188 |
189 | rank = int(os.environ.get("RANK"))
190 | world_size = int(os.environ.get("WORLD_SIZE", 1))
191 | os.environ["CUDA_VISIBLE_DEVICES"] = str(rank % 8)
192 |
193 | print(f"{rank=}, {world_size=}")
194 | dist.init_process_group("nccl")
195 |
196 | my_op_shard_idx = rank % n_op_shards
197 | my_replica_idx = rank // n_op_shards
198 |
199 | shard_rank_lists = [list(range(i, i + n_op_shards)) for i in range(0, world_size, n_op_shards)]
200 |
201 | shard_groups = [dist.new_group(shard_rank_list) for shard_rank_list in shard_rank_lists]
202 |
203 | my_shard_group = shard_groups[my_replica_idx]
204 |
205 | replica_rank_lists = [
206 | list(range(i, n_op_shards * n_replicas, n_op_shards)) for i in range(n_op_shards)
207 | ]
208 |
209 | replica_groups = [dist.new_group(replica_rank_list) for replica_rank_list in replica_rank_lists]
210 |
211 | my_replica_group = replica_groups[my_op_shard_idx]
212 |
213 | torch.distributed.all_reduce(torch.ones(1).cuda())
214 | torch.cuda.synchronize()
215 |
216 | dp_comm = Comm(group=my_replica_group)
217 | sh_comm = Comm(group=my_shard_group)
218 |
219 | return ShardingComms(
220 | n_replicas=n_replicas,
221 | n_op_shards=n_op_shards,
222 | dp_comm=dp_comm,
223 | sh_comm=sh_comm,
224 | dp_rank=my_replica_idx,
225 | sh_rank=my_op_shard_idx,
226 | _rank=rank,
227 | )
228 |
229 |
230 | def replica_shard_to_rank(replica_idx, shard_idx, n_op_shards):
231 | return replica_idx * n_op_shards + shard_idx
232 |
233 |
234 | TRIVIAL_COMMS = ShardingComms(
235 | n_replicas=1,
236 | n_op_shards=1,
237 | dp_rank=0,
238 | sh_rank=0,
239 | dp_comm=None,
240 | sh_comm=None,
241 | _rank=0,
242 | )
243 |
244 |
245 | def sharded_topk(x, k, sh_comm, capacity_factor=None):
246 | batch = x.shape[0]
247 |
248 | if capacity_factor is not None:
249 | k_in = min(int(k * capacity_factor // sh_comm.size()), k)
250 | else:
251 | k_in = k
252 |
253 | topk = torch.topk(x, k=k_in, dim=-1)
254 | inds = topk.indices
255 | vals = topk.values
256 |
257 | if sh_comm is None:
258 | return inds, vals
259 |
260 | all_vals = torch.empty(sh_comm.size(), batch, k_in, dtype=vals.dtype, device=vals.device)
261 | sh_comm.all_gather(all_vals, vals, async_op=True)
262 |
263 | all_vals = all_vals.permute(1, 0, 2) # put shard dim next to k
264 | all_vals = all_vals.reshape(batch, -1) # flatten shard into k
265 |
266 | all_topk = torch.topk(all_vals, k=k, dim=-1)
267 | global_topk = all_topk.values
268 |
269 | dummy_vals = torch.zeros_like(vals)
270 | dummy_inds = torch.zeros_like(inds)
271 |
272 | my_inds = torch.where(vals >= global_topk[:, [-1]], inds, dummy_inds)
273 | my_vals = torch.where(vals >= global_topk[:, [-1]], vals, dummy_vals)
274 |
275 | return my_inds, my_vals
276 |
277 |
278 | ## autoencoder
279 |
280 |
281 | class FastAutoencoder(nn.Module):
282 | """
283 | Top-K Autoencoder with sparse kernels. Implements:
284 |
285 | latents = relu(topk(encoder(x - pre_bias) + latent_bias))
286 | recons = decoder(latents) + pre_bias
287 | """
288 |
289 | def __init__(
290 | self,
291 | n_dirs_local: int,
292 | d_model: int,
293 | k: int,
294 | auxk: int | None,
295 | dead_steps_threshold: int,
296 | comms: ShardingComms | None = None,
297 | ):
298 | super().__init__()
299 | self.n_dirs_local = n_dirs_local
300 | self.d_model = d_model
301 | self.k = k
302 | self.auxk = auxk
303 | self.comms = comms if comms is not None else TRIVIAL_COMMS
304 | self.dead_steps_threshold = dead_steps_threshold
305 |
306 | self.encoder = nn.Linear(d_model, n_dirs_local, bias=False)
307 | self.decoder = nn.Linear(n_dirs_local, d_model, bias=False)
308 |
309 | self.pre_bias = nn.Parameter(torch.zeros(d_model))
310 | self.latent_bias = nn.Parameter(torch.zeros(n_dirs_local))
311 |
312 | self.stats_last_nonzero: torch.Tensor
313 | self.register_buffer("stats_last_nonzero", torch.zeros(n_dirs_local, dtype=torch.long))
314 |
315 | def auxk_mask_fn(x):
316 | dead_mask = self.stats_last_nonzero > dead_steps_threshold
317 | x.data *= dead_mask # inplace to save memory
318 | return x
319 |
320 | self.auxk_mask_fn = auxk_mask_fn
321 |
322 | ## initialization
323 |
324 | # "tied" init
325 | self.decoder.weight.data = self.encoder.weight.data.T.clone()
326 |
327 | # store decoder in column major layout for kernel
328 | self.decoder.weight.data = self.decoder.weight.data.T.contiguous().T
329 |
330 | unit_norm_decoder_(self)
331 |
332 | @property
333 | def n_dirs(self):
334 | return self.n_dirs_local * self.comms.n_op_shards
335 |
336 | def forward(self, x):
337 | class EncWrapper(torch.autograd.Function):
338 | @staticmethod
339 | def forward(ctx, x, pre_bias, weight, latent_bias):
340 | x = x - pre_bias
341 | latents_pre_act = F.linear(x, weight, latent_bias)
342 |
343 | inds, vals = sharded_topk(
344 | latents_pre_act,
345 | k=self.k,
346 | sh_comm=self.comms.sh_comm,
347 | capacity_factor=4,
348 | )
349 |
350 | ## set num nonzero stat ##
351 | tmp = torch.zeros_like(self.stats_last_nonzero)
352 | tmp.scatter_add_(
353 | 0,
354 | inds.reshape(-1),
355 | (vals > 1e-3).to(tmp.dtype).reshape(-1),
356 | )
357 | self.stats_last_nonzero *= 1 - tmp.clamp(max=1)
358 | self.stats_last_nonzero += 1
359 | ## end stats ##
360 |
361 | ## auxk
362 | if self.auxk is not None: # for auxk
363 | # IMPORTANT: has to go after stats update!
364 | # WARN: auxk_mask_fn can mutate latents_pre_act!
365 | auxk_inds, auxk_vals = sharded_topk(
366 | self.auxk_mask_fn(latents_pre_act),
367 | k=self.auxk,
368 | sh_comm=self.comms.sh_comm,
369 | capacity_factor=2,
370 | )
371 | ctx.save_for_backward(x, weight, inds, auxk_inds)
372 | else:
373 | ctx.save_for_backward(x, weight, inds)
374 | auxk_inds = None
375 | auxk_vals = None
376 |
377 | ## end auxk
378 |
379 | return (
380 | inds,
381 | vals,
382 | auxk_inds,
383 | auxk_vals,
384 | )
385 |
386 | @staticmethod
387 | def backward(ctx, _, grad_vals, __, grad_auxk_vals):
388 | # encoder backwards
389 | if self.auxk is not None:
390 | x, weight, inds, auxk_inds = ctx.saved_tensors
391 |
392 | all_inds = torch.cat((inds, auxk_inds), dim=-1)
393 | all_grad_vals = torch.cat((grad_vals, grad_auxk_vals), dim=-1)
394 | else:
395 | x, weight, inds = ctx.saved_tensors
396 |
397 | all_inds = inds
398 | all_grad_vals = grad_vals
399 |
400 | grad_sum = torch.zeros(self.n_dirs_local, dtype=torch.float32, device=grad_vals.device)
401 | grad_sum.scatter_add_(
402 | -1, all_inds.flatten(), all_grad_vals.flatten().to(torch.float32)
403 | )
404 |
405 | return (
406 | None,
407 | # pre_bias grad optimization - can reduce before mat-vec multiply
408 | -(grad_sum @ weight),
409 | triton_sparse_transpose_dense_matmul(all_inds, all_grad_vals, x, N=self.n_dirs_local),
410 | grad_sum,
411 | )
412 |
413 | pre_bias = self.comms.sh_allreduce_backward(self.pre_bias)
414 |
415 | # encoder
416 | inds, vals, auxk_inds, auxk_vals = EncWrapper.apply(
417 | x, pre_bias, self.encoder.weight, self.latent_bias
418 | )
419 |
420 | vals = torch.relu(vals)
421 | if auxk_vals is not None:
422 | auxk_vals = torch.relu(auxk_vals)
423 |
424 | recons = self.decode_sparse(inds, vals)
425 |
426 | return recons, {
427 | "auxk_inds": auxk_inds,
428 | "auxk_vals": auxk_vals,
429 | }
430 |
431 | def decode_sparse(self, inds, vals):
432 | recons = TritonDecoderAutograd.apply(inds, vals, self.decoder.weight)
433 | recons = self.comms.sh_allreduce_forward(recons)
434 |
435 | return recons + self.pre_bias
436 |
437 |
438 | def unit_norm_decoder_(autoencoder: FastAutoencoder) -> None:
439 | """
440 | Unit normalize the decoder weights of an autoencoder.
441 | """
442 | autoencoder.decoder.weight.data /= autoencoder.decoder.weight.data.norm(dim=0)
443 |
444 |
445 | def unit_norm_decoder_grad_adjustment_(autoencoder) -> None:
446 | """project out gradient information parallel to the dictionary vectors - assumes that the decoder is already unit normed"""
447 |
448 | assert autoencoder.decoder.weight.grad is not None
449 |
450 | triton_add_mul_(
451 | autoencoder.decoder.weight.grad,
452 | torch.einsum("bn,bn->n", autoencoder.decoder.weight.data, autoencoder.decoder.weight.grad),
453 | autoencoder.decoder.weight.data,
454 | c=-1,
455 | )
456 |
457 |
458 | def maybe_transpose(x):
459 | return x.T if not x.is_contiguous() and x.T.is_contiguous() else x
460 |
461 |
462 | def sharded_grad_norm(autoencoder, comms, exclude=None):
463 | if exclude is None:
464 | exclude = []
465 | total_sq_norm = torch.zeros((), device="cuda", dtype=torch.float32)
466 | exclude = set(exclude)
467 |
468 | total_num_params = 0
469 | for param in autoencoder.parameters():
470 | if param in exclude:
471 | continue
472 | if param.grad is not None:
473 | sq_norm = ((param.grad).float() ** 2).sum()
474 | if param is autoencoder.pre_bias:
475 | total_sq_norm += sq_norm # pre_bias is the same across all shards
476 | else:
477 | total_sq_norm += comms.sh_sum(sq_norm)
478 |
479 | param_shards = comms.n_op_shards if param is autoencoder.pre_bias else 1
480 | total_num_params += param.numel() * param_shards
481 |
482 | return total_sq_norm.sqrt()
483 |
484 |
485 | def batch_tensors(
486 | it: Iterable[torch.Tensor],
487 | batch_size: int,
488 | drop_last=True,
489 | stream=None,
490 | ) -> Iterator[torch.Tensor]:
491 | """
492 | input is iterable of tensors of shape [batch_old, ...]
493 | output is iterable of tensors of shape [batch_size, ...]
494 | batch_old does not need to be divisible by batch_size
495 | """
496 |
497 | tensors = []
498 | batch_so_far = 0
499 |
500 | for t in it:
501 | tensors.append(t)
502 | batch_so_far += t.shape[0]
503 |
504 | if sum(t.shape[0] for t in tensors) < batch_size:
505 | continue
506 |
507 | while batch_so_far >= batch_size:
508 | if len(tensors) == 1:
509 | (concat,) = tensors
510 | else:
511 | with torch.cuda.stream(stream):
512 | concat = torch.cat(tensors, dim=0)
513 |
514 | offset = 0
515 | while offset + batch_size <= concat.shape[0]:
516 | yield concat[offset : offset + batch_size]
517 | batch_so_far -= batch_size
518 | offset += batch_size
519 |
520 | tensors = [concat[offset:]] if offset < concat.shape[0] else []
521 |
522 | if len(tensors) > 0 and not drop_last:
523 | yield torch.cat(tensors, dim=0)
524 |
525 |
526 | def print0(*a, **k):
527 | if RANK == 0:
528 | print(*a, **k)
529 |
530 |
531 | import wandb
532 |
533 |
534 | class Logger:
535 | def __init__(self, **kws):
536 | self.vals = {}
537 | self.enabled = (RANK == 0) and not kws.pop("dummy", False)
538 | if self.enabled:
539 | wandb.init(
540 | **kws
541 | )
542 |
543 | def logkv(self, k, v):
544 | if self.enabled:
545 | self.vals[k] = v.detach() if isinstance(v, torch.Tensor) else v
546 | return v
547 |
548 | def dumpkvs(self):
549 | if self.enabled:
550 | wandb.log(self.vals)
551 | self.vals = {}
552 |
553 |
554 | def training_loop_(
555 | ae, train_acts_iter, loss_fn, lr, comms, eps=6.25e-10, clip_grad=None, ema_multiplier=0.999, logger=None
556 | ):
557 | if logger is None:
558 | logger = Logger(dummy=True)
559 |
560 | scaler = torch.cuda.amp.GradScaler()
561 | autocast_ctx_manager = torch.cuda.amp.autocast()
562 |
563 | opt = torch.optim.Adam(ae.parameters(), lr=lr, eps=eps, fused=True)
564 | if ema_multiplier is not None:
565 | ema = EmaModel(ae, ema_multiplier=ema_multiplier)
566 |
567 | for i, flat_acts_train_batch in enumerate(train_acts_iter):
568 | flat_acts_train_batch = flat_acts_train_batch.cuda()
569 |
570 | with autocast_ctx_manager:
571 | recons, info = ae(flat_acts_train_batch)
572 |
573 | loss = loss_fn(ae, flat_acts_train_batch, recons, info, logger)
574 |
575 | print0(i, loss)
576 |
577 | logger.logkv("loss_scale", scaler.get_scale())
578 |
579 | if RANK == 0:
580 | wandb.log({"train_loss": loss.item()})
581 |
582 | loss = scaler.scale(loss)
583 | loss.backward()
584 |
585 | unit_norm_decoder_(ae)
586 | unit_norm_decoder_grad_adjustment_(ae)
587 |
588 | # allreduce gradients
589 | comms.dp_allreduce_(ae)
590 |
591 | # keep fp16 loss scale synchronized across shards
592 | comms.sh_allreduce_scale(scaler)
593 |
594 | # if you want to do anything with the gradients that depends on the absolute scale (e.g clipping, do it after the unscale_)
595 | scaler.unscale_(opt)
596 |
597 | # gradient clipping
598 | if clip_grad is not None:
599 | grad_norm = sharded_grad_norm(ae, comms)
600 | logger.logkv("grad_norm", grad_norm)
601 | grads = [x.grad for x in ae.parameters() if x.grad is not None]
602 | torch._foreach_mul_(grads, clip_grad / torch.clamp(grad_norm, min=clip_grad))
603 |
604 | if ema_multiplier is not None:
605 | ema.step()
606 |
607 | # take step with optimizer
608 | scaler.step(opt)
609 | scaler.update()
610 |
611 | logger.dumpkvs()
612 |
613 |
614 | def init_from_data_(ae, stats_acts_sample, comms):
615 | from geom_median.torch import compute_geometric_median
616 |
617 | ae.pre_bias.data = (
618 | compute_geometric_median(stats_acts_sample[:32768].float().cpu()).median.cuda().float()
619 | )
620 | comms.all_broadcast(ae.pre_bias.data)
621 |
622 | # encoder initialization (note: in our ablations we couldn't find clear evidence that this is beneficial, this is just to ensure exact match with internal codebase)
623 | d_model = ae.d_model
624 | with torch.no_grad():
625 | x = torch.randn(256, d_model).cuda().to(stats_acts_sample.dtype)
626 | x /= x.norm(dim=-1, keepdim=True)
627 | x += ae.pre_bias.data
628 | comms.all_broadcast(x)
629 | recons, _ = ae(x)
630 | recons_norm = (recons - ae.pre_bias.data).norm(dim=-1).mean()
631 |
632 | ae.encoder.weight.data /= recons_norm.item()
633 | print0("x norm", x.norm(dim=-1).mean().item())
634 | print0("out norm", (ae(x)[0] - ae.pre_bias.data).norm(dim=-1).mean().item())
635 |
636 |
637 | from contextlib import contextmanager
638 |
639 |
640 | @contextmanager
641 | def temporary_weight_swap(model: torch.nn.Module, new_weights: list[torch.Tensor]):
642 | for _p, new_p in zip(model.parameters(), new_weights, strict=True):
643 | assert _p.shape == new_p.shape
644 | _p.data, new_p.data = new_p.data, _p.data
645 |
646 | yield
647 |
648 | for _p, new_p in zip(model.parameters(), new_weights, strict=True):
649 | assert _p.shape == new_p.shape
650 | _p.data, new_p.data = new_p.data, _p.data
651 |
652 |
653 | class EmaModel:
654 | def __init__(self, model, ema_multiplier):
655 | self.model = model
656 | self.ema_multiplier = ema_multiplier
657 | self.ema_weights = [torch.zeros_like(x, requires_grad=False) for x in model.parameters()]
658 | self.ema_steps = 0
659 |
660 | def step(self):
661 | torch._foreach_lerp_(
662 | self.ema_weights,
663 | list(self.model.parameters()),
664 | 1 - self.ema_multiplier,
665 | )
666 | self.ema_steps += 1
667 |
668 | # context manager for setting the autoencoder weights to the EMA weights
669 | @contextmanager
670 | def use_ema_weights(self):
671 | assert self.ema_steps > 0
672 |
673 | # apply bias correction
674 | bias_correction = 1 - self.ema_multiplier**self.ema_steps
675 | ema_weights_bias_corrected = torch._foreach_div(self.ema_weights, bias_correction)
676 |
677 | with torch.no_grad():
678 | with temporary_weight_swap(self.model, ema_weights_bias_corrected):
679 | yield
680 |
681 |
682 | @dataclass
683 | class Config:
684 | n_op_shards: int = 1
685 | n_replicas: int = 8
686 |
687 | n_dirs: int = 32768
688 | bs: int = 131072
689 | d_model: int = 768
690 | k: int = 32
691 | auxk: int = 256
692 |
693 | lr: float = 1e-4
694 | eps: float = 6.25e-10
695 | clip_grad: float | None = None
696 | auxk_coef: float = 1 / 32
697 | dead_toks_threshold: int = 10_000_000
698 | ema_multiplier: float | None = None
699 |
700 | wandb_project: str | None = None
701 | wandb_name: str | None = None
702 |
703 |
704 | def main():
705 | cfg = Config()
706 | comms = make_torch_comms(n_op_shards=cfg.n_op_shards, n_replicas=cfg.n_replicas)
707 |
708 | ## dataloading is left as an exercise for the reader
709 | acts_iter = ...
710 | stats_acts_sample = ...
711 |
712 | n_dirs_local = cfg.n_dirs // cfg.n_op_shards
713 | bs_local = cfg.bs // cfg.n_replicas
714 |
715 | ae = FastAutoencoder(
716 | n_dirs_local=n_dirs_local,
717 | d_model=cfg.d_model,
718 | k=cfg.k,
719 | auxk=cfg.auxk,
720 | dead_steps_threshold=cfg.dead_toks_threshold // cfg.bs,
721 | comms=comms,
722 | )
723 | ae.cuda()
724 | init_from_data_(ae, stats_acts_sample, comms)
725 | # IMPORTANT: make sure all DP ranks have the same params
726 | comms.init_broadcast_(ae)
727 |
728 | mse_scale = (
729 | 1 / ((stats_acts_sample.float().mean(dim=0) - stats_acts_sample.float()) ** 2).mean()
730 | )
731 | comms.all_broadcast(mse_scale)
732 | mse_scale = mse_scale.item()
733 |
734 | logger = Logger(
735 | project=cfg.wandb_project,
736 | name=cfg.wandb_name,
737 | dummy=cfg.wandb_project is None,
738 | )
739 |
740 | training_loop_(
741 | ae,
742 | batch_tensors(
743 | acts_iter,
744 | bs_local,
745 | drop_last=True,
746 | ),
747 | lambda ae, flat_acts_train_batch, recons, info, logger: (
748 | # MSE
749 | logger.logkv("train_recons", mse_scale * mse(recons, flat_acts_train_batch))
750 | # AuxK
751 | + logger.logkv(
752 | "train_maxk_recons",
753 | cfg.auxk_coef
754 | * normalized_mse(
755 | ae.decode_sparse(
756 | info["auxk_inds"],
757 | info["auxk_vals"],
758 | ),
759 | flat_acts_train_batch - recons.detach() + ae.pre_bias.detach(),
760 | ).nan_to_num(0),
761 | )
762 | ),
763 | lr=cfg.lr,
764 | eps=cfg.eps,
765 | clip_grad=cfg.clip_grad,
766 | ema_multiplier=cfg.ema_multiplier,
767 | logger=logger,
768 | comms=comms,
769 | )
770 |
771 |
772 | if __name__ == "__main__":
773 | main()
774 |
--------------------------------------------------------------------------------