├── .gitignore ├── .gitmodules ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── RoBERTa ├── README.md ├── fairseq.patch ├── prepare_wikitext.sh └── run.sh ├── SECURITY.md ├── SUPPORT.md ├── Swin-Transformer ├── README.md ├── Swin-Transformer.patch ├── requirements.txt └── run.sh ├── common ├── stat_communication.py └── te_utils.py ├── deit ├── README.md ├── deit.patch ├── requirements.txt └── run.sh └── gpt3 ├── Megatron-DeepSpeed.patch ├── Megatron-LM.patch ├── README.md ├── prepare_wikipedia.sh ├── pretrain_13b_megatron.sh ├── pretrain_13b_megatron_ds.sh ├── pretrain_345m_megatron.sh ├── pretrain_345m_megatron_ds.sh └── pretrain_6b7_megatron.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://github.com/github/gitignore 2 | 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # pytype static type analyzer 137 | .pytype/ 138 | 139 | # Cython debug symbols 140 | cython_debug/ 141 | 142 | */output* 143 | *.log 144 | */data/* -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/Swin-Transformer"] 2 | path = third_party/Swin-Transformer 3 | url = https://github.com/microsoft/Swin-Transformer.git 4 | branch = main 5 | [submodule "third_party/deit"] 6 | path = third_party/deit 7 | url = https://github.com/facebookresearch/deit.git 8 | branch = main 9 | [submodule "third_party/fairseq"] 10 | path = third_party/fairseq 11 | url = https://github.com/facebookresearch/fairseq.git 12 | [submodule "Megatron-DeepSpeed"] 13 | path = third_party/Megatron-DeepSpeed 14 | url = https://github.com/microsoft/Megatron-DeepSpeed.git 15 | [submodule "third_party/Megatron-LM"] 16 | path = third_party/Megatron-LM 17 | url = https://github.com/NVIDIA/Megatron-LM.git 18 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MS-AMP Examples 2 | This repository contains various training examples including [DeiT](https://github.com/facebookresearch/deit), [Swin-Transformer](https://github.com/microsoft/Swin-Transformer), [RoBERTa](https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.md) and [GPT-3](https://github.com/microsoft/Megatron-DeepSpeed#gpt-pretraining) that use [MS-AMP](https://github.com/Azure/MS-AMP). 3 | 4 | # Get started 5 | 6 | ## Prerequisites 7 | In order to run examples in this repository, you need to install MS-AMP first, and then clone the repository and submodule with the following command: 8 | ``` 9 | git clone https://github.com/Azure/MS-AMP-Examples.git 10 | cd MS-AMP-Examples 11 | git submodule update --init --recursive 12 | ``` 13 | 14 | ## Swin-Transformer 15 | This folder contains end-to-end training of [Swin-Transformer](https://github.com/microsoft/Swin-Transformer) model using MS-AMP. 16 | 17 | ## DeiT 18 | This folder contains end-to-end training of [DeiT](https://github.com/facebookresearch/deit) model using MS-AMP. 19 | 20 | ## RoBERTa 21 | This folder contains end-to-end training of [RoBERTa](https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.md) model using MS-AMP. 22 | 23 | ## GPT-3 24 | This folder contains end-to-end training of [GPT-3](https://github.com/NVIDIA/Megatron-LM#gpt-pretraining) model using MS-AMP. 25 | 26 | ## Contributing 27 | 28 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 29 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 30 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 31 | 32 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 33 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 34 | provided by the bot. You will only need to do this once across all repos using our CLA. 35 | 36 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 37 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 38 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 39 | 40 | ## Trademarks 41 | 42 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 43 | trademarks or logos is subject to and must follow 44 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 45 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 46 | Any use of third-party trademarks or logos are subject to those third-party's policies. 47 | -------------------------------------------------------------------------------- /RoBERTa/README.md: -------------------------------------------------------------------------------- 1 | # This is an example of RoBERTa using MS-AMP 2 | This example demonstrates how to use MS-AMP in [RoBERTa](https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.md). 3 | 4 | ## Apply patch to fairseq 5 | We made a few changes to the official [fairseq](https://github.com/facebookresearch/fairseq) and packaged it into a patch. You need to apply this patch to third_party/fairseq. 6 | ``` 7 | cd ../third_party/fairseq 8 | git apply ../../RoBERTa/fairseq.patch 9 | ``` 10 | 11 | ## Install failseq 12 | You need to install fairseq before training RoBERTa. It is recommended to use venv for virtual environments, but it is not strictly necessary. 13 | ``` 14 | pip install --no-build-isolation -v -e . 15 | pip install fvcore 16 | cd - 17 | ``` 18 | You can verify if the installation of fairseq is successful by executing `python -c "import fairseq; print(fairseq.__version__)"`. 19 | 20 | ## Data preparation 21 | Currently we haven't published the data we use in this example. You can use public dataset such as [WikiText-103 dataset](https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/) or your own data. Please see the [tutorial for pretraining RoBERTa using your own data](https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.pretraining.md). 22 | 23 | Here is an example of preparing WikiText-103 dataset: 24 | ``` 25 | sh prepare_wikitext.sh 26 | ``` 27 | After running the above command, a folder named data-bin will be generated. The file structure should look like: 28 | ``` 29 | $ tree data-bin/ 30 | data-bin/ 31 | └── wikitext-103 32 | ├── dict.txt 33 | ├── preprocess.log 34 | ├── test.bin 35 | ├── test.idx 36 | ├── train.bin 37 | ├── train.idx 38 | ├── valid.bin 39 | └── valid.idx 40 | ``` 41 | 42 | If you use your own data, don't forget to change the variable DATA_PATH to the data folder in launch script. 43 | 44 | 45 | ## Train RoBERTa model with AMP 46 | Run the following command to train RoBERTa base model using AMP: 47 | ``` 48 | sh run.sh amp 49 | ``` 50 | 51 | ## Train RoBERTa model with MS-AMP 52 | Run the following command to train RoBERTa base model using MS-AMP: 53 | ``` 54 | sh run.sh msamp 55 | ``` 56 | -------------------------------------------------------------------------------- /RoBERTa/fairseq.patch: -------------------------------------------------------------------------------- 1 | diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py 2 | index 5fdfab38..e444294c 100644 3 | --- a/fairseq/dataclass/configs.py 4 | +++ b/fairseq/dataclass/configs.py 5 | @@ -250,7 +250,12 @@ class CommonConfig(FairseqDataclass): 6 | "help": "path to run plasma_store, defaults to /tmp/plasma. Paths outside /tmp tend to fail." 7 | }, 8 | ) 9 | - 10 | + # MS-AMP config 11 | + msamp: bool = field(default=False, metadata={"help": "use microsoft automatic mixed precision"}) 12 | + msamp_opt_level: str = field( 13 | + default="O1", 14 | + metadata={"help": "microsoft automatic mixed precision optimization level"}, 15 | + ) 16 | 17 | @dataclass 18 | class DistributedTrainingConfig(FairseqDataclass): 19 | diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py 20 | index d7ced919..02ce8e4f 100644 21 | --- a/fairseq/models/roberta/model.py 22 | +++ b/fairseq/models/roberta/model.py 23 | @@ -473,6 +473,7 @@ class RobertaLMHead(nn.Module): 24 | def __init__(self, embed_dim, output_dim, activation_fn, weight=None): 25 | super().__init__() 26 | self.dense = nn.Linear(embed_dim, embed_dim) 27 | + self.dense.use_fp32_linear = True 28 | self.activation_fn = utils.get_activation_fn(activation_fn) 29 | self.layer_norm = LayerNorm(embed_dim) 30 | 31 | diff --git a/fairseq/optim/adam.py b/fairseq/optim/adam.py 32 | index 678ec7c6..f535ac12 100644 33 | --- a/fairseq/optim/adam.py 34 | +++ b/fairseq/optim/adam.py 35 | @@ -16,7 +16,7 @@ from fairseq.dataclass import FairseqDataclass 36 | from fairseq.optim import FairseqOptimizer, register_optimizer 37 | from fairseq.optim.fused_adam import get_fused_adam_class 38 | from omegaconf import II, OmegaConf 39 | - 40 | +import msamp 41 | 42 | logger = logging.getLogger(__name__) 43 | 44 | @@ -39,6 +39,8 @@ class FairseqAdamConfig(FairseqDataclass): 45 | # TODO common vars below in parent 46 | tpu: bool = II("common.tpu") 47 | lr: List[float] = II("optimization.lr") 48 | + msamp: bool = II("common.msamp") 49 | + msamp_opt_level: str = II("common.msamp_opt_level") 50 | 51 | 52 | @register_optimizer("adam", dataclass=FairseqAdamConfig) 53 | @@ -58,7 +60,19 @@ class FairseqAdam(FairseqOptimizer): 54 | and fused_adam_cls is not None 55 | and torch.cuda.is_available() 56 | ) 57 | - if getattr(cfg, "tpu", False): 58 | + 59 | + if cfg.msamp: 60 | + logger.info(f"using LBAdamW, msamp opt level is {cfg.msamp_opt_level}") 61 | + if cfg.msamp_opt_level == 'O1': 62 | + self.optimizer_config['exp_avg_dtype'] = torch.float32 63 | + self.optimizer_config['exp_avg_sq_dtype'] = torch.float32 64 | + elif cfg.msamp_opt_level == 'O2': 65 | + self.optimizer_config['exp_avg_dtype'] = torch.uint8 66 | + self.optimizer_config['exp_avg_sq_dtype'] = torch.float16 67 | + else: 68 | + logger.warning(f"msamp opt level {cfg.msamp_opt_level} is not supported") 69 | + self._optimizer = LBAdamW(params, **self.optimizer_config) 70 | + elif getattr(cfg, "tpu", False): 71 | if self.cfg.fp16_adam_stats: 72 | raise NotImplementedError("--fp16-adam-stats is only supported on GPU") 73 | # on TPUs we use the Adam defined here, since it 74 | @@ -106,6 +120,10 @@ class FairseqAdam(FairseqOptimizer): 75 | dist.all_reduce(value["exp_avg"], op=dist.ReduceOp.SUM) 76 | dist.all_reduce(value["exp_avg_sq"], op=dist.ReduceOp.SUM) 77 | 78 | + def all_reduce_grads(self, model): 79 | + if self.cfg.msamp and hasattr(self._optimizer, "all_reduce_grads"): 80 | + self._optimizer.all_reduce_grads(model) 81 | + super().all_reduce_grads(model) 82 | 83 | class Adam(torch.optim.Optimizer): 84 | r"""Implements Adam algorithm. 85 | @@ -237,3 +255,31 @@ class Adam(torch.optim.Optimizer): 86 | p.data.copy_(p_data_fp32) 87 | 88 | return loss 89 | + 90 | +class LBAdamW(msamp.LBAdamW): 91 | + def __init__( 92 | + self, 93 | + params, 94 | + lr=1e-3, 95 | + betas=(0.9, 0.999), 96 | + eps=1e-8, 97 | + weight_decay=0, 98 | + amsgrad=False, 99 | + exp_avg_dtype=torch.uint8, 100 | + exp_avg_sq_dtype=torch.float16, 101 | + ): 102 | + defaults = dict( 103 | + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, 104 | + exp_avg_dtype=exp_avg_dtype, exp_avg_sq_dtype=exp_avg_sq_dtype 105 | + ) 106 | + super().__init__(params, **defaults) 107 | + 108 | + @property 109 | + def supports_memory_efficient_fp16(self): 110 | + # DO NOT USE MemoryEfficientFP16Optimizer 111 | + return False 112 | + 113 | + @property 114 | + def supports_flat_params(self): 115 | + # since FP16 params with different scaling factor 116 | + return False 117 | diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py 118 | index 7e541175..5fff9e04 100644 119 | --- a/fairseq/optim/fairseq_optimizer.py 120 | +++ b/fairseq/optim/fairseq_optimizer.py 121 | @@ -7,6 +7,7 @@ import torch 122 | from fairseq import utils 123 | from fairseq.dataclass.utils import gen_parser_from_dataclass 124 | 125 | +import msamp 126 | 127 | class FairseqOptimizer(object): 128 | def __init__(self, cfg): 129 | @@ -109,6 +110,8 @@ class FairseqOptimizer(object): 130 | 131 | def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): 132 | """Clips gradient norm.""" 133 | + if hasattr(self.cfg, "msamp") and self.cfg.msamp: 134 | + return msamp.clip_grad_norm_(self.params, max_norm) 135 | return utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn) 136 | 137 | def step(self, closure=None, scale=1.0, groups=None): 138 | diff --git a/fairseq/trainer.py b/fairseq/trainer.py 139 | index da1f9491..ec6e1732 100644 140 | --- a/fairseq/trainer.py 141 | +++ b/fairseq/trainer.py 142 | @@ -18,6 +18,7 @@ from typing import Any, Dict, List 143 | 144 | import torch 145 | from omegaconf import OmegaConf 146 | +import msamp 147 | 148 | from fairseq import checkpoint_utils, models, optim, utils 149 | from fairseq.dataclass.configs import FairseqConfig 150 | @@ -115,6 +116,14 @@ class Trainer(object): 151 | ): 152 | self._criterion = self._criterion.to(device=self.device) 153 | self._model = self._model.to(device=self.device) 154 | + 155 | + if self.cfg.common.msamp: 156 | + logger.info(f"msamp is enabled, opt level is {self.cfg.common.msamp_opt_level}") 157 | + assert self._model is not None 158 | + self._model = msamp.nn.LinearReplacer.replace(self._model) 159 | + # self._model, _ = msamp.initialize(self._model, None, self.cfg.common.msamp_opt_level) 160 | + logger.info(f"FP8 model is: {self._model}") 161 | + 162 | self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel 163 | self.last_device = None 164 | if self.cuda and self.pipeline_model_parallel: 165 | diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py 166 | index 376bd1d0..75197bc0 100644 167 | --- a/fairseq_cli/train.py 168 | +++ b/fairseq_cli/train.py 169 | @@ -26,6 +26,7 @@ logger = logging.getLogger("fairseq_cli.train") 170 | import numpy as np 171 | import torch 172 | from omegaconf import DictConfig, OmegaConf 173 | +from fvcore.nn import FlopCountAnalysis 174 | 175 | from fairseq import checkpoint_utils, options, quantization_utils, tasks, utils 176 | from fairseq.data import data_utils, iterators 177 | @@ -123,6 +124,16 @@ def main(cfg: FairseqConfig) -> None: 178 | ) 179 | ) 180 | 181 | + # compute model flops: create fake input and use fvcore to compute the model flops. 182 | + tokens_per_sample = cfg.task.tokens_per_sample 183 | + logger.info(f'tokens_per_sample: {tokens_per_sample}') 184 | + token_ids = [0] + [100] * (tokens_per_sample - 2) + [2] 185 | + src_tokens = torch.tensor(token_ids).unsqueeze(0).cuda() 186 | + with torch.no_grad(): 187 | + model.cuda() 188 | + model_flops = FlopCountAnalysis(model, src_tokens).total() 189 | + logger.info(f'model flops: {model_flops}') 190 | + 191 | # Load valid dataset (we load training data below, based on the latest checkpoint) 192 | # We load the valid dataset AFTER building the model 193 | data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg) 194 | @@ -187,7 +198,7 @@ def main(cfg: FairseqConfig) -> None: 195 | break 196 | 197 | # train for one epoch 198 | - valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) 199 | + valid_losses, should_stop = train(cfg, trainer, task, epoch_itr, model_flops) 200 | if should_stop: 201 | break 202 | 203 | @@ -244,7 +255,7 @@ def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool: 204 | 205 | @metrics.aggregate("train") 206 | def train( 207 | - cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr 208 | + cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, model_flops=0 209 | ) -> Tuple[List[Optional[float]], bool]: 210 | """Train the model for one epoch and return validation losses.""" 211 | # Initialize data iterator 212 | @@ -320,6 +331,12 @@ def train( 213 | num_updates = trainer.get_num_updates() 214 | if num_updates % cfg.common.log_interval == 0: 215 | stats = get_training_stats(metrics.get_smoothed_values("train_inner")) 216 | + throughput = stats["bsz"] * stats["ups"] 217 | + stats["throughput"] = throughput 218 | + throughput_per_gpu = throughput / cfg.distributed_training.distributed_world_size 219 | + # The reason of x 3: 1 forward + 2 backward. The reason of x 2: 1MACs = 2FLOPs 220 | + stats["tflops"] = throughput_per_gpu * 3 * 2 * model_flops / 1e12 221 | + 222 | progress.log(stats, tag="train_inner", step=num_updates) 223 | 224 | # reset mid-epoch stats after each log interval 225 | diff --git a/setup.py b/setup.py 226 | index 8a9b2f97..1bc5650d 100644 227 | --- a/setup.py 228 | +++ b/setup.py 229 | @@ -187,7 +187,6 @@ def do_setup(package_data): 230 | "torch>=1.10", 231 | "tqdm", 232 | "bitarray", 233 | - "torchaudio>=0.8.0", 234 | ], 235 | extras_require={ 236 | "dev": ["flake8", "pytest", "black==22.3.0"], 237 | -------------------------------------------------------------------------------- /RoBERTa/prepare_wikitext.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright (c) Microsoft Corporation - All rights reserved 4 | # Licensed under the MIT License 5 | 6 | set -e 7 | 8 | if [ -d "data-bin" ] 9 | then 10 | rm -rf data-bin 11 | fi 12 | 13 | # wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip 14 | # unzip wikitext-103-raw-v1.zip 15 | 16 | wget https://dax-cdn.cdn.appdomain.cloud/dax-wikitext-103/1.0.1/wikitext-103.tar.gz 17 | tar -zxvf wikitext-103.tar.gz && mv wikitext-103 wikitext-103-raw 18 | 19 | 20 | mkdir -p gpt2_bpe 21 | wget -O gpt2_bpe/encoder.json https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json 22 | wget -O gpt2_bpe/vocab.bpe https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe 23 | 24 | export PYTHONPATH=../third_party/fairseq:$PYTHONPATH 25 | 26 | for SPLIT in train valid test; do \ 27 | python -m examples.roberta.multiprocessing_bpe_encoder \ 28 | --encoder-json gpt2_bpe/encoder.json \ 29 | --vocab-bpe gpt2_bpe/vocab.bpe \ 30 | --inputs wikitext-103-raw/wiki.${SPLIT}.tokens \ 31 | --outputs wikitext-103-raw/wiki.${SPLIT}.bpe \ 32 | --keep-empty \ 33 | --workers 60; \ 34 | done 35 | 36 | 37 | wget -O gpt2_bpe/dict.txt https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt 38 | fairseq-preprocess \ 39 | --only-source \ 40 | --srcdict gpt2_bpe/dict.txt \ 41 | --trainpref wikitext-103-raw/wiki.train.bpe \ 42 | --validpref wikitext-103-raw/wiki.valid.bpe \ 43 | --testpref wikitext-103-raw/wiki.test.bpe \ 44 | --destdir data-bin/wikitext-103 \ 45 | --workers 60 46 | 47 | rm -rf gpt2_bpe 48 | rm -rf wikitext-103* 49 | -------------------------------------------------------------------------------- /RoBERTa/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright (c) Microsoft Corporation - All rights reserved 4 | # Licensed under the MIT License 5 | 6 | set -e 7 | 8 | USAGE="usage: bash run.sh [amp|msamp]" 9 | 10 | if [ "$#" -ne 1 ]; then 11 | echo $USAGE 12 | exit 1 13 | fi 14 | 15 | DATA_PATH=$PWD/data-bin/wikitext-103 16 | GPU_NUM=4 17 | amp_type=$1 18 | fairseq_train=`which fairseq-hydra-train` 19 | 20 | if [ "$amp_type" = "amp" ]; then 21 | echo "run RoBERTa base with AMP" 22 | SAVE_PATH=$PWD/checkpoints/roberta_amp/ 23 | $fairseq_train \ 24 | --config-dir ../third_party/fairseq/examples/roberta/config/pretraining \ 25 | --config-name base \ 26 | task.data=$DATA_PATH \ 27 | checkpoint.save_dir=$SAVE_PATH \ 28 | dataset.skip_invalid_size_inputs_valid_test=True \ 29 | dataset.batch_size=64 \ 30 | optimization.update_freq=[8] \ 31 | common.fp16=False \ 32 | common.amp=True \ 33 | checkpoint.save_interval_updates=500 \ 34 | common.log_interval=20 \ 35 | dataset.validate_interval_updates=500 \ 36 | distributed_training.ddp_backend=c10d \ 37 | distributed_training.distributed_world_size=$GPU_NUM 38 | 39 | elif [ "$amp_type" = "msamp" ]; then 40 | echo "run RoBERTa base with MS-AMP" 41 | SAVE_PATH=$PWD/checkpoints/roberta_msamp/ 42 | $fairseq_train \ 43 | --config-dir ../third_party/fairseq/examples/roberta/config/pretraining \ 44 | --config-name base \ 45 | task.data=$DATA_PATH \ 46 | checkpoint.save_dir=$SAVE_PATH \ 47 | dataset.skip_invalid_size_inputs_valid_test=True \ 48 | dataset.batch_size=64 \ 49 | optimization.update_freq=[8] \ 50 | common.fp16=False \ 51 | common.amp=True \ 52 | checkpoint.save_interval_updates=500 \ 53 | common.log_interval=20 \ 54 | dataset.validate_interval_updates=500 \ 55 | common.msamp=True \ 56 | common.msamp_opt_level=O2 \ 57 | distributed_training.ddp_backend=c10d \ 58 | distributed_training.distributed_world_size=$GPU_NUM 59 | else 60 | echo $USAGE 61 | exit 1 62 | fi 63 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support 2 | 3 | ## How to file issues and get help 4 | 5 | This project uses [GitHub Issues] to track bugs and feature requests. Please search the existing 6 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 7 | feature request as a new issue. 8 | 9 | For help and questions about using this project, please create a new post in [GitHub Discussions]. 10 | 11 | ## Microsoft Support Policy 12 | 13 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 14 | 15 | [GitHub Issues]: https://github.com/Azure/MS-AMP-Examples/issues 16 | [GitHub Discussions]: https://github.com/Azure/MS-AMP-Examples/discussions -------------------------------------------------------------------------------- /Swin-Transformer/README.md: -------------------------------------------------------------------------------- 1 | # This is an example of Swin-Transformer using MS-AMP 2 | This example demonstrates how to use MS-AMP in [Swin-Transformer](https://github.com/microsoft/Swin-Transformer). 3 | 4 | ## Data preparation 5 | We use standard ImageNet dataset, you can download it from http://image-net.org/. The file structure should look like: 6 | ``` 7 | $ tree data 8 | ImageNet 9 | ├── train 10 | │ ├── class1 11 | │ │ ├── img1.jpeg 12 | │ │ ├── img2.jpeg 13 | │ │ └── ... 14 | │ ├── class2 15 | │ │ ├── img3.jpeg 16 | │ │ └── ... 17 | │ └── ... 18 | └── val 19 | ├── class1 20 | │ ├── img4.jpeg 21 | │ ├── img5.jpeg 22 | │ └── ... 23 | ├── class2 24 | │ ├── img6.jpeg 25 | │ └── ... 26 | └── ... 27 | ``` 28 | After that, you may need to change the varaible DATA_PATH to the data folder in launch script. 29 | 30 | ## Install dependencies 31 | You need to install depedencies before training Swin-Transformer. It is recommended to use venv for virtual environments, but it is not strictly necessary. 32 | ``` 33 | cd Swin-Transformer 34 | pip install -r requirements.txt 35 | ``` 36 | 37 | ## Apply patch to Swin-Transformer 38 | We made a few changes to the official Swin-Transformer and packaged it into a patch. You need to apply this patch to third_party/Swin-Transformer. 39 | ``` 40 | cd ../third_party/Swin-Transformer 41 | git apply ../../Swin-Transformer/Swin-Transformer.patch 42 | cd - 43 | ``` 44 | 45 | ## Train Swin-Transformer tiny model with AMP 46 | Run the following command to train a tiny Swin-Transformer model using AMP. 47 | ``` 48 | sh run.sh tiny amp 49 | ``` 50 | 51 | ## Train Swin-Transformer tiny model with MS-AMP 52 | Run the following command to train a tiny Swin-Transformer model using MS-AMP. 53 | ``` 54 | sh run.sh tiny msamp 55 | ``` 56 | 57 | ## Train Swin-Transformer giant model with AMP 58 | Run the following command to train a giant Swin-Transformer model using AMP. 59 | ``` 60 | sh run.sh giant amp 61 | ``` 62 | 63 | ## Train Swin-Transformer giant model with TE 64 | Run the following command to train a giant Swin-Transformer model using FP8 in Transformer Engine. 65 | ``` 66 | sh run.sh giant te-fp8 67 | ``` 68 | 69 | ## Train Swin-Transformer giant model with MS-AMP 70 | Run the following command to train a giant Swin-Transformer model using MS-AMP. You can observe significant GPU memory saving using `nvidia-smi` compared with AMP. 71 | ``` 72 | sh run.sh giant msamp 73 | ``` -------------------------------------------------------------------------------- /Swin-Transformer/Swin-Transformer.patch: -------------------------------------------------------------------------------- 1 | diff --git a/config.py b/config.py 2 | index 1671ec3..be7c6fd 100644 3 | --- a/config.py 4 | +++ b/config.py 5 | @@ -260,6 +260,12 @@ _C.LOCAL_RANK = 0 6 | _C.FUSED_WINDOW_PROCESS = False 7 | _C.FUSED_LAYERNORM = False 8 | 9 | +# ms-amp 10 | +_C.ENABLE_MSAMP = False 11 | +_C.MSAMP_OPT_LEVEL = 'O2' 12 | + 13 | +# te-fp8 14 | +_C.ENABLE_TEFP8 = False 15 | 16 | def _update_config_from_file(config, cfg_file): 17 | config.defrost() 18 | @@ -333,6 +339,16 @@ def update_config(config, args): 19 | if _check_args('optim'): 20 | config.TRAIN.OPTIMIZER.NAME = args.optim 21 | 22 | + ## msamp 23 | + if _check_args('enable_msamp'): 24 | + config.ENABLE_MSAMP = args.enable_msamp 25 | + if _check_args('msamp_opt_level'): 26 | + config.MSAMP_OPT_LEVEL = args.msamp_opt_level 27 | + 28 | + # te-fp8 29 | + if _check_args('enable_tefp8'): 30 | + config.ENABLE_TEFP8 = args.enable_tefp8 31 | + 32 | # set local rank for distributed training 33 | config.LOCAL_RANK = args.local_rank 34 | 35 | diff --git a/configs/swin/swin_giant_patch4_window7_224.yaml b/configs/swin/swin_giant_patch4_window7_224.yaml 36 | new file mode 100644 37 | index 0000000..67b4476 38 | --- /dev/null 39 | +++ b/configs/swin/swin_giant_patch4_window7_224.yaml 40 | @@ -0,0 +1,9 @@ 41 | +MODEL: 42 | + TYPE: swin 43 | + NAME: swin_giant_patch4_window7_224 44 | + DROP_PATH_RATE: 0.5 45 | + SWIN: 46 | + EMBED_DIM: 448 47 | + DEPTHS: [ 2, 2, 18, 2 ] 48 | + NUM_HEADS: [ 14, 28, 56, 112 ] 49 | + WINDOW_SIZE: 7 50 | \ No newline at end of file 51 | diff --git a/main.py b/main.py 52 | index 84230ea..a40369e 100644 53 | --- a/main.py 54 | +++ b/main.py 55 | @@ -19,6 +19,7 @@ import torch.distributed as dist 56 | 57 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 58 | from timm.utils import accuracy, AverageMeter 59 | +import msamp 60 | 61 | from config import get_config 62 | from models import build_model 63 | @@ -29,6 +30,10 @@ from logger import create_logger 64 | from utils import load_checkpoint, load_pretrained, save_checkpoint, NativeScalerWithGradNormCount, auto_resume_helper, \ 65 | reduce_tensor 66 | 67 | +import sys 68 | +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../')) 69 | + 70 | +from common.te_utils import replace_with_telinear, TeUtils 71 | 72 | def parse_option(): 73 | parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False) 74 | @@ -64,7 +69,7 @@ def parse_option(): 75 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 76 | 77 | # distributed training 78 | - parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') 79 | + parser.add_argument("--local_rank", type=int, help='local rank for DistributedDataParallel') 80 | 81 | # for acceleration 82 | parser.add_argument('--fused_window_process', action='store_true', 83 | @@ -73,9 +78,18 @@ def parse_option(): 84 | ## overwrite optimizer in config (*.yaml) if specified, e.g., fused_adam/fused_lamb 85 | parser.add_argument('--optim', type=str, 86 | help='overwrite optimizer if provided, can be adamw/sgd/fused_adam/fused_lamb.') 87 | + # ms-amp 88 | + parser.add_argument('--enable-msamp', action='store_true', default=False, help='enable MS-AMP') 89 | + parser.add_argument('--msamp-opt-level', type=str, default='O1', help='MS-AMP optimization level') 90 | + 91 | + # te-fp8 92 | + parser.add_argument('--enable-tefp8', action='store_true', default=False, help='enable TE-FP8') 93 | 94 | args, unparsed = parser.parse_known_args() 95 | 96 | + if args.local_rank is None and 'LOCAL_RANK' in os.environ: 97 | + args.local_rank = int(os.environ['LOCAL_RANK']) 98 | + 99 | config = get_config(args) 100 | 101 | return args, config 102 | @@ -88,6 +102,10 @@ def main(config): 103 | model = build_model(config) 104 | logger.info(str(model)) 105 | 106 | + # get flops of the model. 107 | + model_flops = model.flops() 108 | + args.model_flops = model_flops 109 | + 110 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 111 | logger.info(f"number of params: {n_parameters}") 112 | if hasattr(model, 'flops'): 113 | @@ -95,9 +113,22 @@ def main(config): 114 | logger.info(f"number of GFLOPs: {flops / 1e9}") 115 | 116 | model.cuda() 117 | + 118 | + if config.ENABLE_TEFP8: 119 | + logger.info('te-fp8 is enabled') 120 | + model = replace_with_telinear(model) 121 | + 122 | model_without_ddp = model 123 | 124 | optimizer = build_optimizer(config, model) 125 | + if config.ENABLE_MSAMP: 126 | + logger.info(f"msamp is enabled, opt level is {config.MSAMP_OPT_LEVEL}") 127 | + model, optimizer = msamp.initialize(model, optimizer, config.MSAMP_OPT_LEVEL) 128 | + 129 | + if dist.get_rank() == 0: 130 | + logger.info(f'type of optimizer is {optimizer}') 131 | + logger.info(f'model is {model}') 132 | + 133 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) 134 | loss_scaler = NativeScalerWithGradNormCount() 135 | 136 | @@ -177,6 +208,8 @@ def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mix 137 | 138 | start = time.time() 139 | end = time.time() 140 | + 141 | + autocast_context = TeUtils.get_autocast(config.AMP_ENABLE, config.ENABLE_TEFP8) 142 | for idx, (samples, targets) in enumerate(data_loader): 143 | samples = samples.cuda(non_blocking=True) 144 | targets = targets.cuda(non_blocking=True) 145 | @@ -184,14 +217,14 @@ def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mix 146 | if mixup_fn is not None: 147 | samples, targets = mixup_fn(samples, targets) 148 | 149 | - with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): 150 | + with autocast_context(): 151 | outputs = model(samples) 152 | loss = criterion(outputs, targets) 153 | loss = loss / config.TRAIN.ACCUMULATION_STEPS 154 | 155 | # this attribute is added by timm on one optimizer (adahessian) 156 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 157 | - grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD, 158 | + grad_norm = loss_scaler(loss, model, optimizer, clip_grad=config.TRAIN.CLIP_GRAD, 159 | parameters=model.parameters(), create_graph=is_second_order, 160 | update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0) 161 | if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: 162 | @@ -207,19 +240,26 @@ def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mix 163 | scaler_meter.update(loss_scale_value) 164 | batch_time.update(time.time() - end) 165 | end = time.time() 166 | - 167 | + throughput_per_gpu = args.batch_size / batch_time.val 168 | + throughput = dist.get_world_size() * throughput_per_gpu 169 | + throughput_avg = dist.get_world_size() * args.batch_size / batch_time.avg 170 | if idx % config.PRINT_FREQ == 0: 171 | lr = optimizer.param_groups[0]['lr'] 172 | wd = optimizer.param_groups[0]['weight_decay'] 173 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 174 | etas = batch_time.avg * (num_steps - idx) 175 | + ratio = 3 if config.TRAIN.USE_CHECKPOINT else 4 176 | + # First mutiply by ratio: 1 for forward, 2 for backward, 1 for activation checkpoint. Then multiply by 2: 1MACs = 2FLOPs 177 | + tflops = ratio * 2 * args.model_flops * throughput_per_gpu / 1e12 178 | logger.info( 179 | f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' 180 | f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t wd {wd:.4f}\t' 181 | f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' 182 | + f'throughput {throughput:.2f} ({throughput_avg: .2f})\t' 183 | f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 184 | f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' 185 | f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t' 186 | + f'tflops {tflops:.2f}\t' 187 | f'mem {memory_used:.0f}MB') 188 | epoch_time = time.time() - start 189 | logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") 190 | @@ -236,12 +276,13 @@ def validate(config, data_loader, model): 191 | acc5_meter = AverageMeter() 192 | 193 | end = time.time() 194 | + autocast_context = TeUtils.get_autocast(config.AMP_ENABLE, config.ENABLE_TEFP8) 195 | for idx, (images, target) in enumerate(data_loader): 196 | images = images.cuda(non_blocking=True) 197 | target = target.cuda(non_blocking=True) 198 | 199 | # compute output 200 | - with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): 201 | + with autocast_context(): 202 | output = model(images) 203 | 204 | # measure accuracy and record loss 205 | diff --git a/models/build.py b/models/build.py 206 | index c37384d..772e68b 100644 207 | --- a/models/build.py 208 | +++ b/models/build.py 209 | @@ -50,6 +50,8 @@ def build_model(config, is_pretrain=False): 210 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 211 | use_checkpoint=config.TRAIN.USE_CHECKPOINT, 212 | fused_window_process=config.FUSED_WINDOW_PROCESS) 213 | + if config.ENABLE_MSAMP or config.ENABLE_TEFP8: 214 | + model.head.use_fp32_linear = True 215 | elif model_type == 'swinv2': 216 | model = SwinTransformerV2(img_size=config.DATA.IMG_SIZE, 217 | patch_size=config.MODEL.SWINV2.PATCH_SIZE, 218 | diff --git a/utils.py b/utils.py 219 | index eb607cf..7be4dd6 100644 220 | --- a/utils.py 221 | +++ b/utils.py 222 | @@ -8,8 +8,8 @@ 223 | import os 224 | import torch 225 | import torch.distributed as dist 226 | -from torch._six import inf 227 | - 228 | +from torch import inf 229 | +from msamp import clip_grad_norm_ 230 | 231 | def load_checkpoint(config, model, optimizer, lr_scheduler, loss_scaler, logger): 232 | logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................") 233 | @@ -198,13 +198,15 @@ class NativeScalerWithGradNormCount: 234 | def __init__(self): 235 | self._scaler = torch.cuda.amp.GradScaler() 236 | 237 | - def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 238 | + def __call__(self, loss, model, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 239 | self._scaler.scale(loss).backward(create_graph=create_graph) 240 | if update_grad: 241 | if clip_grad is not None: 242 | + if hasattr(optimizer, 'all_reduce_grads'): 243 | + optimizer.all_reduce_grads(model) 244 | assert parameters is not None 245 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 246 | - norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 247 | + norm = clip_grad_norm_(parameters, clip_grad) 248 | else: 249 | self._scaler.unscale_(optimizer) 250 | norm = ampscaler_get_grad_norm(parameters) 251 | -------------------------------------------------------------------------------- /Swin-Transformer/requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | timm==0.4.12 3 | opencv-python==4.9.0.80 4 | termcolor==1.1.0 5 | yacs==0.1.8 6 | pyyaml 7 | scipy -------------------------------------------------------------------------------- /Swin-Transformer/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright (c) Microsoft Corporation - All rights reserved 4 | # Licensed under the MIT License 5 | 6 | set -e 7 | 8 | USAGE="usage: bash run.sh [tiny|giant] [amp|msamp|te-fp8]" 9 | 10 | if [ "$#" -ne 2 ]; then 11 | echo $USAGE 12 | exit 1 13 | fi 14 | 15 | DATA_PATH=../../ImageNet 16 | GPU_NUM=8 17 | MASTER_PORT=12345 18 | 19 | model=$1 20 | amp_type=$2 21 | 22 | if [ "$model" == "tiny" -a "$amp_type" == "amp" ]; then 23 | echo "run tiny Swin-Transformer with AMP" 24 | python -m torch.distributed.launch \ 25 | --nproc_per_node $GPU_NUM \ 26 | --master_port $MASTER_PORT \ 27 | ../third_party/Swin-Transformer/main.py \ 28 | --cfg ../third_party/Swin-Transformer/configs/swin/swin_tiny_patch4_window7_224.yaml \ 29 | --data-path $DATA_PATH \ 30 | --batch-size 128 \ 31 | --output output 32 | elif [ "$model" == "tiny" -a "$amp_type" == "msamp" ]; then 33 | echo "run tiny Swin-Transformer with MS-AMP" 34 | python -m torch.distributed.launch \ 35 | --nproc_per_node $GPU_NUM \ 36 | --master_port $MASTER_PORT \ 37 | ../third_party/Swin-Transformer/main.py \ 38 | --cfg ../third_party/Swin-Transformer/configs/swin/swin_tiny_patch4_window7_224.yaml \ 39 | --data-path $DATA_PATH \ 40 | --batch-size 128 \ 41 | --output output_msamp \ 42 | --enable-msamp \ 43 | --msamp-opt-level O2 44 | elif [ "$model" == "giant" -a "$amp_type" == "amp" ]; then 45 | echo "run giant Swin-Transformer with AMP" 46 | python -m torch.distributed.launch \ 47 | --nproc_per_node $GPU_NUM \ 48 | --master_port $MASTER_PORT \ 49 | ../third_party/Swin-Transformer/main.py \ 50 | --cfg ../third_party/Swin-Transformer/configs/swin/swin_giant_patch4_window7_224.yaml \ 51 | --data-path $DATA_PATH \ 52 | -output output_giant \ 53 | --batch-size 16 54 | elif [ "$model" == "giant" -a "$amp_type" == "msamp" ]; then 55 | echo "run giant Swin-Transformer with MS-AMP" 56 | python -m torch.distributed.launch \ 57 | --nproc_per_node $GPU_NUM \ 58 | --master_port $MASTER_PORT \ 59 | ../third_party/Swin-Transformer/main.py \ 60 | --cfg ../third_party/Swin-Transformer/configs/swin/swin_giant_patch4_window7_224.yaml \ 61 | --data-path $DATA_PATH \ 62 | --batch-size 16 \ 63 | --output output_giant_msamp \ 64 | --enable-msamp \ 65 | --msamp-opt-level O2 66 | elif [ "$model" == "giant" -a "$amp_type" == "te-fp8" ]; then 67 | echo "run giant Swin-Transformer with Transformer Engine FP8" 68 | python -m torch.distributed.launch \ 69 | --nproc_per_node $GPU_NUM \ 70 | --master_port $MASTER_PORT \ 71 | ../third_party/Swin-Transformer/main.py \ 72 | --cfg ../third_party/Swin-Transformer/configs/swin/swin_giant_patch4_window7_224.yaml \ 73 | --data-path $DATA_PATH \ 74 | --batch-size 16 \ 75 | --output output_giant_tefp8 \ 76 | --enable-tefp8 77 | else 78 | echo $USAGE 79 | exit 1 80 | fi 81 | -------------------------------------------------------------------------------- /common/stat_communication.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import subprocess 5 | import glob 6 | import time 7 | import os 8 | import sys 9 | import argparse 10 | 11 | 12 | def get_ib_bandwidth(): 13 | """Get the infiniband bandwidth.""" 14 | # check if IB is avaiable. 15 | output = os.popen('ibstat').read() 16 | if 'CA type: ' not in output: 17 | print('ib is not avaiable') 18 | return 19 | 20 | # Read the value from Infiniband counter files. 21 | ib_counter_file="/sys/class/infiniband/mlx5_ib*/ports/1/counters/port_*_data" 22 | prev_sum = 0 23 | while True: 24 | current_sum = 0 25 | file_count = 0 26 | for file_path in glob.glob(ib_counter_file): 27 | file_count += 1 28 | with open(file_path, 'r') as f: 29 | current_value=int(f.read().strip()) 30 | current_sum += current_value 31 | if file_count == 0: 32 | print(f'there is no file match {ib_counter_file}, please check') 33 | return 34 | 35 | if prev_sum != 0: 36 | ib_bandwidth = current_sum - prev_sum 37 | print(f'infiniband bandwidth:{ib_bandwidth/1e9:.2f} GB') 38 | 39 | prev_sum = current_sum 40 | time.sleep(1) 41 | 42 | 43 | def get_nvlink_bandwidth(): 44 | """Get the nvlink bandwidth.""" 45 | # Get the number of GPUs on this machine 46 | process = subprocess.Popen(['nvidia-smi', '-L'], stdout=subprocess.PIPE) 47 | num_gpus = 0 48 | for line in process.stdout: 49 | line_str = line.decode().strip() 50 | if line_str.startswith('GPU'): 51 | num_gpus += 1 52 | print(f'num_gpus:{num_gpus}') 53 | 54 | # Run the command and compute nvlink bandwidth. 55 | process = subprocess.Popen(['dcgmi', 'dmon', '-e', '1011,1012'], stdout=subprocess.PIPE) 56 | 57 | gpu_ids = [] 58 | nvlink_bandwidth=0 59 | 60 | for line in process.stdout: 61 | line_str = line.decode().strip() 62 | if not line_str.startswith('GPU'): 63 | continue 64 | arr = line_str.split() 65 | gpu_id = arr[1] 66 | gpu_ids.append(gpu_id) 67 | 68 | if arr[2] != 'N/A': 69 | nvlink_bandwidth += int(arr[2]) 70 | if arr[3] != 'N/A': 71 | nvlink_bandwidth += int(arr[3]) 72 | if len(gpu_ids) == num_gpus: 73 | print(f"nvlink bandwidth:{nvlink_bandwidth/1e9:.2f} GB") 74 | gpu_ids.clear() 75 | nvlink_bandwidth = 0 76 | 77 | 78 | def main(): 79 | """The entry point function.""" 80 | parser = argparse.ArgumentParser(description='stat communication') 81 | parser.add_argument('--ib', action='store_true', help='stat infiniband bandwidth') 82 | parser.add_argument('--nvlink', action='store_true', help='stat nvlink bandwidth') 83 | 84 | args = parser.parse_args() 85 | if args.ib: 86 | get_ib_bandwidth() 87 | elif args.nvlink: 88 | get_nvlink_bandwidth() 89 | else: 90 | print('please specify --ib or --nvlink') 91 | sys.exit(1) 92 | 93 | 94 | if __name__ == "__main__": 95 | main() -------------------------------------------------------------------------------- /common/te_utils.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) Microsoft Corporation. 3 | # Licensed under the MIT License. 4 | 5 | from contextlib import contextmanager, nullcontext 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | import transformer_engine.pytorch as te 12 | from transformer_engine.common.recipe import Format, DelayedScaling 13 | 14 | 15 | class FP8Linear(nn.Module): 16 | """This is a wrapper of te.Linear, which pads the input to a multiple of 16 and then truncates the output to the original size.""" 17 | TE_BASE = 16 18 | 19 | def __init__(self, in_features, out_features, bias=True, params_dtype=torch.float32): 20 | """Constructor of FP8Linear. 21 | 22 | Args: 23 | in_features (int): size of each input sample. 24 | out_features (int): size of each output sample. 25 | bias (bool): If set to ``False``, the layer will not learn an additive bias. Default: ``True``. 26 | params_dtype (torch.dtype): The data type of the weight and bias. Default: ``torch.float32``. 27 | """ 28 | super().__init__() 29 | self.in_features = in_features 30 | self.out_features = out_features 31 | pad_in_features = FP8Linear.Round2Times(in_features, FP8Linear.TE_BASE) 32 | pad_out_features = FP8Linear.Round2Times(out_features, FP8Linear.TE_BASE) 33 | self.in_padding = pad_in_features - in_features 34 | self.linear = te.Linear(pad_in_features, pad_out_features, bias=bias, params_dtype=params_dtype) 35 | 36 | def forward(self, x): 37 | """Forward pass of FP8Linear. 38 | 39 | Args: 40 | x (torch.Tensor): Input tensor of shape ``[batch_size, ..., in_features]``. 41 | 42 | Returns: 43 | torch.Tensor: Output tensor of shape ``[batch_size, ..., out_features]``. 44 | """ 45 | return self.padding_block_io(self.linear, x) 46 | 47 | @staticmethod 48 | def Round2Times(value, base): 49 | """Round up the value to the nearest multiple of base. 50 | 51 | Args: 52 | value (int): The value to be rounded. 53 | 54 | Returns: 55 | int: The rounded value. 56 | """ 57 | return (value + base - 1) // base * base 58 | 59 | def padding_block_io(self, block, x): 60 | """Pad the input to a multiple of 16 and then truncate the output to the original size. 61 | 62 | Args: 63 | block (torch.nn.Module): The block to be padded. 64 | x (torch.Tensor): Input tensor of shape ``[batch_size, ..., in_features]``. 65 | 66 | Returns: 67 | torch.Tensor: Output tensor of shape ``[batch_size, ..., out_features]``. 68 | """ 69 | shape = list(x.shape) 70 | last_dim = shape[-1] 71 | # reshape to 2 dims 72 | x = x.view(-1, last_dim) 73 | first_dim = len(x) 74 | pad_first_dim = FP8Linear.Round2Times(first_dim, FP8Linear.TE_BASE) 75 | first_dim_padding = pad_first_dim - first_dim 76 | if first_dim_padding > 0 or self.in_padding > 0: 77 | x = F.pad(x, (0, self.in_padding, 0, first_dim_padding)) 78 | x = block(x) 79 | x = x[:first_dim, :self.out_features] 80 | shape[-1] = self.out_features 81 | return x.reshape(shape) 82 | 83 | @property 84 | def weight(self): 85 | return self.linear.weight[:self.out_features, :self.in_features] 86 | 87 | @weight.setter 88 | def weight(self, value): 89 | raise NotImplementedError 90 | 91 | @property 92 | def bias(self): 93 | if self.linear.bias is None: 94 | return None 95 | return self.linear.bias[:self.out_features] 96 | 97 | @bias.setter 98 | def bias(self, value): 99 | raise NotImplementedError 100 | 101 | 102 | @torch.no_grad() 103 | def replace_with_telinear(model): 104 | """ 105 | Replace torch.nn.Linear with FP8Linear in a model. 106 | 107 | Args: 108 | model (torch.nn.Module): The model to be replaced. 109 | 110 | Returns: 111 | torch.nn.Module: The model with FP8Linear. 112 | """ 113 | model = model.cuda() 114 | def _replace_with_telinear(model): 115 | if isinstance(model, torch.nn.Linear): 116 | if getattr(model, 'use_fp32_linear', False): 117 | return model 118 | te_linear = build_telinear_from_linear(model) 119 | return te_linear 120 | else: 121 | for name, module in model.named_children(): 122 | setattr(model, name, _replace_with_telinear(module)) 123 | return model 124 | model = _replace_with_telinear(model) 125 | return model 126 | 127 | 128 | @torch.no_grad() 129 | def build_telinear_from_linear(linear): 130 | """build FP8Linear from torch.nn.Linear. 131 | 132 | Args: 133 | linear (torch.nn.Linear): The torch.nn.Linear to be replaced. 134 | 135 | Returns: 136 | FP8Linear: The FP8Linear with same input and output features. 137 | """ 138 | weight_dtype = torch.float32 139 | te_linear = FP8Linear( 140 | in_features=linear.in_features, 141 | out_features=linear.out_features, 142 | bias=linear.bias is not None, 143 | params_dtype=weight_dtype, 144 | ).cuda() 145 | te_linear.weight[:linear.out_features, :linear.in_features].copy_(linear.weight.to(te_linear.weight.dtype)) 146 | if linear.bias is not None: 147 | te_linear.bias[:linear.out_features].copy_(linear.bias.to(te_linear.bias.dtype)) 148 | return te_linear 149 | 150 | 151 | class TeUtils: 152 | """A utility class for using transformer_engine.""" 153 | @staticmethod 154 | def get_fp8_recipe(format, max_history_len, amax_compute_algo): 155 | """Get the recipe for FP8. 156 | 157 | Args: 158 | format (str): The format of FP8. It can be 'hybrid', 'e4m3' or 'e5m2'. 159 | max_history_len (int): The maximum length of history. 160 | amax_compute_algo (str): The algorithm to compute amax. It can be 'max' or 'most_recent'. 161 | 162 | Returns: 163 | DelayedScaling: The recipe for FP8. 164 | """ 165 | assert format in ['hybrid', 'e4m3', 'e5m2'] 166 | if format == 'hybrid': 167 | fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass 168 | elif format == 'e4m3': 169 | fp8_format = Format.E4M3 170 | elif format == 'e5m2': 171 | fp8_format = Format.E5M2 172 | 173 | fp8_recipe = DelayedScaling( 174 | fp8_format=fp8_format, 175 | amax_history_len=max_history_len, 176 | amax_compute_algo=amax_compute_algo) 177 | return fp8_recipe 178 | 179 | @staticmethod 180 | def get_autocast(amp_enable, fp8_enable, fp8_format='hybrid', max_history_len=16, amax_compute_algo='max'): 181 | """Get the context manager for autocast. 182 | 183 | Args: 184 | amp_enable (bool): If set to ``True``, the amp autocast will be enabled. 185 | fp8_enable (bool): If set to ``True``, the fp8 autocast will be enabled. 186 | fp8_format (str): The format of FP8. It can be 'hybrid', 'e4m3' or 'e5m2'. 187 | max_history_len (int): The maximum length of history. 188 | amax_compute_algo (str): The algorithm to compute amax. It can be 'max' or 'most_recent'. 189 | 190 | Returns: 191 | contextmanager: The context manager for autocast. 192 | """ 193 | autocast_context = lambda : torch.cuda.amp.autocast(enabled=amp_enable, dtype=torch.bfloat16) 194 | if fp8_enable: 195 | fp8_recipe = TeUtils.get_fp8_recipe(fp8_format, max_history_len, amax_compute_algo) 196 | fp8_context = lambda : te.fp8_autocast(enabled=fp8_enable, fp8_recipe=fp8_recipe) 197 | else: 198 | fp8_context = lambda : nullcontext() 199 | 200 | @contextmanager 201 | def context_manager(*args, **kwargs): 202 | with autocast_context(): 203 | with fp8_context(): 204 | yield 205 | 206 | return context_manager -------------------------------------------------------------------------------- /deit/README.md: -------------------------------------------------------------------------------- 1 | # This is an example of DeiT using MS-AMP 2 | This example demonstrates how to use MS-AMP in [DeiT](https://github.com/facebookresearch/deit). 3 | 4 | ## Data preparation 5 | Before training, please download and extract ImageNet train and val images from http://image-net.org/. The directory structure is the standard layout for the torchvision `datasets.ImageFolder`, and the training and validation data are expected to be in the train folder and val folder respectively: 6 | ``` 7 | /path/to/imagenet/ 8 | train/ 9 | class1/ 10 | img1.jpeg 11 | class2/ 12 | img2.jpeg 13 | val/ 14 | class1/ 15 | img3.jpeg 16 | class2/ 17 | img4.jpeg 18 | ``` 19 | After that, you may need to change the varaible DATA_PATH to the data folder in launch script. 20 | 21 | ## Install dependencies 22 | You need to install depedencies before training DeiT. It is recommended to use venv for virtual environments, but it is not strictly necessary. 23 | ``` 24 | cd deit 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | ## Apply patch to DeiT 29 | We made a few changes to the official DeiT and packaged it into a patch. You need to apply this patch to third_party/deit. 30 | ``` 31 | cd ../third_party/deit 32 | git apply ../../deit/deit.patch 33 | cd - 34 | ``` 35 | 36 | ## Train DeiT-Small model with AMP 37 | Run the following command to train a small DeiT model using AMP. 38 | ``` 39 | sh run.sh small amp 40 | ``` 41 | 42 | ## Train DeiT-Small model with MS-AMP 43 | Run the following command to train a small DeiT model using MS-AMP. 44 | ``` 45 | sh run.sh small msamp 46 | ``` 47 | 48 | ## Train DeiT-Large model with AMP 49 | Run the following command to train a large DeiT model using AMP. The model has 1.3 billion parameters. 50 | ``` 51 | sh run.sh large amp 52 | ``` 53 | 54 | ## Train DeiT-Large model with TE 55 | Run the following command to train a large DeiT model using FP8 in Transformer Engine. 56 | ``` 57 | sh run.sh large te-fp8 58 | ``` 59 | 60 | ## Train DeiT-Large model with MS-AMP 61 | Run the following command to train a large DeiT model using MS-AMP. You can observe significant GPU memory saving using `nvidia-smi` compared with AMP. 62 | ``` 63 | sh run.sh large msamp 64 | ``` -------------------------------------------------------------------------------- /deit/deit.patch: -------------------------------------------------------------------------------- 1 | diff --git a/engine.py b/engine.py 2 | index ed10cea..ccc33e2 100644 3 | --- a/engine.py 4 | +++ b/engine.py 5 | @@ -15,6 +15,9 @@ from timm.utils import accuracy, ModelEma 6 | from losses import DistillationLoss 7 | import utils 8 | 9 | +import os 10 | +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../')) 11 | +from common.te_utils import TeUtils 12 | 13 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 14 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 15 | @@ -22,11 +25,14 @@ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 16 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 17 | set_training_mode=True, args = None): 18 | model.train(set_training_mode) 19 | - metric_logger = utils.MetricLogger(delimiter=" ") 20 | + metric_logger = utils.MetricLogger(delimiter=" ", model_flops=args.model_flops) 21 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 22 | header = 'Epoch: [{}]'.format(epoch) 23 | print_freq = 10 24 | 25 | + # amp_enable, fp8_enable, fp8_format='hybrid', max_history_len=1, amax_compute_algo='max' 26 | + autocast_context = TeUtils.get_autocast(True, args.enable_te_fp8) 27 | + 28 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 29 | samples = samples.to(device, non_blocking=True) 30 | targets = targets.to(device, non_blocking=True) 31 | @@ -37,7 +43,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 32 | if args.bce_loss: 33 | targets = targets.gt(0.0).type(targets.dtype) 34 | 35 | - with torch.cuda.amp.autocast(): 36 | + with autocast_context(): 37 | outputs = model(samples) 38 | loss = criterion(samples, outputs, targets) 39 | 40 | @@ -51,7 +57,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 41 | 42 | # this attribute is added by timm on one optimizer (adahessian) 43 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 44 | - loss_scaler(loss, optimizer, clip_grad=max_norm, 45 | + loss_scaler(loss, model, optimizer, clip_grad=max_norm, 46 | parameters=model.parameters(), create_graph=is_second_order) 47 | 48 | torch.cuda.synchronize() 49 | @@ -67,7 +73,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 50 | 51 | 52 | @torch.no_grad() 53 | -def evaluate(data_loader, model, device): 54 | +def evaluate(data_loader, model, device, args): 55 | criterion = torch.nn.CrossEntropyLoss() 56 | 57 | metric_logger = utils.MetricLogger(delimiter=" ") 58 | @@ -75,13 +81,14 @@ def evaluate(data_loader, model, device): 59 | 60 | # switch to evaluation mode 61 | model.eval() 62 | + autocast_context = TeUtils.get_autocast(True, args.enable_te_fp8) 63 | 64 | for images, target in metric_logger.log_every(data_loader, 10, header): 65 | images = images.to(device, non_blocking=True) 66 | target = target.to(device, non_blocking=True) 67 | 68 | # compute output 69 | - with torch.cuda.amp.autocast(): 70 | + with autocast_context(): 71 | output = model(images) 72 | loss = criterion(output, target) 73 | 74 | diff --git a/main.py b/main.py 75 | index bc8c418..75fce85 100644 76 | --- a/main.py 77 | +++ b/main.py 78 | @@ -15,7 +15,9 @@ from timm.models import create_model 79 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 80 | from timm.scheduler import create_scheduler 81 | from timm.optim import create_optimizer 82 | -from timm.utils import NativeScaler, get_state_dict, ModelEma 83 | +from timm.utils import get_state_dict, ModelEma 84 | +import msamp 85 | +from scaler import NativeScalerWithGradReduce 86 | 87 | from datasets import build_dataset 88 | from engine import train_one_epoch, evaluate 89 | @@ -28,6 +30,12 @@ import models_v2 90 | 91 | import utils 92 | 93 | +import os 94 | +import sys 95 | +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../')) 96 | + 97 | +from fvcore.nn import FlopCountAnalysis 98 | +from common.te_utils import replace_with_telinear 99 | 100 | def get_args_parser(): 101 | parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False) 102 | @@ -181,6 +189,13 @@ def get_args_parser(): 103 | parser.add_argument('--world_size', default=1, type=int, 104 | help='number of distributed processes') 105 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 106 | + 107 | + # msamp parameters 108 | + parser.add_argument('--enable-msamp', action='store_true', default=False, help='enable MS-AMP') 109 | + parser.add_argument('--msamp-opt-level', type=str, default='O1', help='MS-AMP optimization level') 110 | + 111 | + # transformer engine 112 | + parser.add_argument("--enable-te-fp8", action='store_true', default=False, help='enable TE-FP8') 113 | return parser 114 | 115 | 116 | @@ -205,6 +220,8 @@ def main(args): 117 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 118 | dataset_val, _ = build_dataset(is_train=False, args=args) 119 | 120 | + first_image = dataset_train[0][0] 121 | + 122 | if True: # args.distributed: 123 | num_tasks = utils.get_world_size() 124 | global_rank = utils.get_rank() 125 | @@ -266,6 +283,14 @@ def main(args): 126 | img_size=args.input_size 127 | ) 128 | 129 | + input = first_image.unsqueeze(0).cuda() 130 | + model_flops = FlopCountAnalysis(model.cuda(), input).total() 131 | + 132 | + args.model_flops = model_flops 133 | + print(f"model flops: {model_flops}") 134 | + 135 | + if args.enable_msamp or args.enable_te_fp8: 136 | + model.head.use_fp32_linear = True 137 | 138 | if args.finetune: 139 | if args.finetune.startswith('https'): 140 | @@ -336,17 +361,33 @@ def main(args): 141 | device='cpu' if args.model_ema_force_cpu else '', 142 | resume='') 143 | 144 | + if not args.unscale_lr: 145 | + linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 146 | + args.lr = linear_scaled_lr 147 | + 148 | + if args.enable_te_fp8: 149 | + print("te-fp8 is enabled") 150 | + assert not args.enable_msamp, 'msamp and te-fp8 cannot be enabled at the same time' 151 | + model = replace_with_telinear(model) 152 | + 153 | + optimizer = create_optimizer(args, model) 154 | + 155 | + if args.enable_msamp: 156 | + print(f'msamp is enabled, opt_level: {args.msamp_opt_level}') 157 | + model, optimizer = msamp.initialize(model, optimizer, args.msamp_opt_level) 158 | + 159 | + if utils.get_rank() == 0: 160 | + print(f'type of optimizer is {type(optimizer)}') 161 | + print(f'model is {model}') 162 | + 163 | model_without_ddp = model 164 | if args.distributed: 165 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 166 | model_without_ddp = model.module 167 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 168 | print('number of params:', n_parameters) 169 | - if not args.unscale_lr: 170 | - linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 171 | - args.lr = linear_scaled_lr 172 | - optimizer = create_optimizer(args, model_without_ddp) 173 | - loss_scaler = NativeScaler() 174 | + 175 | + loss_scaler = NativeScalerWithGradReduce() 176 | 177 | lr_scheduler, _ = create_scheduler(args, optimizer) 178 | 179 | @@ -406,7 +447,7 @@ def main(args): 180 | loss_scaler.load_state_dict(checkpoint['scaler']) 181 | lr_scheduler.step(args.start_epoch) 182 | if args.eval: 183 | - test_stats = evaluate(data_loader_val, model, device) 184 | + test_stats = evaluate(data_loader_val, model, device, args) 185 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 186 | return 187 | 188 | @@ -434,13 +475,13 @@ def main(args): 189 | 'optimizer': optimizer.state_dict(), 190 | 'lr_scheduler': lr_scheduler.state_dict(), 191 | 'epoch': epoch, 192 | - 'model_ema': get_state_dict(model_ema), 193 | + 'model_ema': get_state_dict(model_ema) if model_ema else None, 194 | 'scaler': loss_scaler.state_dict(), 195 | 'args': args, 196 | }, checkpoint_path) 197 | 198 | 199 | - test_stats = evaluate(data_loader_val, model, device) 200 | + test_stats = evaluate(data_loader_val, model, device, args) 201 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 202 | 203 | if max_accuracy < test_stats["acc1"]: 204 | @@ -453,7 +494,7 @@ def main(args): 205 | 'optimizer': optimizer.state_dict(), 206 | 'lr_scheduler': lr_scheduler.state_dict(), 207 | 'epoch': epoch, 208 | - 'model_ema': get_state_dict(model_ema), 209 | + 'model_ema': get_state_dict(model_ema) if model_ema else None, 210 | 'scaler': loss_scaler.state_dict(), 211 | 'args': args, 212 | }, checkpoint_path) 213 | diff --git a/models.py b/models.py 214 | index 5b22ef3..9919678 100644 215 | --- a/models.py 216 | +++ b/models.py 217 | @@ -10,7 +10,7 @@ from timm.models.layers import trunc_normal_ 218 | 219 | 220 | __all__ = [ 221 | - 'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224', 222 | + 'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224', 'deit_large_patch16_224', 223 | 'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224', 224 | 'deit_base_distilled_patch16_224', 'deit_base_patch16_384', 225 | 'deit_base_distilled_patch16_384', 226 | @@ -103,6 +103,14 @@ def deit_base_patch16_224(pretrained=False, **kwargs): 227 | model.load_state_dict(checkpoint["model"]) 228 | return model 229 | 230 | +@register_model 231 | +def deit_large_patch16_224(pretrained=False, **kwargs): 232 | + model = VisionTransformer( 233 | + patch_size=16, embed_dim=2048, depth=24, num_heads=32, mlp_ratio=4, qkv_bias=True, 234 | + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 235 | + assert not pretrained 236 | + model.default_cfg = _cfg() 237 | + return model 238 | 239 | @register_model 240 | def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): 241 | diff --git a/scaler.py b/scaler.py 242 | new file mode 100644 243 | index 0000000..dba0819 244 | --- /dev/null 245 | +++ b/scaler.py 246 | @@ -0,0 +1,28 @@ 247 | + 248 | +import torch 249 | +from msamp import clip_grad_norm_ 250 | + 251 | + 252 | +class NativeScalerWithGradReduce: 253 | + state_dict_key = "amp_scaler" 254 | + 255 | + def __init__(self): 256 | + self._scaler = torch.cuda.amp.GradScaler() 257 | + 258 | + def __call__(self, loss, model, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): 259 | + self._scaler.scale(loss).backward(create_graph=create_graph) 260 | + if hasattr(optimizer, 'all_reduce_grads'): 261 | + optimizer.all_reduce_grads(model) 262 | + if clip_grad is not None: 263 | + assert parameters is not None 264 | + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 265 | + assert clip_mode == 'norm' 266 | + clip_grad_norm_(parameters, clip_grad) 267 | + self._scaler.step(optimizer) 268 | + self._scaler.update() 269 | + 270 | + def state_dict(self): 271 | + return self._scaler.state_dict() 272 | + 273 | + def load_state_dict(self, state_dict): 274 | + self._scaler.load_state_dict(state_dict) 275 | \ No newline at end of file 276 | diff --git a/utils.py b/utils.py 277 | index d1064f6..e96ca83 100644 278 | --- a/utils.py 279 | +++ b/utils.py 280 | @@ -14,7 +14,6 @@ import datetime 281 | import torch 282 | import torch.distributed as dist 283 | 284 | - 285 | class SmoothedValue(object): 286 | """Track a series of values and provide access to smoothed values over a 287 | window or the global series average. 288 | @@ -78,9 +77,10 @@ class SmoothedValue(object): 289 | 290 | 291 | class MetricLogger(object): 292 | - def __init__(self, delimiter="\t"): 293 | + def __init__(self, delimiter="\t", model_flops=0): 294 | self.meters = defaultdict(SmoothedValue) 295 | self.delimiter = delimiter 296 | + self.model_flops = model_flops 297 | 298 | def update(self, **kwargs): 299 | for k, v in kwargs.items(): 300 | @@ -127,16 +127,27 @@ class MetricLogger(object): 301 | 'eta: {eta}', 302 | '{meters}', 303 | 'time: {time}', 304 | - 'data: {data}' 305 | + 'data: {data}', 306 | + 'throughput: {throughput}', 307 | + 'tflops: {tflops:.2f}' 308 | ] 309 | if torch.cuda.is_available(): 310 | log_msg.append('max mem: {memory:.0f}') 311 | + 312 | log_msg = self.delimiter.join(log_msg) 313 | MB = 1024.0 * 1024.0 314 | + 315 | for obj in iterable: 316 | data_time.update(time.time() - end) 317 | yield obj 318 | + 319 | + batch_size = obj[0].shape[0] 320 | iter_time.update(time.time() - end) 321 | + throughput_per_gpu = int(batch_size / iter_time.value) 322 | + throughput = throughput_per_gpu * get_world_size() 323 | + # First mutiply by 3: 1 for forward, 2 for backward. Then multiply by 2: 1MACs = 2FLOPs 324 | + tflops = 3 * 2 * self.model_flops * throughput_per_gpu / 1e12 325 | + 326 | if i % print_freq == 0 or i == len(iterable) - 1: 327 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 328 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 329 | @@ -145,7 +156,10 @@ class MetricLogger(object): 330 | i, len(iterable), eta=eta_string, 331 | meters=str(self), 332 | time=str(iter_time), data=str(data_time), 333 | - memory=torch.cuda.max_memory_allocated() / MB)) 334 | + throughput=throughput, 335 | + tflops=tflops, 336 | + memory=torch.cuda.max_memory_allocated() / MB, 337 | + )) 338 | else: 339 | print(log_msg.format( 340 | i, len(iterable), eta=eta_string, 341 | -------------------------------------------------------------------------------- /deit/requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.4.12 2 | fvcore -------------------------------------------------------------------------------- /deit/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright (c) Microsoft Corporation - All rights reserved 4 | # Licensed under the MIT License 5 | 6 | set -e 7 | 8 | USAGE="usage: bash run.sh [small|large] [amp|msamp|te-fp8]" 9 | 10 | if [ "$#" -ne 2 ]; then 11 | echo $USAGE 12 | exit 1 13 | fi 14 | 15 | DATA_PATH=../../ImageNet 16 | GPU_NUM=8 17 | 18 | model=$1 19 | amp_type=$2 20 | 21 | if [ "$model" == "small" -a "$amp_type" == "amp" ]; then 22 | echo "run small DeiT with AMP" 23 | python -m torch.distributed.launch \ 24 | --nproc_per_node=$GPU_NUM \ 25 | --use_env \ 26 | ../third_party/deit/main.py \ 27 | --model deit_small_patch16_224 \ 28 | --batch-size 128 \ 29 | --data-path $DATA_PATH \ 30 | --output_dir output \ 31 | --no-model-ema 32 | elif [ "$model" == "small" -a "$amp_type" == "msamp" ]; then 33 | echo "run small DeiT with MS-AMP" 34 | python -m torch.distributed.launch \ 35 | --nproc_per_node=$GPU_NUM \ 36 | --use_env ../third_party/deit/main.py \ 37 | --model deit_small_patch16_224 \ 38 | --batch-size 128 \ 39 | --data-path $DATA_PATH \ 40 | --output_dir output_msamp \ 41 | --no-model-ema \ 42 | --enable-msamp \ 43 | --msamp-opt-level O2 44 | elif [ "$model" == "large" -a "$amp_type" == "amp" ]; then 45 | echo "run large DeiT with AMP" 46 | python -m torch.distributed.launch \ 47 | --nproc_per_node=$GPU_NUM \ 48 | --use_env \ 49 | ../third_party/deit/main.py \ 50 | --model deit_large_patch16_224 \ 51 | --batch-size 64 \ 52 | --data-path $DATA_PATH \ 53 | --output_dir output_large \ 54 | --no-model-ema 55 | elif [ "$model" == "large" -a "$amp_type" == "msamp" ]; then 56 | echo "run large DeiT with MS-AMP" 57 | python -m torch.distributed.launch \ 58 | --nproc_per_node=$GPU_NUM \ 59 | --use_env \ 60 | ../third_party/deit/main.py \ 61 | --model deit_large_patch16_224 \ 62 | --batch-size 64 \ 63 | --data-path $DATA_PATH \ 64 | --output_dir output_large_msamp \ 65 | --no-model-ema \ 66 | --enable-msamp \ 67 | --msamp-opt-level O2 68 | elif [ "$model" == "large" -a "$amp_type" == "te-fp8" ]; then 69 | echo "run large Deit with transformer engine fp8" 70 | python -m torch.distributed.launch \ 71 | --nproc_per_node=$GPU_NUM \ 72 | --use_env \ 73 | ../third_party/deit/main.py \ 74 | --model deit_large_patch16_224 \ 75 | --batch-size 64 \ 76 | --data-path $DATA_PATH \ 77 | --output_dir output_large_te \ 78 | --no-model-ema \ 79 | --enable-te-fp8 80 | else 81 | echo $USAGE 82 | exit 1 83 | fi 84 | -------------------------------------------------------------------------------- /gpt3/Megatron-DeepSpeed.patch: -------------------------------------------------------------------------------- 1 | diff --git a/megatron/arguments.py b/megatron/arguments.py 2 | index bdd1745..4058c11 100644 3 | --- a/megatron/arguments.py 4 | +++ b/megatron/arguments.py 5 | @@ -46,6 +46,7 @@ def parse_args(extra_args_provider=None, defaults={}, 6 | parser = _add_memoryopt_args(parser) 7 | parser = _add_activation_checkpoint_args(parser) 8 | parser = _add_distillation_args(parser) 9 | + parser = _add_msamp_args(parser) 10 | 11 | # Custom arguments. 12 | if extra_args_provider is not None: 13 | @@ -936,3 +937,11 @@ def _add_distillation_args(parser): 14 | help='Directory containing a teacher model checkpoint.') 15 | 16 | return parser 17 | + 18 | + 19 | +def _add_msamp_args(parser): 20 | + group = parser.add_argument_group('MS-AMP', 'MS-AMP configurations') 21 | + group.add_argument('--msamp', action='store_true', help='Enable MS-AMP', default=False) 22 | + group.add_argument('--msamp-opt-level', type=str, default='O2', choices=['O1', 'O2', 'O3'], 23 | + help='MS-AMP optimization level') 24 | + return parser 25 | diff --git a/megatron/mpu/__init__.py b/megatron/mpu/__init__.py 26 | index 32bb5fc..dd894a4 100644 27 | --- a/megatron/mpu/__init__.py 28 | +++ b/megatron/mpu/__init__.py 29 | @@ -23,6 +23,7 @@ from .initialize import is_unitialized 30 | from .initialize import destroy_model_parallel 31 | from .initialize import get_data_parallel_group 32 | from .initialize import get_data_parallel_rank 33 | +from .initialize import get_data_parallel_src_rank 34 | from .initialize import get_data_parallel_world_size 35 | from .initialize import get_embedding_group 36 | from .initialize import get_model_parallel_group 37 | diff --git a/megatron/mpu/initialize.py b/megatron/mpu/initialize.py 38 | index c24d117..dda8269 100644 39 | --- a/megatron/mpu/initialize.py 40 | +++ b/megatron/mpu/initialize.py 41 | @@ -45,6 +45,10 @@ _MPU_PIPELINE_MODEL_PARALLEL_RANK = None 42 | # rank when broadcasting from the first or last pipeline stage 43 | _PIPELINE_GLOBAL_RANKS = None 44 | 45 | +# A list of global ranks for each data parallel group to ease calculation of the source 46 | +# rank when broadcasting weights from src to all other data parallel ranks 47 | +_DATA_PARALLEL_GLOBAL_RANKS = None 48 | + 49 | def is_unitialized(): 50 | """Useful for code segments that may be accessed with or without mpu initialization""" 51 | return _DATA_PARALLEL_GROUP is None 52 | @@ -105,8 +109,13 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, 53 | 54 | # Build the data-parallel groups. 55 | global _DATA_PARALLEL_GROUP 56 | + global _DATA_PARALLEL_GLOBAL_RANKS 57 | + 58 | assert _DATA_PARALLEL_GROUP is None, \ 59 | 'data parallel group is already initialized' 60 | + assert _DATA_PARALLEL_GLOBAL_RANKS is None, \ 61 | + 'data parallel global ranks is already initialized' 62 | + 63 | all_data_parallel_group_ranks = [] 64 | for i in range(pipeline_model_parallel_size): 65 | start_rank = i * num_pipeline_model_parallel_groups 66 | @@ -118,6 +127,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, 67 | group = torch.distributed.new_group(ranks) 68 | if rank in ranks: 69 | _DATA_PARALLEL_GROUP = group 70 | + _DATA_PARALLEL_GLOBAL_RANKS = ranks 71 | 72 | # Build the model-parallel groups. 73 | global _MODEL_PARALLEL_GROUP 74 | @@ -198,6 +208,14 @@ def get_pipeline_model_parallel_group(): 75 | return _PIPELINE_MODEL_PARALLEL_GROUP 76 | 77 | 78 | +def get_data_parallel_src_rank(): 79 | + """Calculate the global rank corresponding to the first local rank 80 | + in the data parallel group.""" 81 | + assert _DATA_PARALLEL_GLOBAL_RANKS is not None, \ 82 | + "Data parallel group is not initialized" 83 | + return _DATA_PARALLEL_GLOBAL_RANKS[0] 84 | + 85 | + 86 | def get_data_parallel_group(): 87 | """Get the data parallel group the caller rank belongs to.""" 88 | assert _DATA_PARALLEL_GROUP is not None, \ 89 | @@ -363,7 +381,6 @@ def get_data_parallel_rank(): 90 | """Return my rank for the data parallel group.""" 91 | return torch.distributed.get_rank(group=get_data_parallel_group()) 92 | 93 | - 94 | def destroy_model_parallel(): 95 | """Set the groups to none.""" 96 | global _TENSOR_MODEL_PARALLEL_GROUP 97 | diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py 98 | index 5d168c6..3554270 100644 99 | --- a/megatron/mpu/layers.py 100 | +++ b/megatron/mpu/layers.py 101 | @@ -21,6 +21,7 @@ 102 | import math 103 | 104 | import torch 105 | +import torch.nn as nn 106 | import torch.nn.functional as F 107 | import torch.nn.init as init 108 | from torch.nn.parameter import Parameter 109 | @@ -254,36 +255,44 @@ class ColumnParallelLinear(torch.nn.Module): 110 | # Initialize weight. 111 | args = get_args() 112 | if args.use_cpu_initialization: 113 | - self.weight = Parameter(torch.empty(self.output_size_per_partition, 114 | + _weight = Parameter(torch.empty(self.output_size_per_partition, 115 | self.input_size, 116 | dtype=args.params_dtype)) 117 | self.master_weight = _initialize_affine_weight_cpu( 118 | - self.weight, self.output_size, self.input_size, 119 | + _weight, self.output_size, self.input_size, 120 | self.output_size_per_partition, 0, init_method, 121 | stride=stride, return_master_weight=keep_master_weight_for_test) 122 | else: 123 | - self.weight = Parameter(torch.empty( 124 | + _weight = Parameter(torch.empty( 125 | self.output_size_per_partition, self.input_size, 126 | device=get_accelerator().current_device_name(), dtype=args.params_dtype)) 127 | - _initialize_affine_weight_gpu(self.weight, init_method, 128 | + _initialize_affine_weight_gpu(_weight, init_method, 129 | partition_dim=0, stride=stride) 130 | 131 | if bias: 132 | if args.use_cpu_initialization: 133 | - self.bias = Parameter(torch.empty( 134 | + _bias = Parameter(torch.empty( 135 | self.output_size_per_partition, dtype=args.params_dtype)) 136 | else: 137 | - self.bias = Parameter(torch.empty( 138 | + _bias = Parameter(torch.empty( 139 | self.output_size_per_partition, 140 | device=get_accelerator().current_device_name(), 141 | dtype=args.params_dtype)) 142 | - set_tensor_model_parallel_attributes(self.bias, True, 0, stride) 143 | + set_tensor_model_parallel_attributes(_bias, True, 0, stride) 144 | # Always initialize bias to zero. 145 | with torch.no_grad(): 146 | - self.bias.zero_() 147 | + _bias.zero_() 148 | else: 149 | self.register_parameter('bias', None) 150 | 151 | + self.linear = nn.Linear(out_features=self.output_size_per_partition, in_features=self.input_size, 152 | + bias=not self.skip_bias_add, dtype=args.params_dtype) 153 | + self.linear.weight = _weight 154 | + if not self.skip_bias_add: 155 | + self.linear.bias = _bias 156 | + else: 157 | + self.output_bias = _bias 158 | + 159 | 160 | 161 | def forward(self, input_): 162 | @@ -294,15 +303,14 @@ class ColumnParallelLinear(torch.nn.Module): 163 | input_parallel = copy_to_tensor_model_parallel_region(input_) 164 | 165 | # Matrix multiply. 166 | + output_parallel = self.linear(input_parallel) 167 | 168 | - bias = self.bias if not self.skip_bias_add else None 169 | - output_parallel = F.linear(input_parallel, self.weight, bias) 170 | if self.gather_output and not self.is_expert_without_slicing: 171 | # All-gather across the partitions. 172 | output = gather_from_tensor_model_parallel_region(output_parallel) 173 | else: 174 | - output = output_parallel 175 | - output_bias = self.bias if self.skip_bias_add else None 176 | + output = output_parallel 177 | + output_bias = self.output_bias if self.skip_bias_add else None 178 | return output, output_bias 179 | 180 | 181 | @@ -365,18 +373,18 @@ class RowParallelLinear(torch.nn.Module): 182 | # Initialize weight. 183 | args = get_args() 184 | if args.use_cpu_initialization: 185 | - self.weight = Parameter(torch.empty(self.output_size, 186 | + _weight = Parameter(torch.empty(self.output_size, 187 | self.input_size_per_partition, 188 | dtype=args.params_dtype)) 189 | self.master_weight = _initialize_affine_weight_cpu( 190 | - self.weight, self.output_size, self.input_size, 191 | + _weight, self.output_size, self.input_size, 192 | self.input_size_per_partition, 1, init_method, 193 | stride=stride, return_master_weight=keep_master_weight_for_test) 194 | else: 195 | - self.weight = Parameter(torch.empty( 196 | + _weight = Parameter(torch.empty( 197 | self.output_size, self.input_size_per_partition, 198 | device=get_accelerator().current_device_name(), dtype=args.params_dtype)) 199 | - _initialize_affine_weight_gpu(self.weight, init_method, 200 | + _initialize_affine_weight_gpu(_weight, init_method, 201 | partition_dim=1, stride=stride) 202 | if bias: 203 | if args.use_cpu_initialization: 204 | @@ -392,6 +400,10 @@ class RowParallelLinear(torch.nn.Module): 205 | else: 206 | self.register_parameter('bias', None) 207 | 208 | + assert skip_bias_add 209 | + self.linear = nn.Linear(out_features=self.output_size, in_features=self.input_size_per_partition, 210 | + bias=False, dtype=args.params_dtype) 211 | + self.linear.weight = _weight 212 | 213 | 214 | def forward(self, input_): 215 | @@ -401,7 +413,8 @@ class RowParallelLinear(torch.nn.Module): 216 | else: 217 | input_parallel = scatter_to_tensor_model_parallel_region(input_) 218 | # Matrix multiply. 219 | - output_parallel = F.linear(input_parallel, self.weight) 220 | + output_parallel = self.linear(input_parallel) 221 | + 222 | # All-reduce across all the partitions. 223 | if self.is_expert_without_slicing: # non-expert only tensor-parallelism 224 | output_ = output_parallel 225 | diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py 226 | index 659d680..19fda6c 100644 227 | --- a/megatron/optimizer/__init__.py 228 | +++ b/megatron/optimizer/__init__.py 229 | @@ -12,6 +12,7 @@ 230 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 231 | # See the License for the specific language governing permissions and 232 | # limitations under the License. 233 | +import torch 234 | from deepspeed.accelerator import get_accelerator 235 | if get_accelerator().device_name() == 'cuda': 236 | from apex.optimizers import FusedAdam as Adam 237 | @@ -20,6 +21,7 @@ else: 238 | from torch.optim import Adam 239 | from torch.optim import SGD 240 | 241 | +from msamp.optim import LBAdamW 242 | 243 | from megatron import get_args 244 | from megatron.model import LayerNorm 245 | @@ -59,7 +61,7 @@ def get_megatron_optimizer(model): 246 | if args.create_moe_param_group: 247 | from deepspeed.moe.utils import is_moe_param, split_params_into_different_moe_groups_for_optimizer 248 | param_groups = split_params_into_different_moe_groups_for_optimizer(param_groups) 249 | - 250 | + 251 | if args.cpu_optimizer: 252 | assert args.optimizer == 'adam', 'CPU offloading is for Adam' 253 | if args.cpu_torch_adam: 254 | @@ -71,20 +73,40 @@ def get_megatron_optimizer(model): 255 | lr=args.lr, 256 | weight_decay=args.weight_decay) 257 | else: 258 | - if args.optimizer == 'adam': 259 | - optimizer = Adam(param_groups, 260 | - lr=args.lr, 261 | - weight_decay=args.weight_decay, 262 | - betas=(args.adam_beta1, args.adam_beta2), 263 | - eps=args.adam_eps) 264 | - elif args.optimizer == 'sgd': 265 | - optimizer = SGD(param_groups, 266 | - lr=args.lr, 267 | - weight_decay=args.weight_decay, 268 | - momentum=args.sgd_momentum) 269 | + if args.msamp: 270 | + print(f"Using MS-AMP optimizer, opt_level is {args.msamp_opt_level}") 271 | + if args.msamp_opt_level == 'O2' or args.msamp_opt_level == 'O3': 272 | + exp_avg_dtype = torch.uint8 273 | + exp_avg_sq_dtype = torch.float16 274 | + elif args.msamp_opt_level == 'O1': 275 | + exp_avg_dtype = torch.float32 276 | + exp_avg_sq_dtype = torch.float32 277 | + else: 278 | + raise Exception(f'Unsupported msamp_opt_level: {args.msamp_opt_level}') 279 | + optimizer = LBAdamW(param_groups, 280 | + lr=args.lr, 281 | + weight_decay=args.weight_decay, 282 | + betas=(args.adam_beta1, args.adam_beta2), 283 | + eps=args.adam_eps, 284 | + exp_avg_dtype=exp_avg_dtype, 285 | + exp_avg_sq_dtype=exp_avg_sq_dtype, 286 | + tensor_scale=True, 287 | + ) 288 | else: 289 | - raise Exception('{} optimizer is not supported.'.format( 290 | - args.optimizer)) 291 | + if args.optimizer == 'adam': 292 | + optimizer = Adam(param_groups, 293 | + lr=args.lr, 294 | + weight_decay=args.weight_decay, 295 | + betas=(args.adam_beta1, args.adam_beta2), 296 | + eps=args.adam_eps) 297 | + elif args.optimizer == 'sgd': 298 | + optimizer = SGD(param_groups, 299 | + lr=args.lr, 300 | + weight_decay=args.weight_decay, 301 | + momentum=args.sgd_momentum) 302 | + else: 303 | + raise Exception('{} optimizer is not supported.'.format( 304 | + args.optimizer)) 305 | 306 | if args.deepspeed: 307 | return optimizer 308 | diff --git a/megatron/training.py b/megatron/training.py 309 | index 94133e7..7100f03 100644 310 | --- a/megatron/training.py 311 | +++ b/megatron/training.py 312 | @@ -60,6 +60,9 @@ from deepspeed.compression.compress import init_compression, redundancy_clean 313 | from megatron.model.transformer import ParallelTransformerLayer 314 | from deepspeed.runtime.data_pipeline.data_routing.helper import convert_to_random_ltd 315 | 316 | +from msamp import deepspeed 317 | +from msamp.nn import LinearReplacer 318 | + 319 | def print_datetime(string): 320 | """Note that this call will sync across all ranks.""" 321 | torch.distributed.barrier() 322 | @@ -436,6 +439,14 @@ def setup_model_and_optimizer(model_provider_func, teacher=False, 323 | 324 | model = get_model(model_provider_func) 325 | 326 | + if args.msamp: 327 | + assert len(model) == 1 328 | + model[0] = LinearReplacer.replace(model[0], 329 | + src_rank=mpu.get_data_parallel_src_rank(), 330 | + group=mpu.get_data_parallel_group()) 331 | + print('after replaced with FP8Linear, model is: ') 332 | + print(model[0]) 333 | + 334 | # initialize the compression here 335 | student_global_steps = 0 336 | if args.kd or args.mos: 337 | -------------------------------------------------------------------------------- /gpt3/Megatron-LM.patch: -------------------------------------------------------------------------------- 1 | diff --git a/megatron/arguments.py b/megatron/arguments.py 2 | index ae42b83e..f427bc50 100644 3 | --- a/megatron/arguments.py 4 | +++ b/megatron/arguments.py 5 | @@ -38,6 +38,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): 6 | parser = _add_inference_args(parser) 7 | parser = _add_transformer_engine_args(parser) 8 | parser = _add_retro_args(parser) 9 | + parser = _add_msamp_args(parser) 10 | 11 | # Custom arguments. 12 | if extra_args_provider is not None: 13 | @@ -1306,3 +1307,10 @@ def _add_vision_args(parser): 14 | help='warmup teacher temperaure epochs') 15 | 16 | return parser 17 | + 18 | + 19 | +def _add_msamp_args(parser): 20 | + group = parser.add_argument_group(title="msamp") 21 | + group.add_argument('--msamp', action='store_true', default=False, 22 | + help='whether to enable msamp') 23 | + return parser 24 | \ No newline at end of file 25 | diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py 26 | index a86444cc..600f49d8 100644 27 | --- a/megatron/core/tensor_parallel/layers.py 28 | +++ b/megatron/core/tensor_parallel/layers.py 29 | @@ -439,7 +439,9 @@ def linear_with_grad_accumulation_and_async_allreduce( 30 | "maximum speedup" 31 | ) 32 | linear_with_grad_accumulation_and_async_allreduce.warned = True 33 | - 34 | + if hasattr(weight, '_scaling_metas'): 35 | + from msamp.megatron import FP8LinearWithGradAccumulationAndAsyncCommunication 36 | + return FP8LinearWithGradAccumulationAndAsyncCommunication.apply(*args) 37 | return LinearWithGradAccumulationAndAsyncCommunication.apply(*args) 38 | 39 | 40 | @@ -513,14 +515,14 @@ class ColumnParallelLinear(torch.nn.Module): 41 | # Initialize weight. 42 | if not skip_weight_param_allocation: 43 | if config.use_cpu_initialization: 44 | - self.weight = Parameter( 45 | + _weight = Parameter( 46 | torch.empty( 47 | self.output_size_per_partition, self.input_size, dtype=config.params_dtype 48 | ) 49 | ) 50 | if config.perform_initialization: 51 | self.master_weight = _initialize_affine_weight_cpu( 52 | - self.weight, 53 | + _weight, 54 | self.output_size, 55 | self.input_size, 56 | self.output_size_per_partition, 57 | @@ -530,7 +532,7 @@ class ColumnParallelLinear(torch.nn.Module): 58 | return_master_weight=keep_master_weight_for_test, 59 | ) 60 | else: 61 | - self.weight = Parameter( 62 | + _weight = Parameter( 63 | torch.empty( 64 | self.output_size_per_partition, 65 | self.input_size, 66 | @@ -540,10 +542,10 @@ class ColumnParallelLinear(torch.nn.Module): 67 | ) 68 | if config.perform_initialization: 69 | _initialize_affine_weight_gpu( 70 | - self.weight, init_method, partition_dim=0, stride=stride 71 | + _weight, init_method, partition_dim=0, stride=stride 72 | ) 73 | else: 74 | - self.weight = None 75 | + _weight = None 76 | 77 | if bias: 78 | if config.use_cpu_initialization: 79 | @@ -597,6 +599,17 @@ class ColumnParallelLinear(torch.nn.Module): 80 | ) 81 | 82 | self._forward_impl = linear_with_grad_accumulation_and_async_allreduce 83 | + self.linear = torch.nn.Linear(self.input_size, self.output_size_per_partition, bias=False, dtype=config.params_dtype) 84 | + assert self.linear.weight.shape == _weight.shape 85 | + self.linear.weight = _weight 86 | + 87 | + @property 88 | + def weight(self): 89 | + return self.linear.weight 90 | + 91 | + @weight.setter 92 | + def weight(self, value): 93 | + raise RuntimeError('Do not set weight.') 94 | 95 | def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None): 96 | """Forward of ColumnParallelLinear 97 | @@ -722,14 +735,14 @@ class RowParallelLinear(torch.nn.Module): 98 | # we allocate the transpose. 99 | # Initialize weight. 100 | if config.use_cpu_initialization: 101 | - self.weight = Parameter( 102 | + _weight = Parameter( 103 | torch.empty( 104 | self.output_size, self.input_size_per_partition, dtype=config.params_dtype 105 | ) 106 | ) 107 | if config.perform_initialization: 108 | self.master_weight = _initialize_affine_weight_cpu( 109 | - self.weight, 110 | + _weight, 111 | self.output_size, 112 | self.input_size, 113 | self.input_size_per_partition, 114 | @@ -740,7 +753,7 @@ class RowParallelLinear(torch.nn.Module): 115 | params_dtype=config.params_dtype, 116 | ) 117 | else: 118 | - self.weight = Parameter( 119 | + _weight = Parameter( 120 | torch.empty( 121 | self.output_size, 122 | self.input_size_per_partition, 123 | @@ -750,7 +763,7 @@ class RowParallelLinear(torch.nn.Module): 124 | ) 125 | if config.perform_initialization: 126 | _initialize_affine_weight_gpu( 127 | - self.weight, init_method, partition_dim=1, stride=stride 128 | + _weight, init_method, partition_dim=1, stride=stride 129 | ) 130 | if bias: 131 | if config.use_cpu_initialization: 132 | @@ -774,6 +787,18 @@ class RowParallelLinear(torch.nn.Module): 133 | 134 | self._forward_impl = linear_with_grad_accumulation_and_async_allreduce 135 | 136 | + self.linear = torch.nn.Linear(self.input_size_per_partition, self.output_size, bias=False, dtype=config.params_dtype) 137 | + assert self.linear.weight.shape == _weight.shape 138 | + self.linear.weight = _weight 139 | + 140 | + @property 141 | + def weight(self): 142 | + return self.linear.weight 143 | + 144 | + @weight.setter 145 | + def weight(self, value): 146 | + raise RuntimeError('Do not set weight.') 147 | + 148 | def forward(self, input_): 149 | """Forward of RowParallelLinear 150 | 151 | diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py 152 | index 7aca206c..1368434a 100644 153 | --- a/megatron/model/transformer.py 154 | +++ b/megatron/model/transformer.py 155 | @@ -1418,8 +1418,8 @@ class ParallelTransformer(MegatronModule): 156 | tp_group=mpu.get_tensor_model_parallel_group(), 157 | get_rng_state_tracker=tensor_parallel.get_cuda_rng_tracker, 158 | fuse_wgrad_accumulation=config.gradient_accumulation_fusion, 159 | - apply_query_key_layer_scaling=config.apply_query_key_layer_scaling, 160 | - attention_softmax_in_fp32=config.attention_softmax_in_fp32, 161 | + # apply_query_key_layer_scaling=config.apply_query_key_layer_scaling, 162 | + # attention_softmax_in_fp32=config.attention_softmax_in_fp32, 163 | seq_length=args.seq_length, 164 | micro_batch_size=args.micro_batch_size, 165 | sequence_parallel=config.sequence_parallel, 166 | diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py 167 | index 484e9b32..e85984d7 100644 168 | --- a/megatron/optimizer/__init__.py 169 | +++ b/megatron/optimizer/__init__.py 170 | @@ -5,10 +5,12 @@ from apex.optimizers import FusedSGD as SGD 171 | 172 | from megatron import get_args 173 | 174 | -from .distrib_optimizer import DistributedOptimizer 175 | from .grad_scaler import ConstantGradScaler, DynamicGradScaler 176 | from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer 177 | 178 | +import torch 179 | +from msamp.optim import LBAdamW 180 | +from msamp.megatron import FP8DistributedOptimizer as DistributedOptimizer 181 | 182 | def get_param_groups(modules, 183 | no_weight_decay_cond, 184 | @@ -73,11 +75,21 @@ def get_megatron_optimizer(model, 185 | lr_mult) 186 | 187 | if args.optimizer == 'adam': 188 | - optimizer = Adam(param_groups, 189 | - lr=args.lr, 190 | - weight_decay=args.weight_decay, 191 | - betas=(args.adam_beta1, args.adam_beta2), 192 | - eps=args.adam_eps) 193 | + if args.msamp: 194 | + exp_avg_dtype, exp_avg_sq_dtype = torch.uint8, torch.float16 195 | + optimizer = LBAdamW(param_groups, 196 | + lr=args.lr, 197 | + weight_decay=args.weight_decay, 198 | + betas=(args.adam_beta1, args.adam_beta2), 199 | + eps=args.adam_eps, 200 | + exp_avg_dtype=exp_avg_dtype, exp_avg_sq_dtype=exp_avg_sq_dtype, 201 | + tensor_scale=True) 202 | + else: 203 | + optimizer = Adam(param_groups, 204 | + lr=args.lr, 205 | + weight_decay=args.weight_decay, 206 | + betas=(args.adam_beta1, args.adam_beta2), 207 | + eps=args.adam_eps) 208 | elif args.optimizer == 'sgd': 209 | optimizer = SGD(param_groups, 210 | lr=args.lr, 211 | diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py 212 | index da9cd70f..414fd887 100644 213 | --- a/megatron/optimizer/optimizer.py 214 | +++ b/megatron/optimizer/optimizer.py 215 | @@ -13,13 +13,15 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 216 | from megatron import get_timers 217 | from megatron import print_rank_0 218 | from megatron.core import mpu, tensor_parallel 219 | -from megatron.model import DistributedDataParallel as LocalDDP 220 | +# from megatron.model import DistributedDataParallel as LocalDDP 221 | from megatron.model import Float16Module 222 | from megatron.model.module import param_is_not_shared 223 | from megatron.utils import unwrap_model 224 | 225 | -from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32 226 | +from .clip_grads import count_zeros_fp32 227 | 228 | +from msamp.megatron import clip_grad_norm_fp32 229 | +from msamp.megatron import FP8DistributedDataParallel as LocalDDP 230 | 231 | def _zero_grad_group_helper(group, set_to_none): 232 | """Zero out the gradient for a group of parameters. 233 | diff --git a/megatron/training.py b/megatron/training.py 234 | index b821ae7b..99a7fadb 100644 235 | --- a/megatron/training.py 236 | +++ b/megatron/training.py 237 | @@ -33,7 +33,7 @@ from megatron.initialize import initialize_megatron 238 | from megatron.initialize import write_args_to_tensorboard 239 | from megatron.initialize import set_jit_fusion_options 240 | from megatron.optimizer_param_scheduler import OptimizerParamScheduler 241 | -from megatron.model import DistributedDataParallel as LocalDDP 242 | +# from megatron.model import DistributedDataParallel as LocalDDP 243 | from megatron.utils import check_adlr_autoresume_termination 244 | from megatron.utils import unwrap_model 245 | from megatron.data.data_samplers import build_pretraining_data_loader 246 | @@ -42,6 +42,10 @@ from megatron.core.pipeline_parallel import get_forward_backward_func 247 | from megatron.utils import report_memory 248 | from megatron.model.vision.knn_monitor import compute_feature_bank 249 | 250 | +from msamp.nn import LinearReplacer 251 | +from msamp.common.dtype import Dtypes 252 | +from msamp.nn.state import model_state 253 | +from msamp.megatron import FP8DistributedDataParallel as LocalDDP 254 | 255 | def print_datetime(string): 256 | """Note that this call will sync across all ranks.""" 257 | @@ -216,6 +220,9 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap 258 | args = get_args() 259 | args.model_type = model_type 260 | 261 | + if args.msamp and args.transformer_impl == 'transformer_engine': 262 | + import msamp.te 263 | + 264 | # Build model. 265 | if mpu.get_pipeline_model_parallel_world_size() > 1 and \ 266 | args.virtual_pipeline_model_parallel_size is not None: 267 | @@ -296,6 +303,20 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap 268 | if args.fp16 or args.bf16: 269 | model = [Float16Module(model_module, args) for model_module in model] 270 | 271 | + if args.msamp: 272 | + print_rank_0("msamp is enabled") 273 | + model_state.use_fp8_ddp = True 274 | + for i in range(len(model)): 275 | + if args.transformer_impl == 'transformer_engine': 276 | + from msamp.te import TeReplacer 277 | + model[i] = TeReplacer.replace(model[i]) 278 | + else: 279 | + model[i] = LinearReplacer.replace(model[i], Dtypes.kfloat16, 280 | + src_rank=mpu.get_data_parallel_src_rank(), 281 | + group=mpu.get_data_parallel_group()) 282 | + 283 | + print_rank_0(model[i]) 284 | + 285 | if wrap_with_ddp: 286 | if args.DDP_impl == 'torch': 287 | i = torch.cuda.current_device() 288 | @@ -629,6 +650,22 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, 289 | if iteration % args.log_interval == 0: 290 | elapsed_time = timers('interval-time').elapsed(barrier=True) 291 | elapsed_time_per_iteration = elapsed_time / total_iterations 292 | + 293 | + # Compute throughput. 294 | + samples_per_sec = batch_size / elapsed_time_per_iteration 295 | + 296 | + # Compute tflops. 297 | + seq_len = args.seq_length 298 | + hidden_size = args.hidden_size 299 | + num_layers = args.num_layers 300 | + vocab_size = args.padded_vocab_size 301 | + 302 | + checkpoint_activations_factor = 4 if args.recompute_granularity else 3 303 | + print_rank_last(f'checkpoint_activations_factor: {checkpoint_activations_factor}') 304 | + coefficient = 24 305 | + flops_per_iteration = (coefficient * checkpoint_activations_factor * batch_size * seq_len * num_layers * (hidden_size**2)) * (1. + (seq_len / (6. * hidden_size)) + (vocab_size / (16. * num_layers * hidden_size))) 306 | + tflops = flops_per_iteration / (elapsed_time_per_iteration * args.world_size * (10**12)) 307 | + 308 | if writer: 309 | if args.log_timers_to_tensorboard: 310 | writer.add_scalar('iteration-time', 311 | @@ -660,6 +697,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, 312 | total_loss_dict[skipped_iters_key]) 313 | log_string += ' number of nan iterations: {:3d} |'.format( 314 | total_loss_dict[nan_iters_key]) 315 | + 316 | + log_string += ' samples per second: {:.3f} |'.format(samples_per_sec) 317 | + log_string += ' TFLOPs: {:.2f} |'.format(tflops) 318 | + 319 | total_loss_dict[advanced_iters_key] = 0 320 | total_loss_dict[skipped_iters_key] = 0 321 | total_loss_dict[nan_iters_key] = 0 322 | -------------------------------------------------------------------------------- /gpt3/README.md: -------------------------------------------------------------------------------- 1 | # This is an example of GPT3 using MS-AMP 2 | We support both of Megatron-DeepSpeed and Megatron-LM. You can choose either of them to run GPT-3. 3 | 4 | ## Install dependencies 5 | You need to install depedencies before training GPT3. It is recommended to use venv for virtual environments, but it is not strictly necessary. 6 | ```bash 7 | pip install einops nltk wikiextractor 8 | ``` 9 | 10 | ## Data preparation 11 | Currently we haven't published the data we use in this example. But we provide a script of preprocessing Wikipedia data from scatch. Make sure you have more than 40GB space on your disk and it may take ~4 hours. You can also also use your own data. 12 | 13 | ```bash 14 | bash prepare_wikipedia.sh 15 | ``` 16 | After running the above command, a folder named data will be generated. The file structure should look like: 17 | ``` 18 | $ tree data/ 19 | data 20 | ├── gpt2-merges.txt 21 | ├── gpt2-vocab.json 22 | ├── wikipedia_text_document.bin 23 | └── wikipedia_text_document.idx 24 | ``` 25 | 26 | ## Using Megatron-LM 27 | 28 | ### Apply patch to Megatron-LM 29 | We made a few changes to the official Megatron-LM and packaged it into a patch. You need to apply this patch to third_party/Megatron-LM. 30 | ```bash 31 | cd ../third_party/Megatron-LM 32 | git apply ../../gpt3/Megatron-LM.patch 33 | cd ../../gpt3 34 | ``` 35 | 36 | Please note that if you are using GPU that does not support FP8 computation, such as Nvidia A100, you need to delete `--fp8-hybrid` in pretrain_xx_megatron.sh first. 37 | 38 | ### Pretrain GPT3-345m 39 | Run the following command to train 345M GPT3 using bf16, Transformer-Engine and MS-AMP: 40 | ```bash 41 | bash pretrain_345m_megatron.sh bf16 42 | bash pretrain_345m_megatron.sh te 43 | bash pretrain_345m_megatron.sh msamp 44 | ``` 45 | 46 | Please note that currently MS-AMP may not outperform Transformer-Engine for small models. 47 | 48 | ### Pretrain GPT3-6.7b 49 | Run the following command to train 6.7B GPT3 using bf16, Transformer-Engine and MS-AMP: 50 | ```bash 51 | bash pretrain_6b7_megatron.sh bf16 52 | bash pretrain_6b7_megatron.sh te 53 | bash pretrain_6b7_megatron.sh msamp 54 | ``` 55 | 56 | ### Pretrain GPT3-13b 57 | Run the following command to train 13B GPT3 using bf16, Transformer-Engine and MS-AMP: 58 | ```bash 59 | bash pretrain_13b_megatron.sh bf16 60 | bash pretrain_13b_megatron.sh te 61 | bash pretrain_13b_megatron.sh msamp 62 | ``` 63 | You may get out-of-memory error when using Tranformer-Engine since Transformer-Engine consumes more memory than bf16 and MS-AMP. 64 | 65 | ## Using Megatron-DeepSpeed 66 | 67 | ### Apply patch to Megatron-DeepSpeed 68 | We made a few changes to the official Megatron-DeepSpeed and packaged it into a patch. You need to apply this patch to third_party/Megatron-DeepSpeed. 69 | ```bash 70 | cd ../third_party/Megatron-DeepSpeed 71 | git apply ../../gpt3/Megatron-DeepSpeed.patch 72 | cd ../../gpt3 73 | ``` 74 | 75 | ### Pretrain GPT3-345m 76 | Run the following command to train 345M GPT3 using fp16 and MS-AMP: 77 | ```bash 78 | bash pretrain_345m_megatron_ds.sh fp16 79 | bash pretrain_345m_megatron_ds.sh msamp 80 | ``` 81 | 82 | ### Pretrain GPT3-13b 83 | 84 | Run the following command to train 13B GPT3 using bf16 and MS-AMP: 85 | ```bash 86 | bash pretrain_13b_megatron_ds.sh bf16 87 | bash pretrain_13b_megatron_ds.sh msamp 88 | ``` 89 | 90 | ## Multi-node training 91 | If you want to train GPT-3 with Megatron-LM using multiple nodes, you need: 92 | - Upload training data to a shared storage and mount the shared storage to each node. 93 | - Change MASTER_ADDR, NNODES, NODE_RANK in the script. 94 | - [optional] Set some environment variables related to RDMA before running the script. For example, if you are using [ND H100 v5](https://learn.microsoft.com/en-us/azure/virtual-machines/nd-h100-v5-series), you need to set these environment variables: 95 | ```bash 96 | export NCCL_IB_PCI_RELAXED_ORDERING=1 97 | export NCCL_SOCKET_IFNAME=eth0 98 | export CUDA_DEVICE_ORDER=PCI_BUS_ID 99 | export NCCL_NET_GDR_LEVEL=5 100 | export NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml 101 | export NCCL_DEBUG=WARN 102 | ``` 103 | - Use a parallel ssh tool to start the script on all nodes. 104 | -------------------------------------------------------------------------------- /gpt3/prepare_wikipedia.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) Microsoft Corporation - All rights reserved 4 | # Licensed under the MIT License 5 | 6 | set -e 7 | 8 | mkdir -p data 9 | cd data 10 | 11 | echo "start to download Wikipedia dump" 12 | wget https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2 13 | 14 | echo "download completed, start to extract json files" 15 | python -m wikiextractor.WikiExtractor --json enwiki-latest-pages-articles.xml.bz2 16 | rm -rf enwiki-latest-pages-articles.xml.bz2 17 | 18 | echo "extract completed, start to merge json files" 19 | ouput_json="wiki_all.json" 20 | 21 | find text/ -type f -print0 | 22 | while IFS= read -r -d '' line; do 23 | filename=$(echo "$line" | rev | cut -d'/' -f 1 | rev) 24 | subfilename=$(echo "$line" | rev | cut -d'/' -f 2 | rev) 25 | prefix="${subfilename}_${filename}" 26 | new_name=$(echo "$line") 27 | echo "Procesing $prefix, $filename, $new_name" 28 | cat $new_name >> $ouput_json 29 | done 30 | rm -rf text/ 31 | 32 | echo "merge completed, start to preprocess" 33 | wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json 34 | wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt 35 | python ../../third_party/Megatron-DeepSpeed/tools/preprocess_data.py \ 36 | --input $ouput_json \ 37 | --output-prefix wikipedia \ 38 | --vocab gpt2-vocab.json \ 39 | --dataset-impl mmap \ 40 | --tokenizer-type GPT2BPETokenizer \ 41 | --merge-file gpt2-merges.txt \ 42 | --append-eod \ 43 | --workers 70 44 | 45 | rm -rf $ouput_json 46 | 47 | cd ../ 48 | -------------------------------------------------------------------------------- /gpt3/pretrain_13b_megatron.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Runs the "13B" parameter model 4 | # GPT-13B: 40 layers, 5120 hidden size, 40 attention heads 5 | 6 | set -e 7 | 8 | USAGE="usage: bash pretrain_13b_megatron.sh [bf16|te|msamp]" 9 | 10 | if [ "$#" -ne 1 ]; then 11 | echo $USAGE 12 | exit 1 13 | fi 14 | 15 | FP_TYPE=$1 16 | 17 | export CUDA_DEVICE_MAX_CONNECTIONS=1 18 | GPUS_PER_NODE=8 19 | # Change for multinode config 20 | MASTER_ADDR=localhost 21 | MASTER_PORT=6000 22 | NNODES=1 23 | NODE_RANK=0 24 | WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) 25 | 26 | VOCAB_FILE=$PWD/data/gpt2-vocab.json 27 | MERGE_FILE=$PWD/data/gpt2-merges.txt 28 | DATA_PATH=$PWD/data/wikipedia_text_document 29 | 30 | DISTRIBUTED_ARGS=" 31 | --nproc_per_node $GPUS_PER_NODE \ 32 | --nnodes $NNODES \ 33 | --node_rank $NODE_RANK \ 34 | --master_addr $MASTER_ADDR \ 35 | --master_port $MASTER_PORT 36 | " 37 | 38 | GPT_ARGS=" 39 | --tensor-model-parallel-size 2 \ 40 | --pipeline-model-parallel-size 1 \ 41 | --distributed-backend nccl \ 42 | --no-query-key-layer-scaling \ 43 | --seed 43 \ 44 | --num-layers 40 \ 45 | --hidden-size 5120 \ 46 | --num-attention-heads 40 \ 47 | --seq-length 2048 \ 48 | --max-position-embeddings 2048 \ 49 | --train-samples 146484375 \ 50 | --lr-decay-samples 131835938 \ 51 | --lr-warmup-samples 4096000 \ 52 | --lr 2.0e-4 \ 53 | --min-lr 2.0e-5 \ 54 | --lr-decay-style cosine \ 55 | --micro-batch-size 1 \ 56 | --global-batch-size 1280 \ 57 | --clip-grad 1.0 \ 58 | --weight-decay 0.1 \ 59 | --attention-dropout 0.0 \ 60 | --hidden-dropout 0.0 \ 61 | --optimizer adam \ 62 | --adam-beta1 0.9 \ 63 | --adam-beta2 0.95 \ 64 | --init-method-std 0.0099 \ 65 | --num-workers 1 \ 66 | --bf16 \ 67 | --sequence-parallel \ 68 | --use-flash-attn \ 69 | --no-gradient-accumulation-fusion \ 70 | --use-distributed-optimizer 71 | " 72 | 73 | DATA_ARGS=" 74 | --data-path $DATA_PATH \ 75 | --vocab-file $VOCAB_FILE \ 76 | --merge-file $MERGE_FILE \ 77 | --data-impl mmap \ 78 | --split 949,50,1 79 | " 80 | 81 | OUTPUT_ARGS=" 82 | --log-interval 1 \ 83 | --save-interval 2000 \ 84 | --eval-interval 200 \ 85 | --eval-iters 7 86 | " 87 | 88 | if [ "$FP_TYPE" = "bf16" ]; then 89 | CHECKPOINT_PATH=$PWD/checkpoints/gpt_13b_bf16 90 | torchrun $DISTRIBUTED_ARGS ../third_party/Megatron-LM/pretrain_gpt.py \ 91 | $GPT_ARGS \ 92 | $DATA_ARGS \ 93 | $OUTPUT_ARGS \ 94 | --save $CHECKPOINT_PATH \ 95 | --load $CHECKPOINT_PATH 96 | elif [ "$FP_TYPE" = "te" ]; then 97 | CHECKPOINT_PATH=$PWD/checkpoints/gpt_13b_te 98 | torchrun $DISTRIBUTED_ARGS ../third_party/Megatron-LM/pretrain_gpt.py \ 99 | $GPT_ARGS \ 100 | $DATA_ARGS \ 101 | $OUTPUT_ARGS \ 102 | --fp8-hybrid \ 103 | --transformer-impl transformer_engine \ 104 | --save $CHECKPOINT_PATH \ 105 | --load $CHECKPOINT_PATH 106 | elif [ "$FP_TYPE" = "msamp" ]; then 107 | CHECKPOINT_PATH=$PWD/checkpoints/gpt_13b_msamp 108 | torchrun $DISTRIBUTED_ARGS ../third_party/Megatron-LM/pretrain_gpt.py \ 109 | $GPT_ARGS \ 110 | $DATA_ARGS \ 111 | $OUTPUT_ARGS \ 112 | --fp8-hybrid \ 113 | --transformer-impl transformer_engine \ 114 | --msamp \ 115 | --save $CHECKPOINT_PATH \ 116 | --load $CHECKPOINT_PATH 117 | else 118 | echo $USAGE 119 | exit 1 120 | fi 121 | -------------------------------------------------------------------------------- /gpt3/pretrain_13b_megatron_ds.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # Runs the "13B" parameter model 4 | # GPT-13B: 40 layers, 5120 hidden size, 40 attention heads 5 | 6 | set -e 7 | 8 | USAGE="usage: bash pretrain_13b_megatron_ds.sh [bf16|msamp]" 9 | 10 | if [ "$#" -ne 1 ]; then 11 | echo $USAGE 12 | exit 1 13 | fi 14 | 15 | FP_TYPE=$1 16 | NODE_RANK=0 17 | NNODES=1 18 | GPUS_PER_NODE=8 19 | MASTER_ADDR=127.0.0.1 20 | MASTER_PORT=6001 21 | 22 | DATA_PATH=$PWD/data/wikipedia_text_document 23 | DATA_PATH=$PWD/data/wikipedia_text_document 24 | DATASET="1.0 ${DATA_PATH}" 25 | BS=4 26 | PP=1 27 | TP=2 28 | CLIP_GRAD=1.0 29 | GLOBAL_BATCH_SIZE=1280 30 | LOG_INTERVAL=1 31 | ZERO_STAGE=1 32 | VOCAB_FILE=$PWD/data/gpt2-vocab.json 33 | MERGE_FILE=$PWD/data/gpt2-merges.txt 34 | 35 | GPT_ARGS=" 36 | --tensor-model-parallel-size $TP \ 37 | --pipeline-model-parallel-size $PP \ 38 | --distributed-backend nccl \ 39 | --no-query-key-layer-scaling \ 40 | --seed 43 \ 41 | --num-layers 40 \ 42 | --hidden-size 5120 \ 43 | --num-attention-heads 40 \ 44 | --seq-length 2048 \ 45 | --max-position-embeddings 2048 \ 46 | --train-samples 146484375 \ 47 | --lr-decay-samples 131835938 \ 48 | --lr-warmup-samples 4096000 \ 49 | --lr 2.0e-4 \ 50 | --min-lr 2.0e-5 \ 51 | --lr-decay-style cosine \ 52 | --micro-batch-size $BS \ 53 | --global-batch-size $GLOBAL_BATCH_SIZE \ 54 | --clip-grad $CLIP_GRAD \ 55 | --weight-decay 0.1 \ 56 | --attention-dropout 0.0 \ 57 | --hidden-dropout 0.0 \ 58 | --optimizer adam \ 59 | --adam-beta1 0.9 \ 60 | --adam-beta2 0.95 \ 61 | --init-method-std 0.0099 \ 62 | --num-workers 1 \ 63 | --bf16 \ 64 | --checkpoint-activations \ 65 | " 66 | 67 | DATA_ARGS=" 68 | --data-path $DATASET \ 69 | --vocab-file $VOCAB_FILE \ 70 | --merge-file $MERGE_FILE \ 71 | --data-impl mmap \ 72 | --split 949,50,1 73 | " 74 | 75 | OUTPUT_ARGS=" 76 | --log-interval $LOG_INTERVAL \ 77 | --eval-iters 7 \ 78 | --eval-interval 200 \ 79 | --save-interval 2000 \ 80 | " 81 | 82 | DISTRIBUTED_ARGS=" 83 | --nproc_per_node $GPUS_PER_NODE \ 84 | --nnodes $NNODES \ 85 | --node_rank $NODE_RANK \ 86 | --master_addr $MASTER_ADDR \ 87 | --master_port $MASTER_PORT \ 88 | " 89 | 90 | config_json="./ds_config.json" 91 | 92 | cat <$config_json 93 | { 94 | "train_micro_batch_size_per_gpu": $BS, 95 | "train_batch_size": $GLOBAL_BATCH_SIZE, 96 | "gradient_clipping": $CLIP_GRAD, 97 | "zero_optimization": { 98 | "stage": $ZERO_STAGE 99 | }, 100 | "bf16": { 101 | "enabled": true 102 | }, 103 | "steps_per_print": $LOG_INTERVAL 104 | } 105 | EOT 106 | 107 | DEEPSPEED_ARGS=" \ 108 | --deepspeed \ 109 | --deepspeed_config ${config_json} \ 110 | --zero-stage ${ZERO_STAGE} \ 111 | --deepspeed-activation-checkpointing \ 112 | " 113 | 114 | export CUDA_DEVICE_MAX_CONNECTIONS=1 115 | 116 | if [ "$FP_TYPE" = "bf16" ]; then 117 | echo "run 13b GPT3 with bf16" 118 | CHECKPOINT_PATH=$PWD/checkpoints/gpt_13b_bf16 119 | torchrun $DISTRIBUTED_ARGS \ 120 | ../third_party/Megatron-DeepSpeed/pretrain_gpt.py \ 121 | $GPT_ARGS \ 122 | $DATA_ARGS \ 123 | $OUTPUT_ARGS \ 124 | --save $CHECKPOINT_PATH \ 125 | --load $CHECKPOINT_PATH \ 126 | $DEEPSPEED_ARGS 127 | elif [ "$FP_TYPE" = "msamp" ]; then 128 | echo "run 13b GPT3 with msamp" 129 | CHECKPOINT_PATH=$PWD/checkpoints/gpt_13b_msamp 130 | torchrun $DISTRIBUTED_ARGS \ 131 | ../third_party/Megatron-DeepSpeed/pretrain_gpt.py \ 132 | $GPT_ARGS \ 133 | $DATA_ARGS \ 134 | $OUTPUT_ARGS \ 135 | --save $CHECKPOINT_PATH \ 136 | --load $CHECKPOINT_PATH \ 137 | --msamp \ 138 | --msamp-opt-level O3 \ 139 | $DEEPSPEED_ARGS 140 | else 141 | echo $USAGE 142 | exit 1 143 | fi 144 | -------------------------------------------------------------------------------- /gpt3/pretrain_345m_megatron.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Runs the "345M" parameter model 4 | 5 | set -e 6 | 7 | USAGE="usage: bash pretrain_345m_megatron.sh [bf16|te|msamp]" 8 | 9 | if [ "$#" -ne 1 ]; then 10 | echo $USAGE 11 | exit 1 12 | fi 13 | 14 | FP_TYPE=$1 15 | 16 | export CUDA_DEVICE_MAX_CONNECTIONS=1 17 | GPUS_PER_NODE=8 18 | # Change for multinode config 19 | MASTER_ADDR=localhost 20 | MASTER_PORT=6000 21 | NNODES=1 22 | NODE_RANK=0 23 | WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) 24 | 25 | VOCAB_FILE=$PWD/data/gpt2-vocab.json 26 | MERGE_FILE=$PWD/data/gpt2-merges.txt 27 | DATA_PATH=$PWD/data/wikipedia_text_document 28 | 29 | DISTRIBUTED_ARGS=" 30 | --nproc_per_node $GPUS_PER_NODE \ 31 | --nnodes $NNODES \ 32 | --node_rank $NODE_RANK \ 33 | --master_addr $MASTER_ADDR \ 34 | --master_port $MASTER_PORT 35 | " 36 | 37 | GPT_ARGS=" 38 | --num-layers 24 \ 39 | --hidden-size 1024 \ 40 | --num-attention-heads 16 \ 41 | --seq-length 1024 \ 42 | --max-position-embeddings 1024 \ 43 | --micro-batch-size 8 \ 44 | --global-batch-size 64 \ 45 | --lr 0.00015 \ 46 | --train-iters 500000 \ 47 | --lr-decay-iters 320000 \ 48 | --lr-decay-style cosine \ 49 | --min-lr 1.0e-5 \ 50 | --weight-decay 1e-2 \ 51 | --lr-warmup-fraction .01 \ 52 | --clip-grad 1.0 \ 53 | --bf16 \ 54 | --use-flash-attn \ 55 | --no-gradient-accumulation-fusion \ 56 | --use-distributed-optimizer 57 | " 58 | 59 | DATA_ARGS=" 60 | --data-path $DATA_PATH \ 61 | --vocab-file $VOCAB_FILE \ 62 | --merge-file $MERGE_FILE \ 63 | --data-impl mmap \ 64 | --split 949,50,1 65 | " 66 | 67 | OUTPUT_ARGS=" 68 | --log-interval 100 \ 69 | --save-interval 2000 \ 70 | --eval-interval 1000 \ 71 | --eval-iters 10 72 | " 73 | 74 | if [ "$FP_TYPE" = "bf16" ]; then 75 | CHECKPOINT_PATH=$PWD/checkpoints/gpt_345m_bf16 76 | torchrun $DISTRIBUTED_ARGS ../third_party/Megatron-LM/pretrain_gpt.py \ 77 | $GPT_ARGS \ 78 | $DATA_ARGS \ 79 | $OUTPUT_ARGS \ 80 | --save $CHECKPOINT_PATH \ 81 | --load $CHECKPOINT_PATH 82 | 83 | elif [ "$FP_TYPE" = "te" ]; then 84 | CHECKPOINT_PATH=$PWD/checkpoints/gpt_345m_te 85 | torchrun $DISTRIBUTED_ARGS ../third_party/Megatron-LM/pretrain_gpt.py \ 86 | $GPT_ARGS \ 87 | $DATA_ARGS \ 88 | $OUTPUT_ARGS \ 89 | --fp8-hybrid \ 90 | --transformer-impl transformer_engine \ 91 | --save $CHECKPOINT_PATH \ 92 | --load $CHECKPOINT_PATH 93 | 94 | elif [ "$FP_TYPE" = "msamp" ]; then 95 | CHECKPOINT_PATH=$PWD/checkpoints/gpt_345m_msamp 96 | torchrun $DISTRIBUTED_ARGS ../third_party/Megatron-LM/pretrain_gpt.py \ 97 | $GPT_ARGS \ 98 | $DATA_ARGS \ 99 | $OUTPUT_ARGS \ 100 | --fp8-hybrid \ 101 | --transformer-impl transformer_engine \ 102 | --msamp \ 103 | --save $CHECKPOINT_PATH \ 104 | --load $CHECKPOINT_PATH 105 | else 106 | echo $USAGE 107 | exit 1 108 | fi 109 | -------------------------------------------------------------------------------- /gpt3/pretrain_345m_megatron_ds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Runs the "345M" parameter model 4 | 5 | set -e 6 | 7 | USAGE="usage: bash pretrain_345m_megatron_ds.sh [fp16|msamp]" 8 | 9 | if [ "$#" -ne 1 ]; then 10 | echo $USAGE 11 | exit 1 12 | fi 13 | 14 | FP_TYPE=$1 15 | VOCAB_FILE=$PWD/data/gpt2-vocab.json 16 | MERGE_FILE=$PWD/data/gpt2-merges.txt 17 | DATA_PATH=$PWD/data/wikipedia_text_document 18 | BS=4 19 | GLOBAL_BS=8 20 | CLIP_GRAD=1.0 21 | LOG_INTERVAL=100 22 | 23 | GPT_ARGS=" 24 | --num-layers 24 \ 25 | --hidden-size 1024 \ 26 | --num-attention-heads 16 \ 27 | --seq-length 1024 \ 28 | --max-position-embeddings 1024 \ 29 | --micro-batch-size $BS \ 30 | --global-batch-size $GLOBAL_BS \ 31 | --lr 0.00015 \ 32 | --train-iters 500000 \ 33 | --lr-decay-iters 320000 \ 34 | --lr-decay-style cosine \ 35 | --min-lr 1.0e-5 \ 36 | --weight-decay 1e-2 \ 37 | --lr-warmup-fraction .01 \ 38 | --clip-grad $CLIP_GRAD \ 39 | --fp16 40 | " 41 | 42 | DATA_ARGS=" 43 | --data-path $DATA_PATH \ 44 | --vocab-file $VOCAB_FILE \ 45 | --merge-file $MERGE_FILE \ 46 | --data-impl mmap \ 47 | --split 949,50,1 48 | " 49 | 50 | OUTPUT_ARGS=" 51 | --log-interval $LOG_INTERVAL \ 52 | --save-interval 10000 \ 53 | --eval-interval 1000 \ 54 | --eval-iters 10 55 | " 56 | 57 | config_json="./ds_config.json" 58 | cat <$config_json 59 | { 60 | "train_micro_batch_size_per_gpu": $BS, 61 | "train_batch_size": $GLOBAL_BS, 62 | "gradient_clipping": $CLIP_GRAD, 63 | "fp16": { 64 | "enabled": true 65 | }, 66 | "steps_per_print": 100 67 | } 68 | EOT 69 | 70 | export CUDA_DEVICE_MAX_CONNECTIONS=1 71 | if [ "$FP_TYPE" = "fp16" ]; then 72 | echo "run 345M gpt3 with fp16" 73 | CHECKPOINT_PATH=$PWD/checkpoints/gpt_345m_fp16 74 | torchrun ../third_party/Megatron-DeepSpeed/pretrain_gpt.py \ 75 | $GPT_ARGS \ 76 | $DATA_ARGS \ 77 | $OUTPUT_ARGS \ 78 | --save $CHECKPOINT_PATH \ 79 | --load $CHECKPOINT_PATH 80 | 81 | elif [ "$FP_TYPE" = "msamp" ]; then 82 | echo "run 345M gpt3 with MS-AMP" 83 | 84 | DEEPSPEED_ARGS=" \ 85 | --deepspeed \ 86 | --deepspeed_config ${config_json} \ 87 | " 88 | CHECKPOINT_PATH=$PWD/checkpoints/gpt_345m_msamp 89 | 90 | torchrun ../third_party/Megatron-DeepSpeed/pretrain_gpt.py \ 91 | $GPT_ARGS \ 92 | $DATA_ARGS \ 93 | $OUTPUT_ARGS \ 94 | --save $CHECKPOINT_PATH \ 95 | --load $CHECKPOINT_PATH \ 96 | --msamp \ 97 | --msamp-opt-level O2 \ 98 | $DEEPSPEED_ARGS 99 | else 100 | echo $USAGE 101 | exit 1 102 | fi 103 | -------------------------------------------------------------------------------- /gpt3/pretrain_6b7_megatron.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Runs the "6.7b" parameter model 4 | # GPT-13B: 32 layers, 4096 hidden size, 32 attention heads 5 | 6 | set -e 7 | 8 | USAGE="usage: bash pretrain_6b7_megatron.sh [bf16|te|msamp]" 9 | 10 | if [ "$#" -ne 1 ]; then 11 | echo $USAGE 12 | exit 1 13 | fi 14 | 15 | FP_TYPE=$1 16 | 17 | export CUDA_DEVICE_MAX_CONNECTIONS=1 18 | GPUS_PER_NODE=8 19 | # Change for multinode config 20 | MASTER_ADDR=localhost 21 | MASTER_PORT=6000 22 | NNODES=1 23 | NODE_RANK=0 24 | WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) 25 | 26 | VOCAB_FILE=$PWD/data/gpt2-vocab.json 27 | MERGE_FILE=$PWD/data/gpt2-merges.txt 28 | DATA_PATH=$PWD/data/wikipedia_text_document 29 | 30 | DISTRIBUTED_ARGS=" 31 | --nproc_per_node $GPUS_PER_NODE \ 32 | --nnodes $NNODES \ 33 | --node_rank $NODE_RANK \ 34 | --master_addr $MASTER_ADDR \ 35 | --master_port $MASTER_PORT 36 | " 37 | 38 | GPT_ARGS=" 39 | --tensor-model-parallel-size 1 \ 40 | --pipeline-model-parallel-size 1 \ 41 | --distributed-backend nccl \ 42 | --no-query-key-layer-scaling \ 43 | --seed 43 \ 44 | --num-layers 32 \ 45 | --hidden-size 4096 \ 46 | --num-attention-heads 32 \ 47 | --seq-length 2048 \ 48 | --max-position-embeddings 2048 \ 49 | --train-samples 48828125 \ 50 | --lr-decay-samples 43945312 \ 51 | --lr-warmup-samples 2048000 \ 52 | --lr 3.0e-4 \ 53 | --min-lr 3.0e-5 \ 54 | --lr-decay-style cosine \ 55 | --micro-batch-size 1 \ 56 | --global-batch-size 2048 \ 57 | --clip-grad 1.0 \ 58 | --weight-decay 0.1 \ 59 | --attention-dropout 0.0 \ 60 | --hidden-dropout 0.0 \ 61 | --optimizer adam \ 62 | --adam-beta1 0.9 \ 63 | --adam-beta2 0.95 \ 64 | --init-method-std 0.0099 \ 65 | --num-workers 1 \ 66 | --bf16 \ 67 | --sequence-parallel \ 68 | --use-flash-attn \ 69 | --no-gradient-accumulation-fusion \ 70 | --use-distributed-optimizer 71 | " 72 | 73 | 74 | DATA_ARGS=" 75 | --data-path $DATA_PATH \ 76 | --vocab-file $VOCAB_FILE \ 77 | --merge-file $MERGE_FILE \ 78 | --data-impl mmap \ 79 | --split 949,50,1 80 | " 81 | 82 | OUTPUT_ARGS=" 83 | --log-interval 1 \ 84 | --save-interval 1000 \ 85 | --eval-interval 500 \ 86 | --eval-iters 7 87 | " 88 | 89 | if [ "$FP_TYPE" = "bf16" ]; then 90 | CHECKPOINT_PATH=$PWD/checkpoints/gpt_6b7_bf16 91 | torchrun $DISTRIBUTED_ARGS ../third_party/Megatron-LM/pretrain_gpt.py \ 92 | $GPT_ARGS \ 93 | $DATA_ARGS \ 94 | $OUTPUT_ARGS \ 95 | --save $CHECKPOINT_PATH \ 96 | --load $CHECKPOINT_PATH 97 | elif [ "$FP_TYPE" = "te" ]; then 98 | CHECKPOINT_PATH=$PWD/checkpoints/gpt_6b7_te 99 | torchrun $DISTRIBUTED_ARGS ../third_party/Megatron-LM/pretrain_gpt.py \ 100 | $GPT_ARGS \ 101 | $DATA_ARGS \ 102 | $OUTPUT_ARGS \ 103 | --save $CHECKPOINT_PATH \ 104 | --load $CHECKPOINT_PATH \ 105 | --fp8-hybrid \ 106 | --transformer-impl transformer_engine 107 | elif [ "$FP_TYPE" = "msamp" ]; then 108 | CHECKPOINT_PATH=$PWD/checkpoints/gpt_6b7_msamp 109 | torchrun $DISTRIBUTED_ARGS ../third_party/Megatron-LM/pretrain_gpt.py \ 110 | $GPT_ARGS \ 111 | $DATA_ARGS \ 112 | $OUTPUT_ARGS \ 113 | --save $CHECKPOINT_PATH \ 114 | --load $CHECKPOINT_PATH \ 115 | --fp8-hybrid \ 116 | --transformer-impl transformer_engine \ 117 | --msamp 118 | else 119 | echo $USAGE 120 | exit 1 121 | fi 122 | --------------------------------------------------------------------------------