├── misc
├── moe.pdf
├── capsnet.pdf
├── self_routing.png
├── routing_by_agreement.png
└── neurips2019-self_routing-poster.pdf
├── .gitignore
├── models
├── __init__.py
├── smallnet.py
├── convnet.py
└── resnet.py
├── requirements.txt
├── loss.py
├── main.py
├── utils.py
├── README.md
├── attack.py
├── config.py
├── data_loader.py
├── modules.py
├── trainer.py
└── norb.py
/misc/moe.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coder3000/SR-CapsNet/HEAD/misc/moe.pdf
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 |
4 | ckpt/**
5 | data/**
6 | logs/**
7 |
8 |
--------------------------------------------------------------------------------
/misc/capsnet.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coder3000/SR-CapsNet/HEAD/misc/capsnet.pdf
--------------------------------------------------------------------------------
/misc/self_routing.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coder3000/SR-CapsNet/HEAD/misc/self_routing.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import *
2 | from .convnet import *
3 | from .smallnet import *
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.2.0
2 | torchvision==0.4.0
3 | tensorboardx==1.8
4 | scipy
5 | numpy
6 | tqdm
--------------------------------------------------------------------------------
/misc/routing_by_agreement.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coder3000/SR-CapsNet/HEAD/misc/routing_by_agreement.png
--------------------------------------------------------------------------------
/misc/neurips2019-self_routing-poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coder3000/SR-CapsNet/HEAD/misc/neurips2019-self_routing-poster.pdf
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from torch.nn.modules.loss import _Loss
6 |
7 | from utils import one_hot
8 |
9 |
10 | class DynamicRoutingLoss(nn.Module):
11 | def __init(self):
12 | super(DynamicRoutingLoss, self).__init()
13 |
14 | def forward(self, x, target):
15 | target = one_hot(target, x.shape[1])
16 |
17 | left = F.relu(0.9 - x) ** 2
18 | right = F.relu(x - 0.1) ** 2
19 |
20 | margin_loss = target * left + 0.5 * (1. - target) * right
21 | margin_loss = margin_loss.sum(dim=1).mean()
22 | return margin_loss
23 |
24 |
25 | class EmRoutingLoss(nn.Module):
26 | def __init__(self, max_epoch):
27 | super(EmRoutingLoss, self).__init__()
28 | self.margin_init = 0.2
29 | self.margin_step = 0.2 / max_epoch
30 | self.max_epoch = max_epoch
31 |
32 | def forward(self, x, target, epoch=None):
33 | if epoch is None:
34 | margin = 0.9
35 | else:
36 | margin = self.margin_init + self.margin_step * min(epoch, self.max_epoch)
37 |
38 | b, E = x.shape
39 | at = x.new_zeros(b)
40 | for i, lb in enumerate(target):
41 | at[i] = x[i][lb]
42 | at = at.view(b, 1).repeat(1, E)
43 |
44 | zeros = x.new_zeros(x.shape)
45 | loss = torch.max(margin - (at - x), zeros)
46 | loss = loss**2
47 | loss = loss.sum(dim=1).mean()
48 | return loss
49 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torchvision import datasets, transforms
4 |
5 | from trainer import Trainer
6 | from config import get_config
7 | from utils import prepare_dirs
8 | from data_loader import get_test_loader, get_train_valid_loader, VIEWPOINT_EXPS
9 |
10 |
11 | torch.backends.cudnn.deterministic = True
12 | torch.backends.cudnn.benchmark = False
13 |
14 |
15 | def main(config):
16 |
17 | # ensure directories are setup
18 | prepare_dirs(config)
19 |
20 | # ensure reproducibility
21 | torch.manual_seed(config.random_seed)
22 | kwargs = {}
23 | if torch.cuda.is_available():
24 | torch.cuda.manual_seed(config.random_seed)
25 | kwargs = {'num_workers': 4, 'pin_memory': False}
26 |
27 | # instantiate data loaders
28 | if config.is_train:
29 | data_loader = get_train_valid_loader(
30 | config.data_dir, config.dataset, config.batch_size,
31 | config.random_seed, config.exp, config.valid_size,
32 | config.shuffle, **kwargs
33 | )
34 | else:
35 | data_loader = get_test_loader(
36 | config.data_dir, config.dataset, config.batch_size, config.exp, config.familiar,
37 | **kwargs
38 | )
39 |
40 | # instantiate trainer
41 | trainer = Trainer(config, data_loader)
42 |
43 | if config.is_train:
44 | trainer.train()
45 | else:
46 | if config.attack:
47 | trainer.test_attack()
48 | else:
49 | trainer.test()
50 |
51 | if __name__ == '__main__':
52 | config, unparsed = get_config()
53 | main(config)
54 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import json
4 | import torch
5 | import torch.nn as nn
6 | import numpy as np
7 |
8 | class AverageMeter(object):
9 | """
10 | Computes and stores the average and
11 | current value.
12 | """
13 | def __init__(self):
14 | self.reset()
15 |
16 | def reset(self):
17 | self.val = 0
18 | self.avg = 0
19 | self.sum = 0
20 | self.count = 0
21 |
22 | def update(self, val, n=1):
23 | self.val = val
24 | self.sum += val * n
25 | self.count += n
26 | self.avg = self.sum / self.count
27 |
28 |
29 | def prepare_dirs(config):
30 | for path in [config.data_dir, config.ckpt_dir, config.logs_dir]:
31 | if not os.path.exists(path):
32 | os.makedirs(path)
33 |
34 | def save_config(model_name, config):
35 | filename = model_name + '_params.json'
36 | param_path = os.path.join(config.ckpt_dir, filename)
37 |
38 | print("[*] Model Checkpoint Dir: {}".format(config.ckpt_dir))
39 | print("[*] Param Path: {}".format(param_path))
40 |
41 | with open(param_path, 'w') as fp:
42 | json.dump(config.__dict__, fp, indent=4, sort_keys=True)
43 |
44 | def one_hot(y, n_dims):
45 | scatter_dim = len(y.size())
46 | y_tensor = y.view(*y.size(), -1)
47 | zeros = torch.zeros(*y.size(), n_dims).cuda()
48 | return zeros.scatter(scatter_dim, y_tensor, 1)
49 |
50 | # dynamic routing
51 | def squash(s, dim=-1):
52 | mag_sq = torch.sum(s**2, dim=dim, keepdim=True)
53 | mag = torch.sqrt(mag_sq)
54 | v = (mag_sq / (1.0 + mag_sq)) * (s / mag)
55 | return v
56 |
57 | def weights_init(m):
58 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
59 | nn.init.kaiming_uniform_(m.weight.data)
60 | elif isinstance(m, nn.BatchNorm2d):
61 | nn.init.constant_(m.weight, 1)
62 | nn.init.constant_(m.bias, 0)
63 |
64 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SR-CapsNet
2 |
3 | PyTorch implementation for our paper [**Self-Routing Capsule Networks**](https://papers.nips.cc/paper/8982-self-routing-capsule-networks) in NeurIPS 2019.
4 |
5 |
6 | [[poster]](https://github.com/coder3000/SR-CapsNet/blob/master/misc/neurips2019-self_routing-poster.pdf)
7 |
8 | ## Prerequisites
9 | - Python >= 3.5.2
10 | - CUDA >= 9.0 supported GPU
11 |
12 | Install required packages by:
13 | ```
14 | pip3 install -r requirements.txt
15 | ```
16 |
17 |
18 | ## Training
19 | To train a model for CIFAR-10 or SVHN, run:
20 | ```
21 | python3 main.py --dataset=cifar10 --name=resnet_[routing_method] --epochs=350
22 | python3 main.py --dataset=svhn --name=resnet_[routing_method] --epochs=200
23 | ```
24 |
25 | `routing_method` should be one of `[avg, max, fc, dynamic_routing, em_routing, self_routing]`. This will modify last layers of base model accordingly.
26 |
27 |
28 | For SmallNORB, run:
29 |
30 | ```
31 | python3 main.py --dataset=smallnorb --name=convnet_[routing_method] --epochs=200 --exp=elevation
32 | ```
33 |
34 | Here `--exp` denotes which viewpoint data should be splitted on.
35 |
36 | See `config.py` for more options and their descriptions.
37 |
38 | ## Testing
39 | To test a model, simply run:
40 |
41 | ```
42 | python3 main.py --dataset=cifar10 --name=resnet_[routing_method] --is_train=False
43 | ```
44 |
45 | You can perform adversarial attacks against a trained model by:
46 | ```
47 | python3 main.py --dataset=cifar10 --name=resnet_[routing_method] --is_train=False --attack=True --attack_type=bim --attack_eps=0.1 --targeted=False
48 | ```
49 |
50 | For SmallNORB, you can test against novel viewpoints by:
51 | ```
52 | python3 main.py --dataset=smallnorb --name=convnet_[routing_method] --is_train=False --familiar=False
53 | ```
54 |
55 |
56 | ## Citation
57 | ```
58 | @inproceedings{hahn2019,
59 | title={Self-Routing Capsule Networks},
60 | author={Hahn, Taeyoung and Pyeon, Myeongjang and Kim, Gunhee},
61 | booktitle={NeurIPS},
62 | year={2019}
63 | }
64 | ```
65 |
--------------------------------------------------------------------------------
/attack.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | import torch.nn.functional as F
4 | import torch.nn as nn
5 | from torch.autograd import Variable
6 |
7 | import random
8 |
9 | random.seed(2019)
10 |
11 | class Attack(object):
12 | def __init__(self, net, criterion, attack_type, eps):
13 | self.net = net
14 | self.criterion = criterion
15 | self.attack_type = attack_type
16 |
17 | if attack_type not in ["bim", "fgsm"]:
18 | raise NotImplementedError("Unknown attack type")
19 |
20 | self.eps = eps
21 |
22 | def make(self, x, y, target):
23 | return getattr(self, self.attack_type)(x, y, target=target)
24 |
25 | def bim(self, x, y, target=None, x_val_min=-1, x_val_max=1):
26 | out = self.net(x)
27 | pred = torch.max(out, 1)[1]
28 |
29 | if pred.item() != y.item():
30 | return None
31 |
32 | eta = torch.zeros_like(x)
33 | iters = 10
34 | eps_iter = self.eps / iters
35 | for i in range(iters):
36 | nx = x + eta
37 | nx.requires_grad_()
38 |
39 | out = self.net(nx)
40 |
41 | self.net.zero_grad()
42 | if target is not None:
43 | cost = self.criterion(out, target)
44 | else:
45 | cost = -self.criterion(out, y)
46 | cost.backward()
47 |
48 | eta -= eps_iter * torch.sign(nx.grad.data)
49 | eta.clamp_(-self.eps, self.eps)
50 | nx.grad.data.zero_()
51 |
52 | x_adv = x + eta
53 | x_adv.clamp_(x_val_min, x_val_max)
54 |
55 | if target is not None:
56 | return x_adv.detach(), target
57 |
58 | return x_adv.detach(), y
59 |
60 | def fgsm(self, x, y, target=None, x_val_min=-1, x_val_max=1):
61 | data = Variable(x.data, requires_grad=True)
62 | out = self.net(data)
63 | pred = torch.max(out, 1)[1]
64 |
65 | if pred.item() != y.item():
66 | return None
67 |
68 | if target is not None:
69 | cost = self.criterion(out, target)
70 | else:
71 | cost = -self.criterion(out, y)
72 |
73 | self.net.zero_grad()
74 | if data.grad is not None:
75 | data.grad.data.fill_(0)
76 | cost.backward()
77 |
78 | data.grad.sign_()
79 | data = data - self.eps * data.grad
80 | x_adv = torch.clamp(data, x_val_min, x_val_max)
81 |
82 | if target is not None:
83 | return x_adv, target
84 |
85 | return x_adv, y
86 |
87 | def extract_adv_images(attacker, dataloader, targeted, classes=10):
88 | adv_images = []
89 | num_examples = 0
90 | for batch, (x, y) in enumerate(dataloader):
91 | x, y = x.cuda(), y.cuda()
92 | curr_x_adv_batch = []
93 | curr_y_batch = []
94 | for i in range(len(y)):
95 | if targeted:
96 | y_new = y[i] + 1
97 | if y_new == classes:
98 | y_new = 0
99 | target = y.new_zeros(1)
100 | target[0] = y_new
101 | gg = attacker.make(x[i:i+1], y[i:i+1], target=target)
102 | else:
103 | gg = attacker.make(x[i:i+1], y[i:i+1], target=None)
104 |
105 | if gg is not None:
106 | curr_x_adv_batch.append(gg[0])
107 | curr_y_batch.append(gg[1])
108 | num_examples += 1
109 |
110 | curr_x_adv_batch = torch.cat(curr_x_adv_batch, dim=0)
111 | curr_y_batch = torch.cat(curr_y_batch, dim=0)
112 | adv_images.append((curr_x_adv_batch, curr_y_batch))
113 |
114 | if batch == 20:
115 | break
116 |
117 | return adv_images, num_examples
118 |
119 |
120 |
--------------------------------------------------------------------------------
/models/smallnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from modules import *
6 |
7 | from utils import weights_init
8 |
9 |
10 | class SmallNet(nn.Module):
11 | def __init__(self, cfg_data, mode='SR'):
12 | super(SmallNet, self).__init__()
13 | channels, classes = cfg_data['channels'], cfg_data['classes']
14 | self.conv1 = nn.Conv2d(channels, 256, kernel_size=7, stride=2, padding=1, bias=False)
15 | self.bn1 = nn.BatchNorm2d(256)
16 |
17 | self.mode = mode
18 |
19 | self.num_caps = 16
20 |
21 | planes = 16
22 | last_size = 6
23 | if self.mode == 'SR':
24 | self.conv_a = nn.Conv2d(256, self.num_caps, kernel_size=5, stride=1, padding=1, bias=False)
25 | self.conv_pose = nn.Conv2d(256, self.num_caps*planes, kernel_size=5, stride=1, padding=1, bias=False)
26 | self.bn_a = nn.BatchNorm2d(self.num_caps)
27 | self.bn_pose = nn.BatchNorm2d(self.num_caps*planes)
28 |
29 | self.conv_caps = SelfRouting2d(self.num_caps, self.num_caps, planes, planes, kernel_size=3, stride=2, padding=1, pose_out=True)
30 | self.bn_pose_conv_caps = nn.BatchNorm2d(self.num_caps*planes)
31 |
32 | self.fc_caps = SelfRouting2d(self.num_caps, classes, planes, 1, kernel_size=last_size, padding=0, pose_out=False)
33 |
34 | elif self.mode == 'DR':
35 | self.conv_pose = nn.Conv2d(256, self.num_caps*planes, kernel_size=5, stride=1, padding=1, bias=False)
36 | # self.bn_pose = nn.BatchNorm2d(self.num_caps*planes)
37 |
38 | self.conv_caps = DynamicRouting2d(self.num_caps, self.num_caps, 16, 16, kernel_size=3, stride=2, padding=1)
39 | nn.init.normal_(self.conv_caps.W, 0, 0.5)
40 |
41 | self.fc_caps = DynamicRouting2d(self.num_caps, classes, 16, 16, kernel_size=last_size, padding=0)
42 | nn.init.normal_(self.fc_caps.W, 0, 0.05)
43 |
44 | elif self.mode == 'EM':
45 | self.conv_a = nn.Conv2d(256, self.num_caps, kernel_size=5, stride=1, padding=1, bias=False)
46 | self.conv_pose = nn.Conv2d(256, self.num_caps*16, kernel_size=5, stride=1, padding=1, bias=False)
47 | self.bn_a = nn.BatchNorm2d(self.num_caps)
48 | self.bn_pose = nn.BatchNorm2d(self.num_caps*16)
49 |
50 | self.conv_caps = EmRouting2d(self.num_caps, self.num_caps, 16, kernel_size=3, stride=2, padding=1)
51 | self.bn_pose_conv_caps = nn.BatchNorm2d(self.num_caps*planes)
52 |
53 | self.fc_caps = EmRouting2d(self.num_caps, classes, 16, kernel_size=last_size, padding=0)
54 |
55 | else:
56 | raise NotImplementedError
57 |
58 | self.apply(weights_init)
59 |
60 | def forward(self, x):
61 | out = F.relu(self.bn1(self.conv1(x)))
62 |
63 | if self.mode == 'DR':
64 | # pose = self.bn_pose(self.conv_pose(out))
65 | pose = self.conv_pose(out)
66 |
67 | b, c, h, w = pose.shape
68 | pose = pose.permute(0, 2, 3, 1).contiguous()
69 | pose = squash(pose.view(b, h, w, self.num_caps, 16))
70 | pose = pose.view(b, h, w, -1)
71 | pose = pose.permute(0, 3, 1, 2)
72 |
73 | pose = self.conv_caps(pose)
74 |
75 | out = self.fc_caps(pose)
76 | out = out.view(b, -1, 16)
77 | out = out.norm(dim=-1)
78 |
79 | elif self.mode == 'EM':
80 | a, pose = self.conv_a(out), self.conv_pose(out)
81 | a, pose = torch.sigmoid(self.bn_a(a)), self.bn_pose(pose)
82 |
83 | a, pose = self.conv_caps(a, pose)
84 | pose = self.bn_pose_conv_caps(pose)
85 |
86 | a, _ = self.fc_caps(a, pose)
87 | out = a.view(a.size(0), -1)
88 |
89 | elif self.mode == 'SR':
90 | a, pose = self.conv_a(out), self.conv_pose(out)
91 | a, pose = torch.sigmoid(self.bn_a(a)), self.bn_pose(pose)
92 |
93 | a, pose = self.conv_caps(a, pose)
94 | pose = self.bn_pose_conv_caps(pose)
95 |
96 | a, _ = self.fc_caps(a, pose)
97 |
98 | out = a.view(a.size(0), -1)
99 | out = out.log()
100 |
101 | return out
102 |
103 |
104 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | arg_lists = []
4 | parser = argparse.ArgumentParser(description='CapsNet')
5 |
6 | def str2bool(v):
7 | return v.lower() in ('true', '1')
8 |
9 | def add_argument_group(name):
10 | arg = parser.add_argument_group(name)
11 | arg_lists.append(arg)
12 | return arg
13 |
14 |
15 | # data params
16 | data_arg = add_argument_group('Data Params')
17 | data_arg.add_argument('--valid_size', type=float, default=0.1,
18 | help='Proportion of training set used for validation')
19 | data_arg.add_argument('--batch_size', type=int, default=64,
20 | help='# of images in each batch of data')
21 | data_arg.add_argument('--num_workers', type=int, default=4,
22 | help='# of subprocesses to use for data loading')
23 | data_arg.add_argument('--shuffle', type=str2bool, default=True,
24 | help='Whether to shuffle the train and valid indices')
25 |
26 |
27 | # training params
28 | train_arg = add_argument_group('Training Params')
29 | train_arg.add_argument('--is_train', type=str2bool, default=True,
30 | help='Whether to train or test the model')
31 | train_arg.add_argument('--momentum', type=float, default=0.9,
32 | help='Momentum value')
33 | train_arg.add_argument('--weight_decay', type=float, default=1e-4,
34 | help='Weight decay value')
35 | train_arg.add_argument('--epochs', type=int, default=350,
36 | help='# of epochs to train for')
37 | train_arg.add_argument('--init_lr', type=float, default=0.1,
38 | help='Initial learning rate value')
39 | train_arg.add_argument('--train_patience', type=int, default=100,
40 | help='Number of epochs to wait before stopping train')
41 | train_arg.add_argument('--dataset', type=str, default='cifar10',
42 | help='Dataset for training: {mnist, cifar10}')
43 | train_arg.add_argument('--planes', type=int, default=16,
44 | help='starting layer width')
45 | train_arg.add_argument('--num_caps', type=int, default=32,
46 | help="# of capsules per layer")
47 | train_arg.add_argument('--caps_size', type=int, default=16,
48 | help="# of neurons per capsule")
49 | train_arg.add_argument('--depth', type=int, default=1,
50 | help="depth of additional layers")
51 |
52 |
53 | # other params
54 | misc_arg = add_argument_group('Misc.')
55 | misc_arg.add_argument('--name', type=str, default=None,
56 | help='Name of model to load / save')
57 | misc_arg.add_argument('--best', type=str2bool, default=True,
58 | help='Load best model or most recent for testing')
59 | misc_arg.add_argument('--random_seed', type=int, default=2018,
60 | help='Seed to ensure reproducibility')
61 | misc_arg.add_argument('--data_dir', type=str, default='./data',
62 | help='Directory in which data is stored')
63 | misc_arg.add_argument('--ckpt_dir', type=str, default='./ckpt',
64 | help='Directory in which to save model checkpoints')
65 | misc_arg.add_argument('--logs_dir', type=str, default='./logs/',
66 | help='Directory in which Tensorboard logs wil be stored')
67 | misc_arg.add_argument('--use_tensorboard', type=str2bool, default=True,
68 | help='Whether to use tensorboard for visualization')
69 | misc_arg.add_argument('--resume', type=str2bool, default=False,
70 | help='Whether to resume training from checkpoint')
71 | misc_arg.add_argument('--print_freq', type=int, default=10,
72 | help='How frequently to print training details')
73 |
74 | misc_arg.add_argument('--attack', type=str2bool, default=False,
75 | help='Whether to test against attack')
76 | misc_arg.add_argument('--attack_type', type=str, default='fgsm',
77 | help='Attack to perform: {fgms, bim}')
78 | misc_arg.add_argument('--attack_eps', type=float, default=0.1,
79 | help='eps for adv attack')
80 | misc_arg.add_argument('--targeted', type=str2bool, default=False,
81 | help='if true, do targeted attack')
82 | train_arg.add_argument('--exp', type=str, default='',
83 | help="viewpoint exp name (NULL, azimuth, elevation, full)")
84 | train_arg.add_argument('--familiar', type=str2bool, default=True,
85 | help="viewpoint exp setting (novel, familiar)")
86 |
87 |
88 | def get_config():
89 | config, unparsed = parser.parse_known_args()
90 | return config, unparsed
91 |
--------------------------------------------------------------------------------
/data_loader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | from torchvision import datasets
5 | from torchvision import transforms
6 | from torch.utils.data import Subset
7 | from norb import smallNORBViewPoint, smallNORB
8 |
9 |
10 | def get_train_valid_loader(data_dir,
11 | dataset,
12 | batch_size,
13 | random_seed,
14 | exp='azimuth',
15 | valid_size=0.1,
16 | shuffle=True,
17 | num_workers=4,
18 | pin_memory=False):
19 |
20 | data_dir = data_dir + '/' + dataset
21 |
22 | if dataset == "cifar10":
23 | trans = [transforms.RandomCrop(32, padding=4),
24 | transforms.RandomHorizontalFlip(0.5),
25 | transforms.ToTensor(),
26 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
27 | dataset = datasets.CIFAR10(data_dir, train=True, download=True,
28 | transform=transforms.Compose(trans))
29 |
30 | elif dataset == "svhn":
31 | normalize = transforms.Normalize(mean=[x / 255.0 for x in[109.9, 109.7, 113.8]],
32 | std=[x / 255.0 for x in [50.1, 50.6, 50.8]])
33 | trans = [transforms.RandomCrop(32, padding=4),
34 | transforms.ToTensor(),
35 | normalize]
36 | dataset = datasets.SVHN(data_dir, split='train', download=True,
37 | transform=transforms.Compose(trans))
38 |
39 | elif dataset == "smallnorb":
40 | trans = [transforms.Resize(48),
41 | transforms.RandomCrop(32),
42 | transforms.ColorJitter(brightness=32./255, contrast=0.3),
43 | transforms.ToTensor(),
44 | #transforms.Normalize((0.7199,), (0.117,))
45 | ]
46 | if exp in VIEWPOINT_EXPS:
47 | train_set = smallNORBViewPoint(data_dir, exp=exp, train=True, download=True,
48 | transform=transforms.Compose(trans))
49 | trans = trans[:1] + [transforms.CenterCrop(32)] +trans[3:]
50 | valid_set = smallNORBViewPoint(data_dir, exp=exp, train=False, familiar=False, download=False,
51 | transform=transforms.Compose(trans))
52 | elif exp == "full":
53 | dataset = smallNORB(data_dir, train=True, download=True,
54 | transform = transforms.Compose(trans))
55 |
56 | if exp not in VIEWPOINT_EXPS:
57 | num_train = len(dataset)
58 | indices = list(range(num_train))
59 | split = int(np.floor(valid_size * num_train))
60 |
61 | if shuffle:
62 | np.random.seed(random_seed)
63 | np.random.shuffle(indices)
64 |
65 | train_idx = indices[split:]
66 | valid_idx = indices[:split]
67 |
68 | train_set = Subset(dataset, train_idx)
69 | valid_set = Subset(dataset, valid_idx)
70 |
71 | train_loader = torch.utils.data.DataLoader(
72 | train_set, batch_size=batch_size, shuffle=True,
73 | num_workers=num_workers, pin_memory=pin_memory,
74 | )
75 |
76 | valid_loader = torch.utils.data.DataLoader(
77 | valid_set, batch_size=batch_size, shuffle=False,
78 | num_workers=num_workers, pin_memory=pin_memory,
79 | )
80 |
81 | return train_loader, valid_loader
82 |
83 | def get_test_loader(data_dir,
84 | dataset,
85 | batch_size,
86 | exp='azimuth', # smallnorb only
87 | familiar=True, # smallnorb only
88 | num_workers=4,
89 | pin_memory=False):
90 |
91 | data_dir = data_dir + '/' + dataset
92 |
93 | if dataset == "cifar10":
94 | trans = [transforms.ToTensor(),
95 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
96 | dataset = datasets.CIFAR10(data_dir, train=False, download=False,
97 | transform=transforms.Compose(trans))
98 |
99 | elif dataset == "svhn":
100 | normalize = transforms.Normalize(mean=[x / 255.0 for x in[109.9, 109.7, 113.8]],
101 | std=[x / 255.0 for x in [50.1, 50.6, 50.8]])
102 | trans = [transforms.ToTensor(),
103 | normalize]
104 | dataset = datasets.SVHN(data_dir, split='test', download=True,
105 | transform=transforms.Compose(trans))
106 |
107 | elif dataset == "smallnorb":
108 | trans = [transforms.Resize(48),
109 | transforms.CenterCrop(32),
110 | transforms.ToTensor(),
111 | #transforms.Normalize((0.7199,), (0.117,))
112 | ]
113 | if exp in VIEWPOINT_EXPS:
114 | dataset = smallNORBViewPoint(data_dir, exp=exp, familiar=familiar, train=False, download=True,
115 | transform=transforms.Compose(trans))
116 | elif exp == "full":
117 | dataset = smallNORB(data_dir, train=False, download=True,
118 | transform=transforms.Compose(trans))
119 |
120 | data_loader = torch.utils.data.DataLoader(
121 | dataset, batch_size=batch_size, shuffle=False,
122 | num_workers=num_workers, pin_memory=pin_memory,
123 | )
124 |
125 | return data_loader
126 |
127 | DATASET_CONFIGS = {
128 | 'cifar10': {'size': 32, 'channels': 3, 'classes': 10},
129 | 'svhn': {'size': 32, 'channels': 3, 'classes': 10},
130 | 'smallnorb': {'size': 32, 'channels': 1, 'classes': 5},
131 | }
132 |
133 | VIEWPOINT_EXPS = ['azimuth', 'elevation']
134 |
--------------------------------------------------------------------------------
/models/convnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import math
6 |
7 | from modules import *
8 | from utils import weights_init
9 |
10 |
11 | class ConvNet(nn.Module):
12 | def __init__(self, planes, cfg_data, num_caps, caps_size, depth, mode):
13 | caps_size = 16
14 | super(ConvNet, self).__init__()
15 | channels, classes = cfg_data['channels'], cfg_data['classes']
16 | self.num_caps = num_caps
17 | self.caps_size = caps_size
18 | self.depth = depth
19 | self.mode = mode
20 |
21 | self.layers = nn.Sequential(
22 | nn.Conv2d(channels, planes, kernel_size=3, stride=1, padding=1, bias=False),
23 | nn.BatchNorm2d(planes),
24 | nn.ReLU(True),
25 | nn.Conv2d(planes, planes*2, kernel_size=3, stride=2, padding=1, bias=False),
26 | nn.BatchNorm2d(planes*2),
27 | nn.ReLU(True),
28 | nn.Conv2d(planes*2, planes*2, kernel_size=3, stride=1, padding=1, bias=False),
29 | nn.BatchNorm2d(planes*2),
30 | nn.ReLU(True),
31 | nn.Conv2d(planes*2, planes*4, kernel_size=3, stride=2, padding=1, bias=False),
32 | nn.BatchNorm2d(planes*4),
33 | nn.ReLU(True),
34 | nn.Conv2d(planes*4, planes*4, kernel_size=3, stride=1, padding=1, bias=False),
35 | nn.BatchNorm2d(planes*4),
36 | nn.ReLU(True),
37 | nn.Conv2d(planes*4, planes*8, kernel_size=3, stride=2, padding=1, bias=False),
38 | nn.BatchNorm2d(planes*8),
39 | nn.ReLU(True),
40 | )
41 |
42 | self.conv_layers = nn.ModuleList()
43 | self.norm_layers = nn.ModuleList()
44 |
45 | #========= ConvCaps Layers
46 | for d in range(1, depth):
47 | if self.mode == 'DR':
48 | self.conv_layers.append(DynamicRouting2d(num_caps, num_caps, caps_size, caps_size, kernel_size=3, stride=1, padding=1))
49 | nn.init.normal_(self.conv_layers[0].W, 0, 0.5)
50 | elif self.mode == 'EM':
51 | self.conv_layers.append(EmRouting2d(num_caps, num_caps, caps_size, kernel_size=3, stride=1, padding=1))
52 | self.norm_layers.append(nn.BatchNorm2d(4*4*num_caps))
53 | elif self.mode == 'SR':
54 | self.conv_layers.append(SelfRouting2d(num_caps, num_caps, caps_size, caps_size, kernel_size=3, stride=1, padding=1, pose_out=True))
55 | self.norm_layers.append(nn.BatchNorm2d(planes*num_caps))
56 | else:
57 | break
58 |
59 | final_shape = 4
60 |
61 | # DR
62 | if self.mode == 'DR':
63 | self.conv_pose = nn.Conv2d(8*planes, num_caps*caps_size, kernel_size=3, stride=1, padding=1, bias=False)
64 | self.bn_pose = nn.BatchNorm2d(num_caps*caps_size)
65 | self.fc = DynamicRouting2d(num_caps, classes, caps_size, caps_size, kernel_size=final_shape, padding=0)
66 | # initialize so that output logits are in reasonable range (0.1-0.9)
67 | nn.init.normal_(self.fc.W, 0, 0.1)
68 |
69 | # EM
70 | elif self.mode == 'EM':
71 | self.conv_a = nn.Conv2d(8*planes, num_caps, kernel_size=3, stride=1, padding=1, bias=False)
72 | self.conv_pose = nn.Conv2d(8*planes, num_caps*caps_size, kernel_size=3, stride=1, padding=1, bias=False)
73 | self.bn_a = nn.BatchNorm2d(num_caps)
74 | self.bn_pose = nn.BatchNorm2d(num_caps*caps_size)
75 | self.fc = EmRouting2d(num_caps, classes, caps_size, kernel_size=final_shape, padding=0)
76 |
77 | # SR
78 | elif self.mode == 'SR':
79 | self.conv_a = nn.Conv2d(8*planes, num_caps, kernel_size=3, stride=1, padding=1, bias=False)
80 | self.conv_pose = nn.Conv2d(8*planes, num_caps*caps_size, kernel_size=3, stride=1, padding=1, bias=False)
81 | self.bn_a = nn.BatchNorm2d(num_caps)
82 | self.bn_pose = nn.BatchNorm2d(num_caps*caps_size)
83 | self.fc = SelfRouting2d(num_caps, classes, caps_size, 1, kernel_size=final_shape, padding=0, pose_out=False)
84 |
85 | # avg pooling
86 | elif self.mode == 'AVG':
87 | self.pool = nn.AvgPool2d(final_shape)
88 | self.fc = nn.Linear(8*planes, classes)
89 |
90 | # max pooling
91 | elif self.mode == 'MAX':
92 | self.pool = nn.MaxPool2d(final_shape)
93 | self.fc = nn.Linear(8*planes, classes)
94 |
95 | elif self.mode == 'FC':
96 | self.conv_ = nn.Conv2d(8*planes, num_caps*caps_size, kernel_size=3, stride=1, padding=1, bias=False)
97 | self.bn_ = nn.BatchNorm2d(num_caps*caps_size)
98 |
99 | self.fc = nn.Linear(num_caps*caps_size*final_shape*final_shape, classes)
100 |
101 | self.apply(weights_init)
102 |
103 | def forward(self, x):
104 | out = self.layers(x)
105 |
106 | # DR
107 | if self.mode == 'DR':
108 | pose = self.bn_pose(self.conv_pose(out))
109 |
110 | b, c, h, w = pose.shape
111 | pose = pose.permute(0, 2, 3, 1).contiguous()
112 | pose = squash(pose.view(b, h, w, self.num_caps, self.caps_size))
113 | pose = pose.view(b, h, w, -1)
114 | pose = pose.permute(0, 3, 1, 2)
115 |
116 | for m in self.conv_layers:
117 | pose = m(pose)
118 |
119 | out = self.fc(pose)
120 | out = out.view(b, -1, self.caps_size)
121 | out = out.norm(dim=-1)
122 |
123 | # EM
124 | elif self.mode == 'EM':
125 | a, pose = self.conv_a(out), self.conv_pose(out)
126 | a, pose = torch.sigmoid(self.bn_a(a)), self.bn_pose(pose)
127 |
128 | for m, bn in zip(self.conv_layers, self.norm_layers):
129 | a, pose = m(a, pose)
130 | pose = bn(pose)
131 |
132 | a, _ = self.fc(a, pose)
133 | out = a.view(a.size(0), -1)
134 |
135 | # ours
136 | elif self.mode == 'SR':
137 | a, pose = self.conv_a(out), self.conv_pose(out)
138 | a, pose = torch.sigmoid(self.bn_a(a)), self.bn_pose(pose)
139 |
140 | for m, bn in zip(self.conv_layers, self.norm_layers):
141 | a, pose = m(a, pose)
142 | pose = bn(pose)
143 |
144 | a, _ = self.fc(a, pose)
145 | out = a.view(a.size(0), -1)
146 | out = out.log()
147 |
148 | elif self.mode == 'AVG' or self.mode =='MAX':
149 | out = self.pool(out)
150 | out = out.view(out.size(0), -1)
151 | out = self.fc(out)
152 |
153 | elif self.mode == 'FC':
154 | out = F.relu(self.bn_(self.conv_(out)))
155 | out = out.view(out.size(0), -1)
156 | out = self.fc(out)
157 |
158 | return out
159 |
160 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from modules import *
6 | from utils import weights_init
7 |
8 |
9 | class LambdaLayer(nn.Module):
10 | def __init__(self, lambd):
11 | super(LambdaLayer, self).__init__()
12 | self.lambd = lambd
13 |
14 | def forward(self, x):
15 | return self.lambd(x)
16 |
17 | class BasicBlock(nn.Module):
18 | expansion = 1
19 |
20 | def __init__(self, in_planes, planes, stride=1, option='A'):
21 | super(BasicBlock, self).__init__()
22 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
23 | self.bn1 = nn.BatchNorm2d(planes)
24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
25 | self.bn2 = nn.BatchNorm2d(planes)
26 |
27 | self.shortcut = nn.Sequential()
28 | if stride != 1 or in_planes != planes:
29 | if option == 'A':
30 | """
31 | For CIFAR10 ResNet paper uses option A.
32 | """
33 | self.shortcut = LambdaLayer(lambda x:
34 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
35 | elif option == 'B':
36 | self.shortcut = nn.Sequential(
37 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
38 | nn.BatchNorm2d(self.expansion * planes)
39 | )
40 |
41 | def forward(self, x):
42 | out = F.relu(self.bn1(self.conv1(x)))
43 | out = self.bn2(self.conv2(out))
44 | out += self.shortcut(x)
45 | out = F.relu(out)
46 | return out
47 |
48 |
49 | class ResNet(nn.Module):
50 | def __init__(self, block, num_blocks, planes, num_caps, caps_size, depth, cfg_data, mode):
51 | super(ResNet, self).__init__()
52 | self.in_planes = planes
53 | channels, classes = cfg_data['channels'], cfg_data['classes']
54 |
55 | self.num_caps = num_caps
56 | self.caps_size = caps_size
57 |
58 | self.depth = depth
59 |
60 | self.conv1 = nn.Conv2d(channels, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False)
61 | self.bn1 = nn.BatchNorm2d(self.in_planes)
62 | self.layer1 = self._make_layer(block, planes, num_blocks[0], stride=1)
63 | self.layer2 = self._make_layer(block, 2*planes, num_blocks[1], stride=2)
64 | self.layer3 = self._make_layer(block, 4*planes, num_blocks[2], stride=2)
65 |
66 | self.mode = mode
67 |
68 | self.conv_layers = nn.ModuleList()
69 | self.norm_layers = nn.ModuleList()
70 |
71 | for d in range(1, depth):
72 | stride = 2 if d == 1 else 1
73 | if self.mode == 'DR':
74 | self.conv_layers.append(DynamicRouting2d(num_caps, num_caps, caps_size, caps_size, kernel_size=3, stride=stride, padding=1))
75 | self.norm_layers.append(nn.BatchNorm2d(caps_size*num_caps))
76 | elif self.mode == 'EM':
77 | self.conv_layers.append(EmRouting2d(num_caps, num_caps, caps_size, kernel_size=3, stride=stride, padding=1))
78 | self.norm_layers.append(nn.BatchNorm2d(caps_size*num_caps))
79 | elif self.mode == 'SR':
80 | self.conv_layers.append(SelfRouting2d(num_caps, num_caps, caps_size, caps_size, kernel_size=3, stride=stride, padding=1, pose_out=True))
81 | self.norm_layers.append(nn.BatchNorm2d(caps_size*num_caps))
82 | else:
83 | break
84 |
85 | final_shape = 8 if depth == 1 else 4
86 |
87 | # DR
88 | if self.mode == 'DR':
89 | self.conv_pose = nn.Conv2d(4*planes, num_caps*caps_size, kernel_size=3, stride=1, padding=1, bias=False)
90 | self.bn_pose = nn.BatchNorm2d(num_caps*caps_size)
91 | self.fc = DynamicRouting2d(num_caps, classes, caps_size, caps_size, kernel_size=final_shape, padding=0)
92 |
93 | # EM
94 | elif self.mode == 'EM':
95 | self.conv_a = nn.Conv2d(4*planes, num_caps, kernel_size=3, stride=1, padding=1, bias=False)
96 | self.conv_pose = nn.Conv2d(4*planes, num_caps*caps_size, kernel_size=3, stride=1, padding=1, bias=False)
97 | self.bn_a = nn.BatchNorm2d(num_caps)
98 | self.bn_pose = nn.BatchNorm2d(num_caps*caps_size)
99 | self.fc = EmRouting2d(num_caps, classes, caps_size, kernel_size=final_shape, padding=0)
100 |
101 | # SR
102 | elif self.mode == 'SR':
103 | self.conv_a = nn.Conv2d(4*planes, num_caps, kernel_size=3, stride=1, padding=1, bias=False)
104 | self.conv_pose = nn.Conv2d(4*planes, num_caps*caps_size, kernel_size=3, stride=1, padding=1, bias=False)
105 | self.bn_a = nn.BatchNorm2d(num_caps)
106 | self.bn_pose = nn.BatchNorm2d(num_caps*caps_size)
107 | self.fc = SelfRouting2d(num_caps, classes, caps_size, 1, kernel_size=final_shape, padding=0, pose_out=False)
108 |
109 | # avg pooling
110 | elif self.mode == 'AVG':
111 | self.pool = nn.AvgPool2d(final_shape)
112 | self.fc = nn.Linear(4*planes, classes)
113 |
114 | # max pooling
115 | elif self.mode == 'MAX':
116 | self.pool = nn.MaxPool2d(final_shape)
117 | self.fc = nn.Linear(4*planes, classes)
118 |
119 | elif self.mode == 'FC':
120 | self.conv_ = nn.Conv2d(4*planes, num_caps*caps_size, kernel_size=3, stride=stride, padding=1)
121 | self.bn_ = nn.BatchNorm2d(num_caps*caps_size)
122 |
123 | self.fc = nn.Linear(num_caps*caps_size*final_shape*final_shape, classes)
124 |
125 | self.apply(weights_init)
126 |
127 | def _make_layer(self, block, planes, num_blocks, stride):
128 | strides = [stride] + [1]*(num_blocks-1)
129 | layers = []
130 | for stride in strides:
131 | layers.append(block(self.in_planes, planes, stride))
132 | self.in_planes = planes * block.expansion
133 |
134 | return nn.Sequential(*layers)
135 |
136 | def forward(self, x):
137 | out = F.relu(self.bn1(self.conv1(x)))
138 | out = self.layer1(out)
139 | out = self.layer2(out)
140 | out = self.layer3(out)
141 |
142 | # DR
143 | if self.mode == 'DR':
144 | pose = self.bn_pose(self.conv_pose(out))
145 |
146 | b, c, h, w = pose.shape
147 | pose = pose.permute(0, 2, 3, 1).contiguous()
148 | pose = squash(pose.view(b, h, w, self.num_caps, self.caps_size))
149 | pose = pose.view(b, h, w, -1)
150 | pose = pose.permute(0, 3, 1, 2)
151 |
152 | for m in self.conv_layers:
153 | pose = m(pose)
154 |
155 | out = self.fc(pose).squeeze()
156 | out = out.view(b, -1, self.caps_size)
157 |
158 | out = out.norm(dim=-1)
159 | out = out / out.sum(dim=1, keepdim=True)
160 | out = out.log()
161 |
162 | # EM
163 | elif self.mode == 'EM':
164 | a, pose = self.conv_a(out), self.conv_pose(out)
165 | a, pose = torch.sigmoid(self.bn_a(a)), self.bn_pose(pose)
166 |
167 | for m, bn in zip(self.conv_layers, self.norm_layers):
168 | a, pose = m(a, pose)
169 | pose = bn(pose)
170 |
171 | a, _ = self.fc(a, pose)
172 | out = a.view(a.size(0), -1)
173 | out = out / out.sum(dim=1, keepdim=True)
174 | out = out.log()
175 |
176 | # ours
177 | if self.mode == 'SR':
178 | a, pose = self.conv_a(out), self.conv_pose(out)
179 | a, pose = torch.sigmoid(self.bn_a(a)), self.bn_pose(pose)
180 |
181 | for m, bn in zip(self.conv_layers, self.norm_layers):
182 | a, pose = m(a, pose)
183 | pose = bn(pose)
184 |
185 | a, _ = self.fc(a, pose)
186 | out = a.view(a.size(0), -1)
187 | out = out.log()
188 |
189 | elif self.mode == 'AVG' or self.mode =='MAX':
190 | out = self.pool(out)
191 | out = out.view(out.size(0), -1)
192 | out = self.fc(out)
193 |
194 | elif self.mode == 'FC':
195 | out = F.relu(self.bn_(self.conv_(out)))
196 | out = out.view(out.size(0), -1)
197 | out = self.fc(out)
198 |
199 | return out
200 |
201 | def forward_activations(self, x):
202 | out = F.relu(self.bn1(self.conv1(x)))
203 | out = self.layer1(out)
204 | out = self.layer2(out)
205 | out = self.layer3(out)
206 |
207 | if self.mode == 'DR':
208 | pose = self.bn_pose(self.conv_pose(out))
209 |
210 | b, c, h, w = pose.shape
211 | pose = pose.permute(0, 2, 3, 1).contiguous()
212 | pose = squash(pose.view(b, h, w, self.num_caps, self.caps_size))
213 | pose = pose.view(b, h, w, -1)
214 | pose = pose.permute(0, 3, 1, 2)
215 | a = pose.norm(dim=1)
216 |
217 | elif self.mode == 'EM':
218 | a = torch.sigmoid(self.bn_a(self.conv_a(out)))
219 |
220 | elif self.mode == 'SR':
221 | a = torch.sigmoid(self.bn_a(self.conv_a(out)))
222 |
223 | else:
224 | raise NotImplementedError
225 |
226 | return a
227 |
228 |
229 | def resnet20(planes, cfg_data, num_caps, caps_size, depth, mode):
230 | return ResNet(BasicBlock, [3, 3, 3], planes, num_caps, caps_size, depth, cfg_data, mode)
231 |
232 | def resnet32(planes, cfg_data, num_caps, caps_size, depth, mode):
233 | return ResNet(BasicBlock, [5, 5, 5], planes, num_caps, caps_size, depth, cfg_data, mode)
234 |
235 | def resnet44(planes, cfg_data, num_caps, caps_size, depth, mode):
236 | return ResNet(BasicBlock, [7, 7, 7], planes, num_caps, caps_size, depth, cfg_data, mode)
237 |
238 | def resnet56(planes, cfg_data, num_caps, caps_size, depth, mode):
239 | return ResNet(BasicBlock, [9, 9, 9], planes, num_caps, caps_size, depth, cfg_data, mode)
240 |
241 | def resnet110(planes, cfg_data, num_caps, caps_size, depth, mode):
242 | return ResNet(BasicBlock, [18, 18, 18], planes, num_caps, caps_size, depth, cfg_data, mode)
243 |
244 |
245 |
--------------------------------------------------------------------------------
/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import math
6 |
7 | from utils import squash
8 |
9 |
10 | eps = 1e-12
11 |
12 |
13 | class DynamicRouting2d(nn.Module):
14 | def __init__(self, A, B, C, D, kernel_size=1, stride=1, padding=1, iters=3):
15 | super(DynamicRouting2d, self).__init__()
16 | self.A = A
17 | self.B = B
18 | self.C = C
19 | self.D = D
20 |
21 | self.k = kernel_size
22 | self.kk = kernel_size ** 2
23 | self.kkA = self.kk * A
24 |
25 | self.stride = stride
26 | self.pad = padding
27 |
28 | self.iters = iters
29 | self.W = nn.Parameter(torch.FloatTensor(self.kkA, B*D, C))
30 | nn.init.kaiming_uniform_(self.W)
31 |
32 | def forward(self, pose):
33 | # x: [b, AC, h, w]
34 | b, _, h, w = pose.shape
35 | # [b, ACkk, l]
36 | pose = F.unfold(pose, self.k, stride=self.stride, padding=self.pad)
37 | l = pose.shape[-1]
38 | # [b, A, C, kk, l]
39 | pose = pose.view(b, self.A, self.C, self.kk, l)
40 | # [b, l, kk, A, C]
41 | pose = pose.permute(0, 4, 3, 1, 2).contiguous()
42 | # [b, l, kkA, C, 1]
43 | pose = pose.view(b, l, self.kkA, self.C, 1)
44 |
45 | # [b, l, kkA, BD]
46 | pose_out = torch.matmul(self.W, pose).squeeze(-1)
47 | # [b, l, kkA, B, D]
48 | pose_out = pose_out.view(b, l, self.kkA, self.B, self.D)
49 |
50 | # [b, l, kkA, B, 1]
51 | b = pose.new_zeros(b, l, self.kkA, self.B, 1)
52 | for i in range(self.iters):
53 | c = torch.softmax(b, dim=3)
54 |
55 | # [b, l, 1, B, D]
56 | s = (c * pose_out).sum(dim=2, keepdim=True)
57 | # [b, l, 1, B, D]
58 | v = squash(s)
59 |
60 | b = b + (v * pose_out).sum(dim=-1, keepdim=True)
61 |
62 | # [b, l, B, D]
63 | v = v.squeeze(2)
64 | # [b, l, BD]
65 | v = v.view(v.shape[0], l, -1)
66 | # [b, BD, l]
67 | v = v.transpose(1,2).contiguous()
68 |
69 | oh = ow = math.floor(l**(1/2))
70 |
71 | # [b, BD, oh, ow]
72 | return v.view(v.shape[0], -1, oh, ow)
73 |
74 |
75 | class EmRouting2d(nn.Module):
76 | def __init__(self, A, B, caps_size, kernel_size=3, stride=1, padding=1, iters=3, final_lambda=1e-2):
77 | super(EmRouting2d, self).__init__()
78 | self.A = A
79 | self.B = B
80 | self.psize = caps_size
81 | self.mat_dim = int(caps_size ** 0.5)
82 |
83 | self.k = kernel_size
84 | self.kk = kernel_size ** 2
85 | self.kkA = self.kk * A
86 |
87 | self.stride = stride
88 | self.pad = padding
89 |
90 | self.iters = iters
91 |
92 | self.W = nn.Parameter(torch.FloatTensor(self.kkA, B, self.mat_dim, self.mat_dim))
93 | nn.init.kaiming_uniform_(self.W.data)
94 |
95 | self.beta_u = nn.Parameter(torch.FloatTensor(1, 1, B, 1))
96 | self.beta_a = nn.Parameter(torch.FloatTensor(1, 1, B))
97 | nn.init.constant_(self.beta_u, 0)
98 | nn.init.constant_(self.beta_a, 0)
99 |
100 | self.final_lambda = final_lambda
101 | self.ln_2pi = math.log(2*math.pi)
102 |
103 | def m_step(self, v, a_in, r):
104 | # v: [b, l, kkA, B, psize]
105 | # a_in: [b, l, kkA]
106 | # r: [b, l, kkA, B, 1]
107 | b, l, _, _, _ = v.shape
108 |
109 | # r: [b, l, kkA, B, 1]
110 | r = r * a_in.view(b, l, -1, 1, 1)
111 | # r_sum: [b, l, 1, B, 1]
112 | r_sum = r.sum(dim=2, keepdim=True)
113 | # coeff: [b, l, kkA, B, 1]
114 | coeff = r / (r_sum + eps)
115 |
116 | # mu: [b, l, 1, B, psize]
117 | mu = torch.sum(coeff * v, dim=2, keepdim=True)
118 | # sigma_sq: [b, l, 1, B, psize]
119 | sigma_sq = torch.sum(coeff * (v - mu)**2, dim=2, keepdim=True) + eps
120 |
121 | # [b, l, B, 1]
122 | r_sum = r_sum.squeeze(2)
123 | # [b, l, B, psize]
124 | sigma_sq = sigma_sq.squeeze(2)
125 | # [1, 1, B, 1] + [b, l, B, psize] * [b, l, B, 1]
126 | cost_h = (self.beta_u + torch.log(sigma_sq.sqrt())) * r_sum
127 | # cost_h = (torch.log(sigma_sq.sqrt())) * r_sum
128 |
129 | # [b, l, B]
130 | a_out = torch.sigmoid(self.lambda_*(self.beta_a - cost_h.sum(dim=3)))
131 | # a_out = torch.sigmoid(self.lambda_*(-cost_h.sum(dim=3)))
132 |
133 | return a_out, mu, sigma_sq
134 |
135 | def e_step(self, v, a_out, mu, sigma_sq):
136 | b, l, _ = a_out.shape
137 | # v: [b, l, kkA, B, psize]
138 | # a_out: [b, l, B]
139 | # mu: [b, l, 1, B, psize]
140 | # sigma_sq: [b, l, B, psize]
141 |
142 | # [b, l, 1, B, psize]
143 | sigma_sq = sigma_sq.unsqueeze(2)
144 |
145 | ln_p_j = -0.5 * torch.sum(torch.log(sigma_sq*self.ln_2pi), dim=-1) \
146 | - torch.sum((v - mu)**2 / (2 * sigma_sq), dim=-1)
147 |
148 | # [b, l, kkA, B]
149 | ln_ap = ln_p_j + torch.log(a_out.view(b, l, 1, self.B))
150 | # [b, l, kkA, B]
151 | r = torch.softmax(ln_ap, dim=-1)
152 | # [b, l, kkA, B, 1]
153 | return r.unsqueeze(-1)
154 |
155 | def forward(self, a_in, pose):
156 | # pose: [batch_size, A, psize]
157 | # a: [batch_size, A]
158 | batch_size = a_in.shape[0]
159 |
160 | # a: [b, A, h, w]
161 | # pose: [b, A*psize, h, w]
162 | b, _, h, w = a_in.shape
163 |
164 | # [b, A*psize*kk, l]
165 | pose = F.unfold(pose, self.k, stride=self.stride, padding=self.pad)
166 | l = pose.shape[-1]
167 | # [b, A, psize, kk, l]
168 | pose = pose.view(b, self.A, self.psize, self.kk, l)
169 | # [b, l, kk, A, psize]
170 | pose = pose.permute(0, 4, 3, 1, 2).contiguous()
171 | # [b, l, kkA, psize]
172 | pose = pose.view(b, l, self.kkA, self.psize)
173 | # [b, l, kkA, 1, mat_dim, mat_dim]
174 | pose = pose.view(batch_size, l, self.kkA, self.mat_dim, self.mat_dim).unsqueeze(3)
175 |
176 | # [b, l, kkA, B, mat_dim, mat_dim]
177 | pose_out = torch.matmul(pose, self.W)
178 |
179 | # [b, l, kkA, B, psize]
180 | v = pose_out.view(batch_size, l, self.kkA, self.B, -1)
181 |
182 | # [b, kkA, l]
183 | a_in = F.unfold(a_in, self.k, stride=self.stride, padding=self.pad)
184 | # [b, A, kk, l]
185 | a_in = a_in.view(b, self.A, self.kk, l)
186 | # [b, l, kk, A]
187 | a_in = a_in.permute(0, 3, 2, 1).contiguous()
188 | # [b, l, kkA]
189 | a_in = a_in.view(b, l, self.kkA)
190 |
191 | r = a_in.new_ones(batch_size, l, self.kkA, self.B, 1)
192 | for i in range(self.iters):
193 | # this is from open review
194 | self.lambda_ = self.final_lambda * (1 - 0.95 ** (i+1))
195 | a_out, pose_out, sigma_sq = self.m_step(v, a_in, r)
196 | if i < self.iters - 1:
197 | r = self.e_step(v, a_out, pose_out, sigma_sq)
198 |
199 | # [b, l, B*psize]
200 | pose_out = pose_out.squeeze(2).view(b, l, -1)
201 | # [b, B*psize, l]
202 | pose_out = pose_out.transpose(1, 2)
203 | # [b, B, l]
204 | a_out = a_out.transpose(1, 2).contiguous()
205 |
206 | oh = ow = math.floor(l**(1/2))
207 |
208 | a_out = a_out.view(b, -1, oh, ow)
209 | pose_out = pose_out.view(b, -1, oh, ow)
210 |
211 | return a_out, pose_out
212 |
213 |
214 | class SelfRouting2d(nn.Module):
215 | def __init__(self, A, B, C, D, kernel_size=3, stride=1, padding=1, pose_out=False):
216 | super(SelfRouting2d, self).__init__()
217 | self.A = A
218 | self.B = B
219 | self.C = C
220 | self.D = D
221 |
222 | self.k = kernel_size
223 | self.kk = kernel_size ** 2
224 | self.kkA = self.kk * A
225 |
226 | self.stride = stride
227 | self.pad = padding
228 |
229 | self.pose_out = pose_out
230 |
231 | if pose_out:
232 | self.W1 = nn.Parameter(torch.FloatTensor(self.kkA, B*D, C))
233 | nn.init.kaiming_uniform_(self.W1.data)
234 |
235 | self.W2 = nn.Parameter(torch.FloatTensor(self.kkA, B, C))
236 | self.b2 = nn.Parameter(torch.FloatTensor(1, 1, self.kkA, B))
237 |
238 | nn.init.constant_(self.W2.data, 0)
239 | nn.init.constant_(self.b2.data, 0)
240 |
241 | def forward(self, a, pose):
242 | # a: [b, A, h, w]
243 | # pose: [b, AC, h, w]
244 | b, _, h, w = a.shape
245 |
246 | # [b, ACkk, l]
247 | pose = F.unfold(pose, self.k, stride=self.stride, padding=self.pad)
248 | l = pose.shape[-1]
249 | # [b, A, C, kk, l]
250 | pose = pose.view(b, self.A, self.C, self.kk, l)
251 | # [b, l, kk, A, C]
252 | pose = pose.permute(0, 4, 3, 1, 2).contiguous()
253 | # [b, l, kkA, C, 1]
254 | pose = pose.view(b, l, self.kkA, self.C, 1)
255 |
256 | if hasattr(self, 'W1'):
257 | # [b, l, kkA, BD]
258 | pose_out = torch.matmul(self.W1, pose).squeeze(-1)
259 | # [b, l, kkA, B, D]
260 | pose_out = pose_out.view(b, l, self.kkA, self.B, self.D)
261 |
262 | # [b, l, kkA, B]
263 | logit = torch.matmul(self.W2, pose).squeeze(-1) + self.b2
264 |
265 | # [b, l, kkA, B]
266 | r = torch.softmax(logit, dim=3)
267 |
268 | # [b, kkA, l]
269 | a = F.unfold(a, self.k, stride=self.stride, padding=self.pad)
270 | # [b, A, kk, l]
271 | a = a.view(b, self.A, self.kk, l)
272 | # [b, l, kk, A]
273 | a = a.permute(0, 3, 2, 1).contiguous()
274 | # [b, l, kkA, 1]
275 | a = a.view(b, l, self.kkA, 1)
276 |
277 | # [b, l, kkA, B]
278 | ar = a * r
279 | # [b, l, 1, B]
280 | ar_sum = ar.sum(dim=2, keepdim=True)
281 | # [b, l, kkA, B, 1]
282 | coeff = (ar / (ar_sum)).unsqueeze(-1)
283 |
284 | # [b, l, B]
285 | # a_out = ar_sum.squeeze(2)
286 | a_out = ar_sum / a.sum(dim=2, keepdim=True)
287 | a_out = a_out.squeeze(2)
288 |
289 | # [b, B, l]
290 | a_out = a_out.transpose(1,2)
291 |
292 | if hasattr(self, 'W1'):
293 | # [b, l, B, D]
294 | pose_out = (coeff * pose_out).sum(dim=2)
295 | # [b, l, BD]
296 | pose_out = pose_out.view(b, l, -1)
297 | # [b, BD, l]
298 | pose_out = pose_out.transpose(1,2)
299 |
300 | oh = ow = math.floor(l**(1/2))
301 |
302 | a_out = a_out.view(b, -1, oh, ow)
303 | if hasattr(self, 'W1'):
304 | pose_out = pose_out.view(b, -1, oh, ow)
305 | else:
306 | pose_out = None
307 |
308 | return a_out, pose_out
309 |
310 |
311 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 |
5 | import os
6 | import time
7 | import shutil
8 | import math
9 |
10 | from tqdm import tqdm
11 | from utils import AverageMeter, save_config
12 | from tensorboardX import SummaryWriter
13 |
14 | from models import *
15 | from loss import *
16 | from data_loader import DATASET_CONFIGS
17 |
18 | from attack import Attack, extract_adv_images
19 |
20 |
21 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22 |
23 |
24 | class Trainer(object):
25 | """
26 | Trainer encapsulates all the logic necessary for
27 | training.
28 |
29 | All hyperparameters are provided by the user in the
30 | config file.
31 | """
32 | def __init__(self, config, data_loader):
33 | """
34 | Construct a new Trainer instance.
35 |
36 | Args
37 | ----
38 | - config: object containing command line arguments.
39 | - data_loader: data iterator
40 | """
41 | self.config = config
42 |
43 | # data params
44 | if config.is_train:
45 | self.train_loader = data_loader[0]
46 | self.valid_loader = data_loader[1]
47 | self.num_train = len(self.train_loader.dataset)
48 | self.num_valid = len(self.valid_loader.dataset)
49 | else:
50 | self.test_loader = data_loader
51 | self.num_test = len(self.test_loader.dataset)
52 |
53 | # training params
54 | self.epochs = config.epochs
55 | self.start_epoch = 0
56 | self.momentum = config.momentum
57 | self.weight_decay = config.weight_decay
58 | self.lr = config.init_lr
59 |
60 | # misc params
61 | self.best = config.best
62 | self.ckpt_dir = config.ckpt_dir
63 | self.logs_dir = config.logs_dir
64 | self.best_valid_acc = 0.
65 | self.counter = 0
66 | self.train_patience = config.train_patience
67 | self.use_tensorboard = config.use_tensorboard
68 | self.resume = config.resume
69 | self.print_freq = config.print_freq
70 |
71 | self.attack_type = config.attack_type
72 | self.attack_eps = config.attack_eps
73 | self.targeted = config.targeted
74 |
75 | self.name = config.name
76 |
77 | if config.name.endswith('dynamic_routing'):
78 | self.mode = 'DR'
79 | elif config.name.endswith('em_routing'):
80 | self.mode = 'EM'
81 | elif config.name.endswith('self_routing'):
82 | self.mode = 'SR'
83 | elif config.name.endswith('max'):
84 | self.mode = 'MAX'
85 | elif config.name.endswith('avg'):
86 | self.mode = 'AVG'
87 | elif config.name.endswith('fc'):
88 | self.mode = 'FC'
89 | else:
90 | raise NotImplementedError("Unknown model postfix")
91 |
92 | # initialize
93 | if config.name.startswith('resnet'):
94 | self.model = resnet20(config.planes, DATASET_CONFIGS[config.dataset], config.num_caps, config.caps_size, config.depth, mode=self.mode).to(device)
95 | elif config.name.startswith('convnet'):
96 | self.model = ConvNet(config.planes, DATASET_CONFIGS[config.dataset], config.num_caps, config.caps_size, config.depth, mode=self.mode).to(device)
97 | elif config.name.startswith('smallnet'):
98 | assert self.mode in ['SR', 'DR', 'EM']
99 | self.model = SmallNet(DATASET_CONFIGS[config.dataset], mode=self.mode).to(device)
100 | else:
101 | raise NotImplementedError("Unknown model prefix")
102 |
103 | if torch.cuda.device_count() > 1:
104 | print("Let's use", torch.cuda.device_count(), "GPUs!")
105 | self.model = nn.DataParallel(self.model)
106 |
107 | self.loss = nn.CrossEntropyLoss().to(device)
108 | if self.mode in ['DR', 'EM', 'SR']:
109 | if config.dataset in ['cifar10', 'svhn']:
110 | print("using NLL loss")
111 | self.loss = nn.NLLLoss().to(device)
112 | elif config.dataset == "smallnorb":
113 | if self.mode == 'DR':
114 | print("using DR loss")
115 | self.loss = DynamicRoutingLoss().to(device)
116 | elif self.mode == 'EM':
117 | print("using EM loss")
118 | self.loss = EmRoutingLoss(self.epochs).to(device)
119 | elif self.mode == 'SR':
120 | print("using NLL loss")
121 | self.loss = nn.NLLLoss().to(device)
122 |
123 | self.params = self.model.parameters()
124 | self.optimizer = optim.SGD(self.params, lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay)
125 |
126 | if config.dataset == "cifar10":
127 | self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[150, 250], gamma=0.1)
128 | elif config.dataset == "svhn":
129 | self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[100, 150], gamma=0.1)
130 | elif config.dataset == "smallnorb":
131 | self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[100, 150], gamma=0.1)
132 |
133 | # save config as json
134 | save_config(self.name, self.config)
135 |
136 | # configure tensorboard logging
137 | if self.use_tensorboard:
138 | tensorboard_dir = self.logs_dir + self.name
139 | print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
140 | if not os.path.exists(tensorboard_dir):
141 | os.makedirs(tensorboard_dir)
142 | self.writer = SummaryWriter(tensorboard_dir)
143 |
144 | print('[*] Number of model parameters: {:,}'.format(
145 | sum([p.data.nelement() for p in self.model.parameters()])))
146 |
147 | def train(self):
148 | """
149 | Train the model on the training set.
150 |
151 | A checkpoint of the model is saved after each epoch
152 | and if the validation accuracy is improved upon,
153 | a separate ckpt is created for use on the test set.
154 | """
155 | # load the most recent checkpoint
156 | if self.resume:
157 | self.load_checkpoint(best=False)
158 |
159 | print("\n[*] Train on {} samples, validate on {} samples".format(
160 | self.num_train, self.num_valid)
161 | )
162 |
163 | for epoch in range(self.start_epoch, self.epochs):
164 | # get current lr
165 | for i, param_group in enumerate(self.optimizer.param_groups):
166 | lr = float(param_group['lr'])
167 | break
168 |
169 | print(
170 | '\nEpoch: {}/{} - LR: {:.1e}'.format(epoch+1, self.epochs, lr)
171 | )
172 |
173 | # train for 1 epoch
174 | train_loss, train_acc = self.train_one_epoch(epoch)
175 |
176 | # evaluate on validation set
177 | with torch.no_grad():
178 | valid_loss, valid_acc = self.validate(epoch)
179 |
180 |
181 | msg1 = "train loss: {:.3f} - train acc: {:.3f}"
182 | msg2 = " - val loss: {:.3f} - val acc: {:.3f}"
183 |
184 | is_best = valid_acc > self.best_valid_acc
185 | if is_best:
186 | self.counter = 0
187 | msg2 += " [*]"
188 |
189 | msg = msg1 + msg2
190 | print(msg.format(train_loss, train_acc, valid_loss, valid_acc))
191 |
192 | # check for improvement
193 | if not is_best:
194 | self.counter += 1
195 | '''
196 | if self.counter > self.train_patience:
197 | print("[!] No improvement in a while, stopping training.")
198 | return
199 | '''
200 |
201 | # decay lr
202 | self.scheduler.step()
203 |
204 | self.best_valid_acc = max(valid_acc, self.best_valid_acc)
205 | self.save_checkpoint(
206 | {'epoch': epoch + 1,
207 | 'model_state': self.model.state_dict(),
208 | 'optim_state': self.optimizer.state_dict(),
209 | 'scheduler_state': self.scheduler.state_dict(),
210 | 'best_valid_acc': self.best_valid_acc
211 | }, is_best
212 | )
213 |
214 | if self.use_tensorboard:
215 | self.writer.close()
216 |
217 | print(self.best_valid_acc)
218 |
219 | def train_one_epoch(self, epoch):
220 | """
221 | Train the model for 1 epoch of the training set.
222 |
223 | An epoch corresponds to one full pass through the entire
224 | training set in successive mini-batches.
225 |
226 | This is used by train() and should not be called manually.
227 | """
228 | self.model.train()
229 |
230 | losses = AverageMeter()
231 | accs = AverageMeter()
232 |
233 | tic = time.time()
234 | with tqdm(total=self.num_train) as pbar:
235 | for i, (x, y) in enumerate(self.train_loader):
236 | x, y = x.to(device), y.to(device)
237 |
238 | b = x.shape[0]
239 | out = self.model(x)
240 | if isinstance(self.loss, EmRoutingLoss):
241 | loss = self.loss(out, y, epoch=epoch)
242 | else:
243 | loss = self.loss(out, y)
244 |
245 | # compute accuracy
246 | pred = torch.max(out, 1)[1]
247 | correct = (pred == y).float()
248 | acc = 100 * (correct.sum() / len(y))
249 |
250 | # store
251 | losses.update(loss.data.item(), x.size()[0])
252 | accs.update(acc.data.item(), x.size()[0])
253 |
254 | # compute gradients and update SGD
255 | self.optimizer.zero_grad()
256 | loss.backward()
257 | self.optimizer.step()
258 |
259 | # measure elapsed time
260 | toc = time.time()
261 | pbar.set_description(
262 | (
263 | "{:.1f}s - loss: {:.3f} - acc: {:.3f}".format(
264 | (toc-tic), loss.data.item(), acc.data.item()
265 | )
266 | )
267 | )
268 | pbar.update(b)
269 |
270 | if self.use_tensorboard:
271 | iteration = epoch*len(self.train_loader) + i
272 | self.writer.add_scalar('train_loss', loss, iteration)
273 | self.writer.add_scalar('train_acc', acc, iteration)
274 |
275 | return losses.avg, accs.avg
276 |
277 | def validate(self, epoch):
278 | """
279 | Evaluate the model on the validation set.
280 | """
281 | self.model.eval()
282 |
283 | losses = AverageMeter()
284 | accs = AverageMeter()
285 |
286 | for i, (x, y) in enumerate(self.valid_loader):
287 | x, y = x.to(device), y.to(device)
288 |
289 | out = self.model(x)
290 | if isinstance(self.loss, EmRoutingLoss):
291 | loss = self.loss(out, y, epoch=epoch)
292 | else:
293 | loss = self.loss(out, y)
294 |
295 | # compute accuracy
296 | pred = torch.max(out, 1)[1]
297 | correct = (pred == y).float()
298 | acc = 100 * (correct.sum() / len(y))
299 |
300 | # store
301 | losses.update(loss.data.item(), x.size()[0])
302 | accs.update(acc.data.item(), x.size()[0])
303 |
304 | # log to tensorboard
305 | if self.use_tensorboard:
306 | self.writer.add_scalar('valid_loss', losses.avg, epoch)
307 | self.writer.add_scalar('valid_acc', accs.avg, epoch)
308 |
309 | return losses.avg, accs.avg
310 |
311 | def test(self):
312 | """
313 | Test the model on the held-out test data.
314 | This function should only be called at the very
315 | end once the model has finished training.
316 | """
317 | correct = 0
318 |
319 | # load the best checkpoint
320 | self.load_checkpoint(best=self.best)
321 | self.model.eval()
322 |
323 | for i, (x, y) in enumerate(self.test_loader):
324 | x, y = x.to(device), y.to(device)
325 |
326 | out = self.model(x)
327 |
328 | # compute accuracy
329 | pred = torch.max(out, 1)[1]
330 | correct += pred.eq(y.data.view_as(pred)).cpu().sum()
331 |
332 | perc = (100. * correct.data.item()) / (self.num_test)
333 | error = 100 - perc
334 | print(
335 | '[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)'.format(
336 | correct, self.num_test, perc, error)
337 | )
338 |
339 | def test_attack(self):
340 | correct = 0
341 | self.load_checkpoint(best=self.best)
342 | self.model.eval()
343 |
344 | # prepare adv attack
345 | attacker = Attack(self.model, self.loss, self.attack_type, self.attack_eps)
346 | adv_data, num_examples = extract_adv_images(attacker, self.test_loader, self.targeted, DATASET_CONFIGS[self.config.dataset]["classes"])
347 |
348 | with torch.no_grad():
349 | for i, (x, y) in enumerate(adv_data):
350 | x, y = x.to(device), y.to(device)
351 |
352 | out = self.model(x)
353 |
354 | # compute accuracy
355 | pred = torch.max(out, 1)[1]
356 | correct += pred.eq(y.data.view_as(pred)).cpu().sum()
357 |
358 | if self.targeted:
359 | success = correct
360 | else:
361 | success = num_examples - correct
362 |
363 | perc = (100. * success.data.item()) / (num_examples)
364 |
365 | print(
366 | '[*] Attack success rate ({}, targeted={}, eps={}): {}/{} ({:.2f}% - {:.2f}%)'.format(
367 | self.attack_type, self.targeted, self.attack_eps, success, num_examples, perc, 100. - perc)
368 | )
369 |
370 | def save_checkpoint(self, state, is_best):
371 | """
372 | Save a copy of the model so that it can be loaded at a future
373 | date. This function is used when the model is being evaluated
374 | on the test data.
375 |
376 | If this model has reached the best validation accuracy thus
377 | far, a seperate file with the suffix `best` is created.
378 | """
379 | # print("[*] Saving model to {}".format(self.ckpt_dir))
380 |
381 | filename = self.name + '_ckpt.pth.tar'
382 | ckpt_path = os.path.join(self.ckpt_dir, filename)
383 | torch.save(state, ckpt_path)
384 |
385 | if is_best:
386 | filename = self.name + '_model_best.pth.tar'
387 | shutil.copyfile(
388 | ckpt_path, os.path.join(self.ckpt_dir, filename)
389 | )
390 |
391 | def load_checkpoint(self, best=False):
392 | """
393 | Load the best copy of a model. This is useful for 2 cases:
394 |
395 | - Resuming training with the most recent model checkpoint.
396 | - Loading the best validation model to evaluate on the test data.
397 |
398 | Params
399 | ------
400 | - best: if set to True, loads the best model. Use this if you want
401 | to evaluate your model on the test data. Else, set to False in
402 | which case the most recent version of the checkpoint is used.
403 | """
404 | print("[*] Loading model from {}".format(self.ckpt_dir))
405 |
406 | filename = self.name + '_ckpt.pth.tar'
407 | if best:
408 | filename = self.name + '_model_best.pth.tar'
409 | ckpt_path = os.path.join(self.ckpt_dir, filename)
410 | ckpt = torch.load(ckpt_path)
411 |
412 | # load variables from checkpoint
413 | self.start_epoch = ckpt['epoch']
414 | self.best_valid_acc = ckpt['best_valid_acc']
415 | self.model.load_state_dict(ckpt['model_state'])
416 | self.optimizer.load_state_dict(ckpt['optim_state'])
417 | self.scheduler.load_state_dict(ckpt['scheduler_state'])
418 |
419 | if best:
420 | print(
421 | "[*] Loaded {} checkpoint @ epoch {} "
422 | "with best valid acc of {:.3f}".format(
423 | filename, ckpt['epoch'], ckpt['best_valid_acc'])
424 | )
425 | else:
426 | print(
427 | "[*] Loaded {} checkpoint @ epoch {}".format(
428 | filename, ckpt['epoch'])
429 | )
430 |
431 |
--------------------------------------------------------------------------------
/norb.py:
--------------------------------------------------------------------------------
1 | # Loader taken from https://github.com/mavanb/vision/blob/448fac0f38cab35a387666d553b9d5e4eec4c5e6/torchvision/datasets/utils.py
2 |
3 | from __future__ import print_function
4 | import os
5 | import errno
6 | import struct
7 |
8 | import torch
9 | import torch.utils.data as data
10 | import numpy as np
11 | from PIL import Image
12 | from torchvision.datasets.utils import download_url, check_integrity
13 |
14 |
15 | class smallNORB(data.Dataset):
16 | """`MNIST `_ Dataset.
17 | Args:
18 | root (string): Root directory of dataset where processed folder and
19 | and raw folder exist.
20 | train (bool, optional): If True, creates dataset from the training files,
21 | otherwise from the test files.
22 | download (bool, optional): If true, downloads the dataset from the internet and
23 | puts it in root directory. If the dataset is already processed, it is not processed
24 | and downloaded again. If dataset is only already downloaded, it is not
25 | downloaded again.
26 | transform (callable, optional): A function/transform that takes in an PIL image
27 | and returns a transformed version. E.g, ``transforms.RandomCrop``
28 | target_transform (callable, optional): A function/transform that takes in the
29 | target and transforms it.
30 | info_transform (callable, optional): A function/transform that takes in the
31 | info and transforms it.
32 | mode (string, optional): Denotes how the images in the data files are returned. Possible values:
33 | - all (default): both left and right are included separately.
34 | - stereo: left and right images are included as corresponding pairs.
35 | - left: only the left images are included.
36 | - right: only the right images are included.
37 | """
38 |
39 | dataset_root = "https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/"
40 | data_files = {
41 | 'train': {
42 | 'dat': {
43 | "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat',
44 | "md5_gz": "66054832f9accfe74a0f4c36a75bc0a2",
45 | "md5": "8138a0902307b32dfa0025a36dfa45ec"
46 | },
47 | 'info': {
48 | "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-info.mat',
49 | "md5_gz": "51dee1210a742582ff607dfd94e332e3",
50 | "md5": "19faee774120001fc7e17980d6960451"
51 | },
52 | 'cat': {
53 | "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat',
54 | "md5_gz": "23c8b86101fbf0904a000b43d3ed2fd9",
55 | "md5": "fd5120d3f770ad57ebe620eb61a0b633"
56 | },
57 | },
58 | 'test': {
59 | 'dat': {
60 | "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat',
61 | "md5_gz": "e4ad715691ed5a3a5f138751a4ceb071",
62 | "md5": "e9920b7f7b2869a8f1a12e945b2c166c"
63 | },
64 | 'info': {
65 | "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat',
66 | "md5_gz": "a9454f3864d7fd4bb3ea7fc3eb84924e",
67 | "md5": "7c5b871cc69dcadec1bf6a18141f5edc"
68 | },
69 | 'cat': {
70 | "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat',
71 | "md5_gz": "5aa791cd7e6016cf957ce9bdb93b8603",
72 | "md5": "fd5120d3f770ad57ebe620eb61a0b633"
73 | },
74 | },
75 | }
76 |
77 | raw_folder = 'raw'
78 | processed_folder = 'processed'
79 | train_image_file = 'train_img'
80 | train_label_file = 'train_label'
81 | train_info_file = 'train_info'
82 | test_image_file = 'test_img'
83 | test_label_file = 'test_label'
84 | test_info_file = 'test_info'
85 | extension = '.pt'
86 |
87 | def __init__(self, root, train=True, transform=None, target_transform=None, info_transform=None, download=False,
88 | mode="all"):
89 |
90 | self.root = os.path.expanduser(root)
91 | self.transform = transform
92 | self.target_transform = target_transform
93 | self.info_transform = info_transform
94 | self.train = train # training set or test set
95 | self.mode = mode
96 |
97 | if download:
98 | self.download()
99 |
100 | if not self._check_exists():
101 | raise RuntimeError('Dataset not found or corrupted.' +
102 | ' You can use download=True to download it')
103 |
104 | # load test or train set
105 | image_file = self.train_image_file if self.train else self.test_image_file
106 | label_file = self.train_label_file if self.train else self.test_label_file
107 | info_file = self.train_info_file if self.train else self.test_info_file
108 |
109 | # load labels
110 | self.labels = self._load(label_file)
111 |
112 | # load info files
113 | self.infos = self._load(info_file)
114 |
115 | # load right set
116 | if self.mode == "left":
117 | self.data = self._load("{}_left".format(image_file))
118 |
119 | # load left set
120 | elif self.mode == "right":
121 | self.data = self._load("{}_right".format(image_file))
122 |
123 | elif self.mode == "all" or self.mode == "stereo":
124 | left_data = self._load("{}_left".format(image_file))
125 | right_data = self._load("{}_right".format(image_file))
126 |
127 | # load stereo
128 | if self.mode == "stereo":
129 | self.data = torch.stack((left_data, right_data), dim=1)
130 |
131 | # load all
132 | else:
133 | self.data = torch.cat((left_data, right_data), dim=0)
134 |
135 | def __getitem__(self, index):
136 | """
137 | Args:
138 | index (int): Index
139 | Returns:
140 | mode ``all'', ``left'', ``right'':
141 | tuple: (image, target, info)
142 | mode ``stereo'':
143 | tuple: (image left, image right, target, info)
144 | """
145 | target = self.labels[index % 24300] if self.mode is "all" else self.labels[index]
146 | if self.target_transform is not None:
147 | target = self.target_transform(target)
148 |
149 | info = self.infos[index % 24300] if self.mode is "all" else self.infos[index]
150 | if self.info_transform is not None:
151 | info = self.info_transform(info)
152 |
153 | if self.mode == "stereo":
154 | img_left = self._transform(self.data[index, 0])
155 | img_right = self._transform(self.data[index, 1])
156 | return img_left, img_right, target, info
157 |
158 | img = self._transform(self.data[index])
159 | return img, target
160 |
161 | def __len__(self):
162 | return len(self.data)
163 |
164 | def _transform(self, img):
165 | # doing this so that it is consistent with all other data sets
166 | # to return a PIL Image
167 | img = Image.fromarray(img.numpy(), mode='L')
168 |
169 | if self.transform is not None:
170 | img = self.transform(img)
171 | return img
172 |
173 | def _load(self, file_name):
174 | return torch.load(os.path.join(self.root, self.processed_folder, file_name + self.extension))
175 |
176 | def _save(self, file, file_name):
177 | with open(os.path.join(self.root, self.processed_folder, file_name + self.extension), 'wb') as f:
178 | torch.save(file, f)
179 |
180 | def _check_exists(self):
181 | """ Check if processed files exists."""
182 | files = (
183 | "{}_left".format(self.train_image_file),
184 | "{}_right".format(self.train_image_file),
185 | "{}_left".format(self.test_image_file),
186 | "{}_right".format(self.test_image_file),
187 | self.test_label_file,
188 | self.train_label_file
189 | )
190 | fpaths = [os.path.exists(os.path.join(self.root, self.processed_folder, f + self.extension)) for f in files]
191 | return False not in fpaths
192 |
193 | def _flat_data_files(self):
194 | return [j for i in self.data_files.values() for j in list(i.values())]
195 |
196 | def _check_integrity(self):
197 | """Check if unpacked files have correct md5 sum."""
198 | root = self.root
199 | for file_dict in self._flat_data_files():
200 | filename = file_dict["name"]
201 | md5 = file_dict["md5"]
202 | fpath = os.path.join(root, self.raw_folder, filename)
203 | if not check_integrity(fpath, md5):
204 | return False
205 | return True
206 |
207 | def download(self):
208 | """Download the SmallNORB data if it doesn't exist in processed_folder already."""
209 | import gzip
210 |
211 | if self._check_exists():
212 | return
213 |
214 | # check if already extracted and verified
215 | if self._check_integrity():
216 | print('Files already downloaded and verified')
217 | else:
218 | # download and extract
219 | for file_dict in self._flat_data_files():
220 | url = self.dataset_root + file_dict["name"] + '.gz'
221 | filename = file_dict["name"]
222 | gz_filename = filename + '.gz'
223 | md5 = file_dict["md5_gz"]
224 | fpath = os.path.join(self.root, self.raw_folder, filename)
225 | gz_fpath = fpath + '.gz'
226 |
227 | # download if compressed file not exists and verified
228 | download_url(url, os.path.join(self.root, self.raw_folder), gz_filename, md5)
229 |
230 | print('# Extracting data {}\n'.format(filename))
231 |
232 | with open(fpath, 'wb') as out_f, \
233 | gzip.GzipFile(gz_fpath) as zip_f:
234 | out_f.write(zip_f.read())
235 |
236 | os.unlink(gz_fpath)
237 |
238 | # process and save as torch files
239 | print('Processing...')
240 |
241 | # create processed folder
242 | try:
243 | os.makedirs(os.path.join(self.root, self.processed_folder))
244 | except OSError as e:
245 | if e.errno == errno.EEXIST:
246 | pass
247 | else:
248 | raise
249 |
250 | # read train files
251 | left_train_img, right_train_img = self._read_image_file(self.data_files["train"]["dat"]["name"])
252 | train_info = self._read_info_file(self.data_files["train"]["info"]["name"])
253 | train_label = self._read_label_file(self.data_files["train"]["cat"]["name"])
254 |
255 | # read test files
256 | left_test_img, right_test_img = self._read_image_file(self.data_files["test"]["dat"]["name"])
257 | test_info = self._read_info_file(self.data_files["test"]["info"]["name"])
258 | test_label = self._read_label_file(self.data_files["test"]["cat"]["name"])
259 |
260 | # save training files
261 | self._save(left_train_img, "{}_left".format(self.train_image_file))
262 | self._save(right_train_img, "{}_right".format(self.train_image_file))
263 | self._save(train_label, self.train_label_file)
264 | self._save(train_info, self.train_info_file)
265 |
266 | # save test files
267 | self._save(left_test_img, "{}_left".format(self.test_image_file))
268 | self._save(right_test_img, "{}_right".format(self.test_image_file))
269 | self._save(test_label, self.test_label_file)
270 | self._save(test_info, self.test_info_file)
271 |
272 | print('Done!')
273 |
274 | @staticmethod
275 | def _parse_header(file_pointer):
276 | # Read magic number and ignore
277 | struct.unpack('`_ Dataset.
349 | Args:
350 | root (string): Root directory of dataset where processed folder and
351 | and raw folder exist.
352 | train (bool, optional): If True, creates dataset from the training files,
353 | otherwise from the test files.
354 | download (bool, optional): If true, downloads the dataset from the internet and
355 | puts it in root directory. If the dataset is already processed, it is not processed
356 | and downloaded again. If dataset is only already downloaded, it is not
357 | downloaded again.
358 | transform (callable, optional): A function/transform that takes in an PIL image
359 | and returns a transformed version. E.g, ``transforms.RandomCrop``
360 | target_transform (callable, optional): A function/transform that takes in the
361 | target and transforms it.
362 | info_transform (callable, optional): A function/transform that takes in the
363 | info and transforms it.
364 | mode (string, optional): Denotes how the images in the data files are returned. Possible values:
365 | - all (default): both left and right are included separately.
366 | - stereo: left and right images are included as corresponding pairs.
367 | - left: only the left images are included.
368 | - right: only the right images are included.
369 | """
370 |
371 | dataset_root = "https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/"
372 | data_files = {
373 | 'train': {
374 | 'dat': {
375 | "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat',
376 | "md5_gz": "66054832f9accfe74a0f4c36a75bc0a2",
377 | "md5": "8138a0902307b32dfa0025a36dfa45ec"
378 | },
379 | 'info': {
380 | "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-info.mat',
381 | "md5_gz": "51dee1210a742582ff607dfd94e332e3",
382 | "md5": "19faee774120001fc7e17980d6960451"
383 | },
384 | 'cat': {
385 | "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat',
386 | "md5_gz": "23c8b86101fbf0904a000b43d3ed2fd9",
387 | "md5": "fd5120d3f770ad57ebe620eb61a0b633"
388 | },
389 | },
390 | 'test': {
391 | 'dat': {
392 | "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat',
393 | "md5_gz": "e4ad715691ed5a3a5f138751a4ceb071",
394 | "md5": "e9920b7f7b2869a8f1a12e945b2c166c"
395 | },
396 | 'info': {
397 | "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat',
398 | "md5_gz": "a9454f3864d7fd4bb3ea7fc3eb84924e",
399 | "md5": "7c5b871cc69dcadec1bf6a18141f5edc"
400 | },
401 | 'cat': {
402 | "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat',
403 | "md5_gz": "5aa791cd7e6016cf957ce9bdb93b8603",
404 | "md5": "fd5120d3f770ad57ebe620eb61a0b633"
405 | },
406 | },
407 | }
408 |
409 | raw_folder = 'raw'
410 | processed_folder = 'processed'
411 | train_image_file = 'train_img'
412 | train_label_file = 'train_label'
413 | train_info_file = 'train_info'
414 | test_image_file = 'test_img'
415 | test_label_file = 'test_label'
416 | test_info_file = 'test_info'
417 | extension = '.pt'
418 |
419 | def __init__(self, root, exp='azimuth', train=True, familiar=True, transform=None, target_transform=None, info_transform=None, download=False,
420 | mode="all"):
421 |
422 | self.root = os.path.expanduser(root)
423 | self.transform = transform
424 | self.target_transform = target_transform
425 | self.info_transform = info_transform
426 | self.train = train # training set or test set
427 | self.familiar = familiar
428 | self.mode = mode
429 |
430 | if download:
431 | self.download()
432 |
433 | if not self._check_exists():
434 | raise RuntimeError('Dataset not found or corrupted.' +
435 | ' You can use download=True to download it')
436 |
437 | # load test or train set
438 | image_file = self.train_image_file if self.train else self.test_image_file
439 | label_file = self.train_label_file if self.train else self.test_label_file
440 | info_file = self.train_info_file if self.train else self.test_info_file
441 |
442 | # load labels
443 | self.labels = self._load(label_file)
444 |
445 | # load info files
446 | self.infos = self._load(info_file)
447 |
448 | # load right set
449 | if self.mode == "left":
450 | self.data = self._load("{}_left".format(image_file))
451 |
452 | # load left set
453 | elif self.mode == "right":
454 | self.data = self._load("{}_right".format(image_file))
455 |
456 | elif self.mode == "all" or self.mode == "stereo":
457 | left_data = self._load("{}_left".format(image_file))
458 | right_data = self._load("{}_right".format(image_file))
459 |
460 | # load stereo
461 | if self.mode == "stereo":
462 | self.data = torch.stack((left_data, right_data), dim=1)
463 |
464 | # load all
465 | else:
466 | self.data = torch.cat((left_data, right_data), dim=0)
467 |
468 | # prepare exp
469 | img, tar, inf = [], [], []
470 | if exp == 'azimuth':
471 | self.anno_dim = 2
472 | self.train_anno = [0, 2, 4, 34, 32, 30]
473 | elif exp == 'elevation':
474 | self.anno_dim = 1
475 | self.train_anno = [0, 1, 2]
476 | else:
477 | raise NotImplementedError
478 |
479 | indices = []
480 | for i, info in enumerate(self.infos):
481 | info = info[self.anno_dim].data.item()
482 | if (info in self.train_anno) == (self.train or self.familiar):
483 | indices.append(i)
484 |
485 | self.data = self.data[indices + [i + 24300 for i in indices]] if self.mode is 'all' else self.data[indices]
486 | self.labels = self.labels[indices]
487 | self.infos = self.infos[indices]
488 |
489 | def __getitem__(self, index):
490 | """
491 | Args:
492 | index (int): Index
493 | Returns:
494 | mode ``all'', ``left'', ``right'':
495 | tuple: (image, target, info)
496 | mode ``stereo'':
497 | tuple: (image left, image right, target, info)
498 | """
499 | target = self.labels[index % len(self.infos)] if self.mode is "all" else self.labels[index]
500 | if self.target_transform is not None:
501 | target = self.target_transform(target)
502 |
503 | info = self.infos[index % len(self.infos)] if self.mode is "all" else self.infos[index]
504 | if self.info_transform is not None:
505 | info = self.info_transform(info)
506 |
507 | if self.mode == "stereo":
508 | img_left = self._transform(self.data[index, 0])
509 | img_right = self._transform(self.data[index, 1])
510 | return img_left, img_right, target, info
511 |
512 | img = self._transform(self.data[index])
513 | return img, target
514 |
515 | def __len__(self):
516 | return len(self.data)
517 |
518 | def _transform(self, img):
519 | # doing this so that it is consistent with all other data sets
520 | # to return a PIL Image
521 | img = Image.fromarray(img.numpy(), mode='L')
522 |
523 | if self.transform is not None:
524 | img = self.transform(img)
525 | return img
526 |
527 | def _load(self, file_name):
528 | return torch.load(os.path.join(self.root, self.processed_folder, file_name + self.extension))
529 |
530 | def _save(self, file, file_name):
531 | with open(os.path.join(self.root, self.processed_folder, file_name + self.extension), 'wb') as f:
532 | torch.save(file, f)
533 |
534 | def _check_exists(self):
535 | """ Check if processed files exists."""
536 | files = (
537 | "{}_left".format(self.train_image_file),
538 | "{}_right".format(self.train_image_file),
539 | "{}_left".format(self.test_image_file),
540 | "{}_right".format(self.test_image_file),
541 | self.test_label_file,
542 | self.train_label_file
543 | )
544 | fpaths = [os.path.exists(os.path.join(self.root, self.processed_folder, f + self.extension)) for f in files]
545 | return False not in fpaths
546 |
547 | def _flat_data_files(self):
548 | return [j for i in self.data_files.values() for j in list(i.values())]
549 |
550 | def _check_integrity(self):
551 | """Check if unpacked files have correct md5 sum."""
552 | root = self.root
553 | for file_dict in self._flat_data_files():
554 | filename = file_dict["name"]
555 | md5 = file_dict["md5"]
556 | fpath = os.path.join(root, self.raw_folder, filename)
557 | if not check_integrity(fpath, md5):
558 | return False
559 | return True
560 |
561 | def download(self):
562 | """Download the SmallNORB data if it doesn't exist in processed_folder already."""
563 | import gzip
564 |
565 | if self._check_exists():
566 | return
567 |
568 | # check if already extracted and verified
569 | if self._check_integrity():
570 | print('Files already downloaded and verified')
571 | else:
572 | # download and extract
573 | for file_dict in self._flat_data_files():
574 | url = self.dataset_root + file_dict["name"] + '.gz'
575 | filename = file_dict["name"]
576 | gz_filename = filename + '.gz'
577 | md5 = file_dict["md5_gz"]
578 | fpath = os.path.join(self.root, self.raw_folder, filename)
579 | gz_fpath = fpath + '.gz'
580 |
581 | # download if compressed file not exists and verified
582 | download_url(url, os.path.join(self.root, self.raw_folder), gz_filename, md5)
583 |
584 | print('# Extracting data {}\n'.format(filename))
585 |
586 | with open(fpath, 'wb') as out_f, \
587 | gzip.GzipFile(gz_fpath) as zip_f:
588 | out_f.write(zip_f.read())
589 |
590 | os.unlink(gz_fpath)
591 |
592 | # process and save as torch files
593 | print('Processing...')
594 |
595 | # create processed folder
596 | try:
597 | os.makedirs(os.path.join(self.root, self.processed_folder))
598 | except OSError as e:
599 | if e.errno == errno.EEXIST:
600 | pass
601 | else:
602 | raise
603 |
604 | # read train files
605 | left_train_img, right_train_img = self._read_image_file(self.data_files["train"]["dat"]["name"])
606 | train_info = self._read_info_file(self.data_files["train"]["info"]["name"])
607 | train_label = self._read_label_file(self.data_files["train"]["cat"]["name"])
608 |
609 | # read test files
610 | left_test_img, right_test_img = self._read_image_file(self.data_files["test"]["dat"]["name"])
611 | test_info = self._read_info_file(self.data_files["test"]["info"]["name"])
612 | test_label = self._read_label_file(self.data_files["test"]["cat"]["name"])
613 |
614 | # save training files
615 | self._save(left_train_img, "{}_left".format(self.train_image_file))
616 | self._save(right_train_img, "{}_right".format(self.train_image_file))
617 | self._save(train_label, self.train_label_file)
618 | self._save(train_info, self.train_info_file)
619 |
620 | # save test files
621 | self._save(left_test_img, "{}_left".format(self.test_image_file))
622 | self._save(right_test_img, "{}_right".format(self.test_image_file))
623 | self._save(test_label, self.test_label_file)
624 | self._save(test_info, self.test_info_file)
625 |
626 | print('Done!')
627 |
628 | @staticmethod
629 | def _parse_header(file_pointer):
630 | # Read magic number and ignore
631 | struct.unpack('