├── README.md ├── extract_features.py ├── logger.py ├── main.py ├── models ├── __init__.py ├── __init__.pyc ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── resnext.cpython-36.pyc ├── caffe_cifar.py ├── caffe_cifar.pyc ├── densenet.py ├── densenet.pyc ├── imagenet_resnet.py ├── imagenet_resnet.pyc ├── ops.py ├── ops.pyc ├── preresnet.py ├── preresnet.pyc ├── res_utils.py ├── res_utils.pyc ├── resnet.py ├── resnet.pyc ├── resnet_mod.py ├── resnet_mod.pyc ├── resnext.py └── resnext.pyc ├── pretrained_model ├── checkpoint.pth.tar ├── curve.png ├── log_seed_9828.txt └── model_best.pth.tar ├── utils.py └── view ├── class0.png ├── class1.png ├── class2.png ├── class3.png ├── class4.png ├── class5.png ├── class6.png ├── class7.png ├── class8.png ├── class9.png ├── features.npy ├── labels.npy ├── log_seed_7102.txt ├── view.py └── weights.npy /README.md: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py $PATH_TO_CIFAR10$ --dataset cifar10 --arch resnet110 --save_path ./110epoch500D8 --epochs 500 --schedule 250 375 --gammas 0.1 0.1 --learning_rate 0.1 --decay 0.0001 --batch_size 128 --Ddim 8 2 | 3 | pre-trained model: 4 | ./pretrained_model 5 | -------------------------------------------------------------------------------- /extract_features.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import os, sys, pdb, shutil, time, random 4 | import argparse 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import torchvision.datasets as dset 8 | import torchvision.transforms as transforms 9 | from utils import AverageMeter, RecorderMeter, time_string, convert_secs2time 10 | import models 11 | import numpy as np 12 | 13 | model_names = sorted(name for name in models.__dict__ 14 | if name.islower() and not name.startswith("__") 15 | and callable(models.__dict__[name])) 16 | 17 | parser = argparse.ArgumentParser(description='Trains ResNeXt on CIFAR or ImageNet', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | parser.add_argument('data_path', type=str, help='Path to dataset') 19 | parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'imagenet', 'svhn', 'stl10'], help='Choose between Cifar10/100 and ImageNet.') 20 | parser.add_argument('--arch', metavar='ARCH', default='resnet18', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnext29_8_64)') 21 | # Optimization options 22 | parser.add_argument('--batch_size', type=int, default=128, help='Batch size.') 23 | # Checkpoints 24 | parser.add_argument('--save_path', type=str, default='./', help='Folder to save checkpoints and log.') 25 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 26 | # Acceleration 27 | parser.add_argument('--ngpu', type=int, default=1, help='0 = CPU.') 28 | parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') 29 | # random seed 30 | parser.add_argument('--manualSeed', type=int, help='manual seed') 31 | parser.add_argument('--Ddim', type=int, default=4) 32 | args = parser.parse_args() 33 | args.use_cuda = args.ngpu>0 and torch.cuda.is_available() 34 | 35 | if args.manualSeed is None: 36 | args.manualSeed = random.randint(1, 10000) 37 | random.seed(args.manualSeed) 38 | torch.manual_seed(args.manualSeed) 39 | if args.use_cuda: 40 | torch.cuda.manual_seed_all(args.manualSeed) 41 | cudnn.benchmark = True 42 | 43 | def main(): 44 | # Init logger 45 | if not os.path.isdir(args.save_path): 46 | os.makedirs(args.save_path) 47 | log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w') 48 | print_log('save path : {}'.format(args.save_path), log) 49 | state = {k: v for k, v in args._get_kwargs()} 50 | print_log(state, log) 51 | print_log("Random Seed: {}".format(args.manualSeed), log) 52 | print_log("python version : {}".format(sys.version.replace('\n', ' ')), log) 53 | print_log("torch version : {}".format(torch.__version__), log) 54 | print_log("cudnn version : {}".format(torch.backends.cudnn.version()), log) 55 | 56 | # Init dataset 57 | if not os.path.isdir(args.data_path): 58 | os.makedirs(args.data_path) 59 | 60 | if args.dataset == 'cifar10': 61 | mean = [x / 255 for x in [125.3, 123.0, 113.9]] 62 | std = [x / 255 for x in [63.0, 62.1, 66.7]] 63 | elif args.dataset == 'cifar100': 64 | mean = [x / 255 for x in [129.3, 124.1, 112.4]] 65 | std = [x / 255 for x in [68.2, 65.4, 70.4]] 66 | else: 67 | assert False, "Unknow dataset : {}".format(args.dataset) 68 | 69 | train_transform = transforms.Compose( 70 | [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), 71 | transforms.Normalize(mean, std)]) 72 | test_transform = transforms.Compose( 73 | [transforms.ToTensor(), transforms.Normalize(mean, std)]) 74 | 75 | if args.dataset == 'cifar10': 76 | train_data = dset.CIFAR10(args.data_path, train=True, transform=train_transform, download=True) 77 | test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform, download=True) 78 | num_classes = 10 79 | elif args.dataset == 'cifar100': 80 | train_data = dset.CIFAR100(args.data_path, train=True, transform=train_transform, download=True) 81 | test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True) 82 | num_classes = 100 83 | elif args.dataset == 'svhn': 84 | train_data = dset.SVHN(args.data_path, split='train', transform=train_transform, download=True) 85 | test_data = dset.SVHN(args.data_path, split='test', transform=test_transform, download=True) 86 | num_classes = 10 87 | elif args.dataset == 'stl10': 88 | train_data = dset.STL10(args.data_path, split='train', transform=train_transform, download=True) 89 | test_data = dset.STL10(args.data_path, split='test', transform=test_transform, download=True) 90 | num_classes = 10 91 | elif args.dataset == 'imagenet': 92 | assert False, 'Do not finish imagenet code' 93 | else: 94 | assert False, 'Do not support dataset : {}'.format(args.dataset) 95 | 96 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, 97 | num_workers=args.workers, pin_memory=True) 98 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, 99 | num_workers=args.workers, pin_memory=True) 100 | 101 | print_log("=> creating model '{}'".format(args.arch), log) 102 | # Init model, criterion, and optimizer 103 | net = models.__dict__[args.arch](num_classes, args.Ddim) 104 | print_log("=> network :\n {}".format(net), log) 105 | 106 | net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu))) 107 | 108 | if args.use_cuda: 109 | net.cuda() 110 | 111 | # optionally resume from a checkpoint 112 | if args.resume: 113 | if os.path.isfile(args.resume): 114 | print_log("=> loading checkpoint '{}'".format(args.resume), log) 115 | checkpoint = torch.load(args.resume) 116 | recorder = checkpoint['recorder'] 117 | net.load_state_dict(checkpoint['state_dict']) 118 | print_log("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch']), log) 119 | else: 120 | print_log("=> no checkpoint found at '{}'".format(args.resume), log) 121 | else: 122 | print_log("=> do not use any checkpoint for {} model".format(args.arch), log) 123 | 124 | # Main loop 125 | 126 | val_acc = validate(train_loader, net, log, args.Ddim, args.save_path) 127 | print(val_acc) 128 | 129 | # for epoch in range(args.start_epoch, args.epochs): 130 | # # train for one epoch 131 | # train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log, args.Ddim) 132 | # 133 | # # evaluate on validation set 134 | # #val_acc, val_los = extract_features(test_loader, net, criterion, log) 135 | # val_acc, val_los = validate(test_loader, net, criterion, log, args.Ddim) 136 | # is_best = recorder.update(epoch, train_los, train_acc, val_los, val_acc) 137 | # 138 | # save_checkpoint({ 139 | # 'epoch': epoch + 1, 140 | # 'arch': args.arch, 141 | # 'state_dict': net.state_dict(), 142 | # 'recorder': recorder, 143 | # 'optimizer' : optimizer.state_dict(), 144 | # }, is_best, args.save_path, 'checkpoint.pth.tar') 145 | # 146 | # # measure elapsed time 147 | # epoch_time.update(time.time() - start_time) 148 | # start_time = time.time() 149 | # recorder.plot_curve( os.path.join(args.save_path, 'curve.png') ) 150 | 151 | log.close() 152 | 153 | 154 | def validate(val_loader, model, log, Ddim, save_path): 155 | losses = AverageMeter() 156 | top1 = AverageMeter() 157 | top5 = AverageMeter() 158 | 159 | # switch to evaluate mode 160 | model.eval() 161 | 162 | for i, (input, target) in enumerate(val_loader): 163 | eye = torch.eye(Ddim) 164 | if args.use_cuda: 165 | target = target.cuda(async=True) 166 | input = input.cuda() 167 | eye = eye.cuda() 168 | input_var = torch.autograd.Variable(input, volatile=True) 169 | target_var = torch.autograd.Variable(target, volatile=True) 170 | eye_var = torch.autograd.Variable(eye, volatile=True) 171 | 172 | # compute output 173 | features, output = model(input_var, eye_var) 174 | if i == 0: 175 | features_data = features.cpu().data.numpy() 176 | target_data = target.cpu().numpy() 177 | else: 178 | features_data = np.concatenate((features_data, features.cpu().data.numpy()), axis = 0) 179 | target_data = np.concatenate((target_data, target.cpu().numpy()), axis = 0) 180 | 181 | # measure accuracy and record loss 182 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 183 | top1.update(prec1[0], input.size(0)) 184 | top5.update(prec5[0], input.size(0)) 185 | 186 | np.save(save_path+'/features',features_data) 187 | np.save(save_path+'/labels',target_data) 188 | np.save(save_path+'/weights',model.module.classifier.weight.cpu().data.numpy()) 189 | print(features_data.shape) 190 | print(target_data.shape) 191 | return top1.avg 192 | 193 | def extract_features(val_loader, model, criterion, log): 194 | losses = AverageMeter() 195 | top1 = AverageMeter() 196 | top5 = AverageMeter() 197 | 198 | # switch to evaluate mode 199 | model.eval() 200 | 201 | for i, (input, target) in enumerate(val_loader): 202 | if args.use_cuda: 203 | target = target.cuda(async=True) 204 | input = input.cuda() 205 | input_var = torch.autograd.Variable(input, volatile=True) 206 | target_var = torch.autograd.Variable(target, volatile=True) 207 | 208 | # compute output 209 | output, features = model([input_var]) 210 | 211 | pdb.set_trace() 212 | 213 | loss = criterion(output, target_var) 214 | 215 | # measure accuracy and record loss 216 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 217 | losses.update(loss.data[0], input.size(0)) 218 | top1.update(prec1[0], input.size(0)) 219 | top5.update(prec5[0], input.size(0)) 220 | 221 | print_log(' **Test** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg), log) 222 | 223 | return top1.avg, losses.avg 224 | 225 | def print_log(print_string, log): 226 | print("{}".format(print_string)) 227 | log.write('{}\n'.format(print_string)) 228 | log.flush() 229 | 230 | def save_checkpoint(state, is_best, save_path, filename): 231 | filename = os.path.join(save_path, filename) 232 | torch.save(state, filename) 233 | if is_best: 234 | bestname = os.path.join(save_path, 'model_best.pth.tar') 235 | shutil.copyfile(filename, bestname) 236 | 237 | def adjust_learning_rate(optimizer, epoch, gammas, schedule): 238 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 239 | lr = args.learning_rate 240 | assert len(gammas) == len(schedule), "length of gammas and schedule should be equal" 241 | for (gamma, step) in zip(gammas, schedule): 242 | if (epoch >= step): 243 | lr = lr * gamma 244 | else: 245 | break 246 | for param_group in optimizer.param_groups: 247 | param_group['lr'] = lr 248 | return lr 249 | 250 | def accuracy(output, target, topk=(1,)): 251 | """Computes the precision@k for the specified values of k""" 252 | maxk = max(topk) 253 | batch_size = target.size(0) 254 | 255 | _, pred = output.topk(maxk, 1, True, True) 256 | pred = pred.t() 257 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 258 | 259 | res = [] 260 | for k in topk: 261 | correct_k = correct[:k].view(-1).float().sum(0) 262 | res.append(correct_k.mul_(100.0 / batch_size)) 263 | return res 264 | 265 | if __name__ == '__main__': 266 | main() 267 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | import tensorflow as tf 3 | import numpy as np 4 | import scipy.misc 5 | import sys 6 | if sys.version[0] == '2': 7 | from StringIO import StringIO # Python 2.x 8 | elif sys.version[0] == '3': 9 | from io import BytesIO # Python 3.x 10 | 11 | 12 | class Logger(object): 13 | 14 | def __init__(self, log_dir): 15 | """Create a summary writer logging to log_dir.""" 16 | self.writer = tf.summary.FileWriter(log_dir) 17 | 18 | def scalar_summary(self, tag, value, step): 19 | """Log a scalar variable.""" 20 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 21 | self.writer.add_summary(summary, step) 22 | 23 | def image_summary(self, tag, images, step): 24 | """Log a list of images.""" 25 | 26 | img_summaries = [] 27 | for i, img in enumerate(images): 28 | # Write the image to a string 29 | try: 30 | s = StringIO() 31 | except: 32 | s = BytesIO() 33 | scipy.misc.toimage(img).save(s, format="png") 34 | 35 | # Create an Image object 36 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 37 | height=img.shape[0], 38 | width=img.shape[1]) 39 | # Create a Summary value 40 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 41 | 42 | # Create and write Summary 43 | summary = tf.Summary(value=img_summaries) 44 | self.writer.add_summary(summary, step) 45 | 46 | def histo_summary(self, tag, values, step, bins=1000): 47 | """Log a histogram of the tensor of values.""" 48 | 49 | # Create a histogram using numpy 50 | counts, bin_edges = np.histogram(values, bins=bins) 51 | 52 | # Fill the fields of the histogram proto 53 | hist = tf.HistogramProto() 54 | hist.min = float(np.min(values)) 55 | hist.max = float(np.max(values)) 56 | hist.num = int(np.prod(values.shape)) 57 | hist.sum = float(np.sum(values)) 58 | hist.sum_squares = float(np.sum(values**2)) 59 | 60 | # Drop the start of the first bin 61 | bin_edges = bin_edges[1:] 62 | 63 | # Add bin edges and counts 64 | for edge in bin_edges: 65 | hist.bucket_limit.append(edge) 66 | for c in counts: 67 | hist.bucket.append(c) 68 | 69 | # Create and write Summary 70 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 71 | self.writer.add_summary(summary, step) 72 | self.writer.flush() 73 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import os, sys, pdb, shutil, time, random 4 | import argparse 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import torchvision.datasets as dset 8 | import torchvision.transforms as transforms 9 | from utils import AverageMeter, RecorderMeter, time_string, convert_secs2time 10 | import models 11 | 12 | model_names = sorted(name for name in models.__dict__ 13 | if name.islower() and not name.startswith("__") 14 | and callable(models.__dict__[name])) 15 | 16 | parser = argparse.ArgumentParser(description='Trains ResNeXt on CIFAR or ImageNet', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | parser.add_argument('data_path', type=str, help='Path to dataset') 18 | parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'imagenet', 'svhn', 'stl10'], help='Choose between Cifar10/100 and ImageNet.') 19 | parser.add_argument('--arch', metavar='ARCH', default='resnet18', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnext29_8_64)') 20 | # Optimization options 21 | parser.add_argument('--epochs', type=int, default=300, help='Number of epochs to train.') 22 | parser.add_argument('--batch_size', type=int, default=128, help='Batch size.') 23 | parser.add_argument('--learning_rate', type=float, default=0.1, help='The Learning Rate.') 24 | parser.add_argument('--momentum', type=float, default=0.9, help='Momentum.') 25 | parser.add_argument('--decay', type=float, default=0.0005, help='Weight decay (L2 penalty).') 26 | parser.add_argument('--schedule', type=int, nargs='+', default=[150, 225], help='Decrease learning rate at these epochs.') 27 | parser.add_argument('--gammas', type=float, nargs='+', default=[0.1, 0.1], help='LR is multiplied by gamma on schedule, number of gammas should be equal to schedule') 28 | # Checkpoints 29 | parser.add_argument('--print_freq', default=200, type=int, metavar='N', help='print frequency (default: 200)') 30 | parser.add_argument('--save_path', type=str, default='./', help='Folder to save checkpoints and log.') 31 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 32 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') 33 | parser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') 34 | # Acceleration 35 | parser.add_argument('--ngpu', type=int, default=1, help='0 = CPU.') 36 | parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') 37 | # random seed 38 | parser.add_argument('--manualSeed', type=int, help='manual seed') 39 | parser.add_argument('--Ddim', type=int, default=4) 40 | args = parser.parse_args() 41 | args.use_cuda = args.ngpu>0 and torch.cuda.is_available() 42 | 43 | if args.manualSeed is None: 44 | args.manualSeed = random.randint(1, 10000) 45 | random.seed(args.manualSeed) 46 | torch.manual_seed(args.manualSeed) 47 | if args.use_cuda: 48 | torch.cuda.manual_seed_all(args.manualSeed) 49 | cudnn.benchmark = True 50 | 51 | def main(): 52 | # Init logger 53 | if not os.path.isdir(args.save_path): 54 | os.makedirs(args.save_path) 55 | log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w') 56 | print_log('save path : {}'.format(args.save_path), log) 57 | state = {k: v for k, v in args._get_kwargs()} 58 | print_log(state, log) 59 | print_log("Random Seed: {}".format(args.manualSeed), log) 60 | print_log("python version : {}".format(sys.version.replace('\n', ' ')), log) 61 | print_log("torch version : {}".format(torch.__version__), log) 62 | print_log("cudnn version : {}".format(torch.backends.cudnn.version()), log) 63 | 64 | # Init dataset 65 | if not os.path.isdir(args.data_path): 66 | os.makedirs(args.data_path) 67 | 68 | if args.dataset == 'cifar10': 69 | mean = [x / 255 for x in [125.3, 123.0, 113.9]] 70 | std = [x / 255 for x in [63.0, 62.1, 66.7]] 71 | elif args.dataset == 'cifar100': 72 | mean = [x / 255 for x in [129.3, 124.1, 112.4]] 73 | std = [x / 255 for x in [68.2, 65.4, 70.4]] 74 | else: 75 | assert False, "Unknow dataset : {}".format(args.dataset) 76 | 77 | train_transform = transforms.Compose( 78 | [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), 79 | transforms.Normalize(mean, std)]) 80 | test_transform = transforms.Compose( 81 | [transforms.ToTensor(), transforms.Normalize(mean, std)]) 82 | 83 | if args.dataset == 'cifar10': 84 | train_data = dset.CIFAR10(args.data_path, train=True, transform=train_transform, download=True) 85 | test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform, download=True) 86 | num_classes = 10 87 | elif args.dataset == 'cifar100': 88 | train_data = dset.CIFAR100(args.data_path, train=True, transform=train_transform, download=True) 89 | test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True) 90 | num_classes = 100 91 | elif args.dataset == 'svhn': 92 | train_data = dset.SVHN(args.data_path, split='train', transform=train_transform, download=True) 93 | test_data = dset.SVHN(args.data_path, split='test', transform=test_transform, download=True) 94 | num_classes = 10 95 | elif args.dataset == 'stl10': 96 | train_data = dset.STL10(args.data_path, split='train', transform=train_transform, download=True) 97 | test_data = dset.STL10(args.data_path, split='test', transform=test_transform, download=True) 98 | num_classes = 10 99 | elif args.dataset == 'imagenet': 100 | assert False, 'Do not finish imagenet code' 101 | else: 102 | assert False, 'Do not support dataset : {}'.format(args.dataset) 103 | 104 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, 105 | num_workers=args.workers, pin_memory=True) 106 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, 107 | num_workers=args.workers, pin_memory=True) 108 | 109 | print_log("=> creating model '{}'".format(args.arch), log) 110 | # Init model, criterion, and optimizer 111 | net = models.__dict__[args.arch](num_classes, args.Ddim) 112 | print_log("=> network :\n {}".format(net), log) 113 | 114 | net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu))) 115 | 116 | # define loss function (criterion) and optimizer 117 | criterion = torch.nn.CrossEntropyLoss() 118 | 119 | optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'], 120 | weight_decay=state['decay'], nesterov=True) 121 | 122 | if args.use_cuda: 123 | net.cuda() 124 | criterion.cuda() 125 | 126 | recorder = RecorderMeter(args.epochs) 127 | # optionally resume from a checkpoint 128 | if args.resume: 129 | if os.path.isfile(args.resume): 130 | print_log("=> loading checkpoint '{}'".format(args.resume), log) 131 | checkpoint = torch.load(args.resume) 132 | recorder = checkpoint['recorder'] 133 | args.start_epoch = checkpoint['epoch'] 134 | net.load_state_dict(checkpoint['state_dict']) 135 | optimizer.load_state_dict(checkpoint['optimizer']) 136 | print_log("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch']), log) 137 | else: 138 | print_log("=> no checkpoint found at '{}'".format(args.resume), log) 139 | else: 140 | print_log("=> do not use any checkpoint for {} model".format(args.arch), log) 141 | 142 | if args.evaluate: 143 | validate(test_loader, net, criterion, log) 144 | return 145 | 146 | # Main loop 147 | start_time = time.time() 148 | epoch_time = AverageMeter() 149 | for epoch in range(args.start_epoch, args.epochs): 150 | current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule) 151 | 152 | need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch)) 153 | need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs) 154 | 155 | print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \ 156 | + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log) 157 | 158 | # train for one epoch 159 | train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log, args.Ddim) 160 | 161 | # evaluate on validation set 162 | #val_acc, val_los = extract_features(test_loader, net, criterion, log) 163 | val_acc, val_los = validate(test_loader, net, criterion, log, args.Ddim) 164 | is_best = recorder.update(epoch, train_los, train_acc, val_los, val_acc) 165 | 166 | save_checkpoint({ 167 | 'epoch': epoch + 1, 168 | 'arch': args.arch, 169 | 'state_dict': net.state_dict(), 170 | 'recorder': recorder, 171 | 'optimizer' : optimizer.state_dict(), 172 | }, is_best, args.save_path, 'checkpoint.pth.tar') 173 | 174 | # measure elapsed time 175 | epoch_time.update(time.time() - start_time) 176 | start_time = time.time() 177 | recorder.plot_curve( os.path.join(args.save_path, 'curve.png') ) 178 | 179 | log.close() 180 | 181 | # train function (forward, backward, update) 182 | def train(train_loader, model, criterion, optimizer, epoch, log, Ddim): 183 | batch_time = AverageMeter() 184 | data_time = AverageMeter() 185 | losses = AverageMeter() 186 | top1 = AverageMeter() 187 | top5 = AverageMeter() 188 | # switch to train mode 189 | model.train() 190 | 191 | end = time.time() 192 | for i, (input, target) in enumerate(train_loader): 193 | # measure data loading time 194 | data_time.update(time.time() - end) 195 | 196 | eye = torch.eye(Ddim) 197 | 198 | if args.use_cuda: 199 | target = target.cuda(async=True) 200 | input = input.cuda() 201 | eye = eye.cuda() 202 | input_var = torch.autograd.Variable(input) 203 | target_var = torch.autograd.Variable(target) 204 | eye_var = torch.autograd.Variable(eye) 205 | 206 | # compute output 207 | output = model(input_var, eye_var) 208 | loss = criterion(output, target_var) 209 | 210 | # measure accuracy and record loss 211 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 212 | losses.update(loss.data[0], input.size(0)) 213 | top1.update(prec1[0], input.size(0)) 214 | top5.update(prec5[0], input.size(0)) 215 | 216 | # compute gradient and do SGD step 217 | optimizer.zero_grad() 218 | loss.backward() 219 | optimizer.step() 220 | 221 | # measure elapsed time 222 | batch_time.update(time.time() - end) 223 | end = time.time() 224 | 225 | if i % args.print_freq == 0: 226 | print_log(' Epoch: [{:03d}][{:03d}/{:03d}] ' 227 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 228 | 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' 229 | 'Loss {loss.val:.4f} ({loss.avg:.4f}) ' 230 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f}) ' 231 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f}) '.format( 232 | epoch, i, len(train_loader), batch_time=batch_time, 233 | data_time=data_time, loss=losses, top1=top1, top5=top5) + time_string(), log) 234 | print_log(' **Train** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg), log) 235 | return top1.avg, losses.avg 236 | 237 | def validate(val_loader, model, criterion, log, Ddim): 238 | losses = AverageMeter() 239 | top1 = AverageMeter() 240 | top5 = AverageMeter() 241 | 242 | # switch to evaluate mode 243 | model.eval() 244 | 245 | for i, (input, target) in enumerate(val_loader): 246 | eye = torch.eye(Ddim) 247 | if args.use_cuda: 248 | target = target.cuda(async=True) 249 | input = input.cuda() 250 | eye = eye.cuda() 251 | input_var = torch.autograd.Variable(input, volatile=True) 252 | target_var = torch.autograd.Variable(target, volatile=True) 253 | eye_var = torch.autograd.Variable(eye, volatile=True) 254 | 255 | # compute output 256 | output = model(input_var, eye_var) 257 | loss = criterion(output, target_var) 258 | 259 | # measure accuracy and record loss 260 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 261 | losses.update(loss.data[0], input.size(0)) 262 | top1.update(prec1[0], input.size(0)) 263 | top5.update(prec5[0], input.size(0)) 264 | 265 | print_log(' **Test** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg), log) 266 | 267 | return top1.avg, losses.avg 268 | 269 | def extract_features(val_loader, model, criterion, log): 270 | losses = AverageMeter() 271 | top1 = AverageMeter() 272 | top5 = AverageMeter() 273 | 274 | # switch to evaluate mode 275 | model.eval() 276 | 277 | for i, (input, target) in enumerate(val_loader): 278 | if args.use_cuda: 279 | target = target.cuda(async=True) 280 | input = input.cuda() 281 | input_var = torch.autograd.Variable(input, volatile=True) 282 | target_var = torch.autograd.Variable(target, volatile=True) 283 | 284 | # compute output 285 | output, features = model([input_var]) 286 | 287 | pdb.set_trace() 288 | 289 | loss = criterion(output, target_var) 290 | 291 | # measure accuracy and record loss 292 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 293 | losses.update(loss.data[0], input.size(0)) 294 | top1.update(prec1[0], input.size(0)) 295 | top5.update(prec5[0], input.size(0)) 296 | 297 | print_log(' **Test** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg), log) 298 | 299 | return top1.avg, losses.avg 300 | 301 | def print_log(print_string, log): 302 | print("{}".format(print_string)) 303 | log.write('{}\n'.format(print_string)) 304 | log.flush() 305 | 306 | def save_checkpoint(state, is_best, save_path, filename): 307 | filename = os.path.join(save_path, filename) 308 | torch.save(state, filename) 309 | if is_best: 310 | bestname = os.path.join(save_path, 'model_best.pth.tar') 311 | shutil.copyfile(filename, bestname) 312 | 313 | def adjust_learning_rate(optimizer, epoch, gammas, schedule): 314 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 315 | lr = args.learning_rate 316 | assert len(gammas) == len(schedule), "length of gammas and schedule should be equal" 317 | for (gamma, step) in zip(gammas, schedule): 318 | if (epoch >= step): 319 | lr = lr * gamma 320 | else: 321 | break 322 | for param_group in optimizer.param_groups: 323 | param_group['lr'] = lr 324 | return lr 325 | 326 | def accuracy(output, target, topk=(1,)): 327 | """Computes the precision@k for the specified values of k""" 328 | maxk = max(topk) 329 | batch_size = target.size(0) 330 | 331 | _, pred = output.topk(maxk, 1, True, True) 332 | pred = pred.t() 333 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 334 | 335 | res = [] 336 | for k in topk: 337 | correct_k = correct[:k].view(-1).float().sum(0) 338 | res.append(correct_k.mul_(100.0 / batch_size)) 339 | return res 340 | 341 | if __name__ == '__main__': 342 | main() 343 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """The models subpackage contains definitions for the following model 2 | architectures: 3 | - `ResNeXt` for CIFAR10 CIFAR100 4 | You can construct a model with random weights by calling its constructor: 5 | .. code:: python 6 | import models 7 | resnext29_16_64 = models.ResNeXt29_16_64(num_classes) 8 | resnext29_8_64 = models.ResNeXt29_8_64(num_classes) 9 | resnet20 = models.ResNet20(num_classes) 10 | resnet32 = models.ResNet32(num_classes) 11 | 12 | 13 | .. ResNext: https://arxiv.org/abs/1611.05431 14 | """ 15 | 16 | from .resnext import resnext29_8_64, resnext29_16_64 17 | from .resnet import resnet20, resnet32, resnet44, resnet56, resnet110, resnet110_valid 18 | from .preresnet import preresnet20, preresnet32, preresnet44, preresnet56, preresnet110, preresnet164 19 | from .caffe_cifar import caffe_cifar 20 | from .densenet import densenet100_12 21 | from .resnet_mod import resnet_mod20, resnet_mod32, resnet_mod44, resnet_mod56, resnet_mod110 22 | 23 | from .imagenet_resnet import resnet18, resnet34, resnet50, resnet101, resnet152 24 | -------------------------------------------------------------------------------- /models/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/models/__init__.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnext.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/models/__pycache__/resnext.cpython-36.pyc -------------------------------------------------------------------------------- /models/caffe_cifar.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import init 7 | import math 8 | 9 | ## http://torch.ch/blog/2015/07/30/cifar.html 10 | class CifarCaffeNet(nn.Module): 11 | def __init__(self, num_classes): 12 | super(CifarCaffeNet, self).__init__() 13 | 14 | self.num_classes = num_classes 15 | 16 | self.block_1 = nn.Sequential( 17 | nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), 18 | nn.MaxPool2d(kernel_size=3, stride=2), 19 | nn.ReLU(), 20 | nn.BatchNorm2d(32)) 21 | 22 | self.block_2 = nn.Sequential( 23 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), 24 | nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), 25 | nn.ReLU(), 26 | nn.AvgPool2d(kernel_size=3, stride=2), 27 | nn.BatchNorm2d(64)) 28 | 29 | self.block_3 = nn.Sequential( 30 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), 31 | nn.Conv2d(64,128, kernel_size=3, stride=1, padding=1), 32 | nn.ReLU(), 33 | nn.AvgPool2d(kernel_size=3, stride=2), 34 | nn.BatchNorm2d(128)) 35 | 36 | self.classifier = nn.Linear(128*9, self.num_classes) 37 | 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv2d): 40 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 41 | m.weight.data.normal_(0, math.sqrt(2. / n)) 42 | elif isinstance(m, nn.BatchNorm2d): 43 | m.weight.data.fill_(1) 44 | m.bias.data.zero_() 45 | elif isinstance(m, nn.Linear): 46 | init.kaiming_normal(m.weight) 47 | m.bias.data.zero_() 48 | 49 | def forward(self, x): 50 | x = self.block_1.forward(x) 51 | x = self.block_2.forward(x) 52 | x = self.block_3.forward(x) 53 | x = x.view(x.size(0), -1) 54 | #print ('{}'.format(x.size())) 55 | return self.classifier(x) 56 | 57 | def caffe_cifar(num_classes=10): 58 | model = CifarCaffeNet(num_classes) 59 | return model 60 | -------------------------------------------------------------------------------- /models/caffe_cifar.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/models/caffe_cifar.pyc -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | import math, torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import ops 5 | 6 | class Bottleneck(nn.Module): 7 | def __init__(self, nChannels, growthRate): 8 | super(Bottleneck, self).__init__() 9 | interChannels = 4*growthRate 10 | self.bn1 = nn.BatchNorm2d(nChannels) 11 | self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False) 12 | self.bn2 = nn.BatchNorm2d(interChannels) 13 | self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False) 14 | 15 | def forward(self, x): 16 | out = self.conv1(F.relu(self.bn1(x))) 17 | out = self.conv2(F.relu(self.bn2(out))) 18 | out = torch.cat((x, out), 1) 19 | return out 20 | 21 | class SingleLayer(nn.Module): 22 | def __init__(self, nChannels, growthRate): 23 | super(SingleLayer, self).__init__() 24 | self.bn1 = nn.BatchNorm2d(nChannels) 25 | self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False) 26 | 27 | def forward(self, x): 28 | out = self.conv1(F.relu(self.bn1(x))) 29 | out = torch.cat((x, out), 1) 30 | return out 31 | 32 | class Transition(nn.Module): 33 | def __init__(self, nChannels, nOutChannels): 34 | super(Transition, self).__init__() 35 | self.bn1 = nn.BatchNorm2d(nChannels) 36 | self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False) 37 | 38 | def forward(self, x): 39 | out = self.conv1(F.relu(self.bn1(x))) 40 | out = F.avg_pool2d(out, 2) 41 | return out 42 | 43 | class DenseNet(nn.Module): 44 | def __init__(self, growthRate, depth, reduction, nClasses, bottleneck, Ddim): 45 | super(DenseNet, self).__init__() 46 | 47 | if bottleneck: nDenseBlocks = int( (depth-4) / 6 ) 48 | else : nDenseBlocks = int( (depth-4) / 3 ) 49 | 50 | nChannels = 2*growthRate 51 | self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False) 52 | 53 | self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 54 | nChannels += nDenseBlocks*growthRate 55 | nOutChannels = int(math.floor(nChannels*reduction)) 56 | self.trans1 = Transition(nChannels, nOutChannels) 57 | 58 | nChannels = nOutChannels 59 | self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 60 | nChannels += nDenseBlocks*growthRate 61 | nOutChannels = int(math.floor(nChannels*reduction)) 62 | self.trans2 = Transition(nChannels, nOutChannels) 63 | 64 | nChannels = nOutChannels 65 | self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 66 | nChannels += nDenseBlocks*growthRate 67 | 68 | self.bn1 = nn.BatchNorm2d(nChannels) 69 | self.fc = ops.LinearCapsPro(nChannels, nClasses, Ddim) 70 | # self.fc = nn.Linear(nChannels, nClasses) 71 | 72 | for m in self.modules(): 73 | if isinstance(m, nn.Conv2d): 74 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 75 | m.weight.data.normal_(0, math.sqrt(2. / n)) 76 | elif isinstance(m, nn.BatchNorm2d): 77 | m.weight.data.fill_(1) 78 | m.bias.data.zero_() 79 | elif isinstance(m, nn.Linear): 80 | m.bias.data.zero_() 81 | 82 | def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): 83 | layers = [] 84 | for i in range(int(nDenseBlocks)): 85 | if bottleneck: 86 | layers.append(Bottleneck(nChannels, growthRate)) 87 | else: 88 | layers.append(SingleLayer(nChannels, growthRate)) 89 | nChannels += growthRate 90 | return nn.Sequential(*layers) 91 | 92 | def forward(self, x, eye): 93 | out = self.conv1(x) 94 | out = self.trans1(self.dense1(out)) 95 | out = self.trans2(self.dense2(out)) 96 | out = self.dense3(out) 97 | out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8)) 98 | out = self.fc(out, eye) 99 | # out = F.log_softmax(self.fc(out)) 100 | return out 101 | 102 | def densenet100_12(num_classes=10, Ddim=4): 103 | model = DenseNet(12, 100, 0.5, num_classes, True, Ddim) 104 | return model 105 | -------------------------------------------------------------------------------- /models/densenet.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/models/densenet.pyc -------------------------------------------------------------------------------- /models/imagenet_resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | "3x3 convolution with padding" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 8 | padding=1, bias=False) 9 | 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, inplanes, planes, stride=1, downsample=None): 15 | super(BasicBlock, self).__init__() 16 | self.conv1 = conv3x3(inplanes, planes, stride) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.conv2 = conv3x3(planes, planes) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.downsample = downsample 22 | self.stride = stride 23 | 24 | def forward(self, x): 25 | residual = x 26 | 27 | out = self.conv1(x) 28 | out = self.bn1(out) 29 | out = self.relu(out) 30 | 31 | out = self.conv2(out) 32 | out = self.bn2(out) 33 | 34 | if self.downsample is not None: 35 | residual = self.downsample(x) 36 | 37 | out += residual 38 | out = self.relu(out) 39 | 40 | return out 41 | 42 | 43 | class Bottleneck(nn.Module): 44 | expansion = 4 45 | 46 | def __init__(self, inplanes, planes, stride=1, downsample=None): 47 | super(Bottleneck, self).__init__() 48 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(planes) 50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 51 | padding=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 54 | self.bn3 = nn.BatchNorm2d(planes * 4) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.downsample = downsample 57 | self.stride = stride 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv3(out) 71 | out = self.bn3(out) 72 | 73 | if self.downsample is not None: 74 | residual = self.downsample(x) 75 | 76 | out += residual 77 | out = self.relu(out) 78 | 79 | return out 80 | 81 | 82 | class ResNet(nn.Module): 83 | 84 | def __init__(self, block, layers, num_classes=1000): 85 | self.inplanes = 64 86 | super(ResNet, self).__init__() 87 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 88 | bias=False) 89 | self.bn1 = nn.BatchNorm2d(64) 90 | self.relu = nn.ReLU(inplace=True) 91 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 92 | self.layer1 = self._make_layer(block, 64, layers[0]) 93 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 94 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 95 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 96 | self.avgpool = nn.AvgPool2d(7) 97 | self.fc = nn.Linear(512 * block.expansion, num_classes) 98 | 99 | for m in self.modules(): 100 | if isinstance(m, nn.Conv2d): 101 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 102 | m.weight.data.normal_(0, math.sqrt(2. / n)) 103 | elif isinstance(m, nn.BatchNorm2d): 104 | m.weight.data.fill_(1) 105 | m.bias.data.zero_() 106 | 107 | def _make_layer(self, block, planes, blocks, stride=1): 108 | downsample = None 109 | if stride != 1 or self.inplanes != planes * block.expansion: 110 | downsample = nn.Sequential( 111 | nn.Conv2d(self.inplanes, planes * block.expansion, 112 | kernel_size=1, stride=stride, bias=False), 113 | nn.BatchNorm2d(planes * block.expansion), 114 | ) 115 | 116 | layers = [] 117 | layers.append(block(self.inplanes, planes, stride, downsample)) 118 | self.inplanes = planes * block.expansion 119 | for i in range(1, blocks): 120 | layers.append(block(self.inplanes, planes)) 121 | 122 | return nn.Sequential(*layers) 123 | 124 | def forward(self, x): 125 | x = self.conv1(x) 126 | x = self.bn1(x) 127 | x = self.relu(x) 128 | x = self.maxpool(x) 129 | 130 | x = self.layer1(x) 131 | x = self.layer2(x) 132 | x = self.layer3(x) 133 | x = self.layer4(x) 134 | 135 | x = self.avgpool(x) 136 | x = x.view(x.size(0), -1) 137 | x = self.fc(x) 138 | 139 | return x 140 | 141 | 142 | def resnet18(num_classes=1000): 143 | """Constructs a ResNet-18 model. 144 | 145 | Args: 146 | pretrained (bool): If True, returns a model pre-trained on ImageNet 147 | """ 148 | model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes) 149 | return model 150 | 151 | 152 | def resnet34(num_classes=1000): 153 | """Constructs a ResNet-34 model. 154 | 155 | Args: 156 | pretrained (bool): If True, returns a model pre-trained on ImageNet 157 | """ 158 | model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes) 159 | return model 160 | 161 | 162 | def resnet50(num_classes=1000): 163 | """Constructs a ResNet-50 model. 164 | 165 | Args: 166 | pretrained (bool): If True, returns a model pre-trained on ImageNet 167 | """ 168 | model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes) 169 | return model 170 | 171 | 172 | def resnet101(num_classes=1000): 173 | """Constructs a ResNet-101 model. 174 | 175 | Args: 176 | pretrained (bool): If True, returns a model pre-trained on ImageNet 177 | """ 178 | model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes) 179 | return model 180 | 181 | 182 | def resnet152(num_classes=1000): 183 | """Constructs a ResNet-152 model. 184 | 185 | Args: 186 | pretrained (bool): If True, returns a model pre-trained on ImageNet 187 | """ 188 | model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes) 189 | return model 190 | -------------------------------------------------------------------------------- /models/imagenet_resnet.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/models/imagenet_resnet.pyc -------------------------------------------------------------------------------- /models/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | import math 6 | import numpy as np 7 | from torch.nn.parameter import Parameter 8 | 9 | class Caps_BN(nn.Module): 10 | ''' 11 | Input variable N*CD*H*W 12 | First perform normal BN without learnable affine parameters, then apply a C group convolution to perform per-capsule 13 | linear transformation 14 | ''' 15 | def __init__(self, num_C, num_D): 16 | super(Caps_BN, self).__init__() 17 | self.BN = nn.BatchNorm2d(num_C*num_D, affine=False) 18 | self.conv = nn.Conv2d(num_C*num_D, num_C*num_D, 1, groups=num_C) 19 | 20 | eye = torch.FloatTensor(num_C, num_D, num_D).copy_(torch.eye(num_D), broadcast = True).view(num_C*num_D, num_D, 1, 1) 21 | self.conv.weight.data.copy_(eye) 22 | self.conv.bias.data.zero_() 23 | 24 | def forward(self, x): 25 | output = self.BN(x) 26 | output = self.conv(output) 27 | 28 | return output 29 | 30 | class Caps_MaxPool(nn.Module): 31 | ''' 32 | Input variable N*CD*H*W 33 | First get the argmax indices of capsule lengths, then tile the indices D time and apply the tiled indices to capsules 34 | ''' 35 | def __init__(self, num_C, num_D, kernel_size, stride=None, padding=0, dilation=1): 36 | super(Caps_MaxPool, self).__init__() 37 | self.num_C = num_C 38 | self.num_D = num_D 39 | self.maxpool = nn.MaxPool2d(kernel_size, stride=stride, padding=padding, dilation=dilation, return_indices=True) 40 | 41 | def forward(self, x): 42 | B = x.shape[0] 43 | H, W = x.shape[2:] 44 | x_caps = x.view(B, self.num_C, self.num_D, H, W) 45 | x_length = torch.sum(x_caps * x_caps, dim=2) 46 | x_length_pool, indices = self.maxpool(x_length) 47 | H_pool, W_pool = x_length_pool.shape[2:] 48 | indices_tile = torch.unsqueeze(indices, 2).expand(-1, -1, self.num_D, -1, -1).contiguous() 49 | indices_tile = indices_tile.view(B, self.num_C * self.num_D, -1) 50 | x_flatten = x.view(B, self.num_C*self.num_D, -1) 51 | output = torch.gather(x_flatten, 2, indices_tile).view(B, self.num_C*self.num_D, H_pool, W_pool) 52 | 53 | return output 54 | 55 | class Caps_Conv(nn.Module): 56 | def __init__(self, in_C, in_D, out_C, out_D, kernel_size, stride=1, padding=0, dilation=1, bias=False): 57 | super(Caps_Conv, self).__init__() 58 | self.in_C = in_C 59 | self.in_D = in_D 60 | self.out_C = out_C 61 | self.out_D = out_D 62 | self.conv_D = nn.Conv2d(in_C*in_D, in_C*out_D, 1, groups=in_C, bias=False) 63 | self.conv_C = nn.Conv2d(in_C, out_C, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) 64 | 65 | m = self.conv_D.kernel_size[0] * self.conv_D.kernel_size[1] * self.conv_D.out_channels 66 | self.conv_D.weight.data.normal_(0, math.sqrt(2. / m)) 67 | n = self.conv_C.kernel_size[0] * self.conv_C.kernel_size[1] * self.conv_C.out_channels 68 | self.conv_C.weight.data.normal_(0, math.sqrt(2. / n)) 69 | if bias: 70 | self.conv_C.bias.data.zero_() 71 | 72 | def forward(self, x): 73 | x = self.conv_D(x) 74 | x = x.view(x.shape[0], self.in_C, self.out_D, x.shape[2], x.shape[3]) 75 | x = torch.transpose(x, 1, 2).contiguous() 76 | x = x.view(-1, self.in_C, x.shape[3], x.shape[4]) 77 | x = self.conv_C(x) 78 | x = x.view(-1, self.out_D, self.out_C, x.shape[2], x.shape[3]) 79 | x = torch.transpose(x, 1, 2).contiguous() 80 | x = x.view(-1, self.out_C*self.out_D, x.shape[3], x.shape[4]) 81 | 82 | return x 83 | 84 | class Squash(nn.Module): 85 | def __init__(self, num_C, num_D, eps=0.0001): 86 | super(Squash, self).__init__() 87 | self.num_C = num_C 88 | self.num_D = num_D 89 | self.eps = eps 90 | 91 | def forward(self, x): 92 | x_caps = x.view(x.shape[0], self.num_C, self.num_D, x.shape[2], x.shape[3]) 93 | x_length = torch.sqrt(torch.sum(x_caps * x_caps, dim=2)) 94 | x_length = torch.unsqueeze(x_length, 2) 95 | x_caps = x_caps * x_length / (1+self.eps+x_length*x_length) 96 | x = x_caps.view(x.shape[0], -1, x.shape[2], x.shape[3]) 97 | return x 98 | 99 | class Relu_Caps(nn.Module): 100 | def __init__(self, num_C, num_D, theta=0.2, eps=0.0001): 101 | super(Relu_Caps, self).__init__() 102 | self.num_C = num_C 103 | self.num_D = num_D 104 | self.theta = theta 105 | self.eps = eps 106 | 107 | def forward(self, x): 108 | x_caps = x.view(x.shape[0], self.num_C, self.num_D, x.shape[2], x.shape[3]) 109 | x_length = torch.sqrt(torch.sum(x_caps * x_caps, dim=2)) 110 | x_length = torch.unsqueeze(x_length, 2) 111 | x_caps = F.relu(x_length - self.theta) * x_caps / (x_length + self.eps) 112 | x = x_caps.view(x.shape[0], -1, x.shape[2], x.shape[3]) 113 | return x 114 | 115 | class Relu_Adpt(nn.Module): 116 | def __init__(self, num_C, num_D, eps=0.0001): 117 | super(Relu_Adpt, self).__init__() 118 | self.num_C = num_C 119 | self.num_D = num_D 120 | self.eps = eps 121 | 122 | self.theta = Parameter(torch.Tensor(1, self.num_C, 1, 1, 1)) 123 | self.theta.data.fill_(0.) 124 | 125 | def forward(self, x): 126 | x_caps = x.view(x.shape[0], self.num_C, self.num_D, x.shape[2], x.shape[3]) 127 | x_length = torch.sqrt(torch.sum(x_caps * x_caps, dim=2)) 128 | x_length = torch.unsqueeze(x_length, 2) 129 | x_caps = F.relu(x_length - self.theta) * x_caps / (x_length + self.eps) 130 | x = x_caps.view(x.shape[0], -1, x.shape[2], x.shape[3]) 131 | return x 132 | 133 | class LinearCaps(nn.Module): 134 | def __init__(self, in_features, num_C, num_D, bias = False, eps=0.0001): 135 | super(LinearCaps, self).__init__() 136 | self.in_features = in_features 137 | self.num_C = num_C 138 | self.num_D = num_D 139 | self.eps = eps 140 | self.weight = Parameter(torch.Tensor(num_C*num_D, in_features)) 141 | if bias: 142 | self.bias = Parameter(torch.Tensor(num_C*num_D)) 143 | else: 144 | self.register_parameter('bias', None) 145 | 146 | self.reset_parameters() 147 | 148 | def reset_parameters(self): 149 | stdv = 1. / math.sqrt(self.weight.size(1)) 150 | self.weight.data.uniform_(-stdv, stdv) 151 | if self.bias is not None: 152 | self.bias.data.uniform_(-stdv, stdv) 153 | 154 | # weights_reduce = torch.sqrt(torch.sum(self.weight*self.weight, dim=1)) 155 | # weights_reduce = torch.reciprocal(weights_reduce + self.eps) 156 | # weights_reduce = torch.unsqueeze(weights_reduce, dim=0) 157 | # 158 | # self.scalar.data.copy_(weights_reduce.data) 159 | # del weights_reduce 160 | 161 | def forward(self, x): 162 | scalar = torch.sqrt(torch.sum(self.weight * self.weight, dim=1)) 163 | scalar = torch.reciprocal(scalar + self.eps) 164 | scalar = torch.unsqueeze(scalar, dim=1) 165 | 166 | output = F.linear(x, scalar * self.weight, self.bias) 167 | 168 | return output 169 | 170 | class LinearCapsPro(nn.Module): 171 | def __init__(self, in_features, num_C, num_D, eps=0.0001): 172 | super(LinearCapsPro, self).__init__() 173 | self.in_features = in_features 174 | self.num_C = num_C 175 | self.num_D = num_D 176 | self.eps = eps 177 | self.weight = Parameter(torch.Tensor(num_C*num_D, in_features)) 178 | 179 | self.reset_parameters() 180 | 181 | def reset_parameters(self): 182 | stdv = 1. / math.sqrt(self.weight.size(1)) 183 | self.weight.data.uniform_(-stdv, stdv) 184 | 185 | def forward(self, x, eye): 186 | weight_caps = self.weight[:self.num_D] 187 | sigma = torch.inverse(torch.mm(weight_caps, torch.t(weight_caps))+self.eps*eye) 188 | sigma = torch.unsqueeze(sigma, dim=0) 189 | for ii in range(1, self.num_C): 190 | weight_caps = self.weight[ii*self.num_D:(ii+1)*self.num_D] 191 | sigma_ = torch.inverse(torch.mm(weight_caps, torch.t(weight_caps))+self.eps*eye) 192 | sigma_ = torch.unsqueeze(sigma_, dim=0) 193 | sigma = torch.cat((sigma, sigma_)) 194 | 195 | out = torch.matmul(x, torch.t(self.weight)) 196 | out = out.view(out.shape[0], self.num_C, 1, self.num_D) 197 | out = torch.matmul(out, sigma) 198 | out = torch.matmul(out, self.weight.view(self.num_C, self.num_D, self.in_features)) 199 | out = torch.squeeze(out, dim=2) 200 | out = torch.matmul(out, torch.unsqueeze(x, dim=2)) 201 | out = torch.squeeze(out, dim=2) 202 | 203 | return torch.sqrt(out) -------------------------------------------------------------------------------- /models/ops.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/models/ops.pyc -------------------------------------------------------------------------------- /models/preresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | from .res_utils import DownsampleA, DownsampleC 6 | import math 7 | import ops 8 | 9 | 10 | class ResNetBasicblock(nn.Module): 11 | expansion = 1 12 | def __init__(self, inplanes, planes, stride, downsample, Type): 13 | super(ResNetBasicblock, self).__init__() 14 | 15 | self.Type = Type 16 | 17 | self.bn_a = nn.BatchNorm2d(inplanes) 18 | self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 19 | 20 | self.bn_b = nn.BatchNorm2d(planes) 21 | self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | 23 | self.relu = nn.ReLU(inplace=True) 24 | self.downsample = downsample 25 | 26 | def forward(self, x): 27 | residual = x 28 | 29 | basicblock = self.bn_a(x) 30 | basicblock = self.relu(basicblock) 31 | 32 | if self.Type == 'both_preact': 33 | residual = basicblock 34 | elif self.Type != 'normal': 35 | assert False, 'Unknow type : {}'.format(self.Type) 36 | 37 | basicblock = self.conv_a(basicblock) 38 | 39 | basicblock = self.bn_b(basicblock) 40 | basicblock = self.relu(basicblock) 41 | basicblock = self.conv_b(basicblock) 42 | 43 | if self.downsample is not None: 44 | residual = self.downsample(residual) 45 | 46 | return residual + basicblock 47 | 48 | class CifarPreResNet(nn.Module): 49 | """ 50 | ResNet optimized for the Cifar dataset, as specified in 51 | https://arxiv.org/abs/1512.03385.pdf 52 | """ 53 | def __init__(self, block, depth, num_classes, Ddim): 54 | """ Constructor 55 | Args: 56 | depth: number of layers. 57 | num_classes: number of classes 58 | base_width: base width 59 | """ 60 | super(CifarPreResNet, self).__init__() 61 | 62 | #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model 63 | assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' 64 | layer_blocks = (depth - 2) // 6 65 | print ('CifarPreResNet : Depth : {} , Layers for each block : {}'.format(depth, layer_blocks)) 66 | 67 | self.num_classes = num_classes 68 | self.Ddim = Ddim 69 | 70 | self.conv_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 71 | 72 | self.inplanes = 16 73 | self.stage_1 = self._make_layer(block, 16, layer_blocks, 1) 74 | self.stage_2 = self._make_layer(block, 32, layer_blocks, 2) 75 | self.stage_3 = self._make_layer(block, 64, layer_blocks, 2) 76 | self.lastact = nn.Sequential(nn.BatchNorm2d(64*block.expansion), nn.ReLU(inplace=True)) 77 | self.avgpool = nn.AvgPool2d(8) 78 | # self.classifier = nn.Linear(64*block.expansion, num_classes) 79 | self.classifier = ops.LinearCapsPro(64*block.expansion, num_classes, Ddim) 80 | 81 | for m in self.modules(): 82 | if isinstance(m, nn.Conv2d): 83 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 84 | m.weight.data.normal_(0, math.sqrt(2. / n)) 85 | #m.bias.data.zero_() 86 | elif isinstance(m, nn.BatchNorm2d): 87 | m.weight.data.fill_(1) 88 | m.bias.data.zero_() 89 | elif isinstance(m, nn.Linear): 90 | init.kaiming_normal(m.weight) 91 | m.bias.data.zero_() 92 | 93 | def _make_layer(self, block, planes, blocks, stride=1): 94 | downsample = None 95 | if stride != 1 or self.inplanes != planes * block.expansion: 96 | downsample = DownsampleA(self.inplanes, planes * block.expansion, stride) 97 | 98 | layers = [] 99 | layers.append(block(self.inplanes, planes, stride, downsample, 'both_preact')) 100 | self.inplanes = planes * block.expansion 101 | for i in range(1, blocks): 102 | layers.append(block(self.inplanes, planes, 1, None, 'normal')) 103 | 104 | return nn.Sequential(*layers) 105 | 106 | def forward(self, x, eye): 107 | x = self.conv_3x3(x) 108 | x = self.stage_1(x) 109 | x = self.stage_2(x) 110 | x = self.stage_3(x) 111 | x = self.lastact(x) 112 | x = self.avgpool(x) 113 | x = x.view(x.size(0), -1) 114 | return self.classifier(x, eye) 115 | 116 | def preresnet20(num_classes=10): 117 | """Constructs a ResNet-20 model for CIFAR-10 (by default) 118 | Args: 119 | num_classes (uint): number of classes 120 | """ 121 | model = CifarPreResNet(ResNetBasicblock, 20, num_classes) 122 | return model 123 | 124 | def preresnet32(num_classes=10): 125 | """Constructs a ResNet-32 model for CIFAR-10 (by default) 126 | Args: 127 | num_classes (uint): number of classes 128 | """ 129 | model = CifarPreResNet(ResNetBasicblock, 32, num_classes) 130 | return model 131 | 132 | def preresnet44(num_classes=10): 133 | """Constructs a ResNet-44 model for CIFAR-10 (by default) 134 | Args: 135 | num_classes (uint): number of classes 136 | """ 137 | model = CifarPreResNet(ResNetBasicblock, 44, num_classes) 138 | return model 139 | 140 | def preresnet56(num_classes=10): 141 | """Constructs a ResNet-56 model for CIFAR-10 (by default) 142 | Args: 143 | num_classes (uint): number of classes 144 | """ 145 | model = CifarPreResNet(ResNetBasicblock, 56, num_classes) 146 | return model 147 | 148 | def preresnet110(num_classes=10): 149 | """Constructs a ResNet-110 model for CIFAR-10 (by default) 150 | Args: 151 | num_classes (uint): number of classes 152 | """ 153 | model = CifarPreResNet(ResNetBasicblock, 110, num_classes) 154 | return model 155 | 156 | def preresnet164(num_classes=10, Ddim=4): 157 | """Constructs a ResNet-110 model for CIFAR-10 (by default) 158 | Args: 159 | num_classes (uint): number of classes 160 | """ 161 | model = CifarPreResNet(ResNetBasicblock, 164, num_classes, Ddim) 162 | return model 163 | -------------------------------------------------------------------------------- /models/preresnet.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/models/preresnet.pyc -------------------------------------------------------------------------------- /models/res_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class DownsampleA(nn.Module): 5 | 6 | def __init__(self, nIn, nOut, stride): 7 | super(DownsampleA, self).__init__() 8 | assert stride == 2 9 | self.avg = nn.AvgPool2d(kernel_size=1, stride=stride) 10 | 11 | def forward(self, x): 12 | x = self.avg(x) 13 | return torch.cat((x, x.mul(0)), 1) 14 | 15 | class DownsampleC(nn.Module): 16 | 17 | def __init__(self, nIn, nOut, stride): 18 | super(DownsampleC, self).__init__() 19 | assert stride != 1 or nIn != nOut 20 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False) 21 | 22 | def forward(self, x): 23 | x = self.conv(x) 24 | return x 25 | 26 | class DownsampleD(nn.Module): 27 | 28 | def __init__(self, nIn, nOut, stride): 29 | super(DownsampleD, self).__init__() 30 | assert stride == 2 31 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=2, stride=stride, padding=0, bias=False) 32 | self.bn = nn.BatchNorm2d(nOut) 33 | 34 | def forward(self, x): 35 | x = self.conv(x) 36 | x = self.bn(x) 37 | return x 38 | -------------------------------------------------------------------------------- /models/res_utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/models/res_utils.pyc -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | from .res_utils import DownsampleA, DownsampleC, DownsampleD 6 | import math 7 | import ops 8 | 9 | class ResNetBasicblock(nn.Module): 10 | expansion = 1 11 | """ 12 | RexNet basicblock (https://github.com/facebook/fb.resnet.torch/blob/master/models/resnet.lua) 13 | """ 14 | def __init__(self, inplanes, planes, stride=1, downsample=None): 15 | super(ResNetBasicblock, self).__init__() 16 | 17 | self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | self.bn_a = nn.BatchNorm2d(planes) 19 | 20 | self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 21 | self.bn_b = nn.BatchNorm2d(planes) 22 | 23 | self.downsample = downsample 24 | 25 | def forward(self, x): 26 | residual = x 27 | 28 | basicblock = self.conv_a(x) 29 | basicblock = self.bn_a(basicblock) 30 | basicblock = F.relu(basicblock, inplace=True) 31 | 32 | basicblock = self.conv_b(basicblock) 33 | basicblock = self.bn_b(basicblock) 34 | 35 | if self.downsample is not None: 36 | residual = self.downsample(x) 37 | 38 | return F.relu(residual + basicblock, inplace=True) 39 | 40 | class CifarResNet(nn.Module): 41 | """ 42 | ResNet optimized for the Cifar dataset, as specified in 43 | https://arxiv.org/abs/1512.03385.pdf 44 | """ 45 | def __init__(self, block, depth, num_classes, Ddim): 46 | """ Constructor 47 | Args: 48 | depth: number of layers. 49 | num_classes: number of classes 50 | base_width: base width 51 | """ 52 | super(CifarResNet, self).__init__() 53 | 54 | #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model 55 | assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' 56 | layer_blocks = (depth - 2) // 6 57 | print ('CifarResNet : Depth : {} , Layers for each block : {}'.format(depth, layer_blocks)) 58 | 59 | self.num_classes = num_classes 60 | self.Ddim = Ddim 61 | 62 | self.conv_1_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 63 | self.bn_1 = nn.BatchNorm2d(16) 64 | 65 | self.inplanes = 16 66 | self.stage_1 = self._make_layer(block, 16, layer_blocks, 1) 67 | self.stage_2 = self._make_layer(block, 32, layer_blocks, 2) 68 | self.stage_3 = self._make_layer(block, 64, layer_blocks, 2) 69 | self.avgpool = nn.AvgPool2d(8) 70 | self.classifier = ops.LinearCapsPro(64*block.expansion, num_classes, Ddim) 71 | # self.classifier = nn.Linear(64*block.expansion, num_classes) 72 | 73 | for m in self.modules(): 74 | if isinstance(m, nn.Conv2d): 75 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 76 | m.weight.data.normal_(0, math.sqrt(2. / n)) 77 | #m.bias.data.zero_() 78 | elif isinstance(m, nn.BatchNorm2d): 79 | m.weight.data.fill_(1) 80 | m.bias.data.zero_() 81 | elif isinstance(m, nn.Linear): 82 | init.kaiming_normal(m.weight) 83 | m.bias.data.zero_() 84 | 85 | def _make_layer(self, block, planes, blocks, stride=1): 86 | downsample = None 87 | if stride != 1 or self.inplanes != planes * block.expansion: 88 | downsample = DownsampleA(self.inplanes, planes * block.expansion, stride) 89 | 90 | layers = [] 91 | layers.append(block(self.inplanes, planes, stride, downsample)) 92 | self.inplanes = planes * block.expansion 93 | for i in range(1, blocks): 94 | layers.append(block(self.inplanes, planes)) 95 | 96 | return nn.Sequential(*layers) 97 | 98 | def forward(self, x, eye): 99 | x = self.conv_1_3x3(x) 100 | x = F.relu(self.bn_1(x), inplace=True) 101 | x = self.stage_1(x) 102 | x = self.stage_2(x) 103 | x = self.stage_3(x) 104 | x = self.avgpool(x) 105 | x = x.view(x.size(0), -1) 106 | return self.classifier(x, eye) 107 | 108 | class CifarResNet_valid(nn.Module): 109 | """ 110 | ResNet optimized for the Cifar dataset, as specified in 111 | https://arxiv.org/abs/1512.03385.pdf 112 | """ 113 | def __init__(self, block, depth, num_classes, Ddim): 114 | """ Constructor 115 | Args: 116 | depth: number of layers. 117 | num_classes: number of classes 118 | base_width: base width 119 | """ 120 | super(CifarResNet_valid, self).__init__() 121 | 122 | #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model 123 | assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' 124 | layer_blocks = (depth - 2) // 6 125 | print ('CifarResNet : Depth : {} , Layers for each block : {}'.format(depth, layer_blocks)) 126 | 127 | self.num_classes = num_classes 128 | self.Ddim = Ddim 129 | 130 | self.conv_1_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 131 | self.bn_1 = nn.BatchNorm2d(16) 132 | 133 | self.inplanes = 16 134 | self.stage_1 = self._make_layer(block, 16, layer_blocks, 1) 135 | self.stage_2 = self._make_layer(block, 32, layer_blocks, 2) 136 | self.stage_3 = self._make_layer(block, 64, layer_blocks, 2) 137 | self.avgpool = nn.AvgPool2d(8) 138 | self.classifier = ops.LinearCapsPro(64*block.expansion, num_classes, Ddim) 139 | # self.classifier = nn.Linear(64*block.expansion, num_classes) 140 | 141 | for m in self.modules(): 142 | if isinstance(m, nn.Conv2d): 143 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 144 | m.weight.data.normal_(0, math.sqrt(2. / n)) 145 | #m.bias.data.zero_() 146 | elif isinstance(m, nn.BatchNorm2d): 147 | m.weight.data.fill_(1) 148 | m.bias.data.zero_() 149 | elif isinstance(m, nn.Linear): 150 | init.kaiming_normal(m.weight) 151 | m.bias.data.zero_() 152 | 153 | def _make_layer(self, block, planes, blocks, stride=1): 154 | downsample = None 155 | if stride != 1 or self.inplanes != planes * block.expansion: 156 | downsample = DownsampleA(self.inplanes, planes * block.expansion, stride) 157 | 158 | layers = [] 159 | layers.append(block(self.inplanes, planes, stride, downsample)) 160 | self.inplanes = planes * block.expansion 161 | for i in range(1, blocks): 162 | layers.append(block(self.inplanes, planes)) 163 | 164 | return nn.Sequential(*layers) 165 | 166 | def forward(self, x, eye): 167 | x = self.conv_1_3x3(x) 168 | x = F.relu(self.bn_1(x), inplace=True) 169 | x = self.stage_1(x) 170 | x = self.stage_2(x) 171 | x = self.stage_3(x) 172 | x = self.avgpool(x) 173 | x = x.view(x.size(0), -1) 174 | return x, self.classifier(x, eye) 175 | 176 | def resnet20(num_classes=10): 177 | """Constructs a ResNet-20 model for CIFAR-10 (by default) 178 | Args: 179 | num_classes (uint): number of classes 180 | """ 181 | model = CifarResNet(ResNetBasicblock, 20, num_classes) 182 | return model 183 | 184 | def resnet32(num_classes=10): 185 | """Constructs a ResNet-32 model for CIFAR-10 (by default) 186 | Args: 187 | num_classes (uint): number of classes 188 | """ 189 | model = CifarResNet(ResNetBasicblock, 32, num_classes) 190 | return model 191 | 192 | def resnet44(num_classes=10): 193 | """Constructs a ResNet-44 model for CIFAR-10 (by default) 194 | Args: 195 | num_classes (uint): number of classes 196 | """ 197 | model = CifarResNet(ResNetBasicblock, 44, num_classes) 198 | return model 199 | 200 | def resnet56(num_classes=10): 201 | """Constructs a ResNet-56 model for CIFAR-10 (by default) 202 | Args: 203 | num_classes (uint): number of classes 204 | """ 205 | model = CifarResNet(ResNetBasicblock, 56, num_classes) 206 | return model 207 | 208 | def resnet110(num_classes=10, Ddim=4): 209 | """Constructs a ResNet-110 model for CIFAR-10 (by default) 210 | Args: 211 | num_classes (uint): number of classes 212 | """ 213 | model = CifarResNet(ResNetBasicblock, 110, num_classes, Ddim) 214 | return model 215 | 216 | def resnet110_valid(num_classes=10, Ddim=4): 217 | """Constructs a ResNet-110 model for CIFAR-10 (by default) 218 | Args: 219 | num_classes (uint): number of classes 220 | """ 221 | model = CifarResNet_valid(ResNetBasicblock, 110, num_classes, Ddim) 222 | return model 223 | -------------------------------------------------------------------------------- /models/resnet.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/models/resnet.pyc -------------------------------------------------------------------------------- /models/resnet_mod.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | from .res_utils import DownsampleA, DownsampleC, DownsampleD 6 | import math 7 | 8 | 9 | class ResNetBasicblock(nn.Module): 10 | expansion = 1 11 | """ 12 | RexNet basicblock (https://github.com/facebook/fb.resnet.torch/blob/master/models/resnet.lua) 13 | """ 14 | def __init__(self, inplanes, planes, stride=1, downsample=None): 15 | super(ResNetBasicblock, self).__init__() 16 | 17 | self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | self.bn_a = nn.BatchNorm2d(planes) 19 | 20 | self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 21 | self.bn_b = nn.BatchNorm2d(planes) 22 | 23 | self.downsample = downsample 24 | 25 | def forward(self, x): 26 | if isinstance(x, list): 27 | x, is_list, features = x[0], True, x[1:] 28 | else: 29 | is_list, features = False, None 30 | residual = x 31 | 32 | conv_a = self.conv_a(x) 33 | bn_a = self.bn_a(conv_a) 34 | relu_a = F.relu(bn_a, inplace=True) 35 | 36 | conv_b = self.conv_b(relu_a) 37 | bn_b = self.bn_b(conv_b) 38 | 39 | if self.downsample is not None: 40 | residual = self.downsample(x) 41 | 42 | output = F.relu(residual + bn_b, inplace=True) 43 | 44 | if is_list: 45 | return [output] + features + [bn_a, bn_b] 46 | else: 47 | return output 48 | 49 | class CifarResNet(nn.Module): 50 | """ 51 | ResNet optimized for the Cifar dataset, as specified in 52 | https://arxiv.org/abs/1512.03385.pdf 53 | """ 54 | def __init__(self, block, depth, num_classes): 55 | """ Constructor 56 | Args: 57 | depth: number of layers. 58 | num_classes: number of classes 59 | base_width: base width 60 | """ 61 | super(CifarResNet, self).__init__() 62 | 63 | #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model 64 | assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' 65 | layer_blocks = (depth - 2) // 6 66 | print ('CifarResNet : Depth : {} , Layers for each block : {}'.format(depth, layer_blocks)) 67 | 68 | self.num_classes = num_classes 69 | 70 | self.conv_1_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.bn_1 = nn.BatchNorm2d(16) 72 | 73 | self.inplanes = 16 74 | self.stage_1 = self._make_layer(block, 16, layer_blocks, 1) 75 | self.stage_2 = self._make_layer(block, 32, layer_blocks, 2) 76 | self.stage_3 = self._make_layer(block, 64, layer_blocks, 2) 77 | self.avgpool = nn.AvgPool2d(8) 78 | self.classifier = nn.Linear(64*block.expansion, num_classes) 79 | 80 | for m in self.modules(): 81 | if isinstance(m, nn.Conv2d): 82 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 83 | m.weight.data.normal_(0, math.sqrt(2. / n)) 84 | #m.bias.data.zero_() 85 | elif isinstance(m, nn.BatchNorm2d): 86 | m.weight.data.fill_(1) 87 | m.bias.data.zero_() 88 | elif isinstance(m, nn.Linear): 89 | init.kaiming_normal(m.weight) 90 | m.bias.data.zero_() 91 | 92 | def _make_layer(self, block, planes, blocks, stride=1): 93 | downsample = None 94 | if stride != 1 or self.inplanes != planes * block.expansion: 95 | downsample = DownsampleA(self.inplanes, planes * block.expansion, stride) 96 | 97 | layers = [] 98 | layers.append(block(self.inplanes, planes, stride, downsample)) 99 | self.inplanes = planes * block.expansion 100 | for i in range(1, blocks): 101 | layers.append(block(self.inplanes, planes)) 102 | 103 | return nn.Sequential(*layers) 104 | 105 | def forward(self, x): 106 | if isinstance(x, list): 107 | assert len(x) == 1, 'The length of inputs must be one vs {}'.format(len(x)) 108 | x, is_list = x[0], True 109 | else: 110 | x, is_list = x, False 111 | x = self.conv_1_3x3(x) 112 | x = F.relu(self.bn_1(x), inplace=True) 113 | 114 | if is_list: x = [x] 115 | x = self.stage_1(x) 116 | x = self.stage_2(x) 117 | x = self.stage_3(x) 118 | if is_list: 119 | x, features = x[0], x[1:] 120 | else: 121 | features = None 122 | x = self.avgpool(x) 123 | x = x.view(x.size(0), -1) 124 | cls = self.classifier(x) 125 | 126 | if is_list: return cls, features 127 | else: return cls 128 | 129 | def resnet_mod20(num_classes=10): 130 | """Constructs a ResNet-20 model for CIFAR-10 (by default) 131 | Args: 132 | num_classes (uint): number of classes 133 | """ 134 | model = CifarResNet(ResNetBasicblock, 20, num_classes) 135 | return model 136 | 137 | def resnet_mod32(num_classes=10): 138 | """Constructs a ResNet-32 model for CIFAR-10 (by default) 139 | Args: 140 | num_classes (uint): number of classes 141 | """ 142 | model = CifarResNet(ResNetBasicblock, 32, num_classes) 143 | return model 144 | 145 | def resnet_mod44(num_classes=10): 146 | """Constructs a ResNet-44 model for CIFAR-10 (by default) 147 | Args: 148 | num_classes (uint): number of classes 149 | """ 150 | model = CifarResNet(ResNetBasicblock, 44, num_classes) 151 | return model 152 | 153 | def resnet_mod56(num_classes=10): 154 | """Constructs a ResNet-56 model for CIFAR-10 (by default) 155 | Args: 156 | num_classes (uint): number of classes 157 | """ 158 | model = CifarResNet(ResNetBasicblock, 56, num_classes) 159 | return model 160 | 161 | def resnet_mod110(num_classes=10): 162 | """Constructs a ResNet-110 model for CIFAR-10 (by default) 163 | Args: 164 | num_classes (uint): number of classes 165 | """ 166 | model = CifarResNet(ResNetBasicblock, 110, num_classes) 167 | return model 168 | -------------------------------------------------------------------------------- /models/resnet_mod.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/models/resnet_mod.pyc -------------------------------------------------------------------------------- /models/resnext.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torch.nn import init 4 | import math 5 | import ops 6 | 7 | class ResNeXtBottleneck(nn.Module): 8 | expansion = 4 9 | """ 10 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua) 11 | """ 12 | def __init__(self, inplanes, planes, cardinality, base_width, stride=1, downsample=None): 13 | super(ResNeXtBottleneck, self).__init__() 14 | 15 | D = int(math.floor(planes * (base_width/64.0))) 16 | C = cardinality 17 | 18 | self.conv_reduce = nn.Conv2d(inplanes, D*C, kernel_size=1, stride=1, padding=0, bias=False) 19 | self.bn_reduce = nn.BatchNorm2d(D*C) 20 | 21 | self.conv_conv = nn.Conv2d(D*C, D*C, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 22 | self.bn = nn.BatchNorm2d(D*C) 23 | 24 | self.conv_expand = nn.Conv2d(D*C, planes*4, kernel_size=1, stride=1, padding=0, bias=False) 25 | self.bn_expand = nn.BatchNorm2d(planes*4) 26 | 27 | self.downsample = downsample 28 | 29 | def forward(self, x): 30 | residual = x 31 | 32 | bottleneck = self.conv_reduce(x) 33 | bottleneck = F.relu(self.bn_reduce(bottleneck), inplace=True) 34 | 35 | bottleneck = self.conv_conv(bottleneck) 36 | bottleneck = F.relu(self.bn(bottleneck), inplace=True) 37 | 38 | bottleneck = self.conv_expand(bottleneck) 39 | bottleneck = self.bn_expand(bottleneck) 40 | 41 | if self.downsample is not None: 42 | residual = self.downsample(x) 43 | 44 | return F.relu(residual + bottleneck, inplace=True) 45 | 46 | 47 | class CifarResNeXt(nn.Module): 48 | """ 49 | ResNext optimized for the Cifar dataset, as specified in 50 | https://arxiv.org/pdf/1611.05431.pdf 51 | """ 52 | def __init__(self, block, depth, cardinality, base_width, num_classes, Ddim): 53 | super(CifarResNeXt, self).__init__() 54 | 55 | #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model 56 | assert (depth - 2) % 9 == 0, 'depth should be one of 29, 38, 47, 56, 101' 57 | layer_blocks = (depth - 2) // 9 58 | 59 | self.cardinality = cardinality 60 | self.base_width = base_width 61 | self.num_classes = num_classes 62 | self.Ddim = Ddim 63 | 64 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False) 65 | self.bn_1 = nn.BatchNorm2d(64) 66 | 67 | self.inplanes = 64 68 | self.stage_1 = self._make_layer(block, 64 , layer_blocks, 1) 69 | self.stage_2 = self._make_layer(block, 128, layer_blocks, 2) 70 | self.stage_3 = self._make_layer(block, 256, layer_blocks, 2) 71 | self.avgpool = nn.AvgPool2d(8) 72 | # self.classifier = nn.Linear(256*block.expansion, num_classes) 73 | self.classifier = ops.LinearCapsPro(256*block.expansion, num_classes, Ddim) 74 | 75 | for m in self.modules(): 76 | if isinstance(m, nn.Conv2d): 77 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 78 | m.weight.data.normal_(0, math.sqrt(2. / n)) 79 | elif isinstance(m, nn.BatchNorm2d): 80 | m.weight.data.fill_(1) 81 | m.bias.data.zero_() 82 | elif isinstance(m, nn.Linear): 83 | init.kaiming_normal(m.weight) 84 | m.bias.data.zero_() 85 | 86 | def _make_layer(self, block, planes, blocks, stride=1): 87 | downsample = None 88 | if stride != 1 or self.inplanes != planes * block.expansion: 89 | downsample = nn.Sequential( 90 | nn.Conv2d(self.inplanes, planes * block.expansion, 91 | kernel_size=1, stride=stride, bias=False), 92 | nn.BatchNorm2d(planes * block.expansion), 93 | ) 94 | 95 | layers = [] 96 | layers.append(block(self.inplanes, planes, self.cardinality, self.base_width, stride, downsample)) 97 | self.inplanes = planes * block.expansion 98 | for i in range(1, blocks): 99 | layers.append(block(self.inplanes, planes, self.cardinality, self.base_width)) 100 | 101 | return nn.Sequential(*layers) 102 | 103 | def forward(self, x, eye): 104 | x = self.conv_1_3x3(x) 105 | x = F.relu(self.bn_1(x), inplace=True) 106 | x = self.stage_1(x) 107 | x = self.stage_2(x) 108 | x = self.stage_3(x) 109 | x = self.avgpool(x) 110 | x = x.view(x.size(0), -1) 111 | return self.classifier(x, eye) 112 | 113 | def resnext29_16_64(num_classes=10, Ddim=4): 114 | """Constructs a ResNeXt-29, 16*64d model for CIFAR-10 (by default) 115 | 116 | Args: 117 | num_classes (uint): number of classes 118 | """ 119 | model = CifarResNeXt(ResNeXtBottleneck, 29, 16, 64, num_classes, Ddim) 120 | return model 121 | 122 | def resnext29_8_64(num_classes=10, Ddim=4): 123 | """Constructs a ResNeXt-29, 8*64d model for CIFAR-10 (by default) 124 | 125 | Args: 126 | num_classes (uint): number of classes 127 | """ 128 | model = CifarResNeXt(ResNeXtBottleneck, 29, 8, 64, num_classes, Ddim) 129 | return model 130 | -------------------------------------------------------------------------------- /models/resnext.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/models/resnext.pyc -------------------------------------------------------------------------------- /pretrained_model/checkpoint.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/pretrained_model/checkpoint.pth.tar -------------------------------------------------------------------------------- /pretrained_model/curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/pretrained_model/curve.png -------------------------------------------------------------------------------- /pretrained_model/model_best.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/pretrained_model/model_best.pth.tar -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os, sys, time 2 | import numpy as np 3 | import matplotlib 4 | matplotlib.use('agg') 5 | import matplotlib.pyplot as plt 6 | 7 | class AverageMeter(object): 8 | """Computes and stores the average and current value""" 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | 24 | 25 | class RecorderMeter(object): 26 | """Computes and stores the minimum loss value and its epoch index""" 27 | def __init__(self, total_epoch): 28 | self.reset(total_epoch) 29 | 30 | def reset(self, total_epoch): 31 | assert total_epoch > 0 32 | self.total_epoch = total_epoch 33 | self.current_epoch = 0 34 | self.epoch_losses = np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val] 35 | self.epoch_losses = self.epoch_losses - 1 36 | 37 | self.epoch_accuracy= np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val] 38 | self.epoch_accuracy= self.epoch_accuracy 39 | 40 | def update(self, idx, train_loss, train_acc, val_loss, val_acc): 41 | assert idx >= 0 and idx < self.total_epoch, 'total_epoch : {} , but update with the {} index'.format(self.total_epoch, idx) 42 | self.epoch_losses [idx, 0] = train_loss 43 | self.epoch_losses [idx, 1] = val_loss 44 | self.epoch_accuracy[idx, 0] = train_acc 45 | self.epoch_accuracy[idx, 1] = val_acc 46 | self.current_epoch = idx + 1 47 | return self.max_accuracy(False) == val_acc 48 | 49 | def max_accuracy(self, istrain): 50 | if self.current_epoch <= 0: return 0 51 | if istrain: return self.epoch_accuracy[:self.current_epoch, 0].max() 52 | else: return self.epoch_accuracy[:self.current_epoch, 1].max() 53 | 54 | def plot_curve(self, save_path): 55 | title = 'the accuracy/loss curve of train/val' 56 | dpi = 80 57 | width, height = 1200, 800 58 | legend_fontsize = 10 59 | scale_distance = 48.8 60 | figsize = width / float(dpi), height / float(dpi) 61 | 62 | fig = plt.figure(figsize=figsize) 63 | x_axis = np.array([i for i in range(self.total_epoch)]) # epochs 64 | y_axis = np.zeros(self.total_epoch) 65 | 66 | plt.xlim(0, self.total_epoch) 67 | plt.ylim(0, 100) 68 | interval_y = 5 69 | interval_x = 5 70 | plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x)) 71 | plt.yticks(np.arange(0, 100 + interval_y, interval_y)) 72 | plt.grid() 73 | plt.title(title, fontsize=20) 74 | plt.xlabel('the training epoch', fontsize=16) 75 | plt.ylabel('accuracy', fontsize=16) 76 | 77 | y_axis[:] = self.epoch_accuracy[:, 0] 78 | plt.plot(x_axis, y_axis, color='g', linestyle='-', label='train-accuracy', lw=2) 79 | plt.legend(loc=4, fontsize=legend_fontsize) 80 | 81 | y_axis[:] = self.epoch_accuracy[:, 1] 82 | plt.plot(x_axis, y_axis, color='y', linestyle='-', label='valid-accuracy', lw=2) 83 | plt.legend(loc=4, fontsize=legend_fontsize) 84 | 85 | 86 | y_axis[:] = self.epoch_losses[:, 0] 87 | plt.plot(x_axis, y_axis*50, color='g', linestyle=':', label='train-loss-x50', lw=2) 88 | plt.legend(loc=4, fontsize=legend_fontsize) 89 | 90 | y_axis[:] = self.epoch_losses[:, 1] 91 | plt.plot(x_axis, y_axis*50, color='y', linestyle=':', label='valid-loss-x50', lw=2) 92 | plt.legend(loc=4, fontsize=legend_fontsize) 93 | 94 | if save_path is not None: 95 | fig.savefig(save_path, dpi=dpi, bbox_inches='tight') 96 | print ('---- save figure {} into {}'.format(title, save_path)) 97 | plt.close(fig) 98 | 99 | 100 | def time_string(): 101 | ISOTIMEFORMAT='%Y-%m-%d %X' 102 | string = '[{}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) 103 | return string 104 | 105 | def convert_secs2time(epoch_time): 106 | need_hour = int(epoch_time / 3600) 107 | need_mins = int((epoch_time - 3600*need_hour) / 60) 108 | need_secs = int(epoch_time - 3600*need_hour - 60*need_mins) 109 | return need_hour, need_mins, need_secs 110 | 111 | def time_file_str(): 112 | ISOTIMEFORMAT='%Y-%m-%d' 113 | string = '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) 114 | return string + '-{}'.format(random.randint(1, 10000)) 115 | -------------------------------------------------------------------------------- /view/class0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/view/class0.png -------------------------------------------------------------------------------- /view/class1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/view/class1.png -------------------------------------------------------------------------------- /view/class2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/view/class2.png -------------------------------------------------------------------------------- /view/class3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/view/class3.png -------------------------------------------------------------------------------- /view/class4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/view/class4.png -------------------------------------------------------------------------------- /view/class5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/view/class5.png -------------------------------------------------------------------------------- /view/class6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/view/class6.png -------------------------------------------------------------------------------- /view/class7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/view/class7.png -------------------------------------------------------------------------------- /view/class8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/view/class8.png -------------------------------------------------------------------------------- /view/class9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/view/class9.png -------------------------------------------------------------------------------- /view/features.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/view/features.npy -------------------------------------------------------------------------------- /view/labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/view/labels.npy -------------------------------------------------------------------------------- /view/log_seed_7102.txt: -------------------------------------------------------------------------------- 1 | save path : ./view 2 | {'manualSeed': 7102, 'workers': 2, 'data_path': '/home/liheng/cifar10', 'batch_size': 128, 'dataset': 'cifar10', 'ngpu': 1, 'use_cuda': True, 'save_path': './view', 'resume': '/home/liheng/deep_capsule/cifar10/v6_resnet110_original/110epoch500D2/checkpoint.pth.tar', 'Ddim': 2, 'arch': 'resnet110_valid'} 3 | Random Seed: 7102 4 | python version : 2.7.13 |Anaconda custom (64-bit)| (default, Dec 20 2016, 23:09:15) [GCC 4.4.7 20120313 (Red Hat 4.4.7-1)] 5 | torch version : 0.3.1.post3 6 | cudnn version : 7102 7 | => creating model 'resnet110_valid' 8 | => network : 9 | CifarResNet_valid( 10 | (conv_1_3x3): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 11 | (bn_1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 12 | (stage_1): Sequential( 13 | (0): ResNetBasicblock( 14 | (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 15 | (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 16 | (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 17 | (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 18 | ) 19 | (1): ResNetBasicblock( 20 | (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 21 | (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 22 | (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 23 | (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 24 | ) 25 | (2): ResNetBasicblock( 26 | (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 27 | (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 28 | (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 29 | (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 30 | ) 31 | (3): ResNetBasicblock( 32 | (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 33 | (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 34 | (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 35 | (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 36 | ) 37 | (4): ResNetBasicblock( 38 | (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 39 | (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 40 | (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 41 | (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 42 | ) 43 | (5): ResNetBasicblock( 44 | (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 45 | (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 46 | (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 47 | (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 48 | ) 49 | (6): ResNetBasicblock( 50 | (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 51 | (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 52 | (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 53 | (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 54 | ) 55 | (7): ResNetBasicblock( 56 | (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 57 | (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 58 | (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 59 | (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 60 | ) 61 | (8): ResNetBasicblock( 62 | (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 63 | (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 64 | (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 65 | (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 66 | ) 67 | (9): ResNetBasicblock( 68 | (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 69 | (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 70 | (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 71 | (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 72 | ) 73 | (10): ResNetBasicblock( 74 | (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 75 | (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 76 | (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 77 | (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 78 | ) 79 | (11): ResNetBasicblock( 80 | (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 81 | (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 82 | (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 83 | (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 84 | ) 85 | (12): ResNetBasicblock( 86 | (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 87 | (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 88 | (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 89 | (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 90 | ) 91 | (13): ResNetBasicblock( 92 | (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 93 | (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 94 | (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 95 | (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 96 | ) 97 | (14): ResNetBasicblock( 98 | (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 99 | (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 100 | (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 101 | (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 102 | ) 103 | (15): ResNetBasicblock( 104 | (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 105 | (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 106 | (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 107 | (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 108 | ) 109 | (16): ResNetBasicblock( 110 | (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 111 | (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 112 | (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 113 | (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 114 | ) 115 | (17): ResNetBasicblock( 116 | (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 117 | (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 118 | (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 119 | (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True) 120 | ) 121 | ) 122 | (stage_2): Sequential( 123 | (0): ResNetBasicblock( 124 | (conv_a): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) 125 | (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 126 | (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 127 | (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 128 | (downsample): DownsampleA( 129 | (avg): AvgPool2d(kernel_size=1, stride=2, padding=0, ceil_mode=False, count_include_pad=True) 130 | ) 131 | ) 132 | (1): ResNetBasicblock( 133 | (conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 134 | (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 135 | (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 136 | (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 137 | ) 138 | (2): ResNetBasicblock( 139 | (conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 140 | (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 141 | (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 142 | (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 143 | ) 144 | (3): ResNetBasicblock( 145 | (conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 146 | (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 147 | (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 148 | (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 149 | ) 150 | (4): ResNetBasicblock( 151 | (conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 152 | (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 153 | (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 154 | (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 155 | ) 156 | (5): ResNetBasicblock( 157 | (conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 158 | (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 159 | (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 160 | (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 161 | ) 162 | (6): ResNetBasicblock( 163 | (conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 164 | (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 165 | (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 166 | (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 167 | ) 168 | (7): ResNetBasicblock( 169 | (conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 170 | (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 171 | (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 172 | (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 173 | ) 174 | (8): ResNetBasicblock( 175 | (conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 176 | (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 177 | (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 178 | (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 179 | ) 180 | (9): ResNetBasicblock( 181 | (conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 182 | (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 183 | (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 184 | (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 185 | ) 186 | (10): ResNetBasicblock( 187 | (conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 188 | (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 189 | (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 190 | (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 191 | ) 192 | (11): ResNetBasicblock( 193 | (conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 194 | (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 195 | (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 196 | (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 197 | ) 198 | (12): ResNetBasicblock( 199 | (conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 200 | (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 201 | (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 202 | (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 203 | ) 204 | (13): ResNetBasicblock( 205 | (conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 206 | (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 207 | (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 208 | (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 209 | ) 210 | (14): ResNetBasicblock( 211 | (conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 212 | (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 213 | (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 214 | (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 215 | ) 216 | (15): ResNetBasicblock( 217 | (conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 218 | (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 219 | (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 220 | (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 221 | ) 222 | (16): ResNetBasicblock( 223 | (conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 224 | (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 225 | (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 226 | (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 227 | ) 228 | (17): ResNetBasicblock( 229 | (conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 230 | (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 231 | (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 232 | (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True) 233 | ) 234 | ) 235 | (stage_3): Sequential( 236 | (0): ResNetBasicblock( 237 | (conv_a): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) 238 | (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 239 | (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 240 | (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 241 | (downsample): DownsampleA( 242 | (avg): AvgPool2d(kernel_size=1, stride=2, padding=0, ceil_mode=False, count_include_pad=True) 243 | ) 244 | ) 245 | (1): ResNetBasicblock( 246 | (conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 247 | (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 248 | (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 249 | (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 250 | ) 251 | (2): ResNetBasicblock( 252 | (conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 253 | (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 254 | (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 255 | (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 256 | ) 257 | (3): ResNetBasicblock( 258 | (conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 259 | (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 260 | (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 261 | (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 262 | ) 263 | (4): ResNetBasicblock( 264 | (conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 265 | (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 266 | (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 267 | (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 268 | ) 269 | (5): ResNetBasicblock( 270 | (conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 271 | (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 272 | (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 273 | (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 274 | ) 275 | (6): ResNetBasicblock( 276 | (conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 277 | (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 278 | (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 279 | (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 280 | ) 281 | (7): ResNetBasicblock( 282 | (conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 283 | (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 284 | (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 285 | (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 286 | ) 287 | (8): ResNetBasicblock( 288 | (conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 289 | (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 290 | (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 291 | (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 292 | ) 293 | (9): ResNetBasicblock( 294 | (conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 295 | (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 296 | (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 297 | (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 298 | ) 299 | (10): ResNetBasicblock( 300 | (conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 301 | (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 302 | (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 303 | (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 304 | ) 305 | (11): ResNetBasicblock( 306 | (conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 307 | (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 308 | (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 309 | (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 310 | ) 311 | (12): ResNetBasicblock( 312 | (conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 313 | (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 314 | (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 315 | (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 316 | ) 317 | (13): ResNetBasicblock( 318 | (conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 319 | (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 320 | (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 321 | (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 322 | ) 323 | (14): ResNetBasicblock( 324 | (conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 325 | (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 326 | (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 327 | (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 328 | ) 329 | (15): ResNetBasicblock( 330 | (conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 331 | (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 332 | (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 333 | (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 334 | ) 335 | (16): ResNetBasicblock( 336 | (conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 337 | (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 338 | (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 339 | (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 340 | ) 341 | (17): ResNetBasicblock( 342 | (conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 343 | (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 344 | (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 345 | (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 346 | ) 347 | ) 348 | (avgpool): AvgPool2d(kernel_size=8, stride=8, padding=0, ceil_mode=False, count_include_pad=True) 349 | (classifier): LinearCapsPro( 350 | ) 351 | ) 352 | => loading checkpoint '/home/liheng/deep_capsule/cifar10/v6_resnet110_original/110epoch500D2/checkpoint.pth.tar' 353 | => loaded checkpoint '/home/liheng/deep_capsule/cifar10/v6_resnet110_original/110epoch500D2/checkpoint.pth.tar' (epoch 500) 354 | -------------------------------------------------------------------------------- /view/view.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.linalg import fractional_matrix_power 3 | import matplotlib.pyplot as plt 4 | 5 | features = np.load('features.npy') 6 | labels = np.load('labels.npy') 7 | weights = np.load('weights.npy') 8 | 9 | for cls in range(10): 10 | fig = plt.figure(cls+1) 11 | 12 | w = np.transpose(weights[cls*2:(cls+1)*2]) 13 | proj = fractional_matrix_power(np.matmul(np.transpose(w), w), -0.5) 14 | proj = np.matmul(proj, np.transpose(w)) 15 | 16 | for ii in range(500): 17 | x = features[ii][:, None] 18 | u = np.matmul(proj, x) 19 | if labels[ii]==cls: 20 | plt.plot(u[0], u[1], 'ro') 21 | else: 22 | plt.plot(u[0], u[1], 'g^') 23 | plt.title('Capsule subspace '+str(cls)) 24 | ax = fig.add_subplot(1, 1, 1) 25 | 26 | ax.spines['left'].set_position(('data', 0.0)) 27 | ax.spines['bottom'].set_position(('data', 0.0)) 28 | ax.spines['right'].set_color('none') 29 | ax.spines['top'].set_color('none') 30 | ax.set_axisbelow(False) 31 | 32 | ax.yaxis.tick_left() 33 | ax.xaxis.tick_bottom() 34 | plt.show() 35 | 36 | #w = np.transpose(weights[:2]) 37 | #proj = fractional_matrix_power(np.matmul(np.transpose(w), w), -0.5) 38 | #proj = np.matmul(proj, np.transpose(w)) 39 | #for ii in range(300): 40 | # x = features[ii][:, None] 41 | # u = np.matmul(proj, x) 42 | # if labels[ii]==0: 43 | # plt.plot(np.abs(u[0]), np.abs(u[1]), 'ro') 44 | # else: 45 | # plt.plot(np.abs(u[0]), np.abs(u[1]), 'g^') 46 | # print(u) 47 | # print(labels[ii]) 48 | #plt.show() -------------------------------------------------------------------------------- /view/weights.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maple-research-lab/CapProNet-Pytorch/f8affe045f93847b2ba5c1c16722f1032286a378/view/weights.npy --------------------------------------------------------------------------------