├── .gitignore ├── LICENSE ├── README.md ├── cifar ├── README.md ├── easy_mixup.py ├── images │ ├── cifar10_wd1em4.png │ └── cifar10_wd5em4.png ├── models │ ├── __init__.py │ ├── densenet.py │ ├── dpn.py │ ├── googlenet.py │ ├── lenet.py │ ├── mobilenet.py │ ├── pnasnet.py │ ├── preact_resnet.py │ ├── resnet.py │ ├── resnext.py │ ├── senet.py │ ├── shufflenet.py │ └── vgg.py └── utils.py └── gan ├── README.md ├── example_gan.py └── images └── gan_results.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repo contains demo reimplementations of the CIFAR-10 training code and the GAN experiment in PyTorch based on the following paper: 2 | > Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin and David Lopez-Paz. _mixup: Beyond Empirical Risk Minimization._ https://arxiv.org/abs/1710.09412 3 | 4 | ## CIFAR-10 5 | 6 | The following table shows the median test errors of the last 10 epochs in a 200-epoch training session. (Please refer to Section 3.2 in the paper for details.) 7 | 8 | | Model | weight decay = 1e-4 | weight decay = 5e-4 | 9 | |:-------------------|---------------------:|---------------------:| 10 | | ERM | 5.53% | 5.18% | 11 | | _mixup_ | 4.24% | 4.68% | 12 | 13 | ## Generative Adversarial Networks (GAN) 14 | 15 | ![](gan/images/gan_results.png) 16 | 17 | ## Other implementations 18 | - [A Tensorflow implementation of mixup](https://github.com/ppwwyyxx/tensorpack/tree/master/examples/ResNet#cifar10-preact18-mixuppy) which reproduces our results in [tensorpack](https://github.com/ppwwyyxx/tensorpack) 19 | - [Official Facebook implementation of the CIFAR-10 experiments](https://github.com/facebookresearch/mixup-cifar10) 20 | 21 | ## Acknowledgement 22 | The CIFAR-10 reimplementation of _mixup_ is adapted from the [pytorch-cifar](https://github.com/kuangliu/pytorch-cifar) repository by [kuangliu](https://github.com/kuangliu). 23 | -------------------------------------------------------------------------------- /cifar/README.md: -------------------------------------------------------------------------------- 1 | # mixup training for CIFAR-10 2 | 3 | ## Training 4 | ```bash 5 | # training with default parameters (weight_decay=1e-4 and alpha=1) 6 | python easy_mixup.py --sess my_session_1 --seed 11111 7 | ``` 8 | The above bash script will download the CIFAR data and train the network with 1e-4 weight decay and mixup parameter alpha=1.0; alternatively, we can experiment with other weight decay and alpha value using corresponding options: 9 | ```bash 10 | # training with weight_decay=5e-4 and alpha=0 (no mixup) 11 | python easy_mixup.py --sess my_session_2 --seed 22222 --decay 5e-4 --alpha 0. 12 | ``` 13 | The other choices (network architecture, #epochs, learning rate schedule, momentum, data augmentation etc.) are hard coded but modifications are hopefully straightfoward. 14 | 15 | By default, the trained model with the best validation accuracy resides in `./checkpoint` folder, and the training log (including training loss/accuracy and validation loss/accuracy for each epoch) is saved in `./results` as a `.csv` file. 16 | 17 | ## Results 18 | **_mixup_ reduces overfitting and improves generalization.** The following plots show test error curves of a typical training session using the PreAct ResNet-18 architecture (default; you can make changes [here](https://github.com/hongyi-zhang/mixup/blob/8b43d663501b10ccb8e21d88be9d42d3bab0fd2f/easy_mixup.py#L78)). Note that compared with the ERM baseline, **_mixup_ prefers a smaller weight decay** (1e-4 vs. 5e-4), indicating its regularization effects. 19 | 20 | | Model | weight decay = 1e-4 | weight decay = 5e-4 | 21 | |:-------------------|---------------------:|---------------------:| 22 | | ERM | 5.53% | 5.18% | 23 | | _mixup_ | 4.24% | 4.68% | 24 | 25 | ![](images/cifar10_wd1em4.png) 26 | ![](images/cifar10_wd5em4.png) 27 | 28 | ## Other frameworks 29 | - [A Tensorflow implementation of mixup](https://github.com/ppwwyyxx/tensorpack/tree/master/examples/ResNet#cifar10-preact18-mixuppy) which reproduces our results in [tensorpack](https://github.com/ppwwyyxx/tensorpack) 30 | 31 | ## Acknowledgement 32 | This reimplementation is adapted from the [pytorch-cifar](https://github.com/kuangliu/pytorch-cifar) repository by [kuangliu](https://github.com/kuangliu). 33 | -------------------------------------------------------------------------------- /cifar/easy_mixup.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | from __future__ import print_function 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | import torch.backends.cudnn as cudnn 9 | 10 | import torchvision 11 | import torchvision.transforms as transforms 12 | 13 | import os 14 | import argparse 15 | import csv 16 | 17 | from models import * 18 | from utils import progress_bar, mixup_data, mixup_criterion 19 | from torch.autograd import Variable 20 | 21 | 22 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 23 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 24 | parser.add_argument('--sess', default='mixup_default', type=str, help='session id') 25 | parser.add_argument('--seed', default=0, type=int, help='rng seed') 26 | parser.add_argument('--alpha', default=1., type=float, help='interpolation strength (uniform=1., ERM=0.)') 27 | parser.add_argument('--decay', default=1e-4, type=float, help='weight decay (default=1e-4)') 28 | args = parser.parse_args() 29 | 30 | torch.manual_seed(args.seed) 31 | 32 | use_cuda = torch.cuda.is_available() 33 | best_acc = 0 # best test accuracy 34 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 35 | batch_size = 128 36 | base_learning_rate = 0.1 37 | if use_cuda: 38 | # data parallel 39 | n_gpu = torch.cuda.device_count() 40 | batch_size *= n_gpu 41 | base_learning_rate *= n_gpu 42 | 43 | # Data 44 | print('==> Preparing data..') 45 | transform_train = transforms.Compose([ 46 | transforms.RandomCrop(32, padding=4), 47 | transforms.RandomHorizontalFlip(), 48 | transforms.ToTensor(), 49 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 50 | ]) 51 | 52 | transform_test = transforms.Compose([ 53 | transforms.ToTensor(), 54 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 55 | ]) 56 | 57 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 58 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) 59 | 60 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 61 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) 62 | 63 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 64 | 65 | # Model 66 | if args.resume: 67 | # Load checkpoint. 68 | print('==> Resuming from checkpoint..') 69 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 70 | checkpoint = torch.load('./checkpoint/ckpt.t7.' + args.sess + '_' + str(args.seed)) 71 | net = checkpoint['net'] 72 | best_acc = checkpoint['acc'] 73 | start_epoch = checkpoint['epoch'] + 1 74 | torch.set_rng_state(checkpoint['rng_state']) 75 | else: 76 | print('==> Building model..') 77 | # net = VGG('VGG19') 78 | net = PreActResNet18() 79 | # net = ResNet18() 80 | # net = GoogLeNet() 81 | # net = DenseNet121() 82 | # net = ResNeXt29_2x64d() 83 | # net = MobileNet() 84 | # net = DPN92() 85 | # net = ShuffleNetG2() 86 | # net = SENet18() 87 | 88 | result_folder = './results/' 89 | if not os.path.exists(result_folder): 90 | os.makedirs(result_folder) 91 | 92 | logname = result_folder + net.__class__.__name__ + '_' + args.sess + '_' + str(args.seed) + '.csv' 93 | 94 | if use_cuda: 95 | net.cuda() 96 | net = torch.nn.DataParallel(net) 97 | print('Using', torch.cuda.device_count(), 'GPUs.') 98 | cudnn.benchmark = True 99 | print('Using CUDA..') 100 | 101 | criterion = nn.CrossEntropyLoss() 102 | optimizer = optim.SGD(net.parameters(), lr=base_learning_rate, momentum=0.9, weight_decay=args.decay) 103 | 104 | # Training 105 | def train(epoch): 106 | print('\nEpoch: %d' % epoch) 107 | net.train() 108 | train_loss = 0 109 | correct = 0 110 | total = 0 111 | for batch_idx, (inputs, targets) in enumerate(trainloader): 112 | if use_cuda: 113 | inputs, targets = inputs.cuda(), targets.cuda() 114 | # generate mixed inputs, two one-hot label vectors and mixing coefficient 115 | inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, args.alpha, use_cuda) 116 | optimizer.zero_grad() 117 | inputs, targets_a, targets_b = Variable(inputs), Variable(targets_a), Variable(targets_b) 118 | outputs = net(inputs) 119 | 120 | loss_func = mixup_criterion(targets_a, targets_b, lam) 121 | loss = loss_func(criterion, outputs) 122 | loss.backward() 123 | optimizer.step() 124 | 125 | train_loss += loss.data[0] 126 | _, predicted = torch.max(outputs.data, 1) 127 | total += targets.size(0) 128 | correct += lam * predicted.eq(targets_a.data).cpu().sum() + (1 - lam) * predicted.eq(targets_b.data).cpu().sum() 129 | 130 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 131 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 132 | return (train_loss/batch_idx, 100.*correct/total) 133 | 134 | def test(epoch): 135 | global best_acc 136 | net.eval() 137 | test_loss = 0 138 | correct = 0 139 | total = 0 140 | for batch_idx, (inputs, targets) in enumerate(testloader): 141 | if use_cuda: 142 | inputs, targets = inputs.cuda(), targets.cuda() 143 | inputs, targets = Variable(inputs, volatile=True), Variable(targets) 144 | outputs = net(inputs) 145 | loss = criterion(outputs, targets) 146 | 147 | test_loss += loss.data[0] 148 | _, predicted = torch.max(outputs.data, 1) 149 | total += targets.size(0) 150 | correct += predicted.eq(targets.data).cpu().sum() 151 | 152 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 153 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 154 | 155 | # Save checkpoint. 156 | acc = 100.*correct/total 157 | if acc > best_acc: 158 | best_acc = acc 159 | checkpoint(acc, epoch) 160 | return (test_loss/batch_idx, 100.*correct/total) 161 | 162 | def checkpoint(acc, epoch): 163 | # Save checkpoint. 164 | print('Saving..') 165 | state = { 166 | 'net': net, 167 | 'acc': acc, 168 | 'epoch': epoch, 169 | 'rng_state': torch.get_rng_state() 170 | } 171 | if not os.path.isdir('checkpoint'): 172 | os.mkdir('checkpoint') 173 | torch.save(state, './checkpoint/ckpt.t7.' + args.sess + '_' + str(args.seed)) 174 | 175 | def adjust_learning_rate(optimizer, epoch): 176 | """decrease the learning rate at 100 and 150 epoch""" 177 | lr = base_learning_rate 178 | if epoch <= 9 and lr > 0.1: 179 | # warm-up training for large minibatch 180 | lr = 0.1 + (base_learning_rate - 0.1) * epoch / 10. 181 | if epoch >= 100: 182 | lr /= 10 183 | if epoch >= 150: 184 | lr /= 10 185 | for param_group in optimizer.param_groups: 186 | param_group['lr'] = lr 187 | 188 | if not os.path.exists(logname): 189 | with open(logname, 'w') as logfile: 190 | logwriter = csv.writer(logfile, delimiter=',') 191 | logwriter.writerow(['epoch', 'train loss', 'train acc', 'test loss', 'test acc']) 192 | 193 | for epoch in range(start_epoch, 200): 194 | adjust_learning_rate(optimizer, epoch) 195 | train_loss, train_acc = train(epoch) 196 | test_loss, test_acc = test(epoch) 197 | with open(logname, 'a') as logfile: 198 | logwriter = csv.writer(logfile, delimiter=',') 199 | logwriter.writerow([epoch, train_loss, train_acc, test_loss, test_acc]) 200 | -------------------------------------------------------------------------------- /cifar/images/cifar10_wd1em4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hongyi-zhang/mixup/80000cea340bf829a52481ae45a317a487ce2deb/cifar/images/cifar10_wd1em4.png -------------------------------------------------------------------------------- /cifar/images/cifar10_wd5em4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hongyi-zhang/mixup/80000cea340bf829a52481ae45a317a487ce2deb/cifar/images/cifar10_wd5em4.png -------------------------------------------------------------------------------- /cifar/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import * 2 | from .dpn import * 3 | from .lenet import * 4 | from .senet import * 5 | from .resnet import * 6 | from .resnext import * 7 | from .pnasnet import * 8 | from .densenet import * 9 | from .googlenet import * 10 | from .mobilenet import * 11 | from .shufflenet import * 12 | from .preact_resnet import * 13 | -------------------------------------------------------------------------------- /cifar/models/densenet.py: -------------------------------------------------------------------------------- 1 | '''DenseNet in PyTorch.''' 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from torch.autograd import Variable 9 | 10 | 11 | class Bottleneck(nn.Module): 12 | def __init__(self, in_planes, growth_rate): 13 | super(Bottleneck, self).__init__() 14 | self.bn1 = nn.BatchNorm2d(in_planes) 15 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(4*growth_rate) 17 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 18 | 19 | def forward(self, x): 20 | out = self.conv1(F.relu(self.bn1(x))) 21 | out = self.conv2(F.relu(self.bn2(out))) 22 | out = torch.cat([out,x], 1) 23 | return out 24 | 25 | 26 | class Transition(nn.Module): 27 | def __init__(self, in_planes, out_planes): 28 | super(Transition, self).__init__() 29 | self.bn = nn.BatchNorm2d(in_planes) 30 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 31 | 32 | def forward(self, x): 33 | out = self.conv(F.relu(self.bn(x))) 34 | out = F.avg_pool2d(out, 2) 35 | return out 36 | 37 | 38 | class DenseNet(nn.Module): 39 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): 40 | super(DenseNet, self).__init__() 41 | self.growth_rate = growth_rate 42 | 43 | num_planes = 2*growth_rate 44 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 45 | 46 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 47 | num_planes += nblocks[0]*growth_rate 48 | out_planes = int(math.floor(num_planes*reduction)) 49 | self.trans1 = Transition(num_planes, out_planes) 50 | num_planes = out_planes 51 | 52 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 53 | num_planes += nblocks[1]*growth_rate 54 | out_planes = int(math.floor(num_planes*reduction)) 55 | self.trans2 = Transition(num_planes, out_planes) 56 | num_planes = out_planes 57 | 58 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 59 | num_planes += nblocks[2]*growth_rate 60 | out_planes = int(math.floor(num_planes*reduction)) 61 | self.trans3 = Transition(num_planes, out_planes) 62 | num_planes = out_planes 63 | 64 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 65 | num_planes += nblocks[3]*growth_rate 66 | 67 | self.bn = nn.BatchNorm2d(num_planes) 68 | self.linear = nn.Linear(num_planes, num_classes) 69 | 70 | def _make_dense_layers(self, block, in_planes, nblock): 71 | layers = [] 72 | for i in range(nblock): 73 | layers.append(block(in_planes, self.growth_rate)) 74 | in_planes += self.growth_rate 75 | return nn.Sequential(*layers) 76 | 77 | def forward(self, x): 78 | out = self.conv1(x) 79 | out = self.trans1(self.dense1(out)) 80 | out = self.trans2(self.dense2(out)) 81 | out = self.trans3(self.dense3(out)) 82 | out = self.dense4(out) 83 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 84 | out = out.view(out.size(0), -1) 85 | out = self.linear(out) 86 | return out 87 | 88 | def DenseNet121(): 89 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32) 90 | 91 | def DenseNet169(): 92 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32) 93 | 94 | def DenseNet201(): 95 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32) 96 | 97 | def DenseNet161(): 98 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48) 99 | 100 | def densenet_cifar(): 101 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12) 102 | 103 | def test_densenet(): 104 | net = densenet_cifar() 105 | x = torch.randn(1,3,32,32) 106 | y = net(Variable(x)) 107 | print(y) 108 | 109 | # test_densenet() 110 | -------------------------------------------------------------------------------- /cifar/models/dpn.py: -------------------------------------------------------------------------------- 1 | '''Dual Path Networks in PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch.autograd import Variable 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer): 11 | super(Bottleneck, self).__init__() 12 | self.out_planes = out_planes 13 | self.dense_depth = dense_depth 14 | 15 | self.conv1 = nn.Conv2d(last_planes, in_planes, kernel_size=1, bias=False) 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False) 18 | self.bn2 = nn.BatchNorm2d(in_planes) 19 | self.conv3 = nn.Conv2d(in_planes, out_planes+dense_depth, kernel_size=1, bias=False) 20 | self.bn3 = nn.BatchNorm2d(out_planes+dense_depth) 21 | 22 | self.shortcut = nn.Sequential() 23 | if first_layer: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(last_planes, out_planes+dense_depth, kernel_size=1, stride=stride, bias=False), 26 | nn.BatchNorm2d(out_planes+dense_depth) 27 | ) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = F.relu(self.bn2(self.conv2(out))) 32 | out = self.bn3(self.conv3(out)) 33 | x = self.shortcut(x) 34 | d = self.out_planes 35 | out = torch.cat([x[:,:d,:,:]+out[:,:d,:,:], x[:,d:,:,:], out[:,d:,:,:]], 1) 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class DPN(nn.Module): 41 | def __init__(self, cfg): 42 | super(DPN, self).__init__() 43 | in_planes, out_planes = cfg['in_planes'], cfg['out_planes'] 44 | num_blocks, dense_depth = cfg['num_blocks'], cfg['dense_depth'] 45 | 46 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 47 | self.bn1 = nn.BatchNorm2d(64) 48 | self.last_planes = 64 49 | self.layer1 = self._make_layer(in_planes[0], out_planes[0], num_blocks[0], dense_depth[0], stride=1) 50 | self.layer2 = self._make_layer(in_planes[1], out_planes[1], num_blocks[1], dense_depth[1], stride=2) 51 | self.layer3 = self._make_layer(in_planes[2], out_planes[2], num_blocks[2], dense_depth[2], stride=2) 52 | self.layer4 = self._make_layer(in_planes[3], out_planes[3], num_blocks[3], dense_depth[3], stride=2) 53 | self.linear = nn.Linear(out_planes[3]+(num_blocks[3]+1)*dense_depth[3], 10) 54 | 55 | def _make_layer(self, in_planes, out_planes, num_blocks, dense_depth, stride): 56 | strides = [stride] + [1]*(num_blocks-1) 57 | layers = [] 58 | for i,stride in enumerate(strides): 59 | layers.append(Bottleneck(self.last_planes, in_planes, out_planes, dense_depth, stride, i==0)) 60 | self.last_planes = out_planes + (i+2) * dense_depth 61 | return nn.Sequential(*layers) 62 | 63 | def forward(self, x): 64 | out = F.relu(self.bn1(self.conv1(x))) 65 | out = self.layer1(out) 66 | out = self.layer2(out) 67 | out = self.layer3(out) 68 | out = self.layer4(out) 69 | out = F.avg_pool2d(out, 4) 70 | out = out.view(out.size(0), -1) 71 | out = self.linear(out) 72 | return out 73 | 74 | 75 | def DPN26(): 76 | cfg = { 77 | 'in_planes': (96,192,384,768), 78 | 'out_planes': (256,512,1024,2048), 79 | 'num_blocks': (2,2,2,2), 80 | 'dense_depth': (16,32,24,128) 81 | } 82 | return DPN(cfg) 83 | 84 | def DPN92(): 85 | cfg = { 86 | 'in_planes': (96,192,384,768), 87 | 'out_planes': (256,512,1024,2048), 88 | 'num_blocks': (3,4,20,3), 89 | 'dense_depth': (16,32,24,128) 90 | } 91 | return DPN(cfg) 92 | 93 | 94 | def test(): 95 | net = DPN92() 96 | x = Variable(torch.randn(1,3,32,32)) 97 | y = net(x) 98 | print(y) 99 | 100 | # test() 101 | -------------------------------------------------------------------------------- /cifar/models/googlenet.py: -------------------------------------------------------------------------------- 1 | '''GoogLeNet with PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch.autograd import Variable 7 | 8 | 9 | class Inception(nn.Module): 10 | def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes): 11 | super(Inception, self).__init__() 12 | # 1x1 conv branch 13 | self.b1 = nn.Sequential( 14 | nn.Conv2d(in_planes, n1x1, kernel_size=1), 15 | nn.BatchNorm2d(n1x1), 16 | nn.ReLU(True), 17 | ) 18 | 19 | # 1x1 conv -> 3x3 conv branch 20 | self.b2 = nn.Sequential( 21 | nn.Conv2d(in_planes, n3x3red, kernel_size=1), 22 | nn.BatchNorm2d(n3x3red), 23 | nn.ReLU(True), 24 | nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1), 25 | nn.BatchNorm2d(n3x3), 26 | nn.ReLU(True), 27 | ) 28 | 29 | # 1x1 conv -> 5x5 conv branch 30 | self.b3 = nn.Sequential( 31 | nn.Conv2d(in_planes, n5x5red, kernel_size=1), 32 | nn.BatchNorm2d(n5x5red), 33 | nn.ReLU(True), 34 | nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(n5x5), 36 | nn.ReLU(True), 37 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1), 38 | nn.BatchNorm2d(n5x5), 39 | nn.ReLU(True), 40 | ) 41 | 42 | # 3x3 pool -> 1x1 conv branch 43 | self.b4 = nn.Sequential( 44 | nn.MaxPool2d(3, stride=1, padding=1), 45 | nn.Conv2d(in_planes, pool_planes, kernel_size=1), 46 | nn.BatchNorm2d(pool_planes), 47 | nn.ReLU(True), 48 | ) 49 | 50 | def forward(self, x): 51 | y1 = self.b1(x) 52 | y2 = self.b2(x) 53 | y3 = self.b3(x) 54 | y4 = self.b4(x) 55 | return torch.cat([y1,y2,y3,y4], 1) 56 | 57 | 58 | class GoogLeNet(nn.Module): 59 | def __init__(self): 60 | super(GoogLeNet, self).__init__() 61 | self.pre_layers = nn.Sequential( 62 | nn.Conv2d(3, 192, kernel_size=3, padding=1), 63 | nn.BatchNorm2d(192), 64 | nn.ReLU(True), 65 | ) 66 | 67 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) 68 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) 69 | 70 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 71 | 72 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) 73 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) 74 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) 75 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) 76 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) 77 | 78 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) 79 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) 80 | 81 | self.avgpool = nn.AvgPool2d(8, stride=1) 82 | self.linear = nn.Linear(1024, 10) 83 | 84 | def forward(self, x): 85 | out = self.pre_layers(x) 86 | out = self.a3(out) 87 | out = self.b3(out) 88 | out = self.maxpool(out) 89 | out = self.a4(out) 90 | out = self.b4(out) 91 | out = self.c4(out) 92 | out = self.d4(out) 93 | out = self.e4(out) 94 | out = self.maxpool(out) 95 | out = self.a5(out) 96 | out = self.b5(out) 97 | out = self.avgpool(out) 98 | out = out.view(out.size(0), -1) 99 | out = self.linear(out) 100 | return out 101 | 102 | # net = GoogLeNet() 103 | # x = torch.randn(1,3,32,32) 104 | # y = net(Variable(x)) 105 | # print(y.size()) 106 | -------------------------------------------------------------------------------- /cifar/models/lenet.py: -------------------------------------------------------------------------------- 1 | '''LeNet in PyTorch.''' 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class LeNet(nn.Module): 6 | def __init__(self): 7 | super(LeNet, self).__init__() 8 | self.conv1 = nn.Conv2d(3, 6, 5) 9 | self.conv2 = nn.Conv2d(6, 16, 5) 10 | self.fc1 = nn.Linear(16*5*5, 120) 11 | self.fc2 = nn.Linear(120, 84) 12 | self.fc3 = nn.Linear(84, 10) 13 | 14 | def forward(self, x): 15 | out = F.relu(self.conv1(x)) 16 | out = F.max_pool2d(out, 2) 17 | out = F.relu(self.conv2(out)) 18 | out = F.max_pool2d(out, 2) 19 | out = out.view(out.size(0), -1) 20 | out = F.relu(self.fc1(out)) 21 | out = F.relu(self.fc2(out)) 22 | out = self.fc3(out) 23 | return out 24 | -------------------------------------------------------------------------------- /cifar/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | '''MobileNet in PyTorch. 2 | 3 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 4 | for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from torch.autograd import Variable 11 | 12 | 13 | class Block(nn.Module): 14 | '''Depthwise conv + Pointwise conv''' 15 | def __init__(self, in_planes, out_planes, stride=1): 16 | super(Block, self).__init__() 17 | self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False) 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 20 | self.bn2 = nn.BatchNorm2d(out_planes) 21 | 22 | def forward(self, x): 23 | out = F.relu(self.bn1(self.conv1(x))) 24 | out = F.relu(self.bn2(self.conv2(out))) 25 | return out 26 | 27 | 28 | class MobileNet(nn.Module): 29 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1 30 | cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024] 31 | 32 | def __init__(self, num_classes=10): 33 | super(MobileNet, self).__init__() 34 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 35 | self.bn1 = nn.BatchNorm2d(32) 36 | self.layers = self._make_layers(in_planes=32) 37 | self.linear = nn.Linear(1024, num_classes) 38 | 39 | def _make_layers(self, in_planes): 40 | layers = [] 41 | for x in self.cfg: 42 | out_planes = x if isinstance(x, int) else x[0] 43 | stride = 1 if isinstance(x, int) else x[1] 44 | layers.append(Block(in_planes, out_planes, stride)) 45 | in_planes = out_planes 46 | return nn.Sequential(*layers) 47 | 48 | def forward(self, x): 49 | out = F.relu(self.bn1(self.conv1(x))) 50 | out = self.layers(out) 51 | out = F.avg_pool2d(out, 2) 52 | out = out.view(out.size(0), -1) 53 | out = self.linear(out) 54 | return out 55 | 56 | 57 | def test(): 58 | net = MobileNet() 59 | x = torch.randn(1,3,32,32) 60 | y = net(Variable(x)) 61 | print(y.size()) 62 | 63 | # test() 64 | -------------------------------------------------------------------------------- /cifar/models/pnasnet.py: -------------------------------------------------------------------------------- 1 | '''PNASNet in PyTorch. 2 | 3 | Paper: Progressive Neural Architecture Search 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from torch.autograd import Variable 10 | 11 | 12 | class SepConv(nn.Module): 13 | '''Separable Convolution.''' 14 | def __init__(self, in_planes, out_planes, kernel_size, stride): 15 | super(SepConv, self).__init__() 16 | self.conv1 = nn.Conv2d(in_planes, out_planes, 17 | kernel_size, stride, 18 | padding=(kernel_size-1)//2, 19 | bias=False, groups=in_planes) 20 | self.bn1 = nn.BatchNorm2d(out_planes) 21 | 22 | def forward(self, x): 23 | return self.bn1(self.conv1(x)) 24 | 25 | 26 | class CellA(nn.Module): 27 | def __init__(self, in_planes, out_planes, stride=1): 28 | super(CellA, self).__init__() 29 | self.stride = stride 30 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 31 | if stride==2: 32 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 33 | self.bn1 = nn.BatchNorm2d(out_planes) 34 | 35 | def forward(self, x): 36 | y1 = self.sep_conv1(x) 37 | y2 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 38 | if self.stride==2: 39 | y2 = self.bn1(self.conv1(y2)) 40 | return F.relu(y1+y2) 41 | 42 | class CellB(nn.Module): 43 | def __init__(self, in_planes, out_planes, stride=1): 44 | super(CellB, self).__init__() 45 | self.stride = stride 46 | # Left branch 47 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 48 | self.sep_conv2 = SepConv(in_planes, out_planes, kernel_size=3, stride=stride) 49 | # Right branch 50 | self.sep_conv3 = SepConv(in_planes, out_planes, kernel_size=5, stride=stride) 51 | if stride==2: 52 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 53 | self.bn1 = nn.BatchNorm2d(out_planes) 54 | # Reduce channels 55 | self.conv2 = nn.Conv2d(2*out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 56 | self.bn2 = nn.BatchNorm2d(out_planes) 57 | 58 | def forward(self, x): 59 | # Left branch 60 | y1 = self.sep_conv1(x) 61 | y2 = self.sep_conv2(x) 62 | # Right branch 63 | y3 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 64 | if self.stride==2: 65 | y3 = self.bn1(self.conv1(y3)) 66 | y4 = self.sep_conv3(x) 67 | # Concat & reduce channels 68 | b1 = F.relu(y1+y2) 69 | b2 = F.relu(y3+y4) 70 | y = torch.cat([b1,b2], 1) 71 | return F.relu(self.bn2(self.conv2(y))) 72 | 73 | class PNASNet(nn.Module): 74 | def __init__(self, cell_type, num_cells, num_planes): 75 | super(PNASNet, self).__init__() 76 | self.in_planes = num_planes 77 | self.cell_type = cell_type 78 | 79 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, stride=1, padding=1, bias=False) 80 | self.bn1 = nn.BatchNorm2d(num_planes) 81 | 82 | self.layer1 = self._make_layer(num_planes, num_cells=6) 83 | self.layer2 = self._downsample(num_planes*2) 84 | self.layer3 = self._make_layer(num_planes*2, num_cells=6) 85 | self.layer4 = self._downsample(num_planes*4) 86 | self.layer5 = self._make_layer(num_planes*4, num_cells=6) 87 | 88 | self.linear = nn.Linear(num_planes*4, 10) 89 | 90 | def _make_layer(self, planes, num_cells): 91 | layers = [] 92 | for _ in range(num_cells): 93 | layers.append(self.cell_type(self.in_planes, planes, stride=1)) 94 | self.in_planes = planes 95 | return nn.Sequential(*layers) 96 | 97 | def _downsample(self, planes): 98 | layer = self.cell_type(self.in_planes, planes, stride=2) 99 | self.in_planes = planes 100 | return layer 101 | 102 | def forward(self, x): 103 | out = F.relu(self.bn1(self.conv1(x))) 104 | out = self.layer1(out) 105 | out = self.layer2(out) 106 | out = self.layer3(out) 107 | out = self.layer4(out) 108 | out = self.layer5(out) 109 | out = F.avg_pool2d(out, 8) 110 | out = self.linear(out.view(out.size(0), -1)) 111 | return out 112 | 113 | 114 | def PNASNetA(): 115 | return PNASNet(CellA, num_cells=6, num_planes=44) 116 | 117 | def PNASNetB(): 118 | return PNASNet(CellB, num_cells=6, num_planes=32) 119 | 120 | 121 | def test(): 122 | net = PNASNetB() 123 | print(net) 124 | x = Variable(torch.randn(1,3,32,32)) 125 | y = net(x) 126 | print(y) 127 | 128 | # test() 129 | -------------------------------------------------------------------------------- /cifar/models/preact_resnet.py: -------------------------------------------------------------------------------- 1 | '''Pre-activation ResNet in PyTorch. 2 | 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from torch.autograd import Variable 12 | 13 | 14 | class PreActBlock(nn.Module): 15 | '''Pre-activation version of the BasicBlock.''' 16 | expansion = 1 17 | 18 | def __init__(self, in_planes, planes, stride=1): 19 | super(PreActBlock, self).__init__() 20 | self.bn1 = nn.BatchNorm2d(in_planes) 21 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 24 | 25 | if stride != 1 or in_planes != self.expansion*planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 28 | ) 29 | 30 | def forward(self, x): 31 | out = F.relu(self.bn1(x)) 32 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 33 | out = self.conv1(out) 34 | out = self.conv2(F.relu(self.bn2(out))) 35 | out += shortcut 36 | return out 37 | 38 | 39 | class PreActBottleneck(nn.Module): 40 | '''Pre-activation version of the original Bottleneck module.''' 41 | expansion = 4 42 | 43 | def __init__(self, in_planes, planes, stride=1): 44 | super(PreActBottleneck, self).__init__() 45 | self.bn1 = nn.BatchNorm2d(in_planes) 46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 47 | self.bn2 = nn.BatchNorm2d(planes) 48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 49 | self.bn3 = nn.BatchNorm2d(planes) 50 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 51 | 52 | if stride != 1 or in_planes != self.expansion*planes: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 55 | ) 56 | 57 | def forward(self, x): 58 | out = F.relu(self.bn1(x)) 59 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 60 | out = self.conv1(out) 61 | out = self.conv2(F.relu(self.bn2(out))) 62 | out = self.conv3(F.relu(self.bn3(out))) 63 | out += shortcut 64 | return out 65 | 66 | 67 | class PreActResNet(nn.Module): 68 | def __init__(self, block, num_blocks, num_classes=10): 69 | super(PreActResNet, self).__init__() 70 | self.in_planes = 64 71 | 72 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 73 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 74 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 75 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 76 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 77 | self.linear = nn.Linear(512*block.expansion, num_classes) 78 | 79 | def _make_layer(self, block, planes, num_blocks, stride): 80 | strides = [stride] + [1]*(num_blocks-1) 81 | layers = [] 82 | for stride in strides: 83 | layers.append(block(self.in_planes, planes, stride)) 84 | self.in_planes = planes * block.expansion 85 | return nn.Sequential(*layers) 86 | 87 | def forward(self, x): 88 | out = self.conv1(x) 89 | out = self.layer1(out) 90 | out = self.layer2(out) 91 | out = self.layer3(out) 92 | out = self.layer4(out) 93 | out = F.avg_pool2d(out, 4) 94 | out = out.view(out.size(0), -1) 95 | out = self.linear(out) 96 | return out 97 | 98 | 99 | def PreActResNet18(): 100 | return PreActResNet(PreActBlock, [2,2,2,2]) 101 | 102 | def PreActResNet34(): 103 | return PreActResNet(PreActBlock, [3,4,6,3]) 104 | 105 | def PreActResNet50(): 106 | return PreActResNet(PreActBottleneck, [3,4,6,3]) 107 | 108 | def PreActResNet101(): 109 | return PreActResNet(PreActBottleneck, [3,4,23,3]) 110 | 111 | def PreActResNet152(): 112 | return PreActResNet(PreActBottleneck, [3,8,36,3]) 113 | 114 | 115 | def test(): 116 | net = PreActResNet18() 117 | y = net(Variable(torch.randn(1,3,32,32))) 118 | print(y.size()) 119 | 120 | # test() 121 | -------------------------------------------------------------------------------- /cifar/models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from torch.autograd import Variable 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, in_planes, planes, stride=1): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion*planes: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 30 | nn.BatchNorm2d(self.expansion*planes) 31 | ) 32 | 33 | def forward(self, x): 34 | out = F.relu(self.bn1(self.conv1(x))) 35 | out = self.bn2(self.conv2(out)) 36 | out += self.shortcut(x) 37 | out = F.relu(out) 38 | return out 39 | 40 | 41 | class Bottleneck(nn.Module): 42 | expansion = 4 43 | 44 | def __init__(self, in_planes, planes, stride=1): 45 | super(Bottleneck, self).__init__() 46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 47 | self.bn1 = nn.BatchNorm2d(planes) 48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 52 | 53 | self.shortcut = nn.Sequential() 54 | if stride != 1 or in_planes != self.expansion*planes: 55 | self.shortcut = nn.Sequential( 56 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 57 | nn.BatchNorm2d(self.expansion*planes) 58 | ) 59 | 60 | def forward(self, x): 61 | out = F.relu(self.bn1(self.conv1(x))) 62 | out = F.relu(self.bn2(self.conv2(out))) 63 | out = self.bn3(self.conv3(out)) 64 | out += self.shortcut(x) 65 | out = F.relu(out) 66 | return out 67 | 68 | 69 | class ResNet(nn.Module): 70 | def __init__(self, block, num_blocks, num_classes=10): 71 | super(ResNet, self).__init__() 72 | self.in_planes = 64 73 | 74 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 75 | self.bn1 = nn.BatchNorm2d(64) 76 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 77 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 78 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 79 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 80 | self.linear = nn.Linear(512*block.expansion, num_classes) 81 | 82 | def _make_layer(self, block, planes, num_blocks, stride): 83 | strides = [stride] + [1]*(num_blocks-1) 84 | layers = [] 85 | for stride in strides: 86 | layers.append(block(self.in_planes, planes, stride)) 87 | self.in_planes = planes * block.expansion 88 | return nn.Sequential(*layers) 89 | 90 | def forward(self, x): 91 | out = F.relu(self.bn1(self.conv1(x))) 92 | out = self.layer1(out) 93 | out = self.layer2(out) 94 | out = self.layer3(out) 95 | out = self.layer4(out) 96 | out = F.avg_pool2d(out, 4) 97 | out = out.view(out.size(0), -1) 98 | out = self.linear(out) 99 | return out 100 | 101 | 102 | def ResNet18(): 103 | return ResNet(BasicBlock, [2,2,2,2]) 104 | 105 | def ResNet34(): 106 | return ResNet(BasicBlock, [3,4,6,3]) 107 | 108 | def ResNet50(): 109 | return ResNet(Bottleneck, [3,4,6,3]) 110 | 111 | def ResNet101(): 112 | return ResNet(Bottleneck, [3,4,23,3]) 113 | 114 | def ResNet152(): 115 | return ResNet(Bottleneck, [3,8,36,3]) 116 | 117 | 118 | def test(): 119 | net = ResNet18() 120 | y = net(Variable(torch.randn(1,3,32,32))) 121 | print(y.size()) 122 | 123 | # test() 124 | -------------------------------------------------------------------------------- /cifar/models/resnext.py: -------------------------------------------------------------------------------- 1 | '''ResNeXt in PyTorch. 2 | 3 | See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from torch.autograd import Variable 10 | 11 | 12 | class Block(nn.Module): 13 | '''Grouped convolution block.''' 14 | expansion = 2 15 | 16 | def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1): 17 | super(Block, self).__init__() 18 | group_width = cardinality * bottleneck_width 19 | self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(group_width) 21 | self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 22 | self.bn2 = nn.BatchNorm2d(group_width) 23 | self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False) 24 | self.bn3 = nn.BatchNorm2d(self.expansion*group_width) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion*group_width: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False), 30 | nn.BatchNorm2d(self.expansion*group_width) 31 | ) 32 | 33 | def forward(self, x): 34 | out = F.relu(self.bn1(self.conv1(x))) 35 | out = F.relu(self.bn2(self.conv2(out))) 36 | out = self.bn3(self.conv3(out)) 37 | out += self.shortcut(x) 38 | out = F.relu(out) 39 | return out 40 | 41 | 42 | class ResNeXt(nn.Module): 43 | def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10): 44 | super(ResNeXt, self).__init__() 45 | self.cardinality = cardinality 46 | self.bottleneck_width = bottleneck_width 47 | self.in_planes = 64 48 | 49 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False) 50 | self.bn1 = nn.BatchNorm2d(64) 51 | self.layer1 = self._make_layer(num_blocks[0], 1) 52 | self.layer2 = self._make_layer(num_blocks[1], 2) 53 | self.layer3 = self._make_layer(num_blocks[2], 2) 54 | # self.layer4 = self._make_layer(num_blocks[3], 2) 55 | self.linear = nn.Linear(cardinality*bottleneck_width*8, num_classes) 56 | 57 | def _make_layer(self, num_blocks, stride): 58 | strides = [stride] + [1]*(num_blocks-1) 59 | layers = [] 60 | for stride in strides: 61 | layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride)) 62 | self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width 63 | # Increase bottleneck_width by 2 after each stage. 64 | self.bottleneck_width *= 2 65 | return nn.Sequential(*layers) 66 | 67 | def forward(self, x): 68 | out = F.relu(self.bn1(self.conv1(x))) 69 | out = self.layer1(out) 70 | out = self.layer2(out) 71 | out = self.layer3(out) 72 | # out = self.layer4(out) 73 | out = F.avg_pool2d(out, 8) 74 | out = out.view(out.size(0), -1) 75 | out = self.linear(out) 76 | return out 77 | 78 | 79 | def ResNeXt29_2x64d(): 80 | return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64) 81 | 82 | def ResNeXt29_4x64d(): 83 | return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64) 84 | 85 | def ResNeXt29_8x64d(): 86 | return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64) 87 | 88 | def ResNeXt29_32x4d(): 89 | return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4) 90 | 91 | def test_resnext(): 92 | net = ResNeXt29_2x64d() 93 | x = torch.randn(1,3,32,32) 94 | y = net(Variable(x)) 95 | print(y.size()) 96 | 97 | # test_resnext() 98 | -------------------------------------------------------------------------------- /cifar/models/senet.py: -------------------------------------------------------------------------------- 1 | '''SENet in PyTorch. 2 | 3 | SENet is the winner of ImageNet-2017. The paper is not released yet. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from torch.autograd import Variable 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | def __init__(self, in_planes, planes, stride=1): 14 | super(BasicBlock, self).__init__() 15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 16 | self.bn1 = nn.BatchNorm2d(planes) 17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(planes) 19 | 20 | self.shortcut = nn.Sequential() 21 | if stride != 1 or in_planes != planes: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(planes) 25 | ) 26 | 27 | # SE layers 28 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) # Use nn.Conv2d instead of nn.Linear 29 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | 35 | # Squeeze 36 | w = F.avg_pool2d(out, out.size(2)) 37 | w = F.relu(self.fc1(w)) 38 | w = F.sigmoid(self.fc2(w)) 39 | # Excitation 40 | out = out * w # New broadcasting feature from v0.2! 41 | 42 | out += self.shortcut(x) 43 | out = F.relu(out) 44 | return out 45 | 46 | 47 | class PreActBlock(nn.Module): 48 | def __init__(self, in_planes, planes, stride=1): 49 | super(PreActBlock, self).__init__() 50 | self.bn1 = nn.BatchNorm2d(in_planes) 51 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 54 | 55 | if stride != 1 or in_planes != planes: 56 | self.shortcut = nn.Sequential( 57 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False) 58 | ) 59 | 60 | # SE layers 61 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) 62 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 63 | 64 | def forward(self, x): 65 | out = F.relu(self.bn1(x)) 66 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 67 | out = self.conv1(out) 68 | out = self.conv2(F.relu(self.bn2(out))) 69 | 70 | # Squeeze 71 | w = F.avg_pool2d(out, out.size(2)) 72 | w = F.relu(self.fc1(w)) 73 | w = F.sigmoid(self.fc2(w)) 74 | # Excitation 75 | out = out * w 76 | 77 | out += shortcut 78 | return out 79 | 80 | 81 | class SENet(nn.Module): 82 | def __init__(self, block, num_blocks, num_classes=10): 83 | super(SENet, self).__init__() 84 | self.in_planes = 64 85 | 86 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 87 | self.bn1 = nn.BatchNorm2d(64) 88 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 89 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 90 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 91 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 92 | self.linear = nn.Linear(512, num_classes) 93 | 94 | def _make_layer(self, block, planes, num_blocks, stride): 95 | strides = [stride] + [1]*(num_blocks-1) 96 | layers = [] 97 | for stride in strides: 98 | layers.append(block(self.in_planes, planes, stride)) 99 | self.in_planes = planes 100 | return nn.Sequential(*layers) 101 | 102 | def forward(self, x): 103 | out = F.relu(self.bn1(self.conv1(x))) 104 | out = self.layer1(out) 105 | out = self.layer2(out) 106 | out = self.layer3(out) 107 | out = self.layer4(out) 108 | out = F.avg_pool2d(out, 4) 109 | out = out.view(out.size(0), -1) 110 | out = self.linear(out) 111 | return out 112 | 113 | 114 | def SENet18(): 115 | return SENet(PreActBlock, [2,2,2,2]) 116 | 117 | 118 | def test(): 119 | net = SENet18() 120 | y = net(Variable(torch.randn(1,3,32,32))) 121 | print(y.size()) 122 | 123 | # test() 124 | -------------------------------------------------------------------------------- /cifar/models/shufflenet.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNet in PyTorch. 2 | 3 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from torch.autograd import Variable 10 | 11 | 12 | class ShuffleBlock(nn.Module): 13 | def __init__(self, groups): 14 | super(ShuffleBlock, self).__init__() 15 | self.groups = groups 16 | 17 | def forward(self, x): 18 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 19 | N,C,H,W = x.size() 20 | g = self.groups 21 | return x.view(N,g,C/g,H,W).permute(0,2,1,3,4).contiguous().view(N,C,H,W) 22 | 23 | 24 | class Bottleneck(nn.Module): 25 | def __init__(self, in_planes, out_planes, stride, groups): 26 | super(Bottleneck, self).__init__() 27 | self.stride = stride 28 | 29 | mid_planes = out_planes/4 30 | g = 1 if in_planes==24 else groups 31 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) 32 | self.bn1 = nn.BatchNorm2d(mid_planes) 33 | self.shuffle1 = ShuffleBlock(groups=g) 34 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) 35 | self.bn2 = nn.BatchNorm2d(mid_planes) 36 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) 37 | self.bn3 = nn.BatchNorm2d(out_planes) 38 | 39 | self.shortcut = nn.Sequential() 40 | if stride == 2: 41 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) 42 | 43 | def forward(self, x): 44 | out = F.relu(self.bn1(self.conv1(x))) 45 | out = self.shuffle1(out) 46 | out = F.relu(self.bn2(self.conv2(out))) 47 | out = self.bn3(self.conv3(out)) 48 | res = self.shortcut(x) 49 | out = F.relu(torch.cat([out,res], 1)) if self.stride==2 else F.relu(out+res) 50 | return out 51 | 52 | 53 | class ShuffleNet(nn.Module): 54 | def __init__(self, cfg): 55 | super(ShuffleNet, self).__init__() 56 | out_planes = cfg['out_planes'] 57 | num_blocks = cfg['num_blocks'] 58 | groups = cfg['groups'] 59 | 60 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 61 | self.bn1 = nn.BatchNorm2d(24) 62 | self.in_planes = 24 63 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 64 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 65 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 66 | self.linear = nn.Linear(out_planes[2], 10) 67 | 68 | def _make_layer(self, out_planes, num_blocks, groups): 69 | layers = [] 70 | for i in range(num_blocks): 71 | stride = 2 if i == 0 else 1 72 | cat_planes = self.in_planes if i == 0 else 0 73 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, stride=stride, groups=groups)) 74 | self.in_planes = out_planes 75 | return nn.Sequential(*layers) 76 | 77 | def forward(self, x): 78 | out = F.relu(self.bn1(self.conv1(x))) 79 | out = self.layer1(out) 80 | out = self.layer2(out) 81 | out = self.layer3(out) 82 | out = F.avg_pool2d(out, 4) 83 | out = out.view(out.size(0), -1) 84 | out = self.linear(out) 85 | return out 86 | 87 | 88 | def ShuffleNetG2(): 89 | cfg = { 90 | 'out_planes': [200,400,800], 91 | 'num_blocks': [4,8,4], 92 | 'groups': 2 93 | } 94 | return ShuffleNet(cfg) 95 | 96 | def ShuffleNetG3(): 97 | cfg = { 98 | 'out_planes': [240,480,960], 99 | 'num_blocks': [4,8,4], 100 | 'groups': 3 101 | } 102 | return ShuffleNet(cfg) 103 | 104 | 105 | def test(): 106 | net = ShuffleNetG2() 107 | x = Variable(torch.randn(1,3,32,32)) 108 | y = net(x) 109 | print(y) 110 | 111 | # test() 112 | -------------------------------------------------------------------------------- /cifar/models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | 7 | cfg = { 8 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 10 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 11 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 12 | } 13 | 14 | 15 | class VGG(nn.Module): 16 | def __init__(self, vgg_name): 17 | super(VGG, self).__init__() 18 | self.features = self._make_layers(cfg[vgg_name]) 19 | self.classifier = nn.Linear(512, 10) 20 | 21 | def forward(self, x): 22 | out = self.features(x) 23 | out = out.view(out.size(0), -1) 24 | out = self.classifier(out) 25 | return out 26 | 27 | def _make_layers(self, cfg): 28 | layers = [] 29 | in_channels = 3 30 | for x in cfg: 31 | if x == 'M': 32 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 33 | else: 34 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(x), 36 | nn.ReLU(inplace=True)] 37 | in_channels = x 38 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 39 | return nn.Sequential(*layers) 40 | 41 | # net = VGG('VGG11') 42 | # x = torch.randn(2,3,32,32) 43 | # print(net(Variable(x)).size()) 44 | -------------------------------------------------------------------------------- /cifar/utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | 11 | import torch.nn as nn 12 | import torch.nn.init as init 13 | 14 | import numpy as np 15 | import torch 16 | 17 | def mixup_data(x, y, alpha=1.0, use_cuda=True): 18 | 19 | '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda''' 20 | if alpha > 0.: 21 | lam = np.random.beta(alpha, alpha) 22 | else: 23 | lam = 1. 24 | batch_size = x.size()[0] 25 | if use_cuda: 26 | index = torch.randperm(batch_size).cuda() 27 | else: 28 | index = torch.randperm(batch_size) 29 | 30 | mixed_x = lam * x + (1 - lam) * x[index,:] 31 | y_a, y_b = y, y[index] 32 | return mixed_x, y_a, y_b, lam 33 | 34 | def mixup_criterion(y_a, y_b, lam): 35 | return lambda criterion, pred: lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 36 | 37 | def get_mean_and_std(dataset): 38 | '''Compute the mean and std value of dataset.''' 39 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 40 | mean = torch.zeros(3) 41 | std = torch.zeros(3) 42 | print('==> Computing mean and std..') 43 | for inputs, targets in dataloader: 44 | for i in range(3): 45 | mean[i] += inputs[:,i,:,:].mean() 46 | std[i] += inputs[:,i,:,:].std() 47 | mean.div_(len(dataset)) 48 | std.div_(len(dataset)) 49 | return mean, std 50 | 51 | def init_params(net): 52 | '''Init layer parameters.''' 53 | for m in net.modules(): 54 | if isinstance(m, nn.Conv2d): 55 | init.kaiming_normal(m.weight, mode='fan_out') 56 | if m.bias: 57 | init.constant(m.bias, 0) 58 | elif isinstance(m, nn.BatchNorm2d): 59 | init.constant(m.weight, 1) 60 | init.constant(m.bias, 0) 61 | elif isinstance(m, nn.Linear): 62 | init.normal(m.weight, std=1e-3) 63 | if m.bias: 64 | init.constant(m.bias, 0) 65 | 66 | 67 | _, term_width = os.popen('stty size', 'r').read().split() 68 | term_width = int(term_width) 69 | 70 | TOTAL_BAR_LENGTH = 65. 71 | last_time = time.time() 72 | begin_time = last_time 73 | 74 | def progress_bar(current, total, msg=None): 75 | global last_time, begin_time 76 | if current == 0: 77 | begin_time = time.time() # Reset for new bar. 78 | 79 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 80 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 81 | 82 | sys.stdout.write(' [') 83 | for i in range(cur_len): 84 | sys.stdout.write('=') 85 | sys.stdout.write('>') 86 | for i in range(rest_len): 87 | sys.stdout.write('.') 88 | sys.stdout.write(']') 89 | 90 | cur_time = time.time() 91 | step_time = cur_time - last_time 92 | last_time = cur_time 93 | tot_time = cur_time - begin_time 94 | 95 | L = [] 96 | L.append(' Step: %s' % format_time(step_time)) 97 | L.append(' | Tot: %s' % format_time(tot_time)) 98 | if msg: 99 | L.append(' | ' + msg) 100 | 101 | msg = ''.join(L) 102 | sys.stdout.write(msg) 103 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 104 | sys.stdout.write(' ') 105 | 106 | # Go back to the center of the bar. 107 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 108 | sys.stdout.write('\b') 109 | sys.stdout.write(' %d/%d ' % (current+1, total)) 110 | 111 | if current < total-1: 112 | sys.stdout.write('\r') 113 | else: 114 | sys.stdout.write('\n') 115 | sys.stdout.flush() 116 | 117 | def format_time(seconds): 118 | days = int(seconds / 3600/24) 119 | seconds = seconds - days*3600*24 120 | hours = int(seconds / 3600) 121 | seconds = seconds - hours*3600 122 | minutes = int(seconds / 60) 123 | seconds = seconds - minutes*60 124 | secondsf = int(seconds) 125 | seconds = seconds - secondsf 126 | millis = int(seconds*1000) 127 | 128 | f = '' 129 | i = 1 130 | if days > 0: 131 | f += str(days) + 'D' 132 | i += 1 133 | if hours > 0 and i <= 2: 134 | f += str(hours) + 'h' 135 | i += 1 136 | if minutes > 0 and i <= 2: 137 | f += str(minutes) + 'm' 138 | i += 1 139 | if secondsf > 0 and i <= 2: 140 | f += str(secondsf) + 's' 141 | i += 1 142 | if millis > 0 and i <= 2: 143 | f += str(millis) + 'ms' 144 | i += 1 145 | if f == '': 146 | f = '0ms' 147 | return f 148 | -------------------------------------------------------------------------------- /gan/README.md: -------------------------------------------------------------------------------- 1 | 2 | # mixup training of GAN 3 | This is a demo implementation of using _mixup_ in GAN training of two 2-d toy examples as shown in the [paper](https://arxiv.org/abs/1710.09412). 4 | 5 | ## Training 6 | Simply run 7 | ``` 8 | python example_gan.py 9 | ``` 10 | 11 | You will need PyTorch and the `tqdm` package to run this script. 12 | 13 | ## Results 14 | It may take a few hours (about 5 hours on a Nvidia GTX 1070) to run all the settings in this script. After the experiments finish, you should see a set of images similar to the ones shown here in the `images` folder, which are visualizations of the target distribution and generated samples during the training process. The `samples` folder contains the `(x, y)` values of corresponding samples. 15 | 16 | ![](images/gan_results.png) 17 | -------------------------------------------------------------------------------- /gan/example_gan.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | 4 | from matplotlib import pyplot as plt 5 | import torch 6 | from torch.autograd import Variable, grad 7 | from torch.nn.functional import binary_cross_entropy_with_logits as bce 8 | 9 | from tqdm import tqdm 10 | 11 | import os 12 | 13 | if not os.path.exists('./samples'): 14 | os.makedirs('./samples') 15 | 16 | n_iterations = 20001 17 | n_latent = 2 18 | n_layers = 3 19 | n_hidden = 512 20 | bs = 128 21 | extraD = 5 22 | use_cuda = True 23 | 24 | for shape in ['ring', 'grid']: 25 | for n_latent in [2, 5, 10]: 26 | for mixup in [0, 0.1, 0.2, 0.5, 1]: 27 | 28 | class Perceptron(torch.nn.Module): 29 | def __init__(self, sizes, final=None): 30 | super(Perceptron, self).__init__() 31 | layers = [] 32 | for i in range(len(sizes) - 1): 33 | layers.append(torch.nn.Linear(sizes[i], sizes[i + 1])) 34 | if i != (len(sizes) - 2): 35 | layers.append(torch.nn.ReLU(inplace=True)) 36 | if final is not None: 37 | layers.append(final()) 38 | self.net = torch.nn.Sequential(*layers) 39 | 40 | def forward(self, x): 41 | return self.net(x) 42 | 43 | def plot(x, y, mixup, iteration): 44 | lims = (x.min() - .25, x.max() + .25) 45 | plt.figure(figsize=(2, 2)) 46 | plt.plot(x[:, 0], x[:, 1], '.', label='real') 47 | plt.plot(y[:, 0], y[:, 1], '.', alpha=0.25, label='fake') 48 | plt.axis('off') 49 | plt.gca().axes.get_xaxis().set_visible(False) 50 | plt.gca().axes.get_yaxis().set_visible(False) 51 | plt.xlim(*lims) 52 | plt.ylim(*lims) 53 | plt.tight_layout(0, 0, 0) 54 | plt.show() 55 | plt.savefig("images/example_z=%d_%s_%1.1f_%06d.png" % 56 | (n_latent, shape, mixup, iteration), 57 | bbox_inches='tight', pad_inches=0) 58 | plt.close() 59 | 60 | def means_circle(k=8): 61 | p = 3.14159265359 62 | t = torch.linspace(0, 2 * p - (2 * p / k), k) 63 | m = torch.cat((torch.sin(t).view(-1, 1), 64 | torch.cos(t).view(-1, 1)), 1) 65 | return m 66 | 67 | def means_grid(k=25): 68 | m = torch.zeros(k, 2) 69 | s = int(torch.sqrt(torch.Tensor([k]))[0] / 2) 70 | cnt = 0 71 | for i in range(- s, s + 1): 72 | for j in range(- s, s + 1): 73 | m[cnt][0] = i 74 | m[cnt][1] = j 75 | cnt += 1 76 | return m / s 77 | 78 | def sample_real(n, shape, std=0.01): 79 | if shape == 'ring': 80 | m = means_circle() 81 | else: 82 | m = means_grid() 83 | i = torch.zeros(n).random_(m.size(0)).long() 84 | s = torch.randn(n, 2) * std + m[i] 85 | s = Variable(s, requires_grad=True) 86 | if use_cuda: 87 | s = s.cuda() 88 | return s 89 | 90 | def sample_noise(bs, d): 91 | z = torch.randn(bs, d) 92 | z = Variable(z, requires_grad=True) 93 | if use_cuda: 94 | z = z.cuda() 95 | return z 96 | 97 | netD = Perceptron([2] + [n_hidden] * n_layers + [1]) 98 | netG = Perceptron([n_latent] + [n_hidden] * n_layers + [2]) 99 | 100 | if use_cuda: 101 | netD.cuda() 102 | netG.cuda() 103 | 104 | optD = torch.optim.Adam(netD.parameters()) 105 | optG = torch.optim.Adam(netG.parameters()) 106 | 107 | p_real = sample_real(1000, shape) 108 | p_nois = sample_noise(1000, n_latent) 109 | 110 | def mixup_batch(mixup=0.0): 111 | def one_batch(): 112 | real = sample_real(bs, shape) 113 | fake = netG(sample_noise(bs, n_latent)) 114 | data = torch.cat((real, fake)) 115 | ones = Variable(torch.ones(real.size(0), 1)) 116 | zeros = Variable(torch.zeros(fake.size(0), 1)) 117 | perm = torch.randperm(data.size(0)).view(-1).long() 118 | if use_cuda: 119 | ones = ones.cuda() 120 | zeros = zeros.cuda() 121 | perm = perm.cuda() 122 | labels = torch.cat((ones, zeros)) 123 | return data[perm], labels[perm] 124 | 125 | d1, l1 = one_batch() 126 | if mixup == 0: 127 | return d1, l1 128 | d2, l2 = one_batch() 129 | alpha = Variable(torch.randn(d1.size(0), 1).uniform_(0, mixup)) 130 | if use_cuda: 131 | alpha = alpha.cuda() 132 | d = alpha * d1 + (1. - alpha) * d2 133 | l = alpha * l1 + (1. - alpha) * l2 134 | return d, l 135 | 136 | for iteration in tqdm(range(n_iterations)): 137 | for extra in range(extraD): 138 | data, labels = mixup_batch(mixup) 139 | 140 | optD.zero_grad() 141 | lossD = bce(netD(data), labels) 142 | lossD.backward() 143 | optD.step() 144 | 145 | data, labels = mixup_batch(0) 146 | 147 | optG.zero_grad() 148 | lossG = - bce(netD(data), labels) 149 | lossG.backward() 150 | optG.step() 151 | 152 | if iteration in [10, 100, 1000, 10000, 20000]: 153 | plot_real = p_real.cpu().data.numpy() 154 | plot_fake = netG(p_nois).cpu().data.numpy() 155 | torch.save((plot_real, plot_fake), 156 | 'samples/example_z=%d_%s_%1.1f_%06d.pt' % 157 | (n_latent, shape, mixup, iteration)) 158 | plot(plot_real, plot_fake, mixup, iteration) 159 | -------------------------------------------------------------------------------- /gan/images/gan_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hongyi-zhang/mixup/80000cea340bf829a52481ae45a317a487ce2deb/gan/images/gan_results.png --------------------------------------------------------------------------------