├── .gitignore ├── LICENSE ├── README.md ├── assets ├── palm.gif └── palm_loss.png ├── data └── openwebtext │ └── prepare.py ├── model.py ├── nanoPaLM.nsys-rep ├── requirements.txt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | /data/openwebtext/*.bin 3 | /wandb/* 4 | /checkpoints/* 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Robert Riachi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nanoPALM 2 | 3 | 4 | 5 | Inspired by nanoGPT, the simplest, fastest repository for training/finetuning small to medium-sized PALM models. 6 | 7 | This code tries to faithfully reproduce and create a functioning PaLM (Paper: https://arxiv.org/pdf/2204.02311.pdf) as efficiently as possible. 8 | 9 | Trained on OpenWebText, using ~213M params and running on a single Nvidia 3090 GPU for 100,000 iterations (~26 hours) yeilds a val loss of 3.465. I was able to achieve roughly 1.15s/iter on my single machine, running a batch size of 16 with 4 grad accumulations per optimizer step. I've also included the nsys report for those interested in looking at the profile of a couple iterations of the model! 10 | 11 | # Getting started 12 | 13 | ## Requirements 14 | 15 | The code in this repo should work with any version of Python >=3.9. The purpose of this repo is to be lightweight, we only have a few dependencies, you can install them by running: 16 | 17 | ``` 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | ## Prepare your training and validation data 22 | 23 | Simply run the following command. Warning this takes 54GB in huggingface .cache dir, and generates train.bin and val.bin which take up ~18GB 24 | 25 | ``` 26 | python data/openwebtext/prepare.py 27 | ``` 28 | 29 | ## Training 30 | 31 | Ideally you want to have some sort of consumer GPU, the code is actively being developed, but the data loader will raise a NotImplemented error for non-cuda devices. 32 | 33 | Training should just work, hyper-parameters are defined globally in `train.py` for now, and all experimentation is documented as comments explaining why some original methods from the paper were excluded (specifically when they made training unstable). 34 | 35 | For reference, training for 100k iterations on a 3090 takes about 1 day. 36 | 37 | ``` 38 | python train.py 39 | ``` 40 | 41 | # Results 42 | 43 | ## Sample 1 44 | 45 | Prompt: `The meaning of life is` 46 | 47 | Response: `The meaning of life is not lost in the development process or by establishing and maintaining an ideal of shared purpose, where work and practices may not be done in accordance with the principles of imbalances.` 48 | 49 | ## Sample 2 50 | 51 | Prompt: `Once upon a time there was` 52 | 53 | Response: `Once upon a time there was no form of protest. For the last few years there were almost a thousand people who were confronted by police. Earlier this year we had seen an increase in people arrested in the U.S. One in every 20 arrests were made in the U.S.` 54 | 55 | ## Training performance for ~1 day on a single consumer GPU 56 | 57 | 58 | -------------------------------------------------------------------------------- /assets/palm.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RobertRiachi/nanoPALM/71ea685c447ec41be6537def6b5fedb43007e7f1/assets/palm.gif -------------------------------------------------------------------------------- /assets/palm_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RobertRiachi/nanoPALM/71ea685c447ec41be6537def6b5fedb43007e7f1/assets/palm_loss.png -------------------------------------------------------------------------------- /data/openwebtext/prepare.py: -------------------------------------------------------------------------------- 1 | # This code comes directly from nanoGPT: https://github.com/karpathy/nanoGPT/blob/master/data/openwebtext/prepare.py 2 | # However, this might change depending on the future direction of the project, i.e. different tokenization methods, etc... 3 | 4 | # train.bin is ~17GB, val.bin ~8.5MB 5 | # train has ~9B tokens (9,035,582,198) 6 | # val has ~4M tokens (4,434,897) 7 | 8 | import os 9 | from tqdm import tqdm 10 | import numpy as np 11 | import multiprocessing as mp 12 | from transformers import AutoTokenizer 13 | from datasets import load_dataset 14 | 15 | NUM_PROC = mp.cpu_count() // 2 16 | ENCODING_METHOD = 'gpt2' 17 | 18 | # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769) 19 | dataset = load_dataset("openwebtext") 20 | 21 | # owt by default only contains the 'train' split, so create a test split 22 | split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True) 23 | split_dataset['val'] = split_dataset.pop('test') # rename the test split to val 24 | 25 | tokenizer = AutoTokenizer.from_pretrained(ENCODING_METHOD) 26 | def process(example): 27 | # ignore special tokens and append EOT 28 | ids = tokenizer.encode(example['text']) + [tokenizer.eos_token_id] 29 | return {'ids': ids, 'len': len(ids)} 30 | 31 | # tokenize the dataset 32 | tokenized = split_dataset.map( 33 | process, 34 | remove_columns=['text'], 35 | desc="tokenizing the splits", 36 | num_proc=NUM_PROC, 37 | ) 38 | 39 | # concatenate all the ids in each dataset into one large file we can use for training 40 | for split, dset in tokenized.items(): 41 | filename = os.path.join(os.path.dirname(__file__), f'{split}.bin') 42 | arr = np.memmap(filename, dtype=np.uint16 , mode='w+', shape=(np.sum(dset['len']),)) 43 | 44 | print(f"writing {filename}...") 45 | idx = 0 46 | for example in tqdm(dset): 47 | arr[idx : idx + example['len']] = example['ids'] 48 | idx += example['len'] 49 | arr.flush() 50 | 51 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from dataclasses import dataclass 6 | 7 | 8 | def swiglu(x): 9 | x, gate = x.chunk(2, dim=-1) 10 | return F.silu(gate) * x 11 | 12 | 13 | class LayerNorm(nn.Module): 14 | # Disable bias in layernorm, since torch doesn't support bias=False 15 | # From PaLM paper: 16 | # No biases were used in any of the dense kernels or layer norms. 17 | # We found this to result in increased training stability for large models. 18 | 19 | def __init__(self, n_dim): 20 | super().__init__() 21 | self.weight = nn.Parameter(torch.ones(n_dim)) 22 | 23 | def forward(self, x): 24 | # None here is for torch functional's bias param 25 | return F.layer_norm(x, self.weight.shape, self.weight, None, 1e-5) 26 | 27 | 28 | class MultiQueryAttention(nn.Module): 29 | 30 | def __init__(self, config): 31 | super().__init__() 32 | 33 | self.c_attn = nn.Linear(config.n_embed, (config.n_head + 2) 34 | * (config.n_embed // config.n_head), bias=False) 35 | self.out_proj = nn.Linear(config.n_embed, config.n_embed, bias=False) 36 | self.attn_dropout = nn.Dropout(config.dropout) 37 | self.resid_dropout = nn.Dropout(config.dropout) 38 | self.dropout = config.dropout 39 | self.n_embed = config.n_embed 40 | self.n_head = config.n_head 41 | self.head_dim = self.n_embed // self.n_head 42 | 43 | def rotate_embeddings(self, x): 44 | x = x.view(*x.shape[:-1], -1, 2).flip(-1) 45 | x[...,0] *= -1 46 | return x.flatten(start_dim=-2) 47 | 48 | def forward(self, x): 49 | 50 | _, n_tokens, _ = x.shape 51 | head_embed = self.n_embed//self.n_head 52 | 53 | # Multi-Query Attention 54 | q, k, v = self.c_attn(x).split( 55 | [self.n_embed, head_embed, head_embed], dim=2) 56 | q = q.view((*x.shape[:2], self.n_head, -1)).permute(0, 2, 1, 3) 57 | k = k.view(*x.shape[:2], 1, head_embed).permute(0, 2, 1, 3) 58 | v = v.view(*x.shape[:2], 1, head_embed).permute(0, 2, 1, 3) 59 | 60 | # RoPE embeddings 61 | pos = 10000**((-2 * torch.arange(0, self.head_dim, 2, device=x.device) - 1)/self.head_dim) 62 | token_seq = torch.arange(n_tokens, dtype=pos.dtype, device=x.device).unsqueeze(1) @ pos.unsqueeze(0) 63 | rotary_embds = torch.cat((token_seq, token_seq), dim=-1) 64 | 65 | q = (q * rotary_embds.cos()) + \ 66 | (self.rotate_embeddings(q) * rotary_embds.sin()) 67 | k = (k * rotary_embds.cos()) + \ 68 | (self.rotate_embeddings(k) * rotary_embds.sin()) 69 | 70 | attn = F.scaled_dot_product_attention(q,k,v, dropout_p=self.dropout, is_causal=True) 71 | 72 | attn = attn.permute(0, 2, 1, 3).flatten(start_dim=2) 73 | return self.resid_dropout(self.out_proj(attn)) 74 | 75 | 76 | class MLP(nn.Module): 77 | def __init__(self, config): 78 | super().__init__() 79 | 80 | # Traditionally scale by 4, but overcompensate b/c multi-query attention 81 | h_dim = 4 * config.n_head * config.n_embed 82 | # double h_dim b/c swiglu activation 83 | self.fc = nn.Linear(config.n_embed, 2*h_dim, bias=False) 84 | self.proj = nn.Linear(h_dim, config.n_embed, bias=False) 85 | self.dropout = nn.Dropout(config.dropout) 86 | 87 | def forward(self, x): 88 | x = swiglu(self.fc(x)) 89 | return self.dropout(self.proj(x)) 90 | 91 | 92 | class ParallelLayerBlock(nn.Module): 93 | def __init__(self, config): 94 | super().__init__() 95 | 96 | self.mlp = MLP(config) 97 | self.mlp_ln = LayerNorm(config.n_embed) 98 | 99 | self.multi_query_attn = MultiQueryAttention(config) 100 | self.mqa_ln = LayerNorm(config.n_embed) 101 | 102 | def forward(self, x): 103 | mlp_out = self.mlp(self.mlp_ln(x)) 104 | attn_out = self.multi_query_attn(self.mqa_ln(x)) 105 | return x + mlp_out + attn_out 106 | 107 | 108 | class PaLM(nn.Module): 109 | 110 | def __init__(self, config): 111 | super().__init__() 112 | 113 | self.config = config 114 | 115 | self.decoder = nn.ModuleDict(dict( 116 | word_embds=nn.Embedding(config.vocab_size, config.n_embed), 117 | drop=nn.Dropout(config.dropout), 118 | blocks=nn.ModuleList([ParallelLayerBlock(config) 119 | for _ in range(config.n_layer)]), 120 | out_ln=LayerNorm(config.n_embed) 121 | )) 122 | 123 | # Set linear head weights to embedding weights according to paper 124 | self.ln_vocab = nn.Linear( 125 | config.n_embed, config.vocab_size, bias=False) 126 | self.ln_vocab.weight = self.decoder.word_embds.weight 127 | 128 | self.apply(self._init_weights) 129 | 130 | def _init_weights(self, module): 131 | # Paper inits all weights aside from embedding and layer_norm using W ~ N(0, 1/sqrt(n_in)) 132 | # Input embeddings get initalized to E ~ N(0,1) since layer_norm isn't applied to the embedding 133 | if isinstance(module, nn.Linear): 134 | torch.nn.init.normal_(module.weight, mean=0.0, std=1/math.sqrt(module.in_features)) 135 | elif isinstance(module, nn.Embedding): 136 | nn.init.normal_(self.decoder.word_embds.weight) # maybe make std=0.02 here 137 | 138 | @torch.no_grad() 139 | def generate(self, input_tokens, max_length, terminal_ids=None, temp=1.0): 140 | 141 | while len(input_tokens) < max_length: 142 | 143 | logits, _ = self(input_tokens) 144 | logits = logits[:, -1, :] / temp 145 | 146 | token_scores = F.softmax(logits, dim=-1) 147 | next_token = torch.multinomial(token_scores, num_samples=1) 148 | 149 | input_tokens = torch.cat((input_tokens, next_token), dim=1) 150 | 151 | if terminal_ids and next_token in terminal_ids: 152 | break 153 | 154 | return input_tokens 155 | 156 | 157 | def forward(self, x, targets=None): 158 | 159 | x = self.decoder.word_embds(x) 160 | x = self.decoder.drop(x) 161 | 162 | for block in self.decoder.blocks: 163 | x = block(x) 164 | 165 | x = self.decoder.out_ln(x) 166 | 167 | logits = self.ln_vocab(x) 168 | 169 | if targets is not None: 170 | # Paper scales pre-softmax output logits by 1/sqrt(n_embed), but I can't get this to work well 171 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), 172 | targets.contiguous().view(-1), 173 | ignore_index=-1) 174 | 175 | return logits, loss 176 | return logits, None 177 | 178 | 179 | @dataclass 180 | class PaLMConfig: 181 | n_embed: int 182 | n_head: int 183 | dropout: float 184 | vocab_size: int 185 | n_layer: int 186 | -------------------------------------------------------------------------------- /nanoPaLM.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RobertRiachi/nanoPALM/71ea685c447ec41be6537def6b5fedb43007e7f1/nanoPaLM.nsys-rep -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.0.0 2 | transformers 3 | datasets 4 | tqdm 5 | wandb -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import wandb 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | from torch.optim import AdamW 8 | from contextlib import nullcontext 9 | from model import PaLMConfig, PaLM, LayerNorm 10 | from transformers import AutoTokenizer 11 | from tqdm import tqdm 12 | 13 | # TODO: clean this up 14 | device = "cuda" if torch.cuda.is_available() else "cpu" # No love for MPS for now 15 | run_name = "palm" 16 | 17 | # Evaluation 18 | eval_freq = 100#1000 19 | num_evals = 20#100 20 | best_val_loss = 1e9 21 | 22 | # Data 23 | datasets_dir = 'data' 24 | dataset = "openwebtext" 25 | grad_accumulation_steps = 4 26 | batch_size = 16 # Paper follows get_bs function defined below, but this might be too extreme for consumer GPUs 27 | block_size = 512 # Paper uses 2048 but this might be a bit too extreme for consumer GPUs 28 | 29 | # Training 30 | # Note: Paper uses lr=1e-2 for 10k iters, then drops to 1/sqrt(step) 31 | # I've found 2e-4 and cosine decay following Chinchilla guidelines to work better 32 | start_iter = 0 # TODO: Update this when loading from checkpoint in the future 33 | max_iters = 100000 34 | warmup_iters = 2000 35 | learning_rate = 2e-4 # Modified at runtime to follow cosine decay 36 | lr_decay_iters = max_iters # Chinchilla 37 | min_learning_rate = learning_rate / 10 # Chinchilla 38 | weight_decay = learning_rate**2.0 # Decoupled weight decay & modified at runtime 39 | grad_clip = 0.5 40 | 41 | # Precision 42 | precision = torch.bfloat16 43 | amp_enabled = (precision == torch.bfloat16) # Only works with bfloat16 on my gpu, else loss becomes nan not sure why 44 | amp_ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(enabled=amp_enabled, device_type=device, dtype=precision) 45 | scaler = torch.cuda.amp.GradScaler(enabled=amp_enabled) 46 | 47 | # WandB 48 | wandb_logging_enabled = False 49 | wandb_project_name = "nanoPaLM" 50 | 51 | # Config 52 | config = {k:v for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))} 53 | 54 | def get_lr(step): 55 | # Warmup, else cosine decay learning rate 56 | if step < warmup_iters: 57 | return learning_rate * step / warmup_iters 58 | 59 | decay = (step - warmup_iters) / (lr_decay_iters - warmup_iters) 60 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay)) 61 | return min_learning_rate + coeff * (learning_rate - min_learning_rate) 62 | 63 | 64 | def update_optim(optim, step): 65 | 66 | for group in optim.param_groups: 67 | lr = get_lr(step) 68 | group['lr'] = lr 69 | 70 | # If not in no_decay group update decay 71 | if group['weight_decay'] != 0.0: 72 | group['weight_decay'] = lr**2.0 73 | 74 | 75 | def num_model_params(model): 76 | units = ['', 'K', 'M', 'B', 'T'] 77 | total_params = sum(p.numel() 78 | for p in model.parameters() if p.requires_grad) 79 | mag = int(math.floor(math.log(total_params, 1000))) 80 | return f"{int(total_params / 1000**mag)}{units[mag]}" 81 | 82 | 83 | # Buffers for incoming data 84 | xy = torch.empty((batch_size, block_size+1), dtype=torch.int32).pin_memory() 85 | xy_cuda = torch.empty((batch_size, block_size+1), dtype=torch.int64, device="cuda") 86 | 87 | def load_batch(split, batch_size, device): 88 | global xy 89 | # Select which items to load 90 | ix = torch.randint(len(split) - block_size, (batch_size,)) 91 | # Set the relevant elements of xy 92 | for i, data_i in enumerate(ix): 93 | xy[i].numpy()[...] = split[data_i:data_i+1+block_size] 94 | if device == 'cuda': 95 | # Copy the incoming data directly from pinned memory into cuda mem 96 | xy_cuda.copy_(xy, non_blocking=True) 97 | # Slice out x and y 98 | x = xy_cuda[:, :-1] 99 | y = xy_cuda[:, 1:] 100 | else: 101 | raise NotImplementedError 102 | #x, y = x.to(device), y.to(device) 103 | return x, y 104 | 105 | 106 | @torch.no_grad() 107 | def evaluate_splits(model, splits, split_names, num_evals, batch_size, device): 108 | model.eval() 109 | split_losses = {} 110 | for split, split_name in zip(splits, split_names): 111 | losses = torch.zeros(num_evals) 112 | for i in range(num_evals): 113 | x, y = load_batch(split, batch_size, device) 114 | 115 | with amp_ctx: 116 | _, loss = model(x, y) 117 | 118 | losses[i] = loss.item() 119 | 120 | split_losses[split_name] = losses.mean() 121 | model.train() 122 | return split_losses 123 | 124 | 125 | if __name__ == "__main__": 126 | 127 | # Load data & tokenizer 128 | data_dir = os.path.join(datasets_dir, dataset) 129 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 130 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') 131 | tokenizer = AutoTokenizer.from_pretrained('gpt2') 132 | 133 | # Load model 134 | palm_config = PaLMConfig(n_embed=768, 135 | n_head=6, 136 | dropout=0.1, 137 | vocab_size=50304, # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency 138 | n_layer=4) 139 | 140 | model = PaLM(palm_config).to(device) 141 | num_params = num_model_params(model) 142 | print(f"Initializing PaLM model with {num_params} params") 143 | 144 | # Initalize logging 145 | if wandb_logging_enabled: 146 | import wandb 147 | wandb.init(project=wandb_project_name, name=run_name, config=palm_config) 148 | 149 | # Disable weight decay for unwanted modules 150 | # PaLM model has no bias so only include weight params 151 | # Exclude ln_vocab.weight as it's weight is tied to the word embedding weights 152 | no_decay_modules = [LayerNorm, torch.nn.Embedding] 153 | decay_modules = [torch.nn.Linear] 154 | param_dict = {pn: p for pn, p in model.named_parameters()} 155 | no_decay_params = [f"{n}.weight" for n, m in model.named_modules() if any( 156 | nd for nd in decay_modules if isinstance(m, nd))] 157 | decay_params = [f"{n}.weight" for n, m in model.named_modules() if any( 158 | nd for nd in no_decay_modules if isinstance(m, nd))] 159 | 160 | optimizer_grouped_parameters = [ 161 | {'params': [param_dict[p] for p in decay_params], 'weight_decay': weight_decay}, 162 | {'params': [param_dict[p] for p in no_decay_params if p != 'ln_vocab.weight'], 'weight_decay': 0.0} 163 | ] 164 | 165 | # Model uses betas=(0.9, (1-step**-0.8)), but I've found default works better w/ AdamW 166 | optim = AdamW(optimizer_grouped_parameters, 167 | lr=learning_rate, 168 | fused=True if device == 'cuda' else False) 169 | 170 | model = torch.compile(model) 171 | 172 | # Training loop 173 | for step in tqdm(range(start_iter, max_iters + 1)): 174 | 175 | update_optim(optim, step) 176 | 177 | if step % eval_freq == 0 and step != 0: 178 | losses = evaluate_splits(model, 179 | splits=[train_data, val_data], 180 | split_names=['train', 'val'], 181 | num_evals=num_evals, 182 | batch_size=batch_size, 183 | device=device) 184 | print(f"Step {step}: Training loss={losses['train']}") 185 | 186 | if wandb_logging_enabled: 187 | wandb.log({ 188 | "iter": step, 189 | "train/loss": losses['train'], 190 | "val/loss": losses['val'], 191 | "lr": get_lr(step) 192 | }) 193 | 194 | if losses['val'] < best_val_loss: 195 | best_val_loss = losses['val'] 196 | checkpoint = { 197 | 'model': model.state_dict(), 198 | 'optimizer': optim.state_dict(), 199 | 'model_args': palm_config, 200 | 'step': step, 201 | 'best_val_loss': best_val_loss, 202 | 'config': config, 203 | } 204 | print(f"Saving checkpoint, step:{step}, val_loss:{best_val_loss}") 205 | check_out = f"checkpoints/{run_name}" 206 | 207 | if not os.path.exists(check_out): 208 | os.mkdir(check_out) 209 | torch.save(checkpoint, os.path.join(check_out, "ckpt.pt")) 210 | 211 | for micro_step in range(grad_accumulation_steps): 212 | x, y = load_batch(train_data, batch_size, device=device) 213 | 214 | with amp_ctx: 215 | logits, loss = model(x, y) 216 | 217 | scaler.scale(loss).backward() 218 | 219 | # Grad clipping for all model sizes 220 | scaler.unscale_(optim) 221 | torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 222 | 223 | scaler.step(optim) 224 | scaler.update() 225 | 226 | optim.zero_grad() 227 | --------------------------------------------------------------------------------