├── .gitignore ├── networks ├── .gitignore ├── model_list │ ├── __init__.py │ └── alexnet.py └── main.py ├── datasets ├── __init__.py ├── folder.py └── transforms.py ├── tools ├── fix_key.sh └── convert_key.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | networks/data* 3 | -------------------------------------------------------------------------------- /networks/.gitignore: -------------------------------------------------------------------------------- 1 | *.pth.tar 2 | *.pth.tar 3 | -------------------------------------------------------------------------------- /networks/model_list/__init__.py: -------------------------------------------------------------------------------- 1 | from alexnet import alexnet 2 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .folder import ImageFolder 2 | 3 | __all__ = ('ImageFolder') 4 | -------------------------------------------------------------------------------- /tools/fix_key.sh: -------------------------------------------------------------------------------- 1 | # original lmdb paths 2 | Original_Train_LMDB=/data/jiecaoyu/imagenet/lmdb/ilsvrc12_train_lmdb_badkey/ 3 | Original_Val_LMDB=/data/jiecaoyu/imagenet/lmdb/ilsvrc12_val_lmdb_badkey/ 4 | 5 | # target lmdb paths 6 | Target_Train_LMDB=/data/jiecaoyu/imagenet/lmdb/ilsvrc12_train_lmdb/ 7 | Target_Val_LMDB=/data/jiecaoyu/imagenet/lmdb/ilsvrc12_val_lmdb/ 8 | 9 | python convert_key.py --source $Original_Train_LMDB --target $Target_Train_LMDB 10 | python convert_key.py --source $Original_Val_LMDB --target $Target_Val_LMDB 11 | -------------------------------------------------------------------------------- /tools/convert_key.py: -------------------------------------------------------------------------------- 1 | import caffe 2 | import lmdb 3 | import argparse 4 | import subprocess 5 | 6 | parser = argparse.ArgumentParser(description='Convert dataset key') 7 | parser.add_argument('--target', default=None, action='store', 8 | help='target directory') 9 | parser.add_argument('--source', default=None, action='store', 10 | help='target directory') 11 | args = parser.parse_args() 12 | 13 | print '\n===========================\nConverting keys from:' 14 | print '\t'+args.source 15 | print ' to:' 16 | print '\t'+args.target+'\n' 17 | 18 | # read source 19 | source_env = lmdb.open(args.source, readonly=True) 20 | source_txn = source_env.begin() 21 | source_cursor = source_txn.cursor() 22 | 23 | # clean remaining target directory 24 | subprocess.call('rm -r '+args.target, shell=True) 25 | target_env = lmdb.open(args.target, map_size=int(1e12)) 26 | target_txn = target_env.begin(write=True) 27 | batch_size = 1000 28 | 29 | item_id = -1 30 | for key, value in source_cursor: 31 | item_id += 1 32 | key_str = key[0:8] 33 | target_txn.put( key_str, value ) 34 | 35 | # write batch 36 | if(item_id + 1) % batch_size == 0: 37 | target_txn.commit() 38 | target_txn = target_env.begin(write=True) 39 | print (item_id + 1) 40 | 41 | # write last batch 42 | if (item_id+1) % batch_size != 0: 43 | target_txn.commit() 44 | print 'last batch' 45 | print (item_id + 1) 46 | source_env.close() 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | The implementation of AlexNet in [PyTorch Vision](https://github.com/pytorch/vision) is not actually the ordinary version. In this case, this repository reimplements some of the networks for the author's usage. 2 | 3 | # Pre-requirements 4 | - [Caffe](https://github.com/BVLC/caffe) 5 | - [PyTorch](https://github.com/pytorch/pytorch) 6 | 7 | # Data Preparation 8 | The original data loader ([link](https://github.com/pytorch/vision#imagenet-12)) is slow. Therefore, I build a new data loader with Caffe utils. 9 | ### Genearte LMDB 10 | The preprocessed datasets can be found [here](https://drive.google.com/uc?export=download&id=0B-7I62GOSnZ8aENhOEtESVFHa2M). Please download it and uncompress it into the directory of ```./networks/data/```. 11 | To generate the dataset from raw images, please follow the [instructions](http://caffe.berkeleyvision.org/gathered/examples/imagenet.html) for Caffe to build the LMDB dataset of ImageNet. However, the ```key``` used in the LMDB dataset is not suitable for accessing. Therefore, please use the script ```./tools/fix_key.sh``` to convert the keys. 12 | Preprocessed data will be available soon. 13 | ### Load LMDB 14 | Please change the variable ```lmdb_dir``` in ```./datasets/folder.py``` to the directory which includes the training and validating LMDB datasets. 15 | # AlexNet 16 | The implementation is in ```./networks/model_list/alexnet.py```. Since PyTorch does not support local response normalization (LRN) layer, I implements it also. The trained model will be available soon. 17 | ### Training from scratch 18 | ```bash 19 | $ cd networks 20 | $ python main.py --arch alexnet 21 | ``` 22 | 23 | ### Evaluate pretained model 24 | Pretrained model is available [here](https://drive.google.com/uc?export=download&id=0B-7I62GOSnZ8NzVxZndDU2dYcHM). The pretained model achieves an accuracy of 57.494% (Top-1) and 80.588% (Top-5). Please download it and put it under the directory of ```./networks/model_list/```. Run the evaluation by: 25 | ```bash 26 | $ python main.py --arch alexnet --pretrained --evaluate 27 | ``` 28 | -------------------------------------------------------------------------------- /networks/model_list/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | 7 | __all__ = ['AlexNet', 'alexnet'] 8 | 9 | class LRN(nn.Module): 10 | def __init__(self, local_size=1, alpha=1.0, beta=0.75, ACROSS_CHANNELS=True): 11 | super(LRN, self).__init__() 12 | self.ACROSS_CHANNELS = ACROSS_CHANNELS 13 | if ACROSS_CHANNELS: 14 | self.average=nn.AvgPool3d(kernel_size=(local_size, 1, 1), 15 | stride=1, 16 | padding=(int((local_size-1.0)/2), 0, 0)) 17 | else: 18 | self.average=nn.AvgPool2d(kernel_size=local_size, 19 | stride=1, 20 | padding=int((local_size-1.0)/2)) 21 | self.alpha = alpha 22 | self.beta = beta 23 | 24 | 25 | def forward(self, x): 26 | if self.ACROSS_CHANNELS: 27 | div = x.pow(2).unsqueeze(1) 28 | div = self.average(div).squeeze(1) 29 | div = div.mul(self.alpha).add(1.0).pow(self.beta) 30 | else: 31 | div = x.pow(2) 32 | div = self.average(div) 33 | div = div.mul(self.alpha).add(1.0).pow(self.beta) 34 | x = x.div(div) 35 | return x 36 | 37 | class AlexNet(nn.Module): 38 | 39 | def __init__(self, num_classes=1000): 40 | super(AlexNet, self).__init__() 41 | self.features = nn.Sequential( 42 | nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0), 43 | nn.ReLU(inplace=True), 44 | LRN(local_size=5, alpha=0.0001, beta=0.75), 45 | nn.MaxPool2d(kernel_size=3, stride=2), 46 | nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2), 47 | nn.ReLU(inplace=True), 48 | LRN(local_size=5, alpha=0.0001, beta=0.75), 49 | nn.MaxPool2d(kernel_size=3, stride=2), 50 | nn.Conv2d(256, 384, kernel_size=3, padding=1), 51 | nn.ReLU(inplace=True), 52 | nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2), 53 | nn.ReLU(inplace=True), 54 | nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2), 55 | nn.ReLU(inplace=True), 56 | nn.MaxPool2d(kernel_size=3, stride=2), 57 | ) 58 | self.classifier = nn.Sequential( 59 | nn.Linear(256 * 6 * 6, 4096), 60 | nn.ReLU(inplace=True), 61 | nn.Dropout(), 62 | nn.Linear(4096, 4096), 63 | nn.ReLU(inplace=True), 64 | nn.Dropout(), 65 | nn.Linear(4096, num_classes), 66 | ) 67 | 68 | def forward(self, x): 69 | x = self.features(x) 70 | x = x.view(x.size(0), 256 * 6 * 6) 71 | x = self.classifier(x) 72 | return x 73 | 74 | 75 | def alexnet(pretrained=False, **kwargs): 76 | r"""AlexNet model architecture from the 77 | `"One weird trick..." `_ paper. 78 | 79 | Args: 80 | pretrained (bool): If True, returns a model pre-trained on ImageNet 81 | """ 82 | model = AlexNet(**kwargs) 83 | if pretrained: 84 | model_path = 'model_list/alexnet.pth.tar' 85 | pretrained_model = torch.load(model_path) 86 | model.load_state_dict(pretrained_model['state_dict']) 87 | return model 88 | -------------------------------------------------------------------------------- /datasets/folder.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import lmdb 7 | import caffe 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | 19 | def find_classes(dir): 20 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 21 | classes.sort() 22 | class_to_idx = {classes[i]: i for i in range(len(classes))} 23 | return classes, class_to_idx 24 | 25 | 26 | def make_dataset(dir, class_to_idx): 27 | images = [] 28 | dir = os.path.expanduser(dir) 29 | for target in sorted(os.listdir(dir)): 30 | d = os.path.join(dir, target) 31 | if not os.path.isdir(d): 32 | continue 33 | 34 | for root, _, fnames in sorted(os.walk(d)): 35 | for fname in sorted(fnames): 36 | if is_image_file(fname): 37 | path = os.path.join(root, fname) 38 | item = (path, class_to_idx[target]) 39 | images.append(item) 40 | 41 | return images 42 | 43 | 44 | def pil_loader(path): 45 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 46 | with open(path, 'rb') as f: 47 | # with Image.open(f) as img: 48 | # return img.convert('RGB') 49 | img = Image.open(f) 50 | return img.convert('RGB') 51 | 52 | 53 | def accimage_loader(path): 54 | import accimage 55 | try: 56 | return accimage.Image(path) 57 | except IOError: 58 | # Potentially a decoding problem, fall back to PIL.Image 59 | return pil_loader(path) 60 | 61 | 62 | def default_loader(path): 63 | from torchvision import get_image_backend 64 | if get_image_backend() == 'accimage': 65 | return accimage_loader(path) 66 | else: 67 | return pil_loader(path) 68 | 69 | 70 | class ImageFolder(data.Dataset): 71 | """A generic data loader where the images are arranged in this way: :: 72 | 73 | root/dog/xxx.png 74 | root/dog/xxy.png 75 | root/dog/xxz.png 76 | 77 | root/cat/123.png 78 | root/cat/nsdf3.png 79 | root/cat/asd932_.png 80 | 81 | Args: 82 | root (string): Root directory path. 83 | transform (callable, optional): A function/transform that takes in an PIL image 84 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 85 | target_transform (callable, optional): A function/transform that takes in the 86 | target and transforms it. 87 | loader (callable, optional): A function to load an image given its path. 88 | 89 | Attributes: 90 | classes (list): List of the class names. 91 | class_to_idx (dict): Dict with items (class_name, class_index). 92 | imgs (list): List of (image path, class_index) tuples 93 | """ 94 | 95 | def __init__(self, data_path=None, transform=None, target_transform=None, 96 | loader=default_loader, Train=True): 97 | self.transform = transform 98 | self.target_transform = target_transform 99 | self.loader = loader 100 | 101 | self.Train = Train 102 | self.lmdb_dir = data_path 103 | if self.Train: 104 | self.lmdb_dir = self.lmdb_dir+'/ilsvrc12_train_lmdb/' 105 | else: 106 | self.lmdb_dir = self.lmdb_dir+'/ilsvrc12_val_lmdb/' 107 | self.lmdb_env = lmdb.open(self.lmdb_dir, readonly=True) 108 | self.lmdb_txn = self.lmdb_env.begin() 109 | 110 | self.length = self.lmdb_env.stat()['entries'] 111 | 112 | def __getitem__(self, index): 113 | """ 114 | Args: 115 | index (int): Index 116 | 117 | Returns: 118 | tuple: (image, target) where target is class_index of the target class. 119 | """ 120 | datum = caffe.proto.caffe_pb2.Datum() 121 | lmdb_cursor = self.lmdb_txn.cursor() 122 | key_index ='{:08}'.format(index) 123 | value = lmdb_cursor.get(key_index) 124 | datum.ParseFromString(value) 125 | data = caffe.io.datum_to_array(datum) 126 | if self.transform is not None: 127 | img = self.transform(data) 128 | target = datum.label 129 | 130 | return img, target 131 | 132 | def __len__(self): 133 | return self.length 134 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import caffe 3 | import torch 4 | import math 5 | import random 6 | from PIL import Image, ImageOps 7 | try: 8 | import accimage 9 | except ImportError: 10 | accimage = None 11 | import numpy as np 12 | import numbers 13 | import types 14 | import collections 15 | 16 | 17 | class Compose(object): 18 | """Composes several transforms together. 19 | 20 | Args: 21 | transforms (list of ``Transform`` objects): list of transforms to compose. 22 | 23 | Example: 24 | >>> transforms.Compose([ 25 | >>> transforms.CenterCrop(10), 26 | >>> transforms.ToTensor(), 27 | >>> ]) 28 | """ 29 | 30 | def __init__(self, transforms): 31 | self.transforms = transforms 32 | 33 | def __call__(self, img): 34 | for t in self.transforms: 35 | img = t(img) 36 | return img 37 | 38 | 39 | class ToTensor(object): 40 | """Convert a ``numpy.ndarray`` to tensor. 41 | 42 | """ 43 | 44 | def __call__(self, pic): 45 | """ 46 | Args: 47 | pic (numpy.ndarray): Image to be converted to tensor. 48 | 49 | Returns: 50 | Tensor: Converted image. 51 | """ 52 | if isinstance(pic, np.ndarray): 53 | # handle numpy array 54 | img = torch.from_numpy(pic) 55 | # backward compatibility 56 | return img.float() 57 | 58 | class Normalize(object): 59 | """Normalize an tensor image with mean and standard deviation. 60 | 61 | Given mean: (R, G, B), 62 | will normalize each channel of the torch.*Tensor, i.e. 63 | channel = channel - mean 64 | 65 | Args: 66 | mean (sequence): Sequence of means for R, G, B channels respecitvely. 67 | """ 68 | 69 | def __init__(self, mean=None, meanfile=None): 70 | if mean: 71 | self.mean = mean 72 | else: 73 | data = open(meanfile, 'rb').read() 74 | blob = caffe.proto.caffe_pb2.BlobProto() 75 | blob.ParseFromString(data) 76 | arr = np.array(caffe.io.blobproto_to_array(blob)) 77 | self.mean = torch.from_numpy(arr[0].astype('float32')) 78 | 79 | def __call__(self, tensor): 80 | """ 81 | Args: 82 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 83 | 84 | Returns: 85 | Tensor: Normalized image. 86 | """ 87 | # TODO: make efficient 88 | for t, m in zip(tensor, self.mean): 89 | t.sub_(m) 90 | return tensor 91 | 92 | 93 | class Scale(object): 94 | """Rescale the input PIL.Image to the given size. 95 | 96 | Args: 97 | size (sequence or int): Desired output size. If size is a sequence like 98 | (w, h), output size will be matched to this. If size is an int, 99 | smaller edge of the image will be matched to this number. 100 | i.e, if height > width, then image will be rescaled to 101 | (size * height / width, size) 102 | interpolation (int, optional): Desired interpolation. Default is 103 | ``PIL.Image.BILINEAR`` 104 | """ 105 | 106 | def __init__(self, size, interpolation=Image.BILINEAR): 107 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 108 | self.size = size 109 | self.interpolation = interpolation 110 | 111 | def __call__(self, img): 112 | """ 113 | Args: 114 | img (PIL.Image): Image to be scaled. 115 | 116 | """ 117 | assert(img.shape[1]==self.size) 118 | assert(img.shape[2]==self.size) 119 | return img 120 | 121 | class CenterCrop(object): 122 | """Crops the given PIL.Image at the center. 123 | 124 | Args: 125 | size (sequence or int): Desired output size of the crop. If size is an 126 | int instead of sequence like (h, w), a square crop (size, size) is 127 | made. 128 | """ 129 | 130 | def __init__(self, size): 131 | if isinstance(size, numbers.Number): 132 | self.size = (int(size), int(size)) 133 | else: 134 | self.size = size 135 | 136 | def __call__(self, img): 137 | """ 138 | Args: 139 | img (PIL.Image): Image to be cropped. 140 | 141 | Returns: 142 | PIL.Image: Cropped image. 143 | """ 144 | w, h = (img.shape[1], img.shape[2]) 145 | th, tw = self.size 146 | w_off = int((w - tw) / 2.) 147 | h_off = int((h - th) / 2.) 148 | img = img[:, h_off:h_off+th, w_off:w_off+tw] 149 | return img 150 | 151 | 152 | class Pad(object): 153 | """Pad the given PIL.Image on all sides with the given "pad" value. 154 | 155 | Args: 156 | padding (int or sequence): Padding on each border. If a sequence of 157 | length 4, it is used to pad left, top, right and bottom borders respectively. 158 | fill: Pixel fill value. Default is 0. 159 | """ 160 | 161 | def __init__(self, padding, fill=0): 162 | assert isinstance(padding, numbers.Number) 163 | assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) 164 | self.padding = padding 165 | self.fill = fill 166 | 167 | def __call__(self, img): 168 | """ 169 | Args: 170 | img (PIL.Image): Image to be padded. 171 | 172 | Returns: 173 | PIL.Image: Padded image. 174 | """ 175 | return ImageOps.expand(img, border=self.padding, fill=self.fill) 176 | 177 | 178 | class Lambda(object): 179 | """Apply a user-defined lambda as a transform. 180 | 181 | Args: 182 | lambd (function): Lambda/function to be used for transform. 183 | """ 184 | 185 | def __init__(self, lambd): 186 | assert isinstance(lambd, types.LambdaType) 187 | self.lambd = lambd 188 | 189 | def __call__(self, img): 190 | return self.lambd(img) 191 | 192 | 193 | class RandomCrop(object): 194 | """Crop the given PIL.Image at a random location. 195 | 196 | Args: 197 | size (sequence or int): Desired output size of the crop. If size is an 198 | int instead of sequence like (h, w), a square crop (size, size) is 199 | made. 200 | padding (int or sequence, optional): Optional padding on each border 201 | of the image. Default is 0, i.e no padding. If a sequence of length 202 | 4 is provided, it is used to pad left, top, right, bottom borders 203 | respectively. 204 | """ 205 | 206 | def __init__(self, size, padding=0): 207 | if isinstance(size, numbers.Number): 208 | self.size = (int(size), int(size)) 209 | else: 210 | self.size = size 211 | self.padding = padding 212 | 213 | def __call__(self, img): 214 | """ 215 | Args: 216 | img (PIL.Image): Image to be cropped. 217 | 218 | Returns: 219 | PIL.Image: Cropped image. 220 | """ 221 | if self.padding > 0: 222 | img = ImageOps.expand(img, border=self.padding, fill=0) 223 | 224 | w, h = img.size 225 | th, tw = self.size 226 | if w == tw and h == th: 227 | return img 228 | 229 | x1 = random.randint(0, w - tw) 230 | y1 = random.randint(0, h - th) 231 | return img.crop((x1, y1, x1 + tw, y1 + th)) 232 | 233 | 234 | class RandomHorizontalFlip(object): 235 | """Horizontally flip the given PIL.Image randomly with a probability of 0.5.""" 236 | 237 | def __call__(self, img): 238 | """ 239 | Args: 240 | img (PIL.Image): Image to be flipped. 241 | 242 | Returns: 243 | PIL.Image: Randomly flipped image. 244 | """ 245 | if random.random() < 0.5: 246 | img = np.flip(img, axis=2).copy() 247 | return img 248 | 249 | 250 | class RandomSizedCrop(object): 251 | """Crop the given PIL.Image to random size and aspect ratio. 252 | 253 | A crop of random size of (0.08 to 1.0) of the original size and a random 254 | aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop 255 | is finally resized to given size. 256 | This is popularly used to train the Inception networks. 257 | 258 | Args: 259 | size: size of the smaller edge 260 | interpolation: Default: PIL.Image.BILINEAR 261 | """ 262 | 263 | def __init__(self, size, interpolation=Image.BILINEAR): 264 | self.size = size 265 | self.interpolation = interpolation 266 | 267 | def __call__(self, img): 268 | h_off = random.randint(0, img.shape[1]-self.size) 269 | w_off = random.randint(0, img.shape[2]-self.size) 270 | img = img[:, h_off:h_off+self.size, w_off:w_off+self.size] 271 | return img 272 | -------------------------------------------------------------------------------- /networks/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.distributed as dist 11 | import torch.optim 12 | import torch.utils.data 13 | import torch.utils.data.distributed 14 | # import torchvision.transforms as transforms 15 | # import torchvision.datasets as datasets 16 | import model_list 17 | 18 | # set the seed 19 | torch.manual_seed(1) 20 | torch.cuda.manual_seed(1) 21 | 22 | import sys 23 | import gc 24 | cwd = os.getcwd() 25 | sys.path.append(cwd+'/../') 26 | import datasets as datasets 27 | import datasets.transforms as transforms 28 | 29 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 30 | parser.add_argument('--arch', '-a', metavar='ARCH', default='alexnet', 31 | help='model architecture (default: alexnet)') 32 | parser.add_argument('--data', metavar='DATA_PATH', default='./data/', 33 | help='path to imagenet data (default: ./data/)') 34 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 35 | help='number of data loading workers (default: 8)') 36 | parser.add_argument('--epochs', default=160, type=int, metavar='N', 37 | help='number of total epochs to run') 38 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 39 | help='manual epoch number (useful on restarts)') 40 | parser.add_argument('-b', '--batch-size', default=256, type=int, 41 | metavar='N', help='mini-batch size (default: 256)') 42 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 43 | metavar='LR', help='initial learning rate') 44 | parser.add_argument('--momentum', default=0.90, type=float, metavar='M', 45 | help='momentum') 46 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 47 | metavar='W', help='weight decay (default: 5e-4)') 48 | parser.add_argument('--print-freq', '-p', default=10, type=int, 49 | metavar='N', help='print frequency (default: 10)') 50 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 51 | help='path to latest checkpoint (default: none)') 52 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 53 | help='evaluate model on validation set') 54 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 55 | default=False, help='use pre-trained model') 56 | parser.add_argument('--world-size', default=1, type=int, 57 | help='number of distributed processes') 58 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 59 | help='url used to set up distributed training') 60 | parser.add_argument('--dist-backend', default='gloo', type=str, 61 | help='distributed backend') 62 | 63 | best_prec1 = 0 64 | 65 | 66 | def main(): 67 | global args, best_prec1 68 | args = parser.parse_args() 69 | 70 | args.distributed = args.world_size > 1 71 | 72 | if args.distributed: 73 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 74 | world_size=args.world_size) 75 | 76 | # create model 77 | if args.arch=='alexnet': 78 | model = model_list.alexnet(pretrained=args.pretrained) 79 | input_size = 227 80 | else: 81 | raise Exception('Model not supported yet') 82 | 83 | if not args.distributed: 84 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 85 | model.features = torch.nn.DataParallel(model.features) 86 | model.cuda() 87 | else: 88 | model = torch.nn.DataParallel(model).cuda() 89 | else: 90 | model.cuda() 91 | model = torch.nn.parallel.DistributedDataParallel(model) 92 | 93 | # define loss function (criterion) and optimizer 94 | criterion = nn.CrossEntropyLoss().cuda() 95 | 96 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 97 | momentum=args.momentum, 98 | weight_decay=args.weight_decay) 99 | 100 | # optionally resume from a checkpoint 101 | if args.resume: 102 | if os.path.isfile(args.resume): 103 | print("=> loading checkpoint '{}'".format(args.resume)) 104 | checkpoint = torch.load(args.resume) 105 | args.start_epoch = checkpoint['epoch'] 106 | best_prec1 = checkpoint['best_prec1'] 107 | model.load_state_dict(checkpoint['state_dict']) 108 | optimizer.load_state_dict(checkpoint['optimizer']) 109 | print("=> loaded checkpoint '{}' (epoch {})" 110 | .format(args.resume, checkpoint['epoch'])) 111 | else: 112 | print("=> no checkpoint found at '{}'".format(args.resume)) 113 | 114 | cudnn.benchmark = True 115 | 116 | # Data loading code 117 | if not os.path.exists(args.data+'/imagenet_mean.binaryproto'): 118 | print("==> Data directory"+args.data+"does not exits") 119 | print("==> Please specify the correct data path by") 120 | print("==> --data ") 121 | return 122 | 123 | normalize = transforms.Normalize( 124 | meanfile=args.data+'/imagenet_mean.binaryproto') 125 | 126 | train_dataset = datasets.ImageFolder( 127 | args.data, 128 | transforms.Compose([ 129 | transforms.RandomHorizontalFlip(), 130 | transforms.ToTensor(), 131 | normalize, 132 | transforms.RandomSizedCrop(input_size), 133 | ]), 134 | Train=True) 135 | 136 | if args.distributed: 137 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 138 | else: 139 | train_sampler = None 140 | 141 | train_loader = torch.utils.data.DataLoader( 142 | train_dataset, batch_size=args.batch_size, shuffle=False, 143 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 144 | 145 | val_loader = torch.utils.data.DataLoader( 146 | datasets.ImageFolder(args.data, transforms.Compose([ 147 | transforms.ToTensor(), 148 | normalize, 149 | transforms.CenterCrop(input_size), 150 | ]), 151 | Train=False), 152 | batch_size=args.batch_size, shuffle=False, 153 | num_workers=args.workers, pin_memory=True) 154 | 155 | print model 156 | 157 | if args.evaluate: 158 | validate(val_loader, model, criterion) 159 | return 160 | 161 | for epoch in range(args.start_epoch, args.epochs): 162 | if args.distributed: 163 | train_sampler.set_epoch(epoch) 164 | adjust_learning_rate(optimizer, epoch) 165 | 166 | # train for one epoch 167 | train(train_loader, model, criterion, optimizer, epoch) 168 | 169 | # evaluate on validation set 170 | prec1 = validate(val_loader, model, criterion) 171 | 172 | # remember best prec@1 and save checkpoint 173 | is_best = prec1 > best_prec1 174 | best_prec1 = max(prec1, best_prec1) 175 | save_checkpoint({ 176 | 'epoch': epoch + 1, 177 | 'arch': args.arch, 178 | 'state_dict': model.state_dict(), 179 | 'best_prec1': best_prec1, 180 | 'optimizer' : optimizer.state_dict(), 181 | }, is_best) 182 | 183 | 184 | def train(train_loader, model, criterion, optimizer, epoch): 185 | batch_time = AverageMeter() 186 | data_time = AverageMeter() 187 | losses = AverageMeter() 188 | top1 = AverageMeter() 189 | top5 = AverageMeter() 190 | 191 | # switch to train mode 192 | model.train() 193 | 194 | end = time.time() 195 | for i, (input, target) in enumerate(train_loader): 196 | # measure data loading time 197 | data_time.update(time.time() - end) 198 | 199 | target = target.cuda(async=True) 200 | input_var = torch.autograd.Variable(input) 201 | target_var = torch.autograd.Variable(target) 202 | 203 | # compute output 204 | output = model(input_var) 205 | loss = criterion(output, target_var) 206 | 207 | # measure accuracy and record loss 208 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 209 | losses.update(loss.data[0], input.size(0)) 210 | top1.update(prec1[0], input.size(0)) 211 | top5.update(prec5[0], input.size(0)) 212 | 213 | # compute gradient and do SGD step 214 | optimizer.zero_grad() 215 | loss.backward() 216 | optimizer.step() 217 | 218 | # measure elapsed time 219 | batch_time.update(time.time() - end) 220 | end = time.time() 221 | 222 | if i % args.print_freq == 0: 223 | print('Epoch: [{0}][{1}/{2}]\t' 224 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 225 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 226 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 227 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 228 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 229 | epoch, i, len(train_loader), batch_time=batch_time, 230 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 231 | gc.collect() 232 | 233 | 234 | def validate(val_loader, model, criterion): 235 | batch_time = AverageMeter() 236 | losses = AverageMeter() 237 | top1 = AverageMeter() 238 | top5 = AverageMeter() 239 | 240 | # switch to evaluate mode 241 | model.eval() 242 | 243 | end = time.time() 244 | for i, (input, target) in enumerate(val_loader): 245 | target = target.cuda(async=True) 246 | input_var = torch.autograd.Variable(input, volatile=True) 247 | target_var = torch.autograd.Variable(target, volatile=True) 248 | 249 | # compute output 250 | output = model(input_var) 251 | loss = criterion(output, target_var) 252 | 253 | # measure accuracy and record loss 254 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 255 | losses.update(loss.data[0], input.size(0)) 256 | top1.update(prec1[0], input.size(0)) 257 | top5.update(prec5[0], input.size(0)) 258 | 259 | # measure elapsed time 260 | batch_time.update(time.time() - end) 261 | end = time.time() 262 | 263 | if i % args.print_freq == 0: 264 | print('Test: [{0}/{1}]\t' 265 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 266 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 267 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 268 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 269 | i, len(val_loader), batch_time=batch_time, loss=losses, 270 | top1=top1, top5=top5)) 271 | 272 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 273 | .format(top1=top1, top5=top5)) 274 | 275 | return top1.avg 276 | 277 | 278 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 279 | torch.save(state, filename) 280 | if is_best: 281 | shutil.copyfile(filename, 'model_best.pth.tar') 282 | 283 | 284 | class AverageMeter(object): 285 | """Computes and stores the average and current value""" 286 | def __init__(self): 287 | self.reset() 288 | 289 | def reset(self): 290 | self.val = 0 291 | self.avg = 0 292 | self.sum = 0 293 | self.count = 0 294 | 295 | def update(self, val, n=1): 296 | self.val = val 297 | self.sum += val * n 298 | self.count += n 299 | self.avg = self.sum / self.count 300 | 301 | 302 | def adjust_learning_rate(optimizer, epoch): 303 | """Sets the learning rate to the initial LR decayed by 10 every 40 epochs""" 304 | lr = args.lr * (0.1 ** (epoch // 40)) 305 | print 'Learning rate:', lr 306 | for param_group in optimizer.param_groups: 307 | param_group['lr'] = lr 308 | 309 | 310 | def accuracy(output, target, topk=(1,)): 311 | """Computes the precision@k for the specified values of k""" 312 | maxk = max(topk) 313 | batch_size = target.size(0) 314 | 315 | _, pred = output.topk(maxk, 1, True, True) 316 | pred = pred.t() 317 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 318 | 319 | res = [] 320 | for k in topk: 321 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 322 | res.append(correct_k.mul_(100.0 / batch_size)) 323 | return res 324 | 325 | 326 | if __name__ == '__main__': 327 | main() 328 | --------------------------------------------------------------------------------