├── models ├── __init__.py └── cifar │ ├── alexnet.py │ ├── __init__.py │ ├── wrn.py │ ├── vgg.py │ ├── densenet.py │ ├── preresnet.py │ ├── resnext.py │ └── resnet.py ├── utils ├── __init__.py ├── images │ ├── cifar.png │ └── imagenet.png ├── eval.py ├── log.py ├── misc.py └── cifar_loader.py ├── .gitmodules ├── .gitignore ├── run_cifar.sh ├── README.md └── cifar.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .eval import * 5 | 6 | -------------------------------------------------------------------------------- /utils/images/cifar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feidfoe/AdjustBnd4Imbalance/HEAD/utils/images/cifar.png -------------------------------------------------------------------------------- /utils/images/imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feidfoe/AdjustBnd4Imbalance/HEAD/utils/images/imagenet.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "utils/progress"] 2 | path = utils/progress 3 | url = https://github.com/verigak/progress.git 4 | -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | __all__ = ['accuracy'] 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the precision@k for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res -------------------------------------------------------------------------------- /utils/log.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import json 4 | import time 5 | import logging 6 | 7 | def save_option(option): 8 | option_path = os.path.join(option.save_dir, option.exp_name, "options.json") 9 | 10 | with open(option_path, 'w') as fp: 11 | json.dump(option.__dict__, fp, indent=4, sort_keys=True) 12 | 13 | def logger_setting(exp_name, save_dir): 14 | logger = logging.getLogger(exp_name) 15 | formatter = logging.Formatter('[%(name)s] %(levelname)s: %(message)s') 16 | 17 | log_out = os.path.join(save_dir, 'train.log') 18 | file_handler = logging.FileHandler(log_out) 19 | stream_handler = logging.StreamHandler() 20 | 21 | file_handler.setFormatter(formatter) 22 | stream_handler.setFormatter(formatter) 23 | 24 | logger.addHandler(file_handler) 25 | logger.addHandler(stream_handler) 26 | 27 | logger.setLevel(logging.INFO) 28 | return logger 29 | 30 | -------------------------------------------------------------------------------- /models/cifar/alexnet.py: -------------------------------------------------------------------------------- 1 | '''AlexNet for CIFAR10. FC layers are removed. Paddings are adjusted. 2 | Without BN, the start learning rate should be 0.01 3 | (c) YANG, Wei 4 | ''' 5 | import torch.nn as nn 6 | 7 | 8 | __all__ = ['alexnet'] 9 | 10 | 11 | class AlexNet(nn.Module): 12 | 13 | def __init__(self, num_classes=10): 14 | super(AlexNet, self).__init__() 15 | self.features = nn.Sequential( 16 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5), 17 | nn.ReLU(inplace=True), 18 | nn.MaxPool2d(kernel_size=2, stride=2), 19 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 20 | nn.ReLU(inplace=True), 21 | nn.MaxPool2d(kernel_size=2, stride=2), 22 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 27 | nn.ReLU(inplace=True), 28 | nn.MaxPool2d(kernel_size=2, stride=2), 29 | ) 30 | self.classifier = nn.Linear(256, num_classes) 31 | 32 | def forward(self, x): 33 | x = self.features(x) 34 | x = x.view(x.size(0), -1) 35 | x = self.classifier(x) 36 | return x 37 | 38 | 39 | def alexnet(**kwargs): 40 | r"""AlexNet model architecture from the 41 | `"One weird trick..." `_ paper. 42 | """ 43 | model = AlexNet(**kwargs) 44 | return model 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # tmp dirs and files 2 | checkpoint 3 | checkpoints 4 | data 5 | cifar-debug.py 6 | test.eps 7 | dev 8 | monitor.py 9 | exp 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | env/ 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *,cover 56 | .hypothesis/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # IPython Notebook 80 | .ipynb_checkpoints 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # dotenv 89 | .env 90 | 91 | # virtualenv 92 | venv/ 93 | ENV/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | -------------------------------------------------------------------------------- /run_cifar.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GPUID=0 4 | dataset_path= 5 | checkpoint_path= 6 | 7 | dataset=cifar10 8 | arch=resnet 9 | depth=32 10 | imb=100 11 | 12 | ExpName=${dataset}_${arch}${depth}_imb${imb} 13 | 14 | 15 | TRAIN=true 16 | EVAL=false 17 | 18 | 19 | 20 | if $TRAIN; then 21 | NV_GPU=${GPUID} nvidia-docker run -v `pwd`:`pwd` \ 22 | -v ${dataset_path}:`pwd`/data/ \ 23 | -v ${checkpoint_path}:`pwd`/checkpoints/ \ 24 | -w `pwd` \ 25 | --rm -it \ 26 | --ipc=host \ 27 | --name ${ExpName} \ 28 | feidfoe/pytorch:v.2 \ 29 | python cifar.py -a ${arch} \ 30 | --depth $depth \ 31 | --imbalance $imb \ 32 | --WVN \ 33 | --checkpoint checkpoints/${ExpName} 34 | fi 35 | 36 | 37 | if $EVAL; then 38 | CKPT=checkpoints/${ExpName}/checkpoint.pth.tar 39 | NV_GPU=${GPUID} nvidia-docker run -v `pwd`:`pwd` \ 40 | -v ${dataset_path}:`pwd`/data/ \ 41 | -v ${checkpoint_path}:`pwd`/checkpoints/ \ 42 | -w `pwd` \ 43 | --rm -it \ 44 | --ipc=host \ 45 | --name ${ExpName} \ 46 | feidfoe/pytorch:v.2 \ 47 | python cifar.py -a ${arch} \ 48 | --depth $depth \ 49 | --imbalance $imb \ 50 | --RS 0.5 \ 51 | --checkpoint checkpoints/${ExpName} \ 52 | --resume ${CKPT} \ 53 | --evaluate 54 | 55 | fi 56 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import math 11 | 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 17 | 18 | 19 | def get_mean_and_std(dataset): 20 | '''Compute the mean and std value of dataset.''' 21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 22 | 23 | mean = torch.zeros(3) 24 | std = torch.zeros(3) 25 | print('==> Computing mean and std..') 26 | for inputs, targets in dataloader: 27 | for i in range(3): 28 | mean[i] += inputs[:,i,:,:].mean() 29 | std[i] += inputs[:,i,:,:].std() 30 | mean.div_(len(dataset)) 31 | std.div_(len(dataset)) 32 | return mean, std 33 | 34 | def init_params(net): 35 | '''Init layer parameters.''' 36 | for m in net.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal(m.weight, mode='fan_out') 39 | if m.bias: 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant(m.weight, 1) 43 | init.constant(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal(m.weight, std=1e-3) 46 | if m.bias: 47 | init.constant(m.bias, 0) 48 | 49 | def mkdir_p(path): 50 | '''make dir if not exist''' 51 | try: 52 | os.makedirs(path) 53 | except OSError as exc: # Python >2.5 54 | if exc.errno == errno.EEXIST and os.path.isdir(path): 55 | pass 56 | else: 57 | raise 58 | 59 | class AverageMeter(object): 60 | """Computes and stores the average and current value 61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 62 | """ 63 | def __init__(self): 64 | self.reset() 65 | 66 | def reset(self): 67 | self.val = 0 68 | self.avg = 0 69 | self.sum = 0 70 | self.count = 0 71 | 72 | def update(self, val, n=1): 73 | self.val = val 74 | self.sum += val * n 75 | self.count += n 76 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /models/cifar/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | """The models subpackage contains definitions for the following model for CIFAR10/CIFAR100 4 | architectures: 5 | 6 | - `AlexNet`_ 7 | - `VGG`_ 8 | - `ResNet`_ 9 | - `SqueezeNet`_ 10 | - `DenseNet`_ 11 | 12 | You can construct a model with random weights by calling its constructor: 13 | 14 | .. code:: python 15 | 16 | import torchvision.models as models 17 | resnet18 = models.resnet18() 18 | alexnet = models.alexnet() 19 | squeezenet = models.squeezenet1_0() 20 | densenet = models.densenet_161() 21 | 22 | We provide pre-trained models for the ResNet variants and AlexNet, using the 23 | PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing 24 | ``pretrained=True``: 25 | 26 | .. code:: python 27 | 28 | import torchvision.models as models 29 | resnet18 = models.resnet18(pretrained=True) 30 | alexnet = models.alexnet(pretrained=True) 31 | 32 | ImageNet 1-crop error rates (224x224) 33 | 34 | ======================== ============= ============= 35 | Network Top-1 error Top-5 error 36 | ======================== ============= ============= 37 | ResNet-18 30.24 10.92 38 | ResNet-34 26.70 8.58 39 | ResNet-50 23.85 7.13 40 | ResNet-101 22.63 6.44 41 | ResNet-152 21.69 5.94 42 | Inception v3 22.55 6.44 43 | AlexNet 43.45 20.91 44 | VGG-11 30.98 11.37 45 | VGG-13 30.07 10.75 46 | VGG-16 28.41 9.62 47 | VGG-19 27.62 9.12 48 | SqueezeNet 1.0 41.90 19.58 49 | SqueezeNet 1.1 41.81 19.38 50 | Densenet-121 25.35 7.83 51 | Densenet-169 24.00 7.00 52 | Densenet-201 22.80 6.43 53 | Densenet-161 22.35 6.20 54 | ======================== ============= ============= 55 | 56 | 57 | .. _AlexNet: https://arxiv.org/abs/1404.5997 58 | .. _VGG: https://arxiv.org/abs/1409.1556 59 | .. _ResNet: https://arxiv.org/abs/1512.03385 60 | .. _SqueezeNet: https://arxiv.org/abs/1602.07360 61 | .. _DenseNet: https://arxiv.org/abs/1608.06993 62 | """ 63 | 64 | from .alexnet import * 65 | from .vgg import * 66 | from .resnet import * 67 | from .resnext import * 68 | from .wrn import * 69 | from .preresnet import * 70 | from .densenet import * 71 | -------------------------------------------------------------------------------- /utils/cifar_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | import torchvision.transforms as transforms 5 | import numpy as np 6 | from PIL import Image 7 | import os 8 | 9 | import pickle 10 | 11 | import torch.utils.data as data 12 | import torchvision.datasets as datasets 13 | 14 | 15 | class CIFARLoader(Dataset): 16 | def __init__(self, root='./data', train=True, 17 | imbalance=1, transform=None): 18 | self.T = transform 19 | dataset = root.split('/')[-1] 20 | 21 | if train: 22 | if dataset == 'cifar10': 23 | data_list = ['data_batch_%d'%(i+1) for i in range(5)] 24 | elif dataset == 'cifar100': 25 | data_list = ['train'] 26 | else: 27 | if dataset == 'cifar10': 28 | data_list = ['test_batch'] 29 | elif dataset == 'cifar100': 30 | data_list = ['test'] 31 | 32 | 33 | 34 | self.data = [] 35 | self.label = [] 36 | for filename in data_list: 37 | filepath = os.path.join(os.path.join(root,filename)) 38 | with open(filepath, 'rb') as f: 39 | entry = pickle.load(f,encoding='latin1') 40 | self.data.append(entry['data']) 41 | if 'labels' in entry: 42 | self.label.extend(entry['labels']) 43 | else: 44 | self.label.extend(entry['fine_labels']) 45 | 46 | data = np.vstack(self.data).reshape(-1,3,32,32) 47 | data = data.transpose((0,2,3,1)) #NHWC 48 | labels = np.array(self.label) 49 | 50 | n_class = np.max(labels) + 1 51 | img_max = data.shape[0] // n_class 52 | imb_factor = 1. / imbalance 53 | img_list, lbl_list = [], [] 54 | for i in range(n_class): 55 | idx = np.squeeze(np.argwhere(labels == i)) 56 | img = data[idx] 57 | lbl = labels[idx] 58 | num_sample = int(img_max * (imb_factor**(i/(n_class - 1)))) 59 | img_list.append(img[:num_sample]) 60 | lbl_list.append(lbl[:num_sample]) 61 | 62 | self.images = np.concatenate(img_list) 63 | self.labels = np.concatenate(lbl_list) 64 | 65 | 66 | def __getitem__(self,index): 67 | image = self.images[index] 68 | label = self.labels[index] 69 | 70 | return self.T(image), label.astype(np.long) 71 | 72 | 73 | def __len__(self): 74 | return self.images.shape[0] 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adjusting Decision Boundary for Class Imbalanced Learning 2 | This repository is the official PyTorch implementation of WVN-RS, introduced in [Adjusting Decision Boundary for Class Imbalanced Learning](https://ieeexplore.ieee.org/document/9081988). 3 | 4 | 5 | ### Requirements 6 | 1. NVIDIA docker : Docker image will be pulled from cloud. 7 | 2. CIFAR dataset : The "dataset_path" in run_cifar.sh should be 8 | ``` 9 | cifar10/ 10 | data_batch_N 11 | test_batch 12 | cifar100/ 13 | train 14 | test 15 | ``` 16 | CIFAR datasets are available [here](https://www.cs.toronto.edu/~kriz/cifar.html). 17 | 18 | ### How to use 19 | Run the shell script. 20 | ``` 21 | bash run_cifar.sh 22 | ``` 23 | To use Weight Vector Normalization (WVN), use --WVN flag. (It is already in the script.) 24 | 25 | ### Results 26 | 1. *Validation error* on Long-Tailed CIFAR10 27 | 28 | Imbalance|200|100|50|20|10|1 29 | :---:|:---:|:---:|:---:|:---:|:---:|:---: 30 | Baseline | 35.67 | 29.71 | 22.91 | 16.04 | 13.26 | 6.83 31 | Over-sample| 32.19 | 28.27 | 21.40 | 15.23 | 12.24 | 6.61 32 | [Focal](https://arxiv.org/abs/1708.02002) | 34.71 | 29.62 | 23.28 | 16.77 | 13.19 | 6.60 33 | [CB](https://arxiv.org/abs/1901.05555) | 31.11 | 25.43 | 20.73 | 15.64 | 12.51 | 6.36 34 | [LDAM-DRW](https://arxiv.org/abs/1906.07413) | 28.09 | 22.97 | 17.83 | 14.53 | *11.84* | 6.32 35 | Baseline+RS| **27.02** | *21.36* | *17.16* | *13.46* | 11.86 | *6.32* 36 | WVN+RS | *27.23* | **20.17** | **16.80** | **12.76** | **10.71** | **6.29** 37 | 38 | 39 | 2. *Validation error* on Long-Tailed CIFAR100 40 | 41 | Imbalance|200|100|50|20|10|1 42 | :---:|:---:|:---:|:---:|:---:|:---:|:---: 43 | Baseline | 64.21 | 60.38 | 55.09 | 48.93 | 43.52 | 29.69 44 | Over-sample| 66.39 | 61.53 | 56.65 | 49.03 | 43.38 | 29.41 45 | [Focal](https://arxiv.org/abs/1708.02002) | 64.38 | 61.31 | 55.68 | 48.05 | 44.22 | *28.52* 46 | [CB](https://arxiv.org/abs/1901.05555) | 63.77 | 60.40 | 54.68 | 47.41 | 42.01 | **28.39** 47 | [LDAM-DRW](https://arxiv.org/abs/1906.07413) | 61.73 | 57.96 | 52.54 | 47.14 | *41.29* | 28.85 48 | Baseline+RS| *59.59* | *55.65* | *51.91* | **45.09** | 41.45 | 29.80 49 | WVN+RS | **59.48** | **55.50** | **51.80** | *46.12* | **41.02** | 29.22 50 | 51 | 52 | 53 | 54 | ### Notes 55 | This codes use docker image "feidfoe/pytorch:v.2" with pytorch version, '0.4.0a0+0640816'. 56 | The image only provides basic libraries such as NumPy or PIL. 57 | 58 | WVN is implemented on ResNet architecture only. 59 | 60 | 61 | 62 | #### Baseline repository 63 | This repository is forked and modified from [original repo](https://github.com/bearpaw/pytorch-classification). 64 | 65 | 66 | ### Contact 67 | [Byungju Kim](https://feidfoe.github.io/) (byungju.kim@kaist.ac.kr) 68 | 69 | 70 | ### BibTeX for Citation 71 | ``` 72 | @ARTICLE{9081988, 73 | author={B. {Kim} and J. {Kim}}, 74 | journal={IEEE Access}, 75 | title={Adjusting Decision Boundary for Class Imbalanced Learning}, 76 | year={2020}, 77 | volume={8}, 78 | number={}, 79 | pages={81674-81685},} 80 | ``` 81 | -------------------------------------------------------------------------------- /models/cifar/wrn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | __all__ = ['wrn'] 7 | 8 | class BasicBlock(nn.Module): 9 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 10 | super(BasicBlock, self).__init__() 11 | self.bn1 = nn.BatchNorm2d(in_planes) 12 | self.relu1 = nn.ReLU(inplace=True) 13 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 14 | padding=1, bias=False) 15 | self.bn2 = nn.BatchNorm2d(out_planes) 16 | self.relu2 = nn.ReLU(inplace=True) 17 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 18 | padding=1, bias=False) 19 | self.droprate = dropRate 20 | self.equalInOut = (in_planes == out_planes) 21 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 22 | padding=0, bias=False) or None 23 | def forward(self, x): 24 | if not self.equalInOut: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 29 | if self.droprate > 0: 30 | out = F.dropout(out, p=self.droprate, training=self.training) 31 | out = self.conv2(out) 32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 33 | 34 | class NetworkBlock(nn.Module): 35 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 36 | super(NetworkBlock, self).__init__() 37 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 38 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 39 | layers = [] 40 | for i in range(nb_layers): 41 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 42 | return nn.Sequential(*layers) 43 | def forward(self, x): 44 | return self.layer(x) 45 | 46 | class WideResNet(nn.Module): 47 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 48 | super(WideResNet, self).__init__() 49 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 50 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' 51 | n = (depth - 4) // 6 52 | block = BasicBlock 53 | # 1st conv before any network block 54 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 55 | padding=1, bias=False) 56 | # 1st block 57 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 58 | # 2nd block 59 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 60 | # 3rd block 61 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 62 | # global average pooling and classifier 63 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.fc = nn.Linear(nChannels[3], num_classes) 66 | self.nChannels = nChannels[3] 67 | 68 | for m in self.modules(): 69 | if isinstance(m, nn.Conv2d): 70 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 71 | m.weight.data.normal_(0, math.sqrt(2. / n)) 72 | elif isinstance(m, nn.BatchNorm2d): 73 | m.weight.data.fill_(1) 74 | m.bias.data.zero_() 75 | elif isinstance(m, nn.Linear): 76 | m.bias.data.zero_() 77 | 78 | def forward(self, x): 79 | out = self.conv1(x) 80 | out = self.block1(out) 81 | out = self.block2(out) 82 | out = self.block3(out) 83 | out = self.relu(self.bn1(out)) 84 | out = F.avg_pool2d(out, 8) 85 | out = out.view(-1, self.nChannels) 86 | return self.fc(out) 87 | 88 | def wrn(**kwargs): 89 | """ 90 | Constructs a Wide Residual Networks. 91 | """ 92 | model = WideResNet(**kwargs) 93 | return model 94 | -------------------------------------------------------------------------------- /models/cifar/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG for CIFAR10. FC layers are removed. 2 | (c) YANG, Wei 3 | ''' 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | import math 7 | 8 | 9 | __all__ = [ 10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 11 | 'vgg19_bn', 'vgg19', 12 | ] 13 | 14 | 15 | model_urls = { 16 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 17 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 18 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 19 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 20 | } 21 | 22 | 23 | class VGG(nn.Module): 24 | 25 | def __init__(self, features, num_classes=1000): 26 | super(VGG, self).__init__() 27 | self.features = features 28 | self.classifier = nn.Linear(512, num_classes) 29 | self._initialize_weights() 30 | 31 | def forward(self, x): 32 | x = self.features(x) 33 | x = x.view(x.size(0), -1) 34 | x = self.classifier(x) 35 | return x 36 | 37 | def _initialize_weights(self): 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv2d): 40 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 41 | m.weight.data.normal_(0, math.sqrt(2. / n)) 42 | if m.bias is not None: 43 | m.bias.data.zero_() 44 | elif isinstance(m, nn.BatchNorm2d): 45 | m.weight.data.fill_(1) 46 | m.bias.data.zero_() 47 | elif isinstance(m, nn.Linear): 48 | n = m.weight.size(1) 49 | m.weight.data.normal_(0, 0.01) 50 | m.bias.data.zero_() 51 | 52 | 53 | def make_layers(cfg, batch_norm=False): 54 | layers = [] 55 | in_channels = 3 56 | for v in cfg: 57 | if v == 'M': 58 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 59 | else: 60 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 61 | if batch_norm: 62 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 63 | else: 64 | layers += [conv2d, nn.ReLU(inplace=True)] 65 | in_channels = v 66 | return nn.Sequential(*layers) 67 | 68 | 69 | cfg = { 70 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 71 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 72 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 73 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 74 | } 75 | 76 | 77 | def vgg11(**kwargs): 78 | """VGG 11-layer model (configuration "A") 79 | 80 | Args: 81 | pretrained (bool): If True, returns a model pre-trained on ImageNet 82 | """ 83 | model = VGG(make_layers(cfg['A']), **kwargs) 84 | return model 85 | 86 | 87 | def vgg11_bn(**kwargs): 88 | """VGG 11-layer model (configuration "A") with batch normalization""" 89 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 90 | return model 91 | 92 | 93 | def vgg13(**kwargs): 94 | """VGG 13-layer model (configuration "B") 95 | 96 | Args: 97 | pretrained (bool): If True, returns a model pre-trained on ImageNet 98 | """ 99 | model = VGG(make_layers(cfg['B']), **kwargs) 100 | return model 101 | 102 | 103 | def vgg13_bn(**kwargs): 104 | """VGG 13-layer model (configuration "B") with batch normalization""" 105 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 106 | return model 107 | 108 | 109 | def vgg16(**kwargs): 110 | """VGG 16-layer model (configuration "D") 111 | 112 | Args: 113 | pretrained (bool): If True, returns a model pre-trained on ImageNet 114 | """ 115 | model = VGG(make_layers(cfg['D']), **kwargs) 116 | return model 117 | 118 | 119 | def vgg16_bn(**kwargs): 120 | """VGG 16-layer model (configuration "D") with batch normalization""" 121 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 122 | return model 123 | 124 | 125 | def vgg19(**kwargs): 126 | """VGG 19-layer model (configuration "E") 127 | 128 | Args: 129 | pretrained (bool): If True, returns a model pre-trained on ImageNet 130 | """ 131 | model = VGG(make_layers(cfg['E']), **kwargs) 132 | return model 133 | 134 | 135 | def vgg19_bn(**kwargs): 136 | """VGG 19-layer model (configuration 'E') with batch normalization""" 137 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 138 | return model 139 | -------------------------------------------------------------------------------- /models/cifar/densenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | __all__ = ['densenet'] 8 | 9 | 10 | from torch.autograd import Variable 11 | 12 | class Bottleneck(nn.Module): 13 | def __init__(self, inplanes, expansion=4, growthRate=12, dropRate=0): 14 | super(Bottleneck, self).__init__() 15 | planes = expansion * growthRate 16 | self.bn1 = nn.BatchNorm2d(inplanes) 17 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(planes) 19 | self.conv2 = nn.Conv2d(planes, growthRate, kernel_size=3, 20 | padding=1, bias=False) 21 | self.relu = nn.ReLU(inplace=True) 22 | self.dropRate = dropRate 23 | 24 | def forward(self, x): 25 | out = self.bn1(x) 26 | out = self.relu(out) 27 | out = self.conv1(out) 28 | out = self.bn2(out) 29 | out = self.relu(out) 30 | out = self.conv2(out) 31 | if self.dropRate > 0: 32 | out = F.dropout(out, p=self.dropRate, training=self.training) 33 | 34 | out = torch.cat((x, out), 1) 35 | 36 | return out 37 | 38 | 39 | class BasicBlock(nn.Module): 40 | def __init__(self, inplanes, expansion=1, growthRate=12, dropRate=0): 41 | super(BasicBlock, self).__init__() 42 | planes = expansion * growthRate 43 | self.bn1 = nn.BatchNorm2d(inplanes) 44 | self.conv1 = nn.Conv2d(inplanes, growthRate, kernel_size=3, 45 | padding=1, bias=False) 46 | self.relu = nn.ReLU(inplace=True) 47 | self.dropRate = dropRate 48 | 49 | def forward(self, x): 50 | out = self.bn1(x) 51 | out = self.relu(out) 52 | out = self.conv1(out) 53 | if self.dropRate > 0: 54 | out = F.dropout(out, p=self.dropRate, training=self.training) 55 | 56 | out = torch.cat((x, out), 1) 57 | 58 | return out 59 | 60 | 61 | class Transition(nn.Module): 62 | def __init__(self, inplanes, outplanes): 63 | super(Transition, self).__init__() 64 | self.bn1 = nn.BatchNorm2d(inplanes) 65 | self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size=1, 66 | bias=False) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | def forward(self, x): 70 | out = self.bn1(x) 71 | out = self.relu(out) 72 | out = self.conv1(out) 73 | out = F.avg_pool2d(out, 2) 74 | return out 75 | 76 | 77 | class DenseNet(nn.Module): 78 | 79 | def __init__(self, depth=22, block=Bottleneck, 80 | dropRate=0, num_classes=10, growthRate=12, compressionRate=2): 81 | super(DenseNet, self).__init__() 82 | 83 | assert (depth - 4) % 3 == 0, 'depth should be 3n+4' 84 | n = (depth - 4) / 3 if block == BasicBlock else (depth - 4) // 6 85 | 86 | self.growthRate = growthRate 87 | self.dropRate = dropRate 88 | 89 | # self.inplanes is a global variable used across multiple 90 | # helper functions 91 | self.inplanes = growthRate * 2 92 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, 93 | bias=False) 94 | self.dense1 = self._make_denseblock(block, n) 95 | self.trans1 = self._make_transition(compressionRate) 96 | self.dense2 = self._make_denseblock(block, n) 97 | self.trans2 = self._make_transition(compressionRate) 98 | self.dense3 = self._make_denseblock(block, n) 99 | self.bn = nn.BatchNorm2d(self.inplanes) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.avgpool = nn.AvgPool2d(8) 102 | self.fc = nn.Linear(self.inplanes, num_classes) 103 | 104 | # Weight initialization 105 | for m in self.modules(): 106 | if isinstance(m, nn.Conv2d): 107 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 108 | m.weight.data.normal_(0, math.sqrt(2. / n)) 109 | elif isinstance(m, nn.BatchNorm2d): 110 | m.weight.data.fill_(1) 111 | m.bias.data.zero_() 112 | 113 | def _make_denseblock(self, block, blocks): 114 | layers = [] 115 | for i in range(blocks): 116 | # Currently we fix the expansion ratio as the default value 117 | layers.append(block(self.inplanes, growthRate=self.growthRate, dropRate=self.dropRate)) 118 | self.inplanes += self.growthRate 119 | 120 | return nn.Sequential(*layers) 121 | 122 | def _make_transition(self, compressionRate): 123 | inplanes = self.inplanes 124 | outplanes = int(math.floor(self.inplanes // compressionRate)) 125 | self.inplanes = outplanes 126 | return Transition(inplanes, outplanes) 127 | 128 | 129 | def forward(self, x): 130 | x = self.conv1(x) 131 | 132 | x = self.trans1(self.dense1(x)) 133 | x = self.trans2(self.dense2(x)) 134 | x = self.dense3(x) 135 | x = self.bn(x) 136 | x = self.relu(x) 137 | 138 | x = self.avgpool(x) 139 | x = x.view(x.size(0), -1) 140 | x = self.fc(x) 141 | 142 | return x 143 | 144 | 145 | def densenet(**kwargs): 146 | """ 147 | Constructs a ResNet model. 148 | """ 149 | return DenseNet(**kwargs) -------------------------------------------------------------------------------- /models/cifar/preresnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | '''Resnet for cifar dataset. 4 | Ported form 5 | https://github.com/facebook/fb.resnet.torch 6 | and 7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 8 | (c) YANG, Wei 9 | ''' 10 | import torch.nn as nn 11 | import math 12 | 13 | 14 | __all__ = ['preresnet'] 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | "3x3 convolution with padding" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=1, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None): 26 | super(BasicBlock, self).__init__() 27 | self.bn1 = nn.BatchNorm2d(inplanes) 28 | self.relu = nn.ReLU(inplace=True) 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn2 = nn.BatchNorm2d(planes) 31 | self.conv2 = conv3x3(planes, planes) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | residual = x 37 | 38 | out = self.bn1(x) 39 | out = self.relu(out) 40 | out = self.conv1(out) 41 | 42 | out = self.bn2(out) 43 | out = self.relu(out) 44 | out = self.conv2(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | 51 | return out 52 | 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None): 58 | super(Bottleneck, self).__init__() 59 | self.bn1 = nn.BatchNorm2d(inplanes) 60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 61 | self.bn2 = nn.BatchNorm2d(planes) 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 63 | padding=1, bias=False) 64 | self.bn3 = nn.BatchNorm2d(planes) 65 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.downsample = downsample 68 | self.stride = stride 69 | 70 | def forward(self, x): 71 | residual = x 72 | 73 | out = self.bn1(x) 74 | out = self.relu(out) 75 | out = self.conv1(out) 76 | 77 | out = self.bn2(out) 78 | out = self.relu(out) 79 | out = self.conv2(out) 80 | 81 | out = self.bn3(out) 82 | out = self.relu(out) 83 | out = self.conv3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | 90 | return out 91 | 92 | 93 | class PreResNet(nn.Module): 94 | 95 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock'): 96 | super(PreResNet, self).__init__() 97 | # Model type specifies number of layers for CIFAR-10 model 98 | if block_name.lower() == 'basicblock': 99 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' 100 | n = (depth - 2) // 6 101 | block = BasicBlock 102 | elif block_name.lower() == 'bottleneck': 103 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' 104 | n = (depth - 2) // 9 105 | block = Bottleneck 106 | else: 107 | raise ValueError('block_name shoule be Basicblock or Bottleneck') 108 | 109 | self.inplanes = 16 110 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 111 | bias=False) 112 | self.layer1 = self._make_layer(block, 16, n) 113 | self.layer2 = self._make_layer(block, 32, n, stride=2) 114 | self.layer3 = self._make_layer(block, 64, n, stride=2) 115 | self.bn = nn.BatchNorm2d(64 * block.expansion) 116 | self.relu = nn.ReLU(inplace=True) 117 | self.avgpool = nn.AvgPool2d(8) 118 | self.fc = nn.Linear(64 * block.expansion, num_classes) 119 | 120 | for m in self.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 123 | m.weight.data.normal_(0, math.sqrt(2. / n)) 124 | elif isinstance(m, nn.BatchNorm2d): 125 | m.weight.data.fill_(1) 126 | m.bias.data.zero_() 127 | 128 | def _make_layer(self, block, planes, blocks, stride=1): 129 | downsample = None 130 | if stride != 1 or self.inplanes != planes * block.expansion: 131 | downsample = nn.Sequential( 132 | nn.Conv2d(self.inplanes, planes * block.expansion, 133 | kernel_size=1, stride=stride, bias=False), 134 | ) 135 | 136 | layers = [] 137 | layers.append(block(self.inplanes, planes, stride, downsample)) 138 | self.inplanes = planes * block.expansion 139 | for i in range(1, blocks): 140 | layers.append(block(self.inplanes, planes)) 141 | 142 | return nn.Sequential(*layers) 143 | 144 | def forward(self, x): 145 | x = self.conv1(x) 146 | 147 | x = self.layer1(x) # 32x32 148 | x = self.layer2(x) # 16x16 149 | x = self.layer3(x) # 8x8 150 | x = self.bn(x) 151 | x = self.relu(x) 152 | 153 | x = self.avgpool(x) 154 | x = x.view(x.size(0), -1) 155 | x = self.fc(x) 156 | 157 | return x 158 | 159 | 160 | def preresnet(**kwargs): 161 | """ 162 | Constructs a ResNet model. 163 | """ 164 | return PreResNet(**kwargs) 165 | -------------------------------------------------------------------------------- /models/cifar/resnext.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | """ 3 | Creates a ResNeXt Model as defined in: 4 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016). 5 | Aggregated residual transformations for deep neural networks. 6 | arXiv preprint arXiv:1611.05431. 7 | import from https://github.com/prlz77/ResNeXt.pytorch/blob/master/models/model.py 8 | """ 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn import init 12 | 13 | __all__ = ['resnext'] 14 | 15 | class ResNeXtBottleneck(nn.Module): 16 | """ 17 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua) 18 | """ 19 | def __init__(self, in_channels, out_channels, stride, cardinality, widen_factor): 20 | """ Constructor 21 | Args: 22 | in_channels: input channel dimensionality 23 | out_channels: output channel dimensionality 24 | stride: conv stride. Replaces pooling layer. 25 | cardinality: num of convolution groups. 26 | widen_factor: factor to reduce the input dimensionality before convolution. 27 | """ 28 | super(ResNeXtBottleneck, self).__init__() 29 | D = cardinality * out_channels // widen_factor 30 | self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False) 31 | self.bn_reduce = nn.BatchNorm2d(D) 32 | self.conv_conv = nn.Conv2d(D, D, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 33 | self.bn = nn.BatchNorm2d(D) 34 | self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False) 35 | self.bn_expand = nn.BatchNorm2d(out_channels) 36 | 37 | self.shortcut = nn.Sequential() 38 | if in_channels != out_channels: 39 | self.shortcut.add_module('shortcut_conv', nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False)) 40 | self.shortcut.add_module('shortcut_bn', nn.BatchNorm2d(out_channels)) 41 | 42 | def forward(self, x): 43 | bottleneck = self.conv_reduce.forward(x) 44 | bottleneck = F.relu(self.bn_reduce.forward(bottleneck), inplace=True) 45 | bottleneck = self.conv_conv.forward(bottleneck) 46 | bottleneck = F.relu(self.bn.forward(bottleneck), inplace=True) 47 | bottleneck = self.conv_expand.forward(bottleneck) 48 | bottleneck = self.bn_expand.forward(bottleneck) 49 | residual = self.shortcut.forward(x) 50 | return F.relu(residual + bottleneck, inplace=True) 51 | 52 | 53 | class CifarResNeXt(nn.Module): 54 | """ 55 | ResNext optimized for the Cifar dataset, as specified in 56 | https://arxiv.org/pdf/1611.05431.pdf 57 | """ 58 | def __init__(self, cardinality, depth, num_classes, widen_factor=4, dropRate=0): 59 | """ Constructor 60 | Args: 61 | cardinality: number of convolution groups. 62 | depth: number of layers. 63 | num_classes: number of classes 64 | widen_factor: factor to adjust the channel dimensionality 65 | """ 66 | super(CifarResNeXt, self).__init__() 67 | self.cardinality = cardinality 68 | self.depth = depth 69 | self.block_depth = (self.depth - 2) // 9 70 | self.widen_factor = widen_factor 71 | self.num_classes = num_classes 72 | self.output_size = 64 73 | self.stages = [64, 64 * self.widen_factor, 128 * self.widen_factor, 256 * self.widen_factor] 74 | 75 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False) 76 | self.bn_1 = nn.BatchNorm2d(64) 77 | self.stage_1 = self.block('stage_1', self.stages[0], self.stages[1], 1) 78 | self.stage_2 = self.block('stage_2', self.stages[1], self.stages[2], 2) 79 | self.stage_3 = self.block('stage_3', self.stages[2], self.stages[3], 2) 80 | self.classifier = nn.Linear(1024, num_classes) 81 | init.kaiming_normal(self.classifier.weight) 82 | 83 | for key in self.state_dict(): 84 | if key.split('.')[-1] == 'weight': 85 | if 'conv' in key: 86 | init.kaiming_normal(self.state_dict()[key], mode='fan_out') 87 | if 'bn' in key: 88 | self.state_dict()[key][...] = 1 89 | elif key.split('.')[-1] == 'bias': 90 | self.state_dict()[key][...] = 0 91 | 92 | def block(self, name, in_channels, out_channels, pool_stride=2): 93 | """ Stack n bottleneck modules where n is inferred from the depth of the network. 94 | Args: 95 | name: string name of the current block. 96 | in_channels: number of input channels 97 | out_channels: number of output channels 98 | pool_stride: factor to reduce the spatial dimensionality in the first bottleneck of the block. 99 | Returns: a Module consisting of n sequential bottlenecks. 100 | """ 101 | block = nn.Sequential() 102 | for bottleneck in range(self.block_depth): 103 | name_ = '%s_bottleneck_%d' % (name, bottleneck) 104 | if bottleneck == 0: 105 | block.add_module(name_, ResNeXtBottleneck(in_channels, out_channels, pool_stride, self.cardinality, 106 | self.widen_factor)) 107 | else: 108 | block.add_module(name_, 109 | ResNeXtBottleneck(out_channels, out_channels, 1, self.cardinality, self.widen_factor)) 110 | return block 111 | 112 | def forward(self, x): 113 | x = self.conv_1_3x3.forward(x) 114 | x = F.relu(self.bn_1.forward(x), inplace=True) 115 | x = self.stage_1.forward(x) 116 | x = self.stage_2.forward(x) 117 | x = self.stage_3.forward(x) 118 | x = F.avg_pool2d(x, 8, 1) 119 | x = x.view(-1, 1024) 120 | return self.classifier(x) 121 | 122 | def resnext(**kwargs): 123 | """Constructs a ResNeXt. 124 | """ 125 | model = CifarResNeXt(**kwargs) 126 | return model -------------------------------------------------------------------------------- /models/cifar/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | '''Resnet for cifar dataset. 4 | Ported form 5 | https://github.com/facebook/fb.resnet.torch 6 | and 7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 8 | (c) YANG, Wei 9 | ''' 10 | import torch 11 | import torch.nn as nn 12 | import math 13 | 14 | 15 | __all__ = ['resnet'] 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | "3x3 convolution with padding" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | 26 | def __init__(self, inplanes, planes, stride=1, downsample=None): 27 | super(BasicBlock, self).__init__() 28 | self.conv1 = conv3x3(inplanes, planes, stride) 29 | self.bn1 = nn.BatchNorm2d(planes) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.conv2 = conv3x3(planes, planes) 32 | self.bn2 = nn.BatchNorm2d(planes) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | residual = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | out = self.relu(out) 51 | 52 | return out 53 | 54 | 55 | class Bottleneck(nn.Module): 56 | expansion = 4 57 | 58 | def __init__(self, inplanes, planes, stride=1, downsample=None): 59 | super(Bottleneck, self).__init__() 60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 61 | self.bn1 = nn.BatchNorm2d(planes) 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 63 | padding=1, bias=False) 64 | self.bn2 = nn.BatchNorm2d(planes) 65 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 66 | self.bn3 = nn.BatchNorm2d(planes * 4) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.downsample = downsample 69 | self.stride = stride 70 | 71 | def forward(self, x): 72 | residual = x 73 | 74 | out = self.conv1(x) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv2(out) 79 | out = self.bn2(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv3(out) 83 | out = self.bn3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class ResNet(nn.Module): 95 | 96 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock', WVN=False): 97 | super(ResNet, self).__init__() 98 | # Model type specifies number of layers for CIFAR-10 model 99 | if block_name.lower() == 'basicblock': 100 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' 101 | n = (depth - 2) // 6 102 | block = BasicBlock 103 | elif block_name.lower() == 'bottleneck': 104 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' 105 | n = (depth - 2) // 9 106 | block = Bottleneck 107 | else: 108 | raise ValueError('block_name shoule be Basicblock or Bottleneck') 109 | 110 | 111 | self.inplanes = 16 112 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 113 | bias=False) 114 | self.bn1 = nn.BatchNorm2d(16) 115 | self.relu = nn.ReLU(inplace=True) 116 | self.layer1 = self._make_layer(block, 16, n) 117 | self.layer2 = self._make_layer(block, 32, n, stride=2) 118 | self.layer3 = self._make_layer(block, 64, n, stride=2) 119 | self.avgpool = nn.AvgPool2d(8) 120 | self.fc = nn.Linear(64 * block.expansion, num_classes, bias=False) 121 | 122 | if WVN: 123 | self.fc.register_backward_hook(self.__WVN__) 124 | 125 | 126 | 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 130 | m.weight.data.normal_(0, math.sqrt(2. / n)) 131 | elif isinstance(m, nn.BatchNorm2d): 132 | m.weight.data.fill_(1) 133 | m.bias.data.zero_() 134 | 135 | def __WVN__(self, module, grad_input, grad_output): 136 | W = module.weight.data 137 | W_norm = W / torch.norm(W, p=2, dim=1, keepdim=True) 138 | 139 | module.weight.data.copy_(W_norm) 140 | 141 | def _make_layer(self, block, planes, blocks, stride=1): 142 | downsample = None 143 | if stride != 1 or self.inplanes != planes * block.expansion: 144 | downsample = nn.Sequential( 145 | nn.Conv2d(self.inplanes, planes * block.expansion, 146 | kernel_size=1, stride=stride, bias=False), 147 | nn.BatchNorm2d(planes * block.expansion), 148 | ) 149 | 150 | layers = [] 151 | layers.append(block(self.inplanes, planes, stride, downsample)) 152 | self.inplanes = planes * block.expansion 153 | for i in range(1, blocks): 154 | layers.append(block(self.inplanes, planes)) 155 | 156 | return nn.Sequential(*layers) 157 | 158 | def forward(self, x): 159 | x = self.conv1(x) 160 | x = self.bn1(x) 161 | x = self.relu(x) # 32x32 162 | 163 | x = self.layer1(x) # 32x32 164 | x = self.layer2(x) # 16x16 165 | x = self.layer3(x) # 8x8 166 | 167 | x = self.avgpool(x) 168 | x = x.view(x.size(0), -1) 169 | x = self.fc(x) 170 | 171 | return x 172 | 173 | 174 | def resnet(**kwargs): 175 | """ 176 | Constructs a ResNet model. 177 | """ 178 | return ResNet(**kwargs) 179 | -------------------------------------------------------------------------------- /cifar.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Training script for CIFAR-10/100 3 | Copyright (c) Wei YANG, 2017 4 | ''' 5 | from __future__ import print_function 6 | 7 | import argparse 8 | import os 9 | import shutil 10 | import time 11 | import random 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.parallel 16 | import torch.backends.cudnn as cudnn 17 | import torch.optim as optim 18 | import torch.utils.data as data 19 | import torchvision.transforms as transforms 20 | import torchvision.datasets as datasets 21 | import models.cifar as models 22 | 23 | 24 | from utils import AverageMeter, accuracy, mkdir_p, cifar_loader 25 | import utils.log 26 | 27 | 28 | model_names = sorted(name for name in models.__dict__ 29 | if name.islower() and not name.startswith("__") 30 | and callable(models.__dict__[name])) 31 | 32 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10/100 Training') 33 | # Datasets 34 | parser.add_argument('-d', '--dataset', default='cifar10', type=str) 35 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 36 | help='number of data loading workers (default: 8)') 37 | parser.add_argument('--imbalance', default=1, type=int, 38 | help='imbalance factor for cifar dataset') 39 | # Optimization options 40 | parser.add_argument('--epochs', default=180, type=int, metavar='N', 41 | help='number of total epochs to run') 42 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 43 | help='manual epoch number (useful on restarts)') 44 | parser.add_argument('--train-batch', default=128, type=int, metavar='N', 45 | help='train batchsize') 46 | parser.add_argument('--test-batch', default=100, type=int, metavar='N', 47 | help='test batchsize') 48 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 49 | metavar='LR', help='initial learning rate') 50 | parser.add_argument('--drop', '--dropout', default=0, type=float, 51 | metavar='Dropout', help='Dropout ratio') 52 | parser.add_argument('--schedule', type=int, nargs='+', default=[80, 150], 53 | help='Decrease learning rate at these epochs.') 54 | parser.add_argument('--gamma', type=float, default=0.1, 55 | help='LR is multiplied by gamma on schedule.') 56 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 57 | help='momentum') 58 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 59 | metavar='W', help='weight decay (default: 1e-4)') 60 | # Checkpoints 61 | parser.add_argument('-c', '--checkpoint', default='checkpoints', 62 | type=str, metavar='PATH', 63 | help='path to save checkpoint (default: checkpoint)') 64 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 65 | help='path to latest checkpoint (default: none)') 66 | # Architecture 67 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet20', 68 | choices=model_names, 69 | help='model architecture: ' + 70 | ' | '.join(model_names) + 71 | ' (default: resnet18)') 72 | parser.add_argument('--depth', type=int, default=29, help='Model depth.') 73 | parser.add_argument('--block-name', type=str, default='BasicBlock', 74 | help='the building block for Resnet and Preresnet') 75 | parser.add_argument('--cardinality', type=int, default=8, 76 | help='Model cardinality (group).') 77 | parser.add_argument('--widen-factor', type=int, default=4, 78 | help='Widen factor. 4 -> 64, 8 -> 128, ...') 79 | parser.add_argument('--growthRate', type=int, default=12, 80 | help='Growth rate for DenseNet.') 81 | parser.add_argument('--compressionRate', type=int, default=2, 82 | help='Compression Rate (theta) for DenseNet.') 83 | parser.add_argument('--WVN', dest='wvn', action='store_true', 84 | help='whether to use WVN or not') 85 | parser.add_argument('--RS', default=0.1, type=float, 86 | help='gamma for weight re-scaling') 87 | # Miscs 88 | parser.add_argument('--manualSeed', type=int, help='manual seed') 89 | parser.add_argument('--evaluate', dest='evaluate', action='store_true', 90 | help='evaluate model on validation set') 91 | #Device options 92 | parser.add_argument('--gpu-id', default='0', type=str, 93 | help='id(s) for CUDA_VISIBLE_DEVICES') 94 | 95 | args = parser.parse_args() 96 | state = {k: v for k, v in args._get_kwargs()} 97 | 98 | # Validate dataset 99 | assert args.dataset == 'cifar10' or args.dataset == 'cifar100', \ 100 | 'Dataset can only be cifar10 or cifar100.' 101 | 102 | # Use CUDA 103 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 104 | use_cuda = torch.cuda.is_available() 105 | 106 | # Random seed 107 | if args.manualSeed is None: 108 | args.manualSeed = random.randint(1, 10000) 109 | random.seed(args.manualSeed) 110 | torch.manual_seed(args.manualSeed) 111 | if use_cuda: 112 | torch.cuda.manual_seed_all(args.manualSeed) 113 | 114 | best_acc = 0 # best test accuracy 115 | 116 | def main(): 117 | global best_acc 118 | start_epoch = args.start_epoch # start from epoch 0 or last 119 | 120 | if not os.path.isdir(args.checkpoint): 121 | mkdir_p(args.checkpoint) 122 | exp_name = args.checkpoint.split('/')[-1] 123 | logger = utils.log.logger_setting(exp_name, args.checkpoint) 124 | print('Experiment Name : %s'%exp_name) 125 | log_prefix = 'Epoch:[%3d | %d] LR: %.4f, ' + \ 126 | 'Loss(Tr): %.4f, Loss(Tt): %.4f, ' + \ 127 | 'Acc(Tr): %.4f, Acc(Tt): %.4f' 128 | 129 | 130 | # Data 131 | print('==> Preparing dataset %s' % args.dataset) 132 | transform_train = transforms.Compose([ 133 | transforms.ToPILImage(), 134 | transforms.RandomCrop(32, padding=4), 135 | transforms.RandomHorizontalFlip(), 136 | transforms.ToTensor(), 137 | transforms.Normalize((0.4914, 0.4822, 0.4465), 138 | (0.2023, 0.1994, 0.2010)), 139 | ]) 140 | 141 | transform_test = transforms.Compose([ 142 | transforms.ToTensor(), 143 | transforms.Normalize((0.4914, 0.4822, 0.4465), 144 | (0.2023, 0.1994, 0.2010)), 145 | ]) 146 | if args.dataset == 'cifar10': 147 | num_classes = 10 148 | else: 149 | num_classes = 100 150 | 151 | data_path = os.path.join('./data',args.dataset) 152 | trainset = utils.cifar_loader.CIFARLoader(root=data_path, 153 | train=True, 154 | imbalance=args.imbalance, 155 | transform=transform_train) 156 | trainloader = data.DataLoader(trainset, batch_size=args.train_batch, 157 | shuffle=True, num_workers=args.workers) 158 | 159 | 160 | 161 | testset = utils.cifar_loader.CIFARLoader(root=data_path, 162 | train=False, 163 | transform=transform_test) 164 | testloader = data.DataLoader(testset, batch_size=args.test_batch, 165 | shuffle=False, num_workers=args.workers) 166 | 167 | # Model 168 | print("==> creating model '{}'".format(args.arch)) 169 | if args.arch.startswith('resnext'): 170 | model = models.__dict__[args.arch]( 171 | cardinality=args.cardinality, 172 | num_classes=num_classes, 173 | depth=args.depth, 174 | widen_factor=args.widen_factor, 175 | dropRate=args.drop, 176 | ) 177 | elif args.arch.startswith('densenet'): 178 | model = models.__dict__[args.arch]( 179 | num_classes=num_classes, 180 | depth=args.depth, 181 | growthRate=args.growthRate, 182 | compressionRate=args.compressionRate, 183 | dropRate=args.drop, 184 | ) 185 | elif args.arch.startswith('wrn'): 186 | model = models.__dict__[args.arch]( 187 | num_classes=num_classes, 188 | depth=args.depth, 189 | widen_factor=args.widen_factor, 190 | dropRate=args.drop, 191 | ) 192 | elif args.arch.endswith('resnet'): 193 | print("Use Weight Vector Normalization") 194 | model = models.__dict__[args.arch]( 195 | num_classes=num_classes, 196 | depth=args.depth, 197 | block_name=args.block_name, 198 | WVN=args.wvn, 199 | ) 200 | else: 201 | model = models.__dict__[args.arch](num_classes=num_classes) 202 | 203 | model = torch.nn.DataParallel(model).cuda() 204 | cudnn.benchmark = True 205 | print(' Total params: %.2fM' % \ 206 | (sum(p.numel() for p in model.parameters())/1000000.0)) 207 | criterion = nn.CrossEntropyLoss() 208 | optimizer = optim.SGD(model.parameters(), lr=args.lr, 209 | momentum=args.momentum, 210 | weight_decay=args.weight_decay) 211 | 212 | # Resume 213 | title = 'cifar-10-' + args.arch 214 | if args.resume: 215 | # Load checkpoint. 216 | print('==> Resuming from checkpoint..') 217 | print(args.resume) 218 | assert os.path.isfile(args.resume), \ 219 | 'Error: no checkpoint directory found!' 220 | args.checkpoint = os.path.dirname(args.resume) 221 | checkpoint = torch.load(args.resume) 222 | best_acc = checkpoint['best_acc'] 223 | start_epoch = checkpoint['epoch'] 224 | model.load_state_dict(checkpoint['state_dict']) 225 | optimizer.load_state_dict(checkpoint['optimizer']) 226 | 227 | 228 | if args.evaluate: 229 | print('\nEvaluation only') 230 | test_loss, test_acc = test(testloader, model, criterion, 231 | start_epoch, use_cuda) 232 | print('[w/o RS] Test Loss: %.8f, Test Acc: %.2f%%' % (test_loss, test_acc)) 233 | 234 | current_state = model.state_dict() 235 | W = current_state['module.fc.weight'] 236 | 237 | imb_factor = 1. / args.imbalance 238 | img_max = 50000/num_classes 239 | num_sample = [img_max * (imb_factor**(i/(num_classes - 1))) \ 240 | for i in range(num_classes)] 241 | 242 | ns = [ float(n) / max(num_sample) for n in num_sample ] 243 | ns = [ n**args.RS for n in ns ] 244 | ns = torch.FloatTensor(ns).unsqueeze(-1).cuda() 245 | new_W = W / ns 246 | 247 | current_state['module.fc.weight'] = new_W 248 | model.load_state_dict(current_state) 249 | 250 | test_loss, test_acc = test(testloader, model, criterion, 251 | start_epoch, use_cuda) 252 | print('[w/ RS] Test Loss: %.8f, Test Acc: %.2f%%' % (test_loss, test_acc)) 253 | 254 | return 255 | 256 | # Train and val 257 | for epoch in range(start_epoch, args.epochs): 258 | adjust_learning_rate(optimizer, epoch) 259 | 260 | 261 | train_loss, train_acc = train(trainloader, model, 262 | criterion, optimizer, 263 | epoch, use_cuda) 264 | test_loss, test_acc = test(testloader, model, 265 | criterion, epoch, use_cuda) 266 | 267 | msg = log_prefix%(epoch+1, args.epochs, state['lr'], \ 268 | train_loss, test_loss, \ 269 | train_acc/100, test_acc/100) 270 | logger.info(msg) 271 | 272 | # save model 273 | is_best = test_acc > best_acc 274 | best_acc = max(test_acc, best_acc) 275 | save_checkpoint({ 276 | 'epoch': epoch + 1, 277 | 'state_dict': model.state_dict(), 278 | 'acc': test_acc, 279 | 'best_acc': best_acc, 280 | 'optimizer' : optimizer.state_dict(), 281 | }, is_best, checkpoint=args.checkpoint) 282 | 283 | 284 | print('Best acc:') 285 | print(best_acc) 286 | 287 | def train(trainloader, model, criterion, optimizer, epoch, use_cuda): 288 | # switch to train mode 289 | model.train() 290 | 291 | batch_time = AverageMeter() 292 | data_time = AverageMeter() 293 | losses = AverageMeter() 294 | top1 = AverageMeter() 295 | top5 = AverageMeter() 296 | end = time.time() 297 | max_iter = trainloader.__len__() 298 | num_print = 13 299 | denom = max_iter // num_print 300 | per = 1/num_print*100 301 | 302 | for batch_idx, (inputs, targets) in enumerate(trainloader): 303 | # measure data loading time 304 | data_time.update(time.time() - end) 305 | 306 | if use_cuda: 307 | inputs, targets = inputs.cuda(), targets.cuda(async=True) 308 | inputs = torch.autograd.Variable(inputs) 309 | targets = torch.autograd.Variable(targets) 310 | 311 | # compute output 312 | outputs = model(inputs) 313 | loss = criterion(outputs, targets) 314 | 315 | # measure accuracy and record loss 316 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 317 | losses.update(loss.data[0], inputs.size(0)) 318 | top1.update(prec1[0], inputs.size(0)) 319 | top5.update(prec5[0], inputs.size(0)) 320 | 321 | # compute gradient and do SGD step 322 | optimizer.zero_grad() 323 | loss.backward() 324 | optimizer.step() 325 | 326 | if not (batch_idx+1)%denom: 327 | print('[%2dE.%3d%%] Train Loss: %.4f Elapsed Time: %.3f'% \ 328 | (epoch,per*(batch_idx//denom+1), loss, time.time()-end)) 329 | 330 | # measure elapsed time 331 | batch_time.update(time.time() - end) 332 | end = time.time() 333 | 334 | 335 | 336 | 337 | return (losses.avg, top1.avg) 338 | 339 | def test(testloader, model, criterion, epoch, use_cuda): 340 | global best_acc 341 | 342 | batch_time = AverageMeter() 343 | data_time = AverageMeter() 344 | losses = AverageMeter() 345 | top1 = AverageMeter() 346 | top5 = AverageMeter() 347 | 348 | # switch to evaluate mode 349 | model.eval() 350 | 351 | end = time.time() 352 | for batch_idx, (inputs, targets) in enumerate(testloader): 353 | # measure data loading time 354 | data_time.update(time.time() - end) 355 | 356 | if use_cuda: 357 | inputs, targets = inputs.cuda(), targets.cuda() 358 | inputs = torch.autograd.Variable(inputs, volatile=True) 359 | targets = torch.autograd.Variable(targets) 360 | 361 | # compute output 362 | outputs = model(inputs) 363 | loss = criterion(outputs, targets) 364 | 365 | # measure accuracy and record loss 366 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 367 | losses.update(loss.data[0], inputs.size(0)) 368 | top1.update(prec1[0], inputs.size(0)) 369 | top5.update(prec5[0], inputs.size(0)) 370 | 371 | # measure elapsed time 372 | batch_time.update(time.time() - end) 373 | end = time.time() 374 | return (losses.avg, top1.avg) 375 | 376 | def save_checkpoint(state, is_best, checkpoint='checkpoints', 377 | filename='checkpoint.pth.tar'): 378 | filepath = os.path.join(checkpoint, filename) 379 | torch.save(state, filepath) 380 | if is_best: 381 | dst = os.path.join(checkpoint, 'model_best.pth.tar') 382 | shutil.copyfile(filepath, dst) 383 | def adjust_learning_rate(optimizer, epoch): 384 | global state 385 | if epoch in args.schedule: 386 | state['lr'] *= args.gamma 387 | for param_group in optimizer.param_groups: 388 | param_group['lr'] = state['lr'] 389 | 390 | if __name__ == '__main__': 391 | main() 392 | --------------------------------------------------------------------------------