├── .gitignore ├── README.md ├── install.bash ├── iterate.yaml ├── media └── sharding.gif ├── minTransformerSharded ├── data.py ├── input.txt ├── layers.py ├── model.py └── train.py ├── model_parallel_tutorial ├── Todo.txt ├── Tutorial.ipynb ├── input.txt ├── memory.prof ├── setup.sh └── setupTPU.sh ├── old └── meshPipelineParallelism │ ├── .tmuxinator.yaml │ ├── dashboard │ └── readme.md │ ├── guide.md │ ├── infra │ ├── .tmuxinator.yaml │ ├── config.py │ ├── host_list │ ├── mpi_jax_test.py │ ├── mpi_test.py │ ├── ray_test.py │ ├── sentry.py │ └── utils.py │ ├── requirements.txt │ ├── steps.md │ ├── tests.ipynb │ └── tpu_key.pub ├── pipelineParallelism ├── infra │ ├── .tmuxinator.yaml │ ├── config.py │ ├── connection.py │ ├── ray_test.py │ └── utils.py └── launch.ipynb └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Scaling experiments 2 | 3 | A minimal implementation of a multi-device sharded transformer training, and a walk through of each component. The intention is educational - we'll build the required elements from the ground up and understand exactly where each computation is going. This is made exceptionally easy by jax, which is beautiful to work with on hardware meshes. All up it is only ~100 lines more than minGPT, with many of those as comments. Credit where its due - many elements from GPT-J's layer implementations are re-used here, but explained in detail. 4 | 5 | ![Alt Text](https://github.com/sholtodouglas/scalingExperiments/raw/main/media/sharding.gif) 6 | 7 | For production ready code, look at Mistral or the Deepspeed library (for Pytorch), or GPT-J (Jax). This repo is purely focused on exploring memory/compute strategies in multi-device training for a GPT style transformer - it could be further optimised through using float16, gcp streaming of tfrecords for dataloader, learning rate schedules etc 8 | 9 | This code uses the megatron-LM/GPT-J data+tensor parallelism scheme, which is simple and efficient on hardware meshes like TPUs. Soon, I'd like to look at pipeline parallelism, implement ZeRO style sharding - and use Ray to coordinate a K8s cluster of TPUv2s (for all those times you don't have a TPUvX-256!) 10 | 11 | This should be run on a TPU (either through GCP / TRC or Colab) as that gives us 8 devices to experiment with. In general, TPUs make training large models much easier - as your needs scale you can use bigger and bigger TPU pods, so its easy to see why Tesla is making their own extensible hardware mesh in Dojo. 12 | 13 | A couple of resources that I've leant on: 14 | 15 | - [Lilian Weng's notes on training large models](https://lilianweng.github.io/lil-log/2021/09/24/train-large-neural-networks.html) 16 | - [Ben Wang's GPT-J](https://github.com/kingoflolz/mesh-transformer-jax) 17 | - [Karpathy's MinGPT](https://github.com/karpathy/minGPT) 18 | - 3Blue1Brown's Manim library (to make the gif!) 19 | 20 | 21 | 22 | ## Usage 23 | 1. Run setup.sh (required for jax memory profiling). 24 | 2. Either work through the tutorial notebook, or run train.py. ~1 hour of training produces results identifiably Shakespearean output which is structured like a play. 25 | 26 | 27 | -------------------------------------------------------------------------------- /install.bash: -------------------------------------------------------------------------------- 1 | sudo apt install software-properties-common 2 | sudo add-apt-repository ppa:greymd/tmux-xpanes 3 | sudo apt update 4 | gem install tmuxinator 5 | sudo apt install tmux-xpanes -------------------------------------------------------------------------------- /iterate.yaml: -------------------------------------------------------------------------------- 1 | name: iterate 2 | root: ~/ 3 | windows: 4 | - one: 5 | panes: 6 | <%- args.each do |arg| %> 7 | - echo <%= arg %> 8 | <%- end %> -------------------------------------------------------------------------------- /media/sharding.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sholtodouglas/scalingExperiments/e8cec77d1c8e555c5d32aa795949e16952846c95/media/sharding.gif -------------------------------------------------------------------------------- /minTransformerSharded/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | class Dataloader(): 5 | ''' 6 | A super simple dataloader for a tiny dataset. Better implementations would 7 | - Pre-process the data into tf-records, and stream this from 8 | a GCP bucket. 9 | - Eliminate unncessary copies (e.g. as_numpy_iterator) 10 | ''' 11 | def __init__(self, config): 12 | super().__init__() 13 | 14 | if not os.path.exists('input.txt'): 15 | os.system('wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt') 16 | 17 | text = open('input.txt', 'r').read() 18 | self.vocab = sorted(list(set(text))) 19 | self.vocab_len = len(self.vocab) 20 | self.stoi = { ch:i for i,ch in enumerate(self.vocab) } 21 | self.itos = { i:ch for i,ch in enumerate(self.vocab) } 22 | tokens = [self.stoi[c] for c in text] 23 | d = tf.data.Dataset.from_tensor_slices(tokens) 24 | d = d.batch(config['block_size']+1, drop_remainder=True) # +1 because [:-1] will be x, and [1:] will be y 25 | self.d = iter(d.batch(config['batch_size_per_parallel']*config['devices'], drop_remainder=True).repeat().as_numpy_iterator()) 26 | 27 | def next_batch(self): 28 | b = self.d.next() 29 | return b[:, :-1], b[:, 1:] # x, y -------------------------------------------------------------------------------- /minTransformerSharded/layers.py: -------------------------------------------------------------------------------- 1 | 2 | import haiku as hk 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | 7 | 8 | class TransformerLayerShard(hk.Module): 9 | ''' 10 | # A simple transformer layer shard that exists on one of the devices. 11 | ''' 12 | def __init__(self, config, name=None, init_scale=1.): 13 | super().__init__(name=name) 14 | heads = config["n_head"] 15 | dim = config["d_model"] 16 | self.shards = config["shards"] 17 | 18 | assert dim % heads == 0 19 | assert heads % self.shards == 0 20 | 21 | self.dim_per_head = dim // heads 22 | self.dim_per_shard = dim // self.shards 23 | self.heads_per_shard = heads // self.shards 24 | 25 | # GPT-J uses a common layer norm between the mlp and the self attention, minGPT uses different layers - lets go with one for now. Much of a muchness. 26 | self.ln = hk.LayerNorm(-1, True, True) 27 | 28 | # key, query and value projections for all heads on this shard 29 | self.q = hk.Linear(self.dim_per_shard, with_bias=False) 30 | self.k = hk.Linear(self.dim_per_shard, with_bias=False) 31 | self.v = hk.Linear(self.dim_per_shard, with_bias=False) 32 | 33 | # self att output projection 34 | self.att_proj = hk.Linear(dim, with_bias=False, w_init=hk.initializers.TruncatedNormal(stddev=init_scale / np.sqrt(dim))) 35 | 36 | # feedforward layers 37 | self.dense_proj = hk.Linear(self.dim_per_shard * 4) 38 | self.dense_proj_out = hk.Linear(dim, 39 | w_init=hk.initializers.TruncatedNormal(stddev=init_scale / np.sqrt(dim))) 40 | 41 | 42 | def self_attention(self, q, k ,v, bias): 43 | ''' 44 | k,q,v: [T, heads_per_shard, dim_per_head] 45 | ''' 46 | T, _, _ = k.shape 47 | 48 | # No batch dimension needed in the einsum because it is abstracted away by the xmap 49 | attention_logits = jnp.einsum('thd,Thd->htT', q, k) # [heads_per_shard, T,T] 50 | sqrt_key_size = np.sqrt(self.dim_per_head).astype(k.dtype) # [1,] 51 | 52 | attention_logits = attention_logits/sqrt_key_size # [heads_per_shard, T,T] 53 | 54 | attention_logits += bias # [B, heads_per_shard, T,T] 55 | 56 | attention_weights = jax.nn.softmax(attention_logits) # [heads_per_shard, T,T] 57 | 58 | weighted_values = jnp.einsum('htT, Thd->thd', attention_weights, v).reshape((T, self.dim_per_shard)) # [T, dim_per_shard] 59 | 60 | return self.att_proj(weighted_values) 61 | 62 | def feed_forward(self, x): 63 | ''' 64 | x: [T,embed_dim] 65 | ''' 66 | dense_proj = self.dense_proj(x) 67 | dense_proj = jax.nn.gelu(dense_proj) 68 | return self.dense_proj_out(dense_proj) 69 | 70 | 71 | def qkv_proj(self, x): 72 | ''' 73 | x: [T, embed_dim] 74 | ''' 75 | q = self.q(x).reshape(x.shape[:-1] + (self.heads_per_shard, self.dim_per_head)) # [T, heads_per_shard, dim_per_head] 76 | v = self.v(x).reshape(x.shape[:-1] + (self.heads_per_shard, self.dim_per_head)) # "" 77 | k = self.k(x).reshape(x.shape[:-1] + (self.heads_per_shard, self.dim_per_head)) # "" 78 | 79 | return q,k,v 80 | 81 | 82 | def __call__(self, x): 83 | ''' 84 | x: [T, embed_dim] 85 | ''' 86 | # preliminaries 87 | # print(x.shape) 88 | T,C = x.shape 89 | x = self.ln(x) # [T,embed_dim] 90 | 91 | # causal self attention 92 | q,k,v = self.qkv_proj(x) 93 | causal_mask = np.tril(np.ones((T,T))) # [T,T] 94 | bias = -1e10 * (1. - causal_mask) # [T,T] 95 | attn = self.self_attention(q, k ,v, bias) # [T,embed_dim] 96 | 97 | # feedforward 98 | ff = self.feed_forward(x) # [B,T,embed_dim] 99 | 100 | # block 101 | x = x + attn # [T,embed_dim] 102 | x = x + ff # [T,embed_dim] 103 | 104 | # We finally need to sum across shards to collect the information from each head into the new embedding. 105 | # The idea of heads 'adding to the information stream' is a nice way of thinking about transformers from Anthropic's 106 | # latest work. 107 | # In the full GPT-J implementation, they've defined a custom operator which does the psum on the forward pass but 108 | # is the identity function on the backward pass - currently testing how necessary that is. 109 | return jax.lax.psum(x, "shard") 110 | 111 | 112 | 113 | class EmbeddingShard(hk.Module): 114 | def __init__(self, config, name=None): 115 | super().__init__(name=name) 116 | in_dim = config["n_vocab"] 117 | out_dim = config["d_model"] 118 | shards = config["shards"] 119 | 120 | assert in_dim % shards == 0 121 | 122 | self.in_dim = in_dim 123 | self.out_dim = out_dim 124 | self.in_dim_per_shard = in_dim // shards 125 | self.out_dim_per_shard = out_dim // shards 126 | 127 | embed_init = hk.initializers.TruncatedNormal(stddev=0.02) 128 | self.positional_embeddings = hk.get_parameter('pos_embs', [config["block_size"], self.out_dim_per_shard], init=embed_init) 129 | 130 | # notice unlike the ff transformer layer, this linear layer has the full output dimension because we are partitioning across vocab. 131 | self.proj = hk.Linear(self.out_dim, w_init=hk.initializers.TruncatedNormal(stddev=1 / np.sqrt(in_dim))) 132 | 133 | def __call__(self, x, dtype=jnp.bfloat16): 134 | 135 | # work out which shard we are on, and the start token index 136 | shard_start_index = jax.lax.axis_index('shard') * self.in_dim_per_shard 137 | 138 | # subtract the shard_start_index from the input indices. This means anything below it will be zero-d (as it will be a negative number) 139 | # at the same time, anything above 'in_dim_per_shard' will also be zero-d. This means that each shard gets a window of in_dim_per_shard indices 140 | # which it will expand to a one-hot representation - saving lots of space! 141 | input_onehot = jax.nn.one_hot(x - shard_start_index, self.in_dim_per_shard) 142 | proj_out = self.proj(input_onehot) 143 | # sum across shards to create a full embedding 144 | proj_out = jax.lax.psum(proj_out, "shard") 145 | # gets all of the positional embeddings as split across each shard (column wise split of positional embeddings) 146 | all_pos_embed = jax.lax.all_gather(self.positional_embeddings, 'shard') 147 | # flattens them, so now there are identical, complete positional embeddings on each device 148 | all_pos_embed = hk.Flatten()(jnp.transpose(all_pos_embed, (1, 0, 2))) 149 | 150 | proj_out += all_pos_embed[:proj_out.shape[0]] # only do the embeddings up to the length of the input sequence, to allow for variable input size 151 | 152 | return proj_out 153 | 154 | 155 | class ProjectionShard(hk.Module): 156 | def __init__(self, config, name=None): 157 | super().__init__(name=name) 158 | out_dim = config["n_vocab"] 159 | shards = config["shards"] 160 | 161 | assert out_dim % shards == 0 162 | 163 | self.dim = out_dim 164 | self.dim_per_shard = out_dim // shards 165 | 166 | self.norm = hk.LayerNorm(-1, True, True) 167 | 168 | self.proj = hk.Linear(self.dim_per_shard) 169 | 170 | def __call__(self, x): 171 | x = self.norm(x) 172 | proj = self.proj(x) 173 | 174 | all_proj = jax.lax.all_gather(proj, 'shard') 175 | 176 | return hk.Flatten()(jnp.transpose(all_proj, (1, 0, 2))) 177 | 178 | def loss(self, x, targets, z_loss=1): 179 | 180 | ''' 181 | x: [T, dim_per_shard] 182 | targets: [T] 183 | ''' 184 | x = self.norm(x) 185 | # calculate logits on a per shard basis 186 | logits = self.proj(x) # [T, dim_per_shard] 187 | # get the max of logits per dimension across the shards. Use this to prevent both under and overflow by subtracting it from every logit. 188 | # For an explaination on why you need this - see the opening pages of chapter 4 'Numerical computation' in goodfellow's deep learning book. 189 | global_max = jax.lax.pmax(jax.lax.stop_gradient(logits.max(-1, keepdims=True)), "shard") 190 | logits -= jax.lax.stop_gradient(global_max) # [T, dim_per_shard] 191 | 192 | # As we are computing the output vocab matrix in a sharded fashion, only get the targets corresponding to that shard 193 | # using the same trick as used in the embedding matrix. 194 | shard_start_index = jax.lax.axis_index('shard') * self.dim_per_shard 195 | gt_onehot = jax.nn.one_hot(targets - shard_start_index, self.dim_per_shard) # [T, dim_per_shard] 196 | # this is a point multiplication, so it zeros out anything which isn't a 1 in the one-hot representation. 197 | # then sums along the embedding axis. See above code snippet for explaination for the next few lines. 198 | predicted_logits = jnp.sum(jnp.multiply(gt_onehot, logits), axis=-1) # [T] 199 | # subtract the summed logit from the summed 'predicted_logit' 200 | # Any entry but the correct one is 0 in 'predicted logit' - and due to the max used for stabilisation 201 | # the entry of the highest index will be 0. Therefore, the subtraction of the two will draw the highest index to the correct one. 202 | # By only working with sums when we are using the sharded version we minimise communication. 203 | predicted_logits = jax.lax.psum(predicted_logits, 'shard') 204 | exp_logits = jnp.exp(logits) 205 | sum_exp_logits = exp_logits.sum(axis=-1) 206 | sum_exp_logits = jax.lax.psum(sum_exp_logits, 'shard') 207 | loss = jnp.log(sum_exp_logits) - predicted_logits 208 | # An additional loss which keeps the logits small - avoiding roundoff errors in bfloat16 (according to the mesh tensorflow repo). 209 | loss += (1e-4 * jnp.square(jnp.log(sum_exp_logits)) * z_loss).mean() 210 | 211 | # Due to the above, it is easy to correctly predict accuracy. 212 | correct = (0.0 == predicted_logits) 213 | return loss.sum(), correct 214 | 215 | 216 | class CausalTransformerShard(hk.Module): 217 | def __init__(self, config): 218 | super().__init__() 219 | heads = config["n_head"] 220 | shards = config["shards"] 221 | layer_count = config["n_layer"] 222 | 223 | self.layers = [] 224 | self.heads = heads 225 | 226 | self.heads_per_shard = heads // shards 227 | 228 | self.embed = EmbeddingShard(config) 229 | 230 | init_scale = 2. / layer_count 231 | 232 | for i in range(layer_count): 233 | self.layers.append(TransformerLayerShard(config, name=f"layer_{i}", init_scale=init_scale)) 234 | 235 | self.proj = ProjectionShard(config) 236 | 237 | 238 | def trunk(self, tokens): 239 | x = self.embed(tokens) 240 | 241 | for l in self.layers: 242 | x = x + l(x) 243 | return x 244 | 245 | def loss(self, tokens, targets): 246 | x = self.trunk(tokens) 247 | l, acc = self.proj.loss(x, targets) 248 | return l 249 | 250 | def __call__(self, tokens): 251 | 252 | x = self.trunk(tokens) 253 | return self.proj(x) 254 | 255 | -------------------------------------------------------------------------------- /minTransformerSharded/model.py: -------------------------------------------------------------------------------- 1 | 2 | import haiku as hk 3 | import jax 4 | import optax 5 | from layers import CausalTransformerShard 6 | from jax import value_and_grad 7 | import os 8 | import pickle 9 | import numpy as np 10 | 11 | 12 | class CausalTransformer(): 13 | def __init__(self, config): 14 | super().__init__() 15 | 16 | axis_names = ('batch', 'shard') 17 | mesh_devices = np.array(jax.devices()).reshape((config['devices'], config['shards'])) 18 | self.mesh_def = (mesh_devices, axis_names) 19 | 20 | self.config = config 21 | self.optimizer = optax.adam(1e-5) 22 | 23 | self.key = hk.PRNGSequence(42) 24 | 25 | self.init = jax.experimental.maps.xmap(fun=self.init_state, 26 | in_axes=(["shard", ...], # rngs 27 | ["batch", ...]), # x 28 | out_axes=["shard", ...], 29 | axis_resources={'shard': 'shard', 'batch': 'batch'}) 30 | 31 | self.forward = jax.experimental.maps.xmap(fun=self.eval_step, 32 | in_axes=(["shard", ...], # params 33 | ["batch", ...]), # x 34 | out_axes=(["batch", ...]), 35 | axis_resources={'shard': 'shard', 'batch': 'batch'}) 36 | 37 | 38 | self.train = jax.experimental.maps.xmap(fun=self.train_step, 39 | in_axes=(["shard", ...], # state 40 | ["batch", ...], # x 41 | ["batch", ...]),# y 42 | out_axes=([['batch'], # loss 43 | ['shard',...]]), # state 44 | axis_resources={'shard': 'shard', 'batch': 'batch'}) 45 | 46 | 47 | # Haiku pure functions for convenience 48 | def eval_fn(self, x): 49 | model = CausalTransformerShard(self.config) 50 | return model(x) 51 | 52 | def train_fn(self, x,y): 53 | model = CausalTransformerShard(self.config) 54 | return model.loss(x,y) 55 | 56 | 57 | def init_state(self, key, x): 58 | ''' 59 | A parallelised init function that ensures optimiser params are stored on the respective devices. 60 | ''' 61 | params = hk.transform(self.eval_fn).init(key, x) 62 | 63 | return { 64 | "params": params, 65 | "step": np.array(0), 66 | "opt_state": self.optimizer.init(params) 67 | } 68 | 69 | def eval_step(self, params, x): 70 | 71 | forward_fn = hk.without_apply_rng(hk.transform(self.eval_fn)) 72 | out = forward_fn.apply(params, x) 73 | return out 74 | 75 | def train_step(self, state, x,y): 76 | 77 | l_fn = hk.without_apply_rng(hk.transform(self.train_fn)) 78 | loss, grads = value_and_grad(l_fn.apply)(state['params'], x,y) 79 | grads = jax.lax.pmean(grads, "batch") 80 | updates, new_opt_state = self.optimizer.update(grads, state['opt_state'], state['params']) 81 | 82 | return loss, { 83 | "params": optax.apply_updates(state['params'], updates), 84 | "step": state['step'] + 1, 85 | "opt_state": new_opt_state 86 | } 87 | 88 | def save(self, state): 89 | os.makedirs(self.config['ckpt_dir'], exist_ok=True) 90 | with open(os.path.join(self.config['ckpt_dir'], "arrays.npy"), "wb") as f: 91 | for x in jax.tree_leaves(state): 92 | np.save(f, x, allow_pickle=False) 93 | 94 | tree_struct = jax.tree_map(lambda t: 0, state) 95 | with open(os.path.join(self.config['ckpt_dir'], "tree.pkl"), "wb") as f: 96 | pickle.dump(tree_struct, f) 97 | 98 | def restore(self): 99 | ''' 100 | Usage: set state = model.restore() after initialising the model. 101 | ''' 102 | with open(os.path.join(self.config['ckpt_dir'], "tree.pkl"), "rb") as f: 103 | tree_struct = pickle.load(f) 104 | 105 | leaves, treedef = jax.tree_flatten(tree_struct) 106 | with open(os.path.join(self.config['ckpt_dir'], "arrays.npy"), "rb") as f: 107 | flat_state = [np.load(f) for _ in leaves] 108 | 109 | return jax.tree_unflatten(treedef, flat_state) 110 | -------------------------------------------------------------------------------- /minTransformerSharded/train.py: -------------------------------------------------------------------------------- 1 | from data import Dataloader 2 | from model import CausalTransformer 3 | import jax.numpy as jnp 4 | import jax 5 | 6 | GPTConfig = { 7 | 'n_vocab': 66, 8 | 'block_size': 32, 9 | 'n_layer' : 3, 10 | 'n_head' : 8, 11 | 'd_model' : 768, 12 | 'shards': 2, 13 | 'devices': 4, 14 | 'batch_size_per_parallel': 256, 15 | 'ckpt_dir': 'test'} 16 | 17 | # A downside of using the more memory efficient method of embedding sharding is that it requires equal shard size across devices 18 | # or a 'check which device I'm on, lookup desired shard size'. For the moment - easier to just have a few empty spots for tokens. 19 | 20 | assert GPTConfig['n_vocab'] % GPTConfig['shards'] == 0 21 | 22 | 23 | ds = Dataloader(GPTConfig) 24 | model = CausalTransformer(GPTConfig) 25 | 26 | 27 | x,y = ds.next_batch() # [B,T], [B,T] 28 | 29 | with jax.experimental.maps.mesh(*model.mesh_def): 30 | state = model.init(jnp.array(model.key.take(GPTConfig['shards'])), x) 31 | 32 | 33 | 34 | from tqdm import tqdm 35 | 36 | losses = [] 37 | with jax.experimental.maps.mesh(*model.mesh_def): 38 | steps = [t for t in range(0, 10000)] 39 | pbar = tqdm(steps) 40 | for t in pbar: 41 | x,y = ds.next_batch() 42 | loss, state = model.train(state, x,y) 43 | if t % 100 == 0: 44 | pbar.set_description(f"Loss: {loss.mean()}") 45 | losses.append(loss.mean()) 46 | 47 | # Non auto-regressive sampling (works faster so you can see if it broadly making sense after 15 minutes) 48 | with jax.experimental.maps.mesh(*model.mesh_def): 49 | x,y = ds.next_batch() 50 | y_pred = model.forward(state['params'], x) 51 | y_pred_logit = jnp.argmax(y_pred, -1) 52 | 53 | for i in range(0,100): 54 | print(''.join([ds.itos[c] for c in list(y_pred_logit[i])])) 55 | print('--------------------------') -------------------------------------------------------------------------------- /model_parallel_tutorial/Todo.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sholtodouglas/scalingExperiments/e8cec77d1c8e555c5d32aa795949e16952846c95/model_parallel_tutorial/Todo.txt -------------------------------------------------------------------------------- /model_parallel_tutorial/memory.prof: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sholtodouglas/scalingExperiments/e8cec77d1c8e555c5d32aa795949e16952846c95/model_parallel_tutorial/memory.prof -------------------------------------------------------------------------------- /model_parallel_tutorial/setup.sh: -------------------------------------------------------------------------------- 1 | sudo add-apt-repository ppa:longsleep/golang-backports -y 2 | sudo apt update 3 | sudo apt install golang-go 4 | env GOPATH=/root/go 5 | sudo apt-get install graphviz gv 6 | go install github.com/google/pprof@latest 7 | pip install jupyter optax tensorflow_datasets 8 | pip install dm-haiku==0.0.5 -------------------------------------------------------------------------------- /model_parallel_tutorial/setupTPU.sh: -------------------------------------------------------------------------------- 1 | 2 | gcloud alpha compute tpus tpu-vm create lfp1 --zone=us-central1-f --accelerator-type=v2-8 --version=v2-alpha --project learning-from-play-303306 3 | gcloud alpha compute tpus tpu-vm ssh lfp1 --zone=us-central1-f --project learning-from-play-303306 -- -L 8888:localhost:8888 4 | -------------------------------------------------------------------------------- /old/meshPipelineParallelism/.tmuxinator.yaml: -------------------------------------------------------------------------------- 1 | name: scaling 2 | root: ~/ 3 | windows: 4 | - servers: 5 | layout: tiled 6 | panes: 7 | - test_0: 8 | - clear 9 | - test_1: 10 | - clear 11 | - test_2: 12 | - clear 13 | - test_3: 14 | - clear 15 | - test_4: 16 | - clear 17 | - test_5: 18 | - clear 19 | - test_6: 20 | - clear 21 | -------------------------------------------------------------------------------- /old/meshPipelineParallelism/dashboard/readme.md: -------------------------------------------------------------------------------- 1 | This is composed of a flask server that can be pinged to show updates. 2 | 3 | I wonder if one can ping local flask from remote? 4 | 5 | -------------------------------------------------------------------------------- /old/meshPipelineParallelism/guide.md: -------------------------------------------------------------------------------- 1 | To achieve pipeline parallelism - we'll need to coordinate a whole cluster of machines. The optimal hardware would be something like a TPU pod, which is made up of up to 32 TPU boards (each with 8 TPU cores), connected with a high speed interconnect. I don't have access to one, but I do have access to a large cluster of separate TPUv2 boards - which is arguably more representative of the clusters of A100s (or similar) that are in use outside GCP. That makes them an excellent playground to explore inter-machine parallelism! 2 | 3 | As with the previous post, this is focused tightly on being educational and looking at some of the design decisions which might have gone into writing a framework like Deepspeed. 4 | 5 | - 6 | 7 | # Setup 8 | 9 | As we'll be using a whole cluster of machines - we'll need a way to orchestrate them. Normally, I'd opt for kubernetes through Google Kubernetes Engine (GKE) - but it appears to only support the older TPU nodes which require a separate host CPU driving them (and are commensurately much slower). Instead, we'll write our own simple orchestration functions (I'm quite confident TPU-VM's will be available through GKE in time). 10 | 11 | # Distributed Systems 12 | 13 | Two libaries stand out 14 | 15 | - Ray: A distributed systems library for python. It sets up a single head + multiple worker nodes, but theoretically the head shouldn't be a bottleneck as it does not copy data to the head in order to transfer it between worker nodes. 16 | - mpi4jax: Zero copy, multi-host communication of JAX arrays. No orchestrating node is required - all nodes send and receive directly to eachtother. 17 | 18 | mpi4jax is a lower level library and is likely to introduce more complexity into the code - but it may be faster as it has been directly optimised for the transfer of jax arrays from GPU/TPU memory. Lets test! Regardless, we'll still use Ray for multiprocessing of orchestration operations from our local machine. 19 | 20 | 21 | 22 | # Design 23 | 24 | Pre-emption seems reasonably common, so lets make it fault tolerant from the start. 25 | 26 | We should have a continously running process which performs the following steps 27 | 28 | - Tries to create the desired config 29 | - - Checks the currently active nodes, and constructs the best possible configuration from it. 30 | - - Sets that running 31 | - - If one of the existing ones is pre-empted, cancel the program somehow, recompute optimal arrangement 32 | - - If a new one is added, recompute optimal config. Wait till next save epoch to update 33 | 34 | 35 | 36 | # Setup 37 | 38 | Working with and debugging a distributed environment is a little more annoying! While we're getting started, I set up a tmuxinator that sshs into everything. 39 | 40 | To ensure the tmux looks pretty, copy the following into ~/.tmux.conf 41 | 42 | 43 | 44 | tmuxinator start scaling -p .tmuxinator.yaml 45 | 46 | 47 | 48 | # GCP setup 49 | 50 | Create a project with TPU accces 51 | 52 | Create an ssh key called 'google_compute_engine', and add it to the project. 53 | ''' 54 | ssh-keygen -t rsa 55 | 56 | 57 | gcloud compute os-login ssh-keys add \ 58 | --key-file=KEY_FILE_PATH.pub \ 59 | --project=PROJECT 60 | 61 | ''' 62 | 63 | 64 | 65 | 66 | gcloud compute os-login ssh-keys add \ 67 | --key-file= /home/sholto/.ssh/google_compute.pub \ 68 | --project=learning-from-play-303306 69 | 70 | 71 | 72 | Enable each machine to access the others 73 | https://www.open-mpi.org/faq/?category=running#missing-prereqs 74 | https://github.com/NAThompson/mpi_clustering 75 | 76 | 77 | ssh-keygen -t rsa -f $HOME/.ssh/id_rsa -N '' -C "MPI Keys" 78 | 79 | Need to make sure each device shares an ssh key -------------------------------------------------------------------------------- /old/meshPipelineParallelism/infra/.tmuxinator.yaml: -------------------------------------------------------------------------------- 1 | 2 | name: scaling 3 | root: ~/ 4 | windows: 5 | - servers: 6 | layout: tiled 7 | panes: 8 | - test_0: 9 | - clear 10 | - test_1: 11 | - clear 12 | - test_2: 13 | - clear 14 | - test_3: 15 | - clear -------------------------------------------------------------------------------- /old/meshPipelineParallelism/infra/config.py: -------------------------------------------------------------------------------- 1 | cluster_config = { 2 | 'nodes': 4, 3 | 'pipeline_length': 2, 4 | 'name': "test", 5 | "project": "learning-from-play-303306", 6 | "accelerator_type": "v2-8", 7 | "zone": "us-central1-f", 8 | "preemptible": False, 9 | "redis_password": "5241590000000000" # the default 10 | } 11 | 12 | constant_args = f"--zone={cluster_config['zone']} --project={cluster_config['project']}" 13 | 14 | 15 | # the formatting in this section is weird, but precisely maps to the tmuxinator file. 16 | # so that it is easily modifiable to anyone with familiarity there. 17 | if __name__ == "__main__": 18 | 19 | tmuxinator_header = ''' 20 | name: scaling 21 | root: ~/ 22 | windows: 23 | - servers: 24 | layout: tiled 25 | panes:''' 26 | 27 | with open('.tmuxinator.yaml', 'w') as file: 28 | 29 | 30 | panes = "".join([f''' 31 | - test_{i}: 32 | - clear''' for i in range(0, cluster_config['nodes'])] ) 33 | 34 | file.write(tmuxinator_header + panes) 35 | 36 | 37 | -------------------------------------------------------------------------------- /old/meshPipelineParallelism/infra/host_list: -------------------------------------------------------------------------------- 1 | 10.128.15.204 slots=1 2 | 10.128.15.207 slots=1 3 | 10.128.15.201 slots=1 4 | 10.128.15.203 slots=1 5 | -------------------------------------------------------------------------------- /old/meshPipelineParallelism/infra/mpi_jax_test.py: -------------------------------------------------------------------------------- 1 | from mpi4py import MPI 2 | import jax 3 | import jax.numpy as jnp 4 | import mpi4jax 5 | 6 | comm = MPI.COMM_WORLD 7 | rank = comm.Get_rank() 8 | 9 | print(f'hello from {rank} of {comm}') 10 | 11 | @jax.jit 12 | def foo(arr): 13 | arr = arr + rank 14 | arr_sum, _ = mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm) 15 | return arr_sum 16 | 17 | a = jnp.zeros((3, 3)) 18 | result = foo(a) 19 | 20 | if rank == 0: 21 | print(result) -------------------------------------------------------------------------------- /old/meshPipelineParallelism/infra/mpi_test.py: -------------------------------------------------------------------------------- 1 | from mpi4py import MPI 2 | 3 | size = MPI.COMM_WORLD.Get_size() 4 | rank = MPI.COMM_WORLD.Get_rank() 5 | name = MPI.Get_processor_name() 6 | 7 | print("Hello from rank {0} of {1} on {2}".format(rank, size, name)) -------------------------------------------------------------------------------- /old/meshPipelineParallelism/infra/ray_test.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import socket 3 | import time 4 | 5 | import ray 6 | from config import cluster_config, constant_args 7 | 8 | ray.init(address='auto', _redis_password=cluster_config['redis_password']) 9 | 10 | print('''This cluster $ consists of 11 | {} nodes in total 12 | {} CPU resources in total 13 | '''.format(len(ray.nodes()), ray.cluster_resources()['CPU'])) 14 | 15 | @ray.remote 16 | def f(): 17 | time.sleep(0.001) 18 | # Return IP address. 19 | return socket.gethostbyname(socket.gethostname()) 20 | 21 | object_ids = [f.remote() for _ in range(10000)] 22 | ip_addresses = ray.get(object_ids) 23 | 24 | print('Tasks executed') 25 | for ip_address, num_tasks in Counter(ip_addresses).items(): 26 | print(' {} tasks on {}'.format(num_tasks, ip_address)) 27 | -------------------------------------------------------------------------------- /old/meshPipelineParallelism/infra/sentry.py: -------------------------------------------------------------------------------- 1 | # This file handles the establishment and management of a cluster of TPU-VMS 2 | 3 | from .config import cluster_config, constant_args 4 | from .utils import construct_cluster_names, list_tpus 5 | from os import system 6 | import psutil 7 | 8 | 9 | ''' 10 | ''' 11 | 12 | import yaml 13 | 14 | LEADER = 0 # the leader node is the 0th 15 | 16 | def make_tmux(): 17 | names = construct_cluster_names(cluster_config['nodes']) 18 | 19 | def make_pane(name): 20 | pane = {name: [ 21 | "clear" 22 | ]} 23 | return pane 24 | 25 | data = { 26 | "name": "scaling", 27 | "root": "~/", 28 | "windows": [ 29 | {"servers": { 30 | "layout": "tiled", 31 | "panes": [make_pane(name) for name in names], 32 | } 33 | } 34 | ], 35 | } 36 | 37 | file = open(".tmuxinator.yaml", "w") 38 | yaml.dump(data, file) 39 | file.close() 40 | 41 | 42 | def tmux(command): 43 | system('tmux %s' % command) 44 | 45 | def tmux_shell(command): 46 | tmux('send-keys "%s" "C-m"' % command) 47 | 48 | def connect(): 49 | to_connect = construct_cluster_names(cluster_config['nodes']) 50 | connections = psutil.net_connections() 51 | ports = [conn.laddr.port for conn in connections] 52 | 53 | existing_tpus = list_tpus()['nodes'] 54 | ready_tpus = [node['name'].split('/')[-1] for node in existing_tpus if node['state'] == 'READY'] 55 | 56 | 57 | tmux("set -g pane-active-border-style bg=default,fg=magenta") 58 | tmux("set -g pane-border-style fg=green") 59 | 60 | # start the existing, but stopped nodes 61 | for name in to_connect: 62 | 63 | pane = int(name.split('-')[-1]) 64 | port = 8800 + pane 65 | tmux(f'select-pane -t {pane}') 66 | 67 | # if it is the correct type of node 68 | if name in ready_tpus: 69 | # tmux("set pane-border-style bg=green,fg=green") # make it red for failing to connect 70 | if port in ports: 71 | print(f"{port} already in use - indicates connection exists") 72 | else: 73 | 74 | tmux_shell(f"gcloud alpha compute tpus tpu-vm ssh {name} {constant_args} -- -L {port}:localhost:{port}") 75 | else: 76 | pass 77 | # tmux("set pane-border-style bg=red,fg=red") # make it red for failing to connect 78 | 79 | def run(cmd, pane=None): 80 | 81 | # run it on all of them 82 | if pane == None: 83 | panes_to_connect_to = [int(name.split('-')[-1]) for name in construct_cluster_names(cluster_config['nodes'])] 84 | connections = [conn.laddr.port for conn in psutil.net_connections()] 85 | active_panes = [pane for pane in panes_to_connect_to if 8800 + pane in connections] 86 | 87 | for pane in active_panes: 88 | tmux(f'select-pane -t {pane}') 89 | tmux_shell(cmd) 90 | # run on a specific pane 91 | else: 92 | tmux(f'select-pane -t {pane}') 93 | tmux_shell(cmd) 94 | 95 | 96 | def clear_all(): 97 | panes_to_connect_to = [int(name.split('-')[-1]) for name in construct_cluster_names(cluster_config['nodes'])] 98 | for pane in panes_to_connect_to: 99 | tmux(f'select-pane -t {pane}') 100 | tmux_shell('clear') -------------------------------------------------------------------------------- /old/meshPipelineParallelism/infra/utils.py: -------------------------------------------------------------------------------- 1 | # This file handles the establishment and management of a cluster of TPU-VMS 2 | import functools 3 | import os 4 | import requests 5 | import subprocess 6 | from tqdm import tqdm 7 | 8 | 9 | from fabric import Connection 10 | import os 11 | 12 | from infra.config import cluster_config, constant_args 13 | 14 | # @functools.lru_cache() # TODO this can error if it has been a while since it is called. 15 | def get_bearer(): 16 | return subprocess.check_output("gcloud auth print-access-token", shell=True).decode("utf-8").strip() 17 | 18 | def check_tpu(name): 19 | headers = { 20 | 'Authorization': f'Bearer {get_bearer()}', 21 | } 22 | 23 | response = requests.get( 24 | f"https://tpu.googleapis.com/v2alpha1/projects/{cluster_config['project']}/locations/{cluster_config['zone']}/nodes/{name}", 25 | headers=headers) 26 | 27 | return response.json() 28 | 29 | 30 | def list_tpus(): 31 | headers = { 32 | 'Authorization': f'Bearer {get_bearer()}', 33 | } 34 | 35 | response = requests.get( 36 | f"https://tpu.googleapis.com/v2alpha1/projects/{cluster_config['project']}/locations/{cluster_config['zone']}/nodes", 37 | headers=headers) 38 | 39 | return response.json() 40 | 41 | 42 | def create_tpu(name): 43 | print(f"Creating {name}") 44 | headers = { 45 | 'Authorization': f'Bearer {get_bearer()}', 46 | 'Content-Type': 'application/json', 47 | } 48 | 49 | params = ( 50 | ('node_id', name), 51 | ) 52 | 53 | data = {"accelerator_type": 54 | cluster_config['accelerator_type'], 55 | "runtime_version": 56 | 'v2-alpha', 57 | "network_config": 58 | {"enable_external_ips": True}, 59 | "schedulingConfig": 60 | {"preemptible": cluster_config['preemptible']}, 61 | } 62 | 63 | response = requests.post(f"https://tpu.googleapis.com/v2alpha1/projects/{cluster_config['project']}/locations/{cluster_config['zone']}/nodes", 64 | headers=headers, params=params, json=data) 65 | 66 | print(response.json()) 67 | 68 | return response.status_code == 200 69 | 70 | 71 | def start_tpu(name): 72 | print(f"Starting {name}") 73 | headers = { 74 | 'Authorization': f'Bearer {get_bearer()}', 75 | } 76 | 77 | response = requests.post( 78 | f"https://tpu.googleapis.com/v2alpha1/projects/{cluster_config['project']}/locations/{cluster_config['zone']}/nodes/{name}:start", 79 | headers=headers) 80 | 81 | return response.json() 82 | 83 | 84 | def stop_tpu(name): 85 | print(f"Stopping {name}") 86 | headers = { 87 | 'Authorization': f'Bearer {get_bearer()}', 88 | } 89 | 90 | response = requests.delete( 91 | f"https://tpu.googleapis.com/v2alpha1/projects/{cluster_config['project']}/locations/{cluster_config['zone']}/nodes/{name}:stop", 92 | headers=headers) 93 | 94 | return response.json() 95 | 96 | 97 | def delete_tpu(name): 98 | print(f"Deleting {name}") 99 | headers = { 100 | 'Authorization': f'Bearer {get_bearer()}', 101 | } 102 | 103 | response = requests.delete( 104 | f"https://tpu.googleapis.com/v2alpha1/projects/{cluster_config['project']}/locations/{cluster_config['zone']}/nodes/{name}", 105 | headers=headers) 106 | 107 | return response.json() 108 | 109 | def construct_cluster_names(N: int): 110 | return [f"{cluster_config['name']}-{n}" for n in range(0, N)] 111 | 112 | def scale_cluster(): 113 | existing_tpus = list_tpus().get('nodes', []) 114 | 115 | # determine what we need to create 116 | to_construct = construct_cluster_names(cluster_config['nodes']) 117 | 118 | # start the existing, but stopped nodes 119 | for node in existing_tpus: 120 | name = node['name'].split('/')[-1] 121 | # if it is the correct type of node 122 | if name in to_construct and node['acceleratorType'] == cluster_config['accelerator_type']: 123 | # remove it from what we need to construct 124 | to_construct.remove(name) 125 | # start any that are stopped 126 | if node['state'] == 'STOPPED': 127 | start_tpu(name) 128 | if node['state'] == 'PREEMPTED': 129 | delete_tpu(name) 130 | start_tpu(name) 131 | 132 | # create the remainder 133 | for remaining in to_construct: 134 | res = create_tpu(remaining) # TODO: if res failed. 135 | 136 | 137 | 138 | def validate_cluster(): 139 | print('Validating cluster creation') 140 | for name in tqdm(construct_cluster_names(cluster_config['nodes'])): 141 | try: 142 | if check_tpu(name)['state'] != 'READY': 143 | print(f"Failed: {name}") 144 | except: 145 | print(f"TPU {name} not found") 146 | 147 | 148 | def chunks(l, n): 149 | n = max(1, n) 150 | return (l[i:i+n] for i in range(0, len(l), n)) 151 | 152 | 153 | def get_pipelines(): 154 | ''' 155 | Takes the currently active nodes and arranges them into full length pipelines 156 | If there are insufficient nodes for the final pipeline, they are ignored. 157 | ''' 158 | tpus = list_tpus().get('nodes', []) 159 | 160 | if len(tpus) < cluster_config['pipeline_length']: 161 | raise Exception('Insufficient Devices to form a pipeline') 162 | 163 | # create a connection object for each tpu to allow for easy file copy 164 | for tpu in tpus: 165 | tpu['connection_object'] = Connection(tpu['networkEndpoints'][0]['accessConfig']['externalIp'], connect_kwargs={ 166 | "key_filename": os.path.expanduser('~/.ssh/google_compute_engine'), }) 167 | 168 | # arrange them in pipelines 169 | complete_pipelines = [p for p in chunks(tpus, cluster_config['pipeline_length']) if len(p) == cluster_config['pipeline_length']] 170 | 171 | return complete_pipelines 172 | 173 | 174 | -------------------------------------------------------------------------------- /old/meshPipelineParallelism/requirements.txt: -------------------------------------------------------------------------------- 1 | ray[default] 2 | fabric -------------------------------------------------------------------------------- /old/meshPipelineParallelism/steps.md: -------------------------------------------------------------------------------- 1 | # Step 0 2 | 3 | Create a config at infra/config.py 4 | 5 | 6 | ``` 7 | # generate the tmuxinator pane 8 | python infra/config.py 9 | 10 | # run tmuxinator to set up the work space 11 | tmuxinator start scaling -p infra/.tmuxinator.yaml 12 | 13 | ``` 14 | 15 | 16 | # Step 1 17 | 18 | Create a cluster, which will have 1 leader (coordinating node), and N worker nodes. 19 | 20 | 21 | ``` 22 | python 23 | ``` 24 | -------------------------------------------------------------------------------- /old/meshPipelineParallelism/tests.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from infra.utils import construct_cluster_names, scale_cluster, list_tpus, check_tpu, validate_cluster, get_pipelines\n", 10 | "from infra.config import cluster_config, constant_args\n", 11 | "from infra.sentry import connect, run, clear_all, LEADER, tmux, tmux_shell\n", 12 | "\n", 13 | "from tqdm import tqdm" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 4, 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "data": { 23 | "text/plain": [ 24 | "{'nodes': 4,\n", 25 | " 'pipeline_length': 2,\n", 26 | " 'name': 'test',\n", 27 | " 'project': 'learning-from-play-303306',\n", 28 | " 'accelerator_type': 'v2-8',\n", 29 | " 'zone': 'us-central1-f',\n", 30 | " 'preemptible': False,\n", 31 | " 'redis_password': '5241590000000000'}" 32 | ] 33 | }, 34 | "execution_count": 4, 35 | "metadata": {}, 36 | "output_type": "execute_result" 37 | } 38 | ], 39 | "source": [ 40 | "\n", 41 | "# Our primary maintenence loop needs to do the following:\n", 42 | "\n", 43 | "# Get a list of TPUs\n", 44 | "\n", 45 | "# If a TPU is active and has a spot in the pipelines then do nothing\n", 46 | "\n", 47 | "# If a TPU is not active, and it is in a pipeline, delete the TPU and halt all training loops in that pipeline.\n", 48 | "# (It was pre-empted and the others will be hanging waiting for it) Maybe we can set it up so that if a training loop is interrupted, it catches the \n", 49 | "cluster_config" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 5, 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "ename": "NameError", 59 | "evalue": "name 'operating_tpus' is not defined", 60 | "output_type": "error", 61 | "traceback": [ 62 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 63 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 64 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0moperating_tpus\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 65 | "\u001b[0;31mNameError\u001b[0m: name 'operating_tpus' is not defined" 66 | ] 67 | } 68 | ], 69 | "source": [ 70 | "operating_tpus[0]" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 6, 76 | "metadata": {}, 77 | "outputs": [ 78 | { 79 | "name": "stdout", 80 | "output_type": "stream", 81 | "text": [ 82 | "Creating test-0\n", 83 | "{'name': 'projects/learning-from-play-303306/locations/us-central1-f/operations/operation-1645434281611-5d88385db5091-ad76b03e-fd6cd594', 'metadata': {'@type': 'type.googleapis.com/google.cloud.common.OperationMetadata', 'createTime': '2022-02-21T09:04:43.377054859Z', 'target': 'projects/learning-from-play-303306/locations/us-central1-f/nodes/test-0', 'verb': 'create', 'cancelRequested': False, 'apiVersion': 'v2alpha1'}, 'done': False}\n", 84 | "Creating test-1\n", 85 | "{'name': 'projects/learning-from-play-303306/locations/us-central1-f/operations/operation-1645434286026-5d883861eaf96-27cef1d1-c627bef3', 'metadata': {'@type': 'type.googleapis.com/google.cloud.common.OperationMetadata', 'createTime': '2022-02-21T09:04:47.730907008Z', 'target': 'projects/learning-from-play-303306/locations/us-central1-f/nodes/test-1', 'verb': 'create', 'cancelRequested': False, 'apiVersion': 'v2alpha1'}, 'done': False}\n", 86 | "Creating test-2\n", 87 | "{'name': 'projects/learning-from-play-303306/locations/us-central1-f/operations/operation-1645434290192-5d883865e3f7f-5b329f40-05f712f0', 'metadata': {'@type': 'type.googleapis.com/google.cloud.common.OperationMetadata', 'createTime': '2022-02-21T09:04:52.013019836Z', 'target': 'projects/learning-from-play-303306/locations/us-central1-f/nodes/test-2', 'verb': 'create', 'cancelRequested': False, 'apiVersion': 'v2alpha1'}, 'done': False}\n", 88 | "Creating test-3\n", 89 | "{'name': 'projects/learning-from-play-303306/locations/us-central1-f/operations/operation-1645434294773-5d88386a42552-9ce0d4d0-ce508366', 'metadata': {'@type': 'type.googleapis.com/google.cloud.common.OperationMetadata', 'createTime': '2022-02-21T09:04:56.575477024Z', 'target': 'projects/learning-from-play-303306/locations/us-central1-f/nodes/test-3', 'verb': 'create', 'cancelRequested': False, 'apiVersion': 'v2alpha1'}, 'done': False}\n" 90 | ] 91 | } 92 | ], 93 | "source": [ 94 | "scale_cluster()" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 8, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "# the next thing we need to do is look at avilable to make and create config\n", 104 | "\n", 105 | "\n", 106 | "valid_pipelines = get_pipelines()\n", 107 | "operating_tpus = [tpu for pipeline in valid_pipelines for tpu in pipeline]" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 9, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "connect()\n" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 27, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "run('cd scalingExperiments/pipelineParallelism')" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 25, 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "run('git clone https://github.com/sholtodouglas/scalingExperiments')\n", 135 | "run('cd scalingExperiments/pipelineParallelism')\n", 136 | "run('pip install -U \"ray[default]\"')\n", 137 | "# add the ray scripts to path - they are not accessible by default\n", 138 | "run('echo \"export PATH=\"$HOME/.local/bin:$PATH\"\" >> ~/.bashrc')\n", 139 | "run('source ~/.bashrc')" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 32, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "run('sudo apt-get update')\n", 149 | "run('sudo apt install -y libopenmpi-dev')\n", 150 | "run('pip install mpi4jax') # Todo this often doesn't get called .. I think b/c of the y?" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 17, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "def copy_file(path, dest):\n", 165 | " for tpu in tqdm(operating_tpus):\n", 166 | " tpu['connection_object'].put(path, dest)\n", 167 | "\n" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 73, 173 | "metadata": {}, 174 | "outputs": [ 175 | { 176 | "name": "stderr", 177 | "output_type": "stream", 178 | "text": [ 179 | " 0%| | 0/4 [00:00> .ssh/authorized_keys')\n", 252 | "\n", 253 | "subprocess.run(\"rm id_rsa id_rsa.pub\", shell=True)" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 74, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "with open('infra/host_list', 'w') as f:\n", 263 | " for idx, tpu in enumerate(operating_tpus):\n", 264 | " # f.write(+'\\n')\n", 265 | " f.write(tpu['networkEndpoints'][0]['ipAddress']+ ' slots=1\\n')\n", 266 | "\n", 267 | "copy_file('infra/host_list', 'scalingExperiments/pipelineParallelism/infra') \n" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 30, 273 | "metadata": {}, 274 | "outputs": [ 275 | { 276 | "name": "stderr", 277 | "output_type": "stream", 278 | "text": [ 279 | "100%|██████████| 4/4 [00:05<00:00, 1.39s/it]\n" 280 | ] 281 | } 282 | ], 283 | "source": [ 284 | "copy_file('infra/mpi_test.py', 'scalingExperiments/pipelineParallelism/infra')" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 106, 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [ 293 | "run('mpirun --hostfile infra/host_list python3 infra/mpi_test.py',0)" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 95, 299 | "metadata": {}, 300 | "outputs": [ 301 | { 302 | "name": "stderr", 303 | "output_type": "stream", 304 | "text": [ 305 | "100%|██████████| 4/4 [00:06<00:00, 1.56s/it]\n" 306 | ] 307 | } 308 | ], 309 | "source": [ 310 | "copy_file('infra/mpi_jax_test.py', 'scalingExperiments/pipelineParallelism/infra')" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 126, 316 | "metadata": {}, 317 | "outputs": [], 318 | "source": [ 319 | "run('mpirun --hostfile infra/host_list python3 infra/mpi_jax_test.py',0)" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": 118, 325 | "metadata": {}, 326 | "outputs": [], 327 | "source": [ 328 | "run('^C')" 329 | ] 330 | }, 331 | { 332 | "cell_type": "markdown", 333 | "metadata": {}, 334 | "source": [] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 116, 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [ 342 | "run('pip install \"jax[tpu]>=0.2.16\" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html')" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": 125, 348 | "metadata": {}, 349 | "outputs": [], 350 | "source": [ 351 | "run(\"pip install --upgrade 'jax[cpu]'\")" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": null, 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "#############################################################################################################################\n", 361 | "run('ray start --head --port=6379', LEADER)" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": null, 367 | "metadata": {}, 368 | "outputs": [], 369 | "source": [ 370 | "# set up ray on all the follower nodes\n", 371 | "for idx, tpu in enumerate(operating_tpus):\n", 372 | " if idx != LEADER:\n", 373 | " run(f\"ray start --address='{operating_tpus[LEADER]['networkEndpoints'][0]['ipAddress']}:6379' \\\n", 374 | " --redis-password='{cluster_config['redis_password']}'\", \\\n", 375 | " idx)" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": null, 381 | "metadata": {}, 382 | "outputs": [], 383 | "source": [ 384 | "run('python3 infra/ray_test.py')" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": 78, 390 | "metadata": {}, 391 | "outputs": [ 392 | { 393 | "data": { 394 | "text/plain": [ 395 | "'--zone=us-central1-f --project=learning-from-play-303306'" 396 | ] 397 | }, 398 | "execution_count": 78, 399 | "metadata": {}, 400 | "output_type": "execute_result" 401 | } 402 | ], 403 | "source": [ 404 | "constant_args" 405 | ] 406 | }, 407 | { 408 | "cell_type": "markdown", 409 | "metadata": {}, 410 | "source": [ 411 | "run('ray stop')" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": null, 417 | "metadata": {}, 418 | "outputs": [], 419 | "source": [ 420 | "10.128.15.204, 10.128.15.207, 10.128.15.201, 10.128.15.203" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": null, 426 | "metadata": {}, 427 | "outputs": [], 428 | "source": [ 429 | "mpirun -np 1 --hostfile infra/host_list --host t1v-n-e23727e8-w-0 python3 infra/mpi_test.py" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": 88, 435 | "metadata": {}, 436 | "outputs": [], 437 | "source": [ 438 | "run('ray stop')" 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": null, 444 | "metadata": {}, 445 | "outputs": [], 446 | "source": [ 447 | "run('pwd')" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": null, 453 | "metadata": {}, 454 | "outputs": [ 455 | { 456 | "data": { 457 | "text/plain": [ 458 | "{'error': {'code': 401,\n", 459 | " 'message': 'Request had invalid authentication credentials. Expected OAuth 2 access token, login cookie or other valid authentication credential. See https://developers.google.com/identity/sign-in/web/devconsole-project.',\n", 460 | " 'status': 'UNAUTHENTICATED',\n", 461 | " 'details': [{'@type': 'type.googleapis.com/google.rpc.ErrorInfo',\n", 462 | " 'reason': 'ACCESS_TOKEN_TYPE_UNSUPPORTED',\n", 463 | " 'metadata': {'method': 'google.cloud.tpu.v2alpha1.Tpu.ListNodes',\n", 464 | " 'service': 'tpu.googleapis.com'}}]}}" 465 | ] 466 | }, 467 | "execution_count": 11, 468 | "metadata": {}, 469 | "output_type": "execute_result" 470 | } 471 | ], 472 | "source": [ 473 | "gcloud alpha compute scp infra/ray_test.py test-0:~/$HOME/" 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": null, 479 | "metadata": {}, 480 | "outputs": [ 481 | { 482 | "name": "stderr", 483 | "output_type": "stream", 484 | "text": [ 485 | " 0%| | 0/6 [00:00, type=, laddr=addr(ip='127.0.0.1', port=43975), raddr=addr(ip='127.0.0.1', port=50414), status='ESTABLISHED', pid=13912),\n", 620 | " sconn(fd=23, family=, type=, laddr=addr(ip='127.0.0.1', port=44969), raddr=(), status='LISTEN', pid=13912),\n", 621 | " sconn(fd=8, family=, type=, laddr=addr(ip='127.0.0.1', port=3080), raddr=(), status='LISTEN', pid=2496),\n", 622 | " sconn(fd=131, family=, type=, laddr=addr(ip='127.0.0.1', port=45942), raddr=addr(ip='127.0.0.1', port=8888), status='ESTABLISHED', pid=1071),\n", 623 | " sconn(fd=7, family=, type=, laddr=addr(ip='127.0.0.1', port=3085), raddr=(), status='LISTEN', pid=2496),\n", 624 | " sconn(fd=18, family=, type=, laddr=addr(ip='::', port=7000), raddr=(), status='LISTEN', pid=15740),\n", 625 | " sconn(fd=-1, family=, type=, laddr=addr(ip='127.0.0.1', port=8888), raddr=addr(ip='127.0.0.1', port=33794), status='TIME_WAIT', pid=None),\n", 626 | " sconn(fd=92, family=, type=, laddr=addr(ip='10.0.0.80', port=45826), raddr=addr(ip='142.250.66.238', port=443), status='NONE', pid=18958),\n", 627 | " sconn(fd=-1, family=, type=, laddr=addr(ip='::', port=48506), raddr=(), status='NONE', pid=None),\n", 628 | " sconn(fd=-1, family=, type=, laddr=addr(ip='0.0.0.0', port=5353), raddr=(), status='NONE', pid=None),\n", 629 | " sconn(fd=16, family=, type=, laddr=addr(ip='127.0.0.1', port=36655), raddr=(), status='LISTEN', pid=13912),\n", 630 | " sconn(fd=27, family=, type=, laddr=addr(ip='127.0.0.1', port=45690), raddr=addr(ip='127.0.0.1', port=36655), status='ESTABLISHED', pid=5015),\n", 631 | " sconn(fd=6, family=, type=, laddr=addr(ip='10.0.0.80', port=54706), raddr=addr(ip='172.217.167.74', port=443), status='CLOSE_WAIT', pid=12478),\n", 632 | " sconn(fd=97, family=, type=, laddr=addr(ip='10.0.0.80', port=40594), raddr=addr(ip='142.250.76.106', port=443), status='NONE', pid=18958),\n", 633 | " sconn(fd=18, family=, type=, laddr=addr(ip='127.0.0.1', port=39853), raddr=(), status='LISTEN', pid=13912),\n", 634 | " sconn(fd=73, family=, type=, laddr=addr(ip='10.0.0.80', port=48355), raddr=addr(ip='142.250.67.14', port=443), status='NONE', pid=18958),\n", 635 | " sconn(fd=102, family=, type=, laddr=addr(ip='10.0.0.80', port=56992), raddr=addr(ip='142.250.66.206', port=443), status='NONE', pid=18958),\n", 636 | " sconn(fd=84, family=, type=, laddr=addr(ip='10.0.0.80', port=49336), raddr=addr(ip='104.244.42.193', port=443), status='ESTABLISHED', pid=18958),\n", 637 | " sconn(fd=197, family=, type=, laddr=addr(ip='224.0.0.251', port=5353), raddr=(), status='NONE', pid=4769),\n", 638 | " sconn(fd=10, family=, type=, laddr=addr(ip='127.0.0.1', port=8892), raddr=(), status='LISTEN', pid=2541),\n", 639 | " sconn(fd=79, family=, type=, laddr=addr(ip='10.0.0.80', port=55647), raddr=addr(ip='142.250.67.2', port=443), status='NONE', pid=18958),\n", 640 | " sconn(fd=-1, family=, type=, laddr=addr(ip='127.0.0.1', port=8888), raddr=addr(ip='127.0.0.1', port=33764), status='TIME_WAIT', pid=None),\n", 641 | " sconn(fd=24, family=, type=, laddr=addr(ip='127.0.0.1', port=50414), raddr=addr(ip='127.0.0.1', port=43975), status='ESTABLISHED', pid=5015),\n", 642 | " sconn(fd=159, family=, type=, laddr=addr(ip='10.0.0.80', port=59176), raddr=addr(ip='104.45.184.134', port=443), status='ESTABLISHED', pid=18958),\n", 643 | " sconn(fd=56, family=, type=, laddr=addr(ip='10.0.0.80', port=46406), raddr=addr(ip='104.18.72.113', port=443), status='ESTABLISHED', pid=18958),\n", 644 | " sconn(fd=10, family=, type=, laddr=addr(ip='10.0.0.80', port=46160), raddr=addr(ip='10.0.0.138', port=139), status='ESTABLISHED', pid=13459),\n", 645 | " sconn(fd=42, family=, type=, laddr=addr(ip='127.0.0.1', port=55619), raddr=addr(ip='127.0.0.1', port=48030), status='ESTABLISHED', pid=13912),\n", 646 | " sconn(fd=-1, family=, type=, laddr=addr(ip='::1', port=631), raddr=(), status='LISTEN', pid=None),\n", 647 | " sconn(fd=29, family=, type=, laddr=addr(ip='127.0.0.1', port=48030), raddr=addr(ip='127.0.0.1', port=55619), status='ESTABLISHED', pid=5015),\n", 648 | " sconn(fd=142, family=, type=, laddr=addr(ip='10.0.0.80', port=48286), raddr=addr(ip='104.18.98.194', port=443), status='ESTABLISHED', pid=18958),\n", 649 | " sconn(fd=5, family=, type=, laddr=addr(ip='::1', port=8805), raddr=(), status='LISTEN', pid=12526),\n", 650 | " sconn(fd=-1, family=, type=, laddr=addr(ip='::', port=5353), raddr=(), status='NONE', pid=None),\n", 651 | " sconn(fd=-1, family=, type=, laddr=addr(ip='127.0.0.53', port=53), raddr=(), status='LISTEN', pid=None),\n", 652 | " sconn(fd=44, family=, type=, laddr=addr(ip='10.0.0.80', port=48707), raddr=addr(ip='142.250.66.228', port=443), status='NONE', pid=18958),\n", 653 | " sconn(fd=6, family=, type=, laddr=addr(ip='10.0.0.80', port=54704), raddr=addr(ip='172.217.167.74', port=443), status='CLOSE_WAIT', pid=12485),\n", 654 | " sconn(fd=88, family=, type=, laddr=addr(ip='10.0.0.80', port=52966), raddr=addr(ip='13.35.147.59', port=443), status='ESTABLISHED', pid=18958),\n", 655 | " sconn(fd=-1, family=, type=, laddr=addr(ip='127.0.0.1', port=8888), raddr=addr(ip='127.0.0.1', port=33792), status='TIME_WAIT', pid=None),\n", 656 | " sconn(fd=12, family=, type=, laddr=addr(ip='127.0.0.1', port=43975), raddr=(), status='LISTEN', pid=13912),\n", 657 | " sconn(fd=-1, family=, type=, laddr=addr(ip='127.0.0.53', port=53), raddr=(), status='NONE', pid=None),\n", 658 | " sconn(fd=20, family=, type=, laddr=addr(ip='127.0.0.1', port=39432), raddr=addr(ip='127.0.0.1', port=39853), status='ESTABLISHED', pid=5015),\n", 659 | " sconn(fd=131, family=, type=, laddr=addr(ip='127.0.0.1', port=39467), raddr=addr(ip='127.0.0.53', port=53), status='NONE', pid=11404),\n", 660 | " sconn(fd=-1, family=, type=, laddr=addr(ip='127.0.0.1', port=8888), raddr=addr(ip='127.0.0.1', port=33752), status='TIME_WAIT', pid=None),\n", 661 | " sconn(fd=31, family=, type=, laddr=addr(ip='127.0.0.1', port=59235), raddr=(), status='LISTEN', pid=13912),\n", 662 | " sconn(fd=121, family=, type=, laddr=addr(ip='10.0.0.80', port=33198), raddr=addr(ip='157.240.8.18', port=443), status='ESTABLISHED', pid=18958),\n", 663 | " sconn(fd=7, family=, type=, laddr=addr(ip='127.0.0.1', port=3333), raddr=(), status='LISTEN', pid=2555),\n", 664 | " sconn(fd=67, family=, type=, laddr=addr(ip='224.0.0.251', port=5353), raddr=(), status='NONE', pid=18958),\n", 665 | " sconn(fd=109, family=, type=, laddr=addr(ip='10.0.0.80', port=60801), raddr=addr(ip='172.217.194.155', port=443), status='NONE', pid=18958),\n", 666 | " sconn(fd=-1, family=, type=, laddr=addr(ip='0.0.0.0', port=45589), raddr=(), status='NONE', pid=None),\n", 667 | " sconn(fd=-1, family=, type=, laddr=addr(ip='0.0.0.0', port=68), raddr=(), status='NONE', pid=None),\n", 668 | " sconn(fd=140, family=, type=, laddr=addr(ip='10.0.0.80', port=57724), raddr=addr(ip='104.16.53.111', port=443), status='ESTABLISHED', pid=18958),\n", 669 | " sconn(fd=68, family=, type=, laddr=addr(ip='10.0.0.80', port=36203), raddr=addr(ip='172.217.167.106', port=443), status='NONE', pid=18958),\n", 670 | " sconn(fd=-1, family=, type=, laddr=addr(ip='127.0.0.1', port=8888), raddr=addr(ip='127.0.0.1', port=33770), status='TIME_WAIT', pid=None),\n", 671 | " sconn(fd=86, family=, type=, laddr=addr(ip='10.0.0.80', port=33190), raddr=addr(ip='157.240.8.18', port=443), status='ESTABLISHED', pid=18958),\n", 672 | " sconn(fd=14, family=, type=, laddr=addr(ip='127.0.0.1', port=55619), raddr=(), status='LISTEN', pid=13912),\n", 673 | " sconn(fd=43, family=, type=, laddr=addr(ip='127.0.0.1', port=36655), raddr=addr(ip='127.0.0.1', port=45698), status='ESTABLISHED', pid=13912),\n", 674 | " sconn(fd=38, family=, type=, laddr=addr(ip='10.0.0.80', port=33562), raddr=addr(ip='140.82.112.25', port=443), status='ESTABLISHED', pid=18958),\n", 675 | " sconn(fd=41, family=, type=, laddr=addr(ip='10.0.0.80', port=57408), raddr=addr(ip='104.18.22.110', port=443), status='ESTABLISHED', pid=18958),\n", 676 | " sconn(fd=33, family=, type=, laddr=addr(ip='127.0.0.1', port=45698), raddr=addr(ip='127.0.0.1', port=36655), status='ESTABLISHED', pid=5015),\n", 677 | " sconn(fd=3, family=, type=, laddr=addr(ip='10.0.0.80', port=47344), raddr=addr(ip='35.194.58.18', port=22), status='ESTABLISHED', pid=12526),\n", 678 | " sconn(fd=-1, family=, type=, laddr=addr(ip='127.0.0.1', port=8888), raddr=addr(ip='127.0.0.1', port=33766), status='TIME_WAIT', pid=None),\n", 679 | " sconn(fd=72, family=, type=, laddr=addr(ip='10.0.0.80', port=56662), raddr=addr(ip='151.101.129.69', port=443), status='ESTABLISHED', pid=18958),\n", 680 | " sconn(fd=42, family=, type=, laddr=addr(ip='10.0.0.80', port=58866), raddr=addr(ip='142.251.10.189', port=443), status='NONE', pid=18958),\n", 681 | " sconn(fd=6, family=, type=, laddr=addr(ip='127.0.0.1', port=8800), raddr=(), status='LISTEN', pid=12527),\n", 682 | " sconn(fd=9, family=, type=, laddr=addr(ip='127.0.0.1', port=56022), raddr=addr(ip='127.0.0.1', port=3333), status='ESTABLISHED', pid=2496),\n", 683 | " sconn(fd=6, family=, type=, laddr=addr(ip='127.0.0.1', port=8805), raddr=(), status='LISTEN', pid=12526),\n", 684 | " sconn(fd=5, family=, type=, laddr=addr(ip='10.0.0.80', port=41590), raddr=addr(ip='142.250.204.10', port=443), status='CLOSE_WAIT', pid=12485),\n", 685 | " sconn(fd=5, family=, type=, laddr=addr(ip='127.0.0.1', port=8888), raddr=(), status='LISTEN', pid=5015),\n", 686 | " sconn(fd=-1, family=, type=, laddr=addr(ip='127.0.0.1', port=631), raddr=(), status='LISTEN', pid=None),\n", 687 | " sconn(fd=-1, family=, type=, laddr=addr(ip='127.0.0.1', port=8888), raddr=addr(ip='127.0.0.1', port=33790), status='TIME_WAIT', pid=None),\n", 688 | " sconn(fd=144, family=, type=, laddr=addr(ip='10.0.0.80', port=40116), raddr=addr(ip='13.236.17.252', port=443), status='ESTABLISHED', pid=18958),\n", 689 | " sconn(fd=69, family=, type=, laddr=addr(ip='10.0.0.80', port=39242), raddr=addr(ip='162.159.135.234', port=443), status='ESTABLISHED', pid=18958),\n", 690 | " sconn(fd=5, family=, type=, laddr=addr(ip='10.0.0.80', port=41584), raddr=addr(ip='142.250.204.10', port=443), status='CLOSE_WAIT', pid=12478),\n", 691 | " sconn(fd=-1, family=, type=, laddr=addr(ip='127.0.0.1', port=8888), raddr=addr(ip='127.0.0.1', port=33754), status='TIME_WAIT', pid=None),\n", 692 | " sconn(fd=3, family=, type=, laddr=addr(ip='10.0.0.80', port=36636), raddr=addr(ip='146.148.98.144', port=22), status='ESTABLISHED', pid=12527),\n", 693 | " sconn(fd=64, family=, type=, laddr=addr(ip='10.0.0.80', port=41092), raddr=addr(ip='142.250.67.3', port=443), status='NONE', pid=18958),\n", 694 | " sconn(fd=70, family=, type=, laddr=addr(ip='10.0.0.80', port=38478), raddr=addr(ip='172.217.167.66', port=443), status='NONE', pid=18958),\n", 695 | " sconn(fd=8, family=, type=, laddr=addr(ip='127.0.0.1', port=3333), raddr=addr(ip='127.0.0.1', port=56020), status='ESTABLISHED', pid=2555),\n", 696 | " sconn(fd=45, family=, type=, laddr=addr(ip='10.0.0.80', port=48824), raddr=addr(ip='142.250.66.197', port=443), status='ESTABLISHED', pid=18958),\n", 697 | " sconn(fd=9, family=, type=, laddr=addr(ip='127.0.0.1', port=8891), raddr=(), status='LISTEN', pid=2541),\n", 698 | " sconn(fd=122, family=, type=, laddr=addr(ip='10.0.0.80', port=50280), raddr=addr(ip='157.240.8.10', port=443), status='ESTABLISHED', pid=18958),\n", 699 | " sconn(fd=-1, family=, type=, laddr=addr(ip='127.0.0.1', port=8888), raddr=addr(ip='127.0.0.1', port=33750), status='TIME_WAIT', pid=None),\n", 700 | " sconn(fd=38, family=, type=, laddr=addr(ip='127.0.0.1', port=39853), raddr=addr(ip='127.0.0.1', port=39432), status='ESTABLISHED', pid=13912),\n", 701 | " sconn(fd=-1, family=, type=, laddr=addr(ip='10.0.0.80', port=45730), raddr=addr(ip='13.35.147.115', port=443), status='TIME_WAIT', pid=None),\n", 702 | " sconn(fd=60, family=, type=, laddr=addr(ip='10.0.0.80', port=56660), raddr=addr(ip='151.101.129.69', port=443), status='ESTABLISHED', pid=18958),\n", 703 | " sconn(fd=5, family=, type=, laddr=addr(ip='::1', port=8800), raddr=(), status='LISTEN', pid=12527),\n", 704 | " sconn(fd=41, family=, type=, laddr=addr(ip='127.0.0.1', port=36655), raddr=addr(ip='127.0.0.1', port=45690), status='ESTABLISHED', pid=13912),\n", 705 | " sconn(fd=43, family=, type=, laddr=addr(ip='10.0.0.80', port=52467), raddr=addr(ip='142.250.66.234', port=443), status='NONE', pid=18958),\n", 706 | " sconn(fd=-1, family=, type=, laddr=addr(ip='0.0.0.0', port=631), raddr=(), status='NONE', pid=None),\n", 707 | " sconn(fd=54, family=, type=, laddr=addr(ip='10.0.0.80', port=33010), raddr=addr(ip='104.244.42.2', port=443), status='ESTABLISHED', pid=18958),\n", 708 | " sconn(fd=9, family=, type=, laddr=addr(ip='127.0.0.1', port=3333), raddr=addr(ip='127.0.0.1', port=56022), status='ESTABLISHED', pid=2555),\n", 709 | " sconn(fd=8, family=, type=, laddr=addr(ip='127.0.0.1', port=56020), raddr=addr(ip='127.0.0.1', port=3333), status='ESTABLISHED', pid=2541),\n", 710 | " sconn(fd=104, family=, type=, laddr=addr(ip='10.0.0.80', port=46410), raddr=addr(ip='104.18.72.113', port=443), status='ESTABLISHED', pid=18958),\n", 711 | " sconn(fd=22, family=, type=, laddr=addr(ip='127.0.0.1', port=39434), raddr=addr(ip='127.0.0.1', port=39853), status='ESTABLISHED', pid=5015),\n", 712 | " sconn(fd=-1, family=, type=, laddr=addr(ip='127.0.0.1', port=8888), raddr=addr(ip='127.0.0.1', port=33796), status='TIME_WAIT', pid=None),\n", 713 | " sconn(fd=58, family=, type=, laddr=addr(ip='10.0.0.80', port=58198), raddr=addr(ip='172.66.40.102', port=443), status='ESTABLISHED', pid=18958),\n", 714 | " sconn(fd=26, family=, type=, laddr=addr(ip='10.0.0.80', port=54368), raddr=addr(ip='104.18.22.110', port=443), status='ESTABLISHED', pid=18958),\n", 715 | " sconn(fd=-1, family=, type=, laddr=addr(ip='127.0.0.1', port=8888), raddr=addr(ip='127.0.0.1', port=33778), status='TIME_WAIT', pid=None),\n", 716 | " sconn(fd=-1, family=, type=, laddr=addr(ip='127.0.0.1', port=8888), raddr=addr(ip='127.0.0.1', port=33780), status='TIME_WAIT', pid=None),\n", 717 | " sconn(fd=-1, family=, type=, laddr=addr(ip='127.0.0.1', port=8888), raddr=addr(ip='127.0.0.1', port=33768), status='TIME_WAIT', pid=None),\n", 718 | " sconn(fd=10, family=, type=, laddr=addr(ip='127.0.0.1', port=8888), raddr=addr(ip='127.0.0.1', port=45942), status='ESTABLISHED', pid=5015),\n", 719 | " sconn(fd=66, family=, type=, laddr=addr(ip='10.0.0.80', port=58668), raddr=addr(ip='34.237.73.95', port=443), status='ESTABLISHED', pid=18958),\n", 720 | " sconn(fd=137, family=, type=, laddr=addr(ip='10.0.0.80', port=34056), raddr=addr(ip='23.48.251.43', port=443), status='ESTABLISHED', pid=18958),\n", 721 | " sconn(fd=93, family=, type=, laddr=addr(ip='10.0.0.80', port=47328), raddr=addr(ip='151.101.52.193', port=443), status='ESTABLISHED', pid=18958),\n", 722 | " sconn(fd=74, family=, type=, laddr=addr(ip='10.0.0.80', port=33906), raddr=addr(ip='192.0.73.2', port=443), status='ESTABLISHED', pid=18958),\n", 723 | " sconn(fd=37, family=, type=, laddr=addr(ip='127.0.0.1', port=39853), raddr=addr(ip='127.0.0.1', port=39434), status='ESTABLISHED', pid=13912)]" 724 | ] 725 | }, 726 | "execution_count": 170, 727 | "metadata": {}, 728 | "output_type": "execute_result" 729 | } 730 | ], 731 | "source": [ 732 | "\n" 733 | ] 734 | }, 735 | { 736 | "cell_type": "code", 737 | "execution_count": null, 738 | "metadata": {}, 739 | "outputs": [ 740 | { 741 | "data": { 742 | "text/plain": [ 743 | "[43975,\n", 744 | " 44969,\n", 745 | " 3080,\n", 746 | " 45942,\n", 747 | " 3085,\n", 748 | " 7000,\n", 749 | " 8888,\n", 750 | " 45826,\n", 751 | " 48506,\n", 752 | " 5353,\n", 753 | " 36655,\n", 754 | " 45690,\n", 755 | " 54706,\n", 756 | " 40594,\n", 757 | " 39853,\n", 758 | " 48355,\n", 759 | " 56992,\n", 760 | " 49336,\n", 761 | " 5353,\n", 762 | " 8892,\n", 763 | " 55647,\n", 764 | " 8888,\n", 765 | " 50414,\n", 766 | " 59176,\n", 767 | " 46406,\n", 768 | " 46160,\n", 769 | " 55619,\n", 770 | " 631,\n", 771 | " 48030,\n", 772 | " 48286,\n", 773 | " 8805,\n", 774 | " 5353,\n", 775 | " 53,\n", 776 | " 48707,\n", 777 | " 54704,\n", 778 | " 52966,\n", 779 | " 8888,\n", 780 | " 43975,\n", 781 | " 53,\n", 782 | " 39432,\n", 783 | " 39467,\n", 784 | " 8888,\n", 785 | " 59235,\n", 786 | " 33198,\n", 787 | " 3333,\n", 788 | " 5353,\n", 789 | " 60801,\n", 790 | " 45589,\n", 791 | " 68,\n", 792 | " 57724,\n", 793 | " 36203,\n", 794 | " 8888,\n", 795 | " 33190,\n", 796 | " 55619,\n", 797 | " 36655,\n", 798 | " 33562,\n", 799 | " 57408,\n", 800 | " 45698,\n", 801 | " 47344,\n", 802 | " 8888,\n", 803 | " 56662,\n", 804 | " 58866,\n", 805 | " 8800,\n", 806 | " 56022,\n", 807 | " 8805,\n", 808 | " 41590,\n", 809 | " 8888,\n", 810 | " 631,\n", 811 | " 8888,\n", 812 | " 40116,\n", 813 | " 39242,\n", 814 | " 41584,\n", 815 | " 8888,\n", 816 | " 36636,\n", 817 | " 41092,\n", 818 | " 38478,\n", 819 | " 3333,\n", 820 | " 48824,\n", 821 | " 8891,\n", 822 | " 50280,\n", 823 | " 8888,\n", 824 | " 39853,\n", 825 | " 45730,\n", 826 | " 56660,\n", 827 | " 8800,\n", 828 | " 36655,\n", 829 | " 52467,\n", 830 | " 631,\n", 831 | " 33010,\n", 832 | " 3333,\n", 833 | " 56020,\n", 834 | " 46410,\n", 835 | " 39434,\n", 836 | " 8888,\n", 837 | " 58198,\n", 838 | " 54368,\n", 839 | " 8888,\n", 840 | " 8888,\n", 841 | " 8888,\n", 842 | " 8888,\n", 843 | " 58668,\n", 844 | " 34056,\n", 845 | " 47328,\n", 846 | " 33906,\n", 847 | " 39853]" 848 | ] 849 | }, 850 | "execution_count": 181, 851 | "metadata": {}, 852 | "output_type": "execute_result" 853 | } 854 | ], 855 | "source": [ 856 | "\n", 857 | "ports" 858 | ] 859 | }, 860 | { 861 | "cell_type": "code", 862 | "execution_count": null, 863 | "metadata": {}, 864 | "outputs": [], 865 | "source": [ 866 | "make_tmux(6)" 867 | ] 868 | }, 869 | { 870 | "cell_type": "code", 871 | "execution_count": null, 872 | "metadata": {}, 873 | "outputs": [ 874 | { 875 | "name": "stdout", 876 | "output_type": "stream", 877 | "text": [ 878 | "Deleting test_1\n", 879 | "Starting test_1\n", 880 | "Deleting test_3\n", 881 | "Starting test_3\n", 882 | "Creating test_4\n", 883 | "{'name': 'projects/learning-from-play-303306/locations/us-central1-f/operations/operation-1643935665172-5d726995b2f90-f51e4767-e303dfb7', 'metadata': {'@type': 'type.googleapis.com/google.cloud.common.OperationMetadata', 'createTime': '2022-02-04T00:47:46.994482809Z', 'target': 'projects/learning-from-play-303306/locations/us-central1-f/nodes/test_4', 'verb': 'create', 'cancelRequested': False, 'apiVersion': 'v2alpha1'}, 'done': False}\n", 884 | "Creating test_5\n", 885 | "{'name': 'projects/learning-from-play-303306/locations/us-central1-f/operations/operation-1643935669267-5d7269999ace0-c9e483cb-0308fdfa', 'metadata': {'@type': 'type.googleapis.com/google.cloud.common.OperationMetadata', 'createTime': '2022-02-04T00:47:50.986263737Z', 'target': 'projects/learning-from-play-303306/locations/us-central1-f/nodes/test_5', 'verb': 'create', 'cancelRequested': False, 'apiVersion': 'v2alpha1'}, 'done': False}\n" 886 | ] 887 | } 888 | ], 889 | "source": [ 890 | "scale_cluster(6)" 891 | ] 892 | }, 893 | { 894 | "cell_type": "code", 895 | "execution_count": null, 896 | "metadata": {}, 897 | "outputs": [ 898 | { 899 | "data": { 900 | "text/plain": [ 901 | "{'error': {'code': 404,\n", 902 | " 'message': \"Resource 'projects/learning-from-play-303306/locations/us-central1-f/nodes/test_1' was not found\",\n", 903 | " 'status': 'NOT_FOUND',\n", 904 | " 'details': [{'@type': 'type.googleapis.com/google.rpc.ResourceInfo',\n", 905 | " 'resourceName': 'projects/learning-from-play-303306/locations/us-central1-f/nodes/test_1'}]}}" 906 | ] 907 | }, 908 | "execution_count": 71, 909 | "metadata": {}, 910 | "output_type": "execute_result" 911 | } 912 | ], 913 | "source": [ 914 | "check_tpu('test_1')" 915 | ] 916 | }, 917 | { 918 | "cell_type": "code", 919 | "execution_count": null, 920 | "metadata": {}, 921 | "outputs": [ 922 | { 923 | "name": "stdout", 924 | "output_type": "stream", 925 | "text": [ 926 | "Creating test_1\n", 927 | "{'name': 'projects/learning-from-play-303306/locations/us-central1-f/operations/operation-1643889519283-5d71bdad8cecc-cf195c1f-86a70427', 'metadata': {'@type': 'type.googleapis.com/google.cloud.common.OperationMetadata', 'createTime': '2022-02-03T11:58:41.098872920Z', 'target': 'projects/learning-from-play-303306/locations/us-central1-f/nodes/test_1', 'verb': 'create', 'cancelRequested': False, 'apiVersion': 'v2alpha1'}, 'done': False}\n" 928 | ] 929 | } 930 | ], 931 | "source": [ 932 | "scale_cluster(4)" 933 | ] 934 | }, 935 | { 936 | "cell_type": "code", 937 | "execution_count": null, 938 | "metadata": {}, 939 | "outputs": [ 940 | { 941 | "name": "stderr", 942 | "output_type": "stream", 943 | "text": [ 944 | " 0%| | 0/6 [00:00