├── LICENSE ├── README.md ├── makemore.py └── names.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Andrej Karpathy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # makemore 3 | 4 | makemore takes one text file as input, where each line is assumed to be one training thing, and generates more things like it. Under the hood, it is an autoregressive character-level language model, with a wide choice of models from bigrams all the way to a Transformer (exactly as seen in GPT). For example, we can feed it a database of names, and makemore will generate cool baby name ideas that all sound name-like, but are not already existing names. Or if we feed it a database of company names then we can generate new ideas for a name of a company. Or we can just feed it valid scrabble words and generate english-like babble. 5 | 6 | This is not meant to be too heavyweight library with a billion switches and knobs. It is one hackable file, and is mostly intended for educational purposes. [PyTorch](https://pytorch.org) is the only requirement. 7 | 8 | Current implementation follows a few key papers: 9 | 10 | - Bigram (one character predicts the next one with a lookup table of counts) 11 | - MLP, following [Bengio et al. 2003](https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf) 12 | - CNN, following [DeepMind WaveNet 2016](https://arxiv.org/abs/1609.03499) (in progress...) 13 | - RNN, following [Mikolov et al. 2010](https://www.fit.vutbr.cz/research/groups/speech/publi/2010/mikolov_interspeech2010_IS100722.pdf) 14 | - LSTM, following [Graves et al. 2014](https://arxiv.org/abs/1308.0850) 15 | - GRU, following [Kyunghyun Cho et al. 2014](https://arxiv.org/abs/1409.1259) 16 | - Transformer, following [Vaswani et al. 2017](https://arxiv.org/abs/1706.03762) 17 | 18 | ### Usage 19 | 20 | The included `names.txt` dataset, as an example, has the most common 32K names takes from [ssa.gov](https://www.ssa.gov/oact/babynames/) for the year 2018. It looks like: 21 | 22 | ``` 23 | emma 24 | olivia 25 | ava 26 | isabella 27 | sophia 28 | charlotte 29 | ... 30 | ``` 31 | 32 | Let's point the script at it: 33 | 34 | ```bash 35 | $ python makemore.py -i names.txt -o names 36 | ``` 37 | 38 | Training progress and logs and model will all be saved to the working directory `names`. The default model is a super tiny 200K param transformer; Many more training configurations are available - see the argparse and read the code. Training does not require any special hardware, it runs on my Macbook Air and will run on anything else, but if you have a GPU then training will fly faster. As training progresses the script will print some samples throughout. However, if you'd like to sample manually, you can use the `--sample-only` flag, e.g. in a separate terminal do: 39 | 40 | ```bash 41 | $ python makemore.py -i names.txt -o names --sample-only 42 | ``` 43 | 44 | This will load the best model so far and print more samples on demand. Here are some unique baby names that get eventually generated from current default settings (test logprob of ~1.92, though much lower logprobs are achievable with some hyperparameter tuning): 45 | 46 | ``` 47 | dontell 48 | khylum 49 | camatena 50 | aeriline 51 | najlah 52 | sherrith 53 | ryel 54 | irmi 55 | taislee 56 | mortaz 57 | akarli 58 | maxfelynn 59 | biolett 60 | zendy 61 | laisa 62 | halliliana 63 | goralynn 64 | brodynn 65 | romima 66 | chiyomin 67 | loghlyn 68 | melichae 69 | mahmed 70 | irot 71 | helicha 72 | besdy 73 | ebokun 74 | lucianno 75 | ``` 76 | 77 | Have fun! 78 | 79 | ### License 80 | 81 | MIT 82 | -------------------------------------------------------------------------------- /makemore.py: -------------------------------------------------------------------------------- 1 | """ 2 | you give this script some words (one per line) and it will generate more things like it. 3 | uses super state of the art Transformer AI tech 4 | this code is intended to be super hackable. tune it to your needs. 5 | 6 | Changes from minGPT: 7 | - I removed the from_pretrained function where we init with GPT2 weights 8 | - I removed dropout layers because the models we train here are small, 9 | it's not necessary to understand at this stage and at this scale. 10 | - I removed weight decay and all of the complexity around what parameters are 11 | and are not weight decayed. I don't believe this should make a massive 12 | difference at the scale that we operate on here. 13 | """ 14 | 15 | import os 16 | import sys 17 | import time 18 | import math 19 | import argparse 20 | from dataclasses import dataclass 21 | from typing import List 22 | 23 | import torch 24 | import torch.nn as nn 25 | from torch.nn import functional as F 26 | from torch.utils.data import Dataset 27 | from torch.utils.data.dataloader import DataLoader 28 | from torch.utils.tensorboard import SummaryWriter 29 | 30 | # ----------------------------------------------------------------------------- 31 | 32 | @dataclass 33 | class ModelConfig: 34 | block_size: int = None # length of the input sequences of integers 35 | vocab_size: int = None # the input integers are in range [0 .. vocab_size -1] 36 | # parameters below control the sizes of each model slightly differently 37 | n_layer: int = 4 38 | n_embd: int = 64 39 | n_embd2: int = 64 40 | n_head: int = 4 41 | 42 | # ----------------------------------------------------------------------------- 43 | # Transformer Language Model (*exactly* as used in GPT-2) 44 | 45 | class NewGELU(nn.Module): 46 | """ 47 | Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). 48 | Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 49 | """ 50 | def forward(self, x): 51 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 52 | 53 | class CausalSelfAttention(nn.Module): 54 | """ 55 | A vanilla multi-head masked self-attention layer with a projection at the end. 56 | It is possible to use torch.nn.MultiheadAttention here but I am including an 57 | explicit implementation here to show that there is nothing too scary here. 58 | """ 59 | 60 | def __init__(self, config): 61 | super().__init__() 62 | assert config.n_embd % config.n_head == 0 63 | # key, query, value projections for all heads, but in a batch 64 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) 65 | # output projection 66 | self.c_proj = nn.Linear(config.n_embd, config.n_embd) 67 | # causal mask to ensure that attention is only applied to the left in the input sequence 68 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) 69 | .view(1, 1, config.block_size, config.block_size)) 70 | self.n_head = config.n_head 71 | self.n_embd = config.n_embd 72 | 73 | def forward(self, x): 74 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 75 | 76 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 77 | q, k ,v = self.c_attn(x).split(self.n_embd, dim=2) 78 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 79 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 80 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 81 | 82 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 83 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 84 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) 85 | att = F.softmax(att, dim=-1) 86 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 87 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 88 | 89 | # output projection 90 | y = self.c_proj(y) 91 | return y 92 | 93 | class Block(nn.Module): 94 | """ an unassuming Transformer block """ 95 | 96 | def __init__(self, config): 97 | super().__init__() 98 | self.ln_1 = nn.LayerNorm(config.n_embd) 99 | self.attn = CausalSelfAttention(config) 100 | self.ln_2 = nn.LayerNorm(config.n_embd) 101 | self.mlp = nn.ModuleDict(dict( 102 | c_fc = nn.Linear(config.n_embd, 4 * config.n_embd), 103 | c_proj = nn.Linear(4 * config.n_embd, config.n_embd), 104 | act = NewGELU(), 105 | )) 106 | m = self.mlp 107 | self.mlpf = lambda x: m.c_proj(m.act(m.c_fc(x))) # MLP forward 108 | 109 | def forward(self, x): 110 | x = x + self.attn(self.ln_1(x)) 111 | x = x + self.mlpf(self.ln_2(x)) 112 | return x 113 | 114 | class Transformer(nn.Module): 115 | """ Transformer Language Model, exactly as seen in GPT-2 """ 116 | 117 | def __init__(self, config): 118 | super().__init__() 119 | self.block_size = config.block_size 120 | 121 | self.transformer = nn.ModuleDict(dict( 122 | wte = nn.Embedding(config.vocab_size, config.n_embd), 123 | wpe = nn.Embedding(config.block_size, config.n_embd), 124 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 125 | ln_f = nn.LayerNorm(config.n_embd), 126 | )) 127 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 128 | 129 | # report number of parameters (note we don't count the decoder parameters in lm_head) 130 | n_params = sum(p.numel() for p in self.transformer.parameters()) 131 | print("number of parameters: %.2fM" % (n_params/1e6,)) 132 | 133 | def get_block_size(self): 134 | return self.block_size 135 | 136 | def forward(self, idx, targets=None): 137 | device = idx.device 138 | b, t = idx.size() 139 | assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}" 140 | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) 141 | 142 | # forward the GPT model itself 143 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 144 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) 145 | x = tok_emb + pos_emb 146 | for block in self.transformer.h: 147 | x = block(x) 148 | x = self.transformer.ln_f(x) 149 | logits = self.lm_head(x) 150 | 151 | # if we are given some desired targets also calculate the loss 152 | loss = None 153 | if targets is not None: 154 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 155 | 156 | return logits, loss 157 | 158 | # ----------------------------------------------------------------------------- 159 | # Bag of Words (BoW) language model 160 | 161 | class CausalBoW(nn.Module): 162 | """ 163 | Causal bag of words. Averages the preceding elements and looks suspiciously like 164 | a CausalAttention module you'd find in a transformer, for no apparent reason at all ;) 165 | """ 166 | def __init__(self, config): 167 | super().__init__() 168 | 169 | # used to mask out vectors and preserve autoregressive property 170 | self.block_size = config.block_size 171 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) 172 | .view(1, config.block_size, config.block_size)) 173 | 174 | def forward(self, x): 175 | B, T, C = x.size() # batch size, sequence length, n_embd 176 | 177 | # do the weighted average of all preceeding token features 178 | att = torch.zeros((B, T, T), device=x.device) 179 | att = att.masked_fill(self.bias[:,:T,:T] == 0, float('-inf')) 180 | att = F.softmax(att, dim=-1) 181 | y = att @ x # (B, T, T) x (B, T, C) -> (B, T, C) 182 | 183 | return y 184 | 185 | class BoWBlock(nn.Module): 186 | """ collects BoW features and adds an MLP """ 187 | 188 | def __init__(self, config): 189 | super().__init__() 190 | 191 | # Causal BoW module 192 | self.cbow = CausalBoW(config) 193 | # MLP assembler 194 | self.mlp = nn.ModuleDict(dict( 195 | c_fc = nn.Linear(config.n_embd, config.n_embd2), 196 | c_proj = nn.Linear(config.n_embd2, config.n_embd), 197 | )) 198 | m = self.mlp 199 | self.mlpf = lambda x: m.c_proj(F.tanh(m.c_fc(x))) # MLP forward 200 | 201 | def forward(self, x): 202 | x = x + self.cbow(x) 203 | x = x + self.mlpf(x) 204 | return x 205 | 206 | class BoW(nn.Module): 207 | """ 208 | takes the previous block_size tokens, encodes them with a lookup table, 209 | also encodes their positions with lookup table, then averages all of those 210 | embeddings up and uses that to predict the next token. 211 | """ 212 | 213 | def __init__(self, config): 214 | super().__init__() 215 | self.block_size = config.block_size 216 | self.vocab_size = config.vocab_size 217 | # token embedding 218 | self.wte = nn.Embedding(config.vocab_size, config.n_embd) 219 | # position embedding 220 | self.wpe = nn.Embedding(config.block_size, config.n_embd) 221 | # context block 222 | self.context_block = BoWBlock(config) 223 | # language model head decoder layer 224 | self.lm_head = nn.Linear(config.n_embd, self.vocab_size) 225 | 226 | def get_block_size(self): 227 | return self.block_size 228 | 229 | def forward(self, idx, targets=None): 230 | 231 | device = idx.device 232 | b, t = idx.size() 233 | assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}" 234 | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) 235 | 236 | # forward the token and position embedding layers 237 | tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd) 238 | pos_emb = self.wpe(pos) # position embeddings of shape (1, t, n_embd) 239 | # add and run through the decoder MLP 240 | x = tok_emb + pos_emb 241 | # run the bag of words context module 242 | x = self.context_block(x) 243 | # decode to next token probability 244 | logits = self.lm_head(x) 245 | 246 | # if we are given some desired targets also calculate the loss 247 | loss = None 248 | if targets is not None: 249 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 250 | 251 | return logits, loss 252 | 253 | # ----------------------------------------------------------------------------- 254 | """ 255 | Recurrent Neural Net language model: either a vanilla RNN recurrence or a GRU. 256 | Did not implement an LSTM because its API is a bit more annoying as it has 257 | both a hidden state and a cell state, but it's very similar to GRU and in 258 | practice works just as well. 259 | """ 260 | 261 | class RNNCell(nn.Module): 262 | """ 263 | the job of a 'Cell' is to: 264 | take input at current time step x_{t} and the hidden state at the 265 | previous time step h_{t-1} and return the resulting hidden state 266 | h_{t} at the current timestep 267 | """ 268 | def __init__(self, config): 269 | super().__init__() 270 | self.xh_to_h = nn.Linear(config.n_embd + config.n_embd2, config.n_embd2) 271 | 272 | def forward(self, xt, hprev): 273 | xh = torch.cat([xt, hprev], dim=1) 274 | ht = F.tanh(self.xh_to_h(xh)) 275 | return ht 276 | 277 | class GRUCell(nn.Module): 278 | """ 279 | same job as RNN cell, but a bit more complicated recurrence formula 280 | that makes the GRU more expressive and easier to optimize. 281 | """ 282 | def __init__(self, config): 283 | super().__init__() 284 | # input, forget, output, gate 285 | self.xh_to_z = nn.Linear(config.n_embd + config.n_embd2, config.n_embd2) 286 | self.xh_to_r = nn.Linear(config.n_embd + config.n_embd2, config.n_embd2) 287 | self.xh_to_hbar = nn.Linear(config.n_embd + config.n_embd2, config.n_embd2) 288 | 289 | def forward(self, xt, hprev): 290 | # first use the reset gate to wipe some channels of the hidden state to zero 291 | xh = torch.cat([xt, hprev], dim=1) 292 | r = F.sigmoid(self.xh_to_r(xh)) 293 | hprev_reset = r * hprev 294 | # calculate the candidate new hidden state hbar 295 | xhr = torch.cat([xt, hprev_reset], dim=1) 296 | hbar = F.tanh(self.xh_to_hbar(xhr)) 297 | # calculate the switch gate that determines if each channel should be updated at all 298 | z = F.sigmoid(self.xh_to_z(xh)) 299 | # blend the previous hidden state and the new candidate hidden state 300 | ht = (1 - z) * hprev + z * hbar 301 | return ht 302 | 303 | class RNN(nn.Module): 304 | 305 | def __init__(self, config, cell_type): 306 | super().__init__() 307 | self.block_size = config.block_size 308 | self.vocab_size = config.vocab_size 309 | self.start = nn.Parameter(torch.zeros(1, config.n_embd2)) # the starting hidden state 310 | self.wte = nn.Embedding(config.vocab_size, config.n_embd) # token embeddings table 311 | if cell_type == 'rnn': 312 | self.cell = RNNCell(config) 313 | elif cell_type == 'gru': 314 | self.cell = GRUCell(config) 315 | self.lm_head = nn.Linear(config.n_embd2, self.vocab_size) 316 | 317 | def get_block_size(self): 318 | return self.block_size 319 | 320 | def forward(self, idx, targets=None): 321 | device = idx.device 322 | b, t = idx.size() 323 | 324 | # embed all the integers up front and all at once for efficiency 325 | emb = self.wte(idx) # (b, t, n_embd) 326 | 327 | # sequentially iterate over the inputs and update the RNN state each tick 328 | hprev = self.start.expand((b, -1)) # expand out the batch dimension 329 | hiddens = [] 330 | for i in range(t): 331 | xt = emb[:, i, :] # (b, n_embd) 332 | ht = self.cell(xt, hprev) # (b, n_embd2) 333 | hprev = ht 334 | hiddens.append(ht) 335 | 336 | # decode the outputs 337 | hidden = torch.stack(hiddens, 1) # (b, t, n_embd2) 338 | logits = self.lm_head(hidden) 339 | 340 | # if we are given some desired targets also calculate the loss 341 | loss = None 342 | if targets is not None: 343 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 344 | 345 | return logits, loss 346 | 347 | # ----------------------------------------------------------------------------- 348 | # MLP language model 349 | 350 | class MLP(nn.Module): 351 | """ 352 | takes the previous block_size tokens, encodes them with a lookup table, 353 | concatenates the vectors and predicts the next token with an MLP. 354 | 355 | Reference: 356 | Bengio et al. 2003 https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf 357 | """ 358 | 359 | def __init__(self, config): 360 | super().__init__() 361 | self.block_size = config.block_size 362 | self.vocab_size = config.vocab_size 363 | self.wte = nn.Embedding(config.vocab_size + 1, config.n_embd) # token embeddings table 364 | # +1 in the line above for a special token that gets inserted if encoding a token 365 | # before the beginning of the input sequence 366 | self.mlp = nn.Sequential( 367 | nn.Linear(self.block_size * config.n_embd, config.n_embd2), 368 | nn.Tanh(), 369 | nn.Linear(config.n_embd2, self.vocab_size) 370 | ) 371 | 372 | def get_block_size(self): 373 | return self.block_size 374 | 375 | def forward(self, idx, targets=None): 376 | 377 | # gather the word embeddings of the previous 3 words 378 | embs = [] 379 | for k in range(self.block_size): 380 | tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd) 381 | idx = torch.roll(idx, 1, 1) 382 | idx[:, 0] = self.vocab_size # special token 383 | embs.append(tok_emb) 384 | 385 | # concat all of the embeddings together and pass through an MLP 386 | x = torch.cat(embs, -1) # (b, t, n_embd * block_size) 387 | logits = self.mlp(x) 388 | 389 | # if we are given some desired targets also calculate the loss 390 | loss = None 391 | if targets is not None: 392 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 393 | 394 | return logits, loss 395 | 396 | # ----------------------------------------------------------------------------- 397 | # Bigram language model 398 | 399 | class Bigram(nn.Module): 400 | """ 401 | Bigram Language Model 'neural net', simply a lookup table of logits for the 402 | next character given a previous character. 403 | """ 404 | 405 | def __init__(self, config): 406 | super().__init__() 407 | n = config.vocab_size 408 | self.logits = nn.Parameter(torch.zeros((n, n))) 409 | 410 | def get_block_size(self): 411 | return 1 # this model only needs one previous character to predict the next 412 | 413 | def forward(self, idx, targets=None): 414 | 415 | # 'forward pass', lol 416 | logits = self.logits[idx] 417 | 418 | # if we are given some desired targets also calculate the loss 419 | loss = None 420 | if targets is not None: 421 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 422 | 423 | return logits, loss 424 | 425 | # ----------------------------------------------------------------------------- 426 | # helper functions for evaluating and sampling from the model 427 | 428 | @torch.no_grad() 429 | def generate(model, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None): 430 | """ 431 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete 432 | the sequence max_new_tokens times, feeding the predictions back into the model each time. 433 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. 434 | """ 435 | block_size = model.get_block_size() 436 | for _ in range(max_new_tokens): 437 | # if the sequence context is growing too long we must crop it at block_size 438 | idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:] 439 | # forward the model to get the logits for the index in the sequence 440 | logits, _ = model(idx_cond) 441 | # pluck the logits at the final step and scale by desired temperature 442 | logits = logits[:, -1, :] / temperature 443 | # optionally crop the logits to only the top k options 444 | if top_k is not None: 445 | v, _ = torch.topk(logits, top_k) 446 | logits[logits < v[:, [-1]]] = -float('Inf') 447 | # apply softmax to convert logits to (normalized) probabilities 448 | probs = F.softmax(logits, dim=-1) 449 | # either sample from the distribution or take the most likely element 450 | if do_sample: 451 | idx_next = torch.multinomial(probs, num_samples=1) 452 | else: 453 | _, idx_next = torch.topk(probs, k=1, dim=-1) 454 | # append sampled index to the running sequence and continue 455 | idx = torch.cat((idx, idx_next), dim=1) 456 | 457 | return idx 458 | 459 | def print_samples(num=10): 460 | """ samples from the model and pretty prints the decoded samples """ 461 | X_init = torch.zeros(num, 1, dtype=torch.long).to(args.device) 462 | top_k = args.top_k if args.top_k != -1 else None 463 | steps = train_dataset.get_output_length() - 1 # -1 because we already start with token (index 0) 464 | X_samp = generate(model, X_init, steps, top_k=top_k, do_sample=True).to('cpu') 465 | train_samples, test_samples, new_samples = [], [], [] 466 | for i in range(X_samp.size(0)): 467 | # get the i'th row of sampled integers, as python list 468 | row = X_samp[i, 1:].tolist() # note: we need to crop out the first token 469 | # token 0 is the token, so we crop the output sequence at that point 470 | crop_index = row.index(0) if 0 in row else len(row) 471 | row = row[:crop_index] 472 | word_samp = train_dataset.decode(row) 473 | # separately track samples that we have and have not seen before 474 | if train_dataset.contains(word_samp): 475 | train_samples.append(word_samp) 476 | elif test_dataset.contains(word_samp): 477 | test_samples.append(word_samp) 478 | else: 479 | new_samples.append(word_samp) 480 | print('-'*80) 481 | for lst, desc in [(train_samples, 'in train'), (test_samples, 'in test'), (new_samples, 'new')]: 482 | print(f"{len(lst)} samples that are {desc}:") 483 | for word in lst: 484 | print(word) 485 | print('-'*80) 486 | 487 | @torch.inference_mode() 488 | def evaluate(model, dataset, batch_size=50, max_batches=None): 489 | model.eval() 490 | loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=0) 491 | losses = [] 492 | for i, batch in enumerate(loader): 493 | batch = [t.to(args.device) for t in batch] 494 | X, Y = batch 495 | logits, loss = model(X, Y) 496 | losses.append(loss.item()) 497 | if max_batches is not None and i >= max_batches: 498 | break 499 | mean_loss = torch.tensor(losses).mean().item() 500 | model.train() # reset model back to training mode 501 | return mean_loss 502 | 503 | # ----------------------------------------------------------------------------- 504 | # helper functions for creating the training and test Datasets that emit words 505 | 506 | class CharDataset(Dataset): 507 | 508 | def __init__(self, words, chars, max_word_length): 509 | self.words = words 510 | self.chars = chars 511 | self.max_word_length = max_word_length 512 | self.stoi = {ch:i+1 for i,ch in enumerate(chars)} 513 | self.itos = {i:s for s,i in self.stoi.items()} # inverse mapping 514 | 515 | def __len__(self): 516 | return len(self.words) 517 | 518 | def contains(self, word): 519 | return word in self.words 520 | 521 | def get_vocab_size(self): 522 | return len(self.chars) + 1 # all the possible characters and special 0 token 523 | 524 | def get_output_length(self): 525 | return self.max_word_length + 1 # token followed by words 526 | 527 | def encode(self, word): 528 | ix = torch.tensor([self.stoi[w] for w in word], dtype=torch.long) 529 | return ix 530 | 531 | def decode(self, ix): 532 | word = ''.join(self.itos[i] for i in ix) 533 | return word 534 | 535 | def __getitem__(self, idx): 536 | word = self.words[idx] 537 | ix = self.encode(word) 538 | x = torch.zeros(self.max_word_length + 1, dtype=torch.long) 539 | y = torch.zeros(self.max_word_length + 1, dtype=torch.long) 540 | x[1:1+len(ix)] = ix 541 | y[:len(ix)] = ix 542 | y[len(ix)+1:] = -1 # index -1 will mask the loss at the inactive locations 543 | return x, y 544 | 545 | def create_datasets(input_file): 546 | 547 | # preprocessing of the input text file 548 | with open(input_file, 'r') as f: 549 | data = f.read() 550 | words = data.splitlines() 551 | words = [w.strip() for w in words] # get rid of any leading or trailing white space 552 | words = [w for w in words if w] # get rid of any empty strings 553 | chars = sorted(list(set(''.join(words)))) # all the possible characters 554 | max_word_length = max(len(w) for w in words) 555 | print(f"number of examples in the dataset: {len(words)}") 556 | print(f"max word length: {max_word_length}") 557 | print(f"number of unique characters in the vocabulary: {len(chars)}") 558 | print("vocabulary:") 559 | print(''.join(chars)) 560 | 561 | # partition the input data into a training and the test set 562 | test_set_size = min(1000, int(len(words) * 0.1)) # 10% of the training set, or up to 1000 examples 563 | rp = torch.randperm(len(words)).tolist() 564 | train_words = [words[i] for i in rp[:-test_set_size]] 565 | test_words = [words[i] for i in rp[-test_set_size:]] 566 | print(f"split up the dataset into {len(train_words)} training examples and {len(test_words)} test examples") 567 | 568 | # wrap in dataset objects 569 | train_dataset = CharDataset(train_words, chars, max_word_length) 570 | test_dataset = CharDataset(test_words, chars, max_word_length) 571 | 572 | return train_dataset, test_dataset 573 | 574 | class InfiniteDataLoader: 575 | """ 576 | this is really hacky and I'm not proud of it, but there doesn't seem to be 577 | a better way in PyTorch to just create an infinite dataloader? 578 | """ 579 | 580 | def __init__(self, dataset, **kwargs): 581 | train_sampler = torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=int(1e10)) 582 | self.train_loader = DataLoader(dataset, sampler=train_sampler, **kwargs) 583 | self.data_iter = iter(self.train_loader) 584 | 585 | def next(self): 586 | try: 587 | batch = next(self.data_iter) 588 | except StopIteration: # this will technically only happen after 1e10 samples... (i.e. basically never) 589 | self.data_iter = iter(self.train_loader) 590 | batch = next(self.data_iter) 591 | return batch 592 | 593 | # ----------------------------------------------------------------------------- 594 | if __name__ == '__main__': 595 | 596 | # parse command line args 597 | parser = argparse.ArgumentParser(description="Make More") 598 | # system/input/output 599 | parser.add_argument('--input-file', '-i', type=str, default='names.txt', help="input file with things one per line") 600 | parser.add_argument('--work-dir', '-o', type=str, default='out', help="output working directory") 601 | parser.add_argument('--resume', action='store_true', help="when this flag is used, we will resume optimization from existing model in the workdir") 602 | parser.add_argument('--sample-only', action='store_true', help="just sample from the model and quit, don't train") 603 | parser.add_argument('--num-workers', '-n', type=int, default=4, help="number of data workers for both train/test") 604 | parser.add_argument('--max-steps', type=int, default=-1, help="max number of optimization steps to run for, or -1 for infinite.") 605 | parser.add_argument('--device', type=str, default='cpu', help="device to use for compute, examples: cpu|cuda|cuda:2|mps") 606 | parser.add_argument('--seed', type=int, default=3407, help="seed") 607 | # sampling 608 | parser.add_argument('--top-k', type=int, default=-1, help="top-k for sampling, -1 means no top-k") 609 | # model 610 | parser.add_argument('--type', type=str, default='transformer', help="model class type to use, bigram|mlp|rnn|gru|bow|transformer") 611 | parser.add_argument('--n-layer', type=int, default=4, help="number of layers") 612 | parser.add_argument('--n-head', type=int, default=4, help="number of heads (in a transformer)") 613 | parser.add_argument('--n-embd', type=int, default=64, help="number of feature channels in the model") 614 | parser.add_argument('--n-embd2', type=int, default=64, help="number of feature channels elsewhere in the model") 615 | # optimization 616 | parser.add_argument('--batch-size', '-b', type=int, default=32, help="batch size during optimization") 617 | parser.add_argument('--learning-rate', '-l', type=float, default=5e-4, help="learning rate") 618 | parser.add_argument('--weight-decay', '-w', type=float, default=0.01, help="weight decay") 619 | args = parser.parse_args() 620 | print(vars(args)) 621 | 622 | # system inits 623 | torch.manual_seed(args.seed) 624 | torch.cuda.manual_seed_all(args.seed) 625 | os.makedirs(args.work_dir, exist_ok=True) 626 | writer = SummaryWriter(log_dir=args.work_dir) 627 | 628 | # init datasets 629 | train_dataset, test_dataset = create_datasets(args.input_file) 630 | vocab_size = train_dataset.get_vocab_size() 631 | block_size = train_dataset.get_output_length() 632 | print(f"dataset determined that: {vocab_size=}, {block_size=}") 633 | 634 | # init model 635 | config = ModelConfig(vocab_size=vocab_size, block_size=block_size, 636 | n_layer=args.n_layer, n_head=args.n_head, 637 | n_embd=args.n_embd, n_embd2=args.n_embd2) 638 | if args.type == 'transformer': 639 | model = Transformer(config) 640 | elif args.type == 'bigram': 641 | model = Bigram(config) 642 | elif args.type == 'mlp': 643 | model = MLP(config) 644 | elif args.type == 'rnn': 645 | model = RNN(config, cell_type='rnn') 646 | elif args.type == 'gru': 647 | model = RNN(config, cell_type='gru') 648 | elif args.type == 'bow': 649 | model = BoW(config) 650 | else: 651 | raise ValueError(f'model type {args.type} is not recognized') 652 | model.to(args.device) 653 | print(f"model #params: {sum(p.numel() for p in model.parameters())}") 654 | if args.resume or args.sample_only: # note: if we sample-only then we also assume we are resuming 655 | print("resuming from existing model in the workdir") 656 | model.load_state_dict(torch.load(os.path.join(args.work_dir, 'model.pt'))) 657 | if args.sample_only: 658 | print_samples(num=50) 659 | sys.exit() 660 | 661 | # init optimizer 662 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, betas=(0.9, 0.99), eps=1e-8) 663 | 664 | # init dataloader 665 | batch_loader = InfiniteDataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True, num_workers=args.num_workers) 666 | 667 | # training loop 668 | best_loss = None 669 | step = 0 670 | while True: 671 | 672 | t0 = time.time() 673 | 674 | # get the next batch, ship to device, and unpack it to input and target 675 | batch = batch_loader.next() 676 | batch = [t.to(args.device) for t in batch] 677 | X, Y = batch 678 | 679 | # feed into the model 680 | logits, loss = model(X, Y) 681 | 682 | # calculate the gradient, update the weights 683 | model.zero_grad(set_to_none=True) 684 | loss.backward() 685 | optimizer.step() 686 | 687 | # wait for all CUDA work on the GPU to finish then calculate iteration time taken 688 | if args.device.startswith('cuda'): 689 | torch.cuda.synchronize() 690 | t1 = time.time() 691 | 692 | # logging 693 | if step % 10 == 0: 694 | print(f"step {step} | loss {loss.item():.4f} | step time {(t1-t0)*1000:.2f}ms") 695 | 696 | # evaluate the model 697 | if step > 0 and step % 500 == 0: 698 | train_loss = evaluate(model, train_dataset, batch_size=100, max_batches=10) 699 | test_loss = evaluate(model, test_dataset, batch_size=100, max_batches=10) 700 | writer.add_scalar("Loss/train", train_loss, step) 701 | writer.add_scalar("Loss/test", test_loss, step) 702 | writer.flush() 703 | print(f"step {step} train loss: {train_loss} test loss: {test_loss}") 704 | # save the model to disk if it has improved 705 | if best_loss is None or test_loss < best_loss: 706 | out_path = os.path.join(args.work_dir, "model.pt") 707 | print(f"test loss {test_loss} is the best so far, saving model to {out_path}") 708 | torch.save(model.state_dict(), out_path) 709 | best_loss = test_loss 710 | 711 | # sample from the model 712 | if step > 0 and step % 200 == 0: 713 | print_samples(num=10) 714 | 715 | step += 1 716 | # termination conditions 717 | if args.max_steps >= 0 and step >= args.max_steps: 718 | break 719 | 720 | --------------------------------------------------------------------------------