├── postcss.config.js
├── .gitignore
├── tsconfig.node.json
├── src
├── index.tsx
├── hooks
│ ├── useEventLogger.ts
│ ├── useTransformerMachine.ts
│ └── useTransformerDiagram.ts
├── index.css
├── utils
│ ├── math.ts
│ ├── randomWeights.ts
│ ├── data.ts
│ ├── constants.ts
│ ├── componentTransformations.ts
│ └── componentDataGenerator.ts
├── types
│ └── events.d.ts
├── App.tsx
├── test
│ └── setup.ts
├── components
│ ├── TokenVisualization.tsx
│ ├── TransformerDiagram.tsx
│ ├── HelpMenu.tsx
│ ├── MatrixVisualization.tsx
│ ├── SettingsMenu.tsx
│ └── ComponentDetailsPanel.tsx
└── state
│ └── transformerMachine.ts
├── vite.config.ts
├── tailwind.config.js
├── index.html
├── tsconfig.json
├── .eslintrc.json
├── package.json
├── .github
└── workflows
│ └── build-release.yml
├── LICENSE
├── server
└── server.py
└── README.md
/postcss.config.js:
--------------------------------------------------------------------------------
1 | export default {
2 | plugins: {
3 | tailwindcss: {},
4 | autoprefixer: {},
5 | },
6 | }
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | node_modules/
2 | dist/
3 | .env
4 | .env.local
5 | .env.development.local
6 | .env.test.local
7 | .env.production.local
8 | .env.development
9 | .env.test
10 | .env.production
11 |
--------------------------------------------------------------------------------
/tsconfig.node.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "composite": true,
4 | "skipLibCheck": true,
5 | "module": "ESNext",
6 | "moduleResolution": "bundler",
7 | "allowSyntheticDefaultImports": true
8 | },
9 | "include": ["vite.config.ts"]
10 | }
--------------------------------------------------------------------------------
/src/index.tsx:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import ReactDOM from 'react-dom/client';
3 | import App from './App';
4 | import './index.css';
5 |
6 | const root = ReactDOM.createRoot(
7 | document.getElementById('root') as HTMLElement
8 | );
9 |
10 | root.render(
11 |
12 |
13 |
14 | );
--------------------------------------------------------------------------------
/vite.config.ts:
--------------------------------------------------------------------------------
1 | import { defineConfig } from 'vite'
2 | import react from '@vitejs/plugin-react'
3 |
4 | // https://vitejs.dev/config/
5 | export default defineConfig({
6 | plugins: [react()],
7 | server: {
8 | proxy: {
9 | '/log': {
10 | target: 'http://localhost:3001',
11 | changeOrigin: true,
12 | }
13 | }
14 | }
15 | })
--------------------------------------------------------------------------------
/tailwind.config.js:
--------------------------------------------------------------------------------
1 | /** @type {import('tailwindcss').Config} */
2 | export default {
3 | content: [
4 | "./index.html",
5 | "./src/**/*.{js,ts,jsx,tsx}",
6 | ],
7 | theme: {
8 | extend: {
9 | colors: {
10 | 'cs-blue': '#1062fb',
11 | 'cs-gold': '#ffc500',
12 | 'cs-sky': '#64d3ff',
13 | 'cs-deep': '#002570',
14 | },
15 | fontFamily: {
16 | sans: ['Inter', 'sans-serif'],
17 | },
18 | },
19 | },
20 | plugins: [],
21 | }
--------------------------------------------------------------------------------
/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | Travel Through Transformers
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "target": "ES2020",
4 | "useDefineForClassFields": true,
5 | "lib": ["ES2020", "DOM", "DOM.Iterable"],
6 | "module": "ESNext",
7 | "skipLibCheck": true,
8 |
9 | /* Bundler mode */
10 | "moduleResolution": "bundler",
11 | "allowImportingTsExtensions": true,
12 | "resolveJsonModule": true,
13 | "isolatedModules": true,
14 | "noEmit": true,
15 | "jsx": "react-jsx",
16 |
17 | /* Linting */
18 | "strict": true,
19 | "noUnusedLocals": true,
20 | "noUnusedParameters": true,
21 | "noFallthroughCasesInSwitch": true
22 | },
23 | "include": ["src"],
24 | "references": [{ "path": "./tsconfig.node.json" }]
25 | }
--------------------------------------------------------------------------------
/src/hooks/useEventLogger.ts:
--------------------------------------------------------------------------------
1 | import { useCallback } from 'react';
2 | import { LogEvent } from '../types/events';
3 |
4 | export function useEventLogger() {
5 | const logEvent = useCallback(async (eventType: LogEvent['event_type'], payload: Record) => {
6 | const event: LogEvent = {
7 | timestamp: Date.now(),
8 | event_type: eventType,
9 | payload,
10 | };
11 |
12 | try {
13 | await fetch('/log', {
14 | method: 'POST',
15 | headers: {
16 | 'Content-Type': 'application/json',
17 | },
18 | body: JSON.stringify(event),
19 | });
20 | } catch (error) {
21 | console.warn('Failed to log event:', error);
22 | }
23 | }, []);
24 |
25 | return { logEvent };
26 | }
--------------------------------------------------------------------------------
/.eslintrc.json:
--------------------------------------------------------------------------------
1 | {
2 | "root": true,
3 | "env": {
4 | "browser": true,
5 | "es2020": true
6 | },
7 | "extends": [
8 | "eslint:recommended",
9 | "plugin:react-hooks/recommended"
10 | ],
11 | "ignorePatterns": ["dist", "*.config.js", "*.config.ts"],
12 | "parser": "@typescript-eslint/parser",
13 | "plugins": ["react-refresh", "@typescript-eslint"],
14 | "rules": {
15 | "react-refresh/only-export-components": [
16 | "warn",
17 | { "allowConstantExport": true }
18 | ],
19 | "no-unused-vars": "off",
20 | "@typescript-eslint/no-unused-vars": ["error", { "argsIgnorePattern": "^_" }],
21 | "prefer-const": "error",
22 | "no-var": "error",
23 | "no-console": ["warn", { "allow": ["warn", "error"] }]
24 | }
25 | }
--------------------------------------------------------------------------------
/src/index.css:
--------------------------------------------------------------------------------
1 | @tailwind base;
2 | @tailwind components;
3 | @tailwind utilities;
4 |
5 | @layer base {
6 | html {
7 | font-family: Inter, system-ui, sans-serif;
8 | }
9 | }
10 |
11 | @layer components {
12 | .slider::-webkit-slider-thumb {
13 | appearance: none;
14 | height: 20px;
15 | width: 20px;
16 | border-radius: 50%;
17 | background: #1062fb;
18 | cursor: pointer;
19 | border: 2px solid #ffffff;
20 | box-shadow: 0 0 0 1px rgba(16, 98, 251, 0.2);
21 | }
22 |
23 | .slider::-moz-range-thumb {
24 | height: 20px;
25 | width: 20px;
26 | border-radius: 50%;
27 | background: #1062fb;
28 | cursor: pointer;
29 | border: 2px solid #ffffff;
30 | box-shadow: 0 0 0 1px rgba(16, 98, 251, 0.2);
31 | }
32 |
33 | .slider:focus::-webkit-slider-thumb {
34 | box-shadow: 0 0 0 3px rgba(16, 98, 251, 0.3);
35 | }
36 |
37 | .slider:focus::-moz-range-thumb {
38 | box-shadow: 0 0 0 3px rgba(16, 98, 251, 0.3);
39 | }
40 | }
41 |
42 | @layer utilities {
43 | .text-balance {
44 | text-wrap: balance;
45 | }
46 | }
--------------------------------------------------------------------------------
/src/utils/math.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Performs matrix multiplication between two 2D arrays.
3 | * @param a - First matrix (m x k)
4 | * @param b - Second matrix (k x n)
5 | * @returns Result matrix (m x n)
6 | */
7 | export function matmul(a: number[][], b: number[][]): number[][] {
8 | const result: number[][] = [];
9 | const rows = a.length;
10 | const cols = b[0].length;
11 | const inner = a[0].length;
12 |
13 | for (let i = 0; i < rows; i++) {
14 | result[i] = [];
15 | for (let j = 0; j < cols; j++) {
16 | result[i][j] = 0;
17 | for (let k = 0; k < inner; k++) {
18 | result[i][j] += a[i][k] * b[k][j];
19 | }
20 | }
21 | }
22 |
23 | return result;
24 | }
25 |
26 | /**
27 | * Applies softmax function row-wise to a 2D matrix.
28 | * @param matrix - Input matrix where each row will be normalized
29 | * @returns Matrix with each row normalized to sum to 1
30 | */
31 | export function softmax(matrix: number[][]): number[][] {
32 | return matrix.map(row => {
33 | const max = Math.max(...row);
34 | const exp = row.map(x => Math.exp(x - max));
35 | const sum = exp.reduce((a, b) => a + b, 0);
36 | return exp.map(x => x / sum);
37 | });
38 | }
--------------------------------------------------------------------------------
/src/types/events.d.ts:
--------------------------------------------------------------------------------
1 | import { ZoomLevel as ImportedZoomLevel } from '../utils/constants';
2 |
3 | export type ZoomLevel = ImportedZoomLevel;
4 | export type EventType = 'step' | 'param_change' | 'toggle' | 'token_change' | 'zoom_change' | 'component_select' | 'animation_control' | 'attention_lens' | 'sequence_generation';
5 |
6 | export interface LogEvent {
7 | timestamp: number;
8 | event_type: EventType;
9 | payload: Record;
10 | }
11 |
12 | export interface TransformerParams {
13 | numLayers: number;
14 | dModel: number;
15 | numHeads: number;
16 | seqLen: number;
17 | posEncoding: 'learned' | 'sinusoidal';
18 | dropout: boolean;
19 | }
20 |
21 | export interface ComponentData {
22 | inputs: ComponentSection[];
23 | parameters: ComponentSection[];
24 | outputs: ComponentSection[];
25 | description: string;
26 | category: 'embedding' | 'attention' | 'ffn' | 'output' | 'tokens' | 'positional';
27 | }
28 |
29 | export interface ComponentSection {
30 | id: string;
31 | label: string;
32 | description: string;
33 | type: 'matrix' | 'vector' | 'scalar' | 'tokens' | 'text';
34 | data?: number[][] | number[] | number | string[] | string;
35 | shape?: [number, number] | [number];
36 | metadata?: {
37 | [key: string]: any;
38 | };
39 | }
40 |
41 |
--------------------------------------------------------------------------------
/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "travel-through-transformers",
3 | "private": true,
4 | "version": "0.1.0",
5 | "type": "module",
6 | "scripts": {
7 | "dev": "vite",
8 | "build": "tsc && vite build",
9 | "preview": "vite preview",
10 | "lint": "eslint . --ext ts,tsx --report-unused-disable-directives --max-warnings 0",
11 | "test": "vitest",
12 | "test:e2e": "playwright test"
13 | },
14 | "dependencies": {
15 | "@heroicons/react": "^2.0.18",
16 | "@xstate/react": "^6.0.0",
17 | "d3": "^7.8.5",
18 | "i18next": "^25.3.0",
19 | "react": "^18.2.0",
20 | "react-dom": "^18.2.0",
21 | "react-i18next": "^15.5.3",
22 | "reactflow": "^11.11.4",
23 | "xstate": "^5.20.0"
24 | },
25 | "devDependencies": {
26 | "@playwright/test": "^1.43.0",
27 | "@testing-library/jest-dom": "^6.4.5",
28 | "@testing-library/react": "^15.0.6",
29 | "@types/d3": "^7.4.3",
30 | "@types/react": "^18.2.66",
31 | "@types/react-dom": "^18.2.22",
32 | "@typescript-eslint/eslint-plugin": "^7.2.0",
33 | "@typescript-eslint/parser": "^7.2.0",
34 | "@vitejs/plugin-react": "^4.2.1",
35 | "autoprefixer": "^10.4.19",
36 | "eslint": "^8.57.0",
37 | "eslint-plugin-react-hooks": "^4.6.0",
38 | "eslint-plugin-react-refresh": "^0.4.6",
39 | "postcss": "^8.4.38",
40 | "tailwindcss": "^3.4.3",
41 | "typescript": "^5.2.2",
42 | "vite": "^5.2.0",
43 | "vitest": "^1.6.0"
44 | }
45 | }
46 |
--------------------------------------------------------------------------------
/.github/workflows/build-release.yml:
--------------------------------------------------------------------------------
1 | name: Build and Release
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 |
7 | permissions:
8 | contents: write
9 |
10 | jobs:
11 | build-and-release:
12 | runs-on: ubuntu-latest
13 |
14 | steps:
15 | - name: Checkout repository
16 | uses: actions/checkout@v4
17 |
18 | - name: Setup Node.js
19 | uses: actions/setup-node@v4
20 | with:
21 | node-version: '20'
22 | cache: 'npm'
23 |
24 | - name: Install dependencies
25 | run: npm ci
26 |
27 | - name: Build project
28 | run: npm run build
29 |
30 | - name: Archive build output
31 | run: tar -czf dist.tar.gz dist
32 |
33 | - name: Upload build artifact (for workflow logs)
34 | uses: actions/upload-artifact@v4
35 | with:
36 | name: dist
37 | path: dist
38 |
39 | - name: Create GitHub Release and upload asset
40 | uses: ncipollo/release-action@v1
41 | with:
42 | token: ${{ secrets.GITHUB_TOKEN }}
43 | tag: release-${{ github.run_id }}
44 | name: Release ${{ github.run_number }}
45 | body: |
46 | Automated release for commit ${{ github.sha }} on branch ${{ github.ref_name }}.
47 | - Trigger: push to main
48 | - Run number: ${{ github.run_number }}
49 | artifacts: dist.tar.gz
50 | allowUpdates: false
51 | draft: false
52 | prerelease: false
53 | makeLatest: true
54 |
--------------------------------------------------------------------------------
/src/utils/randomWeights.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Random weight generation utilities with seeded randomness for reproducible results.
3 | * Used to generate matrices, vectors, embeddings, and positional encodings for the simulation.
4 | */
5 | class SeededRandom {
6 | private seed: number;
7 |
8 | constructor(seed: number = 42) {
9 | this.seed = seed;
10 | }
11 |
12 | next(): number {
13 | // Simple linear congruential generator
14 | this.seed = (this.seed * 1664525 + 1013904223) % (2 ** 32);
15 | return this.seed / (2 ** 32);
16 | }
17 |
18 | uniform(min: number = 0, max: number = 1): number {
19 | return min + (max - min) * this.next();
20 | }
21 |
22 | normal(mean: number = 0, std: number = 1): number {
23 | // Box-Muller transform
24 | const u1 = this.next();
25 | const u2 = this.next();
26 | const z0 = Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2);
27 | return mean + std * z0;
28 | }
29 | }
30 |
31 | const rng = new SeededRandom(42);
32 |
33 | /**
34 | * Generates a random matrix with specified dimensions and distribution.
35 | * @param rows - Number of rows
36 | * @param cols - Number of columns
37 | * @param distribution - Random distribution type ('uniform' or 'normal')
38 | * @returns 2D array of random numbers
39 | */
40 | export function generateMatrix(rows: number, cols: number, distribution: 'uniform' | 'normal' = 'normal'): number[][] {
41 | const matrix: number[][] = [];
42 |
43 | for (let i = 0; i < rows; i++) {
44 | matrix[i] = [];
45 | for (let j = 0; j < cols; j++) {
46 | if (distribution === 'uniform') {
47 | matrix[i][j] = rng.uniform(-1, 1);
48 | } else {
49 | matrix[i][j] = rng.normal(0, 0.1);
50 | }
51 | }
52 | }
53 |
54 | return matrix;
55 | }
56 |
57 | export function generateVector(size: number, distribution: 'uniform' | 'normal' = 'normal'): number[] {
58 | const vector: number[] = [];
59 |
60 | for (let i = 0; i < size; i++) {
61 | if (distribution === 'uniform') {
62 | vector[i] = rng.uniform(-1, 1);
63 | } else {
64 | vector[i] = rng.normal(0, 0.1);
65 | }
66 | }
67 |
68 | return vector;
69 | }
70 |
71 | export function generateEmbeddings(vocabSize: number, dModel: number): number[][] {
72 | return generateMatrix(vocabSize, dModel, 'normal');
73 | }
74 |
75 | /**
76 | * Generates positional encoding for transformer models.
77 | * @param seqLen - Sequence length
78 | * @param dModel - Model dimension
79 | * @param type - Type of encoding ('learned' uses random weights, 'sinusoidal' uses mathematical formula)
80 | * @returns 2D array representing positional encoding
81 | */
82 | export function generatePositionalEncoding(seqLen: number, dModel: number, type: 'learned' | 'sinusoidal'): number[][] {
83 | if (type === 'learned') {
84 | return generateMatrix(seqLen, dModel, 'normal');
85 | }
86 |
87 | // Sinusoidal positional encoding
88 | const pos: number[][] = [];
89 |
90 | for (let i = 0; i < seqLen; i++) {
91 | pos[i] = [];
92 | for (let j = 0; j < dModel; j++) {
93 | const angle = i / Math.pow(10000, (2 * Math.floor(j / 2)) / dModel);
94 | pos[i][j] = j % 2 === 0 ? Math.sin(angle) : Math.cos(angle);
95 | }
96 | }
97 |
98 | return pos;
99 | }
--------------------------------------------------------------------------------
/src/App.tsx:
--------------------------------------------------------------------------------
1 | import { useTransformerMachine } from './hooks/useTransformerMachine';
2 | import { TransformerDiagram } from './components/TransformerDiagram';
3 | import { ComponentDetailsPanel } from './components/ComponentDetailsPanel';
4 | import { SettingsMenu } from './components/SettingsMenu';
5 | import { HelpMenu } from './components/HelpMenu';
6 | import { generateComponentData } from './utils/componentDataGenerator';
7 |
8 | function App() {
9 | const { state, actions } = useTransformerMachine();
10 |
11 | // Extract commonly used values from state
12 | const { params, selectedComponent } = state;
13 |
14 | // Component selection handlers
15 | const handleComponentClick = (componentId: string) => {
16 | actions.setZoomLevel('sub_layer', componentId);
17 | };
18 |
19 | const handleComponentClose = () => {
20 | actions.setZoomLevel('global');
21 | };
22 |
23 | // Generate component data for the enhanced panel
24 | const componentData = selectedComponent ? generateComponentData(selectedComponent, params) : null;
25 |
26 | return (
27 |
28 | {/* Header */}
29 |
30 |
31 |
32 |
33 |
34 |
35 | Travel Through Transformers
36 |
37 |
38 |
39 |
40 |
41 |
42 |
46 |
47 |
48 |
49 |
50 |
51 | {/* Main Content */}
52 |
53 |
54 |
55 | {/* Main Architecture View */}
56 |
57 |
62 |
63 |
64 | {/* Component Detail Panel */}
65 |
66 | {selectedComponent && componentData ? (
67 |
72 | ) : (
73 |
74 |
75 |
80 |
No Component Selected
81 |
82 | Click on a transformer component to view its details and internal structure.
83 |
84 |
85 |
86 | )}
87 |
88 |
89 |
90 |
91 | );
92 | }
93 |
94 | export default App;
--------------------------------------------------------------------------------
/src/utils/data.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Generates sequences and manages sequence data for the transformer visualization.
3 | * Provides realistic English-to-French translation examples with proper padding.
4 | */
5 |
6 | export class SequenceGenerator {
7 | // Longer, meaningful sequences for demonstration (20-25 tokens)
8 | private static readonly INPUT_SEQUENCE = [
9 | '', 'The', 'artificial', 'intelligence', 'system', 'processes', 'natural', 'language',
10 | 'text', 'by', 'analyzing', 'patterns', 'in', 'large', 'datasets', 'to', 'understand',
11 | 'context', 'and', 'generate', 'meaningful', 'responses', 'for', 'users', ''
12 | ];
13 |
14 | private static readonly OUTPUT_SEQUENCE = [
15 | '', 'Le', 'système', 'd\'intelligence', 'artificielle', 'traite', 'le', 'texte',
16 | 'en', 'langage', 'naturel', 'en', 'analysant', 'les', 'motifs', 'dans', 'de',
17 | 'grandes', 'bases', 'de', 'données', 'pour', 'comprendre', 'le', 'contexte', ''
18 | ];
19 |
20 |
21 |
22 | /**
23 | * Generates a sequence of specified length, using base sequences as templates
24 | */
25 | static generateSequence(length: number, type: 'input' | 'output' = 'input'): string[] {
26 | const baseSequence = type === 'input' ? this.INPUT_SEQUENCE : this.OUTPUT_SEQUENCE;
27 | return this.adjustSequenceLength(baseSequence, length);
28 | }
29 |
30 | /**
31 | * Adjusts sequence length by truncating or padding as needed
32 | */
33 | static adjustSequenceLength(sequence: string[], newLength: number): string[] {
34 | if (sequence.length === newLength) {
35 | return [...sequence];
36 | } else if (sequence.length > newLength) {
37 | return this.truncateSequence(sequence, newLength);
38 | } else {
39 | return this.padSequence(sequence, newLength);
40 | }
41 | }
42 |
43 | /**
44 | * Truncates sequence while preserving start/end tokens when possible
45 | */
46 | private static truncateSequence(sequence: string[], targetLength: number): string[] {
47 | if (targetLength < 2) {
48 | return sequence.slice(0, targetLength);
49 | }
50 |
51 | // Try to keep SOS and EOS tokens
52 | const hasStart = sequence[0] === '';
53 | const hasEnd = sequence[sequence.length - 1] === '';
54 |
55 | if (hasStart && hasEnd && targetLength >= 2) {
56 | const middleLength = targetLength - 2;
57 | const middleTokens = sequence.slice(1, sequence.length - 1).slice(0, middleLength);
58 | return ['', ...middleTokens, ''];
59 | } else if (hasStart && targetLength >= 1) {
60 | const remainingLength = targetLength - 1;
61 | const remainingTokens = sequence.slice(1, 1 + remainingLength);
62 | return ['', ...remainingTokens];
63 | } else {
64 | return sequence.slice(0, targetLength);
65 | }
66 | }
67 |
68 | /**
69 | * Pads sequence to target length using appropriate padding tokens
70 | */
71 | private static padSequence(sequence: string[], targetLength: number): string[] {
72 | const padded = [...sequence];
73 | const tokensToAdd = targetLength - sequence.length;
74 |
75 | for (let i = 0; i < tokensToAdd; i++) {
76 | padded.push('');
77 | }
78 |
79 | return padded;
80 | }
81 |
82 | /**
83 | * Generates matching output sequence for given input
84 | */
85 | static generateMatchingOutput(inputSequence: string[]): string[] {
86 | // For this demonstration, we'll use our base translation and adjust length
87 | return this.adjustSequenceLength(this.OUTPUT_SEQUENCE, inputSequence.length);
88 | }
89 |
90 | /**
91 | * Component-specific sequence generation
92 | */
93 | static generateForComponent(componentId: string, length: number): string[] {
94 | if (componentId.includes('output') || componentId.includes('decoder') || componentId.includes('predicted')) {
95 | return this.generateSequence(length, 'output');
96 | }
97 | return this.generateSequence(length, 'input');
98 | }
99 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Elastic License 2.0
2 |
3 | URL: https://www.elastic.co/licensing/elastic-license
4 |
5 | ## Acceptance
6 |
7 | By using the software, you agree to all of the terms and conditions below.
8 |
9 | ## Copyright License
10 |
11 | The licensor grants you a non-exclusive, royalty-free, worldwide,
12 | non-sublicensable, non-transferable license to use, copy, distribute, make
13 | available, and prepare derivative works of the software, in each case subject to
14 | the limitations and conditions below.
15 |
16 | ## Limitations
17 |
18 | You may not provide the software to third parties as a hosted or managed
19 | service, where the service provides users with access to any substantial set of
20 | the features or functionality of the software.
21 |
22 | You may not move, change, disable, or circumvent the license key functionality
23 | in the software, and you may not remove or obscure any functionality in the
24 | software that is protected by the license key.
25 |
26 | You may not alter, remove, or obscure any licensing, copyright, or other notices
27 | of the licensor in the software. Any use of the licensor’s trademarks is subject
28 | to applicable law.
29 |
30 | ## Patents
31 |
32 | The licensor grants you a license, under any patent claims the licensor can
33 | license, or becomes able to license, to make, have made, use, sell, offer for
34 | sale, import and have imported the software, in each case subject to the
35 | limitations and conditions in this license. This license does not cover any
36 | patent claims that you cause to be infringed by modifications or additions to
37 | the software. If you or your company make any written claim that the software
38 | infringes or contributes to infringement of any patent, your patent license for
39 | the software granted under these terms ends immediately. If your company makes
40 | such a claim, your patent license ends immediately for work on behalf of your
41 | company.
42 |
43 | ## Notices
44 |
45 | You must ensure that anyone who gets a copy of any part of the software from you
46 | also gets a copy of these terms.
47 |
48 | If you modify the software, you must include in any modified copies of the
49 | software prominent notices stating that you have modified the software.
50 |
51 | ## No Other Rights
52 |
53 | These terms do not imply any licenses other than those expressly granted in
54 | these terms.
55 |
56 | ## Termination
57 |
58 | If you use the software in violation of these terms, such use is not licensed,
59 | and your licenses will automatically terminate. If the licensor provides you
60 | with a notice of your violation, and you cease all violation of this license no
61 | later than 30 days after you receive that notice, your licenses will be
62 | reinstated retroactively. However, if you violate these terms after such
63 | reinstatement, any additional violation of these terms will cause your licenses
64 | to terminate automatically and permanently.
65 |
66 | ## No Liability
67 |
68 | *As far as the law allows, the software comes as is, without any warranty or
69 | condition, and the licensor will not be liable to you for any damages arising
70 | out of these terms or the use or nature of the software, under any kind of
71 | legal claim.*
72 |
73 | ## Definitions
74 |
75 | The **licensor** is the entity offering these terms, and the **software** is the
76 | software the licensor makes available under these terms, including any portion
77 | of it.
78 |
79 | **you** refers to the individual or entity agreeing to these terms.
80 |
81 | **your company** is any legal entity, sole proprietorship, or other kind of
82 | organization that you work for, plus all organizations that have control over,
83 | are under the control of, or are under common control with that
84 | organization. **control** means ownership of substantially all the assets of an
85 | entity, or the power to direct its management and policies by vote, contract, or
86 | otherwise. Control can be direct or indirect.
87 |
88 | **your licenses** are all the licenses granted to you for the software under
89 | these terms.
90 |
91 | **use** means anything you do with the software requiring one of your licenses.
92 |
93 | **trademark** means trademarks, service marks, and similar rights.
94 |
--------------------------------------------------------------------------------
/server/server.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Simple HTTP server for the Travel Through Transformers simulation.
4 | Serves static files and handles event logging.
5 | """
6 |
7 | import json
8 | import os
9 | import sys
10 | from datetime import datetime
11 | from http.server import HTTPServer, SimpleHTTPRequestHandler
12 | from urllib.parse import urlparse
13 |
14 |
15 | class TransformerServerHandler(SimpleHTTPRequestHandler):
16 | def do_POST(self):
17 | """Handle POST requests for logging events."""
18 | if self.path == '/log':
19 | try:
20 | # Read the request body
21 | content_length = int(self.headers['Content-Length'])
22 | post_data = self.rfile.read(content_length)
23 | event_data = json.loads(post_data.decode('utf-8'))
24 |
25 | # Ensure logs directory exists
26 | os.makedirs('logs', exist_ok=True)
27 |
28 | # Append to log file
29 | log_file = 'logs/simulation_log.jsonl'
30 | with open(log_file, 'a', encoding='utf-8') as f:
31 | f.write(json.dumps(event_data) + '\n')
32 |
33 | # Send success response
34 | self.send_response(200)
35 | self.send_header('Content-Type', 'application/json')
36 | self.send_header('Access-Control-Allow-Origin', '*')
37 | self.send_header('Access-Control-Allow-Methods', 'POST, GET, OPTIONS')
38 | self.send_header('Access-Control-Allow-Headers', 'Content-Type')
39 | self.end_headers()
40 |
41 | response = {'status': 'success', 'timestamp': datetime.now().isoformat()}
42 | self.wfile.write(json.dumps(response).encode('utf-8'))
43 |
44 | print(f"[{datetime.now().strftime('%H:%M:%S')}] Logged event: {event_data['event_type']}")
45 |
46 | except Exception as e:
47 | print(f"Error logging event: {e}")
48 | self.send_response(500)
49 | self.send_header('Content-Type', 'application/json')
50 | self.end_headers()
51 | error_response = {'status': 'error', 'message': str(e)}
52 | self.wfile.write(json.dumps(error_response).encode('utf-8'))
53 | else:
54 | self.send_response(404)
55 | self.end_headers()
56 |
57 | def do_OPTIONS(self):
58 | """Handle CORS preflight requests."""
59 | self.send_response(200)
60 | self.send_header('Access-Control-Allow-Origin', '*')
61 | self.send_header('Access-Control-Allow-Methods', 'POST, GET, OPTIONS')
62 | self.send_header('Access-Control-Allow-Headers', 'Content-Type')
63 | self.end_headers()
64 |
65 | def do_GET(self):
66 | """Handle GET requests - serve static files from dist directory."""
67 | if self.path == '/':
68 | self.path = '/index.html'
69 |
70 | # Serve files from dist directory
71 | if os.path.exists(f'dist{self.path}'):
72 | return super().do_GET()
73 |
74 | # Fallback for SPA routing
75 | if not self.path.startswith('/api') and not os.path.exists(f'dist{self.path}'):
76 | self.path = '/index.html'
77 |
78 | return super().do_GET()
79 |
80 | def translate_path(self, path):
81 | """Translate a /-separated PATH to the local filename syntax."""
82 | # Remove query string and fragment
83 | path = urlparse(path).path
84 |
85 | # Serve from dist directory
86 | root = os.path.join(os.getcwd(), 'dist')
87 | return os.path.join(root, path.lstrip('/'))
88 |
89 |
90 | def run_server(port=3000):
91 | """Run the HTTP server."""
92 | server_address = ('', port)
93 | httpd = HTTPServer(server_address, TransformerServerHandler)
94 |
95 | print(f"Starting server on http://localhost:{port}")
96 | print("Serving static files from ./dist/")
97 | print("Logging events to ./logs/simulation_log.jsonl")
98 | print("Press Ctrl+C to stop the server")
99 |
100 | try:
101 | httpd.serve_forever()
102 | except KeyboardInterrupt:
103 | print("\nServer stopped.")
104 | httpd.server_close()
105 |
106 |
107 | if __name__ == '__main__':
108 | port = int(sys.argv[1]) if len(sys.argv) > 1 else 3000
109 | run_server(port)
--------------------------------------------------------------------------------
/src/test/setup.ts:
--------------------------------------------------------------------------------
1 | import '@testing-library/jest-dom';
2 | import { beforeAll, afterAll, vi } from 'vitest';
3 |
4 | // Simple mocks for browser APIs without complex type declarations
5 | const mockIntersectionObserver = vi.fn(() => ({
6 | disconnect: vi.fn(),
7 | observe: vi.fn(),
8 | unobserve: vi.fn(),
9 | }));
10 |
11 | const mockResizeObserver = vi.fn(() => ({
12 | disconnect: vi.fn(),
13 | observe: vi.fn(),
14 | unobserve: vi.fn(),
15 | }));
16 |
17 | // Set up global mocks
18 | Object.defineProperty(window, 'IntersectionObserver', {
19 | writable: true,
20 | configurable: true,
21 | value: mockIntersectionObserver,
22 | });
23 |
24 | Object.defineProperty(window, 'ResizeObserver', {
25 | writable: true,
26 | configurable: true,
27 | value: mockResizeObserver,
28 | });
29 |
30 | Object.defineProperty(window, 'matchMedia', {
31 | writable: true,
32 | value: vi.fn().mockImplementation(query => ({
33 | matches: false,
34 | media: query,
35 | onchange: null,
36 | addListener: vi.fn(),
37 | removeListener: vi.fn(),
38 | addEventListener: vi.fn(),
39 | removeEventListener: vi.fn(),
40 | dispatchEvent: vi.fn(),
41 | })),
42 | });
43 |
44 | // Mock console methods for cleaner test output
45 | const originalConsoleError = console.error;
46 | const originalConsoleWarn = console.warn;
47 |
48 | beforeAll(() => {
49 | // Suppress React development warnings in tests
50 | console.error = (...args: any[]) => {
51 | if (
52 | typeof args[0] === 'string' &&
53 | (args[0].includes('Warning: ReactDOM.render is no longer supported') ||
54 | args[0].includes('Warning: An invalid form control'))
55 | ) {
56 | return;
57 | }
58 | originalConsoleError(...args);
59 | };
60 |
61 | console.warn = (...args: any[]) => {
62 | if (
63 | typeof args[0] === 'string' &&
64 | args[0].includes('React.createFactory() is deprecated')
65 | ) {
66 | return;
67 | }
68 | originalConsoleWarn(...args);
69 | };
70 | });
71 |
72 | afterAll(() => {
73 | console.error = originalConsoleError;
74 | console.warn = originalConsoleWarn;
75 | });
76 |
77 | // Mock D3 for visualization tests
78 | vi.mock('d3', () => ({
79 | select: vi.fn(() => ({
80 | selectAll: vi.fn(() => ({
81 | data: vi.fn(() => ({
82 | enter: vi.fn(() => ({
83 | append: vi.fn(() => ({
84 | attr: vi.fn().mockReturnThis(),
85 | style: vi.fn().mockReturnThis(),
86 | text: vi.fn().mockReturnThis()
87 | }))
88 | })),
89 | exit: vi.fn(() => ({
90 | remove: vi.fn()
91 | })),
92 | attr: vi.fn().mockReturnThis(),
93 | style: vi.fn().mockReturnThis(),
94 | text: vi.fn().mockReturnThis()
95 | }))
96 | })),
97 | append: vi.fn(() => ({
98 | attr: vi.fn().mockReturnThis(),
99 | style: vi.fn().mockReturnThis()
100 | })),
101 | attr: vi.fn().mockReturnThis(),
102 | style: vi.fn().mockReturnThis(),
103 | on: vi.fn().mockReturnThis()
104 | })),
105 | scaleLinear: vi.fn(() => ({
106 | domain: vi.fn(() => ({
107 | range: vi.fn().mockReturnThis()
108 | })),
109 | range: vi.fn().mockReturnThis()
110 | })),
111 | extent: vi.fn(() => [0, 1])
112 | }));
113 |
114 | // Educational simulation specific setup and utilities
115 | export const createMockTransformerParams = () => ({
116 | seqLen: 5,
117 | dModel: 64,
118 | numHeads: 4,
119 | numLayers: 2,
120 | vocabSize: 1000,
121 | maxSeqLen: 10,
122 | positionEncoding: 'sinusoidal' as const,
123 | dropout: 0.1,
124 | showValues: false
125 | });
126 |
127 | export const createMockComponentData = () => ({
128 | description: 'Test component for educational simulation',
129 | category: 'attention' as const,
130 | inputs: [],
131 | parameters: [],
132 | outputs: []
133 | });
134 |
135 | // Utility for testing async operations in educational simulations
136 | export const waitForSimulation = async (ms: number = 100) => {
137 | await new Promise(resolve => setTimeout(resolve, ms));
138 | };
139 |
140 | // Mock XState machine for state management tests
141 | export const createMockMachineContext = () => ({
142 | currentStep: 0,
143 | totalSteps: 10,
144 | activeToken: 0,
145 | focusToken: false,
146 | zoomLevel: 'global' as const,
147 | params: createMockTransformerParams(),
148 | isPlaying: false,
149 | playbackSpeed: 1.0,
150 | settingsOpen: false,
151 | darkMode: false,
152 | colorBlindMode: false,
153 | attentionLensActive: false,
154 | attentionTopK: 5,
155 | inputSequence: ['The', 'cat', 'sat'],
156 | outputSequence: ['Le', 'chat', 'était'],
157 | lastFrameTime: 0,
158 | currentFPS: 60
159 | });
--------------------------------------------------------------------------------
/src/components/TokenVisualization.tsx:
--------------------------------------------------------------------------------
1 | /**
2 | * TokenVisualization component for displaying tokenized sequences
3 | * with visual styling for different token types and interactive tooltips.
4 | */
5 |
6 | interface TokenVisualizationProps {
7 | tokens: string[];
8 | label: string;
9 | className?: string;
10 | }
11 |
12 | export function TokenVisualization({ tokens, label, className = '' }: TokenVisualizationProps) {
13 | const getTokenStyle = (token: string): string => {
14 | // Start token styling
15 | if (token === '' || token === '') {
16 | return 'bg-green-100 text-green-800 border-green-300';
17 | }
18 | // End token styling
19 | if (token === '') {
20 | return 'bg-red-100 text-red-800 border-red-300';
21 | }
22 | // Pad token styling
23 | if (token === '') {
24 | return 'bg-gray-100 text-gray-800 border-gray-300';
25 | }
26 |
27 | // Regular token styling (everything else)
28 | return 'bg-blue-50 text-blue-800 border-blue-200';
29 | };
30 |
31 | const getTokenTooltip = (token: string, index: number): string => {
32 | if (token === '' || token === '') {
33 | return 'Start Token - Marks the beginning of input';
34 | }
35 | if (token === '') {
36 | return 'End Token - Marks the end of input';
37 | }
38 | if (token === '') {
39 | return 'Pad Token - Used to fill sequences to fixed length';
40 | }
41 |
42 | return `Regular Token ${index}: "${token}" - Vocabulary token`;
43 | };
44 |
45 | return (
46 |
47 |
48 |
{label}
49 |
50 | Tokenized sequence with special tokens
51 |
52 |
53 |
54 |
55 | {/* Token sequence visualization */}
56 |
57 | {tokens.map((token, index) => (
58 |
67 |
68 |
69 | {index}
70 |
71 |
72 | {token}
73 |
74 |
75 |
76 | ))}
77 |
78 |
79 | {/* Token count and statistics */}
80 |
81 |
82 |
83 | Total Tokens:
84 | {tokens.length}
85 |
86 |
87 | Special Tokens:
88 |
89 | {tokens.filter(token => token.startsWith('<') && token.endsWith('>')).length}
90 |
91 |
92 |
93 |
94 |
95 | {/* Legend */}
96 |
97 |
Legend:
98 |
99 |
100 |
101 |
Start token
102 |
103 |
104 |
105 |
Regular token
106 |
107 |
108 |
109 |
End token
110 |
111 |
112 |
113 |
Pad token
114 |
115 |
116 |
117 |
118 |
119 | );
120 | }
--------------------------------------------------------------------------------
/src/components/TransformerDiagram.tsx:
--------------------------------------------------------------------------------
1 | import { useRef } from 'react';
2 | import { TransformerParams } from '../types/events';
3 | import { useTransformerDiagram } from '../hooks/useTransformerDiagram';
4 |
5 | interface TransformerDiagramProps {
6 | params: TransformerParams;
7 | selectedComponent: string | null;
8 | onComponentClick: (componentId: string) => void;
9 | className?: string;
10 | }
11 |
12 | export function TransformerDiagram({
13 | params,
14 | selectedComponent,
15 | onComponentClick,
16 | className = '',
17 | }: TransformerDiagramProps) {
18 | const containerRef = useRef(null);
19 | const svgRef = useTransformerDiagram(params, selectedComponent, onComponentClick, containerRef);
20 |
21 | return (
22 |
26 |
27 |
Transformer Architecture
28 |
29 | {/* Organized Legend */}
30 |
31 | {/* Architectural Sections */}
32 |
33 |
Architecture Sections
34 |
44 |
45 |
46 | {/* Processing Blocks */}
47 |
48 |
Processing Components
49 |
50 |
54 |
58 |
59 |
60 |
Positional
61 |
62 |
66 |
67 |
68 |
Add & Norm
69 |
70 |
71 |
72 |
Feed Forward
73 |
74 |
78 |
79 |
80 |
81 | {/* Connection Types */}
82 |
83 |
Data Flow
84 |
85 |
89 |
93 |
94 |
95 |
Cross-Attention
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 | );
107 | }
--------------------------------------------------------------------------------
/src/hooks/useTransformerMachine.ts:
--------------------------------------------------------------------------------
1 | import { useMachine } from '@xstate/react';
2 | import { useEffect, useRef } from 'react';
3 | import transformerMachine, { TransformerContext } from '../state/transformerMachine';
4 | import { useEventLogger } from './useEventLogger';
5 |
6 | export function useTransformerMachine() {
7 | const [state, send] = useMachine(transformerMachine);
8 | const { logEvent } = useEventLogger();
9 | const animationFrameRef = useRef();
10 |
11 | // Animation loop for auto-play
12 | useEffect(() => {
13 | if (state.context.isPlaying) {
14 | const animate = () => {
15 | send({ type: 'TICK' });
16 | animationFrameRef.current = requestAnimationFrame(animate);
17 | };
18 | animationFrameRef.current = requestAnimationFrame(animate);
19 | } else {
20 | if (animationFrameRef.current) {
21 | cancelAnimationFrame(animationFrameRef.current);
22 | }
23 | }
24 |
25 | return () => {
26 | if (animationFrameRef.current) {
27 | cancelAnimationFrame(animationFrameRef.current);
28 | }
29 | };
30 | }, [state.context.isPlaying, send]);
31 |
32 | // Event logging will be handled by individual action calls
33 |
34 | // Convenience actions with logging
35 | const actions = {
36 | // Playback controls
37 | play: () => {
38 | send({ type: 'PLAY' });
39 | logEvent('animation_control', { action: 'play', current_step: state.context.currentStep });
40 | },
41 | pause: () => {
42 | send({ type: 'PAUSE' });
43 | logEvent('animation_control', { action: 'pause', current_step: state.context.currentStep });
44 | },
45 | nextStep: () => {
46 | send({ type: 'NEXT_STEP' });
47 | logEvent('step', { direction: 'next', step: state.context.currentStep + 1 });
48 | },
49 | prevStep: () => {
50 | send({ type: 'PREV_STEP' });
51 | logEvent('step', { direction: 'prev', step: Math.max(0, state.context.currentStep - 1) });
52 | },
53 | setStep: (step: number) => {
54 | send({ type: 'SET_STEP', step });
55 | logEvent('step', { direction: 'jump', step });
56 | },
57 | reset: () => {
58 | send({ type: 'RESET' });
59 | logEvent('step', { direction: 'reset', step: 0 });
60 | },
61 | setPlaybackSpeed: (speed: number) => {
62 | send({ type: 'SET_PLAYBACK_SPEED', speed });
63 | logEvent('animation_control', { action: 'speed_change', speed, current_step: state.context.currentStep });
64 | },
65 |
66 | // Token controls
67 | setActiveToken: (tokenIndex: number) => {
68 | send({ type: 'SET_ACTIVE_TOKEN', tokenIndex });
69 | logEvent('token_change', { tokenIndex, focusToken: state.context.focusToken });
70 | },
71 | toggleFocusToken: () => {
72 | send({ type: 'TOGGLE_FOCUS_TOKEN' });
73 | logEvent('toggle', { feature: 'focus_token', enabled: !state.context.focusToken });
74 | },
75 |
76 | // Zoom controls
77 | zoomIn: (component?: string, layer?: number) => {
78 | send({ type: 'ZOOM_IN', component, layer });
79 | logEvent('zoom_change', { zoomLevel: 'zoom_in', component, layer });
80 | },
81 | zoomOut: () => {
82 | send({ type: 'ZOOM_OUT' });
83 | logEvent('zoom_change', { zoomLevel: 'zoom_out' });
84 | },
85 | setZoomLevel: (level: string, component?: string) => {
86 | send({
87 | type: 'SET_ZOOM_LEVEL',
88 | level: level as TransformerContext['zoomLevel'],
89 | component
90 | });
91 | logEvent('zoom_change', { zoomLevel: level, component });
92 | },
93 |
94 | // Parameter controls
95 | updateParams: (params: Partial) => {
96 | send({ type: 'UPDATE_PARAMS', params });
97 | logEvent('param_change', { changes: params, newParams: { ...state.context.params, ...params } });
98 | },
99 |
100 |
101 |
102 | // UI controls
103 | toggleSettings: () => {
104 | send({ type: 'TOGGLE_SETTINGS' });
105 | logEvent('toggle', { feature: 'settings', enabled: !state.context.settingsOpen });
106 | },
107 | toggleDarkMode: () => {
108 | send({ type: 'TOGGLE_DARK_MODE' });
109 | logEvent('toggle', { feature: 'dark_mode', enabled: !state.context.darkMode });
110 | },
111 | toggleColorBlindMode: () => {
112 | send({ type: 'TOGGLE_COLOR_BLIND_MODE' });
113 | logEvent('toggle', { feature: 'color_blind_mode', enabled: !state.context.colorBlindMode });
114 | },
115 |
116 | // Attention controls
117 | toggleAttentionLens: () => {
118 | send({ type: 'TOGGLE_ATTENTION_LENS' });
119 | logEvent('attention_lens', { active: !state.context.attentionLensActive, top_k: state.context.attentionTopK });
120 | },
121 | setAttentionTopK: (k: number) => {
122 | send({ type: 'SET_ATTENTION_TOP_K', k });
123 | logEvent('attention_lens', { active: state.context.attentionLensActive, top_k: k });
124 | },
125 |
126 | // Error handling
127 | setError: (message: string) => send({ type: 'ERROR', message }),
128 | clearError: () => send({ type: 'CLEAR_ERROR' }),
129 | };
130 |
131 | return {
132 | state: state.context,
133 | isIdle: state.matches('idle'),
134 | isPlaying: state.matches('playing'),
135 | hasError: state.matches('error'),
136 | actions,
137 | };
138 | }
--------------------------------------------------------------------------------
/src/components/HelpMenu.tsx:
--------------------------------------------------------------------------------
1 | import { useState, useRef, useEffect } from 'react';
2 | import { QuestionMarkCircleIcon, XMarkIcon, BookOpenIcon, CursorArrowRaysIcon } from '@heroicons/react/24/outline';
3 |
4 | export function HelpMenu() {
5 | const [isOpen, setIsOpen] = useState(false);
6 | const menuRef = useRef(null);
7 |
8 | // Close menu when clicking outside
9 | useEffect(() => {
10 | function handleClickOutside(event: MouseEvent) {
11 | if (menuRef.current && !menuRef.current.contains(event.target as Node)) {
12 | setIsOpen(false);
13 | }
14 | }
15 |
16 | if (isOpen) {
17 | document.addEventListener('mousedown', handleClickOutside);
18 | return () => document.removeEventListener('mousedown', handleClickOutside);
19 | }
20 | }, [isOpen]);
21 |
22 | return (
23 |
24 | {/* Help Button */}
25 |
33 |
34 | {/* Help Panel */}
35 | {isOpen && (
36 |
37 |
38 | {/* Header */}
39 |
40 |
41 |
42 |
43 |
44 |
45 |
Learning Guide
46 |
Explore transformer architecture interactively
47 |
48 |
49 |
55 |
56 |
57 |
58 |
59 | {/* Interactive Features */}
60 |
61 |
62 |
63 |
Interactive Features
64 |
65 |
66 |
67 |
🔍 Component Explorer
68 |
69 | Click any component in the diagram to see its detailed structure,
70 | inputs, parameters, and mathematical operations.
71 |
72 |
73 |
74 |
⚙️ Parameter Controls
75 |
76 | Adjust model settings like number of layers, attention heads,
77 | and sequence length to see how they change the architecture.
78 |
79 |
80 |
81 |
📊 Data Visualization
82 |
83 | Switch between views to see abstract representations or
84 | actual numerical matrices and attention patterns.
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 | {/* Understanding the Diagram */}
93 |
94 |
Understanding the Diagram
95 |
96 |
97 |
98 |
Encoder components (process input)
99 |
100 |
101 |
102 |
Decoder components (generate output)
103 |
104 |
105 |
106 |
Cross-attention (encoder → decoder)
107 |
108 |
109 |
110 |
Residual connections (skip paths)
111 |
112 |
113 |
114 |
115 |
116 |
117 | )}
118 |
119 | );
120 | }
--------------------------------------------------------------------------------
/src/utils/constants.ts:
--------------------------------------------------------------------------------
1 | // Core constants for the transformer simulation
2 | // Consolidated from multiple files for better import clarity
3 |
4 | import { TransformerParams } from '../types/events';
5 |
6 | // Theme and Colors
7 | export const COLORS = {
8 | // Primary Beyond Blue palette
9 | primary: '#1062fb',
10 | gold: '#ffc500',
11 |
12 | // Supporting colors
13 | white: '#ffffff',
14 | black: '#000000',
15 | gray: {
16 | 50: '#f9fafb',
17 | 100: '#f3f4f6',
18 | 200: '#e5e7eb',
19 | 300: '#d1d5db',
20 | 400: '#9ca3af',
21 | 500: '#6b7280',
22 | 600: '#4b5563',
23 | 700: '#374151',
24 | 800: '#1f2937',
25 | 900: '#111827',
26 | },
27 |
28 | // Semantic colors
29 | text: '#111827',
30 | textSecondary: '#6b7280',
31 | background: '#ffffff',
32 | border: '#e5e7eb',
33 |
34 | // Status colors
35 | success: '#10b981',
36 | warning: '#f59e0b',
37 | error: '#ef4444',
38 |
39 | // Component colors - CodeSignal brand colors
40 | encoder: '#1062fb', // Beyond Blue (Primary)
41 | decoder: '#002570', // Dive Deep Blue (Secondary)
42 | cross_attention: '#64D3FF', // SignalLight Blue (Secondary)
43 | residual: '#9ca3af', // Gray 400
44 | layer_norm: '#21CF82', // Victory Green (Tertiary)
45 | ffn: '#FF7232', // FiresIDE Orange (Tertiary)
46 |
47 | // Processing block colors
48 | input_tokens: '#6b7280', // Gray 500
49 | embedding: '#1062FB', // Beyond Blue (Primary)
50 | positional: '#FFC500', // Gold Standard Yellow (Secondary)
51 | attention: '#E6193F', // Ruby Red (Tertiary)
52 | output: '#002570', // Dive Deep Blue (Secondary)
53 |
54 | // Matrix cell colors (Viridis palette - color-blind safe)
55 | viridis: [
56 | '#440154', '#482777', '#3f4a8a', '#31678e',
57 | '#26838f', '#1f9d8a', '#6cce5a', '#b6de2b',
58 | '#fee825'
59 | ],
60 | };
61 |
62 | // Zoom Levels
63 | export const ZOOM_LEVELS = {
64 | GLOBAL: 'global' as const,
65 | LAYER: 'layer' as const,
66 | SUB_LAYER: 'sub_layer' as const,
67 | TENSOR_CELL: 'tensor_cell' as const,
68 | };
69 |
70 | export type ZoomLevel = typeof ZOOM_LEVELS[keyof typeof ZOOM_LEVELS];
71 |
72 | // Default Parameters
73 | export const DEFAULT_PARAMS: TransformerParams = {
74 | numLayers: 1,
75 | dModel: 64,
76 | numHeads: 4,
77 | seqLen: 8,
78 | posEncoding: 'sinusoidal',
79 | dropout: false,
80 | };
81 |
82 | export const PARAM_LIMITS = {
83 | numLayers: { min: 1, max: 6 },
84 | dModel: { min: 32, max: 512, step: 32 },
85 | numHeads: { min: 1, max: 16 },
86 | seqLen: { min: 3, max: 32 },
87 | };
88 |
89 | // Transformer Steps
90 | export const TRANSFORMER_STEPS = [
91 | // Input processing
92 | {
93 | id: 'input_tokens',
94 | label: 'Input Tokens',
95 | description: 'Source sequence tokens as discrete symbols',
96 | category: 'input',
97 | component: 'input',
98 | zoomLevel: ZOOM_LEVELS.TENSOR_CELL
99 | },
100 | {
101 | id: 'input_embedding',
102 | label: 'Input Embedding',
103 | description: 'Convert source tokens to dense vectors',
104 | category: 'embedding',
105 | component: 'input',
106 | zoomLevel: ZOOM_LEVELS.TENSOR_CELL
107 | },
108 | {
109 | id: 'input_positional',
110 | label: 'Input Positional Encoding',
111 | description: 'Add position information to input embeddings',
112 | category: 'embedding',
113 | component: 'input',
114 | zoomLevel: ZOOM_LEVELS.TENSOR_CELL
115 | },
116 |
117 | // Encoder stack
118 | {
119 | id: 'encoder_self_attention',
120 | label: 'Encoder Self-Attention',
121 | description: 'Encoder attends to input sequence',
122 | category: 'encoder',
123 | component: 'encoder',
124 | zoomLevel: ZOOM_LEVELS.SUB_LAYER
125 | },
126 | {
127 | id: 'encoder_ffn',
128 | label: 'Encoder Feed-Forward',
129 | description: 'Process encoder representations through FFN',
130 | category: 'encoder',
131 | component: 'encoder',
132 | zoomLevel: ZOOM_LEVELS.SUB_LAYER
133 | },
134 |
135 | // Decoder input
136 | {
137 | id: 'output_tokens',
138 | label: 'Output Tokens',
139 | description: 'Target sequence tokens (shifted right)',
140 | category: 'input',
141 | component: 'decoder',
142 | zoomLevel: ZOOM_LEVELS.TENSOR_CELL
143 | },
144 | {
145 | id: 'output_embedding',
146 | label: 'Output Embedding',
147 | description: 'Convert target tokens to dense vectors',
148 | category: 'embedding',
149 | component: 'decoder',
150 | zoomLevel: ZOOM_LEVELS.TENSOR_CELL
151 | },
152 | {
153 | id: 'output_positional',
154 | label: 'Output Positional Encoding',
155 | description: 'Add position information to output embeddings',
156 | category: 'embedding',
157 | component: 'decoder',
158 | zoomLevel: ZOOM_LEVELS.TENSOR_CELL
159 | },
160 |
161 | // Decoder stack
162 | {
163 | id: 'decoder_self_attention',
164 | label: 'Decoder Self-Attention (Masked)',
165 | description: 'Decoder attends to previous output tokens',
166 | category: 'decoder',
167 | component: 'decoder',
168 | zoomLevel: ZOOM_LEVELS.SUB_LAYER
169 | },
170 | {
171 | id: 'cross_attention',
172 | label: 'Cross-Attention',
173 | description: 'Decoder queries attend to encoder keys/values',
174 | category: 'cross_attention',
175 | component: 'cross_attention',
176 | zoomLevel: ZOOM_LEVELS.SUB_LAYER
177 | },
178 | {
179 | id: 'decoder_ffn',
180 | label: 'Decoder Feed-Forward',
181 | description: 'Process decoder representations through FFN',
182 | category: 'decoder',
183 | component: 'decoder',
184 | zoomLevel: ZOOM_LEVELS.SUB_LAYER
185 | },
186 |
187 | // Output processing
188 | {
189 | id: 'linear_layer',
190 | label: 'Linear Layer',
191 | description: 'Project decoder output to vocabulary space',
192 | category: 'output',
193 | component: 'output',
194 | zoomLevel: ZOOM_LEVELS.LAYER
195 | },
196 | {
197 | id: 'softmax',
198 | label: 'Softmax',
199 | description: 'Convert logits to probability distribution',
200 | category: 'output',
201 | component: 'output',
202 | zoomLevel: ZOOM_LEVELS.TENSOR_CELL
203 | },
204 | ];
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Travel Through Transformers
2 |
3 | An interactive web-based simulation that lets learners follow a single token step-by-step through every component of a Transformer encoder/decoder stack.
4 |
5 | ## Features
6 |
7 | - **Component-focused visualization**: Click through different transformer components to see detailed internals
8 | - **Interactive parameters**: Adjust layers, model dimensions, attention heads, and sequence length in real-time
9 | - **Dual visualization modes**: Abstract shape view for understanding flow, or detailed numerical values
10 | - **Multi-head attention visualization**: See how different attention heads process information
11 | - **Event logging**: All interactions are logged for analytics
12 |
13 | ## Quick Start
14 |
15 | ### Prerequisites
16 |
17 | - Node.js 16+ and npm
18 | - Python 3.7+
19 |
20 | ### Installation
21 |
22 | 1. **Install dependencies**:
23 | ```bash
24 | npm install
25 | ```
26 |
27 | 2. **Build the application**:
28 | ```bash
29 | npm run build
30 | ```
31 |
32 | 3. **Start the server**:
33 | ```bash
34 | python server/server.py
35 | ```
36 |
37 | 4. **Open your browser**:
38 | Navigate to `http://localhost:3000`
39 |
40 | ### Development Mode
41 |
42 | For development with hot reloading:
43 |
44 | ```bash
45 | # Terminal 1 - Start the development server
46 | npm run dev
47 |
48 | # Terminal 2 - Start the logging server
49 | python server/server.py
50 | ```
51 |
52 | Then open `http://localhost:5173` (development) or `http://localhost:3000` (production).
53 |
54 | ## Usage
55 |
56 | ### Controls
57 |
58 | 1. **Component Selection**: Click on transformer components in the diagram to explore their internals
59 | 2. **Show Values Toggle**: Switch between abstract block view and actual numerical matrices
60 | 3. **Model Parameters**:
61 | - Adjust number of layers (1-6)
62 | - Change model dimension (32-512)
63 | - Set attention heads (1-8, must divide model dimension)
64 | - Modify sequence length (3-10)
65 | - Choose positional encoding type
66 | - Enable dropout visualization
67 |
68 | ### Understanding the Visualization
69 |
70 | #### Abstract Mode (Default)
71 | - Colored blocks represent matrices with dimensions shown
72 | - Different colors indicate different types of operations
73 | - Active components highlighted with distinctive colors
74 | - Attention heads shown in different colors
75 |
76 | #### Values Mode
77 | - Heat maps show actual numerical values
78 | - Color intensity represents magnitude
79 | - Hover for precise values
80 | - Active components highlighted
81 |
82 | ## Architecture
83 |
84 | ```
85 | travel-through-transformers/
86 | ├── src/
87 | │ ├── components/ # React UI components
88 | │ │ ├── MatrixVisualization.tsx # D3-powered matrix visualization
89 | │ │ ├── TokenVisualization.tsx # Token display and interaction
90 | │ │ ├── SettingsMenu.tsx # Parameter controls
91 | │ │ ├── ComponentDetailsPanel.tsx # Component detail exploration
92 | │ │ ├── HelpMenu.tsx # Help system
93 | │ │ └── TransformerDiagram.tsx # Main architecture diagram
94 | │ ├── hooks/ # Custom React hooks
95 | │ │ ├── useTransformerMachine.ts # Main state management (XState)
96 | │ │ ├── useTransformerDiagram.ts # Diagram interaction logic
97 | │ │ └── useEventLogger.ts # Analytics logging
98 | │ ├── utils/ # Utility functions
99 | │ │ ├── math.ts # Matrix operations
100 | │ │ ├── randomWeights.ts # Seeded random generation
101 | │ │ ├── constants.ts # Configuration and steps
102 | │ │ ├── data.ts # Sample data generation
103 | │ │ ├── componentDataGenerator.ts # Component data creation
104 | │ │ └── componentTransformations.ts # Math transformations
105 | │ ├── state/ # State management
106 | │ │ └── transformerMachine.ts # XState machine definition
107 | │ └── types/ # TypeScript definitions
108 | │ └── events.d.ts # Event and parameter types
109 | ├── server/
110 | │ └── server.py # Python logging server
111 | └── logs/ # Event logs (generated)
112 | ```
113 |
114 | ## Educational Goals
115 |
116 | This simulation helps learners understand:
117 |
118 | 1. **Component Architecture**: How transformer components are organized and connected
119 | 2. **Attention Mechanism**: How queries, keys, and values interact
120 | 3. **Multi-Head Attention**: How different heads capture different patterns
121 | 4. **Residual Connections**: How information flows around attention blocks
122 | 5. **Layer Normalization**: How activations are normalized
123 | 6. **Feed-Forward Networks**: How information is processed after attention
124 | 7. **Positional Encoding**: How position information is added to tokens
125 | 8. **Cross-Attention**: How decoder attends to encoder representations
126 |
127 | ## Technical Details
128 |
129 | - **Frontend**: React + TypeScript + Vite
130 | - **State Management**: XState for complex state transitions
131 | - **Visualization**: D3.js for interactive SVG graphics
132 | - **Styling**: TailwindCSS with CodeSignal brand colors
133 | - **Math**: Custom lightweight tensor operations (no external ML libraries)
134 | - **Backend**: Simple Python HTTP server for logging
135 | - **Data**: Seeded random weights for reproducible results
136 |
137 | ## Customization
138 |
139 | ### Adding New Components
140 |
141 | 1. Add component definition to transformer machine states
142 | 2. Implement component logic in `componentDataGenerator.ts`
143 | 3. Add appropriate visualizations in component files
144 | 4. Update component transformations in `componentTransformations.ts`
145 |
146 | ### Modifying Visualization
147 |
148 | - Matrix colors: Edit `COLORS` in `constants.ts`
149 | - D3 rendering: Modify `MatrixVisualization.tsx`
150 | - Component descriptions: Update transformer machine configuration
151 |
152 | ### Analytics
153 |
154 | Event logs are stored in `logs/simulation_log.jsonl` with schema:
155 | ```json
156 | {
157 | "timestamp": 1625239200,
158 | "event_type": "param_change" | "component_select" | "toggle" | "zoom_change",
159 | "payload": { /* event-specific data */ }
160 | }
161 | ```
162 |
163 | ## Browser Support
164 |
165 | - Chrome 90+
166 | - Firefox 88+
167 | - Safari 14+
168 | - Edge 90+
169 |
170 | ## Performance
171 |
172 | Optimized for:
173 | - 60 FPS animations
174 | - Sequence length ≤ 10
175 | - Attention heads ≤ 8
176 | - Model dimension ≤ 512
177 |
178 | ## License
179 |
180 | MIT License - see [LICENSE](LICENSE) for details.
181 |
182 | ## Contributing
183 |
184 | 1. Fork the repository
185 | 2. Create a feature branch
186 | 3. Make your changes
187 | 4. Add tests if applicable
188 | 5. Submit a pull request
189 |
190 | ## Troubleshooting
191 |
192 | ### Common Issues
193 |
194 | **"Cannot find module" errors**: Run `npm install`
195 |
196 | **Server won't start**: Check that port 3000 is available, or specify a different port: `python server/server.py 3001`
197 |
198 | **Visualization not updating**: Try refreshing the page or clearing browser cache
199 |
200 | **Performance issues**: Reduce model parameters (fewer layers, smaller dimensions)
201 |
202 | ### Debug Mode
203 |
204 | Enable debug logging:
205 | ```bash
206 | DEBUG=1 python server/server.py
207 | ```
208 |
209 | ## Educational Extensions
210 |
211 | Future enhancements could include:
212 |
213 | - Real model weights from Hugging Face
214 | - Attention pattern analysis
215 | - Interactive quizzes between steps
216 | - Comparison with other architectures
217 | - Custom text input
218 | - Export/import configurations
--------------------------------------------------------------------------------
/src/components/MatrixVisualization.tsx:
--------------------------------------------------------------------------------
1 | import { useRef, useEffect, useState, useMemo } from 'react';
2 | import * as d3 from 'd3';
3 | import { MagnifyingGlassIcon } from '@heroicons/react/24/outline';
4 | import { COLORS } from '../utils/constants';
5 |
6 | interface MatrixVisualizationProps {
7 | data: number[][] | number[];
8 | shape?: [number, number] | [number];
9 | label: string;
10 | className?: string;
11 | }
12 |
13 | export function MatrixVisualization({
14 | data,
15 | label,
16 | className = ''
17 | }: MatrixVisualizationProps) {
18 | const svgRef = useRef(null);
19 | const detailSvgRef = useRef(null);
20 | const containerRef = useRef(null);
21 | const [showDetailModal, setShowDetailModal] = useState(false);
22 |
23 | // Convert data to 2D array if it's 1D - memoized to avoid recreation on every render
24 | const matrix2D = useMemo(() => {
25 | return Array.isArray(data[0]) ? data as number[][] : [data as number[]];
26 | }, [data]);
27 |
28 | const rows = matrix2D.length;
29 | const cols = matrix2D[0].length;
30 |
31 | // Simplified overview visualization
32 | useEffect(() => {
33 | if (!svgRef.current || !containerRef.current || !data) return;
34 |
35 | const svg = d3.select(svgRef.current);
36 | svg.selectAll('*').remove();
37 |
38 | const width = 200;
39 | const height = 150;
40 | svg.attr('width', width).attr('height', height);
41 |
42 | const margin = { top: 30, right: 40, bottom: 30, left: 40 };
43 | const innerWidth = width - margin.left - margin.right;
44 | const innerHeight = height - margin.top - margin.bottom;
45 |
46 | const g = svg.append('g').attr('transform', `translate(${margin.left},${margin.top})`);
47 |
48 | // Calculate matrix rectangle dimensions
49 | const aspectRatio = cols / rows;
50 | let rectWidth, rectHeight;
51 |
52 | if (aspectRatio > 1) {
53 | rectWidth = innerWidth;
54 | rectHeight = innerWidth / aspectRatio;
55 | } else {
56 | rectHeight = innerHeight;
57 | rectWidth = innerHeight * aspectRatio;
58 | }
59 |
60 | const rectX = (innerWidth - rectWidth) / 2;
61 | const rectY = (innerHeight - rectHeight) / 2;
62 |
63 | // Color based on data range for visual feedback
64 | const flatData = matrix2D.flat();
65 | const dataRange = d3.extent(flatData) as [number, number];
66 | const avgValue = d3.mean(flatData) || 0;
67 | const normalizedValue = dataRange[1] !== dataRange[0] ?
68 | (avgValue - dataRange[0]) / (dataRange[1] - dataRange[0]) : 0.5;
69 |
70 | const matrixColor = d3.interpolateViridis(normalizedValue);
71 |
72 | // Draw matrix rectangle
73 | g.append('rect')
74 | .attr('width', rectWidth)
75 | .attr('height', rectHeight)
76 | .attr('x', rectX)
77 | .attr('y', rectY)
78 | .attr('fill', matrixColor)
79 | .attr('fill-opacity', 0.7)
80 | .attr('stroke', COLORS.text)
81 | .attr('stroke-width', 1)
82 | .attr('rx', 4);
83 |
84 | // Row dimension label (left side)
85 | g.append('text')
86 | .attr('x', rectX - 8)
87 | .attr('y', rectY + rectHeight / 2)
88 | .attr('text-anchor', 'end')
89 | .attr('dominant-baseline', 'middle')
90 | .attr('font-size', 14)
91 | .attr('font-weight', 'bold')
92 | .attr('fill', COLORS.text)
93 | .text(rows);
94 |
95 | // Column dimension label (top)
96 | g.append('text')
97 | .attr('x', rectX + rectWidth / 2)
98 | .attr('y', rectY - 8)
99 | .attr('text-anchor', 'middle')
100 | .attr('dominant-baseline', 'middle')
101 | .attr('font-size', 14)
102 | .attr('font-weight', 'bold')
103 | .attr('fill', COLORS.text)
104 | .text(cols);
105 |
106 | }, [data, rows, cols, matrix2D]); // Added matrix2D to dependencies
107 |
108 | // Detailed modal visualization
109 | useEffect(() => {
110 | if (!detailSvgRef.current || !showDetailModal || !data) return;
111 |
112 | const svg = d3.select(detailSvgRef.current);
113 | svg.selectAll('*').remove();
114 |
115 | const maxWidth = 800;
116 | const maxHeight = 600;
117 |
118 | // Calculate dimensions for detailed view
119 | const cellSize = Math.min(maxWidth / cols, maxHeight / rows, 30);
120 | const width = cols * cellSize;
121 | const height = rows * cellSize;
122 |
123 | svg.attr('width', width + 80).attr('height', height + 80);
124 |
125 | const margin = { top: 40, right: 40, bottom: 40, left: 60 };
126 | const g = svg.append('g').attr('transform', `translate(${margin.left},${margin.top})`);
127 |
128 | // Create color scale
129 | const flatData = matrix2D.flat();
130 | const colorScale = d3.scaleSequential(d3.interpolateViridis)
131 | .domain(d3.extent(flatData) as [number, number]);
132 |
133 | // Create heatmap
134 | const cells = g.selectAll('.cell')
135 | .data(matrix2D.flatMap((row, i) =>
136 | row.map((value, j) => ({ value, row: i, col: j }))
137 | ))
138 | .enter()
139 | .append('g')
140 | .attr('class', 'cell')
141 | .attr('transform', d => `translate(${d.col * cellSize},${d.row * cellSize})`);
142 |
143 | cells.append('rect')
144 | .attr('width', cellSize - 1)
145 | .attr('height', cellSize - 1)
146 | .attr('fill', d => colorScale(d.value))
147 | .attr('stroke', '#fff')
148 | .attr('stroke-width', 0.5);
149 |
150 | // Add values if cells are large enough
151 | if (cellSize >= 20) {
152 | cells.append('text')
153 | .attr('x', cellSize / 2)
154 | .attr('y', cellSize / 2)
155 | .attr('text-anchor', 'middle')
156 | .attr('dominant-baseline', 'middle')
157 | .attr('font-size', Math.min(cellSize / 3, 12))
158 | .attr('font-weight', 'bold')
159 | .attr('fill', d => d3.lab(colorScale(d.value)).l > 60 ? '#000' : '#fff')
160 | .text(d => d.value.toFixed(3));
161 | }
162 |
163 | // Add axis labels with better spacing
164 | if (rows <= 20) {
165 | g.selectAll('.row-label')
166 | .data(d3.range(rows))
167 | .enter()
168 | .append('text')
169 | .attr('class', 'row-label')
170 | .attr('x', -8)
171 | .attr('y', d => d * cellSize + cellSize / 2)
172 | .attr('text-anchor', 'end')
173 | .attr('dominant-baseline', 'middle')
174 | .attr('font-size', Math.min(cellSize / 2, 12))
175 | .attr('fill', COLORS.text)
176 | .text(d => d);
177 | }
178 |
179 | if (cols <= 30) {
180 | g.selectAll('.col-label')
181 | .data(d3.range(cols))
182 | .enter()
183 | .append('text')
184 | .attr('class', 'col-label')
185 | .attr('x', d => d * cellSize + cellSize / 2)
186 | .attr('y', -8)
187 | .attr('text-anchor', 'middle')
188 | .attr('font-size', Math.min(cellSize / 2, 12))
189 | .attr('fill', COLORS.text)
190 | .text(d => d);
191 | }
192 |
193 | }, [data, showDetailModal, rows, cols, label, matrix2D]); // Added matrix2D to dependencies
194 |
195 | return (
196 | <>
197 |
198 |
199 |
200 |
{label}
201 |
202 |
209 |
210 |
211 |
212 |
213 |
214 |
215 | {/* Detail Modal */}
216 | {showDetailModal && (
217 |
218 |
219 |
220 |
221 |
{label}
222 |
Detailed matrix view • {rows}×{cols}
223 |
224 |
230 |
231 |
232 |
233 |
234 |
235 |
236 | )}
237 | >
238 | );
239 | }
--------------------------------------------------------------------------------
/src/components/SettingsMenu.tsx:
--------------------------------------------------------------------------------
1 | import { useState, useRef, useEffect } from 'react';
2 | import { CogIcon, XMarkIcon } from '@heroicons/react/24/outline';
3 | import { TransformerParams } from '../types/events';
4 | import { PARAM_LIMITS } from '../utils/constants';
5 |
6 | interface SettingsMenuProps {
7 | params: TransformerParams;
8 | onParamsChange: (params: Partial) => void;
9 | }
10 |
11 | // Define preset configurations
12 | const PRESETS = {
13 | minimal: { numLayers: 1, dModel: 64, numHeads: 4, seqLen: 4 },
14 | standard: { numLayers: 3, dModel: 128, numHeads: 8, seqLen: 6 },
15 | large: { numLayers: 6, dModel: 256, numHeads: 8, seqLen: 8 },
16 | } as const;
17 |
18 | export function SettingsMenu({
19 | params,
20 | onParamsChange,
21 | }: SettingsMenuProps) {
22 | const [isOpen, setIsOpen] = useState(false);
23 | const menuRef = useRef(null);
24 |
25 | // Determine which preset is currently active
26 | const getActivePreset = () => {
27 | for (const [presetName, presetParams] of Object.entries(PRESETS)) {
28 | const isMatch = Object.entries(presetParams).every(([key, value]) => {
29 | return params[key as keyof TransformerParams] === value;
30 | });
31 | if (isMatch) return presetName;
32 | }
33 | return null;
34 | };
35 |
36 | const activePreset = getActivePreset();
37 |
38 | // Close menu when clicking outside
39 | useEffect(() => {
40 | function handleClickOutside(event: MouseEvent) {
41 | if (menuRef.current && !menuRef.current.contains(event.target as Node)) {
42 | setIsOpen(false);
43 | }
44 | }
45 |
46 | if (isOpen) {
47 | document.addEventListener('mousedown', handleClickOutside);
48 | return () => document.removeEventListener('mousedown', handleClickOutside);
49 | }
50 | }, [isOpen]);
51 |
52 | // Helper function to get button styling based on active state
53 | const getButtonClass = (presetName: string) => {
54 | const baseClass = "px-3 py-2 text-xs rounded-md transition-colors";
55 | const isActive = activePreset === presetName;
56 |
57 | if (isActive) {
58 | return `${baseClass} bg-cs-blue text-white hover:bg-cs-deep`;
59 | } else {
60 | return `${baseClass} bg-gray-100 hover:bg-gray-200 text-gray-700`;
61 | }
62 | };
63 |
64 | return (
65 |
66 | {/* Settings Button */}
67 |
75 |
76 | {/* Settings Panel */}
77 | {isOpen && (
78 |
79 |
80 | {/* Header */}
81 |
82 |
Configuration
83 |
89 |
90 |
91 | {/* Architecture Parameters */}
92 |
93 |
Architecture
94 |
95 | {/* Number of Layers */}
96 |
97 |
98 |
99 | {params.numLayers}
100 |
101 |
onParamsChange({ numLayers: parseInt(e.target.value) })}
107 | className="w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer slider"
108 | />
109 |
110 |
111 | {/* Model Dimension */}
112 |
113 |
114 |
115 | {params.dModel}
116 |
117 |
onParamsChange({ dModel: parseInt(e.target.value) })}
124 | className="w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer slider"
125 | />
126 |
127 |
128 | {/* Number of Heads */}
129 |
130 |
131 |
132 | {params.numHeads}
133 |
134 |
onParamsChange({ numHeads: parseInt(e.target.value) })}
140 | className="w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer slider"
141 | />
142 |
143 | Max: {Math.floor(params.dModel / 8)} (divisible constraint)
144 |
145 |
146 |
147 | {/* Sequence Length */}
148 |
149 |
150 |
151 | {params.seqLen}
152 |
153 |
onParamsChange({ seqLen: parseInt(e.target.value) })}
159 | className="w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer slider"
160 | />
161 |
162 |
163 | {/* Positional Encoding */}
164 |
165 |
166 |
174 |
175 |
176 |
177 | {/* Quick Presets */}
178 |
179 |
Quick Presets
180 |
181 |
187 |
193 |
199 |
200 |
201 |
202 |
203 | )}
204 |
205 | );
206 | }
--------------------------------------------------------------------------------
/src/state/transformerMachine.ts:
--------------------------------------------------------------------------------
1 | import { createMachine, assign } from 'xstate';
2 | import { TransformerParams } from '../types/events';
3 | import { DEFAULT_PARAMS, TRANSFORMER_STEPS, ZOOM_LEVELS, ZoomLevel } from '../utils/constants';
4 | import { SequenceGenerator } from '../utils/data';
5 |
6 | export interface TransformerContext {
7 | // Core simulation state
8 | currentStep: number;
9 | totalSteps: number;
10 | activeToken: number;
11 | focusToken: boolean;
12 |
13 | // Zoom and view state
14 | zoomLevel: ZoomLevel;
15 | selectedComponent?: string;
16 | selectedLayer?: number;
17 |
18 | // Model parameters
19 | params: TransformerParams;
20 |
21 | // UI state
22 | isPlaying: boolean;
23 | playbackSpeed: number; // 0.5x to 2x
24 |
25 | // Settings
26 | settingsOpen: boolean;
27 | darkMode: boolean;
28 | colorBlindMode: boolean;
29 |
30 | // Attention visualization
31 | attentionLensActive: boolean;
32 | attentionTopK: number;
33 |
34 | // Data - simplified to just sequences
35 | inputSequence: string[];
36 | outputSequence: string[];
37 |
38 | // Performance
39 | lastFrameTime: number;
40 | currentFPS: number;
41 |
42 | // Error state
43 | error?: string;
44 | }
45 |
46 | export type TransformerEvent =
47 | | { type: 'PLAY' }
48 | | { type: 'PAUSE' }
49 | | { type: 'NEXT_STEP' }
50 | | { type: 'PREV_STEP' }
51 | | { type: 'SET_STEP'; step: number }
52 | | { type: 'RESET' }
53 | | { type: 'SET_ACTIVE_TOKEN'; tokenIndex: number }
54 | | { type: 'TOGGLE_FOCUS_TOKEN' }
55 | | { type: 'ZOOM_IN'; component?: string; layer?: number }
56 | | { type: 'ZOOM_OUT' }
57 | | { type: 'SET_ZOOM_LEVEL'; level: ZoomLevel; component?: string }
58 | | { type: 'UPDATE_PARAMS'; params: Partial }
59 | | { type: 'TOGGLE_SETTINGS' }
60 | | { type: 'SET_PLAYBACK_SPEED'; speed: number }
61 | | { type: 'TOGGLE_DARK_MODE' }
62 | | { type: 'TOGGLE_COLOR_BLIND_MODE' }
63 | | { type: 'TOGGLE_ATTENTION_LENS' }
64 | | { type: 'SET_ATTENTION_TOP_K'; k: number }
65 | | { type: 'TICK' } // Animation frame
66 | | { type: 'ERROR'; message: string }
67 | | { type: 'CLEAR_ERROR' };
68 |
69 | // Helper function to create toggle actions
70 | const createToggleAction = (field: K) =>
71 | assign({ [field]: ({ context }: { context: TransformerContext }) => !context[field] });
72 |
73 | // Generate initial sequences - simplified
74 | const initialInputSequence = SequenceGenerator.generateSequence(DEFAULT_PARAMS.seqLen, 'input');
75 | const initialOutputSequence = SequenceGenerator.generateMatchingOutput(initialInputSequence);
76 |
77 | const transformerMachine = createMachine({
78 | /** @xstate-layout N4IgpgJg5mDOIC5QAoC2BDAxgCwJYDswBKAOnwGIA */
79 | id: 'transformer',
80 | initial: 'idle',
81 | context: {
82 | currentStep: 0,
83 | totalSteps: TRANSFORMER_STEPS.length,
84 | activeToken: 0,
85 | focusToken: false,
86 |
87 | zoomLevel: ZOOM_LEVELS.GLOBAL,
88 | selectedComponent: undefined,
89 | selectedLayer: undefined,
90 |
91 | params: DEFAULT_PARAMS,
92 |
93 | isPlaying: false,
94 | playbackSpeed: 1.0,
95 |
96 | settingsOpen: false,
97 | darkMode: false,
98 | colorBlindMode: false,
99 |
100 | attentionLensActive: false,
101 | attentionTopK: 5,
102 |
103 | inputSequence: initialInputSequence,
104 | outputSequence: initialOutputSequence,
105 |
106 | lastFrameTime: 0,
107 | currentFPS: 60,
108 | } as TransformerContext,
109 |
110 | states: {
111 | idle: {
112 | on: {
113 | PLAY: {
114 | target: 'playing',
115 | actions: 'setPlaying'
116 | },
117 | NEXT_STEP: {
118 | actions: 'nextStep'
119 | },
120 | PREV_STEP: {
121 | actions: 'prevStep'
122 | },
123 | SET_STEP: {
124 | actions: 'setStep'
125 | },
126 | RESET: {
127 | actions: 'reset'
128 | },
129 | SET_ACTIVE_TOKEN: {
130 | actions: 'setActiveToken'
131 | },
132 | TOGGLE_FOCUS_TOKEN: {
133 | actions: 'toggleFocusToken'
134 | },
135 | ZOOM_IN: {
136 | actions: 'zoomIn'
137 | },
138 | ZOOM_OUT: {
139 | actions: 'zoomOut'
140 | },
141 | SET_ZOOM_LEVEL: {
142 | actions: 'setZoomLevel'
143 | },
144 | UPDATE_PARAMS: {
145 | actions: 'updateParams'
146 | },
147 | TOGGLE_SETTINGS: {
148 | actions: 'toggleSettings'
149 | },
150 | SET_PLAYBACK_SPEED: {
151 | actions: 'setPlaybackSpeed'
152 | },
153 | TOGGLE_DARK_MODE: {
154 | actions: 'toggleDarkMode'
155 | },
156 | TOGGLE_COLOR_BLIND_MODE: {
157 | actions: 'toggleColorBlindMode'
158 | },
159 | TOGGLE_ATTENTION_LENS: {
160 | actions: 'toggleAttentionLens'
161 | },
162 | SET_ATTENTION_TOP_K: {
163 | actions: 'setAttentionTopK'
164 | },
165 | ERROR: {
166 | target: 'error',
167 | actions: 'setError'
168 | }
169 | }
170 | },
171 |
172 | playing: {
173 | on: {
174 | PAUSE: {
175 | target: 'idle',
176 | actions: 'setPaused'
177 | },
178 | TICK: {
179 | actions: 'animationTick'
180 | },
181 | // Allow other interactions while playing
182 | SET_ACTIVE_TOKEN: {
183 | actions: 'setActiveToken'
184 | },
185 | TOGGLE_FOCUS_TOKEN: {
186 | actions: 'toggleFocusToken'
187 | },
188 | ZOOM_IN: {
189 | actions: 'zoomIn'
190 | },
191 | ZOOM_OUT: {
192 | actions: 'zoomOut'
193 | },
194 | SET_ZOOM_LEVEL: {
195 | actions: 'setZoomLevel'
196 | },
197 | TOGGLE_SETTINGS: {
198 | actions: 'toggleSettings'
199 | },
200 | SET_PLAYBACK_SPEED: {
201 | actions: 'setPlaybackSpeed'
202 | },
203 | TOGGLE_ATTENTION_LENS: {
204 | actions: 'toggleAttentionLens'
205 | },
206 | SET_ATTENTION_TOP_K: {
207 | actions: 'setAttentionTopK'
208 | },
209 | ERROR: {
210 | target: 'error',
211 | actions: 'setError'
212 | }
213 | }
214 | },
215 |
216 | error: {
217 | on: {
218 | CLEAR_ERROR: {
219 | target: 'idle',
220 | actions: 'clearError'
221 | },
222 | RESET: {
223 | target: 'idle',
224 | actions: ['reset', 'clearError']
225 | }
226 | }
227 | }
228 | }
229 | }, {
230 | actions: {
231 | setPlaying: assign({
232 | isPlaying: true
233 | }),
234 |
235 | setPaused: assign({
236 | isPlaying: false
237 | }),
238 |
239 | nextStep: assign({
240 | currentStep: ({ context }) =>
241 | Math.min(context.currentStep + 1, context.totalSteps - 1)
242 | }),
243 |
244 | prevStep: assign({
245 | currentStep: ({ context }) =>
246 | Math.max(context.currentStep - 1, 0)
247 | }),
248 |
249 | setStep: assign({
250 | currentStep: ({ event }) =>
251 | event.type === 'SET_STEP' ? event.step : 0
252 | }),
253 |
254 | reset: assign({
255 | currentStep: 0,
256 | activeToken: 0,
257 | zoomLevel: ZOOM_LEVELS.GLOBAL,
258 | selectedComponent: undefined,
259 | selectedLayer: undefined
260 | }),
261 |
262 | setActiveToken: assign({
263 | activeToken: ({ event, context }) => {
264 | if (event.type !== 'SET_ACTIVE_TOKEN') return context.activeToken;
265 | // Ensure token index is within sequence bounds
266 | const tokenIndex = event.tokenIndex;
267 | const maxIndex = Math.max(0, context.inputSequence.length - 1);
268 | return Math.min(Math.max(0, tokenIndex), maxIndex);
269 | }
270 | }),
271 |
272 | toggleFocusToken: createToggleAction('focusToken'),
273 |
274 | zoomIn: assign({
275 | zoomLevel: ({ context }) => {
276 | const levels = Object.values(ZOOM_LEVELS);
277 | const currentIndex = levels.indexOf(context.zoomLevel);
278 | const nextIndex = Math.min(currentIndex + 1, levels.length - 1);
279 | return levels[nextIndex];
280 | },
281 | selectedComponent: ({ event }) =>
282 | event.type === 'ZOOM_IN' ? event.component : undefined,
283 | selectedLayer: ({ event }) =>
284 | event.type === 'ZOOM_IN' ? event.layer : undefined
285 | }),
286 |
287 | zoomOut: assign({
288 | zoomLevel: ({ context }) => {
289 | const levels = Object.values(ZOOM_LEVELS);
290 | const currentIndex = levels.indexOf(context.zoomLevel);
291 | const prevIndex = Math.max(currentIndex - 1, 0);
292 | return levels[prevIndex];
293 | },
294 | selectedComponent: ({ context }) =>
295 | context.zoomLevel === ZOOM_LEVELS.GLOBAL ? undefined : context.selectedComponent,
296 | selectedLayer: ({ context }) =>
297 | context.zoomLevel === ZOOM_LEVELS.GLOBAL ? undefined : context.selectedLayer
298 | }),
299 |
300 | setZoomLevel: assign({
301 | zoomLevel: ({ event }) =>
302 | event.type === 'SET_ZOOM_LEVEL' ? event.level : ZOOM_LEVELS.GLOBAL,
303 | selectedComponent: ({ event }) =>
304 | event.type === 'SET_ZOOM_LEVEL' ? event.component : undefined
305 | }),
306 |
307 | updateParams: assign(({ context, event }) => {
308 | if (event.type !== 'UPDATE_PARAMS') return {};
309 |
310 | const newParams = { ...context.params, ...event.params };
311 |
312 | // Validate parameters
313 | if (newParams.numHeads > 0 && newParams.dModel % newParams.numHeads !== 0) {
314 | newParams.numHeads = Math.floor(newParams.dModel / 8) || 1;
315 | }
316 |
317 | // Handle sequence length changes with truncation/padding instead of regeneration
318 | const lengthChanged = event.params.seqLen && event.params.seqLen !== context.params.seqLen;
319 |
320 | if (lengthChanged) {
321 | const newLen = event.params.seqLen!;
322 |
323 | // Adjust existing sequences to new length instead of generating new ones
324 | const adjustedInputSequence = SequenceGenerator.adjustSequenceLength(context.inputSequence, newLen);
325 | const adjustedOutputSequence = SequenceGenerator.adjustSequenceLength(context.outputSequence, newLen);
326 |
327 | return {
328 | params: newParams,
329 | inputSequence: adjustedInputSequence,
330 | outputSequence: adjustedOutputSequence,
331 | // Keep existing sequence context since we're just adjusting length
332 | // Reset active token if it's now out of bounds
333 | activeToken: Math.min(context.activeToken, Math.max(0, newLen - 1))
334 | };
335 | }
336 |
337 | return {
338 | params: newParams
339 | };
340 | }),
341 |
342 | toggleSettings: createToggleAction('settingsOpen'),
343 |
344 | setPlaybackSpeed: assign({
345 | playbackSpeed: ({ event }) =>
346 | event.type === 'SET_PLAYBACK_SPEED' ? event.speed : 1.0
347 | }),
348 |
349 | toggleDarkMode: createToggleAction('darkMode'),
350 |
351 | toggleColorBlindMode: createToggleAction('colorBlindMode'),
352 |
353 | toggleAttentionLens: createToggleAction('attentionLensActive'),
354 |
355 | setAttentionTopK: assign({
356 | attentionTopK: ({ event }) =>
357 | event.type === 'SET_ATTENTION_TOP_K' ? event.k : 5
358 | }),
359 |
360 | animationTick: assign(({ context }) => {
361 | const now = performance.now();
362 | const delta = now - context.lastFrameTime;
363 | const currentFPS = delta > 0 ? Math.round(1000 / delta) : 60;
364 |
365 | // Auto-advance step based on playback speed
366 | const frameInterval = 1000 / 24; // 24 FPS as specified in PRD
367 | const adjustedInterval = frameInterval / context.playbackSpeed;
368 | const timeSinceLastStep = now - context.lastFrameTime;
369 |
370 | const currentStep = timeSinceLastStep >= adjustedInterval
371 | ? Math.min(context.currentStep + 1, context.totalSteps - 1)
372 | : context.currentStep;
373 |
374 | return {
375 | lastFrameTime: now,
376 | currentFPS,
377 | currentStep
378 | };
379 | }),
380 |
381 | setError: assign({
382 | error: ({ event }) =>
383 | event.type === 'ERROR' ? event.message : undefined
384 | }),
385 |
386 | clearError: assign({
387 | error: undefined
388 | }),
389 |
390 |
391 | }
392 | });
393 |
394 | export default transformerMachine;
--------------------------------------------------------------------------------
/src/utils/componentTransformations.ts:
--------------------------------------------------------------------------------
1 | import { ComponentSection } from '../types/events';
2 |
3 | // Helper function to check if component is Add & Norm
4 | const isAddNormComponent = (componentId: string): boolean => {
5 | return componentId.includes('-norm');
6 | };
7 |
8 | // Helper function to get operations for Add & Norm components
9 | const getAddNormOperations = (): string[] => {
10 | return ['Residual Addition', 'Layer Normalization'];
11 | };
12 |
13 | // Helper function to format component names
14 | export const formatComponentName = (id: string): string => {
15 | return id.split('-').map(word =>
16 | word.charAt(0).toUpperCase() + word.slice(1)
17 | ).join(' ');
18 | };
19 |
20 | // Helper function to get category colors
21 | export const getCategoryColor = (category: string): string => {
22 | const colors = {
23 | embedding: 'bg-blue-100 text-blue-800',
24 | attention: 'bg-purple-100 text-purple-800',
25 | ffn: 'bg-green-100 text-green-800',
26 | output: 'bg-orange-100 text-orange-800',
27 | tokens: 'bg-gray-100 text-gray-800',
28 | positional: 'bg-amber-100 text-amber-800',
29 | };
30 | return colors[category as keyof typeof colors] || 'bg-gray-100 text-gray-800';
31 | };
32 |
33 | // Helper function to get operations for each component type
34 | export const getOperations = (category: string, componentId: string): string[] => {
35 | switch (category) {
36 | case 'embedding':
37 | return ['Token Lookup', 'Positional Addition'];
38 | case 'attention':
39 | // Handle Add & Norm blocks for attention
40 | if (isAddNormComponent(componentId)) {
41 | return getAddNormOperations();
42 | }
43 | return ['Q/K/V Projection', 'Scaled Dot-Product', 'Softmax', 'Value Weighting'];
44 | case 'ffn':
45 | // Handle Add & Norm blocks for FFN
46 | if (isAddNormComponent(componentId)) {
47 | return getAddNormOperations();
48 | }
49 | return ['Linear Expansion', 'ReLU Activation', 'Linear Projection'];
50 | case 'output':
51 | // Handle encoder output specifically
52 | if (componentId === 'encoder-output') {
53 | return ['Contextual Representation', 'Multi-Layer Processing'];
54 | }
55 | // Handle final output (Linear + Softmax)
56 | if (componentId === 'final-output') {
57 | return ['Vocabulary Projection', 'Softmax Computation'];
58 | }
59 | // Handle individual linear layer
60 | if (componentId.includes('linear') || componentId.includes('lm-head')) {
61 | return ['Vocabulary Projection'];
62 | }
63 | // Handle individual softmax
64 | if (componentId.includes('softmax')) {
65 | return ['Softmax Computation'];
66 | }
67 | // Handle predicted tokens (final sampling)
68 | if (componentId === 'predicted-tokens') {
69 | return ['Token Sampling'];
70 | }
71 | return ['Vocabulary Projection', 'Softmax', 'Token Sampling'];
72 | case 'tokens':
73 | return ['Tokenization'];
74 | case 'positional':
75 | return ['Sinusoidal Encoding'];
76 | default:
77 | // Handle Add & Norm blocks specifically (fallback)
78 | if (isAddNormComponent(componentId)) {
79 | return getAddNormOperations();
80 | }
81 | return ['Component Processing'];
82 | }
83 | };
84 |
85 | // Transformation generators for each component category
86 | const generateEmbeddingTransformations = (): ComponentSection[] => [
87 | {
88 | id: 'embedding-lookup',
89 | label: 'Token Lookup',
90 | description: 'Lookup embedding vectors from learned embedding matrix',
91 | type: 'text',
92 | data: 'embedding_matrix[token_indices]',
93 | metadata: {
94 | operation: 'lookup',
95 | formula: 'E = W_e[tokens]',
96 | complexity: 'O(seq_len)',
97 | activation: 'none'
98 | }
99 | },
100 | {
101 | id: 'positional-addition',
102 | label: 'Positional Addition',
103 | description: 'Element-wise addition of positional encoding to embeddings',
104 | type: 'text',
105 | data: 'embeddings + positional_encoding',
106 | metadata: {
107 | operation: 'addition',
108 | formula: 'H = E + P',
109 | complexity: 'O(seq_len × d_model)',
110 | activation: 'none'
111 | }
112 | }
113 | ];
114 |
115 | const generateAttentionTransformations = (): ComponentSection[] => [
116 | {
117 | id: 'linear-projections',
118 | label: 'Q/K/V Projection',
119 | description: 'Project input to query, key, and value spaces',
120 | type: 'text',
121 | data: 'Q = XW_Q, K = XW_K, V = XW_V',
122 | metadata: {
123 | operation: 'linear',
124 | formula: 'Q, K, V = XW_Q, XW_K, XW_V',
125 | complexity: 'O(seq_len × d_model²)',
126 | activation: 'none'
127 | }
128 | },
129 | {
130 | id: 'attention-scores',
131 | label: 'Scaled Dot-Product',
132 | description: 'Compute attention scores using scaled dot-product',
133 | type: 'text',
134 | data: 'scores = QK^T / √d_k',
135 | metadata: {
136 | operation: 'dot-product',
137 | formula: 'Attention(Q,K,V) = softmax(QK^T/√d_k)V',
138 | complexity: 'O(seq_len²)',
139 | activation: 'none'
140 | }
141 | },
142 | {
143 | id: 'softmax-activation',
144 | label: 'Softmax',
145 | description: 'Apply softmax to get attention weights',
146 | type: 'text',
147 | data: 'attention_weights = softmax(scores)',
148 | metadata: {
149 | operation: 'normalization',
150 | formula: 'α_ij = exp(e_ij) / Σ_k exp(e_ik)',
151 | complexity: 'O(seq_len²)',
152 | activation: 'softmax'
153 | }
154 | },
155 | {
156 | id: 'weighted-combination',
157 | label: 'Value Weighting',
158 | description: 'Compute weighted combination of values',
159 | type: 'text',
160 | data: 'output = attention_weights × V',
161 | metadata: {
162 | operation: 'weighted-sum',
163 | formula: 'head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)',
164 | complexity: 'O(seq_len × d_model)',
165 | activation: 'none'
166 | }
167 | }
168 | ];
169 |
170 | const generateFFNTransformations = (): ComponentSection[] => [
171 | {
172 | id: 'first-linear',
173 | label: 'Linear Expansion',
174 | description: 'Expand to intermediate dimension',
175 | type: 'text',
176 | data: 'intermediate = XW₁ + b₁',
177 | metadata: {
178 | operation: 'linear',
179 | formula: 'FFN(x) = max(0, xW₁ + b₁)W₂ + b₂',
180 | complexity: 'O(seq_len × d_model × d_ff)',
181 | activation: 'none'
182 | }
183 | },
184 | {
185 | id: 'relu-activation',
186 | label: 'ReLU Activation',
187 | description: 'Apply rectified linear unit activation',
188 | type: 'text',
189 | data: 'activated = max(0, intermediate)',
190 | metadata: {
191 | operation: 'activation',
192 | formula: 'ReLU(x) = max(0, x)',
193 | complexity: 'O(seq_len × d_ff)',
194 | activation: 'relu'
195 | }
196 | },
197 | {
198 | id: 'second-linear',
199 | label: 'Linear Projection',
200 | description: 'Project back to model dimension',
201 | type: 'text',
202 | data: 'output = activatedW₂ + b₂',
203 | metadata: {
204 | operation: 'linear',
205 | formula: 'W₂ ∈ ℝ^(d_ff × d_model)',
206 | complexity: 'O(seq_len × d_ff × d_model)',
207 | activation: 'none'
208 | }
209 | }
210 | ];
211 |
212 | const generateAddNormTransformations = (): ComponentSection[] => [
213 | {
214 | id: 'residual-addition',
215 | label: 'Residual Addition',
216 | description: 'Add input to block output (skip connection)',
217 | type: 'text',
218 | data: 'residual_sum = X + block_output',
219 | metadata: {
220 | operation: 'addition',
221 | formula: 'X_residual = X + SubLayer(X)',
222 | complexity: 'O(seq_len × d_model)',
223 | activation: 'none'
224 | }
225 | },
226 | {
227 | id: 'layer-normalization',
228 | label: 'Layer Normalization',
229 | description: 'Normalize across feature dimension',
230 | type: 'text',
231 | data: 'normalized = LayerNorm(residual_sum)',
232 | metadata: {
233 | operation: 'normalization',
234 | formula: 'LayerNorm(x) = γ ⊙ (x - μ)/σ + β',
235 | complexity: 'O(seq_len × d_model)',
236 | activation: 'none'
237 | }
238 | }
239 | ];
240 |
241 | const generateOutputTransformations = (componentId: string): ComponentSection[] => {
242 | const transformations: ComponentSection[] = [];
243 |
244 | if (componentId.includes('linear') || componentId.includes('lm-head')) {
245 | transformations.push({
246 | id: 'vocab-projection',
247 | label: 'Vocabulary Projection',
248 | description: 'Project hidden states to vocabulary size',
249 | type: 'text',
250 | data: 'logits = hidden_states × W_vocab',
251 | metadata: {
252 | operation: 'linear',
253 | formula: 'logits = HW_v + b_v',
254 | complexity: 'O(seq_len × d_model × vocab_size)',
255 | activation: 'none'
256 | }
257 | });
258 | }
259 |
260 | if (componentId.includes('softmax')) {
261 | transformations.push({
262 | id: 'softmax-computation',
263 | label: 'Softmax Computation',
264 | description: 'Convert logits to probability distribution',
265 | type: 'text',
266 | data: 'probs = softmax(logits)',
267 | metadata: {
268 | operation: 'normalization',
269 | formula: 'p_i = exp(logit_i) / Σ_j exp(logit_j)',
270 | complexity: 'O(seq_len × vocab_size)',
271 | activation: 'softmax'
272 | }
273 | });
274 | }
275 |
276 | if (componentId.includes('predicted')) {
277 | transformations.push({
278 | id: 'token-sampling',
279 | label: 'Token Sampling',
280 | description: 'Sample tokens from probability distribution',
281 | type: 'text',
282 | data: 'token = sample(probabilities)',
283 | metadata: {
284 | operation: 'sampling',
285 | formula: 'token ~ Categorical(p)',
286 | complexity: 'O(seq_len)',
287 | activation: 'none'
288 | }
289 | });
290 | }
291 |
292 | return transformations;
293 | };
294 |
295 | // Generate transformations for encoder output
296 | const generateEncoderOutputTransformations = (): ComponentSection[] => [
297 | {
298 | id: 'contextualized-representations',
299 | label: 'Contextual Representation',
300 | description: 'Final encoder representations with full contextual information',
301 | type: 'text',
302 | data: 'encoder_output = LayerNorm(H_N + Attention(H_N))',
303 | metadata: {
304 | operation: 'normalization',
305 | formula: 'H_enc = LayerNorm(H_N + MultiHeadAttention(H_N))',
306 | complexity: 'O(seq_len × d_model)',
307 | activation: 'none'
308 | }
309 | },
310 | {
311 | id: 'multi-layer-processing',
312 | label: 'Multi-Layer Processing',
313 | description: 'Result of processing through all encoder layers',
314 | type: 'text',
315 | data: 'for layer in encoder_layers: H = layer(H)',
316 | metadata: {
317 | operation: 'composition',
318 | formula: 'H_enc = EncoderLayer_N(...EncoderLayer_1(H_0))',
319 | complexity: 'O(N × seq_len × d_model²)',
320 | activation: 'various'
321 | }
322 | }
323 | ];
324 |
325 | // Generate transformations for final output (Linear + Softmax)
326 | const generateFinalOutputTransformations = (): ComponentSection[] => [
327 | {
328 | id: 'vocabulary-projection',
329 | label: 'Vocabulary Projection',
330 | description: 'Project decoder states to vocabulary logits',
331 | type: 'text',
332 | data: 'logits = decoder_output × W_vocab + b_vocab',
333 | metadata: {
334 | operation: 'linear',
335 | formula: 'logits = H_dec × W_v + b_v',
336 | complexity: 'O(seq_len × d_model × vocab_size)',
337 | activation: 'none'
338 | }
339 | },
340 | {
341 | id: 'softmax-computation',
342 | label: 'Softmax Computation',
343 | description: 'Convert logits to probability distribution over vocabulary',
344 | type: 'text',
345 | data: 'probabilities = softmax(logits)',
346 | metadata: {
347 | operation: 'normalization',
348 | formula: 'P(token_i) = exp(logit_i) / Σ_j exp(logit_j)',
349 | complexity: 'O(seq_len × vocab_size)',
350 | activation: 'softmax'
351 | }
352 | }
353 | ];
354 |
355 | // Main function to generate detailed transformation data
356 | export const generateDetailedTransformations = (category: string, componentId: string): ComponentSection[] => {
357 | switch (category) {
358 | case 'embedding':
359 | return generateEmbeddingTransformations();
360 | case 'attention':
361 | // Handle Add & Norm blocks for attention
362 | if (isAddNormComponent(componentId)) {
363 | return generateAddNormTransformations();
364 | }
365 | return generateAttentionTransformations();
366 | case 'ffn':
367 | // Handle Add & Norm blocks for FFN
368 | if (isAddNormComponent(componentId)) {
369 | return generateAddNormTransformations();
370 | }
371 | return generateFFNTransformations();
372 | case 'output':
373 | // Handle specific output components
374 | if (componentId === 'encoder-output') {
375 | return generateEncoderOutputTransformations();
376 | }
377 | if (componentId === 'final-output') {
378 | return generateFinalOutputTransformations();
379 | }
380 | return generateOutputTransformations(componentId);
381 | default:
382 | // Handle Add & Norm blocks specifically (fallback)
383 | if (isAddNormComponent(componentId)) {
384 | return generateAddNormTransformations();
385 | }
386 | // Default transformation for other components
387 | return [{
388 | id: 'default-transform',
389 | label: 'Component Processing',
390 | description: 'Process inputs through this component',
391 | type: 'text',
392 | data: 'output = f(inputs, parameters)',
393 | metadata: {
394 | operation: 'function',
395 | formula: 'y = f(x, θ)',
396 | complexity: 'O(n)',
397 | activation: 'varies'
398 | }
399 | }];
400 | }
401 | };
--------------------------------------------------------------------------------
/src/components/ComponentDetailsPanel.tsx:
--------------------------------------------------------------------------------
1 | import { useState } from 'react';
2 | import { EyeIcon, DocumentTextIcon, ChevronDownIcon, ChevronRightIcon, ListBulletIcon } from '@heroicons/react/24/outline';
3 | import { ComponentData, type ComponentSection } from '../types/events';
4 | import { MatrixVisualization } from './MatrixVisualization';
5 | import { TokenVisualization } from './TokenVisualization';
6 | import { formatComponentName, getCategoryColor, getOperations, generateDetailedTransformations } from '../utils/componentTransformations';
7 |
8 | interface ComponentDetailsPanelProps {
9 | componentId: string;
10 | componentData: ComponentData;
11 | onClose: () => void;
12 | }
13 |
14 | export function ComponentDetailsPanel({
15 | componentId,
16 | componentData,
17 | onClose
18 | }: ComponentDetailsPanelProps) {
19 | const [viewMode, setViewMode] = useState<'textual' | 'visual'>('textual');
20 | const [expandedSections, setExpandedSections] = useState>(new Set());
21 | const [showTransformationModal, setShowTransformationModal] = useState(false);
22 |
23 | const toggleSection = (sectionId: string) => {
24 | const newExpanded = new Set(expandedSections);
25 | if (newExpanded.has(sectionId)) {
26 | newExpanded.delete(sectionId);
27 | } else {
28 | newExpanded.add(sectionId);
29 | }
30 | setExpandedSections(newExpanded);
31 | };
32 |
33 |
34 |
35 |
36 |
37 | const operations = getOperations(componentData.category, componentId);
38 |
39 | return (
40 |
41 | {/* Header */}
42 |
43 |
44 |
45 |
46 | Component Details
47 |
48 |
49 | {componentData.category}
50 |
51 |
52 |
58 |
59 |
60 |
61 |
62 | {formatComponentName(componentId)}
63 |
64 |
65 | {componentData.description}
66 |
67 |
68 |
69 | {/* Operations Section */}
70 |
71 |
72 |
Operations
73 |
81 |
82 |
83 | {operations.map((operation, index) => (
84 | -
85 | {index + 1}.
86 | {operation}
87 |
88 | ))}
89 |
90 |
91 |
92 | {/* View Mode Toggle */}
93 |
94 |
105 |
116 |
117 |
118 |
119 | {/* Content */}
120 |
121 | {/* Inputs Section - Only show if there are inputs */}
122 | {componentData.inputs.length > 0 && (
123 | toggleSection('inputs')}
129 | emptyMessage="No inputs (starting component)"
130 | />
131 | )}
132 |
133 | {/* Parameters Section - Only show if there are parameters */}
134 | {componentData.parameters.length > 0 && (
135 | toggleSection('parameters')}
141 | emptyMessage="No learnable parameters"
142 | />
143 | )}
144 |
145 | {/* Outputs Section - Only show if there are outputs */}
146 | {componentData.outputs.length > 0 && (
147 | toggleSection('outputs')}
153 | emptyMessage="No outputs (final component)"
154 | />
155 | )}
156 |
157 |
158 | {/* Transformation Detail Modal */}
159 | {showTransformationModal && (
160 |
setShowTransformationModal(false)}
164 | />
165 | )}
166 |
167 | );
168 | }
169 |
170 | interface ComponentSectionProps {
171 | title: string;
172 | sections: ComponentSection[];
173 | viewMode: 'textual' | 'visual';
174 | isExpanded: boolean;
175 | onToggle: () => void;
176 | emptyMessage: string;
177 | }
178 |
179 | function ComponentSectionContainer({
180 | title,
181 | sections,
182 | viewMode,
183 | isExpanded,
184 | onToggle,
185 | emptyMessage
186 | }: ComponentSectionProps) {
187 | const sectionColors = {
188 | Inputs: 'border-blue-200 bg-blue-50',
189 | Parameters: 'border-purple-200 bg-purple-50',
190 | Outputs: 'border-green-200 bg-green-50',
191 | };
192 |
193 | const iconColors = {
194 | Inputs: 'text-blue-600',
195 | Parameters: 'text-purple-600',
196 | Outputs: 'text-green-600',
197 | };
198 |
199 | // Helper function to get brief summary of items
200 | const getBriefSummary = () => {
201 | if (sections.length === 0) {
202 | return [];
203 | }
204 |
205 | return sections.map(section => {
206 | let summary = section.label;
207 |
208 | // Add shape information if available
209 | if (section.shape) {
210 | const shapeStr = section.shape.length === 2
211 | ? `${section.shape[0]}×${section.shape[1]}`
212 | : `${section.shape[0]}`;
213 | summary += `, ${shapeStr}`;
214 | }
215 |
216 | return summary;
217 | });
218 | };
219 |
220 | return (
221 |
222 |
236 |
237 | {/* Brief summary when collapsed */}
238 | {!isExpanded && sections.length > 0 && (
239 |
240 |
241 | {getBriefSummary().map((summary, index) => (
242 |
243 | •
244 | {summary}
245 |
246 | ))}
247 |
248 |
249 | )}
250 |
251 | {isExpanded && (
252 |
253 | {sections.length === 0 ? (
254 |
255 |
{emptyMessage}
256 |
257 | ) : (
258 |
259 | {sections.map((section) => (
260 |
265 | ))}
266 |
267 | )}
268 |
269 | )}
270 |
271 | );
272 | }
273 |
274 | interface SectionItemProps {
275 | section: ComponentSection;
276 | viewMode: 'textual' | 'visual';
277 | }
278 |
279 | function SectionItem({ section, viewMode }: SectionItemProps) {
280 | const renderContent = () => {
281 | if (viewMode === 'textual') {
282 | return (
283 |
284 |
285 |
{section.label}
286 | {section.shape && (
287 |
288 | {section.shape.length === 2 ? `${section.shape[0]}×${section.shape[1]}` : `${section.shape[0]}`}
289 |
290 | )}
291 |
292 |
{section.description}
293 | {section.metadata && (
294 |
295 | {Object.entries(section.metadata).map(([key, value]) => (
296 |
297 | {key}:
298 | {String(value)}
299 |
300 | ))}
301 |
302 | )}
303 |
304 | );
305 | } else {
306 | // Visual mode
307 | if (section.type === 'tokens') {
308 | return (
309 |
313 | );
314 | } else if (section.type === 'matrix' || section.type === 'vector') {
315 | return (
316 |
321 | );
322 | } else if (section.type === 'scalar') {
323 | return (
324 |
325 |
326 |
327 | {typeof section.data === 'number' ? section.data.toFixed(4) : section.data}
328 |
329 |
{section.label}
330 |
331 |
332 | );
333 | } else {
334 | return (
335 |
336 |
{section.description}
337 |
338 | );
339 | }
340 | }
341 | };
342 |
343 | return (
344 |
345 | {renderContent()}
346 |
347 | );
348 | }
349 |
350 | // Detailed transformation modal (keep existing implementation)
351 | interface TransformationDetailModalProps {
352 | componentName: string;
353 | transformations: ComponentSection[];
354 | onClose: () => void;
355 | }
356 |
357 | function TransformationDetailModal({
358 | componentName,
359 | transformations,
360 | onClose
361 | }: TransformationDetailModalProps) {
362 | const getOperationIcon = (operation: string) => {
363 | switch (operation) {
364 | case 'linear':
365 | case 'lookup':
366 | return '⊗';
367 | case 'addition':
368 | return '+';
369 | case 'dot-product':
370 | return '⋅';
371 | case 'normalization':
372 | return 'σ';
373 | case 'activation':
374 | return 'f';
375 | case 'weighted-sum':
376 | return '∑';
377 | case 'sampling':
378 | return '🎲';
379 | default:
380 | return '⚙️';
381 | }
382 | };
383 |
384 | const getOperationColor = (operation: string) => {
385 | switch (operation) {
386 | case 'linear':
387 | case 'lookup':
388 | return 'bg-blue-100 text-blue-800 border-blue-300';
389 | case 'addition':
390 | return 'bg-green-100 text-green-800 border-green-300';
391 | case 'dot-product':
392 | return 'bg-purple-100 text-purple-800 border-purple-300';
393 | case 'normalization':
394 | return 'bg-yellow-100 text-yellow-800 border-yellow-300';
395 | case 'activation':
396 | return 'bg-red-100 text-red-800 border-red-300';
397 | case 'weighted-sum':
398 | return 'bg-indigo-100 text-indigo-800 border-indigo-300';
399 | case 'sampling':
400 | return 'bg-pink-100 text-pink-800 border-pink-300';
401 | default:
402 | return 'bg-gray-100 text-gray-800 border-gray-300';
403 | }
404 | };
405 |
406 | const getActivationBadge = (activation: string) => {
407 | if (!activation || activation === 'none') return null;
408 |
409 | const activationColors = {
410 | 'relu': 'bg-red-500 text-white',
411 | 'softmax': 'bg-blue-500 text-white',
412 | 'tanh': 'bg-green-500 text-white',
413 | 'sigmoid': 'bg-purple-500 text-white',
414 | 'gelu': 'bg-orange-500 text-white',
415 | };
416 |
417 | return (
418 |
419 | {activation.toUpperCase()}
420 |
421 | );
422 | };
423 |
424 | return (
425 |
426 |
427 | {/* Header */}
428 |
429 |
430 |
431 |
{componentName}
432 |
Detailed Transformation Flow
433 |
434 |
440 |
441 |
442 |
443 | {/* Content */}
444 |
445 |
446 | {transformations.map((transformation, index) => (
447 |
448 |
449 |
450 | {/* Step indicator */}
451 |
452 |
453 | {getOperationIcon(transformation.metadata?.operation || 'default')}
454 |
455 | {index < transformations.length - 1 && (
456 |
457 | )}
458 |
459 |
460 | {/* Content */}
461 |
462 |
463 |
464 |
{transformation.label}
465 |
{transformation.description}
466 |
467 |
468 | {getActivationBadge(transformation.metadata?.activation)}
469 |
470 | {transformation.metadata?.complexity}
471 |
472 |
473 |
474 |
475 | {/* Code and formula */}
476 |
477 |
478 |
Operation:
479 |
{transformation.data}
480 |
481 |
482 | {transformation.metadata?.formula && (
483 |
484 |
Mathematical Formula:
485 |
{transformation.metadata.formula}
486 |
487 | )}
488 |
489 |
490 |
491 |
492 |
493 | ))}
494 |
495 |
496 |
497 | {/* Footer */}
498 |
499 |
500 | Click outside the modal or the ✕ button to close
501 |
502 |
503 |
504 |
505 | );
506 | }
--------------------------------------------------------------------------------
/src/hooks/useTransformerDiagram.ts:
--------------------------------------------------------------------------------
1 | import { useRef, useEffect, RefObject } from 'react';
2 | import * as d3 from 'd3';
3 | import { TransformerParams } from '../types/events';
4 | import { COLORS } from '../utils/constants';
5 |
6 | interface ComponentPosition {
7 | x: number;
8 | y: number;
9 | width: number;
10 | height: number;
11 | id: string;
12 | }
13 |
14 | const drawImprovedTransformerArchitecture = (
15 | g: d3.Selection,
16 | params: TransformerParams,
17 | selectedComponent: string | null,
18 | onComponentClick: (componentId: string) => void,
19 | width: number,
20 | height: number,
21 | ) => {
22 | const componentWidth = 120;
23 | const componentHeight = 35;
24 | const positionalSize = 28; // Size for circular positional block
25 |
26 | const sectionWidth = Math.max(componentWidth + 40, Math.min(300, width * 0.45));
27 |
28 | const encoderX = width * 0.25;
29 | const decoderX = width * 0.75;
30 |
31 | const components: ComponentPosition[] = [];
32 |
33 | // CodeSignal brand colors for consistent visual identity
34 | const styles = {
35 | input: { color: '#6b7280', textColor: 'white', canClick: true }, // Gray 500
36 | embedding: { color: '#1062FB', textColor: 'white', canClick: true }, // Beyond Blue (Primary)
37 | positional: { color: '#FFC500', textColor: 'black', canClick: true }, // Gold Standard Yellow (Secondary)
38 | attention: { color: '#E6193F', textColor: 'white', canClick: true }, // Ruby Red (Tertiary)
39 | norm: { color: '#21CF82', textColor: 'white', canClick: true }, // Victory Green (Tertiary)
40 | ffn: { color: '#FF7232', textColor: 'white', canClick: true }, // FiresIDE Orange (Tertiary)
41 | output: { color: '#002570', textColor: 'white', canClick: true }, // Dive Deep Blue (Secondary)
42 | decoder: { color: '#002570', textColor: 'white', canClick: true }, // Dive Deep Blue (Secondary)
43 | };
44 |
45 | const CROSS_COLOR = '#64D3FF'; // SignalLight Blue (Secondary)
46 | const RESIDUAL_COLOR = '#9ca3af'; // Gray 400
47 |
48 | const createComponent = (
49 | id: string,
50 | label: string,
51 | type: string,
52 | x: number,
53 | y: number,
54 | width = componentWidth,
55 | height = componentHeight,
56 | ): ComponentPosition => {
57 | const isSelected = selectedComponent === id;
58 | const style = styles[type as keyof typeof styles];
59 |
60 | const group = g
61 | .append('g')
62 | .attr('class', 'component')
63 | .style('cursor', style.canClick ? 'pointer' : 'default')
64 | .on('click', () => style.canClick && onComponentClick(id));
65 |
66 | group
67 | .append('rect')
68 | .attr('x', x - width / 2 + 2)
69 | .attr('y', y - height / 2 + 2)
70 | .attr('width', width)
71 | .attr('height', height)
72 | .attr('rx', 6)
73 | .attr('fill', '#00000020')
74 | .attr('opacity', 0.3);
75 |
76 | group
77 | .append('rect')
78 | .attr('x', x - width / 2)
79 | .attr('y', y - height / 2)
80 | .attr('width', width)
81 | .attr('height', height)
82 | .attr('rx', 6)
83 | .attr('fill', style.color)
84 | .attr('stroke', isSelected ? COLORS.gold : 'none')
85 | .attr('stroke-width', isSelected ? 3 : 0)
86 | .style('filter', isSelected ? 'drop-shadow(0 0 10px rgba(255, 197, 0, 0.5))' : 'none');
87 |
88 | const lines = label.split('\n');
89 | const fontSize = lines.length > 1 ? 10 : 11;
90 | lines.forEach((line, i) => {
91 | group
92 | .append('text')
93 | .attr('x', x)
94 | .attr('y', y + (i - (lines.length - 1) / 2) * 12)
95 | .attr('text-anchor', 'middle')
96 | .attr('dy', '0.35em')
97 | .attr('font-size', fontSize)
98 | .attr('font-weight', isSelected ? 'bold' : '500')
99 | .attr('fill', style.textColor)
100 | .text(line);
101 | });
102 |
103 | const position = { x, y, width, height, id };
104 | components.push(position);
105 | return position;
106 | };
107 |
108 | const createPositionalComponent = (
109 | id: string,
110 | x: number,
111 | y: number,
112 | size = positionalSize,
113 | ): ComponentPosition => {
114 | const isSelected = selectedComponent === id;
115 | const style = styles.positional;
116 |
117 | const group = g
118 | .append('g')
119 | .attr('class', 'component')
120 | .style('cursor', style.canClick ? 'pointer' : 'default')
121 | .on('click', () => style.canClick && onComponentClick(id));
122 |
123 | // Drop shadow
124 | group
125 | .append('circle')
126 | .attr('cx', x + 2)
127 | .attr('cy', y + 2)
128 | .attr('r', size / 2)
129 | .attr('fill', '#00000020')
130 | .attr('opacity', 0.3);
131 |
132 | // Main circle
133 | group
134 | .append('circle')
135 | .attr('cx', x)
136 | .attr('cy', y)
137 | .attr('r', size / 2)
138 | .attr('fill', style.color)
139 | .attr('stroke', isSelected ? COLORS.gold : 'none')
140 | .attr('stroke-width', isSelected ? 3 : 0)
141 | .style('filter', isSelected ? 'drop-shadow(0 0 10px rgba(255, 197, 0, 0.5))' : 'none');
142 |
143 | // Position icon (clock-like symbol)
144 | const iconGroup = group.append('g');
145 |
146 | // Clock face
147 | iconGroup
148 | .append('circle')
149 | .attr('cx', x)
150 | .attr('cy', y)
151 | .attr('r', size / 4)
152 | .attr('fill', 'none')
153 | .attr('stroke', style.textColor)
154 | .attr('stroke-width', 1.5);
155 |
156 | // Clock hands
157 | iconGroup
158 | .append('line')
159 | .attr('x1', x)
160 | .attr('y1', y)
161 | .attr('x2', x)
162 | .attr('y2', y - size / 6)
163 | .attr('stroke', style.textColor)
164 | .attr('stroke-width', 1.5)
165 | .attr('stroke-linecap', 'round');
166 |
167 | iconGroup
168 | .append('line')
169 | .attr('x1', x)
170 | .attr('y1', y)
171 | .attr('x2', x + size / 8)
172 | .attr('y2', y)
173 | .attr('stroke', style.textColor)
174 | .attr('stroke-width', 1.5)
175 | .attr('stroke-linecap', 'round');
176 |
177 | // Center dot
178 | iconGroup
179 | .append('circle')
180 | .attr('cx', x)
181 | .attr('cy', y)
182 | .attr('r', 1.5)
183 | .attr('fill', style.textColor);
184 |
185 | const position = { x, y, width: size, height: size, id };
186 | components.push(position);
187 | return position;
188 | };
189 |
190 | const drawArrow = (from: ComponentPosition, to: ComponentPosition, type: 'main' | 'residual' | 'cross' | 'side' = 'main') => {
191 | let fromX, fromY, toX, toY;
192 |
193 | if (type === 'residual') {
194 | const isEncoder = from.x < width / 2;
195 | if (isEncoder) {
196 | fromX = from.x - from.width / 2;
197 | fromY = from.y + from.height / 4;
198 | toX = to.x - to.width / 2;
199 | toY = to.y - to.height / 4;
200 | } else {
201 | fromX = from.x + from.width / 2;
202 | fromY = from.y + from.height / 4;
203 | toX = to.x + to.width / 2;
204 | toY = to.y - to.height / 4;
205 | }
206 | } else if (type === 'cross') {
207 | fromX = from.x + from.width / 2;
208 | fromY = from.y;
209 | toX = to.x - to.width / 2;
210 | toY = to.y;
211 | } else if (type === 'side') {
212 | // Side-to-side arrow for positional to embedding
213 | const isEncoder = from.x < width / 2;
214 | if (isEncoder) {
215 | // Encoder: positional (right) to embedding (left) - arrow goes from left side of positional to right side of embedding
216 | fromX = from.x - from.width / 2;
217 | fromY = from.y;
218 | toX = to.x + to.width / 2;
219 | toY = to.y;
220 | } else {
221 | // Decoder: positional (left) to embedding (right) - arrow goes from right side of positional to left side of embedding
222 | fromX = from.x + from.width / 2;
223 | fromY = from.y;
224 | toX = to.x - to.width / 2;
225 | toY = to.y;
226 | }
227 | } else {
228 | fromX = from.x;
229 | fromY = from.y + from.height / 2;
230 | toX = to.x;
231 | toY = to.y - to.height / 2;
232 | }
233 |
234 | const strokeColor = type === 'residual' ? RESIDUAL_COLOR : type === 'cross' ? CROSS_COLOR : '#64748b';
235 | const strokeWidth = type === 'residual' ? 2 : type === 'cross' ? 2 : 1.5;
236 |
237 | const path = g
238 | .append('path')
239 | .attr('stroke', strokeColor)
240 | .attr('stroke-width', strokeWidth)
241 | .attr('fill', 'none')
242 | .attr('marker-end', `url(#arrowhead-${type === 'side' ? 'main' : type})`);
243 |
244 | if (type === 'cross') {
245 | const midX = (fromX + toX) / 2;
246 | path.attr('d', `M ${fromX} ${fromY} L ${midX} ${fromY} L ${midX} ${toY} L ${toX} ${toY}`);
247 | } else if (type === 'residual') {
248 | const sideOffset = 30;
249 | const isEncoder = from.x < width / 2;
250 | const sideX = isEncoder ? fromX - sideOffset : fromX + sideOffset;
251 | const pathData = [`M ${fromX} ${fromY}`, `L ${sideX} ${fromY}`, `L ${sideX} ${toY}`, `L ${toX} ${toY}`].join(' ');
252 | path.attr('d', pathData);
253 | } else {
254 | path.attr('d', `M ${fromX} ${fromY} L ${toX} ${toY}`);
255 | }
256 | };
257 |
258 | const defs = g.append('defs');
259 | ['main', 'residual', 'cross'].forEach((type) => {
260 | const color = type === 'residual' ? RESIDUAL_COLOR : type === 'cross' ? CROSS_COLOR : '#64748b';
261 | defs
262 | .append('marker')
263 | .attr('id', `arrowhead-${type}`)
264 | .attr('viewBox', '0 0 10 10')
265 | .attr('refX', 8)
266 | .attr('refY', 3)
267 | .attr('markerWidth', 6)
268 | .attr('markerHeight', 6)
269 | .attr('orient', 'auto')
270 | .append('path')
271 | .attr('d', 'M 0 0 L 10 3 L 0 6 z')
272 | .attr('fill', color);
273 | });
274 |
275 | g.append('rect')
276 | .attr('x', encoderX - sectionWidth / 2)
277 | .attr('y', 30)
278 | .attr('width', sectionWidth)
279 | .attr('height', height - 60)
280 | .attr('fill', '#3b82f6')
281 | .attr('fill-opacity', 0.2)
282 | .attr('stroke', '#3b82f6')
283 | .attr('stroke-width', 3)
284 | .attr('stroke-dasharray', '8,4')
285 | .attr('rx', 12);
286 |
287 | g.append('rect')
288 | .attr('x', decoderX - sectionWidth / 2)
289 | .attr('y', 30)
290 | .attr('width', sectionWidth)
291 | .attr('height', height - 60)
292 | .attr('fill', '#8b5cf6')
293 | .attr('fill-opacity', 0.2)
294 | .attr('stroke', '#8b5cf6')
295 | .attr('stroke-width', 3)
296 | .attr('stroke-dasharray', '8,4')
297 | .attr('rx', 12);
298 |
299 | const encoderOutput = drawEncoderStack(g, params, createComponent, createPositionalComponent, drawArrow, encoderX, sectionWidth);
300 | drawDecoderStack(g, params, createComponent, createPositionalComponent, drawArrow, decoderX, encoderOutput, sectionWidth);
301 |
302 | drawArrow(
303 | encoderOutput,
304 | components.find((c) => c.id.startsWith('decoder-layer-0-cross-attention'))!,
305 | 'cross',
306 | );
307 | };
308 |
309 | const drawEncoderStack = (
310 | g: d3.Selection,
311 | params: TransformerParams,
312 | createComponent: (id: string, label: string, type: string, x: number, y: number, width?: number, height?: number) => ComponentPosition,
313 | createPositionalComponent: (id: string, x: number, y: number, size?: number) => ComponentPosition,
314 | drawArrow: (from: ComponentPosition, to: ComponentPosition, type?: 'main' | 'residual' | 'cross' | 'side') => void,
315 | encoderX: number,
316 | sectionWidth: number,
317 | ): ComponentPosition => {
318 | const { numLayers } = params;
319 | const intraLayerSpacing = 20;
320 | const interLayerSpacing = 30;
321 | const componentHeight = 35;
322 | const positionalSize = 28;
323 | const initialY = 80;
324 |
325 | const inputTokens = createComponent('input-tokens', 'Input Tokens', 'input', encoderX, initialY);
326 |
327 | // Create separate embedding and positional components
328 | const embeddingY = inputTokens.y + componentHeight + interLayerSpacing;
329 | const embedding = createComponent('embedding', 'Embedding', 'embedding', encoderX, embeddingY);
330 |
331 | // Positional component to the right of embedding in encoder with more spacing
332 | const positional = createPositionalComponent('positional-encoding', encoderX + 100, embeddingY, positionalSize);
333 |
334 | // Arrows: input -> embedding, positional -> embedding (side)
335 | drawArrow(inputTokens, embedding);
336 | drawArrow(positional, embedding, 'side');
337 |
338 | let prevLayerOutput = embedding;
339 | let currentY = embedding.y + componentHeight / 2 + interLayerSpacing;
340 |
341 | for (let i = 0; i < numLayers; i++) {
342 | const layerGroup = g.append('g').attr('class', 'encoder-layer-group');
343 | const layerHeight = 4 * componentHeight + 5 * intraLayerSpacing;
344 | const layerY = currentY;
345 |
346 | layerGroup
347 | .append('rect')
348 | .attr('x', encoderX - sectionWidth / 2 + 10)
349 | .attr('y', layerY - intraLayerSpacing / 2)
350 | .attr('width', sectionWidth - 20)
351 | .attr('height', layerHeight)
352 | .attr('rx', 8)
353 | .attr('fill', 'rgba(255, 255, 255, 0.4)')
354 | .attr('stroke', '#9ca3af')
355 | .attr('stroke-dasharray', '3,3');
356 |
357 | const attention = createComponent(
358 | `encoder-layer-${i}-attention`,
359 | `Multi-Head\nAttention`,
360 | 'attention',
361 | encoderX,
362 | layerY + intraLayerSpacing + componentHeight / 2,
363 | );
364 | drawArrow(prevLayerOutput, attention);
365 |
366 | const attentionNorm = createComponent(`encoder-layer-${i}-attention-norm`, 'Add & Norm', 'norm', encoderX, attention.y + componentHeight + intraLayerSpacing);
367 | drawArrow(attention, attentionNorm);
368 | drawArrow(prevLayerOutput, attentionNorm, 'residual');
369 |
370 | const ffn = createComponent(`encoder-layer-${i}-ffn`, 'Feed Forward', 'ffn', encoderX, attentionNorm.y + componentHeight + intraLayerSpacing);
371 | drawArrow(attentionNorm, ffn);
372 |
373 | const ffnNorm = createComponent(`encoder-layer-${i}-ffn-norm`, 'Add & Norm', 'norm', encoderX, ffn.y + componentHeight + intraLayerSpacing);
374 | drawArrow(ffn, ffnNorm);
375 | drawArrow(attentionNorm, ffnNorm, 'residual');
376 |
377 | prevLayerOutput = ffnNorm;
378 | currentY = ffnNorm.y + componentHeight / 2 + interLayerSpacing;
379 | }
380 |
381 | const encoderOutput = createComponent(
382 | 'encoder-output',
383 | 'Encoder Output',
384 | 'output',
385 | encoderX,
386 | currentY + componentHeight / 2 + intraLayerSpacing,
387 | );
388 | drawArrow(prevLayerOutput, encoderOutput);
389 |
390 | return encoderOutput;
391 | };
392 |
393 | const drawDecoderStack = (
394 | g: d3.Selection,
395 | params: TransformerParams,
396 | createComponent: (id: string, label: string, type: string, x: number, y: number, width?: number, height?: number) => ComponentPosition,
397 | createPositionalComponent: (id: string, x: number, y: number, size?: number) => ComponentPosition,
398 | drawArrow: (from: ComponentPosition, to: ComponentPosition, type?: 'main' | 'residual' | 'cross' | 'side') => void,
399 | decoderX: number,
400 | encoderOutput: ComponentPosition,
401 | sectionWidth: number,
402 | ) => {
403 | const { numLayers } = params;
404 | const intraLayerSpacing = 20;
405 | const interLayerSpacing = 30;
406 | const componentHeight = 35;
407 | const positionalSize = 28;
408 | const initialY = 80;
409 |
410 | const outputTokens = createComponent('output-tokens', 'Output Tokens', 'input', decoderX, initialY);
411 |
412 | // Create separate embedding and positional components
413 | const embeddingY = outputTokens.y + componentHeight + interLayerSpacing;
414 | const embedding = createComponent('decoder-embedding', 'Embedding', 'embedding', decoderX, embeddingY);
415 |
416 | // Positional component to the left of embedding in decoder with more spacing
417 | const positional = createPositionalComponent('decoder-positional-encoding', decoderX - 100, embeddingY, positionalSize);
418 |
419 | // Arrows: output -> embedding, positional -> embedding (side)
420 | drawArrow(outputTokens, embedding);
421 | drawArrow(positional, embedding, 'side');
422 |
423 | let prevLayerOutput = embedding;
424 | let currentY = embedding.y + componentHeight / 2 + interLayerSpacing;
425 |
426 | for (let i = 0; i < numLayers; i++) {
427 | const layerGroup = g.append('g').attr('class', 'decoder-layer-group');
428 | const layerHeight = 6 * componentHeight + 7 * intraLayerSpacing;
429 | const layerY = currentY;
430 |
431 | layerGroup
432 | .append('rect')
433 | .attr('x', decoderX - sectionWidth / 2 + 10)
434 | .attr('y', layerY - intraLayerSpacing / 2)
435 | .attr('width', sectionWidth - 20)
436 | .attr('height', layerHeight)
437 | .attr('rx', 8)
438 | .attr('fill', 'rgba(255, 255, 255, 0.4)')
439 | .attr('stroke', '#9ca3af')
440 | .attr('stroke-dasharray', '3,3');
441 |
442 | const maskedAttention = createComponent(
443 | `decoder-layer-${i}-masked-attention`,
444 | `Masked Multi-Head\nAttention`,
445 | 'attention',
446 | decoderX,
447 | layerY + intraLayerSpacing + componentHeight / 2,
448 | );
449 | drawArrow(prevLayerOutput, maskedAttention);
450 |
451 | const maskedAttentionNorm = createComponent(`decoder-layer-${i}-masked-attention-norm`, 'Add & Norm', 'norm', decoderX, maskedAttention.y + componentHeight + intraLayerSpacing);
452 | drawArrow(maskedAttention, maskedAttentionNorm);
453 | drawArrow(prevLayerOutput, maskedAttentionNorm, 'residual');
454 |
455 | const crossAttention = createComponent(`decoder-layer-${i}-cross-attention`, `Cross-Attention`, 'attention', decoderX, maskedAttentionNorm.y + componentHeight + intraLayerSpacing);
456 | drawArrow(maskedAttentionNorm, crossAttention);
457 | drawArrow(encoderOutput, crossAttention, 'cross');
458 |
459 | const crossAttentionNorm = createComponent(`decoder-layer-${i}-cross-attention-norm`, 'Add & Norm', 'norm', decoderX, crossAttention.y + componentHeight + intraLayerSpacing);
460 | drawArrow(crossAttention, crossAttentionNorm);
461 | drawArrow(maskedAttentionNorm, crossAttentionNorm, 'residual');
462 |
463 | const ffn = createComponent(`decoder-layer-${i}-ffn`, 'Feed Forward', 'ffn', decoderX, crossAttentionNorm.y + componentHeight + intraLayerSpacing);
464 | drawArrow(crossAttentionNorm, ffn);
465 |
466 | const ffnNorm = createComponent(`decoder-layer-${i}-ffn-norm`, 'Add & Norm', 'norm', decoderX, ffn.y + componentHeight + intraLayerSpacing);
467 | drawArrow(ffn, ffnNorm);
468 | drawArrow(crossAttentionNorm, ffnNorm, 'residual');
469 |
470 | prevLayerOutput = ffnNorm;
471 | currentY = ffnNorm.y + componentHeight / 2 + interLayerSpacing;
472 | }
473 |
474 | const finalOutput = createComponent(
475 | 'final-output',
476 | 'Linear + Softmax',
477 | 'ffn',
478 | decoderX,
479 | currentY + componentHeight / 2 + intraLayerSpacing,
480 | );
481 | drawArrow(prevLayerOutput, finalOutput);
482 |
483 | const predictedTokens = createComponent('predicted-tokens', 'Predicted Tokens', 'output', decoderX, finalOutput.y + 60);
484 | drawArrow(finalOutput, predictedTokens);
485 | };
486 |
487 | export const useTransformerDiagram = (
488 | params: TransformerParams,
489 | selectedComponent: string | null,
490 | onComponentClick: (componentId: string) => void,
491 | containerRef: RefObject,
492 | ) => {
493 | const svgRef = useRef(null);
494 |
495 | useEffect(() => {
496 | if (!svgRef.current || !containerRef.current) return;
497 |
498 | const svg = d3.select(svgRef.current);
499 | svg.selectAll('*').remove();
500 |
501 | const containerRect = containerRef.current.getBoundingClientRect();
502 | const width = Math.max(containerRect.width, 300);
503 |
504 | const componentHeight = 35;
505 | const intraLayerSpacing = 20;
506 | const interLayerSpacing = 30;
507 | const initialY = 80;
508 |
509 | let totalHeight = initialY + componentHeight + interLayerSpacing;
510 |
511 | const encoderLayerH = 4 * componentHeight + 6 * intraLayerSpacing;
512 | const encoderStackHeight = params.numLayers * (encoderLayerH + interLayerSpacing);
513 |
514 | const decoderLayerH = 6 * componentHeight + 8 * intraLayerSpacing;
515 | const decoderStackHeight = params.numLayers * (decoderLayerH + interLayerSpacing);
516 |
517 | totalHeight += Math.max(encoderStackHeight, decoderStackHeight);
518 | totalHeight += componentHeight + interLayerSpacing + intraLayerSpacing;
519 | totalHeight += 60;
520 | totalHeight += 60;
521 |
522 | const height = Math.max(totalHeight, 600);
523 |
524 | svg.attr('viewBox', `0 0 ${width} ${height}`).attr('preserveAspectRatio', 'xMidYMid meet');
525 |
526 | const mainGroup = svg.append('g');
527 |
528 | // Simple fixed positioning without zoom functionality
529 | mainGroup.attr('transform', 'translate(20, 20)');
530 |
531 | drawImprovedTransformerArchitecture(mainGroup, params, selectedComponent, onComponentClick, width - 40, height - 40);
532 | }, [params, selectedComponent, onComponentClick, containerRef]);
533 |
534 | return svgRef;
535 | };
--------------------------------------------------------------------------------
/src/utils/componentDataGenerator.ts:
--------------------------------------------------------------------------------
1 | import { ComponentData, TransformerParams } from '../types/events';
2 | import { generateMatrix, generateEmbeddings, generatePositionalEncoding } from './randomWeights';
3 | import { matmul, softmax } from './math';
4 | import { SequenceGenerator } from './data';
5 |
6 | // Helper function to parse layer component IDs
7 | interface LayerComponentInfo {
8 | isLayerComponent: boolean;
9 | layerType: 'encoder' | 'decoder' | null;
10 | componentType: 'attention' | 'masked-attention' | 'cross-attention' | 'ffn' | 'norm' | null;
11 | isNorm: boolean;
12 | }
13 |
14 | function parseLayerComponentId(componentId: string): LayerComponentInfo {
15 | const isEncoderLayer = componentId.startsWith('encoder-layer-');
16 | const isDecoderLayer = componentId.startsWith('decoder-layer-');
17 |
18 | if (!isEncoderLayer && !isDecoderLayer) {
19 | return { isLayerComponent: false, layerType: null, componentType: null, isNorm: false };
20 | }
21 |
22 | const layerType = isEncoderLayer ? 'encoder' : 'decoder';
23 | const isNorm = componentId.includes('-norm');
24 |
25 | let componentType: LayerComponentInfo['componentType'] = null;
26 | if (componentId.includes('-masked-attention')) {
27 | componentType = 'masked-attention';
28 | } else if (componentId.includes('-cross-attention')) {
29 | componentType = 'cross-attention';
30 | } else if (componentId.includes('-attention')) {
31 | componentType = 'attention';
32 | } else if (componentId.includes('-ffn')) {
33 | componentType = 'ffn';
34 | } else if (isNorm) {
35 | componentType = 'norm';
36 | }
37 |
38 | return { isLayerComponent: true, layerType, componentType, isNorm };
39 | }
40 |
41 | /**
42 | * Generates detailed component data for transformer visualization including inputs, parameters, and outputs.
43 | * @param componentId - Unique identifier for the transformer component
44 | * @param params - Transformer model parameters (dimensions, layers, etc.)
45 | * @returns ComponentData object with structured information about the component
46 | */
47 | export function generateComponentData(componentId: string, params: TransformerParams): ComponentData {
48 | const { dModel, numHeads, seqLen, posEncoding } = params;
49 |
50 | // Handle input tokens
51 | if (componentId === 'input-tokens' || componentId === 'output-tokens') {
52 | return generateInputTokensData(seqLen, componentId);
53 | }
54 |
55 | // Handle embeddings
56 | if (componentId === 'embedding' || componentId === 'decoder-embedding') {
57 | return generateEmbeddingData(seqLen, dModel);
58 | }
59 |
60 | // Handle positional encoding
61 | if (componentId === 'positional-encoding' || componentId === 'decoder-positional-encoding') {
62 | return generatePositionalEncodingData(seqLen, dModel, posEncoding);
63 | }
64 |
65 | // Handle layer components using helper function
66 | const layerInfo = parseLayerComponentId(componentId);
67 | if (layerInfo.isLayerComponent) {
68 | if (layerInfo.isNorm) {
69 | // Add & Norm components
70 | const blockType = componentId.includes('attention') ? 'attention' : 'ffn';
71 | return generateAddNormData(seqLen, dModel, blockType);
72 | }
73 |
74 | // Non-norm layer components
75 | switch (layerInfo.componentType) {
76 | case 'attention':
77 | return generateAttentionData(seqLen, dModel, numHeads, 'encoder');
78 | case 'masked-attention':
79 | return generateAttentionData(seqLen, dModel, numHeads, 'decoder-masked');
80 | case 'cross-attention':
81 | return generateAttentionData(seqLen, dModel, numHeads, 'cross');
82 | case 'ffn':
83 | return generateFFNData(seqLen, dModel);
84 | }
85 | }
86 |
87 | // Handle final output components
88 | if (componentId === 'predicted-tokens') {
89 | return generatePredictedTokensData(seqLen);
90 | }
91 |
92 | // Handle encoder output
93 | if (componentId === 'encoder-output') {
94 | return generateEncoderOutputData(seqLen, dModel);
95 | }
96 |
97 | // Handle final output (Linear + Softmax)
98 | if (componentId === 'final-output') {
99 | return generateFinalOutputData(seqLen, dModel);
100 | }
101 |
102 | // Default case - determine category from componentId
103 | return generateDefaultData(componentId);
104 | }
105 |
106 | function generateInputTokensData(seqLen: number, componentId: string): ComponentData {
107 | // Generate simple tokens
108 | const tokens = SequenceGenerator.generateForComponent(componentId, seqLen);
109 |
110 | const isOutputTokens = componentId === 'output-tokens';
111 | const description = isOutputTokens ?
112 | 'Target sequence tokens for decoder (shifted right during training). These guide the model\'s generation process.' :
113 | 'Raw input tokens before embedding. These discrete symbols represent words, subwords, or characters from the input text.';
114 |
115 | return {
116 | description,
117 | category: 'tokens',
118 | inputs: [],
119 | parameters: [],
120 | outputs: [
121 | {
122 | id: 'tokens',
123 | label: isOutputTokens ? 'Output Tokens' : 'Input Tokens',
124 | description: `Tokenized sequence (length: ${seqLen}) with special control tokens`,
125 | type: 'tokens',
126 | data: tokens,
127 | shape: [tokens.length],
128 | metadata: {
129 | vocab_size: 50000,
130 | special_tokens: tokens.filter(t => t.startsWith('<')).length,
131 | sequence_type: isOutputTokens ? 'target' : 'source',
132 | actual_tokens: tokens.filter(t => !t.startsWith('<')).length
133 | }
134 | }
135 | ]
136 | };
137 | }
138 |
139 | function generateEmbeddingData(seqLen: number, dModel: number): ComponentData {
140 | const embeddings = generateEmbeddings(seqLen, dModel);
141 | const positionalEnc = generatePositionalEncoding(seqLen, dModel, 'sinusoidal');
142 |
143 | // Generate simple tokens
144 | const tokens = SequenceGenerator.generateSequence(seqLen, 'input');
145 |
146 | // Combined embeddings with positional encoding
147 | const combinedEmbeddings = embeddings.map((row, i) =>
148 | row.map((val, j) => val + positionalEnc[i][j])
149 | );
150 |
151 | return {
152 | description: `Converts discrete tokens into dense ${dModel}-dimensional vector representations and combines them with positional encoding to preserve sequence order information.`,
153 | category: 'embedding',
154 | inputs: [
155 | {
156 | id: 'tokens',
157 | label: 'Input Tokens',
158 | description: 'Discrete token indices from vocabulary',
159 | type: 'tokens',
160 | data: tokens,
161 | shape: [seqLen],
162 | metadata: { vocab_size: 50000, encoding: 'utf-8' }
163 | },
164 | {
165 | id: 'positional-encoding',
166 | label: 'Positional Encoding (P)',
167 | description: 'Positional information to be added to embeddings',
168 | type: 'matrix',
169 | data: positionalEnc,
170 | shape: [seqLen, dModel],
171 | metadata: { from_component: 'positional-encoding', type: 'sinusoidal' }
172 | }
173 | ],
174 | parameters: [
175 | {
176 | id: 'embedding-matrix',
177 | label: 'Embedding Matrix (W_e)',
178 | description: 'Learned lookup table mapping token indices to dense vectors',
179 | type: 'matrix',
180 | data: generateMatrix(Math.min(50000, 100), dModel).slice(0, Math.min(20, seqLen + 5)), // Show relevant subset
181 | shape: [50000, dModel],
182 | metadata: {
183 | learnable: true,
184 | initialized: 'xavier_uniform',
185 | parameters: 50000 * dModel,
186 | gradient_norm: 'clipped'
187 | }
188 | }
189 | ],
190 | outputs: [
191 | {
192 | id: 'positioned-embeddings',
193 | label: 'Positioned Embeddings (H)',
194 | description: 'Token embeddings enhanced with positional information (E + P)',
195 | type: 'matrix',
196 | data: combinedEmbeddings,
197 | shape: [seqLen, dModel],
198 | metadata: {
199 | operation: 'embedding_lookup + positional_encoding',
200 | range: 'normalized',
201 | ready_for_attention: true
202 | }
203 | }
204 | ]
205 | };
206 | }
207 |
208 | function generatePositionalEncodingData(seqLen: number, dModel: number, posEncoding: string): ComponentData {
209 | const posEnc = generatePositionalEncoding(seqLen, dModel, posEncoding as 'sinusoidal' | 'learned');
210 |
211 | return {
212 | description: `Generates ${posEncoding} positional information for sequence length ${seqLen}. This helps the model understand the order and position of tokens in the sequence.`,
213 | category: 'positional',
214 | inputs: [], // No inputs - positional encoding is generated independently
215 | parameters: posEncoding === 'learned' ? [
216 | {
217 | id: 'position-embeddings',
218 | label: 'Learned Position Embeddings (P)',
219 | description: 'Trainable position-specific vectors',
220 | type: 'matrix',
221 | data: posEnc,
222 | shape: [2048, dModel], // Standard max length
223 | metadata: { learnable: true, max_position: 2048 }
224 | }
225 | ] : [],
226 | outputs: [
227 | {
228 | id: 'positional-encoding',
229 | label: 'Positional Encoding (P)',
230 | description: posEncoding === 'learned' ?
231 | `Learned positional encodings for ${seqLen} positions` :
232 | `Fixed sinusoidal positional encodings computed using mathematical formulas`,
233 | type: 'matrix',
234 | data: posEnc,
235 | shape: [seqLen, dModel],
236 | metadata: {
237 | learnable: posEncoding === 'learned',
238 | formula: posEncoding === 'sinusoidal' ? 'PE(pos,2i) = sin(pos/10000^(2i/d)), PE(pos,2i+1) = cos(pos/10000^(2i/d))' : 'learned_vectors',
239 | max_position: 2048,
240 | frequency_bands: dModel / 2
241 | }
242 | }
243 | ]
244 | };
245 | }
246 |
247 | function generateAttentionData(seqLen: number, dModel: number, numHeads: number, attentionType: string = 'encoder'): ComponentData {
248 | const headSize = dModel / numHeads;
249 | const input = generateEmbeddings(seqLen, dModel);
250 |
251 | // Generate Q, K, V matrices for visualization
252 | const q = generateMatrix(seqLen, headSize);
253 | const k = generateMatrix(seqLen, headSize);
254 | const v = generateMatrix(seqLen, headSize);
255 |
256 | // Compute attention scores with proper scaling
257 | const scale = 1 / Math.sqrt(headSize);
258 | const kTransposed = k[0].map((_, colIndex) => k.map(row => row[colIndex]));
259 | const rawScores = matmul(q, kTransposed);
260 | let scaledScores = rawScores.map(row => row.map(val => val * scale));
261 |
262 | // Apply masking for decoder attention
263 | if (attentionType === 'decoder-masked') {
264 | for (let i = 0; i < seqLen; i++) {
265 | for (let j = i + 1; j < seqLen; j++) {
266 | scaledScores[i][j] = -Infinity;
267 | }
268 | }
269 | }
270 |
271 | const attentionWeights = softmax(scaledScores);
272 | const output = matmul(attentionWeights, v);
273 |
274 | // Simplified descriptions based on attention type
275 | const descriptions = {
276 | 'encoder': `Multi-head self-attention with ${numHeads} heads processing ${seqLen} tokens. Each token can attend to all other tokens in the sequence to build contextual representations.`,
277 | 'decoder-masked': `Masked self-attention with ${numHeads} heads for autoregressive generation. Tokens can only attend to previous positions to maintain causality during generation.`,
278 | 'cross': `Cross-attention with ${numHeads} heads connecting encoder and decoder. Decoder queries attend to encoder keys and values, enabling the model to focus on relevant source information.`
279 | };
280 |
281 | const description = descriptions[attentionType as keyof typeof descriptions] || descriptions.encoder;
282 |
283 | // Create inputs array based on attention type
284 | const inputs = [
285 | {
286 | id: 'input',
287 | label: attentionType === 'cross' ? 'Decoder Input (X)' : 'Input Embeddings (X)',
288 | description: attentionType === 'cross' ?
289 | 'Input from previous decoder layer (for queries)' :
290 | 'Input from previous layer (embeddings or previous attention block)',
291 | type: 'matrix' as const,
292 | data: input,
293 | shape: [seqLen, dModel] as [number, number],
294 | metadata: { from_previous_layer: true, sequence_length: seqLen }
295 | }
296 | ];
297 |
298 | // Add encoder output as input for cross-attention
299 | if (attentionType === 'cross') {
300 | inputs.push({
301 | id: 'encoder-output',
302 | label: 'Encoder Output (H_enc)',
303 | description: 'Encoder output used as keys and values for cross-attention',
304 | type: 'matrix' as const,
305 | data: generateEmbeddings(seqLen, dModel),
306 | shape: [seqLen, dModel] as [number, number],
307 | metadata: { from_encoder: true, used_for_keys_values: true } as any
308 | });
309 | }
310 |
311 | return {
312 | description,
313 | category: 'attention',
314 | inputs,
315 | parameters: [
316 | {
317 | id: 'wq',
318 | label: 'Query Weight Matrix (W_Q)',
319 | description: 'Projects input to query space for attention computation',
320 | type: 'matrix',
321 | data: generateMatrix(dModel, dModel),
322 | shape: [dModel, dModel],
323 | metadata: { learnable: true, heads: numHeads, head_dim: headSize }
324 | },
325 | {
326 | id: 'wk',
327 | label: 'Key Weight Matrix (W_K)',
328 | description: 'Projects input to key space for attention computation',
329 | type: 'matrix',
330 | data: generateMatrix(dModel, dModel),
331 | shape: [dModel, dModel],
332 | metadata: { learnable: true, heads: numHeads, head_dim: headSize }
333 | },
334 | {
335 | id: 'wv',
336 | label: 'Value Weight Matrix (W_V)',
337 | description: 'Projects input to value space for attention computation',
338 | type: 'matrix',
339 | data: generateMatrix(dModel, dModel),
340 | shape: [dModel, dModel],
341 | metadata: { learnable: true, heads: numHeads, head_dim: headSize }
342 | },
343 | {
344 | id: 'wo',
345 | label: 'Output Weight Matrix (W_O)',
346 | description: 'Projects concatenated multi-head output back to model dimension',
347 | type: 'matrix',
348 | data: generateMatrix(dModel, dModel),
349 | shape: [dModel, dModel],
350 | metadata: { learnable: true, final_projection: true, total_params: 4 * dModel * dModel }
351 | }
352 | ],
353 | outputs: [
354 | {
355 | id: 'attention-weights',
356 | label: 'Attention Weights (α)',
357 | description: `Attention probability distribution (${seqLen}×${seqLen}) showing which tokens attend to which`,
358 | type: 'matrix',
359 | data: attentionWeights,
360 | shape: [seqLen, seqLen],
361 | metadata: {
362 | normalized: true,
363 | range: [0, 1],
364 | scaled: true,
365 | masking: attentionType === 'decoder-masked' ? 'causal' : 'none',
366 | sequence_length: seqLen
367 | }
368 | },
369 | {
370 | id: 'attention-output',
371 | label: 'Attention Output (Z)',
372 | description: 'Weighted combination of values based on attention weights',
373 | type: 'matrix',
374 | data: output,
375 | shape: [seqLen, dModel],
376 | metadata: {
377 | multi_head: true,
378 | heads: numHeads,
379 | attention_type: attentionType,
380 | ready_for_residual: true
381 | }
382 | }
383 | ]
384 | };
385 | }
386 |
387 | function generateAddNormData(seqLen: number, dModel: number, blockType: 'attention' | 'ffn'): ComponentData {
388 | const residual = generateEmbeddings(seqLen, dModel);
389 | const blockOutput = generateEmbeddings(seqLen, dModel);
390 | const added = residual.map((row, i) => row.map((val, j) => val + blockOutput[i][j]));
391 |
392 | // Simple layer norm simulation
393 | const normalized = added.map(row => {
394 | const mean = row.reduce((sum, val) => sum + val, 0) / row.length;
395 | const variance = row.reduce((sum, val) => sum + Math.pow(val - mean, 2), 0) / row.length;
396 | const std = Math.sqrt(variance + 1e-5);
397 | return row.map(val => (val - mean) / std);
398 | });
399 |
400 | return {
401 | description: `Residual connection and layer normalization after ${blockType} block. The residual connection helps with gradient flow during training, while layer normalization stabilizes training by normalizing inputs to have zero mean and unit variance.`,
402 | category: blockType,
403 | inputs: [
404 | {
405 | id: 'residual',
406 | label: 'Residual Input (X)',
407 | description: 'Original input to be added back (skip connection)',
408 | type: 'matrix',
409 | data: residual,
410 | shape: [seqLen, dModel],
411 | metadata: { skip_connection: true, gradient_highway: true }
412 | },
413 | {
414 | id: 'block-output',
415 | label: `${blockType.charAt(0).toUpperCase() + blockType.slice(1)} Output (${blockType === 'attention' ? 'Z' : 'Y'})`,
416 | description: `Output from the ${blockType} block`,
417 | type: 'matrix',
418 | data: blockOutput,
419 | shape: [seqLen, dModel],
420 | metadata: { from_block: blockType, transformed: true }
421 | }
422 | ],
423 | parameters: [
424 | {
425 | id: 'gamma',
426 | label: 'Layer Norm Scale (γ)',
427 | description: 'Learnable scale parameter for layer normalization',
428 | type: 'vector',
429 | data: Array(dModel).fill(1),
430 | shape: [dModel],
431 | metadata: { learnable: true, initialized: 'ones', adaptive_scaling: true }
432 | },
433 | {
434 | id: 'beta',
435 | label: 'Layer Norm Bias (β)',
436 | description: 'Learnable bias parameter for layer normalization',
437 | type: 'vector',
438 | data: Array(dModel).fill(0),
439 | shape: [dModel],
440 | metadata: { learnable: true, initialized: 'zeros', shift_parameter: true }
441 | }
442 | ],
443 | outputs: [
444 | {
445 | id: 'normalized',
446 | label: 'Normalized Output (H_norm)',
447 | description: 'Layer normalized output after residual connection',
448 | type: 'matrix',
449 | data: normalized,
450 | shape: [seqLen, dModel],
451 | metadata: {
452 | mean: 0,
453 | variance: 1,
454 | operation: 'add_then_norm',
455 | stable_gradients: true,
456 | ready_for_next_layer: true
457 | }
458 | }
459 | ]
460 | };
461 | }
462 |
463 | function generateFFNData(seqLen: number, dModel: number): ComponentData {
464 | const input = generateEmbeddings(seqLen, dModel);
465 | const hiddenDim = dModel * 4; // Typical FFN expansion
466 |
467 | const w1 = generateMatrix(dModel, hiddenDim);
468 | const w2 = generateMatrix(hiddenDim, dModel);
469 |
470 | // Simulate FFN computation (simplified for visualization)
471 | const hidden = generateMatrix(seqLen, hiddenDim);
472 | const activated = hidden.map(row => row.map(val => Math.max(0, val))); // ReLU
473 | const output = generateMatrix(seqLen, dModel);
474 |
475 | return {
476 | description: `Position-wise feed-forward network with expansion factor 4 (${dModel} → ${hiddenDim} → ${dModel}). Each position is processed independently with two linear transformations and ReLU activation.`,
477 | category: 'ffn',
478 | inputs: [
479 | {
480 | id: 'input',
481 | label: 'Input from Attention (X)',
482 | description: 'Normalized output from attention block',
483 | type: 'matrix',
484 | data: input,
485 | shape: [seqLen, dModel],
486 | metadata: { from_attention: true, sequence_parallel: true }
487 | }
488 | ],
489 | parameters: [
490 | {
491 | id: 'w1',
492 | label: 'First Linear Layer (W₁)',
493 | description: `Expands dimension from ${dModel} to ${hiddenDim} (4× expansion)`,
494 | type: 'matrix',
495 | data: w1,
496 | shape: [dModel, hiddenDim],
497 | metadata: { learnable: true, expansion_factor: 4, parameters: dModel * hiddenDim }
498 | },
499 | {
500 | id: 'b1',
501 | label: 'First Layer Bias (b₁)',
502 | description: 'Bias for first linear transformation',
503 | type: 'vector',
504 | data: Array(hiddenDim).fill(0),
505 | shape: [hiddenDim],
506 | metadata: { learnable: true, initialized: 'zeros' }
507 | },
508 | {
509 | id: 'w2',
510 | label: 'Second Linear Layer (W₂)',
511 | description: `Projects back from ${hiddenDim} to ${dModel}`,
512 | type: 'matrix',
513 | data: w2,
514 | shape: [hiddenDim, dModel],
515 | metadata: { learnable: true, projection_back: true, parameters: hiddenDim * dModel }
516 | },
517 | {
518 | id: 'b2',
519 | label: 'Second Layer Bias (b₂)',
520 | description: 'Bias for second linear transformation',
521 | type: 'vector',
522 | data: Array(dModel).fill(0),
523 | shape: [dModel],
524 | metadata: { learnable: true, initialized: 'zeros' }
525 | }
526 | ],
527 | outputs: [
528 | {
529 | id: 'hidden',
530 | label: 'Hidden Representation (H_hidden)',
531 | description: `Intermediate representation (${hiddenDim}D) after first linear layer and ReLU`,
532 | type: 'matrix',
533 | data: activated,
534 | shape: [seqLen, hiddenDim],
535 | metadata: { activation: 'ReLU', expanded: true, sparse: true }
536 | },
537 | {
538 | id: 'output',
539 | label: 'FFN Output (Y)',
540 | description: 'Final output after second linear transformation',
541 | type: 'matrix',
542 | data: output,
543 | shape: [seqLen, dModel],
544 | metadata: { ready_for_residual: true, position_wise: true }
545 | }
546 | ]
547 | };
548 | }
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 | function generatePredictedTokensData(seqLen: number): ComponentData {
557 | // Generate simple predicted tokens
558 | const predictedTokens = SequenceGenerator.generateSequence(seqLen, 'output');
559 |
560 | return {
561 | description: `Final predicted tokens from the transformer model after sampling from the probability distribution. These represent the model's best guess for the target sequence.`,
562 | category: 'output',
563 | inputs: [
564 | {
565 | id: 'probabilities',
566 | label: 'Token Probabilities (P)',
567 | description: 'Probability distribution over vocabulary from softmax',
568 | type: 'matrix',
569 | data: generateMatrix(seqLen, 100),
570 | shape: [seqLen, 50000],
571 | metadata: { from_softmax: true, ready_for_sampling: true }
572 | }
573 | ],
574 | parameters: [
575 | {
576 | id: 'sampling-strategy',
577 | label: 'Sampling Strategy (σ)',
578 | description: 'Method for selecting tokens from probability distribution',
579 | type: 'text',
580 | data: 'greedy',
581 | metadata: {
582 | options: ['greedy', 'top_k', 'top_p', 'beam_search'],
583 | current: 'greedy',
584 | deterministic: true
585 | }
586 | }
587 | ],
588 | outputs: [
589 | {
590 | id: 'predicted-tokens',
591 | label: 'Predicted Tokens (ŷ)',
592 | description: 'Final predicted sequence',
593 | type: 'tokens',
594 | data: predictedTokens,
595 | shape: [predictedTokens.length],
596 | metadata: {
597 | sampling_method: 'greedy',
598 | confidence: 0.85,
599 | sequence_length: seqLen,
600 | generated: true
601 | }
602 | }
603 | ]
604 | };
605 | }
606 |
607 | function generateEncoderOutputData(seqLen: number, dModel: number): ComponentData {
608 | const input = generateEmbeddings(seqLen, dModel);
609 | const output = generateEmbeddings(seqLen, dModel);
610 |
611 | return {
612 | description: `Final output of the encoder block, which consists of a series of encoder layers. This output contains rich contextual representations that are passed to the decoder for cross-attention and subsequent decoding.`,
613 | category: 'output',
614 | inputs: [
615 | {
616 | id: 'final-encoder-layer-output',
617 | label: 'Final Encoder Layer Output (H_N)',
618 | description: 'Output from the last encoder layer (normalized)',
619 | type: 'matrix',
620 | data: input,
621 | shape: [seqLen, dModel],
622 | metadata: { from_final_encoder_layer: true, sequence_length: seqLen }
623 | }
624 | ],
625 | parameters: [],
626 | outputs: [
627 | {
628 | id: 'encoder-output',
629 | label: 'Encoder Output (H_enc)',
630 | description: 'Final contextual representations from the encoder block',
631 | type: 'matrix',
632 | data: output,
633 | shape: [seqLen, dModel],
634 | metadata: {
635 | from_encoder: true,
636 | contextualized: true,
637 | used_for_cross_attention: true,
638 | all_layers_processed: true
639 | }
640 | }
641 | ]
642 | };
643 | }
644 |
645 | function generateFinalOutputData(seqLen: number, dModel: number): ComponentData {
646 | const input = generateEmbeddings(seqLen, dModel);
647 | const logits = generateMatrix(seqLen, Math.min(50000, 100)); // Show subset for visualization
648 | const probs = softmax(logits);
649 |
650 | return {
651 | description: `Combines the final linear layer (language model head) and softmax function to produce the final probability distribution over the vocabulary. This is the complete output processing pipeline of the transformer model.`,
652 | category: 'output',
653 | inputs: [
654 | {
655 | id: 'decoder-output',
656 | label: 'Decoder Output (H_dec)',
657 | description: 'Final hidden states from the decoder',
658 | type: 'matrix',
659 | data: input,
660 | shape: [seqLen, dModel],
661 | metadata: { from_decoder: true, final_layer: true }
662 | }
663 | ],
664 | parameters: [
665 | {
666 | id: 'vocab-projection-weights',
667 | label: 'Vocabulary Projection Weights (W_v)',
668 | description: 'Weight matrix for projecting to vocabulary size',
669 | type: 'matrix',
670 | data: generateMatrix(dModel, Math.min(50000, 100)),
671 | shape: [dModel, 50000],
672 | metadata: { learnable: true, vocab_size: 50000 }
673 | },
674 | {
675 | id: 'vocab-projection-bias',
676 | label: 'Vocabulary Projection Bias (b_v)',
677 | description: 'Bias vector for vocabulary projection',
678 | type: 'vector',
679 | data: Array(Math.min(50000, 100)).fill(0),
680 | shape: [50000],
681 | metadata: { learnable: true, initialized: 'zeros' }
682 | }
683 | ],
684 | outputs: [
685 | {
686 | id: 'token-probabilities',
687 | label: 'Token Probabilities (P)',
688 | description: 'Final probability distribution over vocabulary',
689 | type: 'matrix',
690 | data: probs,
691 | shape: [seqLen, 50000],
692 | metadata: {
693 | normalized: true,
694 | sum_to_one: true,
695 | ready_for_sampling: true,
696 | temperature_applied: true
697 | }
698 | }
699 | ]
700 | };
701 | }
702 |
703 | function generateDefaultData(componentId: string): ComponentData {
704 | // Determine category based on component ID
705 | let category: ComponentData['category'] = 'embedding';
706 |
707 | if (componentId.includes('token')) {
708 | category = 'tokens';
709 | } else if (componentId.includes('attention')) {
710 | category = 'attention';
711 | } else if (componentId.includes('ffn') || componentId.includes('feed')) {
712 | category = 'ffn';
713 | } else if (componentId.includes('output') || componentId.includes('linear') || componentId.includes('softmax')) {
714 | category = 'output';
715 | } else if (componentId.includes('positional') || componentId.includes('position')) {
716 | category = 'positional';
717 | }
718 |
719 | return {
720 | description: `Component details for ${componentId}. This component is part of the transformer architecture and contributes to the overall sequence processing.`,
721 | category,
722 | inputs: [],
723 | parameters: [],
724 | outputs: []
725 | };
726 | }
--------------------------------------------------------------------------------