├── .gitignore ├── 01-single-gpu ├── README.md └── train_llm.py ├── 02-distributed-data-parallel ├── README.md └── train_llm.py ├── 03-job-launchers ├── README.md └── job.sbatch ├── 04-fully-sharded-data-parallel ├── README.md └── train_llm.py ├── 05-training-llama-405b ├── README.md ├── download.py ├── hosts ├── launch.sh └── train_llm.py ├── 06-tensor-parallel ├── README.md └── train_llm.py ├── 07-2d-parallel ├── README.md └── train_llm.py ├── LICENSE ├── README.md ├── alternative-frameworks └── deepspeed │ ├── README.md │ ├── ds_config.json │ └── train_llm.py ├── diagnosing-errors └── README.md ├── related-topics ├── README.md ├── determinism │ └── README.md ├── effective-batch-size-and-lr │ └── README.md ├── elastic-training │ ├── README.md │ └── toy.py ├── gradient-accumulation │ └── README.md ├── optimizing-data-loading │ └── README.md └── wandb-configurations │ └── README.md ├── requirements.txt └── top-cluster.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | .vscode 164 | wandb 165 | outputs 166 | .cache 167 | logs 168 | error.json -------------------------------------------------------------------------------- /01-single-gpu/README.md: -------------------------------------------------------------------------------- 1 | # Single GPU 2 | 3 | This is the "standard" single gpu training script. It doesn't do anything with distributed, and aims to be as simple as possible. 4 | 5 | The rest of this guide uses this code as the basis, so this chapter explains all the different parts of the code and why we do them. 6 | 7 | ## Command 8 | 9 | ```bash 10 | cd distributed-training-guide/01-single-gpu 11 | python train_llm.py \ 12 | --experiment-name gpt2-alpaca-single-gpu-$(date +%Y-%m-%dT%H-%M-%S) \ 13 | --dataset-name tatsu-lab/alpaca \ 14 | --model-name openai-community/gpt2 15 | ``` 16 | 17 | ## Code explanation 18 | 19 | This explanation goes roughly in code order, starting from the top. 20 | 21 | ### Argument parsing 22 | 23 | Our training script is a CLI (command line interface) program. That means you run it from a terminal. We have a variety of arguments we'd like the user (you) to be able to change using the CLI. So this is a very standar python way to enable that: 24 | 25 | ```python 26 | def main(): 27 | parser = _get_parser() 28 | args = parser.parse_args() 29 | 30 | 31 | def _get_parser() -> argparse.ArgumentParser: 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--experiment-name", default=None, required=True) 34 | parser.add_argument("--dataset-name", default=None, required=True) 35 | parser.add_argument("--model-name", default=None, required=True) 36 | parser.add_argument("--save-dir", default="../outputs") 37 | parser.add_argument("--seed", default=0, type=int) 38 | parser.add_argument("--num-epochs", default=100, type=int) 39 | parser.add_argument("--lr", default=3e-5, type=float) 40 | parser.add_argument("--batch-size", default=1, type=int) 41 | parser.add_argument("--log-freq", default=100, type=int) 42 | parser.add_argument("--ckpt-freq", default=500, type=int) 43 | return parser 44 | 45 | 46 | if __name__ == "__main__": 47 | main() 48 | ``` 49 | 50 | ### Setting up logging 51 | 52 | For this guide, we just use the built in `logging` package for python. This will output everything to stdout/stderr, and we use command line tools to redirect this output to files for later. 53 | 54 | ```python 55 | LOGGER = logging.getLogger(__name__) 56 | 57 | logging.basicConfig( 58 | format=f"[%(asctime)s] %(levelname)s:%(message)s", 59 | level=logging.INFO, 60 | ) 61 | 62 | LOGGER.info(os.environ) 63 | LOGGER.info(args) 64 | ``` 65 | 66 | It's useful to be able to see what the environment variables & CLI args we are running the program with (especially with multiple nodes involved later). So we log those first. 67 | 68 | ### pytorch setup 69 | 70 | As we are using pytorch there are a couple of useful things to do before we initialize anything 71 | 72 | ```python 73 | device = torch.device("cuda") 74 | dtype = torch.bfloat16 75 | torch.manual_seed(args.seed) 76 | ``` 77 | 78 | Here we are saying that the device we will be using for the rest of the script is a GPU (specifically a CUDA device), and that we are going to be training with bfloat16 (aka bf16) which is a 16 bit floating point number (float is 32 bit, and double is 64 bits). 79 | 80 | ### Initializing the model 81 | 82 | We are training a BF16 causal language model (think GPT) using `transformers` 83 | 84 | ```python 85 | config = AutoConfig.from_pretrained(args.model_name, use_cache=False) 86 | model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype).to(device) 87 | ``` 88 | 89 | ### Initializing our dataset 90 | 91 | We are using `datasets` to load and preprocess our dataset. The processing code used in this guide was sourced from https://github.com/huggingface/transformers/blob/v4.45.1/examples/pytorch/language-modeling/run_clm_no_trainer.py. 92 | 93 | Encourage readers to check out datasets if they want more information. 94 | 95 | ### Data Loading, LR Schedule, Optimizer 96 | 97 | The next section of code is fairly standard pytorch. We are using a DataLoader to iterate our dataset, the AdamW optimizer, and a Cosine Annealing LR schedule. 98 | 99 | ```python 100 | dataloader = DataLoader( 101 | train_data, 102 | batch_size=args.batch_size, 103 | shuffle=True, 104 | drop_last=True, 105 | collate_fn=default_data_collator, 106 | ) 107 | 108 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, fused=True) 109 | 110 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 111 | optimizer, T_max=1000, eta_min=args.lr * 1e-2 112 | ) 113 | ``` 114 | 115 | Note: The `fused=True` argument to the optimizer results in a fused kernel being used. This is faster in pretty much all cases, so we enable it in all the chapters in this guide!! 116 | 117 | ### Outputs & Resuming 118 | 119 | We save checkpoints into `args.save_dir/args.experiment_name` - `--experiment-name is a **unique** run identifier 120 | 121 | ```python 122 | exp_dir: Path = Path(args.save_dir) / args.experiment_name 123 | ``` 124 | 125 | If `args.save_dir/args.experiment_name/state.json` already exists, we attempt to resume. This means if a checkpoint already exists for our experiment_name, then we interpret this as a resumed run. 126 | 127 | ```python 128 | state = { 129 | "epoch": 0, 130 | "global_step": 0, 131 | "epoch_step": 0, 132 | "running_loss": 0, 133 | } 134 | resumed = False 135 | if (exp_dir / "state.json").exists(): 136 | model.load_state_dict(_load_to_device(exp_dir / "model.pt")) 137 | optimizer.load_state_dict(_load_to_device(exp_dir / "optimizer.pt")) 138 | lr_scheduler.load_state_dict(_load_to_device(exp_dir / "lr_scheduler.pt")) 139 | with open(exp_dir / "state.json") as fp: 140 | state = json.load(fp) 141 | resumed = True 142 | ``` 143 | 144 | ### Experiment tracking with Weights & Biases (wandb) 145 | 146 | We resume the run in [wandb](https://wandb.ai/) if we loaded a checkpoint (& also ensure that our unique experiment ID is used for the wandb run id). 147 | 148 | We include a couple of useful initialization flags here as well, so wandb will save our code, and include some hyperparameters we specified on the CLI. 149 | 150 | When we resume a run, we tell wandb that we "must" initialize in resume mode. 151 | 152 | ```python 153 | wandb.init( 154 | project="distributed-training-guide", 155 | dir=exp_dir, 156 | name=args.experiment_name, 157 | id=args.experiment_name, 158 | resume="must" if resumed else None, 159 | save_code=True, 160 | config={ 161 | "args": vars(args), 162 | "training_data_size": len(train_data), 163 | "num_batches": len(dataloader), 164 | }, 165 | ) 166 | ``` 167 | 168 | ### Iterating our batches 169 | 170 | We do this in a non-standard way so we can time various parts of the training loop. Normally, we wouldn't be able to time the actual construction of the batch, but by manually pulling the next batch using `next()`, we can time it: 171 | 172 | ```python 173 | batches = iter(dataloader) 174 | 175 | for i_step in range(len(dataloader)): 176 | # Here we measure the time it takes to generate a batch and move it to the GPU 177 | with timers["data"], torch.no_grad(): 178 | batch = next(batches) 179 | batch = {k: v.to(device=device) for k, v in batch.items()} 180 | ``` 181 | 182 | ### Forward/backward/update 183 | 184 | This is standard pytorch code, with the addition of timing so we can benchmark: 185 | 186 | ```python 187 | with timers["forward"]: 188 | outputs = model(**batch) 189 | 190 | with timers["backward"]: 191 | # NOTE: set_to_none=True will de-allocate the gradients, saving us some memory. 192 | optimizer.zero_grad(set_to_none=True) 193 | outputs.loss.backward() 194 | 195 | with timers["update"]: 196 | optimizer.step() 197 | lr_scheduler.step() 198 | ``` 199 | 200 | ### Logging to wandb (& stdout) 201 | 202 | The next blocks of code involve logging various tidbits about how our training is going: 203 | 204 | We do this based on the `--log-freq` argument, e.g. if we do `--log-freq 100` we will log this data every 100 steps. 205 | 206 | Note that we both log to our LOGGER, and also wandb. 207 | 208 | ```python 209 | if state["global_step"] % args.log_freq == 0: 210 | info = { 211 | "global_step": state["global_step"], 212 | "lr": lr_scheduler.get_last_lr()[0], 213 | "running_loss": state["running_loss"] / args.log_freq, 214 | "epoch": state["epoch"], 215 | "epoch_progress": state["epoch_step"] / len(dataloader), 216 | "num_batches_remaining": len(dataloader) - i_step, 217 | "time/total": sum(t.avg_elapsed_ms() for t in timers.values()), 218 | **{ 219 | f"time/{k}": timer.avg_elapsed_ms() 220 | for k, timer in timers.items() 221 | }, 222 | } 223 | 224 | LOGGER.info(info) 225 | wandb.log(info, step=state["global_step"]) 226 | 227 | state["running_loss"] = 0 228 | for t in timers.values(): 229 | t.reset() 230 | ``` 231 | 232 | ### Checkpoints 233 | 234 | The final block of code is our checkpointing logic, here just using `torch.save`. 235 | 236 | Note that we are saving the optimizer and LR scheduler in addition to the model! 237 | 238 | ```python 239 | if state["global_step"] % args.ckpt_freq == 0: 240 | LOGGER.info("Saving checkpoint.") 241 | torch.save(optimizer.state_dict(), exp_dir / "optimizer.pt") 242 | torch.save(model.state_dict(), exp_dir / "model.pt") 243 | torch.save(lr_scheduler.state_dict(), exp_dir / "lr_scheduler.pt") 244 | with open(exp_dir / "state.json", "w") as fp: 245 | json.dump(state, fp) 246 | ``` 247 | -------------------------------------------------------------------------------- /01-single-gpu/train_llm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from itertools import chain 3 | import json 4 | import multiprocessing 5 | import os 6 | import time 7 | from pathlib import Path 8 | import logging 9 | 10 | import torch 11 | from torch.utils.data import DataLoader 12 | import wandb 13 | import tqdm 14 | import datasets 15 | from transformers import ( 16 | AutoConfig, 17 | AutoModelForCausalLM, 18 | AutoTokenizer, 19 | default_data_collator, 20 | ) 21 | 22 | LOGGER = logging.getLogger(__name__) 23 | 24 | 25 | def main(): 26 | parser = _get_parser() 27 | args = parser.parse_args() 28 | 29 | # Will be modifying this in future version to include rank information 30 | logging.basicConfig( 31 | format=f"[%(asctime)s] %(levelname)s:%(message)s", 32 | level=logging.INFO, 33 | ) 34 | 35 | # Helpful to log this information when running on multiple nodes to make sure all nodes have the same environment. 36 | LOGGER.info(os.environ) 37 | LOGGER.info(args) 38 | 39 | # This guide assumes CUDA device is available, and does all training in bf16 40 | device = torch.device("cuda") 41 | dtype = torch.bfloat16 42 | 43 | # Seed pytorch's RNG. See https://pytorch.org/docs/stable/notes/randomness.html 44 | torch.manual_seed(args.seed) 45 | 46 | # Note: Initializing an **untrained** model 47 | config = AutoConfig.from_pretrained(args.model_name, use_cache=False) 48 | with device: 49 | model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype) 50 | LOGGER.info(f"{sum(p.numel() for p in model.parameters())} model parameters") 51 | 52 | train_data = _load_and_preprocess_data(args, config) 53 | LOGGER.info(f"{len(train_data)} training samples") 54 | 55 | # Standard pytorch dataset iterator 56 | dataloader = DataLoader( 57 | train_data, 58 | batch_size=args.batch_size, 59 | shuffle=True, 60 | drop_last=True, 61 | collate_fn=default_data_collator, 62 | ) 63 | LOGGER.info(f"{len(dataloader)} batches per epoch") 64 | 65 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, fused=True) 66 | 67 | # NOTE: T_max and eta_min were arbitrarily chosen 68 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 69 | optimizer, T_max=1000, eta_min=args.lr * 1e-2 70 | ) 71 | 72 | exp_dir: Path = Path(args.save_dir) / args.experiment_name 73 | LOGGER.info(f"Experiment saving to {exp_dir}") 74 | 75 | # attempt resume 76 | state = { 77 | "epoch": 0, 78 | "global_step": 0, 79 | "epoch_step": 0, 80 | "running_loss": 0, 81 | } 82 | resumed = False 83 | if (exp_dir / "state.json").exists(): 84 | # NOTE: weights_only is to protect against arbitrary code execution with pickle decoding. 85 | def _load_to_device(p): 86 | return torch.load(p, map_location=device, weights_only=True) 87 | 88 | model.load_state_dict(_load_to_device(exp_dir / "model.pt")) 89 | optimizer.load_state_dict(_load_to_device(exp_dir / "optimizer.pt")) 90 | lr_scheduler.load_state_dict(_load_to_device(exp_dir / "lr_scheduler.pt")) 91 | with open(exp_dir / "state.json") as fp: 92 | state = json.load(fp) 93 | resumed = True 94 | LOGGER.info(f"Resumed={resumed} | {state}") 95 | 96 | LOGGER.info(f"Creating experiment root directory") 97 | exp_dir.mkdir(parents=True, exist_ok=True) 98 | 99 | # Initializing [wandb](https://wandb.ai/) - a very useful experiment tracking library. 100 | wandb.init( 101 | project="distributed-training-guide", 102 | dir=exp_dir, 103 | name=args.experiment_name, 104 | id=args.experiment_name, 105 | resume="must" if resumed else None, 106 | save_code=True, 107 | config={ 108 | "args": vars(args), 109 | "training_data_size": len(train_data), 110 | "num_batches": len(dataloader), 111 | }, 112 | ) 113 | 114 | # will be using to understand breakdown of speed 115 | timers = {k: LocalTimer(device) for k in ["data", "forward", "backward", "update"]} 116 | 117 | for state["epoch"] in range(state["epoch"], args.num_epochs): 118 | LOGGER.info(f"Begin epoch {state['epoch']} at step {state['epoch_step']}") 119 | 120 | progress_bar = tqdm.tqdm(range(len(dataloader))) 121 | if state["epoch_step"] > 0: 122 | progress_bar.update(state["epoch_step"]) 123 | 124 | # NOTE: This is not standard. Normally you can just iterate directly over dataloader. 125 | # We are doing this so we can explicitly measure the time it takes to generate a batch. 126 | batches = iter(dataloader) 127 | 128 | for i_step in range(len(dataloader)): 129 | # Here we measure the time it takes to generate a batch and move it to the GPU 130 | with timers["data"], torch.no_grad(): 131 | batch = next(batches) 132 | batch = {k: v.to(device=device) for k, v in batch.items()} 133 | 134 | # For resuming, this has to come after getting the next batch, so we move through the dataset properly. 135 | if i_step < state["epoch_step"]: 136 | # NOTE: for resuming 137 | continue 138 | 139 | with timers["forward"]: 140 | outputs = model(**batch) 141 | 142 | with timers["backward"]: 143 | # NOTE: set_to_none=True will de-allocate the gradients, saving us some memory. 144 | optimizer.zero_grad(set_to_none=True) 145 | outputs.loss.backward() 146 | 147 | with timers["update"]: 148 | optimizer.step() 149 | lr_scheduler.step() 150 | 151 | state["global_step"] += 1 152 | state["epoch_step"] += 1 153 | state["running_loss"] += outputs.loss.item() 154 | progress_bar.update(1) 155 | 156 | if state["global_step"] % args.log_freq == 0: 157 | tok_per_step = args.batch_size * args.seq_length 158 | ms_per_step = sum(t.avg_elapsed_ms() for t in timers.values()) 159 | info = { 160 | "global_step": state["global_step"], 161 | "lr": lr_scheduler.get_last_lr()[0], 162 | "running_loss": state["running_loss"] / args.log_freq, 163 | "epoch": state["epoch"], 164 | "epoch_progress": state["epoch_step"] / len(dataloader), 165 | "num_batches_remaining": len(dataloader) - i_step, 166 | **get_mem_stats(device), 167 | "tok/s": 1000 * tok_per_step / ms_per_step, 168 | "time/total": ms_per_step, 169 | **{ 170 | f"time/{k}": timer.avg_elapsed_ms() 171 | for k, timer in timers.items() 172 | }, 173 | } 174 | 175 | LOGGER.info(info) 176 | wandb.log(info, step=state["global_step"]) 177 | 178 | torch.cuda.reset_peak_memory_stats(device) 179 | state["running_loss"] = 0 180 | for t in timers.values(): 181 | t.reset() 182 | 183 | if state["global_step"] % args.ckpt_freq == 0: 184 | LOGGER.info("Saving checkpoint.") 185 | torch.save(optimizer.state_dict(), exp_dir / "optimizer.pt") 186 | torch.save(model.state_dict(), exp_dir / "model.pt") 187 | torch.save(lr_scheduler.state_dict(), exp_dir / "lr_scheduler.pt") 188 | with open(exp_dir / "state.json", "w") as fp: 189 | json.dump(state, fp) 190 | 191 | state["epoch_step"] = 0 192 | 193 | 194 | def _load_and_preprocess_data(args, config): 195 | """ 196 | Function created using code found in 197 | https://github.com/huggingface/transformers/blob/v4.45.1/examples/pytorch/language-modeling/run_clm_no_trainer.py 198 | """ 199 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 200 | 201 | data = datasets.load_dataset(args.dataset_name, trust_remote_code=True) 202 | 203 | column_names = data["train"].column_names 204 | text_column_name = "text" if "text" in column_names else column_names[0] 205 | 206 | def tokenize_function(examples): 207 | return tokenizer(examples[text_column_name]) 208 | 209 | tokenized_datasets = data.map( 210 | tokenize_function, 211 | batched=True, 212 | remove_columns=column_names, 213 | num_proc=multiprocessing.cpu_count(), 214 | load_from_cache_file=True, 215 | desc="Running tokenizer on dataset", 216 | ) 217 | 218 | seq_length = args.seq_length or tokenizer.model_max_length 219 | if seq_length > config.max_position_embeddings: 220 | seq_length = min(1024, config.max_position_embeddings) 221 | 222 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. 223 | def group_texts(examples): 224 | # Concatenate all texts. 225 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 226 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 227 | # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. 228 | # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. 229 | if total_length > seq_length: 230 | total_length = (total_length // seq_length) * seq_length 231 | # Split by chunks of max_len. 232 | result = { 233 | k: [t[i : i + seq_length] for i in range(0, total_length, seq_length)] 234 | for k, t in concatenated_examples.items() 235 | } 236 | result["labels"] = result["input_ids"].copy() 237 | return result 238 | 239 | lm_datasets = tokenized_datasets.map( 240 | group_texts, 241 | batched=True, 242 | num_proc=multiprocessing.cpu_count(), 243 | load_from_cache_file=True, 244 | desc=f"Grouping texts in chunks of {seq_length}", 245 | ) 246 | 247 | return lm_datasets["train"] 248 | 249 | 250 | def get_mem_stats(device=None): 251 | mem = torch.cuda.memory_stats(device) 252 | props = torch.cuda.get_device_properties(device) 253 | return { 254 | "total_gb": 1e-9 * props.total_memory, 255 | "curr_alloc_gb": 1e-9 * mem["allocated_bytes.all.current"], 256 | "peak_alloc_gb": 1e-9 * mem["allocated_bytes.all.peak"], 257 | "curr_resv_gb": 1e-9 * mem["reserved_bytes.all.current"], 258 | "peak_resv_gb": 1e-9 * mem["reserved_bytes.all.peak"], 259 | } 260 | 261 | 262 | class LocalTimer: 263 | def __init__(self, device: torch.device): 264 | if device.type == "cpu": 265 | self.synchronize = lambda: torch.cpu.synchronize(device=device) 266 | elif device.type == "cuda": 267 | self.synchronize = lambda: torch.cuda.synchronize(device=device) 268 | self.measurements = [] 269 | self.start_time = None 270 | 271 | def __enter__(self): 272 | self.synchronize() 273 | self.start_time = time.time() 274 | return self 275 | 276 | def __exit__(self, type, value, traceback): 277 | if traceback is None: 278 | self.synchronize() 279 | end_time = time.time() 280 | self.measurements.append(end_time - self.start_time) 281 | self.start_time = None 282 | 283 | def avg_elapsed_ms(self): 284 | return 1000 * (sum(self.measurements) / len(self.measurements)) 285 | 286 | def reset(self): 287 | self.measurements = [] 288 | self.start_time = None 289 | 290 | 291 | def _get_parser() -> argparse.ArgumentParser: 292 | parser = argparse.ArgumentParser() 293 | parser.add_argument("-e", "--experiment-name", default=None, required=True) 294 | parser.add_argument("-d", "--dataset-name", default=None, required=True) 295 | parser.add_argument("-m", "--model-name", default=None, required=True) 296 | parser.add_argument("--save-dir", default="../outputs") 297 | parser.add_argument("--seed", default=0, type=int) 298 | parser.add_argument("--num-epochs", default=100, type=int) 299 | parser.add_argument("--lr", default=3e-5, type=float) 300 | parser.add_argument("-b", "--batch-size", default=1, type=int) 301 | parser.add_argument("--log-freq", default=100, type=int) 302 | parser.add_argument("--ckpt-freq", default=500, type=int) 303 | parser.add_argument("-s", "--seq-length", default=1024, type=int) 304 | return parser 305 | 306 | 307 | if __name__ == "__main__": 308 | main() 309 | -------------------------------------------------------------------------------- /02-distributed-data-parallel/README.md: -------------------------------------------------------------------------------- 1 | # Multi GPU on a single node 2 | 3 | **NOTE: This chapter's code builds off of [chapter 1](../01-single-gpu).** 4 | 5 | Single node command: 6 | 7 | ```bash 8 | cd distributed-training-guide/02-distributed-data-parallel 9 | export TORCHELASTIC_ERROR_FILE=../error.json 10 | export OMP_NUM_THREADS=1 11 | torchrun --standalone \ 12 | --nproc-per-node gpu \ 13 | --redirects 3 \ 14 | --log-dir ../logs \ 15 | train_llm.py \ 16 | --experiment-name gpt2-alpaca-ddp-$(date +%Y-%m-%dT%H-%M-%S) \ 17 | --dataset-name tatsu-lab/alpaca \ 18 | --model-name openai-community/gpt2 19 | ``` 20 | 21 | For multi node, see our [chapter on job launchers](../03-job-launchers/). 22 | 23 | Quick jump: 24 | - [How Distributed Training works](#how-distributed-training-works) 25 | - [Using torchrun](#using-torchrun-instead-of-python) 26 | - [Code Changes](#code-changes) 27 | - [Calling dist.init_process_group() and torch.cuda.set_device()](#calling-distinit_process_group-and-torchcudaset_device) 28 | - [Including rank in logging statements](#including-rank-in-logging-statements) 29 | - [DistributedDataParallel (DDP)](#using-distributeddataparallel) 30 | - [DistributedSampler](#using-distributedsampler) 31 | - I/O related guards 32 | - [Downloading model/data in rank 0 first](#downloading-model--data-in-rank-0-first) 33 | - [Interacting with file system on rank 0 only](#only-creating-experiment-directory-on-rank-0) 34 | - [wandb on rank 0 only](#wandb-runs-on-rank-0) 35 | - [Checkpoints from rank 0 only](#save-checkpoint-on-rank-0) 36 | - [Optimizing memory - Zero Redundancy](#optimizing-memory---zero-redundancy-optimizer) 37 | - [How multi node works](#how-multi-node-works) 38 | - [Shared storage - Managing your python virtual environment across nodes](#shared-storage---managing-your-python-virtual-environment-across-nodes) 39 | - [Shared storage - Mangaging your dataset/model checkpoints across nodes](#shared-storage---mangaging-your-datasetmodel-checkpoints-across-nodes) 40 | - [`$HF_HOME` - The downloaded Model/Dataset directory](#hf_home---the-downloaded-modeldataset-directory) 41 | 42 | ## Dictionary 43 | 44 | - `world size`: the total number of participating gpus 45 | - `rank` the global unique id of this worker (from `0` up to and including `world_size - 1`) 46 | - `local_rank` the rank local to this machine (from `0` up to and including `torch.cuda.device_count() - 1`) 47 | 48 | ## How Distributed Training works 49 | 50 | Before we get into the changes required to do distributed training, let's think a little bit. If you've ever done parallel computations before, you know one way to achieve parallelization is to simply split your workload over all your cores. This is really useful if your task is relatively the same for all of the things you want to process. In fact, this is how python's multiprocessing.Pool.map object works. 51 | 52 | Well distributed training with a GPU actually works the same way - we are splitting our workload (which is the batches from our dataset) over multiple GPUs. However we have an additional problem: how do we ensure that the model on all of our GPUs is the same? 53 | 54 | We can actually achieve this in a very clever way. For sake of simplicity let's assume: 55 | 1. Our model and optimizer fully fit on every GPU 56 | 2. We initialize our model the exact same way on all of our GPUs 57 | 3. Our optimizer has the exact same settings on all of our GPUs 58 | 59 | Now let's focus on our training loop. The canonical one in pytorch is: 60 | 61 | ```python 62 | loss = model(**batch) # 1. Forward pass asynchronously 63 | optimizer.zero_grad() # 2. Reset gradients asynchronously 64 | loss.backward() # 3. calculates gradients asynchronously 65 | optimizer.step() # 4. synchronize gradients & update weights 66 | ``` 67 | 68 | The first 3 lines of the above can all be done asychronously. `loss.backward()` will compute gradients in each of our training processes. The clever bit is that `optimizer.step()` will synchronize the gradients across all processes before actually updating the model parameters. 69 | 70 | Here is a high level depiction of how a single step of training works when using this data parallel technique: 71 | 72 | image 73 | 74 | So to be explicit: **`optimizer.step()` is a synchronization point across ALL processes**. 75 | 76 | So how does pytorch achieve this? 77 | 78 | 1. [Running N copies of our training script with torchrun.](#using-torchrun-instead-of-python) 79 | 2. [Splitting data across our workers with DistributedSampler](#using-distributedsampler) 80 | 3. [Synchronizing our gradients with DistributedDataParallel](#using-distributeddataparallel) 81 | 82 | ## Using `torchrun` instead of `python` 83 | 84 | When we use `torchrun` to launch a distributed training job, what's happening is that it is **launch N separate processes** (where N is the number of gpus you specify in `--nproc-per-node`), all running your same training script: 85 | 86 | ``` 87 | > torchrun --nproc-per-node 3 train_llm.py ... 88 | Launches subproc `$RANK=0 $WORLD_SIZE=3 train_llm.py ...` 89 | Launches subproc `$RANK=1 $WORLD_SIZE=3 train_llm.py ...` 90 | Launches subproc `$RANK=2 $WORLD_SIZE=3 train_llm.py ...` 91 | ``` 92 | 93 | It will also set up some synchronization between the processes. Then each of the processes is running the same training code and needs to synchronize at various points. Each of the processes gets an id (the `rank`), which will tell it what device to use. 94 | 95 | When running on multiple nodes, you need to run torchrun on every machine, but other than that it works exactly the same. See our [job launchers chapter](../03-job-launchers/) for how to do this. 96 | 97 | Here are some of the common CLI arguments to torchrun used throughout this guide: 98 | - `--standalone` argument is used when only running on a single node. 99 | - `--nnodes` is the number of nodes we are using, in this case 1, but once we go to multiple nodes, this will be > 1. 100 | - `--nproc-per-node` is the number of processes. `gpu` means to use all available GPUs. 101 | - `--redirects 3` redirects the stdout & stderr into files 102 | - `--log-dir ../logs` configures the log directory 103 | 104 | #### TORCHELASTIC_ERROR_FILE 105 | 106 | **Very important to include this for debugging!** 107 | 108 | When one of the workers (including a thread from a worker process) has an error, torchrun will save the error to the filepath controlled by this environment variable. 109 | 110 | You also need to add a `@record` (imported `from torch.distributed.elastic.multiprocessing.errors import record`) annotation to your main function: 111 | 112 | ```diff 113 | +@record 114 | def main(): 115 | ``` 116 | 117 | #### OMP_NUM_THREADS 118 | 119 | pytorch by default tries to take advantage of all the cores available when doing computations, even when you are on the GPU. Since we have multiple processes running pytorch, if we didn't set `OMP_NUM_THREADS` to something else, all of them would try to use all available cores. 120 | 121 | You can manually check how many available cores there are and then split them accordingly. E.g. if there were 32 cores on a machine and 8 GPUs, you could set OMP_NUM_THREADS to 4. 122 | 123 | 124 | ## Code Changes 125 | 126 | ### Calling `dist.init_process_group()` and `torch.cuda.set_device()` 127 | 128 | You are required to call both of these before calling other dist apis. 129 | 130 | `dist.init_process_group()` will block until `WORLD_SIZE` processes have called it. 131 | 132 | ```diff 133 | +from torch import distributed as dist 134 | 135 | def main(): 136 | parser = _get_parser() 137 | args = parser.parse_args() 138 | 139 | + dist.init_process_group() 140 | + rank = dist.get_rank() 141 | + local_rank = rank % torch.cuda.device_count() 142 | + world_size = dist.get_world_size() 143 | ``` 144 | 145 | Then we can set our device using rank: 146 | 147 | ```diff 148 | -device = torch.device(f"cuda") 149 | +device = torch.device(f"cuda:{local_rank}") 150 | ``` 151 | 152 | And finally add your call to torch.cuda.set_device shortly after that: 153 | 154 | ```diff 155 | device = torch.device(f"cuda:{local_rank}") 156 | dtype = torch.bfloat16 157 | +torch.cuda.set_device(device) 158 | ``` 159 | 160 | If you don't call torch.cuda.set_device, processes may not be using the correct CUDA device. 161 | 162 | ### Including rank in logging statements 163 | 164 | This is a helpful thing to do to handle all the processes outputting to the same file, or even when you're browsing a single log file it's useful to have this on every log statement: 165 | 166 | ```diff 167 | logging.basicConfig( 168 | - format=f"[%(asctime)s] %(levelname)s:%(message)s", 169 | + format=f"[rank={rank}] [%(asctime)s] %(levelname)s:%(message)s", 170 | level=logging.INFO, 171 | ) 172 | ``` 173 | 174 | ### Using DistributedDataParallel 175 | 176 | ```diff 177 | +from torch.nn.parallel import DistributedDataParallel 178 | 179 | with device: 180 | model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype) 181 | 182 | +model = DistributedDataParallel(model, device_ids=[local_rank]) 183 | ``` 184 | 185 | Funnily enough you might assume that the DDP module splits batches across processes, but that is not what it does at all! 186 | 187 | This is a model wrapper class that ensures **gradients are synchronized before calling optimizer.step()**. I encourage you to read the documentation for this, it's very informative: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html. 188 | 189 | This class also ensures that *model parameters are equal when you construct it*! 190 | 191 | It achieves all of this through some [very special model hooks](https://github.com/pytorch/pytorch/blob/v2.4.1/torch/nn/parallel/distributed.py#L939) to sum all the gradients from all the ranks on all the ranks together: 192 | 193 | ```python 194 | # NOTE: internal pytorch code found at https://github.com/pytorch/pytorch/blob/v2.4.1/torch/nn/parallel/distributed.py#L939 195 | gradient = param.grad / self.process_group.size() 196 | gradient = fcol.all_reduce(gradient, "sum", self.process_group) 197 | ``` 198 | 199 | ### Using DistributedSampler 200 | 201 | In our normal training script we use a `torch.utils.data.DataLoader` to batch our data. One of the arguments to DataLoader is a `sampler`, which basically samples items from the dataset when constructing the batches. You can think of the sampler as doing: 202 | 203 | ```python 204 | def simple_sampler(): 205 | worker_len = len(dataset) 206 | return random.choice(range(worker_len)) 207 | ``` 208 | 209 | The clever thing that the DistributedSampler does is it partitions the length of the dataset across each of our workers. You don't even have to partition the actual dataset - it just chooses the integers that it returns from a specific subset of the dataset: 210 | 211 | ```python 212 | def distributed_sampler(): 213 | worker_len = len(dataset) // dist.get_world_size() 214 | return dist.get_rank() * worker_len + random.choice(range(worker_len)) 215 | ``` 216 | 217 | Our code changes are very minimal for this to work! 218 | 219 | ```diff 220 | +from torch.utils.data.distributed import DistributedSampler 221 | 222 | dataloader = DataLoader( 223 | train_data, 224 | batch_size=args.batch_size, 225 | - shuffle=True, 226 | - drop_last=True, 227 | collate_fn=default_data_collator, 228 | + sampler=DistributedSampler(train_data, shuffle=True, drop_last=True), 229 | ) 230 | ``` 231 | 232 | You also need to call [DistributedSampler.set_epoch](https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler). Here's the quote from the pytorch doc on this: 233 | 234 | ```diff 235 | for state["epoch"] in range(state["epoch"], args.num_epochs): 236 | + dataloader.sampler.set_epoch(state["epoch"]) 237 | batches = iter(dataloader) 238 | ``` 239 | 240 | > In distributed mode, calling the set_epoch() method at the beginning of each epoch before creating the DataLoader iterator is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be always used. 241 | 242 | ### Downloading model & data in rank 0 first 243 | 244 | This is mainly necessary because loading our data may download/preprocess some data and write to disk. 245 | 246 | If we didn't do rank 0 first, all of our ranks may try to download the data at once, which will slow everything down. 247 | 248 | **NOTE: A good best practice is to have your data already downloaded & preprocessed into a shared network drive** 249 | 250 | We can add a simple context manager to do this: 251 | 252 | ```python 253 | @contextmanager 254 | def rank0_first(): 255 | rank = dist.get_rank() 256 | if rank == 0: 257 | yield 258 | dist.barrier() 259 | if rank > 0: 260 | yield 261 | dist.barrier() 262 | ``` 263 | 264 | Downloading model weights & tokenizer: 265 | 266 | ```diff 267 | +with rank0_first(): 268 | config = AutoConfig.from_pretrained(args.model_name, use_cache=False) 269 | with device: 270 | model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype) 271 | ``` 272 | 273 | Downloading data: 274 | 275 | ```diff 276 | +with rank0_first(): 277 | train_data = _load_and_preprocess_data(args, tokenizer, config) 278 | ``` 279 | 280 | ### Only creating experiment directory on rank 0 281 | 282 | Note the `dist.barrier()` calls before and after we create the directory. **These are very important!** 283 | 284 | Since we check to see if the experiment directory already exists right before creating the experiment directory, we need to ensure that **all processes have checked for its existence**. So the first `dist.barrier()` call ensures that all workers have already checked the existence of that. Then and only then can we create the directory on rank 0. 285 | 286 | ```diff 287 | +if rank == 0: 288 | exp_dir.mkdir(parents=True, exist_ok=True) 289 | +dist.barrier() 290 | ``` 291 | 292 | ### wandb runs on rank 0 293 | 294 | The standard approach for doing distributed wandb is to only invoke wandb on rank 0 process. This is very easy to implement and you don't have to worry about scaling issues as you add more ranks. 295 | 296 | There are other approaches you can use, like grouped wandb runs, which you can read about in our chapter on [wandb-configurations](../related-topics/wandb-configurations/) for more details. 297 | 298 | ```diff 299 | +if rank == 0: 300 | wandb.init( 301 | project="distributed-training-guide", 302 | dir=exp_dir, 303 | name=args.experiment_name, 304 | id=args.experiment_name, 305 | resume="must" if resumed else None, 306 | save_code=True, 307 | config={ 308 | "args": vars(args), 309 | "training_data_size": len(train_data), 310 | "num_batches": len(dataloader), 311 | "rank": rank, 312 | "world_size": world_size, 313 | }, 314 | ) 315 | ``` 316 | 317 | and 318 | 319 | ```diff 320 | +if rank == 0: 321 | wandb.log(info, step=state["global_step"]) 322 | ``` 323 | 324 | ### Save checkpoint on rank 0 325 | 326 | We only want one of our ranks to save a checkpoint. Otherwise the ranks might write to the same file and corrupt each other. 327 | 328 | ```diff 329 | if state["global_step"] % args.ckpt_freq == 0: 330 | + if rank == 0: 331 | torch.save(optimizer.state_dict(), exp_dir / "optimizer.pt") 332 | torch.save(model.state_dict(), exp_dir / "model.pt") 333 | torch.save(lr_scheduler.state_dict(), exp_dir / "lr_scheduler.pt") 334 | with open(exp_dir / "state.json", "w") as fp: 335 | json.dump(state, fp) 336 | + dist.barrier() 337 | ``` 338 | 339 | ## Optimizing memory - Zero Redundancy Optimizer 340 | 341 | DDP stores the entire model and optimizer on every single GPU. This is especially wasteful regarding the optimizer. Thankfully we have [ZeRO: Memory Optimizations Toward Training Trillion Parameter Models](https://arxiv.org/abs/1910.02054) which we can easily add to reduce memory usage: 342 | 343 | ```diff 344 | + from torch.distributed.optim import ZeroRedundancyOptimizer 345 | 346 | -optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) 347 | +optimizer = ZeroRedundancyOptimizer( 348 | + model.parameters(), optimizer_class=torch.optim.AdamW, lr=args.lr 349 | +) 350 | ``` 351 | 352 | Unfortunately the code to save a state dict for ZeRO is exorbitantly slow, so we also have to remove saving the optimizer state dicts. 353 | 354 | ## How multi node works 355 | 356 | It actually works in much the same way as the single node multi GPU. Since in the single node setting we have multiple processes, now we are just adding extra processes on different machines. 357 | 358 | The main differences here to consider are: 359 | 1. How to maintain the same environment on every node 360 | 2. How the nodes get in contact with each other (the `rdzv` arguments in the torchrun command) 361 | 3. How each node will access the data 362 | 363 | Error reporting/handling becomes extremely important with more than 1 node. Networking issues are very common, and there are some subtle things that you need to ensure are identical between the machines. 364 | 365 | ## Shared storage - Managing your python virtual environment across nodes 366 | 367 | For this the easiest approach is to create your python virtual environment in a shared network drive that all nodes can access. This way all of your nodes are using the exact same python executable/environment. 368 | 369 | Creating the virtual environment is the same as normal, you just want the directory to be shared. 370 | 371 | ## Shared storage - Mangaging your dataset/model checkpoints across nodes 372 | 373 | Again, the easiest approach here is to keep your data in a shared network drive. One thing to note is that shared network drives are slower to read from than node local drives. If you run into slowdowns in data loading, you can copy the data or model into node local storage. 374 | 375 | When using `transformers` or `datasets`, make sure to set the `$HF_HOME` environment variable to control where huggingface downloads both datasets and model weights. 376 | 377 | ## `$HF_HOME` - The downloaded Model/Dataset directory 378 | 379 | Huggingface `transformers` and `datasets` library will download things to `$HF_HOME` by default. `$HF_HOME` defaults to a **node local** value. There are two options for you here: 380 | 381 | 1. Keep `$HF_HOME` as node local and change `with rank0_first()` to be `with local_rank0_first()` 382 | 2. Change `$HF_HOME` to be a shared network drive 383 | 384 | A third option which requires code changes to the code in this repo would be to do this automatically in code: 385 | 386 | ```python 387 | @contextmanager 388 | def rank_ordered(first: bool): 389 | if first: 390 | yield 391 | dist.barrier() 392 | if not first: 393 | yield 394 | dist.barrier() 395 | 396 | # Determine if HF_HOME is node local or shared directory 397 | hf_home_is_networked = os.path.ismount(os.environ["HF_HOME"]) 398 | 399 | if hf_home_is_networked: 400 | # We want rank 0 to go first (download will ONLY occur in rank 0) if directory is shared 401 | should_go_first = rank == 0 402 | else: 403 | # If directory is node local we want a SINGLE process on the node to download the data (local rank 0) 404 | should_go_first = local_rank == 0 405 | 406 | with rank_ordered(should_go_first): 407 | train_data = _load_and_preprocess_data(args, tokenizer, config) 408 | ``` 409 | -------------------------------------------------------------------------------- /02-distributed-data-parallel/train_llm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from contextlib import contextmanager 3 | from itertools import chain 4 | import json 5 | import multiprocessing 6 | import os 7 | import time 8 | from pathlib import Path 9 | import logging 10 | 11 | import torch 12 | from torch.utils.data import DataLoader 13 | from torch.utils.data.distributed import DistributedSampler 14 | from torch.nn.parallel import DistributedDataParallel 15 | from torch import distributed as dist 16 | from torch.distributed.elastic.multiprocessing.errors import record 17 | from torch.distributed.optim import ZeroRedundancyOptimizer 18 | 19 | import wandb 20 | import tqdm 21 | import datasets 22 | from transformers import ( 23 | AutoConfig, 24 | AutoModelForCausalLM, 25 | AutoTokenizer, 26 | default_data_collator, 27 | ) 28 | 29 | LOGGER = logging.getLogger(__name__) 30 | 31 | 32 | @record 33 | def main(): 34 | parser = _get_parser() 35 | args = parser.parse_args() 36 | 37 | dist.init_process_group() 38 | 39 | rank = dist.get_rank() 40 | local_rank = rank % torch.cuda.device_count() 41 | world_size = dist.get_world_size() 42 | 43 | logging.basicConfig( 44 | format=f"[rank={rank}] [%(asctime)s] %(levelname)s:%(message)s", 45 | level=logging.INFO, 46 | ) 47 | 48 | LOGGER.info(os.environ) 49 | LOGGER.info(args) 50 | LOGGER.info(f"local_rank={local_rank} rank={rank} world_size={world_size}") 51 | 52 | device = torch.device(f"cuda:{local_rank}") 53 | dtype = torch.bfloat16 54 | torch.cuda.set_device(device) 55 | 56 | torch.manual_seed(args.seed) 57 | 58 | # NOTE: assumes $HF_HOME is shared storage 59 | with rank0_first(): 60 | config = AutoConfig.from_pretrained(args.model_name, use_cache=False) 61 | with device: 62 | model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype) 63 | LOGGER.info(f"{sum(p.numel() for p in model.parameters())} model parameters") 64 | 65 | model = DistributedDataParallel(model, device_ids=[local_rank]) 66 | 67 | # NOTE: Assumes that $HF_HOME is shared storage 68 | with rank0_first(): 69 | train_data = _load_and_preprocess_data(args, config) 70 | LOGGER.info(f"{len(train_data)} training samples") 71 | 72 | dataloader = DataLoader( 73 | train_data, 74 | batch_size=args.batch_size, 75 | collate_fn=default_data_collator, 76 | # NOTE: this sampler will split dataset evenly across workers 77 | sampler=DistributedSampler(train_data, shuffle=True, drop_last=True), 78 | ) 79 | LOGGER.info(f"{len(dataloader)} batches per epoch") 80 | 81 | optimizer = ZeroRedundancyOptimizer( 82 | model.parameters(), optimizer_class=torch.optim.AdamW, lr=args.lr 83 | ) 84 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 85 | optimizer, T_max=1000, eta_min=args.lr * 1e-2 86 | ) 87 | 88 | exp_dir: Path = Path(args.save_dir) / args.experiment_name 89 | 90 | # attempt resume 91 | state = { 92 | "epoch": 0, 93 | "global_step": 0, 94 | "epoch_step": 0, 95 | "running_loss": 0, 96 | } 97 | resumed = False 98 | if (exp_dir / "state.json").exists(): 99 | 100 | def _load_to_device(p): 101 | return torch.load(p, map_location=device, weights_only=True) 102 | 103 | model.load_state_dict(_load_to_device(exp_dir / "model.pt")) 104 | optimizer.load_state_dict(_load_to_device(exp_dir / "optimizer.pt")) 105 | lr_scheduler.load_state_dict(_load_to_device(exp_dir / "lr_scheduler.pt")) 106 | with open(exp_dir / "state.json") as fp: 107 | state = json.load(fp) 108 | resumed = True 109 | LOGGER.info(f"Resumed={resumed} | {state}") 110 | dist.barrier() 111 | 112 | if rank == 0: 113 | LOGGER.info(f"Creating experiment root directory") 114 | exp_dir.mkdir(parents=True, exist_ok=True) 115 | dist.barrier() 116 | 117 | if rank == 0: 118 | wandb.init( 119 | project="distributed-training-guide", 120 | dir=exp_dir, 121 | name=args.experiment_name, 122 | id=args.experiment_name, 123 | resume="must" if resumed else None, 124 | save_code=True, 125 | config={ 126 | "args": vars(args), 127 | "training_data_size": len(train_data), 128 | "num_batches": len(dataloader), 129 | "world_size": world_size, 130 | }, 131 | ) 132 | 133 | timers = {k: LocalTimer(device) for k in ["data", "forward", "backward", "update"]} 134 | 135 | for state["epoch"] in range(state["epoch"], args.num_epochs): 136 | LOGGER.info(f"Begin epoch {state['epoch']} at step {state['epoch_step']}") 137 | 138 | progress_bar = tqdm.tqdm(range(len(dataloader)), disable=rank > 0) 139 | if state["epoch_step"] > 0: 140 | progress_bar.update(state["epoch_step"]) 141 | 142 | # We need to do this so we shuffle differently on each epoch in a reproducible way. 143 | dataloader.sampler.set_epoch(state["epoch"]) 144 | batches = iter(dataloader) 145 | 146 | for i_step in range(len(dataloader)): 147 | with timers["data"], torch.no_grad(): 148 | batch = next(batches) 149 | batch = {k: v.to(device=device) for k, v in batch.items()} 150 | 151 | if i_step < state["epoch_step"]: 152 | # NOTE: for resuming 153 | continue 154 | 155 | with timers["forward"]: 156 | outputs = model(**batch) 157 | 158 | with timers["backward"]: 159 | optimizer.zero_grad(set_to_none=True) 160 | outputs.loss.backward() 161 | 162 | with timers["update"]: 163 | optimizer.step() 164 | lr_scheduler.step() 165 | 166 | state["global_step"] += 1 167 | state["epoch_step"] += 1 168 | state["running_loss"] += outputs.loss.item() 169 | progress_bar.update(1) 170 | 171 | if state["global_step"] % args.log_freq == 0: 172 | tok_per_step = world_size * args.batch_size * args.seq_length 173 | ms_per_step = sum(t.avg_elapsed_ms() for t in timers.values()) 174 | info = { 175 | "global_step": state["global_step"], 176 | "lr": lr_scheduler.get_last_lr()[0], 177 | "running_loss": state["running_loss"] / args.log_freq, 178 | "epoch": state["epoch"], 179 | "epoch_progress": state["epoch_step"] / len(dataloader), 180 | "num_batches_remaining": len(dataloader) - i_step, 181 | **get_mem_stats(device), 182 | "tok/s": 1000 * tok_per_step / ms_per_step, 183 | "time/total": ms_per_step, 184 | **{ 185 | f"time/{k}": timer.avg_elapsed_ms() 186 | for k, timer in timers.items() 187 | }, 188 | } 189 | 190 | LOGGER.info(info) 191 | if rank == 0: 192 | wandb.log(info, step=state["global_step"]) 193 | 194 | torch.cuda.reset_peak_memory_stats(device) 195 | state["running_loss"] = 0 196 | for t in timers.values(): 197 | t.reset() 198 | 199 | if state["global_step"] % args.ckpt_freq == 0: 200 | if rank == 0: 201 | LOGGER.info("Saving checkpoint.") 202 | torch.save(model.state_dict(), exp_dir / "model.pt") 203 | torch.save(lr_scheduler.state_dict(), exp_dir / "lr_scheduler.pt") 204 | with open(exp_dir / "state.json", "w") as fp: 205 | json.dump(state, fp) 206 | dist.barrier() 207 | 208 | state["epoch_step"] = 0 209 | 210 | 211 | def _load_and_preprocess_data(args, config): 212 | """ 213 | Function created using code found in 214 | https://github.com/huggingface/transformers/blob/v4.45.1/examples/pytorch/language-modeling/run_clm_no_trainer.py 215 | """ 216 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 217 | 218 | data = datasets.load_dataset(args.dataset_name, trust_remote_code=True) 219 | 220 | column_names = data["train"].column_names 221 | text_column_name = "text" if "text" in column_names else column_names[0] 222 | 223 | def tokenize_function(examples): 224 | return tokenizer(examples[text_column_name]) 225 | 226 | tokenized_datasets = data.map( 227 | tokenize_function, 228 | batched=True, 229 | remove_columns=column_names, 230 | num_proc=multiprocessing.cpu_count(), 231 | load_from_cache_file=True, 232 | desc="Running tokenizer on dataset", 233 | ) 234 | 235 | seq_length = args.seq_length or tokenizer.model_max_length 236 | if seq_length > config.max_position_embeddings: 237 | seq_length = min(1024, config.max_position_embeddings) 238 | 239 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. 240 | def group_texts(examples): 241 | # Concatenate all texts. 242 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 243 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 244 | # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. 245 | # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. 246 | if total_length > seq_length: 247 | total_length = (total_length // seq_length) * seq_length 248 | # Split by chunks of max_len. 249 | result = { 250 | k: [t[i : i + seq_length] for i in range(0, total_length, seq_length)] 251 | for k, t in concatenated_examples.items() 252 | } 253 | result["labels"] = result["input_ids"].copy() 254 | return result 255 | 256 | lm_datasets = tokenized_datasets.map( 257 | group_texts, 258 | batched=True, 259 | num_proc=multiprocessing.cpu_count(), 260 | load_from_cache_file=True, 261 | desc=f"Grouping texts in chunks of {seq_length}", 262 | ) 263 | 264 | return lm_datasets["train"] 265 | 266 | 267 | def get_mem_stats(device=None): 268 | mem = torch.cuda.memory_stats(device) 269 | props = torch.cuda.get_device_properties(device) 270 | return { 271 | "total_gb": 1e-9 * props.total_memory, 272 | "curr_alloc_gb": 1e-9 * mem["allocated_bytes.all.current"], 273 | "peak_alloc_gb": 1e-9 * mem["allocated_bytes.all.peak"], 274 | "curr_resv_gb": 1e-9 * mem["reserved_bytes.all.current"], 275 | "peak_resv_gb": 1e-9 * mem["reserved_bytes.all.peak"], 276 | } 277 | 278 | 279 | @contextmanager 280 | def rank0_first(): 281 | rank = dist.get_rank() 282 | if rank == 0: 283 | yield 284 | dist.barrier() 285 | if rank > 0: 286 | yield 287 | dist.barrier() 288 | 289 | 290 | class LocalTimer: 291 | def __init__(self, device: torch.device): 292 | if device.type == "cpu": 293 | self.synchronize = lambda: torch.cpu.synchronize(device=device) 294 | elif device.type == "cuda": 295 | self.synchronize = lambda: torch.cuda.synchronize(device=device) 296 | self.measurements = [] 297 | self.start_time = None 298 | 299 | def __enter__(self): 300 | self.synchronize() 301 | self.start_time = time.time() 302 | return self 303 | 304 | def __exit__(self, type, value, traceback): 305 | if traceback is None: 306 | self.synchronize() 307 | end_time = time.time() 308 | self.measurements.append(end_time - self.start_time) 309 | self.start_time = None 310 | 311 | def avg_elapsed_ms(self): 312 | return 1000 * (sum(self.measurements) / len(self.measurements)) 313 | 314 | def reset(self): 315 | self.measurements = [] 316 | self.start_time = None 317 | 318 | 319 | def _get_parser() -> argparse.ArgumentParser: 320 | parser = argparse.ArgumentParser() 321 | parser.add_argument("-e", "--experiment-name", default=None, required=True) 322 | parser.add_argument("-d", "--dataset-name", default=None, required=True) 323 | parser.add_argument("-m", "--model-name", default=None, required=True) 324 | parser.add_argument("--save-dir", default="../outputs") 325 | parser.add_argument("--seed", default=0, type=int) 326 | parser.add_argument("--num-epochs", default=100, type=int) 327 | parser.add_argument("--lr", default=3e-5, type=float) 328 | parser.add_argument("-b", "--batch-size", default=1, type=int) 329 | parser.add_argument("--log-freq", default=100, type=int) 330 | parser.add_argument("--ckpt-freq", default=500, type=int) 331 | parser.add_argument("-s", "--seq-length", default=1024, type=int) 332 | return parser 333 | 334 | 335 | if __name__ == "__main__": 336 | main() 337 | -------------------------------------------------------------------------------- /03-job-launchers/README.md: -------------------------------------------------------------------------------- 1 | # Job Launchers 2 | 3 | **NOTE: This chapter's code is identical to [chapter 2](../02-distributed-data-parallel/)'s code, so the command uses the training script from chapter 2.** If the job launcher requires code changes to work, the code changes will be called out. 4 | 5 | Since it is quite cumbersome to manually SSH into every node and start a training job, there are various ways to launch distributed training jobs from a single node. 6 | 7 | Quick jump: 8 | - [Bash per node](#bash-commands-xargssshtmux) 9 | - [slurm](#slurm) 10 | - [mpi](#mpirun) 11 | - [deepspeed](#deepspeed) 12 | 13 | ## Bash Commands (xargs/ssh/tmux) 14 | 15 | Since the main thing we need to do is spawn processes on other machines, we can combine a few bash tools together to achieve this. This approach is one of the most lightweight approaches for this, and makes it easy to edit the commands any way you want. While it takes a bit to understand how all the bash commands work together, they are generally applicable to other problems as well. 16 | 17 | Put your list of hostnames/IPs in a file called `hosts`. Each line represents a single node that we will launch `torchrun` on. 18 | 19 | ``` 20 | 21 | 22 | ... 23 | 24 | ``` 25 | 26 | Then we can use ssh to launch `torchrun` on each of the hosts. This command is very similar to our previous bash command, except we are using `torchrun` (`python -m torch.distributed.run`) instead of just invoking our python script. 27 | 28 | ```bash 29 | cd distributed-training-guide/03-job-launchers 30 | JOB_NAME=multi-node-tmux 31 | xargs -a hosts -I {} \ 32 | ssh {} tmux new-session -d -s $JOB_NAME -c $(pwd) \ 33 | -e TORCHELASTIC_ERROR_FILE=../error.json \ 34 | -e OMP_NUM_THREADS=1 \ 35 | -e HF_HOME=../.cache \ 36 | $(which python) -m torch.distributed.run \ 37 | --rdzv-id $JOB_NAME \ 38 | --rdzv-backend c10d \ 39 | --rdzv-endpoint $(head -n 1 hosts):5001 \ 40 | --nnodes $(grep -c '^' hosts) \ 41 | --nproc-per-node gpu \ 42 | --redirects 3 \ 43 | --log-dir ../logs \ 44 | ../02-distributed-data-parallel/train_llm.py \ 45 | --experiment-name $JOB_NAME \ 46 | --dataset-name tatsu-lab/alpaca \ 47 | --model-name openai-community/gpt2 48 | ``` 49 | 50 | Monitoring the output: 51 | ```bash 52 | find ../logs/ -name \*stderr.log | xargs tail -f 53 | ``` 54 | 55 | Killing the job: 56 | ```bash 57 | xargs -a hosts -I{} ssh {} tmux kill-session -t $JOB_NAME 58 | ``` 59 | 60 | Here's how these work: 61 | 62 | 1. `xargs -a hosts -I {}` reads the lines from the `hosts` file, and replaces `{}` in the command following with each line 63 | 2. `ssh {} tmux new-session -d -s $JOB_NAME -c $(pwd)` creates a tmux session on each of the hosts in the hosts file 64 | 1. `-d` means detached, so we can spawn it without blocking 65 | 2. `-s $JOB_NAME` means the sessions will have the name of our job, meaning we can kill them easily. 66 | 3. `-c $(pwd)` means every process will have the working directory that we launch this command from 67 | 4. `-e =` will set up an environment variable for the process we launch using tmux 68 | 69 | From there on we just paste our normal python command, note that we use `$(which python)` to get the absolute path to whatever interpreter executable we are using. 70 | 71 | ## slurm 72 | 73 | slurm is a very popular job scheduling software often used with clusters. 74 | 75 | Submit the training job using the provided `job.sbatch` script: 76 | 77 | ```bash 78 | cd distributed-training-guide/03-job-launchers 79 | sbatch --nodes 2 --gpus 16 --cpus-per-task 8 job.sbatch 80 | ``` 81 | 82 | By default slurm assigns 1 task per node, which is great for us because we will invoke torchrun once per node. 83 | 84 | The command above requests a total of 16 gpus from 2 nodes total. 85 | 86 | ### The slurm file 87 | 88 | This is mostly identical to our torchrun command that we have been using thus far, just with various settings controlled by slurm. 89 | 90 | The command listed below will be run on each node (since we have specified `--ntasks-per-node=1`). 91 | 92 | ```bash 93 | # SBATCH --ntasks-per-node=1 94 | 95 | source $(pwd)/../venv/bin/activate 96 | 97 | MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 98 | MASTER_PORT=$(expr 5000 + $(echo -n ${SLURM_JOBID} | tail -c 4)) 99 | export TORCHELASTIC_ERROR_FILE=./error-${SLURM_JOBID}-${SLURM_NODEID}.json 100 | export OMP_NUM_THREADS=1 101 | export HF_HOME=../.cache 102 | 103 | printenv 104 | 105 | srun torchrun \ 106 | --rdzv-id "slurm-${SLURM_JOBID}" \ 107 | --rdzv-backend c10d \ 108 | --rdzv-endpoint ${MASTER_ADDR}:${MASTER_PORT} \ 109 | --nnodes ${SLURM_NNODES} \ 110 | --nproc-per-node ${SLURM_GPUS_ON_NODE} \ 111 | --redirects 3 \ 112 | --log-dir ${SLURM_SUBMIT_DIR}/logs \ 113 | ../02-distributed-data-parallel/train_llm.py \ 114 | --experiment-name gpt2-alpaca-slurm-$(date +%Y-%m-%dT%H-%M-%S) \ 115 | --dataset-name tatsu-lab/alpaca \ 116 | --model-name openai-community/gpt2 117 | ``` 118 | 119 | ## mpirun 120 | 121 | There are two main flavors of MPI implementation, OpenMPI and MPICH. Either of them will work and we will use the OpenMPI implementation in this blog. **You need to install OpenMPI**. 122 | 123 | ### Code Changes 124 | 125 | Use MPI environment variables when initializing the process group: 126 | 127 | ```diff 128 | - dist.init_process_group() 129 | + dist.init_process_group( 130 | + rank=int(os.environ["OMPI_COMM_WORLD_RANK"]), 131 | + world_size=int(os.environ["OMPI_COMM_WORLD_SIZE"]), 132 | + ) 133 | ``` 134 | 135 | ### Command 136 | 137 | ```bash 138 | cd distributed-training-guide/03-job-launchers 139 | mpirun \ 140 | -H :,...,: \ 141 | -x MASTER_ADDR= \ 142 | -x MASTER_PORT=5001 \ 143 | -x TORCHELASTIC_ERROR_FILE=../error.json \ 144 | -x OMP_NUM_THREADS=1 \ 145 | -x HF_HOME=../.cache \ 146 | -bind-to none \ 147 | -map-by slot \ 148 | -wdir $(pwd) \ 149 | -output-filename ../logs/mpi-multi-node \ 150 | $(which python) train_llm.py \ 151 | --experiment-name mpi-multi-node \ 152 | --dataset-name tatsu-lab/alpaca \ 153 | --model-name openai-community/gpt2 154 | ``` 155 | 156 | Arguments: 157 | - `-H` specifies the hosts we want to launch on AND the number of processes per host 158 | - `-x` sets up an environment variable in all the launched processes 159 | - `-wdir` sets up the working directory for the launched processes 160 | - `-bind-to none` specifies Open MPI to not bind a training process to a single CPU core (which would hurt performance). 161 | - `-map-by slot` allows you to have a mixture of different NUMA configurations because the default behavior is to bind to the socket. 162 | 163 | Notes: 164 | - We have to specify `MASTER_ADDR` and `MASTER_PORT` for pytorch to know how to talk to each other 165 | - In our code we have to pass the rank and world size based on the `$OMPI_COMM_WORLD_RANK` and `$OMPI_COMM_WORLD_SIZE` environment variables. 166 | - We use `$(which python)` to get the absolute path of our python interpreter - if you are launch from a head node instead of a worker node, you'll need to change this. 167 | 168 | ## deepspeed 169 | 170 | deepspeed is a distributed training library with many optimizations. We go into some of these optimizations in more detail in later chapters, but here we can just use the launcher included with it. 171 | 172 | **NOTE: you do not have to integrate deepspeed into your training code to use the deepspeed launcher.** 173 | 174 | Install: `pip install deepspeed` 175 | 176 | ### Code Changes 177 | 178 | Add `--local_rank` to cli parsing: 179 | ```diff 180 | parser.add_argument("--log-freq", default=100, type=int) 181 | parser.add_argument("--ckpt-freq", default=500, type=int) 182 | + parser.add_argument("--local_rank", type=int, default=None) 183 | return parser 184 | ``` 185 | 186 | Use it when initializing local_rank: 187 | ```diff 188 | - local_rank = rank % torch.cuda.device_count() 189 | + local_rank = args.local_rank or (rank % torch.cuda.device_count()) 190 | ``` 191 | 192 | ### Command 193 | 194 | ```bash 195 | cd distributed-training-guide/03-job-launchers 196 | export HF_HOME=../.cache 197 | export TORCHELASTIC_ERROR_FILE=../error.json 198 | export OMP_NUM_THREADS=1 199 | deepspeed \ 200 | --include @ \ 201 | --enable_each_rank_log ../logs \ 202 | train_llm.py \ 203 | --experiment-name deepspeed-multi-node \ 204 | --dataset-name tatsu-lab/alpaca \ 205 | --model-name openai-community/gpt2 206 | ``` 207 | -------------------------------------------------------------------------------- /03-job-launchers/job.sbatch: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # SBATCH --ntasks-per-node=1 4 | 5 | source $(pwd)/../venv/bin/activate 6 | 7 | MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 8 | MASTER_PORT=$(expr 5000 + $(echo -n ${SLURM_JOBID} | tail -c 4)) 9 | export TORCHELASTIC_ERROR_FILE=./error-${SLURM_JOBID}-${SLURM_NODEID}.json 10 | export OMP_NUM_THREADS=1 11 | export HF_HOME=../.cache 12 | 13 | printenv 14 | 15 | srun torchrun \ 16 | --rdzv-id "slurm-${SLURM_JOBID}" \ 17 | --rdzv-backend c10d \ 18 | --rdzv-endpoint ${MASTER_ADDR}:${MASTER_PORT} \ 19 | --nnodes ${SLURM_NNODES} \ 20 | --nproc-per-node ${SLURM_GPUS_ON_NODE} \ 21 | --redirects 3 \ 22 | --log-dir ${SLURM_SUBMIT_DIR}/logs \ 23 | ../03-multi-node/train_llm.py \ 24 | --experiment-name gpt2-alpaca-slurm-$(date +%Y-%m-%dT%H-%M-%S) \ 25 | --dataset-name tatsu-lab/alpaca \ 26 | --model-name openai-community/gpt2 27 | -------------------------------------------------------------------------------- /04-fully-sharded-data-parallel/train_llm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from contextlib import contextmanager 3 | import functools 4 | from itertools import chain 5 | import json 6 | import multiprocessing 7 | import os 8 | import time 9 | from pathlib import Path 10 | import logging 11 | 12 | import torch 13 | from torch.utils.data import DataLoader 14 | from torch.utils.data.distributed import DistributedSampler 15 | from torch import distributed as dist 16 | from torch.distributed.elastic.multiprocessing.errors import record 17 | from torch.distributed.fsdp.fully_sharded_data_parallel import ( 18 | FullyShardedDataParallel, 19 | CPUOffload, 20 | ShardingStrategy, 21 | ) 22 | from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy 23 | from torch.distributed.checkpoint.state_dict import ( 24 | get_state_dict, 25 | set_state_dict, 26 | StateDictOptions, 27 | ) 28 | from torch.distributed.checkpoint import load, save 29 | 30 | 31 | import wandb 32 | import tqdm 33 | import datasets 34 | from transformers import ( 35 | AutoConfig, 36 | AutoModelForCausalLM, 37 | AutoTokenizer, 38 | default_data_collator, 39 | ) 40 | from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaRotaryEmbedding 41 | 42 | # fixes for reset_parameters not existing 43 | LlamaRMSNorm.reset_parameters = lambda self: torch.nn.init.ones_(self.weight) 44 | LlamaRotaryEmbedding.reset_parameters = lambda _: None 45 | 46 | LOGGER = logging.getLogger(__name__) 47 | 48 | 49 | @record 50 | def main(): 51 | parser = _get_parser() 52 | args = parser.parse_args() 53 | 54 | dist.init_process_group() 55 | 56 | rank = dist.get_rank() 57 | local_rank = rank % torch.cuda.device_count() 58 | world_size = dist.get_world_size() 59 | 60 | logging.basicConfig( 61 | format=f"[rank={rank}] [%(asctime)s] %(levelname)s:%(message)s", 62 | level=logging.INFO, 63 | ) 64 | 65 | LOGGER.info(os.environ) 66 | LOGGER.info(args) 67 | LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}") 68 | 69 | device = torch.device(f"cuda:{local_rank}") 70 | dtype = torch.bfloat16 71 | torch.cuda.set_device(device) 72 | 73 | torch.manual_seed(args.seed) 74 | 75 | with rank0_first(): 76 | config = AutoConfig.from_pretrained(args.model_name, use_cache=False) 77 | # NOTE: meta device will not allocate any memory 78 | with torch.device("meta"): 79 | model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype) 80 | LOGGER.info(f"{sum(p.numel() for p in model.parameters())} model parameters") 81 | 82 | LOGGER.info(f"Before FSDP: {get_mem_stats(device)}") 83 | 84 | wrap_policy = functools.partial( 85 | size_based_auto_wrap_policy, min_num_params=int(args.numel_to_wrap) 86 | ) 87 | model = FullyShardedDataParallel( 88 | model, 89 | device_id=local_rank, 90 | sync_module_states=True, 91 | # NOTE: FULL_SHARD is equivalent to deepspeed ZeRO stage 3 92 | auto_wrap_policy=wrap_policy, 93 | sharding_strategy=ShardingStrategy.FULL_SHARD, 94 | cpu_offload=CPUOffload(offload_params=args.cpu_offload == "on"), 95 | ) 96 | 97 | LOGGER.info(f"After FSDP: {get_mem_stats(device)}") 98 | 99 | # NOTE: since this can download data, make sure to do the main process first 100 | # NOTE: This assumes that the data is on a **shared** network drive, accessible to all processes 101 | with rank0_first(): 102 | train_data = _load_and_preprocess_data(args, config) 103 | LOGGER.info(f"{len(train_data)} training samples") 104 | 105 | dataloader = DataLoader( 106 | train_data, 107 | batch_size=args.batch_size, 108 | collate_fn=default_data_collator, 109 | # NOTE: this sampler will split dataset evenly across workers 110 | sampler=DistributedSampler(train_data, shuffle=True, drop_last=True), 111 | ) 112 | LOGGER.info(f"{len(dataloader)} batches per epoch") 113 | 114 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, fused=True) 115 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 116 | optimizer, T_max=1000, eta_min=args.lr * 1e-2 117 | ) 118 | 119 | exp_dir: Path = Path(args.save_dir) / args.experiment_name 120 | 121 | # NOTE: full_state_dict=False means we will be saving sharded checkpoints. 122 | ckpt_opts = StateDictOptions(full_state_dict=False, cpu_offload=True) 123 | 124 | # attempt resume 125 | state = { 126 | "epoch": 0, 127 | "global_step": 0, 128 | "epoch_step": 0, 129 | "running_loss": 0, 130 | } 131 | resumed = False 132 | if (exp_dir / "state.json").exists(): 133 | sharded_model_state, sharded_optimizer_state = get_state_dict( 134 | model, optimizer, options=ckpt_opts 135 | ) 136 | load( 137 | dict(model=sharded_model_state, optimizer=sharded_optimizer_state), 138 | checkpoint_id=exp_dir / "checkpoint", 139 | ) 140 | set_state_dict( 141 | model, 142 | optimizer, 143 | model_state_dict=sharded_model_state, 144 | optim_state_dict=sharded_optimizer_state, 145 | options=ckpt_opts, 146 | ) 147 | lr_scheduler.load_state_dict( 148 | torch.load( 149 | exp_dir / "lr_scheduler.pt", map_location=device, weights_only=True 150 | ) 151 | ) 152 | with open(exp_dir / "state.json") as fp: 153 | state = json.load(fp) 154 | resumed = True 155 | LOGGER.info(f"Resumed={resumed} | {state}") 156 | dist.barrier() 157 | 158 | if (exp_dir.is_mount() and rank == 0) or ( 159 | not exp_dir.is_mount() and local_rank == 0 160 | ): 161 | LOGGER.info(f"Creating experiment root directory") 162 | exp_dir.mkdir(parents=True, exist_ok=True) 163 | dist.barrier() 164 | 165 | (exp_dir / f"rank-{rank}").mkdir(parents=True, exist_ok=True) 166 | LOGGER.info(f"Worker saving to {exp_dir / f'rank-{rank}'}") 167 | 168 | if rank == 0: 169 | wandb.init( 170 | project="distributed-training-guide", 171 | dir=exp_dir, 172 | name=args.experiment_name, 173 | id=args.experiment_name, 174 | resume="must" if resumed else None, 175 | save_code=True, 176 | config={ 177 | "args": vars(args), 178 | "training_data_size": len(train_data), 179 | "num_batches": len(dataloader), 180 | "world_size": world_size, 181 | }, 182 | ) 183 | 184 | timers = {k: LocalTimer(device) for k in ["data", "forward", "backward", "update"]} 185 | 186 | for state["epoch"] in range(state["epoch"], args.num_epochs): 187 | LOGGER.info(f"Begin epoch {state['epoch']} at step {state['epoch_step']}") 188 | 189 | progress_bar = tqdm.tqdm(range(len(dataloader)), disable=rank > 0) 190 | if state["epoch_step"] > 0: 191 | progress_bar.update(state["epoch_step"]) 192 | 193 | dataloader.sampler.set_epoch(state["epoch"]) 194 | batches = iter(dataloader) 195 | 196 | for i_step in range(len(dataloader)): 197 | with timers["data"], torch.no_grad(): 198 | batch = next(batches) 199 | batch = {k: v.to(device=device) for k, v in batch.items()} 200 | 201 | if i_step < state["epoch_step"]: 202 | # NOTE: for resuming 203 | continue 204 | 205 | with timers["forward"]: 206 | outputs = model(**batch) 207 | 208 | with timers["backward"]: 209 | optimizer.zero_grad(set_to_none=True) 210 | outputs.loss.backward() 211 | 212 | with timers["update"]: 213 | optimizer.step() 214 | lr_scheduler.step() 215 | 216 | state["global_step"] += 1 217 | state["epoch_step"] += 1 218 | state["running_loss"] += outputs.loss.item() 219 | progress_bar.update(1) 220 | 221 | if state["global_step"] % args.log_freq == 0: 222 | tok_per_step = world_size * args.batch_size * args.seq_length 223 | ms_per_step = sum(t.avg_elapsed_ms() for t in timers.values()) 224 | info = { 225 | "global_step": state["global_step"], 226 | "lr": lr_scheduler.get_last_lr()[0], 227 | "running_loss": state["running_loss"] / args.log_freq, 228 | "epoch": state["epoch"], 229 | "epoch_progress": state["epoch_step"] / len(dataloader), 230 | "num_batches_remaining": len(dataloader) - i_step, 231 | **get_mem_stats(device), 232 | "tok/s": 1000 * tok_per_step / ms_per_step, 233 | "time/total": ms_per_step, 234 | **{ 235 | f"time/{k}": timer.avg_elapsed_ms() 236 | for k, timer in timers.items() 237 | }, 238 | } 239 | 240 | LOGGER.info(info) 241 | if rank == 0: 242 | wandb.log(info, step=state["global_step"]) 243 | 244 | torch.cuda.reset_peak_memory_stats(device) 245 | state["running_loss"] = 0 246 | for t in timers.values(): 247 | t.reset() 248 | 249 | if state["global_step"] % args.ckpt_freq == 0: 250 | dist.barrier() 251 | # NOTE: we have to call this on ALL ranks 252 | sharded_model_state, sharded_optimizer_state = get_state_dict( 253 | model, optimizer, options=ckpt_opts 254 | ) 255 | save( 256 | dict(model=sharded_model_state, optimizer=sharded_optimizer_state), 257 | checkpoint_id=exp_dir / "checkpoint", 258 | ) 259 | if rank == 0: 260 | torch.save(lr_scheduler.state_dict(), exp_dir / "lr_scheduler.pt") 261 | with open(exp_dir / "state.json", "w") as fp: 262 | json.dump(state, fp) 263 | dist.barrier() 264 | 265 | state["epoch_step"] = 0 266 | 267 | 268 | def _load_and_preprocess_data(args, config): 269 | """ 270 | Function created using code found in 271 | https://github.com/huggingface/transformers/blob/v4.45.1/examples/pytorch/language-modeling/run_clm_no_trainer.py 272 | """ 273 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 274 | 275 | data = datasets.load_dataset(args.dataset_name, trust_remote_code=True) 276 | 277 | column_names = data["train"].column_names 278 | text_column_name = "text" if "text" in column_names else column_names[0] 279 | 280 | def tokenize_function(examples): 281 | return tokenizer(examples[text_column_name]) 282 | 283 | tokenized_datasets = data.map( 284 | tokenize_function, 285 | batched=True, 286 | remove_columns=column_names, 287 | num_proc=multiprocessing.cpu_count(), 288 | load_from_cache_file=True, 289 | desc="Running tokenizer on dataset", 290 | ) 291 | 292 | seq_length = args.seq_length or tokenizer.model_max_length 293 | if seq_length > config.max_position_embeddings: 294 | seq_length = min(1024, config.max_position_embeddings) 295 | 296 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. 297 | def group_texts(examples): 298 | # Concatenate all texts. 299 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 300 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 301 | # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. 302 | # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. 303 | if total_length > seq_length: 304 | total_length = (total_length // seq_length) * seq_length 305 | # Split by chunks of max_len. 306 | result = { 307 | k: [t[i : i + seq_length] for i in range(0, total_length, seq_length)] 308 | for k, t in concatenated_examples.items() 309 | } 310 | result["labels"] = result["input_ids"].copy() 311 | return result 312 | 313 | lm_datasets = tokenized_datasets.map( 314 | group_texts, 315 | batched=True, 316 | num_proc=multiprocessing.cpu_count(), 317 | load_from_cache_file=True, 318 | desc=f"Grouping texts in chunks of {seq_length}", 319 | ) 320 | 321 | return lm_datasets["train"] 322 | 323 | 324 | def get_mem_stats(device=None): 325 | mem = torch.cuda.memory_stats(device) 326 | props = torch.cuda.get_device_properties(device) 327 | return { 328 | "total_gb": 1e-9 * props.total_memory, 329 | "curr_alloc_gb": 1e-9 * mem["allocated_bytes.all.current"], 330 | "peak_alloc_gb": 1e-9 * mem["allocated_bytes.all.peak"], 331 | "curr_resv_gb": 1e-9 * mem["reserved_bytes.all.current"], 332 | "peak_resv_gb": 1e-9 * mem["reserved_bytes.all.peak"], 333 | } 334 | 335 | 336 | @contextmanager 337 | def rank0_first(): 338 | rank = dist.get_rank() 339 | if rank == 0: 340 | yield 341 | dist.barrier() 342 | if rank > 0: 343 | yield 344 | dist.barrier() 345 | 346 | 347 | class LocalTimer: 348 | def __init__(self, device: torch.device): 349 | if device.type == "cpu": 350 | self.synchronize = lambda: torch.cpu.synchronize(device=device) 351 | elif device.type == "cuda": 352 | self.synchronize = lambda: torch.cuda.synchronize(device=device) 353 | self.measurements = [] 354 | self.start_time = None 355 | 356 | def __enter__(self): 357 | self.synchronize() 358 | self.start_time = time.time() 359 | return self 360 | 361 | def __exit__(self, type, value, traceback): 362 | if traceback is None: 363 | self.synchronize() 364 | end_time = time.time() 365 | self.measurements.append(end_time - self.start_time) 366 | self.start_time = None 367 | 368 | def avg_elapsed_ms(self): 369 | return 1000 * (sum(self.measurements) / len(self.measurements)) 370 | 371 | def reset(self): 372 | self.measurements = [] 373 | self.start_time = None 374 | 375 | 376 | def _get_parser() -> argparse.ArgumentParser: 377 | parser = argparse.ArgumentParser() 378 | parser.add_argument("-e", "--experiment-name", default=None, required=True) 379 | parser.add_argument("-d", "--dataset-name", default=None, required=True) 380 | parser.add_argument("-m", "--model-name", default=None, required=True) 381 | parser.add_argument("--save-dir", default="../outputs") 382 | parser.add_argument("--seed", default=0, type=int) 383 | parser.add_argument("--num-epochs", default=100, type=int) 384 | parser.add_argument("--lr", default=3e-5, type=float) 385 | parser.add_argument("-b", "--batch-size", default=1, type=int) 386 | parser.add_argument("--log-freq", default=100, type=int) 387 | parser.add_argument("--ckpt-freq", default=500, type=int) 388 | parser.add_argument("-s", "--seq-length", default=1024, type=int) 389 | parser.add_argument( 390 | "--numel-to-wrap", 391 | default=100_000_000, 392 | type=int, 393 | help="Only applies FSDP to modules with numel > this value.", 394 | ) 395 | parser.add_argument("--cpu-offload", default="off", choices=["on", "off"]) 396 | return parser 397 | 398 | 399 | if __name__ == "__main__": 400 | main() 401 | -------------------------------------------------------------------------------- /05-training-llama-405b/README.md: -------------------------------------------------------------------------------- 1 | # Training a 405B model 2 | 3 | **NOTE: This chapter's code builds off of [chapter 4's FSDP code](../04-fully-sharded-data-parallel/).** 4 | 5 | Here we are going to utilize an 8 node cluster (64 H100 GPUs) to train Llama 3.1 405B. **This does not utilize LORA!** We are actually fully training the weights of a 405b model in plain pytorch. 6 | 7 | The next few sections go through various changes we have to make to our FSDP code from chapter 4 to make training a 405b model work. 8 | 9 | Quick Jump: 10 | - [Use flash attention](#use-flash-attention) 11 | - [Download model weights](#download-model-weights) 12 | - [Loading pretrained weights](#loading-pretrained-weights) 13 | - [Sharding Llama 405B](#sharding-llama-405b) 14 | - [Gradient (aka activation) checkpointing](#gradient-aka-activation-checkpointing) 15 | - [CPU Offload \& fused optimizer kernels](#cpu-offload--fused-optimizer-kernels) 16 | - [NOT de-allocating gradients](#not-de-allocating-gradients) 17 | - [Launch command](#launch-command) 18 | - [Monitoring](#monitoring) 19 | - [Run statistics](#run-statistics) 20 | - [Other notes on settings that didn't affect throughput](#other-notes-on-settings-that-didnt-affect-throughput) 21 | 22 | ## Use flash attention 23 | 24 | Flash attention is a fused implementation of scaled dot product attention that heavily minimizes memory usage. The whole goal behind it is to query memory as little as possible, and minimize temporary memory used. 25 | 26 | Check out the [repo](https://github.com/Dao-AILab/flash-attention) and the [paper](https://arxiv.org/abs/2205.14135) for more information. 27 | 28 | This ends up saving us 10s of gb in the forward/backward pass. 29 | 30 | Install: 31 | 32 | ```bash 33 | pip install packaging 34 | pip install ninja 35 | pip install flash-attn --no-build-isolation 36 | ``` 37 | 38 | Use it when we initialize our model: 39 | 40 | ```python 41 | model = AutoModelForCausalLM.from_pretrained( 42 | ... 43 | attn_implementation="flash_attention_2", 44 | ) 45 | ``` 46 | 47 | ## Download model weights 48 | 49 | The actual model weights are huge - it contains 191 separate files which are each about 4GB - totally about 764 GB. 50 | 51 | There are two options for storing these weights here (and they make a difference!): 52 | 53 | 1. A shared network drive that all the nodes can access 54 | 2. Locally on the main rank 0 node 55 | 56 | Node local storage is **much** faster when initializing. For some numbers, while running this script on 8 8xH100 80GB nodes, the shared network drive took 50 minutes to initialize, while the node local storage only took 3 minutes. 57 | 58 | There's a download script in this repo for utility, run this on node 0: 59 | 60 | ```bash 61 | cd distributed-training-guide/05-training-llama-405b 62 | python download.py 63 | ``` 64 | 65 | And run this on the other nodes (to download config & tokenizer): 66 | 67 | ```bash 68 | cd distributed-training-guide/05-training-llama-405b 69 | python download.py --skip-model 70 | ``` 71 | 72 | NOTE: you will likely have to log into your huggingface account using `huggingface-cli login`. 73 | 74 | ## Loading pretrained weights 75 | 76 | When we actual load the weights, it will take some time AND takes a lot of memory to load. Again the full size is about 764 GB, so we need to make sure we have enough RAM to store the weights. 77 | 78 | There's three parts to this: 79 | 80 | 1. Loading the weights into RAM only on `rank==0` 81 | 2. Using the [meta](../04-fully-sharded-data-parallel/README.md#initialization-after-sharding---the-meta-device) device on `rank>0` 82 | 3. Using `from_config` instead of `from_pretrained` on `rank>0` so we don't need to download the weights on all the nodes. 83 | 1. Note that if you have the weights on a shared network drive, you can just use `from_pretrained` instead. 84 | 4. Enabling [sync_module_states](../04-fully-sharded-data-parallel/README.md#sync_module_states) in FSDP constructor 85 | 86 | You might think of using the `device_map` feature of `transformers` - e.g. `device_map="auto"` tries to smartly fill up memory. However if you try this approach you'll end up with out of memory errors when FSDP tries to start sending memory to the GPU. 87 | 88 | Here's our code snippet for doing this: 89 | 90 | ```python 91 | if rank == 0: 92 | with torch.device("cpu"): 93 | model = AutoModelForCausalLM.from_pretrained(...) 94 | else: 95 | with torch.device("meta"): 96 | model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype) 97 | ``` 98 | 99 | Then later, sync_module_states in [FSDP constructor](../04-fully-sharded-data-parallel/README.md#the-fsdp-constructor) will make sure the weights are broadcasted from rank 0 to the other ranks. 100 | 101 | ## Sharding Llama 405B 102 | 103 | Determining what layers you should shard is complex. If you are using `transformers`, they include a private attribute on classes called [_no_split_modules](https://github.com/huggingface/transformers/blob/v4.45.1/src/transformers/models/llama/modeling_llama.py#L784) that will contain classes that you should not shard anything under them. E.g. for Llama this attribute just contains `LlamaDecoderLayer`. So that is what we will wrap! During testing I also found that sharding the `nn.Embedding` layer at the beginning of the network improved throughput and reduced memory usage. 104 | 105 | We can use the [transformer_auto_wrap_policy()](https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/wrap.py#L307C5-L307C33) to target the specific classes for those layers, and pass that as our [auto_wrap_policy in the FSDP constructor](../04-fully-sharded-data-parallel/README.md#what-layers-to-shard---the-auto_wrap_policy): 106 | 107 | ```python 108 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 109 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 110 | 111 | wrap_policy = functools.partial( 112 | transformer_auto_wrap_policy, 113 | transformer_layer_cls={LlamaDecoderLayer, nn.Embedding}, 114 | ) 115 | FSDP(..., auto_wrap_policy=wrap_policy) 116 | ``` 117 | 118 | Please consult [our explanation on the FSDP constructor](../04-fully-sharded-data-parallel/README.md#the-fsdp-constructor) for more info. 119 | 120 | As a reminder - this will cause FSDP to gather all the parameters for each DecoderLayer (which includes Attention, Linear, and various norm modules), and shard them across the world. At the start of forward/backward pass FSDP will issue an all-gather so all the nodes have the full weights in memory, and at the end of the DecoderLayer forward/backward, it will free up the full weights again. 121 | 122 | So where you apply FSDP determines where the all-gather happens! 123 | 124 | ## Gradient (aka activation) checkpointing 125 | 126 | Another piece of reducing memory usage is gradient checkpointing (first introduced in [Training Deep Nets with Sublinear Memory Cost](https://arxiv.org/abs/1604.06174)). Normally when you do the forward pass, you have to keep the input & output in memory until you run the backward pass. This takes up a lot of memory to keep these intermediate tensors around. With gradient checkpointing, we actually **re-run** the forward pass during backwards to regenerate the output. So we are doing more compute but saving a lot of memory. 127 | 128 | The method we are using is kind of a hidden method in pytorch, but this is actually exactly what [accelerate uses under the hood](https://github.com/huggingface/accelerate/blob/v0.34.2/src/accelerate/accelerator.py#L1492) so rest assured that it is a "standard" way of doing it: 129 | 130 | This piece of code has to go **after** the FSDP constructor!!! I'm not exactly sure of the reason, but it doesn't work before the FSDP initialization. 131 | 132 | ```python 133 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 134 | apply_activation_checkpointing, 135 | checkpoint_wrapper, 136 | ) 137 | 138 | model = FSDP(...) 139 | 140 | apply_activation_checkpointing( 141 | model, checkpoint_wrapper_fn=checkpoint_wrapper, auto_wrap_policy=wrap_policy 142 | ) 143 | ``` 144 | 145 | ## CPU Offload & fused optimizer kernels 146 | 147 | Since the model is so large, we pretty much have to enable [CPU offloading](../04-fully-sharded-data-parallel/README.md#cpu-offload) with FSDP. **When using CPUOffload feature of FSDP, the optimizer entirely runs on the CPU**. This is because there is significant cost to transfer data to and from the GPU when doing `optimizer.step()`. At the time of this being written there are open issues on how to overlap the `optimizer.step()` with the next `forward()` call. 148 | 149 | By default the optimizers will use non-fused kernel when running on the CPU which will generate a lot of intermediate tensors. By explicitly using the fused kernel we get a lot of speedup, which is especially important since we are running that step on the CPU: 150 | 151 | ```python 152 | torch.optim.AdamW(model.parameters(), lr=args.lr, fused=True) 153 | ``` 154 | 155 | If you want to peek through the pytorch code: 156 | 1. [_single_tensor_adamw()](https://github.com/pytorch/pytorch/blob/v2.4.1/torch/optim/adamw.py#L322) is the default implementation used 157 | 2. [_fused_adamw()](https://github.com/pytorch/pytorch/blob/v2.4.1/torch/optim/adamw.py#L612) is the fused implementation 158 | 159 | ## NOT de-allocating gradients 160 | 161 | You may have seen this `set_to_none` argument in [optimizer.zero_grad()](https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html). According to the docs: 162 | 163 | > This will in general have lower memory footprint, and can modestly improve performance. 164 | 165 | Basically `set_to_none=True` will **deallocate the gradients** after they are used. In most GPU cases where we want to save a bit of memory, it is a good thing to de-allocate. However in our case we are using CPU offload, which means all of our gradients are already on the CPU! Since we aren't taking up GPU memory, that means we just have to pay for allocating & de-allocating a lot if we do set to none. So if you set `set_to_none=False` you should actually see a slight speed up for our case! 166 | 167 | ```python 168 | optimizer.zero_grad(set_to_none=args.cpu_offload == "off") 169 | ``` 170 | 171 | ## Launch command 172 | 173 | That's pretty much all the changes you need from our base [FSDP code](../04-fully-sharded-data-parallel/). Now let's launch! 174 | 175 | We provide a customized [launch.sh](./launch.sh) script here based on the bash command for spawning torchrun on all available nodes: 176 | 177 | ```bash 178 | cd distributed-training-guide/05-training-llama-405b 179 | bash launch.sh # NOTE: this is non blocking 180 | ``` 181 | 182 | Also note that this launch.sh specifies `HF_HOME` as an environment variable in the tmux session, so if you've not used the default value of `/home/ubuntu/.cache/huggingface`, please update the script! 183 | 184 | You can change the hostnames in the [hosts](./hosts) file in this directory. 185 | 186 | ## Monitoring 187 | 188 | We are using torchrun in our [launch.sh](./launch.sh) script, so we will get an output directory per node with a bunch of sub directories with our log files in them. It's a bit of a pain to manually monitor these, so here's a bash command for tailing all of them at once: 189 | 190 | ```bash 191 | cd distributed-training-guide/05-training-llama-405b 192 | find ../logs/ -name \*stderr.log | xargs tail -f 193 | ``` 194 | 195 | Additionally, we have a top like utility script for monitoring the entire cluster at the top level of this directory: 196 | 197 | ```bash 198 | cd distributed-training-guide/05-training-llama-405b 199 | python ../top-cluster.py hosts 200 | ``` 201 | 202 | If you notice any of the nprocs go down or the power usage go down then you know that an error has occurred! 203 | 204 | To kill all the processes on all the nodes you can just kill the tmux sessions: 205 | 206 | ```bash 207 | xargs -a hosts -I{} ssh {} tmux kill-session -t torchrun-llama-405b 208 | ``` 209 | 210 | ## Run statistics 211 | 212 | Training with `--seq-length 4096` and `--batch-size 1` on 64 H100 gpus (8 separate nodes) has the following stats: 213 | 214 | - ~30s per iteration (data/forward/backward/update). Breakdown is 215 | - data: ~2ms 216 | - forward: ~7s 217 | - backward: ~19s 218 | - update: ~4s 219 | - Peak Memory Allocated: 52.9GB 220 | - Peak Memory Reserved: 77.9GB 221 | 222 | Noting that reserved memory has to do with pytorch allocation caching. 223 | 224 | ## Other notes on settings that didn't affect throughput 225 | 226 | - Allowing tf32 had no impact on throughput (`torch.backends.cudnn.allow_tf32` and `torch.backends.cuda.matmul.allow_tf32`) 227 | - Enabling benchmarking had no impact on throughput (`torch.backends.cudnn.benchmark = True`) 228 | - Using CuDNN sdpa was slower (`attn_implementation="sdpa"` and `torch.backends.cuda.enable_cudnn_sdp(True)`) 229 | - torch.compile had no impact (`use_orig_params=True` and `torch.compile` after FSDP constructor) 230 | - Very minimal testing of NCCL environment variables either made things worse or had no impact (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html) 231 | - `PYTORCH_NO_CUDA_MEMORY_CACHING=1` made enough memory available that `--batch-size 2` or higher sequence lengths were possible, but it was much much slower. 232 | - It's possible that some well placed calls to `torch.cuda.empty_cache()` could achieve this without the throughput loss. 233 | - Only `FULL_SHARD` works. Others fail silently. 234 | -------------------------------------------------------------------------------- /05-training-llama-405b/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import transformers 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--skip-model", default=False, action="store_true") 8 | args = parser.parse_args() 9 | 10 | os.environ["HF_HOME"] = "/home/ubuntu/.cache/huggingface" 11 | 12 | model_name = "meta-llama/Meta-Llama-3.1-405B" 13 | 14 | print(f"Downloading {model_name} to $HF_HOME = {os.environ['HF_HOME']}.") 15 | 16 | config = transformers.AutoConfig.from_pretrained(model_name) 17 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) 18 | if not args.skip_model: 19 | with torch.device("meta"): 20 | model = transformers.AutoModelForCausalLM.from_pretrained(model_name) 21 | -------------------------------------------------------------------------------- /05-training-llama-405b/hosts: -------------------------------------------------------------------------------- 1 | ml-64-node-001 2 | ml-64-node-002 3 | ml-64-node-003 4 | ml-64-node-004 5 | ml-64-node-005 6 | ml-64-node-006 7 | ml-64-node-007 8 | ml-64-node-008 -------------------------------------------------------------------------------- /05-training-llama-405b/launch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | EXPERIMENT_NAME=llama-405b 4 | 5 | if [ ! -f ./hosts ]; then 6 | echo "ERROR: ./hosts file not found. Please add this file to this current directory." 7 | exit 1 8 | fi 9 | 10 | ssh $(head -n 1 hosts) $(which wandb) login 11 | 12 | xargs \ 13 | -a hosts \ 14 | -I {} \ 15 | ssh {} \ 16 | tmux new-session -d -s torchrun-${EXPERIMENT_NAME} -c $(pwd) \ 17 | -e HF_HOME=/home/ubuntu/.cache/huggingface \ 18 | -e OMP_NUM_THREADS=26 \ 19 | -e NCCL_CROSS_NIC=1 \ 20 | -e TORCH_NCCL_AVOID_RECORD_STREAMS=1 \ 21 | $(which python) -m torch.distributed.run \ 22 | --rdzv-id ${EXPERIMENT_NAME} \ 23 | --rdzv-backend c10d \ 24 | --rdzv-endpoint $(head -n 1 hosts):5001 \ 25 | --nnodes $(grep -c '^' hosts) \ 26 | --nproc-per-node 8 \ 27 | --redirects 3 \ 28 | --log-dir ./logs \ 29 | train_llm.py \ 30 | --experiment-name ${EXPERIMENT_NAME} \ 31 | --dataset-name Skylion007/openwebtext \ 32 | --model-name meta-llama/Meta-Llama-3.1-405B \ 33 | --batch-size 1 \ 34 | --seq-length 4096 \ 35 | --cpu-offload on \ 36 | --log-freq 1 37 | -------------------------------------------------------------------------------- /05-training-llama-405b/train_llm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from contextlib import contextmanager 3 | import functools 4 | from itertools import chain 5 | import json 6 | import multiprocessing 7 | import os 8 | import time 9 | from pathlib import Path 10 | import logging 11 | 12 | import torch 13 | from torch.utils.data import DataLoader 14 | from torch.utils.data.distributed import DistributedSampler 15 | from torch import distributed as dist 16 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 17 | apply_activation_checkpointing, 18 | checkpoint_wrapper, 19 | ) 20 | from torch.distributed.elastic.multiprocessing.errors import record 21 | from torch.distributed.fsdp.fully_sharded_data_parallel import ( 22 | FullyShardedDataParallel, 23 | CPUOffload, 24 | ShardingStrategy, 25 | ) 26 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 27 | from torch.distributed.checkpoint.state_dict import ( 28 | get_state_dict, 29 | set_state_dict, 30 | StateDictOptions, 31 | ) 32 | from torch.distributed.checkpoint import load, save 33 | 34 | 35 | import wandb 36 | import tqdm 37 | import datasets 38 | from transformers import ( 39 | AutoConfig, 40 | AutoModelForCausalLM, 41 | AutoTokenizer, 42 | default_data_collator, 43 | ) 44 | from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaRotaryEmbedding 45 | 46 | # fixes for reset_parameters not existing 47 | LlamaRMSNorm.reset_parameters = lambda self: torch.nn.init.ones_(self.weight) 48 | LlamaRotaryEmbedding.reset_parameters = lambda _: None 49 | 50 | LOGGER = logging.getLogger(__name__) 51 | 52 | 53 | @record 54 | def main(): 55 | parser = _get_parser() 56 | args = parser.parse_args() 57 | 58 | dist.init_process_group() 59 | 60 | rank = dist.get_rank() 61 | local_rank = rank % torch.cuda.device_count() 62 | world_size = dist.get_world_size() 63 | 64 | logging.basicConfig( 65 | format=f"[rank={rank}] [%(asctime)s] %(levelname)s:%(message)s", 66 | level=logging.INFO, 67 | ) 68 | 69 | LOGGER.info(os.environ) 70 | LOGGER.info(args) 71 | LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}") 72 | 73 | device = torch.device(f"cuda:{local_rank}") 74 | dtype = torch.bfloat16 75 | torch.cuda.set_device(device) 76 | 77 | torch.manual_seed(args.seed) 78 | 79 | LOGGER.info(f"Loading model from HF_HOME={os.environ['HF_HOME']}") 80 | 81 | config = AutoConfig.from_pretrained(args.model_name, use_cache=False) 82 | if rank == 0: 83 | with torch.device("cpu"): 84 | model = AutoModelForCausalLM.from_pretrained( 85 | args.model_name, 86 | torch_dtype=dtype, 87 | attn_implementation="flash_attention_2", 88 | use_cache=False, 89 | ) 90 | else: 91 | with torch.device("meta"): 92 | model = AutoModelForCausalLM.from_config( 93 | config, 94 | torch_dtype=dtype, 95 | attn_implementation="flash_attention_2", 96 | ) 97 | LOGGER.info(f"{sum(p.numel() for p in model.parameters())} model parameters") 98 | 99 | LOGGER.info(f"Before FSDP: {get_mem_stats(device)}") 100 | 101 | from torch.nn import Embedding 102 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 103 | 104 | wrap_policy = functools.partial( 105 | transformer_auto_wrap_policy, 106 | transformer_layer_cls={LlamaDecoderLayer, Embedding}, 107 | ) 108 | model = FullyShardedDataParallel( 109 | model, 110 | device_id=local_rank, 111 | param_init_fn=lambda m: m.to_empty(device=device, recurse=False), 112 | sync_module_states=True, 113 | # NOTE: FULL_SHARD is equivalent to deepspeed ZeRO stage 3 114 | auto_wrap_policy=wrap_policy, 115 | sharding_strategy=ShardingStrategy.FULL_SHARD, 116 | cpu_offload=CPUOffload(offload_params=args.cpu_offload == "on"), 117 | ) 118 | 119 | LOGGER.info(f"After FSDP: {get_mem_stats(device)}") 120 | LOGGER.info(f"FSDP architecture: {model}") 121 | 122 | # Applying gradient checkpointing - note that only the LlamaDecoderLayer supports this, 123 | # so we can just reuse our existing wrap_policy. 124 | apply_activation_checkpointing( 125 | model, checkpoint_wrapper_fn=checkpoint_wrapper, auto_wrap_policy=wrap_policy 126 | ) 127 | 128 | # NOTE: since this can download data, make sure to do the main process first on each node 129 | # since we manually specified HF_HOME to be a node local drive. 130 | with rank_ordered(should_go_first=local_rank == 0): 131 | train_data = _load_and_preprocess_data(args, config) 132 | LOGGER.info(f"{len(train_data)} training samples") 133 | 134 | dataloader = DataLoader( 135 | train_data, 136 | batch_size=args.batch_size, 137 | collate_fn=default_data_collator, 138 | num_workers=1, 139 | prefetch_factor=2, 140 | # NOTE: this sampler will split dataset evenly across workers 141 | sampler=DistributedSampler(train_data, shuffle=True, drop_last=True), 142 | ) 143 | LOGGER.info(f"{len(dataloader)} batches per epoch") 144 | 145 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, fused=True) 146 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 147 | optimizer, T_max=1000, eta_min=args.lr * 1e-2 148 | ) 149 | 150 | exp_dir: Path = Path(args.save_dir) / args.experiment_name 151 | 152 | # NOTE: full_state_dict=False means we will be saving sharded checkpoints. 153 | ckpt_opts = StateDictOptions(full_state_dict=False, cpu_offload=True) 154 | 155 | # attempt resume 156 | state = { 157 | "epoch": 0, 158 | "global_step": 0, 159 | "epoch_step": 0, 160 | "running_loss": 0, 161 | } 162 | resumed = False 163 | if (exp_dir / "state.json").exists(): 164 | sharded_model_state, sharded_optimizer_state = get_state_dict( 165 | model, optimizer, options=ckpt_opts 166 | ) 167 | load( 168 | dict(model=sharded_model_state, optimizer=sharded_optimizer_state), 169 | checkpoint_id=exp_dir / "checkpoint", 170 | ) 171 | set_state_dict( 172 | model, 173 | optimizer, 174 | model_state_dict=sharded_model_state, 175 | optim_state_dict=sharded_optimizer_state, 176 | options=ckpt_opts, 177 | ) 178 | lr_scheduler.load_state_dict( 179 | torch.load( 180 | exp_dir / "lr_scheduler.pt", map_location=device, weights_only=True 181 | ) 182 | ) 183 | with open(exp_dir / "state.json") as fp: 184 | state = json.load(fp) 185 | resumed = True 186 | LOGGER.info(f"Resumed={resumed} | {state}") 187 | dist.barrier() 188 | 189 | if (exp_dir.is_mount() and rank == 0) or ( 190 | not exp_dir.is_mount() and local_rank == 0 191 | ): 192 | LOGGER.info(f"Creating experiment root directory") 193 | exp_dir.mkdir(parents=True, exist_ok=True) 194 | dist.barrier() 195 | 196 | if rank == 0: 197 | wandb.init( 198 | project="distributed-training-guide", 199 | dir=exp_dir, 200 | name=args.experiment_name, 201 | id=args.experiment_name, 202 | resume="must" if resumed else None, 203 | save_code=True, 204 | config={ 205 | "args": vars(args), 206 | "training_data_size": len(train_data), 207 | "num_batches": len(dataloader), 208 | "world_size": world_size, 209 | }, 210 | ) 211 | 212 | timers = {k: LocalTimer(device) for k in ["data", "forward", "backward", "update"]} 213 | 214 | for state["epoch"] in range(state["epoch"], args.num_epochs): 215 | LOGGER.info(f"Begin epoch {state['epoch']} at step {state['epoch_step']}") 216 | 217 | progress_bar = tqdm.tqdm(range(len(dataloader)), disable=True) 218 | if state["epoch_step"] > 0: 219 | progress_bar.update(state["epoch_step"]) 220 | 221 | dataloader.sampler.set_epoch(state["epoch"]) 222 | batches = iter(dataloader) 223 | 224 | for i_step in range(len(dataloader)): 225 | with timers["data"], torch.no_grad(): 226 | batch = next(batches) 227 | batch = {k: v.to(device=device) for k, v in batch.items()} 228 | 229 | if i_step < state["epoch_step"]: 230 | # NOTE: for resuming 231 | continue 232 | 233 | with timers["forward"]: 234 | outputs = model(**batch) 235 | 236 | with timers["backward"]: 237 | outputs.loss.backward() 238 | 239 | with timers["update"]: 240 | optimizer.step() 241 | lr_scheduler.step() 242 | optimizer.zero_grad(set_to_none=args.cpu_offload == "off") 243 | 244 | state["global_step"] += 1 245 | state["epoch_step"] += 1 246 | state["running_loss"] += outputs.loss.item() 247 | progress_bar.update(1) 248 | 249 | if state["global_step"] % args.log_freq == 0: 250 | tok_per_step = world_size * args.batch_size * args.seq_length 251 | ms_per_step = sum(t.avg_elapsed_ms() for t in timers.values()) 252 | info = { 253 | "global_step": state["global_step"], 254 | "lr": lr_scheduler.get_last_lr()[0], 255 | "running_loss": state["running_loss"] / args.log_freq, 256 | "epoch": state["epoch"], 257 | "epoch_progress": state["epoch_step"] / len(dataloader), 258 | "num_batches_remaining": len(dataloader) - i_step, 259 | **get_mem_stats(device), 260 | "tok/s": 1000 * tok_per_step / ms_per_step, 261 | "time/total": ms_per_step, 262 | "time/total": sum(t.avg_elapsed_ms() for t in timers.values()), 263 | **{ 264 | f"time/{k}": timer.avg_elapsed_ms() 265 | for k, timer in timers.items() 266 | }, 267 | } 268 | 269 | LOGGER.info(info) 270 | if rank == 0: 271 | wandb.log(info, step=state["global_step"]) 272 | 273 | torch.cuda.reset_peak_memory_stats(device) 274 | state["running_loss"] = 0 275 | for t in timers.values(): 276 | t.reset() 277 | 278 | if state["global_step"] % args.ckpt_freq == 0: 279 | LOGGER.info("Saving checkpoint.") 280 | dist.barrier() 281 | # NOTE: we have to call this on ALL ranks 282 | sharded_model_state, sharded_optimizer_state = get_state_dict( 283 | model, optimizer, options=ckpt_opts 284 | ) 285 | save( 286 | dict(model=sharded_model_state, optimizer=sharded_optimizer_state), 287 | checkpoint_id=exp_dir / "checkpoint", 288 | ) 289 | if rank == 0: 290 | torch.save(lr_scheduler.state_dict(), exp_dir / "lr_scheduler.pt") 291 | with open(exp_dir / "state.json", "w") as fp: 292 | json.dump(state, fp) 293 | dist.barrier() 294 | 295 | state["epoch_step"] = 0 296 | 297 | 298 | def _load_and_preprocess_data(args, config): 299 | """ 300 | Function created using code found in 301 | https://github.com/huggingface/transformers/blob/v4.45.1/examples/pytorch/language-modeling/run_clm_no_trainer.py 302 | """ 303 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 304 | 305 | data = datasets.load_dataset(args.dataset_name, trust_remote_code=True) 306 | 307 | column_names = data["train"].column_names 308 | text_column_name = "text" if "text" in column_names else column_names[0] 309 | 310 | def tokenize_function(examples): 311 | return tokenizer(examples[text_column_name]) 312 | 313 | tokenized_datasets = data.map( 314 | tokenize_function, 315 | batched=True, 316 | remove_columns=column_names, 317 | num_proc=multiprocessing.cpu_count(), 318 | load_from_cache_file=True, 319 | desc="Running tokenizer on dataset", 320 | ) 321 | 322 | seq_length = args.seq_length or tokenizer.model_max_length 323 | if seq_length > config.max_position_embeddings: 324 | seq_length = min(1024, config.max_position_embeddings) 325 | 326 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. 327 | def group_texts(examples): 328 | # Concatenate all texts. 329 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 330 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 331 | # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. 332 | # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. 333 | if total_length > seq_length: 334 | total_length = (total_length // seq_length) * seq_length 335 | # Split by chunks of max_len. 336 | result = { 337 | k: [t[i : i + seq_length] for i in range(0, total_length, seq_length)] 338 | for k, t in concatenated_examples.items() 339 | } 340 | result["labels"] = result["input_ids"].copy() 341 | return result 342 | 343 | lm_datasets = tokenized_datasets.map( 344 | group_texts, 345 | batched=True, 346 | num_proc=multiprocessing.cpu_count(), 347 | load_from_cache_file=True, 348 | desc=f"Grouping texts in chunks of {seq_length}", 349 | ) 350 | 351 | return lm_datasets["train"] 352 | 353 | 354 | def get_mem_stats(device=None): 355 | mem = torch.cuda.memory_stats(device) 356 | props = torch.cuda.get_device_properties(device) 357 | return { 358 | "total_mem_in_gb": 1e-9 * props.total_memory, 359 | "curr_alloc_in_gb": 1e-9 * mem["allocated_bytes.all.current"], 360 | "peak_alloc_in_gb": 1e-9 * mem["allocated_bytes.all.peak"], 361 | "curr_resv_in_gb": 1e-9 * mem["reserved_bytes.all.current"], 362 | "peak_resv_in_gb": 1e-9 * mem["reserved_bytes.all.peak"], 363 | } 364 | 365 | 366 | @contextmanager 367 | def rank_ordered(*, should_go_first: bool): 368 | if should_go_first: 369 | yield 370 | dist.barrier() 371 | if not should_go_first: 372 | yield 373 | dist.barrier() 374 | 375 | 376 | class LocalTimer: 377 | def __init__(self, device: torch.device): 378 | if device.type == "cpu": 379 | self.synchronize = lambda: torch.cpu.synchronize(device=device) 380 | elif device.type == "cuda": 381 | self.synchronize = lambda: torch.cuda.synchronize(device=device) 382 | self.measurements = [] 383 | self.start_time = None 384 | 385 | def __enter__(self): 386 | self.synchronize() 387 | self.start_time = time.time() 388 | return self 389 | 390 | def __exit__(self, type, value, traceback): 391 | if traceback is None: 392 | self.synchronize() 393 | end_time = time.time() 394 | self.measurements.append(end_time - self.start_time) 395 | self.start_time = None 396 | 397 | def avg_elapsed_ms(self): 398 | return 1000 * (sum(self.measurements) / len(self.measurements)) 399 | 400 | def reset(self): 401 | self.measurements = [] 402 | self.start_time = None 403 | 404 | 405 | def _get_parser() -> argparse.ArgumentParser: 406 | parser = argparse.ArgumentParser() 407 | parser.add_argument("-e", "--experiment-name", default=None, required=True) 408 | parser.add_argument("-d", "--dataset-name", default=None, required=True) 409 | parser.add_argument("-m", "--model-name", default=None, required=True) 410 | parser.add_argument("--save-dir", default="../outputs") 411 | parser.add_argument("--seed", default=0, type=int) 412 | parser.add_argument("--num-epochs", default=100, type=int) 413 | parser.add_argument("--lr", default=3e-5, type=float) 414 | parser.add_argument("-b", "--batch-size", default=1, type=int) 415 | parser.add_argument("--log-freq", default=100, type=int) 416 | parser.add_argument("--ckpt-freq", default=500, type=int) 417 | parser.add_argument("-s", "--seq-length", default=1024, type=int) 418 | parser.add_argument("--cpu-offload", default="on", choices=["on", "off"]) 419 | return parser 420 | 421 | 422 | if __name__ == "__main__": 423 | main() 424 | -------------------------------------------------------------------------------- /06-tensor-parallel/README.md: -------------------------------------------------------------------------------- 1 | # Tensor Parallelism (TP) 2 | 3 | So far we've just been using data parallel techniques. You may have heard of other parallelism techniques, and indeed the [Llama 405B paper](https://ai.meta.com/research/publications/the-llama-3-herd-of-models/) actually uses 4D parallelism when training the 405B model: 4 | 5 | 1. Data parallel (FSDP as we've learned) 6 | 2. Tensor parallel (**this chapter**) 7 | 3. Context parallel (For long context lengths) 8 | 4. Pipeline/model parallel 9 | 10 | In this chapter we are going to diving into what tensor parallelism is, before we think about combining it with other types. 11 | 12 | ## Basics: What is tensor parallelism? 13 | 14 | TP splits the model weights **AND** computation across multiple GPUs. 15 | 16 | FSDP splits the model weights, but it gathers them back for the computation. Splitting the computation across GPUs is the difference. 17 | 18 | A result of this is the world size is scaled **down** by your tensor parallel size => the cost of allgathers/allreduces is reduced. This becomes a big factor when your cluster is large, and TP is a very effective way to scale up! 19 | 20 | Here are the benefits of this: 21 | 1. The peak GPU memory is reduced - now instead of each GPU fully loading up the full weights for each layer, they now only load `1/num_gpus` of the weights. 22 | 2. We now have `per GPU memory * num_gpus` as our amount of memory to use for each layer. 23 | 3. Less allgather/allreduce cost 24 | 25 | Here are the downsides: 26 | 1. Global batch size is reduced 27 | 2. Increased code complexity 28 | 29 | Note that this can only really be applied to certain modules, but most of the modules in an LLM work with it. 30 | 31 | ## Ensure all GPUs on a node get the same input 32 | 33 | Since we are splitting computation across GPUs, all GPUs in the same group need to receive the same input. (That is why the global batch size is reduced). 34 | 35 | First we are going to create our device mesh. A device mesh is just a way to view your devices in an N-dimensional way. So if you have 8 GPUs, you could organize it into a device mesh like `(2, 2, 2)`, or `(2, 4)`, or `(4, 2)` or even things like `(1, 8)`. 36 | 37 | The reason this is helpful is because we are going to name these dimensions, much like we do with tensor dimensions. Similar to how we have a batch and sequence dimension, for our device mesh we are going to have a data parallel and tensor parallel dimension. 38 | 39 | ```python 40 | gpus_on_node = torch.cuda.device_count() 41 | num_nodes = world_size // gpus_on_node 42 | mesh = dist.device_mesh.init_device_mesh( 43 | "cuda", 44 | (num_nodes, gpus_on_node), 45 | mesh_dim_names=("dp", "tp"), 46 | ) 47 | ``` 48 | 49 | So if we have 4 GPUs total, and have a `(2, 2)` device mesh, here are the assignments: 50 | 51 | | | DP rank | TP rank | 52 | | --- | --- | --- | 53 | | GPU 0 | 0 | 0 | 54 | | GPU 1 | 0 | 1 | 55 | | GPU 2 | 1 | 0 | 56 | | GPU 3 | 1 | 1 | 57 | 58 | This doesn't actually mean anything unless we update the rest of our code to use these device meshes, so let's see how we do that! 59 | A lot of the pytorch distributed APIs actually take an optional `mesh: Optional[DeviceMesh] = None` argument, we just haven't used it so far. 60 | 61 | The first place is actually our data sampler, and this is how we get all of our GPUs in the tensor parallel group the same input: 62 | 63 | ```python 64 | sampler=DistributedSampler( 65 | ..., 66 | num_replicas=mesh["dp"].size(), 67 | # NOTE: every GPU on a node will have the same "dp" rank, 68 | # meaning they will all receive the same input! 69 | rank=mesh["dp"].get_local_rank(), 70 | ) 71 | ``` 72 | 73 | From GPU 0's perspective above, it would have these arguments to DistributedSampler: 74 | 75 | | | num_replicas | rank| 76 | | --- | --- | --- | 77 | | GPU 0 | 2 | 0 | 78 | | GPU 1 | 2 | 0 | 79 | | GPU 2 | 2 | 1 | 80 | | GPU 3 | 2 | 1 | 81 | 82 | Because our DP dimension is size of 2, and our first table above actually shows the local_rank that we use to pass to DistributedSampler. 83 | 84 | ## Parallelizing linear & attention modules 85 | 86 | Here's the code first and then there are graphics after this that explain how this works. Note that we are passing our `mesh["tp"]` to the API, which means this is happening across our tensor parallel group! 87 | 88 | ```python 89 | for layer in model.model.layers: 90 | tp.parallelize_module( 91 | layer, 92 | mesh["tp"], 93 | { 94 | "self_attn.q_proj": tp.ColwiseParallel(), 95 | "self_attn.k_proj": tp.ColwiseParallel(), 96 | "self_attn.v_proj": tp.ColwiseParallel(), 97 | "self_attn.o_proj": tp.RowwiseParallel(), 98 | 99 | "mlp.gate_proj": tp.ColwiseParallel(), 100 | "mlp.up_proj": tp.ColwiseParallel(), 101 | "mlp.down_proj": tp.RowwiseParallel(), 102 | }, 103 | ) 104 | ``` 105 | 106 | ### colwise 107 | 108 | Our first three linear layers in self attention (q/k/v projection) are all colwise linear. This means we are sharding the weight matrix inside along dimension 0 (since it's stored in a transposed format). The remainder of the attention layer (including self attention), uses this sharded output to run (so attention actually will run on smaller tensors). 109 | 110 | 111 | 112 | Image Source: [PyTorchLightning](https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning#column-wise-parallel) 113 | 114 | ### colwise into rowwise 115 | 116 | Our final layer in our self attention layer is another linear layer (o_proj). Note that we are doing rowwise parallel here. This actually let's us "recombine" across our tp dimension, as shown here: 117 | 118 | 119 | 120 | Image Source: [PyTorchLightning](https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning#combined-parallel-layers) 121 | 122 | So the final output of self attention will be replicated again. 123 | 124 | ### Parallelizing Embedding layer 125 | 126 | The embeddings weight get's sharded along dimension 1. Meaning each GPU holds a different slice of the data associated with each token: 127 | 128 | | Embedding Weight Shape | Sharded Shape | 129 | | --- | --- | 130 | | `(vocab_size, hidden_dim)` | `(vocab_size, hidden_dim / mesh["tp"].size())` | 131 | 132 | In a normal embedding layer it: 133 | - Takes input tokens of `shape=(batch, seq)` 134 | - Outputs embeddings of `shape=(batch, seq, hidden_dim)` 135 | 136 | Now that we've sharded the embedding weight tensor, the layer will actually output: 137 | - Sharded output embeddings of `shape=(batch, seq, hidden_dim / mesh["tp"].size())`. 138 | 139 | We have a problem though: Our *colwise* pieces of the `self_attn` module will receive the output of this module. ColwiseParallel actually expects input to be **replicated** not sharded. 140 | 141 | So we need to do an allgather on the tensor to replicate it across the group (i.e. it will be back to `shape=(batch, seq, hidden_dim)`). Luckily we can just specify this additional transformation with the `output_layouts` argument: 142 | 143 | ```python 144 | tp.parallelize_module( 145 | model, 146 | mesh["tp"], 147 | {"model.embed_tokens": tp.ColwiseParallel(output_layouts=Replicate())}, 148 | ) 149 | ``` 150 | 151 | ### Parallelizing the final linear layer of the model 152 | 153 | ```python 154 | tp.parallelize_module( 155 | model, 156 | mesh["tp"], 157 | { 158 | "lm_head": tp.ColwiseParallel( 159 | output_layouts=Replicate() 160 | ), 161 | }, 162 | ) 163 | ``` 164 | 165 | We have to include `Replicate()` here because our loss expects replicated tensors, but colwise by default shards on the last dimension. 166 | 167 | ## Parallelizing Norm Layers with SequenceParallel 168 | 169 | For normalization layers, it works a bit differently. We don't actually shard the layer's weights at all, instead, we shard the **input** for this on the sequence dimension! 170 | 171 | So our computation is split, and we need to do some work to join the results back together for the other modules: 172 | 173 | ```diff 174 | for layer in model.model.layers: 175 | tp.parallelize_module( 176 | layer, 177 | mesh["tp"], 178 | { 179 | + "input_layernorm": tp.SequenceParallel(), 180 | + "self_attn": tp.PrepareModuleInput( 181 | + input_kwarg_layouts={"hidden_states": Shard(dim=1)}, 182 | + desired_input_kwarg_layouts={"hidden_states": Replicate()}, 183 | + ), 184 | "self_attn.q_proj": tp.ColwiseParallel(), 185 | "self_attn.k_proj": tp.ColwiseParallel(), 186 | "self_attn.v_proj": tp.ColwiseParallel(), 187 | - "self_attn.o_proj": tp.RowwiseParallel(), 188 | + "self_attn.o_proj": tp.RowwiseParallel(output_layouts=Shard(1)), 189 | + "post_attention_layernorm": tp.SequenceParallel(), 190 | + "mlp": tp.PrepareModuleInput( 191 | + input_layouts=Shard(dim=1), 192 | + desired_input_layouts=Replicate(), 193 | + ), 194 | "mlp.gate_proj": tp.ColwiseParallel(), 195 | "mlp.up_proj": tp.ColwiseParallel(), 196 | - "mlp.down_proj": tp.RowwiseParallel(), 197 | + "mlp.down_proj": tp.RowwiseParallel(output_layouts=Shard(1)), 198 | }, 199 | ) 200 | ``` 201 | 202 | The `PrepareModuleInput` objects transform how the tensors are split up. E.g. for `self_attn` the hidden_states input is sharded along the 1st dimension because of the `SequenceParallel`, but all the `ColwiseParallel` expect input to be replicated. 203 | 204 | We also need to change our embedding layer, since now the output of that is going into our SequenceParallel layer, we need to shard it along dimension 1: 205 | 206 | ```diff 207 | tp.parallelize_module( 208 | model, 209 | mesh["tp"], 210 | - {"model.embed_tokens": tp.ColwiseParallel(output_layouts=Replicate())}, 211 | + {"model.embed_tokens": tp.ColwiseParallel(output_layouts=Shard(1))}, 212 | ) 213 | ``` 214 | 215 | We actually need an additional change because of this, due to `transformers` specific code. It computes the sequence length based on the output of the embedding layer, which will be wrong since we are now sharding it along the sequence dimension. Passing position_ids explicitly will fix this, but **its very implementation specific**: 216 | 217 | ```diff 218 | with timers["data"], torch.no_grad(): 219 | batch = next(batches) 220 | batch = {k: v.to(device=device) for k, v in batch.items()} 221 | + batch["position_ids"] = torch.arange( 222 | + 0, args.seq_length, device=device, dtype=torch.long 223 | + ).unsqueeze(0) 224 | ``` 225 | 226 | And here is the diff for our final output from the network: 227 | ```diff 228 | tp.parallelize_module( 229 | model, 230 | mesh["tp"], 231 | { 232 | + "model.norm": tp.SequenceParallel(), 233 | "lm_head": tp.ColwiseParallel( 234 | + input_layouts=Shard(1), 235 | output_layouts=Replicate(), 236 | ), 237 | }, 238 | ) 239 | ``` 240 | 241 | ## Parallelizing Loss computation 242 | 243 | There's an additional api for parallelizing the loss computation (only works for Cross Entropy at the moment of writing) across the **class** dimension. We first need to use this context manager around our loss computation: 244 | 245 | ```python 246 | with tp.loss_parallel(), timers["forward"]: 247 | outputs = model(**batch) 248 | 249 | with tp.loss_parallel(), timers["backward"]: 250 | outputs.loss.backward() 251 | ``` 252 | 253 | Then we need to update the output of our `lm_head` for this also, because loss_parallel requires different sharding format and DTensor: 254 | 255 | ```diff 256 | tp.parallelize_module( 257 | model, 258 | mesh["tp"], 259 | { 260 | "model.norm": tp.SequenceParallel(), 261 | "lm_head": tp.ColwiseParallel( 262 | input_layouts=Shard(1), 263 | - output_layouts=Replicate(), 264 | + output_layouts=Shard(-1), 265 | + use_local_output=False, 266 | ), 267 | }, 268 | ) 269 | ``` 270 | 271 | `use_local_output=False` tells pytorch to return a `DTensor` from the operation, instead of a normal `Tensor`. 272 | 273 | ## Computing throughput with our new world size 274 | 275 | Because each of our GPUs is now no longer the unit, we just need to update our throughput calculation to use our device mesh: 276 | 277 | ```diff 278 | if state["global_step"] % args.log_freq == 0: 279 | - tok_per_step = world_size * args.batch_size * args.seq_length 280 | + tok_per_step = mesh["dp"].size() * args.batch_size * args.seq_length 281 | ms_per_step = sum(t.avg_elapsed_ms() for t in timers.values()) 282 | ``` 283 | 284 | ## Results 285 | 286 | Here are some results from launching training for llama 8B on a single node of 8x H100s: 287 | 288 | Command: 289 | ```bash 290 | HF_HOME=/home/ubuntu/.cache/huggingface OMP_NUM_THREADS=26 torchrun --standalone --nproc-per-node gpu train_llm.py --experiment-name tp-llama-8b --dataset-name tatsu-lab/alpaca --model-name meta-llama/Llama-3.1-8B --log-freq 10 --batch-size 16 --seq-length 1024 --num-epochs 1 291 | ``` 292 | 293 | 294 | 295 | 296 | 297 | ## Useful References 298 | 299 | For completeness here are the relevant docs/guides from pytorch on how to achieve this: 300 | - [TP API docs](https://pytorch.org/docs/stable/distributed.tensor.parallel.html#tensor-parallelism-torch-distributed-tensor-parallel) 301 | - [2d Parallelism Tutorial](https://pytorch.org/tutorials/intermediate/TP_tutorial.html#large-scale-transformer-model-training-with-tensor-parallel-tp) 302 | - [Device Mesh tutorial](https://pytorch.org/tutorials/recipes/distributed_device_mesh.html) 303 | - [PyTorch Lightning TP Tutorial](https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning) 304 | 305 | ## Pytorch API Reference 306 | 307 | Here we are going to give a brief explanation of how the api we are going to be using works. 308 | 309 | - [tp.RowwiseParallel()](https://pytorch.org/docs/stable/distributed.tensor.parallel.html#torch.distributed.tensor.parallel.RowwiseParallel) shards the module's weights in a row wise fashion. 310 | - Inputs by default are sharded on last dimension 311 | - Outputs by default are replicated on all workers 312 | - [tp.ColwiseParallel()](https://pytorch.org/docs/stable/distributed.tensor.parallel.html#torch.distributed.tensor.parallel.ColwiseParallel) shards the module's weights in a col wise fashion. 313 | - Inputs by default are replicated on all workers 314 | - Outputs by default are sharded on last dimension 315 | - [tp.SequenceParallel()](https://pytorch.org/docs/stable/distributed.tensor.parallel.html#torch.distributed.tensor.parallel.SequenceParallel) shards the input/output across dimension 1. Module weights are NOT sharded. 316 | - [tp.PrepareModuleInput()](https://pytorch.org/docs/stable/distributed.tensor.parallel.html#torch.distributed.tensor.parallel.PrepareModuleInput) let's you change the sharding configuration of input tensors 317 | - `torch.distributed._tensor.Shard(dim=X)` indicates a tensor should be sharded along dimension X 318 | - `torch.distributed._tensor.Replicate()` indicates a tensor should be replicated among all workers. 319 | 320 | How all of these things interact is actually very subtle and complex, which is why this guide is useful! 321 | 322 | You can also change most of the default behavior with arguments to these classes. For example, you can change RowwiseParallel to assume the input is replicated instead of sharded. 323 | -------------------------------------------------------------------------------- /06-tensor-parallel/train_llm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from contextlib import contextmanager 3 | from itertools import chain 4 | import json 5 | import multiprocessing 6 | import os 7 | import time 8 | from pathlib import Path 9 | import logging 10 | 11 | import torch 12 | from torch.utils.data import DataLoader 13 | from torch.utils.data.distributed import DistributedSampler 14 | from torch import distributed as dist 15 | import torch.distributed.tensor.parallel as tp 16 | from torch.distributed._tensor import Shard, Replicate 17 | from torch.distributed.elastic.multiprocessing.errors import record 18 | import torch.distributed.checkpoint as DCP 19 | 20 | 21 | import wandb 22 | import tqdm 23 | import datasets 24 | from transformers import ( 25 | AutoConfig, 26 | AutoModelForCausalLM, 27 | AutoTokenizer, 28 | default_data_collator, 29 | ) 30 | 31 | LOGGER = logging.getLogger(__name__) 32 | 33 | 34 | @record 35 | def main(): 36 | parser = _get_parser() 37 | args = parser.parse_args() 38 | 39 | dist.init_process_group() 40 | 41 | gpus_on_node = torch.cuda.device_count() 42 | 43 | rank = dist.get_rank() 44 | local_rank = rank % gpus_on_node 45 | world_size = dist.get_world_size() 46 | 47 | assert ( 48 | world_size % gpus_on_node == 0 49 | ), "This script assumes all nodes have the same amount of GPUs" 50 | num_nodes = world_size // gpus_on_node 51 | 52 | mesh = dist.device_mesh.init_device_mesh( 53 | "cuda", 54 | (num_nodes, gpus_on_node), 55 | mesh_dim_names=("dp", "tp"), 56 | ) 57 | 58 | logging.basicConfig( 59 | format=f"[rank={rank}] [%(asctime)s] %(levelname)s:%(message)s", 60 | level=logging.INFO, 61 | ) 62 | 63 | LOGGER.info(os.environ) 64 | LOGGER.info(args) 65 | LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}") 66 | LOGGER.info(f"dp_size={mesh['dp'].size()} tp_size={mesh['tp'].size()}") 67 | 68 | device = torch.device(f"cuda:{local_rank}") 69 | dtype = torch.bfloat16 70 | torch.cuda.set_device(device) 71 | 72 | torch.manual_seed(args.seed) 73 | 74 | LOGGER.info(f"Loading model from HF_HOME={os.environ['HF_HOME']}") 75 | 76 | with rank_ordered(should_go_first=local_rank == 0): 77 | config = AutoConfig.from_pretrained(args.model_name, use_cache=False) 78 | with device: 79 | model = AutoModelForCausalLM.from_config( 80 | config, torch_dtype=dtype, attn_implementation="flash_attention_2" 81 | ) 82 | LOGGER.info(f"{sum(p.numel() for p in model.parameters())} model parameters") 83 | 84 | tp.parallelize_module( 85 | model, 86 | mesh["tp"], 87 | {"model.embed_tokens": tp.ColwiseParallel(output_layouts=Shard(1))}, 88 | ) 89 | for layer in model.model.layers: 90 | tp.parallelize_module( 91 | layer, 92 | mesh["tp"], 93 | { 94 | # SequenceParallel will apply sharding to sequence dimension. 95 | "input_layernorm": tp.SequenceParallel(), 96 | # The input to self_attn (which is the output from the SequenceParallel input_layer_norm) will be sharded on dimension 1, but we wanted it to be the whole tensor. 97 | "self_attn": tp.PrepareModuleInput( 98 | input_kwarg_layouts={"hidden_states": Shard(dim=1)}, 99 | desired_input_kwarg_layouts={"hidden_states": Replicate()}, 100 | ), 101 | "self_attn.q_proj": tp.ColwiseParallel(), 102 | "self_attn.k_proj": tp.ColwiseParallel(), 103 | "self_attn.v_proj": tp.ColwiseParallel(), 104 | "self_attn.o_proj": tp.RowwiseParallel(output_layouts=Shard(1)), 105 | # Another sharding along sequence dimension. 106 | "post_attention_layernorm": tp.SequenceParallel(), 107 | "mlp": tp.PrepareModuleInput( 108 | input_layouts=Shard(dim=1), 109 | desired_input_layouts=Replicate(), 110 | ), 111 | "mlp.gate_proj": tp.ColwiseParallel(), 112 | "mlp.up_proj": tp.ColwiseParallel(), 113 | "mlp.down_proj": tp.RowwiseParallel(output_layouts=Shard(1)), 114 | }, 115 | ) 116 | 117 | tp.parallelize_module( 118 | model, 119 | mesh["tp"], 120 | { 121 | "model.norm": tp.SequenceParallel(), 122 | "lm_head": tp.ColwiseParallel( 123 | input_layouts=Shard(1), 124 | output_layouts=Shard(-1), # for tp.loss_parallel 125 | use_local_output=False, # for tp.loss_parallel 126 | ), 127 | }, 128 | ) 129 | 130 | LOGGER.info(f"Final Architecture: {model}") 131 | LOGGER.info(f"{sum(p.numel() for p in model.parameters())} model parameters") 132 | 133 | model = model.to_empty(device=device) 134 | model.init_weights() 135 | model.train() 136 | 137 | LOGGER.info(f"{get_mem_stats(device)}") 138 | 139 | # NOTE: since this can download data, make sure to do the main process first on each node 140 | # since we manually specified HF_HOME to be a node local drive. 141 | with rank_ordered(should_go_first=local_rank == 0): 142 | train_data = _load_and_preprocess_data(args, config) 143 | LOGGER.info(f"{len(train_data)} training samples") 144 | 145 | dataloader = DataLoader( 146 | train_data, 147 | batch_size=args.batch_size, 148 | collate_fn=default_data_collator, 149 | num_workers=1, 150 | prefetch_factor=2, 151 | # NOTE: this sampler will split dataset evenly across workers 152 | sampler=DistributedSampler( 153 | train_data, 154 | shuffle=True, 155 | drop_last=True, 156 | num_replicas=mesh["dp"].size(), # equivalent to `num_nodes` 157 | rank=mesh["dp"].get_local_rank(), # equivalent to `rank // num_nodes` 158 | ), 159 | ) 160 | LOGGER.info(f"{len(dataloader)} batches per epoch") 161 | 162 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, fused=True) 163 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 164 | optimizer, T_max=1000, eta_min=args.lr * 1e-2 165 | ) 166 | 167 | exp_dir: Path = Path(args.save_dir) / args.experiment_name 168 | 169 | # attempt resume 170 | state = { 171 | "epoch": 0, 172 | "global_step": 0, 173 | "epoch_step": 0, 174 | "running_loss": 0, 175 | } 176 | resumed = False 177 | if (exp_dir / "state.json").exists(): 178 | DCP.load( 179 | dict(model=model, optimizer=optimizer), 180 | checkpoint_id=exp_dir / "checkpoint", 181 | ) 182 | lr_scheduler.load_state_dict( 183 | torch.load( 184 | exp_dir / "lr_scheduler.pt", map_location=device, weights_only=True 185 | ) 186 | ) 187 | with open(exp_dir / "state.json") as fp: 188 | state = json.load(fp) 189 | resumed = True 190 | LOGGER.info(f"Resumed={resumed} | {state}") 191 | dist.barrier() 192 | 193 | if (exp_dir.is_mount() and rank == 0) or ( 194 | not exp_dir.is_mount() and local_rank == 0 195 | ): 196 | LOGGER.info(f"Creating experiment root directory") 197 | exp_dir.mkdir(parents=True, exist_ok=True) 198 | dist.barrier() 199 | 200 | if rank == 0: 201 | wandb.init( 202 | project="distributed-training-guide", 203 | dir=exp_dir, 204 | name=args.experiment_name, 205 | id=args.experiment_name, 206 | resume="must" if resumed else None, 207 | save_code=True, 208 | config={ 209 | "args": vars(args), 210 | "training_data_size": len(train_data), 211 | "num_batches": len(dataloader), 212 | "world_size": world_size, 213 | }, 214 | ) 215 | 216 | timers = {k: LocalTimer(device) for k in ["data", "forward", "backward", "update"]} 217 | 218 | for state["epoch"] in range(state["epoch"], args.num_epochs): 219 | LOGGER.info(f"Begin epoch {state['epoch']} at step {state['epoch_step']}") 220 | 221 | progress_bar = tqdm.tqdm(range(len(dataloader)), disable=True) 222 | if state["epoch_step"] > 0: 223 | progress_bar.update(state["epoch_step"]) 224 | 225 | batches = iter(dataloader) 226 | 227 | for i_step in range(len(dataloader)): 228 | with timers["data"], torch.no_grad(): 229 | batch = next(batches) 230 | batch = {k: v.to(device=device) for k, v in batch.items()} 231 | batch["position_ids"] = torch.arange( 232 | 0, args.seq_length, device=device, dtype=torch.long 233 | ).unsqueeze(0) 234 | 235 | if i_step < state["epoch_step"]: 236 | # NOTE: for resuming 237 | continue 238 | 239 | with tp.loss_parallel(), timers["forward"]: 240 | outputs = model(**batch) 241 | 242 | with tp.loss_parallel(), timers["backward"]: 243 | outputs.loss.backward() 244 | 245 | with timers["update"]: 246 | optimizer.step() 247 | lr_scheduler.step() 248 | optimizer.zero_grad(set_to_none=True) 249 | 250 | state["global_step"] += 1 251 | state["epoch_step"] += 1 252 | state["running_loss"] += outputs.loss.item() 253 | progress_bar.update(1) 254 | 255 | if state["global_step"] % args.log_freq == 0: 256 | tok_per_step = mesh["dp"].size() * args.batch_size * args.seq_length 257 | ms_per_step = sum(t.avg_elapsed_ms() for t in timers.values()) 258 | info = { 259 | "global_step": state["global_step"], 260 | "lr": lr_scheduler.get_last_lr()[0], 261 | "running_loss": state["running_loss"] / args.log_freq, 262 | "epoch": state["epoch"], 263 | "epoch_progress": state["epoch_step"] / len(dataloader), 264 | "num_batches_remaining": len(dataloader) - i_step, 265 | "tok/s": 1000 * tok_per_step / ms_per_step, 266 | **get_mem_stats(device), 267 | "time/total": sum(t.avg_elapsed_ms() for t in timers.values()), 268 | **{ 269 | f"time/{k}": timer.avg_elapsed_ms() 270 | for k, timer in timers.items() 271 | }, 272 | } 273 | 274 | LOGGER.info(info) 275 | if rank == 0: 276 | wandb.log(info, step=state["global_step"]) 277 | 278 | torch.cuda.reset_peak_memory_stats(device) 279 | state["running_loss"] = 0 280 | for t in timers.values(): 281 | t.reset() 282 | 283 | if state["global_step"] % args.ckpt_freq == 0: 284 | LOGGER.info("Saving checkpoint.") 285 | dist.barrier() 286 | # NOTE: we have to call this on ALL ranks 287 | DCP.save( 288 | dict(model=model, optimizer=optimizer), 289 | checkpoint_id=exp_dir / "checkpoint", 290 | ) 291 | if rank == 0: 292 | torch.save(lr_scheduler.state_dict(), exp_dir / "lr_scheduler.pt") 293 | with open(exp_dir / "state.json", "w") as fp: 294 | json.dump(state, fp) 295 | dist.barrier() 296 | 297 | state["epoch_step"] = 0 298 | 299 | 300 | def get_mem_stats(device=None): 301 | mem = torch.cuda.memory_stats(device) 302 | props = torch.cuda.get_device_properties(device) 303 | return { 304 | "total_gb": 1e-9 * props.total_memory, 305 | "curr_alloc_gb": 1e-9 * mem["allocated_bytes.all.current"], 306 | "peak_alloc_gb": 1e-9 * mem["allocated_bytes.all.peak"], 307 | "curr_resv_gb": 1e-9 * mem["reserved_bytes.all.current"], 308 | "peak_resv_gb": 1e-9 * mem["reserved_bytes.all.peak"], 309 | } 310 | 311 | 312 | def _load_and_preprocess_data(args, config): 313 | """ 314 | Function created using code found in 315 | https://github.com/huggingface/transformers/blob/v4.45.1/examples/pytorch/language-modeling/run_clm_no_trainer.py 316 | """ 317 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 318 | 319 | data = datasets.load_dataset(args.dataset_name, trust_remote_code=True) 320 | 321 | column_names = data["train"].column_names 322 | text_column_name = "text" if "text" in column_names else column_names[0] 323 | 324 | def tokenize_function(examples): 325 | return tokenizer(examples[text_column_name]) 326 | 327 | tokenized_datasets = data.map( 328 | tokenize_function, 329 | batched=True, 330 | remove_columns=column_names, 331 | num_proc=multiprocessing.cpu_count(), 332 | load_from_cache_file=True, 333 | desc="Running tokenizer on dataset", 334 | ) 335 | 336 | seq_length = args.seq_length or tokenizer.model_max_length 337 | if seq_length > config.max_position_embeddings: 338 | seq_length = min(1024, config.max_position_embeddings) 339 | 340 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. 341 | def group_texts(examples): 342 | # Concatenate all texts. 343 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 344 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 345 | # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. 346 | # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. 347 | if total_length > seq_length: 348 | total_length = (total_length // seq_length) * seq_length 349 | # Split by chunks of max_len. 350 | result = { 351 | k: [t[i : i + seq_length] for i in range(0, total_length, seq_length)] 352 | for k, t in concatenated_examples.items() 353 | } 354 | result["labels"] = result["input_ids"].copy() 355 | return result 356 | 357 | lm_datasets = tokenized_datasets.map( 358 | group_texts, 359 | batched=True, 360 | num_proc=multiprocessing.cpu_count(), 361 | load_from_cache_file=True, 362 | desc=f"Grouping texts in chunks of {seq_length}", 363 | ) 364 | 365 | return lm_datasets["train"] 366 | 367 | 368 | @contextmanager 369 | def rank_ordered(*, should_go_first: bool): 370 | if should_go_first: 371 | yield 372 | dist.barrier() 373 | if not should_go_first: 374 | yield 375 | dist.barrier() 376 | 377 | 378 | class LocalTimer: 379 | def __init__(self, device: torch.device): 380 | if device.type == "cpu": 381 | self.synchronize = lambda: torch.cpu.synchronize(device=device) 382 | elif device.type == "cuda": 383 | self.synchronize = lambda: torch.cuda.synchronize(device=device) 384 | self.measurements = [] 385 | self.start_time = None 386 | 387 | def __enter__(self): 388 | self.synchronize() 389 | self.start_time = time.time() 390 | return self 391 | 392 | def __exit__(self, type, value, traceback): 393 | if traceback is None: 394 | self.synchronize() 395 | end_time = time.time() 396 | self.measurements.append(end_time - self.start_time) 397 | self.start_time = None 398 | 399 | def avg_elapsed_ms(self): 400 | return 1000 * (sum(self.measurements) / len(self.measurements)) 401 | 402 | def reset(self): 403 | self.measurements = [] 404 | self.start_time = None 405 | 406 | 407 | def _get_parser() -> argparse.ArgumentParser: 408 | parser = argparse.ArgumentParser() 409 | parser.add_argument("-e", "--experiment-name", default=None, required=True) 410 | parser.add_argument("-d", "--dataset-name", default=None, required=True) 411 | parser.add_argument("-m", "--model-name", default=None, required=True) 412 | parser.add_argument("--save-dir", default="../outputs") 413 | parser.add_argument("--seed", default=0, type=int) 414 | parser.add_argument("--num-epochs", default=100, type=int) 415 | parser.add_argument("--lr", default=3e-5, type=float) 416 | parser.add_argument("-b", "--batch-size", default=1, type=int) 417 | parser.add_argument("--log-freq", default=100, type=int) 418 | parser.add_argument("--ckpt-freq", default=500, type=int) 419 | parser.add_argument("-s", "--seq-length", default=None, type=int) 420 | return parser 421 | 422 | 423 | if __name__ == "__main__": 424 | main() 425 | -------------------------------------------------------------------------------- /07-2d-parallel/README.md: -------------------------------------------------------------------------------- 1 | # 2d parallelism (TP + DP) 2 | 3 | Using both [FSDP](../04-fully-sharded-data-parallel) and [TP](../06-tensor-parallel) is actually quite simple code wise when starting from our [chapter 6 TP script](../06-tensor-parallel/train_llm.py). 4 | 5 | **Disclaimer** this only works if you use pytorch's **newer FSDP 2 api, which is still in alpha stages**. 6 | 7 | What does using these two together mean exactly? Let's get into an example with 6 GPUs, 2 way FSDP and 3 way TP: 8 | 9 | image 10 | 11 | When we first start out every gpu holds the full model. Then we shard the model into 3 pieces (our TP dimension). The 3 shards in the graphic above are red+orange, yellow+green, and blue+purple. Note that GPU 0 and GPU 3 **have the exact same shard**! This is because they are the same tensor parallel rank, but are different data parallel ranks. This means we have **duplicated** our model across our data parallel dimension. 12 | 13 | When we apply FSDP in the next step, we split those duplicated shards! So Shard red+orange (which is duplicated on GPU 0 & 3) is split into two pieces (Shard red and Shard orange). 14 | 15 | By the end we have 6 distinct shards of our model split on every GPU. 16 | 17 | Now if you remember with FSDP, it does an allgather of all the shards before the forward pass. When GPU 0 & GPU 3 are executing their forward passes, they will gather the two shards (Shard red and Shard orange) into local memory to form Shard red+orange, so that each one can use the full shard during computation. 18 | 19 | ## Applying FSDP after TP 20 | 21 | We are starting from our [chapter 6 code](../06-tensor-parallel/train_llm.py), which already support TP. So we just need to add FSDP to the script: 22 | 23 | The api is much simpler than FSDP 1 api, this is all we need to add **after** our TP code: 24 | 25 | ```python 26 | from torch.distributed._composable.fsdp import fully_shard 27 | 28 | if mesh["dp"].size() > 1: 29 | for layer in model.model.layers: 30 | fully_shard(layer, mesh=mesh["dp"]) 31 | fully_shard(model, mesh=mesh["dp"]) 32 | ``` 33 | 34 | Note how we are passing our `mesh["dp"]` here to indicate that this is happening across our data parallel dimension. 35 | 36 | ## Controlling TP size 37 | 38 | When creating our mesh we are going to set the TP size based on a CLI argument: 39 | 40 | ```python 41 | assert world_size % args.tp == 0 42 | 43 | mesh = dist.device_mesh.init_device_mesh( 44 | "cuda", 45 | (world_size // args.tp, args.tp), 46 | mesh_dim_names=("dp", "tp"), 47 | ) 48 | ``` 49 | 50 | and add it to our argparser: 51 | 52 | ```python 53 | parser.add_argument("--tp", default=8, type=int) 54 | ``` 55 | 56 | ## Performance with different configurations 57 | 58 | Here are some training results for 4 different setups of the TP size: 59 | - 1x8 is 8 way TP, and no data parallelism. `--batch-size 18 --tp 8` 60 | - 2x4 is 4 way TP, with 2 groups of FSDP. `--batch-size 14 --tp 4` 61 | - 4x2 is 2 way TP, with 4 groups of FSDP. `--batch-size 10 --tp 2` 62 | - 8x1 is FSDP. `--batch-size 7 --tp 1` 63 | 64 | Note that all of these runs have the same `--lr` while having different batch sizes, which is why the loss curves are slightly different. 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /07-2d-parallel/train_llm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from contextlib import contextmanager 3 | from itertools import chain 4 | import json 5 | import multiprocessing 6 | import os 7 | import time 8 | from pathlib import Path 9 | import logging 10 | 11 | import torch 12 | from torch.utils.data import DataLoader 13 | from torch.utils.data.distributed import DistributedSampler 14 | from torch import distributed as dist 15 | import torch.distributed.tensor.parallel as tp 16 | from torch.distributed._tensor import Shard, Replicate 17 | from torch.distributed.elastic.multiprocessing.errors import record 18 | import torch.distributed.checkpoint as DCP 19 | from torch.distributed._composable.fsdp import fully_shard 20 | 21 | import wandb 22 | import tqdm 23 | import datasets 24 | from transformers import ( 25 | AutoConfig, 26 | AutoModelForCausalLM, 27 | AutoTokenizer, 28 | default_data_collator, 29 | ) 30 | 31 | LOGGER = logging.getLogger(__name__) 32 | 33 | 34 | @record 35 | def main(): 36 | parser = _get_parser() 37 | args = parser.parse_args() 38 | 39 | dist.init_process_group() 40 | 41 | rank = dist.get_rank() 42 | local_rank = rank % torch.cuda.device_count() 43 | world_size = dist.get_world_size() 44 | 45 | assert args.tp > 1 46 | assert world_size % args.tp == 0 47 | 48 | mesh = dist.device_mesh.init_device_mesh( 49 | "cuda", 50 | (world_size // args.tp, args.tp), 51 | mesh_dim_names=("dp", "tp"), 52 | ) 53 | 54 | logging.basicConfig( 55 | format=f"[rank={rank}] [%(asctime)s] %(levelname)s:%(message)s", 56 | level=logging.INFO, 57 | ) 58 | 59 | LOGGER.info(os.environ) 60 | LOGGER.info(args) 61 | LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}") 62 | LOGGER.info(f"dp_size={mesh['dp'].size()} tp_size={mesh['tp'].size()}") 63 | 64 | device = torch.device(f"cuda:{local_rank}") 65 | dtype = torch.bfloat16 66 | torch.cuda.set_device(device) 67 | 68 | torch.manual_seed(args.seed) 69 | 70 | LOGGER.info(f"Loading model from HF_HOME={os.environ['HF_HOME']}") 71 | 72 | with rank_ordered(should_go_first=local_rank == 0): 73 | config = AutoConfig.from_pretrained(args.model_name, use_cache=False) 74 | with device: 75 | model = AutoModelForCausalLM.from_config( 76 | config, torch_dtype=dtype, attn_implementation="flash_attention_2" 77 | ) 78 | LOGGER.info(f"{sum(p.numel() for p in model.parameters())} model parameters") 79 | 80 | tp.parallelize_module( 81 | model, 82 | mesh["tp"], 83 | {"model.embed_tokens": tp.ColwiseParallel(output_layouts=Shard(1))}, 84 | ) 85 | for layer in model.model.layers: 86 | tp.parallelize_module( 87 | layer, 88 | mesh["tp"], 89 | { 90 | # SequenceParallel will apply sharding to sequence dimension. 91 | "input_layernorm": tp.SequenceParallel(), 92 | # The input to self_attn (which is the output from the SequenceParallel input_layer_norm) will be sharded on dimension 1, but we wanted it to be the whole tensor. 93 | "self_attn": tp.PrepareModuleInput( 94 | input_kwarg_layouts={"hidden_states": Shard(dim=1)}, 95 | desired_input_kwarg_layouts={"hidden_states": Replicate()}, 96 | ), 97 | "self_attn.q_proj": tp.ColwiseParallel(), 98 | "self_attn.k_proj": tp.ColwiseParallel(), 99 | "self_attn.v_proj": tp.ColwiseParallel(), 100 | "self_attn.o_proj": tp.RowwiseParallel(output_layouts=Shard(1)), 101 | # Another sharding along sequence dimension. 102 | "post_attention_layernorm": tp.SequenceParallel(), 103 | "mlp": tp.PrepareModuleInput( 104 | input_layouts=Shard(dim=1), 105 | desired_input_layouts=Replicate(), 106 | ), 107 | "mlp.gate_proj": tp.ColwiseParallel(), 108 | "mlp.up_proj": tp.ColwiseParallel(), 109 | "mlp.down_proj": tp.RowwiseParallel(output_layouts=Shard(1)), 110 | }, 111 | ) 112 | 113 | tp.parallelize_module( 114 | model, 115 | mesh["tp"], 116 | { 117 | "model.norm": tp.SequenceParallel(), 118 | "lm_head": tp.ColwiseParallel( 119 | input_layouts=Shard(1), 120 | output_layouts=Shard(-1), # for tp.loss_parallel 121 | use_local_output=False, # for tp.loss_parallel 122 | ), 123 | }, 124 | ) 125 | 126 | for layer in model.model.layers: 127 | fully_shard(layer, mesh=mesh["dp"]) 128 | fully_shard(model, mesh=mesh["dp"]) 129 | 130 | LOGGER.info(f"Final Architecture: {model}") 131 | LOGGER.info(f"{sum(p.numel() for p in model.parameters())} model parameters") 132 | 133 | model = model.to_empty(device=device) 134 | model.init_weights() 135 | model.train() 136 | 137 | LOGGER.info(f"{get_mem_stats(device)}") 138 | 139 | # NOTE: since this can download data, make sure to do the main process first on each node 140 | # since we manually specified HF_HOME to be a node local drive. 141 | with rank_ordered(should_go_first=local_rank == 0): 142 | train_data = _load_and_preprocess_data(args, config) 143 | LOGGER.info(f"{len(train_data)} training samples") 144 | 145 | dataloader = DataLoader( 146 | train_data, 147 | batch_size=args.batch_size, 148 | collate_fn=default_data_collator, 149 | num_workers=1, 150 | prefetch_factor=2, 151 | # NOTE: this sampler will split dataset evenly across workers 152 | sampler=DistributedSampler( 153 | train_data, 154 | shuffle=True, 155 | drop_last=True, 156 | num_replicas=mesh["dp"].size(), # equivalent to `num_nodes` 157 | rank=mesh["dp"].get_local_rank(), # equivalent to `rank // num_nodes` 158 | ), 159 | ) 160 | LOGGER.info(f"{len(dataloader)} batches per epoch") 161 | 162 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, fused=True) 163 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 164 | optimizer, T_max=1000, eta_min=args.lr * 1e-2 165 | ) 166 | 167 | exp_dir: Path = Path(args.save_dir) / args.experiment_name 168 | 169 | # attempt resume 170 | state = { 171 | "epoch": 0, 172 | "global_step": 0, 173 | "epoch_step": 0, 174 | "running_loss": 0, 175 | } 176 | resumed = False 177 | if (exp_dir / "state.json").exists(): 178 | DCP.load( 179 | dict(model=model, optimizer=optimizer), 180 | checkpoint_id=exp_dir / "checkpoint", 181 | ) 182 | lr_scheduler.load_state_dict( 183 | torch.load( 184 | exp_dir / "lr_scheduler.pt", map_location=device, weights_only=True 185 | ) 186 | ) 187 | with open(exp_dir / "state.json") as fp: 188 | state = json.load(fp) 189 | resumed = True 190 | LOGGER.info(f"Resumed={resumed} | {state}") 191 | dist.barrier() 192 | 193 | if (exp_dir.is_mount() and rank == 0) or ( 194 | not exp_dir.is_mount() and local_rank == 0 195 | ): 196 | LOGGER.info(f"Creating experiment root directory") 197 | exp_dir.mkdir(parents=True, exist_ok=True) 198 | dist.barrier() 199 | 200 | if rank == 0: 201 | wandb.init( 202 | project="distributed-training-guide", 203 | dir=exp_dir, 204 | name=args.experiment_name, 205 | id=args.experiment_name, 206 | resume="must" if resumed else None, 207 | save_code=True, 208 | config={ 209 | "args": vars(args), 210 | "training_data_size": len(train_data), 211 | "num_batches": len(dataloader), 212 | "world_size": world_size, 213 | }, 214 | ) 215 | 216 | timers = {k: LocalTimer(device) for k in ["data", "forward", "backward", "update"]} 217 | 218 | for state["epoch"] in range(state["epoch"], args.num_epochs): 219 | LOGGER.info(f"Begin epoch {state['epoch']} at step {state['epoch_step']}") 220 | 221 | progress_bar = tqdm.tqdm(range(len(dataloader)), disable=True) 222 | if state["epoch_step"] > 0: 223 | progress_bar.update(state["epoch_step"]) 224 | 225 | batches = iter(dataloader) 226 | 227 | for i_step in range(len(dataloader)): 228 | with timers["data"], torch.no_grad(): 229 | batch = next(batches) 230 | batch = {k: v.to(device=device) for k, v in batch.items()} 231 | batch["position_ids"] = torch.arange( 232 | 0, args.seq_length, device=device, dtype=torch.long 233 | ).unsqueeze(0) 234 | 235 | if i_step < state["epoch_step"]: 236 | # NOTE: for resuming 237 | continue 238 | 239 | with tp.loss_parallel(), timers["forward"]: 240 | outputs = model(**batch) 241 | 242 | with tp.loss_parallel(), timers["backward"]: 243 | outputs.loss.backward() 244 | 245 | with timers["update"]: 246 | optimizer.step() 247 | lr_scheduler.step() 248 | optimizer.zero_grad(set_to_none=True) 249 | 250 | state["global_step"] += 1 251 | state["epoch_step"] += 1 252 | state["running_loss"] += outputs.loss.item() 253 | progress_bar.update(1) 254 | 255 | if state["global_step"] % args.log_freq == 0: 256 | tok_per_step = mesh["dp"].size() * args.batch_size * args.seq_length 257 | ms_per_step = sum(t.avg_elapsed_ms() for t in timers.values()) 258 | info = { 259 | "global_step": state["global_step"], 260 | "lr": lr_scheduler.get_last_lr()[0], 261 | "running_loss": state["running_loss"] / args.log_freq, 262 | "epoch": state["epoch"], 263 | "epoch_progress": state["epoch_step"] / len(dataloader), 264 | "num_batches_remaining": len(dataloader) - i_step, 265 | "tok/s": 1000 * tok_per_step / ms_per_step, 266 | **get_mem_stats(device), 267 | "time/total": sum(t.avg_elapsed_ms() for t in timers.values()), 268 | **{ 269 | f"time/{k}": timer.avg_elapsed_ms() 270 | for k, timer in timers.items() 271 | }, 272 | } 273 | 274 | LOGGER.info(info) 275 | if rank == 0: 276 | wandb.log(info, step=state["global_step"]) 277 | 278 | torch.cuda.reset_peak_memory_stats(device) 279 | state["running_loss"] = 0 280 | for t in timers.values(): 281 | t.reset() 282 | 283 | if state["global_step"] % args.ckpt_freq == 0: 284 | LOGGER.info("Saving checkpoint.") 285 | dist.barrier() 286 | # NOTE: we have to call this on ALL ranks 287 | DCP.save( 288 | dict(model=model, optimizer=optimizer), 289 | checkpoint_id=exp_dir / "checkpoint", 290 | ) 291 | if rank == 0: 292 | torch.save(lr_scheduler.state_dict(), exp_dir / "lr_scheduler.pt") 293 | with open(exp_dir / "state.json", "w") as fp: 294 | json.dump(state, fp) 295 | dist.barrier() 296 | 297 | state["epoch_step"] = 0 298 | 299 | 300 | def get_mem_stats(device=None): 301 | mem = torch.cuda.memory_stats(device) 302 | props = torch.cuda.get_device_properties(device) 303 | return { 304 | "total_gb": 1e-9 * props.total_memory, 305 | "curr_alloc_gb": 1e-9 * mem["allocated_bytes.all.current"], 306 | "peak_alloc_gb": 1e-9 * mem["allocated_bytes.all.peak"], 307 | "curr_resv_gb": 1e-9 * mem["reserved_bytes.all.current"], 308 | "peak_resv_gb": 1e-9 * mem["reserved_bytes.all.peak"], 309 | } 310 | 311 | 312 | def _load_and_preprocess_data(args, config): 313 | """ 314 | Function created using code found in 315 | https://github.com/huggingface/transformers/blob/v4.45.1/examples/pytorch/language-modeling/run_clm_no_trainer.py 316 | """ 317 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 318 | 319 | data = datasets.load_dataset(args.dataset_name, trust_remote_code=True) 320 | 321 | column_names = data["train"].column_names 322 | text_column_name = "text" if "text" in column_names else column_names[0] 323 | 324 | def tokenize_function(examples): 325 | return tokenizer(examples[text_column_name]) 326 | 327 | tokenized_datasets = data.map( 328 | tokenize_function, 329 | batched=True, 330 | remove_columns=column_names, 331 | num_proc=multiprocessing.cpu_count(), 332 | load_from_cache_file=True, 333 | desc="Running tokenizer on dataset", 334 | ) 335 | 336 | seq_length = args.seq_length or tokenizer.model_max_length 337 | if seq_length > config.max_position_embeddings: 338 | seq_length = min(1024, config.max_position_embeddings) 339 | 340 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. 341 | def group_texts(examples): 342 | # Concatenate all texts. 343 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 344 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 345 | # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. 346 | # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. 347 | if total_length > seq_length: 348 | total_length = (total_length // seq_length) * seq_length 349 | # Split by chunks of max_len. 350 | result = { 351 | k: [t[i : i + seq_length] for i in range(0, total_length, seq_length)] 352 | for k, t in concatenated_examples.items() 353 | } 354 | result["labels"] = result["input_ids"].copy() 355 | return result 356 | 357 | lm_datasets = tokenized_datasets.map( 358 | group_texts, 359 | batched=True, 360 | num_proc=multiprocessing.cpu_count(), 361 | load_from_cache_file=True, 362 | desc=f"Grouping texts in chunks of {seq_length}", 363 | ) 364 | 365 | return lm_datasets["train"] 366 | 367 | 368 | @contextmanager 369 | def rank_ordered(*, should_go_first: bool): 370 | if should_go_first: 371 | yield 372 | dist.barrier() 373 | if not should_go_first: 374 | yield 375 | dist.barrier() 376 | 377 | 378 | class LocalTimer: 379 | def __init__(self, device: torch.device): 380 | if device.type == "cpu": 381 | self.synchronize = lambda: torch.cpu.synchronize(device=device) 382 | elif device.type == "cuda": 383 | self.synchronize = lambda: torch.cuda.synchronize(device=device) 384 | self.measurements = [] 385 | self.start_time = None 386 | 387 | def __enter__(self): 388 | self.synchronize() 389 | self.start_time = time.time() 390 | return self 391 | 392 | def __exit__(self, type, value, traceback): 393 | if traceback is None: 394 | self.synchronize() 395 | end_time = time.time() 396 | self.measurements.append(end_time - self.start_time) 397 | self.start_time = None 398 | 399 | def avg_elapsed_ms(self): 400 | return 1000 * (sum(self.measurements) / len(self.measurements)) 401 | 402 | def reset(self): 403 | self.measurements = [] 404 | self.start_time = None 405 | 406 | 407 | def _get_parser() -> argparse.ArgumentParser: 408 | parser = argparse.ArgumentParser() 409 | parser.add_argument("-e", "--experiment-name", default=None, required=True) 410 | parser.add_argument("-d", "--dataset-name", default=None, required=True) 411 | parser.add_argument("-m", "--model-name", default=None, required=True) 412 | parser.add_argument("--save-dir", default="../outputs") 413 | parser.add_argument("--seed", default=0, type=int) 414 | parser.add_argument("--num-epochs", default=100, type=int) 415 | parser.add_argument("--lr", default=3e-5, type=float) 416 | parser.add_argument("-b", "--batch-size", default=1, type=int) 417 | parser.add_argument("--log-freq", default=100, type=int) 418 | parser.add_argument("--ckpt-freq", default=500, type=int) 419 | parser.add_argument("-s", "--seq-length", default=None, type=int) 420 | parser.add_argument("--tp", default=8, type=int) 421 | return parser 422 | 423 | 424 | if __name__ == "__main__": 425 | main() 426 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Lambda, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Distributed Training Guide 2 | 3 | 4 | 5 | [Neurips 2024 presentation slides here](https://docs.google.com/presentation/d/1ANMmkOGaruYKTvhnsAbZgI9GrdMliNvibWGuNYw6HX8/edit?usp=sharing) 6 | 7 | Ever wondered how to train a large neural network across a giant cluster? Look no further! 8 | 9 | This is a comprehensive guide on best practices for distributed training, diagnosing errors, and fully utilizing all resources available. It is organized into sequential chapters, each with a `README.md` and a `train_llm.py` script in them. The readme will discuss both the high level concepts of distributed training, and the code changes introduced in that chapter. 10 | 11 | The guide is written entirely in very minimal standard pytorch, using `transformers` and `datasets` for models and data, respectively. No other library is used for distributed code - the distributed stuff is entirely in pytorch. 12 | 13 | 1. [Chapter 1](./01-single-gpu/) - A standard Causal LLM training script that runs on a **single GPU**. 14 | 2. [Chapter 2](./02-distributed-data-parallel/) - Upgrades the training script to support **multiple GPUs and to use DDP**. 15 | 3. [Chapter 3](./03-job-launchers/) - Covers how to **launch training jobs** across clusters with multiple nodes. 16 | 4. [Chapter 4](./04-fully-sharded-data-parallel/) - Upgrades the training script to **use FSDP** instead of DDP for more optimal memory usage. 17 | 5. [Chapter 5](./05-training-llama-405b/) - Upgrades the training script to **train Llama-405b**. 18 | 6. [Chapter 6](./06-tensor-parallel/) - Upgrades our single GPU training script to support **tensor parallelism**. 19 | 7. [Chapter 7](./06-2d-parallel/) - Upgrades our TP training script to use **2d parallelism (FSDP + TP)**. 20 | 8. [Alternative Frameworks](./alternative-frameworks/) - Covers different frameworks that all work with pytorch under the hood. 21 | 9. [Diagnosing Errors](./diagnosing-errors/) - Best practices and how tos for **quickly diagnosing errors** in your cluster. 22 | 10. [Related Topics](./related-topics/) - Topics that you should be aware of when distributed training. 23 | 24 | 25 | Questions this guide answers: 26 | 27 | - How do I update a single gpu training/fine tuning script to run on multiple GPUs or multiple nodes? 28 | - How do I diagnose hanging/errors that happen during training? 29 | - My model/optimizer is too big for a single gpu - how do I train/fine tune it on my cluster? 30 | - How do I schedule/launch training on a cluster? 31 | - How do I scale my hyperparameters when increasing the number of workers? 32 | 33 | Best practices for logging stdout/stderr and wandb are also included, as logging is vitally important in diagnosing/debugging training runs on a cluster. 34 | 35 | Each of the training scripts is aimed at training a causal language model (i.e. gpt/llama). 36 | 37 | ## Set up 38 | 39 | ### Clone this repo 40 | 41 | ```bash 42 | git clone https://github.com/LambdaLabsML/distributed-training-guide.git 43 | ``` 44 | 45 | ### Virtual Environment 46 | 47 | ```bash 48 | cd distributed-training-guide 49 | python3 -m venv venv 50 | source venv/bin/activate 51 | python -m pip install -U pip 52 | pip install -U setuptools wheel 53 | pip install -r requirements.txt 54 | pip install flash-attn --no-build-isolation 55 | ``` 56 | 57 | ### wandb 58 | 59 | This tutorial uses `wandb` as an experiment tracker. 60 | 61 | ```bash 62 | wandb login 63 | ``` 64 | 65 |

