├── my_classifier_template ├── __init__.py ├── .DS_Store ├── dataset.py └── model.py ├── requirements.txt ├── .gitignore ├── README.md └── main.py /my_classifier_template/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning==1.6.3 2 | torch==1.11 3 | torchvision==0.12.0 4 | watermark==2.3.1 5 | -------------------------------------------------------------------------------- /my_classifier_template/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rasbt/b3-basic-batchsize-benchmark/HEAD/my_classifier_template/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /my_classifier_template/dataset.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.dataset import random_split 4 | from torchvision import datasets, transforms 5 | 6 | 7 | class Cifar10DataModule(pl.LightningDataModule): 8 | def __init__( 9 | self, 10 | batch_size, 11 | train_transform=None, 12 | test_transform=None, 13 | num_workers=4, 14 | data_path="./", 15 | ): 16 | super().__init__() 17 | self.data_path = data_path 18 | self.batch_size = batch_size 19 | self.num_workers = num_workers 20 | self.custom_train_transform = train_transform 21 | self.custom_test_transform = test_transform 22 | 23 | def prepare_data(self): 24 | datasets.CIFAR10(root=self.data_path, download=True) 25 | return 26 | 27 | def setup(self, stage=None): 28 | 29 | if self.custom_train_transform is None: 30 | self.train_transform = transforms.Compose( 31 | [ 32 | transforms.Resize((70, 70)), 33 | transforms.RandomCrop((64, 64)), 34 | transforms.ToTensor(), 35 | ] 36 | ) 37 | else: 38 | self.train_transform = self.custom_train_transform 39 | 40 | if self.custom_train_transform is None: 41 | self.test_transform = transforms.Compose( 42 | [ 43 | transforms.Resize((70, 70)), 44 | transforms.CenterCrop((64, 64)), 45 | transforms.ToTensor(), 46 | ] 47 | ) 48 | else: 49 | self.test_transform = self.custom_test_transform 50 | 51 | train = datasets.CIFAR10( 52 | root=self.data_path, 53 | train=True, 54 | transform=self.train_transform, 55 | download=False, 56 | ) 57 | 58 | self.test = datasets.CIFAR10( 59 | root=self.data_path, 60 | train=False, 61 | transform=self.test_transform, 62 | download=False, 63 | ) 64 | 65 | self.train, self.valid = random_split(train, lengths=[45000, 5000]) 66 | 67 | def train_dataloader(self): 68 | train_loader = DataLoader( 69 | dataset=self.train, 70 | batch_size=self.batch_size, 71 | drop_last=True, 72 | shuffle=True, 73 | persistent_workers=True, 74 | num_workers=self.num_workers, 75 | ) 76 | return train_loader 77 | 78 | def val_dataloader(self): 79 | valid_loader = DataLoader( 80 | dataset=self.valid, 81 | batch_size=self.batch_size, 82 | drop_last=False, 83 | persistent_workers=True, 84 | shuffle=False, 85 | num_workers=self.num_workers, 86 | ) 87 | return valid_loader 88 | 89 | def test_dataloader(self): 90 | test_loader = DataLoader( 91 | dataset=self.test, 92 | batch_size=self.batch_size, 93 | drop_last=False, 94 | persistent_workers=True, 95 | shuffle=False, 96 | num_workers=self.num_workers, 97 | ) 98 | return test_loader 99 | -------------------------------------------------------------------------------- /my_classifier_template/model.py: -------------------------------------------------------------------------------- 1 | 2 | import pytorch_lightning as pl 3 | import torch 4 | import torchmetrics 5 | 6 | 7 | # LightningModule that receives a PyTorch model as input 8 | class LightningClassifier(pl.LightningModule): 9 | def __init__(self, model, learning_rate, log_accuracy): 10 | super().__init__() 11 | 12 | self.log_accuracy = log_accuracy 13 | 14 | # Note that the other __init__ parameters will be available as 15 | # self.hparams.argname after calling self.save_hyperparameters below 16 | 17 | # The inherited PyTorch module 18 | self.model = model 19 | if hasattr(model, "dropout_proba"): 20 | self.dropout_proba = model.dropout_proba 21 | 22 | # Save settings and hyperparameters to the log directory 23 | # but skip the model parameters 24 | self.save_hyperparameters(ignore=["model"]) 25 | 26 | # Set up attributes for computing the accuracy 27 | self.train_acc = torchmetrics.Accuracy() 28 | self.valid_acc = torchmetrics.Accuracy() 29 | self.test_acc = torchmetrics.Accuracy() 30 | 31 | # Defining the forward method is only necessary 32 | # if you want to use a Trainer's .predict() method (optional) 33 | def forward(self, x): 34 | return self.model(x) 35 | 36 | # A common forward step to compute the loss and labels 37 | # this is used for training, validation, and testing below 38 | def _shared_step(self, batch): 39 | features, true_labels = batch 40 | logits = self(features) 41 | loss = torch.nn.functional.cross_entropy(logits, true_labels) 42 | predicted_labels = torch.argmax(logits, dim=1) 43 | 44 | return loss, true_labels, predicted_labels 45 | 46 | def training_step(self, batch, batch_idx): 47 | loss, true_labels, predicted_labels = self._shared_step(batch) 48 | self.log("train_loss", loss) 49 | 50 | # Do another forward pass in .eval() mode to compute accuracy 51 | # while accountingfor Dropout, BatchNorm etc. behavior 52 | # during evaluation (inference) 53 | self.model.eval() 54 | with torch.no_grad(): 55 | _, true_labels, predicted_labels = self._shared_step(batch) 56 | 57 | if self.log_accuracy: 58 | self.train_acc(predicted_labels, true_labels) 59 | self.log("train_acc", self.train_acc, on_epoch=True, on_step=False) 60 | self.model.train() 61 | 62 | return loss # this is passed to the optimzer for training 63 | 64 | def validation_step(self, batch, batch_idx): 65 | loss, true_labels, predicted_labels = self._shared_step(batch) 66 | self.log("valid_loss", loss) 67 | self.valid_acc(predicted_labels, true_labels) 68 | 69 | if self.log_accuracy: 70 | self.log( 71 | "valid_acc", 72 | self.valid_acc, 73 | on_epoch=True, 74 | on_step=False, 75 | prog_bar=True, 76 | ) 77 | 78 | def test_step(self, batch, batch_idx): 79 | loss, true_labels, predicted_labels = self._shared_step(batch) 80 | self.test_acc(predicted_labels, true_labels) 81 | self.log("test_acc", self.test_acc, on_epoch=True, on_step=False) 82 | 83 | def configure_optimizers(self): 84 | optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) 85 | return optimizer -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # B3 -- Basic Batchsize Benchmark 2 | 3 | 4 | 5 | A quick benchmark with different batch sizes that was prompted by the discussion [here](https://twitter.com/rasbt/status/1542882893181108227?s=20&t=96dUITuyaNJUfw1TWxDLng), which was in turn prompted by the [Do Batch Sizes Actually Need to be Powers of 2?](https://wandb.ai/datenzauberai/Batch-Size-Testing/reports/Do-Batch-Sizes-Actually-Need-to-be-Powers-of-2---VmlldzoyMDkwNDQx) article. 6 | 7 | 8 | 9 | Right now, this benchmark is a [MobileNetV3 (large)](https://arxiv.org/abs/1905.02244) on CIFAR-10 (the images are resized to 224 to reach proper GPU utilization). You can run it as follows: 10 | 11 | 12 | 13 | **Step 1: Initial Setup** 14 | 15 | ```bash 16 | git clone https://github.com/rasbt/b3-basic-batchsize-benchmark.git 17 | cd b3-basic-batchsize-benchmark 18 | conda create -n benchmark python=3.8 19 | conda activate benchmark 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | 24 | 25 | **Step 2: Running the Training Script** 26 | 27 | 28 | ```python 29 | python main.py --num_epochs 10 --batch_size 127 --mixed_precision true 30 | ``` 31 | 32 | 33 | 34 | ### Additional Resources 35 | 36 | - [Ross Wightman mentioning](https://twitter.com/wightmanr/status/1542917523556904960?s=20&t=96dUITuyaNJUfw1TWxDLng) that it might matter more for TPUs 37 | - [Nvidia's Deep Learning Performance Documentation on matrix multiplication](https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html) explaining the theoretical rational behind choosing batch sizes as multiples of 8 for tensor cores 38 | 39 | 40 | 41 | ### Results 42 | 43 | 44 | 45 | 46 | | batch size | train time | inf. time | epochs | GPU | mixed prec. | 47 | | ---------- | ---------- | --------- | ------ | ---- | ----------- | 48 | | 100 | 10.50 min | 0.15 min | 10 | V100 | Yes | 49 | | 127 | 9.80 min | 0.15 min | 10 | V100 | Yes | 50 | | 128 | 9.78 min | 0.15 min | 10 | V100 | Yes | 51 | | 129 | 9.92 min | 0.15 min | 10 | V100 | Yes | 52 | | 156 | 9.38 min | 0.16 min | 10 | V100 | Yes | 53 | | | | | | | | 54 | | 511 | 8.74 min | 0.17 min | 10 | V100 | Yes | 55 | | 512 | 8.71 min | 0.17 min | 10 | V100 | Yes | 56 | | 513 | 8.72 min | 0.17 min | 10 | V100 | Yes | 57 | 58 | 59 | Below, I trained the same neural network using 4 V100 GPUs with the distributed data parallel strategy: 60 | 61 | ```bash 62 | python main.py --num_epochs 10 --batch_size 255 --mixed_precision true --num_workers 4 --strategy ddp 63 | ``` 64 | 65 | | batch size | train time | epochs | GPU | mixed prec. | 66 | | ---------- | ---------- | ------ | ------ | ----------- | 67 | | 255 | 2.95 min | 10 | 4xV100 | Yes | 68 | | 256 | 2.87 min | 10 | 4xV100 | Yes | 69 | | 257 | 2.86 min | 10 | 4xV100 | Yes | 70 | 71 | Note that I removed the inference time (here: evaluation on the test set) from this table, because in practice, you would still use a single V100 for inference purposes. 72 | 73 | 74 | 75 | 76 | Note that this is all from one run each. To get more reliable stats, repeating the runs many times and reporting the average + SD might be worthwhile. However, even from the numbers above, it is probably apparent that there is only a small but barely noticeable difference between 127, 128, and 129. 77 | 78 | 79 | 80 | **Or in other words, do you have a batch size of 128 that you would like to run, but it doesn't fit into memory? It's probably okay to train that model with a batch size of 120 and 100 before scaling it down to 64** 😊. 81 | 82 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import pytorch_lightning as pl 5 | import torch 6 | from my_classifier_template.dataset import Cifar10DataModule 7 | from my_classifier_template.model import LightningClassifier 8 | from pytorch_lightning.callbacks import ModelCheckpoint 9 | from torchvision import transforms 10 | from watermark import watermark 11 | 12 | 13 | def parse_cmdline_args(parser=None): 14 | 15 | if parser is None: 16 | parser = argparse.ArgumentParser() 17 | 18 | parser.add_argument("--accelerator", type=str, default="auto") 19 | 20 | parser.add_argument("--batch_size", type=int, default=32) 21 | 22 | parser.add_argument("--data_path", type=str, default="./data") 23 | 24 | parser.add_argument("--learning_rate", type=float, default=0.0005) 25 | 26 | parser.add_argument( 27 | "--log_accuracy", type=str, choices=("true", "false"), default="true" 28 | ) 29 | 30 | parser.add_argument( 31 | "--mixed_precision", type=str, choices=("true", "false"), default="false" 32 | ) 33 | 34 | parser.add_argument("--num_epochs", type=int, default=10) 35 | 36 | parser.add_argument("--num_workers", type=int, default=3) 37 | 38 | parser.add_argument("--output_path", type=str, default="") 39 | 40 | parser.add_argument( 41 | "--pretrained", type=str, choices=("true", "false"), default="false" 42 | ) 43 | 44 | parser.add_argument("--num_devices", nargs="+", default="auto") 45 | 46 | parser.add_argument("--device_numbers", type=str, default="") 47 | 48 | parser.add_argument("--random_seed", type=int, default=-1) 49 | 50 | parser.add_argument("--strategy", type=str, default="") 51 | 52 | parser.set_defaults(feature=True) 53 | args = parser.parse_args() 54 | 55 | if not args.strategy: 56 | args.strategy = None 57 | 58 | if args.num_devices != "auto": 59 | args.num_devices = int(args.num_devices[0]) 60 | if args.device_numbers: 61 | args.num_devices = [int(i) for i in args.device_numbers.split(",")] 62 | 63 | d = {"true": True, "false": False} 64 | 65 | args.log_accuracy = d[args.log_accuracy] 66 | args.pretrained = d[args.pretrained] 67 | args.mixed_precision = d[args.mixed_precision] 68 | if args.mixed_precision: 69 | args.mixed_precision = 16 70 | else: 71 | args.mixed_precision = 32 72 | 73 | return args 74 | 75 | 76 | if __name__ == "__main__": 77 | 78 | print(watermark()) 79 | print(watermark(packages="torch,pytorch_lightning")) 80 | 81 | parser = argparse.ArgumentParser() 82 | args = parse_cmdline_args(parser) 83 | 84 | torch.manual_seed(args.random_seed) 85 | 86 | custom_train_transform = transforms.Compose( 87 | [ 88 | transforms.Resize((256, 256)), 89 | transforms.RandomCrop((224, 224)), 90 | transforms.ToTensor(), 91 | ] 92 | ) 93 | 94 | custom_test_transform = transforms.Compose( 95 | [ 96 | transforms.Resize((256, 256)), 97 | transforms.CenterCrop((224, 224)), 98 | transforms.ToTensor(), 99 | ] 100 | ) 101 | 102 | data_module = Cifar10DataModule( 103 | batch_size=args.batch_size, 104 | data_path=args.data_path, 105 | num_workers=args.num_workers, 106 | train_transform=custom_train_transform, 107 | test_transform=custom_test_transform, 108 | ) 109 | 110 | pytorch_model = torch.hub.load( 111 | "pytorch/vision:v0.11.0", "mobilenet_v3_large", pretrained=args.pretrained 112 | ) 113 | 114 | pytorch_model.classifier[-1] = torch.nn.Linear( 115 | in_features=1280, out_features=10 # as in original 116 | ) # number of class labels in Cifar-10) 117 | 118 | lightning_model = LightningClassifier( 119 | pytorch_model, learning_rate=args.learning_rate, log_accuracy=args.log_accuracy 120 | ) 121 | 122 | if args.log_accuracy: 123 | callbacks = [ 124 | ModelCheckpoint( 125 | save_top_k=1, mode="max", monitor="valid_acc" 126 | ) # save top 1 model 127 | ] 128 | else: 129 | callbacks = [ 130 | ModelCheckpoint( 131 | save_top_k=1, mode="min", monitor="valid_loss" 132 | ) # save top 1 model 133 | ] 134 | 135 | trainer = pl.Trainer( 136 | max_epochs=args.num_epochs, 137 | callbacks=callbacks, 138 | accelerator=args.accelerator, 139 | devices=args.num_devices, 140 | default_root_dir=args.output_path, 141 | strategy=args.strategy, 142 | precision=args.mixed_precision, 143 | deterministic=False, 144 | log_every_n_steps=10, 145 | ) 146 | 147 | start_time = time.time() 148 | trainer.fit(model=lightning_model, datamodule=data_module) 149 | 150 | train_time = time.time() 151 | runtime = (train_time - start_time) / 60 152 | print(f"Training took {runtime:.2f} min.") 153 | 154 | # setup data on host machine 155 | data_module.prepare_data() 156 | data_module.setup() 157 | 158 | before = time.time() 159 | val_acc = trainer.test(dataloaders=data_module.val_dataloader()) 160 | runtime = (time.time() - before) / 60 161 | print(f"Inference on the validation set took {runtime:.2f} min.") 162 | 163 | before = time.time() 164 | test_acc = trainer.test(dataloaders=data_module.test_dataloader()) 165 | runtime = (time.time() - before) / 60 166 | print(f"Inference on the test set took {runtime:.2f} min.") 167 | 168 | runtime = (time.time() - start_time) / 60 169 | print(f"The total runtime was {runtime:.2f} min.") 170 | 171 | print("Validation accuracy:", val_acc) 172 | print("Test accuracy:", test_acc) 173 | --------------------------------------------------------------------------------