├── .gitignore ├── README.md ├── main.py ├── model.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | dataset/ 2 | __pycache__/ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-ResNet-CIFAR10 2 | 3 | This is a PyTorch implementation of Residual Networks as described in the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Microsoft Research Asia. It is designed for the CIFAR-10 image classification task, following the ResNet architecture described on page 7 of the paper. This version allows use of dropout, arbitrary value of n, and a custom residual projection option. 4 | 5 | #### Motivation 6 | 7 | I completed this project in order to gain a better understanding of residual connections, which inspire the design of many state-of-the-art convnets at the present moment, as well as the gradient degradation problem. Having my own custom implementation made it easier to experiment with dropout and custom projection methods, and gave me practice with PyTorch. 8 | 9 | ## Usage 10 | 11 | To train the network, use the following command: 12 | 13 | ```python main.py [-n=7] [--res-option='B'] [--use-dropout]``` 14 | 15 | ### Default Hyperparameters 16 | 17 | Hyperparameter | Default Value | Description 18 | | - | - | - | 19 | n | 5 | parameter controlling depth of network given structure described in paper 20 | `res_option` | A | projection method when number of residual channels increases 21 | `batch_size` | 128 | - 22 | `weight_decay` | 0.0001 | - 23 | `use_dropout` | False | - 24 | 25 | ## Results 26 | 27 | Using `n=9` with otherwise default hyperparameters, the network achieves a test accuracy of 91.69%. This is somewhat lower than the result reported in the paper, likely because I used fewer training iterations due to compute limitations. 28 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.autograd import Variable 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data import sampler 9 | 10 | import torchvision.datasets as dset 11 | import torchvision.transforms as T 12 | 13 | from model import ResNet 14 | 15 | import numpy as np 16 | 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--data-dir', default='./dataset', type=str, 20 | help='path to dataset') 21 | parser.add_argument('--weight-decay', default=0.0001, type=float, 22 | help='parameter to decay weights') 23 | parser.add_argument('--batch-size', default=128, type=int, 24 | help='size of each batch of cifar-10 training images') 25 | parser.add_argument('--print-every', default=100, type=int, 26 | help='number of iterations to wait before printing') 27 | parser.add_argument('-n', default=5, type=int, 28 | help='value of n to use for resnet configuration (see https://arxiv.org/pdf/1512.03385.pdf for details)') 29 | parser.add_argument('--use-dropout', default=False, const=True, nargs='?', 30 | help='whether to use dropout in network') 31 | parser.add_argument('--res-option', default='A', type=str, 32 | help='which projection method to use for changing number of channels in residual connections') 33 | 34 | def main(args): 35 | # define transforms for normalization and data augmentation 36 | transform_augment = T.Compose([ 37 | T.RandomHorizontalFlip(), 38 | T.RandomCrop(32, padding=4)]) 39 | transform_normalize = T.Compose([ 40 | T.ToTensor(), 41 | T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 42 | ]) 43 | # get CIFAR-10 data 44 | NUM_TRAIN = 45000 45 | NUM_VAL = 5000 46 | cifar10_train = dset.CIFAR10('./dataset', train=True, download=True, 47 | transform=T.Compose([transform_augment, transform_normalize])) 48 | loader_train = DataLoader(cifar10_train, batch_size=args.batch_size, 49 | sampler=ChunkSampler(NUM_TRAIN)) 50 | cifar10_val = dset.CIFAR10('./dataset', train=True, download=True, 51 | transform=transform_normalize) 52 | loader_val = DataLoader(cifar10_train, batch_size=args.batch_size, 53 | sampler=ChunkSampler(NUM_VAL, start=NUM_TRAIN)) 54 | cifar10_test = dset.CIFAR10('./dataset', train=False, download=True, 55 | transform=transform_normalize) 56 | loader_test = DataLoader(cifar10_test, batch_size=args.batch_size) 57 | 58 | # load model 59 | model = ResNet(args.n, res_option=args.res_option, use_dropout=args.use_dropout) 60 | 61 | param_count = get_param_count(model) 62 | print('Parameter count: %d' % param_count) 63 | 64 | # use gpu for training 65 | if not torch.cuda.is_available(): 66 | print('Error: CUDA library unavailable on system') 67 | return 68 | global gpu_dtype 69 | gpu_dtype = torch.cuda.FloatTensor 70 | model = model.type(gpu_dtype) 71 | 72 | # setup loss function 73 | criterion = nn.CrossEntropyLoss().cuda() 74 | # train model 75 | SCHEDULE_EPOCHS = [50, 5, 5] # divide lr by 10 after each number of epochs 76 | # SCHEDULE_EPOCHS = [100, 50, 50] # divide lr by 10 after each number of epochs 77 | learning_rate = 0.1 78 | for num_epochs in SCHEDULE_EPOCHS: 79 | print('Training for %d epochs with learning rate %f' % (num_epochs, learning_rate)) 80 | optimizer = optim.SGD(model.parameters(), lr=learning_rate, 81 | momentum=0.9, weight_decay=args.weight_decay) 82 | for epoch in range(num_epochs): 83 | check_accuracy(model, loader_val) 84 | print('Starting epoch %d / %d' % (epoch+1, num_epochs)) 85 | train(loader_train, model, criterion, optimizer) 86 | learning_rate *= 0.1 87 | 88 | print('Final test accuracy:') 89 | check_accuracy(model, loader_test) 90 | 91 | def check_accuracy(model, loader): 92 | num_correct = 0 93 | num_samples = 0 94 | model.eval() 95 | for X, y in loader: 96 | X_var = Variable(X.type(gpu_dtype), volatile=True) 97 | 98 | scores = model(X_var) 99 | _, preds = scores.data.cpu().max(1) 100 | 101 | num_correct += (preds == y).sum() 102 | num_samples += preds.size(0) 103 | 104 | acc = float(num_correct) / num_samples 105 | print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc)) 106 | 107 | def train(loader_train, model, criterion, optimizer): 108 | model.train() 109 | for t, (X, y) in enumerate(loader_train): 110 | X_var = Variable(X.type(gpu_dtype)) 111 | y_var = Variable(y.type(gpu_dtype)).long() 112 | 113 | scores = model(X_var) 114 | 115 | loss = criterion(scores, y_var) 116 | if (t+1) % args.print_every == 0: 117 | print('t = %d, loss = %.4f' % (t+1, loss.data[0])) 118 | 119 | optimizer.zero_grad() 120 | loss.backward() 121 | optimizer.step() 122 | 123 | def get_param_count(model): 124 | param_counts = [np.prod(p.size()) for p in model.parameters()] 125 | return sum(param_counts) 126 | 127 | class ChunkSampler(sampler.Sampler): 128 | def __init__(self, num_samples, start=0): 129 | self.num_samples = num_samples 130 | self.start = start 131 | 132 | def __iter__(self): 133 | return iter(range(self.start, self.start+self.num_samples)) 134 | 135 | def __len__(self): 136 | return self.num_samples 137 | 138 | if __name__ == '__main__': 139 | args = parser.parse_args() 140 | main(args) 141 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class ResNet(nn.Module): 6 | 7 | def __init__(self, n=7, res_option='A', use_dropout=False): 8 | super(ResNet, self).__init__() 9 | self.res_option = res_option 10 | self.use_dropout = use_dropout 11 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) 12 | self.norm1 = nn.BatchNorm2d(16) 13 | self.relu1 = nn.ReLU(inplace=True) 14 | self.layers1 = self._make_layer(n, 16, 16, 1) 15 | self.layers2 = self._make_layer(n, 32, 16, 2) 16 | self.layers3 = self._make_layer(n, 64, 32, 2) 17 | self.avgpool = nn.AvgPool2d(8) 18 | self.linear = nn.Linear(64, 10) 19 | 20 | def _make_layer(self, layer_count, channels, channels_in, stride): 21 | return nn.Sequential( 22 | ResBlock(channels, channels_in, stride, res_option=self.res_option, use_dropout=self.use_dropout), 23 | *[ResBlock(channels) for _ in range(layer_count-1)]) 24 | 25 | def forward(self, x): 26 | out = self.conv1(x) 27 | out = self.norm1(out) 28 | out = self.relu1(out) 29 | out = self.layers1(out) 30 | out = self.layers2(out) 31 | out = self.layers3(out) 32 | out = self.avgpool(out) 33 | out = out.view(out.size(0), -1) 34 | out = self.linear(out) 35 | return out 36 | 37 | class ResBlock(nn.Module): 38 | 39 | def __init__(self, num_filters, channels_in=None, stride=1, res_option='A', use_dropout=False): 40 | super(ResBlock, self).__init__() 41 | 42 | # uses 1x1 convolutions for downsampling 43 | if not channels_in or channels_in == num_filters: 44 | channels_in = num_filters 45 | self.projection = None 46 | else: 47 | if res_option == 'A': 48 | self.projection = IdentityPadding(num_filters, channels_in, stride) 49 | elif res_option == 'B': 50 | self.projection = ConvProjection(num_filters, channels_in, stride) 51 | elif res_option == 'C': 52 | self.projection = AvgPoolPadding(num_filters, channels_in, stride) 53 | self.use_dropout = use_dropout 54 | 55 | self.conv1 = nn.Conv2d(channels_in, num_filters, kernel_size=3, stride=stride, padding=1) 56 | self.bn1 = nn.BatchNorm2d(num_filters) 57 | self.relu1 = nn.ReLU(inplace=True) 58 | self.conv2 = nn.Conv2d(num_filters, num_filters, kernel_size=3, stride=1, padding=1) 59 | self.bn2 = nn.BatchNorm2d(num_filters) 60 | if self.use_dropout: 61 | self.dropout = nn.Dropout(inplace=True) 62 | self.relu2 = nn.ReLU(inplace=True) 63 | 64 | def forward(self, x): 65 | residual = x 66 | out = self.conv1(x) 67 | out = self.bn1(out) 68 | out = self.relu1(out) 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | if self.use_dropout: 72 | out = self.dropout(out) 73 | if self.projection: 74 | residual = self.projection(x) 75 | out += residual 76 | out = self.relu2(out) 77 | return out 78 | 79 | 80 | # various projection options to change number of filters in residual connection 81 | # option A from paper 82 | class IdentityPadding(nn.Module): 83 | def __init__(self, num_filters, channels_in, stride): 84 | super(IdentityPadding, self).__init__() 85 | # with kernel_size=1, max pooling is equivalent to identity mapping with stride 86 | self.identity = nn.MaxPool2d(1, stride=stride) 87 | self.num_zeros = num_filters - channels_in 88 | 89 | def forward(self, x): 90 | out = F.pad(x, (0, 0, 0, 0, 0, self.num_zeros)) 91 | out = self.identity(out) 92 | return out 93 | 94 | # option B from paper 95 | class ConvProjection(nn.Module): 96 | 97 | def __init__(self, num_filters, channels_in, stride): 98 | super(ResA, self).__init__() 99 | self.conv = nn.Conv2d(channels_in, num_filters, kernel_size=1, stride=stride) 100 | 101 | def forward(self, x): 102 | out = self.conv(x) 103 | return out 104 | 105 | # experimental option C 106 | class AvgPoolPadding(nn.Module): 107 | 108 | def __init__(self, num_filters, channels_in, stride): 109 | super(AvgPoolPadding, self).__init__() 110 | self.identity = nn.AvgPool2d(stride, stride=stride) 111 | self.num_zeros = num_filters - channels_in 112 | 113 | def forward(self, x): 114 | out = F.pad(x, (0, 0, 0, 0, 0, self.num_zeros)) 115 | out = self.identity(out) 116 | return out 117 | 118 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | numpy 4 | --------------------------------------------------------------------------------