├── .gitignore
├── LICENSE
├── README.md
├── assets
├── palm.gif
└── palm_loss.png
├── data
└── openwebtext
│ └── prepare.py
├── model.py
├── nanoPaLM.nsys-rep
├── requirements.txt
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | /data/openwebtext/*.bin
3 | /wandb/*
4 | /checkpoints/*
5 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Robert Riachi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # nanoPALM
2 |
3 |
4 |
5 | Inspired by nanoGPT, the simplest, fastest repository for training/finetuning small to medium-sized PALM models.
6 |
7 | This code tries to faithfully reproduce and create a functioning PaLM (Paper: https://arxiv.org/pdf/2204.02311.pdf) as efficiently as possible.
8 |
9 | Trained on OpenWebText, using ~213M params and running on a single Nvidia 3090 GPU for 100,000 iterations (~26 hours) yeilds a val loss of 3.465. I was able to achieve roughly 1.15s/iter on my single machine, running a batch size of 16 with 4 grad accumulations per optimizer step. I've also included the nsys report for those interested in looking at the profile of a couple iterations of the model!
10 |
11 | # Getting started
12 |
13 | ## Requirements
14 |
15 | The code in this repo should work with any version of Python >=3.9. The purpose of this repo is to be lightweight, we only have a few dependencies, you can install them by running:
16 |
17 | ```
18 | pip install -r requirements.txt
19 | ```
20 |
21 | ## Prepare your training and validation data
22 |
23 | Simply run the following command. Warning this takes 54GB in huggingface .cache dir, and generates train.bin and val.bin which take up ~18GB
24 |
25 | ```
26 | python data/openwebtext/prepare.py
27 | ```
28 |
29 | ## Training
30 |
31 | Ideally you want to have some sort of consumer GPU, the code is actively being developed, but the data loader will raise a NotImplemented error for non-cuda devices.
32 |
33 | Training should just work, hyper-parameters are defined globally in `train.py` for now, and all experimentation is documented as comments explaining why some original methods from the paper were excluded (specifically when they made training unstable).
34 |
35 | For reference, training for 100k iterations on a 3090 takes about 1 day.
36 |
37 | ```
38 | python train.py
39 | ```
40 |
41 | # Results
42 |
43 | ## Sample 1
44 |
45 | Prompt: `The meaning of life is`
46 |
47 | Response: `The meaning of life is not lost in the development process or by establishing and maintaining an ideal of shared purpose, where work and practices may not be done in accordance with the principles of imbalances.`
48 |
49 | ## Sample 2
50 |
51 | Prompt: `Once upon a time there was`
52 |
53 | Response: `Once upon a time there was no form of protest. For the last few years there were almost a thousand people who were confronted by police. Earlier this year we had seen an increase in people arrested in the U.S. One in every 20 arrests were made in the U.S.`
54 |
55 | ## Training performance for ~1 day on a single consumer GPU
56 |
57 |
58 |
--------------------------------------------------------------------------------
/assets/palm.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RobertRiachi/nanoPALM/71ea685c447ec41be6537def6b5fedb43007e7f1/assets/palm.gif
--------------------------------------------------------------------------------
/assets/palm_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RobertRiachi/nanoPALM/71ea685c447ec41be6537def6b5fedb43007e7f1/assets/palm_loss.png
--------------------------------------------------------------------------------
/data/openwebtext/prepare.py:
--------------------------------------------------------------------------------
1 | # This code comes directly from nanoGPT: https://github.com/karpathy/nanoGPT/blob/master/data/openwebtext/prepare.py
2 | # However, this might change depending on the future direction of the project, i.e. different tokenization methods, etc...
3 |
4 | # train.bin is ~17GB, val.bin ~8.5MB
5 | # train has ~9B tokens (9,035,582,198)
6 | # val has ~4M tokens (4,434,897)
7 |
8 | import os
9 | from tqdm import tqdm
10 | import numpy as np
11 | import multiprocessing as mp
12 | from transformers import AutoTokenizer
13 | from datasets import load_dataset
14 |
15 | NUM_PROC = mp.cpu_count() // 2
16 | ENCODING_METHOD = 'gpt2'
17 |
18 | # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769)
19 | dataset = load_dataset("openwebtext")
20 |
21 | # owt by default only contains the 'train' split, so create a test split
22 | split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True)
23 | split_dataset['val'] = split_dataset.pop('test') # rename the test split to val
24 |
25 | tokenizer = AutoTokenizer.from_pretrained(ENCODING_METHOD)
26 | def process(example):
27 | # ignore special tokens and append EOT
28 | ids = tokenizer.encode(example['text']) + [tokenizer.eos_token_id]
29 | return {'ids': ids, 'len': len(ids)}
30 |
31 | # tokenize the dataset
32 | tokenized = split_dataset.map(
33 | process,
34 | remove_columns=['text'],
35 | desc="tokenizing the splits",
36 | num_proc=NUM_PROC,
37 | )
38 |
39 | # concatenate all the ids in each dataset into one large file we can use for training
40 | for split, dset in tokenized.items():
41 | filename = os.path.join(os.path.dirname(__file__), f'{split}.bin')
42 | arr = np.memmap(filename, dtype=np.uint16 , mode='w+', shape=(np.sum(dset['len']),))
43 |
44 | print(f"writing {filename}...")
45 | idx = 0
46 | for example in tqdm(dset):
47 | arr[idx : idx + example['len']] = example['ids']
48 | idx += example['len']
49 | arr.flush()
50 |
51 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from dataclasses import dataclass
6 |
7 |
8 | def swiglu(x):
9 | x, gate = x.chunk(2, dim=-1)
10 | return F.silu(gate) * x
11 |
12 |
13 | class LayerNorm(nn.Module):
14 | # Disable bias in layernorm, since torch doesn't support bias=False
15 | # From PaLM paper:
16 | # No biases were used in any of the dense kernels or layer norms.
17 | # We found this to result in increased training stability for large models.
18 |
19 | def __init__(self, n_dim):
20 | super().__init__()
21 | self.weight = nn.Parameter(torch.ones(n_dim))
22 |
23 | def forward(self, x):
24 | # None here is for torch functional's bias param
25 | return F.layer_norm(x, self.weight.shape, self.weight, None, 1e-5)
26 |
27 |
28 | class MultiQueryAttention(nn.Module):
29 |
30 | def __init__(self, config):
31 | super().__init__()
32 |
33 | self.c_attn = nn.Linear(config.n_embed, (config.n_head + 2)
34 | * (config.n_embed // config.n_head), bias=False)
35 | self.out_proj = nn.Linear(config.n_embed, config.n_embed, bias=False)
36 | self.attn_dropout = nn.Dropout(config.dropout)
37 | self.resid_dropout = nn.Dropout(config.dropout)
38 | self.dropout = config.dropout
39 | self.n_embed = config.n_embed
40 | self.n_head = config.n_head
41 | self.head_dim = self.n_embed // self.n_head
42 |
43 | def rotate_embeddings(self, x):
44 | x = x.view(*x.shape[:-1], -1, 2).flip(-1)
45 | x[...,0] *= -1
46 | return x.flatten(start_dim=-2)
47 |
48 | def forward(self, x):
49 |
50 | _, n_tokens, _ = x.shape
51 | head_embed = self.n_embed//self.n_head
52 |
53 | # Multi-Query Attention
54 | q, k, v = self.c_attn(x).split(
55 | [self.n_embed, head_embed, head_embed], dim=2)
56 | q = q.view((*x.shape[:2], self.n_head, -1)).permute(0, 2, 1, 3)
57 | k = k.view(*x.shape[:2], 1, head_embed).permute(0, 2, 1, 3)
58 | v = v.view(*x.shape[:2], 1, head_embed).permute(0, 2, 1, 3)
59 |
60 | # RoPE embeddings
61 | pos = 10000**((-2 * torch.arange(0, self.head_dim, 2, device=x.device) - 1)/self.head_dim)
62 | token_seq = torch.arange(n_tokens, dtype=pos.dtype, device=x.device).unsqueeze(1) @ pos.unsqueeze(0)
63 | rotary_embds = torch.cat((token_seq, token_seq), dim=-1)
64 |
65 | q = (q * rotary_embds.cos()) + \
66 | (self.rotate_embeddings(q) * rotary_embds.sin())
67 | k = (k * rotary_embds.cos()) + \
68 | (self.rotate_embeddings(k) * rotary_embds.sin())
69 |
70 | attn = F.scaled_dot_product_attention(q,k,v, dropout_p=self.dropout, is_causal=True)
71 |
72 | attn = attn.permute(0, 2, 1, 3).flatten(start_dim=2)
73 | return self.resid_dropout(self.out_proj(attn))
74 |
75 |
76 | class MLP(nn.Module):
77 | def __init__(self, config):
78 | super().__init__()
79 |
80 | # Traditionally scale by 4, but overcompensate b/c multi-query attention
81 | h_dim = 4 * config.n_head * config.n_embed
82 | # double h_dim b/c swiglu activation
83 | self.fc = nn.Linear(config.n_embed, 2*h_dim, bias=False)
84 | self.proj = nn.Linear(h_dim, config.n_embed, bias=False)
85 | self.dropout = nn.Dropout(config.dropout)
86 |
87 | def forward(self, x):
88 | x = swiglu(self.fc(x))
89 | return self.dropout(self.proj(x))
90 |
91 |
92 | class ParallelLayerBlock(nn.Module):
93 | def __init__(self, config):
94 | super().__init__()
95 |
96 | self.mlp = MLP(config)
97 | self.mlp_ln = LayerNorm(config.n_embed)
98 |
99 | self.multi_query_attn = MultiQueryAttention(config)
100 | self.mqa_ln = LayerNorm(config.n_embed)
101 |
102 | def forward(self, x):
103 | mlp_out = self.mlp(self.mlp_ln(x))
104 | attn_out = self.multi_query_attn(self.mqa_ln(x))
105 | return x + mlp_out + attn_out
106 |
107 |
108 | class PaLM(nn.Module):
109 |
110 | def __init__(self, config):
111 | super().__init__()
112 |
113 | self.config = config
114 |
115 | self.decoder = nn.ModuleDict(dict(
116 | word_embds=nn.Embedding(config.vocab_size, config.n_embed),
117 | drop=nn.Dropout(config.dropout),
118 | blocks=nn.ModuleList([ParallelLayerBlock(config)
119 | for _ in range(config.n_layer)]),
120 | out_ln=LayerNorm(config.n_embed)
121 | ))
122 |
123 | # Set linear head weights to embedding weights according to paper
124 | self.ln_vocab = nn.Linear(
125 | config.n_embed, config.vocab_size, bias=False)
126 | self.ln_vocab.weight = self.decoder.word_embds.weight
127 |
128 | self.apply(self._init_weights)
129 |
130 | def _init_weights(self, module):
131 | # Paper inits all weights aside from embedding and layer_norm using W ~ N(0, 1/sqrt(n_in))
132 | # Input embeddings get initalized to E ~ N(0,1) since layer_norm isn't applied to the embedding
133 | if isinstance(module, nn.Linear):
134 | torch.nn.init.normal_(module.weight, mean=0.0, std=1/math.sqrt(module.in_features))
135 | elif isinstance(module, nn.Embedding):
136 | nn.init.normal_(self.decoder.word_embds.weight) # maybe make std=0.02 here
137 |
138 | @torch.no_grad()
139 | def generate(self, input_tokens, max_length, terminal_ids=None, temp=1.0):
140 |
141 | while len(input_tokens) < max_length:
142 |
143 | logits, _ = self(input_tokens)
144 | logits = logits[:, -1, :] / temp
145 |
146 | token_scores = F.softmax(logits, dim=-1)
147 | next_token = torch.multinomial(token_scores, num_samples=1)
148 |
149 | input_tokens = torch.cat((input_tokens, next_token), dim=1)
150 |
151 | if terminal_ids and next_token in terminal_ids:
152 | break
153 |
154 | return input_tokens
155 |
156 |
157 | def forward(self, x, targets=None):
158 |
159 | x = self.decoder.word_embds(x)
160 | x = self.decoder.drop(x)
161 |
162 | for block in self.decoder.blocks:
163 | x = block(x)
164 |
165 | x = self.decoder.out_ln(x)
166 |
167 | logits = self.ln_vocab(x)
168 |
169 | if targets is not None:
170 | # Paper scales pre-softmax output logits by 1/sqrt(n_embed), but I can't get this to work well
171 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
172 | targets.contiguous().view(-1),
173 | ignore_index=-1)
174 |
175 | return logits, loss
176 | return logits, None
177 |
178 |
179 | @dataclass
180 | class PaLMConfig:
181 | n_embed: int
182 | n_head: int
183 | dropout: float
184 | vocab_size: int
185 | n_layer: int
186 |
--------------------------------------------------------------------------------
/nanoPaLM.nsys-rep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RobertRiachi/nanoPALM/71ea685c447ec41be6537def6b5fedb43007e7f1/nanoPaLM.nsys-rep
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=2.0.0
2 | transformers
3 | datasets
4 | tqdm
5 | wandb
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import wandb
4 | import torch
5 | import torch.nn.functional as F
6 | import numpy as np
7 | from torch.optim import AdamW
8 | from contextlib import nullcontext
9 | from model import PaLMConfig, PaLM, LayerNorm
10 | from transformers import AutoTokenizer
11 | from tqdm import tqdm
12 |
13 | # TODO: clean this up
14 | device = "cuda" if torch.cuda.is_available() else "cpu" # No love for MPS for now
15 | run_name = "palm"
16 |
17 | # Evaluation
18 | eval_freq = 100#1000
19 | num_evals = 20#100
20 | best_val_loss = 1e9
21 |
22 | # Data
23 | datasets_dir = 'data'
24 | dataset = "openwebtext"
25 | grad_accumulation_steps = 4
26 | batch_size = 16 # Paper follows get_bs function defined below, but this might be too extreme for consumer GPUs
27 | block_size = 512 # Paper uses 2048 but this might be a bit too extreme for consumer GPUs
28 |
29 | # Training
30 | # Note: Paper uses lr=1e-2 for 10k iters, then drops to 1/sqrt(step)
31 | # I've found 2e-4 and cosine decay following Chinchilla guidelines to work better
32 | start_iter = 0 # TODO: Update this when loading from checkpoint in the future
33 | max_iters = 100000
34 | warmup_iters = 2000
35 | learning_rate = 2e-4 # Modified at runtime to follow cosine decay
36 | lr_decay_iters = max_iters # Chinchilla
37 | min_learning_rate = learning_rate / 10 # Chinchilla
38 | weight_decay = learning_rate**2.0 # Decoupled weight decay & modified at runtime
39 | grad_clip = 0.5
40 |
41 | # Precision
42 | precision = torch.bfloat16
43 | amp_enabled = (precision == torch.bfloat16) # Only works with bfloat16 on my gpu, else loss becomes nan not sure why
44 | amp_ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(enabled=amp_enabled, device_type=device, dtype=precision)
45 | scaler = torch.cuda.amp.GradScaler(enabled=amp_enabled)
46 |
47 | # WandB
48 | wandb_logging_enabled = False
49 | wandb_project_name = "nanoPaLM"
50 |
51 | # Config
52 | config = {k:v for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))}
53 |
54 | def get_lr(step):
55 | # Warmup, else cosine decay learning rate
56 | if step < warmup_iters:
57 | return learning_rate * step / warmup_iters
58 |
59 | decay = (step - warmup_iters) / (lr_decay_iters - warmup_iters)
60 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay))
61 | return min_learning_rate + coeff * (learning_rate - min_learning_rate)
62 |
63 |
64 | def update_optim(optim, step):
65 |
66 | for group in optim.param_groups:
67 | lr = get_lr(step)
68 | group['lr'] = lr
69 |
70 | # If not in no_decay group update decay
71 | if group['weight_decay'] != 0.0:
72 | group['weight_decay'] = lr**2.0
73 |
74 |
75 | def num_model_params(model):
76 | units = ['', 'K', 'M', 'B', 'T']
77 | total_params = sum(p.numel()
78 | for p in model.parameters() if p.requires_grad)
79 | mag = int(math.floor(math.log(total_params, 1000)))
80 | return f"{int(total_params / 1000**mag)}{units[mag]}"
81 |
82 |
83 | # Buffers for incoming data
84 | xy = torch.empty((batch_size, block_size+1), dtype=torch.int32).pin_memory()
85 | xy_cuda = torch.empty((batch_size, block_size+1), dtype=torch.int64, device="cuda")
86 |
87 | def load_batch(split, batch_size, device):
88 | global xy
89 | # Select which items to load
90 | ix = torch.randint(len(split) - block_size, (batch_size,))
91 | # Set the relevant elements of xy
92 | for i, data_i in enumerate(ix):
93 | xy[i].numpy()[...] = split[data_i:data_i+1+block_size]
94 | if device == 'cuda':
95 | # Copy the incoming data directly from pinned memory into cuda mem
96 | xy_cuda.copy_(xy, non_blocking=True)
97 | # Slice out x and y
98 | x = xy_cuda[:, :-1]
99 | y = xy_cuda[:, 1:]
100 | else:
101 | raise NotImplementedError
102 | #x, y = x.to(device), y.to(device)
103 | return x, y
104 |
105 |
106 | @torch.no_grad()
107 | def evaluate_splits(model, splits, split_names, num_evals, batch_size, device):
108 | model.eval()
109 | split_losses = {}
110 | for split, split_name in zip(splits, split_names):
111 | losses = torch.zeros(num_evals)
112 | for i in range(num_evals):
113 | x, y = load_batch(split, batch_size, device)
114 |
115 | with amp_ctx:
116 | _, loss = model(x, y)
117 |
118 | losses[i] = loss.item()
119 |
120 | split_losses[split_name] = losses.mean()
121 | model.train()
122 | return split_losses
123 |
124 |
125 | if __name__ == "__main__":
126 |
127 | # Load data & tokenizer
128 | data_dir = os.path.join(datasets_dir, dataset)
129 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
130 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
131 | tokenizer = AutoTokenizer.from_pretrained('gpt2')
132 |
133 | # Load model
134 | palm_config = PaLMConfig(n_embed=768,
135 | n_head=6,
136 | dropout=0.1,
137 | vocab_size=50304, # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
138 | n_layer=4)
139 |
140 | model = PaLM(palm_config).to(device)
141 | num_params = num_model_params(model)
142 | print(f"Initializing PaLM model with {num_params} params")
143 |
144 | # Initalize logging
145 | if wandb_logging_enabled:
146 | import wandb
147 | wandb.init(project=wandb_project_name, name=run_name, config=palm_config)
148 |
149 | # Disable weight decay for unwanted modules
150 | # PaLM model has no bias so only include weight params
151 | # Exclude ln_vocab.weight as it's weight is tied to the word embedding weights
152 | no_decay_modules = [LayerNorm, torch.nn.Embedding]
153 | decay_modules = [torch.nn.Linear]
154 | param_dict = {pn: p for pn, p in model.named_parameters()}
155 | no_decay_params = [f"{n}.weight" for n, m in model.named_modules() if any(
156 | nd for nd in decay_modules if isinstance(m, nd))]
157 | decay_params = [f"{n}.weight" for n, m in model.named_modules() if any(
158 | nd for nd in no_decay_modules if isinstance(m, nd))]
159 |
160 | optimizer_grouped_parameters = [
161 | {'params': [param_dict[p] for p in decay_params], 'weight_decay': weight_decay},
162 | {'params': [param_dict[p] for p in no_decay_params if p != 'ln_vocab.weight'], 'weight_decay': 0.0}
163 | ]
164 |
165 | # Model uses betas=(0.9, (1-step**-0.8)), but I've found default works better w/ AdamW
166 | optim = AdamW(optimizer_grouped_parameters,
167 | lr=learning_rate,
168 | fused=True if device == 'cuda' else False)
169 |
170 | model = torch.compile(model)
171 |
172 | # Training loop
173 | for step in tqdm(range(start_iter, max_iters + 1)):
174 |
175 | update_optim(optim, step)
176 |
177 | if step % eval_freq == 0 and step != 0:
178 | losses = evaluate_splits(model,
179 | splits=[train_data, val_data],
180 | split_names=['train', 'val'],
181 | num_evals=num_evals,
182 | batch_size=batch_size,
183 | device=device)
184 | print(f"Step {step}: Training loss={losses['train']}")
185 |
186 | if wandb_logging_enabled:
187 | wandb.log({
188 | "iter": step,
189 | "train/loss": losses['train'],
190 | "val/loss": losses['val'],
191 | "lr": get_lr(step)
192 | })
193 |
194 | if losses['val'] < best_val_loss:
195 | best_val_loss = losses['val']
196 | checkpoint = {
197 | 'model': model.state_dict(),
198 | 'optimizer': optim.state_dict(),
199 | 'model_args': palm_config,
200 | 'step': step,
201 | 'best_val_loss': best_val_loss,
202 | 'config': config,
203 | }
204 | print(f"Saving checkpoint, step:{step}, val_loss:{best_val_loss}")
205 | check_out = f"checkpoints/{run_name}"
206 |
207 | if not os.path.exists(check_out):
208 | os.mkdir(check_out)
209 | torch.save(checkpoint, os.path.join(check_out, "ckpt.pt"))
210 |
211 | for micro_step in range(grad_accumulation_steps):
212 | x, y = load_batch(train_data, batch_size, device=device)
213 |
214 | with amp_ctx:
215 | logits, loss = model(x, y)
216 |
217 | scaler.scale(loss).backward()
218 |
219 | # Grad clipping for all model sizes
220 | scaler.unscale_(optim)
221 | torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
222 |
223 | scaler.step(optim)
224 | scaler.update()
225 |
226 | optim.zero_grad()
227 |
--------------------------------------------------------------------------------