├── .gitignore ├── README.md ├── conda-env.yaml ├── micro_llama.ipynb └── micro_llama.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/llama3-8b 2 | __pycache__ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Micro LLAMA 3 | 4 | This is a tiny implementation of the LLAMA 3 model architecture for didactical purposes. The entire implementation is approximately 180 lines of code, hence the name "micro". 5 | 6 | The code uses the smallest LLAMA 3 model, i.e., the 8B parameters one. This model is still 15GB in size, and requires about 30GB of memory to execute. 7 | The code by defaults runs this on the CPU, but beware of the memory impact. 8 | 9 | Start exploring the code using the notebook `micro_llama.ipynb`. 10 | 11 | The model's code itself is entirely contained in the `micro_llama.py` file. 12 | 13 | ## Requirements 14 | 15 | Use the following instruction to create a suitable Conda environment, called `micro_llama`: 16 | 17 | ```bash 18 | conda env create --file conda-env.yaml --yes 19 | conda activate micro_llama 20 | ``` 21 | 22 | You can get rid of the Conda enviroment as follows: 23 | 24 | ```bash 25 | conda remove -n micro_llama --all --y 26 | ``` 27 | 28 | ## References 29 | 30 | This implementation is inspired by: 31 | 32 | * [building-llama-3-from-scratch](https://lightning.ai/fareedhassankhan12/studios/building-llama-3-from-scratch) 33 | * [Building-llama3-from-scratch](https://github.com/FareedKhan-dev/Building-llama3-from-scratch) -------------------------------------------------------------------------------- /conda-env.yaml: -------------------------------------------------------------------------------- 1 | name: micro_llama 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - blobfile 7 | - jupyter 8 | - matplotlib 9 | - nbstripout 10 | - numpy<2.0.0 11 | - pytorch::pytorch 12 | - tiktoken 13 | -------------------------------------------------------------------------------- /micro_llama.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Download the model weights\n", 10 | "import os\n", 11 | "import urllib.request\n", 12 | "\n", 13 | "downloads = [\n", 14 | " {\n", 15 | " \"filename\": \"data/llama3-8b/tokenizer.model\",\n", 16 | " \"url\": \"https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/1460c22666392e470910ce3d44ffeb2ab7dbd4df/original/tokenizer.model\",\n", 17 | " },\n", 18 | " {\n", 19 | " \"filename\": \"data/llama3-8b/consolidated.00.pth\",\n", 20 | " \"url\": \"https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/1460c22666392e470910ce3d44ffeb2ab7dbd4df/original/consolidated.00.pth\",\n", 21 | " },\n", 22 | "]\n", 23 | "\n", 24 | "for download in downloads:\n", 25 | " if not os.path.isfile(download[\"filename\"]):\n", 26 | " os.makedirs(os.path.dirname(download[\"filename\"]), exist_ok=True)\n", 27 | " print(f\"Downloading {download[\"url\"]} to {download[\"filename\"]}\")\n", 28 | " urllib.request.urlretrieve(download[\"url\"], download[\"filename\"])\n", 29 | " else:\n", 30 | " print(f\"File {download[\"filename\"]} already found, skipping download\")" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "# Load the Tiktoken tokenizer\n", 40 | "import torch\n", 41 | "import micro_llama\n", 42 | "\n", 43 | "tokenizer = micro_llama.make_tokenizer(\"data/llama3-8B/tokenizer.model\")" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "# Demonstrate the Tiktoken tokenizer\n", 53 | "prompt = \"the answer to the ultimate question of life, the universe, and everything is \"\n", 54 | "tokens = tokenizer.encode(prompt)\n", 55 | "prompt_ = tokenizer.decode(tokens)\n", 56 | "\n", 57 | "print(prompt)\n", 58 | "print(tokens)\n", 59 | "print(prompt_)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "# Demonstrate the RoPE positional embedding\n", 69 | "N = 64\n", 70 | "D = 256\n", 71 | "theta = 500_000\n", 72 | "theta = 5\n", 73 | "\n", 74 | "x = torch.randn(1, D)\n", 75 | "x = x.expand(N, D) + torch.randn(N, D) * 0.01\n", 76 | "x = x / x.norm(dim=-1, keepdim=True)\n", 77 | "\n", 78 | "y = micro_llama.rope(x.reshape(1, N, 1, D), theta=theta)\n", 79 | "y = y.reshape(N,D)\n", 80 | "\n", 81 | "M = x @ x.transpose(-2, -1)\n", 82 | "M_ = y @ y.transpose(-2, -1)\n", 83 | "\n", 84 | "from matplotlib import pyplot as plt\n", 85 | "plt.figure()\n", 86 | "plt.subplot(1,2,1)\n", 87 | "plt.imshow(M.detach().numpy())\n", 88 | "plt.title(\"Without RoPE\")\n", 89 | "plt.subplot(1,2,2)\n", 90 | "plt.title(\"With RoPE\")\n", 91 | "plt.imshow(M_.detach().numpy())" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "# Load the LLAMA3 8B model\n", 101 | "llama = micro_llama.Llama()\n", 102 | "params = torch.load('data/llama3-8B/consolidated.00.pth', weights_only=True)\n", 103 | "llama.load_state_dict(params)\n", 104 | "llama.eval()" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "# Demonstrate the LLAMA3 model\n", 114 | "prompt = \"the answer to the ultimate question of life, the universe, and everything is \"\n", 115 | "x = torch.tensor([128000] + tokenizer.encode(prompt))\n", 116 | "print(tokenizer.decode(list(x)))\n", 117 | "\n", 118 | "y = llama(x.unsqueeze(0))\n", 119 | "print(tokenizer.decode(list(y.argmax(dim=-1)[0])))" 120 | ] 121 | } 122 | ], 123 | "metadata": { 124 | "kernelspec": { 125 | "display_name": "gen", 126 | "language": "python", 127 | "name": "python3" 128 | }, 129 | "language_info": { 130 | "codemirror_mode": { 131 | "name": "ipython", 132 | "version": 3 133 | }, 134 | "file_extension": ".py", 135 | "mimetype": "text/x-python", 136 | "name": "python", 137 | "nbconvert_exporter": "python", 138 | "pygments_lexer": "ipython3", 139 | "version": "3.12.7" 140 | } 141 | }, 142 | "nbformat": 4, 143 | "nbformat_minor": 2 144 | } 145 | -------------------------------------------------------------------------------- /micro_llama.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tiktoken 3 | import tiktoken.load 4 | 5 | 6 | def rope(x, theta): 7 | B, N, H, D = x.shape 8 | freq = theta ** -torch.arange(0, 1, 2 / D, device=x.device) 9 | time = torch.arange(N, device=x.device) 10 | phase = freq.reshape(1, 1, 1, D // 2) * time.reshape(1, N, 1, 1) # (1, N, 1, D//2) 11 | c = torch.cos(phase) 12 | s = torch.sin(phase) 13 | rot = torch.stack([c, s, -s, c], dim=-1).reshape(1, N, 1, D // 2, 2, 2) 14 | x = x.reshape(B, N, H, D // 2, 1, 2) @ rot 15 | return x.reshape(B, N, H, D) 16 | 17 | 18 | def make_tokenizer(path): 19 | tokenizer_model = tiktoken.load.load_tiktoken_bpe(path) 20 | special_tokens = [ 21 | "<|begin_of_text|>", # Marks the beginning of a text sequence. 22 | "<|end_of_text|>", # Marks the end of a text sequence. 23 | "<|reserved_special_token_0|>", # Reserved for future use. 24 | "<|reserved_special_token_1|>", # Reserved for future use. 25 | "<|reserved_special_token_2|>", # Reserved for future use. 26 | "<|reserved_special_token_3|>", # Reserved for future use. 27 | "<|start_header_id|>", # Indicates the start of a header ID. 28 | "<|end_header_id|>", # Indicates the end of a header ID. 29 | "<|reserved_special_token_4|>", # Reserved for future use. 30 | "<|eot_id|>", # Marks the end of a turn (in a conversational context). 31 | ] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)] 32 | tokenize_breaker = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" 33 | return tiktoken.Encoding( 34 | name=path, 35 | pat_str=tokenize_breaker, 36 | mergeable_ranks=tokenizer_model, 37 | special_tokens={ 38 | token: len(tokenizer_model) + i for i, token in enumerate(special_tokens) 39 | }, 40 | ) 41 | 42 | 43 | class RMSNorm(torch.nn.Module): 44 | def __init__(self, dim, epsilon): 45 | super().__init__() 46 | self.epsilon = epsilon 47 | self.weight = torch.nn.Parameter(torch.ones(dim)) 48 | 49 | def forward(self, x): 50 | return (x * self.weight) * torch.rsqrt( 51 | (x * x).mean(-1, keepdim=True) + self.epsilon 52 | ) 53 | 54 | 55 | class Attention(torch.nn.Module): 56 | def __init__(self, dim, n_heads, n_kv_heads, rope_theta): 57 | super().__init__() 58 | self.n_heads = n_heads 59 | self.n_kv_heads = n_kv_heads 60 | self.rope_theta = rope_theta 61 | self.wq = torch.nn.Linear(dim, dim, bias=False) 62 | self.wk = torch.nn.Linear(dim, dim // (n_heads // n_kv_heads), bias=False) 63 | self.wv = torch.nn.Linear(dim, dim // (n_heads // n_kv_heads), bias=False) 64 | self.wo = torch.nn.Linear(dim, dim, bias=False) 65 | 66 | def forward(self, x): 67 | B, N, D = x.shape 68 | H = self.n_heads 69 | J = self.n_kv_heads 70 | q = self.wq(x) 71 | k = self.wk(x) 72 | v = self.wv(x) 73 | 74 | q = q.reshape((B, N, H, D // H)) 75 | k = k.reshape((B, N, J, D // H)) 76 | v = v.reshape((B, N, J, D // H)) 77 | 78 | q = rope(q, theta=self.rope_theta) 79 | k = rope(k, theta=self.rope_theta) 80 | 81 | k = ( 82 | k.reshape((B, N, J, 1, D // H)) 83 | .expand((B, N, J, H // J, D // H)) 84 | .reshape((B, N, H, D // H)) 85 | ) 86 | v = ( 87 | v.reshape((B, N, J, 1, D // H)) 88 | .expand((B, N, J, H // J, D // H)) 89 | .reshape((B, N, H, D // H)) 90 | ) 91 | 92 | q = q.transpose(1, 2) # (B, H, N, D//H) 93 | k = k.transpose(1, 2) 94 | v = v.transpose(1, 2) 95 | 96 | dot = q @ k.transpose(2, 3) / (D // H) ** 0.5 97 | mask = torch.full((N, N), float("-inf"), device=x.device) 98 | mask = torch.triu(mask, diagonal=1) 99 | dot = dot + mask # (B, H, N, N) 100 | 101 | weight = torch.nn.functional.softmax(dot, dim=-1) 102 | x = weight @ v # (B, H, N, D//H) 103 | x = x.transpose(1, 2).reshape((B, N, D)) 104 | x = self.wo(x) 105 | return x 106 | 107 | 108 | class FeedForward(torch.nn.Module): 109 | def __init__(self, dim, latent_dim): 110 | super().__init__() 111 | self.w1 = torch.nn.Linear(dim, latent_dim, bias=False) 112 | self.w2 = torch.nn.Linear(latent_dim, dim, bias=False) 113 | self.w3 = torch.nn.Linear(dim, latent_dim, bias=False) 114 | 115 | def forward(self, x): 116 | a = torch.nn.functional.silu(self.w1(x)) 117 | b = self.w3(x) 118 | return self.w2(a * b) 119 | 120 | 121 | class Layer(torch.nn.Module): 122 | def __init__(self, dim, n_heads, n_kv_heads, norm_eps, latent_dim, rope_theta): 123 | super().__init__() 124 | self.attention_norm = RMSNorm(dim, norm_eps) 125 | self.ffn_norm = RMSNorm(dim, norm_eps) 126 | self.attention = Attention(dim, n_heads, n_kv_heads, rope_theta) 127 | self.feed_forward = FeedForward(dim, latent_dim) 128 | 129 | def forward(self, x): 130 | y = self.attention_norm(x) 131 | y = self.attention(y) 132 | x = x + y 133 | 134 | y = self.ffn_norm(x) 135 | y = self.feed_forward(y) 136 | x = x + y 137 | return x 138 | 139 | 140 | class Llama(torch.nn.Module): 141 | def __init__( 142 | self, 143 | dim=4096, 144 | n_layers=32, 145 | n_heads=32, 146 | n_kv_heads=8, 147 | vocab_size=128256, 148 | norm_eps=1e-5, 149 | latent_dim=14336, 150 | rope_theta=500000.0, 151 | ): 152 | super().__init__() 153 | self.dim = dim 154 | self.n_layers = n_layers 155 | self.n_heads = n_heads 156 | self.n_kv_heads = n_kv_heads 157 | self.vocab_size = vocab_size 158 | self.norm_eps = norm_eps 159 | self.latent_dim = latent_dim 160 | self.rope_theta = rope_theta 161 | 162 | self.tok_embeddings = torch.nn.Embedding(self.vocab_size, self.dim) 163 | self.layers = torch.nn.ModuleList( 164 | [ 165 | Layer( 166 | dim=self.dim, 167 | n_heads=self.n_heads, 168 | n_kv_heads=self.n_kv_heads, 169 | norm_eps=self.norm_eps, 170 | latent_dim=self.latent_dim, 171 | rope_theta=self.rope_theta, 172 | ) 173 | for _ in range(self.n_layers) 174 | ] 175 | ) 176 | self.norm = RMSNorm(self.dim, self.norm_eps) 177 | self.output = torch.nn.Linear(self.dim, self.vocab_size, bias=False) 178 | 179 | def forward(self, x): 180 | x = self.tok_embeddings(x) 181 | for layer in self.layers: 182 | x = layer(x) 183 | x = self.norm(x) 184 | x = self.output(x) 185 | return x 186 | --------------------------------------------------------------------------------