├── src ├── __init__.py ├── configs │ ├── __init__.py │ ├── shakespeare_char.py │ ├── openwebtext.py │ ├── openwebtext_mh.py │ └── openwebtext_xl.py ├── sharding.py ├── layers.py ├── model.py └── train.py ├── .gitignore ├── requirements.txt ├── data ├── shakespeare_char │ ├── readme.md │ └── prepare.py └── openwebtext │ ├── readme.md │ └── prepare.py ├── scripts ├── setup.sh ├── test_data.py ├── test_ckpt.py ├── test_rotary.py ├── test_jax.py └── tpu_commands.sh ├── launch.py ├── README.md └── sample.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/**/*.txt 2 | data/**/*.pkl 3 | data/**/*.bin 4 | __pycache__ 5 | _launch.py 6 | slurm_*.sh 7 | *.out 8 | outputs 9 | *_env 10 | wandb/ 11 | plugins/ -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | equinox==0.11.2 2 | tiktoken==0.5.1 3 | optax==0.1.7 4 | tqdm==4.66.1 5 | transformers==4.35.2 6 | tiktoken==0.5.1 7 | orbax==0.1.9 8 | rich==13.7.0 9 | gcsfs 10 | wandb -------------------------------------------------------------------------------- /data/shakespeare_char/readme.md: -------------------------------------------------------------------------------- 1 | 2 | # tiny shakespeare, character-level 3 | 4 | Tiny shakespeare, of the good old char-rnn fame :) Treated on character-level. 5 | 6 | After running `prepare.py`: 7 | 8 | - train.bin has 1,003,854 tokens 9 | - val.bin has 111,540 tokens 10 | -------------------------------------------------------------------------------- /data/openwebtext/readme.md: -------------------------------------------------------------------------------- 1 | 2 | ## openwebtext dataset 3 | 4 | after running `prepare.py` (preprocess) we get: 5 | 6 | - train.bin is ~17GB, val.bin ~8.5MB 7 | - train has ~9B tokens (9,035,582,198) 8 | - val has ~4M tokens (4,434,897) 9 | 10 | this came from 8,013,769 documents in total. 11 | 12 | references: 13 | 14 | - OpenAI's WebText dataset is discussed in [GPT-2 paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) 15 | - [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/) dataset 16 | -------------------------------------------------------------------------------- /src/configs/shakespeare_char.py: -------------------------------------------------------------------------------- 1 | from src.train import ExperimentConfig 2 | from src.model import GPTConfig 3 | 4 | config = ExperimentConfig( 5 | rundir='', 6 | data_dir='data/shakespeare_char', 7 | learning_rate=1e-3, 8 | batch_size=64, 9 | warmup_steps=100, 10 | min_lr=1e-4, 11 | lr_decay_steps=5000, 12 | max_steps=5000, 13 | beta2=0.99, 14 | weight_decay=1e-4, 15 | eval_interval=2000, 16 | compute_dtype='bfloat16', 17 | param_dtype='float32', 18 | g_accum_iters=1, 19 | shard_model=False, 20 | model_config=GPTConfig( 21 | block_size=256, vocab_size=65, n_layer=6, n_head=6, n_embd=384, dropout=0.2), 22 | ) 23 | -------------------------------------------------------------------------------- /scripts/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | source scripts/tpu_commands.sh 4 | 5 | # Remove any outdated info from known hosts. 6 | for ip in $(tpu midGPT ips $2); do ssh-keygen -R $ip; done 7 | 8 | tpu midGPT copy $2 9 | tpu midGPT ssh $2 "pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html" 10 | tpu midGPT ssh $2 "cd midGPT; pip install -r requirements.txt" 11 | 12 | # Attach and mount PD that has dataset. 13 | gcloud alpha compute tpus tpu-vm attach-disk $2 \ 14 | --zone=$1 \ 15 | --disk=$3 \ 16 | --mode=read-only 17 | 18 | tpu midGPT ssh $2 "sudo mkdir -p /mnt/disks/persist" 19 | tpu midGPT ssh $2 "sudo mount -o discard,defaults /dev/sdb /mnt/disks/persist" 20 | -------------------------------------------------------------------------------- /src/configs/openwebtext.py: -------------------------------------------------------------------------------- 1 | from src.train import ExperimentConfig 2 | from src.model import GPTConfig 3 | 4 | config = ExperimentConfig( 5 | rundir='', 6 | data_dir='data/openwebtext', 7 | learning_rate=1e-3, 8 | batch_size=128, 9 | warmup_steps=5_000, 10 | min_lr=1e-5, 11 | lr_decay_steps=60_000, 12 | max_steps=60_000, 13 | beta2=0.95, 14 | weight_decay=1e-4, 15 | eval_interval=1000, 16 | compute_dtype='bfloat16', 17 | param_dtype='float32', 18 | g_accum_iters=16, # eff BS = 2048 19 | shard_model=False, 20 | model_config=GPTConfig( 21 | block_size=1024, vocab_size=50304, n_layer=12, n_head=12, n_embd=768, dropout=0.0) 22 | ) 23 | -------------------------------------------------------------------------------- /src/configs/openwebtext_mh.py: -------------------------------------------------------------------------------- 1 | from src.train import ExperimentConfig 2 | from src.model import GPTConfig 3 | 4 | config = ExperimentConfig( 5 | rundir='', 6 | data_dir='/mnt/disks/persist/openwebtext', 7 | learning_rate=1e-3, 8 | batch_size=2048, 9 | warmup_steps=5_000, 10 | min_lr=1e-5, 11 | lr_decay_steps=60_000, 12 | max_steps=60_000, 13 | beta2=0.95, 14 | weight_decay=1e-4, 15 | eval_interval=1000, 16 | compute_dtype='bfloat16', 17 | param_dtype='float32', 18 | g_accum_iters=1, 19 | shard_model=False, 20 | model_config=GPTConfig( 21 | block_size=1024, vocab_size=50304, n_layer=12, n_head=12, n_embd=768, dropout=0.0) 22 | ) 23 | -------------------------------------------------------------------------------- /src/configs/openwebtext_xl.py: -------------------------------------------------------------------------------- 1 | from src.train import ExperimentConfig 2 | from src.model import GPTConfig 3 | 4 | config = ExperimentConfig( 5 | rundir='', 6 | data_dir='/mnt/disks/persist/openwebtext', 7 | learning_rate=1e-3, 8 | batch_size=1024, 9 | warmup_steps=2500, 10 | min_lr=1e-5, 11 | lr_decay_steps=25_000, 12 | max_steps=25_000, 13 | beta2=0.95, 14 | weight_decay=1e-4, 15 | eval_interval=1000, 16 | compute_dtype='bfloat16', 17 | param_dtype='float32', 18 | g_accum_iters=1, 19 | shard_model=True, 20 | model_config=GPTConfig( 21 | block_size=1024, vocab_size=50304, n_layer=24, n_head=16, n_embd=2048, dropout=0.0) 22 | ) 23 | -------------------------------------------------------------------------------- /scripts/test_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | script_dir = os.path.dirname(os.path.abspath(__file__)) 4 | sys.path.insert(0, os.path.dirname(script_dir)) 5 | 6 | import time 7 | import os 8 | import numpy as np 9 | import jax 10 | from src.train import get_batch 11 | 12 | start = time.time() 13 | train_data = np.memmap(os.path.join('/mnt/disks/persist/openwebtext', 'train.bin'), dtype=np.uint16, mode='r').copy() 14 | print(f"Worker {jax.process_index()}; Time to load train.bin: {time.time() - start}") 15 | x_GxBxD, y_GxBxD = get_batch( 16 | train_data, 1024, 128, 4 17 | ) 18 | 19 | # time how long it takes to get 100 batches 20 | start = time.time() 21 | for i in range(100): 22 | x_GxBxD, y_GxBxD = get_batch( 23 | train_data, 1024, 128, 4 24 | ) 25 | end = time.time() 26 | print(f"Worker {jax.process_index()}; Batches per second: {100 / (end - start)}") -------------------------------------------------------------------------------- /scripts/test_ckpt.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax.experimental import mesh_utils 4 | import orbax.checkpoint as ocp 5 | Mesh, NamedSharding = jax.sharding.Mesh, jax.sharding.NamedSharding 6 | P, with_sharding_constraint = jax.sharding.PartitionSpec, jax.lax.with_sharding_constraint 7 | 8 | jax.distributed.initialize() 9 | 10 | options = ocp.CheckpointManagerOptions( 11 | max_to_keep=1, save_interval_steps=100) 12 | mngr = ocp.CheckpointManager( 13 | "gs://training_out/test", 14 | ocp.AsyncCheckpointer(ocp.PyTreeCheckpointHandler()), 15 | options=options) 16 | 17 | mesh = Mesh(mesh_utils.create_device_mesh((jax.device_count(),)), axis_names=('data',)) 18 | shardings = NamedSharding(mesh, P('data', None)) 19 | @jax.jit 20 | def init(): 21 | x = jnp.ones((128 * 64, 128 * 100)) 22 | return with_sharding_constraint(x, shardings) 23 | A = init() 24 | mngr.save(0, [A]) -------------------------------------------------------------------------------- /scripts/test_rotary.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | script_dir = os.path.dirname(os.path.abspath(__file__)) 4 | sys.path.insert(0, os.path.dirname(script_dir)) 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | from src.layers import fixed_pos_embedding, apply_rotary_pos_emb 9 | 10 | 11 | def test_rotary(): 12 | key = jax.random.PRNGKey(0) 13 | key1, key2 = jax.random.split(key) 14 | T, C = 32, 64 15 | Q_HxTxC = jax.random.normal(key1, (8, T, C)) 16 | K_HxTxC = jax.random.normal(key2, (8, T, C)) 17 | # Shift K, Q along T dimension 18 | shift = 5 19 | Qshift_HxTxC = jnp.roll(Q_HxTxC, shift, axis=1) 20 | Kshift_HxTxC = jnp.roll(K_HxTxC, shift, axis=1) 21 | sin_TxCp, cos_TxCp = fixed_pos_embedding(C, T) 22 | Q_HxTxC = apply_rotary_pos_emb(Q_HxTxC, sin_TxCp, cos_TxCp) 23 | K_HxTxC = apply_rotary_pos_emb(K_HxTxC, sin_TxCp, cos_TxCp) 24 | A_HxTxT = Q_HxTxC @ jnp.transpose(K_HxTxC, (0, 2, 1)) 25 | 26 | Qshift_HxTxC = apply_rotary_pos_emb(Qshift_HxTxC, sin_TxCp, cos_TxCp) 27 | Kshift_HxTxC = apply_rotary_pos_emb(Kshift_HxTxC, sin_TxCp, cos_TxCp) 28 | Ashift_HxTxT = Qshift_HxTxC @ jnp.transpose(Kshift_HxTxC, (0, 2, 1)) 29 | 30 | A_HxTxT_shifted = jnp.roll(A_HxTxT, shift, axis=(-2, -1)) 31 | print(jnp.abs(Ashift_HxTxT[:, shift:, shift:] - A_HxTxT_shifted[:, shift:, shift:]).max()) 32 | return 33 | 34 | 35 | if __name__ == "__main__": 36 | test_rotary() 37 | -------------------------------------------------------------------------------- /src/sharding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax 3 | 4 | jtu = jax.tree_util 5 | NamedSharding = jax.sharding.NamedSharding 6 | P = jax.sharding.PartitionSpec 7 | 8 | 9 | def tree_broadcast(prefix, target): 10 | def _broadcast(leaf, subtree): 11 | return jtu.tree_map(lambda _: leaf, subtree) 12 | return jtu.tree_map(_broadcast, prefix, target) 13 | 14 | 15 | def reshard(tree, shardings): 16 | # From https://github.com/google-research/big_vision/blob/1b17abc6b754175dcd92e9db3e13c409e2ccb951/big_vision/utils.py#L1288 17 | def _make_global_arr(x, shard, shape): 18 | # Avoid unnecessary copies and transfers: 19 | if hasattr(x, "sharding") and x.sharding.is_equivalent_to(shard, len(shape)): 20 | return x 21 | if not getattr(x, "is_fully_addressable", True): 22 | raise RuntimeError("Trying to reshard a non-fully-addressable array. See link above.") 23 | x = jax.device_get(x) # Might be on local devices. 24 | xs = [jax.device_put(x[s], device=d) 25 | for d, s in shard.addressable_devices_indices_map(shape).items()] 26 | return jax.make_array_from_single_device_arrays(shape, shard, xs) 27 | 28 | shapes = jax.tree_map(np.shape, tree) 29 | shardings = tree_broadcast(shardings, tree) 30 | return jax.tree_map(_make_global_arr, tree, shardings, shapes) 31 | 32 | 33 | def get_shard_fn(mesh, sharding): 34 | """Shard fn for data parallelism.""" 35 | n_procs = jax.process_count() 36 | def shard(x): 37 | local_ds = mesh.local_devices 38 | xs = jax.device_put(np.split(x, len(local_ds), axis=1), local_ds) 39 | global_shape = (x.shape[0], x.shape[1] * n_procs, *x.shape[2:]) 40 | # each proc has its own sub-batch--"combine" them together into a jax array. 41 | return jax.make_array_from_single_device_arrays(global_shape, sharding, xs) 42 | return shard 43 | -------------------------------------------------------------------------------- /scripts/test_jax.py: -------------------------------------------------------------------------------- 1 | # The following code snippet will be run on all TPU hosts 2 | import numpy as np 3 | import jax 4 | from jax.experimental import mesh_utils 5 | import numpy as np 6 | Mesh, NamedSharding = jax.sharding.Mesh, jax.sharding.NamedSharding 7 | P, with_sharding_constraint = jax.sharding.PartitionSpec, jax.lax.with_sharding_constraint 8 | 9 | 10 | def tree_broadcast(prefix, target): 11 | def _broadcast(leaf, subtree): 12 | return jax.tree_map(lambda _: leaf, subtree) 13 | return jax.tree_map(_broadcast, prefix, target) 14 | 15 | 16 | def reshard(tree, shardings): 17 | def _make_global_arr(x, shard, shape): 18 | # Avoid unnecessary copies and transfers: 19 | if hasattr(x, "sharding") and x.sharding.is_equivalent_to(shard, len(shape)): # pylint: disable=line-too-long 20 | return x 21 | if not getattr(x, "is_fully_addressable", True): 22 | raise RuntimeError("Trying to reshard a non-fully-addressable array. " 23 | "Please see the doc-comment for detailed explanation.") 24 | x = jax.device_get(x) # Might be on local devices. 25 | xs = [jax.device_put(x[s], device=d) 26 | for d, s in shard.addressable_devices_indices_map(shape).items()] 27 | return jax.make_array_from_single_device_arrays(shape, shard, xs) 28 | 29 | shapes = jax.tree_map(np.shape, tree) 30 | shardings = tree_broadcast(shardings, tree) 31 | return jax.tree_map(_make_global_arr, tree, shardings, shapes) 32 | 33 | 34 | # The total number of TPU cores in the Pod 35 | device_count = jax.device_count() 36 | 37 | # The number of TPU cores attached to this host 38 | local_device_count = jax.local_device_count() 39 | 40 | mesh = Mesh(mesh_utils.create_device_mesh((jax.device_count(),)), axis_names=('data',)) 41 | A = np.ones((128 * 64, 128 * 100)) 42 | B = np.ones((128 * 100, 256)) 43 | shardings = (NamedSharding(mesh, P('data', None)), NamedSharding(mesh, P('data', None))) 44 | A, B = reshard((A, B), shardings) 45 | 46 | @jax.jit 47 | def op(A, B): 48 | return A @ B 49 | 50 | result = op(A, B) 51 | 52 | # Print from a single host to avoid duplicated output 53 | if jax.process_index() == 0: 54 | print('global device count:', jax.device_count()) 55 | print('local device count:', jax.local_device_count()) 56 | jax.debug.visualize_array_sharding(A) 57 | jax.debug.visualize_array_sharding(B) 58 | jax.debug.visualize_array_sharding(result) -------------------------------------------------------------------------------- /data/shakespeare_char/prepare.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prepare the Shakespeare dataset for character-level language modeling. 3 | So instead of encoding with GPT-2 BPE tokens, we just map characters to ints. 4 | Will save train.bin, val.bin containing the ids, and meta.pkl containing the 5 | encoder and decoder and some other related info. 6 | """ 7 | import os 8 | import pickle 9 | import requests 10 | import numpy as np 11 | 12 | # download the tiny shakespeare dataset 13 | input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt') 14 | if not os.path.exists(input_file_path): 15 | data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' 16 | with open(input_file_path, 'w') as f: 17 | f.write(requests.get(data_url).text) 18 | 19 | with open(input_file_path, 'r') as f: 20 | data = f.read() 21 | print(f"length of dataset in characters: {len(data):,}") 22 | 23 | # get all the unique characters that occur in this text 24 | chars = sorted(list(set(data))) 25 | vocab_size = len(chars) 26 | print("all the unique characters:", ''.join(chars)) 27 | print(f"vocab size: {vocab_size:,}") 28 | 29 | # create a mapping from characters to integers 30 | stoi = { ch:i for i,ch in enumerate(chars) } 31 | itos = { i:ch for i,ch in enumerate(chars) } 32 | def encode(s): 33 | return [stoi[c] for c in s] # encoder: take a string, output a list of integers 34 | def decode(l): 35 | return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string 36 | 37 | # create the train and test splits 38 | n = len(data) 39 | train_data = data[:int(n*0.9)] 40 | val_data = data[int(n*0.9):] 41 | 42 | # encode both to integers 43 | train_ids = encode(train_data) 44 | val_ids = encode(val_data) 45 | print(f"train has {len(train_ids):,} tokens") 46 | print(f"val has {len(val_ids):,} tokens") 47 | 48 | # export to bin files 49 | train_ids = np.array(train_ids, dtype=np.uint16) 50 | val_ids = np.array(val_ids, dtype=np.uint16) 51 | train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin')) 52 | val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin')) 53 | 54 | # save the meta information as well, to help us encode/decode later 55 | meta = { 56 | 'vocab_size': vocab_size, 57 | 'itos': itos, 58 | 'stoi': stoi, 59 | } 60 | with open(os.path.join(os.path.dirname(__file__), 'meta.pkl'), 'wb') as f: 61 | pickle.dump(meta, f) 62 | 63 | # length of dataset in characters: 1115394 64 | # all the unique characters: 65 | # !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz 66 | # vocab size: 65 67 | # train has 1003854 tokens 68 | # val has 111540 tokens 69 | -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict 2 | from datetime import datetime 3 | import argparse 4 | import os 5 | 6 | import equinox as eqx 7 | import gcsfs 8 | import jax 9 | import json 10 | import wandb 11 | from jax.experimental.multihost_utils import sync_global_devices 12 | 13 | from src.train import train 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--config", type=str, required=True) 17 | parser.add_argument("--rundir", type=str) 18 | parser.add_argument("--debug", action="store_true") 19 | parser.add_argument("--multihost", action="store_true") 20 | cmd_args = parser.parse_args() 21 | 22 | if cmd_args.multihost: 23 | jax.distributed.initialize() 24 | # load config from src.configs 25 | config = getattr( 26 | __import__("src.configs", fromlist=[cmd_args.config]), cmd_args.config 27 | ).config 28 | if cmd_args.rundir is not None: 29 | config.rundir = cmd_args.rundir 30 | elif not cmd_args.debug: 31 | assert not cmd_args.multihost, "Multihost must prespecify rundir." 32 | config.rundir = os.path.join( 33 | "outputs", datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 34 | ) 35 | if cmd_args.debug: 36 | config.debug = True 37 | 38 | if jax.process_index() == 0: # Wandb and config setup 39 | wandb_id = None 40 | config_dict = asdict(config) 41 | if not cmd_args.debug: 42 | print(f"Writing to {config.rundir}") 43 | if config.rundir.startswith("gs://"): 44 | print("Using GCS filesystem") 45 | fs = gcsfs.GCSFileSystem() 46 | fopen, exists = fs.open, fs.exists 47 | else: 48 | print("Using local filesystem") 49 | config.rundir = os.path.abspath(config.rundir) 50 | fs, fopen, exists = os, open, os.path.exists 51 | 52 | # make sure the directory exists 53 | fs.makedirs(config.rundir, exist_ok=True) 54 | 55 | # write config as json 56 | with fopen(os.path.join(config.rundir, "config.json"), "w") as f: 57 | f.write(json.dumps(config_dict)) 58 | 59 | # Load wandb id or write it, for proper wandb resuming. 60 | wandb_id_path = os.path.join(config.rundir, "wandb_id.txt") 61 | if exists(wandb_id_path): 62 | with fopen(wandb_id_path, "r") as f: 63 | wandb_id = f.read() 64 | else: 65 | wandb_id = wandb.util.generate_id() 66 | with fopen(wandb_id_path, "w") as f: 67 | f.write(wandb_id) 68 | wandb.init(project="midgpt", id=wandb_id, resume="allow", config=config_dict) 69 | if cmd_args.multihost: 70 | sync_global_devices("end_wandb_init") 71 | eqx.tree_pprint(config) 72 | train(config) 73 | -------------------------------------------------------------------------------- /data/openwebtext/prepare.py: -------------------------------------------------------------------------------- 1 | # saves the openwebtext dataset to a binary file for training. following was helpful: 2 | # https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py 3 | 4 | import os 5 | from tqdm import tqdm 6 | import numpy as np 7 | import tiktoken 8 | from datasets import load_dataset # huggingface datasets 9 | 10 | # number of workers in .map() call 11 | # good number to use is ~order number of cpu cores // 2 12 | num_proc = 8 13 | 14 | # number of workers in load_dataset() call 15 | # best number might be different from num_proc above as it also depends on NW speed. 16 | # it is better than 1 usually though 17 | num_proc_load_dataset = num_proc 18 | 19 | if __name__ == '__main__': 20 | # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769) 21 | dataset = load_dataset("openwebtext", num_proc=num_proc_load_dataset) 22 | 23 | # owt by default only contains the 'train' split, so create a test split 24 | split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True) 25 | split_dataset['val'] = split_dataset.pop('test') # rename the test split to val 26 | 27 | # this results in: 28 | # >>> split_dataset 29 | # DatasetDict({ 30 | # train: Dataset({ 31 | # features: ['text'], 32 | # num_rows: 8009762 33 | # }) 34 | # val: Dataset({ 35 | # features: ['text'], 36 | # num_rows: 4007 37 | # }) 38 | # }) 39 | 40 | # we now want to tokenize the dataset. first define the encoding function (gpt2 bpe) 41 | enc = tiktoken.get_encoding("gpt2") 42 | def process(example): 43 | ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens 44 | ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe 45 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though... 46 | out = {'ids': ids, 'len': len(ids)} 47 | return out 48 | 49 | # tokenize the dataset 50 | tokenized = split_dataset.map( 51 | process, 52 | remove_columns=['text'], 53 | desc="tokenizing the splits", 54 | num_proc=num_proc, 55 | ) 56 | 57 | # concatenate all the ids in each dataset into one large file we can use for training 58 | for split, dset in tokenized.items(): 59 | arr_len = np.sum(dset['len'], dtype=np.uint64) 60 | filename = os.path.join(os.path.dirname(__file__), f'{split}.bin') 61 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 62 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 63 | total_batches = 1024 64 | 65 | idx = 0 66 | for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'): 67 | # Batch together samples for faster write 68 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy') 69 | arr_batch = np.concatenate(batch['ids']) 70 | # Write into mmap 71 | arr[idx : idx + len(arr_batch)] = arr_batch 72 | idx += len(arr_batch) 73 | arr.flush() 74 | 75 | # train.bin is ~17GB, val.bin ~8.5MB 76 | # train has ~9B tokens (9,035,582,198) 77 | # val has ~4M tokens (4,434,897) 78 | 79 | # to read the bin files later, e.g. with numpy: 80 | # m = np.memmap('train.bin', dtype=np.uint16, mode='r') 81 | -------------------------------------------------------------------------------- /src/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import typing as tp 3 | import numpy as np 4 | import equinox as eqx 5 | import jax 6 | 7 | jnp = jax.numpy 8 | KeyArray = tp.Any 9 | Array = jax.numpy.ndarray 10 | jrandom = jax.random 11 | 12 | 13 | class Embedding(eqx.Module): 14 | """For some reason, default Embedding impl is slow under vmap+JIT.""" 15 | V: int = eqx.field(static=True) # num embeddings 16 | D: int = eqx.field(static=True) # embedding size 17 | weight_VxD: Array 18 | 19 | def __init__( 20 | self, num_embeddings: int, embedding_size: int, weight: tp.Optional[Array]=None, 21 | *, key: tp.Optional[KeyArray]=None 22 | ): 23 | super().__init__() 24 | self.V, self.D = num_embeddings, embedding_size 25 | if weight is not None: 26 | self.weight_VxD = weight 27 | elif key is not None: 28 | self.weight_VxD = jrandom.normal(key, (self.V, self.D)) 29 | else: 30 | raise ValueError("need weight or key to be not None") 31 | 32 | @jax.named_scope("Embedding") 33 | def __call__(self, x_T, *, key=None): 34 | return jnp.take(self.weight_VxD, x_T, axis=0) 35 | 36 | 37 | class Linear(eqx.Module): 38 | """Linear with trunc normal init.""" 39 | weight_MxN: Array 40 | 41 | def __init__( 42 | self, in_features: int, out_features: int, weight: tp.Optional[Array]=None, 43 | *, key: tp.Optional[KeyArray]=None 44 | ): 45 | super().__init__() 46 | if weight is not None: 47 | self.weight_MxN = weight 48 | elif key is not None: 49 | self.weight_MxN = (1 / math.sqrt(in_features)) * jrandom.truncated_normal( 50 | key, lower=-2, upper=2, shape=(out_features, in_features)) 51 | else: 52 | raise ValueError("need weight or key to be not None") 53 | 54 | @jax.named_scope("Linear") 55 | def __call__(self, x_N: Array, *, key: KeyArray=None) -> Array: 56 | x_M = self.weight_MxN @ x_N 57 | return x_M 58 | 59 | 60 | class RMSNorm(eqx.Module): 61 | weight_M: tp.Optional[Array] 62 | eps: float 63 | 64 | def __init__(self, dim: int, use_weight=False, eps=1e-6): 65 | super().__init__() 66 | self.eps, self.weight_M = eps, None 67 | if use_weight: 68 | self.weight_M = jnp.ones((dim,)) 69 | 70 | @jax.named_scope("RMSNorm") 71 | def __call__(self, x_M: Array) -> Array: 72 | out_M = x_M * jax.lax.rsqrt(jnp.mean(jnp.square(x_M), keepdims=True) + self.eps) 73 | if self.weight_M is not None: 74 | out_M = out_M * self.weight_M 75 | return out_M 76 | 77 | 78 | ## RoPE functions 79 | def fixed_pos_embedding(C: int, T: int) -> tp.Tuple[np.ndarray, np.ndarray]: 80 | inv_freq_D = 1.0 / (10000 ** (np.arange(0, C, 2) / C)) # D = C // 2 81 | sinusoid_inp_TxD = np.einsum("i,j -> i j", np.arange(T), inv_freq_D) 82 | return np.sin(sinusoid_inp_TxD), np.cos(sinusoid_inp_TxD) 83 | 84 | 85 | def rotate_every_two(x: Array) -> Array: # [a b c d] -> [-b a -d c] 86 | x1 = x[..., ::2] 87 | x2 = x[..., 1::2] 88 | x = jnp.stack((-x2, x1), axis=-1) 89 | return jnp.reshape(x, x.shape[:-2] + (-1,)) 90 | 91 | 92 | def apply_rotary_pos_emb(x_HxTxC: Array, sin_TxD_np: np.ndarray, cos_TxD_np: np.ndarray) -> Array: 93 | sin_TxD = jnp.asarray(sin_TxD_np, dtype=x_HxTxC.dtype) 94 | cos_TxD = jnp.asarray(cos_TxD_np, dtype=x_HxTxC.dtype) 95 | sin_1xTxC = jnp.stack((sin_TxD, sin_TxD), axis=-1) 96 | sin_1xTxC = jnp.reshape(sin_1xTxC, sin_1xTxC.shape[:-2] + (-1,)) 97 | cos_1xTxC = jnp.stack((cos_TxD, cos_TxD), axis=-1) 98 | cos_1xTxC = jnp.reshape(cos_1xTxC, cos_1xTxC.shape[:-2] + (-1,)) 99 | return (x_HxTxC * cos_1xTxC) + (rotate_every_two(x_HxTxC) * sin_1xTxC) 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # midGPT 2 | A simple and hackable repository for experimenting on LLM pretraining, built using Jax+[Equinox](https://github.com/patrick-kidger/equinox). This codebase trains GPT-style decoder-only Transformers with billions of parameters on TPUs or GPUs. 3 | 4 | MidGPT is inspired by [NanoGPT](https://github.com/karpathy/nanoGPT/), but supports FSDP across multiple devices and hosts for training larger models. It also includes some recent Transformer improvements: rotary embeddings, RMSNorm, QK-Layernorm, and independent weight decay, which can improve or stabilize training at larger scales. 5 | 6 | Model code is in `src/model.py`, training code is in `src/train.py`. Experiments are configured in `src/configs/*.py`. Tested on Python **3.10.12**. 7 | 8 | This project is supported by the [TPU Research Cloud](https://sites.research.google/trc/about/). 9 | 10 | ## Data preparation 11 | 12 | As in nanoGPT, we support shakespeare_char (character-level prediction of Shakespeare texts) and openwebtext. The datasets are first processed into numpy memmapped `.bin` files: 13 | 14 | ```bash 15 | cd data/openwebtext # or data/shakespeare_char 16 | python prepare.py 17 | ``` 18 | 19 | ## Single host, multiple device setup 20 | From a fresh Python 3.10+ virtualenv, [install Jax](https://jax.readthedocs.io/en/latest/installation.html) for your accelerator type, then `pip install -r requirements.txt`. To profile performance, also `pip install tensorflow-cpu tensorboard-plugin-profile`. 21 | 22 | Start training: 23 | ```bash 24 | export WANDB_API_KEY= 25 | python launch.py --config=shakespeare_char 26 | python launch.py --config=openwebtext # 124M model 27 | ``` 28 | 29 | By default, this will create a timestamped rundir in `outputs/`. You can also manually specify `--rundir`, which is useful for resuming training: 30 | ```bash 31 | # Create new run at rundir, or resume training if it already exists: 32 | python launch.py --config=openwebtext --rundir= 33 | ``` 34 | 35 | Add a `--debug` if you want to (1) enable jax profiler and (2) skip checkpoint saving. 36 | 37 | ## Multihost setup 38 | Multihost training has only been tested on TPU slices (e.g., TPU v3-128), and we assume the dataset is openwebtext. Before starting, change the `tpu_project` and `tpu_zone` variables in `scripts/tpu_commands.sh` to your project ID and zone. Then, source the TPU commands: 39 | ```bash 40 | source scripts/tpu_commands.sh 41 | ``` 42 | 43 | 44 | The data should be in a folder `openwebtext/` on a Google Cloud persistent disk, which will then be mounted to each host. Modify `scripts/setup.sh` with the correct zone and disk name, then: 45 | ```bash 46 | ./scripts/setup.sh # after bringing up TPU slice 47 | ``` 48 | 49 | To start training a 1.5B model: 50 | ```bash 51 | tpu midGPT ssh 'tmux new -d -s launch "WANDB_API_KEY= python ~/midGPT/launch.py --config=openwebtext_xl --multihost --rundir=gs://your_bucket_name/run_name"' 52 | ``` 53 | 54 | ## Expected performance 55 | The config `openwebtext.py` trains a 124M model analogous to nanoGPT, and should achieve ~2.80 val loss after all 60,000 steps. The config `openwebtext_xl.py` trains a 1.5B model, and should achieve a val loss ~2.42 after all 25,000 steps. On a TPU v3-128, the 1.5B model should take ~16.5 hours to train (throughput: ~444K tokens per second, MFU=47.8%). 56 | 57 | ## Acknowledgements 58 | Compute was generously provided by the TPU Research Cloud (TRC). 59 | 60 | * Tasks and data loading copied from [nanoGPT](https://github.com/karpathy/nanoGPT/) 61 | * TPU shell commands adapted from [easyLM](https://github.com/young-geng/EasyLM) 62 | * Higher learning rates, independent weight decay, and QK-LayerNorm were adopted based on the results of [small-scale proxies](https://arxiv.org/abs/2309.14322) 63 | 64 | ## Citation 65 | If you would like to cite this work: 66 | 67 | ``` 68 | @article{zhou2023midgpt, 69 | author={Allan Zhou and Nicholas C. Landolfi and Yiding Jiang}, 70 | title={mid{GPT}: a simple and hackable repository for LLM pretraining}, 71 | year={2023}, 72 | url={https://github.com/AllanYangZhou/midGPT}, 73 | } 74 | ``` 75 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sample from a trained model 3 | """ 4 | import argparse 5 | import os 6 | import json 7 | import pickle 8 | 9 | from jax.experimental import mesh_utils 10 | from src.model import GPT 11 | from src.train import ExperimentConfig 12 | import equinox as eqx 13 | import gcsfs 14 | import jax 15 | import jax.numpy as jnp 16 | import jax.random as jrandom 17 | import numpy as np 18 | import optax # type: ignore 19 | import orbax.checkpoint as ocp 20 | import tiktoken 21 | 22 | from src.train import cast_pytree 23 | 24 | jtu = jax.tree_util 25 | NamedSharding, Mesh = jax.sharding.NamedSharding, jax.sharding.Mesh 26 | P = jax.sharding.PartitionSpec 27 | 28 | 29 | parser = argparse.ArgumentParser() 30 | # outputs directory, e.g., outputs/2023-11-25-00-52-09 31 | parser.add_argument("--ckpt_dir", type=str, required=True) 32 | # start with "\n"... or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt" 33 | parser.add_argument("--start", type=str, default="\n") 34 | parser.add_argument("--num_samples", type=int, default=10) 35 | parser.add_argument("--max_new_tokens", type=int, default=500) 36 | parser.add_argument("--temperature", type=float, default=0.8) 37 | cmd_args = parser.parse_args() 38 | 39 | if cmd_args.ckpt_dir.startswith("gs://"): 40 | print("Using GCS filesystem for checkpoint") 41 | fs = gcsfs.GCSFileSystem() 42 | fopen = fs.open 43 | else: 44 | print("Using local filesystem for checkpoint") 45 | fs = os 46 | fopen = open 47 | 48 | 49 | def from_json(json_path, dataclass_type): 50 | def convert(dict_or_list, dataclass_type): 51 | if isinstance(dict_or_list, dict): 52 | field_types = { 53 | f.name: f.type for f in dataclass_type.__dataclass_fields__.values() 54 | } 55 | return dataclass_type( 56 | **{k: convert(v, field_types[k]) for k, v in dict_or_list.items()} 57 | ) 58 | elif isinstance(dict_or_list, list): 59 | return [convert(elem, dataclass_type.__args__[0]) for elem in dict_or_list] 60 | else: 61 | return dict_or_list 62 | 63 | with fopen(json_path, "r") as f: 64 | json_string = f.read() 65 | return convert(json.loads(json_string), dataclass_type) 66 | 67 | 68 | def generate( 69 | config, batched_model, idx, max_new_tokens, temperature=1.0, key=None 70 | ): 71 | block_size = config.model_config.block_size 72 | for _ in range(max_new_tokens): 73 | # take the final block_size tokens for conditioning, if the sequence is too long 74 | idx_cond = idx if idx.shape[1] <= block_size else idx[:, -block_size:] 75 | pluck_T = idx.shape[1] - 1 76 | if idx_cond.shape[1] < block_size: 77 | B, pad_T = idx_cond.shape[0], block_size - idx_cond.shape[1] 78 | padding = jnp.zeros((B, pad_T), dtype=idx_cond.dtype) 79 | idx_cond_new = jnp.concatenate([idx_cond, padding], axis=1) 80 | else: 81 | idx_cond_new = idx_cond 82 | # take the forward pass 83 | logits = batched_model(idx_cond_new) 84 | # pluck the logits at the final step and scale by desired temperature 85 | logits = logits[:, pluck_T, :] / temperature 86 | key, next_token_key = jrandom.split(key) 87 | # sample from the distribution 88 | idx_next = jax.random.categorical( 89 | next_token_key, 90 | logits, 91 | axis=1, 92 | ).reshape((idx.shape[0], 1)) 93 | # append sampled index to the running sequence and continue 94 | idx = jnp.concatenate([idx, idx_next], axis=1) 95 | return idx 96 | 97 | 98 | # load the model 99 | config_path: str = os.path.join(cmd_args.ckpt_dir, "config.json") 100 | config: ExperimentConfig = from_json(config_path, ExperimentConfig) 101 | eqx.tree_pprint(config) 102 | 103 | mngr = ocp.CheckpointManager( 104 | config.rundir, 105 | ocp.PyTreeCheckpointer(), 106 | ) 107 | # model_leaves, _opt_state = mngr.restore(mngr.latest_step()) 108 | # model = GPT(config.model_config, key=jrandom.PRNGKey(0)) 109 | # model: GPT = jtu.tree_unflatten(jtu.tree_structure(model), model_leaves) 110 | 111 | model = GPT(config.model_config, key=jrandom.PRNGKey(0)) 112 | 113 | # both of these are unused, but just for loading the checkpoint 114 | scheduler = optax.warmup_cosine_decay_schedule( 115 | 0, 116 | config.learning_rate, 117 | config.warmup_steps, 118 | config.lr_decay_steps, 119 | end_value=config.min_lr, 120 | ) 121 | optimizer = optax.chain( 122 | optax.clip_by_global_norm(1.0), 123 | optax.scale_by_adam(b2=config.beta2), 124 | optax.add_decayed_weights(config.weight_decay / config.learning_rate), 125 | optax.scale_by_schedule(scheduler), 126 | optax.scale(-1), 127 | ) 128 | 129 | # 130 | opt_state = optimizer.init(eqx.filter(model, eqx.is_array)) 131 | ex_state = (jtu.tree_leaves(model), jtu.tree_leaves(opt_state)) 132 | ex_shardings = jtu.tree_map(lambda x: x.sharding if eqx.is_array(x) else None, ex_state) 133 | restore_args = ocp.checkpoint_utils.construct_restore_args(ex_state, ex_shardings) 134 | model_leaves, opt_state_leaves = mngr.restore( 135 | mngr.latest_step(), restore_kwargs={"restore_args": restore_args} 136 | ) 137 | model = jtu.tree_unflatten(jtu.tree_structure(model), model_leaves) 138 | 139 | # set up encoding/decoding 140 | # the next several lines are copied directly from nanoGPT 141 | # look for the meta pickle in case it is available in the dataset folder 142 | # only for shakespeare_char 143 | load_meta = False 144 | meta_path = os.path.join(config.data_dir, "meta.pkl") 145 | load_meta = os.path.exists(meta_path) 146 | if load_meta: 147 | print(f"Loading meta from LOCAL {meta_path}...") 148 | with open(meta_path, "rb") as f: 149 | meta = pickle.load(f) 150 | # TODO want to make this more general to arbitrary encoder/decoder schemes 151 | stoi, itos = meta["stoi"], meta["itos"] 152 | encode = lambda s: [stoi[c] for c in s] 153 | decode = lambda l: "".join([itos[i] for i in l]) 154 | else: 155 | # ok let's assume gpt-2 encodings by default 156 | print("No LOCAL meta.pkl found, assuming GPT-2 encodings...") 157 | enc = tiktoken.get_encoding("gpt2") 158 | encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"}) 159 | decode = lambda l: enc.decode(l) 160 | 161 | model = cast_pytree(model, jnp.dtype(config.compute_dtype)) 162 | 163 | block_size = config.model_config.block_size 164 | batched_model = eqx.filter_jit(jax.vmap(eqx.Partial(model, inference=True))) 165 | 166 | # load the prompt 167 | start = cmd_args.start 168 | if start.startswith("FILE:"): 169 | with open(start[5:], "r", encoding="utf-8") as f: 170 | start = f.read() 171 | 172 | key = jrandom.PRNGKey(0) 173 | 174 | start_ids = encode(start if start != "" else "\n") 175 | x = np.array([start_ids for _ in range(cmd_args.num_samples)]) 176 | devices = jax.devices() 177 | mesh = Mesh(mesh_utils.create_device_mesh((len(devices),)), axis_names=("data",)) 178 | # TODO: currently replicating all data. Shard data properly. 179 | data_sharding = NamedSharding(mesh, P(None, None)) 180 | x = jax.device_put(x, data_sharding) 181 | jax.debug.visualize_array_sharding(x) 182 | jax.debug.visualize_array_sharding(model.lm_head.weight_MxN) 183 | 184 | print("generating samples...") 185 | key, sample_key = jrandom.split(key) 186 | y = generate( 187 | config, 188 | batched_model, 189 | x, 190 | cmd_args.max_new_tokens, 191 | temperature=cmd_args.temperature, 192 | key=sample_key, 193 | ) 194 | samples = [decode(y[i].tolist()) for i in range(cmd_args.num_samples)] 195 | for s in samples: 196 | print(s) 197 | print("---------------") 198 | -------------------------------------------------------------------------------- /scripts/tpu_commands.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | function _tpu_ips { 4 | tpu_zone=$1 5 | tpu_project=$2 6 | tpu_name=$3 7 | gcloud alpha compute tpus tpu-vm describe $tpu_name --zone $tpu_zone --project $tpu_project | grep 'externalIp' | awk '{print $2}' 8 | } 9 | 10 | function _tpu_create { 11 | tpu_zone=$1 12 | tpu_project=$2 13 | tpu_gen=$3 14 | tpu_cores=$4 15 | tpu_name=$5 16 | if [ "$tpu_gen" = "v3" ]; then 17 | software_version='tpu-vm-base' 18 | else 19 | software_version='tpu-vm-v4-base' 20 | fi 21 | 22 | if [[ $tpu_cores =~ ^[0-9]+$ ]]; then 23 | gcloud alpha compute tpus tpu-vm create \ 24 | $tpu_name \ 25 | --accelerator-type="$tpu_gen-$tpu_cores" \ 26 | --version $software_version \ 27 | --zone $tpu_zone \ 28 | --project $tpu_project 29 | else 30 | gcloud alpha compute tpus tpu-vm create \ 31 | $tpu_name \ 32 | --type="$tpu_gen" \ 33 | --topology="$tpu_cores" \ 34 | --version $software_version \ 35 | --zone $tpu_zone \ 36 | --project $tpu_project 37 | fi 38 | } 39 | 40 | function _tpu_retry_create { 41 | while true; do 42 | _tpu_create "$@" 43 | sleep 120s 44 | done 45 | } 46 | 47 | function _tpu_cp_ssh_key { 48 | tpu_zone=$1 49 | tpu_project=$2 50 | tpu_name=$3 51 | 52 | gcloud alpha compute tpus tpu-vm scp \ 53 | $HOME/.ssh/authorized_keys \ 54 | $tpu_name:/home/$USER/.ssh/ \ 55 | --worker=all \ 56 | --project $tpu_project \ 57 | --zone $tpu_zone 58 | } 59 | 60 | function _tpu_setup { 61 | tpu_zone=$1 62 | tpu_project=$2 63 | tpu_name=$3 64 | 65 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name) 66 | for host in $tpu_ips; do 67 | scp $PROJECT_HOME/$PROJECT_NAME/scripts/tpu_vm_setup.sh $host:~/ 68 | ssh $host '~/tpu_vm_setup.sh' & 69 | done 70 | wait &> /dev/null 71 | 72 | for host in $tpu_ips; do 73 | scp $PROJECT_HOME/$PROJECT_NAME/scripts/tpu_vm_setup.sh $host:~/ 74 | ssh $host '~/tpu_vm_setup.sh' & 75 | done 76 | wait &> /dev/null 77 | } 78 | 79 | function _tpu_check { 80 | tpu_zone=$1 81 | tpu_project=$2 82 | tpu_name=$3 83 | 84 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name) 85 | for host in $tpu_ips; do 86 | echo "============== Checking host: $host ==============" 87 | ssh $host 'tmux capture-pane -pt launch -S -2000' 88 | echo 89 | echo 90 | done 91 | } 92 | 93 | function _tpu_copy { 94 | tpu_zone=$1 95 | tpu_project=$2 96 | tpu_name=$3 97 | 98 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name) 99 | for host in $tpu_ips; do 100 | rsync -e "ssh -o StrictHostKeyChecking=no" -avPI --exclude=logs --exclude=__pycache__ --exclude=.git --exclude='*_env' $PROJECT_HOME/$PROJECT_NAME $host:~/ & 101 | done 102 | wait &> /dev/null 103 | sleep 1s 104 | 105 | for host in $tpu_ips; do 106 | rsync -e "ssh -o StrictHostKeyChecking=no" -avPI --exclude=logs --exclude=__pycache__ --exclude=.git --exclude='*_env' $PROJECT_HOME/$PROJECT_NAME $host:~/ & 107 | done 108 | wait &> /dev/null 109 | sleep 1s 110 | } 111 | 112 | function _tpu_stop { 113 | tpu_zone=$1 114 | tpu_project=$2 115 | tpu_name=$3 116 | 117 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name) 118 | for host in $tpu_ips; do 119 | ssh $host 'tmux kill-session -t launch ; pkill -9 python' & 120 | done 121 | wait &> /dev/null 122 | } 123 | 124 | function _tpu_launch { 125 | tpu_zone=$1 126 | tpu_project=$2 127 | tpu_name=$3 128 | command=$4 129 | 130 | if [ -z "$command" ]; then 131 | echo "Invalid syntax!" 132 | return 1 133 | fi 134 | 135 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name) 136 | for host in $tpu_ips; do 137 | ssh $host 'tmux new -d -s launch "$command"' & 138 | done 139 | wait &> /dev/null 140 | } 141 | 142 | function _tpu_maintain { 143 | tpu_zone=$1 144 | tpu_project=$2 145 | tpu_name=$3 146 | 147 | gcloud alpha compute tpus tpu-vm simulate-maintenance-event $tpu_name \ 148 | --project $tpu_project \ 149 | --zone=$tpu_zone \ 150 | --workers=all 151 | } 152 | 153 | function _tpu_ssh { 154 | tpu_zone=$1 155 | tpu_project=$2 156 | tpu_name=$3 157 | command="$4" 158 | 159 | if [ -z "$command" ]; then 160 | echo "Invalid syntax!" 161 | return 1 162 | fi 163 | 164 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name) 165 | for host in $tpu_ips; do 166 | ssh $host "$command" & 167 | done 168 | wait &> /dev/null 169 | } 170 | 171 | function _tpu_reboot { 172 | tpu_zone=$1 173 | tpu_project=$2 174 | tpu_name=$3 175 | 176 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name) 177 | for host in $tpu_ips; do 178 | ssh $host 'sudo reboot' & 179 | done 180 | wait &> /dev/null 181 | } 182 | 183 | 184 | function tpu { 185 | trap "trap - SIGINT SIGTERM; return 1;" SIGINT SIGTERM 186 | 187 | 188 | # =============== TPU Project Specific Definitions =============== 189 | export PROJECT_HOME='..' 190 | export PROJECT_NAME='midGPT' 191 | tpu_zone='europe-west4-a' 192 | if [ "$1" = "midGPT" ]; then 193 | tpu_project='midgpt-405721' 194 | tpu_zone='europe-west4-a' 195 | tpu_gen='v3' 196 | else 197 | echo "Invalid syntax!" 198 | trap - SIGINT SIGTERM 199 | return 1 200 | fi 201 | # =============== End of TPU Project Specific Definitions =============== 202 | 203 | 204 | if [ "$2" = "list" ]; then 205 | gcloud alpha compute tpus tpu-vm list --zone $tpu_zone --project $tpu_project 206 | elif [ "$2" = "describe" ]; then 207 | gcloud alpha compute tpus tpu-vm describe $3 --zone $tpu_zone --project $tpu_project 208 | elif [ "$2" = "ips" ]; then 209 | _tpu_ips $tpu_zone $tpu_project $3 210 | elif [ "$2" = "delete" ]; then 211 | gcloud alpha compute tpus tpu-vm delete $3 --zone $tpu_zone --project $tpu_project --quiet 212 | elif [ "$2" = "delete_queued" ]; then 213 | gcloud alpha compute tpus queued-resources delete $3 --project $tpu_project --zone $tpu_zone 214 | elif [ "$2" = "create" ]; then 215 | _tpu_create $tpu_zone $tpu_project $tpu_gen $3 $4 216 | elif [ "$2" = "cp_ssh_key" ]; then 217 | _tpu_cp_ssh_key $tpu_zone $tpu_project $3 218 | elif [ "$2" = "retry_create" ]; then 219 | _tpu_retry_create $tpu_zone $tpu_project $tpu_gen $3 $4 220 | elif [ "$2" = "cs" ]; then 221 | _tpu_create $tpu_zone $tpu_project $tpu_gen $3 $4 222 | sleep 90s 223 | _tpu_setup $tpu_zone $tpu_project $4 224 | elif [ "$2" = "check" ]; then 225 | _tpu_check $tpu_zone $tpu_project $3 226 | elif [ "$2" = "setup" ]; then 227 | _tpu_setup $tpu_zone $tpu_project $3 228 | elif [ "$2" = "copy" ]; then 229 | _tpu_copy $tpu_zone $tpu_project $3 230 | elif [ "$2" = "stop" ]; then 231 | _tpu_stop $tpu_zone $tpu_project $3 232 | elif [ "$2" = "launch" ]; then 233 | _tpu_launch $tpu_zone $tpu_project $3 $4 234 | elif [ "$2" = "cl" ]; then 235 | _tpu_copy $tpu_zone $tpu_project $3 236 | _tpu_launch $tpu_zone $tpu_project $3 $4 237 | elif [ "$2" = "maintain" ]; then 238 | _tpu_maintain $tpu_zone $tpu_project $3 239 | elif [ "$2" = "ssh" ]; then 240 | _tpu_ssh $tpu_zone $tpu_project $3 "$4" 241 | elif [ "$2" = "reboot" ]; then 242 | _tpu_reboot $tpu_zone $tpu_project $3 243 | elif [ "$2" = "df" ]; then 244 | _tpu_ssh $tpu_zone $tpu_project $3 'df -h | grep root' 245 | else 246 | echo "Invalid syntax!" 247 | trap - SIGINT SIGTERM 248 | return 1 249 | fi 250 | trap - SIGINT SIGTERM 251 | } 252 | 253 | 254 | export -f tpu _tpu_ips _tpu_create _tpu_setup _tpu_check _tpu_copy _tpu_stop _tpu_launch _tpu_maintain _tpu_ssh _tpu_reboot 255 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import math 3 | import typing as tp 4 | import equinox as eqx 5 | import jax 6 | from .layers import Linear, Embedding, RMSNorm, fixed_pos_embedding, apply_rotary_pos_emb 7 | 8 | jnp, jrandom, vmap, jtu = jax.numpy, jax.random, jax.vmap, jax.tree_util 9 | Array = jax.Array 10 | KeyArray = tp.Any 11 | P = jax.sharding.PartitionSpec 12 | NamedSharding = jax.sharding.NamedSharding 13 | Mesh = jax.sharding.Mesh 14 | with_sharding_constraint = jax.lax.with_sharding_constraint 15 | 16 | 17 | class MLP(eqx.Module): 18 | c_fc: Linear 19 | c_proj: Linear 20 | dropout: eqx.nn.Dropout 21 | 22 | def __init__(self, n_embd, dropout, key): 23 | key1, key2 = jrandom.split(key) 24 | self.c_fc = Linear(n_embd, 4 * n_embd, key=key1) 25 | self.c_proj = Linear(4 * n_embd, n_embd, key=key2) 26 | self.dropout = eqx.nn.Dropout(dropout) 27 | 28 | @jax.named_scope('mlp') 29 | def __call__(self, x_D, inference=False, key=None): 30 | x_D = jax.nn.gelu(self.c_fc(x_D)) 31 | return self.dropout(self.c_proj(x_D), inference=inference, key=key) 32 | 33 | 34 | class CausalSelfAttention(eqx.Module): 35 | n_head: int 36 | n_embd: int 37 | c_attn: Linear 38 | c_proj: Linear 39 | attn_dropout: eqx.nn.Dropout 40 | resid_dropout: eqx.nn.Dropout 41 | q_ln: eqx.nn.LayerNorm 42 | k_ln: eqx.nn.LayerNorm 43 | 44 | def __init__(self, n_embd, n_head, dropout, key): 45 | key1, key2 = jrandom.split(key) 46 | assert n_embd % n_head == 0 47 | self.n_head, self.n_embd = n_head, n_embd 48 | self.c_attn = Linear(n_embd, 3 * n_embd, key=key1) 49 | self.c_proj = Linear(n_embd, n_embd, key=key2) 50 | self.attn_dropout = eqx.nn.Dropout(dropout) 51 | self.resid_dropout = eqx.nn.Dropout(dropout) 52 | self.q_ln = eqx.nn.LayerNorm(n_embd // n_head, eps=1e-6, use_weight=True, use_bias=False) 53 | self.k_ln = eqx.nn.LayerNorm(n_embd // n_head, eps=1e-6, use_weight=True, use_bias=False) 54 | 55 | @jax.named_scope('causal_sa') 56 | def __call__(self, x_TxD, inference=False, key=None): 57 | adrop_key, pdrop_key = jrandom.split(key) if key is not None else (None, None) 58 | T, D = x_TxD.shape 59 | Q_TxD, K_TxD, V_TxD = jnp.split(vmap(self.c_attn)(x_TxD), 3, axis=-1) 60 | C = self.n_embd // self.n_head 61 | Q_HxTxC = jnp.transpose(jnp.reshape(Q_TxD, (T, self.n_head, C)), (1, 0, 2)) 62 | K_HxTxC = jnp.transpose(jnp.reshape(K_TxD, (T, self.n_head, C)), (1, 0, 2)) 63 | # QK LayerNorm 64 | Q_HxTxC = vmap(vmap(self.q_ln))(Q_HxTxC) 65 | K_HxTxC = vmap(vmap(self.k_ln))(K_HxTxC) 66 | # Rotary embeddings 67 | sin_TxCp, cos_TxCp = fixed_pos_embedding(C, T) # Cp = C//2 68 | Q_HxTxC = apply_rotary_pos_emb(Q_HxTxC, sin_TxCp, cos_TxCp) 69 | K_HxTxC = apply_rotary_pos_emb(K_HxTxC, sin_TxCp, cos_TxCp) 70 | V_HxTxC = jnp.transpose(jnp.reshape(V_TxD, (T, self.n_head, C)), (1, 0, 2)) 71 | A_HxTxT = Q_HxTxC @ jnp.transpose(K_HxTxC, (0, 2, 1)) 72 | causal_mask = jnp.tril(jnp.ones((1, T, T))) == 0 73 | A_HxTxT = jnp.where(causal_mask, float('-inf'), A_HxTxT) 74 | # Softmax should be in full precision. 75 | orig_dtype = A_HxTxT.dtype 76 | A_HxTxT = jax.nn.softmax(A_HxTxT.astype(jnp.float32) / jnp.sqrt(C), axis=-1) 77 | A_HxTxT = A_HxTxT.astype(orig_dtype) 78 | A_HxTxT = self.attn_dropout(A_HxTxT, inference=inference, key=adrop_key) 79 | out_TxD = jnp.reshape(jnp.transpose(A_HxTxT @ V_HxTxC, (1, 0, 2)), (T, D)) 80 | out_TxD = self.resid_dropout(vmap(self.c_proj)(out_TxD), inference=inference, key=pdrop_key) 81 | return out_TxD 82 | 83 | 84 | class Block(eqx.Module): 85 | attn: CausalSelfAttention 86 | mlp: MLP 87 | ln1: RMSNorm 88 | ln2: RMSNorm 89 | 90 | def __init__(self, n_embd, n_head, dropout, key): 91 | key1, key2 = jrandom.split(key) 92 | self.attn = CausalSelfAttention(n_embd=n_embd, n_head=n_head, dropout=dropout, key=key1) 93 | self.mlp = MLP(n_embd=n_embd, dropout=dropout, key=key2) 94 | self.ln1 = RMSNorm(n_embd) 95 | self.ln2 = RMSNorm(n_embd) 96 | 97 | @jax.named_scope('block') 98 | def __call__(self, x_TxD, inference=False, key=None): 99 | attn_key, mlp_key = (None, None) 100 | if key is not None: 101 | attn_key, mlp_key = jrandom.split(key) 102 | mlp_key = jrandom.split(mlp_key, x_TxD.shape[0]) 103 | x_TxD = x_TxD + self.attn(vmap(self.ln1)(x_TxD), inference=inference, key=attn_key) 104 | mlp = vmap(self.mlp, in_axes=(0, None, 0)) 105 | return x_TxD + mlp(vmap(self.ln2)(x_TxD), inference, mlp_key) 106 | 107 | 108 | @dataclass 109 | class GPTConfig: 110 | block_size: int # Max sequence length 111 | vocab_size: int # No. of tokens 112 | n_layer: int # No. of transformer blocks 113 | n_head: int # No. attention heads 114 | n_embd: int # Hidden dimension 115 | dropout: float 116 | 117 | 118 | class GPT(eqx.Module): 119 | wte: Embedding 120 | drop: eqx.nn.Dropout 121 | blocks: tp.List[Block] 122 | ln_f: RMSNorm 123 | lm_head: Linear 124 | n_layer: int 125 | 126 | def __init__(self, config, key): 127 | self.n_layer = config.n_layer 128 | block_key, head_key = jrandom.split(key) 129 | self.drop = eqx.nn.Dropout(config.dropout) 130 | def make_block(_key): 131 | return Block(config.n_embd, config.n_head, config.dropout, _key) 132 | self.blocks = eqx.filter_vmap(make_block)(jrandom.split(block_key, config.n_layer)) 133 | self.ln_f = RMSNorm(config.n_embd, eps=1e-5) 134 | embed_std = (1 / math.sqrt(config.n_embd)) 135 | wte_wt = embed_std * jrandom.normal(head_key, (config.vocab_size, config.n_embd)) 136 | self.wte = Embedding(config.vocab_size, config.n_embd, weight=wte_wt) 137 | # Share first and last layer parameters. 138 | self.lm_head = Linear(config.n_embd, config.vocab_size, weight=wte_wt) 139 | 140 | @jax.named_scope('gpt') 141 | def __call__(self, x_T, inference=False, key=None): 142 | # Either (inference=False and key) or (inference=True and key=None) 143 | drop_key, block_keys = None, None 144 | if key is not None: 145 | drop_key, block_keys = jrandom.split(key) 146 | block_keys = jrandom.split(block_keys, self.n_layer) 147 | x_TxD = self.drop(self.wte(x_T), inference=inference, key=drop_key) 148 | dynamic_blocks, static_blocks = eqx.partition(self.blocks, eqx.is_array) 149 | @jax.checkpoint 150 | def block_fn(_x_TxD: Array, block_and_key: tp.Tuple[GPT, tp.Optional[KeyArray]]): 151 | _dynamic_block, _key = block_and_key 152 | block = eqx.combine(_dynamic_block, static_blocks) 153 | return block(_x_TxD, inference=inference, key=_key), None 154 | # Set unroll=self.n_layer for better speed (but slower compile). 155 | x_TxD, _ = jax.lax.scan(block_fn, x_TxD, (dynamic_blocks, block_keys), unroll=1) 156 | x_TxD = vmap(self.ln_f)(x_TxD) 157 | logits_TxV = vmap(self.lm_head)(x_TxD) 158 | return logits_TxV 159 | 160 | 161 | def count_params(model: GPT) -> int: 162 | dupe = jnp.size(model.lm_head.weight_MxN) # embedding and final layer are shared. 163 | tot = sum([jnp.size(x) for x in jtu.tree_leaves(model) if isinstance(x, jax.Array)]) 164 | return tot - dupe # non-embedding only. 165 | 166 | 167 | def shard_gpt( 168 | model: GPT, mesh: Mesh, shard_model: bool, sharding_fn=with_sharding_constraint 169 | ) -> eqx.Module: 170 | """Shard model parameters over devices (TPUs or GPUs).""" 171 | def sharding_map(x: Array) -> NamedSharding: 172 | axes: tuple[tp.Any, ...] = (None,) * x.ndim 173 | if x.size > 2**18 and shard_model: 174 | axes = (None,) * (x.ndim - 1) + ('data',) 175 | return NamedSharding(mesh, P(*axes)) 176 | dynamic_model, static_model = eqx.partition(model, eqx.is_array) 177 | dynamic_model = jtu.tree_map(lambda x: sharding_fn(x, sharding_map(x)), dynamic_model) 178 | return eqx.combine(dynamic_model, static_model) 179 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | from functools import partial 3 | from dataclasses import dataclass 4 | import os 5 | import equinox as eqx 6 | import jax 7 | from jax.experimental import mesh_utils 8 | import optax 9 | import orbax.checkpoint as ocp 10 | import numpy as np 11 | import wandb 12 | from tqdm import trange 13 | from .model import GPT, GPTConfig, shard_gpt, count_params 14 | from .sharding import reshard, get_shard_fn 15 | 16 | jax.config.update("jax_threefry_partitionable", True) 17 | 18 | jnp, jrandom, vmap, scan, jtu = jax.numpy, jax.random, jax.vmap, jax.lax.scan, jax.tree_util 19 | Array = jax.Array 20 | KeyArray = tp.Any 21 | Mesh = jax.sharding.Mesh 22 | NamedSharding = jax.sharding.NamedSharding 23 | P, with_sharding_constraint = jax.sharding.PartitionSpec, jax.lax.with_sharding_constraint 24 | 25 | 26 | @dataclass 27 | class ExperimentConfig: 28 | rundir: str # Directory containing ckpts and logs. 29 | data_dir: str # Dataset directory 30 | learning_rate: float 31 | batch_size: int # GLOBAL across all devices (not per device) 32 | warmup_steps: int 33 | min_lr: float # Final LR after decay 34 | lr_decay_steps: int 35 | max_steps: int # No. of grad steps 36 | beta2: float 37 | weight_decay: float 38 | eval_interval: int 39 | param_dtype: str # bfloat16 or float32 40 | compute_dtype: str 41 | g_accum_iters: int # Accumulate this many grads before step 42 | shard_model: bool 43 | model_config: GPTConfig 44 | debug: bool = False 45 | 46 | 47 | def cast_pytree(pytree: tp.Any, dtype: jnp.dtype) -> tp.Any: 48 | """Cast a pytree of arrays to a given dtype, ignore non-arrays.""" 49 | def cast(x): 50 | if eqx.is_array(x): 51 | return x.astype(dtype) 52 | return x 53 | return jtu.tree_map(cast, pytree) 54 | 55 | 56 | def get_batch( 57 | data, block_size: int, batch_size: int, g_accum_iters: tp.Optional[int]=None 58 | ) -> tp.Tuple[np.ndarray, np.ndarray]: 59 | bs = batch_size * (g_accum_iters or 1) 60 | ix = np.random.randint(0, len(data) - block_size, size=(bs,)) 61 | x = np.take(data, np.arange(block_size) + ix[:, None], axis=0).astype(np.int32) 62 | y = np.take(data, np.arange(1, block_size + 1) + ix[:, None], axis=0).astype(np.int32) 63 | if g_accum_iters is not None: # reshape to (g_accum_steps, batch_size, block_size) 64 | x = x.reshape(g_accum_iters, batch_size, block_size) 65 | y = y.reshape(g_accum_iters, batch_size, block_size) 66 | return x, y 67 | 68 | 69 | def make_training_fns( 70 | config: ExperimentConfig, optimizer: optax.GradientTransformationExtraArgs, 71 | mesh: Mesh) -> tp.Tuple[tp.Callable, tp.Callable]: 72 | def loss_fn(model_params: GPT, model_static: GPT, x: Array, y: Array, key: tp.Optional[KeyArray]) -> Array: 73 | model = eqx.combine(model_params, model_static) 74 | if key is not None: 75 | key = jrandom.split(key, x.shape[0]) 76 | logits = vmap(model)(x, key=key).astype(jnp.float32) 77 | return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean() 78 | 79 | @partial(eqx.filter_jit, donate='all') 80 | def step(model: GPT, opt_state, x_GxBxT: Array, y_GxBxT: Array, key: KeyArray): 81 | G = config.g_accum_iters 82 | params, static = eqx.partition((model), eqx.is_array) 83 | params_cpt = cast_pytree(params, jnp.dtype(config.compute_dtype)) 84 | # compute loss and grad on microbatch, then scan over microbatches 85 | def microstep(grad_so_far, xykey_g: tp.Tuple[Array, Array, KeyArray]): 86 | loss, grad = jax.value_and_grad(loss_fn)(params_cpt, static, *xykey_g) 87 | grad = shard_gpt(grad, mesh, config.shard_model) 88 | grad_so_far = jtu.tree_map(lambda x, y: x + y, grad, grad_so_far) 89 | return grad_so_far, loss 90 | all_keys = jrandom.split(key, config.g_accum_iters) 91 | init_grad = jtu.tree_map(jnp.zeros_like, params) 92 | grad, loss_G = scan(microstep, init_grad, (x_GxBxT, y_GxBxT, all_keys)) 93 | # Grad accumulated (summed) over G, so divide. 94 | loss, grad = jnp.mean(loss_G, axis=0), jtu.tree_map(lambda x: x / G, grad) 95 | updates, opt_state = optimizer.update(grad, opt_state, params) 96 | model = eqx.combine(optax.apply_updates(params, updates), static) 97 | return model, opt_state, loss 98 | 99 | @eqx.filter_jit 100 | def simple_loss(model: tp.Union[GPT, eqx.Partial], x: Array, y: Array, key: tp.Optional[KeyArray]) -> Array: 101 | """Same as loss_fn, but doesn't split params into compute/static.""" 102 | model_params, model_static = eqx.partition(model, eqx.is_array) 103 | return loss_fn(model_params, model_static, x, y, key) 104 | 105 | data_sharding = NamedSharding(mesh, P(None, ('replica', 'data'), None)) # (G, B, D) 106 | shard_fn = get_shard_fn(mesh, data_sharding) 107 | def evaluate(model: GPT, data: np.ndarray) -> float: 108 | eval_model = eqx.Partial(cast_pytree(model, jnp.dtype(config.compute_dtype)), inference=True) 109 | tot_loss = 0 110 | num_eval_steps = 1 if config.debug else 200 111 | for i in range(num_eval_steps): 112 | x_1xBxD_np, y_1xBxD_np = get_batch(data, config.model_config.block_size, config.batch_size, 1) 113 | x_1xBxD, y_1xBxD = jtu.tree_map(shard_fn, (x_1xBxD_np, y_1xBxD_np)) 114 | x_BxD, y_BxD = x_1xBxD.squeeze(0), y_1xBxD.squeeze(0) 115 | loss = simple_loss(eval_model, x_BxD, y_BxD, None).item() 116 | tot_loss = tot_loss + loss 117 | return tot_loss / num_eval_steps 118 | 119 | return step, evaluate 120 | 121 | 122 | def split_array_by_idx(arr_N, proc_idx, n_proc): 123 | n = int(arr_N.shape[0] / n_proc) + 1 # n per proc 124 | return arr_N[proc_idx * n:(proc_idx+1)*n] 125 | 126 | 127 | def train(config: ExperimentConfig): 128 | n_proc, proc_idx = jax.process_count(), jax.process_index() 129 | n_devices = jax.device_count() # Assumes num_devices is multiple of 8. 130 | mesh = Mesh(mesh_utils.create_device_mesh((n_devices // 8, 8)), axis_names=('replica', 'data')) 131 | 132 | train_data = np.memmap(os.path.join(config.data_dir, 'train.bin'), dtype=np.uint16, mode='r').copy() 133 | val_data = np.memmap(os.path.join(config.data_dir, 'val.bin'), dtype=np.uint16, mode='r').copy() 134 | print("Raw shapes", train_data.shape, val_data.shape) 135 | train_data = split_array_by_idx(train_data, proc_idx, n_proc) 136 | val_data = split_array_by_idx(val_data, proc_idx, n_proc) 137 | print(f"Process {proc_idx}/{n_proc}. train_data.shape={train_data.shape}, val_data.shape={val_data.shape}.") 138 | 139 | if not config.debug: 140 | options = ocp.CheckpointManagerOptions( 141 | max_to_keep=1, save_interval_steps=config.eval_interval) 142 | mngr = ocp.CheckpointManager( 143 | config.rundir, 144 | ocp.AsyncCheckpointer(ocp.PyTreeCheckpointHandler()), 145 | options=options) 146 | 147 | scheduler = optax.warmup_cosine_decay_schedule( 148 | 0, config.learning_rate, config.warmup_steps, config.lr_decay_steps, 149 | end_value=config.min_lr) 150 | @jax.jit 151 | def get_lr(_opt_state): 152 | return scheduler(_opt_state[3].count) 153 | optimizer = optax.chain( 154 | optax.clip_by_global_norm(1.0), 155 | optax.scale_by_adam(b2=config.beta2), 156 | optax.add_decayed_weights(config.weight_decay / config.learning_rate), 157 | optax.scale_by_schedule(scheduler), 158 | optax.scale(-1), 159 | ) 160 | step, evaluate = make_training_fns(config, optimizer, mesh) 161 | 162 | key = jrandom.PRNGKey(0) 163 | def init_model(model_key): 164 | model = GPT(config.model_config, model_key) 165 | model = cast_pytree(model, config.param_dtype) 166 | model = shard_gpt(model, mesh, config.shard_model) 167 | return model 168 | key, key1 = jrandom.split(key) 169 | # Use jit with sharding constraints to init sharded model+opt. 170 | model= eqx.filter_jit(init_model)(key1) 171 | print(f'Model has {count_params(model)} parameters.') 172 | def repl_opt_scalars(x: Array): 173 | if x.ndim == 0: 174 | x = reshard(x, NamedSharding(mesh, P())) 175 | return x 176 | opt_state = optimizer.init(eqx.filter(model, eqx.is_array)) 177 | opt_state = jtu.tree_map(repl_opt_scalars, opt_state) 178 | first_step = 0 179 | if not config.debug and mngr.latest_step() is not None: # Restore existing checkpoint. 180 | ex_state = (jtu.tree_leaves(model), jtu.tree_leaves(opt_state)) 181 | ex_shardings = jtu.tree_map(lambda x: x.sharding if eqx.is_array(x) else None, ex_state) 182 | restore_args = ocp.checkpoint_utils.construct_restore_args(ex_state, ex_shardings) 183 | model_leaves, opt_state_leaves = mngr.restore( 184 | mngr.latest_step(), restore_kwargs={'restore_args': restore_args}) 185 | model = jtu.tree_unflatten(jtu.tree_structure(model), model_leaves) 186 | opt_state = jtu.tree_unflatten(jtu.tree_structure(opt_state), opt_state_leaves) 187 | first_step = mngr.latest_step() + 1 188 | data_sharding = NamedSharding(mesh, P(None, ('replica', 'data'), None)) # (G, B, D) 189 | shard_fn = get_shard_fn(mesh, data_sharding) 190 | postfix_values = {} # values to display in the progress bar 191 | pbar = trange( 192 | first_step, config.max_steps, initial=first_step, total=config.max_steps, 193 | disable=jax.process_index() != 0) 194 | for itr in pbar: 195 | if itr % config.eval_interval == 0: 196 | train_loss = evaluate(model, train_data) 197 | val_loss = evaluate(model, val_data) 198 | postfix_values['train_loss'] = train_loss 199 | postfix_values['val_loss'] = val_loss 200 | if jax.process_index() == 0: 201 | wandb.log({'loss/train': train_loss, 'loss/val': val_loss}, step=itr) 202 | key, key1 = jrandom.split(key) 203 | x_GxBxD, y_GxBxD = get_batch( 204 | train_data, config.model_config.block_size, config.batch_size, config.g_accum_iters) 205 | if config.debug and itr == 0: 206 | jax.profiler.start_trace(config.rundir) 207 | x_GxBxD, y_GxBxD = jtu.tree_map(shard_fn, (x_GxBxD, y_GxBxD)) 208 | model, opt_state, loss = step(model, opt_state, x_GxBxD, y_GxBxD, key1) 209 | if config.debug and itr == 0: 210 | loss.block_until_ready() 211 | jax.profiler.stop_trace() 212 | if jax.process_index() == 0 and itr % 20 == 0: 213 | wandb.log({'loss/optimized': loss.item()}, step=itr) 214 | if not config.debug: 215 | mngr.save(itr, (jtu.tree_leaves(model), jtu.tree_leaves(opt_state))) 216 | postfix_values['loss'] = loss.item() 217 | postfix_values['lr'] = get_lr(opt_state).item() 218 | if pbar.format_dict['rate'] is not None: 219 | postfix_values['thpt'] = pbar.format_dict['rate'] * config.batch_size * config.g_accum_iters 220 | pbar.set_postfix(**postfix_values) 221 | pbar.close() 222 | if jax.process_index() == 0: 223 | wandb.finish() 224 | if not config.debug: 225 | mngr.wait_until_finished() 226 | --------------------------------------------------------------------------------