├── .gitignore ├── LICENSE ├── README.md ├── README_en.md ├── images ├── 42.png ├── a10.png ├── afterattention.png ├── archi.png ├── attention.png ├── embeddings.png ├── finallayer.png ├── freq_cis.png ├── god.png ├── heads.png ├── implllama3_30_0.png ├── implllama3_39_0.png ├── implllama3_41_0.png ├── implllama3_42_0.png ├── implllama3_50_0.png ├── implllama3_52_0.png ├── implllama3_54_0.png ├── karpathyminbpe.png ├── keys.png ├── keys0.png ├── last_norm.png ├── mask.png ├── model.png ├── norm.png ├── norm_after.png ├── q_per_token.png ├── qkmatmul.png ├── qkv.png ├── qsplit.png ├── rms.png ├── rope.png ├── ropesplit.png ├── softmax.png ├── stacked.png ├── swiglu.png ├── tokens.png ├── v0.png ├── value.png └── weightmatrix.png ├── llama3-from-scratch_en.ipynb ├── llama3-from-scratch_zh.ipynb ├── llama3 ├── README.md ├── model.py └── tokenizer.py ├── pdf └── 从零实现 Llama3 模型.pdf └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | Meta-Llama-3-8B-Instruct-2layers/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Nishant Aklecha 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 从零实现 Llama3 模型 2 | 3 | ## 注意 4 | 5 | 1. 本文翻译自大佬的 [llama3-from-scratch](https://github.com/naklecha/llama3-from-scratch) 仓库,本人只是将英文翻译为中文,并无任何改动,略微改动模型权重文件,方便加载。原版英文:[README_en.md](README_en.md)。 6 | 2. 原版模型已上传至ModelScope,大小约 15G,[Meta-Llama-3-8B-Instruct](https://www.modelscope.cn/models/wdndev/Meta-Llama-3-8B-Instruct-torch/summary); 7 | 3. 因原版 Llama3 8B 模型32层 Transformers,且大佬仓库使用CPU加载,如果加载全部的参数,16G内存机器加载失败,故选取原版 Llama3 8B 模型权重的前2层,重新保存,大小约为2.7G,此文档也可以直接加载,**实际测试内存占用约4~5G**,唯一缺点是后续推理结果不对,但不影响学习矩阵变换等其他知识,链接为 [Meta-Llama-3-8B-Instruct-2layers](https://www.modelscope.cn/models/wdndev/Meta-Llama-3-8B-Instruct-2layers/summary). 8 | 4. 如果对你有用麻烦点一下star,谢谢! 9 | 10 | ## 模型及Colab 11 | 12 | 模型链接 13 | 14 | - Haggingface link: https://huggingface.co/wdndev/Meta-Llama-3-8B-Instruct-2layers 15 | - ModeScope link: https://www.modelscope.cn/models/wdndev/Meta-Llama-3-8B-Instruct-2layers 16 | 17 | colab链接 18 | 19 | - llama3-from-scratch-en: https://colab.research.google.com/drive/1X9yEa4hAZzgrwTuxHValBoVt1qfx6AXv?usp=sharing 20 | - llama3-from-scratch-zh: https://colab.research.google.com/drive/11MQb8Bn4Ck707VEcqqGVdytqOk3OrQQK?usp=sharing 21 | 22 | ## 从零实现 Llama3 模型 23 | 24 | 在这个文件中,从头实现了 Llama3,其中包含张量和矩阵乘法。 25 | 26 | 此外,直接从 Meta 提供的 Llama3 模型文件中加载张量,在运行此文件之前,需要下载权重。 27 | 这是官方链接: https://llama.meta.com/llama-downloads/ 28 | 29 | > 原版模型已上传至ModelScope,大小约 15G,[Meta-Llama-3-8B-Instruct](https://www.modelscope.cn/models/wdndev/Meta-Llama-3-8B-Instruct-torch/summary) 30 | 31 |
32 | 33 |
34 | 35 | ## tokenizer 36 | 37 | 不会实现一个 BPE 分词器(但 Andrej Karpathy 也有一个非常简洁的实现) 38 | 39 | 这是他的项目地址: https://github.com/karpathy/minbpe 40 | 41 | 42 |
43 | 44 |
45 | 46 | 47 | ```python 48 | from pathlib import Path 49 | import tiktoken 50 | from tiktoken.load import load_tiktoken_bpe 51 | import torch 52 | import json 53 | import matplotlib.pyplot as plt 54 | 55 | # 加载分词器模型路径 56 | tokenizer_path = "Meta-Llama-3-8B-Instruct/tokenizer.model" 57 | special_tokens = [ 58 | "<|begin_of_text|>", 59 | "<|end_of_text|>", 60 | "<|reserved_special_token_0|>", 61 | "<|reserved_special_token_1|>", 62 | "<|reserved_special_token_2|>", 63 | "<|reserved_special_token_3|>", 64 | "<|start_header_id|>", 65 | "<|end_header_id|>", 66 | "<|reserved_special_token_4|>", 67 | "<|eot_id|>", # end of turn 68 | ] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)] 69 | mergeable_ranks = load_tiktoken_bpe(tokenizer_path) 70 | tokenizer = tiktoken.Encoding( 71 | name=Path(tokenizer_path).name, 72 | pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+", 73 | mergeable_ranks=mergeable_ranks, 74 | special_tokens={token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)}, 75 | ) 76 | 77 | # 测试分词器编码和解码功能 78 | tokenizer.decode(tokenizer.encode("hello world!")) 79 | ``` 80 | 81 | ```bash 82 | hello world! 83 | ``` 84 | 85 | 86 | ## 读取模型文件 87 | 88 | 通常,读取模型文件,往往取决于模型类的编写方式以及其中的变量名。 89 | 90 | 但由于要从零实现 Llama3,将一次性读取一个张量。 91 | 92 |
93 | 94 |
95 | 96 | 97 | ```python 98 | # 加载模型权重 99 | model = torch.load("Meta-Llama-3-8B-Instruct/consolidated.00.pth") 100 | print(json.dumps(list(model.keys())[:20], indent=4)) 101 | ``` 102 | 103 | ```bash 104 | [ 105 | "tok_embeddings.weight", 106 | "layers.0.attention.wq.weight", 107 | "layers.0.attention.wk.weight", 108 | "layers.0.attention.wv.weight", 109 | "layers.0.attention.wo.weight", 110 | "layers.0.feed_forward.w1.weight", 111 | "layers.0.feed_forward.w3.weight", 112 | "layers.0.feed_forward.w2.weight", 113 | "layers.0.attention_norm.weight", 114 | "layers.0.ffn_norm.weight", 115 | "layers.1.attention.wq.weight", 116 | "layers.1.attention.wk.weight", 117 | "layers.1.attention.wv.weight", 118 | "layers.1.attention.wo.weight", 119 | "layers.1.feed_forward.w1.weight", 120 | "layers.1.feed_forward.w3.weight", 121 | "layers.1.feed_forward.w2.weight", 122 | "layers.1.attention_norm.weight", 123 | "layers.1.ffn_norm.weight", 124 | "layers.2.attention.wq.weight" 125 | ] 126 | ``` 127 | 128 | ```python 129 | 130 | # 获取模型配置参数 131 | with open("Meta-Llama-3-8B-Instruct/params.json", "r") as f: 132 | config = json.load(f) 133 | config 134 | ``` 135 | 136 | ```json 137 | { 138 | "dim": 4096, 139 | "n_layers": 32, 140 | "n_heads": 32, 141 | "n_kv_heads": 8, 142 | "vocab_size": 128256, 143 | "multiple_of": 1024, 144 | "ffn_dim_multiplier": 1.3, 145 | "norm_eps": 1e-05, 146 | "rope_theta": 500000.0 147 | } 148 | 149 | ``` 150 | 151 | 152 | ## 使用这些配置推理模型的细节 153 | 154 | 1. 模型有 32 个 Transformer 层 155 | 2. 每个多头注意力块有 32 个头 156 | 3. 词汇表大小等 157 | 158 | 159 | ```python 160 | # 从配置文件中提取模型参数 161 | dim = config["dim"] 162 | n_layers = config["n_layers"] 163 | n_heads = config["n_heads"] 164 | n_kv_heads = config["n_kv_heads"] 165 | vocab_size = config["vocab_size"] 166 | multiple_of = config["multiple_of"] 167 | ffn_dim_multiplier = config["ffn_dim_multiplier"] 168 | norm_eps = config["norm_eps"] 169 | rope_theta = torch.tensor(config["rope_theta"]) 170 | ``` 171 | 172 | ## 将文本转换为 token 173 | 174 | 这里使用 tiktoken(OpenAI 的库)作为分词器 175 | 176 |
177 | 178 |
179 | 180 | 181 | ```python 182 | prompt = "the answer to the ultimate question of life, the universe, and everything is " 183 | 184 | # 编码为token 185 | tokens = [128000] + tokenizer.encode(prompt) 186 | print(tokens) 187 | tokens = torch.tensor(tokens) 188 | 189 | # 将每个 token 解码为对应的文本 190 | prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens] 191 | print(prompt_split_as_tokens) 192 | ``` 193 | 194 | ```bash 195 | [128000, 1820, 4320, 311, 279, 17139, 3488, 315, 2324, 11, 279, 15861, 11, 323, 4395, 374, 220] 196 | ['<|begin_of_text|>', 'the', ' answer', ' to', ' the', ' ultimate', ' question', ' of', ' life', ',', ' the', ' universe', ',', ' and', ' everything', ' is', ' '] 197 | ``` 198 | 199 | ## 将 token 转换为 embedding 200 | 201 | 这里使用内置的神经网络模块 202 | 203 | 无论如何, `[17x1]` token 现在是 `[17x4096]`,即每个 token 的长度为 4096 的 embeddings 204 | 205 | 注意:跟踪 shapes,这样一切将变得理解更容易 206 | 207 |
208 | 209 |
210 | 211 | 212 | ```python 213 | 214 | # 加载嵌入层并复制权重 215 | embedding_layer = torch.nn.Embedding(vocab_size, dim) 216 | embedding_layer.weight.data.copy_(model["tok_embeddings.weight"]) 217 | 218 | # 获取未归一化的 token 嵌入 219 | token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16) 220 | token_embeddings_unnormalized.shape 221 | ``` 222 | 223 | ```bash 224 | torch.Size([17, 4096]) 225 | ``` 226 | 227 | 228 | ## 接下来使用 RMS 归一化嵌入 229 | 230 | 请注意,经过此步骤后 shapes 不变, 只是值被归一化 231 | 232 | 需要注意的是,需要一个 norm_eps(来自配置)以避免不小心将 RMS 设置为 0 并导致除以 0 的情况 233 | 234 | 这是公式: 235 | 236 |
237 | 238 |
239 | 240 | 241 | ```python 242 | # rms 归一化函数 243 | 244 | # def rms_norm(tensor, norm_weights): 245 | # rms = (tensor.pow(2).mean(-1, keepdim=True) + norm_eps)**0.5 246 | # return tensor * (norm_weights / rms) 247 | 248 | def rms_norm(tensor, norm_weights): 249 | return (tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + norm_eps)) * norm_weights 250 | ``` 251 | 252 | # 构建第一个 Transformer 层 253 | 254 | 255 | ### 归一化 256 | 257 | 从模型字典中访问 `layer.0` (这是第一层) 258 | 259 | 260 | 归一化后 shapes 仍然是 `[17x4096]`, 与嵌入相同但已归一化 261 | 262 |
263 | 264 |
265 | 266 | 267 | ```python 268 | # 归一化token嵌入 269 | token_embeddings = rms_norm(token_embeddings_unnormalized, model["layers.0.attention_norm.weight"]) 270 | token_embeddings.shape 271 | ``` 272 | 273 | ```bash 274 | torch.Size([17, 4096]) 275 | ``` 276 | 277 | 278 | ### 从头实现注意力机制 279 | 280 | 加载第一个 Transformer 层的注意力头 281 | 282 |
283 | 284 |
285 | 286 | 当我们从模型中加载 `query`, `key`,`value` 和 `output` 向量时,注意到 shapes 分别为 `[4096x4096]`, `[1024x4096]`, `[1024x4096]`, `[4096x4096]` 287 | 288 | 乍一看这有些奇怪,因为在理想情况下我们希望每个头单独拥有各自的 q,k,v 和 o 289 | 290 | 这里作者将其捆绑在一起,为什么会这样呢? 因为这样有助于并行化注意力头的计算 291 | 292 | 将展开所有内容... 293 | 294 | 295 | ```python 296 | # 打印第一个层的注意力权重 shapes 297 | print( 298 | model["layers.0.attention.wq.weight"].shape, 299 | model["layers.0.attention.wk.weight"].shape, 300 | model["layers.0.attention.wv.weight"].shape, 301 | model["layers.0.attention.wo.weight"].shape 302 | ) 303 | ``` 304 | 305 | ```bash 306 | torch.Size([4096, 4096]) 307 | torch.Size([1024, 4096]) 308 | torch.Size([1024, 4096]) 309 | torch.Size([4096, 4096]) 310 | ``` 311 | 312 | ### 展开 query 313 | 314 | 在下一节中,将展开多个注意力头的 query,得到的 shapes 为 `[32x128x4096]` 315 | 316 | 这里的 32 是 Llama3 的注意力头数量,128 是 query 向量的大小,4096 是 token 嵌入的大小 317 | 318 | ```python 319 | # reshape query 权重为[头数,头维度,嵌入维度] 320 | 321 | q_layer0 = model["layers.0.attention.wq.weight"] 322 | head_dim = q_layer0.shape[0] // n_heads 323 | q_layer0 = q_layer0.view(n_heads, head_dim, dim) 324 | q_layer0.shape 325 | ``` 326 | 327 | ```bash 328 | torch.Size([32, 128, 4096]) 329 | ``` 330 | 331 | 332 | ### 实现第一层的第一个头 333 | 334 | 这里查询了第一个层的第一个头的 `query` 权重矩阵,其大小为 `[128x4096]` 335 | 336 | 337 | ```python 338 | q_layer0_head0 = q_layer0[0] 339 | q_layer0_head0.shape 340 | ``` 341 | 342 | ```bash 343 | torch.Size([128, 4096]) 344 | ``` 345 | 346 | 347 | ### 现在将 query 权重与 token 嵌入相乘,以获得每个 token 的 query 348 | 349 | 这里可以看到得到的 shape 是 `[17x128]`, 这是因为有 17 个 token,每个 token 有一个长度为 128 的 query 350 | 351 |
352 | 353 |
354 | 355 | 356 | ```python 357 | q_per_token = torch.matmul(token_embeddings, q_layer0_head0.T) 358 | q_per_token.shape 359 | ``` 360 | 361 | ```bash 362 | torch.Size([17, 128]) 363 | ``` 364 | 365 | 366 | ## 位置编码 367 | 368 | 当前,每个 token 都有一个 query 向量,但如果你想一想 -- 其实各个 query 向量并不知道它们在 prompt 中的位置。 369 | 370 | ```text 371 | query: "the answer to the ultimate question of life, the universe, and everything is " 372 | ``` 373 | 374 | 在我示例 prompt 中,使用了三次 `"the"`,需要根据它们在 prompt 中的位置为每个 `"the"` token 生成不同的 `query` 向量(每个长度为128)。可以使用 RoPE(旋转位置编码)来实现这一点。 375 | 376 | ### RoPE 377 | 来看看这个视频(我就是看的这个)可以理解其中的数据学逻辑。 378 | https://www.youtube.com/watch?v=o29P0Kpobz0&t=530s 379 | 380 | > 国内B站视频链接:[Rotary Positional Embeddings Combining Absolute and Relative](https://www.bilibili.com/video/BV1nt421N7U5/?vd_source=6bc8f793c75740c7bcfb8e281f986a8e&t=530s) 381 | 382 |
383 | 384 |
385 | 386 | 387 | ```python 388 | q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2) 389 | q_per_token_split_into_pairs.shape 390 | ``` 391 | 392 | ```bash 393 | torch.Size([17, 64, 2]) 394 | ``` 395 | 396 | 这里为 prompt 中每个位置生成了旋转位置编码。可以看到,这些编码是正弦和余弦函数的组合。 397 | 398 | 在上的步骤里, 将 `query` 向量分成对, 并对每对应用旋转角度移位! 399 | 400 | 401 | 现在有一个大小为 `[17x64x2]` 的向量,这是针对 prompt 中的每个 token 将 128 个长度的 query 分为 64 对! 这 64 对中的每一对都将旋转 `m*(theta)`,其中 `m` 是旋转查询的 token 的位置! 402 | 403 | 404 |
405 | 406 |
407 | 408 | 409 | ## 使用复数点积计算旋转向量 410 | 411 |
412 | 413 |
414 | 415 | 416 | ```python 417 | zero_to_one_split_into_64_parts = torch.tensor(range(64))/64 418 | zero_to_one_split_into_64_parts 419 | ``` 420 | 421 | ```bash 422 | tensor([0.0000, 0.0156, 0.0312, 0.0469, 0.0625, 0.0781, 0.0938, 0.1094, 0.1250, 423 | 0.1406, 0.1562, 0.1719, 0.1875, 0.2031, 0.2188, 0.2344, 0.2500, 0.2656, 424 | 0.2812, 0.2969, 0.3125, 0.3281, 0.3438, 0.3594, 0.3750, 0.3906, 0.4062, 425 | 0.4219, 0.4375, 0.4531, 0.4688, 0.4844, 0.5000, 0.5156, 0.5312, 0.5469, 426 | 0.5625, 0.5781, 0.5938, 0.6094, 0.6250, 0.6406, 0.6562, 0.6719, 0.6875, 427 | 0.7031, 0.7188, 0.7344, 0.7500, 0.7656, 0.7812, 0.7969, 0.8125, 0.8281, 428 | 0.8438, 0.8594, 0.8750, 0.8906, 0.9062, 0.9219, 0.9375, 0.9531, 0.9688, 429 | 0.9844]) 430 | ``` 431 | 432 | 433 | ```python 434 | freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts) 435 | freqs 436 | ``` 437 | 438 | ```bash 439 | tensor([1.0000e+00, 8.1462e-01, 6.6360e-01, 5.4058e-01, 4.4037e-01, 3.5873e-01, 440 | 2.9223e-01, 2.3805e-01, 1.9392e-01, 1.5797e-01, 1.2869e-01, 1.0483e-01, 441 | 8.5397e-02, 6.9566e-02, 5.6670e-02, 4.6164e-02, 3.7606e-02, 3.0635e-02, 442 | 2.4955e-02, 2.0329e-02, 1.6560e-02, 1.3490e-02, 1.0990e-02, 8.9523e-03, 443 | 7.2927e-03, 5.9407e-03, 4.8394e-03, 3.9423e-03, 3.2114e-03, 2.6161e-03, 444 | 2.1311e-03, 1.7360e-03, 1.4142e-03, 1.1520e-03, 9.3847e-04, 7.6450e-04, 445 | 6.2277e-04, 5.0732e-04, 4.1327e-04, 3.3666e-04, 2.7425e-04, 2.2341e-04, 446 | 1.8199e-04, 1.4825e-04, 1.2077e-04, 9.8381e-05, 8.0143e-05, 6.5286e-05, 447 | 5.3183e-05, 4.3324e-05, 3.5292e-05, 2.8750e-05, 2.3420e-05, 1.9078e-05, 448 | 1.5542e-05, 1.2660e-05, 1.0313e-05, 8.4015e-06, 6.8440e-06, 5.5752e-06, 449 | 4.5417e-06, 3.6997e-06, 3.0139e-06, 2.4551e-06]) 450 | ``` 451 | 452 | ```python 453 | freqs_for_each_token = torch.outer(torch.arange(17), freqs) 454 | freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token) 455 | freqs_cis.shape 456 | 457 | # 查看freqs_cis的第三行 458 | value = freqs_cis[3] 459 | plt.figure() 460 | for i, element in enumerate(value[:17]): 461 | plt.plot([0, element.real], [0, element.imag], color='blue', linewidth=1, label=f"Index: {i}") 462 | plt.annotate(f"{i}", xy=(element.real, element.imag), color='red') 463 | plt.xlabel('Real') 464 | plt.ylabel('Imaginary') 465 | plt.title('Plot of one row of freqs_cis') 466 | plt.show() 467 | ``` 468 | 469 | 470 | ![png](images/implllama3_30_0.png) 471 | ​ 472 | 473 | 474 | ### 现在每个 token 的 query 元素都有一个复数(角度变化向量) 475 | 476 | 可以将 query(将其拆分成对)转换为复数,然后进行点积以根据位置旋转查询 477 | 478 | ```python 479 | q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs) 480 | q_per_token_as_complex_numbers.shape 481 | ``` 482 | 483 | ```bash 484 | torch.Size([17, 64]) 485 | ``` 486 | 487 | ```python 488 | q_per_token_as_complex_numbers_rotated = q_per_token_as_complex_numbers * freqs_cis 489 | q_per_token_as_complex_numbers_rotated.shape 490 | ``` 491 | 492 | ```bash 493 | torch.Size([17, 64]) 494 | ``` 495 | 496 | 497 | 498 | ### 得到旋转向量后 499 | 500 | 可以通过再次将复数看作实数来返回成对的 query 501 | 502 | 503 | ```python 504 | q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers_rotated) 505 | q_per_token_split_into_pairs_rotated.shape 506 | ``` 507 | 508 | ```bash 509 | torch.Size([17, 64, 2]) 510 | ``` 511 | 512 | 513 | 旋转对现在已合并,现在有了一个新的 query 向量(旋转 query 向量),其 shape 为 `[17x128]`,其中 17 是 token 的数量,128 是 query 向量的维度 514 | 515 | 516 | ```python 517 | q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape) 518 | q_per_token_rotated.shape 519 | ``` 520 | 521 | ```bash 522 | torch.Size([17, 128]) 523 | ``` 524 | 525 | # keys(几乎与 query 一模一样) 526 | 527 |
528 | 529 |
530 | 531 | 我是个懒鬼,所以不打算详细讲 keys 的数学过程,只需要记住以下几点: 532 | 533 | - keys 生成的 key 向量的维度也是 128 534 | - **keys 的权重只有 query 的 1/4,因为 keys 的权重在 4 个头之间共享,以减少计算量** 535 | - keys 也像 query 一样被旋转以添加位置信息,其原因相同 536 | 537 | 538 | ```python 539 | k_layer0 = model["layers.0.attention.wk.weight"] 540 | k_layer0 = k_layer0.view(n_kv_heads, k_layer0.shape[0] // n_kv_heads, dim) 541 | k_layer0.shape 542 | ``` 543 | 544 | ```bash 545 | torch.Size([8, 128, 4096]) 546 | ``` 547 | 548 | ```python 549 | k_layer0_head0 = k_layer0[0] 550 | k_layer0_head0.shape 551 | ``` 552 | 553 | ```bash 554 | torch.Size([128, 4096]) 555 | ``` 556 | 557 | ```python 558 | k_per_token = torch.matmul(token_embeddings, k_layer0_head0.T) 559 | k_per_token.shape 560 | ``` 561 | 562 | ```bash 563 | torch.Size([17, 128]) 564 | ``` 565 | 566 | 567 | ```python 568 | k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2) 569 | k_per_token_split_into_pairs.shape 570 | ``` 571 | 572 | ```bash 573 | torch.Size([17, 64, 2]) 574 | ``` 575 | 576 | 577 | ```python 578 | k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs) 579 | k_per_token_as_complex_numbers.shape 580 | ``` 581 | 582 | ```bash 583 | torch.Size([17, 64]) 584 | ``` 585 | 586 | 587 | 588 | 589 | ```python 590 | k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis) 591 | k_per_token_split_into_pairs_rotated.shape 592 | ``` 593 | 594 | ```bash 595 | torch.Size([17, 64, 2]) 596 | ``` 597 | 598 | 599 | ```python 600 | k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape) 601 | k_per_token_rotated.shape 602 | ``` 603 | 604 | ```bash 605 | torch.Size([17, 128]) 606 | ``` 607 | 608 | 609 | 610 | ## 现在,已经有了每个 token 的旋转后的 query 和 key 611 | 612 |
613 | 614 |
615 | 616 | 每个 query 和 key 的 shape 都是 `[17x128]`。 617 | 618 | ## 接下来,将 query 和 key 的矩阵相乘 619 | 620 | 这样做会得到每一个 token 相互映射的分数 621 | 622 | 这个分数描述了每个 token 的 query 与每个 token 的 key 的相关度。这就是自注意力 :) 623 | 624 | 注意力得分矩阵(qk_per_token)的 shape 是 `[17x17]`,其中 17 是 prompt 中的 token 数量 625 | 626 |
627 | 628 |
629 | 630 | 631 | ```python 632 | qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(head_dim)**0.5 633 | qk_per_token.shape 634 | ``` 635 | 636 | ```bash 637 | torch.Size([17, 17]) 638 | ``` 639 | 640 | # 现在必须屏蔽 QK 分数 641 | 642 | 在 llama3 的训练过程中,未来的 token qk 分数被屏蔽。 643 | 644 | 为什么?因为在训练过程中,只学习使用过去的 token 来预测 token 。 645 | 646 | 因此,在推理过程中,将未来的 token 设置为零。 647 | 648 |
649 | 650 |
651 | 652 | 653 | ```python 654 | def display_qk_heatmap(qk_per_token): 655 | _, ax = plt.subplots() 656 | im = ax.imshow(qk_per_token.to(float).detach(), cmap='viridis') 657 | ax.set_xticks(range(len(prompt_split_as_tokens))) 658 | ax.set_yticks(range(len(prompt_split_as_tokens))) 659 | ax.set_xticklabels(prompt_split_as_tokens) 660 | ax.set_yticklabels(prompt_split_as_tokens) 661 | ax.figure.colorbar(im, ax=ax) 662 | 663 | display_qk_heatmap(qk_per_token) 664 | ``` 665 | ​ 666 | ![png](images/implllama3_50_0.png) 667 | ​ 668 | 669 | ```python 670 | mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device) 671 | mask = torch.triu(mask, diagonal=1) 672 | mask 673 | ``` 674 | 675 | ```bash 676 | tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 677 | [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 678 | [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 679 | [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 680 | [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 681 | [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 682 | [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 683 | [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 684 | [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 685 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf], 686 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf], 687 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf], 688 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf], 689 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf], 690 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf], 691 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf], 692 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]) 693 | ``` 694 | 695 | 696 | 697 | ```python 698 | qk_per_token_after_masking = qk_per_token + mask 699 | display_qk_heatmap(qk_per_token_after_masking) 700 | ``` 701 | ​ 702 | ![png](images/implllama3_52_0.png) 703 | ​ 704 | 705 |
706 | 707 |
708 | 709 | 710 | ```python 711 | qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16) 712 | display_qk_heatmap(qk_per_token_after_masking_after_softmax) 713 | ``` 714 | 715 | ​ 716 | ![png](images/implllama3_54_0.png) 717 | ​ 718 | 719 | ## values (注意力机制的最后部分) 720 | 721 | 722 |
723 | 724 |
725 | 726 | 这些分数(0-1)用于确定每个 token 中使用了多少 value 矩阵 727 | 728 | > these scores (0-1) are used to determine how much of value matrix is used per token 729 | 730 | 和 key 一样,value 权重也在每 4 个注意力头之间进行共享(以节省计算量) 731 | 732 | 因此,下面的 value 权重矩阵的 shape 为 `[8x128x4096]` 733 | 734 | 735 | 736 | ```python 737 | v_layer0 = model["layers.0.attention.wv.weight"] 738 | v_layer0 = v_layer0.view(n_kv_heads, v_layer0.shape[0] // n_kv_heads, dim) 739 | v_layer0.shape 740 | ``` 741 | 742 | ```bash 743 | torch.Size([8, 128, 4096]) 744 | ``` 745 | 746 | llama3的第一层,第一个头的权值矩阵如下所示: 747 | 748 | ```python 749 | v_layer0_head0 = v_layer0[0] 750 | v_layer0_head0.shape 751 | ``` 752 | 753 | ```bash 754 | torch.Size([128, 4096]) 755 | ``` 756 | 757 | 758 | ## value 向量 759 | 760 |
761 | 762 |
763 | 764 | 现在使用 value 权重来获取每个 token 的注意力值,其大小为 `[17x128]`,其中 17 是 prompt 中的 token 数,128 是每个 tokene 的 value 向量的维度 765 | 766 | 767 | ```python 768 | v_per_token = torch.matmul(token_embeddings, v_layer0_head0.T) 769 | v_per_token.shape 770 | ``` 771 | 772 | ```bash 773 | torch.Size([17, 128]) 774 | ``` 775 | 776 | ## 注意力(attention) 777 | 778 |
779 | 780 |
781 | 782 | 和每个 token 的 value 相乘后得到的注意力向量的 shape 为 `[17*128]` 783 | 784 | ```python 785 | qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token) 786 | qkv_attention.shape 787 | ``` 788 | 789 | ```bash 790 | torch.Size([17, 128]) 791 | ``` 792 | 793 | 794 | # 多头注意力 (multi head attention) 795 | 796 |
797 | 798 |
799 | 800 | 现在已经有了第一层和第一个头的注意力值 801 | 802 | 现在将运行一个循环,并执行与上面单元格中相同的数学运算,但只针对第一层中的每个头 803 | 804 | 805 | ```python 806 | qkv_attention_store = [] 807 | 808 | for head in range(n_heads): 809 | q_layer0_head = q_layer0[head] 810 | k_layer0_head = k_layer0[head//4] # key weights are shared across 4 heads 811 | v_layer0_head = v_layer0[head//4] # value weights are shared across 4 heads 812 | q_per_token = torch.matmul(token_embeddings, q_layer0_head.T) 813 | k_per_token = torch.matmul(token_embeddings, k_layer0_head.T) 814 | v_per_token = torch.matmul(token_embeddings, v_layer0_head.T) 815 | 816 | q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2) 817 | q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs) 818 | q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis[:len(tokens)]) 819 | q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape) 820 | 821 | k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2) 822 | k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs) 823 | k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis[:len(tokens)]) 824 | k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape) 825 | 826 | qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5 827 | mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device) 828 | mask = torch.triu(mask, diagonal=1) 829 | qk_per_token_after_masking = qk_per_token + mask 830 | qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16) 831 | qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token) 832 | qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token) 833 | qkv_attention_store.append(qkv_attention) 834 | 835 | len(qkv_attention_store) 836 | ``` 837 | 838 | ```bash 839 | 32 840 | ``` 841 | 842 | 843 |
844 | 845 |
846 | 847 | 现在有了第一个层的 32 个头的 qkv_attention 矩阵,接下来将把所有注意力分数合并成一个大矩阵,大小为 `[17x4096]` 848 | 849 | ```python 850 | stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1) 851 | stacked_qkv_attention.shape 852 | ``` 853 | 854 | ```bash 855 | torch.Size([17, 4096]) 856 | ``` 857 | 858 | 859 | 860 | # 权重矩阵,最后几步之一 861 | 862 |
863 | 864 |
865 | 866 | 对于第0层,最后要做的一件事是,将权重矩阵相乘 867 | 868 | ```python 869 | w_layer0 = model["layers.0.attention.wo.weight"] 870 | w_layer0.shape 871 | ``` 872 | 873 | ```bash 874 | torch.Size([4096, 4096]) 875 | ``` 876 | 877 | 878 | ### 这是一个简单的线性层,所以只需要进行乘法运算 879 | 880 | 881 | ```python 882 | embedding_delta = torch.matmul(stacked_qkv_attention, w_layer0.T) 883 | embedding_delta.shape 884 | ``` 885 | 886 | ```bash 887 | torch.Size([17, 4096]) 888 | ``` 889 | 890 |
891 | 892 |
893 | 894 | 注意之后,现在有了嵌入值的变化,应该将其添加到原始的 token embeddings 中 895 | 896 | ```python 897 | embedding_after_edit = token_embeddings_unnormalized + embedding_delta 898 | embedding_after_edit.shape 899 | ``` 900 | 901 | ```bash 902 | torch.Size([17, 4096]) 903 | ``` 904 | 905 | 906 | 907 | ## 将其归一化,然后运行一个前馈神经网络 908 | 909 |
910 | 911 |
912 | 913 | 914 | ```python 915 | embedding_after_edit_normalized = rms_norm(embedding_after_edit, model["layers.0.ffn_norm.weight"]) 916 | embedding_after_edit_normalized.shape 917 | ``` 918 | 919 | ```bash 920 | torch.Size([17, 4096]) 921 | ``` 922 | 923 | ## 加载 FFN 权重并实现前馈网络 924 | 925 |
926 | 927 |
928 | 929 | 在 llama3 中,使用了 `SwiGLU` 前馈网络,这种网络架构非常擅长非线性计算。 930 | 931 | 如今,在 LLMS 中使用这种前馈网络架构是相当常见的 932 | 933 | 934 | ```python 935 | w1 = model["layers.0.feed_forward.w1.weight"] 936 | w2 = model["layers.0.feed_forward.w2.weight"] 937 | w3 = model["layers.0.feed_forward.w3.weight"] 938 | output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T) 939 | output_after_feedforward.shape 940 | ``` 941 | 942 | ```bash 943 | torch.Size([17, 4096]) 944 | ``` 945 | 946 | 947 | # 在第一层之后,终于为每个 token 编辑了新的 EMBEDDINGS 948 | 949 | 离结束还剩 31 层(一层 for 循环) 950 | 951 | 可以将经过编辑的 embedding 想象为包含有关第一层上提出的所有 query 的信息 952 | 953 | 现在,对所有提出的问题每一层都会对 query 进行越来越复杂的编码,直到得到一个 embedding,其中包含了需要的下一个 token 的所有信息。 954 | 955 | ```python 956 | layer_0_embedding = embedding_after_edit+output_after_feedforward 957 | layer_0_embedding.shape 958 | ``` 959 | 960 | ```bash 961 | torch.Size([17, 4096]) 962 | ``` 963 | 964 | 965 | 966 | # 整合 967 | 968 |
969 | 970 |
971 | 972 | 就是这样。 之前为每一层所做的一切都需要一次性完成。 973 | 974 | ```python 975 | final_embedding = token_embeddings_unnormalized 976 | for layer in range(n_layers): 977 | qkv_attention_store = [] 978 | layer_embedding_norm = rms_norm(final_embedding, model[f"layers.{layer}.attention_norm.weight"]) 979 | q_layer = model[f"layers.{layer}.attention.wq.weight"] 980 | q_layer = q_layer.view(n_heads, q_layer.shape[0] // n_heads, dim) 981 | k_layer = model[f"layers.{layer}.attention.wk.weight"] 982 | k_layer = k_layer.view(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim) 983 | v_layer = model[f"layers.{layer}.attention.wv.weight"] 984 | v_layer = v_layer.view(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim) 985 | w_layer = model[f"layers.{layer}.attention.wo.weight"] 986 | for head in range(n_heads): 987 | q_layer_head = q_layer[head] 988 | k_layer_head = k_layer[head//4] 989 | v_layer_head = v_layer[head//4] 990 | q_per_token = torch.matmul(layer_embedding_norm, q_layer_head.T) 991 | k_per_token = torch.matmul(layer_embedding_norm, k_layer_head.T) 992 | v_per_token = torch.matmul(layer_embedding_norm, v_layer_head.T) 993 | q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2) 994 | q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs) 995 | q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis) 996 | q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape) 997 | k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2) 998 | k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs) 999 | k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis) 1000 | k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape) 1001 | qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5 1002 | mask = torch.full((len(token_embeddings_unnormalized), len(token_embeddings_unnormalized)), float("-inf")) 1003 | mask = torch.triu(mask, diagonal=1) 1004 | qk_per_token_after_masking = qk_per_token + mask 1005 | qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16) 1006 | qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token) 1007 | qkv_attention_store.append(qkv_attention) 1008 | 1009 | stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1) 1010 | w_layer = model[f"layers.{layer}.attention.wo.weight"] 1011 | embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T) 1012 | embedding_after_edit = final_embedding + embedding_delta 1013 | embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f"layers.{layer}.ffn_norm.weight"]) 1014 | w1 = model[f"layers.{layer}.feed_forward.w1.weight"] 1015 | w2 = model[f"layers.{layer}.feed_forward.w2.weight"] 1016 | w3 = model[f"layers.{layer}.feed_forward.w3.weight"] 1017 | output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T) 1018 | final_embedding = embedding_after_edit+output_after_feedforward 1019 | ``` 1020 | 1021 | # 得到最终 Embedding,对下一个 token 做预测 1022 | 1023 | embedding 的 shape 与常规 token embedding shape `[17x4096]` 相同,其中 17 是 token 数量,4096 是 embedding 维度 1024 | 1025 |
1026 | 1027 |
1028 | 1029 | 1030 | ```python 1031 | final_embedding = rms_norm(final_embedding, model["norm.weight"]) 1032 | final_embedding.shape 1033 | ``` 1034 | 1035 | ```bash 1036 | torch.Size([17, 4096]) 1037 | ``` 1038 | 1039 | 1040 | # 最后,将 embedding 解码为 token value 1041 | 1042 |
1043 | 1044 |
1045 | 1046 | 将使用输出解码器将最终 embedding 转换为 token。 1047 | 1048 | ```python 1049 | model["output.weight"].shape 1050 | ``` 1051 | 1052 | ```bash 1053 | torch.Size([128256, 4096]) 1054 | ``` 1055 | 1056 | 1057 | # 使用最后一个 token 的 embedding 来预测下一个值 1058 | 1059 | 希望在我们预料之内, 42 :) 1060 | 1061 | 注意:根据《银河系漫游指南》书中提到,“生命、宇宙和一切的终极问题的答案是 42 ” 。大多数现代语言模型在这里应该会回答 42,这应该能验证我们的整个代码!祝我好运 :) 1062 | 1063 | 1064 | ```python 1065 | logits = torch.matmul(final_embedding[-1], model["output.weight"].T) 1066 | logits.shape 1067 | ``` 1068 | 1069 | ```bash 1070 | torch.Size([128256]) 1071 | ``` 1072 | 1073 | 1074 | ### 模型预测的 token 编号是 2983,这是否代表 42 的 token 编号? 1075 | 1076 | 这已经是代码的最后一部分了,希望你已经信心满满 :) 1077 | 1078 | 1079 | ```python 1080 | next_token = torch.argmax(logits, dim=-1) 1081 | next_token 1082 | ``` 1083 | 1084 | ```bash 1085 | tensor(2983) 1086 | ``` 1087 | 1088 | 1089 | # 解码 1090 | 1091 |
1092 | 1093 |
1094 | 1095 | 1096 | ```python 1097 | tokenizer.decode([next_token.item()]) 1098 | ``` 1099 | 1100 | ```bash 1101 | 42 1102 | ``` 1103 | 1104 | 1105 | # 感恩, 爱你哟 :) 1106 | 1107 | 这就是结尾了。希望你喜欢! 1108 | 1109 | 如果你想支持我的工作: 1110 | 1111 | 1. 在 Twitter 上关注我:https://twitter.com/naklecha 1112 | 2. 或者给我买杯咖啡:[https://www.buymeacoffee.com/naklecha](https://www.buymeacoffee.com/naklecha) 1113 | 1114 | 说实话,如果你能到这一步,已经让我很开心了:) 1115 | 1116 | -------------------------------------------------------------------------------- /README_en.md: -------------------------------------------------------------------------------- 1 | # llama3 implemented from scratch 2 | in this file, i implemented llama3 from scratch, one tensor and matrix multiplication at a time. 3 |
4 | also, im going to load tensors directly from the model file that meta provided for llama3, you need to download the weights before running this file. 5 | here is the offical link to download the weights: https://llama.meta.com/llama-downloads/ 6 | 7 |
8 | 9 |
10 | 11 | ## tokenizer 12 | im not going to implement a bpe tokenizer (but andrej karpathy has a really clean implementation) 13 |
14 | link to his implementation: https://github.com/karpathy/minbpe 15 | 16 |
17 | 18 |
19 | 20 | 21 | 22 | ```python 23 | from pathlib import Path 24 | import tiktoken 25 | from tiktoken.load import load_tiktoken_bpe 26 | import torch 27 | import json 28 | import matplotlib.pyplot as plt 29 | 30 | tokenizer_path = "Meta-Llama-3-8B/tokenizer.model" 31 | special_tokens = [ 32 | "<|begin_of_text|>", 33 | "<|end_of_text|>", 34 | "<|reserved_special_token_0|>", 35 | "<|reserved_special_token_1|>", 36 | "<|reserved_special_token_2|>", 37 | "<|reserved_special_token_3|>", 38 | "<|start_header_id|>", 39 | "<|end_header_id|>", 40 | "<|reserved_special_token_4|>", 41 | "<|eot_id|>", # end of turn 42 | ] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)] 43 | mergeable_ranks = load_tiktoken_bpe(tokenizer_path) 44 | tokenizer = tiktoken.Encoding( 45 | name=Path(tokenizer_path).name, 46 | pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+", 47 | mergeable_ranks=mergeable_ranks, 48 | special_tokens={token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)}, 49 | ) 50 | 51 | tokenizer.decode(tokenizer.encode("hello world!")) 52 | ``` 53 | 54 | 55 | 56 | 57 | 'hello world!' 58 | 59 | 60 | 61 | ## reading the model file 62 | normally, reading this depends on how the model classes are written and the variable names inside them. 63 |
64 | but since we are implementing llama3 from scratch we will read the file one tensor at a time. 65 |
66 | 67 |
68 | 69 | 70 | ```python 71 | model = torch.load("Meta-Llama-3-8B/consolidated.00.pth") 72 | print(json.dumps(list(model.keys())[:20], indent=4)) 73 | ``` 74 | 75 | [ 76 | "tok_embeddings.weight", 77 | "layers.0.attention.wq.weight", 78 | "layers.0.attention.wk.weight", 79 | "layers.0.attention.wv.weight", 80 | "layers.0.attention.wo.weight", 81 | "layers.0.feed_forward.w1.weight", 82 | "layers.0.feed_forward.w3.weight", 83 | "layers.0.feed_forward.w2.weight", 84 | "layers.0.attention_norm.weight", 85 | "layers.0.ffn_norm.weight", 86 | "layers.1.attention.wq.weight", 87 | "layers.1.attention.wk.weight", 88 | "layers.1.attention.wv.weight", 89 | "layers.1.attention.wo.weight", 90 | "layers.1.feed_forward.w1.weight", 91 | "layers.1.feed_forward.w3.weight", 92 | "layers.1.feed_forward.w2.weight", 93 | "layers.1.attention_norm.weight", 94 | "layers.1.ffn_norm.weight", 95 | "layers.2.attention.wq.weight" 96 | ] 97 | 98 | 99 | 100 | ```python 101 | with open("Meta-Llama-3-8B/params.json", "r") as f: 102 | config = json.load(f) 103 | config 104 | ``` 105 | 106 | 107 | 108 | 109 | {'dim': 4096, 110 | 'n_layers': 32, 111 | 'n_heads': 32, 112 | 'n_kv_heads': 8, 113 | 'vocab_size': 128256, 114 | 'multiple_of': 1024, 115 | 'ffn_dim_multiplier': 1.3, 116 | 'norm_eps': 1e-05, 117 | 'rope_theta': 500000.0} 118 | 119 | 120 | 121 | ## we use this config to infer details about the model like 122 | 1. the model has 32 transformer layers 123 | 2. each multi-head attention block has 32 heads 124 | 3. the vocab size and so on 125 | 126 | 127 | ```python 128 | dim = config["dim"] 129 | n_layers = config["n_layers"] 130 | n_heads = config["n_heads"] 131 | n_kv_heads = config["n_kv_heads"] 132 | vocab_size = config["vocab_size"] 133 | multiple_of = config["multiple_of"] 134 | ffn_dim_multiplier = config["ffn_dim_multiplier"] 135 | norm_eps = config["norm_eps"] 136 | rope_theta = torch.tensor(config["rope_theta"]) 137 | ``` 138 | 139 | ## converting text to tokens 140 | here we use tiktoken (i think an openai library) as the tokenizer 141 |
142 | 143 |
144 | 145 | 146 | ```python 147 | prompt = "the answer to the ultimate question of life, the universe, and everything is " 148 | tokens = [128000] + tokenizer.encode(prompt) 149 | print(tokens) 150 | tokens = torch.tensor(tokens) 151 | prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens] 152 | print(prompt_split_as_tokens) 153 | ``` 154 | 155 | [128000, 1820, 4320, 311, 279, 17139, 3488, 315, 2324, 11, 279, 15861, 11, 323, 4395, 374, 220] 156 | ['<|begin_of_text|>', 'the', ' answer', ' to', ' the', ' ultimate', ' question', ' of', ' life', ',', ' the', ' universe', ',', ' and', ' everything', ' is', ' '] 157 | 158 | 159 | ## converting tokens to their embedding 160 | IM SORRY but this is the only part of the codebase where i use an inbuilt neural network module 161 |
162 | anyway, so our [17x1] tokens are now [17x4096], i.e. 17 embeddings (one for each token) of length 4096 163 |
164 |
165 | note: keep track of the shapes, it makes it much easier to understand everything 166 | 167 |
168 | 169 |
170 | 171 | 172 | ```python 173 | embedding_layer = torch.nn.Embedding(vocab_size, dim) 174 | embedding_layer.weight.data.copy_(model["tok_embeddings.weight"]) 175 | token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16) 176 | token_embeddings_unnormalized.shape 177 | ``` 178 | 179 | 180 | 181 | 182 | torch.Size([17, 4096]) 183 | 184 | 185 | 186 | ## we then normalize the embedding using rms normalization 187 | please, note after this step the shapes dont change, the values are just normalized 188 |
189 | things to keep in mind, we need a norm_eps (from config) because we dont want to accidently set rms to 0 and divide by 0 190 |
191 | here is the formula: 192 |
193 | 194 |
195 | 196 | 197 | ```python 198 | # def rms_norm(tensor, norm_weights): 199 | # rms = (tensor.pow(2).mean(-1, keepdim=True) + norm_eps)**0.5 200 | # return tensor * (norm_weights / rms) 201 | def rms_norm(tensor, norm_weights): 202 | return (tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + norm_eps)) * norm_weights 203 | ``` 204 | 205 | # building the first first layer of the transformer 206 | 207 | ### normalization 208 | you will see me accessing layer.0 from the model dict (this is the first layer) 209 |
210 | anyway, so after normalizing our shapes are still [17x4096] same as embedding but normalized 211 | 212 |
213 | 214 |
215 | 216 | 217 | ```python 218 | token_embeddings = rms_norm(token_embeddings_unnormalized, model["layers.0.attention_norm.weight"]) 219 | token_embeddings.shape 220 | ``` 221 | 222 | 223 | 224 | 225 | torch.Size([17, 4096]) 226 | 227 | 228 | 229 | ### attention implemented from scratch 230 | let's load the attention heads of the first layer of the transformer 231 |
232 | 233 |
234 | 235 |
236 | 237 | > when we load the query, key, value and output vectors from the model we notice the shapes to be [4096x4096], [1024x4096], [1024x4096], [4096x4096] 238 |
239 | > at first glance this is weird because ideally we want each q,k,v and o for each head individually 240 |
241 | > the authors of the code bundled them togeather because its easy it helps parallize attention head multiplication. 242 |
243 | > im going to unwrap everything... 244 | 245 | 246 | ```python 247 | print( 248 | model["layers.0.attention.wq.weight"].shape, 249 | model["layers.0.attention.wk.weight"].shape, 250 | model["layers.0.attention.wv.weight"].shape, 251 | model["layers.0.attention.wo.weight"].shape 252 | ) 253 | ``` 254 | 255 | torch.Size([4096, 4096]) torch.Size([1024, 4096]) torch.Size([1024, 4096]) torch.Size([4096, 4096]) 256 | 257 | 258 | ### unwrapping query 259 | in the next section we will unwrap the queries from multiple attention heads, the resulting shape is [32x128x4096] 260 |

