├── .gitignore
├── Distributed.html
├── Distributed.ipynb
├── Distributed.py
├── LICENSE
├── README.md
├── drawing.py
├── lib.py
├── puzzles.ipynb
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/Distributed.py:
--------------------------------------------------------------------------------
1 | # ---
2 | # jupyter:
3 | # jupytext:
4 | # cell_metadata_filter: -all
5 | # custom_cell_magics: kql
6 | # text_representation:
7 | # extension: .py
8 | # format_name: light
9 | # format_version: '1.5'
10 | # jupytext_version: 1.14.6
11 | # kernelspec:
12 | # display_name: venv
13 | # language: python
14 | # name: python3
15 | # ---
16 |
17 | # # LLM Training Puzzles
18 | #
19 | # by Sasha Rush ([@srush_nlp](https://twitter.com/srush_nlp))
20 |
21 | # %%capture
22 | # Uncomment to run in Colab
23 | # !pip install -qqq git+https://github.com/chalk-diagrams/chalk asyncio
24 | # !wget https://raw.githubusercontent.com/srush/LLM-Training-Puzzles/main/lib.py https://raw.githubusercontent.com/srush/LLM-Training-Puzzles/main/drawing.py
25 |
26 | from typing import List
27 | from lib import Model, Dist, WeightGrad
28 | from drawing import draw, draw_group
29 | from chalk import vcat
30 | import asyncio
31 | import chalk
32 | chalk.set_svg_height(400)
33 | chalk.set_svg_draw_height(600)
34 |
35 | # ## Preliminaries
36 | #
37 | # The goal of these puzzles is to learn about distributed training of LLMs. However, we will be primarily concerned with a speed and memory efficiency of completing a single update of the models. To make things simpler, we will abstract away from the standard tensor-based transformer model, and just consider a state-less representation of each of the components of a multi-layer neural network.
38 | #
39 | #
40 |
41 | model = Model(layers=2, batches=4)
42 | weights, opt_states, activations, grad_activations, grad_weights = model.storage()
43 |
44 | # Our library has 5 parts:
45 | #
46 | # * Weights
47 | # * Optimizer States - Values needed to update the weights
48 | # * Activations - The internal values computed on the forward pass
49 | # * Grad Activations - The gradients of the loss wrt to activations, needed for backward pass
50 | # * Grad Weights - The gradients of the loss wrt to weights, needed for updates
51 | #
52 | # For these puzzles, you are *not allowed* to have local variables. You need to store each of these in the dictionary corresponding to its type.
53 | #
54 | # We begin by tracing the lifecycle of a single model update.
55 |
56 | # Get the input activations to the model for batches 2, 3
57 | activations[0] = model.get_activation(batches=[2, 3])
58 | activations[0]
59 |
60 | # Load the weights (random) for layers 0 and 1
61 | for i in range(model.LAYERS):
62 | weights[i], opt_states[i] = model.load_weights(i)
63 | weights[0]
64 |
65 | # Activations can be moved forward a layer if you have the weights.
66 | activations[1] = model.forward(layer=0, inp=activations[0], weight=weights[0])
67 | activations[2] = model.forward(layer=1, inp=activations[1], weight=weights[1])
68 | activations[1]
69 |
70 | # Draw all the current activations in memory.
71 | draw_group(activations)
72 |
73 | # At the last layer, we can convert an activation to a grad activation by calling `loss`
74 | grad_activations[model.LAYERS] = model.loss(activations[model.LAYERS])
75 | grad_activations[model.LAYERS]
76 |
77 | # Calling `backward` requires the forward activation, the backward grad activation, and the weights.
78 | # It returns the grad weights and the backward activation.
79 | grad_weights[1], grad_activations[1] = model.backward(1, activations[1], grad_activations[2], weights[1])
80 | grad_weights[0], grad_activations[0] = model.backward(0, activations[0], grad_activations[1], weights[0])
81 | grad_activations[1]
82 |
83 | # We can use delete to remove any memory that is not longer needed.
84 | print("Before memory:", model.memory())
85 | del grad_activations[1]
86 | print("After memory:", model.memory())
87 | model.status()
88 | draw_group(grad_activations)
89 |
90 | # Grad weights keep track of which batches they are for. Here we only have the grad weights for batches 2 and 3.
91 | draw_group(grad_weights)
92 |
93 | # If we try to update with the grad weights we will get an error.
94 | try:
95 | model.update(0, weight_grad=grad_weights[0], weight=weights[0], opt_state=opt_states[0])
96 | except AssertionError as e:
97 | print("Error! Only have batches")
98 | print(e)
99 |
100 | # For this example, we can cheat. Pretend we had the other gradients we needed.
101 | grad_weights[0, 0] = model.fake_grad(0, [0,1])
102 | grad_weights[1, 0] = model.fake_grad(1, [0,1])
103 | grad_weights[0, 0]
104 |
105 |
106 | # Summing together grad_weights gives the full gradient.
107 | grad_weights[0] = grad_weights[0] + grad_weights[0, 0]
108 |
109 | # +
110 | # Now we can call update to the get the new weights and opt_state.
111 | weights[0], opt_states[0] = model.update(0, weight_grad=grad_weights[0], weight=weights[0],
112 | opt_state=opt_states[0])
113 |
114 | # WARNING: You need to set all variables. Otherwise they are not counted towards memory.
115 | grad_weights[1] = grad_weights[1] + grad_weights[1, 0]
116 | weights[1], opt_states[1] = model.update(1, weight_grad=grad_weights[1],
117 | weight=weights[1], opt_state=opt_states[1])
118 | # -
119 |
120 |
121 | # We can complete the tests by setting these as the final weights and calling check.
122 | model.set_final_weight(0, weights[0])
123 | model.set_final_weight(1, weights[1])
124 | Model.check([model])
125 | draw_group(model.final_weights)
126 |
127 | # We can view the final outcome of the system as a diagram.
128 | # This show the forward and backward passes (numbers of batches) and the updates.
129 | # The lines on the bottom show the memory that is used at each time step.
130 | draw([model])
131 |
132 |
133 |
134 | # ### Puzzle 0 - Standard Training
135 | #
136 | # Write a standard (non-distributed) training loop that acts on all the batches and loads all the weights. It should just run forward, loss, backward, and update. Aim for the least amount of max memory used.
137 | #
138 | # * Target Time: 17 steps
139 | # * Target Memory: 2600000
140 |
141 | def basic(model: Model) -> Model:
142 | # Storage on device.
143 | weights, opt_states, activations, grad_activations, grad_weights = model.storage()
144 |
145 | # Load in the full weights
146 | for l in range(model.LAYERS):
147 | weights[l], opt_states[l] = model.load_weights(l)
148 |
149 | # Load the input layer activations
150 | activations[0] = model.get_activation(range(model.BATCHES))
151 |
152 | ## USER CODE
153 | # Forward
154 | for l in range(model.LAYERS):
155 | activations[l + 1] = model.forward(l, activations[l], weights[l])
156 |
157 | # Backward
158 | grad_activations[model.LAYERS] = model.loss(activations[model.LAYERS])
159 | del activations[model.LAYERS]
160 |
161 | for l in range(model.LAYERS - 1, -1, -1):
162 | grad_weights[l], grad_activations[l] = model.backward(
163 | l, activations[l], grad_activations[l + 1], weights[l]
164 | )
165 | del grad_activations[l + 1], activations[l]
166 | del grad_activations[0]
167 | assert len(grad_activations) == 0 and len(activations) ==0
168 |
169 | # Update
170 | for l in range(model.LAYERS):
171 | weights[l], opt_states[l] = model.update(l, grad_weights[l], weights[l], opt_states[l])
172 | ## END USER CODE
173 |
174 | for l in range(model.LAYERS):
175 | model.set_final_weight(l, weights[l])
176 | return model
177 |
178 |
179 | out = basic(Model(layers=2, batches=4, rank=0, dist=Dist(1)))
180 | draw_group(out.final_weights)
181 |
182 | draw([out])
183 |
184 | Model.check([out])
185 |
186 |
187 | # ### Puzzle 1 - Gradient Accumulation
188 | #
189 | # For this puzzle, the goal is to reduce max memory usage. To do so you are going to run on each batch individually instead of all together.
190 | #
191 | # Write a function with four parts. First run on batches {0} and then {1} etc. Sum the grad weights and then update.
192 | #
193 | # * Target Time: 17 steps
194 | # * Target Memory: 2000000
195 |
196 | def grad_accum(model: Model) -> Model:
197 | # Storage on device.
198 | weights, opt_states, activations, grad_activations, grad_weights = model.storage()
199 |
200 | # Load in the full weights
201 | for l in range(model.LAYERS):
202 | weights[l], opt_states[l] = model.load_weights(l)
203 |
204 | ## USER CODE
205 | for r in range(model.BATCHES):
206 | # Load the input layer activations
207 | activations[0, r] = model.get_activation([r])
208 |
209 | ## USER CODE
210 | # Forward
211 | for l in range(model.LAYERS):
212 | activations[l + 1, r] = model.forward(l, activations[l, r], weights[l])
213 |
214 | # Backward
215 | grad_activations[model.LAYERS, r] = model.loss(activations[model.LAYERS, r])
216 | del activations[model.LAYERS, r]
217 |
218 | for l in range(model.LAYERS - 1, -1, -1):
219 | grad_weights[l, r], grad_activations[l, r] = model.backward(
220 | l, activations[l, r], grad_activations[l + 1, r], weights[l]
221 | )
222 | if r == 0:
223 | grad_weights[l] = grad_weights[l, r]
224 | else:
225 | grad_weights[l] = grad_weights[l] + grad_weights[l, r]
226 | del grad_activations[l + 1, r], activations[l,r], grad_weights[l, r]
227 | del grad_activations[0, r]
228 | assert len(grad_activations) == 0 and len(activations) == 0
229 |
230 | # Update
231 | for l in range(model.LAYERS):
232 | weights[l], opt_states[l] = \
233 | model.update(l,
234 | grad_weights[l], weights[l], opt_states[l])
235 |
236 | ## END USER CODE
237 | for l in range(model.LAYERS):
238 | model.set_final_weight(l, weights[l])
239 | return model
240 |
241 |
242 | out = grad_accum(Model(layers=2, batches=4, rank=0, dist=Dist(1)))
243 | draw_group(out.final_weights)
244 |
245 | draw([out])
246 |
247 | Model.check([out])
248 |
249 | # ## Communications: AllReduce
250 |
251 | # When working with multiple GPUs we need to have communication.
252 | # The primary communication primitives for GPUs are implemented in NCCL.
253 | #
254 | # https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/operations.html
255 | #
256 | # We are not going to use these directly, but simulate them using Python and asyncio.
257 | #
258 | # The first operation is AllReduce. We will have 4 GPUs (ranks=4) and use them each to compute a batch of weight grads.
259 |
260 | ranks = 4
261 | weight_grads = [WeightGrad(0, 1, {i}, ranks) for i in range(ranks)]
262 | weight_grads[0] + weight_grads[1] + weight_grads[2] + weight_grads[3]
263 |
264 | # Simple asynchronous function that calls allreduce to sum the weight grads at layer 0
265 | async def myfunc(model: Model) -> WeightGrad:
266 | return await model.allreduce(weight_grads[model.rank], 0)
267 |
268 | # This code uses asyncio to run the above function on 4 "GPUs" .
269 | dist = Dist(ranks)
270 | out_weight_grads = await asyncio.gather(*[
271 | myfunc(Model(layers=1, batches=1, rank=i, dist=dist))
272 | for i in range(ranks)])
273 | out_weight_grads[0]
274 |
275 | # Note: When running communication operations like AllReduce on a GPU, the communication happens in parallel to the computation on that GPU. That means the API for AllReduce does not block, and allows the model to continue running while waiting for this command to run. This means it is beneficial to run AllReduce (and other communication) as early as possible so that other compute can be run during the reduction.
276 | #
277 | # We will ignore this in these puzzles and represent communication as happening efficiently.
278 |
279 | # ### Puzzle 2 - Distributed Data Parallel
280 | #
281 | # Write a function with four parts. First run on batches {0} and then {1} etc. Sum the grad weights and then update. The main benefit of this approach is compute efficiency over gradient accumulation.
282 | #
283 | # * Total Steps: 5
284 | # * Total Memory: 1800000
285 |
286 | async def ddp(model: Model) -> Model:
287 | # Storage on device.
288 | weights, opt_states, activations, grad_activations, grad_weights = model.storage()
289 | # Load all the activations
290 | model.activations[0] = model.get_activation([model.rank])
291 |
292 | ## USER CODE
293 |
294 | # Load in the full weights
295 | for l in range(model.LAYERS):
296 | weights[l], opt_states[l] = model.load_weights(l)
297 |
298 | # Forward
299 | for l in range(model.LAYERS):
300 | activations[l + 1] = model.forward(l, activations[l], weights[l])
301 |
302 | # Backward
303 | grad_activations[model.LAYERS] = model.loss(activations[model.LAYERS])
304 |
305 | for l in range(model.LAYERS - 1, -1, -1):
306 | grad_weights[l], grad_activations[l] = model.backward(
307 | l, activations[l], grad_activations[l + 1], weights[l]
308 | )
309 | del grad_activations[l + 1], activations[l]
310 |
311 | # Update
312 | for l in range(model.LAYERS):
313 | grad_weights[l] = await model.allreduce(grad_weights[l], l)
314 | weights[l], opt_states[l] = model.update(l, grad_weights[l], weights[l], opt_states[l])
315 |
316 | ## END USER CODE
317 | for l in range(model.LAYERS):
318 | model.set_final_weight(l, weights[l])
319 | return model
320 |
321 |
322 | dist = Dist(ranks)
323 | out = await asyncio.gather(*[
324 | ddp(Model(layers=2, batches=ranks, rank=i, dist=dist))
325 | for i in range(ranks)])
326 | draw_group(out[0].final_weights)
327 |
328 | draw(out)
329 |
330 | Model.check(out)
331 |
332 | # ## Communication: AllGather / Sharding
333 | #
334 | # Our next primitive is AllGather. This allows us to communicate "shards" of an object stored on different GPUs to all the GPUs.
335 |
336 | # Load only part of a weights.
337 | model = Model(layers=2, batches=1, rank=0, dist=Dist(1))
338 | weight, _ = model.load_weights(0, shard=0, total=4)
339 | weight
340 |
341 | # Combine togegher two shards on one machine.
342 | weights = [model.load_weights(0, shard=i, total=ranks)[0] for i in range(ranks)]
343 | weights[0].combine(weights[2])
344 |
345 | # +
346 | # Use allgather to collect the shards from all machines.
347 | async def mygather(model: Model) -> WeightGrad:
348 | # Allreduce sums together all the weight grads
349 | return await model.allgather(weights[model.rank], 0)
350 |
351 | dist = Dist(ranks)
352 | out_weights = await asyncio.gather(*[
353 | mygather(Model(layers=1, batches=1, rank=i, dist=dist))
354 | for i in range(ranks)])
355 | out_weights[0]
356 | # -
357 |
358 | # ### Puzzle 3: Weight-Sharded Data Parallel
359 | #
360 | # Run a model that shards each layer weight over all the machines. Reconstruct the layer weight at each layer using allgather. Finally update the weights on each machine using allreduce.
361 | #
362 | # * Total Steps: 20
363 | # * Total Memory: 2800000
364 |
365 | async def wsdp(model: Model) -> Model:
366 | # Storage on device.
367 | weights, opt_states, activations, grad_activations, grad_weights = model.storage()
368 |
369 | # Load all the activations
370 | model.activations[0] = model.get_activation([model.rank])
371 |
372 | # Load a shard of the weights for every layer. Load in the full optimizer states
373 | for l in range(model.LAYERS):
374 | weights[l], opt_states[l] = model.load_weights(l, model.rank, model.RANKS)
375 |
376 | ## USER CODE
377 | # Forward
378 | for l in range(model.LAYERS):
379 | weights[l, 0] = await model.allgather(weights[l], l)
380 | activations[l + 1] = model.forward(l, activations[l], weights[l, 0])
381 | del weights[l, 0]
382 |
383 | # Backward
384 | grad_activations[model.LAYERS] = model.loss(activations[model.LAYERS])
385 |
386 | for l in range(model.LAYERS - 1, -1, -1):
387 | weights[l, 0] = await model.allgather(weights[l], l)
388 | grad_weights[l], grad_activations[l] = model.backward(
389 | l, activations[l], grad_activations[l + 1], weights[l, 0]
390 | )
391 | del grad_activations[l + 1], activations[l], weights[l, 0]
392 |
393 | # Update
394 | for l in range(model.LAYERS):
395 | grad_weights[l] = await model.allreduce(grad_weights[l], l)
396 | weights[l], opt_states[l] = model.update(l, grad_weights[l], weights[l], opt_states[l])
397 |
398 | ## END USER CODE
399 | for l in range(model.LAYERS):
400 | model.set_final_weight(l, weights[l])
401 |
402 | return model
403 |
404 | dist = Dist(ranks)
405 | out = await asyncio.gather(*[
406 | wsdp(Model(layers=6, batches=ranks, rank=i, dist=dist))
407 | for i in range(ranks)])
408 | draw_group(out[1].final_weights)
409 |
410 | draw(out)
411 |
412 | Model.check(out)
413 |
414 | # ## Communication: Scatter-Reduce
415 |
416 | # Scatter across shards
417 | # Reduce across batches
418 |
419 | grad_weight = WeightGrad(0, 1, batches={1}, total_batches=4,
420 | shards={1}, total=4)
421 | grad_weight
422 |
423 | grad_weights = {i: WeightGrad(0, 1, batches={i}, total_batches=4,
424 | shards={0,1,2,3}, total=4) for i in range(4)}
425 | grad_weights[2]
426 |
427 | # +
428 | async def scatterreduce(model: Model) -> WeightGrad:
429 | # Allreduce sums together all the weight grads
430 | return await model.scatterreduce(grad_weights[model.rank], 0)
431 |
432 | dist = Dist(ranks)
433 | out = await asyncio.gather(*[
434 | scatterreduce(Model(layers=1, batches=1, rank=i, dist=dist))
435 | for i in range(ranks)])
436 | out[0]
437 | # -
438 |
439 |
440 |
441 | # ### Puzzle 4: Fully-Sharded Data Parallel
442 | #
443 | # Run a model that shards each layer weight over all the machines. Reconstruct the layer weight at each layer using allgather. Collect the gradients with scatter-reduce.
444 | #
445 | # * Total Steps: 20
446 | # * Total Memory: 2300000
447 |
448 | async def fsdp(model: Model) -> Model:
449 | # Storage on device.
450 | weights, opt_states, activations, grad_activations, grad_weights = model.storage()
451 |
452 | # Load all the activations
453 | model.activations[0] = model.get_activation([model.rank])
454 |
455 | # Load a shard of the weights for every layer. Load in the full weights
456 | for l in range(model.LAYERS):
457 | weights[l], opt_states[l] = model.load_weights(l, model.rank, model.RANKS)
458 |
459 | ## USER CODE
460 | # Forward
461 | for l in range(model.LAYERS):
462 | weights[l, 0] = await model.allgather(weights[l], l)
463 | activations[l + 1] = model.forward(l, activations[l], weights[l, 0])
464 | del weights[l, 0]
465 |
466 | # Backward
467 | grad_activations[model.LAYERS] = model.loss(activations[model.LAYERS])
468 | del(activations[model.LAYERS])
469 |
470 | for l in range(model.LAYERS - 1, -1, -1):
471 | weights[l, 0] = await model.allgather(weights[l], l)
472 | grad_weights[l], grad_activations[l] = model.backward(
473 | l, activations[l], grad_activations[l + 1], weights[l, 0]
474 | )
475 | grad_weights[l] = await model.scatterreduce(grad_weights[l], l)
476 | del grad_activations[l + 1], activations[l], weights[l, 0]
477 |
478 | # Update
479 | for l in range(model.LAYERS):
480 | weights[l], opt_states[l] = model.update(l, grad_weights[l], weights[l], opt_states[l])
481 |
482 | ## END USER CODE
483 | for l in range(model.LAYERS):
484 | model.set_final_weight(l, weights[l])
485 | return model
486 |
487 |
488 | dist = Dist(ranks)
489 | out = await asyncio.gather(*[
490 | fsdp(Model(layers=6, batches=ranks, rank=i, dist=dist))
491 | for i in range(ranks)])
492 | draw_group(out[1].final_weights)
493 |
494 | draw(out)
495 |
496 | Model.check(out)
497 |
498 | # ## Communication: Point-to-Point
499 | #
500 | # An alternative approach to communication is to directly communicate specific information between GPUs. In our model, both GPUs talking to each other block and wait for the handoff.
501 |
502 | # +
503 | async def talk(model: Model) -> None:
504 | if model.rank == 0:
505 | await model.pass_to(1, "extra cheese")
506 | val = await model.receive()
507 | print(val)
508 | else:
509 | val = await model.receive()
510 | print(val)
511 | val = await model.pass_to(0, "pizza")
512 |
513 | dist = Dist(2)
514 | result = await asyncio.gather(*[
515 | talk(Model(layers=1, batches=1, rank=i, dist=dist))
516 | for i in range(2)])
517 | # -
518 |
519 |
520 | # ### Puzzle 5: Pipeline Parallelism
521 | #
522 | # Split the layer weights and optimizers equally between GPUs. Have each GPU handle only its layer. Pass the full set of batches for activations and grad_activations between layers using p2p communication. No need for any global communication.
523 | #
524 | # * Total Steps: 66
525 | # * Total Memory: 3300000
526 |
527 | async def pipeline(model: Model) -> Model:
528 | weights, opt_states, activations, grad_activations, grad_weights = model.storage()
529 | per_rank = model.LAYERS // model.RANKS
530 | my_layers = list([l + (model.rank * per_rank) for l in range(per_rank)])
531 | for l in my_layers:
532 | weights[l], opt_states[l] = model.load_weights(l)
533 | ## USER CODE
534 |
535 | if model.rank == 0:
536 | activations[0] = model.get_activation(range(model.BATCHES))
537 | else:
538 | activations[my_layers[0]] = await model.receive()
539 |
540 | # Forward
541 | for l in my_layers:
542 | activations[l + 1] = model.forward(l, activations[l], weights[l])
543 |
544 | # Backward
545 | if model.rank == model.RANKS - 1:
546 | grad_activations[model.LAYERS] = model.loss(
547 | activations[model.LAYERS]
548 | )
549 | else:
550 | await model.pass_to(model.rank + 1, activations[l + 1])
551 | grad_activations[l + 1] = await model.receive()
552 |
553 | for l in reversed(my_layers):
554 | grad_weights[l], grad_activations[l] = model.backward(
555 | l, activations[l], grad_activations[l + 1], model.weights[l]
556 | )
557 | del model.grad_activations[l + 1], model.activations[l]
558 |
559 | if model.rank != 0:
560 | await model.pass_to(model.rank - 1, grad_activations[l])
561 |
562 | # Update
563 | for l in my_layers:
564 | weights[l], opt_states[l] = model.update(l, grad_weights[l], weights[l], opt_states[l])
565 |
566 | ## END USER CODE
567 | for l in my_layers:
568 | model.set_final_weight(l, weights[l])
569 | return model
570 |
571 |
572 | dist = Dist(ranks)
573 | out = await asyncio.gather(*[
574 | pipeline(Model(layers=8, batches=ranks, rank=i, dist=dist))
575 | for i in range(ranks)])
576 | draw_group(out[1].final_weights)
577 |
578 | draw(out)
579 |
580 | Model.check(out)
581 |
582 | # ### Puzzle 6: GPipe Schedule
583 | #
584 | # A major issue with the pipeline approach is that it causes a "bubble", i.e. time in the later layers waiting for the earlier layers to complete. An alternative approach is to split the batches smaller so you can pass them earlier.
585 | #
586 | # In this puzzle, you should run each batch by itself, and then pass. The graph should look similar as the one above but with a smaller bubble.
587 | #
588 | # * Total Steps: 33
589 | # * Total Memory: 4100000
590 |
591 | async def gpipe(model: Model) -> Model:
592 | weights, opt_states, activations, grad_activations, grad_weights = model.storage()
593 | per_rank = model.LAYERS // model.RANKS
594 | my_layers = list([l + (model.rank * per_rank) for l in range(per_rank)])
595 | for l in my_layers:
596 | weights[l], opt_states[l] = model.load_weights(l)
597 |
598 | # USER CODE
599 | for mb in range(model.BATCHES):
600 | # Forward
601 | if model.rank == 0:
602 | activations[0, mb] = model.get_activation([mb])
603 | else:
604 | activations[my_layers[0], mb] = await model.receive()
605 |
606 | for l in my_layers:
607 | activations[l + 1, mb] = model.forward(l, activations[l, mb], weights[l])
608 | if model.rank != model.RANKS - 1:
609 | await model.pass_to(model.rank + 1, activations[l + 1, mb])
610 |
611 | for mb in range(model.BATCHES):
612 | # Backward
613 | if model.rank == model.RANKS - 1:
614 | grad_activations[model.LAYERS, mb] = model.loss(
615 | activations[model.LAYERS, mb]
616 | )
617 | else:
618 | grad_activations[my_layers[-1] + 1, mb] = await model.receive()
619 |
620 | for l in reversed(my_layers):
621 | grad_weights[l, mb], grad_activations[l, mb] = model.backward(
622 | l, activations[l, mb], grad_activations[l + 1, mb], weights[l]
623 | )
624 | del grad_activations[l + 1, mb], activations[l, mb]
625 |
626 | if model.rank != 0:
627 | await model.pass_to(model.rank - 1, grad_activations[l, mb])
628 |
629 | # Update
630 | for l in reversed(my_layers):
631 | for mb in range(model.BATCHES):
632 | if mb != 0:
633 | grad_weights[l] = grad_weights[l] + grad_weights[l, mb]
634 | else:
635 | grad_weights[l] = grad_weights[l, 0]
636 | del grad_weights[l, mb]
637 | weights[l], opt_states[l] = model.update(l, grad_weights[l], weights[l], opt_states[l])
638 |
639 | ## END USER CODE
640 | for l in my_layers:
641 | model.set_final_weight(l, weights[l])
642 |
643 | return model
644 |
645 |
646 | dist = Dist(ranks)
647 | out = await asyncio.gather(*[
648 | gpipe(Model(layers=8, batches=ranks, rank=i, dist=dist))
649 | for i in range(ranks)])
650 | draw_group(out[1].final_weights)
651 |
652 | draw(out)
653 |
654 | Model.check(out)
655 |
656 |
657 | # ### Puzzle 7: Pipeline + FSDP
658 | #
659 | # As a last exercise, we can put everything together. Here we are going to run a combination of pipeline parallelism while also sharding our weight between 16 different machines. Here the model only has 4 layers, so we will assign 4 GPUs to each layer in the pipeline parallel approach.
660 | #
661 | # This example requires combining both collective communication and p2p communication effectively.
662 | #
663 | # * Total Steps: 15
664 | # * Total Memory: 1000000
665 |
666 | async def pipeline_fsdp(model: Model) -> Model:
667 | weights, opt_states, activations, grad_activations, grad_weights = model.storage()
668 | per_rank = model.LAYERS // (model.RANKS // 4)
669 | my_layers = list([l + ((model.rank % 4) * per_rank) for l in range(per_rank)])
670 | for l in range(model.LAYERS):
671 | weights[l, 0], opt_states[l, 0] = model.load_weights(l, model.rank, model.RANKS)
672 | def empty_grad(l):
673 | return model.fake_grad(l, [])
674 | ## USER CODE
675 | # Forward
676 | for l in range(model.LAYERS):
677 | if l == my_layers[0]:
678 | if model.rank % 4 == 0:
679 | activations[0] = model.get_activation([model.rank // 4])
680 | else:
681 | activations[l] = await model.receive()
682 |
683 | weights[l] = await model.allgather(weights[l, 0], l)
684 | if l in my_layers:
685 | activations[l + 1] = model.forward(l, activations[l], weights[l])
686 | del weights[l]
687 | if l == my_layers[-1]:
688 | if model.rank % 4 == 3 :
689 | grad_activations[model.LAYERS] = model.loss(
690 | activations[model.LAYERS]
691 | )
692 | else:
693 | await model.pass_to(model.rank + 1, activations[l + 1])
694 | # Backward
695 |
696 | for l in reversed(range(model.LAYERS)):
697 | if l == my_layers[-1]:
698 | if model.rank % 4 != 3:
699 | grad_activations[l + 1] = await model.receive()
700 |
701 | weights[l] = await model.allgather(weights[l, 0], l)
702 | if l in my_layers:
703 | grad_weights[l], grad_activations[l] = model.backward(
704 | l, activations[l], grad_activations[l + 1], model.weights[l]
705 | )
706 | del grad_activations[l + 1], activations[l]
707 | grad_weights[l] = await model.scatterreduce(grad_weights[l], l)
708 | else:
709 | grad_weights[l] = await model.scatterreduce(empty_grad(l), l)
710 | del weights[l]
711 |
712 | if model.rank % 4 != 0 and l == my_layers[0]:
713 | await model.pass_to(model.rank - 1, grad_activations[l])
714 | for l in range(model.LAYERS):
715 | weights[l], opt_states[l] = model.update(l, grad_weights[l], weights[l, 0], opt_states[l, 0])
716 |
717 | # END USER CODE
718 | for l in range(model.LAYERS):
719 | model.set_final_weight(l, weights[l])
720 | # Update
721 | return model
722 |
723 | dist = Dist(16)
724 | out = await asyncio.gather(*[
725 | pipeline_fsdp(Model(layers=4, batches=ranks, rank=i, dist=dist))
726 | for i in range(16)])
727 |
728 |
729 | # +
730 | Model.check(out)
731 | chalk.set_svg_height(1000)
732 | chalk.set_svg_draw_height(1000)
733 |
734 | draw(out)
735 | # -
736 |
737 | # ### When does it make sense to combine?
738 | #
739 | # The goal of these exercises is to give you a sense of the different methods out there for distributed training. However, there is not currently a one size fits all approach for distributed training. The right choice will depend on the constants such as batch size, memory per GPU, communication overhead, implementation complexity, model size, and specifics of architecture.
740 | #
741 | # As an example of what's left to explore, this last method Pipeline + FSDP is often not a great choice due to the complexities of communication speed. And in fact GPipe + FSDP also gets you into a bad place. The paper [Breadth First Pipeline Parallelism](https://arxiv.org/pdf/2211.05953.pdf) proposes instead a combination of pipeline scheduling and communication. Here's what it looks like.
742 |
743 | # 
744 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Sasha Rush
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LLM Training Puzzles
2 | - by [Sasha Rush](http://rush-nlp.com) - [srush_nlp](https://twitter.com/srush_nlp)
3 |
4 | 
5 |
6 |
7 | This is a collection of 8 challenging puzzles about training large language models (or really any NN) on many, many GPUs.
8 | Very few people actually get a chance to train on thousands of computers, but it is an interesting challenge and one that
9 | is critically important for modern AI. The goal of these puzzles is to get hands-on experience with the key primitives and to understand
10 | the goals of memory efficiency and compute pipelining.
11 |
12 |
13 | I recommend running in Colab. Click here and copy the notebook to get start.
14 |
15 | [](https://colab.research.google.com/github/srush/LLM-Training-Puzzles/blob/main/puzzles.ipynb)
16 |
17 | 
18 |
19 |
20 |
21 | If you are into this kind of thing, this is 6th in a series of these puzzles.
22 |
23 | * https://github.com/srush/gpu-puzzles
24 | * https://github.com/srush/tensor-puzzles
25 | * https://github.com/srush/autodiff-puzzles
26 | * https://github.com/srush/transformer-puzzles
27 | * https://github.com/srush/GPTworld
28 |
--------------------------------------------------------------------------------
/drawing.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence, Optional, FrozenSet
2 | from lib import Model, Activation, Weight, ActivationGrad, WeightGrad
3 |
4 | from chalk import rectangle, text, vcat, empty, hcat, Diagram, concat
5 | from colour import Color
6 |
7 | # backward_col = list(Color("Yellow").range_to("Red", LAYERS))
8 |
9 |
10 | def draw(models : Sequence[Model]) -> Diagram:
11 | TIME = 2
12 | layers = models[0].LAYERS
13 | forward_col = list(Color("green").range_to("red", layers + 1))
14 | MAXTIME = max(m.log[-1].time for m in models)
15 | MAXMEM = 0
16 | ARGMEM = None
17 | for i, m in enumerate(models):
18 | for event in m.log:
19 | if event.memory > MAXMEM:
20 | MAXMEM = event.memory
21 | ARGMEM = (i, event.time)
22 | print(f"Timesteps: {MAXTIME} \nMaximum Memory: {MAXMEM} at GPU: {ARGMEM[0]} time: {ARGMEM[1]}")
23 | def square(layer:int, time:int, length:int, s: str="") -> Diagram:
24 | PAD = 0.2
25 | return (
26 | (
27 | rectangle(length * TIME - 0.05, 0.9).line_width(0)
28 | + text(s, 0.9)
29 | .translate(0, 0.1)
30 | .line_width(0.05)
31 | .fill_color(Color("black"))
32 | )
33 | .align_l()
34 | .translate(time * TIME, 0)
35 | )
36 |
37 | def draw_forward(e):
38 | return square(e.layer, e.time, e.length, " ".join(map(str, e.batches))).fill_color(
39 | forward_col[e.layer]
40 | )
41 |
42 | def draw_backward(e):
43 | return (
44 | square(e.layer, e.time, e.length, " ".join(map(str, e.batches)))
45 | .fill_color(forward_col[e.layer])
46 | .fill_opacity(0.2)
47 | )
48 |
49 | def draw_update(e):
50 | return square(e.layer, e.time, e.length, "").fill_color(Color("yellow")).fill_opacity(0.5)
51 |
52 | def draw_allred(e):
53 | return square(e.layer, e.time, e.length).fill_color(forward_col[e.layer]).fill_opacity(0.5)
54 | def draw_pass(e):
55 | return square(0, e.time, e.length).fill_color(Color("grey")).fill_opacity(0.5)
56 |
57 | rows = []
58 |
59 | # Time
60 | for gpu in range(len(models)):
61 | d = empty()
62 | box = (
63 | rectangle(TIME * (MAXTIME + 1), 1).fill_color(Color("lightgrey")).align_l()
64 | )
65 | box = box + text(str(gpu), 1).line_width(0).with_envelope(
66 | rectangle(TIME, 1)
67 | ).translate(-TIME, 0).fill_color(Color("orange"))
68 | d += box
69 | for e in models[gpu].log:
70 | if e.typ == "forward":
71 | d += draw_forward(e)
72 | if e.typ == "backward":
73 | d += draw_backward(e)
74 | if e.typ == "update":
75 | d += draw_update(e)
76 | if e.typ in ["allreduce", "scatterreduce", "allgather"]:
77 | d += draw_allred(e)
78 | if e.typ in ["pass"]:
79 | d += draw_pass(e)
80 | rows.append(d)
81 | d = vcat(rows)
82 |
83 | rows = []
84 | for gpu in range(len(models)):
85 | row = rectangle(TIME * (MAXTIME + 1), 1).fill_color(Color("white")).align_l()
86 | row = row + text(str(gpu), 1).line_width(0).with_envelope(
87 | rectangle(TIME, 1)
88 | ).translate(-TIME, 0).fill_color(Color("grey"))
89 | for e in models[gpu].log:
90 | can = (
91 | rectangle(TIME * e.length, e.memory / (1.5 * MAXMEM))
92 | .align_b()
93 | .align_l()
94 | .line_width(0)
95 | .fill_color(Color("grey"))
96 | )
97 | row = row.align_b() + can.translate(TIME * e.time, 0)
98 | if gpu == ARGMEM[0]:
99 | row = row + rectangle(0.1, 1).translate(TIME * ARGMEM[1], 0).line_color(Color("red")).align_b()
100 |
101 | rows.append(row)
102 | d2 = vcat(rows)
103 | d = vcat([d, d2])
104 | # return rectangle(1.5, 0.5) + d.scale_uniform_to_x(1).center_xy()
105 | return rectangle(1.5, 8).line_width(0) + d.scale_uniform_to_y(len(models)).center_xy()
106 |
107 |
108 | def draw_network(layers:int, weight: Optional[int]=None, before:int=-1, after:int=100,
109 | shards:FrozenSet[int]=frozenset({}), total:int=1,
110 | batches: FrozenSet[int]=frozenset({0}), total_batches=1, is_grad: bool=False) -> Diagram:
111 | forward_col = list(Color("green").range_to("red", layers+1))
112 | def layer(l: int) -> Diagram:
113 | W = 3
114 | H = 1 if l < layers else 0
115 | layer = rectangle(W, H).line_width(0.2).align_b()
116 | shard_h = H * (1 / total)
117 | shard_w = W * (1 / total_batches)
118 |
119 | weight_shard = rectangle(shard_w, shard_h).line_width(0.01).align_t().align_l().line_color(Color("white"))
120 | weight_shards = concat([weight_shard.translate(batch * shard_w - (W/2),
121 | shard * shard_h - H)
122 | for shard in shards for batch in batches])
123 |
124 | connect_out = rectangle(1.5, 1.05).line_width(0.2).align_t()
125 | connect_w = 1.5 * (1 / total_batches)
126 | connect = rectangle(connect_w, 1).line_width(0.01).align_l().line_color(Color("white"))
127 | connect = concat([connect.translate(batch * connect_w - 1.5 * (1/2), 0.0)
128 | for batch in batches]).align_t().translate(0, 0.025)
129 |
130 | if l == weight:
131 | weight_shards = weight_shards.fill_color(forward_col[l])
132 | if is_grad:
133 | weight_shards = weight_shards.fill_opacity(0.5)
134 | else:
135 | weight_shards = empty()
136 | if l == before: #or (after != 100 and l <= after):
137 | connect = connect.fill_color(forward_col[l]).fill_opacity(1)
138 | elif l == after + 1:
139 | connect = connect.fill_color(forward_col[l]).fill_opacity(0.5)
140 | else:
141 | connect = empty()
142 | base = connect_out + layer
143 | return base, (connect + weight_shards).with_envelope(base)
144 | return vcat(reversed([layer(l)[0] for l in range(layers+1)])), vcat(reversed([layer(l)[1] for l in range(layers+1)]))
145 |
146 |
147 | def draw_group(group):
148 | group = list(group.values())
149 | base = group[0].draw()[0]
150 | return base + concat([g.draw()[1] for g in group])
151 |
152 |
153 | # hcat([base + Activation(2, 5, [0], 2).draw()[1],
154 | # base + Weight(2, 5, 0, (0,), 2).draw()[1],
155 | # base + WeightGrad(2, 5, {0}, 2, {1}, 2).draw()[1] + WeightGrad(2, 5, {1}, 2, {1}, 2).draw()[1],
156 | # base + ActivationGrad(2, 5, frozenset({1}), 2).draw()[1]], 1).render_svg("activation.svg", 500)
--------------------------------------------------------------------------------
/lib.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import asyncio
3 | from dataclasses import dataclass
4 | from typing import Any, Dict, List, Optional, Protocol, Sequence, Tuple, FrozenSet, TypeVar
5 | from chalk import Diagram
6 | import random
7 |
8 | class Barrier:
9 | """Sync across n ranks"""
10 | def __init__(self, target: int):
11 | self.counter = 0
12 | self.target = target
13 | self.lock = asyncio.Lock()
14 | self.round = 0
15 | self.done = 0
16 |
17 | async def wait(self, rank: int) -> None:
18 | while self.done > 0:
19 | await asyncio.sleep(0.01)
20 | async with self.lock:
21 | self.counter += 1
22 | while self.counter < self.target:
23 | await asyncio.sleep(0.01)
24 | self.done += 1
25 | if rank == 0:
26 | await self.reset()
27 |
28 | async def reset(self) -> None:
29 | while self.done < self.target:
30 | await asyncio.sleep(0.01)
31 | self.counter = 0
32 | self.done = 0
33 |
34 | T = TypeVar('T')
35 | class Reduceable(Protocol[T]):
36 | """
37 | A type that can be reduced.
38 | """
39 | def __add__(self, other: T) -> T:
40 | ...
41 |
42 | O = TypeVar('O')
43 | class Gatherable(Protocol[O]):
44 | """
45 | A type that can be sharded.
46 | """
47 | def shard(self, shard: int, total: int) -> O:
48 | ...
49 |
50 | def is_complete(self) -> bool:
51 | ...
52 |
53 | def combine(self, other: O) -> O:
54 | ...
55 |
56 |
57 | TO = TypeVar('TO')
58 | class ReduceableGatherable(Reduceable[TO], Gatherable[TO]):
59 | pass
60 |
61 |
62 | class Dist:
63 | def __init__(self, total: int) -> None:
64 | self.reduce: Optional[Any] = None
65 | self.gather: Optional[Any] = None
66 | self.ranks = total
67 | self.barrier = Barrier(total)
68 | self.queue : Sequence[asyncio.Queue[Any]] = [asyncio.Queue(maxsize=1) for i in range(total)]
69 | self.mtime = 0
70 |
71 | async def allreduce(self, rank: int, inp: T, time:int) -> Tuple[T, int]:
72 | if self.reduce is None:
73 | self.reduce = inp
74 | else:
75 | self.reduce = self.reduce + inp
76 | self.mtime = max(time, self.mtime)
77 | await self.barrier.wait(rank)
78 | q: T = self.reduce
79 | mtime = self.mtime
80 | await self.barrier.wait(rank)
81 | if rank == 0:
82 | self.reduce = None
83 | self.mtime = 0
84 | await self.barrier.wait(rank)
85 | return q, mtime
86 |
87 | async def allgather(self, rank: int, inp: O, time:int) -> Tuple[O, int]:
88 | if self.gather is None:
89 | self.gather = inp
90 | else:
91 | assert type(self.gather) == type(inp)
92 | self.gather = self.gather.combine(inp)
93 | self.mtime = max(time, self.mtime)
94 | await self.barrier.wait(rank)
95 | q: O = self.gather
96 | mtime = self.mtime
97 | await self.barrier.wait(rank)
98 | if rank == 0:
99 | self.gather = None
100 | self.mtime = 0
101 | await self.barrier.wait(rank)
102 | return q, mtime
103 |
104 | async def scatterreduce(self, rank: int, inp: TO, time:int) -> Tuple[TO, int]:
105 | x, time = await self.allreduce(rank, inp, time)
106 | y = x.shard(rank, self.ranks) # type: ignore
107 | return y, time # type: ignore
108 |
109 | async def receive(self, rank: int) -> Any:
110 | return await self.queue[rank].get()
111 |
112 | async def pass_to(self, rank: int, v: Any) -> None:
113 | await self.queue[rank].put(v)
114 |
115 |
116 | @dataclass
117 | class Weight(Gatherable["Weight"]):
118 | """
119 | The weights for a specific layer. Can be sharded.
120 |
121 | Required for forward and backward passes.
122 | """
123 | layer: int
124 | layers: int
125 | step: int
126 | shards: FrozenSet[int] = frozenset([0])
127 | total: int = 1
128 |
129 | def combine(self, other: Weight) -> Weight:
130 | return Weight(self.layer, self.layers, self.step, self.shards | other.shards, self.total)
131 |
132 | def memory(self) -> float:
133 | return (len(self.shards) / self.total) * HIDDEN * HIDDEN
134 |
135 | def shard(self, shard: int, total: int) -> Weight:
136 | assert self.is_complete()
137 | assert shard < total
138 | return Weight(self.layer, self.layers, self.step, frozenset([shard]), total)
139 |
140 | def is_complete(self) -> bool:
141 | return len(self.shards) == self.total
142 |
143 | def draw(self) -> Diagram:
144 | from drawing import draw_network
145 | return draw_network(self.layers, weight=self.layer,
146 | shards=self.shards, total=self.total)
147 | def _repr_svg_(self):
148 | d = self.draw()
149 | return (d[0] + d[1])._repr_svg_()
150 |
151 | HIDDEN = 512
152 | LENGTH = 256
153 | @dataclass
154 | class Activation:
155 | """
156 | Activations need for a specific layer for a specific set of batches.
157 | """
158 | layer: int
159 | layers: int
160 | batches: FrozenSet[int]
161 | total_batches: int
162 |
163 | def memory(self) -> int:
164 | return len(self.batches) * HIDDEN * LENGTH
165 |
166 | def draw(self) -> Diagram:
167 | from drawing import draw_network
168 | return draw_network(self.layers, before=self.layer,
169 | batches=self.batches, total_batches=self.total_batches)
170 |
171 | def _repr_svg_(self):
172 | d = self.draw()
173 | return (d[0] + d[1])._repr_svg_()
174 |
175 | @dataclass
176 | class WeightGrad(Reduceable["WeightGrad"], Gatherable["WeightGrad"]):
177 | """
178 | The gradient of the loss for a specific weight layer.
179 |
180 | May be sharded to correspond to different parts of the weights.
181 |
182 | May be split into different batches.
183 | """
184 |
185 |
186 | layer: int
187 | layers: int
188 | batches: FrozenSet[int]
189 | total_batches: int
190 | shards: FrozenSet[int] = frozenset([0])
191 | total: int = 1
192 |
193 | def __add__(self, other: WeightGrad) -> WeightGrad:
194 | assert self.layer == other.layer, "Only add same layer weight grads"
195 | assert self.shards == other.shards
196 | return WeightGrad(self.layer, self.layers, self.batches | other.batches, self.total_batches,
197 | self.shards, self.total)
198 |
199 | def combine(self, other: WeightGrad) -> WeightGrad:
200 | return WeightGrad(self.layer, self.layers, self.batches, self.total_batches,
201 | self.shards | other.shards, self.total)
202 |
203 | def memory(self) -> float:
204 | return (len(self.shards) / self.total) * HIDDEN * HIDDEN
205 |
206 | def shard(self, shard: int, total: int) -> WeightGrad:
207 | assert self.is_complete(), f"{self.shards} out of {self.total}"
208 | assert shard < total
209 | return WeightGrad(self.layer, self.layers, self.batches, self.total_batches, frozenset([shard]), total)
210 |
211 | def is_complete(self) -> bool:
212 | return len(self.shards) == self.total
213 |
214 | def draw(self) -> Diagram:
215 | from drawing import draw_network
216 | return draw_network(self.layers, weight=self.layer, shards=self.shards,
217 | batches=self.batches,
218 | total=self.total, total_batches=self.total_batches, is_grad=True)
219 |
220 | def _repr_svg_(self):
221 | d = self.draw()
222 | return (d[0] + d[1])._repr_svg_()
223 |
224 |
225 | @dataclass
226 | class OptState(Gatherable["OptState"]):
227 | """
228 | The state of the optimizer for a specific layer. Can be sharded.
229 |
230 | In pratice this represents ADAM's saved values needed for optimization.
231 |
232 | Required for updating the weights.
233 | """
234 |
235 | layer: int
236 | layers: int
237 | step: int
238 | shards: FrozenSet[int] = frozenset([0,])
239 | total: int = 1
240 |
241 | def combine(self, other: OptState) -> OptState:
242 | return OptState(self.layer, self.layers, self.step, self.shards | other.shards, self.total)
243 |
244 | def memory(self) -> float:
245 | return HIDDEN * HIDDEN * (len(self.shards) / self.total)
246 |
247 | def draw(self) -> Diagram:
248 | from drawing import draw_network
249 | return draw_network(self.layers, before=self.layer, shards=self.shards, total=self.total)
250 |
251 | def _repr_svg_(self):
252 | d = self.draw()
253 | return (d[0] + d[1])._repr_svg_()
254 |
255 | @dataclass
256 | class ActivationGrad:
257 | """
258 | The gradient of the activations for a specific layer.
259 |
260 | May be split into different batches.
261 | """
262 |
263 | layer: int
264 | layers: int
265 | batches: FrozenSet[int]
266 | total_batches: int
267 |
268 | def memory(self) -> int:
269 | return len(self.batches) * HIDDEN * LENGTH
270 |
271 | def draw(self) -> Diagram:
272 | from drawing import draw_network
273 | return draw_network(self.layers, after=self.layer,
274 | batches=self.batches, total_batches=self.total_batches)
275 |
276 | def _repr_svg_(self):
277 | d = self.draw()
278 | return (d[0] + d[1])._repr_svg_()
279 |
280 |
281 | @dataclass
282 | class Event:
283 | "Internal representations of events in the model for the visualizer"
284 | typ: str
285 | layer: Optional[int]
286 | rank: int
287 | time: int
288 | length: int
289 | memory: int
290 | batches: FrozenSet[int] = frozenset()
291 |
292 |
293 | class Model:
294 | def __init__(self, rank: int=1, dist: Dist=Dist(1), layers: int=2, batches: int=1):
295 | self.rank = rank
296 | self.log: List[Event] = []
297 | self.dist = dist
298 | self.time = 0
299 | self.RANKS = dist.ranks
300 | self.LAYERS = layers
301 | self.BATCHES = batches
302 | self.final_weights: Dict[int, Weight] = {}
303 |
304 | self.weights: Dict[Any, Weight] = {}
305 | self.opt_states: Dict[Any, OptState] = {}
306 | self.activations: Dict[Any, Activation] = {}
307 | self.grad_activations: Dict[Any, ActivationGrad] = {}
308 | self.grad_weights: Dict[Any, WeightGrad] = {}
309 |
310 | def storage(self) -> Tuple[Dict[Any, Weight], Dict[Any, OptState], Dict[Any, Activation], Dict[Any, ActivationGrad], Dict[Any, WeightGrad]]:
311 | return self.weights, self.opt_states, self.activations, self.grad_activations, self.grad_weights
312 |
313 | def memory(self) -> int:
314 | mem = 0
315 | for d in list(self.storage()):
316 | assert isinstance(d, dict)
317 | for v in d.values():
318 | mem += v.memory()
319 | return mem
320 |
321 | def status(self):
322 | for d in list(self.storage()):
323 | for k, v in d.items():
324 | print(k, type(v), end=",")
325 | print()
326 |
327 | def event(self, typ: str, layer: Optional[int]=None, batches: FrozenSet[int]=frozenset({})) -> None:
328 | length = 0
329 | if typ in ["loss", "allgather"]:
330 | length = 0
331 | if typ in ["forward", "backward"]:
332 | length = len(batches)
333 | if typ in ["update"]:
334 | length = 0.5
335 | if typ in ["allreduce", "scatterreduce", "allgather"]:
336 | length = 0.3
337 | if typ in ["pass"]:
338 | length = 0.2
339 |
340 | self.log.append(Event(typ, layer, self.rank, self.time, length, self.memory(), batches))
341 | self.time += length
342 | def load_weights(self, layer: int, shard: int = 0, total:int = 1 ) -> Tuple[Weight, OptState]:
343 | return Weight(layer, self.LAYERS, 0, frozenset([shard]), total),\
344 | OptState(layer, self.LAYERS, 0, frozenset([shard]), total)
345 |
346 | def set_final_weight(self, layer: int, weight:Weight) -> None:
347 | self.final_weights[layer] = weight
348 |
349 | def get_activation(self, batches: Sequence[int]) -> Activation:
350 | return Activation(0, self.LAYERS, frozenset(batches), self.BATCHES)
351 |
352 | def forward(self, layer: int, inp: Activation, weight: Weight) -> Activation:
353 | "Take in activation at layer i and return layer i + 1"
354 | self.event("forward", layer, inp.batches)
355 | assert weight.is_complete()
356 | assert weight.layer == layer, f"Weight should be layer {layer}"
357 | assert inp.layer == layer, f"Input should be layer {layer}"
358 | return Activation(layer + 1, self.LAYERS, inp.batches, self.BATCHES)
359 |
360 | def backward(
361 | self, layer: int, inp: Activation, grad: ActivationGrad, weight: Weight
362 | ) -> Tuple[WeightGrad, ActivationGrad]:
363 | self.event("backward", layer, inp.batches)
364 | assert weight.is_complete()
365 | assert weight.layer == layer, f"Weight should be layer {layer}"
366 | assert inp.layer == layer, f"Input should be layer {layer}"
367 | assert set(inp.batches) == set(
368 | grad.batches
369 | ), f"Batch mismatch {set(inp.batches)}"
370 | assert grad.layer == layer, f"Activation Grad should be layer {layer}"
371 | return (WeightGrad(layer, self.LAYERS, inp.batches, self.BATCHES),
372 | ActivationGrad(layer - 1, self.LAYERS, inp.batches, self.BATCHES))
373 |
374 | def loss(self, inp: Activation) -> ActivationGrad:
375 | self.event("loss", self.LAYERS)
376 | assert inp.layer == self.LAYERS, f"Input should be final layer {self.LAYERS}"
377 | return ActivationGrad(self.LAYERS - 1, self.LAYERS, inp.batches, self.BATCHES)
378 |
379 | def update(self, layer: int,
380 | weight_grad: WeightGrad,
381 | weight: Weight,
382 | opt_state: OptState,
383 | shard: int = 0) -> Tuple[Weight, OptState]:
384 |
385 | assert weight.layer == layer, f"Weight should be layer {layer}"
386 | assert weight_grad.layer == layer, f"Grad weight should be layer {layer}"
387 | assert set(weight_grad.batches) == set(
388 | range(self.BATCHES)
389 | ), f"{set(weight_grad.batches)}"
390 | assert opt_state.layer == layer
391 | if weight_grad.total > 1:
392 | assert weight.shards.issubset(weight_grad.shards), f"Weight {weight.shards}"
393 | assert opt_state.shards.issubset(weight_grad.shards), f"Opt {opt_state.shards}"
394 | assert weight.step == opt_state.step
395 | new_opt = OptState(layer, self.LAYERS, opt_state.step + 1, opt_state.shards, opt_state.total)
396 | new_weight = Weight(layer, self.LAYERS, weight.step + 1, weight.shards, weight.total)
397 | self.event("update", None)
398 | return new_weight, new_opt
399 |
400 | def fake_grad(self, layer: int, batches= List[int]):
401 | return WeightGrad(layer, self.LAYERS, frozenset(batches), self.BATCHES)
402 |
403 | async def allreduce(self, v: T, layer: int) -> T:
404 | v, self.time = await self.dist.allreduce(self.rank, v, self.time)
405 | self.event("allreduce", layer)
406 |
407 | return v
408 |
409 | async def scatterreduce(self, v: TO, layer:int) -> TO:
410 | v, self.time = await self.dist.scatterreduce(self.rank, v, self.time)
411 | self.event("scatterreduce", layer)
412 | return v
413 |
414 | async def allgather(self, v: O, layer:int) -> O:
415 | v, self.time = await self.dist.allgather(self.rank, v, self.time)
416 | self.event("allgather", layer)
417 | return v
418 |
419 | async def pass_to(self, rank: int, v: Any) -> None:
420 | self.event("pass", None)
421 | await self.dist.pass_to(rank, (v, self.time))
422 |
423 | async def receive(self) -> Any:
424 | v, time = await self.dist.receive(self.rank)
425 | self.time = max(time, self.time)
426 | self.event("pass", None)
427 | return v
428 |
429 | @staticmethod
430 | def check(models : Sequence[Model]) -> None:
431 | for l in range(models[0].LAYERS):
432 | weight = None
433 | for m in models:
434 | if l in m.final_weights:
435 | assert m.final_weights[l].step == 1
436 | if weight is None:
437 | weight = m.final_weights[l]
438 | else:
439 | weight = weight.combine(m.final_weights[l])
440 | assert weight is not None, f"Missing weight {l}"
441 | assert weight.is_complete(), f"Weight not complete {weight}"
442 |
443 | print("Correct!")
444 | from IPython.display import HTML
445 | pups = [
446 | "2m78jPG",
447 | "pn1e9TO",
448 | "MQCIwzT",
449 | "udLK6FS",
450 | "ZNem5o3",
451 | "DS2IZ6K",
452 | "aydRUz8",
453 | "MVUdQYK",
454 | "kLvno0p",
455 | "wScLiVz",
456 | "Z0TII8i",
457 | "F1SChho",
458 | "9hRi2jN",
459 | "lvzRF3W",
460 | "fqHxOGI",
461 | "1xeUYme",
462 | "6tVqKyM",
463 | "CCxZ6Wr",
464 | "lMW0OPQ",
465 | "wHVpHVG",
466 | "Wj2PGRl",
467 | "HlaTE8H",
468 | "k5jALH0",
469 | "3V37Hqr",
470 | "Eq2uMTA",
471 | "Vy9JShx",
472 | "g9I2ZmK",
473 | "Nu4RH7f",
474 | "sWp0Dqd",
475 | "bRKfspn",
476 | "qawCMl5",
477 | "2F6j2B4",
478 | "fiJxCVA",
479 | "pCAIlxD",
480 | "zJx2skh",
481 | "2Gdl1u7",
482 | "aJJAY4c",
483 | "ros6RLC",
484 | "DKLBJh7",
485 | "eyxH0Wc",
486 | "rJEkEw4"]
487 | return HTML("""
488 |
491 | """%(random.sample(pups, 1)[0]))
--------------------------------------------------------------------------------
/puzzles.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "d8204266",
6 | "metadata": {},
7 | "source": [
8 | "# LLM Training Puzzles\n",
9 | "\n",
10 | "by Sasha Rush ([@srush_nlp](https://twitter.com/srush_nlp))"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "id": "c65be14b",
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "%%capture\n",
21 | "# Uncomment to run in Colab\n",
22 | "!pip install -qqq git+https://github.com/chalk-diagrams/chalk asyncio\n",
23 | "!wget https://raw.githubusercontent.com/srush/LLM-Training-Puzzles/main/lib.py https://raw.githubusercontent.com/srush/LLM-Training-Puzzles/main/drawing.py"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": null,
29 | "id": "d583c979",
30 | "metadata": {},
31 | "outputs": [],
32 | "source": [
33 | "from typing import List\n",
34 | "from lib import Model, Dist, WeightGrad\n",
35 | "from drawing import draw, draw_group\n",
36 | "from chalk import vcat\n",
37 | "import asyncio\n",
38 | "import chalk\n",
39 | "chalk.set_svg_height(400)\n",
40 | "chalk.set_svg_draw_height(600)"
41 | ]
42 | },
43 | {
44 | "cell_type": "markdown",
45 | "id": "cc24dadb",
46 | "metadata": {},
47 | "source": [
48 | "## Preliminaries\n",
49 | "\n",
50 | "The goal of these puzzles is to learn about distributed training of LLMs. However, we will be primarily concerned with a speed and memory efficiency of completing a single update of the models. To make things simpler, we will abstract away from the standard tensor-based transformer model, and just consider a state-less representation of each of the components of a multi-layer neural network.\n",
51 | "\n"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": null,
57 | "id": "ad71f90c",
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "model = Model(layers=2, batches=4)\n",
62 | "weights, opt_states, activations, grad_activations, grad_weights = model.storage()"
63 | ]
64 | },
65 | {
66 | "cell_type": "markdown",
67 | "id": "61cf6388",
68 | "metadata": {},
69 | "source": [
70 | "Our library has 5 parts: \n",
71 | "\n",
72 | "* Weights\n",
73 | "* Optimizer States - Values needed to update the weights\n",
74 | "* Activations - The internal values computed on the forward pass\n",
75 | "* Grad Activations - The gradients of the loss wrt to activations, needed for backward pass\n",
76 | "* Grad Weights - The gradients of the loss wrt to weights, needed for updates\n",
77 | "\n",
78 | "For these puzzles, you are *not allowed* to have local variables. You need to store each of these in the dictionary corresponding to its type. \n",
79 | "\n",
80 | "We begin by tracing the lifecycle of a single model update."
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": null,
86 | "id": "0335a17b",
87 | "metadata": {},
88 | "outputs": [],
89 | "source": [
90 | "# Get the input activations to the model for batches 2, 3 \n",
91 | "activations[0] = model.get_activation(batches=[2, 3])\n",
92 | "activations[0]"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": null,
98 | "id": "962ac1d8",
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "# Load the weights (random) for layers 0 and 1\n",
103 | "for i in range(model.LAYERS):\n",
104 | " weights[i], opt_states[i] = model.load_weights(i)\n",
105 | "weights[0]"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": null,
111 | "id": "f7a83439",
112 | "metadata": {},
113 | "outputs": [],
114 | "source": [
115 | "# Activations can be moved forward a layer if you have the weights.\n",
116 | "activations[1] = model.forward(layer=0, inp=activations[0], weight=weights[0])\n",
117 | "activations[2] = model.forward(layer=1, inp=activations[1], weight=weights[1])\n",
118 | "activations[1]"
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": null,
124 | "id": "81d904e3",
125 | "metadata": {},
126 | "outputs": [],
127 | "source": [
128 | "# Draw all the current activations in memory.\n",
129 | "draw_group(activations)"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": null,
135 | "id": "80e8769a",
136 | "metadata": {},
137 | "outputs": [],
138 | "source": [
139 | "# At the last layer, we can convert an activation to a grad activation by calling `loss`\n",
140 | "grad_activations[model.LAYERS] = model.loss(activations[model.LAYERS])\n",
141 | "grad_activations[model.LAYERS]"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": null,
147 | "id": "b3299042",
148 | "metadata": {},
149 | "outputs": [],
150 | "source": [
151 | "# Calling `backward` requires the forward activation, the backward grad activation, and the weights.\n",
152 | "# It returns the grad weights and the backward activation.\n",
153 | "grad_weights[1], grad_activations[1] = model.backward(1, activations[1], grad_activations[2], weights[1])\n",
154 | "grad_weights[0], grad_activations[0] = model.backward(0, activations[0], grad_activations[1], weights[0])\n",
155 | "grad_activations[1]"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": null,
161 | "id": "33d82618",
162 | "metadata": {},
163 | "outputs": [],
164 | "source": [
165 | "# We can use delete to remove any memory that is not longer needed. \n",
166 | "print(\"Before memory:\", model.memory())\n",
167 | "del grad_activations[1]\n",
168 | "print(\"After memory:\", model.memory())\n",
169 | "model.status()\n",
170 | "draw_group(grad_activations)"
171 | ]
172 | },
173 | {
174 | "cell_type": "code",
175 | "execution_count": null,
176 | "id": "f48969f0",
177 | "metadata": {},
178 | "outputs": [],
179 | "source": [
180 | "# Grad weights keep track of which batches they are for. Here we only have the grad weights for batches 2 and 3.\n",
181 | "draw_group(grad_weights)"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": null,
187 | "id": "d3d469f7",
188 | "metadata": {},
189 | "outputs": [],
190 | "source": [
191 | "# If we try to update with the grad weights we will get an error.\n",
192 | "try:\n",
193 | " model.update(0, weight_grad=grad_weights[0], weight=weights[0], opt_state=opt_states[0])\n",
194 | "except AssertionError as e:\n",
195 | " print(\"Error! Only have batches\")\n",
196 | " print(e)"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": null,
202 | "id": "55046513",
203 | "metadata": {
204 | "lines_to_next_cell": 2
205 | },
206 | "outputs": [],
207 | "source": [
208 | "# For this example, we can cheat. Pretend we had the other gradients we needed. \n",
209 | "grad_weights[0, 0] = model.fake_grad(0, [0,1])\n",
210 | "grad_weights[1, 0] = model.fake_grad(1, [0,1])\n",
211 | "grad_weights[0, 0] "
212 | ]
213 | },
214 | {
215 | "cell_type": "code",
216 | "execution_count": null,
217 | "id": "f5caee88",
218 | "metadata": {},
219 | "outputs": [],
220 | "source": [
221 | "# Summing together grad_weights gives the full gradient.\n",
222 | "grad_weights[0] = grad_weights[0] + grad_weights[0, 0]"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": null,
228 | "id": "dfe2165e",
229 | "metadata": {
230 | "lines_to_next_cell": 2
231 | },
232 | "outputs": [],
233 | "source": [
234 | "# Now we can call update to the get the new weights and opt_state.\n",
235 | "weights[0], opt_states[0] = model.update(0, weight_grad=grad_weights[0], weight=weights[0], \n",
236 | " opt_state=opt_states[0])\n",
237 | "\n",
238 | "# WARNING: You need to set all variables. Otherwise they are not counted towards memory.\n",
239 | "grad_weights[1] = grad_weights[1] + grad_weights[1, 0]\n",
240 | "weights[1], opt_states[1] = model.update(1, weight_grad=grad_weights[1],\n",
241 | " weight=weights[1], opt_state=opt_states[1])"
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": null,
247 | "id": "f9582594",
248 | "metadata": {},
249 | "outputs": [],
250 | "source": [
251 | "# We can complete the tests by setting these as the final weights and calling check.\n",
252 | "model.set_final_weight(0, weights[0])\n",
253 | "model.set_final_weight(1, weights[1])\n",
254 | "Model.check([model])\n",
255 | "draw_group(model.final_weights)"
256 | ]
257 | },
258 | {
259 | "cell_type": "code",
260 | "execution_count": null,
261 | "id": "7b2f0a3b",
262 | "metadata": {},
263 | "outputs": [],
264 | "source": [
265 | "# We can view the final outcome of the system as a diagram. \n",
266 | "# This show the forward and backward passes (numbers of batches) and the updates.\n",
267 | "# The lines on the bottom show the memory that is used at each time step.\n",
268 | "draw([model])"
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "execution_count": null,
274 | "id": "01aa8d66",
275 | "metadata": {},
276 | "outputs": [],
277 | "source": []
278 | },
279 | {
280 | "cell_type": "markdown",
281 | "id": "26e5ea60",
282 | "metadata": {},
283 | "source": [
284 | "### Puzzle 0 - Standard Training\n",
285 | "\n",
286 | "Write a standard (non-distributed) training loop that acts on all the batches and loads all the weights. It should just run forward, loss, backward, and update. Aim for the least amount of max memory used. \n",
287 | "\n",
288 | "* Target Time: 17 steps\n",
289 | "* Target Memory: 2600000"
290 | ]
291 | },
292 | {
293 | "cell_type": "code",
294 | "execution_count": null,
295 | "id": "04dbf7ea",
296 | "metadata": {},
297 | "outputs": [],
298 | "source": [
299 | "def basic(model: Model) -> Model:\n",
300 | " # Storage on device.\n",
301 | " weights, opt_states, activations, grad_activations, grad_weights = model.storage()\n",
302 | "\n",
303 | " # Load in the full weights\n",
304 | " for l in range(model.LAYERS):\n",
305 | " weights[l], opt_states[l] = model.load_weights(l)\n",
306 | "\n",
307 | " # Load the input layer activations\n",
308 | " activations[0] = model.get_activation(range(model.BATCHES))\n",
309 | "\n",
310 | " assert False, 'TODO: Implement me'\n",
311 | " \n",
312 | " for l in range(model.LAYERS):\n",
313 | " model.set_final_weight(l, weights[l])\n",
314 | " return model"
315 | ]
316 | },
317 | {
318 | "cell_type": "code",
319 | "execution_count": null,
320 | "id": "e3632f24",
321 | "metadata": {},
322 | "outputs": [],
323 | "source": [
324 | "out = basic(Model(layers=2, batches=4, rank=0, dist=Dist(1)))\n",
325 | "draw_group(out.final_weights)"
326 | ]
327 | },
328 | {
329 | "cell_type": "code",
330 | "execution_count": null,
331 | "id": "21c87306",
332 | "metadata": {},
333 | "outputs": [],
334 | "source": [
335 | "draw([out])"
336 | ]
337 | },
338 | {
339 | "cell_type": "code",
340 | "execution_count": null,
341 | "id": "74e71f5a",
342 | "metadata": {},
343 | "outputs": [],
344 | "source": [
345 | "Model.check([out])"
346 | ]
347 | },
348 | {
349 | "cell_type": "markdown",
350 | "id": "d25bbcaf",
351 | "metadata": {},
352 | "source": [
353 | "### Puzzle 1 - Gradient Accumulation\n",
354 | "\n",
355 | "For this puzzle, the goal is to reduce max memory usage. To do so you are going to run on each batch individually instead of all together. \n",
356 | "\n",
357 | "Write a function with four parts. First run on batches {0} and then {1} etc. Sum the grad weights and then update.\n",
358 | "\n",
359 | "* Target Time: 17 steps\n",
360 | "* Target Memory: 2000000"
361 | ]
362 | },
363 | {
364 | "cell_type": "code",
365 | "execution_count": null,
366 | "id": "4c870ae6",
367 | "metadata": {},
368 | "outputs": [],
369 | "source": [
370 | "def grad_accum(model: Model) -> Model:\n",
371 | " # Storage on device.\n",
372 | " weights, opt_states, activations, grad_activations, grad_weights = model.storage()\n",
373 | "\n",
374 | " # Load in the full weights\n",
375 | " for l in range(model.LAYERS):\n",
376 | " weights[l], opt_states[l] = model.load_weights(l)\n",
377 | "\n",
378 | " assert False, 'TODO: Implement me'\n",
379 | " for l in range(model.LAYERS):\n",
380 | " model.set_final_weight(l, weights[l])\n",
381 | " return model"
382 | ]
383 | },
384 | {
385 | "cell_type": "code",
386 | "execution_count": null,
387 | "id": "672563cf",
388 | "metadata": {},
389 | "outputs": [],
390 | "source": [
391 | "out = grad_accum(Model(layers=2, batches=4, rank=0, dist=Dist(1)))\n",
392 | "draw_group(out.final_weights)"
393 | ]
394 | },
395 | {
396 | "cell_type": "code",
397 | "execution_count": null,
398 | "id": "77bb31d7",
399 | "metadata": {},
400 | "outputs": [],
401 | "source": [
402 | "draw([out])"
403 | ]
404 | },
405 | {
406 | "cell_type": "code",
407 | "execution_count": null,
408 | "id": "28cf8167",
409 | "metadata": {},
410 | "outputs": [],
411 | "source": [
412 | "Model.check([out])"
413 | ]
414 | },
415 | {
416 | "cell_type": "markdown",
417 | "id": "ccfb345c",
418 | "metadata": {},
419 | "source": [
420 | "## Communications: AllReduce"
421 | ]
422 | },
423 | {
424 | "cell_type": "markdown",
425 | "id": "17383602",
426 | "metadata": {},
427 | "source": [
428 | "When working with multiple GPUs we need to have communication. \n",
429 | "The primary communication primitives for GPUs are implemented in NCCL. \n",
430 | "\n",
431 | "https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/operations.html\n",
432 | "\n",
433 | "We are not going to use these directly, but simulate them using Python and asyncio. \n",
434 | "\n",
435 | "The first operation is AllReduce. We will have 4 GPUs (ranks=4) and use them each to compute a batch of weight grads."
436 | ]
437 | },
438 | {
439 | "cell_type": "code",
440 | "execution_count": null,
441 | "id": "1f7bb767",
442 | "metadata": {},
443 | "outputs": [],
444 | "source": [
445 | "ranks = 4\n",
446 | "weight_grads = [WeightGrad(0, 1, {i}, ranks) for i in range(ranks)]\n",
447 | "weight_grads[0] + weight_grads[1] + weight_grads[2] + weight_grads[3]"
448 | ]
449 | },
450 | {
451 | "cell_type": "code",
452 | "execution_count": null,
453 | "id": "ef0232db",
454 | "metadata": {},
455 | "outputs": [],
456 | "source": [
457 | "# Simple asynchronous function that calls allreduce to sum the weight grads at layer 0\n",
458 | "async def myfunc(model: Model) -> WeightGrad:\n",
459 | " return await model.allreduce(weight_grads[model.rank], 0)"
460 | ]
461 | },
462 | {
463 | "cell_type": "code",
464 | "execution_count": null,
465 | "id": "43a91ae4",
466 | "metadata": {},
467 | "outputs": [],
468 | "source": [
469 | "# This code uses asyncio to run the above function on 4 \"GPUs\" .\n",
470 | "dist = Dist(ranks)\n",
471 | "out_weight_grads = await asyncio.gather(*[\n",
472 | " myfunc(Model(layers=1, batches=1, rank=i, dist=dist))\n",
473 | " for i in range(ranks)])\n",
474 | "out_weight_grads[0]"
475 | ]
476 | },
477 | {
478 | "cell_type": "markdown",
479 | "id": "73d4ee90",
480 | "metadata": {},
481 | "source": [
482 | "Note: When running communication operations like AllReduce on a GPU, the communication happens in parallel to the computation on that GPU. That means the API for AllReduce does not block, and allows the model to continue running while waiting for this command to run. This means it is beneficial to run AllReduce (and other communication) as early as possible so that other compute can be run during the reduction. \n",
483 | "\n",
484 | "We will ignore this in these puzzles and represent communication as happening efficiently."
485 | ]
486 | },
487 | {
488 | "cell_type": "markdown",
489 | "id": "1020523e",
490 | "metadata": {},
491 | "source": [
492 | "### Puzzle 2 - Distributed Data Parallel\n",
493 | "\n",
494 | "Write a function with four parts. First run on batches {0} and then {1} etc. Sum the grad weights and then update. The main benefit of this approach is compute efficiency over gradient accumulation.\n",
495 | "\n",
496 | "* Total Steps: 5\n",
497 | "* Total Memory: 1800000"
498 | ]
499 | },
500 | {
501 | "cell_type": "code",
502 | "execution_count": null,
503 | "id": "f492668c",
504 | "metadata": {
505 | "lines_to_next_cell": 2
506 | },
507 | "outputs": [],
508 | "source": [
509 | "async def ddp(model: Model) -> Model:\n",
510 | " # Storage on device.\n",
511 | " weights, opt_states, activations, grad_activations, grad_weights = model.storage()\n",
512 | " # Load all the activations\n",
513 | " model.activations[0] = model.get_activation([model.rank])\n",
514 | "\n",
515 | " assert False, 'TODO: Implement me'\n",
516 | " for l in range(model.LAYERS):\n",
517 | " model.set_final_weight(l, weights[l])\n",
518 | " return model"
519 | ]
520 | },
521 | {
522 | "cell_type": "code",
523 | "execution_count": null,
524 | "id": "d6ed44f7",
525 | "metadata": {},
526 | "outputs": [],
527 | "source": [
528 | "dist = Dist(ranks)\n",
529 | "out = await asyncio.gather(*[\n",
530 | " ddp(Model(layers=2, batches=ranks, rank=i, dist=dist))\n",
531 | " for i in range(ranks)])\n",
532 | "draw_group(out[0].final_weights)"
533 | ]
534 | },
535 | {
536 | "cell_type": "code",
537 | "execution_count": null,
538 | "id": "89411fad",
539 | "metadata": {},
540 | "outputs": [],
541 | "source": [
542 | "draw(out)"
543 | ]
544 | },
545 | {
546 | "cell_type": "code",
547 | "execution_count": null,
548 | "id": "2df8166f",
549 | "metadata": {},
550 | "outputs": [],
551 | "source": [
552 | "Model.check(out)"
553 | ]
554 | },
555 | {
556 | "cell_type": "markdown",
557 | "id": "8c3df405",
558 | "metadata": {},
559 | "source": [
560 | "## Communication: AllGather / Sharding\n",
561 | "\n",
562 | "Our next primitive is AllGather. This allows us to communicate \"shards\" of an object stored on different GPUs to all the GPUs."
563 | ]
564 | },
565 | {
566 | "cell_type": "code",
567 | "execution_count": null,
568 | "id": "5b4bad85",
569 | "metadata": {},
570 | "outputs": [],
571 | "source": [
572 | "# Load only part of a weights.\n",
573 | "model = Model(layers=2, batches=1, rank=0, dist=Dist(1))\n",
574 | "weight, _ = model.load_weights(0, shard=0, total=4)\n",
575 | "weight"
576 | ]
577 | },
578 | {
579 | "cell_type": "code",
580 | "execution_count": null,
581 | "id": "2597b152",
582 | "metadata": {},
583 | "outputs": [],
584 | "source": [
585 | "# Combine togegher two shards on one machine.\n",
586 | "weights = [model.load_weights(0, shard=i, total=ranks)[0] for i in range(ranks)]\n",
587 | "weights[0].combine(weights[2])"
588 | ]
589 | },
590 | {
591 | "cell_type": "code",
592 | "execution_count": null,
593 | "id": "5e3d10b6",
594 | "metadata": {},
595 | "outputs": [],
596 | "source": [
597 | "# Use allgather to collect the shards from all machines.\n",
598 | "async def mygather(model: Model) -> WeightGrad:\n",
599 | " # Allreduce sums together all the weight grads\n",
600 | " return await model.allgather(weights[model.rank], 0)\n",
601 | "\n",
602 | "dist = Dist(ranks)\n",
603 | "out_weights = await asyncio.gather(*[\n",
604 | " mygather(Model(layers=1, batches=1, rank=i, dist=dist))\n",
605 | " for i in range(ranks)])\n",
606 | "out_weights[0]"
607 | ]
608 | },
609 | {
610 | "cell_type": "markdown",
611 | "id": "c3903613",
612 | "metadata": {},
613 | "source": [
614 | "### Puzzle 3: Weight-Sharded Data Parallel\n",
615 | "\n",
616 | "Run a model that shards each layer weight over all the machines. Reconstruct the layer weight at each layer using allgather. Finally update the weights on each machine using allreduce.\n",
617 | "\n",
618 | "* Total Steps: 20\n",
619 | "* Total Memory: 2800000"
620 | ]
621 | },
622 | {
623 | "cell_type": "code",
624 | "execution_count": null,
625 | "id": "4b674be7",
626 | "metadata": {},
627 | "outputs": [],
628 | "source": [
629 | "async def wsdp(model: Model) -> Model:\n",
630 | " # Storage on device.\n",
631 | " weights, opt_states, activations, grad_activations, grad_weights = model.storage()\n",
632 | "\n",
633 | " # Load all the activations\n",
634 | " model.activations[0] = model.get_activation([model.rank])\n",
635 | "\n",
636 | " # Load a shard of the weights for every layer. Load in the full weights\n",
637 | " for l in range(model.LAYERS):\n",
638 | " weights[l], opt_states[l] = model.load_weights(l, model.rank, model.RANKS) \n",
639 | "\n",
640 | " assert False, 'TODO: Implement me'\n",
641 | " for l in range(model.LAYERS):\n",
642 | " model.set_final_weight(l, weights[l])\n",
643 | "\n",
644 | " return model"
645 | ]
646 | },
647 | {
648 | "cell_type": "code",
649 | "execution_count": null,
650 | "id": "6c23a1eb",
651 | "metadata": {},
652 | "outputs": [],
653 | "source": [
654 | "dist = Dist(ranks)\n",
655 | "out = await asyncio.gather(*[\n",
656 | " wsdp(Model(layers=6, batches=ranks, rank=i, dist=dist))\n",
657 | " for i in range(ranks)])\n",
658 | "draw_group(out[1].final_weights)"
659 | ]
660 | },
661 | {
662 | "cell_type": "code",
663 | "execution_count": null,
664 | "id": "09731c67",
665 | "metadata": {},
666 | "outputs": [],
667 | "source": [
668 | "draw(out)"
669 | ]
670 | },
671 | {
672 | "cell_type": "code",
673 | "execution_count": null,
674 | "id": "d3ff46b6",
675 | "metadata": {},
676 | "outputs": [],
677 | "source": [
678 | "Model.check(out)"
679 | ]
680 | },
681 | {
682 | "cell_type": "markdown",
683 | "id": "32243386",
684 | "metadata": {},
685 | "source": [
686 | "## Communication: Scatter-Reduce"
687 | ]
688 | },
689 | {
690 | "cell_type": "markdown",
691 | "id": "44610031",
692 | "metadata": {},
693 | "source": [
694 | "Scatter across shards\n",
695 | "Reduce across batches"
696 | ]
697 | },
698 | {
699 | "cell_type": "code",
700 | "execution_count": null,
701 | "id": "1034dbb9",
702 | "metadata": {},
703 | "outputs": [],
704 | "source": [
705 | "grad_weight = WeightGrad(0, 1, batches={1}, total_batches=4, \n",
706 | " shards={1}, total=4)\n",
707 | "grad_weight"
708 | ]
709 | },
710 | {
711 | "cell_type": "code",
712 | "execution_count": null,
713 | "id": "02680e7e",
714 | "metadata": {},
715 | "outputs": [],
716 | "source": [
717 | "grad_weights = {i: WeightGrad(0, 1, batches={i}, total_batches=4, \n",
718 | " shards={0,1,2,3}, total=4) for i in range(4)}\n",
719 | "grad_weights[2]"
720 | ]
721 | },
722 | {
723 | "cell_type": "code",
724 | "execution_count": null,
725 | "id": "1498773a",
726 | "metadata": {
727 | "lines_to_next_cell": 0
728 | },
729 | "outputs": [],
730 | "source": [
731 | "async def scatterreduce(model: Model) -> WeightGrad:\n",
732 | " # Allreduce sums together all the weight grads\n",
733 | " return await model.scatterreduce(grad_weights[model.rank], 0)\n",
734 | "\n",
735 | "dist = Dist(ranks)\n",
736 | "out = await asyncio.gather(*[\n",
737 | " scatterreduce(Model(layers=1, batches=1, rank=i, dist=dist))\n",
738 | " for i in range(ranks)])\n",
739 | "out[0]"
740 | ]
741 | },
742 | {
743 | "cell_type": "code",
744 | "execution_count": null,
745 | "id": "261435f2",
746 | "metadata": {
747 | "lines_to_next_cell": 2
748 | },
749 | "outputs": [],
750 | "source": []
751 | },
752 | {
753 | "cell_type": "markdown",
754 | "id": "7e62da15",
755 | "metadata": {},
756 | "source": [
757 | "### Puzzle 4: Fully-Sharded Data Parallel\n",
758 | "\n",
759 | "Run a model that shards each layer weight over all the machines. Reconstruct the layer weight at each layer using allgather. Collect the gradients with scatter-reduce.\n",
760 | "\n",
761 | "* Total Steps: 20\n",
762 | "* Total Memory: 2300000"
763 | ]
764 | },
765 | {
766 | "cell_type": "code",
767 | "execution_count": null,
768 | "id": "43a35535",
769 | "metadata": {
770 | "lines_to_next_cell": 2
771 | },
772 | "outputs": [],
773 | "source": [
774 | "async def fsdp(model: Model) -> Model:\n",
775 | " # Storage on device.\n",
776 | " weights, opt_states, activations, grad_activations, grad_weights = model.storage()\n",
777 | "\n",
778 | " # Load all the activations\n",
779 | " model.activations[0] = model.get_activation([model.rank])\n",
780 | "\n",
781 | " # Load a shard of the weights for every layer. Load in the full weights\n",
782 | " for l in range(model.LAYERS):\n",
783 | " weights[l], opt_states[l] = model.load_weights(l, model.rank, model.RANKS) \n",
784 | "\n",
785 | " assert False, 'TODO: Implement me'\n",
786 | " for l in range(model.LAYERS):\n",
787 | " model.set_final_weight(l, weights[l])\n",
788 | " return model"
789 | ]
790 | },
791 | {
792 | "cell_type": "code",
793 | "execution_count": null,
794 | "id": "dec61bda",
795 | "metadata": {},
796 | "outputs": [],
797 | "source": [
798 | "dist = Dist(ranks)\n",
799 | "out = await asyncio.gather(*[\n",
800 | " fsdp(Model(layers=6, batches=ranks, rank=i, dist=dist))\n",
801 | " for i in range(ranks)])\n",
802 | "draw_group(out[1].final_weights)"
803 | ]
804 | },
805 | {
806 | "cell_type": "code",
807 | "execution_count": null,
808 | "id": "b9f62f28",
809 | "metadata": {},
810 | "outputs": [],
811 | "source": [
812 | "draw(out)"
813 | ]
814 | },
815 | {
816 | "cell_type": "code",
817 | "execution_count": null,
818 | "id": "3a527fc7",
819 | "metadata": {},
820 | "outputs": [],
821 | "source": [
822 | "Model.check(out)"
823 | ]
824 | },
825 | {
826 | "cell_type": "markdown",
827 | "id": "0858bb4f",
828 | "metadata": {},
829 | "source": [
830 | "## Communication: Point-to-Point\n",
831 | "\n",
832 | "An alternative approach to communication is to directly communicate specific information between GPUs. In our model, both GPUs talking to each other block and wait for the handoff. "
833 | ]
834 | },
835 | {
836 | "cell_type": "code",
837 | "execution_count": null,
838 | "id": "de38df4d",
839 | "metadata": {
840 | "lines_to_next_cell": 2
841 | },
842 | "outputs": [],
843 | "source": [
844 | "async def talk(model: Model) -> None:\n",
845 | " if model.rank == 0:\n",
846 | " await model.pass_to(1, \"extra cheese\")\n",
847 | " val = await model.receive()\n",
848 | " print(val)\n",
849 | " else:\n",
850 | " val = await model.receive()\n",
851 | " print(val)\n",
852 | " val = await model.pass_to(0, \"pizza\")\n",
853 | "\n",
854 | "dist = Dist(2)\n",
855 | "result = await asyncio.gather(*[\n",
856 | " talk(Model(layers=1, batches=1, rank=i, dist=dist))\n",
857 | " for i in range(2)])"
858 | ]
859 | },
860 | {
861 | "cell_type": "markdown",
862 | "id": "027b159c",
863 | "metadata": {},
864 | "source": [
865 | "### Puzzle 5: Pipeline Parallelism\n",
866 | "\n",
867 | "Split the layer weights and optimizers equally between GPUs. Have each GPU handle only its layer. Pass the full set of batches for activations and grad_activations between layers using p2p communication. No need for any global communication.\n",
868 | "\n",
869 | "* Total Steps: 66\n",
870 | "* Total Memory: 3300000"
871 | ]
872 | },
873 | {
874 | "cell_type": "code",
875 | "execution_count": null,
876 | "id": "09feb2a6",
877 | "metadata": {
878 | "lines_to_next_cell": 2
879 | },
880 | "outputs": [],
881 | "source": [
882 | "async def pipeline(model: Model) -> Model:\n",
883 | " weights, opt_states, activations, grad_activations, grad_weights = model.storage()\n",
884 | " per_rank = model.LAYERS // model.RANKS\n",
885 | " my_layers = list([l + (model.rank * per_rank) for l in range(per_rank)])\n",
886 | " for l in my_layers:\n",
887 | " weights[l], opt_states[l] = model.load_weights(l)\n",
888 | " assert False, 'TODO: Implement me'\n",
889 | " for l in my_layers:\n",
890 | " model.set_final_weight(l, weights[l])\n",
891 | " return model"
892 | ]
893 | },
894 | {
895 | "cell_type": "code",
896 | "execution_count": null,
897 | "id": "2e5c381b",
898 | "metadata": {},
899 | "outputs": [],
900 | "source": [
901 | "dist = Dist(ranks)\n",
902 | "out = await asyncio.gather(*[\n",
903 | " pipeline(Model(layers=8, batches=ranks, rank=i, dist=dist))\n",
904 | " for i in range(ranks)])\n",
905 | "draw_group(out[1].final_weights)"
906 | ]
907 | },
908 | {
909 | "cell_type": "code",
910 | "execution_count": null,
911 | "id": "5a99ecad",
912 | "metadata": {},
913 | "outputs": [],
914 | "source": [
915 | "draw(out)"
916 | ]
917 | },
918 | {
919 | "cell_type": "code",
920 | "execution_count": null,
921 | "id": "2b2f11d5",
922 | "metadata": {},
923 | "outputs": [],
924 | "source": [
925 | "Model.check(out)"
926 | ]
927 | },
928 | {
929 | "cell_type": "markdown",
930 | "id": "f6606efc",
931 | "metadata": {},
932 | "source": [
933 | "### Puzzle 6: GPipe Schedule\n",
934 | "\n",
935 | "A major issue with the pipeline approach is that it causes a \"bubble\", i.e. time in the later layers waiting for the earlier layers to complete. An alternative approach is to split the batches smaller so you can pass them earlier. \n",
936 | "\n",
937 | "In this puzzle, you should run each batch by itself, and then pass. The graph should look similar as the one above but with a smaller bubble. \n",
938 | "\n",
939 | "* Total Steps: 33\n",
940 | "* Total Memory: 4100000"
941 | ]
942 | },
943 | {
944 | "cell_type": "code",
945 | "execution_count": null,
946 | "id": "f5f33513",
947 | "metadata": {
948 | "lines_to_next_cell": 2
949 | },
950 | "outputs": [],
951 | "source": [
952 | "async def gpipe(model: Model) -> Model:\n",
953 | " weights, opt_states, activations, grad_activations, grad_weights = model.storage()\n",
954 | " per_rank = model.LAYERS // model.RANKS\n",
955 | " my_layers = list([l + (model.rank * per_rank) for l in range(per_rank)])\n",
956 | " for l in my_layers:\n",
957 | " weights[l], opt_states[l] = model.load_weights(l)\n",
958 | "\n",
959 | " assert False, 'TODO: Implement me'\n",
960 | " for l in my_layers:\n",
961 | " model.set_final_weight(l, weights[l])\n",
962 | "\n",
963 | " return model"
964 | ]
965 | },
966 | {
967 | "cell_type": "code",
968 | "execution_count": null,
969 | "id": "f5c73b29",
970 | "metadata": {},
971 | "outputs": [],
972 | "source": [
973 | "dist = Dist(ranks)\n",
974 | "out = await asyncio.gather(*[\n",
975 | " gpipe(Model(layers=8, batches=ranks, rank=i, dist=dist))\n",
976 | " for i in range(ranks)])\n",
977 | "draw_group(out[1].final_weights)"
978 | ]
979 | },
980 | {
981 | "cell_type": "code",
982 | "execution_count": null,
983 | "id": "0e759da9",
984 | "metadata": {},
985 | "outputs": [],
986 | "source": [
987 | "draw(out)"
988 | ]
989 | },
990 | {
991 | "cell_type": "code",
992 | "execution_count": null,
993 | "id": "47d3102b",
994 | "metadata": {
995 | "lines_to_next_cell": 2
996 | },
997 | "outputs": [],
998 | "source": [
999 | "Model.check(out)"
1000 | ]
1001 | },
1002 | {
1003 | "cell_type": "markdown",
1004 | "id": "3a4bc062",
1005 | "metadata": {},
1006 | "source": [
1007 | "### Puzzle 7: Pipeline + FSDP\n",
1008 | "\n",
1009 | "As a last exercise, we can put everything together. Here we are going to run a combination of pipeline parallelism while also sharding our weight between 16 different machines. Here the model only has 4 layers, so we will assign 4 GPUs to each layer in the pipeline parallel approach. \n",
1010 | "\n",
1011 | "This example requires combining both collective communication and p2p communication effectively. \n",
1012 | "\n",
1013 | "* Total Steps: 15\n",
1014 | "* Total Memory: 1000000"
1015 | ]
1016 | },
1017 | {
1018 | "cell_type": "code",
1019 | "execution_count": null,
1020 | "id": "34757c26",
1021 | "metadata": {},
1022 | "outputs": [],
1023 | "source": [
1024 | "async def pipeline_fsdp(model: Model) -> Model:\n",
1025 | " weights, opt_states, activations, grad_activations, grad_weights = model.storage()\n",
1026 | " per_rank = model.LAYERS // (model.RANKS // 4)\n",
1027 | " my_layers = list([l + ((model.rank % 4) * per_rank) for l in range(per_rank)])\n",
1028 | " for l in range(model.LAYERS):\n",
1029 | " weights[l, 0], opt_states[l, 0] = model.load_weights(l, model.rank, model.RANKS)\n",
1030 | " def empty_grad(l):\n",
1031 | " return model.fake_grad(l, [])\n",
1032 | " assert False, 'TODO: Implement me'\n",
1033 | " for l in range(model.LAYERS):\n",
1034 | " model.set_final_weight(l, weights[l])\n",
1035 | " # Update\n",
1036 | " return model"
1037 | ]
1038 | },
1039 | {
1040 | "cell_type": "code",
1041 | "execution_count": null,
1042 | "id": "35a0877b",
1043 | "metadata": {
1044 | "lines_to_next_cell": 2
1045 | },
1046 | "outputs": [],
1047 | "source": [
1048 | "dist = Dist(16)\n",
1049 | "out = await asyncio.gather(*[\n",
1050 | " pipeline_fsdp(Model(layers=4, batches=ranks, rank=i, dist=dist))\n",
1051 | " for i in range(16)])"
1052 | ]
1053 | },
1054 | {
1055 | "cell_type": "code",
1056 | "execution_count": null,
1057 | "id": "d93c6256",
1058 | "metadata": {},
1059 | "outputs": [],
1060 | "source": [
1061 | "Model.check(out)\n",
1062 | "chalk.set_svg_height(1000)\n",
1063 | "chalk.set_svg_draw_height(1000) \n",
1064 | "\n",
1065 | "draw(out)"
1066 | ]
1067 | },
1068 | {
1069 | "cell_type": "markdown",
1070 | "id": "24393483",
1071 | "metadata": {},
1072 | "source": [
1073 | "### When does it make sense to combine?\n",
1074 | "\n",
1075 | "The goal of these exercises is to give you a sense of the different methods out there for distributed training. However, there is not currently a one size fits all approach for distributed training. The right choice will depend on the constants such as batch size, memory per GPU, communication overhead, implementation complexity, model size, and specifics of architecture. \n",
1076 | "\n",
1077 | "As an example of what's left to explore, this last method Pipeline + FSDP is often not a great choice due to the complexities of communication speed. And in fact GPipe + FSDP also gets you into a bad place. The paper [Breadth First Pipeline Parallelism](https://arxiv.org/pdf/2211.05953.pdf) proposes instead a combination of pipeline scheduling and communication. Here's what it looks like. "
1078 | ]
1079 | },
1080 | {
1081 | "cell_type": "markdown",
1082 | "id": "35c9493a",
1083 | "metadata": {},
1084 | "source": [
1085 | ""
1086 | ]
1087 | }
1088 | ],
1089 | "metadata": {
1090 | "jupytext": {
1091 | "cell_metadata_filter": "-all",
1092 | "custom_cell_magics": "kql"
1093 | },
1094 | "kernelspec": {
1095 | "display_name": "venv",
1096 | "language": "python",
1097 | "name": "python3"
1098 | }
1099 | },
1100 | "nbformat": 4,
1101 | "nbformat_minor": 5
1102 | }
1103 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | asyncio
2 | git+https://github.com/chalk-diagrams/chalk/
3 |
--------------------------------------------------------------------------------