├── inference ├── __init__.py └── generator.py ├── models ├── __init__.py ├── config.py ├── mtp.py ├── layers.py ├── moe.py ├── attention.py └── model.py ├── requirements.txt ├── training ├── __init__.py ├── data_loader.py └── trainer.py ├── .gitignore ├── main.py ├── run_inference.py ├── prepare_data_tiny_stories.py ├── README.md ├── notebooks ├── Multi_Token_Prediction_from_Scratch.ipynb └── Mixture_of_Experts_from_Scratch.ipynb └── prepare_data_fineweb.py /inference/__init__.py: -------------------------------------------------------------------------------- 1 | from .generator import generate_text 2 | 3 | __all__ = ['generate_text'] -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import DeepSeekConfig 2 | from .model import DeepSeekV3 3 | 4 | __all__ = ['DeepSeekConfig', 'DeepSeekV3'] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.0.0 2 | numpy>=1.21.0 3 | tqdm>=4.64.0 4 | tiktoken>=0.5.0 5 | datasets>=2.19.0 6 | transformers>=4.30.0 7 | wandb 8 | python-dotenv -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_loader import get_batch, estimate_loss 2 | from .trainer import train_model 3 | 4 | __all__ = ['get_batch', 'estimate_loss', 'train_model'] -------------------------------------------------------------------------------- /models/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | @dataclass 4 | class DeepSeekConfig: 5 | # Model architecture 6 | vocab_size: int = 50257 7 | block_size: int = 128 8 | n_layer: int = 6 9 | n_embd: int = 384 10 | n_head: int = 6 11 | 12 | # MLA configuration 13 | kv_lora_rank: int = 128 14 | q_lora_rank: int = 192 15 | rope_dim: int = 32 16 | 17 | # MoE configuration 18 | n_experts: int = 4 19 | n_experts_per_token: int = 2 20 | expert_intermediate_size: int = 512 21 | shared_expert_intermediate_size: int = 768 22 | use_shared_expert: bool = True 23 | 24 | # MTP configuration 25 | mtp_num_heads: int = 2 26 | 27 | # Training parameters 28 | dropout: float = 0.1 29 | bias: bool = True 30 | aux_loss_weight: float = 0.0 31 | mtp_loss_weight: float = 0.3 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Data files (large binary files) 2 | train.bin 3 | validation.bin 4 | *.bin 5 | 6 | input_output_pairs.csv 7 | training_samples.csv 8 | 9 | myenv/ 10 | 11 | 12 | # Model weights (large files) 13 | best_deepseek_v3.pt 14 | *.pt 15 | *.pth 16 | 17 | # Weights & Biases 18 | wandb/ 19 | *.wandb 20 | wandb-metadata.json 21 | wandb-summary.json 22 | 23 | # Common ML artifacts 24 | 25 | checkpoints/ 26 | logs/ 27 | outputs/ 28 | 29 | 30 | 31 | # Python cache 32 | __pycache__/ 33 | *.py[cod] 34 | *$py.class 35 | *.so 36 | 37 | # Environment 38 | .env 39 | .venv 40 | env/ 41 | venv/ 42 | ENV/ 43 | env.bak/ 44 | venv.bak/ 45 | 46 | # IDE 47 | .vscode/ 48 | .idea/ 49 | *.swp 50 | *.swo 51 | *~ 52 | 53 | # OS 54 | .DS_Store 55 | .DS_Store? 56 | ._* 57 | .Spotlight-V100 58 | .Trashes 59 | ehthumbs.db 60 | Thumbs.db 61 | 62 | # Jupyter Notebook 63 | .ipynb_checkpoints 64 | 65 | # PyTorch 66 | *.pth.tar 67 | 68 | # Logs 69 | *.log 70 | logs/ 71 | 72 | # Temporary files 73 | tmp/ 74 | temp/ -------------------------------------------------------------------------------- /training/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def get_batch(split, config, batch_size, device_type, device): 5 | if split == 'train': 6 | data = np.memmap('train.bin', dtype=np.uint16, mode='r') 7 | else: 8 | data = np.memmap('validation.bin', dtype=np.uint16, mode='r') 9 | 10 | ix = torch.randint(len(data) - config.block_size, (batch_size,)) 11 | x = torch.stack([torch.from_numpy((data[i:i+config.block_size]).astype(np.int64)) for i in ix]) 12 | y = torch.stack([torch.from_numpy((data[i+1:i+1+config.block_size]).astype(np.int64)) for i in ix]) 13 | 14 | if device_type == 'cuda': 15 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 16 | else: 17 | x, y = x.to(device), y.to(device) 18 | 19 | return x, y 20 | 21 | def estimate_loss(model, config, eval_iters, batch_size, device_type, device, ctx): 22 | out = {} 23 | model.eval() 24 | with torch.inference_mode(): 25 | for split in ['train', 'val']: 26 | losses = torch.zeros(eval_iters) 27 | for k in range(eval_iters): 28 | X, Y = get_batch(split, config, batch_size, device_type, device) 29 | with ctx: 30 | _, loss, _, _ = model(X, Y) 31 | losses[k] = loss.item() 32 | out[split] = losses.mean() 33 | model.train() 34 | return out -------------------------------------------------------------------------------- /models/mtp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .layers import RMSNorm 4 | from .attention import MultiHeadLatentAttention 5 | from .moe import MoELayer 6 | 7 | class MultiTokenPredictionHead(nn.Module): 8 | def __init__(self, config, depth): 9 | super().__init__() 10 | self.depth = depth 11 | self.n_embd = config.n_embd 12 | 13 | # Combine previous hidden state with future token embedding 14 | self.combine_proj = nn.Linear(2 * config.n_embd, config.n_embd, bias=config.bias) 15 | 16 | # Normalization 17 | self.norm1 = RMSNorm(config.n_embd) 18 | self.norm2 = RMSNorm(config.n_embd) 19 | 20 | # Transformer components 21 | self.attn = MultiHeadLatentAttention(config) 22 | self.mlp = MoELayer(config) 23 | self.attn_norm = RMSNorm(config.n_embd) 24 | self.mlp_norm = RMSNorm(config.n_embd) 25 | 26 | def forward(self, prev_hidden, future_token_embed): 27 | # Normalize inputs 28 | prev_norm = self.norm1(prev_hidden) 29 | future_norm = self.norm2(future_token_embed) 30 | 31 | # Combine representations 32 | combined = torch.cat([prev_norm, future_norm], dim=-1) 33 | hidden = self.combine_proj(combined) 34 | 35 | # Process through transformer components 36 | hidden = hidden + self.attn(self.attn_norm(hidden)) 37 | hidden = hidden + self.mlp(self.mlp_norm(hidden)) 38 | 39 | return hidden -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class RMSNorm(nn.Module): 6 | def __init__(self, ndim, eps=1e-6): 7 | super().__init__() 8 | self.eps = eps 9 | self.weight = nn.Parameter(torch.ones(ndim)) 10 | 11 | def forward(self, x): 12 | norm = x.norm(dim=-1, keepdim=True) * (x.size(-1) ** -0.5) 13 | return self.weight * x / (norm + self.eps) 14 | 15 | class RotaryEmbedding(nn.Module): 16 | def __init__(self, dim, max_seq_len=2048): 17 | super().__init__() 18 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 19 | self.register_buffer('inv_freq', inv_freq) 20 | self.max_seq_len = max_seq_len 21 | 22 | def forward(self, x, seq_len=None): 23 | if seq_len is None: 24 | seq_len = x.shape[-2] 25 | 26 | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) 27 | freqs = torch.outer(t, self.inv_freq) 28 | cos, sin = freqs.cos(), freqs.sin() 29 | return cos, sin 30 | 31 | def apply_rope(x, cos, sin): 32 | x1, x2 = x.chunk(2, dim=-1) 33 | return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) 34 | 35 | class SwiGLU(nn.Module): 36 | def __init__(self, in_features, hidden_features, out_features, bias=True): 37 | super().__init__() 38 | self.gate_proj = nn.Linear(in_features, hidden_features, bias=bias) 39 | self.up_proj = nn.Linear(in_features, hidden_features, bias=bias) 40 | self.down_proj = nn.Linear(hidden_features, out_features, bias=bias) 41 | 42 | def forward(self, x): 43 | gate = self.gate_proj(x) 44 | up = self.up_proj(x) 45 | return self.down_proj(F.silu(gate) * up) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepSeek-V3 Main Script 3 | 4 | Simple entry point for training and inference. 5 | """ 6 | 7 | import torch 8 | from models import DeepSeekConfig, DeepSeekV3 9 | from training import train_model 10 | from inference import generate_text 11 | 12 | def demo(): 13 | """Run a simple demo.""" 14 | print("=" * 50) 15 | print("DEEPSEEK-V3 DEMO") 16 | print("=" * 50) 17 | 18 | # Create small model for demo 19 | config = DeepSeekConfig( 20 | vocab_size=1000, 21 | block_size=64, 22 | n_layer=3, 23 | n_head=4, 24 | n_embd=128, 25 | kv_lora_rank=32, 26 | q_lora_rank=48, 27 | n_experts=4, 28 | n_experts_per_token=2, 29 | mtp_num_heads=1, 30 | dropout=0.0 31 | ) 32 | 33 | model = DeepSeekV3(config) 34 | total_params = sum(p.numel() for p in model.parameters()) 35 | print(f"Created model with {total_params:,} parameters") 36 | 37 | # Test forward pass 38 | batch_size, seq_len = 2, 32 39 | test_input = torch.randint(0, config.vocab_size, (batch_size, seq_len)) 40 | test_targets = torch.randint(0, config.vocab_size, (batch_size, seq_len)) 41 | 42 | with torch.no_grad(): 43 | logits, total_loss, main_loss, mtp_loss = model(test_input, test_targets) 44 | 45 | print(f"Input shape: {test_input.shape}") 46 | print(f"Output shape: {logits.shape}") 47 | print(f"Main loss: {main_loss:.4f}") 48 | if mtp_loss is not None: 49 | print(f"MTP loss: {mtp_loss:.4f}") 50 | print(f"Total loss: {total_loss:.4f}") 51 | 52 | # Test generation 53 | prompt = torch.randint(0, config.vocab_size, (1, 5)) 54 | with torch.no_grad(): 55 | generated = model.generate(prompt, max_new_tokens=20, temperature=1.0, top_k=10) 56 | 57 | print(f"Generated {generated.shape[1] - prompt.shape[1]} tokens") 58 | print("Demo completed!") 59 | 60 | def main(): 61 | """Main function.""" 62 | import sys 63 | 64 | if len(sys.argv) < 2: 65 | demo() 66 | return 67 | 68 | mode = sys.argv[1] 69 | 70 | if mode == "demo": 71 | demo() 72 | elif mode == "train": 73 | print("Starting training...") 74 | train_model() 75 | else: 76 | print("Usage: python main.py [demo|train|inference]") 77 | 78 | if __name__ == "__main__": 79 | main() -------------------------------------------------------------------------------- /inference/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tiktoken 3 | from models import DeepSeekV3 4 | 5 | def generate_text(model_path, config, prompt, max_tokens=100, temperature=0.8, top_k=50): 6 | """Generate text from a prompt using trained model.""" 7 | device = "cuda" if torch.cuda.is_available() else "cpu" 8 | 9 | # Load model 10 | model = DeepSeekV3(config) 11 | model.load_state_dict(torch.load(model_path, map_location=device)) 12 | model = model.to(device) 13 | model.eval() 14 | 15 | # Tokenize input 16 | enc = tiktoken.get_encoding("gpt2") 17 | context = torch.tensor(enc.encode_ordinary(prompt)).unsqueeze(0).to(device) 18 | 19 | # Generate 20 | with torch.no_grad(): 21 | generated = model.generate(context, max_tokens, temperature, top_k) 22 | 23 | # Decode and return 24 | result = enc.decode(generated.squeeze().tolist()) 25 | return result 26 | 27 | def run_inference_examples(): 28 | """Run inference examples with different prompts.""" 29 | try: 30 | from models import DeepSeekConfig 31 | 32 | config = DeepSeekConfig( 33 | vocab_size=50257, 34 | block_size=128, 35 | n_layer=4, 36 | n_head=4, 37 | n_embd=256, 38 | kv_lora_rank=64, 39 | q_lora_rank=96, 40 | n_experts=4, 41 | n_experts_per_token=2, 42 | mtp_num_heads=1, 43 | dropout=0.1 44 | ) 45 | 46 | test_prompts = [ 47 | "Once upon a time", 48 | "The little girl", 49 | "In a magical forest", 50 | "The brave knight" 51 | ] 52 | 53 | print("=" * 50) 54 | print("DEEPSEEK-V3 INFERENCE EXAMPLES") 55 | print("=" * 50) 56 | 57 | for prompt in test_prompts: 58 | result = generate_text( 59 | "best_deepseek_v3.pt", 60 | config, 61 | prompt, 62 | max_tokens=80, 63 | temperature=0.7, 64 | top_k=40 65 | ) 66 | 67 | print(f"\nPrompt: '{prompt}'") 68 | print("Generated:", result) 69 | print("-" * 30) 70 | 71 | except FileNotFoundError: 72 | print("Model file 'best_deepseek_v3.pt' not found. Please train the model first.") 73 | except Exception as e: 74 | print(f"Error during inference: {e}") -------------------------------------------------------------------------------- /run_inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple script to run inference with DeepSeek-V3 3 | Change the prompts below and run this file to generate text! 4 | """ 5 | 6 | import torch 7 | import tiktoken 8 | from models import DeepSeekConfig, DeepSeekV3 9 | 10 | def generate_text(prompt, max_tokens=100, temperature=0.8, top_k=50): 11 | """Generate text from your prompt.""" 12 | 13 | # Model configuration (same as training) 14 | config = DeepSeekConfig( 15 | vocab_size=50257, 16 | block_size=128, 17 | n_layer=4, 18 | n_head=4, 19 | n_embd=256, 20 | kv_lora_rank=64, 21 | q_lora_rank=96, 22 | n_experts=4, 23 | n_experts_per_token=2, 24 | mtp_num_heads=1, 25 | dropout=0.1 26 | ) 27 | 28 | device = "cuda" if torch.cuda.is_available() else "cpu" 29 | 30 | # Load model 31 | model = DeepSeekV3(config) 32 | try: 33 | model.load_state_dict(torch.load("best_deepseek_v3.pt", map_location=device)) 34 | print("✓ Loaded trained model") 35 | except FileNotFoundError: 36 | print("⚠️ No trained model found, using random weights") 37 | 38 | model = model.to(device) 39 | model.eval() 40 | 41 | # Tokenize input 42 | tokenizer = tiktoken.get_encoding("gpt2") 43 | context = torch.tensor(tokenizer.encode_ordinary(prompt)).unsqueeze(0).to(device) 44 | 45 | # Generate 46 | with torch.no_grad(): 47 | generated = model.generate(context, max_tokens, temperature, top_k) 48 | 49 | # Convert back to text 50 | result = tokenizer.decode(generated.squeeze().tolist()) 51 | return result 52 | 53 | 54 | if __name__ == "__main__": 55 | print("=" * 60) 56 | print("DEEPSEEK-V3 TEXT GENERATION") 57 | print("=" * 60) 58 | 59 | # ============================================ 60 | # CHANGE THESE PROMPTS TO WHATEVER YOU WANT! 61 | # ============================================ 62 | 63 | my_prompts = [ 64 | "Once upon a time", 65 | "The little girl found a magic", 66 | "In the future, artificial intelligence will", 67 | "The secret to happiness is", 68 | "Yesterday I went to the store and", 69 | ] 70 | 71 | # ============================================ 72 | # CHANGE THESE PARAMETERS TO EXPERIMENT! 73 | # ============================================ 74 | 75 | max_tokens = 80 # How many words to generate 76 | temperature = 0.8 # 0.1=boring, 0.8=balanced, 1.2=crazy 77 | top_k = 50 # Vocabulary limit 78 | 79 | # Generate text for each prompt 80 | for i, prompt in enumerate(my_prompts, 1): 81 | print(f"\n{i}. Prompt: '{prompt}'") 82 | print("-" * 40) 83 | 84 | result = generate_text( 85 | prompt=prompt, 86 | max_tokens=max_tokens, 87 | temperature=temperature, 88 | top_k=top_k 89 | ) 90 | 91 | print(f"Generated: {result}") 92 | print() 93 | 94 | print("=" * 60) 95 | print("DONE! Edit the prompts above and run again!") 96 | print("=" * 60) -------------------------------------------------------------------------------- /prepare_data_tiny_stories.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data Preparation Script for DeepSeek-V3 3 | 4 | This script downloads and tokenizes the TinyStories dataset. 5 | Run this BEFORE training the model. 6 | """ 7 | 8 | import os 9 | import numpy as np 10 | import tiktoken 11 | from tqdm.auto import tqdm 12 | from datasets import load_dataset 13 | 14 | def prepare_tinystories_dataset(): 15 | """Download and tokenize TinyStories dataset.""" 16 | 17 | print("=" * 60) 18 | print("PREPARING TINYSTORIES DATASET") 19 | print("=" * 60) 20 | 21 | # Check if already prepared 22 | if os.path.exists("train.bin") and os.path.exists("validation.bin"): 23 | print("✓ Dataset files already exist!") 24 | 25 | # Show file info 26 | train_size = os.path.getsize("train.bin") / (1024 * 1024) 27 | val_size = os.path.getsize("validation.bin") / (1024 * 1024) 28 | 29 | train_data = np.memmap('train.bin', dtype=np.uint16, mode='r') 30 | val_data = np.memmap('validation.bin', dtype=np.uint16, mode='r') 31 | 32 | print(f"├── train.bin: {len(train_data):,} tokens ({train_size:.1f} MB)") 33 | print(f"└── validation.bin: {len(val_data):,} tokens ({val_size:.1f} MB)") 34 | return 35 | 36 | print("Downloading TinyStories dataset...") 37 | 38 | # Load dataset 39 | ds = load_dataset("roneneldan/TinyStories") 40 | 41 | # Initialize tokenizer 42 | enc = tiktoken.get_encoding("gpt2") 43 | 44 | def process_example(example): 45 | """Tokenize text.""" 46 | ids = enc.encode_ordinary(example['text']) 47 | return {'ids': ids, 'len': len(ids)} 48 | 49 | print("Tokenizing dataset...") 50 | 51 | # Tokenize 52 | tokenized = ds.map( 53 | process_example, 54 | remove_columns=['text'], 55 | desc="Tokenizing splits", 56 | num_proc=8, 57 | ) 58 | 59 | # Create binary files 60 | for split, dset in tokenized.items(): 61 | arr_len = np.sum(dset['len'], dtype=np.uint64) 62 | filename = f'{split}.bin' 63 | dtype = np.uint16 # GPT-2 vocab size < 2^16 64 | 65 | print(f"Creating {filename}...") 66 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 67 | total_batches = 1024 68 | 69 | idx = 0 70 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'): 71 | batch = dset.shard( 72 | num_shards=total_batches, 73 | index=batch_idx, 74 | contiguous=True 75 | ).with_format('numpy') 76 | 77 | arr_batch = np.concatenate(batch['ids']) 78 | arr[idx : idx + len(arr_batch)] = arr_batch 79 | idx += len(arr_batch) 80 | 81 | arr.flush() 82 | 83 | size_mb = os.path.getsize(filename) / (1024 * 1024) 84 | print(f"✓ {filename}: {arr_len:,} tokens ({size_mb:.1f} MB)") 85 | 86 | print("\n Dataset preparation completed!") 87 | print("You can now run: python main.py train") 88 | 89 | 90 | if __name__ == "__main__": 91 | prepare_tinystories_dataset() -------------------------------------------------------------------------------- /models/moe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .layers import SwiGLU 5 | 6 | class MoELayer(nn.Module): 7 | def __init__(self, config): 8 | super().__init__() 9 | self.config = config 10 | self.n_experts = config.n_experts 11 | self.top_k = config.n_experts_per_token 12 | self.n_embd = config.n_embd 13 | 14 | # Router 15 | self.router = nn.Linear(config.n_embd, config.n_experts, bias=False) 16 | 17 | # Expert MLPs 18 | self.experts = nn.ModuleList([ 19 | SwiGLU( 20 | config.n_embd, 21 | config.expert_intermediate_size, 22 | config.n_embd, 23 | config.bias 24 | ) for _ in range(config.n_experts) 25 | ]) 26 | 27 | # Shared expert 28 | if config.use_shared_expert: 29 | self.shared_expert = SwiGLU( 30 | config.n_embd, 31 | config.shared_expert_intermediate_size, 32 | config.n_embd, 33 | config.bias 34 | ) 35 | else: 36 | self.shared_expert = None 37 | 38 | # Auxiliary-loss-free load balancing 39 | self.register_buffer('expert_bias', torch.zeros(config.n_experts)) 40 | self.bias_update_rate = 0.001 41 | 42 | def forward(self, x): 43 | batch_size, seq_len, hidden_size = x.shape 44 | x_flat = x.view(-1, hidden_size) 45 | 46 | # Routing phase 47 | router_logits = self.router(x_flat) + self.expert_bias 48 | 49 | # Top-k routing 50 | top_k_logits, top_k_indices = torch.topk(router_logits, self.top_k, dim=-1) 51 | routing_weights = torch.zeros_like(router_logits) 52 | routing_weights.scatter_(-1, top_k_indices, F.softmax(top_k_logits, dim=-1)) 53 | 54 | # Expert computation 55 | output = torch.zeros_like(x_flat) 56 | expert_usage = torch.zeros(self.n_experts, device=x.device) 57 | 58 | # Process through selected experts 59 | for expert_idx in range(self.n_experts): 60 | expert_mask = (top_k_indices == expert_idx).any(dim=-1) 61 | expert_usage[expert_idx] = expert_mask.sum().float() 62 | 63 | if expert_mask.any(): 64 | expert_input = x_flat[expert_mask] 65 | expert_output = self.experts[expert_idx](expert_input) 66 | 67 | # Weight by routing probability 68 | weights = routing_weights[expert_mask, expert_idx].unsqueeze(-1) 69 | output[expert_mask] += expert_output * weights 70 | 71 | # Add shared expert output 72 | if self.shared_expert is not None: 73 | shared_output = self.shared_expert(x_flat) 74 | output += shared_output 75 | 76 | # Auxiliary-loss-free load balancing 77 | if self.training: 78 | with torch.no_grad(): 79 | avg_usage = expert_usage.mean() 80 | for i in range(self.n_experts): 81 | if expert_usage[i] > avg_usage: 82 | self.expert_bias[i] -= self.bias_update_rate 83 | else: 84 | self.expert_bias[i] += self.bias_update_rate 85 | 86 | return output.view(batch_size, seq_len, hidden_size) -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from .layers import RMSNorm, RotaryEmbedding, apply_rope 6 | 7 | class MultiHeadLatentAttention(nn.Module): 8 | def __init__(self, config): 9 | super().__init__() 10 | self.config = config 11 | self.n_embd = config.n_embd 12 | self.n_head = config.n_head 13 | self.head_dim = config.n_embd // config.n_head 14 | 15 | # Compression dimensions 16 | self.kv_lora_rank = config.kv_lora_rank 17 | self.q_lora_rank = config.q_lora_rank 18 | self.rope_dim = config.rope_dim 19 | 20 | # KV compression 21 | self.kv_proj = nn.Linear(self.n_embd, self.kv_lora_rank, bias=False) 22 | self.kv_norm = RMSNorm(self.kv_lora_rank) 23 | 24 | # KV decompression 25 | self.k_decompress = nn.Linear(self.kv_lora_rank, self.n_head * self.head_dim, bias=False) 26 | self.v_decompress = nn.Linear(self.kv_lora_rank, self.n_head * self.head_dim, bias=False) 27 | 28 | # Query compression 29 | self.q_proj = nn.Linear(self.n_embd, self.q_lora_rank, bias=False) 30 | self.q_decompress = nn.Linear(self.q_lora_rank, self.n_head * self.head_dim, bias=False) 31 | 32 | # RoPE projections 33 | self.k_rope_proj = nn.Linear(self.n_embd, self.n_head * self.rope_dim, bias=False) 34 | self.q_rope_proj = nn.Linear(self.q_lora_rank, self.n_head * self.rope_dim, bias=False) 35 | 36 | # Output projection 37 | self.o_proj = nn.Linear(self.n_head * self.head_dim, self.n_embd, bias=config.bias) 38 | 39 | # Dropout 40 | self.attn_dropout = nn.Dropout(config.dropout) 41 | self.resid_dropout = nn.Dropout(config.dropout) 42 | 43 | # RoPE 44 | self.rope = RotaryEmbedding(self.rope_dim, config.block_size) 45 | 46 | # Causal mask 47 | self.register_buffer( 48 | "causal_mask", 49 | torch.tril(torch.ones(config.block_size, config.block_size)).view( 50 | 1, 1, config.block_size, config.block_size 51 | ) 52 | ) 53 | 54 | def forward(self, x): 55 | B, T, C = x.size() 56 | 57 | # Compression phase 58 | kv_compressed = self.kv_norm(self.kv_proj(x)) 59 | q_compressed = self.q_proj(x) 60 | 61 | # Decompression phase 62 | k_content = self.k_decompress(kv_compressed) 63 | v = self.v_decompress(kv_compressed) 64 | q_content = self.q_decompress(q_compressed) 65 | 66 | # RoPE components 67 | k_rope = self.k_rope_proj(x) 68 | q_rope = self.q_rope_proj(q_compressed) 69 | 70 | # Reshape for multi-head attention 71 | k_content = k_content.view(B, T, self.n_head, self.head_dim).transpose(1, 2) 72 | v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) 73 | q_content = q_content.view(B, T, self.n_head, self.head_dim).transpose(1, 2) 74 | k_rope = k_rope.view(B, T, self.n_head, self.rope_dim).transpose(1, 2) 75 | q_rope = q_rope.view(B, T, self.n_head, self.rope_dim).transpose(1, 2) 76 | 77 | # Apply RoPE 78 | cos, sin = self.rope(x, T) 79 | q_rope = apply_rope(q_rope, cos, sin) 80 | k_rope = apply_rope(k_rope, cos, sin) 81 | 82 | # Concatenate content and rope parts 83 | q = torch.cat([q_content, q_rope], dim=-1) 84 | k = torch.cat([k_content, k_rope], dim=-1) 85 | 86 | # Attention computation 87 | scale = 1.0 / math.sqrt(q.size(-1)) 88 | scores = torch.matmul(q, k.transpose(-2, -1)) * scale 89 | 90 | # Apply causal mask 91 | scores = scores.masked_fill(self.causal_mask[:, :, :T, :T] == 0, float('-inf')) 92 | 93 | # Softmax and dropout 94 | attn_weights = F.softmax(scores, dim=-1) 95 | attn_weights = self.attn_dropout(attn_weights) 96 | 97 | # Apply attention to values 98 | out = torch.matmul(attn_weights, v) 99 | 100 | # Reshape and project 101 | out = out.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim) 102 | out = self.resid_dropout(self.o_proj(out)) 103 | 104 | return out -------------------------------------------------------------------------------- /training/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import os 4 | import wandb 5 | from tqdm.auto import tqdm 6 | from contextlib import nullcontext 7 | from models import DeepSeekConfig, DeepSeekV3 8 | from .data_loader import get_batch, estimate_loss 9 | from dotenv import load_dotenv 10 | 11 | # Load environment variables 12 | load_dotenv() 13 | 14 | def train_model(): 15 | # Configuration 16 | config = DeepSeekConfig( 17 | vocab_size=50257, 18 | block_size=1024, 19 | n_layer=8, 20 | n_head=8, 21 | n_embd=512, 22 | kv_lora_rank=128, 23 | q_lora_rank=192, 24 | n_experts=8, 25 | n_experts_per_token=2, 26 | mtp_num_heads=1, 27 | dropout=0.1 28 | ) 29 | 30 | # Training parameters 31 | learning_rate = 3e-4 32 | max_iters = 20000 33 | warmup_steps = 2000 34 | min_lr = 1e-5 35 | eval_iters = 1000 36 | batch_size = 32 37 | gradient_accumulation_steps = 8 38 | 39 | # Device setup 40 | device = "cuda" if torch.cuda.is_available() else "cpu" 41 | device_type = 'cuda' if 'cuda' in device else 'cpu' 42 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' 43 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 44 | ctx = nullcontext() if device_type == 'cpu' else torch.cuda.amp.autocast(dtype=ptdtype) 45 | 46 | # Initialize wandb 47 | wandb.init( 48 | project="deepseek-v3-training", 49 | config={ 50 | "learning_rate": learning_rate, 51 | "max_iters": max_iters, 52 | "warmup_steps": warmup_steps, 53 | "min_lr": min_lr, 54 | "eval_iters": eval_iters, 55 | "batch_size": batch_size, 56 | "gradient_accumulation_steps": gradient_accumulation_steps, 57 | "device": device, 58 | "dtype": dtype, 59 | "vocab_size": config.vocab_size, 60 | "block_size": config.block_size, 61 | "n_layer": config.n_layer, 62 | "n_head": config.n_head, 63 | "n_embd": config.n_embd, 64 | "kv_lora_rank": config.kv_lora_rank, 65 | "q_lora_rank": config.q_lora_rank, 66 | "n_experts": config.n_experts, 67 | "n_experts_per_token": config.n_experts_per_token, 68 | "mtp_num_heads": config.mtp_num_heads, 69 | "dropout": config.dropout 70 | } 71 | ) 72 | 73 | # Initialize model 74 | torch.manual_seed(42) 75 | model = DeepSeekV3(config) 76 | model = model.to(device) 77 | 78 | # Print model info 79 | total_params = sum(p.numel() for p in model.parameters()) 80 | print(f"DeepSeek-V3 model with {total_params:,} parameters") 81 | 82 | # Log model parameters to wandb 83 | wandb.log({"total_parameters": total_params}) 84 | 85 | # Optimizer 86 | optimizer = torch.optim.AdamW( 87 | model.parameters(), 88 | lr=learning_rate, 89 | betas=(0.9, 0.95), 90 | weight_decay=0.1, 91 | eps=1e-9 92 | ) 93 | 94 | # Training loop 95 | model.train() 96 | best_val_loss = float('inf') 97 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) 98 | 99 | for epoch in tqdm(range(max_iters)): 100 | # Evaluation 101 | if epoch % eval_iters == 0 and epoch != 0: 102 | losses = estimate_loss(model, config, eval_iters, batch_size, device_type, device, ctx) 103 | print(f"Epoch {epoch}: train {losses['train']:.4f}, val {losses['val']:.4f}") 104 | 105 | # Log evaluation losses to wandb 106 | wandb.log({ 107 | "epoch": epoch, 108 | "train_loss": losses['train'], 109 | "val_loss": losses['val'], 110 | "best_val_loss": best_val_loss 111 | }) 112 | 113 | if losses['val'] < best_val_loss: 114 | best_val_loss = losses['val'] 115 | torch.save(model.state_dict(), "best_deepseek_v3.pt") 116 | 117 | # Log best model save to wandb 118 | wandb.log({"best_val_loss_updated": best_val_loss}) 119 | 120 | # Training step 121 | X, y = get_batch("train", config, batch_size, device_type, device) 122 | 123 | with ctx: 124 | _, total_loss, main_loss, mtp_loss = model(X, y) 125 | loss = total_loss / gradient_accumulation_steps 126 | scaler.scale(loss).backward() 127 | 128 | if ((epoch + 1) % gradient_accumulation_steps == 0) or (epoch + 1 == max_iters): 129 | scaler.step(optimizer) 130 | scaler.update() 131 | optimizer.zero_grad(set_to_none=True) 132 | 133 | # Learning rate scheduling 134 | if epoch < warmup_steps: 135 | lr = learning_rate * (epoch + 1) / warmup_steps 136 | else: 137 | progress = (epoch - warmup_steps) / (max_iters - warmup_steps) 138 | lr = min_lr + (learning_rate - min_lr) * 0.5 * (1 + math.cos(math.pi * progress)) 139 | 140 | for param_group in optimizer.param_groups: 141 | param_group['lr'] = lr 142 | 143 | # Log training metrics to wandb every step 144 | wandb.log({ 145 | "step": epoch, 146 | "total_loss": total_loss.item(), 147 | "main_loss": main_loss.item(), 148 | "mtp_loss": mtp_loss.item(), 149 | "learning_rate": lr, 150 | "scaled_loss": loss.item() 151 | }) 152 | 153 | print("Training completed!") 154 | 155 | # Finish wandb run 156 | wandb.finish() 157 | 158 | return model, config -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from .layers import RMSNorm 6 | from .attention import MultiHeadLatentAttention 7 | from .moe import MoELayer 8 | from .mtp import MultiTokenPredictionHead 9 | 10 | class DeepSeekBlock(nn.Module): 11 | def __init__(self, config): 12 | super().__init__() 13 | self.ln_1 = RMSNorm(config.n_embd) 14 | self.attn = MultiHeadLatentAttention(config) 15 | self.ln_2 = RMSNorm(config.n_embd) 16 | self.mlp = MoELayer(config) 17 | 18 | def forward(self, x): 19 | x = x + self.attn(self.ln_1(x)) 20 | x = x + self.mlp(self.ln_2(x)) 21 | return x 22 | 23 | class DeepSeekV3(nn.Module): 24 | def __init__(self, config): 25 | super().__init__() 26 | self.config = config 27 | 28 | # Token and position embeddings 29 | self.wte = nn.Embedding(config.vocab_size, config.n_embd) 30 | self.wpe = nn.Embedding(config.block_size, config.n_embd) 31 | self.drop = nn.Dropout(config.dropout) 32 | 33 | # Transformer blocks 34 | self.h = nn.ModuleList([DeepSeekBlock(config) for _ in range(config.n_layer)]) 35 | 36 | # Final layer norm 37 | self.ln_f = RMSNorm(config.n_embd) 38 | 39 | # Language modeling head 40 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 41 | 42 | # Weight tying 43 | self.wte.weight = self.lm_head.weight 44 | 45 | # Multi-Token Prediction heads 46 | if config.mtp_num_heads > 0: 47 | self.mtp_heads = nn.ModuleList([ 48 | MultiTokenPredictionHead(config, depth) 49 | for depth in range(1, config.mtp_num_heads + 1) 50 | ]) 51 | else: 52 | self.mtp_heads = None 53 | 54 | # Initialize weights 55 | self.apply(self._init_weights) 56 | 57 | # Special initialization for residual projections 58 | for pn, p in self.named_parameters(): 59 | if pn.endswith('o_proj.weight') or pn.endswith('down_proj.weight'): 60 | nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) 61 | 62 | def _init_weights(self, module): 63 | if isinstance(module, nn.Linear): 64 | nn.init.normal_(module.weight, mean=0.0, std=0.02) 65 | if module.bias is not None: 66 | nn.init.zeros_(module.bias) 67 | elif isinstance(module, nn.Embedding): 68 | nn.init.normal_(module.weight, mean=0.0, std=0.02) 69 | 70 | def forward(self, idx, targets=None): 71 | device = idx.device 72 | b, t = idx.size() 73 | assert t <= self.config.block_size 74 | 75 | # Token and position embeddings 76 | pos = torch.arange(0, t, dtype=torch.long, device=device) 77 | tok_emb = self.wte(idx) 78 | pos_emb = self.wpe(pos) 79 | x = self.drop(tok_emb + pos_emb) 80 | 81 | # Transformer blocks 82 | for block in self.h: 83 | x = block(x) 84 | 85 | # Final norm 86 | x = self.ln_f(x) 87 | 88 | # Main language modeling head 89 | main_logits = self.lm_head(x) 90 | main_loss = None 91 | 92 | if targets is not None: 93 | main_loss = F.cross_entropy( 94 | main_logits.view(-1, main_logits.size(-1)), 95 | targets.view(-1), 96 | ignore_index=-1 97 | ) 98 | 99 | # Multi-Token Prediction 100 | mtp_loss = None 101 | if self.mtp_heads is not None and targets is not None: 102 | mtp_losses = [] 103 | current_hidden = x 104 | 105 | for depth, mtp_head in enumerate(self.mtp_heads, 1): 106 | if t > depth: 107 | future_indices = idx[:, depth:] 108 | future_embeds = self.wte(future_indices) 109 | 110 | if future_embeds.size(1) < current_hidden.size(1): 111 | pad_size = current_hidden.size(1) - future_embeds.size(1) 112 | padding = torch.zeros( 113 | b, pad_size, self.config.n_embd, 114 | device=device, dtype=future_embeds.dtype 115 | ) 116 | future_embeds = torch.cat([future_embeds, padding], dim=1) 117 | elif future_embeds.size(1) > current_hidden.size(1): 118 | future_embeds = future_embeds[:, :current_hidden.size(1)] 119 | 120 | current_hidden = mtp_head(current_hidden, future_embeds) 121 | mtp_logits = self.lm_head(current_hidden) 122 | 123 | if t > depth + 1: 124 | shift_logits = mtp_logits[..., :-(depth+1), :].contiguous() 125 | shift_labels = targets[..., depth+1:].contiguous() 126 | 127 | if shift_labels.numel() > 0: 128 | mtp_loss_single = F.cross_entropy( 129 | shift_logits.view(-1, shift_logits.size(-1)), 130 | shift_labels.view(-1), 131 | ignore_index=-1 132 | ) 133 | mtp_losses.append(mtp_loss_single) 134 | 135 | if mtp_losses: 136 | mtp_loss = torch.stack(mtp_losses).mean() 137 | 138 | # Combine losses 139 | if targets is not None: 140 | if mtp_loss is not None: 141 | total_loss = main_loss + self.config.mtp_loss_weight * mtp_loss 142 | return main_logits, total_loss, main_loss, mtp_loss 143 | else: 144 | return main_logits, main_loss, main_loss, None 145 | else: 146 | return main_logits[:, [-1], :], None, None, None 147 | 148 | @torch.no_grad() 149 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): 150 | for _ in range(max_new_tokens): 151 | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] 152 | logits, _, _, _ = self(idx_cond) 153 | logits = logits[:, -1, :] / temperature 154 | 155 | if top_k is not None: 156 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 157 | logits[logits < v[:, [-1]]] = -float('Inf') 158 | 159 | probs = F.softmax(logits, dim=-1) 160 | idx_next = torch.multinomial(probs, num_samples=1) 161 | idx = torch.cat((idx, idx_next), dim=1) 162 | 163 | return idx -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepSeek V3 from Scratch 2 | 3 | A complete implementation of the DeepSeek V3 architecture with modern transformer innovations including Multi-Head Latent Attention (MLA), Mixture of Experts (MoE), and Multi-Token Prediction (MTP). This project demonstrates the implementation of a 100+ million parameter language model trained on the FineWeb-Edu dataset. 4 | 5 | ## Architecture Overview 6 | 7 | 8 | ![DeepSeek architecture](https://github.com/user-attachments/assets/8751e031-61e8-4ef2-9823-5e4316bd6356) 9 | 10 | DeepSeek V3 introduces several key architectural improvements over traditional transformer models: 11 | 12 | ### Core Innovations 13 | 14 | **Multi-Head Latent Attention (MLA)** 15 | ![Multi-Head-Latent-Attention](https://github.com/user-attachments/assets/564a2bf0-ab76-4a50-ae91-2f3eadef337d) 16 | - Compresses key-value pairs into shared latent representations 17 | - Dramatically reduces memory usage compared to traditional multi-head attention 18 | - Maintains the expressiveness of multiple attention heads while using significantly less memory 19 | - Critical for handling longer sequences efficiently 20 | 21 | 22 | 23 | 24 | 25 | 26 | **Mixture of Experts (MoE)** 27 | ![Mixture of Experts](https://github.com/user-attachments/assets/d7a4196d-753f-4aa5-9534-067c2a84c0ae) 28 | - Replaces dense feed-forward networks with sparse expert networks 29 | - Uses 8 experts but only activates 2 per token 30 | - Achieves 4x model capacity with only 25% computational overhead 31 | - Each expert specializes in different domains (numbers, language, code, etc.) 32 | 33 | **Multi-Token Prediction (MTP)** 34 | ![Multi Token prediction](https://github.com/user-attachments/assets/52051bc1-641e-44f4-af4e-63f64f133a64) 35 | - Predicts multiple tokens simultaneously during training 36 | - Improves training efficiency by providing more learning signals per forward pass 37 | - Enables faster inference through speculative decoding 38 | 39 | **Additional Components** 40 | - **RoPE (Rotary Positional Encoding)**: Better handling of longer sequences and relative positions 41 | - **RMS Norm**: Computationally simpler normalization without mean centering 42 | - **SwiGLU Activation**: Gated activation function for improved information flow control 43 | 44 | 45 | Final Model Weights 46 | 47 | https://huggingface.co/Mayank022/DeepSeek-V3-from-Scratch/tree/main 48 | 49 | 50 | ## Model Configuration 51 | 52 | Model Summary 53 | 54 | ### Training Parameters 55 | 56 | | Parameter | Value | Description | 57 | |-----------|-------|-------------| 58 | | Model Parameters | 109,032,032 | Total trainable parameters | 59 | | Vocabulary Size | 50,257 | Number of unique tokens | 60 | | Block Size | 1,024 | Maximum sequence length | 61 | | Embedding Dimension | 512 | Hidden dimension size | 62 | | Number of Layers | 8 | Transformer blocks | 63 | | Attention Heads | 8 | Multi-head attention | 64 | | Batch Size | 32 | Training batch size | 65 | | Learning Rate | 0.0003 | Initial learning rate | 66 | | Min Learning Rate | 0.00001 | Minimum learning rate | 67 | | Warmup Steps | 2,000 | Learning rate warmup | 68 | | Max Iterations | 20,000 | Maximum training steps | 69 | | Dropout | 0.1 | Dropout probability | 70 | | Gradient Accumulation | 8 | Steps before optimizer update | 71 | 72 | ### MoE Configuration 73 | 74 | | Parameter | Value | Description | 75 | |-----------|-------|-------------| 76 | | Number of Experts | 8 | Total expert networks | 77 | | Experts per Token | 2 | Active experts per forward pass | 78 | | Expert Efficiency | 25% | Computation vs full dense model | 79 | | Capacity Multiplier | 4x | Model capacity increase | 80 | 81 | ### Attention Configuration 82 | 83 | | Parameter | Value | Description | 84 | |-----------|-------|-------------| 85 | | KV LoRA Rank | 128 | Key-Value compression rank | 86 | | Q LoRA Rank | 192 | Query compression rank | 87 | | MTP Heads | 1 | Multi-token prediction heads | 88 | 89 | ## Dataset 90 | 91 | **Primary Dataset**: FineWeb-Edu (CC-MAIN-2024 subset) 92 | 93 | https://huggingface.co/spaces/HuggingFaceFW/blogpost-fineweb-v1 94 | 95 | https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu 96 | 97 | - Total available records: 13 million 98 | - Used for training: 2 million records 99 | - Training tokens: 2.5 billion 100 | - Validation tokens: 132.8 million 101 | - Format: Educational web content optimized for language model training 102 | 103 | **Fallback Dataset**: TinyStories 104 | 105 | https://huggingface.co/datasets/roneneldan/TinyStories 106 | 107 | - Used for initial prototyping and architecture validation 108 | - Simpler content for testing basic functionality 109 | 110 | Paper 111 | 112 | https://arxiv.org/pdf/2412.19437 113 | 114 | 115 | ## Project Structure 116 | 117 | ``` 118 | DeepSeek-from-Scratch/ 119 | ├── models/ 120 | │ ├── attention.py # Multi-Head Latent Attention implementation 121 | │ ├── config.py # Model configuration parameters 122 | │ ├── layers.py # RoPE, RMS Norm, SwiGLU implementations 123 | │ ├── model.py # Main DeepSeek transformer block 124 | │ ├── moe.py # Mixture of Experts implementation 125 | │ └── mtp.py # Multi-Token Prediction implementation 126 | ├── training/ 127 | │ ├── data_loader.py # Dataset loading and preprocessing 128 | │ └── trainer.py # Training loop and optimization 129 | ├── inference/ 130 | │ ├── generator.py # Text generation utilities 131 | │ └── run_inference.py # Inference script 132 | ├── notebooks/ 133 | │ ├── Mixture_of_Experts_from_Scratch.ipynb 134 | │ ├── Multi_Head_Latent_Attention_From_Scratch.ipynb 135 | │ └── Multi_Token_Prediction_from_Scratch.ipynb 136 | ├── prepare_data_fineweb.py # FineWeb dataset preparation 137 | ├── prepare_data_tiny_stories.py # TinyStories dataset preparation 138 | └── main.py # Main training script 139 | ``` 140 | 141 | ## Installation 142 | 143 | 1. Clone the repository: 144 | ```bash 145 | git clone https://github.com/username/DeepSeek-from-Scratch.git 146 | cd DeepSeek-from-Scratch 147 | ``` 148 | 149 | 2. Create and activate virtual environment: 150 | ```bash 151 | python -m venv deepseek_env 152 | source deepseek_env/bin/activate # Linux/Mac 153 | # or 154 | deepseek_env\Scripts\activate # Windows 155 | ``` 156 | 157 | 3. Install dependencies: 158 | ```bash 159 | pip install -r requirements.txt 160 | ``` 161 | 162 | ## Training 163 | 164 | ### Data Preparation 165 | 166 | Prepare the FineWeb-Edu dataset: 167 | ```bash 168 | python prepare_data_fineweb.py 169 | ``` 170 | 171 | Or use TinyStories for testing: 172 | ```bash 173 | python prepare_data_tiny_stories.py 174 | ``` 175 | 176 | ### Start Training 177 | 178 | ```bash 179 | python main.py train 180 | ``` 181 | 182 | Training configurations can be modified in `models/config.py`. 183 | 184 | ### Monitoring 185 | 186 | Training progress is tracked using Weights & Biases: 187 | - Model checkpoints are saved automatically 188 | - Loss curves and metrics are logged in real-time 189 | - Training time: Approximately 7 hours on A100 80GB 190 | - Cost: $9.53 for full training run 191 | 192 | ## Inference 193 | 194 | Update the test prompts in `inference/run_inference.py` and run: 195 | 196 | ```bash 197 | python inference/run_inference.py 198 | ``` 199 | 200 | The model will generate text completions based on your input prompts. 201 | 202 | ## Key Implementation Details 203 | 204 | ### Multi-Head Latent Attention 205 | 206 | Traditional multi-head attention stores separate key-value pairs for each head, leading to significant memory overhead. MLA compresses these into shared latent representations: 207 | 208 | - Memory reduction: Proportional to number of attention heads 209 | - Performance maintenance: Retains expressiveness of full multi-head attention 210 | - Scalability: Enables training with longer sequences 211 | 212 | ### Mixture of Experts 213 | 214 | Instead of processing every token through the same large feed-forward network, MoE routes tokens to specialized experts: 215 | 216 | - Sparse activation: Only 25% of the model is active per token 217 | - Specialization: Different experts learn different types of patterns 218 | - Efficiency: 4x capacity increase with minimal computational overhead 219 | 220 | ### Multi-Token Prediction 221 | 222 | Enhances training by predicting multiple future tokens simultaneously: 223 | 224 | - Training efficiency: More learning signals per forward pass 225 | - Inference optimization: Enables speculative decoding techniques 226 | - Performance improvement: Better gradient flow during training 227 | 228 | ## Performance Metrics 229 | 230 | ### Training Results 231 | 232 | - **Final Loss**: Achieved convergence after 20,000 iterations 233 | - **Training Time**: 7 hours 1 minute on NVIDIA A100 80GB 234 | - **Memory Usage**: Efficient memory utilization with MLA compression 235 | - **Convergence**: Stable training with proper learning rate scheduling 236 | 237 | ### Model Efficiency 238 | 239 | - **Parameter Efficiency**: 109M parameters with MoE sparse activation 240 | - **Memory Efficiency**: Reduced KV cache through latent attention 241 | - **Computational Efficiency**: 25% active parameters per forward pass 242 | 243 | ## Technical Challenges Addressed 244 | 245 | ### Dataset Selection 246 | 247 | The choice of dataset was critical for demonstrating the architecture's benefits: 248 | 249 | - **TinyStories**: Too simple, didn't justify advanced architecture components 250 | - **Raw Web Data**: Too complex for resource-constrained training 251 | - **FineWeb-Edu**: Perfect balance of complexity and educational content quality 252 | 253 | ### Architecture Decisions 254 | 255 | Careful consideration was given to which components to include: 256 | 257 | - **Essential Components**: MLA, MoE, MTP all included for comprehensive implementation 258 | - **Training Constraints**: Context length limited to 1024 tokens due to compute budget 259 | - **Resource Management**: Balanced model size with available GPU memory 260 | 261 | ## Future Enhancements 262 | 263 | ### Planned Improvements 264 | 265 | - **Dataset Expansion**: Experiment with larger subsets of FineWeb-Edu 266 | - **Evaluation Metrics**: Implement comprehensive benchmarking suite 267 | - **Architecture Extensions**: Additional transformer innovations and optimizations 268 | - **Scaling Studies**: Analysis of performance across different model sizes 269 | 270 | 271 | -------------------------------------------------------------------------------- /notebooks/Multi_Token_Prediction_from_Scratch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [] 7 | }, 8 | "kernelspec": { 9 | "name": "python3", 10 | "display_name": "Python 3" 11 | }, 12 | "language_info": { 13 | "name": "python" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "source": [ 20 | "# Multi-Token Prediction (MTP)" 21 | ], 22 | "metadata": { 23 | "id": "GZps_56evbje" 24 | } 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "source": [ 29 | "## Step 0: Load Packages" 30 | ], 31 | "metadata": { 32 | "id": "ymOhz91jvhK8" 33 | } 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": { 39 | "id": "yaCHMAGSrrA-" 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "import torch\n", 44 | "import torch.nn as nn\n", 45 | "import torch.nn.functional as F" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "source": [ 51 | "## Step 1: Define RMSNorm Class" 52 | ], 53 | "metadata": { 54 | "id": "Q7OaXcEdvrkp" 55 | } 56 | }, 57 | { 58 | "cell_type": "code", 59 | "source": [ 60 | "class RMSNorm(nn.Module):\n", 61 | " \"\"\"Root Mean Square Layer Norm (no learning weights) \"\"\"\n", 62 | " def __init__(self,d_model,eps:float = 1e-8):\n", 63 | " super().__init__()\n", 64 | " self.eps = eps\n", 65 | "\n", 66 | " def forward(self,x):\n", 67 | " # x: (batch,d_model)\n", 68 | " rms = torch.sqrt(x.pow(2).mean(dim=-1,keepdim=True)+ self.eps)\n", 69 | " return x / rms" 70 | ], 71 | "metadata": { 72 | "id": "7brrcEzBvrYC" 73 | }, 74 | "execution_count": null, 75 | "outputs": [] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "source": [ 80 | "## Step 2: Define the Multi-Token Prediction (MTP) class" 81 | ], 82 | "metadata": { 83 | "id": "-2jlelWbzk5q" 84 | } 85 | }, 86 | { 87 | "cell_type": "code", 88 | "source": [ 89 | "class SimpleMTP(nn.Module):\n", 90 | " def __init__(self,d_model:int,vocab_size:int,num_heads:int=3,nhead: int =1):\n", 91 | " \"\"\"\n", 92 | " d_model: hidden size (8 in this example)\n", 93 | " num_heads: number of sequential MTP steps (D)\n", 94 | " nhead: attention heads in each Transformer block\n", 95 | " \"\"\"\n", 96 | " super().__init__()\n", 97 | " self.d_model = d_model\n", 98 | " self.vocab_size = vocab_size\n", 99 | " self.num_heads = num_heads\n", 100 | "\n", 101 | " # shared modules\n", 102 | " self.rmsnorm = RMSNorm(d_model)\n", 103 | " self.embed = nn.Embedding(vocab_size,d_model)\n", 104 | " self.unembed = nn.Linear(d_model,vocab_size,bias=False)\n", 105 | " # share weights between embed and unembed\n", 106 | " self.unembed.weight = self.embed.weight\n", 107 | "\n", 108 | " # one projection + one Transformer per head\n", 109 | " self.projections = nn.ModuleList([\n", 110 | " nn.Linear(2*d_model,d_model) for _ in range(num_heads)\n", 111 | "\n", 112 | " ])\n", 113 | " self.transformers = nn.ModuleList([\n", 114 | " nn.TransformerEncoderLayer(d_model=d_model,nhead=nhead)\n", 115 | " for _ in range(num_heads)\n", 116 | " ])\n", 117 | "\n", 118 | " def forward(self,token_ids:torch.LongTensor,init_hidden:torch.Tensor = None):\n", 119 | " \"\"\"\n", 120 | " token_ids: (batch,seq_len) integer IDs of your input tokens\n", 121 | " init_hidden: optional (batch,seq_len,d_model) base hidden states;\n", 122 | " if None, uses token embedding as initial hidden.\n", 123 | "\n", 124 | " Returns:\n", 125 | " logits_out: Tensor of shape (batch,T-D,D,vocab_size),\n", 126 | " where T=seq_len and D=num_heads\n", 127 | "\n", 128 | "\n", 129 | "\n", 130 | " \"\"\"\n", 131 | "\n", 132 | " B,T = token_ids.shape\n", 133 | " device = token_ids.device\n", 134 | " # token embeddings: (B,T,d_model)\n", 135 | " embeds = self.embed(token_ids)\n", 136 | "\n", 137 | "\n", 138 | " # base hidden states\n", 139 | " if init_hidden is None:\n", 140 | " h0_seq = embeds # use embeddings as base hidden\n", 141 | " else:\n", 142 | " h0_seq = init_hidden # user-provided base states\n", 143 | "\n", 144 | " outputs = [] # will hold (B,D,vocab_size) for each i\n", 145 | " # slide over positions where i + D < T\n", 146 | " max_i = T - self.num_heads - 1\n", 147 | " for i in range(0,max_i + 1):\n", 148 | " # previous hidden for depth 0 at pos i\n", 149 | " h_prev = h0_seq[:,i,:] # (B,d_model)\n", 150 | "\n", 151 | "\n", 152 | " # collect logits for all k at this i\n", 153 | "\n", 154 | " logits_k = []\n", 155 | "\n", 156 | " for k in range(self.num_heads):\n", 157 | " # future token embed at pos i + (k+ 1)\n", 158 | " future_pos = i + (k+1)\n", 159 | " tok_embed = embeds[:,future_pos,:] # (B,d_model)\n", 160 | "\n", 161 | " # 1) RMS-normalize\n", 162 | " h_norm = self.rmsnorm(h_prev) # (B,d_model)\n", 163 | " e_norm = self.rmsnorm(tok_embed) # (B,d_model)\n", 164 | "\n", 165 | " # 2) concatenate -> (B,2*d_model)\n", 166 | " merged = torch.cat([h_norm,e_norm],dim=-1)\n", 167 | "\n", 168 | " # 3) project back to d_model\n", 169 | " proj = self.projections[k](merged) # (B, d_model)\n", 170 | "\n", 171 | " # 4) Transformer block (expects shape (S,B,d_model))\n", 172 | " x = proj.unsqueeze(0) # (1,B,d_model)\n", 173 | " x = self.transformers[k](x) # (1,B,d_model)\n", 174 | " h_curr = x.squeeze(0) # (B,d_model)\n", 175 | "\n", 176 | " # 5) unembed -> logits\n", 177 | " logits = self.unembed(h_curr) # (B,vocab_size)\n", 178 | " logits_k.append(logits)\n", 179 | "\n", 180 | " # 6) chain hidden for next depth\n", 181 | " h_prev = h_curr\n", 182 | "\n", 183 | " # stack along. depth axis -> (B,D,vocab_size)\n", 184 | " logits_k = torch.stack(logits_k,dim=1)\n", 185 | " outputs.append(logits_k)\n", 186 | "\n", 187 | " # stack along sequence axis -> (T-D,B,D,V) then permute -> (B,T-D,D,V)\n", 188 | "\n", 189 | " out = torch.stack(outputs,dim=0)\n", 190 | " out = out.permute(1,0,2,3).contiguous()\n", 191 | " return out" 192 | ], 193 | "metadata": { 194 | "id": "y52_ypdIvq-c" 195 | }, 196 | "execution_count": null, 197 | "outputs": [] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "source": [ 202 | "## Step 3: Pass input tokens through the model and generate multiple next tokens." 203 | ], 204 | "metadata": { 205 | "id": "cxMqwj0w7GdC" 206 | } 207 | }, 208 | { 209 | "cell_type": "code", 210 | "source": [ 211 | "batch_size, seq_len,d_model,vocab_size = 1,8,8,5000\n", 212 | "model = SimpleMTP(d_model=d_model,vocab_size=vocab_size,num_heads=3)\n", 213 | "tokens = torch.randint(0,vocab_size,(batch_size,seq_len))\n", 214 | "\n", 215 | "\n", 216 | "# Forward pass\n", 217 | "logits = model(tokens)\n", 218 | "# logits.shape == (1,4-3,3,5000) -> (batch_size,T-D,D,V)\n", 219 | "print(\"Logits shape:\",logits.shape)\n", 220 | "\n", 221 | "# If you want to inspect the 1-step ahead predition at postition i=0:\n", 222 | "print(\"Head k=0 at i=0 logits:\",logits[0,0,0]) # a tensor of length vocab_size\n", 223 | "\n", 224 | "# Or to get all predictions at i=0 as token IDs:\n", 225 | "\n", 226 | "pred_ids = logits[0,0].argmax(dim=-1)\n", 227 | "print(\"Predicted tokens at i=0 for all heads:\",pred_ids) # a length-3 tensor" 228 | ], 229 | "metadata": { 230 | "colab": { 231 | "base_uri": "https://localhost:8080/" 232 | }, 233 | "id": "dhOmCR8X1yei", 234 | "outputId": "c8d9a829-b5b8-4236-8fd3-a4354059690b" 235 | }, 236 | "execution_count": null, 237 | "outputs": [ 238 | { 239 | "output_type": "stream", 240 | "name": "stdout", 241 | "text": [ 242 | "Logits shape: torch.Size([1, 5, 3, 5000])\n", 243 | "Head k=0 at i=0 logits: tensor([ 3.6052, 2.8964, -1.7114, ..., 1.4961, -3.3179, 1.1599],\n", 244 | " grad_fn=)\n", 245 | "Predicted tokens at i=0 for all heads: tensor([4207, 4708, 4765])\n" 246 | ] 247 | } 248 | ] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "source": [ 253 | "## Step 4: Calcuate loss betweeen Loss between target tokens and predicted tokens" 254 | ], 255 | "metadata": { 256 | "id": "9GP9CvKg-VOT" 257 | } 258 | }, 259 | { 260 | "cell_type": "code", 261 | "source": [ 262 | "batch_size, seq_len, vocab_size = 1,8,5000\n", 263 | "\n", 264 | "# old (wrong): targets = torch.randint(0, vocab_size,(1,4))\n", 265 | "# new (right):\n", 266 | "\n", 267 | "targets = torch.randint(0,vocab_size,(batch_size,seq_len))\n", 268 | "print(\"targets.shape ->\",targets.shape) # torch.Size([1,8])\n", 269 | "\n", 270 | "\n", 271 | "# Now recompute:\n", 272 | "\n", 273 | "logits = model(tokens) # shape (1,5,3,5000)\n", 274 | "B,L,D,V = logits.shape # (1,5,3,5000)\n", 275 | "_,T = targets.shape # (1,8)\n", 276 | "assert L == T - D # 5 == 8 -3 passes\n", 277 | "\n", 278 | "\n", 279 | "# Double-loop loss:\n", 280 | "loss = 0.0\n", 281 | "for i in range(L):\n", 282 | " for k in range(D): # i = 0...4\n", 283 | " logits_ik = logits[:,i,k,:] # (1,5000)\n", 284 | " target_ik = targets[:,i + (k + 1)] # (1,)\n", 285 | " loss += F.cross_entropy(logits_ik,target_ik)\n", 286 | "\n", 287 | "loss = loss / (L*D)\n", 288 | "print(\"MTP loss:\",loss.item())\n", 289 | "\n", 290 | "\n" 291 | ], 292 | "metadata": { 293 | "colab": { 294 | "base_uri": "https://localhost:8080/" 295 | }, 296 | "id": "6fvZFMQy-yXk", 297 | "outputId": "ff4b3e52-7d7a-4cb6-f1c0-12ef0f800f9b" 298 | }, 299 | "execution_count": null, 300 | "outputs": [ 301 | { 302 | "output_type": "stream", 303 | "name": "stdout", 304 | "text": [ 305 | "targets.shape -> torch.Size([1, 8])\n", 306 | "MTP loss: 13.472195625305176\n" 307 | ] 308 | } 309 | ] 310 | } 311 | ] 312 | } -------------------------------------------------------------------------------- /prepare_data_fineweb.py: -------------------------------------------------------------------------------- 1 | """ 2 | Memory-Efficient FineWeb-Edu Dataset Preprocessing 3 | 4 | This version processes large datasets without loading everything into memory at once. 5 | Key improvements: 6 | - Streaming tokenization 7 | - Direct file writing 8 | - Memory-efficient shuffling 9 | - Progress tracking 10 | - Optional quality filtering 11 | """ 12 | 13 | import os 14 | import logging 15 | import numpy as np 16 | import tiktoken 17 | from tqdm.auto import tqdm 18 | from datasets import load_dataset 19 | import random 20 | import re 21 | import pandas as pd 22 | import tempfile 23 | import mmap 24 | 25 | # ============================================================================ 26 | # CONFIGURATION SECTION 27 | # ============================================================================ 28 | 29 | # Main Dataset Processing Options 30 | TRAIN_ON_CUSTOM_ROWS = True 31 | CUSTOM_ROW_COUNT = 800000 32 | 33 | # Dataset Configuration 34 | DATASET_OPTIONS = [ 35 | ("HuggingFaceFW/fineweb-edu", "CC-MAIN-2024-51"), 36 | ] 37 | 38 | # Processing Parameters 39 | CONTEXT_LENGTH = 1024 40 | MIN_TEXT_LENGTH = 50 # Only used when USE_QUALITY_FILTERING = True 41 | TRAIN_SPLIT = 0.95 42 | RANDOM_SEED = 42 43 | USE_QUALITY_FILTERING = False # Set to False to process ALL rows, True to use filtering 44 | 45 | # Memory Management 46 | BATCH_SIZE = 8000 # Process in smaller batches 47 | BUFFER_SIZE = 75000 # Tokens to buffer before writing 48 | TEMP_DIR = "temp_tokens" # Directory for temporary files 49 | 50 | # File Output Configuration 51 | TRAIN_FILENAME = "train.bin" 52 | VALIDATION_FILENAME = "validation.bin" 53 | LOG_FILENAME = "dataset_preparation.log" 54 | 55 | # Tokenizer Configuration 56 | TOKENIZER_NAME = "gpt2" 57 | DTYPE = np.uint16 58 | 59 | # ============================================================================ 60 | 61 | # Setup logging 62 | logging.basicConfig( 63 | level=logging.INFO, 64 | format='%(asctime)s - %(message)s', 65 | handlers=[ 66 | logging.StreamHandler(), 67 | logging.FileHandler(LOG_FILENAME) 68 | ] 69 | ) 70 | logger = logging.getLogger(__name__) 71 | 72 | class MemoryEfficientTokenizer: 73 | """Tokenizes and saves data without loading everything into memory.""" 74 | 75 | def __init__(self, tokenizer_name, context_length, dtype): 76 | self.tokenizer = tiktoken.get_encoding(tokenizer_name) 77 | self.context_length = context_length 78 | self.dtype = dtype 79 | self.temp_files = [] 80 | 81 | def create_temp_dir(self): 82 | """Create temporary directory for intermediate files.""" 83 | if not os.path.exists(TEMP_DIR): 84 | os.makedirs(TEMP_DIR) 85 | 86 | def clean_temp_files(self): 87 | """Clean up temporary files.""" 88 | for temp_file in self.temp_files: 89 | if os.path.exists(temp_file): 90 | os.remove(temp_file) 91 | if os.path.exists(TEMP_DIR): 92 | try: 93 | os.rmdir(TEMP_DIR) 94 | except: 95 | pass 96 | 97 | def process_batch_to_temp_file(self, batch_samples, batch_id): 98 | """Process a batch of samples and save tokens to temporary file.""" 99 | temp_filename = os.path.join(TEMP_DIR, f"batch_{batch_id}.bin") 100 | token_buffer = [] 101 | 102 | for text in batch_samples: 103 | tokens = self.tokenizer.encode_ordinary(text) 104 | 105 | # Split into chunks of context_length 106 | for i in range(0, len(tokens), self.context_length): 107 | chunk = tokens[i:i + self.context_length] 108 | if len(chunk) == self.context_length: # Only complete chunks 109 | token_buffer.extend(chunk) 110 | 111 | # Write buffer to file when it gets large 112 | if len(token_buffer) >= BUFFER_SIZE: 113 | self._append_tokens_to_file(temp_filename, token_buffer) 114 | token_buffer = [] 115 | 116 | # Write remaining tokens 117 | if token_buffer: 118 | self._append_tokens_to_file(temp_filename, token_buffer) 119 | 120 | self.temp_files.append(temp_filename) 121 | return temp_filename 122 | 123 | def _append_tokens_to_file(self, filename, tokens): 124 | """Append tokens to a binary file.""" 125 | token_array = np.array(tokens, dtype=self.dtype) 126 | with open(filename, 'ab') as f: 127 | token_array.tofile(f) 128 | 129 | def merge_temp_files(self, temp_files, output_filename): 130 | """Merge temporary files into final output file.""" 131 | logger.info(f"Merging {len(temp_files)} temporary files into {output_filename}") 132 | 133 | total_tokens = 0 134 | with open(output_filename, 'wb') as output_file: 135 | for temp_file in tqdm(temp_files, desc="Merging files"): 136 | if os.path.exists(temp_file): 137 | # Read and write in chunks to avoid memory issues 138 | with open(temp_file, 'rb') as f: 139 | while True: 140 | chunk = f.read(BUFFER_SIZE * 2) # Read in bytes 141 | if not chunk: 142 | break 143 | output_file.write(chunk) 144 | total_tokens += len(chunk) // 2 # 2 bytes per uint16 145 | 146 | logger.info(f"Merged {total_tokens:,} tokens into {output_filename}") 147 | return total_tokens 148 | 149 | class StreamingDataProcessor: 150 | """Process dataset in streaming fashion without loading all into memory.""" 151 | 152 | def __init__(self): 153 | self.tokenizer_processor = MemoryEfficientTokenizer( 154 | TOKENIZER_NAME, CONTEXT_LENGTH, DTYPE 155 | ) 156 | 157 | def clean_text(self, text): 158 | """Remove excessive whitespace and control characters.""" 159 | text = re.sub(r'\s+', ' ', text) 160 | text = ''.join(char for char in text if ord(char) >= 32 or char in '\n\t') 161 | return text.strip() 162 | 163 | def is_quality_text(self, text): 164 | """Apply quality filters to determine if text should be included.""" 165 | if not USE_QUALITY_FILTERING: 166 | return True # Accept all texts when filtering is disabled 167 | 168 | # Apply quality filters only when enabled 169 | return len(text.strip()) > MIN_TEXT_LENGTH 170 | 171 | def process_dataset_streaming(self, dataset): 172 | """Process dataset in streaming fashion with memory-efficient approach.""" 173 | logger.info(f"Processing {CUSTOM_ROW_COUNT:,} rows with streaming approach") 174 | logger.info(f"Quality filtering: {'ENABLED' if USE_QUALITY_FILTERING else 'DISABLED'}") 175 | 176 | # Create temporary directory 177 | self.tokenizer_processor.create_temp_dir() 178 | 179 | # Process data in batches 180 | train_temp_files = [] 181 | val_temp_files = [] 182 | 183 | current_batch = [] 184 | processed_count = 0 185 | accepted_count = 0 186 | train_count = 0 187 | val_count = 0 188 | batch_id = 0 189 | 190 | progress_bar = tqdm(total=CUSTOM_ROW_COUNT, desc="Processing samples") 191 | 192 | try: 193 | for example in dataset: 194 | if accepted_count >= CUSTOM_ROW_COUNT: 195 | break 196 | 197 | text = example.get('text', '') 198 | if not text: 199 | continue 200 | 201 | cleaned_text = self.clean_text(text) 202 | processed_count += 1 203 | 204 | if self.is_quality_text(cleaned_text): 205 | # Determine if this sample goes to train or validation 206 | if random.random() < TRAIN_SPLIT: 207 | current_batch.append(('train', cleaned_text)) 208 | train_count += 1 209 | else: 210 | current_batch.append(('val', cleaned_text)) 211 | val_count += 1 212 | 213 | accepted_count += 1 214 | progress_bar.update(1) 215 | 216 | # Process batch when it reaches target size 217 | if len(current_batch) >= BATCH_SIZE: 218 | self._process_current_batch(current_batch, batch_id, 219 | train_temp_files, val_temp_files) 220 | current_batch = [] 221 | batch_id += 1 222 | 223 | # Update progress info 224 | progress_bar.set_postfix({ 225 | 'train': train_count, 226 | 'val': val_count, 227 | 'batches': batch_id, 228 | 'processed': processed_count, 229 | 'accepted': accepted_count 230 | }) 231 | 232 | except Exception as e: 233 | logger.error(f"Error during processing: {e}") 234 | finally: 235 | progress_bar.close() 236 | 237 | # Process remaining samples 238 | if current_batch: 239 | self._process_current_batch(current_batch, batch_id, 240 | train_temp_files, val_temp_files) 241 | 242 | logger.info(f"Processed {processed_count:,} total samples") 243 | logger.info(f"Accepted {accepted_count:,} samples") 244 | logger.info(f"Training samples: {train_count:,}, Validation samples: {val_count:,}") 245 | 246 | return train_temp_files, val_temp_files 247 | 248 | def _process_current_batch(self, batch, batch_id, train_temp_files, val_temp_files): 249 | """Process current batch and save to appropriate temporary files.""" 250 | train_samples = [text for split, text in batch if split == 'train'] 251 | val_samples = [text for split, text in batch if split == 'val'] 252 | 253 | if train_samples: 254 | train_temp_file = self.tokenizer_processor.process_batch_to_temp_file( 255 | train_samples, f"train_{batch_id}" 256 | ) 257 | train_temp_files.append(train_temp_file) 258 | 259 | if val_samples: 260 | val_temp_file = self.tokenizer_processor.process_batch_to_temp_file( 261 | val_samples, f"val_{batch_id}" 262 | ) 263 | val_temp_files.append(val_temp_file) 264 | 265 | def finalize_dataset(self, train_temp_files, val_temp_files): 266 | """Merge temporary files into final dataset files.""" 267 | # Merge training files 268 | if train_temp_files: 269 | train_tokens = self.tokenizer_processor.merge_temp_files( 270 | train_temp_files, TRAIN_FILENAME 271 | ) 272 | else: 273 | train_tokens = 0 274 | 275 | # Merge validation files 276 | if val_temp_files: 277 | val_tokens = self.tokenizer_processor.merge_temp_files( 278 | val_temp_files, VALIDATION_FILENAME 279 | ) 280 | else: 281 | val_tokens = 0 282 | 283 | # Clean up temporary files 284 | self.tokenizer_processor.clean_temp_files() 285 | 286 | return train_tokens, val_tokens 287 | 288 | def load_dataset_with_fallbacks(): 289 | """Load dataset with fallback options.""" 290 | for dataset_name, config_name in DATASET_OPTIONS: 291 | try: 292 | logger.info(f"Attempting to load {dataset_name}") 293 | 294 | if config_name: 295 | ds = load_dataset(dataset_name, name=config_name, split="train", streaming=True) 296 | else: 297 | ds = load_dataset(dataset_name, split="train", streaming=True) 298 | 299 | logger.info(f"Successfully loaded {dataset_name}") 300 | return ds, dataset_name 301 | 302 | except Exception as e: 303 | logger.warning(f"Failed to load {dataset_name}: {str(e)[:100]}") 304 | continue 305 | 306 | raise RuntimeError("Could not load any dataset. Check internet connection.") 307 | 308 | def verify_files(): 309 | """Verify the created binary files.""" 310 | logger.info("Verifying output files") 311 | 312 | if not (os.path.exists(TRAIN_FILENAME) and os.path.exists(VALIDATION_FILENAME)): 313 | logger.error("Binary files not found") 314 | return False 315 | 316 | try: 317 | train_data = np.fromfile(TRAIN_FILENAME, dtype=DTYPE) 318 | val_data = np.fromfile(VALIDATION_FILENAME, dtype=DTYPE) 319 | 320 | logger.info(f"Training tokens: {len(train_data):,}") 321 | logger.info(f"Validation tokens: {len(val_data):,}") 322 | logger.info(f"Training sequences: {len(train_data) // CONTEXT_LENGTH:,}") 323 | logger.info(f"Validation sequences: {len(val_data) // CONTEXT_LENGTH:,}") 324 | 325 | return True 326 | 327 | except Exception as e: 328 | logger.error(f"Verification failed: {e}") 329 | return False 330 | 331 | def main(): 332 | """Main function to prepare the dataset with memory-efficient processing.""" 333 | logger.info("Starting memory-efficient FineWeb-edu dataset preparation") 334 | logger.info(f"Configuration: {CUSTOM_ROW_COUNT:,} rows, context length {CONTEXT_LENGTH}") 335 | logger.info(f"Batch size: {BATCH_SIZE}, Buffer size: {BUFFER_SIZE}") 336 | logger.info(f"Quality filtering: {'ENABLED' if USE_QUALITY_FILTERING else 'DISABLED'}") 337 | 338 | # Set random seed 339 | random.seed(RANDOM_SEED) 340 | 341 | # Check if files already exist 342 | if os.path.exists(TRAIN_FILENAME) and os.path.exists(VALIDATION_FILENAME): 343 | logger.info("Dataset files already exist") 344 | verify_files() 345 | return 346 | 347 | # Load dataset 348 | dataset, dataset_name = load_dataset_with_fallbacks() 349 | logger.info(f"Using dataset: {dataset_name}") 350 | 351 | # Initialize processor 352 | processor = StreamingDataProcessor() 353 | 354 | # Process dataset 355 | train_temp_files, val_temp_files = processor.process_dataset_streaming(dataset) 356 | 357 | # Finalize dataset 358 | train_tokens, val_tokens = processor.finalize_dataset(train_temp_files, val_temp_files) 359 | 360 | # Summary 361 | logger.info("Dataset preparation completed") 362 | logger.info(f"Training tokens: {train_tokens:,}") 363 | logger.info(f"Validation tokens: {val_tokens:,}") 364 | logger.info(f"Ready for model training") 365 | 366 | # Verify output 367 | verify_files() 368 | 369 | if __name__ == "__main__": 370 | main() -------------------------------------------------------------------------------- /notebooks/Mixture_of_Experts_from_Scratch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "colab": { 8 | "base_uri": "https://localhost:8080/" 9 | }, 10 | "id": "60-Qwx_hcDp0", 11 | "outputId": "ee0aa230-6dad-4298-b657-ed0097f94c8b" 12 | }, 13 | "outputs": [ 14 | { 15 | "output_type": "execute_result", 16 | "data": { 17 | "text/plain": [ 18 | "" 19 | ] 20 | }, 21 | "metadata": {}, 22 | "execution_count": 72 23 | } 24 | ], 25 | "source": [ 26 | "# Import the necessary packages and seet seed for reproducibility .\n", 27 | "\n", 28 | "import torch\n", 29 | "import torch.nn as nn\n", 30 | "from torch.nn import functional as F\n", 31 | "torch.manual_seed(42)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": { 37 | "id": "4qAclO2Hoiu7" 38 | }, 39 | "source": [ 40 | "# Expert Module" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": { 47 | "id": "In-gJA9hoPk5" 48 | }, 49 | "outputs": [], 50 | "source": [ 51 | "class Expert(nn.Module):\n", 52 | " \"\"\" An MLP is a simple linear layer followed by a non-linearity i.e each Expert \"\"\"\n", 53 | "\n", 54 | " def __init__(self,n_embd):\n", 55 | " super().__init__()\n", 56 | " self.net = nn.Sequential(\n", 57 | " nn.Linear(n_embd,4*n_embd),\n", 58 | " nn.ReLU(),\n", 59 | " nn.Linear(4*n_embd,n_embd),\n", 60 | " nn.Dropout(dropout),\n", 61 | " )\n", 62 | "\n", 63 | " def forward(self,x):\n", 64 | " return self.net(x)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": { 71 | "colab": { 72 | "base_uri": "https://localhost:8080/" 73 | }, 74 | "id": "eFwE10WJkApd", 75 | "outputId": "336ec790-96c4-4683-b481-dcd9be23167e" 76 | }, 77 | "outputs": [ 78 | { 79 | "output_type": "stream", 80 | "name": "stdout", 81 | "text": [ 82 | "--2025-06-30 17:48:35-- https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt\n", 83 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n", 84 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n", 85 | "HTTP request sent, awaiting response... 200 OK\n", 86 | "Length: 1115394 (1.1M) [text/plain]\n", 87 | "Saving to: ‘input.txt.3’\n", 88 | "\n", 89 | "input.txt.3 100%[===================>] 1.06M --.-KB/s in 0.005s \n", 90 | "\n", 91 | "2025-06-30 17:48:35 (206 MB/s) - ‘input.txt.3’ saved [1115394/1115394]\n", 92 | "\n" 93 | ] 94 | } 95 | ], 96 | "source": [ 97 | "!wget https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": { 103 | "id": "mJBLPhTwP7Zi" 104 | }, 105 | "source": [ 106 | "# Step 2: Implement a Router" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": { 113 | "colab": { 114 | "base_uri": "https://localhost:8080/" 115 | }, 116 | "id": "6fjVIKAfP-0K", 117 | "outputId": "4306a632-a42d-445b-8a6f-4bde20e084a1" 118 | }, 119 | "outputs": [ 120 | { 121 | "output_type": "stream", 122 | "name": "stdout", 123 | "text": [ 124 | "tensor([[[-0.8934, 0.1072, 0.5144, -0.2811],\n", 125 | " [-0.8497, -0.9384, -0.1564, -0.3949],\n", 126 | " [-0.1428, 0.6368, 0.3035, 0.9877],\n", 127 | " [-0.3160, -0.8139, -0.3333, -0.0203]]], grad_fn=)\n" 128 | ] 129 | } 130 | ], 131 | "source": [ 132 | "# Understanding how gating works\n", 133 | "\n", 134 | "num_experts = 4\n", 135 | "top_k= 2\n", 136 | "n_embed = 32\n", 137 | "\n", 138 | "# Example multi-head attention output for a simple illustrative example, consider n_embed=32, context_length = 4\n", 139 | "mh_output = torch.randn(1,4,n_embed)\n", 140 | "\n", 141 | "topkgate_linear = nn.Linear(n_embed, num_experts) # nn.Linear(32,4)\n", 142 | "\n", 143 | "logits = topkgate_linear(mh_output)\n", 144 | "\n", 145 | "print(logits)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": { 151 | "id": "bl1JAC5rR8Oy" 152 | }, 153 | "source": [ 154 | "# Step 3: Implement topk load balancing" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": { 161 | "colab": { 162 | "base_uri": "https://localhost:8080/" 163 | }, 164 | "id": "kqKQrLXiRMPg", 165 | "outputId": "43082836-93cf-48ad-be6d-db09513816ec" 166 | }, 167 | "outputs": [ 168 | { 169 | "output_type": "execute_result", 170 | "data": { 171 | "text/plain": [ 172 | "(tensor([[[ 0.5144, 0.1072],\n", 173 | " [-0.1564, -0.3949],\n", 174 | " [ 0.9877, 0.6368],\n", 175 | " [-0.0203, -0.3160]]], grad_fn=),\n", 176 | " tensor([[[2, 1],\n", 177 | " [2, 3],\n", 178 | " [3, 1],\n", 179 | " [3, 0]]]))" 180 | ] 181 | }, 182 | "metadata": {}, 183 | "execution_count": 76 184 | } 185 | ], 186 | "source": [ 187 | "top_k_logits, top_k_indices = logits.topk(top_k,dim=-1) # Get top-k experts\n", 188 | "top_k_logits, top_k_indices" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": { 194 | "id": "KN9MqZftSaWy" 195 | }, 196 | "source": [ 197 | "# Step 4 : Use -inf and apply softmax\n" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "metadata": { 204 | "colab": { 205 | "base_uri": "https://localhost:8080/" 206 | }, 207 | "id": "2sYkYIfHScvP", 208 | "outputId": "ea23bb19-380b-457b-f4f3-d26f2599d7c0" 209 | }, 210 | "outputs": [ 211 | { 212 | "output_type": "execute_result", 213 | "data": { 214 | "text/plain": [ 215 | "tensor([[[ -inf, 0.1072, 0.5144, -inf],\n", 216 | " [ -inf, -inf, -0.1564, -0.3949],\n", 217 | " [ -inf, 0.6368, -inf, 0.9877],\n", 218 | " [-0.3160, -inf, -inf, -0.0203]]], grad_fn=)" 219 | ] 220 | }, 221 | "metadata": {}, 222 | "execution_count": 77 223 | } 224 | ], 225 | "source": [ 226 | "zeros = torch.full_like(logits,float('-inf')) # full_like clones a tensor and fills it with a specified value\n", 227 | "sparse_logits = zeros.scatter(-1,top_k_indices,top_k_logits)\n", 228 | "sparse_logits" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": { 235 | "colab": { 236 | "base_uri": "https://localhost:8080/" 237 | }, 238 | "id": "WgqU9QJOVI6N", 239 | "outputId": "b0b4027d-b3e0-4253-ad18-d4c9d3920d82" 240 | }, 241 | "outputs": [ 242 | { 243 | "output_type": "execute_result", 244 | "data": { 245 | "text/plain": [ 246 | "tensor([[[0.0000, 0.3996, 0.6004, 0.0000],\n", 247 | " [0.0000, 0.0000, 0.5594, 0.4406],\n", 248 | " [0.0000, 0.4132, 0.0000, 0.5868],\n", 249 | " [0.4266, 0.0000, 0.0000, 0.5734]]], grad_fn=)" 250 | ] 251 | }, 252 | "metadata": {}, 253 | "execution_count": 78 254 | } 255 | ], 256 | "source": [ 257 | "gating_output = F.softmax(sparse_logits,dim=-1)\n", 258 | "gating_output" 259 | ] 260 | }, 261 | { 262 | "cell_type": "markdown", 263 | "metadata": { 264 | "id": "pHN4YFSsVSAU" 265 | }, 266 | "source": [ 267 | "# Step 5: Create a class for TopKRouting" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": null, 273 | "metadata": { 274 | "id": "SYIqxvEcVPAa" 275 | }, 276 | "outputs": [], 277 | "source": [ 278 | "# First define the top k router module\n", 279 | "class TopkRouter(nn.Module):\n", 280 | " def __init__(self, n_embed, num_experts, top_k):\n", 281 | " super(TopkRouter, self).__init__()\n", 282 | " self.top_k = top_k\n", 283 | " self.linear =nn.Linear(n_embed, num_experts)\n", 284 | "\n", 285 | " def forward(self, mh_ouput):\n", 286 | " # mh_ouput is the output tensor from multihead self attention block\n", 287 | " logits = self.linear(mh_output)\n", 288 | " top_k_logits, indices = logits.topk(self.top_k, dim=-1)\n", 289 | " zeros = torch.full_like(logits, float('-inf'))\n", 290 | " sparse_logits = zeros.scatter(-1, indices, top_k_logits)\n", 291 | " router_output = F.softmax(sparse_logits, dim=-1)\n", 292 | " return router_output, indices\n" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "metadata": { 299 | "colab": { 300 | "base_uri": "https://localhost:8080/" 301 | }, 302 | "id": "4SlPZzX0XMhK", 303 | "outputId": "52c7ed93-f2e4-4a23-90de-b9685199f827" 304 | }, 305 | "outputs": [ 306 | { 307 | "output_type": "execute_result", 308 | "data": { 309 | "text/plain": [ 310 | "(torch.Size([2, 4, 4]),\n", 311 | " tensor([[[0.0000, 0.0000, 0.5321, 0.4679],\n", 312 | " [0.0000, 0.8659, 0.1341, 0.0000],\n", 313 | " [0.0000, 0.4096, 0.0000, 0.5904],\n", 314 | " [0.0000, 0.0000, 0.5828, 0.4172]],\n", 315 | " \n", 316 | " [[0.6210, 0.0000, 0.3790, 0.0000],\n", 317 | " [0.7005, 0.0000, 0.2995, 0.0000],\n", 318 | " [0.0000, 0.8630, 0.1370, 0.0000],\n", 319 | " [0.0000, 0.3759, 0.6241, 0.0000]]], grad_fn=),\n", 320 | " tensor([[[2, 3],\n", 321 | " [1, 2],\n", 322 | " [3, 1],\n", 323 | " [2, 3]],\n", 324 | " \n", 325 | " [[0, 2],\n", 326 | " [0, 2],\n", 327 | " [1, 2],\n", 328 | " [2, 1]]]))" 329 | ] 330 | }, 331 | "metadata": {}, 332 | "execution_count": 80 333 | } 334 | ], 335 | "source": [ 336 | "#Testing this out:\n", 337 | "num_experts = 4\n", 338 | "top_k = 2\n", 339 | "n_embd = 32\n", 340 | "\n", 341 | "mh_output = torch.randn(2, 4, n_embd) # Example input\n", 342 | "top_k_gate = TopkRouter(n_embd, num_experts, top_k)\n", 343 | "gating_output, indices = top_k_gate(mh_output)\n", 344 | "gating_output.shape, gating_output, indices\n" 345 | ] 346 | }, 347 | { 348 | "cell_type": "markdown", 349 | "metadata": { 350 | "id": "52WYwFrpZpSd" 351 | }, 352 | "source": [ 353 | "# Step 6 Create. a class for Noisy Topk Routing" 354 | ] 355 | }, 356 | { 357 | "cell_type": "markdown", 358 | "metadata": { 359 | "id": "8yhBaSC9bUFT" 360 | }, 361 | "source": [ 362 | "Noisy top-k Gating is an important tool in training MoE models.\n", 363 | "\n", 364 | "Essentially, you don't want all the tokens to be sent to the same set of 'favored' experts.\n", 365 | "\n", 366 | "You want a fine balance of exploitation and exploraation. For this purpose, to load balance, it is helps logits from the gating linear layer. This makes trainning more efficient" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": null, 372 | "metadata": { 373 | "id": "gVd8OFafXy4o" 374 | }, 375 | "outputs": [], 376 | "source": [ 377 | "#Changing the above to accomodate noisy top-k gating\n", 378 | "class NoisyTopkRouter(nn.Module):\n", 379 | " def __init__(self, n_embed, num_experts, top_k):\n", 380 | " super(NoisyTopkRouter, self).__init__()\n", 381 | " self.top_k = top_k\n", 382 | " #layer for router logits\n", 383 | " self.topkroute_linear = nn.Linear(n_embed, num_experts)\n", 384 | " self.noise_linear =nn.Linear(n_embed, num_experts)\n", 385 | "\n", 386 | "\n", 387 | " def forward(self, mh_output):\n", 388 | " # mh_ouput is the output tensor from multihead self attention block\n", 389 | " logits = self.topkroute_linear(mh_output)\n", 390 | "\n", 391 | " #Noise logits\n", 392 | " noise_logits = self.noise_linear(mh_output)\n", 393 | "\n", 394 | " #Adding scaled unit gaussian noise to the logits\n", 395 | " noise = torch.randn_like(logits)*F.softplus(noise_logits)\n", 396 | " noisy_logits = logits + noise\n", 397 | "\n", 398 | " top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)\n", 399 | " zeros = torch.full_like(noisy_logits, float('-inf'))\n", 400 | " sparse_logits = zeros.scatter(-1, indices, top_k_logits)\n", 401 | " router_output = F.softmax(sparse_logits, dim=-1)\n", 402 | " return router_output, indices\n" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": null, 408 | "metadata": { 409 | "colab": { 410 | "base_uri": "https://localhost:8080/" 411 | }, 412 | "id": "wKlwNFjtn0_P", 413 | "outputId": "a61ff7bd-8ded-404f-bb2d-9f4ab65e5060" 414 | }, 415 | "outputs": [ 416 | { 417 | "output_type": "execute_result", 418 | "data": { 419 | "text/plain": [ 420 | "(torch.Size([2, 4, 8]),\n", 421 | " tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6640, 0.3360],\n", 422 | " [0.4700, 0.0000, 0.5300, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", 423 | " [0.0000, 0.0000, 0.4277, 0.0000, 0.0000, 0.0000, 0.5723, 0.0000],\n", 424 | " [0.0000, 0.0000, 0.0000, 0.0000, 0.4105, 0.0000, 0.0000, 0.5895]],\n", 425 | " \n", 426 | " [[0.0000, 0.4177, 0.0000, 0.0000, 0.5823, 0.0000, 0.0000, 0.0000],\n", 427 | " [0.0000, 0.6713, 0.0000, 0.0000, 0.3287, 0.0000, 0.0000, 0.0000],\n", 428 | " [0.0000, 0.1670, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8330],\n", 429 | " [0.4322, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5678]]],\n", 430 | " grad_fn=),\n", 431 | " tensor([[[6, 7],\n", 432 | " [2, 0],\n", 433 | " [6, 2],\n", 434 | " [7, 4]],\n", 435 | " \n", 436 | " [[4, 1],\n", 437 | " [1, 4],\n", 438 | " [7, 1],\n", 439 | " [7, 0]]]))" 440 | ] 441 | }, 442 | "metadata": {}, 443 | "execution_count": 82 444 | } 445 | ], 446 | "source": [ 447 | "#Testing this out, again:\n", 448 | "num_experts = 8\n", 449 | "top_k = 2\n", 450 | "n_embd = 16\n", 451 | "\n", 452 | "mh_output = torch.randn(2, 4, n_embd) # Example input\n", 453 | "noisy_top_k_gate = NoisyTopkRouter(n_embd, num_experts, top_k)\n", 454 | "gating_output, indices = noisy_top_k_gate(mh_output)\n", 455 | "gating_output.shape, gating_output, indices\n", 456 | "\n" 457 | ] 458 | }, 459 | { 460 | "cell_type": "code", 461 | "execution_count": null, 462 | "metadata": { 463 | "id": "M-fNH0N3oFX0" 464 | }, 465 | "outputs": [], 466 | "source": [ 467 | "class SparseMoE(nn.Module):\n", 468 | " def __init__(self, n_embed, num_experts, top_k):\n", 469 | " super(SparseMoE, self).__init__()\n", 470 | " self.router = NoisyTopkRouter(n_embed, num_experts, top_k)\n", 471 | " self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])\n", 472 | " self.top_k = top_k\n", 473 | "\n", 474 | " def forward(self, x):\n", 475 | " gating_output, indices = self.router(x)\n", 476 | " final_output = torch.zeros_like(x)\n", 477 | "\n", 478 | " # Reshape inputs for batch processing\n", 479 | " flat_x = x.view(-1, x.size(-1))\n", 480 | " flat_gating_output = gating_output.view(-1, gating_output.size(-1))\n", 481 | "\n", 482 | " # Process each expert in parallel\n", 483 | " for i, expert in enumerate(self.experts):\n", 484 | " # Create a mask for the inputs where the current expert is in top-k\n", 485 | " expert_mask = (indices == i).any(dim=-1)\n", 486 | " flat_mask = expert_mask.view(-1)\n", 487 | "\n", 488 | " if flat_mask.any():\n", 489 | " expert_input = flat_x[flat_mask]\n", 490 | " expert_output = expert(expert_input)\n", 491 | "\n", 492 | " # Extract and apply gating scores\n", 493 | " gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)\n", 494 | " weighted_output = expert_output * gating_scores\n", 495 | "\n", 496 | " # Update final output additively by indexing and adding\n", 497 | " final_output[expert_mask] += weighted_output.squeeze(1)\n", 498 | "\n", 499 | " return final_output\n" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": null, 505 | "metadata": { 506 | "colab": { 507 | "base_uri": "https://localhost:8080/" 508 | }, 509 | "id": "o8x7da7BzGe5", 510 | "outputId": "fdb9974c-7bf2-4680-95b4-048957e7322b" 511 | }, 512 | "outputs": [ 513 | { 514 | "output_type": "stream", 515 | "name": "stdout", 516 | "text": [ 517 | "Shape of the final output: torch.Size([1, 4, 16])\n", 518 | "tensor([[[ 0.4342, -0.4548, -0.0086, -0.0395, -0.1140, -0.0925, -0.2745,\n", 519 | " -0.0195, -0.3261, 0.1218, -0.0654, 0.1763, -0.1743, -0.2058,\n", 520 | " -0.1101, 0.1939],\n", 521 | " [-0.0275, -0.1477, -0.0534, 0.4124, 0.4952, -0.4857, -0.2878,\n", 522 | " -0.1663, 0.3566, 0.3007, 0.0182, 0.4281, 0.2393, 0.4560,\n", 523 | " -0.1753, 0.0000],\n", 524 | " [-0.0070, 0.0822, 0.0469, 0.2413, -0.3832, -0.0366, -0.1590,\n", 525 | " -0.1679, -0.2089, 0.0903, 0.2140, 0.2162, -0.0403, -0.0554,\n", 526 | " -0.1221, 0.2234],\n", 527 | " [-0.1403, -0.1466, 0.3349, -0.1171, 0.8054, 0.3437, -0.2110,\n", 528 | " 0.2405, -0.0367, 0.3583, -0.0914, -0.3995, -0.4149, -0.0041,\n", 529 | " 0.0000, -0.0971]]], grad_fn=)\n" 530 | ] 531 | } 532 | ], 533 | "source": [ 534 | "import torch\n", 535 | "import torch.nn as nn\n", 536 | "\n", 537 | "#Let's test this out\n", 538 | "num_experts = 8\n", 539 | "top_k = 2\n", 540 | "n_embd = 16\n", 541 | "dropout=0.1\n", 542 | "\n", 543 | "mh_output = torch.randn(1, 4, n_embd) # Example multi-head attention output\n", 544 | "sparse_moe = SparseMoE(n_embd, num_experts, top_k)\n", 545 | "final_output = sparse_moe(mh_output)\n", 546 | "print(\"Shape of the final output:\", final_output.shape)\n", 547 | "print(final_output)\n" 548 | ] 549 | }, 550 | { 551 | "cell_type": "markdown", 552 | "metadata": { 553 | "id": "9RBnw49ezXYG" 554 | }, 555 | "source": [ 556 | "# Step 8: Putting together all teh building blocks of MoE" 557 | ] 558 | }, 559 | { 560 | "cell_type": "code", 561 | "execution_count": null, 562 | "metadata": { 563 | "id": "V6J9Qsa7zGyx" 564 | }, 565 | "outputs": [], 566 | "source": [ 567 | "#Create a self attention + mixture of experts block, that may be repeated several number of times\n", 568 | "class Block(nn.Module):\n", 569 | " \"\"\" Mixture of Experts Transformer block: communication followed by computation (multi-head self attention + SparseMoE) \"\"\"\n", 570 | "\n", 571 | " def __init__(self, n_embed, n_head, num_experts, top_k):\n", 572 | " # n_embed: embedding dimension, n_head: the number of heads we'd like\n", 573 | " super().__init__()\n", 574 | " head_size = n_embed // n_head\n", 575 | " self.sa = MultiHeadAttention(n_head, head_size)\n", 576 | " self.smoe = SparseMoE(n_embed, num_experts, top_k)\n", 577 | " self.ln1 = nn.LayerNorm(n_embed)\n", 578 | " self.ln2 = nn.LayerNorm(n_embed)\n", 579 | "\n", 580 | " def forward(self, x):\n", 581 | " x = x + self.sa(self.ln1(x))\n", 582 | " x = x + self.smoe(self.ln2(x))\n", 583 | " return x\n" 584 | ] 585 | }, 586 | { 587 | "cell_type": "code", 588 | "execution_count": null, 589 | "metadata": { 590 | "id": "ta_Hf1SQ0qwe" 591 | }, 592 | "outputs": [], 593 | "source": [ 594 | "class SparseMoELanguageModel(nn.Module):\n", 595 | "\n", 596 | " def __init__(self):\n", 597 | " super().__init__()\n", 598 | " # each token directly reads off the logits for the next token from a lookup table\n", 599 | " self.token_embedding_table = nn.Embedding(vocab_size, n_embed)\n", 600 | " self.position_embedding_table = nn.Embedding(block_size, n_embed)\n", 601 | " self.blocks = nn.Sequential(*[Block(n_embed, n_head=n_head, num_experts=num_experts,top_k=top_k) for _ in range(n_layer)])\n", 602 | " self.ln_f = nn.LayerNorm(n_embed) # final layer norm\n", 603 | " self.lm_head = nn.Linear(n_embed, vocab_size)\n", 604 | "\n", 605 | " def forward(self, idx, targets=None):\n", 606 | " B, T = idx.shape\n", 607 | "\n", 608 | " # idx and targets are both (B,T) tensor of integers\n", 609 | " tok_emb = self.token_embedding_table(idx) # (B,T,C)\n", 610 | " pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n", 611 | " x = tok_emb + pos_emb # (B,T,C)\n", 612 | " x = self.blocks(x) # (B,T,C)\n", 613 | " x = self.ln_f(x) # (B,T,C)\n", 614 | " logits = self.lm_head(x) # (B,T,vocab_size)\n", 615 | "\n", 616 | " if targets is None:\n", 617 | " loss = None\n", 618 | " else:\n", 619 | " B, T, C = logits.shape\n", 620 | " logits = logits.view(B*T, C)\n", 621 | " targets = targets.view(B*T)\n", 622 | " loss = F.cross_entropy(logits, targets)\n", 623 | "\n", 624 | " return logits, loss\n", 625 | "\n", 626 | " def generate(self, idx, max_new_tokens):\n", 627 | " # idx is (B, T) array of indices in the current context\n", 628 | " for _ in range(max_new_tokens):\n", 629 | " # crop idx to the last block_size tokens\n", 630 | " idx_cond = idx[:, -block_size:]\n", 631 | " # get the predictions\n", 632 | " logits, loss = self(idx_cond)\n", 633 | " # focus only on the last time step\n", 634 | " logits = logits[:, -1, :] # becomes (B, C)\n", 635 | " # apply softmax to get probabilities\n", 636 | " probs = F.softmax(logits, dim=-1) # (B, C)\n", 637 | " # sample from the distribution\n", 638 | " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n", 639 | " # append sampled index to the running sequence\n", 640 | " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n", 641 | " return idx\n" 642 | ] 643 | }, 644 | { 645 | "cell_type": "markdown", 646 | "metadata": { 647 | "id": "9skcec_600xP" 648 | }, 649 | "source": [ 650 | "# Step 9: Code the entire transformer block: (Multi-Head Attention)" 651 | ] 652 | }, 653 | { 654 | "cell_type": "code", 655 | "execution_count": null, 656 | "metadata": { 657 | "id": "pt0W8MNj0rFd" 658 | }, 659 | "outputs": [], 660 | "source": [ 661 | "class Head(nn.Module):\n", 662 | " \"\"\" one head of self-attention \"\"\"\n", 663 | "\n", 664 | " def __init__(self,head_size):\n", 665 | " super().__init__()\n", 666 | " self.key = nn.Linear(n_embed,head_size,bias=False)\n", 667 | " self.query = nn.Linear(n_embed,head_size,bias=False)\n", 668 | " self.value = nn.Linear(n_embed,head_size,bias=False)\n", 669 | "\n", 670 | " self.register_buffer('tril',torch.tril(torch.ones(block_size,block_size)))\n", 671 | " self.dropout = nn.Dropout(dropout)\n", 672 | "\n", 673 | "\n", 674 | "\n", 675 | " def forward(self,x):\n", 676 | " B,T,C = x.shape\n", 677 | " k = self.key(x)\n", 678 | " q = self.query(x)\n", 679 | " # Computer attention scores (\"affinities\")\n", 680 | " wei = q @ k.transpose(-2,-1)*C**-0.5 # (B,T,C) @ (B,C,T) -> (B,T,T)\n", 681 | " wei = wei.masked_fill(self.tril[:T,:T]== 0,float('-inf')) # (B,T,T)\n", 682 | " wei = F.softmax(wei,dim=-1)\n", 683 | " wei = self.dropout(wei)\n", 684 | "\n", 685 | " # perform the weighted aggregation ofthe values\n", 686 | " v = self.value(x) # (B,T,C)\n", 687 | " out = wei @ v # (B,T,T) @ (B,T,C) -> (B,T,C)\n", 688 | " return out\n", 689 | "\n", 690 | "\n", 691 | "# Multi-Headed Self Attention\n", 692 | "\n", 693 | "class MultiHeadAttention(nn.Module):\n", 694 | " \"\"\" multiple heads of self-attention in parallel \"\"\"\n", 695 | "\n", 696 | "\n", 697 | " def __init__(self,num_heads,head_size):\n", 698 | " super().__init__()\n", 699 | " self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n", 700 | " self.proj = nn.Linear(n_embed,n_embed)\n", 701 | " self.dropout = nn.Dropout(dropout)\n", 702 | "\n", 703 | "\n", 704 | " def forward(self,x):\n", 705 | " out = torch.cat([h(x) for h in self.heads], dim=-1)\n", 706 | " out = self.dropout(self.proj(out))\n", 707 | " return out\n", 708 | "\n", 709 | "\n" 710 | ] 711 | }, 712 | { 713 | "cell_type": "markdown", 714 | "metadata": { 715 | "id": "EYZ9CqedqXjy" 716 | }, 717 | "source": [ 718 | "# Step 10 Code the entire transfomer block: (Assemble all layers)" 719 | ] 720 | }, 721 | { 722 | "cell_type": "markdown", 723 | "metadata": { 724 | "id": "l-qt80KRq7SO" 725 | }, 726 | "source": [ 727 | "Start by defining a Self-Attention combined with Mixture-of-Experts (MoE) block, designed to be modular and reusable across multiple layers of the model.\n", 728 | "(For better clarity, key architectural parameters are copied and reused.)" 729 | ] 730 | }, 731 | { 732 | "cell_type": "code", 733 | "execution_count": null, 734 | "metadata": { 735 | "id": "2bKBebaxqWC9" 736 | }, 737 | "outputs": [], 738 | "source": [] 739 | }, 740 | { 741 | "cell_type": "markdown", 742 | "metadata": { 743 | "id": "TOngSSo8wgkN" 744 | }, 745 | "source": [ 746 | "# Step 11: Define entire language model architecture\n" 747 | ] 748 | }, 749 | { 750 | "cell_type": "markdown", 751 | "metadata": { 752 | "id": "nVwVQwFVwnQS" 753 | }, 754 | "source": [ 755 | "Finally putting it all together to crease a sparse mixture of experts language model" 756 | ] 757 | }, 758 | { 759 | "cell_type": "code", 760 | "execution_count": null, 761 | "metadata": { 762 | "id": "gQKFUAWFrhl6" 763 | }, 764 | "outputs": [], 765 | "source": [ 766 | "class SparseMoELanguageModel(nn.Module):\n", 767 | "\n", 768 | " def __init__(self):\n", 769 | " super().__init__()\n", 770 | " # each token directly reads off the logits for the next token from a lookup table\n", 771 | " self.token_embedding_table = nn.Embedding(vocab_size, n_embed)\n", 772 | " self.position_embedding_table = nn.Embedding(block_size, n_embed)\n", 773 | " self.blocks = nn.Sequential(*[Block(n_embed, n_head=n_head, num_experts=num_experts,top_k=top_k) for _ in range(n_layer)])\n", 774 | " self.ln_f = nn.LayerNorm(n_embed) # final layer norm\n", 775 | " self.lm_head = nn.Linear(n_embed, vocab_size)\n", 776 | "\n", 777 | " def forward(self, idx, targets=None):\n", 778 | " B, T = idx.shape\n", 779 | "\n", 780 | " # idx and targets are both (B,T) tensor of integers\n", 781 | " tok_emb = self.token_embedding_table(idx) # (B,T,C)\n", 782 | " pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n", 783 | " x = tok_emb + pos_emb # (B,T,C)\n", 784 | " x = self.blocks(x) # (B,T,C)\n", 785 | " x = self.ln_f(x) # (B,T,C)\n", 786 | " logits = self.lm_head(x) # (B,T,vocab_size)\n", 787 | "\n", 788 | " if targets is None:\n", 789 | " loss = None\n", 790 | " else:\n", 791 | " B, T, C = logits.shape\n", 792 | " logits = logits.view(B*T, C)\n", 793 | " targets = targets.view(B*T)\n", 794 | " loss = F.cross_entropy(logits, targets)\n", 795 | "\n", 796 | " return logits, loss\n", 797 | "\n", 798 | " def generate(self, idx, max_new_tokens):\n", 799 | " # idx is (B, T) array of indices in the current context\n", 800 | " for _ in range(max_new_tokens):\n", 801 | " # crop idx to the last block_size tokens\n", 802 | " idx_cond = idx[:, -block_size:]\n", 803 | " # get the predictions\n", 804 | " logits, loss = self(idx_cond)\n", 805 | " # focus only on the last time step\n", 806 | " logits = logits[:, -1, :] # becomes (B, C)\n", 807 | " # apply softmax to get probabilities\n", 808 | " probs = F.softmax(logits, dim=-1) # (B, C)\n", 809 | " # sample from the distribution\n", 810 | " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n", 811 | " # append sampled index to the running sequence\n", 812 | " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n", 813 | " return idx\n" 814 | ] 815 | }, 816 | { 817 | "cell_type": "markdown", 818 | "metadata": { 819 | "id": "d4ljt-EB0vi8" 820 | }, 821 | "source": [ 822 | "# Step 12: Create training and testing data" 823 | ] 824 | }, 825 | { 826 | "cell_type": "code", 827 | "execution_count": null, 828 | "metadata": { 829 | "id": "Z7qpudyh0nhD" 830 | }, 831 | "outputs": [], 832 | "source": [ 833 | "torch.manual_seed(42)\n", 834 | "\n", 835 | "with open('input.txt','r',encoding='utf-8') as f:\n", 836 | " text = f.read()\n", 837 | "\n", 838 | "# here are all the unique characters that occur in this text\n", 839 | "\n", 840 | "chars = sorted(list(set(text)))\n", 841 | "vocab_size = len(chars)\n", 842 | "# create a mapping from characters to integers\n", 843 | "\n", 844 | "stoi = {ch:i for i,ch in enumerate(chars)}\n", 845 | "itos = {i:ch for i,ch in enumerate(chars)}\n", 846 | "encode = lambda s: [stoi[c] for c in s] # encoder: take a string,output a list of integers\n", 847 | "decode = lambda l: ''.join(itos[i] for i in l) # # decoder: take a list of integers, output a string\n", 848 | "\n", 849 | "\n", 850 | "# Train and test splits\n", 851 | "\n", 852 | "data = torch.tensor(encode(text),dtype=torch.long)\n", 853 | "n = int(0.9*len(data)) # first 90% will be train, rest val\n", 854 | "train_data = data[:n]\n", 855 | "val_data = data[n:]\n", 856 | "\n", 857 | "\n", 858 | "# data loading\n", 859 | "\n", 860 | "def get_batch(split):\n", 861 | " # generate a small batch of data of inputs x and target y\n", 862 | " data = train_data if split == 'train' else val_data\n", 863 | " ix = torch.randint(len(data) - block_size, (batch_size,))\n", 864 | "\n", 865 | " # Input: tokens at positions [i, i+1, ..., i+block_size-1]\n", 866 | " x = torch.stack([data[i:i+block_size] for i in ix])\n", 867 | "\n", 868 | " # Target: tokens at positions [i+1, i+2, ..., i+block_size]\n", 869 | " # This is the \"next token\" for each input token\n", 870 | " y = torch.stack([data[i+1:i+1+block_size] for i in ix])\n", 871 | "\n", 872 | " x, y = x.to(device), y.to(device)\n", 873 | " return x, y\n", 874 | "\n", 875 | "\n", 876 | "\n" 877 | ] 878 | }, 879 | { 880 | "cell_type": "markdown", 881 | "metadata": { 882 | "id": "NfX2k052LtFQ" 883 | }, 884 | "source": [ 885 | "# Step 13: Define LLM Loss" 886 | ] 887 | }, 888 | { 889 | "cell_type": "code", 890 | "execution_count": null, 891 | "metadata": { 892 | "id": "Fp7_Ea610brK" 893 | }, 894 | "outputs": [], 895 | "source": [ 896 | "@torch.no_grad()\n", 897 | "def estimate_loss():\n", 898 | " out = {}\n", 899 | " model.eval()\n", 900 | " for split in ['train','val']:\n", 901 | " losses = torch.zeros(eval_iters)\n", 902 | " for k in range(eval_iters):\n", 903 | " X,Y = get_batch(split)\n", 904 | " logits, loss = model(X,Y)\n", 905 | " losses[k] = loss.item()\n", 906 | " out[split] = losses.mean()\n", 907 | " model.train()\n", 908 | " return out\n" 909 | ] 910 | }, 911 | { 912 | "cell_type": "markdown", 913 | "metadata": { 914 | "id": "0o514O42MKOn" 915 | }, 916 | "source": [ 917 | "# Step 14: Define training Loop parameters and other hyperparameters" 918 | ] 919 | }, 920 | { 921 | "cell_type": "code", 922 | "execution_count": null, 923 | "metadata": { 924 | "id": "jHIgfItVMJks" 925 | }, 926 | "outputs": [], 927 | "source": [ 928 | "# First defining hyperparameters and boiler place code. Imports and data preparation code is repeated for convenience\n", 929 | "\n", 930 | "import torch\n", 931 | "import torch.nn as nn\n", 932 | "from torch.nn import functional as F\n", 933 | "from torch.nn import init\n", 934 | "\n", 935 | "# hyperparameters\n", 936 | "batch_size = 16 # how many independent sequences will we process in parallel?\n", 937 | "block_size = 32 # what is the maximum context lenght for prediction?\n", 938 | "max_iters = 1000\n", 939 | "eval_interval = 100\n", 940 | "learning_rate = 1e-3\n", 941 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 942 | "eval_iters = 400\n", 943 | "head_size = 16\n", 944 | "n_embed = 128\n", 945 | "n_head = 8\n", 946 | "n_layer = 8\n", 947 | "dropout = 0.1\n", 948 | "num_experts = 8\n", 949 | "top_k = 2" 950 | ] 951 | }, 952 | { 953 | "cell_type": "markdown", 954 | "metadata": { 955 | "id": "IA266vdgNjci" 956 | }, 957 | "source": [ 958 | "# Step 15: Initialize the entire model" 959 | ] 960 | }, 961 | { 962 | "cell_type": "code", 963 | "execution_count": null, 964 | "metadata": { 965 | "id": "4nDIBLGTMJcq" 966 | }, 967 | "outputs": [], 968 | "source": [ 969 | "def kaiming_init__weights(m):\n", 970 | " if isinstance (m, (nn.Linear)):\n", 971 | " init.kaiming_normal_(m.weight)" 972 | ] 973 | }, 974 | { 975 | "cell_type": "code", 976 | "execution_count": null, 977 | "metadata": { 978 | "colab": { 979 | "base_uri": "https://localhost:8080/" 980 | }, 981 | "id": "CakR1eE2NyiA", 982 | "outputId": "da8d6d0a-8c61-40b4-d6ab-9ed0c5635cb2" 983 | }, 984 | "outputs": [ 985 | { 986 | "output_type": "execute_result", 987 | "data": { 988 | "text/plain": [ 989 | "SparseMoELanguageModel(\n", 990 | " (token_embedding_table): Embedding(65, 128)\n", 991 | " (position_embedding_table): Embedding(32, 128)\n", 992 | " (blocks): Sequential(\n", 993 | " (0): Block(\n", 994 | " (sa): MultiHeadAttention(\n", 995 | " (heads): ModuleList(\n", 996 | " (0-7): 8 x Head(\n", 997 | " (key): Linear(in_features=128, out_features=16, bias=False)\n", 998 | " (query): Linear(in_features=128, out_features=16, bias=False)\n", 999 | " (value): Linear(in_features=128, out_features=16, bias=False)\n", 1000 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1001 | " )\n", 1002 | " )\n", 1003 | " (proj): Linear(in_features=128, out_features=128, bias=True)\n", 1004 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1005 | " )\n", 1006 | " (smoe): SparseMoE(\n", 1007 | " (router): NoisyTopkRouter(\n", 1008 | " (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n", 1009 | " (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n", 1010 | " )\n", 1011 | " (experts): ModuleList(\n", 1012 | " (0-7): 8 x Expert(\n", 1013 | " (net): Sequential(\n", 1014 | " (0): Linear(in_features=128, out_features=512, bias=True)\n", 1015 | " (1): ReLU()\n", 1016 | " (2): Linear(in_features=512, out_features=128, bias=True)\n", 1017 | " (3): Dropout(p=0.1, inplace=False)\n", 1018 | " )\n", 1019 | " )\n", 1020 | " )\n", 1021 | " )\n", 1022 | " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 1023 | " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 1024 | " )\n", 1025 | " (1): Block(\n", 1026 | " (sa): MultiHeadAttention(\n", 1027 | " (heads): ModuleList(\n", 1028 | " (0-7): 8 x Head(\n", 1029 | " (key): Linear(in_features=128, out_features=16, bias=False)\n", 1030 | " (query): Linear(in_features=128, out_features=16, bias=False)\n", 1031 | " (value): Linear(in_features=128, out_features=16, bias=False)\n", 1032 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1033 | " )\n", 1034 | " )\n", 1035 | " (proj): Linear(in_features=128, out_features=128, bias=True)\n", 1036 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1037 | " )\n", 1038 | " (smoe): SparseMoE(\n", 1039 | " (router): NoisyTopkRouter(\n", 1040 | " (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n", 1041 | " (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n", 1042 | " )\n", 1043 | " (experts): ModuleList(\n", 1044 | " (0-7): 8 x Expert(\n", 1045 | " (net): Sequential(\n", 1046 | " (0): Linear(in_features=128, out_features=512, bias=True)\n", 1047 | " (1): ReLU()\n", 1048 | " (2): Linear(in_features=512, out_features=128, bias=True)\n", 1049 | " (3): Dropout(p=0.1, inplace=False)\n", 1050 | " )\n", 1051 | " )\n", 1052 | " )\n", 1053 | " )\n", 1054 | " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 1055 | " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 1056 | " )\n", 1057 | " (2): Block(\n", 1058 | " (sa): MultiHeadAttention(\n", 1059 | " (heads): ModuleList(\n", 1060 | " (0-7): 8 x Head(\n", 1061 | " (key): Linear(in_features=128, out_features=16, bias=False)\n", 1062 | " (query): Linear(in_features=128, out_features=16, bias=False)\n", 1063 | " (value): Linear(in_features=128, out_features=16, bias=False)\n", 1064 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1065 | " )\n", 1066 | " )\n", 1067 | " (proj): Linear(in_features=128, out_features=128, bias=True)\n", 1068 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1069 | " )\n", 1070 | " (smoe): SparseMoE(\n", 1071 | " (router): NoisyTopkRouter(\n", 1072 | " (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n", 1073 | " (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n", 1074 | " )\n", 1075 | " (experts): ModuleList(\n", 1076 | " (0-7): 8 x Expert(\n", 1077 | " (net): Sequential(\n", 1078 | " (0): Linear(in_features=128, out_features=512, bias=True)\n", 1079 | " (1): ReLU()\n", 1080 | " (2): Linear(in_features=512, out_features=128, bias=True)\n", 1081 | " (3): Dropout(p=0.1, inplace=False)\n", 1082 | " )\n", 1083 | " )\n", 1084 | " )\n", 1085 | " )\n", 1086 | " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 1087 | " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 1088 | " )\n", 1089 | " (3): Block(\n", 1090 | " (sa): MultiHeadAttention(\n", 1091 | " (heads): ModuleList(\n", 1092 | " (0-7): 8 x Head(\n", 1093 | " (key): Linear(in_features=128, out_features=16, bias=False)\n", 1094 | " (query): Linear(in_features=128, out_features=16, bias=False)\n", 1095 | " (value): Linear(in_features=128, out_features=16, bias=False)\n", 1096 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1097 | " )\n", 1098 | " )\n", 1099 | " (proj): Linear(in_features=128, out_features=128, bias=True)\n", 1100 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1101 | " )\n", 1102 | " (smoe): SparseMoE(\n", 1103 | " (router): NoisyTopkRouter(\n", 1104 | " (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n", 1105 | " (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n", 1106 | " )\n", 1107 | " (experts): ModuleList(\n", 1108 | " (0-7): 8 x Expert(\n", 1109 | " (net): Sequential(\n", 1110 | " (0): Linear(in_features=128, out_features=512, bias=True)\n", 1111 | " (1): ReLU()\n", 1112 | " (2): Linear(in_features=512, out_features=128, bias=True)\n", 1113 | " (3): Dropout(p=0.1, inplace=False)\n", 1114 | " )\n", 1115 | " )\n", 1116 | " )\n", 1117 | " )\n", 1118 | " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 1119 | " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 1120 | " )\n", 1121 | " (4): Block(\n", 1122 | " (sa): MultiHeadAttention(\n", 1123 | " (heads): ModuleList(\n", 1124 | " (0-7): 8 x Head(\n", 1125 | " (key): Linear(in_features=128, out_features=16, bias=False)\n", 1126 | " (query): Linear(in_features=128, out_features=16, bias=False)\n", 1127 | " (value): Linear(in_features=128, out_features=16, bias=False)\n", 1128 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1129 | " )\n", 1130 | " )\n", 1131 | " (proj): Linear(in_features=128, out_features=128, bias=True)\n", 1132 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1133 | " )\n", 1134 | " (smoe): SparseMoE(\n", 1135 | " (router): NoisyTopkRouter(\n", 1136 | " (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n", 1137 | " (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n", 1138 | " )\n", 1139 | " (experts): ModuleList(\n", 1140 | " (0-7): 8 x Expert(\n", 1141 | " (net): Sequential(\n", 1142 | " (0): Linear(in_features=128, out_features=512, bias=True)\n", 1143 | " (1): ReLU()\n", 1144 | " (2): Linear(in_features=512, out_features=128, bias=True)\n", 1145 | " (3): Dropout(p=0.1, inplace=False)\n", 1146 | " )\n", 1147 | " )\n", 1148 | " )\n", 1149 | " )\n", 1150 | " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 1151 | " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 1152 | " )\n", 1153 | " (5): Block(\n", 1154 | " (sa): MultiHeadAttention(\n", 1155 | " (heads): ModuleList(\n", 1156 | " (0-7): 8 x Head(\n", 1157 | " (key): Linear(in_features=128, out_features=16, bias=False)\n", 1158 | " (query): Linear(in_features=128, out_features=16, bias=False)\n", 1159 | " (value): Linear(in_features=128, out_features=16, bias=False)\n", 1160 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1161 | " )\n", 1162 | " )\n", 1163 | " (proj): Linear(in_features=128, out_features=128, bias=True)\n", 1164 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1165 | " )\n", 1166 | " (smoe): SparseMoE(\n", 1167 | " (router): NoisyTopkRouter(\n", 1168 | " (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n", 1169 | " (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n", 1170 | " )\n", 1171 | " (experts): ModuleList(\n", 1172 | " (0-7): 8 x Expert(\n", 1173 | " (net): Sequential(\n", 1174 | " (0): Linear(in_features=128, out_features=512, bias=True)\n", 1175 | " (1): ReLU()\n", 1176 | " (2): Linear(in_features=512, out_features=128, bias=True)\n", 1177 | " (3): Dropout(p=0.1, inplace=False)\n", 1178 | " )\n", 1179 | " )\n", 1180 | " )\n", 1181 | " )\n", 1182 | " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 1183 | " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 1184 | " )\n", 1185 | " (6): Block(\n", 1186 | " (sa): MultiHeadAttention(\n", 1187 | " (heads): ModuleList(\n", 1188 | " (0-7): 8 x Head(\n", 1189 | " (key): Linear(in_features=128, out_features=16, bias=False)\n", 1190 | " (query): Linear(in_features=128, out_features=16, bias=False)\n", 1191 | " (value): Linear(in_features=128, out_features=16, bias=False)\n", 1192 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1193 | " )\n", 1194 | " )\n", 1195 | " (proj): Linear(in_features=128, out_features=128, bias=True)\n", 1196 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1197 | " )\n", 1198 | " (smoe): SparseMoE(\n", 1199 | " (router): NoisyTopkRouter(\n", 1200 | " (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n", 1201 | " (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n", 1202 | " )\n", 1203 | " (experts): ModuleList(\n", 1204 | " (0-7): 8 x Expert(\n", 1205 | " (net): Sequential(\n", 1206 | " (0): Linear(in_features=128, out_features=512, bias=True)\n", 1207 | " (1): ReLU()\n", 1208 | " (2): Linear(in_features=512, out_features=128, bias=True)\n", 1209 | " (3): Dropout(p=0.1, inplace=False)\n", 1210 | " )\n", 1211 | " )\n", 1212 | " )\n", 1213 | " )\n", 1214 | " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 1215 | " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 1216 | " )\n", 1217 | " (7): Block(\n", 1218 | " (sa): MultiHeadAttention(\n", 1219 | " (heads): ModuleList(\n", 1220 | " (0-7): 8 x Head(\n", 1221 | " (key): Linear(in_features=128, out_features=16, bias=False)\n", 1222 | " (query): Linear(in_features=128, out_features=16, bias=False)\n", 1223 | " (value): Linear(in_features=128, out_features=16, bias=False)\n", 1224 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1225 | " )\n", 1226 | " )\n", 1227 | " (proj): Linear(in_features=128, out_features=128, bias=True)\n", 1228 | " (dropout): Dropout(p=0.1, inplace=False)\n", 1229 | " )\n", 1230 | " (smoe): SparseMoE(\n", 1231 | " (router): NoisyTopkRouter(\n", 1232 | " (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n", 1233 | " (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n", 1234 | " )\n", 1235 | " (experts): ModuleList(\n", 1236 | " (0-7): 8 x Expert(\n", 1237 | " (net): Sequential(\n", 1238 | " (0): Linear(in_features=128, out_features=512, bias=True)\n", 1239 | " (1): ReLU()\n", 1240 | " (2): Linear(in_features=512, out_features=128, bias=True)\n", 1241 | " (3): Dropout(p=0.1, inplace=False)\n", 1242 | " )\n", 1243 | " )\n", 1244 | " )\n", 1245 | " )\n", 1246 | " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 1247 | " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 1248 | " )\n", 1249 | " )\n", 1250 | " (ln_f): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", 1251 | " (lm_head): Linear(in_features=128, out_features=65, bias=True)\n", 1252 | ")" 1253 | ] 1254 | }, 1255 | "metadata": {}, 1256 | "execution_count": 93 1257 | } 1258 | ], 1259 | "source": [ 1260 | "model = SparseMoELanguageModel()\n", 1261 | "model.apply(kaiming_init__weights)" 1262 | ] 1263 | }, 1264 | { 1265 | "cell_type": "markdown", 1266 | "metadata": { 1267 | "id": "vr6Bz9-CODMn" 1268 | }, 1269 | "source": [ 1270 | "# Step 16: Run the pre-training Loop" 1271 | ] 1272 | }, 1273 | { 1274 | "cell_type": "code", 1275 | "execution_count": null, 1276 | "metadata": { 1277 | "colab": { 1278 | "base_uri": "https://localhost:8080/" 1279 | }, 1280 | "id": "QY7j6jf6OFrZ", 1281 | "outputId": "077b4ee9-db74-4ea9-8ac9-e65001e391a0" 1282 | }, 1283 | "outputs": [ 1284 | { 1285 | "output_type": "stream", 1286 | "name": "stdout", 1287 | "text": [ 1288 | "8.996545 M parameters\n", 1289 | "step 0: train loss 5.1127, val loss 5.1287\n", 1290 | "step 100: train loss 2.6564, val loss 2.6589\n", 1291 | "step 200: train loss 2.4973, val loss 2.4938\n", 1292 | "step 300: train loss 2.3931, val loss 2.3948\n", 1293 | "step 400: train loss 2.3098, val loss 2.3223\n", 1294 | "step 500: train loss 2.2309, val loss 2.2536\n", 1295 | "step 600: train loss 2.1587, val loss 2.2019\n", 1296 | "step 700: train loss 2.1083, val loss 2.1624\n", 1297 | "step 800: train loss 2.0612, val loss 2.1126\n", 1298 | "step 900: train loss 2.0206, val loss 2.0880\n", 1299 | "step 999: train loss 1.9632, val loss 2.0509\n", 1300 | "=== VOCABULARY DEBUG ===\n", 1301 | "Actual vocabulary size from text: 65\n", 1302 | "Current vocab_size variable: 65\n", 1303 | "Model's output layer expects: 65\n", 1304 | "Vocabulary sizes match\n", 1305 | "=== END DEBUG ===\n" 1306 | ] 1307 | } 1308 | ], 1309 | "source": [ 1310 | "m = model.to(device)\n", 1311 | "# Print the number of parameters in the model\n", 1312 | "print(sum(p.numel() for p in m.parameters())/1e6,'M parameters')\n", 1313 | "\n", 1314 | "# create a PyTorch optimizer\n", 1315 | "optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)\n", 1316 | "\n", 1317 | "for iter in range(max_iters):\n", 1318 | " # every once in a while evaluate the loss on train and val sets\n", 1319 | " if iter % eval_interval == 0 or iter == max_iters - 1:\n", 1320 | " losses = estimate_loss()\n", 1321 | " print(\"step {}: train loss {:.4f}, val loss {:.4f}\".format(iter, losses['train'], losses['val']))\n", 1322 | "\n", 1323 | " # sample. a batch of data\n", 1324 | " xb, yb = get_batch('train')\n", 1325 | "\n", 1326 | " # evaluate the loss\n", 1327 | " logits, loss = model(xb,yb)\n", 1328 | " optimizer.zero_grad(set_to_none=True)\n", 1329 | " loss.backward()\n", 1330 | " optimizer.step()\n", 1331 | "\n", 1332 | "# Debug the vocabulary mismatch\n", 1333 | "print(\"=== VOCABULARY DEBUG ===\")\n", 1334 | "\n", 1335 | "# Check actual vocabulary\n", 1336 | "with open('input.txt', 'r', encoding='utf-8') as f:\n", 1337 | " text = f.read()\n", 1338 | "\n", 1339 | "chars = sorted(list(set(text)))\n", 1340 | "actual_vocab_size = len(chars)\n", 1341 | "\n", 1342 | "print(f\"Actual vocabulary size from text: {actual_vocab_size}\")\n", 1343 | "print(f\"Current vocab_size variable: {vocab_size}\")\n", 1344 | "\n", 1345 | "# Check model's expected vocabulary size\n", 1346 | "model_vocab_size = model.lm_head.out_features\n", 1347 | "print(f\"Model's output layer expects: {model_vocab_size}\")\n", 1348 | "\n", 1349 | "# Show the mismatch\n", 1350 | "if actual_vocab_size != model_vocab_size:\n", 1351 | " print(f\"MISMATCH: Model expects {model_vocab_size} but data has {actual_vocab_size}\")\n", 1352 | " print(\"Solution: Recreate the model with the correct vocab_size\")\n", 1353 | "else:\n", 1354 | " print(\"Vocabulary sizes match\")\n", 1355 | "\n", 1356 | "print(\"=== END DEBUG ===\")" 1357 | ] 1358 | }, 1359 | { 1360 | "cell_type": "markdown", 1361 | "metadata": { 1362 | "id": "vk9Vw9cOaYm2" 1363 | }, 1364 | "source": [ 1365 | "# Step 17: Inference" 1366 | ] 1367 | }, 1368 | { 1369 | "cell_type": "code", 1370 | "source": [ 1371 | "# generate from the model\n", 1372 | "\n", 1373 | "context = torch.zeros((1,1), dtype=torch.long, device=device)\n", 1374 | "print(decode(m.generate(context,max_new_tokens=2000)[0].tolist))" 1375 | ], 1376 | "metadata": { 1377 | "id": "Vu0pMx1ICON1" 1378 | }, 1379 | "execution_count": null, 1380 | "outputs": [] 1381 | } 1382 | ], 1383 | "metadata": { 1384 | "colab": { 1385 | "provenance": [], 1386 | "machine_shape": "hm" 1387 | }, 1388 | "kernelspec": { 1389 | "display_name": "Python 3", 1390 | "name": "python3" 1391 | }, 1392 | "language_info": { 1393 | "name": "python" 1394 | } 1395 | }, 1396 | "nbformat": 4, 1397 | "nbformat_minor": 0 1398 | } --------------------------------------------------------------------------------