├── README.md └── main.py /README.md: -------------------------------------------------------------------------------- 1 | This is a Pytorch implementation of EM Capsules in the paper ["Matrix capsules with EM routing"](https://openreview.net/forum?id=HJWLfGWRb) 2 | 3 | The code is based on this repository: https://github.com/shzygmyx/Matrix-Capsules-pytorch. 4 | 5 | You need to install pytorch.tnt for logger and visualization, follow instructions on [`https://github.com/pytorch/tnt`](https://github.com/pytorch/tnt) 6 | and Visdom, follow instructions on [`https://github.com/facebookresearch/visdom`](https://github.com/facebookresearch/visdom) 7 | 8 | Some improvements: 9 | + Big improvement is that replacing all for loops in routing by matrix multiplication 10 | + Use visdom to log and visualize learning and testing phases 11 | + Add more losses: cross_entropy_loss, margin_loss (from Dynamic Routing Between Capsules paper), reconstruction_loss 12 | + Add more routing: angle_routing (from Dynamic Routing Between Capsules paper) 13 | + Can use multiple workers at the same time to load data much faster 14 | 15 | Therefore, the performance is much better than the original code. 16 | 17 | For instruction, read the main.py for options in argparse and A, B, C, D when training. -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ''' 4 | The Capsules layer. 5 | @author: Yuxian Meng 6 | ''' 7 | # TODO: use less permute() and contiguous() 8 | 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.autograd import Variable 14 | from torch.distributions import Normal 15 | from torch.optim import lr_scheduler 16 | from torchvision import datasets 17 | from torchvision import transforms 18 | from torchvision.utils import make_grid 19 | import numpy as np 20 | import random 21 | import os 22 | 23 | import matplotlib.pyplot as plt 24 | import argparse 25 | from tqdm import tqdm 26 | import torchnet as tnt 27 | from torchnet.logger import VisdomPlotLogger, VisdomLogger 28 | from torchnet.engine import Engine 29 | 30 | torch.manual_seed(1991) 31 | torch.cuda.manual_seed(1991) 32 | random.seed(1991) 33 | np.random.seed(1991) 34 | 35 | 36 | def print_mat(x): 37 | for i in range(x.size(1)): 38 | plt.matshow(x[0, i].data.cpu().numpy()) 39 | 40 | plt.show() 41 | 42 | 43 | class PrimaryCaps(nn.Module): 44 | """ 45 | Primary Capsule layer is nothing more than concatenate several convolutional 46 | layer together. 47 | Args: 48 | A:input channel 49 | B:number of types of capsules. 50 | 51 | """ 52 | 53 | def __init__(self, A=32, B=32): 54 | super(PrimaryCaps, self).__init__() 55 | self.B = B 56 | self.capsules_pose = nn.ModuleList([nn.Conv2d(in_channels=A, out_channels=4 * 4, 57 | kernel_size=1, stride=1) 58 | for _ in range(self.B)]) 59 | self.capsules_activation = nn.ModuleList([nn.Conv2d(in_channels=A, out_channels=1, 60 | kernel_size=1, stride=1) for _ 61 | in range(self.B)]) 62 | 63 | def forward(self, x): # b,14,14,32 64 | poses = [self.capsules_pose[i](x) for i in range(self.B)] # (b,16,12,12) *32 65 | poses = torch.cat(poses, dim=1) # b,16*32,12,12 66 | activations = [self.capsules_activation[i](x) for i in range(self.B)] # (b,1,12,12)*32 67 | activations = F.sigmoid(torch.cat(activations, dim=1)) # b,32,12,12 68 | return poses, activations 69 | 70 | 71 | class ConvCaps(nn.Module): 72 | """ 73 | Convolutional Capsule Layer. 74 | Args: 75 | B:input number of types of capsules. 76 | C:output number of types of capsules. 77 | kernel: kernel of convolution. kernel=0 means the capsules in layer L+1's 78 | receptive field contain all capsules in layer L. Kernel=0 is used in the 79 | final ClassCaps layer. 80 | stride:stride of convolution 81 | iteration: number of EM iterations 82 | coordinate_add: whether to use Coordinate Addition 83 | transform_share: whether to share transformation matrix. 84 | 85 | """ 86 | 87 | def __init__(self, B=32, C=32, kernel=3, stride=2, iteration=3, 88 | coordinate_add=False, transform_share=False): 89 | super(ConvCaps, self).__init__() 90 | self.B = B 91 | self.C = C 92 | self.K = kernel # kernel = 0 means full receptive field like class capsules 93 | self.Bkk = None 94 | self.Cww = None 95 | self.b = args.batch_size 96 | self.stride = stride 97 | self.coordinate_add = coordinate_add 98 | self.transform_share = transform_share 99 | self.beta_v = None 100 | self.beta_a = None 101 | if not transform_share: 102 | self.W = nn.Parameter(torch.randn(B, kernel, kernel, C, 103 | 4, 4)) # B,K,K,C,4,4 104 | else: 105 | self.W = nn.Parameter(torch.randn(B, C, 4, 4)) # B,C,4,4 106 | 107 | self.iteration = iteration 108 | 109 | def coordinate_addition(self, width_in, votes): 110 | add = [[i / width_in, j / width_in] for i in range(width_in) for j in range(width_in)] # K,K,w,w 111 | add = Variable(torch.Tensor(add).cuda()).view(1, 1, self.K, self.K, 1, 1, 1, 2) 112 | add = add.expand(self.b, self.B, self.K, self.K, self.C, 1, 1, 2).contiguous() 113 | votes[:, :, :, :, :, :, :, :2, -1] = votes[:, :, :, :, :, :, :, :2, -1] + add 114 | return votes 115 | 116 | def down_w(self, w): 117 | return range(w * self.stride, w * self.stride + self.K) 118 | 119 | def EM_routing(self, lambda_, a_, V): 120 | # routing coefficient 121 | R = Variable(torch.ones([self.b, self.Bkk, self.Cww]), requires_grad=False).cuda() / self.Cww 122 | 123 | for i in range(self.iteration): 124 | # M-step 125 | R = (R * a_)[..., None] 126 | sum_R = R.sum(1) 127 | mu = ((R * V).sum(1) / sum_R)[:, None, :, :] 128 | sigma_square = (R * (V - mu) ** 2).sum(1) / sum_R 129 | 130 | # E-step 131 | if i != self.iteration - 1: 132 | mu, sigma_square, V_, a__ = mu.data, sigma_square.data, V.data, a_.data 133 | normal = Normal(mu, sigma_square[:, None, :, :] ** (1 / 2)) 134 | p = torch.exp(normal.log_prob(V_)) 135 | ap = a__ * p.sum(-1) 136 | R = Variable(ap / torch.sum(ap, -1)[..., None], requires_grad=False) 137 | else: 138 | const = (self.beta_v.expand_as(sigma_square) + torch.log(sigma_square)) * sum_R 139 | a = torch.sigmoid(lambda_ * (self.beta_a.repeat(self.b, 1) - const.sum(2))) 140 | 141 | return a, mu 142 | 143 | def angle_routing(self, lambda_, a_, V): 144 | # routing coefficient 145 | R = Variable(torch.zeros([self.b, self.Bkk, self.Cww]), requires_grad=False).cuda() 146 | 147 | for i in range(self.iteration): 148 | R = F.softmax(R, dim=1) 149 | R = (R * a_)[..., None] 150 | sum_R = R.sum(1) 151 | mu = ((R * V).sum(1) / sum_R)[:, None, :, :] 152 | 153 | if i != self.iteration - 1: 154 | u_v = mu.permute(0, 2, 1, 3) @ V.permute(0, 2, 3, 1) 155 | u_v = u_v.squeeze().permute(0, 2, 1) / V.norm(2, -1) / mu.norm(2, -1) 156 | R = R.squeeze() + u_v 157 | else: 158 | sigma_square = (R * (V - mu) ** 2).sum(1) / sum_R 159 | const = (self.beta_v.expand_as(sigma_square) + torch.log(sigma_square)) * sum_R 160 | a = torch.sigmoid(lambda_ * (self.beta_a.repeat(self.b, 1) - const.sum(2))) 161 | 162 | return a, mu 163 | 164 | def forward(self, x, lambda_): 165 | poses, activations = x 166 | width_in = poses.size(2) 167 | w = int((width_in - self.K) / self.stride + 1) if self.K else 1 # 5 168 | self.Cww = w * w * self.C 169 | self.b = poses.size(0) 170 | 171 | if self.beta_v is None: 172 | self.beta_v = nn.Parameter(torch.randn(1, self.Cww, 1)).cuda() 173 | self.beta_a = nn.Parameter(torch.randn(1, self.Cww)).cuda() 174 | 175 | if self.transform_share: 176 | if self.K == 0: 177 | self.K = width_in # class Capsules' kernel = width_in 178 | W = self.W.view(self.B, 1, 1, self.C, 4, 4).expand(self.B, self.K, self.K, self.C, 4, 4).contiguous() 179 | else: 180 | W = self.W # B,K,K,C,4,4 181 | 182 | self.Bkk = self.K * self.K * self.B 183 | 184 | # used to store every capsule i's poses in each capsule c's receptive field 185 | pose = poses.contiguous() # b,16*32,12,12 186 | pose = pose.view(self.b, 16, self.B, width_in, width_in).permute(0, 2, 3, 4, 1).contiguous() # b,B,12,12,16 187 | poses = torch.stack([pose[:, :, self.stride * i:self.stride * i + self.K, 188 | self.stride * j:self.stride * j + self.K, :] for i in range(w) for j in range(w)], 189 | dim=-1) # b,B,K,K,w*w,16 190 | poses = poses.view(self.b, self.B, self.K, self.K, 1, w, w, 4, 4) # b,B,K,K,1,w,w,4,4 191 | W_hat = W[None, :, :, :, :, None, None, :, :] # 1,B,K,K,C,1,1,4,4 192 | votes = W_hat @ poses # b,B,K,K,C,w,w,4,4 193 | 194 | if self.coordinate_add: 195 | votes = self.coordinate_addition(width_in, votes) 196 | activation = activations.view(self.b, -1)[..., None].repeat(1, 1, self.Cww) 197 | else: 198 | activations_ = [activations[:, :, self.down_w(x), :][:, :, :, self.down_w(y)] 199 | for x in range(w) for y in range(w)] 200 | activation = torch.stack( 201 | activations_, dim=4).view(self.b, self.Bkk, 1, -1) \ 202 | .repeat(1, 1, self.C, 1).view(self.b, self.Bkk, self.Cww) 203 | 204 | votes = votes.view(self.b, self.Bkk, self.Cww, 16) 205 | activations, poses = getattr(self, args.routing)(lambda_, activation, votes) 206 | return poses.view(self.b, self.C, w, w, -1), activations.view(self.b, self.C, w, w) 207 | 208 | 209 | class CapsNet(nn.Module): 210 | def __init__(self, A=32, B=32, C=32, D=32, E=10, r=3): 211 | super(CapsNet, self).__init__() 212 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=A, 213 | kernel_size=5, stride=2) 214 | self.primary_caps = PrimaryCaps(A, B) 215 | self.convcaps1 = ConvCaps(B, C, kernel=3, stride=2, iteration=r, 216 | coordinate_add=False, transform_share=False) 217 | self.convcaps2 = ConvCaps(C, D, kernel=3, stride=1, iteration=r, 218 | coordinate_add=False, transform_share=False) 219 | self.classcaps = ConvCaps(D, E, kernel=0, stride=1, iteration=r, 220 | coordinate_add=True, transform_share=True) 221 | self.decoder = nn.Sequential( 222 | nn.Linear(16 * args.num_classes, 512), 223 | nn.ReLU(inplace=True), 224 | nn.Linear(512, 1024), 225 | nn.ReLU(inplace=True), 226 | nn.Linear(1024, 784), 227 | nn.Sigmoid() 228 | ) 229 | 230 | def forward(self, x, lambda_, y=None): # b,1,28,28 231 | x = F.relu(self.conv1(x)) # b,32,12,12 232 | x = self.primary_caps(x) # b,32*(4*4+1),12,12 233 | x = self.convcaps1(x, lambda_) # b,32*(4*4+1),5,5 234 | x = self.convcaps2(x, lambda_) # b,32*(4*4+1),3,3 235 | p, a = self.classcaps(x, lambda_) # b,10*16+10 236 | 237 | p = p.squeeze() 238 | 239 | if y is None: 240 | _, y = a.max(dim=1) 241 | y = y.squeeze() 242 | 243 | # convert to one hot 244 | y = Variable(torch.sparse.torch.eye(args.num_classes)).cuda().index_select(dim=0, index=y) 245 | 246 | reconstructions = self.decoder((p * y[:, :, None]).view(p.size(0), -1)) 247 | 248 | return a.squeeze(), reconstructions 249 | 250 | 251 | class CapsuleLoss(nn.Module): 252 | def __init__(self): 253 | super(CapsuleLoss, self).__init__() 254 | self.reconstruction_loss = nn.MSELoss(size_average=False) 255 | 256 | @staticmethod 257 | def spread_loss(x, target, m): # x:b,10 target:b 258 | loss = F.multi_margin_loss(x, target, p=2, margin=m) 259 | return loss 260 | 261 | @staticmethod 262 | def cross_entropy_loss(x, target, m): 263 | loss = F.cross_entropy(x, target) 264 | return loss 265 | 266 | @staticmethod 267 | def margin_loss(x, labels, m): 268 | left = F.relu(0.9 - x, inplace=True) ** 2 269 | right = F.relu(x - 0.1, inplace=True) ** 2 270 | 271 | labels = Variable(torch.sparse.torch.eye(args.num_classes).cuda()).index_select(dim=0, index=labels) 272 | 273 | margin_loss = labels * left + 0.5 * (1. - labels) * right 274 | margin_loss = margin_loss.sum() 275 | return margin_loss * 1/x.size(0) 276 | 277 | def forward(self, images, output, labels, m, recon): 278 | main_loss = getattr(self, args.loss)(output, labels, m) 279 | 280 | if args.use_recon: 281 | recon_loss = self.reconstruction_loss(recon, images) 282 | main_loss += 0.0005 * recon_loss 283 | 284 | return main_loss 285 | 286 | 287 | def reset_meters(): 288 | meter_accuracy.reset() 289 | meter_loss.reset() 290 | confusion_meter.reset() 291 | 292 | 293 | if __name__ == '__main__': 294 | parser = argparse.ArgumentParser(description='CapsNet') 295 | 296 | parser.add_argument('-batch_size', type=int, default=64) 297 | parser.add_argument('-num_epochs', type=int, default=500) 298 | parser.add_argument('-lr', type=float, default=2e-2) 299 | parser.add_argument('-clip', type=float, default=5) 300 | parser.add_argument('-r', type=int, default=3) 301 | parser.add_argument('-disable_cuda', action='store_true', 302 | help='Disable CUDA') 303 | parser.add_argument('-print_freq', type=int, default=10) 304 | parser.add_argument('-pretrained', type=str, default="") 305 | parser.add_argument('--num-classes', type=int, default=10, metavar='N', 306 | help='number of output classes (default: 10)') 307 | parser.add_argument('-gpu', type=int, default=0, help="which gpu to use") 308 | parser.add_argument('--env-name', type=str, default='main', 309 | metavar='N', help='Environment name for displaying plot') 310 | parser.add_argument('--loss', type=str, default='margin_loss', metavar='N', 311 | help='loss to use: cross_entropy_loss, margin_loss, spread_loss') 312 | parser.add_argument('--routing', type=str, default='angle_routing', metavar='N', 313 | help='routing to use: angle_routing, EM_routing') 314 | parser.add_argument('--use-recon', type=bool, default=True, metavar='N', 315 | help='use reconstruction loss or not') 316 | parser.add_argument('--num-workers', type=int, default=4, metavar='N', 317 | help='num of workers to fetch data') 318 | args = parser.parse_args() 319 | args.use_cuda = not args.disable_cuda and torch.cuda.is_available() 320 | 321 | use_cuda = args.use_cuda 322 | lambda_ = 1e-3 # TODO:find a good schedule to increase lambda and m 323 | m = 0.2 324 | 325 | A, B, C, D, E, r = 64, 8, 16, 16, args.num_classes, args.r # a small CapsNet 326 | # A, B, C, D, E, r = 32, 32, 32, 32, args.num_classes, args.r # a classic CapsNet 327 | 328 | model = CapsNet(A, B, C, D, E, r) 329 | capsule_loss = CapsuleLoss() 330 | 331 | meter_loss = tnt.meter.AverageValueMeter() 332 | meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True) 333 | confusion_meter = tnt.meter.ConfusionMeter(args.num_classes, normalized=True) 334 | 335 | setting_logger = VisdomLogger('text', opts={'title': 'Settings'}, env=args.env_name) 336 | train_loss_logger = VisdomPlotLogger('line', opts={'title': 'Train Loss'}, env=args.env_name) 337 | train_error_logger = VisdomPlotLogger('line', opts={'title': 'Train Accuracy'}, env=args.env_name) 338 | test_loss_logger = VisdomPlotLogger('line', opts={'title': 'Test Loss'}, env=args.env_name) 339 | test_accuracy_logger = VisdomPlotLogger('line', opts={'title': 'Test Accuracy'}, env=args.env_name) 340 | confusion_logger = VisdomLogger('heatmap', opts={'title': 'Confusion matrix', 341 | 'columnnames': list(range(args.num_classes)), 342 | 'rownames': list(range(args.num_classes))}, env=args.env_name) 343 | ground_truth_logger = VisdomLogger('image', opts={'title': 'Ground Truth'}, env=args.env_name) 344 | reconstruction_logger = VisdomLogger('image', opts={'title': 'Reconstruction'}, env=args.env_name) 345 | 346 | weight_folder = 'weights/{}'.format(args.env_name.replace(' ', '_')) 347 | if not os.path.isdir(weight_folder): 348 | os.mkdir(weight_folder) 349 | 350 | setting_logger.log(str(args)) 351 | 352 | print("# parameters:", sum(param.numel() for param in model.parameters())) 353 | 354 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 355 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=1) 356 | 357 | train_dataset = datasets.MNIST(root='./data/', 358 | train=True, 359 | transform=transforms.ToTensor(), 360 | download=True) 361 | 362 | test_dataset = datasets.MNIST(root='./data/', 363 | train=False, 364 | transform=transforms.ToTensor()) 365 | 366 | # Data Loader (Input Pipeline) 367 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 368 | batch_size=args.batch_size, 369 | num_workers=args.num_workers, 370 | shuffle=True) 371 | 372 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 373 | batch_size=args.batch_size, 374 | num_workers=args.num_workers, 375 | shuffle=True) 376 | 377 | steps, lambda_, m = len(train_dataset) // args.batch_size, 1e-3, 0.2 378 | 379 | if args.pretrained: 380 | model.load_state_dict(torch.load(args.pretrained)) 381 | m = 0.8 382 | lambda_ = 0.9 383 | 384 | with torch.cuda.device(args.gpu): 385 | if use_cuda: 386 | print("activating cuda") 387 | model.cuda() 388 | 389 | for epoch in range(args.num_epochs): 390 | reset_meters() 391 | 392 | # Train 393 | print("Epoch {}".format(epoch)) 394 | step = 0 395 | correct = 0 396 | loss = 0 397 | 398 | with tqdm(total=steps) as pbar: 399 | for data in train_loader: 400 | step += 1 401 | if lambda_ < 1: 402 | lambda_ += 2e-1 / steps 403 | if m < 0.9: 404 | m += 2e-1 / steps 405 | 406 | optimizer.zero_grad() 407 | 408 | imgs, labels = data # b,1,28,28; #b 409 | imgs, labels = Variable(imgs), Variable(labels) 410 | if use_cuda: 411 | imgs = imgs.cuda() 412 | labels = labels.cuda() 413 | 414 | out_labels, recon = model(imgs, lambda_, labels) 415 | 416 | recon = recon.view_as(imgs) 417 | loss = capsule_loss(imgs, out_labels, labels, m, recon) 418 | 419 | loss.backward() 420 | optimizer.step() 421 | 422 | meter_accuracy.add(out_labels.data, labels.data) 423 | meter_loss.add(loss.data[0]) 424 | pbar.set_postfix(loss=meter_loss.value()[0], acc=meter_accuracy.value()[0]) 425 | pbar.update() 426 | 427 | loss = meter_loss.value()[0] 428 | acc = meter_accuracy.value()[0] 429 | 430 | train_loss_logger.log(epoch, loss) 431 | train_error_logger.log(epoch, acc) 432 | 433 | print("Epoch{} Train acc:{:4}, loss:{:4}".format(epoch, acc, loss)) 434 | scheduler.step(acc) 435 | torch.save(model.state_dict(), "./weights/em_capsules/model_{}.pth".format(epoch)) 436 | 437 | reset_meters() 438 | # Test 439 | print('Testing...') 440 | correct = 0 441 | for i, data in enumerate(test_loader): 442 | imgs, labels = data # b,1,28,28; #b 443 | imgs, labels = Variable(imgs, volatile=True), Variable(labels, volatile=True) 444 | if use_cuda: 445 | imgs = imgs.cuda() 446 | labels = labels.cuda() 447 | out_labels, recon = model(imgs, lambda_) # b,10,17 448 | 449 | recon = imgs.view_as(imgs) 450 | loss = capsule_loss(imgs, out_labels, labels, m, recon) 451 | 452 | # visualize reconstruction for final batch 453 | if i == 0: 454 | ground_truth_logger.log( 455 | make_grid(imgs.data, nrow=int(args.batch_size ** 0.5), normalize=True, 456 | range=(0, 1)).cpu().numpy()) 457 | reconstruction_logger.log( 458 | make_grid(recon.data, nrow=int(args.batch_size ** 0.5), normalize=True, 459 | range=(0, 1)).cpu().numpy()) 460 | 461 | meter_accuracy.add(out_labels.data, labels.data) 462 | confusion_meter.add(out_labels.data, labels.data) 463 | meter_loss.add(loss.data[0]) 464 | 465 | loss = meter_loss.value()[0] 466 | acc = meter_accuracy.value()[0] 467 | 468 | test_loss_logger.log(epoch, loss) 469 | test_accuracy_logger.log(epoch, acc) 470 | confusion_logger.log(confusion_meter.value()) 471 | 472 | print("Epoch{} Test acc:{:4}, loss:{:4}".format(epoch, acc, loss)) 473 | 474 | --------------------------------------------------------------------------------