├── requirements.txt ├── license ├── README.md ├── LICENSE.txt ├── nanoGPT_LICENSE.txt └── llama_LICENSE.txt ├── README.md ├── rope.py ├── sample.py ├── reproduce ├── jobs.py ├── README.md └── plots.py ├── prepare.py ├── data.py ├── megabyte.py ├── spacebyte.py ├── util.py ├── transformer.py ├── train.py └── spacebyte_figure.svg /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | zstandard 3 | fire 4 | numpy 5 | torch>=2.0.0 6 | wandb 7 | datasets 8 | tiktoken 9 | sentencepiece 10 | flash-attn>=2.3.0 11 | -------------------------------------------------------------------------------- /license/README.md: -------------------------------------------------------------------------------- 1 | This project was initiated from Andrej Karpathy's [nanoGPT](https://github.com/karpathy/nanoGPT). 2 | See `nanoGPT_LICENSE.txt` for its license. 3 | 4 | Rope.py is adapted from META's [Llama 2 code](https://github.com/facebookresearch/llama/blob/main/llama/model.py). 5 | See `llama_LICENSE.txt` for its license. 6 | Everything else is licensed under an MIT license; see `license.txt`. 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SpaceByte 2 | 3 | This is the implementation of SpaceByte used in [SpaceByte: Towards Deleting Tokenization from Large Language Modeling](https://arxiv.org/abs/2404.14408). 4 | SpaceByte is a tokenization-free large language model (LLM) that uses multiscale modeling at the byte and word levels to model language with the same performance as standard LLM architectures that instead use tokenization. 5 | See the `reproduce` directory for instructions for reproducing the results in our paper. 6 | 7 | ![SpaceByte architecture schematic](spacebyte_figure.svg) 8 | -------------------------------------------------------------------------------- /license/LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Kevin Slagle 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /license/nanoGPT_LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Andrej Karpathy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /rope.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # adapted from 4 | # https://github.com/facebookresearch/llama/blob/main/llama/model.py 5 | 6 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): 7 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 8 | 9 | t = torch.arange(end, device=freqs.device, dtype=torch.float32) # type: ignore 10 | freqs = torch.outer(t, freqs) # type: ignore 11 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 12 | return freqs_cis 13 | 14 | def apply_rotary_emb( 15 | xq: torch.Tensor, 16 | xk: torch.Tensor, 17 | freqs_cis: torch.Tensor, 18 | ): 19 | device = xq.device 20 | if not torch.cuda.is_available(): 21 | xq = xq.to('cpu') 22 | xk = xk.to('cpu') 23 | freqs_cis = freqs_cis.to('cpu') 24 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 25 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 26 | freqs_cis = reshape_for_rotary_broadcast(freqs_cis, xq_) 27 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 28 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 29 | return xq_out.type_as(xq).to(device), xk_out.type_as(xk).to(device) 30 | 31 | def reshape_for_rotary_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 32 | ndim = x.ndim 33 | assert 0 <= 1 < ndim 34 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 35 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 36 | return freqs_cis.view(*shape) 37 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import time 4 | 5 | import torch 6 | 7 | import util 8 | import train 9 | 10 | def sample( 11 | model: str, 12 | start: str = '', # can also specify a file by "FILE:prompt.txt" 13 | max_tokens: int = None, 14 | num_samples: int = 1, 15 | temperature: float = 1.0, 16 | top_k: int = None, 17 | 18 | seed: int = 1, 19 | batch_size: int = 10, 20 | device: str = None, 21 | dtype: str = None, 22 | compile: bool = False, 23 | 24 | quiet: bool = False, 25 | check_logits: bool = False, 26 | ): 27 | 28 | print_ = print if not quiet else (lambda *args, **kwargs: None) 29 | 30 | if device is None: 31 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 32 | 33 | if dtype is None: 34 | # todo M2 supposedly support bfloat16 https://en.wikipedia.org/wiki/Bfloat16_floating-point_format 35 | dtype = 'bfloat16' if 'cuda' in device and torch.cuda.is_bf16_supported() else 'float32' 36 | 37 | if isinstance(model, str): 38 | model, _ = train.Train.model_from_checkpoint(model, device=device) 39 | 40 | if compile is None: 41 | compile = 'cuda' in device 42 | if compile: 43 | model = torch.compile(model) 44 | 45 | tokenizer = util.Tokenizer(model.config.tokenizer) 46 | 47 | # encode the beginning of the prompt 48 | if start.startswith('FILE:'): 49 | with open(start[5:], 'r', encoding='utf-8') as f: 50 | start = f.read() 51 | print_(start) 52 | print_() 53 | start_tokens = tokenizer.encode(start, prepend_BOS=True, device=device) 54 | start_tokens = start_tokens.broadcast_to(batch_size, *start_tokens.shape) 55 | 56 | def check_logits_func(delta_logits): 57 | eps = torch.finfo(getattr(torch, dtype)).eps 58 | print_(f'logits error: {util.mean2(delta_logits):7.2g} ' 59 | f'(max={delta_logits.abs().max():7.2g}, eps={eps:.2g})') 60 | assert util.mean2(delta_logits) < 100*eps 61 | 62 | torch.manual_seed(seed) 63 | model.eval() 64 | if quiet: 65 | log = dict(times=[], generations=[]) 66 | with torch.inference_mode(): 67 | with util.autocast_context(dtype): 68 | for _ in range(num_samples): 69 | t0 = time.time() 70 | tokens, logits = model.generate(start_tokens, max_tokens=max_tokens, temperature=temperature, 71 | top_k=top_k, logits=check_logits, check_logits_func=check_logits_func if check_logits else None) 72 | t0 = time.time() - t0 73 | if check_logits: 74 | T = model.config.context_size 75 | forward_logits, _ = model(tokens[:, :-1][:, :T], tokens[:, 1:1+T]) 76 | check_logits_func(forward_logits - logits[:, :T]) 77 | if quiet: 78 | log['times'].append(t0) 79 | log['generations'].append(tokens) 80 | print_(f'generation took {t0:.3f}s, {(tokens[:, 1:].numel() - start_tokens[:, 1:].numel())/t0:.1f} tps') 81 | for generation in tokens: 82 | print_('---------------') 83 | print_(tokenizer.decode(generation[1:])) 84 | 85 | if quiet: 86 | return log 87 | 88 | import fire 89 | if __name__ == '__main__': 90 | fire.Fire(sample) 91 | -------------------------------------------------------------------------------- /reproduce/jobs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # to run jobs on other nodes, set SUBMIT_COMMAND to a submission command 4 | SUBMIT_COMMAND = [] 5 | # SUBMIT_COMMAND = ['echo'] # just print the commands to run 6 | 7 | import math 8 | import subprocess 9 | 10 | def submit(*, flops, **kwargs): 11 | kwargs['iters'] = f'{flops}/flops' 12 | cmd = SUBMIT_COMMAND + [ 13 | 'python3', 'train.py'] + [ 14 | f'--{k}' if v is True else f'--{k}={v}' for k, v in kwargs.items() if v is not None] 15 | subprocess.run(cmd) 16 | 17 | out_dir='spacebyte' 18 | Ld = {192:4, 256:8, 384:16, 512:24, 768:32, 1024:32, 1536:48, 2048:48} # 32:1, 64:2, 128:3, 19 | ds = list(Ld.keys()) 20 | Ls = [1, 2, 3, 4, 6, 8, 12, 16, 24, 32, 48] 21 | 22 | rope = True 23 | beta2 = 0.98 24 | B = 64 25 | lr = 0.5 26 | lr = f'{lr}e-2*{B}**0.5' 27 | for dataset in ['github', 'arxiv', 'pg19']: 28 | for flops in ['1e18', '1e19']: 29 | for tok in [None, 'gpt2', 'sp']: 30 | for d in ds: 31 | if {'1e16': 192 <= d <= 256, '1e17': 256 <= d <= 384, '1e18': 384 <= d <= 768, '1e19': 512 <= d <= 1024}[flops]: 32 | for L in Ls: 33 | P = 6 if 'github' not in dataset else 8 34 | args = dict(dataset=dataset, flops=flops, tokenizer=tok, batch_size=B, lr=lr, beta2=beta2, 35 | context_size=d, d_model=d, n_layers=L, rope=rope, out_dir=out_dir) 36 | 37 | good_L = L == Ld[d]//2 or L == Ld[d] 38 | if good_L and 2 <= L: 39 | if tok is None: 40 | pass 41 | submit(**(args | dict(context_size=P*d, attention_window=d))) 42 | submit(**args) 43 | else: 44 | pass 45 | submit(**args) 46 | 47 | good_L = Ld[d]//4 < L <= Ld[d]//2 48 | if good_L and 2 <= L and tok is None: 49 | for d_local in ds: 50 | if d//2 <= d_local < d: 51 | if d_local >= {'1e17': 192, '1e18': 256, '1e19': 384}[flops]: 52 | for L_local in [L]: 53 | for P_MB in [4, 8]: 54 | mega_args = args | dict(model='MegaByte', patch_size=P_MB, context_size=P_MB*d, 55 | d_local=d_local, n_local_layers=L_local) 56 | submit(**mega_args) 57 | 58 | for P in [P]: 59 | for patch_method in ['utf8', 'periodic']: 60 | TG = d if P != 6 else None 61 | TL = P*TG if P != 6 else None 62 | wise_args = args | dict(model='SpaceByte', patch_method=patch_method, 63 | global_context_size=TG, context_size=TL, d_local=d_local, n_local_layers=L_local, 64 | local_attention_window=d_local) 65 | submit(**wise_args) 66 | -------------------------------------------------------------------------------- /prepare.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | from tqdm import tqdm 5 | 6 | import numpy as np 7 | import datasets 8 | 9 | import util 10 | 11 | def prepare(dataset_name: str, 12 | name: str = None, 13 | max_data_bytes: str = -1, # max number of bytes from dataset to use 14 | tokenizer: str = None, 15 | test_fraction: float = 0.01, # fraction of train data to allocate to test and validation if they don't exist 16 | out_dir: str = None, 17 | filter: str = None, 18 | ): 19 | if isinstance(max_data_bytes, str): 20 | max_data_bytes = round(eval(max_data_bytes)) 21 | if name and ',' in name: 22 | name = tuple(name.split(',')) # fire does this automatically 23 | 24 | if out_dir is None: 25 | out_dir = os.path.join('datasets', dataset_name 26 | + (f'_{name}' if name else '') 27 | + (f'.{filter.split(".")[-1]}' if filter else '') 28 | ) 29 | os.makedirs(out_dir, exist_ok=True) 30 | gen = np.random.default_rng(0) 31 | 32 | if tokenizer is not None and tokenizer not in util.Tokenizer.tiktoken_encodings: 33 | tokenizer = f'{out_dir}/{tokenizer}' 34 | tokenizer = util.Tokenizer(tokenizer) 35 | 36 | def open_split(split): 37 | file_name = os.path.join(out_dir, split + tokenizer.file_suffix) 38 | assert not os.path.exists(file_name) 39 | return open(file_name, 'wb') 40 | 41 | def get_dataset(name): 42 | if name is not None: 43 | print('loaded', name) 44 | return datasets.load_dataset(dataset_name, name, streaming=True, trust_remote_code=True) 45 | if isinstance(name, tuple): 46 | dataset_list = [get_dataset(name0) for name0 in name] 47 | def merged_dataset(split): 48 | merge_gen = np.random.default_rng(abs(hash(split))) 49 | dataset_iters = [iter(d[split]) for d in dataset_list] 50 | while True: 51 | if len(dataset_iters) == 0: 52 | break 53 | i = merge_gen.integers(len(dataset_iters)) 54 | try: 55 | yield next(dataset_iters[i]) 56 | except StopIteration: 57 | del dataset_iters[i] 58 | dataset = {split: merged_dataset(split) for split in dataset_list[0].keys()} 59 | else: 60 | dataset = get_dataset(name) 61 | 62 | print(f'found splits: {list(dataset.keys())}') 63 | for split, data in dataset.items(): 64 | total_data_bytes = 0 65 | with open_split(split) as out_file: 66 | test_files = [] 67 | if split == 'train': 68 | for s in 'test', 'validation': 69 | if s not in dataset: 70 | print(f'{s} not in dataset. Randomly allocating {test_fraction*100:.2g}% of train to {s}...') 71 | test_files.append(open_split(s)) 72 | 73 | for example in tqdm(data, split): 74 | if filter is not None: 75 | *filter_keys, filter_value = filter.split('.') 76 | ex = example 77 | for key in filter_keys: 78 | ex = ex[key] 79 | if ex != filter_value: 80 | continue 81 | 82 | key = 'code' if 'github-code' in dataset_name else \ 83 | 'content' if 'the-stack' in dataset_name else 'text' 84 | text = example[key] 85 | 86 | if max_data_bytes > 0: 87 | total_data_bytes += len(text.encode('utf-8')) 88 | if total_data_bytes > max_data_bytes: 89 | break 90 | 91 | text = tokenizer.encode(text, prepend_BOS=True, dtype=tokenizer.dtype, tensor=np.array) 92 | assert tokenizer.BOS not in text[1:] 93 | text = text.tobytes() 94 | write_text = True 95 | 96 | if len(test_files) > 0: 97 | r = gen.random() 98 | for f in test_files: 99 | if r < test_fraction: 100 | f.write(text) 101 | write_text = False 102 | break 103 | else: 104 | r -= test_fraction 105 | 106 | if write_text: 107 | out_file.write(text) 108 | 109 | for f in test_files: 110 | f.close() 111 | 112 | import fire 113 | if __name__ == '__main__': 114 | fire.Fire(prepare) 115 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | 6 | import util 7 | 8 | def dataset(name, tokenizer=None): 9 | if name == 'zeros': 10 | return ZeroesDataset(tokenizer) 11 | else: 12 | return MemmapDataset(name, tokenizer) 13 | 14 | class MemmapDataset: 15 | def __init__(self, name, tokenizer=None): 16 | super().__init__() 17 | 18 | if not isinstance(tokenizer, util.Tokenizer): 19 | if tokenizer is not None and tokenizer not in util.Tokenizer.tiktoken_encodings: 20 | tokenizer = f'datasets/{name}/{tokenizer}' 21 | tokenizer = util.Tokenizer(tokenizer) 22 | self.tokenizer = tokenizer 23 | 24 | data_dir = os.path.join('datasets', name) 25 | self.data = {} 26 | self.bytes_per_token = None 27 | for file_name in os.listdir(data_dir): 28 | if file_name.endswith(tokenizer.file_suffix): 29 | split = file_name[:-len(tokenizer.file_suffix)] 30 | # we recreate np.memmap with every access to avoid a memory leak 31 | # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122 32 | self.data[split] = lambda file_name=file_name: \ 33 | np.memmap(os.path.join(data_dir, file_name), dtype=tokenizer.dtype, mode='r') 34 | if split == 'train': 35 | # NOTE: assume that the tokenized file contains the same data as the .txt file 36 | self.bytes_per_token = 1 if tokenizer.name is None else ( 37 | os.path.getsize(os.path.join(data_dir, split + '.txt')) * np.dtype(tokenizer.dtype).itemsize / 38 | os.path.getsize(os.path.join(data_dir, file_name)) ) 39 | print(f'MemmapDataset: found {list(self.splits())} splits') 40 | assert self.bytes_per_token is not None 41 | 42 | if 'validation' in self.data: 43 | assert 'val' not in self.data 44 | self.data['val'] = self.data['validation'] 45 | del self.data['validation'] 46 | 47 | self.vocab_size = tokenizer.vocab_size 48 | self.BOS = tokenizer.BOS 49 | 50 | def splits(self): 51 | return self.data.keys() 52 | 53 | def iter(self, split, *, context_size, batch_size=1, seed=0, device='cpu'): 54 | data = self.data[split] 55 | data_size = len(data()) 56 | T = context_size 57 | B = batch_size 58 | rand_gen = torch.Generator() 59 | rand_gen.manual_seed(seed) 60 | 61 | while True: 62 | targets = torch.zeros(B, T, dtype=torch.int64) 63 | tokens = torch.full((B, T), self.BOS, dtype=torch.int64) 64 | 65 | b = 0 66 | while b < B: 67 | t = torch.randint(data_size, tuple(), generator=rand_gen) 68 | target = data()[t:] 69 | 70 | # align with BOS if found in next T tokens 71 | BOS_index, = (target[:T] == self.BOS).nonzero() 72 | if len(BOS_index) > 0: 73 | target = target[BOS_index[0]+1:] 74 | 75 | target = target[:T] 76 | if len(target) < T: 77 | continue 78 | targets[b] = torch.from_numpy(target.astype(np.int64)) 79 | b += 1 80 | 81 | tokens[:,1:] = targets[:, :-1] 82 | yield to_device(tokens, device), to_device(targets, device) 83 | 84 | class ZeroesDataset: 85 | def __init__(self, tokenizer=None): 86 | if not isinstance(tokenizer, util.Tokenizer): 87 | tokenizer = util.Tokenizer(tokenizer) 88 | self.tokenizer = tokenizer 89 | 90 | self.vocab_size = tokenizer.vocab_size 91 | self.BOS = tokenizer.BOS 92 | self.bytes_per_token = 1 93 | 94 | def splits(self): 95 | return ['train', 'val', 'test'] 96 | 97 | def iter(self, split, *, context_size, batch_size=1, seed=0, device='cpu'): 98 | T = context_size 99 | B = batch_size 100 | 101 | while True: 102 | tokens = torch.full((B, T), 0) 103 | targets = torch.cat([torch.full((B, 1), self.BOS), tokens[:, :-1]], 1) 104 | yield to_device(tokens, device), to_device(targets, device) 105 | 106 | def to_device(x, device): 107 | if 'cuda' in device: 108 | x = x.pin_memory() 109 | if 'cpu' not in device: 110 | # non_blocking=True is bugged on mps 111 | # https://github.com/pytorch/pytorch/issues/83015 112 | x = x.to(device, non_blocking = 'mps' not in device) 113 | return x 114 | -------------------------------------------------------------------------------- /reproduce/README.md: -------------------------------------------------------------------------------- 1 | # reproduction 2 | 3 | To reproduce the results in our paper, follow the steps below, starting from the main directory. 4 | 5 | ## install 6 | 7 | ``` 8 | pip install -r requirements.txt 9 | ``` 10 | 11 | ## prepare datasets 12 | 13 | Download datasets and save two copies, one in UTF8 and the other using the gpt2 tokenizer: 14 | 15 | ``` 16 | python3 prepare.py pg19 --out_dir=datasets/pg19 17 | python3 prepare.py pg19 --out_dir=datasets/pg19 --tokenizer=gpt2 18 | python3 prepare.py lucadiliello/STORIES --out_dir=datasets/STORIES 19 | python3 prepare.py lucadiliello/STORIES --out_dir=datasets/STORIES --tokenizer=gpt2 20 | python3 prepare.py monology/pile-uncopyrighted --out_dir=datasets/arxiv --max_data_bytes=20e9 --filter=meta.pile_set_name.ArXiv 21 | python3 prepare.py monology/pile-uncopyrighted --out_dir=datasets/arxiv --max_data_bytes=20e9 --filter=meta.pile_set_name.ArXiv --tokenizer=gpt2 22 | python3 prepare.py monology/pile-uncopyrighted --out_dir=datasets/github --max_data_bytes=20e9 --filter=meta.pile_set_name.Github 23 | python3 prepare.py monology/pile-uncopyrighted --out_dir=datasets/github --max_data_bytes=20e9 --filter=meta.pile_set_name.Github --tokenizer=gpt2 24 | ``` 25 | 26 | In order to prepare the datasets using the sentencepiece tokenizer, you'll need to install sentencepiece with the spm_train command to train the sentencepiece tokenizer. 27 | See [sentencepiece installation](https://github.com/google/sentencepiece?tab=readme-ov-file#installation). 28 | To train the sentencepiece tokenizers: 29 | 30 | ``` 31 | for dir in pg19 arxiv github; do 32 | cd datasets/$dir 33 | spm_train --input=train.txt --model_prefix=sp --model_type=bpe --vocab_size=50257 --num_threads=32 --allow_whitespace_only_pieces=True --remove_extra_whitespaces=False --byte_fallback=True --normalization_rule_name=identity --input_sentence_size=10000000 34 | cd - 35 | done 36 | ``` 37 | 38 | Download datasets and save using sentencepiece (sp) tokens: 39 | 40 | ``` 41 | python3 prepare.py pg19 --out_dir=datasets/pg19 --tokenizer=sp 42 | python3 prepare.py monology/pile-uncopyrighted --out_dir=datasets/arxiv --max_data_bytes=20e9 --filter=meta.pile_set_name.ArXiv --tokenizer=sp 43 | python3 prepare.py monology/pile-uncopyrighted --out_dir=datasets/github --max_data_bytes=20e9 --filter=meta.pile_set_name.Github --tokenizer=sp 44 | ``` 45 | 46 | ## SpaceByte-793M+184M training 47 | 48 | To train SpaceByte-793M+184M on pg19, STORIES, arxiv, and github, run: 49 | 50 | ``` 51 | for dataset in pg19 STORIES arxiv github; do 52 | python3 train.py --dataset=$dataset --model=SpaceByte --context_size=8192 --global_context_size=1344 --d_model=1536 --d_local=768 --n_layers=28 --n_local_layers=26 --local_attention_window=768 --rope --batch_size=64 --iters='30e9/tokens' --lr='0.5e-2*B**0.5' --beta2=0.98 --patch_method=utf8 --micro_batch_size=2 --out_dir=spacebyte-793M184M 53 | done 54 | ``` 55 | 56 | The trained models will appear in a `spacebyte-793M184M` subdirectory. 57 | wandb will be used for logging into a project named `spacebyte-793M184M`. 58 | See `test bits per byte' in wandb for the bits-per-byte for the test split. 59 | 60 | ## Transformer-1B training 61 | 62 | To train the subword Transformer-1B models on pg19, STORIES, arxiv, and github, run: 63 | 64 | ``` 65 | python3 train.py --batch_size=64 --beta2=0.98 --context_size=2048 --d_model=1536 --iters='7.49e9/tokens' --lr=0.5e-2*B**0.5 --rope=True --micro_batch_size=2 --tokenizer=sp --dataset=pg19 --n_layers=40 --out_dir=spacebyte8_medium2 66 | python3 train.py --batch_size=64 --beta2=0.98 --context_size=2048 --d_model=1536 --iters='6.83e9/tokens' --lr=0.5e-2*B**0.5 --rope=True --micro_batch_size=2 --tokenizer=gpt2 --dataset=STORIES --n_layers=44 --out_dir=spacebyte8_medium2 67 | python3 train.py --batch_size=64 --beta2=0.98 --context_size=2048 --d_model=1536 --iters='8.10e9/tokens' --lr=0.5e-2*B**0.5 --rope=True --micro_batch_size=2 --tokenizer=sp --dataset=arxiv --n_layers=37 --out_dir=spacebyte8_medium2 68 | python3 train.py --batch_size=64 --beta2=0.98 --context_size=2048 --d_model=1536 --iters='9.52e9/tokens' --lr=0.5e-2*B**0.5 --rope=True --micro_batch_size=2 --tokenizer=sp --dataset=github --n_layers=31 --out_dir=spacebyte8_medium2 69 | ``` 70 | 71 | ## Pareto frontier grid search 72 | 73 | To train the Pareto frontier models using the grid search, you'll first want to set the SUBMIT_COMMAND variable in `reproduce/jobs.py` so that the training runs aren't all done locally. 74 | Then you can launch the grid search using 75 | 76 | ``` 77 | python3 reproduce/jobs.py 78 | ``` 79 | 80 | The trained models will appear in a `spacebyte` subdirectory. 81 | wandb will be used for logging into a project named `spacebyte`. 82 | 83 | To create the plots and table data in the paper, move `reproduce/plots.py` and the `reproduce/experiments.ipynb` jupyter notebook into the main directory and run the jupyter notebook. 84 | -------------------------------------------------------------------------------- /megabyte.py: -------------------------------------------------------------------------------- 1 | from transformer import * 2 | 3 | @dataclass 4 | class MegaByteConfig(TransformerConfig): 5 | patch_size: int = 8 6 | 7 | d_local: int = None 8 | n_local_layers: int = None 9 | 10 | # I don't think this is needed, but I include it to follow the MegaByte paper. 11 | use_padding: bool = True 12 | 13 | def __post_init__(self): 14 | c = self 15 | 16 | if c.context_size is None: 17 | c.context_size = c.patch_size * c.d_model 18 | 19 | super().__post_init__() 20 | 21 | if c.d_local is None: 22 | c.d_local = c.d_model//2 23 | if c.n_local_layers is None: 24 | c.n_local_layers = c.n_layers 25 | 26 | class MegaByte(Model): 27 | Config = MegaByteConfig 28 | 29 | def __init__(self, config: MegaByteConfig): 30 | super().__init__() 31 | self.config = config 32 | c = config 33 | 34 | T = c.context_size 35 | K = int_div(T, c.patch_size) 36 | D_G = int_div(c.d_model, c.patch_size) 37 | 38 | assert not c.tie_embedding # not implemented 39 | self.global_token_embedding = nn.Embedding(c.vocab_size, D_G) 40 | assert c.position_encoding # else not implemented 41 | self.global_position_encoding = nn.Parameter(torch.randn(T, D_G)) 42 | 43 | if c.use_padding: 44 | self.global_pad = nn.Parameter(torch.randn(c.d_model)) 45 | 46 | global_config = c.copy(context_size=K) 47 | self.global_blocks = nn.ModuleList([TransformerBlock(global_config) for _ in range(c.n_layers)]) 48 | 49 | self.global_to_local = nn.Linear(D_G, c.d_local) 50 | 51 | self.local_token_embedding = nn.Embedding(c.vocab_size, c.d_local) 52 | # Local position encoding does not appear in Fig. 2 of the MegaByte paper, but it is used according to: 53 | # https://openreview.net/forum?id=JTmO2V9Xpz¬eId=VhgZzXezYZ 54 | self.local_position_encoding = nn.Parameter(torch.randn(c.patch_size, c.d_local)) 55 | 56 | if c.use_padding: 57 | self.local_pad = nn.Parameter(torch.randn(c.d_local)) 58 | 59 | local_config = c.copy(d_model=c.d_local, context_size=c.patch_size, attention_window=None) 60 | self.local_blocks = nn.ModuleList([TransformerBlock(local_config) for _ in range(c.n_local_layers)]) 61 | 62 | self.logits = Logits(local_config) 63 | 64 | super().__post_init__() 65 | 66 | generate = Transformer.generate 67 | train_log = Transformer.train_log 68 | 69 | def num_params(self, embedding=True): 70 | n = num_params(self) 71 | if not embedding: 72 | n -= num_params(self.global_token_embedding) 73 | n -= num_params(self.global_position_encoding) 74 | n -= num_params(self.local_token_embedding) 75 | n -= num_params(self.local_position_encoding) 76 | return n 77 | 78 | def n_mult_add(self, training=False): 79 | c = self.config 80 | P = c.patch_size 81 | T = c.context_size 82 | K = T // P 83 | d = c.d_local 84 | V = c.vocab_size 85 | 86 | n = sum(module.n_mult_add(K) for module in self.global_blocks) 87 | n += T * num_params(self.global_to_local) 88 | n += K * sum(module.n_mult_add(P) for module in self.local_blocks) 89 | 90 | return n + T*d*V 91 | 92 | def forward(self, tokens, targets=None, *, cache=None, log=None): 93 | c = self.config 94 | B, T0 = tokens.shape 95 | P = c.patch_size 96 | D = c.d_model 97 | D_G = int_div(D, P) 98 | d = c.d_local 99 | 100 | if cache is None: 101 | t0 = 0 102 | pending_global_tokens = None 103 | else: 104 | prefix = self.module_name + '->' 105 | t0 = cache.get(prefix+'t0', 0) 106 | cache[prefix+'t0'] = t0 + T0 107 | pending_global_tokens = cache.get(prefix+'pending_global_tokens', None) 108 | if pending_global_tokens is None: 109 | pending_global_tokens = torch.full((B, P-1), c.BOS, device=tokens.device) 110 | 111 | global_tokens = torch.cat([pending_global_tokens, tokens], 1) 112 | Kx = global_tokens.shape[1] // P 113 | if cache is not None: 114 | cache[prefix+'pending_global_tokens'] = global_tokens[:, Kx*P:] 115 | 116 | if Kx > 0: 117 | global_tokens = global_tokens[:, :Kx*P] 118 | global_t = (P-1 + t0) // P 119 | 120 | x = self.global_token_embedding(global_tokens) # (B, Kx*P, D_G) 121 | x = x + self.global_position_encoding[global_t*P : (global_t+Kx)*P] 122 | x = x.view(B, Kx, D) 123 | 124 | if c.use_padding and global_t == 0: 125 | x = torch.cat([self.global_pad.broadcast_to(B,1,D), x[:,1:]], 1) # B, K, D 126 | 127 | for block in self.global_blocks: 128 | x = block(x, cache=cache, cache_seqlen=global_t, log=log) 129 | 130 | x = x.view(B*Kx, P, D_G) 131 | x = self.global_to_local(x).view(B, Kx*P, d) 132 | 133 | local_emb = self.local_token_embedding(tokens) # (B, T0, d) 134 | if t0==0 and T0>=P: 135 | t1 = (T0//P)*P 136 | y = local_emb[:, :t1].reshape(B*(T0//P), P, d) + self.local_position_encoding 137 | local_emb = local_emb[:, t1:] # (B, T0-t1, d) 138 | 139 | if c.use_padding: 140 | y = torch.cat([self.local_pad.broadcast_to(len(y), 1, d), y[:, 1:]], 1) 141 | 142 | y = x[:, :t1].reshape(-1, P, d) + y 143 | x = x[:, t1:] 144 | 145 | for block in self.local_blocks: 146 | y = block(y, log=log) # cache not needed 147 | 148 | y = y.view(B, t1, d) 149 | else: 150 | t1 = t0 151 | 152 | T1 = local_emb.shape[1] 153 | if cache is not None: 154 | if t0 > 0: 155 | x = torch.cat([cache[prefix+'local-in'], x], 1) if Kx>0 else cache[prefix+'local-in'] 156 | cache[prefix+'local-in'] = x[:, T1:] 157 | assert (t1 + T1 + x[:, T1:].shape[1]) % P == 0 158 | x = x[:, :T1] 159 | 160 | if T1 > 0: 161 | assert t1//P == (t1 + T1 - 1)//P 162 | z = local_emb + self.local_position_encoding[t1%P : (t1+T1-1)%P + 1] # (B, T1, d) 163 | 164 | if c.use_padding and t1%P == 0: 165 | z = torch.cat([self.local_pad.broadcast_to(B, 1, d), z[:, 1:]], 1) 166 | 167 | z = x + z 168 | 169 | if t1%P == 0: 170 | cache[prefix+'local-cache'] = {} 171 | for block in self.local_blocks: 172 | z = block(z, cache=cache[prefix+'local-cache'], cache_seqlen=t1%P, log=log) 173 | 174 | y = torch.cat([y, z], 1) if t0==0 and T0>=P else z 175 | 176 | logits = self.logits(y, log=log).view(B, T0, c.vocab_size) 177 | 178 | losses = None 179 | if targets is not None: 180 | losses = {} 181 | losses['cross entropy'] = util.cross_entropy(logits, targets) 182 | losses['loss'] = losses['cross entropy'] 183 | 184 | return logits, losses 185 | -------------------------------------------------------------------------------- /license/llama_LICENSE.txt: -------------------------------------------------------------------------------- 1 | LLAMA 2 COMMUNITY LICENSE AGREEMENT 2 | Llama 2 Version Release Date: July 18, 2023 3 | 4 | "Agreement" means the terms and conditions for use, reproduction, distribution and 5 | modification of the Llama Materials set forth herein. 6 | 7 | "Documentation" means the specifications, manuals and documentation 8 | accompanying Llama 2 distributed by Meta at ai.meta.com/resources/models-and- 9 | libraries/llama-downloads/. 10 | 11 | "Licensee" or "you" means you, or your employer or any other person or entity (if 12 | you are entering into this Agreement on such person or entity's behalf), of the age 13 | required under applicable laws, rules or regulations to provide legal consent and that 14 | has legal authority to bind your employer or such other person or entity if you are 15 | entering in this Agreement on their behalf. 16 | 17 | "Llama 2" means the foundational large language models and software and 18 | algorithms, including machine-learning model code, trained model weights, 19 | inference-enabling code, training-enabling code, fine-tuning enabling code and other 20 | elements of the foregoing distributed by Meta at ai.meta.com/resources/models-and- 21 | libraries/llama-downloads/. 22 | 23 | "Llama Materials" means, collectively, Meta's proprietary Llama 2 and 24 | Documentation (and any portion thereof) made available under this Agreement. 25 | 26 | "Meta" or "we" means Meta Platforms Ireland Limited (if you are located in or, if you 27 | are an entity, your principal place of business is in the EEA or Switzerland) and Meta 28 | Platforms, Inc. (if you are located outside of the EEA or Switzerland). 29 | 30 | By clicking "I Accept" below or by using or distributing any portion or element of the 31 | Llama Materials, you agree to be bound by this Agreement. 32 | 33 | 1. License Rights and Redistribution. 34 | 35 | a. Grant of Rights. You are granted a non-exclusive, worldwide, non- 36 | transferable and royalty-free limited license under Meta's intellectual property or 37 | other rights owned by Meta embodied in the Llama Materials to use, reproduce, 38 | distribute, copy, create derivative works of, and make modifications to the Llama 39 | Materials. 40 | 41 | b. Redistribution and Use. 42 | 43 | i. If you distribute or make the Llama Materials, or any derivative works 44 | thereof, available to a third party, you shall provide a copy of this Agreement to such 45 | third party. 46 | ii. If you receive Llama Materials, or any derivative works thereof, from 47 | a Licensee as part of an integrated end user product, then Section 2 of this 48 | Agreement will not apply to you. 49 | 50 | iii. You must retain in all copies of the Llama Materials that you 51 | distribute the following attribution notice within a "Notice" text file distributed as a 52 | part of such copies: "Llama 2 is licensed under the LLAMA 2 Community License, 53 | Copyright (c) Meta Platforms, Inc. All Rights Reserved." 54 | 55 | iv. Your use of the Llama Materials must comply with applicable laws 56 | and regulations (including trade compliance laws and regulations) and adhere to the 57 | Acceptable Use Policy for the Llama Materials (available at 58 | https://ai.meta.com/llama/use-policy), which is hereby incorporated by reference into 59 | this Agreement. 60 | 61 | v. You will not use the Llama Materials or any output or results of the 62 | Llama Materials to improve any other large language model (excluding Llama 2 or 63 | derivative works thereof). 64 | 65 | 2. Additional Commercial Terms. If, on the Llama 2 version release date, the 66 | monthly active users of the products or services made available by or for Licensee, 67 | or Licensee's affiliates, is greater than 700 million monthly active users in the 68 | preceding calendar month, you must request a license from Meta, which Meta may 69 | grant to you in its sole discretion, and you are not authorized to exercise any of the 70 | rights under this Agreement unless or until Meta otherwise expressly grants you 71 | such rights. 72 | 73 | 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE 74 | LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE 75 | PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 76 | EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY 77 | WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR 78 | FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE 79 | FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING 80 | THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR 81 | USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS. 82 | 83 | 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE 84 | LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, 85 | NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS 86 | AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, 87 | CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN 88 | IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF 89 | ANY OF THE FOREGOING. 90 | 91 | 5. Intellectual Property. 92 | 93 | a. No trademark licenses are granted under this Agreement, and in 94 | connection with the Llama Materials, neither Meta nor Licensee may use any name 95 | or mark owned by or associated with the other or any of its affiliates, except as 96 | required for reasonable and customary use in describing and redistributing the 97 | Llama Materials. 98 | 99 | b. Subject to Meta's ownership of Llama Materials and derivatives made by or 100 | for Meta, with respect to any derivative works and modifications of the Llama 101 | Materials that are made by you, as between you and Meta, you are and will be the 102 | owner of such derivative works and modifications. 103 | 104 | c. If you institute litigation or other proceedings against Meta or any entity 105 | (including a cross-claim or counterclaim in a lawsuit) alleging that the Llama 106 | Materials or Llama 2 outputs or results, or any portion of any of the foregoing, 107 | constitutes an infringement of intellectual property or other rights owned or licensable 108 | by you, then any licenses granted to you under this Agreement shall terminate as of 109 | the date such litigation or claim is filed or instituted. You will indemnify and hold 110 | harmless Meta from and against any claim by any third party arising out of or related 111 | to your use or distribution of the Llama Materials. 112 | 113 | 6. Term and Termination. The term of this Agreement will commence upon your 114 | acceptance of this Agreement or access to the Llama Materials and will continue in 115 | full force and effect until terminated in accordance with the terms and conditions 116 | herein. Meta may terminate this Agreement if you are in breach of any term or 117 | condition of this Agreement. Upon termination of this Agreement, you shall delete 118 | and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the 119 | termination of this Agreement. 120 | 121 | 7. Governing Law and Jurisdiction. This Agreement will be governed and 122 | construed under the laws of the State of California without regard to choice of law 123 | principles, and the UN Convention on Contracts for the International Sale of Goods 124 | does not apply to this Agreement. The courts of California shall have exclusive 125 | jurisdiction of any dispute arising out of this Agreement. 126 | -------------------------------------------------------------------------------- /spacebyte.py: -------------------------------------------------------------------------------- 1 | from transformer import * 2 | 3 | @dataclass 4 | class SpaceByteConfig(TransformerConfig): 5 | patch_method: str = 'utf8' # 'utf8', 'periodic' 6 | 7 | # inherited from TransformerConfig: 8 | # d_model = global model dimension 9 | # context_size = context size in bytes 10 | # n_layers = number of layers for the global model 11 | 12 | # for efficient training, global_context_size should be roughly equal to context_size/P where P is the average patch size 13 | global_context_size: int = None 14 | 15 | d_local: int = None 16 | n_initial_layers: int = None # number of layers for the first local model 17 | n_local_layers: int = None # total number of layers for both local models 18 | local_attention_window: int = None 19 | 20 | print_patches: float = 0 # fraction of time to print the patches 21 | 22 | def __post_init__(self): 23 | c = self 24 | 25 | if self.n_layers is None: 26 | # half of the Transformer default 27 | self.n_layers = round(max(4, 12.49*math.log2(c.d_model/154.3)) / 2) 28 | 29 | default_patch_size: int = 6 30 | if c.global_context_size is None: 31 | c.global_context_size = c.context_size//default_patch_size if c.context_size else self.d_model 32 | if c.context_size is None: 33 | c.context_size = default_patch_size * c.global_context_size 34 | if c.patch_method == 'periodic': 35 | assert c.context_size % c.global_context_size == 0 36 | 37 | super().__post_init__() 38 | 39 | assert c.tokenizer is None 40 | 41 | if c.d_local is None: 42 | c.d_local = c.d_model // 2 43 | if c.n_local_layers is None: 44 | c.n_local_layers = c.n_layers 45 | if c.n_initial_layers is None: 46 | c.n_initial_layers = c.n_local_layers // 2 47 | if c.local_attention_window is None: 48 | c.local_attention_window = c.d_local 49 | 50 | class SpaceByte(Model): 51 | Config = SpaceByteConfig 52 | 53 | def __init__(self, config: SpaceByteConfig): 54 | super().__init__() 55 | self.config = config 56 | c = self.config 57 | 58 | self.token_embedding = nn.Embedding( 59 | c.padded_vocab_size() if c.tie_embedding else c.vocab_size, c.d_local) 60 | assert c.position_encoding # else not implemented 61 | self.local_position_encoding = nn.Parameter(torch.randn(c.context_size, c.d_local)) 62 | 63 | local_config = c.copy(d_model=c.d_local, attention_window=c.local_attention_window) 64 | self.initial_blocks = nn.ModuleList([TransformerBlock(local_config) for _ in range(c.n_initial_layers)]) 65 | 66 | self.global_position_encoding = nn.Parameter(torch.randn(c.global_context_size, c.d_model)) 67 | 68 | global_config = c.copy(context_size=c.global_context_size) 69 | self.global_blocks = nn.ModuleList([TransformerBlock(global_config) for l in range(c.n_layers)]) 70 | 71 | self.final_blocks = nn.ModuleList([ 72 | TransformerBlock(local_config) for _ in range(c.n_initial_layers, c.n_local_layers) ]) 73 | self.logits = Logits(local_config, self.token_embedding if c.tie_embedding else None) 74 | 75 | super().__post_init__() 76 | 77 | def num_params(self, embedding=True): 78 | n = num_params(self) 79 | if not embedding: 80 | if not self.config.tie_embedding: 81 | n -= num_params(self.token_embedding) 82 | n -= num_params(self.local_position_encoding) + num_params(self.global_position_encoding) 83 | return n 84 | 85 | def generate(self, tokens, *, max_tokens=None, temperature=1.0, top_k=None, input_lengths=None, logits=False, 86 | check_logits_func=None): 87 | return Transformer.generate(self, tokens, max_tokens=max_tokens, temperature=temperature, top_k=top_k, 88 | input_lengths=input_lengths, logits=logits, check_logits_func=check_logits_func, use_cache=False) 89 | 90 | def n_mult_add(self, training=False): 91 | c = self.config 92 | TL = c.context_size 93 | TG = c.global_context_size 94 | d = c.d_local 95 | V = c.vocab_size 96 | 97 | n = sum(module.n_mult_add(TL) for module in self.initial_blocks + self.final_blocks) 98 | n += sum(module.n_mult_add(TG) for module in self.global_blocks) 99 | 100 | return n + TL*d*V 101 | 102 | def forward(self, tokens, targets=None, *, log=None): 103 | c = self.config 104 | 105 | x = self.token_embedding(tokens) 106 | B, Tx, d = x.shape 107 | 108 | assert Tx <= c.context_size 109 | x = x + self.local_position_encoding[:Tx] 110 | 111 | for block in self.initial_blocks: 112 | x = block(x, log=log) 113 | 114 | D = c.d_model 115 | T = c.context_size 116 | TG = c.global_context_size 117 | P = T // TG 118 | if c.patch_method != 'periodic': 119 | global_T = torch.full((B,), -1) 120 | max_global_T = min(TG, Tx) 121 | global_ts = torch.full((B, max_global_T), Tx-1, device=tokens.device) 122 | stop_gen = [] 123 | def set_global_ts(use_global): 124 | for b, use_global0 in enumerate(use_global): 125 | global_ts0, = use_global0.nonzero(as_tuple=True) 126 | if len(global_ts0) > TG: 127 | if targets is not None: 128 | targets[b, global_ts0[TG]:] = -1 129 | else: 130 | stop_gen.append((b, global_ts0[TG])) 131 | global_ts0 = global_ts0[:TG] 132 | global_T[b] = len(global_ts0) 133 | assert global_T[b] <= max_global_T 134 | global_ts[b, :global_T[b]] = global_ts0 135 | 136 | if c.patch_method == 'utf8': 137 | # https://en.wikipedia.org/wiki/UTF-8#Encoding 138 | # https://en.wikipedia.org/wiki/UTF-8#Codepage_layout 139 | use_global = ( 140 | (tokens < ord('0')) | 141 | ((ord('9') < tokens) & (tokens < ord('A'))) | 142 | ((ord('Z') < tokens) & (tokens < ord('a'))) | 143 | ((ord('z') < tokens) & (tokens < 0b1000_0000)) | 144 | (0b1100_0000 <= tokens) 145 | ) 146 | else: 147 | assert False 148 | 149 | use_global[:, 1:] &= use_global[:, :-1].bitwise_not() 150 | 151 | if c.patch_method == 'utf8': 152 | use_global |= tokens == c.BOS 153 | set_global_ts(use_global) 154 | 155 | if log is not None: 156 | log['global_T'] = global_T 157 | log['global_ts'] = global_ts 158 | 159 | y = x.gather(1, global_ts[:, :, None].expand(B, max_global_T, d)) 160 | else: 161 | y = x[:, ::P] 162 | max_global_T = y.shape[1] 163 | 164 | y = torch.cat([torch.zeros(B, max_global_T, D-d, **like(x)), y], -1) 165 | 166 | y = y + self.global_position_encoding[:max_global_T] 167 | 168 | # print patch boundaries 169 | if c.print_patches > 0 and targets is not None and torch.rand(()) < c.print_patches: 170 | b0, t0, T0 = 0, 0, 128 171 | global_ts0, targets0 = global_ts[b0].cpu(), targets[b0].cpu() 172 | print() 173 | print(f'TG={global_T[b0]/TG:.0%}, ignored={targets0.eq(-1).float().mean().item():.0%}') 174 | while t0 < T: 175 | print(''.join('!' if t_ in global_ts0 else ' ' for t_ in range(t0, t0+T0)) + '|') 176 | print(util.chrs(targets0[t0:t0+T0]) + '|') 177 | print(self.dataset_tokenizer.decode(tokens[b0, t0:t0+T0])) 178 | t0 += T0 179 | print() 180 | 181 | for block in self.global_blocks: 182 | y = block(y, log=log) 183 | 184 | if c.patch_method != 'periodic': 185 | x = torch.stack([ 186 | x0.index_add(0, ts[:Ty0], y0[:Ty0, -d:]) 187 | for x0, ts, Ty0, y0 in zip(x, global_ts, global_T, y, strict=True) ]) 188 | else: 189 | x = x.index_add(1, torch.arange(0, Tx, P, device=x.device), y[:, :, -d:]) 190 | del y 191 | 192 | for block in self.final_blocks: 193 | x = block(x, log=log) 194 | 195 | logits = self.logits(x, log=log) 196 | 197 | losses = None 198 | if targets is not None: 199 | losses = {} 200 | ignore_index = -1 if c.patch_method != 'periodic' else None 201 | losses['cross entropy'] = util.cross_entropy(logits, targets, ignore_index=ignore_index) 202 | losses['loss'] = losses['cross entropy'] 203 | 204 | if c.patch_method != 'periodic': 205 | losses['global context'] = sum(global_T) / len(global_T) 206 | losses['ignored fraction'] = (targets == -1).float().mean() 207 | else: 208 | if c.patch_method != 'periodic': 209 | for b, t in stop_gen: 210 | # not enough global blocks to continue generation 211 | logits[b, t:] = -1e4 212 | logits[b, t:, c.BOS] = 1e4 213 | 214 | return logits, losses 215 | -------------------------------------------------------------------------------- /reproduce/plots.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numbers 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | 7 | import mpl_toolkits.axes_grid1 8 | 9 | def tensor_items(xs, dtype=None): 10 | if isinstance(xs, list): 11 | return [tensor_items(x) for x in xs] 12 | elif isinstance(xs, tuple): 13 | return tuple(tensor_items(x) for x in xs) 14 | elif isinstance(xs, dict): 15 | return {k: tensor_items(v) for k,v in xs.items()} 16 | elif isinstance(xs, torch.Tensor): 17 | if dtype is None: 18 | dtype = xs.dtype 19 | if dtype == torch.bfloat16: 20 | dtype = torch.float32 21 | return xs.item() if xs.dim()==0 else xs.detach().cpu().to(dtype=dtype).numpy() 22 | elif isinstance(xs, np.ndarray): 23 | return xs if dtype is None else xs.astype(dtype, copy=True) 24 | elif hasattr(xs, '__next__') and not hasattr(xs, '__getitem__'): 25 | return (tensor_items(x) for x in xs) 26 | else: 27 | return xs 28 | 29 | def gridplot_data(dataset, x_func, y_func, label_funcs, *args, grid_kwargs={}, **kwargs): 30 | if not isinstance(label_funcs[0], list): 31 | label_funcs = [label_funcs] 32 | 33 | subplots = [ [subplot_data(dataset, x_func, y_func, f, *args, **kwargs) for f in fs] for fs in label_funcs ] 34 | grid(*subplots, **grid_kwargs) 35 | 36 | # xys = [(optional xs, ys, optional 'legend label', options), ...] 37 | # options is a dictionary with the following options: 38 | # 'color' = matlibplot color 39 | # others are passed to matlibplot 40 | def plot(xys, scale='linear', *, 41 | labels=(None,None,None), dataRange=None, 42 | subplot=plt, figsize=(9,5), 43 | plotrange=None, # ([x_min, x_max], [y_min, y_max]) 44 | new_figure=True, 45 | legend=True, # or a dict of legend options 46 | args=['o'], 47 | **kwargs): 48 | 49 | if len(xys) == 0: 50 | print(f'plots.plot: nothing to plot: {xys}') 51 | return 52 | 53 | if not isinstance(xys, list) or isinstance(xys[0], numbers.Number): 54 | xys = [xys] 55 | 56 | xys = [ (dataRange, xy) if not isinstance(xy, tuple) else \ 57 | (dataRange, *xy) if not (hasattr(xy[1], '__getitem__') and 58 | isinstance(xy[1][0], (int, float, torch.Tensor))) else \ 59 | xy for xy in xys] 60 | 61 | use_legend = False 62 | if scale == 'log': 63 | scale = ('linear', 'log') 64 | if isinstance(scale, str): 65 | scale = (scale, scale) 66 | 67 | is_plt = subplot is plt 68 | if new_figure and is_plt: 69 | plt.figure(figsize=figsize) 70 | for x, y, *other in xys: 71 | label = None 72 | if len(other) > 0 and not isinstance(other[0], dict): 73 | label = other[0] 74 | del other[0] 75 | use_legend = legend is not False 76 | 77 | options = {} 78 | if len(other) > 0: 79 | options = other[0] 80 | del other[0] 81 | 82 | assert len(other) == 0, other 83 | 84 | if x is None: 85 | # assert not callable(y) 86 | x = np.arange(1, len(y)+1) 87 | # elif isinstance(x, tuple): # x = (min, max, optional: # points) 88 | # if len(x) == 2: 89 | # x = (*x, 1000 if callable(y) else len(y)) 90 | # x = np.linspace(*x) if not (callable(y) and scale[0] == 'log') else \ 91 | # np.exp(np.linspace(np.log(x[0]), np.log(x[1]), x[2])) 92 | 93 | # if callable(y): 94 | # try: 95 | # y = y(x) 96 | # except: 97 | # y = [y(xi) for xi in x] 98 | # joined0 = True 99 | 100 | if isinstance(x[0], str): 101 | plotter = subplot.bar 102 | subplot.xticks(rotation=90) 103 | else: 104 | plotter = subplot.plot 105 | 106 | if isinstance(x, torch.Tensor): 107 | x = x.detach().cpu() 108 | if isinstance(y, torch.Tensor): 109 | y = y.detach().cpu() 110 | 111 | if 'args' in options: 112 | args0 = options['args'] 113 | del options['args'] 114 | else: 115 | args0 = args 116 | 117 | assert len(x) == len(y) 118 | try: 119 | plotter(x, y, *args0, label=label, **(kwargs|options)) 120 | except: 121 | print(f'{x=}') 122 | print(f'{y=}') 123 | print(f'{label=}') 124 | raise 125 | 126 | if scale[0] != 'linear': 127 | (plt.xscale if is_plt else subplot.set_xscale)(scale[0]) 128 | (plt.yscale if is_plt else subplot.set_yscale)(scale[1]) 129 | 130 | if use_legend: 131 | subplot.legend(**legend) if isinstance(legend, dict) else subplot.legend() 132 | 133 | assert isinstance(labels, (tuple, list)) 134 | for f,l in zip((plt.xlabel if is_plt else subplot.set_xlabel, 135 | plt.ylabel if is_plt else subplot.set_ylabel, 136 | plt.title if is_plt else subplot.set_title), labels): 137 | if l is not None: 138 | f(l) 139 | 140 | if plotrange is not None: 141 | if plotrange[0] is not None: 142 | plt.xlim(*plotrange[0]) if is_plt else subplot.set_xlim(*plotrange[0]) 143 | if plotrange[1] is not None: 144 | plt.ylim(*plotrange[1]) if is_plt else subplot.set_ylim(*plotrange[1]) 145 | 146 | if new_figure and is_plt: 147 | plt.show() 148 | 149 | return subplot 150 | 151 | def plot_data(dataset, x_func, y_func, label_func, *args, 152 | legend_labels=[], # in order to specify ordering 153 | post_process=lambda data: data, 154 | options_func=lambda label: {}, subplot=plot, **kwargs): 155 | xyls_dict = collections.defaultdict(list) 156 | def sortable_label(label): 157 | try: 158 | return (legend_labels.index(label), str(l)) 159 | except ValueError: 160 | (len(legend_labels), str(l)) 161 | for d in dataset: 162 | l = label_func(d) 163 | if l is not None: 164 | xyls_dict[(sortable_label(l), l)].append((x_func(d), y_func(d))) 165 | for k in xyls_dict: 166 | xyls_dict[k].sort() 167 | plot_data = [([x for x,_ in xys], [y for _,y in xys], str(l), options_func(l)) 168 | for (_,l), xys in sorted(xyls_dict.items())] 169 | return subplot(post_process(plot_data), *args, **kwargs) 170 | 171 | def subplot_data(*args, **kwargs): 172 | return plot_data(*args, subplot=SubPlot, **kwargs) 173 | 174 | # example: 175 | # plots.grid([plots.SubPlot(torch.arange(5)), plots.SubPlot(torch.arange(4))], figsize=6) 176 | def grid(*subplots, figsize=16): 177 | subplots = np.matrix(subplots) # NOTE: use grid([...], [...]) instead of grid([[...], [...]]) 178 | n_rows, n_cols = subplots.shape 179 | if isinstance(figsize, int): 180 | figsize = ( figsize, figsize * n_rows/n_cols * 0.8 * n_cols/(1+n_cols) ) 181 | fig, axs = plt.subplots(*subplots.shape, figsize=figsize) 182 | # fig, axs = plt.subplots(*subplots.shape, figsize=(figsize, figsize * (1+n_rows)/(1+n_cols))) 183 | axs = np.matrix(axs).reshape(*subplots.shape) 184 | 185 | for i in range(n_rows): 186 | for j in range(n_cols): 187 | sub = subplots[i,j] 188 | sub.plot(*sub.args, **sub.kwargs, subplot=axs[i,j]) 189 | 190 | plt.tight_layout(pad=1*n_rows/n_cols) 191 | # plt.tight_layout(pad=10*n_rows/n_cols) 192 | plt.show() 193 | 194 | class Sub: 195 | def __init__(self, plot, *args, **kwargs): 196 | self.plot = plot 197 | self.args = args 198 | self.kwargs = kwargs 199 | 200 | class SubPlot(Sub): 201 | def __init__(self, *args, **kwargs): 202 | super().__init__(plot, *args, **kwargs) 203 | 204 | class SubImage(Sub): 205 | def __init__(self, *args, **kwargs): 206 | super().__init__(image, *args, **kwargs) 207 | 208 | def images(images, grid_figsize=16, **kwargs): 209 | grid(*[SubImage(image, **kwargs) for image in images], figsize=grid_figsize) 210 | 211 | def image(image, title=None, subplot=plt, figsize=(9,5), legend=None, **kwargs): 212 | C, H, W = image.shape 213 | 214 | is_plt = subplot is plt 215 | if is_plt: 216 | plt.figure(figsize=figsize) 217 | 218 | if isinstance(image, torch.Tensor): 219 | image = image.cpu().detach().numpy() 220 | 221 | if C == 1: 222 | image = image[0] 223 | if 'cmap' not in kwargs: 224 | kwargs = dict(cmap='gray', vmin=0, vmax=1) | kwargs 225 | elif legend is None: 226 | legend = True 227 | elif C == 3: 228 | image = image.transpose(1, 2, 0) 229 | else: 230 | assert False 231 | 232 | im = subplot.imshow(image, **kwargs) 233 | if legend: 234 | if is_plt: 235 | plt.colorbar() 236 | else: 237 | cax = mpl_toolkits.axes_grid1.make_axes_locatable(subplot).append_axes("right", size="5%", pad=0.05) 238 | plt.colorbar(im, cax=cax) 239 | if title is not None: 240 | subplot.set_title(title) 241 | subplot.axis('off') 242 | if is_plt: 243 | plt.show() 244 | 245 | def hist(x, *args, show=True, **kwargs): 246 | kwargs = kwargs | dict(bins='auto', density=True, histtype='step') 247 | plt.hist(tensor_items(x), *args, **kwargs) 248 | if show: 249 | plt.show() 250 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import psutil 2 | import os 3 | import sys 4 | import contextlib 5 | # import resource 6 | import dataclasses 7 | # import math 8 | # from collections.abc import Iterable 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | from torch.nn import functional as F 14 | 15 | import tiktoken 16 | import sentencepiece as spm 17 | 18 | # https://twitter.com/karpathy/status/1621578354024677377 19 | # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc 20 | DIM_MULT = 64 21 | 22 | def interactive_mode(): 23 | return hasattr(sys, 'ps1') 24 | 25 | def notNone(x): 26 | return True if x is not None else None 27 | 28 | def chrs(ts): 29 | def chr_(t): 30 | return chr(t if 32 <= t <= 126 or (161 <= t and t != 173) else 164) # or (128 <= t <= 130) 31 | return ''.join(chr_(t) for t in ts) 32 | 33 | def mean2(x: torch.Tensor): 34 | return (x*x).mean().sqrt() 35 | 36 | def norm2(x: torch.Tensor): 37 | return torch.dot(x.flatten(), x.flatten()) 38 | 39 | def ceil_div(n: int, k: int) -> int: 40 | return (n+k-1)//k 41 | 42 | def ceil(n: int, k: int) -> int: 43 | return ceil_div(n, k)*k 44 | 45 | def int_div(x, y): 46 | div, rem = divmod(x, y) 47 | assert rem == 0 48 | return div 49 | 50 | def is_pow2(n:int): 51 | return n.bit_count() == 1 52 | 53 | class MeanError: 54 | def __init__(self): 55 | self.n = 0 56 | self.sum = 0 57 | self.sum_squares = 0 58 | 59 | def add(self, x): 60 | self.n += 1 61 | self.sum += x 62 | self.sum_squares += x*x 63 | 64 | def mean(self): 65 | return self.sum / self.n 66 | 67 | def error(self): 68 | n = self.n 69 | if n == 0: 70 | return float('nan') * self.sum 71 | err = (self.sum_squares/n - (self.sum/n)**2) / (n-1) 72 | if isinstance(err, float): 73 | err = max(0, err) 74 | elif isinstance(err, torch.Tensor): 75 | err = err.clamp(min=0) 76 | elif isinstance(err, np.ndarray): 77 | err = np.clip(err, 0, None) 78 | else: 79 | print(type(err)) 80 | assert False 81 | return err**0.5 82 | 83 | byte_BOS = 255 84 | 85 | class Tokenizer: 86 | tiktoken_encodings = ['gpt2', 'cl100k_base'] 87 | 88 | def __init__(self, tokenizer: str = None): 89 | self.name = tokenizer 90 | 91 | if tokenizer is None: 92 | self.tokenizer = None 93 | self.vocab_size = 256 94 | self.BOS = byte_BOS 95 | self.file_suffix = '.txt' 96 | elif tokenizer in Tokenizer.tiktoken_encodings: 97 | self.tokenizer = tiktoken.get_encoding(tokenizer) 98 | self.vocab_size = self.tokenizer.n_vocab 99 | self.BOS = self.tokenizer.eot_token 100 | self.file_suffix = '.' + tokenizer 101 | else: 102 | self.tokenizer = spm.SentencePieceProcessor(model_file=tokenizer + '.model') 103 | self.vocab_size = self.tokenizer.get_piece_size() 104 | self.BOS = self.tokenizer.piece_to_id('') 105 | self.file_suffix = '.' + tokenizer.split('/')[-1] 106 | 107 | assert self.vocab_size <= 2**32 108 | self.dtype = np.uint8 if self.vocab_size <= 2**8 else \ 109 | np.uint16 if self.vocab_size <= 2**16 else np.uint32 110 | 111 | def encode(self, text, prepend_BOS=False, dtype=torch.int64, tensor=torch.tensor, **kwargs): 112 | if self.tokenizer is not None: 113 | ret = self.tokenizer.encode_ordinary(text) if hasattr(self.tokenizer, 'encode_ordinary') else \ 114 | self.tokenizer.encode(text) # for sentencepiece 115 | if prepend_BOS: 116 | ret = [self.BOS] + ret 117 | else: 118 | ret = text.encode('utf-8') 119 | if prepend_BOS: 120 | ret = bytes([self.BOS]) + ret 121 | ret = bytearray(ret) 122 | return tensor(ret, dtype=dtype, **kwargs) 123 | 124 | def decode(self, tokens): 125 | if self.tokenizer is not None: 126 | return self.tokenizer.decode(list(tokens)) 127 | else: 128 | return tensor_to_str(tokens, errors='replace') 129 | 130 | # kwargs example: errors='replace' 131 | def tensor_to_str(x, **kwargs): 132 | return x.cpu().to(torch.uint8).numpy().tobytes().decode('utf-8', **kwargs) 133 | 134 | def autocast_context(dtype): 135 | # torch.backends.cuda.enable_mem_efficient_sdp(False) 136 | # torch.backends.cuda.enable_math_sdp(False) 137 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 138 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 139 | if isinstance(dtype, str): 140 | dtype = eval(f'torch.{dtype}') 141 | return torch.amp.autocast(device_type='cuda', dtype=dtype) \ 142 | if dtype != torch.float32 else contextlib.nullcontext() 143 | 144 | def get_memory_stats(*, device): 145 | ret = ''.join([ 146 | f'{psutil.Process(os.getpid()).memory_info().rss/2**30:.1f}GB RAM', 147 | f', {torch.cuda.max_memory_allocated()/2**30:.1f}GB cuda' if 'cuda' in device else '', 148 | f', {torch.mps.driver_allocated_memory()/2**30:.1f}GB mps' if 'mps' in device else '' ]) 149 | if 'cuda' in device: 150 | torch.cuda.reset_peak_memory_stats() 151 | return ret 152 | 153 | def like(tensor: torch.Tensor): 154 | return {'dtype': tensor.dtype, 'device': tensor.device} 155 | 156 | def num_params(module): 157 | if isinstance(module, nn.Module): 158 | return sum(p.numel() for p in module.parameters()) 159 | else: 160 | return module.numel() 161 | 162 | def log_forward(forward): 163 | def logged_forward(self, x, *args, log=None, **kwargs): 164 | y = forward(self, x, *args, log=log, **kwargs) 165 | # if log is not None: 166 | # log[f'{self.module_name}.x'] = mean2(x.detach()) 167 | # log[f'{self.module_name}.y'] = mean2(y.detach()) 168 | return y 169 | return logged_forward 170 | 171 | def tensor_items(xs, dtype=None): 172 | if isinstance(xs, list): 173 | return [tensor_items(x) for x in xs] 174 | elif isinstance(xs, tuple): 175 | return tuple(tensor_items(x) for x in xs) 176 | elif isinstance(xs, dict): 177 | return {k: tensor_items(v) for k,v in xs.items()} 178 | elif isinstance(xs, torch.Tensor): 179 | if dtype is None: 180 | dtype = xs.dtype 181 | if dtype == torch.bfloat16: 182 | dtype = torch.float32 183 | return xs.item() if xs.dim()==0 else xs.detach().cpu().to(dtype=dtype).numpy() 184 | elif isinstance(xs, np.ndarray): 185 | return xs if dtype is None else xs.astype(dtype, copy=True) 186 | elif hasattr(xs, '__next__') and not hasattr(xs, '__getitem__'): 187 | return (tensor_items(x) for x in xs) 188 | else: 189 | return xs 190 | 191 | def default_device(): 192 | return 'cuda' if torch.cuda.is_available() else \ 193 | 'mps' if torch.backends.mps.is_available() else 'cpu' 194 | 195 | def synchronize_device(device): 196 | device = torch.device(device).type 197 | eval(f'torch.{device}.synchronize')() 198 | 199 | def empty_cache(device): 200 | def device_is(dev): 201 | return dev in device if isinstance(device, str) else dev == torch.device(device).type 202 | 203 | if device_is('cuda'): 204 | torch.cuda.empty_cache() 205 | elif device_is('mps'): 206 | torch.mps.empty_cache() 207 | 208 | def make_dataclasses(data_classes, **kwargs): 209 | field_typess = [ {field.name : field.type for field in dataclasses.fields(data_class)} 210 | for data_class in data_classes ] 211 | dicts = [{} for _ in data_classes] 212 | 213 | for k, v in kwargs.items(): 214 | used = False 215 | for field_types, dict0 in zip(field_typess, dicts): 216 | field_type = field_types.get(k) 217 | if field_type is not None: 218 | assert not used 219 | used = True 220 | if v is not None: 221 | dict0[k] = field_type(v) 222 | 223 | if not used: 224 | raise Exception(f'make_dataclasses: {k} not found in {data_classes}') 225 | 226 | return [data_class(**dict0) for data_class, dict0 in zip(data_classes, dicts)] 227 | 228 | # def entropy(logits): 229 | # return - (logits.softmax(-1) * logits.log_softmax(-1)).sum(-1) # todo inner product 230 | 231 | def cross_entropy(logits, targets, reduction='mean', ignore_index=None): 232 | B, T, V = logits.shape 233 | 234 | if reduction == 'mean' and ignore_index is None: 235 | return F.cross_entropy(logits.reshape(B*T, V), targets.reshape(B*T), reduction='mean') 236 | 237 | cross_entropy = F.cross_entropy(logits.reshape(B*T, V), targets.reshape(B*T), 238 | ignore_index=ignore_index if ignore_index is not None else -100, 239 | reduction='none').view(B, T) 240 | if reduction == 'none': 241 | return cross_entropy 242 | 243 | if reduction == 'batch': 244 | return cross_entropy.sum(0) / (targets >= 0).sum(0) # T 245 | 246 | # NOTE: pytorch takes the mean over batch and context at the same time, 247 | # which isn't invariant under changes of micro-batch size when some indices are ignored 248 | # to fix this, we take the mean over the context before taking a separate mean over the batch 249 | 250 | cross_entropy = cross_entropy.sum(1) / (targets >= 0).sum(1) # B 251 | if reduction == 'context': 252 | return cross_entropy 253 | 254 | assert reduction == 'mean' 255 | return cross_entropy.mean() 256 | 257 | # Tee 258 | 259 | class Tee: 260 | def __init__(self, file_name): 261 | self.file = None 262 | if Tee._TeeStreams is None: 263 | # do nothing if sys.stdout or sys.stderr have been modified (e.g. in jupyter) 264 | if sys.stdout is not sys.__stdout__ or sys.stderr is not sys.__stderr__: 265 | return 266 | sys.stdout = Tee._TeeStream(sys.stdout) 267 | sys.stderr = Tee._TeeStream(sys.stderr) 268 | Tee._TeeStreams = (sys.stdout, sys.stderr) 269 | 270 | self.file = open(file_name, 'w', buffering=1) 271 | Tee._files.append(self.file) 272 | 273 | _files = [] 274 | _TeeStreams = None 275 | 276 | def __del__(self): 277 | if self.file is not None: 278 | self.file.close() 279 | Tee._files.remove(self.file) 280 | self.file = None 281 | 282 | class _TeeStream: 283 | def __init__(self, stream): 284 | self.stream = stream 285 | 286 | def write(self, data): 287 | self.stream.write(data) 288 | for f in Tee._files: 289 | f.write(data) 290 | 291 | def flush(self): 292 | self.stream.flush() 293 | for f in Tee._files: 294 | f.flush() 295 | -------------------------------------------------------------------------------- /transformer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | from dataclasses import dataclass 4 | # from typing import Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import functional as F 9 | import torch.utils.checkpoint 10 | 11 | try: 12 | import flash_attn 13 | except: 14 | flash_attn = None 15 | 16 | from rope import * 17 | 18 | import util 19 | from util import log_forward, num_params, notNone, like, int_div 20 | 21 | @dataclass 22 | class TransformerConfig: 23 | vocab_size: int = None 24 | BOS: int = None 25 | 26 | tokenizer: str = None 27 | d_model: int = 256 28 | context_size: int = None 29 | d_key: int = 64 30 | d_ff_mult: int = 4 31 | n_layers: int = None 32 | tie_embedding: bool = None 33 | 34 | # "Small-scale proxies for large-scale Transformer training instabilities" 35 | qk_layer_norm: bool = True 36 | 37 | # 'muTransfer' = "Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer" 38 | init: str = 'simple' # 'standard', 'muTransfer' 39 | 40 | nonlinear: str = 'ReLU' # ReLU, GELU 41 | rope: bool = False 42 | position_encoding: bool = True 43 | attention_groups: int = None # grouped-query attention 44 | attention_window: int = None 45 | sparse_attention: bool = None # requires attention_window 46 | layer_norm: str = 'LayerNorm' # 'RMSNorm' 47 | 48 | def padded_vocab_size(self): 49 | d = 64 50 | return util.ceil(self.vocab_size, d) 51 | 52 | def __post_init__(self): 53 | d = self.d_model 54 | 55 | if self.context_size is None: 56 | self.context_size = d 57 | 58 | if self.n_layers is None: 59 | # n_layers ~ (log(d) - 5.039)/0.0555) from Eq 11 in "The Depth-to-Width Interplay in Self-Attention" 60 | self.n_layers = max(3, round(12.49*math.log2(d/154.3))) 61 | 62 | if self.tie_embedding is None: 63 | self.tie_embedding = self.tokenizer is not None 64 | 65 | assert self.init in ('standard', 'muTransfer', 'simple') 66 | 67 | def copy(self, **kwargs): 68 | c = copy.copy(self) 69 | for k, v in kwargs.items(): 70 | setattr(c, k, v) 71 | return c 72 | 73 | class Model(nn.Module): 74 | def __init__(self): 75 | super().__init__() 76 | 77 | def __post_init__(self): 78 | for name, module in self.named_modules(): 79 | module.module_name = name 80 | for name, param in self.named_parameters(): 81 | param.parameter_name = name 82 | 83 | def n_flops(self, training=True, average=False): 84 | return 2*(1 + 2*training)*self.n_mult_add() 85 | 86 | def num_params(self, embedding=True): 87 | return util.num_params(self) 88 | 89 | def next_iter(self, train_fraction): 90 | return dict(warmup_lr_again=False, check_grads=True) 91 | 92 | def train_log(self, optimizer_state): 93 | return {} 94 | 95 | class NoModel(Model): 96 | Config = TransformerConfig 97 | 98 | def __init__(self, config: TransformerConfig): 99 | super().__init__() 100 | self.config = config 101 | self.param = nn.Parameter(torch.tensor(1.)) 102 | super().__post_init__() 103 | 104 | def n_mult_add(self): 105 | return 1 106 | 107 | def forward(self, tokens, targets=None, *, log=None): 108 | B, T = tokens.shape 109 | logits = torch.zeros(B, T, self.config.vocab_size, device=tokens.device) 110 | losses = {'loss': self.param**2} 111 | return logits, losses 112 | 113 | class Transformer(Model): 114 | Config = TransformerConfig 115 | 116 | def __init__(self, config: TransformerConfig): 117 | super().__init__() 118 | self.config = config 119 | c = self.config 120 | 121 | self.token_embedding = nn.Embedding( 122 | c.padded_vocab_size() if c.tie_embedding else c.vocab_size, c.d_model) 123 | if c.position_encoding: 124 | self.position_encoding = nn.Parameter(torch.randn(c.context_size, c.d_model)) 125 | 126 | self.blocks = nn.ModuleList([TransformerBlock( 127 | c.copy(sparse_attention=(l%2) and c.sparse_attention) 128 | ) for l in range(c.n_layers)]) 129 | 130 | self.logits = Logits(c, self.token_embedding if c.tie_embedding else None) 131 | 132 | super().__post_init__() 133 | 134 | def num_params(self, embedding=True): 135 | c = self.config 136 | n = num_params(self) 137 | if not embedding: 138 | if not c.tie_embedding: 139 | n -= num_params(self.token_embedding) 140 | if c.position_encoding: 141 | n -= num_params(self.position_encoding) 142 | return n 143 | 144 | def n_mult_add(self): 145 | c = self.config 146 | T = c.context_size 147 | d = c.d_model 148 | V = c.vocab_size 149 | # simple case: L*T*(4*d*d + 2*d*W + 2*d*d_ff) 150 | return sum(block.n_mult_add(T) for block in self.blocks) + T*d*V 151 | 152 | def forward(self, tokens, targets=None, *, cache=None, log=None, ignore_index=None): 153 | assert targets is None or cache is None 154 | _, Tx = tokens.shape 155 | 156 | x = self.token_embedding(tokens) 157 | 158 | if cache is None: 159 | t = 0 160 | else: 161 | prefix = self.module_name + '->' 162 | t = cache.get(prefix+'t', 0) 163 | cache[prefix+'t'] = t + Tx 164 | 165 | if self.config.position_encoding: 166 | x = x + self.position_encoding[t:t+Tx] 167 | 168 | for block in self.blocks: 169 | x = block(x, cache=cache, cache_seqlen=t, log=log) 170 | 171 | logits = self.logits(x, log=log) 172 | 173 | losses = None 174 | if targets is not None: 175 | losses = {} 176 | losses['cross entropy'] = util.cross_entropy(logits, targets, ignore_index=ignore_index) 177 | losses['loss'] = losses['cross entropy'] 178 | 179 | return logits, losses 180 | 181 | def generate(self, tokens, *, max_tokens=None, temperature=1.0, top_k=None, input_lengths=None, 182 | logits=False, use_cache=True, check_logits_func=None): 183 | if max_tokens is None: 184 | max_tokens = self.config.context_size + 1 185 | else: 186 | assert max_tokens <= self.config.context_size + 1 187 | B, Tx = tokens.shape 188 | if input_lengths is None: 189 | input_lengths = torch.full((B,), Tx, device=tokens.device) 190 | V = self.config.vocab_size 191 | if max_tokens > Tx: 192 | tokens = torch.cat([tokens, torch.zeros((B, max_tokens-Tx), **like(tokens))], 1) 193 | else: 194 | tokens = tokens[:, :max_tokens].clone() 195 | 196 | save_logits = logits or check_logits_func is not None 197 | if save_logits: 198 | generated_logits = torch.full((B, max_tokens - 1, V), float('nan'), device=tokens.device) 199 | 200 | t = 0 201 | cache = {} 202 | with torch.no_grad(): 203 | while True: 204 | current_length = t + 1 if t > 0 else input_lengths.min() 205 | if current_length >= max_tokens: 206 | break 207 | 208 | if use_cache: 209 | next_logits, _ = self(tokens[:, t:current_length], cache=cache) 210 | else: 211 | next_logits, _ = self(tokens[:, :current_length]) 212 | if check_logits_func is not None and t > 0: 213 | check_logits_func(next_logits[:, :t] - generated_logits[:, :t]) 214 | next_logits = next_logits[:, t:current_length] 215 | 216 | if save_logits: 217 | generated_logits[:, t:current_length] = next_logits 218 | t = current_length 219 | next_logits = next_logits[:, -1] / temperature # (B, V) 220 | if top_k is not None and top_k < V: 221 | v, _ = torch.topk(next_logits, top_k) 222 | next_logits[next_logits < v[:, -1:]] = -math.inf 223 | 224 | next_token = torch.multinomial(F.softmax(next_logits, dim=-1), num_samples=1)[:,0] # (B, 1) 225 | next_token = next_token.where(current_length >= input_lengths, tokens[:, t]) 226 | tokens[:, t] = next_token 227 | del cache 228 | 229 | return tokens, generated_logits if logits else None 230 | 231 | class Logits(nn.Module): 232 | def __init__(self, config, tied_token_embedding=None): 233 | super().__init__() 234 | self.vocab_size = config.vocab_size 235 | 236 | self.layer_norm = eval(config.layer_norm)(config.d_model, config) 237 | self.logits_linear = tied_embedding_linear(tied_token_embedding, config) \ 238 | if tied_token_embedding is not None else \ 239 | Linear(config.d_model, config.padded_vocab_size(), config, output=True) 240 | 241 | def forward(self, x, *, log): 242 | x = self.layer_norm(x, log=log) 243 | x = self.logits_linear(x, log=log) 244 | x = x[:, :, :self.vocab_size] # throw away padding 245 | return x 246 | 247 | def tied_embedding_linear(embedding, config): 248 | if config.init == 'standard': 249 | mult = 1 250 | elif config.init == 'muTransfer': 251 | mult = 1 / config.d_model 252 | elif config.init == 'simple': 253 | mult = config.d_model ** -0.5 # PaLM does this: "Because the input and output embedding layers are shared, we scale the pre-softmax output logits by 1/√n, where n is the embedding size." 254 | else: 255 | assert False 256 | 257 | return lambda x, *, log=None: F.linear(mult * x, embedding.weight) 258 | 259 | class TransformerBlock(nn.Module): 260 | def __init__(self, config): 261 | super().__init__() 262 | 263 | self.attention = SelfAttention(config) 264 | self.feedforward = FeedForward(config) 265 | 266 | def n_mult_add(self, T): 267 | return self.attention.n_mult_add(T) + self.feedforward.n_mult_add(T) 268 | 269 | def forward(self, x, *, log, **kwargs): 270 | x = x + self.attention(x, log=log, **kwargs) 271 | x = x + self.feedforward(x, log=log) 272 | return x 273 | 274 | class SelfAttention(nn.Module): 275 | def __init__(self, config: TransformerConfig): 276 | super().__init__() 277 | d = config.d_model 278 | self.d_model = d 279 | self.d_key = config.d_key 280 | self.init = config.init 281 | self.attention_window = config.attention_window 282 | self.sparse_attention = config.sparse_attention 283 | self.context_size = config.context_size 284 | assert self.attention_window is None or self.attention_window > 0 285 | n_head = int_div(d, self.d_key) 286 | 287 | self.layer_norm = eval(config.layer_norm)(d, config) 288 | self.attention_groups = config.attention_groups 289 | g = config.attention_groups 290 | if g is None: 291 | self.QKV_linear = Linear(d, 3*d, config) 292 | else: 293 | assert n_head % g == 0 294 | D_k = g * self.d_key 295 | self.Q_linear = Linear(d, d, config) 296 | self.KV_linear = Linear(d, 2*D_k, config) 297 | self.linear = Linear(d, d, config) 298 | 299 | self.qk_layer_norm = config.qk_layer_norm 300 | if self.qk_layer_norm: 301 | self.Q_layer_norm = eval(config.layer_norm)(self.d_key, config) 302 | self.K_layer_norm = eval(config.layer_norm)(self.d_key, config) 303 | 304 | self.rope = config.rope 305 | if config.rope: 306 | freqs_cis = precompute_freqs_cis(self.d_key, config.context_size) 307 | self.register_buffer('freqs_cis', freqs_cis, persistent=False) 308 | 309 | def n_mult_add(self, T): 310 | W = self.attention_window 311 | if W is None: 312 | W = T 313 | d = self.d_model 314 | return T*(4*d*d + 2*d*W) 315 | 316 | def forward(self, x, *, log, cache=None, cache_seqlen=None): 317 | if self.sparse_attention: 318 | B, Tx, d = x.shape 319 | W = self.attention_window 320 | x = x.view(B, Tx//W, W, d).transpose(1, 2).reshape(B*W, Tx//W, d) 321 | 322 | B, Tx, d = x.shape 323 | d_k = self.d_key 324 | n_h = d // d_k 325 | 326 | x = self.layer_norm(x, log=log) 327 | 328 | g = self.attention_groups 329 | if g is None: 330 | Q, K, V = self.QKV_linear(x, log=log).view(B, Tx, 3*n_h, d_k).chunk(3, dim=2) 331 | else: 332 | Q = self.Q_linear(x, log=log).view(B, Tx, n_h, d_k) 333 | K, V = self.KV_linear(x, log=log).view(B, Tx, 2*g, d_k).chunk(2, dim=2) 334 | 335 | if self.qk_layer_norm: 336 | Q = self.Q_layer_norm(Q, log=log) 337 | K = self.K_layer_norm(K, log=log) 338 | Q = Q.to(V.dtype) 339 | K = K.to(V.dtype) 340 | 341 | if cache is None: 342 | t = 0 343 | # ignore cache_seqlens 344 | else: 345 | prefix = self.module_name + '->' 346 | assert cache_seqlen is not None 347 | t = cache_seqlen 348 | 349 | if self.rope: 350 | Q, K = apply_rotary_emb(Q, K, self.freqs_cis[t:t+Tx]) 351 | 352 | if self.init == 'muTransfer' and not self.qk_layer_norm: 353 | Q = Q * d_k**-0.5 354 | 355 | if cache is not None: 356 | if prefix+'KV' not in cache: 357 | cache[prefix+'KV'] = torch.stack([K, V]) 358 | else: 359 | cache_KV = cache[prefix+'KV'] 360 | cache_T = cache_KV.shape[1+1] 361 | if t+Tx >= cache_T: 362 | padding_shape = list(cache_KV.shape) 363 | padding_shape[1+1] = min(2**round(math.log2(t+Tx) + 1), self.context_size) - cache_T 364 | cache_KV = torch.cat([ 365 | cache_KV, 366 | torch.zeros(*padding_shape, **like(cache_KV)) 367 | ], 1+1) 368 | cache[prefix+'KV'] = cache_KV 369 | 370 | cache_KV[0, :, t:t+Tx] = K 371 | cache_KV[1, :, t:t+Tx] = V 372 | K, V = cache_KV[:, :, :t+Tx] 373 | 374 | attention_window = None if self.sparse_attention else self.attention_window 375 | if flash_attn is not None and Q.device.type == 'cuda': 376 | window_size = (-1,-1) if attention_window is None else (attention_window-1, 0) 377 | x = flash_attn.flash_attn_func(Q, K, V, causal=True, window_size=window_size) # (B, T, n_h, d_k) 378 | else: 379 | assert Q.device.type != 'cuda' or Tx > 1 # else significant performance loss 380 | Tk = K.shape[1] 381 | 382 | mask = None 383 | if cache is not None or attention_window is not None: 384 | q_ts = torch.arange(t, t+Tx, **like(Q))[:,None] 385 | k_ts = torch.arange(Tk, **like(Q)) 386 | mask = q_ts >= k_ts 387 | if attention_window is not None: 388 | mask &= q_ts - k_ts <= attention_window 389 | 390 | Q = Q.transpose(1, 2) 391 | K = K.transpose(1, 2) 392 | V = V.transpose(1, 2) 393 | if g is not None: 394 | Q = Q.view(B, g, n_h//g, Tx, d_k) 395 | K = K.view(B, g, 1, Tk, d_k) 396 | V = V.view(B, g, 1, Tk, d_k) 397 | if mask is not None: 398 | mask = mask[None] 399 | 400 | x = F.scaled_dot_product_attention(Q, K, V, attn_mask=mask, is_causal=mask is None) 401 | x = x.view(B, n_h, Tx, d_k).transpose(1, 2).contiguous() 402 | 403 | x = x.view(B, Tx, d) 404 | x = self.linear(x, log=log) 405 | 406 | if self.sparse_attention: 407 | BW, Tx_W, d = x.shape 408 | B = BW // W 409 | Tx = Tx_W * W 410 | W = self.attention_window 411 | x = x.view(B, W, Tx//W, d).transpose(1, 2).reshape(B, Tx, d) 412 | 413 | return x 414 | 415 | class FeedForward(nn.Module): 416 | def __init__(self, config): 417 | super().__init__() 418 | d = config.d_model 419 | 420 | ReLU = nn.ReLU() 421 | GELU = nn.GELU(approximate='tanh') 422 | SiLU = nn.SiLU() 423 | self.nonlinear = eval(config.nonlinear) 424 | 425 | d_ff = config.d_ff_mult * d 426 | self.layer_norm = eval(config.layer_norm)(d, config) 427 | self.linear_1 = Linear(d, d_ff, config) 428 | self.linear_2 = Linear(d_ff, d, config) 429 | 430 | def n_mult_add(self, T): 431 | return T*(num_params(self.linear_1) + num_params(self.linear_2)) 432 | 433 | def forward(self, x, *, log): 434 | x = self.layer_norm(x, log=log) 435 | x = self.linear_1(x, log=log) 436 | x = self.nonlinear(x) 437 | x = self.linear_2(x, log=log) 438 | return x 439 | 440 | class Linear(nn.Module): 441 | '''Linear with normal-distributed initialization and no bias''' 442 | def __init__(self, d_in, d_out, config, output=False): 443 | super().__init__() 444 | 445 | self.mult = 1 446 | if config.init == 'standard': 447 | std = d_in ** -0.5 448 | lr_mult = 1 449 | elif config.init == 'muTransfer' and not output: 450 | # [1] "Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer" 451 | # # hidden weights column of Table 3 and 8 of [1]: 452 | std = d_in ** -0.5 453 | lr_mult = 1 / d_in 454 | 455 | # hidden weights column of Table 8 of [1] modifed to have std = 1: 456 | # std = 1 457 | # self.mult = d_in ** -0.5 458 | # lr_mult = d_in ** -0.5 459 | elif config.init == 'muTransfer' and output: 460 | # output weights column of Table 3 of [1]: 461 | std = 1 / d_in 462 | lr_mult = 1 / d_in 463 | 464 | # output weights column of Table 8 of [1]: 465 | # std = 1 466 | # self.mult = 1 / d_in 467 | # lr_mult = 1 468 | elif config.init == 'simple': 469 | std = d_in ** -0.5 470 | lr_mult = d_in ** -0.5 471 | 472 | # std = 1 473 | # self.mult = d_in ** -0.5 474 | # lr_mult = 1 475 | else: 476 | assert False 477 | 478 | self.weight = nn.Parameter(torch.normal( 0, std, (d_out, d_in) )) 479 | 480 | if lr_mult != 1: 481 | self.weight.lr_mult = lr_mult 482 | 483 | @log_forward 484 | def forward(self, x, *, log): 485 | if self.mult != 1: 486 | x = self.mult * x 487 | 488 | return F.linear(x, self.weight) 489 | 490 | def num_linear_params(module): 491 | return sum(num_params(submod) for submod in module.modules() if isinstance(submod, Linear)) 492 | 493 | class LayerNorm(nn.Module): 494 | def __init__(self, d, config): 495 | super().__init__() 496 | self.d = d 497 | self.scale = nn.Parameter(torch.ones(d)) 498 | 499 | @log_forward 500 | def forward(self, x, *, log): 501 | return F.layer_norm(x, (self.d,), self.scale, bias=None) 502 | 503 | class RMSNorm(nn.Module): 504 | def __init__(self, d, config): 505 | super().__init__() 506 | self.scale = nn.Parameter(torch.ones(d)) 507 | 508 | @log_forward 509 | def forward(self, x, *, log): 510 | eps = 1e-5 511 | return x * self.scale / (x.var(-1, keepdim=True, correction=0) + eps).sqrt() 512 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | from pathlib import Path 5 | import sys 6 | import shutil 7 | import time 8 | import datetime 9 | import math 10 | # import itertools 11 | import copy 12 | import contextlib 13 | import collections 14 | import dataclasses 15 | from dataclasses import dataclass 16 | # from typing import List, Optional, Tuple, Union 17 | from tqdm import tqdm 18 | 19 | import wandb 20 | import torch 21 | import torch.nn as nn 22 | from torch.nn import functional as F 23 | from torch.optim import Adam, NAdam, AdamW # needed for TrainConfig.optimizer 24 | 25 | import util 26 | from util import notNone 27 | import data 28 | from transformer import Transformer, TransformerConfig, NoModel 29 | from spacebyte import SpaceByte, SpaceByteConfig 30 | from megabyte import MegaByte, MegaByteConfig 31 | 32 | @dataclass 33 | class TrainConfig: 34 | # model 35 | model: str = 'Transformer' 36 | model_seed: int = 0 37 | 38 | # command line 39 | args: dict = None 40 | note: str = '' 41 | 42 | # data 43 | data_seed: int = 0 44 | dataset: str = 'pg19' 45 | batch_size: int = 64 46 | micro_batch_size: int = None 47 | 48 | # checkpoint 49 | load_checkpoint_dir: str = None 50 | checkpoint: bool = True 51 | checkpoint_model: bool = True 52 | min_checkpoint_interval: float = 10*60 # in seconds 53 | 54 | # optimization 55 | lr: str = '0.5e-2 * B**0.5' # can use: B, T, d 56 | iters: str = '20 * N / tokens' # can use: N, num embedding E, tokens, and flops 57 | optimizer: str = 'AdamW' 58 | optimizer_kwargs: str = 'dict(fused=using_cuda)' 59 | decay_lr: str = 'half-cos' # 'half-cos', 'cos' 60 | warmup_iters: str = 'iters/100' 61 | beta1: float = 0.9 62 | beta2: str = 0.98 63 | # note: pytorch does not implement weight_decay as in "Decoupled Weight Decay Regularization" since pytorch multiplies the weight decay by the max learning rate while the paper does not 64 | weight_decay: float = 0.01 65 | grad_clip: float = 1.0 66 | 67 | # system 68 | device: str = None # 'cpu', 'mps', 'cuda' 69 | compile: bool = False 70 | dtype: str = None # 'float32' for no autocast. 'bfloat16' or 'float16' for autocast. 'float16' will use a GradScaler 71 | use_deterministic: bool = False 72 | 73 | # logging 74 | out_dir: str = 'out' # set to '' to disable. auto subdir unless trailing / 75 | log_interval: float = 3 # in seconds 76 | 77 | # evaluating 78 | eval_interval: str = 'iters / 50' 79 | eval_iters: str = 'eval_interval/40 * B/mB' # number of micro-batches for mid-training evaluation 80 | final_eval_iters: str = 'min(iters/30, 2**14) * B/mB' # number of micro-batches for the final evaluation 81 | 82 | # wandb logging 83 | wandb_log: bool = None 84 | wandb_project: str = None 85 | wandb_run_name: str = None 86 | 87 | # debug 88 | checkpoint_nan: bool = False 89 | check_nan: bool = False 90 | 91 | def __post_init__(self): 92 | if self.device is None: 93 | self.device = util.default_device() 94 | 95 | if self.compile is None: 96 | self.compile = 'cuda' in self.device 97 | 98 | if self.dtype is None: 99 | # todo M2 supports bfloat16 https://en.wikipedia.org/wiki/Bfloat16_floating-point_format 100 | self.dtype = 'bfloat16' if 'cuda' in self.device and torch.cuda.is_bf16_supported() else 'float32' 101 | 102 | if self.wandb_log is None: 103 | self.wandb_log = self.out_dir != 'out' and self.out_dir[-1] != '/' 104 | 105 | if self.wandb_project is None: 106 | self.wandb_project = self.out_dir 107 | if self.args is not None: 108 | def run_args(sep=' '): 109 | def simplify(v): 110 | if isinstance(v, tuple): 111 | return ','.join(str(x) for x in v) 112 | else: 113 | return v 114 | return sep.join(f'--{k}={simplify(v)}' for k, v in sorted(self.args.items())) 115 | if self.wandb_run_name is None: 116 | self.wandb_run_name = run_args() 117 | if self.out_dir and self.out_dir[-1] != '/': 118 | self.out_dir = os.path.join(self.out_dir, 'Train' + run_args('')) 119 | 120 | class Train: 121 | def from_checkpoint(dir_name, device=None, 122 | train_config_override=dict(wandb_log=False, out_dir=''), 123 | model_config_override=dict()): 124 | if device is None: 125 | device = util.default_device() 126 | model, checkpoint = Train.model_from_checkpoint(dir_name, device, config_override=model_config_override) 127 | 128 | train_config = checkpoint['train_config'] 129 | for k in list(train_config): 130 | if not hasattr(TrainConfig, k): 131 | print(f"TrainChar: WARNING! '{k}' no longer in TrainCharConfig") 132 | del train_config[k] 133 | train_config['device'] = device 134 | train_config |= train_config_override 135 | train_config = TrainConfig(**train_config) 136 | 137 | return Train(model, train_config, checkpoint=checkpoint) 138 | 139 | def __init__(self, model_or_config, train_config: TrainConfig, checkpoint=None, verbose=None): 140 | super().__init__() 141 | self.train_config = train_config 142 | c = self.train_config 143 | self._estimate_losses_dataset_iters = {} 144 | 145 | # c.out_dir might still be '' 146 | if c.out_dir: 147 | print('out_dir =', c.out_dir) 148 | os.makedirs(c.out_dir, exist_ok = not c.wandb_log) 149 | code_dir = os.path.dirname(os.path.realpath(__file__)) 150 | for f in os.listdir(code_dir): 151 | if f[-3:] == '.py': 152 | shutil.copy(os.path.join(code_dir, f), c.out_dir) 153 | self.tee = util.Tee(os.path.join(c.out_dir, 'stdout.txt')) 154 | 155 | if verbose is None: 156 | verbose = not (util.interactive_mode() and checkpoint is not None) 157 | 158 | if c.use_deterministic: 159 | torch.use_deterministic_algorithms(True) 160 | 161 | if c.check_nan: 162 | torch.set_anomaly_enabled(True, check_nan=True) 163 | 164 | self.autocast = util.autocast_context(c.dtype) 165 | self.grad_scaler = torch.cuda.amp.GradScaler(enabled = 'float16' == c.dtype) 166 | 167 | if isinstance(model_or_config, nn.Module): 168 | self.model = model_or_config 169 | model_config = self.model.config 170 | else: 171 | self.model = None 172 | model_config = model_or_config 173 | 174 | self.dataset = data.dataset(c.dataset, model_config.tokenizer) 175 | if model_config.vocab_size is None: 176 | model_config.vocab_size = self.dataset.vocab_size 177 | else: 178 | assert model_config.vocab_size == self.dataset.vocab_size 179 | if model_config.BOS is None: 180 | model_config.BOS = self.dataset.BOS 181 | else: 182 | assert model_config.BOS == self.dataset.BOS 183 | 184 | self.checkpoint_vars = ['iter_num', 'decay_lr_from_iter', 'lr_sum', 'best_val_loss', 'total_flops', 'total_tokens', 185 | 'train_time', 'eval_time'] 186 | if c.load_checkpoint_dir or checkpoint is not None: 187 | checkpoint = c.load_checkpoint_dir if checkpoint is None else checkpoint 188 | device = self.train_config.device 189 | if self.model is None: 190 | model_config_override = {k: v for k, v in c.args.items() if hasattr(model_config, k)} 191 | self.model, checkpoint = Train.model_from_checkpoint(checkpoint, device, config_override=model_config_override) 192 | for var in self.checkpoint_vars: 193 | if var in checkpoint: 194 | setattr(self, var, checkpoint[var]) 195 | self.iter_num += 1 196 | # we load the optimizer checkpoint after initializing it below 197 | else: 198 | checkpoint = None 199 | torch.manual_seed(c.model_seed) 200 | self.model = eval(c.model)(model_config) 201 | self.model.to(c.device) 202 | for var in self.checkpoint_vars: 203 | setattr(self, var, 0) 204 | self.best_val_loss = math.inf 205 | mc = self.model.config 206 | self.model.dataset_tokenizer = self.dataset.tokenizer # useful for debugging 207 | 208 | N = self.model.num_params() 209 | N_E = self.model.num_params(embedding=False) 210 | T = mc.context_size 211 | B = c.batch_size 212 | 213 | if c.micro_batch_size is None: 214 | c.micro_batch_size = c.batch_size 215 | assert c.batch_size % c.micro_batch_size == 0 216 | mB = c.micro_batch_size 217 | 218 | if isinstance(c.iters, str): 219 | flops = c.batch_size * self.model.n_flops(average=True) 220 | tokens = c.batch_size * T 221 | N = N_E 222 | c.iters = round(eval(c.iters)) 223 | N = self.model.num_params() 224 | del flops, tokens 225 | iters = c.iters 226 | 227 | if isinstance(c.beta2, str): 228 | c.beta2 = eval(c.beta2) 229 | 230 | if isinstance(c.eval_interval, str): 231 | c.eval_interval = math.ceil(eval(c.eval_interval)) 232 | 233 | if isinstance(c.final_eval_iters, str): 234 | c.final_eval_iters = math.ceil(eval(c.final_eval_iters)) 235 | if 0 < c.final_eval_iters < 3: 236 | c.final_eval_iters = 3 # avoid a strange wandb error 237 | if isinstance(c.eval_iters, str): 238 | eval_interval = c.eval_interval 239 | c.eval_iters = math.ceil(eval(c.eval_iters)) 240 | if 0 < c.eval_iters < 3: 241 | c.eval_iters = 3 # avoid a strange wandb error 242 | 243 | d = mc.d_model 244 | if isinstance(c.lr, str): 245 | c.lr = eval(c.lr) 246 | lr = c.lr 247 | 248 | if isinstance(c.warmup_iters, str): 249 | c.warmup_iters = eval(c.warmup_iters) 250 | 251 | self.meta = { 252 | 'parameters': N, 253 | 'non-embedding parameters': N_E, 254 | 'mult-add per token': self.model.n_mult_add() / mc.context_size, 255 | 'train FLOPs per token': self.model.n_flops(average=True) / mc.context_size, 256 | 'train FLOPs': c.iters * B * self.model.n_flops(average=True), 257 | 'train tokens': c.iters * B * T, 258 | 'bytes per token': self.dataset.bytes_per_token, 259 | } 260 | if verbose: 261 | b = {'float32': 4, 'bfloat16': 2, 'float16': 2}[c.dtype] 262 | print(f'total parameters: {N/1e6:,.1f}M') 263 | print(f'total non-embedding parameters: {N_E/1e6:,.1f}M') 264 | print(f'model memory: {b*N/1e6:,.1f}MB') 265 | print(f'model+grad+Adam memory: {(1+1+2)*b*N/1e6:,.1f}MB') 266 | print(f'mB*T*d memory: {b*mB*T*d/1e6:,.1f}MB') 267 | print(f'mB*T*V memory: {b*mB*T*mc.vocab_size/1e6:,.1f}MB') 268 | print(f'train FLOPs per token = {self.meta["train FLOPs per token"] / 1e6:,.2f}M') 269 | print(f'mult-add per token = {self.meta["mult-add per token"] / 1e6:,.2f}M') 270 | print() 271 | 272 | param_groups = [] 273 | for param in self.model.parameters(): 274 | weight_decay = (param.weight_decay_mult if hasattr(param, 'weight_decay_mult') else 1) * c.weight_decay 275 | lr_mult = param.lr_mult if hasattr(param, 'lr_mult') else 1 276 | for group in param_groups: 277 | if math.isclose(weight_decay, group['weight_decay']) and math.isclose(lr_mult, group['lr_mult']): 278 | group['params'].append(param) 279 | break 280 | else: 281 | param_groups.append({'params': [param], 'weight_decay': weight_decay, 'lr_mult': lr_mult}) 282 | if verbose and len(param_groups) > 1: 283 | for group in param_groups: 284 | group = copy.copy(group) 285 | group['params'] = [param.parameter_name for param in group['params']] 286 | print(group) 287 | print() 288 | 289 | using_cuda = 'cuda' in c.device # possibly used by eval(c.optimizer_kwargs) 290 | if isinstance(c.optimizer_kwargs, str): 291 | c.optimizer_kwargs = eval(c.optimizer_kwargs) 292 | self.optimizer = eval(c.optimizer)(param_groups, lr=c.lr, betas=(c.beta1, c.beta2), **c.optimizer_kwargs) 293 | if checkpoint is not None: 294 | self.optimizer.load_state_dict(checkpoint['optimizer']) 295 | 296 | if c.compile: 297 | print('compiling the model...') 298 | self.model = torch.compile(self.model) 299 | 300 | if c.wandb_log: 301 | wandb.init( project=c.wandb_project, name=c.wandb_run_name, save_code=True, # reinit=True, 302 | config=dataclasses.asdict(mc) | dataclasses.asdict(c) | self.meta ) 303 | print() 304 | 305 | if verbose: 306 | for k, v in dataclasses.asdict(mc).items(): 307 | print(f'{k} = {v}') 308 | print() 309 | for k, v in dataclasses.asdict(c).items(): 310 | print(f'{k} = {v}') 311 | print() 312 | 313 | def dataset_iter(self, split, dataset=None, **kwargs): 314 | c = self.train_config 315 | mc = self.model.config 316 | if dataset is None: 317 | dataset = self.dataset 318 | kwargs = dict(context_size=mc.context_size, batch_size=c.micro_batch_size, seed=c.data_seed, device=c.device) | kwargs 319 | return dataset.iter(split, **kwargs) 320 | 321 | def train(self, callback=None): 322 | c = self.train_config 323 | mc = self.model.config 324 | 325 | print(f'tokens todo: {(c.iters-self.iter_num) * c.batch_size * mc.context_size:.4g}') 326 | print(f'FLOPs todo: {(c.iters-self.iter_num) * c.batch_size * self.model.n_flops(average=True):.4g}\n') 327 | 328 | # if c.wandb_log: 329 | # wandb.watch(self.model, log_freq=c.eval_interval) 330 | 331 | self.optimizer.zero_grad() 332 | data_iter = self.dataset_iter('train') 333 | next_tokens = next(data_iter) 334 | 335 | # todo skip data up to self.iter_num in case of loading from a checkpoint 336 | if c.data_seed == 0: 337 | assert self.iter_num == 0 # for now, at least just make sure we change the data seed 338 | 339 | train_time_t0 = time.time() 340 | last_log_time = train_time_t0 341 | last_log_iter = -1 342 | last_checkpoint_time = time.time() 343 | 344 | for iter_num in range(self.iter_num, c.iters): 345 | self.iter_num = iter_num 346 | self.model.train() 347 | model_flags = self.model.next_iter(iter_num / c.iters) 348 | if model_flags['warmup_lr_again']: 349 | print('Train: warmup_lr_again') 350 | self.decay_lr_from_iter = self.iter_num 351 | 352 | lr = c.lr 353 | if c.decay_lr: 354 | t = iter_num/c.iters 355 | if c.decay_lr[:3] == 'cos': 356 | mult = eval(c.decay_lr[3:]) if len(c.decay_lr)>3 else 0.1 357 | lr *= mult + (1-mult) * (math.cos(math.pi*t)+1)/2 358 | elif c.decay_lr == 'half-cos': 359 | lr *= math.cos(0.5*math.pi*t) 360 | else: 361 | assert False 362 | if c.warmup_iters > 0: 363 | lr = lr * min(1, (iter_num+1 - self.decay_lr_from_iter)/c.warmup_iters) 364 | for param_group in self.optimizer.param_groups: 365 | param_group['lr'] = lr * param_group['lr_mult'] 366 | self.lr_sum += lr 367 | 368 | final_iter = iter_num == c.iters-1 369 | run_eval = c.final_eval_iters > 0 if final_iter else \ 370 | c.eval_iters > 0 and c.eval_interval > 0 and iter_num % c.eval_interval == 0 and iter_num > 0 371 | 372 | n_micro_batches = util.int_div(c.batch_size, c.micro_batch_size) 373 | for b in range(n_micro_batches): 374 | tokens, targets = next_tokens 375 | model_log = None # {} if run_eval and b+1 == n_micro_batches else None 376 | with self.autocast: 377 | _, losses = self.model(tokens, targets, log=model_log) 378 | loss = losses['loss'] 379 | 380 | next_tokens = next(data_iter) 381 | 382 | if c.checkpoint_nan and math.isnan(loss): 383 | if c.out_dir: 384 | self.checkpoint(c.out_dir + '_nan') 385 | raise Exception('nan') 386 | 387 | self.grad_scaler.scale(loss).backward() 388 | 389 | if c.grad_clip > 0: 390 | self.grad_scaler.unscale_(self.optimizer) 391 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), c.grad_clip) 392 | self.grad_scaler.step(self.optimizer) 393 | self.grad_scaler.update() 394 | 395 | # check that all parameters got a gradient, else we likely have unused parameters 396 | if iter_num == self.decay_lr_from_iter and model_flags['check_grads']: 397 | for param in self.model.parameters(): 398 | assert param.grad is not None, param.parameter_name 399 | 400 | if model_log is not None: 401 | model_log = model_log | self.model.train_log(self.optimizer) 402 | self.optimizer.zero_grad(set_to_none=True) 403 | 404 | flops = c.batch_size * self.model.n_flops() 405 | self.total_flops += flops 406 | self.total_tokens += n_micro_batches * tokens.numel() 407 | 408 | t1 = time.time() 409 | print_stats = run_eval or t1 - last_log_time > c.log_interval or iter_num == 0 410 | if print_stats: 411 | dt = (t1 - last_log_time) / (iter_num - last_log_iter) 412 | last_log_time = t1 413 | flops /= dt 414 | tps = tokens.numel() / dt 415 | last_log_iter = iter_num 416 | 417 | # print(util.chrs(tokens[0,:128].cpu())) 418 | # print(util.chrs(targets[0,:128].cpu())) 419 | 420 | # single print statement so all lines are printed at the same time 421 | eta = str(datetime.timedelta(seconds=round(dt*(c.iters-iter_num)))) 422 | print(f'iter {iter_num:6} ({100*iter_num/c.iters:7.3g}%,', 423 | f'{eta:15}):', 424 | ', '.join(f'{name} {loss:<8.3g}' for name, loss in losses.items()) + ',', 425 | f'{dt*1000:4.0f}ms, {flops/1e12:5.3g} TFLOP/s, {tps/1000:.1f}k tps,', 426 | util.get_memory_stats(device=c.device) + ',', 427 | datetime.datetime.now().strftime('%m/%d %H:%M:%S') ) 428 | assert not math.isnan(loss) 429 | del dt 430 | 431 | if callback is not None: 432 | callback(**locals()) 433 | del loss, losses 434 | 435 | # evaluate the loss, log, and write checkpoints 436 | if run_eval: 437 | util.synchronize_device(c.device) 438 | util.empty_cache(c.device) 439 | eval_start_time = time.time() 440 | self.train_time += eval_start_time - train_time_t0 441 | 442 | ideally_val = 'val' if 'val' in self.dataset.splits() else 'train' 443 | if not final_iter: 444 | eval_iters = c.eval_iters 445 | else: 446 | eval_iters = c.final_eval_iters 447 | splits = self.dataset.splits() 448 | split_losses = self.estimate_losses(eval_iters, splits, continue_iter=not final_iter) 449 | 450 | util.synchronize_device(c.device) 451 | eval_dt = time.time() - eval_start_time 452 | self.eval_time += eval_dt 453 | 454 | print(f'\neval iter {iter_num}: {eval_dt*1000:,.0f}ms') 455 | sorted_losses = list(split_losses[ideally_val].items()) 456 | sorted_losses.sort() 457 | print( f'{ideally_val} losses:', 458 | ', '.join(f'{name} {loss:.4g}' for name, loss in sorted_losses if isinstance(loss, float)) ) 459 | print() 460 | 461 | if c.wandb_log: 462 | wandb_losses = { f'{split} {name}': loss 463 | for split, losses in split_losses.items() 464 | for name, loss in losses.items() 465 | if 'token_XE' not in name or final_iter } 466 | final_losses = {f'final {name}': loss for name, loss in split_losses.items()} if final_iter else {} 467 | checkpoint_dict = { var.replace('_', ' '): getattr(self, var) 468 | for var in self.checkpoint_vars if var != 'iter_num' } 469 | wandb_dict = { 470 | 'iter': iter_num, 471 | 'lr': lr, 472 | 'FLOP/s': flops, 473 | 'TFLOP/s': flops / 1e12, 474 | 'eta': eta, 475 | 'trained percent': 100*iter_num/c.iters, 476 | } | checkpoint_dict | wandb_losses | final_losses | (model_log or {}) 477 | try: 478 | wandb.log(wandb_dict) 479 | except Exception as e: 480 | print('\nwandb.log Exception:', e) 481 | print('wandb_dict = ', wandb_dict) 482 | print() 483 | 484 | loss = split_losses[ideally_val]['loss'] 485 | new_best = loss < self.best_val_loss 486 | if new_best: 487 | self.best_val_loss = loss 488 | 489 | save_checkpoint = c.out_dir and c.checkpoint and ( final_iter or 490 | time.time() > last_checkpoint_time + c.min_checkpoint_interval ) 491 | if save_checkpoint: 492 | best_ckpt = os.path.join(c.out_dir, 'ckpt_best_loss.pt') 493 | checkpoint_dict = {'losses': split_losses} 494 | if not new_best: 495 | ckpt = os.path.join(c.out_dir, 'ckpt.pt') 496 | if os.path.exists(ckpt): 497 | shutil.move(ckpt, best_ckpt) 498 | self.checkpoint(c.out_dir, checkpoint_dict) 499 | last_checkpoint_time = time.time() 500 | 501 | if (new_best or final_iter) and os.path.exists(best_ckpt): 502 | os.remove(best_ckpt) 503 | 504 | util.empty_cache(c.device) 505 | train_time_t0 = time.time() 506 | last_log_time = time.time() 507 | 508 | if c.out_dir and (Path(c.out_dir) / 'STOP').exists(): 509 | assert False, 'STOP' 510 | 511 | if c.out_dir: 512 | with open(os.path.join(c.out_dir, 'FINISHED_TRAINING'), 'a'): 513 | pass 514 | 515 | def checkpoint(self, dir_name, checkpoint_dict): 516 | t0 = time.time() 517 | 518 | checkpoint_dict = dict( 519 | model = self.model.__class__.__name__, 520 | model_config = dataclasses.asdict(self.model.config), 521 | Train = self.__class__.__name__, 522 | train_config = dataclasses.asdict(self.train_config) 523 | ) | checkpoint_dict | self.meta 524 | for var in self.checkpoint_vars: 525 | checkpoint_dict[var] = getattr(self, var) 526 | 527 | print(f'saving checkpoint to {dir_name}') 528 | os.makedirs(dir_name, exist_ok=True) 529 | 530 | torch.save(checkpoint_dict, os.path.join(dir_name, 'ckpt_small.pt')) 531 | 532 | if self.train_config.checkpoint_model: 533 | checkpoint_dict |= dict( 534 | state_dict = self.model.state_dict(), 535 | optimizer = self.optimizer.state_dict() ) 536 | torch.save(checkpoint_dict, os.path.join(dir_name, 'ckpt.pt')) 537 | 538 | print(f'checkpoint saved in {time.time()-t0:.1f} seconds') 539 | print() 540 | 541 | def model_from_checkpoint(checkpoint, device=None, config_override={}): 542 | if isinstance(checkpoint, str): 543 | checkpoint = torch.load(os.path.join(checkpoint, 'ckpt.pt'), map_location=device) 544 | 545 | Model = eval(checkpoint['model']) 546 | config = checkpoint['model_config'] 547 | for k in list(config): 548 | if not hasattr(Model.Config, k): 549 | print(f"Train: WARNING! '{k}' no longer in {Model.Config.__name__}") 550 | del config[k] 551 | config = config | config_override 552 | model_config = Model.Config(**config) 553 | 554 | model = Model(model_config) 555 | if device: 556 | model.to(device) 557 | 558 | state_dict = checkpoint['state_dict'] 559 | model.load_state_dict(state_dict) 560 | 561 | return model, checkpoint 562 | 563 | def config_from_args(**kwargs): 564 | Model = eval(kwargs.get('model', 'Transformer')) 565 | ModelConfig = Model.Config 566 | model_config, train_config = util.make_dataclasses([ModelConfig, TrainConfig], **kwargs, args=kwargs) 567 | return model_config, train_config 568 | 569 | def from_args(**kwargs): 570 | model_config, train_config = Train.config_from_args(**kwargs) 571 | return Train(model_config, train_config) 572 | 573 | def estimate_losses(self, eval_iters, splits=None, continue_iter=False): 574 | # implement continue_iter 575 | def dataset_iter(*args, **kwargs): 576 | if not continue_iter: 577 | for x in self.dataset_iter(*args, **kwargs): 578 | yield x 579 | else: 580 | # cycle through all the data 581 | key_kwargs = copy.copy(kwargs) 582 | if 'dataset' in key_kwargs: 583 | key_kwargs['dataset'] = '' 584 | key = (*args, *sorted(key_kwargs.items())) 585 | iters = self._estimate_losses_dataset_iters 586 | while True: 587 | if key not in iters: 588 | iters[key] = self.dataset_iter(*args, **kwargs) 589 | for x in iters[key]: 590 | yield x 591 | del iters[key] 592 | 593 | def _estimate_loss(dataset_iter): 594 | return estimate_loss(dataset_iter, eval_iters, self.model, 595 | bytes_per_token=self.dataset.bytes_per_token, autocast=self.autocast) 596 | 597 | if splits is None: 598 | splits = self.dataset.splits() 599 | losses = {split: _estimate_loss(dataset_iter(split)) for split in splits} 600 | 601 | return losses 602 | 603 | def estimate_loss(dataset_iter, eval_iters, model, bytes_per_token=None, autocast=contextlib.nullcontext()): 604 | model.eval() 605 | with torch.inference_mode(): 606 | all_losses = collections.defaultdict(util.MeanError) 607 | for _ in range(eval_iters): 608 | tokens, targets = next(dataset_iter) 609 | with autocast: 610 | logits, losses = model(tokens, targets) 611 | 612 | losses['token_XE'] = util.cross_entropy(logits, targets, reduction='batch', ignore_index=-1) 613 | 614 | losses = util.tensor_items(losses, dtype=torch.float64) 615 | for name, loss in losses.items(): 616 | all_losses[name].add(loss) 617 | 618 | for name, losses in list(all_losses.items()): 619 | all_losses[name] = losses.mean() 620 | all_losses[name + ' stat'] = losses.error() 621 | 622 | if bytes_per_token is not None: 623 | BPB_mult = 1 / (bytes_per_token * math.log(2)) 624 | if 'cross entropy' in all_losses: 625 | all_losses['bits per byte'] = BPB_mult * all_losses['cross entropy'] 626 | all_losses['bits per byte stat'] = BPB_mult * all_losses['cross entropy stat'] 627 | 628 | all_losses = util.tensor_items(all_losses) 629 | return all_losses 630 | 631 | import sample 632 | 633 | def train(**kwargs): 634 | class LastLine: 635 | def __init__(self): 636 | self.last = '' 637 | 638 | def write(self, data): 639 | if len(self.last) and self.last[-1] == '\n': 640 | self.last = '' 641 | self.last += data 642 | 643 | def flush(self): 644 | pass 645 | 646 | # benchmark inference 647 | if 'benchmark_generate' in kwargs: 648 | del kwargs['benchmark_generate'] 649 | 650 | print(f'initializaing model...') 651 | sys.stdout = LastLine() 652 | trainer = Train.from_args(**kwargs) 653 | sys.stdout = sys.__stdout__ 654 | tc = trainer.train_config 655 | 656 | B = int(kwargs.get('batch_size', 1)) 657 | while True: 658 | print(f'generating batch_size={B}') 659 | log = sample.sample( 660 | model = trainer.model, 661 | num_samples = 2, 662 | batch_size = B, 663 | device = tc.device, 664 | dtype = tc.dtype, 665 | compile = False, 666 | quiet = True) 667 | t = log['times'][-1] 668 | ys = log['generations'][-1] 669 | print(f'generation took {t:7.3f}s, {ys[:, 1:].numel()/t/1000:6.2f} ktps,', 670 | util.get_memory_stats(device=tc.device)) 671 | B *= 2 672 | 673 | # benchmark different batch sizes 674 | if kwargs.get('batch_size', '') == 'test': 675 | B = 1 676 | sys.stdout = LastLine() 677 | class NextBatchSize(Exception): 678 | pass 679 | 680 | while True: 681 | kwargs['batch_size'] = B 682 | sys.__stdout__.write(f'batch size = {B}:\n') 683 | trainer = Train.from_args(**kwargs) 684 | 685 | done = False 686 | def callback(**kwargs): 687 | nonlocal done 688 | if kwargs['print_stats']: 689 | if done: 690 | raise NextBatchSize 691 | done = True 692 | 693 | try: 694 | trainer.train(callback=callback) 695 | except NextBatchSize: 696 | sys.__stdout__.write(sys.stdout.last) 697 | B *= 2 698 | 699 | mB = None 700 | out_dir = None 701 | while True: 702 | try: 703 | t = time.time() 704 | 705 | model_config, train_config = Train.config_from_args(**kwargs) 706 | if out_dir is None: 707 | out_dir = train_config.out_dir 708 | else: 709 | train_config.out_dir = out_dir 710 | trainer = Train(model_config, train_config) 711 | 712 | if mB is None: 713 | mB = trainer.train_config.micro_batch_size 714 | 715 | try: 716 | trainer.train() 717 | except BaseException as e: 718 | i = 1 719 | while os.path.exists(fail_dir := f'{out_dir}_FAILED{i}'): 720 | i += 1 721 | shutil.move(out_dir, fail_dir) 722 | print(f'\nout_dir moved to {fail_dir}\n') 723 | 724 | trainer.tee.__del__() 725 | raise e 726 | 727 | # if trainer.train_config.wandb_log: 728 | # wandb.finish() 729 | return 730 | except torch.cuda.OutOfMemoryError as e: 731 | if mB == 1: 732 | print() 733 | print(f'micro_batch_size = {mB} can not be decreased') 734 | raise e 735 | 736 | # dt = time.time() - t 737 | # if dt > 30*60: 738 | # print() 739 | # print(f'OOM after running for {round(dt/60)} minutes') 740 | # raise e 741 | 742 | print() 743 | print(e) 744 | print(f'Retrying with micro_batch_size = {mB} -> {mB//2} ...') 745 | print() 746 | mB //= 2 747 | kwargs['micro_batch_size'] = mB 748 | kwargs['note'] = f"{kwargs.get('note','')}_mB{mB}" 749 | torch.cuda.empty_cache() 750 | # if trainer.train_config.wandb_log: 751 | # wandb.finish() 752 | 753 | import fire 754 | if __name__ == '__main__': 755 | fire.Fire(train) 756 | -------------------------------------------------------------------------------- /spacebyte_figure.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | _ 111 | 112 | 113 | 114 | 115 | s 116 | 117 | 118 | 119 | 120 | p 121 | 122 | 123 | 124 | 125 | a 126 | 127 | 128 | 129 | 130 | c 131 | 132 | 133 | 134 | 135 | e 136 | 137 | 138 | 139 | 140 | _ 141 | 142 | 143 | 144 | 145 | i 146 | 147 | 148 | 149 | 150 | s 151 | 152 | 153 | 154 | 155 | _ 156 | 157 | 158 | 159 | 160 | _ 161 | 162 | 163 | 164 | 165 | a 166 | 167 | 168 | 169 | 170 | l 171 | 172 | 173 | 174 | 175 | l 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | _ 363 | 364 | 365 | 366 | 367 | y 368 | 369 | 370 | 371 | 372 | s 373 | 374 | 375 | 376 | 377 | p 378 | 379 | 380 | 381 | 382 | a 383 | 384 | 385 | 386 | 387 | c 388 | 389 | 390 | 391 | 392 | e 393 | 394 | 395 | 396 | 397 | _ 398 | 399 | 400 | 401 | 402 | i 403 | 404 | 405 | 406 | 407 | s 408 | 409 | 410 | 411 | 412 | _ 413 | 414 | 415 | 416 | 417 | a 418 | 419 | 420 | 421 | 422 | l 423 | 424 | 425 | 426 | 427 | l 428 | 429 | 430 | 431 | 432 | 433 | 434 | embedding 435 | 436 | 437 | 438 | 439 | 440 | local 441 | 442 | 443 | 444 | 445 | (1) 446 | 447 | 448 | L 449 | 450 | 451 | 452 | local transformer blocks 453 | 454 | 455 | 456 | 457 | 458 | 459 | local 460 | 461 | 462 | 463 | 464 | (2) 465 | 466 | 467 | L 468 | 469 | 470 | 471 | local transformer blocks 472 | 473 | 474 | 475 | 476 | 477 | 478 | global 479 | 480 | 481 | L 482 | 483 | 484 | 485 | global transformer blocks 486 | 487 | 488 | 489 | 490 | 491 | de-embedding 492 | 493 | 494 | 495 | 496 | --------------------------------------------------------------------------------