├── .gitignore ├── Makefile ├── convert_qwen3_weights.py ├── README.md └── qwen_moe.c /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled binaries 2 | *.exe 3 | *.out 4 | *.o 5 | *.obj 6 | qwen3_moe 7 | qwen3_moe_debug 8 | qwen3_moe_no_omp 9 | qwen3_moe_portable 10 | qwen3_moe_static 11 | 12 | # Model files (too large for git) 13 | *.bin 14 | *.safetensors 15 | *.pt 16 | *.pth 17 | 18 | # Downloaded model directories 19 | Qwen*-*/ 20 | *-A3B*/ 21 | *-A22B*/ 22 | 23 | # Python cache 24 | __pycache__/ 25 | *.pyc 26 | *.pyo 27 | *.pyd 28 | .Python 29 | 30 | # Virtual environments 31 | venv/ 32 | env/ 33 | .venv/ 34 | .env/ 35 | 36 | # IDEs and editors 37 | .vscode/ 38 | .idea/ 39 | *.swp 40 | *.swo 41 | *~ 42 | 43 | # OS generated files 44 | .DS_Store 45 | .DS_Store? 46 | ._* 47 | .Spotlight-V100 48 | .Trashes 49 | ehthumbs.db 50 | Thumbs.db 51 | 52 | # Logs and temporary files 53 | *.log 54 | *.tmp 55 | *.temp 56 | 57 | # Jupyter notebook checkpoints 58 | .ipynb_checkpoints/ 59 | 60 | # Build artifacts 61 | build/ 62 | dist/ 63 | *.egg-info/ 64 | 65 | # Debug files 66 | *.dSYM/ 67 | *.su 68 | *.idb 69 | *.pdb 70 | 71 | # Core dumps 72 | core.* 73 | 74 | # Standalone files from other repositories 75 | standalone-*.ipynb -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Qwen3 MoE C inference 2 | CC = gcc 3 | CFLAGS = -O3 -Wall -Wextra -std=c99 4 | LDFLAGS = -lm 5 | 6 | # OpenMP support 7 | OPENMP_FLAGS = -fopenmp 8 | OPENMP_LIBS = -fopenmp 9 | 10 | # Architecture specific optimizations 11 | ARCH_FLAGS = -march=native -mtune=native 12 | 13 | # Debug flags 14 | DEBUG_FLAGS = -g -DDEBUG 15 | 16 | # Default target 17 | TARGET = qwen3_moe 18 | SOURCE = qwen_moe.c 19 | 20 | # Default build (optimized with OpenMP) 21 | $(TARGET): $(SOURCE) 22 | $(CC) $(CFLAGS) $(ARCH_FLAGS) $(OPENMP_FLAGS) -o $@ $< $(LDFLAGS) $(OPENMP_LIBS) 23 | 24 | # Debug build 25 | debug: $(SOURCE) 26 | $(CC) $(CFLAGS) $(DEBUG_FLAGS) $(OPENMP_FLAGS) -o $(TARGET)_debug $< $(LDFLAGS) $(OPENMP_LIBS) 27 | 28 | # Build without OpenMP (for compatibility) 29 | no-openmp: $(SOURCE) 30 | $(CC) $(CFLAGS) $(ARCH_FLAGS) -o $(TARGET)_no_omp $< $(LDFLAGS) 31 | 32 | # Build for older systems (no advanced arch flags) 33 | portable: $(SOURCE) 34 | $(CC) -O2 -Wall -std=c99 $(OPENMP_FLAGS) -o $(TARGET)_portable $< $(LDFLAGS) $(OPENMP_LIBS) 35 | 36 | # Static build 37 | static: $(SOURCE) 38 | $(CC) $(CFLAGS) $(ARCH_FLAGS) $(OPENMP_FLAGS) -static -o $(TARGET)_static $< $(LDFLAGS) $(OPENMP_LIBS) 39 | 40 | # Clean build artifacts 41 | clean: 42 | rm -f $(TARGET) $(TARGET)_debug $(TARGET)_no_omp $(TARGET)_portable $(TARGET)_static 43 | 44 | # Run tests (requires model file) 45 | test: $(TARGET) 46 | @echo "Testing with dummy model (will fail without actual model file)" 47 | @echo "Usage: make MODEL=path/to/model.bin test-model" 48 | 49 | test-model: $(TARGET) 50 | @if [ -z "$(MODEL)" ]; then \ 51 | echo "Error: Please specify MODEL=path/to/model.bin"; \ 52 | exit 1; \ 53 | fi 54 | ./$(TARGET) $(MODEL) 0.8 50 55 | 56 | # Convert weights (requires Python script) 57 | convert-weights: 58 | @echo "Converting Qwen3 weights to binary format..." 59 | @echo "Usage: python convert_qwen3_weights.py Qwen/Qwen3-Coder-30B-A3B-Instruct qwen3_moe.bin" 60 | 61 | # Install dependencies for weight conversion 62 | install-deps: 63 | pip install torch safetensors huggingface_hub 64 | 65 | # Help target 66 | help: 67 | @echo "Available targets:" 68 | @echo " $(TARGET) - Build optimized version with OpenMP (default)" 69 | @echo " debug - Build debug version" 70 | @echo " no-openmp - Build without OpenMP support" 71 | @echo " portable - Build portable version for older systems" 72 | @echo " static - Build statically linked version" 73 | @echo " clean - Remove build artifacts" 74 | @echo " test-model - Test with MODEL=path/to/model.bin" 75 | @echo " convert-weights - Show weight conversion command" 76 | @echo " install-deps - Install Python dependencies for weight conversion" 77 | @echo " help - Show this help message" 78 | @echo "" 79 | @echo "Example usage:" 80 | @echo " make # Build optimized version" 81 | @echo " make test-model MODEL=qwen3_moe.bin" 82 | @echo " ./$(TARGET) model.bin 0.8 100 # Run with temperature 0.8, max 100 tokens" 83 | 84 | .PHONY: debug no-openmp portable static clean test test-model convert-weights install-deps help -------------------------------------------------------------------------------- /convert_qwen3_weights.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Convert Qwen3 MoE PyTorch weights to binary format for C inference 4 | """ 5 | 6 | import torch 7 | import struct 8 | import numpy as np 9 | import json 10 | from pathlib import Path 11 | from safetensors.torch import load_file 12 | from huggingface_hub import snapshot_download 13 | 14 | 15 | def convert_qwen3_to_binary(repo_id, output_path): 16 | """ 17 | Convert Qwen3 MoE model weights to binary format for C inference 18 | """ 19 | print(f"Downloading model from {repo_id}...") 20 | 21 | # Download model 22 | local_dir = Path(repo_id).parts[-1] 23 | repo_dir = snapshot_download(repo_id=repo_id, local_dir=local_dir) 24 | 25 | # Load weights 26 | index_path = Path(repo_dir) / "model.safetensors.index.json" 27 | with open(index_path, "r") as f: 28 | index = json.load(f) 29 | 30 | weights_dict = {} 31 | for filename in set(index["weight_map"].values()): 32 | shard_path = Path(repo_dir) / filename 33 | shard = load_file(shard_path) 34 | weights_dict.update(shard) 35 | 36 | # Load config 37 | config_path = Path(repo_dir) / "config.json" 38 | with open(config_path, "r") as f: 39 | config_json = json.load(f) 40 | 41 | # Create C-compatible config 42 | config = { 43 | 'dim': config_json['hidden_size'], # 2048 44 | 'hidden_dim': 0, # Not used for MoE layers 45 | 'n_layers': config_json['num_hidden_layers'], # 48 46 | 'n_heads': config_json['num_attention_heads'], # 32 47 | 'n_kv_heads': config_json['num_key_value_heads'], # 4 48 | 'vocab_size': config_json['vocab_size'], # 151936 49 | 'seq_len': config_json['max_position_embeddings'], # 262144 50 | 'head_dim': config_json['hidden_size'] // config_json['num_attention_heads'], # 128 51 | 'qk_norm': 1 if config_json.get('qk_norm', False) else 0, 52 | 'num_experts': config_json['num_experts'], # 128 53 | 'num_experts_per_tok': config_json['num_experts_per_tok'], # 8 54 | 'moe_intermediate_size': config_json['moe_intermediate_size'], # 768 55 | 'rope_theta': float(config_json.get('rope_theta', 10000000.0)) 56 | } 57 | 58 | print("Config:") 59 | for k, v in config.items(): 60 | print(f" {k}: {v}") 61 | 62 | with open(output_path, 'wb') as f: 63 | # Write config as binary (13 integers + 1 float) 64 | f.write(struct.pack('i', config['dim'])) 65 | f.write(struct.pack('i', config['hidden_dim'])) 66 | f.write(struct.pack('i', config['n_layers'])) 67 | f.write(struct.pack('i', config['n_heads'])) 68 | f.write(struct.pack('i', config['n_kv_heads'])) 69 | f.write(struct.pack('i', config['vocab_size'])) 70 | f.write(struct.pack('i', config['seq_len'])) 71 | f.write(struct.pack('i', config['head_dim'])) 72 | f.write(struct.pack('i', config['qk_norm'])) 73 | f.write(struct.pack('i', config['num_experts'])) 74 | f.write(struct.pack('i', config['num_experts_per_tok'])) 75 | f.write(struct.pack('i', config['moe_intermediate_size'])) 76 | f.write(struct.pack('f', config['rope_theta'])) 77 | 78 | # Write weights in the order expected by the C code 79 | print("Writing weights...") 80 | 81 | # 1. Token embeddings 82 | print(" Token embeddings...") 83 | write_tensor(f, weights_dict["model.embed_tokens.weight"]) 84 | 85 | # 2. RMSNorm weights (attention) 86 | print(" Attention RMSNorm weights...") 87 | for l in range(config['n_layers']): 88 | write_tensor(f, weights_dict[f"model.layers.{l}.input_layernorm.weight"]) 89 | 90 | # 3. RMSNorm weights (FFN) 91 | print(" FFN RMSNorm weights...") 92 | for l in range(config['n_layers']): 93 | write_tensor(f, weights_dict[f"model.layers.{l}.post_attention_layernorm.weight"]) 94 | 95 | # 4. QK norm weights (if enabled) 96 | if config['qk_norm']: 97 | print(" QK norm weights...") 98 | for l in range(config['n_layers']): 99 | write_tensor(f, weights_dict[f"model.layers.{l}.self_attn.q_norm.weight"]) 100 | for l in range(config['n_layers']): 101 | write_tensor(f, weights_dict[f"model.layers.{l}.self_attn.k_norm.weight"]) 102 | 103 | # 5. Attention weights 104 | print(" Attention Q weights...") 105 | for l in range(config['n_layers']): 106 | write_tensor(f, weights_dict[f"model.layers.{l}.self_attn.q_proj.weight"]) 107 | 108 | print(" Attention K weights...") 109 | for l in range(config['n_layers']): 110 | write_tensor(f, weights_dict[f"model.layers.{l}.self_attn.k_proj.weight"]) 111 | 112 | print(" Attention V weights...") 113 | for l in range(config['n_layers']): 114 | write_tensor(f, weights_dict[f"model.layers.{l}.self_attn.v_proj.weight"]) 115 | 116 | print(" Attention O weights...") 117 | for l in range(config['n_layers']): 118 | write_tensor(f, weights_dict[f"model.layers.{l}.self_attn.o_proj.weight"]) 119 | 120 | # 6. MoE gating weights 121 | print(" MoE gating weights...") 122 | for l in range(config['n_layers']): 123 | write_tensor(f, weights_dict[f"model.layers.{l}.mlp.gate.weight"]) 124 | 125 | # 7. Expert weights 126 | print(" Expert weights...") 127 | # Write all gate_proj weights 128 | for l in range(config['n_layers']): 129 | for e in range(config['num_experts']): 130 | write_tensor(f, weights_dict[f"model.layers.{l}.mlp.experts.{e}.gate_proj.weight"]) 131 | 132 | # Write all down_proj weights 133 | for l in range(config['n_layers']): 134 | for e in range(config['num_experts']): 135 | write_tensor(f, weights_dict[f"model.layers.{l}.mlp.experts.{e}.down_proj.weight"]) 136 | 137 | # Write all up_proj weights 138 | for l in range(config['n_layers']): 139 | for e in range(config['num_experts']): 140 | write_tensor(f, weights_dict[f"model.layers.{l}.mlp.experts.{e}.up_proj.weight"]) 141 | 142 | # 8. Final norm and output 143 | print(" Final norm and output...") 144 | write_tensor(f, weights_dict["model.norm.weight"]) 145 | 146 | # Check if model uses weight tying 147 | if "lm_head.weight" in weights_dict: 148 | write_tensor(f, weights_dict["lm_head.weight"]) 149 | else: 150 | # Reuse embedding weights 151 | write_tensor(f, weights_dict["model.embed_tokens.weight"]) 152 | 153 | print(f"Successfully converted weights to {output_path}") 154 | 155 | 156 | def write_tensor(f, tensor): 157 | """Write tensor to file in float32 format""" 158 | if tensor.dtype != torch.float32: 159 | tensor = tensor.float() 160 | tensor_np = tensor.cpu().numpy() 161 | f.write(tensor_np.tobytes()) 162 | 163 | 164 | if __name__ == "__main__": 165 | import sys 166 | 167 | if len(sys.argv) != 3: 168 | print("Usage: python convert_qwen3_weights.py ") 169 | print("Example: python convert_qwen3_weights.py Qwen/Qwen3-Coder-30B-A3B-Instruct qwen3_moe.bin") 170 | sys.exit(1) 171 | 172 | repo_id = sys.argv[1] 173 | output_path = sys.argv[2] 174 | 175 | convert_qwen3_to_binary(repo_id, output_path) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Qwen_MOE_C 2 | 3 | A pure C implementation for inference of Qwen3 Mixture-of-Experts models, inspired by Andrej Karpathy's llama2.c approach. This implementation supports the full Qwen3-30B-A3B model architecture with 128 experts and sparse activation. 4 | 5 | ## Features 6 | 7 | - **Pure C Implementation**: Single file, no external dependencies except libc and libm 8 | - **Memory Efficient**: Uses memory mapping to avoid loading entire model into RAM 9 | - **Optimized Performance**: 10 | - OpenMP parallelization for attention heads 11 | - SIMD-friendly matrix multiplication 12 | - Efficient sparse MoE computation (only activates 8 out of 128 experts) 13 | - **Complete Architecture Support**: 14 | - Grouped Query Attention (GQA) with 32 query heads, 4 key/value heads 15 | - RMSNorm with QK normalization 16 | - RoPE (Rotary Position Embedding) 17 | - SwiGLU activation in experts 18 | - **Robust Error Handling**: Comprehensive bounds checking and validation 19 | 20 | ## Architecture Details 21 | 22 | The implementation supports Qwen3-30B-A3B models with the following specifications: 23 | 24 | - **Model Size**: 30B parameters (with 3B active per token) 25 | - **Layers**: 48 transformer blocks 26 | - **Dimensions**: 2048 hidden size, 128 head dimension 27 | - **Attention**: 32 query heads, 4 key/value heads (8:1 ratio) 28 | - **MoE**: 128 experts per layer, 8 experts activated per token 29 | - **Expert Size**: 768 intermediate dimension per expert 30 | - **Context**: Up to 262,144 tokens 31 | - **Vocabulary**: 151,936 tokens 32 | 33 | ## Building 34 | 35 | ### Requirements 36 | 37 | - GCC or Clang compiler 38 | - OpenMP (optional, for parallel processing) 39 | - libm (math library) 40 | 41 | ### Quick Build 42 | 43 | ```bash 44 | make 45 | ``` 46 | 47 | ### Build Options 48 | 49 | ```bash 50 | make # Optimized build with OpenMP (recommended) 51 | make debug # Debug build with symbols 52 | make no-openmp # Build without OpenMP (single-threaded) 53 | make portable # Portable build for older systems 54 | make static # Static build (all libraries linked) 55 | ``` 56 | 57 | ### Manual Compilation 58 | 59 | ```bash 60 | # With OpenMP for multi-threading 61 | gcc -O3 -fopenmp -march=native qwen_moe.c -lm -o qwen3_moe 62 | 63 | # Without OpenMP 64 | gcc -O3 qwen_moe.c -lm -o qwen3_moe 65 | ``` 66 | 67 | ## Converting Model Weights 68 | 69 | First, convert the PyTorch weights to binary format: 70 | 71 | ### Install Python Dependencies 72 | 73 | ```bash 74 | make install-deps 75 | # or manually: 76 | pip install torch safetensors huggingface_hub 77 | ``` 78 | 79 | ### Convert Weights 80 | 81 | ```bash 82 | python convert_qwen3_weights.py Qwen/Qwen3-Coder-30B-A3B-Instruct qwen3_moe.bin 83 | ``` 84 | 85 | This will download the model from Hugging Face and convert it to the binary format expected by the C code. 86 | 87 | ### Supported Models 88 | 89 | - `Qwen/Qwen3-Coder-30B-A3B-Instruct` (Qwen3 Coder Flash) 90 | - `Qwen/Qwen3-30B-A3B-Thinking-2507` (Thinking model) 91 | - `Qwen/Qwen3-235B-A22B-Instruct-2507` (Large instruct model) 92 | - `Qwen/Qwen3-30B-A3B` (Original instruct/thinking hybrid) 93 | 94 | ## Usage 95 | 96 | ```bash 97 | ./qwen3_moe [temperature] [max_tokens] 98 | ``` 99 | 100 | ### Parameters 101 | 102 | - `model.bin`: Path to the converted binary model file 103 | - `temperature`: Sampling temperature (0.0 for greedy, 0.8 recommended, default: 0.8) 104 | - `max_tokens`: Maximum tokens to generate (default: 100) 105 | 106 | ### Examples 107 | 108 | ```bash 109 | # Greedy decoding (deterministic) 110 | ./qwen3_moe qwen3_moe.bin 0.0 50 111 | 112 | # Creative sampling 113 | ./qwen3_moe qwen3_moe.bin 1.0 200 114 | 115 | # Balanced sampling (recommended) 116 | ./qwen3_moe qwen3_moe.bin 0.8 100 117 | ``` 118 | 119 | ## Performance 120 | 121 | ### Memory Requirements 122 | 123 | - **Model weights**: ~30 GB (float32) 124 | - **Runtime memory**: ~2-4 GB (activation buffers, KV cache) 125 | - **Total**: ~32-34 GB RAM minimum 126 | 127 | ### Speed Optimizations 128 | 129 | The implementation includes several performance optimizations: 130 | 131 | 1. **Memory Mapping**: Weights are memory-mapped, not loaded into RAM 132 | 2. **SIMD Optimization**: Matrix multiplication uses 4-way unrolling for better vectorization 133 | 3. **Sparse MoE**: Only computes active experts (8 out of 128) 134 | 4. **OpenMP Parallelization**: Attention heads processed in parallel 135 | 5. **Cache-Friendly Access**: Sequential memory access patterns where possible 136 | 137 | ### Expected Performance 138 | 139 | On a modern CPU (e.g., Intel i9 or AMD Ryzen 9): 140 | - **Single-threaded**: ~1-2 tokens/second 141 | - **Multi-threaded**: ~3-8 tokens/second (depending on core count) 142 | 143 | GPU acceleration is not currently implemented but would provide significant speedup. 144 | 145 | ## Architecture Implementation Details 146 | 147 | ### RoPE (Rotary Position Embedding) 148 | 149 | ```c 150 | // Applied per-head with proper frequency calculation 151 | void rope(float* q, float* k, int pos, int n_heads, int n_kv_heads, int head_dim, float theta_base) 152 | ``` 153 | 154 | ### Mixture of Experts (MoE) 155 | 156 | ```c 157 | // Sparse computation: only processes top-8 experts 158 | 1. Compute gating scores for all 128 experts 159 | 2. Select top-8 using efficient selection algorithm 160 | 3. Apply softmax to selected experts 161 | 4. Compute SwiGLU: gate_proj → silu → × up_proj → down_proj 162 | 5. Weighted combination based on gating scores 163 | ``` 164 | 165 | ### Grouped Query Attention 166 | 167 | ```c 168 | // 32 query heads, 4 key/value heads (8:1 ratio) 169 | // Keys and values are repeated across query head groups 170 | ``` 171 | 172 | ## Limitations and TODOs 173 | 174 | ### Current Limitations 175 | 176 | 1. **No Tokenizer**: Currently uses dummy tokens, needs proper Qwen3 tokenizer integration 177 | 2. **No Chat Templates**: Raw token generation only 178 | 3. **CPU Only**: No GPU acceleration 179 | 4. **Basic Sampling**: Only supports temperature sampling, no top-p/top-k 180 | 181 | ### Future Improvements 182 | 183 | 1. **Tokenizer Integration**: Add Qwen3 tokenizer for text input/output 184 | 2. **Quantization**: Add int8/int4 quantization for smaller memory footprint 185 | 3. **CUDA Support**: GPU acceleration for faster inference 186 | 4. **Advanced Sampling**: Top-p, top-k, and other sampling strategies 187 | 5. **Beam Search**: Multi-sequence generation 188 | 6. **KV Cache Optimization**: Better memory management for long sequences 189 | 190 | ## Implementation Notes 191 | 192 | ### Binary Format 193 | 194 | The binary format stores weights in the following order: 195 | 196 | 1. Config (13 integers + 1 float) 197 | 2. Token embeddings 198 | 3. RMSNorm weights (attention + FFN) 199 | 4. QK norm weights (if enabled) 200 | 5. Attention weights (Q, K, V, O) 201 | 6. MoE gating weights 202 | 7. Expert weights (gate_proj, down_proj, up_proj for all experts) 203 | 8. Final norm and output weights 204 | 205 | ### Error Handling 206 | 207 | The implementation includes comprehensive error checking: 208 | 209 | - Configuration validation 210 | - Memory allocation failures 211 | - File operation errors 212 | - Bounds checking for array accesses 213 | - Numerical stability checks 214 | 215 | ## Contributing 216 | 217 | This implementation follows the simplicity principle of llama2.c. Contributions should: 218 | 219 | 1. Maintain single-file architecture 220 | 2. Avoid external dependencies 221 | 3. Include comprehensive error handling 222 | 4. Add performance optimizations where possible 223 | 5. Follow existing code style 224 | 225 | ## License 226 | 227 | This code is released under the MIT License, following the same permissive approach as llama2.c. 228 | 229 | ## Acknowledgments 230 | 231 | - **Sebastian Raschka** for the comprehensive Qwen3 MoE implementation and educational materials in his [LLMs-from-scratch repository](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/11_qwen3/standalone-qwen3-moe.ipynb), which served as the reference for understanding the Qwen3 architecture details 232 | - **Andrej Karpathy** for the llama2.c approach and inspiration for clean, educational C implementations 233 | - **Qwen team** for the model architecture and pre-trained weights 234 | - The broader LLM open-source community for advancing accessible AI implementations 235 | 236 | ## References 237 | 238 | - [Qwen3 Technical Report](https://arxiv.org/abs/2505.09388) 239 | - [Sebastian Raschka's LLMs-from-scratch Repository](https://github.com/rasbt/LLMs-from-scratch) - Educational LLM implementations 240 | - [Qwen3 MoE Standalone Notebook](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/11_qwen3/standalone-qwen3-moe.ipynb) by Sebastian Raschka 241 | - [llama2.c](https://github.com/karpathy/llama2.c) by Andrej Karpathy 242 | - [Build a Large Language Model (From Scratch)](http://mng.bz/orYv) by Sebastian Raschka -------------------------------------------------------------------------------- /qwen_moe.c: -------------------------------------------------------------------------------- 1 | /* Inference for Qwen3 MoE Transformer model in pure C 2 | * Based on Andrej Karpathy's llama2.c approach 3 | * Supports Qwen3-30B-A3B models with Mixture of Experts 4 | */ 5 | 6 | // Compiler optimization hints 7 | #ifdef __GNUC__ 8 | #pragma GCC optimize("O3") 9 | #pragma GCC target("avx2,fma") 10 | #endif 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | #ifdef _WIN32 22 | #include 23 | #include 24 | #define mmap(addr, len, prot, flags, fd, offset) \ 25 | ((char*)MapViewOfFile(CreateFileMapping((HANDLE)_get_osfhandle(fd), NULL, PAGE_READONLY, 0, 0, NULL), \ 26 | FILE_MAP_READ, 0, 0, (len))) 27 | #define munmap(addr, len) UnmapViewOfFile(addr) 28 | #define MAP_FAILED NULL 29 | #define MAP_PRIVATE 0 30 | #define PROT_READ 0 31 | #else 32 | #include 33 | #include 34 | #endif 35 | 36 | // ---------------------------------------------------------------------------- 37 | // Transformer model structures 38 | 39 | typedef struct { 40 | int dim; // transformer dimension (2048) 41 | int hidden_dim; // for standard ffn layers (not used in MoE layers) 42 | int n_layers; // number of layers (48) 43 | int n_heads; // number of query heads (32) 44 | int n_kv_heads; // number of key/value heads (4) 45 | int vocab_size; // vocabulary size (151936) 46 | int seq_len; // max sequence length (262144) 47 | int head_dim; // dimension per head (128) 48 | int qk_norm; // whether to use QK normalization (1) 49 | int num_experts; // number of experts per layer (128) 50 | int num_experts_per_tok; // experts activated per token (8) 51 | int moe_intermediate_size; // expert hidden dimension (768) 52 | float rope_theta; // RoPE theta base (10000000.0) 53 | } Config; 54 | 55 | typedef struct { 56 | // token embedding table 57 | float* token_embedding_table; // (vocab_size, dim) 58 | // weights for rmsnorms 59 | float* rms_att_weight; // (layer, dim) rmsnorm weights 60 | float* rms_ffn_weight; // (layer, dim) 61 | // QK norms (if enabled) 62 | float* q_norm_weight; // (layer, head_dim) 63 | float* k_norm_weight; // (layer, head_dim) 64 | // weights for attention 65 | float* wq; // (layer, dim, dim) 66 | float* wk; // (layer, dim, n_kv_heads * head_dim) 67 | float* wv; // (layer, dim, n_kv_heads * head_dim) 68 | float* wo; // (layer, dim, dim) 69 | // MoE gating weights 70 | float* moe_gate; // (layer, dim, num_experts) 71 | // Expert weights - stored as contiguous arrays 72 | float* expert_w1; // (layer, num_experts, moe_intermediate_size, dim) 73 | float* expert_w2; // (layer, num_experts, dim, moe_intermediate_size) 74 | float* expert_w3; // (layer, num_experts, moe_intermediate_size, dim) 75 | // final rmsnorm 76 | float* rms_final_weight; // (dim,) 77 | // output projection 78 | float* wcls; // (dim, vocab_size) 79 | } TransformerWeights; 80 | 81 | typedef struct { 82 | // current wave of activations 83 | float *x; // activation at current time stamp (dim,) 84 | float *xb; // same, but inside a residual branch (dim,) 85 | float *xb2; // additional buffer (dim,) 86 | float *q; // query (dim,) 87 | float *k; // key (dim,) 88 | float *v; // value (dim,) 89 | float *att; // attention scores (n_heads, seq_len) 90 | float *logits; // output logits (vocab_size,) 91 | // kv cache 92 | float* key_cache; // (layer, seq_len, n_kv_heads * head_dim) 93 | float* value_cache; // (layer, seq_len, n_kv_heads * head_dim) 94 | // MoE specific buffers 95 | float* gate_scores; // (num_experts,) 96 | float* expert_outputs; // (num_experts_per_tok, dim) 97 | float* moe_buffer; // (moe_intermediate_size,) 98 | float* moe_temp_buffer; // (moe_intermediate_size,) - temporary buffer for SwiGLU 99 | int* topk_indices; // (num_experts_per_tok,) 100 | float* topk_weights; // (num_experts_per_tok,) 101 | } RunState; 102 | 103 | typedef struct { 104 | Config config; 105 | TransformerWeights weights; 106 | RunState state; 107 | // memory mapping 108 | int fd; 109 | float* data; 110 | ssize_t file_size; 111 | } Transformer; 112 | 113 | // ---------------------------------------------------------------------------- 114 | // Memory allocation and initialization 115 | 116 | void malloc_run_state(RunState* s, Config* p) { 117 | int kv_dim = (p->n_kv_heads * p->head_dim); 118 | 119 | s->x = calloc(p->dim, sizeof(float)); 120 | s->xb = calloc(p->dim, sizeof(float)); 121 | s->xb2 = calloc(p->dim, sizeof(float)); 122 | s->q = calloc(p->dim, sizeof(float)); 123 | s->k = calloc(kv_dim, sizeof(float)); 124 | s->v = calloc(kv_dim, sizeof(float)); 125 | s->att = calloc(p->n_heads * p->seq_len, sizeof(float)); 126 | s->logits = calloc(p->vocab_size, sizeof(float)); 127 | 128 | s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); 129 | s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); 130 | 131 | // MoE specific 132 | s->gate_scores = calloc(p->num_experts, sizeof(float)); 133 | s->expert_outputs = calloc(p->num_experts_per_tok * p->dim, sizeof(float)); 134 | s->moe_buffer = calloc(p->moe_intermediate_size, sizeof(float)); 135 | s->moe_temp_buffer = calloc(p->moe_intermediate_size, sizeof(float)); 136 | s->topk_indices = calloc(p->num_experts_per_tok, sizeof(int)); 137 | s->topk_weights = calloc(p->num_experts_per_tok, sizeof(float)); 138 | 139 | // Check allocations 140 | if (!s->x || !s->xb || !s->xb2 || !s->q || !s->k || !s->v || !s->att || 141 | !s->logits || !s->key_cache || !s->value_cache || !s->gate_scores || 142 | !s->expert_outputs || !s->moe_buffer || !s->moe_temp_buffer || 143 | !s->topk_indices || !s->topk_weights) { 144 | fprintf(stderr, "malloc failed!\n"); 145 | exit(EXIT_FAILURE); 146 | } 147 | } 148 | 149 | void free_run_state(RunState* s) { 150 | free(s->x); 151 | free(s->xb); 152 | free(s->xb2); 153 | free(s->q); 154 | free(s->k); 155 | free(s->v); 156 | free(s->att); 157 | free(s->logits); 158 | free(s->key_cache); 159 | free(s->value_cache); 160 | free(s->gate_scores); 161 | free(s->expert_outputs); 162 | free(s->moe_buffer); 163 | free(s->moe_temp_buffer); 164 | free(s->topk_indices); 165 | free(s->topk_weights); 166 | } 167 | 168 | // ---------------------------------------------------------------------------- 169 | // Neural network blocks 170 | 171 | void rmsnorm(float* o, float* x, float* weight, int size) { 172 | // Calculate sum of squares 173 | float ss = 0.0f; 174 | for (int j = 0; j < size; j++) { 175 | ss += x[j] * x[j]; 176 | } 177 | ss /= size; 178 | ss += 1e-6f; // Qwen3 uses 1e-6 epsilon 179 | ss = 1.0f / sqrtf(ss); 180 | // Normalize and scale 181 | for (int j = 0; j < size; j++) { 182 | o[j] = weight[j] * (ss * x[j]); 183 | } 184 | } 185 | 186 | void softmax(float* x, int size) { 187 | // Find max for numerical stability 188 | float max_val = x[0]; 189 | for (int i = 1; i < size; i++) { 190 | if (x[i] > max_val) { 191 | max_val = x[i]; 192 | } 193 | } 194 | // Exp and sum 195 | float sum = 0.0f; 196 | for (int i = 0; i < size; i++) { 197 | x[i] = expf(x[i] - max_val); 198 | sum += x[i]; 199 | } 200 | // Normalize 201 | for (int i = 0; i < size; i++) { 202 | x[i] /= sum; 203 | } 204 | } 205 | 206 | void matmul(float* xout, float* x, float* w, int n, int d) { 207 | // W (d,n) @ x (n,) -> xout (d,) 208 | #pragma omp parallel for schedule(static) 209 | for (int i = 0; i < d; i++) { 210 | float val = 0.0f; 211 | // Add restrict hint and unroll hint for better vectorization 212 | float* __restrict__ w_row = w + i * n; 213 | float* __restrict__ x_vec = x; 214 | 215 | // Process in chunks of 4 for better SIMD utilization 216 | int n_vec = n & ~3; // Round down to multiple of 4 217 | for (int j = 0; j < n_vec; j += 4) { 218 | val += w_row[j] * x_vec[j] + 219 | w_row[j+1] * x_vec[j+1] + 220 | w_row[j+2] * x_vec[j+2] + 221 | w_row[j+3] * x_vec[j+3]; 222 | } 223 | 224 | // Handle remaining elements 225 | for (int j = n_vec; j < n; j++) { 226 | val += w_row[j] * x_vec[j]; 227 | } 228 | 229 | xout[i] = val; 230 | } 231 | } 232 | 233 | // SiLU activation function (used in SwiGLU) 234 | float silu(float x) { 235 | return x / (1.0f + expf(-x)); 236 | } 237 | 238 | // Apply RoPE to q and k vectors per head 239 | void rope(float* q, float* k, int pos, int n_heads, int n_kv_heads, int head_dim, float theta_base) { 240 | // Apply RoPE to each query head 241 | for (int h = 0; h < n_heads; h++) { 242 | float* head_q = q + h * head_dim; 243 | 244 | for (int i = 0; i < head_dim; i += 2) { 245 | float freq = 1.0f / powf(theta_base, (float)i / (float)head_dim); 246 | float val = pos * freq; 247 | float fcr = cosf(val); 248 | float fci = sinf(val); 249 | 250 | float v0 = head_q[i]; 251 | float v1 = head_q[i + 1]; 252 | head_q[i] = v0 * fcr - v1 * fci; 253 | head_q[i + 1] = v0 * fci + v1 * fcr; 254 | } 255 | } 256 | 257 | // Apply RoPE to each key head 258 | if (k != NULL) { 259 | for (int h = 0; h < n_kv_heads; h++) { 260 | float* head_k = k + h * head_dim; 261 | 262 | for (int i = 0; i < head_dim; i += 2) { 263 | float freq = 1.0f / powf(theta_base, (float)i / (float)head_dim); 264 | float val = pos * freq; 265 | float fcr = cosf(val); 266 | float fci = sinf(val); 267 | 268 | float v0 = head_k[i]; 269 | float v1 = head_k[i + 1]; 270 | head_k[i] = v0 * fcr - v1 * fci; 271 | head_k[i + 1] = v0 * fci + v1 * fcr; 272 | } 273 | } 274 | } 275 | } 276 | 277 | // Find top-k values and their indices using heap-based approach 278 | void topk(float* values, int n, int k, int* indices, float* topk_values) { 279 | // For small k (which is typical for MoE), use simple selection 280 | // Initialize with first k elements 281 | for (int i = 0; i < k; i++) { 282 | indices[i] = i; 283 | topk_values[i] = values[i]; 284 | } 285 | 286 | // Sort initial k elements descending 287 | for (int i = 0; i < k - 1; i++) { 288 | for (int j = i + 1; j < k; j++) { 289 | if (topk_values[j] > topk_values[i]) { 290 | float temp_val = topk_values[i]; 291 | int temp_idx = indices[i]; 292 | topk_values[i] = topk_values[j]; 293 | indices[i] = indices[j]; 294 | topk_values[j] = temp_val; 295 | indices[j] = temp_idx; 296 | } 297 | } 298 | } 299 | 300 | // Check remaining elements 301 | for (int i = k; i < n; i++) { 302 | // If current value is larger than smallest in top-k 303 | if (values[i] > topk_values[k-1]) { 304 | // Find insertion position 305 | int pos = k - 1; 306 | while (pos > 0 && values[i] > topk_values[pos-1]) { 307 | pos--; 308 | } 309 | 310 | // Shift and insert 311 | for (int j = k - 1; j > pos; j--) { 312 | topk_values[j] = topk_values[j-1]; 313 | indices[j] = indices[j-1]; 314 | } 315 | topk_values[pos] = values[i]; 316 | indices[pos] = i; 317 | } 318 | } 319 | } 320 | 321 | // ---------------------------------------------------------------------------- 322 | // Forward pass 323 | 324 | float* forward(Transformer* transformer, int token, int pos) { 325 | Config* p = &transformer->config; 326 | TransformerWeights* w = &transformer->weights; 327 | RunState* s = &transformer->state; 328 | float *x = s->x; 329 | int dim = p->dim; 330 | int kv_dim = p->n_kv_heads * p->head_dim; 331 | int head_size = p->head_dim; 332 | 333 | // Bounds checking 334 | if (token < 0 || token >= p->vocab_size) { 335 | fprintf(stderr, "Error: Token %d out of bounds [0, %d)\n", token, p->vocab_size); 336 | exit(EXIT_FAILURE); 337 | } 338 | if (pos < 0 || pos >= p->seq_len) { 339 | fprintf(stderr, "Error: Position %d out of bounds [0, %d)\n", pos, p->seq_len); 340 | exit(EXIT_FAILURE); 341 | } 342 | 343 | // Copy token embedding into x 344 | float* content_row = w->token_embedding_table + token * dim; 345 | memcpy(x, content_row, dim * sizeof(float)); 346 | 347 | // Forward all layers 348 | for (int l = 0; l < p->n_layers; l++) { 349 | // Attention RMSNorm 350 | rmsnorm(s->xb, x, w->rms_att_weight + l * dim, dim); 351 | 352 | // QKV projections 353 | matmul(s->q, s->xb, w->wq + l * dim * dim, dim, dim); 354 | matmul(s->k, s->xb, w->wk + l * dim * kv_dim, dim, kv_dim); 355 | matmul(s->v, s->xb, w->wv + l * dim * kv_dim, dim, kv_dim); 356 | 357 | // Apply QK normalization if enabled 358 | if (p->qk_norm) { 359 | // Normalize Q 360 | for (int h = 0; h < p->n_heads; h++) { 361 | rmsnorm(s->q + h * head_size, s->q + h * head_size, 362 | w->q_norm_weight + l * head_size, head_size); 363 | } 364 | // Normalize K 365 | for (int h = 0; h < p->n_kv_heads; h++) { 366 | rmsnorm(s->k + h * head_size, s->k + h * head_size, 367 | w->k_norm_weight + l * head_size, head_size); 368 | } 369 | } 370 | 371 | // Apply RoPE 372 | rope(s->q, s->k, pos, p->n_heads, p->n_kv_heads, head_size, p->rope_theta); 373 | 374 | // Cache K and V 375 | int loff = l * p->seq_len * kv_dim; 376 | float* key_cache_row = s->key_cache + loff + pos * kv_dim; 377 | float* value_cache_row = s->value_cache + loff + pos * kv_dim; 378 | memcpy(key_cache_row, s->k, kv_dim * sizeof(float)); 379 | memcpy(value_cache_row, s->v, kv_dim * sizeof(float)); 380 | 381 | // Multihead attention 382 | memset(s->xb, 0, dim * sizeof(float)); 383 | int kv_mul = p->n_heads / p->n_kv_heads; 384 | 385 | #pragma omp parallel for 386 | for (int h = 0; h < p->n_heads; h++) { 387 | float* q = s->q + h * head_size; 388 | float* att = s->att + h * p->seq_len; 389 | 390 | // Compute attention scores 391 | for (int t = 0; t <= pos; t++) { 392 | float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size; 393 | float score = 0.0f; 394 | for (int i = 0; i < head_size; i++) { 395 | score += q[i] * k[i]; 396 | } 397 | score /= sqrtf(head_size); 398 | att[t] = score; 399 | } 400 | 401 | // Softmax 402 | softmax(att, pos + 1); 403 | 404 | // Weighted sum of values 405 | float* xb = s->xb + h * head_size; 406 | for (int t = 0; t <= pos; t++) { 407 | float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size; 408 | float a = att[t]; 409 | for (int i = 0; i < head_size; i++) { 410 | xb[i] += a * v[i]; 411 | } 412 | } 413 | } 414 | 415 | // Output projection 416 | matmul(s->xb2, s->xb, w->wo + l * dim * dim, dim, dim); 417 | 418 | // Residual connection 419 | for (int i = 0; i < dim; i++) { 420 | x[i] += s->xb2[i]; 421 | } 422 | 423 | // FFN RMSNorm 424 | rmsnorm(s->xb, x, w->rms_ffn_weight + l * dim, dim); 425 | 426 | // MoE layer 427 | // Compute gating scores 428 | matmul(s->gate_scores, s->xb, w->moe_gate + l * dim * p->num_experts, 429 | dim, p->num_experts); 430 | 431 | // Find top-k experts 432 | topk(s->gate_scores, p->num_experts, p->num_experts_per_tok, 433 | s->topk_indices, s->topk_weights); 434 | 435 | // Softmax over selected experts 436 | softmax(s->topk_weights, p->num_experts_per_tok); 437 | 438 | // Clear expert outputs 439 | memset(s->expert_outputs, 0, p->num_experts_per_tok * dim * sizeof(float)); 440 | 441 | // Process each selected expert 442 | for (int i = 0; i < p->num_experts_per_tok; i++) { 443 | int expert_id = s->topk_indices[i]; 444 | float expert_weight = s->topk_weights[i]; 445 | 446 | // Get expert weight offsets - weights are stored contiguously for all experts 447 | size_t expert_idx = (size_t)l * p->num_experts + expert_id; 448 | size_t w1_size = p->moe_intermediate_size * dim; 449 | size_t w2_size = dim * p->moe_intermediate_size; 450 | size_t w3_size = p->moe_intermediate_size * dim; 451 | 452 | float* w1 = w->expert_w1 + expert_idx * w1_size; // gate_proj: dim -> intermediate 453 | float* w2 = w->expert_w2 + expert_idx * w2_size; // down_proj: intermediate -> dim 454 | float* w3 = w->expert_w3 + expert_idx * w3_size; // up_proj: dim -> intermediate 455 | 456 | // Compute SwiGLU: w2(silu(w1(x)) * w3(x)) 457 | float* gate_output = s->moe_buffer; // w1(x) output 458 | float* up_output = s->moe_temp_buffer; // w3(x) output 459 | float* expert_output = s->expert_outputs + i * dim; // final output 460 | 461 | // w1(x) -> gate_output (gate projection) 462 | matmul(gate_output, s->xb, w1, dim, p->moe_intermediate_size); 463 | 464 | // w3(x) -> up_output (up projection) 465 | matmul(up_output, s->xb, w3, dim, p->moe_intermediate_size); 466 | 467 | // SwiGLU: silu(w1(x)) * w3(x) 468 | for (int j = 0; j < p->moe_intermediate_size; j++) { 469 | gate_output[j] = silu(gate_output[j]) * up_output[j]; 470 | } 471 | 472 | // w2(silu(w1(x)) * w3(x)) -> expert_output (down projection) 473 | matmul(expert_output, gate_output, w2, p->moe_intermediate_size, dim); 474 | 475 | // Weight by gating score 476 | for (int j = 0; j < dim; j++) { 477 | expert_output[j] *= expert_weight; 478 | } 479 | } 480 | 481 | // Sum expert outputs 482 | memset(s->xb2, 0, dim * sizeof(float)); 483 | for (int i = 0; i < p->num_experts_per_tok; i++) { 484 | float* expert_out = s->expert_outputs + i * dim; 485 | for (int j = 0; j < dim; j++) { 486 | s->xb2[j] += expert_out[j]; 487 | } 488 | } 489 | 490 | // Residual connection 491 | for (int i = 0; i < dim; i++) { 492 | x[i] += s->xb2[i]; 493 | } 494 | } 495 | 496 | // Final RMSNorm 497 | rmsnorm(x, x, w->rms_final_weight, dim); 498 | 499 | // Classifier 500 | matmul(s->logits, x, w->wcls, dim, p->vocab_size); 501 | return s->logits; 502 | } 503 | 504 | // ---------------------------------------------------------------------------- 505 | // Model loading 506 | 507 | void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights, 508 | int* fd, float** data, ssize_t* file_size) { 509 | FILE *file = fopen(checkpoint, "rb"); 510 | if (!file) { 511 | fprintf(stderr, "Error: Couldn't open file %s\n", checkpoint); 512 | exit(EXIT_FAILURE); 513 | } 514 | 515 | // Read config fields individually to avoid struct padding issues 516 | if (fread(&config->dim, sizeof(int), 1, file) != 1) { 517 | fprintf(stderr, "Error: Failed to read config.dim\n"); 518 | fclose(file); 519 | exit(EXIT_FAILURE); 520 | } 521 | if (fread(&config->hidden_dim, sizeof(int), 1, file) != 1) { 522 | fprintf(stderr, "Error: Failed to read config.hidden_dim\n"); 523 | fclose(file); 524 | exit(EXIT_FAILURE); 525 | } 526 | if (fread(&config->n_layers, sizeof(int), 1, file) != 1) { 527 | fprintf(stderr, "Error: Failed to read config.n_layers\n"); 528 | fclose(file); 529 | exit(EXIT_FAILURE); 530 | } 531 | if (fread(&config->n_heads, sizeof(int), 1, file) != 1) { 532 | fprintf(stderr, "Error: Failed to read config.n_heads\n"); 533 | fclose(file); 534 | exit(EXIT_FAILURE); 535 | } 536 | if (fread(&config->n_kv_heads, sizeof(int), 1, file) != 1) { 537 | fprintf(stderr, "Error: Failed to read config.n_kv_heads\n"); 538 | fclose(file); 539 | exit(EXIT_FAILURE); 540 | } 541 | if (fread(&config->vocab_size, sizeof(int), 1, file) != 1) { 542 | fprintf(stderr, "Error: Failed to read config.vocab_size\n"); 543 | fclose(file); 544 | exit(EXIT_FAILURE); 545 | } 546 | if (fread(&config->seq_len, sizeof(int), 1, file) != 1) { 547 | fprintf(stderr, "Error: Failed to read config.seq_len\n"); 548 | fclose(file); 549 | exit(EXIT_FAILURE); 550 | } 551 | if (fread(&config->head_dim, sizeof(int), 1, file) != 1) { 552 | fprintf(stderr, "Error: Failed to read config.head_dim\n"); 553 | fclose(file); 554 | exit(EXIT_FAILURE); 555 | } 556 | if (fread(&config->qk_norm, sizeof(int), 1, file) != 1) { 557 | fprintf(stderr, "Error: Failed to read config.qk_norm\n"); 558 | fclose(file); 559 | exit(EXIT_FAILURE); 560 | } 561 | if (fread(&config->num_experts, sizeof(int), 1, file) != 1) { 562 | fprintf(stderr, "Error: Failed to read config.num_experts\n"); 563 | fclose(file); 564 | exit(EXIT_FAILURE); 565 | } 566 | if (fread(&config->num_experts_per_tok, sizeof(int), 1, file) != 1) { 567 | fprintf(stderr, "Error: Failed to read config.num_experts_per_tok\n"); 568 | fclose(file); 569 | exit(EXIT_FAILURE); 570 | } 571 | if (fread(&config->moe_intermediate_size, sizeof(int), 1, file) != 1) { 572 | fprintf(stderr, "Error: Failed to read config.moe_intermediate_size\n"); 573 | fclose(file); 574 | exit(EXIT_FAILURE); 575 | } 576 | if (fread(&config->rope_theta, sizeof(float), 1, file) != 1) { 577 | fprintf(stderr, "Error: Failed to read config.rope_theta\n"); 578 | fclose(file); 579 | exit(EXIT_FAILURE); 580 | } 581 | 582 | // Validate config parameters 583 | if (config->dim <= 0 || config->dim > 8192) { 584 | fprintf(stderr, "Error: Invalid dim %d\n", config->dim); 585 | fclose(file); 586 | exit(EXIT_FAILURE); 587 | } 588 | if (config->n_layers <= 0 || config->n_layers > 200) { 589 | fprintf(stderr, "Error: Invalid n_layers %d\n", config->n_layers); 590 | fclose(file); 591 | exit(EXIT_FAILURE); 592 | } 593 | if (config->n_heads <= 0 || config->n_heads > 256) { 594 | fprintf(stderr, "Error: Invalid n_heads %d\n", config->n_heads); 595 | fclose(file); 596 | exit(EXIT_FAILURE); 597 | } 598 | if (config->n_kv_heads <= 0 || config->n_kv_heads > config->n_heads) { 599 | fprintf(stderr, "Error: Invalid n_kv_heads %d\n", config->n_kv_heads); 600 | fclose(file); 601 | exit(EXIT_FAILURE); 602 | } 603 | if (config->head_dim <= 0 || config->head_dim > 512) { 604 | fprintf(stderr, "Error: Invalid head_dim %d\n", config->head_dim); 605 | fclose(file); 606 | exit(EXIT_FAILURE); 607 | } 608 | if (config->num_experts <= 0 || config->num_experts > 1024) { 609 | fprintf(stderr, "Error: Invalid num_experts %d\n", config->num_experts); 610 | fclose(file); 611 | exit(EXIT_FAILURE); 612 | } 613 | if (config->num_experts_per_tok <= 0 || config->num_experts_per_tok > config->num_experts) { 614 | fprintf(stderr, "Error: Invalid num_experts_per_tok %d\n", config->num_experts_per_tok); 615 | fclose(file); 616 | exit(EXIT_FAILURE); 617 | } 618 | if (config->n_heads % config->n_kv_heads != 0) { 619 | fprintf(stderr, "Error: n_heads (%d) must be divisible by n_kv_heads (%d)\n", 620 | config->n_heads, config->n_kv_heads); 621 | fclose(file); 622 | exit(EXIT_FAILURE); 623 | } 624 | 625 | // Get current position (config size) 626 | long config_size = ftell(file); 627 | 628 | // Get file size 629 | fseek(file, 0, SEEK_END); 630 | *file_size = ftell(file); 631 | fclose(file); 632 | 633 | // Memory map 634 | *fd = open(checkpoint, O_RDONLY); 635 | if (*fd == -1) { 636 | fprintf(stderr, "open failed!\n"); 637 | exit(EXIT_FAILURE); 638 | } 639 | 640 | *data = (float*)mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0); 641 | if (*data == MAP_FAILED) { 642 | fprintf(stderr, "mmap failed!\n"); 643 | exit(EXIT_FAILURE); 644 | } 645 | 646 | // Map weights in the exact order they were written by the Python script 647 | float* ptr = (float*)((char*)*data + config_size); 648 | 649 | // 1. Token embeddings 650 | weights->token_embedding_table = ptr; 651 | ptr += config->vocab_size * config->dim; 652 | 653 | // 2. RMSNorm weights (attention) 654 | weights->rms_att_weight = ptr; 655 | ptr += config->n_layers * config->dim; 656 | 657 | // 3. RMSNorm weights (FFN) 658 | weights->rms_ffn_weight = ptr; 659 | ptr += config->n_layers * config->dim; 660 | 661 | // 4. QK norm weights (if enabled) 662 | if (config->qk_norm) { 663 | weights->q_norm_weight = ptr; 664 | ptr += config->n_layers * config->head_dim; 665 | weights->k_norm_weight = ptr; 666 | ptr += config->n_layers * config->head_dim; 667 | } else { 668 | weights->q_norm_weight = NULL; 669 | weights->k_norm_weight = NULL; 670 | } 671 | 672 | // 5. Attention weights 673 | weights->wq = ptr; 674 | ptr += config->n_layers * config->dim * config->dim; 675 | weights->wk = ptr; 676 | ptr += config->n_layers * config->dim * config->n_kv_heads * config->head_dim; 677 | weights->wv = ptr; 678 | ptr += config->n_layers * config->dim * config->n_kv_heads * config->head_dim; 679 | weights->wo = ptr; 680 | ptr += config->n_layers * config->dim * config->dim; 681 | 682 | // 6. MoE gating weights 683 | weights->moe_gate = ptr; 684 | ptr += config->n_layers * config->dim * config->num_experts; 685 | 686 | // 7. Expert weights (gate_proj for all experts, then down_proj, then up_proj) 687 | size_t total_experts = (size_t)config->n_layers * config->num_experts; 688 | 689 | weights->expert_w1 = ptr; // All gate_proj weights 690 | ptr += total_experts * config->moe_intermediate_size * config->dim; 691 | 692 | weights->expert_w2 = ptr; // All down_proj weights 693 | ptr += total_experts * config->dim * config->moe_intermediate_size; 694 | 695 | weights->expert_w3 = ptr; // All up_proj weights 696 | ptr += total_experts * config->moe_intermediate_size * config->dim; 697 | 698 | // 8. Final norm and output 699 | weights->rms_final_weight = ptr; 700 | ptr += config->dim; 701 | weights->wcls = ptr; 702 | } 703 | 704 | void build_transformer(Transformer *t, char* checkpoint_path) { 705 | read_checkpoint(checkpoint_path, &t->config, &t->weights, 706 | &t->fd, &t->data, &t->file_size); 707 | malloc_run_state(&t->state, &t->config); 708 | } 709 | 710 | void free_transformer(Transformer* t) { 711 | if (t->data != MAP_FAILED) { 712 | munmap(t->data, t->file_size); 713 | } 714 | if (t->fd != -1) { 715 | close(t->fd); 716 | } 717 | free_run_state(&t->state); 718 | } 719 | 720 | // ---------------------------------------------------------------------------- 721 | // Sampling 722 | 723 | int sample_argmax(float* probabilities, int n) { 724 | int max_i = 0; 725 | float max_p = probabilities[0]; 726 | for (int i = 1; i < n; i++) { 727 | if (probabilities[i] > max_p) { 728 | max_i = i; 729 | max_p = probabilities[i]; 730 | } 731 | } 732 | return max_i; 733 | } 734 | 735 | int sample_temperature(float* logits, int n, float temperature) { 736 | // Apply temperature 737 | if (temperature == 0.0f) { 738 | return sample_argmax(logits, n); 739 | } 740 | 741 | // Find max for numerical stability 742 | float max_logit = logits[0]; 743 | for (int i = 1; i < n; i++) { 744 | if (logits[i] > max_logit) { 745 | max_logit = logits[i]; 746 | } 747 | } 748 | 749 | // Apply temperature and softmax 750 | float sum = 0.0f; 751 | for (int i = 0; i < n; i++) { 752 | logits[i] = expf((logits[i] - max_logit) / temperature); 753 | sum += logits[i]; 754 | } 755 | 756 | // Normalize to probabilities 757 | for (int i = 0; i < n; i++) { 758 | logits[i] /= sum; 759 | } 760 | 761 | // Sample from the distribution 762 | float r = (float)rand() / RAND_MAX; 763 | float cumsum = 0.0f; 764 | for (int i = 0; i < n; i++) { 765 | cumsum += logits[i]; 766 | if (r < cumsum) { 767 | return i; 768 | } 769 | } 770 | return n - 1; // fallback 771 | } 772 | 773 | // ---------------------------------------------------------------------------- 774 | // Generation 775 | 776 | void generate(Transformer *transformer, int* prompt_tokens, int num_prompt_tokens, 777 | int steps, float temperature) { 778 | int next; 779 | int token = prompt_tokens[0]; 780 | int pos = 0; 781 | 782 | printf("\nGenerating with temperature %.2f...\n", temperature); 783 | 784 | while (pos < steps) { 785 | // Forward pass 786 | float* logits = forward(transformer, token, pos); 787 | 788 | // Get next token 789 | if (pos < num_prompt_tokens - 1) { 790 | next = prompt_tokens[pos + 1]; 791 | } else { 792 | // Temperature-based sampling 793 | if (temperature == 0.0f) { 794 | next = sample_argmax(logits, transformer->config.vocab_size); 795 | } else { 796 | // Create a copy since temperature sampling modifies logits 797 | float* logits_copy = malloc(transformer->config.vocab_size * sizeof(float)); 798 | if (!logits_copy) { 799 | fprintf(stderr, "Error: Failed to allocate memory for logits copy\n"); 800 | exit(EXIT_FAILURE); 801 | } 802 | memcpy(logits_copy, logits, transformer->config.vocab_size * sizeof(float)); 803 | next = sample_temperature(logits_copy, transformer->config.vocab_size, temperature); 804 | free(logits_copy); 805 | } 806 | } 807 | 808 | pos++; 809 | 810 | // Check for EOS token (you'd need to define this based on tokenizer) 811 | if (next == 151643) { // Qwen3 <|im_end|> token 812 | break; 813 | } 814 | 815 | // Print token (you'd need proper decoding here) 816 | if (pos >= num_prompt_tokens) { 817 | printf("Token %d: %d\n", pos, next); 818 | } 819 | 820 | token = next; 821 | } 822 | 823 | printf("\nGenerated %d tokens\n", pos); 824 | } 825 | 826 | // ---------------------------------------------------------------------------- 827 | // Main 828 | 829 | int main(int argc, char *argv[]) { 830 | if (argc < 2) { 831 | fprintf(stderr, "Usage: %s [temperature] [max_tokens]\n", argv[0]); 832 | fprintf(stderr, " checkpoint.bin: Path to the binary model file\n"); 833 | fprintf(stderr, " temperature: Sampling temperature (0.0 for greedy, default 0.8)\n"); 834 | fprintf(stderr, " max_tokens: Maximum tokens to generate (default 100)\n"); 835 | return 1; 836 | } 837 | 838 | char* checkpoint_path = argv[1]; 839 | float temperature = (argc > 2) ? atof(argv[2]) : 0.8f; 840 | int max_tokens = (argc > 3) ? atoi(argv[3]) : 100; 841 | 842 | // Validate parameters 843 | if (temperature < 0.0f || temperature > 2.0f) { 844 | fprintf(stderr, "Warning: Temperature %.2f is outside recommended range [0.0, 2.0]\n", temperature); 845 | } 846 | if (max_tokens <= 0 || max_tokens > 4096) { 847 | fprintf(stderr, "Error: max_tokens must be between 1 and 4096\n"); 848 | return 1; 849 | } 850 | 851 | // Seed random number generator 852 | srand(time(NULL)); 853 | 854 | // Build transformer 855 | Transformer transformer; 856 | build_transformer(&transformer, checkpoint_path); 857 | 858 | printf("Model loaded successfully!\n"); 859 | printf("Config: dim=%d, n_layers=%d, n_heads=%d, num_experts=%d\n", 860 | transformer.config.dim, transformer.config.n_layers, 861 | transformer.config.n_heads, transformer.config.num_experts); 862 | printf("Vocab size: %d, Seq len: %d, Head dim: %d\n", 863 | transformer.config.vocab_size, transformer.config.seq_len, transformer.config.head_dim); 864 | 865 | // Example generation with dummy tokens 866 | // In practice, you'd tokenize the input prompt 867 | int prompt_tokens[] = {151644, 882, 271}; // Example tokens (<|im_start|>user\n) 868 | int num_tokens = 3; 869 | 870 | generate(&transformer, prompt_tokens, num_tokens, max_tokens, temperature); 871 | 872 | // Cleanup 873 | free_transformer(&transformer); 874 | 875 | return 0; 876 | } --------------------------------------------------------------------------------