├── postcss.config.js ├── .gitignore ├── tsconfig.node.json ├── src ├── index.tsx ├── hooks │ ├── useEventLogger.ts │ ├── useTransformerMachine.ts │ └── useTransformerDiagram.ts ├── index.css ├── utils │ ├── math.ts │ ├── randomWeights.ts │ ├── data.ts │ ├── constants.ts │ ├── componentTransformations.ts │ └── componentDataGenerator.ts ├── types │ └── events.d.ts ├── App.tsx ├── test │ └── setup.ts ├── components │ ├── TokenVisualization.tsx │ ├── TransformerDiagram.tsx │ ├── HelpMenu.tsx │ ├── MatrixVisualization.tsx │ ├── SettingsMenu.tsx │ └── ComponentDetailsPanel.tsx └── state │ └── transformerMachine.ts ├── vite.config.ts ├── tailwind.config.js ├── index.html ├── tsconfig.json ├── .eslintrc.json ├── package.json ├── .github └── workflows │ └── build-release.yml ├── LICENSE ├── server └── server.py └── README.md /postcss.config.js: -------------------------------------------------------------------------------- 1 | export default { 2 | plugins: { 3 | tailwindcss: {}, 4 | autoprefixer: {}, 5 | }, 6 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | node_modules/ 2 | dist/ 3 | .env 4 | .env.local 5 | .env.development.local 6 | .env.test.local 7 | .env.production.local 8 | .env.development 9 | .env.test 10 | .env.production 11 | -------------------------------------------------------------------------------- /tsconfig.node.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "composite": true, 4 | "skipLibCheck": true, 5 | "module": "ESNext", 6 | "moduleResolution": "bundler", 7 | "allowSyntheticDefaultImports": true 8 | }, 9 | "include": ["vite.config.ts"] 10 | } -------------------------------------------------------------------------------- /src/index.tsx: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ReactDOM from 'react-dom/client'; 3 | import App from './App'; 4 | import './index.css'; 5 | 6 | const root = ReactDOM.createRoot( 7 | document.getElementById('root') as HTMLElement 8 | ); 9 | 10 | root.render( 11 | 12 | 13 | 14 | ); -------------------------------------------------------------------------------- /vite.config.ts: -------------------------------------------------------------------------------- 1 | import { defineConfig } from 'vite' 2 | import react from '@vitejs/plugin-react' 3 | 4 | // https://vitejs.dev/config/ 5 | export default defineConfig({ 6 | plugins: [react()], 7 | server: { 8 | proxy: { 9 | '/log': { 10 | target: 'http://localhost:3001', 11 | changeOrigin: true, 12 | } 13 | } 14 | } 15 | }) -------------------------------------------------------------------------------- /tailwind.config.js: -------------------------------------------------------------------------------- 1 | /** @type {import('tailwindcss').Config} */ 2 | export default { 3 | content: [ 4 | "./index.html", 5 | "./src/**/*.{js,ts,jsx,tsx}", 6 | ], 7 | theme: { 8 | extend: { 9 | colors: { 10 | 'cs-blue': '#1062fb', 11 | 'cs-gold': '#ffc500', 12 | 'cs-sky': '#64d3ff', 13 | 'cs-deep': '#002570', 14 | }, 15 | fontFamily: { 16 | sans: ['Inter', 'sans-serif'], 17 | }, 18 | }, 19 | }, 20 | plugins: [], 21 | } -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Travel Through Transformers 8 | 9 | 10 | 11 | 12 | 13 |
14 | 15 | 16 | -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "ES2020", 4 | "useDefineForClassFields": true, 5 | "lib": ["ES2020", "DOM", "DOM.Iterable"], 6 | "module": "ESNext", 7 | "skipLibCheck": true, 8 | 9 | /* Bundler mode */ 10 | "moduleResolution": "bundler", 11 | "allowImportingTsExtensions": true, 12 | "resolveJsonModule": true, 13 | "isolatedModules": true, 14 | "noEmit": true, 15 | "jsx": "react-jsx", 16 | 17 | /* Linting */ 18 | "strict": true, 19 | "noUnusedLocals": true, 20 | "noUnusedParameters": true, 21 | "noFallthroughCasesInSwitch": true 22 | }, 23 | "include": ["src"], 24 | "references": [{ "path": "./tsconfig.node.json" }] 25 | } -------------------------------------------------------------------------------- /src/hooks/useEventLogger.ts: -------------------------------------------------------------------------------- 1 | import { useCallback } from 'react'; 2 | import { LogEvent } from '../types/events'; 3 | 4 | export function useEventLogger() { 5 | const logEvent = useCallback(async (eventType: LogEvent['event_type'], payload: Record) => { 6 | const event: LogEvent = { 7 | timestamp: Date.now(), 8 | event_type: eventType, 9 | payload, 10 | }; 11 | 12 | try { 13 | await fetch('/log', { 14 | method: 'POST', 15 | headers: { 16 | 'Content-Type': 'application/json', 17 | }, 18 | body: JSON.stringify(event), 19 | }); 20 | } catch (error) { 21 | console.warn('Failed to log event:', error); 22 | } 23 | }, []); 24 | 25 | return { logEvent }; 26 | } -------------------------------------------------------------------------------- /.eslintrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "root": true, 3 | "env": { 4 | "browser": true, 5 | "es2020": true 6 | }, 7 | "extends": [ 8 | "eslint:recommended", 9 | "plugin:react-hooks/recommended" 10 | ], 11 | "ignorePatterns": ["dist", "*.config.js", "*.config.ts"], 12 | "parser": "@typescript-eslint/parser", 13 | "plugins": ["react-refresh", "@typescript-eslint"], 14 | "rules": { 15 | "react-refresh/only-export-components": [ 16 | "warn", 17 | { "allowConstantExport": true } 18 | ], 19 | "no-unused-vars": "off", 20 | "@typescript-eslint/no-unused-vars": ["error", { "argsIgnorePattern": "^_" }], 21 | "prefer-const": "error", 22 | "no-var": "error", 23 | "no-console": ["warn", { "allow": ["warn", "error"] }] 24 | } 25 | } -------------------------------------------------------------------------------- /src/index.css: -------------------------------------------------------------------------------- 1 | @tailwind base; 2 | @tailwind components; 3 | @tailwind utilities; 4 | 5 | @layer base { 6 | html { 7 | font-family: Inter, system-ui, sans-serif; 8 | } 9 | } 10 | 11 | @layer components { 12 | .slider::-webkit-slider-thumb { 13 | appearance: none; 14 | height: 20px; 15 | width: 20px; 16 | border-radius: 50%; 17 | background: #1062fb; 18 | cursor: pointer; 19 | border: 2px solid #ffffff; 20 | box-shadow: 0 0 0 1px rgba(16, 98, 251, 0.2); 21 | } 22 | 23 | .slider::-moz-range-thumb { 24 | height: 20px; 25 | width: 20px; 26 | border-radius: 50%; 27 | background: #1062fb; 28 | cursor: pointer; 29 | border: 2px solid #ffffff; 30 | box-shadow: 0 0 0 1px rgba(16, 98, 251, 0.2); 31 | } 32 | 33 | .slider:focus::-webkit-slider-thumb { 34 | box-shadow: 0 0 0 3px rgba(16, 98, 251, 0.3); 35 | } 36 | 37 | .slider:focus::-moz-range-thumb { 38 | box-shadow: 0 0 0 3px rgba(16, 98, 251, 0.3); 39 | } 40 | } 41 | 42 | @layer utilities { 43 | .text-balance { 44 | text-wrap: balance; 45 | } 46 | } -------------------------------------------------------------------------------- /src/utils/math.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Performs matrix multiplication between two 2D arrays. 3 | * @param a - First matrix (m x k) 4 | * @param b - Second matrix (k x n) 5 | * @returns Result matrix (m x n) 6 | */ 7 | export function matmul(a: number[][], b: number[][]): number[][] { 8 | const result: number[][] = []; 9 | const rows = a.length; 10 | const cols = b[0].length; 11 | const inner = a[0].length; 12 | 13 | for (let i = 0; i < rows; i++) { 14 | result[i] = []; 15 | for (let j = 0; j < cols; j++) { 16 | result[i][j] = 0; 17 | for (let k = 0; k < inner; k++) { 18 | result[i][j] += a[i][k] * b[k][j]; 19 | } 20 | } 21 | } 22 | 23 | return result; 24 | } 25 | 26 | /** 27 | * Applies softmax function row-wise to a 2D matrix. 28 | * @param matrix - Input matrix where each row will be normalized 29 | * @returns Matrix with each row normalized to sum to 1 30 | */ 31 | export function softmax(matrix: number[][]): number[][] { 32 | return matrix.map(row => { 33 | const max = Math.max(...row); 34 | const exp = row.map(x => Math.exp(x - max)); 35 | const sum = exp.reduce((a, b) => a + b, 0); 36 | return exp.map(x => x / sum); 37 | }); 38 | } -------------------------------------------------------------------------------- /src/types/events.d.ts: -------------------------------------------------------------------------------- 1 | import { ZoomLevel as ImportedZoomLevel } from '../utils/constants'; 2 | 3 | export type ZoomLevel = ImportedZoomLevel; 4 | export type EventType = 'step' | 'param_change' | 'toggle' | 'token_change' | 'zoom_change' | 'component_select' | 'animation_control' | 'attention_lens' | 'sequence_generation'; 5 | 6 | export interface LogEvent { 7 | timestamp: number; 8 | event_type: EventType; 9 | payload: Record; 10 | } 11 | 12 | export interface TransformerParams { 13 | numLayers: number; 14 | dModel: number; 15 | numHeads: number; 16 | seqLen: number; 17 | posEncoding: 'learned' | 'sinusoidal'; 18 | dropout: boolean; 19 | } 20 | 21 | export interface ComponentData { 22 | inputs: ComponentSection[]; 23 | parameters: ComponentSection[]; 24 | outputs: ComponentSection[]; 25 | description: string; 26 | category: 'embedding' | 'attention' | 'ffn' | 'output' | 'tokens' | 'positional'; 27 | } 28 | 29 | export interface ComponentSection { 30 | id: string; 31 | label: string; 32 | description: string; 33 | type: 'matrix' | 'vector' | 'scalar' | 'tokens' | 'text'; 34 | data?: number[][] | number[] | number | string[] | string; 35 | shape?: [number, number] | [number]; 36 | metadata?: { 37 | [key: string]: any; 38 | }; 39 | } 40 | 41 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "travel-through-transformers", 3 | "private": true, 4 | "version": "0.1.0", 5 | "type": "module", 6 | "scripts": { 7 | "dev": "vite", 8 | "build": "tsc && vite build", 9 | "preview": "vite preview", 10 | "lint": "eslint . --ext ts,tsx --report-unused-disable-directives --max-warnings 0", 11 | "test": "vitest", 12 | "test:e2e": "playwright test" 13 | }, 14 | "dependencies": { 15 | "@heroicons/react": "^2.0.18", 16 | "@xstate/react": "^6.0.0", 17 | "d3": "^7.8.5", 18 | "i18next": "^25.3.0", 19 | "react": "^18.2.0", 20 | "react-dom": "^18.2.0", 21 | "react-i18next": "^15.5.3", 22 | "reactflow": "^11.11.4", 23 | "xstate": "^5.20.0" 24 | }, 25 | "devDependencies": { 26 | "@playwright/test": "^1.43.0", 27 | "@testing-library/jest-dom": "^6.4.5", 28 | "@testing-library/react": "^15.0.6", 29 | "@types/d3": "^7.4.3", 30 | "@types/react": "^18.2.66", 31 | "@types/react-dom": "^18.2.22", 32 | "@typescript-eslint/eslint-plugin": "^7.2.0", 33 | "@typescript-eslint/parser": "^7.2.0", 34 | "@vitejs/plugin-react": "^4.2.1", 35 | "autoprefixer": "^10.4.19", 36 | "eslint": "^8.57.0", 37 | "eslint-plugin-react-hooks": "^4.6.0", 38 | "eslint-plugin-react-refresh": "^0.4.6", 39 | "postcss": "^8.4.38", 40 | "tailwindcss": "^3.4.3", 41 | "typescript": "^5.2.2", 42 | "vite": "^5.2.0", 43 | "vitest": "^1.6.0" 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /.github/workflows/build-release.yml: -------------------------------------------------------------------------------- 1 | name: Build and Release 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | 7 | permissions: 8 | contents: write 9 | 10 | jobs: 11 | build-and-release: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - name: Checkout repository 16 | uses: actions/checkout@v4 17 | 18 | - name: Setup Node.js 19 | uses: actions/setup-node@v4 20 | with: 21 | node-version: '20' 22 | cache: 'npm' 23 | 24 | - name: Install dependencies 25 | run: npm ci 26 | 27 | - name: Build project 28 | run: npm run build 29 | 30 | - name: Archive build output 31 | run: tar -czf dist.tar.gz dist 32 | 33 | - name: Upload build artifact (for workflow logs) 34 | uses: actions/upload-artifact@v4 35 | with: 36 | name: dist 37 | path: dist 38 | 39 | - name: Create GitHub Release and upload asset 40 | uses: ncipollo/release-action@v1 41 | with: 42 | token: ${{ secrets.GITHUB_TOKEN }} 43 | tag: release-${{ github.run_id }} 44 | name: Release ${{ github.run_number }} 45 | body: | 46 | Automated release for commit ${{ github.sha }} on branch ${{ github.ref_name }}. 47 | - Trigger: push to main 48 | - Run number: ${{ github.run_number }} 49 | artifacts: dist.tar.gz 50 | allowUpdates: false 51 | draft: false 52 | prerelease: false 53 | makeLatest: true 54 | -------------------------------------------------------------------------------- /src/utils/randomWeights.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Random weight generation utilities with seeded randomness for reproducible results. 3 | * Used to generate matrices, vectors, embeddings, and positional encodings for the simulation. 4 | */ 5 | class SeededRandom { 6 | private seed: number; 7 | 8 | constructor(seed: number = 42) { 9 | this.seed = seed; 10 | } 11 | 12 | next(): number { 13 | // Simple linear congruential generator 14 | this.seed = (this.seed * 1664525 + 1013904223) % (2 ** 32); 15 | return this.seed / (2 ** 32); 16 | } 17 | 18 | uniform(min: number = 0, max: number = 1): number { 19 | return min + (max - min) * this.next(); 20 | } 21 | 22 | normal(mean: number = 0, std: number = 1): number { 23 | // Box-Muller transform 24 | const u1 = this.next(); 25 | const u2 = this.next(); 26 | const z0 = Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2); 27 | return mean + std * z0; 28 | } 29 | } 30 | 31 | const rng = new SeededRandom(42); 32 | 33 | /** 34 | * Generates a random matrix with specified dimensions and distribution. 35 | * @param rows - Number of rows 36 | * @param cols - Number of columns 37 | * @param distribution - Random distribution type ('uniform' or 'normal') 38 | * @returns 2D array of random numbers 39 | */ 40 | export function generateMatrix(rows: number, cols: number, distribution: 'uniform' | 'normal' = 'normal'): number[][] { 41 | const matrix: number[][] = []; 42 | 43 | for (let i = 0; i < rows; i++) { 44 | matrix[i] = []; 45 | for (let j = 0; j < cols; j++) { 46 | if (distribution === 'uniform') { 47 | matrix[i][j] = rng.uniform(-1, 1); 48 | } else { 49 | matrix[i][j] = rng.normal(0, 0.1); 50 | } 51 | } 52 | } 53 | 54 | return matrix; 55 | } 56 | 57 | export function generateVector(size: number, distribution: 'uniform' | 'normal' = 'normal'): number[] { 58 | const vector: number[] = []; 59 | 60 | for (let i = 0; i < size; i++) { 61 | if (distribution === 'uniform') { 62 | vector[i] = rng.uniform(-1, 1); 63 | } else { 64 | vector[i] = rng.normal(0, 0.1); 65 | } 66 | } 67 | 68 | return vector; 69 | } 70 | 71 | export function generateEmbeddings(vocabSize: number, dModel: number): number[][] { 72 | return generateMatrix(vocabSize, dModel, 'normal'); 73 | } 74 | 75 | /** 76 | * Generates positional encoding for transformer models. 77 | * @param seqLen - Sequence length 78 | * @param dModel - Model dimension 79 | * @param type - Type of encoding ('learned' uses random weights, 'sinusoidal' uses mathematical formula) 80 | * @returns 2D array representing positional encoding 81 | */ 82 | export function generatePositionalEncoding(seqLen: number, dModel: number, type: 'learned' | 'sinusoidal'): number[][] { 83 | if (type === 'learned') { 84 | return generateMatrix(seqLen, dModel, 'normal'); 85 | } 86 | 87 | // Sinusoidal positional encoding 88 | const pos: number[][] = []; 89 | 90 | for (let i = 0; i < seqLen; i++) { 91 | pos[i] = []; 92 | for (let j = 0; j < dModel; j++) { 93 | const angle = i / Math.pow(10000, (2 * Math.floor(j / 2)) / dModel); 94 | pos[i][j] = j % 2 === 0 ? Math.sin(angle) : Math.cos(angle); 95 | } 96 | } 97 | 98 | return pos; 99 | } -------------------------------------------------------------------------------- /src/App.tsx: -------------------------------------------------------------------------------- 1 | import { useTransformerMachine } from './hooks/useTransformerMachine'; 2 | import { TransformerDiagram } from './components/TransformerDiagram'; 3 | import { ComponentDetailsPanel } from './components/ComponentDetailsPanel'; 4 | import { SettingsMenu } from './components/SettingsMenu'; 5 | import { HelpMenu } from './components/HelpMenu'; 6 | import { generateComponentData } from './utils/componentDataGenerator'; 7 | 8 | function App() { 9 | const { state, actions } = useTransformerMachine(); 10 | 11 | // Extract commonly used values from state 12 | const { params, selectedComponent } = state; 13 | 14 | // Component selection handlers 15 | const handleComponentClick = (componentId: string) => { 16 | actions.setZoomLevel('sub_layer', componentId); 17 | }; 18 | 19 | const handleComponentClose = () => { 20 | actions.setZoomLevel('global'); 21 | }; 22 | 23 | // Generate component data for the enhanced panel 24 | const componentData = selectedComponent ? generateComponentData(selectedComponent, params) : null; 25 | 26 | return ( 27 |
28 | {/* Header */} 29 |
30 |
31 |
32 |
33 |
34 |

35 | Travel Through Transformers 36 |

37 |
38 |
39 | 40 |
41 | 42 | 46 |
47 |
48 |
49 |
50 | 51 | {/* Main Content */} 52 |
53 | 54 |
55 | {/* Main Architecture View */} 56 |
57 | 62 |
63 | 64 | {/* Component Detail Panel */} 65 |
66 | {selectedComponent && componentData ? ( 67 | 72 | ) : ( 73 |
74 |
75 |
76 | 77 | 78 | 79 |
80 |

No Component Selected

81 |

82 | Click on a transformer component to view its details and internal structure. 83 |

84 |
85 |
86 | )} 87 |
88 |
89 |
90 |
91 | ); 92 | } 93 | 94 | export default App; -------------------------------------------------------------------------------- /src/utils/data.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Generates sequences and manages sequence data for the transformer visualization. 3 | * Provides realistic English-to-French translation examples with proper padding. 4 | */ 5 | 6 | export class SequenceGenerator { 7 | // Longer, meaningful sequences for demonstration (20-25 tokens) 8 | private static readonly INPUT_SEQUENCE = [ 9 | '', 'The', 'artificial', 'intelligence', 'system', 'processes', 'natural', 'language', 10 | 'text', 'by', 'analyzing', 'patterns', 'in', 'large', 'datasets', 'to', 'understand', 11 | 'context', 'and', 'generate', 'meaningful', 'responses', 'for', 'users', '' 12 | ]; 13 | 14 | private static readonly OUTPUT_SEQUENCE = [ 15 | '', 'Le', 'système', 'd\'intelligence', 'artificielle', 'traite', 'le', 'texte', 16 | 'en', 'langage', 'naturel', 'en', 'analysant', 'les', 'motifs', 'dans', 'de', 17 | 'grandes', 'bases', 'de', 'données', 'pour', 'comprendre', 'le', 'contexte', '' 18 | ]; 19 | 20 | 21 | 22 | /** 23 | * Generates a sequence of specified length, using base sequences as templates 24 | */ 25 | static generateSequence(length: number, type: 'input' | 'output' = 'input'): string[] { 26 | const baseSequence = type === 'input' ? this.INPUT_SEQUENCE : this.OUTPUT_SEQUENCE; 27 | return this.adjustSequenceLength(baseSequence, length); 28 | } 29 | 30 | /** 31 | * Adjusts sequence length by truncating or padding as needed 32 | */ 33 | static adjustSequenceLength(sequence: string[], newLength: number): string[] { 34 | if (sequence.length === newLength) { 35 | return [...sequence]; 36 | } else if (sequence.length > newLength) { 37 | return this.truncateSequence(sequence, newLength); 38 | } else { 39 | return this.padSequence(sequence, newLength); 40 | } 41 | } 42 | 43 | /** 44 | * Truncates sequence while preserving start/end tokens when possible 45 | */ 46 | private static truncateSequence(sequence: string[], targetLength: number): string[] { 47 | if (targetLength < 2) { 48 | return sequence.slice(0, targetLength); 49 | } 50 | 51 | // Try to keep SOS and EOS tokens 52 | const hasStart = sequence[0] === ''; 53 | const hasEnd = sequence[sequence.length - 1] === ''; 54 | 55 | if (hasStart && hasEnd && targetLength >= 2) { 56 | const middleLength = targetLength - 2; 57 | const middleTokens = sequence.slice(1, sequence.length - 1).slice(0, middleLength); 58 | return ['', ...middleTokens, '']; 59 | } else if (hasStart && targetLength >= 1) { 60 | const remainingLength = targetLength - 1; 61 | const remainingTokens = sequence.slice(1, 1 + remainingLength); 62 | return ['', ...remainingTokens]; 63 | } else { 64 | return sequence.slice(0, targetLength); 65 | } 66 | } 67 | 68 | /** 69 | * Pads sequence to target length using appropriate padding tokens 70 | */ 71 | private static padSequence(sequence: string[], targetLength: number): string[] { 72 | const padded = [...sequence]; 73 | const tokensToAdd = targetLength - sequence.length; 74 | 75 | for (let i = 0; i < tokensToAdd; i++) { 76 | padded.push(''); 77 | } 78 | 79 | return padded; 80 | } 81 | 82 | /** 83 | * Generates matching output sequence for given input 84 | */ 85 | static generateMatchingOutput(inputSequence: string[]): string[] { 86 | // For this demonstration, we'll use our base translation and adjust length 87 | return this.adjustSequenceLength(this.OUTPUT_SEQUENCE, inputSequence.length); 88 | } 89 | 90 | /** 91 | * Component-specific sequence generation 92 | */ 93 | static generateForComponent(componentId: string, length: number): string[] { 94 | if (componentId.includes('output') || componentId.includes('decoder') || componentId.includes('predicted')) { 95 | return this.generateSequence(length, 'output'); 96 | } 97 | return this.generateSequence(length, 'input'); 98 | } 99 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Elastic License 2.0 2 | 3 | URL: https://www.elastic.co/licensing/elastic-license 4 | 5 | ## Acceptance 6 | 7 | By using the software, you agree to all of the terms and conditions below. 8 | 9 | ## Copyright License 10 | 11 | The licensor grants you a non-exclusive, royalty-free, worldwide, 12 | non-sublicensable, non-transferable license to use, copy, distribute, make 13 | available, and prepare derivative works of the software, in each case subject to 14 | the limitations and conditions below. 15 | 16 | ## Limitations 17 | 18 | You may not provide the software to third parties as a hosted or managed 19 | service, where the service provides users with access to any substantial set of 20 | the features or functionality of the software. 21 | 22 | You may not move, change, disable, or circumvent the license key functionality 23 | in the software, and you may not remove or obscure any functionality in the 24 | software that is protected by the license key. 25 | 26 | You may not alter, remove, or obscure any licensing, copyright, or other notices 27 | of the licensor in the software. Any use of the licensor’s trademarks is subject 28 | to applicable law. 29 | 30 | ## Patents 31 | 32 | The licensor grants you a license, under any patent claims the licensor can 33 | license, or becomes able to license, to make, have made, use, sell, offer for 34 | sale, import and have imported the software, in each case subject to the 35 | limitations and conditions in this license. This license does not cover any 36 | patent claims that you cause to be infringed by modifications or additions to 37 | the software. If you or your company make any written claim that the software 38 | infringes or contributes to infringement of any patent, your patent license for 39 | the software granted under these terms ends immediately. If your company makes 40 | such a claim, your patent license ends immediately for work on behalf of your 41 | company. 42 | 43 | ## Notices 44 | 45 | You must ensure that anyone who gets a copy of any part of the software from you 46 | also gets a copy of these terms. 47 | 48 | If you modify the software, you must include in any modified copies of the 49 | software prominent notices stating that you have modified the software. 50 | 51 | ## No Other Rights 52 | 53 | These terms do not imply any licenses other than those expressly granted in 54 | these terms. 55 | 56 | ## Termination 57 | 58 | If you use the software in violation of these terms, such use is not licensed, 59 | and your licenses will automatically terminate. If the licensor provides you 60 | with a notice of your violation, and you cease all violation of this license no 61 | later than 30 days after you receive that notice, your licenses will be 62 | reinstated retroactively. However, if you violate these terms after such 63 | reinstatement, any additional violation of these terms will cause your licenses 64 | to terminate automatically and permanently. 65 | 66 | ## No Liability 67 | 68 | *As far as the law allows, the software comes as is, without any warranty or 69 | condition, and the licensor will not be liable to you for any damages arising 70 | out of these terms or the use or nature of the software, under any kind of 71 | legal claim.* 72 | 73 | ## Definitions 74 | 75 | The **licensor** is the entity offering these terms, and the **software** is the 76 | software the licensor makes available under these terms, including any portion 77 | of it. 78 | 79 | **you** refers to the individual or entity agreeing to these terms. 80 | 81 | **your company** is any legal entity, sole proprietorship, or other kind of 82 | organization that you work for, plus all organizations that have control over, 83 | are under the control of, or are under common control with that 84 | organization. **control** means ownership of substantially all the assets of an 85 | entity, or the power to direct its management and policies by vote, contract, or 86 | otherwise. Control can be direct or indirect. 87 | 88 | **your licenses** are all the licenses granted to you for the software under 89 | these terms. 90 | 91 | **use** means anything you do with the software requiring one of your licenses. 92 | 93 | **trademark** means trademarks, service marks, and similar rights. 94 | -------------------------------------------------------------------------------- /server/server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Simple HTTP server for the Travel Through Transformers simulation. 4 | Serves static files and handles event logging. 5 | """ 6 | 7 | import json 8 | import os 9 | import sys 10 | from datetime import datetime 11 | from http.server import HTTPServer, SimpleHTTPRequestHandler 12 | from urllib.parse import urlparse 13 | 14 | 15 | class TransformerServerHandler(SimpleHTTPRequestHandler): 16 | def do_POST(self): 17 | """Handle POST requests for logging events.""" 18 | if self.path == '/log': 19 | try: 20 | # Read the request body 21 | content_length = int(self.headers['Content-Length']) 22 | post_data = self.rfile.read(content_length) 23 | event_data = json.loads(post_data.decode('utf-8')) 24 | 25 | # Ensure logs directory exists 26 | os.makedirs('logs', exist_ok=True) 27 | 28 | # Append to log file 29 | log_file = 'logs/simulation_log.jsonl' 30 | with open(log_file, 'a', encoding='utf-8') as f: 31 | f.write(json.dumps(event_data) + '\n') 32 | 33 | # Send success response 34 | self.send_response(200) 35 | self.send_header('Content-Type', 'application/json') 36 | self.send_header('Access-Control-Allow-Origin', '*') 37 | self.send_header('Access-Control-Allow-Methods', 'POST, GET, OPTIONS') 38 | self.send_header('Access-Control-Allow-Headers', 'Content-Type') 39 | self.end_headers() 40 | 41 | response = {'status': 'success', 'timestamp': datetime.now().isoformat()} 42 | self.wfile.write(json.dumps(response).encode('utf-8')) 43 | 44 | print(f"[{datetime.now().strftime('%H:%M:%S')}] Logged event: {event_data['event_type']}") 45 | 46 | except Exception as e: 47 | print(f"Error logging event: {e}") 48 | self.send_response(500) 49 | self.send_header('Content-Type', 'application/json') 50 | self.end_headers() 51 | error_response = {'status': 'error', 'message': str(e)} 52 | self.wfile.write(json.dumps(error_response).encode('utf-8')) 53 | else: 54 | self.send_response(404) 55 | self.end_headers() 56 | 57 | def do_OPTIONS(self): 58 | """Handle CORS preflight requests.""" 59 | self.send_response(200) 60 | self.send_header('Access-Control-Allow-Origin', '*') 61 | self.send_header('Access-Control-Allow-Methods', 'POST, GET, OPTIONS') 62 | self.send_header('Access-Control-Allow-Headers', 'Content-Type') 63 | self.end_headers() 64 | 65 | def do_GET(self): 66 | """Handle GET requests - serve static files from dist directory.""" 67 | if self.path == '/': 68 | self.path = '/index.html' 69 | 70 | # Serve files from dist directory 71 | if os.path.exists(f'dist{self.path}'): 72 | return super().do_GET() 73 | 74 | # Fallback for SPA routing 75 | if not self.path.startswith('/api') and not os.path.exists(f'dist{self.path}'): 76 | self.path = '/index.html' 77 | 78 | return super().do_GET() 79 | 80 | def translate_path(self, path): 81 | """Translate a /-separated PATH to the local filename syntax.""" 82 | # Remove query string and fragment 83 | path = urlparse(path).path 84 | 85 | # Serve from dist directory 86 | root = os.path.join(os.getcwd(), 'dist') 87 | return os.path.join(root, path.lstrip('/')) 88 | 89 | 90 | def run_server(port=3000): 91 | """Run the HTTP server.""" 92 | server_address = ('', port) 93 | httpd = HTTPServer(server_address, TransformerServerHandler) 94 | 95 | print(f"Starting server on http://localhost:{port}") 96 | print("Serving static files from ./dist/") 97 | print("Logging events to ./logs/simulation_log.jsonl") 98 | print("Press Ctrl+C to stop the server") 99 | 100 | try: 101 | httpd.serve_forever() 102 | except KeyboardInterrupt: 103 | print("\nServer stopped.") 104 | httpd.server_close() 105 | 106 | 107 | if __name__ == '__main__': 108 | port = int(sys.argv[1]) if len(sys.argv) > 1 else 3000 109 | run_server(port) -------------------------------------------------------------------------------- /src/test/setup.ts: -------------------------------------------------------------------------------- 1 | import '@testing-library/jest-dom'; 2 | import { beforeAll, afterAll, vi } from 'vitest'; 3 | 4 | // Simple mocks for browser APIs without complex type declarations 5 | const mockIntersectionObserver = vi.fn(() => ({ 6 | disconnect: vi.fn(), 7 | observe: vi.fn(), 8 | unobserve: vi.fn(), 9 | })); 10 | 11 | const mockResizeObserver = vi.fn(() => ({ 12 | disconnect: vi.fn(), 13 | observe: vi.fn(), 14 | unobserve: vi.fn(), 15 | })); 16 | 17 | // Set up global mocks 18 | Object.defineProperty(window, 'IntersectionObserver', { 19 | writable: true, 20 | configurable: true, 21 | value: mockIntersectionObserver, 22 | }); 23 | 24 | Object.defineProperty(window, 'ResizeObserver', { 25 | writable: true, 26 | configurable: true, 27 | value: mockResizeObserver, 28 | }); 29 | 30 | Object.defineProperty(window, 'matchMedia', { 31 | writable: true, 32 | value: vi.fn().mockImplementation(query => ({ 33 | matches: false, 34 | media: query, 35 | onchange: null, 36 | addListener: vi.fn(), 37 | removeListener: vi.fn(), 38 | addEventListener: vi.fn(), 39 | removeEventListener: vi.fn(), 40 | dispatchEvent: vi.fn(), 41 | })), 42 | }); 43 | 44 | // Mock console methods for cleaner test output 45 | const originalConsoleError = console.error; 46 | const originalConsoleWarn = console.warn; 47 | 48 | beforeAll(() => { 49 | // Suppress React development warnings in tests 50 | console.error = (...args: any[]) => { 51 | if ( 52 | typeof args[0] === 'string' && 53 | (args[0].includes('Warning: ReactDOM.render is no longer supported') || 54 | args[0].includes('Warning: An invalid form control')) 55 | ) { 56 | return; 57 | } 58 | originalConsoleError(...args); 59 | }; 60 | 61 | console.warn = (...args: any[]) => { 62 | if ( 63 | typeof args[0] === 'string' && 64 | args[0].includes('React.createFactory() is deprecated') 65 | ) { 66 | return; 67 | } 68 | originalConsoleWarn(...args); 69 | }; 70 | }); 71 | 72 | afterAll(() => { 73 | console.error = originalConsoleError; 74 | console.warn = originalConsoleWarn; 75 | }); 76 | 77 | // Mock D3 for visualization tests 78 | vi.mock('d3', () => ({ 79 | select: vi.fn(() => ({ 80 | selectAll: vi.fn(() => ({ 81 | data: vi.fn(() => ({ 82 | enter: vi.fn(() => ({ 83 | append: vi.fn(() => ({ 84 | attr: vi.fn().mockReturnThis(), 85 | style: vi.fn().mockReturnThis(), 86 | text: vi.fn().mockReturnThis() 87 | })) 88 | })), 89 | exit: vi.fn(() => ({ 90 | remove: vi.fn() 91 | })), 92 | attr: vi.fn().mockReturnThis(), 93 | style: vi.fn().mockReturnThis(), 94 | text: vi.fn().mockReturnThis() 95 | })) 96 | })), 97 | append: vi.fn(() => ({ 98 | attr: vi.fn().mockReturnThis(), 99 | style: vi.fn().mockReturnThis() 100 | })), 101 | attr: vi.fn().mockReturnThis(), 102 | style: vi.fn().mockReturnThis(), 103 | on: vi.fn().mockReturnThis() 104 | })), 105 | scaleLinear: vi.fn(() => ({ 106 | domain: vi.fn(() => ({ 107 | range: vi.fn().mockReturnThis() 108 | })), 109 | range: vi.fn().mockReturnThis() 110 | })), 111 | extent: vi.fn(() => [0, 1]) 112 | })); 113 | 114 | // Educational simulation specific setup and utilities 115 | export const createMockTransformerParams = () => ({ 116 | seqLen: 5, 117 | dModel: 64, 118 | numHeads: 4, 119 | numLayers: 2, 120 | vocabSize: 1000, 121 | maxSeqLen: 10, 122 | positionEncoding: 'sinusoidal' as const, 123 | dropout: 0.1, 124 | showValues: false 125 | }); 126 | 127 | export const createMockComponentData = () => ({ 128 | description: 'Test component for educational simulation', 129 | category: 'attention' as const, 130 | inputs: [], 131 | parameters: [], 132 | outputs: [] 133 | }); 134 | 135 | // Utility for testing async operations in educational simulations 136 | export const waitForSimulation = async (ms: number = 100) => { 137 | await new Promise(resolve => setTimeout(resolve, ms)); 138 | }; 139 | 140 | // Mock XState machine for state management tests 141 | export const createMockMachineContext = () => ({ 142 | currentStep: 0, 143 | totalSteps: 10, 144 | activeToken: 0, 145 | focusToken: false, 146 | zoomLevel: 'global' as const, 147 | params: createMockTransformerParams(), 148 | isPlaying: false, 149 | playbackSpeed: 1.0, 150 | settingsOpen: false, 151 | darkMode: false, 152 | colorBlindMode: false, 153 | attentionLensActive: false, 154 | attentionTopK: 5, 155 | inputSequence: ['The', 'cat', 'sat'], 156 | outputSequence: ['Le', 'chat', 'était'], 157 | lastFrameTime: 0, 158 | currentFPS: 60 159 | }); -------------------------------------------------------------------------------- /src/components/TokenVisualization.tsx: -------------------------------------------------------------------------------- 1 | /** 2 | * TokenVisualization component for displaying tokenized sequences 3 | * with visual styling for different token types and interactive tooltips. 4 | */ 5 | 6 | interface TokenVisualizationProps { 7 | tokens: string[]; 8 | label: string; 9 | className?: string; 10 | } 11 | 12 | export function TokenVisualization({ tokens, label, className = '' }: TokenVisualizationProps) { 13 | const getTokenStyle = (token: string): string => { 14 | // Start token styling 15 | if (token === '' || token === '') { 16 | return 'bg-green-100 text-green-800 border-green-300'; 17 | } 18 | // End token styling 19 | if (token === '') { 20 | return 'bg-red-100 text-red-800 border-red-300'; 21 | } 22 | // Pad token styling 23 | if (token === '') { 24 | return 'bg-gray-100 text-gray-800 border-gray-300'; 25 | } 26 | 27 | // Regular token styling (everything else) 28 | return 'bg-blue-50 text-blue-800 border-blue-200'; 29 | }; 30 | 31 | const getTokenTooltip = (token: string, index: number): string => { 32 | if (token === '' || token === '') { 33 | return 'Start Token - Marks the beginning of input'; 34 | } 35 | if (token === '') { 36 | return 'End Token - Marks the end of input'; 37 | } 38 | if (token === '') { 39 | return 'Pad Token - Used to fill sequences to fixed length'; 40 | } 41 | 42 | return `Regular Token ${index}: "${token}" - Vocabulary token`; 43 | }; 44 | 45 | return ( 46 |
47 |
48 |
{label}
49 |

50 | Tokenized sequence with special tokens 51 |

52 |
53 | 54 |
55 | {/* Token sequence visualization */} 56 |
57 | {tokens.map((token, index) => ( 58 |
67 |
68 | 69 | {index} 70 | 71 | 72 | {token} 73 | 74 |
75 |
76 | ))} 77 |
78 | 79 | {/* Token count and statistics */} 80 |
81 |
82 |
83 | Total Tokens: 84 | {tokens.length} 85 |
86 |
87 | Special Tokens: 88 | 89 | {tokens.filter(token => token.startsWith('<') && token.endsWith('>')).length} 90 | 91 |
92 |
93 |
94 | 95 | {/* Legend */} 96 |
97 |
Legend:
98 |
99 |
100 |
101 | Start token 102 |
103 |
104 |
105 | Regular token 106 |
107 |
108 |
109 | End token 110 |
111 |
112 |
113 | Pad token 114 |
115 |
116 |
117 |
118 |
119 | ); 120 | } -------------------------------------------------------------------------------- /src/components/TransformerDiagram.tsx: -------------------------------------------------------------------------------- 1 | import { useRef } from 'react'; 2 | import { TransformerParams } from '../types/events'; 3 | import { useTransformerDiagram } from '../hooks/useTransformerDiagram'; 4 | 5 | interface TransformerDiagramProps { 6 | params: TransformerParams; 7 | selectedComponent: string | null; 8 | onComponentClick: (componentId: string) => void; 9 | className?: string; 10 | } 11 | 12 | export function TransformerDiagram({ 13 | params, 14 | selectedComponent, 15 | onComponentClick, 16 | className = '', 17 | }: TransformerDiagramProps) { 18 | const containerRef = useRef(null); 19 | const svgRef = useTransformerDiagram(params, selectedComponent, onComponentClick, containerRef); 20 | 21 | return ( 22 |
26 |
27 |

Transformer Architecture

28 | 29 | {/* Organized Legend */} 30 |
31 | {/* Architectural Sections */} 32 |
33 |

Architecture Sections

34 |
35 |
36 |
37 | Encoder 38 |
39 |
40 |
41 | Decoder 42 |
43 |
44 |
45 | 46 | {/* Processing Blocks */} 47 |
48 |

Processing Components

49 |
50 |
51 |
52 | Tokens 53 |
54 |
55 |
56 | Embedding 57 |
58 |
59 |
60 | Positional 61 |
62 |
63 |
64 | Attention 65 |
66 |
67 |
68 | Add & Norm 69 |
70 |
71 |
72 | Feed Forward 73 |
74 |
75 |
76 | Output 77 |
78 |
79 |
80 | 81 | {/* Connection Types */} 82 |
83 |

Data Flow

84 |
85 |
86 |
87 | Main Flow 88 |
89 |
90 |
91 | Residual 92 |
93 |
94 |
95 | Cross-Attention 96 |
97 |
98 |
99 |
100 |
101 | 102 |
103 | 104 |
105 |
106 | ); 107 | } -------------------------------------------------------------------------------- /src/hooks/useTransformerMachine.ts: -------------------------------------------------------------------------------- 1 | import { useMachine } from '@xstate/react'; 2 | import { useEffect, useRef } from 'react'; 3 | import transformerMachine, { TransformerContext } from '../state/transformerMachine'; 4 | import { useEventLogger } from './useEventLogger'; 5 | 6 | export function useTransformerMachine() { 7 | const [state, send] = useMachine(transformerMachine); 8 | const { logEvent } = useEventLogger(); 9 | const animationFrameRef = useRef(); 10 | 11 | // Animation loop for auto-play 12 | useEffect(() => { 13 | if (state.context.isPlaying) { 14 | const animate = () => { 15 | send({ type: 'TICK' }); 16 | animationFrameRef.current = requestAnimationFrame(animate); 17 | }; 18 | animationFrameRef.current = requestAnimationFrame(animate); 19 | } else { 20 | if (animationFrameRef.current) { 21 | cancelAnimationFrame(animationFrameRef.current); 22 | } 23 | } 24 | 25 | return () => { 26 | if (animationFrameRef.current) { 27 | cancelAnimationFrame(animationFrameRef.current); 28 | } 29 | }; 30 | }, [state.context.isPlaying, send]); 31 | 32 | // Event logging will be handled by individual action calls 33 | 34 | // Convenience actions with logging 35 | const actions = { 36 | // Playback controls 37 | play: () => { 38 | send({ type: 'PLAY' }); 39 | logEvent('animation_control', { action: 'play', current_step: state.context.currentStep }); 40 | }, 41 | pause: () => { 42 | send({ type: 'PAUSE' }); 43 | logEvent('animation_control', { action: 'pause', current_step: state.context.currentStep }); 44 | }, 45 | nextStep: () => { 46 | send({ type: 'NEXT_STEP' }); 47 | logEvent('step', { direction: 'next', step: state.context.currentStep + 1 }); 48 | }, 49 | prevStep: () => { 50 | send({ type: 'PREV_STEP' }); 51 | logEvent('step', { direction: 'prev', step: Math.max(0, state.context.currentStep - 1) }); 52 | }, 53 | setStep: (step: number) => { 54 | send({ type: 'SET_STEP', step }); 55 | logEvent('step', { direction: 'jump', step }); 56 | }, 57 | reset: () => { 58 | send({ type: 'RESET' }); 59 | logEvent('step', { direction: 'reset', step: 0 }); 60 | }, 61 | setPlaybackSpeed: (speed: number) => { 62 | send({ type: 'SET_PLAYBACK_SPEED', speed }); 63 | logEvent('animation_control', { action: 'speed_change', speed, current_step: state.context.currentStep }); 64 | }, 65 | 66 | // Token controls 67 | setActiveToken: (tokenIndex: number) => { 68 | send({ type: 'SET_ACTIVE_TOKEN', tokenIndex }); 69 | logEvent('token_change', { tokenIndex, focusToken: state.context.focusToken }); 70 | }, 71 | toggleFocusToken: () => { 72 | send({ type: 'TOGGLE_FOCUS_TOKEN' }); 73 | logEvent('toggle', { feature: 'focus_token', enabled: !state.context.focusToken }); 74 | }, 75 | 76 | // Zoom controls 77 | zoomIn: (component?: string, layer?: number) => { 78 | send({ type: 'ZOOM_IN', component, layer }); 79 | logEvent('zoom_change', { zoomLevel: 'zoom_in', component, layer }); 80 | }, 81 | zoomOut: () => { 82 | send({ type: 'ZOOM_OUT' }); 83 | logEvent('zoom_change', { zoomLevel: 'zoom_out' }); 84 | }, 85 | setZoomLevel: (level: string, component?: string) => { 86 | send({ 87 | type: 'SET_ZOOM_LEVEL', 88 | level: level as TransformerContext['zoomLevel'], 89 | component 90 | }); 91 | logEvent('zoom_change', { zoomLevel: level, component }); 92 | }, 93 | 94 | // Parameter controls 95 | updateParams: (params: Partial) => { 96 | send({ type: 'UPDATE_PARAMS', params }); 97 | logEvent('param_change', { changes: params, newParams: { ...state.context.params, ...params } }); 98 | }, 99 | 100 | 101 | 102 | // UI controls 103 | toggleSettings: () => { 104 | send({ type: 'TOGGLE_SETTINGS' }); 105 | logEvent('toggle', { feature: 'settings', enabled: !state.context.settingsOpen }); 106 | }, 107 | toggleDarkMode: () => { 108 | send({ type: 'TOGGLE_DARK_MODE' }); 109 | logEvent('toggle', { feature: 'dark_mode', enabled: !state.context.darkMode }); 110 | }, 111 | toggleColorBlindMode: () => { 112 | send({ type: 'TOGGLE_COLOR_BLIND_MODE' }); 113 | logEvent('toggle', { feature: 'color_blind_mode', enabled: !state.context.colorBlindMode }); 114 | }, 115 | 116 | // Attention controls 117 | toggleAttentionLens: () => { 118 | send({ type: 'TOGGLE_ATTENTION_LENS' }); 119 | logEvent('attention_lens', { active: !state.context.attentionLensActive, top_k: state.context.attentionTopK }); 120 | }, 121 | setAttentionTopK: (k: number) => { 122 | send({ type: 'SET_ATTENTION_TOP_K', k }); 123 | logEvent('attention_lens', { active: state.context.attentionLensActive, top_k: k }); 124 | }, 125 | 126 | // Error handling 127 | setError: (message: string) => send({ type: 'ERROR', message }), 128 | clearError: () => send({ type: 'CLEAR_ERROR' }), 129 | }; 130 | 131 | return { 132 | state: state.context, 133 | isIdle: state.matches('idle'), 134 | isPlaying: state.matches('playing'), 135 | hasError: state.matches('error'), 136 | actions, 137 | }; 138 | } -------------------------------------------------------------------------------- /src/components/HelpMenu.tsx: -------------------------------------------------------------------------------- 1 | import { useState, useRef, useEffect } from 'react'; 2 | import { QuestionMarkCircleIcon, XMarkIcon, BookOpenIcon, CursorArrowRaysIcon } from '@heroicons/react/24/outline'; 3 | 4 | export function HelpMenu() { 5 | const [isOpen, setIsOpen] = useState(false); 6 | const menuRef = useRef(null); 7 | 8 | // Close menu when clicking outside 9 | useEffect(() => { 10 | function handleClickOutside(event: MouseEvent) { 11 | if (menuRef.current && !menuRef.current.contains(event.target as Node)) { 12 | setIsOpen(false); 13 | } 14 | } 15 | 16 | if (isOpen) { 17 | document.addEventListener('mousedown', handleClickOutside); 18 | return () => document.removeEventListener('mousedown', handleClickOutside); 19 | } 20 | }, [isOpen]); 21 | 22 | return ( 23 |
24 | {/* Help Button */} 25 | 33 | 34 | {/* Help Panel */} 35 | {isOpen && ( 36 |
37 |
38 | {/* Header */} 39 |
40 |
41 |
42 | 43 |
44 |
45 |

Learning Guide

46 |

Explore transformer architecture interactively

47 |
48 |
49 | 55 |
56 | 57 | 58 | 59 | {/* Interactive Features */} 60 |
61 |
62 | 63 |

Interactive Features

64 |
65 |
66 |
67 |
🔍 Component Explorer
68 |

69 | Click any component in the diagram to see its detailed structure, 70 | inputs, parameters, and mathematical operations. 71 |

72 |
73 |
74 |
⚙️ Parameter Controls
75 |

76 | Adjust model settings like number of layers, attention heads, 77 | and sequence length to see how they change the architecture. 78 |

79 |
80 |
81 |
📊 Data Visualization
82 |

83 | Switch between views to see abstract representations or 84 | actual numerical matrices and attention patterns. 85 |

86 |
87 |
88 |
89 | 90 | 91 | 92 | {/* Understanding the Diagram */} 93 |
94 |

Understanding the Diagram

95 |
96 |
97 |
98 | Encoder components (process input) 99 |
100 |
101 |
102 | Decoder components (generate output) 103 |
104 |
105 |
106 | Cross-attention (encoder → decoder) 107 |
108 |
109 |
110 | Residual connections (skip paths) 111 |
112 |
113 |
114 | 115 |
116 |
117 | )} 118 |
119 | ); 120 | } -------------------------------------------------------------------------------- /src/utils/constants.ts: -------------------------------------------------------------------------------- 1 | // Core constants for the transformer simulation 2 | // Consolidated from multiple files for better import clarity 3 | 4 | import { TransformerParams } from '../types/events'; 5 | 6 | // Theme and Colors 7 | export const COLORS = { 8 | // Primary Beyond Blue palette 9 | primary: '#1062fb', 10 | gold: '#ffc500', 11 | 12 | // Supporting colors 13 | white: '#ffffff', 14 | black: '#000000', 15 | gray: { 16 | 50: '#f9fafb', 17 | 100: '#f3f4f6', 18 | 200: '#e5e7eb', 19 | 300: '#d1d5db', 20 | 400: '#9ca3af', 21 | 500: '#6b7280', 22 | 600: '#4b5563', 23 | 700: '#374151', 24 | 800: '#1f2937', 25 | 900: '#111827', 26 | }, 27 | 28 | // Semantic colors 29 | text: '#111827', 30 | textSecondary: '#6b7280', 31 | background: '#ffffff', 32 | border: '#e5e7eb', 33 | 34 | // Status colors 35 | success: '#10b981', 36 | warning: '#f59e0b', 37 | error: '#ef4444', 38 | 39 | // Component colors - CodeSignal brand colors 40 | encoder: '#1062fb', // Beyond Blue (Primary) 41 | decoder: '#002570', // Dive Deep Blue (Secondary) 42 | cross_attention: '#64D3FF', // SignalLight Blue (Secondary) 43 | residual: '#9ca3af', // Gray 400 44 | layer_norm: '#21CF82', // Victory Green (Tertiary) 45 | ffn: '#FF7232', // FiresIDE Orange (Tertiary) 46 | 47 | // Processing block colors 48 | input_tokens: '#6b7280', // Gray 500 49 | embedding: '#1062FB', // Beyond Blue (Primary) 50 | positional: '#FFC500', // Gold Standard Yellow (Secondary) 51 | attention: '#E6193F', // Ruby Red (Tertiary) 52 | output: '#002570', // Dive Deep Blue (Secondary) 53 | 54 | // Matrix cell colors (Viridis palette - color-blind safe) 55 | viridis: [ 56 | '#440154', '#482777', '#3f4a8a', '#31678e', 57 | '#26838f', '#1f9d8a', '#6cce5a', '#b6de2b', 58 | '#fee825' 59 | ], 60 | }; 61 | 62 | // Zoom Levels 63 | export const ZOOM_LEVELS = { 64 | GLOBAL: 'global' as const, 65 | LAYER: 'layer' as const, 66 | SUB_LAYER: 'sub_layer' as const, 67 | TENSOR_CELL: 'tensor_cell' as const, 68 | }; 69 | 70 | export type ZoomLevel = typeof ZOOM_LEVELS[keyof typeof ZOOM_LEVELS]; 71 | 72 | // Default Parameters 73 | export const DEFAULT_PARAMS: TransformerParams = { 74 | numLayers: 1, 75 | dModel: 64, 76 | numHeads: 4, 77 | seqLen: 8, 78 | posEncoding: 'sinusoidal', 79 | dropout: false, 80 | }; 81 | 82 | export const PARAM_LIMITS = { 83 | numLayers: { min: 1, max: 6 }, 84 | dModel: { min: 32, max: 512, step: 32 }, 85 | numHeads: { min: 1, max: 16 }, 86 | seqLen: { min: 3, max: 32 }, 87 | }; 88 | 89 | // Transformer Steps 90 | export const TRANSFORMER_STEPS = [ 91 | // Input processing 92 | { 93 | id: 'input_tokens', 94 | label: 'Input Tokens', 95 | description: 'Source sequence tokens as discrete symbols', 96 | category: 'input', 97 | component: 'input', 98 | zoomLevel: ZOOM_LEVELS.TENSOR_CELL 99 | }, 100 | { 101 | id: 'input_embedding', 102 | label: 'Input Embedding', 103 | description: 'Convert source tokens to dense vectors', 104 | category: 'embedding', 105 | component: 'input', 106 | zoomLevel: ZOOM_LEVELS.TENSOR_CELL 107 | }, 108 | { 109 | id: 'input_positional', 110 | label: 'Input Positional Encoding', 111 | description: 'Add position information to input embeddings', 112 | category: 'embedding', 113 | component: 'input', 114 | zoomLevel: ZOOM_LEVELS.TENSOR_CELL 115 | }, 116 | 117 | // Encoder stack 118 | { 119 | id: 'encoder_self_attention', 120 | label: 'Encoder Self-Attention', 121 | description: 'Encoder attends to input sequence', 122 | category: 'encoder', 123 | component: 'encoder', 124 | zoomLevel: ZOOM_LEVELS.SUB_LAYER 125 | }, 126 | { 127 | id: 'encoder_ffn', 128 | label: 'Encoder Feed-Forward', 129 | description: 'Process encoder representations through FFN', 130 | category: 'encoder', 131 | component: 'encoder', 132 | zoomLevel: ZOOM_LEVELS.SUB_LAYER 133 | }, 134 | 135 | // Decoder input 136 | { 137 | id: 'output_tokens', 138 | label: 'Output Tokens', 139 | description: 'Target sequence tokens (shifted right)', 140 | category: 'input', 141 | component: 'decoder', 142 | zoomLevel: ZOOM_LEVELS.TENSOR_CELL 143 | }, 144 | { 145 | id: 'output_embedding', 146 | label: 'Output Embedding', 147 | description: 'Convert target tokens to dense vectors', 148 | category: 'embedding', 149 | component: 'decoder', 150 | zoomLevel: ZOOM_LEVELS.TENSOR_CELL 151 | }, 152 | { 153 | id: 'output_positional', 154 | label: 'Output Positional Encoding', 155 | description: 'Add position information to output embeddings', 156 | category: 'embedding', 157 | component: 'decoder', 158 | zoomLevel: ZOOM_LEVELS.TENSOR_CELL 159 | }, 160 | 161 | // Decoder stack 162 | { 163 | id: 'decoder_self_attention', 164 | label: 'Decoder Self-Attention (Masked)', 165 | description: 'Decoder attends to previous output tokens', 166 | category: 'decoder', 167 | component: 'decoder', 168 | zoomLevel: ZOOM_LEVELS.SUB_LAYER 169 | }, 170 | { 171 | id: 'cross_attention', 172 | label: 'Cross-Attention', 173 | description: 'Decoder queries attend to encoder keys/values', 174 | category: 'cross_attention', 175 | component: 'cross_attention', 176 | zoomLevel: ZOOM_LEVELS.SUB_LAYER 177 | }, 178 | { 179 | id: 'decoder_ffn', 180 | label: 'Decoder Feed-Forward', 181 | description: 'Process decoder representations through FFN', 182 | category: 'decoder', 183 | component: 'decoder', 184 | zoomLevel: ZOOM_LEVELS.SUB_LAYER 185 | }, 186 | 187 | // Output processing 188 | { 189 | id: 'linear_layer', 190 | label: 'Linear Layer', 191 | description: 'Project decoder output to vocabulary space', 192 | category: 'output', 193 | component: 'output', 194 | zoomLevel: ZOOM_LEVELS.LAYER 195 | }, 196 | { 197 | id: 'softmax', 198 | label: 'Softmax', 199 | description: 'Convert logits to probability distribution', 200 | category: 'output', 201 | component: 'output', 202 | zoomLevel: ZOOM_LEVELS.TENSOR_CELL 203 | }, 204 | ]; -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Travel Through Transformers 2 | 3 | An interactive web-based simulation that lets learners follow a single token step-by-step through every component of a Transformer encoder/decoder stack. 4 | 5 | ## Features 6 | 7 | - **Component-focused visualization**: Click through different transformer components to see detailed internals 8 | - **Interactive parameters**: Adjust layers, model dimensions, attention heads, and sequence length in real-time 9 | - **Dual visualization modes**: Abstract shape view for understanding flow, or detailed numerical values 10 | - **Multi-head attention visualization**: See how different attention heads process information 11 | - **Event logging**: All interactions are logged for analytics 12 | 13 | ## Quick Start 14 | 15 | ### Prerequisites 16 | 17 | - Node.js 16+ and npm 18 | - Python 3.7+ 19 | 20 | ### Installation 21 | 22 | 1. **Install dependencies**: 23 | ```bash 24 | npm install 25 | ``` 26 | 27 | 2. **Build the application**: 28 | ```bash 29 | npm run build 30 | ``` 31 | 32 | 3. **Start the server**: 33 | ```bash 34 | python server/server.py 35 | ``` 36 | 37 | 4. **Open your browser**: 38 | Navigate to `http://localhost:3000` 39 | 40 | ### Development Mode 41 | 42 | For development with hot reloading: 43 | 44 | ```bash 45 | # Terminal 1 - Start the development server 46 | npm run dev 47 | 48 | # Terminal 2 - Start the logging server 49 | python server/server.py 50 | ``` 51 | 52 | Then open `http://localhost:5173` (development) or `http://localhost:3000` (production). 53 | 54 | ## Usage 55 | 56 | ### Controls 57 | 58 | 1. **Component Selection**: Click on transformer components in the diagram to explore their internals 59 | 2. **Show Values Toggle**: Switch between abstract block view and actual numerical matrices 60 | 3. **Model Parameters**: 61 | - Adjust number of layers (1-6) 62 | - Change model dimension (32-512) 63 | - Set attention heads (1-8, must divide model dimension) 64 | - Modify sequence length (3-10) 65 | - Choose positional encoding type 66 | - Enable dropout visualization 67 | 68 | ### Understanding the Visualization 69 | 70 | #### Abstract Mode (Default) 71 | - Colored blocks represent matrices with dimensions shown 72 | - Different colors indicate different types of operations 73 | - Active components highlighted with distinctive colors 74 | - Attention heads shown in different colors 75 | 76 | #### Values Mode 77 | - Heat maps show actual numerical values 78 | - Color intensity represents magnitude 79 | - Hover for precise values 80 | - Active components highlighted 81 | 82 | ## Architecture 83 | 84 | ``` 85 | travel-through-transformers/ 86 | ├── src/ 87 | │ ├── components/ # React UI components 88 | │ │ ├── MatrixVisualization.tsx # D3-powered matrix visualization 89 | │ │ ├── TokenVisualization.tsx # Token display and interaction 90 | │ │ ├── SettingsMenu.tsx # Parameter controls 91 | │ │ ├── ComponentDetailsPanel.tsx # Component detail exploration 92 | │ │ ├── HelpMenu.tsx # Help system 93 | │ │ └── TransformerDiagram.tsx # Main architecture diagram 94 | │ ├── hooks/ # Custom React hooks 95 | │ │ ├── useTransformerMachine.ts # Main state management (XState) 96 | │ │ ├── useTransformerDiagram.ts # Diagram interaction logic 97 | │ │ └── useEventLogger.ts # Analytics logging 98 | │ ├── utils/ # Utility functions 99 | │ │ ├── math.ts # Matrix operations 100 | │ │ ├── randomWeights.ts # Seeded random generation 101 | │ │ ├── constants.ts # Configuration and steps 102 | │ │ ├── data.ts # Sample data generation 103 | │ │ ├── componentDataGenerator.ts # Component data creation 104 | │ │ └── componentTransformations.ts # Math transformations 105 | │ ├── state/ # State management 106 | │ │ └── transformerMachine.ts # XState machine definition 107 | │ └── types/ # TypeScript definitions 108 | │ └── events.d.ts # Event and parameter types 109 | ├── server/ 110 | │ └── server.py # Python logging server 111 | └── logs/ # Event logs (generated) 112 | ``` 113 | 114 | ## Educational Goals 115 | 116 | This simulation helps learners understand: 117 | 118 | 1. **Component Architecture**: How transformer components are organized and connected 119 | 2. **Attention Mechanism**: How queries, keys, and values interact 120 | 3. **Multi-Head Attention**: How different heads capture different patterns 121 | 4. **Residual Connections**: How information flows around attention blocks 122 | 5. **Layer Normalization**: How activations are normalized 123 | 6. **Feed-Forward Networks**: How information is processed after attention 124 | 7. **Positional Encoding**: How position information is added to tokens 125 | 8. **Cross-Attention**: How decoder attends to encoder representations 126 | 127 | ## Technical Details 128 | 129 | - **Frontend**: React + TypeScript + Vite 130 | - **State Management**: XState for complex state transitions 131 | - **Visualization**: D3.js for interactive SVG graphics 132 | - **Styling**: TailwindCSS with CodeSignal brand colors 133 | - **Math**: Custom lightweight tensor operations (no external ML libraries) 134 | - **Backend**: Simple Python HTTP server for logging 135 | - **Data**: Seeded random weights for reproducible results 136 | 137 | ## Customization 138 | 139 | ### Adding New Components 140 | 141 | 1. Add component definition to transformer machine states 142 | 2. Implement component logic in `componentDataGenerator.ts` 143 | 3. Add appropriate visualizations in component files 144 | 4. Update component transformations in `componentTransformations.ts` 145 | 146 | ### Modifying Visualization 147 | 148 | - Matrix colors: Edit `COLORS` in `constants.ts` 149 | - D3 rendering: Modify `MatrixVisualization.tsx` 150 | - Component descriptions: Update transformer machine configuration 151 | 152 | ### Analytics 153 | 154 | Event logs are stored in `logs/simulation_log.jsonl` with schema: 155 | ```json 156 | { 157 | "timestamp": 1625239200, 158 | "event_type": "param_change" | "component_select" | "toggle" | "zoom_change", 159 | "payload": { /* event-specific data */ } 160 | } 161 | ``` 162 | 163 | ## Browser Support 164 | 165 | - Chrome 90+ 166 | - Firefox 88+ 167 | - Safari 14+ 168 | - Edge 90+ 169 | 170 | ## Performance 171 | 172 | Optimized for: 173 | - 60 FPS animations 174 | - Sequence length ≤ 10 175 | - Attention heads ≤ 8 176 | - Model dimension ≤ 512 177 | 178 | ## License 179 | 180 | MIT License - see [LICENSE](LICENSE) for details. 181 | 182 | ## Contributing 183 | 184 | 1. Fork the repository 185 | 2. Create a feature branch 186 | 3. Make your changes 187 | 4. Add tests if applicable 188 | 5. Submit a pull request 189 | 190 | ## Troubleshooting 191 | 192 | ### Common Issues 193 | 194 | **"Cannot find module" errors**: Run `npm install` 195 | 196 | **Server won't start**: Check that port 3000 is available, or specify a different port: `python server/server.py 3001` 197 | 198 | **Visualization not updating**: Try refreshing the page or clearing browser cache 199 | 200 | **Performance issues**: Reduce model parameters (fewer layers, smaller dimensions) 201 | 202 | ### Debug Mode 203 | 204 | Enable debug logging: 205 | ```bash 206 | DEBUG=1 python server/server.py 207 | ``` 208 | 209 | ## Educational Extensions 210 | 211 | Future enhancements could include: 212 | 213 | - Real model weights from Hugging Face 214 | - Attention pattern analysis 215 | - Interactive quizzes between steps 216 | - Comparison with other architectures 217 | - Custom text input 218 | - Export/import configurations -------------------------------------------------------------------------------- /src/components/MatrixVisualization.tsx: -------------------------------------------------------------------------------- 1 | import { useRef, useEffect, useState, useMemo } from 'react'; 2 | import * as d3 from 'd3'; 3 | import { MagnifyingGlassIcon } from '@heroicons/react/24/outline'; 4 | import { COLORS } from '../utils/constants'; 5 | 6 | interface MatrixVisualizationProps { 7 | data: number[][] | number[]; 8 | shape?: [number, number] | [number]; 9 | label: string; 10 | className?: string; 11 | } 12 | 13 | export function MatrixVisualization({ 14 | data, 15 | label, 16 | className = '' 17 | }: MatrixVisualizationProps) { 18 | const svgRef = useRef(null); 19 | const detailSvgRef = useRef(null); 20 | const containerRef = useRef(null); 21 | const [showDetailModal, setShowDetailModal] = useState(false); 22 | 23 | // Convert data to 2D array if it's 1D - memoized to avoid recreation on every render 24 | const matrix2D = useMemo(() => { 25 | return Array.isArray(data[0]) ? data as number[][] : [data as number[]]; 26 | }, [data]); 27 | 28 | const rows = matrix2D.length; 29 | const cols = matrix2D[0].length; 30 | 31 | // Simplified overview visualization 32 | useEffect(() => { 33 | if (!svgRef.current || !containerRef.current || !data) return; 34 | 35 | const svg = d3.select(svgRef.current); 36 | svg.selectAll('*').remove(); 37 | 38 | const width = 200; 39 | const height = 150; 40 | svg.attr('width', width).attr('height', height); 41 | 42 | const margin = { top: 30, right: 40, bottom: 30, left: 40 }; 43 | const innerWidth = width - margin.left - margin.right; 44 | const innerHeight = height - margin.top - margin.bottom; 45 | 46 | const g = svg.append('g').attr('transform', `translate(${margin.left},${margin.top})`); 47 | 48 | // Calculate matrix rectangle dimensions 49 | const aspectRatio = cols / rows; 50 | let rectWidth, rectHeight; 51 | 52 | if (aspectRatio > 1) { 53 | rectWidth = innerWidth; 54 | rectHeight = innerWidth / aspectRatio; 55 | } else { 56 | rectHeight = innerHeight; 57 | rectWidth = innerHeight * aspectRatio; 58 | } 59 | 60 | const rectX = (innerWidth - rectWidth) / 2; 61 | const rectY = (innerHeight - rectHeight) / 2; 62 | 63 | // Color based on data range for visual feedback 64 | const flatData = matrix2D.flat(); 65 | const dataRange = d3.extent(flatData) as [number, number]; 66 | const avgValue = d3.mean(flatData) || 0; 67 | const normalizedValue = dataRange[1] !== dataRange[0] ? 68 | (avgValue - dataRange[0]) / (dataRange[1] - dataRange[0]) : 0.5; 69 | 70 | const matrixColor = d3.interpolateViridis(normalizedValue); 71 | 72 | // Draw matrix rectangle 73 | g.append('rect') 74 | .attr('width', rectWidth) 75 | .attr('height', rectHeight) 76 | .attr('x', rectX) 77 | .attr('y', rectY) 78 | .attr('fill', matrixColor) 79 | .attr('fill-opacity', 0.7) 80 | .attr('stroke', COLORS.text) 81 | .attr('stroke-width', 1) 82 | .attr('rx', 4); 83 | 84 | // Row dimension label (left side) 85 | g.append('text') 86 | .attr('x', rectX - 8) 87 | .attr('y', rectY + rectHeight / 2) 88 | .attr('text-anchor', 'end') 89 | .attr('dominant-baseline', 'middle') 90 | .attr('font-size', 14) 91 | .attr('font-weight', 'bold') 92 | .attr('fill', COLORS.text) 93 | .text(rows); 94 | 95 | // Column dimension label (top) 96 | g.append('text') 97 | .attr('x', rectX + rectWidth / 2) 98 | .attr('y', rectY - 8) 99 | .attr('text-anchor', 'middle') 100 | .attr('dominant-baseline', 'middle') 101 | .attr('font-size', 14) 102 | .attr('font-weight', 'bold') 103 | .attr('fill', COLORS.text) 104 | .text(cols); 105 | 106 | }, [data, rows, cols, matrix2D]); // Added matrix2D to dependencies 107 | 108 | // Detailed modal visualization 109 | useEffect(() => { 110 | if (!detailSvgRef.current || !showDetailModal || !data) return; 111 | 112 | const svg = d3.select(detailSvgRef.current); 113 | svg.selectAll('*').remove(); 114 | 115 | const maxWidth = 800; 116 | const maxHeight = 600; 117 | 118 | // Calculate dimensions for detailed view 119 | const cellSize = Math.min(maxWidth / cols, maxHeight / rows, 30); 120 | const width = cols * cellSize; 121 | const height = rows * cellSize; 122 | 123 | svg.attr('width', width + 80).attr('height', height + 80); 124 | 125 | const margin = { top: 40, right: 40, bottom: 40, left: 60 }; 126 | const g = svg.append('g').attr('transform', `translate(${margin.left},${margin.top})`); 127 | 128 | // Create color scale 129 | const flatData = matrix2D.flat(); 130 | const colorScale = d3.scaleSequential(d3.interpolateViridis) 131 | .domain(d3.extent(flatData) as [number, number]); 132 | 133 | // Create heatmap 134 | const cells = g.selectAll('.cell') 135 | .data(matrix2D.flatMap((row, i) => 136 | row.map((value, j) => ({ value, row: i, col: j })) 137 | )) 138 | .enter() 139 | .append('g') 140 | .attr('class', 'cell') 141 | .attr('transform', d => `translate(${d.col * cellSize},${d.row * cellSize})`); 142 | 143 | cells.append('rect') 144 | .attr('width', cellSize - 1) 145 | .attr('height', cellSize - 1) 146 | .attr('fill', d => colorScale(d.value)) 147 | .attr('stroke', '#fff') 148 | .attr('stroke-width', 0.5); 149 | 150 | // Add values if cells are large enough 151 | if (cellSize >= 20) { 152 | cells.append('text') 153 | .attr('x', cellSize / 2) 154 | .attr('y', cellSize / 2) 155 | .attr('text-anchor', 'middle') 156 | .attr('dominant-baseline', 'middle') 157 | .attr('font-size', Math.min(cellSize / 3, 12)) 158 | .attr('font-weight', 'bold') 159 | .attr('fill', d => d3.lab(colorScale(d.value)).l > 60 ? '#000' : '#fff') 160 | .text(d => d.value.toFixed(3)); 161 | } 162 | 163 | // Add axis labels with better spacing 164 | if (rows <= 20) { 165 | g.selectAll('.row-label') 166 | .data(d3.range(rows)) 167 | .enter() 168 | .append('text') 169 | .attr('class', 'row-label') 170 | .attr('x', -8) 171 | .attr('y', d => d * cellSize + cellSize / 2) 172 | .attr('text-anchor', 'end') 173 | .attr('dominant-baseline', 'middle') 174 | .attr('font-size', Math.min(cellSize / 2, 12)) 175 | .attr('fill', COLORS.text) 176 | .text(d => d); 177 | } 178 | 179 | if (cols <= 30) { 180 | g.selectAll('.col-label') 181 | .data(d3.range(cols)) 182 | .enter() 183 | .append('text') 184 | .attr('class', 'col-label') 185 | .attr('x', d => d * cellSize + cellSize / 2) 186 | .attr('y', -8) 187 | .attr('text-anchor', 'middle') 188 | .attr('font-size', Math.min(cellSize / 2, 12)) 189 | .attr('fill', COLORS.text) 190 | .text(d => d); 191 | } 192 | 193 | }, [data, showDetailModal, rows, cols, label, matrix2D]); // Added matrix2D to dependencies 194 | 195 | return ( 196 | <> 197 |
198 |
199 |
200 |
{label}
201 |
202 | 209 |
210 |
211 | 212 |
213 |
214 | 215 | {/* Detail Modal */} 216 | {showDetailModal && ( 217 |
218 |
219 |
220 |
221 |

{label}

222 |

Detailed matrix view • {rows}×{cols}

223 |
224 | 230 |
231 |
232 | 233 |
234 |
235 |
236 | )} 237 | 238 | ); 239 | } -------------------------------------------------------------------------------- /src/components/SettingsMenu.tsx: -------------------------------------------------------------------------------- 1 | import { useState, useRef, useEffect } from 'react'; 2 | import { CogIcon, XMarkIcon } from '@heroicons/react/24/outline'; 3 | import { TransformerParams } from '../types/events'; 4 | import { PARAM_LIMITS } from '../utils/constants'; 5 | 6 | interface SettingsMenuProps { 7 | params: TransformerParams; 8 | onParamsChange: (params: Partial) => void; 9 | } 10 | 11 | // Define preset configurations 12 | const PRESETS = { 13 | minimal: { numLayers: 1, dModel: 64, numHeads: 4, seqLen: 4 }, 14 | standard: { numLayers: 3, dModel: 128, numHeads: 8, seqLen: 6 }, 15 | large: { numLayers: 6, dModel: 256, numHeads: 8, seqLen: 8 }, 16 | } as const; 17 | 18 | export function SettingsMenu({ 19 | params, 20 | onParamsChange, 21 | }: SettingsMenuProps) { 22 | const [isOpen, setIsOpen] = useState(false); 23 | const menuRef = useRef(null); 24 | 25 | // Determine which preset is currently active 26 | const getActivePreset = () => { 27 | for (const [presetName, presetParams] of Object.entries(PRESETS)) { 28 | const isMatch = Object.entries(presetParams).every(([key, value]) => { 29 | return params[key as keyof TransformerParams] === value; 30 | }); 31 | if (isMatch) return presetName; 32 | } 33 | return null; 34 | }; 35 | 36 | const activePreset = getActivePreset(); 37 | 38 | // Close menu when clicking outside 39 | useEffect(() => { 40 | function handleClickOutside(event: MouseEvent) { 41 | if (menuRef.current && !menuRef.current.contains(event.target as Node)) { 42 | setIsOpen(false); 43 | } 44 | } 45 | 46 | if (isOpen) { 47 | document.addEventListener('mousedown', handleClickOutside); 48 | return () => document.removeEventListener('mousedown', handleClickOutside); 49 | } 50 | }, [isOpen]); 51 | 52 | // Helper function to get button styling based on active state 53 | const getButtonClass = (presetName: string) => { 54 | const baseClass = "px-3 py-2 text-xs rounded-md transition-colors"; 55 | const isActive = activePreset === presetName; 56 | 57 | if (isActive) { 58 | return `${baseClass} bg-cs-blue text-white hover:bg-cs-deep`; 59 | } else { 60 | return `${baseClass} bg-gray-100 hover:bg-gray-200 text-gray-700`; 61 | } 62 | }; 63 | 64 | return ( 65 |
66 | {/* Settings Button */} 67 | 75 | 76 | {/* Settings Panel */} 77 | {isOpen && ( 78 |
79 |
80 | {/* Header */} 81 |
82 |

Configuration

83 | 89 |
90 | 91 | {/* Architecture Parameters */} 92 |
93 |

Architecture

94 | 95 | {/* Number of Layers */} 96 |
97 |
98 | 99 | {params.numLayers} 100 |
101 | onParamsChange({ numLayers: parseInt(e.target.value) })} 107 | className="w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer slider" 108 | /> 109 |
110 | 111 | {/* Model Dimension */} 112 |
113 |
114 | 115 | {params.dModel} 116 |
117 | onParamsChange({ dModel: parseInt(e.target.value) })} 124 | className="w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer slider" 125 | /> 126 |
127 | 128 | {/* Number of Heads */} 129 |
130 |
131 | 132 | {params.numHeads} 133 |
134 | onParamsChange({ numHeads: parseInt(e.target.value) })} 140 | className="w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer slider" 141 | /> 142 |

