├── sae-viewer ├── public │ ├── robots.txt │ └── favicon.ico ├── .gitignore ├── tailwind.config.js ├── src │ ├── index.tsx │ ├── index.css │ ├── utils.ts │ ├── App.tsx │ ├── feed.tsx │ ├── components │ │ ├── histogram.tsx │ │ ├── tooltip.tsx │ │ ├── tokenHeatmap.tsx │ │ ├── tokenAblationmap.tsx │ │ ├── featureSelect.tsx │ │ └── featureInfo.tsx │ ├── interpAPI.ts │ ├── index.html │ ├── autoencoder_registry.tsx │ ├── types.ts │ ├── App.css │ └── welcome.tsx ├── README.md ├── tsconfig.json └── package.json ├── sparse_autoencoder ├── __init__.py ├── loss.py ├── paths.py ├── model.py ├── explanations.py ├── kernels.py └── train.py ├── .pre-commit-config.yaml ├── SECURITY.md ├── pyproject.toml ├── LICENSE ├── .gitignore └── README.md /sae-viewer/public/robots.txt: -------------------------------------------------------------------------------- 1 | # https://www.robotstxt.org/robotstxt.html 2 | User-agent: * 3 | Disallow: 4 | -------------------------------------------------------------------------------- /sae-viewer/public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/sparse_autoencoder/HEAD/sae-viewer/public/favicon.ico -------------------------------------------------------------------------------- /sae-viewer/.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | dist/ 3 | node_modules/ 4 | __pycache__/ 5 | .parcel-cache/ 6 | .cache/ 7 | blog_json/ 8 | -------------------------------------------------------------------------------- /sparse_autoencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Autoencoder 2 | from . import paths 3 | 4 | __all__ = ["Autoencoder"] 5 | -------------------------------------------------------------------------------- /sae-viewer/tailwind.config.js: -------------------------------------------------------------------------------- 1 | /** @type {import('tailwindcss').Config} */ 2 | module.exports = { 3 | content: ["./src/**/*.{html,js,jsx}"], 4 | theme: { 5 | extend: {}, 6 | }, 7 | plugins: [], 8 | } 9 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: local 3 | hooks: 4 | - id: trufflehog 5 | name: TruffleHog 6 | description: Detect secrets in your data. 7 | entry: bash -c 'trufflehog git file://. --since-commit HEAD --fail --no-update' 8 | language: system 9 | stages: ["commit", "push"] 10 | -------------------------------------------------------------------------------- /sae-viewer/src/index.tsx: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ReactDOM from 'react-dom/client'; 3 | import './index.css'; 4 | import App from './App'; 5 | 6 | const root = ReactDOM.createRoot(document.getElementById('root')); 7 | root.render( 8 | 9 | 10 | 11 | ); 12 | -------------------------------------------------------------------------------- /sae-viewer/README.md: -------------------------------------------------------------------------------- 1 | # SAE viewer 2 | 3 | The easiest way to view activation patterns is through the 4 | [public website](https://openaipublic.blob.core.windows.net/sparse-autoencoder/sae-viewer/index.html). 5 | This directory contains the implementation of that website. 6 | 7 | ## Local development 8 | 9 | Install: 10 | 11 | ```npm install``` 12 | 13 | Run: 14 | 15 | ```npm start``` 16 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | For a more in-depth look at our security policy, please check out our 3 | [Coordinated Vulnerability Disclosure Policy](https://openai.com/security/disclosure/#:~:text=Disclosure%20Policy,-Security%20is%20essential&text=OpenAI%27s%20coordinated%20vulnerability%20disclosure%20policy,expect%20from%20us%20in%20return.). 4 | 5 | Our PGP key can located [at this address.](https://cdn.openai.com/security.txt) 6 | -------------------------------------------------------------------------------- /sae-viewer/src/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | margin: 0; 3 | font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', 4 | 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', 5 | sans-serif; 6 | -webkit-font-smoothing: antialiased; 7 | -moz-osx-font-smoothing: grayscale; 8 | } 9 | 10 | code { 11 | font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New', 12 | monospace; 13 | } 14 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "sparse_autoencoder" 3 | description="Sparse autoencoder for GPT2" 4 | version = "0.1" 5 | authors = [{name = "OpenAI"}] 6 | dependencies = [ 7 | "blobfile == 2.0.2", 8 | "torch == 2.1.0", 9 | "transformer_lens == 1.9.1", 10 | ] 11 | readme = "README.md" 12 | 13 | [build-system] 14 | requires = ["setuptools>=64.0"] 15 | build-backend = "setuptools.build_meta" 16 | 17 | [tool.setuptools.packages.find] 18 | include = ["sparse_autoencoder*"] 19 | -------------------------------------------------------------------------------- /sae-viewer/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "es2021", 4 | "module": "commonjs", 5 | "lib": ["dom", "dom.iterable", "esnext"], 6 | "allowJs": true, 7 | "skipLibCheck": true, 8 | "esModuleInterop": true, 9 | "allowSyntheticDefaultImports": true, 10 | "strict": true, 11 | "forceConsistentCasingInFileNames": true, 12 | "moduleResolution": "node", 13 | "resolveJsonModule": true, 14 | "isolatedModules": true, 15 | "noEmit": true, 16 | "jsx": "react-jsx" 17 | }, 18 | "include": ["src"] 19 | } 20 | -------------------------------------------------------------------------------- /sae-viewer/src/utils.ts: -------------------------------------------------------------------------------- 1 | export const memoizeAsync = (fnname: string, fn: any) => { 2 | return async (...args: any) => { 3 | const key = `memoized:${fnname}:${args.map((x: any) => JSON.stringify(x)).join("-")}` 4 | const val = localStorage.getItem(key); 5 | if (val === null) { 6 | const value = await fn(...args) 7 | localStorage.setItem(key, JSON.stringify(value)) 8 | console.log(`memoized ${fnname}(${args.map((x: any) => JSON.stringify(x)).join(", ")})`, value) 9 | return value 10 | } else { 11 | // console.log(`parsing`, val) 12 | return JSON.parse(val) 13 | } 14 | } 15 | } 16 | 17 | 18 | export const getQueryParams = () => { 19 | const urlParams = new URLSearchParams(window.location.search) 20 | const params: {[key: string]: any} = {} 21 | for (const [key, value] of urlParams.entries()) { 22 | params[key] = value 23 | } 24 | return params 25 | } 26 | -------------------------------------------------------------------------------- /sae-viewer/src/App.tsx: -------------------------------------------------------------------------------- 1 | import "./App.css" 2 | import Feed from "./feed" 3 | import React from "react" 4 | import { Routes, Route, HashRouter } from "react-router-dom" 5 | import { AUTOENCODER_FAMILIES } from "./autoencoder_registry" 6 | import Welcome from "./welcome" 7 | 8 | function App() { 9 | return ( 10 |
11 | 12 | 13 | } /> 14 | } /> 15 | { 16 | Object.values(AUTOENCODER_FAMILIES).map((family) => { 17 | let extra = ''; 18 | family.selectors.forEach((selector) => { 19 | extra += `/${selector.key}/:${selector.key}`; 20 | }) 21 | return } /> 22 | }) 23 | } 24 | 25 | 26 |
27 | ) 28 | } 29 | 30 | export default App 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /sae-viewer/src/feed.tsx: -------------------------------------------------------------------------------- 1 | import FeatureInfo from "./components/featureInfo" 2 | import React, { useEffect } from "react" 3 | import Welcome from "./welcome" 4 | import { Feature } from "./types" 5 | import FeatureSelect from "./components/featureSelect" 6 | import { useState } from "react" 7 | import { useParams, useNavigate, Link } from "react-router-dom" 8 | 9 | import { pathForFeature, DEFAULT_AUTOENCODER, AUTOENCODER_FAMILIES } from "./autoencoder_registry" 10 | 11 | export default function Feed() { 12 | const params = useParams(); 13 | const navigate = useNavigate(); 14 | let family = AUTOENCODER_FAMILIES[params.family || DEFAULT_AUTOENCODER.family]; 15 | let feature: Feature = { 16 | // "layer": parseInt(params.layer), 17 | "atom": parseInt(params.atom), 18 | "autoencoder": family.get_ae(params), 19 | }; 20 | console.log('feature', JSON.stringify(feature, null, 2)) 21 | 22 | return ( 23 |
24 |
25 |

26 | SAE viewer 27 |

28 | navigate(pathForFeature(f, {replace: true}))} 31 | onFeatureSubmit={(f: Feature) => navigate(pathForFeature(f, {replace: true}))} 32 | /> 33 |
34 |
35 |
36 | 37 |
38 |
39 | ) 40 | } 41 | -------------------------------------------------------------------------------- /sae-viewer/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "sae-viewer", 3 | "version": "0.1.67", 4 | "homepage": "https://openaipublic.blob.core.windows.net/sparse-autoencoder/sae-viewer/index.html", 5 | "dependencies": { 6 | "@headlessui/react": "^1.7.8", 7 | "@headlessui/tailwindcss": "^0.1.2", 8 | "@types/d3-scale": "^4.0.3", 9 | "@types/lodash": "^4.14.194", 10 | "@types/react": "^18.0.37", 11 | "@types/react-dom": "^18.0.11", 12 | "d3-scale": "^4.0.2", 13 | "lodash": "^4.17.21", 14 | "plotly.js": "^2.31.0", 15 | "react": "^18.2.0", 16 | "react-dom": "^18.2.0", 17 | "react-plotly.js": "^2.6.0", 18 | "react-router-dom": "^6.10.0", 19 | "web-vitals": "^3.0.3" 20 | }, 21 | "scripts": { 22 | "start": "rm -rf dist && rm -rf .parcel-cache && parcel src/index.html", 23 | "build": "parcel build src/index.html", 24 | "serve": "parcel serve src/index.html", 25 | "typecheck": "tsc -p ." 26 | }, 27 | "eslintConfig": { 28 | "extends": [ 29 | "react-app" 30 | ] 31 | }, 32 | "alias": { 33 | "preact/jsx-dev-runtime": "preact/jsx-runtime" 34 | }, 35 | "devDependencies": { 36 | "@observablehq/plot": "^0.6.5", 37 | "@parcel/transformer-typescript-tsc": "^2.8.3", 38 | "@parcel/validator-typescript": "^2.8.3", 39 | "buffer": "^5.7.1", 40 | "nodemon": "^2.0.22", 41 | "parcel": "^2.8.3", 42 | "preact": "^10.13.2", 43 | "process": "^0.11.10", 44 | "react-refresh": "0.10.0", 45 | "tailwindcss": "^3.2.4", 46 | "typescript": "^5.0.4" 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /sae-viewer/src/components/histogram.tsx: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import Plot from 'react-plotly.js'; 3 | 4 | // TODO get from data 5 | const BIN_WIDTH = 0.2; 6 | // # bins_fn = lambda lats: (lats / BIN_WIDTH).ceil().int() 7 | // bin_fn = lambda val: math.ceil(val / BIN_WIDTH) 8 | // bin_id_to_lower_bound = lambda xs: xs * BIN_WIDTH 9 | 10 | const HistogramDisplay = ({ data }) => { 11 | // min_bin = min(hist.keys()) 12 | // max_bin = max(hist.keys()) 13 | // ys = [hist.get(x, 0) for x in np.arange(min_bin, max_bin + 1)] 14 | // xs = np.arange(min_bin, max_bin + 2) 15 | // xs = bin_id_to_lower_bound(np.array(xs)) 16 | const min_bin = Math.min(...Object.keys(data).map(Number)); 17 | const max_bin = Math.max(...Object.keys(data).map(Number)); 18 | const ys = Array.from({length: max_bin - min_bin + 1}, (_, i) => data[min_bin + i] || 0); 19 | let xs = Array.from({length: max_bin - min_bin + 2}, (_, i) => min_bin + i); 20 | xs = xs.map(x => x * BIN_WIDTH); 21 | 22 | const trace = { 23 | line: {shape: 'hvh'}, 24 | mode: 'lines', 25 | type: 'scatter', 26 | x: xs, 27 | y: ys, 28 | fill: 'tozeroy', 29 | }; 30 | const layout = { 31 | legend: { 32 | y: 0.5, 33 | font: {size: 16}, 34 | traceorder: 'reversed', 35 | }, 36 | yaxis: { 37 | type: 'log', 38 | autorange: true 39 | }, 40 | margin: { l: 30, r: 0, b: 20, t: 0 }, 41 | autosize: true, 42 | }; 43 | 44 | return ( 45 | 50 | ) 51 | } 52 | 53 | export default HistogramDisplay; 54 | -------------------------------------------------------------------------------- /sae-viewer/src/components/tooltip.tsx: -------------------------------------------------------------------------------- 1 | import React from "react" 2 | 3 | type Props = { 4 | content: React.ReactNode; 5 | tooltip: React.ReactNode; 6 | tooltipStyle?: React.CSSProperties; 7 | }; 8 | 9 | function useOutsideClickAlerter(ref, fn) { 10 | React.useEffect(() => { 11 | /** 12 | * Alert if clicked on outside of element 13 | */ 14 | function handleClickOutside(event) { 15 | if (ref.current && !ref.current.contains(event.target)) { fn() } 16 | } 17 | // Bind the event listener 18 | document.addEventListener("mousedown", handleClickOutside); 19 | return () => { 20 | // Unbind the event listener on clean up 21 | document.removeEventListener("mousedown", handleClickOutside); 22 | }; 23 | }, [ref]); 24 | } 25 | 26 | const Tooltip: React.FunctionComponent = (props: Props) => { 27 | const [state, set_state] = React.useState({ force_open: false }); 28 | const [hover, set_hover] = React.useState(false); 29 | const wrapperRef = React.useRef(null); 30 | useOutsideClickAlerter(wrapperRef, () => set_state({ force_open: false })); 31 | return ( 32 | 33 | set_state({ force_open: !state.force_open })} 34 | onMouseOver={() => set_hover(true)} 35 | onMouseOut={() => setTimeout(() => set_hover(false), 100)} 36 | > 37 | {props.content} 38 | 39 | 42 | {props.tooltip} 43 | 44 | 45 | ); 46 | }; 47 | 48 | 49 | export default Tooltip; 50 | -------------------------------------------------------------------------------- /sparse_autoencoder/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def autoencoder_loss( 5 | reconstruction: torch.Tensor, 6 | original_input: torch.Tensor, 7 | latent_activations: torch.Tensor, 8 | l1_weight: float, 9 | ) -> torch.Tensor: 10 | """ 11 | :param reconstruction: output of Autoencoder.decode (shape: [batch, n_inputs]) 12 | :param original_input: input of Autoencoder.encode (shape: [batch, n_inputs]) 13 | :param latent_activations: output of Autoencoder.encode (shape: [batch, n_latents]) 14 | :param l1_weight: weight of L1 loss 15 | :return: loss (shape: [1]) 16 | """ 17 | return ( 18 | normalized_mean_squared_error(reconstruction, original_input) 19 | + normalized_L1_loss(latent_activations, original_input) * l1_weight 20 | ) 21 | 22 | 23 | def normalized_mean_squared_error( 24 | reconstruction: torch.Tensor, 25 | original_input: torch.Tensor, 26 | ) -> torch.Tensor: 27 | """ 28 | :param reconstruction: output of Autoencoder.decode (shape: [batch, n_inputs]) 29 | :param original_input: input of Autoencoder.encode (shape: [batch, n_inputs]) 30 | :return: normalized mean squared error (shape: [1]) 31 | """ 32 | return ( 33 | ((reconstruction - original_input) ** 2).mean(dim=1) / (original_input**2).mean(dim=1) 34 | ).mean() 35 | 36 | 37 | def normalized_L1_loss( 38 | latent_activations: torch.Tensor, 39 | original_input: torch.Tensor, 40 | ) -> torch.Tensor: 41 | """ 42 | :param latent_activations: output of Autoencoder.encode (shape: [batch, n_latents]) 43 | :param original_input: input of Autoencoder.encode (shape: [batch, n_inputs]) 44 | :return: normalized L1 loss (shape: [1]) 45 | """ 46 | return (latent_activations.abs().sum(dim=1) / original_input.norm(dim=1)).mean() 47 | -------------------------------------------------------------------------------- /sae-viewer/src/components/tokenHeatmap.tsx: -------------------------------------------------------------------------------- 1 | import React from "react" 2 | import { interpolateColor, Color, getInterpolatedColor, DEFAULT_COLORS, DEFAULT_BOUNDARIES, SequenceInfo } from '../types' 3 | import Tooltip from './tooltip' 4 | 5 | type Props = { 6 | info: SequenceInfo, 7 | colors?: Color[], 8 | boundaries?: number[] 9 | renderNewlines?: boolean, 10 | } 11 | 12 | function zip_sequence(sequence: SequenceInfo) { 13 | return sequence.tokens.map((token, idx) => ({ 14 | token, 15 | highlight: idx === sequence.idx, 16 | activation: sequence.acts[idx], 17 | normalized_activation: sequence.normalized_acts ? sequence.normalized_acts[idx] : undefined 18 | })); 19 | } 20 | 21 | export default function TokenHeatmap({ info, colors = DEFAULT_COLORS, boundaries = DEFAULT_BOUNDARIES, renderNewlines }: Props) { 22 | //
23 | const zipped = zip_sequence(info) 24 | return ( 25 |
26 | {zipped.map(({ token, activation, normalized_activation, highlight }, i) => { 27 | const color = getInterpolatedColor(colors, boundaries, normalized_activation || activation); 28 | if (!renderNewlines) { 29 | token = token.replace(/\n/g, '↵') 30 | } 31 | return 40 | {token} 41 | 42 | } 43 | tooltip={
Activation: {activation.toFixed(2)}
} 44 | key={i} 45 | /> 46 | })} 47 |
48 | ) 49 | } 50 | -------------------------------------------------------------------------------- /sae-viewer/src/interpAPI.ts: -------------------------------------------------------------------------------- 1 | import {Feature, FeatureInfo} from './types'; 2 | import {memoizeAsync} from "./utils" 3 | 4 | export const load_file_no_cache = async(path: string) => { 5 | const data = { 6 | path: path 7 | } 8 | const url = new URL("/load_az", window.location.href) 9 | url.port = '8000'; 10 | return await ( 11 | await fetch(url, { 12 | method: "POST", // or 'PUT' 13 | headers: { 14 | "Content-Type": "application/json", 15 | }, 16 | body: JSON.stringify(data), 17 | }) 18 | ).json() 19 | 20 | } 21 | 22 | export const load_file_az = async(path: string) => { 23 | const res = ( 24 | await fetch(path, { 25 | method: "GET", 26 | mode: "cors", 27 | headers: { 28 | "Content-Type": "application/json", 29 | }, 30 | }) 31 | ) 32 | if (!res.ok) { 33 | console.error(`HTTP error: ${res.status} - ${res.statusText}`); 34 | return; 35 | } 36 | return await res.json() 37 | } 38 | 39 | 40 | // export const load_file = memoizeAsync('load_file', load_file_no_cache) 41 | // export const load_file = window.location.host.indexOf('localhost:') === -1 ? load_file_az : load_file_no_cache; 42 | export const load_file = load_file_no_cache; 43 | 44 | 45 | export async function get_feature_info(feature: Feature, ablated?: boolean): Promise { 46 | let load_fn = load_file_az; 47 | let prefix = "https://openaipublic.blob.core.windows.net/sparse-autoencoder/viewer" 48 | if (window.location.host.indexOf('localhost:') !== -1) { 49 | load_fn = load_file; 50 | prefix = "az://openaipublic/sparse-autoencoder/viewer" 51 | // prefix = az://oaialignment/interp/autoencoder-vis/ae 52 | } 53 | 54 | const ae = feature.autoencoder; 55 | const result = await load_fn(`${prefix}/${ae.subject}/${ae.path}/atoms/${feature.atom}${ablated ? '-ablated': ''}.json`) 56 | // console.log('result', result) 57 | return result 58 | } 59 | -------------------------------------------------------------------------------- /sparse_autoencoder/paths.py: -------------------------------------------------------------------------------- 1 | 2 | def v1(location, layer_index): 3 | """ 4 | Details: 5 | - Number of autoencoder latents: 32768 6 | - Number of training tokens: ~64M 7 | - Activation function: ReLU 8 | - L1 regularization strength: 0.01 9 | - Layer normed inputs: false 10 | - NeuronRecord files: 11 | `az://openaipublic/sparse-autoencoder/gpt2-small/{location}/collated_activations/{layer_index}/{latent_index}.json` 12 | """ 13 | assert location in ["mlp_post_act", "resid_delta_mlp"] 14 | assert layer_index in range(12) 15 | return f"az://openaipublic/sparse-autoencoder/gpt2-small/{location}/autoencoders/{layer_index}.pt" 16 | 17 | def v4(location, layer_index): 18 | """ 19 | Details: 20 | same as v1 21 | """ 22 | assert location in ["mlp_post_act", "resid_delta_mlp"] 23 | assert layer_index in range(12) 24 | return f"az://openaipublic/sparse-autoencoder/gpt2-small/{location}_v4/autoencoders/{layer_index}.pt" 25 | 26 | def v5_32k(location, layer_index): 27 | """ 28 | Details: 29 | - Number of autoencoder latents: 2**15 = 32768 30 | - Number of training tokens: TODO 31 | - Activation function: TopK(32) 32 | - L1 regularization strength: n/a 33 | - Layer normed inputs: true 34 | """ 35 | assert location in ["resid_delta_attn", "resid_delta_mlp", "resid_post_attn", "resid_post_mlp"] 36 | assert layer_index in range(12) 37 | # note: it's actually 2**15 and 2**17 ~= 131k 38 | return f"az://openaipublic/sparse-autoencoder/gpt2-small/{location}_v5_32k/autoencoders/{layer_index}.pt" 39 | 40 | def v5_128k(location, layer_index): 41 | """ 42 | Details: 43 | - Number of autoencoder latents: 2**17 = 131072 44 | - Number of training tokens: TODO 45 | - Activation function: TopK(32) 46 | - L1 regularization strength: n/a 47 | - Layer normed inputs: true 48 | """ 49 | assert location in ["resid_delta_attn", "resid_delta_mlp", "resid_post_attn", "resid_post_mlp"] 50 | assert layer_index in range(12) 51 | # note: it's actually 2**15 and 2**17 ~= 131k 52 | return f"az://openaipublic/sparse-autoencoder/gpt2-small/{location}_v5_128k/autoencoders/{layer_index}.pt" 53 | 54 | # NOTE: we have larger autoencoders (up to 8M, with varying n and k) trained on layer 8 resid_post_mlp 55 | # we may release them in the future 56 | -------------------------------------------------------------------------------- /sae-viewer/src/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 11 | 15 | 24 | 25 | 26 | SAE viewer 27 | 28 | 29 | 30 | 31 | 42 | 43 | 44 | 45 |
46 | 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sparse autoencoders 2 | 3 | This repository hosts: 4 | - sparse autoencoders trained on the GPT2-small model's activations. 5 | - a visualizer for the autoencoders' features 6 | 7 | ### Install 8 | 9 | ```sh 10 | pip install git+https://github.com/openai/sparse_autoencoder.git 11 | ``` 12 | 13 | ### Code structure 14 | 15 | See [sae-viewer](./sae-viewer/README.md) to see the visualizer code, hosted publicly [here](https://openaipublic.blob.core.windows.net/sparse-autoencoder/sae-viewer/index.html). 16 | 17 | See [model.py](./sparse_autoencoder/model.py) for details on the autoencoder model architecture. 18 | See [train.py](./sparse_autoencoder/train.py) for autoencoder training code. 19 | See [paths.py](./sparse_autoencoder/paths.py) for more details on the available autoencoders. 20 | 21 | ### Example usage 22 | 23 | ```py 24 | import torch 25 | import blobfile as bf 26 | import transformer_lens 27 | import sparse_autoencoder 28 | 29 | # Extract neuron activations with transformer_lens 30 | model = transformer_lens.HookedTransformer.from_pretrained("gpt2", center_writing_weights=False) 31 | device = next(model.parameters()).device 32 | 33 | prompt = "This is an example of a prompt that" 34 | tokens = model.to_tokens(prompt) # (1, n_tokens) 35 | with torch.no_grad(): 36 | logits, activation_cache = model.run_with_cache(tokens, remove_batch_dim=True) 37 | 38 | layer_index = 6 39 | location = "resid_post_mlp" 40 | 41 | transformer_lens_loc = { 42 | "mlp_post_act": f"blocks.{layer_index}.mlp.hook_post", 43 | "resid_delta_attn": f"blocks.{layer_index}.hook_attn_out", 44 | "resid_post_attn": f"blocks.{layer_index}.hook_resid_mid", 45 | "resid_delta_mlp": f"blocks.{layer_index}.hook_mlp_out", 46 | "resid_post_mlp": f"blocks.{layer_index}.hook_resid_post", 47 | }[location] 48 | 49 | with bf.BlobFile(sparse_autoencoder.paths.v5_32k(location, layer_index), mode="rb") as f: 50 | state_dict = torch.load(f) 51 | autoencoder = sparse_autoencoder.Autoencoder.from_state_dict(state_dict) 52 | autoencoder.to(device) 53 | 54 | input_tensor = activation_cache[transformer_lens_loc] 55 | 56 | input_tensor_ln = input_tensor 57 | 58 | with torch.no_grad(): 59 | latent_activations, info = autoencoder.encode(input_tensor_ln) 60 | reconstructed_activations = autoencoder.decode(latent_activations, info) 61 | 62 | normalized_mse = (reconstructed_activations - input_tensor).pow(2).sum(dim=1) / (input_tensor).pow(2).sum(dim=1) 63 | print(location, normalized_mse) 64 | ``` 65 | -------------------------------------------------------------------------------- /sae-viewer/src/autoencoder_registry.tsx: -------------------------------------------------------------------------------- 1 | import { Autoencoder, Feature } from './types'; 2 | import { Route } from "react-router-dom" 3 | 4 | 5 | export const GPT2_LAYER_FAMILY_32k = { 6 | subject: 'gpt2-small', 7 | name: 'v5_32k', 8 | label: 'n=32768, k=32, all locations', 9 | selectors: [ 10 | {key: 'layer', label: 'Layer', values: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11']}, 11 | {key: 'location', label: 'Location', values: ['resid_post_attn', 'resid_post_mlp']}, 12 | ], 13 | default_H: (H: {[key: string]: string}) => ({ 14 | layer: H.layer || '8', location: H.location || 'resid_post_mlp' 15 | }), 16 | get_ae: (H: {[key: string]: string}) => ({ 17 | subject: 'gpt2-small', 18 | family: 'v5_32k', 19 | H: {layer: H.layer, location: H.location}, 20 | path: `v5_32k/layer_${H.layer}/${H.location}`, 21 | num_features: 32768, 22 | }), 23 | } 24 | 25 | export const GPT2_LAYER_FAMILY_128k = { 26 | subject: 'gpt2-small', 27 | name: 'v5_128k', 28 | label: 'n=131072, k=32, all locations', 29 | selectors: [ 30 | {key: 'layer', label: 'Layer', values: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11']}, 31 | {key: 'location', label: 'Location', values: ['resid_post_attn', 'resid_post_mlp']}, 32 | ], 33 | default_H: (H: {[key: string]: string}) => ({ 34 | layer: H.layer || '8', location: H.location || 'resid_post_mlp' 35 | }), 36 | get_ae: (H: {[key: string]: string}) => ({ 37 | subject: 'gpt2-small', 38 | family: 'v5_128k', 39 | H: {layer: H.layer, location: H.location}, 40 | path: `v5_128k/layer_${H.layer}/${H.location}`, 41 | num_features: 131072, 42 | no_effects: true, 43 | }), 44 | } 45 | 46 | export const GPT2_NK_SWEEP = { 47 | subject: 'gpt2-small', 48 | name: 'v5_l8_postmlp', 49 | label: 'layer 8, resid post MLP, n/k sweep', 50 | selectors: [ 51 | {key: 'num_features', label: 'Total features', values: ['2048', '8192', '32768', '131072', '524288', '2097152']}, 52 | {key: 'num_active_features', label: 'Active features', values: ['8', '16', '32', '64', '128', '256', '512']}, 53 | ], 54 | default_H: (H: {[key: string]: string}) => ({ 55 | num_features: H.num_features || '2097152', num_active_features: H.num_active_features || '16' 56 | }), 57 | get_ae: (H: {[key: string]: string}) => ({ 58 | subject: 'gpt2-small', 59 | family: 'v5_l8_postmlp', 60 | H: {num_features: H.num_features, num_active_features: H.num_active_features}, 61 | path: `v5_l8_postmlp/n${H.num_features}/k${H.num_active_features}`, 62 | num_features: H.num_features, 63 | }), 64 | } 65 | 66 | export const GPT4_16m = { 67 | subject: 'gpt4', 68 | name: 'v5_latelayer_postmlp', 69 | label: 'n=16M', 70 | warning: 'Only 65536 features available. Activations shown on The Pile (uncopyrighted) instead of our internal training dataset.', 71 | selectors: [], 72 | default_H: (H: {[key: string]: string}) => ({}), 73 | get_ae: (H: {[key: string]: string}) => ({ 74 | subject: 'gpt4', 75 | family: 'v5_latelayer_postmlp', 76 | H: {}, 77 | path: `v5_latelayer_postmlp/n16777216/k256`, 78 | num_features: 65536, 79 | no_effects: true, 80 | }), 81 | } 82 | 83 | // export const DEFAULT_AUTOENCODER = GPT2_NK_SWEEP.get_ae( 84 | // GPT2_NK_SWEEP.default_H({}) 85 | // ); 86 | export const DEFAULT_AUTOENCODER = GPT4_16m.get_ae( 87 | GPT4_16m.default_H({}) 88 | ); 89 | 90 | export const AUTOENCODER_FAMILIES = Object.fromEntries( 91 | [ 92 | GPT2_NK_SWEEP, 93 | GPT2_LAYER_FAMILY_32k, 94 | GPT2_LAYER_FAMILY_128k, 95 | GPT4_16m, 96 | ].map((family) => [family.name, family]) 97 | ); 98 | 99 | export const SUBJECT_MODELS = ['gpt2-small', 'gpt4']; 100 | 101 | export function pathForFeature(feature: Feature) { 102 | let res = `/model/${feature.autoencoder.subject}/family/${feature.autoencoder.family}`; 103 | // for (const [key, value] of Object.entries(feature.autoencoder.H)) { 104 | // res += `/${key}/${value}`; 105 | // } 106 | for (const selector of AUTOENCODER_FAMILIES[feature.autoencoder.family].selectors) { 107 | res += `/${selector.key}/${feature.autoencoder.H[selector.key]}`; 108 | } 109 | res += `/feature/${feature.atom}`; 110 | console.log('res', res) 111 | return res 112 | } 113 | 114 | -------------------------------------------------------------------------------- /sae-viewer/src/types.ts: -------------------------------------------------------------------------------- 1 | import { scaleLinear } from "d3-scale" 2 | import { min, max, flatten } from "lodash" 3 | 4 | export type Autoencoder = { 5 | subject: string, 6 | num_features: number, 7 | family: string, 8 | H: {[key: string]: any}, 9 | path: string, 10 | }; 11 | 12 | 13 | export type Feature = { 14 | autoencoder: Autoencoder; 15 | atom: number; 16 | } 17 | 18 | export type TokenAndActivation = { 19 | token: string, 20 | activation: number 21 | normalized_activation?: number 22 | } 23 | export type TokenSequence = TokenAndActivation[] 24 | 25 | export type SequenceInfo = { 26 | density: number, 27 | doc_id: number, 28 | idx: number, // which act this document was selected for 29 | acts: number[], 30 | act: number, 31 | tokens: string[], 32 | token_ints: number[], 33 | normalized_acts?: number[], 34 | ablate_loss_diff?: number[], 35 | kl?: number[], 36 | top_downvote_tokens_logits?: string[][], 37 | top_downvotes_logits?: number[][], 38 | top_upvote_tokens_logits?: string[][], 39 | top_upvotes_logits?: number[][], 40 | top_downvote_tokens_probs?: string[][], 41 | top_downvotes_probs?: number[][], 42 | top_upvote_tokens_probs?: string[][], 43 | top_upvotes_probs?: number[][], 44 | } 45 | 46 | export function zip_sequence(sequence: SequenceInfo) { 47 | return sequence.tokens.map((token, idx) => ({ 48 | token, 49 | highlight: idx === sequence.idx, 50 | activation: sequence.acts[idx], 51 | normalized_activation: sequence.normalized_acts ? sequence.normalized_acts[idx] : undefined 52 | })); 53 | } 54 | 55 | export type FeatureInfo = { 56 | density: number, 57 | mean_act: number, 58 | mean_act_squared: number, 59 | hist: {[key: number]: number}, 60 | random: SequenceInfo[], 61 | top: SequenceInfo[], 62 | } 63 | 64 | export const normalizeSequences = (...sequences: SequenceInfo[][]) => { 65 | // console.log('sequences', sequences) 66 | let flattened: SequenceInfo[] = flatten(sequences) 67 | const maxActivation = Math.max(0, ...flattened.map((s) => Math.max(...s.acts))); 68 | const scaler = scaleLinear() 69 | // Even though we're only displaying positive activations, we still need to scale in a way that 70 | // accounts for the existence of negative activations, since our color scale includes them. 71 | .domain([0, maxActivation]) 72 | .range([0, 1]) 73 | 74 | sequences.map((seqs) => seqs.map((s) => { 75 | s.normalized_acts = s.acts.map((activation) => scaler(activation)); 76 | })) 77 | } 78 | 79 | export const normalizeTokenActs = (...sequences: TokenSequence[][]) => { 80 | // console.log('sequences', sequences) 81 | let flattened: TokenAndActivation[] = flatten(flatten(sequences)) 82 | // Replace all activations less than 0 in data.tokens with 0. This matches the format in the 83 | // top + random activation records displayed in the main grid. 84 | flattened = flattened.map(({token, activation}) => { 85 | return { 86 | token, 87 | activation: Math.max(activation, 0) 88 | } 89 | }) 90 | const maxActivation = max(flattened.map((ta) => ta.activation)) || 0; 91 | const scaler = scaleLinear() 92 | // Even though we're only displaying positive activations, we still need to scale in a way that 93 | // accounts for the existence of negative activations, since our color scale includes them. 94 | .domain([0, maxActivation]) 95 | .range([0, 1]) 96 | 97 | return sequences.map((seq) => seq.map((tas) => tas.map(({ token, activation }) => ({ 98 | token, 99 | activation, 100 | normalized_activation: scaler(activation), 101 | })))) 102 | } 103 | 104 | export type Color = {r: number, g: number, b: number}; 105 | export function interpolateColor(color_l: Color, color_r: Color, value: number) { 106 | const color = { 107 | r: Math.round(color_l.r + (color_r.r - color_l.r) * value), 108 | g: Math.round(color_l.g + (color_r.g - color_l.g) * value), 109 | b: Math.round(color_l.b + (color_r.b - color_l.b) * value), 110 | } 111 | return color 112 | } 113 | 114 | export function getInterpolatedColor(colors: Color[], boundaries: number[], value: number) { 115 | const index = boundaries.findIndex((boundary) => boundary >= value) 116 | const colorIndex = Math.max(0, index - 1) 117 | const color_left = colors[colorIndex] 118 | const color_right = colors[colorIndex + 1] 119 | const boundary_left = boundaries[colorIndex] 120 | const boundary_right = boundaries[colorIndex + 1] 121 | const ratio = (value - boundary_left) / (boundary_right - boundary_left) 122 | const color = interpolateColor(color_left, color_right, ratio) 123 | return color 124 | } 125 | 126 | export const DEFAULT_COLORS = [ 127 | { r: 255, g: 0, b: 0 }, 128 | { r: 255, g: 255, b: 255 }, 129 | { r: 0, g: 255, b: 0 }, 130 | ] 131 | export const DEFAULT_BOUNDARIES = [ 132 | -1, 0, 1 133 | ] 134 | 135 | -------------------------------------------------------------------------------- /sae-viewer/src/components/tokenAblationmap.tsx: -------------------------------------------------------------------------------- 1 | import React from "react" 2 | import { interpolateColor, Color, getInterpolatedColor, DEFAULT_COLORS, SequenceInfo } from '../types' 3 | import Tooltip from './tooltip' 4 | import { scaleLinear } from "d3-scale" 5 | 6 | type Props = { 7 | info: SequenceInfo, 8 | colors?: Color[], 9 | boundaries?: number[], 10 | renderNewlines?: boolean, 11 | } 12 | 13 | export const normalizeToUnitInterval = (arr: number[]) => { 14 | const max = Math.max(...arr); 15 | const min = Math.min(...arr); 16 | const max_abs = Math.max(Math.abs(max), Math.abs(min)); 17 | const rescale = scaleLinear() 18 | // Even though we're only displaying positive activations, we still need to scale in a way that 19 | // accounts for the existence of negative activations, since our color scale includes them. 20 | .domain([-max_abs, max_abs]) 21 | .range([-1, 1]) 22 | 23 | return arr.map((x) => rescale(x)); 24 | } 25 | 26 | 27 | export default function TokenAblationmap({ info, colors = DEFAULT_COLORS, renderNewlines }: Props) { 28 | //
29 | if (!info.ablate_loss_diff) { 30 | return <> ; 31 | } 32 | const lossDiffsNorm = normalizeToUnitInterval(info.ablate_loss_diff.map((x) => (-x))); 33 | return ( 34 |
35 | {info.tokens.map((token, idx) => { 36 | const highlight = idx === info.idx; 37 | const loss_diff = (idx === 0) ? 0: info.ablate_loss_diff[idx-1]; 38 | const kl = (idx === 0) ? 0: info.kl[idx-1]; 39 | const activation = info.acts[idx]; 40 | const top_downvotes = (idx === 0) ? [] : info.top_downvotes_logits[idx-1]; 41 | const top_downvote_tokens = (idx === 0) ? [] : info.top_downvote_tokens_logits[idx-1]; 42 | const top_upvotes = (idx === 0) ? [] : info.top_upvotes_logits[idx-1]; 43 | const top_upvote_tokens = (idx === 0) ? [] : info.top_upvote_tokens_logits[idx-1]; 44 | // const top_downvotes_weighted = (idx === 0) ? [] : info.top_downvotes_weighted[idx-1]; 45 | // const top_downvote_tokens_weighted = (idx === 0) ? [] : info.top_downvote_tokens_weighted[idx-1]; 46 | // const top_upvotes_weighted = (idx === 0) ? [] : info.top_upvotes_weighted[idx-1]; 47 | // const top_upvote_tokens_weighted = (idx === 0) ? [] : info.top_upvote_tokens_weighted[idx-1]; 48 | const top_downvotes_probs = (idx === 0) ? [] : info.top_downvotes_probs[idx-1]; 49 | const top_downvote_tokens_probs = (idx === 0) ? [] : info.top_downvote_tokens_probs[idx-1]; 50 | const top_upvotes_probs = (idx === 0) ? [] : info.top_upvotes_probs[idx-1]; 51 | const top_upvote_tokens_probs = (idx === 0) ? [] : info.top_upvote_tokens_probs[idx-1]; 52 | const color = getInterpolatedColor(colors, [-1, 0, 1], (idx === 0) ? 0 : lossDiffsNorm[idx-1]); 53 | if (!renderNewlines) { 54 | token = token.replace(/\n/g, '↵') 55 | } 56 | return 65 | {token} 66 | 67 | } 68 | tooltip={(idx <= info.idx) ?
(prediction prior to ablated token)
:
69 | Loss diff: {loss_diff.toExponential(2)}
70 | KL(clean || ablated): {kl.toExponential(2)}
71 | Logit diffs: 72 | 73 | 74 | 75 | { 76 | ['', /*' (weighted)',*/ ' (probs)'].map((suffix, i) => { 77 | return 78 | 79 | 80 | 81 | }) 82 | } 83 | 84 | 85 | 86 | { 87 | top_upvotes.map((upvote, j) => { 88 | const downvote = top_downvotes[j]; 89 | return 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | }) 100 | } 101 | 102 |
Upvoted {suffix}Downvoted {suffix}
{upvote.toExponential(1)}{top_upvote_tokens[j]}{downvote.toExponential(1)}{top_downvote_tokens[j]}{top_upvotes_probs[j].toExponential(1)}{top_upvote_tokens_probs[j]}{top_downvotes_probs[j].toExponential(1)}{top_downvote_tokens_probs[j]}
103 |
} 104 | key={idx} 105 | /> 106 | })} 107 |
108 | ) 109 | 110 | } 111 | -------------------------------------------------------------------------------- /sae-viewer/src/components/featureSelect.tsx: -------------------------------------------------------------------------------- 1 | import FeatureInfo from "./components/featureInfo" 2 | import React, { useEffect, useState, FormEvent } from "react" 3 | import Welcome from "./welcome" 4 | import { Feature } from "./types" 5 | import { useState } from "react" 6 | 7 | import { pathForFeature, DEFAULT_AUTOENCODER, SUBJECT_MODELS, AUTOENCODER_FAMILIES } from "../autoencoder_registry" 8 | 9 | type FeatureSelectProps = { 10 | init_feature: Feature, 11 | onFeatureChange?: (feature: Feature) => void, 12 | onFeatureSubmit: (feature: Feature) => void, 13 | } 14 | 15 | export default function FeatureSelect({init_feature, onFeatureChange, onFeatureSubmit, show_go}: FeatureSelectProps) { 16 | let [feature, setFeature] = useState(init_feature); 17 | let family = AUTOENCODER_FAMILIES[feature.autoencoder.family]; 18 | let [warningAcknowledged, setWarningAcknowledged] = useState(localStorage.getItem('warningAcknowledged') === 'true'); 19 | let changeFeature = (feature: Feature) => { 20 | onFeatureChange && onFeatureChange(feature); 21 | setFeature(feature); 22 | } 23 | // console.log('features', feature.autoencoder.num_features) 24 | const feelingLuckySubmit = () => { 25 | const atom = Math.floor(Math.random() * feature.autoencoder.num_features); 26 | const random_feature: Feature = {atom, autoencoder: feature.autoencoder} 27 | setFeature(random_feature); 28 | onFeatureSubmit(random_feature); 29 | return false 30 | } 31 | const acknowledgeWarning = () => { 32 | localStorage.setItem('warningAcknowledged', 'true'); 33 | setWarningAcknowledged(true); 34 | } 35 | 36 | if (!warningAcknowledged) { 37 | return () 43 | } 44 | 45 | return ( 46 | <> 47 |

48 | 49 | Subject model {" "} 50 | 60 | 61 | {" "} 62 |
63 | 64 | Autoencoder {" "} 65 | family {" "} 66 | 76 | 77 | { 78 | AUTOENCODER_FAMILIES[feature.autoencoder.family].selectors.map((selector) => ( 79 | 80 | {" "} 81 | {selector.label || selector.key} 82 | 94 | 95 | )) 96 | } 97 | { 98 | family.warning ? Note: {family.warning} : null 99 | } 100 | 101 |
102 | Feature 103 | (!isNaN(parseInt(e.target.value))) && changeFeature({...feature, atom: parseInt(e.target.value)})} 111 | className="border border-gray-300 rounded-md p-2" 112 | /> 113 | { 114 | show_go && 115 | 128 | } 129 |
130 | 137 | 138 | 139 |
140 | 141 |

142 | 143 | ) 144 | } 145 | -------------------------------------------------------------------------------- /sae-viewer/src/components/featureInfo.tsx: -------------------------------------------------------------------------------- 1 | import React, { useEffect, useState, useRef } from "react" 2 | import { normalizeSequences, SequenceInfo, Feature, FeatureInfo } from "../types" 3 | import TokenHeatmap from "./tokenHeatmap"; 4 | import TokenAblationmap from "./tokenAblationmap"; 5 | import Histogram from "./histogram" 6 | import Tooltip from "./tooltip" 7 | 8 | import {get_feature_info} from "../interpAPI" 9 | 10 | export default ({ feature }: {feature: Feature}) => { 11 | const [data, setData] = useState(null as FeatureInfo | null) 12 | const [showingMore, setShowingMore] = useState({}) 13 | const [renderNewlines, setRenderNewlines] = useState(false) 14 | const [isLoading, setIsLoading] = useState(true) 15 | const [got_error, setError] = useState(null) 16 | const currentFeatureRef = useRef(feature); 17 | 18 | 19 | useEffect(() => { 20 | async function fetchData() { 21 | setIsLoading(true) 22 | try { 23 | currentFeatureRef.current = feature; // Update current feature in ref on each effect run 24 | const result = await get_feature_info(feature) 25 | if (currentFeatureRef.current !== feature) { 26 | return; 27 | } 28 | normalizeSequences(result.top, result.random) 29 | result.top.sort((a, b) => b.act - a.act); 30 | setData(result) 31 | setIsLoading(false) 32 | setError(null); 33 | } catch (e) { 34 | setError(e); 35 | } 36 | try { 37 | const result = await get_feature_info(feature, true) 38 | if (currentFeatureRef.current !== feature) { 39 | return; 40 | } 41 | normalizeSequences(result.top, result.random) 42 | result.top.sort((a, b) => b.act - a.act); 43 | setData(result) 44 | setIsLoading(false) 45 | setError(null); 46 | } catch (e) { 47 | setError('Note: ablation effects data not available for this model'); 48 | } 49 | } 50 | fetchData() 51 | }, [feature]) 52 | 53 | if (isLoading) { 54 | return ( 55 |
56 |
57 |
loading top dataset examples
58 | { 59 | got_error ? Error loading data: {got_error} : null 60 | } 61 |
62 | ) 63 | } 64 | if (!data) { 65 | throw new Error('no data. this should not happen.') 66 | } 67 | 68 | const all_sequences = [] 69 | all_sequences.push({ 70 | // label: '[0, 1] (Random)', 71 | label: 'Random positive activations', 72 | sequences: data.random, 73 | default_show: 5, 74 | }) 75 | all_sequences.push({ 76 | // label: '[0.999, 1] (Top quantile, sorted. 50 of 50000)', 77 | label: 'Top activations', 78 | sequences: data.top, 79 | default_show: 5, 80 | }) 81 | 82 | // const activations = data.top_activations; 83 | return ( 84 |
85 | { 86 | got_error ? {got_error} : null 87 | } 88 |
89 |
90 | 91 |
92 | 93 | 94 | 95 | 97 | 98 | 99 | 100 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 |
0]'}/> 96 | {data.density.toExponential(2)}
101 | {data.mean_act ? data.mean_act.toExponential(2) : 'data not available'}
E[a2]}/>{data.mean_act_squared ? data.mean_act_squared.toExponential(2): 'data not available'}
E[a3]/(E[a2])1.5}/>{data.skew ? data.skew.toExponential(2) : 'data not available'}
E[a4]/(E[a2])2}/>{data.kurtosis ? data.kurtosis.toExponential(2) : 'data not available'}
118 |
119 | { 120 | all_sequences.map(({label, sequences, default_show}, idx) => { 121 | // console.log('sequences', sequences) 122 | const n_show = showingMore[label] ? sequences.length : default_show; 123 | return ( 124 | 125 |

126 | {label} 127 | 131 | 135 |

136 | 137 | 138 | 139 | 140 | {sequences.length && sequences[0].ablate_loss_diff && } 141 | 142 | 143 | 144 | {sequences.slice(0, n_show).map((sequence, i) => ( 145 | 146 | 147 | 150 | { 151 | sequence.ablate_loss_diff && 152 | 155 | } 156 | 157 | ))} 158 | 159 |
Doc IDTokenActivationActivationsEffects
{sequence.doc_id}{sequence.idx}{sequence.act.toFixed(2)} 148 | 149 | 153 | 154 |
160 |
161 | ) 162 | }) 163 | } 164 |
165 | ) 166 | } 167 | -------------------------------------------------------------------------------- /sae-viewer/src/App.css: -------------------------------------------------------------------------------- 1 | @tailwind base; 2 | @tailwind components; 3 | @tailwind utilities; 4 | 5 | select { 6 | margin: 5px; 7 | } 8 | 9 | :root { 10 | --secondary-color: #0d978b; 11 | --accent-color: #efefef; 12 | } 13 | 14 | table.activations-table { 15 | border: 1px solid gray; 16 | border-radius: 3px; 17 | border-spacing: 0; 18 | } 19 | table.activations-table td, table.activations-table th { 20 | border-bottom: 1px solid gray; 21 | border-right: 1px solid gray; 22 | border-left: 1px solid gray; 23 | } 24 | table.activations-table tr:last-child > td { 25 | border-bottom: none; 26 | } 27 | 28 | .full-width{ 29 | width: 100vw; 30 | position: relative; 31 | margin-left: -50vw; 32 | left: 50%; 33 | } 34 | 35 | .App { 36 | text-align: center; 37 | } 38 | 39 | .center { 40 | text-align: center; 41 | } 42 | 43 | .App-logo { 44 | height: 40vmin; 45 | pointer-events: none; 46 | } 47 | 48 | @media (prefers-reduced-motion: no-preference) { 49 | .App-logo { 50 | animation: App-logo-spin infinite 20s linear; 51 | } 52 | } 53 | 54 | .App h1 { 55 | font-size: 1.75rem; 56 | } 57 | 58 | .App-article { 59 | background-color: #282c34; 60 | min-height: 100vh; 61 | display: flex; 62 | flex-direction: column; 63 | align-items: center; 64 | justify-content: center; 65 | font-size: calc(10px + 2vmin); 66 | color: white; 67 | } 68 | 69 | .App-link { 70 | color: #61dafb; 71 | } 72 | 73 | @keyframes App-logo-spin { 74 | from { 75 | transform: rotate(0deg); 76 | } 77 | to { 78 | transform: rotate(360deg); 79 | } 80 | } 81 | 82 | 83 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */ 84 | /* Structure 85 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */ 86 | 87 | body { 88 | margin: 0; 89 | padding: 0 1em; 90 | font-size: 12pt; 91 | } 92 | 93 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */ 94 | /* Typography 95 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */ 96 | 97 | h1 { 98 | font-size: 24pt; 99 | font-weight: 500; 100 | padding: 1em 0 0; 101 | display: block; 102 | color: #000; 103 | } 104 | h3 { padding: 0 0; } 105 | h2 { padding: 1em 0 0.5em 0; } 106 | h4, h5 { 107 | text-transform: uppercase; 108 | margin: 1em 0; 109 | justify-tracks: space-between; 110 | font-family: var(--sans-serif); 111 | font-size: 12pt; 112 | font-weight: 600; 113 | } 114 | h2, h3 { font-weight: 500; font-style: italic; } 115 | subtitle { 116 | color: #555; 117 | font-size: 18pt; 118 | font-style: italic; 119 | padding: 0; 120 | display: block; 121 | margin-bottom: 1em 122 | } 123 | 124 | a { 125 | transition: all .05s ease-in-out; 126 | color: #5c60c3 !important; 127 | font-style: normal; 128 | } 129 | a:hover { color: var(--accent-color)!important; } 130 | code, pre { color: var(--inline-code-color); 131 | background-color: #eee; border-radius: 3px; } 132 | pre { padding: 1em; margin: 2em 0; } 133 | code { padding: 0.3em; } 134 | .text-secondary, h3, h5 { color: var(--secondary-color); } 135 | .text-primary, h2,h4 { color: var(--primary-color); } 136 | 137 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */ 138 | /* Images 139 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */ 140 | 141 | img#logo { 142 | width: 50%; 143 | margin: 3em 0 0 144 | } 145 | 146 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */ 147 | /* Alerts */ 148 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */ 149 | 150 | .alert { 151 | font-weight: 600; 152 | font-style: italic; 153 | display: block; 154 | background-color: #fff7f7; 155 | padding: 1em; 156 | margin: 0; 157 | border-radius: 5px; 158 | color: #f25555 159 | } 160 | .alert.cool { 161 | background-color: #f3f0fc; 162 | color: #7155cf; 163 | } 164 | .flash-alert { 165 | display: inline-block; 166 | transition: ease-in-out 1s; 167 | font-size: 14pt; 168 | margin: 1em 0; 169 | padding-top: 0.5em; 170 | } 171 | .flash-alert.success { 172 | color: #000; 173 | } 174 | .flash-alert.failure { 175 | color: red; 176 | } 177 | .flash-alert.hidden { 178 | display: none; 179 | } 180 | 181 | 182 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */ 183 | /* Sidenotes & Superscripts */ 184 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */ 185 | 186 | body { counter-reset: count; } 187 | p { whitespace: nowrap; } 188 | 189 | /* Different behavior if the screen is too 190 | narrow to show a sidenote on the side. */ 191 | 192 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */ 193 | /* Buttons */ 194 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */ 195 | 196 | @media print { 197 | a.btn, button { 198 | display: none!important 199 | } 200 | } 201 | 202 | @media screen { 203 | a.btn, button { 204 | border-radius: 3px; 205 | color: #000 !important; 206 | text-decoration: none !important; 207 | font-size: 11pt; 208 | border: 1px solid #000; 209 | padding: 0.5em 1em; 210 | font-family: -apple-system, 211 | BlinkMacSystemFont, 212 | "avenir next", 213 | avenir, 214 | helvetica, 215 | "helvetica neue", 216 | ubuntu, 217 | roboto, 218 | noto, 219 | "segoe ui", 220 | arial, 221 | sans-serif !important; 222 | background: #fff; 223 | font-weight: 500; 224 | transition: all .05s ease-in-out,box-shadow-color .025s ease-in-out; 225 | display: inline-block; 226 | } 227 | 228 | a.btn:hover, button:hover { 229 | cursor: pointer; 230 | } 231 | a.btn:active, button.active, button:active { 232 | border: 1px solid; 233 | } 234 | a.btn.small,button.small { 235 | border: 1px solid #000; 236 | padding: .6em 1em; 237 | font-weight: 500 238 | } 239 | a.btn.small:hover,button.small:hover { 240 | } 241 | a.btn.small:active,button.small:active { 242 | } 243 | } 244 | 245 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */ 246 | /* Blockquotes & Epigraphs 247 | /* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */ 248 | 249 | blockquote { 250 | margin: 1em; 251 | } 252 | div>blockquote>p { 253 | font-size: 13pt; 254 | color: #555; 255 | font-style: normal!important; 256 | margin: 0; 257 | padding: 1em 0 1.5em 258 | } 259 | blockquote > blockquote { 260 | padding: 0.5em 2em 1em 1.5em !important; 261 | } 262 | 263 | blockquote > blockquote, 264 | blockquote > blockquote > p { 265 | font-size: 14pt; 266 | padding: 0; 267 | margin: 0; 268 | text-align: center; 269 | font-style: italic; 270 | color: var(--epigraph-color); 271 | } 272 | blockquote footer { 273 | font-size: 12pt; 274 | text-align: inherit; 275 | display: block; 276 | font-style: normal; 277 | margin: 1em; 278 | color: #aaa; 279 | } 280 | -------------------------------------------------------------------------------- /sparse_autoencoder/model.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Any 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def LN(x: torch.Tensor, eps: float = 1e-5) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 9 | mu = x.mean(dim=-1, keepdim=True) 10 | x = x - mu 11 | std = x.std(dim=-1, keepdim=True) 12 | x = x / (std + eps) 13 | return x, mu, std 14 | 15 | 16 | class Autoencoder(nn.Module): 17 | """Sparse autoencoder 18 | 19 | Implements: 20 | latents = activation(encoder(x - pre_bias) + latent_bias) 21 | recons = decoder(latents) + pre_bias 22 | """ 23 | 24 | def __init__( 25 | self, n_latents: int, n_inputs: int, activation: Callable = nn.ReLU(), tied: bool = False, 26 | normalize: bool = False 27 | ) -> None: 28 | """ 29 | :param n_latents: dimension of the autoencoder latent 30 | :param n_inputs: dimensionality of the original data (e.g residual stream, number of MLP hidden units) 31 | :param activation: activation function 32 | :param tied: whether to tie the encoder and decoder weights 33 | """ 34 | super().__init__() 35 | 36 | self.pre_bias = nn.Parameter(torch.zeros(n_inputs)) 37 | self.encoder: nn.Module = nn.Linear(n_inputs, n_latents, bias=False) 38 | self.latent_bias = nn.Parameter(torch.zeros(n_latents)) 39 | self.activation = activation 40 | if tied: 41 | self.decoder: nn.Linear | TiedTranspose = TiedTranspose(self.encoder) 42 | else: 43 | self.decoder = nn.Linear(n_latents, n_inputs, bias=False) 44 | self.normalize = normalize 45 | 46 | self.stats_last_nonzero: torch.Tensor 47 | self.latents_activation_frequency: torch.Tensor 48 | self.latents_mean_square: torch.Tensor 49 | self.register_buffer("stats_last_nonzero", torch.zeros(n_latents, dtype=torch.long)) 50 | self.register_buffer( 51 | "latents_activation_frequency", torch.ones(n_latents, dtype=torch.float) 52 | ) 53 | self.register_buffer("latents_mean_square", torch.zeros(n_latents, dtype=torch.float)) 54 | 55 | def encode_pre_act(self, x: torch.Tensor, latent_slice: slice = slice(None)) -> torch.Tensor: 56 | """ 57 | :param x: input data (shape: [batch, n_inputs]) 58 | :param latent_slice: slice of latents to compute 59 | Example: latent_slice = slice(0, 10) to compute only the first 10 latents. 60 | :return: autoencoder latents before activation (shape: [batch, n_latents]) 61 | """ 62 | x = x - self.pre_bias 63 | latents_pre_act = F.linear( 64 | x, self.encoder.weight[latent_slice], self.latent_bias[latent_slice] 65 | ) 66 | return latents_pre_act 67 | 68 | def preprocess(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]: 69 | if not self.normalize: 70 | return x, dict() 71 | x, mu, std = LN(x) 72 | return x, dict(mu=mu, std=std) 73 | 74 | def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]: 75 | """ 76 | :param x: input data (shape: [batch, n_inputs]) 77 | :return: autoencoder latents (shape: [batch, n_latents]) 78 | """ 79 | x, info = self.preprocess(x) 80 | return self.activation(self.encode_pre_act(x)), info 81 | 82 | def decode(self, latents: torch.Tensor, info: dict[str, Any] | None = None) -> torch.Tensor: 83 | """ 84 | :param latents: autoencoder latents (shape: [batch, n_latents]) 85 | :return: reconstructed data (shape: [batch, n_inputs]) 86 | """ 87 | ret = self.decoder(latents) + self.pre_bias 88 | if self.normalize: 89 | assert info is not None 90 | ret = ret * info["std"] + info["mu"] 91 | return ret 92 | 93 | def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 94 | """ 95 | :param x: input data (shape: [batch, n_inputs]) 96 | :return: autoencoder latents pre activation (shape: [batch, n_latents]) 97 | autoencoder latents (shape: [batch, n_latents]) 98 | reconstructed data (shape: [batch, n_inputs]) 99 | """ 100 | x, info = self.preprocess(x) 101 | latents_pre_act = self.encode_pre_act(x) 102 | latents = self.activation(latents_pre_act) 103 | recons = self.decode(latents, info) 104 | 105 | # set all indices of self.stats_last_nonzero where (latents != 0) to 0 106 | self.stats_last_nonzero *= (latents == 0).all(dim=0).long() 107 | self.stats_last_nonzero += 1 108 | 109 | return latents_pre_act, latents, recons 110 | 111 | @classmethod 112 | def from_state_dict( 113 | cls, state_dict: dict[str, torch.Tensor], strict: bool = True 114 | ) -> "Autoencoder": 115 | n_latents, d_model = state_dict["encoder.weight"].shape 116 | 117 | # Retrieve activation 118 | activation_class_name = state_dict.pop("activation", "ReLU") 119 | activation_class = ACTIVATIONS_CLASSES.get(activation_class_name, nn.ReLU) 120 | normalize = activation_class_name == "TopK" # NOTE: hacky way to determine if normalization is enabled 121 | activation_state_dict = state_dict.pop("activation_state_dict", {}) 122 | if hasattr(activation_class, "from_state_dict"): 123 | activation = activation_class.from_state_dict( 124 | activation_state_dict, strict=strict 125 | ) 126 | else: 127 | activation = activation_class() 128 | if hasattr(activation, "load_state_dict"): 129 | activation.load_state_dict(activation_state_dict, strict=strict) 130 | 131 | autoencoder = cls(n_latents, d_model, activation=activation, normalize=normalize) 132 | # Load remaining state dict 133 | autoencoder.load_state_dict(state_dict, strict=strict) 134 | return autoencoder 135 | 136 | def state_dict(self, destination=None, prefix="", keep_vars=False): 137 | sd = super().state_dict(destination, prefix, keep_vars) 138 | sd[prefix + "activation"] = self.activation.__class__.__name__ 139 | if hasattr(self.activation, "state_dict"): 140 | sd[prefix + "activation_state_dict"] = self.activation.state_dict() 141 | return sd 142 | 143 | 144 | class TiedTranspose(nn.Module): 145 | def __init__(self, linear: nn.Linear): 146 | super().__init__() 147 | self.linear = linear 148 | 149 | def forward(self, x: torch.Tensor) -> torch.Tensor: 150 | assert self.linear.bias is None 151 | return F.linear(x, self.linear.weight.t(), None) 152 | 153 | @property 154 | def weight(self) -> torch.Tensor: 155 | return self.linear.weight.t() 156 | 157 | @property 158 | def bias(self) -> torch.Tensor: 159 | return self.linear.bias 160 | 161 | 162 | class TopK(nn.Module): 163 | def __init__(self, k: int, postact_fn: Callable = nn.ReLU()) -> None: 164 | super().__init__() 165 | self.k = k 166 | self.postact_fn = postact_fn 167 | 168 | def forward(self, x: torch.Tensor) -> torch.Tensor: 169 | topk = torch.topk(x, k=self.k, dim=-1) 170 | values = self.postact_fn(topk.values) 171 | # make all other values 0 172 | result = torch.zeros_like(x) 173 | result.scatter_(-1, topk.indices, values) 174 | return result 175 | 176 | def state_dict(self, destination=None, prefix="", keep_vars=False): 177 | state_dict = super().state_dict(destination, prefix, keep_vars) 178 | state_dict.update({prefix + "k": self.k, prefix + "postact_fn": self.postact_fn.__class__.__name__}) 179 | return state_dict 180 | 181 | @classmethod 182 | def from_state_dict(cls, state_dict: dict[str, torch.Tensor], strict: bool = True) -> "TopK": 183 | k = state_dict["k"] 184 | postact_fn = ACTIVATIONS_CLASSES[state_dict["postact_fn"]]() 185 | return cls(k=k, postact_fn=postact_fn) 186 | 187 | 188 | ACTIVATIONS_CLASSES = { 189 | "ReLU": nn.ReLU, 190 | "Identity": nn.Identity, 191 | "TopK": TopK, 192 | } 193 | -------------------------------------------------------------------------------- /sae-viewer/src/welcome.tsx: -------------------------------------------------------------------------------- 1 | import React from "react" 2 | import { useState, FormEvent } from "react" 3 | import { useNavigate } from "react-router-dom" 4 | import { Feature } from "./types" 5 | import FeatureSelect from "./components/featureSelect" 6 | import { pathForFeature, DEFAULT_AUTOENCODER, AUTOENCODER_FAMILIES } from "./autoencoder_registry" 7 | 8 | export default function Welcome() { 9 | const navigate = useNavigate() 10 | 11 | const GPT4_ATOMS_PER_SHARD = 1024; 12 | const displayFeatures = [ 13 | /************** 14 | /* well explained + interesting 15 | ***************/ 16 | {heading: 'GPT-4', heading_type: 'h4', feature: null, label: ''}, 17 | {feature: {atom: 62 * GPT4_ATOMS_PER_SHARD + 53, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})}, 18 | label: "humans have flaws", description: "descriptions of how humans are flawed"}, 19 | {feature: {atom: 25 * GPT4_ATOMS_PER_SHARD + 8, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})}, 20 | label: "police reports, especially child safety", description: "safety incidents especially related to children"}, 21 | {feature: {atom: 9 * GPT4_ATOMS_PER_SHARD + 44, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})}, 22 | label: "price changes", description: "ends of phrases describing commodity/equity price changes"}, 23 | {feature: {atom: 17 * GPT4_ATOMS_PER_SHARD + 33, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})}, 24 | label: "ratification (multilingual)", description: "ratification (multilingual)"}, 25 | {feature: {atom: 3 * GPT4_ATOMS_PER_SHARD + 421, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})}, 26 | label: "would [...]", description: "conditionals (things that would be true)"}, 27 | {feature: {atom: 63 * GPT4_ATOMS_PER_SHARD + 8, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})}, 28 | label: "identification documents (multilingual)", description: "identification documents (multilingual)"}, 29 | {feature: {atom: 0 * GPT4_ATOMS_PER_SHARD + 14, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})}, 30 | label: "lightly incremented timestamps", description: "timestamps being lightly incremented with recurring formats"}, 31 | {heading: 'Technical knowledge', heading_type: 'h3', feature: null, label: ''}, 32 | {feature: {atom: 40 * GPT4_ATOMS_PER_SHARD + 42, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})}, 33 | label: "machine learning training logs", description: "machine learning training logs"}, 34 | {feature: {atom: 12 * GPT4_ATOMS_PER_SHARD + 47, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})}, 35 | label: "onclick/onchange = function(this)", description: "onclick/onchange = function(this)"}, 36 | {feature: {atom: 54 * GPT4_ATOMS_PER_SHARD + 23, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})}, 37 | label: "edges (graph theory) and related concepts", description: "edges (graph theory) and related concepts"}, 38 | {feature: {atom: 56 * GPT4_ATOMS_PER_SHARD + 12, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})}, 39 | label: "algebraic rings", description: "algebraic rings"}, 40 | {feature: {atom: 28 * GPT4_ATOMS_PER_SHARD + 47, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})}, 41 | label: "adenosine/dopamine receptors", description: "adenosine/dopamine receptors"}, 42 | {feature: {atom: 2 * GPT4_ATOMS_PER_SHARD + 601, autoencoder: AUTOENCODER_FAMILIES['v5_latelayer_postmlp'].get_ae({})}, 43 | label: "blockchain vibes", description: "blockchain vibes"}, 44 | 45 | 46 | {heading: 'GPT-2 small', heading_type: 'h4', feature: null, label: ''}, 47 | {feature: {atom: 488432, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({ 48 | num_features: '2097152', num_active_features: '8' 49 | })}, label: "rhetorical questions", description: "rhetorical questions"}, 50 | {feature: {atom: 2088200, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({ 51 | num_features: '2097152', num_active_features: '8' 52 | })}, label: "counting human casualties", description: "counting human casualties"}, 53 | {feature: {atom: 1621560, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({ 54 | num_features: '2097152', num_active_features: '8' 55 | })}, label: "X and Y phrases", description: "X and -> Y"}, 56 | {feature: {atom: 733, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({ 57 | num_features: '32768', num_active_features: '8' 58 | })}, label: "Patrick/Patty surname predictor", description: "Predicts surnames after Patrick"}, 59 | {feature: {atom: 64464, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({ 60 | num_features: '131072', num_active_features: '32' 61 | })}, label: "things that are unknown", description: "things that are unknown"}, 62 | {feature: {atom: 56907, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({ // similar to 33248 63 | num_features: '131072', num_active_features: '32' 64 | })}, label: "words in quotes", description: "predicts words in quotes"}, 65 | {feature: {atom: 1605835, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({ 66 | num_features: '2097152', num_active_features: '8' 67 | })}, label: "these/those responsible things", description: "these/those, in a phrase where something is responsible for something"}, 68 | {feature: {atom: 8040, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({ 69 | num_features: '8192', num_active_features: '32' 70 | })}, label: "2018 natural disasters", description: "2018 natural disasters"}, 71 | {feature: {atom: 21464, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({ 72 | num_features: '131072', num_active_features: '32' 73 | })}, label: "addition in code", description: "addition in code"}, 74 | {feature: {atom: 66232, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({ 75 | num_features: '131072', num_active_features: '32' 76 | })}, label: "function application", description: "function application"}, 77 | {feature: {atom: 64464, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({ 78 | num_features: '131072', num_active_features: '32' 79 | })}, label: "unclear/hidden things", description: "unclear/hidden things (top only)"}, 80 | {feature: {atom: 10423, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({ 81 | num_features: '131072', num_active_features: '32' 82 | })}, label: "what the ...", description: "[who/what/when/where/why/how] the"}, 83 | {heading: 'Safety relevant features (found via attribution methods)', heading_type: 'h3', feature: null, label: ''}, 84 | {feature: {atom: 64840, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({ 85 | num_features: '131072', num_active_features: '32' 86 | })}, label: "profanity (1)", description: "activates in order to output profanity"}, 87 | {feature: {atom: 104813, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({ 88 | num_features: '131072', num_active_features: '32' 89 | })}, label: "profanity (2)", description: "activates on profanity"}, 90 | {feature: {atom: 101090, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({ 91 | num_features: '131072', num_active_features: '32' 92 | })}, label: "profanity (3)", description: "activates on 'fucking' (profane, not sexual contexts)"}, 93 | {feature: {atom: 72185, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({ 94 | num_features: '131072', num_active_features: '32' 95 | })}, label: "erotic content", description: "erotic content"}, 96 | {feature: {atom: 69134, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({ 97 | num_features: '131072', num_active_features: '32' 98 | })}, label: "[content warning] sexual abuse", description: "sexual abuse"}, 99 | // {feature: {atom: 2, autoencoder: AUTOENCODER_FAMILIES['v5_l8_postmlp'].get_ae({ 100 | // num_features: '2097152', num_active_features: '8' 101 | // })}, label: "things being brought", description: "bring * -> together/back"}, 102 | ] 103 | 104 | let [feature, setFeature] = useState({ 105 | atom: 0, autoencoder: DEFAULT_AUTOENCODER 106 | }) 107 | const handleClick = (click_feature: Feature) => { 108 | navigate(pathForFeature(click_feature)) 109 | } 110 | 111 | return ( 112 |
113 |

Welcome! This is a viewer for sparse autoencoders features trained in this paper

114 |

Pick a feature:

115 | setFeature(f)} 118 | onFeatureSubmit={(f: Feature) => navigate(pathForFeature(f))} 119 | show_go={true} 120 | /> 121 | 122 |
123 |

