\
201 | --enable_each_rank_log ../logs \
202 | train_llm.py \
203 | --experiment-name deepspeed-multi-node \
204 | --dataset-name tatsu-lab/alpaca \
205 | --model-name openai-community/gpt2
206 | ```
207 |
--------------------------------------------------------------------------------
/03-job-launchers/job.sbatch:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # SBATCH --ntasks-per-node=1
4 |
5 | source $(pwd)/../venv/bin/activate
6 |
7 | MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
8 | MASTER_PORT=$(expr 5000 + $(echo -n ${SLURM_JOBID} | tail -c 4))
9 | export TORCHELASTIC_ERROR_FILE=./error-${SLURM_JOBID}-${SLURM_NODEID}.json
10 | export OMP_NUM_THREADS=1
11 | export HF_HOME=../.cache
12 |
13 | printenv
14 |
15 | srun torchrun \
16 | --rdzv-id "slurm-${SLURM_JOBID}" \
17 | --rdzv-backend c10d \
18 | --rdzv-endpoint ${MASTER_ADDR}:${MASTER_PORT} \
19 | --nnodes ${SLURM_NNODES} \
20 | --nproc-per-node ${SLURM_GPUS_ON_NODE} \
21 | --redirects 3 \
22 | --log-dir ${SLURM_SUBMIT_DIR}/logs \
23 | ../03-multi-node/train_llm.py \
24 | --experiment-name gpt2-alpaca-slurm-$(date +%Y-%m-%dT%H-%M-%S) \
25 | --dataset-name tatsu-lab/alpaca \
26 | --model-name openai-community/gpt2
27 |
--------------------------------------------------------------------------------
/04-fully-sharded-data-parallel/train_llm.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from contextlib import contextmanager
3 | import functools
4 | from itertools import chain
5 | import json
6 | import multiprocessing
7 | import os
8 | import time
9 | from pathlib import Path
10 | import logging
11 |
12 | import torch
13 | from torch.utils.data import DataLoader
14 | from torch.utils.data.distributed import DistributedSampler
15 | from torch import distributed as dist
16 | from torch.distributed.elastic.multiprocessing.errors import record
17 | from torch.distributed.fsdp.fully_sharded_data_parallel import (
18 | FullyShardedDataParallel,
19 | CPUOffload,
20 | ShardingStrategy,
21 | )
22 | from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
23 | from torch.distributed.checkpoint.state_dict import (
24 | get_state_dict,
25 | set_state_dict,
26 | StateDictOptions,
27 | )
28 | from torch.distributed.checkpoint import load, save
29 |
30 |
31 | import wandb
32 | import tqdm
33 | import datasets
34 | from transformers import (
35 | AutoConfig,
36 | AutoModelForCausalLM,
37 | AutoTokenizer,
38 | default_data_collator,
39 | )
40 | from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaRotaryEmbedding
41 |
42 | # fixes for reset_parameters not existing
43 | LlamaRMSNorm.reset_parameters = lambda self: torch.nn.init.ones_(self.weight)
44 | LlamaRotaryEmbedding.reset_parameters = lambda _: None
45 |
46 | LOGGER = logging.getLogger(__name__)
47 |
48 |
49 | @record
50 | def main():
51 | parser = _get_parser()
52 | args = parser.parse_args()
53 |
54 | dist.init_process_group()
55 |
56 | rank = dist.get_rank()
57 | local_rank = rank % torch.cuda.device_count()
58 | world_size = dist.get_world_size()
59 |
60 | logging.basicConfig(
61 | format=f"[rank={rank}] [%(asctime)s] %(levelname)s:%(message)s",
62 | level=logging.INFO,
63 | )
64 |
65 | LOGGER.info(os.environ)
66 | LOGGER.info(args)
67 | LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}")
68 |
69 | device = torch.device(f"cuda:{local_rank}")
70 | dtype = torch.bfloat16
71 | torch.cuda.set_device(device)
72 |
73 | torch.manual_seed(args.seed)
74 |
75 | with rank0_first():
76 | config = AutoConfig.from_pretrained(args.model_name, use_cache=False)
77 | # NOTE: meta device will not allocate any memory
78 | with torch.device("meta"):
79 | model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype)
80 | LOGGER.info(f"{sum(p.numel() for p in model.parameters())} model parameters")
81 |
82 | LOGGER.info(f"Before FSDP: {get_mem_stats(device)}")
83 |
84 | wrap_policy = functools.partial(
85 | size_based_auto_wrap_policy, min_num_params=int(args.numel_to_wrap)
86 | )
87 | model = FullyShardedDataParallel(
88 | model,
89 | device_id=local_rank,
90 | sync_module_states=True,
91 | # NOTE: FULL_SHARD is equivalent to deepspeed ZeRO stage 3
92 | auto_wrap_policy=wrap_policy,
93 | sharding_strategy=ShardingStrategy.FULL_SHARD,
94 | cpu_offload=CPUOffload(offload_params=args.cpu_offload == "on"),
95 | )
96 |
97 | LOGGER.info(f"After FSDP: {get_mem_stats(device)}")
98 |
99 | # NOTE: since this can download data, make sure to do the main process first
100 | # NOTE: This assumes that the data is on a **shared** network drive, accessible to all processes
101 | with rank0_first():
102 | train_data = _load_and_preprocess_data(args, config)
103 | LOGGER.info(f"{len(train_data)} training samples")
104 |
105 | dataloader = DataLoader(
106 | train_data,
107 | batch_size=args.batch_size,
108 | collate_fn=default_data_collator,
109 | # NOTE: this sampler will split dataset evenly across workers
110 | sampler=DistributedSampler(train_data, shuffle=True, drop_last=True),
111 | )
112 | LOGGER.info(f"{len(dataloader)} batches per epoch")
113 |
114 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, fused=True)
115 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
116 | optimizer, T_max=1000, eta_min=args.lr * 1e-2
117 | )
118 |
119 | exp_dir: Path = Path(args.save_dir) / args.experiment_name
120 |
121 | # NOTE: full_state_dict=False means we will be saving sharded checkpoints.
122 | ckpt_opts = StateDictOptions(full_state_dict=False, cpu_offload=True)
123 |
124 | # attempt resume
125 | state = {
126 | "epoch": 0,
127 | "global_step": 0,
128 | "epoch_step": 0,
129 | "running_loss": 0,
130 | }
131 | resumed = False
132 | if (exp_dir / "state.json").exists():
133 | sharded_model_state, sharded_optimizer_state = get_state_dict(
134 | model, optimizer, options=ckpt_opts
135 | )
136 | load(
137 | dict(model=sharded_model_state, optimizer=sharded_optimizer_state),
138 | checkpoint_id=exp_dir / "checkpoint",
139 | )
140 | set_state_dict(
141 | model,
142 | optimizer,
143 | model_state_dict=sharded_model_state,
144 | optim_state_dict=sharded_optimizer_state,
145 | options=ckpt_opts,
146 | )
147 | lr_scheduler.load_state_dict(
148 | torch.load(
149 | exp_dir / "lr_scheduler.pt", map_location=device, weights_only=True
150 | )
151 | )
152 | with open(exp_dir / "state.json") as fp:
153 | state = json.load(fp)
154 | resumed = True
155 | LOGGER.info(f"Resumed={resumed} | {state}")
156 | dist.barrier()
157 |
158 | if (exp_dir.is_mount() and rank == 0) or (
159 | not exp_dir.is_mount() and local_rank == 0
160 | ):
161 | LOGGER.info(f"Creating experiment root directory")
162 | exp_dir.mkdir(parents=True, exist_ok=True)
163 | dist.barrier()
164 |
165 | (exp_dir / f"rank-{rank}").mkdir(parents=True, exist_ok=True)
166 | LOGGER.info(f"Worker saving to {exp_dir / f'rank-{rank}'}")
167 |
168 | if rank == 0:
169 | wandb.init(
170 | project="distributed-training-guide",
171 | dir=exp_dir,
172 | name=args.experiment_name,
173 | id=args.experiment_name,
174 | resume="must" if resumed else None,
175 | save_code=True,
176 | config={
177 | "args": vars(args),
178 | "training_data_size": len(train_data),
179 | "num_batches": len(dataloader),
180 | "world_size": world_size,
181 | },
182 | )
183 |
184 | timers = {k: LocalTimer(device) for k in ["data", "forward", "backward", "update"]}
185 |
186 | for state["epoch"] in range(state["epoch"], args.num_epochs):
187 | LOGGER.info(f"Begin epoch {state['epoch']} at step {state['epoch_step']}")
188 |
189 | progress_bar = tqdm.tqdm(range(len(dataloader)), disable=rank > 0)
190 | if state["epoch_step"] > 0:
191 | progress_bar.update(state["epoch_step"])
192 |
193 | dataloader.sampler.set_epoch(state["epoch"])
194 | batches = iter(dataloader)
195 |
196 | for i_step in range(len(dataloader)):
197 | with timers["data"], torch.no_grad():
198 | batch = next(batches)
199 | batch = {k: v.to(device=device) for k, v in batch.items()}
200 |
201 | if i_step < state["epoch_step"]:
202 | # NOTE: for resuming
203 | continue
204 |
205 | with timers["forward"]:
206 | outputs = model(**batch)
207 |
208 | with timers["backward"]:
209 | optimizer.zero_grad(set_to_none=True)
210 | outputs.loss.backward()
211 |
212 | with timers["update"]:
213 | optimizer.step()
214 | lr_scheduler.step()
215 |
216 | state["global_step"] += 1
217 | state["epoch_step"] += 1
218 | state["running_loss"] += outputs.loss.item()
219 | progress_bar.update(1)
220 |
221 | if state["global_step"] % args.log_freq == 0:
222 | tok_per_step = world_size * args.batch_size * args.seq_length
223 | ms_per_step = sum(t.avg_elapsed_ms() for t in timers.values())
224 | info = {
225 | "global_step": state["global_step"],
226 | "lr": lr_scheduler.get_last_lr()[0],
227 | "running_loss": state["running_loss"] / args.log_freq,
228 | "epoch": state["epoch"],
229 | "epoch_progress": state["epoch_step"] / len(dataloader),
230 | "num_batches_remaining": len(dataloader) - i_step,
231 | **get_mem_stats(device),
232 | "tok/s": 1000 * tok_per_step / ms_per_step,
233 | "time/total": ms_per_step,
234 | **{
235 | f"time/{k}": timer.avg_elapsed_ms()
236 | for k, timer in timers.items()
237 | },
238 | }
239 |
240 | LOGGER.info(info)
241 | if rank == 0:
242 | wandb.log(info, step=state["global_step"])
243 |
244 | torch.cuda.reset_peak_memory_stats(device)
245 | state["running_loss"] = 0
246 | for t in timers.values():
247 | t.reset()
248 |
249 | if state["global_step"] % args.ckpt_freq == 0:
250 | dist.barrier()
251 | # NOTE: we have to call this on ALL ranks
252 | sharded_model_state, sharded_optimizer_state = get_state_dict(
253 | model, optimizer, options=ckpt_opts
254 | )
255 | save(
256 | dict(model=sharded_model_state, optimizer=sharded_optimizer_state),
257 | checkpoint_id=exp_dir / "checkpoint",
258 | )
259 | if rank == 0:
260 | torch.save(lr_scheduler.state_dict(), exp_dir / "lr_scheduler.pt")
261 | with open(exp_dir / "state.json", "w") as fp:
262 | json.dump(state, fp)
263 | dist.barrier()
264 |
265 | state["epoch_step"] = 0
266 |
267 |
268 | def _load_and_preprocess_data(args, config):
269 | """
270 | Function created using code found in
271 | https://github.com/huggingface/transformers/blob/v4.45.1/examples/pytorch/language-modeling/run_clm_no_trainer.py
272 | """
273 | tokenizer = AutoTokenizer.from_pretrained(args.model_name)
274 |
275 | data = datasets.load_dataset(args.dataset_name, trust_remote_code=True)
276 |
277 | column_names = data["train"].column_names
278 | text_column_name = "text" if "text" in column_names else column_names[0]
279 |
280 | def tokenize_function(examples):
281 | return tokenizer(examples[text_column_name])
282 |
283 | tokenized_datasets = data.map(
284 | tokenize_function,
285 | batched=True,
286 | remove_columns=column_names,
287 | num_proc=multiprocessing.cpu_count(),
288 | load_from_cache_file=True,
289 | desc="Running tokenizer on dataset",
290 | )
291 |
292 | seq_length = args.seq_length or tokenizer.model_max_length
293 | if seq_length > config.max_position_embeddings:
294 | seq_length = min(1024, config.max_position_embeddings)
295 |
296 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
297 | def group_texts(examples):
298 | # Concatenate all texts.
299 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
300 | total_length = len(concatenated_examples[list(examples.keys())[0]])
301 | # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict.
302 | # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
303 | if total_length > seq_length:
304 | total_length = (total_length // seq_length) * seq_length
305 | # Split by chunks of max_len.
306 | result = {
307 | k: [t[i : i + seq_length] for i in range(0, total_length, seq_length)]
308 | for k, t in concatenated_examples.items()
309 | }
310 | result["labels"] = result["input_ids"].copy()
311 | return result
312 |
313 | lm_datasets = tokenized_datasets.map(
314 | group_texts,
315 | batched=True,
316 | num_proc=multiprocessing.cpu_count(),
317 | load_from_cache_file=True,
318 | desc=f"Grouping texts in chunks of {seq_length}",
319 | )
320 |
321 | return lm_datasets["train"]
322 |
323 |
324 | def get_mem_stats(device=None):
325 | mem = torch.cuda.memory_stats(device)
326 | props = torch.cuda.get_device_properties(device)
327 | return {
328 | "total_gb": 1e-9 * props.total_memory,
329 | "curr_alloc_gb": 1e-9 * mem["allocated_bytes.all.current"],
330 | "peak_alloc_gb": 1e-9 * mem["allocated_bytes.all.peak"],
331 | "curr_resv_gb": 1e-9 * mem["reserved_bytes.all.current"],
332 | "peak_resv_gb": 1e-9 * mem["reserved_bytes.all.peak"],
333 | }
334 |
335 |
336 | @contextmanager
337 | def rank0_first():
338 | rank = dist.get_rank()
339 | if rank == 0:
340 | yield
341 | dist.barrier()
342 | if rank > 0:
343 | yield
344 | dist.barrier()
345 |
346 |
347 | class LocalTimer:
348 | def __init__(self, device: torch.device):
349 | if device.type == "cpu":
350 | self.synchronize = lambda: torch.cpu.synchronize(device=device)
351 | elif device.type == "cuda":
352 | self.synchronize = lambda: torch.cuda.synchronize(device=device)
353 | self.measurements = []
354 | self.start_time = None
355 |
356 | def __enter__(self):
357 | self.synchronize()
358 | self.start_time = time.time()
359 | return self
360 |
361 | def __exit__(self, type, value, traceback):
362 | if traceback is None:
363 | self.synchronize()
364 | end_time = time.time()
365 | self.measurements.append(end_time - self.start_time)
366 | self.start_time = None
367 |
368 | def avg_elapsed_ms(self):
369 | return 1000 * (sum(self.measurements) / len(self.measurements))
370 |
371 | def reset(self):
372 | self.measurements = []
373 | self.start_time = None
374 |
375 |
376 | def _get_parser() -> argparse.ArgumentParser:
377 | parser = argparse.ArgumentParser()
378 | parser.add_argument("-e", "--experiment-name", default=None, required=True)
379 | parser.add_argument("-d", "--dataset-name", default=None, required=True)
380 | parser.add_argument("-m", "--model-name", default=None, required=True)
381 | parser.add_argument("--save-dir", default="../outputs")
382 | parser.add_argument("--seed", default=0, type=int)
383 | parser.add_argument("--num-epochs", default=100, type=int)
384 | parser.add_argument("--lr", default=3e-5, type=float)
385 | parser.add_argument("-b", "--batch-size", default=1, type=int)
386 | parser.add_argument("--log-freq", default=100, type=int)
387 | parser.add_argument("--ckpt-freq", default=500, type=int)
388 | parser.add_argument("-s", "--seq-length", default=1024, type=int)
389 | parser.add_argument(
390 | "--numel-to-wrap",
391 | default=100_000_000,
392 | type=int,
393 | help="Only applies FSDP to modules with numel > this value.",
394 | )
395 | parser.add_argument("--cpu-offload", default="off", choices=["on", "off"])
396 | return parser
397 |
398 |
399 | if __name__ == "__main__":
400 | main()
401 |
--------------------------------------------------------------------------------
/05-training-llama-405b/README.md:
--------------------------------------------------------------------------------
1 | # Training a 405B model
2 |
3 | **NOTE: This chapter's code builds off of [chapter 4's FSDP code](../04-fully-sharded-data-parallel/).**
4 |
5 | Here we are going to utilize an 8 node cluster (64 H100 GPUs) to train Llama 3.1 405B. **This does not utilize LORA!** We are actually fully training the weights of a 405b model in plain pytorch.
6 |
7 | The next few sections go through various changes we have to make to our FSDP code from chapter 4 to make training a 405b model work.
8 |
9 | Quick Jump:
10 | - [Use flash attention](#use-flash-attention)
11 | - [Download model weights](#download-model-weights)
12 | - [Loading pretrained weights](#loading-pretrained-weights)
13 | - [Sharding Llama 405B](#sharding-llama-405b)
14 | - [Gradient (aka activation) checkpointing](#gradient-aka-activation-checkpointing)
15 | - [CPU Offload \& fused optimizer kernels](#cpu-offload--fused-optimizer-kernels)
16 | - [NOT de-allocating gradients](#not-de-allocating-gradients)
17 | - [Launch command](#launch-command)
18 | - [Monitoring](#monitoring)
19 | - [Run statistics](#run-statistics)
20 | - [Other notes on settings that didn't affect throughput](#other-notes-on-settings-that-didnt-affect-throughput)
21 |
22 | ## Use flash attention
23 |
24 | Flash attention is a fused implementation of scaled dot product attention that heavily minimizes memory usage. The whole goal behind it is to query memory as little as possible, and minimize temporary memory used.
25 |
26 | Check out the [repo](https://github.com/Dao-AILab/flash-attention) and the [paper](https://arxiv.org/abs/2205.14135) for more information.
27 |
28 | This ends up saving us 10s of gb in the forward/backward pass.
29 |
30 | Install:
31 |
32 | ```bash
33 | pip install packaging
34 | pip install ninja
35 | pip install flash-attn --no-build-isolation
36 | ```
37 |
38 | Use it when we initialize our model:
39 |
40 | ```python
41 | model = AutoModelForCausalLM.from_pretrained(
42 | ...
43 | attn_implementation="flash_attention_2",
44 | )
45 | ```
46 |
47 | ## Download model weights
48 |
49 | The actual model weights are huge - it contains 191 separate files which are each about 4GB - totally about 764 GB.
50 |
51 | There are two options for storing these weights here (and they make a difference!):
52 |
53 | 1. A shared network drive that all the nodes can access
54 | 2. Locally on the main rank 0 node
55 |
56 | Node local storage is **much** faster when initializing. For some numbers, while running this script on 8 8xH100 80GB nodes, the shared network drive took 50 minutes to initialize, while the node local storage only took 3 minutes.
57 |
58 | There's a download script in this repo for utility, run this on node 0:
59 |
60 | ```bash
61 | cd distributed-training-guide/05-training-llama-405b
62 | python download.py
63 | ```
64 |
65 | And run this on the other nodes (to download config & tokenizer):
66 |
67 | ```bash
68 | cd distributed-training-guide/05-training-llama-405b
69 | python download.py --skip-model
70 | ```
71 |
72 | NOTE: you will likely have to log into your huggingface account using `huggingface-cli login`.
73 |
74 | ## Loading pretrained weights
75 |
76 | When we actual load the weights, it will take some time AND takes a lot of memory to load. Again the full size is about 764 GB, so we need to make sure we have enough RAM to store the weights.
77 |
78 | There's three parts to this:
79 |
80 | 1. Loading the weights into RAM only on `rank==0`
81 | 2. Using the [meta](../04-fully-sharded-data-parallel/README.md#initialization-after-sharding---the-meta-device) device on `rank>0`
82 | 3. Using `from_config` instead of `from_pretrained` on `rank>0` so we don't need to download the weights on all the nodes.
83 | 1. Note that if you have the weights on a shared network drive, you can just use `from_pretrained` instead.
84 | 4. Enabling [sync_module_states](../04-fully-sharded-data-parallel/README.md#sync_module_states) in FSDP constructor
85 |
86 | You might think of using the `device_map` feature of `transformers` - e.g. `device_map="auto"` tries to smartly fill up memory. However if you try this approach you'll end up with out of memory errors when FSDP tries to start sending memory to the GPU.
87 |
88 | Here's our code snippet for doing this:
89 |
90 | ```python
91 | if rank == 0:
92 | with torch.device("cpu"):
93 | model = AutoModelForCausalLM.from_pretrained(...)
94 | else:
95 | with torch.device("meta"):
96 | model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype)
97 | ```
98 |
99 | Then later, sync_module_states in [FSDP constructor](../04-fully-sharded-data-parallel/README.md#the-fsdp-constructor) will make sure the weights are broadcasted from rank 0 to the other ranks.
100 |
101 | ## Sharding Llama 405B
102 |
103 | Determining what layers you should shard is complex. If you are using `transformers`, they include a private attribute on classes called [_no_split_modules](https://github.com/huggingface/transformers/blob/v4.45.1/src/transformers/models/llama/modeling_llama.py#L784) that will contain classes that you should not shard anything under them. E.g. for Llama this attribute just contains `LlamaDecoderLayer`. So that is what we will wrap! During testing I also found that sharding the `nn.Embedding` layer at the beginning of the network improved throughput and reduced memory usage.
104 |
105 | We can use the [transformer_auto_wrap_policy()](https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/wrap.py#L307C5-L307C33) to target the specific classes for those layers, and pass that as our [auto_wrap_policy in the FSDP constructor](../04-fully-sharded-data-parallel/README.md#what-layers-to-shard---the-auto_wrap_policy):
106 |
107 | ```python
108 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
109 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer
110 |
111 | wrap_policy = functools.partial(
112 | transformer_auto_wrap_policy,
113 | transformer_layer_cls={LlamaDecoderLayer, nn.Embedding},
114 | )
115 | FSDP(..., auto_wrap_policy=wrap_policy)
116 | ```
117 |
118 | Please consult [our explanation on the FSDP constructor](../04-fully-sharded-data-parallel/README.md#the-fsdp-constructor) for more info.
119 |
120 | As a reminder - this will cause FSDP to gather all the parameters for each DecoderLayer (which includes Attention, Linear, and various norm modules), and shard them across the world. At the start of forward/backward pass FSDP will issue an all-gather so all the nodes have the full weights in memory, and at the end of the DecoderLayer forward/backward, it will free up the full weights again.
121 |
122 | So where you apply FSDP determines where the all-gather happens!
123 |
124 | ## Gradient (aka activation) checkpointing
125 |
126 | Another piece of reducing memory usage is gradient checkpointing (first introduced in [Training Deep Nets with Sublinear Memory Cost](https://arxiv.org/abs/1604.06174)). Normally when you do the forward pass, you have to keep the input & output in memory until you run the backward pass. This takes up a lot of memory to keep these intermediate tensors around. With gradient checkpointing, we actually **re-run** the forward pass during backwards to regenerate the output. So we are doing more compute but saving a lot of memory.
127 |
128 | The method we are using is kind of a hidden method in pytorch, but this is actually exactly what [accelerate uses under the hood](https://github.com/huggingface/accelerate/blob/v0.34.2/src/accelerate/accelerator.py#L1492) so rest assured that it is a "standard" way of doing it:
129 |
130 | This piece of code has to go **after** the FSDP constructor!!! I'm not exactly sure of the reason, but it doesn't work before the FSDP initialization.
131 |
132 | ```python
133 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
134 | apply_activation_checkpointing,
135 | checkpoint_wrapper,
136 | )
137 |
138 | model = FSDP(...)
139 |
140 | apply_activation_checkpointing(
141 | model, checkpoint_wrapper_fn=checkpoint_wrapper, auto_wrap_policy=wrap_policy
142 | )
143 | ```
144 |
145 | ## CPU Offload & fused optimizer kernels
146 |
147 | Since the model is so large, we pretty much have to enable [CPU offloading](../04-fully-sharded-data-parallel/README.md#cpu-offload) with FSDP. **When using CPUOffload feature of FSDP, the optimizer entirely runs on the CPU**. This is because there is significant cost to transfer data to and from the GPU when doing `optimizer.step()`. At the time of this being written there are open issues on how to overlap the `optimizer.step()` with the next `forward()` call.
148 |
149 | By default the optimizers will use non-fused kernel when running on the CPU which will generate a lot of intermediate tensors. By explicitly using the fused kernel we get a lot of speedup, which is especially important since we are running that step on the CPU:
150 |
151 | ```python
152 | torch.optim.AdamW(model.parameters(), lr=args.lr, fused=True)
153 | ```
154 |
155 | If you want to peek through the pytorch code:
156 | 1. [_single_tensor_adamw()](https://github.com/pytorch/pytorch/blob/v2.4.1/torch/optim/adamw.py#L322) is the default implementation used
157 | 2. [_fused_adamw()](https://github.com/pytorch/pytorch/blob/v2.4.1/torch/optim/adamw.py#L612) is the fused implementation
158 |
159 | ## NOT de-allocating gradients
160 |
161 | You may have seen this `set_to_none` argument in [optimizer.zero_grad()](https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html). According to the docs:
162 |
163 | > This will in general have lower memory footprint, and can modestly improve performance.
164 |
165 | Basically `set_to_none=True` will **deallocate the gradients** after they are used. In most GPU cases where we want to save a bit of memory, it is a good thing to de-allocate. However in our case we are using CPU offload, which means all of our gradients are already on the CPU! Since we aren't taking up GPU memory, that means we just have to pay for allocating & de-allocating a lot if we do set to none. So if you set `set_to_none=False` you should actually see a slight speed up for our case!
166 |
167 | ```python
168 | optimizer.zero_grad(set_to_none=args.cpu_offload == "off")
169 | ```
170 |
171 | ## Launch command
172 |
173 | That's pretty much all the changes you need from our base [FSDP code](../04-fully-sharded-data-parallel/). Now let's launch!
174 |
175 | We provide a customized [launch.sh](./launch.sh) script here based on the bash command for spawning torchrun on all available nodes:
176 |
177 | ```bash
178 | cd distributed-training-guide/05-training-llama-405b
179 | bash launch.sh # NOTE: this is non blocking
180 | ```
181 |
182 | Also note that this launch.sh specifies `HF_HOME` as an environment variable in the tmux session, so if you've not used the default value of `/home/ubuntu/.cache/huggingface`, please update the script!
183 |
184 | You can change the hostnames in the [hosts](./hosts) file in this directory.
185 |
186 | ## Monitoring
187 |
188 | We are using torchrun in our [launch.sh](./launch.sh) script, so we will get an output directory per node with a bunch of sub directories with our log files in them. It's a bit of a pain to manually monitor these, so here's a bash command for tailing all of them at once:
189 |
190 | ```bash
191 | cd distributed-training-guide/05-training-llama-405b
192 | find ../logs/ -name \*stderr.log | xargs tail -f
193 | ```
194 |
195 | Additionally, we have a top like utility script for monitoring the entire cluster at the top level of this directory:
196 |
197 | ```bash
198 | cd distributed-training-guide/05-training-llama-405b
199 | python ../top-cluster.py hosts
200 | ```
201 |
202 | If you notice any of the nprocs go down or the power usage go down then you know that an error has occurred!
203 |
204 | To kill all the processes on all the nodes you can just kill the tmux sessions:
205 |
206 | ```bash
207 | xargs -a hosts -I{} ssh {} tmux kill-session -t torchrun-llama-405b
208 | ```
209 |
210 | ## Run statistics
211 |
212 | Training with `--seq-length 4096` and `--batch-size 1` on 64 H100 gpus (8 separate nodes) has the following stats:
213 |
214 | - ~30s per iteration (data/forward/backward/update). Breakdown is
215 | - data: ~2ms
216 | - forward: ~7s
217 | - backward: ~19s
218 | - update: ~4s
219 | - Peak Memory Allocated: 52.9GB
220 | - Peak Memory Reserved: 77.9GB
221 |
222 | Noting that reserved memory has to do with pytorch allocation caching.
223 |
224 | ## Other notes on settings that didn't affect throughput
225 |
226 | - Allowing tf32 had no impact on throughput (`torch.backends.cudnn.allow_tf32` and `torch.backends.cuda.matmul.allow_tf32`)
227 | - Enabling benchmarking had no impact on throughput (`torch.backends.cudnn.benchmark = True`)
228 | - Using CuDNN sdpa was slower (`attn_implementation="sdpa"` and `torch.backends.cuda.enable_cudnn_sdp(True)`)
229 | - torch.compile had no impact (`use_orig_params=True` and `torch.compile` after FSDP constructor)
230 | - Very minimal testing of NCCL environment variables either made things worse or had no impact (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html)
231 | - `PYTORCH_NO_CUDA_MEMORY_CACHING=1` made enough memory available that `--batch-size 2` or higher sequence lengths were possible, but it was much much slower.
232 | - It's possible that some well placed calls to `torch.cuda.empty_cache()` could achieve this without the throughput loss.
233 | - Only `FULL_SHARD` works. Others fail silently.
234 |
--------------------------------------------------------------------------------
/05-training-llama-405b/download.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import transformers
4 | import argparse
5 |
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument("--skip-model", default=False, action="store_true")
8 | args = parser.parse_args()
9 |
10 | os.environ["HF_HOME"] = "/home/ubuntu/.cache/huggingface"
11 |
12 | model_name = "meta-llama/Meta-Llama-3.1-405B"
13 |
14 | print(f"Downloading {model_name} to $HF_HOME = {os.environ['HF_HOME']}.")
15 |
16 | config = transformers.AutoConfig.from_pretrained(model_name)
17 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
18 | if not args.skip_model:
19 | with torch.device("meta"):
20 | model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
21 |
--------------------------------------------------------------------------------
/05-training-llama-405b/hosts:
--------------------------------------------------------------------------------
1 | ml-64-node-001
2 | ml-64-node-002
3 | ml-64-node-003
4 | ml-64-node-004
5 | ml-64-node-005
6 | ml-64-node-006
7 | ml-64-node-007
8 | ml-64-node-008
--------------------------------------------------------------------------------
/05-training-llama-405b/launch.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | EXPERIMENT_NAME=llama-405b
4 |
5 | if [ ! -f ./hosts ]; then
6 | echo "ERROR: ./hosts file not found. Please add this file to this current directory."
7 | exit 1
8 | fi
9 |
10 | ssh $(head -n 1 hosts) $(which wandb) login
11 |
12 | xargs \
13 | -a hosts \
14 | -I {} \
15 | ssh {} \
16 | tmux new-session -d -s torchrun-${EXPERIMENT_NAME} -c $(pwd) \
17 | -e HF_HOME=/home/ubuntu/.cache/huggingface \
18 | -e OMP_NUM_THREADS=26 \
19 | -e NCCL_CROSS_NIC=1 \
20 | -e TORCH_NCCL_AVOID_RECORD_STREAMS=1 \
21 | $(which python) -m torch.distributed.run \
22 | --rdzv-id ${EXPERIMENT_NAME} \
23 | --rdzv-backend c10d \
24 | --rdzv-endpoint $(head -n 1 hosts):5001 \
25 | --nnodes $(grep -c '^' hosts) \
26 | --nproc-per-node 8 \
27 | --redirects 3 \
28 | --log-dir ./logs \
29 | train_llm.py \
30 | --experiment-name ${EXPERIMENT_NAME} \
31 | --dataset-name Skylion007/openwebtext \
32 | --model-name meta-llama/Meta-Llama-3.1-405B \
33 | --batch-size 1 \
34 | --seq-length 4096 \
35 | --cpu-offload on \
36 | --log-freq 1
37 |
--------------------------------------------------------------------------------
/05-training-llama-405b/train_llm.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from contextlib import contextmanager
3 | import functools
4 | from itertools import chain
5 | import json
6 | import multiprocessing
7 | import os
8 | import time
9 | from pathlib import Path
10 | import logging
11 |
12 | import torch
13 | from torch.utils.data import DataLoader
14 | from torch.utils.data.distributed import DistributedSampler
15 | from torch import distributed as dist
16 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
17 | apply_activation_checkpointing,
18 | checkpoint_wrapper,
19 | )
20 | from torch.distributed.elastic.multiprocessing.errors import record
21 | from torch.distributed.fsdp.fully_sharded_data_parallel import (
22 | FullyShardedDataParallel,
23 | CPUOffload,
24 | ShardingStrategy,
25 | )
26 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
27 | from torch.distributed.checkpoint.state_dict import (
28 | get_state_dict,
29 | set_state_dict,
30 | StateDictOptions,
31 | )
32 | from torch.distributed.checkpoint import load, save
33 |
34 |
35 | import wandb
36 | import tqdm
37 | import datasets
38 | from transformers import (
39 | AutoConfig,
40 | AutoModelForCausalLM,
41 | AutoTokenizer,
42 | default_data_collator,
43 | )
44 | from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaRotaryEmbedding
45 |
46 | # fixes for reset_parameters not existing
47 | LlamaRMSNorm.reset_parameters = lambda self: torch.nn.init.ones_(self.weight)
48 | LlamaRotaryEmbedding.reset_parameters = lambda _: None
49 |
50 | LOGGER = logging.getLogger(__name__)
51 |
52 |
53 | @record
54 | def main():
55 | parser = _get_parser()
56 | args = parser.parse_args()
57 |
58 | dist.init_process_group()
59 |
60 | rank = dist.get_rank()
61 | local_rank = rank % torch.cuda.device_count()
62 | world_size = dist.get_world_size()
63 |
64 | logging.basicConfig(
65 | format=f"[rank={rank}] [%(asctime)s] %(levelname)s:%(message)s",
66 | level=logging.INFO,
67 | )
68 |
69 | LOGGER.info(os.environ)
70 | LOGGER.info(args)
71 | LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}")
72 |
73 | device = torch.device(f"cuda:{local_rank}")
74 | dtype = torch.bfloat16
75 | torch.cuda.set_device(device)
76 |
77 | torch.manual_seed(args.seed)
78 |
79 | LOGGER.info(f"Loading model from HF_HOME={os.environ['HF_HOME']}")
80 |
81 | config = AutoConfig.from_pretrained(args.model_name, use_cache=False)
82 | if rank == 0:
83 | with torch.device("cpu"):
84 | model = AutoModelForCausalLM.from_pretrained(
85 | args.model_name,
86 | torch_dtype=dtype,
87 | attn_implementation="flash_attention_2",
88 | use_cache=False,
89 | )
90 | else:
91 | with torch.device("meta"):
92 | model = AutoModelForCausalLM.from_config(
93 | config,
94 | torch_dtype=dtype,
95 | attn_implementation="flash_attention_2",
96 | )
97 | LOGGER.info(f"{sum(p.numel() for p in model.parameters())} model parameters")
98 |
99 | LOGGER.info(f"Before FSDP: {get_mem_stats(device)}")
100 |
101 | from torch.nn import Embedding
102 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer
103 |
104 | wrap_policy = functools.partial(
105 | transformer_auto_wrap_policy,
106 | transformer_layer_cls={LlamaDecoderLayer, Embedding},
107 | )
108 | model = FullyShardedDataParallel(
109 | model,
110 | device_id=local_rank,
111 | param_init_fn=lambda m: m.to_empty(device=device, recurse=False),
112 | sync_module_states=True,
113 | # NOTE: FULL_SHARD is equivalent to deepspeed ZeRO stage 3
114 | auto_wrap_policy=wrap_policy,
115 | sharding_strategy=ShardingStrategy.FULL_SHARD,
116 | cpu_offload=CPUOffload(offload_params=args.cpu_offload == "on"),
117 | )
118 |
119 | LOGGER.info(f"After FSDP: {get_mem_stats(device)}")
120 | LOGGER.info(f"FSDP architecture: {model}")
121 |
122 | # Applying gradient checkpointing - note that only the LlamaDecoderLayer supports this,
123 | # so we can just reuse our existing wrap_policy.
124 | apply_activation_checkpointing(
125 | model, checkpoint_wrapper_fn=checkpoint_wrapper, auto_wrap_policy=wrap_policy
126 | )
127 |
128 | # NOTE: since this can download data, make sure to do the main process first on each node
129 | # since we manually specified HF_HOME to be a node local drive.
130 | with rank_ordered(should_go_first=local_rank == 0):
131 | train_data = _load_and_preprocess_data(args, config)
132 | LOGGER.info(f"{len(train_data)} training samples")
133 |
134 | dataloader = DataLoader(
135 | train_data,
136 | batch_size=args.batch_size,
137 | collate_fn=default_data_collator,
138 | num_workers=1,
139 | prefetch_factor=2,
140 | # NOTE: this sampler will split dataset evenly across workers
141 | sampler=DistributedSampler(train_data, shuffle=True, drop_last=True),
142 | )
143 | LOGGER.info(f"{len(dataloader)} batches per epoch")
144 |
145 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, fused=True)
146 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
147 | optimizer, T_max=1000, eta_min=args.lr * 1e-2
148 | )
149 |
150 | exp_dir: Path = Path(args.save_dir) / args.experiment_name
151 |
152 | # NOTE: full_state_dict=False means we will be saving sharded checkpoints.
153 | ckpt_opts = StateDictOptions(full_state_dict=False, cpu_offload=True)
154 |
155 | # attempt resume
156 | state = {
157 | "epoch": 0,
158 | "global_step": 0,
159 | "epoch_step": 0,
160 | "running_loss": 0,
161 | }
162 | resumed = False
163 | if (exp_dir / "state.json").exists():
164 | sharded_model_state, sharded_optimizer_state = get_state_dict(
165 | model, optimizer, options=ckpt_opts
166 | )
167 | load(
168 | dict(model=sharded_model_state, optimizer=sharded_optimizer_state),
169 | checkpoint_id=exp_dir / "checkpoint",
170 | )
171 | set_state_dict(
172 | model,
173 | optimizer,
174 | model_state_dict=sharded_model_state,
175 | optim_state_dict=sharded_optimizer_state,
176 | options=ckpt_opts,
177 | )
178 | lr_scheduler.load_state_dict(
179 | torch.load(
180 | exp_dir / "lr_scheduler.pt", map_location=device, weights_only=True
181 | )
182 | )
183 | with open(exp_dir / "state.json") as fp:
184 | state = json.load(fp)
185 | resumed = True
186 | LOGGER.info(f"Resumed={resumed} | {state}")
187 | dist.barrier()
188 |
189 | if (exp_dir.is_mount() and rank == 0) or (
190 | not exp_dir.is_mount() and local_rank == 0
191 | ):
192 | LOGGER.info(f"Creating experiment root directory")
193 | exp_dir.mkdir(parents=True, exist_ok=True)
194 | dist.barrier()
195 |
196 | if rank == 0:
197 | wandb.init(
198 | project="distributed-training-guide",
199 | dir=exp_dir,
200 | name=args.experiment_name,
201 | id=args.experiment_name,
202 | resume="must" if resumed else None,
203 | save_code=True,
204 | config={
205 | "args": vars(args),
206 | "training_data_size": len(train_data),
207 | "num_batches": len(dataloader),
208 | "world_size": world_size,
209 | },
210 | )
211 |
212 | timers = {k: LocalTimer(device) for k in ["data", "forward", "backward", "update"]}
213 |
214 | for state["epoch"] in range(state["epoch"], args.num_epochs):
215 | LOGGER.info(f"Begin epoch {state['epoch']} at step {state['epoch_step']}")
216 |
217 | progress_bar = tqdm.tqdm(range(len(dataloader)), disable=True)
218 | if state["epoch_step"] > 0:
219 | progress_bar.update(state["epoch_step"])
220 |
221 | dataloader.sampler.set_epoch(state["epoch"])
222 | batches = iter(dataloader)
223 |
224 | for i_step in range(len(dataloader)):
225 | with timers["data"], torch.no_grad():
226 | batch = next(batches)
227 | batch = {k: v.to(device=device) for k, v in batch.items()}
228 |
229 | if i_step < state["epoch_step"]:
230 | # NOTE: for resuming
231 | continue
232 |
233 | with timers["forward"]:
234 | outputs = model(**batch)
235 |
236 | with timers["backward"]:
237 | outputs.loss.backward()
238 |
239 | with timers["update"]:
240 | optimizer.step()
241 | lr_scheduler.step()
242 | optimizer.zero_grad(set_to_none=args.cpu_offload == "off")
243 |
244 | state["global_step"] += 1
245 | state["epoch_step"] += 1
246 | state["running_loss"] += outputs.loss.item()
247 | progress_bar.update(1)
248 |
249 | if state["global_step"] % args.log_freq == 0:
250 | tok_per_step = world_size * args.batch_size * args.seq_length
251 | ms_per_step = sum(t.avg_elapsed_ms() for t in timers.values())
252 | info = {
253 | "global_step": state["global_step"],
254 | "lr": lr_scheduler.get_last_lr()[0],
255 | "running_loss": state["running_loss"] / args.log_freq,
256 | "epoch": state["epoch"],
257 | "epoch_progress": state["epoch_step"] / len(dataloader),
258 | "num_batches_remaining": len(dataloader) - i_step,
259 | **get_mem_stats(device),
260 | "tok/s": 1000 * tok_per_step / ms_per_step,
261 | "time/total": ms_per_step,
262 | "time/total": sum(t.avg_elapsed_ms() for t in timers.values()),
263 | **{
264 | f"time/{k}": timer.avg_elapsed_ms()
265 | for k, timer in timers.items()
266 | },
267 | }
268 |
269 | LOGGER.info(info)
270 | if rank == 0:
271 | wandb.log(info, step=state["global_step"])
272 |
273 | torch.cuda.reset_peak_memory_stats(device)
274 | state["running_loss"] = 0
275 | for t in timers.values():
276 | t.reset()
277 |
278 | if state["global_step"] % args.ckpt_freq == 0:
279 | LOGGER.info("Saving checkpoint.")
280 | dist.barrier()
281 | # NOTE: we have to call this on ALL ranks
282 | sharded_model_state, sharded_optimizer_state = get_state_dict(
283 | model, optimizer, options=ckpt_opts
284 | )
285 | save(
286 | dict(model=sharded_model_state, optimizer=sharded_optimizer_state),
287 | checkpoint_id=exp_dir / "checkpoint",
288 | )
289 | if rank == 0:
290 | torch.save(lr_scheduler.state_dict(), exp_dir / "lr_scheduler.pt")
291 | with open(exp_dir / "state.json", "w") as fp:
292 | json.dump(state, fp)
293 | dist.barrier()
294 |
295 | state["epoch_step"] = 0
296 |
297 |
298 | def _load_and_preprocess_data(args, config):
299 | """
300 | Function created using code found in
301 | https://github.com/huggingface/transformers/blob/v4.45.1/examples/pytorch/language-modeling/run_clm_no_trainer.py
302 | """
303 | tokenizer = AutoTokenizer.from_pretrained(args.model_name)
304 |
305 | data = datasets.load_dataset(args.dataset_name, trust_remote_code=True)
306 |
307 | column_names = data["train"].column_names
308 | text_column_name = "text" if "text" in column_names else column_names[0]
309 |
310 | def tokenize_function(examples):
311 | return tokenizer(examples[text_column_name])
312 |
313 | tokenized_datasets = data.map(
314 | tokenize_function,
315 | batched=True,
316 | remove_columns=column_names,
317 | num_proc=multiprocessing.cpu_count(),
318 | load_from_cache_file=True,
319 | desc="Running tokenizer on dataset",
320 | )
321 |
322 | seq_length = args.seq_length or tokenizer.model_max_length
323 | if seq_length > config.max_position_embeddings:
324 | seq_length = min(1024, config.max_position_embeddings)
325 |
326 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
327 | def group_texts(examples):
328 | # Concatenate all texts.
329 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
330 | total_length = len(concatenated_examples[list(examples.keys())[0]])
331 | # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict.
332 | # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
333 | if total_length > seq_length:
334 | total_length = (total_length // seq_length) * seq_length
335 | # Split by chunks of max_len.
336 | result = {
337 | k: [t[i : i + seq_length] for i in range(0, total_length, seq_length)]
338 | for k, t in concatenated_examples.items()
339 | }
340 | result["labels"] = result["input_ids"].copy()
341 | return result
342 |
343 | lm_datasets = tokenized_datasets.map(
344 | group_texts,
345 | batched=True,
346 | num_proc=multiprocessing.cpu_count(),
347 | load_from_cache_file=True,
348 | desc=f"Grouping texts in chunks of {seq_length}",
349 | )
350 |
351 | return lm_datasets["train"]
352 |
353 |
354 | def get_mem_stats(device=None):
355 | mem = torch.cuda.memory_stats(device)
356 | props = torch.cuda.get_device_properties(device)
357 | return {
358 | "total_mem_in_gb": 1e-9 * props.total_memory,
359 | "curr_alloc_in_gb": 1e-9 * mem["allocated_bytes.all.current"],
360 | "peak_alloc_in_gb": 1e-9 * mem["allocated_bytes.all.peak"],
361 | "curr_resv_in_gb": 1e-9 * mem["reserved_bytes.all.current"],
362 | "peak_resv_in_gb": 1e-9 * mem["reserved_bytes.all.peak"],
363 | }
364 |
365 |
366 | @contextmanager
367 | def rank_ordered(*, should_go_first: bool):
368 | if should_go_first:
369 | yield
370 | dist.barrier()
371 | if not should_go_first:
372 | yield
373 | dist.barrier()
374 |
375 |
376 | class LocalTimer:
377 | def __init__(self, device: torch.device):
378 | if device.type == "cpu":
379 | self.synchronize = lambda: torch.cpu.synchronize(device=device)
380 | elif device.type == "cuda":
381 | self.synchronize = lambda: torch.cuda.synchronize(device=device)
382 | self.measurements = []
383 | self.start_time = None
384 |
385 | def __enter__(self):
386 | self.synchronize()
387 | self.start_time = time.time()
388 | return self
389 |
390 | def __exit__(self, type, value, traceback):
391 | if traceback is None:
392 | self.synchronize()
393 | end_time = time.time()
394 | self.measurements.append(end_time - self.start_time)
395 | self.start_time = None
396 |
397 | def avg_elapsed_ms(self):
398 | return 1000 * (sum(self.measurements) / len(self.measurements))
399 |
400 | def reset(self):
401 | self.measurements = []
402 | self.start_time = None
403 |
404 |
405 | def _get_parser() -> argparse.ArgumentParser:
406 | parser = argparse.ArgumentParser()
407 | parser.add_argument("-e", "--experiment-name", default=None, required=True)
408 | parser.add_argument("-d", "--dataset-name", default=None, required=True)
409 | parser.add_argument("-m", "--model-name", default=None, required=True)
410 | parser.add_argument("--save-dir", default="../outputs")
411 | parser.add_argument("--seed", default=0, type=int)
412 | parser.add_argument("--num-epochs", default=100, type=int)
413 | parser.add_argument("--lr", default=3e-5, type=float)
414 | parser.add_argument("-b", "--batch-size", default=1, type=int)
415 | parser.add_argument("--log-freq", default=100, type=int)
416 | parser.add_argument("--ckpt-freq", default=500, type=int)
417 | parser.add_argument("-s", "--seq-length", default=1024, type=int)
418 | parser.add_argument("--cpu-offload", default="on", choices=["on", "off"])
419 | return parser
420 |
421 |
422 | if __name__ == "__main__":
423 | main()
424 |
--------------------------------------------------------------------------------
/06-tensor-parallel/README.md:
--------------------------------------------------------------------------------
1 | # Tensor Parallelism (TP)
2 |
3 | So far we've just been using data parallel techniques. You may have heard of other parallelism techniques, and indeed the [Llama 405B paper](https://ai.meta.com/research/publications/the-llama-3-herd-of-models/) actually uses 4D parallelism when training the 405B model:
4 |
5 | 1. Data parallel (FSDP as we've learned)
6 | 2. Tensor parallel (**this chapter**)
7 | 3. Context parallel (For long context lengths)
8 | 4. Pipeline/model parallel
9 |
10 | In this chapter we are going to diving into what tensor parallelism is, before we think about combining it with other types.
11 |
12 | ## Basics: What is tensor parallelism?
13 |
14 | TP splits the model weights **AND** computation across multiple GPUs.
15 |
16 | FSDP splits the model weights, but it gathers them back for the computation. Splitting the computation across GPUs is the difference.
17 |
18 | A result of this is the world size is scaled **down** by your tensor parallel size => the cost of allgathers/allreduces is reduced. This becomes a big factor when your cluster is large, and TP is a very effective way to scale up!
19 |
20 | Here are the benefits of this:
21 | 1. The peak GPU memory is reduced - now instead of each GPU fully loading up the full weights for each layer, they now only load `1/num_gpus` of the weights.
22 | 2. We now have `per GPU memory * num_gpus` as our amount of memory to use for each layer.
23 | 3. Less allgather/allreduce cost
24 |
25 | Here are the downsides:
26 | 1. Global batch size is reduced
27 | 2. Increased code complexity
28 |
29 | Note that this can only really be applied to certain modules, but most of the modules in an LLM work with it.
30 |
31 | ## Ensure all GPUs on a node get the same input
32 |
33 | Since we are splitting computation across GPUs, all GPUs in the same group need to receive the same input. (That is why the global batch size is reduced).
34 |
35 | First we are going to create our device mesh. A device mesh is just a way to view your devices in an N-dimensional way. So if you have 8 GPUs, you could organize it into a device mesh like `(2, 2, 2)`, or `(2, 4)`, or `(4, 2)` or even things like `(1, 8)`.
36 |
37 | The reason this is helpful is because we are going to name these dimensions, much like we do with tensor dimensions. Similar to how we have a batch and sequence dimension, for our device mesh we are going to have a data parallel and tensor parallel dimension.
38 |
39 | ```python
40 | gpus_on_node = torch.cuda.device_count()
41 | num_nodes = world_size // gpus_on_node
42 | mesh = dist.device_mesh.init_device_mesh(
43 | "cuda",
44 | (num_nodes, gpus_on_node),
45 | mesh_dim_names=("dp", "tp"),
46 | )
47 | ```
48 |
49 | So if we have 4 GPUs total, and have a `(2, 2)` device mesh, here are the assignments:
50 |
51 | | | DP rank | TP rank |
52 | | --- | --- | --- |
53 | | GPU 0 | 0 | 0 |
54 | | GPU 1 | 0 | 1 |
55 | | GPU 2 | 1 | 0 |
56 | | GPU 3 | 1 | 1 |
57 |
58 | This doesn't actually mean anything unless we update the rest of our code to use these device meshes, so let's see how we do that!
59 | A lot of the pytorch distributed APIs actually take an optional `mesh: Optional[DeviceMesh] = None` argument, we just haven't used it so far.
60 |
61 | The first place is actually our data sampler, and this is how we get all of our GPUs in the tensor parallel group the same input:
62 |
63 | ```python
64 | sampler=DistributedSampler(
65 | ...,
66 | num_replicas=mesh["dp"].size(),
67 | # NOTE: every GPU on a node will have the same "dp" rank,
68 | # meaning they will all receive the same input!
69 | rank=mesh["dp"].get_local_rank(),
70 | )
71 | ```
72 |
73 | From GPU 0's perspective above, it would have these arguments to DistributedSampler:
74 |
75 | | | num_replicas | rank|
76 | | --- | --- | --- |
77 | | GPU 0 | 2 | 0 |
78 | | GPU 1 | 2 | 0 |
79 | | GPU 2 | 2 | 1 |
80 | | GPU 3 | 2 | 1 |
81 |
82 | Because our DP dimension is size of 2, and our first table above actually shows the local_rank that we use to pass to DistributedSampler.
83 |
84 | ## Parallelizing linear & attention modules
85 |
86 | Here's the code first and then there are graphics after this that explain how this works. Note that we are passing our `mesh["tp"]` to the API, which means this is happening across our tensor parallel group!
87 |
88 | ```python
89 | for layer in model.model.layers:
90 | tp.parallelize_module(
91 | layer,
92 | mesh["tp"],
93 | {
94 | "self_attn.q_proj": tp.ColwiseParallel(),
95 | "self_attn.k_proj": tp.ColwiseParallel(),
96 | "self_attn.v_proj": tp.ColwiseParallel(),
97 | "self_attn.o_proj": tp.RowwiseParallel(),
98 |
99 | "mlp.gate_proj": tp.ColwiseParallel(),
100 | "mlp.up_proj": tp.ColwiseParallel(),
101 | "mlp.down_proj": tp.RowwiseParallel(),
102 | },
103 | )
104 | ```
105 |
106 | ### colwise
107 |
108 | Our first three linear layers in self attention (q/k/v projection) are all colwise linear. This means we are sharding the weight matrix inside along dimension 0 (since it's stored in a transposed format). The remainder of the attention layer (including self attention), uses this sharded output to run (so attention actually will run on smaller tensors).
109 |
110 |
111 |
112 | Image Source: [PyTorchLightning](https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning#column-wise-parallel)
113 |
114 | ### colwise into rowwise
115 |
116 | Our final layer in our self attention layer is another linear layer (o_proj). Note that we are doing rowwise parallel here. This actually let's us "recombine" across our tp dimension, as shown here:
117 |
118 |
119 |
120 | Image Source: [PyTorchLightning](https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning#combined-parallel-layers)
121 |
122 | So the final output of self attention will be replicated again.
123 |
124 | ### Parallelizing Embedding layer
125 |
126 | The embeddings weight get's sharded along dimension 1. Meaning each GPU holds a different slice of the data associated with each token:
127 |
128 | | Embedding Weight Shape | Sharded Shape |
129 | | --- | --- |
130 | | `(vocab_size, hidden_dim)` | `(vocab_size, hidden_dim / mesh["tp"].size())` |
131 |
132 | In a normal embedding layer it:
133 | - Takes input tokens of `shape=(batch, seq)`
134 | - Outputs embeddings of `shape=(batch, seq, hidden_dim)`
135 |
136 | Now that we've sharded the embedding weight tensor, the layer will actually output:
137 | - Sharded output embeddings of `shape=(batch, seq, hidden_dim / mesh["tp"].size())`.
138 |
139 | We have a problem though: Our *colwise* pieces of the `self_attn` module will receive the output of this module. ColwiseParallel actually expects input to be **replicated** not sharded.
140 |
141 | So we need to do an allgather on the tensor to replicate it across the group (i.e. it will be back to `shape=(batch, seq, hidden_dim)`). Luckily we can just specify this additional transformation with the `output_layouts` argument:
142 |
143 | ```python
144 | tp.parallelize_module(
145 | model,
146 | mesh["tp"],
147 | {"model.embed_tokens": tp.ColwiseParallel(output_layouts=Replicate())},
148 | )
149 | ```
150 |
151 | ### Parallelizing the final linear layer of the model
152 |
153 | ```python
154 | tp.parallelize_module(
155 | model,
156 | mesh["tp"],
157 | {
158 | "lm_head": tp.ColwiseParallel(
159 | output_layouts=Replicate()
160 | ),
161 | },
162 | )
163 | ```
164 |
165 | We have to include `Replicate()` here because our loss expects replicated tensors, but colwise by default shards on the last dimension.
166 |
167 | ## Parallelizing Norm Layers with SequenceParallel
168 |
169 | For normalization layers, it works a bit differently. We don't actually shard the layer's weights at all, instead, we shard the **input** for this on the sequence dimension!
170 |
171 | So our computation is split, and we need to do some work to join the results back together for the other modules:
172 |
173 | ```diff
174 | for layer in model.model.layers:
175 | tp.parallelize_module(
176 | layer,
177 | mesh["tp"],
178 | {
179 | + "input_layernorm": tp.SequenceParallel(),
180 | + "self_attn": tp.PrepareModuleInput(
181 | + input_kwarg_layouts={"hidden_states": Shard(dim=1)},
182 | + desired_input_kwarg_layouts={"hidden_states": Replicate()},
183 | + ),
184 | "self_attn.q_proj": tp.ColwiseParallel(),
185 | "self_attn.k_proj": tp.ColwiseParallel(),
186 | "self_attn.v_proj": tp.ColwiseParallel(),
187 | - "self_attn.o_proj": tp.RowwiseParallel(),
188 | + "self_attn.o_proj": tp.RowwiseParallel(output_layouts=Shard(1)),
189 | + "post_attention_layernorm": tp.SequenceParallel(),
190 | + "mlp": tp.PrepareModuleInput(
191 | + input_layouts=Shard(dim=1),
192 | + desired_input_layouts=Replicate(),
193 | + ),
194 | "mlp.gate_proj": tp.ColwiseParallel(),
195 | "mlp.up_proj": tp.ColwiseParallel(),
196 | - "mlp.down_proj": tp.RowwiseParallel(),
197 | + "mlp.down_proj": tp.RowwiseParallel(output_layouts=Shard(1)),
198 | },
199 | )
200 | ```
201 |
202 | The `PrepareModuleInput` objects transform how the tensors are split up. E.g. for `self_attn` the hidden_states input is sharded along the 1st dimension because of the `SequenceParallel`, but all the `ColwiseParallel` expect input to be replicated.
203 |
204 | We also need to change our embedding layer, since now the output of that is going into our SequenceParallel layer, we need to shard it along dimension 1:
205 |
206 | ```diff
207 | tp.parallelize_module(
208 | model,
209 | mesh["tp"],
210 | - {"model.embed_tokens": tp.ColwiseParallel(output_layouts=Replicate())},
211 | + {"model.embed_tokens": tp.ColwiseParallel(output_layouts=Shard(1))},
212 | )
213 | ```
214 |
215 | We actually need an additional change because of this, due to `transformers` specific code. It computes the sequence length based on the output of the embedding layer, which will be wrong since we are now sharding it along the sequence dimension. Passing position_ids explicitly will fix this, but **its very implementation specific**:
216 |
217 | ```diff
218 | with timers["data"], torch.no_grad():
219 | batch = next(batches)
220 | batch = {k: v.to(device=device) for k, v in batch.items()}
221 | + batch["position_ids"] = torch.arange(
222 | + 0, args.seq_length, device=device, dtype=torch.long
223 | + ).unsqueeze(0)
224 | ```
225 |
226 | And here is the diff for our final output from the network:
227 | ```diff
228 | tp.parallelize_module(
229 | model,
230 | mesh["tp"],
231 | {
232 | + "model.norm": tp.SequenceParallel(),
233 | "lm_head": tp.ColwiseParallel(
234 | + input_layouts=Shard(1),
235 | output_layouts=Replicate(),
236 | ),
237 | },
238 | )
239 | ```
240 |
241 | ## Parallelizing Loss computation
242 |
243 | There's an additional api for parallelizing the loss computation (only works for Cross Entropy at the moment of writing) across the **class** dimension. We first need to use this context manager around our loss computation:
244 |
245 | ```python
246 | with tp.loss_parallel(), timers["forward"]:
247 | outputs = model(**batch)
248 |
249 | with tp.loss_parallel(), timers["backward"]:
250 | outputs.loss.backward()
251 | ```
252 |
253 | Then we need to update the output of our `lm_head` for this also, because loss_parallel requires different sharding format and DTensor:
254 |
255 | ```diff
256 | tp.parallelize_module(
257 | model,
258 | mesh["tp"],
259 | {
260 | "model.norm": tp.SequenceParallel(),
261 | "lm_head": tp.ColwiseParallel(
262 | input_layouts=Shard(1),
263 | - output_layouts=Replicate(),
264 | + output_layouts=Shard(-1),
265 | + use_local_output=False,
266 | ),
267 | },
268 | )
269 | ```
270 |
271 | `use_local_output=False` tells pytorch to return a `DTensor` from the operation, instead of a normal `Tensor`.
272 |
273 | ## Computing throughput with our new world size
274 |
275 | Because each of our GPUs is now no longer the unit, we just need to update our throughput calculation to use our device mesh:
276 |
277 | ```diff
278 | if state["global_step"] % args.log_freq == 0:
279 | - tok_per_step = world_size * args.batch_size * args.seq_length
280 | + tok_per_step = mesh["dp"].size() * args.batch_size * args.seq_length
281 | ms_per_step = sum(t.avg_elapsed_ms() for t in timers.values())
282 | ```
283 |
284 | ## Results
285 |
286 | Here are some results from launching training for llama 8B on a single node of 8x H100s:
287 |
288 | Command:
289 | ```bash
290 | HF_HOME=/home/ubuntu/.cache/huggingface OMP_NUM_THREADS=26 torchrun --standalone --nproc-per-node gpu train_llm.py --experiment-name tp-llama-8b --dataset-name tatsu-lab/alpaca --model-name meta-llama/Llama-3.1-8B --log-freq 10 --batch-size 16 --seq-length 1024 --num-epochs 1
291 | ```
292 |
293 |
294 |
295 |
296 |
297 | ## Useful References
298 |
299 | For completeness here are the relevant docs/guides from pytorch on how to achieve this:
300 | - [TP API docs](https://pytorch.org/docs/stable/distributed.tensor.parallel.html#tensor-parallelism-torch-distributed-tensor-parallel)
301 | - [2d Parallelism Tutorial](https://pytorch.org/tutorials/intermediate/TP_tutorial.html#large-scale-transformer-model-training-with-tensor-parallel-tp)
302 | - [Device Mesh tutorial](https://pytorch.org/tutorials/recipes/distributed_device_mesh.html)
303 | - [PyTorch Lightning TP Tutorial](https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning)
304 |
305 | ## Pytorch API Reference
306 |
307 | Here we are going to give a brief explanation of how the api we are going to be using works.
308 |
309 | - [tp.RowwiseParallel()](https://pytorch.org/docs/stable/distributed.tensor.parallel.html#torch.distributed.tensor.parallel.RowwiseParallel) shards the module's weights in a row wise fashion.
310 | - Inputs by default are sharded on last dimension
311 | - Outputs by default are replicated on all workers
312 | - [tp.ColwiseParallel()](https://pytorch.org/docs/stable/distributed.tensor.parallel.html#torch.distributed.tensor.parallel.ColwiseParallel) shards the module's weights in a col wise fashion.
313 | - Inputs by default are replicated on all workers
314 | - Outputs by default are sharded on last dimension
315 | - [tp.SequenceParallel()](https://pytorch.org/docs/stable/distributed.tensor.parallel.html#torch.distributed.tensor.parallel.SequenceParallel) shards the input/output across dimension 1. Module weights are NOT sharded.
316 | - [tp.PrepareModuleInput()](https://pytorch.org/docs/stable/distributed.tensor.parallel.html#torch.distributed.tensor.parallel.PrepareModuleInput) let's you change the sharding configuration of input tensors
317 | - `torch.distributed._tensor.Shard(dim=X)` indicates a tensor should be sharded along dimension X
318 | - `torch.distributed._tensor.Replicate()` indicates a tensor should be replicated among all workers.
319 |
320 | How all of these things interact is actually very subtle and complex, which is why this guide is useful!
321 |
322 | You can also change most of the default behavior with arguments to these classes. For example, you can change RowwiseParallel to assume the input is replicated instead of sharded.
323 |
--------------------------------------------------------------------------------
/06-tensor-parallel/train_llm.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from contextlib import contextmanager
3 | from itertools import chain
4 | import json
5 | import multiprocessing
6 | import os
7 | import time
8 | from pathlib import Path
9 | import logging
10 |
11 | import torch
12 | from torch.utils.data import DataLoader
13 | from torch.utils.data.distributed import DistributedSampler
14 | from torch import distributed as dist
15 | import torch.distributed.tensor.parallel as tp
16 | from torch.distributed._tensor import Shard, Replicate
17 | from torch.distributed.elastic.multiprocessing.errors import record
18 | import torch.distributed.checkpoint as DCP
19 |
20 |
21 | import wandb
22 | import tqdm
23 | import datasets
24 | from transformers import (
25 | AutoConfig,
26 | AutoModelForCausalLM,
27 | AutoTokenizer,
28 | default_data_collator,
29 | )
30 |
31 | LOGGER = logging.getLogger(__name__)
32 |
33 |
34 | @record
35 | def main():
36 | parser = _get_parser()
37 | args = parser.parse_args()
38 |
39 | dist.init_process_group()
40 |
41 | gpus_on_node = torch.cuda.device_count()
42 |
43 | rank = dist.get_rank()
44 | local_rank = rank % gpus_on_node
45 | world_size = dist.get_world_size()
46 |
47 | assert (
48 | world_size % gpus_on_node == 0
49 | ), "This script assumes all nodes have the same amount of GPUs"
50 | num_nodes = world_size // gpus_on_node
51 |
52 | mesh = dist.device_mesh.init_device_mesh(
53 | "cuda",
54 | (num_nodes, gpus_on_node),
55 | mesh_dim_names=("dp", "tp"),
56 | )
57 |
58 | logging.basicConfig(
59 | format=f"[rank={rank}] [%(asctime)s] %(levelname)s:%(message)s",
60 | level=logging.INFO,
61 | )
62 |
63 | LOGGER.info(os.environ)
64 | LOGGER.info(args)
65 | LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}")
66 | LOGGER.info(f"dp_size={mesh['dp'].size()} tp_size={mesh['tp'].size()}")
67 |
68 | device = torch.device(f"cuda:{local_rank}")
69 | dtype = torch.bfloat16
70 | torch.cuda.set_device(device)
71 |
72 | torch.manual_seed(args.seed)
73 |
74 | LOGGER.info(f"Loading model from HF_HOME={os.environ['HF_HOME']}")
75 |
76 | with rank_ordered(should_go_first=local_rank == 0):
77 | config = AutoConfig.from_pretrained(args.model_name, use_cache=False)
78 | with device:
79 | model = AutoModelForCausalLM.from_config(
80 | config, torch_dtype=dtype, attn_implementation="flash_attention_2"
81 | )
82 | LOGGER.info(f"{sum(p.numel() for p in model.parameters())} model parameters")
83 |
84 | tp.parallelize_module(
85 | model,
86 | mesh["tp"],
87 | {"model.embed_tokens": tp.ColwiseParallel(output_layouts=Shard(1))},
88 | )
89 | for layer in model.model.layers:
90 | tp.parallelize_module(
91 | layer,
92 | mesh["tp"],
93 | {
94 | # SequenceParallel will apply sharding to sequence dimension.
95 | "input_layernorm": tp.SequenceParallel(),
96 | # The input to self_attn (which is the output from the SequenceParallel input_layer_norm) will be sharded on dimension 1, but we wanted it to be the whole tensor.
97 | "self_attn": tp.PrepareModuleInput(
98 | input_kwarg_layouts={"hidden_states": Shard(dim=1)},
99 | desired_input_kwarg_layouts={"hidden_states": Replicate()},
100 | ),
101 | "self_attn.q_proj": tp.ColwiseParallel(),
102 | "self_attn.k_proj": tp.ColwiseParallel(),
103 | "self_attn.v_proj": tp.ColwiseParallel(),
104 | "self_attn.o_proj": tp.RowwiseParallel(output_layouts=Shard(1)),
105 | # Another sharding along sequence dimension.
106 | "post_attention_layernorm": tp.SequenceParallel(),
107 | "mlp": tp.PrepareModuleInput(
108 | input_layouts=Shard(dim=1),
109 | desired_input_layouts=Replicate(),
110 | ),
111 | "mlp.gate_proj": tp.ColwiseParallel(),
112 | "mlp.up_proj": tp.ColwiseParallel(),
113 | "mlp.down_proj": tp.RowwiseParallel(output_layouts=Shard(1)),
114 | },
115 | )
116 |
117 | tp.parallelize_module(
118 | model,
119 | mesh["tp"],
120 | {
121 | "model.norm": tp.SequenceParallel(),
122 | "lm_head": tp.ColwiseParallel(
123 | input_layouts=Shard(1),
124 | output_layouts=Shard(-1), # for tp.loss_parallel
125 | use_local_output=False, # for tp.loss_parallel
126 | ),
127 | },
128 | )
129 |
130 | LOGGER.info(f"Final Architecture: {model}")
131 | LOGGER.info(f"{sum(p.numel() for p in model.parameters())} model parameters")
132 |
133 | model = model.to_empty(device=device)
134 | model.init_weights()
135 | model.train()
136 |
137 | LOGGER.info(f"{get_mem_stats(device)}")
138 |
139 | # NOTE: since this can download data, make sure to do the main process first on each node
140 | # since we manually specified HF_HOME to be a node local drive.
141 | with rank_ordered(should_go_first=local_rank == 0):
142 | train_data = _load_and_preprocess_data(args, config)
143 | LOGGER.info(f"{len(train_data)} training samples")
144 |
145 | dataloader = DataLoader(
146 | train_data,
147 | batch_size=args.batch_size,
148 | collate_fn=default_data_collator,
149 | num_workers=1,
150 | prefetch_factor=2,
151 | # NOTE: this sampler will split dataset evenly across workers
152 | sampler=DistributedSampler(
153 | train_data,
154 | shuffle=True,
155 | drop_last=True,
156 | num_replicas=mesh["dp"].size(), # equivalent to `num_nodes`
157 | rank=mesh["dp"].get_local_rank(), # equivalent to `rank // num_nodes`
158 | ),
159 | )
160 | LOGGER.info(f"{len(dataloader)} batches per epoch")
161 |
162 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, fused=True)
163 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
164 | optimizer, T_max=1000, eta_min=args.lr * 1e-2
165 | )
166 |
167 | exp_dir: Path = Path(args.save_dir) / args.experiment_name
168 |
169 | # attempt resume
170 | state = {
171 | "epoch": 0,
172 | "global_step": 0,
173 | "epoch_step": 0,
174 | "running_loss": 0,
175 | }
176 | resumed = False
177 | if (exp_dir / "state.json").exists():
178 | DCP.load(
179 | dict(model=model, optimizer=optimizer),
180 | checkpoint_id=exp_dir / "checkpoint",
181 | )
182 | lr_scheduler.load_state_dict(
183 | torch.load(
184 | exp_dir / "lr_scheduler.pt", map_location=device, weights_only=True
185 | )
186 | )
187 | with open(exp_dir / "state.json") as fp:
188 | state = json.load(fp)
189 | resumed = True
190 | LOGGER.info(f"Resumed={resumed} | {state}")
191 | dist.barrier()
192 |
193 | if (exp_dir.is_mount() and rank == 0) or (
194 | not exp_dir.is_mount() and local_rank == 0
195 | ):
196 | LOGGER.info(f"Creating experiment root directory")
197 | exp_dir.mkdir(parents=True, exist_ok=True)
198 | dist.barrier()
199 |
200 | if rank == 0:
201 | wandb.init(
202 | project="distributed-training-guide",
203 | dir=exp_dir,
204 | name=args.experiment_name,
205 | id=args.experiment_name,
206 | resume="must" if resumed else None,
207 | save_code=True,
208 | config={
209 | "args": vars(args),
210 | "training_data_size": len(train_data),
211 | "num_batches": len(dataloader),
212 | "world_size": world_size,
213 | },
214 | )
215 |
216 | timers = {k: LocalTimer(device) for k in ["data", "forward", "backward", "update"]}
217 |
218 | for state["epoch"] in range(state["epoch"], args.num_epochs):
219 | LOGGER.info(f"Begin epoch {state['epoch']} at step {state['epoch_step']}")
220 |
221 | progress_bar = tqdm.tqdm(range(len(dataloader)), disable=True)
222 | if state["epoch_step"] > 0:
223 | progress_bar.update(state["epoch_step"])
224 |
225 | batches = iter(dataloader)
226 |
227 | for i_step in range(len(dataloader)):
228 | with timers["data"], torch.no_grad():
229 | batch = next(batches)
230 | batch = {k: v.to(device=device) for k, v in batch.items()}
231 | batch["position_ids"] = torch.arange(
232 | 0, args.seq_length, device=device, dtype=torch.long
233 | ).unsqueeze(0)
234 |
235 | if i_step < state["epoch_step"]:
236 | # NOTE: for resuming
237 | continue
238 |
239 | with tp.loss_parallel(), timers["forward"]:
240 | outputs = model(**batch)
241 |
242 | with tp.loss_parallel(), timers["backward"]:
243 | outputs.loss.backward()
244 |
245 | with timers["update"]:
246 | optimizer.step()
247 | lr_scheduler.step()
248 | optimizer.zero_grad(set_to_none=True)
249 |
250 | state["global_step"] += 1
251 | state["epoch_step"] += 1
252 | state["running_loss"] += outputs.loss.item()
253 | progress_bar.update(1)
254 |
255 | if state["global_step"] % args.log_freq == 0:
256 | tok_per_step = mesh["dp"].size() * args.batch_size * args.seq_length
257 | ms_per_step = sum(t.avg_elapsed_ms() for t in timers.values())
258 | info = {
259 | "global_step": state["global_step"],
260 | "lr": lr_scheduler.get_last_lr()[0],
261 | "running_loss": state["running_loss"] / args.log_freq,
262 | "epoch": state["epoch"],
263 | "epoch_progress": state["epoch_step"] / len(dataloader),
264 | "num_batches_remaining": len(dataloader) - i_step,
265 | "tok/s": 1000 * tok_per_step / ms_per_step,
266 | **get_mem_stats(device),
267 | "time/total": sum(t.avg_elapsed_ms() for t in timers.values()),
268 | **{
269 | f"time/{k}": timer.avg_elapsed_ms()
270 | for k, timer in timers.items()
271 | },
272 | }
273 |
274 | LOGGER.info(info)
275 | if rank == 0:
276 | wandb.log(info, step=state["global_step"])
277 |
278 | torch.cuda.reset_peak_memory_stats(device)
279 | state["running_loss"] = 0
280 | for t in timers.values():
281 | t.reset()
282 |
283 | if state["global_step"] % args.ckpt_freq == 0:
284 | LOGGER.info("Saving checkpoint.")
285 | dist.barrier()
286 | # NOTE: we have to call this on ALL ranks
287 | DCP.save(
288 | dict(model=model, optimizer=optimizer),
289 | checkpoint_id=exp_dir / "checkpoint",
290 | )
291 | if rank == 0:
292 | torch.save(lr_scheduler.state_dict(), exp_dir / "lr_scheduler.pt")
293 | with open(exp_dir / "state.json", "w") as fp:
294 | json.dump(state, fp)
295 | dist.barrier()
296 |
297 | state["epoch_step"] = 0
298 |
299 |
300 | def get_mem_stats(device=None):
301 | mem = torch.cuda.memory_stats(device)
302 | props = torch.cuda.get_device_properties(device)
303 | return {
304 | "total_gb": 1e-9 * props.total_memory,
305 | "curr_alloc_gb": 1e-9 * mem["allocated_bytes.all.current"],
306 | "peak_alloc_gb": 1e-9 * mem["allocated_bytes.all.peak"],
307 | "curr_resv_gb": 1e-9 * mem["reserved_bytes.all.current"],
308 | "peak_resv_gb": 1e-9 * mem["reserved_bytes.all.peak"],
309 | }
310 |
311 |
312 | def _load_and_preprocess_data(args, config):
313 | """
314 | Function created using code found in
315 | https://github.com/huggingface/transformers/blob/v4.45.1/examples/pytorch/language-modeling/run_clm_no_trainer.py
316 | """
317 | tokenizer = AutoTokenizer.from_pretrained(args.model_name)
318 |
319 | data = datasets.load_dataset(args.dataset_name, trust_remote_code=True)
320 |
321 | column_names = data["train"].column_names
322 | text_column_name = "text" if "text" in column_names else column_names[0]
323 |
324 | def tokenize_function(examples):
325 | return tokenizer(examples[text_column_name])
326 |
327 | tokenized_datasets = data.map(
328 | tokenize_function,
329 | batched=True,
330 | remove_columns=column_names,
331 | num_proc=multiprocessing.cpu_count(),
332 | load_from_cache_file=True,
333 | desc="Running tokenizer on dataset",
334 | )
335 |
336 | seq_length = args.seq_length or tokenizer.model_max_length
337 | if seq_length > config.max_position_embeddings:
338 | seq_length = min(1024, config.max_position_embeddings)
339 |
340 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
341 | def group_texts(examples):
342 | # Concatenate all texts.
343 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
344 | total_length = len(concatenated_examples[list(examples.keys())[0]])
345 | # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict.
346 | # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
347 | if total_length > seq_length:
348 | total_length = (total_length // seq_length) * seq_length
349 | # Split by chunks of max_len.
350 | result = {
351 | k: [t[i : i + seq_length] for i in range(0, total_length, seq_length)]
352 | for k, t in concatenated_examples.items()
353 | }
354 | result["labels"] = result["input_ids"].copy()
355 | return result
356 |
357 | lm_datasets = tokenized_datasets.map(
358 | group_texts,
359 | batched=True,
360 | num_proc=multiprocessing.cpu_count(),
361 | load_from_cache_file=True,
362 | desc=f"Grouping texts in chunks of {seq_length}",
363 | )
364 |
365 | return lm_datasets["train"]
366 |
367 |
368 | @contextmanager
369 | def rank_ordered(*, should_go_first: bool):
370 | if should_go_first:
371 | yield
372 | dist.barrier()
373 | if not should_go_first:
374 | yield
375 | dist.barrier()
376 |
377 |
378 | class LocalTimer:
379 | def __init__(self, device: torch.device):
380 | if device.type == "cpu":
381 | self.synchronize = lambda: torch.cpu.synchronize(device=device)
382 | elif device.type == "cuda":
383 | self.synchronize = lambda: torch.cuda.synchronize(device=device)
384 | self.measurements = []
385 | self.start_time = None
386 |
387 | def __enter__(self):
388 | self.synchronize()
389 | self.start_time = time.time()
390 | return self
391 |
392 | def __exit__(self, type, value, traceback):
393 | if traceback is None:
394 | self.synchronize()
395 | end_time = time.time()
396 | self.measurements.append(end_time - self.start_time)
397 | self.start_time = None
398 |
399 | def avg_elapsed_ms(self):
400 | return 1000 * (sum(self.measurements) / len(self.measurements))
401 |
402 | def reset(self):
403 | self.measurements = []
404 | self.start_time = None
405 |
406 |
407 | def _get_parser() -> argparse.ArgumentParser:
408 | parser = argparse.ArgumentParser()
409 | parser.add_argument("-e", "--experiment-name", default=None, required=True)
410 | parser.add_argument("-d", "--dataset-name", default=None, required=True)
411 | parser.add_argument("-m", "--model-name", default=None, required=True)
412 | parser.add_argument("--save-dir", default="../outputs")
413 | parser.add_argument("--seed", default=0, type=int)
414 | parser.add_argument("--num-epochs", default=100, type=int)
415 | parser.add_argument("--lr", default=3e-5, type=float)
416 | parser.add_argument("-b", "--batch-size", default=1, type=int)
417 | parser.add_argument("--log-freq", default=100, type=int)
418 | parser.add_argument("--ckpt-freq", default=500, type=int)
419 | parser.add_argument("-s", "--seq-length", default=None, type=int)
420 | return parser
421 |
422 |
423 | if __name__ == "__main__":
424 | main()
425 |
--------------------------------------------------------------------------------
/07-2d-parallel/README.md:
--------------------------------------------------------------------------------
1 | # 2d parallelism (TP + DP)
2 |
3 | Using both [FSDP](../04-fully-sharded-data-parallel) and [TP](../06-tensor-parallel) is actually quite simple code wise when starting from our [chapter 6 TP script](../06-tensor-parallel/train_llm.py).
4 |
5 | **Disclaimer** this only works if you use pytorch's **newer FSDP 2 api, which is still in alpha stages**.
6 |
7 | What does using these two together mean exactly? Let's get into an example with 6 GPUs, 2 way FSDP and 3 way TP:
8 |
9 |
10 |
11 | When we first start out every gpu holds the full model. Then we shard the model into 3 pieces (our TP dimension). The 3 shards in the graphic above are red+orange, yellow+green, and blue+purple. Note that GPU 0 and GPU 3 **have the exact same shard**! This is because they are the same tensor parallel rank, but are different data parallel ranks. This means we have **duplicated** our model across our data parallel dimension.
12 |
13 | When we apply FSDP in the next step, we split those duplicated shards! So Shard red+orange (which is duplicated on GPU 0 & 3) is split into two pieces (Shard red and Shard orange).
14 |
15 | By the end we have 6 distinct shards of our model split on every GPU.
16 |
17 | Now if you remember with FSDP, it does an allgather of all the shards before the forward pass. When GPU 0 & GPU 3 are executing their forward passes, they will gather the two shards (Shard red and Shard orange) into local memory to form Shard red+orange, so that each one can use the full shard during computation.
18 |
19 | ## Applying FSDP after TP
20 |
21 | We are starting from our [chapter 6 code](../06-tensor-parallel/train_llm.py), which already support TP. So we just need to add FSDP to the script:
22 |
23 | The api is much simpler than FSDP 1 api, this is all we need to add **after** our TP code:
24 |
25 | ```python
26 | from torch.distributed._composable.fsdp import fully_shard
27 |
28 | if mesh["dp"].size() > 1:
29 | for layer in model.model.layers:
30 | fully_shard(layer, mesh=mesh["dp"])
31 | fully_shard(model, mesh=mesh["dp"])
32 | ```
33 |
34 | Note how we are passing our `mesh["dp"]` here to indicate that this is happening across our data parallel dimension.
35 |
36 | ## Controlling TP size
37 |
38 | When creating our mesh we are going to set the TP size based on a CLI argument:
39 |
40 | ```python
41 | assert world_size % args.tp == 0
42 |
43 | mesh = dist.device_mesh.init_device_mesh(
44 | "cuda",
45 | (world_size // args.tp, args.tp),
46 | mesh_dim_names=("dp", "tp"),
47 | )
48 | ```
49 |
50 | and add it to our argparser:
51 |
52 | ```python
53 | parser.add_argument("--tp", default=8, type=int)
54 | ```
55 |
56 | ## Performance with different configurations
57 |
58 | Here are some training results for 4 different setups of the TP size:
59 | - 1x8 is 8 way TP, and no data parallelism. `--batch-size 18 --tp 8`
60 | - 2x4 is 4 way TP, with 2 groups of FSDP. `--batch-size 14 --tp 4`
61 | - 4x2 is 2 way TP, with 4 groups of FSDP. `--batch-size 10 --tp 2`
62 | - 8x1 is FSDP. `--batch-size 7 --tp 1`
63 |
64 | Note that all of these runs have the same `--lr` while having different batch sizes, which is why the loss curves are slightly different.
65 |
66 |
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/07-2d-parallel/train_llm.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from contextlib import contextmanager
3 | from itertools import chain
4 | import json
5 | import multiprocessing
6 | import os
7 | import time
8 | from pathlib import Path
9 | import logging
10 |
11 | import torch
12 | from torch.utils.data import DataLoader
13 | from torch.utils.data.distributed import DistributedSampler
14 | from torch import distributed as dist
15 | import torch.distributed.tensor.parallel as tp
16 | from torch.distributed._tensor import Shard, Replicate
17 | from torch.distributed.elastic.multiprocessing.errors import record
18 | import torch.distributed.checkpoint as DCP
19 | from torch.distributed._composable.fsdp import fully_shard
20 |
21 | import wandb
22 | import tqdm
23 | import datasets
24 | from transformers import (
25 | AutoConfig,
26 | AutoModelForCausalLM,
27 | AutoTokenizer,
28 | default_data_collator,
29 | )
30 |
31 | LOGGER = logging.getLogger(__name__)
32 |
33 |
34 | @record
35 | def main():
36 | parser = _get_parser()
37 | args = parser.parse_args()
38 |
39 | dist.init_process_group()
40 |
41 | rank = dist.get_rank()
42 | local_rank = rank % torch.cuda.device_count()
43 | world_size = dist.get_world_size()
44 |
45 | assert args.tp > 1
46 | assert world_size % args.tp == 0
47 |
48 | mesh = dist.device_mesh.init_device_mesh(
49 | "cuda",
50 | (world_size // args.tp, args.tp),
51 | mesh_dim_names=("dp", "tp"),
52 | )
53 |
54 | logging.basicConfig(
55 | format=f"[rank={rank}] [%(asctime)s] %(levelname)s:%(message)s",
56 | level=logging.INFO,
57 | )
58 |
59 | LOGGER.info(os.environ)
60 | LOGGER.info(args)
61 | LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}")
62 | LOGGER.info(f"dp_size={mesh['dp'].size()} tp_size={mesh['tp'].size()}")
63 |
64 | device = torch.device(f"cuda:{local_rank}")
65 | dtype = torch.bfloat16
66 | torch.cuda.set_device(device)
67 |
68 | torch.manual_seed(args.seed)
69 |
70 | LOGGER.info(f"Loading model from HF_HOME={os.environ['HF_HOME']}")
71 |
72 | with rank_ordered(should_go_first=local_rank == 0):
73 | config = AutoConfig.from_pretrained(args.model_name, use_cache=False)
74 | with device:
75 | model = AutoModelForCausalLM.from_config(
76 | config, torch_dtype=dtype, attn_implementation="flash_attention_2"
77 | )
78 | LOGGER.info(f"{sum(p.numel() for p in model.parameters())} model parameters")
79 |
80 | tp.parallelize_module(
81 | model,
82 | mesh["tp"],
83 | {"model.embed_tokens": tp.ColwiseParallel(output_layouts=Shard(1))},
84 | )
85 | for layer in model.model.layers:
86 | tp.parallelize_module(
87 | layer,
88 | mesh["tp"],
89 | {
90 | # SequenceParallel will apply sharding to sequence dimension.
91 | "input_layernorm": tp.SequenceParallel(),
92 | # The input to self_attn (which is the output from the SequenceParallel input_layer_norm) will be sharded on dimension 1, but we wanted it to be the whole tensor.
93 | "self_attn": tp.PrepareModuleInput(
94 | input_kwarg_layouts={"hidden_states": Shard(dim=1)},
95 | desired_input_kwarg_layouts={"hidden_states": Replicate()},
96 | ),
97 | "self_attn.q_proj": tp.ColwiseParallel(),
98 | "self_attn.k_proj": tp.ColwiseParallel(),
99 | "self_attn.v_proj": tp.ColwiseParallel(),
100 | "self_attn.o_proj": tp.RowwiseParallel(output_layouts=Shard(1)),
101 | # Another sharding along sequence dimension.
102 | "post_attention_layernorm": tp.SequenceParallel(),
103 | "mlp": tp.PrepareModuleInput(
104 | input_layouts=Shard(dim=1),
105 | desired_input_layouts=Replicate(),
106 | ),
107 | "mlp.gate_proj": tp.ColwiseParallel(),
108 | "mlp.up_proj": tp.ColwiseParallel(),
109 | "mlp.down_proj": tp.RowwiseParallel(output_layouts=Shard(1)),
110 | },
111 | )
112 |
113 | tp.parallelize_module(
114 | model,
115 | mesh["tp"],
116 | {
117 | "model.norm": tp.SequenceParallel(),
118 | "lm_head": tp.ColwiseParallel(
119 | input_layouts=Shard(1),
120 | output_layouts=Shard(-1), # for tp.loss_parallel
121 | use_local_output=False, # for tp.loss_parallel
122 | ),
123 | },
124 | )
125 |
126 | for layer in model.model.layers:
127 | fully_shard(layer, mesh=mesh["dp"])
128 | fully_shard(model, mesh=mesh["dp"])
129 |
130 | LOGGER.info(f"Final Architecture: {model}")
131 | LOGGER.info(f"{sum(p.numel() for p in model.parameters())} model parameters")
132 |
133 | model = model.to_empty(device=device)
134 | model.init_weights()
135 | model.train()
136 |
137 | LOGGER.info(f"{get_mem_stats(device)}")
138 |
139 | # NOTE: since this can download data, make sure to do the main process first on each node
140 | # since we manually specified HF_HOME to be a node local drive.
141 | with rank_ordered(should_go_first=local_rank == 0):
142 | train_data = _load_and_preprocess_data(args, config)
143 | LOGGER.info(f"{len(train_data)} training samples")
144 |
145 | dataloader = DataLoader(
146 | train_data,
147 | batch_size=args.batch_size,
148 | collate_fn=default_data_collator,
149 | num_workers=1,
150 | prefetch_factor=2,
151 | # NOTE: this sampler will split dataset evenly across workers
152 | sampler=DistributedSampler(
153 | train_data,
154 | shuffle=True,
155 | drop_last=True,
156 | num_replicas=mesh["dp"].size(), # equivalent to `num_nodes`
157 | rank=mesh["dp"].get_local_rank(), # equivalent to `rank // num_nodes`
158 | ),
159 | )
160 | LOGGER.info(f"{len(dataloader)} batches per epoch")
161 |
162 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, fused=True)
163 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
164 | optimizer, T_max=1000, eta_min=args.lr * 1e-2
165 | )
166 |
167 | exp_dir: Path = Path(args.save_dir) / args.experiment_name
168 |
169 | # attempt resume
170 | state = {
171 | "epoch": 0,
172 | "global_step": 0,
173 | "epoch_step": 0,
174 | "running_loss": 0,
175 | }
176 | resumed = False
177 | if (exp_dir / "state.json").exists():
178 | DCP.load(
179 | dict(model=model, optimizer=optimizer),
180 | checkpoint_id=exp_dir / "checkpoint",
181 | )
182 | lr_scheduler.load_state_dict(
183 | torch.load(
184 | exp_dir / "lr_scheduler.pt", map_location=device, weights_only=True
185 | )
186 | )
187 | with open(exp_dir / "state.json") as fp:
188 | state = json.load(fp)
189 | resumed = True
190 | LOGGER.info(f"Resumed={resumed} | {state}")
191 | dist.barrier()
192 |
193 | if (exp_dir.is_mount() and rank == 0) or (
194 | not exp_dir.is_mount() and local_rank == 0
195 | ):
196 | LOGGER.info(f"Creating experiment root directory")
197 | exp_dir.mkdir(parents=True, exist_ok=True)
198 | dist.barrier()
199 |
200 | if rank == 0:
201 | wandb.init(
202 | project="distributed-training-guide",
203 | dir=exp_dir,
204 | name=args.experiment_name,
205 | id=args.experiment_name,
206 | resume="must" if resumed else None,
207 | save_code=True,
208 | config={
209 | "args": vars(args),
210 | "training_data_size": len(train_data),
211 | "num_batches": len(dataloader),
212 | "world_size": world_size,
213 | },
214 | )
215 |
216 | timers = {k: LocalTimer(device) for k in ["data", "forward", "backward", "update"]}
217 |
218 | for state["epoch"] in range(state["epoch"], args.num_epochs):
219 | LOGGER.info(f"Begin epoch {state['epoch']} at step {state['epoch_step']}")
220 |
221 | progress_bar = tqdm.tqdm(range(len(dataloader)), disable=True)
222 | if state["epoch_step"] > 0:
223 | progress_bar.update(state["epoch_step"])
224 |
225 | batches = iter(dataloader)
226 |
227 | for i_step in range(len(dataloader)):
228 | with timers["data"], torch.no_grad():
229 | batch = next(batches)
230 | batch = {k: v.to(device=device) for k, v in batch.items()}
231 | batch["position_ids"] = torch.arange(
232 | 0, args.seq_length, device=device, dtype=torch.long
233 | ).unsqueeze(0)
234 |
235 | if i_step < state["epoch_step"]:
236 | # NOTE: for resuming
237 | continue
238 |
239 | with tp.loss_parallel(), timers["forward"]:
240 | outputs = model(**batch)
241 |
242 | with tp.loss_parallel(), timers["backward"]:
243 | outputs.loss.backward()
244 |
245 | with timers["update"]:
246 | optimizer.step()
247 | lr_scheduler.step()
248 | optimizer.zero_grad(set_to_none=True)
249 |
250 | state["global_step"] += 1
251 | state["epoch_step"] += 1
252 | state["running_loss"] += outputs.loss.item()
253 | progress_bar.update(1)
254 |
255 | if state["global_step"] % args.log_freq == 0:
256 | tok_per_step = mesh["dp"].size() * args.batch_size * args.seq_length
257 | ms_per_step = sum(t.avg_elapsed_ms() for t in timers.values())
258 | info = {
259 | "global_step": state["global_step"],
260 | "lr": lr_scheduler.get_last_lr()[0],
261 | "running_loss": state["running_loss"] / args.log_freq,
262 | "epoch": state["epoch"],
263 | "epoch_progress": state["epoch_step"] / len(dataloader),
264 | "num_batches_remaining": len(dataloader) - i_step,
265 | "tok/s": 1000 * tok_per_step / ms_per_step,
266 | **get_mem_stats(device),
267 | "time/total": sum(t.avg_elapsed_ms() for t in timers.values()),
268 | **{
269 | f"time/{k}": timer.avg_elapsed_ms()
270 | for k, timer in timers.items()
271 | },
272 | }
273 |
274 | LOGGER.info(info)
275 | if rank == 0:
276 | wandb.log(info, step=state["global_step"])
277 |
278 | torch.cuda.reset_peak_memory_stats(device)
279 | state["running_loss"] = 0
280 | for t in timers.values():
281 | t.reset()
282 |
283 | if state["global_step"] % args.ckpt_freq == 0:
284 | LOGGER.info("Saving checkpoint.")
285 | dist.barrier()
286 | # NOTE: we have to call this on ALL ranks
287 | DCP.save(
288 | dict(model=model, optimizer=optimizer),
289 | checkpoint_id=exp_dir / "checkpoint",
290 | )
291 | if rank == 0:
292 | torch.save(lr_scheduler.state_dict(), exp_dir / "lr_scheduler.pt")
293 | with open(exp_dir / "state.json", "w") as fp:
294 | json.dump(state, fp)
295 | dist.barrier()
296 |
297 | state["epoch_step"] = 0
298 |
299 |
300 | def get_mem_stats(device=None):
301 | mem = torch.cuda.memory_stats(device)
302 | props = torch.cuda.get_device_properties(device)
303 | return {
304 | "total_gb": 1e-9 * props.total_memory,
305 | "curr_alloc_gb": 1e-9 * mem["allocated_bytes.all.current"],
306 | "peak_alloc_gb": 1e-9 * mem["allocated_bytes.all.peak"],
307 | "curr_resv_gb": 1e-9 * mem["reserved_bytes.all.current"],
308 | "peak_resv_gb": 1e-9 * mem["reserved_bytes.all.peak"],
309 | }
310 |
311 |
312 | def _load_and_preprocess_data(args, config):
313 | """
314 | Function created using code found in
315 | https://github.com/huggingface/transformers/blob/v4.45.1/examples/pytorch/language-modeling/run_clm_no_trainer.py
316 | """
317 | tokenizer = AutoTokenizer.from_pretrained(args.model_name)
318 |
319 | data = datasets.load_dataset(args.dataset_name, trust_remote_code=True)
320 |
321 | column_names = data["train"].column_names
322 | text_column_name = "text" if "text" in column_names else column_names[0]
323 |
324 | def tokenize_function(examples):
325 | return tokenizer(examples[text_column_name])
326 |
327 | tokenized_datasets = data.map(
328 | tokenize_function,
329 | batched=True,
330 | remove_columns=column_names,
331 | num_proc=multiprocessing.cpu_count(),
332 | load_from_cache_file=True,
333 | desc="Running tokenizer on dataset",
334 | )
335 |
336 | seq_length = args.seq_length or tokenizer.model_max_length
337 | if seq_length > config.max_position_embeddings:
338 | seq_length = min(1024, config.max_position_embeddings)
339 |
340 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
341 | def group_texts(examples):
342 | # Concatenate all texts.
343 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
344 | total_length = len(concatenated_examples[list(examples.keys())[0]])
345 | # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict.
346 | # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
347 | if total_length > seq_length:
348 | total_length = (total_length // seq_length) * seq_length
349 | # Split by chunks of max_len.
350 | result = {
351 | k: [t[i : i + seq_length] for i in range(0, total_length, seq_length)]
352 | for k, t in concatenated_examples.items()
353 | }
354 | result["labels"] = result["input_ids"].copy()
355 | return result
356 |
357 | lm_datasets = tokenized_datasets.map(
358 | group_texts,
359 | batched=True,
360 | num_proc=multiprocessing.cpu_count(),
361 | load_from_cache_file=True,
362 | desc=f"Grouping texts in chunks of {seq_length}",
363 | )
364 |
365 | return lm_datasets["train"]
366 |
367 |
368 | @contextmanager
369 | def rank_ordered(*, should_go_first: bool):
370 | if should_go_first:
371 | yield
372 | dist.barrier()
373 | if not should_go_first:
374 | yield
375 | dist.barrier()
376 |
377 |
378 | class LocalTimer:
379 | def __init__(self, device: torch.device):
380 | if device.type == "cpu":
381 | self.synchronize = lambda: torch.cpu.synchronize(device=device)
382 | elif device.type == "cuda":
383 | self.synchronize = lambda: torch.cuda.synchronize(device=device)
384 | self.measurements = []
385 | self.start_time = None
386 |
387 | def __enter__(self):
388 | self.synchronize()
389 | self.start_time = time.time()
390 | return self
391 |
392 | def __exit__(self, type, value, traceback):
393 | if traceback is None:
394 | self.synchronize()
395 | end_time = time.time()
396 | self.measurements.append(end_time - self.start_time)
397 | self.start_time = None
398 |
399 | def avg_elapsed_ms(self):
400 | return 1000 * (sum(self.measurements) / len(self.measurements))
401 |
402 | def reset(self):
403 | self.measurements = []
404 | self.start_time = None
405 |
406 |
407 | def _get_parser() -> argparse.ArgumentParser:
408 | parser = argparse.ArgumentParser()
409 | parser.add_argument("-e", "--experiment-name", default=None, required=True)
410 | parser.add_argument("-d", "--dataset-name", default=None, required=True)
411 | parser.add_argument("-m", "--model-name", default=None, required=True)
412 | parser.add_argument("--save-dir", default="../outputs")
413 | parser.add_argument("--seed", default=0, type=int)
414 | parser.add_argument("--num-epochs", default=100, type=int)
415 | parser.add_argument("--lr", default=3e-5, type=float)
416 | parser.add_argument("-b", "--batch-size", default=1, type=int)
417 | parser.add_argument("--log-freq", default=100, type=int)
418 | parser.add_argument("--ckpt-freq", default=500, type=int)
419 | parser.add_argument("-s", "--seq-length", default=None, type=int)
420 | parser.add_argument("--tp", default=8, type=int)
421 | return parser
422 |
423 |
424 | if __name__ == "__main__":
425 | main()
426 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Lambda, Inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Distributed Training Guide
2 |
3 |
4 |
5 | [Neurips 2024 presentation slides here](https://docs.google.com/presentation/d/1ANMmkOGaruYKTvhnsAbZgI9GrdMliNvibWGuNYw6HX8/edit?usp=sharing)
6 |
7 | Ever wondered how to train a large neural network across a giant cluster? Look no further!
8 |
9 | This is a comprehensive guide on best practices for distributed training, diagnosing errors, and fully utilizing all resources available. It is organized into sequential chapters, each with a `README.md` and a `train_llm.py` script in them. The readme will discuss both the high level concepts of distributed training, and the code changes introduced in that chapter.
10 |
11 | The guide is written entirely in very minimal standard pytorch, using `transformers` and `datasets` for models and data, respectively. No other library is used for distributed code - the distributed stuff is entirely in pytorch.
12 |
13 | 1. [Chapter 1](./01-single-gpu/) - A standard Causal LLM training script that runs on a **single GPU**.
14 | 2. [Chapter 2](./02-distributed-data-parallel/) - Upgrades the training script to support **multiple GPUs and to use DDP**.
15 | 3. [Chapter 3](./03-job-launchers/) - Covers how to **launch training jobs** across clusters with multiple nodes.
16 | 4. [Chapter 4](./04-fully-sharded-data-parallel/) - Upgrades the training script to **use FSDP** instead of DDP for more optimal memory usage.
17 | 5. [Chapter 5](./05-training-llama-405b/) - Upgrades the training script to **train Llama-405b**.
18 | 6. [Chapter 6](./06-tensor-parallel/) - Upgrades our single GPU training script to support **tensor parallelism**.
19 | 7. [Chapter 7](./06-2d-parallel/) - Upgrades our TP training script to use **2d parallelism (FSDP + TP)**.
20 | 8. [Alternative Frameworks](./alternative-frameworks/) - Covers different frameworks that all work with pytorch under the hood.
21 | 9. [Diagnosing Errors](./diagnosing-errors/) - Best practices and how tos for **quickly diagnosing errors** in your cluster.
22 | 10. [Related Topics](./related-topics/) - Topics that you should be aware of when distributed training.
23 |
24 |
25 | Questions this guide answers:
26 |
27 | - How do I update a single gpu training/fine tuning script to run on multiple GPUs or multiple nodes?
28 | - How do I diagnose hanging/errors that happen during training?
29 | - My model/optimizer is too big for a single gpu - how do I train/fine tune it on my cluster?
30 | - How do I schedule/launch training on a cluster?
31 | - How do I scale my hyperparameters when increasing the number of workers?
32 |
33 | Best practices for logging stdout/stderr and wandb are also included, as logging is vitally important in diagnosing/debugging training runs on a cluster.
34 |
35 | Each of the training scripts is aimed at training a causal language model (i.e. gpt/llama).
36 |
37 | ## Set up
38 |
39 | ### Clone this repo
40 |
41 | ```bash
42 | git clone https://github.com/LambdaLabsML/distributed-training-guide.git
43 | ```
44 |
45 | ### Virtual Environment
46 |
47 | ```bash
48 | cd distributed-training-guide
49 | python3 -m venv venv
50 | source venv/bin/activate
51 | python -m pip install -U pip
52 | pip install -U setuptools wheel
53 | pip install -r requirements.txt
54 | pip install flash-attn --no-build-isolation
55 | ```
56 |
57 | ### wandb
58 |
59 | This tutorial uses `wandb` as an experiment tracker.
60 |
61 | ```bash
62 | wandb login
63 | ```
64 |
65 |
66 | 🦄 Other exciting ML projects at Lambda: ML Times, Text2Video, GPU Benchmark.
67 |
68 |
--------------------------------------------------------------------------------
/alternative-frameworks/deepspeed/README.md:
--------------------------------------------------------------------------------
1 | # DeepSpeed ZeRO
2 |
3 | Install deepspeed: `pip install deepspeed`
4 |
5 |
6 |
7 | This is actually a collection of modes to shard more and more memory:
8 |
9 | > ZeRO Stage 1: The optimizer states (e.g., for Adam optimizer, 32-bit weights, and the first, and second moment estimates) are partitioned across the processes, so that each process updates only its partition.
10 |
11 | > ZeRO Stage 2: The reduced 32-bit gradients for updating the model weights are also partitioned such that each process retains only the gradients corresponding to its portion of the optimizer states.
12 |
13 | > ZeRO Stage 3: The 16-bit model parameters are partitioned across the processes. ZeRO-3 will automatically collect and partition them during the forward and backward passes.
14 |
15 | References:
16 | - [deepspeed docs](https://deepspeed.readthedocs.io/en/latest/zero3.html)
17 | - [ZeRO: Memory Optimizations Toward Training Trillion Parameter Models](https://arxiv.org/abs/1910.02054)
18 | - [ZeRO-Offload: Democratizing Billion-Scale Model Training](https://arxiv.org/abs/2101.06840)
19 | - [ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning](https://arxiv.org/abs/2104.07857)
20 |
21 | ## Integrating DeepSpeed into training code
22 |
23 | ### Argument Parsing
24 |
25 | ```diff
26 | @@ -305,11 +302,10 @@ def _get_parser() -> argparse.ArgumentParser:
27 | parser.add_argument("--log-freq", default=100, type=int)
28 | parser.add_argument("--ckpt-freq", default=500, type=int)
29 | + parser.add_argument("--local_rank", type=int, default=None)
30 | + deepspeed.add_config_arguments(parser)
31 | return parser
32 | ```
33 |
34 | ### Initialization
35 |
36 | Two main differences here:
37 | 1. We call `deepspeed.init_distributed` instead of using pytorch's `init_process_group`
38 | 2. We call `deepspeed.initialize` after we've constructed the model **instead** of wrapping the model with DDP.
39 |
40 | **NOTE**: `deepspeed.initialize` will construct the optimizer & lr_scheduler based on the config you pass in
41 |
42 | ```diff
43 | @@ -14,6 +14,7 @@ from torch.nn.parallel import DistributedDataParallel
44 | from torch import distributed as dist
45 | from torch.distributed.elastic.multiprocessing.errors import record
46 |
47 | +import deepspeed
48 | import numpy
49 | import wandb
50 | import tqdm
51 | @@ -42,10 +43,15 @@ def main():
52 | - dist.init_process_group()
53 | + deepspeed.init_distributed()
54 |
55 | rank = dist.get_rank()
56 | - local_rank = rank % torch.cuda.device_count()
57 | + local_rank = args.local_rank or (rank % torch.cuda.device_count())
58 | world_size = dist.get_world_size()
59 |
60 | LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}")
61 |
62 | @@ -73,10 +73,6 @@ def main():
63 | if len(tokenizer) > embedding_size:
64 | model.resize_token_embeddings(len(tokenizer))
65 |
66 | - model = DistributedDataParallel(
67 | - model, device_ids=[local_rank], output_device=local_rank
68 | - )
69 | -
70 | @@ -89,9 +95,11 @@ def main():
71 | )
72 | LOGGER.info(f"{len(dataloader)} batches per epoch")
73 |
74 | - optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
75 | - lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
76 | - optimizer, T_max=1000, eta_min=args.lr * 1e-2
77 | + model_engine: deepspeed.DeepSpeedEngine
78 | + model_engine, _, _, lr_scheduler = deepspeed.initialize(
79 | + args,
80 | + model=model,
81 | + model_parameters=(p for p in model.parameters() if p.requires_grad),
82 | )
83 | ```
84 |
85 | ### Train Loop
86 |
87 | Here we are just going to be replacing our pytorch calls with deepspeed calls. Note that we don't have direct access to optimizer/lr_scheduler anymore since deepspeed handles that.
88 |
89 | ```diff
90 | with timers["forward"]:
91 | - outputs = model(**batch)
92 | + outputs = model_engine(**batch)
93 |
94 | with timers["backward"]:
95 | - optimizer.zero_grad(set_to_none=True)
96 | - outputs.loss.backward()
97 | + model_engine.backward(outputs.loss)
98 |
99 | with timers["update"]:
100 | - optimizer.step()
101 | - lr_scheduler.step()
102 | + model_engine.step()
103 |
104 | state["global_step"] += 1
105 | state["epoch_step"] += 1
106 | ```
107 |
108 | ### Checkpoints
109 |
110 | Loading becomes:
111 |
112 | ```diff
113 | resumed = False
114 | - if (exp_dir / "state.json").exists():
115 | - model.load_state_dict(_load_to_device(exp_dir / "model.pt"))
116 | - optimizer.load_state_dict(_load_to_device(exp_dir / "optimizer.pt"))
117 | - lr_scheduler.load_state_dict(_load_to_device(exp_dir / "lr_scheduler.pt"))
118 | - with open(exp_dir / "state.json") as fp:
119 | - state = json.load(fp)
120 | - resumed = True
121 | + if (exp_dir / "pytorch_model.bin").exists():
122 | + load_path, state = model_engine.load_checkpoint(exp_dir)
123 | + resumed = load_path is not None
124 | ```
125 |
126 | Saving becomes: (**NOTE**: saving must be done on ALL ranks instead of just rank 0 - because of sharding)
127 |
128 | ```diff
129 | if state["global_step"] % args.ckpt_freq == 0:
130 | - if rank == 0:
131 | - torch.save(optimizer.state_dict(), exp_dir / "optimizer.pt")
132 | - torch.save(model.state_dict(), exp_dir / "model.pt")
133 | - torch.save(lr_scheduler.state_dict(), exp_dir / "lr_scheduler.pt")
134 | - with open(exp_dir / "state.json", "w") as fp:
135 | - json.dump(state, fp)
136 | + model_engine.save_checkpoint(exp_dir, client_state=state)
137 | dist.barrier()
138 | ```
139 |
140 | ## Configuration
141 |
142 | ```json
143 | {
144 | "train_micro_batch_size_per_gpu": 64,
145 | "optimizer": {
146 | "type": "Adam",
147 | "params": {
148 | "lr": 3e-5
149 | }
150 | },
151 | "scheduler": {
152 | "type": "WarmupCosineLR",
153 | "params": {
154 | "total_num_steps": 1000,
155 | "warmup_num_steps": 0,
156 | "cos_min_ratio": 1e-2
157 | }
158 | },
159 | "bf16": {
160 | "enabled": true
161 | },
162 | "zero_optimization": {
163 | "stage": 3,
164 | "offload_param": false,
165 | "offload_optimizer": false
166 | }
167 | }
168 | ```
169 |
170 | ## Command
171 |
172 | ```bash
173 | cd distributed-training-guide/05-sharding-deepspeed
174 | export TORCHELASTIC_ERROR_FILE=../error.json
175 | export OMP_NUM_THREADS=1
176 | export HF_HOME=../.cache
177 | deepspeed \
178 | --enable_each_rank_log ../logs \
179 | train_llm.py \
180 | --experiment-name deepspeed-multi-node-$(date +%Y-%m-%dT%H-%M-%S) \
181 | --dataset-name tatsu-lab/alpaca \
182 | --model-name openai-community/gpt2 \
183 | --deepspeed_config ds_config.json
184 | ```
185 |
--------------------------------------------------------------------------------
/alternative-frameworks/deepspeed/ds_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": 1,
3 | "optimizer": {
4 | "type": "AdamW",
5 | "params": {
6 | "lr": 3e-5
7 | }
8 | },
9 | "scheduler": {
10 | "type": "WarmupCosineLR",
11 | "params": {
12 | "total_num_steps": 1000,
13 | "warmup_num_steps": 0,
14 | "cos_min_ratio": 1e-2
15 | }
16 | },
17 | "bf16": {
18 | "enabled": true
19 | },
20 | "zero_optimization": {
21 | "stage": 3,
22 | "offload_param": false,
23 | "offload_optimizer": false
24 | }
25 | }
--------------------------------------------------------------------------------
/alternative-frameworks/deepspeed/train_llm.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from contextlib import contextmanager
3 | from itertools import chain
4 | import multiprocessing
5 | import os
6 | import time
7 | from pathlib import Path
8 | import logging
9 |
10 | import torch
11 | from torch.utils.data import DataLoader
12 | from torch.utils.data.distributed import DistributedSampler
13 | from torch import distributed as dist
14 | from torch.distributed.elastic.multiprocessing.errors import record
15 |
16 | import deepspeed
17 | import wandb
18 | import tqdm
19 | import datasets
20 | from transformers import (
21 | AutoConfig,
22 | AutoModelForCausalLM,
23 | AutoTokenizer,
24 | default_data_collator,
25 | )
26 |
27 | LOGGER = logging.getLogger(__name__)
28 |
29 |
30 | @record
31 | def main():
32 | parser = _get_parser()
33 | args = parser.parse_args()
34 |
35 | dist.init_process_group()
36 |
37 | rank = dist.get_rank()
38 | local_rank = args.local_rank or (rank % torch.cuda.device_count())
39 | world_size = dist.get_world_size()
40 |
41 | logging.basicConfig(
42 | format=f"[rank={rank}] [%(asctime)s] %(levelname)s:%(message)s",
43 | level=logging.INFO,
44 | )
45 |
46 | LOGGER.info(os.environ)
47 | LOGGER.info(args)
48 | LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}")
49 |
50 | device = torch.device(f"cuda:{local_rank}")
51 | dtype = torch.bfloat16
52 | torch.cuda.set_device(device)
53 |
54 | torch.manual_seed(args.seed)
55 |
56 | with rank0_first():
57 | config = AutoConfig.from_pretrained(args.model_name, use_cache=False)
58 | with deepspeed.zero.Init(remote_device="cpu", pin_memory=True):
59 | model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype)
60 | LOGGER.info(f"{sum(p.numel() for p in model.parameters())} model parameters")
61 |
62 | # NOTE: since this can download data, make sure to do the main process first
63 | # NOTE: This assumes that the data is on a **shared** network drive, accessible to all processes
64 | with rank0_first():
65 | train_data = _load_and_preprocess_data(args, config)
66 | LOGGER.info(f"{len(train_data)} training samples")
67 |
68 | model_engine: deepspeed.DeepSpeedEngine
69 | model_engine, _, _, lr_scheduler = deepspeed.initialize(
70 | args,
71 | model=model,
72 | model_parameters=(p for p in model.parameters() if p.requires_grad),
73 | )
74 |
75 | dataloader = DataLoader(
76 | train_data,
77 | batch_size=model_engine.train_micro_batch_size_per_gpu(),
78 | collate_fn=default_data_collator,
79 | # NOTE: this sampler will split dataset evenly across workers
80 | sampler=DistributedSampler(train_data, shuffle=True, drop_last=True),
81 | )
82 | LOGGER.info(f"{len(dataloader)} batches per epoch")
83 |
84 | exp_dir: Path = Path(args.save_dir) / args.experiment_name
85 |
86 | # attempt resume
87 | state = {
88 | "epoch": 0,
89 | "global_step": 0,
90 | "epoch_step": 0,
91 | "running_loss": 0,
92 | }
93 | resumed = False
94 | if (exp_dir / "pytorch_model.bin").exists():
95 | load_path, state = model_engine.load_checkpoint(exp_dir)
96 | resumed = load_path is not None
97 | LOGGER.info(f"Resumed={resumed} | {state}")
98 | dist.barrier()
99 |
100 | if (exp_dir.is_mount() and rank == 0) or (
101 | not exp_dir.is_mount() and local_rank == 0
102 | ):
103 | LOGGER.info(f"Creating experiment root directory")
104 | exp_dir.mkdir(parents=True, exist_ok=True)
105 | dist.barrier()
106 |
107 | (exp_dir / f"rank-{rank}").mkdir(parents=True, exist_ok=True)
108 | LOGGER.info(f"Worker saving to {exp_dir / f'rank-{rank}'}")
109 |
110 | if rank == 0:
111 | wandb.init(
112 | project="distributed-training-guide",
113 | dir=exp_dir,
114 | name=args.experiment_name,
115 | id=args.experiment_name,
116 | resume="must" if resumed else None,
117 | save_code=True,
118 | config={
119 | "args": vars(args),
120 | "training_data_size": len(train_data),
121 | "num_batches": len(dataloader),
122 | "world_size": world_size,
123 | },
124 | )
125 |
126 | timers = {k: LocalTimer(device) for k in ["data", "forward", "backward", "update"]}
127 |
128 | for state["epoch"] in range(state["epoch"], args.num_epochs):
129 | LOGGER.info(f"Begin epoch {state['epoch']} at step {state['epoch_step']}")
130 |
131 | progress_bar = tqdm.tqdm(range(len(dataloader)), disable=rank > 0)
132 | if state["epoch_step"] > 0:
133 | progress_bar.update(state["epoch_step"])
134 |
135 | dataloader.sampler.set_epoch(state["epoch"])
136 | batches = iter(dataloader)
137 |
138 | for i_step in range(len(dataloader)):
139 | with timers["data"], torch.no_grad():
140 | batch = next(batches)
141 | batch = {k: v.to(device=device) for k, v in batch.items()}
142 |
143 | if i_step < state["epoch_step"]:
144 | # NOTE: for resuming
145 | continue
146 |
147 | with timers["forward"]:
148 | outputs = model_engine(**batch)
149 |
150 | with timers["backward"]:
151 | model_engine.backward(outputs.loss)
152 |
153 | with timers["update"]:
154 | model_engine.step()
155 |
156 | state["global_step"] += 1
157 | state["epoch_step"] += 1
158 | state["running_loss"] += outputs.loss.item()
159 | progress_bar.update(1)
160 |
161 | if state["global_step"] % args.log_freq == 0:
162 | tok_per_step = (
163 | world_size
164 | * model_engine.train_micro_batch_size_per_gpu()
165 | * args.seq_length
166 | )
167 | ms_per_step = sum(t.avg_elapsed_ms() for t in timers.values())
168 | info = {
169 | "global_step": state["global_step"],
170 | "lr": lr_scheduler.get_last_lr()[0],
171 | "running_loss": state["running_loss"] / args.log_freq,
172 | "epoch": state["epoch"],
173 | "epoch_progress": state["epoch_step"] / len(dataloader),
174 | "num_batches_remaining": len(dataloader) - i_step,
175 | **get_mem_stats(device),
176 | "tok/s": 1000 * tok_per_step / ms_per_step,
177 | "time/total": ms_per_step,
178 | **{
179 | f"time/{k}": timer.avg_elapsed_ms()
180 | for k, timer in timers.items()
181 | },
182 | }
183 |
184 | LOGGER.info(info)
185 | if rank == 0:
186 | wandb.log(info, step=state["global_step"])
187 |
188 | torch.cuda.reset_peak_memory_stats(device)
189 | state["running_loss"] = 0
190 | for t in timers.values():
191 | t.reset()
192 |
193 | if state["global_step"] % args.ckpt_freq == 0:
194 | LOGGER.info("Saving checkpoint.")
195 | model_engine.save_checkpoint(exp_dir, client_state=state)
196 | dist.barrier()
197 |
198 | state["epoch_step"] = 0
199 |
200 |
201 | def _load_and_preprocess_data(args, config):
202 | """
203 | Function created using code found in
204 | https://github.com/huggingface/transformers/blob/v4.45.1/examples/pytorch/language-modeling/run_clm_no_trainer.py
205 | """
206 | tokenizer = AutoTokenizer.from_pretrained(args.model_name)
207 |
208 | data = datasets.load_dataset(args.dataset_name, trust_remote_code=True)
209 |
210 | column_names = data["train"].column_names
211 | text_column_name = "text" if "text" in column_names else column_names[0]
212 |
213 | def tokenize_function(examples):
214 | return tokenizer(examples[text_column_name])
215 |
216 | tokenized_datasets = data.map(
217 | tokenize_function,
218 | batched=True,
219 | remove_columns=column_names,
220 | num_proc=multiprocessing.cpu_count(),
221 | load_from_cache_file=True,
222 | desc="Running tokenizer on dataset",
223 | )
224 |
225 | seq_length = args.seq_length or tokenizer.model_max_length
226 | if seq_length > config.max_position_embeddings:
227 | seq_length = min(1024, config.max_position_embeddings)
228 |
229 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
230 | def group_texts(examples):
231 | # Concatenate all texts.
232 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
233 | total_length = len(concatenated_examples[list(examples.keys())[0]])
234 | # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict.
235 | # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
236 | if total_length > seq_length:
237 | total_length = (total_length // seq_length) * seq_length
238 | # Split by chunks of max_len.
239 | result = {
240 | k: [t[i : i + seq_length] for i in range(0, total_length, seq_length)]
241 | for k, t in concatenated_examples.items()
242 | }
243 | result["labels"] = result["input_ids"].copy()
244 | return result
245 |
246 | lm_datasets = tokenized_datasets.map(
247 | group_texts,
248 | batched=True,
249 | num_proc=multiprocessing.cpu_count(),
250 | load_from_cache_file=True,
251 | desc=f"Grouping texts in chunks of {seq_length}",
252 | )
253 |
254 | return lm_datasets["train"]
255 |
256 |
257 | def get_mem_stats(device=None):
258 | mem = torch.cuda.memory_stats(device)
259 | props = torch.cuda.get_device_properties(device)
260 | return {
261 | "total_mem_in_gb": 1e-9 * props.total_memory,
262 | "curr_alloc_in_gb": 1e-9 * mem["allocated_bytes.all.current"],
263 | "peak_alloc_in_gb": 1e-9 * mem["allocated_bytes.all.peak"],
264 | "curr_resv_in_gb": 1e-9 * mem["reserved_bytes.all.current"],
265 | "peak_resv_in_gb": 1e-9 * mem["reserved_bytes.all.peak"],
266 | }
267 |
268 |
269 | @contextmanager
270 | def rank0_first():
271 | rank = dist.get_rank()
272 | if rank == 0:
273 | yield
274 | dist.barrier()
275 | if rank > 0:
276 | yield
277 | dist.barrier()
278 |
279 |
280 | class LocalTimer:
281 | def __init__(self, device: torch.device):
282 | if device.type == "cpu":
283 | self.synchronize = lambda: torch.cpu.synchronize(device=device)
284 | elif device.type == "cuda":
285 | self.synchronize = lambda: torch.cuda.synchronize(device=device)
286 | self.measurements = []
287 | self.start_time = None
288 |
289 | def __enter__(self):
290 | self.synchronize()
291 | self.start_time = time.time()
292 | return self
293 |
294 | def __exit__(self, type, value, traceback):
295 | if traceback is None:
296 | self.synchronize()
297 | end_time = time.time()
298 | self.measurements.append(end_time - self.start_time)
299 | self.start_time = None
300 |
301 | def avg_elapsed_ms(self):
302 | return 1000 * (sum(self.measurements) / len(self.measurements))
303 |
304 | def reset(self):
305 | self.measurements = []
306 | self.start_time = None
307 |
308 |
309 | def _get_parser() -> argparse.ArgumentParser:
310 | parser = argparse.ArgumentParser()
311 | parser.add_argument("-e", "--experiment-name", default=None, required=True)
312 | parser.add_argument("-d", "--dataset-name", default=None, required=True)
313 | parser.add_argument("-m", "--model-name", default=None, required=True)
314 | parser.add_argument("--save-dir", default="../outputs")
315 | parser.add_argument("--seed", default=0, type=int)
316 | parser.add_argument("--num-epochs", default=100, type=int)
317 | parser.add_argument("--log-freq", default=100, type=int)
318 | parser.add_argument("--ckpt-freq", default=500, type=int)
319 | parser.add_argument("-s", "--seq-length", default=1024, type=int)
320 | parser.add_argument("--local_rank", type=int, default=None)
321 | deepspeed.add_config_arguments(parser)
322 | return parser
323 |
324 |
325 | if __name__ == "__main__":
326 | main()
327 |
--------------------------------------------------------------------------------
/diagnosing-errors/README.md:
--------------------------------------------------------------------------------
1 | # Diagnosing Errors
2 |
3 | Hanging and deadlocks can be caused by so many things, even your own code! Here's some diagnostic tools that will help you figure out what is going on.
4 |
5 | ## System metrics to watch for to diagnose hanging
6 |
7 | `GPU Power Usage` will be the main one - if the training process is hanging, then the power usage will drop to around ~10% for all workers:
8 |
9 | ```bash
10 | > nvidia-smi --query-gpu=power.draw,power.limit --format=csv,noheader
11 | 69.75 W, 700.00 W
12 | 75.10 W, 700.00 W
13 | 70.82 W, 700.00 W
14 | 69.29 W, 700.00 W
15 | 69.19 W, 700.00 W
16 | 68.72 W, 700.00 W
17 | 70.80 W, 700.00 W
18 | 70.87 W, 700.00 W
19 | ```
20 |
21 | Using our provided [top-cluster.py](../top-cluster.py) script will output something like this:
22 |
23 | ```bash
24 | > python top-cluster.py
25 | ===2024-10-02 19:55:02.553039
26 | name util power memory nprocs
27 | cluster 100.0% 99.1% 96.9% 64
28 | node-001 100.0% 99.7% 96.1% 8
29 | node-002 100.0% 97.8% 96.9% 8
30 | node-003 100.0% 99.2% 97.2% 8
31 | node-004 100.0% 99.1% 97.4% 8
32 | node-005 100.0% 98.1% 97.1% 8
33 | node-006 100.0% 99.0% 97.7% 8
34 | node-007 100.0% 99.8% 96.9% 8
35 | node-008 100.0% 100.0% 96.2% 8
36 | ===
37 | ```
38 |
39 | ## Getting a dump of stack traces
40 |
41 | Use [py-spy](https://github.com/benfred/py-spy) to get a dump of stacktraces from all python threads in a running python program. Here's how you get a dump from each worker:
42 |
43 | ```
44 | sudo env "PATH=$PATH" py-spy dump --locals --pid
45 | ```
46 |
47 | ## Benchmarking/profiling
48 |
49 | You can use `py-spy top --pid <>`, to get a `top`/`htop` like view of the functions that are being called in your python process.
50 |
51 | ## Recording errors
52 |
53 | Python has a great built in library for getting errors that occur in any thread of a python program called [faulthandler](https://docs.python.org/3/library/faulthandler.html). This is especially useful when you're using a DataLoader with num_workers > 0.
54 |
55 | Turns out, pytorch already has a built in way to use it! You just have to set `TORCHELASTIC_ERROR_FILE=../error.json` environment variable and add a `@record` annotation to your main function.
56 |
57 | ```python
58 | from torch.distributed.elastic.multiprocessing.errors import record
59 |
60 | # NOTE: records errors to $TORCHELASTIC_ERROR_FILE
61 | @record
62 | def main():
63 | ...
64 | ```
65 |
66 | Luckily all the code in this guide has been doing this, and so should you! **Make sure to set $TORCHELASTIC_ERROR_FILE**!.
67 |
68 | ## Checklist for system problems
69 |
70 | 1. System date time on each system is the same (can cause NCCL timeouts)
71 | 2. NVLink valid topology `nvidia-smi topo -m`
72 | 3. NVLink status `nvidia-smi topo -p2p n` (additionally `w`/`r` in place of `n`)
73 | 4. Open file descriptor limit `ulimit -aH` (and then look for line containing `open files`).
74 | 5. `timeout` in `dist.init_process_group(timeout=...)` is sufficiently large.
75 |
--------------------------------------------------------------------------------
/related-topics/README.md:
--------------------------------------------------------------------------------
1 | # Related topics
2 |
3 | This directory contains a list of additional topics that are adjacent to everything discussed in prior chapters.
4 |
5 | These chapters don't contain a training script individually, but the changes discussed in each are relatively small, and code snippets are provided to make it easy to add the features into your code.
6 |
--------------------------------------------------------------------------------
/related-topics/determinism/README.md:
--------------------------------------------------------------------------------
1 | # Determinism across resumes
2 |
3 | **NOTE: This chapter's code builds off of [chapter 3](../../03-multi-node/)'s code.**
4 |
5 | See pytorch's documnetation on reproducibility: https://pytorch.org/docs/stable/notes/randomness.html#reproducibility
6 |
7 | Notably we are also saving & restoring the rng states from various libraries, and explicitly seeding the workers for data loading.
8 |
9 | ## Code Changes
10 |
11 | ```diff
12 | diff --git a/03-multi-node/train_llm.py b/10-determinism/train_llm.py
13 | index 24eacbd..0a3a029 100644
14 | --- a/03-multi-node/train_llm.py
15 | +++ b/10-determinism/train_llm.py
16 | @@ -40,6 +40,7 @@ def main():
17 |
18 | torch.set_num_threads(1)
19 | torch.set_num_interop_threads(1)
20 | + torch.use_deterministic_algorithms(True)
21 |
22 | torch.manual_seed(args.seed)
23 | torch.cuda.manual_seed_all(args.seed)
24 | @@ -84,6 +85,8 @@ def main():
25 | train_data = _load_and_preprocess_data(args, tokenizer, config)
26 | LOGGER.info(f"{len(train_data)} training samples")
27 |
28 | + g = torch.Generator()
29 | + g.manual_seed(args.seed)
30 | dataloader = DataLoader(
31 | train_data,
32 | batch_size=args.batch_size,
33 | @@ -91,6 +94,8 @@ def main():
34 | num_workers=1,
35 | # NOTE: this sampler will split dataset evenly across workers
36 | sampler=DistributedSampler(train_data, shuffle=True, drop_last=True),
37 | + worker_init_fn=_seed_worker,
38 | + generator=g,
39 | )
40 | LOGGER.info(f"{len(dataloader)} batches per epoch")
41 |
42 | @@ -116,6 +121,13 @@ def main():
43 | lr_scheduler.load_state_dict(_load_to_device(exp_dir / "lr_scheduler.pt"))
44 | with open(exp_dir / "state.json") as fp:
45 | state = json.load(fp)
46 | + rng_state = torch.load(
47 | + exp_dir / "rng.pt", weights_only=False, map_location="cpu"
48 | + )
49 | + numpy.random.set_state(rng_state["np"])
50 | + random.setstate(rng_state["random"])
51 | + torch.set_rng_state(rng_state["torch"])
52 | + torch.cuda.set_rng_state(rng_state["cuda"][local_rank], device)
53 | resumed = True
54 | LOGGER.info(f"Resumed={resumed} | {state}")
55 |
56 | @@ -208,11 +220,26 @@ def main():
57 | torch.save(lr_scheduler.state_dict(), exp_dir / "lr_scheduler.pt")
58 | with open(exp_dir / "state.json", "w") as fp:
59 | json.dump(state, fp)
60 | + torch.save(
61 | + {
62 | + "np": numpy.random.get_state(),
63 | + "random": random.getstate(),
64 | + "torch": torch.get_rng_state(),
65 | + "cuda": torch.cuda.get_rng_state_all(),
66 | + },
67 | + exp_dir / "rng.pt",
68 | + )
69 | dist.barrier()
70 |
71 | state["epoch_step"] = 0
72 |
73 |
74 | +def _seed_worker(worker_id):
75 | + worker_seed = torch.initial_seed() % 2**32
76 | + numpy.random.seed(worker_seed)
77 | + random.seed(worker_seed)
78 | +
79 | +
80 | def _load_and_preprocess_data(args, tokenizer, config):
81 | data = datasets.load_dataset(
82 | args.dataset_name, trust_remote_code=True,
83 | ```
--------------------------------------------------------------------------------
/related-topics/effective-batch-size-and-lr/README.md:
--------------------------------------------------------------------------------
1 | # Effective Batch Size and LR
2 |
3 | As you scale up the number of nodes, the effective batch size (the amount of items used for model updates) increases as well:
4 |
5 | ```
6 | effective_batch_size = batch_size * world_size
7 | ```
8 |
9 | As you may know, increasing the batch size means that the variance of the data that your model is training on decreases, meaning your gradients will be much smoother. This directly impacts the dynamics of how your model learns and changes!
10 |
11 | If you want to **exactly match the dynamics of single gpu training** when moving to multi node training, this chapter is aimed at you!
12 |
13 | ## Scaling Rules
14 |
15 | If you want exact training dynamics, you have to also scale the learning rate. However, this depends on what optimizer you are using. The exact rules are not fully understood, and you can look into the following papers for more information:
16 |
17 | - [Exploring Learning Rate Scaling Rules for Distributed ML Training on Transient Resources](https://anakli.inf.ethz.ch/papers/learning_rate_distribml22.pdf)
18 |
19 | As of writing this, the most common rules that people use to scale learning rate are:
20 |
21 | ### Linear scaling rule
22 |
23 | ```python
24 | lr = args.lr * dist.get_world_size()
25 | ```
26 |
27 | This was first reported in the large minibatch SGD paper above. However this doesn't quite produce exactly the same training dynamics, and the paper actually used a **factor of the world size**.
28 |
29 | NOTE: **Be careful when using this for optimizers other than SGD**
30 |
31 | References:
32 | - [Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour](https://arxiv.org/pdf/1706.02677)
33 |
34 | ### Square root scaling rule
35 |
36 | ```python
37 | lr = args.lr * numpy.sqrt(dist.get_world_size())
38 | ```
39 |
40 | This is proposed for use with the Adam optimizer, and maintains the square root of the variance of the gradient when scaling the number of batches.
41 |
42 | References:
43 | - [One weird trick for parallelizing convolutional neural networks](https://arxiv.org/pdf/1404.5997)
44 | - [Large-Batch Training for LSTM and Beyond](https://arxiv.org/pdf/1901.08256)
45 |
--------------------------------------------------------------------------------
/related-topics/elastic-training/README.md:
--------------------------------------------------------------------------------
1 | # Elastic Training
2 |
3 | Elastic training is training where the launcher can restart a subset (or all) of the workers at various points throughout training.
4 |
5 | Contrary to what you might think, usually when 1 worker encounters an error, **ALL workers are restarted** (see https://pytorch.org/docs/stable/elastic/run.html#membership-changes).
6 |
7 | `torchrun` supports this via [elastic launch](https://pytorch.org/docs/stable/elastic/run.html#elastic-min-1-max-4-tolerates-up-to-3-membership-changes-or-failures):
8 |
9 | ```bash
10 | torchrun
11 | --nnodes=1:4
12 | --max-restarts=3
13 | ...
14 | ```
15 |
16 | which means that torchrun will restart all the workers up to 3 times (and if some of the nodes go offline, it can use as few as 1).
17 |
18 | Note:
19 | - `rank`, `local_rank`, and `world_size` are all not stable across restarts of a worker.
20 | - Sometimes nodes have issues that can't be fixed just by restarting (like if you have a bug).
21 |
22 | ## Code Changes
23 |
24 | No code changes are needed to do elastic training for our existing code. Instead it is more informative to play with a toy example where workers randomly crash to give you a sense for how it works.
25 |
26 | ```bash
27 | cd distributed-training-guide/96-elastic-training
28 | torchrun \
29 | --nnodes 1 \
30 | --nproc_per_node 8 \
31 | --max-restarts 3 \
32 | --redirects 3 \
33 | --log-dir ../logs \
34 | toy.py
35 | ```
36 |
37 | This toy script will randomly throw an error from each of the ranks. **No GPU required to try this command!**
38 |
39 | Inspect the log directory after you run this, for each attempt, there will be 1 worker sub directory that has a `error.json` file in it. You can also inspect each worker's stdout/stderr.
40 |
--------------------------------------------------------------------------------
/related-topics/elastic-training/toy.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import logging
4 | import os
5 |
6 | from torch import distributed as dist
7 | from torch.distributed.elastic.multiprocessing.errors import record
8 |
9 | LOGGER = logging.getLogger(__name__)
10 | _STATE_PATH = "./toy-state.json"
11 |
12 |
13 | @record
14 | def main():
15 | logging.basicConfig(level=logging.INFO)
16 |
17 | dist.init_process_group()
18 |
19 | rank = dist.get_rank()
20 | local_rank = os.environ["LOCAL_RANK"]
21 | world_size = dist.get_world_size()
22 |
23 | LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}")
24 |
25 | state = {"num_steps": 0}
26 | if os.path.exists(_STATE_PATH):
27 | with open(_STATE_PATH) as fp:
28 | state = json.load(fp)
29 |
30 | random.seed(rank + world_size * state["num_steps"])
31 |
32 | while True:
33 | value = random.random()
34 | LOGGER.info(f"[{rank=}] step={state['num_steps']} {value=}")
35 | if value < 0.001:
36 | raise ValueError("Encountered fake bad value.")
37 |
38 | state["num_steps"] += 1
39 |
40 | dist.barrier()
41 | if rank == 0:
42 | with open(_STATE_PATH, "w") as fp:
43 | json.dump(state, fp)
44 | dist.barrier()
45 |
46 |
47 | if __name__ == "__main__":
48 | main()
49 |
--------------------------------------------------------------------------------
/related-topics/gradient-accumulation/README.md:
--------------------------------------------------------------------------------
1 | # Gradient Accumulation
2 |
3 | Gradient accumulation is a way to increase the effective batch sizes of your model updates.
4 |
5 | It is normally applied when your model is so big that you use a lower batch size when running the forward/backward pass.
6 |
7 | If on a single GPU you have a batch size of 4, and a gradient accumulation of 2, then your effective batch size is 8.
8 |
9 | However, applying gradient accumulation in a standard way will cause slowdowns in distributed training setting because of gradient synchronization.
10 |
11 | ## Standard Implementation
12 |
13 | ```python
14 | outputs = model(**batch)
15 | outputs.loss.backward()
16 | if i_step % grad_accum == 0:
17 | optimizer.step()
18 | lr_scheduler.step()
19 | optimizer.zero_grad(set_to_none=True)
20 | ```
21 |
22 | ## DataDistributedParalell Implementation
23 |
24 | In a distributed setting, gradients will be synchronized at multiple points during our forward pass. It turns out we need to delay this synchronization until we do the full model step!
25 |
26 | We can use [torch.nn.parallel.DistributedDataParallel.no_sync](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.no_sync) for this:
27 |
28 | ```python
29 | from contextlib import nullcontext
30 | maybe_sync_grads = model.no_sync if i_step % grad_accum != 0 else nullcontext
31 | with maybe_sync_grads():
32 | outputs = model(**batch)
33 | outputs.loss.backward()
34 | if i_step % grad_accum == 0:
35 | optimizer.step()
36 | lr_scheduler.step()
37 | optimizer.zero_grad(set_to_none=True)
38 | ```
39 |
--------------------------------------------------------------------------------
/related-topics/optimizing-data-loading/README.md:
--------------------------------------------------------------------------------
1 | # Optimizing Data Loading
2 |
3 | **NOTE: This chapter's code builds off of [chapter 3](../../03-multi-node/)'s code.**
4 |
5 | An important part of achieving high throughput during distributed training is ensuring that all processes are moving at roughly the same speed. If one process is much faster, it will spend a lot of time waiting for the other processes to catch up. Data loading is actually a hugely important part of this.
6 |
7 | ## Motivating Example
8 |
9 | While writing this guide, I noticed a drop in GPU utilization **across all nodes** when moving from single node to multi node. When training single node, the GPU power draw was at 80%, and when I went to multi node, it dropped to 60% across all nodes.
10 |
11 | It turns out data loading was consistently slower on one node, causing **all nodes** to wait for it.
12 |
13 | In this guide's case, since data loading is relatively fast, simply updating the number of workers and the prefetch factor fixed it. In more complex examples, other optimizations or preprocessing may be needed.
14 |
15 | ## Loading data in parallel
16 |
17 | Most slow downs in this case all come from data size:
18 |
19 | 1. If some of the processes read data more slowly, then they will already be behind. This can be due to disk reads being blocked, limits of open file descriptors, etc.
20 | 2. If you have batches of different sizes, then the model forward/backward calls will take different amounts of time.
21 |
22 | Most of these can be handled simply by doing data loading in another process (via `num_workers` argument):
23 |
24 | ```diff
25 | dataloader = DataLoader(
26 | train_data,
27 | batch_size=args.batch_size,
28 | collate_fn=default_data_collator,
29 | + num_workers=1,
30 | + prefetch_factor=2,
31 | # NOTE: this sampler will split dataset evenly across workers
32 | sampler=DistributedSampler(train_data, shuffle=True, drop_last=True),
33 | )
34 | ```
35 |
36 | This will cause the data loading to happen behind the scenes **in parallel to the batch processing**.
37 |
38 | You'll need to change the num_workers and prefetch factor settings based on a number of things:
39 | 1. How big your batch size is
40 | 2. How long a single row from your dataset takes to load/preprocess
41 | 3. How fast your batches take to process
42 |
43 | If you have `num_workers>0`, then you just want the time to fully load a batch to be less than the time to process the batch.
44 |
45 | ## Measuring wait time
46 |
47 | We can measure this phenomena by adding some explicit `dist.barrier()` calls in our code with our timing wrapped around it:
48 |
49 | ```diff --git a/03-multi-node/train_llm.py b/06-data-loading/train_llm.py
50 | index d5cb05c..26cadb8 100644
51 | --- a/03-multi-node/train_llm.py
52 | +++ b/06-data-loading/train_llm.py
53 | @@ -146,7 +148,10 @@ def main():
54 | },
55 | )
56 |
57 | - timers = {k: LocalTimer(device) for k in ["data", "forward", "backward", "update"]}
58 | + timers = {
59 | + k: LocalTimer(device)
60 | + for k in ["data", "forward", "backward", "update", "waiting"]
61 | + }
62 |
63 | for state["epoch"] in range(state["epoch"], args.num_epochs):
64 | LOGGER.info(
65 | @@ -168,13 +173,22 @@ def main():
66 | # NOTE: for resuming
67 | continue
68 |
69 | + with timers["waiting"]:
70 | + dist.barrier()
71 | +
72 | with timers["forward"]:
73 | outputs = model(**batch)
74 |
75 | + with timers["waiting"]:
76 | + dist.barrier()
77 | +
78 | with timers["backward"]:
79 | optimizer.zero_grad(set_to_none=True)
80 | outputs.loss.backward()
81 |
82 | + with timers["waiting"]:
83 | + dist.barrier()
84 | +
85 | with timers["update"]:
86 | optimizer.step()
87 | lr_scheduler.step()
88 | ```
89 |
90 |
91 | ## Faster storage
92 |
93 | A very common setup is to have all of your data on networked data storage. While this is convenient for our code, it is not the most efficient for data reading.
94 |
95 | Similar to how the cache is faster than ram, and ram is faster than disk - local node storage is much faster than networked storage:
96 |
97 | 1. Cache (Fastest)
98 | 2. RAM
99 | 3. Machine local disk
100 | 4. Networked disk (Slowest)
101 |
102 | Simply copying all of your data to each node individual can improve the speed of data loading, at the cost of more storage.
103 |
--------------------------------------------------------------------------------
/related-topics/wandb-configurations/README.md:
--------------------------------------------------------------------------------
1 | # wandb configurations
2 |
3 | There are a bunch of ways to configure wandb during your training runs. What will work best for you depends on how big your cluster is and what you want to track.
4 |
5 | ## rank 0
6 |
7 | This is the standard approach. You will only see system information from the node that has the rank 0 process, and only data from rank 0 will be logged. It is minimal information, and you still get to track the experiment progress.
8 |
9 | ```python
10 | if rank == 0:
11 | wandb.init(
12 | project="distributed-training-guide",
13 | dir=exp_dir,
14 | id=args.experiment_name,
15 | name=args.experiment_name,
16 | resume="must" if resumed else None,
17 | save_code=True,
18 | config=...,
19 | )
20 | ```
21 |
22 | ## local_rank 0 (every node)
23 |
24 | With this approach you can see system information from all nodes, and it scales linearly with number of nodes. This approach uses [wandb grouped runs](https://docs.wandb.ai/guides/runs/grouping/).
25 |
26 | ```python
27 | if local_rank == 0:
28 | wandb.init(
29 | project="distributed-training-guide",
30 | dir=exp_dir / f"rank-{rank}",
31 | group=args.experiment_name,
32 | name=f"rank-{rank}",
33 | id=f"{args.experiment_name}-{rank}",
34 | resume="must" if resumed else None,
35 | save_code=True,
36 | config=...,
37 | )
38 | ```
39 |
40 | If you want the name to appear as the node id you can set:
41 |
42 | ```python
43 | name=f"node-{rank // world_size}"
44 | ```
45 |
46 | ## every rank
47 |
48 | [Grouping docs](https://docs.wandb.ai/guides/runs/grouping)
49 |
50 | This configuration is really useful for tracking as much information about your cluster as possible. The downsides are that if you have a very large cluster, you can hit the ratelimit of wandb, and the wandb graphs become unusable.
51 |
52 | ```python
53 | wandb.init(
54 | project="distributed-training-guide",
55 | dir=exp_dir / f"rank-{rank}",
56 | group=args.experiment_name,
57 | name=f"rank-{rank}",
58 | id=f"{args.experiment_name}-{rank}",
59 | resume="must" if resumed else None,
60 | save_code=True,
61 | config=...,
62 | )
63 | ```
64 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | wandb==0.17.5
2 | torch==2.5.1
3 | tqdm
4 | datasets==3.2.0
5 | transformers==4.48.0
6 |
--------------------------------------------------------------------------------
/top-cluster.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import subprocess
3 | import time
4 | import datetime
5 |
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument(
8 | "--poll-freq", default=1000, type=int, help="Frequency (in ms) to poll clusters"
9 | )
10 | parser.add_argument("hosts", help="File containing hostnames separated by newlines")
11 | args = parser.parse_args()
12 |
13 | with open(args.hosts) as fp:
14 | hosts = list(filter(None, map(str.strip, fp.readlines())))
15 |
16 | while True:
17 | procs = [
18 | subprocess.Popen(
19 | [
20 | "ssh",
21 | host,
22 | "nvidia-smi",
23 | "--query-gpu=utilization.gpu,power.draw,power.limit,memory.used,memory.total",
24 | "--format=csv,noheader,nounits",
25 | "&&",
26 | "nvidia-smi",
27 | "--query-compute-apps=pid",
28 | "--format=csv,noheader,nounits",
29 | ],
30 | stdout=subprocess.PIPE,
31 | stderr=subprocess.STDOUT,
32 | )
33 | for host in hosts
34 | ]
35 | for proc in procs:
36 | proc.wait()
37 |
38 | outputs = [proc.stdout.read().decode() for proc in procs]
39 |
40 | gpu_stats = {}
41 | node_stats = {
42 | host: dict(util=0, power_usage=0, memory_usage=0, num_gpus=0, num_procs=0)
43 | for host in hosts
44 | }
45 | cluster_stats = dict(util=0, power_usage=0, memory_usage=0, num_gpus=0, num_procs=0)
46 | for host, output in zip(hosts, outputs):
47 | gpu_stats[host] = {}
48 | for gpu, stats in enumerate(output.splitlines()):
49 | if "," not in stats:
50 | node_stats[host]["num_procs"] += 1
51 | cluster_stats["num_procs"] += 1
52 | continue
53 |
54 | util, power_draw, power_limit, memory_used, memory_total = map(
55 | float, stats.split(", ")
56 | )
57 | power_usage = 100 * power_draw / power_limit
58 | memory_usage = 100 * memory_used / memory_total
59 |
60 | gpu_stats[host][gpu] = dict(
61 | util=util, power_usage=power_usage, memory_usage=memory_usage
62 | )
63 | node_stats[host]["util"] += util
64 | node_stats[host]["memory_usage"] += memory_usage
65 | node_stats[host]["power_usage"] += power_usage
66 | node_stats[host]["num_gpus"] += 1
67 | cluster_stats["util"] += util
68 | cluster_stats["memory_usage"] += memory_usage
69 | cluster_stats["power_usage"] += power_usage
70 | cluster_stats["num_gpus"] += 1
71 |
72 | if cluster_stats["num_gpus"] > 0:
73 | cluster_stats["util"] /= cluster_stats["num_gpus"]
74 | cluster_stats["memory_usage"] /= cluster_stats["num_gpus"]
75 | cluster_stats["power_usage"] /= cluster_stats["num_gpus"]
76 | for host in hosts:
77 | if node_stats[host]["num_gpus"] == 0:
78 | continue
79 | node_stats[host]["util"] /= node_stats[host]["num_gpus"]
80 | node_stats[host]["memory_usage"] /= node_stats[host]["num_gpus"]
81 | node_stats[host]["power_usage"] /= node_stats[host]["num_gpus"]
82 |
83 | print(f"==={datetime.datetime.now()}")
84 | print(f"{'name':>10}\t{'util':>10}\t{'power':>10}\t{'memory':>10}\t{'nprocs':>10}")
85 | print(
86 | f"{'cluster':>10}\t{cluster_stats['util']:>9.1f}%\t{cluster_stats['power_usage']:>9.1f}%\t{cluster_stats['memory_usage']:>9.1f}%\t{cluster_stats['num_procs']:>10}"
87 | )
88 | for host, stats in node_stats.items():
89 | print(
90 | f"{host:>10}\t{stats['util']:>9.1f}%\t{stats['power_usage']:>9.1f}%\t{stats['memory_usage']:>9.1f}%\t{stats['num_procs']:>10}"
91 | )
92 | print("===")
93 |
94 | time.sleep(args.poll_freq / 1000.0)
95 |
--------------------------------------------------------------------------------