├── misc ├── moe.pdf ├── capsnet.pdf ├── self_routing.png ├── routing_by_agreement.png └── neurips2019-self_routing-poster.pdf ├── .gitignore ├── models ├── __init__.py ├── smallnet.py ├── convnet.py └── resnet.py ├── requirements.txt ├── loss.py ├── main.py ├── utils.py ├── README.md ├── attack.py ├── config.py ├── data_loader.py ├── modules.py ├── trainer.py └── norb.py /misc/moe.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coder3000/SR-CapsNet/HEAD/misc/moe.pdf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | 4 | ckpt/** 5 | data/** 6 | logs/** 7 | 8 | -------------------------------------------------------------------------------- /misc/capsnet.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coder3000/SR-CapsNet/HEAD/misc/capsnet.pdf -------------------------------------------------------------------------------- /misc/self_routing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coder3000/SR-CapsNet/HEAD/misc/self_routing.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .convnet import * 3 | from .smallnet import * 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.2.0 2 | torchvision==0.4.0 3 | tensorboardx==1.8 4 | scipy 5 | numpy 6 | tqdm -------------------------------------------------------------------------------- /misc/routing_by_agreement.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coder3000/SR-CapsNet/HEAD/misc/routing_by_agreement.png -------------------------------------------------------------------------------- /misc/neurips2019-self_routing-poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coder3000/SR-CapsNet/HEAD/misc/neurips2019-self_routing-poster.pdf -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.nn.modules.loss import _Loss 6 | 7 | from utils import one_hot 8 | 9 | 10 | class DynamicRoutingLoss(nn.Module): 11 | def __init(self): 12 | super(DynamicRoutingLoss, self).__init() 13 | 14 | def forward(self, x, target): 15 | target = one_hot(target, x.shape[1]) 16 | 17 | left = F.relu(0.9 - x) ** 2 18 | right = F.relu(x - 0.1) ** 2 19 | 20 | margin_loss = target * left + 0.5 * (1. - target) * right 21 | margin_loss = margin_loss.sum(dim=1).mean() 22 | return margin_loss 23 | 24 | 25 | class EmRoutingLoss(nn.Module): 26 | def __init__(self, max_epoch): 27 | super(EmRoutingLoss, self).__init__() 28 | self.margin_init = 0.2 29 | self.margin_step = 0.2 / max_epoch 30 | self.max_epoch = max_epoch 31 | 32 | def forward(self, x, target, epoch=None): 33 | if epoch is None: 34 | margin = 0.9 35 | else: 36 | margin = self.margin_init + self.margin_step * min(epoch, self.max_epoch) 37 | 38 | b, E = x.shape 39 | at = x.new_zeros(b) 40 | for i, lb in enumerate(target): 41 | at[i] = x[i][lb] 42 | at = at.view(b, 1).repeat(1, E) 43 | 44 | zeros = x.new_zeros(x.shape) 45 | loss = torch.max(margin - (at - x), zeros) 46 | loss = loss**2 47 | loss = loss.sum(dim=1).mean() 48 | return loss 49 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchvision import datasets, transforms 4 | 5 | from trainer import Trainer 6 | from config import get_config 7 | from utils import prepare_dirs 8 | from data_loader import get_test_loader, get_train_valid_loader, VIEWPOINT_EXPS 9 | 10 | 11 | torch.backends.cudnn.deterministic = True 12 | torch.backends.cudnn.benchmark = False 13 | 14 | 15 | def main(config): 16 | 17 | # ensure directories are setup 18 | prepare_dirs(config) 19 | 20 | # ensure reproducibility 21 | torch.manual_seed(config.random_seed) 22 | kwargs = {} 23 | if torch.cuda.is_available(): 24 | torch.cuda.manual_seed(config.random_seed) 25 | kwargs = {'num_workers': 4, 'pin_memory': False} 26 | 27 | # instantiate data loaders 28 | if config.is_train: 29 | data_loader = get_train_valid_loader( 30 | config.data_dir, config.dataset, config.batch_size, 31 | config.random_seed, config.exp, config.valid_size, 32 | config.shuffle, **kwargs 33 | ) 34 | else: 35 | data_loader = get_test_loader( 36 | config.data_dir, config.dataset, config.batch_size, config.exp, config.familiar, 37 | **kwargs 38 | ) 39 | 40 | # instantiate trainer 41 | trainer = Trainer(config, data_loader) 42 | 43 | if config.is_train: 44 | trainer.train() 45 | else: 46 | if config.attack: 47 | trainer.test_attack() 48 | else: 49 | trainer.test() 50 | 51 | if __name__ == '__main__': 52 | config, unparsed = get_config() 53 | main(config) 54 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import json 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | class AverageMeter(object): 9 | """ 10 | Computes and stores the average and 11 | current value. 12 | """ 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | 29 | def prepare_dirs(config): 30 | for path in [config.data_dir, config.ckpt_dir, config.logs_dir]: 31 | if not os.path.exists(path): 32 | os.makedirs(path) 33 | 34 | def save_config(model_name, config): 35 | filename = model_name + '_params.json' 36 | param_path = os.path.join(config.ckpt_dir, filename) 37 | 38 | print("[*] Model Checkpoint Dir: {}".format(config.ckpt_dir)) 39 | print("[*] Param Path: {}".format(param_path)) 40 | 41 | with open(param_path, 'w') as fp: 42 | json.dump(config.__dict__, fp, indent=4, sort_keys=True) 43 | 44 | def one_hot(y, n_dims): 45 | scatter_dim = len(y.size()) 46 | y_tensor = y.view(*y.size(), -1) 47 | zeros = torch.zeros(*y.size(), n_dims).cuda() 48 | return zeros.scatter(scatter_dim, y_tensor, 1) 49 | 50 | # dynamic routing 51 | def squash(s, dim=-1): 52 | mag_sq = torch.sum(s**2, dim=dim, keepdim=True) 53 | mag = torch.sqrt(mag_sq) 54 | v = (mag_sq / (1.0 + mag_sq)) * (s / mag) 55 | return v 56 | 57 | def weights_init(m): 58 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 59 | nn.init.kaiming_uniform_(m.weight.data) 60 | elif isinstance(m, nn.BatchNorm2d): 61 | nn.init.constant_(m.weight, 1) 62 | nn.init.constant_(m.bias, 0) 63 | 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SR-CapsNet 2 | 3 | PyTorch implementation for our paper [**Self-Routing Capsule Networks**](https://papers.nips.cc/paper/8982-self-routing-capsule-networks) in NeurIPS 2019. 4 | 5 | 6 | [[poster]](https://github.com/coder3000/SR-CapsNet/blob/master/misc/neurips2019-self_routing-poster.pdf) 7 | 8 | ## Prerequisites 9 | - Python >= 3.5.2 10 | - CUDA >= 9.0 supported GPU 11 | 12 | Install required packages by: 13 | ``` 14 | pip3 install -r requirements.txt 15 | ``` 16 | 17 | 18 | ## Training 19 | To train a model for CIFAR-10 or SVHN, run: 20 | ``` 21 | python3 main.py --dataset=cifar10 --name=resnet_[routing_method] --epochs=350 22 | python3 main.py --dataset=svhn --name=resnet_[routing_method] --epochs=200 23 | ``` 24 | 25 | `routing_method` should be one of `[avg, max, fc, dynamic_routing, em_routing, self_routing]`. This will modify last layers of base model accordingly. 26 | 27 | 28 | For SmallNORB, run: 29 | 30 | ``` 31 | python3 main.py --dataset=smallnorb --name=convnet_[routing_method] --epochs=200 --exp=elevation 32 | ``` 33 | 34 | Here `--exp` denotes which viewpoint data should be splitted on. 35 | 36 | See `config.py` for more options and their descriptions. 37 | 38 | ## Testing 39 | To test a model, simply run: 40 | 41 | ``` 42 | python3 main.py --dataset=cifar10 --name=resnet_[routing_method] --is_train=False 43 | ``` 44 | 45 | You can perform adversarial attacks against a trained model by: 46 | ``` 47 | python3 main.py --dataset=cifar10 --name=resnet_[routing_method] --is_train=False --attack=True --attack_type=bim --attack_eps=0.1 --targeted=False 48 | ``` 49 | 50 | For SmallNORB, you can test against novel viewpoints by: 51 | ``` 52 | python3 main.py --dataset=smallnorb --name=convnet_[routing_method] --is_train=False --familiar=False 53 | ``` 54 | 55 | 56 | ## Citation 57 | ``` 58 | @inproceedings{hahn2019, 59 | title={Self-Routing Capsule Networks}, 60 | author={Hahn, Taeyoung and Pyeon, Myeongjang and Kim, Gunhee}, 61 | booktitle={NeurIPS}, 62 | year={2019} 63 | } 64 | ``` 65 | -------------------------------------------------------------------------------- /attack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | import random 8 | 9 | random.seed(2019) 10 | 11 | class Attack(object): 12 | def __init__(self, net, criterion, attack_type, eps): 13 | self.net = net 14 | self.criterion = criterion 15 | self.attack_type = attack_type 16 | 17 | if attack_type not in ["bim", "fgsm"]: 18 | raise NotImplementedError("Unknown attack type") 19 | 20 | self.eps = eps 21 | 22 | def make(self, x, y, target): 23 | return getattr(self, self.attack_type)(x, y, target=target) 24 | 25 | def bim(self, x, y, target=None, x_val_min=-1, x_val_max=1): 26 | out = self.net(x) 27 | pred = torch.max(out, 1)[1] 28 | 29 | if pred.item() != y.item(): 30 | return None 31 | 32 | eta = torch.zeros_like(x) 33 | iters = 10 34 | eps_iter = self.eps / iters 35 | for i in range(iters): 36 | nx = x + eta 37 | nx.requires_grad_() 38 | 39 | out = self.net(nx) 40 | 41 | self.net.zero_grad() 42 | if target is not None: 43 | cost = self.criterion(out, target) 44 | else: 45 | cost = -self.criterion(out, y) 46 | cost.backward() 47 | 48 | eta -= eps_iter * torch.sign(nx.grad.data) 49 | eta.clamp_(-self.eps, self.eps) 50 | nx.grad.data.zero_() 51 | 52 | x_adv = x + eta 53 | x_adv.clamp_(x_val_min, x_val_max) 54 | 55 | if target is not None: 56 | return x_adv.detach(), target 57 | 58 | return x_adv.detach(), y 59 | 60 | def fgsm(self, x, y, target=None, x_val_min=-1, x_val_max=1): 61 | data = Variable(x.data, requires_grad=True) 62 | out = self.net(data) 63 | pred = torch.max(out, 1)[1] 64 | 65 | if pred.item() != y.item(): 66 | return None 67 | 68 | if target is not None: 69 | cost = self.criterion(out, target) 70 | else: 71 | cost = -self.criterion(out, y) 72 | 73 | self.net.zero_grad() 74 | if data.grad is not None: 75 | data.grad.data.fill_(0) 76 | cost.backward() 77 | 78 | data.grad.sign_() 79 | data = data - self.eps * data.grad 80 | x_adv = torch.clamp(data, x_val_min, x_val_max) 81 | 82 | if target is not None: 83 | return x_adv, target 84 | 85 | return x_adv, y 86 | 87 | def extract_adv_images(attacker, dataloader, targeted, classes=10): 88 | adv_images = [] 89 | num_examples = 0 90 | for batch, (x, y) in enumerate(dataloader): 91 | x, y = x.cuda(), y.cuda() 92 | curr_x_adv_batch = [] 93 | curr_y_batch = [] 94 | for i in range(len(y)): 95 | if targeted: 96 | y_new = y[i] + 1 97 | if y_new == classes: 98 | y_new = 0 99 | target = y.new_zeros(1) 100 | target[0] = y_new 101 | gg = attacker.make(x[i:i+1], y[i:i+1], target=target) 102 | else: 103 | gg = attacker.make(x[i:i+1], y[i:i+1], target=None) 104 | 105 | if gg is not None: 106 | curr_x_adv_batch.append(gg[0]) 107 | curr_y_batch.append(gg[1]) 108 | num_examples += 1 109 | 110 | curr_x_adv_batch = torch.cat(curr_x_adv_batch, dim=0) 111 | curr_y_batch = torch.cat(curr_y_batch, dim=0) 112 | adv_images.append((curr_x_adv_batch, curr_y_batch)) 113 | 114 | if batch == 20: 115 | break 116 | 117 | return adv_images, num_examples 118 | 119 | 120 | -------------------------------------------------------------------------------- /models/smallnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from modules import * 6 | 7 | from utils import weights_init 8 | 9 | 10 | class SmallNet(nn.Module): 11 | def __init__(self, cfg_data, mode='SR'): 12 | super(SmallNet, self).__init__() 13 | channels, classes = cfg_data['channels'], cfg_data['classes'] 14 | self.conv1 = nn.Conv2d(channels, 256, kernel_size=7, stride=2, padding=1, bias=False) 15 | self.bn1 = nn.BatchNorm2d(256) 16 | 17 | self.mode = mode 18 | 19 | self.num_caps = 16 20 | 21 | planes = 16 22 | last_size = 6 23 | if self.mode == 'SR': 24 | self.conv_a = nn.Conv2d(256, self.num_caps, kernel_size=5, stride=1, padding=1, bias=False) 25 | self.conv_pose = nn.Conv2d(256, self.num_caps*planes, kernel_size=5, stride=1, padding=1, bias=False) 26 | self.bn_a = nn.BatchNorm2d(self.num_caps) 27 | self.bn_pose = nn.BatchNorm2d(self.num_caps*planes) 28 | 29 | self.conv_caps = SelfRouting2d(self.num_caps, self.num_caps, planes, planes, kernel_size=3, stride=2, padding=1, pose_out=True) 30 | self.bn_pose_conv_caps = nn.BatchNorm2d(self.num_caps*planes) 31 | 32 | self.fc_caps = SelfRouting2d(self.num_caps, classes, planes, 1, kernel_size=last_size, padding=0, pose_out=False) 33 | 34 | elif self.mode == 'DR': 35 | self.conv_pose = nn.Conv2d(256, self.num_caps*planes, kernel_size=5, stride=1, padding=1, bias=False) 36 | # self.bn_pose = nn.BatchNorm2d(self.num_caps*planes) 37 | 38 | self.conv_caps = DynamicRouting2d(self.num_caps, self.num_caps, 16, 16, kernel_size=3, stride=2, padding=1) 39 | nn.init.normal_(self.conv_caps.W, 0, 0.5) 40 | 41 | self.fc_caps = DynamicRouting2d(self.num_caps, classes, 16, 16, kernel_size=last_size, padding=0) 42 | nn.init.normal_(self.fc_caps.W, 0, 0.05) 43 | 44 | elif self.mode == 'EM': 45 | self.conv_a = nn.Conv2d(256, self.num_caps, kernel_size=5, stride=1, padding=1, bias=False) 46 | self.conv_pose = nn.Conv2d(256, self.num_caps*16, kernel_size=5, stride=1, padding=1, bias=False) 47 | self.bn_a = nn.BatchNorm2d(self.num_caps) 48 | self.bn_pose = nn.BatchNorm2d(self.num_caps*16) 49 | 50 | self.conv_caps = EmRouting2d(self.num_caps, self.num_caps, 16, kernel_size=3, stride=2, padding=1) 51 | self.bn_pose_conv_caps = nn.BatchNorm2d(self.num_caps*planes) 52 | 53 | self.fc_caps = EmRouting2d(self.num_caps, classes, 16, kernel_size=last_size, padding=0) 54 | 55 | else: 56 | raise NotImplementedError 57 | 58 | self.apply(weights_init) 59 | 60 | def forward(self, x): 61 | out = F.relu(self.bn1(self.conv1(x))) 62 | 63 | if self.mode == 'DR': 64 | # pose = self.bn_pose(self.conv_pose(out)) 65 | pose = self.conv_pose(out) 66 | 67 | b, c, h, w = pose.shape 68 | pose = pose.permute(0, 2, 3, 1).contiguous() 69 | pose = squash(pose.view(b, h, w, self.num_caps, 16)) 70 | pose = pose.view(b, h, w, -1) 71 | pose = pose.permute(0, 3, 1, 2) 72 | 73 | pose = self.conv_caps(pose) 74 | 75 | out = self.fc_caps(pose) 76 | out = out.view(b, -1, 16) 77 | out = out.norm(dim=-1) 78 | 79 | elif self.mode == 'EM': 80 | a, pose = self.conv_a(out), self.conv_pose(out) 81 | a, pose = torch.sigmoid(self.bn_a(a)), self.bn_pose(pose) 82 | 83 | a, pose = self.conv_caps(a, pose) 84 | pose = self.bn_pose_conv_caps(pose) 85 | 86 | a, _ = self.fc_caps(a, pose) 87 | out = a.view(a.size(0), -1) 88 | 89 | elif self.mode == 'SR': 90 | a, pose = self.conv_a(out), self.conv_pose(out) 91 | a, pose = torch.sigmoid(self.bn_a(a)), self.bn_pose(pose) 92 | 93 | a, pose = self.conv_caps(a, pose) 94 | pose = self.bn_pose_conv_caps(pose) 95 | 96 | a, _ = self.fc_caps(a, pose) 97 | 98 | out = a.view(a.size(0), -1) 99 | out = out.log() 100 | 101 | return out 102 | 103 | 104 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | arg_lists = [] 4 | parser = argparse.ArgumentParser(description='CapsNet') 5 | 6 | def str2bool(v): 7 | return v.lower() in ('true', '1') 8 | 9 | def add_argument_group(name): 10 | arg = parser.add_argument_group(name) 11 | arg_lists.append(arg) 12 | return arg 13 | 14 | 15 | # data params 16 | data_arg = add_argument_group('Data Params') 17 | data_arg.add_argument('--valid_size', type=float, default=0.1, 18 | help='Proportion of training set used for validation') 19 | data_arg.add_argument('--batch_size', type=int, default=64, 20 | help='# of images in each batch of data') 21 | data_arg.add_argument('--num_workers', type=int, default=4, 22 | help='# of subprocesses to use for data loading') 23 | data_arg.add_argument('--shuffle', type=str2bool, default=True, 24 | help='Whether to shuffle the train and valid indices') 25 | 26 | 27 | # training params 28 | train_arg = add_argument_group('Training Params') 29 | train_arg.add_argument('--is_train', type=str2bool, default=True, 30 | help='Whether to train or test the model') 31 | train_arg.add_argument('--momentum', type=float, default=0.9, 32 | help='Momentum value') 33 | train_arg.add_argument('--weight_decay', type=float, default=1e-4, 34 | help='Weight decay value') 35 | train_arg.add_argument('--epochs', type=int, default=350, 36 | help='# of epochs to train for') 37 | train_arg.add_argument('--init_lr', type=float, default=0.1, 38 | help='Initial learning rate value') 39 | train_arg.add_argument('--train_patience', type=int, default=100, 40 | help='Number of epochs to wait before stopping train') 41 | train_arg.add_argument('--dataset', type=str, default='cifar10', 42 | help='Dataset for training: {mnist, cifar10}') 43 | train_arg.add_argument('--planes', type=int, default=16, 44 | help='starting layer width') 45 | train_arg.add_argument('--num_caps', type=int, default=32, 46 | help="# of capsules per layer") 47 | train_arg.add_argument('--caps_size', type=int, default=16, 48 | help="# of neurons per capsule") 49 | train_arg.add_argument('--depth', type=int, default=1, 50 | help="depth of additional layers") 51 | 52 | 53 | # other params 54 | misc_arg = add_argument_group('Misc.') 55 | misc_arg.add_argument('--name', type=str, default=None, 56 | help='Name of model to load / save') 57 | misc_arg.add_argument('--best', type=str2bool, default=True, 58 | help='Load best model or most recent for testing') 59 | misc_arg.add_argument('--random_seed', type=int, default=2018, 60 | help='Seed to ensure reproducibility') 61 | misc_arg.add_argument('--data_dir', type=str, default='./data', 62 | help='Directory in which data is stored') 63 | misc_arg.add_argument('--ckpt_dir', type=str, default='./ckpt', 64 | help='Directory in which to save model checkpoints') 65 | misc_arg.add_argument('--logs_dir', type=str, default='./logs/', 66 | help='Directory in which Tensorboard logs wil be stored') 67 | misc_arg.add_argument('--use_tensorboard', type=str2bool, default=True, 68 | help='Whether to use tensorboard for visualization') 69 | misc_arg.add_argument('--resume', type=str2bool, default=False, 70 | help='Whether to resume training from checkpoint') 71 | misc_arg.add_argument('--print_freq', type=int, default=10, 72 | help='How frequently to print training details') 73 | 74 | misc_arg.add_argument('--attack', type=str2bool, default=False, 75 | help='Whether to test against attack') 76 | misc_arg.add_argument('--attack_type', type=str, default='fgsm', 77 | help='Attack to perform: {fgms, bim}') 78 | misc_arg.add_argument('--attack_eps', type=float, default=0.1, 79 | help='eps for adv attack') 80 | misc_arg.add_argument('--targeted', type=str2bool, default=False, 81 | help='if true, do targeted attack') 82 | train_arg.add_argument('--exp', type=str, default='', 83 | help="viewpoint exp name (NULL, azimuth, elevation, full)") 84 | train_arg.add_argument('--familiar', type=str2bool, default=True, 85 | help="viewpoint exp setting (novel, familiar)") 86 | 87 | 88 | def get_config(): 89 | config, unparsed = parser.parse_known_args() 90 | return config, unparsed 91 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torchvision import datasets 5 | from torchvision import transforms 6 | from torch.utils.data import Subset 7 | from norb import smallNORBViewPoint, smallNORB 8 | 9 | 10 | def get_train_valid_loader(data_dir, 11 | dataset, 12 | batch_size, 13 | random_seed, 14 | exp='azimuth', 15 | valid_size=0.1, 16 | shuffle=True, 17 | num_workers=4, 18 | pin_memory=False): 19 | 20 | data_dir = data_dir + '/' + dataset 21 | 22 | if dataset == "cifar10": 23 | trans = [transforms.RandomCrop(32, padding=4), 24 | transforms.RandomHorizontalFlip(0.5), 25 | transforms.ToTensor(), 26 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))] 27 | dataset = datasets.CIFAR10(data_dir, train=True, download=True, 28 | transform=transforms.Compose(trans)) 29 | 30 | elif dataset == "svhn": 31 | normalize = transforms.Normalize(mean=[x / 255.0 for x in[109.9, 109.7, 113.8]], 32 | std=[x / 255.0 for x in [50.1, 50.6, 50.8]]) 33 | trans = [transforms.RandomCrop(32, padding=4), 34 | transforms.ToTensor(), 35 | normalize] 36 | dataset = datasets.SVHN(data_dir, split='train', download=True, 37 | transform=transforms.Compose(trans)) 38 | 39 | elif dataset == "smallnorb": 40 | trans = [transforms.Resize(48), 41 | transforms.RandomCrop(32), 42 | transforms.ColorJitter(brightness=32./255, contrast=0.3), 43 | transforms.ToTensor(), 44 | #transforms.Normalize((0.7199,), (0.117,)) 45 | ] 46 | if exp in VIEWPOINT_EXPS: 47 | train_set = smallNORBViewPoint(data_dir, exp=exp, train=True, download=True, 48 | transform=transforms.Compose(trans)) 49 | trans = trans[:1] + [transforms.CenterCrop(32)] +trans[3:] 50 | valid_set = smallNORBViewPoint(data_dir, exp=exp, train=False, familiar=False, download=False, 51 | transform=transforms.Compose(trans)) 52 | elif exp == "full": 53 | dataset = smallNORB(data_dir, train=True, download=True, 54 | transform = transforms.Compose(trans)) 55 | 56 | if exp not in VIEWPOINT_EXPS: 57 | num_train = len(dataset) 58 | indices = list(range(num_train)) 59 | split = int(np.floor(valid_size * num_train)) 60 | 61 | if shuffle: 62 | np.random.seed(random_seed) 63 | np.random.shuffle(indices) 64 | 65 | train_idx = indices[split:] 66 | valid_idx = indices[:split] 67 | 68 | train_set = Subset(dataset, train_idx) 69 | valid_set = Subset(dataset, valid_idx) 70 | 71 | train_loader = torch.utils.data.DataLoader( 72 | train_set, batch_size=batch_size, shuffle=True, 73 | num_workers=num_workers, pin_memory=pin_memory, 74 | ) 75 | 76 | valid_loader = torch.utils.data.DataLoader( 77 | valid_set, batch_size=batch_size, shuffle=False, 78 | num_workers=num_workers, pin_memory=pin_memory, 79 | ) 80 | 81 | return train_loader, valid_loader 82 | 83 | def get_test_loader(data_dir, 84 | dataset, 85 | batch_size, 86 | exp='azimuth', # smallnorb only 87 | familiar=True, # smallnorb only 88 | num_workers=4, 89 | pin_memory=False): 90 | 91 | data_dir = data_dir + '/' + dataset 92 | 93 | if dataset == "cifar10": 94 | trans = [transforms.ToTensor(), 95 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))] 96 | dataset = datasets.CIFAR10(data_dir, train=False, download=False, 97 | transform=transforms.Compose(trans)) 98 | 99 | elif dataset == "svhn": 100 | normalize = transforms.Normalize(mean=[x / 255.0 for x in[109.9, 109.7, 113.8]], 101 | std=[x / 255.0 for x in [50.1, 50.6, 50.8]]) 102 | trans = [transforms.ToTensor(), 103 | normalize] 104 | dataset = datasets.SVHN(data_dir, split='test', download=True, 105 | transform=transforms.Compose(trans)) 106 | 107 | elif dataset == "smallnorb": 108 | trans = [transforms.Resize(48), 109 | transforms.CenterCrop(32), 110 | transforms.ToTensor(), 111 | #transforms.Normalize((0.7199,), (0.117,)) 112 | ] 113 | if exp in VIEWPOINT_EXPS: 114 | dataset = smallNORBViewPoint(data_dir, exp=exp, familiar=familiar, train=False, download=True, 115 | transform=transforms.Compose(trans)) 116 | elif exp == "full": 117 | dataset = smallNORB(data_dir, train=False, download=True, 118 | transform=transforms.Compose(trans)) 119 | 120 | data_loader = torch.utils.data.DataLoader( 121 | dataset, batch_size=batch_size, shuffle=False, 122 | num_workers=num_workers, pin_memory=pin_memory, 123 | ) 124 | 125 | return data_loader 126 | 127 | DATASET_CONFIGS = { 128 | 'cifar10': {'size': 32, 'channels': 3, 'classes': 10}, 129 | 'svhn': {'size': 32, 'channels': 3, 'classes': 10}, 130 | 'smallnorb': {'size': 32, 'channels': 1, 'classes': 5}, 131 | } 132 | 133 | VIEWPOINT_EXPS = ['azimuth', 'elevation'] 134 | -------------------------------------------------------------------------------- /models/convnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import math 6 | 7 | from modules import * 8 | from utils import weights_init 9 | 10 | 11 | class ConvNet(nn.Module): 12 | def __init__(self, planes, cfg_data, num_caps, caps_size, depth, mode): 13 | caps_size = 16 14 | super(ConvNet, self).__init__() 15 | channels, classes = cfg_data['channels'], cfg_data['classes'] 16 | self.num_caps = num_caps 17 | self.caps_size = caps_size 18 | self.depth = depth 19 | self.mode = mode 20 | 21 | self.layers = nn.Sequential( 22 | nn.Conv2d(channels, planes, kernel_size=3, stride=1, padding=1, bias=False), 23 | nn.BatchNorm2d(planes), 24 | nn.ReLU(True), 25 | nn.Conv2d(planes, planes*2, kernel_size=3, stride=2, padding=1, bias=False), 26 | nn.BatchNorm2d(planes*2), 27 | nn.ReLU(True), 28 | nn.Conv2d(planes*2, planes*2, kernel_size=3, stride=1, padding=1, bias=False), 29 | nn.BatchNorm2d(planes*2), 30 | nn.ReLU(True), 31 | nn.Conv2d(planes*2, planes*4, kernel_size=3, stride=2, padding=1, bias=False), 32 | nn.BatchNorm2d(planes*4), 33 | nn.ReLU(True), 34 | nn.Conv2d(planes*4, planes*4, kernel_size=3, stride=1, padding=1, bias=False), 35 | nn.BatchNorm2d(planes*4), 36 | nn.ReLU(True), 37 | nn.Conv2d(planes*4, planes*8, kernel_size=3, stride=2, padding=1, bias=False), 38 | nn.BatchNorm2d(planes*8), 39 | nn.ReLU(True), 40 | ) 41 | 42 | self.conv_layers = nn.ModuleList() 43 | self.norm_layers = nn.ModuleList() 44 | 45 | #========= ConvCaps Layers 46 | for d in range(1, depth): 47 | if self.mode == 'DR': 48 | self.conv_layers.append(DynamicRouting2d(num_caps, num_caps, caps_size, caps_size, kernel_size=3, stride=1, padding=1)) 49 | nn.init.normal_(self.conv_layers[0].W, 0, 0.5) 50 | elif self.mode == 'EM': 51 | self.conv_layers.append(EmRouting2d(num_caps, num_caps, caps_size, kernel_size=3, stride=1, padding=1)) 52 | self.norm_layers.append(nn.BatchNorm2d(4*4*num_caps)) 53 | elif self.mode == 'SR': 54 | self.conv_layers.append(SelfRouting2d(num_caps, num_caps, caps_size, caps_size, kernel_size=3, stride=1, padding=1, pose_out=True)) 55 | self.norm_layers.append(nn.BatchNorm2d(planes*num_caps)) 56 | else: 57 | break 58 | 59 | final_shape = 4 60 | 61 | # DR 62 | if self.mode == 'DR': 63 | self.conv_pose = nn.Conv2d(8*planes, num_caps*caps_size, kernel_size=3, stride=1, padding=1, bias=False) 64 | self.bn_pose = nn.BatchNorm2d(num_caps*caps_size) 65 | self.fc = DynamicRouting2d(num_caps, classes, caps_size, caps_size, kernel_size=final_shape, padding=0) 66 | # initialize so that output logits are in reasonable range (0.1-0.9) 67 | nn.init.normal_(self.fc.W, 0, 0.1) 68 | 69 | # EM 70 | elif self.mode == 'EM': 71 | self.conv_a = nn.Conv2d(8*planes, num_caps, kernel_size=3, stride=1, padding=1, bias=False) 72 | self.conv_pose = nn.Conv2d(8*planes, num_caps*caps_size, kernel_size=3, stride=1, padding=1, bias=False) 73 | self.bn_a = nn.BatchNorm2d(num_caps) 74 | self.bn_pose = nn.BatchNorm2d(num_caps*caps_size) 75 | self.fc = EmRouting2d(num_caps, classes, caps_size, kernel_size=final_shape, padding=0) 76 | 77 | # SR 78 | elif self.mode == 'SR': 79 | self.conv_a = nn.Conv2d(8*planes, num_caps, kernel_size=3, stride=1, padding=1, bias=False) 80 | self.conv_pose = nn.Conv2d(8*planes, num_caps*caps_size, kernel_size=3, stride=1, padding=1, bias=False) 81 | self.bn_a = nn.BatchNorm2d(num_caps) 82 | self.bn_pose = nn.BatchNorm2d(num_caps*caps_size) 83 | self.fc = SelfRouting2d(num_caps, classes, caps_size, 1, kernel_size=final_shape, padding=0, pose_out=False) 84 | 85 | # avg pooling 86 | elif self.mode == 'AVG': 87 | self.pool = nn.AvgPool2d(final_shape) 88 | self.fc = nn.Linear(8*planes, classes) 89 | 90 | # max pooling 91 | elif self.mode == 'MAX': 92 | self.pool = nn.MaxPool2d(final_shape) 93 | self.fc = nn.Linear(8*planes, classes) 94 | 95 | elif self.mode == 'FC': 96 | self.conv_ = nn.Conv2d(8*planes, num_caps*caps_size, kernel_size=3, stride=1, padding=1, bias=False) 97 | self.bn_ = nn.BatchNorm2d(num_caps*caps_size) 98 | 99 | self.fc = nn.Linear(num_caps*caps_size*final_shape*final_shape, classes) 100 | 101 | self.apply(weights_init) 102 | 103 | def forward(self, x): 104 | out = self.layers(x) 105 | 106 | # DR 107 | if self.mode == 'DR': 108 | pose = self.bn_pose(self.conv_pose(out)) 109 | 110 | b, c, h, w = pose.shape 111 | pose = pose.permute(0, 2, 3, 1).contiguous() 112 | pose = squash(pose.view(b, h, w, self.num_caps, self.caps_size)) 113 | pose = pose.view(b, h, w, -1) 114 | pose = pose.permute(0, 3, 1, 2) 115 | 116 | for m in self.conv_layers: 117 | pose = m(pose) 118 | 119 | out = self.fc(pose) 120 | out = out.view(b, -1, self.caps_size) 121 | out = out.norm(dim=-1) 122 | 123 | # EM 124 | elif self.mode == 'EM': 125 | a, pose = self.conv_a(out), self.conv_pose(out) 126 | a, pose = torch.sigmoid(self.bn_a(a)), self.bn_pose(pose) 127 | 128 | for m, bn in zip(self.conv_layers, self.norm_layers): 129 | a, pose = m(a, pose) 130 | pose = bn(pose) 131 | 132 | a, _ = self.fc(a, pose) 133 | out = a.view(a.size(0), -1) 134 | 135 | # ours 136 | elif self.mode == 'SR': 137 | a, pose = self.conv_a(out), self.conv_pose(out) 138 | a, pose = torch.sigmoid(self.bn_a(a)), self.bn_pose(pose) 139 | 140 | for m, bn in zip(self.conv_layers, self.norm_layers): 141 | a, pose = m(a, pose) 142 | pose = bn(pose) 143 | 144 | a, _ = self.fc(a, pose) 145 | out = a.view(a.size(0), -1) 146 | out = out.log() 147 | 148 | elif self.mode == 'AVG' or self.mode =='MAX': 149 | out = self.pool(out) 150 | out = out.view(out.size(0), -1) 151 | out = self.fc(out) 152 | 153 | elif self.mode == 'FC': 154 | out = F.relu(self.bn_(self.conv_(out))) 155 | out = out.view(out.size(0), -1) 156 | out = self.fc(out) 157 | 158 | return out 159 | 160 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from modules import * 6 | from utils import weights_init 7 | 8 | 9 | class LambdaLayer(nn.Module): 10 | def __init__(self, lambd): 11 | super(LambdaLayer, self).__init__() 12 | self.lambd = lambd 13 | 14 | def forward(self, x): 15 | return self.lambd(x) 16 | 17 | class BasicBlock(nn.Module): 18 | expansion = 1 19 | 20 | def __init__(self, in_planes, planes, stride=1, option='A'): 21 | super(BasicBlock, self).__init__() 22 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | 27 | self.shortcut = nn.Sequential() 28 | if stride != 1 or in_planes != planes: 29 | if option == 'A': 30 | """ 31 | For CIFAR10 ResNet paper uses option A. 32 | """ 33 | self.shortcut = LambdaLayer(lambda x: 34 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 35 | elif option == 'B': 36 | self.shortcut = nn.Sequential( 37 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 38 | nn.BatchNorm2d(self.expansion * planes) 39 | ) 40 | 41 | def forward(self, x): 42 | out = F.relu(self.bn1(self.conv1(x))) 43 | out = self.bn2(self.conv2(out)) 44 | out += self.shortcut(x) 45 | out = F.relu(out) 46 | return out 47 | 48 | 49 | class ResNet(nn.Module): 50 | def __init__(self, block, num_blocks, planes, num_caps, caps_size, depth, cfg_data, mode): 51 | super(ResNet, self).__init__() 52 | self.in_planes = planes 53 | channels, classes = cfg_data['channels'], cfg_data['classes'] 54 | 55 | self.num_caps = num_caps 56 | self.caps_size = caps_size 57 | 58 | self.depth = depth 59 | 60 | self.conv1 = nn.Conv2d(channels, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False) 61 | self.bn1 = nn.BatchNorm2d(self.in_planes) 62 | self.layer1 = self._make_layer(block, planes, num_blocks[0], stride=1) 63 | self.layer2 = self._make_layer(block, 2*planes, num_blocks[1], stride=2) 64 | self.layer3 = self._make_layer(block, 4*planes, num_blocks[2], stride=2) 65 | 66 | self.mode = mode 67 | 68 | self.conv_layers = nn.ModuleList() 69 | self.norm_layers = nn.ModuleList() 70 | 71 | for d in range(1, depth): 72 | stride = 2 if d == 1 else 1 73 | if self.mode == 'DR': 74 | self.conv_layers.append(DynamicRouting2d(num_caps, num_caps, caps_size, caps_size, kernel_size=3, stride=stride, padding=1)) 75 | self.norm_layers.append(nn.BatchNorm2d(caps_size*num_caps)) 76 | elif self.mode == 'EM': 77 | self.conv_layers.append(EmRouting2d(num_caps, num_caps, caps_size, kernel_size=3, stride=stride, padding=1)) 78 | self.norm_layers.append(nn.BatchNorm2d(caps_size*num_caps)) 79 | elif self.mode == 'SR': 80 | self.conv_layers.append(SelfRouting2d(num_caps, num_caps, caps_size, caps_size, kernel_size=3, stride=stride, padding=1, pose_out=True)) 81 | self.norm_layers.append(nn.BatchNorm2d(caps_size*num_caps)) 82 | else: 83 | break 84 | 85 | final_shape = 8 if depth == 1 else 4 86 | 87 | # DR 88 | if self.mode == 'DR': 89 | self.conv_pose = nn.Conv2d(4*planes, num_caps*caps_size, kernel_size=3, stride=1, padding=1, bias=False) 90 | self.bn_pose = nn.BatchNorm2d(num_caps*caps_size) 91 | self.fc = DynamicRouting2d(num_caps, classes, caps_size, caps_size, kernel_size=final_shape, padding=0) 92 | 93 | # EM 94 | elif self.mode == 'EM': 95 | self.conv_a = nn.Conv2d(4*planes, num_caps, kernel_size=3, stride=1, padding=1, bias=False) 96 | self.conv_pose = nn.Conv2d(4*planes, num_caps*caps_size, kernel_size=3, stride=1, padding=1, bias=False) 97 | self.bn_a = nn.BatchNorm2d(num_caps) 98 | self.bn_pose = nn.BatchNorm2d(num_caps*caps_size) 99 | self.fc = EmRouting2d(num_caps, classes, caps_size, kernel_size=final_shape, padding=0) 100 | 101 | # SR 102 | elif self.mode == 'SR': 103 | self.conv_a = nn.Conv2d(4*planes, num_caps, kernel_size=3, stride=1, padding=1, bias=False) 104 | self.conv_pose = nn.Conv2d(4*planes, num_caps*caps_size, kernel_size=3, stride=1, padding=1, bias=False) 105 | self.bn_a = nn.BatchNorm2d(num_caps) 106 | self.bn_pose = nn.BatchNorm2d(num_caps*caps_size) 107 | self.fc = SelfRouting2d(num_caps, classes, caps_size, 1, kernel_size=final_shape, padding=0, pose_out=False) 108 | 109 | # avg pooling 110 | elif self.mode == 'AVG': 111 | self.pool = nn.AvgPool2d(final_shape) 112 | self.fc = nn.Linear(4*planes, classes) 113 | 114 | # max pooling 115 | elif self.mode == 'MAX': 116 | self.pool = nn.MaxPool2d(final_shape) 117 | self.fc = nn.Linear(4*planes, classes) 118 | 119 | elif self.mode == 'FC': 120 | self.conv_ = nn.Conv2d(4*planes, num_caps*caps_size, kernel_size=3, stride=stride, padding=1) 121 | self.bn_ = nn.BatchNorm2d(num_caps*caps_size) 122 | 123 | self.fc = nn.Linear(num_caps*caps_size*final_shape*final_shape, classes) 124 | 125 | self.apply(weights_init) 126 | 127 | def _make_layer(self, block, planes, num_blocks, stride): 128 | strides = [stride] + [1]*(num_blocks-1) 129 | layers = [] 130 | for stride in strides: 131 | layers.append(block(self.in_planes, planes, stride)) 132 | self.in_planes = planes * block.expansion 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | out = F.relu(self.bn1(self.conv1(x))) 138 | out = self.layer1(out) 139 | out = self.layer2(out) 140 | out = self.layer3(out) 141 | 142 | # DR 143 | if self.mode == 'DR': 144 | pose = self.bn_pose(self.conv_pose(out)) 145 | 146 | b, c, h, w = pose.shape 147 | pose = pose.permute(0, 2, 3, 1).contiguous() 148 | pose = squash(pose.view(b, h, w, self.num_caps, self.caps_size)) 149 | pose = pose.view(b, h, w, -1) 150 | pose = pose.permute(0, 3, 1, 2) 151 | 152 | for m in self.conv_layers: 153 | pose = m(pose) 154 | 155 | out = self.fc(pose).squeeze() 156 | out = out.view(b, -1, self.caps_size) 157 | 158 | out = out.norm(dim=-1) 159 | out = out / out.sum(dim=1, keepdim=True) 160 | out = out.log() 161 | 162 | # EM 163 | elif self.mode == 'EM': 164 | a, pose = self.conv_a(out), self.conv_pose(out) 165 | a, pose = torch.sigmoid(self.bn_a(a)), self.bn_pose(pose) 166 | 167 | for m, bn in zip(self.conv_layers, self.norm_layers): 168 | a, pose = m(a, pose) 169 | pose = bn(pose) 170 | 171 | a, _ = self.fc(a, pose) 172 | out = a.view(a.size(0), -1) 173 | out = out / out.sum(dim=1, keepdim=True) 174 | out = out.log() 175 | 176 | # ours 177 | if self.mode == 'SR': 178 | a, pose = self.conv_a(out), self.conv_pose(out) 179 | a, pose = torch.sigmoid(self.bn_a(a)), self.bn_pose(pose) 180 | 181 | for m, bn in zip(self.conv_layers, self.norm_layers): 182 | a, pose = m(a, pose) 183 | pose = bn(pose) 184 | 185 | a, _ = self.fc(a, pose) 186 | out = a.view(a.size(0), -1) 187 | out = out.log() 188 | 189 | elif self.mode == 'AVG' or self.mode =='MAX': 190 | out = self.pool(out) 191 | out = out.view(out.size(0), -1) 192 | out = self.fc(out) 193 | 194 | elif self.mode == 'FC': 195 | out = F.relu(self.bn_(self.conv_(out))) 196 | out = out.view(out.size(0), -1) 197 | out = self.fc(out) 198 | 199 | return out 200 | 201 | def forward_activations(self, x): 202 | out = F.relu(self.bn1(self.conv1(x))) 203 | out = self.layer1(out) 204 | out = self.layer2(out) 205 | out = self.layer3(out) 206 | 207 | if self.mode == 'DR': 208 | pose = self.bn_pose(self.conv_pose(out)) 209 | 210 | b, c, h, w = pose.shape 211 | pose = pose.permute(0, 2, 3, 1).contiguous() 212 | pose = squash(pose.view(b, h, w, self.num_caps, self.caps_size)) 213 | pose = pose.view(b, h, w, -1) 214 | pose = pose.permute(0, 3, 1, 2) 215 | a = pose.norm(dim=1) 216 | 217 | elif self.mode == 'EM': 218 | a = torch.sigmoid(self.bn_a(self.conv_a(out))) 219 | 220 | elif self.mode == 'SR': 221 | a = torch.sigmoid(self.bn_a(self.conv_a(out))) 222 | 223 | else: 224 | raise NotImplementedError 225 | 226 | return a 227 | 228 | 229 | def resnet20(planes, cfg_data, num_caps, caps_size, depth, mode): 230 | return ResNet(BasicBlock, [3, 3, 3], planes, num_caps, caps_size, depth, cfg_data, mode) 231 | 232 | def resnet32(planes, cfg_data, num_caps, caps_size, depth, mode): 233 | return ResNet(BasicBlock, [5, 5, 5], planes, num_caps, caps_size, depth, cfg_data, mode) 234 | 235 | def resnet44(planes, cfg_data, num_caps, caps_size, depth, mode): 236 | return ResNet(BasicBlock, [7, 7, 7], planes, num_caps, caps_size, depth, cfg_data, mode) 237 | 238 | def resnet56(planes, cfg_data, num_caps, caps_size, depth, mode): 239 | return ResNet(BasicBlock, [9, 9, 9], planes, num_caps, caps_size, depth, cfg_data, mode) 240 | 241 | def resnet110(planes, cfg_data, num_caps, caps_size, depth, mode): 242 | return ResNet(BasicBlock, [18, 18, 18], planes, num_caps, caps_size, depth, cfg_data, mode) 243 | 244 | 245 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import math 6 | 7 | from utils import squash 8 | 9 | 10 | eps = 1e-12 11 | 12 | 13 | class DynamicRouting2d(nn.Module): 14 | def __init__(self, A, B, C, D, kernel_size=1, stride=1, padding=1, iters=3): 15 | super(DynamicRouting2d, self).__init__() 16 | self.A = A 17 | self.B = B 18 | self.C = C 19 | self.D = D 20 | 21 | self.k = kernel_size 22 | self.kk = kernel_size ** 2 23 | self.kkA = self.kk * A 24 | 25 | self.stride = stride 26 | self.pad = padding 27 | 28 | self.iters = iters 29 | self.W = nn.Parameter(torch.FloatTensor(self.kkA, B*D, C)) 30 | nn.init.kaiming_uniform_(self.W) 31 | 32 | def forward(self, pose): 33 | # x: [b, AC, h, w] 34 | b, _, h, w = pose.shape 35 | # [b, ACkk, l] 36 | pose = F.unfold(pose, self.k, stride=self.stride, padding=self.pad) 37 | l = pose.shape[-1] 38 | # [b, A, C, kk, l] 39 | pose = pose.view(b, self.A, self.C, self.kk, l) 40 | # [b, l, kk, A, C] 41 | pose = pose.permute(0, 4, 3, 1, 2).contiguous() 42 | # [b, l, kkA, C, 1] 43 | pose = pose.view(b, l, self.kkA, self.C, 1) 44 | 45 | # [b, l, kkA, BD] 46 | pose_out = torch.matmul(self.W, pose).squeeze(-1) 47 | # [b, l, kkA, B, D] 48 | pose_out = pose_out.view(b, l, self.kkA, self.B, self.D) 49 | 50 | # [b, l, kkA, B, 1] 51 | b = pose.new_zeros(b, l, self.kkA, self.B, 1) 52 | for i in range(self.iters): 53 | c = torch.softmax(b, dim=3) 54 | 55 | # [b, l, 1, B, D] 56 | s = (c * pose_out).sum(dim=2, keepdim=True) 57 | # [b, l, 1, B, D] 58 | v = squash(s) 59 | 60 | b = b + (v * pose_out).sum(dim=-1, keepdim=True) 61 | 62 | # [b, l, B, D] 63 | v = v.squeeze(2) 64 | # [b, l, BD] 65 | v = v.view(v.shape[0], l, -1) 66 | # [b, BD, l] 67 | v = v.transpose(1,2).contiguous() 68 | 69 | oh = ow = math.floor(l**(1/2)) 70 | 71 | # [b, BD, oh, ow] 72 | return v.view(v.shape[0], -1, oh, ow) 73 | 74 | 75 | class EmRouting2d(nn.Module): 76 | def __init__(self, A, B, caps_size, kernel_size=3, stride=1, padding=1, iters=3, final_lambda=1e-2): 77 | super(EmRouting2d, self).__init__() 78 | self.A = A 79 | self.B = B 80 | self.psize = caps_size 81 | self.mat_dim = int(caps_size ** 0.5) 82 | 83 | self.k = kernel_size 84 | self.kk = kernel_size ** 2 85 | self.kkA = self.kk * A 86 | 87 | self.stride = stride 88 | self.pad = padding 89 | 90 | self.iters = iters 91 | 92 | self.W = nn.Parameter(torch.FloatTensor(self.kkA, B, self.mat_dim, self.mat_dim)) 93 | nn.init.kaiming_uniform_(self.W.data) 94 | 95 | self.beta_u = nn.Parameter(torch.FloatTensor(1, 1, B, 1)) 96 | self.beta_a = nn.Parameter(torch.FloatTensor(1, 1, B)) 97 | nn.init.constant_(self.beta_u, 0) 98 | nn.init.constant_(self.beta_a, 0) 99 | 100 | self.final_lambda = final_lambda 101 | self.ln_2pi = math.log(2*math.pi) 102 | 103 | def m_step(self, v, a_in, r): 104 | # v: [b, l, kkA, B, psize] 105 | # a_in: [b, l, kkA] 106 | # r: [b, l, kkA, B, 1] 107 | b, l, _, _, _ = v.shape 108 | 109 | # r: [b, l, kkA, B, 1] 110 | r = r * a_in.view(b, l, -1, 1, 1) 111 | # r_sum: [b, l, 1, B, 1] 112 | r_sum = r.sum(dim=2, keepdim=True) 113 | # coeff: [b, l, kkA, B, 1] 114 | coeff = r / (r_sum + eps) 115 | 116 | # mu: [b, l, 1, B, psize] 117 | mu = torch.sum(coeff * v, dim=2, keepdim=True) 118 | # sigma_sq: [b, l, 1, B, psize] 119 | sigma_sq = torch.sum(coeff * (v - mu)**2, dim=2, keepdim=True) + eps 120 | 121 | # [b, l, B, 1] 122 | r_sum = r_sum.squeeze(2) 123 | # [b, l, B, psize] 124 | sigma_sq = sigma_sq.squeeze(2) 125 | # [1, 1, B, 1] + [b, l, B, psize] * [b, l, B, 1] 126 | cost_h = (self.beta_u + torch.log(sigma_sq.sqrt())) * r_sum 127 | # cost_h = (torch.log(sigma_sq.sqrt())) * r_sum 128 | 129 | # [b, l, B] 130 | a_out = torch.sigmoid(self.lambda_*(self.beta_a - cost_h.sum(dim=3))) 131 | # a_out = torch.sigmoid(self.lambda_*(-cost_h.sum(dim=3))) 132 | 133 | return a_out, mu, sigma_sq 134 | 135 | def e_step(self, v, a_out, mu, sigma_sq): 136 | b, l, _ = a_out.shape 137 | # v: [b, l, kkA, B, psize] 138 | # a_out: [b, l, B] 139 | # mu: [b, l, 1, B, psize] 140 | # sigma_sq: [b, l, B, psize] 141 | 142 | # [b, l, 1, B, psize] 143 | sigma_sq = sigma_sq.unsqueeze(2) 144 | 145 | ln_p_j = -0.5 * torch.sum(torch.log(sigma_sq*self.ln_2pi), dim=-1) \ 146 | - torch.sum((v - mu)**2 / (2 * sigma_sq), dim=-1) 147 | 148 | # [b, l, kkA, B] 149 | ln_ap = ln_p_j + torch.log(a_out.view(b, l, 1, self.B)) 150 | # [b, l, kkA, B] 151 | r = torch.softmax(ln_ap, dim=-1) 152 | # [b, l, kkA, B, 1] 153 | return r.unsqueeze(-1) 154 | 155 | def forward(self, a_in, pose): 156 | # pose: [batch_size, A, psize] 157 | # a: [batch_size, A] 158 | batch_size = a_in.shape[0] 159 | 160 | # a: [b, A, h, w] 161 | # pose: [b, A*psize, h, w] 162 | b, _, h, w = a_in.shape 163 | 164 | # [b, A*psize*kk, l] 165 | pose = F.unfold(pose, self.k, stride=self.stride, padding=self.pad) 166 | l = pose.shape[-1] 167 | # [b, A, psize, kk, l] 168 | pose = pose.view(b, self.A, self.psize, self.kk, l) 169 | # [b, l, kk, A, psize] 170 | pose = pose.permute(0, 4, 3, 1, 2).contiguous() 171 | # [b, l, kkA, psize] 172 | pose = pose.view(b, l, self.kkA, self.psize) 173 | # [b, l, kkA, 1, mat_dim, mat_dim] 174 | pose = pose.view(batch_size, l, self.kkA, self.mat_dim, self.mat_dim).unsqueeze(3) 175 | 176 | # [b, l, kkA, B, mat_dim, mat_dim] 177 | pose_out = torch.matmul(pose, self.W) 178 | 179 | # [b, l, kkA, B, psize] 180 | v = pose_out.view(batch_size, l, self.kkA, self.B, -1) 181 | 182 | # [b, kkA, l] 183 | a_in = F.unfold(a_in, self.k, stride=self.stride, padding=self.pad) 184 | # [b, A, kk, l] 185 | a_in = a_in.view(b, self.A, self.kk, l) 186 | # [b, l, kk, A] 187 | a_in = a_in.permute(0, 3, 2, 1).contiguous() 188 | # [b, l, kkA] 189 | a_in = a_in.view(b, l, self.kkA) 190 | 191 | r = a_in.new_ones(batch_size, l, self.kkA, self.B, 1) 192 | for i in range(self.iters): 193 | # this is from open review 194 | self.lambda_ = self.final_lambda * (1 - 0.95 ** (i+1)) 195 | a_out, pose_out, sigma_sq = self.m_step(v, a_in, r) 196 | if i < self.iters - 1: 197 | r = self.e_step(v, a_out, pose_out, sigma_sq) 198 | 199 | # [b, l, B*psize] 200 | pose_out = pose_out.squeeze(2).view(b, l, -1) 201 | # [b, B*psize, l] 202 | pose_out = pose_out.transpose(1, 2) 203 | # [b, B, l] 204 | a_out = a_out.transpose(1, 2).contiguous() 205 | 206 | oh = ow = math.floor(l**(1/2)) 207 | 208 | a_out = a_out.view(b, -1, oh, ow) 209 | pose_out = pose_out.view(b, -1, oh, ow) 210 | 211 | return a_out, pose_out 212 | 213 | 214 | class SelfRouting2d(nn.Module): 215 | def __init__(self, A, B, C, D, kernel_size=3, stride=1, padding=1, pose_out=False): 216 | super(SelfRouting2d, self).__init__() 217 | self.A = A 218 | self.B = B 219 | self.C = C 220 | self.D = D 221 | 222 | self.k = kernel_size 223 | self.kk = kernel_size ** 2 224 | self.kkA = self.kk * A 225 | 226 | self.stride = stride 227 | self.pad = padding 228 | 229 | self.pose_out = pose_out 230 | 231 | if pose_out: 232 | self.W1 = nn.Parameter(torch.FloatTensor(self.kkA, B*D, C)) 233 | nn.init.kaiming_uniform_(self.W1.data) 234 | 235 | self.W2 = nn.Parameter(torch.FloatTensor(self.kkA, B, C)) 236 | self.b2 = nn.Parameter(torch.FloatTensor(1, 1, self.kkA, B)) 237 | 238 | nn.init.constant_(self.W2.data, 0) 239 | nn.init.constant_(self.b2.data, 0) 240 | 241 | def forward(self, a, pose): 242 | # a: [b, A, h, w] 243 | # pose: [b, AC, h, w] 244 | b, _, h, w = a.shape 245 | 246 | # [b, ACkk, l] 247 | pose = F.unfold(pose, self.k, stride=self.stride, padding=self.pad) 248 | l = pose.shape[-1] 249 | # [b, A, C, kk, l] 250 | pose = pose.view(b, self.A, self.C, self.kk, l) 251 | # [b, l, kk, A, C] 252 | pose = pose.permute(0, 4, 3, 1, 2).contiguous() 253 | # [b, l, kkA, C, 1] 254 | pose = pose.view(b, l, self.kkA, self.C, 1) 255 | 256 | if hasattr(self, 'W1'): 257 | # [b, l, kkA, BD] 258 | pose_out = torch.matmul(self.W1, pose).squeeze(-1) 259 | # [b, l, kkA, B, D] 260 | pose_out = pose_out.view(b, l, self.kkA, self.B, self.D) 261 | 262 | # [b, l, kkA, B] 263 | logit = torch.matmul(self.W2, pose).squeeze(-1) + self.b2 264 | 265 | # [b, l, kkA, B] 266 | r = torch.softmax(logit, dim=3) 267 | 268 | # [b, kkA, l] 269 | a = F.unfold(a, self.k, stride=self.stride, padding=self.pad) 270 | # [b, A, kk, l] 271 | a = a.view(b, self.A, self.kk, l) 272 | # [b, l, kk, A] 273 | a = a.permute(0, 3, 2, 1).contiguous() 274 | # [b, l, kkA, 1] 275 | a = a.view(b, l, self.kkA, 1) 276 | 277 | # [b, l, kkA, B] 278 | ar = a * r 279 | # [b, l, 1, B] 280 | ar_sum = ar.sum(dim=2, keepdim=True) 281 | # [b, l, kkA, B, 1] 282 | coeff = (ar / (ar_sum)).unsqueeze(-1) 283 | 284 | # [b, l, B] 285 | # a_out = ar_sum.squeeze(2) 286 | a_out = ar_sum / a.sum(dim=2, keepdim=True) 287 | a_out = a_out.squeeze(2) 288 | 289 | # [b, B, l] 290 | a_out = a_out.transpose(1,2) 291 | 292 | if hasattr(self, 'W1'): 293 | # [b, l, B, D] 294 | pose_out = (coeff * pose_out).sum(dim=2) 295 | # [b, l, BD] 296 | pose_out = pose_out.view(b, l, -1) 297 | # [b, BD, l] 298 | pose_out = pose_out.transpose(1,2) 299 | 300 | oh = ow = math.floor(l**(1/2)) 301 | 302 | a_out = a_out.view(b, -1, oh, ow) 303 | if hasattr(self, 'W1'): 304 | pose_out = pose_out.view(b, -1, oh, ow) 305 | else: 306 | pose_out = None 307 | 308 | return a_out, pose_out 309 | 310 | 311 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | 5 | import os 6 | import time 7 | import shutil 8 | import math 9 | 10 | from tqdm import tqdm 11 | from utils import AverageMeter, save_config 12 | from tensorboardX import SummaryWriter 13 | 14 | from models import * 15 | from loss import * 16 | from data_loader import DATASET_CONFIGS 17 | 18 | from attack import Attack, extract_adv_images 19 | 20 | 21 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 22 | 23 | 24 | class Trainer(object): 25 | """ 26 | Trainer encapsulates all the logic necessary for 27 | training. 28 | 29 | All hyperparameters are provided by the user in the 30 | config file. 31 | """ 32 | def __init__(self, config, data_loader): 33 | """ 34 | Construct a new Trainer instance. 35 | 36 | Args 37 | ---- 38 | - config: object containing command line arguments. 39 | - data_loader: data iterator 40 | """ 41 | self.config = config 42 | 43 | # data params 44 | if config.is_train: 45 | self.train_loader = data_loader[0] 46 | self.valid_loader = data_loader[1] 47 | self.num_train = len(self.train_loader.dataset) 48 | self.num_valid = len(self.valid_loader.dataset) 49 | else: 50 | self.test_loader = data_loader 51 | self.num_test = len(self.test_loader.dataset) 52 | 53 | # training params 54 | self.epochs = config.epochs 55 | self.start_epoch = 0 56 | self.momentum = config.momentum 57 | self.weight_decay = config.weight_decay 58 | self.lr = config.init_lr 59 | 60 | # misc params 61 | self.best = config.best 62 | self.ckpt_dir = config.ckpt_dir 63 | self.logs_dir = config.logs_dir 64 | self.best_valid_acc = 0. 65 | self.counter = 0 66 | self.train_patience = config.train_patience 67 | self.use_tensorboard = config.use_tensorboard 68 | self.resume = config.resume 69 | self.print_freq = config.print_freq 70 | 71 | self.attack_type = config.attack_type 72 | self.attack_eps = config.attack_eps 73 | self.targeted = config.targeted 74 | 75 | self.name = config.name 76 | 77 | if config.name.endswith('dynamic_routing'): 78 | self.mode = 'DR' 79 | elif config.name.endswith('em_routing'): 80 | self.mode = 'EM' 81 | elif config.name.endswith('self_routing'): 82 | self.mode = 'SR' 83 | elif config.name.endswith('max'): 84 | self.mode = 'MAX' 85 | elif config.name.endswith('avg'): 86 | self.mode = 'AVG' 87 | elif config.name.endswith('fc'): 88 | self.mode = 'FC' 89 | else: 90 | raise NotImplementedError("Unknown model postfix") 91 | 92 | # initialize 93 | if config.name.startswith('resnet'): 94 | self.model = resnet20(config.planes, DATASET_CONFIGS[config.dataset], config.num_caps, config.caps_size, config.depth, mode=self.mode).to(device) 95 | elif config.name.startswith('convnet'): 96 | self.model = ConvNet(config.planes, DATASET_CONFIGS[config.dataset], config.num_caps, config.caps_size, config.depth, mode=self.mode).to(device) 97 | elif config.name.startswith('smallnet'): 98 | assert self.mode in ['SR', 'DR', 'EM'] 99 | self.model = SmallNet(DATASET_CONFIGS[config.dataset], mode=self.mode).to(device) 100 | else: 101 | raise NotImplementedError("Unknown model prefix") 102 | 103 | if torch.cuda.device_count() > 1: 104 | print("Let's use", torch.cuda.device_count(), "GPUs!") 105 | self.model = nn.DataParallel(self.model) 106 | 107 | self.loss = nn.CrossEntropyLoss().to(device) 108 | if self.mode in ['DR', 'EM', 'SR']: 109 | if config.dataset in ['cifar10', 'svhn']: 110 | print("using NLL loss") 111 | self.loss = nn.NLLLoss().to(device) 112 | elif config.dataset == "smallnorb": 113 | if self.mode == 'DR': 114 | print("using DR loss") 115 | self.loss = DynamicRoutingLoss().to(device) 116 | elif self.mode == 'EM': 117 | print("using EM loss") 118 | self.loss = EmRoutingLoss(self.epochs).to(device) 119 | elif self.mode == 'SR': 120 | print("using NLL loss") 121 | self.loss = nn.NLLLoss().to(device) 122 | 123 | self.params = self.model.parameters() 124 | self.optimizer = optim.SGD(self.params, lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay) 125 | 126 | if config.dataset == "cifar10": 127 | self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[150, 250], gamma=0.1) 128 | elif config.dataset == "svhn": 129 | self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[100, 150], gamma=0.1) 130 | elif config.dataset == "smallnorb": 131 | self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[100, 150], gamma=0.1) 132 | 133 | # save config as json 134 | save_config(self.name, self.config) 135 | 136 | # configure tensorboard logging 137 | if self.use_tensorboard: 138 | tensorboard_dir = self.logs_dir + self.name 139 | print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir)) 140 | if not os.path.exists(tensorboard_dir): 141 | os.makedirs(tensorboard_dir) 142 | self.writer = SummaryWriter(tensorboard_dir) 143 | 144 | print('[*] Number of model parameters: {:,}'.format( 145 | sum([p.data.nelement() for p in self.model.parameters()]))) 146 | 147 | def train(self): 148 | """ 149 | Train the model on the training set. 150 | 151 | A checkpoint of the model is saved after each epoch 152 | and if the validation accuracy is improved upon, 153 | a separate ckpt is created for use on the test set. 154 | """ 155 | # load the most recent checkpoint 156 | if self.resume: 157 | self.load_checkpoint(best=False) 158 | 159 | print("\n[*] Train on {} samples, validate on {} samples".format( 160 | self.num_train, self.num_valid) 161 | ) 162 | 163 | for epoch in range(self.start_epoch, self.epochs): 164 | # get current lr 165 | for i, param_group in enumerate(self.optimizer.param_groups): 166 | lr = float(param_group['lr']) 167 | break 168 | 169 | print( 170 | '\nEpoch: {}/{} - LR: {:.1e}'.format(epoch+1, self.epochs, lr) 171 | ) 172 | 173 | # train for 1 epoch 174 | train_loss, train_acc = self.train_one_epoch(epoch) 175 | 176 | # evaluate on validation set 177 | with torch.no_grad(): 178 | valid_loss, valid_acc = self.validate(epoch) 179 | 180 | 181 | msg1 = "train loss: {:.3f} - train acc: {:.3f}" 182 | msg2 = " - val loss: {:.3f} - val acc: {:.3f}" 183 | 184 | is_best = valid_acc > self.best_valid_acc 185 | if is_best: 186 | self.counter = 0 187 | msg2 += " [*]" 188 | 189 | msg = msg1 + msg2 190 | print(msg.format(train_loss, train_acc, valid_loss, valid_acc)) 191 | 192 | # check for improvement 193 | if not is_best: 194 | self.counter += 1 195 | ''' 196 | if self.counter > self.train_patience: 197 | print("[!] No improvement in a while, stopping training.") 198 | return 199 | ''' 200 | 201 | # decay lr 202 | self.scheduler.step() 203 | 204 | self.best_valid_acc = max(valid_acc, self.best_valid_acc) 205 | self.save_checkpoint( 206 | {'epoch': epoch + 1, 207 | 'model_state': self.model.state_dict(), 208 | 'optim_state': self.optimizer.state_dict(), 209 | 'scheduler_state': self.scheduler.state_dict(), 210 | 'best_valid_acc': self.best_valid_acc 211 | }, is_best 212 | ) 213 | 214 | if self.use_tensorboard: 215 | self.writer.close() 216 | 217 | print(self.best_valid_acc) 218 | 219 | def train_one_epoch(self, epoch): 220 | """ 221 | Train the model for 1 epoch of the training set. 222 | 223 | An epoch corresponds to one full pass through the entire 224 | training set in successive mini-batches. 225 | 226 | This is used by train() and should not be called manually. 227 | """ 228 | self.model.train() 229 | 230 | losses = AverageMeter() 231 | accs = AverageMeter() 232 | 233 | tic = time.time() 234 | with tqdm(total=self.num_train) as pbar: 235 | for i, (x, y) in enumerate(self.train_loader): 236 | x, y = x.to(device), y.to(device) 237 | 238 | b = x.shape[0] 239 | out = self.model(x) 240 | if isinstance(self.loss, EmRoutingLoss): 241 | loss = self.loss(out, y, epoch=epoch) 242 | else: 243 | loss = self.loss(out, y) 244 | 245 | # compute accuracy 246 | pred = torch.max(out, 1)[1] 247 | correct = (pred == y).float() 248 | acc = 100 * (correct.sum() / len(y)) 249 | 250 | # store 251 | losses.update(loss.data.item(), x.size()[0]) 252 | accs.update(acc.data.item(), x.size()[0]) 253 | 254 | # compute gradients and update SGD 255 | self.optimizer.zero_grad() 256 | loss.backward() 257 | self.optimizer.step() 258 | 259 | # measure elapsed time 260 | toc = time.time() 261 | pbar.set_description( 262 | ( 263 | "{:.1f}s - loss: {:.3f} - acc: {:.3f}".format( 264 | (toc-tic), loss.data.item(), acc.data.item() 265 | ) 266 | ) 267 | ) 268 | pbar.update(b) 269 | 270 | if self.use_tensorboard: 271 | iteration = epoch*len(self.train_loader) + i 272 | self.writer.add_scalar('train_loss', loss, iteration) 273 | self.writer.add_scalar('train_acc', acc, iteration) 274 | 275 | return losses.avg, accs.avg 276 | 277 | def validate(self, epoch): 278 | """ 279 | Evaluate the model on the validation set. 280 | """ 281 | self.model.eval() 282 | 283 | losses = AverageMeter() 284 | accs = AverageMeter() 285 | 286 | for i, (x, y) in enumerate(self.valid_loader): 287 | x, y = x.to(device), y.to(device) 288 | 289 | out = self.model(x) 290 | if isinstance(self.loss, EmRoutingLoss): 291 | loss = self.loss(out, y, epoch=epoch) 292 | else: 293 | loss = self.loss(out, y) 294 | 295 | # compute accuracy 296 | pred = torch.max(out, 1)[1] 297 | correct = (pred == y).float() 298 | acc = 100 * (correct.sum() / len(y)) 299 | 300 | # store 301 | losses.update(loss.data.item(), x.size()[0]) 302 | accs.update(acc.data.item(), x.size()[0]) 303 | 304 | # log to tensorboard 305 | if self.use_tensorboard: 306 | self.writer.add_scalar('valid_loss', losses.avg, epoch) 307 | self.writer.add_scalar('valid_acc', accs.avg, epoch) 308 | 309 | return losses.avg, accs.avg 310 | 311 | def test(self): 312 | """ 313 | Test the model on the held-out test data. 314 | This function should only be called at the very 315 | end once the model has finished training. 316 | """ 317 | correct = 0 318 | 319 | # load the best checkpoint 320 | self.load_checkpoint(best=self.best) 321 | self.model.eval() 322 | 323 | for i, (x, y) in enumerate(self.test_loader): 324 | x, y = x.to(device), y.to(device) 325 | 326 | out = self.model(x) 327 | 328 | # compute accuracy 329 | pred = torch.max(out, 1)[1] 330 | correct += pred.eq(y.data.view_as(pred)).cpu().sum() 331 | 332 | perc = (100. * correct.data.item()) / (self.num_test) 333 | error = 100 - perc 334 | print( 335 | '[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)'.format( 336 | correct, self.num_test, perc, error) 337 | ) 338 | 339 | def test_attack(self): 340 | correct = 0 341 | self.load_checkpoint(best=self.best) 342 | self.model.eval() 343 | 344 | # prepare adv attack 345 | attacker = Attack(self.model, self.loss, self.attack_type, self.attack_eps) 346 | adv_data, num_examples = extract_adv_images(attacker, self.test_loader, self.targeted, DATASET_CONFIGS[self.config.dataset]["classes"]) 347 | 348 | with torch.no_grad(): 349 | for i, (x, y) in enumerate(adv_data): 350 | x, y = x.to(device), y.to(device) 351 | 352 | out = self.model(x) 353 | 354 | # compute accuracy 355 | pred = torch.max(out, 1)[1] 356 | correct += pred.eq(y.data.view_as(pred)).cpu().sum() 357 | 358 | if self.targeted: 359 | success = correct 360 | else: 361 | success = num_examples - correct 362 | 363 | perc = (100. * success.data.item()) / (num_examples) 364 | 365 | print( 366 | '[*] Attack success rate ({}, targeted={}, eps={}): {}/{} ({:.2f}% - {:.2f}%)'.format( 367 | self.attack_type, self.targeted, self.attack_eps, success, num_examples, perc, 100. - perc) 368 | ) 369 | 370 | def save_checkpoint(self, state, is_best): 371 | """ 372 | Save a copy of the model so that it can be loaded at a future 373 | date. This function is used when the model is being evaluated 374 | on the test data. 375 | 376 | If this model has reached the best validation accuracy thus 377 | far, a seperate file with the suffix `best` is created. 378 | """ 379 | # print("[*] Saving model to {}".format(self.ckpt_dir)) 380 | 381 | filename = self.name + '_ckpt.pth.tar' 382 | ckpt_path = os.path.join(self.ckpt_dir, filename) 383 | torch.save(state, ckpt_path) 384 | 385 | if is_best: 386 | filename = self.name + '_model_best.pth.tar' 387 | shutil.copyfile( 388 | ckpt_path, os.path.join(self.ckpt_dir, filename) 389 | ) 390 | 391 | def load_checkpoint(self, best=False): 392 | """ 393 | Load the best copy of a model. This is useful for 2 cases: 394 | 395 | - Resuming training with the most recent model checkpoint. 396 | - Loading the best validation model to evaluate on the test data. 397 | 398 | Params 399 | ------ 400 | - best: if set to True, loads the best model. Use this if you want 401 | to evaluate your model on the test data. Else, set to False in 402 | which case the most recent version of the checkpoint is used. 403 | """ 404 | print("[*] Loading model from {}".format(self.ckpt_dir)) 405 | 406 | filename = self.name + '_ckpt.pth.tar' 407 | if best: 408 | filename = self.name + '_model_best.pth.tar' 409 | ckpt_path = os.path.join(self.ckpt_dir, filename) 410 | ckpt = torch.load(ckpt_path) 411 | 412 | # load variables from checkpoint 413 | self.start_epoch = ckpt['epoch'] 414 | self.best_valid_acc = ckpt['best_valid_acc'] 415 | self.model.load_state_dict(ckpt['model_state']) 416 | self.optimizer.load_state_dict(ckpt['optim_state']) 417 | self.scheduler.load_state_dict(ckpt['scheduler_state']) 418 | 419 | if best: 420 | print( 421 | "[*] Loaded {} checkpoint @ epoch {} " 422 | "with best valid acc of {:.3f}".format( 423 | filename, ckpt['epoch'], ckpt['best_valid_acc']) 424 | ) 425 | else: 426 | print( 427 | "[*] Loaded {} checkpoint @ epoch {}".format( 428 | filename, ckpt['epoch']) 429 | ) 430 | 431 | -------------------------------------------------------------------------------- /norb.py: -------------------------------------------------------------------------------- 1 | # Loader taken from https://github.com/mavanb/vision/blob/448fac0f38cab35a387666d553b9d5e4eec4c5e6/torchvision/datasets/utils.py 2 | 3 | from __future__ import print_function 4 | import os 5 | import errno 6 | import struct 7 | 8 | import torch 9 | import torch.utils.data as data 10 | import numpy as np 11 | from PIL import Image 12 | from torchvision.datasets.utils import download_url, check_integrity 13 | 14 | 15 | class smallNORB(data.Dataset): 16 | """`MNIST `_ Dataset. 17 | Args: 18 | root (string): Root directory of dataset where processed folder and 19 | and raw folder exist. 20 | train (bool, optional): If True, creates dataset from the training files, 21 | otherwise from the test files. 22 | download (bool, optional): If true, downloads the dataset from the internet and 23 | puts it in root directory. If the dataset is already processed, it is not processed 24 | and downloaded again. If dataset is only already downloaded, it is not 25 | downloaded again. 26 | transform (callable, optional): A function/transform that takes in an PIL image 27 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 28 | target_transform (callable, optional): A function/transform that takes in the 29 | target and transforms it. 30 | info_transform (callable, optional): A function/transform that takes in the 31 | info and transforms it. 32 | mode (string, optional): Denotes how the images in the data files are returned. Possible values: 33 | - all (default): both left and right are included separately. 34 | - stereo: left and right images are included as corresponding pairs. 35 | - left: only the left images are included. 36 | - right: only the right images are included. 37 | """ 38 | 39 | dataset_root = "https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/" 40 | data_files = { 41 | 'train': { 42 | 'dat': { 43 | "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat', 44 | "md5_gz": "66054832f9accfe74a0f4c36a75bc0a2", 45 | "md5": "8138a0902307b32dfa0025a36dfa45ec" 46 | }, 47 | 'info': { 48 | "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-info.mat', 49 | "md5_gz": "51dee1210a742582ff607dfd94e332e3", 50 | "md5": "19faee774120001fc7e17980d6960451" 51 | }, 52 | 'cat': { 53 | "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat', 54 | "md5_gz": "23c8b86101fbf0904a000b43d3ed2fd9", 55 | "md5": "fd5120d3f770ad57ebe620eb61a0b633" 56 | }, 57 | }, 58 | 'test': { 59 | 'dat': { 60 | "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat', 61 | "md5_gz": "e4ad715691ed5a3a5f138751a4ceb071", 62 | "md5": "e9920b7f7b2869a8f1a12e945b2c166c" 63 | }, 64 | 'info': { 65 | "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat', 66 | "md5_gz": "a9454f3864d7fd4bb3ea7fc3eb84924e", 67 | "md5": "7c5b871cc69dcadec1bf6a18141f5edc" 68 | }, 69 | 'cat': { 70 | "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat', 71 | "md5_gz": "5aa791cd7e6016cf957ce9bdb93b8603", 72 | "md5": "fd5120d3f770ad57ebe620eb61a0b633" 73 | }, 74 | }, 75 | } 76 | 77 | raw_folder = 'raw' 78 | processed_folder = 'processed' 79 | train_image_file = 'train_img' 80 | train_label_file = 'train_label' 81 | train_info_file = 'train_info' 82 | test_image_file = 'test_img' 83 | test_label_file = 'test_label' 84 | test_info_file = 'test_info' 85 | extension = '.pt' 86 | 87 | def __init__(self, root, train=True, transform=None, target_transform=None, info_transform=None, download=False, 88 | mode="all"): 89 | 90 | self.root = os.path.expanduser(root) 91 | self.transform = transform 92 | self.target_transform = target_transform 93 | self.info_transform = info_transform 94 | self.train = train # training set or test set 95 | self.mode = mode 96 | 97 | if download: 98 | self.download() 99 | 100 | if not self._check_exists(): 101 | raise RuntimeError('Dataset not found or corrupted.' + 102 | ' You can use download=True to download it') 103 | 104 | # load test or train set 105 | image_file = self.train_image_file if self.train else self.test_image_file 106 | label_file = self.train_label_file if self.train else self.test_label_file 107 | info_file = self.train_info_file if self.train else self.test_info_file 108 | 109 | # load labels 110 | self.labels = self._load(label_file) 111 | 112 | # load info files 113 | self.infos = self._load(info_file) 114 | 115 | # load right set 116 | if self.mode == "left": 117 | self.data = self._load("{}_left".format(image_file)) 118 | 119 | # load left set 120 | elif self.mode == "right": 121 | self.data = self._load("{}_right".format(image_file)) 122 | 123 | elif self.mode == "all" or self.mode == "stereo": 124 | left_data = self._load("{}_left".format(image_file)) 125 | right_data = self._load("{}_right".format(image_file)) 126 | 127 | # load stereo 128 | if self.mode == "stereo": 129 | self.data = torch.stack((left_data, right_data), dim=1) 130 | 131 | # load all 132 | else: 133 | self.data = torch.cat((left_data, right_data), dim=0) 134 | 135 | def __getitem__(self, index): 136 | """ 137 | Args: 138 | index (int): Index 139 | Returns: 140 | mode ``all'', ``left'', ``right'': 141 | tuple: (image, target, info) 142 | mode ``stereo'': 143 | tuple: (image left, image right, target, info) 144 | """ 145 | target = self.labels[index % 24300] if self.mode is "all" else self.labels[index] 146 | if self.target_transform is not None: 147 | target = self.target_transform(target) 148 | 149 | info = self.infos[index % 24300] if self.mode is "all" else self.infos[index] 150 | if self.info_transform is not None: 151 | info = self.info_transform(info) 152 | 153 | if self.mode == "stereo": 154 | img_left = self._transform(self.data[index, 0]) 155 | img_right = self._transform(self.data[index, 1]) 156 | return img_left, img_right, target, info 157 | 158 | img = self._transform(self.data[index]) 159 | return img, target 160 | 161 | def __len__(self): 162 | return len(self.data) 163 | 164 | def _transform(self, img): 165 | # doing this so that it is consistent with all other data sets 166 | # to return a PIL Image 167 | img = Image.fromarray(img.numpy(), mode='L') 168 | 169 | if self.transform is not None: 170 | img = self.transform(img) 171 | return img 172 | 173 | def _load(self, file_name): 174 | return torch.load(os.path.join(self.root, self.processed_folder, file_name + self.extension)) 175 | 176 | def _save(self, file, file_name): 177 | with open(os.path.join(self.root, self.processed_folder, file_name + self.extension), 'wb') as f: 178 | torch.save(file, f) 179 | 180 | def _check_exists(self): 181 | """ Check if processed files exists.""" 182 | files = ( 183 | "{}_left".format(self.train_image_file), 184 | "{}_right".format(self.train_image_file), 185 | "{}_left".format(self.test_image_file), 186 | "{}_right".format(self.test_image_file), 187 | self.test_label_file, 188 | self.train_label_file 189 | ) 190 | fpaths = [os.path.exists(os.path.join(self.root, self.processed_folder, f + self.extension)) for f in files] 191 | return False not in fpaths 192 | 193 | def _flat_data_files(self): 194 | return [j for i in self.data_files.values() for j in list(i.values())] 195 | 196 | def _check_integrity(self): 197 | """Check if unpacked files have correct md5 sum.""" 198 | root = self.root 199 | for file_dict in self._flat_data_files(): 200 | filename = file_dict["name"] 201 | md5 = file_dict["md5"] 202 | fpath = os.path.join(root, self.raw_folder, filename) 203 | if not check_integrity(fpath, md5): 204 | return False 205 | return True 206 | 207 | def download(self): 208 | """Download the SmallNORB data if it doesn't exist in processed_folder already.""" 209 | import gzip 210 | 211 | if self._check_exists(): 212 | return 213 | 214 | # check if already extracted and verified 215 | if self._check_integrity(): 216 | print('Files already downloaded and verified') 217 | else: 218 | # download and extract 219 | for file_dict in self._flat_data_files(): 220 | url = self.dataset_root + file_dict["name"] + '.gz' 221 | filename = file_dict["name"] 222 | gz_filename = filename + '.gz' 223 | md5 = file_dict["md5_gz"] 224 | fpath = os.path.join(self.root, self.raw_folder, filename) 225 | gz_fpath = fpath + '.gz' 226 | 227 | # download if compressed file not exists and verified 228 | download_url(url, os.path.join(self.root, self.raw_folder), gz_filename, md5) 229 | 230 | print('# Extracting data {}\n'.format(filename)) 231 | 232 | with open(fpath, 'wb') as out_f, \ 233 | gzip.GzipFile(gz_fpath) as zip_f: 234 | out_f.write(zip_f.read()) 235 | 236 | os.unlink(gz_fpath) 237 | 238 | # process and save as torch files 239 | print('Processing...') 240 | 241 | # create processed folder 242 | try: 243 | os.makedirs(os.path.join(self.root, self.processed_folder)) 244 | except OSError as e: 245 | if e.errno == errno.EEXIST: 246 | pass 247 | else: 248 | raise 249 | 250 | # read train files 251 | left_train_img, right_train_img = self._read_image_file(self.data_files["train"]["dat"]["name"]) 252 | train_info = self._read_info_file(self.data_files["train"]["info"]["name"]) 253 | train_label = self._read_label_file(self.data_files["train"]["cat"]["name"]) 254 | 255 | # read test files 256 | left_test_img, right_test_img = self._read_image_file(self.data_files["test"]["dat"]["name"]) 257 | test_info = self._read_info_file(self.data_files["test"]["info"]["name"]) 258 | test_label = self._read_label_file(self.data_files["test"]["cat"]["name"]) 259 | 260 | # save training files 261 | self._save(left_train_img, "{}_left".format(self.train_image_file)) 262 | self._save(right_train_img, "{}_right".format(self.train_image_file)) 263 | self._save(train_label, self.train_label_file) 264 | self._save(train_info, self.train_info_file) 265 | 266 | # save test files 267 | self._save(left_test_img, "{}_left".format(self.test_image_file)) 268 | self._save(right_test_img, "{}_right".format(self.test_image_file)) 269 | self._save(test_label, self.test_label_file) 270 | self._save(test_info, self.test_info_file) 271 | 272 | print('Done!') 273 | 274 | @staticmethod 275 | def _parse_header(file_pointer): 276 | # Read magic number and ignore 277 | struct.unpack('`_ Dataset. 349 | Args: 350 | root (string): Root directory of dataset where processed folder and 351 | and raw folder exist. 352 | train (bool, optional): If True, creates dataset from the training files, 353 | otherwise from the test files. 354 | download (bool, optional): If true, downloads the dataset from the internet and 355 | puts it in root directory. If the dataset is already processed, it is not processed 356 | and downloaded again. If dataset is only already downloaded, it is not 357 | downloaded again. 358 | transform (callable, optional): A function/transform that takes in an PIL image 359 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 360 | target_transform (callable, optional): A function/transform that takes in the 361 | target and transforms it. 362 | info_transform (callable, optional): A function/transform that takes in the 363 | info and transforms it. 364 | mode (string, optional): Denotes how the images in the data files are returned. Possible values: 365 | - all (default): both left and right are included separately. 366 | - stereo: left and right images are included as corresponding pairs. 367 | - left: only the left images are included. 368 | - right: only the right images are included. 369 | """ 370 | 371 | dataset_root = "https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/" 372 | data_files = { 373 | 'train': { 374 | 'dat': { 375 | "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat', 376 | "md5_gz": "66054832f9accfe74a0f4c36a75bc0a2", 377 | "md5": "8138a0902307b32dfa0025a36dfa45ec" 378 | }, 379 | 'info': { 380 | "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-info.mat', 381 | "md5_gz": "51dee1210a742582ff607dfd94e332e3", 382 | "md5": "19faee774120001fc7e17980d6960451" 383 | }, 384 | 'cat': { 385 | "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat', 386 | "md5_gz": "23c8b86101fbf0904a000b43d3ed2fd9", 387 | "md5": "fd5120d3f770ad57ebe620eb61a0b633" 388 | }, 389 | }, 390 | 'test': { 391 | 'dat': { 392 | "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat', 393 | "md5_gz": "e4ad715691ed5a3a5f138751a4ceb071", 394 | "md5": "e9920b7f7b2869a8f1a12e945b2c166c" 395 | }, 396 | 'info': { 397 | "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat', 398 | "md5_gz": "a9454f3864d7fd4bb3ea7fc3eb84924e", 399 | "md5": "7c5b871cc69dcadec1bf6a18141f5edc" 400 | }, 401 | 'cat': { 402 | "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat', 403 | "md5_gz": "5aa791cd7e6016cf957ce9bdb93b8603", 404 | "md5": "fd5120d3f770ad57ebe620eb61a0b633" 405 | }, 406 | }, 407 | } 408 | 409 | raw_folder = 'raw' 410 | processed_folder = 'processed' 411 | train_image_file = 'train_img' 412 | train_label_file = 'train_label' 413 | train_info_file = 'train_info' 414 | test_image_file = 'test_img' 415 | test_label_file = 'test_label' 416 | test_info_file = 'test_info' 417 | extension = '.pt' 418 | 419 | def __init__(self, root, exp='azimuth', train=True, familiar=True, transform=None, target_transform=None, info_transform=None, download=False, 420 | mode="all"): 421 | 422 | self.root = os.path.expanduser(root) 423 | self.transform = transform 424 | self.target_transform = target_transform 425 | self.info_transform = info_transform 426 | self.train = train # training set or test set 427 | self.familiar = familiar 428 | self.mode = mode 429 | 430 | if download: 431 | self.download() 432 | 433 | if not self._check_exists(): 434 | raise RuntimeError('Dataset not found or corrupted.' + 435 | ' You can use download=True to download it') 436 | 437 | # load test or train set 438 | image_file = self.train_image_file if self.train else self.test_image_file 439 | label_file = self.train_label_file if self.train else self.test_label_file 440 | info_file = self.train_info_file if self.train else self.test_info_file 441 | 442 | # load labels 443 | self.labels = self._load(label_file) 444 | 445 | # load info files 446 | self.infos = self._load(info_file) 447 | 448 | # load right set 449 | if self.mode == "left": 450 | self.data = self._load("{}_left".format(image_file)) 451 | 452 | # load left set 453 | elif self.mode == "right": 454 | self.data = self._load("{}_right".format(image_file)) 455 | 456 | elif self.mode == "all" or self.mode == "stereo": 457 | left_data = self._load("{}_left".format(image_file)) 458 | right_data = self._load("{}_right".format(image_file)) 459 | 460 | # load stereo 461 | if self.mode == "stereo": 462 | self.data = torch.stack((left_data, right_data), dim=1) 463 | 464 | # load all 465 | else: 466 | self.data = torch.cat((left_data, right_data), dim=0) 467 | 468 | # prepare exp 469 | img, tar, inf = [], [], [] 470 | if exp == 'azimuth': 471 | self.anno_dim = 2 472 | self.train_anno = [0, 2, 4, 34, 32, 30] 473 | elif exp == 'elevation': 474 | self.anno_dim = 1 475 | self.train_anno = [0, 1, 2] 476 | else: 477 | raise NotImplementedError 478 | 479 | indices = [] 480 | for i, info in enumerate(self.infos): 481 | info = info[self.anno_dim].data.item() 482 | if (info in self.train_anno) == (self.train or self.familiar): 483 | indices.append(i) 484 | 485 | self.data = self.data[indices + [i + 24300 for i in indices]] if self.mode is 'all' else self.data[indices] 486 | self.labels = self.labels[indices] 487 | self.infos = self.infos[indices] 488 | 489 | def __getitem__(self, index): 490 | """ 491 | Args: 492 | index (int): Index 493 | Returns: 494 | mode ``all'', ``left'', ``right'': 495 | tuple: (image, target, info) 496 | mode ``stereo'': 497 | tuple: (image left, image right, target, info) 498 | """ 499 | target = self.labels[index % len(self.infos)] if self.mode is "all" else self.labels[index] 500 | if self.target_transform is not None: 501 | target = self.target_transform(target) 502 | 503 | info = self.infos[index % len(self.infos)] if self.mode is "all" else self.infos[index] 504 | if self.info_transform is not None: 505 | info = self.info_transform(info) 506 | 507 | if self.mode == "stereo": 508 | img_left = self._transform(self.data[index, 0]) 509 | img_right = self._transform(self.data[index, 1]) 510 | return img_left, img_right, target, info 511 | 512 | img = self._transform(self.data[index]) 513 | return img, target 514 | 515 | def __len__(self): 516 | return len(self.data) 517 | 518 | def _transform(self, img): 519 | # doing this so that it is consistent with all other data sets 520 | # to return a PIL Image 521 | img = Image.fromarray(img.numpy(), mode='L') 522 | 523 | if self.transform is not None: 524 | img = self.transform(img) 525 | return img 526 | 527 | def _load(self, file_name): 528 | return torch.load(os.path.join(self.root, self.processed_folder, file_name + self.extension)) 529 | 530 | def _save(self, file, file_name): 531 | with open(os.path.join(self.root, self.processed_folder, file_name + self.extension), 'wb') as f: 532 | torch.save(file, f) 533 | 534 | def _check_exists(self): 535 | """ Check if processed files exists.""" 536 | files = ( 537 | "{}_left".format(self.train_image_file), 538 | "{}_right".format(self.train_image_file), 539 | "{}_left".format(self.test_image_file), 540 | "{}_right".format(self.test_image_file), 541 | self.test_label_file, 542 | self.train_label_file 543 | ) 544 | fpaths = [os.path.exists(os.path.join(self.root, self.processed_folder, f + self.extension)) for f in files] 545 | return False not in fpaths 546 | 547 | def _flat_data_files(self): 548 | return [j for i in self.data_files.values() for j in list(i.values())] 549 | 550 | def _check_integrity(self): 551 | """Check if unpacked files have correct md5 sum.""" 552 | root = self.root 553 | for file_dict in self._flat_data_files(): 554 | filename = file_dict["name"] 555 | md5 = file_dict["md5"] 556 | fpath = os.path.join(root, self.raw_folder, filename) 557 | if not check_integrity(fpath, md5): 558 | return False 559 | return True 560 | 561 | def download(self): 562 | """Download the SmallNORB data if it doesn't exist in processed_folder already.""" 563 | import gzip 564 | 565 | if self._check_exists(): 566 | return 567 | 568 | # check if already extracted and verified 569 | if self._check_integrity(): 570 | print('Files already downloaded and verified') 571 | else: 572 | # download and extract 573 | for file_dict in self._flat_data_files(): 574 | url = self.dataset_root + file_dict["name"] + '.gz' 575 | filename = file_dict["name"] 576 | gz_filename = filename + '.gz' 577 | md5 = file_dict["md5_gz"] 578 | fpath = os.path.join(self.root, self.raw_folder, filename) 579 | gz_fpath = fpath + '.gz' 580 | 581 | # download if compressed file not exists and verified 582 | download_url(url, os.path.join(self.root, self.raw_folder), gz_filename, md5) 583 | 584 | print('# Extracting data {}\n'.format(filename)) 585 | 586 | with open(fpath, 'wb') as out_f, \ 587 | gzip.GzipFile(gz_fpath) as zip_f: 588 | out_f.write(zip_f.read()) 589 | 590 | os.unlink(gz_fpath) 591 | 592 | # process and save as torch files 593 | print('Processing...') 594 | 595 | # create processed folder 596 | try: 597 | os.makedirs(os.path.join(self.root, self.processed_folder)) 598 | except OSError as e: 599 | if e.errno == errno.EEXIST: 600 | pass 601 | else: 602 | raise 603 | 604 | # read train files 605 | left_train_img, right_train_img = self._read_image_file(self.data_files["train"]["dat"]["name"]) 606 | train_info = self._read_info_file(self.data_files["train"]["info"]["name"]) 607 | train_label = self._read_label_file(self.data_files["train"]["cat"]["name"]) 608 | 609 | # read test files 610 | left_test_img, right_test_img = self._read_image_file(self.data_files["test"]["dat"]["name"]) 611 | test_info = self._read_info_file(self.data_files["test"]["info"]["name"]) 612 | test_label = self._read_label_file(self.data_files["test"]["cat"]["name"]) 613 | 614 | # save training files 615 | self._save(left_train_img, "{}_left".format(self.train_image_file)) 616 | self._save(right_train_img, "{}_right".format(self.train_image_file)) 617 | self._save(train_label, self.train_label_file) 618 | self._save(train_info, self.train_info_file) 619 | 620 | # save test files 621 | self._save(left_test_img, "{}_left".format(self.test_image_file)) 622 | self._save(right_test_img, "{}_right".format(self.test_image_file)) 623 | self._save(test_label, self.test_label_file) 624 | self._save(test_info, self.test_info_file) 625 | 626 | print('Done!') 627 | 628 | @staticmethod 629 | def _parse_header(file_pointer): 630 | # Read magic number and ignore 631 | struct.unpack('