├── loader.py ├── ray_tpu.py ├── readme.md ├── scripts └── init_ray.sh ├── setup.py ├── swarm_jax ├── __init__.py ├── embedding_layer.py ├── model.py ├── reversible_layer.py ├── swarm.py └── swarm_layer.py ├── swarm_run.py └── swarm_run_tpu.py /loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import mmap 5 | import numpy as np 6 | 7 | 8 | class TextLoader(): 9 | def __init__(self, fname, batchsize, sample_size, offset=0, length=0): 10 | self.f = open(fname, "r+b") 11 | self.mm = mmap.mmap(self.f.fileno(), length=length, offset=offset) 12 | self.file_size = os.stat(fname).st_size 13 | self.bs = np.product(batchsize) 14 | 15 | if isinstance(batchsize, tuple): 16 | self.batch_shape = batchsize 17 | else: 18 | self.batch_shape = (batchsize,) 19 | self.ss = sample_size 20 | 21 | self.np_mm = np.memmap(fname, dtype='uint8', mode='r', shape=(self.file_size,)) 22 | 23 | def get_samples(self): 24 | sample = np.random.randint(0, self.file_size - 2 - self.ss, self.bs) 25 | batch = np.zeros((self.bs, self.ss + 1)) 26 | 27 | for i in range(self.ss + 1): 28 | batch[:, i] = self.np_mm[sample + i] 29 | 30 | target = batch[:, 1:].astype(np.uint32) 31 | target = target.reshape(self.batch_shape + (self.ss,)) 32 | 33 | obs = batch[:, :-1].astype(np.uint32) 34 | obs = obs.reshape(self.batch_shape + (self.ss,)) 35 | 36 | return {"target": target, "obs": obs} 37 | 38 | 39 | if __name__ == "__main__": 40 | tl = TextLoader("data/enwik9", batchsize=(8, 128), sample_size=128) 41 | np.sum(tl.np_mm) 42 | print("preload done") 43 | 44 | for i in range(100): 45 | tl.get_samples() 46 | 47 | print("warmup done") 48 | 49 | start = time.time() 50 | 51 | it = 1000 52 | 53 | for i in range(it): 54 | tl.get_samples() 55 | 56 | t = time.time() - start 57 | print(f"samples done in {t} s") 58 | print(f"{tl.bs * it/t} eg/s") 59 | 60 | -------------------------------------------------------------------------------- /ray_tpu.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | import subprocess 4 | import time 5 | 6 | import glob 7 | import requests 8 | from fabric import Connection 9 | 10 | 11 | @functools.lru_cache() 12 | def get_bearer(): 13 | return subprocess.check_output("gcloud auth print-access-token", shell=True).decode("utf-8").strip() 14 | 15 | 16 | @functools.lru_cache() 17 | def get_project(): 18 | return subprocess.check_output("gcloud config list --format 'value(core.project)'", shell=True).decode( 19 | "utf-8").strip() 20 | 21 | 22 | def create_tpu( 23 | name, 24 | zone, 25 | type, 26 | preemptible, 27 | ): 28 | headers = { 29 | 'Authorization': f'Bearer {get_bearer()}', 30 | 'Content-Type': 'application/json', 31 | } 32 | 33 | params = ( 34 | ('node_id', name), 35 | ) 36 | 37 | data = {"accelerator_type": 38 | type, 39 | "runtime_version": 40 | 'v2-alpha', 41 | "network_config": 42 | {"enable_external_ips": True}, 43 | } 44 | 45 | if preemptible: 46 | data["schedulingConfig"] = {"preemptible": True} 47 | 48 | response = requests.post(f'https://tpu.googleapis.com/v2alpha1/projects/{get_project()}/locations/{zone}/nodes', 49 | headers=headers, params=params, json=data) 50 | 51 | print(response.json()) 52 | 53 | return response.status_code == 200 54 | 55 | 56 | def check_tpu(name, zone): 57 | headers = { 58 | 'Authorization': f'Bearer {get_bearer()}', 59 | } 60 | 61 | response = requests.get( 62 | f'https://tpu.googleapis.com/v2alpha1/projects/{get_project()}/locations/{zone}/nodes/{name}', 63 | headers=headers) 64 | 65 | return response.json() 66 | 67 | 68 | def delete_tpu(name, zone): 69 | headers = { 70 | 'Authorization': f'Bearer {get_bearer()}', 71 | } 72 | 73 | response = requests.delete( 74 | f'https://tpu.googleapis.com/v2alpha1/projects/{get_project()}/locations/{zone}/nodes/{name}', 75 | headers=headers) 76 | 77 | return response.json() 78 | 79 | 80 | def wait_til(name, zone, state): 81 | while True: 82 | ret = check_tpu(name, zone) 83 | 84 | print(ret) 85 | 86 | matches = True 87 | for k, expected_v in state.items(): 88 | if k not in ret: 89 | matches = False 90 | continue 91 | if ret[k] != expected_v: 92 | matches = False 93 | 94 | if "error" in ret: 95 | return False 96 | 97 | if ret["state"] == "TERMINATED": 98 | return False 99 | 100 | if matches: 101 | return True 102 | 103 | time.sleep(1) 104 | 105 | 106 | def get_connection( 107 | name, 108 | zone, 109 | ): 110 | info = check_tpu(name, zone) 111 | outputs = [] 112 | for i in info["networkEndpoints"]: 113 | outputs.append(Connection(i["ipAddress"], 114 | connect_kwargs={ 115 | "key_filename": os.path.expanduser('~/.ssh/google_compute_engine'), })) 116 | return outputs 117 | 118 | 119 | def start_ray(conn, address): 120 | conn.sudo('rm -rf *.py') 121 | conn.sudo('rm -rf swarm_jax') 122 | 123 | for i in glob.glob("*.py"): 124 | print(i) 125 | conn.put(i, "") 126 | 127 | conn.run("mkdir swarm_jax -p") 128 | 129 | for i in glob.glob("swarm_jax/*.py"): 130 | print(i) 131 | conn.put(i, "swarm_jax/") 132 | 133 | conn.sudo('python3 setup.py install') 134 | 135 | conn.put("scripts/init_ray.sh", "/tmp/ray-tpu.sh") 136 | print(conn.sudo('chmod +x /tmp/ray-tpu.sh')) 137 | print(conn.sudo('/tmp/ray-tpu.sh')) 138 | try: 139 | print(conn.run('ray stop -f')) 140 | except: 141 | pass 142 | print(conn.run(f"ray start --address={address} --load-code-from-local --resources='" + '{"tpu": 1}\'')) 143 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Pipelined Swarm Training 2 | 3 | Swarm training "framework" using Haiku + Jax + Ray. 4 | 5 | Designed for training large language models in a model parallel fashion with unreliable, heterogeneous nodes. (eventually) 6 | 7 | Look in `swarm_run.py` for an example of running a character transformer on enwik8. 8 | 9 | # TODOs 10 | 11 | - [x] Forward passes 12 | - [x] Backward passes with activation reconstruction 13 | - [x] Run optimizer 14 | - [x] Logging 15 | - [x] Checkpointing 16 | - [x] Actually do pipelining 17 | - [x] fp16 with static loss scaling 18 | - [x] Integer quantization for activations and gradients between layers 19 | - [ ] Get rid of pipeline stalls from running optimizer 20 | - [ ] Data parallelism with multiple nodes per layer and gradient/weight aggregation 21 | - [ ] Heterogeneous nodes with potentially multiple layers per node 22 | - [ ] Handle unbalanced and unreliable nodes (layerdrop) 23 | - [ ] Dynamic node addition 24 | - [ ] 1T or bust? -------------------------------------------------------------------------------- /scripts/init_ray.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # initializes jax and installs ray on cloud TPUs 3 | 4 | # create tempfs for ray shared memory 5 | sudo mkdir /dev/shm 6 | sudo mount -t tmpfs -o size=100g tmpfs /dev/shm 7 | 8 | sudo pip install --upgrade jaxlib==0.1.59 9 | sudo pip install --upgrade jax ray fabric dataclasses optax git+https://github.com/deepmind/dm-haiku -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='swarm_jax', 5 | version='0.0.0', 6 | packages=['swarm_jax'] 7 | ) 8 | -------------------------------------------------------------------------------- /swarm_jax/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kingoflolz/swarm-jax/62cd943ba38c3aa8262b23e45f80870c7e7434f6/swarm_jax/__init__.py -------------------------------------------------------------------------------- /swarm_jax/embedding_layer.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import operator 3 | import random 4 | import time 5 | from queue import Queue 6 | 7 | import haiku as hk 8 | import jax 9 | import jax.numpy as jnp 10 | import numpy as np 11 | import optax 12 | import ray 13 | from typing import Optional 14 | 15 | from .swarm_layer import save_checkpoint, load_checkpoint, opt_state, run_threads, run_function, NetworkPrecision, \ 16 | quantize, dequantize, init_fn 17 | 18 | 19 | def layer_norm(x: jnp.ndarray, name: Optional[str] = None) -> jnp.ndarray: 20 | """Apply a unique LayerNorm to x with default settings.""" 21 | return hk.LayerNorm(axis=-1, 22 | create_scale=True, 23 | create_offset=True, 24 | name=name)(x) 25 | 26 | 27 | @ray.remote(resources={"tpu": 1}) 28 | class EmbeddingLayer(object): 29 | def __init__(self, obs, vocab: int, d_model: int, optimizer: optax.GradientTransformation, 30 | precision: NetworkPrecision): 31 | self.vocab = vocab 32 | self.d_model = d_model 33 | self.optimizer = optimizer 34 | self.precision = precision 35 | 36 | print("start init") 37 | self.devices = jax.local_device_count() 38 | print("done jax init") 39 | 40 | def embed_forward(x): 41 | embed_init = hk.initializers.TruncatedNormal(stddev=0.02) 42 | 43 | seq_length = x.shape[1] 44 | positional_embeddings = hk.get_parameter('pos_embs', [seq_length, d_model], init=embed_init) 45 | 46 | o = hk.Embed(vocab, d_model, w_init=embed_init, name="embedding")(x) + positional_embeddings 47 | 48 | return o 49 | 50 | self.embed_fwd_fn = hk.transform(embed_forward) 51 | master_rng = jax.random.PRNGKey(random.getrandbits(32)) 52 | 53 | @functools.partial(jax.pmap) 54 | def embed_fwd_fn(obs, params): 55 | out = self.embed_fwd_fn.apply(params, None, obs) 56 | 57 | return out 58 | 59 | @functools.partial(jax.pmap, donate_argnums=(1, 2)) 60 | def embed_grad_fn(obs, y_dy, acc, params): 61 | y, dy = y_dy 62 | 63 | y_new, vjpfun = jax.vjp(self.embed_fwd_fn.apply, params, None, obs) 64 | weights_grad, _, _ = vjpfun(dy) 65 | diff = jnp.square(y - y_new).mean() 66 | cos_err = jnp.abs(1.0 - jnp.dot(y_new.flatten(), y.flatten()) / ( 67 | jnp.linalg.norm(y.flatten()) * jnp.linalg.norm(y_new.flatten()))) 68 | 69 | new_acc = jax.tree_multimap(operator.add, acc, weights_grad) 70 | return diff, cos_err, new_acc 71 | 72 | # we call all the functions here to trigger jit at init 73 | self.state = init_fn(master_rng, obs, self.embed_fwd_fn.init, optimizer) 74 | 75 | num_params = hk.data_structures.tree_size(self.state["params"]) 76 | print(f'Param count = {num_params}') 77 | 78 | self.embed_fwd = embed_fwd_fn 79 | e = self.embed_fwd(obs, self.state["params"]) 80 | 81 | self.embed_grad = embed_grad_fn 82 | _, _, new_acc = self.embed_grad(obs, (e, e), self.state["grad_acc"], self.state["params"]) 83 | self.state["grad_acc"] = new_acc 84 | 85 | self.state = opt_state(self.state, self.optimizer) 86 | self.state = init_fn(master_rng, obs, self.embed_fwd_fn.init, optimizer) 87 | 88 | self.init = False 89 | 90 | def run(self): 91 | def forward(obs, state): 92 | return quantize(self.embed_fwd(obs, state["params"]), self.precision.fwd_act) 93 | 94 | def backward(y_dy, obs, state): 95 | y, dy = y_dy 96 | y_dy = (dequantize(y, "float32"), dequantize(dy, "float32")) 97 | diff, cos_err, new_grad_acc = self.embed_grad(obs, y_dy, state["grad_acc"], state["params"]) 98 | state["grad_acc"] = new_grad_acc 99 | state["grad_count"] = state["grad_count"] + 1 100 | 101 | self.state = state 102 | 103 | return diff, cos_err 104 | 105 | self.fwd_q = Queue(2) 106 | self.bwd_q = Queue(2) 107 | self.init = True 108 | 109 | run_threads(self.state, self.fwd_q, self.bwd_q, 2, forward, backward) 110 | 111 | @ray.method(num_returns=2) 112 | def embed_forward(self, obs): 113 | while not self.init: 114 | time.sleep(0.1) 115 | return run_function(self.fwd_q, obs), None 116 | 117 | def embed_grad(self, obs, y_dy): 118 | while not self.init: 119 | time.sleep(0.1) 120 | return run_function(self.bwd_q, y_dy, obs) 121 | 122 | def opt(self): 123 | self.state = opt_state(self.state, self.optimizer) 124 | 125 | def get_params(self): 126 | return self.state["params"] 127 | 128 | def get_accum(self): 129 | return self.state["grad_acc"] 130 | 131 | def save(self, path, epoch): 132 | save_checkpoint(self.state, path, epoch) 133 | 134 | def load(self, path): 135 | ckpt = load_checkpoint(path) 136 | 137 | if ckpt: 138 | self.state = ckpt 139 | return True 140 | 141 | 142 | @ray.remote(resources={"tpu": 1}) 143 | class ProjLayer(object): 144 | def __init__(self, data, vocab: int, d_model: int, optimizer: optax.GradientTransformation, loss_scale: float, 145 | precision: NetworkPrecision): 146 | self.vocab = vocab 147 | self.d_model = d_model 148 | self.optimizer = optimizer 149 | self.loss_scale = loss_scale 150 | self.precision = precision 151 | 152 | data = dequantize(data, "float32") 153 | 154 | def debed_forward(x): 155 | x = layer_norm(x) 156 | 157 | return hk.Linear(vocab)(x) 158 | 159 | def debed_loss(x, target): 160 | logits = debed_forward(x) 161 | target_onehot = jax.nn.one_hot(target, vocab) 162 | 163 | assert logits.shape == target_onehot.shape 164 | 165 | loss = -jnp.sum(target_onehot * jax.nn.log_softmax(logits), axis=-1) 166 | loss = jnp.mean(loss) * self.loss_scale 167 | 168 | return loss 169 | 170 | self.proj_fwd_fn = hk.transform(debed_forward) 171 | self.proj_loss_fn = hk.transform(debed_loss) 172 | 173 | master_rng = jax.random.PRNGKey(random.getrandbits(32)) 174 | 175 | @functools.partial(jax.pmap) 176 | def debed_fwd_fn(target, params): 177 | out = self.proj_fwd_fn.apply(params, None, target) 178 | 179 | return out 180 | 181 | @functools.partial(jax.pmap, donate_argnums=(0, 2)) 182 | def debed_grad_fn(hidden, target, acc, params): 183 | loss, vjpfun = jax.vjp(self.proj_loss_fn.apply, params, None, hidden, target) 184 | weights_grad, _, x_grad, _ = vjpfun(np.ones((), dtype=hidden.dtype)) 185 | 186 | new_acc = jax.tree_multimap(operator.add, acc, weights_grad) 187 | return hidden, x_grad, loss, new_acc 188 | 189 | # we call all the functions here to trigger jit at init 190 | self.state = init_fn(master_rng, data, self.proj_fwd_fn.init, optimizer) 191 | 192 | num_params = hk.data_structures.tree_size(self.state["params"]) 193 | print(f'Param count = {num_params}') 194 | 195 | self.debed_fwd = debed_fwd_fn 196 | self.debed_fwd(jnp.zeros_like(data), self.state["params"]) 197 | 198 | self.debed_grad = debed_grad_fn 199 | _, _, _, new_acc = self.debed_grad(jnp.zeros_like(data), np.ones_like(data).mean(axis=-1), 200 | self.state["grad_acc"], 201 | self.state["params"]) 202 | self.state["grad_acc"] = new_acc 203 | 204 | self.state = opt_state(self.state, self.optimizer) 205 | self.state = init_fn(master_rng, jnp.zeros_like(data), self.proj_fwd_fn.init, optimizer) 206 | 207 | self.init = False 208 | 209 | def run(self): 210 | def forward(h, state): 211 | return self.debed_fwd(dequantize(h, "float32"), state["params"]) 212 | 213 | def backward(h, targets, state): 214 | hidden, x_grad, loss, new_acc = self.debed_grad(dequantize(h, "float32"), targets, state["grad_acc"], 215 | state["params"]) 216 | state["grad_acc"] = new_acc 217 | state["grad_count"] = state["grad_count"] + 1 218 | 219 | self.state = state 220 | 221 | return (quantize(hidden, self.precision.rev_act), quantize(x_grad, self.precision.grad)), loss 222 | 223 | self.fwd_q = Queue(2) 224 | self.bwd_q = Queue(2) 225 | self.init = True 226 | 227 | run_threads(self.state, self.fwd_q, self.bwd_q, 2, forward, backward) 228 | 229 | @ray.method(num_returns=2) 230 | def debed_forward(self, h): 231 | while not self.init: 232 | time.sleep(0.1) 233 | return run_function(self.fwd_q, h), 0 234 | 235 | @ray.method(num_returns=2) 236 | def debed_grad(self, h, targets): 237 | while not self.init: 238 | time.sleep(0.1) 239 | return run_function(self.bwd_q, h, targets) 240 | 241 | def opt(self): 242 | self.state = opt_state(self.state, self.optimizer) 243 | 244 | def get_params(self): 245 | return self.state["params"] 246 | 247 | def get_accum(self): 248 | return self.state["grad_acc"] 249 | 250 | def save(self, path, epoch): 251 | save_checkpoint(self.state, path, epoch) 252 | 253 | def load(self, path): 254 | ckpt = load_checkpoint(path) 255 | 256 | if ckpt: 257 | self.state = ckpt 258 | return True 259 | -------------------------------------------------------------------------------- /swarm_jax/model.py: -------------------------------------------------------------------------------- 1 | import haiku as hk 2 | import jax 3 | import jax.numpy as jnp 4 | import numpy as np 5 | from typing import Optional, Callable 6 | 7 | 8 | class MultiHeadAttentionFixed(hk.Module): 9 | """Multi-headed attention mechanism. 10 | 11 | With fixed attention scaling 12 | """ 13 | 14 | def __init__( 15 | self, 16 | num_heads: int, 17 | key_size: int, 18 | w_init_scale: float, 19 | query_size: Optional[int] = None, 20 | value_size: Optional[int] = None, 21 | model_size: Optional[int] = None, 22 | name: Optional[str] = None, 23 | ): 24 | super().__init__(name=name) 25 | self.num_heads = num_heads 26 | self.key_size = key_size 27 | self.query_size = query_size or key_size 28 | self.value_size = value_size or key_size 29 | self.model_size = model_size or key_size * num_heads 30 | self.w_init = hk.initializers.VarianceScaling(w_init_scale) 31 | 32 | def __call__( 33 | self, 34 | query: jnp.ndarray, 35 | mask: Optional[jnp.ndarray] = None, 36 | ) -> jnp.ndarray: 37 | """Compute (optionally masked) MHA with queries, keys & values.""" 38 | query_heads = self._linear_projection(query, self.query_size, "query") 39 | key_heads = self._linear_projection(query, self.key_size, "key") 40 | value_heads = self._linear_projection(query, self.value_size, "value") 41 | 42 | sqrt_key_size = np.sqrt(self.key_size) 43 | query_heads = query_heads / sqrt_key_size 44 | 45 | attention_logits = jnp.einsum("bthd,bThd->bhtT", query_heads, key_heads) 46 | 47 | seq_len = query.shape[1] 48 | causal_mask = np.tril(np.ones((seq_len, seq_len))) 49 | mask = mask * causal_mask if mask is not None else causal_mask 50 | 51 | attention_logits -= 1e10 * (1. - mask) 52 | 53 | attention_weights = jax.nn.softmax(attention_logits) 54 | attention = jnp.einsum("bhtT,bThd->bthd", attention_weights, value_heads) 55 | # Concatenate attention matrix of all heads into a single vector. 56 | attention_vec = jnp.reshape(attention, (*query.shape[:2], -1)) 57 | 58 | return hk.Linear(self.model_size, w_init=self.w_init)(attention_vec) 59 | 60 | @hk.transparent 61 | def _linear_projection( 62 | self, 63 | x: jnp.ndarray, 64 | head_size: int, 65 | name: Optional[str] = None 66 | ) -> jnp.ndarray: 67 | y = hk.Linear(self.num_heads * head_size, w_init=self.w_init, name=name)(x) 68 | return y.reshape((*x.shape[:2], self.num_heads, head_size)) 69 | 70 | 71 | class DenseBlock(hk.Module): 72 | """A 2-layer MLP which widens then narrows the input.""" 73 | 74 | def __init__(self, 75 | init_scale: float, 76 | widening_factor: int = 4, 77 | name: Optional[str] = None): 78 | super().__init__(name=name) 79 | self._init_scale = init_scale 80 | self._widening_factor = widening_factor 81 | 82 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 83 | hiddens = x.shape[-1] 84 | initializer = hk.initializers.VarianceScaling(self._init_scale) 85 | x = hk.Linear(self._widening_factor * hiddens, w_init=initializer)(x) 86 | x = jax.nn.gelu(x) 87 | return hk.Linear(hiddens, w_init=initializer)(x) 88 | 89 | 90 | class SwarmModel: 91 | def __init__(self, vocab: int, d_model: int, 92 | rev_init: Callable, rev_layers: int): 93 | self.vocab = vocab 94 | self.d_model = d_model 95 | self.rev_init = rev_init 96 | self.rev_layers = rev_layers 97 | 98 | 99 | n_layer = 6 100 | 101 | def char_layer_init(i): 102 | if i % 2: 103 | f = MultiHeadAttentionFixed( 104 | num_heads=8, 105 | key_size=128, 106 | w_init_scale=2. / n_layer, 107 | name=f'l{i}_f_attn', 108 | ) 109 | g = DenseBlock( 110 | init_scale=2. / n_layer, 111 | name=f'l{i}_g_dense', 112 | widening_factor=4 113 | ) 114 | else: 115 | f = DenseBlock( 116 | init_scale=2. / n_layer, 117 | name=f'l{i}_f_dense', 118 | widening_factor=4 119 | ) 120 | g = MultiHeadAttentionFixed( 121 | num_heads=8, 122 | key_size=128, 123 | w_init_scale=2. / n_layer, 124 | name=f'l{i}_g_attn', 125 | ) 126 | return f, g 127 | 128 | 129 | SwarmCharTransformer = SwarmModel( 130 | vocab=256, 131 | d_model=512, 132 | rev_init=char_layer_init, 133 | rev_layers=n_layer 134 | ) 135 | 136 | SwarmCharTransformerBig = SwarmModel( 137 | vocab=256, 138 | d_model=2048, 139 | rev_init=char_layer_init, 140 | rev_layers=n_layer 141 | ) 142 | -------------------------------------------------------------------------------- /swarm_jax/reversible_layer.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import operator 3 | import random 4 | import time 5 | from queue import Queue 6 | 7 | import haiku as hk 8 | import jax 9 | import jax.numpy as jnp 10 | import optax 11 | import ray 12 | from typing import Callable 13 | 14 | from .swarm_layer import save_checkpoint, load_checkpoint, opt_state, run_threads, run_function, NetworkPrecision, \ 15 | quantize, dequantize, init_fn 16 | 17 | 18 | @ray.remote(resources={"tpu": 1}) 19 | class ReversibleLayer(object): 20 | def __init__( 21 | self, 22 | layer_init: Callable, 23 | layer: int, 24 | data: jnp.ndarray, 25 | optimizer: optax.GradientTransformation, 26 | precision: NetworkPrecision 27 | ): 28 | self.layer = layer 29 | self.optimizer = optimizer 30 | self.precision = precision 31 | 32 | data = dequantize(data, "float32") 33 | 34 | def forward(x): 35 | f, g = layer_init(layer) 36 | 37 | hidden = x.shape[-1] 38 | x1 = x[:, :, :hidden // 2] 39 | x2 = x[:, :, hidden // 2:] 40 | 41 | y1 = f(x2) + x1 42 | y2 = g(y1) + x2 43 | 44 | assert x1.shape == y1.shape 45 | assert x2.shape == y2.shape 46 | 47 | return jnp.concatenate((y1, y2), axis=-1) 48 | 49 | def reverse(y): 50 | f, g = layer_init(layer) 51 | 52 | hidden = y.shape[-1] 53 | y1 = y[:, :, :hidden // 2] 54 | y2 = y[:, :, hidden // 2:] 55 | 56 | x2 = y2 - g(y1) 57 | x1 = y1 - f(x2) 58 | 59 | return jnp.concatenate((x1, x2), axis=-1) 60 | 61 | self.forward_fn = hk.transform(forward) 62 | self.reverse_fn = hk.transform(reverse) 63 | 64 | master_rng = jax.random.PRNGKey(random.getrandbits(32)) 65 | 66 | @functools.partial(jax.pmap, donate_argnums=0) 67 | def forward_fn(x, params): 68 | out = self.forward_fn.apply(params, None, x) 69 | return out 70 | 71 | @functools.partial(jax.pmap, donate_argnums=(0, 1)) 72 | def reverse_fn(y_dy, acc, params): 73 | y, dy = y_dy 74 | reconstr_x = self.reverse_fn.apply(params, None, y) 75 | 76 | _, vjpfun = jax.vjp(self.forward_fn.apply, params, None, reconstr_x) 77 | weights_grad, _, x_grad = vjpfun(dy) 78 | 79 | new_acc = jax.tree_multimap(operator.add, acc, weights_grad) 80 | return (reconstr_x, x_grad), new_acc 81 | 82 | self.state = init_fn(master_rng, jnp.zeros_like(data), self.forward_fn.init, optimizer) 83 | num_params = hk.data_structures.tree_size(self.state["params"]) 84 | print(f'Param count = {num_params}') 85 | 86 | self.forward = forward_fn 87 | self.forward(jnp.zeros_like(data), self.state["params"]) 88 | 89 | self.reverse = reverse_fn 90 | _, new_acc = self.reverse((jnp.zeros_like(data), jnp.zeros_like(data)), self.state["grad_acc"], 91 | self.state["params"]) 92 | self.state["grad_acc"] = new_acc 93 | 94 | self.state = opt_state(self.state, self.optimizer) 95 | self.state = init_fn(master_rng, jnp.zeros_like(data), self.forward_fn.init, optimizer) 96 | 97 | self.init = False 98 | 99 | def run(self): 100 | def forward(h, state): 101 | return quantize(self.forward(dequantize(h, "float32"), state["params"]), self.precision.fwd_act) 102 | 103 | def backward(y_dy, state): 104 | y, dy = y_dy 105 | y_dy = (dequantize(y, "float32"), dequantize(dy, "float32")) 106 | x_dx, new_acc = self.reverse(y_dy, state["grad_acc"], state["params"]) 107 | state["grad_acc"] = new_acc 108 | state["grad_count"] = state["grad_count"] + 1 109 | 110 | self.state = state 111 | 112 | x, dx = x_dx 113 | return quantize(x, self.precision.rev_act), quantize(dx, self.precision.grad) 114 | 115 | self.fwd_q = Queue(2) 116 | self.bwd_q = Queue(2) 117 | self.init = True 118 | 119 | run_threads(self.state, self.fwd_q, self.bwd_q, 2, forward, backward) 120 | 121 | @ray.method(num_returns=2) 122 | def forward(self, h): 123 | while not self.init: 124 | time.sleep(0.1) 125 | return run_function(self.fwd_q, h), None 126 | 127 | @ray.method(num_returns=2) 128 | def backward(self, y_dy): 129 | while not self.init: 130 | time.sleep(0.1) 131 | return run_function(self.bwd_q, y_dy), None 132 | 133 | def opt(self): 134 | self.state = opt_state(self.state, self.optimizer) 135 | 136 | def get_params(self): 137 | return self.state["params"] 138 | 139 | def get_accum(self): 140 | return self.state["grad_acc"] 141 | 142 | def save(self, path, epoch): 143 | save_checkpoint(self.state, path, epoch) 144 | 145 | def load(self, path): 146 | ckpt = load_checkpoint(path) 147 | 148 | if ckpt: 149 | self.state = ckpt 150 | return True -------------------------------------------------------------------------------- /swarm_jax/swarm.py: -------------------------------------------------------------------------------- 1 | from multiprocessing.pool import ThreadPool 2 | 3 | import numpy as np 4 | import optax 5 | import ray 6 | from tensorboardX import SummaryWriter 7 | from typing import Callable 8 | 9 | from .embedding_layer import EmbeddingLayer, ProjLayer 10 | from .model import SwarmModel 11 | from .reversible_layer import ReversibleLayer 12 | from .swarm_layer import NetworkPrecision 13 | 14 | 15 | class Swarm: 16 | def __init__(self, 17 | model: SwarmModel, 18 | optimizer: optax.GradientTransformation, 19 | loss_scale: float, 20 | dataloader: Callable, 21 | precision: NetworkPrecision): 22 | self.model = model 23 | self.optimizer = optax.chain( 24 | optax.scale(1 / loss_scale), 25 | optimizer 26 | ) 27 | self.dataloader = dataloader 28 | self.minibatches = 1 29 | self.loss_scale = loss_scale 30 | 31 | assert ray.is_initialized() # needs a valid ray cluster to start 32 | 33 | example = self.dataloader() 34 | self.embedding = EmbeddingLayer.options(max_concurrency=8).remote(example["obs"], self.model.vocab, 35 | self.model.d_model, self.optimizer, precision) 36 | self.embedding.run.remote() 37 | 38 | x, _ = self.embedding.embed_forward.remote(example["obs"]) 39 | 40 | self.proj = ProjLayer.options(max_concurrency=8).remote(x, self.model.vocab, self.model.d_model, self.optimizer, 41 | self.loss_scale, precision) 42 | self.proj.run.remote() 43 | 44 | self.layers = [] 45 | for i in range(model.rev_layers): 46 | self.layers.append( 47 | ReversibleLayer.options(max_concurrency=8).remote(self.model.rev_init, i, x, self.optimizer, precision)) 48 | 49 | for l in self.layers: 50 | l.run.remote() 51 | 52 | self.all_layers = [self.embedding] + self.layers + [self.proj] 53 | 54 | def run(self, epochs, log_path, ckpt_path): 55 | assert ray.is_initialized() # needs a valid ray cluster 56 | writer = SummaryWriter(log_path, flush_secs=5) 57 | 58 | ckpt_loads = [layer.load.remote(f"{ckpt_path}/{i}/") for i, layer in enumerate(self.all_layers)] 59 | print(f"checkpoint load status: {ray.get(ckpt_loads)}") 60 | 61 | pool = ThreadPool(16) # have max 16 concurrent examples in the network 62 | 63 | for e in range(epochs): 64 | if e % 5000 == 0: 65 | ckpt_saves = [layer.save.remote(f"{ckpt_path}/{i}/", e) for i, layer in enumerate(self.all_layers)] 66 | ray.wait(ckpt_saves, num_returns=len(ckpt_saves)) 67 | 68 | print(f"checkpoint saved") 69 | 70 | data = self.dataloader() 71 | 72 | def map_fn(_): 73 | return drive_example(self, data) 74 | 75 | result = list(pool.imap_unordered(map_fn, range(32))) # 32 microbatches per batch 76 | result = np.array(result) 77 | error, cos_err, loss = result.mean(axis=(0, 2)) 78 | 79 | opts = [layers.opt.remote() for layers in self.all_layers] 80 | ray.wait(opts, num_returns=len(opts)) 81 | 82 | writer.add_scalar("loss", loss / self.loss_scale, e) 83 | writer.add_scalar("reconstruction_error", error, e) 84 | writer.add_scalar("reconstruction_cos_error", cos_err, e) 85 | print(e, loss / self.loss_scale) 86 | 87 | 88 | # take a training example and shoves it through forward and backward of all layers 89 | def drive_example(swarm: Swarm, data): 90 | x, x_wait = swarm.embedding.embed_forward.remote(data["obs"]) 91 | ray.wait([x_wait]) 92 | 93 | # wrap all big ray objects in unit tuples to stop implicit .get 94 | for l in swarm.layers: 95 | x, x_wait = l.forward.remote((x,)) 96 | ray.wait([x_wait]) 97 | 98 | y_dy, loss = swarm.proj.debed_grad.remote((x,), data["target"]) 99 | ray.wait([loss]) 100 | 101 | for l in reversed(swarm.layers): 102 | y_dy, y_dy_wait = l.backward.remote((y_dy,)) 103 | ray.wait([y_dy_wait]) 104 | 105 | error = swarm.embedding.embed_grad.remote(data["obs"], (y_dy,)) 106 | ray.wait([error]) 107 | 108 | ret = ray.get(error) + (ray.get(loss),) 109 | 110 | return ret 111 | -------------------------------------------------------------------------------- /swarm_jax/swarm_layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common methods for layer actors 3 | """ 4 | import pickle 5 | import re 6 | from functools import partial 7 | from pathlib import Path 8 | from queue import Queue, Empty 9 | from threading import Thread 10 | 11 | import jax 12 | import jax.numpy as jnp 13 | import numpy as np 14 | import optax 15 | import ray 16 | from dataclasses import dataclass 17 | from glob import glob 18 | from typing import Callable 19 | 20 | 21 | # TODO: more intellegent checkpoint saving with deleting old checkpoints etc 22 | def save_checkpoint(state, path, epoch): 23 | Path(path).mkdir(parents=True, exist_ok=True) 24 | 25 | save_file = Path(path, f"ckpt_{epoch:06}.pkl") 26 | f = open(save_file, "wb") 27 | pickle.dump(state, f, protocol=pickle.HIGHEST_PROTOCOL) 28 | 29 | 30 | def load_checkpoint(path): 31 | checkpoints = [int(re.findall(r'\d+', i)[-1]) for i in glob(f"{path}ckpt_*.pkl")] 32 | checkpoints.sort(reverse=True) 33 | 34 | if checkpoints: 35 | checkpoint_to_load = checkpoints[0] 36 | 37 | f = open(f"{path}ckpt_{checkpoint_to_load:06}.pkl", "rb") 38 | return pickle.load(f) 39 | return None 40 | 41 | 42 | # @partial(jax.jit, donate_argnums=(0, 1, 2), static_argnums=3) 43 | @partial(jax.jit, static_argnums=3) 44 | def opt_jit(grad_acc, opt_state, params, optimizer): 45 | total_grad = jax.tree_map(lambda x: jnp.mean(x, axis=0), grad_acc) 46 | 47 | cpu_device = jax.devices("cpu")[0] 48 | 49 | total_grad = jax.device_put(total_grad, device=cpu_device) 50 | cpu_params = jax.device_put(jax.tree_map(lambda x: x[0], params), device=cpu_device) 51 | 52 | updates, new_opt_state = optimizer.update(total_grad, opt_state) 53 | 54 | new_params = optax.apply_updates(cpu_params, updates) 55 | 56 | new_grad_acc = jax.tree_map(jnp.zeros_like, grad_acc) 57 | return new_grad_acc, new_opt_state, new_params 58 | 59 | 60 | def opt_state(state, optimizer): 61 | new_grad_acc, new_opt_state, new_params = opt_jit(state["grad_acc"], 62 | state["opt_state"], 63 | state["params"], 64 | optimizer) 65 | 66 | state["grad_acc"] = new_grad_acc 67 | state["opt_state"] = new_opt_state 68 | state["params"] = jax.device_put_replicated(new_params, jax.local_devices()) 69 | state["grad_count"] = np.array(0) 70 | return state 71 | 72 | 73 | # @partial(jax.jit) 74 | def init_fn(master_rng, data, init_fn, optimizer): 75 | out_rng, init_rng = jax.random.split(master_rng) 76 | 77 | # copy the same initial params to each accelerator 78 | init_rng = jnp.broadcast_to(init_rng, (jax.local_device_count(),) + init_rng.shape) 79 | params = jax.pmap(init_fn)(init_rng, data) 80 | 81 | cpu_device = jax.devices("cpu")[0] 82 | 83 | # place optimizer state on CPU 84 | cpu_params = jax.tree_map(lambda x: jax.device_put(x[0], device=cpu_device), params) 85 | opt_state = optimizer.init(cpu_params) 86 | 87 | return dict( 88 | step=np.array(0), 89 | rng=out_rng, 90 | opt_state=opt_state, 91 | grad_acc=jax.tree_map(jnp.zeros_like, params), 92 | grad_count=np.array(0), 93 | params=params) 94 | 95 | 96 | # Thread to overlap remote transfers with computation (with bounded memory usage) 97 | # TODO: have hierarchy of remote -> RAM -> accelerator memory to overlap PCI-e transfer with computation 98 | class GetThread(Thread): 99 | def __init__(self, in_queue: Queue, out_queue: Queue): 100 | super().__init__() 101 | self.i_q = in_queue 102 | self.o_q = out_queue 103 | 104 | def run(self): 105 | while True: 106 | ret_q, obj_id, *aux = self.i_q.get() 107 | 108 | if isinstance(obj_id, ray.ObjectID): 109 | # GIL released here 110 | o = ray.get(obj_id) 111 | else: 112 | o = obj_id 113 | 114 | self.o_q.put((ret_q, o, *aux)) 115 | 116 | 117 | # TODO: bound the number of pending outputs when layerdrop is added, currently equal to number of pending inputs 118 | class RunThread(Thread): 119 | def __init__(self, fwd_q: Queue, bwd_q: Queue, fwd_fn: Callable, bwd_fn: Callable): 120 | super().__init__() 121 | self.fwd_q = fwd_q 122 | self.bwd_q = bwd_q 123 | 124 | self.fwd_fn = function_wrapper(fwd_fn) 125 | self.bwd_fn = function_wrapper(bwd_fn) 126 | 127 | def run(self): 128 | while True: 129 | # GIL released in fwd and bwd functions when e.g. XLA computation occurs 130 | # Prioritize backward over forward to minimize the amount of in flight batches (?) 131 | # TODO: figure out if this actually makes sense from queuing theory perspective 132 | while not self.bwd_q.empty(): 133 | self.bwd_fn(*self.bwd_q.get_nowait()) 134 | self.bwd_q.task_done() 135 | while not self.fwd_q.empty(): 136 | if not self.bwd_q.empty(): 137 | break 138 | self.fwd_fn(*self.fwd_q.get_nowait()) 139 | self.fwd_q.task_done() 140 | try: 141 | # also release GIL here to not busy wait 142 | self.bwd_fn(*self.bwd_q.get(timeout=0.01)) 143 | self.bwd_q.task_done() 144 | except Empty: 145 | pass 146 | 147 | 148 | # create a get thread for both fwd and bwd as well as a run thread (blocks forever) 149 | def run_threads(state, fwd_in_q: Queue, bwd_in_q: Queue, queue_size: int, fwd_fn: Callable, bwd_fn: Callable): 150 | fwd_out_q = Queue(queue_size) 151 | bwd_out_q = Queue(queue_size) 152 | 153 | fwd_get = GetThread(fwd_in_q, fwd_out_q) 154 | bwd_get = GetThread(bwd_in_q, bwd_out_q) 155 | run = RunThread(fwd_out_q, bwd_out_q, partial(fwd_fn, state=state), partial(bwd_fn, state=state)) 156 | 157 | fwd_get.start() 158 | bwd_get.start() 159 | run.start() 160 | 161 | run.join() 162 | # should never get here 163 | raise Exception("Run thread terminated unexpectedly") 164 | 165 | 166 | # take a function and wrap it to return via queue instead 167 | def function_wrapper(fun): 168 | def ret_fun(q: Queue, *args): 169 | ret = fun(*args) 170 | q.put(ret) 171 | 172 | return ret_fun 173 | 174 | 175 | # runs a function via queue (blocking, run in threadpool) 176 | def run_function(q: Queue, obj_id, *aux): 177 | ret_q = Queue(1) 178 | 179 | if isinstance(obj_id, tuple): 180 | q.put((ret_q, obj_id[0], *aux)) 181 | else: 182 | q.put((ret_q, obj_id, *aux)) 183 | 184 | return ret_q.get() 185 | 186 | 187 | @partial(jax.jit, static_argnums=2) 188 | def int_quantize_jit(x: jnp.ndarray, max_int: int, to_type: str): 189 | min = x.min(axis=1, keepdims=True) 190 | max = x.max(axis=1, keepdims=True) 191 | 192 | offset = min 193 | scale = max - min 194 | 195 | normalized = (x - min) / scale 196 | return offset, scale, (normalized * max_int + 0.5).astype(to_type) # round to nearest instead of round to zero 197 | 198 | 199 | def quantize(x: jnp.ndarray, to_type: str): 200 | assert to_type in ["float16", "float32", "uint16", "uint8"] 201 | 202 | if "int" in to_type: 203 | max_int = 2 ** 8 - 1 if to_type == "uint8" else 2 ** 16 - 1 204 | return to_type, int_quantize_jit(x, max_int, to_type) 205 | else: 206 | return to_type, x.astype(to_type) 207 | 208 | 209 | @partial(jax.jit, static_argnums=4) 210 | def int_dequantize_jit(x: jnp.ndarray, scale: jnp.ndarray, offset: jnp.ndarray, max_int: int, to_type: str): 211 | return x.astype(to_type) * scale.astype(to_type) / max_int + offset.astype(to_type) 212 | 213 | 214 | def dequantize(x, to_type: str): 215 | from_type, data = x 216 | assert from_type in ["float16", "float32", "uint16", "uint8"] 217 | 218 | if "int" in from_type: 219 | offset, scale, data = data 220 | max_int = 2 ** 8 - 1 if from_type == "uint8" else 2 ** 16 - 1 221 | 222 | return int_dequantize_jit(data, scale, offset, max_int, to_type) 223 | else: 224 | return data.astype(to_type) 225 | 226 | 227 | @dataclass 228 | class NetworkPrecision: 229 | fwd_act: str 230 | rev_act: str 231 | grad: str 232 | 233 | 234 | if __name__ == "__main__": 235 | import os 236 | 237 | os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=/opt/cuda-10.1" 238 | 239 | rng = jax.random.PRNGKey(0) 240 | 241 | r = jax.random.normal(rng, (16, 128, 512)) 242 | 243 | q = quantize(r, "uint16") 244 | d = dequantize(q, "float32") 245 | assert jnp.allclose(r, d, atol=1e-3, rtol=1e-3) 246 | 247 | q = quantize(r, "uint8") 248 | d = dequantize(q, "float32") 249 | assert jnp.allclose(r, d, atol=1e-1, rtol=1e-1) 250 | 251 | q = quantize(r, "float16") 252 | d = dequantize(q, "float32") 253 | assert jnp.allclose(r, d, atol=1e-3, rtol=1e-3) 254 | 255 | q = quantize(r, "float32") 256 | d = dequantize(q, "float32") 257 | assert jnp.allclose(r, d) 258 | -------------------------------------------------------------------------------- /swarm_run.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=/opt/cuda-10.1" 4 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 5 | os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" 6 | os.environ["JAX_DEBUG_NANS"] = "True" 7 | 8 | from swarm_jax.swarm_layer import NetworkPrecision 9 | 10 | from loader import TextLoader 11 | from swarm_jax.model import SwarmCharTransformer 12 | from swarm_jax.swarm import Swarm 13 | 14 | import ray 15 | import optax 16 | 17 | ray.init(resources={"tpu": 999}) # pretend we have infinite tpus lol 18 | 19 | train_dataset = TextLoader("data/enwik8", batchsize=(1, 16), sample_size=128, length=90000000) 20 | 21 | optimizer = optax.chain( 22 | optax.clip_by_global_norm(0.25), 23 | optax.adam(2e-4, b1=0.9, b2=0.99, eps=1e-5)) 24 | 25 | prec = NetworkPrecision(fwd_act="uint16", rev_act="uint16", grad="uint16") 26 | 27 | model = SwarmCharTransformer 28 | swarm = Swarm(model, optimizer, 2 ** 16, train_dataset.get_samples, prec) 29 | swarm.run(100000, "runs/512_30L", "ckpt/512_30L") 30 | 31 | ray.shutdown() 32 | -------------------------------------------------------------------------------- /swarm_run_tpu.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import multiprocessing 3 | 4 | import optax 5 | import ray 6 | 7 | from loader import TextLoader 8 | from ray_tpu import start_ray, get_connection, create_tpu, wait_til 9 | from swarm_jax.model import SwarmCharTransformerBig 10 | from swarm_jax.swarm import Swarm 11 | from swarm_jax.swarm_layer import NetworkPrecision 12 | 13 | tpus = 8 14 | 15 | # for i in range(tpus): 16 | # delete_tpu(f"swarm-jax-test-{i}", "europe-west4-a") 17 | # 18 | # exit() 19 | 20 | head_info = ray.init(dashboard_host="0.0.0.0") 21 | address = head_info['redis_address'] 22 | 23 | conns = [] 24 | for i in range(tpus): 25 | create_tpu(f"swarm-jax-test-{i}", "europe-west4-a", "v3-8", False) 26 | 27 | for i in range(tpus): 28 | assert wait_til(f"swarm-jax-test-{i}", "europe-west4-a", {'state': 'READY', 'health': 'HEALTHY'}) 29 | 30 | for i in range(tpus): 31 | conns += get_connection(f"swarm-jax-test-{i}", "europe-west4-a") 32 | 33 | with multiprocessing.Pool(processes=tpus) as p: 34 | p.map(functools.partial(start_ray, address=address), conns) 35 | 36 | train_dataset = TextLoader("data/enwik9", batchsize=(8, 8), sample_size=1024, length=90000000) 37 | 38 | optimizer = optax.chain( 39 | optax.clip_by_global_norm(0.25), 40 | optax.adam(2e-4, b1=0.9, b2=0.99, eps=1e-5)) 41 | 42 | prec = NetworkPrecision(fwd_act="float32", rev_act="float32", grad="float32") 43 | 44 | model = SwarmCharTransformerBig 45 | swarm = Swarm(model, optimizer, 2 ** 16, train_dataset.get_samples, prec) 46 | swarm.run(100000, "runs/512_30L", "ckpt/512_30L") 47 | 48 | ray.shutdown() 49 | --------------------------------------------------------------------------------