66 | 🦄 Other exciting ML projects at Lambda: ML Times, Text2Video, GPU Benchmark. 67 |

68 | -------------------------------------------------------------------------------- /alternative-frameworks/deepspeed/README.md: -------------------------------------------------------------------------------- 1 | # DeepSpeed ZeRO 2 | 3 | Install deepspeed: `pip install deepspeed` 4 | 5 | image 6 | 7 | This is actually a collection of modes to shard more and more memory: 8 | 9 | > ZeRO Stage 1: The optimizer states (e.g., for Adam optimizer, 32-bit weights, and the first, and second moment estimates) are partitioned across the processes, so that each process updates only its partition. 10 | 11 | > ZeRO Stage 2: The reduced 32-bit gradients for updating the model weights are also partitioned such that each process retains only the gradients corresponding to its portion of the optimizer states. 12 | 13 | > ZeRO Stage 3: The 16-bit model parameters are partitioned across the processes. ZeRO-3 will automatically collect and partition them during the forward and backward passes. 14 | 15 | References: 16 | - [deepspeed docs](https://deepspeed.readthedocs.io/en/latest/zero3.html) 17 | - [ZeRO: Memory Optimizations Toward Training Trillion Parameter Models](https://arxiv.org/abs/1910.02054) 18 | - [ZeRO-Offload: Democratizing Billion-Scale Model Training](https://arxiv.org/abs/2101.06840) 19 | - [ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning](https://arxiv.org/abs/2104.07857) 20 | 21 | ## Integrating DeepSpeed into training code 22 | 23 | ### Argument Parsing 24 | 25 | ```diff 26 | @@ -305,11 +302,10 @@ def _get_parser() -> argparse.ArgumentParser: 27 | parser.add_argument("--log-freq", default=100, type=int) 28 | parser.add_argument("--ckpt-freq", default=500, type=int) 29 | + parser.add_argument("--local_rank", type=int, default=None) 30 | + deepspeed.add_config_arguments(parser) 31 | return parser 32 | ``` 33 | 34 | ### Initialization 35 | 36 | Two main differences here: 37 | 1. We call `deepspeed.init_distributed` instead of using pytorch's `init_process_group` 38 | 2. We call `deepspeed.initialize` after we've constructed the model **instead** of wrapping the model with DDP. 39 | 40 | **NOTE**: `deepspeed.initialize` will construct the optimizer & lr_scheduler based on the config you pass in 41 | 42 | ```diff 43 | @@ -14,6 +14,7 @@ from torch.nn.parallel import DistributedDataParallel 44 | from torch import distributed as dist 45 | from torch.distributed.elastic.multiprocessing.errors import record 46 | 47 | +import deepspeed 48 | import numpy 49 | import wandb 50 | import tqdm 51 | @@ -42,10 +43,15 @@ def main(): 52 | - dist.init_process_group() 53 | + deepspeed.init_distributed() 54 | 55 | rank = dist.get_rank() 56 | - local_rank = rank % torch.cuda.device_count() 57 | + local_rank = args.local_rank or (rank % torch.cuda.device_count()) 58 | world_size = dist.get_world_size() 59 | 60 | LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}") 61 | 62 | @@ -73,10 +73,6 @@ def main(): 63 | if len(tokenizer) > embedding_size: 64 | model.resize_token_embeddings(len(tokenizer)) 65 | 66 | - model = DistributedDataParallel( 67 | - model, device_ids=[local_rank], output_device=local_rank 68 | - ) 69 | - 70 | @@ -89,9 +95,11 @@ def main(): 71 | ) 72 | LOGGER.info(f"{len(dataloader)} batches per epoch") 73 | 74 | - optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) 75 | - lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 76 | - optimizer, T_max=1000, eta_min=args.lr * 1e-2 77 | + model_engine: deepspeed.DeepSpeedEngine 78 | + model_engine, _, _, lr_scheduler = deepspeed.initialize( 79 | + args, 80 | + model=model, 81 | + model_parameters=(p for p in model.parameters() if p.requires_grad), 82 | ) 83 | ``` 84 | 85 | ### Train Loop 86 | 87 | Here we are just going to be replacing our pytorch calls with deepspeed calls. Note that we don't have direct access to optimizer/lr_scheduler anymore since deepspeed handles that. 88 | 89 | ```diff 90 | with timers["forward"]: 91 | - outputs = model(**batch) 92 | + outputs = model_engine(**batch) 93 | 94 | with timers["backward"]: 95 | - optimizer.zero_grad(set_to_none=True) 96 | - outputs.loss.backward() 97 | + model_engine.backward(outputs.loss) 98 | 99 | with timers["update"]: 100 | - optimizer.step() 101 | - lr_scheduler.step() 102 | + model_engine.step() 103 | 104 | state["global_step"] += 1 105 | state["epoch_step"] += 1 106 | ``` 107 | 108 | ### Checkpoints 109 | 110 | Loading becomes: 111 | 112 | ```diff 113 | resumed = False 114 | - if (exp_dir / "state.json").exists(): 115 | - model.load_state_dict(_load_to_device(exp_dir / "model.pt")) 116 | - optimizer.load_state_dict(_load_to_device(exp_dir / "optimizer.pt")) 117 | - lr_scheduler.load_state_dict(_load_to_device(exp_dir / "lr_scheduler.pt")) 118 | - with open(exp_dir / "state.json") as fp: 119 | - state = json.load(fp) 120 | - resumed = True 121 | + if (exp_dir / "pytorch_model.bin").exists(): 122 | + load_path, state = model_engine.load_checkpoint(exp_dir) 123 | + resumed = load_path is not None 124 | ``` 125 | 126 | Saving becomes: (**NOTE**: saving must be done on ALL ranks instead of just rank 0 - because of sharding) 127 | 128 | ```diff 129 | if state["global_step"] % args.ckpt_freq == 0: 130 | - if rank == 0: 131 | - torch.save(optimizer.state_dict(), exp_dir / "optimizer.pt") 132 | - torch.save(model.state_dict(), exp_dir / "model.pt") 133 | - torch.save(lr_scheduler.state_dict(), exp_dir / "lr_scheduler.pt") 134 | - with open(exp_dir / "state.json", "w") as fp: 135 | - json.dump(state, fp) 136 | + model_engine.save_checkpoint(exp_dir, client_state=state) 137 | dist.barrier() 138 | ``` 139 | 140 | ## Configuration 141 | 142 | ```json 143 | { 144 | "train_micro_batch_size_per_gpu": 64, 145 | "optimizer": { 146 | "type": "Adam", 147 | "params": { 148 | "lr": 3e-5 149 | } 150 | }, 151 | "scheduler": { 152 | "type": "WarmupCosineLR", 153 | "params": { 154 | "total_num_steps": 1000, 155 | "warmup_num_steps": 0, 156 | "cos_min_ratio": 1e-2 157 | } 158 | }, 159 | "bf16": { 160 | "enabled": true 161 | }, 162 | "zero_optimization": { 163 | "stage": 3, 164 | "offload_param": false, 165 | "offload_optimizer": false 166 | } 167 | } 168 | ``` 169 | 170 | ## Command 171 | 172 | ```bash 173 | cd distributed-training-guide/05-sharding-deepspeed 174 | export TORCHELASTIC_ERROR_FILE=../error.json 175 | export OMP_NUM_THREADS=1 176 | export HF_HOME=../.cache 177 | deepspeed \ 178 | --enable_each_rank_log ../logs \ 179 | train_llm.py \ 180 | --experiment-name deepspeed-multi-node-$(date +%Y-%m-%dT%H-%M-%S) \ 181 | --dataset-name tatsu-lab/alpaca \ 182 | --model-name openai-community/gpt2 \ 183 | --deepspeed_config ds_config.json 184 | ``` 185 | -------------------------------------------------------------------------------- /alternative-frameworks/deepspeed/ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 1, 3 | "optimizer": { 4 | "type": "AdamW", 5 | "params": { 6 | "lr": 3e-5 7 | } 8 | }, 9 | "scheduler": { 10 | "type": "WarmupCosineLR", 11 | "params": { 12 | "total_num_steps": 1000, 13 | "warmup_num_steps": 0, 14 | "cos_min_ratio": 1e-2 15 | } 16 | }, 17 | "bf16": { 18 | "enabled": true 19 | }, 20 | "zero_optimization": { 21 | "stage": 3, 22 | "offload_param": false, 23 | "offload_optimizer": false 24 | } 25 | } -------------------------------------------------------------------------------- /alternative-frameworks/deepspeed/train_llm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from contextlib import contextmanager 3 | from itertools import chain 4 | import multiprocessing 5 | import os 6 | import time 7 | from pathlib import Path 8 | import logging 9 | 10 | import torch 11 | from torch.utils.data import DataLoader 12 | from torch.utils.data.distributed import DistributedSampler 13 | from torch import distributed as dist 14 | from torch.distributed.elastic.multiprocessing.errors import record 15 | 16 | import deepspeed 17 | import wandb 18 | import tqdm 19 | import datasets 20 | from transformers import ( 21 | AutoConfig, 22 | AutoModelForCausalLM, 23 | AutoTokenizer, 24 | default_data_collator, 25 | ) 26 | 27 | LOGGER = logging.getLogger(__name__) 28 | 29 | 30 | @record 31 | def main(): 32 | parser = _get_parser() 33 | args = parser.parse_args() 34 | 35 | dist.init_process_group() 36 | 37 | rank = dist.get_rank() 38 | local_rank = args.local_rank or (rank % torch.cuda.device_count()) 39 | world_size = dist.get_world_size() 40 | 41 | logging.basicConfig( 42 | format=f"[rank={rank}] [%(asctime)s] %(levelname)s:%(message)s", 43 | level=logging.INFO, 44 | ) 45 | 46 | LOGGER.info(os.environ) 47 | LOGGER.info(args) 48 | LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}") 49 | 50 | device = torch.device(f"cuda:{local_rank}") 51 | dtype = torch.bfloat16 52 | torch.cuda.set_device(device) 53 | 54 | torch.manual_seed(args.seed) 55 | 56 | with rank0_first(): 57 | config = AutoConfig.from_pretrained(args.model_name, use_cache=False) 58 | with deepspeed.zero.Init(remote_device="cpu", pin_memory=True): 59 | model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype) 60 | LOGGER.info(f"{sum(p.numel() for p in model.parameters())} model parameters") 61 | 62 | # NOTE: since this can download data, make sure to do the main process first 63 | # NOTE: This assumes that the data is on a **shared** network drive, accessible to all processes 64 | with rank0_first(): 65 | train_data = _load_and_preprocess_data(args, config) 66 | LOGGER.info(f"{len(train_data)} training samples") 67 | 68 | model_engine: deepspeed.DeepSpeedEngine 69 | model_engine, _, _, lr_scheduler = deepspeed.initialize( 70 | args, 71 | model=model, 72 | model_parameters=(p for p in model.parameters() if p.requires_grad), 73 | ) 74 | 75 | dataloader = DataLoader( 76 | train_data, 77 | batch_size=model_engine.train_micro_batch_size_per_gpu(), 78 | collate_fn=default_data_collator, 79 | # NOTE: this sampler will split dataset evenly across workers 80 | sampler=DistributedSampler(train_data, shuffle=True, drop_last=True), 81 | ) 82 | LOGGER.info(f"{len(dataloader)} batches per epoch") 83 | 84 | exp_dir: Path = Path(args.save_dir) / args.experiment_name 85 | 86 | # attempt resume 87 | state = { 88 | "epoch": 0, 89 | "global_step": 0, 90 | "epoch_step": 0, 91 | "running_loss": 0, 92 | } 93 | resumed = False 94 | if (exp_dir / "pytorch_model.bin").exists(): 95 | load_path, state = model_engine.load_checkpoint(exp_dir) 96 | resumed = load_path is not None 97 | LOGGER.info(f"Resumed={resumed} | {state}") 98 | dist.barrier() 99 | 100 | if (exp_dir.is_mount() and rank == 0) or ( 101 | not exp_dir.is_mount() and local_rank == 0 102 | ): 103 | LOGGER.info(f"Creating experiment root directory") 104 | exp_dir.mkdir(parents=True, exist_ok=True) 105 | dist.barrier() 106 | 107 | (exp_dir / f"rank-{rank}").mkdir(parents=True, exist_ok=True) 108 | LOGGER.info(f"Worker saving to {exp_dir / f'rank-{rank}'}") 109 | 110 | if rank == 0: 111 | wandb.init( 112 | project="distributed-training-guide", 113 | dir=exp_dir, 114 | name=args.experiment_name, 115 | id=args.experiment_name, 116 | resume="must" if resumed else None, 117 | save_code=True, 118 | config={ 119 | "args": vars(args), 120 | "training_data_size": len(train_data), 121 | "num_batches": len(dataloader), 122 | "world_size": world_size, 123 | }, 124 | ) 125 | 126 | timers = {k: LocalTimer(device) for k in ["data", "forward", "backward", "update"]} 127 | 128 | for state["epoch"] in range(state["epoch"], args.num_epochs): 129 | LOGGER.info(f"Begin epoch {state['epoch']} at step {state['epoch_step']}") 130 | 131 | progress_bar = tqdm.tqdm(range(len(dataloader)), disable=rank > 0) 132 | if state["epoch_step"] > 0: 133 | progress_bar.update(state["epoch_step"]) 134 | 135 | dataloader.sampler.set_epoch(state["epoch"]) 136 | batches = iter(dataloader) 137 | 138 | for i_step in range(len(dataloader)): 139 | with timers["data"], torch.no_grad(): 140 | batch = next(batches) 141 | batch = {k: v.to(device=device) for k, v in batch.items()} 142 | 143 | if i_step < state["epoch_step"]: 144 | # NOTE: for resuming 145 | continue 146 | 147 | with timers["forward"]: 148 | outputs = model_engine(**batch) 149 | 150 | with timers["backward"]: 151 | model_engine.backward(outputs.loss) 152 | 153 | with timers["update"]: 154 | model_engine.step() 155 | 156 | state["global_step"] += 1 157 | state["epoch_step"] += 1 158 | state["running_loss"] += outputs.loss.item() 159 | progress_bar.update(1) 160 | 161 | if state["global_step"] % args.log_freq == 0: 162 | tok_per_step = ( 163 | world_size 164 | * model_engine.train_micro_batch_size_per_gpu() 165 | * args.seq_length 166 | ) 167 | ms_per_step = sum(t.avg_elapsed_ms() for t in timers.values()) 168 | info = { 169 | "global_step": state["global_step"], 170 | "lr": lr_scheduler.get_last_lr()[0], 171 | "running_loss": state["running_loss"] / args.log_freq, 172 | "epoch": state["epoch"], 173 | "epoch_progress": state["epoch_step"] / len(dataloader), 174 | "num_batches_remaining": len(dataloader) - i_step, 175 | **get_mem_stats(device), 176 | "tok/s": 1000 * tok_per_step / ms_per_step, 177 | "time/total": ms_per_step, 178 | **{ 179 | f"time/{k}": timer.avg_elapsed_ms() 180 | for k, timer in timers.items() 181 | }, 182 | } 183 | 184 | LOGGER.info(info) 185 | if rank == 0: 186 | wandb.log(info, step=state["global_step"]) 187 | 188 | torch.cuda.reset_peak_memory_stats(device) 189 | state["running_loss"] = 0 190 | for t in timers.values(): 191 | t.reset() 192 | 193 | if state["global_step"] % args.ckpt_freq == 0: 194 | LOGGER.info("Saving checkpoint.") 195 | model_engine.save_checkpoint(exp_dir, client_state=state) 196 | dist.barrier() 197 | 198 | state["epoch_step"] = 0 199 | 200 | 201 | def _load_and_preprocess_data(args, config): 202 | """ 203 | Function created using code found in 204 | https://github.com/huggingface/transformers/blob/v4.45.1/examples/pytorch/language-modeling/run_clm_no_trainer.py 205 | """ 206 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 207 | 208 | data = datasets.load_dataset(args.dataset_name, trust_remote_code=True) 209 | 210 | column_names = data["train"].column_names 211 | text_column_name = "text" if "text" in column_names else column_names[0] 212 | 213 | def tokenize_function(examples): 214 | return tokenizer(examples[text_column_name]) 215 | 216 | tokenized_datasets = data.map( 217 | tokenize_function, 218 | batched=True, 219 | remove_columns=column_names, 220 | num_proc=multiprocessing.cpu_count(), 221 | load_from_cache_file=True, 222 | desc="Running tokenizer on dataset", 223 | ) 224 | 225 | seq_length = args.seq_length or tokenizer.model_max_length 226 | if seq_length > config.max_position_embeddings: 227 | seq_length = min(1024, config.max_position_embeddings) 228 | 229 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. 230 | def group_texts(examples): 231 | # Concatenate all texts. 232 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 233 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 234 | # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. 235 | # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. 236 | if total_length > seq_length: 237 | total_length = (total_length // seq_length) * seq_length 238 | # Split by chunks of max_len. 239 | result = { 240 | k: [t[i : i + seq_length] for i in range(0, total_length, seq_length)] 241 | for k, t in concatenated_examples.items() 242 | } 243 | result["labels"] = result["input_ids"].copy() 244 | return result 245 | 246 | lm_datasets = tokenized_datasets.map( 247 | group_texts, 248 | batched=True, 249 | num_proc=multiprocessing.cpu_count(), 250 | load_from_cache_file=True, 251 | desc=f"Grouping texts in chunks of {seq_length}", 252 | ) 253 | 254 | return lm_datasets["train"] 255 | 256 | 257 | def get_mem_stats(device=None): 258 | mem = torch.cuda.memory_stats(device) 259 | props = torch.cuda.get_device_properties(device) 260 | return { 261 | "total_mem_in_gb": 1e-9 * props.total_memory, 262 | "curr_alloc_in_gb": 1e-9 * mem["allocated_bytes.all.current"], 263 | "peak_alloc_in_gb": 1e-9 * mem["allocated_bytes.all.peak"], 264 | "curr_resv_in_gb": 1e-9 * mem["reserved_bytes.all.current"], 265 | "peak_resv_in_gb": 1e-9 * mem["reserved_bytes.all.peak"], 266 | } 267 | 268 | 269 | @contextmanager 270 | def rank0_first(): 271 | rank = dist.get_rank() 272 | if rank == 0: 273 | yield 274 | dist.barrier() 275 | if rank > 0: 276 | yield 277 | dist.barrier() 278 | 279 | 280 | class LocalTimer: 281 | def __init__(self, device: torch.device): 282 | if device.type == "cpu": 283 | self.synchronize = lambda: torch.cpu.synchronize(device=device) 284 | elif device.type == "cuda": 285 | self.synchronize = lambda: torch.cuda.synchronize(device=device) 286 | self.measurements = [] 287 | self.start_time = None 288 | 289 | def __enter__(self): 290 | self.synchronize() 291 | self.start_time = time.time() 292 | return self 293 | 294 | def __exit__(self, type, value, traceback): 295 | if traceback is None: 296 | self.synchronize() 297 | end_time = time.time() 298 | self.measurements.append(end_time - self.start_time) 299 | self.start_time = None 300 | 301 | def avg_elapsed_ms(self): 302 | return 1000 * (sum(self.measurements) / len(self.measurements)) 303 | 304 | def reset(self): 305 | self.measurements = [] 306 | self.start_time = None 307 | 308 | 309 | def _get_parser() -> argparse.ArgumentParser: 310 | parser = argparse.ArgumentParser() 311 | parser.add_argument("-e", "--experiment-name", default=None, required=True) 312 | parser.add_argument("-d", "--dataset-name", default=None, required=True) 313 | parser.add_argument("-m", "--model-name", default=None, required=True) 314 | parser.add_argument("--save-dir", default="../outputs") 315 | parser.add_argument("--seed", default=0, type=int) 316 | parser.add_argument("--num-epochs", default=100, type=int) 317 | parser.add_argument("--log-freq", default=100, type=int) 318 | parser.add_argument("--ckpt-freq", default=500, type=int) 319 | parser.add_argument("-s", "--seq-length", default=1024, type=int) 320 | parser.add_argument("--local_rank", type=int, default=None) 321 | deepspeed.add_config_arguments(parser) 322 | return parser 323 | 324 | 325 | if __name__ == "__main__": 326 | main() 327 | -------------------------------------------------------------------------------- /diagnosing-errors/README.md: -------------------------------------------------------------------------------- 1 | # Diagnosing Errors 2 | 3 | Hanging and deadlocks can be caused by so many things, even your own code! Here's some diagnostic tools that will help you figure out what is going on. 4 | 5 | ## System metrics to watch for to diagnose hanging 6 | 7 | `GPU Power Usage` will be the main one - if the training process is hanging, then the power usage will drop to around ~10% for all workers: 8 | 9 | ```bash 10 | > nvidia-smi --query-gpu=power.draw,power.limit --format=csv,noheader 11 | 69.75 W, 700.00 W 12 | 75.10 W, 700.00 W 13 | 70.82 W, 700.00 W 14 | 69.29 W, 700.00 W 15 | 69.19 W, 700.00 W 16 | 68.72 W, 700.00 W 17 | 70.80 W, 700.00 W 18 | 70.87 W, 700.00 W 19 | ``` 20 | 21 | Using our provided [top-cluster.py](../top-cluster.py) script will output something like this: 22 | 23 | ```bash 24 | > python top-cluster.py 25 | ===2024-10-02 19:55:02.553039 26 | name util power memory nprocs 27 | cluster 100.0% 99.1% 96.9% 64 28 | node-001 100.0% 99.7% 96.1% 8 29 | node-002 100.0% 97.8% 96.9% 8 30 | node-003 100.0% 99.2% 97.2% 8 31 | node-004 100.0% 99.1% 97.4% 8 32 | node-005 100.0% 98.1% 97.1% 8 33 | node-006 100.0% 99.0% 97.7% 8 34 | node-007 100.0% 99.8% 96.9% 8 35 | node-008 100.0% 100.0% 96.2% 8 36 | === 37 | ``` 38 | 39 | ## Getting a dump of stack traces 40 | 41 | Use [py-spy](https://github.com/benfred/py-spy) to get a dump of stacktraces from all python threads in a running python program. Here's how you get a dump from each worker: 42 | 43 | ``` 44 | sudo env "PATH=$PATH" py-spy dump --locals --pid 45 | ``` 46 | 47 | ## Benchmarking/profiling 48 | 49 | You can use `py-spy top --pid <>`, to get a `top`/`htop` like view of the functions that are being called in your python process. 50 | 51 | ## Recording errors 52 | 53 | Python has a great built in library for getting errors that occur in any thread of a python program called [faulthandler](https://docs.python.org/3/library/faulthandler.html). This is especially useful when you're using a DataLoader with num_workers > 0. 54 | 55 | Turns out, pytorch already has a built in way to use it! You just have to set `TORCHELASTIC_ERROR_FILE=../error.json` environment variable and add a `@record` annotation to your main function. 56 | 57 | ```python 58 | from torch.distributed.elastic.multiprocessing.errors import record 59 | 60 | # NOTE: records errors to $TORCHELASTIC_ERROR_FILE 61 | @record 62 | def main(): 63 | ... 64 | ``` 65 | 66 | Luckily all the code in this guide has been doing this, and so should you! **Make sure to set $TORCHELASTIC_ERROR_FILE**!. 67 | 68 | ## Checklist for system problems 69 | 70 | 1. System date time on each system is the same (can cause NCCL timeouts) 71 | 2. NVLink valid topology `nvidia-smi topo -m` 72 | 3. NVLink status `nvidia-smi topo -p2p n` (additionally `w`/`r` in place of `n`) 73 | 4. Open file descriptor limit `ulimit -aH` (and then look for line containing `open files`). 74 | 5. `timeout` in `dist.init_process_group(timeout=...)` is sufficiently large. 75 | -------------------------------------------------------------------------------- /related-topics/README.md: -------------------------------------------------------------------------------- 1 | # Related topics 2 | 3 | This directory contains a list of additional topics that are adjacent to everything discussed in prior chapters. 4 | 5 | These chapters don't contain a training script individually, but the changes discussed in each are relatively small, and code snippets are provided to make it easy to add the features into your code. 6 | -------------------------------------------------------------------------------- /related-topics/determinism/README.md: -------------------------------------------------------------------------------- 1 | # Determinism across resumes 2 | 3 | **NOTE: This chapter's code builds off of [chapter 3](../../03-multi-node/)'s code.** 4 | 5 | See pytorch's documnetation on reproducibility: https://pytorch.org/docs/stable/notes/randomness.html#reproducibility 6 | 7 | Notably we are also saving & restoring the rng states from various libraries, and explicitly seeding the workers for data loading. 8 | 9 | ## Code Changes 10 | 11 | ```diff 12 | diff --git a/03-multi-node/train_llm.py b/10-determinism/train_llm.py 13 | index 24eacbd..0a3a029 100644 14 | --- a/03-multi-node/train_llm.py 15 | +++ b/10-determinism/train_llm.py 16 | @@ -40,6 +40,7 @@ def main(): 17 | 18 | torch.set_num_threads(1) 19 | torch.set_num_interop_threads(1) 20 | + torch.use_deterministic_algorithms(True) 21 | 22 | torch.manual_seed(args.seed) 23 | torch.cuda.manual_seed_all(args.seed) 24 | @@ -84,6 +85,8 @@ def main(): 25 | train_data = _load_and_preprocess_data(args, tokenizer, config) 26 | LOGGER.info(f"{len(train_data)} training samples") 27 | 28 | + g = torch.Generator() 29 | + g.manual_seed(args.seed) 30 | dataloader = DataLoader( 31 | train_data, 32 | batch_size=args.batch_size, 33 | @@ -91,6 +94,8 @@ def main(): 34 | num_workers=1, 35 | # NOTE: this sampler will split dataset evenly across workers 36 | sampler=DistributedSampler(train_data, shuffle=True, drop_last=True), 37 | + worker_init_fn=_seed_worker, 38 | + generator=g, 39 | ) 40 | LOGGER.info(f"{len(dataloader)} batches per epoch") 41 | 42 | @@ -116,6 +121,13 @@ def main(): 43 | lr_scheduler.load_state_dict(_load_to_device(exp_dir / "lr_scheduler.pt")) 44 | with open(exp_dir / "state.json") as fp: 45 | state = json.load(fp) 46 | + rng_state = torch.load( 47 | + exp_dir / "rng.pt", weights_only=False, map_location="cpu" 48 | + ) 49 | + numpy.random.set_state(rng_state["np"]) 50 | + random.setstate(rng_state["random"]) 51 | + torch.set_rng_state(rng_state["torch"]) 52 | + torch.cuda.set_rng_state(rng_state["cuda"][local_rank], device) 53 | resumed = True 54 | LOGGER.info(f"Resumed={resumed} | {state}") 55 | 56 | @@ -208,11 +220,26 @@ def main(): 57 | torch.save(lr_scheduler.state_dict(), exp_dir / "lr_scheduler.pt") 58 | with open(exp_dir / "state.json", "w") as fp: 59 | json.dump(state, fp) 60 | + torch.save( 61 | + { 62 | + "np": numpy.random.get_state(), 63 | + "random": random.getstate(), 64 | + "torch": torch.get_rng_state(), 65 | + "cuda": torch.cuda.get_rng_state_all(), 66 | + }, 67 | + exp_dir / "rng.pt", 68 | + ) 69 | dist.barrier() 70 | 71 | state["epoch_step"] = 0 72 | 73 | 74 | +def _seed_worker(worker_id): 75 | + worker_seed = torch.initial_seed() % 2**32 76 | + numpy.random.seed(worker_seed) 77 | + random.seed(worker_seed) 78 | + 79 | + 80 | def _load_and_preprocess_data(args, tokenizer, config): 81 | data = datasets.load_dataset( 82 | args.dataset_name, trust_remote_code=True, 83 | ``` -------------------------------------------------------------------------------- /related-topics/effective-batch-size-and-lr/README.md: -------------------------------------------------------------------------------- 1 | # Effective Batch Size and LR 2 | 3 | As you scale up the number of nodes, the effective batch size (the amount of items used for model updates) increases as well: 4 | 5 | ``` 6 | effective_batch_size = batch_size * world_size 7 | ``` 8 | 9 | As you may know, increasing the batch size means that the variance of the data that your model is training on decreases, meaning your gradients will be much smoother. This directly impacts the dynamics of how your model learns and changes! 10 | 11 | If you want to **exactly match the dynamics of single gpu training** when moving to multi node training, this chapter is aimed at you! 12 | 13 | ## Scaling Rules 14 | 15 | If you want exact training dynamics, you have to also scale the learning rate. However, this depends on what optimizer you are using. The exact rules are not fully understood, and you can look into the following papers for more information: 16 | 17 | - [Exploring Learning Rate Scaling Rules for Distributed ML Training on Transient Resources](https://anakli.inf.ethz.ch/papers/learning_rate_distribml22.pdf) 18 | 19 | As of writing this, the most common rules that people use to scale learning rate are: 20 | 21 | ### Linear scaling rule 22 | 23 | ```python 24 | lr = args.lr * dist.get_world_size() 25 | ``` 26 | 27 | This was first reported in the large minibatch SGD paper above. However this doesn't quite produce exactly the same training dynamics, and the paper actually used a **factor of the world size**. 28 | 29 | NOTE: **Be careful when using this for optimizers other than SGD** 30 | 31 | References: 32 | - [Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour](https://arxiv.org/pdf/1706.02677) 33 | 34 | ### Square root scaling rule 35 | 36 | ```python 37 | lr = args.lr * numpy.sqrt(dist.get_world_size()) 38 | ``` 39 | 40 | This is proposed for use with the Adam optimizer, and maintains the square root of the variance of the gradient when scaling the number of batches. 41 | 42 | References: 43 | - [One weird trick for parallelizing convolutional neural networks](https://arxiv.org/pdf/1404.5997) 44 | - [Large-Batch Training for LSTM and Beyond](https://arxiv.org/pdf/1901.08256) 45 | -------------------------------------------------------------------------------- /related-topics/elastic-training/README.md: -------------------------------------------------------------------------------- 1 | # Elastic Training 2 | 3 | Elastic training is training where the launcher can restart a subset (or all) of the workers at various points throughout training. 4 | 5 | Contrary to what you might think, usually when 1 worker encounters an error, **ALL workers are restarted** (see https://pytorch.org/docs/stable/elastic/run.html#membership-changes). 6 | 7 | `torchrun` supports this via [elastic launch](https://pytorch.org/docs/stable/elastic/run.html#elastic-min-1-max-4-tolerates-up-to-3-membership-changes-or-failures): 8 | 9 | ```bash 10 | torchrun 11 | --nnodes=1:4 12 | --max-restarts=3 13 | ... 14 | ``` 15 | 16 | which means that torchrun will restart all the workers up to 3 times (and if some of the nodes go offline, it can use as few as 1). 17 | 18 | Note: 19 | - `rank`, `local_rank`, and `world_size` are all not stable across restarts of a worker. 20 | - Sometimes nodes have issues that can't be fixed just by restarting (like if you have a bug). 21 | 22 | ## Code Changes 23 | 24 | No code changes are needed to do elastic training for our existing code. Instead it is more informative to play with a toy example where workers randomly crash to give you a sense for how it works. 25 | 26 | ```bash 27 | cd distributed-training-guide/96-elastic-training 28 | torchrun \ 29 | --nnodes 1 \ 30 | --nproc_per_node 8 \ 31 | --max-restarts 3 \ 32 | --redirects 3 \ 33 | --log-dir ../logs \ 34 | toy.py 35 | ``` 36 | 37 | This toy script will randomly throw an error from each of the ranks. **No GPU required to try this command!** 38 | 39 | Inspect the log directory after you run this, for each attempt, there will be 1 worker sub directory that has a `error.json` file in it. You can also inspect each worker's stdout/stderr. 40 | -------------------------------------------------------------------------------- /related-topics/elastic-training/toy.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import logging 4 | import os 5 | 6 | from torch import distributed as dist 7 | from torch.distributed.elastic.multiprocessing.errors import record 8 | 9 | LOGGER = logging.getLogger(__name__) 10 | _STATE_PATH = "./toy-state.json" 11 | 12 | 13 | @record 14 | def main(): 15 | logging.basicConfig(level=logging.INFO) 16 | 17 | dist.init_process_group() 18 | 19 | rank = dist.get_rank() 20 | local_rank = os.environ["LOCAL_RANK"] 21 | world_size = dist.get_world_size() 22 | 23 | LOGGER.info(f"local_rank={local_rank} rank={rank} world size={world_size}") 24 | 25 | state = {"num_steps": 0} 26 | if os.path.exists(_STATE_PATH): 27 | with open(_STATE_PATH) as fp: 28 | state = json.load(fp) 29 | 30 | random.seed(rank + world_size * state["num_steps"]) 31 | 32 | while True: 33 | value = random.random() 34 | LOGGER.info(f"[{rank=}] step={state['num_steps']} {value=}") 35 | if value < 0.001: 36 | raise ValueError("Encountered fake bad value.") 37 | 38 | state["num_steps"] += 1 39 | 40 | dist.barrier() 41 | if rank == 0: 42 | with open(_STATE_PATH, "w") as fp: 43 | json.dump(state, fp) 44 | dist.barrier() 45 | 46 | 47 | if __name__ == "__main__": 48 | main() 49 | -------------------------------------------------------------------------------- /related-topics/gradient-accumulation/README.md: -------------------------------------------------------------------------------- 1 | # Gradient Accumulation 2 | 3 | Gradient accumulation is a way to increase the effective batch sizes of your model updates. 4 | 5 | It is normally applied when your model is so big that you use a lower batch size when running the forward/backward pass. 6 | 7 | If on a single GPU you have a batch size of 4, and a gradient accumulation of 2, then your effective batch size is 8. 8 | 9 | However, applying gradient accumulation in a standard way will cause slowdowns in distributed training setting because of gradient synchronization. 10 | 11 | ## Standard Implementation 12 | 13 | ```python 14 | outputs = model(**batch) 15 | outputs.loss.backward() 16 | if i_step % grad_accum == 0: 17 | optimizer.step() 18 | lr_scheduler.step() 19 | optimizer.zero_grad(set_to_none=True) 20 | ``` 21 | 22 | ## DataDistributedParalell Implementation 23 | 24 | In a distributed setting, gradients will be synchronized at multiple points during our forward pass. It turns out we need to delay this synchronization until we do the full model step! 25 | 26 | We can use [torch.nn.parallel.DistributedDataParallel.no_sync](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.no_sync) for this: 27 | 28 | ```python 29 | from contextlib import nullcontext 30 | maybe_sync_grads = model.no_sync if i_step % grad_accum != 0 else nullcontext 31 | with maybe_sync_grads(): 32 | outputs = model(**batch) 33 | outputs.loss.backward() 34 | if i_step % grad_accum == 0: 35 | optimizer.step() 36 | lr_scheduler.step() 37 | optimizer.zero_grad(set_to_none=True) 38 | ``` 39 | -------------------------------------------------------------------------------- /related-topics/optimizing-data-loading/README.md: -------------------------------------------------------------------------------- 1 | # Optimizing Data Loading 2 | 3 | **NOTE: This chapter's code builds off of [chapter 3](../../03-multi-node/)'s code.** 4 | 5 | An important part of achieving high throughput during distributed training is ensuring that all processes are moving at roughly the same speed. If one process is much faster, it will spend a lot of time waiting for the other processes to catch up. Data loading is actually a hugely important part of this. 6 | 7 | ## Motivating Example 8 | 9 | While writing this guide, I noticed a drop in GPU utilization **across all nodes** when moving from single node to multi node. When training single node, the GPU power draw was at 80%, and when I went to multi node, it dropped to 60% across all nodes. 10 | 11 | It turns out data loading was consistently slower on one node, causing **all nodes** to wait for it. 12 | 13 | In this guide's case, since data loading is relatively fast, simply updating the number of workers and the prefetch factor fixed it. In more complex examples, other optimizations or preprocessing may be needed. 14 | 15 | ## Loading data in parallel 16 | 17 | Most slow downs in this case all come from data size: 18 | 19 | 1. If some of the processes read data more slowly, then they will already be behind. This can be due to disk reads being blocked, limits of open file descriptors, etc. 20 | 2. If you have batches of different sizes, then the model forward/backward calls will take different amounts of time. 21 | 22 | Most of these can be handled simply by doing data loading in another process (via `num_workers` argument): 23 | 24 | ```diff 25 | dataloader = DataLoader( 26 | train_data, 27 | batch_size=args.batch_size, 28 | collate_fn=default_data_collator, 29 | + num_workers=1, 30 | + prefetch_factor=2, 31 | # NOTE: this sampler will split dataset evenly across workers 32 | sampler=DistributedSampler(train_data, shuffle=True, drop_last=True), 33 | ) 34 | ``` 35 | 36 | This will cause the data loading to happen behind the scenes **in parallel to the batch processing**. 37 | 38 | You'll need to change the num_workers and prefetch factor settings based on a number of things: 39 | 1. How big your batch size is 40 | 2. How long a single row from your dataset takes to load/preprocess 41 | 3. How fast your batches take to process 42 | 43 | If you have `num_workers>0`, then you just want the time to fully load a batch to be less than the time to process the batch. 44 | 45 | ## Measuring wait time 46 | 47 | We can measure this phenomena by adding some explicit `dist.barrier()` calls in our code with our timing wrapped around it: 48 | 49 | ```diff --git a/03-multi-node/train_llm.py b/06-data-loading/train_llm.py 50 | index d5cb05c..26cadb8 100644 51 | --- a/03-multi-node/train_llm.py 52 | +++ b/06-data-loading/train_llm.py 53 | @@ -146,7 +148,10 @@ def main(): 54 | }, 55 | ) 56 | 57 | - timers = {k: LocalTimer(device) for k in ["data", "forward", "backward", "update"]} 58 | + timers = { 59 | + k: LocalTimer(device) 60 | + for k in ["data", "forward", "backward", "update", "waiting"] 61 | + } 62 | 63 | for state["epoch"] in range(state["epoch"], args.num_epochs): 64 | LOGGER.info( 65 | @@ -168,13 +173,22 @@ def main(): 66 | # NOTE: for resuming 67 | continue 68 | 69 | + with timers["waiting"]: 70 | + dist.barrier() 71 | + 72 | with timers["forward"]: 73 | outputs = model(**batch) 74 | 75 | + with timers["waiting"]: 76 | + dist.barrier() 77 | + 78 | with timers["backward"]: 79 | optimizer.zero_grad(set_to_none=True) 80 | outputs.loss.backward() 81 | 82 | + with timers["waiting"]: 83 | + dist.barrier() 84 | + 85 | with timers["update"]: 86 | optimizer.step() 87 | lr_scheduler.step() 88 | ``` 89 | 90 | 91 | ## Faster storage 92 | 93 | A very common setup is to have all of your data on networked data storage. While this is convenient for our code, it is not the most efficient for data reading. 94 | 95 | Similar to how the cache is faster than ram, and ram is faster than disk - local node storage is much faster than networked storage: 96 | 97 | 1. Cache (Fastest) 98 | 2. RAM 99 | 3. Machine local disk 100 | 4. Networked disk (Slowest) 101 | 102 | Simply copying all of your data to each node individual can improve the speed of data loading, at the cost of more storage. 103 | -------------------------------------------------------------------------------- /related-topics/wandb-configurations/README.md: -------------------------------------------------------------------------------- 1 | # wandb configurations 2 | 3 | There are a bunch of ways to configure wandb during your training runs. What will work best for you depends on how big your cluster is and what you want to track. 4 | 5 | ## rank 0 6 | 7 | This is the standard approach. You will only see system information from the node that has the rank 0 process, and only data from rank 0 will be logged. It is minimal information, and you still get to track the experiment progress. 8 | 9 | ```python 10 | if rank == 0: 11 | wandb.init( 12 | project="distributed-training-guide", 13 | dir=exp_dir, 14 | id=args.experiment_name, 15 | name=args.experiment_name, 16 | resume="must" if resumed else None, 17 | save_code=True, 18 | config=..., 19 | ) 20 | ``` 21 | 22 | ## local_rank 0 (every node) 23 | 24 | With this approach you can see system information from all nodes, and it scales linearly with number of nodes. This approach uses [wandb grouped runs](https://docs.wandb.ai/guides/runs/grouping/). 25 | 26 | ```python 27 | if local_rank == 0: 28 | wandb.init( 29 | project="distributed-training-guide", 30 | dir=exp_dir / f"rank-{rank}", 31 | group=args.experiment_name, 32 | name=f"rank-{rank}", 33 | id=f"{args.experiment_name}-{rank}", 34 | resume="must" if resumed else None, 35 | save_code=True, 36 | config=..., 37 | ) 38 | ``` 39 | 40 | If you want the name to appear as the node id you can set: 41 | 42 | ```python 43 | name=f"node-{rank // world_size}" 44 | ``` 45 | 46 | ## every rank 47 | 48 | [Grouping docs](https://docs.wandb.ai/guides/runs/grouping) 49 | 50 | This configuration is really useful for tracking as much information about your cluster as possible. The downsides are that if you have a very large cluster, you can hit the ratelimit of wandb, and the wandb graphs become unusable. 51 | 52 | ```python 53 | wandb.init( 54 | project="distributed-training-guide", 55 | dir=exp_dir / f"rank-{rank}", 56 | group=args.experiment_name, 57 | name=f"rank-{rank}", 58 | id=f"{args.experiment_name}-{rank}", 59 | resume="must" if resumed else None, 60 | save_code=True, 61 | config=..., 62 | ) 63 | ``` 64 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | wandb==0.17.5 2 | torch==2.5.1 3 | tqdm 4 | datasets==3.2.0 5 | transformers==4.48.0 6 | -------------------------------------------------------------------------------- /top-cluster.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | import time 4 | import datetime 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument( 8 | "--poll-freq", default=1000, type=int, help="Frequency (in ms) to poll clusters" 9 | ) 10 | parser.add_argument("hosts", help="File containing hostnames separated by newlines") 11 | args = parser.parse_args() 12 | 13 | with open(args.hosts) as fp: 14 | hosts = list(filter(None, map(str.strip, fp.readlines()))) 15 | 16 | while True: 17 | procs = [ 18 | subprocess.Popen( 19 | [ 20 | "ssh", 21 | host, 22 | "nvidia-smi", 23 | "--query-gpu=utilization.gpu,power.draw,power.limit,memory.used,memory.total", 24 | "--format=csv,noheader,nounits", 25 | "&&", 26 | "nvidia-smi", 27 | "--query-compute-apps=pid", 28 | "--format=csv,noheader,nounits", 29 | ], 30 | stdout=subprocess.PIPE, 31 | stderr=subprocess.STDOUT, 32 | ) 33 | for host in hosts 34 | ] 35 | for proc in procs: 36 | proc.wait() 37 | 38 | outputs = [proc.stdout.read().decode() for proc in procs] 39 | 40 | gpu_stats = {} 41 | node_stats = { 42 | host: dict(util=0, power_usage=0, memory_usage=0, num_gpus=0, num_procs=0) 43 | for host in hosts 44 | } 45 | cluster_stats = dict(util=0, power_usage=0, memory_usage=0, num_gpus=0, num_procs=0) 46 | for host, output in zip(hosts, outputs): 47 | gpu_stats[host] = {} 48 | for gpu, stats in enumerate(output.splitlines()): 49 | if "," not in stats: 50 | node_stats[host]["num_procs"] += 1 51 | cluster_stats["num_procs"] += 1 52 | continue 53 | 54 | util, power_draw, power_limit, memory_used, memory_total = map( 55 | float, stats.split(", ") 56 | ) 57 | power_usage = 100 * power_draw / power_limit 58 | memory_usage = 100 * memory_used / memory_total 59 | 60 | gpu_stats[host][gpu] = dict( 61 | util=util, power_usage=power_usage, memory_usage=memory_usage 62 | ) 63 | node_stats[host]["util"] += util 64 | node_stats[host]["memory_usage"] += memory_usage 65 | node_stats[host]["power_usage"] += power_usage 66 | node_stats[host]["num_gpus"] += 1 67 | cluster_stats["util"] += util 68 | cluster_stats["memory_usage"] += memory_usage 69 | cluster_stats["power_usage"] += power_usage 70 | cluster_stats["num_gpus"] += 1 71 | 72 | if cluster_stats["num_gpus"] > 0: 73 | cluster_stats["util"] /= cluster_stats["num_gpus"] 74 | cluster_stats["memory_usage"] /= cluster_stats["num_gpus"] 75 | cluster_stats["power_usage"] /= cluster_stats["num_gpus"] 76 | for host in hosts: 77 | if node_stats[host]["num_gpus"] == 0: 78 | continue 79 | node_stats[host]["util"] /= node_stats[host]["num_gpus"] 80 | node_stats[host]["memory_usage"] /= node_stats[host]["num_gpus"] 81 | node_stats[host]["power_usage"] /= node_stats[host]["num_gpus"] 82 | 83 | print(f"==={datetime.datetime.now()}") 84 | print(f"{'name':>10}\t{'util':>10}\t{'power':>10}\t{'memory':>10}\t{'nprocs':>10}") 85 | print( 86 | f"{'cluster':>10}\t{cluster_stats['util']:>9.1f}%\t{cluster_stats['power_usage']:>9.1f}%\t{cluster_stats['memory_usage']:>9.1f}%\t{cluster_stats['num_procs']:>10}" 87 | ) 88 | for host, stats in node_stats.items(): 89 | print( 90 | f"{host:>10}\t{stats['util']:>9.1f}%\t{stats['power_usage']:>9.1f}%\t{stats['memory_usage']:>9.1f}%\t{stats['num_procs']:>10}" 91 | ) 92 | print("===") 93 | 94 | time.sleep(args.poll_freq / 1000.0) 95 | --------------------------------------------------------------------------------