├── README.md ├── cifar.py ├── imagenet.py ├── imagenet_dataloader.py ├── models ├── __init__.py ├── __init__.pyc ├── cifar │ ├── __init__.py │ ├── __init__.pyc │ ├── mixnet.py │ └── mixnet.pyc └── imagenet │ ├── __init__.py │ └── mixnet.py ├── pretrained └── README.md └── utils ├── __init__.py ├── __init__.pyc ├── eval.py ├── eval.pyc ├── logger.py ├── logger.pyc ├── misc.py ├── misc.pyc ├── visualize.py └── visualize.pyc /README.md: -------------------------------------------------------------------------------- 1 | # Mixed Link Networks 2 | MixNet: [[Arxiv](https://arxiv.org/abs/1802.01808)] 3 | 4 | by Wenhai Wang, Xiang Li, Jian Yang, Tong Lu 5 | 6 | IMAGINE Lab@National Key Laboratory for Novel Software Technology, Nanjing University. 7 | DeepInsight@PCALab, Nanjing University of Science and Technology. 8 | 9 | ## Requirements 10 | * Install [PyTorch v0.2.0](http://pytorch.org/) 11 | * Clone recursively 12 | ``` 13 | git clone --recursive https://github.com/DeepInsight-PCALab/MixNet.git 14 | ``` 15 | * Download the ImageNet dataset and move validation images to labeled subfolders 16 | * To do this, you can use the following script: https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh 17 | 18 | ## Training 19 | ### CIFAR-10 20 | ``` 21 | CUDA_VISIBLE_DEVICES=0 python cifar.py --dataset cifar10 --depth 100 --k1 12 --k2 12 --train-batch 64 --epochs 300 --schedule 150 225 --wd 1e-4 --gamma 0.1 --checkpoint checkpoints/cifar10/mixnet-100/ 22 | ``` 23 | 24 | ### ImageNet 25 | ``` 26 | CUDA_VISIBLE_DEVICES=0,1,2,3 python imagenet.py -d ../imagenet/ -j 4 --arch mixnet105 --train-batch 200 --checkpoint checkpoints/imagenet/mixnet-105/ 27 | ``` 28 | 29 | ## Testing on ImageNet 30 | ``` 31 | CUDA_VISIBLE_DEVICES=0 python imagenet.py -d ../imagenet/ -j 4 --arch mixnet105 --test-batch 20 --pretrained pretrained/mixnet105.pth.tar --evaluate 32 | ``` 33 | 34 | ## Results on CIFAR 35 | | Model | Parameters | CIFAR-10 | CIFAR-100 | 36 | | - | - | - | - | 37 | | MixNet-100 (k1 = 12, k2 = 12) | 1.5M | 4.19 | 21.12 | 38 | | MixNet-250 (k1 = 24, k2 = 24) | 29.0M | 3.32 | 17.06 | 39 | | MixNet-190 (k1 = 40, k2 = 40) | 48.5M | 3.13 | 16.96 | 40 | 41 | ## Results on ImageNet and Pretrained Models 42 | 43 | | Method | Parameters | Top-1 error | Pretrained model | 44 | | - | - | - | - | 45 | | MixNet-105 (k1 = 32, k2 = 32) | 11.16M | 23.3 | [baidu](https://pan.baidu.com/s/1q-LjwofEu2nM7feZClTA7w), [onedrive](https://1drv.ms/u/s!Ai5Ldd26LrzkkigERtzmTEFTjN89) | 46 | | MixNet-121 (k1 = 40, k2 = 40) | 21.86M | 21.9 | [baidu](https://pan.baidu.com/s/1wIzkO0UVIXd_BPx_lmT7_w), [onedrive](https://1drv.ms/u/s!Ai5Ldd26LrzkkiniBUJ50Stp4sRP) | 47 | | MixNet-141 (k1 = 48, k2 = 48) | 41.07M | 20.4 | [baidu](https://pan.baidu.com/s/1lYczUcAczhkQqpEwjZT66Q), [onedrive](https://1drv.ms/u/s!Ai5Ldd26LrzkkioUUToxJ1m-VYR2) | 48 | 49 | ## Citation 50 | ``` 51 | @inproceedings{wang2018mixed, 52 | title={Mixed link networks}, 53 | author={Wang, Wenhai and Li, Xiang and Lu, Tong and Yang, Jian}, 54 | booktitle={Proceedings of the 27th International Joint Conference on Artificial Intelligence}, 55 | pages={2819--2825}, 56 | year={2018}, 57 | organization={AAAI Press} 58 | } 59 | ``` 60 | -------------------------------------------------------------------------------- /cifar.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | import shutil 6 | import time 7 | import random 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.optim as optim 14 | import torch.utils.data as data 15 | import torchvision.transforms as transforms 16 | import torchvision.datasets as datasets 17 | import models.cifar as models 18 | 19 | from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig 20 | 21 | 22 | model_names = sorted(name for name in models.__dict__ 23 | if name.islower() and not name.startswith("__") 24 | and callable(models.__dict__[name])) 25 | 26 | print(model_names) 27 | 28 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10/100 Training') 29 | # Datasets 30 | parser.add_argument('-d', '--dataset', default='cifar10', type=str) 31 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 32 | help='number of data loading workers (default: 4)') 33 | # Optimization options 34 | parser.add_argument('--epochs', default=300, type=int, metavar='N', 35 | help='number of total epochs to run') 36 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 37 | help='manual epoch number (useful on restarts)') 38 | parser.add_argument('--train-batch', default=64, type=int, metavar='N', 39 | help='train batchsize') 40 | parser.add_argument('--test-batch', default=100, type=int, metavar='N', 41 | help='test batchsize') 42 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 43 | metavar='LR', help='initial learning rate') 44 | parser.add_argument('--drop', '--dropout', default=0, type=float, 45 | metavar='Dropout', help='Dropout ratio') 46 | parser.add_argument('--schedule', type=int, nargs='+', default=[150, 225], 47 | help='Decrease learning rate at these epochs.') 48 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') 49 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 50 | help='momentum') 51 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 52 | metavar='W', help='weight decay (default: 1e-4)') 53 | # Checkpoints 54 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', 55 | help='path to save checkpoint (default: checkpoint)') 56 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 57 | help='path to latest checkpoint (default: none)') 58 | # Architecture 59 | parser.add_argument('--arch', '-a', metavar='ARCH', default='mixnet', 60 | choices=model_names, 61 | help='model architecture: ' + 62 | ' | '.join(model_names) + 63 | ' (default: mixnet)') 64 | parser.add_argument('--depth', type=int, default=29, help='Model depth.') 65 | parser.add_argument('--k1', type=int, default=12, help='Inner link parameter.') 66 | parser.add_argument('--k2', type=int, default=12, help='Outer link parameter.') 67 | parser.add_argument('--compressionRate', type=int, default=2, help='Compression Rate.') 68 | # Miscs 69 | parser.add_argument('--manualSeed', type=int, help='manual seed') 70 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 71 | help='evaluate model on validation set') 72 | 73 | args = parser.parse_args() 74 | state = {k: v for k, v in args._get_kwargs()} 75 | 76 | # Validate dataset 77 | assert args.dataset == 'cifar10' or args.dataset == 'cifar100', 'Dataset can only be cifar10 or cifar100.' 78 | 79 | # Use CUDA 80 | use_cuda = torch.cuda.is_available() 81 | 82 | # Random seed 83 | if args.manualSeed is None: 84 | args.manualSeed = random.randint(1, 10000) 85 | random.seed(args.manualSeed) 86 | torch.manual_seed(args.manualSeed) 87 | if use_cuda: 88 | torch.cuda.manual_seed_all(args.manualSeed) 89 | 90 | best_acc = 0 # best test accuracy 91 | 92 | def main(): 93 | global best_acc 94 | start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch 95 | 96 | if not os.path.isdir(args.checkpoint): 97 | mkdir_p(args.checkpoint) 98 | 99 | # Data 100 | print('==> Preparing dataset %s' % args.dataset) 101 | transform_train = transforms.Compose([ 102 | transforms.RandomCrop(32, padding=4), 103 | transforms.RandomHorizontalFlip(), 104 | transforms.ToTensor(), 105 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 106 | ]) 107 | 108 | transform_test = transforms.Compose([ 109 | transforms.ToTensor(), 110 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 111 | ]) 112 | if args.dataset == 'cifar10': 113 | dataloader = datasets.CIFAR10 114 | num_classes = 10 115 | else: 116 | dataloader = datasets.CIFAR100 117 | num_classes = 100 118 | 119 | 120 | trainset = dataloader(root='/home/shared/CIFAR/', train=True, download=True, transform=transform_train) 121 | trainloader = data.DataLoader(trainset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers) 122 | 123 | testset = dataloader(root='/home/shared/CIFAR/', train=False, download=False, transform=transform_test) 124 | testloader = data.DataLoader(testset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 125 | 126 | # Model 127 | print("==> creating model '{}'".format(args.arch)) 128 | 129 | model = models.__dict__[args.arch]( 130 | num_classes=num_classes, 131 | depth=args.depth, 132 | k1 = args.k1, 133 | k2=args.k2, 134 | compressionRate=args.compressionRate, 135 | dropRate=args.drop, 136 | ) 137 | 138 | model = torch.nn.DataParallel(model).cuda() 139 | cudnn.benchmark = True 140 | model_size = (sum(p.numel() for p in model.parameters())/1000000.0) 141 | print(' Total params: %.2fM' % model_size) 142 | 143 | criterion = nn.CrossEntropyLoss() 144 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 145 | 146 | # Resume 147 | title = 'cifar-10-' + args.arch 148 | if args.resume: 149 | # Load checkpoint. 150 | print('==> Resuming from checkpoint..') 151 | assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' 152 | args.checkpoint = os.path.dirname(args.resume) 153 | checkpoint = torch.load(args.resume) 154 | best_acc = checkpoint['best_acc'] 155 | start_epoch = checkpoint['epoch'] 156 | model.load_state_dict(checkpoint['state_dict']) 157 | optimizer.load_state_dict(checkpoint['optimizer']) 158 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True) 159 | else: 160 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) 161 | logger.set_names(['Total params.', 'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.']) 162 | 163 | if args.evaluate: 164 | print('\nEvaluation only') 165 | test_loss, test_acc = test(testloader, model, criterion, start_epoch, use_cuda) 166 | print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc)) 167 | return 168 | 169 | # Train and val 170 | for epoch in range(start_epoch, args.epochs): 171 | adjust_learning_rate(optimizer, epoch) 172 | 173 | print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) 174 | 175 | train_loss, train_acc = train(trainloader, model, criterion, optimizer, epoch, use_cuda) 176 | test_loss, test_acc = test(testloader, model, criterion, epoch, use_cuda) 177 | 178 | # save model 179 | is_best = test_acc > best_acc 180 | best_acc = max(test_acc, best_acc) 181 | save_checkpoint({ 182 | 'epoch': epoch + 1, 183 | 'state_dict': model.state_dict(), 184 | 'acc': test_acc, 185 | 'best_acc': best_acc, 186 | 'optimizer' : optimizer.state_dict(), 187 | }, is_best, checkpoint=args.checkpoint) 188 | 189 | # append logger file 190 | logger.append([model_size, state['lr'], train_loss, test_loss, train_acc, test_acc]) 191 | 192 | logger.close() 193 | 194 | print('Best acc:') 195 | print(best_acc) 196 | 197 | def train(trainloader, model, criterion, optimizer, epoch, use_cuda): 198 | # switch to train mode 199 | model.train() 200 | 201 | batch_time = AverageMeter() 202 | data_time = AverageMeter() 203 | losses = AverageMeter() 204 | top1 = AverageMeter() 205 | top5 = AverageMeter() 206 | end = time.time() 207 | 208 | bar = Bar('Processing', max=len(trainloader)) 209 | for batch_idx, (inputs, targets) in enumerate(trainloader): 210 | # measure data loading time 211 | data_time.update(time.time() - end) 212 | 213 | if use_cuda: 214 | inputs, targets = inputs.cuda(), targets.cuda(async=True) 215 | inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets) 216 | 217 | # compute output 218 | outputs = model(inputs) 219 | loss = criterion(outputs, targets) 220 | 221 | # measure accuracy and record loss 222 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 223 | losses.update(loss.data[0], inputs.size(0)) 224 | top1.update(prec1[0], inputs.size(0)) 225 | top5.update(prec5[0], inputs.size(0)) 226 | 227 | # compute gradient and do SGD step 228 | optimizer.zero_grad() 229 | loss.backward() 230 | optimizer.step() 231 | 232 | # measure elapsed time 233 | batch_time.update(time.time() - end) 234 | end = time.time() 235 | 236 | # plot progress 237 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 238 | batch=batch_idx + 1, 239 | size=len(trainloader), 240 | data=data_time.avg, 241 | bt=batch_time.avg, 242 | total=bar.elapsed_td, 243 | eta=bar.eta_td, 244 | loss=losses.avg, 245 | top1=top1.avg, 246 | top5=top5.avg, 247 | ) 248 | bar.next() 249 | bar.finish() 250 | return (losses.avg, top1.avg) 251 | 252 | def test(testloader, model, criterion, epoch, use_cuda): 253 | global best_acc 254 | 255 | batch_time = AverageMeter() 256 | data_time = AverageMeter() 257 | losses = AverageMeter() 258 | top1 = AverageMeter() 259 | top5 = AverageMeter() 260 | 261 | # switch to evaluate mode 262 | model.eval() 263 | 264 | end = time.time() 265 | bar = Bar('Processing', max=len(testloader)) 266 | for batch_idx, (inputs, targets) in enumerate(testloader): 267 | # measure data loading time 268 | data_time.update(time.time() - end) 269 | 270 | if use_cuda: 271 | inputs, targets = inputs.cuda(), targets.cuda() 272 | inputs, targets = torch.autograd.Variable(inputs, volatile=True), torch.autograd.Variable(targets) 273 | 274 | # compute output 275 | outputs = model(inputs) 276 | loss = criterion(outputs, targets) 277 | 278 | # measure accuracy and record loss 279 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 280 | losses.update(loss.data[0], inputs.size(0)) 281 | top1.update(prec1[0], inputs.size(0)) 282 | top5.update(prec5[0], inputs.size(0)) 283 | 284 | # measure elapsed time 285 | batch_time.update(time.time() - end) 286 | end = time.time() 287 | 288 | # plot progress 289 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 290 | batch=batch_idx + 1, 291 | size=len(testloader), 292 | data=data_time.avg, 293 | bt=batch_time.avg, 294 | total=bar.elapsed_td, 295 | eta=bar.eta_td, 296 | loss=losses.avg, 297 | top1=top1.avg, 298 | top5=top5.avg, 299 | ) 300 | bar.next() 301 | bar.finish() 302 | return (losses.avg, top1.avg) 303 | 304 | def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'): 305 | filepath = os.path.join(checkpoint, filename) 306 | torch.save(state, filepath) 307 | if is_best: 308 | shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) 309 | 310 | def adjust_learning_rate(optimizer, epoch): 311 | global state 312 | if epoch in args.schedule: 313 | state['lr'] *= args.gamma 314 | for param_group in optimizer.param_groups: 315 | param_group['lr'] = state['lr'] 316 | 317 | if __name__ == '__main__': 318 | main() -------------------------------------------------------------------------------- /imagenet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Training script for ImageNet 3 | Copyright (c) Wei YANG, 2017 4 | ''' 5 | from __future__ import print_function 6 | 7 | import argparse 8 | import os 9 | import shutil 10 | import time 11 | import random 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.parallel 16 | import torch.backends.cudnn as cudnn 17 | import torch.optim as optim 18 | import torch.utils.data as data 19 | import torchvision.transforms as transforms 20 | import torchvision.datasets as datasets 21 | # import torchvision.models as models 22 | import models.imagenet as customized_models 23 | import models.imagenet as models 24 | import imagenet_dataloader as dataloader 25 | 26 | from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig 27 | 28 | # Models 29 | default_model_names = sorted(name for name in models.__dict__ 30 | if name.islower() and not name.startswith("__") 31 | and callable(models.__dict__[name])) 32 | 33 | customized_models_names = sorted(name for name in customized_models.__dict__ 34 | if name.islower() and not name.startswith("__") 35 | and callable(customized_models.__dict__[name])) 36 | 37 | for name in customized_models.__dict__: 38 | if name.islower() and not name.startswith("__") and callable(customized_models.__dict__[name]): 39 | models.__dict__[name] = customized_models.__dict__[name] 40 | 41 | model_names = default_model_names + customized_models_names 42 | 43 | # Parse arguments 44 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 45 | 46 | # Datasets 47 | parser.add_argument('-d', '--data', default='path to dataset', type=str) 48 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 49 | help='number of data loading workers (default: 4)') 50 | # Optimization options 51 | parser.add_argument('--epochs', default=110, type=int, metavar='N', 52 | help='number of total epochs to run') 53 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 54 | help='manual epoch number (useful on restarts)') 55 | parser.add_argument('--train-batch', default=100, type=int, metavar='N', 56 | help='train batchsize (default: 256)') 57 | parser.add_argument('--test-batch', default=200, type=int, metavar='N', 58 | help='test batchsize (default: 200)') 59 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 60 | metavar='LR', help='initial learning rate') 61 | parser.add_argument('--drop', '--dropout', default=0, type=float, 62 | metavar='Dropout', help='Dropout ratio') 63 | parser.add_argument('--schedule', type=int, nargs='+', default=[30, 60, 90, 100], 64 | help='Decrease learning rate at these epochs.') 65 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') 66 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 67 | help='momentum') 68 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 69 | metavar='W', help='weight decay (default: 1e-4)') 70 | # Checkpoints 71 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', 72 | help='path to save checkpoint (default: checkpoint)') 73 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 74 | help='path to latest checkpoint (default: none)') 75 | # Architecture 76 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', 77 | choices=model_names, 78 | help='model architecture: ' + 79 | ' | '.join(model_names) + 80 | ' (default: resnet18)') 81 | parser.add_argument('--depth', type=int, default=29, help='Model depth.') 82 | parser.add_argument('--cardinality', type=int, default=32, help='ResNet cardinality (group).') 83 | parser.add_argument('--base-width', type=int, default=4, help='ResNet base width.') 84 | parser.add_argument('--widen-factor', type=int, default=4, help='Widen factor. 4 -> 64, 8 -> 128, ...') 85 | # Miscs 86 | parser.add_argument('--manualSeed', type=int, help='manual seed') 87 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 88 | help='evaluate model on validation set') 89 | parser.add_argument('--pretrained', dest='pretrained', default=None, type=str, metavar='PATH', 90 | help='use pre-trained model') 91 | 92 | args = parser.parse_args() 93 | state = {k: v for k, v in args._get_kwargs()} 94 | 95 | # Use CUDA 96 | use_cuda = torch.cuda.is_available() 97 | 98 | # Random seed 99 | if args.manualSeed is None: 100 | args.manualSeed = random.randint(1, 10000) 101 | random.seed(args.manualSeed) 102 | torch.manual_seed(args.manualSeed) 103 | if use_cuda: 104 | torch.cuda.manual_seed_all(args.manualSeed) 105 | 106 | best_acc = 0 # best test accuracy 107 | 108 | def main(): 109 | print('schedule:', args.schedule) 110 | print('batch size:', args.train_batch) 111 | global best_acc 112 | start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch 113 | 114 | if not os.path.isdir(args.checkpoint): 115 | mkdir_p(args.checkpoint) 116 | 117 | # Data loading code 118 | traindir = os.path.join(args.data, 'train') 119 | valdir = os.path.join(args.data, 'val') 120 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 121 | std=[0.229, 0.224, 0.225]) 122 | 123 | train_loader = torch.utils.data.DataLoader( 124 | dataloader.ImageFolder(traindir, transforms.Compose([ 125 | transforms.RandomSizedCrop(224), 126 | transforms.RandomHorizontalFlip(), 127 | transforms.ToTensor(), 128 | normalize, 129 | ])), 130 | batch_size=args.train_batch, shuffle=True, 131 | num_workers=args.workers, pin_memory=True) 132 | 133 | val_loader = torch.utils.data.DataLoader( 134 | dataloader.ImageFolder(valdir, transforms.Compose([ 135 | transforms.Scale(256), 136 | transforms.CenterCrop(224), 137 | transforms.ToTensor(), 138 | normalize, 139 | ])), 140 | batch_size=args.test_batch, shuffle=False, 141 | num_workers=args.workers, pin_memory=True) 142 | 143 | # create model 144 | print("=> creating model '{}'".format(args.arch)) 145 | model = models.__dict__[args.arch]() 146 | model = torch.nn.DataParallel(model).cuda() 147 | 148 | if args.pretrained != None: 149 | print("=> using pre-trained model '{}'".format(args.pretrained)) 150 | checkpoint = torch.load(args.pretrained) 151 | model.load_state_dict(checkpoint['state_dict']) 152 | 153 | cudnn.benchmark = True 154 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 155 | 156 | # define loss function (criterion) and optimizer 157 | criterion = nn.CrossEntropyLoss().cuda() 158 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 159 | 160 | # Resume 161 | title = 'ImageNet-' + args.arch 162 | if args.resume: 163 | # Load checkpoint. 164 | print('==> Resuming from checkpoint..') 165 | assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' 166 | args.checkpoint = os.path.dirname(args.resume) 167 | checkpoint = torch.load(args.resume) 168 | best_acc = checkpoint['best_acc'] 169 | start_epoch = checkpoint['epoch'] 170 | optimizer.load_state_dict(checkpoint['optimizer']) 171 | model.load_state_dict(checkpoint['state_dict']) 172 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True) 173 | else: 174 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) 175 | logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.']) 176 | 177 | 178 | if args.evaluate: 179 | print('\nEvaluation only') 180 | test_loss, test_acc = test(val_loader, model, criterion, start_epoch, use_cuda) 181 | print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc)) 182 | return 183 | 184 | # Train and val 185 | for epoch in range(start_epoch, args.epochs): 186 | adjust_learning_rate(optimizer, epoch) 187 | 188 | print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) 189 | 190 | train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, use_cuda) 191 | test_loss, test_acc = test(val_loader, model, criterion, epoch, use_cuda) 192 | 193 | # append logger file 194 | logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc]) 195 | 196 | # save model 197 | is_best = test_acc > best_acc 198 | best_acc = max(test_acc, best_acc) 199 | save_checkpoint({ 200 | 'epoch': epoch + 1, 201 | 'state_dict': model.state_dict(), 202 | 'acc': test_acc, 203 | 'best_acc': best_acc, 204 | 'optimizer' : optimizer.state_dict(), 205 | }, is_best, checkpoint=args.checkpoint, filename='checkpoint' + '.pth.tar') 206 | 207 | logger.close() 208 | logger.plot() 209 | savefig(os.path.join(args.checkpoint, 'log.eps')) 210 | 211 | print('Best acc:') 212 | print(best_acc) 213 | 214 | def train(train_loader, model, criterion, optimizer, epoch, use_cuda): 215 | # switch to train mode 216 | model.train() 217 | 218 | batch_time = AverageMeter() 219 | data_time = AverageMeter() 220 | losses = AverageMeter() 221 | top1 = AverageMeter() 222 | top5 = AverageMeter() 223 | end = time.time() 224 | 225 | bar = Bar('Processing', max=len(train_loader)) 226 | for batch_idx, (inputs, targets) in enumerate(train_loader): 227 | # measure data loading time 228 | data_time.update(time.time() - end) 229 | 230 | if use_cuda: 231 | inputs, targets = inputs.cuda(), targets.cuda(async=True) 232 | inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets) 233 | 234 | # compute output 235 | outputs = model(inputs) 236 | loss = criterion(outputs, targets) 237 | 238 | # measure accuracy and record loss 239 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 240 | losses.update(loss.data[0], inputs.size(0)) 241 | top1.update(prec1[0], inputs.size(0)) 242 | top5.update(prec5[0], inputs.size(0)) 243 | 244 | # compute gradient and do SGD step 245 | optimizer.zero_grad() 246 | loss.backward() 247 | optimizer.step() 248 | 249 | # measure elapsed time 250 | batch_time.update(time.time() - end) 251 | end = time.time() 252 | 253 | # plot progress 254 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 255 | batch=batch_idx + 1, 256 | size=len(train_loader), 257 | data=data_time.val, 258 | bt=batch_time.val, 259 | total=bar.elapsed_td, 260 | eta=bar.eta_td, 261 | loss=losses.avg, 262 | top1=top1.avg, 263 | top5=top5.avg, 264 | ) 265 | bar.next() 266 | bar.finish() 267 | return (losses.avg, top1.avg) 268 | 269 | def test(val_loader, model, criterion, epoch, use_cuda): 270 | global best_acc 271 | 272 | batch_time = AverageMeter() 273 | data_time = AverageMeter() 274 | losses = AverageMeter() 275 | top1 = AverageMeter() 276 | top5 = AverageMeter() 277 | 278 | # switch to evaluate mode 279 | model.eval() 280 | 281 | end = time.time() 282 | bar = Bar('Processing', max=len(val_loader)) 283 | for batch_idx, (inputs, targets) in enumerate(val_loader): 284 | # measure data loading time 285 | data_time.update(time.time() - end) 286 | 287 | if use_cuda: 288 | inputs, targets = inputs.cuda(), targets.cuda() 289 | inputs, targets = torch.autograd.Variable(inputs, volatile=True), torch.autograd.Variable(targets) 290 | 291 | # compute output 292 | outputs = model(inputs) 293 | loss = criterion(outputs, targets) 294 | 295 | # measure accuracy and record loss 296 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 297 | losses.update(loss.data[0], inputs.size(0)) 298 | top1.update(prec1[0], inputs.size(0)) 299 | top5.update(prec5[0], inputs.size(0)) 300 | 301 | # measure elapsed time 302 | batch_time.update(time.time() - end) 303 | end = time.time() 304 | 305 | # plot progress 306 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 307 | batch=batch_idx + 1, 308 | size=len(val_loader), 309 | data=data_time.avg, 310 | bt=batch_time.avg, 311 | total=bar.elapsed_td, 312 | eta=bar.eta_td, 313 | loss=losses.avg, 314 | top1=top1.avg, 315 | top5=top5.avg, 316 | ) 317 | bar.next() 318 | bar.finish() 319 | return (losses.avg, top1.avg) 320 | 321 | def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'): 322 | print (checkpoint, filename) 323 | filepath = os.path.join(checkpoint, filename) 324 | torch.save(state, filepath) 325 | if is_best: 326 | shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) 327 | 328 | def adjust_learning_rate(optimizer, epoch): 329 | global state 330 | print('epoch:', epoch) 331 | if epoch in args.schedule: 332 | state['lr'] *= args.gamma 333 | print('lr:', state['lr']) 334 | for param_group in optimizer.param_groups: 335 | param_group['lr'] = state['lr'] 336 | 337 | if __name__ == '__main__': 338 | main() 339 | -------------------------------------------------------------------------------- /imagenet_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import cv2 7 | 8 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] 9 | 10 | 11 | def is_image_file(filename): 12 | """Checks if a file is an image. 13 | 14 | Args: 15 | filename (string): path to a file 16 | 17 | Returns: 18 | bool: True if the filename ends with a known image extension 19 | """ 20 | filename_lower = filename.lower() 21 | return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) 22 | 23 | 24 | def find_classes(dir): 25 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 26 | classes.sort() 27 | class_to_idx = {classes[i]: i for i in range(len(classes))} 28 | return classes, class_to_idx 29 | 30 | 31 | def make_dataset(dir, class_to_idx): 32 | images = [] 33 | dir = os.path.expanduser(dir) 34 | for target in sorted(os.listdir(dir)): 35 | d = os.path.join(dir, target) 36 | if not os.path.isdir(d): 37 | continue 38 | 39 | for root, _, fnames in sorted(os.walk(d)): 40 | for fname in sorted(fnames): 41 | if is_image_file(fname): 42 | path = os.path.join(root, fname) 43 | item = (path, class_to_idx[target]) 44 | images.append(item) 45 | 46 | return images 47 | 48 | 49 | def pil_loader(path): 50 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 51 | # with open(path, 'rb') as f: 52 | # img = Image.open(f) 53 | # return img.convert('RGB') 54 | img = cv2.imread(path) 55 | img = img[:, :, [2, 1, 0]] 56 | img = Image.fromarray(img) 57 | return img.convert('RGB') 58 | 59 | 60 | def accimage_loader(path): 61 | import accimage 62 | try: 63 | return accimage.Image(path) 64 | except IOError: 65 | # Potentially a decoding problem, fall back to PIL.Image 66 | return pil_loader(path) 67 | 68 | 69 | def default_loader(path): 70 | from torchvision import get_image_backend 71 | if get_image_backend() == 'accimage': 72 | return accimage_loader(path) 73 | else: 74 | return pil_loader(path) 75 | 76 | 77 | class ImageFolder(data.Dataset): 78 | """A generic data loader where the images are arranged in this way: :: 79 | 80 | root/dog/xxx.png 81 | root/dog/xxy.png 82 | root/dog/xxz.png 83 | 84 | root/cat/123.png 85 | root/cat/nsdf3.png 86 | root/cat/asd932_.png 87 | 88 | Args: 89 | root (string): Root directory path. 90 | transform (callable, optional): A function/transform that takes in an PIL image 91 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 92 | target_transform (callable, optional): A function/transform that takes in the 93 | target and transforms it. 94 | loader (callable, optional): A function to load an image given its path. 95 | 96 | Attributes: 97 | classes (list): List of the class names. 98 | class_to_idx (dict): Dict with items (class_name, class_index). 99 | imgs (list): List of (image path, class_index) tuples 100 | """ 101 | 102 | def __init__(self, root, transform=None, target_transform=None, 103 | loader=default_loader): 104 | classes, class_to_idx = find_classes(root) 105 | imgs = make_dataset(root, class_to_idx) 106 | if len(imgs) == 0: 107 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 108 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 109 | 110 | self.root = root 111 | self.imgs = imgs 112 | self.classes = classes 113 | self.class_to_idx = class_to_idx 114 | self.transform = transform 115 | self.target_transform = target_transform 116 | self.loader = loader 117 | 118 | def __getitem__(self, index): 119 | """ 120 | Args: 121 | index (int): Index 122 | 123 | Returns: 124 | tuple: (image, target) where target is class_index of the target class. 125 | """ 126 | path, target = self.imgs[index] 127 | img = self.loader(path) 128 | if self.transform is not None: 129 | img = self.transform(img) 130 | if self.target_transform is not None: 131 | target = self.target_transform(target) 132 | 133 | return img, target 134 | 135 | def __len__(self): 136 | return len(self.imgs) 137 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepInsight-PCALab/MixNet/5cd7150f70eadb0c18d6025af9bce07c462f017c/models/__init__.py -------------------------------------------------------------------------------- /models/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepInsight-PCALab/MixNet/5cd7150f70eadb0c18d6025af9bce07c462f017c/models/__init__.pyc -------------------------------------------------------------------------------- /models/cifar/__init__.py: -------------------------------------------------------------------------------- 1 | from mixnet import * -------------------------------------------------------------------------------- /models/cifar/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepInsight-PCALab/MixNet/5cd7150f70eadb0c18d6025af9bce07c462f017c/models/cifar/__init__.pyc -------------------------------------------------------------------------------- /models/cifar/mixnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | __all__ = ['mixnet'] 8 | 9 | 10 | from torch.autograd import Variable 11 | 12 | class Bottleneck(nn.Module): 13 | def __init__(self, inplanes, expansion=4, k1=12, k2=12, dropRate=0): 14 | super(Bottleneck, self).__init__() 15 | # inner link module 16 | if k1 > 0: 17 | planes = expansion * k1 18 | self.bn1_1 = nn.BatchNorm2d(inplanes) 19 | self.conv1_1 = nn.Conv2d(inplanes, planes, kernel_size = 1, bias = False) 20 | self.bn1_2 = nn.BatchNorm2d(planes) 21 | self.conv1_2 = nn.Conv2d(planes, k1, kernel_size = 3, padding = 1, bias = False) 22 | 23 | # outer link module 24 | if k2 > 0: 25 | planes = expansion * k2 26 | self.bn2_1 = nn.BatchNorm2d(inplanes) 27 | self.conv2_1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 28 | self.bn2_2 = nn.BatchNorm2d(planes) 29 | self.conv2_2 = nn.Conv2d(planes, k2, kernel_size=3, padding=1, bias=False) 30 | 31 | self.dropRate = dropRate 32 | self.relu = nn.ReLU(inplace=True) 33 | self.k1 = k1 34 | self.k2 = k2 35 | 36 | def forward(self, x): 37 | if self.k1 > 0: 38 | inner_link = self.bn1_1(x) 39 | inner_link = self.relu(inner_link) 40 | inner_link = self.conv1_1(inner_link) 41 | inner_link = self.bn1_2(inner_link) 42 | inner_link = self.relu(inner_link) 43 | inner_link = self.conv1_2(inner_link) 44 | 45 | if self.k2 > 0: 46 | outer_link = self.bn2_1(x) 47 | outer_link = self.relu(outer_link) 48 | outer_link = self.conv2_1(outer_link) 49 | outer_link = self.bn2_2(outer_link) 50 | outer_link = self.relu(outer_link) 51 | outer_link = self.conv2_2(outer_link) 52 | 53 | if self.dropRate > 0: 54 | inner_link = F.dropout(inner_link, p=self.dropRate, training=self.training) 55 | outer_link = F.dropout(outer_link, p=self.dropRate, training=self.training) 56 | 57 | 58 | c = x.size(1) 59 | if self.k1 > 0 and self.k1 < c: 60 | xl = x[:, 0: c - self.k1, :, :] 61 | xr = x[:, c - self.k1: c, :, :] + inner_link 62 | x = torch.cat((xl, xr), 1) 63 | elif self.k1 == c: 64 | x = x + inner_link 65 | 66 | if self.k2 > 0: 67 | out = torch.cat((x, outer_link), 1) 68 | else: 69 | out = x 70 | 71 | return out 72 | 73 | class Transition(nn.Module): 74 | def __init__(self, inplanes, outplanes): 75 | super(Transition, self).__init__() 76 | self.bn = nn.BatchNorm2d(inplanes) 77 | self.conv = nn.Conv2d(inplanes, outplanes, kernel_size=1, 78 | bias=False) 79 | self.relu = nn.ReLU(inplace=True) 80 | 81 | def forward(self, x): 82 | out = self.bn(x) 83 | out = self.relu(out) 84 | out = self.conv(out) 85 | out = F.avg_pool2d(out, 2) 86 | return out 87 | 88 | class MixNet(nn.Module): 89 | 90 | def __init__(self, 91 | depth=100, 92 | unit=Bottleneck, 93 | dropRate=0, 94 | num_classes=10, 95 | k1=12, 96 | k2=12, 97 | compressionRate=2): 98 | super(MixNet, self).__init__() 99 | 100 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' 101 | n = (depth - 4) // 6 102 | 103 | self.k2 = k2 104 | self.k1 = k1 105 | self.dropRate = dropRate 106 | 107 | self.inplanes = max(k2 * 2, k1) 108 | 109 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, 110 | bias=False) 111 | self.block1 = self._make_block(unit, n) 112 | self.trans1 = self._make_transition(compressionRate) 113 | self.block2 = self._make_block(unit, n) 114 | self.trans2 = self._make_transition(compressionRate) 115 | self.block3 = self._make_block(unit, n) 116 | self.bn = nn.BatchNorm2d(self.inplanes) 117 | self.relu = nn.ReLU(inplace=True) 118 | self.avgpool = nn.AvgPool2d(8) 119 | self.fc = nn.Linear(self.inplanes, num_classes) 120 | 121 | # Weight initialization 122 | for m in self.modules(): 123 | if isinstance(m, nn.Conv2d): 124 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 125 | m.weight.data.normal_(0, math.sqrt(2. / n)) 126 | elif isinstance(m, nn.BatchNorm2d): 127 | m.weight.data.fill_(1) 128 | m.bias.data.zero_() 129 | 130 | def _make_block(self, unit, unit_num): 131 | layers = [] 132 | for i in range(unit_num): 133 | # Currently we fix the expansion ratio as the default value 134 | layers.append(unit(self.inplanes, k1=self.k1, k2=self.k2, dropRate=self.dropRate)) 135 | self.inplanes += self.k2 136 | 137 | return nn.Sequential(*layers) 138 | 139 | def _make_transition(self, compressionRate): 140 | inplanes = self.inplanes 141 | outplanes = max(int(math.floor(self.inplanes // compressionRate)), self.k1) 142 | self.inplanes = outplanes 143 | return Transition(inplanes, outplanes) 144 | 145 | 146 | def forward(self, x): 147 | x = self.conv1(x) 148 | x = self.block1(x) 149 | x = self.trans1(x) 150 | x = self.block2(x) 151 | x = self.trans2(x) 152 | x = self.block3(x) 153 | x = self.bn(x) 154 | x = self.relu(x) 155 | x = self.avgpool(x) 156 | x = x.view(x.size(0), -1) 157 | x = self.fc(x) 158 | 159 | return x 160 | 161 | 162 | def mixnet(**kwargs): 163 | """ 164 | Constructs a ResNet model. 165 | """ 166 | return MixNet(**kwargs) 167 | -------------------------------------------------------------------------------- /models/cifar/mixnet.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepInsight-PCALab/MixNet/5cd7150f70eadb0c18d6025af9bce07c462f017c/models/cifar/mixnet.pyc -------------------------------------------------------------------------------- /models/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | from mixnet import * -------------------------------------------------------------------------------- /models/imagenet/mixnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.model_zoo as model_zoo 5 | from collections import OrderedDict 6 | 7 | __all__ = ['MixNet', 'mixnet105', 'mixnet121', 'mixnet141'] 8 | 9 | 10 | def mixnet105(**kwargs): 11 | model = MixNet(num_init_features=64, k1=32, k2=32, block_config=(6, 12, 20, 12), **kwargs) 12 | return model 13 | 14 | def mixnet121(**kwargs): 15 | model = MixNet(num_init_features=80, k1=40, k2=40, block_config=(6, 12, 24, 16), **kwargs) 16 | return model 17 | 18 | def mixnet141(**kwargs): 19 | model = MixNet(num_init_features=96, k1=48, k2=48, block_config=(6, 12, 30, 20), **kwargs) 20 | return model 21 | 22 | class _MixLayer(nn.Sequential): 23 | def __init__(self, num_input_features, expansion, k1, k2, drop_rate): 24 | super(_MixLayer, self).__init__() 25 | if k1 > 0: 26 | self.bn1_1 = nn.BatchNorm2d(num_input_features) 27 | self.conv1_1 = nn.Conv2d(num_input_features, expansion * k1, kernel_size = 1, stride=1, bias = False) 28 | self.bn1_2 = nn.BatchNorm2d(expansion * k1) 29 | self.conv1_2 = nn.Conv2d(expansion * k1, k1, kernel_size = 3, stride=1, padding = 1, bias = False) 30 | 31 | if k2 > 0: 32 | self.bn2_1 = nn.BatchNorm2d(num_input_features) 33 | self.conv2_1 = nn.Conv2d(num_input_features, expansion * k2, kernel_size = 1, stride=1, bias = False) 34 | self.bn2_2 = nn.BatchNorm2d(expansion * k2) 35 | self.conv2_2 = nn.Conv2d(expansion * k2, k2, kernel_size = 3, stride=1, padding = 1, bias = False) 36 | 37 | self.drop_rate = drop_rate 38 | self.relu = nn.ReLU(inplace=True) 39 | self.k1 = k1 40 | self.k2 = k2 41 | 42 | def forward(self, x): 43 | if self.k1 > 0: 44 | inner_link = self.bn1_1(x) 45 | inner_link = self.relu(inner_link) 46 | inner_link = self.conv1_1(inner_link) 47 | inner_link = self.bn1_2(inner_link) 48 | inner_link = self.relu(inner_link) 49 | inner_link = self.conv1_2(inner_link) 50 | 51 | if self.k2 > 0: 52 | outer_link = self.bn2_1(x) 53 | outer_link = self.relu(outer_link) 54 | outer_link = self.conv2_1(outer_link) 55 | outer_link = self.bn2_2(outer_link) 56 | outer_link = self.relu(outer_link) 57 | outer_link = self.conv2_2(outer_link) 58 | 59 | if self.drop_rate > 0: 60 | inner_link = F.dropout(inner_link, p=self.drop_rate, training=self.training) 61 | outer_link = F.dropout(outer_link, p=self.drop_rate, training=self.training) 62 | 63 | c = x.size(1) 64 | if self.k1 > 0 and self.k1 < c: 65 | xl = x[:, 0: c - self.k1, :, :] 66 | xr = x[:, c - self.k1: c, :, :] + inner_link 67 | x = torch.cat((xl, xr), 1) 68 | elif self.k1 == c: 69 | x = x + inner_link 70 | 71 | if self.k2 > 0: 72 | out = torch.cat((x, outer_link), 1) 73 | else: 74 | out = x 75 | 76 | return out 77 | 78 | 79 | class Block(nn.Sequential): 80 | def __init__(self, num_layers, num_input_features, expansion, k1, k2, drop_rate): 81 | super(Block, self).__init__() 82 | for i in range(num_layers): 83 | layer = _MixLayer(num_input_features + i * k2, expansion, k1, k2, drop_rate) 84 | self.add_module('mixlayer%d' % (i + 1), layer) 85 | 86 | 87 | class _Transition(nn.Sequential): 88 | def __init__(self, num_input_features, num_output_features): 89 | super(_Transition, self).__init__() 90 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 91 | self.add_module('relu', nn.ReLU(inplace=True)) 92 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)) 93 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 94 | 95 | class MixNet(nn.Module): 96 | def __init__(self, block_config=(6, 12, 24, 16), num_init_features=64, expansion=4, k1=32, k2=32, drop_rate=0, num_classes=1000): 97 | 98 | super(MixNet, self).__init__() 99 | print('k1: ', k1, 'k2: ', k2) 100 | 101 | # First convolution 102 | self.features = nn.Sequential(OrderedDict([ 103 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 104 | ('norm0', nn.BatchNorm2d(num_init_features)), 105 | ('relu0', nn.ReLU(inplace=True)), 106 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 107 | ])) 108 | 109 | # Each block 110 | num_features = num_init_features 111 | for i, num_layers in enumerate(block_config): 112 | block = Block(num_layers=num_layers, num_input_features=num_features, expansion=expansion, k1=k1, k2=k2, drop_rate=drop_rate) 113 | self.features.add_module('block%d' % (i + 1), block) 114 | num_features = num_features + num_layers * k2 115 | if i != len(block_config) - 1: 116 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 117 | self.features.add_module('transition%d' % (i + 1), trans) 118 | num_features = max(num_features // 2, k1) 119 | 120 | # Final batch norm 121 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 122 | 123 | # Linear layer 124 | self.classifier = nn.Linear(num_features, num_classes) 125 | 126 | # Official init from torch repo. 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | nn.init.kaiming_normal(m.weight.data) 130 | elif isinstance(m, nn.BatchNorm2d): 131 | m.weight.data.fill_(1) 132 | m.bias.data.zero_() 133 | elif isinstance(m, nn.Linear): 134 | m.bias.data.zero_() 135 | 136 | def forward(self, x): 137 | features = self.features(x) 138 | out = F.relu(features, inplace=True) 139 | out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1) 140 | out = self.classifier(out) 141 | return out 142 | -------------------------------------------------------------------------------- /pretrained/README.md: -------------------------------------------------------------------------------- 1 | ## Pretrained Models on ImageNet 2 | 3 | | Method | Pretrained model | 4 | | - | - | 5 | | MixNet-105 (k1 = 32, k2 = 32) | [Download(43.2M)](https://pan.baidu.com/s/1q-LjwofEu2nM7feZClTA7w) | 6 | | MixNet-121 (k1 = 40, k2 = 40) | [Download(84.3M)](https://pan.baidu.com/s/1wIzkO0UVIXd_BPx_lmT7_w) | 7 | | MixNet-141 (k1 = 48, k2 = 48) | [Download(158.1M)](https://pan.baidu.com/s/1lYczUcAczhkQqpEwjZT66Q) | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .visualize import * 6 | from .eval import * 7 | 8 | # progress bar 9 | import os, sys 10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 11 | from progress.bar import Bar as Bar -------------------------------------------------------------------------------- /utils/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepInsight-PCALab/MixNet/5cd7150f70eadb0c18d6025af9bce07c462f017c/utils/__init__.pyc -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | __all__ = ['accuracy'] 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the precision@k for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res -------------------------------------------------------------------------------- /utils/eval.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepInsight-PCALab/MixNet/5cd7150f70eadb0c18d6025af9bce07c462f017c/utils/eval.pyc -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | # import matplotlib.pyplot as plt 5 | import matplotlib 6 | matplotlib.use('pdf') 7 | import matplotlib.pyplot as plt 8 | import os 9 | import sys 10 | import numpy as np 11 | 12 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 13 | 14 | def savefig(fname, dpi=None): 15 | dpi = 150 if dpi == None else dpi 16 | plt.savefig(fname, dpi=dpi) 17 | 18 | def plot_overlap(logger, names=None): 19 | names = logger.names if names == None else names 20 | numbers = logger.numbers 21 | for _, name in enumerate(names): 22 | x = np.arange(len(numbers[name])) 23 | plt.plot(x, np.asarray(numbers[name])) 24 | return [logger.title + '(' + name + ')' for name in names] 25 | 26 | class Logger(object): 27 | '''Save training process to log file with simple plot function.''' 28 | def __init__(self, fpath, title=None, resume=False): 29 | self.file = None 30 | self.resume = resume 31 | self.title = '' if title == None else title 32 | if fpath is not None: 33 | if resume: 34 | self.file = open(fpath, 'r') 35 | name = self.file.readline() 36 | self.names = name.rstrip().split('\t') 37 | self.numbers = {} 38 | for _, name in enumerate(self.names): 39 | self.numbers[name] = [] 40 | 41 | for numbers in self.file: 42 | numbers = numbers.rstrip().split('\t') 43 | for i in range(0, len(numbers)): 44 | self.numbers[self.names[i]].append(numbers[i]) 45 | self.file.close() 46 | self.file = open(fpath, 'a') 47 | else: 48 | self.file = open(fpath, 'w') 49 | 50 | def set_names(self, names): 51 | if self.resume: 52 | pass 53 | # initialize numbers as empty list 54 | self.numbers = {} 55 | self.names = names 56 | for _, name in enumerate(self.names): 57 | self.file.write(name) 58 | self.file.write('\t') 59 | self.numbers[name] = [] 60 | self.file.write('\n') 61 | self.file.flush() 62 | 63 | 64 | def append(self, numbers): 65 | assert len(self.names) == len(numbers), 'Numbers do not match names' 66 | for index, num in enumerate(numbers): 67 | self.file.write("{0:.6f}".format(num)) 68 | self.file.write('\t') 69 | self.numbers[self.names[index]].append(num) 70 | self.file.write('\n') 71 | self.file.flush() 72 | 73 | def plot(self, names=None): 74 | print 'plot' 75 | ''' 76 | names = self.names if names == None else names 77 | numbers = self.numbers 78 | for _, name in enumerate(names): 79 | x = np.arange(len(numbers[name])) 80 | plt.plot(x, np.asarray(numbers[name])) 81 | plt.legend([self.title + '(' + name + ')' for name in names]) 82 | plt.grid(True) 83 | ''' 84 | 85 | def close(self): 86 | if self.file is not None: 87 | self.file.close() 88 | 89 | class LoggerMonitor(object): 90 | '''Load and visualize multiple logs.''' 91 | def __init__ (self, paths): 92 | '''paths is a distionary with {name:filepath} pair''' 93 | self.loggers = [] 94 | for title, path in paths.items(): 95 | logger = Logger(path, title=title, resume=True) 96 | self.loggers.append(logger) 97 | 98 | def plot(self, names=None): 99 | plt.figure() 100 | plt.subplot(121) 101 | legend_text = [] 102 | for logger in self.loggers: 103 | legend_text += plot_overlap(logger, names) 104 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 105 | plt.grid(True) 106 | 107 | if __name__ == '__main__': 108 | # # Example 109 | # logger = Logger('test.txt') 110 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 111 | 112 | # length = 100 113 | # t = np.arange(length) 114 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 115 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 116 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 117 | 118 | # for i in range(0, length): 119 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 120 | # logger.plot() 121 | 122 | # Example: logger monitor 123 | paths = { 124 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 125 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 126 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 127 | } 128 | 129 | field = ['Valid Acc.'] 130 | 131 | monitor = LoggerMonitor(paths) 132 | monitor.plot(names=field) 133 | savefig('test.eps') -------------------------------------------------------------------------------- /utils/logger.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepInsight-PCALab/MixNet/5cd7150f70eadb0c18d6025af9bce07c462f017c/utils/logger.pyc -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import math 11 | 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 17 | 18 | 19 | def get_mean_and_std(dataset): 20 | '''Compute the mean and std value of dataset.''' 21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 22 | 23 | mean = torch.zeros(3) 24 | std = torch.zeros(3) 25 | print('==> Computing mean and std..') 26 | for inputs, targets in dataloader: 27 | for i in range(3): 28 | mean[i] += inputs[:,i,:,:].mean() 29 | std[i] += inputs[:,i,:,:].std() 30 | mean.div_(len(dataset)) 31 | std.div_(len(dataset)) 32 | return mean, std 33 | 34 | def init_params(net): 35 | '''Init layer parameters.''' 36 | for m in net.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal(m.weight, mode='fan_out') 39 | if m.bias: 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant(m.weight, 1) 43 | init.constant(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal(m.weight, std=1e-3) 46 | if m.bias: 47 | init.constant(m.bias, 0) 48 | 49 | def mkdir_p(path): 50 | '''make dir if not exist''' 51 | try: 52 | os.makedirs(path) 53 | except OSError as exc: # Python >2.5 54 | if exc.errno == errno.EEXIST and os.path.isdir(path): 55 | pass 56 | else: 57 | raise 58 | 59 | class AverageMeter(object): 60 | """Computes and stores the average and current value""" 61 | def __init__(self): 62 | self.reset() 63 | 64 | def reset(self): 65 | self.val = 0 66 | self.avg = 0 67 | self.sum = 0 68 | self.count = 0 69 | 70 | def update(self, val, n=1): 71 | self.val = val 72 | self.sum += val * n 73 | self.count += n 74 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /utils/misc.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepInsight-PCALab/MixNet/5cd7150f70eadb0c18d6025af9bce07c462f017c/utils/misc.pyc -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | from .misc import * 8 | 9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] 10 | 11 | # functions to show an image 12 | def make_image(img, mean=(0,0,0), std=(1,1,1)): 13 | for i in range(0, 3): 14 | img[i] = img[i] * std[i] + mean[i] # unnormalize 15 | npimg = img.numpy() 16 | return np.transpose(npimg, (1, 2, 0)) 17 | 18 | def gauss(x,a,b,c): 19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a) 20 | 21 | def colorize(x): 22 | ''' Converts a one-channel grayscale image to a color heatmap image ''' 23 | if x.dim() == 2: 24 | torch.unsqueeze(x, 0, out=x) 25 | if x.dim() == 3: 26 | cl = torch.zeros([3, x.size(1), x.size(2)]) 27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 28 | cl[1] = gauss(x,1,.5,.3) 29 | cl[2] = gauss(x,1,.2,.3) 30 | cl[cl.gt(1)] = 1 31 | elif x.dim() == 4: 32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) 33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 34 | cl[:,1,:,:] = gauss(x,1,.5,.3) 35 | cl[:,2,:,:] = gauss(x,1,.2,.3) 36 | return cl 37 | 38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 40 | plt.imshow(images) 41 | plt.show() 42 | 43 | 44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 45 | im_size = images.size(2) 46 | 47 | # save for adding mask 48 | im_data = images.clone() 49 | for i in range(0, 3): 50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 51 | 52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 53 | plt.subplot(2, 1, 1) 54 | plt.imshow(images) 55 | plt.axis('off') 56 | 57 | # for b in range(mask.size(0)): 58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 59 | mask_size = mask.size(2) 60 | # print('Max %f Min %f' % (mask.max(), mask.min())) 61 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 63 | # for c in range(3): 64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 65 | 66 | # print(mask.size()) 67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 69 | plt.subplot(2, 1, 2) 70 | plt.imshow(mask) 71 | plt.axis('off') 72 | 73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 74 | im_size = images.size(2) 75 | 76 | # save for adding mask 77 | im_data = images.clone() 78 | for i in range(0, 3): 79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 80 | 81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 82 | plt.subplot(1+len(masklist), 1, 1) 83 | plt.imshow(images) 84 | plt.axis('off') 85 | 86 | for i in range(len(masklist)): 87 | mask = masklist[i].data.cpu() 88 | # for b in range(mask.size(0)): 89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 90 | mask_size = mask.size(2) 91 | # print('Max %f Min %f' % (mask.max(), mask.min())) 92 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 94 | # for c in range(3): 95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 96 | 97 | # print(mask.size()) 98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 100 | plt.subplot(1+len(masklist), 1, i+2) 101 | plt.imshow(mask) 102 | plt.axis('off') 103 | 104 | 105 | 106 | # x = torch.zeros(1, 3, 3) 107 | # out = colorize(x) 108 | # out_im = make_image(out) 109 | # plt.imshow(out_im) 110 | # plt.show() -------------------------------------------------------------------------------- /utils/visualize.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepInsight-PCALab/MixNet/5cd7150f70eadb0c18d6025af9bce07c462f017c/utils/visualize.pyc --------------------------------------------------------------------------------