├── .gitignore ├── README.md ├── gen_mean_std.py ├── main.py ├── models ├── __init__.py ├── densenet_cifar.py ├── resnet_cifar.py ├── resnext_cifar.py └── wide_resnet_cifar.py ├── requirements.txt └── run.sh /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | data/* 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Experiments on CIFAR datasets with PyTorch 2 | 3 | ## Introduction 4 | Reimplement state-of-the-art CNN models in cifar dataset with PyTorch, now including: 5 | 6 | 1.[ResNet](https://arxiv.org/abs/1512.03385v1) 7 | 8 | 2.[PreActResNet](https://arxiv.org/abs/1603.05027v3) 9 | 10 | 3.[WideResNet](https://arxiv.org/abs/1605.07146v4) 11 | 12 | 4.[ResNeXt](https://arxiv.org/abs/1611.05431v2) 13 | 14 | 5.[DenseNet](https://arxiv.org/abs/1608.06993v4) 15 | 16 | other results will be added later. 17 | 18 | ## Requirements:software 19 | Requirements for [PyTorch](http://pytorch.org/) 20 | 21 | ## Requirements:hardware 22 | For most experiments, one or two K40(~11G of memory) gpus is enough cause PyTorch is very memory efficient. However, 23 | to train DenseNet on cifar(10 or 100), you need at least 4 K40 gpus. 24 | 25 | ## Usage 26 | 1. Clone this repository 27 | 28 | ``` 29 | git clone https://github.com/junyuseu/pytorch-cifar-models.git 30 | ``` 31 | 32 | In this project, the network structure is defined in the models folder, the script ```gen_mean_std.py``` is used to calculate 33 | the mean and standard deviation value of the dataset. 34 | 35 | 2. Edit main.py and run.sh 36 | 37 | In the ```main.py```, you can specify the network you want to train(for example): 38 | 39 | ``` 40 | model = resnet20_cifar(num_classes=10) 41 | ... 42 | fdir = 'result/resnet20_cifar10' 43 | ``` 44 | 45 | Then, you need specify some parameter for training in ```run.sh```. For resnet20: 46 | 47 | ``` 48 | CUDA_VISIBLE_DEVICES=0 python main.py --epoch 160 --batch-size 128 --lr 0.1 --momentum 0.9 --wd 1e-4 -ct 10 49 | ``` 50 | 51 | 3. Train 52 | 53 | ``` 54 | nohup sh run.sh > resnet20_cifar10.log & 55 | ``` 56 | 57 | After training, the training log will be recorded in the .log file, the best model(on the test set) 58 | will be stored in the fdir. 59 | 60 | **Note**:For first training, cifar10 or cifar100 dataset will be downloaded, so make sure your comuter is online. 61 | Otherwise, download the datasets and decompress them and put them in the ```data``` folder. 62 | 63 | 4. Test 64 | 65 | ``` 66 | CUDA_VISIBLE_DEVICES=0 python main.py -e --resume=fdir/model_best.pth.tar 67 | ``` 68 | 69 | 5. CIFAR100 70 | 71 | The default setting in the code is for cifar10, to train with cifar100, you need specify it explicitly in the code. 72 | 73 | ``` 74 | model = resnet20_cifar(num_classes=100) 75 | ``` 76 | 77 | **Note**: you should also change **fdir** In the run.sh, you should set ```-ct 100``` 78 | 79 | ## Results 80 | **Note**:The results as follow are got by only one single experiment. 81 | 82 | **We got comparable or even better results than the original papers, the experiment settings are totally follow 83 | the original ones** 84 | 85 | ### ResNet 86 | 87 | layers|#params|error(%) 88 | :---:|:---:|:---: 89 | 20|0.27M|8.33 90 | 32|0.46M|7.36 91 | 44|0.66M|6.77 92 | 56|0.85M|6.73 93 | 110|1.7M|**6.13** 94 | 1202|19.4M|- 95 | 96 | ### PreActResNet 97 | 98 | dataset|network|baseline unit|pre-activation unit 99 | :---:|:---:|:---:|:---: 100 | CIFAR-10|ResNet-110|6.13|6.13 101 | CIFAR-10|ResNet-164|5.84|5.35 102 | CIFAR-10|ResNet-1001|11.27|**5.13** 103 | CIFAR-100|ResNet-164|24.99|24.50 104 | CIFAR-100|ResNet-1001|31.73|**24.03** 105 | 106 | ### WideResNet 107 | 108 | depth-k|#params|CIFAR-10|CIFAR-100 109 | :---:|:---:|:---:|:---: 110 | 20-10|26.8M|4.27|19.73 111 | 26-10|36.5M|**3.89**|**19.51** 112 | 113 | ### ResNeXt 114 | 115 | network|#params|CIFAR-10|CIFAR-100 116 | :---:|:---:|:---:|:---: 117 | ResNeXt-29,1x64d|4.9M|4.51|22.09 118 | ResNeXt-29,8x64d|34.4M|3.78|17.44 119 | ResNeXt-29,16x64d|68.1M|**3.69**|**17.11** 120 | 121 | ### DenseNet 122 | 123 | network|depth|#params|CIFAR-10|CIFAR-100 124 | :---:|:---:|:---:|:---:|:---: 125 | DenseNet-BC(k=12)|100|0.8M|4.69|22.19 126 | DenseNet-BC(k=24)|250|15.3M|3.44|**17.17** 127 | DenseNet-BC(k=40)|190|25.6M|**3.41**|17.33 128 | 129 | # References: 130 | [1] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016. 131 | 132 | [2] K. He, X. Zhang, S. Ren, and J. Sun. Identity mappings in deep residual networks. In ECCV, 2016. 133 | 134 | [3] S. Zagoruyko and N. Komodakis. Wide residual networks. In BMVC, 2016. 135 | 136 | [4] S. Xie, G. Ross, P. Dollar, Z. Tu and K. He Aggregated residual transformations for deep neural networks. In CVPR, 2017 137 | 138 | [5] H. Gao, Z. Liu, L. Maaten and K. Weinberger. Densely connected convolutional networks. In CVPR, 2017 139 | -------------------------------------------------------------------------------- /gen_mean_std.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | 6 | def gen_mean_std(dataset): 7 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=50000, shuffle=False, num_workers=2) 8 | train = iter(dataloader).next()[0] 9 | mean = np.mean(train.numpy(), axis=(0, 2, 3)) 10 | std = np.std(train.numpy(), axis=(0, 2, 3)) 11 | return mean, std 12 | 13 | if __name__=='__main__': 14 | # cifar10 = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()])) 15 | cifar100 = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()])) 16 | mean,std = gen_mean_std(cifar100) 17 | print(mean, std) 18 | 19 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import shutil 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | import torch.backends.cudnn as cudnn 11 | 12 | 13 | import torchvision 14 | import torchvision.transforms as transforms 15 | 16 | from models import * 17 | 18 | 19 | parser = argparse.ArgumentParser(description='PyTorch Cifar10 Training') 20 | parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run') 21 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') 22 | parser.add_argument('-b', '--batch-size', default=128, type=int, metavar='N', help='mini-batch size (default: 128),only used for train') 23 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate') 24 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 25 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)') 26 | parser.add_argument('--print-freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)') 27 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 28 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') 29 | parser.add_argument('-ct', '--cifar-type', default='10', type=int, metavar='CT', help='10 for cifar10,100 for cifar100 (default: 10)') 30 | 31 | best_prec = 0 32 | 33 | def main(): 34 | global args, best_prec 35 | args = parser.parse_args() 36 | use_gpu = torch.cuda.is_available() 37 | 38 | # Model building 39 | print('=> Building model...') 40 | if use_gpu: 41 | # model can be set to anyone that I have defined in models folder 42 | # note the model should match to the cifar type ! 43 | 44 | model = resnet20_cifar() 45 | # model = resnet32_cifar() 46 | # model = resnet44_cifar() 47 | # model = resnet110_cifar() 48 | # model = preact_resnet110_cifar() 49 | # model = resnet164_cifar(num_classes=100) 50 | # model = resnet1001_cifar(num_classes=100) 51 | # model = preact_resnet164_cifar(num_classes=100) 52 | # model = preact_resnet1001_cifar(num_classes=100) 53 | 54 | # model = wide_resnet_cifar(depth=26, width=10, num_classes=100) 55 | 56 | # model = resneXt_cifar(depth=29, cardinality=16, baseWidth=64, num_classes=100) 57 | 58 | #model = densenet_BC_cifar(depth=190, k=40, num_classes=100) 59 | 60 | # mkdir a new folder to store the checkpoint and best model 61 | if not os.path.exists('result'): 62 | os.makedirs('result') 63 | fdir = 'result/resnet20_cifar10' 64 | if not os.path.exists(fdir): 65 | os.makedirs(fdir) 66 | 67 | # adjust the lr according to the model type 68 | if isinstance(model, (ResNet_Cifar, PreAct_ResNet_Cifar)): 69 | model_type = 1 70 | elif isinstance(model, Wide_ResNet_Cifar): 71 | model_type = 2 72 | elif isinstance(model, (ResNeXt_Cifar, DenseNet_Cifar)): 73 | model_type = 3 74 | else: 75 | print('model type unrecognized...') 76 | return 77 | 78 | model = nn.DataParallel(model).cuda() 79 | criterion = nn.CrossEntropyLoss().cuda() 80 | optimizer = optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 81 | cudnn.benchmark = True 82 | else: 83 | print('Cuda is not available!') 84 | return 85 | 86 | if args.resume: 87 | if os.path.isfile(args.resume): 88 | print('=> loading checkpoint "{}"'.format(args.resume)) 89 | checkpoint = torch.load(args.resume) 90 | args.start_epoch = checkpoint['epoch'] 91 | best_prec = checkpoint['best_prec'] 92 | model.load_state_dict(checkpoint['state_dict']) 93 | optimizer.load_state_dict(checkpoint['optimizer']) 94 | print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) 95 | else: 96 | print("=> no checkpoint found at '{}'".format(args.resume)) 97 | 98 | # Data loading and preprocessing 99 | # CIFAR10 100 | if args.cifar_type == 10: 101 | print('=> loading cifar10 data...') 102 | normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262]) 103 | 104 | train_dataset = torchvision.datasets.CIFAR10( 105 | root='./data', 106 | train=True, 107 | download=True, 108 | transform=transforms.Compose([ 109 | transforms.RandomCrop(32, padding=4), 110 | transforms.RandomHorizontalFlip(), 111 | transforms.ToTensor(), 112 | normalize, 113 | ])) 114 | trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2) 115 | 116 | test_dataset = torchvision.datasets.CIFAR10( 117 | root='./data', 118 | train=False, 119 | download=True, 120 | transform=transforms.Compose([ 121 | transforms.ToTensor(), 122 | normalize, 123 | ])) 124 | testloader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=2) 125 | # CIFAR100 126 | else: 127 | print('=> loading cifar100 data...') 128 | normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]) 129 | 130 | train_dataset = torchvision.datasets.CIFAR100( 131 | root='./data', 132 | train=True, 133 | download=True, 134 | transform=transforms.Compose([ 135 | transforms.RandomCrop(32, padding=4), 136 | transforms.RandomHorizontalFlip(), 137 | transforms.ToTensor(), 138 | normalize, 139 | ])) 140 | trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2) 141 | 142 | test_dataset = torchvision.datasets.CIFAR100( 143 | root='./data', 144 | train=False, 145 | download=True, 146 | transform=transforms.Compose([ 147 | transforms.ToTensor(), 148 | normalize, 149 | ])) 150 | testloader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=2) 151 | 152 | if args.evaluate: 153 | validate(testloader, model, criterion) 154 | return 155 | 156 | for epoch in range(args.start_epoch, args.epochs): 157 | adjust_learning_rate(optimizer, epoch, model_type) 158 | 159 | # train for one epoch 160 | train(trainloader, model, criterion, optimizer, epoch) 161 | 162 | # evaluate on test set 163 | prec = validate(testloader, model, criterion) 164 | 165 | # remember best precision and save checkpoint 166 | is_best = prec > best_prec 167 | best_prec = max(prec,best_prec) 168 | save_checkpoint({ 169 | 'epoch': epoch + 1, 170 | 'state_dict': model.state_dict(), 171 | 'best_prec': best_prec, 172 | 'optimizer': optimizer.state_dict(), 173 | }, is_best, fdir) 174 | 175 | 176 | class AverageMeter(object): 177 | """Computes and stores the average and current value""" 178 | def __init__(self): 179 | self.reset() 180 | 181 | def reset(self): 182 | self.val = 0 183 | self.avg = 0 184 | self.sum = 0 185 | self.count = 0 186 | 187 | def update(self, val, n=1): 188 | self.val = val 189 | self.sum += val * n 190 | self.count += n 191 | self.avg = self.sum / self.count 192 | 193 | 194 | def train(trainloader, model, criterion, optimizer, epoch): 195 | batch_time = AverageMeter() 196 | data_time = AverageMeter() 197 | losses = AverageMeter() 198 | top1 = AverageMeter() 199 | 200 | model.train() 201 | 202 | end = time.time() 203 | for i, (input, target) in enumerate(trainloader): 204 | # measure data loading time 205 | data_time.update(time.time() - end) 206 | 207 | input, target = input.cuda(), target.cuda() 208 | 209 | # compute output 210 | output = model(input) 211 | loss = criterion(output, target) 212 | 213 | # measure accuracy and record loss 214 | prec = accuracy(output, target)[0] 215 | losses.update(loss.item(), input.size(0)) 216 | top1.update(prec.item(), input.size(0)) 217 | 218 | # compute gradient and do SGD step 219 | optimizer.zero_grad() 220 | loss.backward() 221 | optimizer.step() 222 | 223 | # measure elapsed time 224 | batch_time.update(time.time() - end) 225 | end = time.time() 226 | 227 | if i % args.print_freq == 0: 228 | print('Epoch: [{0}][{1}/{2}]\t' 229 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 230 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 231 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 232 | 'Prec {top1.val:.3f}% ({top1.avg:.3f}%)'.format( 233 | epoch, i, len(trainloader), batch_time=batch_time, 234 | data_time=data_time, loss=losses, top1=top1)) 235 | 236 | 237 | def validate(val_loader, model, criterion): 238 | batch_time = AverageMeter() 239 | losses = AverageMeter() 240 | top1 = AverageMeter() 241 | 242 | # switch to evaluate mode 243 | model.eval() 244 | 245 | end = time.time() 246 | with torch.no_grad(): 247 | for i, (input, target) in enumerate(val_loader): 248 | input, target = input.cuda(), target.cuda() 249 | 250 | # compute output 251 | output = model(input) 252 | loss = criterion(output, target) 253 | 254 | # measure accuracy and record loss 255 | prec = accuracy(output, target)[0] 256 | losses.update(loss.item(), input.size(0)) 257 | top1.update(prec.item(), input.size(0)) 258 | 259 | # measure elapsed time 260 | batch_time.update(time.time() - end) 261 | end = time.time() 262 | 263 | if i % args.print_freq == 0: 264 | print('Test: [{0}/{1}]\t' 265 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 266 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 267 | 'Prec {top1.val:.3f}% ({top1.avg:.3f}%)'.format( 268 | i, len(val_loader), batch_time=batch_time, loss=losses, 269 | top1=top1)) 270 | 271 | print(' * Prec {top1.avg:.3f}% '.format(top1=top1)) 272 | 273 | return top1.avg 274 | 275 | 276 | def save_checkpoint(state, is_best, fdir): 277 | filepath = os.path.join(fdir, 'checkpoint.pth') 278 | torch.save(state, filepath) 279 | if is_best: 280 | shutil.copyfile(filepath, os.path.join(fdir, 'model_best.pth.tar')) 281 | 282 | 283 | def adjust_learning_rate(optimizer, epoch, model_type): 284 | """For resnet, the lr starts from 0.1, and is divided by 10 at 80 and 120 epochs""" 285 | if model_type == 1: 286 | if epoch < 80: 287 | lr = args.lr 288 | elif epoch < 120: 289 | lr = args.lr * 0.1 290 | else: 291 | lr = args.lr * 0.01 292 | elif model_type == 2: 293 | if epoch < 60: 294 | lr = args.lr 295 | elif epoch < 120: 296 | lr = args.lr * 0.2 297 | elif epoch < 160: 298 | lr = args.lr * 0.04 299 | else: 300 | lr = args.lr * 0.008 301 | elif model_type == 3: 302 | if epoch < 150: 303 | lr = args.lr 304 | elif epoch < 225: 305 | lr = args.lr * 0.1 306 | else: 307 | lr = args.lr * 0.01 308 | for param_group in optimizer.param_groups: 309 | param_group['lr'] = lr 310 | 311 | 312 | def accuracy(output, target, topk=(1,)): 313 | """Computes the precision@k for the specified values of k""" 314 | maxk = max(topk) 315 | batch_size = target.size(0) 316 | 317 | _, pred = output.topk(maxk, 1, True, True) 318 | pred = pred.t() 319 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 320 | 321 | res = [] 322 | for k in topk: 323 | correct_k = correct[:k].view(-1).float().sum(0) 324 | res.append(correct_k.mul_(100.0 / batch_size)) 325 | return res 326 | 327 | 328 | if __name__=='__main__': 329 | main() 330 | 331 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.resnet_cifar import * 2 | from models.wide_resnet_cifar import * 3 | from models.resnext_cifar import * 4 | from models.densenet_cifar import * 5 | -------------------------------------------------------------------------------- /models/densenet_cifar.py: -------------------------------------------------------------------------------- 1 | """ 2 | DenseNet for cifar with pytorch 3 | 4 | Reference: 5 | [1] H. Gao, Z. Liu, L. Maaten and K. Weinberger. Densely connected convolutional networks. In CVPR, 2017 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from collections import OrderedDict 12 | 13 | import math 14 | 15 | class _DenseLayer(nn.Sequential): 16 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 17 | super(_DenseLayer, self).__init__() 18 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 19 | self.add_module('relu1', nn.ReLU(inplace=True)), 20 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 21 | growth_rate, kernel_size=1, stride=1, bias=False)), 22 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 23 | self.add_module('relu2', nn.ReLU(inplace=True)), 24 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 25 | kernel_size=3, stride=1, padding=1, bias=False)), 26 | self.drop_rate = drop_rate 27 | 28 | def forward(self, x): 29 | new_features = super(_DenseLayer, self).forward(x) 30 | if self.drop_rate > 0: 31 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 32 | return torch.cat([x, new_features], 1) 33 | 34 | 35 | class _DenseBlock(nn.Sequential): 36 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 37 | super(_DenseBlock, self).__init__() 38 | for i in range(num_layers): 39 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) 40 | self.add_module('denselayer%d' % (i + 1), layer) 41 | 42 | 43 | class _Transition(nn.Sequential): 44 | def __init__(self, num_input_features, num_output_features): 45 | super(_Transition, self).__init__() 46 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 47 | self.add_module('relu', nn.ReLU(inplace=True)) 48 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 49 | kernel_size=1, stride=1, bias=False)) 50 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 51 | 52 | 53 | class DenseNet_Cifar(nn.Module): 54 | r"""Densenet-BC model class, based on 55 | `"Densely Connected Convolutional Networks" `_ 56 | 57 | Args: 58 | growth_rate (int) - how many filters to add each layer (`k` in paper) 59 | block_config (list of 4 ints) - how many layers in each pooling block 60 | num_init_features (int) - the number of filters to learn in the first convolution layer 61 | bn_size (int) - multiplicative factor for number of bottle neck layers 62 | (i.e. bn_size * k features in the bottleneck layer) 63 | drop_rate (float) - dropout rate after each dense layer 64 | num_classes (int) - number of classification classes 65 | """ 66 | def __init__(self, growth_rate=12, block_config=(16, 16, 16), 67 | num_init_features=24, bn_size=4, drop_rate=0, num_classes=10): 68 | 69 | super(DenseNet_Cifar, self).__init__() 70 | 71 | # First convolution 72 | self.features = nn.Sequential(OrderedDict([ 73 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=3, stride=1, padding=1, bias=False)), 74 | ])) 75 | 76 | # Each denseblock 77 | num_features = num_init_features 78 | for i, num_layers in enumerate(block_config): 79 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 80 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 81 | self.features.add_module('denseblock%d' % (i + 1), block) 82 | num_features = num_features + num_layers * growth_rate 83 | if i != len(block_config) - 1: 84 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 85 | self.features.add_module('transition%d' % (i + 1), trans) 86 | num_features = num_features // 2 87 | 88 | # Final batch norm 89 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 90 | 91 | # Linear layer 92 | self.classifier = nn.Linear(num_features, num_classes) 93 | 94 | # initialize conv and bn parameters 95 | for m in self.modules(): 96 | if isinstance(m, nn.Conv2d): 97 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 98 | m.weight.data.normal_(0, math.sqrt(2. / n)) 99 | elif isinstance(m, nn.BatchNorm2d): 100 | m.weight.data.fill_(1) 101 | m.bias.data.zero_() 102 | 103 | def forward(self, x): 104 | features = self.features(x) 105 | out = F.relu(features, inplace=True) 106 | out = F.avg_pool2d(out, kernel_size=8, stride=1).view(features.size(0), -1) 107 | out = self.classifier(out) 108 | return out 109 | 110 | 111 | def densenet_BC_cifar(depth, k, **kwargs): 112 | N = (depth - 4) // 6 113 | model = DenseNet_Cifar(growth_rate=k, block_config=[N, N, N], num_init_features=2*k, **kwargs) 114 | return model 115 | 116 | 117 | if __name__ == '__main__': 118 | net = densenet_BC_cifar(190, 40, num_classes=100) 119 | input = torch.randn(1, 3, 32, 32) 120 | y = net(input) 121 | print(net) 122 | print(y.size()) 123 | -------------------------------------------------------------------------------- /models/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | ''' 2 | resnet for cifar in pytorch 3 | 4 | Reference: 5 | [1] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016. 6 | [2] K. He, X. Zhang, S. Ren, and J. Sun. Identity mappings in deep residual networks. In ECCV, 2016. 7 | ''' 8 | 9 | import torch 10 | import torch.nn as nn 11 | import math 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | " 3x3 convolution with padding " 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion=1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None): 23 | super(BasicBlock, self).__init__() 24 | self.conv1 = conv3x3(inplanes, planes, stride) 25 | self.bn1 = nn.BatchNorm2d(planes) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv2 = conv3x3(planes, planes) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | out += residual 46 | out = self.relu(out) 47 | 48 | return out 49 | 50 | 51 | class Bottleneck(nn.Module): 52 | expansion=4 53 | 54 | def __init__(self, inplanes, planes, stride=1, downsample=None): 55 | super(Bottleneck, self).__init__() 56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 57 | self.bn1 = nn.BatchNorm2d(planes) 58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 59 | self.bn2 = nn.BatchNorm2d(planes) 60 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False) 61 | self.bn3 = nn.BatchNorm2d(planes*4) 62 | self.relu = nn.ReLU(inplace=True) 63 | self.downsample = downsample 64 | self.stride = stride 65 | 66 | def forward(self, x): 67 | residual = x 68 | 69 | out = self.conv1(x) 70 | out = self.bn1(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv2(out) 74 | out = self.bn2(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv3(out) 78 | out = self.bn3(out) 79 | 80 | if self.downsample is not None: 81 | residual = self.downsample(x) 82 | 83 | out += residual 84 | out = self.relu(out) 85 | 86 | return out 87 | 88 | 89 | class PreActBasicBlock(nn.Module): 90 | expansion = 1 91 | 92 | def __init__(self, inplanes, planes, stride=1, downsample=None): 93 | super(PreActBasicBlock, self).__init__() 94 | self.bn1 = nn.BatchNorm2d(inplanes) 95 | self.relu = nn.ReLU(inplace=True) 96 | self.conv1 = conv3x3(inplanes, planes, stride) 97 | self.bn2 = nn.BatchNorm2d(planes) 98 | self.conv2 = conv3x3(planes, planes) 99 | self.downsample = downsample 100 | self.stride = stride 101 | 102 | def forward(self, x): 103 | residual = x 104 | 105 | out = self.bn1(x) 106 | out = self.relu(out) 107 | 108 | if self.downsample is not None: 109 | residual = self.downsample(out) 110 | 111 | out = self.conv1(out) 112 | 113 | out = self.bn2(out) 114 | out = self.relu(out) 115 | out = self.conv2(out) 116 | 117 | out += residual 118 | 119 | return out 120 | 121 | 122 | class PreActBottleneck(nn.Module): 123 | expansion = 4 124 | 125 | def __init__(self, inplanes, planes, stride=1, downsample=None): 126 | super(PreActBottleneck, self).__init__() 127 | self.bn1 = nn.BatchNorm2d(inplanes) 128 | self.relu = nn.ReLU(inplace=True) 129 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 130 | self.bn2 = nn.BatchNorm2d(planes) 131 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 132 | self.bn3 = nn.BatchNorm2d(planes) 133 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False) 134 | self.downsample = downsample 135 | self.stride = stride 136 | 137 | def forward(self, x): 138 | residual = x 139 | 140 | out = self.bn1(x) 141 | out = self.relu(out) 142 | 143 | if self.downsample is not None: 144 | residual = self.downsample(out) 145 | 146 | out = self.conv1(out) 147 | 148 | out = self.bn2(out) 149 | out = self.relu(out) 150 | out = self.conv2(out) 151 | 152 | out = self.bn3(out) 153 | out = self.relu(out) 154 | out = self.conv3(out) 155 | 156 | out += residual 157 | 158 | return out 159 | 160 | 161 | class ResNet_Cifar(nn.Module): 162 | 163 | def __init__(self, block, layers, num_classes=10): 164 | super(ResNet_Cifar, self).__init__() 165 | self.inplanes = 16 166 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 167 | self.bn1 = nn.BatchNorm2d(16) 168 | self.relu = nn.ReLU(inplace=True) 169 | self.layer1 = self._make_layer(block, 16, layers[0]) 170 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 171 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 172 | self.avgpool = nn.AvgPool2d(8, stride=1) 173 | self.fc = nn.Linear(64 * block.expansion, num_classes) 174 | 175 | for m in self.modules(): 176 | if isinstance(m, nn.Conv2d): 177 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 178 | m.weight.data.normal_(0, math.sqrt(2. / n)) 179 | elif isinstance(m, nn.BatchNorm2d): 180 | m.weight.data.fill_(1) 181 | m.bias.data.zero_() 182 | 183 | def _make_layer(self, block, planes, blocks, stride=1): 184 | downsample = None 185 | if stride != 1 or self.inplanes != planes * block.expansion: 186 | downsample = nn.Sequential( 187 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 188 | nn.BatchNorm2d(planes * block.expansion) 189 | ) 190 | 191 | layers = [] 192 | layers.append(block(self.inplanes, planes, stride, downsample)) 193 | self.inplanes = planes * block.expansion 194 | for _ in range(1, blocks): 195 | layers.append(block(self.inplanes, planes)) 196 | 197 | return nn.Sequential(*layers) 198 | 199 | def forward(self, x): 200 | x = self.conv1(x) 201 | x = self.bn1(x) 202 | x = self.relu(x) 203 | 204 | x = self.layer1(x) 205 | x = self.layer2(x) 206 | x = self.layer3(x) 207 | 208 | x = self.avgpool(x) 209 | x = x.view(x.size(0), -1) 210 | x = self.fc(x) 211 | 212 | return x 213 | 214 | 215 | class PreAct_ResNet_Cifar(nn.Module): 216 | 217 | def __init__(self, block, layers, num_classes=10): 218 | super(PreAct_ResNet_Cifar, self).__init__() 219 | self.inplanes = 16 220 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 221 | self.layer1 = self._make_layer(block, 16, layers[0]) 222 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 223 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 224 | self.bn = nn.BatchNorm2d(64*block.expansion) 225 | self.relu = nn.ReLU(inplace=True) 226 | self.avgpool = nn.AvgPool2d(8, stride=1) 227 | self.fc = nn.Linear(64*block.expansion, num_classes) 228 | 229 | for m in self.modules(): 230 | if isinstance(m, nn.Conv2d): 231 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 232 | m.weight.data.normal_(0, math.sqrt(2. / n)) 233 | elif isinstance(m, nn.BatchNorm2d): 234 | m.weight.data.fill_(1) 235 | m.bias.data.zero_() 236 | 237 | def _make_layer(self, block, planes, blocks, stride=1): 238 | downsample = None 239 | if stride != 1 or self.inplanes != planes*block.expansion: 240 | downsample = nn.Sequential( 241 | nn.Conv2d(self.inplanes, planes*block.expansion, kernel_size=1, stride=stride, bias=False) 242 | ) 243 | 244 | layers = [] 245 | layers.append(block(self.inplanes, planes, stride, downsample)) 246 | self.inplanes = planes*block.expansion 247 | for _ in range(1, blocks): 248 | layers.append(block(self.inplanes, planes)) 249 | return nn.Sequential(*layers) 250 | 251 | def forward(self, x): 252 | x = self.conv1(x) 253 | 254 | x = self.layer1(x) 255 | x = self.layer2(x) 256 | x = self.layer3(x) 257 | 258 | x = self.bn(x) 259 | x = self.relu(x) 260 | x = self.avgpool(x) 261 | x = x.view(x.size(0), -1) 262 | x = self.fc(x) 263 | 264 | return x 265 | 266 | 267 | 268 | def resnet20_cifar(**kwargs): 269 | model = ResNet_Cifar(BasicBlock, [3, 3, 3], **kwargs) 270 | return model 271 | 272 | 273 | def resnet32_cifar(**kwargs): 274 | model = ResNet_Cifar(BasicBlock, [5, 5, 5], **kwargs) 275 | return model 276 | 277 | 278 | def resnet44_cifar(**kwargs): 279 | model = ResNet_Cifar(BasicBlock, [7, 7, 7], **kwargs) 280 | return model 281 | 282 | 283 | def resnet56_cifar(**kwargs): 284 | model = ResNet_Cifar(BasicBlock, [9, 9, 9], **kwargs) 285 | return model 286 | 287 | 288 | def resnet110_cifar(**kwargs): 289 | model = ResNet_Cifar(BasicBlock, [18, 18, 18], **kwargs) 290 | return model 291 | 292 | 293 | def resnet1202_cifar(**kwargs): 294 | model = ResNet_Cifar(BasicBlock, [200, 200, 200], **kwargs) 295 | return model 296 | 297 | 298 | def resnet164_cifar(**kwargs): 299 | model = ResNet_Cifar(Bottleneck, [18, 18, 18], **kwargs) 300 | return model 301 | 302 | 303 | def resnet1001_cifar(**kwargs): 304 | model = ResNet_Cifar(Bottleneck, [111, 111, 111], **kwargs) 305 | return model 306 | 307 | 308 | def preact_resnet110_cifar(**kwargs): 309 | model = PreAct_ResNet_Cifar(PreActBasicBlock, [18, 18, 18], **kwargs) 310 | return model 311 | 312 | 313 | def preact_resnet164_cifar(**kwargs): 314 | model = PreAct_ResNet_Cifar(PreActBottleneck, [18, 18, 18], **kwargs) 315 | return model 316 | 317 | 318 | def preact_resnet1001_cifar(**kwargs): 319 | model = PreAct_ResNet_Cifar(PreActBottleneck, [111, 111, 111], **kwargs) 320 | return model 321 | 322 | 323 | if __name__ == '__main__': 324 | net = resnet20_cifar() 325 | y = net(torch.randn(1, 3, 64, 64)) 326 | print(net) 327 | print(y.size()) 328 | 329 | -------------------------------------------------------------------------------- /models/resnext_cifar.py: -------------------------------------------------------------------------------- 1 | """ 2 | resneXt for cifar with pytorch 3 | 4 | Reference: 5 | [1] S. Xie, G. Ross, P. Dollar, Z. Tu and K. He Aggregated residual transformations for deep neural networks. In CVPR, 2017 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import math 11 | 12 | 13 | class Bottleneck(nn.Module): 14 | expansion = 4 15 | 16 | def __init__(self, inplanes, planes, cardinality, baseWidth, stride=1, downsample=None): 17 | super(Bottleneck, self).__init__() 18 | D = int(planes * (baseWidth / 64.)) 19 | C = cardinality 20 | self.conv1 = nn.Conv2d(inplanes, D*C, kernel_size=1, bias=False) 21 | self.bn1 = nn.BatchNorm2d(D*C) 22 | self.conv2 = nn.Conv2d(D*C, D*C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False) 23 | self.bn2 = nn.BatchNorm2d(D*C) 24 | self.conv3 = nn.Conv2d(D*C, planes*4, kernel_size=1, bias=False) 25 | self.bn3 = nn.BatchNorm2d(planes*4) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.downsample = downsample 28 | self.stride = stride 29 | 30 | def forward(self, x): 31 | residual = x 32 | 33 | out = self.conv1(x) 34 | out = self.bn1(out) 35 | out = self.relu(out) 36 | 37 | out = self.conv2(out) 38 | out = self.bn2(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv3(out) 42 | out = self.bn3(out) 43 | 44 | if self.downsample is not None: 45 | residual = self.downsample(x) 46 | 47 | if residual.size() != out.size(): 48 | print(out.size(), residual.size()) 49 | out += residual 50 | out = self.relu(out) 51 | 52 | return out 53 | 54 | 55 | class ResNeXt_Cifar(nn.Module): 56 | 57 | def __init__(self, block, layers, cardinality, baseWidth, num_classes=10): 58 | super(ResNeXt_Cifar, self).__init__() 59 | self.inplanes = 64 60 | self.cardinality = cardinality 61 | self.baseWidth = baseWidth 62 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 63 | self.bn1 = nn.BatchNorm2d(64) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.layer1 = self._make_layer(block, 64, layers[0]) 66 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 67 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 68 | self.avgpool = nn.AvgPool2d(8, stride=1) 69 | self.fc = nn.Linear(256 * block.expansion, num_classes) 70 | 71 | for m in self.modules(): 72 | if isinstance(m, nn.Conv2d): 73 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 74 | m.weight.data.normal_(0, math.sqrt(2. / n)) 75 | elif isinstance(m, nn.BatchNorm2d): 76 | m.weight.data.fill_(1) 77 | m.bias.data.zero_() 78 | 79 | def _make_layer(self, block, planes, blocks, stride=1): 80 | downsample = None 81 | if stride != 1 or self.inplanes != planes * block.expansion: 82 | downsample = nn.Sequential( 83 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 84 | nn.BatchNorm2d(planes * block.expansion) 85 | ) 86 | 87 | layers = [] 88 | layers.append(block(self.inplanes, planes, self.cardinality, self.baseWidth, stride, downsample)) 89 | self.inplanes = planes * block.expansion 90 | for _ in range(1, blocks): 91 | layers.append(block(self.inplanes, planes, self.cardinality, self.baseWidth)) 92 | 93 | return nn.Sequential(*layers) 94 | 95 | def forward(self, x): 96 | x = self.conv1(x) 97 | x = self.bn1(x) 98 | x = self.relu(x) 99 | 100 | x = self.layer1(x) 101 | x = self.layer2(x) 102 | x = self.layer3(x) 103 | 104 | x = self.avgpool(x) 105 | x = x.view(x.size(0), -1) 106 | x = self.fc(x) 107 | 108 | return x 109 | 110 | 111 | def resneXt_cifar(depth, cardinality, baseWidth, **kwargs): 112 | assert (depth - 2) % 9 == 0 113 | n = (depth - 2) / 9 114 | model = ResNeXt_Cifar(Bottleneck, [n, n, n], cardinality, baseWidth, **kwargs) 115 | return model 116 | 117 | 118 | if __name__ == '__main__': 119 | net = resneXt_cifar(29, 16, 64) 120 | y = net(torch.randn(1, 3, 32, 32)) 121 | print(net) 122 | print(y.size()) 123 | -------------------------------------------------------------------------------- /models/wide_resnet_cifar.py: -------------------------------------------------------------------------------- 1 | """ 2 | wide resnet for cifar in pytorch 3 | 4 | Reference: 5 | [1] S. Zagoruyko and N. Komodakis. Wide residual networks. In BMVC, 2016. 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import math 10 | from models.resnet_cifar import BasicBlock 11 | 12 | 13 | class Wide_ResNet_Cifar(nn.Module): 14 | 15 | def __init__(self, block, layers, wfactor, num_classes=10): 16 | super(Wide_ResNet_Cifar, self).__init__() 17 | self.inplanes = 16 18 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 19 | self.bn1 = nn.BatchNorm2d(16) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.layer1 = self._make_layer(block, 16*wfactor, layers[0]) 22 | self.layer2 = self._make_layer(block, 32*wfactor, layers[1], stride=2) 23 | self.layer3 = self._make_layer(block, 64*wfactor, layers[2], stride=2) 24 | self.avgpool = nn.AvgPool2d(8, stride=1) 25 | self.fc = nn.Linear(64*block.expansion*wfactor, num_classes) 26 | 27 | for m in self.modules(): 28 | if isinstance(m, nn.Conv2d): 29 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 30 | m.weight.data.normal_(0, math.sqrt(2. / n)) 31 | elif isinstance(m, nn.BatchNorm2d): 32 | m.weight.data.fill_(1) 33 | m.bias.data.zero_() 34 | 35 | def _make_layer(self, block, planes, blocks, stride=1): 36 | downsample = None 37 | if stride != 1 or self.inplanes != planes * block.expansion: 38 | downsample = nn.Sequential( 39 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 40 | nn.BatchNorm2d(planes * block.expansion) 41 | ) 42 | 43 | layers = [] 44 | layers.append(block(self.inplanes, planes, stride, downsample)) 45 | self.inplanes = planes * block.expansion 46 | for _ in range(1, blocks): 47 | layers.append(block(self.inplanes, planes)) 48 | 49 | return nn.Sequential(*layers) 50 | 51 | def forward(self, x): 52 | x = self.conv1(x) 53 | x = self.bn1(x) 54 | x = self.relu(x) 55 | 56 | x = self.layer1(x) 57 | x = self.layer2(x) 58 | x = self.layer3(x) 59 | 60 | x = self.avgpool(x) 61 | x = x.view(x.size(0), -1) 62 | x = self.fc(x) 63 | 64 | return x 65 | 66 | 67 | def wide_resnet_cifar(depth, width, **kwargs): 68 | assert (depth - 2) % 6 == 0 69 | n = (depth - 2) / 6 70 | return Wide_ResNet_Cifar(BasicBlock, [n, n, n], width, **kwargs) 71 | 72 | 73 | if __name__=='__main__': 74 | net = wide_resnet_cifar(20, 10) 75 | y = net(torch.randn(1, 3, 32, 32)) 76 | print(isinstance(net, Wide_ResNet_Cifar)) 77 | print(y.size()) 78 | 79 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | torchvision 4 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # for densenet 2 | # CUDA_VISIBLE_DEVICES=0 python main.py --epoch 300 --batch-size 64 -ct 100 3 | 4 | # for resnet 5 | CUDA_VISIBLE_DEVICES=0 python main.py --epoch 160 --batch-size 128 --lr 0.1 --momentum 0.9 --wd 1e-4 -ct 10 6 | --------------------------------------------------------------------------------