├── .gitignore ├── README.md ├── build.zig ├── download_weights.py ├── generate_nano_gpt.py ├── generate_test_data.py └── src ├── bpe.zig ├── main.zig ├── ops.zig └── tests.zig /.gitignore: -------------------------------------------------------------------------------- 1 | zig-cache/ 2 | zig-out/ 3 | models/ 4 | lib/OpenBLAS 5 | .venv/ 6 | .DS_Store 7 | main* 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # zig_gpt2 2 | GPT-2 inference engine written in Zig. Generation time: ~28ms per token. 3 | 4 | ### Features: 5 | * No third-party dependencies besides BLAS (Accelerate or OpenBLAS). 6 | * No memory allocations at runtime. 7 | * Can run [NanoGPT](https://github.com/karpathy/nanoGPT). 8 | 9 | ### How to Run: 10 | 11 | Download the GPT-2 checkpoint from OpenAI. 12 | ```bash 13 | python3 download_weights.py 14 | ``` 15 | 16 | Build the Zig binary and run it with a prompt to generate completions: 17 | ```bash 18 | zig build -DOptimize=ReleaseFast 19 | ./zig-out/bin/zig_gpt2 "Marcus Aurelius said" 20 | ``` 21 | 22 | ### How to Test: 23 | 24 | Generate test data by forwarding random tensors through PyTorch ops. 25 | ```bash 26 | python3 generate_test_data.py 27 | ``` 28 | 29 | Run tests. Verifies Zig ops produce the same output as PyTorch. 30 | ```bash 31 | zig build test 32 | ``` 33 | 34 | --- 35 | 36 | ### TODO 37 | 38 | Implementation: 39 | * ✅ Implement basic ops: Embedding, Linear, LayerNorm, GELU, Softmax, CausalSelfAttention. 40 | * ✅ Implement transformer modules: MLP, Transformer block. 41 | * ✅ Implement the full GPT model. 42 | * ✅ Implement sampling from the model. 43 | * ✅ Implement BPE encoding/decoding. 44 | 45 | Efficiency: 46 | * ✅ Replace custom linear algebra kernels with BLAS. 47 | * ✅ Stream output as each new token is generated. 48 | * ✅ Create central set of memory buffers and reuse them for each layer. No allocations at runtime. 49 | * ✅ Add KV cache. 50 | * Parallelize `softmax` and `gelu` operations. 51 | -------------------------------------------------------------------------------- /build.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | 3 | // Although this function looks imperative, note that its job is to 4 | // declaratively construct a build graph that will be executed by an external 5 | // runner. 6 | pub fn build(b: *std.Build) void { 7 | // Standard target options allows the person running `zig build` to choose 8 | // what target to build for. Here we do not override the defaults, which 9 | // means any target is allowed, and the default is native. Other options 10 | // for restricting supported target set are available. 11 | const target = b.standardTargetOptions(.{}); 12 | 13 | // Standard optimization options allow the person running `zig build` to select 14 | // between Debug, ReleaseSafe, ReleaseFast, and ReleaseSmall. Here we do not 15 | // set a preferred release mode, allowing the user to decide how to optimize. 16 | const optimize = b.standardOptimizeOption(.{}); 17 | 18 | const exe = b.addExecutable(.{ 19 | .name = "zig_gpt2", 20 | // In this case the main source file is merely a path, however, in more 21 | // complicated build scripts, this could be a generated file. 22 | .root_source_file = .{ .path = "src/main.zig" }, 23 | .target = target, 24 | .optimize = optimize, 25 | }); 26 | exe.linkLibC(); 27 | 28 | // Link BLAS. If you're not on MacOS, you need to download and build OpenBLAS and uncomment the 29 | // lines below. 30 | exe.linkFramework("Accelerate"); 31 | // exe.addIncludePath("lib/OpenBLAS"); 32 | // exe.addObjectFile("lib/OpenBLAS/libopenblas.a"); 33 | 34 | // // This declares intent for the executable to be installed into the 35 | // standard location when the user invokes the "install" step (the default 36 | // step when running `zig build`). 37 | b.installArtifact(exe); 38 | 39 | // This *creates* a Run step in the build graph, to be executed when another 40 | // step is evaluated that depends on it. The next line below will establish 41 | // such a dependency. 42 | const run_cmd = b.addRunArtifact(exe); 43 | 44 | // By making the run step depend on the install step, it will be run from the 45 | // installation directory rather than directly from within the cache directory. 46 | // This is not necessary, however, if the application depends on other installed 47 | // files, this ensures they will be present and in the expected location. 48 | run_cmd.step.dependOn(b.getInstallStep()); 49 | 50 | // This allows the user to pass arguments to the application in the build 51 | // command itself, like this: `zig build run -- arg1 arg2 etc` 52 | if (b.args) |args| { 53 | run_cmd.addArgs(args); 54 | } 55 | 56 | // This creates a build step. It will be visible in the `zig build --help` menu, 57 | // and can be selected like this: `zig build run` 58 | // This will evaluate the `run` step rather than the default, which is "install". 59 | const run_step = b.step("run", "Run the app"); 60 | run_step.dependOn(&run_cmd.step); 61 | 62 | // Creates a step for unit testing. This only builds the test executable 63 | // but does not run it. 64 | const unit_tests = b.addTest(.{ 65 | .root_source_file = .{ .path = "src/tests.zig" }, 66 | .target = target, 67 | .optimize = optimize, 68 | }); 69 | // Link BLAS. If you're not on MacOS, you need to download and build OpenBLAS and uncomment the 70 | // lines below. 71 | unit_tests.linkFramework("Accelerate"); 72 | // unit_tests.addIncludePath("lib/OpenBLAS"); 73 | // unit_tests.addObjectFile("lib/OpenBLAS/libopenblas.a"); 74 | 75 | const run_unit_tests = b.addRunArtifact(unit_tests); 76 | 77 | // Similar to creating the run step earlier, this exposes a `test` step to 78 | // the `zig build --help` menu, providing a way for the user to request 79 | // running the unit tests. 80 | const test_step = b.step("test", "Run unit tests"); 81 | test_step.dependOn(&run_unit_tests.step); 82 | } 83 | -------------------------------------------------------------------------------- /download_weights.py: -------------------------------------------------------------------------------- 1 | """Downloads GPT-2 checkpoints from OpenAI. 2 | 3 | Weight tensors are transposed and dumped in raw binary so they can easily be loaded into 4 | Zig/PyTorch. The unicode->byte encoder is statically generated and dumped to json. 5 | 6 | Based on https://github.com/openai/gpt-2. 7 | """ 8 | 9 | import json 10 | import os 11 | 12 | import numpy as np 13 | import requests 14 | import tensorflow as tf 15 | from tqdm import tqdm 16 | 17 | model = "models/124M" 18 | 19 | # Download the model weights from OpenAI if they don't already exist. 20 | if not os.path.exists(model): 21 | os.makedirs(model) 22 | for filename in [ 23 | "checkpoint", 24 | "encoder.json", 25 | "hparams.json", 26 | "model.ckpt.data-00000-of-00001", 27 | "model.ckpt.index", 28 | "model.ckpt.meta", 29 | "vocab.bpe", 30 | ]: 31 | resp = requests.get( 32 | f"https://openaipublic.blob.core.windows.net/gpt-2/{model}/{filename}", 33 | stream=True, 34 | ) 35 | 36 | with open("{model}/{filename}", "wb") as file_: 37 | file_size = int(resp.headers["content-length"]) 38 | chunk_size = 1000 39 | with tqdm( 40 | ncols=100, desc=f"Fetching {filename}", total=file_size, unit_scale=True 41 | ) as pbar: 42 | # 1k for chunk_size, since Ethernet packet size is around 1500 bytes. 43 | for chunk in resp.iter_content(chunk_size=chunk_size): 44 | file_.write(chunk) 45 | pbar.update(chunk_size) 46 | 47 | 48 | # Dump the model weights in raw binary if they don't already exist. 49 | weights_dir = f"{model}/raw" 50 | if not os.path.exists(weights_dir): 51 | os.makedirs(weights_dir) 52 | checkpoint = tf.train.load_checkpoint(model) 53 | variables = sorted(list(checkpoint.get_variable_to_shape_map().keys())) 54 | with tqdm( 55 | ncols=100, desc=f"Dumping raw weights", total=len(variables), unit_scale=True 56 | ) as pbar: 57 | for name in variables: 58 | tensor = checkpoint.get_tensor(name).astype(np.float32).squeeze() 59 | # Store weight tensors in column major format. 60 | if name.endswith("/w"): 61 | tensor = tensor.T 62 | fname = name.replace("/", "-") 63 | with open(f"{weights_dir}/{fname}", "wb") as file_: 64 | file_.write(tensor.reshape(-1).tobytes()) 65 | pbar.update(1) 66 | 67 | 68 | # Statically create and dump the unicode->bytes encoder. 69 | def unicode_to_bytes(): 70 | """Returns a dictionary of unicode->byte.""" 71 | bs = ( 72 | list(range(ord("!"), ord("~") + 1)) 73 | + list(range(ord("¡"), ord("¬") + 1)) 74 | + list(range(ord("®"), ord("ÿ") + 1)) 75 | ) 76 | cs = bs[:] 77 | n = 0 78 | for b in range(2**8): 79 | if b not in bs: 80 | bs.append(b) 81 | cs.append(2**8 + n) 82 | n += 1 83 | cs = [chr(n) for n in cs] 84 | # !!NOTE!!: Unlike OpenAI's implementation, we dump out unicode->bytes so we don't 85 | # have to deal with non-string JSON keys. 86 | return dict(zip(cs, bs)) 87 | 88 | 89 | with open(f"{model}/byte_encoder.json", "w") as file_: 90 | json.dump(unicode_to_bytes(), file_) 91 | -------------------------------------------------------------------------------- /generate_nano_gpt.py: -------------------------------------------------------------------------------- 1 | """Definition of GPT-2, largely copied from NanoGPT [1]. 2 | 3 | NOTE: There is some divergence from NanoGPT: 4 | * Always use biases and the default LayerNorm (like GPT-2). 5 | * Use the same vocab size as GPT-2. 6 | * Remove dropout (does not affect inference). 7 | * Stripped down GPT module which only includes forward and returns logits. 8 | * No support for PyTorch 2.0 / flash attention. 9 | 10 | [1]: https://github.com/karpathy/nanoGPT/blob/master/model.py 11 | """ 12 | 13 | import math 14 | import os 15 | from dataclasses import dataclass 16 | 17 | import numpy as np 18 | import tiktoken 19 | import torch 20 | import torch.nn as nn 21 | from torch.nn import functional as F 22 | 23 | 24 | @dataclass 25 | class GPTConfig: 26 | vocab_size: int = 50257 27 | block_size: int = 1024 28 | n_layer: int = 12 29 | n_head: int = 12 30 | n_embd: int = 768 31 | 32 | 33 | def new_gelu(x): 34 | """Gaussian Error Linear Unit (GELU) activation function. 35 | 36 | Copied from NanoGPT and identical to OpenAI GPT-2 implementation. 37 | Paper: https://arxiv.org/abs/1606.08415 38 | """ 39 | # fmt: off 40 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 41 | # fmt: on 42 | 43 | 44 | class MLP(nn.Module): 45 | def __init__(self, config: GPTConfig): 46 | super().__init__() 47 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) 48 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) 49 | 50 | def forward(self, x): 51 | x = self.c_fc(x) 52 | x = new_gelu(x) 53 | x = self.c_proj(x) 54 | return x 55 | 56 | 57 | class CausalSelfAttention(nn.Module): 58 | def __init__(self, config: GPTConfig): 59 | super().__init__() 60 | assert config.n_embd % config.n_head == 0 61 | self.n_head = config.n_head 62 | self.n_embd = config.n_embd 63 | # Key, query, value projections for all heads, but in a batch. 64 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) 65 | # Output projection. 66 | self.c_proj = nn.Linear(config.n_embd, config.n_embd) 67 | # Causal mask. 68 | bias = torch.tril(torch.ones(config.block_size, config.block_size)) 69 | bias = bias.view(1, 1, config.block_size, config.block_size) 70 | self.register_buffer("bias", bias) 71 | 72 | def forward(self, x): 73 | # Batch size, sequence length, embedding dimensionality (n_embd). 74 | B, T, C = x.size() 75 | hs = C // self.n_head 76 | 77 | # Calculate query, key, values for all heads in batch and move head forward to 78 | # be the batch dim. 79 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 80 | k = k.view(B, T, self.n_head, hs).transpose(1, 2) # (B, nh, T, hs) 81 | q = q.view(B, T, self.n_head, hs).transpose(1, 2) # (B, nh, T, hs) 82 | v = v.view(B, T, self.n_head, hs).transpose(1, 2) # (B, nh, T, hs) 83 | 84 | # Manual implementation of attention. 85 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 86 | att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) 87 | att = F.softmax(att, dim=-1) 88 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 89 | 90 | # Re-assemble all head outputs side by side. 91 | y = y.transpose(1, 2).contiguous().view(B, T, C) 92 | 93 | # Output projection. 94 | y = self.c_proj(y) 95 | return y 96 | 97 | 98 | class Block(nn.Module): 99 | def __init__(self, config: GPTConfig): 100 | super().__init__() 101 | self.ln_1 = nn.LayerNorm(config.n_embd) 102 | self.attn = CausalSelfAttention(config) 103 | self.ln_2 = nn.LayerNorm(config.n_embd) 104 | self.mlp = MLP(config) 105 | 106 | def forward(self, x): 107 | x = x + self.attn(self.ln_1(x)) 108 | x = x + self.mlp(self.ln_2(x)) 109 | return x 110 | 111 | 112 | class GPT(nn.Module): 113 | def __init__(self, config): 114 | super().__init__() 115 | self.config = config 116 | self.transformer = nn.ModuleDict( 117 | dict( 118 | wte=nn.Embedding(config.vocab_size, config.n_embd), 119 | wpe=nn.Embedding(config.block_size, config.n_embd), 120 | h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 121 | ln_f=nn.LayerNorm(config.n_embd), 122 | ) 123 | ) 124 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 125 | # https://paperswithcode.com/method/weight-tying 126 | self.transformer.wte.weight = self.lm_head.weight 127 | 128 | def forward(self, idx): 129 | _, t = idx.size() 130 | pos = torch.arange(0, t, dtype=torch.long, device=idx.device) # (t) 131 | 132 | # Forward the GPT model. 133 | tok_emb = self.transformer.wte(idx) # token embeddings (B, T, n_embd) 134 | pos_emb = self.transformer.wpe(pos) # position embeddings (T, n_embd) 135 | x = tok_emb + pos_emb 136 | for block in self.transformer.h: 137 | x = block(x) 138 | x = self.transformer.ln_f(x) 139 | 140 | # Mini-optimization: Only forward the lm_head on the very last position. 141 | return self.lm_head(x[:, [-1], :]) # Using [-1] to preserve the time dim. 142 | 143 | @torch.no_grad() 144 | def generate(self, idx, new_tokens, temp=0.8): 145 | assert len(idx) + new_tokens <= self.config.block_size 146 | 147 | for _ in range(new_tokens): 148 | logits = self(idx)[:, -1, :] / temp 149 | probs = F.softmax(logits, dim=-1) 150 | idx_next = torch.multinomial(probs, num_samples=1) 151 | idx = torch.cat((idx, idx_next), dim=1) 152 | return idx 153 | 154 | 155 | def load_linear(module: nn.Linear, name: str, in_f: int, out_f: int) -> None: 156 | with open(f"models/124M/raw/model-{name}-w", "rb") as file_: 157 | tensor = np.frombuffer(file_.read(), dtype=np.float32) 158 | module.weight.data = torch.tensor(tensor).reshape(out_f, in_f) 159 | 160 | with open(f"models/124M/raw/model-{name}-b", "rb") as file_: 161 | tensor = np.frombuffer(file_.read(), dtype=np.float32) 162 | module.bias.data = torch.tensor(tensor).reshape(out_f) 163 | 164 | 165 | def load_layernorm(module: nn.LayerNorm, name: str) -> None: 166 | with open(f"models/124M/raw/model-{name}-g", "rb") as file_: 167 | tensor = np.frombuffer(file_.read(), dtype=np.float32) 168 | module.weight.data = torch.tensor(tensor) 169 | 170 | with open(f"models/124M/raw/model-{name}-b", "rb") as file_: 171 | tensor = np.frombuffer(file_.read(), dtype=np.float32) 172 | module.bias.data = torch.tensor(tensor) 173 | 174 | 175 | def load_attention(module: CausalSelfAttention, layer: int, n_embd: int) -> None: 176 | load_linear(module.c_attn, f"h{layer}-attn-c_attn", n_embd, 3 * n_embd) 177 | load_linear(module.c_proj, f"h{layer}-attn-c_proj", n_embd, n_embd) 178 | 179 | 180 | def load_mlp(module: MLP, layer: int, n_embd: int) -> None: 181 | load_linear(module.c_fc, f"h{layer}-mlp-c_fc", n_embd, 4 * n_embd) 182 | load_linear(module.c_proj, f"h{layer}-mlp-c_proj", 4 * n_embd, n_embd) 183 | 184 | 185 | def load_block(module: Block, layer: int, n_embd: int) -> None: 186 | load_layernorm(module.ln_1, f"h{layer}-ln_1") 187 | load_attention(module.attn, layer, n_embd) 188 | load_layernorm(module.ln_2, f"h{layer}-ln_2") 189 | load_mlp(module.mlp, layer, n_embd) 190 | 191 | 192 | def load_embedding( 193 | module: nn.Embedding, name: str, vocab_size: int, n_embd: int 194 | ) -> None: 195 | with open(f"models/124M/raw/model-{name}", "rb") as file_: 196 | tensor = np.frombuffer(file_.read(), dtype=np.float32) 197 | tensor = torch.tensor(tensor).reshape(vocab_size, n_embd) 198 | module.weight.data = tensor 199 | 200 | 201 | def load_gpt(module: GPT, config: GPTConfig) -> None: 202 | load_embedding(module.transformer.wte, "wte", config.vocab_size, config.n_embd) 203 | load_embedding(module.transformer.wpe, "wpe", config.block_size, config.n_embd) 204 | for i in range(config.n_layer): 205 | load_block(module.transformer.h[i], i, config.n_embd) 206 | load_layernorm(module.transformer.ln_f, "ln_f") 207 | # Loading wte should automatically load lm_head since they point to the same tensor. 208 | assert module.lm_head.weight is module.transformer.wte.weight 209 | 210 | 211 | gpt_config = GPTConfig() 212 | gpt = GPT(gpt_config).eval() 213 | load_gpt(gpt, gpt_config) 214 | 215 | encoder = tiktoken.get_encoding("gpt2") 216 | encoded = encoder.encode( 217 | "Marcus Aurelius said thus: ", allowed_special={"<|endoftext|>"} 218 | ) 219 | inputs = torch.tensor(encoded).view((1, -1)) 220 | outputs = gpt(inputs) 221 | 222 | generated = gpt.generate(inputs, 10).tolist()[0] 223 | print(encoder.decode(generated)) 224 | 225 | name_to_tensor = { 226 | "gpt_inputs": inputs, 227 | "gpt_outputs": outputs, 228 | } 229 | for name, tensor in name_to_tensor.items(): 230 | if not os.path.exists(f"models/test/{name}"): 231 | with open(f"models/test/{name}", "wb") as file_: 232 | file_.write(tensor.reshape(-1).detach().numpy().tobytes()) 233 | -------------------------------------------------------------------------------- /generate_test_data.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | name_to_tensor = {} 9 | 10 | # Generate Linear. 11 | inputs = torch.randn(3, 768) 12 | 13 | linear = nn.Linear(in_features=768, out_features=4 * 768) 14 | outputs = linear(inputs) 15 | 16 | linear_no_bias = nn.Linear(in_features=768, out_features=4 * 768, bias=False) 17 | linear_no_bias.weight.data = linear.weight.data 18 | outputs_no_bias = linear_no_bias(inputs) 19 | 20 | name_to_tensor.update( 21 | { 22 | "linear_weight": linear.weight, 23 | "linear_bias": linear.bias, 24 | "linear_inputs": inputs, 25 | "linear_outputs": outputs, 26 | "linear_outputs_no_bias": outputs_no_bias, 27 | } 28 | ) 29 | 30 | 31 | # Generate GELU. 32 | def gelu(x): 33 | """Gaussian Error Linear Unit (GELU) activation function. 34 | 35 | Copied from the nanogpt repo and identical to OpenAI GPT2 implementation. Paper: 36 | https://arxiv.org/abs/1606.08415 37 | """ 38 | # fmt: off 39 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 40 | # fmt: on 41 | 42 | 43 | inputs = torch.randn(3, 768) 44 | outputs = gelu(inputs) 45 | name_to_tensor.update({"gelu_inputs": inputs, "gelu_outputs": outputs}) 46 | 47 | 48 | # Generate softmax. 49 | inputs = torch.randn(3, 768) 50 | outputs = F.softmax(inputs, dim=-1) 51 | name_to_tensor.update({"softmax_inputs": inputs, "softmax_outputs": outputs}) 52 | 53 | 54 | # Generate Embedding. 55 | embedding = nn.Embedding(10, 768) 56 | inputs = torch.randint(0, 10, (3,)) 57 | outputs = embedding(inputs) 58 | name_to_tensor.update( 59 | { 60 | "embedding_weight": embedding.weight, 61 | "embedding_inputs": inputs, 62 | "embedding_outputs": outputs, 63 | } 64 | ) 65 | 66 | 67 | # Generate LayerNorm. 68 | layer_norm = nn.LayerNorm(768) 69 | inputs = torch.randn(3, 768) 70 | outputs = layer_norm(inputs) 71 | name_to_tensor.update( 72 | { 73 | "layer_norm_weight": layer_norm.weight, 74 | "layer_norm_bias": layer_norm.bias, 75 | "layer_norm_inputs": inputs, 76 | "layer_norm_outputs": outputs, 77 | } 78 | ) 79 | 80 | 81 | # Generate causal self attention. 82 | batch_size, seq_len, n_head, head_dim = 1, 5, 12, 64 83 | n_embed = n_head * head_dim 84 | 85 | # Generate transpose intermediaries. 86 | inputs = torch.randn(batch_size, seq_len, n_head, head_dim) 87 | outputs = inputs.transpose(1, 2) 88 | name_to_tensor.update({"transpose_inputs": inputs, "transpose_outputs": outputs}) 89 | 90 | # Generate split intermediaries. 91 | inputs = torch.randn(batch_size, seq_len, 3 * n_embed) 92 | q, k, v = inputs.split(n_embed, dim=2) 93 | name_to_tensor.update( 94 | {"split_inputs": inputs, "split_q": q, "split_k": k, "split_v": v} 95 | ) 96 | 97 | 98 | inputs = torch.randn(batch_size, seq_len, n_embed) 99 | c_attn = nn.Linear(in_features=n_embed, out_features=3 * n_embed) 100 | outputs = c_attn(inputs) 101 | name_to_tensor.update( 102 | { 103 | "attn_inputs": inputs, 104 | "attn_c_attn_weight": c_attn.weight, 105 | "attn_c_attn_bias": c_attn.bias, 106 | } 107 | ) 108 | 109 | # Generate intermediaries from scaled dot product attention. 110 | q, k, v = outputs.split(n_embed, dim=2) 111 | q = q.view(batch_size, seq_len, n_head, n_embed // n_head).transpose(1, 2) 112 | k = k.view(batch_size, seq_len, n_head, n_embed // n_head).transpose(1, 2) 113 | v = v.view(batch_size, seq_len, n_head, n_embed // n_head).transpose(1, 2) 114 | mask = torch.tril(torch.ones(seq_len, seq_len).view(1, 1, seq_len, seq_len)) 115 | attn = q @ k.transpose(-2, -1) / math.sqrt(k.size(-1)) 116 | attn = attn.masked_fill(mask[:, :, :seq_len, :seq_len] == 0, float("-inf")) 117 | attn = F.softmax(attn, dim=-1) 118 | outputs = attn @ v 119 | name_to_tensor.update({"sdpa_q": q, "sdpa_k": k, "sdpa_v": v, "sdpa_outputs": outputs}) 120 | 121 | inputs = outputs.transpose(1, 2).contiguous().view(batch_size, seq_len, n_embed) 122 | c_proj = nn.Linear(n_embed, n_embed) 123 | outputs = c_proj(inputs) 124 | name_to_tensor.update( 125 | { 126 | "attn_c_proj_weight": c_proj.weight, 127 | "attn_c_proj_bias": c_proj.bias, 128 | "attn_outputs": outputs, 129 | } 130 | ) 131 | 132 | for name, tensor in name_to_tensor.items(): 133 | if not os.path.exists(f"models/test/{name}"): 134 | with open(f"models/test/{name}", "wb") as file_: 135 | file_.write(tensor.reshape(-1).detach().numpy().tobytes()) 136 | -------------------------------------------------------------------------------- /src/bpe.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const c = @cImport(@cInclude("regex.h")); 3 | 4 | pub const Encoder = struct { 5 | const Self = @This(); 6 | const int_to_str_t = std.hash_map.AutoHashMap(usize, []const u8); 7 | 8 | token_to_idx: std.json.ObjectMap, 9 | idx_to_token: int_to_str_t, 10 | unicode_to_byte: std.json.ObjectMap, 11 | byte_to_unicode: int_to_str_t, 12 | regex: *c.regex_t, 13 | 14 | pub fn init( 15 | token_to_idx: std.json.ObjectMap, 16 | unicode_to_byte: std.json.ObjectMap, 17 | allocator: std.mem.Allocator, 18 | ) !Self { 19 | // Setup encoders. 20 | var idx_to_token = int_to_str_t.init(allocator); 21 | var it = token_to_idx.iterator(); 22 | while (it.next()) |item| { 23 | try idx_to_token.put(@intCast(item.value_ptr.*.integer), item.key_ptr.*); 24 | } 25 | var byte_to_unicode = int_to_str_t.init(allocator); 26 | it = unicode_to_byte.iterator(); 27 | while (it.next()) |item| { 28 | try byte_to_unicode.put(@intCast(item.value_ptr.*.integer), item.key_ptr.*); 29 | } 30 | 31 | // Setup regex. 32 | var slice = try allocator.alignedAlloc(u8, @alignOf(c.regex_t), @sizeOf(c.regex_t)); 33 | const regex = @as(*c.regex_t, @ptrCast(slice.ptr)); 34 | const contractions = "'s|'t|'re|'ve|'m|'ll|'d"; 35 | const letters = "|[[:space:]]?[[:alpha:]]+"; 36 | const numbers = "|[[:space:]]?[[:digit:]]+"; 37 | const others = "|[[:space:]]?[^[:space:][:alpha:][:digit:]]+"; 38 | // TODO(eugenhotaj): Multiple spaces between tokens are not handled correctly! 39 | const space = "|[[:space:]]+"; 40 | _ = c.regcomp(regex, contractions ++ letters ++ numbers ++ others ++ space, c.REG_EXTENDED); 41 | 42 | return Self{ 43 | .token_to_idx = token_to_idx, 44 | .idx_to_token = idx_to_token, 45 | .unicode_to_byte = unicode_to_byte, 46 | .byte_to_unicode = byte_to_unicode, 47 | .regex = regex, 48 | }; 49 | } 50 | 51 | pub fn deinit(self: *Self) void { 52 | self.idx_to_token.deinit(); 53 | self.token_to_idx.deinit(); 54 | self.unicode_to_byte.deinit(); 55 | self.byte_to_unicode.deinit(); 56 | c.regfree(self.regex); 57 | } 58 | 59 | pub fn encode(self: Self, inputs: []const u8, outputs: []usize) usize { 60 | var matches: [1]c.regmatch_t = undefined; 61 | var token_idx: usize = 0; 62 | var offset: usize = 0; 63 | while (offset < inputs.len) { 64 | // Match next word. 65 | _ = c.regexec(self.regex, inputs[offset..].ptr, matches.len, &matches, 0); 66 | const match = matches[0]; 67 | const match_so = offset + @as(usize, @intCast(match.rm_so)); 68 | const match_eo = offset + @as(usize, @intCast(match.rm_eo)); 69 | 70 | // Replace bytes with unicode. 71 | var word: [20]u8 = undefined; 72 | var word_eo: usize = 0; 73 | for (inputs[match_so..match_eo]) |byte| { 74 | for (self.byte_to_unicode.get(byte).?) |code| { 75 | word[word_eo] = code; 76 | word_eo += 1; 77 | } 78 | } 79 | 80 | // Tokenize word. 81 | var token_so: usize = 0; 82 | var token_eo = word_eo; 83 | while (token_so < token_eo) { 84 | if (self.token_to_idx.contains(word[token_so..token_eo])) { 85 | outputs[token_idx] = @intCast(self.token_to_idx.get(word[token_so..token_eo]).?.integer); 86 | token_so = token_eo; 87 | token_eo = word_eo; 88 | token_idx += 1; 89 | } else { 90 | token_eo -= 1; 91 | } 92 | } 93 | 94 | offset = match_eo; 95 | } 96 | return token_idx; 97 | } 98 | 99 | pub fn decode(self: Self, inputs: []const usize, outputs: []u8) usize { 100 | var outputs_len: usize = 0; 101 | for (inputs) |idx| { 102 | const token = self.idx_to_token.get(idx).?; 103 | var i: usize = 0; 104 | while (i < token.len) { 105 | var char: []const u8 = undefined; 106 | if (self.unicode_to_byte.contains(token[i .. i + 1])) { 107 | char = token[i .. i + 1]; 108 | i += 1; 109 | } else { 110 | char = token[i .. i + 2]; 111 | i += 2; 112 | } 113 | outputs[outputs_len] = @intCast(self.unicode_to_byte.get(char).?.integer); 114 | outputs_len += 1; 115 | } 116 | } 117 | return outputs_len; 118 | } 119 | }; 120 | -------------------------------------------------------------------------------- /src/main.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const ops = @import("ops.zig"); 3 | const bpe = @import("bpe.zig"); 4 | 5 | const GPTConfig = struct { 6 | const Self = @This(); 7 | 8 | vocab_size: usize, 9 | context_size: usize, 10 | n_layer: usize, 11 | n_heads: usize, 12 | n_embed: usize, 13 | 14 | pub fn init(vocab_size: usize, context_size: usize, n_layer: usize, n_heads: usize, n_embed: usize) Self { 15 | return Self{ 16 | .vocab_size = vocab_size, 17 | .context_size = context_size, 18 | .n_layer = n_layer, 19 | .n_heads = n_heads, 20 | .n_embed = n_embed, 21 | }; 22 | } 23 | }; 24 | 25 | /// Structure which maintains state which is shared across all GPT layers. 26 | pub const State = struct { 27 | const Self = @This(); 28 | 29 | pos_emb: []f32, 30 | x: []f32, 31 | o: []f32, 32 | logits: []f32, 33 | decoded: []u8, 34 | 35 | // Intermediate buffers. 36 | _h: []f32, 37 | _4xh: []f32, 38 | _qkv: []f32, 39 | _q: []f32, 40 | _k: []f32, 41 | _v: []f32, 42 | _attn: []f32, 43 | 44 | allocator: std.mem.Allocator, 45 | 46 | pub fn init(config: GPTConfig, allocator: std.mem.Allocator) !Self { 47 | return Self{ 48 | .pos_emb = try allocator.alloc(f32, 1 * config.n_embed), 49 | .x = try allocator.alloc(f32, 1 * config.n_embed), 50 | .o = try allocator.alloc(f32, 1 * config.n_embed), 51 | .logits = try allocator.alloc(f32, config.vocab_size), 52 | .decoded = try allocator.alloc(u8, 20), 53 | 54 | ._h = try allocator.alloc(f32, 1 * config.n_embed), 55 | ._4xh = try allocator.alloc(f32, 1 * 4 * config.n_embed), 56 | ._qkv = try allocator.alloc(f32, 1 * 3 * config.n_embed), 57 | ._q = try allocator.alloc(f32, 1 * config.n_embed), 58 | ._k = try allocator.alloc(f32, config.context_size * config.n_embed), 59 | ._v = try allocator.alloc(f32, config.context_size * config.n_embed), 60 | ._attn = try allocator.alloc(f32, 1 * config.context_size), 61 | 62 | .allocator = allocator, 63 | }; 64 | } 65 | }; 66 | 67 | const MLP = struct { 68 | const Self = @This(); 69 | 70 | c_fc: ops.Linear, 71 | c_proj: ops.Linear, 72 | 73 | pub fn init(c_fc: ops.Linear, c_proj: ops.Linear) MLP { 74 | return MLP{ .c_fc = c_fc, .c_proj = c_proj }; 75 | } 76 | 77 | /// Computes the forward pass and writes the result to state.o. 78 | pub fn forward(self: Self, inputs: []const f32, state: State) void { 79 | self.c_fc.forward(inputs, state._4xh); 80 | ops.gelu(state._4xh); 81 | self.c_proj.forward(state._4xh, state.o); 82 | } 83 | }; 84 | 85 | const Block = struct { 86 | const Self = @This(); 87 | 88 | n_embed: usize, 89 | ln_1: ops.LayerNorm, 90 | attn: ops.CausalSelfAttention, 91 | ln_2: ops.LayerNorm, 92 | mlp: MLP, 93 | k_cache: []f32, 94 | v_cache: []f32, 95 | 96 | pub fn init( 97 | n_embed: usize, 98 | ln_1: ops.LayerNorm, 99 | attn: ops.CausalSelfAttention, 100 | ln_2: ops.LayerNorm, 101 | mlp: MLP, 102 | k_cache: []f32, 103 | v_cache: []f32, 104 | ) Self { 105 | return Self{ 106 | .n_embed = n_embed, 107 | .ln_1 = ln_1, 108 | .attn = attn, 109 | .ln_2 = ln_2, 110 | .mlp = mlp, 111 | .k_cache = k_cache, 112 | .v_cache = v_cache, 113 | }; 114 | } 115 | 116 | /// Computes the forward pass and writes the result to both state.x and state.o. This 117 | /// enables sequentially calling multiple Block.forwards() in a row without having to copy 118 | /// memory. 119 | pub fn forward(self: Self, seq_len: usize, inputs: []const f32, state: State) void { 120 | // Create a copy of x for residual computation. 121 | @memcpy(state._h, inputs); 122 | 123 | self.ln_1.forward(state._h); 124 | self.attn.forward( 125 | seq_len, 126 | state._h, 127 | self.k_cache[0 .. seq_len * self.n_embed], 128 | self.v_cache[0 .. seq_len * self.n_embed], 129 | state.o, 130 | state._qkv, 131 | state._q, 132 | state._k[0 .. seq_len * self.n_embed], 133 | state._v[0 .. seq_len * self.n_embed], 134 | state._attn[0..seq_len], 135 | ); 136 | for (0..state.o) |i| { 137 | state._h[i] = state.o[i] + inputs[i]; 138 | state.x[i] = state._h[i]; 139 | } 140 | self.ln_2.forward(state._h); 141 | self.mlp.forward(state._h, state); 142 | for (0..state.o) |i| { 143 | state.o[i] += state.x[i]; 144 | state.x[i] = state.o[i]; 145 | } 146 | } 147 | }; 148 | 149 | const GPT = struct { 150 | const Self = @This(); 151 | 152 | config: GPTConfig, 153 | wte: ops.Embedding, 154 | wpe: ops.Embedding, 155 | h: []const Block, 156 | ln_f: ops.LayerNorm, 157 | lm_head: ops.Linear, 158 | 159 | pub fn init( 160 | config: GPTConfig, 161 | wte: ops.Embedding, 162 | wpe: ops.Embedding, 163 | h: []const Block, 164 | ln_f: ops.LayerNorm, 165 | lm_head: ops.Linear, 166 | ) Self { 167 | return Self{ 168 | .config = config, 169 | .wte = wte, 170 | .wpe = wpe, 171 | .h = h, 172 | .ln_f = ln_f, 173 | .lm_head = lm_head, 174 | }; 175 | } 176 | 177 | /// Computes the forward pass and writes the result in state.logits. 178 | pub fn forward(self: Self, seq_len: usize, token: usize, compute_logits: bool, state: State) void { 179 | self.wpe.forward(&[1]usize{seq_len - 1}, state.pos_emb); 180 | self.wte.forward(&[1]usize{token}, state.x); 181 | for (0..self.config.n_embed) |i| { 182 | state.x[i] += state.pos_emb[i]; 183 | } 184 | 185 | // Forward the transformer. 186 | for (0..self.h.len) |i| { 187 | self.h[i].forward(seq_len, state.x, state); 188 | } 189 | self.ln_f.forward(state.x); 190 | 191 | // Compute logits. 192 | if (compute_logits) { 193 | self.lm_head.forward(state.x, state.logits); 194 | } 195 | } 196 | 197 | /// Samples the next token. 198 | pub fn sample(self: Self, seq_len: usize, temp: f32, token: usize, state: State) usize { 199 | self.forward(seq_len, token, true, state); 200 | for (0..state.logits.len) |i| { 201 | state.logits[i] /= temp; 202 | } 203 | ops.softmax(state.logits); 204 | var rng = std.rand.DefaultPrng.init(@intCast(std.time.timestamp())); 205 | var random = rng.random(); 206 | return random.weightedIndex(f32, state.logits); 207 | } 208 | }; 209 | 210 | pub fn load_linear( 211 | name: []const u8, 212 | in_features: usize, 213 | out_features: usize, 214 | allocator: std.mem.Allocator, 215 | ) !ops.Linear { 216 | const weight_path = try std.fmt.allocPrint(allocator, "models/124M/raw/model-{s}-w", .{name}); 217 | defer allocator.free(weight_path); 218 | var weight = try ops.load_tensor( 219 | weight_path, 220 | &[_]usize{ in_features, out_features }, 221 | f32, 222 | allocator, 223 | ); 224 | const bias_path = try std.fmt.allocPrint(allocator, "models/124M/raw/model-{s}-b", .{name}); 225 | defer allocator.free(bias_path); 226 | var bias = try ops.load_tensor( 227 | bias_path, 228 | &[_]usize{out_features}, 229 | f32, 230 | allocator, 231 | ); 232 | return ops.Linear.init(in_features, out_features, weight, bias); 233 | } 234 | 235 | pub fn load_layer_norm( 236 | name: []const u8, 237 | n_features: usize, 238 | allocator: std.mem.Allocator, 239 | ) !ops.LayerNorm { 240 | const weight_path = try std.fmt.allocPrint(allocator, "models/124M/raw/model-{s}-g", .{name}); 241 | defer allocator.free(weight_path); 242 | var weight = try ops.load_tensor( 243 | weight_path, 244 | &[_]usize{n_features}, 245 | f32, 246 | allocator, 247 | ); 248 | const bias_path = try std.fmt.allocPrint(allocator, "models/124M/raw/model-{s}-b", .{name}); 249 | defer allocator.free(bias_path); 250 | var bias = try ops.load_tensor( 251 | bias_path, 252 | &[_]usize{n_features}, 253 | f32, 254 | allocator, 255 | ); 256 | return ops.LayerNorm.init(n_features, weight, bias); 257 | } 258 | 259 | pub fn load_embedding(name: []const u8, vocab_size: usize, emb_dim: usize, allocator: std.mem.Allocator) !ops.Embedding { 260 | const path = try std.fmt.allocPrint(allocator, "models/124M/raw/model-{s}", .{name}); 261 | defer allocator.free(path); 262 | var weight = try ops.load_tensor( 263 | path, 264 | &[_]usize{ vocab_size, emb_dim }, 265 | f32, 266 | allocator, 267 | ); 268 | return ops.Embedding.init(emb_dim, weight); 269 | } 270 | 271 | pub fn load_block(layer_idx: usize, config: GPTConfig, allocator: std.mem.Allocator) !Block { 272 | const ln_1_name = try std.fmt.allocPrint(allocator, "h{any}-ln_1", .{layer_idx}); 273 | defer allocator.free(ln_1_name); 274 | const ln_1 = try load_layer_norm(ln_1_name, config.n_embed, allocator); 275 | 276 | const c_attn_name = try std.fmt.allocPrint(allocator, "h{any}-attn-c_attn", .{layer_idx}); 277 | defer allocator.free(c_attn_name); 278 | const c_attn = try load_linear(c_attn_name, config.n_embed, 3 * config.n_embed, allocator); 279 | 280 | const c_proj_name = try std.fmt.allocPrint(allocator, "h{any}-attn-c_proj", .{layer_idx}); 281 | defer allocator.free(c_proj_name); 282 | const c_proj = try load_linear(c_proj_name, config.n_embed, config.n_embed, allocator); 283 | 284 | const ln_2_name = try std.fmt.allocPrint(allocator, "h{any}-ln_2", .{layer_idx}); 285 | defer allocator.free(ln_2_name); 286 | const ln_2 = try load_layer_norm(ln_2_name, config.n_embed, allocator); 287 | 288 | const c_fc_name = try std.fmt.allocPrint(allocator, "h{any}-mlp-c_fc", .{layer_idx}); 289 | defer allocator.free(c_fc_name); 290 | const c_fc = try load_linear(c_fc_name, config.n_embed, 4 * config.n_embed, allocator); 291 | 292 | const mlp_c_proj_name = try std.fmt.allocPrint(allocator, "h{any}-mlp-c_proj", .{layer_idx}); 293 | defer allocator.free(mlp_c_proj_name); 294 | const mlp_c_proj = try load_linear(mlp_c_proj_name, 4 * config.n_embed, config.n_embed, allocator); 295 | 296 | const attn = ops.CausalSelfAttention.init(config.n_heads, config.n_embed, c_attn, c_proj); 297 | const mlp = MLP.init(c_fc, mlp_c_proj); 298 | const k_cache = try allocator.alloc(f32, config.context_size * config.n_embed); 299 | const v_cache = try allocator.alloc(f32, config.context_size * config.n_embed); 300 | 301 | return Block.init(config.n_embed, ln_1, attn, ln_2, mlp, k_cache, v_cache); 302 | } 303 | 304 | pub fn load_gpt(config: GPTConfig, allocator: std.mem.Allocator) !GPT { 305 | var wte = try load_embedding("wte", config.vocab_size, config.n_embed, allocator); 306 | const wpe = try load_embedding("wpe", config.context_size, config.n_embed, allocator); 307 | var h = try allocator.alloc(Block, config.n_layer); 308 | for (0..h.len) |i| { 309 | h[i] = try load_block(i, config, allocator); 310 | } 311 | const ln_f = try load_layer_norm("ln_f", config.n_embed, allocator); 312 | const lm_head = ops.Linear.init(config.n_embed, config.vocab_size, wte.weight, null); 313 | return GPT.init(config, wte, wpe, h, ln_f, lm_head); 314 | } 315 | 316 | pub fn load_encoder(allocator: std.mem.Allocator) !bpe.Encoder { 317 | const parsed_encoder = try ops.load_json("models/124M/encoder.json", allocator); 318 | const parsed_bytes_encoder = try ops.load_json("models/124M/byte_encoder.json", allocator); 319 | return bpe.Encoder.init(parsed_encoder.object, parsed_bytes_encoder.object, allocator); 320 | } 321 | 322 | pub fn generate( 323 | gpt: GPT, 324 | encoder: bpe.Encoder, 325 | temp: f32, 326 | inputs: []usize, 327 | state: State, 328 | ) void { 329 | var token: usize = undefined; 330 | for (0..gpt.config.context_size) |s| { 331 | if (s < inputs.len) { 332 | // Fill up KV cache. 333 | token = inputs[s]; 334 | gpt.forward(s + 1, token, false, state); 335 | } else { 336 | // Generate. 337 | token = gpt.sample(s + 1, temp, token, state); 338 | } 339 | const decoded_len = encoder.decode(&[_]usize{token}, state.decoded); 340 | std.debug.print("{s}", .{state.decoded[0..decoded_len]}); 341 | } 342 | } 343 | 344 | pub fn main() !void { 345 | const temp = 0.8; 346 | const config = GPTConfig.init(50257, 1024, 12, 12, 768); 347 | 348 | var gpa = std.heap.GeneralPurposeAllocator(.{}){}; 349 | var arena = std.heap.ArenaAllocator.init(gpa.allocator()); 350 | defer arena.deinit(); 351 | const allocator = arena.allocator(); 352 | 353 | var inputs = try allocator.alloc(usize, config.context_size); 354 | var encoder = try load_encoder(allocator); 355 | defer encoder.deinit(); 356 | var state = try State.init(config, allocator); 357 | const gpt = try load_gpt(config, allocator); 358 | 359 | const args = try std.process.argsAlloc(allocator); 360 | defer std.process.argsFree(allocator, args); 361 | const prompt = args[1]; 362 | 363 | const input_tokens = encoder.encode(prompt, inputs); 364 | generate( 365 | gpt, 366 | encoder, 367 | temp, 368 | inputs[0..input_tokens], 369 | state, 370 | ); 371 | } 372 | -------------------------------------------------------------------------------- /src/ops.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const c = @cImport(@cInclude("Accelerate/Accelerate.h")); 3 | 4 | pub const Linear = struct { 5 | const Self = @This(); 6 | 7 | in_features: usize, 8 | out_features: usize, 9 | weight: []const f32, // Weights must be provided in *column major* order! 10 | bias: ?[]const f32, 11 | 12 | pub fn init(in_features: usize, out_features: usize, weight: []const f32, bias: ?[]const f32) Self { 13 | return Self{ 14 | .in_features = in_features, 15 | .out_features = out_features, 16 | .weight = weight, 17 | .bias = bias, 18 | }; 19 | } 20 | 21 | pub fn forward(self: Self, inputs: []const f32, outputs: []f32) void { 22 | const batch_size = inputs.len / self.in_features; 23 | var beta: f32 = 0.0; 24 | if (self.bias) |bias| { 25 | for (0..batch_size) |b| { 26 | @memcpy(outputs[b * self.out_features .. (b + 1) * self.out_features], bias); 27 | } 28 | beta = 1.0; 29 | } 30 | c.cblas_sgemm( 31 | c.CblasRowMajor, 32 | c.CblasNoTrans, 33 | c.CblasTrans, 34 | @intCast(batch_size), 35 | @intCast(self.out_features), 36 | @intCast(self.in_features), 37 | 1.0, 38 | inputs.ptr, 39 | @intCast(self.in_features), 40 | self.weight.ptr, 41 | @intCast(self.in_features), 42 | beta, 43 | outputs.ptr, 44 | @intCast(self.out_features), 45 | ); 46 | } 47 | }; 48 | 49 | pub const Embedding = struct { 50 | const Self = @This(); 51 | 52 | emb_dim: usize, 53 | weight: []const f32, 54 | 55 | pub fn init(emb_dim: usize, weight: []const f32) Self { 56 | return Self{ .emb_dim = emb_dim, .weight = weight }; 57 | } 58 | 59 | pub fn forward(self: Self, idxs: []const usize, embeddings: []f32) void { 60 | for (0..idxs.len) |i| { 61 | const idx = idxs[i]; 62 | @memcpy( 63 | embeddings[i * self.emb_dim .. (i + 1) * self.emb_dim], 64 | self.weight[self.emb_dim * idx .. self.emb_dim * (idx + 1)], 65 | ); 66 | } 67 | } 68 | }; 69 | 70 | pub const LayerNorm = struct { 71 | const Self = @This(); 72 | 73 | n_features: usize, 74 | weight: []const f32, 75 | bias: []const f32, 76 | eps: f32 = 1e-5, 77 | 78 | pub fn init(n_features: usize, weight: []const f32, bias: []const f32) Self { 79 | return Self{ .n_features = n_features, .weight = weight, .bias = bias }; 80 | } 81 | 82 | pub fn forward(self: Self, inputs: []f32) void { 83 | const batch_size = inputs.len / self.n_features; 84 | for (0..batch_size) |b| { 85 | // Compute the mean and variance. 86 | var mean: f32 = 0.0; 87 | var std_: f32 = 0.0; 88 | for (0..self.n_features) |i| { 89 | const x = inputs[b * self.n_features + i]; 90 | mean += x; 91 | std_ += x * x; 92 | } 93 | const n: f32 = @floatFromInt(self.n_features); 94 | mean /= n; 95 | std_ = @sqrt((std_ / n) - (mean * mean) + self.eps); 96 | 97 | // Normalize. 98 | for (0..self.n_features) |i| { 99 | const idx = b * self.n_features + i; 100 | const x = inputs[idx]; 101 | inputs[idx] = (x - mean) / std_ * self.weight[i] + self.bias[i]; 102 | } 103 | } 104 | } 105 | }; 106 | 107 | pub const CausalSelfAttention = struct { 108 | const Self = @This(); 109 | 110 | n_heads: usize, 111 | n_embed: usize, 112 | head_dim: usize, 113 | c_attn: Linear, 114 | c_proj: Linear, 115 | 116 | pub fn init(n_heads: usize, n_embed: usize, c_attn: Linear, c_proj: Linear) Self { 117 | return Self{ 118 | .n_heads = n_heads, 119 | .n_embed = n_embed, 120 | .head_dim = n_embed / n_heads, 121 | .c_attn = c_attn, 122 | .c_proj = c_proj, 123 | }; 124 | } 125 | 126 | // TODO(eugenhotaj): Remove the batch_size == 1 restriction. We impose this restriction right 127 | // now because extending the KV cache for larger batch sizes is a bit tedious. It involves 128 | // expanding the sequence dimension which requires copying and moving around memory. 129 | pub fn forward( 130 | self: Self, 131 | seq_len: usize, 132 | inputs: []const f32, 133 | k_cache: []f32, 134 | v_cache: []f32, 135 | outputs: []f32, 136 | // Parameters below are intermediate buffers used inside the function. 137 | _qkv: []f32, 138 | _q: []f32, 139 | _k: []f32, 140 | _v: []f32, 141 | _attn: []f32, 142 | ) void { 143 | self.c_attn.forward(inputs, _qkv); 144 | 145 | // Q: 1 * n_embed. 146 | self.split_qkv(1, _qkv, 0, outputs); 147 | Self.transpose([3]usize{ 1, self.n_heads, self.head_dim }, outputs, _q); 148 | 149 | const t_shape = [3]usize{ seq_len, self.n_heads, self.head_dim }; 150 | // Extend K: 1 * n_embed --> seq_len * n_embed. 151 | self.split_qkv(1, _qkv, 1, outputs); 152 | @memcpy(k_cache[(seq_len - 1) * self.n_embed .. seq_len * self.n_embed], outputs); 153 | Self.transpose(t_shape, k_cache, _k); 154 | 155 | // Extend V: 1 * n_embed --> seq_len * n_embed. 156 | self.split_qkv(1, _qkv, 2, outputs); 157 | @memcpy(v_cache[(seq_len - 1) * self.n_embed .. seq_len * self.n_embed], outputs); 158 | Self.transpose(t_shape, v_cache, _v); 159 | 160 | scaled_dot_product_attention( 161 | _q, 162 | _k, 163 | _v, 164 | self.n_heads, 165 | seq_len, 166 | self.head_dim, 167 | outputs, 168 | _attn, 169 | ); 170 | // Hack: Store untranspose in _q so we don't need to keep another buffer. 171 | Self.transpose([3]usize{ self.n_heads, 1, self.head_dim }, outputs, _q); 172 | self.c_proj.forward(_q, outputs); 173 | } 174 | 175 | /// Splits (seq_len, 3 * n_embed) -> (batch_size, n_heads, n_embed). The split_index 176 | /// determines which split to return. 177 | pub fn split_qkv( 178 | self: Self, 179 | seq_len: usize, 180 | inputs: []const f32, 181 | split_idx: usize, 182 | outputs: []f32, 183 | ) void { 184 | const n_embed_ = 3 * self.n_embed; 185 | const batch_size = inputs.len / (seq_len * n_embed_); 186 | for (0..batch_size) |b| { 187 | for (0..seq_len) |r| { 188 | const out_offset = (b * seq_len * self.n_embed) + (r * self.n_embed); 189 | const in_offset = (b * seq_len * n_embed_) + (r * n_embed_) + (split_idx * self.n_embed); 190 | @memcpy( 191 | outputs[out_offset .. out_offset + self.n_embed], 192 | inputs[in_offset .. in_offset + self.n_embed], 193 | ); 194 | } 195 | } 196 | } 197 | 198 | // Transposes (b, t, n, h) --> (b, n, t, h) where shape contains the sizes of (t, n, h). 199 | pub fn transpose(shape: [3]usize, inputs: []const f32, outputs: []f32) void { 200 | const seq_len = shape[0]; 201 | const n_heads = shape[1]; 202 | const head_dim = shape[2]; 203 | const batch_size = inputs.len / (seq_len * n_heads * head_dim); 204 | for (0..batch_size) |b| { 205 | for (0..n_heads) |h| { 206 | for (0..seq_len) |s| { 207 | const in_offset = (b * seq_len * n_heads * head_dim) + (s * n_heads * head_dim) + (h * head_dim); 208 | const out_offset = (b * seq_len * n_heads * head_dim) + (h * seq_len * head_dim) + (s * head_dim); 209 | @memcpy( 210 | outputs[out_offset .. out_offset + head_dim], 211 | inputs[in_offset .. in_offset + head_dim], 212 | ); 213 | } 214 | } 215 | } 216 | } 217 | }; 218 | 219 | /// Computes Gaussian Error Linear Unit (GELU) activation on the given inputs tensor inplace. 220 | /// Paper: https://arxiv.org/abs/1606.08415 221 | pub fn gelu(inputs: []f32) void { 222 | for (0..inputs.len) |i| { 223 | const x = inputs[i]; 224 | inputs[i] = 0.5 * x * (1.0 + std.math.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))); 225 | // Faster, but less accurate gelu. 226 | // inputs[i] = x / (1.0 + @exp(-1.702 * x)); 227 | } 228 | } 229 | 230 | /// Computes the (stable) softmax of the given inputs vector inplace. 231 | pub fn softmax(inputs: []f32) void { 232 | const max = std.mem.max(f32, inputs); 233 | var sum: f32 = 0.0; 234 | for (0..inputs.len) |i| { 235 | inputs[i] = @exp(inputs[i] - max); 236 | sum += inputs[i]; 237 | } 238 | for (0..inputs.len) |i| { 239 | inputs[i] /= sum; 240 | } 241 | } 242 | 243 | /// Computes the scaled dot product attention. 244 | /// 245 | /// The dimensions of the input tensors are expected to be: 246 | /// q: batch_size * n_heads * 1 * head_dim 247 | /// k: batch_size * n_heads * seq_len * head_dim 248 | /// v: batch_size * n_heads * seq_len * head_dim 249 | pub fn scaled_dot_product_attention( 250 | q: []const f32, 251 | k: []const f32, 252 | v: []const f32, 253 | n_heads: usize, 254 | seq_len: usize, 255 | head_dim: usize, 256 | outputs: []f32, 257 | _attn: []f32, // Intermediate buffers used inside the function. 258 | ) void { 259 | const batch_size = k.len / (n_heads * seq_len * head_dim); 260 | for (0..batch_size) |b| { 261 | for (0..n_heads) |h| { 262 | const qo_offset = (b * n_heads * 1 * head_dim) + (h * 1 * head_dim); 263 | const kv_offset = (b * n_heads * seq_len * head_dim) + (h * seq_len * head_dim); 264 | 265 | // Compute attention logits, i.e. attn = softmax((q @ k.T) / sqrt(head_dim)). 266 | var q_slice = q[qo_offset .. qo_offset + 1 * head_dim]; 267 | var k_slice = k[kv_offset .. kv_offset + seq_len * head_dim]; 268 | c.cblas_sgemm( 269 | c.CblasRowMajor, 270 | c.CblasNoTrans, 271 | c.CblasTrans, 272 | 1, 273 | @intCast(seq_len), 274 | @intCast(head_dim), 275 | 1.0 / @sqrt(@as(f32, @floatFromInt(head_dim))), 276 | q_slice.ptr, 277 | @intCast(head_dim), 278 | k_slice.ptr, 279 | @intCast(head_dim), 280 | 0.0, 281 | _attn.ptr, 282 | @intCast(seq_len), 283 | ); 284 | softmax(_attn); 285 | 286 | // Compute attn @ v. 287 | var v_slice = v[kv_offset .. kv_offset + seq_len * head_dim]; 288 | var out_slice = outputs[qo_offset .. qo_offset + 1 * head_dim]; 289 | c.cblas_sgemm( 290 | c.CblasRowMajor, 291 | c.CblasNoTrans, 292 | c.CblasNoTrans, 293 | 1, 294 | @intCast(head_dim), 295 | @intCast(seq_len), 296 | 1.0, 297 | _attn.ptr, 298 | @intCast(seq_len), 299 | v_slice.ptr, 300 | @intCast(head_dim), 301 | 0.0, 302 | out_slice.ptr, 303 | @intCast(head_dim), 304 | ); 305 | } 306 | } 307 | } 308 | 309 | pub fn load_tensor(path: []const u8, shape: []const usize, comptime dtype: type, allocator: std.mem.Allocator) ![]dtype { 310 | var n_elements: usize = 1; 311 | for (shape) |item| { 312 | n_elements *= item; 313 | } 314 | var tensor = try allocator.alloc(dtype, n_elements); 315 | 316 | const fd = try std.fs.cwd().openFile(path, .{}); 317 | defer fd.close(); 318 | _ = try fd.readAll(std.mem.sliceAsBytes(tensor)); 319 | return tensor; 320 | } 321 | 322 | pub fn load_json(path: []const u8, allocator: std.mem.Allocator) !std.json.Value { 323 | const fd = try std.fs.cwd().openFile(path, .{}); 324 | const buffer = try fd.readToEndAlloc(allocator, 4 * 1024 * 1024); 325 | return std.json.parseFromSliceLeaky(std.json.Value, allocator, buffer, .{}); 326 | } 327 | -------------------------------------------------------------------------------- /src/tests.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const ops = @import("ops.zig"); 3 | 4 | pub fn expectTensorsApproxEqual(expected: []const f32, actual: []const f32) !void { 5 | for (0..expected.len) |i| { 6 | if (@fabs(expected[i]) < 1e-3) { 7 | try std.testing.expectApproxEqAbs( 8 | expected[i], 9 | actual[i], 10 | 5e-7, 11 | ); 12 | } else { 13 | try std.testing.expectApproxEqRel( 14 | expected[i], 15 | actual[i], 16 | 6e-4, 17 | ); 18 | } 19 | } 20 | } 21 | 22 | test "Linear" { 23 | const batch_size = 3; 24 | const in_features = 768; 25 | const out_features = 4 * 768; 26 | 27 | const allocator = std.heap.page_allocator; 28 | const weight = try ops.load_tensor( 29 | "models/test/linear_weight", 30 | &[_]usize{ in_features, out_features }, 31 | f32, 32 | allocator, 33 | ); 34 | defer allocator.free(weight); 35 | const bias = try ops.load_tensor( 36 | "models/test/linear_bias", 37 | &[_]usize{out_features}, 38 | f32, 39 | allocator, 40 | ); 41 | defer allocator.free(bias); 42 | const inputs = try ops.load_tensor( 43 | "models/test/linear_inputs", 44 | &[_]usize{ batch_size, in_features }, 45 | f32, 46 | allocator, 47 | ); 48 | defer allocator.free(inputs); 49 | const expected = try ops.load_tensor( 50 | "models/test/linear_outputs", 51 | &[_]usize{ batch_size, out_features }, 52 | f32, 53 | allocator, 54 | ); 55 | defer allocator.free(expected); 56 | 57 | // Test Linear with bias. 58 | const linear = ops.Linear.init(in_features, out_features, weight, bias); 59 | const actual = try allocator.alloc(f32, batch_size * out_features); 60 | defer allocator.free(actual); 61 | linear.forward(inputs, actual); 62 | try expectTensorsApproxEqual(expected, actual); 63 | 64 | // Test Linear no bias. 65 | const expected_no_bias = try ops.load_tensor( 66 | "models/test/linear_outputs_no_bias", 67 | &[_]usize{ batch_size, out_features }, 68 | f32, 69 | allocator, 70 | ); 71 | defer allocator.free(expected_no_bias); 72 | 73 | const no_bias = ops.Linear.init(in_features, out_features, weight, null); 74 | const actual_no_bias = try allocator.alloc(f32, batch_size * out_features); 75 | defer allocator.free(actual_no_bias); 76 | no_bias.forward(inputs, actual_no_bias); 77 | try expectTensorsApproxEqual(expected_no_bias, actual_no_bias); 78 | } 79 | 80 | test "Embedding" { 81 | const batch_size = 3; 82 | const vocab_size = 10; 83 | const embedding_dim = 768; 84 | 85 | const allocator = std.heap.page_allocator; 86 | const weight = try ops.load_tensor( 87 | "models/test/embedding_weight", 88 | &[_]usize{ vocab_size, embedding_dim }, 89 | f32, 90 | allocator, 91 | ); 92 | defer allocator.free(weight); 93 | const inputs = try ops.load_tensor( 94 | "models/test/embedding_inputs", 95 | &[_]usize{batch_size}, 96 | usize, 97 | allocator, 98 | ); 99 | defer allocator.free(inputs); 100 | const expected = try ops.load_tensor( 101 | "models/test/embedding_outputs", 102 | &[_]usize{ batch_size, embedding_dim }, 103 | f32, 104 | allocator, 105 | ); 106 | defer allocator.free(expected); 107 | 108 | const embedding = ops.Embedding.init(embedding_dim, weight); 109 | const actual = try allocator.alloc(f32, batch_size * embedding_dim); 110 | defer allocator.free(actual); 111 | embedding.forward(inputs, actual); 112 | 113 | try expectTensorsApproxEqual(expected, actual); 114 | } 115 | 116 | test "LayerNorm" { 117 | const batch_size = 3; 118 | const in_features = 768; 119 | 120 | const allocator = std.heap.page_allocator; 121 | const weight = try ops.load_tensor( 122 | "models/test/layer_norm_weight", 123 | &[_]usize{in_features}, 124 | f32, 125 | allocator, 126 | ); 127 | defer allocator.free(weight); 128 | const bias = try ops.load_tensor( 129 | "models/test/layer_norm_bias", 130 | &[_]usize{in_features}, 131 | f32, 132 | allocator, 133 | ); 134 | defer allocator.free(bias); 135 | const inputs = try ops.load_tensor( 136 | "models/test/layer_norm_inputs", 137 | &[_]usize{ batch_size, in_features }, 138 | f32, 139 | allocator, 140 | ); 141 | defer allocator.free(inputs); 142 | const expected = try ops.load_tensor( 143 | "models/test/layer_norm_outputs", 144 | &[_]usize{ batch_size, in_features }, 145 | f32, 146 | allocator, 147 | ); 148 | defer allocator.free(expected); 149 | 150 | const layer_norm = ops.LayerNorm.init(in_features, weight, bias); 151 | layer_norm.forward(inputs); 152 | const actual = inputs; 153 | 154 | try expectTensorsApproxEqual(expected, actual); 155 | } 156 | 157 | test "CausalSelfAttention.split_qkv" { 158 | const batch_size = 3; 159 | const n_heads = 12; 160 | const seq_len = 5; 161 | const n_embed = 768; 162 | 163 | const allocator = std.heap.page_allocator; 164 | var inputs = try ops.load_tensor( 165 | "models/test/split_inputs", 166 | &[_]usize{ batch_size, seq_len, 3 * n_embed }, 167 | f32, 168 | allocator, 169 | ); 170 | defer allocator.free(inputs); 171 | var expected_q = try ops.load_tensor( 172 | "models/test/split_q", 173 | &[_]usize{ batch_size, seq_len, n_embed }, 174 | f32, 175 | allocator, 176 | ); 177 | defer allocator.free(expected_q); 178 | var expected_k = try ops.load_tensor( 179 | "models/test/split_k", 180 | &[_]usize{ batch_size, seq_len, n_embed }, 181 | f32, 182 | allocator, 183 | ); 184 | defer allocator.free(expected_k); 185 | var expected_v = try ops.load_tensor( 186 | "models/test/split_v", 187 | &[_]usize{ batch_size, seq_len, n_embed }, 188 | f32, 189 | allocator, 190 | ); 191 | defer allocator.free(expected_v); 192 | 193 | const fake_attn = ops.CausalSelfAttention.init(n_heads, n_embed, undefined, undefined); 194 | const actual_q = try allocator.alloc(f32, batch_size * seq_len * n_embed); 195 | defer allocator.free(actual_q); 196 | fake_attn.split_qkv(seq_len, inputs, 0, actual_q); 197 | 198 | const actual_k = try allocator.alloc(f32, batch_size * seq_len * n_embed); 199 | defer allocator.free(actual_k); 200 | fake_attn.split_qkv(seq_len, inputs, 1, actual_k); 201 | 202 | const actual_v = try allocator.alloc(f32, batch_size * seq_len * n_embed); 203 | defer allocator.free(actual_v); 204 | fake_attn.split_qkv(seq_len, inputs, 2, actual_v); 205 | 206 | try expectTensorsApproxEqual(expected_q, actual_q); 207 | try expectTensorsApproxEqual(expected_k, actual_k); 208 | try expectTensorsApproxEqual(expected_v, actual_v); 209 | } 210 | 211 | test "CausalSelfAttention.transpose" { 212 | const batch_size = 3; 213 | const n_heads = 12; 214 | const seq_len = 5; 215 | const n_embed = 768; 216 | const head_dim = n_embed / n_heads; 217 | 218 | const allocator = std.heap.page_allocator; 219 | var inputs = try ops.load_tensor( 220 | "models/test/transpose_inputs", 221 | &[_]usize{ batch_size, seq_len, n_heads, head_dim }, 222 | f32, 223 | allocator, 224 | ); 225 | defer allocator.free(inputs); 226 | var expected = try ops.load_tensor( 227 | "models/test/transpose_outputs", 228 | &[_]usize{ batch_size, n_heads, seq_len, head_dim }, 229 | f32, 230 | allocator, 231 | ); 232 | defer allocator.free(expected); 233 | 234 | const actual = try allocator.alloc(f32, batch_size * seq_len * n_embed); 235 | defer allocator.free(actual); 236 | ops.CausalSelfAttention.transpose( 237 | [3]usize{ seq_len, n_heads, head_dim }, 238 | inputs, 239 | actual, 240 | ); 241 | 242 | try expectTensorsApproxEqual(expected, actual); 243 | } 244 | 245 | test "CausalSelfAttention.forward" { 246 | const batch_size = 1; 247 | const seq_len = 5; 248 | const n_heads = 12; 249 | const head_dim = 64; 250 | const n_embed = n_heads * head_dim; 251 | 252 | const allocator = std.heap.page_allocator; 253 | var inputs = try ops.load_tensor( 254 | "models/test/attn_inputs", 255 | &[_]usize{ batch_size, seq_len, n_embed }, 256 | f32, 257 | allocator, 258 | ); 259 | defer allocator.free(inputs); 260 | var c_attn_weight = try ops.load_tensor( 261 | "models/test/attn_c_attn_weight", 262 | &[_]usize{ n_embed, 3 * n_embed }, 263 | f32, 264 | allocator, 265 | ); 266 | defer allocator.free(c_attn_weight); 267 | var c_attn_bias = try ops.load_tensor( 268 | "models/test/attn_c_attn_bias", 269 | &[_]usize{3 * n_embed}, 270 | f32, 271 | allocator, 272 | ); 273 | var c_proj_weight = try ops.load_tensor( 274 | "models/test/attn_c_proj_weight", 275 | &[_]usize{ n_embed, n_embed }, 276 | f32, 277 | allocator, 278 | ); 279 | defer allocator.free(c_proj_weight); 280 | var c_proj_bias = try ops.load_tensor( 281 | "models/test/attn_c_proj_bias", 282 | &[_]usize{n_embed}, 283 | f32, 284 | allocator, 285 | ); 286 | defer allocator.free(c_proj_bias); 287 | var expected = try ops.load_tensor( 288 | "models/test/attn_outputs", 289 | &[_]usize{ batch_size, seq_len, n_embed }, 290 | f32, 291 | allocator, 292 | ); 293 | defer allocator.free(expected); 294 | 295 | const c_attn = ops.Linear.init(n_embed, 3 * n_embed, c_attn_weight, c_attn_bias); 296 | const c_proj = ops.Linear.init(n_embed, n_embed, c_proj_weight, c_proj_bias); 297 | const attn = ops.CausalSelfAttention.init(n_heads, n_embed, c_attn, c_proj); 298 | 299 | const actual = try allocator.alloc(f32, batch_size * seq_len * n_embed); 300 | defer allocator.free(actual); 301 | const k_cache = try allocator.alloc(f32, batch_size * seq_len * n_embed); 302 | defer allocator.free(k_cache); 303 | const v_cache = try allocator.alloc(f32, batch_size * seq_len * n_embed); 304 | defer allocator.free(v_cache); 305 | const _qkv = try allocator.alloc(f32, batch_size * 1 * 3 * n_embed); 306 | defer allocator.free(_qkv); 307 | const _q = try allocator.alloc(f32, batch_size * 1 * n_embed); 308 | defer allocator.free(_q); 309 | const _k = try allocator.alloc(f32, batch_size * seq_len * n_embed); 310 | defer allocator.free(_k); 311 | const _v = try allocator.alloc(f32, batch_size * seq_len * n_embed); 312 | defer allocator.free(_v); 313 | const _attn = try allocator.alloc(f32, 1 * seq_len); 314 | defer allocator.free(_attn); 315 | 316 | for (0..seq_len) |s| { 317 | attn.forward( 318 | s + 1, 319 | inputs[s * n_embed .. (s + 1) * n_embed], 320 | k_cache[0 .. (s + 1) * n_embed], 321 | v_cache[0 .. (s + 1) * n_embed], 322 | actual[s * n_embed .. (s + 1) * n_embed], 323 | _qkv, 324 | _q, 325 | _k[0 .. (s + 1) * n_embed], 326 | _v[0 .. (s + 1) * n_embed], 327 | _attn[0..(s + 1)], 328 | ); 329 | try expectTensorsApproxEqual( 330 | expected[s * n_embed .. (s + 1) * n_embed], 331 | actual[s * n_embed .. (s + 1) * n_embed], 332 | ); 333 | } 334 | } 335 | 336 | test "gelu" { 337 | const batch_size = 3; 338 | const in_features = 768; 339 | 340 | const allocator = std.heap.page_allocator; 341 | var inputs = try ops.load_tensor( 342 | "models/test/gelu_inputs", 343 | &[_]usize{ batch_size, in_features }, 344 | f32, 345 | allocator, 346 | ); 347 | defer allocator.free(inputs); 348 | const expected = try ops.load_tensor( 349 | "models/test/gelu_outputs", 350 | &[_]usize{ batch_size, in_features }, 351 | f32, 352 | allocator, 353 | ); 354 | defer allocator.free(expected); 355 | 356 | ops.gelu(inputs); 357 | const actual = inputs; 358 | 359 | try expectTensorsApproxEqual(expected, actual); 360 | } 361 | 362 | test "softmax" { 363 | const batch_size = 3; 364 | const in_features = 768; 365 | 366 | const allocator = std.heap.page_allocator; 367 | var inputs = try ops.load_tensor( 368 | "models/test/softmax_inputs", 369 | &[_]usize{ batch_size, in_features }, 370 | f32, 371 | allocator, 372 | ); 373 | defer allocator.free(inputs); 374 | const expected = try ops.load_tensor( 375 | "models/test/softmax_outputs", 376 | &[_]usize{ batch_size, in_features }, 377 | f32, 378 | allocator, 379 | ); 380 | defer allocator.free(expected); 381 | 382 | for (0..batch_size) |b| { 383 | ops.softmax(inputs[b * in_features .. (b + 1) * in_features]); 384 | } 385 | const actual = inputs; 386 | 387 | try expectTensorsApproxEqual(expected, actual); 388 | } 389 | --------------------------------------------------------------------------------