├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── codebase ├── __init__.py ├── config │ └── __init__.py ├── criterion │ ├── __init__.py │ ├── label_smooth.py │ └── register.py ├── data │ ├── __init__.py │ ├── cifar.py │ ├── imagenet │ │ ├── __init__.py │ │ ├── dali.py │ │ └── native.py │ ├── register.py │ ├── synthetic_dataset.py │ └── utils.py ├── engine.py ├── main.py ├── models │ ├── __init__.py │ ├── dummy_model.py │ └── register.py ├── optimizer │ ├── __init__.py │ └── register.py └── scheduler │ ├── __init__.py │ ├── register.py │ ├── warmup_cosine_annealing.py │ └── warmup_exponential.py ├── conf ├── base.conf ├── cifar10.conf ├── cifar100.conf ├── resnet50-benchmark.conf ├── resnet50-tfrec-v1_5.conf └── vit_cifar10.conf ├── doc └── benchmark.md ├── entry └── run.py ├── requirements.txt ├── tests ├── __init__.py ├── codebase │ ├── __init__.py │ ├── test_criterion.py │ └── test_main.py └── resources │ ├── __init__.py │ └── test.conf └── tools ├── make_tfrecord.py └── make_wds.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/code,python 3 | # Edit at https://www.gitignore.io/?templates=code,python 4 | 5 | ### Code ### 6 | .vscode/* 7 | # !.vscode/settings.json 8 | # !.vscode/tasks.json 9 | # !.vscode/launch.json 10 | # !.vscode/extensions.json 11 | 12 | ### Python ### 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | pip-wheel-metadata/ 36 | share/python-wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # pipenv 82 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 83 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 84 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 85 | # install all needed dependencies. 86 | #Pipfile.lock 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # Mr Developer 102 | .mr.developer.cfg 103 | .project 104 | .pydevproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | .dmypy.json 112 | dmypy.json 113 | 114 | # Pyre type checker 115 | .pyre/ 116 | 117 | # End of https://www.gitignore.io/api/code,python 118 | 119 | *.pth 120 | jobs -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "codebase/torchutils"] 2 | path = codebase/torchutils 3 | url = https://github.com/chenyaofo/torchutils 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) chenyaofo 2021, 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image Classification Codebase 2 | 3 | This project **aims** to provide a codebase for the image classification task implemented by PyTorch. 4 | It does not use any high-level deep learning libraries (such as pytorch-lightening or MMClassification). 5 | Thus, it should be easy to follow and modified. 6 | 7 | ## Requirements 8 | 9 | The code is tested on `python==3.9, pyhocon==0.3.57, torch=1.8.0, torchvision=0.9.0` 10 | 11 | ## Get Started 12 | 13 | You can get started with a resnet20 convolution network on cifar10 with the following command. 14 | 15 | **Single node, single GPU:** 16 | 17 | ```bash 18 | CUDA_VISIBLE_DEVICES=0 python -m entry.run --conf conf/cifar10.conf -o output/cifar10/resnet20 19 | ``` 20 | 21 | > Tips: run `CUDA_VISIBLE_DEVICES=0 python -m entry.run --conf conf/resnet50-benchmark.conf -o output/benchmark` to check throughput performance, more details can be found at [doc/benchmark.md](doc/benchmark.md) 22 | 23 | You can use multiple GPUs to accelerate the training with distributed data parallel: 24 | 25 | **Single node, multiple GPUs:** 26 | 27 | ```bash 28 | CUDA_VISIBLE_DEVICES=0,1 python -m entry.run --world-size 2 \ 29 | --conf conf/cifar10.conf -o output/cifar10/resnet20 30 | ``` 31 | 32 | **Multiple nodes:** 33 | 34 | Node 0: 35 | ```bash 36 | CUDA_VISIBLE_DEVICES=0,1 python -m entry.run --world-size 4 --dist-url \ 37 | 'tcp://IP_OF_NODE0:FREEPORT' --node-rank 0 --conf conf/cifar10.conf -o output/cifar10/resnet20 38 | ``` 39 | 40 | Node 1: 41 | ```bash 42 | CUDA_VISIBLE_DEVICES=0,1 python -m entry.run --world-size 4 --dist-url \ 43 | 'tcp://IP_OF_NODE1:FREEPORT' --node-rank 1 --conf conf/cifar10.conf -o output/cifar10/resnet20 44 | ``` 45 | 46 | 47 | ## Features 48 | 49 | This codebase adopt configuration file (`.hocon`) to store the hyperparameters (such as the learning rate, training epochs and etc.). 50 | If you want to modify the configuration hyperparameters, you have two ways: 51 | 52 | 1. Modify the configuration file to generate a new file. 53 | 54 | 2. You can add `-M` in the running command line to modify the hyperparameters temporarily. 55 | 56 | 57 | For example, if you hope to modify the total training epochs to 100 and the learning rate to 0.05. You can run the following command: 58 | 59 | ```bash 60 | CUDA_VISIBLE_DEVICES=0 python -m entry.run --conf conf/cifar10.conf -o output/cifar10/resnet20 -M max_epochs=100 optimizer.lr=0.05 61 | ``` 62 | 63 | If you modify a non existing hyperparameter, the code will raise an exception. 64 | 65 | To list all valid hyperparameters names, you can run the following command: 66 | 67 | ```bash 68 | pyhocon -i conf/cifar10.conf -f properties 69 | ``` 70 | 71 | 3. We use NVIDIA DALI to accelerate the data preprocessing on ImageNet (use it by the flag `data.use_dali`) and tfrecord format to store the ImageNet (create the tfrecords by `tools/make_tfrecord.py` and use it by the flag `data.use_tfrecord`). 72 | 73 | 74 | Finally, enjoy the code. 75 | 76 | ## Cite 77 | 78 | ``` 79 | @misc{chen2020image, 80 | author = {Yaofo Chen}, 81 | title = {Image Classification Codebase}, 82 | year = {2021}, 83 | howpublished = {\url{https://github.com/chenyaofo/image-classification-codebase}} 84 | } 85 | ``` 86 | -------------------------------------------------------------------------------- /codebase/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /codebase/config/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | import sys 4 | import pathlib 5 | from dataclasses import dataclass 6 | from typing import List, Optional 7 | 8 | from pyhocon import ConfigFactory, ConfigTree 9 | 10 | from codebase.torchutils.common import get_free_port, apply_modifications 11 | from codebase.torchutils.typed_args import TypedArgs, add_argument 12 | 13 | 14 | def is_valid_domain(value): 15 | pattern = re.compile( 16 | r'^(([a-zA-Z]{1})|([a-zA-Z]{1}[a-zA-Z]{1})|' 17 | r'([a-zA-Z]{1}[0-9]{1})|([0-9]{1}[a-zA-Z]{1})|' 18 | r'([a-zA-Z0-9][-_.a-zA-Z0-9]{0,61}[a-zA-Z0-9]))\.' 19 | r'([a-zA-Z]{2,13}|[a-zA-Z0-9-]{2,30}.[a-zA-Z]{2,3})$' 20 | ) 21 | return True if pattern.match(value) else False 22 | 23 | 24 | def is_valid_ip(str): 25 | p = re.compile('^((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)$') 26 | return True if p.match(str) else False 27 | 28 | 29 | @dataclass 30 | class Args(TypedArgs): 31 | output_dir: str = add_argument("-o", "--output-dir", default="") 32 | conf: str = add_argument("--conf", default="") 33 | modifications: List[str] = add_argument("-M", nargs='+', help="list") 34 | 35 | world_size: int = add_argument("--world-size", default=1) 36 | dist_backend: str = add_argument("--dist-backend", default="nccl") 37 | dist_url: Optional[str] = add_argument("--dist-url", default=None) 38 | node_rank: int = add_argument("--node-rank", default=0) 39 | 40 | 41 | def get_args(argv=sys.argv): 42 | args, _ = Args.from_known_args(argv) 43 | args.output_dir = pathlib.Path(args.output_dir) 44 | 45 | if args.dist_url is None: 46 | args.dist_url = f"tcp://127.0.0.1:{get_free_port()}" 47 | elif is_valid_domain(args.dist_url) or is_valid_ip(args.dist_url): 48 | args.dist_url = f"tcp://{args.dist_url}:{get_free_port()}" 49 | 50 | args.conf:ConfigTree = ConfigFactory.parse_file(args.conf) 51 | args.output_dir.mkdir(parents=True, exist_ok=True) 52 | 53 | apply_modifications(modifications=args.modifications, conf=args.conf) 54 | 55 | return args 56 | -------------------------------------------------------------------------------- /codebase/criterion/__init__.py: -------------------------------------------------------------------------------- 1 | from .register import CRITERION 2 | from codebase.torchutils.common import load_modules 3 | 4 | 5 | load_modules(__name__, __file__) 6 | -------------------------------------------------------------------------------- /codebase/criterion/label_smooth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.modules.loss import _WeightedLoss 5 | 6 | from .register import CRITERION 7 | 8 | 9 | @CRITERION.register 10 | class LabelSmoothCrossEntropyLoss(_WeightedLoss): 11 | def __init__(self, num_classes, epsilon=0.1, weight=None, size_average=None, 12 | reduce=None, reduction='mean'): 13 | super().__init__(weight, size_average, reduce, reduction) 14 | self.num_classes = num_classes 15 | self.epsilon = epsilon 16 | 17 | def forward(self, input, target): 18 | logprobs = F.log_softmax(input, dim=-1) 19 | with torch.no_grad(): 20 | target_probs = torch.full_like(logprobs, self.epsilon/(self.num_classes-1)) 21 | target_probs.scatter_(dim=-1, index=target.unsqueeze(1), value=1.0-self.epsilon) 22 | 23 | losses = -(target_probs*logprobs).sum(dim=-1) 24 | if self.weight is not None: 25 | losses = losses * self.weight 26 | if self.reduction == "none": 27 | return losses 28 | elif self.reduction == "sum": 29 | return losses.sum() 30 | elif self.reduction == "mean": 31 | return losses.mean() 32 | else: 33 | raise ValueError(f"The parameter 'reduction' must be in ['none','mean','sum'], bot got {self.redcution}") 34 | -------------------------------------------------------------------------------- /codebase/criterion/register.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from codebase.torchutils.register import Register 3 | 4 | CRITERION = Register("criterion") 5 | 6 | CRITERION.register(nn.CrossEntropyLoss) -------------------------------------------------------------------------------- /codebase/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .register import DATA 2 | 3 | from .cifar import cifar10 4 | from .imagenet import imagenet2012 5 | from .synthetic_dataset import synthetic_data 6 | -------------------------------------------------------------------------------- /codebase/data/cifar.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.utils.data as data 3 | import torchvision.transforms as T 4 | from torchvision.datasets import CIFAR10, CIFAR100 5 | 6 | from .utils import get_samplers 7 | from .register import DATA 8 | 9 | 10 | _logger = logging.getLogger(__name__) 11 | 12 | 13 | def get_train_transforms(mean, std): 14 | return T.Compose([ 15 | T.RandomCrop(32, padding=4), 16 | T.RandomHorizontalFlip(), 17 | T.ToTensor(), 18 | T.Normalize(mean=mean, std=std) 19 | ]) 20 | 21 | 22 | def get_val_transforms(mean, std): 23 | return T.Compose([ 24 | T.ToTensor(), 25 | T.Normalize(mean=mean, std=std) 26 | ]) 27 | 28 | 29 | def get_vit_train_transforms(mean, std, img_size): 30 | return T.Compose([ 31 | T.RandomResizedCrop((img_size, img_size), scale=(0.05, 1.0)), 32 | T.ToTensor(), 33 | T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 34 | ]) 35 | 36 | 37 | def get_vit_val_transforms(mean, std, img_size): 38 | return T.Compose([ 39 | T.Resize((img_size, img_size)), 40 | T.ToTensor(), 41 | T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 42 | ]) 43 | 44 | 45 | def _cifar(root, image_size, mean, std, batch_size, num_workers, is_vit, dataset_builder, **kwargs): 46 | if is_vit: 47 | train_transforms = get_vit_train_transforms(mean, std, image_size) 48 | val_transforms = get_vit_val_transforms(mean, std, image_size) 49 | else: 50 | train_transforms = get_train_transforms(mean, std) 51 | val_transforms = get_val_transforms(mean, std) 52 | 53 | trainset = dataset_builder(root, train=True, transform=train_transforms, download=True) 54 | valset = dataset_builder(root, train=False, transform=val_transforms, download=True) 55 | 56 | _logger.info(f"Loading {dataset_builder.__name__} dataset with trainset (len={len(trainset)}) and valset (len={len(valset)})") 57 | 58 | train_sampler = get_samplers(trainset, is_training=True) 59 | val_sampler = get_samplers(valset, is_training=False) 60 | 61 | train_loader = data.DataLoader(trainset, batch_size=batch_size, 62 | shuffle=(train_sampler is None), 63 | sampler=train_sampler, 64 | num_workers=num_workers, 65 | persistent_workers=True) 66 | val_loader = data.DataLoader(valset, batch_size=batch_size, 67 | shuffle=(val_sampler is None), 68 | sampler=val_sampler, 69 | num_workers=num_workers, 70 | persistent_workers=True) 71 | 72 | return train_loader, val_loader 73 | 74 | 75 | @DATA.register 76 | def cifar10(root, image_size, mean, std, batch_size, num_workers, is_vit, **kwargs): 77 | return _cifar( 78 | root, image_size, mean, std, batch_size, num_workers, is_vit, CIFAR10, **kwargs 79 | ) 80 | 81 | 82 | @DATA.register 83 | def cifar100(root, image_size, mean, std, batch_size, num_workers, is_vit, **kwargs): 84 | return _cifar( 85 | root, image_size, mean, std, batch_size, num_workers, is_vit, CIFAR100, **kwargs 86 | ) 87 | -------------------------------------------------------------------------------- /codebase/data/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | from .native import build_imagenet_loader 2 | from .dali import build_imagenet_dali_loader 3 | 4 | from ..register import DATA 5 | 6 | 7 | @DATA.register 8 | def imagenet2012(root, image_size, mean, std, batch_size, num_workers, use_dali, 9 | use_tfrecord, local_rank, **kwargs): 10 | if use_dali: 11 | return build_imagenet_dali_loader(root, image_size, mean, std, batch_size, num_workers, 12 | use_tfrecord, local_rank) 13 | else: 14 | return build_imagenet_loader(root, image_size, mean, std, batch_size, num_workers) 15 | -------------------------------------------------------------------------------- /codebase/data/imagenet/dali.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import logging 3 | import pathlib 4 | 5 | try: 6 | import nvidia.dali.types as types 7 | import nvidia.dali.fn as fn 8 | from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy 9 | from nvidia.dali.pipeline import Pipeline 10 | import nvidia.dali.tfrecord as tfrec 11 | except ImportError: 12 | warnings.warn("NVIDIA DALI library is unavailable, cannot load and preprocess dataset with DALI.") 13 | 14 | from codebase.torchutils.distributed import world_size, rank 15 | from ..utils import glob_by_suffix 16 | 17 | 18 | _logger = logging.getLogger(__name__) 19 | 20 | 21 | def create_dali_pipeline(reader, image_size, batch_size, mean, std, num_workers, local_rank, 22 | use_tfrecord, dali_cpu=False, is_training=True): 23 | # refer to https://github.com/NVIDIA/DALI/blob/54034c4ddd7cfe2b6dda7e8cec5f91ae18f7ad39/docs/examples/use_cases/pytorch/resnet50/main.py 24 | pipe = Pipeline(batch_size, num_workers, device_id=local_rank) 25 | with pipe: 26 | if use_tfrecord: 27 | images = reader["image"] 28 | labels = reader["label"] 29 | else: 30 | images, labels = reader 31 | # images, labels = fn.external_source(source=eii, num_outputs=2) 32 | dali_device = 'cpu' if dali_cpu else 'gpu' 33 | decoder_device = 'cpu' if dali_cpu else 'mixed' 34 | # ask nvJPEG to preallocate memory for the biggest sample in ImageNet for CPU and GPU to avoid reallocations in runtime 35 | device_memory_padding = 211025920 if decoder_device == 'mixed' else 0 36 | host_memory_padding = 140544512 if decoder_device == 'mixed' else 0 37 | # ask HW NVJPEG to allocate memory ahead for the biggest image in the data set to avoid reallocations in runtime 38 | preallocate_width_hint = 5980 if decoder_device == 'mixed' else 0 39 | preallocate_height_hint = 6430 if decoder_device == 'mixed' else 0 40 | 41 | if is_training: 42 | images = fn.decoders.image_random_crop(images, 43 | device=decoder_device, output_type=types.RGB, 44 | device_memory_padding=device_memory_padding, 45 | host_memory_padding=host_memory_padding, 46 | preallocate_width_hint=preallocate_width_hint, 47 | preallocate_height_hint=preallocate_height_hint, 48 | random_aspect_ratio=[0.75, 4.0 / 3.0], 49 | random_area=[0.08, 1.0], 50 | num_attempts=100) 51 | images = fn.resize(images, 52 | device=dali_device, 53 | resize_x=image_size, 54 | resize_y=image_size, 55 | interp_type=types.INTERP_LINEAR) 56 | mirror = fn.random.coin_flip(probability=0.5) 57 | else: 58 | images = fn.decoders.image(images, 59 | device=decoder_device, 60 | output_type=types.RGB) 61 | images = fn.resize(images, 62 | device=dali_device, 63 | size=int(image_size/7*8), 64 | mode="not_smaller", 65 | interp_type=types.INTERP_LINEAR) 66 | mirror = False 67 | 68 | images = fn.crop_mirror_normalize(images.gpu(), 69 | dtype=types.FLOAT, 70 | output_layout="CHW", 71 | crop=(image_size, image_size), 72 | mean=[item * 255 for item in mean], 73 | std=[item * 255 for item in std], 74 | mirror=mirror) 75 | labels = labels.gpu() 76 | pipe.set_outputs(images, labels) 77 | return pipe 78 | 79 | 80 | class DALIWrapper: 81 | def __init__(self, daliiterator): 82 | self.daliiterator = daliiterator 83 | 84 | def __iter__(self): 85 | self._iter = iter(self.daliiterator) 86 | return self 87 | 88 | def __next__(self): 89 | datas = next(self._iter) 90 | inputs = datas[0]["images"] 91 | targets = datas[0]["targets"].squeeze(-1).long() 92 | return inputs, targets 93 | 94 | def __len__(self): 95 | return len(self.daliiterator) 96 | 97 | 98 | def _build_imagenet_dali_loader(root, is_training, image_size, mean, std, batch_size, num_workers, 99 | use_tfrecord, local_rank=None): 100 | 101 | if use_tfrecord: 102 | reader = fn.readers.tfrecord( 103 | path=glob_by_suffix( 104 | pathlib.Path(root)/("train" if is_training else "val"), 105 | "*.tfrecord" 106 | ), 107 | index_path=glob_by_suffix( 108 | pathlib.Path(root)/("train" if is_training else "val"), 109 | "*.idx" 110 | ), 111 | features={ 112 | "fname": tfrec.FixedLenFeature((), tfrec.string, ""), 113 | "image": tfrec.FixedLenFeature((), tfrec.string, ""), 114 | "label": tfrec.FixedLenFeature([1], tfrec.int64, -1), 115 | }, 116 | shard_id=rank(), 117 | num_shards=world_size(), 118 | random_shuffle=is_training, 119 | initial_fill=3000, 120 | pad_last_batch=True, 121 | dont_use_mmap=True, # If set to True, the Loader will use plain file I/O 122 | # instead of trying to map the file in memory. Mapping provides a small 123 | # performance benefit when accessing a local file system, but most network 124 | # file systems, do not provide optimum performance. 125 | prefetch_queue_depth=2, 126 | read_ahead=True, 127 | name="Reader") 128 | else: 129 | reader = fn.readers.file( 130 | file_root=pathlib.Path(root)/("train" if is_training else "val"), 131 | shard_id=rank(), 132 | num_shards=world_size(), 133 | random_shuffle=is_training, 134 | pad_last_batch=True, 135 | name="Reader" 136 | ) 137 | pipe = create_dali_pipeline(reader, image_size, batch_size, mean, std, num_workers, local_rank, 138 | use_tfrecord=use_tfrecord, is_training=is_training) 139 | loader = DALIGenericIterator(pipe, 140 | output_map=["images", "targets"], 141 | auto_reset=True, 142 | last_batch_policy=LastBatchPolicy.DROP if is_training else LastBatchPolicy.PARTIAL, 143 | reader_name="Reader") 144 | 145 | loader = DALIWrapper(loader) 146 | 147 | _logger.info(f"Loading ImageNet dataset using DALI from {'tfrecord' if use_tfrecord else 'folder'}" 148 | f" with {'trainset' if is_training else 'valset'} (len={pipe.reader_meta()['Reader']['epoch_size']})") 149 | _logger.info(f"Total batch_size={batch_size*world_size()} with world_size={world_size()}, run with {len(loader)} iters per epoch") 150 | return loader 151 | 152 | 153 | def build_imagenet_dali_loader(root, image_size, mean, std, batch_size, num_workers, 154 | use_tfrecord, local_rank): 155 | return _build_imagenet_dali_loader(root, True, image_size, mean, std, batch_size, num_workers, 156 | use_tfrecord, local_rank),\ 157 | _build_imagenet_dali_loader(root, False, image_size, mean, std, batch_size, num_workers, 158 | use_tfrecord, local_rank) 159 | -------------------------------------------------------------------------------- /codebase/data/imagenet/native.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | 4 | import torch.utils.data as data 5 | import torchvision.transforms as T 6 | from torchvision.datasets import ImageFolder 7 | 8 | from codebase.torchutils.distributed import world_size 9 | from ..utils import get_samplers 10 | 11 | 12 | _logger = logging.getLogger(__name__) 13 | 14 | 15 | def get_train_transforms(crop_size, mean, std, is_training): 16 | pipelines = [] 17 | if is_training: 18 | pipelines.append(T.RandomResizedCrop(crop_size)) 19 | pipelines.append(T.RandomHorizontalFlip()) 20 | else: 21 | pipelines.append(T.Resize(int(crop_size/7*8))) 22 | pipelines.append(T.CenterCrop(crop_size)) 23 | pipelines.append(T.ToTensor()) 24 | pipelines.append(T.Normalize(mean=mean, std=std)) 25 | return T.Compose(pipelines) 26 | 27 | 28 | def _build_imagenet_loader(root, is_training, image_size, mean, std, batch_size, num_workers): 29 | transforms = get_train_transforms(image_size, mean, std, is_training) 30 | 31 | dataset = ImageFolder(pathlib.Path(root)/("train" if is_training else "val"), transform=transforms) 32 | sampler = get_samplers(dataset, is_training) 33 | loader = data.DataLoader(dataset, batch_size=batch_size, 34 | shuffle=(sampler is None), 35 | sampler=sampler, 36 | num_workers=num_workers, 37 | persistent_workers=True, 38 | drop_last=is_training) 39 | _logger.info(f"Loading ImageNet dataset using torchvision from folder" 40 | f" with {'trainset' if is_training else 'valset'} (len={len(dataset)})") 41 | _logger.info(f"Total batch_size={batch_size*world_size()} with world_size={world_size()}, run with {len(loader)} iters per epoch") 42 | 43 | return loader 44 | 45 | 46 | def build_imagenet_loader(root, image_size, mean, std, batch_size, num_workers): 47 | return _build_imagenet_loader(root, True, image_size, mean, std, batch_size, num_workers),\ 48 | _build_imagenet_loader(root, False, image_size, mean, std, batch_size, num_workers) 49 | -------------------------------------------------------------------------------- /codebase/data/register.py: -------------------------------------------------------------------------------- 1 | from codebase.torchutils.register import Register 2 | 3 | DATA = Register("data") 4 | -------------------------------------------------------------------------------- /codebase/data/synthetic_dataset.py: -------------------------------------------------------------------------------- 1 | from typing_extensions import runtime 2 | import logging 3 | 4 | import torch 5 | from .register import DATA 6 | 7 | _logger = logging.getLogger(__name__) 8 | 9 | 10 | class SyntheticDataLoader: 11 | def __init__(self, input_size, target_size, num_classes, length: int, device="cuda"): 12 | print(device) 13 | self.images = torch.rand(input_size, device=device, dtype=torch.float) 14 | self.targets = torch.randint(0, num_classes, target_size, device=device, dtype=torch.long) 15 | self.length = length 16 | 17 | def __iter__(self): 18 | self.n = 0 19 | return self 20 | 21 | def __next__(self): 22 | self.n += 1 23 | if self.n == len(self): 24 | raise StopIteration 25 | return self.images, self.targets 26 | 27 | def __len__(self): 28 | return self.length 29 | 30 | 31 | @DATA.register 32 | def synthetic_data(input_size, target_size, num_classes, length, device, **kwargs): 33 | return SyntheticDataLoader(input_size, target_size, num_classes, length, device),\ 34 | SyntheticDataLoader(input_size, target_size, num_classes, length, device) 35 | -------------------------------------------------------------------------------- /codebase/data/utils.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | from torch.utils.data.distributed import DistributedSampler 4 | 5 | from codebase.torchutils.distributed import is_dist_avail_and_init 6 | 7 | 8 | def glob_tars(path): 9 | tars = list(map(str, pathlib.Path(path).glob('*.tar'))) 10 | tars = sorted(tars) 11 | return tars 12 | 13 | 14 | def glob_by_suffix(path, pattern): 15 | tars = list(map(str, pathlib.Path(path).glob(pattern))) 16 | tars = sorted(tars) 17 | return tars 18 | 19 | 20 | def get_samplers(dataset, is_training): 21 | return DistributedSampler(dataset, shuffle=is_training) if is_dist_avail_and_init() else None 22 | -------------------------------------------------------------------------------- /codebase/engine.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.utils.data as data 8 | from torch.cuda.amp import autocast, GradScaler 9 | 10 | from codebase.torchutils.distributed import world_size 11 | from codebase.torchutils.metrics import AccuracyMetric, AverageMetric, EstimatedTimeArrival 12 | from codebase.torchutils.common import GradientAccumulator 13 | from codebase.torchutils.common import ThroughputTester, time_enumerate 14 | 15 | _logger = logging.getLogger(__name__) 16 | 17 | scaler = None 18 | 19 | 20 | def _run_one_epoch(is_training: bool, 21 | epoch: int, 22 | model: nn.Module, 23 | loader: data.DataLoader, 24 | criterion: nn.modules.loss._Loss, 25 | optimizer: optim.Optimizer, 26 | scheduler: optim.lr_scheduler._LRScheduler, 27 | use_amp: bool, 28 | accmulated_steps: int, 29 | device: str, 30 | memory_format: str, 31 | log_interval: int): 32 | phase = "train" if is_training else "eval" 33 | model.train(mode=is_training) 34 | 35 | global scaler 36 | if scaler is None: 37 | scaler = GradScaler(enabled=use_amp and is_training) 38 | 39 | gradident_accumulator = GradientAccumulator(steps=accmulated_steps, enabled=is_training) 40 | 41 | time_cost_metric = AverageMetric("time_cost") 42 | loss_metric = AverageMetric("loss") 43 | accuracy_metric = AccuracyMetric(topk=(1, 5)) 44 | eta = EstimatedTimeArrival(len(loader)) 45 | speed_tester = ThroughputTester() 46 | 47 | if is_training and scheduler is not None: 48 | scheduler.step(epoch) 49 | 50 | lr = optimizer.param_groups[0]['lr'] 51 | _logger.info(f"{phase.upper()} start, epoch={epoch:04d}, lr={lr:.6f}") 52 | 53 | for time_cost, iter_, (inputs, targets) in time_enumerate(loader, start=1): 54 | inputs = inputs.to(device=device, non_blocking=True, memory_format=memory_format) 55 | targets = targets.to(device=device, non_blocking=True) 56 | 57 | with torch.set_grad_enabled(mode=is_training): 58 | with autocast(enabled=use_amp and is_training): 59 | outputs = model(inputs) 60 | loss: torch.Tensor = criterion(outputs, targets) 61 | 62 | gradident_accumulator.backward_step(model, loss, optimizer, scaler) 63 | 64 | time_cost_metric.update(time_cost) 65 | loss_metric.update(loss) 66 | accuracy_metric.update(outputs, targets) 67 | eta.step() 68 | speed_tester.update(inputs) 69 | 70 | if iter_ % log_interval == 0 or iter_ == len(loader): 71 | _logger.info(", ".join([ 72 | phase.upper(), 73 | f"epoch={epoch:04d}", 74 | f"iter={iter_:05d}/{len(loader):05d}", 75 | f"fetch data time cost={time_cost_metric.compute()*1000:.2f}ms", 76 | f"fps={speed_tester.compute()*world_size():.0f} images/s", 77 | f"{loss_metric}", 78 | f"{accuracy_metric}", 79 | f"{eta}", 80 | ])) 81 | time_cost_metric.reset() 82 | speed_tester.reset() 83 | 84 | _logger.info(", ".join([ 85 | phase.upper(), 86 | f"epoch={epoch:04d} {phase} complete", 87 | f"{loss_metric}", 88 | f"{accuracy_metric}", 89 | ])) 90 | 91 | return { 92 | f"{phase}/lr": lr, 93 | f"{phase}/loss": loss_metric.compute(), 94 | f"{phase}/top1_acc": accuracy_metric.at(1).rate, 95 | f"{phase}/top5_acc": accuracy_metric.at(5).rate, 96 | } 97 | 98 | 99 | train_one_epoch = functools.partial(_run_one_epoch, is_training=True) 100 | evaluate_one_epoch = functools.partial(_run_one_epoch, is_training=False) 101 | -------------------------------------------------------------------------------- /codebase/main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import dataclasses 3 | import pprint 4 | 5 | import torch 6 | from torch import optim 7 | import torch.cuda 8 | import torch.utils.data 9 | import torch.nn as nn 10 | import torch.multiprocessing as mp 11 | from torch.utils.collect_env import get_pretty_env_info 12 | from torch.utils.tensorboard import SummaryWriter 13 | from pyhocon import ConfigTree 14 | 15 | from codebase.config import Args 16 | from codebase.data import DATA 17 | from codebase.models import MODEL 18 | from codebase.optimizer import OPTIMIZER 19 | from codebase.scheduler import SCHEDULER 20 | from codebase.criterion import CRITERION 21 | from codebase.engine import train_one_epoch, evaluate_one_epoch 22 | 23 | from codebase.torchutils.common import set_cudnn_auto_tune, set_reproducible, generate_random_seed, disable_debug_api 24 | from codebase.torchutils.common import set_proper_device, get_device 25 | from codebase.torchutils.common import unwarp_module 26 | from codebase.torchutils.common import compute_nparam, compute_flops 27 | from codebase.torchutils.common import StateCheckPoint 28 | from codebase.torchutils.common import MetricsStore 29 | from codebase.torchutils.common import patch_download_in_cn 30 | from codebase.torchutils.common import only_master 31 | from codebase.torchutils.distributed import distributed_init, is_dist_avail_and_init, is_master, world_size 32 | from codebase.torchutils.metrics import EstimatedTimeArrival 33 | from codebase.torchutils.logging import init_logger, create_code_snapshot 34 | 35 | 36 | _logger = logging.getLogger(__name__) 37 | 38 | 39 | def excute_pipeline( 40 | only_evaluate: bool, 41 | start_epoch: int, 42 | max_epochs: int, 43 | train_loader: torch.utils.data.DataLoader, 44 | val_loader: torch.utils.data.DataLoader, 45 | writer: SummaryWriter, 46 | state_ckpt: StateCheckPoint, 47 | states: dict, 48 | metric_store: MetricsStore, 49 | **kwargs 50 | ): 51 | if only_evaluate: 52 | metric_store += evaluate_one_epoch( 53 | epoch=0, 54 | loader=val_loader, 55 | **kwargs 56 | ) 57 | return 58 | 59 | eta = EstimatedTimeArrival(max_epochs) 60 | 61 | for epoch in range(start_epoch+1, max_epochs+1): 62 | if is_dist_avail_and_init(): 63 | if hasattr(train_loader, "sampler"): 64 | train_loader.sampler.set_epoch(epoch) 65 | val_loader.sampler.set_epoch(epoch) 66 | 67 | metric_store += train_one_epoch( 68 | epoch=epoch, 69 | loader=train_loader, 70 | **kwargs 71 | ) 72 | 73 | metric_store += evaluate_one_epoch( 74 | epoch=epoch, 75 | loader=val_loader, 76 | **kwargs 77 | ) 78 | 79 | for name, metric in metric_store.get_last_metrics().items(): 80 | writer.add_scalar(name, metric, epoch) 81 | 82 | state_ckpt.save(metric_store=metric_store, states=states) 83 | 84 | eta.step() 85 | 86 | best_metrics = metric_store.get_best_metrics() 87 | _logger.info(f"Epoch={epoch:04d} complete, best val top1-acc={best_metrics['eval/top1_acc']*100:.2f}%, " 88 | f"top5-acc={best_metrics['eval/top5_acc']*100:.2f}% (epoch={metric_store.best_epoch+1}), {eta}") 89 | 90 | 91 | def prepare_for_training(conf: ConfigTree, output_dir: str, local_rank: int): 92 | model_config = conf.get("model") 93 | load_from = model_config.pop("load_from") 94 | model: nn.Module = MODEL.build_from(model_config) 95 | if load_from is not None: 96 | model.load_state_dict(torch.load(load_from, map_location="cpu")) 97 | 98 | if is_dist_avail_and_init() and conf.get_bool("sync_batchnorm"): 99 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 100 | 101 | train_loader, val_loader = DATA.build_from(conf.get("data"), dict(local_rank=local_rank)) 102 | 103 | criterion = CRITERION.build_from(conf.get("criterion")) 104 | 105 | optimizer_config: dict = conf.get("optimizer") 106 | basic_bs = optimizer_config.pop("basic_bs") 107 | optimizer_config["lr"] = optimizer_config["lr"] * (conf.get("data.batch_size") * world_size() / basic_bs) 108 | optimizer = OPTIMIZER.build_from(optimizer_config, dict(params=model.named_parameters())) 109 | _logger.info(f'Set lr={optimizer_config["lr"]:.4f} with batch size={conf.get("data.batch_size") * world_size()}') 110 | 111 | scheduler = SCHEDULER.build_from(conf.get("scheduler"), dict(optimizer=optimizer)) 112 | 113 | if torch.cuda.is_available(): 114 | model = model.to(device=get_device(), memory_format=getattr(torch, conf.get("memory_format"))) 115 | criterion = criterion.to(device=get_device()) 116 | 117 | if conf.get_bool("use_compile"): 118 | if hasattr(torch, "compile"): 119 | _logger.info("Use torch.compile to optimize model, please wait for while.") 120 | model = torch.compile( 121 | model=model, 122 | **conf.get("compile") 123 | ) 124 | else: 125 | _logger.info("PyTorch version is too old to support torch.compile, skip it.") 126 | 127 | if conf.get_bool("use_tf32"): 128 | torch.backends.cuda.matmul.allow_tf32 = True 129 | torch.backends.cudnn.allow_tf32 = True 130 | # image_size = conf.get_int('data.image_size') 131 | # _logger.info(f"Model details: n_params={compute_nparam(model)/1e6:.2f}M, " 132 | # f"flops={compute_flops(model,(1,3, image_size, image_size))/1e6:.2f}M.") 133 | 134 | writer = only_master(SummaryWriter(output_dir)) 135 | 136 | metric_store = MetricsStore(dominant_metric_name="eval/top1_acc") 137 | states = dict(model=unwarp_module(model), optimizer=optimizer, scheduler=scheduler) 138 | state_ckpt = StateCheckPoint(output_dir) 139 | 140 | state_ckpt.restore(metric_store, states, device=get_device()) 141 | 142 | if is_dist_avail_and_init(): 143 | model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) 144 | 145 | return model, train_loader, val_loader, criterion, optimizer, scheduler, \ 146 | state_ckpt, writer, metric_store, states 147 | 148 | 149 | def _init(local_rank: int, ngpus_per_node: int, args: Args): 150 | set_proper_device(local_rank) 151 | rank = args.node_rank*ngpus_per_node+local_rank 152 | init_logger(rank=rank, filenmae=args.output_dir/"default.log") 153 | 154 | # patch_download_in_cn() 155 | 156 | if StateCheckPoint(args.output_dir).is_ckpt_exists(): 157 | _logger.info("-"*30+"Resume from the last training checkpoints."+"-"*30) 158 | 159 | if set_reproducible: 160 | set_reproducible(generate_random_seed()) 161 | else: 162 | set_cudnn_auto_tune() 163 | disable_debug_api() 164 | 165 | create_code_snapshot(name="code", include_suffix=[".py", ".conf"], 166 | source_directory=".", store_directory=args.output_dir) 167 | 168 | _logger.info("Collect envs from system:\n" + get_pretty_env_info()) 169 | _logger.info("Args:\n" + pprint.pformat(dataclasses.asdict(args))) 170 | 171 | distributed_init(dist_backend=args.dist_backend, init_method=args.dist_url, 172 | world_size=args.world_size, rank=rank) 173 | 174 | 175 | def main_worker(local_rank: int, 176 | ngpus_per_node: int, 177 | args: Args, 178 | conf: ConfigTree): 179 | 180 | _init(local_rank=local_rank, ngpus_per_node=ngpus_per_node, args=args) 181 | 182 | model, train_loader, val_loader, criterion, optimizer, \ 183 | scheduler, saver, writer, metric_store, states = \ 184 | prepare_for_training(conf, args.output_dir, local_rank) 185 | 186 | excute_pipeline( 187 | only_evaluate=conf.get_bool("only_evaluate"), 188 | start_epoch=metric_store.total_epoch, 189 | max_epochs=conf.get_int("max_epochs"), 190 | train_loader=train_loader, 191 | val_loader=val_loader, 192 | writer=writer, 193 | state_ckpt=saver, 194 | states=states, 195 | metric_store=metric_store, 196 | model=model, 197 | optimizer=optimizer, 198 | criterion=criterion, 199 | scheduler=scheduler, 200 | use_amp=conf.get_bool("use_amp"), 201 | accmulated_steps=conf.get_int("accmulated_steps"), 202 | device=get_device(), 203 | memory_format=getattr(torch, conf.get("memory_format")), 204 | log_interval=conf.get_int("log_interval"), 205 | ) 206 | 207 | 208 | def main(args: Args): 209 | distributed = args.world_size > 1 210 | ngpus_per_node = torch.cuda.device_count() 211 | if distributed: 212 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args, args.conf)) 213 | else: 214 | local_rank = 0 215 | main_worker(local_rank, ngpus_per_node, args, args.conf) 216 | -------------------------------------------------------------------------------- /codebase/models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.hub 3 | 4 | from torchvision.models import resnet18, resnet50 5 | from torchvision.models import mobilenet_v2, shufflenet_v2_x1_0 6 | from torchvision.models import vit_b_16, swin_t 7 | from .dummy_model import dummy_model 8 | 9 | from .register import MODEL 10 | 11 | 12 | @MODEL.register 13 | def PyTorchHub(repo: str, name: str, **kwargs): 14 | return torch.hub.load(repo, name, **kwargs) 15 | 16 | 17 | MODEL.register(resnet18) 18 | MODEL.register(resnet50) 19 | MODEL.register(mobilenet_v2) 20 | MODEL.register(shufflenet_v2_x1_0) 21 | MODEL.register(vit_b_16) 22 | MODEL.register(swin_t) 23 | -------------------------------------------------------------------------------- /codebase/models/dummy_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .register import MODEL 4 | 5 | 6 | @MODEL.register 7 | def dummy_model(): 8 | class DymmyModel(nn.Module): 9 | def __init__(self): 10 | super(DymmyModel, self).__init__() 11 | self.linear = nn.Linear(3, 1000) 12 | 13 | def forward(self, x: torch.Tensor): 14 | x = x.mean(dim=[2, 3], keepdim=False) 15 | x = self.linear(x) 16 | return x 17 | return DymmyModel() 18 | -------------------------------------------------------------------------------- /codebase/models/register.py: -------------------------------------------------------------------------------- 1 | from codebase.torchutils.register import Register 2 | 3 | MODEL = Register("model") 4 | -------------------------------------------------------------------------------- /codebase/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | from .register import OPTIMIZER 4 | 5 | # OPTIMIZER.register(optim.SGD) 6 | # OPTIMIZER.register(optim.Adam) 7 | # OPTIMIZER.register(optim.LBFGS) 8 | 9 | 10 | @OPTIMIZER.register 11 | def CustomSGD(params, lr, momentum=0, dampening=0, weight_decay=0, nesterov=False, bn_weight_decay=None, **kwargs): 12 | def add_bn_extra_ophp(model, extra_ophp): 13 | # add extra optimizer hyper-parameters for bn layers 14 | basic_params = [v for n, v in params if not "bn" in n] 15 | bn_params = [v for n, v in params if "bn" in n] 16 | 17 | basic_params_ophp = dict(params=basic_params) 18 | bn_params_ophp = {**dict(params=bn_params), **extra_ophp} 19 | return [basic_params_ophp, bn_params_ophp] 20 | if bn_weight_decay is None: 21 | return optim.SGD([v for n, v in params], lr, momentum, dampening, weight_decay, nesterov) 22 | else: 23 | return optim.SGD(add_bn_extra_ophp(params, dict(weight_decay=bn_weight_decay)), 24 | lr, momentum, dampening, weight_decay, nesterov) 25 | -------------------------------------------------------------------------------- /codebase/optimizer/register.py: -------------------------------------------------------------------------------- 1 | from codebase.torchutils.register import Register 2 | 3 | OPTIMIZER = Register("optimizer") 4 | -------------------------------------------------------------------------------- /codebase/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | from .register import SCHEDULER 4 | 5 | from .warmup_cosine_annealing import WarmupCosineAnnealingLR 6 | 7 | SCHEDULER.register(optim.lr_scheduler.MultiStepLR) 8 | SCHEDULER.register(optim.lr_scheduler.CosineAnnealingLR) 9 | SCHEDULER.register(optim.lr_scheduler.CosineAnnealingWarmRestarts) 10 | SCHEDULER.register(optim.lr_scheduler.ExponentialLR) 11 | SCHEDULER.register(optim.lr_scheduler.CyclicLR) 12 | SCHEDULER.register(optim.lr_scheduler.LambdaLR) 13 | -------------------------------------------------------------------------------- /codebase/scheduler/register.py: -------------------------------------------------------------------------------- 1 | from codebase.torchutils.register import Register 2 | 3 | SCHEDULER = Register("scheduler") 4 | -------------------------------------------------------------------------------- /codebase/scheduler/warmup_cosine_annealing.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from .register import SCHEDULER 4 | 5 | 6 | @SCHEDULER.register 7 | class WarmupCosineAnnealingLR(object): 8 | def __init__(self, optimizer, T_warmup, T_max, eta_min=0, last_epoch=-1): 9 | self.T_warmup = T_warmup 10 | self.T_max = T_max 11 | self.eta_min = eta_min 12 | 13 | self.optimizer = optimizer 14 | 15 | self.base_lr = optimizer.param_groups[0]["lr"] 16 | 17 | def step(self, epoch): 18 | if epoch < self.T_warmup: 19 | lr = self.base_lr * epoch / self.T_warmup 20 | else: 21 | lr = self.eta_min + (self.base_lr - self.eta_min) * \ 22 | (1 + math.cos(math.pi * (epoch-self.T_warmup) / (self.T_max-self.T_warmup))) / 2 23 | for param_group in self.optimizer.param_groups: 24 | param_group['lr'] = lr 25 | -------------------------------------------------------------------------------- /codebase/scheduler/warmup_exponential.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | class WarmupExponentialLR(object): 5 | def __init__(self, optimizer, T_warmup, lambda_, last_epoch=-1): 6 | self.T_warmup = T_warmup 7 | self.lambda_ = lambda_ 8 | 9 | self.optimizer = optimizer 10 | 11 | self.base_lr = optimizer.param_groups[0]["lr"] 12 | 13 | def step(self, epoch): 14 | if epoch < self.T_warmup: 15 | lr = self.base_lr * epoch / self.T_warmup 16 | else: 17 | lr = self.base_lr * self.lambda_**(epoch - self.T_warmup) 18 | for param_group in self.optimizer.param_groups: 19 | param_group['lr'] = lr 20 | -------------------------------------------------------------------------------- /conf/base.conf: -------------------------------------------------------------------------------- 1 | max_epochs: 90 2 | 3 | log_interval: 10 4 | 5 | # refer to https://pytorch.org/docs/stable/tensor_attributes.html?highlight=memory_format#torch.torch.memory_format 6 | # highly recommend use 'channels_last' on NVIDIA Tasla A100, V100 and RTX 3090 with typical CNNs 7 | memory_format: "channels_last" # select from 'contiguous_format' | 'channels_last' | 'preserve_format' 8 | 9 | use_amp: true # if true, it will train in automatic mixed precision mode 10 | 11 | only_evaluate: false # if true, it will only evalute the model on the validation set and exit 12 | 13 | auto_resume: true # if true, it will automatically load the checkpoint in the output directory and continue to train 14 | 15 | sync_batchnorm: false # if true, it will convert all the batchnorm layers into torch.nn.SyncBatchNorm 16 | 17 | accmulated_steps: 1 18 | 19 | set_reproducible: false # if true, the training will be set to reproducible (refer to https://pytorch.org/docs/stable/notes/randomness.html) 20 | # else torch.backends.cudnn.benchmark will be set to True for largest throughput 21 | 22 | use_tf32: false # if true, it will use TF32 on NVIDIA Ampere GPUs 23 | 24 | use_compile: false # if true, it will compile the model with torch.compile 25 | 26 | compile { 27 | fullgraph: false 28 | dynamic: false, 29 | backend: inductor 30 | mode: null 31 | options: null 32 | disable: false 33 | } -------------------------------------------------------------------------------- /conf/cifar10.conf: -------------------------------------------------------------------------------- 1 | include "base.conf" 2 | 3 | max_epochs: 200 4 | 5 | # refer to https://pytorch.org/docs/stable/tensor_attributes.html?highlight=memory_format#torch.torch.memory_format 6 | # highly recommend use 'channels_last' on NVIDIA Tasla A100, V100 and RTX 3090 with typical CNNs 7 | memory_format: "contiguous_format" # select from 'contiguous_format' | 'channels_last' | 'preserve_format' 8 | 9 | use_amp: false # if true, it will train in automatic mixed precision mode 10 | 11 | data { 12 | type_: cifar10 13 | 14 | is_vit: false 15 | 16 | image_size: 32 17 | num_classes: 10 18 | 19 | root: data/cifar10 20 | 21 | mean: [0.4914, 0.4822, 0.4465] 22 | std: [0.2023, 0.1994, 0.2010] 23 | 24 | batch_size: 256 25 | num_workers: 4 26 | } 27 | 28 | model { 29 | type_: PyTorchHub 30 | force_reload: false 31 | repo: chenyaofo/pytorch-cifar-models 32 | name: cifar10_resnet20 33 | pretrained: false 34 | load_from: null 35 | } 36 | 37 | optimizer { 38 | type_: CustomSGD 39 | basic_bs: 256 40 | lr: 0.1 41 | momentum: 0.9 42 | dampening: 0 43 | weight_decay: 5e-4 44 | nesterov: true 45 | } 46 | 47 | scheduler { 48 | type_: CosineAnnealingLR 49 | T_max: ${max_epochs} 50 | eta_min: 0 51 | } 52 | 53 | criterion { 54 | type_: CrossEntropyLoss 55 | } 56 | -------------------------------------------------------------------------------- /conf/cifar100.conf: -------------------------------------------------------------------------------- 1 | include "cifar10.conf" 2 | 3 | data { 4 | type_: cifar100 5 | 6 | image_size: 32 7 | num_classes: 100 8 | 9 | root: data/cifar100 10 | 11 | mean: [0.5070, 0.4865, 0.4409] 12 | std: [0.2673, 0.2564, 0.2761] 13 | } 14 | 15 | model { 16 | type_: PyTorchHub 17 | force_reload: false 18 | repo: chenyaofo/pytorch-cifar-models 19 | name: cifar100_resnet20 20 | } 21 | -------------------------------------------------------------------------------- /conf/resnet50-benchmark.conf: -------------------------------------------------------------------------------- 1 | include "base.conf" 2 | 3 | max_epochs: 1 4 | 5 | memory_format: "channels_last" # select from 'contiguous_format' | 'channels_last' | 'preserve_format' 6 | 7 | use_amp: true # if true, it will train in automatic mixed precision mode 8 | 9 | data { 10 | type_: synthetic_data 11 | 12 | image_size: 224 13 | batch_size: 256 14 | 15 | input_size: [256,3,224,224] 16 | target_size: [256] 17 | device: cuda 18 | length: 12800 19 | 20 | num_classes: 1000 21 | } 22 | 23 | model { 24 | type_: resnet50 25 | pretrained: false 26 | load_from: null 27 | } 28 | 29 | optimizer { 30 | type_: CustomSGD 31 | lr: 0.256 32 | basic_bs: 256 # we set lr=0.256 for 256 batch size, for other batch sizes we linearly scale the learning rate. 33 | momentum: 0.875 34 | dampening: 0 35 | weight_decay: 3.0517578125e-05 # refer to https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/resnet50v1.5 36 | nesterov: false 37 | bn_weight_decay: 0 38 | } 39 | 40 | scheduler { 41 | type_: WarmupCosineAnnealingLR 42 | T_warmup: 5 43 | T_max: ${max_epochs} 44 | eta_min: 0 45 | } 46 | 47 | 48 | criterion { 49 | type_: CrossEntropyLoss 50 | } 51 | -------------------------------------------------------------------------------- /conf/resnet50-tfrec-v1_5.conf: -------------------------------------------------------------------------------- 1 | include "base.conf" 2 | 3 | max_epochs: 90 4 | 5 | memory_format: "channels_last" # select from 'contiguous_format' | 'channels_last' | 'preserve_format' 6 | 7 | use_amp: true # if true, it will train in automatic mixed precision mode 8 | 9 | data { 10 | type_: imagenet2012 11 | 12 | image_size: 224 13 | num_classes: 1000 14 | 15 | root: data/imagenet2012-tfrec 16 | 17 | mean: [0.485, 0.456, 0.406] 18 | std: [0.229, 0.224, 0.225] 19 | 20 | batch_size: 256 21 | num_workers: 8 22 | 23 | use_dali: true # if ture, use NVIDIA DALI to preprocess the images instead of torchvision 24 | use_tfrecord: true # if ture, use tfrecord format to load images instead of loading from image folder 25 | 26 | } 27 | 28 | model { 29 | type_: resnet50 30 | pretrained: false 31 | load_from: null 32 | } 33 | 34 | optimizer { 35 | type_: CustomSGD 36 | lr: 0.256 37 | basic_bs: 256 # we set lr=0.256 for 256 batch size, for other batch sizes we linearly scale the learning rate. 38 | momentum: 0.875 39 | dampening: 0 40 | weight_decay: 3.0517578125e-05 # refer to https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/resnet50v1.5 41 | nesterov: false 42 | bn_weight_decay: 0 43 | } 44 | 45 | scheduler { 46 | type_: WarmupCosineAnnealingLR 47 | T_warmup: 5 48 | T_max: ${max_epochs} 49 | eta_min: 0 50 | } 51 | 52 | criterion { 53 | type_: LabelSmoothCrossEntropyLoss 54 | num_classes: ${data.num_classes} 55 | epsilon: 0.1 56 | } 57 | -------------------------------------------------------------------------------- /conf/vit_cifar10.conf: -------------------------------------------------------------------------------- 1 | include "cifar10.conf" 2 | 3 | data { 4 | type_: cifar10 5 | 6 | is_vit: true 7 | 8 | image_size: 224 9 | num_classes: 10 10 | 11 | root: data/cifar10 12 | 13 | mean: [0.4914, 0.4822, 0.4465] 14 | std: [0.2023, 0.1994, 0.2010] 15 | 16 | batch_size: 32 17 | num_workers: 4 18 | } 19 | 20 | model { 21 | type_: PyTorchHub 22 | force_reload: false 23 | repo: chenyaofo/pytorch-cifar-models 24 | name: cifar10_vit_b16 25 | pretrained: false 26 | } 27 | -------------------------------------------------------------------------------- /doc/benchmark.md: -------------------------------------------------------------------------------- 1 | ## Throughput Benchmark 2 | 3 | We test this code on NVIDIA A100 and report the throughput in the followings. 4 | 5 | | settings | throughput (imgs/s) | 6 | | --- | --- | 7 | | baseline | 928 | 8 | | +channels_last | 992 | 9 | | +amp | 1459 | 10 | | +channels_last&& | 2260 | 11 | 12 | > Check for NVIDIA impl and **Throughput Benchmark** at https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Classification/ConvNets/resnet50v1.5/README.md#training-performance-results 13 | 14 | Test environment: 15 | 16 | ``` 17 | PyTorch version: 1.12.1+cu113 18 | Is debug build: False 19 | CUDA used to build PyTorch: 11.3 20 | ROCM used to build PyTorch: N/A 21 | 22 | OS: Ubuntu 20.04.4 LTS (x86_64) 23 | GCC version: Could not collect 24 | Clang version: Could not collect 25 | CMake version: Could not collect 26 | Libc version: glibc-2.31 27 | 28 | Python version: 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:35:26) [GCC 10.4.0] (64-bit runtime) 29 | Python platform: Linux-4.15.0-192-generic-x86_64-with-glibc2.31 30 | Is CUDA available: True 31 | CUDA runtime version: Could not collect 32 | GPU models and configuration: GPU 0: NVIDIA A100-SXM4-40GB 33 | Nvidia driver version: 470.129.06 34 | cuDNN version: Could not collect 35 | HIP runtime version: N/A 36 | MIOpen runtime version: N/A 37 | Is XNNPACK available: True 38 | 39 | Versions of relevant libraries: 40 | [pip3] numpy==1.23.3 41 | [pip3] pytorch-lightning==1.7.6 42 | [pip3] torch==1.12.1+cu113 43 | [pip3] torchaudio==0.12.1+cu113 44 | [pip3] torchmetrics==0.9.3 45 | [pip3] torchvision==0.13.1+cu113 46 | [conda] numpy 1.23.3 py310h53a5b5f_0 conda-forge 47 | [conda] pytorch-lightning 1.7.6 pypi_0 pypi 48 | [conda] torch 1.12.1+cu113 pypi_0 pypi 49 | [conda] torchaudio 0.12.1+cu113 pypi_0 pypi 50 | [conda] torchmetrics 0.9.3 pypi_0 pypi 51 | [conda] torchvision 0.13.1+cu113 pypi_0 pypi 52 | ``` 53 | 54 | ## More Benchmarks on PyTorch 2.0 55 | 56 | We test this code on NVIDIA V100 and report the throughput in the followings. 57 | 58 | - Benchmarks on ResNet-50 59 | 60 | | settings | throughput (imgs/s) | 61 | | --- | --- | 62 | | baseline | 345 | 63 | | +channels_last | 345 | 64 | | +amp | 774 | 65 | | +channels_last&& | 1175 | 66 | | +channels_last&&&&compile(default) | 1228 | 67 | | +channels_last&&&&compile(default+fullgraph) | 1228 | 68 | | +channels_last&&&&compile(reduce-overhead) | 1234 | 69 | | +channels_last&&&&compile(max-autotune) | FAIL | 70 | 71 | - Benchmarks on MobileNetV2 72 | 73 | | settings | throughput (imgs/s) | 74 | | --- | --- | 75 | | baseline | 813 | 76 | | +channels_last | 420 | 77 | | +amp | 1315 | 78 | | +channels_last&& | 2100 | 79 | | +channels_last&&&&compile(default) | 2316 | 80 | 81 | - Benchmarks on ShuffleNetV2 82 | 83 | | settings | throughput (imgs/s) | 84 | | --- | --- | 85 | | baseline | 2342 | 86 | | +channels_last | 1854 | 87 | | +amp | 3250 | 88 | | +channels_last&& | 3862 | 89 | | +channels_last&&&&compile(default) | 4711 | 90 | 91 | - Benchmarks on ViT-B16 92 | 93 | | settings | throughput (imgs/s) | 94 | | --- | --- | 95 | | baseline | 102 | 96 | | +amp | 360 | 97 | | +amp&&compile(default) | 289 | 98 | 99 | - Benchmarks on SwinTransformer-tiny 100 | 101 | | settings | throughput (imgs/s) | 102 | | --- | --- | 103 | | baseline | 264 | 104 | | +amp | 499 | 105 | | +amp&&compile(default) | 789 | 106 | 107 | Test environment: 108 | 109 | ``` 110 | PyTorch version: 2.0.0+cu118 111 | Is debug build: False 112 | CUDA used to build PyTorch: 11.8 113 | ROCM used to build PyTorch: N/A 114 | 115 | OS: Ubuntu 22.04.1 LTS (x86_64) 116 | GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0 117 | Clang version: Could not collect 118 | CMake version: version 3.25.0 119 | Libc version: glibc-2.35 120 | 121 | Python version: 3.10.9 | packaged by conda-forge | (main, Feb 2 2023, 20:20:04) [GCC 11.3.0] (64-bit runtime) 122 | Python platform: Linux-5.4.0-139-generic-x86_64-with-glibc2.35 123 | Is CUDA available: True 124 | CUDA runtime version: 11.8.89 125 | CUDA_MODULE_LOADING set to: LAZY 126 | GPU models and configuration: GPU 0: Tesla V100-SXM2-32GB 127 | Nvidia driver version: 525.85.12 128 | cuDNN version: Probably one of the following: 129 | /usr/lib/x86_64-linux-gnu/libcudnn.so.8.7.0 130 | /usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.7.0 131 | /usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.7.0 132 | /usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.7.0 133 | /usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.7.0 134 | /usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.7.0 135 | /usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.7.0 136 | HIP runtime version: N/A 137 | MIOpen runtime version: N/A 138 | Is XNNPACK available: True 139 | 140 | CPU: 141 | Architecture: x86_64 142 | CPU op-mode(s): 32-bit, 64-bit 143 | Address sizes: 46 bits physical, 48 bits virtual 144 | Byte Order: Little Endian 145 | CPU(s): 10 146 | On-line CPU(s) list: 0-9 147 | Vendor ID: GenuineIntel 148 | Model name: Intel(R) Xeon(R) Platinum 8255C CPU @ 2.50GHz 149 | CPU family: 6 150 | Model: 85 151 | Thread(s) per core: 1 152 | Core(s) per socket: 10 153 | Socket(s): 1 154 | Stepping: 5 155 | BogoMIPS: 4999.99 156 | Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 arat avx512_vnni 157 | Hypervisor vendor: KVM 158 | Virtualization type: full 159 | L1d cache: 320 KiB (10 instances) 160 | L1i cache: 320 KiB (10 instances) 161 | L2 cache: 40 MiB (10 instances) 162 | L3 cache: 35.8 MiB (1 instance) 163 | NUMA node(s): 1 164 | NUMA node0 CPU(s): 0-9 165 | Vulnerability Itlb multihit: KVM: Vulnerable 166 | Vulnerability L1tf: Mitigation; PTE Inversion 167 | Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown 168 | Vulnerability Meltdown: Mitigation; PTI 169 | Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown 170 | Vulnerability Retbleed: Vulnerable 171 | Vulnerability Spec store bypass: Vulnerable 172 | Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization 173 | Vulnerability Spectre v2: Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected 174 | Vulnerability Srbds: Not affected 175 | Vulnerability Tsx async abort: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown 176 | 177 | Versions of relevant libraries: 178 | [pip3] numpy==1.23.5 179 | [pip3] torch==2.0.0+cu118 180 | [pip3] torchaudio==2.0.1+cu118 181 | [pip3] torchdata==0.6.0 182 | [pip3] torchvision==0.15.1+cu118 183 | [conda] numpy 1.23.5 pypi_0 pypi 184 | [conda] torch 2.0.0+cu118 pypi_0 pypi 185 | [conda] torchaudio 2.0.1+cu118 pypi_0 pypi 186 | [conda] torchdata 0.6.0 pypi_0 pypi 187 | [conda] torchvision 0.15.1+cu118 pypi_0 pypi 188 | ``` -------------------------------------------------------------------------------- /entry/run.py: -------------------------------------------------------------------------------- 1 | from codebase.config import get_args 2 | from codebase.main import main 3 | 4 | 5 | if __name__ == "__main__": 6 | main(get_args()) 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyhocon 2 | torch 3 | torchvision 4 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenyaofo/image-classification-codebase/6bce2f0224552e1666ae254c362e6f03a38c0f8a/tests/__init__.py -------------------------------------------------------------------------------- /tests/codebase/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenyaofo/image-classification-codebase/6bce2f0224552e1666ae254c362e6f03a38c0f8a/tests/codebase/__init__.py -------------------------------------------------------------------------------- /tests/codebase/test_criterion.py: -------------------------------------------------------------------------------- 1 | from codebase.criterion import CRITERION 2 | 3 | 4 | def test_criterion(): 5 | CRITERION.build_from( 6 | dict( 7 | type_="LabelSmoothCrossEntropyLoss", 8 | num_classes=1000 9 | ) 10 | ) 11 | -------------------------------------------------------------------------------- /tests/codebase/test_main.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import pkg_resources 4 | import tempfile 5 | 6 | from codebase.config import Args, get_args 7 | from codebase.main import main 8 | 9 | 10 | def test_main(): 11 | with tempfile.TemporaryDirectory() as tmpdirname: 12 | args = get_args( 13 | [ 14 | "-o", 15 | tmpdirname, 16 | "--conf", 17 | pkg_resources.resource_filename('tests.resources', 'test.conf'), 18 | "-M", 19 | "max_epochs=3" 20 | ] 21 | ) 22 | 23 | main(args) 24 | -------------------------------------------------------------------------------- /tests/resources/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenyaofo/image-classification-codebase/6bce2f0224552e1666ae254c362e6f03a38c0f8a/tests/resources/__init__.py -------------------------------------------------------------------------------- /tests/resources/test.conf: -------------------------------------------------------------------------------- 1 | max_epochs: 2 2 | 3 | log_interval: 10 4 | 5 | # refer to https://pytorch.org/docs/stable/tensor_attributes.html?highlight=memory_format#torch.torch.memory_format 6 | # highly recommend use 'channels_last' on NVIDIA Tasla A100, V100 and RTX 3090 with typical CNNs 7 | memory_format: "channels_last" # select from 'contiguous_format' | 'channels_last' | 'preserve_format' 8 | 9 | use_amp: false # if true, it will train in automatic mixed precision mode 10 | 11 | only_evaluate: false # if true, it will only evalute the model on the validation set and exit 12 | 13 | sync_batchnorm: false # if true, it will convert all the batchnorm layers into torch.nn.SyncBatchNorm 14 | 15 | accmulated_steps: 1 16 | 17 | set_reproducible: false # if true, the training will be set to reproducible (refer to https://pytorch.org/docs/stable/notes/randomness.html) 18 | # else torch.backends.cudnn.benchmark will be set to True for largest throughput 19 | 20 | data { 21 | type_: synthetic_data 22 | 23 | image_size: 224 24 | batch_size: 16 25 | 26 | input_size: [16,3,224,224] 27 | target_size: [16] 28 | device: cpu 29 | length: 128 30 | 31 | num_classes: 1000 32 | } 33 | 34 | model { 35 | type_: dummy_model 36 | load_from: null 37 | } 38 | 39 | optimizer { 40 | type_: CustomSGD 41 | lr: 0.256 42 | basic_bs: 256 # we set lr=0.256 for 256 batch size, for other batch sizes we linearly scale the learning rate. 43 | momentum: 0.875 44 | dampening: 0 45 | weight_decay: 3.0517578125e-05 # refer to https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/resnet50v1.5 46 | nesterov: false 47 | bn_weight_decay: 0 48 | } 49 | 50 | scheduler { 51 | type_: WarmupCosineAnnealingLR 52 | T_warmup: 5 53 | T_max: ${max_epochs} 54 | eta_min: 0 55 | } 56 | 57 | criterion { 58 | type_: CrossEntropyLoss 59 | } 60 | -------------------------------------------------------------------------------- /tools/make_tfrecord.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This script aims to create tfrecord tar shards with multi-processing. 3 | ''' 4 | 5 | import os 6 | import random 7 | import datetime 8 | from multiprocessing import Process 9 | from torchvision.datasets.folder import ImageFolder 10 | 11 | from codebase.torchutils.serialization import jsonpack 12 | 13 | import struct 14 | import tfrecord 15 | 16 | def create_index(tfrecord_file: str, index_file: str) -> None: 17 | """ 18 | refer to https://github.com/vahidk/tfrecord/blob/master/tfrecord/tools/tfrecord2idx.py 19 | Create index from the tfrecords file. 20 | Stores starting location (byte) and length (in bytes) of each 21 | serialized record. 22 | Params: 23 | ------- 24 | tfrecord_file: str 25 | Path to the TFRecord file. 26 | index_file: str 27 | Path where to store the index file. 28 | """ 29 | infile = open(tfrecord_file, "rb") 30 | outfile = open(index_file, "w") 31 | 32 | while True: 33 | current = infile.tell() 34 | try: 35 | byte_len = infile.read(8) 36 | if len(byte_len) == 0: 37 | break 38 | infile.read(4) 39 | proto_len = struct.unpack("q", byte_len)[0] 40 | infile.read(proto_len) 41 | infile.read(4) 42 | outfile.write(str(current) + " " + str(infile.tell() - current) + "\n") 43 | except: 44 | print("Failed to parse TFRecord.") 45 | break 46 | infile.close() 47 | outfile.close() 48 | 49 | 50 | def make_wds_shards(pattern, num_shards, num_workers, samples, map_func, **kwargs): 51 | random.shuffle(samples) 52 | samples_per_shards = [samples[i::num_shards] for i in range(num_shards)] 53 | shard_ids = list(range(num_shards)) 54 | processes = [ 55 | Process( 56 | target=write_partial_samples, 57 | args=( 58 | pattern, 59 | shard_ids[i::num_workers], 60 | samples_per_shards[i::num_workers], 61 | map_func, 62 | kwargs 63 | ) 64 | ) 65 | for i in range(num_workers)] 66 | for p in processes: 67 | p.start() 68 | for p in processes: 69 | p.join() 70 | 71 | 72 | def write_partial_samples(pattern, shard_ids, samples, map_func, kwargs): 73 | for shard_id, samples in zip(shard_ids, samples): 74 | write_samples_into_single_shard(pattern, shard_id, samples, map_func, kwargs) 75 | 76 | 77 | def write_samples_into_single_shard(pattern, shard_id, samples, map_func, kwargs): 78 | fname = pattern % shard_id 79 | print(f"[{datetime.datetime.now()}] start to write samples to shard {fname}") 80 | # stream = TarWriter(fname, **kwargs) 81 | writer = tfrecord.TFRecordWriter(fname) 82 | size = 0 83 | for i, item in enumerate(samples): 84 | raw_data = map_func(item) 85 | size += len(raw_data["image"][0]) 86 | writer.write(raw_data) 87 | 88 | if i % 1000 == 0: 89 | print(f"[{datetime.datetime.now()}] complete to write {i:06d} samples to shard {fname}") 90 | writer.close() 91 | print(f"[{datetime.datetime.now()}] complete to write samples to shard {fname}!!!") 92 | create_index(fname, fname+".idx") 93 | print(f"[{datetime.datetime.now()}] complete tfrecord2idx to shard {fname}!!!") 94 | return size 95 | 96 | 97 | def main(dataset_root, dataset_split_root, dest, pattern, num_shards, num_workers): 98 | items = [] 99 | dataset = ImageFolder(root=dataset_split_root, loader=lambda x: x) 100 | for i in range(len(dataset)): 101 | path, class_idx = dataset[i] 102 | relpath = os.path.relpath(path, dataset_root) 103 | items.append((path, relpath, class_idx)) 104 | 105 | def map_func(item): 106 | path, relpath, class_idx = item 107 | with open(os.path.join(path), "rb") as stream: 108 | image = stream.read() 109 | 110 | sample = { 111 | # "fname": (bytes(os.path.splitext(os.path.basename(name))[0], "utf-8"), "byte"), 112 | "metadata": ( 113 | jsonpack(dict(path=relpath), maxlen=128), 114 | "byte" 115 | ), 116 | "image": (image, "byte"), 117 | "label": (class_idx, "int") 118 | } 119 | return sample 120 | 121 | os.makedirs(dest, exist_ok=False) 122 | make_wds_shards( 123 | pattern=os.path.join(dest, pattern), 124 | num_shards=num_shards, 125 | num_workers=num_workers, 126 | samples=items, 127 | map_func=map_func, 128 | ) 129 | 130 | 131 | if __name__ == "__main__": 132 | source = os.path.expanduser("/home/chenyaofo/datasets/imagenet") 133 | dest = os.path.expanduser("/home/chenyaofo/datasets/imagenet-tfrec") 134 | main( 135 | dataset_root=source, 136 | dataset_split_root=os.path.join(source, "train"), 137 | dest=os.path.join(dest, "train"), 138 | pattern="%06d.tfrecord", 139 | num_shards=1024, 140 | num_workers=8 141 | ) 142 | main( 143 | dataset_root=source, 144 | dataset_split_root=os.path.join(source, "val"), 145 | dest=os.path.join(dest, "val"), 146 | pattern="%06d.tfrecord", 147 | num_shards=256, 148 | num_workers=8 149 | ) 150 | -------------------------------------------------------------------------------- /tools/make_wds.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This script aims to create webdataset tar shards with multi-processing. 3 | ''' 4 | 5 | import os 6 | import random 7 | import datetime 8 | from multiprocessing import Process 9 | from torchvision.datasets.folder import ImageFolder 10 | from webdataset import TarWriter 11 | 12 | 13 | def make_wds_shards(pattern, num_shards, num_workers, samples, map_func, **kwargs): 14 | random.shuffle(samples) 15 | samples_per_shards = [samples[i::num_shards] for i in range(num_shards)] 16 | shard_ids = list(range(num_shards)) 17 | processes = [ 18 | Process( 19 | target=write_partial_samples, 20 | args=( 21 | pattern, 22 | shard_ids[i::num_workers], 23 | samples_per_shards[i::num_workers], 24 | map_func, 25 | kwargs 26 | ) 27 | ) 28 | for i in range(num_workers)] 29 | for p in processes: 30 | p.start() 31 | for p in processes: 32 | p.join() 33 | 34 | 35 | def write_partial_samples(pattern, shard_ids, samples, map_func, kwargs): 36 | for shard_id, samples in zip(shard_ids, samples): 37 | write_samples_into_single_shard(pattern, shard_id, samples, map_func, kwargs) 38 | 39 | 40 | def write_samples_into_single_shard(pattern, shard_id, samples, map_func, kwargs): 41 | fname = pattern % shard_id 42 | print(f"[{datetime.datetime.now()}] start to write samples to shard {fname}") 43 | stream = TarWriter(fname, **kwargs) 44 | size = 0 45 | for i, item in enumerate(samples): 46 | size += stream.write(map_func(item)) 47 | if i % 1000 == 0: 48 | print(f"[{datetime.datetime.now()}] complete to write {i:06d} samples to shard {fname}") 49 | stream.close() 50 | print(f"[{datetime.datetime.now()}] complete to write samples to shard {fname}!!!") 51 | return size 52 | 53 | 54 | def main(source, dest, num_shards, num_workers): 55 | root = source 56 | items = [] 57 | dataset = ImageFolder(root=root, loader=lambda x: x) 58 | for i in range(len(dataset)): 59 | items.append(dataset[i]) 60 | 61 | def map_func(item): 62 | name, class_idx = item 63 | with open(os.path.join(name), "rb") as stream: 64 | image = stream.read() 65 | sample = { 66 | "__key__": os.path.splitext(os.path.basename(name))[0], 67 | "jpg": image, 68 | "cls": str(class_idx).encode("ascii") 69 | } 70 | return sample 71 | make_wds_shards( 72 | pattern=dest, 73 | num_shards=num_shards, # 设置分片数量 74 | num_workers=num_workers, # 设置创建wds数据集的进程数 75 | samples=items, 76 | map_func=map_func, 77 | ) 78 | 79 | 80 | if __name__ == "__main__": 81 | main( 82 | source="/mnt/cephfs/dataset/imagenet/train", 83 | dest="/mnt/cephfs/home/chenyaofo/datasets/imagenet-wds/train/imagenet-1k-train-%06d.tar", 84 | num_shards=256, 85 | num_workers=8 86 | ) 87 | main( 88 | source="/mnt/cephfs/dataset/imagenet/val", 89 | dest="/mnt/cephfs/home/chenyaofo/datasets/imagenet-wds/val/imagenet-1k-val-%06d.tar", 90 | num_shards=256, 91 | num_workers=8 92 | ) 93 | --------------------------------------------------------------------------------