├── .gitignore ├── Distributed.html ├── Distributed.ipynb ├── Distributed.py ├── LICENSE ├── README.md ├── drawing.py ├── lib.py ├── puzzles.ipynb └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /Distributed.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # cell_metadata_filter: -all 5 | # custom_cell_magics: kql 6 | # text_representation: 7 | # extension: .py 8 | # format_name: light 9 | # format_version: '1.5' 10 | # jupytext_version: 1.14.6 11 | # kernelspec: 12 | # display_name: venv 13 | # language: python 14 | # name: python3 15 | # --- 16 | 17 | # # LLM Training Puzzles 18 | # 19 | # by Sasha Rush ([@srush_nlp](https://twitter.com/srush_nlp)) 20 | 21 | # %%capture 22 | # Uncomment to run in Colab 23 | # !pip install -qqq git+https://github.com/chalk-diagrams/chalk asyncio 24 | # !wget https://raw.githubusercontent.com/srush/LLM-Training-Puzzles/main/lib.py https://raw.githubusercontent.com/srush/LLM-Training-Puzzles/main/drawing.py 25 | 26 | from typing import List 27 | from lib import Model, Dist, WeightGrad 28 | from drawing import draw, draw_group 29 | from chalk import vcat 30 | import asyncio 31 | import chalk 32 | chalk.set_svg_height(400) 33 | chalk.set_svg_draw_height(600) 34 | 35 | # ## Preliminaries 36 | # 37 | # The goal of these puzzles is to learn about distributed training of LLMs. However, we will be primarily concerned with a speed and memory efficiency of completing a single update of the models. To make things simpler, we will abstract away from the standard tensor-based transformer model, and just consider a state-less representation of each of the components of a multi-layer neural network. 38 | # 39 | # 40 | 41 | model = Model(layers=2, batches=4) 42 | weights, opt_states, activations, grad_activations, grad_weights = model.storage() 43 | 44 | # Our library has 5 parts: 45 | # 46 | # * Weights 47 | # * Optimizer States - Values needed to update the weights 48 | # * Activations - The internal values computed on the forward pass 49 | # * Grad Activations - The gradients of the loss wrt to activations, needed for backward pass 50 | # * Grad Weights - The gradients of the loss wrt to weights, needed for updates 51 | # 52 | # For these puzzles, you are *not allowed* to have local variables. You need to store each of these in the dictionary corresponding to its type. 53 | # 54 | # We begin by tracing the lifecycle of a single model update. 55 | 56 | # Get the input activations to the model for batches 2, 3 57 | activations[0] = model.get_activation(batches=[2, 3]) 58 | activations[0] 59 | 60 | # Load the weights (random) for layers 0 and 1 61 | for i in range(model.LAYERS): 62 | weights[i], opt_states[i] = model.load_weights(i) 63 | weights[0] 64 | 65 | # Activations can be moved forward a layer if you have the weights. 66 | activations[1] = model.forward(layer=0, inp=activations[0], weight=weights[0]) 67 | activations[2] = model.forward(layer=1, inp=activations[1], weight=weights[1]) 68 | activations[1] 69 | 70 | # Draw all the current activations in memory. 71 | draw_group(activations) 72 | 73 | # At the last layer, we can convert an activation to a grad activation by calling `loss` 74 | grad_activations[model.LAYERS] = model.loss(activations[model.LAYERS]) 75 | grad_activations[model.LAYERS] 76 | 77 | # Calling `backward` requires the forward activation, the backward grad activation, and the weights. 78 | # It returns the grad weights and the backward activation. 79 | grad_weights[1], grad_activations[1] = model.backward(1, activations[1], grad_activations[2], weights[1]) 80 | grad_weights[0], grad_activations[0] = model.backward(0, activations[0], grad_activations[1], weights[0]) 81 | grad_activations[1] 82 | 83 | # We can use delete to remove any memory that is not longer needed. 84 | print("Before memory:", model.memory()) 85 | del grad_activations[1] 86 | print("After memory:", model.memory()) 87 | model.status() 88 | draw_group(grad_activations) 89 | 90 | # Grad weights keep track of which batches they are for. Here we only have the grad weights for batches 2 and 3. 91 | draw_group(grad_weights) 92 | 93 | # If we try to update with the grad weights we will get an error. 94 | try: 95 | model.update(0, weight_grad=grad_weights[0], weight=weights[0], opt_state=opt_states[0]) 96 | except AssertionError as e: 97 | print("Error! Only have batches") 98 | print(e) 99 | 100 | # For this example, we can cheat. Pretend we had the other gradients we needed. 101 | grad_weights[0, 0] = model.fake_grad(0, [0,1]) 102 | grad_weights[1, 0] = model.fake_grad(1, [0,1]) 103 | grad_weights[0, 0] 104 | 105 | 106 | # Summing together grad_weights gives the full gradient. 107 | grad_weights[0] = grad_weights[0] + grad_weights[0, 0] 108 | 109 | # + 110 | # Now we can call update to the get the new weights and opt_state. 111 | weights[0], opt_states[0] = model.update(0, weight_grad=grad_weights[0], weight=weights[0], 112 | opt_state=opt_states[0]) 113 | 114 | # WARNING: You need to set all variables. Otherwise they are not counted towards memory. 115 | grad_weights[1] = grad_weights[1] + grad_weights[1, 0] 116 | weights[1], opt_states[1] = model.update(1, weight_grad=grad_weights[1], 117 | weight=weights[1], opt_state=opt_states[1]) 118 | # - 119 | 120 | 121 | # We can complete the tests by setting these as the final weights and calling check. 122 | model.set_final_weight(0, weights[0]) 123 | model.set_final_weight(1, weights[1]) 124 | Model.check([model]) 125 | draw_group(model.final_weights) 126 | 127 | # We can view the final outcome of the system as a diagram. 128 | # This show the forward and backward passes (numbers of batches) and the updates. 129 | # The lines on the bottom show the memory that is used at each time step. 130 | draw([model]) 131 | 132 | 133 | 134 | # ### Puzzle 0 - Standard Training 135 | # 136 | # Write a standard (non-distributed) training loop that acts on all the batches and loads all the weights. It should just run forward, loss, backward, and update. Aim for the least amount of max memory used. 137 | # 138 | # * Target Time: 17 steps 139 | # * Target Memory: 2600000 140 | 141 | def basic(model: Model) -> Model: 142 | # Storage on device. 143 | weights, opt_states, activations, grad_activations, grad_weights = model.storage() 144 | 145 | # Load in the full weights 146 | for l in range(model.LAYERS): 147 | weights[l], opt_states[l] = model.load_weights(l) 148 | 149 | # Load the input layer activations 150 | activations[0] = model.get_activation(range(model.BATCHES)) 151 | 152 | ## USER CODE 153 | # Forward 154 | for l in range(model.LAYERS): 155 | activations[l + 1] = model.forward(l, activations[l], weights[l]) 156 | 157 | # Backward 158 | grad_activations[model.LAYERS] = model.loss(activations[model.LAYERS]) 159 | del activations[model.LAYERS] 160 | 161 | for l in range(model.LAYERS - 1, -1, -1): 162 | grad_weights[l], grad_activations[l] = model.backward( 163 | l, activations[l], grad_activations[l + 1], weights[l] 164 | ) 165 | del grad_activations[l + 1], activations[l] 166 | del grad_activations[0] 167 | assert len(grad_activations) == 0 and len(activations) ==0 168 | 169 | # Update 170 | for l in range(model.LAYERS): 171 | weights[l], opt_states[l] = model.update(l, grad_weights[l], weights[l], opt_states[l]) 172 | ## END USER CODE 173 | 174 | for l in range(model.LAYERS): 175 | model.set_final_weight(l, weights[l]) 176 | return model 177 | 178 | 179 | out = basic(Model(layers=2, batches=4, rank=0, dist=Dist(1))) 180 | draw_group(out.final_weights) 181 | 182 | draw([out]) 183 | 184 | Model.check([out]) 185 | 186 | 187 | # ### Puzzle 1 - Gradient Accumulation 188 | # 189 | # For this puzzle, the goal is to reduce max memory usage. To do so you are going to run on each batch individually instead of all together. 190 | # 191 | # Write a function with four parts. First run on batches {0} and then {1} etc. Sum the grad weights and then update. 192 | # 193 | # * Target Time: 17 steps 194 | # * Target Memory: 2000000 195 | 196 | def grad_accum(model: Model) -> Model: 197 | # Storage on device. 198 | weights, opt_states, activations, grad_activations, grad_weights = model.storage() 199 | 200 | # Load in the full weights 201 | for l in range(model.LAYERS): 202 | weights[l], opt_states[l] = model.load_weights(l) 203 | 204 | ## USER CODE 205 | for r in range(model.BATCHES): 206 | # Load the input layer activations 207 | activations[0, r] = model.get_activation([r]) 208 | 209 | ## USER CODE 210 | # Forward 211 | for l in range(model.LAYERS): 212 | activations[l + 1, r] = model.forward(l, activations[l, r], weights[l]) 213 | 214 | # Backward 215 | grad_activations[model.LAYERS, r] = model.loss(activations[model.LAYERS, r]) 216 | del activations[model.LAYERS, r] 217 | 218 | for l in range(model.LAYERS - 1, -1, -1): 219 | grad_weights[l, r], grad_activations[l, r] = model.backward( 220 | l, activations[l, r], grad_activations[l + 1, r], weights[l] 221 | ) 222 | if r == 0: 223 | grad_weights[l] = grad_weights[l, r] 224 | else: 225 | grad_weights[l] = grad_weights[l] + grad_weights[l, r] 226 | del grad_activations[l + 1, r], activations[l,r], grad_weights[l, r] 227 | del grad_activations[0, r] 228 | assert len(grad_activations) == 0 and len(activations) == 0 229 | 230 | # Update 231 | for l in range(model.LAYERS): 232 | weights[l], opt_states[l] = \ 233 | model.update(l, 234 | grad_weights[l], weights[l], opt_states[l]) 235 | 236 | ## END USER CODE 237 | for l in range(model.LAYERS): 238 | model.set_final_weight(l, weights[l]) 239 | return model 240 | 241 | 242 | out = grad_accum(Model(layers=2, batches=4, rank=0, dist=Dist(1))) 243 | draw_group(out.final_weights) 244 | 245 | draw([out]) 246 | 247 | Model.check([out]) 248 | 249 | # ## Communications: AllReduce 250 | 251 | # When working with multiple GPUs we need to have communication. 252 | # The primary communication primitives for GPUs are implemented in NCCL. 253 | # 254 | # https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/operations.html 255 | # 256 | # We are not going to use these directly, but simulate them using Python and asyncio. 257 | # 258 | # The first operation is AllReduce. We will have 4 GPUs (ranks=4) and use them each to compute a batch of weight grads. 259 | 260 | ranks = 4 261 | weight_grads = [WeightGrad(0, 1, {i}, ranks) for i in range(ranks)] 262 | weight_grads[0] + weight_grads[1] + weight_grads[2] + weight_grads[3] 263 | 264 | # Simple asynchronous function that calls allreduce to sum the weight grads at layer 0 265 | async def myfunc(model: Model) -> WeightGrad: 266 | return await model.allreduce(weight_grads[model.rank], 0) 267 | 268 | # This code uses asyncio to run the above function on 4 "GPUs" . 269 | dist = Dist(ranks) 270 | out_weight_grads = await asyncio.gather(*[ 271 | myfunc(Model(layers=1, batches=1, rank=i, dist=dist)) 272 | for i in range(ranks)]) 273 | out_weight_grads[0] 274 | 275 | # Note: When running communication operations like AllReduce on a GPU, the communication happens in parallel to the computation on that GPU. That means the API for AllReduce does not block, and allows the model to continue running while waiting for this command to run. This means it is beneficial to run AllReduce (and other communication) as early as possible so that other compute can be run during the reduction. 276 | # 277 | # We will ignore this in these puzzles and represent communication as happening efficiently. 278 | 279 | # ### Puzzle 2 - Distributed Data Parallel 280 | # 281 | # Write a function with four parts. First run on batches {0} and then {1} etc. Sum the grad weights and then update. The main benefit of this approach is compute efficiency over gradient accumulation. 282 | # 283 | # * Total Steps: 5 284 | # * Total Memory: 1800000 285 | 286 | async def ddp(model: Model) -> Model: 287 | # Storage on device. 288 | weights, opt_states, activations, grad_activations, grad_weights = model.storage() 289 | # Load all the activations 290 | model.activations[0] = model.get_activation([model.rank]) 291 | 292 | ## USER CODE 293 | 294 | # Load in the full weights 295 | for l in range(model.LAYERS): 296 | weights[l], opt_states[l] = model.load_weights(l) 297 | 298 | # Forward 299 | for l in range(model.LAYERS): 300 | activations[l + 1] = model.forward(l, activations[l], weights[l]) 301 | 302 | # Backward 303 | grad_activations[model.LAYERS] = model.loss(activations[model.LAYERS]) 304 | 305 | for l in range(model.LAYERS - 1, -1, -1): 306 | grad_weights[l], grad_activations[l] = model.backward( 307 | l, activations[l], grad_activations[l + 1], weights[l] 308 | ) 309 | del grad_activations[l + 1], activations[l] 310 | 311 | # Update 312 | for l in range(model.LAYERS): 313 | grad_weights[l] = await model.allreduce(grad_weights[l], l) 314 | weights[l], opt_states[l] = model.update(l, grad_weights[l], weights[l], opt_states[l]) 315 | 316 | ## END USER CODE 317 | for l in range(model.LAYERS): 318 | model.set_final_weight(l, weights[l]) 319 | return model 320 | 321 | 322 | dist = Dist(ranks) 323 | out = await asyncio.gather(*[ 324 | ddp(Model(layers=2, batches=ranks, rank=i, dist=dist)) 325 | for i in range(ranks)]) 326 | draw_group(out[0].final_weights) 327 | 328 | draw(out) 329 | 330 | Model.check(out) 331 | 332 | # ## Communication: AllGather / Sharding 333 | # 334 | # Our next primitive is AllGather. This allows us to communicate "shards" of an object stored on different GPUs to all the GPUs. 335 | 336 | # Load only part of a weights. 337 | model = Model(layers=2, batches=1, rank=0, dist=Dist(1)) 338 | weight, _ = model.load_weights(0, shard=0, total=4) 339 | weight 340 | 341 | # Combine togegher two shards on one machine. 342 | weights = [model.load_weights(0, shard=i, total=ranks)[0] for i in range(ranks)] 343 | weights[0].combine(weights[2]) 344 | 345 | # + 346 | # Use allgather to collect the shards from all machines. 347 | async def mygather(model: Model) -> WeightGrad: 348 | # Allreduce sums together all the weight grads 349 | return await model.allgather(weights[model.rank], 0) 350 | 351 | dist = Dist(ranks) 352 | out_weights = await asyncio.gather(*[ 353 | mygather(Model(layers=1, batches=1, rank=i, dist=dist)) 354 | for i in range(ranks)]) 355 | out_weights[0] 356 | # - 357 | 358 | # ### Puzzle 3: Weight-Sharded Data Parallel 359 | # 360 | # Run a model that shards each layer weight over all the machines. Reconstruct the layer weight at each layer using allgather. Finally update the weights on each machine using allreduce. 361 | # 362 | # * Total Steps: 20 363 | # * Total Memory: 2800000 364 | 365 | async def wsdp(model: Model) -> Model: 366 | # Storage on device. 367 | weights, opt_states, activations, grad_activations, grad_weights = model.storage() 368 | 369 | # Load all the activations 370 | model.activations[0] = model.get_activation([model.rank]) 371 | 372 | # Load a shard of the weights for every layer. Load in the full optimizer states 373 | for l in range(model.LAYERS): 374 | weights[l], opt_states[l] = model.load_weights(l, model.rank, model.RANKS) 375 | 376 | ## USER CODE 377 | # Forward 378 | for l in range(model.LAYERS): 379 | weights[l, 0] = await model.allgather(weights[l], l) 380 | activations[l + 1] = model.forward(l, activations[l], weights[l, 0]) 381 | del weights[l, 0] 382 | 383 | # Backward 384 | grad_activations[model.LAYERS] = model.loss(activations[model.LAYERS]) 385 | 386 | for l in range(model.LAYERS - 1, -1, -1): 387 | weights[l, 0] = await model.allgather(weights[l], l) 388 | grad_weights[l], grad_activations[l] = model.backward( 389 | l, activations[l], grad_activations[l + 1], weights[l, 0] 390 | ) 391 | del grad_activations[l + 1], activations[l], weights[l, 0] 392 | 393 | # Update 394 | for l in range(model.LAYERS): 395 | grad_weights[l] = await model.allreduce(grad_weights[l], l) 396 | weights[l], opt_states[l] = model.update(l, grad_weights[l], weights[l], opt_states[l]) 397 | 398 | ## END USER CODE 399 | for l in range(model.LAYERS): 400 | model.set_final_weight(l, weights[l]) 401 | 402 | return model 403 | 404 | dist = Dist(ranks) 405 | out = await asyncio.gather(*[ 406 | wsdp(Model(layers=6, batches=ranks, rank=i, dist=dist)) 407 | for i in range(ranks)]) 408 | draw_group(out[1].final_weights) 409 | 410 | draw(out) 411 | 412 | Model.check(out) 413 | 414 | # ## Communication: Scatter-Reduce 415 | 416 | # Scatter across shards 417 | # Reduce across batches 418 | 419 | grad_weight = WeightGrad(0, 1, batches={1}, total_batches=4, 420 | shards={1}, total=4) 421 | grad_weight 422 | 423 | grad_weights = {i: WeightGrad(0, 1, batches={i}, total_batches=4, 424 | shards={0,1,2,3}, total=4) for i in range(4)} 425 | grad_weights[2] 426 | 427 | # + 428 | async def scatterreduce(model: Model) -> WeightGrad: 429 | # Allreduce sums together all the weight grads 430 | return await model.scatterreduce(grad_weights[model.rank], 0) 431 | 432 | dist = Dist(ranks) 433 | out = await asyncio.gather(*[ 434 | scatterreduce(Model(layers=1, batches=1, rank=i, dist=dist)) 435 | for i in range(ranks)]) 436 | out[0] 437 | # - 438 | 439 | 440 | 441 | # ### Puzzle 4: Fully-Sharded Data Parallel 442 | # 443 | # Run a model that shards each layer weight over all the machines. Reconstruct the layer weight at each layer using allgather. Collect the gradients with scatter-reduce. 444 | # 445 | # * Total Steps: 20 446 | # * Total Memory: 2300000 447 | 448 | async def fsdp(model: Model) -> Model: 449 | # Storage on device. 450 | weights, opt_states, activations, grad_activations, grad_weights = model.storage() 451 | 452 | # Load all the activations 453 | model.activations[0] = model.get_activation([model.rank]) 454 | 455 | # Load a shard of the weights for every layer. Load in the full weights 456 | for l in range(model.LAYERS): 457 | weights[l], opt_states[l] = model.load_weights(l, model.rank, model.RANKS) 458 | 459 | ## USER CODE 460 | # Forward 461 | for l in range(model.LAYERS): 462 | weights[l, 0] = await model.allgather(weights[l], l) 463 | activations[l + 1] = model.forward(l, activations[l], weights[l, 0]) 464 | del weights[l, 0] 465 | 466 | # Backward 467 | grad_activations[model.LAYERS] = model.loss(activations[model.LAYERS]) 468 | del(activations[model.LAYERS]) 469 | 470 | for l in range(model.LAYERS - 1, -1, -1): 471 | weights[l, 0] = await model.allgather(weights[l], l) 472 | grad_weights[l], grad_activations[l] = model.backward( 473 | l, activations[l], grad_activations[l + 1], weights[l, 0] 474 | ) 475 | grad_weights[l] = await model.scatterreduce(grad_weights[l], l) 476 | del grad_activations[l + 1], activations[l], weights[l, 0] 477 | 478 | # Update 479 | for l in range(model.LAYERS): 480 | weights[l], opt_states[l] = model.update(l, grad_weights[l], weights[l], opt_states[l]) 481 | 482 | ## END USER CODE 483 | for l in range(model.LAYERS): 484 | model.set_final_weight(l, weights[l]) 485 | return model 486 | 487 | 488 | dist = Dist(ranks) 489 | out = await asyncio.gather(*[ 490 | fsdp(Model(layers=6, batches=ranks, rank=i, dist=dist)) 491 | for i in range(ranks)]) 492 | draw_group(out[1].final_weights) 493 | 494 | draw(out) 495 | 496 | Model.check(out) 497 | 498 | # ## Communication: Point-to-Point 499 | # 500 | # An alternative approach to communication is to directly communicate specific information between GPUs. In our model, both GPUs talking to each other block and wait for the handoff. 501 | 502 | # + 503 | async def talk(model: Model) -> None: 504 | if model.rank == 0: 505 | await model.pass_to(1, "extra cheese") 506 | val = await model.receive() 507 | print(val) 508 | else: 509 | val = await model.receive() 510 | print(val) 511 | val = await model.pass_to(0, "pizza") 512 | 513 | dist = Dist(2) 514 | result = await asyncio.gather(*[ 515 | talk(Model(layers=1, batches=1, rank=i, dist=dist)) 516 | for i in range(2)]) 517 | # - 518 | 519 | 520 | # ### Puzzle 5: Pipeline Parallelism 521 | # 522 | # Split the layer weights and optimizers equally between GPUs. Have each GPU handle only its layer. Pass the full set of batches for activations and grad_activations between layers using p2p communication. No need for any global communication. 523 | # 524 | # * Total Steps: 66 525 | # * Total Memory: 3300000 526 | 527 | async def pipeline(model: Model) -> Model: 528 | weights, opt_states, activations, grad_activations, grad_weights = model.storage() 529 | per_rank = model.LAYERS // model.RANKS 530 | my_layers = list([l + (model.rank * per_rank) for l in range(per_rank)]) 531 | for l in my_layers: 532 | weights[l], opt_states[l] = model.load_weights(l) 533 | ## USER CODE 534 | 535 | if model.rank == 0: 536 | activations[0] = model.get_activation(range(model.BATCHES)) 537 | else: 538 | activations[my_layers[0]] = await model.receive() 539 | 540 | # Forward 541 | for l in my_layers: 542 | activations[l + 1] = model.forward(l, activations[l], weights[l]) 543 | 544 | # Backward 545 | if model.rank == model.RANKS - 1: 546 | grad_activations[model.LAYERS] = model.loss( 547 | activations[model.LAYERS] 548 | ) 549 | else: 550 | await model.pass_to(model.rank + 1, activations[l + 1]) 551 | grad_activations[l + 1] = await model.receive() 552 | 553 | for l in reversed(my_layers): 554 | grad_weights[l], grad_activations[l] = model.backward( 555 | l, activations[l], grad_activations[l + 1], model.weights[l] 556 | ) 557 | del model.grad_activations[l + 1], model.activations[l] 558 | 559 | if model.rank != 0: 560 | await model.pass_to(model.rank - 1, grad_activations[l]) 561 | 562 | # Update 563 | for l in my_layers: 564 | weights[l], opt_states[l] = model.update(l, grad_weights[l], weights[l], opt_states[l]) 565 | 566 | ## END USER CODE 567 | for l in my_layers: 568 | model.set_final_weight(l, weights[l]) 569 | return model 570 | 571 | 572 | dist = Dist(ranks) 573 | out = await asyncio.gather(*[ 574 | pipeline(Model(layers=8, batches=ranks, rank=i, dist=dist)) 575 | for i in range(ranks)]) 576 | draw_group(out[1].final_weights) 577 | 578 | draw(out) 579 | 580 | Model.check(out) 581 | 582 | # ### Puzzle 6: GPipe Schedule 583 | # 584 | # A major issue with the pipeline approach is that it causes a "bubble", i.e. time in the later layers waiting for the earlier layers to complete. An alternative approach is to split the batches smaller so you can pass them earlier. 585 | # 586 | # In this puzzle, you should run each batch by itself, and then pass. The graph should look similar as the one above but with a smaller bubble. 587 | # 588 | # * Total Steps: 33 589 | # * Total Memory: 4100000 590 | 591 | async def gpipe(model: Model) -> Model: 592 | weights, opt_states, activations, grad_activations, grad_weights = model.storage() 593 | per_rank = model.LAYERS // model.RANKS 594 | my_layers = list([l + (model.rank * per_rank) for l in range(per_rank)]) 595 | for l in my_layers: 596 | weights[l], opt_states[l] = model.load_weights(l) 597 | 598 | # USER CODE 599 | for mb in range(model.BATCHES): 600 | # Forward 601 | if model.rank == 0: 602 | activations[0, mb] = model.get_activation([mb]) 603 | else: 604 | activations[my_layers[0], mb] = await model.receive() 605 | 606 | for l in my_layers: 607 | activations[l + 1, mb] = model.forward(l, activations[l, mb], weights[l]) 608 | if model.rank != model.RANKS - 1: 609 | await model.pass_to(model.rank + 1, activations[l + 1, mb]) 610 | 611 | for mb in range(model.BATCHES): 612 | # Backward 613 | if model.rank == model.RANKS - 1: 614 | grad_activations[model.LAYERS, mb] = model.loss( 615 | activations[model.LAYERS, mb] 616 | ) 617 | else: 618 | grad_activations[my_layers[-1] + 1, mb] = await model.receive() 619 | 620 | for l in reversed(my_layers): 621 | grad_weights[l, mb], grad_activations[l, mb] = model.backward( 622 | l, activations[l, mb], grad_activations[l + 1, mb], weights[l] 623 | ) 624 | del grad_activations[l + 1, mb], activations[l, mb] 625 | 626 | if model.rank != 0: 627 | await model.pass_to(model.rank - 1, grad_activations[l, mb]) 628 | 629 | # Update 630 | for l in reversed(my_layers): 631 | for mb in range(model.BATCHES): 632 | if mb != 0: 633 | grad_weights[l] = grad_weights[l] + grad_weights[l, mb] 634 | else: 635 | grad_weights[l] = grad_weights[l, 0] 636 | del grad_weights[l, mb] 637 | weights[l], opt_states[l] = model.update(l, grad_weights[l], weights[l], opt_states[l]) 638 | 639 | ## END USER CODE 640 | for l in my_layers: 641 | model.set_final_weight(l, weights[l]) 642 | 643 | return model 644 | 645 | 646 | dist = Dist(ranks) 647 | out = await asyncio.gather(*[ 648 | gpipe(Model(layers=8, batches=ranks, rank=i, dist=dist)) 649 | for i in range(ranks)]) 650 | draw_group(out[1].final_weights) 651 | 652 | draw(out) 653 | 654 | Model.check(out) 655 | 656 | 657 | # ### Puzzle 7: Pipeline + FSDP 658 | # 659 | # As a last exercise, we can put everything together. Here we are going to run a combination of pipeline parallelism while also sharding our weight between 16 different machines. Here the model only has 4 layers, so we will assign 4 GPUs to each layer in the pipeline parallel approach. 660 | # 661 | # This example requires combining both collective communication and p2p communication effectively. 662 | # 663 | # * Total Steps: 15 664 | # * Total Memory: 1000000 665 | 666 | async def pipeline_fsdp(model: Model) -> Model: 667 | weights, opt_states, activations, grad_activations, grad_weights = model.storage() 668 | per_rank = model.LAYERS // (model.RANKS // 4) 669 | my_layers = list([l + ((model.rank % 4) * per_rank) for l in range(per_rank)]) 670 | for l in range(model.LAYERS): 671 | weights[l, 0], opt_states[l, 0] = model.load_weights(l, model.rank, model.RANKS) 672 | def empty_grad(l): 673 | return model.fake_grad(l, []) 674 | ## USER CODE 675 | # Forward 676 | for l in range(model.LAYERS): 677 | if l == my_layers[0]: 678 | if model.rank % 4 == 0: 679 | activations[0] = model.get_activation([model.rank // 4]) 680 | else: 681 | activations[l] = await model.receive() 682 | 683 | weights[l] = await model.allgather(weights[l, 0], l) 684 | if l in my_layers: 685 | activations[l + 1] = model.forward(l, activations[l], weights[l]) 686 | del weights[l] 687 | if l == my_layers[-1]: 688 | if model.rank % 4 == 3 : 689 | grad_activations[model.LAYERS] = model.loss( 690 | activations[model.LAYERS] 691 | ) 692 | else: 693 | await model.pass_to(model.rank + 1, activations[l + 1]) 694 | # Backward 695 | 696 | for l in reversed(range(model.LAYERS)): 697 | if l == my_layers[-1]: 698 | if model.rank % 4 != 3: 699 | grad_activations[l + 1] = await model.receive() 700 | 701 | weights[l] = await model.allgather(weights[l, 0], l) 702 | if l in my_layers: 703 | grad_weights[l], grad_activations[l] = model.backward( 704 | l, activations[l], grad_activations[l + 1], model.weights[l] 705 | ) 706 | del grad_activations[l + 1], activations[l] 707 | grad_weights[l] = await model.scatterreduce(grad_weights[l], l) 708 | else: 709 | grad_weights[l] = await model.scatterreduce(empty_grad(l), l) 710 | del weights[l] 711 | 712 | if model.rank % 4 != 0 and l == my_layers[0]: 713 | await model.pass_to(model.rank - 1, grad_activations[l]) 714 | for l in range(model.LAYERS): 715 | weights[l], opt_states[l] = model.update(l, grad_weights[l], weights[l, 0], opt_states[l, 0]) 716 | 717 | # END USER CODE 718 | for l in range(model.LAYERS): 719 | model.set_final_weight(l, weights[l]) 720 | # Update 721 | return model 722 | 723 | dist = Dist(16) 724 | out = await asyncio.gather(*[ 725 | pipeline_fsdp(Model(layers=4, batches=ranks, rank=i, dist=dist)) 726 | for i in range(16)]) 727 | 728 | 729 | # + 730 | Model.check(out) 731 | chalk.set_svg_height(1000) 732 | chalk.set_svg_draw_height(1000) 733 | 734 | draw(out) 735 | # - 736 | 737 | # ### When does it make sense to combine? 738 | # 739 | # The goal of these exercises is to give you a sense of the different methods out there for distributed training. However, there is not currently a one size fits all approach for distributed training. The right choice will depend on the constants such as batch size, memory per GPU, communication overhead, implementation complexity, model size, and specifics of architecture. 740 | # 741 | # As an example of what's left to explore, this last method Pipeline + FSDP is often not a great choice due to the complexities of communication speed. And in fact GPipe + FSDP also gets you into a bad place. The paper [Breadth First Pipeline Parallelism](https://arxiv.org/pdf/2211.05953.pdf) proposes instead a combination of pipeline scheduling and communication. Here's what it looks like. 742 | 743 | # ![image.png](https://github.com/srush/LLM-Training-Puzzles/assets/35882/f286089a-83bd-483c-b441-f154821d161c) 744 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Sasha Rush 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLM Training Puzzles 2 | - by [Sasha Rush](http://rush-nlp.com) - [srush_nlp](https://twitter.com/srush_nlp) 3 | 4 | ![image](https://github.com/srush/LLM-Training-Puzzles/assets/35882/0c46911f-ad1c-4e7a-a42b-2bc2537cccc3) 5 | 6 | 7 | This is a collection of 8 challenging puzzles about training large language models (or really any NN) on many, many GPUs. 8 | Very few people actually get a chance to train on thousands of computers, but it is an interesting challenge and one that 9 | is critically important for modern AI. The goal of these puzzles is to get hands-on experience with the key primitives and to understand 10 | the goals of memory efficiency and compute pipelining. 11 | 12 | 13 | I recommend running in Colab. Click here and copy the notebook to get start. 14 | 15 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/LLM-Training-Puzzles/blob/main/puzzles.ipynb) 16 | 17 | ![image](https://github.com/srush/LLM-Training-Puzzles/assets/35882/6d16fc9e-3d14-4bd0-b7c7-d056e49854ac) 18 | 19 | 20 | 21 | If you are into this kind of thing, this is 6th in a series of these puzzles. 22 | 23 | * https://github.com/srush/gpu-puzzles 24 | * https://github.com/srush/tensor-puzzles 25 | * https://github.com/srush/autodiff-puzzles 26 | * https://github.com/srush/transformer-puzzles 27 | * https://github.com/srush/GPTworld 28 | -------------------------------------------------------------------------------- /drawing.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Optional, FrozenSet 2 | from lib import Model, Activation, Weight, ActivationGrad, WeightGrad 3 | 4 | from chalk import rectangle, text, vcat, empty, hcat, Diagram, concat 5 | from colour import Color 6 | 7 | # backward_col = list(Color("Yellow").range_to("Red", LAYERS)) 8 | 9 | 10 | def draw(models : Sequence[Model]) -> Diagram: 11 | TIME = 2 12 | layers = models[0].LAYERS 13 | forward_col = list(Color("green").range_to("red", layers + 1)) 14 | MAXTIME = max(m.log[-1].time for m in models) 15 | MAXMEM = 0 16 | ARGMEM = None 17 | for i, m in enumerate(models): 18 | for event in m.log: 19 | if event.memory > MAXMEM: 20 | MAXMEM = event.memory 21 | ARGMEM = (i, event.time) 22 | print(f"Timesteps: {MAXTIME} \nMaximum Memory: {MAXMEM} at GPU: {ARGMEM[0]} time: {ARGMEM[1]}") 23 | def square(layer:int, time:int, length:int, s: str="") -> Diagram: 24 | PAD = 0.2 25 | return ( 26 | ( 27 | rectangle(length * TIME - 0.05, 0.9).line_width(0) 28 | + text(s, 0.9) 29 | .translate(0, 0.1) 30 | .line_width(0.05) 31 | .fill_color(Color("black")) 32 | ) 33 | .align_l() 34 | .translate(time * TIME, 0) 35 | ) 36 | 37 | def draw_forward(e): 38 | return square(e.layer, e.time, e.length, " ".join(map(str, e.batches))).fill_color( 39 | forward_col[e.layer] 40 | ) 41 | 42 | def draw_backward(e): 43 | return ( 44 | square(e.layer, e.time, e.length, " ".join(map(str, e.batches))) 45 | .fill_color(forward_col[e.layer]) 46 | .fill_opacity(0.2) 47 | ) 48 | 49 | def draw_update(e): 50 | return square(e.layer, e.time, e.length, "").fill_color(Color("yellow")).fill_opacity(0.5) 51 | 52 | def draw_allred(e): 53 | return square(e.layer, e.time, e.length).fill_color(forward_col[e.layer]).fill_opacity(0.5) 54 | def draw_pass(e): 55 | return square(0, e.time, e.length).fill_color(Color("grey")).fill_opacity(0.5) 56 | 57 | rows = [] 58 | 59 | # Time 60 | for gpu in range(len(models)): 61 | d = empty() 62 | box = ( 63 | rectangle(TIME * (MAXTIME + 1), 1).fill_color(Color("lightgrey")).align_l() 64 | ) 65 | box = box + text(str(gpu), 1).line_width(0).with_envelope( 66 | rectangle(TIME, 1) 67 | ).translate(-TIME, 0).fill_color(Color("orange")) 68 | d += box 69 | for e in models[gpu].log: 70 | if e.typ == "forward": 71 | d += draw_forward(e) 72 | if e.typ == "backward": 73 | d += draw_backward(e) 74 | if e.typ == "update": 75 | d += draw_update(e) 76 | if e.typ in ["allreduce", "scatterreduce", "allgather"]: 77 | d += draw_allred(e) 78 | if e.typ in ["pass"]: 79 | d += draw_pass(e) 80 | rows.append(d) 81 | d = vcat(rows) 82 | 83 | rows = [] 84 | for gpu in range(len(models)): 85 | row = rectangle(TIME * (MAXTIME + 1), 1).fill_color(Color("white")).align_l() 86 | row = row + text(str(gpu), 1).line_width(0).with_envelope( 87 | rectangle(TIME, 1) 88 | ).translate(-TIME, 0).fill_color(Color("grey")) 89 | for e in models[gpu].log: 90 | can = ( 91 | rectangle(TIME * e.length, e.memory / (1.5 * MAXMEM)) 92 | .align_b() 93 | .align_l() 94 | .line_width(0) 95 | .fill_color(Color("grey")) 96 | ) 97 | row = row.align_b() + can.translate(TIME * e.time, 0) 98 | if gpu == ARGMEM[0]: 99 | row = row + rectangle(0.1, 1).translate(TIME * ARGMEM[1], 0).line_color(Color("red")).align_b() 100 | 101 | rows.append(row) 102 | d2 = vcat(rows) 103 | d = vcat([d, d2]) 104 | # return rectangle(1.5, 0.5) + d.scale_uniform_to_x(1).center_xy() 105 | return rectangle(1.5, 8).line_width(0) + d.scale_uniform_to_y(len(models)).center_xy() 106 | 107 | 108 | def draw_network(layers:int, weight: Optional[int]=None, before:int=-1, after:int=100, 109 | shards:FrozenSet[int]=frozenset({}), total:int=1, 110 | batches: FrozenSet[int]=frozenset({0}), total_batches=1, is_grad: bool=False) -> Diagram: 111 | forward_col = list(Color("green").range_to("red", layers+1)) 112 | def layer(l: int) -> Diagram: 113 | W = 3 114 | H = 1 if l < layers else 0 115 | layer = rectangle(W, H).line_width(0.2).align_b() 116 | shard_h = H * (1 / total) 117 | shard_w = W * (1 / total_batches) 118 | 119 | weight_shard = rectangle(shard_w, shard_h).line_width(0.01).align_t().align_l().line_color(Color("white")) 120 | weight_shards = concat([weight_shard.translate(batch * shard_w - (W/2), 121 | shard * shard_h - H) 122 | for shard in shards for batch in batches]) 123 | 124 | connect_out = rectangle(1.5, 1.05).line_width(0.2).align_t() 125 | connect_w = 1.5 * (1 / total_batches) 126 | connect = rectangle(connect_w, 1).line_width(0.01).align_l().line_color(Color("white")) 127 | connect = concat([connect.translate(batch * connect_w - 1.5 * (1/2), 0.0) 128 | for batch in batches]).align_t().translate(0, 0.025) 129 | 130 | if l == weight: 131 | weight_shards = weight_shards.fill_color(forward_col[l]) 132 | if is_grad: 133 | weight_shards = weight_shards.fill_opacity(0.5) 134 | else: 135 | weight_shards = empty() 136 | if l == before: #or (after != 100 and l <= after): 137 | connect = connect.fill_color(forward_col[l]).fill_opacity(1) 138 | elif l == after + 1: 139 | connect = connect.fill_color(forward_col[l]).fill_opacity(0.5) 140 | else: 141 | connect = empty() 142 | base = connect_out + layer 143 | return base, (connect + weight_shards).with_envelope(base) 144 | return vcat(reversed([layer(l)[0] for l in range(layers+1)])), vcat(reversed([layer(l)[1] for l in range(layers+1)])) 145 | 146 | 147 | def draw_group(group): 148 | group = list(group.values()) 149 | base = group[0].draw()[0] 150 | return base + concat([g.draw()[1] for g in group]) 151 | 152 | 153 | # hcat([base + Activation(2, 5, [0], 2).draw()[1], 154 | # base + Weight(2, 5, 0, (0,), 2).draw()[1], 155 | # base + WeightGrad(2, 5, {0}, 2, {1}, 2).draw()[1] + WeightGrad(2, 5, {1}, 2, {1}, 2).draw()[1], 156 | # base + ActivationGrad(2, 5, frozenset({1}), 2).draw()[1]], 1).render_svg("activation.svg", 500) -------------------------------------------------------------------------------- /lib.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import asyncio 3 | from dataclasses import dataclass 4 | from typing import Any, Dict, List, Optional, Protocol, Sequence, Tuple, FrozenSet, TypeVar 5 | from chalk import Diagram 6 | import random 7 | 8 | class Barrier: 9 | """Sync across n ranks""" 10 | def __init__(self, target: int): 11 | self.counter = 0 12 | self.target = target 13 | self.lock = asyncio.Lock() 14 | self.round = 0 15 | self.done = 0 16 | 17 | async def wait(self, rank: int) -> None: 18 | while self.done > 0: 19 | await asyncio.sleep(0.01) 20 | async with self.lock: 21 | self.counter += 1 22 | while self.counter < self.target: 23 | await asyncio.sleep(0.01) 24 | self.done += 1 25 | if rank == 0: 26 | await self.reset() 27 | 28 | async def reset(self) -> None: 29 | while self.done < self.target: 30 | await asyncio.sleep(0.01) 31 | self.counter = 0 32 | self.done = 0 33 | 34 | T = TypeVar('T') 35 | class Reduceable(Protocol[T]): 36 | """ 37 | A type that can be reduced. 38 | """ 39 | def __add__(self, other: T) -> T: 40 | ... 41 | 42 | O = TypeVar('O') 43 | class Gatherable(Protocol[O]): 44 | """ 45 | A type that can be sharded. 46 | """ 47 | def shard(self, shard: int, total: int) -> O: 48 | ... 49 | 50 | def is_complete(self) -> bool: 51 | ... 52 | 53 | def combine(self, other: O) -> O: 54 | ... 55 | 56 | 57 | TO = TypeVar('TO') 58 | class ReduceableGatherable(Reduceable[TO], Gatherable[TO]): 59 | pass 60 | 61 | 62 | class Dist: 63 | def __init__(self, total: int) -> None: 64 | self.reduce: Optional[Any] = None 65 | self.gather: Optional[Any] = None 66 | self.ranks = total 67 | self.barrier = Barrier(total) 68 | self.queue : Sequence[asyncio.Queue[Any]] = [asyncio.Queue(maxsize=1) for i in range(total)] 69 | self.mtime = 0 70 | 71 | async def allreduce(self, rank: int, inp: T, time:int) -> Tuple[T, int]: 72 | if self.reduce is None: 73 | self.reduce = inp 74 | else: 75 | self.reduce = self.reduce + inp 76 | self.mtime = max(time, self.mtime) 77 | await self.barrier.wait(rank) 78 | q: T = self.reduce 79 | mtime = self.mtime 80 | await self.barrier.wait(rank) 81 | if rank == 0: 82 | self.reduce = None 83 | self.mtime = 0 84 | await self.barrier.wait(rank) 85 | return q, mtime 86 | 87 | async def allgather(self, rank: int, inp: O, time:int) -> Tuple[O, int]: 88 | if self.gather is None: 89 | self.gather = inp 90 | else: 91 | assert type(self.gather) == type(inp) 92 | self.gather = self.gather.combine(inp) 93 | self.mtime = max(time, self.mtime) 94 | await self.barrier.wait(rank) 95 | q: O = self.gather 96 | mtime = self.mtime 97 | await self.barrier.wait(rank) 98 | if rank == 0: 99 | self.gather = None 100 | self.mtime = 0 101 | await self.barrier.wait(rank) 102 | return q, mtime 103 | 104 | async def scatterreduce(self, rank: int, inp: TO, time:int) -> Tuple[TO, int]: 105 | x, time = await self.allreduce(rank, inp, time) 106 | y = x.shard(rank, self.ranks) # type: ignore 107 | return y, time # type: ignore 108 | 109 | async def receive(self, rank: int) -> Any: 110 | return await self.queue[rank].get() 111 | 112 | async def pass_to(self, rank: int, v: Any) -> None: 113 | await self.queue[rank].put(v) 114 | 115 | 116 | @dataclass 117 | class Weight(Gatherable["Weight"]): 118 | """ 119 | The weights for a specific layer. Can be sharded. 120 | 121 | Required for forward and backward passes. 122 | """ 123 | layer: int 124 | layers: int 125 | step: int 126 | shards: FrozenSet[int] = frozenset([0]) 127 | total: int = 1 128 | 129 | def combine(self, other: Weight) -> Weight: 130 | return Weight(self.layer, self.layers, self.step, self.shards | other.shards, self.total) 131 | 132 | def memory(self) -> float: 133 | return (len(self.shards) / self.total) * HIDDEN * HIDDEN 134 | 135 | def shard(self, shard: int, total: int) -> Weight: 136 | assert self.is_complete() 137 | assert shard < total 138 | return Weight(self.layer, self.layers, self.step, frozenset([shard]), total) 139 | 140 | def is_complete(self) -> bool: 141 | return len(self.shards) == self.total 142 | 143 | def draw(self) -> Diagram: 144 | from drawing import draw_network 145 | return draw_network(self.layers, weight=self.layer, 146 | shards=self.shards, total=self.total) 147 | def _repr_svg_(self): 148 | d = self.draw() 149 | return (d[0] + d[1])._repr_svg_() 150 | 151 | HIDDEN = 512 152 | LENGTH = 256 153 | @dataclass 154 | class Activation: 155 | """ 156 | Activations need for a specific layer for a specific set of batches. 157 | """ 158 | layer: int 159 | layers: int 160 | batches: FrozenSet[int] 161 | total_batches: int 162 | 163 | def memory(self) -> int: 164 | return len(self.batches) * HIDDEN * LENGTH 165 | 166 | def draw(self) -> Diagram: 167 | from drawing import draw_network 168 | return draw_network(self.layers, before=self.layer, 169 | batches=self.batches, total_batches=self.total_batches) 170 | 171 | def _repr_svg_(self): 172 | d = self.draw() 173 | return (d[0] + d[1])._repr_svg_() 174 | 175 | @dataclass 176 | class WeightGrad(Reduceable["WeightGrad"], Gatherable["WeightGrad"]): 177 | """ 178 | The gradient of the loss for a specific weight layer. 179 | 180 | May be sharded to correspond to different parts of the weights. 181 | 182 | May be split into different batches. 183 | """ 184 | 185 | 186 | layer: int 187 | layers: int 188 | batches: FrozenSet[int] 189 | total_batches: int 190 | shards: FrozenSet[int] = frozenset([0]) 191 | total: int = 1 192 | 193 | def __add__(self, other: WeightGrad) -> WeightGrad: 194 | assert self.layer == other.layer, "Only add same layer weight grads" 195 | assert self.shards == other.shards 196 | return WeightGrad(self.layer, self.layers, self.batches | other.batches, self.total_batches, 197 | self.shards, self.total) 198 | 199 | def combine(self, other: WeightGrad) -> WeightGrad: 200 | return WeightGrad(self.layer, self.layers, self.batches, self.total_batches, 201 | self.shards | other.shards, self.total) 202 | 203 | def memory(self) -> float: 204 | return (len(self.shards) / self.total) * HIDDEN * HIDDEN 205 | 206 | def shard(self, shard: int, total: int) -> WeightGrad: 207 | assert self.is_complete(), f"{self.shards} out of {self.total}" 208 | assert shard < total 209 | return WeightGrad(self.layer, self.layers, self.batches, self.total_batches, frozenset([shard]), total) 210 | 211 | def is_complete(self) -> bool: 212 | return len(self.shards) == self.total 213 | 214 | def draw(self) -> Diagram: 215 | from drawing import draw_network 216 | return draw_network(self.layers, weight=self.layer, shards=self.shards, 217 | batches=self.batches, 218 | total=self.total, total_batches=self.total_batches, is_grad=True) 219 | 220 | def _repr_svg_(self): 221 | d = self.draw() 222 | return (d[0] + d[1])._repr_svg_() 223 | 224 | 225 | @dataclass 226 | class OptState(Gatherable["OptState"]): 227 | """ 228 | The state of the optimizer for a specific layer. Can be sharded. 229 | 230 | In pratice this represents ADAM's saved values needed for optimization. 231 | 232 | Required for updating the weights. 233 | """ 234 | 235 | layer: int 236 | layers: int 237 | step: int 238 | shards: FrozenSet[int] = frozenset([0,]) 239 | total: int = 1 240 | 241 | def combine(self, other: OptState) -> OptState: 242 | return OptState(self.layer, self.layers, self.step, self.shards | other.shards, self.total) 243 | 244 | def memory(self) -> float: 245 | return HIDDEN * HIDDEN * (len(self.shards) / self.total) 246 | 247 | def draw(self) -> Diagram: 248 | from drawing import draw_network 249 | return draw_network(self.layers, before=self.layer, shards=self.shards, total=self.total) 250 | 251 | def _repr_svg_(self): 252 | d = self.draw() 253 | return (d[0] + d[1])._repr_svg_() 254 | 255 | @dataclass 256 | class ActivationGrad: 257 | """ 258 | The gradient of the activations for a specific layer. 259 | 260 | May be split into different batches. 261 | """ 262 | 263 | layer: int 264 | layers: int 265 | batches: FrozenSet[int] 266 | total_batches: int 267 | 268 | def memory(self) -> int: 269 | return len(self.batches) * HIDDEN * LENGTH 270 | 271 | def draw(self) -> Diagram: 272 | from drawing import draw_network 273 | return draw_network(self.layers, after=self.layer, 274 | batches=self.batches, total_batches=self.total_batches) 275 | 276 | def _repr_svg_(self): 277 | d = self.draw() 278 | return (d[0] + d[1])._repr_svg_() 279 | 280 | 281 | @dataclass 282 | class Event: 283 | "Internal representations of events in the model for the visualizer" 284 | typ: str 285 | layer: Optional[int] 286 | rank: int 287 | time: int 288 | length: int 289 | memory: int 290 | batches: FrozenSet[int] = frozenset() 291 | 292 | 293 | class Model: 294 | def __init__(self, rank: int=1, dist: Dist=Dist(1), layers: int=2, batches: int=1): 295 | self.rank = rank 296 | self.log: List[Event] = [] 297 | self.dist = dist 298 | self.time = 0 299 | self.RANKS = dist.ranks 300 | self.LAYERS = layers 301 | self.BATCHES = batches 302 | self.final_weights: Dict[int, Weight] = {} 303 | 304 | self.weights: Dict[Any, Weight] = {} 305 | self.opt_states: Dict[Any, OptState] = {} 306 | self.activations: Dict[Any, Activation] = {} 307 | self.grad_activations: Dict[Any, ActivationGrad] = {} 308 | self.grad_weights: Dict[Any, WeightGrad] = {} 309 | 310 | def storage(self) -> Tuple[Dict[Any, Weight], Dict[Any, OptState], Dict[Any, Activation], Dict[Any, ActivationGrad], Dict[Any, WeightGrad]]: 311 | return self.weights, self.opt_states, self.activations, self.grad_activations, self.grad_weights 312 | 313 | def memory(self) -> int: 314 | mem = 0 315 | for d in list(self.storage()): 316 | assert isinstance(d, dict) 317 | for v in d.values(): 318 | mem += v.memory() 319 | return mem 320 | 321 | def status(self): 322 | for d in list(self.storage()): 323 | for k, v in d.items(): 324 | print(k, type(v), end=",") 325 | print() 326 | 327 | def event(self, typ: str, layer: Optional[int]=None, batches: FrozenSet[int]=frozenset({})) -> None: 328 | length = 0 329 | if typ in ["loss", "allgather"]: 330 | length = 0 331 | if typ in ["forward", "backward"]: 332 | length = len(batches) 333 | if typ in ["update"]: 334 | length = 0.5 335 | if typ in ["allreduce", "scatterreduce", "allgather"]: 336 | length = 0.3 337 | if typ in ["pass"]: 338 | length = 0.2 339 | 340 | self.log.append(Event(typ, layer, self.rank, self.time, length, self.memory(), batches)) 341 | self.time += length 342 | def load_weights(self, layer: int, shard: int = 0, total:int = 1 ) -> Tuple[Weight, OptState]: 343 | return Weight(layer, self.LAYERS, 0, frozenset([shard]), total),\ 344 | OptState(layer, self.LAYERS, 0, frozenset([shard]), total) 345 | 346 | def set_final_weight(self, layer: int, weight:Weight) -> None: 347 | self.final_weights[layer] = weight 348 | 349 | def get_activation(self, batches: Sequence[int]) -> Activation: 350 | return Activation(0, self.LAYERS, frozenset(batches), self.BATCHES) 351 | 352 | def forward(self, layer: int, inp: Activation, weight: Weight) -> Activation: 353 | "Take in activation at layer i and return layer i + 1" 354 | self.event("forward", layer, inp.batches) 355 | assert weight.is_complete() 356 | assert weight.layer == layer, f"Weight should be layer {layer}" 357 | assert inp.layer == layer, f"Input should be layer {layer}" 358 | return Activation(layer + 1, self.LAYERS, inp.batches, self.BATCHES) 359 | 360 | def backward( 361 | self, layer: int, inp: Activation, grad: ActivationGrad, weight: Weight 362 | ) -> Tuple[WeightGrad, ActivationGrad]: 363 | self.event("backward", layer, inp.batches) 364 | assert weight.is_complete() 365 | assert weight.layer == layer, f"Weight should be layer {layer}" 366 | assert inp.layer == layer, f"Input should be layer {layer}" 367 | assert set(inp.batches) == set( 368 | grad.batches 369 | ), f"Batch mismatch {set(inp.batches)}" 370 | assert grad.layer == layer, f"Activation Grad should be layer {layer}" 371 | return (WeightGrad(layer, self.LAYERS, inp.batches, self.BATCHES), 372 | ActivationGrad(layer - 1, self.LAYERS, inp.batches, self.BATCHES)) 373 | 374 | def loss(self, inp: Activation) -> ActivationGrad: 375 | self.event("loss", self.LAYERS) 376 | assert inp.layer == self.LAYERS, f"Input should be final layer {self.LAYERS}" 377 | return ActivationGrad(self.LAYERS - 1, self.LAYERS, inp.batches, self.BATCHES) 378 | 379 | def update(self, layer: int, 380 | weight_grad: WeightGrad, 381 | weight: Weight, 382 | opt_state: OptState, 383 | shard: int = 0) -> Tuple[Weight, OptState]: 384 | 385 | assert weight.layer == layer, f"Weight should be layer {layer}" 386 | assert weight_grad.layer == layer, f"Grad weight should be layer {layer}" 387 | assert set(weight_grad.batches) == set( 388 | range(self.BATCHES) 389 | ), f"{set(weight_grad.batches)}" 390 | assert opt_state.layer == layer 391 | if weight_grad.total > 1: 392 | assert weight.shards.issubset(weight_grad.shards), f"Weight {weight.shards}" 393 | assert opt_state.shards.issubset(weight_grad.shards), f"Opt {opt_state.shards}" 394 | assert weight.step == opt_state.step 395 | new_opt = OptState(layer, self.LAYERS, opt_state.step + 1, opt_state.shards, opt_state.total) 396 | new_weight = Weight(layer, self.LAYERS, weight.step + 1, weight.shards, weight.total) 397 | self.event("update", None) 398 | return new_weight, new_opt 399 | 400 | def fake_grad(self, layer: int, batches= List[int]): 401 | return WeightGrad(layer, self.LAYERS, frozenset(batches), self.BATCHES) 402 | 403 | async def allreduce(self, v: T, layer: int) -> T: 404 | v, self.time = await self.dist.allreduce(self.rank, v, self.time) 405 | self.event("allreduce", layer) 406 | 407 | return v 408 | 409 | async def scatterreduce(self, v: TO, layer:int) -> TO: 410 | v, self.time = await self.dist.scatterreduce(self.rank, v, self.time) 411 | self.event("scatterreduce", layer) 412 | return v 413 | 414 | async def allgather(self, v: O, layer:int) -> O: 415 | v, self.time = await self.dist.allgather(self.rank, v, self.time) 416 | self.event("allgather", layer) 417 | return v 418 | 419 | async def pass_to(self, rank: int, v: Any) -> None: 420 | self.event("pass", None) 421 | await self.dist.pass_to(rank, (v, self.time)) 422 | 423 | async def receive(self) -> Any: 424 | v, time = await self.dist.receive(self.rank) 425 | self.time = max(time, self.time) 426 | self.event("pass", None) 427 | return v 428 | 429 | @staticmethod 430 | def check(models : Sequence[Model]) -> None: 431 | for l in range(models[0].LAYERS): 432 | weight = None 433 | for m in models: 434 | if l in m.final_weights: 435 | assert m.final_weights[l].step == 1 436 | if weight is None: 437 | weight = m.final_weights[l] 438 | else: 439 | weight = weight.combine(m.final_weights[l]) 440 | assert weight is not None, f"Missing weight {l}" 441 | assert weight.is_complete(), f"Weight not complete {weight}" 442 | 443 | print("Correct!") 444 | from IPython.display import HTML 445 | pups = [ 446 | "2m78jPG", 447 | "pn1e9TO", 448 | "MQCIwzT", 449 | "udLK6FS", 450 | "ZNem5o3", 451 | "DS2IZ6K", 452 | "aydRUz8", 453 | "MVUdQYK", 454 | "kLvno0p", 455 | "wScLiVz", 456 | "Z0TII8i", 457 | "F1SChho", 458 | "9hRi2jN", 459 | "lvzRF3W", 460 | "fqHxOGI", 461 | "1xeUYme", 462 | "6tVqKyM", 463 | "CCxZ6Wr", 464 | "lMW0OPQ", 465 | "wHVpHVG", 466 | "Wj2PGRl", 467 | "HlaTE8H", 468 | "k5jALH0", 469 | "3V37Hqr", 470 | "Eq2uMTA", 471 | "Vy9JShx", 472 | "g9I2ZmK", 473 | "Nu4RH7f", 474 | "sWp0Dqd", 475 | "bRKfspn", 476 | "qawCMl5", 477 | "2F6j2B4", 478 | "fiJxCVA", 479 | "pCAIlxD", 480 | "zJx2skh", 481 | "2Gdl1u7", 482 | "aJJAY4c", 483 | "ros6RLC", 484 | "DKLBJh7", 485 | "eyxH0Wc", 486 | "rJEkEw4"] 487 | return HTML(""" 488 | 491 | """%(random.sample(pups, 1)[0])) -------------------------------------------------------------------------------- /puzzles.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "d8204266", 6 | "metadata": {}, 7 | "source": [ 8 | "# LLM Training Puzzles\n", 9 | "\n", 10 | "by Sasha Rush ([@srush_nlp](https://twitter.com/srush_nlp))" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "c65be14b", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "%%capture\n", 21 | "# Uncomment to run in Colab\n", 22 | "!pip install -qqq git+https://github.com/chalk-diagrams/chalk asyncio\n", 23 | "!wget https://raw.githubusercontent.com/srush/LLM-Training-Puzzles/main/lib.py https://raw.githubusercontent.com/srush/LLM-Training-Puzzles/main/drawing.py" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "id": "d583c979", 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "from typing import List\n", 34 | "from lib import Model, Dist, WeightGrad\n", 35 | "from drawing import draw, draw_group\n", 36 | "from chalk import vcat\n", 37 | "import asyncio\n", 38 | "import chalk\n", 39 | "chalk.set_svg_height(400)\n", 40 | "chalk.set_svg_draw_height(600)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "id": "cc24dadb", 46 | "metadata": {}, 47 | "source": [ 48 | "## Preliminaries\n", 49 | "\n", 50 | "The goal of these puzzles is to learn about distributed training of LLMs. However, we will be primarily concerned with a speed and memory efficiency of completing a single update of the models. To make things simpler, we will abstract away from the standard tensor-based transformer model, and just consider a state-less representation of each of the components of a multi-layer neural network.\n", 51 | "\n" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "ad71f90c", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "model = Model(layers=2, batches=4)\n", 62 | "weights, opt_states, activations, grad_activations, grad_weights = model.storage()" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "id": "61cf6388", 68 | "metadata": {}, 69 | "source": [ 70 | "Our library has 5 parts: \n", 71 | "\n", 72 | "* Weights\n", 73 | "* Optimizer States - Values needed to update the weights\n", 74 | "* Activations - The internal values computed on the forward pass\n", 75 | "* Grad Activations - The gradients of the loss wrt to activations, needed for backward pass\n", 76 | "* Grad Weights - The gradients of the loss wrt to weights, needed for updates\n", 77 | "\n", 78 | "For these puzzles, you are *not allowed* to have local variables. You need to store each of these in the dictionary corresponding to its type. \n", 79 | "\n", 80 | "We begin by tracing the lifecycle of a single model update." 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "id": "0335a17b", 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "# Get the input activations to the model for batches 2, 3 \n", 91 | "activations[0] = model.get_activation(batches=[2, 3])\n", 92 | "activations[0]" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "id": "962ac1d8", 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "# Load the weights (random) for layers 0 and 1\n", 103 | "for i in range(model.LAYERS):\n", 104 | " weights[i], opt_states[i] = model.load_weights(i)\n", 105 | "weights[0]" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "id": "f7a83439", 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "# Activations can be moved forward a layer if you have the weights.\n", 116 | "activations[1] = model.forward(layer=0, inp=activations[0], weight=weights[0])\n", 117 | "activations[2] = model.forward(layer=1, inp=activations[1], weight=weights[1])\n", 118 | "activations[1]" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "id": "81d904e3", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "# Draw all the current activations in memory.\n", 129 | "draw_group(activations)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "id": "80e8769a", 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "# At the last layer, we can convert an activation to a grad activation by calling `loss`\n", 140 | "grad_activations[model.LAYERS] = model.loss(activations[model.LAYERS])\n", 141 | "grad_activations[model.LAYERS]" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "id": "b3299042", 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "# Calling `backward` requires the forward activation, the backward grad activation, and the weights.\n", 152 | "# It returns the grad weights and the backward activation.\n", 153 | "grad_weights[1], grad_activations[1] = model.backward(1, activations[1], grad_activations[2], weights[1])\n", 154 | "grad_weights[0], grad_activations[0] = model.backward(0, activations[0], grad_activations[1], weights[0])\n", 155 | "grad_activations[1]" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "id": "33d82618", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "# We can use delete to remove any memory that is not longer needed. \n", 166 | "print(\"Before memory:\", model.memory())\n", 167 | "del grad_activations[1]\n", 168 | "print(\"After memory:\", model.memory())\n", 169 | "model.status()\n", 170 | "draw_group(grad_activations)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "id": "f48969f0", 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "# Grad weights keep track of which batches they are for. Here we only have the grad weights for batches 2 and 3.\n", 181 | "draw_group(grad_weights)" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "id": "d3d469f7", 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "# If we try to update with the grad weights we will get an error.\n", 192 | "try:\n", 193 | " model.update(0, weight_grad=grad_weights[0], weight=weights[0], opt_state=opt_states[0])\n", 194 | "except AssertionError as e:\n", 195 | " print(\"Error! Only have batches\")\n", 196 | " print(e)" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "id": "55046513", 203 | "metadata": { 204 | "lines_to_next_cell": 2 205 | }, 206 | "outputs": [], 207 | "source": [ 208 | "# For this example, we can cheat. Pretend we had the other gradients we needed. \n", 209 | "grad_weights[0, 0] = model.fake_grad(0, [0,1])\n", 210 | "grad_weights[1, 0] = model.fake_grad(1, [0,1])\n", 211 | "grad_weights[0, 0] " 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "id": "f5caee88", 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "# Summing together grad_weights gives the full gradient.\n", 222 | "grad_weights[0] = grad_weights[0] + grad_weights[0, 0]" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "id": "dfe2165e", 229 | "metadata": { 230 | "lines_to_next_cell": 2 231 | }, 232 | "outputs": [], 233 | "source": [ 234 | "# Now we can call update to the get the new weights and opt_state.\n", 235 | "weights[0], opt_states[0] = model.update(0, weight_grad=grad_weights[0], weight=weights[0], \n", 236 | " opt_state=opt_states[0])\n", 237 | "\n", 238 | "# WARNING: You need to set all variables. Otherwise they are not counted towards memory.\n", 239 | "grad_weights[1] = grad_weights[1] + grad_weights[1, 0]\n", 240 | "weights[1], opt_states[1] = model.update(1, weight_grad=grad_weights[1],\n", 241 | " weight=weights[1], opt_state=opt_states[1])" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "id": "f9582594", 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "# We can complete the tests by setting these as the final weights and calling check.\n", 252 | "model.set_final_weight(0, weights[0])\n", 253 | "model.set_final_weight(1, weights[1])\n", 254 | "Model.check([model])\n", 255 | "draw_group(model.final_weights)" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "id": "7b2f0a3b", 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "# We can view the final outcome of the system as a diagram. \n", 266 | "# This show the forward and backward passes (numbers of batches) and the updates.\n", 267 | "# The lines on the bottom show the memory that is used at each time step.\n", 268 | "draw([model])" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": null, 274 | "id": "01aa8d66", 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [] 278 | }, 279 | { 280 | "cell_type": "markdown", 281 | "id": "26e5ea60", 282 | "metadata": {}, 283 | "source": [ 284 | "### Puzzle 0 - Standard Training\n", 285 | "\n", 286 | "Write a standard (non-distributed) training loop that acts on all the batches and loads all the weights. It should just run forward, loss, backward, and update. Aim for the least amount of max memory used. \n", 287 | "\n", 288 | "* Target Time: 17 steps\n", 289 | "* Target Memory: 2600000" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "id": "04dbf7ea", 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "def basic(model: Model) -> Model:\n", 300 | " # Storage on device.\n", 301 | " weights, opt_states, activations, grad_activations, grad_weights = model.storage()\n", 302 | "\n", 303 | " # Load in the full weights\n", 304 | " for l in range(model.LAYERS):\n", 305 | " weights[l], opt_states[l] = model.load_weights(l)\n", 306 | "\n", 307 | " # Load the input layer activations\n", 308 | " activations[0] = model.get_activation(range(model.BATCHES))\n", 309 | "\n", 310 | " assert False, 'TODO: Implement me'\n", 311 | " \n", 312 | " for l in range(model.LAYERS):\n", 313 | " model.set_final_weight(l, weights[l])\n", 314 | " return model" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "id": "e3632f24", 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [ 324 | "out = basic(Model(layers=2, batches=4, rank=0, dist=Dist(1)))\n", 325 | "draw_group(out.final_weights)" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": null, 331 | "id": "21c87306", 332 | "metadata": {}, 333 | "outputs": [], 334 | "source": [ 335 | "draw([out])" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": null, 341 | "id": "74e71f5a", 342 | "metadata": {}, 343 | "outputs": [], 344 | "source": [ 345 | "Model.check([out])" 346 | ] 347 | }, 348 | { 349 | "cell_type": "markdown", 350 | "id": "d25bbcaf", 351 | "metadata": {}, 352 | "source": [ 353 | "### Puzzle 1 - Gradient Accumulation\n", 354 | "\n", 355 | "For this puzzle, the goal is to reduce max memory usage. To do so you are going to run on each batch individually instead of all together. \n", 356 | "\n", 357 | "Write a function with four parts. First run on batches {0} and then {1} etc. Sum the grad weights and then update.\n", 358 | "\n", 359 | "* Target Time: 17 steps\n", 360 | "* Target Memory: 2000000" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": null, 366 | "id": "4c870ae6", 367 | "metadata": {}, 368 | "outputs": [], 369 | "source": [ 370 | "def grad_accum(model: Model) -> Model:\n", 371 | " # Storage on device.\n", 372 | " weights, opt_states, activations, grad_activations, grad_weights = model.storage()\n", 373 | "\n", 374 | " # Load in the full weights\n", 375 | " for l in range(model.LAYERS):\n", 376 | " weights[l], opt_states[l] = model.load_weights(l)\n", 377 | "\n", 378 | " assert False, 'TODO: Implement me'\n", 379 | " for l in range(model.LAYERS):\n", 380 | " model.set_final_weight(l, weights[l])\n", 381 | " return model" 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "execution_count": null, 387 | "id": "672563cf", 388 | "metadata": {}, 389 | "outputs": [], 390 | "source": [ 391 | "out = grad_accum(Model(layers=2, batches=4, rank=0, dist=Dist(1)))\n", 392 | "draw_group(out.final_weights)" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": null, 398 | "id": "77bb31d7", 399 | "metadata": {}, 400 | "outputs": [], 401 | "source": [ 402 | "draw([out])" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": null, 408 | "id": "28cf8167", 409 | "metadata": {}, 410 | "outputs": [], 411 | "source": [ 412 | "Model.check([out])" 413 | ] 414 | }, 415 | { 416 | "cell_type": "markdown", 417 | "id": "ccfb345c", 418 | "metadata": {}, 419 | "source": [ 420 | "## Communications: AllReduce" 421 | ] 422 | }, 423 | { 424 | "cell_type": "markdown", 425 | "id": "17383602", 426 | "metadata": {}, 427 | "source": [ 428 | "When working with multiple GPUs we need to have communication. \n", 429 | "The primary communication primitives for GPUs are implemented in NCCL. \n", 430 | "\n", 431 | "https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/operations.html\n", 432 | "\n", 433 | "We are not going to use these directly, but simulate them using Python and asyncio. \n", 434 | "\n", 435 | "The first operation is AllReduce. We will have 4 GPUs (ranks=4) and use them each to compute a batch of weight grads." 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": null, 441 | "id": "1f7bb767", 442 | "metadata": {}, 443 | "outputs": [], 444 | "source": [ 445 | "ranks = 4\n", 446 | "weight_grads = [WeightGrad(0, 1, {i}, ranks) for i in range(ranks)]\n", 447 | "weight_grads[0] + weight_grads[1] + weight_grads[2] + weight_grads[3]" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": null, 453 | "id": "ef0232db", 454 | "metadata": {}, 455 | "outputs": [], 456 | "source": [ 457 | "# Simple asynchronous function that calls allreduce to sum the weight grads at layer 0\n", 458 | "async def myfunc(model: Model) -> WeightGrad:\n", 459 | " return await model.allreduce(weight_grads[model.rank], 0)" 460 | ] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "execution_count": null, 465 | "id": "43a91ae4", 466 | "metadata": {}, 467 | "outputs": [], 468 | "source": [ 469 | "# This code uses asyncio to run the above function on 4 \"GPUs\" .\n", 470 | "dist = Dist(ranks)\n", 471 | "out_weight_grads = await asyncio.gather(*[\n", 472 | " myfunc(Model(layers=1, batches=1, rank=i, dist=dist))\n", 473 | " for i in range(ranks)])\n", 474 | "out_weight_grads[0]" 475 | ] 476 | }, 477 | { 478 | "cell_type": "markdown", 479 | "id": "73d4ee90", 480 | "metadata": {}, 481 | "source": [ 482 | "Note: When running communication operations like AllReduce on a GPU, the communication happens in parallel to the computation on that GPU. That means the API for AllReduce does not block, and allows the model to continue running while waiting for this command to run. This means it is beneficial to run AllReduce (and other communication) as early as possible so that other compute can be run during the reduction. \n", 483 | "\n", 484 | "We will ignore this in these puzzles and represent communication as happening efficiently." 485 | ] 486 | }, 487 | { 488 | "cell_type": "markdown", 489 | "id": "1020523e", 490 | "metadata": {}, 491 | "source": [ 492 | "### Puzzle 2 - Distributed Data Parallel\n", 493 | "\n", 494 | "Write a function with four parts. First run on batches {0} and then {1} etc. Sum the grad weights and then update. The main benefit of this approach is compute efficiency over gradient accumulation.\n", 495 | "\n", 496 | "* Total Steps: 5\n", 497 | "* Total Memory: 1800000" 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "execution_count": null, 503 | "id": "f492668c", 504 | "metadata": { 505 | "lines_to_next_cell": 2 506 | }, 507 | "outputs": [], 508 | "source": [ 509 | "async def ddp(model: Model) -> Model:\n", 510 | " # Storage on device.\n", 511 | " weights, opt_states, activations, grad_activations, grad_weights = model.storage()\n", 512 | " # Load all the activations\n", 513 | " model.activations[0] = model.get_activation([model.rank])\n", 514 | "\n", 515 | " assert False, 'TODO: Implement me'\n", 516 | " for l in range(model.LAYERS):\n", 517 | " model.set_final_weight(l, weights[l])\n", 518 | " return model" 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "execution_count": null, 524 | "id": "d6ed44f7", 525 | "metadata": {}, 526 | "outputs": [], 527 | "source": [ 528 | "dist = Dist(ranks)\n", 529 | "out = await asyncio.gather(*[\n", 530 | " ddp(Model(layers=2, batches=ranks, rank=i, dist=dist))\n", 531 | " for i in range(ranks)])\n", 532 | "draw_group(out[0].final_weights)" 533 | ] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "execution_count": null, 538 | "id": "89411fad", 539 | "metadata": {}, 540 | "outputs": [], 541 | "source": [ 542 | "draw(out)" 543 | ] 544 | }, 545 | { 546 | "cell_type": "code", 547 | "execution_count": null, 548 | "id": "2df8166f", 549 | "metadata": {}, 550 | "outputs": [], 551 | "source": [ 552 | "Model.check(out)" 553 | ] 554 | }, 555 | { 556 | "cell_type": "markdown", 557 | "id": "8c3df405", 558 | "metadata": {}, 559 | "source": [ 560 | "## Communication: AllGather / Sharding\n", 561 | "\n", 562 | "Our next primitive is AllGather. This allows us to communicate \"shards\" of an object stored on different GPUs to all the GPUs." 563 | ] 564 | }, 565 | { 566 | "cell_type": "code", 567 | "execution_count": null, 568 | "id": "5b4bad85", 569 | "metadata": {}, 570 | "outputs": [], 571 | "source": [ 572 | "# Load only part of a weights.\n", 573 | "model = Model(layers=2, batches=1, rank=0, dist=Dist(1))\n", 574 | "weight, _ = model.load_weights(0, shard=0, total=4)\n", 575 | "weight" 576 | ] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "execution_count": null, 581 | "id": "2597b152", 582 | "metadata": {}, 583 | "outputs": [], 584 | "source": [ 585 | "# Combine togegher two shards on one machine.\n", 586 | "weights = [model.load_weights(0, shard=i, total=ranks)[0] for i in range(ranks)]\n", 587 | "weights[0].combine(weights[2])" 588 | ] 589 | }, 590 | { 591 | "cell_type": "code", 592 | "execution_count": null, 593 | "id": "5e3d10b6", 594 | "metadata": {}, 595 | "outputs": [], 596 | "source": [ 597 | "# Use allgather to collect the shards from all machines.\n", 598 | "async def mygather(model: Model) -> WeightGrad:\n", 599 | " # Allreduce sums together all the weight grads\n", 600 | " return await model.allgather(weights[model.rank], 0)\n", 601 | "\n", 602 | "dist = Dist(ranks)\n", 603 | "out_weights = await asyncio.gather(*[\n", 604 | " mygather(Model(layers=1, batches=1, rank=i, dist=dist))\n", 605 | " for i in range(ranks)])\n", 606 | "out_weights[0]" 607 | ] 608 | }, 609 | { 610 | "cell_type": "markdown", 611 | "id": "c3903613", 612 | "metadata": {}, 613 | "source": [ 614 | "### Puzzle 3: Weight-Sharded Data Parallel\n", 615 | "\n", 616 | "Run a model that shards each layer weight over all the machines. Reconstruct the layer weight at each layer using allgather. Finally update the weights on each machine using allreduce.\n", 617 | "\n", 618 | "* Total Steps: 20\n", 619 | "* Total Memory: 2800000" 620 | ] 621 | }, 622 | { 623 | "cell_type": "code", 624 | "execution_count": null, 625 | "id": "4b674be7", 626 | "metadata": {}, 627 | "outputs": [], 628 | "source": [ 629 | "async def wsdp(model: Model) -> Model:\n", 630 | " # Storage on device.\n", 631 | " weights, opt_states, activations, grad_activations, grad_weights = model.storage()\n", 632 | "\n", 633 | " # Load all the activations\n", 634 | " model.activations[0] = model.get_activation([model.rank])\n", 635 | "\n", 636 | " # Load a shard of the weights for every layer. Load in the full weights\n", 637 | " for l in range(model.LAYERS):\n", 638 | " weights[l], opt_states[l] = model.load_weights(l, model.rank, model.RANKS) \n", 639 | "\n", 640 | " assert False, 'TODO: Implement me'\n", 641 | " for l in range(model.LAYERS):\n", 642 | " model.set_final_weight(l, weights[l])\n", 643 | "\n", 644 | " return model" 645 | ] 646 | }, 647 | { 648 | "cell_type": "code", 649 | "execution_count": null, 650 | "id": "6c23a1eb", 651 | "metadata": {}, 652 | "outputs": [], 653 | "source": [ 654 | "dist = Dist(ranks)\n", 655 | "out = await asyncio.gather(*[\n", 656 | " wsdp(Model(layers=6, batches=ranks, rank=i, dist=dist))\n", 657 | " for i in range(ranks)])\n", 658 | "draw_group(out[1].final_weights)" 659 | ] 660 | }, 661 | { 662 | "cell_type": "code", 663 | "execution_count": null, 664 | "id": "09731c67", 665 | "metadata": {}, 666 | "outputs": [], 667 | "source": [ 668 | "draw(out)" 669 | ] 670 | }, 671 | { 672 | "cell_type": "code", 673 | "execution_count": null, 674 | "id": "d3ff46b6", 675 | "metadata": {}, 676 | "outputs": [], 677 | "source": [ 678 | "Model.check(out)" 679 | ] 680 | }, 681 | { 682 | "cell_type": "markdown", 683 | "id": "32243386", 684 | "metadata": {}, 685 | "source": [ 686 | "## Communication: Scatter-Reduce" 687 | ] 688 | }, 689 | { 690 | "cell_type": "markdown", 691 | "id": "44610031", 692 | "metadata": {}, 693 | "source": [ 694 | "Scatter across shards\n", 695 | "Reduce across batches" 696 | ] 697 | }, 698 | { 699 | "cell_type": "code", 700 | "execution_count": null, 701 | "id": "1034dbb9", 702 | "metadata": {}, 703 | "outputs": [], 704 | "source": [ 705 | "grad_weight = WeightGrad(0, 1, batches={1}, total_batches=4, \n", 706 | " shards={1}, total=4)\n", 707 | "grad_weight" 708 | ] 709 | }, 710 | { 711 | "cell_type": "code", 712 | "execution_count": null, 713 | "id": "02680e7e", 714 | "metadata": {}, 715 | "outputs": [], 716 | "source": [ 717 | "grad_weights = {i: WeightGrad(0, 1, batches={i}, total_batches=4, \n", 718 | " shards={0,1,2,3}, total=4) for i in range(4)}\n", 719 | "grad_weights[2]" 720 | ] 721 | }, 722 | { 723 | "cell_type": "code", 724 | "execution_count": null, 725 | "id": "1498773a", 726 | "metadata": { 727 | "lines_to_next_cell": 0 728 | }, 729 | "outputs": [], 730 | "source": [ 731 | "async def scatterreduce(model: Model) -> WeightGrad:\n", 732 | " # Allreduce sums together all the weight grads\n", 733 | " return await model.scatterreduce(grad_weights[model.rank], 0)\n", 734 | "\n", 735 | "dist = Dist(ranks)\n", 736 | "out = await asyncio.gather(*[\n", 737 | " scatterreduce(Model(layers=1, batches=1, rank=i, dist=dist))\n", 738 | " for i in range(ranks)])\n", 739 | "out[0]" 740 | ] 741 | }, 742 | { 743 | "cell_type": "code", 744 | "execution_count": null, 745 | "id": "261435f2", 746 | "metadata": { 747 | "lines_to_next_cell": 2 748 | }, 749 | "outputs": [], 750 | "source": [] 751 | }, 752 | { 753 | "cell_type": "markdown", 754 | "id": "7e62da15", 755 | "metadata": {}, 756 | "source": [ 757 | "### Puzzle 4: Fully-Sharded Data Parallel\n", 758 | "\n", 759 | "Run a model that shards each layer weight over all the machines. Reconstruct the layer weight at each layer using allgather. Collect the gradients with scatter-reduce.\n", 760 | "\n", 761 | "* Total Steps: 20\n", 762 | "* Total Memory: 2300000" 763 | ] 764 | }, 765 | { 766 | "cell_type": "code", 767 | "execution_count": null, 768 | "id": "43a35535", 769 | "metadata": { 770 | "lines_to_next_cell": 2 771 | }, 772 | "outputs": [], 773 | "source": [ 774 | "async def fsdp(model: Model) -> Model:\n", 775 | " # Storage on device.\n", 776 | " weights, opt_states, activations, grad_activations, grad_weights = model.storage()\n", 777 | "\n", 778 | " # Load all the activations\n", 779 | " model.activations[0] = model.get_activation([model.rank])\n", 780 | "\n", 781 | " # Load a shard of the weights for every layer. Load in the full weights\n", 782 | " for l in range(model.LAYERS):\n", 783 | " weights[l], opt_states[l] = model.load_weights(l, model.rank, model.RANKS) \n", 784 | "\n", 785 | " assert False, 'TODO: Implement me'\n", 786 | " for l in range(model.LAYERS):\n", 787 | " model.set_final_weight(l, weights[l])\n", 788 | " return model" 789 | ] 790 | }, 791 | { 792 | "cell_type": "code", 793 | "execution_count": null, 794 | "id": "dec61bda", 795 | "metadata": {}, 796 | "outputs": [], 797 | "source": [ 798 | "dist = Dist(ranks)\n", 799 | "out = await asyncio.gather(*[\n", 800 | " fsdp(Model(layers=6, batches=ranks, rank=i, dist=dist))\n", 801 | " for i in range(ranks)])\n", 802 | "draw_group(out[1].final_weights)" 803 | ] 804 | }, 805 | { 806 | "cell_type": "code", 807 | "execution_count": null, 808 | "id": "b9f62f28", 809 | "metadata": {}, 810 | "outputs": [], 811 | "source": [ 812 | "draw(out)" 813 | ] 814 | }, 815 | { 816 | "cell_type": "code", 817 | "execution_count": null, 818 | "id": "3a527fc7", 819 | "metadata": {}, 820 | "outputs": [], 821 | "source": [ 822 | "Model.check(out)" 823 | ] 824 | }, 825 | { 826 | "cell_type": "markdown", 827 | "id": "0858bb4f", 828 | "metadata": {}, 829 | "source": [ 830 | "## Communication: Point-to-Point\n", 831 | "\n", 832 | "An alternative approach to communication is to directly communicate specific information between GPUs. In our model, both GPUs talking to each other block and wait for the handoff. " 833 | ] 834 | }, 835 | { 836 | "cell_type": "code", 837 | "execution_count": null, 838 | "id": "de38df4d", 839 | "metadata": { 840 | "lines_to_next_cell": 2 841 | }, 842 | "outputs": [], 843 | "source": [ 844 | "async def talk(model: Model) -> None:\n", 845 | " if model.rank == 0:\n", 846 | " await model.pass_to(1, \"extra cheese\")\n", 847 | " val = await model.receive()\n", 848 | " print(val)\n", 849 | " else:\n", 850 | " val = await model.receive()\n", 851 | " print(val)\n", 852 | " val = await model.pass_to(0, \"pizza\")\n", 853 | "\n", 854 | "dist = Dist(2)\n", 855 | "result = await asyncio.gather(*[\n", 856 | " talk(Model(layers=1, batches=1, rank=i, dist=dist))\n", 857 | " for i in range(2)])" 858 | ] 859 | }, 860 | { 861 | "cell_type": "markdown", 862 | "id": "027b159c", 863 | "metadata": {}, 864 | "source": [ 865 | "### Puzzle 5: Pipeline Parallelism\n", 866 | "\n", 867 | "Split the layer weights and optimizers equally between GPUs. Have each GPU handle only its layer. Pass the full set of batches for activations and grad_activations between layers using p2p communication. No need for any global communication.\n", 868 | "\n", 869 | "* Total Steps: 66\n", 870 | "* Total Memory: 3300000" 871 | ] 872 | }, 873 | { 874 | "cell_type": "code", 875 | "execution_count": null, 876 | "id": "09feb2a6", 877 | "metadata": { 878 | "lines_to_next_cell": 2 879 | }, 880 | "outputs": [], 881 | "source": [ 882 | "async def pipeline(model: Model) -> Model:\n", 883 | " weights, opt_states, activations, grad_activations, grad_weights = model.storage()\n", 884 | " per_rank = model.LAYERS // model.RANKS\n", 885 | " my_layers = list([l + (model.rank * per_rank) for l in range(per_rank)])\n", 886 | " for l in my_layers:\n", 887 | " weights[l], opt_states[l] = model.load_weights(l)\n", 888 | " assert False, 'TODO: Implement me'\n", 889 | " for l in my_layers:\n", 890 | " model.set_final_weight(l, weights[l])\n", 891 | " return model" 892 | ] 893 | }, 894 | { 895 | "cell_type": "code", 896 | "execution_count": null, 897 | "id": "2e5c381b", 898 | "metadata": {}, 899 | "outputs": [], 900 | "source": [ 901 | "dist = Dist(ranks)\n", 902 | "out = await asyncio.gather(*[\n", 903 | " pipeline(Model(layers=8, batches=ranks, rank=i, dist=dist))\n", 904 | " for i in range(ranks)])\n", 905 | "draw_group(out[1].final_weights)" 906 | ] 907 | }, 908 | { 909 | "cell_type": "code", 910 | "execution_count": null, 911 | "id": "5a99ecad", 912 | "metadata": {}, 913 | "outputs": [], 914 | "source": [ 915 | "draw(out)" 916 | ] 917 | }, 918 | { 919 | "cell_type": "code", 920 | "execution_count": null, 921 | "id": "2b2f11d5", 922 | "metadata": {}, 923 | "outputs": [], 924 | "source": [ 925 | "Model.check(out)" 926 | ] 927 | }, 928 | { 929 | "cell_type": "markdown", 930 | "id": "f6606efc", 931 | "metadata": {}, 932 | "source": [ 933 | "### Puzzle 6: GPipe Schedule\n", 934 | "\n", 935 | "A major issue with the pipeline approach is that it causes a \"bubble\", i.e. time in the later layers waiting for the earlier layers to complete. An alternative approach is to split the batches smaller so you can pass them earlier. \n", 936 | "\n", 937 | "In this puzzle, you should run each batch by itself, and then pass. The graph should look similar as the one above but with a smaller bubble. \n", 938 | "\n", 939 | "* Total Steps: 33\n", 940 | "* Total Memory: 4100000" 941 | ] 942 | }, 943 | { 944 | "cell_type": "code", 945 | "execution_count": null, 946 | "id": "f5f33513", 947 | "metadata": { 948 | "lines_to_next_cell": 2 949 | }, 950 | "outputs": [], 951 | "source": [ 952 | "async def gpipe(model: Model) -> Model:\n", 953 | " weights, opt_states, activations, grad_activations, grad_weights = model.storage()\n", 954 | " per_rank = model.LAYERS // model.RANKS\n", 955 | " my_layers = list([l + (model.rank * per_rank) for l in range(per_rank)])\n", 956 | " for l in my_layers:\n", 957 | " weights[l], opt_states[l] = model.load_weights(l)\n", 958 | "\n", 959 | " assert False, 'TODO: Implement me'\n", 960 | " for l in my_layers:\n", 961 | " model.set_final_weight(l, weights[l])\n", 962 | "\n", 963 | " return model" 964 | ] 965 | }, 966 | { 967 | "cell_type": "code", 968 | "execution_count": null, 969 | "id": "f5c73b29", 970 | "metadata": {}, 971 | "outputs": [], 972 | "source": [ 973 | "dist = Dist(ranks)\n", 974 | "out = await asyncio.gather(*[\n", 975 | " gpipe(Model(layers=8, batches=ranks, rank=i, dist=dist))\n", 976 | " for i in range(ranks)])\n", 977 | "draw_group(out[1].final_weights)" 978 | ] 979 | }, 980 | { 981 | "cell_type": "code", 982 | "execution_count": null, 983 | "id": "0e759da9", 984 | "metadata": {}, 985 | "outputs": [], 986 | "source": [ 987 | "draw(out)" 988 | ] 989 | }, 990 | { 991 | "cell_type": "code", 992 | "execution_count": null, 993 | "id": "47d3102b", 994 | "metadata": { 995 | "lines_to_next_cell": 2 996 | }, 997 | "outputs": [], 998 | "source": [ 999 | "Model.check(out)" 1000 | ] 1001 | }, 1002 | { 1003 | "cell_type": "markdown", 1004 | "id": "3a4bc062", 1005 | "metadata": {}, 1006 | "source": [ 1007 | "### Puzzle 7: Pipeline + FSDP\n", 1008 | "\n", 1009 | "As a last exercise, we can put everything together. Here we are going to run a combination of pipeline parallelism while also sharding our weight between 16 different machines. Here the model only has 4 layers, so we will assign 4 GPUs to each layer in the pipeline parallel approach. \n", 1010 | "\n", 1011 | "This example requires combining both collective communication and p2p communication effectively. \n", 1012 | "\n", 1013 | "* Total Steps: 15\n", 1014 | "* Total Memory: 1000000" 1015 | ] 1016 | }, 1017 | { 1018 | "cell_type": "code", 1019 | "execution_count": null, 1020 | "id": "34757c26", 1021 | "metadata": {}, 1022 | "outputs": [], 1023 | "source": [ 1024 | "async def pipeline_fsdp(model: Model) -> Model:\n", 1025 | " weights, opt_states, activations, grad_activations, grad_weights = model.storage()\n", 1026 | " per_rank = model.LAYERS // (model.RANKS // 4)\n", 1027 | " my_layers = list([l + ((model.rank % 4) * per_rank) for l in range(per_rank)])\n", 1028 | " for l in range(model.LAYERS):\n", 1029 | " weights[l, 0], opt_states[l, 0] = model.load_weights(l, model.rank, model.RANKS)\n", 1030 | " def empty_grad(l):\n", 1031 | " return model.fake_grad(l, [])\n", 1032 | " assert False, 'TODO: Implement me'\n", 1033 | " for l in range(model.LAYERS):\n", 1034 | " model.set_final_weight(l, weights[l])\n", 1035 | " # Update\n", 1036 | " return model" 1037 | ] 1038 | }, 1039 | { 1040 | "cell_type": "code", 1041 | "execution_count": null, 1042 | "id": "35a0877b", 1043 | "metadata": { 1044 | "lines_to_next_cell": 2 1045 | }, 1046 | "outputs": [], 1047 | "source": [ 1048 | "dist = Dist(16)\n", 1049 | "out = await asyncio.gather(*[\n", 1050 | " pipeline_fsdp(Model(layers=4, batches=ranks, rank=i, dist=dist))\n", 1051 | " for i in range(16)])" 1052 | ] 1053 | }, 1054 | { 1055 | "cell_type": "code", 1056 | "execution_count": null, 1057 | "id": "d93c6256", 1058 | "metadata": {}, 1059 | "outputs": [], 1060 | "source": [ 1061 | "Model.check(out)\n", 1062 | "chalk.set_svg_height(1000)\n", 1063 | "chalk.set_svg_draw_height(1000) \n", 1064 | "\n", 1065 | "draw(out)" 1066 | ] 1067 | }, 1068 | { 1069 | "cell_type": "markdown", 1070 | "id": "24393483", 1071 | "metadata": {}, 1072 | "source": [ 1073 | "### When does it make sense to combine?\n", 1074 | "\n", 1075 | "The goal of these exercises is to give you a sense of the different methods out there for distributed training. However, there is not currently a one size fits all approach for distributed training. The right choice will depend on the constants such as batch size, memory per GPU, communication overhead, implementation complexity, model size, and specifics of architecture. \n", 1076 | "\n", 1077 | "As an example of what's left to explore, this last method Pipeline + FSDP is often not a great choice due to the complexities of communication speed. And in fact GPipe + FSDP also gets you into a bad place. The paper [Breadth First Pipeline Parallelism](https://arxiv.org/pdf/2211.05953.pdf) proposes instead a combination of pipeline scheduling and communication. Here's what it looks like. " 1078 | ] 1079 | }, 1080 | { 1081 | "cell_type": "markdown", 1082 | "id": "35c9493a", 1083 | "metadata": {}, 1084 | "source": [ 1085 | "![image.png](https://github.com/srush/LLM-Training-Puzzles/assets/35882/f286089a-83bd-483c-b441-f154821d161c)" 1086 | ] 1087 | } 1088 | ], 1089 | "metadata": { 1090 | "jupytext": { 1091 | "cell_metadata_filter": "-all", 1092 | "custom_cell_magics": "kql" 1093 | }, 1094 | "kernelspec": { 1095 | "display_name": "venv", 1096 | "language": "python", 1097 | "name": "python3" 1098 | } 1099 | }, 1100 | "nbformat": 4, 1101 | "nbformat_minor": 5 1102 | } 1103 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | asyncio 2 | git+https://github.com/chalk-diagrams/chalk/ 3 | --------------------------------------------------------------------------------