├── RL ├── GSPO.ipynb ├── PPO.ipynb ├── KL.ipynb ├── DPO.ipynb └── GRPO.ipynb ├── Norm ├── RMSNorm.ipynb └── LayerNorm.ipynb ├── Components ├── SwiGLU.ipynb ├── Linear.ipynb ├── LoRA.ipynb ├── RoPE.ipynb └── BPE.ipynb ├── Functional ├── sft.ipynb ├── CE.ipynb ├── activation_fun.ipynb ├── InfoNCE.ipynb ├── sample.ipynb └── quantize.ipynb ├── Attention ├── mask.ipynb ├── MHA.ipynb ├── GQA.ipynb └── MHA_kvcache.ipynb └── readme.md /RL/GSPO.ipynb: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /RL/PPO.ipynb: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /RL/KL.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "bbc37522", 6 | "metadata": {}, 7 | "source": [ 8 | "# KL Divergence\n", 9 | "\n", 10 | "$$\n", 11 | "D_{KL}(P||Q) = \\sum_{x} P(x) \\log \\frac{P(x)}{Q(x)}\n", 12 | "$$\n" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "5cbb2d92", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import torch\n", 23 | "\n", 24 | "def compute_kl(logp, ref_logp, method=\"k1\"):\n", 25 | " logr=ref_logp - logp\n", 26 | " if method==\"k1\":\n", 27 | " kl=-logr\n", 28 | " elif method==\"k2\":\n", 29 | " kl=(logr ** 2) / 2\n", 30 | " else:\n", 31 | " kl=torch.exp(logr) - logr - 1\n", 32 | " return kl" 33 | ] 34 | } 35 | ], 36 | "metadata": { 37 | "language_info": { 38 | "name": "python" 39 | } 40 | }, 41 | "nbformat": 4, 42 | "nbformat_minor": 5 43 | } 44 | -------------------------------------------------------------------------------- /Norm/RMSNorm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "06c45769", 6 | "metadata": {}, 7 | "source": [ 8 | "# RMS Normalization" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "ba7eda15", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "from torch import nn\n", 20 | "\n", 21 | "def RMSNorm(nn.Module):\n", 22 | " def __init__(self, hidden_dim, eps):\n", 23 | " super().__init__()\n", 24 | " self.eps=eps\n", 25 | " self.weight=nn.Parameter(torch.ones(hidden_dim))\n", 26 | " \n", 27 | " def forward(self, x):\n", 28 | " rms=torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True)+self.eps)\n", 29 | " x_norm=x/rms\n", 30 | " return self.weight * x_norm" 31 | ] 32 | } 33 | ], 34 | "metadata": { 35 | "language_info": { 36 | "name": "python" 37 | } 38 | }, 39 | "nbformat": 4, 40 | "nbformat_minor": 5 41 | } 42 | -------------------------------------------------------------------------------- /Components/SwiGLU.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "3b271326", 6 | "metadata": {}, 7 | "source": [ 8 | "# SwiGLU" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "92484602", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "from torch import nn\n", 20 | "import torch.nn.functional as F\n", 21 | "\n", 22 | "class SwiGLU(nn.Module):\n", 23 | " def __init__(self, hidden_dim, intermediate_dim, bias=False):\n", 24 | " super().__init__()\n", 25 | " self.gate_proj=nn.Linear(hidden_dim, intermediate_dim, bias=bias)\n", 26 | " self.up_proj=nn.Linear(hidden_dim, intermediate_dim, bias=bias)\n", 27 | " self.down_proj=nn.Linear(intermediate_dim, hidden_dim, bias=bias)\n", 28 | "\n", 29 | " def forward(self, x):\n", 30 | " return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))" 31 | ] 32 | } 33 | ], 34 | "metadata": { 35 | "kernelspec": { 36 | "display_name": "gaia", 37 | "language": "python", 38 | "name": "python3" 39 | }, 40 | "language_info": { 41 | "name": "python", 42 | "version": "3.11.13" 43 | } 44 | }, 45 | "nbformat": 4, 46 | "nbformat_minor": 5 47 | } 48 | -------------------------------------------------------------------------------- /Norm/LayerNorm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "2cd0a0f8", 6 | "metadata": {}, 7 | "source": [ 8 | "# Layer Normalization" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "5b93ae44", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "from torch import nn\n", 20 | "\n", 21 | "class LayerNorm(nn.Module):\n", 22 | " def __init__(self, hidden_dim, eps=1e-6):\n", 23 | " super().__init__()\n", 24 | " self.weight=nn.Parameter(torch.ones(hidden_dim))\n", 25 | " self.bias=nn.Parameter(torch.zeros(hidden_dim))\n", 26 | " self.eps=eps\n", 27 | "\n", 28 | " def forward(self, x):\n", 29 | " avg=x.mean(dim=-1, keepdim=True)\n", 30 | " var=x.var(dim=-1, keepdim=True, unbiased=False)\n", 31 | " x_norm=(x-avg) / torch.sqrt(var+self.eps)\n", 32 | " return x_norm * self.weight + self.bias" 33 | ] 34 | } 35 | ], 36 | "metadata": { 37 | "kernelspec": { 38 | "display_name": "gaia", 39 | "language": "python", 40 | "name": "python3" 41 | }, 42 | "language_info": { 43 | "name": "python", 44 | "version": "3.11.13" 45 | } 46 | }, 47 | "nbformat": 4, 48 | "nbformat_minor": 5 49 | } 50 | -------------------------------------------------------------------------------- /Functional/sft.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "22d00460", 6 | "metadata": {}, 7 | "source": [ 8 | "# SFT loss" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "6169822b", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch.nn.functional as F\n", 19 | "\n", 20 | "def causal_lm_loss(\n", 21 | " logits, # [batch_size, seq_len, hidden_dim]\n", 22 | " labels, # [batch_size, seq_len]\n", 23 | " pad_token_id: int = 0, \n", 24 | " **kwargs\n", 25 | "):\n", 26 | " # 取最后一个token前的所有token\n", 27 | " shift_logits = logits[..., :-1, :].contiguous().view(-1, logits.size(-1)) # [batch_size * (seq_len-1), hidden_dim]\n", 28 | " # 取第一个token后的所有token\n", 29 | " shift_labels = labels[..., 1:].contiguous().view(-1) # [batch_size * (seq_len-1)]\n", 30 | " \n", 31 | " loss = F.cross_entropy(\n", 32 | " shift_logits, \n", 33 | " shift_labels, \n", 34 | " ignore_index=pad_token_id, \n", 35 | " reduction='mean', \n", 36 | " **kwargs\n", 37 | " )\n", 38 | " \n", 39 | " return loss" 40 | ] 41 | } 42 | ], 43 | "metadata": { 44 | "language_info": { 45 | "name": "python" 46 | } 47 | }, 48 | "nbformat": 4, 49 | "nbformat_minor": 5 50 | } 51 | -------------------------------------------------------------------------------- /RL/DPO.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b55975c6", 6 | "metadata": {}, 7 | "source": [ 8 | "# DPO loss\n", 9 | "\n", 10 | "$$\n", 11 | "\\begin{equation*} \\mathcal{L}_\\text{DPO}(\\pi_{\\theta}; \\pi_{ref}) = -\\mathbb{E}_{(x, y_w, y_l)\\sim \\mathcal{D}}\\left[\\log \\sigma \\left(\\beta \\log \\frac{\\pi_{\\theta}(y_w\\mid x)}{\\pi_{ref}(y_w\\mid x)} - \\beta \\log \\frac{\\pi_{\\theta}(y_l\\mid x)}{\\pi_{ref}(y_l\\mid x)}\\right)\\right] \\end{equation*}\n", 12 | "$$" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "e884ccd7", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "from torch.nn import functional as F\n", 23 | "\n", 24 | "def dpo_loss(chosen_logp, rejected_logp, ref_chosen_logp, ref_rejected_logp, beta=0.1):\n", 25 | " chosen_logratio = chosen_logp - ref_chosen_logp\n", 26 | " rejected_logratio = rejected_logp - ref_rejected_logp\n", 27 | " logratio=chosen_logratio - rejected_logratio\n", 28 | " return - F.logsigmoid(beta * logratio).mean()" 29 | ] 30 | } 31 | ], 32 | "metadata": { 33 | "kernelspec": { 34 | "display_name": "gaia", 35 | "language": "python", 36 | "name": "python3" 37 | }, 38 | "language_info": { 39 | "name": "python", 40 | "version": "3.11.13" 41 | } 42 | }, 43 | "nbformat": 4, 44 | "nbformat_minor": 5 45 | } 46 | -------------------------------------------------------------------------------- /Functional/CE.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "04748b07", 6 | "metadata": {}, 7 | "source": [ 8 | "# Cross-Entropy Loss" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "912a0405", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "import torch.nn.functional as F\n", 20 | "\n", 21 | "def ce_loss(predict, target):\n", 22 | " \"\"\"\n", 23 | " Args:\n", 24 | " logits: 模型的未归一化输出 (形状: [batch_size, num_classes])\n", 25 | " target: 若为类别索引 (形状: [batch_size]),则为每个样本的类别;\n", 26 | " 若为概率分布 (形状: [batch_size, num_classes]),则为每个样本上各类别的分布。\n", 27 | " Returns:\n", 28 | " loss: 交叉熵损失标量\n", 29 | " \"\"\"\n", 30 | " log_prob=F.log_softmax(predict, dim=-1)\n", 31 | "\n", 32 | " if target.dim()==1:\n", 33 | " loss=-log_prob.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1)\n", 34 | " else:\n", 35 | " loss=-(target*log_prob).sum(dim=-1)\n", 36 | " return loss.mean()" 37 | ] 38 | } 39 | ], 40 | "metadata": { 41 | "kernelspec": { 42 | "display_name": "gaia", 43 | "language": "python", 44 | "name": "python3" 45 | }, 46 | "language_info": { 47 | "name": "python", 48 | "version": "3.11.13" 49 | } 50 | }, 51 | "nbformat": 4, 52 | "nbformat_minor": 5 53 | } 54 | -------------------------------------------------------------------------------- /Components/Linear.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "a6e97755", 6 | "metadata": {}, 7 | "source": [ 8 | "# Linear Layer" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "61311f95", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "from torch import nn\n", 20 | "\n", 21 | "class Linear(nn.Module):\n", 22 | " def __init__(self, in_dim, out_dim, bias=True):\n", 23 | " super().__init__()\n", 24 | " # 初始化为 (out_dim, in_dim) 与pytorch底层张量存储方式有关,有利于计算效率\n", 25 | " self.weight=nn.Parameter(torch.randn(out_dim, in_dim))\n", 26 | " self.bias=None\n", 27 | " if bias:\n", 28 | " self.bias=nn.Parameter(torch.randn(out_dim))\n", 29 | " \n", 30 | " def forward(self, x):\n", 31 | " output=x @ self.weight.t()\n", 32 | " if self.bias:\n", 33 | " output+=self.bias\n", 34 | " return output" 35 | ] 36 | } 37 | ], 38 | "metadata": { 39 | "kernelspec": { 40 | "display_name": "gaia", 41 | "language": "python", 42 | "name": "python3" 43 | }, 44 | "language_info": { 45 | "codemirror_mode": { 46 | "name": "ipython", 47 | "version": 3 48 | }, 49 | "file_extension": ".py", 50 | "mimetype": "text/x-python", 51 | "name": "python", 52 | "nbconvert_exporter": "python", 53 | "pygments_lexer": "ipython3", 54 | "version": "3.11.13" 55 | } 56 | }, 57 | "nbformat": 4, 58 | "nbformat_minor": 5 59 | } 60 | -------------------------------------------------------------------------------- /Attention/mask.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "16289e5a", 6 | "metadata": {}, 7 | "source": [ 8 | "# Attention Mask" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "78414fa9", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "\n", 20 | "def create_attention_mask(input_ids, pad_token_id=0, causal=True):\n", 21 | " batch_size, seq_len = input_ids.size()\n", 22 | "\n", 23 | " padding_mask = (input_ids == pad_token_id).view(batch_size, 1, 1, seq_len)\n", 24 | "\n", 25 | " if causal:\n", 26 | " causal_mask = torch.triu(\n", 27 | " torch.ones(seq_len, seq_len, dtype=torch.bool),\n", 28 | " diagonal=1 # 对角线及以下为False,对角线以上为True\n", 29 | " )\n", 30 | " causal_mask = causal_mask.view(1, 1, *causal_mask.shape)\n", 31 | " mask = padding_mask | causal_mask # [batch_size, 1, seq_len, seq_len] 是为了方便 torch.mask_filled广播\n", 32 | " else:\n", 33 | " mask = padding_mask.expand(batch_size, 1, seq_len, seq_len) # [batch_size, 1, seq_len, seq_len]\n", 34 | " return mask\n", 35 | "\n", 36 | "# 使用时:\n", 37 | "# mask = get_attention_mask(input_ids, pad_token_id=0)\n", 38 | "# attention_scores = attention_weights.masked_fill(mask, float('-inf'))" 39 | ] 40 | } 41 | ], 42 | "metadata": { 43 | "kernelspec": { 44 | "display_name": "gaia", 45 | "language": "python", 46 | "name": "python3" 47 | }, 48 | "language_info": { 49 | "name": "python", 50 | "version": "3.11.13" 51 | } 52 | }, 53 | "nbformat": 4, 54 | "nbformat_minor": 5 55 | } 56 | -------------------------------------------------------------------------------- /Functional/activation_fun.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "adf0b727", 6 | "metadata": {}, 7 | "source": [ 8 | "# Activation Function\n", 9 | "## Sigmoid" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "id": "03069e00", 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import numpy as np\n", 20 | "\n", 21 | "def sigmoid(x):\n", 22 | " return 1 / (1+np.exp(-x))" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "id": "aa98c150", 28 | "metadata": {}, 29 | "source": [ 30 | "## Softmax" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "id": "90f42046", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "import numpy as np\n", 41 | "\n", 42 | "def softmax(x, dim):\n", 43 | " exp_x=np.exp(x-np.max(x, axis=dim, keepdims=True)) # 减去最大值防止数值溢出\n", 44 | " return exp_x / np.sum(exp_x, axis=dim, keepdims=True)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "id": "3d4dbccf", 50 | "metadata": {}, 51 | "source": [ 52 | "## SiLU (Swish)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "id": "007abd08", 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "import numpy as np\n", 63 | "\n", 64 | "def silu(x): \n", 65 | " return x / (1+np.exp(-x)) # sigmoid(x) = 1 / (1 + np.exp(-x))" 66 | ] 67 | } 68 | ], 69 | "metadata": { 70 | "kernelspec": { 71 | "display_name": "gaia", 72 | "language": "python", 73 | "name": "python3" 74 | }, 75 | "language_info": { 76 | "name": "python", 77 | "version": "3.11.13" 78 | } 79 | }, 80 | "nbformat": 4, 81 | "nbformat_minor": 5 82 | } 83 | -------------------------------------------------------------------------------- /Functional/InfoNCE.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "c1a5891b", 6 | "metadata": {}, 7 | "source": [ 8 | "# InfoNCE Loss\n", 9 | "\n", 10 | "$$\n", 11 | "\\mathcal{L}_{info} = -\\frac{1}{N}\\sum_{i=1}^N \\log\\frac{\\exp(sim(x_{i}, x_i^+)/\\tau)}{\\sum_{k=1}^{K} \\exp(sim(x_i, x_k^{-})/\\tau)}\n", 12 | "$$" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "44b0a3a6", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import torch\n", 23 | "import torch.nn.functional as F\n", 24 | "\n", 25 | "def info_nce_loss(pairs, temperature=0.07):\n", 26 | " \"\"\"\n", 27 | " InfoNCE损失函数,使用in-batch样本作为negative sample\n", 28 | " Args:\n", 29 | " pairs (torch.Tensor): 输入样本对, shape: [batch_size, 2, feature_dim]\n", 30 | " temperature (float): 温度系数\n", 31 | " \"\"\"\n", 32 | " z1 = F.normalize(pairs[:, 0], p=2, dim=1) # L2归一化后余弦相似度等价于点积\n", 33 | " z2 = F.normalize(pairs[:, 1], p=2, dim=1) \n", 34 | " \n", 35 | " sim_matrix = z1 @ z2.T / temperature # 计算相似度矩阵\n", 36 | " pos_sim = sim_matrix.diagonal() \n", 37 | " total_sim = torch.logsumexp(sim_matrix, dim=1) # 分母:所有样本的相似度(包括正样本和负样本)\n", 38 | " \n", 39 | " # -log(exp(pos_sim) / exp(total_sim))\n", 40 | " loss = -pos_sim + total_sim\n", 41 | " \n", 42 | " return loss.mean()" 43 | ] 44 | } 45 | ], 46 | "metadata": { 47 | "kernelspec": { 48 | "display_name": "gaia", 49 | "language": "python", 50 | "name": "python3" 51 | }, 52 | "language_info": { 53 | "codemirror_mode": { 54 | "name": "ipython", 55 | "version": 3 56 | }, 57 | "file_extension": ".py", 58 | "mimetype": "text/x-python", 59 | "name": "python", 60 | "nbconvert_exporter": "python", 61 | "pygments_lexer": "ipython3", 62 | "version": "3.11.13" 63 | } 64 | }, 65 | "nbformat": 4, 66 | "nbformat_minor": 5 67 | } 68 | -------------------------------------------------------------------------------- /Components/LoRA.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "ade7cd3b", 6 | "metadata": {}, 7 | "source": [ 8 | "# LoRA Linear Layer" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "c2bedc29", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "import torch.nn as nn\n", 20 | "import math\n", 21 | "\n", 22 | "class LoraLinear(nn.Module):\n", 23 | " def __init__(self, in_dim, out_dim, r, alpha,bias=True):\n", 24 | " super().__init__()\n", 25 | " self.in_dim = in_dim\n", 26 | " self.out_dim = out_dim\n", 27 | " self.r = r\n", 28 | " self.alpha = alpha \n", 29 | " self.scale = self.alpha / self.r\n", 30 | " \n", 31 | " self.linear = nn.Linear(in_dim, out_dim, bias=bias)\n", 32 | " self.lora_a = nn.Linear(in_dim, r, bias=False)\n", 33 | " self.lora_b = nn.Linear(r, out_dim, bias=False)\n", 34 | " self._init_weights()\n", 35 | " \n", 36 | " # 冻结原始权重\n", 37 | " self.linear.weight.requires_grad = False\n", 38 | " if self.linear.bias is not None:\n", 39 | " self.linear.bias.requires_grad = False\n", 40 | "\n", 41 | " def _init_weights(self):\n", 42 | " nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))\n", 43 | " nn.init.zeros_(self.lora_b.weight)\n", 44 | "\n", 45 | " def forward(self, x):\n", 46 | " original_output = self.linear(x)\n", 47 | " lora_output = self.lora_b(self.lora_a(x)) * self.scale\n", 48 | " return original_output + lora_output" 49 | ] 50 | } 51 | ], 52 | "metadata": { 53 | "kernelspec": { 54 | "display_name": "gaia", 55 | "language": "python", 56 | "name": "python3" 57 | }, 58 | "language_info": { 59 | "name": "python", 60 | "version": "3.11.13" 61 | } 62 | }, 63 | "nbformat": 4, 64 | "nbformat_minor": 5 65 | } 66 | -------------------------------------------------------------------------------- /Functional/sample.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "e67f6900", 6 | "metadata": {}, 7 | "source": [ 8 | "# 采样\n", 9 | "\n", 10 | "采样顺序需按照:temperature -> top-k -> top-p" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 7, 16 | "id": "0fa592b9", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import torch\n", 21 | "\n", 22 | "def sample(logits, greedy=False, temperature=1.0, top_k=0, top_p=0.0):\n", 23 | " \"\"\"\n", 24 | " logits: [batch_size, vocab_size] # 简化为单步采样\n", 25 | " \"\"\"\n", 26 | " if temperature == 0 or greedy: # 贪婪采样\n", 27 | " return torch.argmax(logits, dim=-1).unsqueeze(-1) # [batch_size, 1]\n", 28 | "\n", 29 | " if temperature > 0:\n", 30 | " logits = logits / temperature\n", 31 | "\n", 32 | " if top_k > 0:\n", 33 | " values, _ = torch.topk(logits, top_k) # [batch_size, top_k]\n", 34 | " min_values = values[:, -1].unsqueeze(-1) # [batch_size, 1]\n", 35 | " # 需要将topk logits散布回原来的位置,保持形状不变,方便后续的multinomial\n", 36 | " logits = torch.where(logits < min_values, torch.full_like(logits, -float(\"inf\")), logits)\n", 37 | "\n", 38 | " if 0 < top_p < 1:\n", 39 | " sorted_logits, sorted_indices = torch.sort(logits, descending=True)\n", 40 | " probs = torch.softmax(sorted_logits, dim=-1)\n", 41 | " cumprobs = torch.cumsum(probs, dim=-1)\n", 42 | "\n", 43 | " mask = cumprobs > top_p\n", 44 | " mask[:, 1:] = mask[:, :-1].clone() # 将mask右移一位,表示当前位置之前的累积prob是否大于top_p\n", 45 | " mask[:, 0] = False\n", 46 | "\n", 47 | " sorted_logits[mask] = -float(\"inf\")\n", 48 | " logits = torch.full_like(logits, -float(\"inf\")).scatter(-1, sorted_indices, sorted_logits)\n", 49 | "\n", 50 | " probs = torch.softmax(logits, dim=-1)\n", 51 | " next_token_id = torch.multinomial(probs, num_samples=1) # 根据prob进行随机抽样\n", 52 | " return next_token_id" 53 | ] 54 | } 55 | ], 56 | "metadata": { 57 | "kernelspec": { 58 | "display_name": "gaia", 59 | "language": "python", 60 | "name": "python3" 61 | }, 62 | "language_info": { 63 | "codemirror_mode": { 64 | "name": "ipython", 65 | "version": 3 66 | }, 67 | "file_extension": ".py", 68 | "mimetype": "text/x-python", 69 | "name": "python", 70 | "nbconvert_exporter": "python", 71 | "pygments_lexer": "ipython3", 72 | "version": "3.11.13" 73 | } 74 | }, 75 | "nbformat": 4, 76 | "nbformat_minor": 5 77 | } 78 | -------------------------------------------------------------------------------- /Components/RoPE.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "1defb7bf", 6 | "metadata": {}, 7 | "source": [ 8 | "# RoPE\n", 9 | "\n", 10 | "$$\n", 11 | "x=[x^{(0)},x^{(1)},...,x^{(|D|-1)}]\n", 12 | "$$\n", 13 | "$$\n", 14 | "f_{rope}([x^{(2d)},x^{2d+1}]^T)=\\begin{pmatrix} \\cos m\\theta_d & -\\sin m\\theta_d) \\\\ \\sin m \\theta_d & \\cos m \\theta_d \\end{pmatrix}\\begin{pmatrix} x^{(2d)} \\\\ x^{2d+1} \\end{pmatrix}\n", 15 | "$$" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "id": "4022f363", 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "import torch\n", 26 | "from torch import nn\n", 27 | "\n", 28 | "class RoPEEmbedding(nn.Module):\n", 29 | " def __init__(self, head_dim, max_seq_len, base=10000):\n", 30 | " super().__init__()\n", 31 | " assert head_dim % 2==0, \"维度必须为偶数\"\n", 32 | "\n", 33 | " self.head_dim=head_dim\n", 34 | " self.max_seq_len=max_seq_len\n", 35 | " self.base=base\n", 36 | "\n", 37 | " # 计算 theta = 1 / (base^(2i / head_dim))\n", 38 | " theta=1.0 / (base**(torch.range(0, head_dim, 2).float() / head_dim)) \n", 39 | " \n", 40 | " pos_ids=torch.arrange(max_seq_len)\n", 41 | " freqs=pos_ids * theta\n", 42 | " sin = torch.sin(freq)\n", 43 | " cos = torch.cos(freq)\n", 44 | " self.register_buffer('sin_table', sin) # [max_seq_len, head_dim/2]\n", 45 | " self.register_buffer('cos_table', cos) # [max_seq_len, head_dim/2]\n", 46 | "\n", 47 | " def forward(self, x, offset=0):\n", 48 | " _, _, seq_len, _=x.shape # [batch_size, num_heads, seq_len, head_dim]\n", 49 | "\n", 50 | " sin=self.sin_table[offset:seq_len+offset]\n", 51 | " cos=self.cos_table[offset:seq_len+offset]\n", 52 | "\n", 53 | " x1=x[..., 0::2] # [batch_size, num_heads, seq_len, head_dim//2]\n", 54 | " x2=x[..., 1::2]\n", 55 | " rotated_x1=x1*cos - x2*sin\n", 56 | " rotated_x2=x2*cos + x1*sin\n", 57 | " # 使用 stack 和 flatten/reshape 来高效地交错合并\n", 58 | " # 1. 堆叠: [batch_size, num_heads, seq_len, head_dim / 2, 2]\n", 59 | " # 2. 展平: [batch_size, num_heads, seq_len, head_dim] \n", 60 | " rotated_x = torch.stack((rotated_x1, rotated_x2), dim=-1).flatten(-2)\n", 61 | " return rotated_x" 62 | ] 63 | } 64 | ], 65 | "metadata": { 66 | "kernelspec": { 67 | "display_name": "gaia", 68 | "language": "python", 69 | "name": "python3" 70 | }, 71 | "language_info": { 72 | "name": "python", 73 | "version": "3.11.13" 74 | } 75 | }, 76 | "nbformat": 4, 77 | "nbformat_minor": 5 78 | } 79 | -------------------------------------------------------------------------------- /Attention/MHA.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "16f2b896", 6 | "metadata": {}, 7 | "source": [ 8 | "# MHA" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "974bffb3", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "from torch import nn\n", 20 | "from xxx import RoPEEmbedding # 假设的RoPE模块\n", 21 | "\n", 22 | "\n", 23 | "class MultiHeadAttention(nn.Module):\n", 24 | " def __init__(self, hidden_dim, num_heads, max_seq_len, dropout=0.1):\n", 25 | " super().__init__()\n", 26 | " assert hidden_dim % num_heads == 0, \"hidden_dim must be divisible by num_heads\"\n", 27 | "\n", 28 | " self.hidden_dim = hidden_dim\n", 29 | " self.num_heads = num_heads\n", 30 | " self.head_dim = hidden_dim // num_heads\n", 31 | " self.scale=self.head_dim ** -0.5\n", 32 | " self.max_seq_len = max_seq_len\n", 33 | "\n", 34 | " self.q_proj = nn.Linear(hidden_dim, hidden_dim)\n", 35 | " self.k_proj = nn.Linear(hidden_dim, hidden_dim)\n", 36 | " self.v_proj = nn.Linear(hidden_dim, hidden_dim)\n", 37 | " self.o_proj = nn.Linear(hidden_dim, hidden_dim)\n", 38 | " self.dropout = nn.Dropout(dropout)\n", 39 | " self.rope = RoPEEmbedding(self.head_dim, max_seq_len)\n", 40 | " def forwward(self, x, mask=None):\n", 41 | " batch_size = x.shape[0]\n", 42 | "\n", 43 | " Q = self.q_proj(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # (batch, num_heads, seq_len, head_dim)\n", 44 | " K = self.k_proj(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n", 45 | " V = self.v_proj(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)\n", 46 | " \n", 47 | " Q = self.rope(Q)\n", 48 | " K = self.rope(K)\n", 49 | "\n", 50 | " attn_scores = Q @ K.transpose(-2, -1) * self.scale\n", 51 | " if mask:\n", 52 | " attn_scores = attn_scores.masked_fill(mask, float('-inf'))\n", 53 | " attn_scores = torch.softmax(attn_scores, dim=-1)\n", 54 | " attn_scores = self.dropout(attn_scores)\n", 55 | "\n", 56 | " output = (attn_scores @ V).transpose(1, 2).reshape(batch_size, -1, self.hidden_dim)\n", 57 | " output = self.o_proj(output)\n", 58 | " return output, attn_scores" 59 | ] 60 | } 61 | ], 62 | "metadata": { 63 | "kernelspec": { 64 | "display_name": "gaia", 65 | "language": "python", 66 | "name": "python3" 67 | }, 68 | "language_info": { 69 | "name": "python", 70 | "version": "3.11.13" 71 | } 72 | }, 73 | "nbformat": 4, 74 | "nbformat_minor": 5 75 | } 76 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # LLM Interview Code 2 | 3 | **LLM面试常见手撕代码合集** 4 | 5 | > ps: 本人目前几十场面试仅遇到过 `MHA`, `RoPE`, `RMSNorm`, `BPE`, `InfoNCE`, `DPO`。如有帮助请点个star⭐️~ 6 | 7 | ## 项目结构 8 | 9 |
| 目录 | 13 |文件 | 14 |说明 | 15 |
|---|---|---|
| Attention | 20 |MHA.ipynb | 21 |多头注意力 (Multi-Head Attention) | 22 |
| GQA.ipynb | 25 |分组查询注意力 (Grouped Query Attention) | 26 ||
| MHA_kvcache.ipynb | 29 |带KV cache的注意力 | 30 ||
| mask.ipynb | 33 |注意力掩码 | 34 ||
| Components | 37 |Linear.ipynb | 38 |线性层 | 39 |
| BPE.ipynb | 42 |Byte Pair Encoding | 43 ||
| LoRA.ipynb | 46 |LoRA Linear 层 | 47 ||
| RoPE.ipynb | 50 |旋转位置编码 | 51 ||
| SwiGLU.ipynb | 54 |SwiGLU 激活函数 | 55 ||
| Norm | 58 |LayerNorm.ipynb | 59 |层归一化 | 60 |
| RMSNorm.ipynb | 63 |RMS归一化 | 64 ||
| Functional | 67 |activation_fun.ipynb | 68 |激活函数 | 69 |
| CE.ipynb | 72 |交叉熵损失 | 73 ||
| InfoNCE.ipynb | 76 |InfoNCE损失 | 77 ||
| quantify.ipynb | 80 |量化 | 81 ||
| SFT.ipynb | 84 |SFT损失 | 85 ||
| RL | 88 |DPO.ipynb | 89 |DPO损失 | 90 |
| GRPO.ipynb | 93 |GRPO损失 | 94 ||
| GSPO.ipynb | 97 |GSPO损失 | 98 ||
| KL.ipynb | 101 |KL散度 | 102 ||
| PPO.ipynb | 105 |PPO损失 | 106 |