├── cupy ├── __init__.py ├── .gitignore ├── __pycache__ │ ├── mnist.cpython-35.pyc │ ├── optim.cpython-35.pyc │ └── modules.cpython-35.pyc ├── optim.py ├── mnist.py ├── main.py └── modules.py ├── numpy ├── __init__.py ├── .gitignore ├── optim.py ├── mnist.py ├── main.py └── modules.py ├── .gitignore ├── imgs └── README.md ├── pytorch ├── __init__.py ├── .gitignore ├── modules.py └── main.py ├── capsnet.png ├── decoder.png ├── perturb.jpg ├── reconst.jpeg ├── compgraph_digitcaps.png ├── LICENSE └── README.md /cupy/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /numpy/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store -------------------------------------------------------------------------------- /imgs/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /cupy/.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | -------------------------------------------------------------------------------- /pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /numpy/.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | __pycache__/* 3 | -------------------------------------------------------------------------------- /pytorch/.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | saved_models/* 3 | -------------------------------------------------------------------------------- /capsnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/pyCapsNet/master/capsnet.png -------------------------------------------------------------------------------- /decoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/pyCapsNet/master/decoder.png -------------------------------------------------------------------------------- /perturb.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/pyCapsNet/master/perturb.jpg -------------------------------------------------------------------------------- /reconst.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/pyCapsNet/master/reconst.jpeg -------------------------------------------------------------------------------- /compgraph_digitcaps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/pyCapsNet/master/compgraph_digitcaps.png -------------------------------------------------------------------------------- /cupy/__pycache__/mnist.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/pyCapsNet/master/cupy/__pycache__/mnist.cpython-35.pyc -------------------------------------------------------------------------------- /cupy/__pycache__/optim.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/pyCapsNet/master/cupy/__pycache__/optim.cpython-35.pyc -------------------------------------------------------------------------------- /cupy/__pycache__/modules.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/D-K-E/pyCapsNet/master/cupy/__pycache__/modules.cpython-35.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Xander Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cupy/optim.py: -------------------------------------------------------------------------------- 1 | import cupy as cp 2 | 3 | class Optimizer: 4 | def __init__(self): 5 | self.t = 0 6 | 7 | def step(self): 8 | self.t += 1 9 | 10 | def update_val(self, x, dx): 11 | raise NotImplementedError 12 | 13 | def __call__(self, *input, **kwargs): 14 | return self.update_val(*input, **kwargs) 15 | 16 | 17 | class AdamOptimizer(Optimizer): 18 | def __init__(self, lr=1e-2, beta=(0.9,0.999), eps=1e-8): 19 | super(AdamOptimizer, self).__init__() 20 | self.lr = lr 21 | self.beta = beta 22 | self.eps = eps 23 | self.m = None 24 | self.v = None 25 | 26 | def update_val(self, x, dx): 27 | self.m = cp.zeros_like(x) 28 | self.v = cp.zeros_like(x) 29 | m,v,lr,eps = self.m,self.v,self.lr,self.eps 30 | beta1, beta2 = self.beta 31 | m = beta1 * m + (1 - beta1) * dx 32 | v = beta2 * v + (1 - beta2) * dx**2 33 | alpha = lr * cp.sqrt(1 - beta2 ** self.t) / (1 - beta1 ** self.t) 34 | x -= alpha * (m / (cp.sqrt(v) + eps)) 35 | self.m = m 36 | self.v = v 37 | return x 38 | 39 | -------------------------------------------------------------------------------- /numpy/optim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Optimizer: 4 | def __init__(self): 5 | self.t = 0 6 | 7 | def step(self): 8 | self.t += 1 9 | 10 | def update_val(self, x, dx): 11 | raise NotImplementedError 12 | 13 | def __call__(self, *input, **kwargs): 14 | return self.update_val(*input, **kwargs) 15 | 16 | 17 | class AdamOptimizer(Optimizer): 18 | def __init__(self, lr=1e-2, beta=(0.9,0.999), eps=1e-8): 19 | super(AdamOptimizer, self).__init__() 20 | self.lr = lr 21 | self.beta = beta 22 | self.eps = eps 23 | self.m = None 24 | self.v = None 25 | 26 | def update_val(self, x, dx): 27 | self.m = np.zeros_like(x) 28 | self.v = np.zeros_like(x) 29 | m,v,lr,eps = self.m,self.v,self.lr,self.eps 30 | beta1, beta2 = self.beta 31 | m = beta1 * m + (1 - beta1) * dx 32 | v = beta2 * v + (1 - beta2) * dx**2 33 | alpha = lr * np.sqrt(1 - beta2 ** self.t) / (1 - beta1 ** self.t) 34 | x -= alpha * (m / (np.sqrt(v) + eps)) 35 | self.m = m 36 | self.v = v 37 | return x 38 | 39 | -------------------------------------------------------------------------------- /numpy/mnist.py: -------------------------------------------------------------------------------- 1 | import time, os 2 | import numpy as np 3 | from urllib import request 4 | import gzip 5 | import pickle 6 | 7 | 8 | class MNIST: 9 | def __init__(self, path='data', bs=1, shuffle=False): 10 | self.filename = [ 11 | ["training_images","train-images-idx3-ubyte.gz"], 12 | ["test_images","t10k-images-idx3-ubyte.gz"], 13 | ["training_labels","train-labels-idx1-ubyte.gz"], 14 | ["test_labels","t10k-labels-idx1-ubyte.gz"] 15 | ] 16 | self.mean = 0.1307 17 | self.std = 0.3081 18 | self.num_classes = 10 19 | self.bs = bs 20 | self.path = path 21 | 22 | if not os.path.exists(self.path): 23 | os.mkdir(self.path) 24 | if not os.path.exists(self.path+'/mnist.pkl'): 25 | self.download_mnist() 26 | self.load(shuffle=shuffle) 27 | print('Loading complete.') 28 | 29 | def download_mnist(self): 30 | base_url = "http://yann.lecun.com/exdb/mnist/" 31 | for name in self.filename: 32 | print("Downloading "+name[1]+"...") 33 | request.urlretrieve(base_url+name[1], self.path+'/'+name[1]) 34 | print("Download complete.") 35 | self.save_mnist() 36 | 37 | def save_mnist(self): 38 | mnist = {} 39 | for name in self.filename[:2]: 40 | with gzip.open(self.path+'/'+name[1], 'rb') as f: 41 | mnist[name[0]] = ((np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28*28))/255.-self.mean)/self.std 42 | for name in self.filename[-2:]: 43 | with gzip.open(self.path+'/'+name[1], 'rb') as f: 44 | mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8) 45 | with open(self.path+'/'+"mnist.pkl", 'wb') as f: 46 | pickle.dump(mnist,f) 47 | print("Save complete.") 48 | 49 | def chunks(self, l): 50 | for i in range(0, len(l), self.bs): 51 | yield l[i:i + self.bs] 52 | 53 | def load(self, shuffle=False): 54 | with open(self.path+"/mnist.pkl",'rb') as f: 55 | mnist = pickle.load(f) 56 | if shuffle: 57 | n = mnist['training_images'].shape[0] 58 | idxs = np.arange(n) 59 | np.random.shuffle(idxs) 60 | mnist['training_images'] = mnist['training_images'].reshape((-1,1,28,28)) 61 | mnist['training_images'] = list(self.chunks(mnist['training_images'][idxs])) 62 | mnist['training_labels'] = list(self.chunks(mnist['training_labels'][idxs])) 63 | self.train_dataset = zip(mnist['training_images'], mnist['training_labels']) 64 | 65 | n = mnist['test_images'].shape[0] 66 | idxs = np.arange(n) 67 | np.random.shuffle(idxs) 68 | mnist['test_images'] = mnist['test_images'].reshape((-1,1,28,28)) 69 | mnist['test_images'] = list(self.chunks(mnist['test_images'][idxs])) 70 | mnist['test_labels'] = list(self.chunks(mnist['test_labels'][idxs])) 71 | self.eval_dataset = zip(mnist['test_images'], mnist['test_labels']) 72 | 73 | -------------------------------------------------------------------------------- /cupy/mnist.py: -------------------------------------------------------------------------------- 1 | import time, os 2 | import numpy as np 3 | import cupy as cp 4 | from urllib import request 5 | import gzip 6 | import pickle 7 | 8 | 9 | class MNIST: 10 | def __init__(self, path='data', bs=1, shuffle=False): 11 | self.filename = [ 12 | ["training_images","train-images-idx3-ubyte.gz"], 13 | ["test_images","t10k-images-idx3-ubyte.gz"], 14 | ["training_labels","train-labels-idx1-ubyte.gz"], 15 | ["test_labels","t10k-labels-idx1-ubyte.gz"] 16 | ] 17 | self.mean = 0.1307 18 | self.std = 0.3081 19 | self.num_classes = 10 20 | self.bs = bs 21 | self.path = path 22 | if not os.path.exists(self.path): 23 | os.mkdir(self.path) 24 | if not os.path.exists(self.path+'/mnist.pkl'): 25 | self.download_mnist() 26 | self.load(shuffle=shuffle) 27 | print('Loading complete.') 28 | 29 | def download_mnist(self): 30 | base_url = "http://yann.lecun.com/exdb/mnist/" 31 | for name in self.filename: 32 | print("Downloading "+name[1]+"...") 33 | request.urlretrieve(base_url+name[1], self.path+'/'+name[1]) 34 | print("Download complete.") 35 | self.save_mnist() 36 | 37 | def save_mnist(self): 38 | mnist = {} 39 | for name in self.filename[:2]: 40 | with gzip.open(self.path+'/'+name[1], 'rb') as f: 41 | mnist[name[0]] = ((np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28*28))/255.-self.mean)/self.std 42 | for name in self.filename[-2:]: 43 | with gzip.open(self.path+'/'+name[1], 'rb') as f: 44 | mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8) 45 | with open(self.path+'/'+"mnist.pkl", 'wb') as f: 46 | pickle.dump(mnist,f) 47 | print("Save complete.") 48 | 49 | def chunks(self, l): 50 | for i in range(0, len(l), self.bs): 51 | yield l[i:i + self.bs] 52 | 53 | def load(self, shuffle=False): 54 | with open(self.path+"/mnist.pkl",'rb') as f: 55 | mnist = pickle.load(f) 56 | if shuffle: 57 | n = mnist['training_images'].shape[0] 58 | idxs = np.arange(n) 59 | np.random.shuffle(idxs) 60 | mnist['training_images'] = mnist['training_images'].reshape((-1,1,28,28)) 61 | mnist['training_images'] = list(self.chunks(mnist['training_images'][idxs])) 62 | mnist['training_labels'] = list(self.chunks(mnist['training_labels'][idxs])) 63 | self.train_dataset = zip(cp.array(mnist['training_images']), cp.array(mnist['training_labels'])) 64 | 65 | n = mnist['test_images'].shape[0] 66 | idxs = np.arange(n) 67 | np.random.shuffle(idxs) 68 | mnist['test_images'] = mnist['test_images'].reshape((-1,1,28,28)) 69 | mnist['test_images'] = list(self.chunks(mnist['test_images'][idxs])) 70 | mnist['test_labels'] = list(self.chunks(mnist['test_labels'][idxs])) 71 | self.eval_dataset = zip(cp.array(mnist['test_images']), cp.array(mnist['test_labels'])) 72 | 73 | -------------------------------------------------------------------------------- /cupy/main.py: -------------------------------------------------------------------------------- 1 | from modules import * 2 | import time, os, argparse 3 | import cupy as cp 4 | from mnist import MNIST 5 | from modules import CapsNet, CapsLoss 6 | from optim import AdamOptimizer 7 | 8 | 9 | def parse_args(): 10 | """ 11 | Parse input arguments 12 | """ 13 | parser = argparse.ArgumentParser(description='Cupy Capsnet') 14 | parser.add_argument('--bs', dest='bs', 15 | help='batch size', 16 | default='100', type=int) 17 | parser.add_argument('--lr', dest='lr', 18 | help='learning rate', 19 | default=1e-2, type=float) 20 | parser.add_argument('--opt', dest='opt', 21 | help='optimizer', 22 | default='adam', type=str) 23 | parser.add_argument('--disp', dest='disp_interval', 24 | help='interval to display training loss', 25 | default='10', type=int) 26 | parser.add_argument('--num_epochs', dest='num_epochs', 27 | help='num epochs to train', 28 | default='100', type=int) 29 | parser.add_argument('--val_epoch', dest='val_epoch', 30 | help='num epochs to run validation', 31 | default='1', type=int) 32 | 33 | args = parser.parse_args() 34 | 35 | return args 36 | 37 | if __name__ == '__main__': 38 | 39 | args = parse_args() 40 | 41 | mnist = MNIST(bs=args.bs, shuffle=True) 42 | eye = cp.eye(mnist.num_classes) 43 | model = CapsNet() 44 | 45 | criterion = CapsLoss() 46 | if args.opt == 'adam': 47 | optimizer = AdamOptimizer(lr=args.lr) 48 | 49 | print('Training started!') 50 | 51 | for epoch in range(args.num_epochs): 52 | start = time.time() 53 | 54 | # train 55 | correct = 0 56 | for batch_idx, (imgs, targets) in enumerate(mnist.train_dataset): 57 | optimizer.step() 58 | if imgs.shape[0] != args.bs: 59 | continue 60 | 61 | targets = eye[targets] 62 | scores, reconst = model(imgs) 63 | loss, grad = criterion(scores, targets, reconst, imgs) 64 | model.backward(grad, optimizer) 65 | 66 | classes = cp.argmax(scores, axis=1) 67 | predicted = eye[cp.squeeze(classes), :] 68 | 69 | predicted_idx = cp.argmax(predicted, 1) 70 | label_idx = cp.argmax(targets, 1) 71 | correct = cp.sum(predicted_idx == label_idx) 72 | 73 | # info 74 | if batch_idx % args.disp_interval == 0: 75 | end = time.time() 76 | print("[epoch %2d][iter %4d] loss: %.4f, acc: %.4f%% (%d/%d)" \ 77 | % (epoch, batch_idx, loss, 100.*correct/args.bs, correct, args.bs)) 78 | 79 | # val 80 | if epoch % args.val_epoch == 0: 81 | print('Validating...') 82 | correct = 0 83 | total = 0 84 | 85 | for batch_idx, (imgs, targets) in enumerate(mnist.eval_dataset): 86 | if imgs.shape[0] != args.bs: 87 | continue 88 | 89 | targets = eye[targets] 90 | scores, reconst = model(imgs) 91 | loss, grad = criterion(scores, targets, reconst, imgs) 92 | model.backward(grad, optimizer) 93 | 94 | classes = cp.argmax(scores, axis=1) 95 | predicted = eye[cp.squeeze(classes, axis=1), :] 96 | 97 | predicted_idx = cp.argmax(predicted, 1) 98 | label_idx = cp.argmax(targets, 1) 99 | correct += cp.sum(predicted_idx == label_idx) 100 | total += targets.shape[0] 101 | 102 | print("[epoch %2d] val acc: %.4f%% (%d/%d)" \ 103 | % (epoch, 100.*correct/total, correct, total)) 104 | -------------------------------------------------------------------------------- /numpy/main.py: -------------------------------------------------------------------------------- 1 | from modules import * 2 | import time, os, argparse 3 | import numpy as np 4 | from mnist import MNIST 5 | from modules import CapsNet, CapsLoss 6 | from optim import AdamOptimizer 7 | import multiprocessing as mp 8 | 9 | def parse_args(): 10 | """ 11 | Parse input arguments 12 | """ 13 | parser = argparse.ArgumentParser(description='Cupy Capsnet') 14 | parser.add_argument('--bs', dest='bs', 15 | help='batch size', 16 | default='100', type=int) 17 | parser.add_argument('--lr', dest='lr', 18 | help='learning rate', 19 | default=1e-2, type=float) 20 | parser.add_argument('--opt', dest='opt', 21 | help='optimizer', 22 | default='adam', type=str) 23 | parser.add_argument('--disp', dest='disp_interval', 24 | help='interval to display training loss', 25 | default='1', type=int) 26 | parser.add_argument('--num_epochs', dest='num_epochs', 27 | help='num epochs to train', 28 | default='100', type=int) 29 | parser.add_argument('--val_epoch', dest='val_epoch', 30 | help='num epochs to run validation', 31 | default='1', type=int) 32 | 33 | args = parser.parse_args() 34 | 35 | return args 36 | 37 | if __name__ == '__main__': 38 | mp.set_start_method('spawn') 39 | args = parse_args() 40 | 41 | mnist = MNIST(bs=args.bs, shuffle=True) 42 | eye = np.eye(mnist.num_classes) 43 | model = CapsNet() 44 | 45 | criterion = CapsLoss() 46 | if args.opt == 'adam': 47 | optimizer = AdamOptimizer(lr=args.lr) 48 | 49 | print('Training started!') 50 | 51 | for epoch in range(args.num_epochs): 52 | start = time.time() 53 | 54 | # train 55 | correct = 0 56 | for batch_idx, (imgs, targets) in enumerate(mnist.train_dataset): 57 | optimizer.step() 58 | if imgs.shape[0] != args.bs: 59 | continue 60 | 61 | targets = eye[targets] 62 | scores, reconst = model(imgs) 63 | loss, grad = criterion(scores, targets, reconst, imgs) 64 | model.backward(grad, optimizer) 65 | 66 | classes = np.argmax(scores, axis=1) 67 | predicted = eye[np.squeeze(classes), :] 68 | 69 | predicted_idx = np.argmax(predicted, 1) 70 | label_idx = np.argmax(targets, 1) 71 | correct = np.sum(predicted_idx == label_idx) 72 | 73 | # info 74 | if batch_idx % args.disp_interval == 0: 75 | end = time.time() 76 | print("[epoch %2d][iter %4d] loss: %.4f, acc: %.4f%% (%d/%d)" \ 77 | % (epoch, batch_idx, loss, 100.*correct/args.bs, correct, args.bs)) 78 | 79 | # val 80 | if epoch % args.val_epoch == 0: 81 | print('Validating...') 82 | correct = 0 83 | total = 0 84 | 85 | for batch_idx, (imgs, targets) in enumerate(mnist.eval_dataset): 86 | if imgs.shape[0] != args.bs: 87 | continue 88 | 89 | targets = eye[targets] 90 | scores, reconst = model(imgs) 91 | loss, grad = criterion(scores, targets, reconst, imgs) 92 | model.backward(grad, optimizer) 93 | 94 | classes = np.argmax(scores, axis=1) 95 | predicted = eye[np.squeeze(classes, axis=1), :] 96 | 97 | predicted_idx = np.argmax(predicted, 1) 98 | label_idx = np.argmax(targets, 1) 99 | correct += np.sum(predicted_idx == label_idx) 100 | total += targets.shape[0] 101 | 102 | print("[epoch %2d] val acc: %.4f%% (%d/%d)" \ 103 | % (epoch, 100.*correct/total, correct, total)) 104 | -------------------------------------------------------------------------------- /pytorch/modules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchvision import transforms, datasets 6 | import time, os 7 | from torch.autograd import Variable 8 | 9 | def squash(s, dim=-1): 10 | norm2 = torch.sum(s**2, dim=dim, keepdim=True) 11 | norm = torch.sqrt(norm2) 12 | return (norm2 / (1.0 + norm2)) * (s / norm) 13 | 14 | class PrimaryCaps(nn.Module): 15 | def __init__(self, use_cuda=False, out_channels=32, in_channels=256, ncaps=32*6*6, ndim=8, kernel_size=9, stride=2, padding=0): 16 | super(PrimaryCaps, self).__init__() 17 | self.ncaps = ncaps 18 | self.ndim = ndim 19 | self.caps = nn.ModuleList( 20 | [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) for _ in 21 | range(ndim)]) 22 | 23 | def forward(self, x): 24 | u = torch.cat([cap(x).view(x.size(0), -1, 1) for cap in self.caps], dim=-1) 25 | # output (bs, ncaps, ndim) 26 | return squash(u) 27 | 28 | 29 | class DigitCaps(nn.Module): 30 | def __init__(self, use_cuda=False, ncaps=10, ncaps_prev=32 * 6 * 6, ndim_prev=8, ndim=16): 31 | super(DigitCaps, self).__init__() 32 | self.use_cuda = use_cuda 33 | self.ndim_prev = ndim_prev 34 | self.ncaps_prev = ncaps_prev 35 | self.ncaps = ncaps 36 | self.route_iter = 3 37 | self.W = nn.Parameter(torch.randn(1, ncaps_prev, ncaps, ndim, ndim_prev)) 38 | 39 | def forward(self, x): 40 | bs = x.size(0) 41 | x = torch.stack([x] * self.ncaps, dim=2).unsqueeze(-1) 42 | W = torch.cat([self.W] * bs, dim=0) 43 | u_hat = W @ x 44 | 45 | b = Variable(torch.zeros(1, self.ncaps_prev, self.ncaps, 1)) 46 | if self.use_cuda: 47 | b = b.cuda() 48 | 49 | for i in range(self.route_iter): 50 | c = F.softmax(b) 51 | c = torch.cat([c] * bs, dim=0).unsqueeze(-1) 52 | 53 | s = (c * u_hat).sum(dim=1, keepdim=True) 54 | v = squash(s) 55 | 56 | if i < self.route_iter - 1: 57 | b = b + torch.matmul(u_hat.transpose(-1, -2), torch.cat([v] * self.ncaps_prev, dim=1)) \ 58 | .squeeze(-1).mean(dim=0, keepdim=True) 59 | return v.squeeze(1) 60 | 61 | 62 | class Decoder(nn.Module): 63 | def __init__(self): 64 | super(Decoder, self).__init__() 65 | self.net = nn.Sequential( 66 | nn.Linear(16*10,512), 67 | nn.ReLU(inplace=True), 68 | nn.Linear(512,1024), 69 | nn.ReLU(inplace=True), 70 | nn.Linear(1024,784), 71 | nn.Sigmoid() 72 | ) 73 | 74 | def forward(self,x): 75 | x = x.view(x.size(0),-1) 76 | x = self.net(x) 77 | return x 78 | 79 | class CapsNet(nn.Module): 80 | def __init__(self, use_cuda=False, kernel_size=9, stride=1): 81 | super(CapsNet, self).__init__() 82 | 83 | self.conv1 = nn.Conv2d(1,256,kernel_size,stride=stride) 84 | self.primary_caps = PrimaryCaps(use_cuda=use_cuda) 85 | self.digit_caps = DigitCaps(use_cuda=use_cuda) 86 | self.decoder = Decoder() 87 | 88 | def forward(self, inpt): 89 | start = time.time() 90 | x = F.relu(self.conv1(inpt), inplace=True) 91 | x = self.primary_caps(x) 92 | x = self.digit_caps(x) 93 | reconst = self.decoder(x) 94 | return x, reconst 95 | 96 | class CapsLoss(nn.Module): 97 | def __init__(self): 98 | super(CapsLoss, self).__init__() 99 | self.mse_loss = nn.MSELoss() 100 | self.reconst_factor = 0.0005 101 | def forward(self, scores, labels, reconst, inpt): 102 | norms = torch.sqrt(scores).squeeze() 103 | margin_loss = labels * ( F.relu(0.9 - norms, inplace=True) )**2 + 0.5*(1-labels) * ( F.relu(norms - 0.1, inplace=True) )**2 104 | margin_loss = margin_loss.sum(dim=-1).mean() 105 | reconst_loss = self.mse_loss(reconst.view(reconst.size(0),-1), inpt.view(inpt.size(0),-1)) 106 | return margin_loss + self.reconst_factor * reconst_loss 107 | 108 | -------------------------------------------------------------------------------- /pytorch/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchvision import transforms, datasets 6 | import time, os, argparse 7 | from torch.autograd import Variable 8 | from modules import * 9 | 10 | 11 | class MNIST: 12 | def __init__(self, bs=1): 13 | dataset_transform = transforms.Compose([ 14 | transforms.ToTensor(), 15 | transforms.Normalize((0.1307,), (0.3081,)) 16 | ]) 17 | 18 | train_dataset = datasets.MNIST('data', train=True, download=True, transform=dataset_transform) 19 | eval_dataset = datasets.MNIST('data', train=False, download=True, transform=dataset_transform) 20 | 21 | self.num_classes = 10 22 | self.train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=bs, shuffle=True) 23 | self.eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=bs, shuffle=True) 24 | 25 | def parse_args(): 26 | """ 27 | Parse input arguments 28 | """ 29 | parser = argparse.ArgumentParser(description='Cupy Capsnet') 30 | parser.add_argument('--bs', dest='bs', 31 | help='batch size', 32 | default='100', type=int) 33 | parser.add_argument('--lr', dest='lr', 34 | help='learning rate', 35 | default=1e-2, type=float) 36 | parser.add_argument('--opt', dest='optimizer', 37 | help='optimizer', 38 | default='adam', type=str) 39 | parser.add_argument('--disp', dest='disp_interval', 40 | help='interval to display training loss', 41 | default=1, type=int) 42 | parser.add_argument('--num_epochs', dest='num_epochs', 43 | help='num epochs to train', 44 | default=100, type=int) 45 | parser.add_argument('--val_epoch', dest='val_epoch', 46 | help='num epochs to run validation', 47 | default=1, type=int) 48 | parser.add_argument('--save_epoch', dest='save_epoch', 49 | help='num epochs to save model', 50 | default=1, type=int) 51 | parser.add_argument('--use_cuda', dest='use_cuda', 52 | help='whether or not to use cuda', 53 | default=True, type=bool) 54 | parser.add_argument('--save_dir', dest='save_dir', 55 | help='directory to save trained models', 56 | default=True, type=bool) 57 | 58 | args = parser.parse_args() 59 | 60 | return args 61 | 62 | if __name__ == '__main__': 63 | args = parse_args() 64 | 65 | if not os.path.exists(args.save_dir): 66 | os.makedirs(args.save_dir) 67 | 68 | mnist = MNIST(bs=args.bs) 69 | # Variables 70 | inputs = torch.FloatTensor(1) 71 | labels = torch.FloatTensor(1) 72 | eye = Variable(torch.eye(mnist.num_classes)) 73 | inputs = Variable(inputs) 74 | labels = Variable(labels) 75 | 76 | # Model 77 | model = CapsNet(use_cuda=args.use_cuda) 78 | 79 | # cuda 80 | if args.use_cuda: 81 | inputs = inputs.cuda() 82 | labels = labels.cuda() 83 | model = model.cuda() 84 | eye = eye.cuda() 85 | 86 | params = [] 87 | 88 | for key, value in dict(model.named_parameters()).items(): 89 | if value.requires_grad: 90 | params += [{'params':[value],'lr':args.lr}] 91 | 92 | # optimizer 93 | if args.optimizer == "adam": 94 | optimizer = torch.optim.Adam(model.parameters()) 95 | elif args.optimizer == "sgd": 96 | optimizer = torch.optim.SGD(params) 97 | 98 | criterion = CapsLoss() 99 | 100 | print('Training started!') 101 | 102 | for epoch in range(args.num_epochs): 103 | start = time.time() 104 | 105 | # train 106 | model.train() 107 | correct = 0 108 | train_loss = 0 109 | for batch_idx, (imgs, targets) in enumerate(mnist.train_dataloader): 110 | if imgs.size(0) != args.bs: 111 | continue 112 | 113 | targets = eye.cpu().data.index_select(dim=0, index=targets) 114 | inputs.data.resize_(imgs.size()).copy_(imgs) 115 | labels.data.resize_(targets.size()).copy_(targets) 116 | 117 | optimizer.zero_grad() 118 | outputs, reconst = model(inputs) 119 | 120 | scores = torch.sqrt((outputs ** 2).sum(2)) 121 | loss = criterion(scores, labels, reconst, inputs) 122 | train_loss = loss.data.cpu().numpy()[0] 123 | 124 | # backward 125 | loss.backward() 126 | optimizer.step() 127 | 128 | scores, classes = F.softmax(scores).max(dim=1) 129 | predicted = eye.index_select(dim=0, index=classes.squeeze(1)) 130 | 131 | predicted_idx = np.argmax(predicted.data.cpu().numpy(),1) 132 | label_idx = np.argmax(targets.numpy(), 1) 133 | correct = np.sum(predicted_idx == label_idx) 134 | 135 | # info 136 | if batch_idx % args.disp_interval == 0: 137 | end = time.time() 138 | print("[epoch %2d][iter %4d] loss: %.4f, acc: %.4f%% (%d/%d)" \ 139 | % (epoch, batch_idx, train_loss/(batch_idx+1), 100.*correct/args.bs, correct, args.bs)) 140 | 141 | save_name = os.path.join(args.save_dir, '{}_{}.pth'.format(project_id, epoch)) 142 | if args.save_epoch > 0 and batch_idx % args.save_epoch == 0: 143 | torch.save({ 144 | 'epoch': epoch, 145 | }, save_name) 146 | 147 | # val 148 | if epoch % args.val_epoch == 0: 149 | print('Validating...') 150 | correct = 0 151 | total = 0 152 | model.eval() 153 | for batch_idx, (imgs, targets) in enumerate(mnist.eval_dataloader): 154 | if imgs.size(0) != args.bs: 155 | continue 156 | targets = eye.cpu().data.index_select(dim=0, index=targets) 157 | inputs.data.resize_(imgs.size()).copy_(imgs) 158 | labels.data.resize_(targets.size()).copy_(targets) 159 | 160 | outputs, reconst = model(inputs) 161 | scores = torch.sqrt((outputs ** 2).sum(2)) 162 | scores, classes = F.softmax(scores).max(dim=1) 163 | predicted = eye.index_select(dim=0, index=classes.squeeze(1)) 164 | 165 | predicted_idx = np.argmax(predicted.data.cpu().numpy(),1) 166 | label_idx = np.argmax(targets.numpy(), 1) 167 | correct += np.sum(predicted_idx == label_idx) 168 | total += targets.size(0) 169 | print("[epoch %2d] val acc: %.4f%% (%d/%d)" \ 170 | % (epoch, 100.*correct/total, correct, total)) 171 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pyCapsNet 2 | 3 | [![License][license]][license-url] 4 | 5 | Pytorch, NumPy and CuPy implementations of Capsule Networks (CapsNet), based on the paper [Sabour, Sara, Nicholas Frosst, and Geoffrey E. Hinton. "Dynamic routing between capsules." Advances in Neural Information Processing Systems. 2017.] 6 | 7 | ## Requirements 8 | 9 | * Python 3 10 | 11 | PyTorch Implementation: 12 | * PyTorch 13 | * Tested with PyTorch 0.3.0.post4 14 | * CUDA 8 (if using CUDA) 15 | 16 | CuPy Implementation: 17 | * CuPy 2.0.0 18 | * CUDA 8 19 | 20 | ## Motivation 21 | There are many great implementations of Capsule Networks [with PyTorch], [TensorFlow] and [Keras], so why do we need another one? This project actually provides three implementations of CapsNet: PyTorch, NumPy and CuPy. For the PyTorch version, I implemented CapsNet for performance check and visualizations; for the NumPy and CuPy ones, I implemented CapsNet purely from scratch, both forward and backpropagation, aiming to get a deeper understanding of the structure and the gradient flow of CapsNet. The computation graph that I used for this implementation is provided later in this document. 22 | 23 | The purpose of this project is not to shoot for better performance or optimizing the speed, but to offer a better understanding of CapsNet implementation-wise. Reading the paper thoroughly is a must, but it is easy to get confused when it comes to real implementation. I will provide my own understanding in CapsNet and implementation walkthrough in this document. 24 | 25 | This [video] really helped a lot for me to understand CapsNet. Take a minute and check it out. 26 | 27 | ## Challenges of Implementation 28 | * The 5-dimension tensor for CapsNet can be pretty confusing. Stick with one sequence of dimensions and mind the difference between element-wise and matrix multiplication. 29 | * CuPy and NumPy implementations are built from scratch. It is challenging to make sure gradient flows correctly; I drew computational graph and performed unit tests on each basic modules (e.g.: Squash, Conv2d, Linear, Sequence, losses) and composite ones (e.g.: PrimaryCaps, DigitCaps). Accumulating gradients for the iterative refinement especially requires a clear understanding on the computation flow of DigitCaps. 30 | 31 | ## To Run 32 | For NumPy and CuPy implementations, change into the corresponding directories, and run 33 | ``` 34 | python3 main.py --bs=100 --lr=1e-3 --opt='adam' --disp=10 --num_epochs=100 --val_epoch=1 35 | ``` 36 | For the PyTorch implementation, run 37 | ``` 38 | python3 main.py --bs=100 --lr=1e-3 --opt='adam' --disp=1 --num_epochs=100 --val_epoch=1 --use_cuda=True --save_dir='saved_models' 39 | ``` 40 | To visualize the reconstructed data, run the jupyter notebook in PyTorch/Visualization.ipnb. 41 | 42 | ## Capsule Networks: Key Points 43 | A capsule is a neuron that outputs activity vectors to represent the instantiation parameters of certain entities. The magnitude of the activation vector corresponds to the probability that such entity exists, and the orientation represents the instantiation parameters. The [paper] proposes a multi-layer capsule network for different image classification tasks, and achieved state-of-the-art performance on the MNIST dataset. 44 | 45 | ### Activity Vectors 46 | Unlike the neurons in most neural networks, the capsules in this architecture outputs an activity vector for each input. The paper introduces a nonlinear "squashing" function for the activity vector ![img](http://latex.codecogs.com/svg.latex?%5Ctextbf%7B%5Ctextit%7Bs%7D%7D): 47 | 48 | ![img](http://latex.codecogs.com/svg.latex?%5Ctext%7Bsquash%7D%28%5Ctextit%7B%5Ctextbf%7Bs%7D%7D%29%3D%5Cfrac%7B%7C%7C%5Ctextit%7B%5Ctextbf%7Bs%7D%7D%5E2%7C%7C%7D%7B1%2B%7C%7C%5Ctextit%7B%5Ctextbf%7Bs%7D%7D%5E2%7C%7C%7D%5Cfrac%7B%5Ctextit%7B%5Ctextbf%7Bs%7D%7D%7D%7B%7C%7C%5Ctextit%7B%5Ctextbf%7Bs%7D%7D%7C%7C%7D) 49 | 50 | ### Dynamic Routing Between Capsules 51 | In this paper, the authors replaces the conventional max-pooling layer with dynamic routing, iteratively refining the coupling coefficient to route the prediction made by the last layer to the next one. Such routing is achieved by "routing-by-agreement", where a capsule prefers to route its outputs to the capsules in the next layer whose output has a greater dot product with its own output. A result can be, for example, that the "5" capsule would receive features that agrees with "5". In the paper, the authors investigated this iterative refinement of routing coefficients, and found out that the number of iterations for finding the coefficients can indeed help achieve a lower loss and better stability. However, this sequential and iterative structure can make CapsNet very slow. In this project, I followed the paper and set the number of routing iterations to 3. 52 | 53 | ### The Architecture 54 |

55 | 56 | The Capsule Network in the figure above consists of three parts: a convolutional layer (Conv1) and two capsule layers (PrimaryCaps and DigitCaps). The DigitCaps layer yields a 16-dimensional vector for each of the 10 classes, and the L2 norm of these vectors becomes the class score. The decoder in the figure below consumes these vectors and tries to reconstruct the image. The final loss is the combination of the score loss ("margin loss" in the paper) and the reconstruction loss. 57 | 58 |

59 | 60 | ### Detailed Walkthrough 61 | 62 | Here I'll provide a brief cheat sheet, which I would find extremely helpful if I saw this before implementation. 63 | 64 | * Conv1: 65 | * Input size `(N, C=1, H=28, W=28)` 66 | * `in_channel=1, out_channel=256, kernel_size=(9,9), stride=1, padding=0` 67 | * This convolution yields output size `(N, 256, 20, 20)` 68 | 69 | * PrimaryCaps: 70 | * Input size `(N, C=256, H=20, W=20)` 71 | * `ndim=8` Convolution kernels, each with `in_channel=256, out_channel=32, kernel_size=(9,9), stride=2, padding=0` 72 | * Each convolution yields output size `(N, 32, 6, 6)`; linearize for each batch and concatenate the output of each convolution and feed into DigitCaps. 73 | 74 | * DigitCaps: 75 | * Input size `(N, ncaps_prev=32*6*6, ndim_prev=8)` 76 | * For convenience, each tensor involved in this layer is reshaped into 5 dimensions, corresponding to the dimensions of the weight, which is of size `(1, ncaps_prev=32*6*6, ncaps=10, ndim=16, ndim_prev=8)`. Note that the input and the weight has the same dimensions in `ncaps_prev` and `ncaps`, and `N` can be handled by broadcasting. Focusing on the last two dimensions, the weight is of size `(16, 8)` and the input is of size `(8, 1)`, therefore the output size gives `(16, 1)` in these two dimensions. 77 | 78 | * Outputs size `(N, 1, 10, 16, 1)`. Getting the L2 norm across the 3rd dimension to get scores for each class. This tensor is fed into the decoder to get reconstructions. 79 | 80 | * Decoder: 81 | * Input size `(N, 10, 16)`, 82 | * Three linear layers, sizes `(16*10, 512)`, `(512, 1024)` and `(1024, 28*28)`. The first two are followed with ReLU and the last layer is followed by a Sigmoid layer. 83 | * Outputs `(N, 28*28)`, i.e. the reconstruction. 84 | 85 | ## Computation Graph for DigitCaps 86 | 87 |

88 | 89 | This computation graph was originally used for a better understanding of the gradient flow. I redrew the graph for DigitCaps in SVG to provide a clear illustration for people interested in implementing this part. If you want to implement backpropagation of DigitCaps, mind the accumulated gradient from each routing iteration. 90 | 91 | ## Results 92 | I achieved 99.41% validation accuracy at epoch 22 with the PyTorch implementation, which is close to the number reported on the paper. The CuPy implementation can quickly converge to 90%+, but overall trains slower than the PyTorch version. The NumPy implementation is trained purely on CPU; Though I used multiprocessing in the network, it much slower than the GPU implementations. The reconstructed images are given below. 93 | 94 |

95 | 96 | I also performed the experiment to perturb the 16-dimensional vectors of DigitCaps output, and feed in the pre-trained decoder and try to visualize the meaning of each dimension. The image given below shows that perturbing one dimension could change the orientation, width, stroke width and local features of the reconstructed image. 97 | 98 |

99 | 100 | 101 | ## To-dos 102 | 103 | - [x] Add visualization ipynb for PyTorch implementation 104 | - [ ] Add visualization ipynb for CuPy and NumPy implementations 105 | - [ ] Finish deformable convolution implementation in CuPy 106 | - [ ] Start a project on CuPy automatic differentiation, which could possibly benefit this project 107 | 108 | 109 | [license]: https://img.shields.io/github/license/mashape/apistatus.svg 110 | [license-url]: https://github.com/xanderchf/pyCapsNet/blob/master/LICENSE 111 | [Sabour, Sara, Nicholas Frosst, and Geoffrey E. Hinton. "Dynamic routing between capsules." Advances in Neural Information Processing Systems. 2017.]: https://arxiv.org/abs/1710.09829 112 | [paper]: https://arxiv.org/abs/1710.09829 113 | [with PyTorch]: https://github.com/gram-ai/capsule-networks 114 | [TensorFlow]: https://github.com/ageron/handson-ml 115 | [Keras]: https://github.com/XifengGuo/CapsNet-Keras 116 | [video]: https://www.youtube.com/watch?v=2Kawrd5szHE 117 | -------------------------------------------------------------------------------- /cupy/modules.py: -------------------------------------------------------------------------------- 1 | 2 | # im2col functions adapted from https://github.com/Burton2000/CS231n-2017/blob/master/assignment2/cs231n/im2col.py 3 | 4 | import cupy as cp 5 | import time, os 6 | 7 | 8 | def tile(arr, copy, axis): 9 | return cp.concatenate([arr] * copy, axis=axis) 10 | 11 | 12 | class Module(object): 13 | def __init__(self, trainable=False): 14 | self.trainable = trainable 15 | pass 16 | 17 | def forward(self, x): 18 | raise NotImplementedError 19 | 20 | def backward(self, grad, optimizer=None): 21 | raise NotImplementedError 22 | 23 | def __call__(self, *input, **kwargs): 24 | return self.forward(*input, **kwargs) 25 | 26 | 27 | class Sequence(Module): 28 | def __init__(self, modules): 29 | self._modules = modules 30 | 31 | def forward(self, inpt): 32 | t = time.time() 33 | for module in self._modules: 34 | inpt = module(inpt) 35 | cur = time.time() 36 | t = cur 37 | if module.trainable: 38 | self.trainable = True 39 | return inpt 40 | 41 | def backward(self, grad, optimizer=None): 42 | for module in self._modules[::-1]: 43 | if module.trainable: 44 | grad = module.backward(grad, optimizer) 45 | else: 46 | grad = module.backward(grad) 47 | 48 | return grad 49 | 50 | def modules(self): 51 | return self._modules 52 | 53 | def trainable_modules(self): 54 | return [i for i in self._modules if i.trainable] 55 | 56 | 57 | class Linear(Module): 58 | def __init__(self, in_channel, out_channel): 59 | super(Linear, self).__init__(trainable=True) 60 | std = 1/cp.sqrt(in_channel) 61 | self.w = cp.random.uniform(-std, std, (out_channel, in_channel)) 62 | self.b = cp.random.uniform(-std, std, (1, out_channel)) 63 | self.x = None 64 | 65 | def _set_params(self, params): 66 | w, b = params 67 | self.w = w 68 | self.b = b 69 | if len(self.b.shape) < 2: 70 | self.b = self.b[None,:] 71 | 72 | def forward(self, x): 73 | out = x.dot(self.w.T) + self.b 74 | self.x = x 75 | return out 76 | 77 | def backward(self, grad, optimizer=None): 78 | dw = (self.x.T @ grad).T 79 | db = cp.sum(grad, axis=0, keepdims=True) 80 | # update parameters 81 | if optimizer is not None: 82 | self.w = optimizer(self.w, dw) 83 | self.b = optimizer(self.b, db) 84 | 85 | dx = grad @ self.w 86 | dx = cp.reshape(dx, self.x.shape) 87 | return dx 88 | 89 | 90 | class ReLU(Module): 91 | def __init__(self, alpha=0): 92 | super(ReLU, self).__init__() 93 | self.alpha = alpha 94 | self.x = None 95 | 96 | def forward(self, x): 97 | out = x.copy() 98 | if self.alpha > 0: 99 | out[out<0] = self.alpha*x 100 | else: 101 | out[out<0] = 0 102 | self.x = x 103 | return out 104 | 105 | def backward(self, grad): 106 | dx = grad.copy() 107 | dx[self.x < 0] = 0 108 | return dx 109 | 110 | class Sigmoid(Module): 111 | def __init__(self): 112 | super(Sigmoid, self).__init__() 113 | self.s = None 114 | 115 | def forward(self, x): 116 | self.s = 1/(1 + cp.exp(-x)) 117 | return self.s 118 | 119 | def backward(self, grad): 120 | return grad * (self.s * (1-self.s)) 121 | 122 | 123 | class Conv2d(Module): 124 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad=0, eps=1e-4): 125 | super(Conv2d, self).__init__(trainable=True) 126 | self.ic = in_channels 127 | self.oc = out_channels 128 | self.k = kernel_size 129 | self.s = stride 130 | self.p = pad 131 | 132 | std = 1/(cp.sqrt(self.ic* self.k**2)) 133 | self.W = cp.random.uniform(-std, std, (self.oc,self.ic,self.k,self.k)) 134 | self.b = cp.random.uniform(-std, std, (self.oc, 1)) 135 | 136 | self.X_col = None 137 | self.x_shape = None 138 | 139 | def _set_params(self, params): 140 | W, b = params 141 | self.W = W 142 | self.b = b 143 | 144 | def forward(self, X): 145 | NF, CF, HF, WF = self.W.shape 146 | NX, DX, HX, WX = X.shape 147 | self.x_shape = X.shape 148 | h_out = int((HX - HF + 2 * self.p) / self.s + 1) 149 | w_out = int((WX - WF + 2 * self.p) / self.s + 1) 150 | 151 | X_col = self.im2col_indices(X) 152 | self.X_col = X_col 153 | W_col = self.W.reshape(NF, -1) 154 | 155 | out = W_col @ self.X_col + self.b 156 | out = out.reshape(NF, h_out, w_out, NX) 157 | out = out.transpose(3, 0, 1, 2) 158 | 159 | return out 160 | 161 | 162 | def backward(self, dout, optimizer=None): 163 | NF, CF, HF, WF = self.W.shape 164 | 165 | db = cp.sum(dout, axis=(0, 2, 3)) 166 | db = db.reshape(NF, -1) 167 | 168 | dout_reshaped = dout.transpose(1, 2, 3, 0).reshape(NF, -1) 169 | dW = dout_reshaped @ self.X_col.T 170 | dW = dW.reshape(self.W.shape) 171 | 172 | if optimizer is not None: 173 | self.b = optimizer(self.b, db) 174 | self.W = optimizer(self.W, dW) 175 | 176 | W_reshape = self.W.reshape(NF, -1) 177 | dX_col = W_reshape.T @ dout_reshaped 178 | dX = self.col2im_indices(dX_col.astype(cp.float32)) 179 | 180 | return dX 181 | 182 | def get_im2col_indices(self): 183 | padding, stride, field_height, field_width, x_shape = self.p, self.s, self.k, self.k, self.x_shape 184 | N, C, H, W = x_shape 185 | # assert (H + 2 * padding - field_height) % stride == 0 186 | # assert (W + 2 * padding - field_height) % stride == 0 187 | out_height = int((H + 2 * padding - field_height) / stride + 1) 188 | out_width = int((W + 2 * padding - field_width) / stride + 1) 189 | 190 | i0 = cp.repeat(cp.arange(field_height), field_width) 191 | i0 = cp.tile(i0, C) 192 | i1 = stride * cp.repeat(cp.arange(out_height), out_width) 193 | j0 = cp.tile(cp.arange(field_width), field_height * C) 194 | j1 = stride * cp.tile(cp.arange(out_width), out_height) 195 | i = i0.reshape(-1, 1) + i1.reshape(1, -1) 196 | j = j0.reshape(-1, 1) + j1.reshape(1, -1) 197 | 198 | k = cp.repeat(cp.arange(C), field_height * field_width).reshape(-1, 1) 199 | 200 | return (k.astype(cp.int32), i.astype(cp.int32), j.astype(cp.int32)) 201 | 202 | 203 | def im2col_indices(self, x): 204 | p, stride, field_height, field_width = self.p, self.s, self.k, self.k 205 | x_padded = cp.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant') 206 | 207 | k, i, j = self.get_im2col_indices() 208 | 209 | cols = x_padded[:, k, i, j] 210 | C = x.shape[1] 211 | cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1) 212 | return cols 213 | 214 | 215 | def col2im_indices(self, cols): 216 | field_height, field_width, padding, stride = self.k, self.k, self.p, self.s 217 | N, C, H, W = self.x_shape 218 | H_padded, W_padded = H + 2 * padding, W + 2 * padding 219 | x_padded = cp.zeros((N, C, H_padded, W_padded), dtype=cols.dtype) 220 | k, i, j = self.get_im2col_indices() 221 | cols_reshaped = cols.reshape(C * field_height * field_width, -1, N) 222 | cols_reshaped = cols_reshaped.transpose(2, 0, 1).astype(cp.float32) 223 | cp.scatter_add(x_padded, (slice(None), k, i, j), cols_reshaped) 224 | if padding == 0: 225 | return x_padded 226 | return x_padded[:, :, padding:-padding, padding:-padding] 227 | 228 | 229 | class Softmax(Module): 230 | def __init__(self, dim=-1): 231 | super(Softmax, self).__init__() 232 | self.s = None 233 | self.dim = dim 234 | self.squeeze_len = None 235 | 236 | def forward(self, x, dim=None): 237 | if dim is not None: 238 | self.dim = dim 239 | if self.dim < 0: 240 | self.dim = len(x.shape)+self.dim 241 | self.squeeze_len = x.shape[self.dim] 242 | y = cp.exp(x) 243 | s = y/cp.sum(y, axis=self.dim, keepdims=True) 244 | self.s = s 245 | return s 246 | 247 | def backward(self, grad): 248 | self.s = cp.expand_dims(self.s.swapaxes(self.dim,-1), -1) 249 | grad = cp.expand_dims(grad.swapaxes(self.dim,-1), -1) 250 | mat = self.s @ self.s.swapaxes(-1,-2) 251 | mat = (-mat + cp.eye(mat.shape[-1]) * (mat**0.5)) 252 | grad = mat @ grad 253 | self.s = self.s.swapaxes(self.dim,-1).squeeze(-1) 254 | return grad.swapaxes(self.dim,-2).squeeze(-1) 255 | 256 | 257 | class Squash(Module): 258 | def __init__(self, dim=-1): 259 | super(Squash, self).__init__() 260 | self.dim = dim 261 | self.squeeze_len = None 262 | self.s = None 263 | 264 | def forward(self, s): 265 | self.s = s 266 | self.squeeze_len = s.shape[self.dim] 267 | norm2 = cp.sum((s)**2, axis=self.dim, keepdims=True) 268 | return (cp.sqrt(norm2) / (1.0 + norm2)) * s 269 | 270 | def backward(self, grad): 271 | norm2 = cp.sum((self.s)**2, axis=self.dim, keepdims=True) 272 | norm = cp.sqrt(norm2) 273 | temp = tile((1/(2*(1.+norm2)*norm) - norm/(1.+norm2)**2), self.squeeze_len, self.dim) 274 | dnorm2 = cp.sum(self.s * temp, axis=-1, keepdims=True) 275 | factor = norm/(1+norm2) 276 | return grad * dnorm2 * (2.*self.s) + grad * factor 277 | 278 | class MSELoss(Module): 279 | def __init__(self): 280 | super(MSELoss, self).__init__() 281 | self.x = None 282 | self.y = None 283 | 284 | def forward(self, x, y): 285 | self.x = x 286 | self.y = y 287 | return cp.sum((x - y)**2)/float(x.size), 2*(x - y)/float(x.size) 288 | 289 | 290 | class PrimaryCaps(Module): 291 | def __init__(self, use_cuda=False, out_channels=32, in_channels=256, mapsize=6, ndim=8, kernel_size=9, stride=2, padding=0): 292 | super(PrimaryCaps, self).__init__(trainable=True) 293 | self.ndim = ndim 294 | self.caps = [Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, pad=padding) for _ in 295 | range(ndim)] 296 | 297 | self.out_channels = out_channels 298 | self.mapsize = mapsize 299 | self.ncaps = out_channels * mapsize**2 300 | self.squash = Squash() 301 | self.x_size = None 302 | 303 | def _set_params(self, params): 304 | for i, c in enumerate(self.caps): 305 | c._set_params(params[i]) 306 | 307 | def forward(self, x): 308 | t = time.time() 309 | # output (bs, ncaps, ndim) 310 | self.x_size = x.shape 311 | u = cp.concatenate([cap(x).reshape((x.shape[0], -1, 1)) for cap in self.caps], axis=-1) 312 | return self.squash(u) 313 | 314 | def backward(self, grads, optimizer=None): 315 | t = time.time() 316 | grads = self.squash.backward(grads) 317 | grads = grads.reshape((self.x_size[0],self.out_channels, self.mapsize, self.mapsize,-1)) 318 | grads = cp.concatenate([cp.expand_dims(self.caps[i].backward( 319 | grads[:,:,:,:,i], optimizer=optimizer), -1) for i in range(self.ndim)], axis=-1) 320 | out = cp.sum(grads, axis=-1) 321 | return out 322 | 323 | 324 | class Decoder(Module): 325 | def __init__(self): 326 | super(Decoder, self).__init__(trainable=True) 327 | self.net = Sequence([ 328 | Linear(16*10,512), 329 | ReLU(), 330 | Linear(512,1024), 331 | ReLU(), 332 | Linear(1024,784), 333 | Sigmoid() 334 | ]) 335 | self.x_shape = None 336 | 337 | def forward(self, x): 338 | self.x_shape = x.shape 339 | x = x.reshape(x.shape[0],-1) 340 | 341 | return self.net(x) 342 | 343 | def _set_params(self, params): 344 | for i, l in enumerate(self.net.trainable_modules()): 345 | l._set_params(params[i]) 346 | 347 | def backward(self, grad, optimizer): 348 | return self.net.backward(grad, optimizer).reshape(self.x_shape) 349 | 350 | 351 | class DigitCaps(Module): 352 | def __init__(self, ncaps=10, ncaps_prev=32 * 6 * 6, ndim_prev=8, ndim=16): 353 | super(DigitCaps, self).__init__(trainable=True) 354 | self.ndim_prev = ndim_prev 355 | self.ncaps_prev = ncaps_prev 356 | self.ncaps = ncaps 357 | self.route_iter = 2 358 | self.W = cp.random.randn(1, ncaps_prev, ncaps, ndim, ndim_prev) 359 | self.softmaxs = [Softmax() for _ in range(self.route_iter)] 360 | self.squashs = [Squash() for _ in range(self.route_iter)] 361 | self.u_hat = None 362 | self.bs = None 363 | self.b = [None] * self.route_iter 364 | self.v = [None] * self.route_iter 365 | self.x = None 366 | 367 | def _set_params(self, params): 368 | self.W = params 369 | 370 | def forward(self, x): 371 | t = time.time() 372 | self.bs = x.shape[0] 373 | self.x = x 374 | x = tile(x[:,:,None,:,None], self.ncaps, 2) 375 | W = tile(self.W, self.bs, 0) 376 | u_hat = W @ x 377 | self.u_hat = u_hat 378 | b = cp.zeros((1, self.ncaps_prev, self.ncaps, 1, 1)) 379 | 380 | for r in range(self.route_iter): 381 | self.b[r] = b 382 | c = self.softmaxs[r](b, dim=1) 383 | 384 | c = tile(c, self.bs, 0) 385 | s = cp.sum(c * u_hat, axis=1, keepdims=True) 386 | v = self.squashs[r](s) 387 | if r == self.route_iter - 1: 388 | return cp.squeeze(v, axis=1) 389 | 390 | self.v[r] = v 391 | p = u_hat.swapaxes(-1, -2) @ tile(v, self.ncaps_prev, 1) 392 | b = b + cp.mean(p, axis=0, keepdims=True) 393 | 394 | 395 | def backward(self, grad, optimizer=None): 396 | t = time.time() 397 | grad_accum = cp.zeros_like(self.u_hat) 398 | b_grad_accum = None 399 | grad = grad[:,None,:,:,:] 400 | for r in range(self.route_iter)[::-1]: 401 | if r < self.route_iter-1: 402 | grad = b_grad_accum 403 | grad = tile(grad, self.bs, 0)/self.bs 404 | p_grad = tile(self.v[r], self.ncaps_prev, 1) * grad 405 | 406 | grad_accum += p_grad 407 | 408 | grad = self.u_hat * grad 409 | grad = cp.sum(grad, axis=1, keepdims=True) 410 | 411 | grad = self.squashs[r].backward(grad) 412 | grad = tile(grad, self.ncaps_prev, 1) 413 | c = self.softmaxs[r].s 414 | grad_accum += tile(c, self.bs, 0) * grad 415 | grad = self.u_hat.swapaxes(-1,-2) @ grad 416 | 417 | if r > 0: 418 | grad = cp.sum(grad, axis=0, keepdims=True) 419 | grad = self.softmaxs[r].backward(grad) 420 | if b_grad_accum is None: 421 | b_grad_accum = grad 422 | else: 423 | b_grad_accum += grad 424 | 425 | x = tile(self.x[:,:,None,:,None], self.ncaps, 2) 426 | dW = cp.sum(grad_accum @ x.swapaxes(-1,-2), axis=0, keepdims=True) 427 | if optimizer is not None: 428 | self.W = optimizer(self.W, dW) 429 | 430 | grad_accum = cp.squeeze(self.W.swapaxes(-1,-2) @ grad_accum, axis=-1) 431 | dx = cp.sum(grad_accum, axis=2) 432 | return dx 433 | 434 | 435 | class CapsNet(Module): 436 | def __init__(self, use_cuda=False, kernel_size=9, stride=1): 437 | super(CapsNet, self).__init__(trainable=True) 438 | self.net = Sequence([ 439 | Conv2d(1,256,kernel_size=kernel_size,stride=stride), 440 | ReLU(), 441 | PrimaryCaps(), 442 | DigitCaps() 443 | ]) 444 | self.decoder = Decoder() 445 | self.x = None 446 | self.digit_ndim = 16 447 | self.softmax = Softmax() 448 | 449 | def _set_params(self, params): 450 | for i, m in enumerate(self.net.trainable_modules() + [self.decoder]): 451 | m._set_params(params) 452 | 453 | def forward(self, x): 454 | x = self.net(x) 455 | self.x = x 456 | reconst = self.decoder(x) 457 | scores = cp.sqrt((x ** 2).sum(2)).squeeze() 458 | return scores, reconst 459 | 460 | def backward(self, grad, optimizer): 461 | scores_grad, reconst_grad = grad 462 | 463 | scores_grad = scores_grad[:,:,None, None] 464 | t = 0.5 * ((self.x ** 2).sum(2, keepdims=True) ** (-0.5)) 465 | scores_grad *= 0.5 * ((self.x ** 2).sum(2, keepdims=True) ** (-0.5)) 466 | scores_grad = tile(scores_grad, self.digit_ndim, 2) # tile at dimension 2 467 | scores_grad *= 2*self.x 468 | t = time.time() 469 | 470 | reconst_grad = self.decoder.backward(reconst_grad, optimizer) 471 | grad = scores_grad + reconst_grad 472 | 473 | grad = self.net.backward(grad, optimizer=optimizer) 474 | return grad 475 | 476 | 477 | class CapsLoss(Module): 478 | def __init__(self): 479 | super(CapsLoss, self).__init__() 480 | self.mse_loss = MSELoss() 481 | self.relu1 = ReLU() 482 | self.relu2 = ReLU() 483 | self.reconst_factor = 0.0005 484 | 485 | 486 | def forward(self, norms, labels, reconst, inpt): 487 | self.labels = labels 488 | 489 | int1 = self.relu1(0.9 - norms) 490 | int2 = self.relu2(norms - 0.1) 491 | margin_loss = labels * int1**2 + 0.5*(1-labels) * int2**2 492 | bs, ndim_prev = margin_loss.shape[0], margin_loss.shape[-1] 493 | margin_loss = cp.sum(margin_loss, axis=-1).mean() 494 | 495 | reconst_loss, reconst_grad = self.mse_loss(reconst.reshape(reconst.shape[0],-1), inpt.reshape(inpt.shape[0],-1)) 496 | loss = margin_loss + self.reconst_factor * reconst_loss 497 | 498 | margin_grad = cp.ones((bs, ndim_prev)) / float(bs) 499 | margin_grad_pos = -self.relu1.backward(margin_grad * labels * (2*int1)) 500 | margin_grad_neg = self.relu2.backward(margin_grad * 0.5*(1-labels) * (2*int2)) 501 | 502 | margin_grad = margin_grad_pos + margin_grad_neg 503 | reconst_grad *= self.reconst_factor 504 | 505 | return loss, (margin_grad, reconst_grad) -------------------------------------------------------------------------------- /numpy/modules.py: -------------------------------------------------------------------------------- 1 | 2 | # im2col functions adapted from https://github.com/Burton2000/CS231n-2017/blob/master/assignment2/cs231n/im2col.py 3 | 4 | import numpy as np 5 | import time, os 6 | import multiprocessing as mp 7 | from functools import partial 8 | 9 | def tile(arr, copy, axis): 10 | return np.concatenate([arr] * copy, axis=axis) 11 | 12 | 13 | class Module(object): 14 | def __init__(self, trainable=False): 15 | self.trainable = trainable 16 | pass 17 | 18 | def forward(self, x): 19 | raise NotImplementedError 20 | 21 | def backward(self, grad, optimizer=None): 22 | raise NotImplementedError 23 | 24 | def __call__(self, *input, **kwargs): 25 | return self.forward(*input, **kwargs) 26 | 27 | 28 | class Sequence(Module): 29 | def __init__(self, modules): 30 | self._modules = modules 31 | 32 | def forward(self, inpt): 33 | t = time.time() 34 | for module in self._modules: 35 | inpt = module(inpt) 36 | cur = time.time() 37 | t = cur 38 | if module.trainable: 39 | self.trainable = True 40 | return inpt 41 | 42 | def backward(self, grad, optimizer=None): 43 | for module in self._modules[::-1]: 44 | if module.trainable: 45 | grad = module.backward(grad, optimizer) 46 | else: 47 | grad = module.backward(grad) 48 | 49 | return grad 50 | 51 | def modules(self): 52 | return self._modules 53 | 54 | def trainable_modules(self): 55 | return [i for i in self._modules if i.trainable] 56 | 57 | 58 | class Linear(Module): 59 | def __init__(self, in_channel, out_channel): 60 | super(Linear, self).__init__(trainable=True) 61 | std = 1/np.sqrt(in_channel) 62 | self.w = np.random.uniform(-std, std, (out_channel, in_channel)) 63 | self.b = np.random.uniform(-std, std, (1, out_channel)) 64 | self.x = None 65 | 66 | def _set_params(self, params): 67 | w, b = params 68 | self.w = w 69 | self.b = b 70 | if len(self.b.shape) < 2: 71 | self.b = self.b[None,:] 72 | 73 | def forward(self, x): 74 | out = x.dot(self.w.T) + self.b 75 | self.x = x 76 | return out 77 | 78 | def backward(self, grad, optimizer=None): 79 | dw = (self.x.T @ grad).T 80 | db = np.sum(grad, axis=0, keepdims=True) 81 | # update parameters 82 | if optimizer is not None: 83 | self.w = optimizer(self.w, dw) 84 | self.b = optimizer(self.b, db) 85 | 86 | dx = grad @ self.w 87 | dx = np.reshape(dx, self.x.shape) 88 | return dx 89 | 90 | 91 | class ReLU(Module): 92 | def __init__(self, alpha=0): 93 | super(ReLU, self).__init__() 94 | self.alpha = alpha 95 | self.x = None 96 | 97 | def forward(self, x): 98 | out = x.copy() 99 | if self.alpha > 0: 100 | out[out<0] = self.alpha*x 101 | else: 102 | out[out<0] = 0 103 | self.x = x 104 | return out 105 | 106 | def backward(self, grad): 107 | dx = grad.copy() 108 | dx[self.x < 0] = 0 109 | return dx 110 | 111 | class Sigmoid(Module): 112 | def __init__(self): 113 | super(Sigmoid, self).__init__() 114 | self.s = None 115 | 116 | def forward(self, x): 117 | self.s = 1/(1 + np.exp(-x)) 118 | return self.s 119 | 120 | def backward(self, grad): 121 | return grad * (self.s * (1-self.s)) 122 | 123 | 124 | class Conv2d(Module): 125 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad=0, eps=1e-4): 126 | super(Conv2d, self).__init__(trainable=True) 127 | self.ic = in_channels 128 | self.oc = out_channels 129 | self.k = kernel_size 130 | self.s = stride 131 | self.p = pad 132 | 133 | std = 1/(np.sqrt(self.ic* self.k**2)) 134 | self.W = np.random.uniform(-std, std, (self.oc,self.ic,self.k,self.k)) 135 | self.b = np.random.uniform(-std, std, (self.oc, 1)) 136 | 137 | self.X_col = None 138 | self.x_shape = None 139 | 140 | def _set_params(self, params): 141 | W, b = params 142 | self.W = W 143 | self.b = b 144 | 145 | def _set_input(self, x): 146 | self.x_shape = x.shape 147 | self.X_col = self.im2col_indices(x) 148 | 149 | def forward(self, X): 150 | NF, CF, HF, WF = self.W.shape 151 | NX, DX, HX, WX = X.shape 152 | self.x_shape = X.shape 153 | h_out = int((HX - HF + 2 * self.p) / self.s + 1) 154 | w_out = int((WX - WF + 2 * self.p) / self.s + 1) 155 | 156 | X_col = self.im2col_indices(X) 157 | self.X_col = X_col 158 | W_col = self.W.reshape(NF, -1) 159 | 160 | out = W_col @ self.X_col + self.b 161 | out = out.reshape(NF, h_out, w_out, NX) 162 | out = out.transpose(3, 0, 1, 2) 163 | 164 | return out 165 | 166 | 167 | def backward(self, dout, optimizer=None): 168 | NF, CF, HF, WF = self.W.shape 169 | 170 | db = np.sum(dout, axis=(0, 2, 3)) 171 | db = db.reshape(NF, -1) 172 | 173 | dout_reshaped = dout.transpose(1, 2, 3, 0).reshape(NF, -1) 174 | dW = dout_reshaped @ self.X_col.T 175 | dW = dW.reshape(self.W.shape) 176 | 177 | if optimizer is not None: 178 | self.b = optimizer(self.b, db) 179 | self.W = optimizer(self.W, dW) 180 | 181 | W_reshape = self.W.reshape(NF, -1) 182 | dX_col = W_reshape.T @ dout_reshaped 183 | dX = self.col2im_indices(dX_col) 184 | 185 | return dX 186 | 187 | def get_im2col_indices(self): 188 | padding, stride, field_height, field_width, x_shape = self.p, self.s, self.k, self.k, self.x_shape 189 | N, C, H, W = x_shape 190 | # assert (H + 2 * padding - field_height) % stride == 0 191 | # assert (W + 2 * padding - field_height) % stride == 0 192 | out_height = int((H + 2 * padding - field_height) / stride + 1) 193 | out_width = int((W + 2 * padding - field_width) / stride + 1) 194 | 195 | i0 = np.repeat(np.arange(field_height), field_width) 196 | i0 = np.tile(i0, C) 197 | i1 = stride * np.repeat(np.arange(out_height), out_width) 198 | j0 = np.tile(np.arange(field_width), field_height * C) 199 | j1 = stride * np.tile(np.arange(out_width), out_height) 200 | i = i0.reshape(-1, 1) + i1.reshape(1, -1) 201 | j = j0.reshape(-1, 1) + j1.reshape(1, -1) 202 | 203 | k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1) 204 | 205 | return (k.astype(np.int), i.astype(np.int), j.astype(np.int)) 206 | 207 | 208 | def im2col_indices(self, x): 209 | p, stride, field_height, field_width = self.p, self.s, self.k, self.k 210 | x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant') 211 | 212 | k, i, j = self.get_im2col_indices() 213 | 214 | cols = x_padded[:, k, i, j] 215 | C = x.shape[1] 216 | cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1) 217 | return cols 218 | 219 | 220 | def col2im_indices(self, cols): 221 | field_height, field_width, padding, stride = self.k, self.k, self.p, self.s 222 | N, C, H, W = self.x_shape 223 | H_padded, W_padded = H + 2 * padding, W + 2 * padding 224 | x_padded = np.zeros((N, C, H_padded, W_padded), dtype=cols.dtype) 225 | k, i, j = self.get_im2col_indices() 226 | cols_reshaped = cols.reshape(C * field_height * field_width, -1, N) 227 | cols_reshaped = cols_reshaped.transpose(2, 0, 1) 228 | np.add.at(x_padded, (slice(None), k, i, j), cols_reshaped) 229 | if padding == 0: 230 | return x_padded 231 | return x_padded[:, :, padding:-padding, padding:-padding] 232 | 233 | 234 | class Softmax(Module): 235 | def __init__(self, dim=-1): 236 | super(Softmax, self).__init__() 237 | self.s = None 238 | self.dim = dim 239 | self.squeeze_len = None 240 | 241 | def forward(self, x, dim=None): 242 | if dim is not None: 243 | self.dim = dim 244 | if self.dim < 0: 245 | self.dim = len(x.shape)+self.dim 246 | self.squeeze_len = x.shape[self.dim] 247 | y = np.exp(x) 248 | s = y/np.sum(y, axis=self.dim, keepdims=True) 249 | self.s = s 250 | return s 251 | 252 | def backward(self, grad): 253 | self.s = np.expand_dims(self.s.swapaxes(self.dim,-1), -1) 254 | grad = np.expand_dims(grad.swapaxes(self.dim,-1), -1) 255 | mat = self.s @ self.s.swapaxes(-1,-2) 256 | mat = (-mat + np.eye(mat.shape[-1]) * (mat**0.5)) 257 | grad = mat @ grad 258 | self.s = self.s.swapaxes(self.dim,-1).squeeze(-1) 259 | return grad.swapaxes(self.dim,-2).squeeze(-1) 260 | 261 | 262 | class Squash(Module): 263 | def __init__(self, dim=-1): 264 | super(Squash, self).__init__() 265 | self.dim = dim 266 | self.squeeze_len = None 267 | self.s = None 268 | 269 | def forward(self, s): 270 | self.s = s 271 | self.squeeze_len = s.shape[self.dim] 272 | norm2 = np.sum((s)**2, axis=self.dim, keepdims=True) 273 | return (np.sqrt(norm2) / (1.0 + norm2)) * s 274 | 275 | def backward(self, grad): 276 | norm2 = np.sum((self.s)**2, axis=self.dim, keepdims=True) 277 | norm = np.sqrt(norm2) 278 | temp = tile((1/(2*(1.+norm2)*norm) - norm/(1.+norm2)**2), self.squeeze_len, self.dim) 279 | dnorm2 = np.sum(self.s * temp, axis=-1, keepdims=True) 280 | factor = norm/(1+norm2) 281 | return grad * dnorm2 * (2.*self.s) + grad * factor 282 | 283 | class MSELoss(Module): 284 | def __init__(self): 285 | super(MSELoss, self).__init__() 286 | self.x = None 287 | self.y = None 288 | 289 | def forward(self, x, y): 290 | self.x = x 291 | self.y = y 292 | return np.sum((x - y)**2)/float(x.size), 2*(x - y)/float(x.size) 293 | 294 | 295 | class PrimaryCaps(Module): 296 | def __init__(self, use_cuda=False, out_channels=32, in_channels=256, mapsize=6, ndim=8, kernel_size=9, stride=2, padding=0): 297 | super(PrimaryCaps, self).__init__(trainable=True) 298 | self.ndim = ndim 299 | self.caps = [Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, pad=padding) for _ in 300 | range(ndim)] 301 | 302 | self.out_channels = out_channels 303 | self.mapsize = mapsize 304 | self.ncaps = out_channels * mapsize**2 305 | self.squash = Squash() 306 | self.x = None 307 | 308 | def _set_params(self, params): 309 | for i, c in enumerate(self.caps): 310 | c._set_params(params[i]) 311 | 312 | def cap_forward(self, i, x): 313 | out = self.caps[i](x).reshape((x.shape[0], -1, 1)) 314 | return out 315 | 316 | def cap_backward(self, i, grads, x, optimizer): 317 | self.caps[i]._set_input(x) 318 | out = np.expand_dims(self.caps[i].backward( 319 | grads[:,:,:,:,i], optimizer=optimizer), -1) 320 | return out 321 | 322 | def forward(self, x): 323 | t = time.time() 324 | # output (bs, ncaps, ndim) 325 | self.x_size = x.shape 326 | self.x = x 327 | with mp.Pool() as pool: 328 | u = pool.map(partial(self.cap_forward, x=x), np.arange(len(self.caps))) 329 | u = np.concatenate(u, axis=-1) 330 | 331 | return self.squash(u) 332 | 333 | def backward(self, grads, optimizer=None): 334 | t = time.time() 335 | grads = self.squash.backward(grads) 336 | grads = grads.reshape((self.x_size[0],self.out_channels, self.mapsize, self.mapsize,-1)) 337 | 338 | with mp.Pool() as pool: 339 | grads = pool.map(partial(self.cap_backward, grads=grads, x=self.x, optimizer=optimizer), np.arange(len(self.caps))) 340 | grads = np.concatenate(grads, axis=-1) 341 | out = np.sum(grads, axis=-1) 342 | 343 | return out 344 | 345 | 346 | class Decoder(Module): 347 | def __init__(self): 348 | super(Decoder, self).__init__(trainable=True) 349 | self.net = Sequence([ 350 | Linear(16*10,512), 351 | ReLU(), 352 | Linear(512,1024), 353 | ReLU(), 354 | Linear(1024,784), 355 | Sigmoid() 356 | ]) 357 | self.x_shape = None 358 | 359 | def forward(self, x): 360 | self.x_shape = x.shape 361 | x = x.reshape(x.shape[0],-1) 362 | 363 | return self.net(x) 364 | 365 | def _set_params(self, params): 366 | for i, l in enumerate(self.net.trainable_modules()): 367 | l._set_params(params[i]) 368 | 369 | def backward(self, grad, optimizer): 370 | return self.net.backward(grad, optimizer).reshape(self.x_shape) 371 | 372 | 373 | class DigitCaps(Module): 374 | def __init__(self, ncaps=10, ncaps_prev=32 * 6 * 6, ndim_prev=8, ndim=16): 375 | super(DigitCaps, self).__init__(trainable=True) 376 | self.ndim_prev = ndim_prev 377 | self.ncaps_prev = ncaps_prev 378 | self.ncaps = ncaps 379 | self.route_iter = 2 380 | self.W = np.random.randn(1, ncaps_prev, ncaps, ndim, ndim_prev) 381 | self.softmaxs = [Softmax() for _ in range(self.route_iter)] 382 | self.squashs = [Squash() for _ in range(self.route_iter)] 383 | self.u_hat = None 384 | self.bs = None 385 | self.b = [None] * self.route_iter 386 | self.v = [None] * self.route_iter 387 | self.x = None 388 | 389 | def _set_params(self, params): 390 | self.W = params 391 | 392 | def forward(self, x): 393 | t = time.time() 394 | self.bs = x.shape[0] 395 | self.x = x 396 | x = tile(x[:,:,None,:,None], self.ncaps, 2) 397 | W = tile(self.W, self.bs, 0) 398 | u_hat = W @ x 399 | self.u_hat = u_hat 400 | b = np.zeros((1, self.ncaps_prev, self.ncaps, 1, 1)) 401 | 402 | for r in range(self.route_iter): 403 | self.b[r] = b 404 | c = self.softmaxs[r](b, dim=1) 405 | 406 | c = tile(c, self.bs, 0) 407 | s = np.sum(c * u_hat, axis=1, keepdims=True) 408 | v = self.squashs[r](s) 409 | if r == self.route_iter - 1: 410 | return np.squeeze(v, axis=1) 411 | 412 | self.v[r] = v 413 | p = u_hat.swapaxes(-1, -2) @ tile(v, self.ncaps_prev, 1) 414 | b = b + np.mean(p, axis=0, keepdims=True) 415 | 416 | 417 | def backward(self, grad, optimizer=None): 418 | t = time.time() 419 | grad_accum = np.zeros_like(self.u_hat) 420 | b_grad_accum = None 421 | grad = grad[:,None,:,:,:] 422 | for r in range(self.route_iter)[::-1]: 423 | if r < self.route_iter-1: 424 | grad = b_grad_accum 425 | grad = tile(grad, self.bs, 0)/self.bs 426 | p_grad = tile(self.v[r], self.ncaps_prev, 1) * grad 427 | 428 | grad_accum += p_grad 429 | 430 | grad = self.u_hat * grad 431 | grad = np.sum(grad, axis=1, keepdims=True) 432 | 433 | grad = self.squashs[r].backward(grad) 434 | grad = tile(grad, self.ncaps_prev, 1) 435 | c = self.softmaxs[r].s 436 | grad_accum += tile(c, self.bs, 0) * grad 437 | grad = self.u_hat.swapaxes(-1,-2) @ grad 438 | 439 | if r > 0: 440 | grad = np.sum(grad, axis=0, keepdims=True) 441 | grad = self.softmaxs[r].backward(grad) 442 | if b_grad_accum is None: 443 | b_grad_accum = grad 444 | else: 445 | b_grad_accum += grad 446 | 447 | x = tile(self.x[:,:,None,:,None], self.ncaps, 2) 448 | dW = np.sum(grad_accum @ x.swapaxes(-1,-2), axis=0, keepdims=True) 449 | if optimizer is not None: 450 | self.W = optimizer(self.W, dW) 451 | 452 | grad_accum = np.squeeze(self.W.swapaxes(-1,-2) @ grad_accum, axis=-1) 453 | dx = np.sum(grad_accum, axis=2) 454 | return dx 455 | 456 | 457 | class CapsNet(Module): 458 | def __init__(self, use_cuda=False, kernel_size=9, stride=1): 459 | super(CapsNet, self).__init__(trainable=True) 460 | self.net = Sequence([ 461 | Conv2d(1,256,kernel_size=kernel_size,stride=stride), 462 | ReLU(), 463 | PrimaryCaps(), 464 | DigitCaps() 465 | ]) 466 | self.decoder = Decoder() 467 | self.x = None 468 | self.digit_ndim = 16 469 | self.softmax = Softmax() 470 | 471 | def _set_params(self, params): 472 | for i, m in enumerate(self.net.trainable_modules() + [self.decoder]): 473 | m._set_params(params) 474 | 475 | def forward(self, x): 476 | x = self.net(x) 477 | self.x = x 478 | reconst = self.decoder(x) 479 | scores = np.sqrt((x ** 2).sum(2)).squeeze() 480 | return scores, reconst 481 | 482 | def backward(self, grad, optimizer): 483 | scores_grad, reconst_grad = grad 484 | 485 | scores_grad = scores_grad[:,:,None, None] 486 | t = 0.5 * ((self.x ** 2).sum(2, keepdims=True) ** (-0.5)) 487 | scores_grad *= 0.5 * ((self.x ** 2).sum(2, keepdims=True) ** (-0.5)) 488 | scores_grad = tile(scores_grad, self.digit_ndim, 2) # tile at dimension 2 489 | scores_grad *= 2*self.x 490 | t = time.time() 491 | 492 | reconst_grad = self.decoder.backward(reconst_grad, optimizer) 493 | grad = scores_grad + reconst_grad 494 | 495 | grad = self.net.backward(grad, optimizer=optimizer) 496 | return grad 497 | 498 | 499 | class CapsLoss(Module): 500 | def __init__(self): 501 | super(CapsLoss, self).__init__() 502 | self.mse_loss = MSELoss() 503 | self.relu1 = ReLU() 504 | self.relu2 = ReLU() 505 | self.reconst_factor = 0.0005 506 | 507 | 508 | def forward(self, norms, labels, reconst, inpt): 509 | self.labels = labels 510 | 511 | int1 = self.relu1(0.9 - norms) 512 | int2 = self.relu2(norms - 0.1) 513 | margin_loss = labels * int1**2 + 0.5*(1-labels) * int2**2 514 | bs, ndim_prev = margin_loss.shape[0], margin_loss.shape[-1] 515 | margin_loss = np.sum(margin_loss, axis=-1).mean() 516 | 517 | reconst_loss, reconst_grad = self.mse_loss(reconst.reshape(reconst.shape[0],-1), inpt.reshape(inpt.shape[0],-1)) 518 | loss = margin_loss + self.reconst_factor * reconst_loss 519 | 520 | margin_grad = np.ones((bs, ndim_prev)) / float(bs) 521 | margin_grad_pos = -self.relu1.backward(margin_grad * labels * (2*int1)) 522 | margin_grad_neg = self.relu2.backward(margin_grad * 0.5*(1-labels) * (2*int2)) 523 | 524 | margin_grad = margin_grad_pos + margin_grad_neg 525 | reconst_grad *= self.reconst_factor 526 | 527 | return loss, (margin_grad, reconst_grad) --------------------------------------------------------------------------------