261 | here, 32 is the number of attention heads in llama3, 128 is the size of the query vector and 4096 is the size of the token embedding 262 | 263 | 264 | ```python 265 | q_layer0 = model["layers.0.attention.wq.weight"] 266 | head_dim = q_layer0.shape[0] // n_heads 267 | q_layer0 = q_layer0.view(n_heads, head_dim, dim) 268 | q_layer0.shape 269 | ``` 270 | 271 | 272 | 273 | 274 | torch.Size([32, 128, 4096]) 275 | 276 | 277 | 278 | ### im going to implement the first head of the first layer 279 | here i access the query weight matrix first head of the first layer, the size of this query weight matrix is [128x4096] 280 | 281 | 282 | ```python 283 | q_layer0_head0 = q_layer0[0] 284 | q_layer0_head0.shape 285 | ``` 286 | 287 | 288 | 289 | 290 | torch.Size([128, 4096]) 291 | 292 | 293 | 294 | ### we now multiply the query weights with the token embedding, to recive a query for the token 295 | here you can see the resulting shape is [17x128], this is because we have 17 tokens and for each token there is a 128 length query. 296 |
297 | 298 |
299 | 300 | 301 | ```python 302 | q_per_token = torch.matmul(token_embeddings, q_layer0_head0.T) 303 | q_per_token.shape 304 | ``` 305 | 306 | 307 | 308 | 309 | torch.Size([17, 128]) 310 | 311 | 312 | 313 | ## positioning encoding 314 | we are now at a stage where we have a query vector for each token in our prompt, but if you think about it -- the indivitually query vector has no idea about the position in the prompt. 315 |

