├── .gitignore ├── LICENSE ├── README.md ├── constants.py ├── data_loaders.py ├── model.py ├── notebooks ├── Rotation.ipynb ├── StatisticsPlotting.ipynb └── Visualization.ipynb ├── options.py ├── pictures ├── capsnet_deconv.png ├── cifar_reconstruction_epoch_86.png ├── primary_caps.png ├── rec_visualization.gif ├── reconstruction_epoch_50.png ├── robust_rotation.gif └── smallnorb_rec.png ├── smallNorb.py ├── stats.py ├── tools.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | saved_models 2 | .DS_Store 3 | logs 4 | reconstructions 5 | __pycache__ 6 | .ipynb_checkpoints 7 | datasets 8 | options/ 9 | logs_old/ 10 | .vscode 11 | data/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Ethan Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CapsNet 2 | Capsule networks is a novel approach showing promising results on SmallNorb and MNIST. Here we reproduce and build upon the impressive results shown by [Sara Sabour et al.](https://arxiv.org/abs/1710.09829) We experiment on the Capsule Network architecture by visualizing exactly what the capsules on different layers represents, what information they store about 3D objects in an image, and try to improve its classification results on CIFAR10 and SmallNorb with various methods including some tricks with reconstruction loss. Further, We present a deconvolution-based reconstruction module that reduces the number of learnable parameters by 80% from the fully-connected module presented by Sara Sabour et al. 3 | 4 | ## Benchmarks 5 | 6 | Our baseline model is the same as the original paper, but is only trained for 113 epochs on MNIST, and we did not use a 7-model ensemble for CIFAR10 as did in the paper. 7 | 8 | |Model | MNIST | SmallNORB | CIFAR10 | 9 | |:-------------|:-------:|:-----------:|:---------:| 10 | |Sabour et al. | 99.75% | 97.3% | 89.40% | 11 | |Baseline | 99.73% | 91.5% | 72.59% | 12 | 13 | ## Experiments 14 | 15 | We introduced a deconvolution-based reconstructions module, and experimented with Batch normalization and different network topologies. 16 | 17 | ### Deconvolution-based Reconstruction 18 | 19 | The baseline model has 1.4M parameters in the fully connected decoder, while our deconvolution-based reconstruction module recudes the number of learnable parameters by 80% down to 0.25M. 20 | 21 | ![](pictures/capsnet_deconv.png) 22 | 23 | Here is an comparison between the two reconstruction modules after training for 25 epochs on MNIST, where RLoss is the SSE reconstruction loss, and MLoss is the margin loss. 24 | 25 | |Model | RLoss | MLoss | Accuracy | 26 | |:------------|:-------:|:-------:|:----------:| 27 | |FC | 21.62 | 0.0058 | 99.51% | 28 | |FC w/ BN | 13.12 | 0.0054 | 99.54% | 29 | |DeConv | 10.87 | 0.0050 | 99.54% | 30 | |DeConv w/ BN | 9.52 | 0.0044 | 99.55% | 31 | 32 | ## Visualization 33 | 34 | ### Reconstructions 35 | 36 | Here are the reconstruction results for SmallNORB and CIFAR10, after training for 186 epochs and 86 epochs respectively. 37 | 38 | ![](pictures/smallnorb_rec.png) 39 | ![](pictures/cifar_reconstruction_epoch_86.png) 40 | 41 | ### Robustness to Affine Transformations 42 | 43 | We visualized how the network recognizes a rotated MNIST image when only trained on unmodified MNIST data. We present an image of number 2 as an example. The network is confident about the result when the image is just slightly rotated, but as the image is further rotated, it starts to confuse the image with other numbers. For example, it is very confident about the image being number 7 at a certain angle, and reconstructs a number 7 that aligns pretty well with the input. Due to its special topological features, the input number 2 is still recognized by the network when rotated by 180°. 44 | 45 | ![](pictures/robust_rotation.gif) 46 | 47 | ### Primary Capsules Reconstructions 48 | 49 | We used a pre-trained network to train a reconstruction module for Primary Capsules. By scaling these capsules by its routing coefficients to the classified object, we were able to visualize reconstructions from Primary Capsules. Each row is reconstructed from a single capsule, and the routing coefficient is increased from left to right. 50 | 51 | ![](pictures/primary_caps.png) 52 | 53 | ## Usage 54 | 55 | **Step 1. Install requirements** 56 | 57 | * Python 3 58 | * PyTorch 1.0.1 59 | * Torchvision 0.2.1 60 | * TQDM 61 | 62 | **Step 2. Adjust hyperparameters** 63 | 64 | In ```constants.py```: 65 | ```python 66 | DEFAULT_LEARNING_RATE = 0.001 67 | DEFAULT_ALPHA = 0.0005 # Scaling factor for reconstruction loss 68 | DEFAULT_DATASET = "small_norb" # 'mnist', 'small_norb' 69 | DEFAULT_DECODER = "FC" # 'FC' or 'Conv' 70 | DEFAULT_BATCH_SIZE = 128 71 | DEFAULT_EPOCHS = 300 72 | DEFAULT_USE_GPU = True 73 | DEFAULT_ROUTING_ITERATIONS = 3 74 | ``` 75 | 76 | **Step 3. Start training** 77 | 78 | Training with default settings: 79 | 80 | ```console 81 | $ python train.py 82 | ``` 83 | 84 | Training flags example: 85 | 86 | ```console 87 | $ python train.py --decoder=Conv --file=model32.pt --dataset=mnist 88 | ``` 89 | 90 | Further help with training flags: 91 | 92 | ```console 93 | $ python train.py -h 94 | ``` 95 | 96 | 97 | **Step 4. Get your results** 98 | 99 | Trained models are saved in ```saved_models``` directory. Tensorboard logs are saved to logs/. You can launch tensorboard with 100 | 101 | ```bash 102 | tensorboard --logdir logs 103 | ``` 104 | 105 | 106 | ## Future work 107 | 108 | * Fully develop notebooks for visualization and plotting. 109 | * Implement [EM routing](https://openreview.net/pdf?id=HJWLfGWRb). 110 | 111 | 112 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | # Directory to save models 3 | SAVE_DIR = "saved_models" 4 | # Directory to save plots 5 | PLOT_DIR = "plots" 6 | # Directory to save logs 7 | LOG_DIR = "logs" 8 | # Directory to save options 9 | OPTIONS_DIR = "options" 10 | # Directory to save images 11 | IMAGES_SAVE_DIR = "reconstructions" 12 | # Directory to save smallNorb Dataset 13 | SMALL_NORB_PATH = os.path.join("datasets", "smallNORB") 14 | 15 | # Default values for command arguments 16 | DEFAULT_LEARNING_RATE = 0.001 17 | DEFAULT_ANNEAL_TEMPERATURE = 8 # Anneal Alpha 18 | DEFAULT_ALPHA = 0.0005 # Scaling factor for reconstruction loss 19 | DEFAULT_DATASET = "small_norb" # 'mnist', 'small_norb' 20 | DEFAULT_DECODER = "FC" # 'FC' or 'Conv' 21 | DEFAULT_BATCH_SIZE = 128 22 | DEFAULT_EPOCHS = 300 # DEFAULT_EPOCHS = 300 23 | DEFAULT_USE_GPU = True 24 | DEFAULT_ROUTING_ITERATIONS = 3 25 | DEFAULT_VALIDATION_SIZE = 1000 26 | 27 | # Random seed for validation split 28 | VALIDATION_SEED = 889256487 -------------------------------------------------------------------------------- /data_loaders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision import datasets, transforms 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | from constants import * 6 | from smallNorb import SmallNORB 7 | 8 | def build_dataloaders(batch_size, valid_size, train_dataset, valid_dataset, test_dataset): 9 | # Compute validation split 10 | train_size = len(train_dataset) 11 | indices = list(range(train_size)) 12 | split = int(np.floor(valid_size * train_size)) 13 | np.random.shuffle(indices) 14 | train_idx, valid_idx = indices[split:], indices[:split] 15 | train_sampler = SubsetRandomSampler(train_idx) 16 | valid_sampler = SubsetRandomSampler(valid_idx) 17 | 18 | # Create dataloaders 19 | train_loader = torch.utils.data.DataLoader(train_dataset, 20 | batch_size=batch_size, 21 | sampler=train_sampler) 22 | valid_loader = torch.utils.data.DataLoader(valid_dataset, 23 | batch_size=batch_size, 24 | sampler=valid_sampler) 25 | test_loader = torch.utils.data.DataLoader(test_dataset, 26 | batch_size=batch_size, 27 | shuffle=False) 28 | return train_loader, valid_loader, test_loader 29 | 30 | def load_mnist(batch_size, valid_size=0.1): 31 | train_transform = transforms.Compose([ 32 | transforms.RandomAffine(0, translate=[0.08,0.08]), 33 | transforms.ToTensor(), 34 | transforms.Normalize((0.1307,), (0.3081,)) 35 | ]) 36 | valid_transform = transforms.Compose([ 37 | transforms.ToTensor(), 38 | transforms.Normalize((0.1307,), (0.3081,)) 39 | ]) 40 | test_transform = transforms.Compose([ 41 | transforms.ToTensor(), 42 | transforms.Normalize((0.1307,), (0.3081,)) 43 | ]) 44 | 45 | train_dataset = datasets.MNIST('../data', 46 | train=True, 47 | download=True, 48 | transform=train_transform) 49 | valid_dataset = datasets.MNIST('../data', 50 | train=True, 51 | download=True, 52 | transform=valid_transform) 53 | test_dataset = datasets.MNIST('../data', 54 | train=False, 55 | download=True, 56 | transform=test_transform) 57 | 58 | return build_dataloaders(batch_size, valid_size, train_dataset, valid_dataset, test_dataset) 59 | 60 | 61 | 62 | def load_small_norb(batch_size): 63 | path = SMALL_NORB_PATH 64 | train_transform = transforms.Compose([ 65 | transforms.Resize(48), 66 | transforms.RandomCrop(32), 67 | transforms.ColorJitter(brightness=32./255, contrast=0.5), 68 | transforms.ToTensor(), 69 | transforms.Normalize((0.0,), (0.3081,)) 70 | ]) 71 | valid_transform = transforms.Compose([ 72 | transforms.Resize(48), 73 | transforms.CenterCrop(32), 74 | transforms.ToTensor(), 75 | transforms.Normalize((0.,), (0.3081,)) 76 | ]) 77 | test_transform = transforms.Compose([ 78 | transforms.Resize(48), 79 | transforms.CenterCrop(32), 80 | transforms.ToTensor(), 81 | transforms.Normalize((0.,), (0.3081,)) 82 | ]) 83 | 84 | train_dataset = SmallNORB(path, train=True, download=True, transform=train_transform) 85 | valid_dataset = SmallNORB(path, train=True, download=True, transform=valid_transform) 86 | test_dataset = SmallNORB(path, train=False, transform=test_transform) 87 | 88 | return build_dataloaders(batch_size, valid_size, train_dataset, valid_dataset, test_dataset) 89 | 90 | def load_cifar10(batch_size, valid_size=0.1): 91 | train_transform = transforms.Compose([ 92 | transforms.ColorJitter(brightness=63./255, contrast=0.8), 93 | transforms.RandomHorizontalFlip(), 94 | transforms.ToTensor(), 95 | transforms.Normalize((0,0,0), (0.5, 0.5, 0.5)) 96 | ]) 97 | valid_transform = transforms.Compose([ 98 | transforms.ToTensor(), 99 | transforms.Normalize((0,0,0), (0.5, 0.5, 0.5)) 100 | ]) 101 | test_transform = transforms.Compose([ 102 | transforms.ToTensor(), 103 | transforms.Normalize((0,0,0), (0.5, 0.5, 0.5)) 104 | ]) 105 | train_dataset = datasets.CIFAR10('../data', 106 | train=True, 107 | download=True, 108 | transform=train_transform) 109 | valid_dataset = datasets.CIFAR10('../data', 110 | train=True, 111 | download=True, 112 | transform=valid_transform) 113 | test_dataset = datasets.CIFAR10('../data', 114 | train=False, 115 | download=True, 116 | transform=test_transform) 117 | 118 | return build_dataloaders(batch_size, valid_size, train_dataset, valid_dataset, test_dataset) 119 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as functional 3 | from tools import squash 4 | import torch 5 | from torch.autograd import Variable 6 | USE_GPU=True 7 | 8 | def routing_algorithm(x, weight, bias, routing_iterations): 9 | """ 10 | x: [batch_size, num_capsules_in, capsule_dim] 11 | weight: [1,num_capsules_in,num_capsules_out,out_channels,in_channels] 12 | bias: [1,1, num_capsules_out, out_channels] 13 | """ 14 | num_capsules_in = x.shape[1] 15 | num_capsules_out = weight.shape[2] 16 | batch_size = x.size(0) 17 | 18 | x = x.unsqueeze(2).unsqueeze(4) 19 | 20 | #[batch_size, 32*6*6, 10, 16] 21 | u_hat = torch.matmul(weight, x).squeeze() 22 | 23 | b_ij = Variable(x.new(batch_size, num_capsules_in, num_capsules_out, 1).zero_()) 24 | 25 | 26 | for it in range(routing_iterations): 27 | c_ij = functional.softmax(b_ij, dim=2) 28 | 29 | # [batch_size, 1, num_classes, capsule_size] 30 | s_j = (c_ij * u_hat).sum(dim=1, keepdim=True) + bias 31 | # [batch_size, 1, num_capsules, out_channels] 32 | v_j = squash(s_j, dim=-1) 33 | 34 | if it < routing_iterations - 1: 35 | # [batch-size, 32*6*6, 10, 1] 36 | delta = (u_hat * v_j).sum(dim=-1, keepdim=True) 37 | b_ij = b_ij + delta 38 | 39 | return v_j.squeeze() 40 | 41 | # First Convolutional Layer 42 | class ConvLayer(nn.Module): 43 | def __init__(self, 44 | in_channels=1, 45 | out_channels=256, 46 | kernel_size=9, 47 | batchnorm=False): 48 | super(ConvLayer, self).__init__() 49 | 50 | if batchnorm: 51 | self.conv = nn.Sequential( 52 | nn.Conv2d(in_channels=in_channels, 53 | out_channels=out_channels, 54 | kernel_size=kernel_size, 55 | stride=1), 56 | nn.BatchNorm2d(out_channels), 57 | nn.ReLU() 58 | ) 59 | else: 60 | self.conv = nn.Sequential( 61 | nn.Conv2d(in_channels=in_channels, 62 | out_channels=out_channels, 63 | kernel_size=kernel_size, 64 | stride=1), 65 | nn.ReLU() 66 | ) 67 | def forward(self, x): 68 | output = self.conv(x) 69 | return output 70 | 71 | class PrimaryCapules(nn.Module): 72 | 73 | def __init__(self, 74 | num_capsules=32, 75 | in_channels=256, 76 | out_channels=8, 77 | kernel_size=9, 78 | primary_caps_gridsize=6, 79 | batchnorm=False): 80 | 81 | super(PrimaryCapules, self).__init__() 82 | self.gridsize = primary_caps_gridsize 83 | self.num_capsules = num_capsules 84 | if batchnorm: 85 | self.capsules = nn.ModuleList([ 86 | nn.Sequential( 87 | nn.Conv2d(in_channels=in_channels, 88 | out_channels=num_capsules, 89 | kernel_size=kernel_size, 90 | stride=2, 91 | padding=0), 92 | nn.BatchNorm2d(num_capsules) 93 | ) 94 | for i in range(out_channels) 95 | ]) 96 | else: 97 | self.capsules = nn.ModuleList([ 98 | nn.Sequential( 99 | nn.Conv2d(in_channels=in_channels, 100 | out_channels=num_capsules, 101 | kernel_size=kernel_size, 102 | stride=2, 103 | padding=0), 104 | 105 | ) 106 | for i in range(out_channels) 107 | ]) 108 | 109 | def forward(self, x): 110 | output = [caps(x) for caps in self.capsules] 111 | output = torch.stack(output, dim=1) 112 | output = output.view(x.size(0), self.num_capsules*(self.gridsize)*(self.gridsize), -1) 113 | 114 | return squash(output) 115 | 116 | 117 | class ClassCapsules(nn.Module): 118 | 119 | def __init__(self, 120 | num_capsules=10, 121 | num_routes = 32*6*6, 122 | in_channels=8, 123 | out_channels=16, 124 | routing_iterations=3, 125 | leaky=False): 126 | super(ClassCapsules, self).__init__() 127 | 128 | 129 | self.in_channels = in_channels 130 | self.num_routes = num_routes 131 | self.num_capsules = num_capsules 132 | self.routing_iterations = routing_iterations 133 | 134 | self.W = nn.Parameter(torch.rand(1,num_routes,num_capsules,out_channels,in_channels)) 135 | self.bias = nn.Parameter(torch.rand(1,1, num_capsules, out_channels)) 136 | 137 | 138 | # [batch_size, 10, 16, 1] 139 | def forward(self, x): 140 | v_j = routing_algorithm(x, self.W, self.bias, self.routing_iterations) 141 | return v_j.unsqueeze(-1) 142 | 143 | 144 | class ReconstructionModule(nn.Module): 145 | def __init__(self, capsule_size=16, num_capsules=10, imsize=28,img_channel=1, batchnorm=False): 146 | super(ReconstructionModule, self).__init__() 147 | 148 | self.num_capsules = num_capsules 149 | self.capsule_size = capsule_size 150 | self.imsize = imsize 151 | self.img_channel = img_channel 152 | if batchnorm: 153 | self.decoder = nn.Sequential( 154 | nn.Linear(capsule_size*num_capsules, 512), 155 | nn.BatchNorm1d(512), 156 | nn.ReLU(), 157 | nn.Linear(512, 1024), 158 | nn.BatchNorm1d(1024), 159 | nn.ReLU(), 160 | nn.Linear(1024, imsize*imsize*img_channel), 161 | nn.Sigmoid() 162 | ) 163 | else: 164 | self.decoder = nn.Sequential( 165 | nn.Linear(capsule_size*num_capsules, 512), 166 | nn.ReLU(), 167 | nn.Linear(512, 1024), 168 | nn.ReLU(), 169 | nn.Linear(1024, imsize*imsize*img_channel), 170 | nn.Sigmoid() 171 | ) 172 | 173 | def forward(self, x, target=None): 174 | batch_size = x.size(0) 175 | if target is None: 176 | classes = torch.norm(x, dim=2) 177 | max_length_indices = classes.max(dim=1)[1].squeeze() 178 | else: 179 | max_length_indices = target.max(dim=1)[1] 180 | 181 | masked = Variable(x.new_tensor(torch.eye(self.num_capsules))) 182 | 183 | masked = masked.index_select(dim=0, index=max_length_indices.data) 184 | decoder_input = (x * masked[:, :, None, None]).view(batch_size, -1) 185 | 186 | reconstructions = self.decoder(decoder_input) 187 | reconstructions = reconstructions.view(-1, self.img_channel, self.imsize, self.imsize) 188 | return reconstructions, masked 189 | 190 | class ConvReconstructionModule(nn.Module): 191 | def __init__(self, num_capsules=10, capsule_size=16, imsize=28,img_channels=1, batchnorm=False): 192 | super(ConvReconstructionModule, self).__init__() 193 | self.num_capsules = num_capsules 194 | self.capsule_size = capsule_size 195 | self.imsize = imsize 196 | self.img_channels = img_channels 197 | self.grid_size = 6 198 | if batchnorm: 199 | self.FC = nn.Sequential( 200 | nn.Linear(capsule_size * num_capsules, num_capsules * (self.grid_size)**2 ), 201 | nn.BatchNorm1d(num_capsules * self.grid_size**2), 202 | nn.ReLU() 203 | ) 204 | self.decoder = nn.Sequential( 205 | nn.ConvTranspose2d(in_channels=self.num_capsules, out_channels=32, kernel_size=9, stride=2), 206 | nn.BatchNorm2d(32), 207 | nn.ReLU(), 208 | nn.ConvTranspose2d(in_channels=32, out_channels=64, kernel_size=9, stride=1), 209 | nn.BatchNorm2d(64), 210 | nn.ReLU(), 211 | nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=2, stride=1), 212 | nn.Sigmoid() 213 | ) 214 | else: 215 | self.FC = nn.Sequential( 216 | nn.Linear(capsule_size * num_capsules, num_capsules *(self.grid_size**2) ), 217 | nn.ReLU() 218 | ) 219 | self.decoder = nn.Sequential( 220 | nn.ConvTranspose2d(in_channels=self.num_capsules, out_channels=32, kernel_size=9, stride=2), 221 | nn.ReLU(), 222 | nn.ConvTranspose2d(in_channels=32, out_channels=64, kernel_size=9, stride=1), 223 | nn.ReLU(), 224 | nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=2, stride=1), 225 | nn.Sigmoid() 226 | ) 227 | 228 | def forward(self, x, target=None): 229 | batch_size = x.size(0) 230 | if target is None: 231 | classes = torch.norm(x, dim=2) 232 | max_length_indices = classes.max(dim=1)[1].squeeze() 233 | else: 234 | max_length_indices = target.max(dim=1)[1] 235 | 236 | masked = x.new_tensor(torch.eye(self.num_capsules)) 237 | masked = masked.index_select(dim=0, index=max_length_indices.data) 238 | 239 | decoder_input = (x * masked[:, :, None, None]).view(batch_size, -1) 240 | decoder_input = self.FC(decoder_input) 241 | decoder_input = decoder_input.view(batch_size,self.num_capsules, self.grid_size, self.grid_size) 242 | reconstructions = self.decoder(decoder_input) 243 | reconstructions = reconstructions.view(-1, self.img_channels, self.imsize, self.imsize) 244 | 245 | return reconstructions, masked 246 | 247 | 248 | 249 | 250 | class SmallNorbConvReconstructionModule(nn.Module): 251 | def __init__(self, num_capsules=10, capsule_size=16, imsize=28,img_channels=1, batchnorm=False): 252 | super(SmallNorbConvReconstructionModule, self).__init__() 253 | self.num_capsules = num_capsules 254 | self.capsule_size = capsule_size 255 | self.imsize = imsize 256 | self.img_channels = img_channels 257 | 258 | self.grid_size = 4 259 | 260 | if batchnorm: 261 | self.FC = nn.Sequential( 262 | nn.Linear(capsule_size * num_capsules, num_capsules *self.grid_size*self.grid_size), 263 | nn.BatchNorm1d(num_capsules * self.grid_size**2), 264 | nn.ReLU() 265 | ) 266 | self.decoder = nn.Sequential( 267 | nn.ConvTranspose2d(in_channels=num_capsules, out_channels=32, kernel_size=9, stride=2), 268 | nn.BatchNorm2d(32), 269 | nn.ReLU(), 270 | nn.ConvTranspose2d(in_channels=32, out_channels=64, kernel_size=9, stride=1), 271 | nn.BatchNorm2d(64), 272 | nn.ReLU(), 273 | nn.ConvTranspose2d(in_channels=64, out_channels=128, kernel_size=9, stride=1), 274 | nn.BatchNorm2d(128), 275 | nn.ReLU(), 276 | nn.ConvTranspose2d(in_channels=128, out_channels=img_channels, kernel_size=2, stride=1), 277 | nn.Sigmoid() 278 | ) 279 | else: 280 | self.FC = nn.Sequential( 281 | nn.Linear(capsule_size * num_capsules, num_capsules *(self.grid_size**2) ), 282 | nn.ReLU() 283 | ) 284 | self.decoder = nn.Sequential( 285 | nn.ConvTranspose2d(in_channels=num_capsules, out_channels=32, kernel_size=9, stride=2), 286 | nn.ReLU(), 287 | nn.ConvTranspose2d(in_channels=32, out_channels=64, kernel_size=9, stride=1), 288 | nn.ReLU(), 289 | nn.ConvTranspose2d(in_channels=64, out_channels=128, kernel_size=9, stride=1), 290 | nn.ReLU(), 291 | nn.ConvTranspose2d(in_channels=128, out_channels=img_channels, kernel_size=2, stride=1), 292 | nn.Sigmoid() 293 | ) 294 | 295 | def forward(self, x, target=None): 296 | batch_size = x.size(0) 297 | if target is None: 298 | classes = torch.norm(x, dim=2) 299 | max_length_indices = classes.max(dim=1)[1].squeeze() 300 | else: 301 | max_length_indices = target.max(dim=1)[1] 302 | masked = Variable(x.new_tensor(torch.eye(self.num_capsules))) 303 | masked = masked.index_select(dim=0, index=max_length_indices.data) 304 | 305 | decoder_input = (x * masked[:, :, None, None]).view(batch_size, -1) 306 | decoder_input = self.FC(decoder_input) 307 | decoder_input = decoder_input.view(batch_size,self.num_capsules, self.grid_size, self.grid_size) 308 | reconstructions = self.decoder(decoder_input) 309 | reconstructions = reconstructions.view(-1, self.img_channels, self.imsize, self.imsize) 310 | 311 | return reconstructions, masked 312 | 313 | 314 | 315 | 316 | class CapsNet(nn.Module): 317 | 318 | def __init__(self, 319 | reconstruction_type = "FC", 320 | imsize=28, 321 | num_classes=10, 322 | routing_iterations=3, 323 | primary_caps_gridsize=6, 324 | img_channels = 1, 325 | batchnorm = False, 326 | loss = "L2", 327 | num_primary_capsules=32, 328 | leaky_routing = False 329 | ): 330 | super(CapsNet, self).__init__() 331 | self.num_classes = num_classes 332 | if leaky_routing: 333 | num_classes += 1 334 | self.num_classes += 1 335 | 336 | self.imsize=imsize 337 | self.conv_layer = ConvLayer(in_channels=img_channels, batchnorm=batchnorm) 338 | self.leaky_routing = leaky_routing 339 | 340 | self.primary_capsules = PrimaryCapules(primary_caps_gridsize=primary_caps_gridsize, 341 | batchnorm=batchnorm, 342 | num_capsules = num_primary_capsules) 343 | 344 | self.digit_caps = ClassCapsules(num_capsules=num_classes, 345 | num_routes=num_primary_capsules*primary_caps_gridsize*primary_caps_gridsize, 346 | routing_iterations=routing_iterations, 347 | leaky=leaky_routing) 348 | 349 | if reconstruction_type == "FC": 350 | self.decoder = ReconstructionModule(imsize=imsize, 351 | num_capsules=num_classes, 352 | img_channel=img_channels, 353 | batchnorm=batchnorm) 354 | elif reconstruction_type == "Conv32": 355 | self.decoder = SmallNorbConvReconstructionModule(num_capsules=num_classes, 356 | imsize=imsize, 357 | img_channels=img_channels, 358 | batchnorm=batchnorm) 359 | else: 360 | self.decoder = ConvReconstructionModule(num_capsules=num_classes, 361 | imsize=imsize, 362 | img_channels=img_channels, 363 | batchnorm=batchnorm) 364 | 365 | if loss == "L2": 366 | self.reconstruction_criterion = nn.MSELoss(reduction="none") 367 | if loss == "L1": 368 | self.reconstruction_criterion = nn.L1Loss(reduction="none") 369 | 370 | def forward(self, x, target=None): 371 | output = self.conv_layer(x) 372 | output = self.primary_capsules(output) 373 | output = self.digit_caps(output) 374 | reconstruction, masked = self.decoder(output, target) 375 | 376 | return output, reconstruction, masked 377 | 378 | def loss(self, images, labels, capsule_output, reconstruction, alpha): 379 | marg_loss = self.margin_loss(capsule_output, labels) 380 | rec_loss = self.reconstruction_loss(images, reconstruction) 381 | total_loss = (marg_loss + alpha * rec_loss).mean() 382 | return total_loss, rec_loss.mean(), marg_loss.mean() 383 | 384 | def margin_loss(self, x, labels): 385 | batch_size = x.size(0) 386 | v_c = torch.norm(x, dim=2, keepdim=True) 387 | 388 | left = functional.relu(0.9 - v_c).view(batch_size, -1) ** 2 389 | right = functional.relu(v_c - 0.1).view(batch_size, -1) ** 2 390 | 391 | loss = labels * left + 0.5 *(1-labels)*right 392 | loss = loss.sum(dim=1) 393 | return loss 394 | 395 | def reconstruction_loss(self, data, reconstructions): 396 | batch_size = reconstructions.size(0) 397 | reconstructions = reconstructions.view(batch_size, -1) 398 | data = data.view(batch_size, -1) 399 | loss = self.reconstruction_criterion(reconstructions, data) 400 | loss = loss.sum(dim=1) 401 | return loss 402 | -------------------------------------------------------------------------------- /notebooks/Rotation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "from data_loaders import load_mnist\n", 12 | "import numpy as np\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "import os\n", 15 | "from model import * \n", 16 | "import torch\n", 17 | "from PIL import Image\n", 18 | "import torchvision\n", 19 | "from torchvision import datasets, transforms\n", 20 | "import torch\n", 21 | "from constants import * \n", 22 | "import torch.nn.functional as functional\n", 23 | "from tqdm import tqdm\n", 24 | "import imageio" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": { 31 | "collapsed": true 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "def load_mnist(batch_size, rotate=0, scale=1):\n", 36 | " dataset_transform = transforms.Compose([\n", 37 | " transforms.RandomAffine([rotate, rotate+1], scale=[scale, scale]),\n", 38 | " transforms.ToTensor(),\n", 39 | " transforms.Normalize((0.1307,), (0.3081,))\n", 40 | " ])\n", 41 | " \n", 42 | " train_dataset = datasets.MNIST('../data', \n", 43 | " train=True, \n", 44 | " download=True, \n", 45 | " transform=dataset_transform)\n", 46 | " test_dataset = datasets.MNIST('../data', \n", 47 | " train=False, \n", 48 | " download=True, \n", 49 | " transform=dataset_transform)\n", 50 | "\n", 51 | "\n", 52 | " train_loader = torch.utils.data.DataLoader(train_dataset, \n", 53 | " batch_size=batch_size,\n", 54 | " shuffle=True)\n", 55 | " test_loader = torch.utils.data.DataLoader(test_dataset, \n", 56 | " batch_size=batch_size,\n", 57 | " shuffle=False)\n", 58 | " return train_loader, test_loader\n", 59 | "\n" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "capsnet = CapsNet(reconstruction_type=\"FC\")\n", 69 | "capsnet.load_state_dict(torch.load(\"../saved_models/model36.pt\"))\n", 70 | "capsnet.cuda()\n", 71 | "\"\"" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "iter(load_mnist(20)[1]).next()[1]" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": { 87 | "collapsed": true 88 | }, 89 | "outputs": [], 90 | "source": [ 91 | ", _" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "j = 1\n", 101 | "for i in tqdm(range(1, 361, 4)):\n", 102 | " _, test_loader = load_mnist(j+1, rotate=0, scale=i/64)\n", 103 | " images, targets = iter(test_loader).next()\n", 104 | "\n", 105 | " target = targets[j].item()\n", 106 | " output, reconstruction, _ = capsnet(images.cuda())\n", 107 | " output = torch.norm(output, dim=2).data.squeeze()\n", 108 | " pred = output.squeeze().max(dim=1)[1][j].item()\n", 109 | " im = images[j, 0].data.cpu().numpy()\n", 110 | " rec = reconstruction[j,0].data.cpu().numpy()\n", 111 | "\n", 112 | " plt.figure(figsize=(20,10))\n", 113 | " plt.subplot(1,3,1)\n", 114 | " plt.title(\"Confidence\")\n", 115 | " plt.ylim([0,1])\n", 116 | " plt.bar(range(0,10), output[j])\n", 117 | " plt.bar(pred, output[j,pred])\n", 118 | " plt.xticks(range(10))\n", 119 | " plt.subplot(1,3,2)\n", 120 | " plt.title(\"Input Image\")\n", 121 | " plt.axis('off')\n", 122 | " plt.imshow(im, cmap=\"gray\")\n", 123 | " plt.subplot(1,3,3)\n", 124 | " plt.title(\"Reconstructed Image\")\n", 125 | " plt.axis('off') \n", 126 | " plt.imshow(rec, cmap=\"gray\")\n", 127 | " plt.savefig(\"rotation/test{}.png\".format(i))\n", 128 | "\n", 129 | "\"\"\"\n", 130 | "fig = plt.figure()\n", 131 | "plt.subplot(1,2,1)\n", 132 | "plt.bar(range(0,10), output[j])\n", 133 | "pred = output[j].max(dim=0)[1].item()\n", 134 | "plt.bar(pred, output[j][pred])\n", 135 | "plt.xticks(range(0,10))\n", 136 | "plt.subplot(1,2,2)\n", 137 | "plt.imshow(im, cmap=\"gray\")\n", 138 | "plt.savefig(\"test.png\")\n", 139 | "\"\"\"" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "from tqdm import trange\n", 149 | "images = []\n", 150 | "for i in trange(1,361,4):\n", 151 | " images.append(imageio.imread(\"rotation/test{}.png\".format(i)))\n", 152 | "imageio.mimsave('./movie.gif', images)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "j = 1\n", 162 | "confidences_correct = []\n", 163 | "confidences_correct_i = []\n", 164 | "confidences_false = []\n", 165 | "confidences_false_i = []\n", 166 | "for i in tqdm(range(0, 360, 2)):\n", 167 | " _, test_loader = load_mnist(j+1, rotate=i)\n", 168 | " images, targets = iter(test_loader).next()\n", 169 | "\n", 170 | " target = targets[j].item()\n", 171 | " output, reconstruction, _ = capsnet(images.cuda())\n", 172 | " output = torch.norm(output, dim=2)\n", 173 | " pred = output.squeeze().max(dim=1)[1][j].item()\n", 174 | " \n", 175 | " if pred == target:\n", 176 | " confidences_correct.append(output[j,target,0].item())\n", 177 | " confidences_correct_i.append(i)\n", 178 | " else:\n", 179 | " confidences_false.append(output[j,target,0].item())\n", 180 | " confidences_false_i.append(i)\n", 181 | " \n", 182 | "# Show Image\n", 183 | "_, test_loader = load_mnist(j+1, rotate=0)\n", 184 | "images, targets = iter(test_loader).next()\n", 185 | "im = images[j, 0].data.numpy()\n", 186 | "plt.imshow(im, cmap=\"gray\")\n", 187 | "\n", 188 | "# Print graph\n", 189 | "print(targets[j])\n" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "_, test_loader = load_mnist(1+1, rotate=0)\n", 199 | "images, targets = iter(test_loader).next()\n", 200 | "im = images[1, 0].data.numpy()\n", 201 | "plt.axis('off')\n", 202 | "plt.imshow(im, cmap=\"gray\")\n" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "metadata": { 209 | "collapsed": true 210 | }, 211 | "outputs": [], 212 | "source": [ 213 | "plt.figure(figsize=(20,10))\n", 214 | "plt.plot(confidences_correct_i, confidences_correct, '.')\n", 215 | "plt.plot(confidences_false_i, confidences_false, '.')\n", 216 | "plt.xlabel(\"Rotation degrees\")\n", 217 | "plt.ylabel(\"Confidence\")\n", 218 | "plt.xlim([0,360])\n", 219 | "plt.ylim([0,1])" 220 | ] 221 | } 222 | ], 223 | "metadata": { 224 | "kernelspec": { 225 | "display_name": "Python 3", 226 | "language": "python", 227 | "name": "python3" 228 | }, 229 | "language_info": { 230 | "codemirror_mode": { 231 | "name": "ipython", 232 | "version": 3 233 | }, 234 | "file_extension": ".py", 235 | "mimetype": "text/x-python", 236 | "name": "python", 237 | "nbconvert_exporter": "python", 238 | "pygments_lexer": "ipython3", 239 | "version": "3.6.2" 240 | } 241 | }, 242 | "nbformat": 4, 243 | "nbformat_minor": 2 244 | } 245 | -------------------------------------------------------------------------------- /notebooks/StatisticsPlotting.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import pandas as pd\n", 12 | "import os\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "import numpy as np\n", 15 | "%matplotlib inline\n", 16 | "LOG_DIR = \"../logs\"\n", 17 | "filename= \"log-1528586274.9812639.txt\"\n", 18 | "path = os.path.join(LOG_DIR, filename)\n", 19 | "data = pd.read_csv(path,skipinitialspace=True)" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "data[-10:]" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "plt.figure(figsize=(20,10))\n", 38 | "plt.title(\"Loss\")\n", 39 | "# plt.xlim([0, 80])\n", 40 | "# plt.ylim([0.0, 100.0])\n", 41 | "plt.plot(data.reconstruction_loss_test[0:], '--o', label='Test reconstruction loss')\n", 42 | "plt.plot(data.reconstruction_loss_train[0:], '--o', label='Train reconstruction loss')\n", 43 | "plt.legend()" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "plt.figure(figsize=(20,10))\n", 53 | "plt.title(\"Loss\")\n", 54 | "# plt.xlim([0,80])\n", 55 | "plt.plot(data.test_loss[0:], '--o', label='Test loss')\n", 56 | "plt.plot(data.train_loss[0:], '--o', label='Train loss')\n", 57 | "plt.legend()" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "plt.figure(figsize=(20,10))\n", 67 | "plt.title(\"Loss\")\n", 68 | "# plt.xlim([0,80])\n", 69 | "plt.plot(data.margin_loss_test[0:], '--o', label='Test loss')\n", 70 | "plt.plot(data.margin_loss_train[0:], '--o', label='Train loss')\n", 71 | "plt.legend()" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "[len(data.test_accuracy),len(data.test_accuracy)]" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "plt.figure(figsize=(20,10))\n", 90 | "plt.xlim([0, len(data.test_accuracy)])\n", 91 | "plt.plot([0,len(data.test_accuracy)], [99.5, 99.5], '--')\n", 92 | "plt.plot(data.test_accuracy[0:], '--o')" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "plt.figure(figsize=(20,10))\n", 102 | "plt.plot(data.time[0:])\n", 103 | "# plt.ylim([100,110])" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": { 110 | "collapsed": true 111 | }, 112 | "outputs": [], 113 | "source": [] 114 | } 115 | ], 116 | "metadata": { 117 | "kernelspec": { 118 | "display_name": "Python 3", 119 | "language": "python", 120 | "name": "python3" 121 | }, 122 | "language_info": { 123 | "codemirror_mode": { 124 | "name": "ipython", 125 | "version": 3 126 | }, 127 | "file_extension": ".py", 128 | "mimetype": "text/x-python", 129 | "name": "python", 130 | "nbconvert_exporter": "python", 131 | "pygments_lexer": "ipython3", 132 | "version": "3.6.2" 133 | } 134 | }, 135 | "nbformat": 4, 136 | "nbformat_minor": 2 137 | } 138 | -------------------------------------------------------------------------------- /notebooks/Visualization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [] 7 | }, 8 | { 9 | "cell_type": "code", 10 | "execution_count": null, 11 | "metadata": {}, 12 | "outputs": [], 13 | "source": [ 14 | "import os.path as path\n", 15 | "import numpy as np\n", 16 | "import torch.nn.functional as functional\n", 17 | "from IPython.display import display, clear_output\n", 18 | "from ipywidgets import FloatSlider, interactive, VBox\n", 19 | "import ipywidgets as widgets\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "\n", 22 | "import sys \n", 23 | "sys.path.append('..')\n", 24 | "from constants import *\n", 25 | "from data_loaders import *\n", 26 | "from model import CapsNet\n", 27 | "\n", 28 | "%matplotlib inline\n", 29 | "\n", 30 | "DEBUG_MODE = False\n", 31 | "USE_GPU = True\n", 32 | "MODEL = \"model577.pt\" # Specifies which model to load\n", 33 | "DATASET = \"small_norb\" # 'mnist', 'small_norb'\n", 34 | "RECONSTRUCTION_TYPE = \"FC\" # 'FC' or 'Conv'" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "Re-run this block to reset your model outputs if you messed it up." 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "# Load model\n", 51 | "if DATASET == \"mnist\":\n", 52 | " capsnet = CapsNet(reconstruction_type=RECONSTRUCTION_TYPE, alpha=DEFAULT_ALPHA)\n", 53 | " _, test_loader = load_mnist(DEFAULT_BATCH_SIZE)\n", 54 | "if DATASET == \"small_norb\":\n", 55 | " capsnet = CapsNet(reconstruction_type=RECONSTRUCTION_TYPE, alpha=DEFAULT_ALPHA, imsize=28, num_classes=5)\n", 56 | " _, test_loader = load_small_norb(DEFAULT_BATCH_SIZE)\n", 57 | "if USE_GPU:\n", 58 | " capsnet.cuda()\n", 59 | "\n", 60 | "model_path = path.join(\"../\", SAVE_DIR, MODEL)\n", 61 | "capsnet.load_state_dict(torch.load(model_path))\n", 62 | "\n", 63 | "capsnet.eval()\n", 64 | "data, target = iter(test_loader).next()\n", 65 | "target = torch.eye(10).index_select(dim=0, index=target) # One-hot encode target\n", 66 | "output, reconstruction, masked = capsnet(data.cuda())" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "Here is where you choose which input image to play around with." 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": { 80 | "collapsed": true 81 | }, 82 | "outputs": [], 83 | "source": [ 84 | "i = np.random.randint(DEFAULT_BATCH_SIZE) # index of chosen image in last batch\n", 85 | "capsules = output[i:i+1] # capsules that correspond to this specific image\n", 86 | "\n", 87 | "# Find prediction\n", 88 | "classes = torch.sqrt((capsules**2).sum(2))\n", 89 | "classes = functional.softmax(classes, dim=1)\n", 90 | "_, prediction = classes.max(dim=1)\n", 91 | "\n", 92 | "if DEBUG_MODE:\n", 93 | " print(\"Image:{}\".format(i))\n", 94 | " print(\"Target:{}\".format(target[i:i+1,:].max(dim=1)[1].item()))\n", 95 | " print(\"Prediction:{}\".format(prediction.item()))\n", 96 | " print(capsules[:,prediction,:,:].shape)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": { 103 | "collapsed": true 104 | }, 105 | "outputs": [], 106 | "source": [ 107 | "# Dirty work here\n", 108 | "# TODO: Fix problems with capsules and prediction as parameters\n", 109 | "def reconstruct(prediction,c0,c1,c2,c3,c4,c5,c6,c7,c8,c9,c10,c11,c12,c13,c14,c15):\n", 110 | " capsules[:,prediction,0,:] = c0\n", 111 | " capsules[:,prediction,1,:] = c1\n", 112 | " capsules[:,prediction,2,:] = c2\n", 113 | " capsules[:,prediction,3,:] = c3\n", 114 | " capsules[:,prediction,4,:] = c4\n", 115 | " capsules[:,prediction,5,:] = c5\n", 116 | " capsules[:,prediction,6,:] = c6\n", 117 | " capsules[:,prediction,7,:] = c7\n", 118 | " capsules[:,prediction,8,:] = c8\n", 119 | " capsules[:,prediction,9,:] = c9\n", 120 | " capsules[:,prediction,10,:] = c10\n", 121 | " capsules[:,prediction,11,:] = c11\n", 122 | " capsules[:,prediction,12,:] = c12\n", 123 | " capsules[:,prediction,13,:] = c13\n", 124 | " capsules[:,prediction,14,:] = c14\n", 125 | " capsules[:,prediction,15,:] = c15\n", 126 | " \n", 127 | " reconstruction, _ = capsnet.decoder(capsules, data, target[i:i+1].cuda())\n", 128 | " \n", 129 | " im = np.squeeze(reconstruction.data.cpu().numpy())\n", 130 | " im += abs(im.min())\n", 131 | " im /= im.max()\n", 132 | " plt.subplot(1,2,1)\n", 133 | " plt.title(\"Reconstruction\")\n", 134 | " plt.imshow(im, cmap=\"gray\");\n", 135 | " im2 = data[i, 0].data.cpu().numpy()\n", 136 | " im2 += abs(im.min())\n", 137 | " im2 /= im.max()\n", 138 | " plt.subplot(1,2,2)\n", 139 | " plt.title(\"Input\")\n", 140 | " plt.imshow(im2, cmap=\"gray\");\n", 141 | " \n", 142 | "def build_widgets(capsule_init):\n", 143 | " return interactive(reconstruct,\n", 144 | " prediction=prediction,\n", 145 | " c0=FloatSlider(description=\"Capsule 0\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[0]),\n", 146 | " c1=FloatSlider(description=\"Capsule 1\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[1]),\n", 147 | " c2=FloatSlider(description=\"Capsule 2\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[2]),\n", 148 | " c3=FloatSlider(description=\"Capsule 3\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[3]),\n", 149 | " c4=FloatSlider(description=\"Capsule 4\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[4]),\n", 150 | " c5=FloatSlider(description=\"Capsule 5\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[5]),\n", 151 | " c6=FloatSlider(description=\"Capsule 6\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[6]),\n", 152 | " c7=FloatSlider(description=\"Capsule 7\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[7]),\n", 153 | " c8=FloatSlider(description=\"Capsule 8\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[8]),\n", 154 | " c9=FloatSlider(description=\"Capsule 9\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[9]),\n", 155 | " c10=FloatSlider(description=\"Capsule 10\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[10]),\n", 156 | " c11=FloatSlider(description=\"Capsule 11\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[11]),\n", 157 | " c12=FloatSlider(description=\"Capsule 12\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[12]),\n", 158 | " c13=FloatSlider(description=\"Capsule 13\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[13]),\n", 159 | " c14=FloatSlider(description=\"Capsule 14\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[14]),\n", 160 | " c15=FloatSlider(description=\"Capsule 15\",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE, value=capsule_init[15]))" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "Currently all sliders are initialized to zeros, which means the initial reconstruction is not correct at all. You can set debug mode to true, and adjust the parameters according to the model output vector." 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": { 174 | "collapsed": true 175 | }, 176 | "outputs": [], 177 | "source": [ 178 | "if DEBUG_MODE:\n", 179 | " print(capsules[:,prediction,:,:])" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": {}, 185 | "source": [ 186 | "Re-run this block to reset capsule" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "MIN = -1\n", 196 | "MAX = 1\n", 197 | "STEP = 0.05\n", 198 | "CONTINUOUS_UPDATE = True\n", 199 | "\n", 200 | "# Initial values\n", 201 | "capsule_init = capsules[:,prediction,:,:].squeeze()\n", 202 | "\n", 203 | "w = build_widgets(capsule_init)\n", 204 | "display(w)" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": { 211 | "collapsed": true 212 | }, 213 | "outputs": [], 214 | "source": [ 215 | "# Experimental improvements for interaction with visualization\n", 216 | "# CURRENTLY NOT WORKING\n", 217 | "\n", 218 | "# def reconstruct(change, prediction, widgets_list):\n", 219 | "# for i, widget in enumerate(widgets_list):\n", 220 | "# capsules[:,prediction,i,:] = widget.value\n", 221 | " \n", 222 | "# reconstruction, _ = capsnet.decoder(capsules, data, target[i:i+1].cuda())\n", 223 | " \n", 224 | "# if DEBUG_MODE:\n", 225 | "# print(capsules)\n", 226 | "# print(target[i:i+1])\n", 227 | "# print(target[i:i+1].max(dim=1)[1].reshape(-1,1))\n", 228 | " \n", 229 | "# im = np.squeeze(reconstruction.data.cpu().numpy())\n", 230 | "# im += abs(im.min())\n", 231 | "# im /= im.max()\n", 232 | "# plt.subplot(1,2,1)\n", 233 | "# plt.title(\"Reconstruction\")\n", 234 | "# plt.imshow(im, cmap=\"gray\");\n", 235 | "# im2 = data[i, 0].data.cpu().numpy()\n", 236 | "# im2 += abs(im.min())\n", 237 | "# im2 /= im.max()\n", 238 | "# plt.subplot(1,2,2)\n", 239 | "# plt.title(\"Input\")\n", 240 | "# plt.imshow(im2, cmap=\"gray\");\n", 241 | "\n", 242 | "# MIN = -1\n", 243 | "# MAX = 1\n", 244 | "# STEP = 1e-1\n", 245 | "# CAPS_COUNT = 16\n", 246 | "# CONTINUOUS_UPDATE = True\n", 247 | "\n", 248 | "# # Credits to building these widgets: https://stackoverflow.com/questions/37622023\n", 249 | "# widgets_list = []\n", 250 | "# for i in range(CAPS_COUNT):\n", 251 | "# widgets_list.append(FloatSlider(description=\"Capsule \"+str(i),\n", 252 | "# min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE))\n", 253 | "# for widget in widgets_list:\n", 254 | "# widget.observe(lambda change:reconstruct(change, prediction, widgets_list))\n", 255 | " \n", 256 | "# w = VBox(children=widgets_list)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "metadata": { 263 | "collapsed": true 264 | }, 265 | "outputs": [], 266 | "source": [] 267 | } 268 | ], 269 | "metadata": { 270 | "kernelspec": { 271 | "display_name": "Python 3", 272 | "language": "python", 273 | "name": "python3" 274 | }, 275 | "language_info": { 276 | "codemirror_mode": { 277 | "name": "ipython", 278 | "version": 3 279 | }, 280 | "file_extension": ".py", 281 | "mimetype": "text/x-python", 282 | "name": "python", 283 | "nbconvert_exporter": "python", 284 | "pygments_lexer": "ipython3", 285 | "version": "3.6.2" 286 | } 287 | }, 288 | "nbformat": 4, 289 | "nbformat_minor": 2 290 | } 291 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | from constants import * 2 | from optparse import OptionParser 3 | 4 | 5 | def print_options(options): 6 | print("-"*80) 7 | print("Using options:") 8 | values = vars(options) 9 | for key in values.keys(): 10 | print("{:15s} {}".format(key, values[key])) 11 | print("-"*80) 12 | 13 | def log_options(options): 14 | logname = "{}.txt".format(options.model) 15 | log_file = os.path.join(OPTIONS_DIR, logname) 16 | os.makedirs(OPTIONS_DIR, exist_ok=True) 17 | 18 | f = open(log_file, 'w') 19 | 20 | f.write("Using options:\n") 21 | values = vars(options) 22 | for key in values.keys(): 23 | f.write("{:15s} {}\n".format(key, values[key])) 24 | f.close() 25 | 26 | def create_options(): 27 | parser = OptionParser() 28 | parser.add_option("-l", "--lr", dest="learning_rate", default=DEFAULT_LEARNING_RATE, type="float", 29 | help="learning rate") 30 | parser.add_option("-d","--decoder", dest="decoder", default=DEFAULT_DECODER, 31 | help="Decoder structure 'FC' or 'Conv'") 32 | parser.add_option("-b", "--batch_size", dest="batch_size", default=DEFAULT_BATCH_SIZE, type="int") 33 | parser.add_option("-e", "--epochs", dest="epochs", default=DEFAULT_EPOCHS, type="int", 34 | help="Number of epochs to train for") 35 | parser.add_option("-f", "--file", dest="filepath", default="", type="string", 36 | help="Name of the model to be loaded") 37 | parser.add_option("-g", "--use_gpu", dest="use_gpu", default=DEFAULT_USE_GPU, action="store_false", 38 | help="Indicates whether or not to use GPU") 39 | parser.add_option("--save_images", dest="save_images", default=True, action="store_false", 40 | help="Set if you want to save reconstruction results each epoch") 41 | parser.add_option("-a", "--alpha", dest="alpha", default=DEFAULT_ALPHA, type="float", 42 | help="Alpha constant from paper (Amount of reconstruction loss)") 43 | parser.add_option("--dataset", dest="dataset", default=DEFAULT_DATASET, help="Set wanted dataset. Options: [mnist, small_norb,cifar10]") 44 | parser.add_option("-r", "--routing", dest="routing_iterations", default=DEFAULT_ROUTING_ITERATIONS, type="int", 45 | help="Number of routing iterations to use") 46 | parser.add_option("--logfile", dest="log_filepath", default="", type="string", 47 | help="Path to previous logfile if continuing training") 48 | parser.add_option("--gpu_ids", dest="gpu_ids", default=None, type="str", 49 | help="GPU IDS to use if training on multiple GPU. Give ID with comma seperators.") 50 | parser.add_option("--batch_norm", dest="batch_norm", default=False, type=int, 51 | help="Turn on/off batch norm in encoder/decoder") 52 | parser.add_option("--loss", dest="loss_type", default="L2", 53 | help="Define reconstruction loss. Types: [L1, L2]") 54 | parser.add_option("--anneal", dest="anneal_alpha", default="none", 55 | help="Set annealing function for alpha. Options: [none, 1, 2]") 56 | parser.add_option("--leaky", dest="leaky_routing", default=False, action="store_true", 57 | help="Turn on/off leaky routing (Add orphan class for reconstruction)") 58 | parser.add_option("--model", dest="model", help="Set model name") 59 | 60 | 61 | 62 | options, args = parser.parse_args() 63 | assert options.model is not None, "You have to set a model name with the argument --model" 64 | if options.gpu_ids: 65 | options.gpu_ids = [int(x) for x in options.gpu_ids.split(',')] 66 | print_options(options) 67 | log_options(options) 68 | 69 | return options 70 | 71 | 72 | 73 | if __name__ == '__main__': 74 | options = create_options() 75 | 76 | 77 | -------------------------------------------------------------------------------- /pictures/capsnet_deconv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanleet/CapsNet/ac29d616e8f307037c65848864c1a8febb3aa6b7/pictures/capsnet_deconv.png -------------------------------------------------------------------------------- /pictures/cifar_reconstruction_epoch_86.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanleet/CapsNet/ac29d616e8f307037c65848864c1a8febb3aa6b7/pictures/cifar_reconstruction_epoch_86.png -------------------------------------------------------------------------------- /pictures/primary_caps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanleet/CapsNet/ac29d616e8f307037c65848864c1a8febb3aa6b7/pictures/primary_caps.png -------------------------------------------------------------------------------- /pictures/rec_visualization.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanleet/CapsNet/ac29d616e8f307037c65848864c1a8febb3aa6b7/pictures/rec_visualization.gif -------------------------------------------------------------------------------- /pictures/reconstruction_epoch_50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanleet/CapsNet/ac29d616e8f307037c65848864c1a8febb3aa6b7/pictures/reconstruction_epoch_50.png -------------------------------------------------------------------------------- /pictures/robust_rotation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanleet/CapsNet/ac29d616e8f307037c65848864c1a8febb3aa6b7/pictures/robust_rotation.gif -------------------------------------------------------------------------------- /pictures/smallnorb_rec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethanleet/CapsNet/ac29d616e8f307037c65848864c1a8febb3aa6b7/pictures/smallnorb_rec.png -------------------------------------------------------------------------------- /smallNorb.py: -------------------------------------------------------------------------------- 1 | # Loader taken from https://github.com/mavanb/vision/blob/448fac0f38cab35a387666d553b9d5e4eec4c5e6/torchvision/datasets/utils.py 2 | 3 | from __future__ import print_function 4 | import os 5 | import errno 6 | import struct 7 | 8 | import torch 9 | import torch.utils.data as data 10 | import numpy as np 11 | from PIL import Image 12 | from torchvision.datasets.utils import download_url, check_integrity 13 | 14 | 15 | class SmallNORB(data.Dataset): 16 | """`MNIST `_ Dataset. 17 | Args: 18 | root (string): Root directory of dataset where processed folder and 19 | and raw folder exist. 20 | train (bool, optional): If True, creates dataset from the training files, 21 | otherwise from the test files. 22 | download (bool, optional): If true, downloads the dataset from the internet and 23 | puts it in root directory. If the dataset is already processed, it is not processed 24 | and downloaded again. If dataset is only already downloaded, it is not 25 | downloaded again. 26 | transform (callable, optional): A function/transform that takes in an PIL image 27 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 28 | target_transform (callable, optional): A function/transform that takes in the 29 | target and transforms it. 30 | info_transform (callable, optional): A function/transform that takes in the 31 | info and transforms it. 32 | mode (string, optional): Denotes how the images in the data files are returned. Possible values: 33 | - all (default): both left and right are included separately. 34 | - stereo: left and right images are included as corresponding pairs. 35 | - left: only the left images are included. 36 | - right: only the right images are included. 37 | """ 38 | 39 | dataset_root = "https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/" 40 | data_files = { 41 | 'train': { 42 | 'dat': { 43 | "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat', 44 | "md5_gz": "66054832f9accfe74a0f4c36a75bc0a2", 45 | "md5": "8138a0902307b32dfa0025a36dfa45ec" 46 | }, 47 | 'info': { 48 | "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-info.mat', 49 | "md5_gz": "51dee1210a742582ff607dfd94e332e3", 50 | "md5": "19faee774120001fc7e17980d6960451" 51 | }, 52 | 'cat': { 53 | "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat', 54 | "md5_gz": "23c8b86101fbf0904a000b43d3ed2fd9", 55 | "md5": "fd5120d3f770ad57ebe620eb61a0b633" 56 | }, 57 | }, 58 | 'test': { 59 | 'dat': { 60 | "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat', 61 | "md5_gz": "e4ad715691ed5a3a5f138751a4ceb071", 62 | "md5": "e9920b7f7b2869a8f1a12e945b2c166c" 63 | }, 64 | 'info': { 65 | "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat', 66 | "md5_gz": "a9454f3864d7fd4bb3ea7fc3eb84924e", 67 | "md5": "7c5b871cc69dcadec1bf6a18141f5edc" 68 | }, 69 | 'cat': { 70 | "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat', 71 | "md5_gz": "5aa791cd7e6016cf957ce9bdb93b8603", 72 | "md5": "fd5120d3f770ad57ebe620eb61a0b633" 73 | }, 74 | }, 75 | } 76 | 77 | raw_folder = 'raw' 78 | processed_folder = 'processed' 79 | train_image_file = 'train_img' 80 | train_label_file = 'train_label' 81 | train_info_file = 'train_info' 82 | test_image_file = 'test_img' 83 | test_label_file = 'test_label' 84 | test_info_file = 'test_info' 85 | extension = '.pt' 86 | 87 | def __init__(self, root, train=True, transform=None, target_transform=None, info_transform=None, download=False, 88 | mode="all"): 89 | 90 | self.root = os.path.expanduser(root) 91 | self.transform = transform 92 | self.target_transform = target_transform 93 | self.info_transform = info_transform 94 | self.train = train # training set or test set 95 | self.mode = mode 96 | 97 | if download: 98 | self.download() 99 | 100 | if not self._check_exists(): 101 | raise RuntimeError('Dataset not found or corrupted.' + 102 | ' You can use download=True to download it') 103 | 104 | # load test or train set 105 | image_file = self.train_image_file if self.train else self.test_image_file 106 | label_file = self.train_label_file if self.train else self.test_label_file 107 | info_file = self.train_info_file if self.train else self.test_info_file 108 | 109 | # load labels 110 | self.labels = self._load(label_file) 111 | 112 | # load info files 113 | self.infos = self._load(info_file) 114 | 115 | # load right set 116 | if self.mode == "left": 117 | self.data = self._load("{}_left".format(image_file)) 118 | 119 | # load left set 120 | elif self.mode == "right": 121 | self.data = self._load("{}_right".format(image_file)) 122 | 123 | elif self.mode == "all" or self.mode == "stereo": 124 | left_data = self._load("{}_left".format(image_file)) 125 | right_data = self._load("{}_right".format(image_file)) 126 | 127 | # load stereo 128 | if self.mode == "stereo": 129 | self.data = torch.stack((left_data, right_data), dim=1) 130 | 131 | # load all 132 | else: 133 | self.data = torch.cat((left_data, right_data), dim=0) 134 | 135 | def __getitem__(self, index): 136 | """ 137 | Args: 138 | index (int): Index 139 | Returns: 140 | mode ``all'', ``left'', ``right'': 141 | tuple: (image, target, info) 142 | mode ``stereo'': 143 | tuple: (image left, image right, target, info) 144 | """ 145 | target = self.labels[index % 24300] if self.mode is "all" else self.labels[index] 146 | if self.target_transform is not None: 147 | target = self.target_transform(target) 148 | 149 | info = self.infos[index % 24300] if self.mode is "all" else self.infos[index] 150 | if self.info_transform is not None: 151 | info = self.info_transform(info) 152 | 153 | if self.mode == "stereo": 154 | img_left = self._transform(self.data[index, 0]) 155 | img_right = self._transform(self.data[index, 1]) 156 | return img_left, img_right, target, info 157 | 158 | img = self._transform(self.data[index]) 159 | return img, target 160 | 161 | def __len__(self): 162 | return len(self.data) 163 | 164 | def _transform(self, img): 165 | # doing this so that it is consistent with all other data sets 166 | # to return a PIL Image 167 | img = Image.fromarray(img.numpy(), mode='L') 168 | 169 | if self.transform is not None: 170 | img = self.transform(img) 171 | return img 172 | 173 | def _load(self, file_name): 174 | return torch.load(os.path.join(self.root, self.processed_folder, file_name + self.extension)) 175 | 176 | def _save(self, file, file_name): 177 | with open(os.path.join(self.root, self.processed_folder, file_name + self.extension), 'wb') as f: 178 | torch.save(file, f) 179 | 180 | def _check_exists(self): 181 | """ Check if processed files exists.""" 182 | files = ( 183 | "{}_left".format(self.train_image_file), 184 | "{}_right".format(self.train_image_file), 185 | "{}_left".format(self.test_image_file), 186 | "{}_right".format(self.test_image_file), 187 | self.test_label_file, 188 | self.train_label_file 189 | ) 190 | fpaths = [os.path.exists(os.path.join(self.root, self.processed_folder, f + self.extension)) for f in files] 191 | return False not in fpaths 192 | 193 | def _flat_data_files(self): 194 | return [j for i in self.data_files.values() for j in list(i.values())] 195 | 196 | def _check_integrity(self): 197 | """Check if unpacked files have correct md5 sum.""" 198 | root = self.root 199 | for file_dict in self._flat_data_files(): 200 | filename = file_dict["name"] 201 | md5 = file_dict["md5"] 202 | fpath = os.path.join(root, self.raw_folder, filename) 203 | if not check_integrity(fpath, md5): 204 | return False 205 | return True 206 | 207 | def download(self): 208 | """Download the SmallNORB data if it doesn't exist in processed_folder already.""" 209 | import gzip 210 | 211 | if self._check_exists(): 212 | return 213 | 214 | # check if already extracted and verified 215 | if self._check_integrity(): 216 | print('Files already downloaded and verified') 217 | else: 218 | # download and extract 219 | for file_dict in self._flat_data_files(): 220 | url = self.dataset_root + file_dict["name"] + '.gz' 221 | filename = file_dict["name"] 222 | gz_filename = filename + '.gz' 223 | md5 = file_dict["md5_gz"] 224 | fpath = os.path.join(self.root, self.raw_folder, filename) 225 | gz_fpath = fpath + '.gz' 226 | 227 | # download if compressed file not exists and verified 228 | download_url(url, os.path.join(self.root, self.raw_folder), gz_filename, md5) 229 | 230 | print('# Extracting data {}\n'.format(filename)) 231 | 232 | with open(fpath, 'wb') as out_f, \ 233 | gzip.GzipFile(gz_fpath) as zip_f: 234 | out_f.write(zip_f.read()) 235 | 236 | os.unlink(gz_fpath) 237 | 238 | # process and save as torch files 239 | print('Processing...') 240 | 241 | # create processed folder 242 | try: 243 | os.makedirs(os.path.join(self.root, self.processed_folder)) 244 | except OSError as e: 245 | if e.errno == errno.EEXIST: 246 | pass 247 | else: 248 | raise 249 | 250 | # read train files 251 | left_train_img, right_train_img = self._read_image_file(self.data_files["train"]["dat"]["name"]) 252 | train_info = self._read_info_file(self.data_files["train"]["info"]["name"]) 253 | train_label = self._read_label_file(self.data_files["train"]["cat"]["name"]) 254 | 255 | # read test files 256 | left_test_img, right_test_img = self._read_image_file(self.data_files["test"]["dat"]["name"]) 257 | test_info = self._read_info_file(self.data_files["test"]["info"]["name"]) 258 | test_label = self._read_label_file(self.data_files["test"]["cat"]["name"]) 259 | 260 | # save training files 261 | self._save(left_train_img, "{}_left".format(self.train_image_file)) 262 | self._save(right_train_img, "{}_right".format(self.train_image_file)) 263 | self._save(train_label, self.train_label_file) 264 | self._save(train_info, self.train_info_file) 265 | 266 | # save test files 267 | self._save(left_test_img, "{}_left".format(self.test_image_file)) 268 | self._save(right_test_img, "{}_right".format(self.test_image_file)) 269 | self._save(test_label, self.test_label_file) 270 | self._save(test_info, self.test_info_file) 271 | 272 | print('Done!') 273 | 274 | @staticmethod 275 | def _parse_header(file_pointer): 276 | # Read magic number and ignore 277 | struct.unpack('