├── requirements.txt ├── .gitignore ├── folder2lmdb.py ├── README.md └── main.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | numpy 4 | tensorpack 5 | pyarrow 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | env/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | 104 | tmp/ 105 | *.lmdb 106 | *-lock 107 | -------------------------------------------------------------------------------- /folder2lmdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import six 4 | import string 5 | import argparse 6 | 7 | import lmdb 8 | import pickle 9 | import msgpack 10 | import tqdm 11 | from PIL import Image 12 | 13 | import torch 14 | import torch.utils.data as data 15 | from torch.utils.data import DataLoader 16 | from torchvision.transforms import transforms 17 | from torchvision.datasets import ImageFolder 18 | from torchvision import transforms, datasets 19 | # This segfaults when imported before torch: https://github.com/apache/arrow/issues/2637 20 | import pyarrow as pa 21 | 22 | 23 | class ImageFolderLMDB(data.Dataset): 24 | def __init__(self, db_path, transform=None, target_transform=None): 25 | self.db_path = db_path 26 | self.env = lmdb.open(db_path, subdir=os.path.isdir(db_path), 27 | readonly=True, lock=False, 28 | readahead=False, meminit=False) 29 | with self.env.begin(write=False) as txn: 30 | # self.length = txn.stat()['entries'] - 1 31 | self.length = pa.deserialize(txn.get(b'__len__')) 32 | self.keys = pa.deserialize(txn.get(b'__keys__')) 33 | 34 | self.transform = transform 35 | self.target_transform = target_transform 36 | 37 | def __getitem__(self, index): 38 | img, target = None, None 39 | env = self.env 40 | with env.begin(write=False) as txn: 41 | byteflow = txn.get(self.keys[index]) 42 | unpacked = pa.deserialize(byteflow) 43 | 44 | # load image 45 | imgbuf = unpacked[0] 46 | buf = six.BytesIO() 47 | buf.write(imgbuf) 48 | buf.seek(0) 49 | img = Image.open(buf).convert('RGB') 50 | 51 | # load label 52 | target = unpacked[1] 53 | 54 | if self.transform is not None: 55 | img = self.transform(img) 56 | 57 | if self.target_transform is not None: 58 | target = self.target_transform(target) 59 | 60 | return img, target 61 | 62 | def __len__(self): 63 | return self.length 64 | 65 | def __repr__(self): 66 | return self.__class__.__name__ + ' (' + self.db_path + ')' 67 | 68 | 69 | def raw_reader(path): 70 | with open(path, 'rb') as f: 71 | bin_data = f.read() 72 | return bin_data 73 | 74 | 75 | def dumps_pyarrow(obj): 76 | """ 77 | Serialize an object. 78 | 79 | Returns: 80 | Implementation-dependent bytes-like object 81 | """ 82 | return pa.serialize(obj).to_buffer() 83 | 84 | 85 | def folder2lmdb(path, outpath, write_frequency=5000): 86 | directory = os.path.expanduser(path) 87 | print("Loading dataset from %s" % directory) 88 | dataset = ImageFolder(directory, loader=raw_reader) 89 | data_loader = DataLoader(dataset, num_workers=16, collate_fn=lambda x: x) 90 | 91 | lmdb_path = os.path.expanduser(outpath) 92 | isdir = os.path.isdir(lmdb_path) 93 | 94 | print("Generate LMDB to %s" % lmdb_path) 95 | db = lmdb.open(lmdb_path, subdir=isdir, 96 | map_size=1099511627776 * 2, readonly=False, 97 | meminit=False, map_async=True) 98 | 99 | txn = db.begin(write=True) 100 | for idx, data in enumerate(data_loader): 101 | image, label = data[0] 102 | txn.put(u'{}'.format(idx).encode('ascii'), dumps_pyarrow((image, label))) 103 | if idx % write_frequency == 0: 104 | print("[%d/%d]" % (idx, len(data_loader))) 105 | txn.commit() 106 | txn = db.begin(write=True) 107 | 108 | # finish iterating through dataset 109 | txn.commit() 110 | keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)] 111 | with db.begin(write=True) as txn: 112 | txn.put(b'__keys__', dumps_pyarrow(keys)) 113 | txn.put(b'__len__', dumps_pyarrow(len(keys))) 114 | 115 | print("Flushing database ...") 116 | db.sync() 117 | db.close() 118 | 119 | 120 | if __name__ == "__main__": 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument("-d", "--dataset", help="Path to original image dataset folder") 123 | parser.add_argument("-o", "--outpath", help="Path to output LMDB file") 124 | args = parser.parse_args() 125 | folder2lmdb(args.dataset, args.outpath) 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-LMDB 2 | 3 | Scripts to work with LMDB + PyTorch for Imagenet training 4 | 5 | > **NOTE**: This has only been tested in the [NGC PyTorch 19.11-py3 container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) 6 | > 7 | > Other environments have not been tested 8 | 9 | Much of this code and LMDB documentation was adopted from https://github.com/Lyken17/Efficient-PyTorch, so credits to @Lyken17. 10 | 11 | ## Quickstart 12 | 13 | 1. Start interactive PyTorch container 14 | 15 | ```bash 16 | nvidia-docker run -it -v ${PWD}:/mnt -v /imagenet:/imagenet --workdir=/mnt nvcr.io/nvidia/pytorch:19.11-py3 17 | ``` 18 | 19 | 2. Convert data to LMDB format 20 | 21 | ```bash 22 | mkdir -p train-lmdb/ 23 | python folder2lmdb.py --dataset /imagenet/train -o train-lmdb/ 24 | 25 | mkdir -p val-lmdb/ 26 | python folder2lmdb.py --dataset /imagenet/val -o val-lmdb/ 27 | ``` 28 | 29 | 3. Run training on LMDB data 30 | 31 | ```bash 32 | time python main.py --arch resnet50 --train train-lmdb/ --val val-lmdb/ --lmdb --epochs 2 33 | ``` 34 | 35 | 4. (Optional) Compare to JPEG data 36 | 37 | ```bash 38 | time python main.py --arch resnet50 --train /imagenet/train --val /imagenet/val --epochs 2 39 | ``` 40 | 41 | ## Multi-processing Distributed Data Parallel Training 42 | 43 | You should always use the NCCL backend for multi-processing distributed training since it currently provides the best distributed training performance. 44 | 45 | ### Single node, multiple GPUs: 46 | 47 | JPEG 48 | ```bash 49 | python main.py -a resnet50 --dist-url 'tcp://127.0.0.1:9999' --dist-backend 'nccl' --multiprocessing-distributed --world-size 1 --rank 0 --train /imagenet/train --val /imagenet/val 50 | ``` 51 | 52 | LMDB 53 | * NOTE: Since LMDB can't be pickled, you need to hack the `folder2lmdb.ImageFolderLMDB` to delay the loading of the environment, such as below: 54 | 55 | ```python 56 | class ImageFolderLMDB(data.Dataset): 57 | def __init__(self, db_path, transform=None, target_transform=None): 58 | # https://github.com/chainer/chainermn/issues/129 59 | # Delay loading LMDB data until after initialization to avoid "can't pickle Environment Object error" 60 | self.env = None 61 | 62 | # Workaround to have length from the start for ImageNet since we don't have LMDB at initialization time 63 | if 'train' in self.db_path: 64 | self.length = 1281167 65 | elif 'val' in self.db_path: 66 | self.length = 50000 67 | else: 68 | raise NotImplementedError 69 | ... 70 | 71 | def _init_db(self): 72 | self.env = lmdb.open(self.db_path, subdir=os.path.isdir(self.db_path), 73 | readonly=True, lock=False, 74 | readahead=False, meminit=False) 75 | with self.env.begin(write=False) as txn: 76 | # self.length = txn.stat()['entries'] - 1 77 | self.length = pa.deserialize(txn.get(b'__len__')) 78 | self.keys = pa.deserialize(txn.get(b'__keys__')) 79 | 80 | def __getitem__(self, index): 81 | # Delay loading LMDB data until after initialization: https://github.com/chainer/chainermn/issues/129 82 | if self.env is None: 83 | self._init_db() 84 | ... 85 | ``` 86 | 87 | Now we can launch LMDB version with `torch.multiprocessing` using above workaround: 88 | ```bash 89 | python main.py -a resnet50 --dist-url 'tcp://127.0.0.1:9999' --dist-backend 'nccl' --multiprocessing-distributed --world-size 1 --rank 0 --train /imagenet/train-lmdb --val /imagenet/val-lmdb --lmdb 90 | ``` 91 | 92 | ## LMDB 93 | 94 | LMDB is a json-like, but in binary stream key-value storage. In my design, the format of converted LMDB is defined as follow. 95 | 96 | key | value 97 | --- | --- 98 | img-id1 | (jpeg_raw1, label1) 99 | img-id2 | (jpeg_raw2, label2) 100 | img-id3 | (jpeg_raw3, label3) 101 | ... | ... 102 | img-idn | (jpeg_rawn, labeln) 103 | `__keys__` | [img-id1, img-id2, ... img-idn] 104 | `__len__` | n 105 | 106 | As for details of reading/writing, please refer to [code](folder2lmdb.py). 107 | 108 | ### LMDB Dataset / DataLoader 109 | 110 | `folder2lmdb.py` has an implementation of a PyTorch `ImageFolder` for LMDB data to be passed into the `torch.utils.data.DataLoader`. 111 | 112 | In `main.py`, passing the `--lmdb` flag specifies to use `folder2lmdb.ImageFolderLMDB` instead of the default 113 | `torchvision.datasets.ImageFolder` when setting up the data. 114 | 115 | ```python 116 | # Data loading code 117 | if args.lmdb: 118 | import folder2lmdb 119 | image_folder = folder2lmdb.ImageFolderLMDB 120 | else: 121 | image_folder = datasets.ImageFolder 122 | 123 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 124 | std=[0.229, 0.224, 0.225]) 125 | 126 | train_dataset = image_folder( 127 | args.train, 128 | transforms.Compose([ 129 | transforms.RandomResizedCrop(224), 130 | transforms.RandomHorizontalFlip(), 131 | transforms.ToTensor(), 132 | normalize, 133 | ])) 134 | 135 | val_dataset = image_folder( 136 | args.val, 137 | transforms.Compose([ 138 | transforms.Resize(256), 139 | transforms.CenterCrop(224), 140 | transforms.ToTensor(), 141 | normalize, 142 | ])) 143 | ``` 144 | 145 | `ImageFolderLMDB` can be simply used in place of `ImageFolder` like below: 146 | 147 | ```python 148 | from folder2lmdb import ImageFolderLMDB 149 | from torch.utils.data import DataLoader 150 | dataset = ImageFolderLMDB(path, transform, target_transform) 151 | loader = DataLoader(dataset, batch_size=64) 152 | ``` 153 | 154 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.distributed as dist 13 | import torch.optim 14 | import torch.multiprocessing as mp 15 | import torch.utils.data 16 | import torch.utils.data.distributed 17 | import torchvision.transforms as transforms 18 | import torchvision.datasets as datasets 19 | import torchvision.models as models 20 | 21 | model_names = sorted(name for name in models.__dict__ 22 | if name.islower() and not name.startswith("__") 23 | and callable(models.__dict__[name])) 24 | 25 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 26 | parser.add_argument('--train', metavar='PATH', required=True, help='path to training dataset') 27 | parser.add_argument('--val', metavar='PATH', required=True, help='path to validation dataset') 28 | parser.add_argument('--lmdb', action='store_true', 29 | help='Using LMDB format instead of raw images') 30 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 31 | choices=model_names, 32 | help='model architecture: ' + 33 | ' | '.join(model_names) + 34 | ' (default: resnet18)') 35 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 36 | help='number of data loading workers (default: 4)') 37 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 38 | help='number of total epochs to run') 39 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 40 | help='manual epoch number (useful on restarts)') 41 | parser.add_argument('-b', '--batch-size', default=256, type=int, 42 | metavar='N', 43 | help='mini-batch size (default: 256), this is the total ' 44 | 'batch size of all GPUs on the current node when ' 45 | 'using Data Parallel or Distributed Data Parallel') 46 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 47 | metavar='LR', help='initial learning rate', dest='lr') 48 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 49 | help='momentum') 50 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 51 | metavar='W', help='weight decay (default: 1e-4)', 52 | dest='weight_decay') 53 | parser.add_argument('-p', '--print-freq', default=10, type=int, 54 | metavar='N', help='print frequency (default: 10)') 55 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 56 | help='path to latest checkpoint (default: none)') 57 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 58 | help='evaluate model on validation set') 59 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 60 | help='use pre-trained model') 61 | parser.add_argument('--world-size', default=-1, type=int, 62 | help='number of nodes for distributed training') 63 | parser.add_argument('--rank', default=-1, type=int, 64 | help='node rank for distributed training') 65 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 66 | help='url used to set up distributed training') 67 | parser.add_argument('--dist-backend', default='nccl', type=str, 68 | help='distributed backend') 69 | parser.add_argument('--seed', default=None, type=int, 70 | help='seed for initializing training. ') 71 | parser.add_argument('--gpu', default=None, type=int, 72 | help='GPU id to use.') 73 | parser.add_argument('--multiprocessing-distributed', action='store_true', 74 | help='Use multi-processing distributed training to launch ' 75 | 'N processes per node, which has N GPUs. This is the ' 76 | 'fastest way to use PyTorch for either single node or ' 77 | 'multi node data parallel training') 78 | 79 | best_acc1 = 0 80 | 81 | 82 | def main(): 83 | args = parser.parse_args() 84 | 85 | if args.seed is not None: 86 | random.seed(args.seed) 87 | torch.manual_seed(args.seed) 88 | cudnn.deterministic = True 89 | warnings.warn('You have chosen to seed training. ' 90 | 'This will turn on the CUDNN deterministic setting, ' 91 | 'which can slow down your training considerably! ' 92 | 'You may see unexpected behavior when restarting ' 93 | 'from checkpoints.') 94 | 95 | if args.gpu is not None: 96 | warnings.warn('You have chosen a specific GPU. This will completely ' 97 | 'disable data parallelism.') 98 | 99 | if args.dist_url == "env://" and args.world_size == -1: 100 | args.world_size = int(os.environ["WORLD_SIZE"]) 101 | 102 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 103 | 104 | ngpus_per_node = torch.cuda.device_count() 105 | if args.multiprocessing_distributed: 106 | # Since we have ngpus_per_node processes per node, the total world_size 107 | # needs to be adjusted accordingly 108 | args.world_size = ngpus_per_node * args.world_size 109 | # Use torch.multiprocessing.spawn to launch distributed processes: the 110 | # main_worker process function 111 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 112 | else: 113 | # Simply call main_worker function 114 | main_worker(args.gpu, ngpus_per_node, args) 115 | 116 | 117 | def main_worker(gpu, ngpus_per_node, args): 118 | global best_acc1 119 | args.gpu = gpu 120 | 121 | if args.gpu is not None: 122 | print("Use GPU: {} for training".format(args.gpu)) 123 | 124 | if args.distributed: 125 | if args.dist_url == "env://" and args.rank == -1: 126 | args.rank = int(os.environ["RANK"]) 127 | if args.multiprocessing_distributed: 128 | # For multiprocessing distributed training, rank needs to be the 129 | # global rank among all the processes 130 | args.rank = args.rank * ngpus_per_node + gpu 131 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 132 | world_size=args.world_size, rank=args.rank) 133 | # create model 134 | if args.pretrained: 135 | print("=> using pre-trained model '{}'".format(args.arch)) 136 | model = models.__dict__[args.arch](pretrained=True) 137 | else: 138 | print("=> creating model '{}'".format(args.arch)) 139 | model = models.__dict__[args.arch]() 140 | 141 | if args.distributed: 142 | # For multiprocessing distributed, DistributedDataParallel constructor 143 | # should always set the single device scope, otherwise, 144 | # DistributedDataParallel will use all available devices. 145 | if args.gpu is not None: 146 | torch.cuda.set_device(args.gpu) 147 | model.cuda(args.gpu) 148 | # When using a single GPU per process and per 149 | # DistributedDataParallel, we need to divide the batch size 150 | # ourselves based on the total number of GPUs we have 151 | args.batch_size = int(args.batch_size / ngpus_per_node) 152 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 153 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 154 | else: 155 | model.cuda() 156 | # DistributedDataParallel will divide and allocate batch_size to all 157 | # available GPUs if device_ids are not set 158 | model = torch.nn.parallel.DistributedDataParallel(model) 159 | elif args.gpu is not None: 160 | torch.cuda.set_device(args.gpu) 161 | model = model.cuda(args.gpu) 162 | else: 163 | # DataParallel will divide and allocate batch_size to all available GPUs 164 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 165 | model.features = torch.nn.DataParallel(model.features) 166 | model.cuda() 167 | else: 168 | model = torch.nn.DataParallel(model).cuda() 169 | 170 | # define loss function (criterion) and optimizer 171 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 172 | 173 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 174 | momentum=args.momentum, 175 | weight_decay=args.weight_decay) 176 | 177 | # optionally resume from a checkpoint 178 | if args.resume: 179 | if os.path.isfile(args.resume): 180 | print("=> loading checkpoint '{}'".format(args.resume)) 181 | if args.gpu is None: 182 | checkpoint = torch.load(args.resume) 183 | else: 184 | # Map model to be loaded to specified single gpu. 185 | loc = 'cuda:{}'.format(args.gpu) 186 | checkpoint = torch.load(args.resume, map_location=loc) 187 | args.start_epoch = checkpoint['epoch'] 188 | best_acc1 = checkpoint['best_acc1'] 189 | if args.gpu is not None: 190 | # best_acc1 may be from a checkpoint from a different GPU 191 | best_acc1 = best_acc1.to(args.gpu) 192 | model.load_state_dict(checkpoint['state_dict']) 193 | optimizer.load_state_dict(checkpoint['optimizer']) 194 | print("=> loaded checkpoint '{}' (epoch {})" 195 | .format(args.resume, checkpoint['epoch'])) 196 | else: 197 | print("=> no checkpoint found at '{}'".format(args.resume)) 198 | 199 | cudnn.benchmark = True 200 | 201 | # Data loading code 202 | if args.lmdb: 203 | import folder2lmdb 204 | image_folder = folder2lmdb.ImageFolderLMDB 205 | 206 | else: 207 | image_folder = datasets.ImageFolder 208 | 209 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 210 | std=[0.229, 0.224, 0.225]) 211 | 212 | train_dataset = image_folder( 213 | args.train, 214 | transforms.Compose([ 215 | transforms.RandomResizedCrop(224), 216 | transforms.RandomHorizontalFlip(), 217 | transforms.ToTensor(), 218 | normalize, 219 | ])) 220 | 221 | val_dataset = image_folder( 222 | args.val, 223 | transforms.Compose([ 224 | transforms.Resize(256), 225 | transforms.CenterCrop(224), 226 | transforms.ToTensor(), 227 | normalize, 228 | ])) 229 | 230 | if args.distributed: 231 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 232 | else: 233 | train_sampler = None 234 | 235 | train_loader = torch.utils.data.DataLoader( 236 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 237 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 238 | 239 | val_loader = torch.utils.data.DataLoader( 240 | val_dataset, batch_size=args.batch_size, shuffle=False, 241 | num_workers=args.workers, pin_memory=True) 242 | 243 | if args.evaluate: 244 | validate(val_loader, model, criterion, args) 245 | return 246 | 247 | for epoch in range(args.start_epoch, args.epochs): 248 | if args.distributed: 249 | train_sampler.set_epoch(epoch) 250 | adjust_learning_rate(optimizer, epoch, args) 251 | 252 | # train for one epoch 253 | train(train_loader, model, criterion, optimizer, epoch, args) 254 | 255 | # evaluate on validation set 256 | acc1 = validate(val_loader, model, criterion, args) 257 | 258 | # remember best acc@1 and save checkpoint 259 | is_best = acc1 > best_acc1 260 | best_acc1 = max(acc1, best_acc1) 261 | 262 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 263 | and args.rank % ngpus_per_node == 0): 264 | save_checkpoint({ 265 | 'epoch': epoch + 1, 266 | 'arch': args.arch, 267 | 'state_dict': model.state_dict(), 268 | 'best_acc1': best_acc1, 269 | 'optimizer' : optimizer.state_dict(), 270 | }, is_best) 271 | 272 | 273 | def train(train_loader, model, criterion, optimizer, epoch, args): 274 | batch_time = AverageMeter('Time', ':6.3f') 275 | data_time = AverageMeter('Data', ':6.3f') 276 | losses = AverageMeter('Loss', ':.4e') 277 | top1 = AverageMeter('Acc@1', ':6.2f') 278 | top5 = AverageMeter('Acc@5', ':6.2f') 279 | progress = ProgressMeter( 280 | len(train_loader), 281 | [batch_time, data_time, losses, top1, top5], 282 | prefix="Epoch: [{}]".format(epoch)) 283 | 284 | # switch to train mode 285 | model.train() 286 | 287 | end = time.time() 288 | for i, (images, target) in enumerate(train_loader): 289 | # measure data loading time 290 | data_time.update(time.time() - end) 291 | 292 | if args.gpu is not None: 293 | images = images.cuda(args.gpu, non_blocking=True) 294 | target = target.cuda(args.gpu, non_blocking=True) 295 | 296 | # compute output 297 | output = model(images) 298 | loss = criterion(output, target) 299 | 300 | # measure accuracy and record loss 301 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 302 | losses.update(loss.item(), images.size(0)) 303 | top1.update(acc1[0], images.size(0)) 304 | top5.update(acc5[0], images.size(0)) 305 | 306 | # compute gradient and do SGD step 307 | optimizer.zero_grad() 308 | loss.backward() 309 | optimizer.step() 310 | 311 | # measure elapsed time 312 | batch_time.update(time.time() - end) 313 | end = time.time() 314 | 315 | if i % args.print_freq == 0: 316 | progress.display(i) 317 | 318 | 319 | def validate(val_loader, model, criterion, args): 320 | batch_time = AverageMeter('Time', ':6.3f') 321 | losses = AverageMeter('Loss', ':.4e') 322 | top1 = AverageMeter('Acc@1', ':6.2f') 323 | top5 = AverageMeter('Acc@5', ':6.2f') 324 | progress = ProgressMeter( 325 | len(val_loader), 326 | [batch_time, losses, top1, top5], 327 | prefix='Test: ') 328 | 329 | # switch to evaluate mode 330 | model.eval() 331 | 332 | with torch.no_grad(): 333 | end = time.time() 334 | for i, (images, target) in enumerate(val_loader): 335 | if args.gpu is not None: 336 | images = images.cuda(args.gpu, non_blocking=True) 337 | target = target.cuda(args.gpu, non_blocking=True) 338 | 339 | # compute output 340 | output = model(images) 341 | loss = criterion(output, target) 342 | 343 | # measure accuracy and record loss 344 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 345 | losses.update(loss.item(), images.size(0)) 346 | top1.update(acc1[0], images.size(0)) 347 | top5.update(acc5[0], images.size(0)) 348 | 349 | # measure elapsed time 350 | batch_time.update(time.time() - end) 351 | end = time.time() 352 | 353 | if i % args.print_freq == 0: 354 | progress.display(i) 355 | 356 | # TODO: this should also be done with the ProgressMeter 357 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 358 | .format(top1=top1, top5=top5)) 359 | 360 | return top1.avg 361 | 362 | 363 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 364 | torch.save(state, filename) 365 | if is_best: 366 | shutil.copyfile(filename, 'model_best.pth.tar') 367 | 368 | 369 | class AverageMeter(object): 370 | """Computes and stores the average and current value""" 371 | def __init__(self, name, fmt=':f'): 372 | self.name = name 373 | self.fmt = fmt 374 | self.reset() 375 | 376 | def reset(self): 377 | self.val = 0 378 | self.avg = 0 379 | self.sum = 0 380 | self.count = 0 381 | 382 | def update(self, val, n=1): 383 | self.val = val 384 | self.sum += val * n 385 | self.count += n 386 | self.avg = self.sum / self.count 387 | 388 | def __str__(self): 389 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 390 | return fmtstr.format(**self.__dict__) 391 | 392 | 393 | class ProgressMeter(object): 394 | def __init__(self, num_batches, meters, prefix=""): 395 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 396 | self.meters = meters 397 | self.prefix = prefix 398 | 399 | def display(self, batch): 400 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 401 | entries += [str(meter) for meter in self.meters] 402 | print('\t'.join(entries)) 403 | 404 | def _get_batch_fmtstr(self, num_batches): 405 | num_digits = len(str(num_batches // 1)) 406 | fmt = '{:' + str(num_digits) + 'd}' 407 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 408 | 409 | 410 | def adjust_learning_rate(optimizer, epoch, args): 411 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 412 | lr = args.lr * (0.1 ** (epoch // 30)) 413 | for param_group in optimizer.param_groups: 414 | param_group['lr'] = lr 415 | 416 | 417 | def accuracy(output, target, topk=(1,)): 418 | """Computes the accuracy over the k top predictions for the specified values of k""" 419 | with torch.no_grad(): 420 | maxk = max(topk) 421 | batch_size = target.size(0) 422 | 423 | _, pred = output.topk(maxk, 1, True, True) 424 | pred = pred.t() 425 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 426 | 427 | res = [] 428 | for k in topk: 429 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 430 | res.append(correct_k.mul_(100.0 / batch_size)) 431 | return res 432 | 433 | 434 | if __name__ == '__main__': 435 | main() 436 | --------------------------------------------------------------------------------