316 | query: "the answer to the ultimate question of life, the universe, and everything is " 317 |

318 | in our prompt we have used "the" three times, we need the query vectors of all 3 "the" tokens to have different query vectors (each of size [1x128]) based on their positions in the query. we perform these rotations using RoPE (rotory positional embedding). 319 |

320 | ### RoPE 321 | watch this video (this is what i watched) to understand the math. 322 | https://www.youtube.com/watch?v=o29P0Kpobz0&t=530s 323 | 324 | 325 |
326 | 327 |
328 | 329 | 330 | ```python 331 | q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2) 332 | q_per_token_split_into_pairs.shape 333 | ``` 334 | 335 | 336 | 337 | 338 | torch.Size([17, 64, 2]) 339 | 340 | 341 | 342 | in the above step, we split the query vectors into pairs, we apply a rotational angle shift to each pair! 343 |

344 | we now have a vector of size [17x64x2], this is the 128 length queries split into 64 pairs for each token in the prompt! each of those 64 pairs will be rotated by m*(theta) where m is the position of the token for which we are rotating the query! 345 | 346 | 347 |
348 | 349 |
350 | 351 | ## using dot product of complex numbers to rotate a vector 352 |
353 | 354 |
355 | 356 | 357 | ```python 358 | zero_to_one_split_into_64_parts = torch.tensor(range(64))/64 359 | zero_to_one_split_into_64_parts 360 | ``` 361 | 362 | 363 | 364 | 365 | tensor([0.0000, 0.0156, 0.0312, 0.0469, 0.0625, 0.0781, 0.0938, 0.1094, 0.1250, 366 | 0.1406, 0.1562, 0.1719, 0.1875, 0.2031, 0.2188, 0.2344, 0.2500, 0.2656, 367 | 0.2812, 0.2969, 0.3125, 0.3281, 0.3438, 0.3594, 0.3750, 0.3906, 0.4062, 368 | 0.4219, 0.4375, 0.4531, 0.4688, 0.4844, 0.5000, 0.5156, 0.5312, 0.5469, 369 | 0.5625, 0.5781, 0.5938, 0.6094, 0.6250, 0.6406, 0.6562, 0.6719, 0.6875, 370 | 0.7031, 0.7188, 0.7344, 0.7500, 0.7656, 0.7812, 0.7969, 0.8125, 0.8281, 371 | 0.8438, 0.8594, 0.8750, 0.8906, 0.9062, 0.9219, 0.9375, 0.9531, 0.9688, 372 | 0.9844]) 373 | 374 | 375 | 376 | 377 | ```python 378 | freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts) 379 | freqs 380 | ``` 381 | 382 | 383 | 384 | 385 | tensor([1.0000e+00, 8.1462e-01, 6.6360e-01, 5.4058e-01, 4.4037e-01, 3.5873e-01, 386 | 2.9223e-01, 2.3805e-01, 1.9392e-01, 1.5797e-01, 1.2869e-01, 1.0483e-01, 387 | 8.5397e-02, 6.9566e-02, 5.6670e-02, 4.6164e-02, 3.7606e-02, 3.0635e-02, 388 | 2.4955e-02, 2.0329e-02, 1.6560e-02, 1.3490e-02, 1.0990e-02, 8.9523e-03, 389 | 7.2927e-03, 5.9407e-03, 4.8394e-03, 3.9423e-03, 3.2114e-03, 2.6161e-03, 390 | 2.1311e-03, 1.7360e-03, 1.4142e-03, 1.1520e-03, 9.3847e-04, 7.6450e-04, 391 | 6.2277e-04, 5.0732e-04, 4.1327e-04, 3.3666e-04, 2.7425e-04, 2.2341e-04, 392 | 1.8199e-04, 1.4825e-04, 1.2077e-04, 9.8381e-05, 8.0143e-05, 6.5286e-05, 393 | 5.3183e-05, 4.3324e-05, 3.5292e-05, 2.8750e-05, 2.3420e-05, 1.9078e-05, 394 | 1.5542e-05, 1.2660e-05, 1.0313e-05, 8.4015e-06, 6.8440e-06, 5.5752e-06, 395 | 4.5417e-06, 3.6997e-06, 3.0139e-06, 2.4551e-06]) 396 | 397 | 398 | 399 | 400 | ```python 401 | freqs_for_each_token = torch.outer(torch.arange(17), freqs) 402 | freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token) 403 | freqs_cis.shape 404 | 405 | # viewing tjhe third row of freqs_cis 406 | value = freqs_cis[3] 407 | plt.figure() 408 | for i, element in enumerate(value[:17]): 409 | plt.plot([0, element.real], [0, element.imag], color='blue', linewidth=1, label=f"Index: {i}") 410 | plt.annotate(f"{i}", xy=(element.real, element.imag), color='red') 411 | plt.xlabel('Real') 412 | plt.ylabel('Imaginary') 413 | plt.title('Plot of one row of freqs_cis') 414 | plt.show() 415 | ``` 416 | 417 | 418 | 419 | ![png](images/implllama3_30_0.png) 420 | 421 | 422 | 423 | ### now that we have a complex number (the angle change vector) for every token's query element 424 | we can convert our queries (the one we split into pairs) as complex numbers and then dot product to rotate the query based on the position 425 |
426 | honeslty this is beautiful to think about :) 427 | 428 | 429 | ```python 430 | q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs) 431 | q_per_token_as_complex_numbers.shape 432 | ``` 433 | 434 | 435 | 436 | 437 | torch.Size([17, 64]) 438 | 439 | 440 | 441 | 442 | ```python 443 | q_per_token_as_complex_numbers_rotated = q_per_token_as_complex_numbers * freqs_cis 444 | q_per_token_as_complex_numbers_rotated.shape 445 | ``` 446 | 447 | 448 | 449 | 450 | torch.Size([17, 64]) 451 | 452 | 453 | 454 | ### after rotated vector is obtained 455 | we can get back our the queries as pairs by viewing the complex numbers as real numbers again 456 | 457 | 458 | ```python 459 | q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers_rotated) 460 | q_per_token_split_into_pairs_rotated.shape 461 | ``` 462 | 463 | 464 | 465 | 466 | torch.Size([17, 64, 2]) 467 | 468 | 469 | 470 | the rotated pairs are now merged, we now have a new query vector (rotated query vector) that is of the shape [17x128] where 17 is the number of tokens and the 128 is the dim of the query vector 471 | 472 | 473 | ```python 474 | q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape) 475 | q_per_token_rotated.shape 476 | ``` 477 | 478 | 479 | 480 | 481 | torch.Size([17, 128]) 482 | 483 | 484 | 485 | # keys (almost the same as queries) 486 |
487 | 488 |
489 | im lazy as fuck, so im not going to go through the math for keys, the only things you need to keep in mind are: 490 |
491 | > keys generate key vectors also of dimention 128 492 |
493 | > keys have only 1/4th the number of the weights as queries, this is because the weights for keys are shared across 4 heads at a time, to reduce the number of computations need 494 |
495 | > keys are also rotated to add positional info, just like queries because of the same reasons 496 | 497 | 498 | ```python 499 | k_layer0 = model["layers.0.attention.wk.weight"] 500 | k_layer0 = k_layer0.view(n_kv_heads, k_layer0.shape[0] // n_kv_heads, dim) 501 | k_layer0.shape 502 | ``` 503 | 504 | 505 | 506 | 507 | torch.Size([8, 128, 4096]) 508 | 509 | 510 | 511 | 512 | ```python 513 | k_layer0_head0 = k_layer0[0] 514 | k_layer0_head0.shape 515 | ``` 516 | 517 | 518 | 519 | 520 | torch.Size([128, 4096]) 521 | 522 | 523 | 524 | 525 | ```python 526 | k_per_token = torch.matmul(token_embeddings, k_layer0_head0.T) 527 | k_per_token.shape 528 | ``` 529 | 530 | 531 | 532 | 533 | torch.Size([17, 128]) 534 | 535 | 536 | 537 | 538 | ```python 539 | k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2) 540 | k_per_token_split_into_pairs.shape 541 | ``` 542 | 543 | 544 | 545 | 546 | torch.Size([17, 64, 2]) 547 | 548 | 549 | 550 | 551 | ```python 552 | k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs) 553 | k_per_token_as_complex_numbers.shape 554 | ``` 555 | 556 | 557 | 558 | 559 | torch.Size([17, 64]) 560 | 561 | 562 | 563 | 564 | ```python 565 | k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis) 566 | k_per_token_split_into_pairs_rotated.shape 567 | ``` 568 | 569 | 570 | 571 | 572 | torch.Size([17, 64, 2]) 573 | 574 | 575 | 576 | 577 | ```python 578 | k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape) 579 | k_per_token_rotated.shape 580 | ``` 581 | 582 | 583 | 584 | 585 | torch.Size([17, 128]) 586 | 587 | 588 | 589 | ## at this stage now have both the rotated values of queries and keys, for each token. 590 |
591 | 592 |
593 | each of the queries and keys are now of shape [17x128]. 594 | 595 | ## in the next step we will multiply the queries and key matrices 596 | doing this will give us a score mapping each token with one another 597 |
598 | this score describes how well each token's query relates to the each tokens's key. 599 | THIS IS SELF ATTENTION :) 600 |
601 | the shape of the attention score matrix (qk_per_token) is [17x17] where 17 is the number of tokens in the prompt 602 | 603 |
604 | 605 |
606 | 607 | 608 | ```python 609 | qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(head_dim)**0.5 610 | qk_per_token.shape 611 | ``` 612 | 613 | 614 | 615 | 616 | torch.Size([17, 17]) 617 | 618 | 619 | 620 | # we now have to mask query key scores 621 | during the training process of llama3, the future token qk scores are masked. 622 |
623 | why? because during training we only learn to predict tokens using past tokens. 624 |
625 | as a result, during inference we set the future tokens to zero. 626 |
627 | 628 |
629 | 630 | 631 | ```python 632 | def display_qk_heatmap(qk_per_token): 633 | _, ax = plt.subplots() 634 | im = ax.imshow(qk_per_token.to(float).detach(), cmap='viridis') 635 | ax.set_xticks(range(len(prompt_split_as_tokens))) 636 | ax.set_yticks(range(len(prompt_split_as_tokens))) 637 | ax.set_xticklabels(prompt_split_as_tokens) 638 | ax.set_yticklabels(prompt_split_as_tokens) 639 | ax.figure.colorbar(im, ax=ax) 640 | 641 | display_qk_heatmap(qk_per_token) 642 | ``` 643 | 644 | 645 | 646 | ![png](images/implllama3_50_0.png) 647 | 648 | 649 | 650 | 651 | ```python 652 | mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device) 653 | mask = torch.triu(mask, diagonal=1) 654 | mask 655 | ``` 656 | 657 | 658 | 659 | 660 | tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 661 | [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 662 | [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 663 | [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 664 | [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 665 | [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 666 | [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 667 | [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 668 | [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 669 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf], 670 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf], 671 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf], 672 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf], 673 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf], 674 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf], 675 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf], 676 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]) 677 | 678 | 679 | 680 | 681 | ```python 682 | qk_per_token_after_masking = qk_per_token + mask 683 | display_qk_heatmap(qk_per_token_after_masking) 684 | ``` 685 | 686 | 687 | 688 | ![png](images/implllama3_52_0.png) 689 | 690 | 691 | 692 |
693 | 694 |
695 | 696 | 697 | ```python 698 | qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16) 699 | display_qk_heatmap(qk_per_token_after_masking_after_softmax) 700 | ``` 701 | 702 | 703 | 704 | ![png](images/implllama3_54_0.png) 705 | 706 | 707 | 708 | ## values (almost the end of attention) 709 | 710 |
711 | 712 |
713 | these scores (0-1) are used to determine how much of value matrix is used per token 714 |
715 | > just like keys, value weights are also shared acorss every 4 attention heads (to save computation) 716 |
717 | > as a result, the shape of the value weight matrix below is [8x128x4096] 718 | 719 | 720 | 721 | ```python 722 | v_layer0 = model["layers.0.attention.wv.weight"] 723 | v_layer0 = v_layer0.view(n_kv_heads, v_layer0.shape[0] // n_kv_heads, dim) 724 | v_layer0.shape 725 | ``` 726 | 727 | 728 | 729 | 730 | torch.Size([8, 128, 4096]) 731 | 732 | 733 | 734 | the first layer, first head value weight matrix is given below 735 | 736 | 737 | ```python 738 | v_layer0_head0 = v_layer0[0] 739 | v_layer0_head0.shape 740 | ``` 741 | 742 | 743 | 744 | 745 | torch.Size([128, 4096]) 746 | 747 | 748 | 749 | ## value vectors 750 |
751 | 752 |
753 | we now use the value weghts to get the attention values per token, this is of size [17x128] where 17 is the number of tokens in the prompt and 128 is the dim of the value vector per token 754 | 755 | 756 | ```python 757 | v_per_token = torch.matmul(token_embeddings, v_layer0_head0.T) 758 | v_per_token.shape 759 | ``` 760 | 761 | 762 | 763 | 764 | torch.Size([17, 128]) 765 | 766 | 767 | 768 | ## attention 769 |
770 | 771 |
772 | the resultant attention vector after multipying with the values per token is of shape [17*128] 773 | 774 | 775 | ```python 776 | qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token) 777 | qkv_attention.shape 778 | ``` 779 | 780 | 781 | 782 | 783 | torch.Size([17, 128]) 784 | 785 | 786 | 787 | # multi head attention 788 |
789 | 790 |
791 | WE NOW HAVE THE ATTENTION VALUE OF THE FIRST LAYER AND FIRST HEAD 792 |
793 | now im going to run a loop and perform the exact same math as the cells above but for every head in the first layer 794 | 795 | 796 | ```python 797 | qkv_attention_store = [] 798 | 799 | for head in range(n_heads): 800 | q_layer0_head = q_layer0[head] 801 | k_layer0_head = k_layer0[head//4] # key weights are shared across 4 heads 802 | v_layer0_head = v_layer0[head//4] # value weights are shared across 4 heads 803 | q_per_token = torch.matmul(token_embeddings, q_layer0_head.T) 804 | k_per_token = torch.matmul(token_embeddings, k_layer0_head.T) 805 | v_per_token = torch.matmul(token_embeddings, v_layer0_head.T) 806 | 807 | q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2) 808 | q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs) 809 | q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis[:len(tokens)]) 810 | q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape) 811 | 812 | k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2) 813 | k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs) 814 | k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis[:len(tokens)]) 815 | k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape) 816 | 817 | qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5 818 | mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device) 819 | mask = torch.triu(mask, diagonal=1) 820 | qk_per_token_after_masking = qk_per_token + mask 821 | qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16) 822 | qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token) 823 | qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token) 824 | qkv_attention_store.append(qkv_attention) 825 | 826 | len(qkv_attention_store) 827 | ``` 828 | 829 | 830 | 831 | 832 | 32 833 | 834 | 835 | 836 |
837 | 838 |
839 | we now have a the qkv_attention matrix for all 32 heads on the first layer, next im going to merge all attention scores into one large matrix of size [17x4096] 840 |
841 | we are almost at the end :) 842 | 843 | 844 | ```python 845 | stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1) 846 | stacked_qkv_attention.shape 847 | ``` 848 | 849 | 850 | 851 | 852 | torch.Size([17, 4096]) 853 | 854 | 855 | 856 | # weight matrix, one of the final steps 857 |
858 | 859 |
860 | one of the last things to do for a layer 0 attention is, is to multiply the weight matrix of the 861 | 862 | 863 | ```python 864 | w_layer0 = model["layers.0.attention.wo.weight"] 865 | w_layer0.shape 866 | ``` 867 | 868 | 869 | 870 | 871 | torch.Size([4096, 4096]) 872 | 873 | 874 | 875 | ### this is a simple linear layer, so we just matmul 876 | 877 | 878 | ```python 879 | embedding_delta = torch.matmul(stacked_qkv_attention, w_layer0.T) 880 | embedding_delta.shape 881 | ``` 882 | 883 | 884 | 885 | 886 | torch.Size([17, 4096]) 887 | 888 | 889 | 890 |
891 | 892 |
893 | we now have the change in the embedding value after attention, that should be adding to the original token embeddings 894 | 895 | 896 | ```python 897 | embedding_after_edit = token_embeddings_unnormalized + embedding_delta 898 | embedding_after_edit.shape 899 | ``` 900 | 901 | 902 | 903 | 904 | torch.Size([17, 4096]) 905 | 906 | 907 | 908 | ## we normalize and then run a feed forward neural network through the embedding delta 909 |
910 | 911 |
912 | 913 | 914 | ```python 915 | embedding_after_edit_normalized = rms_norm(embedding_after_edit, model["layers.0.ffn_norm.weight"]) 916 | embedding_after_edit_normalized.shape 917 | ``` 918 | 919 | 920 | 921 | 922 | torch.Size([17, 4096]) 923 | 924 | 925 | 926 | ## loading the ff weights and implementing the feed forward network 927 |
928 | 929 |
930 | in llama3, they used a SwiGLU feedforward network, this network architecture is really good at adding non linearity when needed by the model. 931 |
932 | its pretty standard to use this feed forward network architecture in llms these days 933 | 934 | 935 | ```python 936 | w1 = model["layers.0.feed_forward.w1.weight"] 937 | w2 = model["layers.0.feed_forward.w2.weight"] 938 | w3 = model["layers.0.feed_forward.w3.weight"] 939 | output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T) 940 | output_after_feedforward.shape 941 | ``` 942 | 943 | 944 | 945 | 946 | torch.Size([17, 4096]) 947 | 948 | 949 | 950 | # WE FINALLY HAVE NEW EDITED EMBEDDINGS FOR EACH TOKEN AFTER THE FIRST LAYER 951 | just 31 more layers to go before we are done (one for loop away) 952 |
953 | you can imagine this edited embedding as having information about all queries asked on the first layer 954 |
955 | now each layer will encode more and more complex queries on the quesions asked, until we have an embedding that knows everything about the next token that we need. 956 | 957 | 958 | ```python 959 | layer_0_embedding = embedding_after_edit+output_after_feedforward 960 | layer_0_embedding.shape 961 | ``` 962 | 963 | 964 | 965 | 966 | torch.Size([17, 4096]) 967 | 968 | 969 | 970 | # god, everything all at once 971 |
972 | 973 |
974 | yep, this is it. everything we did before, all at once, for every single layer. 975 |
976 | 977 | # have fun reading :) 978 | 979 | 980 | ```python 981 | final_embedding = token_embeddings_unnormalized 982 | for layer in range(n_layers): 983 | qkv_attention_store = [] 984 | layer_embedding_norm = rms_norm(final_embedding, model[f"layers.{layer}.attention_norm.weight"]) 985 | q_layer = model[f"layers.{layer}.attention.wq.weight"] 986 | q_layer = q_layer.view(n_heads, q_layer.shape[0] // n_heads, dim) 987 | k_layer = model[f"layers.{layer}.attention.wk.weight"] 988 | k_layer = k_layer.view(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim) 989 | v_layer = model[f"layers.{layer}.attention.wv.weight"] 990 | v_layer = v_layer.view(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim) 991 | w_layer = model[f"layers.{layer}.attention.wo.weight"] 992 | for head in range(n_heads): 993 | q_layer_head = q_layer[head] 994 | k_layer_head = k_layer[head//4] 995 | v_layer_head = v_layer[head//4] 996 | q_per_token = torch.matmul(layer_embedding_norm, q_layer_head.T) 997 | k_per_token = torch.matmul(layer_embedding_norm, k_layer_head.T) 998 | v_per_token = torch.matmul(layer_embedding_norm, v_layer_head.T) 999 | q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2) 1000 | q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs) 1001 | q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis) 1002 | q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape) 1003 | k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2) 1004 | k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs) 1005 | k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis) 1006 | k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape) 1007 | qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5 1008 | mask = torch.full((len(token_embeddings_unnormalized), len(token_embeddings_unnormalized)), float("-inf")) 1009 | mask = torch.triu(mask, diagonal=1) 1010 | qk_per_token_after_masking = qk_per_token + mask 1011 | qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16) 1012 | qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token) 1013 | qkv_attention_store.append(qkv_attention) 1014 | 1015 | stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1) 1016 | w_layer = model[f"layers.{layer}.attention.wo.weight"] 1017 | embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T) 1018 | embedding_after_edit = final_embedding + embedding_delta 1019 | embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f"layers.{layer}.ffn_norm.weight"]) 1020 | w1 = model[f"layers.{layer}.feed_forward.w1.weight"] 1021 | w2 = model[f"layers.{layer}.feed_forward.w2.weight"] 1022 | w3 = model[f"layers.{layer}.feed_forward.w3.weight"] 1023 | output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T) 1024 | final_embedding = embedding_after_edit+output_after_feedforward 1025 | ``` 1026 | 1027 | # we now have the final embedding, the best guess the model could make about the next token 1028 | the shape of the embedding is the same as regular token embeddings [17x4096] where 17 is the number of tokens and 4096 is the embedding dim 1029 |
1030 | 1031 |
1032 | 1033 | 1034 | ```python 1035 | final_embedding = rms_norm(final_embedding, model["norm.weight"]) 1036 | final_embedding.shape 1037 | ``` 1038 | 1039 | 1040 | 1041 | 1042 | torch.Size([17, 4096]) 1043 | 1044 | 1045 | 1046 | # finally, lets decode the embedding into the token value 1047 |
1048 | 1049 |
1050 | we will use the output decoder to convert the final embedding into a token 1051 | 1052 | 1053 | ```python 1054 | model["output.weight"].shape 1055 | ``` 1056 | 1057 | 1058 | 1059 | 1060 | torch.Size([128256, 4096]) 1061 | 1062 | 1063 | 1064 | # we use the embedding of the last token to predict the next value 1065 | hopefully in our case, 42 :) 1066 | note: 42 is the answer to "the answer to the ultimate question of life, the universe, and everything is ", according to the book "hitchhiker's guide to the galaxy", most mordern llms would answer with 42 here, which should validate our entire code! wish me luck :) 1067 | 1068 | 1069 | ```python 1070 | logits = torch.matmul(final_embedding[-1], model["output.weight"].T) 1071 | logits.shape 1072 | ``` 1073 | 1074 | 1075 | 1076 | 1077 | torch.Size([128256]) 1078 | 1079 | 1080 | 1081 | ### the model predicted token number 2983 as the next token, is this the token number for 42? 1082 | IM HYPING YOU UP, this is the last cell of code, hopefully you had fun :) 1083 | 1084 | 1085 | ```python 1086 | next_token = torch.argmax(logits, dim=-1) 1087 | next_token 1088 | ``` 1089 | 1090 | 1091 | 1092 | 1093 | tensor(2983) 1094 | 1095 | 1096 | 1097 | # lets fucking go 1098 |
1099 | 1100 |
1101 | 1102 | 1103 | ```python 1104 | tokenizer.decode([next_token.item()]) 1105 | ``` 1106 | 1107 | 1108 | 1109 | 1110 | '42' 1111 | 1112 | 1113 | 1114 | # thank you, i love you :) 1115 | 1116 | This is the end. Hopefully you enjoyed reading it! 1117 | 1118 | If you want to support my work 1119 | 1120 | 1. follow me on twitter https://twitter.com/naklecha 1121 | 2. or, buy me a coffee [https://www.buymeacoffee.com/naklecha](https://www.buymeacoffee.com/naklecha) 1122 | 1123 | Honestly, if you made it this far you already made my day :) 1124 | 1125 | ## what motivates me? 1126 | 1127 | My friends and I are on a mission - to make research more accessible! 1128 | We created a research lab called A10 - [AAAAAAAAAA.org](http://aaaaaaaaaa.org/) 1129 | 1130 | A10 twitter - https://twitter.com/aaaaaaaaaaorg 1131 | 1132 | our thesis: 1133 |
1134 | 1135 |
1136 | -------------------------------------------------------------------------------- /images/42.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/42.png -------------------------------------------------------------------------------- /images/a10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/a10.png -------------------------------------------------------------------------------- /images/afterattention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/afterattention.png -------------------------------------------------------------------------------- /images/archi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/archi.png -------------------------------------------------------------------------------- /images/attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/attention.png -------------------------------------------------------------------------------- /images/embeddings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/embeddings.png -------------------------------------------------------------------------------- /images/finallayer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/finallayer.png -------------------------------------------------------------------------------- /images/freq_cis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/freq_cis.png -------------------------------------------------------------------------------- /images/god.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/god.png -------------------------------------------------------------------------------- /images/heads.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/heads.png -------------------------------------------------------------------------------- /images/implllama3_30_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/implllama3_30_0.png -------------------------------------------------------------------------------- /images/implllama3_39_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/implllama3_39_0.png -------------------------------------------------------------------------------- /images/implllama3_41_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/implllama3_41_0.png -------------------------------------------------------------------------------- /images/implllama3_42_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/implllama3_42_0.png -------------------------------------------------------------------------------- /images/implllama3_50_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/implllama3_50_0.png -------------------------------------------------------------------------------- /images/implllama3_52_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/implllama3_52_0.png -------------------------------------------------------------------------------- /images/implllama3_54_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/implllama3_54_0.png -------------------------------------------------------------------------------- /images/karpathyminbpe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/karpathyminbpe.png -------------------------------------------------------------------------------- /images/keys.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/keys.png -------------------------------------------------------------------------------- /images/keys0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/keys0.png -------------------------------------------------------------------------------- /images/last_norm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/last_norm.png -------------------------------------------------------------------------------- /images/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/mask.png -------------------------------------------------------------------------------- /images/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/model.png -------------------------------------------------------------------------------- /images/norm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/norm.png -------------------------------------------------------------------------------- /images/norm_after.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/norm_after.png -------------------------------------------------------------------------------- /images/q_per_token.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/q_per_token.png -------------------------------------------------------------------------------- /images/qkmatmul.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/qkmatmul.png -------------------------------------------------------------------------------- /images/qkv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/qkv.png -------------------------------------------------------------------------------- /images/qsplit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/qsplit.png -------------------------------------------------------------------------------- /images/rms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/rms.png -------------------------------------------------------------------------------- /images/rope.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/rope.png -------------------------------------------------------------------------------- /images/ropesplit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/ropesplit.png -------------------------------------------------------------------------------- /images/softmax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/softmax.png -------------------------------------------------------------------------------- /images/stacked.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/stacked.png -------------------------------------------------------------------------------- /images/swiglu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/swiglu.png -------------------------------------------------------------------------------- /images/tokens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/tokens.png -------------------------------------------------------------------------------- /images/v0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/v0.png -------------------------------------------------------------------------------- /images/value.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/value.png -------------------------------------------------------------------------------- /images/weightmatrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/images/weightmatrix.png -------------------------------------------------------------------------------- /llama3/README.md: -------------------------------------------------------------------------------- 1 | # Llama3 from Scratch 2 | 3 | ## 1.简介 4 | 5 | `llama3` 文件夹,主要从 [meta-llama3](https://github.com/meta-llama/llama3) 仓库中,提取的 Llama3 pytorch 实现,移除了 `fairscale` 库,方便使用学习。 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /llama3/model.py: -------------------------------------------------------------------------------- 1 | """ llama3 pytorch 实现 2 | """ 3 | 4 | import math 5 | from dataclasses import dataclass 6 | from typing import Optional, Tuple 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import nn 11 | 12 | 13 | @dataclass 14 | class ModelArgs: 15 | dim: int = 4096 16 | n_layers: int = 32 17 | n_heads: int = 32 18 | n_kv_heads: Optional[int] = None 19 | vocab_size: int = -1 20 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 21 | ffn_dim_multiplier: Optional[float] = None 22 | norm_eps: float = 1e-5 23 | rope_theta: float = 500000 24 | 25 | max_seq_len: int = 2048 26 | 27 | 28 | class RMSNorm(torch.nn.Module): 29 | """ Root Mean Square Layer Normalization 30 | """ 31 | def __init__(self, dim: int, eps: float = 1e-6): 32 | super().__init__() 33 | self.eps = eps 34 | self.weight = nn.Parameter(torch.ones(dim)) 35 | 36 | def _norm(self, x): 37 | # (B, seq_len, dim) 38 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 39 | 40 | def forward(self, x): 41 | output = self._norm(x.float()).type_as(x) 42 | # (dim) * (B, seq_len, dim) --> (B, seq_len, dim) 43 | return output * self.weight 44 | 45 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): 46 | """ Precomputing the frequency tensor with complex exponentials 47 | for the given sequence length and dimensions 48 | """ 49 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 50 | t = torch.arange(end, device=freqs.device, dtype=torch.float32) 51 | freqs = torch.outer(t, freqs).float() 52 | freqs_cos = torch.cos(freqs) 53 | freqs_sin = torch.sin(freqs) 54 | return freqs_cos, freqs_sin 55 | 56 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 57 | ndim = x.ndim 58 | assert 0 <= 1 < ndim 59 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 60 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 61 | return freqs_cis.view(shape) 62 | 63 | def apply_rotary_emb( 64 | xq: torch.Tensor, 65 | xk: torch.Tensor, 66 | freqs_cos: torch.Tensor, 67 | freqs_sin: torch.Tensor 68 | ) -> Tuple[torch.Tensor, torch.Tensor]: 69 | """ Applying rotary position embeddings to input tensors using the given frequency tensor 70 | """ 71 | 72 | # reshape xq and xk to match the complex representation 73 | xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) 74 | xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) 75 | 76 | # reshape freqs_cos and freqs_sin for broadcasting 77 | freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) 78 | freqs_sin = reshape_for_broadcast(freqs_sin, xq_r) 79 | 80 | # apply rotation using real numbers 81 | xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin 82 | xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos 83 | xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin 84 | xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos 85 | 86 | # flatten last two dimensions 87 | xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) 88 | xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) 89 | 90 | return xq_out.type_as(xq), xk_out.type_as(xk) 91 | 92 | def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: 93 | """ Repeating the heads of keys and values to match the number of query heads 94 | """ 95 | bs, slen, n_kv_heads, head_dim = x.shape 96 | if n_rep == 1: 97 | return x 98 | return ( 99 | x[:, :, :, None, :] # (B, seq_len, n_kv_heads, 1, head_size), added a new dimension 100 | .expand(bs, slen, n_kv_heads, n_rep, head_dim) 101 | .reshape(bs, slen, n_kv_heads * n_rep, head_dim) 102 | ) 103 | 104 | 105 | class Attention(nn.Module): 106 | """ Grouped-Query Attention using KV cache with RoPE applied to queries and keys 107 | """ 108 | def __init__(self, args: ModelArgs): 109 | super().__init__() 110 | self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads 111 | model_parallel_size = 1 112 | self.n_local_heads = args.n_heads // model_parallel_size 113 | self.n_local_kv_heads = self.n_kv_heads // model_parallel_size 114 | self.n_rep = self.n_local_heads // self.n_local_kv_heads 115 | self.head_dim = args.dim // args.n_heads 116 | 117 | self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) 118 | self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) 119 | self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) 120 | self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) 121 | 122 | # use flash attention or a manual implementation? 123 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') 124 | # self.flash = False 125 | if not self.flash: 126 | print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") 127 | mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) 128 | mask = torch.triu(mask, diagonal=1) 129 | self.register_buffer("mask", mask) 130 | 131 | def forward( 132 | self, 133 | x: torch.Tensor, 134 | freqs_cos: torch.Tensor, 135 | freqs_sin: torch.Tensor, 136 | ): 137 | # (batch_size, seqlen, dim) 138 | bsz, seqlen, _ = x.shape 139 | 140 | xq = self.wq(x) # (bs, seqlen, dim) --> (bs, seqlen, n_q_heads * head_size) 141 | xk = self.wk(x) # (bs, seqlen, dim) --> (bs, seqlen, n_kv_heads * head_size) 142 | xv = self.wv(x) # (bs, seqlen, dim) --> (bs, seqlen, b_kv_heads * head_size) 143 | 144 | xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) 145 | xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) 146 | xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) 147 | 148 | xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) 149 | 150 | # repeat k/v heads if n_kv_heads < n_heads 151 | xk = repeat_kv(xk, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) 152 | xv = repeat_kv(xv, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) 153 | 154 | xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) 155 | xk = xk.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) 156 | xv = xv.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) 157 | 158 | if self.flash: 159 | output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=0.0 if self.training else 0.0, is_causal=True) 160 | else: 161 | # manual implementation 162 | scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim) 163 | assert hasattr(self, 'mask') 164 | scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen) 165 | scores = F.softmax(scores.float(), dim=-1).type_as(xq) 166 | output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim) 167 | 168 | output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) 169 | 170 | output = self.wo(output) 171 | return output 172 | 173 | 174 | class FeedForward(nn.Module): 175 | """ Feed forward with SwiGLU 176 | """ 177 | def __init__( 178 | self, 179 | dim: int, 180 | hidden_dim: int, 181 | multiple_of: int, 182 | ffn_dim_multiplier: Optional[float], 183 | ): 184 | super().__init__() 185 | hidden_dim = int(2 * hidden_dim / 3) 186 | # custom dim factor multiplier 187 | if ffn_dim_multiplier is not None: 188 | hidden_dim = int(ffn_dim_multiplier * hidden_dim) 189 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 190 | 191 | self.w1 = nn.Linear(dim, hidden_dim, bias=False) 192 | self.w2 = nn.Linear(hidden_dim, dim, bias=False) 193 | self.w3 = nn.Linear(dim, hidden_dim, bias=False) 194 | 195 | def forward(self, x): 196 | # in SwiGLU, the Swish function is used to gate the linear function of GLU 197 | # swish(x) = x * sigmoid(beta * x) 198 | # when beta = 1, swish function becomes same as the sigmoid linear unit function (SiLU) 199 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 200 | 201 | 202 | class TransformerBlock(nn.Module): 203 | """ Transformer block: communication followed by computation 204 | """ 205 | def __init__(self, layer_id: int, args: ModelArgs): 206 | super().__init__() 207 | self.n_heads = args.n_heads 208 | self.dim = args.dim 209 | self.head_dim = args.dim // args.n_heads 210 | self.attention = Attention(args) 211 | self.feed_forward = FeedForward( 212 | dim=args.dim, 213 | hidden_dim=4 * args.dim, 214 | multiple_of=args.multiple_of, 215 | ffn_dim_multiplier=args.ffn_dim_multiplier, 216 | ) 217 | self.layer_id = layer_id 218 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) 219 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) 220 | 221 | def forward( 222 | self, 223 | x: torch.Tensor, 224 | freqs_cos: torch.Tensor, 225 | freqs_sin: torch.Tensor, 226 | ): 227 | # (B, seq_len, dim) + (B, seq_len, dim) --> (B, seq_len, dim) 228 | h = x + self.attention(self.attention_norm(x), freqs_cos, freqs_sin) 229 | out = h + self.feed_forward(self.ffn_norm(h)) 230 | return out 231 | 232 | class Transformer(nn.Module): 233 | """ Transformer module 234 | """ 235 | def __init__(self, params: ModelArgs): 236 | super().__init__() 237 | self.params = params 238 | self.vocab_size = params.vocab_size 239 | self.n_layers = params.n_layers 240 | 241 | self.tok_embeddings = nn.Embedding(self.vocab_size, params.dim) 242 | 243 | self.layers = torch.nn.ModuleList() 244 | for layer_id in range(params.n_layers): 245 | self.layers.append(TransformerBlock(layer_id, params)) 246 | # final normalization layer 247 | self.norm = RMSNorm(params.dim, eps=params.norm_eps) 248 | # final language model head 249 | self.output = nn.Linear(params.dim, self.vocab_size, bias=False) 250 | 251 | freqs_cos, freqs_sin = precompute_freqs_cis(params.dim // params.n_heads, params.max_seq_len, params.rope_theta) 252 | self.register_buffer("freqs_cos", freqs_cos, persistent=False) 253 | self.register_buffer("freqs_sin", freqs_sin, persistent=False) 254 | 255 | # init all weights 256 | self.apply(self._init_weights) 257 | # apply special scaled init to the residual projections, per GPT-2 paper 258 | for pn, p in self.named_parameters(): 259 | if pn.endswith('w3.weight') or pn.endswith('wo.weight'): 260 | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers)) 261 | 262 | self.last_loss = None 263 | 264 | def _init_weights(self, module): 265 | if isinstance(module, nn.Linear): 266 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 267 | if module.bias is not None: 268 | torch.nn.init.zeros_(module.bias) 269 | elif isinstance(module, nn.Embedding): 270 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 271 | 272 | @torch.inference_mode() 273 | def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None): 274 | _bsz, seqlen = tokens.shape 275 | h = self.tok_embeddings(tokens) # (bs, seq_len) --> (bs, seq_len, dim) 276 | 277 | freqs_cos = self.freqs_cos[:seqlen] 278 | freqs_sin = self.freqs_sin[:seqlen] 279 | 280 | for layer in self.layers: 281 | h = layer(h, freqs_cos, freqs_sin) # (bs, seq_len , dim) 282 | h = self.norm(h) # (bs, seq_len , dim) 283 | 284 | if targets is not None: 285 | # if we are given some desired targets also calculate the loss 286 | logits = self.output(h).float() # (bs, seq_len, vocab_size) 287 | self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 288 | else: 289 | # inference-time mini-optimization: only forward the output on the very last position 290 | logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim 291 | self.last_loss = None 292 | 293 | # (bs, seq_len, vocab_size) 294 | return logits 295 | 296 | @torch.inference_mode() 297 | def generate(self, tokens, max_new_tokens, temperature=1.0, top_k=None, eos=None): 298 | """ 299 | Take a conditioning sequence of indices tokens (LongTensor of shape (b,t)) and complete 300 | the sequence max_new_tokens times, feeding the predictions back into the model each time. 301 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. 302 | Also note this is a super inefficient version of sampling with no key/value cache. 303 | """ 304 | for _ in range(max_new_tokens): 305 | # if the sequence context is growing too long we must crop it at block_size 306 | token_cond = tokens if tokens.size(1) <= self.params.max_seq_len else tokens[:, -self.params.max_seq_len:] 307 | # forward the model to get the logits for the index in the sequence 308 | logits = self(token_cond) 309 | logits = logits[:, -1, :] # crop to just the final time step 310 | if temperature == 0.0: 311 | # "sample" the single most likely index 312 | _, next_token = torch.topk(logits, k=1, dim=-1) 313 | else: 314 | # pluck the logits at the final step and scale by desired temperature 315 | logits = logits / temperature 316 | # optionally crop the logits to only the top k options 317 | if top_k is not None: 318 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 319 | logits[logits < v[:, [-1]]] = -float('Inf') 320 | # apply softmax to convert logits to (normalized) probabilities 321 | probs = F.softmax(logits, dim=-1) 322 | next_token = torch.multinomial(probs, num_samples=1) 323 | # append sampled index to the running sequence and continue 324 | tokens = torch.cat((tokens, next_token), dim=1) 325 | if next_token == eos: 326 | break 327 | 328 | return tokens 329 | 330 | def print_model_parameters(model): 331 | """ print model paramenters 332 | """ 333 | param_sum = 0 334 | for name, param in model.named_parameters(): 335 | if param.requires_grad: 336 | param_sum += param.numel() 337 | print(f"Layer: {name}, Parameters: {param.numel()}") 338 | print(f"Total of parameters: {param_sum}") 339 | 340 | 341 | if __name__ == "__main__": 342 | device: str = 'cuda' if torch.cuda.is_available() else 'cpu' 343 | args_xxx = ModelArgs( 344 | dim=4096, 345 | n_layers=2, 346 | n_heads=32, 347 | n_kv_heads=8, 348 | multiple_of=1024, 349 | vocab_size=128256, 350 | ffn_dim_multiplier=1.3, 351 | norm_eps=1e-05, 352 | rope_theta=50000.0 353 | ) 354 | 355 | model = Transformer(args_xxx).to(device) 356 | print("init") 357 | 358 | checkpoint_path = "Meta-Llama-3-8B-Instruct-2layers/consolidated_2layers.pth" 359 | checkpoint = torch.load(checkpoint_path, map_location=device) 360 | model.load_state_dict(checkpoint, strict=False) 361 | print("load success") 362 | 363 | x = torch.tensor([[128000, 1820, 4320, 311, 279, 17139, 3488, 315, 2324, 11, 279, 15861, 11, 323, 4395, 374, 220]]).to(device) 364 | print(x.shape) 365 | logits = model(x) 366 | print(logits.size()) 367 | 368 | next_token = torch.argmax(logits[:, -1], dim=-1) 369 | print(next_token) 370 | # tensor([50210]) 371 | 372 | next_token = model.generate(x, max_new_tokens=1, temperature=0) 373 | print(next_token) 374 | 375 | # print_model_parameters(model) 376 | 377 | 378 | -------------------------------------------------------------------------------- /llama3/tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import ( 4 | AbstractSet, 5 | cast, 6 | Collection, 7 | Dict, 8 | Iterator, 9 | List, 10 | Literal, 11 | Sequence, 12 | Union, 13 | ) 14 | 15 | import tiktoken 16 | from tiktoken.load import load_tiktoken_bpe 17 | 18 | class Tokenizer: 19 | """ Tokenizing and encoding/decoding text using the Tiktoken tokenizer """ 20 | 21 | special_tokens: Dict[str, int] 22 | # number of reserved special tokens 23 | num_reserved_special_tokens = 256 24 | # regex pattern for splitting the text 25 | pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" 26 | 27 | def __init__(self, model_path: str): 28 | assert os.path.isfile(model_path), model_path 29 | 30 | # loading existing tiktoken model 31 | mergeable_ranks = load_tiktoken_bpe(model_path) 32 | # number of base tokens = number of tokens in the existing tiktoken model 33 | num_base_tokens = len(mergeable_ranks) 34 | # list of some special tokens we will add to the tokenizer 35 | special_tokens = [ 36 | "<|begin_of_text|>", 37 | "<|end_of_text|>", 38 | "<|reserved_special_token_0|>", 39 | "<|reserved_special_token_1|>", 40 | "<|reserved_special_token_2|>", 41 | "<|reserved_special_token_3|>", 42 | "<|start_header_id|>", 43 | "<|end_header_id|>", 44 | "<|reserved_special_token_4|>", 45 | "<|eot_id|>", # end of turn 46 | ] + [ 47 | f"<|reserved_special_token_{i}|>" 48 | for i in range(5, self.num_reserved_special_tokens - 5) 49 | ] 50 | # creating a dictionary of special tokens mentioned above 51 | self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)} 52 | 53 | self.model = tiktoken.Encoding( 54 | name=Path(model_path).name, 55 | pat_str=self.pat_str, 56 | mergeable_ranks=mergeable_ranks, 57 | special_tokens=self.special_tokens, 58 | ) 59 | 60 | # vocabulary size 61 | self.n_words: int = self.model.n_vocab 62 | 63 | # BOS / EOS token IDs 64 | self.bos_id: int = self.special_tokens["<|begin_of_text|>"] 65 | self.eos_id: int = self.special_tokens["<|end_of_text|>"] 66 | self.pad_id: int = -1 67 | self.stop_tokens = { 68 | self.special_tokens["<|end_of_text|>"], 69 | self.special_tokens["<|eot_id|>"], 70 | } 71 | 72 | def encode( 73 | self, 74 | s: str, 75 | *, 76 | bos: bool, 77 | eos: bool, 78 | allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), 79 | disallowed_special: Union[Literal["all"], Collection[str]] = (), 80 | ) -> List[int]: 81 | """ 82 | Encodes a string into a list of token IDs. 83 | 84 | Args: 85 | s (str): The input string to be encoded. 86 | bos (bool): Whether to prepend the beginning-of-sequence token. 87 | eos (bool): Whether to append the end-of-sequence token. 88 | allowed_tokens ("all"|set[str]): allowed special tokens in string 89 | disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string 90 | 91 | Returns: 92 | list[int]: A list of token IDs. 93 | 94 | By default, setting disallowed_special=() encodes a string by ignoring 95 | special tokens. Specifically: 96 | - Setting `disallowed_special` to () will cause all text corresponding 97 | to special tokens to be encoded as natural text (insteading of raising 98 | an error). 99 | - Setting `allowed_special` to "all" will treat all text corresponding 100 | to special tokens to be encoded as special tokens. 101 | """ 102 | 103 | assert type(s) is str, "input must be string" 104 | 105 | # the tiktoken tokenizer can handle <=400k chars without pyo3_runtime.PanicException 106 | TIKTOKEN_MAX_ENCODE_CHARS = 400_000 107 | 108 | # max number of consecutive whitespace characters in a substring 109 | MAX_NUM_WHITESPACES_CHARS = 25_000 110 | 111 | # iterating over subsequences and splitting if we exceed the limit of max consecutive non-whitespace or whitespace characters 112 | substrs = ( 113 | substr 114 | for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) 115 | for substr in self._split_whitespaces_or_nonwhitespaces( 116 | s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NUM_WHITESPACES_CHARS 117 | ) 118 | ) 119 | # list of token ids 120 | t: List[int] = [] 121 | for substr in substrs: 122 | t.extend( 123 | self.model.encode( 124 | substr, 125 | allowed_special=allowed_special, 126 | disallowed_special=disallowed_special, 127 | ) 128 | ) 129 | 130 | # prepending the beginning-of-sequence token 131 | if bos: 132 | t.insert(0, self.bos_id) 133 | 134 | # appending the end-of-sequence token 135 | if eos: 136 | t.append(self.eos_id) 137 | return t 138 | 139 | def decode(self, t: Sequence[int]) -> str: 140 | """ 141 | Decodes a list of token IDs into a string. 142 | 143 | Args: 144 | t (List[int]): The list of token IDs to be decoded. 145 | 146 | Returns: 147 | str: The decoded string. 148 | """ 149 | # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. 150 | return self.model.decode(cast(List[int], t)) 151 | 152 | @staticmethod 153 | def _split_whitespaces_or_nonwhitespaces( 154 | s: str, max_consecutive_slice_len: int 155 | ) -> Iterator[str]: 156 | """ 157 | Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` 158 | consecutive whitespaces or consecutive non-whitespaces. 159 | """ 160 | current_slice_len = 0 161 | current_slice_is_space = s[0].isspace() if len(s) > 0 else False 162 | slice_start = 0 163 | 164 | for i in range(len(s)): 165 | is_now_space = s[i].isspace() 166 | 167 | if current_slice_is_space ^ is_now_space: 168 | current_slice_len = 1 169 | current_slice_is_space = is_now_space 170 | else: 171 | current_slice_len += 1 172 | if current_slice_len > max_consecutive_slice_len: 173 | yield s[slice_start:i] 174 | slice_start = i 175 | current_slice_len = 1 176 | yield s[slice_start:] -------------------------------------------------------------------------------- /pdf/从零实现 Llama3 模型.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdndev/llama3-from-scratch-zh/9aaab6416985fc151c36eeca5e4f52c1a987efbc/pdf/从零实现 Llama3 模型.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | sentencepiece 2 | tiktoken 3 | torch 4 | blobfile 5 | matplotlib --------------------------------------------------------------------------------