├── inference
├── __init__.py
└── generator.py
├── models
├── __init__.py
├── config.py
├── mtp.py
├── layers.py
├── moe.py
├── attention.py
└── model.py
├── requirements.txt
├── training
├── __init__.py
├── data_loader.py
└── trainer.py
├── .gitignore
├── main.py
├── run_inference.py
├── prepare_data_tiny_stories.py
├── README.md
├── notebooks
├── Multi_Token_Prediction_from_Scratch.ipynb
└── Mixture_of_Experts_from_Scratch.ipynb
└── prepare_data_fineweb.py
/inference/__init__.py:
--------------------------------------------------------------------------------
1 | from .generator import generate_text
2 |
3 | __all__ = ['generate_text']
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .config import DeepSeekConfig
2 | from .model import DeepSeekV3
3 |
4 | __all__ = ['DeepSeekConfig', 'DeepSeekV3']
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=2.0.0
2 | numpy>=1.21.0
3 | tqdm>=4.64.0
4 | tiktoken>=0.5.0
5 | datasets>=2.19.0
6 | transformers>=4.30.0
7 | wandb
8 | python-dotenv
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_loader import get_batch, estimate_loss
2 | from .trainer import train_model
3 |
4 | __all__ = ['get_batch', 'estimate_loss', 'train_model']
--------------------------------------------------------------------------------
/models/config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | @dataclass
4 | class DeepSeekConfig:
5 | # Model architecture
6 | vocab_size: int = 50257
7 | block_size: int = 128
8 | n_layer: int = 6
9 | n_embd: int = 384
10 | n_head: int = 6
11 |
12 | # MLA configuration
13 | kv_lora_rank: int = 128
14 | q_lora_rank: int = 192
15 | rope_dim: int = 32
16 |
17 | # MoE configuration
18 | n_experts: int = 4
19 | n_experts_per_token: int = 2
20 | expert_intermediate_size: int = 512
21 | shared_expert_intermediate_size: int = 768
22 | use_shared_expert: bool = True
23 |
24 | # MTP configuration
25 | mtp_num_heads: int = 2
26 |
27 | # Training parameters
28 | dropout: float = 0.1
29 | bias: bool = True
30 | aux_loss_weight: float = 0.0
31 | mtp_loss_weight: float = 0.3
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Data files (large binary files)
2 | train.bin
3 | validation.bin
4 | *.bin
5 |
6 | input_output_pairs.csv
7 | training_samples.csv
8 |
9 | myenv/
10 |
11 |
12 | # Model weights (large files)
13 | best_deepseek_v3.pt
14 | *.pt
15 | *.pth
16 |
17 | # Weights & Biases
18 | wandb/
19 | *.wandb
20 | wandb-metadata.json
21 | wandb-summary.json
22 |
23 | # Common ML artifacts
24 |
25 | checkpoints/
26 | logs/
27 | outputs/
28 |
29 |
30 |
31 | # Python cache
32 | __pycache__/
33 | *.py[cod]
34 | *$py.class
35 | *.so
36 |
37 | # Environment
38 | .env
39 | .venv
40 | env/
41 | venv/
42 | ENV/
43 | env.bak/
44 | venv.bak/
45 |
46 | # IDE
47 | .vscode/
48 | .idea/
49 | *.swp
50 | *.swo
51 | *~
52 |
53 | # OS
54 | .DS_Store
55 | .DS_Store?
56 | ._*
57 | .Spotlight-V100
58 | .Trashes
59 | ehthumbs.db
60 | Thumbs.db
61 |
62 | # Jupyter Notebook
63 | .ipynb_checkpoints
64 |
65 | # PyTorch
66 | *.pth.tar
67 |
68 | # Logs
69 | *.log
70 | logs/
71 |
72 | # Temporary files
73 | tmp/
74 | temp/
--------------------------------------------------------------------------------
/training/data_loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | def get_batch(split, config, batch_size, device_type, device):
5 | if split == 'train':
6 | data = np.memmap('train.bin', dtype=np.uint16, mode='r')
7 | else:
8 | data = np.memmap('validation.bin', dtype=np.uint16, mode='r')
9 |
10 | ix = torch.randint(len(data) - config.block_size, (batch_size,))
11 | x = torch.stack([torch.from_numpy((data[i:i+config.block_size]).astype(np.int64)) for i in ix])
12 | y = torch.stack([torch.from_numpy((data[i+1:i+1+config.block_size]).astype(np.int64)) for i in ix])
13 |
14 | if device_type == 'cuda':
15 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
16 | else:
17 | x, y = x.to(device), y.to(device)
18 |
19 | return x, y
20 |
21 | def estimate_loss(model, config, eval_iters, batch_size, device_type, device, ctx):
22 | out = {}
23 | model.eval()
24 | with torch.inference_mode():
25 | for split in ['train', 'val']:
26 | losses = torch.zeros(eval_iters)
27 | for k in range(eval_iters):
28 | X, Y = get_batch(split, config, batch_size, device_type, device)
29 | with ctx:
30 | _, loss, _, _ = model(X, Y)
31 | losses[k] = loss.item()
32 | out[split] = losses.mean()
33 | model.train()
34 | return out
--------------------------------------------------------------------------------
/models/mtp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .layers import RMSNorm
4 | from .attention import MultiHeadLatentAttention
5 | from .moe import MoELayer
6 |
7 | class MultiTokenPredictionHead(nn.Module):
8 | def __init__(self, config, depth):
9 | super().__init__()
10 | self.depth = depth
11 | self.n_embd = config.n_embd
12 |
13 | # Combine previous hidden state with future token embedding
14 | self.combine_proj = nn.Linear(2 * config.n_embd, config.n_embd, bias=config.bias)
15 |
16 | # Normalization
17 | self.norm1 = RMSNorm(config.n_embd)
18 | self.norm2 = RMSNorm(config.n_embd)
19 |
20 | # Transformer components
21 | self.attn = MultiHeadLatentAttention(config)
22 | self.mlp = MoELayer(config)
23 | self.attn_norm = RMSNorm(config.n_embd)
24 | self.mlp_norm = RMSNorm(config.n_embd)
25 |
26 | def forward(self, prev_hidden, future_token_embed):
27 | # Normalize inputs
28 | prev_norm = self.norm1(prev_hidden)
29 | future_norm = self.norm2(future_token_embed)
30 |
31 | # Combine representations
32 | combined = torch.cat([prev_norm, future_norm], dim=-1)
33 | hidden = self.combine_proj(combined)
34 |
35 | # Process through transformer components
36 | hidden = hidden + self.attn(self.attn_norm(hidden))
37 | hidden = hidden + self.mlp(self.mlp_norm(hidden))
38 |
39 | return hidden
--------------------------------------------------------------------------------
/models/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class RMSNorm(nn.Module):
6 | def __init__(self, ndim, eps=1e-6):
7 | super().__init__()
8 | self.eps = eps
9 | self.weight = nn.Parameter(torch.ones(ndim))
10 |
11 | def forward(self, x):
12 | norm = x.norm(dim=-1, keepdim=True) * (x.size(-1) ** -0.5)
13 | return self.weight * x / (norm + self.eps)
14 |
15 | class RotaryEmbedding(nn.Module):
16 | def __init__(self, dim, max_seq_len=2048):
17 | super().__init__()
18 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
19 | self.register_buffer('inv_freq', inv_freq)
20 | self.max_seq_len = max_seq_len
21 |
22 | def forward(self, x, seq_len=None):
23 | if seq_len is None:
24 | seq_len = x.shape[-2]
25 |
26 | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
27 | freqs = torch.outer(t, self.inv_freq)
28 | cos, sin = freqs.cos(), freqs.sin()
29 | return cos, sin
30 |
31 | def apply_rope(x, cos, sin):
32 | x1, x2 = x.chunk(2, dim=-1)
33 | return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
34 |
35 | class SwiGLU(nn.Module):
36 | def __init__(self, in_features, hidden_features, out_features, bias=True):
37 | super().__init__()
38 | self.gate_proj = nn.Linear(in_features, hidden_features, bias=bias)
39 | self.up_proj = nn.Linear(in_features, hidden_features, bias=bias)
40 | self.down_proj = nn.Linear(hidden_features, out_features, bias=bias)
41 |
42 | def forward(self, x):
43 | gate = self.gate_proj(x)
44 | up = self.up_proj(x)
45 | return self.down_proj(F.silu(gate) * up)
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | DeepSeek-V3 Main Script
3 |
4 | Simple entry point for training and inference.
5 | """
6 |
7 | import torch
8 | from models import DeepSeekConfig, DeepSeekV3
9 | from training import train_model
10 | from inference import generate_text
11 |
12 | def demo():
13 | """Run a simple demo."""
14 | print("=" * 50)
15 | print("DEEPSEEK-V3 DEMO")
16 | print("=" * 50)
17 |
18 | # Create small model for demo
19 | config = DeepSeekConfig(
20 | vocab_size=1000,
21 | block_size=64,
22 | n_layer=3,
23 | n_head=4,
24 | n_embd=128,
25 | kv_lora_rank=32,
26 | q_lora_rank=48,
27 | n_experts=4,
28 | n_experts_per_token=2,
29 | mtp_num_heads=1,
30 | dropout=0.0
31 | )
32 |
33 | model = DeepSeekV3(config)
34 | total_params = sum(p.numel() for p in model.parameters())
35 | print(f"Created model with {total_params:,} parameters")
36 |
37 | # Test forward pass
38 | batch_size, seq_len = 2, 32
39 | test_input = torch.randint(0, config.vocab_size, (batch_size, seq_len))
40 | test_targets = torch.randint(0, config.vocab_size, (batch_size, seq_len))
41 |
42 | with torch.no_grad():
43 | logits, total_loss, main_loss, mtp_loss = model(test_input, test_targets)
44 |
45 | print(f"Input shape: {test_input.shape}")
46 | print(f"Output shape: {logits.shape}")
47 | print(f"Main loss: {main_loss:.4f}")
48 | if mtp_loss is not None:
49 | print(f"MTP loss: {mtp_loss:.4f}")
50 | print(f"Total loss: {total_loss:.4f}")
51 |
52 | # Test generation
53 | prompt = torch.randint(0, config.vocab_size, (1, 5))
54 | with torch.no_grad():
55 | generated = model.generate(prompt, max_new_tokens=20, temperature=1.0, top_k=10)
56 |
57 | print(f"Generated {generated.shape[1] - prompt.shape[1]} tokens")
58 | print("Demo completed!")
59 |
60 | def main():
61 | """Main function."""
62 | import sys
63 |
64 | if len(sys.argv) < 2:
65 | demo()
66 | return
67 |
68 | mode = sys.argv[1]
69 |
70 | if mode == "demo":
71 | demo()
72 | elif mode == "train":
73 | print("Starting training...")
74 | train_model()
75 | else:
76 | print("Usage: python main.py [demo|train|inference]")
77 |
78 | if __name__ == "__main__":
79 | main()
--------------------------------------------------------------------------------
/inference/generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tiktoken
3 | from models import DeepSeekV3
4 |
5 | def generate_text(model_path, config, prompt, max_tokens=100, temperature=0.8, top_k=50):
6 | """Generate text from a prompt using trained model."""
7 | device = "cuda" if torch.cuda.is_available() else "cpu"
8 |
9 | # Load model
10 | model = DeepSeekV3(config)
11 | model.load_state_dict(torch.load(model_path, map_location=device))
12 | model = model.to(device)
13 | model.eval()
14 |
15 | # Tokenize input
16 | enc = tiktoken.get_encoding("gpt2")
17 | context = torch.tensor(enc.encode_ordinary(prompt)).unsqueeze(0).to(device)
18 |
19 | # Generate
20 | with torch.no_grad():
21 | generated = model.generate(context, max_tokens, temperature, top_k)
22 |
23 | # Decode and return
24 | result = enc.decode(generated.squeeze().tolist())
25 | return result
26 |
27 | def run_inference_examples():
28 | """Run inference examples with different prompts."""
29 | try:
30 | from models import DeepSeekConfig
31 |
32 | config = DeepSeekConfig(
33 | vocab_size=50257,
34 | block_size=128,
35 | n_layer=4,
36 | n_head=4,
37 | n_embd=256,
38 | kv_lora_rank=64,
39 | q_lora_rank=96,
40 | n_experts=4,
41 | n_experts_per_token=2,
42 | mtp_num_heads=1,
43 | dropout=0.1
44 | )
45 |
46 | test_prompts = [
47 | "Once upon a time",
48 | "The little girl",
49 | "In a magical forest",
50 | "The brave knight"
51 | ]
52 |
53 | print("=" * 50)
54 | print("DEEPSEEK-V3 INFERENCE EXAMPLES")
55 | print("=" * 50)
56 |
57 | for prompt in test_prompts:
58 | result = generate_text(
59 | "best_deepseek_v3.pt",
60 | config,
61 | prompt,
62 | max_tokens=80,
63 | temperature=0.7,
64 | top_k=40
65 | )
66 |
67 | print(f"\nPrompt: '{prompt}'")
68 | print("Generated:", result)
69 | print("-" * 30)
70 |
71 | except FileNotFoundError:
72 | print("Model file 'best_deepseek_v3.pt' not found. Please train the model first.")
73 | except Exception as e:
74 | print(f"Error during inference: {e}")
--------------------------------------------------------------------------------
/run_inference.py:
--------------------------------------------------------------------------------
1 | """
2 | Simple script to run inference with DeepSeek-V3
3 | Change the prompts below and run this file to generate text!
4 | """
5 |
6 | import torch
7 | import tiktoken
8 | from models import DeepSeekConfig, DeepSeekV3
9 |
10 | def generate_text(prompt, max_tokens=100, temperature=0.8, top_k=50):
11 | """Generate text from your prompt."""
12 |
13 | # Model configuration (same as training)
14 | config = DeepSeekConfig(
15 | vocab_size=50257,
16 | block_size=128,
17 | n_layer=4,
18 | n_head=4,
19 | n_embd=256,
20 | kv_lora_rank=64,
21 | q_lora_rank=96,
22 | n_experts=4,
23 | n_experts_per_token=2,
24 | mtp_num_heads=1,
25 | dropout=0.1
26 | )
27 |
28 | device = "cuda" if torch.cuda.is_available() else "cpu"
29 |
30 | # Load model
31 | model = DeepSeekV3(config)
32 | try:
33 | model.load_state_dict(torch.load("best_deepseek_v3.pt", map_location=device))
34 | print("✓ Loaded trained model")
35 | except FileNotFoundError:
36 | print("⚠️ No trained model found, using random weights")
37 |
38 | model = model.to(device)
39 | model.eval()
40 |
41 | # Tokenize input
42 | tokenizer = tiktoken.get_encoding("gpt2")
43 | context = torch.tensor(tokenizer.encode_ordinary(prompt)).unsqueeze(0).to(device)
44 |
45 | # Generate
46 | with torch.no_grad():
47 | generated = model.generate(context, max_tokens, temperature, top_k)
48 |
49 | # Convert back to text
50 | result = tokenizer.decode(generated.squeeze().tolist())
51 | return result
52 |
53 |
54 | if __name__ == "__main__":
55 | print("=" * 60)
56 | print("DEEPSEEK-V3 TEXT GENERATION")
57 | print("=" * 60)
58 |
59 | # ============================================
60 | # CHANGE THESE PROMPTS TO WHATEVER YOU WANT!
61 | # ============================================
62 |
63 | my_prompts = [
64 | "Once upon a time",
65 | "The little girl found a magic",
66 | "In the future, artificial intelligence will",
67 | "The secret to happiness is",
68 | "Yesterday I went to the store and",
69 | ]
70 |
71 | # ============================================
72 | # CHANGE THESE PARAMETERS TO EXPERIMENT!
73 | # ============================================
74 |
75 | max_tokens = 80 # How many words to generate
76 | temperature = 0.8 # 0.1=boring, 0.8=balanced, 1.2=crazy
77 | top_k = 50 # Vocabulary limit
78 |
79 | # Generate text for each prompt
80 | for i, prompt in enumerate(my_prompts, 1):
81 | print(f"\n{i}. Prompt: '{prompt}'")
82 | print("-" * 40)
83 |
84 | result = generate_text(
85 | prompt=prompt,
86 | max_tokens=max_tokens,
87 | temperature=temperature,
88 | top_k=top_k
89 | )
90 |
91 | print(f"Generated: {result}")
92 | print()
93 |
94 | print("=" * 60)
95 | print("DONE! Edit the prompts above and run again!")
96 | print("=" * 60)
--------------------------------------------------------------------------------
/prepare_data_tiny_stories.py:
--------------------------------------------------------------------------------
1 | """
2 | Data Preparation Script for DeepSeek-V3
3 |
4 | This script downloads and tokenizes the TinyStories dataset.
5 | Run this BEFORE training the model.
6 | """
7 |
8 | import os
9 | import numpy as np
10 | import tiktoken
11 | from tqdm.auto import tqdm
12 | from datasets import load_dataset
13 |
14 | def prepare_tinystories_dataset():
15 | """Download and tokenize TinyStories dataset."""
16 |
17 | print("=" * 60)
18 | print("PREPARING TINYSTORIES DATASET")
19 | print("=" * 60)
20 |
21 | # Check if already prepared
22 | if os.path.exists("train.bin") and os.path.exists("validation.bin"):
23 | print("✓ Dataset files already exist!")
24 |
25 | # Show file info
26 | train_size = os.path.getsize("train.bin") / (1024 * 1024)
27 | val_size = os.path.getsize("validation.bin") / (1024 * 1024)
28 |
29 | train_data = np.memmap('train.bin', dtype=np.uint16, mode='r')
30 | val_data = np.memmap('validation.bin', dtype=np.uint16, mode='r')
31 |
32 | print(f"├── train.bin: {len(train_data):,} tokens ({train_size:.1f} MB)")
33 | print(f"└── validation.bin: {len(val_data):,} tokens ({val_size:.1f} MB)")
34 | return
35 |
36 | print("Downloading TinyStories dataset...")
37 |
38 | # Load dataset
39 | ds = load_dataset("roneneldan/TinyStories")
40 |
41 | # Initialize tokenizer
42 | enc = tiktoken.get_encoding("gpt2")
43 |
44 | def process_example(example):
45 | """Tokenize text."""
46 | ids = enc.encode_ordinary(example['text'])
47 | return {'ids': ids, 'len': len(ids)}
48 |
49 | print("Tokenizing dataset...")
50 |
51 | # Tokenize
52 | tokenized = ds.map(
53 | process_example,
54 | remove_columns=['text'],
55 | desc="Tokenizing splits",
56 | num_proc=8,
57 | )
58 |
59 | # Create binary files
60 | for split, dset in tokenized.items():
61 | arr_len = np.sum(dset['len'], dtype=np.uint64)
62 | filename = f'{split}.bin'
63 | dtype = np.uint16 # GPT-2 vocab size < 2^16
64 |
65 | print(f"Creating {filename}...")
66 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
67 | total_batches = 1024
68 |
69 | idx = 0
70 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'):
71 | batch = dset.shard(
72 | num_shards=total_batches,
73 | index=batch_idx,
74 | contiguous=True
75 | ).with_format('numpy')
76 |
77 | arr_batch = np.concatenate(batch['ids'])
78 | arr[idx : idx + len(arr_batch)] = arr_batch
79 | idx += len(arr_batch)
80 |
81 | arr.flush()
82 |
83 | size_mb = os.path.getsize(filename) / (1024 * 1024)
84 | print(f"✓ {filename}: {arr_len:,} tokens ({size_mb:.1f} MB)")
85 |
86 | print("\n Dataset preparation completed!")
87 | print("You can now run: python main.py train")
88 |
89 |
90 | if __name__ == "__main__":
91 | prepare_tinystories_dataset()
--------------------------------------------------------------------------------
/models/moe.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .layers import SwiGLU
5 |
6 | class MoELayer(nn.Module):
7 | def __init__(self, config):
8 | super().__init__()
9 | self.config = config
10 | self.n_experts = config.n_experts
11 | self.top_k = config.n_experts_per_token
12 | self.n_embd = config.n_embd
13 |
14 | # Router
15 | self.router = nn.Linear(config.n_embd, config.n_experts, bias=False)
16 |
17 | # Expert MLPs
18 | self.experts = nn.ModuleList([
19 | SwiGLU(
20 | config.n_embd,
21 | config.expert_intermediate_size,
22 | config.n_embd,
23 | config.bias
24 | ) for _ in range(config.n_experts)
25 | ])
26 |
27 | # Shared expert
28 | if config.use_shared_expert:
29 | self.shared_expert = SwiGLU(
30 | config.n_embd,
31 | config.shared_expert_intermediate_size,
32 | config.n_embd,
33 | config.bias
34 | )
35 | else:
36 | self.shared_expert = None
37 |
38 | # Auxiliary-loss-free load balancing
39 | self.register_buffer('expert_bias', torch.zeros(config.n_experts))
40 | self.bias_update_rate = 0.001
41 |
42 | def forward(self, x):
43 | batch_size, seq_len, hidden_size = x.shape
44 | x_flat = x.view(-1, hidden_size)
45 |
46 | # Routing phase
47 | router_logits = self.router(x_flat) + self.expert_bias
48 |
49 | # Top-k routing
50 | top_k_logits, top_k_indices = torch.topk(router_logits, self.top_k, dim=-1)
51 | routing_weights = torch.zeros_like(router_logits)
52 | routing_weights.scatter_(-1, top_k_indices, F.softmax(top_k_logits, dim=-1))
53 |
54 | # Expert computation
55 | output = torch.zeros_like(x_flat)
56 | expert_usage = torch.zeros(self.n_experts, device=x.device)
57 |
58 | # Process through selected experts
59 | for expert_idx in range(self.n_experts):
60 | expert_mask = (top_k_indices == expert_idx).any(dim=-1)
61 | expert_usage[expert_idx] = expert_mask.sum().float()
62 |
63 | if expert_mask.any():
64 | expert_input = x_flat[expert_mask]
65 | expert_output = self.experts[expert_idx](expert_input)
66 |
67 | # Weight by routing probability
68 | weights = routing_weights[expert_mask, expert_idx].unsqueeze(-1)
69 | output[expert_mask] += expert_output * weights
70 |
71 | # Add shared expert output
72 | if self.shared_expert is not None:
73 | shared_output = self.shared_expert(x_flat)
74 | output += shared_output
75 |
76 | # Auxiliary-loss-free load balancing
77 | if self.training:
78 | with torch.no_grad():
79 | avg_usage = expert_usage.mean()
80 | for i in range(self.n_experts):
81 | if expert_usage[i] > avg_usage:
82 | self.expert_bias[i] -= self.bias_update_rate
83 | else:
84 | self.expert_bias[i] += self.bias_update_rate
85 |
86 | return output.view(batch_size, seq_len, hidden_size)
--------------------------------------------------------------------------------
/models/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | from .layers import RMSNorm, RotaryEmbedding, apply_rope
6 |
7 | class MultiHeadLatentAttention(nn.Module):
8 | def __init__(self, config):
9 | super().__init__()
10 | self.config = config
11 | self.n_embd = config.n_embd
12 | self.n_head = config.n_head
13 | self.head_dim = config.n_embd // config.n_head
14 |
15 | # Compression dimensions
16 | self.kv_lora_rank = config.kv_lora_rank
17 | self.q_lora_rank = config.q_lora_rank
18 | self.rope_dim = config.rope_dim
19 |
20 | # KV compression
21 | self.kv_proj = nn.Linear(self.n_embd, self.kv_lora_rank, bias=False)
22 | self.kv_norm = RMSNorm(self.kv_lora_rank)
23 |
24 | # KV decompression
25 | self.k_decompress = nn.Linear(self.kv_lora_rank, self.n_head * self.head_dim, bias=False)
26 | self.v_decompress = nn.Linear(self.kv_lora_rank, self.n_head * self.head_dim, bias=False)
27 |
28 | # Query compression
29 | self.q_proj = nn.Linear(self.n_embd, self.q_lora_rank, bias=False)
30 | self.q_decompress = nn.Linear(self.q_lora_rank, self.n_head * self.head_dim, bias=False)
31 |
32 | # RoPE projections
33 | self.k_rope_proj = nn.Linear(self.n_embd, self.n_head * self.rope_dim, bias=False)
34 | self.q_rope_proj = nn.Linear(self.q_lora_rank, self.n_head * self.rope_dim, bias=False)
35 |
36 | # Output projection
37 | self.o_proj = nn.Linear(self.n_head * self.head_dim, self.n_embd, bias=config.bias)
38 |
39 | # Dropout
40 | self.attn_dropout = nn.Dropout(config.dropout)
41 | self.resid_dropout = nn.Dropout(config.dropout)
42 |
43 | # RoPE
44 | self.rope = RotaryEmbedding(self.rope_dim, config.block_size)
45 |
46 | # Causal mask
47 | self.register_buffer(
48 | "causal_mask",
49 | torch.tril(torch.ones(config.block_size, config.block_size)).view(
50 | 1, 1, config.block_size, config.block_size
51 | )
52 | )
53 |
54 | def forward(self, x):
55 | B, T, C = x.size()
56 |
57 | # Compression phase
58 | kv_compressed = self.kv_norm(self.kv_proj(x))
59 | q_compressed = self.q_proj(x)
60 |
61 | # Decompression phase
62 | k_content = self.k_decompress(kv_compressed)
63 | v = self.v_decompress(kv_compressed)
64 | q_content = self.q_decompress(q_compressed)
65 |
66 | # RoPE components
67 | k_rope = self.k_rope_proj(x)
68 | q_rope = self.q_rope_proj(q_compressed)
69 |
70 | # Reshape for multi-head attention
71 | k_content = k_content.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
72 | v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
73 | q_content = q_content.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
74 | k_rope = k_rope.view(B, T, self.n_head, self.rope_dim).transpose(1, 2)
75 | q_rope = q_rope.view(B, T, self.n_head, self.rope_dim).transpose(1, 2)
76 |
77 | # Apply RoPE
78 | cos, sin = self.rope(x, T)
79 | q_rope = apply_rope(q_rope, cos, sin)
80 | k_rope = apply_rope(k_rope, cos, sin)
81 |
82 | # Concatenate content and rope parts
83 | q = torch.cat([q_content, q_rope], dim=-1)
84 | k = torch.cat([k_content, k_rope], dim=-1)
85 |
86 | # Attention computation
87 | scale = 1.0 / math.sqrt(q.size(-1))
88 | scores = torch.matmul(q, k.transpose(-2, -1)) * scale
89 |
90 | # Apply causal mask
91 | scores = scores.masked_fill(self.causal_mask[:, :, :T, :T] == 0, float('-inf'))
92 |
93 | # Softmax and dropout
94 | attn_weights = F.softmax(scores, dim=-1)
95 | attn_weights = self.attn_dropout(attn_weights)
96 |
97 | # Apply attention to values
98 | out = torch.matmul(attn_weights, v)
99 |
100 | # Reshape and project
101 | out = out.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim)
102 | out = self.resid_dropout(self.o_proj(out))
103 |
104 | return out
--------------------------------------------------------------------------------
/training/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import os
4 | import wandb
5 | from tqdm.auto import tqdm
6 | from contextlib import nullcontext
7 | from models import DeepSeekConfig, DeepSeekV3
8 | from .data_loader import get_batch, estimate_loss
9 | from dotenv import load_dotenv
10 |
11 | # Load environment variables
12 | load_dotenv()
13 |
14 | def train_model():
15 | # Configuration
16 | config = DeepSeekConfig(
17 | vocab_size=50257,
18 | block_size=1024,
19 | n_layer=8,
20 | n_head=8,
21 | n_embd=512,
22 | kv_lora_rank=128,
23 | q_lora_rank=192,
24 | n_experts=8,
25 | n_experts_per_token=2,
26 | mtp_num_heads=1,
27 | dropout=0.1
28 | )
29 |
30 | # Training parameters
31 | learning_rate = 3e-4
32 | max_iters = 20000
33 | warmup_steps = 2000
34 | min_lr = 1e-5
35 | eval_iters = 1000
36 | batch_size = 32
37 | gradient_accumulation_steps = 8
38 |
39 | # Device setup
40 | device = "cuda" if torch.cuda.is_available() else "cpu"
41 | device_type = 'cuda' if 'cuda' in device else 'cpu'
42 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
43 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
44 | ctx = nullcontext() if device_type == 'cpu' else torch.cuda.amp.autocast(dtype=ptdtype)
45 |
46 | # Initialize wandb
47 | wandb.init(
48 | project="deepseek-v3-training",
49 | config={
50 | "learning_rate": learning_rate,
51 | "max_iters": max_iters,
52 | "warmup_steps": warmup_steps,
53 | "min_lr": min_lr,
54 | "eval_iters": eval_iters,
55 | "batch_size": batch_size,
56 | "gradient_accumulation_steps": gradient_accumulation_steps,
57 | "device": device,
58 | "dtype": dtype,
59 | "vocab_size": config.vocab_size,
60 | "block_size": config.block_size,
61 | "n_layer": config.n_layer,
62 | "n_head": config.n_head,
63 | "n_embd": config.n_embd,
64 | "kv_lora_rank": config.kv_lora_rank,
65 | "q_lora_rank": config.q_lora_rank,
66 | "n_experts": config.n_experts,
67 | "n_experts_per_token": config.n_experts_per_token,
68 | "mtp_num_heads": config.mtp_num_heads,
69 | "dropout": config.dropout
70 | }
71 | )
72 |
73 | # Initialize model
74 | torch.manual_seed(42)
75 | model = DeepSeekV3(config)
76 | model = model.to(device)
77 |
78 | # Print model info
79 | total_params = sum(p.numel() for p in model.parameters())
80 | print(f"DeepSeek-V3 model with {total_params:,} parameters")
81 |
82 | # Log model parameters to wandb
83 | wandb.log({"total_parameters": total_params})
84 |
85 | # Optimizer
86 | optimizer = torch.optim.AdamW(
87 | model.parameters(),
88 | lr=learning_rate,
89 | betas=(0.9, 0.95),
90 | weight_decay=0.1,
91 | eps=1e-9
92 | )
93 |
94 | # Training loop
95 | model.train()
96 | best_val_loss = float('inf')
97 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
98 |
99 | for epoch in tqdm(range(max_iters)):
100 | # Evaluation
101 | if epoch % eval_iters == 0 and epoch != 0:
102 | losses = estimate_loss(model, config, eval_iters, batch_size, device_type, device, ctx)
103 | print(f"Epoch {epoch}: train {losses['train']:.4f}, val {losses['val']:.4f}")
104 |
105 | # Log evaluation losses to wandb
106 | wandb.log({
107 | "epoch": epoch,
108 | "train_loss": losses['train'],
109 | "val_loss": losses['val'],
110 | "best_val_loss": best_val_loss
111 | })
112 |
113 | if losses['val'] < best_val_loss:
114 | best_val_loss = losses['val']
115 | torch.save(model.state_dict(), "best_deepseek_v3.pt")
116 |
117 | # Log best model save to wandb
118 | wandb.log({"best_val_loss_updated": best_val_loss})
119 |
120 | # Training step
121 | X, y = get_batch("train", config, batch_size, device_type, device)
122 |
123 | with ctx:
124 | _, total_loss, main_loss, mtp_loss = model(X, y)
125 | loss = total_loss / gradient_accumulation_steps
126 | scaler.scale(loss).backward()
127 |
128 | if ((epoch + 1) % gradient_accumulation_steps == 0) or (epoch + 1 == max_iters):
129 | scaler.step(optimizer)
130 | scaler.update()
131 | optimizer.zero_grad(set_to_none=True)
132 |
133 | # Learning rate scheduling
134 | if epoch < warmup_steps:
135 | lr = learning_rate * (epoch + 1) / warmup_steps
136 | else:
137 | progress = (epoch - warmup_steps) / (max_iters - warmup_steps)
138 | lr = min_lr + (learning_rate - min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
139 |
140 | for param_group in optimizer.param_groups:
141 | param_group['lr'] = lr
142 |
143 | # Log training metrics to wandb every step
144 | wandb.log({
145 | "step": epoch,
146 | "total_loss": total_loss.item(),
147 | "main_loss": main_loss.item(),
148 | "mtp_loss": mtp_loss.item(),
149 | "learning_rate": lr,
150 | "scaled_loss": loss.item()
151 | })
152 |
153 | print("Training completed!")
154 |
155 | # Finish wandb run
156 | wandb.finish()
157 |
158 | return model, config
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | from .layers import RMSNorm
6 | from .attention import MultiHeadLatentAttention
7 | from .moe import MoELayer
8 | from .mtp import MultiTokenPredictionHead
9 |
10 | class DeepSeekBlock(nn.Module):
11 | def __init__(self, config):
12 | super().__init__()
13 | self.ln_1 = RMSNorm(config.n_embd)
14 | self.attn = MultiHeadLatentAttention(config)
15 | self.ln_2 = RMSNorm(config.n_embd)
16 | self.mlp = MoELayer(config)
17 |
18 | def forward(self, x):
19 | x = x + self.attn(self.ln_1(x))
20 | x = x + self.mlp(self.ln_2(x))
21 | return x
22 |
23 | class DeepSeekV3(nn.Module):
24 | def __init__(self, config):
25 | super().__init__()
26 | self.config = config
27 |
28 | # Token and position embeddings
29 | self.wte = nn.Embedding(config.vocab_size, config.n_embd)
30 | self.wpe = nn.Embedding(config.block_size, config.n_embd)
31 | self.drop = nn.Dropout(config.dropout)
32 |
33 | # Transformer blocks
34 | self.h = nn.ModuleList([DeepSeekBlock(config) for _ in range(config.n_layer)])
35 |
36 | # Final layer norm
37 | self.ln_f = RMSNorm(config.n_embd)
38 |
39 | # Language modeling head
40 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
41 |
42 | # Weight tying
43 | self.wte.weight = self.lm_head.weight
44 |
45 | # Multi-Token Prediction heads
46 | if config.mtp_num_heads > 0:
47 | self.mtp_heads = nn.ModuleList([
48 | MultiTokenPredictionHead(config, depth)
49 | for depth in range(1, config.mtp_num_heads + 1)
50 | ])
51 | else:
52 | self.mtp_heads = None
53 |
54 | # Initialize weights
55 | self.apply(self._init_weights)
56 |
57 | # Special initialization for residual projections
58 | for pn, p in self.named_parameters():
59 | if pn.endswith('o_proj.weight') or pn.endswith('down_proj.weight'):
60 | nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
61 |
62 | def _init_weights(self, module):
63 | if isinstance(module, nn.Linear):
64 | nn.init.normal_(module.weight, mean=0.0, std=0.02)
65 | if module.bias is not None:
66 | nn.init.zeros_(module.bias)
67 | elif isinstance(module, nn.Embedding):
68 | nn.init.normal_(module.weight, mean=0.0, std=0.02)
69 |
70 | def forward(self, idx, targets=None):
71 | device = idx.device
72 | b, t = idx.size()
73 | assert t <= self.config.block_size
74 |
75 | # Token and position embeddings
76 | pos = torch.arange(0, t, dtype=torch.long, device=device)
77 | tok_emb = self.wte(idx)
78 | pos_emb = self.wpe(pos)
79 | x = self.drop(tok_emb + pos_emb)
80 |
81 | # Transformer blocks
82 | for block in self.h:
83 | x = block(x)
84 |
85 | # Final norm
86 | x = self.ln_f(x)
87 |
88 | # Main language modeling head
89 | main_logits = self.lm_head(x)
90 | main_loss = None
91 |
92 | if targets is not None:
93 | main_loss = F.cross_entropy(
94 | main_logits.view(-1, main_logits.size(-1)),
95 | targets.view(-1),
96 | ignore_index=-1
97 | )
98 |
99 | # Multi-Token Prediction
100 | mtp_loss = None
101 | if self.mtp_heads is not None and targets is not None:
102 | mtp_losses = []
103 | current_hidden = x
104 |
105 | for depth, mtp_head in enumerate(self.mtp_heads, 1):
106 | if t > depth:
107 | future_indices = idx[:, depth:]
108 | future_embeds = self.wte(future_indices)
109 |
110 | if future_embeds.size(1) < current_hidden.size(1):
111 | pad_size = current_hidden.size(1) - future_embeds.size(1)
112 | padding = torch.zeros(
113 | b, pad_size, self.config.n_embd,
114 | device=device, dtype=future_embeds.dtype
115 | )
116 | future_embeds = torch.cat([future_embeds, padding], dim=1)
117 | elif future_embeds.size(1) > current_hidden.size(1):
118 | future_embeds = future_embeds[:, :current_hidden.size(1)]
119 |
120 | current_hidden = mtp_head(current_hidden, future_embeds)
121 | mtp_logits = self.lm_head(current_hidden)
122 |
123 | if t > depth + 1:
124 | shift_logits = mtp_logits[..., :-(depth+1), :].contiguous()
125 | shift_labels = targets[..., depth+1:].contiguous()
126 |
127 | if shift_labels.numel() > 0:
128 | mtp_loss_single = F.cross_entropy(
129 | shift_logits.view(-1, shift_logits.size(-1)),
130 | shift_labels.view(-1),
131 | ignore_index=-1
132 | )
133 | mtp_losses.append(mtp_loss_single)
134 |
135 | if mtp_losses:
136 | mtp_loss = torch.stack(mtp_losses).mean()
137 |
138 | # Combine losses
139 | if targets is not None:
140 | if mtp_loss is not None:
141 | total_loss = main_loss + self.config.mtp_loss_weight * mtp_loss
142 | return main_logits, total_loss, main_loss, mtp_loss
143 | else:
144 | return main_logits, main_loss, main_loss, None
145 | else:
146 | return main_logits[:, [-1], :], None, None, None
147 |
148 | @torch.no_grad()
149 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
150 | for _ in range(max_new_tokens):
151 | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
152 | logits, _, _, _ = self(idx_cond)
153 | logits = logits[:, -1, :] / temperature
154 |
155 | if top_k is not None:
156 | v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
157 | logits[logits < v[:, [-1]]] = -float('Inf')
158 |
159 | probs = F.softmax(logits, dim=-1)
160 | idx_next = torch.multinomial(probs, num_samples=1)
161 | idx = torch.cat((idx, idx_next), dim=1)
162 |
163 | return idx
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DeepSeek V3 from Scratch
2 |
3 | A complete implementation of the DeepSeek V3 architecture with modern transformer innovations including Multi-Head Latent Attention (MLA), Mixture of Experts (MoE), and Multi-Token Prediction (MTP). This project demonstrates the implementation of a 100+ million parameter language model trained on the FineWeb-Edu dataset.
4 |
5 | ## Architecture Overview
6 |
7 |
8 | 
9 |
10 | DeepSeek V3 introduces several key architectural improvements over traditional transformer models:
11 |
12 | ### Core Innovations
13 |
14 | **Multi-Head Latent Attention (MLA)**
15 | 
16 | - Compresses key-value pairs into shared latent representations
17 | - Dramatically reduces memory usage compared to traditional multi-head attention
18 | - Maintains the expressiveness of multiple attention heads while using significantly less memory
19 | - Critical for handling longer sequences efficiently
20 |
21 |
22 |
23 |
24 |
25 |
26 | **Mixture of Experts (MoE)**
27 | 
28 | - Replaces dense feed-forward networks with sparse expert networks
29 | - Uses 8 experts but only activates 2 per token
30 | - Achieves 4x model capacity with only 25% computational overhead
31 | - Each expert specializes in different domains (numbers, language, code, etc.)
32 |
33 | **Multi-Token Prediction (MTP)**
34 | 
35 | - Predicts multiple tokens simultaneously during training
36 | - Improves training efficiency by providing more learning signals per forward pass
37 | - Enables faster inference through speculative decoding
38 |
39 | **Additional Components**
40 | - **RoPE (Rotary Positional Encoding)**: Better handling of longer sequences and relative positions
41 | - **RMS Norm**: Computationally simpler normalization without mean centering
42 | - **SwiGLU Activation**: Gated activation function for improved information flow control
43 |
44 |
45 | Final Model Weights
46 |
47 | https://huggingface.co/Mayank022/DeepSeek-V3-from-Scratch/tree/main
48 |
49 |
50 | ## Model Configuration
51 |
52 |
53 |
54 | ### Training Parameters
55 |
56 | | Parameter | Value | Description |
57 | |-----------|-------|-------------|
58 | | Model Parameters | 109,032,032 | Total trainable parameters |
59 | | Vocabulary Size | 50,257 | Number of unique tokens |
60 | | Block Size | 1,024 | Maximum sequence length |
61 | | Embedding Dimension | 512 | Hidden dimension size |
62 | | Number of Layers | 8 | Transformer blocks |
63 | | Attention Heads | 8 | Multi-head attention |
64 | | Batch Size | 32 | Training batch size |
65 | | Learning Rate | 0.0003 | Initial learning rate |
66 | | Min Learning Rate | 0.00001 | Minimum learning rate |
67 | | Warmup Steps | 2,000 | Learning rate warmup |
68 | | Max Iterations | 20,000 | Maximum training steps |
69 | | Dropout | 0.1 | Dropout probability |
70 | | Gradient Accumulation | 8 | Steps before optimizer update |
71 |
72 | ### MoE Configuration
73 |
74 | | Parameter | Value | Description |
75 | |-----------|-------|-------------|
76 | | Number of Experts | 8 | Total expert networks |
77 | | Experts per Token | 2 | Active experts per forward pass |
78 | | Expert Efficiency | 25% | Computation vs full dense model |
79 | | Capacity Multiplier | 4x | Model capacity increase |
80 |
81 | ### Attention Configuration
82 |
83 | | Parameter | Value | Description |
84 | |-----------|-------|-------------|
85 | | KV LoRA Rank | 128 | Key-Value compression rank |
86 | | Q LoRA Rank | 192 | Query compression rank |
87 | | MTP Heads | 1 | Multi-token prediction heads |
88 |
89 | ## Dataset
90 |
91 | **Primary Dataset**: FineWeb-Edu (CC-MAIN-2024 subset)
92 |
93 | https://huggingface.co/spaces/HuggingFaceFW/blogpost-fineweb-v1
94 |
95 | https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu
96 |
97 | - Total available records: 13 million
98 | - Used for training: 2 million records
99 | - Training tokens: 2.5 billion
100 | - Validation tokens: 132.8 million
101 | - Format: Educational web content optimized for language model training
102 |
103 | **Fallback Dataset**: TinyStories
104 |
105 | https://huggingface.co/datasets/roneneldan/TinyStories
106 |
107 | - Used for initial prototyping and architecture validation
108 | - Simpler content for testing basic functionality
109 |
110 | Paper
111 |
112 | https://arxiv.org/pdf/2412.19437
113 |
114 |
115 | ## Project Structure
116 |
117 | ```
118 | DeepSeek-from-Scratch/
119 | ├── models/
120 | │ ├── attention.py # Multi-Head Latent Attention implementation
121 | │ ├── config.py # Model configuration parameters
122 | │ ├── layers.py # RoPE, RMS Norm, SwiGLU implementations
123 | │ ├── model.py # Main DeepSeek transformer block
124 | │ ├── moe.py # Mixture of Experts implementation
125 | │ └── mtp.py # Multi-Token Prediction implementation
126 | ├── training/
127 | │ ├── data_loader.py # Dataset loading and preprocessing
128 | │ └── trainer.py # Training loop and optimization
129 | ├── inference/
130 | │ ├── generator.py # Text generation utilities
131 | │ └── run_inference.py # Inference script
132 | ├── notebooks/
133 | │ ├── Mixture_of_Experts_from_Scratch.ipynb
134 | │ ├── Multi_Head_Latent_Attention_From_Scratch.ipynb
135 | │ └── Multi_Token_Prediction_from_Scratch.ipynb
136 | ├── prepare_data_fineweb.py # FineWeb dataset preparation
137 | ├── prepare_data_tiny_stories.py # TinyStories dataset preparation
138 | └── main.py # Main training script
139 | ```
140 |
141 | ## Installation
142 |
143 | 1. Clone the repository:
144 | ```bash
145 | git clone https://github.com/username/DeepSeek-from-Scratch.git
146 | cd DeepSeek-from-Scratch
147 | ```
148 |
149 | 2. Create and activate virtual environment:
150 | ```bash
151 | python -m venv deepseek_env
152 | source deepseek_env/bin/activate # Linux/Mac
153 | # or
154 | deepseek_env\Scripts\activate # Windows
155 | ```
156 |
157 | 3. Install dependencies:
158 | ```bash
159 | pip install -r requirements.txt
160 | ```
161 |
162 | ## Training
163 |
164 | ### Data Preparation
165 |
166 | Prepare the FineWeb-Edu dataset:
167 | ```bash
168 | python prepare_data_fineweb.py
169 | ```
170 |
171 | Or use TinyStories for testing:
172 | ```bash
173 | python prepare_data_tiny_stories.py
174 | ```
175 |
176 | ### Start Training
177 |
178 | ```bash
179 | python main.py train
180 | ```
181 |
182 | Training configurations can be modified in `models/config.py`.
183 |
184 | ### Monitoring
185 |
186 | Training progress is tracked using Weights & Biases:
187 | - Model checkpoints are saved automatically
188 | - Loss curves and metrics are logged in real-time
189 | - Training time: Approximately 7 hours on A100 80GB
190 | - Cost: $9.53 for full training run
191 |
192 | ## Inference
193 |
194 | Update the test prompts in `inference/run_inference.py` and run:
195 |
196 | ```bash
197 | python inference/run_inference.py
198 | ```
199 |
200 | The model will generate text completions based on your input prompts.
201 |
202 | ## Key Implementation Details
203 |
204 | ### Multi-Head Latent Attention
205 |
206 | Traditional multi-head attention stores separate key-value pairs for each head, leading to significant memory overhead. MLA compresses these into shared latent representations:
207 |
208 | - Memory reduction: Proportional to number of attention heads
209 | - Performance maintenance: Retains expressiveness of full multi-head attention
210 | - Scalability: Enables training with longer sequences
211 |
212 | ### Mixture of Experts
213 |
214 | Instead of processing every token through the same large feed-forward network, MoE routes tokens to specialized experts:
215 |
216 | - Sparse activation: Only 25% of the model is active per token
217 | - Specialization: Different experts learn different types of patterns
218 | - Efficiency: 4x capacity increase with minimal computational overhead
219 |
220 | ### Multi-Token Prediction
221 |
222 | Enhances training by predicting multiple future tokens simultaneously:
223 |
224 | - Training efficiency: More learning signals per forward pass
225 | - Inference optimization: Enables speculative decoding techniques
226 | - Performance improvement: Better gradient flow during training
227 |
228 | ## Performance Metrics
229 |
230 | ### Training Results
231 |
232 | - **Final Loss**: Achieved convergence after 20,000 iterations
233 | - **Training Time**: 7 hours 1 minute on NVIDIA A100 80GB
234 | - **Memory Usage**: Efficient memory utilization with MLA compression
235 | - **Convergence**: Stable training with proper learning rate scheduling
236 |
237 | ### Model Efficiency
238 |
239 | - **Parameter Efficiency**: 109M parameters with MoE sparse activation
240 | - **Memory Efficiency**: Reduced KV cache through latent attention
241 | - **Computational Efficiency**: 25% active parameters per forward pass
242 |
243 | ## Technical Challenges Addressed
244 |
245 | ### Dataset Selection
246 |
247 | The choice of dataset was critical for demonstrating the architecture's benefits:
248 |
249 | - **TinyStories**: Too simple, didn't justify advanced architecture components
250 | - **Raw Web Data**: Too complex for resource-constrained training
251 | - **FineWeb-Edu**: Perfect balance of complexity and educational content quality
252 |
253 | ### Architecture Decisions
254 |
255 | Careful consideration was given to which components to include:
256 |
257 | - **Essential Components**: MLA, MoE, MTP all included for comprehensive implementation
258 | - **Training Constraints**: Context length limited to 1024 tokens due to compute budget
259 | - **Resource Management**: Balanced model size with available GPU memory
260 |
261 | ## Future Enhancements
262 |
263 | ### Planned Improvements
264 |
265 | - **Dataset Expansion**: Experiment with larger subsets of FineWeb-Edu
266 | - **Evaluation Metrics**: Implement comprehensive benchmarking suite
267 | - **Architecture Extensions**: Additional transformer innovations and optimizations
268 | - **Scaling Studies**: Analysis of performance across different model sizes
269 |
270 |
271 |
--------------------------------------------------------------------------------
/notebooks/Multi_Token_Prediction_from_Scratch.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": []
7 | },
8 | "kernelspec": {
9 | "name": "python3",
10 | "display_name": "Python 3"
11 | },
12 | "language_info": {
13 | "name": "python"
14 | }
15 | },
16 | "cells": [
17 | {
18 | "cell_type": "markdown",
19 | "source": [
20 | "# Multi-Token Prediction (MTP)"
21 | ],
22 | "metadata": {
23 | "id": "GZps_56evbje"
24 | }
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "source": [
29 | "## Step 0: Load Packages"
30 | ],
31 | "metadata": {
32 | "id": "ymOhz91jvhK8"
33 | }
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": null,
38 | "metadata": {
39 | "id": "yaCHMAGSrrA-"
40 | },
41 | "outputs": [],
42 | "source": [
43 | "import torch\n",
44 | "import torch.nn as nn\n",
45 | "import torch.nn.functional as F"
46 | ]
47 | },
48 | {
49 | "cell_type": "markdown",
50 | "source": [
51 | "## Step 1: Define RMSNorm Class"
52 | ],
53 | "metadata": {
54 | "id": "Q7OaXcEdvrkp"
55 | }
56 | },
57 | {
58 | "cell_type": "code",
59 | "source": [
60 | "class RMSNorm(nn.Module):\n",
61 | " \"\"\"Root Mean Square Layer Norm (no learning weights) \"\"\"\n",
62 | " def __init__(self,d_model,eps:float = 1e-8):\n",
63 | " super().__init__()\n",
64 | " self.eps = eps\n",
65 | "\n",
66 | " def forward(self,x):\n",
67 | " # x: (batch,d_model)\n",
68 | " rms = torch.sqrt(x.pow(2).mean(dim=-1,keepdim=True)+ self.eps)\n",
69 | " return x / rms"
70 | ],
71 | "metadata": {
72 | "id": "7brrcEzBvrYC"
73 | },
74 | "execution_count": null,
75 | "outputs": []
76 | },
77 | {
78 | "cell_type": "markdown",
79 | "source": [
80 | "## Step 2: Define the Multi-Token Prediction (MTP) class"
81 | ],
82 | "metadata": {
83 | "id": "-2jlelWbzk5q"
84 | }
85 | },
86 | {
87 | "cell_type": "code",
88 | "source": [
89 | "class SimpleMTP(nn.Module):\n",
90 | " def __init__(self,d_model:int,vocab_size:int,num_heads:int=3,nhead: int =1):\n",
91 | " \"\"\"\n",
92 | " d_model: hidden size (8 in this example)\n",
93 | " num_heads: number of sequential MTP steps (D)\n",
94 | " nhead: attention heads in each Transformer block\n",
95 | " \"\"\"\n",
96 | " super().__init__()\n",
97 | " self.d_model = d_model\n",
98 | " self.vocab_size = vocab_size\n",
99 | " self.num_heads = num_heads\n",
100 | "\n",
101 | " # shared modules\n",
102 | " self.rmsnorm = RMSNorm(d_model)\n",
103 | " self.embed = nn.Embedding(vocab_size,d_model)\n",
104 | " self.unembed = nn.Linear(d_model,vocab_size,bias=False)\n",
105 | " # share weights between embed and unembed\n",
106 | " self.unembed.weight = self.embed.weight\n",
107 | "\n",
108 | " # one projection + one Transformer per head\n",
109 | " self.projections = nn.ModuleList([\n",
110 | " nn.Linear(2*d_model,d_model) for _ in range(num_heads)\n",
111 | "\n",
112 | " ])\n",
113 | " self.transformers = nn.ModuleList([\n",
114 | " nn.TransformerEncoderLayer(d_model=d_model,nhead=nhead)\n",
115 | " for _ in range(num_heads)\n",
116 | " ])\n",
117 | "\n",
118 | " def forward(self,token_ids:torch.LongTensor,init_hidden:torch.Tensor = None):\n",
119 | " \"\"\"\n",
120 | " token_ids: (batch,seq_len) integer IDs of your input tokens\n",
121 | " init_hidden: optional (batch,seq_len,d_model) base hidden states;\n",
122 | " if None, uses token embedding as initial hidden.\n",
123 | "\n",
124 | " Returns:\n",
125 | " logits_out: Tensor of shape (batch,T-D,D,vocab_size),\n",
126 | " where T=seq_len and D=num_heads\n",
127 | "\n",
128 | "\n",
129 | "\n",
130 | " \"\"\"\n",
131 | "\n",
132 | " B,T = token_ids.shape\n",
133 | " device = token_ids.device\n",
134 | " # token embeddings: (B,T,d_model)\n",
135 | " embeds = self.embed(token_ids)\n",
136 | "\n",
137 | "\n",
138 | " # base hidden states\n",
139 | " if init_hidden is None:\n",
140 | " h0_seq = embeds # use embeddings as base hidden\n",
141 | " else:\n",
142 | " h0_seq = init_hidden # user-provided base states\n",
143 | "\n",
144 | " outputs = [] # will hold (B,D,vocab_size) for each i\n",
145 | " # slide over positions where i + D < T\n",
146 | " max_i = T - self.num_heads - 1\n",
147 | " for i in range(0,max_i + 1):\n",
148 | " # previous hidden for depth 0 at pos i\n",
149 | " h_prev = h0_seq[:,i,:] # (B,d_model)\n",
150 | "\n",
151 | "\n",
152 | " # collect logits for all k at this i\n",
153 | "\n",
154 | " logits_k = []\n",
155 | "\n",
156 | " for k in range(self.num_heads):\n",
157 | " # future token embed at pos i + (k+ 1)\n",
158 | " future_pos = i + (k+1)\n",
159 | " tok_embed = embeds[:,future_pos,:] # (B,d_model)\n",
160 | "\n",
161 | " # 1) RMS-normalize\n",
162 | " h_norm = self.rmsnorm(h_prev) # (B,d_model)\n",
163 | " e_norm = self.rmsnorm(tok_embed) # (B,d_model)\n",
164 | "\n",
165 | " # 2) concatenate -> (B,2*d_model)\n",
166 | " merged = torch.cat([h_norm,e_norm],dim=-1)\n",
167 | "\n",
168 | " # 3) project back to d_model\n",
169 | " proj = self.projections[k](merged) # (B, d_model)\n",
170 | "\n",
171 | " # 4) Transformer block (expects shape (S,B,d_model))\n",
172 | " x = proj.unsqueeze(0) # (1,B,d_model)\n",
173 | " x = self.transformers[k](x) # (1,B,d_model)\n",
174 | " h_curr = x.squeeze(0) # (B,d_model)\n",
175 | "\n",
176 | " # 5) unembed -> logits\n",
177 | " logits = self.unembed(h_curr) # (B,vocab_size)\n",
178 | " logits_k.append(logits)\n",
179 | "\n",
180 | " # 6) chain hidden for next depth\n",
181 | " h_prev = h_curr\n",
182 | "\n",
183 | " # stack along. depth axis -> (B,D,vocab_size)\n",
184 | " logits_k = torch.stack(logits_k,dim=1)\n",
185 | " outputs.append(logits_k)\n",
186 | "\n",
187 | " # stack along sequence axis -> (T-D,B,D,V) then permute -> (B,T-D,D,V)\n",
188 | "\n",
189 | " out = torch.stack(outputs,dim=0)\n",
190 | " out = out.permute(1,0,2,3).contiguous()\n",
191 | " return out"
192 | ],
193 | "metadata": {
194 | "id": "y52_ypdIvq-c"
195 | },
196 | "execution_count": null,
197 | "outputs": []
198 | },
199 | {
200 | "cell_type": "markdown",
201 | "source": [
202 | "## Step 3: Pass input tokens through the model and generate multiple next tokens."
203 | ],
204 | "metadata": {
205 | "id": "cxMqwj0w7GdC"
206 | }
207 | },
208 | {
209 | "cell_type": "code",
210 | "source": [
211 | "batch_size, seq_len,d_model,vocab_size = 1,8,8,5000\n",
212 | "model = SimpleMTP(d_model=d_model,vocab_size=vocab_size,num_heads=3)\n",
213 | "tokens = torch.randint(0,vocab_size,(batch_size,seq_len))\n",
214 | "\n",
215 | "\n",
216 | "# Forward pass\n",
217 | "logits = model(tokens)\n",
218 | "# logits.shape == (1,4-3,3,5000) -> (batch_size,T-D,D,V)\n",
219 | "print(\"Logits shape:\",logits.shape)\n",
220 | "\n",
221 | "# If you want to inspect the 1-step ahead predition at postition i=0:\n",
222 | "print(\"Head k=0 at i=0 logits:\",logits[0,0,0]) # a tensor of length vocab_size\n",
223 | "\n",
224 | "# Or to get all predictions at i=0 as token IDs:\n",
225 | "\n",
226 | "pred_ids = logits[0,0].argmax(dim=-1)\n",
227 | "print(\"Predicted tokens at i=0 for all heads:\",pred_ids) # a length-3 tensor"
228 | ],
229 | "metadata": {
230 | "colab": {
231 | "base_uri": "https://localhost:8080/"
232 | },
233 | "id": "dhOmCR8X1yei",
234 | "outputId": "c8d9a829-b5b8-4236-8fd3-a4354059690b"
235 | },
236 | "execution_count": null,
237 | "outputs": [
238 | {
239 | "output_type": "stream",
240 | "name": "stdout",
241 | "text": [
242 | "Logits shape: torch.Size([1, 5, 3, 5000])\n",
243 | "Head k=0 at i=0 logits: tensor([ 3.6052, 2.8964, -1.7114, ..., 1.4961, -3.3179, 1.1599],\n",
244 | " grad_fn=)\n",
245 | "Predicted tokens at i=0 for all heads: tensor([4207, 4708, 4765])\n"
246 | ]
247 | }
248 | ]
249 | },
250 | {
251 | "cell_type": "markdown",
252 | "source": [
253 | "## Step 4: Calcuate loss betweeen Loss between target tokens and predicted tokens"
254 | ],
255 | "metadata": {
256 | "id": "9GP9CvKg-VOT"
257 | }
258 | },
259 | {
260 | "cell_type": "code",
261 | "source": [
262 | "batch_size, seq_len, vocab_size = 1,8,5000\n",
263 | "\n",
264 | "# old (wrong): targets = torch.randint(0, vocab_size,(1,4))\n",
265 | "# new (right):\n",
266 | "\n",
267 | "targets = torch.randint(0,vocab_size,(batch_size,seq_len))\n",
268 | "print(\"targets.shape ->\",targets.shape) # torch.Size([1,8])\n",
269 | "\n",
270 | "\n",
271 | "# Now recompute:\n",
272 | "\n",
273 | "logits = model(tokens) # shape (1,5,3,5000)\n",
274 | "B,L,D,V = logits.shape # (1,5,3,5000)\n",
275 | "_,T = targets.shape # (1,8)\n",
276 | "assert L == T - D # 5 == 8 -3 passes\n",
277 | "\n",
278 | "\n",
279 | "# Double-loop loss:\n",
280 | "loss = 0.0\n",
281 | "for i in range(L):\n",
282 | " for k in range(D): # i = 0...4\n",
283 | " logits_ik = logits[:,i,k,:] # (1,5000)\n",
284 | " target_ik = targets[:,i + (k + 1)] # (1,)\n",
285 | " loss += F.cross_entropy(logits_ik,target_ik)\n",
286 | "\n",
287 | "loss = loss / (L*D)\n",
288 | "print(\"MTP loss:\",loss.item())\n",
289 | "\n",
290 | "\n"
291 | ],
292 | "metadata": {
293 | "colab": {
294 | "base_uri": "https://localhost:8080/"
295 | },
296 | "id": "6fvZFMQy-yXk",
297 | "outputId": "ff4b3e52-7d7a-4cb6-f1c0-12ef0f800f9b"
298 | },
299 | "execution_count": null,
300 | "outputs": [
301 | {
302 | "output_type": "stream",
303 | "name": "stdout",
304 | "text": [
305 | "targets.shape -> torch.Size([1, 8])\n",
306 | "MTP loss: 13.472195625305176\n"
307 | ]
308 | }
309 | ]
310 | }
311 | ]
312 | }
--------------------------------------------------------------------------------
/prepare_data_fineweb.py:
--------------------------------------------------------------------------------
1 | """
2 | Memory-Efficient FineWeb-Edu Dataset Preprocessing
3 |
4 | This version processes large datasets without loading everything into memory at once.
5 | Key improvements:
6 | - Streaming tokenization
7 | - Direct file writing
8 | - Memory-efficient shuffling
9 | - Progress tracking
10 | - Optional quality filtering
11 | """
12 |
13 | import os
14 | import logging
15 | import numpy as np
16 | import tiktoken
17 | from tqdm.auto import tqdm
18 | from datasets import load_dataset
19 | import random
20 | import re
21 | import pandas as pd
22 | import tempfile
23 | import mmap
24 |
25 | # ============================================================================
26 | # CONFIGURATION SECTION
27 | # ============================================================================
28 |
29 | # Main Dataset Processing Options
30 | TRAIN_ON_CUSTOM_ROWS = True
31 | CUSTOM_ROW_COUNT = 800000
32 |
33 | # Dataset Configuration
34 | DATASET_OPTIONS = [
35 | ("HuggingFaceFW/fineweb-edu", "CC-MAIN-2024-51"),
36 | ]
37 |
38 | # Processing Parameters
39 | CONTEXT_LENGTH = 1024
40 | MIN_TEXT_LENGTH = 50 # Only used when USE_QUALITY_FILTERING = True
41 | TRAIN_SPLIT = 0.95
42 | RANDOM_SEED = 42
43 | USE_QUALITY_FILTERING = False # Set to False to process ALL rows, True to use filtering
44 |
45 | # Memory Management
46 | BATCH_SIZE = 8000 # Process in smaller batches
47 | BUFFER_SIZE = 75000 # Tokens to buffer before writing
48 | TEMP_DIR = "temp_tokens" # Directory for temporary files
49 |
50 | # File Output Configuration
51 | TRAIN_FILENAME = "train.bin"
52 | VALIDATION_FILENAME = "validation.bin"
53 | LOG_FILENAME = "dataset_preparation.log"
54 |
55 | # Tokenizer Configuration
56 | TOKENIZER_NAME = "gpt2"
57 | DTYPE = np.uint16
58 |
59 | # ============================================================================
60 |
61 | # Setup logging
62 | logging.basicConfig(
63 | level=logging.INFO,
64 | format='%(asctime)s - %(message)s',
65 | handlers=[
66 | logging.StreamHandler(),
67 | logging.FileHandler(LOG_FILENAME)
68 | ]
69 | )
70 | logger = logging.getLogger(__name__)
71 |
72 | class MemoryEfficientTokenizer:
73 | """Tokenizes and saves data without loading everything into memory."""
74 |
75 | def __init__(self, tokenizer_name, context_length, dtype):
76 | self.tokenizer = tiktoken.get_encoding(tokenizer_name)
77 | self.context_length = context_length
78 | self.dtype = dtype
79 | self.temp_files = []
80 |
81 | def create_temp_dir(self):
82 | """Create temporary directory for intermediate files."""
83 | if not os.path.exists(TEMP_DIR):
84 | os.makedirs(TEMP_DIR)
85 |
86 | def clean_temp_files(self):
87 | """Clean up temporary files."""
88 | for temp_file in self.temp_files:
89 | if os.path.exists(temp_file):
90 | os.remove(temp_file)
91 | if os.path.exists(TEMP_DIR):
92 | try:
93 | os.rmdir(TEMP_DIR)
94 | except:
95 | pass
96 |
97 | def process_batch_to_temp_file(self, batch_samples, batch_id):
98 | """Process a batch of samples and save tokens to temporary file."""
99 | temp_filename = os.path.join(TEMP_DIR, f"batch_{batch_id}.bin")
100 | token_buffer = []
101 |
102 | for text in batch_samples:
103 | tokens = self.tokenizer.encode_ordinary(text)
104 |
105 | # Split into chunks of context_length
106 | for i in range(0, len(tokens), self.context_length):
107 | chunk = tokens[i:i + self.context_length]
108 | if len(chunk) == self.context_length: # Only complete chunks
109 | token_buffer.extend(chunk)
110 |
111 | # Write buffer to file when it gets large
112 | if len(token_buffer) >= BUFFER_SIZE:
113 | self._append_tokens_to_file(temp_filename, token_buffer)
114 | token_buffer = []
115 |
116 | # Write remaining tokens
117 | if token_buffer:
118 | self._append_tokens_to_file(temp_filename, token_buffer)
119 |
120 | self.temp_files.append(temp_filename)
121 | return temp_filename
122 |
123 | def _append_tokens_to_file(self, filename, tokens):
124 | """Append tokens to a binary file."""
125 | token_array = np.array(tokens, dtype=self.dtype)
126 | with open(filename, 'ab') as f:
127 | token_array.tofile(f)
128 |
129 | def merge_temp_files(self, temp_files, output_filename):
130 | """Merge temporary files into final output file."""
131 | logger.info(f"Merging {len(temp_files)} temporary files into {output_filename}")
132 |
133 | total_tokens = 0
134 | with open(output_filename, 'wb') as output_file:
135 | for temp_file in tqdm(temp_files, desc="Merging files"):
136 | if os.path.exists(temp_file):
137 | # Read and write in chunks to avoid memory issues
138 | with open(temp_file, 'rb') as f:
139 | while True:
140 | chunk = f.read(BUFFER_SIZE * 2) # Read in bytes
141 | if not chunk:
142 | break
143 | output_file.write(chunk)
144 | total_tokens += len(chunk) // 2 # 2 bytes per uint16
145 |
146 | logger.info(f"Merged {total_tokens:,} tokens into {output_filename}")
147 | return total_tokens
148 |
149 | class StreamingDataProcessor:
150 | """Process dataset in streaming fashion without loading all into memory."""
151 |
152 | def __init__(self):
153 | self.tokenizer_processor = MemoryEfficientTokenizer(
154 | TOKENIZER_NAME, CONTEXT_LENGTH, DTYPE
155 | )
156 |
157 | def clean_text(self, text):
158 | """Remove excessive whitespace and control characters."""
159 | text = re.sub(r'\s+', ' ', text)
160 | text = ''.join(char for char in text if ord(char) >= 32 or char in '\n\t')
161 | return text.strip()
162 |
163 | def is_quality_text(self, text):
164 | """Apply quality filters to determine if text should be included."""
165 | if not USE_QUALITY_FILTERING:
166 | return True # Accept all texts when filtering is disabled
167 |
168 | # Apply quality filters only when enabled
169 | return len(text.strip()) > MIN_TEXT_LENGTH
170 |
171 | def process_dataset_streaming(self, dataset):
172 | """Process dataset in streaming fashion with memory-efficient approach."""
173 | logger.info(f"Processing {CUSTOM_ROW_COUNT:,} rows with streaming approach")
174 | logger.info(f"Quality filtering: {'ENABLED' if USE_QUALITY_FILTERING else 'DISABLED'}")
175 |
176 | # Create temporary directory
177 | self.tokenizer_processor.create_temp_dir()
178 |
179 | # Process data in batches
180 | train_temp_files = []
181 | val_temp_files = []
182 |
183 | current_batch = []
184 | processed_count = 0
185 | accepted_count = 0
186 | train_count = 0
187 | val_count = 0
188 | batch_id = 0
189 |
190 | progress_bar = tqdm(total=CUSTOM_ROW_COUNT, desc="Processing samples")
191 |
192 | try:
193 | for example in dataset:
194 | if accepted_count >= CUSTOM_ROW_COUNT:
195 | break
196 |
197 | text = example.get('text', '')
198 | if not text:
199 | continue
200 |
201 | cleaned_text = self.clean_text(text)
202 | processed_count += 1
203 |
204 | if self.is_quality_text(cleaned_text):
205 | # Determine if this sample goes to train or validation
206 | if random.random() < TRAIN_SPLIT:
207 | current_batch.append(('train', cleaned_text))
208 | train_count += 1
209 | else:
210 | current_batch.append(('val', cleaned_text))
211 | val_count += 1
212 |
213 | accepted_count += 1
214 | progress_bar.update(1)
215 |
216 | # Process batch when it reaches target size
217 | if len(current_batch) >= BATCH_SIZE:
218 | self._process_current_batch(current_batch, batch_id,
219 | train_temp_files, val_temp_files)
220 | current_batch = []
221 | batch_id += 1
222 |
223 | # Update progress info
224 | progress_bar.set_postfix({
225 | 'train': train_count,
226 | 'val': val_count,
227 | 'batches': batch_id,
228 | 'processed': processed_count,
229 | 'accepted': accepted_count
230 | })
231 |
232 | except Exception as e:
233 | logger.error(f"Error during processing: {e}")
234 | finally:
235 | progress_bar.close()
236 |
237 | # Process remaining samples
238 | if current_batch:
239 | self._process_current_batch(current_batch, batch_id,
240 | train_temp_files, val_temp_files)
241 |
242 | logger.info(f"Processed {processed_count:,} total samples")
243 | logger.info(f"Accepted {accepted_count:,} samples")
244 | logger.info(f"Training samples: {train_count:,}, Validation samples: {val_count:,}")
245 |
246 | return train_temp_files, val_temp_files
247 |
248 | def _process_current_batch(self, batch, batch_id, train_temp_files, val_temp_files):
249 | """Process current batch and save to appropriate temporary files."""
250 | train_samples = [text for split, text in batch if split == 'train']
251 | val_samples = [text for split, text in batch if split == 'val']
252 |
253 | if train_samples:
254 | train_temp_file = self.tokenizer_processor.process_batch_to_temp_file(
255 | train_samples, f"train_{batch_id}"
256 | )
257 | train_temp_files.append(train_temp_file)
258 |
259 | if val_samples:
260 | val_temp_file = self.tokenizer_processor.process_batch_to_temp_file(
261 | val_samples, f"val_{batch_id}"
262 | )
263 | val_temp_files.append(val_temp_file)
264 |
265 | def finalize_dataset(self, train_temp_files, val_temp_files):
266 | """Merge temporary files into final dataset files."""
267 | # Merge training files
268 | if train_temp_files:
269 | train_tokens = self.tokenizer_processor.merge_temp_files(
270 | train_temp_files, TRAIN_FILENAME
271 | )
272 | else:
273 | train_tokens = 0
274 |
275 | # Merge validation files
276 | if val_temp_files:
277 | val_tokens = self.tokenizer_processor.merge_temp_files(
278 | val_temp_files, VALIDATION_FILENAME
279 | )
280 | else:
281 | val_tokens = 0
282 |
283 | # Clean up temporary files
284 | self.tokenizer_processor.clean_temp_files()
285 |
286 | return train_tokens, val_tokens
287 |
288 | def load_dataset_with_fallbacks():
289 | """Load dataset with fallback options."""
290 | for dataset_name, config_name in DATASET_OPTIONS:
291 | try:
292 | logger.info(f"Attempting to load {dataset_name}")
293 |
294 | if config_name:
295 | ds = load_dataset(dataset_name, name=config_name, split="train", streaming=True)
296 | else:
297 | ds = load_dataset(dataset_name, split="train", streaming=True)
298 |
299 | logger.info(f"Successfully loaded {dataset_name}")
300 | return ds, dataset_name
301 |
302 | except Exception as e:
303 | logger.warning(f"Failed to load {dataset_name}: {str(e)[:100]}")
304 | continue
305 |
306 | raise RuntimeError("Could not load any dataset. Check internet connection.")
307 |
308 | def verify_files():
309 | """Verify the created binary files."""
310 | logger.info("Verifying output files")
311 |
312 | if not (os.path.exists(TRAIN_FILENAME) and os.path.exists(VALIDATION_FILENAME)):
313 | logger.error("Binary files not found")
314 | return False
315 |
316 | try:
317 | train_data = np.fromfile(TRAIN_FILENAME, dtype=DTYPE)
318 | val_data = np.fromfile(VALIDATION_FILENAME, dtype=DTYPE)
319 |
320 | logger.info(f"Training tokens: {len(train_data):,}")
321 | logger.info(f"Validation tokens: {len(val_data):,}")
322 | logger.info(f"Training sequences: {len(train_data) // CONTEXT_LENGTH:,}")
323 | logger.info(f"Validation sequences: {len(val_data) // CONTEXT_LENGTH:,}")
324 |
325 | return True
326 |
327 | except Exception as e:
328 | logger.error(f"Verification failed: {e}")
329 | return False
330 |
331 | def main():
332 | """Main function to prepare the dataset with memory-efficient processing."""
333 | logger.info("Starting memory-efficient FineWeb-edu dataset preparation")
334 | logger.info(f"Configuration: {CUSTOM_ROW_COUNT:,} rows, context length {CONTEXT_LENGTH}")
335 | logger.info(f"Batch size: {BATCH_SIZE}, Buffer size: {BUFFER_SIZE}")
336 | logger.info(f"Quality filtering: {'ENABLED' if USE_QUALITY_FILTERING else 'DISABLED'}")
337 |
338 | # Set random seed
339 | random.seed(RANDOM_SEED)
340 |
341 | # Check if files already exist
342 | if os.path.exists(TRAIN_FILENAME) and os.path.exists(VALIDATION_FILENAME):
343 | logger.info("Dataset files already exist")
344 | verify_files()
345 | return
346 |
347 | # Load dataset
348 | dataset, dataset_name = load_dataset_with_fallbacks()
349 | logger.info(f"Using dataset: {dataset_name}")
350 |
351 | # Initialize processor
352 | processor = StreamingDataProcessor()
353 |
354 | # Process dataset
355 | train_temp_files, val_temp_files = processor.process_dataset_streaming(dataset)
356 |
357 | # Finalize dataset
358 | train_tokens, val_tokens = processor.finalize_dataset(train_temp_files, val_temp_files)
359 |
360 | # Summary
361 | logger.info("Dataset preparation completed")
362 | logger.info(f"Training tokens: {train_tokens:,}")
363 | logger.info(f"Validation tokens: {val_tokens:,}")
364 | logger.info(f"Ready for model training")
365 |
366 | # Verify output
367 | verify_files()
368 |
369 | if __name__ == "__main__":
370 | main()
--------------------------------------------------------------------------------
/notebooks/Mixture_of_Experts_from_Scratch.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "colab": {
8 | "base_uri": "https://localhost:8080/"
9 | },
10 | "id": "60-Qwx_hcDp0",
11 | "outputId": "ee0aa230-6dad-4298-b657-ed0097f94c8b"
12 | },
13 | "outputs": [
14 | {
15 | "output_type": "execute_result",
16 | "data": {
17 | "text/plain": [
18 | ""
19 | ]
20 | },
21 | "metadata": {},
22 | "execution_count": 72
23 | }
24 | ],
25 | "source": [
26 | "# Import the necessary packages and seet seed for reproducibility .\n",
27 | "\n",
28 | "import torch\n",
29 | "import torch.nn as nn\n",
30 | "from torch.nn import functional as F\n",
31 | "torch.manual_seed(42)"
32 | ]
33 | },
34 | {
35 | "cell_type": "markdown",
36 | "metadata": {
37 | "id": "4qAclO2Hoiu7"
38 | },
39 | "source": [
40 | "# Expert Module"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {
47 | "id": "In-gJA9hoPk5"
48 | },
49 | "outputs": [],
50 | "source": [
51 | "class Expert(nn.Module):\n",
52 | " \"\"\" An MLP is a simple linear layer followed by a non-linearity i.e each Expert \"\"\"\n",
53 | "\n",
54 | " def __init__(self,n_embd):\n",
55 | " super().__init__()\n",
56 | " self.net = nn.Sequential(\n",
57 | " nn.Linear(n_embd,4*n_embd),\n",
58 | " nn.ReLU(),\n",
59 | " nn.Linear(4*n_embd,n_embd),\n",
60 | " nn.Dropout(dropout),\n",
61 | " )\n",
62 | "\n",
63 | " def forward(self,x):\n",
64 | " return self.net(x)"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "metadata": {
71 | "colab": {
72 | "base_uri": "https://localhost:8080/"
73 | },
74 | "id": "eFwE10WJkApd",
75 | "outputId": "336ec790-96c4-4683-b481-dcd9be23167e"
76 | },
77 | "outputs": [
78 | {
79 | "output_type": "stream",
80 | "name": "stdout",
81 | "text": [
82 | "--2025-06-30 17:48:35-- https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt\n",
83 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
84 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
85 | "HTTP request sent, awaiting response... 200 OK\n",
86 | "Length: 1115394 (1.1M) [text/plain]\n",
87 | "Saving to: ‘input.txt.3’\n",
88 | "\n",
89 | "input.txt.3 100%[===================>] 1.06M --.-KB/s in 0.005s \n",
90 | "\n",
91 | "2025-06-30 17:48:35 (206 MB/s) - ‘input.txt.3’ saved [1115394/1115394]\n",
92 | "\n"
93 | ]
94 | }
95 | ],
96 | "source": [
97 | "!wget https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt"
98 | ]
99 | },
100 | {
101 | "cell_type": "markdown",
102 | "metadata": {
103 | "id": "mJBLPhTwP7Zi"
104 | },
105 | "source": [
106 | "# Step 2: Implement a Router"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": null,
112 | "metadata": {
113 | "colab": {
114 | "base_uri": "https://localhost:8080/"
115 | },
116 | "id": "6fjVIKAfP-0K",
117 | "outputId": "4306a632-a42d-445b-8a6f-4bde20e084a1"
118 | },
119 | "outputs": [
120 | {
121 | "output_type": "stream",
122 | "name": "stdout",
123 | "text": [
124 | "tensor([[[-0.8934, 0.1072, 0.5144, -0.2811],\n",
125 | " [-0.8497, -0.9384, -0.1564, -0.3949],\n",
126 | " [-0.1428, 0.6368, 0.3035, 0.9877],\n",
127 | " [-0.3160, -0.8139, -0.3333, -0.0203]]], grad_fn=)\n"
128 | ]
129 | }
130 | ],
131 | "source": [
132 | "# Understanding how gating works\n",
133 | "\n",
134 | "num_experts = 4\n",
135 | "top_k= 2\n",
136 | "n_embed = 32\n",
137 | "\n",
138 | "# Example multi-head attention output for a simple illustrative example, consider n_embed=32, context_length = 4\n",
139 | "mh_output = torch.randn(1,4,n_embed)\n",
140 | "\n",
141 | "topkgate_linear = nn.Linear(n_embed, num_experts) # nn.Linear(32,4)\n",
142 | "\n",
143 | "logits = topkgate_linear(mh_output)\n",
144 | "\n",
145 | "print(logits)"
146 | ]
147 | },
148 | {
149 | "cell_type": "markdown",
150 | "metadata": {
151 | "id": "bl1JAC5rR8Oy"
152 | },
153 | "source": [
154 | "# Step 3: Implement topk load balancing"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "metadata": {
161 | "colab": {
162 | "base_uri": "https://localhost:8080/"
163 | },
164 | "id": "kqKQrLXiRMPg",
165 | "outputId": "43082836-93cf-48ad-be6d-db09513816ec"
166 | },
167 | "outputs": [
168 | {
169 | "output_type": "execute_result",
170 | "data": {
171 | "text/plain": [
172 | "(tensor([[[ 0.5144, 0.1072],\n",
173 | " [-0.1564, -0.3949],\n",
174 | " [ 0.9877, 0.6368],\n",
175 | " [-0.0203, -0.3160]]], grad_fn=),\n",
176 | " tensor([[[2, 1],\n",
177 | " [2, 3],\n",
178 | " [3, 1],\n",
179 | " [3, 0]]]))"
180 | ]
181 | },
182 | "metadata": {},
183 | "execution_count": 76
184 | }
185 | ],
186 | "source": [
187 | "top_k_logits, top_k_indices = logits.topk(top_k,dim=-1) # Get top-k experts\n",
188 | "top_k_logits, top_k_indices"
189 | ]
190 | },
191 | {
192 | "cell_type": "markdown",
193 | "metadata": {
194 | "id": "KN9MqZftSaWy"
195 | },
196 | "source": [
197 | "# Step 4 : Use -inf and apply softmax\n"
198 | ]
199 | },
200 | {
201 | "cell_type": "code",
202 | "execution_count": null,
203 | "metadata": {
204 | "colab": {
205 | "base_uri": "https://localhost:8080/"
206 | },
207 | "id": "2sYkYIfHScvP",
208 | "outputId": "ea23bb19-380b-457b-f4f3-d26f2599d7c0"
209 | },
210 | "outputs": [
211 | {
212 | "output_type": "execute_result",
213 | "data": {
214 | "text/plain": [
215 | "tensor([[[ -inf, 0.1072, 0.5144, -inf],\n",
216 | " [ -inf, -inf, -0.1564, -0.3949],\n",
217 | " [ -inf, 0.6368, -inf, 0.9877],\n",
218 | " [-0.3160, -inf, -inf, -0.0203]]], grad_fn=)"
219 | ]
220 | },
221 | "metadata": {},
222 | "execution_count": 77
223 | }
224 | ],
225 | "source": [
226 | "zeros = torch.full_like(logits,float('-inf')) # full_like clones a tensor and fills it with a specified value\n",
227 | "sparse_logits = zeros.scatter(-1,top_k_indices,top_k_logits)\n",
228 | "sparse_logits"
229 | ]
230 | },
231 | {
232 | "cell_type": "code",
233 | "execution_count": null,
234 | "metadata": {
235 | "colab": {
236 | "base_uri": "https://localhost:8080/"
237 | },
238 | "id": "WgqU9QJOVI6N",
239 | "outputId": "b0b4027d-b3e0-4253-ad18-d4c9d3920d82"
240 | },
241 | "outputs": [
242 | {
243 | "output_type": "execute_result",
244 | "data": {
245 | "text/plain": [
246 | "tensor([[[0.0000, 0.3996, 0.6004, 0.0000],\n",
247 | " [0.0000, 0.0000, 0.5594, 0.4406],\n",
248 | " [0.0000, 0.4132, 0.0000, 0.5868],\n",
249 | " [0.4266, 0.0000, 0.0000, 0.5734]]], grad_fn=)"
250 | ]
251 | },
252 | "metadata": {},
253 | "execution_count": 78
254 | }
255 | ],
256 | "source": [
257 | "gating_output = F.softmax(sparse_logits,dim=-1)\n",
258 | "gating_output"
259 | ]
260 | },
261 | {
262 | "cell_type": "markdown",
263 | "metadata": {
264 | "id": "pHN4YFSsVSAU"
265 | },
266 | "source": [
267 | "# Step 5: Create a class for TopKRouting"
268 | ]
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": null,
273 | "metadata": {
274 | "id": "SYIqxvEcVPAa"
275 | },
276 | "outputs": [],
277 | "source": [
278 | "# First define the top k router module\n",
279 | "class TopkRouter(nn.Module):\n",
280 | " def __init__(self, n_embed, num_experts, top_k):\n",
281 | " super(TopkRouter, self).__init__()\n",
282 | " self.top_k = top_k\n",
283 | " self.linear =nn.Linear(n_embed, num_experts)\n",
284 | "\n",
285 | " def forward(self, mh_ouput):\n",
286 | " # mh_ouput is the output tensor from multihead self attention block\n",
287 | " logits = self.linear(mh_output)\n",
288 | " top_k_logits, indices = logits.topk(self.top_k, dim=-1)\n",
289 | " zeros = torch.full_like(logits, float('-inf'))\n",
290 | " sparse_logits = zeros.scatter(-1, indices, top_k_logits)\n",
291 | " router_output = F.softmax(sparse_logits, dim=-1)\n",
292 | " return router_output, indices\n"
293 | ]
294 | },
295 | {
296 | "cell_type": "code",
297 | "execution_count": null,
298 | "metadata": {
299 | "colab": {
300 | "base_uri": "https://localhost:8080/"
301 | },
302 | "id": "4SlPZzX0XMhK",
303 | "outputId": "52c7ed93-f2e4-4a23-90de-b9685199f827"
304 | },
305 | "outputs": [
306 | {
307 | "output_type": "execute_result",
308 | "data": {
309 | "text/plain": [
310 | "(torch.Size([2, 4, 4]),\n",
311 | " tensor([[[0.0000, 0.0000, 0.5321, 0.4679],\n",
312 | " [0.0000, 0.8659, 0.1341, 0.0000],\n",
313 | " [0.0000, 0.4096, 0.0000, 0.5904],\n",
314 | " [0.0000, 0.0000, 0.5828, 0.4172]],\n",
315 | " \n",
316 | " [[0.6210, 0.0000, 0.3790, 0.0000],\n",
317 | " [0.7005, 0.0000, 0.2995, 0.0000],\n",
318 | " [0.0000, 0.8630, 0.1370, 0.0000],\n",
319 | " [0.0000, 0.3759, 0.6241, 0.0000]]], grad_fn=),\n",
320 | " tensor([[[2, 3],\n",
321 | " [1, 2],\n",
322 | " [3, 1],\n",
323 | " [2, 3]],\n",
324 | " \n",
325 | " [[0, 2],\n",
326 | " [0, 2],\n",
327 | " [1, 2],\n",
328 | " [2, 1]]]))"
329 | ]
330 | },
331 | "metadata": {},
332 | "execution_count": 80
333 | }
334 | ],
335 | "source": [
336 | "#Testing this out:\n",
337 | "num_experts = 4\n",
338 | "top_k = 2\n",
339 | "n_embd = 32\n",
340 | "\n",
341 | "mh_output = torch.randn(2, 4, n_embd) # Example input\n",
342 | "top_k_gate = TopkRouter(n_embd, num_experts, top_k)\n",
343 | "gating_output, indices = top_k_gate(mh_output)\n",
344 | "gating_output.shape, gating_output, indices\n"
345 | ]
346 | },
347 | {
348 | "cell_type": "markdown",
349 | "metadata": {
350 | "id": "52WYwFrpZpSd"
351 | },
352 | "source": [
353 | "# Step 6 Create. a class for Noisy Topk Routing"
354 | ]
355 | },
356 | {
357 | "cell_type": "markdown",
358 | "metadata": {
359 | "id": "8yhBaSC9bUFT"
360 | },
361 | "source": [
362 | "Noisy top-k Gating is an important tool in training MoE models.\n",
363 | "\n",
364 | "Essentially, you don't want all the tokens to be sent to the same set of 'favored' experts.\n",
365 | "\n",
366 | "You want a fine balance of exploitation and exploraation. For this purpose, to load balance, it is helps logits from the gating linear layer. This makes trainning more efficient"
367 | ]
368 | },
369 | {
370 | "cell_type": "code",
371 | "execution_count": null,
372 | "metadata": {
373 | "id": "gVd8OFafXy4o"
374 | },
375 | "outputs": [],
376 | "source": [
377 | "#Changing the above to accomodate noisy top-k gating\n",
378 | "class NoisyTopkRouter(nn.Module):\n",
379 | " def __init__(self, n_embed, num_experts, top_k):\n",
380 | " super(NoisyTopkRouter, self).__init__()\n",
381 | " self.top_k = top_k\n",
382 | " #layer for router logits\n",
383 | " self.topkroute_linear = nn.Linear(n_embed, num_experts)\n",
384 | " self.noise_linear =nn.Linear(n_embed, num_experts)\n",
385 | "\n",
386 | "\n",
387 | " def forward(self, mh_output):\n",
388 | " # mh_ouput is the output tensor from multihead self attention block\n",
389 | " logits = self.topkroute_linear(mh_output)\n",
390 | "\n",
391 | " #Noise logits\n",
392 | " noise_logits = self.noise_linear(mh_output)\n",
393 | "\n",
394 | " #Adding scaled unit gaussian noise to the logits\n",
395 | " noise = torch.randn_like(logits)*F.softplus(noise_logits)\n",
396 | " noisy_logits = logits + noise\n",
397 | "\n",
398 | " top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)\n",
399 | " zeros = torch.full_like(noisy_logits, float('-inf'))\n",
400 | " sparse_logits = zeros.scatter(-1, indices, top_k_logits)\n",
401 | " router_output = F.softmax(sparse_logits, dim=-1)\n",
402 | " return router_output, indices\n"
403 | ]
404 | },
405 | {
406 | "cell_type": "code",
407 | "execution_count": null,
408 | "metadata": {
409 | "colab": {
410 | "base_uri": "https://localhost:8080/"
411 | },
412 | "id": "wKlwNFjtn0_P",
413 | "outputId": "a61ff7bd-8ded-404f-bb2d-9f4ab65e5060"
414 | },
415 | "outputs": [
416 | {
417 | "output_type": "execute_result",
418 | "data": {
419 | "text/plain": [
420 | "(torch.Size([2, 4, 8]),\n",
421 | " tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6640, 0.3360],\n",
422 | " [0.4700, 0.0000, 0.5300, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
423 | " [0.0000, 0.0000, 0.4277, 0.0000, 0.0000, 0.0000, 0.5723, 0.0000],\n",
424 | " [0.0000, 0.0000, 0.0000, 0.0000, 0.4105, 0.0000, 0.0000, 0.5895]],\n",
425 | " \n",
426 | " [[0.0000, 0.4177, 0.0000, 0.0000, 0.5823, 0.0000, 0.0000, 0.0000],\n",
427 | " [0.0000, 0.6713, 0.0000, 0.0000, 0.3287, 0.0000, 0.0000, 0.0000],\n",
428 | " [0.0000, 0.1670, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8330],\n",
429 | " [0.4322, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5678]]],\n",
430 | " grad_fn=),\n",
431 | " tensor([[[6, 7],\n",
432 | " [2, 0],\n",
433 | " [6, 2],\n",
434 | " [7, 4]],\n",
435 | " \n",
436 | " [[4, 1],\n",
437 | " [1, 4],\n",
438 | " [7, 1],\n",
439 | " [7, 0]]]))"
440 | ]
441 | },
442 | "metadata": {},
443 | "execution_count": 82
444 | }
445 | ],
446 | "source": [
447 | "#Testing this out, again:\n",
448 | "num_experts = 8\n",
449 | "top_k = 2\n",
450 | "n_embd = 16\n",
451 | "\n",
452 | "mh_output = torch.randn(2, 4, n_embd) # Example input\n",
453 | "noisy_top_k_gate = NoisyTopkRouter(n_embd, num_experts, top_k)\n",
454 | "gating_output, indices = noisy_top_k_gate(mh_output)\n",
455 | "gating_output.shape, gating_output, indices\n",
456 | "\n"
457 | ]
458 | },
459 | {
460 | "cell_type": "code",
461 | "execution_count": null,
462 | "metadata": {
463 | "id": "M-fNH0N3oFX0"
464 | },
465 | "outputs": [],
466 | "source": [
467 | "class SparseMoE(nn.Module):\n",
468 | " def __init__(self, n_embed, num_experts, top_k):\n",
469 | " super(SparseMoE, self).__init__()\n",
470 | " self.router = NoisyTopkRouter(n_embed, num_experts, top_k)\n",
471 | " self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])\n",
472 | " self.top_k = top_k\n",
473 | "\n",
474 | " def forward(self, x):\n",
475 | " gating_output, indices = self.router(x)\n",
476 | " final_output = torch.zeros_like(x)\n",
477 | "\n",
478 | " # Reshape inputs for batch processing\n",
479 | " flat_x = x.view(-1, x.size(-1))\n",
480 | " flat_gating_output = gating_output.view(-1, gating_output.size(-1))\n",
481 | "\n",
482 | " # Process each expert in parallel\n",
483 | " for i, expert in enumerate(self.experts):\n",
484 | " # Create a mask for the inputs where the current expert is in top-k\n",
485 | " expert_mask = (indices == i).any(dim=-1)\n",
486 | " flat_mask = expert_mask.view(-1)\n",
487 | "\n",
488 | " if flat_mask.any():\n",
489 | " expert_input = flat_x[flat_mask]\n",
490 | " expert_output = expert(expert_input)\n",
491 | "\n",
492 | " # Extract and apply gating scores\n",
493 | " gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)\n",
494 | " weighted_output = expert_output * gating_scores\n",
495 | "\n",
496 | " # Update final output additively by indexing and adding\n",
497 | " final_output[expert_mask] += weighted_output.squeeze(1)\n",
498 | "\n",
499 | " return final_output\n"
500 | ]
501 | },
502 | {
503 | "cell_type": "code",
504 | "execution_count": null,
505 | "metadata": {
506 | "colab": {
507 | "base_uri": "https://localhost:8080/"
508 | },
509 | "id": "o8x7da7BzGe5",
510 | "outputId": "fdb9974c-7bf2-4680-95b4-048957e7322b"
511 | },
512 | "outputs": [
513 | {
514 | "output_type": "stream",
515 | "name": "stdout",
516 | "text": [
517 | "Shape of the final output: torch.Size([1, 4, 16])\n",
518 | "tensor([[[ 0.4342, -0.4548, -0.0086, -0.0395, -0.1140, -0.0925, -0.2745,\n",
519 | " -0.0195, -0.3261, 0.1218, -0.0654, 0.1763, -0.1743, -0.2058,\n",
520 | " -0.1101, 0.1939],\n",
521 | " [-0.0275, -0.1477, -0.0534, 0.4124, 0.4952, -0.4857, -0.2878,\n",
522 | " -0.1663, 0.3566, 0.3007, 0.0182, 0.4281, 0.2393, 0.4560,\n",
523 | " -0.1753, 0.0000],\n",
524 | " [-0.0070, 0.0822, 0.0469, 0.2413, -0.3832, -0.0366, -0.1590,\n",
525 | " -0.1679, -0.2089, 0.0903, 0.2140, 0.2162, -0.0403, -0.0554,\n",
526 | " -0.1221, 0.2234],\n",
527 | " [-0.1403, -0.1466, 0.3349, -0.1171, 0.8054, 0.3437, -0.2110,\n",
528 | " 0.2405, -0.0367, 0.3583, -0.0914, -0.3995, -0.4149, -0.0041,\n",
529 | " 0.0000, -0.0971]]], grad_fn=)\n"
530 | ]
531 | }
532 | ],
533 | "source": [
534 | "import torch\n",
535 | "import torch.nn as nn\n",
536 | "\n",
537 | "#Let's test this out\n",
538 | "num_experts = 8\n",
539 | "top_k = 2\n",
540 | "n_embd = 16\n",
541 | "dropout=0.1\n",
542 | "\n",
543 | "mh_output = torch.randn(1, 4, n_embd) # Example multi-head attention output\n",
544 | "sparse_moe = SparseMoE(n_embd, num_experts, top_k)\n",
545 | "final_output = sparse_moe(mh_output)\n",
546 | "print(\"Shape of the final output:\", final_output.shape)\n",
547 | "print(final_output)\n"
548 | ]
549 | },
550 | {
551 | "cell_type": "markdown",
552 | "metadata": {
553 | "id": "9RBnw49ezXYG"
554 | },
555 | "source": [
556 | "# Step 8: Putting together all teh building blocks of MoE"
557 | ]
558 | },
559 | {
560 | "cell_type": "code",
561 | "execution_count": null,
562 | "metadata": {
563 | "id": "V6J9Qsa7zGyx"
564 | },
565 | "outputs": [],
566 | "source": [
567 | "#Create a self attention + mixture of experts block, that may be repeated several number of times\n",
568 | "class Block(nn.Module):\n",
569 | " \"\"\" Mixture of Experts Transformer block: communication followed by computation (multi-head self attention + SparseMoE) \"\"\"\n",
570 | "\n",
571 | " def __init__(self, n_embed, n_head, num_experts, top_k):\n",
572 | " # n_embed: embedding dimension, n_head: the number of heads we'd like\n",
573 | " super().__init__()\n",
574 | " head_size = n_embed // n_head\n",
575 | " self.sa = MultiHeadAttention(n_head, head_size)\n",
576 | " self.smoe = SparseMoE(n_embed, num_experts, top_k)\n",
577 | " self.ln1 = nn.LayerNorm(n_embed)\n",
578 | " self.ln2 = nn.LayerNorm(n_embed)\n",
579 | "\n",
580 | " def forward(self, x):\n",
581 | " x = x + self.sa(self.ln1(x))\n",
582 | " x = x + self.smoe(self.ln2(x))\n",
583 | " return x\n"
584 | ]
585 | },
586 | {
587 | "cell_type": "code",
588 | "execution_count": null,
589 | "metadata": {
590 | "id": "ta_Hf1SQ0qwe"
591 | },
592 | "outputs": [],
593 | "source": [
594 | "class SparseMoELanguageModel(nn.Module):\n",
595 | "\n",
596 | " def __init__(self):\n",
597 | " super().__init__()\n",
598 | " # each token directly reads off the logits for the next token from a lookup table\n",
599 | " self.token_embedding_table = nn.Embedding(vocab_size, n_embed)\n",
600 | " self.position_embedding_table = nn.Embedding(block_size, n_embed)\n",
601 | " self.blocks = nn.Sequential(*[Block(n_embed, n_head=n_head, num_experts=num_experts,top_k=top_k) for _ in range(n_layer)])\n",
602 | " self.ln_f = nn.LayerNorm(n_embed) # final layer norm\n",
603 | " self.lm_head = nn.Linear(n_embed, vocab_size)\n",
604 | "\n",
605 | " def forward(self, idx, targets=None):\n",
606 | " B, T = idx.shape\n",
607 | "\n",
608 | " # idx and targets are both (B,T) tensor of integers\n",
609 | " tok_emb = self.token_embedding_table(idx) # (B,T,C)\n",
610 | " pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n",
611 | " x = tok_emb + pos_emb # (B,T,C)\n",
612 | " x = self.blocks(x) # (B,T,C)\n",
613 | " x = self.ln_f(x) # (B,T,C)\n",
614 | " logits = self.lm_head(x) # (B,T,vocab_size)\n",
615 | "\n",
616 | " if targets is None:\n",
617 | " loss = None\n",
618 | " else:\n",
619 | " B, T, C = logits.shape\n",
620 | " logits = logits.view(B*T, C)\n",
621 | " targets = targets.view(B*T)\n",
622 | " loss = F.cross_entropy(logits, targets)\n",
623 | "\n",
624 | " return logits, loss\n",
625 | "\n",
626 | " def generate(self, idx, max_new_tokens):\n",
627 | " # idx is (B, T) array of indices in the current context\n",
628 | " for _ in range(max_new_tokens):\n",
629 | " # crop idx to the last block_size tokens\n",
630 | " idx_cond = idx[:, -block_size:]\n",
631 | " # get the predictions\n",
632 | " logits, loss = self(idx_cond)\n",
633 | " # focus only on the last time step\n",
634 | " logits = logits[:, -1, :] # becomes (B, C)\n",
635 | " # apply softmax to get probabilities\n",
636 | " probs = F.softmax(logits, dim=-1) # (B, C)\n",
637 | " # sample from the distribution\n",
638 | " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n",
639 | " # append sampled index to the running sequence\n",
640 | " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
641 | " return idx\n"
642 | ]
643 | },
644 | {
645 | "cell_type": "markdown",
646 | "metadata": {
647 | "id": "9skcec_600xP"
648 | },
649 | "source": [
650 | "# Step 9: Code the entire transformer block: (Multi-Head Attention)"
651 | ]
652 | },
653 | {
654 | "cell_type": "code",
655 | "execution_count": null,
656 | "metadata": {
657 | "id": "pt0W8MNj0rFd"
658 | },
659 | "outputs": [],
660 | "source": [
661 | "class Head(nn.Module):\n",
662 | " \"\"\" one head of self-attention \"\"\"\n",
663 | "\n",
664 | " def __init__(self,head_size):\n",
665 | " super().__init__()\n",
666 | " self.key = nn.Linear(n_embed,head_size,bias=False)\n",
667 | " self.query = nn.Linear(n_embed,head_size,bias=False)\n",
668 | " self.value = nn.Linear(n_embed,head_size,bias=False)\n",
669 | "\n",
670 | " self.register_buffer('tril',torch.tril(torch.ones(block_size,block_size)))\n",
671 | " self.dropout = nn.Dropout(dropout)\n",
672 | "\n",
673 | "\n",
674 | "\n",
675 | " def forward(self,x):\n",
676 | " B,T,C = x.shape\n",
677 | " k = self.key(x)\n",
678 | " q = self.query(x)\n",
679 | " # Computer attention scores (\"affinities\")\n",
680 | " wei = q @ k.transpose(-2,-1)*C**-0.5 # (B,T,C) @ (B,C,T) -> (B,T,T)\n",
681 | " wei = wei.masked_fill(self.tril[:T,:T]== 0,float('-inf')) # (B,T,T)\n",
682 | " wei = F.softmax(wei,dim=-1)\n",
683 | " wei = self.dropout(wei)\n",
684 | "\n",
685 | " # perform the weighted aggregation ofthe values\n",
686 | " v = self.value(x) # (B,T,C)\n",
687 | " out = wei @ v # (B,T,T) @ (B,T,C) -> (B,T,C)\n",
688 | " return out\n",
689 | "\n",
690 | "\n",
691 | "# Multi-Headed Self Attention\n",
692 | "\n",
693 | "class MultiHeadAttention(nn.Module):\n",
694 | " \"\"\" multiple heads of self-attention in parallel \"\"\"\n",
695 | "\n",
696 | "\n",
697 | " def __init__(self,num_heads,head_size):\n",
698 | " super().__init__()\n",
699 | " self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n",
700 | " self.proj = nn.Linear(n_embed,n_embed)\n",
701 | " self.dropout = nn.Dropout(dropout)\n",
702 | "\n",
703 | "\n",
704 | " def forward(self,x):\n",
705 | " out = torch.cat([h(x) for h in self.heads], dim=-1)\n",
706 | " out = self.dropout(self.proj(out))\n",
707 | " return out\n",
708 | "\n",
709 | "\n"
710 | ]
711 | },
712 | {
713 | "cell_type": "markdown",
714 | "metadata": {
715 | "id": "EYZ9CqedqXjy"
716 | },
717 | "source": [
718 | "# Step 10 Code the entire transfomer block: (Assemble all layers)"
719 | ]
720 | },
721 | {
722 | "cell_type": "markdown",
723 | "metadata": {
724 | "id": "l-qt80KRq7SO"
725 | },
726 | "source": [
727 | "Start by defining a Self-Attention combined with Mixture-of-Experts (MoE) block, designed to be modular and reusable across multiple layers of the model.\n",
728 | "(For better clarity, key architectural parameters are copied and reused.)"
729 | ]
730 | },
731 | {
732 | "cell_type": "code",
733 | "execution_count": null,
734 | "metadata": {
735 | "id": "2bKBebaxqWC9"
736 | },
737 | "outputs": [],
738 | "source": []
739 | },
740 | {
741 | "cell_type": "markdown",
742 | "metadata": {
743 | "id": "TOngSSo8wgkN"
744 | },
745 | "source": [
746 | "# Step 11: Define entire language model architecture\n"
747 | ]
748 | },
749 | {
750 | "cell_type": "markdown",
751 | "metadata": {
752 | "id": "nVwVQwFVwnQS"
753 | },
754 | "source": [
755 | "Finally putting it all together to crease a sparse mixture of experts language model"
756 | ]
757 | },
758 | {
759 | "cell_type": "code",
760 | "execution_count": null,
761 | "metadata": {
762 | "id": "gQKFUAWFrhl6"
763 | },
764 | "outputs": [],
765 | "source": [
766 | "class SparseMoELanguageModel(nn.Module):\n",
767 | "\n",
768 | " def __init__(self):\n",
769 | " super().__init__()\n",
770 | " # each token directly reads off the logits for the next token from a lookup table\n",
771 | " self.token_embedding_table = nn.Embedding(vocab_size, n_embed)\n",
772 | " self.position_embedding_table = nn.Embedding(block_size, n_embed)\n",
773 | " self.blocks = nn.Sequential(*[Block(n_embed, n_head=n_head, num_experts=num_experts,top_k=top_k) for _ in range(n_layer)])\n",
774 | " self.ln_f = nn.LayerNorm(n_embed) # final layer norm\n",
775 | " self.lm_head = nn.Linear(n_embed, vocab_size)\n",
776 | "\n",
777 | " def forward(self, idx, targets=None):\n",
778 | " B, T = idx.shape\n",
779 | "\n",
780 | " # idx and targets are both (B,T) tensor of integers\n",
781 | " tok_emb = self.token_embedding_table(idx) # (B,T,C)\n",
782 | " pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n",
783 | " x = tok_emb + pos_emb # (B,T,C)\n",
784 | " x = self.blocks(x) # (B,T,C)\n",
785 | " x = self.ln_f(x) # (B,T,C)\n",
786 | " logits = self.lm_head(x) # (B,T,vocab_size)\n",
787 | "\n",
788 | " if targets is None:\n",
789 | " loss = None\n",
790 | " else:\n",
791 | " B, T, C = logits.shape\n",
792 | " logits = logits.view(B*T, C)\n",
793 | " targets = targets.view(B*T)\n",
794 | " loss = F.cross_entropy(logits, targets)\n",
795 | "\n",
796 | " return logits, loss\n",
797 | "\n",
798 | " def generate(self, idx, max_new_tokens):\n",
799 | " # idx is (B, T) array of indices in the current context\n",
800 | " for _ in range(max_new_tokens):\n",
801 | " # crop idx to the last block_size tokens\n",
802 | " idx_cond = idx[:, -block_size:]\n",
803 | " # get the predictions\n",
804 | " logits, loss = self(idx_cond)\n",
805 | " # focus only on the last time step\n",
806 | " logits = logits[:, -1, :] # becomes (B, C)\n",
807 | " # apply softmax to get probabilities\n",
808 | " probs = F.softmax(logits, dim=-1) # (B, C)\n",
809 | " # sample from the distribution\n",
810 | " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n",
811 | " # append sampled index to the running sequence\n",
812 | " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
813 | " return idx\n"
814 | ]
815 | },
816 | {
817 | "cell_type": "markdown",
818 | "metadata": {
819 | "id": "d4ljt-EB0vi8"
820 | },
821 | "source": [
822 | "# Step 12: Create training and testing data"
823 | ]
824 | },
825 | {
826 | "cell_type": "code",
827 | "execution_count": null,
828 | "metadata": {
829 | "id": "Z7qpudyh0nhD"
830 | },
831 | "outputs": [],
832 | "source": [
833 | "torch.manual_seed(42)\n",
834 | "\n",
835 | "with open('input.txt','r',encoding='utf-8') as f:\n",
836 | " text = f.read()\n",
837 | "\n",
838 | "# here are all the unique characters that occur in this text\n",
839 | "\n",
840 | "chars = sorted(list(set(text)))\n",
841 | "vocab_size = len(chars)\n",
842 | "# create a mapping from characters to integers\n",
843 | "\n",
844 | "stoi = {ch:i for i,ch in enumerate(chars)}\n",
845 | "itos = {i:ch for i,ch in enumerate(chars)}\n",
846 | "encode = lambda s: [stoi[c] for c in s] # encoder: take a string,output a list of integers\n",
847 | "decode = lambda l: ''.join(itos[i] for i in l) # # decoder: take a list of integers, output a string\n",
848 | "\n",
849 | "\n",
850 | "# Train and test splits\n",
851 | "\n",
852 | "data = torch.tensor(encode(text),dtype=torch.long)\n",
853 | "n = int(0.9*len(data)) # first 90% will be train, rest val\n",
854 | "train_data = data[:n]\n",
855 | "val_data = data[n:]\n",
856 | "\n",
857 | "\n",
858 | "# data loading\n",
859 | "\n",
860 | "def get_batch(split):\n",
861 | " # generate a small batch of data of inputs x and target y\n",
862 | " data = train_data if split == 'train' else val_data\n",
863 | " ix = torch.randint(len(data) - block_size, (batch_size,))\n",
864 | "\n",
865 | " # Input: tokens at positions [i, i+1, ..., i+block_size-1]\n",
866 | " x = torch.stack([data[i:i+block_size] for i in ix])\n",
867 | "\n",
868 | " # Target: tokens at positions [i+1, i+2, ..., i+block_size]\n",
869 | " # This is the \"next token\" for each input token\n",
870 | " y = torch.stack([data[i+1:i+1+block_size] for i in ix])\n",
871 | "\n",
872 | " x, y = x.to(device), y.to(device)\n",
873 | " return x, y\n",
874 | "\n",
875 | "\n",
876 | "\n"
877 | ]
878 | },
879 | {
880 | "cell_type": "markdown",
881 | "metadata": {
882 | "id": "NfX2k052LtFQ"
883 | },
884 | "source": [
885 | "# Step 13: Define LLM Loss"
886 | ]
887 | },
888 | {
889 | "cell_type": "code",
890 | "execution_count": null,
891 | "metadata": {
892 | "id": "Fp7_Ea610brK"
893 | },
894 | "outputs": [],
895 | "source": [
896 | "@torch.no_grad()\n",
897 | "def estimate_loss():\n",
898 | " out = {}\n",
899 | " model.eval()\n",
900 | " for split in ['train','val']:\n",
901 | " losses = torch.zeros(eval_iters)\n",
902 | " for k in range(eval_iters):\n",
903 | " X,Y = get_batch(split)\n",
904 | " logits, loss = model(X,Y)\n",
905 | " losses[k] = loss.item()\n",
906 | " out[split] = losses.mean()\n",
907 | " model.train()\n",
908 | " return out\n"
909 | ]
910 | },
911 | {
912 | "cell_type": "markdown",
913 | "metadata": {
914 | "id": "0o514O42MKOn"
915 | },
916 | "source": [
917 | "# Step 14: Define training Loop parameters and other hyperparameters"
918 | ]
919 | },
920 | {
921 | "cell_type": "code",
922 | "execution_count": null,
923 | "metadata": {
924 | "id": "jHIgfItVMJks"
925 | },
926 | "outputs": [],
927 | "source": [
928 | "# First defining hyperparameters and boiler place code. Imports and data preparation code is repeated for convenience\n",
929 | "\n",
930 | "import torch\n",
931 | "import torch.nn as nn\n",
932 | "from torch.nn import functional as F\n",
933 | "from torch.nn import init\n",
934 | "\n",
935 | "# hyperparameters\n",
936 | "batch_size = 16 # how many independent sequences will we process in parallel?\n",
937 | "block_size = 32 # what is the maximum context lenght for prediction?\n",
938 | "max_iters = 1000\n",
939 | "eval_interval = 100\n",
940 | "learning_rate = 1e-3\n",
941 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
942 | "eval_iters = 400\n",
943 | "head_size = 16\n",
944 | "n_embed = 128\n",
945 | "n_head = 8\n",
946 | "n_layer = 8\n",
947 | "dropout = 0.1\n",
948 | "num_experts = 8\n",
949 | "top_k = 2"
950 | ]
951 | },
952 | {
953 | "cell_type": "markdown",
954 | "metadata": {
955 | "id": "IA266vdgNjci"
956 | },
957 | "source": [
958 | "# Step 15: Initialize the entire model"
959 | ]
960 | },
961 | {
962 | "cell_type": "code",
963 | "execution_count": null,
964 | "metadata": {
965 | "id": "4nDIBLGTMJcq"
966 | },
967 | "outputs": [],
968 | "source": [
969 | "def kaiming_init__weights(m):\n",
970 | " if isinstance (m, (nn.Linear)):\n",
971 | " init.kaiming_normal_(m.weight)"
972 | ]
973 | },
974 | {
975 | "cell_type": "code",
976 | "execution_count": null,
977 | "metadata": {
978 | "colab": {
979 | "base_uri": "https://localhost:8080/"
980 | },
981 | "id": "CakR1eE2NyiA",
982 | "outputId": "da8d6d0a-8c61-40b4-d6ab-9ed0c5635cb2"
983 | },
984 | "outputs": [
985 | {
986 | "output_type": "execute_result",
987 | "data": {
988 | "text/plain": [
989 | "SparseMoELanguageModel(\n",
990 | " (token_embedding_table): Embedding(65, 128)\n",
991 | " (position_embedding_table): Embedding(32, 128)\n",
992 | " (blocks): Sequential(\n",
993 | " (0): Block(\n",
994 | " (sa): MultiHeadAttention(\n",
995 | " (heads): ModuleList(\n",
996 | " (0-7): 8 x Head(\n",
997 | " (key): Linear(in_features=128, out_features=16, bias=False)\n",
998 | " (query): Linear(in_features=128, out_features=16, bias=False)\n",
999 | " (value): Linear(in_features=128, out_features=16, bias=False)\n",
1000 | " (dropout): Dropout(p=0.1, inplace=False)\n",
1001 | " )\n",
1002 | " )\n",
1003 | " (proj): Linear(in_features=128, out_features=128, bias=True)\n",
1004 | " (dropout): Dropout(p=0.1, inplace=False)\n",
1005 | " )\n",
1006 | " (smoe): SparseMoE(\n",
1007 | " (router): NoisyTopkRouter(\n",
1008 | " (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n",
1009 | " (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n",
1010 | " )\n",
1011 | " (experts): ModuleList(\n",
1012 | " (0-7): 8 x Expert(\n",
1013 | " (net): Sequential(\n",
1014 | " (0): Linear(in_features=128, out_features=512, bias=True)\n",
1015 | " (1): ReLU()\n",
1016 | " (2): Linear(in_features=512, out_features=128, bias=True)\n",
1017 | " (3): Dropout(p=0.1, inplace=False)\n",
1018 | " )\n",
1019 | " )\n",
1020 | " )\n",
1021 | " )\n",
1022 | " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
1023 | " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
1024 | " )\n",
1025 | " (1): Block(\n",
1026 | " (sa): MultiHeadAttention(\n",
1027 | " (heads): ModuleList(\n",
1028 | " (0-7): 8 x Head(\n",
1029 | " (key): Linear(in_features=128, out_features=16, bias=False)\n",
1030 | " (query): Linear(in_features=128, out_features=16, bias=False)\n",
1031 | " (value): Linear(in_features=128, out_features=16, bias=False)\n",
1032 | " (dropout): Dropout(p=0.1, inplace=False)\n",
1033 | " )\n",
1034 | " )\n",
1035 | " (proj): Linear(in_features=128, out_features=128, bias=True)\n",
1036 | " (dropout): Dropout(p=0.1, inplace=False)\n",
1037 | " )\n",
1038 | " (smoe): SparseMoE(\n",
1039 | " (router): NoisyTopkRouter(\n",
1040 | " (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n",
1041 | " (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n",
1042 | " )\n",
1043 | " (experts): ModuleList(\n",
1044 | " (0-7): 8 x Expert(\n",
1045 | " (net): Sequential(\n",
1046 | " (0): Linear(in_features=128, out_features=512, bias=True)\n",
1047 | " (1): ReLU()\n",
1048 | " (2): Linear(in_features=512, out_features=128, bias=True)\n",
1049 | " (3): Dropout(p=0.1, inplace=False)\n",
1050 | " )\n",
1051 | " )\n",
1052 | " )\n",
1053 | " )\n",
1054 | " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
1055 | " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
1056 | " )\n",
1057 | " (2): Block(\n",
1058 | " (sa): MultiHeadAttention(\n",
1059 | " (heads): ModuleList(\n",
1060 | " (0-7): 8 x Head(\n",
1061 | " (key): Linear(in_features=128, out_features=16, bias=False)\n",
1062 | " (query): Linear(in_features=128, out_features=16, bias=False)\n",
1063 | " (value): Linear(in_features=128, out_features=16, bias=False)\n",
1064 | " (dropout): Dropout(p=0.1, inplace=False)\n",
1065 | " )\n",
1066 | " )\n",
1067 | " (proj): Linear(in_features=128, out_features=128, bias=True)\n",
1068 | " (dropout): Dropout(p=0.1, inplace=False)\n",
1069 | " )\n",
1070 | " (smoe): SparseMoE(\n",
1071 | " (router): NoisyTopkRouter(\n",
1072 | " (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n",
1073 | " (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n",
1074 | " )\n",
1075 | " (experts): ModuleList(\n",
1076 | " (0-7): 8 x Expert(\n",
1077 | " (net): Sequential(\n",
1078 | " (0): Linear(in_features=128, out_features=512, bias=True)\n",
1079 | " (1): ReLU()\n",
1080 | " (2): Linear(in_features=512, out_features=128, bias=True)\n",
1081 | " (3): Dropout(p=0.1, inplace=False)\n",
1082 | " )\n",
1083 | " )\n",
1084 | " )\n",
1085 | " )\n",
1086 | " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
1087 | " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
1088 | " )\n",
1089 | " (3): Block(\n",
1090 | " (sa): MultiHeadAttention(\n",
1091 | " (heads): ModuleList(\n",
1092 | " (0-7): 8 x Head(\n",
1093 | " (key): Linear(in_features=128, out_features=16, bias=False)\n",
1094 | " (query): Linear(in_features=128, out_features=16, bias=False)\n",
1095 | " (value): Linear(in_features=128, out_features=16, bias=False)\n",
1096 | " (dropout): Dropout(p=0.1, inplace=False)\n",
1097 | " )\n",
1098 | " )\n",
1099 | " (proj): Linear(in_features=128, out_features=128, bias=True)\n",
1100 | " (dropout): Dropout(p=0.1, inplace=False)\n",
1101 | " )\n",
1102 | " (smoe): SparseMoE(\n",
1103 | " (router): NoisyTopkRouter(\n",
1104 | " (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n",
1105 | " (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n",
1106 | " )\n",
1107 | " (experts): ModuleList(\n",
1108 | " (0-7): 8 x Expert(\n",
1109 | " (net): Sequential(\n",
1110 | " (0): Linear(in_features=128, out_features=512, bias=True)\n",
1111 | " (1): ReLU()\n",
1112 | " (2): Linear(in_features=512, out_features=128, bias=True)\n",
1113 | " (3): Dropout(p=0.1, inplace=False)\n",
1114 | " )\n",
1115 | " )\n",
1116 | " )\n",
1117 | " )\n",
1118 | " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
1119 | " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
1120 | " )\n",
1121 | " (4): Block(\n",
1122 | " (sa): MultiHeadAttention(\n",
1123 | " (heads): ModuleList(\n",
1124 | " (0-7): 8 x Head(\n",
1125 | " (key): Linear(in_features=128, out_features=16, bias=False)\n",
1126 | " (query): Linear(in_features=128, out_features=16, bias=False)\n",
1127 | " (value): Linear(in_features=128, out_features=16, bias=False)\n",
1128 | " (dropout): Dropout(p=0.1, inplace=False)\n",
1129 | " )\n",
1130 | " )\n",
1131 | " (proj): Linear(in_features=128, out_features=128, bias=True)\n",
1132 | " (dropout): Dropout(p=0.1, inplace=False)\n",
1133 | " )\n",
1134 | " (smoe): SparseMoE(\n",
1135 | " (router): NoisyTopkRouter(\n",
1136 | " (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n",
1137 | " (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n",
1138 | " )\n",
1139 | " (experts): ModuleList(\n",
1140 | " (0-7): 8 x Expert(\n",
1141 | " (net): Sequential(\n",
1142 | " (0): Linear(in_features=128, out_features=512, bias=True)\n",
1143 | " (1): ReLU()\n",
1144 | " (2): Linear(in_features=512, out_features=128, bias=True)\n",
1145 | " (3): Dropout(p=0.1, inplace=False)\n",
1146 | " )\n",
1147 | " )\n",
1148 | " )\n",
1149 | " )\n",
1150 | " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
1151 | " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
1152 | " )\n",
1153 | " (5): Block(\n",
1154 | " (sa): MultiHeadAttention(\n",
1155 | " (heads): ModuleList(\n",
1156 | " (0-7): 8 x Head(\n",
1157 | " (key): Linear(in_features=128, out_features=16, bias=False)\n",
1158 | " (query): Linear(in_features=128, out_features=16, bias=False)\n",
1159 | " (value): Linear(in_features=128, out_features=16, bias=False)\n",
1160 | " (dropout): Dropout(p=0.1, inplace=False)\n",
1161 | " )\n",
1162 | " )\n",
1163 | " (proj): Linear(in_features=128, out_features=128, bias=True)\n",
1164 | " (dropout): Dropout(p=0.1, inplace=False)\n",
1165 | " )\n",
1166 | " (smoe): SparseMoE(\n",
1167 | " (router): NoisyTopkRouter(\n",
1168 | " (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n",
1169 | " (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n",
1170 | " )\n",
1171 | " (experts): ModuleList(\n",
1172 | " (0-7): 8 x Expert(\n",
1173 | " (net): Sequential(\n",
1174 | " (0): Linear(in_features=128, out_features=512, bias=True)\n",
1175 | " (1): ReLU()\n",
1176 | " (2): Linear(in_features=512, out_features=128, bias=True)\n",
1177 | " (3): Dropout(p=0.1, inplace=False)\n",
1178 | " )\n",
1179 | " )\n",
1180 | " )\n",
1181 | " )\n",
1182 | " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
1183 | " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
1184 | " )\n",
1185 | " (6): Block(\n",
1186 | " (sa): MultiHeadAttention(\n",
1187 | " (heads): ModuleList(\n",
1188 | " (0-7): 8 x Head(\n",
1189 | " (key): Linear(in_features=128, out_features=16, bias=False)\n",
1190 | " (query): Linear(in_features=128, out_features=16, bias=False)\n",
1191 | " (value): Linear(in_features=128, out_features=16, bias=False)\n",
1192 | " (dropout): Dropout(p=0.1, inplace=False)\n",
1193 | " )\n",
1194 | " )\n",
1195 | " (proj): Linear(in_features=128, out_features=128, bias=True)\n",
1196 | " (dropout): Dropout(p=0.1, inplace=False)\n",
1197 | " )\n",
1198 | " (smoe): SparseMoE(\n",
1199 | " (router): NoisyTopkRouter(\n",
1200 | " (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n",
1201 | " (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n",
1202 | " )\n",
1203 | " (experts): ModuleList(\n",
1204 | " (0-7): 8 x Expert(\n",
1205 | " (net): Sequential(\n",
1206 | " (0): Linear(in_features=128, out_features=512, bias=True)\n",
1207 | " (1): ReLU()\n",
1208 | " (2): Linear(in_features=512, out_features=128, bias=True)\n",
1209 | " (3): Dropout(p=0.1, inplace=False)\n",
1210 | " )\n",
1211 | " )\n",
1212 | " )\n",
1213 | " )\n",
1214 | " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
1215 | " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
1216 | " )\n",
1217 | " (7): Block(\n",
1218 | " (sa): MultiHeadAttention(\n",
1219 | " (heads): ModuleList(\n",
1220 | " (0-7): 8 x Head(\n",
1221 | " (key): Linear(in_features=128, out_features=16, bias=False)\n",
1222 | " (query): Linear(in_features=128, out_features=16, bias=False)\n",
1223 | " (value): Linear(in_features=128, out_features=16, bias=False)\n",
1224 | " (dropout): Dropout(p=0.1, inplace=False)\n",
1225 | " )\n",
1226 | " )\n",
1227 | " (proj): Linear(in_features=128, out_features=128, bias=True)\n",
1228 | " (dropout): Dropout(p=0.1, inplace=False)\n",
1229 | " )\n",
1230 | " (smoe): SparseMoE(\n",
1231 | " (router): NoisyTopkRouter(\n",
1232 | " (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n",
1233 | " (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n",
1234 | " )\n",
1235 | " (experts): ModuleList(\n",
1236 | " (0-7): 8 x Expert(\n",
1237 | " (net): Sequential(\n",
1238 | " (0): Linear(in_features=128, out_features=512, bias=True)\n",
1239 | " (1): ReLU()\n",
1240 | " (2): Linear(in_features=512, out_features=128, bias=True)\n",
1241 | " (3): Dropout(p=0.1, inplace=False)\n",
1242 | " )\n",
1243 | " )\n",
1244 | " )\n",
1245 | " )\n",
1246 | " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
1247 | " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
1248 | " )\n",
1249 | " )\n",
1250 | " (ln_f): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
1251 | " (lm_head): Linear(in_features=128, out_features=65, bias=True)\n",
1252 | ")"
1253 | ]
1254 | },
1255 | "metadata": {},
1256 | "execution_count": 93
1257 | }
1258 | ],
1259 | "source": [
1260 | "model = SparseMoELanguageModel()\n",
1261 | "model.apply(kaiming_init__weights)"
1262 | ]
1263 | },
1264 | {
1265 | "cell_type": "markdown",
1266 | "metadata": {
1267 | "id": "vr6Bz9-CODMn"
1268 | },
1269 | "source": [
1270 | "# Step 16: Run the pre-training Loop"
1271 | ]
1272 | },
1273 | {
1274 | "cell_type": "code",
1275 | "execution_count": null,
1276 | "metadata": {
1277 | "colab": {
1278 | "base_uri": "https://localhost:8080/"
1279 | },
1280 | "id": "QY7j6jf6OFrZ",
1281 | "outputId": "077b4ee9-db74-4ea9-8ac9-e65001e391a0"
1282 | },
1283 | "outputs": [
1284 | {
1285 | "output_type": "stream",
1286 | "name": "stdout",
1287 | "text": [
1288 | "8.996545 M parameters\n",
1289 | "step 0: train loss 5.1127, val loss 5.1287\n",
1290 | "step 100: train loss 2.6564, val loss 2.6589\n",
1291 | "step 200: train loss 2.4973, val loss 2.4938\n",
1292 | "step 300: train loss 2.3931, val loss 2.3948\n",
1293 | "step 400: train loss 2.3098, val loss 2.3223\n",
1294 | "step 500: train loss 2.2309, val loss 2.2536\n",
1295 | "step 600: train loss 2.1587, val loss 2.2019\n",
1296 | "step 700: train loss 2.1083, val loss 2.1624\n",
1297 | "step 800: train loss 2.0612, val loss 2.1126\n",
1298 | "step 900: train loss 2.0206, val loss 2.0880\n",
1299 | "step 999: train loss 1.9632, val loss 2.0509\n",
1300 | "=== VOCABULARY DEBUG ===\n",
1301 | "Actual vocabulary size from text: 65\n",
1302 | "Current vocab_size variable: 65\n",
1303 | "Model's output layer expects: 65\n",
1304 | "Vocabulary sizes match\n",
1305 | "=== END DEBUG ===\n"
1306 | ]
1307 | }
1308 | ],
1309 | "source": [
1310 | "m = model.to(device)\n",
1311 | "# Print the number of parameters in the model\n",
1312 | "print(sum(p.numel() for p in m.parameters())/1e6,'M parameters')\n",
1313 | "\n",
1314 | "# create a PyTorch optimizer\n",
1315 | "optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)\n",
1316 | "\n",
1317 | "for iter in range(max_iters):\n",
1318 | " # every once in a while evaluate the loss on train and val sets\n",
1319 | " if iter % eval_interval == 0 or iter == max_iters - 1:\n",
1320 | " losses = estimate_loss()\n",
1321 | " print(\"step {}: train loss {:.4f}, val loss {:.4f}\".format(iter, losses['train'], losses['val']))\n",
1322 | "\n",
1323 | " # sample. a batch of data\n",
1324 | " xb, yb = get_batch('train')\n",
1325 | "\n",
1326 | " # evaluate the loss\n",
1327 | " logits, loss = model(xb,yb)\n",
1328 | " optimizer.zero_grad(set_to_none=True)\n",
1329 | " loss.backward()\n",
1330 | " optimizer.step()\n",
1331 | "\n",
1332 | "# Debug the vocabulary mismatch\n",
1333 | "print(\"=== VOCABULARY DEBUG ===\")\n",
1334 | "\n",
1335 | "# Check actual vocabulary\n",
1336 | "with open('input.txt', 'r', encoding='utf-8') as f:\n",
1337 | " text = f.read()\n",
1338 | "\n",
1339 | "chars = sorted(list(set(text)))\n",
1340 | "actual_vocab_size = len(chars)\n",
1341 | "\n",
1342 | "print(f\"Actual vocabulary size from text: {actual_vocab_size}\")\n",
1343 | "print(f\"Current vocab_size variable: {vocab_size}\")\n",
1344 | "\n",
1345 | "# Check model's expected vocabulary size\n",
1346 | "model_vocab_size = model.lm_head.out_features\n",
1347 | "print(f\"Model's output layer expects: {model_vocab_size}\")\n",
1348 | "\n",
1349 | "# Show the mismatch\n",
1350 | "if actual_vocab_size != model_vocab_size:\n",
1351 | " print(f\"MISMATCH: Model expects {model_vocab_size} but data has {actual_vocab_size}\")\n",
1352 | " print(\"Solution: Recreate the model with the correct vocab_size\")\n",
1353 | "else:\n",
1354 | " print(\"Vocabulary sizes match\")\n",
1355 | "\n",
1356 | "print(\"=== END DEBUG ===\")"
1357 | ]
1358 | },
1359 | {
1360 | "cell_type": "markdown",
1361 | "metadata": {
1362 | "id": "vk9Vw9cOaYm2"
1363 | },
1364 | "source": [
1365 | "# Step 17: Inference"
1366 | ]
1367 | },
1368 | {
1369 | "cell_type": "code",
1370 | "source": [
1371 | "# generate from the model\n",
1372 | "\n",
1373 | "context = torch.zeros((1,1), dtype=torch.long, device=device)\n",
1374 | "print(decode(m.generate(context,max_new_tokens=2000)[0].tolist))"
1375 | ],
1376 | "metadata": {
1377 | "id": "Vu0pMx1ICON1"
1378 | },
1379 | "execution_count": null,
1380 | "outputs": []
1381 | }
1382 | ],
1383 | "metadata": {
1384 | "colab": {
1385 | "provenance": [],
1386 | "machine_shape": "hm"
1387 | },
1388 | "kernelspec": {
1389 | "display_name": "Python 3",
1390 | "name": "python3"
1391 | },
1392 | "language_info": {
1393 | "name": "python"
1394 | }
1395 | },
1396 | "nbformat": 4,
1397 | "nbformat_minor": 0
1398 | }
--------------------------------------------------------------------------------