├── .ipynb_checkpoints └── Capsule Network-checkpoint.ipynb ├── Capsule Network.ipynb └── README.md /.ipynb_checkpoints/Capsule Network-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "\n", 13 | "import torch\n", 14 | "import torch.nn as nn\n", 15 | "import torch.nn.functional as F\n", 16 | "from torch.autograd import Variable\n", 17 | "from torch.optim import Adam\n", 18 | "from torchvision import datasets, transforms\n", 19 | "\n", 20 | "USE_CUDA = True" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": { 27 | "collapsed": true 28 | }, 29 | "outputs": [], 30 | "source": [ 31 | "class Mnist:\n", 32 | " def __init__(self, batch_size):\n", 33 | " dataset_transform = transforms.Compose([\n", 34 | " transforms.ToTensor(),\n", 35 | " transforms.Normalize((0.1307,), (0.3081,))\n", 36 | " ])\n", 37 | "\n", 38 | " train_dataset = datasets.MNIST('../data', train=True, download=True, transform=dataset_transform)\n", 39 | " test_dataset = datasets.MNIST('../data', train=False, download=True, transform=dataset_transform)\n", 40 | " \n", 41 | " self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n", 42 | " self.test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True) " 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 3, 48 | "metadata": { 49 | "collapsed": true 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "class ConvLayer(nn.Module):\n", 54 | " def __init__(self, in_channels=1, out_channels=256, kernel_size=9):\n", 55 | " super(ConvLayer, self).__init__()\n", 56 | "\n", 57 | " self.conv = nn.Conv2d(in_channels=in_channels,\n", 58 | " out_channels=out_channels,\n", 59 | " kernel_size=kernel_size,\n", 60 | " stride=1\n", 61 | " )\n", 62 | "\n", 63 | " def forward(self, x):\n", 64 | " return F.relu(self.conv(x))" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 4, 70 | "metadata": { 71 | "collapsed": true 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "class PrimaryCaps(nn.Module):\n", 76 | " def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9):\n", 77 | " super(PrimaryCaps, self).__init__()\n", 78 | "\n", 79 | " self.capsules = nn.ModuleList([\n", 80 | " nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0) \n", 81 | " for _ in range(num_capsules)])\n", 82 | " \n", 83 | " def forward(self, x):\n", 84 | " u = [capsule(x) for capsule in self.capsules]\n", 85 | " u = torch.stack(u, dim=1)\n", 86 | " u = u.view(x.size(0), 32 * 6 * 6, -1)\n", 87 | " return self.squash(u)\n", 88 | " \n", 89 | " def squash(self, input_tensor):\n", 90 | " squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)\n", 91 | " output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))\n", 92 | " return output_tensor" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 29, 98 | "metadata": { 99 | "collapsed": true 100 | }, 101 | "outputs": [], 102 | "source": [ 103 | "class DigitCaps(nn.Module):\n", 104 | " def __init__(self, num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16):\n", 105 | " super(DigitCaps, self).__init__()\n", 106 | "\n", 107 | " self.in_channels = in_channels\n", 108 | " self.num_routes = num_routes\n", 109 | " self.num_capsules = num_capsules\n", 110 | "\n", 111 | " self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))\n", 112 | "\n", 113 | " def forward(self, x):\n", 114 | " batch_size = x.size(0)\n", 115 | " x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)\n", 116 | "\n", 117 | " W = torch.cat([self.W] * batch_size, dim=0)\n", 118 | " u_hat = torch.matmul(W, x)\n", 119 | "\n", 120 | " b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1))\n", 121 | " if USE_CUDA:\n", 122 | " b_ij = b_ij.cuda()\n", 123 | "\n", 124 | " num_iterations = 3\n", 125 | " for iteration in range(num_iterations):\n", 126 | " c_ij = F.softmax(b_ij)\n", 127 | " c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)\n", 128 | "\n", 129 | " s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)\n", 130 | " v_j = self.squash(s_j)\n", 131 | " \n", 132 | " if iteration < num_iterations - 1:\n", 133 | " a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))\n", 134 | " b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)\n", 135 | "\n", 136 | " return v_j.squeeze(1)\n", 137 | " \n", 138 | " def squash(self, input_tensor):\n", 139 | " squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)\n", 140 | " output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))\n", 141 | " return output_tensor" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 130, 147 | "metadata": { 148 | "collapsed": true 149 | }, 150 | "outputs": [], 151 | "source": [ 152 | "class Decoder(nn.Module):\n", 153 | " def __init__(self):\n", 154 | " super(Decoder, self).__init__()\n", 155 | " \n", 156 | " self.reconstraction_layers = nn.Sequential(\n", 157 | " nn.Linear(16 * 10, 512),\n", 158 | " nn.ReLU(inplace=True),\n", 159 | " nn.Linear(512, 1024),\n", 160 | " nn.ReLU(inplace=True),\n", 161 | " nn.Linear(1024, 784),\n", 162 | " nn.Sigmoid()\n", 163 | " )\n", 164 | " \n", 165 | " def forward(self, x, data):\n", 166 | " classes = torch.sqrt((x ** 2).sum(2))\n", 167 | " classes = F.softmax(classes)\n", 168 | " \n", 169 | " _, max_length_indices = classes.max(dim=1)\n", 170 | " masked = Variable(torch.sparse.torch.eye(10))\n", 171 | " if USE_CUDA:\n", 172 | " masked = masked.cuda()\n", 173 | " masked = masked.index_select(dim=0, index=max_length_indices.squeeze(1).data)\n", 174 | " \n", 175 | " reconstructions = self.reconstraction_layers((x * masked[:, :, None, None]).view(x.size(0), -1))\n", 176 | " reconstructions = reconstructions.view(-1, 1, 28, 28)\n", 177 | " \n", 178 | " return reconstructions, masked" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 131, 184 | "metadata": { 185 | "collapsed": true 186 | }, 187 | "outputs": [], 188 | "source": [ 189 | "class CapsNet(nn.Module):\n", 190 | " def __init__(self):\n", 191 | " super(CapsNet, self).__init__()\n", 192 | " self.conv_layer = ConvLayer()\n", 193 | " self.primary_capsules = PrimaryCaps()\n", 194 | " self.digit_capsules = DigitCaps()\n", 195 | " self.decoder = Decoder()\n", 196 | " \n", 197 | " self.mse_loss = nn.MSELoss()\n", 198 | " \n", 199 | " def forward(self, data):\n", 200 | " output = self.digit_capsules(self.primary_capsules(self.conv_layer(data)))\n", 201 | " reconstructions, masked = self.decoder(output, data)\n", 202 | " return output, reconstructions, masked\n", 203 | " \n", 204 | " def loss(self, data, x, target, reconstructions):\n", 205 | " return self.margin_loss(x, target) + self.reconstruction_loss(data, reconstructions)\n", 206 | " \n", 207 | " def margin_loss(self, x, labels, size_average=True):\n", 208 | " batch_size = x.size(0)\n", 209 | "\n", 210 | " v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True))\n", 211 | "\n", 212 | " left = F.relu(0.9 - v_c).view(batch_size, -1)\n", 213 | " right = F.relu(v_c - 0.1).view(batch_size, -1)\n", 214 | "\n", 215 | " loss = labels * left + 0.5 * (1.0 - labels) * right\n", 216 | " loss = loss.sum(dim=1).mean()\n", 217 | "\n", 218 | " return loss\n", 219 | " \n", 220 | " def reconstruction_loss(self, data, reconstructions):\n", 221 | " loss = self.mse_loss(reconstructions.view(reconstructions.size(0), -1), data.view(reconstructions.size(0), -1))\n", 222 | " return loss * 0.0005" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 132, 228 | "metadata": { 229 | "collapsed": true 230 | }, 231 | "outputs": [], 232 | "source": [ 233 | "capsule_net = CapsNet()\n", 234 | "if USE_CUDA:\n", 235 | " capsule_net = capsule_net.cuda()\n", 236 | "optimizer = Adam(capsule_net.parameters())" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 133, 242 | "metadata": { 243 | "collapsed": false 244 | }, 245 | "outputs": [ 246 | { 247 | "name": "stdout", 248 | "output_type": "stream", 249 | "text": [ 250 | "train accuracy: 0.12\n", 251 | "train accuracy: 0.9\n", 252 | "train accuracy: 0.94\n", 253 | "train accuracy: 0.96\n", 254 | "train accuracy: 0.99\n", 255 | "train accuracy: 0.96\n", 256 | "0.229411779922\n", 257 | "test accuracy: 0.96\n", 258 | "0.0547490972094\n", 259 | "train accuracy: 0.98\n", 260 | "train accuracy: 0.98\n", 261 | "train accuracy: 0.99\n", 262 | "train accuracy: 0.99\n", 263 | "train accuracy: 1.0\n", 264 | "train accuracy: 0.99\n", 265 | "0.0456192491871\n", 266 | "test accuracy: 0.98\n", 267 | "0.0390225026663\n", 268 | "train accuracy: 0.99\n", 269 | "train accuracy: 0.99\n", 270 | "train accuracy: 1.0\n" 271 | ] 272 | }, 273 | { 274 | "ename": "KeyboardInterrupt", 275 | "evalue": "", 276 | "output_type": "error", 277 | "traceback": [ 278 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 279 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 280 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 24\u001b[0;31m \u001b[0mtrain_loss\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbatch_id\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m100\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 281 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 282 | ] 283 | } 284 | ], 285 | "source": [ 286 | "batch_size = 100\n", 287 | "mnist = Mnist(batch_size)\n", 288 | "\n", 289 | "n_epochs = 30\n", 290 | "\n", 291 | "\n", 292 | "for epoch in range(n_epochs):\n", 293 | " capsule_net.train()\n", 294 | " train_loss = 0\n", 295 | " for batch_id, (data, target) in enumerate(mnist.train_loader):\n", 296 | "\n", 297 | " target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)\n", 298 | " data, target = Variable(data), Variable(target)\n", 299 | "\n", 300 | " if USE_CUDA:\n", 301 | " data, target = data.cuda(), target.cuda()\n", 302 | "\n", 303 | " optimizer.zero_grad()\n", 304 | " output, reconstructions, masked = capsule_net(data)\n", 305 | " loss = capsule_net.loss(data, output, target, reconstructions)\n", 306 | " loss.backward()\n", 307 | " optimizer.step()\n", 308 | "\n", 309 | " train_loss += loss.data[0]\n", 310 | " \n", 311 | " if batch_id % 100 == 0:\n", 312 | " print \"train accuracy:\", sum(np.argmax(masked.data.cpu().numpy(), 1) == \n", 313 | " np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size)\n", 314 | " \n", 315 | " print train_loss / len(mnist.train_loader)\n", 316 | " \n", 317 | " capsule_net.eval()\n", 318 | " test_loss = 0\n", 319 | " for batch_id, (data, target) in enumerate(mnist.test_loader):\n", 320 | "\n", 321 | " target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)\n", 322 | " data, target = Variable(data), Variable(target)\n", 323 | "\n", 324 | " if USE_CUDA:\n", 325 | " data, target = data.cuda(), target.cuda()\n", 326 | "\n", 327 | " output, reconstructions, masked = capsule_net(data)\n", 328 | " loss = capsule_net.loss(data, output, target, reconstructions)\n", 329 | "\n", 330 | " test_loss += loss.data[0]\n", 331 | " \n", 332 | " if batch_id % 100 == 0:\n", 333 | " print \"test accuracy:\", sum(np.argmax(masked.data.cpu().numpy(), 1) == \n", 334 | " np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size)\n", 335 | " \n", 336 | " print test_loss / len(mnist.test_loader)" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": 134, 342 | "metadata": { 343 | "collapsed": false 344 | }, 345 | "outputs": [], 346 | "source": [ 347 | "import matplotlib\n", 348 | "import matplotlib.pyplot as plt\n", 349 | "\n", 350 | "def plot_images_separately(images):\n", 351 | " \"Plot the six MNIST images separately.\"\n", 352 | " fig = plt.figure()\n", 353 | " for j in xrange(1, 7):\n", 354 | " ax = fig.add_subplot(1, 6, j)\n", 355 | " ax.matshow(images[j-1], cmap = matplotlib.cm.binary)\n", 356 | " plt.xticks(np.array([]))\n", 357 | " plt.yticks(np.array([]))\n", 358 | " plt.show()" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 135, 364 | "metadata": { 365 | "collapsed": false 366 | }, 367 | "outputs": [ 368 | { 369 | "data": { 370 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWQAAABFCAYAAAB9nJwHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADTFJREFUeJzt3XuQ1WMcx/H3hkRKUiFd3RlEUtJFIVkzFV1YqmkIIdGW\nS0m5pCmkaHMJbSGXEmI2ZDCTu9g1IwyjVOs2ad1SZNu1/jjzfX7n7J7ddju35+z5vP7Z/Pbsen57\nznnO93me7/N9sioqKhARkdRrkOoGiIhIiDpkERFPqEMWEfGEOmQREU+oQxYR8YQ6ZBERT6hDFhHx\nhDpkERFPqEMWEfHEnnV5cIsWLSo6dOiQoKYkXmFhYUlFRUXLmh6je/Rfbe4RMuM+M+EeIXPus04d\ncocOHfj00093v1UplpWVtWlXj9E9+q829wiZcZ+ZcI+QOfepKQsREU+oQxYR8YQ6ZBERT9RpDjkV\ntm/fDsC///7rrjVt2hSAPff0vvkiIrWmCFlExBNehph//PEHkydPBuDdd98F4KuvvnLfv/nmmwHI\nzc0FoFWrVkluYXzZIQGPP/44ALNmzeK7774D4OqrrwYgLy+PPfbYIzUNlIz1zz//ADB79mzuu+8+\nAN5++20AOnfunLJ21VeKkEVEPOFVhFxSUgLAwIED+eijj6p93N133w3AN998A8CyZcvScj75k08+\nAWDq1KkArFq1CoCGDRuy9957A/Dwww8DMHToUM4888wUtHL3lZaWMnfuXCC4t/HjxwOh51j89/nn\nnwMwbdo0d+3bb78FMi9CXr9+Pa+99lqtH3/ttdfW+f/hVS+2cuVKgBo743ArVqwA4JZbbuGee+5J\nWLsS4fXXX2fo0KFAsHA5YMAAIDQ9se+++wLQvn17AH766acUtDI248ePdx8o5r333gPgueeeY/Dg\nwaloVtz89ddfQGg6LTs7G4DDDz8cIOLD057Xbt26AbDXXnsls5lSg59//hmAU089FYBJkyZV6Ujt\ng2fTpk38/vvvVX5H48aNAdz7ORaashAR8YRXEXLHjh0BaNSoEUcddRQQDHFPPvlkAIqLi7nssssA\n+PXXX4HQgkO6RMhvvfUWAMOGDaO0tBQI7tEWKdu1a+cef+SRRya5hfFjw91wBx10EBCKKtMxQt6y\nZQvr168H4KKLLgJCr0lj23vDt/naa/Oqq64CYO7cuTRq1Cgp7Y2Vvcfqq7Vr1wKh5xWgdevWbirU\nnl8bwZ5xxhmMHTsWgDZt2rjfYdOlNjqKhSJkERFPeBUh9+7dG4B169ax3377AcEmENOpUycOO+ww\nIL0+vcvKyoAgtW3btm3ceuutAEyfPr3K4y16tnnKdPLDDz8A8OOPP1b5Xn5+PgD9+vVLaptitXPn\nTgB69OjhFrWiOfjgg4HIUc66desAeOSRR4DQYvQzzzwT8XhfPfvss+7fzZo1A6Bv376pak7cvfDC\nCwBu8Xnw4MHMnDkTgObNmwNwySWXAHDTTTclvD1edcimdevW1X5v9erVFBUVRVzr0aNHopsUk7Ky\nMiZOnAiEFrMAxowZw5QpU6r9mVdffRWADRs2AHDEEUckuJXxY8M/+xru448/BtKnQ16+fDkADz30\nEEC1nbHli0+YMAGIfL6++OILAAoLCwG4//77ueCCCwBYsGABACeeeGK8mx4Ty423QAKge/fuQPrn\n/Zvi4mLeeOMNAM477zx3/eKLLwbgxhtvBJK7I1hTFiIinvAyQo7GhowrV66kvLw84nvnnHNOKppU\na/Pnz2fevHkR14499tgaF3YsMrNIsmvXrolrYJzde++9QLAYEs4iSd/ZrrTZs2cDsHnz5iqPOemk\nkwBYvHgxRx99NEDU5/T444+P+Dpo0CAXEVsO+ssvvxzP5sfMRgI2ooMgQq4vDjjgAJdI8P3337vr\nqSyErwhZRMQTaRMhv/POO0AQsQD06dMHSM5k++6wRZypU6e6pHGbl6ppznDJkiUsW7YMgKeffhqA\nBg38/+y0OePKc/wA48aNA4KFEh/ZLqwZM2a4XZS2uBouJycHgAceeACo+5xqs2bN3Fz6nDlzdru9\niRQeGRtLRa0vmjRp4kYtW7duBUKL6E2aNElZm/x/l4uIZIikR8hWxay2W4HffPNNABYuXOiutWjR\nAoDbb78dwNV98IVFxhYNjx492qXS7LPPPtX+3N9//w3AY489xogRIwDSavPExo0bgaDGSDib58/K\nykpmk+rEUhLLy8ujRsYA2dnZrtpgeGRs6xoffvghEGwRnzBhAg0bNqzyew455BAgmG/3zciRIwG4\n7bbb3LX+/funqjkJM2nSJACXSrtz586Ie062hHbIlo+6cOFC17FaGlcstRkeffRRIMhb9oVNq1hH\nfO655wIwc+bMGjtiYyVHi4qKXAGldCm5WVZWxowZM6pct9xVnxeE/vzzTyA4BCG8lorVKbC0KJtK\nCrdjxw63ezQ8bxdCz/1LL70EBOmZvgUQ0Tz11FOpbkJSWIrtK6+8AkSfbksmTVmIiHgi7hFySUkJ\n8+fPB4Jphmg7tnZXt27dvBw6rVmzxlWJOvDAA4HQBgCoeZoCQot4EJTazMvL47TTTktUUxOipKQk\naupW5b+Jb5YuXepSz6Jt+rjhhhuAYHoMgkjaalrMmjWrSmRstm7dyllnnQWEKvxBegz9n3/++Yj/\n7tixY9Spl/rCRtsffPCBSxawTSPJvG9FyCIinohbhGxzw126dIlaM7Qyi5hGjhzp5uVqM6/82Wef\nuUUjS8z3wZNPPulSZywSOvTQQ2v8GauGlpeXBwQLKZdffnmimhl3//33HwB33nlnle/17NnTzYv7\nqqioKGpkPHz4cCA4Lsw2Ji1YsIDffvsNIKWLP4m0du1a9342/fr1czW66yPb0FNaWsrq1asB3Ndk\nbvOPuUO2BZELL7wQIKIztkLc3bt35+yzzwZwwzebTM/OznZFosPZgti2bduAYNW6tLTUvTl8YO3K\nz893+dDHHHPMLn9u+/btruym1QuwTIx0WciDoMZD5UL0EMo5Ttc3seWi2pvR8qgnTpzoFipry0qo\npstuyy+//LLKLsvRo0enqDXJZVkXECQlJJOmLEREPBFzhGxRnVWygqAIueVYjhgxwkXSL774IoBL\nEwrPWbWSm/3792fp0qVA8CmVyv3l0VjeqZ011q5du1oNzy3yyMnJcX8zO7oqHato2fA9mt05UyzZ\nbKdWZZZLbt5//333719++SXie9nZ2a5Mqo2Ywl133XVAqHZCOrD3Xjo75ZRTgNDuy7q8r5o2bepG\nR/E4kqmuFCGLiHgi5gi5crQAQY1XK9Kdn5/v9v3bkSnhbI7NikSH1yaN9vsLCgqA4GDCVLCFSIuI\nlixZUmPCf3hkDKF7sHnXnj17JrKpCRVeL9fYPJytF/hs+PDhLsq3Of2aNG7cmCuvvBIIdlF27drV\n1UGOFiH7mvJXnfA61vYes4jTd5bCaAvmkydPdpUWbZNPbVVe1EvGhh5FyCIinog5Ql60aBEQWaPA\nIg2rbRDNcccdB8A111zjKvRHm2NbvHhxlWvVzfslS3l5uRsF2PEulmUSzcaNG918qs0Xz5s3jyuu\nuCLBLU28aNXKTj/9dCA9KtQ1aNDAzfGagoIC9/ps2bIlENxTr169XPZQuOpek1lZWUk9cSLe7DlM\nl8yfu+66Cwj6o0WLFrmMrRNOOAGIngVla0KFhYWuponVlrFo214LiRTzK2XUqFFAKA/XROuI7Y9g\n6UPDhg0DgkJBtdWqVSt3AnWqLF++3A1n7MMkGstHnjZtGl9//TUQ7F4cNWpU2rzIo7FhbfhhAXYS\nb+fOnVPSpt1lb97rr78+4mtdrFmzJur1tm3bute6JJ4Fg1Y+dPPmze70aFvcsw/XcDb1VlBQ4I6v\nstz6ZHTExv8QRkQkQ8QcIQ8YMAAI9veHp8wMGTIEgClTprgFvroWKLfdUPap1qdPnxoPQU2G4uJi\nlxoTnjBvu5usroGdJt22bVtX8SsdFrpqww4K2LFjh7tmOwx3tUOxPrKym5UNHDgwyS3JbJYYMGbM\nGCBUGfLBBx8EggSBFStWuMfvv//+ET/fpk0bN+1R0+g3URQhi4h4IuYI2aLg888/Hwi20kKw/TSW\nRQ2bv7HFMx9s2LDBjQguvfRSIFT5yzZ6WHFz+5SePn16UuehkqFydbPmzZszduzYFLUm9Wye8o47\n7gCC17yVDJDksjWrOXPm0KVLFyD6JiYbedtCrS38pUrcln9tgSpddiPFIicnxw17rHQmhHZsQbDS\na4uPPp+SES/jxo2r8wJtfVJ5J6kVqxk0aFAKWhObIUOGuOCiU6dOKW5N7HwK5nZFUxYiIp5I3wTJ\nFOrdu3dMR1DVB0888QQQjAbCq2RlIpu6s6OPqkuDSwe5ubnk5uamuhkZSRGyiIgnFCHLbunbt2/E\n10xnC9h21JONHETqQhGyiIgnFCGLxFGvXr0AWLVqVYpbIulIEbKIiCfUIYuIeCLLKhvV6sFZWVuA\nTYlrTsK1r6ioqHHLnO4xLezyHiEz7jMT7hEy6D7r0iGLiEjiaMpCRMQT6pBFRDyhDllExBPqkEVE\nPKEOWUTEE+qQRUQ8oQ5ZRMQT6pBFRDyhDllExBP/AwL3rM94Bza/AAAAAElFTkSuQmCC\n", 371 | "text/plain": [ 372 | "" 373 | ] 374 | }, 375 | "metadata": {}, 376 | "output_type": "display_data" 377 | } 378 | ], 379 | "source": [ 380 | "plot_images_separately(data[:6,0].data.cpu().numpy())" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 136, 386 | "metadata": { 387 | "collapsed": false 388 | }, 389 | "outputs": [ 390 | { 391 | "data": { 392 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWQAAABFCAYAAAB9nJwHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAD6RJREFUeJztnVtsVFUXgL/TltIC/UuhlGstiqJiRQSKKNRLtEYUFRRp\nURPUoJCIF9D4YBQxGKUxPggxSoJIMCGiVSKoCIqgooACCiq0CCIXqxYFBIqCc/kfJmufoZ2ZzrRz\n2bXrexktu6d7zz5nnXXfjt/vR1EURUk9aamegKIoihJABbKiKIolqEBWFEWxBBXIiqIolqACWVEU\nxRJUICuKoliCCmRFURRLUIGsKIpiCSqQFUVRLCEjlsH5+fn+vn37JmgqiWfz5s1/+P3+bpHG6Brt\nJ5o1QttYZ1tYI7SddcYkkPv27cumTZuaP6sU4zjO3qbG6BrtJ5o1QttYZ1tYI7SddarLQlEUxRJU\nICuKoliCCmRFURRLUIGsKIpiCSqQFUVRLCGmLItEMH36dH777TcADh48CMCxY8c4deoUABdffDEA\n999/PwCDBw9OwSxbht/vp76+HoAdO3YAsGfPHrPu4uJiAM4991wAevXqheM4KZiporQtGh7Q4fP5\nzM8yMpIvHlVDVhRFsYSkvwImTpwIwGeffQbAL7/8wr///ht2/DfffAPABx98AMDLL7/M1VdfDUBO\nTk4ip9pi9u4NpB5WVlby4YcfArB//34APB5Po/EDBw4EYObMmYwYMQKAgoKCZEw1LmzduhWAw4cP\nA/DXX38Zi6CmpgaA7OxsAEpKSrj77rtTMEslGr7//nsAjhw5AgSs1u3btwPwww8/AJCVlQW0vr08\nfvw4ABs2bGDLli0AdO3aFQjcsz/++CMQkE0A559/PgA9e/Zk6NChAIwcOTIhc0uaQL799tsBWL16\nNRBYOBBRGAcj5v2cOXOMiW+bQBZT58033wRg4cKFAKxcubKRaRSKbdu2AVBeXm6+lz///BOAzp07\nk5Zmr0Hj8/k4evQoAFdddVXYcYWFhQB07949KfOKN+vWreP5558HYNmyZWHHvfLKKwBUVFTQqVMn\nALN/trujfD6fcbGVlpaGHdda9tLr9QKwZMkSAD7++GMgIJBFYYjE8uXLAcjNzTWyZ/z48QBMnTo1\nrs+lvU+4oihKGyMpGvLatWvx+XwAtG/fHmi+lrB582befvttAB5++OHTrplK/H4/b7zxBgBPP/00\n4JrpsRJsNYgptX//fvr06dPCWSYOj8fD5Zdf3uQ4cdmcccYZiZ5Si6itrQXcPXzqqacA+Pzzz6P6\n/SlTpgBQVVXF3LlzATjnnHMASE9Pj+tc443X62X48OFNjpO9LCoqSvSUmo3X6zVa8IYNGwBYsWIF\n4Frd0XLixAl27doFYFyQq1at4r333ovXdFVDVhRFsYWkaMjdu3fn2muvBdygzr59+wD45JNPYrqW\n1+tl9+7dAPz8888A9O/fP+V+uV27dvHOO+8AzdeMI7FkyRIeeeSRuF83XsRqpUyYMCFBM2k5R48e\nZf369QCMGzeuRdeqrq6mqqoKgDvuuAOAPn360K5du5ZNMoFkZmbGNL6ioiJBM2k5fr+fb7/9FoCv\nv/4acOMysVJYWEiXLl0A9zs666yz4jBLl6QI5Ly8PGN6S5DgwIEDAJSVlZlc4xMnTgCwceNG3n33\nXSBwQwfj8/lMFF+EcCqFsQTrVq9ebR68SEgAoF27diYoIqaf5F6HCgDu3r2b33//HbAziFJfX8/s\n2bMBmDVrVthxr7/+OuBG6G1CvvdNmzaZoE0kxFTPzMw0e3jy5MnTrvXHH38YM1f+zWZhDIHn8Lnn\nngMi7+XixYsBO/dSOHDggMnQElkS7BIU2RH8zN14440ADBkyBIB+/foBgX2T3OTevXsDcMEFF5jf\njYccUpeFoiiKJSRFQ+7WrRujRo0C3BQU0RZOnTpltMZff/0VCJgUeXl5Ya8nQREb0t6+++47AObN\nmxfVeNE4Hn/8cfOzOXPmAPDss88CGE0Y3Lfu0qVLTa6njRpyeno6X331VZPjrr/++iTMpnnI915Z\nWWmC0KGQoK3c04cPH2b69OmAm6MreDweY/n17Nkz7nNOBOnp6SYAFglZv43I/lVXVxsXp6RlBtNQ\nu509ezbnnXceAIMGDQJcN2tubq7RrsVlEW9rRzVkRVEUS0haYYi8SUS7lU+/32/+rUOHDkDgTSRB\nlYacPHnSVLSlsopNqn0kkCeBg3BIL45gzVh44IEHTrvmrFmzGvkijxw5YqoWS0pKWjr9uLNjxw5W\nrlzZ5DgJitiI+ETXrFkTdsyECRMoLy8H3LV4vV5++umnkOM9Ho+JFeTm5sZzugmjpqaGjz76qMlx\nnTt3TsJsmocUtqxfv95YsZGsnrKyMiAQF5CUTLHA5dNxHKNJJyoOoBqyoiiKJSRFQ05LS2uUOSC+\nmOzsbFNGLZHqUP4r8TP369fPdEVLZYK91LsvWLCgybHTpk3jnnvuAdy3dKhyy44dOwLQo0cPkxYo\neDwe6urqWjTnRCJFD60ZyYYIVc4v/QwmT57MsWPHALfPyqJFi/j7779DXtNxHJNF1FqQrILWiMgZ\nycSqrq4OuzfBrFq1Cgjss1iqgsibZHR/S4pAdhynkUA+dOgQEPjipJHJxo0bATfIFYri4mKGDRuW\noJlGjwhkeYlEoqysjAEDBgCuIPb7/eY7kUZD8mJq3769GSdBUMdxou77kQqam9tpAyJgpTovLS2t\nkXkrJvDixYvNf3/xxRcAjV6ewTiOE1XVm01IQ6HWjDxLsVa3vvjiiyZoPnbsWCDwEoZAden//ve/\nOM6yMeqyUBRFsYSkBfWkqk60EDnSu6amxmga4nwPhWjFF110UcrT3Y4cOcJLL70U9fgrrrgiZAK6\nIBqa9EsIRn4vLS0tJQ2zoyWSlmg7YqJKMCczM5N//vnntDFiCS1YsMDsSTQWy8SJE40rqrUgbSdb\nI7I3UrQzaNAgY21GCuoFI9ae9Ka59NJLAairqzNyKFEBTdWQFUVRLCEpKteaNWtMWtiiRYsAtwF2\nqEbtoRDNul+/flH1Fk4k1dXVUSXOC+3atTNv5+Byb1mHNHYPhYzp2LEjF154YXOnnHBuvfXWVE+h\n2Ui65WWXXQYEeh7L/Srff0N/f7SMHDky5t4Qqeamm25K9RRajFjRw4YN46GHHgLcQy6kQMTn8xnr\nKLgbpVhHMm7SpEkAPPPMM3Tr1g1wi0bi3bYhoQJ53bp1QKDB89KlSwG3Gi/WG1u+uPT0dBNUEbMh\n2b0s9u/fb/pOhEI2UNouZmRkmDnKxnu9XhMJFpdNKMTcKisrs/I8QTFvbc4AaQpxKUjw7dFHH2X+\n/PkAfPnllwCn7Xc0CoHs2+DBg1uNy0IqFWNtS2kj8rz179/fuAKnTZsGuHt54sQJdu7cCbjP4Nat\nW03AXmSV1DvU1taaYGE8+1cEoy4LRVEUS0iIhizuCDn6pKqqKqr0sEiIVrxw4ULTE0C0kJycnKQG\n+g4dOhQySCfzEU0reExw+hoEAkJynNWMGTMa/Q25lqy1pKQk4Sk3zSHa5typdjNFQoKlUlE3atQo\nc+zSfffdB7idCDt06GDcbtLoPBTBzfptOEAhGqTpelPYvJeCzNHn85m9lE/BcRzTt6JXr15A4B54\n6623ADeQK8/izp07zb0i8qhTp05x1ZJVQ1YURbGEuGvIu3btMn1H5WRp6X3cEkTD3L59u3mTS5pc\nSUlJUpPvwzWgl7ey+OIknS0nJ8f0ppCqoeXLl3PvvfeG/RtShTh69GgALrnkEqt6B0iam3SoC0es\nsYJUIt95Xl6e6Ykr85f9q6+vN6cvR9KQJQBbVFSU8sMTmkJ8pTNnzow4Tp5Bm5FnUJ69tLS0RkHV\n4FRSWZP4+evq6kxQr2F649GjRxulzqkPWVEU5T9K3DXkffv28emnnwKuZpyRkdHsst/gtxkENBbp\nrCVvq2T3UcjMzAzpR5OfiQ9KtOEBAwaYnqwy99deey3i35C0GulCZUuHN1mjnOgSqSBk0qRJRrOU\n1LLWgtx34jMU7fngwYNs2bIl7O9JJoxoyDZ3eJO9lDiAFG+FYvLkySY7wdYTQvx+v8maECu6ffv2\nJpvp5ptvBtwufcePHzfHOomsWrNmjUlDlXtWrKQuXbqY+yJRcYG4C+SCggKTvycLaYlaLwuXByI/\nP984588880wArrvuumZfvzmMHj2aysrKsP8uOaxNteQMR1ZWlnFnSHDIlsCQNKGXwFYk5s+fb9LH\nWkMgKBLy8t+2bRvvv/9+2HHShEjOWgvVRMoWpDnSq6++2uTYefPmmUMYbN5LeeaWLVsGwJ49e4z8\nESVIXH9+v99UDIdCFCppUN+7d28jpBOVW27v3aIoitLGiLuG3KNHD/P2yM/PBwIHPcaKtLw7++yz\nAbe+/MorrzRayJ133tni+TaHIUOGmPr2cI30m4Oktc2dO9ec0i2VQTbg8XiMOypaF1TDxu2JSqhP\nNE01JJf0xHifQpwoPB4Pa9euBdyAZVPs3bv3tP+3cS8bHu9WW1sb9frCId3fhg8fblyJiUI1ZEVR\nFEuIu4acn5/Pgw8+CLhvqaKiIuM8Fwd7JA1r+fLljbRs8cXl5eWZTk6pIjs72/hQ5WimtWvXRiyn\njsQNN9wAwGOPPQbA0KFDrQyCpaenm7LaaLSOJ598stEhAjZpU9HQsFl5KAYOHGiOhe/fvz/gWni2\nkp6ebgJZDTvbhWLGjBmN/OG27aXjOKaHsaTy1dfXmwKsWH3fYgVXVFQAMG7cuIQfipGQSj2JML/w\nwgtAoHZc3BaSHyiCubCwkK5duwKhHeVyE9i2+eJKkeDBihUrqKqqAjBBn1CNvuU07dtuu40xY8YA\nATcMuMEDW3Ecx5xnKAFbCTZ26dLFmHalpaUAjBgxwrS0bK3IQ3zLLbcAAReMuJZkvQUFBfTo0QNw\nq/1ibYyebBzHobi4GHB7O4hbpmvXrsZVJkHl0tJS69cErsti6tSpQGAfZJ+WLFkCRM6Nz8rKYsqU\nKYCrKF1zzTUJm29D1GWhKIpiCQnt9hacsycmXTR4vd6UnpcXC6Ihjhkzxmi8DVs2+v3+Ruaex+NJ\n2Mm1ieSuu+467TMS0TYEtxnZt/HjxwOBvGux8oSsrKyQp4nbjpzzKJ+RaG17KbJn7Nixxo3xxBNP\nAJhKy7q6OvOM9u3bFwho1Klsc6sasqIoiiVYeSZQa9GOwxHchD4crVE7jhWbiyKiRYqQysvLUzyT\n1PJf2Evp7CafNtL6v2VFUZT/CCqQFUVRLEEFsqIoiiWoQFYURbEEJ5bqFcdxDgJ7mxxoL0V+vz9i\ncwhdY6ugyTVC21hnW1gjtKF12txKT1EUpS2hLgtFURRLUIGsKIpiCSqQFUVRLEEFsqIoiiWoQFYU\nRbEEFciKoiiWoAJZURTFElQgK4qiWIIKZEVRFEv4P2LsVpWGtLqLAAAAAElFTkSuQmCC\n", 393 | "text/plain": [ 394 | "" 395 | ] 396 | }, 397 | "metadata": {}, 398 | "output_type": "display_data" 399 | } 400 | ], 401 | "source": [ 402 | "plot_images_separately(reconstructions[:6,0].data.cpu().numpy())" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": null, 408 | "metadata": { 409 | "collapsed": true 410 | }, 411 | "outputs": [], 412 | "source": [] 413 | } 414 | ], 415 | "metadata": { 416 | "kernelspec": { 417 | "display_name": "Python 2", 418 | "language": "python", 419 | "name": "python2" 420 | }, 421 | "language_info": { 422 | "codemirror_mode": { 423 | "name": "ipython", 424 | "version": 2 425 | }, 426 | "file_extension": ".py", 427 | "mimetype": "text/x-python", 428 | "name": "python", 429 | "nbconvert_exporter": "python", 430 | "pygments_lexer": "ipython2", 431 | "version": "2.7.13" 432 | } 433 | }, 434 | "nbformat": 4, 435 | "nbformat_minor": 2 436 | } 437 | -------------------------------------------------------------------------------- /Capsule Network.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "\n", 13 | "import torch\n", 14 | "import torch.nn as nn\n", 15 | "import torch.nn.functional as F\n", 16 | "from torch.autograd import Variable\n", 17 | "from torch.optim import Adam\n", 18 | "from torchvision import datasets, transforms\n", 19 | "\n", 20 | "USE_CUDA = True" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": { 27 | "collapsed": true 28 | }, 29 | "outputs": [], 30 | "source": [ 31 | "class Mnist:\n", 32 | " def __init__(self, batch_size):\n", 33 | " dataset_transform = transforms.Compose([\n", 34 | " transforms.ToTensor(),\n", 35 | " transforms.Normalize((0.1307,), (0.3081,))\n", 36 | " ])\n", 37 | "\n", 38 | " train_dataset = datasets.MNIST('../data', train=True, download=True, transform=dataset_transform)\n", 39 | " test_dataset = datasets.MNIST('../data', train=False, download=True, transform=dataset_transform)\n", 40 | " \n", 41 | " self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n", 42 | " self.test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True) " 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 3, 48 | "metadata": { 49 | "collapsed": true 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "class ConvLayer(nn.Module):\n", 54 | " def __init__(self, in_channels=1, out_channels=256, kernel_size=9):\n", 55 | " super(ConvLayer, self).__init__()\n", 56 | "\n", 57 | " self.conv = nn.Conv2d(in_channels=in_channels,\n", 58 | " out_channels=out_channels,\n", 59 | " kernel_size=kernel_size,\n", 60 | " stride=1\n", 61 | " )\n", 62 | "\n", 63 | " def forward(self, x):\n", 64 | " return F.relu(self.conv(x))" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 4, 70 | "metadata": { 71 | "collapsed": true 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "class PrimaryCaps(nn.Module):\n", 76 | " def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9):\n", 77 | " super(PrimaryCaps, self).__init__()\n", 78 | "\n", 79 | " self.capsules = nn.ModuleList([\n", 80 | " nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0) \n", 81 | " for _ in range(num_capsules)])\n", 82 | " \n", 83 | " def forward(self, x):\n", 84 | " u = [capsule(x) for capsule in self.capsules]\n", 85 | " u = torch.stack(u, dim=1)\n", 86 | " u = u.view(x.size(0), 32 * 6 * 6, -1)\n", 87 | " return self.squash(u)\n", 88 | " \n", 89 | " def squash(self, input_tensor):\n", 90 | " squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)\n", 91 | " output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))\n", 92 | " return output_tensor" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 29, 98 | "metadata": { 99 | "collapsed": true 100 | }, 101 | "outputs": [], 102 | "source": [ 103 | "class DigitCaps(nn.Module):\n", 104 | " def __init__(self, num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16):\n", 105 | " super(DigitCaps, self).__init__()\n", 106 | "\n", 107 | " self.in_channels = in_channels\n", 108 | " self.num_routes = num_routes\n", 109 | " self.num_capsules = num_capsules\n", 110 | "\n", 111 | " self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))\n", 112 | "\n", 113 | " def forward(self, x):\n", 114 | " batch_size = x.size(0)\n", 115 | " x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)\n", 116 | "\n", 117 | " W = torch.cat([self.W] * batch_size, dim=0)\n", 118 | " u_hat = torch.matmul(W, x)\n", 119 | "\n", 120 | " b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1))\n", 121 | " if USE_CUDA:\n", 122 | " b_ij = b_ij.cuda()\n", 123 | "\n", 124 | " num_iterations = 3\n", 125 | " for iteration in range(num_iterations):\n", 126 | " c_ij = F.softmax(b_ij)\n", 127 | " c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)\n", 128 | "\n", 129 | " s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)\n", 130 | " v_j = self.squash(s_j)\n", 131 | " \n", 132 | " if iteration < num_iterations - 1:\n", 133 | " a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))\n", 134 | " b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)\n", 135 | "\n", 136 | " return v_j.squeeze(1)\n", 137 | " \n", 138 | " def squash(self, input_tensor):\n", 139 | " squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)\n", 140 | " output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))\n", 141 | " return output_tensor" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 130, 147 | "metadata": { 148 | "collapsed": true 149 | }, 150 | "outputs": [], 151 | "source": [ 152 | "class Decoder(nn.Module):\n", 153 | " def __init__(self):\n", 154 | " super(Decoder, self).__init__()\n", 155 | " \n", 156 | " self.reconstraction_layers = nn.Sequential(\n", 157 | " nn.Linear(16 * 10, 512),\n", 158 | " nn.ReLU(inplace=True),\n", 159 | " nn.Linear(512, 1024),\n", 160 | " nn.ReLU(inplace=True),\n", 161 | " nn.Linear(1024, 784),\n", 162 | " nn.Sigmoid()\n", 163 | " )\n", 164 | " \n", 165 | " def forward(self, x, data):\n", 166 | " classes = torch.sqrt((x ** 2).sum(2))\n", 167 | " classes = F.softmax(classes)\n", 168 | " \n", 169 | " _, max_length_indices = classes.max(dim=1)\n", 170 | " masked = Variable(torch.sparse.torch.eye(10))\n", 171 | " if USE_CUDA:\n", 172 | " masked = masked.cuda()\n", 173 | " masked = masked.index_select(dim=0, index=max_length_indices.squeeze(1).data)\n", 174 | " \n", 175 | " reconstructions = self.reconstraction_layers((x * masked[:, :, None, None]).view(x.size(0), -1))\n", 176 | " reconstructions = reconstructions.view(-1, 1, 28, 28)\n", 177 | " \n", 178 | " return reconstructions, masked" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 131, 184 | "metadata": { 185 | "collapsed": true 186 | }, 187 | "outputs": [], 188 | "source": [ 189 | "class CapsNet(nn.Module):\n", 190 | " def __init__(self):\n", 191 | " super(CapsNet, self).__init__()\n", 192 | " self.conv_layer = ConvLayer()\n", 193 | " self.primary_capsules = PrimaryCaps()\n", 194 | " self.digit_capsules = DigitCaps()\n", 195 | " self.decoder = Decoder()\n", 196 | " \n", 197 | " self.mse_loss = nn.MSELoss()\n", 198 | " \n", 199 | " def forward(self, data):\n", 200 | " output = self.digit_capsules(self.primary_capsules(self.conv_layer(data)))\n", 201 | " reconstructions, masked = self.decoder(output, data)\n", 202 | " return output, reconstructions, masked\n", 203 | " \n", 204 | " def loss(self, data, x, target, reconstructions):\n", 205 | " return self.margin_loss(x, target) + self.reconstruction_loss(data, reconstructions)\n", 206 | " \n", 207 | " def margin_loss(self, x, labels, size_average=True):\n", 208 | " batch_size = x.size(0)\n", 209 | "\n", 210 | " v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True))\n", 211 | "\n", 212 | " left = F.relu(0.9 - v_c).view(batch_size, -1)\n", 213 | " right = F.relu(v_c - 0.1).view(batch_size, -1)\n", 214 | "\n", 215 | " loss = labels * left + 0.5 * (1.0 - labels) * right\n", 216 | " loss = loss.sum(dim=1).mean()\n", 217 | "\n", 218 | " return loss\n", 219 | " \n", 220 | " def reconstruction_loss(self, data, reconstructions):\n", 221 | " loss = self.mse_loss(reconstructions.view(reconstructions.size(0), -1), data.view(reconstructions.size(0), -1))\n", 222 | " return loss * 0.0005" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 132, 228 | "metadata": { 229 | "collapsed": true 230 | }, 231 | "outputs": [], 232 | "source": [ 233 | "capsule_net = CapsNet()\n", 234 | "if USE_CUDA:\n", 235 | " capsule_net = capsule_net.cuda()\n", 236 | "optimizer = Adam(capsule_net.parameters())" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 133, 242 | "metadata": { 243 | "collapsed": false 244 | }, 245 | "outputs": [ 246 | { 247 | "name": "stdout", 248 | "output_type": "stream", 249 | "text": [ 250 | "train accuracy: 0.12\n", 251 | "train accuracy: 0.9\n", 252 | "train accuracy: 0.94\n", 253 | "train accuracy: 0.96\n", 254 | "train accuracy: 0.99\n", 255 | "train accuracy: 0.96\n", 256 | "0.229411779922\n", 257 | "test accuracy: 0.96\n", 258 | "0.0547490972094\n", 259 | "train accuracy: 0.98\n", 260 | "train accuracy: 0.98\n", 261 | "train accuracy: 0.99\n", 262 | "train accuracy: 0.99\n", 263 | "train accuracy: 1.0\n", 264 | "train accuracy: 0.99\n", 265 | "0.0456192491871\n", 266 | "test accuracy: 0.98\n", 267 | "0.0390225026663\n", 268 | "train accuracy: 0.99\n", 269 | "train accuracy: 0.99\n", 270 | "train accuracy: 1.0\n" 271 | ] 272 | }, 273 | { 274 | "ename": "KeyboardInterrupt", 275 | "evalue": "", 276 | "output_type": "error", 277 | "traceback": [ 278 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 279 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 280 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 24\u001b[0;31m \u001b[0mtrain_loss\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbatch_id\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m100\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 281 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 282 | ] 283 | } 284 | ], 285 | "source": [ 286 | "batch_size = 100\n", 287 | "mnist = Mnist(batch_size)\n", 288 | "\n", 289 | "n_epochs = 30\n", 290 | "\n", 291 | "\n", 292 | "for epoch in range(n_epochs):\n", 293 | " capsule_net.train()\n", 294 | " train_loss = 0\n", 295 | " for batch_id, (data, target) in enumerate(mnist.train_loader):\n", 296 | "\n", 297 | " target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)\n", 298 | " data, target = Variable(data), Variable(target)\n", 299 | "\n", 300 | " if USE_CUDA:\n", 301 | " data, target = data.cuda(), target.cuda()\n", 302 | "\n", 303 | " optimizer.zero_grad()\n", 304 | " output, reconstructions, masked = capsule_net(data)\n", 305 | " loss = capsule_net.loss(data, output, target, reconstructions)\n", 306 | " loss.backward()\n", 307 | " optimizer.step()\n", 308 | "\n", 309 | " train_loss += loss.data[0]\n", 310 | " \n", 311 | " if batch_id % 100 == 0:\n", 312 | " print \"train accuracy:\", sum(np.argmax(masked.data.cpu().numpy(), 1) == \n", 313 | " np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size)\n", 314 | " \n", 315 | " print train_loss / len(mnist.train_loader)\n", 316 | " \n", 317 | " capsule_net.eval()\n", 318 | " test_loss = 0\n", 319 | " for batch_id, (data, target) in enumerate(mnist.test_loader):\n", 320 | "\n", 321 | " target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)\n", 322 | " data, target = Variable(data), Variable(target)\n", 323 | "\n", 324 | " if USE_CUDA:\n", 325 | " data, target = data.cuda(), target.cuda()\n", 326 | "\n", 327 | " output, reconstructions, masked = capsule_net(data)\n", 328 | " loss = capsule_net.loss(data, output, target, reconstructions)\n", 329 | "\n", 330 | " test_loss += loss.data[0]\n", 331 | " \n", 332 | " if batch_id % 100 == 0:\n", 333 | " print \"test accuracy:\", sum(np.argmax(masked.data.cpu().numpy(), 1) == \n", 334 | " np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size)\n", 335 | " \n", 336 | " print test_loss / len(mnist.test_loader)" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": 134, 342 | "metadata": { 343 | "collapsed": false 344 | }, 345 | "outputs": [], 346 | "source": [ 347 | "import matplotlib\n", 348 | "import matplotlib.pyplot as plt\n", 349 | "\n", 350 | "def plot_images_separately(images):\n", 351 | " \"Plot the six MNIST images separately.\"\n", 352 | " fig = plt.figure()\n", 353 | " for j in xrange(1, 7):\n", 354 | " ax = fig.add_subplot(1, 6, j)\n", 355 | " ax.matshow(images[j-1], cmap = matplotlib.cm.binary)\n", 356 | " plt.xticks(np.array([]))\n", 357 | " plt.yticks(np.array([]))\n", 358 | " plt.show()" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 135, 364 | "metadata": { 365 | "collapsed": false 366 | }, 367 | "outputs": [ 368 | { 369 | "data": { 370 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWQAAABFCAYAAAB9nJwHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADTFJREFUeJzt3XuQ1WMcx/H3hkRKUiFd3RlEUtJFIVkzFV1YqmkIIdGW\nS0m5pCmkaHMJbSGXEmI2ZDCTu9g1IwyjVOs2ad1SZNu1/jjzfX7n7J7ddju35+z5vP7Z/Pbsen57\nznnO93me7/N9sioqKhARkdRrkOoGiIhIiDpkERFPqEMWEfGEOmQREU+oQxYR8YQ6ZBERT6hDFhHx\nhDpkERFPqEMWEfHEnnV5cIsWLSo6dOiQoKYkXmFhYUlFRUXLmh6je/Rfbe4RMuM+M+EeIXPus04d\ncocOHfj00093v1UplpWVtWlXj9E9+q829wiZcZ+ZcI+QOfepKQsREU+oQxYR8YQ6ZBERT9RpDjkV\ntm/fDsC///7rrjVt2hSAPff0vvkiIrWmCFlExBNehph//PEHkydPBuDdd98F4KuvvnLfv/nmmwHI\nzc0FoFWrVkluYXzZIQGPP/44ALNmzeK7774D4OqrrwYgLy+PPfbYIzUNlIz1zz//ADB79mzuu+8+\nAN5++20AOnfunLJ21VeKkEVEPOFVhFxSUgLAwIED+eijj6p93N133w3AN998A8CyZcvScj75k08+\nAWDq1KkArFq1CoCGDRuy9957A/Dwww8DMHToUM4888wUtHL3lZaWMnfuXCC4t/HjxwOh51j89/nn\nnwMwbdo0d+3bb78FMi9CXr9+Pa+99lqtH3/ttdfW+f/hVS+2cuVKgBo743ArVqwA4JZbbuGee+5J\nWLsS4fXXX2fo0KFAsHA5YMAAIDQ9se+++wLQvn17AH766acUtDI248ePdx8o5r333gPgueeeY/Dg\nwaloVtz89ddfQGg6LTs7G4DDDz8cIOLD057Xbt26AbDXXnsls5lSg59//hmAU089FYBJkyZV6Ujt\ng2fTpk38/vvvVX5H48aNAdz7ORaashAR8YRXEXLHjh0BaNSoEUcddRQQDHFPPvlkAIqLi7nssssA\n+PXXX4HQgkO6RMhvvfUWAMOGDaO0tBQI7tEWKdu1a+cef+SRRya5hfFjw91wBx10EBCKKtMxQt6y\nZQvr168H4KKLLgJCr0lj23vDt/naa/Oqq64CYO7cuTRq1Cgp7Y2Vvcfqq7Vr1wKh5xWgdevWbirU\nnl8bwZ5xxhmMHTsWgDZt2rjfYdOlNjqKhSJkERFPeBUh9+7dG4B169ax3377AcEmENOpUycOO+ww\nIL0+vcvKyoAgtW3btm3ceuutAEyfPr3K4y16tnnKdPLDDz8A8OOPP1b5Xn5+PgD9+vVLaptitXPn\nTgB69OjhFrWiOfjgg4HIUc66desAeOSRR4DQYvQzzzwT8XhfPfvss+7fzZo1A6Bv376pak7cvfDC\nCwBu8Xnw4MHMnDkTgObNmwNwySWXAHDTTTclvD1edcimdevW1X5v9erVFBUVRVzr0aNHopsUk7Ky\nMiZOnAiEFrMAxowZw5QpU6r9mVdffRWADRs2AHDEEUckuJXxY8M/+xru448/BtKnQ16+fDkADz30\nEEC1nbHli0+YMAGIfL6++OILAAoLCwG4//77ueCCCwBYsGABACeeeGK8mx4Ty423QAKge/fuQPrn\n/Zvi4mLeeOMNAM477zx3/eKLLwbgxhtvBJK7I1hTFiIinvAyQo7GhowrV66kvLw84nvnnHNOKppU\na/Pnz2fevHkR14499tgaF3YsMrNIsmvXrolrYJzde++9QLAYEs4iSd/ZrrTZs2cDsHnz5iqPOemk\nkwBYvHgxRx99NEDU5/T444+P+Dpo0CAXEVsO+ssvvxzP5sfMRgI2ooMgQq4vDjjgAJdI8P3337vr\nqSyErwhZRMQTaRMhv/POO0AQsQD06dMHSM5k++6wRZypU6e6pHGbl6ppznDJkiUsW7YMgKeffhqA\nBg38/+y0OePKc/wA48aNA4KFEh/ZLqwZM2a4XZS2uBouJycHgAceeACo+5xqs2bN3Fz6nDlzdru9\niRQeGRtLRa0vmjRp4kYtW7duBUKL6E2aNElZm/x/l4uIZIikR8hWxay2W4HffPNNABYuXOiutWjR\nAoDbb78dwNV98IVFxhYNjx492qXS7LPPPtX+3N9//w3AY489xogRIwDSavPExo0bgaDGSDib58/K\nykpmk+rEUhLLy8ujRsYA2dnZrtpgeGRs6xoffvghEGwRnzBhAg0bNqzyew455BAgmG/3zciRIwG4\n7bbb3LX+/funqjkJM2nSJACXSrtz586Ie062hHbIlo+6cOFC17FaGlcstRkeffRRIMhb9oVNq1hH\nfO655wIwc+bMGjtiYyVHi4qKXAGldCm5WVZWxowZM6pct9xVnxeE/vzzTyA4BCG8lorVKbC0KJtK\nCrdjxw63ezQ8bxdCz/1LL70EBOmZvgUQ0Tz11FOpbkJSWIrtK6+8AkSfbksmTVmIiHgi7hFySUkJ\n8+fPB4Jphmg7tnZXt27dvBw6rVmzxlWJOvDAA4HQBgCoeZoCQot4EJTazMvL47TTTktUUxOipKQk\naupW5b+Jb5YuXepSz6Jt+rjhhhuAYHoMgkjaalrMmjWrSmRstm7dyllnnQWEKvxBegz9n3/++Yj/\n7tixY9Spl/rCRtsffPCBSxawTSPJvG9FyCIinohbhGxzw126dIlaM7Qyi5hGjhzp5uVqM6/82Wef\nuUUjS8z3wZNPPulSZywSOvTQQ2v8GauGlpeXBwQLKZdffnmimhl3//33HwB33nlnle/17NnTzYv7\nqqioKGpkPHz4cCA4Lsw2Ji1YsIDffvsNIKWLP4m0du1a9342/fr1czW66yPb0FNaWsrq1asB3Ndk\nbvOPuUO2BZELL7wQIKIztkLc3bt35+yzzwZwwzebTM/OznZFosPZgti2bduAYNW6tLTUvTl8YO3K\nz893+dDHHHPMLn9u+/btruym1QuwTIx0WciDoMZD5UL0EMo5Ttc3seWi2pvR8qgnTpzoFipry0qo\npstuyy+//LLKLsvRo0enqDXJZVkXECQlJJOmLEREPBFzhGxRnVWygqAIueVYjhgxwkXSL774IoBL\nEwrPWbWSm/3792fp0qVA8CmVyv3l0VjeqZ011q5du1oNzy3yyMnJcX8zO7oqHato2fA9mt05UyzZ\nbKdWZZZLbt5//333719++SXie9nZ2a5Mqo2Ywl133XVAqHZCOrD3Xjo75ZRTgNDuy7q8r5o2bepG\nR/E4kqmuFCGLiHgi5gi5crQAQY1XK9Kdn5/v9v3bkSnhbI7NikSH1yaN9vsLCgqA4GDCVLCFSIuI\nlixZUmPCf3hkDKF7sHnXnj17JrKpCRVeL9fYPJytF/hs+PDhLsq3Of2aNG7cmCuvvBIIdlF27drV\n1UGOFiH7mvJXnfA61vYes4jTd5bCaAvmkydPdpUWbZNPbVVe1EvGhh5FyCIinog5Ql60aBEQWaPA\nIg2rbRDNcccdB8A111zjKvRHm2NbvHhxlWvVzfslS3l5uRsF2PEulmUSzcaNG918qs0Xz5s3jyuu\nuCLBLU28aNXKTj/9dCA9KtQ1aNDAzfGagoIC9/ps2bIlENxTr169XPZQuOpek1lZWUk9cSLe7DlM\nl8yfu+66Cwj6o0WLFrmMrRNOOAGIngVla0KFhYWuponVlrFo214LiRTzK2XUqFFAKA/XROuI7Y9g\n6UPDhg0DgkJBtdWqVSt3AnWqLF++3A1n7MMkGstHnjZtGl9//TUQ7F4cNWpU2rzIo7FhbfhhAXYS\nb+fOnVPSpt1lb97rr78+4mtdrFmzJur1tm3bute6JJ4Fg1Y+dPPmze70aFvcsw/XcDb1VlBQ4I6v\nstz6ZHTExv8QRkQkQ8QcIQ8YMAAI9veHp8wMGTIEgClTprgFvroWKLfdUPap1qdPnxoPQU2G4uJi\nlxoTnjBvu5usroGdJt22bVtX8SsdFrpqww4K2LFjh7tmOwx3tUOxPrKym5UNHDgwyS3JbJYYMGbM\nGCBUGfLBBx8EggSBFStWuMfvv//+ET/fpk0bN+1R0+g3URQhi4h4IuYI2aLg888/Hwi20kKw/TSW\nRQ2bv7HFMx9s2LDBjQguvfRSIFT5yzZ6WHFz+5SePn16UuehkqFydbPmzZszduzYFLUm9Wye8o47\n7gCC17yVDJDksjWrOXPm0KVLFyD6JiYbedtCrS38pUrcln9tgSpddiPFIicnxw17rHQmhHZsQbDS\na4uPPp+SES/jxo2r8wJtfVJ5J6kVqxk0aFAKWhObIUOGuOCiU6dOKW5N7HwK5nZFUxYiIp5I3wTJ\nFOrdu3dMR1DVB0888QQQjAbCq2RlIpu6s6OPqkuDSwe5ubnk5uamuhkZSRGyiIgnFCHLbunbt2/E\n10xnC9h21JONHETqQhGyiIgnFCGLxFGvXr0AWLVqVYpbIulIEbKIiCfUIYuIeCLLKhvV6sFZWVuA\nTYlrTsK1r6ioqHHLnO4xLezyHiEz7jMT7hEy6D7r0iGLiEjiaMpCRMQT6pBFRDyhDllExBPqkEVE\nPKEOWUTEE+qQRUQ8oQ5ZRMQT6pBFRDyhDllExBP/AwL3rM94Bza/AAAAAElFTkSuQmCC\n", 371 | "text/plain": [ 372 | "" 373 | ] 374 | }, 375 | "metadata": {}, 376 | "output_type": "display_data" 377 | } 378 | ], 379 | "source": [ 380 | "plot_images_separately(data[:6,0].data.cpu().numpy())" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 136, 386 | "metadata": { 387 | "collapsed": false 388 | }, 389 | "outputs": [ 390 | { 391 | "data": { 392 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWQAAABFCAYAAAB9nJwHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAD6RJREFUeJztnVtsVFUXgL/TltIC/UuhlGstiqJiRQSKKNRLtEYUFRRp\nURPUoJCIF9D4YBQxGKUxPggxSoJIMCGiVSKoCIqgooACCiq0CCIXqxYFBIqCc/kfJmufoZ2ZzrRz\n2bXrexktu6d7zz5nnXXfjt/vR1EURUk9aamegKIoihJABbKiKIolqEBWFEWxBBXIiqIolqACWVEU\nxRJUICuKoliCCmRFURRLUIGsKIpiCSqQFUVRLCEjlsH5+fn+vn37JmgqiWfz5s1/+P3+bpHG6Brt\nJ5o1QttYZ1tYI7SddcYkkPv27cumTZuaP6sU4zjO3qbG6BrtJ5o1QttYZ1tYI7SddarLQlEUxRJU\nICuKoliCCmRFURRLUIGsKIpiCSqQFUVRLCGmLItEMH36dH777TcADh48CMCxY8c4deoUABdffDEA\n999/PwCDBw9OwSxbht/vp76+HoAdO3YAsGfPHrPu4uJiAM4991wAevXqheM4KZiporQtGh7Q4fP5\nzM8yMpIvHlVDVhRFsYSkvwImTpwIwGeffQbAL7/8wr///ht2/DfffAPABx98AMDLL7/M1VdfDUBO\nTk4ip9pi9u4NpB5WVlby4YcfArB//34APB5Po/EDBw4EYObMmYwYMQKAgoKCZEw1LmzduhWAw4cP\nA/DXX38Zi6CmpgaA7OxsAEpKSrj77rtTMEslGr7//nsAjhw5AgSs1u3btwPwww8/AJCVlQW0vr08\nfvw4ABs2bGDLli0AdO3aFQjcsz/++CMQkE0A559/PgA9e/Zk6NChAIwcOTIhc0uaQL799tsBWL16\nNRBYOBBRGAcj5v2cOXOMiW+bQBZT58033wRg4cKFAKxcubKRaRSKbdu2AVBeXm6+lz///BOAzp07\nk5Zmr0Hj8/k4evQoAFdddVXYcYWFhQB07949KfOKN+vWreP5558HYNmyZWHHvfLKKwBUVFTQqVMn\nALN/trujfD6fcbGVlpaGHdda9tLr9QKwZMkSAD7++GMgIJBFYYjE8uXLAcjNzTWyZ/z48QBMnTo1\nrs+lvU+4oihKGyMpGvLatWvx+XwAtG/fHmi+lrB582befvttAB5++OHTrplK/H4/b7zxBgBPP/00\n4JrpsRJsNYgptX//fvr06dPCWSYOj8fD5Zdf3uQ4cdmcccYZiZ5Si6itrQXcPXzqqacA+Pzzz6P6\n/SlTpgBQVVXF3LlzATjnnHMASE9Pj+tc443X62X48OFNjpO9LCoqSvSUmo3X6zVa8IYNGwBYsWIF\n4Frd0XLixAl27doFYFyQq1at4r333ovXdFVDVhRFsYWkaMjdu3fn2muvBdygzr59+wD45JNPYrqW\n1+tl9+7dAPz8888A9O/fP+V+uV27dvHOO+8AzdeMI7FkyRIeeeSRuF83XsRqpUyYMCFBM2k5R48e\nZf369QCMGzeuRdeqrq6mqqoKgDvuuAOAPn360K5du5ZNMoFkZmbGNL6ioiJBM2k5fr+fb7/9FoCv\nv/4acOMysVJYWEiXLl0A9zs666yz4jBLl6QI5Ly8PGN6S5DgwIEDAJSVlZlc4xMnTgCwceNG3n33\nXSBwQwfj8/lMFF+EcCqFsQTrVq9ebR68SEgAoF27diYoIqaf5F6HCgDu3r2b33//HbAziFJfX8/s\n2bMBmDVrVthxr7/+OuBG6G1CvvdNmzaZoE0kxFTPzMw0e3jy5MnTrvXHH38YM1f+zWZhDIHn8Lnn\nngMi7+XixYsBO/dSOHDggMnQElkS7BIU2RH8zN14440ADBkyBIB+/foBgX2T3OTevXsDcMEFF5jf\njYccUpeFoiiKJSRFQ+7WrRujRo0C3BQU0RZOnTpltMZff/0VCJgUeXl5Ya8nQREb0t6+++47AObN\nmxfVeNE4Hn/8cfOzOXPmAPDss88CGE0Y3Lfu0qVLTa6njRpyeno6X331VZPjrr/++iTMpnnI915Z\nWWmC0KGQoK3c04cPH2b69OmAm6MreDweY/n17Nkz7nNOBOnp6SYAFglZv43I/lVXVxsXp6RlBtNQ\nu509ezbnnXceAIMGDQJcN2tubq7RrsVlEW9rRzVkRVEUS0haYYi8SUS7lU+/32/+rUOHDkDgTSRB\nlYacPHnSVLSlsopNqn0kkCeBg3BIL45gzVh44IEHTrvmrFmzGvkijxw5YqoWS0pKWjr9uLNjxw5W\nrlzZ5DgJitiI+ETXrFkTdsyECRMoLy8H3LV4vV5++umnkOM9Ho+JFeTm5sZzugmjpqaGjz76qMlx\nnTt3TsJsmocUtqxfv95YsZGsnrKyMiAQF5CUTLHA5dNxHKNJJyoOoBqyoiiKJSRFQ05LS2uUOSC+\nmOzsbFNGLZHqUP4r8TP369fPdEVLZYK91LsvWLCgybHTpk3jnnvuAdy3dKhyy44dOwLQo0cPkxYo\neDwe6urqWjTnRCJFD60ZyYYIVc4v/QwmT57MsWPHALfPyqJFi/j7779DXtNxHJNF1FqQrILWiMgZ\nycSqrq4OuzfBrFq1Cgjss1iqgsibZHR/S4pAdhynkUA+dOgQEPjipJHJxo0bATfIFYri4mKGDRuW\noJlGjwhkeYlEoqysjAEDBgCuIPb7/eY7kUZD8mJq3769GSdBUMdxou77kQqam9tpAyJgpTovLS2t\nkXkrJvDixYvNf3/xxRcAjV6ewTiOE1XVm01IQ6HWjDxLsVa3vvjiiyZoPnbsWCDwEoZAden//ve/\nOM6yMeqyUBRFsYSkBfWkqk60EDnSu6amxmga4nwPhWjFF110UcrT3Y4cOcJLL70U9fgrrrgiZAK6\nIBqa9EsIRn4vLS0tJQ2zoyWSlmg7YqJKMCczM5N//vnntDFiCS1YsMDsSTQWy8SJE40rqrUgbSdb\nI7I3UrQzaNAgY21GCuoFI9ae9Ka59NJLAairqzNyKFEBTdWQFUVRLCEpKteaNWtMWtiiRYsAtwF2\nqEbtoRDNul+/flH1Fk4k1dXVUSXOC+3atTNv5+Byb1mHNHYPhYzp2LEjF154YXOnnHBuvfXWVE+h\n2Ui65WWXXQYEeh7L/Srff0N/f7SMHDky5t4Qqeamm25K9RRajFjRw4YN46GHHgLcQy6kQMTn8xnr\nKLgbpVhHMm7SpEkAPPPMM3Tr1g1wi0bi3bYhoQJ53bp1QKDB89KlSwG3Gi/WG1u+uPT0dBNUEbMh\n2b0s9u/fb/pOhEI2UNouZmRkmDnKxnu9XhMJFpdNKMTcKisrs/I8QTFvbc4AaQpxKUjw7dFHH2X+\n/PkAfPnllwCn7Xc0CoHs2+DBg1uNy0IqFWNtS2kj8rz179/fuAKnTZsGuHt54sQJdu7cCbjP4Nat\nW03AXmSV1DvU1taaYGE8+1cEoy4LRVEUS0iIhizuCDn6pKqqKqr0sEiIVrxw4ULTE0C0kJycnKQG\n+g4dOhQySCfzEU0reExw+hoEAkJynNWMGTMa/Q25lqy1pKQk4Sk3zSHa5typdjNFQoKlUlE3atQo\nc+zSfffdB7idCDt06GDcbtLoPBTBzfptOEAhGqTpelPYvJeCzNHn85m9lE/BcRzTt6JXr15A4B54\n6623ADeQK8/izp07zb0i8qhTp05x1ZJVQ1YURbGEuGvIu3btMn1H5WRp6X3cEkTD3L59u3mTS5pc\nSUlJUpPvwzWgl7ey+OIknS0nJ8f0ppCqoeXLl3PvvfeG/RtShTh69GgALrnkEqt6B0iam3SoC0es\nsYJUIt95Xl6e6Ykr85f9q6+vN6cvR9KQJQBbVFSU8sMTmkJ8pTNnzow4Tp5Bm5FnUJ69tLS0RkHV\n4FRSWZP4+evq6kxQr2F649GjRxulzqkPWVEU5T9K3DXkffv28emnnwKuZpyRkdHsst/gtxkENBbp\nrCVvq2T3UcjMzAzpR5OfiQ9KtOEBAwaYnqwy99deey3i35C0GulCZUuHN1mjnOgSqSBk0qRJRrOU\n1LLWgtx34jMU7fngwYNs2bIl7O9JJoxoyDZ3eJO9lDiAFG+FYvLkySY7wdYTQvx+v8maECu6ffv2\nJpvp5ptvBtwufcePHzfHOomsWrNmjUlDlXtWrKQuXbqY+yJRcYG4C+SCggKTvycLaYlaLwuXByI/\nP984588880wArrvuumZfvzmMHj2aysrKsP8uOaxNteQMR1ZWlnFnSHDIlsCQNKGXwFYk5s+fb9LH\nWkMgKBLy8t+2bRvvv/9+2HHShEjOWgvVRMoWpDnSq6++2uTYefPmmUMYbN5LeeaWLVsGwJ49e4z8\nESVIXH9+v99UDIdCFCppUN+7d28jpBOVW27v3aIoitLGiLuG3KNHD/P2yM/PBwIHPcaKtLw7++yz\nAbe+/MorrzRayJ133tni+TaHIUOGmPr2cI30m4Oktc2dO9ec0i2VQTbg8XiMOypaF1TDxu2JSqhP\nNE01JJf0xHifQpwoPB4Pa9euBdyAZVPs3bv3tP+3cS8bHu9WW1sb9frCId3fhg8fblyJiUI1ZEVR\nFEuIu4acn5/Pgw8+CLhvqaKiIuM8Fwd7JA1r+fLljbRs8cXl5eWZTk6pIjs72/hQ5WimtWvXRiyn\njsQNN9wAwGOPPQbA0KFDrQyCpaenm7LaaLSOJ598stEhAjZpU9HQsFl5KAYOHGiOhe/fvz/gWni2\nkp6ebgJZDTvbhWLGjBmN/OG27aXjOKaHsaTy1dfXmwKsWH3fYgVXVFQAMG7cuIQfipGQSj2JML/w\nwgtAoHZc3BaSHyiCubCwkK5duwKhHeVyE9i2+eJKkeDBihUrqKqqAjBBn1CNvuU07dtuu40xY8YA\nATcMuMEDW3Ecx5xnKAFbCTZ26dLFmHalpaUAjBgxwrS0bK3IQ3zLLbcAAReMuJZkvQUFBfTo0QNw\nq/1ibYyebBzHobi4GHB7O4hbpmvXrsZVJkHl0tJS69cErsti6tSpQGAfZJ+WLFkCRM6Nz8rKYsqU\nKYCrKF1zzTUJm29D1GWhKIpiCQnt9hacsycmXTR4vd6UnpcXC6Ihjhkzxmi8DVs2+v3+Ruaex+NJ\n2Mm1ieSuu+467TMS0TYEtxnZt/HjxwOBvGux8oSsrKyQp4nbjpzzKJ+RaG17KbJn7Nixxo3xxBNP\nAJhKy7q6OvOM9u3bFwho1Klsc6sasqIoiiVYeSZQa9GOwxHchD4crVE7jhWbiyKiRYqQysvLUzyT\n1PJf2Evp7CafNtL6v2VFUZT/CCqQFUVRLEEFsqIoiiWoQFYURbEEJ5bqFcdxDgJ7mxxoL0V+vz9i\ncwhdY6ugyTVC21hnW1gjtKF12txKT1EUpS2hLgtFURRLUIGsKIpiCSqQFUVRLEEFsqIoiiWoQFYU\nRbEEFciKoiiWoAJZURTFElQgK4qiWIIKZEVRFEv4P2LsVpWGtLqLAAAAAElFTkSuQmCC\n", 393 | "text/plain": [ 394 | "" 395 | ] 396 | }, 397 | "metadata": {}, 398 | "output_type": "display_data" 399 | } 400 | ], 401 | "source": [ 402 | "plot_images_separately(reconstructions[:6,0].data.cpu().numpy())" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": null, 408 | "metadata": { 409 | "collapsed": true 410 | }, 411 | "outputs": [], 412 | "source": [] 413 | } 414 | ], 415 | "metadata": { 416 | "kernelspec": { 417 | "display_name": "Python 2", 418 | "language": "python", 419 | "name": "python2" 420 | }, 421 | "language_info": { 422 | "codemirror_mode": { 423 | "name": "ipython", 424 | "version": 2 425 | }, 426 | "file_extension": ".py", 427 | "mimetype": "text/x-python", 428 | "name": "python", 429 | "nbconvert_exporter": "python", 430 | "pygments_lexer": "ipython2", 431 | "version": "2.7.13" 432 | } 433 | }, 434 | "nbformat": 4, 435 | "nbformat_minor": 2 436 | } 437 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Capsule-Network-Tutorial 2 | This is easy-to-follow Capsule Network tutorial with clean readable code: [Capsule Network.ipynb 3 | ](https://github.com/higgsfield/Capsule-Network-Tutorial/blob/master/Capsule%20Network.ipynb) 4 | 5 | [Dynamic Routing Between Capsules](https://arxiv.org/abs/1710.09829) 6 | 7 | **Understanding Hinton’s Capsule Networks** blog posts: 8 | 9 | [Part I: Intuition.](https://medium.com/ai%C2%B3-theory-practice-business/understanding-hintons-capsule-networks-part-i-intuition-b4b559d1159b) 10 | 11 | [Part II: How Capsules Work.](https://medium.com/ai%C2%B3-theory-practice-business/understanding-hintons-capsule-networks-part-ii-how-capsules-work-153b6ade9f66) 12 | 13 | [Part III: Dynamic Routing Between Capsules.](https://medium.com/ai%C2%B3-theory-practice-business/understanding-hintons-capsule-networks-part-iii-dynamic-routing-between-capsules-349f6d30418) 14 | --------------------------------------------------------------------------------