├── .gitignore
├── LICENSE
├── README.md
├── constants.py
├── data_loaders.py
├── model.py
├── notebooks
├── Rotation.ipynb
├── StatisticsPlotting.ipynb
└── Visualization.ipynb
├── options.py
├── pictures
├── capsnet_deconv.png
├── cifar_reconstruction_epoch_86.png
├── primary_caps.png
├── rec_visualization.gif
├── reconstruction_epoch_50.png
├── robust_rotation.gif
└── smallnorb_rec.png
├── smallNorb.py
├── stats.py
├── tools.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | saved_models
2 | .DS_Store
3 | logs
4 | reconstructions
5 | __pycache__
6 | .ipynb_checkpoints
7 | datasets
8 | options/
9 | logs_old/
10 | .vscode
11 | data/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Ethan Li
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CapsNet
2 | Capsule networks is a novel approach showing promising results on SmallNorb and MNIST. Here we reproduce and build upon the impressive results shown by [Sara Sabour et al.](https://arxiv.org/abs/1710.09829) We experiment on the Capsule Network architecture by visualizing exactly what the capsules on different layers represents, what information they store about 3D objects in an image, and try to improve its classification results on CIFAR10 and SmallNorb with various methods including some tricks with reconstruction loss. Further, We present a deconvolution-based reconstruction module that reduces the number of learnable parameters by 80% from the fully-connected module presented by Sara Sabour et al.
3 |
4 | ## Benchmarks
5 |
6 | Our baseline model is the same as the original paper, but is only trained for 113 epochs on MNIST, and we did not use a 7-model ensemble for CIFAR10 as did in the paper.
7 |
8 | |Model | MNIST | SmallNORB | CIFAR10 |
9 | |:-------------|:-------:|:-----------:|:---------:|
10 | |Sabour et al. | 99.75% | 97.3% | 89.40% |
11 | |Baseline | 99.73% | 91.5% | 72.59% |
12 |
13 | ## Experiments
14 |
15 | We introduced a deconvolution-based reconstructions module, and experimented with Batch normalization and different network topologies.
16 |
17 | ### Deconvolution-based Reconstruction
18 |
19 | The baseline model has 1.4M parameters in the fully connected decoder, while our deconvolution-based reconstruction module recudes the number of learnable parameters by 80% down to 0.25M.
20 |
21 | 
22 |
23 | Here is an comparison between the two reconstruction modules after training for 25 epochs on MNIST, where RLoss is the SSE reconstruction loss, and MLoss is the margin loss.
24 |
25 | |Model | RLoss | MLoss | Accuracy |
26 | |:------------|:-------:|:-------:|:----------:|
27 | |FC | 21.62 | 0.0058 | 99.51% |
28 | |FC w/ BN | 13.12 | 0.0054 | 99.54% |
29 | |DeConv | 10.87 | 0.0050 | 99.54% |
30 | |DeConv w/ BN | 9.52 | 0.0044 | 99.55% |
31 |
32 | ## Visualization
33 |
34 | ### Reconstructions
35 |
36 | Here are the reconstruction results for SmallNORB and CIFAR10, after training for 186 epochs and 86 epochs respectively.
37 |
38 | 
39 | 
40 |
41 | ### Robustness to Affine Transformations
42 |
43 | We visualized how the network recognizes a rotated MNIST image when only trained on unmodified MNIST data. We present an image of number 2 as an example. The network is confident about the result when the image is just slightly rotated, but as the image is further rotated, it starts to confuse the image with other numbers. For example, it is very confident about the image being number 7 at a certain angle, and reconstructs a number 7 that aligns pretty well with the input. Due to its special topological features, the input number 2 is still recognized by the network when rotated by 180°.
44 |
45 | 
46 |
47 | ### Primary Capsules Reconstructions
48 |
49 | We used a pre-trained network to train a reconstruction module for Primary Capsules. By scaling these capsules by its routing coefficients to the classified object, we were able to visualize reconstructions from Primary Capsules. Each row is reconstructed from a single capsule, and the routing coefficient is increased from left to right.
50 |
51 | 
52 |
53 | ## Usage
54 |
55 | **Step 1. Install requirements**
56 |
57 | * Python 3
58 | * PyTorch 1.0.1
59 | * Torchvision 0.2.1
60 | * TQDM
61 |
62 | **Step 2. Adjust hyperparameters**
63 |
64 | In ```constants.py```:
65 | ```python
66 | DEFAULT_LEARNING_RATE = 0.001
67 | DEFAULT_ALPHA = 0.0005 # Scaling factor for reconstruction loss
68 | DEFAULT_DATASET = "small_norb" # 'mnist', 'small_norb'
69 | DEFAULT_DECODER = "FC" # 'FC' or 'Conv'
70 | DEFAULT_BATCH_SIZE = 128
71 | DEFAULT_EPOCHS = 300
72 | DEFAULT_USE_GPU = True
73 | DEFAULT_ROUTING_ITERATIONS = 3
74 | ```
75 |
76 | **Step 3. Start training**
77 |
78 | Training with default settings:
79 |
80 | ```console
81 | $ python train.py
82 | ```
83 |
84 | Training flags example:
85 |
86 | ```console
87 | $ python train.py --decoder=Conv --file=model32.pt --dataset=mnist
88 | ```
89 |
90 | Further help with training flags:
91 |
92 | ```console
93 | $ python train.py -h
94 | ```
95 |
96 |
97 | **Step 4. Get your results**
98 |
99 | Trained models are saved in ```saved_models``` directory. Tensorboard logs are saved to logs/. You can launch tensorboard with
100 |
101 | ```bash
102 | tensorboard --logdir logs
103 | ```
104 |
105 |
106 | ## Future work
107 |
108 | * Fully develop notebooks for visualization and plotting.
109 | * Implement [EM routing](https://openreview.net/pdf?id=HJWLfGWRb).
110 |
111 |
112 |
--------------------------------------------------------------------------------
/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 | # Directory to save models
3 | SAVE_DIR = "saved_models"
4 | # Directory to save plots
5 | PLOT_DIR = "plots"
6 | # Directory to save logs
7 | LOG_DIR = "logs"
8 | # Directory to save options
9 | OPTIONS_DIR = "options"
10 | # Directory to save images
11 | IMAGES_SAVE_DIR = "reconstructions"
12 | # Directory to save smallNorb Dataset
13 | SMALL_NORB_PATH = os.path.join("datasets", "smallNORB")
14 |
15 | # Default values for command arguments
16 | DEFAULT_LEARNING_RATE = 0.001
17 | DEFAULT_ANNEAL_TEMPERATURE = 8 # Anneal Alpha
18 | DEFAULT_ALPHA = 0.0005 # Scaling factor for reconstruction loss
19 | DEFAULT_DATASET = "small_norb" # 'mnist', 'small_norb'
20 | DEFAULT_DECODER = "FC" # 'FC' or 'Conv'
21 | DEFAULT_BATCH_SIZE = 128
22 | DEFAULT_EPOCHS = 300 # DEFAULT_EPOCHS = 300
23 | DEFAULT_USE_GPU = True
24 | DEFAULT_ROUTING_ITERATIONS = 3
25 | DEFAULT_VALIDATION_SIZE = 1000
26 |
27 | # Random seed for validation split
28 | VALIDATION_SEED = 889256487
--------------------------------------------------------------------------------
/data_loaders.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torchvision import datasets, transforms
4 | from torch.utils.data.sampler import SubsetRandomSampler
5 | from constants import *
6 | from smallNorb import SmallNORB
7 |
8 | def build_dataloaders(batch_size, valid_size, train_dataset, valid_dataset, test_dataset):
9 | # Compute validation split
10 | train_size = len(train_dataset)
11 | indices = list(range(train_size))
12 | split = int(np.floor(valid_size * train_size))
13 | np.random.shuffle(indices)
14 | train_idx, valid_idx = indices[split:], indices[:split]
15 | train_sampler = SubsetRandomSampler(train_idx)
16 | valid_sampler = SubsetRandomSampler(valid_idx)
17 |
18 | # Create dataloaders
19 | train_loader = torch.utils.data.DataLoader(train_dataset,
20 | batch_size=batch_size,
21 | sampler=train_sampler)
22 | valid_loader = torch.utils.data.DataLoader(valid_dataset,
23 | batch_size=batch_size,
24 | sampler=valid_sampler)
25 | test_loader = torch.utils.data.DataLoader(test_dataset,
26 | batch_size=batch_size,
27 | shuffle=False)
28 | return train_loader, valid_loader, test_loader
29 |
30 | def load_mnist(batch_size, valid_size=0.1):
31 | train_transform = transforms.Compose([
32 | transforms.RandomAffine(0, translate=[0.08,0.08]),
33 | transforms.ToTensor(),
34 | transforms.Normalize((0.1307,), (0.3081,))
35 | ])
36 | valid_transform = transforms.Compose([
37 | transforms.ToTensor(),
38 | transforms.Normalize((0.1307,), (0.3081,))
39 | ])
40 | test_transform = transforms.Compose([
41 | transforms.ToTensor(),
42 | transforms.Normalize((0.1307,), (0.3081,))
43 | ])
44 |
45 | train_dataset = datasets.MNIST('../data',
46 | train=True,
47 | download=True,
48 | transform=train_transform)
49 | valid_dataset = datasets.MNIST('../data',
50 | train=True,
51 | download=True,
52 | transform=valid_transform)
53 | test_dataset = datasets.MNIST('../data',
54 | train=False,
55 | download=True,
56 | transform=test_transform)
57 |
58 | return build_dataloaders(batch_size, valid_size, train_dataset, valid_dataset, test_dataset)
59 |
60 |
61 |
62 | def load_small_norb(batch_size):
63 | path = SMALL_NORB_PATH
64 | train_transform = transforms.Compose([
65 | transforms.Resize(48),
66 | transforms.RandomCrop(32),
67 | transforms.ColorJitter(brightness=32./255, contrast=0.5),
68 | transforms.ToTensor(),
69 | transforms.Normalize((0.0,), (0.3081,))
70 | ])
71 | valid_transform = transforms.Compose([
72 | transforms.Resize(48),
73 | transforms.CenterCrop(32),
74 | transforms.ToTensor(),
75 | transforms.Normalize((0.,), (0.3081,))
76 | ])
77 | test_transform = transforms.Compose([
78 | transforms.Resize(48),
79 | transforms.CenterCrop(32),
80 | transforms.ToTensor(),
81 | transforms.Normalize((0.,), (0.3081,))
82 | ])
83 |
84 | train_dataset = SmallNORB(path, train=True, download=True, transform=train_transform)
85 | valid_dataset = SmallNORB(path, train=True, download=True, transform=valid_transform)
86 | test_dataset = SmallNORB(path, train=False, transform=test_transform)
87 |
88 | return build_dataloaders(batch_size, valid_size, train_dataset, valid_dataset, test_dataset)
89 |
90 | def load_cifar10(batch_size, valid_size=0.1):
91 | train_transform = transforms.Compose([
92 | transforms.ColorJitter(brightness=63./255, contrast=0.8),
93 | transforms.RandomHorizontalFlip(),
94 | transforms.ToTensor(),
95 | transforms.Normalize((0,0,0), (0.5, 0.5, 0.5))
96 | ])
97 | valid_transform = transforms.Compose([
98 | transforms.ToTensor(),
99 | transforms.Normalize((0,0,0), (0.5, 0.5, 0.5))
100 | ])
101 | test_transform = transforms.Compose([
102 | transforms.ToTensor(),
103 | transforms.Normalize((0,0,0), (0.5, 0.5, 0.5))
104 | ])
105 | train_dataset = datasets.CIFAR10('../data',
106 | train=True,
107 | download=True,
108 | transform=train_transform)
109 | valid_dataset = datasets.CIFAR10('../data',
110 | train=True,
111 | download=True,
112 | transform=valid_transform)
113 | test_dataset = datasets.CIFAR10('../data',
114 | train=False,
115 | download=True,
116 | transform=test_transform)
117 |
118 | return build_dataloaders(batch_size, valid_size, train_dataset, valid_dataset, test_dataset)
119 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as functional
3 | from tools import squash
4 | import torch
5 | from torch.autograd import Variable
6 | USE_GPU=True
7 |
8 | def routing_algorithm(x, weight, bias, routing_iterations):
9 | """
10 | x: [batch_size, num_capsules_in, capsule_dim]
11 | weight: [1,num_capsules_in,num_capsules_out,out_channels,in_channels]
12 | bias: [1,1, num_capsules_out, out_channels]
13 | """
14 | num_capsules_in = x.shape[1]
15 | num_capsules_out = weight.shape[2]
16 | batch_size = x.size(0)
17 |
18 | x = x.unsqueeze(2).unsqueeze(4)
19 |
20 | #[batch_size, 32*6*6, 10, 16]
21 | u_hat = torch.matmul(weight, x).squeeze()
22 |
23 | b_ij = Variable(x.new(batch_size, num_capsules_in, num_capsules_out, 1).zero_())
24 |
25 |
26 | for it in range(routing_iterations):
27 | c_ij = functional.softmax(b_ij, dim=2)
28 |
29 | # [batch_size, 1, num_classes, capsule_size]
30 | s_j = (c_ij * u_hat).sum(dim=1, keepdim=True) + bias
31 | # [batch_size, 1, num_capsules, out_channels]
32 | v_j = squash(s_j, dim=-1)
33 |
34 | if it < routing_iterations - 1:
35 | # [batch-size, 32*6*6, 10, 1]
36 | delta = (u_hat * v_j).sum(dim=-1, keepdim=True)
37 | b_ij = b_ij + delta
38 |
39 | return v_j.squeeze()
40 |
41 | # First Convolutional Layer
42 | class ConvLayer(nn.Module):
43 | def __init__(self,
44 | in_channels=1,
45 | out_channels=256,
46 | kernel_size=9,
47 | batchnorm=False):
48 | super(ConvLayer, self).__init__()
49 |
50 | if batchnorm:
51 | self.conv = nn.Sequential(
52 | nn.Conv2d(in_channels=in_channels,
53 | out_channels=out_channels,
54 | kernel_size=kernel_size,
55 | stride=1),
56 | nn.BatchNorm2d(out_channels),
57 | nn.ReLU()
58 | )
59 | else:
60 | self.conv = nn.Sequential(
61 | nn.Conv2d(in_channels=in_channels,
62 | out_channels=out_channels,
63 | kernel_size=kernel_size,
64 | stride=1),
65 | nn.ReLU()
66 | )
67 | def forward(self, x):
68 | output = self.conv(x)
69 | return output
70 |
71 | class PrimaryCapules(nn.Module):
72 |
73 | def __init__(self,
74 | num_capsules=32,
75 | in_channels=256,
76 | out_channels=8,
77 | kernel_size=9,
78 | primary_caps_gridsize=6,
79 | batchnorm=False):
80 |
81 | super(PrimaryCapules, self).__init__()
82 | self.gridsize = primary_caps_gridsize
83 | self.num_capsules = num_capsules
84 | if batchnorm:
85 | self.capsules = nn.ModuleList([
86 | nn.Sequential(
87 | nn.Conv2d(in_channels=in_channels,
88 | out_channels=num_capsules,
89 | kernel_size=kernel_size,
90 | stride=2,
91 | padding=0),
92 | nn.BatchNorm2d(num_capsules)
93 | )
94 | for i in range(out_channels)
95 | ])
96 | else:
97 | self.capsules = nn.ModuleList([
98 | nn.Sequential(
99 | nn.Conv2d(in_channels=in_channels,
100 | out_channels=num_capsules,
101 | kernel_size=kernel_size,
102 | stride=2,
103 | padding=0),
104 |
105 | )
106 | for i in range(out_channels)
107 | ])
108 |
109 | def forward(self, x):
110 | output = [caps(x) for caps in self.capsules]
111 | output = torch.stack(output, dim=1)
112 | output = output.view(x.size(0), self.num_capsules*(self.gridsize)*(self.gridsize), -1)
113 |
114 | return squash(output)
115 |
116 |
117 | class ClassCapsules(nn.Module):
118 |
119 | def __init__(self,
120 | num_capsules=10,
121 | num_routes = 32*6*6,
122 | in_channels=8,
123 | out_channels=16,
124 | routing_iterations=3,
125 | leaky=False):
126 | super(ClassCapsules, self).__init__()
127 |
128 |
129 | self.in_channels = in_channels
130 | self.num_routes = num_routes
131 | self.num_capsules = num_capsules
132 | self.routing_iterations = routing_iterations
133 |
134 | self.W = nn.Parameter(torch.rand(1,num_routes,num_capsules,out_channels,in_channels))
135 | self.bias = nn.Parameter(torch.rand(1,1, num_capsules, out_channels))
136 |
137 |
138 | # [batch_size, 10, 16, 1]
139 | def forward(self, x):
140 | v_j = routing_algorithm(x, self.W, self.bias, self.routing_iterations)
141 | return v_j.unsqueeze(-1)
142 |
143 |
144 | class ReconstructionModule(nn.Module):
145 | def __init__(self, capsule_size=16, num_capsules=10, imsize=28,img_channel=1, batchnorm=False):
146 | super(ReconstructionModule, self).__init__()
147 |
148 | self.num_capsules = num_capsules
149 | self.capsule_size = capsule_size
150 | self.imsize = imsize
151 | self.img_channel = img_channel
152 | if batchnorm:
153 | self.decoder = nn.Sequential(
154 | nn.Linear(capsule_size*num_capsules, 512),
155 | nn.BatchNorm1d(512),
156 | nn.ReLU(),
157 | nn.Linear(512, 1024),
158 | nn.BatchNorm1d(1024),
159 | nn.ReLU(),
160 | nn.Linear(1024, imsize*imsize*img_channel),
161 | nn.Sigmoid()
162 | )
163 | else:
164 | self.decoder = nn.Sequential(
165 | nn.Linear(capsule_size*num_capsules, 512),
166 | nn.ReLU(),
167 | nn.Linear(512, 1024),
168 | nn.ReLU(),
169 | nn.Linear(1024, imsize*imsize*img_channel),
170 | nn.Sigmoid()
171 | )
172 |
173 | def forward(self, x, target=None):
174 | batch_size = x.size(0)
175 | if target is None:
176 | classes = torch.norm(x, dim=2)
177 | max_length_indices = classes.max(dim=1)[1].squeeze()
178 | else:
179 | max_length_indices = target.max(dim=1)[1]
180 |
181 | masked = Variable(x.new_tensor(torch.eye(self.num_capsules)))
182 |
183 | masked = masked.index_select(dim=0, index=max_length_indices.data)
184 | decoder_input = (x * masked[:, :, None, None]).view(batch_size, -1)
185 |
186 | reconstructions = self.decoder(decoder_input)
187 | reconstructions = reconstructions.view(-1, self.img_channel, self.imsize, self.imsize)
188 | return reconstructions, masked
189 |
190 | class ConvReconstructionModule(nn.Module):
191 | def __init__(self, num_capsules=10, capsule_size=16, imsize=28,img_channels=1, batchnorm=False):
192 | super(ConvReconstructionModule, self).__init__()
193 | self.num_capsules = num_capsules
194 | self.capsule_size = capsule_size
195 | self.imsize = imsize
196 | self.img_channels = img_channels
197 | self.grid_size = 6
198 | if batchnorm:
199 | self.FC = nn.Sequential(
200 | nn.Linear(capsule_size * num_capsules, num_capsules * (self.grid_size)**2 ),
201 | nn.BatchNorm1d(num_capsules * self.grid_size**2),
202 | nn.ReLU()
203 | )
204 | self.decoder = nn.Sequential(
205 | nn.ConvTranspose2d(in_channels=self.num_capsules, out_channels=32, kernel_size=9, stride=2),
206 | nn.BatchNorm2d(32),
207 | nn.ReLU(),
208 | nn.ConvTranspose2d(in_channels=32, out_channels=64, kernel_size=9, stride=1),
209 | nn.BatchNorm2d(64),
210 | nn.ReLU(),
211 | nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=2, stride=1),
212 | nn.Sigmoid()
213 | )
214 | else:
215 | self.FC = nn.Sequential(
216 | nn.Linear(capsule_size * num_capsules, num_capsules *(self.grid_size**2) ),
217 | nn.ReLU()
218 | )
219 | self.decoder = nn.Sequential(
220 | nn.ConvTranspose2d(in_channels=self.num_capsules, out_channels=32, kernel_size=9, stride=2),
221 | nn.ReLU(),
222 | nn.ConvTranspose2d(in_channels=32, out_channels=64, kernel_size=9, stride=1),
223 | nn.ReLU(),
224 | nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=2, stride=1),
225 | nn.Sigmoid()
226 | )
227 |
228 | def forward(self, x, target=None):
229 | batch_size = x.size(0)
230 | if target is None:
231 | classes = torch.norm(x, dim=2)
232 | max_length_indices = classes.max(dim=1)[1].squeeze()
233 | else:
234 | max_length_indices = target.max(dim=1)[1]
235 |
236 | masked = x.new_tensor(torch.eye(self.num_capsules))
237 | masked = masked.index_select(dim=0, index=max_length_indices.data)
238 |
239 | decoder_input = (x * masked[:, :, None, None]).view(batch_size, -1)
240 | decoder_input = self.FC(decoder_input)
241 | decoder_input = decoder_input.view(batch_size,self.num_capsules, self.grid_size, self.grid_size)
242 | reconstructions = self.decoder(decoder_input)
243 | reconstructions = reconstructions.view(-1, self.img_channels, self.imsize, self.imsize)
244 |
245 | return reconstructions, masked
246 |
247 |
248 |
249 |
250 | class SmallNorbConvReconstructionModule(nn.Module):
251 | def __init__(self, num_capsules=10, capsule_size=16, imsize=28,img_channels=1, batchnorm=False):
252 | super(SmallNorbConvReconstructionModule, self).__init__()
253 | self.num_capsules = num_capsules
254 | self.capsule_size = capsule_size
255 | self.imsize = imsize
256 | self.img_channels = img_channels
257 |
258 | self.grid_size = 4
259 |
260 | if batchnorm:
261 | self.FC = nn.Sequential(
262 | nn.Linear(capsule_size * num_capsules, num_capsules *self.grid_size*self.grid_size),
263 | nn.BatchNorm1d(num_capsules * self.grid_size**2),
264 | nn.ReLU()
265 | )
266 | self.decoder = nn.Sequential(
267 | nn.ConvTranspose2d(in_channels=num_capsules, out_channels=32, kernel_size=9, stride=2),
268 | nn.BatchNorm2d(32),
269 | nn.ReLU(),
270 | nn.ConvTranspose2d(in_channels=32, out_channels=64, kernel_size=9, stride=1),
271 | nn.BatchNorm2d(64),
272 | nn.ReLU(),
273 | nn.ConvTranspose2d(in_channels=64, out_channels=128, kernel_size=9, stride=1),
274 | nn.BatchNorm2d(128),
275 | nn.ReLU(),
276 | nn.ConvTranspose2d(in_channels=128, out_channels=img_channels, kernel_size=2, stride=1),
277 | nn.Sigmoid()
278 | )
279 | else:
280 | self.FC = nn.Sequential(
281 | nn.Linear(capsule_size * num_capsules, num_capsules *(self.grid_size**2) ),
282 | nn.ReLU()
283 | )
284 | self.decoder = nn.Sequential(
285 | nn.ConvTranspose2d(in_channels=num_capsules, out_channels=32, kernel_size=9, stride=2),
286 | nn.ReLU(),
287 | nn.ConvTranspose2d(in_channels=32, out_channels=64, kernel_size=9, stride=1),
288 | nn.ReLU(),
289 | nn.ConvTranspose2d(in_channels=64, out_channels=128, kernel_size=9, stride=1),
290 | nn.ReLU(),
291 | nn.ConvTranspose2d(in_channels=128, out_channels=img_channels, kernel_size=2, stride=1),
292 | nn.Sigmoid()
293 | )
294 |
295 | def forward(self, x, target=None):
296 | batch_size = x.size(0)
297 | if target is None:
298 | classes = torch.norm(x, dim=2)
299 | max_length_indices = classes.max(dim=1)[1].squeeze()
300 | else:
301 | max_length_indices = target.max(dim=1)[1]
302 | masked = Variable(x.new_tensor(torch.eye(self.num_capsules)))
303 | masked = masked.index_select(dim=0, index=max_length_indices.data)
304 |
305 | decoder_input = (x * masked[:, :, None, None]).view(batch_size, -1)
306 | decoder_input = self.FC(decoder_input)
307 | decoder_input = decoder_input.view(batch_size,self.num_capsules, self.grid_size, self.grid_size)
308 | reconstructions = self.decoder(decoder_input)
309 | reconstructions = reconstructions.view(-1, self.img_channels, self.imsize, self.imsize)
310 |
311 | return reconstructions, masked
312 |
313 |
314 |
315 |
316 | class CapsNet(nn.Module):
317 |
318 | def __init__(self,
319 | reconstruction_type = "FC",
320 | imsize=28,
321 | num_classes=10,
322 | routing_iterations=3,
323 | primary_caps_gridsize=6,
324 | img_channels = 1,
325 | batchnorm = False,
326 | loss = "L2",
327 | num_primary_capsules=32,
328 | leaky_routing = False
329 | ):
330 | super(CapsNet, self).__init__()
331 | self.num_classes = num_classes
332 | if leaky_routing:
333 | num_classes += 1
334 | self.num_classes += 1
335 |
336 | self.imsize=imsize
337 | self.conv_layer = ConvLayer(in_channels=img_channels, batchnorm=batchnorm)
338 | self.leaky_routing = leaky_routing
339 |
340 | self.primary_capsules = PrimaryCapules(primary_caps_gridsize=primary_caps_gridsize,
341 | batchnorm=batchnorm,
342 | num_capsules = num_primary_capsules)
343 |
344 | self.digit_caps = ClassCapsules(num_capsules=num_classes,
345 | num_routes=num_primary_capsules*primary_caps_gridsize*primary_caps_gridsize,
346 | routing_iterations=routing_iterations,
347 | leaky=leaky_routing)
348 |
349 | if reconstruction_type == "FC":
350 | self.decoder = ReconstructionModule(imsize=imsize,
351 | num_capsules=num_classes,
352 | img_channel=img_channels,
353 | batchnorm=batchnorm)
354 | elif reconstruction_type == "Conv32":
355 | self.decoder = SmallNorbConvReconstructionModule(num_capsules=num_classes,
356 | imsize=imsize,
357 | img_channels=img_channels,
358 | batchnorm=batchnorm)
359 | else:
360 | self.decoder = ConvReconstructionModule(num_capsules=num_classes,
361 | imsize=imsize,
362 | img_channels=img_channels,
363 | batchnorm=batchnorm)
364 |
365 | if loss == "L2":
366 | self.reconstruction_criterion = nn.MSELoss(reduction="none")
367 | if loss == "L1":
368 | self.reconstruction_criterion = nn.L1Loss(reduction="none")
369 |
370 | def forward(self, x, target=None):
371 | output = self.conv_layer(x)
372 | output = self.primary_capsules(output)
373 | output = self.digit_caps(output)
374 | reconstruction, masked = self.decoder(output, target)
375 |
376 | return output, reconstruction, masked
377 |
378 | def loss(self, images, labels, capsule_output, reconstruction, alpha):
379 | marg_loss = self.margin_loss(capsule_output, labels)
380 | rec_loss = self.reconstruction_loss(images, reconstruction)
381 | total_loss = (marg_loss + alpha * rec_loss).mean()
382 | return total_loss, rec_loss.mean(), marg_loss.mean()
383 |
384 | def margin_loss(self, x, labels):
385 | batch_size = x.size(0)
386 | v_c = torch.norm(x, dim=2, keepdim=True)
387 |
388 | left = functional.relu(0.9 - v_c).view(batch_size, -1) ** 2
389 | right = functional.relu(v_c - 0.1).view(batch_size, -1) ** 2
390 |
391 | loss = labels * left + 0.5 *(1-labels)*right
392 | loss = loss.sum(dim=1)
393 | return loss
394 |
395 | def reconstruction_loss(self, data, reconstructions):
396 | batch_size = reconstructions.size(0)
397 | reconstructions = reconstructions.view(batch_size, -1)
398 | data = data.view(batch_size, -1)
399 | loss = self.reconstruction_criterion(reconstructions, data)
400 | loss = loss.sum(dim=1)
401 | return loss
402 |
--------------------------------------------------------------------------------
/notebooks/Rotation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "from data_loaders import load_mnist\n",
12 | "import numpy as np\n",
13 | "import matplotlib.pyplot as plt\n",
14 | "import os\n",
15 | "from model import * \n",
16 | "import torch\n",
17 | "from PIL import Image\n",
18 | "import torchvision\n",
19 | "from torchvision import datasets, transforms\n",
20 | "import torch\n",
21 | "from constants import * \n",
22 | "import torch.nn.functional as functional\n",
23 | "from tqdm import tqdm\n",
24 | "import imageio"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "metadata": {
31 | "collapsed": true
32 | },
33 | "outputs": [],
34 | "source": [
35 | "def load_mnist(batch_size, rotate=0, scale=1):\n",
36 | " dataset_transform = transforms.Compose([\n",
37 | " transforms.RandomAffine([rotate, rotate+1], scale=[scale, scale]),\n",
38 | " transforms.ToTensor(),\n",
39 | " transforms.Normalize((0.1307,), (0.3081,))\n",
40 | " ])\n",
41 | " \n",
42 | " train_dataset = datasets.MNIST('../data', \n",
43 | " train=True, \n",
44 | " download=True, \n",
45 | " transform=dataset_transform)\n",
46 | " test_dataset = datasets.MNIST('../data', \n",
47 | " train=False, \n",
48 | " download=True, \n",
49 | " transform=dataset_transform)\n",
50 | "\n",
51 | "\n",
52 | " train_loader = torch.utils.data.DataLoader(train_dataset, \n",
53 | " batch_size=batch_size,\n",
54 | " shuffle=True)\n",
55 | " test_loader = torch.utils.data.DataLoader(test_dataset, \n",
56 | " batch_size=batch_size,\n",
57 | " shuffle=False)\n",
58 | " return train_loader, test_loader\n",
59 | "\n"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": null,
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "capsnet = CapsNet(reconstruction_type=\"FC\")\n",
69 | "capsnet.load_state_dict(torch.load(\"../saved_models/model36.pt\"))\n",
70 | "capsnet.cuda()\n",
71 | "\"\""
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": null,
77 | "metadata": {},
78 | "outputs": [],
79 | "source": [
80 | "iter(load_mnist(20)[1]).next()[1]"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": null,
86 | "metadata": {
87 | "collapsed": true
88 | },
89 | "outputs": [],
90 | "source": [
91 | ", _"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": null,
97 | "metadata": {},
98 | "outputs": [],
99 | "source": [
100 | "j = 1\n",
101 | "for i in tqdm(range(1, 361, 4)):\n",
102 | " _, test_loader = load_mnist(j+1, rotate=0, scale=i/64)\n",
103 | " images, targets = iter(test_loader).next()\n",
104 | "\n",
105 | " target = targets[j].item()\n",
106 | " output, reconstruction, _ = capsnet(images.cuda())\n",
107 | " output = torch.norm(output, dim=2).data.squeeze()\n",
108 | " pred = output.squeeze().max(dim=1)[1][j].item()\n",
109 | " im = images[j, 0].data.cpu().numpy()\n",
110 | " rec = reconstruction[j,0].data.cpu().numpy()\n",
111 | "\n",
112 | " plt.figure(figsize=(20,10))\n",
113 | " plt.subplot(1,3,1)\n",
114 | " plt.title(\"Confidence\")\n",
115 | " plt.ylim([0,1])\n",
116 | " plt.bar(range(0,10), output[j])\n",
117 | " plt.bar(pred, output[j,pred])\n",
118 | " plt.xticks(range(10))\n",
119 | " plt.subplot(1,3,2)\n",
120 | " plt.title(\"Input Image\")\n",
121 | " plt.axis('off')\n",
122 | " plt.imshow(im, cmap=\"gray\")\n",
123 | " plt.subplot(1,3,3)\n",
124 | " plt.title(\"Reconstructed Image\")\n",
125 | " plt.axis('off') \n",
126 | " plt.imshow(rec, cmap=\"gray\")\n",
127 | " plt.savefig(\"rotation/test{}.png\".format(i))\n",
128 | "\n",
129 | "\"\"\"\n",
130 | "fig = plt.figure()\n",
131 | "plt.subplot(1,2,1)\n",
132 | "plt.bar(range(0,10), output[j])\n",
133 | "pred = output[j].max(dim=0)[1].item()\n",
134 | "plt.bar(pred, output[j][pred])\n",
135 | "plt.xticks(range(0,10))\n",
136 | "plt.subplot(1,2,2)\n",
137 | "plt.imshow(im, cmap=\"gray\")\n",
138 | "plt.savefig(\"test.png\")\n",
139 | "\"\"\""
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": null,
145 | "metadata": {},
146 | "outputs": [],
147 | "source": [
148 | "from tqdm import trange\n",
149 | "images = []\n",
150 | "for i in trange(1,361,4):\n",
151 | " images.append(imageio.imread(\"rotation/test{}.png\".format(i)))\n",
152 | "imageio.mimsave('./movie.gif', images)"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": null,
158 | "metadata": {},
159 | "outputs": [],
160 | "source": [
161 | "j = 1\n",
162 | "confidences_correct = []\n",
163 | "confidences_correct_i = []\n",
164 | "confidences_false = []\n",
165 | "confidences_false_i = []\n",
166 | "for i in tqdm(range(0, 360, 2)):\n",
167 | " _, test_loader = load_mnist(j+1, rotate=i)\n",
168 | " images, targets = iter(test_loader).next()\n",
169 | "\n",
170 | " target = targets[j].item()\n",
171 | " output, reconstruction, _ = capsnet(images.cuda())\n",
172 | " output = torch.norm(output, dim=2)\n",
173 | " pred = output.squeeze().max(dim=1)[1][j].item()\n",
174 | " \n",
175 | " if pred == target:\n",
176 | " confidences_correct.append(output[j,target,0].item())\n",
177 | " confidences_correct_i.append(i)\n",
178 | " else:\n",
179 | " confidences_false.append(output[j,target,0].item())\n",
180 | " confidences_false_i.append(i)\n",
181 | " \n",
182 | "# Show Image\n",
183 | "_, test_loader = load_mnist(j+1, rotate=0)\n",
184 | "images, targets = iter(test_loader).next()\n",
185 | "im = images[j, 0].data.numpy()\n",
186 | "plt.imshow(im, cmap=\"gray\")\n",
187 | "\n",
188 | "# Print graph\n",
189 | "print(targets[j])\n"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "execution_count": null,
195 | "metadata": {},
196 | "outputs": [],
197 | "source": [
198 | "_, test_loader = load_mnist(1+1, rotate=0)\n",
199 | "images, targets = iter(test_loader).next()\n",
200 | "im = images[1, 0].data.numpy()\n",
201 | "plt.axis('off')\n",
202 | "plt.imshow(im, cmap=\"gray\")\n"
203 | ]
204 | },
205 | {
206 | "cell_type": "code",
207 | "execution_count": null,
208 | "metadata": {
209 | "collapsed": true
210 | },
211 | "outputs": [],
212 | "source": [
213 | "plt.figure(figsize=(20,10))\n",
214 | "plt.plot(confidences_correct_i, confidences_correct, '.')\n",
215 | "plt.plot(confidences_false_i, confidences_false, '.')\n",
216 | "plt.xlabel(\"Rotation degrees\")\n",
217 | "plt.ylabel(\"Confidence\")\n",
218 | "plt.xlim([0,360])\n",
219 | "plt.ylim([0,1])"
220 | ]
221 | }
222 | ],
223 | "metadata": {
224 | "kernelspec": {
225 | "display_name": "Python 3",
226 | "language": "python",
227 | "name": "python3"
228 | },
229 | "language_info": {
230 | "codemirror_mode": {
231 | "name": "ipython",
232 | "version": 3
233 | },
234 | "file_extension": ".py",
235 | "mimetype": "text/x-python",
236 | "name": "python",
237 | "nbconvert_exporter": "python",
238 | "pygments_lexer": "ipython3",
239 | "version": "3.6.2"
240 | }
241 | },
242 | "nbformat": 4,
243 | "nbformat_minor": 2
244 | }
245 |
--------------------------------------------------------------------------------
/notebooks/StatisticsPlotting.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import pandas as pd\n",
12 | "import os\n",
13 | "import matplotlib.pyplot as plt\n",
14 | "import numpy as np\n",
15 | "%matplotlib inline\n",
16 | "LOG_DIR = \"../logs\"\n",
17 | "filename= \"log-1528586274.9812639.txt\"\n",
18 | "path = os.path.join(LOG_DIR, filename)\n",
19 | "data = pd.read_csv(path,skipinitialspace=True)"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": null,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "data[-10:]"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "plt.figure(figsize=(20,10))\n",
38 | "plt.title(\"Loss\")\n",
39 | "# plt.xlim([0, 80])\n",
40 | "# plt.ylim([0.0, 100.0])\n",
41 | "plt.plot(data.reconstruction_loss_test[0:], '--o', label='Test reconstruction loss')\n",
42 | "plt.plot(data.reconstruction_loss_train[0:], '--o', label='Train reconstruction loss')\n",
43 | "plt.legend()"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "plt.figure(figsize=(20,10))\n",
53 | "plt.title(\"Loss\")\n",
54 | "# plt.xlim([0,80])\n",
55 | "plt.plot(data.test_loss[0:], '--o', label='Test loss')\n",
56 | "plt.plot(data.train_loss[0:], '--o', label='Train loss')\n",
57 | "plt.legend()"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | "plt.figure(figsize=(20,10))\n",
67 | "plt.title(\"Loss\")\n",
68 | "# plt.xlim([0,80])\n",
69 | "plt.plot(data.margin_loss_test[0:], '--o', label='Test loss')\n",
70 | "plt.plot(data.margin_loss_train[0:], '--o', label='Train loss')\n",
71 | "plt.legend()"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": null,
77 | "metadata": {},
78 | "outputs": [],
79 | "source": [
80 | "[len(data.test_accuracy),len(data.test_accuracy)]"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": null,
86 | "metadata": {},
87 | "outputs": [],
88 | "source": [
89 | "plt.figure(figsize=(20,10))\n",
90 | "plt.xlim([0, len(data.test_accuracy)])\n",
91 | "plt.plot([0,len(data.test_accuracy)], [99.5, 99.5], '--')\n",
92 | "plt.plot(data.test_accuracy[0:], '--o')"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": null,
98 | "metadata": {},
99 | "outputs": [],
100 | "source": [
101 | "plt.figure(figsize=(20,10))\n",
102 | "plt.plot(data.time[0:])\n",
103 | "# plt.ylim([100,110])"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": null,
109 | "metadata": {
110 | "collapsed": true
111 | },
112 | "outputs": [],
113 | "source": []
114 | }
115 | ],
116 | "metadata": {
117 | "kernelspec": {
118 | "display_name": "Python 3",
119 | "language": "python",
120 | "name": "python3"
121 | },
122 | "language_info": {
123 | "codemirror_mode": {
124 | "name": "ipython",
125 | "version": 3
126 | },
127 | "file_extension": ".py",
128 | "mimetype": "text/x-python",
129 | "name": "python",
130 | "nbconvert_exporter": "python",
131 | "pygments_lexer": "ipython3",
132 | "version": "3.6.2"
133 | }
134 | },
135 | "nbformat": 4,
136 | "nbformat_minor": 2
137 | }
138 |
--------------------------------------------------------------------------------
/notebooks/Visualization.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": []
7 | },
8 | {
9 | "cell_type": "code",
10 | "execution_count": null,
11 | "metadata": {},
12 | "outputs": [],
13 | "source": [
14 | "import os.path as path\n",
15 | "import numpy as np\n",
16 | "import torch.nn.functional as functional\n",
17 | "from IPython.display import display, clear_output\n",
18 | "from ipywidgets import FloatSlider, interactive, VBox\n",
19 | "import ipywidgets as widgets\n",
20 | "import matplotlib.pyplot as plt\n",
21 | "\n",
22 | "import sys \n",
23 | "sys.path.append('..')\n",
24 | "from constants import *\n",
25 | "from data_loaders import *\n",
26 | "from model import CapsNet\n",
27 | "\n",
28 | "%matplotlib inline\n",
29 | "\n",
30 | "DEBUG_MODE = False\n",
31 | "USE_GPU = True\n",
32 | "MODEL = \"model577.pt\" # Specifies which model to load\n",
33 | "DATASET = \"small_norb\" # 'mnist', 'small_norb'\n",
34 | "RECONSTRUCTION_TYPE = \"FC\" # 'FC' or 'Conv'"
35 | ]
36 | },
37 | {
38 | "cell_type": "markdown",
39 | "metadata": {},
40 | "source": [
41 | "Re-run this block to reset your model outputs if you messed it up."
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "# Load model\n",
51 | "if DATASET == \"mnist\":\n",
52 | " capsnet = CapsNet(reconstruction_type=RECONSTRUCTION_TYPE, alpha=DEFAULT_ALPHA)\n",
53 | " _, test_loader = load_mnist(DEFAULT_BATCH_SIZE)\n",
54 | "if DATASET == \"small_norb\":\n",
55 | " capsnet = CapsNet(reconstruction_type=RECONSTRUCTION_TYPE, alpha=DEFAULT_ALPHA, imsize=28, num_classes=5)\n",
56 | " _, test_loader = load_small_norb(DEFAULT_BATCH_SIZE)\n",
57 | "if USE_GPU:\n",
58 | " capsnet.cuda()\n",
59 | "\n",
60 | "model_path = path.join(\"../\", SAVE_DIR, MODEL)\n",
61 | "capsnet.load_state_dict(torch.load(model_path))\n",
62 | "\n",
63 | "capsnet.eval()\n",
64 | "data, target = iter(test_loader).next()\n",
65 | "target = torch.eye(10).index_select(dim=0, index=target) # One-hot encode target\n",
66 | "output, reconstruction, masked = capsnet(data.cuda())"
67 | ]
68 | },
69 | {
70 | "cell_type": "markdown",
71 | "metadata": {},
72 | "source": [
73 | "Here is where you choose which input image to play around with."
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": null,
79 | "metadata": {
80 | "collapsed": true
81 | },
82 | "outputs": [],
83 | "source": [
84 | "i = np.random.randint(DEFAULT_BATCH_SIZE) # index of chosen image in last batch\n",
85 | "capsules = output[i:i+1] # capsules that correspond to this specific image\n",
86 | "\n",
87 | "# Find prediction\n",
88 | "classes = torch.sqrt((capsules**2).sum(2))\n",
89 | "classes = functional.softmax(classes, dim=1)\n",
90 | "_, prediction = classes.max(dim=1)\n",
91 | "\n",
92 | "if DEBUG_MODE:\n",
93 | " print(\"Image:{}\".format(i))\n",
94 | " print(\"Target:{}\".format(target[i:i+1,:].max(dim=1)[1].item()))\n",
95 | " print(\"Prediction:{}\".format(prediction.item()))\n",
96 | " print(capsules[:,prediction,:,:].shape)"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": null,
102 | "metadata": {
103 | "collapsed": true
104 | },
105 | "outputs": [],
106 | "source": [
107 | "# Dirty work here\n",
108 | "# TODO: Fix problems with capsules and prediction as parameters\n",
109 | "def reconstruct(prediction,c0,c1,c2,c3,c4,c5,c6,c7,c8,c9,c10,c11,c12,c13,c14,c15):\n",
110 | " capsules[:,prediction,0,:] = c0\n",
111 | " capsules[:,prediction,1,:] = c1\n",
112 | " capsules[:,prediction,2,:] = c2\n",
113 | " capsules[:,prediction,3,:] = c3\n",
114 | " capsules[:,prediction,4,:] = c4\n",
115 | " capsules[:,prediction,5,:] = c5\n",
116 | " capsules[:,prediction,6,:] = c6\n",
117 | " capsules[:,prediction,7,:] = c7\n",
118 | " capsules[:,prediction,8,:] = c8\n",
119 | " capsules[:,prediction,9,:] = c9\n",
120 | " capsules[:,prediction,10,:] = c10\n",
121 | " capsules[:,prediction,11,:] = c11\n",
122 | " capsules[:,prediction,12,:] = c12\n",
123 | " capsules[:,prediction,13,:] = c13\n",
124 | " capsules[:,prediction,14,:] = c14\n",
125 | " capsules[:,prediction,15,:] = c15\n",
126 | " \n",
127 | " reconstruction, _ = capsnet.decoder(capsules, data, target[i:i+1].cuda())\n",
128 | " \n",
129 | " im = np.squeeze(reconstruction.data.cpu().numpy())\n",
130 | " im += abs(im.min())\n",
131 | " im /= im.max()\n",
132 | " plt.subplot(1,2,1)\n",
133 | " plt.title(\"Reconstruction\")\n",
134 | " plt.imshow(im, cmap=\"gray\");\n",
135 | " im2 = data[i, 0].data.cpu().numpy()\n",
136 | " im2 += abs(im.min())\n",
137 | " im2 /= im.max()\n",
138 | " plt.subplot(1,2,2)\n",
139 | " plt.title(\"Input\")\n",
140 | " plt.imshow(im2, cmap=\"gray\");\n",
141 | " \n",
142 | "def build_widgets(capsule_init):\n",
143 | " return interactive(reconstruct,\n",
144 | " prediction=prediction,\n",
145 | " c0=FloatSlider(description=\"Capsule 0\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[0]),\n",
146 | " c1=FloatSlider(description=\"Capsule 1\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[1]),\n",
147 | " c2=FloatSlider(description=\"Capsule 2\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[2]),\n",
148 | " c3=FloatSlider(description=\"Capsule 3\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[3]),\n",
149 | " c4=FloatSlider(description=\"Capsule 4\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[4]),\n",
150 | " c5=FloatSlider(description=\"Capsule 5\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[5]),\n",
151 | " c6=FloatSlider(description=\"Capsule 6\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[6]),\n",
152 | " c7=FloatSlider(description=\"Capsule 7\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[7]),\n",
153 | " c8=FloatSlider(description=\"Capsule 8\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[8]),\n",
154 | " c9=FloatSlider(description=\"Capsule 9\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[9]),\n",
155 | " c10=FloatSlider(description=\"Capsule 10\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[10]),\n",
156 | " c11=FloatSlider(description=\"Capsule 11\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[11]),\n",
157 | " c12=FloatSlider(description=\"Capsule 12\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[12]),\n",
158 | " c13=FloatSlider(description=\"Capsule 13\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[13]),\n",
159 | " c14=FloatSlider(description=\"Capsule 14\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[14]),\n",
160 | " c15=FloatSlider(description=\"Capsule 15\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[15]))"
161 | ]
162 | },
163 | {
164 | "cell_type": "markdown",
165 | "metadata": {},
166 | "source": [
167 | "Currently all sliders are initialized to zeros, which means the initial reconstruction is not correct at all. You can set debug mode to true, and adjust the parameters according to the model output vector."
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": null,
173 | "metadata": {
174 | "collapsed": true
175 | },
176 | "outputs": [],
177 | "source": [
178 | "if DEBUG_MODE:\n",
179 | " print(capsules[:,prediction,:,:])"
180 | ]
181 | },
182 | {
183 | "cell_type": "markdown",
184 | "metadata": {},
185 | "source": [
186 | "Re-run this block to reset capsule"
187 | ]
188 | },
189 | {
190 | "cell_type": "code",
191 | "execution_count": null,
192 | "metadata": {},
193 | "outputs": [],
194 | "source": [
195 | "MIN = -1\n",
196 | "MAX = 1\n",
197 | "STEP = 0.05\n",
198 | "CONTINUOUS_UPDATE = True\n",
199 | "\n",
200 | "# Initial values\n",
201 | "capsule_init = capsules[:,prediction,:,:].squeeze()\n",
202 | "\n",
203 | "w = build_widgets(capsule_init)\n",
204 | "display(w)"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": null,
210 | "metadata": {
211 | "collapsed": true
212 | },
213 | "outputs": [],
214 | "source": [
215 | "# Experimental improvements for interaction with visualization\n",
216 | "# CURRENTLY NOT WORKING\n",
217 | "\n",
218 | "# def reconstruct(change, prediction, widgets_list):\n",
219 | "# for i, widget in enumerate(widgets_list):\n",
220 | "# capsules[:,prediction,i,:] = widget.value\n",
221 | " \n",
222 | "# reconstruction, _ = capsnet.decoder(capsules, data, target[i:i+1].cuda())\n",
223 | " \n",
224 | "# if DEBUG_MODE:\n",
225 | "# print(capsules)\n",
226 | "# print(target[i:i+1])\n",
227 | "# print(target[i:i+1].max(dim=1)[1].reshape(-1,1))\n",
228 | " \n",
229 | "# im = np.squeeze(reconstruction.data.cpu().numpy())\n",
230 | "# im += abs(im.min())\n",
231 | "# im /= im.max()\n",
232 | "# plt.subplot(1,2,1)\n",
233 | "# plt.title(\"Reconstruction\")\n",
234 | "# plt.imshow(im, cmap=\"gray\");\n",
235 | "# im2 = data[i, 0].data.cpu().numpy()\n",
236 | "# im2 += abs(im.min())\n",
237 | "# im2 /= im.max()\n",
238 | "# plt.subplot(1,2,2)\n",
239 | "# plt.title(\"Input\")\n",
240 | "# plt.imshow(im2, cmap=\"gray\");\n",
241 | "\n",
242 | "# MIN = -1\n",
243 | "# MAX = 1\n",
244 | "# STEP = 1e-1\n",
245 | "# CAPS_COUNT = 16\n",
246 | "# CONTINUOUS_UPDATE = True\n",
247 | "\n",
248 | "# # Credits to building these widgets: https://stackoverflow.com/questions/37622023\n",
249 | "# widgets_list = []\n",
250 | "# for i in range(CAPS_COUNT):\n",
251 | "# widgets_list.append(FloatSlider(description=\"Capsule \"+str(i),\n",
252 | "# min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE))\n",
253 | "# for widget in widgets_list:\n",
254 | "# widget.observe(lambda change:reconstruct(change, prediction, widgets_list))\n",
255 | " \n",
256 | "# w = VBox(children=widgets_list)"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": null,
262 | "metadata": {
263 | "collapsed": true
264 | },
265 | "outputs": [],
266 | "source": []
267 | }
268 | ],
269 | "metadata": {
270 | "kernelspec": {
271 | "display_name": "Python 3",
272 | "language": "python",
273 | "name": "python3"
274 | },
275 | "language_info": {
276 | "codemirror_mode": {
277 | "name": "ipython",
278 | "version": 3
279 | },
280 | "file_extension": ".py",
281 | "mimetype": "text/x-python",
282 | "name": "python",
283 | "nbconvert_exporter": "python",
284 | "pygments_lexer": "ipython3",
285 | "version": "3.6.2"
286 | }
287 | },
288 | "nbformat": 4,
289 | "nbformat_minor": 2
290 | }
291 |
--------------------------------------------------------------------------------
/options.py:
--------------------------------------------------------------------------------
1 | from constants import *
2 | from optparse import OptionParser
3 |
4 |
5 | def print_options(options):
6 | print("-"*80)
7 | print("Using options:")
8 | values = vars(options)
9 | for key in values.keys():
10 | print("{:15s} {}".format(key, values[key]))
11 | print("-"*80)
12 |
13 | def log_options(options):
14 | logname = "{}.txt".format(options.model)
15 | log_file = os.path.join(OPTIONS_DIR, logname)
16 | os.makedirs(OPTIONS_DIR, exist_ok=True)
17 |
18 | f = open(log_file, 'w')
19 |
20 | f.write("Using options:\n")
21 | values = vars(options)
22 | for key in values.keys():
23 | f.write("{:15s} {}\n".format(key, values[key]))
24 | f.close()
25 |
26 | def create_options():
27 | parser = OptionParser()
28 | parser.add_option("-l", "--lr", dest="learning_rate", default=DEFAULT_LEARNING_RATE, type="float",
29 | help="learning rate")
30 | parser.add_option("-d","--decoder", dest="decoder", default=DEFAULT_DECODER,
31 | help="Decoder structure 'FC' or 'Conv'")
32 | parser.add_option("-b", "--batch_size", dest="batch_size", default=DEFAULT_BATCH_SIZE, type="int")
33 | parser.add_option("-e", "--epochs", dest="epochs", default=DEFAULT_EPOCHS, type="int",
34 | help="Number of epochs to train for")
35 | parser.add_option("-f", "--file", dest="filepath", default="", type="string",
36 | help="Name of the model to be loaded")
37 | parser.add_option("-g", "--use_gpu", dest="use_gpu", default=DEFAULT_USE_GPU, action="store_false",
38 | help="Indicates whether or not to use GPU")
39 | parser.add_option("--save_images", dest="save_images", default=True, action="store_false",
40 | help="Set if you want to save reconstruction results each epoch")
41 | parser.add_option("-a", "--alpha", dest="alpha", default=DEFAULT_ALPHA, type="float",
42 | help="Alpha constant from paper (Amount of reconstruction loss)")
43 | parser.add_option("--dataset", dest="dataset", default=DEFAULT_DATASET, help="Set wanted dataset. Options: [mnist, small_norb,cifar10]")
44 | parser.add_option("-r", "--routing", dest="routing_iterations", default=DEFAULT_ROUTING_ITERATIONS, type="int",
45 | help="Number of routing iterations to use")
46 | parser.add_option("--logfile", dest="log_filepath", default="", type="string",
47 | help="Path to previous logfile if continuing training")
48 | parser.add_option("--gpu_ids", dest="gpu_ids", default=None, type="str",
49 | help="GPU IDS to use if training on multiple GPU. Give ID with comma seperators.")
50 | parser.add_option("--batch_norm", dest="batch_norm", default=False, type=int,
51 | help="Turn on/off batch norm in encoder/decoder")
52 | parser.add_option("--loss", dest="loss_type", default="L2",
53 | help="Define reconstruction loss. Types: [L1, L2]")
54 | parser.add_option("--anneal", dest="anneal_alpha", default="none",
55 | help="Set annealing function for alpha. Options: [none, 1, 2]")
56 | parser.add_option("--leaky", dest="leaky_routing", default=False, action="store_true",
57 | help="Turn on/off leaky routing (Add orphan class for reconstruction)")
58 | parser.add_option("--model", dest="model", help="Set model name")
59 |
60 |
61 |
62 | options, args = parser.parse_args()
63 | assert options.model is not None, "You have to set a model name with the argument --model"
64 | if options.gpu_ids:
65 | options.gpu_ids = [int(x) for x in options.gpu_ids.split(',')]
66 | print_options(options)
67 | log_options(options)
68 |
69 | return options
70 |
71 |
72 |
73 | if __name__ == '__main__':
74 | options = create_options()
75 |
76 |
77 |
--------------------------------------------------------------------------------
/pictures/capsnet_deconv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ethanleet/CapsNet/ac29d616e8f307037c65848864c1a8febb3aa6b7/pictures/capsnet_deconv.png
--------------------------------------------------------------------------------
/pictures/cifar_reconstruction_epoch_86.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ethanleet/CapsNet/ac29d616e8f307037c65848864c1a8febb3aa6b7/pictures/cifar_reconstruction_epoch_86.png
--------------------------------------------------------------------------------
/pictures/primary_caps.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ethanleet/CapsNet/ac29d616e8f307037c65848864c1a8febb3aa6b7/pictures/primary_caps.png
--------------------------------------------------------------------------------
/pictures/rec_visualization.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ethanleet/CapsNet/ac29d616e8f307037c65848864c1a8febb3aa6b7/pictures/rec_visualization.gif
--------------------------------------------------------------------------------
/pictures/reconstruction_epoch_50.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ethanleet/CapsNet/ac29d616e8f307037c65848864c1a8febb3aa6b7/pictures/reconstruction_epoch_50.png
--------------------------------------------------------------------------------
/pictures/robust_rotation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ethanleet/CapsNet/ac29d616e8f307037c65848864c1a8febb3aa6b7/pictures/robust_rotation.gif
--------------------------------------------------------------------------------
/pictures/smallnorb_rec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ethanleet/CapsNet/ac29d616e8f307037c65848864c1a8febb3aa6b7/pictures/smallnorb_rec.png
--------------------------------------------------------------------------------
/smallNorb.py:
--------------------------------------------------------------------------------
1 | # Loader taken from https://github.com/mavanb/vision/blob/448fac0f38cab35a387666d553b9d5e4eec4c5e6/torchvision/datasets/utils.py
2 |
3 | from __future__ import print_function
4 | import os
5 | import errno
6 | import struct
7 |
8 | import torch
9 | import torch.utils.data as data
10 | import numpy as np
11 | from PIL import Image
12 | from torchvision.datasets.utils import download_url, check_integrity
13 |
14 |
15 | class SmallNORB(data.Dataset):
16 | """`MNIST `_ Dataset.
17 | Args:
18 | root (string): Root directory of dataset where processed folder and
19 | and raw folder exist.
20 | train (bool, optional): If True, creates dataset from the training files,
21 | otherwise from the test files.
22 | download (bool, optional): If true, downloads the dataset from the internet and
23 | puts it in root directory. If the dataset is already processed, it is not processed
24 | and downloaded again. If dataset is only already downloaded, it is not
25 | downloaded again.
26 | transform (callable, optional): A function/transform that takes in an PIL image
27 | and returns a transformed version. E.g, ``transforms.RandomCrop``
28 | target_transform (callable, optional): A function/transform that takes in the
29 | target and transforms it.
30 | info_transform (callable, optional): A function/transform that takes in the
31 | info and transforms it.
32 | mode (string, optional): Denotes how the images in the data files are returned. Possible values:
33 | - all (default): both left and right are included separately.
34 | - stereo: left and right images are included as corresponding pairs.
35 | - left: only the left images are included.
36 | - right: only the right images are included.
37 | """
38 |
39 | dataset_root = "https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/"
40 | data_files = {
41 | 'train': {
42 | 'dat': {
43 | "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat',
44 | "md5_gz": "66054832f9accfe74a0f4c36a75bc0a2",
45 | "md5": "8138a0902307b32dfa0025a36dfa45ec"
46 | },
47 | 'info': {
48 | "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-info.mat',
49 | "md5_gz": "51dee1210a742582ff607dfd94e332e3",
50 | "md5": "19faee774120001fc7e17980d6960451"
51 | },
52 | 'cat': {
53 | "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat',
54 | "md5_gz": "23c8b86101fbf0904a000b43d3ed2fd9",
55 | "md5": "fd5120d3f770ad57ebe620eb61a0b633"
56 | },
57 | },
58 | 'test': {
59 | 'dat': {
60 | "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat',
61 | "md5_gz": "e4ad715691ed5a3a5f138751a4ceb071",
62 | "md5": "e9920b7f7b2869a8f1a12e945b2c166c"
63 | },
64 | 'info': {
65 | "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat',
66 | "md5_gz": "a9454f3864d7fd4bb3ea7fc3eb84924e",
67 | "md5": "7c5b871cc69dcadec1bf6a18141f5edc"
68 | },
69 | 'cat': {
70 | "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat',
71 | "md5_gz": "5aa791cd7e6016cf957ce9bdb93b8603",
72 | "md5": "fd5120d3f770ad57ebe620eb61a0b633"
73 | },
74 | },
75 | }
76 |
77 | raw_folder = 'raw'
78 | processed_folder = 'processed'
79 | train_image_file = 'train_img'
80 | train_label_file = 'train_label'
81 | train_info_file = 'train_info'
82 | test_image_file = 'test_img'
83 | test_label_file = 'test_label'
84 | test_info_file = 'test_info'
85 | extension = '.pt'
86 |
87 | def __init__(self, root, train=True, transform=None, target_transform=None, info_transform=None, download=False,
88 | mode="all"):
89 |
90 | self.root = os.path.expanduser(root)
91 | self.transform = transform
92 | self.target_transform = target_transform
93 | self.info_transform = info_transform
94 | self.train = train # training set or test set
95 | self.mode = mode
96 |
97 | if download:
98 | self.download()
99 |
100 | if not self._check_exists():
101 | raise RuntimeError('Dataset not found or corrupted.' +
102 | ' You can use download=True to download it')
103 |
104 | # load test or train set
105 | image_file = self.train_image_file if self.train else self.test_image_file
106 | label_file = self.train_label_file if self.train else self.test_label_file
107 | info_file = self.train_info_file if self.train else self.test_info_file
108 |
109 | # load labels
110 | self.labels = self._load(label_file)
111 |
112 | # load info files
113 | self.infos = self._load(info_file)
114 |
115 | # load right set
116 | if self.mode == "left":
117 | self.data = self._load("{}_left".format(image_file))
118 |
119 | # load left set
120 | elif self.mode == "right":
121 | self.data = self._load("{}_right".format(image_file))
122 |
123 | elif self.mode == "all" or self.mode == "stereo":
124 | left_data = self._load("{}_left".format(image_file))
125 | right_data = self._load("{}_right".format(image_file))
126 |
127 | # load stereo
128 | if self.mode == "stereo":
129 | self.data = torch.stack((left_data, right_data), dim=1)
130 |
131 | # load all
132 | else:
133 | self.data = torch.cat((left_data, right_data), dim=0)
134 |
135 | def __getitem__(self, index):
136 | """
137 | Args:
138 | index (int): Index
139 | Returns:
140 | mode ``all'', ``left'', ``right'':
141 | tuple: (image, target, info)
142 | mode ``stereo'':
143 | tuple: (image left, image right, target, info)
144 | """
145 | target = self.labels[index % 24300] if self.mode is "all" else self.labels[index]
146 | if self.target_transform is not None:
147 | target = self.target_transform(target)
148 |
149 | info = self.infos[index % 24300] if self.mode is "all" else self.infos[index]
150 | if self.info_transform is not None:
151 | info = self.info_transform(info)
152 |
153 | if self.mode == "stereo":
154 | img_left = self._transform(self.data[index, 0])
155 | img_right = self._transform(self.data[index, 1])
156 | return img_left, img_right, target, info
157 |
158 | img = self._transform(self.data[index])
159 | return img, target
160 |
161 | def __len__(self):
162 | return len(self.data)
163 |
164 | def _transform(self, img):
165 | # doing this so that it is consistent with all other data sets
166 | # to return a PIL Image
167 | img = Image.fromarray(img.numpy(), mode='L')
168 |
169 | if self.transform is not None:
170 | img = self.transform(img)
171 | return img
172 |
173 | def _load(self, file_name):
174 | return torch.load(os.path.join(self.root, self.processed_folder, file_name + self.extension))
175 |
176 | def _save(self, file, file_name):
177 | with open(os.path.join(self.root, self.processed_folder, file_name + self.extension), 'wb') as f:
178 | torch.save(file, f)
179 |
180 | def _check_exists(self):
181 | """ Check if processed files exists."""
182 | files = (
183 | "{}_left".format(self.train_image_file),
184 | "{}_right".format(self.train_image_file),
185 | "{}_left".format(self.test_image_file),
186 | "{}_right".format(self.test_image_file),
187 | self.test_label_file,
188 | self.train_label_file
189 | )
190 | fpaths = [os.path.exists(os.path.join(self.root, self.processed_folder, f + self.extension)) for f in files]
191 | return False not in fpaths
192 |
193 | def _flat_data_files(self):
194 | return [j for i in self.data_files.values() for j in list(i.values())]
195 |
196 | def _check_integrity(self):
197 | """Check if unpacked files have correct md5 sum."""
198 | root = self.root
199 | for file_dict in self._flat_data_files():
200 | filename = file_dict["name"]
201 | md5 = file_dict["md5"]
202 | fpath = os.path.join(root, self.raw_folder, filename)
203 | if not check_integrity(fpath, md5):
204 | return False
205 | return True
206 |
207 | def download(self):
208 | """Download the SmallNORB data if it doesn't exist in processed_folder already."""
209 | import gzip
210 |
211 | if self._check_exists():
212 | return
213 |
214 | # check if already extracted and verified
215 | if self._check_integrity():
216 | print('Files already downloaded and verified')
217 | else:
218 | # download and extract
219 | for file_dict in self._flat_data_files():
220 | url = self.dataset_root + file_dict["name"] + '.gz'
221 | filename = file_dict["name"]
222 | gz_filename = filename + '.gz'
223 | md5 = file_dict["md5_gz"]
224 | fpath = os.path.join(self.root, self.raw_folder, filename)
225 | gz_fpath = fpath + '.gz'
226 |
227 | # download if compressed file not exists and verified
228 | download_url(url, os.path.join(self.root, self.raw_folder), gz_filename, md5)
229 |
230 | print('# Extracting data {}\n'.format(filename))
231 |
232 | with open(fpath, 'wb') as out_f, \
233 | gzip.GzipFile(gz_fpath) as zip_f:
234 | out_f.write(zip_f.read())
235 |
236 | os.unlink(gz_fpath)
237 |
238 | # process and save as torch files
239 | print('Processing...')
240 |
241 | # create processed folder
242 | try:
243 | os.makedirs(os.path.join(self.root, self.processed_folder))
244 | except OSError as e:
245 | if e.errno == errno.EEXIST:
246 | pass
247 | else:
248 | raise
249 |
250 | # read train files
251 | left_train_img, right_train_img = self._read_image_file(self.data_files["train"]["dat"]["name"])
252 | train_info = self._read_info_file(self.data_files["train"]["info"]["name"])
253 | train_label = self._read_label_file(self.data_files["train"]["cat"]["name"])
254 |
255 | # read test files
256 | left_test_img, right_test_img = self._read_image_file(self.data_files["test"]["dat"]["name"])
257 | test_info = self._read_info_file(self.data_files["test"]["info"]["name"])
258 | test_label = self._read_label_file(self.data_files["test"]["cat"]["name"])
259 |
260 | # save training files
261 | self._save(left_train_img, "{}_left".format(self.train_image_file))
262 | self._save(right_train_img, "{}_right".format(self.train_image_file))
263 | self._save(train_label, self.train_label_file)
264 | self._save(train_info, self.train_info_file)
265 |
266 | # save test files
267 | self._save(left_test_img, "{}_left".format(self.test_image_file))
268 | self._save(right_test_img, "{}_right".format(self.test_image_file))
269 | self._save(test_label, self.test_label_file)
270 | self._save(test_info, self.test_info_file)
271 |
272 | print('Done!')
273 |
274 | @staticmethod
275 | def _parse_header(file_pointer):
276 | # Read magic number and ignore
277 | struct.unpack('