143 | Max: {Math.floor(params.dModel / 8)} (divisible constraint) 144 |

145 |
146 | 147 | {/* Sequence Length */} 148 |
149 |
150 | 151 | {params.seqLen} 152 |
153 | onParamsChange({ seqLen: parseInt(e.target.value) })} 159 | className="w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer slider" 160 | /> 161 |
162 | 163 | {/* Positional Encoding */} 164 |
165 | 166 | 174 |
175 |
176 | 177 | {/* Quick Presets */} 178 |
179 |

Quick Presets

180 |
181 | 187 | 193 | 199 |
200 |
201 |
202 |
203 | )} 204 |
205 | ); 206 | } -------------------------------------------------------------------------------- /src/state/transformerMachine.ts: -------------------------------------------------------------------------------- 1 | import { createMachine, assign } from 'xstate'; 2 | import { TransformerParams } from '../types/events'; 3 | import { DEFAULT_PARAMS, TRANSFORMER_STEPS, ZOOM_LEVELS, ZoomLevel } from '../utils/constants'; 4 | import { SequenceGenerator } from '../utils/data'; 5 | 6 | export interface TransformerContext { 7 | // Core simulation state 8 | currentStep: number; 9 | totalSteps: number; 10 | activeToken: number; 11 | focusToken: boolean; 12 | 13 | // Zoom and view state 14 | zoomLevel: ZoomLevel; 15 | selectedComponent?: string; 16 | selectedLayer?: number; 17 | 18 | // Model parameters 19 | params: TransformerParams; 20 | 21 | // UI state 22 | isPlaying: boolean; 23 | playbackSpeed: number; // 0.5x to 2x 24 | 25 | // Settings 26 | settingsOpen: boolean; 27 | darkMode: boolean; 28 | colorBlindMode: boolean; 29 | 30 | // Attention visualization 31 | attentionLensActive: boolean; 32 | attentionTopK: number; 33 | 34 | // Data - simplified to just sequences 35 | inputSequence: string[]; 36 | outputSequence: string[]; 37 | 38 | // Performance 39 | lastFrameTime: number; 40 | currentFPS: number; 41 | 42 | // Error state 43 | error?: string; 44 | } 45 | 46 | export type TransformerEvent = 47 | | { type: 'PLAY' } 48 | | { type: 'PAUSE' } 49 | | { type: 'NEXT_STEP' } 50 | | { type: 'PREV_STEP' } 51 | | { type: 'SET_STEP'; step: number } 52 | | { type: 'RESET' } 53 | | { type: 'SET_ACTIVE_TOKEN'; tokenIndex: number } 54 | | { type: 'TOGGLE_FOCUS_TOKEN' } 55 | | { type: 'ZOOM_IN'; component?: string; layer?: number } 56 | | { type: 'ZOOM_OUT' } 57 | | { type: 'SET_ZOOM_LEVEL'; level: ZoomLevel; component?: string } 58 | | { type: 'UPDATE_PARAMS'; params: Partial } 59 | | { type: 'TOGGLE_SETTINGS' } 60 | | { type: 'SET_PLAYBACK_SPEED'; speed: number } 61 | | { type: 'TOGGLE_DARK_MODE' } 62 | | { type: 'TOGGLE_COLOR_BLIND_MODE' } 63 | | { type: 'TOGGLE_ATTENTION_LENS' } 64 | | { type: 'SET_ATTENTION_TOP_K'; k: number } 65 | | { type: 'TICK' } // Animation frame 66 | | { type: 'ERROR'; message: string } 67 | | { type: 'CLEAR_ERROR' }; 68 | 69 | // Helper function to create toggle actions 70 | const createToggleAction = (field: K) => 71 | assign({ [field]: ({ context }: { context: TransformerContext }) => !context[field] }); 72 | 73 | // Generate initial sequences - simplified 74 | const initialInputSequence = SequenceGenerator.generateSequence(DEFAULT_PARAMS.seqLen, 'input'); 75 | const initialOutputSequence = SequenceGenerator.generateMatchingOutput(initialInputSequence); 76 | 77 | const transformerMachine = createMachine({ 78 | /** @xstate-layout N4IgpgJg5mDOIC5QAoC2BDAxgCwJYDswBKAOnwGIA */ 79 | id: 'transformer', 80 | initial: 'idle', 81 | context: { 82 | currentStep: 0, 83 | totalSteps: TRANSFORMER_STEPS.length, 84 | activeToken: 0, 85 | focusToken: false, 86 | 87 | zoomLevel: ZOOM_LEVELS.GLOBAL, 88 | selectedComponent: undefined, 89 | selectedLayer: undefined, 90 | 91 | params: DEFAULT_PARAMS, 92 | 93 | isPlaying: false, 94 | playbackSpeed: 1.0, 95 | 96 | settingsOpen: false, 97 | darkMode: false, 98 | colorBlindMode: false, 99 | 100 | attentionLensActive: false, 101 | attentionTopK: 5, 102 | 103 | inputSequence: initialInputSequence, 104 | outputSequence: initialOutputSequence, 105 | 106 | lastFrameTime: 0, 107 | currentFPS: 60, 108 | } as TransformerContext, 109 | 110 | states: { 111 | idle: { 112 | on: { 113 | PLAY: { 114 | target: 'playing', 115 | actions: 'setPlaying' 116 | }, 117 | NEXT_STEP: { 118 | actions: 'nextStep' 119 | }, 120 | PREV_STEP: { 121 | actions: 'prevStep' 122 | }, 123 | SET_STEP: { 124 | actions: 'setStep' 125 | }, 126 | RESET: { 127 | actions: 'reset' 128 | }, 129 | SET_ACTIVE_TOKEN: { 130 | actions: 'setActiveToken' 131 | }, 132 | TOGGLE_FOCUS_TOKEN: { 133 | actions: 'toggleFocusToken' 134 | }, 135 | ZOOM_IN: { 136 | actions: 'zoomIn' 137 | }, 138 | ZOOM_OUT: { 139 | actions: 'zoomOut' 140 | }, 141 | SET_ZOOM_LEVEL: { 142 | actions: 'setZoomLevel' 143 | }, 144 | UPDATE_PARAMS: { 145 | actions: 'updateParams' 146 | }, 147 | TOGGLE_SETTINGS: { 148 | actions: 'toggleSettings' 149 | }, 150 | SET_PLAYBACK_SPEED: { 151 | actions: 'setPlaybackSpeed' 152 | }, 153 | TOGGLE_DARK_MODE: { 154 | actions: 'toggleDarkMode' 155 | }, 156 | TOGGLE_COLOR_BLIND_MODE: { 157 | actions: 'toggleColorBlindMode' 158 | }, 159 | TOGGLE_ATTENTION_LENS: { 160 | actions: 'toggleAttentionLens' 161 | }, 162 | SET_ATTENTION_TOP_K: { 163 | actions: 'setAttentionTopK' 164 | }, 165 | ERROR: { 166 | target: 'error', 167 | actions: 'setError' 168 | } 169 | } 170 | }, 171 | 172 | playing: { 173 | on: { 174 | PAUSE: { 175 | target: 'idle', 176 | actions: 'setPaused' 177 | }, 178 | TICK: { 179 | actions: 'animationTick' 180 | }, 181 | // Allow other interactions while playing 182 | SET_ACTIVE_TOKEN: { 183 | actions: 'setActiveToken' 184 | }, 185 | TOGGLE_FOCUS_TOKEN: { 186 | actions: 'toggleFocusToken' 187 | }, 188 | ZOOM_IN: { 189 | actions: 'zoomIn' 190 | }, 191 | ZOOM_OUT: { 192 | actions: 'zoomOut' 193 | }, 194 | SET_ZOOM_LEVEL: { 195 | actions: 'setZoomLevel' 196 | }, 197 | TOGGLE_SETTINGS: { 198 | actions: 'toggleSettings' 199 | }, 200 | SET_PLAYBACK_SPEED: { 201 | actions: 'setPlaybackSpeed' 202 | }, 203 | TOGGLE_ATTENTION_LENS: { 204 | actions: 'toggleAttentionLens' 205 | }, 206 | SET_ATTENTION_TOP_K: { 207 | actions: 'setAttentionTopK' 208 | }, 209 | ERROR: { 210 | target: 'error', 211 | actions: 'setError' 212 | } 213 | } 214 | }, 215 | 216 | error: { 217 | on: { 218 | CLEAR_ERROR: { 219 | target: 'idle', 220 | actions: 'clearError' 221 | }, 222 | RESET: { 223 | target: 'idle', 224 | actions: ['reset', 'clearError'] 225 | } 226 | } 227 | } 228 | } 229 | }, { 230 | actions: { 231 | setPlaying: assign({ 232 | isPlaying: true 233 | }), 234 | 235 | setPaused: assign({ 236 | isPlaying: false 237 | }), 238 | 239 | nextStep: assign({ 240 | currentStep: ({ context }) => 241 | Math.min(context.currentStep + 1, context.totalSteps - 1) 242 | }), 243 | 244 | prevStep: assign({ 245 | currentStep: ({ context }) => 246 | Math.max(context.currentStep - 1, 0) 247 | }), 248 | 249 | setStep: assign({ 250 | currentStep: ({ event }) => 251 | event.type === 'SET_STEP' ? event.step : 0 252 | }), 253 | 254 | reset: assign({ 255 | currentStep: 0, 256 | activeToken: 0, 257 | zoomLevel: ZOOM_LEVELS.GLOBAL, 258 | selectedComponent: undefined, 259 | selectedLayer: undefined 260 | }), 261 | 262 | setActiveToken: assign({ 263 | activeToken: ({ event, context }) => { 264 | if (event.type !== 'SET_ACTIVE_TOKEN') return context.activeToken; 265 | // Ensure token index is within sequence bounds 266 | const tokenIndex = event.tokenIndex; 267 | const maxIndex = Math.max(0, context.inputSequence.length - 1); 268 | return Math.min(Math.max(0, tokenIndex), maxIndex); 269 | } 270 | }), 271 | 272 | toggleFocusToken: createToggleAction('focusToken'), 273 | 274 | zoomIn: assign({ 275 | zoomLevel: ({ context }) => { 276 | const levels = Object.values(ZOOM_LEVELS); 277 | const currentIndex = levels.indexOf(context.zoomLevel); 278 | const nextIndex = Math.min(currentIndex + 1, levels.length - 1); 279 | return levels[nextIndex]; 280 | }, 281 | selectedComponent: ({ event }) => 282 | event.type === 'ZOOM_IN' ? event.component : undefined, 283 | selectedLayer: ({ event }) => 284 | event.type === 'ZOOM_IN' ? event.layer : undefined 285 | }), 286 | 287 | zoomOut: assign({ 288 | zoomLevel: ({ context }) => { 289 | const levels = Object.values(ZOOM_LEVELS); 290 | const currentIndex = levels.indexOf(context.zoomLevel); 291 | const prevIndex = Math.max(currentIndex - 1, 0); 292 | return levels[prevIndex]; 293 | }, 294 | selectedComponent: ({ context }) => 295 | context.zoomLevel === ZOOM_LEVELS.GLOBAL ? undefined : context.selectedComponent, 296 | selectedLayer: ({ context }) => 297 | context.zoomLevel === ZOOM_LEVELS.GLOBAL ? undefined : context.selectedLayer 298 | }), 299 | 300 | setZoomLevel: assign({ 301 | zoomLevel: ({ event }) => 302 | event.type === 'SET_ZOOM_LEVEL' ? event.level : ZOOM_LEVELS.GLOBAL, 303 | selectedComponent: ({ event }) => 304 | event.type === 'SET_ZOOM_LEVEL' ? event.component : undefined 305 | }), 306 | 307 | updateParams: assign(({ context, event }) => { 308 | if (event.type !== 'UPDATE_PARAMS') return {}; 309 | 310 | const newParams = { ...context.params, ...event.params }; 311 | 312 | // Validate parameters 313 | if (newParams.numHeads > 0 && newParams.dModel % newParams.numHeads !== 0) { 314 | newParams.numHeads = Math.floor(newParams.dModel / 8) || 1; 315 | } 316 | 317 | // Handle sequence length changes with truncation/padding instead of regeneration 318 | const lengthChanged = event.params.seqLen && event.params.seqLen !== context.params.seqLen; 319 | 320 | if (lengthChanged) { 321 | const newLen = event.params.seqLen!; 322 | 323 | // Adjust existing sequences to new length instead of generating new ones 324 | const adjustedInputSequence = SequenceGenerator.adjustSequenceLength(context.inputSequence, newLen); 325 | const adjustedOutputSequence = SequenceGenerator.adjustSequenceLength(context.outputSequence, newLen); 326 | 327 | return { 328 | params: newParams, 329 | inputSequence: adjustedInputSequence, 330 | outputSequence: adjustedOutputSequence, 331 | // Keep existing sequence context since we're just adjusting length 332 | // Reset active token if it's now out of bounds 333 | activeToken: Math.min(context.activeToken, Math.max(0, newLen - 1)) 334 | }; 335 | } 336 | 337 | return { 338 | params: newParams 339 | }; 340 | }), 341 | 342 | toggleSettings: createToggleAction('settingsOpen'), 343 | 344 | setPlaybackSpeed: assign({ 345 | playbackSpeed: ({ event }) => 346 | event.type === 'SET_PLAYBACK_SPEED' ? event.speed : 1.0 347 | }), 348 | 349 | toggleDarkMode: createToggleAction('darkMode'), 350 | 351 | toggleColorBlindMode: createToggleAction('colorBlindMode'), 352 | 353 | toggleAttentionLens: createToggleAction('attentionLensActive'), 354 | 355 | setAttentionTopK: assign({ 356 | attentionTopK: ({ event }) => 357 | event.type === 'SET_ATTENTION_TOP_K' ? event.k : 5 358 | }), 359 | 360 | animationTick: assign(({ context }) => { 361 | const now = performance.now(); 362 | const delta = now - context.lastFrameTime; 363 | const currentFPS = delta > 0 ? Math.round(1000 / delta) : 60; 364 | 365 | // Auto-advance step based on playback speed 366 | const frameInterval = 1000 / 24; // 24 FPS as specified in PRD 367 | const adjustedInterval = frameInterval / context.playbackSpeed; 368 | const timeSinceLastStep = now - context.lastFrameTime; 369 | 370 | const currentStep = timeSinceLastStep >= adjustedInterval 371 | ? Math.min(context.currentStep + 1, context.totalSteps - 1) 372 | : context.currentStep; 373 | 374 | return { 375 | lastFrameTime: now, 376 | currentFPS, 377 | currentStep 378 | }; 379 | }), 380 | 381 | setError: assign({ 382 | error: ({ event }) => 383 | event.type === 'ERROR' ? event.message : undefined 384 | }), 385 | 386 | clearError: assign({ 387 | error: undefined 388 | }), 389 | 390 | 391 | } 392 | }); 393 | 394 | export default transformerMachine; -------------------------------------------------------------------------------- /src/utils/componentTransformations.ts: -------------------------------------------------------------------------------- 1 | import { ComponentSection } from '../types/events'; 2 | 3 | // Helper function to check if component is Add & Norm 4 | const isAddNormComponent = (componentId: string): boolean => { 5 | return componentId.includes('-norm'); 6 | }; 7 | 8 | // Helper function to get operations for Add & Norm components 9 | const getAddNormOperations = (): string[] => { 10 | return ['Residual Addition', 'Layer Normalization']; 11 | }; 12 | 13 | // Helper function to format component names 14 | export const formatComponentName = (id: string): string => { 15 | return id.split('-').map(word => 16 | word.charAt(0).toUpperCase() + word.slice(1) 17 | ).join(' '); 18 | }; 19 | 20 | // Helper function to get category colors 21 | export const getCategoryColor = (category: string): string => { 22 | const colors = { 23 | embedding: 'bg-blue-100 text-blue-800', 24 | attention: 'bg-purple-100 text-purple-800', 25 | ffn: 'bg-green-100 text-green-800', 26 | output: 'bg-orange-100 text-orange-800', 27 | tokens: 'bg-gray-100 text-gray-800', 28 | positional: 'bg-amber-100 text-amber-800', 29 | }; 30 | return colors[category as keyof typeof colors] || 'bg-gray-100 text-gray-800'; 31 | }; 32 | 33 | // Helper function to get operations for each component type 34 | export const getOperations = (category: string, componentId: string): string[] => { 35 | switch (category) { 36 | case 'embedding': 37 | return ['Token Lookup', 'Positional Addition']; 38 | case 'attention': 39 | // Handle Add & Norm blocks for attention 40 | if (isAddNormComponent(componentId)) { 41 | return getAddNormOperations(); 42 | } 43 | return ['Q/K/V Projection', 'Scaled Dot-Product', 'Softmax', 'Value Weighting']; 44 | case 'ffn': 45 | // Handle Add & Norm blocks for FFN 46 | if (isAddNormComponent(componentId)) { 47 | return getAddNormOperations(); 48 | } 49 | return ['Linear Expansion', 'ReLU Activation', 'Linear Projection']; 50 | case 'output': 51 | // Handle encoder output specifically 52 | if (componentId === 'encoder-output') { 53 | return ['Contextual Representation', 'Multi-Layer Processing']; 54 | } 55 | // Handle final output (Linear + Softmax) 56 | if (componentId === 'final-output') { 57 | return ['Vocabulary Projection', 'Softmax Computation']; 58 | } 59 | // Handle individual linear layer 60 | if (componentId.includes('linear') || componentId.includes('lm-head')) { 61 | return ['Vocabulary Projection']; 62 | } 63 | // Handle individual softmax 64 | if (componentId.includes('softmax')) { 65 | return ['Softmax Computation']; 66 | } 67 | // Handle predicted tokens (final sampling) 68 | if (componentId === 'predicted-tokens') { 69 | return ['Token Sampling']; 70 | } 71 | return ['Vocabulary Projection', 'Softmax', 'Token Sampling']; 72 | case 'tokens': 73 | return ['Tokenization']; 74 | case 'positional': 75 | return ['Sinusoidal Encoding']; 76 | default: 77 | // Handle Add & Norm blocks specifically (fallback) 78 | if (isAddNormComponent(componentId)) { 79 | return getAddNormOperations(); 80 | } 81 | return ['Component Processing']; 82 | } 83 | }; 84 | 85 | // Transformation generators for each component category 86 | const generateEmbeddingTransformations = (): ComponentSection[] => [ 87 | { 88 | id: 'embedding-lookup', 89 | label: 'Token Lookup', 90 | description: 'Lookup embedding vectors from learned embedding matrix', 91 | type: 'text', 92 | data: 'embedding_matrix[token_indices]', 93 | metadata: { 94 | operation: 'lookup', 95 | formula: 'E = W_e[tokens]', 96 | complexity: 'O(seq_len)', 97 | activation: 'none' 98 | } 99 | }, 100 | { 101 | id: 'positional-addition', 102 | label: 'Positional Addition', 103 | description: 'Element-wise addition of positional encoding to embeddings', 104 | type: 'text', 105 | data: 'embeddings + positional_encoding', 106 | metadata: { 107 | operation: 'addition', 108 | formula: 'H = E + P', 109 | complexity: 'O(seq_len × d_model)', 110 | activation: 'none' 111 | } 112 | } 113 | ]; 114 | 115 | const generateAttentionTransformations = (): ComponentSection[] => [ 116 | { 117 | id: 'linear-projections', 118 | label: 'Q/K/V Projection', 119 | description: 'Project input to query, key, and value spaces', 120 | type: 'text', 121 | data: 'Q = XW_Q, K = XW_K, V = XW_V', 122 | metadata: { 123 | operation: 'linear', 124 | formula: 'Q, K, V = XW_Q, XW_K, XW_V', 125 | complexity: 'O(seq_len × d_model²)', 126 | activation: 'none' 127 | } 128 | }, 129 | { 130 | id: 'attention-scores', 131 | label: 'Scaled Dot-Product', 132 | description: 'Compute attention scores using scaled dot-product', 133 | type: 'text', 134 | data: 'scores = QK^T / √d_k', 135 | metadata: { 136 | operation: 'dot-product', 137 | formula: 'Attention(Q,K,V) = softmax(QK^T/√d_k)V', 138 | complexity: 'O(seq_len²)', 139 | activation: 'none' 140 | } 141 | }, 142 | { 143 | id: 'softmax-activation', 144 | label: 'Softmax', 145 | description: 'Apply softmax to get attention weights', 146 | type: 'text', 147 | data: 'attention_weights = softmax(scores)', 148 | metadata: { 149 | operation: 'normalization', 150 | formula: 'α_ij = exp(e_ij) / Σ_k exp(e_ik)', 151 | complexity: 'O(seq_len²)', 152 | activation: 'softmax' 153 | } 154 | }, 155 | { 156 | id: 'weighted-combination', 157 | label: 'Value Weighting', 158 | description: 'Compute weighted combination of values', 159 | type: 'text', 160 | data: 'output = attention_weights × V', 161 | metadata: { 162 | operation: 'weighted-sum', 163 | formula: 'head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)', 164 | complexity: 'O(seq_len × d_model)', 165 | activation: 'none' 166 | } 167 | } 168 | ]; 169 | 170 | const generateFFNTransformations = (): ComponentSection[] => [ 171 | { 172 | id: 'first-linear', 173 | label: 'Linear Expansion', 174 | description: 'Expand to intermediate dimension', 175 | type: 'text', 176 | data: 'intermediate = XW₁ + b₁', 177 | metadata: { 178 | operation: 'linear', 179 | formula: 'FFN(x) = max(0, xW₁ + b₁)W₂ + b₂', 180 | complexity: 'O(seq_len × d_model × d_ff)', 181 | activation: 'none' 182 | } 183 | }, 184 | { 185 | id: 'relu-activation', 186 | label: 'ReLU Activation', 187 | description: 'Apply rectified linear unit activation', 188 | type: 'text', 189 | data: 'activated = max(0, intermediate)', 190 | metadata: { 191 | operation: 'activation', 192 | formula: 'ReLU(x) = max(0, x)', 193 | complexity: 'O(seq_len × d_ff)', 194 | activation: 'relu' 195 | } 196 | }, 197 | { 198 | id: 'second-linear', 199 | label: 'Linear Projection', 200 | description: 'Project back to model dimension', 201 | type: 'text', 202 | data: 'output = activatedW₂ + b₂', 203 | metadata: { 204 | operation: 'linear', 205 | formula: 'W₂ ∈ ℝ^(d_ff × d_model)', 206 | complexity: 'O(seq_len × d_ff × d_model)', 207 | activation: 'none' 208 | } 209 | } 210 | ]; 211 | 212 | const generateAddNormTransformations = (): ComponentSection[] => [ 213 | { 214 | id: 'residual-addition', 215 | label: 'Residual Addition', 216 | description: 'Add input to block output (skip connection)', 217 | type: 'text', 218 | data: 'residual_sum = X + block_output', 219 | metadata: { 220 | operation: 'addition', 221 | formula: 'X_residual = X + SubLayer(X)', 222 | complexity: 'O(seq_len × d_model)', 223 | activation: 'none' 224 | } 225 | }, 226 | { 227 | id: 'layer-normalization', 228 | label: 'Layer Normalization', 229 | description: 'Normalize across feature dimension', 230 | type: 'text', 231 | data: 'normalized = LayerNorm(residual_sum)', 232 | metadata: { 233 | operation: 'normalization', 234 | formula: 'LayerNorm(x) = γ ⊙ (x - μ)/σ + β', 235 | complexity: 'O(seq_len × d_model)', 236 | activation: 'none' 237 | } 238 | } 239 | ]; 240 | 241 | const generateOutputTransformations = (componentId: string): ComponentSection[] => { 242 | const transformations: ComponentSection[] = []; 243 | 244 | if (componentId.includes('linear') || componentId.includes('lm-head')) { 245 | transformations.push({ 246 | id: 'vocab-projection', 247 | label: 'Vocabulary Projection', 248 | description: 'Project hidden states to vocabulary size', 249 | type: 'text', 250 | data: 'logits = hidden_states × W_vocab', 251 | metadata: { 252 | operation: 'linear', 253 | formula: 'logits = HW_v + b_v', 254 | complexity: 'O(seq_len × d_model × vocab_size)', 255 | activation: 'none' 256 | } 257 | }); 258 | } 259 | 260 | if (componentId.includes('softmax')) { 261 | transformations.push({ 262 | id: 'softmax-computation', 263 | label: 'Softmax Computation', 264 | description: 'Convert logits to probability distribution', 265 | type: 'text', 266 | data: 'probs = softmax(logits)', 267 | metadata: { 268 | operation: 'normalization', 269 | formula: 'p_i = exp(logit_i) / Σ_j exp(logit_j)', 270 | complexity: 'O(seq_len × vocab_size)', 271 | activation: 'softmax' 272 | } 273 | }); 274 | } 275 | 276 | if (componentId.includes('predicted')) { 277 | transformations.push({ 278 | id: 'token-sampling', 279 | label: 'Token Sampling', 280 | description: 'Sample tokens from probability distribution', 281 | type: 'text', 282 | data: 'token = sample(probabilities)', 283 | metadata: { 284 | operation: 'sampling', 285 | formula: 'token ~ Categorical(p)', 286 | complexity: 'O(seq_len)', 287 | activation: 'none' 288 | } 289 | }); 290 | } 291 | 292 | return transformations; 293 | }; 294 | 295 | // Generate transformations for encoder output 296 | const generateEncoderOutputTransformations = (): ComponentSection[] => [ 297 | { 298 | id: 'contextualized-representations', 299 | label: 'Contextual Representation', 300 | description: 'Final encoder representations with full contextual information', 301 | type: 'text', 302 | data: 'encoder_output = LayerNorm(H_N + Attention(H_N))', 303 | metadata: { 304 | operation: 'normalization', 305 | formula: 'H_enc = LayerNorm(H_N + MultiHeadAttention(H_N))', 306 | complexity: 'O(seq_len × d_model)', 307 | activation: 'none' 308 | } 309 | }, 310 | { 311 | id: 'multi-layer-processing', 312 | label: 'Multi-Layer Processing', 313 | description: 'Result of processing through all encoder layers', 314 | type: 'text', 315 | data: 'for layer in encoder_layers: H = layer(H)', 316 | metadata: { 317 | operation: 'composition', 318 | formula: 'H_enc = EncoderLayer_N(...EncoderLayer_1(H_0))', 319 | complexity: 'O(N × seq_len × d_model²)', 320 | activation: 'various' 321 | } 322 | } 323 | ]; 324 | 325 | // Generate transformations for final output (Linear + Softmax) 326 | const generateFinalOutputTransformations = (): ComponentSection[] => [ 327 | { 328 | id: 'vocabulary-projection', 329 | label: 'Vocabulary Projection', 330 | description: 'Project decoder states to vocabulary logits', 331 | type: 'text', 332 | data: 'logits = decoder_output × W_vocab + b_vocab', 333 | metadata: { 334 | operation: 'linear', 335 | formula: 'logits = H_dec × W_v + b_v', 336 | complexity: 'O(seq_len × d_model × vocab_size)', 337 | activation: 'none' 338 | } 339 | }, 340 | { 341 | id: 'softmax-computation', 342 | label: 'Softmax Computation', 343 | description: 'Convert logits to probability distribution over vocabulary', 344 | type: 'text', 345 | data: 'probabilities = softmax(logits)', 346 | metadata: { 347 | operation: 'normalization', 348 | formula: 'P(token_i) = exp(logit_i) / Σ_j exp(logit_j)', 349 | complexity: 'O(seq_len × vocab_size)', 350 | activation: 'softmax' 351 | } 352 | } 353 | ]; 354 | 355 | // Main function to generate detailed transformation data 356 | export const generateDetailedTransformations = (category: string, componentId: string): ComponentSection[] => { 357 | switch (category) { 358 | case 'embedding': 359 | return generateEmbeddingTransformations(); 360 | case 'attention': 361 | // Handle Add & Norm blocks for attention 362 | if (isAddNormComponent(componentId)) { 363 | return generateAddNormTransformations(); 364 | } 365 | return generateAttentionTransformations(); 366 | case 'ffn': 367 | // Handle Add & Norm blocks for FFN 368 | if (isAddNormComponent(componentId)) { 369 | return generateAddNormTransformations(); 370 | } 371 | return generateFFNTransformations(); 372 | case 'output': 373 | // Handle specific output components 374 | if (componentId === 'encoder-output') { 375 | return generateEncoderOutputTransformations(); 376 | } 377 | if (componentId === 'final-output') { 378 | return generateFinalOutputTransformations(); 379 | } 380 | return generateOutputTransformations(componentId); 381 | default: 382 | // Handle Add & Norm blocks specifically (fallback) 383 | if (isAddNormComponent(componentId)) { 384 | return generateAddNormTransformations(); 385 | } 386 | // Default transformation for other components 387 | return [{ 388 | id: 'default-transform', 389 | label: 'Component Processing', 390 | description: 'Process inputs through this component', 391 | type: 'text', 392 | data: 'output = f(inputs, parameters)', 393 | metadata: { 394 | operation: 'function', 395 | formula: 'y = f(x, θ)', 396 | complexity: 'O(n)', 397 | activation: 'varies' 398 | } 399 | }]; 400 | } 401 | }; -------------------------------------------------------------------------------- /src/components/ComponentDetailsPanel.tsx: -------------------------------------------------------------------------------- 1 | import { useState } from 'react'; 2 | import { EyeIcon, DocumentTextIcon, ChevronDownIcon, ChevronRightIcon, ListBulletIcon } from '@heroicons/react/24/outline'; 3 | import { ComponentData, type ComponentSection } from '../types/events'; 4 | import { MatrixVisualization } from './MatrixVisualization'; 5 | import { TokenVisualization } from './TokenVisualization'; 6 | import { formatComponentName, getCategoryColor, getOperations, generateDetailedTransformations } from '../utils/componentTransformations'; 7 | 8 | interface ComponentDetailsPanelProps { 9 | componentId: string; 10 | componentData: ComponentData; 11 | onClose: () => void; 12 | } 13 | 14 | export function ComponentDetailsPanel({ 15 | componentId, 16 | componentData, 17 | onClose 18 | }: ComponentDetailsPanelProps) { 19 | const [viewMode, setViewMode] = useState<'textual' | 'visual'>('textual'); 20 | const [expandedSections, setExpandedSections] = useState>(new Set()); 21 | const [showTransformationModal, setShowTransformationModal] = useState(false); 22 | 23 | const toggleSection = (sectionId: string) => { 24 | const newExpanded = new Set(expandedSections); 25 | if (newExpanded.has(sectionId)) { 26 | newExpanded.delete(sectionId); 27 | } else { 28 | newExpanded.add(sectionId); 29 | } 30 | setExpandedSections(newExpanded); 31 | }; 32 | 33 | 34 | 35 | 36 | 37 | const operations = getOperations(componentData.category, componentId); 38 | 39 | return ( 40 |
41 | {/* Header */} 42 |
43 |
44 |
45 |

46 | Component Details 47 |

48 | 49 | {componentData.category} 50 | 51 |
52 | 58 |
59 | 60 |
61 |

62 | {formatComponentName(componentId)} 63 |

64 |

65 | {componentData.description} 66 |

67 |
68 | 69 | {/* Operations Section */} 70 |
71 |
72 |
Operations
73 | 81 |
82 |
    83 | {operations.map((operation, index) => ( 84 |
  1. 85 | {index + 1}. 86 | {operation} 87 |
  2. 88 | ))} 89 |
90 |
91 | 92 | {/* View Mode Toggle */} 93 |
94 | 105 | 116 |
117 |
118 | 119 | {/* Content */} 120 |
121 | {/* Inputs Section - Only show if there are inputs */} 122 | {componentData.inputs.length > 0 && ( 123 | toggleSection('inputs')} 129 | emptyMessage="No inputs (starting component)" 130 | /> 131 | )} 132 | 133 | {/* Parameters Section - Only show if there are parameters */} 134 | {componentData.parameters.length > 0 && ( 135 | toggleSection('parameters')} 141 | emptyMessage="No learnable parameters" 142 | /> 143 | )} 144 | 145 | {/* Outputs Section - Only show if there are outputs */} 146 | {componentData.outputs.length > 0 && ( 147 | toggleSection('outputs')} 153 | emptyMessage="No outputs (final component)" 154 | /> 155 | )} 156 |
157 | 158 | {/* Transformation Detail Modal */} 159 | {showTransformationModal && ( 160 | setShowTransformationModal(false)} 164 | /> 165 | )} 166 |
167 | ); 168 | } 169 | 170 | interface ComponentSectionProps { 171 | title: string; 172 | sections: ComponentSection[]; 173 | viewMode: 'textual' | 'visual'; 174 | isExpanded: boolean; 175 | onToggle: () => void; 176 | emptyMessage: string; 177 | } 178 | 179 | function ComponentSectionContainer({ 180 | title, 181 | sections, 182 | viewMode, 183 | isExpanded, 184 | onToggle, 185 | emptyMessage 186 | }: ComponentSectionProps) { 187 | const sectionColors = { 188 | Inputs: 'border-blue-200 bg-blue-50', 189 | Parameters: 'border-purple-200 bg-purple-50', 190 | Outputs: 'border-green-200 bg-green-50', 191 | }; 192 | 193 | const iconColors = { 194 | Inputs: 'text-blue-600', 195 | Parameters: 'text-purple-600', 196 | Outputs: 'text-green-600', 197 | }; 198 | 199 | // Helper function to get brief summary of items 200 | const getBriefSummary = () => { 201 | if (sections.length === 0) { 202 | return []; 203 | } 204 | 205 | return sections.map(section => { 206 | let summary = section.label; 207 | 208 | // Add shape information if available 209 | if (section.shape) { 210 | const shapeStr = section.shape.length === 2 211 | ? `${section.shape[0]}×${section.shape[1]}` 212 | : `${section.shape[0]}`; 213 | summary += `, ${shapeStr}`; 214 | } 215 | 216 | return summary; 217 | }); 218 | }; 219 | 220 | return ( 221 |
222 | 236 | 237 | {/* Brief summary when collapsed */} 238 | {!isExpanded && sections.length > 0 && ( 239 |
240 |
241 | {getBriefSummary().map((summary, index) => ( 242 |
243 | 244 | {summary} 245 |
246 | ))} 247 |
248 |
249 | )} 250 | 251 | {isExpanded && ( 252 |
253 | {sections.length === 0 ? ( 254 |
255 |

{emptyMessage}

256 |
257 | ) : ( 258 |
259 | {sections.map((section) => ( 260 | 265 | ))} 266 |
267 | )} 268 |
269 | )} 270 |
271 | ); 272 | } 273 | 274 | interface SectionItemProps { 275 | section: ComponentSection; 276 | viewMode: 'textual' | 'visual'; 277 | } 278 | 279 | function SectionItem({ section, viewMode }: SectionItemProps) { 280 | const renderContent = () => { 281 | if (viewMode === 'textual') { 282 | return ( 283 |
284 |
285 |
{section.label}
286 | {section.shape && ( 287 | 288 | {section.shape.length === 2 ? `${section.shape[0]}×${section.shape[1]}` : `${section.shape[0]}`} 289 | 290 | )} 291 |
292 |

{section.description}

293 | {section.metadata && ( 294 |
295 | {Object.entries(section.metadata).map(([key, value]) => ( 296 |
297 | {key}: 298 | {String(value)} 299 |
300 | ))} 301 |
302 | )} 303 |
304 | ); 305 | } else { 306 | // Visual mode 307 | if (section.type === 'tokens') { 308 | return ( 309 | 313 | ); 314 | } else if (section.type === 'matrix' || section.type === 'vector') { 315 | return ( 316 | 321 | ); 322 | } else if (section.type === 'scalar') { 323 | return ( 324 |
325 |
326 |
327 | {typeof section.data === 'number' ? section.data.toFixed(4) : section.data} 328 |
329 |
{section.label}
330 |
331 |
332 | ); 333 | } else { 334 | return ( 335 |
336 |
{section.description}
337 |
338 | ); 339 | } 340 | } 341 | }; 342 | 343 | return ( 344 |
345 | {renderContent()} 346 |
347 | ); 348 | } 349 | 350 | // Detailed transformation modal (keep existing implementation) 351 | interface TransformationDetailModalProps { 352 | componentName: string; 353 | transformations: ComponentSection[]; 354 | onClose: () => void; 355 | } 356 | 357 | function TransformationDetailModal({ 358 | componentName, 359 | transformations, 360 | onClose 361 | }: TransformationDetailModalProps) { 362 | const getOperationIcon = (operation: string) => { 363 | switch (operation) { 364 | case 'linear': 365 | case 'lookup': 366 | return '⊗'; 367 | case 'addition': 368 | return '+'; 369 | case 'dot-product': 370 | return '⋅'; 371 | case 'normalization': 372 | return 'σ'; 373 | case 'activation': 374 | return 'f'; 375 | case 'weighted-sum': 376 | return '∑'; 377 | case 'sampling': 378 | return '🎲'; 379 | default: 380 | return '⚙️'; 381 | } 382 | }; 383 | 384 | const getOperationColor = (operation: string) => { 385 | switch (operation) { 386 | case 'linear': 387 | case 'lookup': 388 | return 'bg-blue-100 text-blue-800 border-blue-300'; 389 | case 'addition': 390 | return 'bg-green-100 text-green-800 border-green-300'; 391 | case 'dot-product': 392 | return 'bg-purple-100 text-purple-800 border-purple-300'; 393 | case 'normalization': 394 | return 'bg-yellow-100 text-yellow-800 border-yellow-300'; 395 | case 'activation': 396 | return 'bg-red-100 text-red-800 border-red-300'; 397 | case 'weighted-sum': 398 | return 'bg-indigo-100 text-indigo-800 border-indigo-300'; 399 | case 'sampling': 400 | return 'bg-pink-100 text-pink-800 border-pink-300'; 401 | default: 402 | return 'bg-gray-100 text-gray-800 border-gray-300'; 403 | } 404 | }; 405 | 406 | const getActivationBadge = (activation: string) => { 407 | if (!activation || activation === 'none') return null; 408 | 409 | const activationColors = { 410 | 'relu': 'bg-red-500 text-white', 411 | 'softmax': 'bg-blue-500 text-white', 412 | 'tanh': 'bg-green-500 text-white', 413 | 'sigmoid': 'bg-purple-500 text-white', 414 | 'gelu': 'bg-orange-500 text-white', 415 | }; 416 | 417 | return ( 418 | 419 | {activation.toUpperCase()} 420 | 421 | ); 422 | }; 423 | 424 | return ( 425 |
426 |
427 | {/* Header */} 428 |
429 |
430 |
431 |