Interesting features:

124 |
125 |
128 | {displayFeatures.map(({ heading, heading_type, feature, label, description }, j) => ( 129 | heading ?
130 | {React.createElement(heading_type, {}, heading)} 131 |
: 140 | ))} 141 |
142 |
143 |
144 |
145 | ) 146 | } 147 | -------------------------------------------------------------------------------- /sparse_autoencoder/explanations.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Any, Callable 3 | 4 | import blobfile as bf 5 | 6 | 7 | class Explanation(ABC): 8 | def predict(self, tokens: list[str]) -> list[float]: 9 | raise NotImplementedError 10 | 11 | def dump(self) -> bytes: 12 | raise NotImplementedError 13 | 14 | @classmethod 15 | def load(cls, serialized: Any) -> "Explanation": 16 | raise NotImplementedError 17 | 18 | def dumpf(self, filename: str): 19 | d = self.dump() 20 | assert isinstance(d, bytes) 21 | with bf.BlobFile(filename, "wb") as f: 22 | f.write(d) 23 | 24 | @classmethod 25 | def loadf(cls, filename: str): 26 | with bf.BlobFile(filename, "rb") as f: 27 | return cls.load(f.read()) 28 | 29 | 30 | _ANY_TOKEN = "token(*)" 31 | _START_TOKEN = "token(^)" 32 | _SALIENCY_KEY = "" 33 | 34 | 35 | class NtgExplanation(Explanation): 36 | def __init__(self, trie: dict): 37 | self.trie = trie 38 | 39 | def todict(self) -> dict: 40 | return { 41 | "trie": self.trie, 42 | } 43 | 44 | @classmethod 45 | def load(cls, serialized: dict) -> "Explanation": 46 | assert isinstance(serialized, dict) 47 | return cls(serialized["trie"]) 48 | 49 | def predict(self, tokens: list[str]) -> list[float]: 50 | predicted_acts = [] 51 | # for each token, traverse the trie beginning from that token and proceeding in reverse order until we match 52 | # a pattern or are no longer able to traverse. 53 | for i in range(len(tokens)): 54 | curr = self.trie 55 | for j in range(i, -1, -1): 56 | if tokens[j] not in curr and _ANY_TOKEN not in curr: 57 | predicted_acts.append(0) 58 | break 59 | if tokens[j] in curr: 60 | curr = curr[tokens[j]] 61 | else: 62 | curr = curr[_ANY_TOKEN] 63 | if _SALIENCY_KEY in curr: 64 | predicted_acts.append(curr[_SALIENCY_KEY]) 65 | break 66 | # if we"ve reached the end of the sequence and haven't found a saliency value, append 0. 67 | elif j == 0: 68 | if _START_TOKEN in curr: 69 | curr = curr[_START_TOKEN] 70 | assert _SALIENCY_KEY in curr 71 | predicted_acts.append(curr[_SALIENCY_KEY]) 72 | break 73 | predicted_acts.append(0) 74 | # We should have appended a value for each token in the sequence. 75 | assert len(predicted_acts) == len(tokens) 76 | return predicted_acts 77 | 78 | # TODO make this more efficient 79 | def predict_many(self, tokens_batch: list[list[str]]) -> list[list[float]]: 80 | return [self.predict(t) for t in tokens_batch] 81 | 82 | 83 | def batched(iterable, bs): 84 | batch = [] 85 | it = iter(iterable) 86 | while True: 87 | batch = [] 88 | try: 89 | for _ in range(bs): 90 | batch.append(next(it)) 91 | yield batch 92 | except StopIteration: 93 | if len(batch) > 0: 94 | yield batch 95 | return 96 | 97 | 98 | def apply_batched(fn, iterable, bs): 99 | for batch in batched(iterable, bs): 100 | ret = fn(batch) 101 | assert len(ret) == len(batch) 102 | yield from ret 103 | 104 | 105 | def batch_parallelize(algos, fn, batch_size): 106 | """ 107 | Algorithms are coroutines that yield items to be processed in parallel. 108 | We concurrently run the algorithm on all items in the batch. 109 | """ 110 | inputs = [] 111 | for i, algo in enumerate(algos): 112 | inputs.append((i, next(algo))) 113 | results = [None] * len(algos) 114 | while len(inputs) > 0: 115 | ret = list(apply_batched(fn, [x[1] for x in inputs], batch_size)) 116 | assert len(ret) == len(inputs) 117 | inds = [x[0] for x in inputs] 118 | inputs = [] 119 | for i, r in zip(inds, ret): 120 | try: 121 | next_input = algos[i].send(r) 122 | inputs.append((i, next_input)) 123 | except StopIteration as e: 124 | results[i] = e.value 125 | return results 126 | 127 | 128 | def create_n2g_explanation( 129 | model_fn: Callable, train_set: list[dict], batch_size: int = 16, 130 | padding_token=4808 # " _" for GPT-2 131 | ) -> NtgExplanation: 132 | truncated = [] 133 | # for each one find the index of the selected activation in the doc. truncate the sequences after this point. 134 | for doc in train_set: 135 | # get index of selected activation. for docs stored in 'top', this is the max activation. 136 | # for docs stored in 'random', it is a random positive activation (we sample activations, not docs 137 | # to populate 'random', so docs with more positive activations are more likely to be included). 138 | max_idx = doc["idx"] 139 | truncated.append( 140 | { 141 | "act": doc["act"], 142 | "acts": doc["acts"][: max_idx + 1], 143 | "tokens": doc["tokens"][: max_idx + 1], 144 | "token_ints": doc["token_ints"][: max_idx + 1], 145 | } 146 | ) 147 | 148 | def get_minimal_subsequence(doc): 149 | for i in range(len(doc["token_ints"]) - 1, -1, -1): 150 | atom_acts = yield doc["token_ints"][i:] 151 | assert ( 152 | len(atom_acts) == len(doc["token_ints"]) - i 153 | ), f"{len(atom_acts)} != {len(doc['token_ints']) - i}" 154 | if atom_acts[-1] / doc["act"] >= 0.5: 155 | return { 156 | "tokens": doc["tokens"][i:], 157 | "token_ints": doc["token_ints"][i:], 158 | "subsequence_act": atom_acts[-1], 159 | "orig_act": doc["act"], 160 | } 161 | print("Warning: no minimal subsequence found") 162 | # raise ValueError("No minimal subsequence found") 163 | return { 164 | "tokens": doc["tokens"], 165 | "token_ints": doc["token_ints"], 166 | "subsequence_act": doc["act"], 167 | "orig_act": doc["act"], 168 | } 169 | 170 | minimal_subsequences = batch_parallelize( 171 | [get_minimal_subsequence(doc) for doc in truncated], model_fn, batch_size 172 | ) 173 | 174 | start_padded = apply_batched( 175 | model_fn, 176 | [[padding_token] + doc["token_ints"] for doc in minimal_subsequences], 177 | batch_size, 178 | ) 179 | for min_seq, pad_atom_acts in zip(minimal_subsequences, start_padded): 180 | min_seq["can_pad_start"] = pad_atom_acts[-1] / min_seq["orig_act"] >= 0.5 181 | 182 | # for m in minimal_subsequences: 183 | # print("\t" + "".join(m["tokens"])) 184 | 185 | # for each token in a minimal subsequence, replace it with a padding token and compute the saliency value (1 - (orig act / new act)) 186 | for doc in minimal_subsequences: 187 | all_seqs = [] 188 | for i in range(len(doc["token_ints"])): 189 | tokens = doc["token_ints"][:i] + [padding_token] + doc["token_ints"][i + 1 :] 190 | assert len(tokens) == len(doc["token_ints"]) 191 | all_seqs.append(tokens) 192 | saliency_vals = [] 193 | all_atom_acts = apply_batched(model_fn, all_seqs, batch_size) 194 | for atom_acts, tokens in zip(all_atom_acts, all_seqs): 195 | assert len(atom_acts) == len(tokens) 196 | saliency_vals.append(1 - (atom_acts[-1] / doc["subsequence_act"])) 197 | doc["saliency_vals"] = saliency_vals 198 | 199 | trie = {} 200 | for doc in minimal_subsequences: 201 | curr = trie 202 | for i, (token, saliency) in enumerate(zip(doc["tokens"][::-1], doc["saliency_vals"][::-1])): 203 | if saliency < 0.5: 204 | token = _ANY_TOKEN 205 | if token not in curr: 206 | curr[token] = {} 207 | curr = curr[token] 208 | if i == len(doc["tokens"]) - 1: 209 | if not doc["can_pad_start"]: 210 | curr[_START_TOKEN] = {} 211 | curr = curr[_START_TOKEN] 212 | curr[_SALIENCY_KEY] = doc["subsequence_act"] 213 | 214 | return NtgExplanation(trie) 215 | 216 | 217 | if __name__ == "__main__": 218 | 219 | def first_position_fn(tokens: list[list[float]]) -> list[list[float]]: 220 | return [[1.0] + [0.0] * (len(toks) - 1) for toks in tokens] 221 | 222 | expl = create_n2g_explanation( 223 | first_position_fn, 224 | [ 225 | { 226 | "idx": 0, 227 | "act": 1.0, 228 | "acts": [1.0, 0.0, 0.0], 229 | "tokens": ["a", "b", "c"], 230 | "token_ints": [0, 1, 2], 231 | }, 232 | { 233 | "idx": 0, 234 | "act": 1.0, 235 | "acts": [1.0, 0.0, 0.0], 236 | "tokens": ["b", "c", "d"], 237 | "token_ints": [1, 2, 3], 238 | }, 239 | ], 240 | ) 241 | print(expl.trie) 242 | assert expl.predict(["a", "b", "c"]) == [1.0, 0.0, 0.0] 243 | assert expl.predict(["c", "b", "a"]) == [1.0, 0.0, 0.0] 244 | 245 | def c_fn(tokens: list[list[float]]) -> list[list[float]]: 246 | return [[1.0 if tok == 2 else 0.0 for tok in toks] for toks in tokens] 247 | 248 | expl = create_n2g_explanation( 249 | c_fn, 250 | [ 251 | { 252 | "idx": 2, 253 | "act": 1.0, 254 | "acts": [0.0, 0.0, 1.0], 255 | "tokens": ["a", "b", "c"], 256 | "token_ints": [0, 1, 2], 257 | }, 258 | { 259 | "idx": 1, 260 | "act": 1.0, 261 | "acts": [0.0, 1.0, 0.0], 262 | "tokens": ["b", "c", "d"], 263 | "token_ints": [1, 2, 3], 264 | }, 265 | ], 266 | ) 267 | 268 | print(expl.trie) 269 | assert expl.predict(["b", "c"]) == [0.0, 1.0] 270 | assert expl.predict(["a", "a", "a"]) == [0.0, 0.0, 0.0] 271 | assert expl.predict(["c", "b", "c"]) == [1.0, 0.0, 1.0] 272 | 273 | def a_star_c_fn(tokens: list[list[float]]) -> list[list[float]]: 274 | return [ 275 | [1.0 if (tok == 2 and 0 in toks[:i]) else 0.0 for i, tok in enumerate(toks)] 276 | for toks in tokens 277 | ] 278 | 279 | expl = create_n2g_explanation( 280 | a_star_c_fn, 281 | [ 282 | { 283 | "idx": 2, 284 | "act": 1.0, 285 | "acts": [0.0, 0.0, 1.0], 286 | "tokens": ["a", "b", "c"], 287 | "token_ints": [0, 1, 2], 288 | }, 289 | { 290 | "idx": 2, 291 | "act": 1.0, 292 | "acts": [0.0, 0.0, 1.0], 293 | "tokens": ["b", "a", "c"], 294 | "token_ints": [1, 0, 2], 295 | }, 296 | { 297 | "idx": 1, 298 | "act": 1.0, 299 | "acts": [0.0, 1.0, 0.0], 300 | "tokens": ["a", "c", "d"], 301 | "token_ints": [0, 2, 3], 302 | }, 303 | ], 304 | ) 305 | 306 | print(expl.trie) 307 | assert expl.predict(["b", "c"]) == [0.0, 0.0] 308 | assert expl.predict(["a", "c"]) == [0.0, 1.0] 309 | assert expl.predict(["a", "e", "c", "a", "c"]) == [0.0, 0.0, 1.0, 0.0, 1.0] 310 | # NOTE: should be 0, 0, 0, 1 but we're not smart enough to handle this yet 311 | assert expl.predict(["a", "b", "b", "c"]) == [0.0, 0.0, 0.0, 0.0] 312 | 313 | def zero_fn(tokens: list[list[float]]) -> list[list[float]]: 314 | return [[0.0 for tok in toks] for toks in tokens] 315 | 316 | expl = create_n2g_explanation(zero_fn, []) 317 | 318 | print(expl.trie) 319 | assert expl.predict(["b", "c"]) == [0.0, 0.0] 320 | assert expl.predict(["a", "c"]) == [0.0, 0.0] 321 | assert expl.predict(["a", "e", "c", "a", "c"]) == [0.0, 0.0, 0.0, 0.0, 0.0] 322 | -------------------------------------------------------------------------------- /sparse_autoencoder/kernels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import triton 4 | import triton.language as tl 5 | 6 | 7 | ## kernels 8 | 9 | 10 | def triton_sparse_transpose_dense_matmul( 11 | sparse_indices: torch.Tensor, 12 | sparse_values: torch.Tensor, 13 | dense: torch.Tensor, 14 | N: int, 15 | BLOCK_SIZE_AK=128, 16 | ) -> torch.Tensor: 17 | """ 18 | calculates sparse.T @ dense (i.e reducing along the collated dimension of sparse) 19 | dense must be contiguous along dim 0 (in other words, dense.T is contiguous) 20 | 21 | sparse_indices is shape (A, k) 22 | sparse_values is shape (A, k) 23 | dense is shape (A, B) 24 | 25 | output is shape (N, B) 26 | """ 27 | 28 | assert sparse_indices.shape == sparse_values.shape 29 | assert sparse_indices.is_contiguous() 30 | assert sparse_values.is_contiguous() 31 | assert dense.is_contiguous() # contiguous along B 32 | 33 | K = sparse_indices.shape[1] 34 | A = dense.shape[0] 35 | B = dense.shape[1] 36 | assert sparse_indices.shape[0] == A 37 | 38 | # COO-format and sorted 39 | sorted_indices = sparse_indices.view(-1).sort() 40 | coo_indices = torch.stack( 41 | [ 42 | torch.arange(A, device=sparse_indices.device).repeat_interleave(K)[ 43 | sorted_indices.indices 44 | ], 45 | sorted_indices.values, 46 | ] 47 | ) # shape (2, A * K) 48 | coo_values = sparse_values.view(-1)[sorted_indices.indices] # shape (A * K,) 49 | return triton_coo_sparse_dense_matmul(coo_indices, coo_values, dense, N, BLOCK_SIZE_AK) 50 | 51 | 52 | def triton_coo_sparse_dense_matmul( 53 | coo_indices: torch.Tensor, 54 | coo_values: torch.Tensor, 55 | dense: torch.Tensor, 56 | N: int, 57 | BLOCK_SIZE_AK=128, 58 | ) -> torch.Tensor: 59 | AK = coo_indices.shape[1] 60 | B = dense.shape[1] 61 | 62 | out = torch.zeros(N, B, device=dense.device, dtype=coo_values.dtype) 63 | 64 | grid = lambda META: ( 65 | triton.cdiv(AK, META["BLOCK_SIZE_AK"]), 66 | 1, 67 | ) 68 | triton_sparse_transpose_dense_matmul_kernel[grid]( 69 | coo_indices, 70 | coo_values, 71 | dense, 72 | out, 73 | stride_da=dense.stride(0), 74 | stride_db=dense.stride(1), 75 | B=B, 76 | N=N, 77 | AK=AK, 78 | BLOCK_SIZE_AK=BLOCK_SIZE_AK, 79 | BLOCK_SIZE_B=triton.next_power_of_2(B), 80 | ) 81 | return out 82 | 83 | 84 | @triton.jit 85 | def triton_sparse_transpose_dense_matmul_kernel( 86 | coo_indices_ptr, 87 | coo_values_ptr, 88 | dense_ptr, 89 | out_ptr, 90 | stride_da, 91 | stride_db, 92 | B, 93 | N, 94 | AK, 95 | BLOCK_SIZE_AK: tl.constexpr, 96 | BLOCK_SIZE_B: tl.constexpr, 97 | ): 98 | """ 99 | coo_indices is shape (2, AK) 100 | coo_values is shape (AK,) 101 | dense is shape (A, B), contiguous along B 102 | out is shape (N, B) 103 | """ 104 | 105 | pid_ak = tl.program_id(0) 106 | pid_b = tl.program_id(1) 107 | 108 | coo_offsets = tl.arange(0, BLOCK_SIZE_AK) 109 | b_offsets = tl.arange(0, BLOCK_SIZE_B) 110 | 111 | A_coords = tl.load( 112 | coo_indices_ptr + pid_ak * BLOCK_SIZE_AK + coo_offsets, 113 | mask=pid_ak * BLOCK_SIZE_AK + coo_offsets < AK, 114 | ) 115 | K_coords = tl.load( 116 | coo_indices_ptr + pid_ak * BLOCK_SIZE_AK + coo_offsets + AK, 117 | mask=pid_ak * BLOCK_SIZE_AK + coo_offsets < AK, 118 | ) 119 | values = tl.load( 120 | coo_values_ptr + pid_ak * BLOCK_SIZE_AK + coo_offsets, 121 | mask=pid_ak * BLOCK_SIZE_AK + coo_offsets < AK, 122 | ) 123 | 124 | last_k = tl.min(K_coords) 125 | accum = tl.zeros((BLOCK_SIZE_B,), dtype=tl.float32) 126 | 127 | for ind in range(BLOCK_SIZE_AK): 128 | if ind + pid_ak * BLOCK_SIZE_AK < AK: 129 | # workaround to do A_coords[ind] 130 | a = tl.sum( 131 | tl.where( 132 | tl.arange(0, BLOCK_SIZE_AK) == ind, 133 | A_coords, 134 | tl.zeros((BLOCK_SIZE_AK,), dtype=tl.int64), 135 | ) 136 | ) 137 | 138 | k = tl.sum( 139 | tl.where( 140 | tl.arange(0, BLOCK_SIZE_AK) == ind, 141 | K_coords, 142 | tl.zeros((BLOCK_SIZE_AK,), dtype=tl.int64), 143 | ) 144 | ) 145 | 146 | v = tl.sum( 147 | tl.where( 148 | tl.arange(0, BLOCK_SIZE_AK) == ind, 149 | values, 150 | tl.zeros((BLOCK_SIZE_AK,), dtype=tl.float32), 151 | ) 152 | ) 153 | 154 | tl.device_assert(k < N) 155 | 156 | if k != last_k: 157 | tl.atomic_add( 158 | out_ptr + last_k * B + BLOCK_SIZE_B * pid_b + b_offsets, 159 | accum, 160 | mask=BLOCK_SIZE_B * pid_b + b_offsets < B, 161 | ) 162 | accum *= 0 163 | last_k = k 164 | 165 | if v != 0: 166 | accum += v * tl.load(dense_ptr + a * stride_da + b_offsets, mask=b_offsets < B) 167 | 168 | tl.atomic_add( 169 | out_ptr + last_k * B + BLOCK_SIZE_B * pid_b + b_offsets, 170 | accum, 171 | mask=BLOCK_SIZE_B * pid_b + b_offsets < B, 172 | ) 173 | 174 | 175 | def triton_sparse_dense_matmul( 176 | sparse_indices: torch.Tensor, 177 | sparse_values: torch.Tensor, 178 | dense: torch.Tensor, 179 | ) -> torch.Tensor: 180 | """ 181 | calculates sparse @ dense (i.e reducing along the uncollated dimension of sparse) 182 | dense must be contiguous along dim 0 (in other words, dense.T is contiguous) 183 | 184 | sparse_indices is shape (A, k) 185 | sparse_values is shape (A, k) 186 | dense is shape (N, B) 187 | 188 | output is shape (A, B) 189 | """ 190 | N = dense.shape[0] 191 | assert sparse_indices.shape == sparse_values.shape 192 | assert sparse_indices.is_contiguous() 193 | assert sparse_values.is_contiguous() 194 | assert dense.is_contiguous() # contiguous along B 195 | 196 | A = sparse_indices.shape[0] 197 | K = sparse_indices.shape[1] 198 | B = dense.shape[1] 199 | 200 | out = torch.zeros(A, B, device=dense.device, dtype=sparse_values.dtype) 201 | 202 | triton_sparse_dense_matmul_kernel[(A,)]( 203 | sparse_indices, 204 | sparse_values, 205 | dense, 206 | out, 207 | stride_dn=dense.stride(0), 208 | stride_db=dense.stride(1), 209 | A=A, 210 | B=B, 211 | N=N, 212 | K=K, 213 | BLOCK_SIZE_K=triton.next_power_of_2(K), 214 | BLOCK_SIZE_B=triton.next_power_of_2(B), 215 | ) 216 | return out 217 | 218 | 219 | @triton.jit 220 | def triton_sparse_dense_matmul_kernel( 221 | sparse_indices_ptr, 222 | sparse_values_ptr, 223 | dense_ptr, 224 | out_ptr, 225 | stride_dn, 226 | stride_db, 227 | A, 228 | B, 229 | N, 230 | K, 231 | BLOCK_SIZE_K: tl.constexpr, 232 | BLOCK_SIZE_B: tl.constexpr, 233 | ): 234 | """ 235 | sparse_indices is shape (A, K) 236 | sparse_values is shape (A, K) 237 | dense is shape (N, B), contiguous along B 238 | out is shape (A, B) 239 | """ 240 | 241 | pid = tl.program_id(0) 242 | 243 | offsets_k = tl.arange(0, BLOCK_SIZE_K) 244 | sparse_indices = tl.load( 245 | sparse_indices_ptr + pid * K + offsets_k, mask=offsets_k < K 246 | ) # shape (K,) 247 | sparse_values = tl.load( 248 | sparse_values_ptr + pid * K + offsets_k, mask=offsets_k < K 249 | ) # shape (K,) 250 | 251 | accum = tl.zeros((BLOCK_SIZE_B,), dtype=tl.float32) 252 | 253 | offsets_b = tl.arange(0, BLOCK_SIZE_B) 254 | 255 | for k in range(K): 256 | # workaround to do sparse_indices[k] 257 | i = tl.sum( 258 | tl.where( 259 | tl.arange(0, BLOCK_SIZE_K) == k, 260 | sparse_indices, 261 | tl.zeros((BLOCK_SIZE_K,), dtype=tl.int64), 262 | ) 263 | ) 264 | # workaround to do sparse_values[k] 265 | v = tl.sum( 266 | tl.where( 267 | tl.arange(0, BLOCK_SIZE_K) == k, 268 | sparse_values, 269 | tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32), 270 | ) 271 | ) 272 | 273 | tl.device_assert(i < N) 274 | if v != 0: 275 | accum += v * tl.load( 276 | dense_ptr + i * stride_dn + offsets_b * stride_db, mask=offsets_b < B 277 | ) 278 | 279 | tl.store(out_ptr + pid * B + offsets_b, accum.to(sparse_values.dtype), mask=offsets_b < B) 280 | 281 | 282 | def triton_dense_dense_sparseout_matmul( 283 | dense1: torch.Tensor, 284 | dense2: torch.Tensor, 285 | at_indices: torch.Tensor, 286 | ) -> torch.Tensor: 287 | """ 288 | dense1: shape (A, B) 289 | dense2: shape (B, N) 290 | at_indices: shape (A, K) 291 | out values: shape (A, K) 292 | calculates dense1 @ dense2 only for the indices in at_indices 293 | 294 | equivalent to (dense1 @ dense2).gather(1, at_indices) 295 | """ 296 | A, B = dense1.shape 297 | N = dense2.shape[1] 298 | assert dense2.shape[0] == B 299 | assert at_indices.shape[0] == A 300 | K = at_indices.shape[1] 301 | assert at_indices.is_contiguous() 302 | 303 | assert dense1.stride(1) == 1, "dense1 must be contiguous along B" 304 | assert dense2.stride(0) == 1, "dense2 must be contiguous along B" 305 | 306 | if K > 512: 307 | # print("WARN - using naive matmul for large K") 308 | # naive is more efficient for large K 309 | return (dense1 @ dense2).gather(1, at_indices) 310 | 311 | out = torch.zeros(A, K, device=dense1.device, dtype=dense1.dtype) 312 | 313 | # grid = lambda META: (triton.cdiv(A, META['BLOCK_SIZE_A']),) 314 | 315 | triton_dense_dense_sparseout_matmul_kernel[(A,)]( 316 | dense1, 317 | dense2, 318 | at_indices, 319 | out, 320 | stride_d1a=dense1.stride(0), 321 | stride_d1b=dense1.stride(1), 322 | stride_d2b=dense2.stride(0), 323 | stride_d2n=dense2.stride(1), 324 | A=A, 325 | B=B, 326 | N=N, 327 | K=K, 328 | BLOCK_SIZE_B=triton.next_power_of_2(B), 329 | BLOCK_SIZE_N=triton.next_power_of_2(N), 330 | BLOCK_SIZE_K=triton.next_power_of_2(K), 331 | ) 332 | 333 | return out 334 | 335 | 336 | @triton.jit 337 | def triton_dense_dense_sparseout_matmul_kernel( 338 | dense1_ptr, 339 | dense2_ptr, 340 | at_indices_ptr, 341 | out_ptr, 342 | stride_d1a, 343 | stride_d1b, 344 | stride_d2b, 345 | stride_d2n, 346 | A, 347 | B, 348 | N, 349 | K, 350 | BLOCK_SIZE_B: tl.constexpr, 351 | BLOCK_SIZE_N: tl.constexpr, 352 | BLOCK_SIZE_K: tl.constexpr, 353 | ): 354 | """ 355 | dense1: shape (A, B) 356 | dense2: shape (B, N) 357 | at_indices: shape (A, K) 358 | out values: shape (A, K) 359 | """ 360 | 361 | pid = tl.program_id(0) 362 | 363 | offsets_k = tl.arange(0, BLOCK_SIZE_K) 364 | at_indices = tl.load(at_indices_ptr + pid * K + offsets_k, mask=offsets_k < K) # shape (K,) 365 | 366 | offsets_b = tl.arange(0, BLOCK_SIZE_B) 367 | dense1 = tl.load( 368 | dense1_ptr + pid * stride_d1a + offsets_b * stride_d1b, mask=offsets_b < B 369 | ) # shape (B,) 370 | 371 | accum = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32) 372 | 373 | for k in range(K): 374 | # workaround to do at_indices[b] 375 | i = tl.sum( 376 | tl.where( 377 | tl.arange(0, BLOCK_SIZE_K) == k, 378 | at_indices, 379 | tl.zeros((BLOCK_SIZE_K,), dtype=tl.int64), 380 | ) 381 | ) 382 | tl.device_assert(i < N) 383 | 384 | dense2col = tl.load( 385 | dense2_ptr + offsets_b * stride_d2b + i * stride_d2n, mask=offsets_b < B 386 | ) # shape (B,) 387 | accum += tl.where( 388 | tl.arange(0, BLOCK_SIZE_K) == k, 389 | tl.sum(dense1 * dense2col), 390 | tl.zeros((BLOCK_SIZE_K,), dtype=tl.int64), 391 | ) 392 | 393 | tl.store(out_ptr + pid * K + offsets_k, accum, mask=offsets_k < K) 394 | 395 | 396 | class TritonDecoderAutograd(torch.autograd.Function): 397 | @staticmethod 398 | def forward(ctx, sparse_indices, sparse_values, decoder_weight): 399 | ctx.save_for_backward(sparse_indices, sparse_values, decoder_weight) 400 | return triton_sparse_dense_matmul(sparse_indices, sparse_values, decoder_weight.T) 401 | 402 | @staticmethod 403 | def backward(ctx, grad_output): 404 | sparse_indices, sparse_values, decoder_weight = ctx.saved_tensors 405 | 406 | assert grad_output.is_contiguous(), "grad_output must be contiguous; this is probably because the subsequent op was a .sum() or something like that, which returns a non contiguous gradient" 407 | 408 | decoder_grad = triton_sparse_transpose_dense_matmul( 409 | sparse_indices, sparse_values, grad_output, N=decoder_weight.shape[1] 410 | ).T 411 | 412 | return ( 413 | None, 414 | triton_dense_dense_sparseout_matmul(grad_output, decoder_weight, sparse_indices), 415 | # decoder is contiguous when transposed so this is a matching layout 416 | decoder_grad, 417 | None, 418 | ) 419 | 420 | 421 | def triton_add_mul_( 422 | x: torch.Tensor, 423 | a: torch.Tensor, 424 | b: torch.Tensor, 425 | c: float, 426 | ): 427 | """ 428 | does 429 | x += a * b * c 430 | 431 | x : [m, n] 432 | a : [m, n] 433 | b : [m, n] 434 | c : float 435 | """ 436 | 437 | if len(a.shape) == 1: 438 | a = a[None, :].broadcast_to(x.shape) 439 | 440 | if len(b.shape) == 1: 441 | b = b[None, :].broadcast_to(x.shape) 442 | 443 | assert x.shape == a.shape == b.shape 444 | 445 | BLOCK_SIZE_M = 64 446 | BLOCK_SIZE_N = 64 447 | grid = lambda META: ( 448 | triton.cdiv(x.shape[0], META["BLOCK_SIZE_M"]), 449 | triton.cdiv(x.shape[1], META["BLOCK_SIZE_N"]), 450 | ) 451 | triton_add_mul_kernel[grid]( 452 | x, 453 | a, 454 | b, 455 | c, 456 | x.stride(0), 457 | x.stride(1), 458 | a.stride(0), 459 | a.stride(1), 460 | b.stride(0), 461 | b.stride(1), 462 | BLOCK_SIZE_M, 463 | BLOCK_SIZE_N, 464 | x.shape[0], 465 | x.shape[1], 466 | ) 467 | 468 | 469 | @triton.jit 470 | def triton_add_mul_kernel( 471 | x_ptr, 472 | a_ptr, 473 | b_ptr, 474 | c, 475 | stride_x0, 476 | stride_x1, 477 | stride_a0, 478 | stride_a1, 479 | stride_b0, 480 | stride_b1, 481 | BLOCK_SIZE_M: tl.constexpr, 482 | BLOCK_SIZE_N: tl.constexpr, 483 | M: tl.constexpr, 484 | N: tl.constexpr, 485 | ): 486 | pid_m = tl.program_id(0) 487 | pid_n = tl.program_id(1) 488 | 489 | offsets_m = tl.arange(0, BLOCK_SIZE_M) + pid_m * BLOCK_SIZE_M 490 | offsets_n = tl.arange(0, BLOCK_SIZE_N) + pid_n * BLOCK_SIZE_N 491 | 492 | x = tl.load( 493 | x_ptr + offsets_m[:, None] * stride_x0 + offsets_n[None, :] * stride_x1, 494 | mask=(offsets_m[:, None] < M) & (offsets_n[None, :] < N), 495 | ) 496 | a = tl.load( 497 | a_ptr + offsets_m[:, None] * stride_a0 + offsets_n[None, :] * stride_a1, 498 | mask=(offsets_m[:, None] < M) & (offsets_n[None, :] < N), 499 | ) 500 | b = tl.load( 501 | b_ptr + offsets_m[:, None] * stride_b0 + offsets_n[None, :] * stride_b1, 502 | mask=(offsets_m[:, None] < M) & (offsets_n[None, :] < N), 503 | ) 504 | 505 | x_dtype = x.dtype 506 | x = (x.to(tl.float32) + a.to(tl.float32) * b.to(tl.float32) * c).to(x_dtype) 507 | 508 | tl.store( 509 | x_ptr + offsets_m[:, None] * stride_x0 + offsets_n[None, :] * stride_x1, 510 | x, 511 | mask=(offsets_m[:, None] < M) & (offsets_n[None, :] < N), 512 | ) 513 | 514 | 515 | 516 | def triton_sum_dim0_in_fp32(xs): 517 | a, b = xs.shape 518 | 519 | assert xs.is_contiguous() 520 | assert xs.dtype == torch.float16 521 | 522 | BLOCK_SIZE_A = min(triton.next_power_of_2(a), 512) 523 | BLOCK_SIZE_B = 64 # cache line is 128 bytes 524 | 525 | out = torch.zeros(b, dtype=torch.float32, device=xs.device) 526 | 527 | grid = lambda META: (triton.cdiv(b, META["BLOCK_SIZE_B"]),) 528 | 529 | triton_sum_dim0_in_fp32_kernel[grid]( 530 | xs, 531 | out, 532 | stride_a=xs.stride(0), 533 | a=a, 534 | b=b, 535 | BLOCK_SIZE_A=BLOCK_SIZE_A, 536 | BLOCK_SIZE_B=BLOCK_SIZE_B, 537 | ) 538 | 539 | return out 540 | 541 | 542 | @triton.jit 543 | def triton_sum_dim0_in_fp32_kernel( 544 | xs_ptr, 545 | out_ptr, 546 | stride_a, 547 | a, 548 | b, 549 | BLOCK_SIZE_A: tl.constexpr, 550 | BLOCK_SIZE_B: tl.constexpr, 551 | ): 552 | # each program handles 64 columns of xs 553 | pid = tl.program_id(0) 554 | offsets_b = tl.arange(0, BLOCK_SIZE_B) + pid * BLOCK_SIZE_B 555 | 556 | all_out = tl.zeros((BLOCK_SIZE_B,), dtype=tl.float32) 557 | 558 | for i in range(0, a, BLOCK_SIZE_A): 559 | offsets_a = tl.arange(0, BLOCK_SIZE_A) + i 560 | xs = tl.load( 561 | xs_ptr + offsets_a[:, None] * stride_a + offsets_b[None, :], 562 | mask=(offsets_a < a)[:, None] & (offsets_b < b)[None, :], 563 | other=0, 564 | ) 565 | xs = xs.to(tl.float32) 566 | out = tl.sum(xs, axis=0) 567 | all_out += out 568 | 569 | tl.store(out_ptr + offsets_b, all_out, mask=offsets_b < b) 570 | 571 | 572 | def mse( 573 | output, 574 | target, 575 | ): # fusing fp32 cast and MSE to save memory 576 | assert output.shape == target.shape 577 | assert len(output.shape) == 2 578 | assert output.stride(1) == 1 579 | assert target.stride(1) == 1 580 | 581 | a, b = output.shape 582 | 583 | BLOCK_SIZE_B = triton.next_power_of_2(b) 584 | 585 | class _MSE(torch.autograd.Function): 586 | @staticmethod 587 | def forward(ctx, output, target): 588 | ctx.save_for_backward(output, target) 589 | out = torch.zeros(a, dtype=torch.float32, device=output.device) 590 | 591 | triton_mse_loss_fp16_kernel[(a,)]( 592 | output, 593 | target, 594 | out, 595 | stride_a_output=output.stride(0), 596 | stride_a_target=target.stride(0), 597 | a=a, 598 | b=b, 599 | BLOCK_SIZE_B=BLOCK_SIZE_B, 600 | ) 601 | 602 | return out 603 | 604 | @staticmethod 605 | def backward(ctx, grad_output): 606 | output, target = ctx.saved_tensors 607 | res = (output - target).float() 608 | res *= grad_output[:, None] * 2 / b 609 | return res, None 610 | 611 | return _MSE.apply(output, target).mean() 612 | 613 | 614 | def normalized_mse(recon: torch.Tensor, xs: torch.Tensor) -> torch.Tensor: 615 | # only used for auxk 616 | xs_mu = ( 617 | triton_sum_dim0_in_fp32(xs) / xs.shape[0] 618 | if xs.dtype == torch.float16 619 | else xs.mean(dim=0) 620 | ) 621 | 622 | loss = mse(recon, xs) / mse( 623 | xs_mu[None, :].broadcast_to(xs.shape), xs 624 | ) 625 | 626 | return loss 627 | 628 | 629 | @triton.jit 630 | def triton_mse_loss_fp16_kernel( 631 | output_ptr, 632 | target_ptr, 633 | out_ptr, 634 | stride_a_output, 635 | stride_a_target, 636 | a, 637 | b, 638 | BLOCK_SIZE_B: tl.constexpr, 639 | ): 640 | pid = tl.program_id(0) 641 | offsets_b = tl.arange(0, BLOCK_SIZE_B) 642 | 643 | output = tl.load( 644 | output_ptr + pid * stride_a_output + offsets_b, 645 | mask=offsets_b < b, 646 | ) 647 | target = tl.load( 648 | target_ptr + pid * stride_a_target + offsets_b, 649 | mask=offsets_b < b, 650 | ) 651 | 652 | output = output.to(tl.float32) 653 | target = target.to(tl.float32) 654 | 655 | mse = tl.sum((output - target) * (output - target)) / b 656 | 657 | tl.store(out_ptr + pid, mse) 658 | 659 | 660 | def triton_add_mul_( 661 | x: torch.Tensor, 662 | a: torch.Tensor, 663 | b: torch.Tensor, 664 | c: float, 665 | ): 666 | """ 667 | does 668 | x += a * b * c 669 | 670 | x : [m, n] 671 | a : [m, n] 672 | b : [m, n] 673 | c : float 674 | """ 675 | 676 | if len(a.shape) == 1: 677 | a = a[None, :].broadcast_to(x.shape) 678 | 679 | if len(b.shape) == 1: 680 | b = b[None, :].broadcast_to(x.shape) 681 | 682 | assert x.shape == a.shape == b.shape 683 | 684 | BLOCK_SIZE_M = 64 685 | BLOCK_SIZE_N = 64 686 | grid = lambda META: ( 687 | triton.cdiv(x.shape[0], META["BLOCK_SIZE_M"]), 688 | triton.cdiv(x.shape[1], META["BLOCK_SIZE_N"]), 689 | ) 690 | triton_add_mul_kernel[grid]( 691 | x, 692 | a, 693 | b, 694 | c, 695 | x.stride(0), 696 | x.stride(1), 697 | a.stride(0), 698 | a.stride(1), 699 | b.stride(0), 700 | b.stride(1), 701 | BLOCK_SIZE_M, 702 | BLOCK_SIZE_N, 703 | x.shape[0], 704 | x.shape[1], 705 | ) 706 | -------------------------------------------------------------------------------- /sparse_autoencoder/train.py: -------------------------------------------------------------------------------- 1 | # bare bones training script using sparse kernels and sharding/data parallel. 2 | # the main purpose of this code is to provide a reference implementation to compare 3 | # against when implementing our training methodology into other codebases, and to 4 | # demonstrate how sharding/DP can be implemented for autoencoders. some limitations: 5 | # - many basic features (e.g checkpointing, data loading, validation) are not implemented, 6 | # - the codebase is not designed to be extensible or easily hackable. 7 | # - this code is not guaranteed to run efficiently out of the box / in 8 | # combination with other changes, so you should profile it and make changes as needed. 9 | # 10 | # example launch command: 11 | # torchrun --nproc-per-node 8 train.py 12 | 13 | 14 | import os 15 | from dataclasses import dataclass 16 | from typing import Callable, Iterable, Iterator 17 | 18 | import torch 19 | import torch.distributed as dist 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | import triton 23 | import triton.language as tl 24 | from sparse_autoencoder.kernels import * 25 | from torch.distributed import ReduceOp 26 | 27 | RANK = int(os.environ.get("RANK", "0")) 28 | 29 | 30 | ## parallelism 31 | 32 | 33 | @dataclass 34 | class Comm: 35 | group: torch.distributed.ProcessGroup 36 | 37 | def all_reduce(self, x, op=ReduceOp.SUM, async_op=False): 38 | return dist.all_reduce(x, op=op, group=self.group, async_op=async_op) 39 | 40 | def all_gather(self, x_list, x, async_op=False): 41 | return dist.all_gather(list(x_list), x, group=self.group, async_op=async_op) 42 | 43 | def broadcast(self, x, src, async_op=False): 44 | return dist.broadcast(x, src, group=self.group, async_op=async_op) 45 | 46 | def barrier(self): 47 | return dist.barrier(group=self.group) 48 | 49 | def size(self): 50 | return self.group.size() 51 | 52 | 53 | @dataclass 54 | class ShardingComms: 55 | n_replicas: int 56 | n_op_shards: int 57 | dp_rank: int 58 | sh_rank: int 59 | dp_comm: Comm | None 60 | sh_comm: Comm | None 61 | _rank: int 62 | 63 | def sh_allreduce_forward(self, x: torch.Tensor) -> torch.Tensor: 64 | if self.sh_comm is None: 65 | return x 66 | 67 | class AllreduceForward(torch.autograd.Function): 68 | @staticmethod 69 | def forward(ctx, input): 70 | assert self.sh_comm is not None 71 | self.sh_comm.all_reduce(input, async_op=True) 72 | return input 73 | 74 | @staticmethod 75 | def backward(ctx, grad_output): 76 | return grad_output 77 | 78 | return AllreduceForward.apply(x) # type: ignore 79 | 80 | def sh_allreduce_backward(self, x: torch.Tensor) -> torch.Tensor: 81 | if self.sh_comm is None: 82 | return x 83 | 84 | class AllreduceBackward(torch.autograd.Function): 85 | @staticmethod 86 | def forward(ctx, input): 87 | return input 88 | 89 | @staticmethod 90 | def backward(ctx, grad_output): 91 | grad_output = grad_output.clone() 92 | assert self.sh_comm is not None 93 | self.sh_comm.all_reduce(grad_output, async_op=True) 94 | return grad_output 95 | 96 | return AllreduceBackward.apply(x) # type: ignore 97 | 98 | def init_broadcast_(self, autoencoder): 99 | if self.dp_comm is not None: 100 | for p in autoencoder.parameters(): 101 | self.dp_comm.broadcast( 102 | maybe_transpose(p.data), 103 | replica_shard_to_rank( 104 | replica_idx=0, 105 | shard_idx=self.sh_rank, 106 | n_op_shards=self.n_op_shards, 107 | ), 108 | ) 109 | 110 | if self.sh_comm is not None: 111 | # pre_bias is the same across all shards 112 | self.sh_comm.broadcast( 113 | autoencoder.pre_bias.data, 114 | replica_shard_to_rank( 115 | replica_idx=self.dp_rank, 116 | shard_idx=0, 117 | n_op_shards=self.n_op_shards, 118 | ), 119 | ) 120 | 121 | def dp_allreduce_(self, autoencoder) -> None: 122 | if self.dp_comm is None: 123 | return 124 | 125 | for param in autoencoder.parameters(): 126 | if param.grad is not None: 127 | self.dp_comm.all_reduce(maybe_transpose(param.grad), op=ReduceOp.AVG, async_op=True) 128 | 129 | # make sure statistics for dead neurons are correct 130 | self.dp_comm.all_reduce( # type: ignore 131 | autoencoder.stats_last_nonzero, op=ReduceOp.MIN, async_op=True 132 | ) 133 | 134 | def sh_allreduce_scale(self, scaler): 135 | if self.sh_comm is None: 136 | return 137 | 138 | if hasattr(scaler, "_scale") and scaler._scale is not None: 139 | self.sh_comm.all_reduce(scaler._scale, op=ReduceOp.MIN, async_op=True) 140 | self.sh_comm.all_reduce(scaler._growth_tracker, op=ReduceOp.MIN, async_op=True) 141 | 142 | def _sh_comm_op(self, x, op): 143 | if isinstance(x, (float, int)): 144 | x = torch.tensor(x, device="cuda") 145 | 146 | if not x.is_cuda: 147 | x = x.cuda() 148 | 149 | if self.sh_comm is None: 150 | return x 151 | 152 | out = x.clone() 153 | self.sh_comm.all_reduce(x, op=op, async_op=True) 154 | return out 155 | 156 | def sh_sum(self, x: torch.Tensor) -> torch.Tensor: 157 | return self._sh_comm_op(x, ReduceOp.SUM) 158 | 159 | def all_broadcast(self, x: torch.Tensor) -> torch.Tensor: 160 | if self.dp_comm is not None: 161 | self.dp_comm.broadcast( 162 | x, 163 | replica_shard_to_rank( 164 | replica_idx=0, 165 | shard_idx=self.sh_rank, 166 | n_op_shards=self.n_op_shards, 167 | ), 168 | ) 169 | 170 | if self.sh_comm is not None: 171 | self.sh_comm.broadcast( 172 | x, 173 | replica_shard_to_rank( 174 | replica_idx=self.dp_rank, 175 | shard_idx=0, 176 | n_op_shards=self.n_op_shards, 177 | ), 178 | ) 179 | 180 | return x 181 | 182 | 183 | def make_torch_comms(n_op_shards=4, n_replicas=2): 184 | if "RANK" not in os.environ: 185 | assert n_op_shards == 1 186 | assert n_replicas == 1 187 | return TRIVIAL_COMMS 188 | 189 | rank = int(os.environ.get("RANK")) 190 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 191 | os.environ["CUDA_VISIBLE_DEVICES"] = str(rank % 8) 192 | 193 | print(f"{rank=}, {world_size=}") 194 | dist.init_process_group("nccl") 195 | 196 | my_op_shard_idx = rank % n_op_shards 197 | my_replica_idx = rank // n_op_shards 198 | 199 | shard_rank_lists = [list(range(i, i + n_op_shards)) for i in range(0, world_size, n_op_shards)] 200 | 201 | shard_groups = [dist.new_group(shard_rank_list) for shard_rank_list in shard_rank_lists] 202 | 203 | my_shard_group = shard_groups[my_replica_idx] 204 | 205 | replica_rank_lists = [ 206 | list(range(i, n_op_shards * n_replicas, n_op_shards)) for i in range(n_op_shards) 207 | ] 208 | 209 | replica_groups = [dist.new_group(replica_rank_list) for replica_rank_list in replica_rank_lists] 210 | 211 | my_replica_group = replica_groups[my_op_shard_idx] 212 | 213 | torch.distributed.all_reduce(torch.ones(1).cuda()) 214 | torch.cuda.synchronize() 215 | 216 | dp_comm = Comm(group=my_replica_group) 217 | sh_comm = Comm(group=my_shard_group) 218 | 219 | return ShardingComms( 220 | n_replicas=n_replicas, 221 | n_op_shards=n_op_shards, 222 | dp_comm=dp_comm, 223 | sh_comm=sh_comm, 224 | dp_rank=my_replica_idx, 225 | sh_rank=my_op_shard_idx, 226 | _rank=rank, 227 | ) 228 | 229 | 230 | def replica_shard_to_rank(replica_idx, shard_idx, n_op_shards): 231 | return replica_idx * n_op_shards + shard_idx 232 | 233 | 234 | TRIVIAL_COMMS = ShardingComms( 235 | n_replicas=1, 236 | n_op_shards=1, 237 | dp_rank=0, 238 | sh_rank=0, 239 | dp_comm=None, 240 | sh_comm=None, 241 | _rank=0, 242 | ) 243 | 244 | 245 | def sharded_topk(x, k, sh_comm, capacity_factor=None): 246 | batch = x.shape[0] 247 | 248 | if capacity_factor is not None: 249 | k_in = min(int(k * capacity_factor // sh_comm.size()), k) 250 | else: 251 | k_in = k 252 | 253 | topk = torch.topk(x, k=k_in, dim=-1) 254 | inds = topk.indices 255 | vals = topk.values 256 | 257 | if sh_comm is None: 258 | return inds, vals 259 | 260 | all_vals = torch.empty(sh_comm.size(), batch, k_in, dtype=vals.dtype, device=vals.device) 261 | sh_comm.all_gather(all_vals, vals, async_op=True) 262 | 263 | all_vals = all_vals.permute(1, 0, 2) # put shard dim next to k 264 | all_vals = all_vals.reshape(batch, -1) # flatten shard into k 265 | 266 | all_topk = torch.topk(all_vals, k=k, dim=-1) 267 | global_topk = all_topk.values 268 | 269 | dummy_vals = torch.zeros_like(vals) 270 | dummy_inds = torch.zeros_like(inds) 271 | 272 | my_inds = torch.where(vals >= global_topk[:, [-1]], inds, dummy_inds) 273 | my_vals = torch.where(vals >= global_topk[:, [-1]], vals, dummy_vals) 274 | 275 | return my_inds, my_vals 276 | 277 | 278 | ## autoencoder 279 | 280 | 281 | class FastAutoencoder(nn.Module): 282 | """ 283 | Top-K Autoencoder with sparse kernels. Implements: 284 | 285 | latents = relu(topk(encoder(x - pre_bias) + latent_bias)) 286 | recons = decoder(latents) + pre_bias 287 | """ 288 | 289 | def __init__( 290 | self, 291 | n_dirs_local: int, 292 | d_model: int, 293 | k: int, 294 | auxk: int | None, 295 | dead_steps_threshold: int, 296 | comms: ShardingComms | None = None, 297 | ): 298 | super().__init__() 299 | self.n_dirs_local = n_dirs_local 300 | self.d_model = d_model 301 | self.k = k 302 | self.auxk = auxk 303 | self.comms = comms if comms is not None else TRIVIAL_COMMS 304 | self.dead_steps_threshold = dead_steps_threshold 305 | 306 | self.encoder = nn.Linear(d_model, n_dirs_local, bias=False) 307 | self.decoder = nn.Linear(n_dirs_local, d_model, bias=False) 308 | 309 | self.pre_bias = nn.Parameter(torch.zeros(d_model)) 310 | self.latent_bias = nn.Parameter(torch.zeros(n_dirs_local)) 311 | 312 | self.stats_last_nonzero: torch.Tensor 313 | self.register_buffer("stats_last_nonzero", torch.zeros(n_dirs_local, dtype=torch.long)) 314 | 315 | def auxk_mask_fn(x): 316 | dead_mask = self.stats_last_nonzero > dead_steps_threshold 317 | x.data *= dead_mask # inplace to save memory 318 | return x 319 | 320 | self.auxk_mask_fn = auxk_mask_fn 321 | 322 | ## initialization 323 | 324 | # "tied" init 325 | self.decoder.weight.data = self.encoder.weight.data.T.clone() 326 | 327 | # store decoder in column major layout for kernel 328 | self.decoder.weight.data = self.decoder.weight.data.T.contiguous().T 329 | 330 | unit_norm_decoder_(self) 331 | 332 | @property 333 | def n_dirs(self): 334 | return self.n_dirs_local * self.comms.n_op_shards 335 | 336 | def forward(self, x): 337 | class EncWrapper(torch.autograd.Function): 338 | @staticmethod 339 | def forward(ctx, x, pre_bias, weight, latent_bias): 340 | x = x - pre_bias 341 | latents_pre_act = F.linear(x, weight, latent_bias) 342 | 343 | inds, vals = sharded_topk( 344 | latents_pre_act, 345 | k=self.k, 346 | sh_comm=self.comms.sh_comm, 347 | capacity_factor=4, 348 | ) 349 | 350 | ## set num nonzero stat ## 351 | tmp = torch.zeros_like(self.stats_last_nonzero) 352 | tmp.scatter_add_( 353 | 0, 354 | inds.reshape(-1), 355 | (vals > 1e-3).to(tmp.dtype).reshape(-1), 356 | ) 357 | self.stats_last_nonzero *= 1 - tmp.clamp(max=1) 358 | self.stats_last_nonzero += 1 359 | ## end stats ## 360 | 361 | ## auxk 362 | if self.auxk is not None: # for auxk 363 | # IMPORTANT: has to go after stats update! 364 | # WARN: auxk_mask_fn can mutate latents_pre_act! 365 | auxk_inds, auxk_vals = sharded_topk( 366 | self.auxk_mask_fn(latents_pre_act), 367 | k=self.auxk, 368 | sh_comm=self.comms.sh_comm, 369 | capacity_factor=2, 370 | ) 371 | ctx.save_for_backward(x, weight, inds, auxk_inds) 372 | else: 373 | ctx.save_for_backward(x, weight, inds) 374 | auxk_inds = None 375 | auxk_vals = None 376 | 377 | ## end auxk 378 | 379 | return ( 380 | inds, 381 | vals, 382 | auxk_inds, 383 | auxk_vals, 384 | ) 385 | 386 | @staticmethod 387 | def backward(ctx, _, grad_vals, __, grad_auxk_vals): 388 | # encoder backwards 389 | if self.auxk is not None: 390 | x, weight, inds, auxk_inds = ctx.saved_tensors 391 | 392 | all_inds = torch.cat((inds, auxk_inds), dim=-1) 393 | all_grad_vals = torch.cat((grad_vals, grad_auxk_vals), dim=-1) 394 | else: 395 | x, weight, inds = ctx.saved_tensors 396 | 397 | all_inds = inds 398 | all_grad_vals = grad_vals 399 | 400 | grad_sum = torch.zeros(self.n_dirs_local, dtype=torch.float32, device=grad_vals.device) 401 | grad_sum.scatter_add_( 402 | -1, all_inds.flatten(), all_grad_vals.flatten().to(torch.float32) 403 | ) 404 | 405 | return ( 406 | None, 407 | # pre_bias grad optimization - can reduce before mat-vec multiply 408 | -(grad_sum @ weight), 409 | triton_sparse_transpose_dense_matmul(all_inds, all_grad_vals, x, N=self.n_dirs_local), 410 | grad_sum, 411 | ) 412 | 413 | pre_bias = self.comms.sh_allreduce_backward(self.pre_bias) 414 | 415 | # encoder 416 | inds, vals, auxk_inds, auxk_vals = EncWrapper.apply( 417 | x, pre_bias, self.encoder.weight, self.latent_bias 418 | ) 419 | 420 | vals = torch.relu(vals) 421 | if auxk_vals is not None: 422 | auxk_vals = torch.relu(auxk_vals) 423 | 424 | recons = self.decode_sparse(inds, vals) 425 | 426 | return recons, { 427 | "auxk_inds": auxk_inds, 428 | "auxk_vals": auxk_vals, 429 | } 430 | 431 | def decode_sparse(self, inds, vals): 432 | recons = TritonDecoderAutograd.apply(inds, vals, self.decoder.weight) 433 | recons = self.comms.sh_allreduce_forward(recons) 434 | 435 | return recons + self.pre_bias 436 | 437 | 438 | def unit_norm_decoder_(autoencoder: FastAutoencoder) -> None: 439 | """ 440 | Unit normalize the decoder weights of an autoencoder. 441 | """ 442 | autoencoder.decoder.weight.data /= autoencoder.decoder.weight.data.norm(dim=0) 443 | 444 | 445 | def unit_norm_decoder_grad_adjustment_(autoencoder) -> None: 446 | """project out gradient information parallel to the dictionary vectors - assumes that the decoder is already unit normed""" 447 | 448 | assert autoencoder.decoder.weight.grad is not None 449 | 450 | triton_add_mul_( 451 | autoencoder.decoder.weight.grad, 452 | torch.einsum("bn,bn->n", autoencoder.decoder.weight.data, autoencoder.decoder.weight.grad), 453 | autoencoder.decoder.weight.data, 454 | c=-1, 455 | ) 456 | 457 | 458 | def maybe_transpose(x): 459 | return x.T if not x.is_contiguous() and x.T.is_contiguous() else x 460 | 461 | 462 | def sharded_grad_norm(autoencoder, comms, exclude=None): 463 | if exclude is None: 464 | exclude = [] 465 | total_sq_norm = torch.zeros((), device="cuda", dtype=torch.float32) 466 | exclude = set(exclude) 467 | 468 | total_num_params = 0 469 | for param in autoencoder.parameters(): 470 | if param in exclude: 471 | continue 472 | if param.grad is not None: 473 | sq_norm = ((param.grad).float() ** 2).sum() 474 | if param is autoencoder.pre_bias: 475 | total_sq_norm += sq_norm # pre_bias is the same across all shards 476 | else: 477 | total_sq_norm += comms.sh_sum(sq_norm) 478 | 479 | param_shards = comms.n_op_shards if param is autoencoder.pre_bias else 1 480 | total_num_params += param.numel() * param_shards 481 | 482 | return total_sq_norm.sqrt() 483 | 484 | 485 | def batch_tensors( 486 | it: Iterable[torch.Tensor], 487 | batch_size: int, 488 | drop_last=True, 489 | stream=None, 490 | ) -> Iterator[torch.Tensor]: 491 | """ 492 | input is iterable of tensors of shape [batch_old, ...] 493 | output is iterable of tensors of shape [batch_size, ...] 494 | batch_old does not need to be divisible by batch_size 495 | """ 496 | 497 | tensors = [] 498 | batch_so_far = 0 499 | 500 | for t in it: 501 | tensors.append(t) 502 | batch_so_far += t.shape[0] 503 | 504 | if sum(t.shape[0] for t in tensors) < batch_size: 505 | continue 506 | 507 | while batch_so_far >= batch_size: 508 | if len(tensors) == 1: 509 | (concat,) = tensors 510 | else: 511 | with torch.cuda.stream(stream): 512 | concat = torch.cat(tensors, dim=0) 513 | 514 | offset = 0 515 | while offset + batch_size <= concat.shape[0]: 516 | yield concat[offset : offset + batch_size] 517 | batch_so_far -= batch_size 518 | offset += batch_size 519 | 520 | tensors = [concat[offset:]] if offset < concat.shape[0] else [] 521 | 522 | if len(tensors) > 0 and not drop_last: 523 | yield torch.cat(tensors, dim=0) 524 | 525 | 526 | def print0(*a, **k): 527 | if RANK == 0: 528 | print(*a, **k) 529 | 530 | 531 | import wandb 532 | 533 | 534 | class Logger: 535 | def __init__(self, **kws): 536 | self.vals = {} 537 | self.enabled = (RANK == 0) and not kws.pop("dummy", False) 538 | if self.enabled: 539 | wandb.init( 540 | **kws 541 | ) 542 | 543 | def logkv(self, k, v): 544 | if self.enabled: 545 | self.vals[k] = v.detach() if isinstance(v, torch.Tensor) else v 546 | return v 547 | 548 | def dumpkvs(self): 549 | if self.enabled: 550 | wandb.log(self.vals) 551 | self.vals = {} 552 | 553 | 554 | def training_loop_( 555 | ae, train_acts_iter, loss_fn, lr, comms, eps=6.25e-10, clip_grad=None, ema_multiplier=0.999, logger=None 556 | ): 557 | if logger is None: 558 | logger = Logger(dummy=True) 559 | 560 | scaler = torch.cuda.amp.GradScaler() 561 | autocast_ctx_manager = torch.cuda.amp.autocast() 562 | 563 | opt = torch.optim.Adam(ae.parameters(), lr=lr, eps=eps, fused=True) 564 | if ema_multiplier is not None: 565 | ema = EmaModel(ae, ema_multiplier=ema_multiplier) 566 | 567 | for i, flat_acts_train_batch in enumerate(train_acts_iter): 568 | flat_acts_train_batch = flat_acts_train_batch.cuda() 569 | 570 | with autocast_ctx_manager: 571 | recons, info = ae(flat_acts_train_batch) 572 | 573 | loss = loss_fn(ae, flat_acts_train_batch, recons, info, logger) 574 | 575 | print0(i, loss) 576 | 577 | logger.logkv("loss_scale", scaler.get_scale()) 578 | 579 | if RANK == 0: 580 | wandb.log({"train_loss": loss.item()}) 581 | 582 | loss = scaler.scale(loss) 583 | loss.backward() 584 | 585 | unit_norm_decoder_(ae) 586 | unit_norm_decoder_grad_adjustment_(ae) 587 | 588 | # allreduce gradients 589 | comms.dp_allreduce_(ae) 590 | 591 | # keep fp16 loss scale synchronized across shards 592 | comms.sh_allreduce_scale(scaler) 593 | 594 | # if you want to do anything with the gradients that depends on the absolute scale (e.g clipping, do it after the unscale_) 595 | scaler.unscale_(opt) 596 | 597 | # gradient clipping 598 | if clip_grad is not None: 599 | grad_norm = sharded_grad_norm(ae, comms) 600 | logger.logkv("grad_norm", grad_norm) 601 | grads = [x.grad for x in ae.parameters() if x.grad is not None] 602 | torch._foreach_mul_(grads, clip_grad / torch.clamp(grad_norm, min=clip_grad)) 603 | 604 | if ema_multiplier is not None: 605 | ema.step() 606 | 607 | # take step with optimizer 608 | scaler.step(opt) 609 | scaler.update() 610 | 611 | logger.dumpkvs() 612 | 613 | 614 | def init_from_data_(ae, stats_acts_sample, comms): 615 | from geom_median.torch import compute_geometric_median 616 | 617 | ae.pre_bias.data = ( 618 | compute_geometric_median(stats_acts_sample[:32768].float().cpu()).median.cuda().float() 619 | ) 620 | comms.all_broadcast(ae.pre_bias.data) 621 | 622 | # encoder initialization (note: in our ablations we couldn't find clear evidence that this is beneficial, this is just to ensure exact match with internal codebase) 623 | d_model = ae.d_model 624 | with torch.no_grad(): 625 | x = torch.randn(256, d_model).cuda().to(stats_acts_sample.dtype) 626 | x /= x.norm(dim=-1, keepdim=True) 627 | x += ae.pre_bias.data 628 | comms.all_broadcast(x) 629 | recons, _ = ae(x) 630 | recons_norm = (recons - ae.pre_bias.data).norm(dim=-1).mean() 631 | 632 | ae.encoder.weight.data /= recons_norm.item() 633 | print0("x norm", x.norm(dim=-1).mean().item()) 634 | print0("out norm", (ae(x)[0] - ae.pre_bias.data).norm(dim=-1).mean().item()) 635 | 636 | 637 | from contextlib import contextmanager 638 | 639 | 640 | @contextmanager 641 | def temporary_weight_swap(model: torch.nn.Module, new_weights: list[torch.Tensor]): 642 | for _p, new_p in zip(model.parameters(), new_weights, strict=True): 643 | assert _p.shape == new_p.shape 644 | _p.data, new_p.data = new_p.data, _p.data 645 | 646 | yield 647 | 648 | for _p, new_p in zip(model.parameters(), new_weights, strict=True): 649 | assert _p.shape == new_p.shape 650 | _p.data, new_p.data = new_p.data, _p.data 651 | 652 | 653 | class EmaModel: 654 | def __init__(self, model, ema_multiplier): 655 | self.model = model 656 | self.ema_multiplier = ema_multiplier 657 | self.ema_weights = [torch.zeros_like(x, requires_grad=False) for x in model.parameters()] 658 | self.ema_steps = 0 659 | 660 | def step(self): 661 | torch._foreach_lerp_( 662 | self.ema_weights, 663 | list(self.model.parameters()), 664 | 1 - self.ema_multiplier, 665 | ) 666 | self.ema_steps += 1 667 | 668 | # context manager for setting the autoencoder weights to the EMA weights 669 | @contextmanager 670 | def use_ema_weights(self): 671 | assert self.ema_steps > 0 672 | 673 | # apply bias correction 674 | bias_correction = 1 - self.ema_multiplier**self.ema_steps 675 | ema_weights_bias_corrected = torch._foreach_div(self.ema_weights, bias_correction) 676 | 677 | with torch.no_grad(): 678 | with temporary_weight_swap(self.model, ema_weights_bias_corrected): 679 | yield 680 | 681 | 682 | @dataclass 683 | class Config: 684 | n_op_shards: int = 1 685 | n_replicas: int = 8 686 | 687 | n_dirs: int = 32768 688 | bs: int = 131072 689 | d_model: int = 768 690 | k: int = 32 691 | auxk: int = 256 692 | 693 | lr: float = 1e-4 694 | eps: float = 6.25e-10 695 | clip_grad: float | None = None 696 | auxk_coef: float = 1 / 32 697 | dead_toks_threshold: int = 10_000_000 698 | ema_multiplier: float | None = None 699 | 700 | wandb_project: str | None = None 701 | wandb_name: str | None = None 702 | 703 | 704 | def main(): 705 | cfg = Config() 706 | comms = make_torch_comms(n_op_shards=cfg.n_op_shards, n_replicas=cfg.n_replicas) 707 | 708 | ## dataloading is left as an exercise for the reader 709 | acts_iter = ... 710 | stats_acts_sample = ... 711 | 712 | n_dirs_local = cfg.n_dirs // cfg.n_op_shards 713 | bs_local = cfg.bs // cfg.n_replicas 714 | 715 | ae = FastAutoencoder( 716 | n_dirs_local=n_dirs_local, 717 | d_model=cfg.d_model, 718 | k=cfg.k, 719 | auxk=cfg.auxk, 720 | dead_steps_threshold=cfg.dead_toks_threshold // cfg.bs, 721 | comms=comms, 722 | ) 723 | ae.cuda() 724 | init_from_data_(ae, stats_acts_sample, comms) 725 | # IMPORTANT: make sure all DP ranks have the same params 726 | comms.init_broadcast_(ae) 727 | 728 | mse_scale = ( 729 | 1 / ((stats_acts_sample.float().mean(dim=0) - stats_acts_sample.float()) ** 2).mean() 730 | ) 731 | comms.all_broadcast(mse_scale) 732 | mse_scale = mse_scale.item() 733 | 734 | logger = Logger( 735 | project=cfg.wandb_project, 736 | name=cfg.wandb_name, 737 | dummy=cfg.wandb_project is None, 738 | ) 739 | 740 | training_loop_( 741 | ae, 742 | batch_tensors( 743 | acts_iter, 744 | bs_local, 745 | drop_last=True, 746 | ), 747 | lambda ae, flat_acts_train_batch, recons, info, logger: ( 748 | # MSE 749 | logger.logkv("train_recons", mse_scale * mse(recons, flat_acts_train_batch)) 750 | # AuxK 751 | + logger.logkv( 752 | "train_maxk_recons", 753 | cfg.auxk_coef 754 | * normalized_mse( 755 | ae.decode_sparse( 756 | info["auxk_inds"], 757 | info["auxk_vals"], 758 | ), 759 | flat_acts_train_batch - recons.detach() + ae.pre_bias.detach(), 760 | ).nan_to_num(0), 761 | ) 762 | ), 763 | lr=cfg.lr, 764 | eps=cfg.eps, 765 | clip_grad=cfg.clip_grad, 766 | ema_multiplier=cfg.ema_multiplier, 767 | logger=logger, 768 | comms=comms, 769 | ) 770 | 771 | 772 | if __name__ == "__main__": 773 | main() 774 | --------------------------------------------------------------------------------