├── LICENSE ├── README.md ├── assets ├── loss.png └── shakespeare_char │ ├── input.txt │ └── meta.pkl ├── data ├── shakespeare.py └── tokenizers.py ├── env.sh ├── models ├── gpt2.py ├── llama.py └── rope.py ├── sample.py ├── train.py └── train_accelerate.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2025 CUHK (Qiuqiang Kong) 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Minimal PyTorch implementation of GPT2 and Llama 2 | 3 | This repository provides the minimal PyTorch implementation of GPT-2 and LLaMA. This repo simplifies the GPT-2 and LLaMA code for easier understanding and usage. This repo trains a natural language generation system with 1 million characters and converges in less than 5 minutes. 4 | 5 | ## 0. Install dependencies 6 | 7 | ```bash 8 | # Clone the repo 9 | git clone https://github.com/qiuqiangkong/mini_llm 10 | cd mini_llm 11 | 12 | # Install Python environment 13 | conda create --name llm python=3.10 14 | 15 | # Activate environment 16 | conda activate llm 17 | 18 | # Install Python packages dependencies 19 | bash env.sh 20 | ``` 21 | 22 | ## 1. Train 23 | 24 | ```python 25 | CUDA_VISIBLE_DEVICES=0 python train.py --model_name=Llama 26 | ``` 27 | 28 | We train the languge model on the Shakespeares dataset with 1 million characters. The training takes around 20 min to train for 10,000 steps on a single RTX4090. 29 | 30 | ![Training & Validation Loss](assets/loss.png) 31 | 32 | ### Train on Multiple GPUs. 33 | 34 | We use Huggingface accelerate library to train the systems on multiple GPUs. train_accelerate.py just adds a few lines to train.py. Here is an example to run with 4 GPUs: 35 | 36 | ```python 37 | CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --multi_gpu --num_processes 4 train_accelerate.py --model_name=Llama 38 | ``` 39 | 40 | Then, the training can speed up by 4x times. The code can also train with multiple nodes such as 32 GPUs with 4 nodes. 41 | 42 | ## 2. Sample 43 | 44 | ```python 45 | CUDA_VISIBLE_DEVICES=0 python sample.py --model_name=Llama --ckpt_path="checkpoints/train/Llama/step=10000.pth" 46 | ``` 47 | 48 | The sampled texts look like: 49 | 50 |
51 | We may! though a bald prove. We three, I say! What                    
52 | must I see so, most heart?
53 | 
54 | Servant:
55 | He hath ribbons of an the city, which he main for her
56 | voices of the same winder. What say you to yours?
57 | 
58 | Provost:
59 | It was commanded so willingly I do at ever.
60 | So fortune
61 | 
62 | 63 | ## External links 64 | 65 | This repo is benefited from the following repos. 66 | 67 | NanoGPT: https://github.com/karpathy/nanoGPT 68 | 69 | Lit-Llama: https://github.com/Lightning-AI/lit-llama 70 | 71 | ## License 72 | 73 | MIT -------------------------------------------------------------------------------- /assets/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiuqiangkong/mini_llm/6913ca8553c1acddcbe1e42178ac222a042ddaeb/assets/loss.png -------------------------------------------------------------------------------- /assets/shakespeare_char/meta.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiuqiangkong/mini_llm/6913ca8553c1acddcbe1e42178ac222a042ddaeb/assets/shakespeare_char/meta.pkl -------------------------------------------------------------------------------- /data/shakespeare.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import random 4 | 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | from typing_extensions import Literal 8 | 9 | 10 | class ShakespeareChar(Dataset): 11 | r"""Shakespear dataset with plain texts. Size: 1 MB.""" 12 | 13 | def __init__(self, 14 | text_path: str = "input.txt", 15 | tokenizer: object = None, 16 | split: Literal["train", "test"] = "train", 17 | seq_len: int = 256, 18 | ): 19 | super().__init__() 20 | 21 | self.seq_len = seq_len 22 | 23 | # Load all texts 24 | self.ids = load_text_to_ids(text_path=text_path, tokenizer=tokenizer, split=split) 25 | 26 | def __getitem__(self, index: int) -> dict: 27 | r"""Fetch a clip of IDs for training. The `index` argument is not used 28 | because we use only one book for training.""" 29 | 30 | # Randomly sample a position in the book 31 | idx = random.randint(0, len(self.ids) - self.seq_len - 1) 32 | 33 | data = { 34 | "id": self.ids[idx : idx + self.seq_len + 1] # shape: (seq_len + 1,) 35 | } 36 | 37 | return data 38 | 39 | def __len__(self): 40 | 41 | # We call 1000 steps as an `epoch` 42 | return 1000 43 | 44 | 45 | def load_text_to_ids( 46 | text_path: str, 47 | tokenizer: object, 48 | split: Literal["train", "test"] 49 | ) -> np.ndarray: 50 | r"""Load a text file and convert characters to tokens.""" 51 | 52 | # Load texts 53 | with open(text_path, 'r') as file: 54 | text = file.read() 55 | 56 | # Convert texts to token IDs 57 | ids = np.array([tokenizer.stoi(char) for char in text]) 58 | 59 | if split == "train": 60 | ids = ids[0 : 1003854] # Consistent with nanoGPT 61 | 62 | elif split == "test": 63 | ids = ids[1003854 :] # Consistent with nanoGPT 64 | 65 | else: 66 | raise ValueError(split) 67 | 68 | return ids -------------------------------------------------------------------------------- /data/tokenizers.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | 4 | class TokenizerChar: 5 | def __init__(self, meta_path: str) -> None: 6 | 7 | with open(meta_path, 'rb') as f: 8 | self.meta = pickle.load(f) 9 | 10 | def stoi(self, token: str) -> int: 11 | r"""E.g., 'a' -> 39. 12 | """ 13 | return self.meta["stoi"][token] 14 | 15 | def itos(self, id: int) -> str: 16 | r"""E.g., 39 -> 'a'. 17 | """ 18 | return self.meta["itos"][id] 19 | 20 | def __len__(self) -> int: 21 | return self.meta["vocab_size"] -------------------------------------------------------------------------------- /env.sh: -------------------------------------------------------------------------------- 1 | pip install numpy==2.2.1 2 | pip install torch==2.5.1 3 | pip install tqdm==4.67.1 4 | pip install wandb==0.19.1 5 | pip install accelerate==1.2.1 -------------------------------------------------------------------------------- /models/gpt2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/karpathy/nanoGPT/blob/master/model.py 3 | """ 4 | from __future__ import annotations 5 | import math 6 | import inspect 7 | from dataclasses import dataclass 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import functional as F 12 | 13 | 14 | @dataclass 15 | class GPTConfig: 16 | block_size: int = 1024 17 | vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency 18 | n_layer: int = 12 19 | n_head: int = 12 20 | n_embd: int = 768 21 | dropout: float = 0.0 22 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster 23 | 24 | 25 | class GPT2(nn.Module): 26 | def __init__(self, config: GPTConfig): 27 | r"""GPT2. Modified from https://github.com/karpathy/nanoGPT/blob/master/model.py""" 28 | 29 | super().__init__() 30 | 31 | self.config = config 32 | 33 | # Word to embedding (wte) and word position embedding (wpe) 34 | self.wte = nn.Embedding(config.vocab_size, config.n_embd) 35 | self.wpe = nn.Embedding(config.block_size, config.n_embd) 36 | self.drop = nn.Dropout(config.dropout) 37 | 38 | # Transformer blocks 39 | self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)]) 40 | 41 | # Output layers 42 | self.ln_f = LayerNorm(config.n_embd, bias=config.bias) 43 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 44 | 45 | # Bind weights 46 | self.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying 47 | 48 | # init all weights 49 | self.apply(self._init_weights) 50 | 51 | # apply special scaled init to the residual projections, per GPT-2 paper 52 | for pn, p in self.named_parameters(): 53 | if pn.endswith('c_proj.weight'): 54 | torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) 55 | 56 | def _init_weights(self, module): 57 | if isinstance(module, nn.Linear): 58 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 59 | if module.bias is not None: 60 | torch.nn.init.zeros_(module.bias) 61 | elif isinstance(module, nn.Embedding): 62 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 63 | 64 | def forward(self, ids: torch.LongTensor) -> torch.LongTensor: 65 | r"""Next id prediction with GPT2. 66 | 67 | b: batch_size 68 | t: time_steps 69 | d: hidden_size 70 | v: vocab_size 71 | 72 | Args: 73 | ids: (b, t) 74 | 75 | Outputs: 76 | logits: (b, t, v) 77 | """ 78 | 79 | device = ids.device 80 | B, T = ids.shape 81 | 82 | assert T <= self.config.block_size, "Can not forward sequence of {T} > {self.config.block_size}" 83 | 84 | # Absolute positions 85 | pos = torch.arange(0, T, dtype=torch.long, device=device) # shape: (t,) 86 | 87 | # ID embedding and position embedding 88 | id_emb = self.wte(ids) # shape: (b, t, d) 89 | pos_emb = self.wpe(pos) # shape: (t, d) 90 | x = self.drop(id_emb + pos_emb) # shape; (b, t, d) 91 | 92 | # Transformer 93 | for block in self.blocks: 94 | x = block(x) 95 | # x: (b, t, d) 96 | 97 | # Output layers 98 | x = self.ln_f(x) # shape: (b, t, d) 99 | logits = self.lm_head(x) # shape: (b, t, v) 100 | 101 | return logits 102 | 103 | @torch.no_grad() 104 | def generate( 105 | self, 106 | ids: torch.LongTensor, 107 | max_new_ids: int, 108 | temperature: float = 1.0, 109 | top_k: None | int = None 110 | ): 111 | r"""Next ID sampling with auto-regression. Make sure to use model.eval() 112 | 113 | b: batch_size 114 | t: time_steps 115 | v: vocab_size 116 | 117 | Args: 118 | ids: (b, 1) 119 | max_new_ids: int 120 | temperature: float 121 | top_k: None | int 122 | 123 | Returns: 124 | new_ids: (b, t), sampled IDs 125 | """ 126 | input_len = ids.shape[1] 127 | 128 | for _ in range(max_new_ids): 129 | 130 | # If the sequence context is growing too long we must crop it at block_size 131 | if ids.shape[1] <= self.config.block_size: 132 | prev_ids = ids 133 | else: 134 | prev_ids = ids[:, -self.config.block_size:] 135 | 136 | # Forward 137 | logits = self(prev_ids) # shape: (b, t, v) 138 | 139 | # Take the final step logits 140 | logits = logits[:, -1, :] / temperature # shape: (b, v) 141 | 142 | # Crop the logits to only the top k options 143 | if top_k is not None: 144 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 145 | logits[logits < v[:, [-1]]] = -float('Inf') 146 | 147 | # Convert logits to probabilities 148 | probs = F.softmax(logits, dim=-1) # shape: (b, v) 149 | 150 | # Sample the next ID 151 | next_id = torch.multinomial(probs, num_samples=1) # shape: (b, 1) 152 | 153 | # Append the sampled ID to the running IDs and continue 154 | ids = torch.cat((ids, next_id), dim=1) # shape: (b, t) 155 | 156 | new_ids = ids[:, input_len:] # shape: (b, t) 157 | 158 | return new_ids 159 | 160 | 161 | class LayerNorm(nn.Module): 162 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ 163 | 164 | def __init__(self, ndim: int, bias: bool): 165 | super().__init__() 166 | self.weight = nn.Parameter(torch.ones(ndim)) 167 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None 168 | 169 | def forward(self, input): 170 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) 171 | 172 | 173 | class CausalSelfAttention(nn.Module): 174 | 175 | def __init__(self, config: GPTConfig): 176 | super().__init__() 177 | 178 | assert config.n_embd % config.n_head == 0 179 | 180 | # key, query, value projections for all heads, but in a batch 181 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) 182 | 183 | # output projection 184 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 185 | 186 | # regularization 187 | self.attn_dropout = nn.Dropout(config.dropout) 188 | self.resid_dropout = nn.Dropout(config.dropout) 189 | self.n_head = config.n_head 190 | self.n_embd = config.n_embd 191 | self.dropout = config.dropout 192 | 193 | # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 194 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') 195 | 196 | if not self.flash: 197 | print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") 198 | # causal mask to ensure that attention is only applied to the left in the input sequence 199 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) 200 | .view(1, 1, config.block_size, config.block_size)) 201 | 202 | def forward(self, x): 203 | r"""Causal self attention. 204 | 205 | b: batch size 206 | t: time steps 207 | d: latent dim 208 | h: heads num 209 | 210 | Args: 211 | x: (b, t, d) 212 | 213 | Outputs: 214 | x: (b, t, d) 215 | """ 216 | 217 | B, T, D = x.shape 218 | 219 | # Calculate query, key, values 220 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 221 | # q, k, v shapes: (b, t, d) 222 | 223 | k = k.view(B, T, self.n_head, D // self.n_head).transpose(1, 2) 224 | q = q.view(B, T, self.n_head, D // self.n_head).transpose(1, 2) 225 | v = v.view(B, T, self.n_head, D // self.n_head).transpose(1, 2) 226 | # q, k, v shapes: (b, t, h, d/h) 227 | 228 | # Causal self-attention 229 | if self.flash: 230 | # Efficient attention using Flash Attention CUDA kernels 231 | x = torch.nn.functional.scaled_dot_product_attention( 232 | query=q, 233 | key=k, 234 | value=v, 235 | attn_mask=None, 236 | dropout_p=self.dropout if self.training else 0, 237 | is_causal=True 238 | ) 239 | # shape: (b, h, t, d/h) 240 | else: 241 | # manual implementation of attention 242 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # shape: (b, h, t, t) 243 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) # shape: (b, h, t, t) 244 | att = F.softmax(att, dim=-1) # shape: (b, h, t, t) 245 | att = self.attn_dropout(att) # shape: (b, h, t, t) 246 | x = att @ v # shape: (b, h, t, d/h) 247 | 248 | x = x.transpose(1, 2).contiguous().view(B, T, D) # shape: (b, t, d) 249 | 250 | # output projection 251 | x = self.resid_dropout(self.c_proj(x)) # shape: (b, t, d) 252 | 253 | return x 254 | 255 | class MLP(nn.Module): 256 | 257 | def __init__(self, config: GPTConfig): 258 | super().__init__() 259 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) 260 | self.gelu = nn.GELU() 261 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) 262 | self.dropout = nn.Dropout(config.dropout) 263 | 264 | def forward(self, x): 265 | r"""MLP. 266 | 267 | Args: 268 | x: (b, t, d) 269 | 270 | Outputs: 271 | x: (b, t, d) 272 | """ 273 | x = self.c_fc(x) 274 | x = self.gelu(x) 275 | x = self.c_proj(x) 276 | x = self.dropout(x) 277 | return x 278 | 279 | class Block(nn.Module): 280 | 281 | def __init__(self, config: GPTConfig): 282 | super().__init__() 283 | self.att_norm = LayerNorm(config.n_embd, bias=config.bias) 284 | self.att = CausalSelfAttention(config) 285 | self.ffn_norm = LayerNorm(config.n_embd, bias=config.bias) 286 | self.mlp = MLP(config) 287 | 288 | def forward(self, x): 289 | r"""MLP. 290 | 291 | Args: 292 | x: (b, t, d) 293 | 294 | Outputs: 295 | x: (b, t, d) 296 | """ 297 | x = x + self.att(self.att_norm(x)) 298 | x = x + self.mlp(self.ffn_norm(x)) 299 | return x -------------------------------------------------------------------------------- /models/llama.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py 3 | """ 4 | import math 5 | from dataclasses import dataclass 6 | from typing import List, Optional, Tuple, Union 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import functional as F 11 | from models.rope import build_rope, apply_rope 12 | 13 | 14 | @dataclass 15 | class LlamaConfig: 16 | block_size: int = 2048 17 | vocab_size: int = 32000 # Better to be divied by 64 18 | n_layer: int = 32 19 | n_head: int = 32 20 | n_embd: int = 4096 21 | 22 | 23 | # Default Llama configurations 24 | llama_configs = { 25 | "7B": dict(n_layer=32, n_head=32, n_embd=4096), 26 | "13B": dict(n_layer=40, n_head=40, n_embd=5120), 27 | "30B": dict(n_layer=60, n_head=52, n_embd=6656), 28 | "65B": dict(n_layer=80, n_head=64, n_embd=8192), 29 | } 30 | 31 | 32 | class Llama(nn.Module): 33 | r"""Llama model.""" 34 | 35 | def __init__(self, config: LlamaConfig) -> None: 36 | super().__init__() 37 | 38 | self.config = config 39 | 40 | # Word to embedding 41 | self.wte = nn.Embedding(config.vocab_size, config.n_embd) 42 | 43 | # Transformer blocks 44 | self.blocks = nn.ModuleList(Block(config) for _ in range(config.n_layer)) 45 | 46 | # Output layers 47 | self.ln_f = RMSNorm(config.n_embd) 48 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 49 | 50 | # Build RoPE cache 51 | rope = build_rope( 52 | seq_len=config.block_size, 53 | head_dim=config.n_embd // config.n_head, 54 | ) # shape: (t, head_dim/2, 2) 55 | self.register_buffer(name="rope", tensor=rope) 56 | 57 | def _init_weights(self, module: nn.Module) -> None: 58 | if isinstance(module, nn.Linear): 59 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) 60 | elif isinstance(module, nn.Embedding): 61 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) 62 | 63 | def forward( 64 | self, 65 | ids: torch.LongTensor, 66 | mask: None | torch.Tensor = None, 67 | ) -> torch.Tensor: 68 | r"""Next ID prediction with Llama. 69 | 70 | b: batch_size 71 | t: time_steps 72 | d: hidden_size 73 | v: vocab_size 74 | 75 | Args: 76 | IDs: (b, t) 77 | mask: None | (1, 1, t, t) 78 | 79 | Outputs: 80 | logits: (b, t, v) 81 | """ 82 | 83 | device = ids.device 84 | B, T = ids.shape 85 | 86 | assert T <= self.config.block_size, "Can not forward sequence of {T} > {self.config.block_size}" 87 | 88 | if mask is None: 89 | mask = build_causal_mask(seq_len=T).to(device) 90 | 91 | # IDs embedding 92 | x = self.wte(ids) # shape: (b, t, d) 93 | 94 | # Transformer 95 | for block in self.blocks: 96 | x = block(x, self.rope, mask) 97 | # x: (b, t, d) 98 | 99 | # Output layers 100 | x = self.ln_f(x) # shape: (b, t, d) 101 | logits = self.lm_head(x) # shape: (b, t, v) 102 | 103 | return logits 104 | 105 | @torch.no_grad() 106 | def generate( 107 | self, 108 | ids: torch.LongTensor, 109 | max_new_ids: int, 110 | temperature: float = 1.0, 111 | top_k: None | int = None 112 | ): 113 | r"""Next ID sampling with auto-regression. Make sure to use model.eval() 114 | 115 | b: batch_size 116 | t: time_steps 117 | v: vocab_size 118 | 119 | Args: 120 | ids: (b, 1) 121 | max_new_ids: int 122 | temperature: float 123 | top_k: None | int 124 | 125 | Returns: 126 | new_ids: (b, t), sampled IDs 127 | """ 128 | input_len = ids.shape[1] 129 | 130 | for _ in range(max_new_ids): 131 | 132 | # If the sequence context is growing too long we must crop it at block_size 133 | if ids.shape[1] <= self.config.block_size: 134 | prev_ids = ids 135 | else: 136 | prev_ids = ids[:, -self.config.block_size:] 137 | 138 | # Forward 139 | logits = self(prev_ids) # shape: (b, t, v) 140 | 141 | # Take the final step logits 142 | logits = logits[:, -1, :] / temperature # shape: (b, v) 143 | 144 | # Crop the logits to only the top k options 145 | if top_k is not None: 146 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 147 | logits[logits < v[:, [-1]]] = -float('Inf') 148 | 149 | # Convert logits to probabilities 150 | probs = F.softmax(logits, dim=-1) # shape: (b, v) 151 | 152 | # Sample the next ID 153 | next_id = torch.multinomial(probs, num_samples=1) # shape: (b, 1) 154 | 155 | # Append the sampled ID to the running IDs and continue 156 | ids = torch.cat((ids, next_id), dim=1) # shape: (b, t) 157 | 158 | new_ids = ids[:, input_len:] # shape: (b, t) 159 | 160 | return new_ids 161 | 162 | 163 | class Block(nn.Module): 164 | def __init__(self, config: LlamaConfig) -> None: 165 | super().__init__() 166 | self.att_norm = RMSNorm(config.n_embd) 167 | self.att = CausalSelfAttention(config) 168 | self.ffn_norm = RMSNorm(config.n_embd) 169 | self.mlp = MLP(config) 170 | 171 | def forward( 172 | self, 173 | x: torch.Tensor, 174 | rope: torch.Tensor, 175 | mask: torch.Tensor, 176 | ) -> torch.Tensor: 177 | r""" 178 | 179 | Args: 180 | x: (b, t, d) 181 | rope: (t, head_dim/2) 182 | mask: (1, 1, t, t) 183 | 184 | Outputs: 185 | x: (b, t, d) 186 | """ 187 | x = x + self.att(self.att_norm(x), rope, mask) 188 | x = x + self.mlp(self.ffn_norm(x)) 189 | return x 190 | 191 | 192 | class RMSNorm(nn.Module): 193 | r"""Root Mean Square Layer Normalization. 194 | 195 | Ref: https://github.com/meta-llama/llama/blob/main/llama/model.py 196 | """ 197 | def __init__(self, dim: int, eps: float = 1e-6): 198 | 199 | super().__init__() 200 | self.eps = eps 201 | self.scale = nn.Parameter(torch.ones(dim)) 202 | 203 | def forward(self, x): 204 | r"""RMSNorm. 205 | 206 | Args: 207 | x: (b, t, d) 208 | 209 | Outputs: 210 | x: (b, t, d) 211 | """ 212 | norm_x = torch.mean(x ** 2, dim=-1, keepdim=True) 213 | output = x * torch.rsqrt(norm_x + self.eps) * self.scale 214 | return output 215 | 216 | 217 | class CausalSelfAttention(nn.Module): 218 | def __init__(self, config: LlamaConfig) -> None: 219 | super().__init__() 220 | assert config.n_embd % config.n_head == 0 221 | 222 | # key, query, value projections for all heads, but in a batch 223 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False) 224 | 225 | # output projection 226 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) 227 | 228 | self.n_head = config.n_head 229 | self.n_embd = config.n_embd 230 | self.block_size = config.block_size 231 | 232 | def forward( 233 | self, 234 | x: torch.Tensor, 235 | rope: torch.Tensor, 236 | mask: torch.Tensor, 237 | ) -> torch.Tensor: 238 | r"""Causal self attention. 239 | 240 | b: batch size 241 | t: time steps 242 | d: latent dim 243 | h: heads num 244 | 245 | Args: 246 | x: (b, t, d) 247 | rope: (t, head_dim/2, 2) 248 | mask: (1, 1, ) 249 | 250 | Outputs: 251 | x: (b, t, d) 252 | """ 253 | B, T, D = x.shape 254 | 255 | # Calculate query, key, values 256 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 257 | # q, k, v shapes: (b, t, d) 258 | 259 | k = k.view(B, T, self.n_head, D // self.n_head) 260 | q = q.view(B, T, self.n_head, D // self.n_head) 261 | v = v.view(B, T, self.n_head, D // self.n_head) 262 | # q, k, v shapes: (b, t, h, head_dim) 263 | 264 | q = apply_rope(q, rope) 265 | k = apply_rope(k, rope) 266 | # q, k shapes: (b, t, h, head_dim) 267 | 268 | k = k.transpose(1, 2) 269 | q = q.transpose(1, 2) 270 | v = v.transpose(1, 2) 271 | # q, k, v shapes: (b, h, t, head_dim) 272 | 273 | # Efficient attention using Flash Attention CUDA kernels 274 | x = F.scaled_dot_product_attention( 275 | query=q, 276 | key=k, 277 | value=v, 278 | attn_mask=mask, 279 | dropout_p=0.0 280 | ) 281 | # shape: (b, h, t, head_dim) 282 | 283 | x = x.transpose(1, 2).contiguous().view(B, T, D) # shape: (b, t, d) 284 | 285 | # output projection 286 | x = self.c_proj(x) # shape: (b, t, d) 287 | 288 | return x 289 | 290 | 291 | class MLP(nn.Module): 292 | def __init__(self, config: LlamaConfig) -> None: 293 | super().__init__() 294 | 295 | # The hyper-parameters follow https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py 296 | hidden_dim = 4 * config.n_embd 297 | n_hidden = int(2 * hidden_dim / 3) 298 | 299 | self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False) 300 | self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False) 301 | self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False) 302 | 303 | def forward(self, x: torch.Tensor) -> torch.Tensor: 304 | r"""Causal self attention. 305 | 306 | Args: 307 | x: (b, t, d) 308 | 309 | Outputs: 310 | x: (b, t, d) 311 | """ 312 | x = F.silu(self.c_fc1(x)) * self.c_fc2(x) 313 | x = self.c_proj(x) 314 | return x 315 | 316 | 317 | def build_causal_mask(seq_len: int) -> torch.Tensor: 318 | r"""Build causal mask.""" 319 | ones = torch.ones((seq_len, seq_len), dtype=torch.bool) # shape: (t, t) 320 | mask = torch.tril(ones)[None, None, :, :] # shape: (1, 1, t, t) 321 | return mask -------------------------------------------------------------------------------- /models/rope.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from: https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py 3 | """ 4 | import torch 5 | 6 | 7 | def build_rope( 8 | seq_len: int, head_dim: int, base: int = 10000 9 | ) -> torch.Tensor: 10 | r"""Rotary Position Embedding. 11 | Modified from: https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py 12 | 13 | Args: 14 | seq_len: int, e.g., 1024 15 | head_dim: head dim, e.g., 768/24 16 | base: int 17 | 18 | Outputs: 19 | cache: (t, head_dim/2, 2) 20 | """ 21 | 22 | theta = 1.0 / (base ** (torch.arange(0, head_dim, 2) / head_dim)) 23 | 24 | seq_idx = torch.arange(seq_len) 25 | 26 | # Calculate the product of position index and $\theta_i$ 27 | idx_theta = torch.outer(seq_idx, theta).float() 28 | 29 | cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) 30 | 31 | return cache 32 | 33 | 34 | def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: 35 | # truncate to support variable sizes 36 | T = x.size(1) 37 | rope_cache = rope_cache[:T] 38 | 39 | # cast because the reference does 40 | xshaped = x.float().reshape(*x.shape[:-1], -1, 2) 41 | rope_cache = rope_cache.view(1, xshaped.size(1), 1, xshaped.size(3), 2) 42 | x_out2 = torch.stack( 43 | [ 44 | xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], 45 | xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], 46 | ], 47 | -1, 48 | ) 49 | 50 | x_out2 = x_out2.flatten(3) 51 | return x_out2.type_as(x) -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/karpathy/nanoGPT/blob/master/sample.py 3 | """ 4 | from __future__ import annotations 5 | import argparse 6 | from pathlib import Path 7 | 8 | import torch 9 | 10 | from data.tokenizers import TokenizerChar 11 | from train import TokenizerChar, get_model 12 | 13 | 14 | def sample(args): 15 | 16 | # Arguments 17 | model_name = args.model_name 18 | ckpt_path = args.ckpt_path 19 | 20 | num_samples = 5 # Number of samples to draw 21 | max_new_ids = 256 # Number of IDs generated in each sample 22 | temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions 23 | top_k = 200 # Retain only the top_k most likely IDs, clamp others to have 0 probability 24 | device = "cuda" 25 | 26 | # Paths 27 | root = "./assets/shakespeare_char" 28 | meta_path = Path(root, "meta.pkl") 29 | 30 | # Tokenizer 31 | tokenizer = TokenizerChar(meta_path=meta_path) 32 | 33 | # Load model 34 | model = get_model(model_name=model_name, vocab_size=len(tokenizer)) 35 | model.load_state_dict(torch.load(ckpt_path)) 36 | model.to(device) 37 | 38 | # Begin ID 39 | input_id = tokenizer.stoi("\n") # 0 40 | input_ids = torch.LongTensor([[input_id]]).to(device) # (b, 1) 41 | 42 | # Sample 43 | for n in range(num_samples): 44 | 45 | with torch.no_grad(): 46 | model.eval() 47 | ids = model.generate( 48 | ids=input_ids, 49 | max_new_ids=max_new_ids, 50 | temperature=temperature, 51 | top_k=top_k 52 | ) 53 | # shape: (b, t) 54 | 55 | ids = ids[0].cpu().numpy() 56 | strings = ids_to_text(ids, tokenizer) 57 | print(strings) 58 | print("------------") 59 | 60 | 61 | def ids_to_text(ids, tokenizer): 62 | return "".join([tokenizer.itos(id) for id in ids]) 63 | 64 | 65 | if __name__ == "__main__": 66 | 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('--model_name', type=str, required=True) 69 | parser.add_argument('--ckpt_path', type=str, required=True) 70 | args = parser.parse_args() 71 | 72 | sample(args) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import argparse 4 | import pickle 5 | import random 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | import wandb 14 | from torch.utils.data import DataLoader 15 | from tqdm import tqdm 16 | from typing_extensions import Literal 17 | 18 | from data.shakespeare import ShakespeareChar, load_text_to_ids 19 | from data.tokenizers import TokenizerChar 20 | 21 | 22 | def train(args): 23 | 24 | # Arguments 25 | model_name = args.model_name 26 | wandb_log = not args.no_log 27 | 28 | # Default parameters 29 | batch_size = 16 30 | num_workers = 16 31 | pin_memory = True 32 | learning_rate = 1e-4 33 | test_every_n_steps = 200 34 | save_every_n_steps = 2000 35 | training_steps = 10000 36 | seq_len = 256 37 | device = "cuda" 38 | 39 | filename = Path(__file__).stem 40 | 41 | # Paths 42 | root = "./assets/shakespeare_char" 43 | text_path = Path(root, "input.txt") 44 | meta_path = Path(root, "meta.pkl") 45 | 46 | # Checkpoints directory 47 | ckpts_dir = Path("./checkpoints", filename, model_name) 48 | Path(ckpts_dir).mkdir(parents=True, exist_ok=True) 49 | 50 | # Tokenizer 51 | tokenizer = TokenizerChar(meta_path=meta_path) 52 | 53 | # Dataset 54 | train_dataset = ShakespeareChar( 55 | text_path=text_path, 56 | tokenizer=tokenizer, 57 | split="train", 58 | seq_len=seq_len 59 | ) 60 | 61 | # Sampler 62 | train_sampler = MySampler(books_num=1) 63 | 64 | # Dataloader 65 | train_dataloader = DataLoader( 66 | dataset=train_dataset, 67 | batch_size=batch_size, 68 | sampler=train_sampler, 69 | num_workers=num_workers, 70 | pin_memory=pin_memory 71 | ) 72 | 73 | # Model 74 | model = get_model(model_name=model_name, vocab_size=len(tokenizer)) 75 | model.to(device) 76 | 77 | # Optimizer 78 | optimizer = optim.AdamW(params=model.parameters(), lr=learning_rate) 79 | 80 | # Logger 81 | if wandb_log: 82 | wandb.init(project="mini_llm", name="{}".format(model_name)) 83 | 84 | # Train 85 | for step, data in enumerate(tqdm(train_dataloader)): 86 | 87 | # Move data to device 88 | input_ids = data["id"][:, 0 : -1].to(device) # (b, t) 89 | target_ids = data["id"][:, 1 :].to(device) # (b, t) 90 | 91 | # Forward 92 | model.train() 93 | logits = model(ids=input_ids) # shape: (b, t, vocab_size) 94 | 95 | # Loss 96 | loss = ce_loss(output=logits, target=target_ids) 97 | 98 | # Optimize 99 | optimizer.zero_grad() # Reset all parameter.grad to 0 100 | loss.backward() # Update all parameter.grad 101 | optimizer.step() # Update all parameters based on all parameter.grad 102 | 103 | # Evaluate 104 | if step % test_every_n_steps == 0: 105 | 106 | loss_dict = {} 107 | 108 | for split in ["train", "test"]: 109 | loss = validate( 110 | text_path=text_path, 111 | tokenizer=tokenizer, 112 | split=split, 113 | model=model, 114 | seq_len=seq_len 115 | ) 116 | loss_dict[split] = loss 117 | 118 | if wandb_log: 119 | wandb.log( 120 | data={"train_loss": loss_dict["train"], "test_loss": loss_dict["test"]}, 121 | step=step 122 | ) 123 | 124 | print("Train loss: {}".format(loss_dict["train"])) 125 | print("Test loss: {}".format(loss_dict["test"])) 126 | 127 | # Save model 128 | if step % save_every_n_steps == 0: 129 | ckpt_path = Path(ckpts_dir, "step={}.pth".format(step)) 130 | torch.save(model.state_dict(), ckpt_path) 131 | print("Save model to {}".format(ckpt_path)) 132 | 133 | if step == training_steps: 134 | break 135 | 136 | 137 | class MySampler: 138 | def __init__(self, books_num: int): 139 | self.books_num = books_num 140 | 141 | def __iter__(self) -> int: 142 | while True: 143 | yield random.randint(a=0, b=self.books_num) 144 | 145 | 146 | def get_model(model_name: str, vocab_size: int) -> nn.Module: 147 | 148 | if model_name == "GPT2": 149 | from models.gpt2 import GPTConfig, GPT2 150 | config = GPTConfig( 151 | block_size=1024, 152 | vocab_size=vocab_size, 153 | n_layer=12, 154 | n_head=12, 155 | n_embd=768 156 | ) 157 | return GPT2(config=config) 158 | 159 | elif model_name == "Llama": 160 | from models.llama import LlamaConfig, Llama 161 | config = LlamaConfig( 162 | block_size=1024, 163 | vocab_size=vocab_size, 164 | n_layer=12, 165 | n_head=12, 166 | n_embd=768 167 | ) 168 | return Llama(config=config) 169 | 170 | else: 171 | raise ValueError(model_name) 172 | 173 | 174 | def ce_loss(output: torch.Tensor, target: torch.LongTensor) -> float: 175 | r"""Cross entropy loss. 176 | 177 | Args: 178 | output: (b, t, vocab_size) 179 | target: (b, t) 180 | 181 | Outputs: 182 | loss: torch.float 183 | """ 184 | 185 | B, T, V = output.shape 186 | 187 | loss = F.cross_entropy( 188 | input=output.flatten(0, 1), # shape: (b*t, vocab_size) 189 | target=target.flatten(0, 1), # shape: (b*t,) 190 | ignore_index=-1 191 | ) 192 | 193 | return loss 194 | 195 | 196 | def validate( 197 | text_path: str, 198 | tokenizer: object, 199 | split: Literal["train", "test"], 200 | model: nn.Module, 201 | seq_len: int, 202 | valid_steps: int = 100 203 | ) -> float: 204 | r"""Validate the model on part of data.""" 205 | 206 | device = next(model.parameters()).device 207 | 208 | # Load tokens 209 | ids = load_text_to_ids(text_path=text_path, tokenizer=tokenizer, split=split) 210 | 211 | losses = [] 212 | 213 | # Sample indexes from the beginning 214 | for i in range(valid_steps): 215 | 216 | # Fetch data 217 | bgn = i * seq_len 218 | end = (i + 1) * seq_len + 1 219 | clip_ids = ids[bgn : end] # shape: (t + 1,) 220 | 221 | input_ids = torch.LongTensor(clip_ids[None, 0 : -1]).to(device) # (b, t) 222 | target_ids = torch.LongTensor(clip_ids[None, 1 :]).to(device) # (b, t) 223 | 224 | # Forward 225 | with torch.no_grad(): 226 | model.eval() 227 | logits = model(ids=input_ids) # shape: (b, t, vocab_size) 228 | 229 | # Calculate loss 230 | loss = ce_loss(output=logits, target=target_ids) 231 | losses.append(loss.item()) 232 | 233 | return np.mean(losses) 234 | 235 | 236 | if __name__ == "__main__": 237 | 238 | parser = argparse.ArgumentParser() 239 | parser.add_argument('--model_name', type=str, required=True) 240 | parser.add_argument('--no_log', action='store_true', default=False) 241 | args = parser.parse_args() 242 | 243 | train(args) -------------------------------------------------------------------------------- /train_accelerate.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import argparse 4 | import pickle 5 | import random 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | from accelerate import Accelerator 14 | import wandb 15 | from torch.utils.data import DataLoader 16 | from tqdm import tqdm 17 | from typing_extensions import Literal 18 | 19 | from data.shakespeare import ShakespeareChar, load_text_to_ids 20 | from data.tokenizers import TokenizerChar 21 | from train import MySampler, get_model, ce_loss, validate 22 | 23 | 24 | def train(args): 25 | 26 | # Arguments 27 | model_name = args.model_name 28 | wandb_log = not args.no_log 29 | 30 | # Default parameters 31 | batch_size = 16 32 | num_workers = 16 33 | pin_memory = True 34 | learning_rate = 1e-4 35 | test_every_n_steps = 200 36 | save_every_n_steps = 2000 37 | training_steps = 10000 38 | seq_len = 256 39 | device = "cuda" 40 | 41 | filename = Path(__file__).stem 42 | 43 | # Paths 44 | root = "./assets/shakespeare_char" 45 | text_path = Path(root, "input.txt") 46 | meta_path = Path(root, "meta.pkl") 47 | 48 | # Checkpoints directory 49 | ckpts_dir = Path("./checkpoints", filename, model_name) 50 | Path(ckpts_dir).mkdir(parents=True, exist_ok=True) 51 | 52 | # Tokenizer 53 | tokenizer = TokenizerChar(meta_path=meta_path) 54 | 55 | # Dataset 56 | train_dataset = ShakespeareChar( 57 | text_path=text_path, 58 | tokenizer=tokenizer, 59 | split="train", 60 | seq_len=seq_len 61 | ) 62 | 63 | # Sampler 64 | train_sampler = MySampler(books_num=1) 65 | 66 | # Dataloader 67 | train_dataloader = DataLoader( 68 | dataset=train_dataset, 69 | batch_size=batch_size, 70 | sampler=train_sampler, 71 | num_workers=num_workers, 72 | pin_memory=pin_memory 73 | ) 74 | 75 | # Model 76 | model = get_model(model_name=model_name, vocab_size=len(tokenizer)) 77 | model.to(device) 78 | 79 | # Optimizer 80 | optimizer = optim.AdamW(params=model.parameters(), lr=learning_rate) 81 | 82 | # Prepare for multiprocessing 83 | accelerator = Accelerator() 84 | 85 | model, optimizer, train_dataloader = accelerator.prepare( 86 | model, optimizer, train_dataloader) 87 | 88 | # Logger 89 | if wandb_log and accelerator.is_main_process: 90 | wandb.init(project="mini_llm", name="{}".format(model_name)) 91 | 92 | # Train 93 | for step, data in enumerate(tqdm(train_dataloader)): 94 | 95 | # Move data to device 96 | input_ids = data["id"][:, 0 : -1] # (b, t) 97 | target_ids = data["id"][:, 1 :] # (b, t) 98 | 99 | # Forward 100 | model.train() 101 | logits = model(ids=input_ids) # shape: (b, t, vocab_size) 102 | 103 | # Loss 104 | loss = ce_loss(output=logits, target=target_ids) 105 | 106 | # Optimize 107 | optimizer.zero_grad() # Reset all parameter.grad to 0 108 | accelerator.backward(loss) # Update all parameter.grad 109 | optimizer.step() # Update all parameters based on all parameter.grad 110 | 111 | # Evaluate 112 | if step % test_every_n_steps == 0 and accelerator.is_main_process: 113 | 114 | loss_dict = {} 115 | 116 | for split in ["train", "test"]: 117 | loss = validate( 118 | text_path=text_path, 119 | tokenizer=tokenizer, 120 | split=split, 121 | model=accelerator.unwrap_model(model), 122 | seq_len=seq_len 123 | ) 124 | loss_dict[split] = loss 125 | 126 | print("Train loss: {}".format(loss_dict["train"])) 127 | print("Test loss: {}".format(loss_dict["test"])) 128 | 129 | if wandb_log: 130 | wandb.log( 131 | data={"train_loss": loss_dict["train"], "test_loss": loss_dict["test"]}, 132 | step=step 133 | ) 134 | 135 | # Save model 136 | if step % save_every_n_steps == 0 and accelerator.is_main_process: 137 | ckpt_path = Path(ckpts_dir, "step={}.pth".format(step)) 138 | torch.save(accelerator.unwrap_model(model).state_dict(), ckpt_path) 139 | print("Save model to {}".format(ckpt_path)) 140 | 141 | if step == training_steps: 142 | break 143 | 144 | 145 | 146 | if __name__ == "__main__": 147 | 148 | parser = argparse.ArgumentParser() 149 | parser.add_argument('--model_name', type=str, required=True) 150 | parser.add_argument('--no_log', action='store_true', default=False) 151 | args = parser.parse_args() 152 | 153 | train(args) --------------------------------------------------------------------------------