{componentName}

432 |

Detailed Transformation Flow

433 |
434 | 440 |
441 |
442 | 443 | {/* Content */} 444 |
445 |
446 | {transformations.map((transformation, index) => ( 447 |
448 |
449 |
450 | {/* Step indicator */} 451 |
452 |
453 | {getOperationIcon(transformation.metadata?.operation || 'default')} 454 |
455 | {index < transformations.length - 1 && ( 456 |
457 | )} 458 |
459 | 460 | {/* Content */} 461 |
462 |
463 |
464 |

{transformation.label}

465 |

{transformation.description}

466 |
467 |
468 | {getActivationBadge(transformation.metadata?.activation)} 469 | 470 | {transformation.metadata?.complexity} 471 | 472 |
473 |
474 | 475 | {/* Code and formula */} 476 |
477 |
478 |
Operation:
479 |
{transformation.data}
480 |
481 | 482 | {transformation.metadata?.formula && ( 483 |
484 |
Mathematical Formula:
485 |
{transformation.metadata.formula}
486 |
487 | )} 488 |
489 |
490 |
491 |
492 |
493 | ))} 494 |
495 |
496 | 497 | {/* Footer */} 498 |
499 |

500 | Click outside the modal or the ✕ button to close 501 |

502 |
503 |
504 |
505 | ); 506 | } -------------------------------------------------------------------------------- /src/hooks/useTransformerDiagram.ts: -------------------------------------------------------------------------------- 1 | import { useRef, useEffect, RefObject } from 'react'; 2 | import * as d3 from 'd3'; 3 | import { TransformerParams } from '../types/events'; 4 | import { COLORS } from '../utils/constants'; 5 | 6 | interface ComponentPosition { 7 | x: number; 8 | y: number; 9 | width: number; 10 | height: number; 11 | id: string; 12 | } 13 | 14 | const drawImprovedTransformerArchitecture = ( 15 | g: d3.Selection, 16 | params: TransformerParams, 17 | selectedComponent: string | null, 18 | onComponentClick: (componentId: string) => void, 19 | width: number, 20 | height: number, 21 | ) => { 22 | const componentWidth = 120; 23 | const componentHeight = 35; 24 | const positionalSize = 28; // Size for circular positional block 25 | 26 | const sectionWidth = Math.max(componentWidth + 40, Math.min(300, width * 0.45)); 27 | 28 | const encoderX = width * 0.25; 29 | const decoderX = width * 0.75; 30 | 31 | const components: ComponentPosition[] = []; 32 | 33 | // CodeSignal brand colors for consistent visual identity 34 | const styles = { 35 | input: { color: '#6b7280', textColor: 'white', canClick: true }, // Gray 500 36 | embedding: { color: '#1062FB', textColor: 'white', canClick: true }, // Beyond Blue (Primary) 37 | positional: { color: '#FFC500', textColor: 'black', canClick: true }, // Gold Standard Yellow (Secondary) 38 | attention: { color: '#E6193F', textColor: 'white', canClick: true }, // Ruby Red (Tertiary) 39 | norm: { color: '#21CF82', textColor: 'white', canClick: true }, // Victory Green (Tertiary) 40 | ffn: { color: '#FF7232', textColor: 'white', canClick: true }, // FiresIDE Orange (Tertiary) 41 | output: { color: '#002570', textColor: 'white', canClick: true }, // Dive Deep Blue (Secondary) 42 | decoder: { color: '#002570', textColor: 'white', canClick: true }, // Dive Deep Blue (Secondary) 43 | }; 44 | 45 | const CROSS_COLOR = '#64D3FF'; // SignalLight Blue (Secondary) 46 | const RESIDUAL_COLOR = '#9ca3af'; // Gray 400 47 | 48 | const createComponent = ( 49 | id: string, 50 | label: string, 51 | type: string, 52 | x: number, 53 | y: number, 54 | width = componentWidth, 55 | height = componentHeight, 56 | ): ComponentPosition => { 57 | const isSelected = selectedComponent === id; 58 | const style = styles[type as keyof typeof styles]; 59 | 60 | const group = g 61 | .append('g') 62 | .attr('class', 'component') 63 | .style('cursor', style.canClick ? 'pointer' : 'default') 64 | .on('click', () => style.canClick && onComponentClick(id)); 65 | 66 | group 67 | .append('rect') 68 | .attr('x', x - width / 2 + 2) 69 | .attr('y', y - height / 2 + 2) 70 | .attr('width', width) 71 | .attr('height', height) 72 | .attr('rx', 6) 73 | .attr('fill', '#00000020') 74 | .attr('opacity', 0.3); 75 | 76 | group 77 | .append('rect') 78 | .attr('x', x - width / 2) 79 | .attr('y', y - height / 2) 80 | .attr('width', width) 81 | .attr('height', height) 82 | .attr('rx', 6) 83 | .attr('fill', style.color) 84 | .attr('stroke', isSelected ? COLORS.gold : 'none') 85 | .attr('stroke-width', isSelected ? 3 : 0) 86 | .style('filter', isSelected ? 'drop-shadow(0 0 10px rgba(255, 197, 0, 0.5))' : 'none'); 87 | 88 | const lines = label.split('\n'); 89 | const fontSize = lines.length > 1 ? 10 : 11; 90 | lines.forEach((line, i) => { 91 | group 92 | .append('text') 93 | .attr('x', x) 94 | .attr('y', y + (i - (lines.length - 1) / 2) * 12) 95 | .attr('text-anchor', 'middle') 96 | .attr('dy', '0.35em') 97 | .attr('font-size', fontSize) 98 | .attr('font-weight', isSelected ? 'bold' : '500') 99 | .attr('fill', style.textColor) 100 | .text(line); 101 | }); 102 | 103 | const position = { x, y, width, height, id }; 104 | components.push(position); 105 | return position; 106 | }; 107 | 108 | const createPositionalComponent = ( 109 | id: string, 110 | x: number, 111 | y: number, 112 | size = positionalSize, 113 | ): ComponentPosition => { 114 | const isSelected = selectedComponent === id; 115 | const style = styles.positional; 116 | 117 | const group = g 118 | .append('g') 119 | .attr('class', 'component') 120 | .style('cursor', style.canClick ? 'pointer' : 'default') 121 | .on('click', () => style.canClick && onComponentClick(id)); 122 | 123 | // Drop shadow 124 | group 125 | .append('circle') 126 | .attr('cx', x + 2) 127 | .attr('cy', y + 2) 128 | .attr('r', size / 2) 129 | .attr('fill', '#00000020') 130 | .attr('opacity', 0.3); 131 | 132 | // Main circle 133 | group 134 | .append('circle') 135 | .attr('cx', x) 136 | .attr('cy', y) 137 | .attr('r', size / 2) 138 | .attr('fill', style.color) 139 | .attr('stroke', isSelected ? COLORS.gold : 'none') 140 | .attr('stroke-width', isSelected ? 3 : 0) 141 | .style('filter', isSelected ? 'drop-shadow(0 0 10px rgba(255, 197, 0, 0.5))' : 'none'); 142 | 143 | // Position icon (clock-like symbol) 144 | const iconGroup = group.append('g'); 145 | 146 | // Clock face 147 | iconGroup 148 | .append('circle') 149 | .attr('cx', x) 150 | .attr('cy', y) 151 | .attr('r', size / 4) 152 | .attr('fill', 'none') 153 | .attr('stroke', style.textColor) 154 | .attr('stroke-width', 1.5); 155 | 156 | // Clock hands 157 | iconGroup 158 | .append('line') 159 | .attr('x1', x) 160 | .attr('y1', y) 161 | .attr('x2', x) 162 | .attr('y2', y - size / 6) 163 | .attr('stroke', style.textColor) 164 | .attr('stroke-width', 1.5) 165 | .attr('stroke-linecap', 'round'); 166 | 167 | iconGroup 168 | .append('line') 169 | .attr('x1', x) 170 | .attr('y1', y) 171 | .attr('x2', x + size / 8) 172 | .attr('y2', y) 173 | .attr('stroke', style.textColor) 174 | .attr('stroke-width', 1.5) 175 | .attr('stroke-linecap', 'round'); 176 | 177 | // Center dot 178 | iconGroup 179 | .append('circle') 180 | .attr('cx', x) 181 | .attr('cy', y) 182 | .attr('r', 1.5) 183 | .attr('fill', style.textColor); 184 | 185 | const position = { x, y, width: size, height: size, id }; 186 | components.push(position); 187 | return position; 188 | }; 189 | 190 | const drawArrow = (from: ComponentPosition, to: ComponentPosition, type: 'main' | 'residual' | 'cross' | 'side' = 'main') => { 191 | let fromX, fromY, toX, toY; 192 | 193 | if (type === 'residual') { 194 | const isEncoder = from.x < width / 2; 195 | if (isEncoder) { 196 | fromX = from.x - from.width / 2; 197 | fromY = from.y + from.height / 4; 198 | toX = to.x - to.width / 2; 199 | toY = to.y - to.height / 4; 200 | } else { 201 | fromX = from.x + from.width / 2; 202 | fromY = from.y + from.height / 4; 203 | toX = to.x + to.width / 2; 204 | toY = to.y - to.height / 4; 205 | } 206 | } else if (type === 'cross') { 207 | fromX = from.x + from.width / 2; 208 | fromY = from.y; 209 | toX = to.x - to.width / 2; 210 | toY = to.y; 211 | } else if (type === 'side') { 212 | // Side-to-side arrow for positional to embedding 213 | const isEncoder = from.x < width / 2; 214 | if (isEncoder) { 215 | // Encoder: positional (right) to embedding (left) - arrow goes from left side of positional to right side of embedding 216 | fromX = from.x - from.width / 2; 217 | fromY = from.y; 218 | toX = to.x + to.width / 2; 219 | toY = to.y; 220 | } else { 221 | // Decoder: positional (left) to embedding (right) - arrow goes from right side of positional to left side of embedding 222 | fromX = from.x + from.width / 2; 223 | fromY = from.y; 224 | toX = to.x - to.width / 2; 225 | toY = to.y; 226 | } 227 | } else { 228 | fromX = from.x; 229 | fromY = from.y + from.height / 2; 230 | toX = to.x; 231 | toY = to.y - to.height / 2; 232 | } 233 | 234 | const strokeColor = type === 'residual' ? RESIDUAL_COLOR : type === 'cross' ? CROSS_COLOR : '#64748b'; 235 | const strokeWidth = type === 'residual' ? 2 : type === 'cross' ? 2 : 1.5; 236 | 237 | const path = g 238 | .append('path') 239 | .attr('stroke', strokeColor) 240 | .attr('stroke-width', strokeWidth) 241 | .attr('fill', 'none') 242 | .attr('marker-end', `url(#arrowhead-${type === 'side' ? 'main' : type})`); 243 | 244 | if (type === 'cross') { 245 | const midX = (fromX + toX) / 2; 246 | path.attr('d', `M ${fromX} ${fromY} L ${midX} ${fromY} L ${midX} ${toY} L ${toX} ${toY}`); 247 | } else if (type === 'residual') { 248 | const sideOffset = 30; 249 | const isEncoder = from.x < width / 2; 250 | const sideX = isEncoder ? fromX - sideOffset : fromX + sideOffset; 251 | const pathData = [`M ${fromX} ${fromY}`, `L ${sideX} ${fromY}`, `L ${sideX} ${toY}`, `L ${toX} ${toY}`].join(' '); 252 | path.attr('d', pathData); 253 | } else { 254 | path.attr('d', `M ${fromX} ${fromY} L ${toX} ${toY}`); 255 | } 256 | }; 257 | 258 | const defs = g.append('defs'); 259 | ['main', 'residual', 'cross'].forEach((type) => { 260 | const color = type === 'residual' ? RESIDUAL_COLOR : type === 'cross' ? CROSS_COLOR : '#64748b'; 261 | defs 262 | .append('marker') 263 | .attr('id', `arrowhead-${type}`) 264 | .attr('viewBox', '0 0 10 10') 265 | .attr('refX', 8) 266 | .attr('refY', 3) 267 | .attr('markerWidth', 6) 268 | .attr('markerHeight', 6) 269 | .attr('orient', 'auto') 270 | .append('path') 271 | .attr('d', 'M 0 0 L 10 3 L 0 6 z') 272 | .attr('fill', color); 273 | }); 274 | 275 | g.append('rect') 276 | .attr('x', encoderX - sectionWidth / 2) 277 | .attr('y', 30) 278 | .attr('width', sectionWidth) 279 | .attr('height', height - 60) 280 | .attr('fill', '#3b82f6') 281 | .attr('fill-opacity', 0.2) 282 | .attr('stroke', '#3b82f6') 283 | .attr('stroke-width', 3) 284 | .attr('stroke-dasharray', '8,4') 285 | .attr('rx', 12); 286 | 287 | g.append('rect') 288 | .attr('x', decoderX - sectionWidth / 2) 289 | .attr('y', 30) 290 | .attr('width', sectionWidth) 291 | .attr('height', height - 60) 292 | .attr('fill', '#8b5cf6') 293 | .attr('fill-opacity', 0.2) 294 | .attr('stroke', '#8b5cf6') 295 | .attr('stroke-width', 3) 296 | .attr('stroke-dasharray', '8,4') 297 | .attr('rx', 12); 298 | 299 | const encoderOutput = drawEncoderStack(g, params, createComponent, createPositionalComponent, drawArrow, encoderX, sectionWidth); 300 | drawDecoderStack(g, params, createComponent, createPositionalComponent, drawArrow, decoderX, encoderOutput, sectionWidth); 301 | 302 | drawArrow( 303 | encoderOutput, 304 | components.find((c) => c.id.startsWith('decoder-layer-0-cross-attention'))!, 305 | 'cross', 306 | ); 307 | }; 308 | 309 | const drawEncoderStack = ( 310 | g: d3.Selection, 311 | params: TransformerParams, 312 | createComponent: (id: string, label: string, type: string, x: number, y: number, width?: number, height?: number) => ComponentPosition, 313 | createPositionalComponent: (id: string, x: number, y: number, size?: number) => ComponentPosition, 314 | drawArrow: (from: ComponentPosition, to: ComponentPosition, type?: 'main' | 'residual' | 'cross' | 'side') => void, 315 | encoderX: number, 316 | sectionWidth: number, 317 | ): ComponentPosition => { 318 | const { numLayers } = params; 319 | const intraLayerSpacing = 20; 320 | const interLayerSpacing = 30; 321 | const componentHeight = 35; 322 | const positionalSize = 28; 323 | const initialY = 80; 324 | 325 | const inputTokens = createComponent('input-tokens', 'Input Tokens', 'input', encoderX, initialY); 326 | 327 | // Create separate embedding and positional components 328 | const embeddingY = inputTokens.y + componentHeight + interLayerSpacing; 329 | const embedding = createComponent('embedding', 'Embedding', 'embedding', encoderX, embeddingY); 330 | 331 | // Positional component to the right of embedding in encoder with more spacing 332 | const positional = createPositionalComponent('positional-encoding', encoderX + 100, embeddingY, positionalSize); 333 | 334 | // Arrows: input -> embedding, positional -> embedding (side) 335 | drawArrow(inputTokens, embedding); 336 | drawArrow(positional, embedding, 'side'); 337 | 338 | let prevLayerOutput = embedding; 339 | let currentY = embedding.y + componentHeight / 2 + interLayerSpacing; 340 | 341 | for (let i = 0; i < numLayers; i++) { 342 | const layerGroup = g.append('g').attr('class', 'encoder-layer-group'); 343 | const layerHeight = 4 * componentHeight + 5 * intraLayerSpacing; 344 | const layerY = currentY; 345 | 346 | layerGroup 347 | .append('rect') 348 | .attr('x', encoderX - sectionWidth / 2 + 10) 349 | .attr('y', layerY - intraLayerSpacing / 2) 350 | .attr('width', sectionWidth - 20) 351 | .attr('height', layerHeight) 352 | .attr('rx', 8) 353 | .attr('fill', 'rgba(255, 255, 255, 0.4)') 354 | .attr('stroke', '#9ca3af') 355 | .attr('stroke-dasharray', '3,3'); 356 | 357 | const attention = createComponent( 358 | `encoder-layer-${i}-attention`, 359 | `Multi-Head\nAttention`, 360 | 'attention', 361 | encoderX, 362 | layerY + intraLayerSpacing + componentHeight / 2, 363 | ); 364 | drawArrow(prevLayerOutput, attention); 365 | 366 | const attentionNorm = createComponent(`encoder-layer-${i}-attention-norm`, 'Add & Norm', 'norm', encoderX, attention.y + componentHeight + intraLayerSpacing); 367 | drawArrow(attention, attentionNorm); 368 | drawArrow(prevLayerOutput, attentionNorm, 'residual'); 369 | 370 | const ffn = createComponent(`encoder-layer-${i}-ffn`, 'Feed Forward', 'ffn', encoderX, attentionNorm.y + componentHeight + intraLayerSpacing); 371 | drawArrow(attentionNorm, ffn); 372 | 373 | const ffnNorm = createComponent(`encoder-layer-${i}-ffn-norm`, 'Add & Norm', 'norm', encoderX, ffn.y + componentHeight + intraLayerSpacing); 374 | drawArrow(ffn, ffnNorm); 375 | drawArrow(attentionNorm, ffnNorm, 'residual'); 376 | 377 | prevLayerOutput = ffnNorm; 378 | currentY = ffnNorm.y + componentHeight / 2 + interLayerSpacing; 379 | } 380 | 381 | const encoderOutput = createComponent( 382 | 'encoder-output', 383 | 'Encoder Output', 384 | 'output', 385 | encoderX, 386 | currentY + componentHeight / 2 + intraLayerSpacing, 387 | ); 388 | drawArrow(prevLayerOutput, encoderOutput); 389 | 390 | return encoderOutput; 391 | }; 392 | 393 | const drawDecoderStack = ( 394 | g: d3.Selection, 395 | params: TransformerParams, 396 | createComponent: (id: string, label: string, type: string, x: number, y: number, width?: number, height?: number) => ComponentPosition, 397 | createPositionalComponent: (id: string, x: number, y: number, size?: number) => ComponentPosition, 398 | drawArrow: (from: ComponentPosition, to: ComponentPosition, type?: 'main' | 'residual' | 'cross' | 'side') => void, 399 | decoderX: number, 400 | encoderOutput: ComponentPosition, 401 | sectionWidth: number, 402 | ) => { 403 | const { numLayers } = params; 404 | const intraLayerSpacing = 20; 405 | const interLayerSpacing = 30; 406 | const componentHeight = 35; 407 | const positionalSize = 28; 408 | const initialY = 80; 409 | 410 | const outputTokens = createComponent('output-tokens', 'Output Tokens', 'input', decoderX, initialY); 411 | 412 | // Create separate embedding and positional components 413 | const embeddingY = outputTokens.y + componentHeight + interLayerSpacing; 414 | const embedding = createComponent('decoder-embedding', 'Embedding', 'embedding', decoderX, embeddingY); 415 | 416 | // Positional component to the left of embedding in decoder with more spacing 417 | const positional = createPositionalComponent('decoder-positional-encoding', decoderX - 100, embeddingY, positionalSize); 418 | 419 | // Arrows: output -> embedding, positional -> embedding (side) 420 | drawArrow(outputTokens, embedding); 421 | drawArrow(positional, embedding, 'side'); 422 | 423 | let prevLayerOutput = embedding; 424 | let currentY = embedding.y + componentHeight / 2 + interLayerSpacing; 425 | 426 | for (let i = 0; i < numLayers; i++) { 427 | const layerGroup = g.append('g').attr('class', 'decoder-layer-group'); 428 | const layerHeight = 6 * componentHeight + 7 * intraLayerSpacing; 429 | const layerY = currentY; 430 | 431 | layerGroup 432 | .append('rect') 433 | .attr('x', decoderX - sectionWidth / 2 + 10) 434 | .attr('y', layerY - intraLayerSpacing / 2) 435 | .attr('width', sectionWidth - 20) 436 | .attr('height', layerHeight) 437 | .attr('rx', 8) 438 | .attr('fill', 'rgba(255, 255, 255, 0.4)') 439 | .attr('stroke', '#9ca3af') 440 | .attr('stroke-dasharray', '3,3'); 441 | 442 | const maskedAttention = createComponent( 443 | `decoder-layer-${i}-masked-attention`, 444 | `Masked Multi-Head\nAttention`, 445 | 'attention', 446 | decoderX, 447 | layerY + intraLayerSpacing + componentHeight / 2, 448 | ); 449 | drawArrow(prevLayerOutput, maskedAttention); 450 | 451 | const maskedAttentionNorm = createComponent(`decoder-layer-${i}-masked-attention-norm`, 'Add & Norm', 'norm', decoderX, maskedAttention.y + componentHeight + intraLayerSpacing); 452 | drawArrow(maskedAttention, maskedAttentionNorm); 453 | drawArrow(prevLayerOutput, maskedAttentionNorm, 'residual'); 454 | 455 | const crossAttention = createComponent(`decoder-layer-${i}-cross-attention`, `Cross-Attention`, 'attention', decoderX, maskedAttentionNorm.y + componentHeight + intraLayerSpacing); 456 | drawArrow(maskedAttentionNorm, crossAttention); 457 | drawArrow(encoderOutput, crossAttention, 'cross'); 458 | 459 | const crossAttentionNorm = createComponent(`decoder-layer-${i}-cross-attention-norm`, 'Add & Norm', 'norm', decoderX, crossAttention.y + componentHeight + intraLayerSpacing); 460 | drawArrow(crossAttention, crossAttentionNorm); 461 | drawArrow(maskedAttentionNorm, crossAttentionNorm, 'residual'); 462 | 463 | const ffn = createComponent(`decoder-layer-${i}-ffn`, 'Feed Forward', 'ffn', decoderX, crossAttentionNorm.y + componentHeight + intraLayerSpacing); 464 | drawArrow(crossAttentionNorm, ffn); 465 | 466 | const ffnNorm = createComponent(`decoder-layer-${i}-ffn-norm`, 'Add & Norm', 'norm', decoderX, ffn.y + componentHeight + intraLayerSpacing); 467 | drawArrow(ffn, ffnNorm); 468 | drawArrow(crossAttentionNorm, ffnNorm, 'residual'); 469 | 470 | prevLayerOutput = ffnNorm; 471 | currentY = ffnNorm.y + componentHeight / 2 + interLayerSpacing; 472 | } 473 | 474 | const finalOutput = createComponent( 475 | 'final-output', 476 | 'Linear + Softmax', 477 | 'ffn', 478 | decoderX, 479 | currentY + componentHeight / 2 + intraLayerSpacing, 480 | ); 481 | drawArrow(prevLayerOutput, finalOutput); 482 | 483 | const predictedTokens = createComponent('predicted-tokens', 'Predicted Tokens', 'output', decoderX, finalOutput.y + 60); 484 | drawArrow(finalOutput, predictedTokens); 485 | }; 486 | 487 | export const useTransformerDiagram = ( 488 | params: TransformerParams, 489 | selectedComponent: string | null, 490 | onComponentClick: (componentId: string) => void, 491 | containerRef: RefObject, 492 | ) => { 493 | const svgRef = useRef(null); 494 | 495 | useEffect(() => { 496 | if (!svgRef.current || !containerRef.current) return; 497 | 498 | const svg = d3.select(svgRef.current); 499 | svg.selectAll('*').remove(); 500 | 501 | const containerRect = containerRef.current.getBoundingClientRect(); 502 | const width = Math.max(containerRect.width, 300); 503 | 504 | const componentHeight = 35; 505 | const intraLayerSpacing = 20; 506 | const interLayerSpacing = 30; 507 | const initialY = 80; 508 | 509 | let totalHeight = initialY + componentHeight + interLayerSpacing; 510 | 511 | const encoderLayerH = 4 * componentHeight + 6 * intraLayerSpacing; 512 | const encoderStackHeight = params.numLayers * (encoderLayerH + interLayerSpacing); 513 | 514 | const decoderLayerH = 6 * componentHeight + 8 * intraLayerSpacing; 515 | const decoderStackHeight = params.numLayers * (decoderLayerH + interLayerSpacing); 516 | 517 | totalHeight += Math.max(encoderStackHeight, decoderStackHeight); 518 | totalHeight += componentHeight + interLayerSpacing + intraLayerSpacing; 519 | totalHeight += 60; 520 | totalHeight += 60; 521 | 522 | const height = Math.max(totalHeight, 600); 523 | 524 | svg.attr('viewBox', `0 0 ${width} ${height}`).attr('preserveAspectRatio', 'xMidYMid meet'); 525 | 526 | const mainGroup = svg.append('g'); 527 | 528 | // Simple fixed positioning without zoom functionality 529 | mainGroup.attr('transform', 'translate(20, 20)'); 530 | 531 | drawImprovedTransformerArchitecture(mainGroup, params, selectedComponent, onComponentClick, width - 40, height - 40); 532 | }, [params, selectedComponent, onComponentClick, containerRef]); 533 | 534 | return svgRef; 535 | }; -------------------------------------------------------------------------------- /src/utils/componentDataGenerator.ts: -------------------------------------------------------------------------------- 1 | import { ComponentData, TransformerParams } from '../types/events'; 2 | import { generateMatrix, generateEmbeddings, generatePositionalEncoding } from './randomWeights'; 3 | import { matmul, softmax } from './math'; 4 | import { SequenceGenerator } from './data'; 5 | 6 | // Helper function to parse layer component IDs 7 | interface LayerComponentInfo { 8 | isLayerComponent: boolean; 9 | layerType: 'encoder' | 'decoder' | null; 10 | componentType: 'attention' | 'masked-attention' | 'cross-attention' | 'ffn' | 'norm' | null; 11 | isNorm: boolean; 12 | } 13 | 14 | function parseLayerComponentId(componentId: string): LayerComponentInfo { 15 | const isEncoderLayer = componentId.startsWith('encoder-layer-'); 16 | const isDecoderLayer = componentId.startsWith('decoder-layer-'); 17 | 18 | if (!isEncoderLayer && !isDecoderLayer) { 19 | return { isLayerComponent: false, layerType: null, componentType: null, isNorm: false }; 20 | } 21 | 22 | const layerType = isEncoderLayer ? 'encoder' : 'decoder'; 23 | const isNorm = componentId.includes('-norm'); 24 | 25 | let componentType: LayerComponentInfo['componentType'] = null; 26 | if (componentId.includes('-masked-attention')) { 27 | componentType = 'masked-attention'; 28 | } else if (componentId.includes('-cross-attention')) { 29 | componentType = 'cross-attention'; 30 | } else if (componentId.includes('-attention')) { 31 | componentType = 'attention'; 32 | } else if (componentId.includes('-ffn')) { 33 | componentType = 'ffn'; 34 | } else if (isNorm) { 35 | componentType = 'norm'; 36 | } 37 | 38 | return { isLayerComponent: true, layerType, componentType, isNorm }; 39 | } 40 | 41 | /** 42 | * Generates detailed component data for transformer visualization including inputs, parameters, and outputs. 43 | * @param componentId - Unique identifier for the transformer component 44 | * @param params - Transformer model parameters (dimensions, layers, etc.) 45 | * @returns ComponentData object with structured information about the component 46 | */ 47 | export function generateComponentData(componentId: string, params: TransformerParams): ComponentData { 48 | const { dModel, numHeads, seqLen, posEncoding } = params; 49 | 50 | // Handle input tokens 51 | if (componentId === 'input-tokens' || componentId === 'output-tokens') { 52 | return generateInputTokensData(seqLen, componentId); 53 | } 54 | 55 | // Handle embeddings 56 | if (componentId === 'embedding' || componentId === 'decoder-embedding') { 57 | return generateEmbeddingData(seqLen, dModel); 58 | } 59 | 60 | // Handle positional encoding 61 | if (componentId === 'positional-encoding' || componentId === 'decoder-positional-encoding') { 62 | return generatePositionalEncodingData(seqLen, dModel, posEncoding); 63 | } 64 | 65 | // Handle layer components using helper function 66 | const layerInfo = parseLayerComponentId(componentId); 67 | if (layerInfo.isLayerComponent) { 68 | if (layerInfo.isNorm) { 69 | // Add & Norm components 70 | const blockType = componentId.includes('attention') ? 'attention' : 'ffn'; 71 | return generateAddNormData(seqLen, dModel, blockType); 72 | } 73 | 74 | // Non-norm layer components 75 | switch (layerInfo.componentType) { 76 | case 'attention': 77 | return generateAttentionData(seqLen, dModel, numHeads, 'encoder'); 78 | case 'masked-attention': 79 | return generateAttentionData(seqLen, dModel, numHeads, 'decoder-masked'); 80 | case 'cross-attention': 81 | return generateAttentionData(seqLen, dModel, numHeads, 'cross'); 82 | case 'ffn': 83 | return generateFFNData(seqLen, dModel); 84 | } 85 | } 86 | 87 | // Handle final output components 88 | if (componentId === 'predicted-tokens') { 89 | return generatePredictedTokensData(seqLen); 90 | } 91 | 92 | // Handle encoder output 93 | if (componentId === 'encoder-output') { 94 | return generateEncoderOutputData(seqLen, dModel); 95 | } 96 | 97 | // Handle final output (Linear + Softmax) 98 | if (componentId === 'final-output') { 99 | return generateFinalOutputData(seqLen, dModel); 100 | } 101 | 102 | // Default case - determine category from componentId 103 | return generateDefaultData(componentId); 104 | } 105 | 106 | function generateInputTokensData(seqLen: number, componentId: string): ComponentData { 107 | // Generate simple tokens 108 | const tokens = SequenceGenerator.generateForComponent(componentId, seqLen); 109 | 110 | const isOutputTokens = componentId === 'output-tokens'; 111 | const description = isOutputTokens ? 112 | 'Target sequence tokens for decoder (shifted right during training). These guide the model\'s generation process.' : 113 | 'Raw input tokens before embedding. These discrete symbols represent words, subwords, or characters from the input text.'; 114 | 115 | return { 116 | description, 117 | category: 'tokens', 118 | inputs: [], 119 | parameters: [], 120 | outputs: [ 121 | { 122 | id: 'tokens', 123 | label: isOutputTokens ? 'Output Tokens' : 'Input Tokens', 124 | description: `Tokenized sequence (length: ${seqLen}) with special control tokens`, 125 | type: 'tokens', 126 | data: tokens, 127 | shape: [tokens.length], 128 | metadata: { 129 | vocab_size: 50000, 130 | special_tokens: tokens.filter(t => t.startsWith('<')).length, 131 | sequence_type: isOutputTokens ? 'target' : 'source', 132 | actual_tokens: tokens.filter(t => !t.startsWith('<')).length 133 | } 134 | } 135 | ] 136 | }; 137 | } 138 | 139 | function generateEmbeddingData(seqLen: number, dModel: number): ComponentData { 140 | const embeddings = generateEmbeddings(seqLen, dModel); 141 | const positionalEnc = generatePositionalEncoding(seqLen, dModel, 'sinusoidal'); 142 | 143 | // Generate simple tokens 144 | const tokens = SequenceGenerator.generateSequence(seqLen, 'input'); 145 | 146 | // Combined embeddings with positional encoding 147 | const combinedEmbeddings = embeddings.map((row, i) => 148 | row.map((val, j) => val + positionalEnc[i][j]) 149 | ); 150 | 151 | return { 152 | description: `Converts discrete tokens into dense ${dModel}-dimensional vector representations and combines them with positional encoding to preserve sequence order information.`, 153 | category: 'embedding', 154 | inputs: [ 155 | { 156 | id: 'tokens', 157 | label: 'Input Tokens', 158 | description: 'Discrete token indices from vocabulary', 159 | type: 'tokens', 160 | data: tokens, 161 | shape: [seqLen], 162 | metadata: { vocab_size: 50000, encoding: 'utf-8' } 163 | }, 164 | { 165 | id: 'positional-encoding', 166 | label: 'Positional Encoding (P)', 167 | description: 'Positional information to be added to embeddings', 168 | type: 'matrix', 169 | data: positionalEnc, 170 | shape: [seqLen, dModel], 171 | metadata: { from_component: 'positional-encoding', type: 'sinusoidal' } 172 | } 173 | ], 174 | parameters: [ 175 | { 176 | id: 'embedding-matrix', 177 | label: 'Embedding Matrix (W_e)', 178 | description: 'Learned lookup table mapping token indices to dense vectors', 179 | type: 'matrix', 180 | data: generateMatrix(Math.min(50000, 100), dModel).slice(0, Math.min(20, seqLen + 5)), // Show relevant subset 181 | shape: [50000, dModel], 182 | metadata: { 183 | learnable: true, 184 | initialized: 'xavier_uniform', 185 | parameters: 50000 * dModel, 186 | gradient_norm: 'clipped' 187 | } 188 | } 189 | ], 190 | outputs: [ 191 | { 192 | id: 'positioned-embeddings', 193 | label: 'Positioned Embeddings (H)', 194 | description: 'Token embeddings enhanced with positional information (E + P)', 195 | type: 'matrix', 196 | data: combinedEmbeddings, 197 | shape: [seqLen, dModel], 198 | metadata: { 199 | operation: 'embedding_lookup + positional_encoding', 200 | range: 'normalized', 201 | ready_for_attention: true 202 | } 203 | } 204 | ] 205 | }; 206 | } 207 | 208 | function generatePositionalEncodingData(seqLen: number, dModel: number, posEncoding: string): ComponentData { 209 | const posEnc = generatePositionalEncoding(seqLen, dModel, posEncoding as 'sinusoidal' | 'learned'); 210 | 211 | return { 212 | description: `Generates ${posEncoding} positional information for sequence length ${seqLen}. This helps the model understand the order and position of tokens in the sequence.`, 213 | category: 'positional', 214 | inputs: [], // No inputs - positional encoding is generated independently 215 | parameters: posEncoding === 'learned' ? [ 216 | { 217 | id: 'position-embeddings', 218 | label: 'Learned Position Embeddings (P)', 219 | description: 'Trainable position-specific vectors', 220 | type: 'matrix', 221 | data: posEnc, 222 | shape: [2048, dModel], // Standard max length 223 | metadata: { learnable: true, max_position: 2048 } 224 | } 225 | ] : [], 226 | outputs: [ 227 | { 228 | id: 'positional-encoding', 229 | label: 'Positional Encoding (P)', 230 | description: posEncoding === 'learned' ? 231 | `Learned positional encodings for ${seqLen} positions` : 232 | `Fixed sinusoidal positional encodings computed using mathematical formulas`, 233 | type: 'matrix', 234 | data: posEnc, 235 | shape: [seqLen, dModel], 236 | metadata: { 237 | learnable: posEncoding === 'learned', 238 | formula: posEncoding === 'sinusoidal' ? 'PE(pos,2i) = sin(pos/10000^(2i/d)), PE(pos,2i+1) = cos(pos/10000^(2i/d))' : 'learned_vectors', 239 | max_position: 2048, 240 | frequency_bands: dModel / 2 241 | } 242 | } 243 | ] 244 | }; 245 | } 246 | 247 | function generateAttentionData(seqLen: number, dModel: number, numHeads: number, attentionType: string = 'encoder'): ComponentData { 248 | const headSize = dModel / numHeads; 249 | const input = generateEmbeddings(seqLen, dModel); 250 | 251 | // Generate Q, K, V matrices for visualization 252 | const q = generateMatrix(seqLen, headSize); 253 | const k = generateMatrix(seqLen, headSize); 254 | const v = generateMatrix(seqLen, headSize); 255 | 256 | // Compute attention scores with proper scaling 257 | const scale = 1 / Math.sqrt(headSize); 258 | const kTransposed = k[0].map((_, colIndex) => k.map(row => row[colIndex])); 259 | const rawScores = matmul(q, kTransposed); 260 | let scaledScores = rawScores.map(row => row.map(val => val * scale)); 261 | 262 | // Apply masking for decoder attention 263 | if (attentionType === 'decoder-masked') { 264 | for (let i = 0; i < seqLen; i++) { 265 | for (let j = i + 1; j < seqLen; j++) { 266 | scaledScores[i][j] = -Infinity; 267 | } 268 | } 269 | } 270 | 271 | const attentionWeights = softmax(scaledScores); 272 | const output = matmul(attentionWeights, v); 273 | 274 | // Simplified descriptions based on attention type 275 | const descriptions = { 276 | 'encoder': `Multi-head self-attention with ${numHeads} heads processing ${seqLen} tokens. Each token can attend to all other tokens in the sequence to build contextual representations.`, 277 | 'decoder-masked': `Masked self-attention with ${numHeads} heads for autoregressive generation. Tokens can only attend to previous positions to maintain causality during generation.`, 278 | 'cross': `Cross-attention with ${numHeads} heads connecting encoder and decoder. Decoder queries attend to encoder keys and values, enabling the model to focus on relevant source information.` 279 | }; 280 | 281 | const description = descriptions[attentionType as keyof typeof descriptions] || descriptions.encoder; 282 | 283 | // Create inputs array based on attention type 284 | const inputs = [ 285 | { 286 | id: 'input', 287 | label: attentionType === 'cross' ? 'Decoder Input (X)' : 'Input Embeddings (X)', 288 | description: attentionType === 'cross' ? 289 | 'Input from previous decoder layer (for queries)' : 290 | 'Input from previous layer (embeddings or previous attention block)', 291 | type: 'matrix' as const, 292 | data: input, 293 | shape: [seqLen, dModel] as [number, number], 294 | metadata: { from_previous_layer: true, sequence_length: seqLen } 295 | } 296 | ]; 297 | 298 | // Add encoder output as input for cross-attention 299 | if (attentionType === 'cross') { 300 | inputs.push({ 301 | id: 'encoder-output', 302 | label: 'Encoder Output (H_enc)', 303 | description: 'Encoder output used as keys and values for cross-attention', 304 | type: 'matrix' as const, 305 | data: generateEmbeddings(seqLen, dModel), 306 | shape: [seqLen, dModel] as [number, number], 307 | metadata: { from_encoder: true, used_for_keys_values: true } as any 308 | }); 309 | } 310 | 311 | return { 312 | description, 313 | category: 'attention', 314 | inputs, 315 | parameters: [ 316 | { 317 | id: 'wq', 318 | label: 'Query Weight Matrix (W_Q)', 319 | description: 'Projects input to query space for attention computation', 320 | type: 'matrix', 321 | data: generateMatrix(dModel, dModel), 322 | shape: [dModel, dModel], 323 | metadata: { learnable: true, heads: numHeads, head_dim: headSize } 324 | }, 325 | { 326 | id: 'wk', 327 | label: 'Key Weight Matrix (W_K)', 328 | description: 'Projects input to key space for attention computation', 329 | type: 'matrix', 330 | data: generateMatrix(dModel, dModel), 331 | shape: [dModel, dModel], 332 | metadata: { learnable: true, heads: numHeads, head_dim: headSize } 333 | }, 334 | { 335 | id: 'wv', 336 | label: 'Value Weight Matrix (W_V)', 337 | description: 'Projects input to value space for attention computation', 338 | type: 'matrix', 339 | data: generateMatrix(dModel, dModel), 340 | shape: [dModel, dModel], 341 | metadata: { learnable: true, heads: numHeads, head_dim: headSize } 342 | }, 343 | { 344 | id: 'wo', 345 | label: 'Output Weight Matrix (W_O)', 346 | description: 'Projects concatenated multi-head output back to model dimension', 347 | type: 'matrix', 348 | data: generateMatrix(dModel, dModel), 349 | shape: [dModel, dModel], 350 | metadata: { learnable: true, final_projection: true, total_params: 4 * dModel * dModel } 351 | } 352 | ], 353 | outputs: [ 354 | { 355 | id: 'attention-weights', 356 | label: 'Attention Weights (α)', 357 | description: `Attention probability distribution (${seqLen}×${seqLen}) showing which tokens attend to which`, 358 | type: 'matrix', 359 | data: attentionWeights, 360 | shape: [seqLen, seqLen], 361 | metadata: { 362 | normalized: true, 363 | range: [0, 1], 364 | scaled: true, 365 | masking: attentionType === 'decoder-masked' ? 'causal' : 'none', 366 | sequence_length: seqLen 367 | } 368 | }, 369 | { 370 | id: 'attention-output', 371 | label: 'Attention Output (Z)', 372 | description: 'Weighted combination of values based on attention weights', 373 | type: 'matrix', 374 | data: output, 375 | shape: [seqLen, dModel], 376 | metadata: { 377 | multi_head: true, 378 | heads: numHeads, 379 | attention_type: attentionType, 380 | ready_for_residual: true 381 | } 382 | } 383 | ] 384 | }; 385 | } 386 | 387 | function generateAddNormData(seqLen: number, dModel: number, blockType: 'attention' | 'ffn'): ComponentData { 388 | const residual = generateEmbeddings(seqLen, dModel); 389 | const blockOutput = generateEmbeddings(seqLen, dModel); 390 | const added = residual.map((row, i) => row.map((val, j) => val + blockOutput[i][j])); 391 | 392 | // Simple layer norm simulation 393 | const normalized = added.map(row => { 394 | const mean = row.reduce((sum, val) => sum + val, 0) / row.length; 395 | const variance = row.reduce((sum, val) => sum + Math.pow(val - mean, 2), 0) / row.length; 396 | const std = Math.sqrt(variance + 1e-5); 397 | return row.map(val => (val - mean) / std); 398 | }); 399 | 400 | return { 401 | description: `Residual connection and layer normalization after ${blockType} block. The residual connection helps with gradient flow during training, while layer normalization stabilizes training by normalizing inputs to have zero mean and unit variance.`, 402 | category: blockType, 403 | inputs: [ 404 | { 405 | id: 'residual', 406 | label: 'Residual Input (X)', 407 | description: 'Original input to be added back (skip connection)', 408 | type: 'matrix', 409 | data: residual, 410 | shape: [seqLen, dModel], 411 | metadata: { skip_connection: true, gradient_highway: true } 412 | }, 413 | { 414 | id: 'block-output', 415 | label: `${blockType.charAt(0).toUpperCase() + blockType.slice(1)} Output (${blockType === 'attention' ? 'Z' : 'Y'})`, 416 | description: `Output from the ${blockType} block`, 417 | type: 'matrix', 418 | data: blockOutput, 419 | shape: [seqLen, dModel], 420 | metadata: { from_block: blockType, transformed: true } 421 | } 422 | ], 423 | parameters: [ 424 | { 425 | id: 'gamma', 426 | label: 'Layer Norm Scale (γ)', 427 | description: 'Learnable scale parameter for layer normalization', 428 | type: 'vector', 429 | data: Array(dModel).fill(1), 430 | shape: [dModel], 431 | metadata: { learnable: true, initialized: 'ones', adaptive_scaling: true } 432 | }, 433 | { 434 | id: 'beta', 435 | label: 'Layer Norm Bias (β)', 436 | description: 'Learnable bias parameter for layer normalization', 437 | type: 'vector', 438 | data: Array(dModel).fill(0), 439 | shape: [dModel], 440 | metadata: { learnable: true, initialized: 'zeros', shift_parameter: true } 441 | } 442 | ], 443 | outputs: [ 444 | { 445 | id: 'normalized', 446 | label: 'Normalized Output (H_norm)', 447 | description: 'Layer normalized output after residual connection', 448 | type: 'matrix', 449 | data: normalized, 450 | shape: [seqLen, dModel], 451 | metadata: { 452 | mean: 0, 453 | variance: 1, 454 | operation: 'add_then_norm', 455 | stable_gradients: true, 456 | ready_for_next_layer: true 457 | } 458 | } 459 | ] 460 | }; 461 | } 462 | 463 | function generateFFNData(seqLen: number, dModel: number): ComponentData { 464 | const input = generateEmbeddings(seqLen, dModel); 465 | const hiddenDim = dModel * 4; // Typical FFN expansion 466 | 467 | const w1 = generateMatrix(dModel, hiddenDim); 468 | const w2 = generateMatrix(hiddenDim, dModel); 469 | 470 | // Simulate FFN computation (simplified for visualization) 471 | const hidden = generateMatrix(seqLen, hiddenDim); 472 | const activated = hidden.map(row => row.map(val => Math.max(0, val))); // ReLU 473 | const output = generateMatrix(seqLen, dModel); 474 | 475 | return { 476 | description: `Position-wise feed-forward network with expansion factor 4 (${dModel} → ${hiddenDim} → ${dModel}). Each position is processed independently with two linear transformations and ReLU activation.`, 477 | category: 'ffn', 478 | inputs: [ 479 | { 480 | id: 'input', 481 | label: 'Input from Attention (X)', 482 | description: 'Normalized output from attention block', 483 | type: 'matrix', 484 | data: input, 485 | shape: [seqLen, dModel], 486 | metadata: { from_attention: true, sequence_parallel: true } 487 | } 488 | ], 489 | parameters: [ 490 | { 491 | id: 'w1', 492 | label: 'First Linear Layer (W₁)', 493 | description: `Expands dimension from ${dModel} to ${hiddenDim} (4× expansion)`, 494 | type: 'matrix', 495 | data: w1, 496 | shape: [dModel, hiddenDim], 497 | metadata: { learnable: true, expansion_factor: 4, parameters: dModel * hiddenDim } 498 | }, 499 | { 500 | id: 'b1', 501 | label: 'First Layer Bias (b₁)', 502 | description: 'Bias for first linear transformation', 503 | type: 'vector', 504 | data: Array(hiddenDim).fill(0), 505 | shape: [hiddenDim], 506 | metadata: { learnable: true, initialized: 'zeros' } 507 | }, 508 | { 509 | id: 'w2', 510 | label: 'Second Linear Layer (W₂)', 511 | description: `Projects back from ${hiddenDim} to ${dModel}`, 512 | type: 'matrix', 513 | data: w2, 514 | shape: [hiddenDim, dModel], 515 | metadata: { learnable: true, projection_back: true, parameters: hiddenDim * dModel } 516 | }, 517 | { 518 | id: 'b2', 519 | label: 'Second Layer Bias (b₂)', 520 | description: 'Bias for second linear transformation', 521 | type: 'vector', 522 | data: Array(dModel).fill(0), 523 | shape: [dModel], 524 | metadata: { learnable: true, initialized: 'zeros' } 525 | } 526 | ], 527 | outputs: [ 528 | { 529 | id: 'hidden', 530 | label: 'Hidden Representation (H_hidden)', 531 | description: `Intermediate representation (${hiddenDim}D) after first linear layer and ReLU`, 532 | type: 'matrix', 533 | data: activated, 534 | shape: [seqLen, hiddenDim], 535 | metadata: { activation: 'ReLU', expanded: true, sparse: true } 536 | }, 537 | { 538 | id: 'output', 539 | label: 'FFN Output (Y)', 540 | description: 'Final output after second linear transformation', 541 | type: 'matrix', 542 | data: output, 543 | shape: [seqLen, dModel], 544 | metadata: { ready_for_residual: true, position_wise: true } 545 | } 546 | ] 547 | }; 548 | } 549 | 550 | 551 | 552 | 553 | 554 | 555 | 556 | function generatePredictedTokensData(seqLen: number): ComponentData { 557 | // Generate simple predicted tokens 558 | const predictedTokens = SequenceGenerator.generateSequence(seqLen, 'output'); 559 | 560 | return { 561 | description: `Final predicted tokens from the transformer model after sampling from the probability distribution. These represent the model's best guess for the target sequence.`, 562 | category: 'output', 563 | inputs: [ 564 | { 565 | id: 'probabilities', 566 | label: 'Token Probabilities (P)', 567 | description: 'Probability distribution over vocabulary from softmax', 568 | type: 'matrix', 569 | data: generateMatrix(seqLen, 100), 570 | shape: [seqLen, 50000], 571 | metadata: { from_softmax: true, ready_for_sampling: true } 572 | } 573 | ], 574 | parameters: [ 575 | { 576 | id: 'sampling-strategy', 577 | label: 'Sampling Strategy (σ)', 578 | description: 'Method for selecting tokens from probability distribution', 579 | type: 'text', 580 | data: 'greedy', 581 | metadata: { 582 | options: ['greedy', 'top_k', 'top_p', 'beam_search'], 583 | current: 'greedy', 584 | deterministic: true 585 | } 586 | } 587 | ], 588 | outputs: [ 589 | { 590 | id: 'predicted-tokens', 591 | label: 'Predicted Tokens (ŷ)', 592 | description: 'Final predicted sequence', 593 | type: 'tokens', 594 | data: predictedTokens, 595 | shape: [predictedTokens.length], 596 | metadata: { 597 | sampling_method: 'greedy', 598 | confidence: 0.85, 599 | sequence_length: seqLen, 600 | generated: true 601 | } 602 | } 603 | ] 604 | }; 605 | } 606 | 607 | function generateEncoderOutputData(seqLen: number, dModel: number): ComponentData { 608 | const input = generateEmbeddings(seqLen, dModel); 609 | const output = generateEmbeddings(seqLen, dModel); 610 | 611 | return { 612 | description: `Final output of the encoder block, which consists of a series of encoder layers. This output contains rich contextual representations that are passed to the decoder for cross-attention and subsequent decoding.`, 613 | category: 'output', 614 | inputs: [ 615 | { 616 | id: 'final-encoder-layer-output', 617 | label: 'Final Encoder Layer Output (H_N)', 618 | description: 'Output from the last encoder layer (normalized)', 619 | type: 'matrix', 620 | data: input, 621 | shape: [seqLen, dModel], 622 | metadata: { from_final_encoder_layer: true, sequence_length: seqLen } 623 | } 624 | ], 625 | parameters: [], 626 | outputs: [ 627 | { 628 | id: 'encoder-output', 629 | label: 'Encoder Output (H_enc)', 630 | description: 'Final contextual representations from the encoder block', 631 | type: 'matrix', 632 | data: output, 633 | shape: [seqLen, dModel], 634 | metadata: { 635 | from_encoder: true, 636 | contextualized: true, 637 | used_for_cross_attention: true, 638 | all_layers_processed: true 639 | } 640 | } 641 | ] 642 | }; 643 | } 644 | 645 | function generateFinalOutputData(seqLen: number, dModel: number): ComponentData { 646 | const input = generateEmbeddings(seqLen, dModel); 647 | const logits = generateMatrix(seqLen, Math.min(50000, 100)); // Show subset for visualization 648 | const probs = softmax(logits); 649 | 650 | return { 651 | description: `Combines the final linear layer (language model head) and softmax function to produce the final probability distribution over the vocabulary. This is the complete output processing pipeline of the transformer model.`, 652 | category: 'output', 653 | inputs: [ 654 | { 655 | id: 'decoder-output', 656 | label: 'Decoder Output (H_dec)', 657 | description: 'Final hidden states from the decoder', 658 | type: 'matrix', 659 | data: input, 660 | shape: [seqLen, dModel], 661 | metadata: { from_decoder: true, final_layer: true } 662 | } 663 | ], 664 | parameters: [ 665 | { 666 | id: 'vocab-projection-weights', 667 | label: 'Vocabulary Projection Weights (W_v)', 668 | description: 'Weight matrix for projecting to vocabulary size', 669 | type: 'matrix', 670 | data: generateMatrix(dModel, Math.min(50000, 100)), 671 | shape: [dModel, 50000], 672 | metadata: { learnable: true, vocab_size: 50000 } 673 | }, 674 | { 675 | id: 'vocab-projection-bias', 676 | label: 'Vocabulary Projection Bias (b_v)', 677 | description: 'Bias vector for vocabulary projection', 678 | type: 'vector', 679 | data: Array(Math.min(50000, 100)).fill(0), 680 | shape: [50000], 681 | metadata: { learnable: true, initialized: 'zeros' } 682 | } 683 | ], 684 | outputs: [ 685 | { 686 | id: 'token-probabilities', 687 | label: 'Token Probabilities (P)', 688 | description: 'Final probability distribution over vocabulary', 689 | type: 'matrix', 690 | data: probs, 691 | shape: [seqLen, 50000], 692 | metadata: { 693 | normalized: true, 694 | sum_to_one: true, 695 | ready_for_sampling: true, 696 | temperature_applied: true 697 | } 698 | } 699 | ] 700 | }; 701 | } 702 | 703 | function generateDefaultData(componentId: string): ComponentData { 704 | // Determine category based on component ID 705 | let category: ComponentData['category'] = 'embedding'; 706 | 707 | if (componentId.includes('token')) { 708 | category = 'tokens'; 709 | } else if (componentId.includes('attention')) { 710 | category = 'attention'; 711 | } else if (componentId.includes('ffn') || componentId.includes('feed')) { 712 | category = 'ffn'; 713 | } else if (componentId.includes('output') || componentId.includes('linear') || componentId.includes('softmax')) { 714 | category = 'output'; 715 | } else if (componentId.includes('positional') || componentId.includes('position')) { 716 | category = 'positional'; 717 | } 718 | 719 | return { 720 | description: `Component details for ${componentId}. This component is part of the transformer architecture and contributes to the overall sequence processing.`, 721 | category, 722 | inputs: [], 723 | parameters: [], 724 | outputs: [] 725 | }; 726 | } --------------------------------------------------------------------------------