├── cupy
├── __init__.py
├── .gitignore
├── __pycache__
│ ├── mnist.cpython-35.pyc
│ ├── optim.cpython-35.pyc
│ └── modules.cpython-35.pyc
├── optim.py
├── mnist.py
├── main.py
└── modules.py
├── numpy
├── __init__.py
├── .gitignore
├── optim.py
├── mnist.py
├── main.py
└── modules.py
├── .gitignore
├── imgs
└── README.md
├── pytorch
├── __init__.py
├── .gitignore
├── modules.py
└── main.py
├── capsnet.png
├── decoder.png
├── perturb.jpg
├── reconst.jpeg
├── compgraph_digitcaps.png
├── LICENSE
└── README.md
/cupy/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/numpy/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
--------------------------------------------------------------------------------
/imgs/README.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/cupy/.gitignore:
--------------------------------------------------------------------------------
1 | data/*
2 |
--------------------------------------------------------------------------------
/pytorch/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/numpy/.gitignore:
--------------------------------------------------------------------------------
1 | data/*
2 | __pycache__/*
3 |
--------------------------------------------------------------------------------
/pytorch/.gitignore:
--------------------------------------------------------------------------------
1 | data/*
2 | saved_models/*
3 |
--------------------------------------------------------------------------------
/capsnet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/D-K-E/pyCapsNet/master/capsnet.png
--------------------------------------------------------------------------------
/decoder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/D-K-E/pyCapsNet/master/decoder.png
--------------------------------------------------------------------------------
/perturb.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/D-K-E/pyCapsNet/master/perturb.jpg
--------------------------------------------------------------------------------
/reconst.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/D-K-E/pyCapsNet/master/reconst.jpeg
--------------------------------------------------------------------------------
/compgraph_digitcaps.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/D-K-E/pyCapsNet/master/compgraph_digitcaps.png
--------------------------------------------------------------------------------
/cupy/__pycache__/mnist.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/D-K-E/pyCapsNet/master/cupy/__pycache__/mnist.cpython-35.pyc
--------------------------------------------------------------------------------
/cupy/__pycache__/optim.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/D-K-E/pyCapsNet/master/cupy/__pycache__/optim.cpython-35.pyc
--------------------------------------------------------------------------------
/cupy/__pycache__/modules.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/D-K-E/pyCapsNet/master/cupy/__pycache__/modules.cpython-35.pyc
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Xander Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/cupy/optim.py:
--------------------------------------------------------------------------------
1 | import cupy as cp
2 |
3 | class Optimizer:
4 | def __init__(self):
5 | self.t = 0
6 |
7 | def step(self):
8 | self.t += 1
9 |
10 | def update_val(self, x, dx):
11 | raise NotImplementedError
12 |
13 | def __call__(self, *input, **kwargs):
14 | return self.update_val(*input, **kwargs)
15 |
16 |
17 | class AdamOptimizer(Optimizer):
18 | def __init__(self, lr=1e-2, beta=(0.9,0.999), eps=1e-8):
19 | super(AdamOptimizer, self).__init__()
20 | self.lr = lr
21 | self.beta = beta
22 | self.eps = eps
23 | self.m = None
24 | self.v = None
25 |
26 | def update_val(self, x, dx):
27 | self.m = cp.zeros_like(x)
28 | self.v = cp.zeros_like(x)
29 | m,v,lr,eps = self.m,self.v,self.lr,self.eps
30 | beta1, beta2 = self.beta
31 | m = beta1 * m + (1 - beta1) * dx
32 | v = beta2 * v + (1 - beta2) * dx**2
33 | alpha = lr * cp.sqrt(1 - beta2 ** self.t) / (1 - beta1 ** self.t)
34 | x -= alpha * (m / (cp.sqrt(v) + eps))
35 | self.m = m
36 | self.v = v
37 | return x
38 |
39 |
--------------------------------------------------------------------------------
/numpy/optim.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class Optimizer:
4 | def __init__(self):
5 | self.t = 0
6 |
7 | def step(self):
8 | self.t += 1
9 |
10 | def update_val(self, x, dx):
11 | raise NotImplementedError
12 |
13 | def __call__(self, *input, **kwargs):
14 | return self.update_val(*input, **kwargs)
15 |
16 |
17 | class AdamOptimizer(Optimizer):
18 | def __init__(self, lr=1e-2, beta=(0.9,0.999), eps=1e-8):
19 | super(AdamOptimizer, self).__init__()
20 | self.lr = lr
21 | self.beta = beta
22 | self.eps = eps
23 | self.m = None
24 | self.v = None
25 |
26 | def update_val(self, x, dx):
27 | self.m = np.zeros_like(x)
28 | self.v = np.zeros_like(x)
29 | m,v,lr,eps = self.m,self.v,self.lr,self.eps
30 | beta1, beta2 = self.beta
31 | m = beta1 * m + (1 - beta1) * dx
32 | v = beta2 * v + (1 - beta2) * dx**2
33 | alpha = lr * np.sqrt(1 - beta2 ** self.t) / (1 - beta1 ** self.t)
34 | x -= alpha * (m / (np.sqrt(v) + eps))
35 | self.m = m
36 | self.v = v
37 | return x
38 |
39 |
--------------------------------------------------------------------------------
/numpy/mnist.py:
--------------------------------------------------------------------------------
1 | import time, os
2 | import numpy as np
3 | from urllib import request
4 | import gzip
5 | import pickle
6 |
7 |
8 | class MNIST:
9 | def __init__(self, path='data', bs=1, shuffle=False):
10 | self.filename = [
11 | ["training_images","train-images-idx3-ubyte.gz"],
12 | ["test_images","t10k-images-idx3-ubyte.gz"],
13 | ["training_labels","train-labels-idx1-ubyte.gz"],
14 | ["test_labels","t10k-labels-idx1-ubyte.gz"]
15 | ]
16 | self.mean = 0.1307
17 | self.std = 0.3081
18 | self.num_classes = 10
19 | self.bs = bs
20 | self.path = path
21 |
22 | if not os.path.exists(self.path):
23 | os.mkdir(self.path)
24 | if not os.path.exists(self.path+'/mnist.pkl'):
25 | self.download_mnist()
26 | self.load(shuffle=shuffle)
27 | print('Loading complete.')
28 |
29 | def download_mnist(self):
30 | base_url = "http://yann.lecun.com/exdb/mnist/"
31 | for name in self.filename:
32 | print("Downloading "+name[1]+"...")
33 | request.urlretrieve(base_url+name[1], self.path+'/'+name[1])
34 | print("Download complete.")
35 | self.save_mnist()
36 |
37 | def save_mnist(self):
38 | mnist = {}
39 | for name in self.filename[:2]:
40 | with gzip.open(self.path+'/'+name[1], 'rb') as f:
41 | mnist[name[0]] = ((np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28*28))/255.-self.mean)/self.std
42 | for name in self.filename[-2:]:
43 | with gzip.open(self.path+'/'+name[1], 'rb') as f:
44 | mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
45 | with open(self.path+'/'+"mnist.pkl", 'wb') as f:
46 | pickle.dump(mnist,f)
47 | print("Save complete.")
48 |
49 | def chunks(self, l):
50 | for i in range(0, len(l), self.bs):
51 | yield l[i:i + self.bs]
52 |
53 | def load(self, shuffle=False):
54 | with open(self.path+"/mnist.pkl",'rb') as f:
55 | mnist = pickle.load(f)
56 | if shuffle:
57 | n = mnist['training_images'].shape[0]
58 | idxs = np.arange(n)
59 | np.random.shuffle(idxs)
60 | mnist['training_images'] = mnist['training_images'].reshape((-1,1,28,28))
61 | mnist['training_images'] = list(self.chunks(mnist['training_images'][idxs]))
62 | mnist['training_labels'] = list(self.chunks(mnist['training_labels'][idxs]))
63 | self.train_dataset = zip(mnist['training_images'], mnist['training_labels'])
64 |
65 | n = mnist['test_images'].shape[0]
66 | idxs = np.arange(n)
67 | np.random.shuffle(idxs)
68 | mnist['test_images'] = mnist['test_images'].reshape((-1,1,28,28))
69 | mnist['test_images'] = list(self.chunks(mnist['test_images'][idxs]))
70 | mnist['test_labels'] = list(self.chunks(mnist['test_labels'][idxs]))
71 | self.eval_dataset = zip(mnist['test_images'], mnist['test_labels'])
72 |
73 |
--------------------------------------------------------------------------------
/cupy/mnist.py:
--------------------------------------------------------------------------------
1 | import time, os
2 | import numpy as np
3 | import cupy as cp
4 | from urllib import request
5 | import gzip
6 | import pickle
7 |
8 |
9 | class MNIST:
10 | def __init__(self, path='data', bs=1, shuffle=False):
11 | self.filename = [
12 | ["training_images","train-images-idx3-ubyte.gz"],
13 | ["test_images","t10k-images-idx3-ubyte.gz"],
14 | ["training_labels","train-labels-idx1-ubyte.gz"],
15 | ["test_labels","t10k-labels-idx1-ubyte.gz"]
16 | ]
17 | self.mean = 0.1307
18 | self.std = 0.3081
19 | self.num_classes = 10
20 | self.bs = bs
21 | self.path = path
22 | if not os.path.exists(self.path):
23 | os.mkdir(self.path)
24 | if not os.path.exists(self.path+'/mnist.pkl'):
25 | self.download_mnist()
26 | self.load(shuffle=shuffle)
27 | print('Loading complete.')
28 |
29 | def download_mnist(self):
30 | base_url = "http://yann.lecun.com/exdb/mnist/"
31 | for name in self.filename:
32 | print("Downloading "+name[1]+"...")
33 | request.urlretrieve(base_url+name[1], self.path+'/'+name[1])
34 | print("Download complete.")
35 | self.save_mnist()
36 |
37 | def save_mnist(self):
38 | mnist = {}
39 | for name in self.filename[:2]:
40 | with gzip.open(self.path+'/'+name[1], 'rb') as f:
41 | mnist[name[0]] = ((np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28*28))/255.-self.mean)/self.std
42 | for name in self.filename[-2:]:
43 | with gzip.open(self.path+'/'+name[1], 'rb') as f:
44 | mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
45 | with open(self.path+'/'+"mnist.pkl", 'wb') as f:
46 | pickle.dump(mnist,f)
47 | print("Save complete.")
48 |
49 | def chunks(self, l):
50 | for i in range(0, len(l), self.bs):
51 | yield l[i:i + self.bs]
52 |
53 | def load(self, shuffle=False):
54 | with open(self.path+"/mnist.pkl",'rb') as f:
55 | mnist = pickle.load(f)
56 | if shuffle:
57 | n = mnist['training_images'].shape[0]
58 | idxs = np.arange(n)
59 | np.random.shuffle(idxs)
60 | mnist['training_images'] = mnist['training_images'].reshape((-1,1,28,28))
61 | mnist['training_images'] = list(self.chunks(mnist['training_images'][idxs]))
62 | mnist['training_labels'] = list(self.chunks(mnist['training_labels'][idxs]))
63 | self.train_dataset = zip(cp.array(mnist['training_images']), cp.array(mnist['training_labels']))
64 |
65 | n = mnist['test_images'].shape[0]
66 | idxs = np.arange(n)
67 | np.random.shuffle(idxs)
68 | mnist['test_images'] = mnist['test_images'].reshape((-1,1,28,28))
69 | mnist['test_images'] = list(self.chunks(mnist['test_images'][idxs]))
70 | mnist['test_labels'] = list(self.chunks(mnist['test_labels'][idxs]))
71 | self.eval_dataset = zip(cp.array(mnist['test_images']), cp.array(mnist['test_labels']))
72 |
73 |
--------------------------------------------------------------------------------
/cupy/main.py:
--------------------------------------------------------------------------------
1 | from modules import *
2 | import time, os, argparse
3 | import cupy as cp
4 | from mnist import MNIST
5 | from modules import CapsNet, CapsLoss
6 | from optim import AdamOptimizer
7 |
8 |
9 | def parse_args():
10 | """
11 | Parse input arguments
12 | """
13 | parser = argparse.ArgumentParser(description='Cupy Capsnet')
14 | parser.add_argument('--bs', dest='bs',
15 | help='batch size',
16 | default='100', type=int)
17 | parser.add_argument('--lr', dest='lr',
18 | help='learning rate',
19 | default=1e-2, type=float)
20 | parser.add_argument('--opt', dest='opt',
21 | help='optimizer',
22 | default='adam', type=str)
23 | parser.add_argument('--disp', dest='disp_interval',
24 | help='interval to display training loss',
25 | default='10', type=int)
26 | parser.add_argument('--num_epochs', dest='num_epochs',
27 | help='num epochs to train',
28 | default='100', type=int)
29 | parser.add_argument('--val_epoch', dest='val_epoch',
30 | help='num epochs to run validation',
31 | default='1', type=int)
32 |
33 | args = parser.parse_args()
34 |
35 | return args
36 |
37 | if __name__ == '__main__':
38 |
39 | args = parse_args()
40 |
41 | mnist = MNIST(bs=args.bs, shuffle=True)
42 | eye = cp.eye(mnist.num_classes)
43 | model = CapsNet()
44 |
45 | criterion = CapsLoss()
46 | if args.opt == 'adam':
47 | optimizer = AdamOptimizer(lr=args.lr)
48 |
49 | print('Training started!')
50 |
51 | for epoch in range(args.num_epochs):
52 | start = time.time()
53 |
54 | # train
55 | correct = 0
56 | for batch_idx, (imgs, targets) in enumerate(mnist.train_dataset):
57 | optimizer.step()
58 | if imgs.shape[0] != args.bs:
59 | continue
60 |
61 | targets = eye[targets]
62 | scores, reconst = model(imgs)
63 | loss, grad = criterion(scores, targets, reconst, imgs)
64 | model.backward(grad, optimizer)
65 |
66 | classes = cp.argmax(scores, axis=1)
67 | predicted = eye[cp.squeeze(classes), :]
68 |
69 | predicted_idx = cp.argmax(predicted, 1)
70 | label_idx = cp.argmax(targets, 1)
71 | correct = cp.sum(predicted_idx == label_idx)
72 |
73 | # info
74 | if batch_idx % args.disp_interval == 0:
75 | end = time.time()
76 | print("[epoch %2d][iter %4d] loss: %.4f, acc: %.4f%% (%d/%d)" \
77 | % (epoch, batch_idx, loss, 100.*correct/args.bs, correct, args.bs))
78 |
79 | # val
80 | if epoch % args.val_epoch == 0:
81 | print('Validating...')
82 | correct = 0
83 | total = 0
84 |
85 | for batch_idx, (imgs, targets) in enumerate(mnist.eval_dataset):
86 | if imgs.shape[0] != args.bs:
87 | continue
88 |
89 | targets = eye[targets]
90 | scores, reconst = model(imgs)
91 | loss, grad = criterion(scores, targets, reconst, imgs)
92 | model.backward(grad, optimizer)
93 |
94 | classes = cp.argmax(scores, axis=1)
95 | predicted = eye[cp.squeeze(classes, axis=1), :]
96 |
97 | predicted_idx = cp.argmax(predicted, 1)
98 | label_idx = cp.argmax(targets, 1)
99 | correct += cp.sum(predicted_idx == label_idx)
100 | total += targets.shape[0]
101 |
102 | print("[epoch %2d] val acc: %.4f%% (%d/%d)" \
103 | % (epoch, 100.*correct/total, correct, total))
104 |
--------------------------------------------------------------------------------
/numpy/main.py:
--------------------------------------------------------------------------------
1 | from modules import *
2 | import time, os, argparse
3 | import numpy as np
4 | from mnist import MNIST
5 | from modules import CapsNet, CapsLoss
6 | from optim import AdamOptimizer
7 | import multiprocessing as mp
8 |
9 | def parse_args():
10 | """
11 | Parse input arguments
12 | """
13 | parser = argparse.ArgumentParser(description='Cupy Capsnet')
14 | parser.add_argument('--bs', dest='bs',
15 | help='batch size',
16 | default='100', type=int)
17 | parser.add_argument('--lr', dest='lr',
18 | help='learning rate',
19 | default=1e-2, type=float)
20 | parser.add_argument('--opt', dest='opt',
21 | help='optimizer',
22 | default='adam', type=str)
23 | parser.add_argument('--disp', dest='disp_interval',
24 | help='interval to display training loss',
25 | default='1', type=int)
26 | parser.add_argument('--num_epochs', dest='num_epochs',
27 | help='num epochs to train',
28 | default='100', type=int)
29 | parser.add_argument('--val_epoch', dest='val_epoch',
30 | help='num epochs to run validation',
31 | default='1', type=int)
32 |
33 | args = parser.parse_args()
34 |
35 | return args
36 |
37 | if __name__ == '__main__':
38 | mp.set_start_method('spawn')
39 | args = parse_args()
40 |
41 | mnist = MNIST(bs=args.bs, shuffle=True)
42 | eye = np.eye(mnist.num_classes)
43 | model = CapsNet()
44 |
45 | criterion = CapsLoss()
46 | if args.opt == 'adam':
47 | optimizer = AdamOptimizer(lr=args.lr)
48 |
49 | print('Training started!')
50 |
51 | for epoch in range(args.num_epochs):
52 | start = time.time()
53 |
54 | # train
55 | correct = 0
56 | for batch_idx, (imgs, targets) in enumerate(mnist.train_dataset):
57 | optimizer.step()
58 | if imgs.shape[0] != args.bs:
59 | continue
60 |
61 | targets = eye[targets]
62 | scores, reconst = model(imgs)
63 | loss, grad = criterion(scores, targets, reconst, imgs)
64 | model.backward(grad, optimizer)
65 |
66 | classes = np.argmax(scores, axis=1)
67 | predicted = eye[np.squeeze(classes), :]
68 |
69 | predicted_idx = np.argmax(predicted, 1)
70 | label_idx = np.argmax(targets, 1)
71 | correct = np.sum(predicted_idx == label_idx)
72 |
73 | # info
74 | if batch_idx % args.disp_interval == 0:
75 | end = time.time()
76 | print("[epoch %2d][iter %4d] loss: %.4f, acc: %.4f%% (%d/%d)" \
77 | % (epoch, batch_idx, loss, 100.*correct/args.bs, correct, args.bs))
78 |
79 | # val
80 | if epoch % args.val_epoch == 0:
81 | print('Validating...')
82 | correct = 0
83 | total = 0
84 |
85 | for batch_idx, (imgs, targets) in enumerate(mnist.eval_dataset):
86 | if imgs.shape[0] != args.bs:
87 | continue
88 |
89 | targets = eye[targets]
90 | scores, reconst = model(imgs)
91 | loss, grad = criterion(scores, targets, reconst, imgs)
92 | model.backward(grad, optimizer)
93 |
94 | classes = np.argmax(scores, axis=1)
95 | predicted = eye[np.squeeze(classes, axis=1), :]
96 |
97 | predicted_idx = np.argmax(predicted, 1)
98 | label_idx = np.argmax(targets, 1)
99 | correct += np.sum(predicted_idx == label_idx)
100 | total += targets.shape[0]
101 |
102 | print("[epoch %2d] val acc: %.4f%% (%d/%d)" \
103 | % (epoch, 100.*correct/total, correct, total))
104 |
--------------------------------------------------------------------------------
/pytorch/modules.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torchvision import transforms, datasets
6 | import time, os
7 | from torch.autograd import Variable
8 |
9 | def squash(s, dim=-1):
10 | norm2 = torch.sum(s**2, dim=dim, keepdim=True)
11 | norm = torch.sqrt(norm2)
12 | return (norm2 / (1.0 + norm2)) * (s / norm)
13 |
14 | class PrimaryCaps(nn.Module):
15 | def __init__(self, use_cuda=False, out_channels=32, in_channels=256, ncaps=32*6*6, ndim=8, kernel_size=9, stride=2, padding=0):
16 | super(PrimaryCaps, self).__init__()
17 | self.ncaps = ncaps
18 | self.ndim = ndim
19 | self.caps = nn.ModuleList(
20 | [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) for _ in
21 | range(ndim)])
22 |
23 | def forward(self, x):
24 | u = torch.cat([cap(x).view(x.size(0), -1, 1) for cap in self.caps], dim=-1)
25 | # output (bs, ncaps, ndim)
26 | return squash(u)
27 |
28 |
29 | class DigitCaps(nn.Module):
30 | def __init__(self, use_cuda=False, ncaps=10, ncaps_prev=32 * 6 * 6, ndim_prev=8, ndim=16):
31 | super(DigitCaps, self).__init__()
32 | self.use_cuda = use_cuda
33 | self.ndim_prev = ndim_prev
34 | self.ncaps_prev = ncaps_prev
35 | self.ncaps = ncaps
36 | self.route_iter = 3
37 | self.W = nn.Parameter(torch.randn(1, ncaps_prev, ncaps, ndim, ndim_prev))
38 |
39 | def forward(self, x):
40 | bs = x.size(0)
41 | x = torch.stack([x] * self.ncaps, dim=2).unsqueeze(-1)
42 | W = torch.cat([self.W] * bs, dim=0)
43 | u_hat = W @ x
44 |
45 | b = Variable(torch.zeros(1, self.ncaps_prev, self.ncaps, 1))
46 | if self.use_cuda:
47 | b = b.cuda()
48 |
49 | for i in range(self.route_iter):
50 | c = F.softmax(b)
51 | c = torch.cat([c] * bs, dim=0).unsqueeze(-1)
52 |
53 | s = (c * u_hat).sum(dim=1, keepdim=True)
54 | v = squash(s)
55 |
56 | if i < self.route_iter - 1:
57 | b = b + torch.matmul(u_hat.transpose(-1, -2), torch.cat([v] * self.ncaps_prev, dim=1)) \
58 | .squeeze(-1).mean(dim=0, keepdim=True)
59 | return v.squeeze(1)
60 |
61 |
62 | class Decoder(nn.Module):
63 | def __init__(self):
64 | super(Decoder, self).__init__()
65 | self.net = nn.Sequential(
66 | nn.Linear(16*10,512),
67 | nn.ReLU(inplace=True),
68 | nn.Linear(512,1024),
69 | nn.ReLU(inplace=True),
70 | nn.Linear(1024,784),
71 | nn.Sigmoid()
72 | )
73 |
74 | def forward(self,x):
75 | x = x.view(x.size(0),-1)
76 | x = self.net(x)
77 | return x
78 |
79 | class CapsNet(nn.Module):
80 | def __init__(self, use_cuda=False, kernel_size=9, stride=1):
81 | super(CapsNet, self).__init__()
82 |
83 | self.conv1 = nn.Conv2d(1,256,kernel_size,stride=stride)
84 | self.primary_caps = PrimaryCaps(use_cuda=use_cuda)
85 | self.digit_caps = DigitCaps(use_cuda=use_cuda)
86 | self.decoder = Decoder()
87 |
88 | def forward(self, inpt):
89 | start = time.time()
90 | x = F.relu(self.conv1(inpt), inplace=True)
91 | x = self.primary_caps(x)
92 | x = self.digit_caps(x)
93 | reconst = self.decoder(x)
94 | return x, reconst
95 |
96 | class CapsLoss(nn.Module):
97 | def __init__(self):
98 | super(CapsLoss, self).__init__()
99 | self.mse_loss = nn.MSELoss()
100 | self.reconst_factor = 0.0005
101 | def forward(self, scores, labels, reconst, inpt):
102 | norms = torch.sqrt(scores).squeeze()
103 | margin_loss = labels * ( F.relu(0.9 - norms, inplace=True) )**2 + 0.5*(1-labels) * ( F.relu(norms - 0.1, inplace=True) )**2
104 | margin_loss = margin_loss.sum(dim=-1).mean()
105 | reconst_loss = self.mse_loss(reconst.view(reconst.size(0),-1), inpt.view(inpt.size(0),-1))
106 | return margin_loss + self.reconst_factor * reconst_loss
107 |
108 |
--------------------------------------------------------------------------------
/pytorch/main.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torchvision import transforms, datasets
6 | import time, os, argparse
7 | from torch.autograd import Variable
8 | from modules import *
9 |
10 |
11 | class MNIST:
12 | def __init__(self, bs=1):
13 | dataset_transform = transforms.Compose([
14 | transforms.ToTensor(),
15 | transforms.Normalize((0.1307,), (0.3081,))
16 | ])
17 |
18 | train_dataset = datasets.MNIST('data', train=True, download=True, transform=dataset_transform)
19 | eval_dataset = datasets.MNIST('data', train=False, download=True, transform=dataset_transform)
20 |
21 | self.num_classes = 10
22 | self.train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=bs, shuffle=True)
23 | self.eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=bs, shuffle=True)
24 |
25 | def parse_args():
26 | """
27 | Parse input arguments
28 | """
29 | parser = argparse.ArgumentParser(description='Cupy Capsnet')
30 | parser.add_argument('--bs', dest='bs',
31 | help='batch size',
32 | default='100', type=int)
33 | parser.add_argument('--lr', dest='lr',
34 | help='learning rate',
35 | default=1e-2, type=float)
36 | parser.add_argument('--opt', dest='optimizer',
37 | help='optimizer',
38 | default='adam', type=str)
39 | parser.add_argument('--disp', dest='disp_interval',
40 | help='interval to display training loss',
41 | default=1, type=int)
42 | parser.add_argument('--num_epochs', dest='num_epochs',
43 | help='num epochs to train',
44 | default=100, type=int)
45 | parser.add_argument('--val_epoch', dest='val_epoch',
46 | help='num epochs to run validation',
47 | default=1, type=int)
48 | parser.add_argument('--save_epoch', dest='save_epoch',
49 | help='num epochs to save model',
50 | default=1, type=int)
51 | parser.add_argument('--use_cuda', dest='use_cuda',
52 | help='whether or not to use cuda',
53 | default=True, type=bool)
54 | parser.add_argument('--save_dir', dest='save_dir',
55 | help='directory to save trained models',
56 | default=True, type=bool)
57 |
58 | args = parser.parse_args()
59 |
60 | return args
61 |
62 | if __name__ == '__main__':
63 | args = parse_args()
64 |
65 | if not os.path.exists(args.save_dir):
66 | os.makedirs(args.save_dir)
67 |
68 | mnist = MNIST(bs=args.bs)
69 | # Variables
70 | inputs = torch.FloatTensor(1)
71 | labels = torch.FloatTensor(1)
72 | eye = Variable(torch.eye(mnist.num_classes))
73 | inputs = Variable(inputs)
74 | labels = Variable(labels)
75 |
76 | # Model
77 | model = CapsNet(use_cuda=args.use_cuda)
78 |
79 | # cuda
80 | if args.use_cuda:
81 | inputs = inputs.cuda()
82 | labels = labels.cuda()
83 | model = model.cuda()
84 | eye = eye.cuda()
85 |
86 | params = []
87 |
88 | for key, value in dict(model.named_parameters()).items():
89 | if value.requires_grad:
90 | params += [{'params':[value],'lr':args.lr}]
91 |
92 | # optimizer
93 | if args.optimizer == "adam":
94 | optimizer = torch.optim.Adam(model.parameters())
95 | elif args.optimizer == "sgd":
96 | optimizer = torch.optim.SGD(params)
97 |
98 | criterion = CapsLoss()
99 |
100 | print('Training started!')
101 |
102 | for epoch in range(args.num_epochs):
103 | start = time.time()
104 |
105 | # train
106 | model.train()
107 | correct = 0
108 | train_loss = 0
109 | for batch_idx, (imgs, targets) in enumerate(mnist.train_dataloader):
110 | if imgs.size(0) != args.bs:
111 | continue
112 |
113 | targets = eye.cpu().data.index_select(dim=0, index=targets)
114 | inputs.data.resize_(imgs.size()).copy_(imgs)
115 | labels.data.resize_(targets.size()).copy_(targets)
116 |
117 | optimizer.zero_grad()
118 | outputs, reconst = model(inputs)
119 |
120 | scores = torch.sqrt((outputs ** 2).sum(2))
121 | loss = criterion(scores, labels, reconst, inputs)
122 | train_loss = loss.data.cpu().numpy()[0]
123 |
124 | # backward
125 | loss.backward()
126 | optimizer.step()
127 |
128 | scores, classes = F.softmax(scores).max(dim=1)
129 | predicted = eye.index_select(dim=0, index=classes.squeeze(1))
130 |
131 | predicted_idx = np.argmax(predicted.data.cpu().numpy(),1)
132 | label_idx = np.argmax(targets.numpy(), 1)
133 | correct = np.sum(predicted_idx == label_idx)
134 |
135 | # info
136 | if batch_idx % args.disp_interval == 0:
137 | end = time.time()
138 | print("[epoch %2d][iter %4d] loss: %.4f, acc: %.4f%% (%d/%d)" \
139 | % (epoch, batch_idx, train_loss/(batch_idx+1), 100.*correct/args.bs, correct, args.bs))
140 |
141 | save_name = os.path.join(args.save_dir, '{}_{}.pth'.format(project_id, epoch))
142 | if args.save_epoch > 0 and batch_idx % args.save_epoch == 0:
143 | torch.save({
144 | 'epoch': epoch,
145 | }, save_name)
146 |
147 | # val
148 | if epoch % args.val_epoch == 0:
149 | print('Validating...')
150 | correct = 0
151 | total = 0
152 | model.eval()
153 | for batch_idx, (imgs, targets) in enumerate(mnist.eval_dataloader):
154 | if imgs.size(0) != args.bs:
155 | continue
156 | targets = eye.cpu().data.index_select(dim=0, index=targets)
157 | inputs.data.resize_(imgs.size()).copy_(imgs)
158 | labels.data.resize_(targets.size()).copy_(targets)
159 |
160 | outputs, reconst = model(inputs)
161 | scores = torch.sqrt((outputs ** 2).sum(2))
162 | scores, classes = F.softmax(scores).max(dim=1)
163 | predicted = eye.index_select(dim=0, index=classes.squeeze(1))
164 |
165 | predicted_idx = np.argmax(predicted.data.cpu().numpy(),1)
166 | label_idx = np.argmax(targets.numpy(), 1)
167 | correct += np.sum(predicted_idx == label_idx)
168 | total += targets.size(0)
169 | print("[epoch %2d] val acc: %.4f%% (%d/%d)" \
170 | % (epoch, 100.*correct/total, correct, total))
171 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pyCapsNet
2 |
3 | [![License][license]][license-url]
4 |
5 | Pytorch, NumPy and CuPy implementations of Capsule Networks (CapsNet), based on the paper [Sabour, Sara, Nicholas Frosst, and Geoffrey E. Hinton. "Dynamic routing between capsules." Advances in Neural Information Processing Systems. 2017.]
6 |
7 | ## Requirements
8 |
9 | * Python 3
10 |
11 | PyTorch Implementation:
12 | * PyTorch
13 | * Tested with PyTorch 0.3.0.post4
14 | * CUDA 8 (if using CUDA)
15 |
16 | CuPy Implementation:
17 | * CuPy 2.0.0
18 | * CUDA 8
19 |
20 | ## Motivation
21 | There are many great implementations of Capsule Networks [with PyTorch], [TensorFlow] and [Keras], so why do we need another one? This project actually provides three implementations of CapsNet: PyTorch, NumPy and CuPy. For the PyTorch version, I implemented CapsNet for performance check and visualizations; for the NumPy and CuPy ones, I implemented CapsNet purely from scratch, both forward and backpropagation, aiming to get a deeper understanding of the structure and the gradient flow of CapsNet. The computation graph that I used for this implementation is provided later in this document.
22 |
23 | The purpose of this project is not to shoot for better performance or optimizing the speed, but to offer a better understanding of CapsNet implementation-wise. Reading the paper thoroughly is a must, but it is easy to get confused when it comes to real implementation. I will provide my own understanding in CapsNet and implementation walkthrough in this document.
24 |
25 | This [video] really helped a lot for me to understand CapsNet. Take a minute and check it out.
26 |
27 | ## Challenges of Implementation
28 | * The 5-dimension tensor for CapsNet can be pretty confusing. Stick with one sequence of dimensions and mind the difference between element-wise and matrix multiplication.
29 | * CuPy and NumPy implementations are built from scratch. It is challenging to make sure gradient flows correctly; I drew computational graph and performed unit tests on each basic modules (e.g.: Squash, Conv2d, Linear, Sequence, losses) and composite ones (e.g.: PrimaryCaps, DigitCaps). Accumulating gradients for the iterative refinement especially requires a clear understanding on the computation flow of DigitCaps.
30 |
31 | ## To Run
32 | For NumPy and CuPy implementations, change into the corresponding directories, and run
33 | ```
34 | python3 main.py --bs=100 --lr=1e-3 --opt='adam' --disp=10 --num_epochs=100 --val_epoch=1
35 | ```
36 | For the PyTorch implementation, run
37 | ```
38 | python3 main.py --bs=100 --lr=1e-3 --opt='adam' --disp=1 --num_epochs=100 --val_epoch=1 --use_cuda=True --save_dir='saved_models'
39 | ```
40 | To visualize the reconstructed data, run the jupyter notebook in PyTorch/Visualization.ipnb.
41 |
42 | ## Capsule Networks: Key Points
43 | A capsule is a neuron that outputs activity vectors to represent the instantiation parameters of certain entities. The magnitude of the activation vector corresponds to the probability that such entity exists, and the orientation represents the instantiation parameters. The [paper] proposes a multi-layer capsule network for different image classification tasks, and achieved state-of-the-art performance on the MNIST dataset.
44 |
45 | ### Activity Vectors
46 | Unlike the neurons in most neural networks, the capsules in this architecture outputs an activity vector for each input. The paper introduces a nonlinear "squashing" function for the activity vector :
47 |
48 | 
49 |
50 | ### Dynamic Routing Between Capsules
51 | In this paper, the authors replaces the conventional max-pooling layer with dynamic routing, iteratively refining the coupling coefficient to route the prediction made by the last layer to the next one. Such routing is achieved by "routing-by-agreement", where a capsule prefers to route its outputs to the capsules in the next layer whose output has a greater dot product with its own output. A result can be, for example, that the "5" capsule would receive features that agrees with "5". In the paper, the authors investigated this iterative refinement of routing coefficients, and found out that the number of iterations for finding the coefficients can indeed help achieve a lower loss and better stability. However, this sequential and iterative structure can make CapsNet very slow. In this project, I followed the paper and set the number of routing iterations to 3.
52 |
53 | ### The Architecture
54 |
55 |
56 | The Capsule Network in the figure above consists of three parts: a convolutional layer (Conv1) and two capsule layers (PrimaryCaps and DigitCaps). The DigitCaps layer yields a 16-dimensional vector for each of the 10 classes, and the L2 norm of these vectors becomes the class score. The decoder in the figure below consumes these vectors and tries to reconstruct the image. The final loss is the combination of the score loss ("margin loss" in the paper) and the reconstruction loss.
57 |
58 |

59 |
60 | ### Detailed Walkthrough
61 |
62 | Here I'll provide a brief cheat sheet, which I would find extremely helpful if I saw this before implementation.
63 |
64 | * Conv1:
65 | * Input size `(N, C=1, H=28, W=28)`
66 | * `in_channel=1, out_channel=256, kernel_size=(9,9), stride=1, padding=0`
67 | * This convolution yields output size `(N, 256, 20, 20)`
68 |
69 | * PrimaryCaps:
70 | * Input size `(N, C=256, H=20, W=20)`
71 | * `ndim=8` Convolution kernels, each with `in_channel=256, out_channel=32, kernel_size=(9,9), stride=2, padding=0`
72 | * Each convolution yields output size `(N, 32, 6, 6)`; linearize for each batch and concatenate the output of each convolution and feed into DigitCaps.
73 |
74 | * DigitCaps:
75 | * Input size `(N, ncaps_prev=32*6*6, ndim_prev=8)`
76 | * For convenience, each tensor involved in this layer is reshaped into 5 dimensions, corresponding to the dimensions of the weight, which is of size `(1, ncaps_prev=32*6*6, ncaps=10, ndim=16, ndim_prev=8)`. Note that the input and the weight has the same dimensions in `ncaps_prev` and `ncaps`, and `N` can be handled by broadcasting. Focusing on the last two dimensions, the weight is of size `(16, 8)` and the input is of size `(8, 1)`, therefore the output size gives `(16, 1)` in these two dimensions.
77 |
78 | * Outputs size `(N, 1, 10, 16, 1)`. Getting the L2 norm across the 3rd dimension to get scores for each class. This tensor is fed into the decoder to get reconstructions.
79 |
80 | * Decoder:
81 | * Input size `(N, 10, 16)`,
82 | * Three linear layers, sizes `(16*10, 512)`, `(512, 1024)` and `(1024, 28*28)`. The first two are followed with ReLU and the last layer is followed by a Sigmoid layer.
83 | * Outputs `(N, 28*28)`, i.e. the reconstruction.
84 |
85 | ## Computation Graph for DigitCaps
86 |
87 | 
88 |
89 | This computation graph was originally used for a better understanding of the gradient flow. I redrew the graph for DigitCaps in SVG to provide a clear illustration for people interested in implementing this part. If you want to implement backpropagation of DigitCaps, mind the accumulated gradient from each routing iteration.
90 |
91 | ## Results
92 | I achieved 99.41% validation accuracy at epoch 22 with the PyTorch implementation, which is close to the number reported on the paper. The CuPy implementation can quickly converge to 90%+, but overall trains slower than the PyTorch version. The NumPy implementation is trained purely on CPU; Though I used multiprocessing in the network, it much slower than the GPU implementations. The reconstructed images are given below.
93 |
94 | 
95 |
96 | I also performed the experiment to perturb the 16-dimensional vectors of DigitCaps output, and feed in the pre-trained decoder and try to visualize the meaning of each dimension. The image given below shows that perturbing one dimension could change the orientation, width, stroke width and local features of the reconstructed image.
97 |
98 | 
99 |
100 |
101 | ## To-dos
102 |
103 | - [x] Add visualization ipynb for PyTorch implementation
104 | - [ ] Add visualization ipynb for CuPy and NumPy implementations
105 | - [ ] Finish deformable convolution implementation in CuPy
106 | - [ ] Start a project on CuPy automatic differentiation, which could possibly benefit this project
107 |
108 |
109 | [license]: https://img.shields.io/github/license/mashape/apistatus.svg
110 | [license-url]: https://github.com/xanderchf/pyCapsNet/blob/master/LICENSE
111 | [Sabour, Sara, Nicholas Frosst, and Geoffrey E. Hinton. "Dynamic routing between capsules." Advances in Neural Information Processing Systems. 2017.]: https://arxiv.org/abs/1710.09829
112 | [paper]: https://arxiv.org/abs/1710.09829
113 | [with PyTorch]: https://github.com/gram-ai/capsule-networks
114 | [TensorFlow]: https://github.com/ageron/handson-ml
115 | [Keras]: https://github.com/XifengGuo/CapsNet-Keras
116 | [video]: https://www.youtube.com/watch?v=2Kawrd5szHE
117 |
--------------------------------------------------------------------------------
/cupy/modules.py:
--------------------------------------------------------------------------------
1 |
2 | # im2col functions adapted from https://github.com/Burton2000/CS231n-2017/blob/master/assignment2/cs231n/im2col.py
3 |
4 | import cupy as cp
5 | import time, os
6 |
7 |
8 | def tile(arr, copy, axis):
9 | return cp.concatenate([arr] * copy, axis=axis)
10 |
11 |
12 | class Module(object):
13 | def __init__(self, trainable=False):
14 | self.trainable = trainable
15 | pass
16 |
17 | def forward(self, x):
18 | raise NotImplementedError
19 |
20 | def backward(self, grad, optimizer=None):
21 | raise NotImplementedError
22 |
23 | def __call__(self, *input, **kwargs):
24 | return self.forward(*input, **kwargs)
25 |
26 |
27 | class Sequence(Module):
28 | def __init__(self, modules):
29 | self._modules = modules
30 |
31 | def forward(self, inpt):
32 | t = time.time()
33 | for module in self._modules:
34 | inpt = module(inpt)
35 | cur = time.time()
36 | t = cur
37 | if module.trainable:
38 | self.trainable = True
39 | return inpt
40 |
41 | def backward(self, grad, optimizer=None):
42 | for module in self._modules[::-1]:
43 | if module.trainable:
44 | grad = module.backward(grad, optimizer)
45 | else:
46 | grad = module.backward(grad)
47 |
48 | return grad
49 |
50 | def modules(self):
51 | return self._modules
52 |
53 | def trainable_modules(self):
54 | return [i for i in self._modules if i.trainable]
55 |
56 |
57 | class Linear(Module):
58 | def __init__(self, in_channel, out_channel):
59 | super(Linear, self).__init__(trainable=True)
60 | std = 1/cp.sqrt(in_channel)
61 | self.w = cp.random.uniform(-std, std, (out_channel, in_channel))
62 | self.b = cp.random.uniform(-std, std, (1, out_channel))
63 | self.x = None
64 |
65 | def _set_params(self, params):
66 | w, b = params
67 | self.w = w
68 | self.b = b
69 | if len(self.b.shape) < 2:
70 | self.b = self.b[None,:]
71 |
72 | def forward(self, x):
73 | out = x.dot(self.w.T) + self.b
74 | self.x = x
75 | return out
76 |
77 | def backward(self, grad, optimizer=None):
78 | dw = (self.x.T @ grad).T
79 | db = cp.sum(grad, axis=0, keepdims=True)
80 | # update parameters
81 | if optimizer is not None:
82 | self.w = optimizer(self.w, dw)
83 | self.b = optimizer(self.b, db)
84 |
85 | dx = grad @ self.w
86 | dx = cp.reshape(dx, self.x.shape)
87 | return dx
88 |
89 |
90 | class ReLU(Module):
91 | def __init__(self, alpha=0):
92 | super(ReLU, self).__init__()
93 | self.alpha = alpha
94 | self.x = None
95 |
96 | def forward(self, x):
97 | out = x.copy()
98 | if self.alpha > 0:
99 | out[out<0] = self.alpha*x
100 | else:
101 | out[out<0] = 0
102 | self.x = x
103 | return out
104 |
105 | def backward(self, grad):
106 | dx = grad.copy()
107 | dx[self.x < 0] = 0
108 | return dx
109 |
110 | class Sigmoid(Module):
111 | def __init__(self):
112 | super(Sigmoid, self).__init__()
113 | self.s = None
114 |
115 | def forward(self, x):
116 | self.s = 1/(1 + cp.exp(-x))
117 | return self.s
118 |
119 | def backward(self, grad):
120 | return grad * (self.s * (1-self.s))
121 |
122 |
123 | class Conv2d(Module):
124 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad=0, eps=1e-4):
125 | super(Conv2d, self).__init__(trainable=True)
126 | self.ic = in_channels
127 | self.oc = out_channels
128 | self.k = kernel_size
129 | self.s = stride
130 | self.p = pad
131 |
132 | std = 1/(cp.sqrt(self.ic* self.k**2))
133 | self.W = cp.random.uniform(-std, std, (self.oc,self.ic,self.k,self.k))
134 | self.b = cp.random.uniform(-std, std, (self.oc, 1))
135 |
136 | self.X_col = None
137 | self.x_shape = None
138 |
139 | def _set_params(self, params):
140 | W, b = params
141 | self.W = W
142 | self.b = b
143 |
144 | def forward(self, X):
145 | NF, CF, HF, WF = self.W.shape
146 | NX, DX, HX, WX = X.shape
147 | self.x_shape = X.shape
148 | h_out = int((HX - HF + 2 * self.p) / self.s + 1)
149 | w_out = int((WX - WF + 2 * self.p) / self.s + 1)
150 |
151 | X_col = self.im2col_indices(X)
152 | self.X_col = X_col
153 | W_col = self.W.reshape(NF, -1)
154 |
155 | out = W_col @ self.X_col + self.b
156 | out = out.reshape(NF, h_out, w_out, NX)
157 | out = out.transpose(3, 0, 1, 2)
158 |
159 | return out
160 |
161 |
162 | def backward(self, dout, optimizer=None):
163 | NF, CF, HF, WF = self.W.shape
164 |
165 | db = cp.sum(dout, axis=(0, 2, 3))
166 | db = db.reshape(NF, -1)
167 |
168 | dout_reshaped = dout.transpose(1, 2, 3, 0).reshape(NF, -1)
169 | dW = dout_reshaped @ self.X_col.T
170 | dW = dW.reshape(self.W.shape)
171 |
172 | if optimizer is not None:
173 | self.b = optimizer(self.b, db)
174 | self.W = optimizer(self.W, dW)
175 |
176 | W_reshape = self.W.reshape(NF, -1)
177 | dX_col = W_reshape.T @ dout_reshaped
178 | dX = self.col2im_indices(dX_col.astype(cp.float32))
179 |
180 | return dX
181 |
182 | def get_im2col_indices(self):
183 | padding, stride, field_height, field_width, x_shape = self.p, self.s, self.k, self.k, self.x_shape
184 | N, C, H, W = x_shape
185 | # assert (H + 2 * padding - field_height) % stride == 0
186 | # assert (W + 2 * padding - field_height) % stride == 0
187 | out_height = int((H + 2 * padding - field_height) / stride + 1)
188 | out_width = int((W + 2 * padding - field_width) / stride + 1)
189 |
190 | i0 = cp.repeat(cp.arange(field_height), field_width)
191 | i0 = cp.tile(i0, C)
192 | i1 = stride * cp.repeat(cp.arange(out_height), out_width)
193 | j0 = cp.tile(cp.arange(field_width), field_height * C)
194 | j1 = stride * cp.tile(cp.arange(out_width), out_height)
195 | i = i0.reshape(-1, 1) + i1.reshape(1, -1)
196 | j = j0.reshape(-1, 1) + j1.reshape(1, -1)
197 |
198 | k = cp.repeat(cp.arange(C), field_height * field_width).reshape(-1, 1)
199 |
200 | return (k.astype(cp.int32), i.astype(cp.int32), j.astype(cp.int32))
201 |
202 |
203 | def im2col_indices(self, x):
204 | p, stride, field_height, field_width = self.p, self.s, self.k, self.k
205 | x_padded = cp.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')
206 |
207 | k, i, j = self.get_im2col_indices()
208 |
209 | cols = x_padded[:, k, i, j]
210 | C = x.shape[1]
211 | cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1)
212 | return cols
213 |
214 |
215 | def col2im_indices(self, cols):
216 | field_height, field_width, padding, stride = self.k, self.k, self.p, self.s
217 | N, C, H, W = self.x_shape
218 | H_padded, W_padded = H + 2 * padding, W + 2 * padding
219 | x_padded = cp.zeros((N, C, H_padded, W_padded), dtype=cols.dtype)
220 | k, i, j = self.get_im2col_indices()
221 | cols_reshaped = cols.reshape(C * field_height * field_width, -1, N)
222 | cols_reshaped = cols_reshaped.transpose(2, 0, 1).astype(cp.float32)
223 | cp.scatter_add(x_padded, (slice(None), k, i, j), cols_reshaped)
224 | if padding == 0:
225 | return x_padded
226 | return x_padded[:, :, padding:-padding, padding:-padding]
227 |
228 |
229 | class Softmax(Module):
230 | def __init__(self, dim=-1):
231 | super(Softmax, self).__init__()
232 | self.s = None
233 | self.dim = dim
234 | self.squeeze_len = None
235 |
236 | def forward(self, x, dim=None):
237 | if dim is not None:
238 | self.dim = dim
239 | if self.dim < 0:
240 | self.dim = len(x.shape)+self.dim
241 | self.squeeze_len = x.shape[self.dim]
242 | y = cp.exp(x)
243 | s = y/cp.sum(y, axis=self.dim, keepdims=True)
244 | self.s = s
245 | return s
246 |
247 | def backward(self, grad):
248 | self.s = cp.expand_dims(self.s.swapaxes(self.dim,-1), -1)
249 | grad = cp.expand_dims(grad.swapaxes(self.dim,-1), -1)
250 | mat = self.s @ self.s.swapaxes(-1,-2)
251 | mat = (-mat + cp.eye(mat.shape[-1]) * (mat**0.5))
252 | grad = mat @ grad
253 | self.s = self.s.swapaxes(self.dim,-1).squeeze(-1)
254 | return grad.swapaxes(self.dim,-2).squeeze(-1)
255 |
256 |
257 | class Squash(Module):
258 | def __init__(self, dim=-1):
259 | super(Squash, self).__init__()
260 | self.dim = dim
261 | self.squeeze_len = None
262 | self.s = None
263 |
264 | def forward(self, s):
265 | self.s = s
266 | self.squeeze_len = s.shape[self.dim]
267 | norm2 = cp.sum((s)**2, axis=self.dim, keepdims=True)
268 | return (cp.sqrt(norm2) / (1.0 + norm2)) * s
269 |
270 | def backward(self, grad):
271 | norm2 = cp.sum((self.s)**2, axis=self.dim, keepdims=True)
272 | norm = cp.sqrt(norm2)
273 | temp = tile((1/(2*(1.+norm2)*norm) - norm/(1.+norm2)**2), self.squeeze_len, self.dim)
274 | dnorm2 = cp.sum(self.s * temp, axis=-1, keepdims=True)
275 | factor = norm/(1+norm2)
276 | return grad * dnorm2 * (2.*self.s) + grad * factor
277 |
278 | class MSELoss(Module):
279 | def __init__(self):
280 | super(MSELoss, self).__init__()
281 | self.x = None
282 | self.y = None
283 |
284 | def forward(self, x, y):
285 | self.x = x
286 | self.y = y
287 | return cp.sum((x - y)**2)/float(x.size), 2*(x - y)/float(x.size)
288 |
289 |
290 | class PrimaryCaps(Module):
291 | def __init__(self, use_cuda=False, out_channels=32, in_channels=256, mapsize=6, ndim=8, kernel_size=9, stride=2, padding=0):
292 | super(PrimaryCaps, self).__init__(trainable=True)
293 | self.ndim = ndim
294 | self.caps = [Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, pad=padding) for _ in
295 | range(ndim)]
296 |
297 | self.out_channels = out_channels
298 | self.mapsize = mapsize
299 | self.ncaps = out_channels * mapsize**2
300 | self.squash = Squash()
301 | self.x_size = None
302 |
303 | def _set_params(self, params):
304 | for i, c in enumerate(self.caps):
305 | c._set_params(params[i])
306 |
307 | def forward(self, x):
308 | t = time.time()
309 | # output (bs, ncaps, ndim)
310 | self.x_size = x.shape
311 | u = cp.concatenate([cap(x).reshape((x.shape[0], -1, 1)) for cap in self.caps], axis=-1)
312 | return self.squash(u)
313 |
314 | def backward(self, grads, optimizer=None):
315 | t = time.time()
316 | grads = self.squash.backward(grads)
317 | grads = grads.reshape((self.x_size[0],self.out_channels, self.mapsize, self.mapsize,-1))
318 | grads = cp.concatenate([cp.expand_dims(self.caps[i].backward(
319 | grads[:,:,:,:,i], optimizer=optimizer), -1) for i in range(self.ndim)], axis=-1)
320 | out = cp.sum(grads, axis=-1)
321 | return out
322 |
323 |
324 | class Decoder(Module):
325 | def __init__(self):
326 | super(Decoder, self).__init__(trainable=True)
327 | self.net = Sequence([
328 | Linear(16*10,512),
329 | ReLU(),
330 | Linear(512,1024),
331 | ReLU(),
332 | Linear(1024,784),
333 | Sigmoid()
334 | ])
335 | self.x_shape = None
336 |
337 | def forward(self, x):
338 | self.x_shape = x.shape
339 | x = x.reshape(x.shape[0],-1)
340 |
341 | return self.net(x)
342 |
343 | def _set_params(self, params):
344 | for i, l in enumerate(self.net.trainable_modules()):
345 | l._set_params(params[i])
346 |
347 | def backward(self, grad, optimizer):
348 | return self.net.backward(grad, optimizer).reshape(self.x_shape)
349 |
350 |
351 | class DigitCaps(Module):
352 | def __init__(self, ncaps=10, ncaps_prev=32 * 6 * 6, ndim_prev=8, ndim=16):
353 | super(DigitCaps, self).__init__(trainable=True)
354 | self.ndim_prev = ndim_prev
355 | self.ncaps_prev = ncaps_prev
356 | self.ncaps = ncaps
357 | self.route_iter = 2
358 | self.W = cp.random.randn(1, ncaps_prev, ncaps, ndim, ndim_prev)
359 | self.softmaxs = [Softmax() for _ in range(self.route_iter)]
360 | self.squashs = [Squash() for _ in range(self.route_iter)]
361 | self.u_hat = None
362 | self.bs = None
363 | self.b = [None] * self.route_iter
364 | self.v = [None] * self.route_iter
365 | self.x = None
366 |
367 | def _set_params(self, params):
368 | self.W = params
369 |
370 | def forward(self, x):
371 | t = time.time()
372 | self.bs = x.shape[0]
373 | self.x = x
374 | x = tile(x[:,:,None,:,None], self.ncaps, 2)
375 | W = tile(self.W, self.bs, 0)
376 | u_hat = W @ x
377 | self.u_hat = u_hat
378 | b = cp.zeros((1, self.ncaps_prev, self.ncaps, 1, 1))
379 |
380 | for r in range(self.route_iter):
381 | self.b[r] = b
382 | c = self.softmaxs[r](b, dim=1)
383 |
384 | c = tile(c, self.bs, 0)
385 | s = cp.sum(c * u_hat, axis=1, keepdims=True)
386 | v = self.squashs[r](s)
387 | if r == self.route_iter - 1:
388 | return cp.squeeze(v, axis=1)
389 |
390 | self.v[r] = v
391 | p = u_hat.swapaxes(-1, -2) @ tile(v, self.ncaps_prev, 1)
392 | b = b + cp.mean(p, axis=0, keepdims=True)
393 |
394 |
395 | def backward(self, grad, optimizer=None):
396 | t = time.time()
397 | grad_accum = cp.zeros_like(self.u_hat)
398 | b_grad_accum = None
399 | grad = grad[:,None,:,:,:]
400 | for r in range(self.route_iter)[::-1]:
401 | if r < self.route_iter-1:
402 | grad = b_grad_accum
403 | grad = tile(grad, self.bs, 0)/self.bs
404 | p_grad = tile(self.v[r], self.ncaps_prev, 1) * grad
405 |
406 | grad_accum += p_grad
407 |
408 | grad = self.u_hat * grad
409 | grad = cp.sum(grad, axis=1, keepdims=True)
410 |
411 | grad = self.squashs[r].backward(grad)
412 | grad = tile(grad, self.ncaps_prev, 1)
413 | c = self.softmaxs[r].s
414 | grad_accum += tile(c, self.bs, 0) * grad
415 | grad = self.u_hat.swapaxes(-1,-2) @ grad
416 |
417 | if r > 0:
418 | grad = cp.sum(grad, axis=0, keepdims=True)
419 | grad = self.softmaxs[r].backward(grad)
420 | if b_grad_accum is None:
421 | b_grad_accum = grad
422 | else:
423 | b_grad_accum += grad
424 |
425 | x = tile(self.x[:,:,None,:,None], self.ncaps, 2)
426 | dW = cp.sum(grad_accum @ x.swapaxes(-1,-2), axis=0, keepdims=True)
427 | if optimizer is not None:
428 | self.W = optimizer(self.W, dW)
429 |
430 | grad_accum = cp.squeeze(self.W.swapaxes(-1,-2) @ grad_accum, axis=-1)
431 | dx = cp.sum(grad_accum, axis=2)
432 | return dx
433 |
434 |
435 | class CapsNet(Module):
436 | def __init__(self, use_cuda=False, kernel_size=9, stride=1):
437 | super(CapsNet, self).__init__(trainable=True)
438 | self.net = Sequence([
439 | Conv2d(1,256,kernel_size=kernel_size,stride=stride),
440 | ReLU(),
441 | PrimaryCaps(),
442 | DigitCaps()
443 | ])
444 | self.decoder = Decoder()
445 | self.x = None
446 | self.digit_ndim = 16
447 | self.softmax = Softmax()
448 |
449 | def _set_params(self, params):
450 | for i, m in enumerate(self.net.trainable_modules() + [self.decoder]):
451 | m._set_params(params)
452 |
453 | def forward(self, x):
454 | x = self.net(x)
455 | self.x = x
456 | reconst = self.decoder(x)
457 | scores = cp.sqrt((x ** 2).sum(2)).squeeze()
458 | return scores, reconst
459 |
460 | def backward(self, grad, optimizer):
461 | scores_grad, reconst_grad = grad
462 |
463 | scores_grad = scores_grad[:,:,None, None]
464 | t = 0.5 * ((self.x ** 2).sum(2, keepdims=True) ** (-0.5))
465 | scores_grad *= 0.5 * ((self.x ** 2).sum(2, keepdims=True) ** (-0.5))
466 | scores_grad = tile(scores_grad, self.digit_ndim, 2) # tile at dimension 2
467 | scores_grad *= 2*self.x
468 | t = time.time()
469 |
470 | reconst_grad = self.decoder.backward(reconst_grad, optimizer)
471 | grad = scores_grad + reconst_grad
472 |
473 | grad = self.net.backward(grad, optimizer=optimizer)
474 | return grad
475 |
476 |
477 | class CapsLoss(Module):
478 | def __init__(self):
479 | super(CapsLoss, self).__init__()
480 | self.mse_loss = MSELoss()
481 | self.relu1 = ReLU()
482 | self.relu2 = ReLU()
483 | self.reconst_factor = 0.0005
484 |
485 |
486 | def forward(self, norms, labels, reconst, inpt):
487 | self.labels = labels
488 |
489 | int1 = self.relu1(0.9 - norms)
490 | int2 = self.relu2(norms - 0.1)
491 | margin_loss = labels * int1**2 + 0.5*(1-labels) * int2**2
492 | bs, ndim_prev = margin_loss.shape[0], margin_loss.shape[-1]
493 | margin_loss = cp.sum(margin_loss, axis=-1).mean()
494 |
495 | reconst_loss, reconst_grad = self.mse_loss(reconst.reshape(reconst.shape[0],-1), inpt.reshape(inpt.shape[0],-1))
496 | loss = margin_loss + self.reconst_factor * reconst_loss
497 |
498 | margin_grad = cp.ones((bs, ndim_prev)) / float(bs)
499 | margin_grad_pos = -self.relu1.backward(margin_grad * labels * (2*int1))
500 | margin_grad_neg = self.relu2.backward(margin_grad * 0.5*(1-labels) * (2*int2))
501 |
502 | margin_grad = margin_grad_pos + margin_grad_neg
503 | reconst_grad *= self.reconst_factor
504 |
505 | return loss, (margin_grad, reconst_grad)
--------------------------------------------------------------------------------
/numpy/modules.py:
--------------------------------------------------------------------------------
1 |
2 | # im2col functions adapted from https://github.com/Burton2000/CS231n-2017/blob/master/assignment2/cs231n/im2col.py
3 |
4 | import numpy as np
5 | import time, os
6 | import multiprocessing as mp
7 | from functools import partial
8 |
9 | def tile(arr, copy, axis):
10 | return np.concatenate([arr] * copy, axis=axis)
11 |
12 |
13 | class Module(object):
14 | def __init__(self, trainable=False):
15 | self.trainable = trainable
16 | pass
17 |
18 | def forward(self, x):
19 | raise NotImplementedError
20 |
21 | def backward(self, grad, optimizer=None):
22 | raise NotImplementedError
23 |
24 | def __call__(self, *input, **kwargs):
25 | return self.forward(*input, **kwargs)
26 |
27 |
28 | class Sequence(Module):
29 | def __init__(self, modules):
30 | self._modules = modules
31 |
32 | def forward(self, inpt):
33 | t = time.time()
34 | for module in self._modules:
35 | inpt = module(inpt)
36 | cur = time.time()
37 | t = cur
38 | if module.trainable:
39 | self.trainable = True
40 | return inpt
41 |
42 | def backward(self, grad, optimizer=None):
43 | for module in self._modules[::-1]:
44 | if module.trainable:
45 | grad = module.backward(grad, optimizer)
46 | else:
47 | grad = module.backward(grad)
48 |
49 | return grad
50 |
51 | def modules(self):
52 | return self._modules
53 |
54 | def trainable_modules(self):
55 | return [i for i in self._modules if i.trainable]
56 |
57 |
58 | class Linear(Module):
59 | def __init__(self, in_channel, out_channel):
60 | super(Linear, self).__init__(trainable=True)
61 | std = 1/np.sqrt(in_channel)
62 | self.w = np.random.uniform(-std, std, (out_channel, in_channel))
63 | self.b = np.random.uniform(-std, std, (1, out_channel))
64 | self.x = None
65 |
66 | def _set_params(self, params):
67 | w, b = params
68 | self.w = w
69 | self.b = b
70 | if len(self.b.shape) < 2:
71 | self.b = self.b[None,:]
72 |
73 | def forward(self, x):
74 | out = x.dot(self.w.T) + self.b
75 | self.x = x
76 | return out
77 |
78 | def backward(self, grad, optimizer=None):
79 | dw = (self.x.T @ grad).T
80 | db = np.sum(grad, axis=0, keepdims=True)
81 | # update parameters
82 | if optimizer is not None:
83 | self.w = optimizer(self.w, dw)
84 | self.b = optimizer(self.b, db)
85 |
86 | dx = grad @ self.w
87 | dx = np.reshape(dx, self.x.shape)
88 | return dx
89 |
90 |
91 | class ReLU(Module):
92 | def __init__(self, alpha=0):
93 | super(ReLU, self).__init__()
94 | self.alpha = alpha
95 | self.x = None
96 |
97 | def forward(self, x):
98 | out = x.copy()
99 | if self.alpha > 0:
100 | out[out<0] = self.alpha*x
101 | else:
102 | out[out<0] = 0
103 | self.x = x
104 | return out
105 |
106 | def backward(self, grad):
107 | dx = grad.copy()
108 | dx[self.x < 0] = 0
109 | return dx
110 |
111 | class Sigmoid(Module):
112 | def __init__(self):
113 | super(Sigmoid, self).__init__()
114 | self.s = None
115 |
116 | def forward(self, x):
117 | self.s = 1/(1 + np.exp(-x))
118 | return self.s
119 |
120 | def backward(self, grad):
121 | return grad * (self.s * (1-self.s))
122 |
123 |
124 | class Conv2d(Module):
125 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad=0, eps=1e-4):
126 | super(Conv2d, self).__init__(trainable=True)
127 | self.ic = in_channels
128 | self.oc = out_channels
129 | self.k = kernel_size
130 | self.s = stride
131 | self.p = pad
132 |
133 | std = 1/(np.sqrt(self.ic* self.k**2))
134 | self.W = np.random.uniform(-std, std, (self.oc,self.ic,self.k,self.k))
135 | self.b = np.random.uniform(-std, std, (self.oc, 1))
136 |
137 | self.X_col = None
138 | self.x_shape = None
139 |
140 | def _set_params(self, params):
141 | W, b = params
142 | self.W = W
143 | self.b = b
144 |
145 | def _set_input(self, x):
146 | self.x_shape = x.shape
147 | self.X_col = self.im2col_indices(x)
148 |
149 | def forward(self, X):
150 | NF, CF, HF, WF = self.W.shape
151 | NX, DX, HX, WX = X.shape
152 | self.x_shape = X.shape
153 | h_out = int((HX - HF + 2 * self.p) / self.s + 1)
154 | w_out = int((WX - WF + 2 * self.p) / self.s + 1)
155 |
156 | X_col = self.im2col_indices(X)
157 | self.X_col = X_col
158 | W_col = self.W.reshape(NF, -1)
159 |
160 | out = W_col @ self.X_col + self.b
161 | out = out.reshape(NF, h_out, w_out, NX)
162 | out = out.transpose(3, 0, 1, 2)
163 |
164 | return out
165 |
166 |
167 | def backward(self, dout, optimizer=None):
168 | NF, CF, HF, WF = self.W.shape
169 |
170 | db = np.sum(dout, axis=(0, 2, 3))
171 | db = db.reshape(NF, -1)
172 |
173 | dout_reshaped = dout.transpose(1, 2, 3, 0).reshape(NF, -1)
174 | dW = dout_reshaped @ self.X_col.T
175 | dW = dW.reshape(self.W.shape)
176 |
177 | if optimizer is not None:
178 | self.b = optimizer(self.b, db)
179 | self.W = optimizer(self.W, dW)
180 |
181 | W_reshape = self.W.reshape(NF, -1)
182 | dX_col = W_reshape.T @ dout_reshaped
183 | dX = self.col2im_indices(dX_col)
184 |
185 | return dX
186 |
187 | def get_im2col_indices(self):
188 | padding, stride, field_height, field_width, x_shape = self.p, self.s, self.k, self.k, self.x_shape
189 | N, C, H, W = x_shape
190 | # assert (H + 2 * padding - field_height) % stride == 0
191 | # assert (W + 2 * padding - field_height) % stride == 0
192 | out_height = int((H + 2 * padding - field_height) / stride + 1)
193 | out_width = int((W + 2 * padding - field_width) / stride + 1)
194 |
195 | i0 = np.repeat(np.arange(field_height), field_width)
196 | i0 = np.tile(i0, C)
197 | i1 = stride * np.repeat(np.arange(out_height), out_width)
198 | j0 = np.tile(np.arange(field_width), field_height * C)
199 | j1 = stride * np.tile(np.arange(out_width), out_height)
200 | i = i0.reshape(-1, 1) + i1.reshape(1, -1)
201 | j = j0.reshape(-1, 1) + j1.reshape(1, -1)
202 |
203 | k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1)
204 |
205 | return (k.astype(np.int), i.astype(np.int), j.astype(np.int))
206 |
207 |
208 | def im2col_indices(self, x):
209 | p, stride, field_height, field_width = self.p, self.s, self.k, self.k
210 | x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')
211 |
212 | k, i, j = self.get_im2col_indices()
213 |
214 | cols = x_padded[:, k, i, j]
215 | C = x.shape[1]
216 | cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1)
217 | return cols
218 |
219 |
220 | def col2im_indices(self, cols):
221 | field_height, field_width, padding, stride = self.k, self.k, self.p, self.s
222 | N, C, H, W = self.x_shape
223 | H_padded, W_padded = H + 2 * padding, W + 2 * padding
224 | x_padded = np.zeros((N, C, H_padded, W_padded), dtype=cols.dtype)
225 | k, i, j = self.get_im2col_indices()
226 | cols_reshaped = cols.reshape(C * field_height * field_width, -1, N)
227 | cols_reshaped = cols_reshaped.transpose(2, 0, 1)
228 | np.add.at(x_padded, (slice(None), k, i, j), cols_reshaped)
229 | if padding == 0:
230 | return x_padded
231 | return x_padded[:, :, padding:-padding, padding:-padding]
232 |
233 |
234 | class Softmax(Module):
235 | def __init__(self, dim=-1):
236 | super(Softmax, self).__init__()
237 | self.s = None
238 | self.dim = dim
239 | self.squeeze_len = None
240 |
241 | def forward(self, x, dim=None):
242 | if dim is not None:
243 | self.dim = dim
244 | if self.dim < 0:
245 | self.dim = len(x.shape)+self.dim
246 | self.squeeze_len = x.shape[self.dim]
247 | y = np.exp(x)
248 | s = y/np.sum(y, axis=self.dim, keepdims=True)
249 | self.s = s
250 | return s
251 |
252 | def backward(self, grad):
253 | self.s = np.expand_dims(self.s.swapaxes(self.dim,-1), -1)
254 | grad = np.expand_dims(grad.swapaxes(self.dim,-1), -1)
255 | mat = self.s @ self.s.swapaxes(-1,-2)
256 | mat = (-mat + np.eye(mat.shape[-1]) * (mat**0.5))
257 | grad = mat @ grad
258 | self.s = self.s.swapaxes(self.dim,-1).squeeze(-1)
259 | return grad.swapaxes(self.dim,-2).squeeze(-1)
260 |
261 |
262 | class Squash(Module):
263 | def __init__(self, dim=-1):
264 | super(Squash, self).__init__()
265 | self.dim = dim
266 | self.squeeze_len = None
267 | self.s = None
268 |
269 | def forward(self, s):
270 | self.s = s
271 | self.squeeze_len = s.shape[self.dim]
272 | norm2 = np.sum((s)**2, axis=self.dim, keepdims=True)
273 | return (np.sqrt(norm2) / (1.0 + norm2)) * s
274 |
275 | def backward(self, grad):
276 | norm2 = np.sum((self.s)**2, axis=self.dim, keepdims=True)
277 | norm = np.sqrt(norm2)
278 | temp = tile((1/(2*(1.+norm2)*norm) - norm/(1.+norm2)**2), self.squeeze_len, self.dim)
279 | dnorm2 = np.sum(self.s * temp, axis=-1, keepdims=True)
280 | factor = norm/(1+norm2)
281 | return grad * dnorm2 * (2.*self.s) + grad * factor
282 |
283 | class MSELoss(Module):
284 | def __init__(self):
285 | super(MSELoss, self).__init__()
286 | self.x = None
287 | self.y = None
288 |
289 | def forward(self, x, y):
290 | self.x = x
291 | self.y = y
292 | return np.sum((x - y)**2)/float(x.size), 2*(x - y)/float(x.size)
293 |
294 |
295 | class PrimaryCaps(Module):
296 | def __init__(self, use_cuda=False, out_channels=32, in_channels=256, mapsize=6, ndim=8, kernel_size=9, stride=2, padding=0):
297 | super(PrimaryCaps, self).__init__(trainable=True)
298 | self.ndim = ndim
299 | self.caps = [Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, pad=padding) for _ in
300 | range(ndim)]
301 |
302 | self.out_channels = out_channels
303 | self.mapsize = mapsize
304 | self.ncaps = out_channels * mapsize**2
305 | self.squash = Squash()
306 | self.x = None
307 |
308 | def _set_params(self, params):
309 | for i, c in enumerate(self.caps):
310 | c._set_params(params[i])
311 |
312 | def cap_forward(self, i, x):
313 | out = self.caps[i](x).reshape((x.shape[0], -1, 1))
314 | return out
315 |
316 | def cap_backward(self, i, grads, x, optimizer):
317 | self.caps[i]._set_input(x)
318 | out = np.expand_dims(self.caps[i].backward(
319 | grads[:,:,:,:,i], optimizer=optimizer), -1)
320 | return out
321 |
322 | def forward(self, x):
323 | t = time.time()
324 | # output (bs, ncaps, ndim)
325 | self.x_size = x.shape
326 | self.x = x
327 | with mp.Pool() as pool:
328 | u = pool.map(partial(self.cap_forward, x=x), np.arange(len(self.caps)))
329 | u = np.concatenate(u, axis=-1)
330 |
331 | return self.squash(u)
332 |
333 | def backward(self, grads, optimizer=None):
334 | t = time.time()
335 | grads = self.squash.backward(grads)
336 | grads = grads.reshape((self.x_size[0],self.out_channels, self.mapsize, self.mapsize,-1))
337 |
338 | with mp.Pool() as pool:
339 | grads = pool.map(partial(self.cap_backward, grads=grads, x=self.x, optimizer=optimizer), np.arange(len(self.caps)))
340 | grads = np.concatenate(grads, axis=-1)
341 | out = np.sum(grads, axis=-1)
342 |
343 | return out
344 |
345 |
346 | class Decoder(Module):
347 | def __init__(self):
348 | super(Decoder, self).__init__(trainable=True)
349 | self.net = Sequence([
350 | Linear(16*10,512),
351 | ReLU(),
352 | Linear(512,1024),
353 | ReLU(),
354 | Linear(1024,784),
355 | Sigmoid()
356 | ])
357 | self.x_shape = None
358 |
359 | def forward(self, x):
360 | self.x_shape = x.shape
361 | x = x.reshape(x.shape[0],-1)
362 |
363 | return self.net(x)
364 |
365 | def _set_params(self, params):
366 | for i, l in enumerate(self.net.trainable_modules()):
367 | l._set_params(params[i])
368 |
369 | def backward(self, grad, optimizer):
370 | return self.net.backward(grad, optimizer).reshape(self.x_shape)
371 |
372 |
373 | class DigitCaps(Module):
374 | def __init__(self, ncaps=10, ncaps_prev=32 * 6 * 6, ndim_prev=8, ndim=16):
375 | super(DigitCaps, self).__init__(trainable=True)
376 | self.ndim_prev = ndim_prev
377 | self.ncaps_prev = ncaps_prev
378 | self.ncaps = ncaps
379 | self.route_iter = 2
380 | self.W = np.random.randn(1, ncaps_prev, ncaps, ndim, ndim_prev)
381 | self.softmaxs = [Softmax() for _ in range(self.route_iter)]
382 | self.squashs = [Squash() for _ in range(self.route_iter)]
383 | self.u_hat = None
384 | self.bs = None
385 | self.b = [None] * self.route_iter
386 | self.v = [None] * self.route_iter
387 | self.x = None
388 |
389 | def _set_params(self, params):
390 | self.W = params
391 |
392 | def forward(self, x):
393 | t = time.time()
394 | self.bs = x.shape[0]
395 | self.x = x
396 | x = tile(x[:,:,None,:,None], self.ncaps, 2)
397 | W = tile(self.W, self.bs, 0)
398 | u_hat = W @ x
399 | self.u_hat = u_hat
400 | b = np.zeros((1, self.ncaps_prev, self.ncaps, 1, 1))
401 |
402 | for r in range(self.route_iter):
403 | self.b[r] = b
404 | c = self.softmaxs[r](b, dim=1)
405 |
406 | c = tile(c, self.bs, 0)
407 | s = np.sum(c * u_hat, axis=1, keepdims=True)
408 | v = self.squashs[r](s)
409 | if r == self.route_iter - 1:
410 | return np.squeeze(v, axis=1)
411 |
412 | self.v[r] = v
413 | p = u_hat.swapaxes(-1, -2) @ tile(v, self.ncaps_prev, 1)
414 | b = b + np.mean(p, axis=0, keepdims=True)
415 |
416 |
417 | def backward(self, grad, optimizer=None):
418 | t = time.time()
419 | grad_accum = np.zeros_like(self.u_hat)
420 | b_grad_accum = None
421 | grad = grad[:,None,:,:,:]
422 | for r in range(self.route_iter)[::-1]:
423 | if r < self.route_iter-1:
424 | grad = b_grad_accum
425 | grad = tile(grad, self.bs, 0)/self.bs
426 | p_grad = tile(self.v[r], self.ncaps_prev, 1) * grad
427 |
428 | grad_accum += p_grad
429 |
430 | grad = self.u_hat * grad
431 | grad = np.sum(grad, axis=1, keepdims=True)
432 |
433 | grad = self.squashs[r].backward(grad)
434 | grad = tile(grad, self.ncaps_prev, 1)
435 | c = self.softmaxs[r].s
436 | grad_accum += tile(c, self.bs, 0) * grad
437 | grad = self.u_hat.swapaxes(-1,-2) @ grad
438 |
439 | if r > 0:
440 | grad = np.sum(grad, axis=0, keepdims=True)
441 | grad = self.softmaxs[r].backward(grad)
442 | if b_grad_accum is None:
443 | b_grad_accum = grad
444 | else:
445 | b_grad_accum += grad
446 |
447 | x = tile(self.x[:,:,None,:,None], self.ncaps, 2)
448 | dW = np.sum(grad_accum @ x.swapaxes(-1,-2), axis=0, keepdims=True)
449 | if optimizer is not None:
450 | self.W = optimizer(self.W, dW)
451 |
452 | grad_accum = np.squeeze(self.W.swapaxes(-1,-2) @ grad_accum, axis=-1)
453 | dx = np.sum(grad_accum, axis=2)
454 | return dx
455 |
456 |
457 | class CapsNet(Module):
458 | def __init__(self, use_cuda=False, kernel_size=9, stride=1):
459 | super(CapsNet, self).__init__(trainable=True)
460 | self.net = Sequence([
461 | Conv2d(1,256,kernel_size=kernel_size,stride=stride),
462 | ReLU(),
463 | PrimaryCaps(),
464 | DigitCaps()
465 | ])
466 | self.decoder = Decoder()
467 | self.x = None
468 | self.digit_ndim = 16
469 | self.softmax = Softmax()
470 |
471 | def _set_params(self, params):
472 | for i, m in enumerate(self.net.trainable_modules() + [self.decoder]):
473 | m._set_params(params)
474 |
475 | def forward(self, x):
476 | x = self.net(x)
477 | self.x = x
478 | reconst = self.decoder(x)
479 | scores = np.sqrt((x ** 2).sum(2)).squeeze()
480 | return scores, reconst
481 |
482 | def backward(self, grad, optimizer):
483 | scores_grad, reconst_grad = grad
484 |
485 | scores_grad = scores_grad[:,:,None, None]
486 | t = 0.5 * ((self.x ** 2).sum(2, keepdims=True) ** (-0.5))
487 | scores_grad *= 0.5 * ((self.x ** 2).sum(2, keepdims=True) ** (-0.5))
488 | scores_grad = tile(scores_grad, self.digit_ndim, 2) # tile at dimension 2
489 | scores_grad *= 2*self.x
490 | t = time.time()
491 |
492 | reconst_grad = self.decoder.backward(reconst_grad, optimizer)
493 | grad = scores_grad + reconst_grad
494 |
495 | grad = self.net.backward(grad, optimizer=optimizer)
496 | return grad
497 |
498 |
499 | class CapsLoss(Module):
500 | def __init__(self):
501 | super(CapsLoss, self).__init__()
502 | self.mse_loss = MSELoss()
503 | self.relu1 = ReLU()
504 | self.relu2 = ReLU()
505 | self.reconst_factor = 0.0005
506 |
507 |
508 | def forward(self, norms, labels, reconst, inpt):
509 | self.labels = labels
510 |
511 | int1 = self.relu1(0.9 - norms)
512 | int2 = self.relu2(norms - 0.1)
513 | margin_loss = labels * int1**2 + 0.5*(1-labels) * int2**2
514 | bs, ndim_prev = margin_loss.shape[0], margin_loss.shape[-1]
515 | margin_loss = np.sum(margin_loss, axis=-1).mean()
516 |
517 | reconst_loss, reconst_grad = self.mse_loss(reconst.reshape(reconst.shape[0],-1), inpt.reshape(inpt.shape[0],-1))
518 | loss = margin_loss + self.reconst_factor * reconst_loss
519 |
520 | margin_grad = np.ones((bs, ndim_prev)) / float(bs)
521 | margin_grad_pos = -self.relu1.backward(margin_grad * labels * (2*int1))
522 | margin_grad_neg = self.relu2.backward(margin_grad * 0.5*(1-labels) * (2*int2))
523 |
524 | margin_grad = margin_grad_pos + margin_grad_neg
525 | reconst_grad *= self.reconst_factor
526 |
527 | return loss, (margin_grad, reconst_grad)
--------------------------------------------------------------------------------