├── .gitignore ├── README.md ├── capsule_network.py ├── capsule_network_svhn.py ├── epochs └── .gitkeep └── media └── Benchmark.png /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | data/ 3 | epochs/* 4 | !epochs/.gitkeep -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dynamic Routing Between Capsules 2 | 3 | A barebones CUDA-enabled PyTorch implementation of the CapsNet architecture in the paper "Dynamic Routing Between Capsules" by [Kenta Iwasaki](https://github.com/iwasaki-kenta) on behalf of Gram.AI. 4 | 5 | Training for the model is done using [TorchNet](https://github.com/pytorch/tnt), with MNIST dataset loading and preprocessing done with [TorchVision](https://github.com/pytorch/vision). 6 | 7 | ## Description 8 | 9 | > A capsule is a group of neurons whose activity vector represents the instantiation parameters of a specific type of entity such as an object or object part. We use the length of the activity vector to represent the probability that the entity exists and its orientation to represent the instantiation paramters. Active capsules at one level make predictions, via transformation matrices, for the instantiation parameters of higher-level capsules. When multiple predictions agree, a higher level capsule becomes active. We show that a discrimininatively trained, multi-layer capsule system achieves state-of-the-art performance on MNIST and is considerably better than a convolutional net at recognizing highly overlapping digits. To achieve these results we use an iterative routing-by-agreement mechanism: A lower-level capsule prefers to send its output to higher level capsules whose activity vectors have a big scalar product with the prediction coming from the lower-level capsule. 10 | 11 | Paper written by Sara Sabour, Nicholas Frosst, and Geoffrey E. Hinton. For more information, please check out the paper [here](https://arxiv.org/abs/1710.09829). 12 | 13 | ## Requirements 14 | 15 | * Python 3 16 | * PyTorch 17 | * TorchVision 18 | * TorchNet 19 | * TQDM 20 | * Visdom 21 | 22 | ## Usage 23 | 24 | **Step 1** Adjust the number of training epochs, batch sizes, etc. inside `capsule_network.py`. 25 | 26 | ```python 27 | BATCH_SIZE = 100 28 | NUM_CLASSES = 10 29 | NUM_EPOCHS = 30 30 | NUM_ROUTING_ITERATIONS = 3 31 | ``` 32 | 33 | **Step 2** Start training. The MNIST dataset will be downloaded if you do not already have it in the same directory the script is run in. Make sure to have Visdom Server running! 34 | 35 | ```console 36 | $ sudo python3 -m visdom.server & python3 capsule_network.py 37 | ``` 38 | 39 | ## Benchmarks 40 | 41 | Highest accuracy was 99.7% on the 443rd epoch. The model may achieve a higher accuracy as shown by the trend of the test accuracy/loss graphs below. 42 | 43 | ![Training progress.](media/Benchmark.png) 44 | 45 | Default PyTorch Adam optimizer hyperparameters were used with no learning rate scheduling. 46 | Epochs with batch size of 100 takes ~3 minutes on a Razer Blade w/ GTX 1050 and ~2 minutes on a NVIDIA Titan XP 47 | 48 | ## TODO 49 | 50 | * Extension to other datasets apart from MNIST. 51 | 52 | ## Credits 53 | 54 | Primarily referenced these two TensorFlow and Keras implementations: 55 | 1. [Keras implementation by @XifengGuo](https://github.com/XifengGuo/CapsNet-Keras) 56 | 2. [TensorFlow implementation by @naturomics](https://github.com/naturomics/CapsNet-Tensorflow) 57 | 58 | Many thanks to [@InnerPeace-Wu](https://github.com/InnerPeace-Wu) for a [discussion on the dynamic routing procedure](https://github.com/XifengGuo/CapsNet-Keras/issues/1) outlined in the paper. 59 | 60 | ## Contact/Support 61 | 62 | Gram.AI is currently heavily developing a wide number of AI models to be either open-sourced or released for free to the community, hence why we cannot guarantee complete support for this work. 63 | 64 | If any issues come up with the usage of this implementation however, or if you would like to contribute in any way, please feel free to send an e-mail to [kenta@perlin.net](kenta@perlin.net) or open a new GitHub issue on this repository. 65 | -------------------------------------------------------------------------------- /capsule_network.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dynamic Routing Between Capsules 3 | https://arxiv.org/abs/1710.09829 4 | 5 | PyTorch implementation by Kenta Iwasaki @ Gram.AI. 6 | """ 7 | import sys 8 | sys.setrecursionlimit(15000) 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from torch import nn 13 | import numpy as np 14 | 15 | BATCH_SIZE = 100 16 | NUM_CLASSES = 10 17 | NUM_EPOCHS = 500 18 | NUM_ROUTING_ITERATIONS = 3 19 | 20 | 21 | def softmax(input, dim=1): 22 | transposed_input = input.transpose(dim, len(input.size()) - 1) 23 | softmaxed_output = F.softmax(transposed_input.contiguous().view(-1, transposed_input.size(-1)), dim=-1) 24 | return softmaxed_output.view(*transposed_input.size()).transpose(dim, len(input.size()) - 1) 25 | 26 | 27 | def augmentation(x, max_shift=2): 28 | _, _, height, width = x.size() 29 | 30 | h_shift, w_shift = np.random.randint(-max_shift, max_shift + 1, size=2) 31 | source_height_slice = slice(max(0, h_shift), h_shift + height) 32 | source_width_slice = slice(max(0, w_shift), w_shift + width) 33 | target_height_slice = slice(max(0, -h_shift), -h_shift + height) 34 | target_width_slice = slice(max(0, -w_shift), -w_shift + width) 35 | 36 | shifted_image = torch.zeros(*x.size()) 37 | shifted_image[:, :, source_height_slice, source_width_slice] = x[:, :, target_height_slice, target_width_slice] 38 | return shifted_image.float() 39 | 40 | 41 | class CapsuleLayer(nn.Module): 42 | def __init__(self, num_capsules, num_route_nodes, in_channels, out_channels, kernel_size=None, stride=None, 43 | num_iterations=NUM_ROUTING_ITERATIONS): 44 | super(CapsuleLayer, self).__init__() 45 | 46 | self.num_route_nodes = num_route_nodes 47 | self.num_iterations = num_iterations 48 | 49 | self.num_capsules = num_capsules 50 | 51 | if num_route_nodes != -1: 52 | self.route_weights = nn.Parameter(torch.randn(num_capsules, num_route_nodes, in_channels, out_channels)) 53 | else: 54 | self.capsules = nn.ModuleList( 55 | [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0) for _ in 56 | range(num_capsules)]) 57 | 58 | def squash(self, tensor, dim=-1): 59 | squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True) 60 | scale = squared_norm / (1 + squared_norm) 61 | return scale * tensor / torch.sqrt(squared_norm) 62 | 63 | def forward(self, x): 64 | if self.num_route_nodes != -1: 65 | priors = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :] 66 | 67 | logits = Variable(torch.zeros(*priors.size())).cuda() 68 | for i in range(self.num_iterations): 69 | probs = softmax(logits, dim=2) 70 | outputs = self.squash((probs * priors).sum(dim=2, keepdim=True)) 71 | 72 | if i != self.num_iterations - 1: 73 | delta_logits = (priors * outputs).sum(dim=-1, keepdim=True) 74 | logits = logits + delta_logits 75 | else: 76 | outputs = [capsule(x).view(x.size(0), -1, 1) for capsule in self.capsules] 77 | outputs = torch.cat(outputs, dim=-1) 78 | outputs = self.squash(outputs) 79 | 80 | return outputs 81 | 82 | 83 | class CapsuleNet(nn.Module): 84 | def __init__(self): 85 | super(CapsuleNet, self).__init__() 86 | 87 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1) 88 | self.primary_capsules = CapsuleLayer(num_capsules=8, num_route_nodes=-1, in_channels=256, out_channels=32, 89 | kernel_size=9, stride=2) 90 | self.digit_capsules = CapsuleLayer(num_capsules=NUM_CLASSES, num_route_nodes=32 * 6 * 6, in_channels=8, 91 | out_channels=16) 92 | 93 | self.decoder = nn.Sequential( 94 | nn.Linear(16 * NUM_CLASSES, 512), 95 | nn.ReLU(inplace=True), 96 | nn.Linear(512, 1024), 97 | nn.ReLU(inplace=True), 98 | nn.Linear(1024, 784), 99 | nn.Sigmoid() 100 | ) 101 | 102 | def forward(self, x, y=None): 103 | x = F.relu(self.conv1(x), inplace=True) 104 | x = self.primary_capsules(x) 105 | x = self.digit_capsules(x).squeeze().transpose(0, 1) 106 | 107 | classes = (x ** 2).sum(dim=-1) ** 0.5 108 | classes = F.softmax(classes, dim=-1) 109 | 110 | if y is None: 111 | # In all batches, get the most active capsule. 112 | _, max_length_indices = classes.max(dim=1) 113 | y = Variable(torch.eye(NUM_CLASSES)).cuda().index_select(dim=0, index=max_length_indices.data) 114 | 115 | reconstructions = self.decoder((x * y[:, :, None]).view(x.size(0), -1)) 116 | 117 | return classes, reconstructions 118 | 119 | 120 | class CapsuleLoss(nn.Module): 121 | def __init__(self): 122 | super(CapsuleLoss, self).__init__() 123 | self.reconstruction_loss = nn.MSELoss(size_average=False) 124 | 125 | def forward(self, images, labels, classes, reconstructions): 126 | left = F.relu(0.9 - classes, inplace=True) ** 2 127 | right = F.relu(classes - 0.1, inplace=True) ** 2 128 | 129 | margin_loss = labels * left + 0.5 * (1. - labels) * right 130 | margin_loss = margin_loss.sum() 131 | 132 | assert torch.numel(images) == torch.numel(reconstructions) 133 | images = images.view(reconstructions.size()[0], -1) 134 | reconstruction_loss = self.reconstruction_loss(reconstructions, images) 135 | 136 | return (margin_loss + 0.0005 * reconstruction_loss) / images.size(0) 137 | 138 | 139 | if __name__ == "__main__": 140 | from torch.autograd import Variable 141 | from torch.optim import Adam 142 | from torchnet.engine import Engine 143 | from torchnet.logger import VisdomPlotLogger, VisdomLogger 144 | from torchvision.utils import make_grid 145 | from torchvision.datasets.mnist import MNIST 146 | from tqdm import tqdm 147 | import torchnet as tnt 148 | 149 | model = CapsuleNet() 150 | # model.load_state_dict(torch.load('epochs/epoch_327.pt')) 151 | model.cuda() 152 | 153 | print("# parameters:", sum(param.numel() for param in model.parameters())) 154 | 155 | optimizer = Adam(model.parameters()) 156 | 157 | engine = Engine() 158 | meter_loss = tnt.meter.AverageValueMeter() 159 | meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True) 160 | confusion_meter = tnt.meter.ConfusionMeter(NUM_CLASSES, normalized=True) 161 | 162 | train_loss_logger = VisdomPlotLogger('line', opts={'title': 'Train Loss'}) 163 | train_error_logger = VisdomPlotLogger('line', opts={'title': 'Train Accuracy'}) 164 | test_loss_logger = VisdomPlotLogger('line', opts={'title': 'Test Loss'}) 165 | test_accuracy_logger = VisdomPlotLogger('line', opts={'title': 'Test Accuracy'}) 166 | confusion_logger = VisdomLogger('heatmap', opts={'title': 'Confusion matrix', 167 | 'columnnames': list(range(NUM_CLASSES)), 168 | 'rownames': list(range(NUM_CLASSES))}) 169 | ground_truth_logger = VisdomLogger('image', opts={'title': 'Ground Truth'}) 170 | reconstruction_logger = VisdomLogger('image', opts={'title': 'Reconstruction'}) 171 | 172 | capsule_loss = CapsuleLoss() 173 | 174 | 175 | def get_iterator(mode): 176 | dataset = MNIST(root='./data', download=True, train=mode) 177 | data = getattr(dataset, 'train_data' if mode else 'test_data') 178 | labels = getattr(dataset, 'train_labels' if mode else 'test_labels') 179 | tensor_dataset = tnt.dataset.TensorDataset([data, labels]) 180 | 181 | return tensor_dataset.parallel(batch_size=BATCH_SIZE, num_workers=4, shuffle=mode) 182 | 183 | 184 | def processor(sample): 185 | data, labels, training = sample 186 | 187 | data = augmentation(data.unsqueeze(1).float() / 255.0) 188 | labels = torch.LongTensor(labels) 189 | 190 | labels = torch.eye(NUM_CLASSES).index_select(dim=0, index=labels) 191 | 192 | data = Variable(data).cuda() 193 | labels = Variable(labels).cuda() 194 | 195 | if training: 196 | classes, reconstructions = model(data, labels) 197 | else: 198 | classes, reconstructions = model(data) 199 | 200 | loss = capsule_loss(data, labels, classes, reconstructions) 201 | 202 | return loss, classes 203 | 204 | 205 | def reset_meters(): 206 | meter_accuracy.reset() 207 | meter_loss.reset() 208 | confusion_meter.reset() 209 | 210 | 211 | def on_sample(state): 212 | state['sample'].append(state['train']) 213 | 214 | 215 | def on_forward(state): 216 | meter_accuracy.add(state['output'].data, torch.LongTensor(state['sample'][1])) 217 | confusion_meter.add(state['output'].data, torch.LongTensor(state['sample'][1])) 218 | meter_loss.add(state['loss'].item()) 219 | 220 | 221 | def on_start_epoch(state): 222 | reset_meters() 223 | state['iterator'] = tqdm(state['iterator']) 224 | 225 | 226 | def on_end_epoch(state): 227 | print('[Epoch %d] Training Loss: %.4f (Accuracy: %.2f%%)' % ( 228 | state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0])) 229 | 230 | train_loss_logger.log(state['epoch'], meter_loss.value()[0]) 231 | train_error_logger.log(state['epoch'], meter_accuracy.value()[0]) 232 | 233 | reset_meters() 234 | 235 | engine.test(processor, get_iterator(False)) 236 | test_loss_logger.log(state['epoch'], meter_loss.value()[0]) 237 | test_accuracy_logger.log(state['epoch'], meter_accuracy.value()[0]) 238 | confusion_logger.log(confusion_meter.value()) 239 | 240 | print('[Epoch %d] Testing Loss: %.4f (Accuracy: %.2f%%)' % ( 241 | state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0])) 242 | 243 | torch.save(model.state_dict(), 'epochs/epoch_%d.pt' % state['epoch']) 244 | 245 | # Reconstruction visualization. 246 | 247 | test_sample = next(iter(get_iterator(False))) 248 | 249 | ground_truth = (test_sample[0].unsqueeze(1).float() / 255.0) 250 | _, reconstructions = model(Variable(ground_truth).cuda()) 251 | reconstruction = reconstructions.cpu().view_as(ground_truth).data 252 | 253 | ground_truth_logger.log( 254 | make_grid(ground_truth, nrow=int(BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy()) 255 | reconstruction_logger.log( 256 | make_grid(reconstruction, nrow=int(BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy()) 257 | 258 | # def on_start(state): 259 | # state['epoch'] = 327 260 | # 261 | # engine.hooks['on_start'] = on_start 262 | engine.hooks['on_sample'] = on_sample 263 | engine.hooks['on_forward'] = on_forward 264 | engine.hooks['on_start_epoch'] = on_start_epoch 265 | engine.hooks['on_end_epoch'] = on_end_epoch 266 | 267 | engine.train(processor, get_iterator(True), maxepoch=NUM_EPOCHS, optimizer=optimizer) 268 | -------------------------------------------------------------------------------- /capsule_network_svhn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dynamic Routing Between Capsules 3 | https://arxiv.org/abs/1710.09829 4 | 5 | PyTorch implementation by Kenta Iwasaki @ Gram.AI. 6 | """ 7 | import sys 8 | sys.setrecursionlimit(15000) 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from torch import nn 13 | import numpy as np 14 | 15 | BATCH_SIZE = 100 16 | NUM_CLASSES = 10 17 | NUM_EPOCHS = 500 18 | NUM_ROUTING_ITERATIONS = 3 19 | 20 | 21 | def softmax(input, dim=1): 22 | transposed_input = input.transpose(dim, len(input.size()) - 1) 23 | softmaxed_output = F.softmax(transposed_input.contiguous().view(-1, transposed_input.size(-1)), dim=-1) 24 | return softmaxed_output.view(*transposed_input.size()).transpose(dim, len(input.size()) - 1) 25 | 26 | 27 | def augmentation(x, max_shift=2): 28 | _, _, height, width = x.size() 29 | 30 | h_shift, w_shift = np.random.randint(-max_shift, max_shift + 1, size=2) 31 | source_height_slice = slice(max(0, h_shift), h_shift + height) 32 | source_width_slice = slice(max(0, w_shift), w_shift + width) 33 | target_height_slice = slice(max(0, -h_shift), -h_shift + height) 34 | target_width_slice = slice(max(0, -w_shift), -w_shift + width) 35 | 36 | shifted_image = torch.zeros(*x.size()) 37 | shifted_image[ :, :, source_height_slice, source_width_slice] = x[:, :, target_height_slice, target_width_slice] 38 | return shifted_image.float() 39 | 40 | 41 | class CapsuleLayer(nn.Module): 42 | def __init__(self, num_capsules, num_route_nodes, in_channels, out_channels, kernel_size=None, stride=None, 43 | num_iterations=NUM_ROUTING_ITERATIONS): 44 | super(CapsuleLayer, self).__init__() 45 | 46 | self.num_route_nodes = num_route_nodes 47 | self.num_iterations = num_iterations 48 | 49 | self.num_capsules = num_capsules 50 | 51 | if num_route_nodes != -1: 52 | self.route_weights = nn.Parameter(torch.randn(num_capsules, num_route_nodes, in_channels, out_channels)) 53 | else: 54 | self.capsules = nn.ModuleList( 55 | [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0) for _ in 56 | range(num_capsules)]) 57 | 58 | def squash(self, tensor, dim=-1): 59 | squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True) 60 | scale = squared_norm / (1 + squared_norm) 61 | return scale * tensor / torch.sqrt(squared_norm) 62 | 63 | def forward(self, x): 64 | if self.num_route_nodes != -1: 65 | priors = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :] 66 | 67 | logits = Variable(torch.zeros(*priors.size())).cuda() 68 | for i in range(self.num_iterations): 69 | probs = softmax(logits, dim=2) 70 | outputs = self.squash((probs * priors).sum(dim=2, keepdim=True)) 71 | 72 | if i != self.num_iterations - 1: 73 | delta_logits = (priors * outputs).sum(dim=-1, keepdim=True) 74 | logits = logits + delta_logits 75 | else: 76 | outputs = [capsule(x).view(x.size(0), -1, 1) for capsule in self.capsules] 77 | outputs = torch.cat(outputs, dim=-1) 78 | outputs = self.squash(outputs) 79 | 80 | return outputs 81 | 82 | 83 | class CapsuleNet(nn.Module): 84 | def __init__(self): 85 | super(CapsuleNet, self).__init__() 86 | 87 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=256, kernel_size=9, stride=1) 88 | self.primary_capsules = CapsuleLayer(num_capsules=8, num_route_nodes=-1, in_channels=256, out_channels=32, 89 | kernel_size=9, stride=2) 90 | self.digit_capsules = CapsuleLayer(num_capsules=NUM_CLASSES, num_route_nodes=2048, in_channels=8, 91 | out_channels=16) 92 | 93 | self.decoder = nn.Sequential( 94 | nn.Linear(16 * NUM_CLASSES, 512), 95 | nn.ReLU(inplace=True), 96 | nn.Linear(512, 1024), 97 | nn.ReLU(inplace=True), 98 | nn.Linear(1024, 3072), 99 | nn.Sigmoid() 100 | ) 101 | 102 | def forward(self, x, y=None): 103 | x = F.relu(self.conv1(x), inplace=True) 104 | x = self.primary_capsules(x) 105 | x = self.digit_capsules(x).squeeze().transpose(0, 1) 106 | 107 | classes = (x ** 2).sum(dim=-1) ** 0.5 108 | classes = F.softmax(classes, dim=-1) 109 | 110 | if y is None: 111 | # In all batches, get the most active capsule. 112 | _, max_length_indices = classes.max(dim=1) 113 | y = Variable(torch.eye(NUM_CLASSES)).cuda().index_select(dim=0, index=max_length_indices.data) 114 | 115 | reconstructions = self.decoder((x * y[:, :, None]).view(x.size(0), -1)) 116 | 117 | return classes, reconstructions 118 | 119 | 120 | class CapsuleLoss(nn.Module): 121 | def __init__(self): 122 | super(CapsuleLoss, self).__init__() 123 | self.reconstruction_loss = nn.MSELoss(size_average=False) 124 | 125 | def forward(self, images, labels, classes, reconstructions): 126 | left = F.relu(0.9 - classes, inplace=True) ** 2 127 | right = F.relu(classes - 0.1, inplace=True) ** 2 128 | 129 | margin_loss = labels * left + 0.5 * (1. - labels) * right 130 | margin_loss = margin_loss.sum() 131 | 132 | assert torch.numel(images) == torch.numel(reconstructions) 133 | images = images.view(reconstructions.size()[0], -1) 134 | reconstruction_loss = self.reconstruction_loss(reconstructions, images) 135 | 136 | return (margin_loss + 0.0005 * reconstruction_loss) / images.size(0) 137 | 138 | 139 | if __name__ == "__main__": 140 | from torch.autograd import Variable 141 | from torch.optim import Adam 142 | from torchnet.engine import Engine 143 | from torchnet.logger import VisdomPlotLogger, VisdomLogger 144 | from torchvision.utils import make_grid 145 | from torchvision.datasets.svhn import SVHN 146 | from tqdm import tqdm 147 | import torchnet as tnt 148 | 149 | model = CapsuleNet() 150 | # model.load_state_dict(torch.load('epochs/epoch_327.pt')) 151 | model.cuda() 152 | 153 | print("# parameters:", sum(param.numel() for param in model.parameters())) 154 | 155 | optimizer = Adam(model.parameters()) 156 | 157 | engine = Engine() 158 | meter_loss = tnt.meter.AverageValueMeter() 159 | meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True) 160 | confusion_meter = tnt.meter.ConfusionMeter(NUM_CLASSES, normalized=True) 161 | 162 | train_loss_logger = VisdomPlotLogger('line', opts={'title': 'Train Loss'}) 163 | train_error_logger = VisdomPlotLogger('line', opts={'title': 'Train Accuracy'}) 164 | test_loss_logger = VisdomPlotLogger('line', opts={'title': 'Test Loss'}) 165 | test_accuracy_logger = VisdomPlotLogger('line', opts={'title': 'Test Accuracy'}) 166 | confusion_logger = VisdomLogger('heatmap', opts={'title': 'Confusion matrix', 167 | 'columnnames': list(range(NUM_CLASSES)), 168 | 'rownames': list(range(NUM_CLASSES))}) 169 | ground_truth_logger = VisdomLogger('image', opts={'title': 'Ground Truth'}) 170 | reconstruction_logger = VisdomLogger('image', opts={'title': 'Reconstruction'}) 171 | 172 | capsule_loss = CapsuleLoss() 173 | 174 | 175 | def get_iterator(mode): 176 | if mode is True: 177 | dataset = SVHN(root='./data', download=True, split="train") 178 | elif mode is False: 179 | dataset = SVHN(root='./data', download=True, split="test") 180 | data = dataset.data 181 | labels = dataset.labels 182 | 183 | tensor_dataset = tnt.dataset.TensorDataset([data, labels]) 184 | 185 | return tensor_dataset.parallel(batch_size=BATCH_SIZE, num_workers=4, shuffle=mode) 186 | 187 | 188 | def processor(sample): 189 | data, labels, training = sample 190 | 191 | data = augmentation(data) 192 | labels = torch.LongTensor(labels) 193 | 194 | labels = torch.eye(NUM_CLASSES).index_select(dim=0, index=labels) 195 | 196 | data = Variable(data).cuda() 197 | labels = Variable(labels).cuda() 198 | 199 | if training: 200 | classes, reconstructions = model(data, labels) 201 | else: 202 | classes, reconstructions = model(data) 203 | 204 | loss = capsule_loss(data, labels, classes, reconstructions) 205 | 206 | return loss, classes 207 | 208 | 209 | def reset_meters(): 210 | meter_accuracy.reset() 211 | meter_loss.reset() 212 | confusion_meter.reset() 213 | 214 | 215 | def on_sample(state): 216 | state['sample'].append(state['train']) 217 | 218 | 219 | def on_forward(state): 220 | meter_accuracy.add(state['output'].data, torch.LongTensor(state['sample'][1])) 221 | confusion_meter.add(state['output'].data, torch.LongTensor(state['sample'][1])) 222 | meter_loss.add(state['loss'].item()) 223 | 224 | 225 | def on_start_epoch(state): 226 | reset_meters() 227 | state['iterator'] = tqdm(state['iterator']) 228 | 229 | 230 | def on_end_epoch(state): 231 | print('[Epoch %d] Training Loss: %.4f (Accuracy: %.2f%%)' % ( 232 | state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0])) 233 | 234 | train_loss_logger.log(state['epoch'], meter_loss.value()[0]) 235 | train_error_logger.log(state['epoch'], meter_accuracy.value()[0]) 236 | 237 | reset_meters() 238 | 239 | engine.test(processor, get_iterator(False)) 240 | test_loss_logger.log(state['epoch'], meter_loss.value()[0]) 241 | test_accuracy_logger.log(state['epoch'], meter_accuracy.value()[0]) 242 | confusion_logger.log(confusion_meter.value()) 243 | 244 | print('[Epoch %d] Testing Loss: %.4f (Accuracy: %.2f%%)' % ( 245 | state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0])) 246 | 247 | torch.save(model.state_dict(), 'epochs/epoch_%d.pt' % state['epoch']) 248 | 249 | # Reconstruction visualization. 250 | 251 | test_sample = next(iter(get_iterator(False))) 252 | 253 | ground_truth = (test_sample[0]) 254 | _, reconstructions = model(Variable(ground_truth).cuda()) 255 | reconstruction = reconstructions.cpu().view_as(ground_truth).data 256 | 257 | ground_truth_logger.log( 258 | make_grid(ground_truth, nrow=int(BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy()) 259 | reconstruction_logger.log( 260 | make_grid(reconstruction, nrow=int(BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy()) 261 | 262 | # def on_start(state): 263 | # state['epoch'] = 327 264 | # 265 | # engine.hooks['on_start'] = on_start 266 | engine.hooks['on_sample'] = on_sample 267 | engine.hooks['on_forward'] = on_forward 268 | engine.hooks['on_start_epoch'] = on_start_epoch 269 | engine.hooks['on_end_epoch'] = on_end_epoch 270 | 271 | engine.train(processor, get_iterator(True), maxepoch=NUM_EPOCHS, optimizer=optimizer) 272 | -------------------------------------------------------------------------------- /epochs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gram-ai/capsule-networks/6eeda0882bff819b2e788be40cf87fe69f20f26a/epochs/.gitkeep -------------------------------------------------------------------------------- /media/Benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gram-ai/capsule-networks/6eeda0882bff819b2e788be40cf87fe69f20f26a/media/Benchmark.png --------------------------------------------------------------------------------