├── src ├── __init__.py ├── data │ ├── __init_.py │ └── in_memory_dataset.py ├── models │ ├── __init__.py │ └── lightning_modules.py ├── log.py ├── callbacks.py ├── train.py └── inference.py ├── requirements-dev.txt ├── requirements.txt ├── dockerfiles ├── cuda118 │ └── Dockerfile └── cuda120 │ └── Dockerfile ├── templates └── rtx_6000_ada.sh ├── .github └── workflows │ ├── linter.yml │ └── docker_push.yaml ├── LICENSE ├── .pre-commit-config.yaml ├── .gitignore └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/__init_.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pre-commit==3.* 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | lightning==2.2.5 2 | protobuf==3.20.* 3 | segmentation-models-pytorch==0.3.3 4 | six==1.16.0 5 | torch==2.3.1 6 | torchvision==0.18.1 7 | -------------------------------------------------------------------------------- /dockerfiles/cuda118/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:22.04 2 | 3 | RUN apt update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 4 | python3 \ 5 | python3-pip \ 6 | curl && \ 7 | apt clean && \ 8 | rm -rf /var/lib/apt/lists/* 9 | 10 | COPY requirements.txt /tmp/requirements.txt 11 | RUN pip3 install --no-cache-dir -r /tmp/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 12 | 13 | COPY ./src /workdir/src 14 | WORKDIR /workdir 15 | 16 | ENTRYPOINT [ "python3", "-m" ] 17 | -------------------------------------------------------------------------------- /dockerfiles/cuda120/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:22.04 2 | 3 | RUN apt update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 4 | python3 \ 5 | python3-pip \ 6 | curl && \ 7 | apt clean && \ 8 | rm -rf /var/lib/apt/lists/* 9 | 10 | COPY requirements.txt /tmp/requirements.txt 11 | RUN pip3 install --no-cache-dir -r /tmp/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 12 | 13 | COPY ./src /workdir/src 14 | WORKDIR /workdir 15 | 16 | ENTRYPOINT [ "python3", "-m" ] 17 | -------------------------------------------------------------------------------- /templates/rtx_6000_ada.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # GPU: 1x NVIDIA RTX 6000 Ada, 48 GB VRAM 3 | 4 | #cfg32=("resnet50,644" "resnext50,512" "unet_resnet50,440" "swin,240" "convnext,256") 5 | cfg16=("resnet50,1280" "resnext50,1024" "unet_resnet50,880" "swin,360" "convnext,500") 6 | 7 | N_ITERS=300 8 | PRECISION="16-mixed" 9 | 10 | for str in ${cfg16[@]}; do 11 | IFS=',' read -r -a parts <<< "$str" 12 | 13 | model="${parts[0]}" 14 | batch="${parts[1]}" 15 | 16 | docker run --ipc=host --ulimit memlock=-1 --gpus '"device=1"' cv-benchmark --model $model --batch-size $batch --n-iter $N_ITERS --precision $PRECISION 17 | done 18 | -------------------------------------------------------------------------------- /src/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from pip._internal.operations import freeze 4 | 5 | 6 | def setup_custom_logger(name: str = "benchmark"): 7 | logger = logging.getLogger(name) 8 | 9 | sh = logging.StreamHandler() 10 | 11 | formatter = logging.Formatter("%(asctime)s - %(message)s") 12 | 13 | sh.setFormatter(formatter) 14 | 15 | logger.addHandler(sh) 16 | logger.setLevel(level=logging.DEBUG) 17 | 18 | return logger 19 | 20 | 21 | def print_requirements(): 22 | pkgs = freeze.freeze() 23 | for pkg in pkgs: 24 | logger.info(pkg) 25 | 26 | 27 | logger = setup_custom_logger() 28 | -------------------------------------------------------------------------------- /.github/workflows/linter.yml: -------------------------------------------------------------------------------- 1 | name: Run pre-commit hooks 2 | 3 | on: 4 | push: 5 | pull_request: 6 | branches: [master] 7 | 8 | jobs: 9 | build: 10 | name: Run pre-commit hooks 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Checkout code 15 | uses: actions/checkout@v4 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v3 19 | with: 20 | python-version: "3.10" 21 | 22 | - name: Install pre-commit 23 | run: pip install -r requirements-dev.txt 24 | 25 | - name: Run pre-commit checks 26 | run: pre-commit run --all-files 27 | -------------------------------------------------------------------------------- /src/models/lightning_modules.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class LitClassification(L.LightningModule): 7 | def __init__(self, model: nn.Module, optimizer=torch.optim.Adam): 8 | super().__init__() 9 | self.model = model 10 | self.loss = nn.CrossEntropyLoss() 11 | self.optimizer = optimizer 12 | 13 | def training_step(self, batch, batch_idx) -> torch.Tensor: 14 | y_hat = self.model(batch) 15 | y = torch.rand_like(y_hat) 16 | 17 | loss = self.loss(y_hat, y) 18 | return loss 19 | 20 | def configure_optimizers(self): 21 | return self.optimizer(self.parameters(), lr=2e-5) 22 | -------------------------------------------------------------------------------- /src/data/in_memory_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | 5 | class InMemoryDataset(Dataset): 6 | def __init__( 7 | self, 8 | width: int = 224, 9 | height: int = 224, 10 | n_channels: int = 3, 11 | dataset_size: int = int(1e7), 12 | ): 13 | super().__init__() 14 | self.width = width 15 | self.height = height 16 | self.n_channels = n_channels 17 | self.dataset_size = dataset_size 18 | 19 | def __len__(self): 20 | return self.dataset_size 21 | 22 | def __getitem__(self, idx: int) -> torch.Tensor: 23 | """ 24 | Must return a tensor of shape C x H x W with values in [0, 1] range. 25 | """ 26 | return torch.rand(self.n_channels, self.height, self.width, dtype=torch.float32) 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 TensorPix 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | default_language_version: 4 | python: python3 5 | 6 | repos: 7 | - repo: https://github.com/PyCQA/isort 8 | rev: 5.13.2 9 | hooks: 10 | - id: isort 11 | name: Format imports 12 | args: ["--profile", "black"] 13 | 14 | - repo: https://github.com/psf/black 15 | rev: 24.1.1 16 | hooks: 17 | - id: black 18 | name: black 19 | entry: black 20 | types: [python] 21 | 22 | - repo: https://github.com/pre-commit/pre-commit-hooks 23 | rev: v4.5.0 24 | hooks: 25 | - id: check-yaml 26 | - id: end-of-file-fixer 27 | - id: check-case-conflict 28 | - id: check-docstring-first 29 | - id: check-executables-have-shebangs 30 | - id: check-added-large-files 31 | args: ["--maxkb=350", "--enforce-all"] 32 | - id: detect-private-key 33 | - id: requirements-txt-fixer 34 | - id: mixed-line-ending 35 | - id: check-merge-conflict 36 | 37 | - repo: https://github.com/asottile/pyupgrade 38 | rev: v3.15.0 39 | hooks: 40 | - id: pyupgrade 41 | args: [--py38-plus] 42 | name: Upgrade code 43 | 44 | - repo: https://github.com/PyCQA/flake8 45 | rev: 7.0.0 46 | hooks: 47 | - id: flake8 48 | types: [python] 49 | args: ["--max-line-length=120", "--ignore=E203,W503"] 50 | -------------------------------------------------------------------------------- /.github/workflows/docker_push.yaml: -------------------------------------------------------------------------------- 1 | name: Create and push cuda118 + cuda120 docker images to this repo's packages 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | workflow_dispatch: {} 8 | 9 | env: 10 | REGISTRY: ghcr.io 11 | IMAGE_NAME: ${{ github.repository }} 12 | 13 | jobs: 14 | build-and-push-cuda-images: 15 | runs-on: ubuntu-latest 16 | permissions: 17 | contents: read 18 | packages: write 19 | steps: 20 | - 21 | name: Checkout repository 22 | uses: actions/checkout@v4 23 | 24 | - 25 | name: Log in to the Container registry 26 | uses: docker/login-action@v3 27 | with: 28 | registry: ${{ env.REGISTRY }} 29 | username: ${{ github.actor }} 30 | password: ${{ secrets.GITHUB_TOKEN }} 31 | #CUDA118 steps 32 | - 33 | name: Extract cuda118 image metadata 34 | id: meta_118 35 | uses: docker/metadata-action@v3 36 | with: 37 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 38 | tags: cuda118 39 | 40 | - 41 | name: Build and push cuda118 image 42 | uses: docker/build-push-action@v5 43 | with: 44 | context: . 45 | file: dockerfiles/cuda118/Dockerfile 46 | push: true 47 | tags: ${{ steps.meta_118.outputs.tags }} 48 | labels: ${{ steps.meta_118.outputs.labels }} 49 | #CUDA120 steps 50 | - 51 | name: Extract cuda120 image metadata 52 | id: meta_120 53 | uses: docker/metadata-action@v3 54 | with: 55 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 56 | tags: | 57 | type=raw,value=cuda120 58 | type=raw,value=latest 59 | - 60 | name: Build and push cuda120 image 61 | uses: docker/build-push-action@v5 62 | with: 63 | context: . 64 | file: dockerfiles/cuda120/Dockerfile 65 | push: true 66 | tags: ${{ steps.meta_120.outputs.tags }} 67 | labels: ${{ steps.meta_120.outputs.labels }} 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | lightning_logs/ 2 | 3 | *.csv 4 | benchmarks/ 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /src/callbacks.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import logging 3 | import os 4 | import stat 5 | import time 6 | from datetime import datetime 7 | 8 | import torch 9 | from lightning.pytorch.callbacks import Callback 10 | 11 | logger = logging.getLogger("benchmark") 12 | 13 | 14 | class BenchmarkCallback(Callback): 15 | def __init__( 16 | self, 17 | model_name: str, 18 | precision: str, 19 | workers: int, 20 | warmup_steps: int = 50, 21 | ): 22 | self.warmup_steps = warmup_steps 23 | self.start_time = 0 24 | self.end_time = 0 25 | self.precision = precision 26 | self.model = model_name 27 | self.workers = workers 28 | 29 | def on_fit_start(self, trainer, pl_module): 30 | logger.info( 31 | f"Benchmark started. Number of warmup iterations: {self.warmup_steps}" 32 | ) 33 | 34 | def on_train_batch_start(self, trainer, pl_module, batch, batch_idx: int): 35 | if batch_idx == self.warmup_steps: 36 | logger.info( 37 | f"Completed {self.warmup_steps} warmup steps. Benchmark timer started." 38 | ) 39 | self.start_time = time.time() 40 | 41 | def on_fit_end(self, trainer, pl_module): 42 | self.end_time = time.time() 43 | logger.info("Fit function finished") 44 | 45 | dataset = trainer.train_dataloader.dataset 46 | batch_size = trainer.train_dataloader.batch_size 47 | in_w, in_h = dataset.width, dataset.height 48 | 49 | benchmark_steps = trainer.global_step - self.warmup_steps 50 | processed_megapixels = ( 51 | trainer.world_size * in_w * in_h * batch_size * benchmark_steps / 1e6 52 | ) 53 | 54 | elapsed_time = ( 55 | self.end_time - self.start_time 56 | ) + 1e-7 # for numerical stability 57 | mpx_s = processed_megapixels / (elapsed_time) 58 | 59 | processed_imgs = batch_size * benchmark_steps * trainer.world_size 60 | images_s = processed_imgs / (elapsed_time) 61 | 62 | batches_s = benchmark_steps * trainer.world_size / elapsed_time 63 | 64 | logger.info(f"Benchmark finished in {elapsed_time:.1f} seconds") 65 | logger.info( 66 | f"Average training throughput: {mpx_s:.2f} MPx/s (megapixels per second) | " 67 | + f"{images_s:.2f} images/s | {batches_s:.2f} batches/s" 68 | ) 69 | 70 | os.makedirs("./benchmarks", exist_ok=True) 71 | csv_path = os.path.join("./benchmarks", "benchmark.csv") 72 | file_exists = os.path.isfile(csv_path) and os.stat(csv_path).st_size >= 0 73 | with open(csv_path, "a") as file: 74 | writer = csv.writer(file) 75 | if not file_exists: 76 | writer.writerow( 77 | [ 78 | "Datetime", 79 | "GPU", 80 | "cuDNN version", 81 | "N GPUs", 82 | "Data Loader workers", 83 | "Model", 84 | "Precision", 85 | "Minibatch", 86 | "Input width [px]", 87 | "Input height [px]", 88 | "Warmup steps", 89 | "Benchmark steps", 90 | "MPx/s", 91 | "images/s", 92 | "batches/s", 93 | ] 94 | ) 95 | 96 | data = [ 97 | datetime.now().strftime("%d/%m/%Y %H:%M:%S"), 98 | torch.cuda.get_device_name(0), 99 | torch.backends.cudnn.version(), 100 | trainer.world_size, 101 | self.workers, 102 | self.model, 103 | self.precision, 104 | batch_size, 105 | in_w, 106 | in_h, 107 | self.warmup_steps, 108 | benchmark_steps, 109 | mpx_s, 110 | images_s, 111 | batches_s, 112 | ] 113 | writer.writerow(data) 114 | logger.info( 115 | "Written benchmark data to a CSV file. " 116 | + "See 'Logging Results to a Persisent CSV File' section to " 117 | + "save the file on your disk: " 118 | + "https://github.com/tensorpix/benchmarking-cv-models#logging-results-to-a-persistent-csv-file" 119 | ) 120 | 121 | try: 122 | os.chmod( 123 | csv_path, 124 | stat.S_IRUSR 125 | | stat.S_IRGRP 126 | | stat.S_IWUSR 127 | | stat.S_IROTH 128 | | stat.S_IWOTH, 129 | ) 130 | except Exception as e: 131 | logger.error(f"Failed to change csv permissions: {e}") 132 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import segmentation_models_pytorch as smp 4 | import torch 5 | from lightning import Trainer 6 | from torch.utils.data import DataLoader 7 | from torchvision.models import ( 8 | convnext_base, 9 | efficientnet_v2_m, 10 | mobilenet_v3_large, 11 | resnet50, 12 | resnext50_32x4d, 13 | swin_b, 14 | vgg16, 15 | vit_b_16, 16 | ) 17 | 18 | from src import log 19 | from src.callbacks import BenchmarkCallback 20 | from src.data.in_memory_dataset import InMemoryDataset 21 | from src.log import print_requirements 22 | from src.models.lightning_modules import LitClassification 23 | 24 | logger = log.logger 25 | 26 | ARCHITECTURES = { 27 | "resnet50": resnet50, 28 | "convnext": convnext_base, 29 | "vgg16": vgg16, 30 | "efficient_net_v2": efficientnet_v2_m, 31 | "mobilenet_v3": mobilenet_v3_large, 32 | "resnext50": resnext50_32x4d, 33 | "swin": swin_b, 34 | "vit": vit_b_16, 35 | "unet_resnet50": smp.Unet, 36 | # TODO "ssd_vgg16": ssd300_vgg16, 37 | # TODO "fasterrcnn_resnet50_v2": fasterrcnn_resnet50_fpn_v2, 38 | } 39 | 40 | 41 | def main(args): 42 | if args.list_requirements: 43 | print_requirements() 44 | 45 | args_dict = vars(args) 46 | logger.info(f"User Arguments {args_dict}") 47 | 48 | dataset = InMemoryDataset(width=args.width, height=args.width) 49 | data_loader = DataLoader( 50 | dataset, 51 | num_workers=args.n_workers, 52 | batch_size=args.batch_size, 53 | shuffle=True, 54 | pin_memory=True, 55 | drop_last=True, 56 | ) 57 | 58 | trainer = Trainer( 59 | accelerator=args.accelerator, 60 | strategy="ddp", 61 | precision=args.precision, 62 | limit_train_batches=args.n_iters + args.warmup_steps, 63 | max_epochs=1, 64 | logger=False, 65 | enable_checkpointing=False, 66 | callbacks=[ 67 | BenchmarkCallback( 68 | warmup_steps=args.warmup_steps, 69 | model_name=args.model, 70 | precision=args.precision, 71 | workers=args.n_workers, 72 | ) 73 | ], 74 | devices=torch.cuda.device_count(), 75 | ) 76 | 77 | if args.model in ARCHITECTURES: 78 | if args.model == "unet_resnet50": 79 | model = ARCHITECTURES[args.model]( 80 | encoder_name="resnet50", encoder_weights=None 81 | ) 82 | else: 83 | model = ARCHITECTURES[args.model]() 84 | 85 | else: 86 | raise ValueError("Architecture not supported.") 87 | 88 | model = LitClassification(model=model) 89 | trainer.fit(model=model, train_dataloaders=data_loader) 90 | 91 | 92 | if __name__ == "__main__": 93 | parser = argparse.ArgumentParser(description="Benchmark CV models training on GPU.") 94 | 95 | parser.add_argument( 96 | "--batch-size", 97 | type=int, 98 | required=True, 99 | help="Minibatch size. Set the value so >90%% VRAM is filled during benchmark for most representative results.", 100 | ) 101 | parser.add_argument( 102 | "--n-iters", 103 | type=int, 104 | default=200, 105 | help="Number of training iterations to benchmark for. One iteration = one batch update.", 106 | ) 107 | parser.add_argument( 108 | "--precision", choices=["32", "16", "16-mixed", "bf16-mixed"], default="32" 109 | ) 110 | parser.add_argument( 111 | "--n-workers", 112 | type=int, 113 | default=4, 114 | help="Number of Data Loader workers. CPU shouldn't be a bottleneck with 4+.", 115 | ) 116 | 117 | parser.add_argument("--width", type=int, default=224, help="Input width") 118 | parser.add_argument("--height", type=int, default=224, help="Input height") 119 | 120 | parser.add_argument( 121 | "--warmup-steps", 122 | type=int, 123 | default=100, 124 | help=( 125 | "Number of training iterations to use for warmup. " 126 | + " The benchmark timer starts after the warmup iterations are finished." 127 | ), 128 | ) 129 | parser.add_argument( 130 | "--accelerator", choices=["gpu"], default="gpu", help="Accelerator to use." 131 | ) 132 | parser.add_argument( 133 | "--model", 134 | default="resnet50", 135 | choices=list(ARCHITECTURES.keys()), 136 | help="Architecture to benchmark.", 137 | ) 138 | parser.add_argument( 139 | "--list-requirements", 140 | action="store_true", 141 | help="Prints all python packages along with their versions.", 142 | ) 143 | 144 | args = parser.parse_args() 145 | 146 | if args.n_iters <= 0: 147 | raise ValueError("Number of iterations must be > 0") 148 | 149 | if args.warmup_steps <= 0: 150 | raise ValueError("Number of warmup steps must be > 0") 151 | 152 | logger.info("########## STARTING NEW BENCHMARK RUN ###########") 153 | 154 | if not torch.cuda.is_available(): 155 | raise ValueError("CUDA device not found on this system.") 156 | else: 157 | logger.info(f"CUDA Device Name: {torch.cuda.get_device_name(0)}") 158 | logger.info(f"CUDNN version: {torch.backends.cudnn.version()}") 159 | logger.info( 160 | f"CUDA Device Total Memory: {(torch.cuda.get_device_properties(0).total_memory / 1e9):.2f} GB" 161 | ) 162 | 163 | main(args=args) 164 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |