├── README.md ├── examples ├── cifar_vit │ ├── README.md │ ├── data.py │ └── train.py ├── glue_bert │ ├── README.md │ ├── data.py │ └── finetune.py └── imagenet_vit │ ├── loss.py │ ├── data.py │ └── train.py ├── benchmark └── gpt │ ├── data.py │ ├── callback.py │ ├── README.md │ ├── train.py │ └── model.py └── .gitignore /README.md: -------------------------------------------------------------------------------- 1 | # ColossalAI Pytorch-lightning strategy 2 | 3 | ## Usage 4 | 5 | ```python 6 | from pytorch_lightning.strategies.colossalai import ColossalAIStrategy 7 | 8 | trainer = Trainer(..., strategy=ColossalAIStrategy()) 9 | ``` -------------------------------------------------------------------------------- /examples/cifar_vit/README.md: -------------------------------------------------------------------------------- 1 | # Train ViT on CIFAR-10 2 | 3 | ## Run 4 | 5 | Prepare dataset: 6 | ```shell 7 | export DATA=/path/to/cifar10 8 | ``` 9 | 10 | Torch DDP: 11 | ```shell 12 | python train.py --np 4 13 | ``` 14 | 15 | ColossalAI ZeRO-DP: 16 | ```shell 17 | python train.py --np 4 --colossal 18 | ``` 19 | 20 | ## Result 21 | 22 | | Strategy | GPUs | Validation loss | Validation Accuracy | 23 | | --- | --- | --- | --- | 24 | | ddp | 4 | 0.650 | 0.834 | 25 | | colossalai | 4 | 0.651 | 0.833 | -------------------------------------------------------------------------------- /examples/glue_bert/README.md: -------------------------------------------------------------------------------- 1 | # Finetune BERT on GLUE 2 | 3 | ## Run 4 | 5 | Torch DDP: 6 | ```shell 7 | python finetune.py 8 | ``` 9 | 10 | ColossalAI ZeRO-DP: 11 | ```shell 12 | python finetune.py --colossal 13 | ``` 14 | 15 | ## Result 16 | 17 | | Strategy | GPUs | Validation loss | Validation Accuracy | Validation F1 | 18 | | --- | --- | --- | --- | --- | 19 | | ddp | 1 | 0.365 | 0.853 | 0.896 | 20 | | colossalai | 1 | 0.358 | 0.863 | 0.902 | 21 | | ddp | 2 | 0.375 | 0.848 | 0.893 | 22 | | colossalai | 2 | 0.368 | 0.848 | 0.892 | 23 | | ddp | 4 | 0.415 | 0.850 | 0.894 | 24 | | colossalai | 4 | 0.389 | 0.848 | 0.892 | -------------------------------------------------------------------------------- /benchmark/gpt/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | __all__ = ['RandomDataloader'] 4 | 5 | 6 | def get_data(batch_size, seq_len, vocab_size): 7 | input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) 8 | attention_mask = torch.ones_like(input_ids) 9 | return input_ids, attention_mask 10 | 11 | 12 | class RandomDataloader: 13 | def __init__(self, n_steps: int, batch_size: int, seq_len: int = 1024, vocab_size: int = 50257) -> None: 14 | self.n_steps = n_steps 15 | self.cur_step = 0 16 | self.batch_size = batch_size 17 | self.seq_len = seq_len 18 | self.vocab_size = vocab_size 19 | 20 | def __iter__(self): 21 | self.cur_step = 0 22 | return self 23 | 24 | def __next__(self): 25 | if self.cur_step >= self.n_steps: 26 | raise StopIteration 27 | self.cur_step += 1 28 | return get_data(self.batch_size, self.seq_len, self.vocab_size) 29 | 30 | def __len__(self): 31 | return self.n_steps 32 | -------------------------------------------------------------------------------- /benchmark/gpt/callback.py: -------------------------------------------------------------------------------- 1 | import psutil 2 | import torch 3 | import torch.distributed as dist 4 | from pytorch_lightning.callbacks import Callback 5 | 6 | 7 | def print_rank_0(*args, **kwargs): 8 | if dist.get_rank() == 0: 9 | print(*args, **kwargs) 10 | dist.barrier() 11 | 12 | 13 | def get_cpu_mem(): 14 | return psutil.Process().memory_info().rss 15 | 16 | 17 | class MemoryMonitor(Callback): 18 | def __init__(self) -> None: 19 | super().__init__() 20 | self.max_cpu_mem = 0 21 | 22 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None: 23 | self.max_cpu_mem = max(get_cpu_mem(), self.max_cpu_mem) 24 | 25 | def on_fit_start(self, trainer, pl_module) -> None: 26 | max_cuda_mem = torch.cuda.max_memory_allocated() 27 | cuda_mem = torch.cuda.memory_allocated() 28 | print_rank_0(f'CPU memory before training: {get_cpu_mem()/1024**2:.3f} MB') 29 | print_rank_0(f'CUDA memory before training: {cuda_mem/1024**2:.3f} MB') 30 | print_rank_0(f'Max CUDA memory before training: {max_cuda_mem/1024**2:.3f} MB') 31 | 32 | def on_fit_end(self, trainer, pl_module) -> None: 33 | max_cuda_mem = torch.cuda.max_memory_allocated() 34 | print_rank_0(f'Max CPU memory: {self.max_cpu_mem/1024**2:.3f} MB') 35 | print_rank_0(f'Max CUDA memory: {max_cuda_mem/1024**2:.3f} MB') 36 | -------------------------------------------------------------------------------- /benchmark/gpt/README.md: -------------------------------------------------------------------------------- 1 | # Benchmark Results 2 | 3 | ## Model scaling 4 | 5 | RAM: 500G 6 | 7 | GPU: A100 (40G) 8 | 9 | We fix the batch size per GPU to 1. 10 | 11 | | Strategy | GPUs | Max model size (B) | Max CUDA memory allocated (MB) | Step time (sec) | 12 | | --- | --- | --- | --- | --- | 13 | | deepspeed (zero3 offload) | 1 | 18 | 5699.500 | 39.32 | 14 | | colssalai (auto) | 1 | 24 | 36483.311 | 69.54 | 15 | | deepspeed (zero3) | 8 | 12 | 29751.203 | 9.06 | 16 | | colssalai (cuda) | 8 | 12 | 24504.032 | 7.07 | 17 | 18 | Commands: 19 | 20 | Deepspeed: 21 | ```shell 22 | python train.py --epochs 1 --steps_per_epoch 3 --model gpt2_18B --strategy deepspeed --offload 23 | python train.py --epochs 1 --steps_per_epoch 3 --model gpt2_12B --strategy deepspeed --np 8 24 | ``` 25 | 26 | ColossalAI: 27 | ```shell 28 | python train.py --epochs 1 --steps_per_epoch 3 --model gpt2_24B --strategy colossal --placement_policy auto --opt_gpu_margin_rat 0.9 29 | python train.py --epochs 1 --steps_per_epoch 3 --model gpt2_12B --strategy colossal --np 8 30 | ``` 31 | 32 | ## Small model comparison 33 | 34 | We collected results using GPT2-XL (~1.6B) that fit training with DDP (AMP). 35 | 36 | All experiments are run on 4x A100 (40G). 37 | 38 | | Strategy | Global batch size | Global throughput (samples/sec) | Max CUDA memory allocated (MB) | 39 | | --- | --- | --- | --- | 40 | | ddp (AMP) | 24 | 11.76 | 37905.422 | 41 | | deepspeed (zero3) | 160 | 18.18 | 25360.968 | 42 | | colssalai (cuda) | 160 | 19.36 | 24003.394 | 43 | 44 | Commands: 45 | 46 | DDP: 47 | ```shell 48 | python train.py --np 4 --batch_size 6 49 | ``` 50 | 51 | Deepspeed: 52 | ```shell 53 | python train.py --np 4 --strategy deepspeed --batch_size 40 54 | ``` 55 | 56 | ColossalAI: 57 | ```shell 58 | python train.py --np 4 --strategy colossal --batch_size 40 59 | ``` -------------------------------------------------------------------------------- /examples/cifar_vit/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.transforms as transforms 4 | from torch.utils.data import DataLoader 5 | from torchvision.datasets import CIFAR10 6 | from typing import Optional, Callable, Tuple, Any 7 | 8 | 9 | class Cifar10Dataset(CIFAR10): 10 | def __init__(self, root: str, train: bool = True, pre_process: Optional[Callable] = None, transform=None, download: bool = False) -> None: 11 | super().__init__(root, train, transform, None, download) 12 | self.data = torch.tensor(self.data.transpose((0, 3, 1, 2))) 13 | self.targets = torch.tensor(self.targets) 14 | if pre_process: 15 | self.data = pre_process(self.data) 16 | 17 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 18 | img, target = self.data[index], self.targets[index] 19 | img = img.cuda() 20 | if self.transform is not None: 21 | img = self.transform(img) 22 | return img, target.cuda() 23 | 24 | 25 | def to_tensor(t): 26 | return t.float().div(255) 27 | 28 | 29 | def build_data(batch_size: int): 30 | root = os.environ['DATA'] 31 | transform_train = transforms.Compose([ 32 | to_tensor, 33 | transforms.RandomCrop(32, padding=4), 34 | transforms.RandomHorizontalFlip(), 35 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 36 | ]) 37 | 38 | transform_test = transforms.Compose([ 39 | to_tensor, 40 | transforms.Resize(32), 41 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 42 | ]) 43 | train_set = Cifar10Dataset(root, train=True, transform=transform_train) 44 | valid_set = Cifar10Dataset(root, train=False, transform=transform_test) 45 | train_loader = DataLoader(train_set, batch_size, shuffle=True, drop_last=True) 46 | valid_loader = DataLoader(valid_set, batch_size, shuffle=False, drop_last=False) 47 | return train_loader, valid_loader 48 | -------------------------------------------------------------------------------- /examples/imagenet_vit/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MulticlassBCEWithLogitsLoss(nn.Module): 7 | """ BCE with optional one-hot from dense targets, label smoothing, thresholding 8 | """ 9 | 10 | def __init__(self, smoothing=0.0): 11 | super().__init__() 12 | assert 0. <= smoothing < 1.0 13 | self.smoothing = smoothing 14 | 15 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 16 | assert x.shape[0] == target.shape[0] 17 | batch_size = x.size(0) 18 | if target.shape != x.shape: 19 | # NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse 20 | num_classes = x.shape[-1] 21 | # FIXME should off/on be different for smoothing w/ BCE? Other impl out there differ 22 | off_value = self.smoothing / num_classes 23 | on_value = 1. - self.smoothing + off_value 24 | target = target.long().view(-1, 1) 25 | target = torch.full( 26 | (batch_size, num_classes), 27 | off_value, 28 | device=x.device, dtype=x.dtype).scatter_(1, target, on_value) 29 | return F.binary_cross_entropy_with_logits(x, target, reduction='sum') / batch_size 30 | 31 | 32 | class LabelSmoothLoss(nn.Module): 33 | 34 | def __init__(self, smoothing=0.0): 35 | super(LabelSmoothLoss, self).__init__() 36 | self.smoothing = smoothing 37 | 38 | def forward(self, input, target): 39 | log_prob = F.log_softmax(input, dim=-1) 40 | weight = input.new_ones(input.size()) * \ 41 | self.smoothing / (input.size(-1) - 1.) 42 | weight.scatter_(-1, target.unsqueeze(-1), (1. - self.smoothing)) 43 | loss = (-weight * log_prob).sum(dim=-1).mean() 44 | return loss 45 | 46 | 47 | class MixupLoss(nn.Module): 48 | def __init__(self, loss_fn): 49 | super().__init__() 50 | self.loss_fn = loss_fn 51 | 52 | def forward(self, inputs, targets_a, targets_b, lam): 53 | return lam * self.loss_fn(inputs, targets_a) + (1 - lam) * self.loss_fn(inputs, targets_b) 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | docs/.build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | # IDE 133 | .idea/ 134 | .vscode/ 135 | 136 | # macos 137 | .DS_Store 138 | #data/ 139 | 140 | docs/.build 141 | 142 | # pytorch checkpoint 143 | *.pt -------------------------------------------------------------------------------- /benchmark/gpt/train.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import argparse 3 | from data import RandomDataloader 4 | from model import GPTLitModule, get_optimizer 5 | from callback import MemoryMonitor 6 | from pytorch_lightning.strategies.ddp import DDPStrategy 7 | from pytorch_lightning.strategies.deepspeed import DeepSpeedStrategy 8 | from pytorch_lightning.strategies.colossalai import ColossalAIStrategy 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--epochs', type=int, default=2) 13 | parser.add_argument('--steps_per_epoch', type=int, default=4) 14 | parser.add_argument('--batch_size', type=int, default=1) 15 | parser.add_argument('--lr', type=float, default=1e-3) 16 | parser.add_argument('--model', default='gpt2_xl') 17 | parser.add_argument('--np', type=int, default=1) 18 | parser.add_argument('--no_activation_ckpt', action='store_true', default=False) 19 | parser.add_argument('--opt_nvme_offload_frac', type=float, default=0.0) 20 | parser.add_argument('--opt_nvme_offload_dir', default='./offload') 21 | parser.add_argument('--seq_len', type=int, default=1024) 22 | parser.add_argument('--placement_policy', default='cuda') 23 | parser.add_argument('--opt_gpu_margin_rat', type=float, default=0.0) 24 | parser.add_argument('--cuda_mem_frac', type=float, default=1.0) 25 | parser.add_argument('--strategy', default='ddp', choices=['ddp', 'colossal', 'deepspeed']) 26 | parser.add_argument('--offload', action='store_true', default=False) 27 | args = parser.parse_args() 28 | train_dataloader = RandomDataloader(args.steps_per_epoch, args.batch_size, args.seq_len) 29 | optimizer_cfg = {'lr': args.lr} 30 | if args.strategy == 'ddp': 31 | trainer_cfg = { 32 | 'accelerator': 'gpu', 33 | 'precision': 16, 34 | 'strategy': DDPStrategy(static_graph=True) 35 | } 36 | elif args.strategy == 'colossal': 37 | trainer_cfg = { 38 | 'accelerator': 'gpu', 39 | 'precision': 16, 40 | 'strategy': ColossalAIStrategy( 41 | placement_policy=args.placement_policy, 42 | gpu_margin_mem_ratio=args.opt_gpu_margin_rat, 43 | initial_scale=32 44 | ) 45 | } 46 | optimizer_cfg['nvme_offload_dir'] = args.opt_nvme_offload_dir 47 | optimizer_cfg['nvme_offload_fraction'] = args.opt_nvme_offload_frac 48 | elif args.strategy == 'deepspeed': 49 | trainer_cfg = { 50 | 'accelerator': 'gpu', 51 | 'precision': 16, 52 | 'strategy': DeepSpeedStrategy( 53 | stage=3, 54 | offload_parameters=args.offload, 55 | offload_optimizer=args.offload, 56 | initial_scale_power=5 57 | ) 58 | } 59 | optimizer_cfg['offload'] = args.offload 60 | opt_init_fn = get_optimizer(args.strategy, **optimizer_cfg) 61 | model = GPTLitModule(args.model, opt_init_fn, checkpoint=not args.no_activation_ckpt, 62 | cuda_mem_fraction=args.cuda_mem_frac) 63 | trainer = pl.Trainer( 64 | max_epochs=args.epochs, 65 | devices=args.np, 66 | enable_checkpointing=False, 67 | callbacks=[MemoryMonitor()], 68 | **trainer_cfg 69 | ) 70 | trainer.fit(model, train_dataloader) 71 | -------------------------------------------------------------------------------- /examples/imagenet_vit/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from pytorch_lightning import LightningDataModule 4 | from titans.dataloader.imagenet import DaliDataloaderWithRandAug, DaliDataloader 5 | 6 | 7 | class DaliImagenetDataModule(LightningDataModule): 8 | def __init__(self, 9 | root: str, 10 | batch_size=128, 11 | num_threads=4, 12 | resize=256, 13 | crop=224, 14 | prefetch=2, 15 | cuda=True, 16 | gpu_aug=True, 17 | mixup_alpha=0.0, 18 | randaug_magnitude=10, 19 | randaug_num_layers=0) -> None: 20 | super().__init__() 21 | self.train_files = sorted(glob.glob(os.path.join(root, 'train/*'))) 22 | self.train_idx_files = sorted(glob.glob(os.path.join(root, 'idx_files/train/*'))) 23 | self.valid_files = sorted(glob.glob(os.path.join(root, 'validation/*'))) 24 | self.valid_idx_files = sorted(glob.glob(os.path.join(root, 'idx_files/validation/*'))) 25 | self.dataloader_args = dict( 26 | batch_size=batch_size, 27 | num_threads=num_threads, 28 | resize=resize, 29 | crop=crop, 30 | prefetch=prefetch, 31 | cuda=cuda, 32 | gpu_aug=gpu_aug 33 | ) 34 | self.randaug_args = dict( 35 | mixup_alpha=mixup_alpha, 36 | randaug_magnitude=randaug_magnitude, 37 | randaug_num_layers=randaug_num_layers, 38 | ) 39 | 40 | @property 41 | def use_randaug(self): 42 | return self.randaug_args['randaug_magnitude'] > 0 or self.randaug_args['mixup_alpha'] > 0.0 43 | 44 | def train_dataloader(self): 45 | if self.use_randaug: 46 | return DaliDataloaderWithRandAug(self.train_files, 47 | self.train_idx_files, 48 | shard_id=self.trainer.global_rank, 49 | num_shards=self.trainer.world_size, 50 | training=True, 51 | **self.dataloader_args, 52 | **self.randaug_args) 53 | return DaliDataloader(self.train_files, 54 | self.train_idx_files, 55 | shard_id=self.trainer.global_rank, 56 | num_shards=self.trainer.world_size, 57 | training=True, 58 | **self.dataloader_args) 59 | 60 | def val_dataloader(self): 61 | if self.use_randaug: 62 | return DaliDataloaderWithRandAug(self.valid_files, 63 | self.valid_idx_files, 64 | shard_id=self.trainer.global_rank, 65 | num_shards=self.trainer.world_size, 66 | training=False, 67 | **self.dataloader_args, 68 | **self.randaug_args) 69 | return DaliDataloader(self.valid_files, 70 | self.valid_idx_files, 71 | shard_id=self.trainer.global_rank, 72 | num_shards=self.trainer.world_size, 73 | training=False, 74 | **self.dataloader_args) 75 | -------------------------------------------------------------------------------- /examples/cifar_vit/train.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import argparse 4 | import pytorch_lightning as pl 5 | from torchmetrics import Accuracy 6 | from pytorch_lightning import seed_everything 7 | from pytorch_lightning.callbacks import TQDMProgressBar 8 | from data import build_data 9 | from timm.models.vision_transformer import _create_vision_transformer, _cfg 10 | from colossalai.nn.optimizer import HybridAdam 11 | from colossalai.nn.lr_scheduler import LinearWarmupLR 12 | from pytorch_lightning.strategies.colossalai import ColossalAIStrategy 13 | 14 | 15 | def vit_cifar(**kwargs): 16 | pretrained_cfg = _cfg(num_classes=10, input_size=(3, 32, 32), crop_pct=1.0) 17 | model_kwargs = dict(patch_size=4, embed_dim=512, depth=6, num_heads=8, 18 | drop_rate=0.1, mlp_ratio=1.0, **kwargs) 19 | model = _create_vision_transformer('vit_cifar', pretrained_cfg=pretrained_cfg, **model_kwargs) 20 | return model 21 | 22 | 23 | class Cifar10PlModule(pl.LightningModule): 24 | def __init__(self, warmup_epochs: int, lr: float = 1e-3, adjust_lr_by_step: bool = False) -> None: 25 | super().__init__() 26 | self.warmup_epochs = warmup_epochs 27 | self.lr = lr 28 | self.adjust_lr_by_step = adjust_lr_by_step 29 | self.criterion = nn.CrossEntropyLoss() 30 | self.model = None 31 | self.top1_accuracy = Accuracy(top_k=1) 32 | 33 | def configure_sharded_model(self) -> None: 34 | self.model = vit_cifar() 35 | 36 | def configure_optimizers(self): 37 | opt = HybridAdam(self.model.parameters(), self.lr) 38 | total_steps = self.trainer.max_epochs 39 | warmup_steps = self.warmup_epochs 40 | interval = 'epoch' 41 | if self.adjust_lr_by_step: 42 | total_steps *= self.trainer.estimated_stepping_batches 43 | warmup_steps *= self.trainer.estimated_stepping_batches 44 | interval = 'step' 45 | scheduler = LinearWarmupLR(opt, total_steps, warmup_steps) 46 | return {'optimizer': opt, 'lr_scheduler': {'scheduler': scheduler, 'interval': interval}} 47 | 48 | def training_step(self, batch, batch_idx): 49 | data, labels = batch 50 | logits = self.model(data) 51 | loss = self.criterion(logits, labels) 52 | self.log('train_loss', loss) 53 | return loss 54 | 55 | def validation_step(self, batch, batch_idx): 56 | data, labels = batch 57 | logits = self.model(data) 58 | loss = self.criterion(logits, labels) 59 | self.top1_accuracy(F.log_softmax(logits, dim=1), labels) 60 | self.log_dict({'valid_loss': loss, 'valid_acc': self.top1_accuracy}, 61 | prog_bar=True, on_epoch=True, sync_dist=True) 62 | 63 | 64 | if __name__ == '__main__': 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('--epochs', type=int, default=60) 67 | parser.add_argument('--warmup', type=int, default=5) 68 | parser.add_argument('--batch_size', type=int, default=512) 69 | parser.add_argument('--np', type=int, default=1) 70 | parser.add_argument('--lr', type=float, default=1e-3) 71 | parser.add_argument('--seed', type=int, default=42) 72 | parser.add_argument('--adjust_lr_by_step', action='store_true', default=False) 73 | parser.add_argument('--colossal', action='store_true', default=False) 74 | args = parser.parse_args() 75 | assert args.batch_size % args.np == 0 76 | seed_everything(args.seed) 77 | batch_size_per_dp = args.batch_size // args.np 78 | trainer_cfg = { 79 | 'strategy': 'ddp' 80 | } 81 | if args.colossal: 82 | trainer_cfg = { 83 | 'precision': 16, 84 | 'strategy': ColossalAIStrategy( 85 | use_chunk=True, 86 | enable_distributed_storage=True, 87 | placement_policy='cuda', 88 | initial_scale=32 89 | ) 90 | } 91 | trainer = pl.Trainer(accelerator='gpu', devices=args.np, 92 | max_epochs=args.epochs, callbacks=[TQDMProgressBar()], 93 | **trainer_cfg) 94 | trainloader, testloader = build_data(batch_size_per_dp) 95 | model = Cifar10PlModule(args.warmup, args.lr, args.adjust_lr_by_step) 96 | trainer.fit(model, trainloader, testloader) 97 | -------------------------------------------------------------------------------- /examples/imagenet_vit/train.py: -------------------------------------------------------------------------------- 1 | from data import DaliImagenetDataModule 2 | from loss import MulticlassBCEWithLogitsLoss, MixupLoss, LabelSmoothLoss 3 | from timm.models.vision_transformer import vit_base_patch16_224, vit_small_patch16_224 4 | from pytorch_lightning import seed_everything 5 | from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR 6 | from colossalai.nn.optimizer import HybridAdam 7 | from torchmetrics import Accuracy 8 | from pytorch_lightning.callbacks import LearningRateMonitor 9 | import pytorch_lightning as pl 10 | import torch 11 | import os 12 | from argparse import ArgumentParser 13 | 14 | 15 | class ViTModule(pl.LightningModule): 16 | def __init__(self, mixup: bool = True, schedule_lr_by_step: bool = False) -> None: 17 | super().__init__() 18 | self.mixup = mixup 19 | self.schedule_lr_by_step = schedule_lr_by_step 20 | self.criterion = MulticlassBCEWithLogitsLoss(0.1) 21 | if mixup: 22 | self.criterion = MixupLoss(self.criterion) 23 | self.acc = Accuracy() 24 | 25 | def configure_sharded_model(self) -> None: 26 | self.model = vit_small_patch16_224(drop_rate=0.1, weight_init='jax', num_classes=100) 27 | 28 | def configure_optimizers(self): 29 | opt = HybridAdam(self.model.parameters(), lr=3e-3, weight_decay=0.3) 30 | if self.schedule_lr_by_step: 31 | num_steps = self.trainer.estimated_stepping_batches 32 | interval = 'step' 33 | else: 34 | num_steps = self.trainer.max_epochs 35 | interval = 'epoch' 36 | num_warmups = int(num_steps * 0.1) 37 | scheduler = CosineAnnealingWarmupLR(opt, num_steps, num_warmups) 38 | scheduler = {'scheduler': scheduler, 'interval': interval} 39 | return [opt], [scheduler] 40 | 41 | def training_step(self, batch, batch_idx): 42 | inputs, targets = batch 43 | logits = self.model(inputs) 44 | if self.mixup: 45 | loss = self.criterion(logits, **targets) 46 | else: 47 | loss = self.criterion(logits, targets) 48 | return loss 49 | 50 | def validation_step(self, batch, batch_idx): 51 | inputs, targets = batch 52 | if self.mixup: 53 | if targets['targets_a'].ndim == 0: 54 | targets['targets_a'] = targets['targets_a'].unsqueeze(0) 55 | if targets['targets_b'].ndim == 0: 56 | targets['targets_b'] = targets['targets_b'].unsqueeze(0) 57 | logits = self.model(inputs) 58 | if self.mixup: 59 | loss = self.criterion(logits, **targets) 60 | else: 61 | loss = self.criterion(logits, targets) 62 | preds = torch.argmax(logits, 1) 63 | if self.mixup: 64 | self.acc(preds, targets['targets_a']) 65 | else: 66 | self.acc(preds, targets) 67 | self.log_dict({'valid_loss': loss, 'valid_acc': self.acc}, prog_bar=True, on_epoch=True, sync_dist=True) 68 | 69 | def on_load_checkpoint(self, checkpoint) -> None: 70 | if not hasattr(self, 'model'): 71 | self.configure_sharded_model() 72 | 73 | 74 | if __name__ == '__main__': 75 | parser = ArgumentParser() 76 | parser.add_argument('--np', type=int, default=1) 77 | parser.add_argument('--batch_size', type=int, default=4096) 78 | parser.add_argument('--ckpt', default=None) 79 | parser.add_argument('--schedule_lr_by_step', action='store_true', default=False) 80 | args = parser.parse_args() 81 | assert args.batch_size % args.np == 0 82 | local_batch_size = args.batch_size // args.np 83 | print(f'local batch size: {local_batch_size}, total batch size: {args.batch_size}') 84 | seed_everything(42) 85 | dm = DaliImagenetDataModule(os.environ['DATA'], local_batch_size, 86 | mixup_alpha=0.2, randaug_magnitude=10, randaug_num_layers=2, 87 | gpu_aug=True) 88 | model = ViTModule(mixup=True, schedule_lr_by_step=args.schedule_lr_by_step) 89 | trainer_cfg = { 90 | 'strategy': 'ddp', 91 | } 92 | trainer = pl.Trainer( 93 | max_epochs=300, 94 | devices=args.np, 95 | gradient_clip_val=1.0, 96 | resume_from_checkpoint=args.ckpt, 97 | callbacks=[LearningRateMonitor('step')], 98 | **trainer_cfg 99 | ) 100 | trainer.fit(model, datamodule=dm) 101 | -------------------------------------------------------------------------------- /examples/glue_bert/data.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | from pytorch_lightning import LightningDataModule 3 | from torch.utils.data import DataLoader 4 | from transformers import AutoTokenizer 5 | 6 | 7 | class GLUEDataModule(LightningDataModule): 8 | 9 | task_text_field_map = { 10 | "cola": ["sentence"], 11 | "sst2": ["sentence"], 12 | "mrpc": ["sentence1", "sentence2"], 13 | "qqp": ["question1", "question2"], 14 | "stsb": ["sentence1", "sentence2"], 15 | "mnli": ["premise", "hypothesis"], 16 | "qnli": ["question", "sentence"], 17 | "rte": ["sentence1", "sentence2"], 18 | "wnli": ["sentence1", "sentence2"], 19 | "ax": ["premise", "hypothesis"], 20 | } 21 | 22 | glue_task_num_labels = { 23 | "cola": 2, 24 | "sst2": 2, 25 | "mrpc": 2, 26 | "qqp": 2, 27 | "stsb": 1, 28 | "mnli": 3, 29 | "qnli": 2, 30 | "rte": 2, 31 | "wnli": 2, 32 | "ax": 3, 33 | } 34 | 35 | loader_columns = [ 36 | "datasets_idx", 37 | "input_ids", 38 | "token_type_ids", 39 | "attention_mask", 40 | "start_positions", 41 | "end_positions", 42 | "labels", 43 | ] 44 | 45 | def __init__( 46 | self, 47 | model_name_or_path: str, 48 | task_name: str = "mrpc", 49 | max_seq_length: int = 128, 50 | train_batch_size: int = 32, 51 | eval_batch_size: int = 32, 52 | **kwargs, 53 | ): 54 | super().__init__() 55 | self.model_name_or_path = model_name_or_path 56 | self.task_name = task_name 57 | self.max_seq_length = max_seq_length 58 | self.train_batch_size = train_batch_size 59 | self.eval_batch_size = eval_batch_size 60 | 61 | self.text_fields = self.task_text_field_map[task_name] 62 | self.num_labels = self.glue_task_num_labels[task_name] 63 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) 64 | 65 | def setup(self, stage: str): 66 | self.dataset = datasets.load_dataset("glue", self.task_name) 67 | 68 | for split in self.dataset.keys(): 69 | self.dataset[split] = self.dataset[split].map( 70 | self.convert_to_features, 71 | batched=True, 72 | remove_columns=["label"], 73 | ) 74 | self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns] 75 | self.dataset[split].set_format(type="torch", columns=self.columns) 76 | 77 | self.eval_splits = [x for x in self.dataset.keys() if "validation" in x] 78 | 79 | def prepare_data(self): 80 | datasets.load_dataset("glue", self.task_name) 81 | AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) 82 | 83 | def train_dataloader(self): 84 | return DataLoader(self.dataset["train"], batch_size=self.train_batch_size, shuffle=True) 85 | 86 | def val_dataloader(self): 87 | if len(self.eval_splits) == 1: 88 | return DataLoader(self.dataset["validation"], batch_size=self.eval_batch_size) 89 | elif len(self.eval_splits) > 1: 90 | return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits] 91 | 92 | def test_dataloader(self): 93 | if len(self.eval_splits) == 1: 94 | return DataLoader(self.dataset["test"], batch_size=self.eval_batch_size) 95 | elif len(self.eval_splits) > 1: 96 | return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits] 97 | 98 | def convert_to_features(self, example_batch, indices=None): 99 | 100 | # Either encode single sentence or sentence pairs 101 | if len(self.text_fields) > 1: 102 | texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])) 103 | else: 104 | texts_or_text_pairs = example_batch[self.text_fields[0]] 105 | 106 | # Tokenize the text/text pairs 107 | features = self.tokenizer.batch_encode_plus( 108 | texts_or_text_pairs, max_length=self.max_seq_length, pad_to_max_length=True, truncation=True 109 | ) 110 | 111 | # Rename label to labels to make it easier to pass to model forward 112 | features["labels"] = example_batch["label"] 113 | 114 | return features 115 | -------------------------------------------------------------------------------- /examples/glue_bert/finetune.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Optional 3 | 4 | import datasets 5 | import torch 6 | from pytorch_lightning import LightningModule, Trainer, seed_everything 7 | from transformers import ( 8 | AutoConfig, 9 | get_linear_schedule_with_warmup, 10 | BertForSequenceClassification 11 | ) 12 | from data import GLUEDataModule 13 | from argparse import ArgumentParser 14 | from colossalai.nn.optimizer import HybridAdam 15 | from pytorch_lightning.strategies.colossalai import ColossalAIStrategy 16 | 17 | 18 | class GLUETransformer(LightningModule): 19 | def __init__( 20 | self, 21 | model_name_or_path: str, 22 | num_labels: int, 23 | task_name: str, 24 | learning_rate: float = 2e-5, 25 | adam_epsilon: float = 1e-8, 26 | warmup_fraction: float = 0.0, 27 | weight_decay: float = 0.0, 28 | train_batch_size: int = 32, 29 | eval_batch_size: int = 32, 30 | eval_splits: Optional[list] = None, 31 | **kwargs, 32 | ): 33 | super().__init__() 34 | 35 | self.save_hyperparameters() 36 | 37 | self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels) 38 | self.metric = datasets.load_metric( 39 | "glue", self.hparams.task_name, experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S") 40 | ) 41 | 42 | def configure_sharded_model(self) -> None: 43 | self.model = BertForSequenceClassification.from_pretrained( 44 | self.hparams.model_name_or_path, config=self.config) 45 | 46 | def forward(self, **inputs): 47 | return self.model(**inputs) 48 | 49 | def training_step(self, batch, batch_idx): 50 | outputs = self(**batch) 51 | loss = outputs[0] 52 | return loss 53 | 54 | def validation_step(self, batch, batch_idx, dataloader_idx=0): 55 | outputs = self(**batch) 56 | val_loss, logits = outputs[:2] 57 | 58 | if self.hparams.num_labels > 1: 59 | preds = torch.argmax(logits, axis=1) 60 | elif self.hparams.num_labels == 1: 61 | preds = logits.squeeze() 62 | 63 | labels = batch["labels"] 64 | 65 | return {"loss": val_loss, "preds": preds, "labels": labels} 66 | 67 | def validation_epoch_end(self, outputs): 68 | if self.hparams.task_name == "mnli": 69 | for i, output in enumerate(outputs): 70 | # matched or mismatched 71 | split = self.hparams.eval_splits[i].split("_")[-1] 72 | preds = torch.cat([x["preds"] for x in output]).detach().cpu().numpy() 73 | labels = torch.cat([x["labels"] for x in output]).detach().cpu().numpy() 74 | loss = torch.stack([x["loss"] for x in output]).mean() 75 | self.log(f"val_loss_{split}", loss, prog_bar=True) 76 | split_metrics = { 77 | f"{k}_{split}": v for k, v in self.metric.compute(predictions=preds, references=labels).items() 78 | } 79 | self.log_dict(split_metrics, prog_bar=True) 80 | return loss 81 | 82 | preds = torch.cat([x["preds"] for x in outputs]).detach().cpu().numpy() 83 | labels = torch.cat([x["labels"] for x in outputs]).detach().cpu().numpy() 84 | loss = torch.stack([x["loss"] for x in outputs]).mean() 85 | self.log("val_loss", loss, prog_bar=True, sync_dist=True) 86 | self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True, sync_dist=True) 87 | 88 | def configure_optimizers(self): 89 | """Prepare optimizer and schedule (linear warmup and decay)""" 90 | model = self.model 91 | no_decay = ["bias", "LayerNorm.weight"] 92 | optimizer_grouped_parameters = [ 93 | { 94 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 95 | "weight_decay": self.hparams.weight_decay, 96 | }, 97 | { 98 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 99 | "weight_decay": 0.0, 100 | }, 101 | ] 102 | optimizer = HybridAdam(optimizer_grouped_parameters, lr=self.hparams.learning_rate, 103 | eps=self.hparams.adam_epsilon) 104 | num_warmup_steps = int(self.trainer.estimated_stepping_batches * self.hparams.warmup_fraction) 105 | scheduler = get_linear_schedule_with_warmup( 106 | optimizer, 107 | num_warmup_steps=num_warmup_steps, 108 | num_training_steps=self.trainer.estimated_stepping_batches, 109 | ) 110 | scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} 111 | return [optimizer], [scheduler] 112 | 113 | 114 | if __name__ == '__main__': 115 | parser = ArgumentParser() 116 | parser.add_argument('--task', default='mrpc') 117 | parser.add_argument('--np', type=int, default=1) 118 | parser.add_argument('--epochs', type=int, default=3) 119 | parser.add_argument('--batch_size', type=int, default=32) 120 | parser.add_argument('--lr', type=float, default=2.4e-5) 121 | parser.add_argument('--weight_decay', type=float, default=0.01) 122 | parser.add_argument('--warmup_fraction', type=float, default=0.1) 123 | parser.add_argument('--colossal', action='store_true', default=False) 124 | args = parser.parse_args() 125 | assert args.batch_size % args.np == 0 126 | local_batch_size = args.batch_size // args.np 127 | seed_everything(42) 128 | model_name = 'bert-base-uncased' 129 | dm = GLUEDataModule( 130 | model_name_or_path=model_name, 131 | task_name=args.task, 132 | train_batch_size=local_batch_size, 133 | eval_batch_size=local_batch_size 134 | ) 135 | dm.setup("fit") 136 | model = GLUETransformer( 137 | model_name_or_path=model_name, 138 | num_labels=dm.num_labels, 139 | eval_splits=dm.eval_splits, 140 | task_name=dm.task_name, 141 | train_batch_size=local_batch_size, 142 | eval_batch_size=local_batch_size, 143 | learning_rate=args.lr, 144 | weight_decay=args.weight_decay, 145 | warmup_fraction=args.warmup_fraction, 146 | ) 147 | trainer_cfg = { 148 | 'strategy': 'ddp', 149 | } 150 | if args.colossal: 151 | trainer_cfg = { 152 | 'precision': 16, 153 | 'strategy': ColossalAIStrategy( 154 | use_chunk=True, 155 | enable_distributed_storage=True, 156 | placement_policy='cuda', 157 | initial_scale=32 158 | ) 159 | } 160 | trainer = Trainer( 161 | accelerator="gpu", 162 | devices=args.np, 163 | max_epochs=args.epochs, 164 | **trainer_cfg 165 | ) 166 | trainer.fit(model, datamodule=dm) 167 | -------------------------------------------------------------------------------- /benchmark/gpt/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import pytorch_lightning as pl 3 | from transformers import GPT2Config, GPT2LMHeadModel, GPT2PreTrainedModel 4 | from colossalai.nn.optimizer import HybridAdam 5 | from colossalai.utils import colo_set_process_memory_fraction 6 | from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam 7 | from torch.optim import Adam, Optimizer 8 | from functools import partial 9 | from typing import Callable, Iterable 10 | from contextlib import contextmanager 11 | __all__ = ['GPTLitModule', 'get_optimizer'] 12 | 13 | 14 | @contextmanager 15 | def no_init_weights(): 16 | def dummy_fn(*args): 17 | return 18 | try: 19 | old_init_weights = GPT2PreTrainedModel._init_weights 20 | GPT2PreTrainedModel._init_weights = dummy_fn 21 | yield 22 | finally: 23 | GPT2PreTrainedModel._init_weights = old_init_weights 24 | 25 | 26 | class GPTLMModel(nn.Module): 27 | def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257, checkpoint=False): 28 | super().__init__() 29 | self.checkpoint = checkpoint 30 | with no_init_weights(): 31 | self.model = GPT2LMHeadModel(GPT2Config(n_embd=hidden_size, n_layer=num_layers, 32 | n_head=num_attention_heads, n_positions=max_seq_len, n_ctx=max_seq_len, vocab_size=vocab_size)) 33 | if checkpoint: 34 | self.model.gradient_checkpointing_enable() 35 | 36 | def forward(self, input_ids, attention_mask): 37 | # Only return lm_logits 38 | return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] 39 | 40 | 41 | def gpt2_tiny(checkpoint=True): 42 | return GPTLMModel(hidden_size=128, num_layers=4, num_attention_heads=4, checkpoint=checkpoint) 43 | 44 | 45 | def gpt2_small(checkpoint=True): 46 | return GPTLMModel(hidden_size=768, num_layers=12, num_attention_heads=12, checkpoint=checkpoint) 47 | 48 | 49 | def gpt2_medium(checkpoint=True): 50 | return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) 51 | 52 | 53 | def gpt2_large(checkpoint=True): 54 | return GPTLMModel(hidden_size=1280, num_layers=36, num_attention_heads=20, checkpoint=checkpoint) 55 | 56 | 57 | def gpt2_xl(checkpoint=True): 58 | return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=25, checkpoint=checkpoint) 59 | 60 | 61 | def gpt2_2B(checkpoint=True): 62 | return GPTLMModel(hidden_size=2048, num_layers=40, num_attention_heads=16, checkpoint=checkpoint) 63 | 64 | 65 | def gpt2_3B(checkpoint=True): 66 | return GPTLMModel(hidden_size=2304, num_layers=48, num_attention_heads=16, checkpoint=checkpoint) 67 | 68 | 69 | def gpt2_4B(checkpoint=True): 70 | return GPTLMModel(hidden_size=2304, num_layers=64, num_attention_heads=16, checkpoint=checkpoint) 71 | 72 | 73 | def gpt2_6B(checkpoint=True): 74 | return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint) 75 | 76 | 77 | def gpt2_8B(checkpoint=True): 78 | return GPTLMModel(hidden_size=3072, num_layers=72, num_attention_heads=24, checkpoint=checkpoint) 79 | 80 | 81 | def gpt2_12B(checkpoint=True): 82 | return GPTLMModel(hidden_size=4096, num_layers=60, num_attention_heads=16, checkpoint=checkpoint) 83 | 84 | 85 | def gpt2_15B(checkpoint=True): 86 | return GPTLMModel(hidden_size=4096, num_layers=78, num_attention_heads=16, checkpoint=checkpoint) 87 | 88 | 89 | def gpt2_18B(checkpoint=True): 90 | return GPTLMModel(hidden_size=4096, num_layers=90, num_attention_heads=16, checkpoint=checkpoint) 91 | 92 | 93 | def gpt2_20B(checkpoint=True): 94 | return GPTLMModel(hidden_size=8192, num_layers=25, num_attention_heads=16, checkpoint=checkpoint) 95 | 96 | 97 | def gpt2_24B(checkpoint=True): 98 | return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16, checkpoint=checkpoint) 99 | 100 | 101 | def gpt2_28B(checkpoint=True): 102 | return GPTLMModel(hidden_size=8192, num_layers=35, num_attention_heads=16, checkpoint=checkpoint) 103 | 104 | 105 | def gpt2_32B(checkpoint=True): 106 | return GPTLMModel(hidden_size=8192, num_layers=40, num_attention_heads=16, checkpoint=checkpoint) 107 | 108 | 109 | def gpt2_36B(checkpoint=True): 110 | return GPTLMModel(hidden_size=8192, num_layers=45, num_attention_heads=16, checkpoint=checkpoint) 111 | 112 | 113 | def gpt2_40B(checkpoint=True): 114 | return GPTLMModel(hidden_size=8192, num_layers=50, num_attention_heads=16, checkpoint=checkpoint) 115 | 116 | 117 | def gpt2_45B(checkpoint=True): 118 | return GPTLMModel(hidden_size=8192, num_layers=56, num_attention_heads=16, checkpoint=checkpoint) 119 | 120 | 121 | def gpt3(checkpoint=True): 122 | return GPTLMModel(max_seq_len=2048, hidden_size=12288, num_layers=96, num_attention_heads=96, checkpoint=checkpoint) 123 | 124 | 125 | def get_gpt_model(model_name: str, checkpoint: bool = True) -> nn.Module: 126 | model_map = { 127 | 'gpt2_tiny': gpt2_tiny, 128 | 'gpt2_small': gpt2_small, 129 | 'gpt2_medium': gpt2_medium, 130 | 'gpt2_large': gpt2_large, 131 | 'gpt2_xl': gpt2_xl, 132 | 'gpt2_2B': gpt2_2B, 133 | 'gpt2_3B': gpt2_3B, 134 | 'gpt2_4B': gpt2_4B, 135 | 'gpt2_6B': gpt2_6B, 136 | 'gpt2_8B': gpt2_8B, 137 | 'gpt2_12B': gpt2_12B, 138 | 'gpt2_15B': gpt2_15B, 139 | 'gpt2_18B': gpt2_18B, 140 | 'gpt2_20B': gpt2_20B, 141 | 'gpt2_24B': gpt2_24B, 142 | 'gpt2_28B': gpt2_28B, 143 | 'gpt2_32B': gpt2_32B, 144 | 'gpt2_36B': gpt2_36B, 145 | 'gpt2_40B': gpt2_40B, 146 | 'gpt2_45B': gpt2_45B, 147 | 'gpt3': gpt3, 148 | } 149 | assert model_name in model_map 150 | return model_map[model_name](checkpoint) 151 | 152 | 153 | class GPTLMLoss(nn.Module): 154 | def __init__(self): 155 | super().__init__() 156 | self.loss = nn.CrossEntropyLoss() 157 | 158 | def forward(self, logits, labels): 159 | shift_logits = logits[..., :-1, :].contiguous() 160 | shift_labels = labels[..., 1:].contiguous() 161 | # Flatten the tokens 162 | return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 163 | 164 | 165 | def get_optimizer(strategy: str, **kwargs) -> Callable[[Iterable], Optimizer]: 166 | assert strategy in ('ddp', 'deepspeed', 'colossal') 167 | if strategy == 'ddp': 168 | opt_cls = Adam 169 | elif strategy == 'deepspeed': 170 | offload = kwargs.pop('offload') 171 | if offload: 172 | opt_cls = DeepSpeedCPUAdam 173 | else: 174 | opt_cls = FusedAdam 175 | else: 176 | opt_cls = HybridAdam 177 | return partial(opt_cls, **kwargs) 178 | 179 | 180 | class GPTLitModule(pl.LightningModule): 181 | def __init__(self, model_name: str, optimizer_init_fn: Callable[[Iterable], Optimizer], 182 | checkpoint: bool = True, cuda_mem_fraction: float = 1.0) -> None: 183 | super().__init__() 184 | self.model_name = model_name 185 | self.optimizer_init_fn = optimizer_init_fn 186 | self.checkpoint = checkpoint 187 | self.criterion = GPTLMLoss() 188 | self.cuda_mem_fraction = cuda_mem_fraction 189 | 190 | def configure_sharded_model(self) -> None: 191 | self.model = get_gpt_model(self.model_name, self.checkpoint) 192 | 193 | def on_load_checkpoint(self, checkpoint) -> None: 194 | if not hasattr(self, 'model'): 195 | self.configure_sharded_model() 196 | 197 | def configure_optimizers(self): 198 | return self.optimizer_init_fn(self.model.parameters()) 199 | 200 | def training_step(self, batch, batch_idx): 201 | input_ids, attention_mask = batch 202 | logits = self.model(input_ids, attention_mask) 203 | loss = self.criterion(logits, input_ids) 204 | return loss 205 | 206 | def on_fit_start(self) -> None: 207 | if self.cuda_mem_fraction < 1.0: 208 | colo_set_process_memory_fraction(self.cuda_mem_fraction) 209 | --------------------------------------------------------------------------------