├── LICENSE ├── .gitignore ├── train.py ├── README.md ├── visualize.py └── nets.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Sosuke Kobayashi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import json 5 | import numpy as np 6 | 7 | import chainer 8 | from chainer.dataset.convert import concat_examples 9 | from chainer import serializers 10 | 11 | import nets 12 | 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser(description='CapsNet: MNIST') 16 | parser.add_argument('--batchsize', '-b', type=int, default=256) 17 | parser.add_argument('--decay', '-d', type=float, default=0.95) 18 | parser.add_argument('--epoch', '-e', type=int, default=500) 19 | parser.add_argument('--gpu', '-g', type=int, default=-1) 20 | parser.add_argument('--seed', '-s', type=int, default=789) 21 | parser.add_argument('--reconstruct', '--recon', action='store_true') 22 | parser.add_argument('--save') 23 | args = parser.parse_args() 24 | print(json.dumps(args.__dict__, indent=2)) 25 | 26 | # Set up a neural network to train 27 | np.random.seed(args.seed) 28 | model = nets.CapsNet(use_reconstruction=args.reconstruct) 29 | if args.gpu >= 0: 30 | # Make a speciied GPU current 31 | chainer.cuda.get_device_from_id(args.gpu).use() 32 | model.to_gpu() # Copy the model to the GPU 33 | np.random.seed(args.seed) 34 | model.xp.random.seed(args.seed) 35 | 36 | # Setup an optimizer 37 | optimizer = chainer.optimizers.Adam(alpha=1e-3) 38 | optimizer.setup(model) 39 | 40 | # Load the MNIST dataset 41 | train, test = chainer.datasets.get_mnist(ndim=3) 42 | train_iter = chainer.iterators.SerialIterator(train, args.batchsize) 43 | test_iter = chainer.iterators.SerialIterator(test, 100, 44 | repeat=False, shuffle=False) 45 | 46 | def report(epoch, result): 47 | mode = 'train' if chainer.config.train else 'test ' 48 | print('epoch {:2d}\t{} mean loss: {}, accuracy: {}'.format( 49 | train_iter.epoch, mode, result['mean_loss'], result['accuracy'])) 50 | if args.reconstruct: 51 | print('\t\t\tclassification: {}, reconstruction: {}'.format( 52 | result['cls_loss'], result['rcn_loss'])) 53 | 54 | best = 0. 55 | best_epoch = 0 56 | print('TRAINING starts') 57 | while train_iter.epoch < args.epoch: 58 | batch = train_iter.next() 59 | x, t = concat_examples(batch, args.gpu) 60 | optimizer.update(model, x, t) 61 | 62 | # evaluation 63 | if train_iter.is_new_epoch: 64 | result = model.pop_results() 65 | report(train_iter.epoch, result) 66 | 67 | with chainer.no_backprop_mode(): 68 | with chainer.using_config('train', False): 69 | for batch in test_iter: 70 | x, t = concat_examples(batch, args.gpu) 71 | loss = model(x, t) 72 | result = model.pop_results() 73 | report(train_iter.epoch, result) 74 | if result['accuracy'] > best: 75 | best, best_epoch = result['accuracy'], train_iter.epoch 76 | serializers.save_npz(args.save, model) 77 | 78 | optimizer.alpha *= args.decay 79 | optimizer.alpha = max(optimizer.alpha, 1e-5) 80 | print('\t\t# optimizer alpha', optimizer.alpha) 81 | test_iter.reset() 82 | print('Finish: Best accuray: {} at {} epoch'.format(best, best_epoch)) 83 | 84 | 85 | if __name__ == '__main__': 86 | main() 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dynamic Routing Between Capsules 2 | 3 | Chainer implementation of CapsNet for MNIST. 4 | 5 | For the detail, see [Dynamic Routing Between Capsules](https://arxiv.org/pdf/1710.09829.pdf), Sara Sabour, Nicholas Frosst, Geoffrey E Hinton, NIPS 2017. 6 | 7 | ``` 8 | python -u train.py -g 0 --save saved_model --reconstruct 9 | ``` 10 | 11 | Test accuracy of a trained model (without reconstruction) reached 99.60%. 12 | The paper does not provide detailed information about initialization and optimization, so the performance might not reach that in the paper. For alleviating those issues, I replaced relu with leaky relu with a very small slope (0.05). The modified model achieved 99.66% (i.e. error rate is 0.34%), as the paper reported. 13 | 14 | 15 | ## Visualization through Reconstruction 16 | 17 | ``` 18 | python visualize.py -g 0 --load saved_model 19 | ``` 20 | 21 | produces some images for analyzing digit capsules. 22 | 23 | ### Different masks 24 | 25 | ![vis_all.png](https://raw.githubusercontent.com/soskek/dynamic_routing_between_capsules/upload-imgs/data/vis_imgs/vis_all.png) 26 | 27 | The top green images are real images which are given to the model. Blue images in i-th represents reconstructed ones of digit "i". 28 | 29 | If an correct digit is selected as a target, the model reconstructs an image well (see the diagonal cells). 30 | 31 | If an irrelevant target is selected, the reconstructed image gets spoiled (see "0" and the others in the column leftmost), maybe because of lack of information in its digit capsule. However, reconstruction toward a relevant target is not always spoiled, even if a target is not correct (see "8" and "9" the column rightmost). 32 | 33 | 34 | ### Interpolation of values in digit capsules 35 | 36 | Here, I show reconstructed images after linearly tweaking the value in a dimension in the capsule (as well as section 5.1 and figure 4 in the paper). Green images in the center are reconstructed images without perturbation. Note that a dimension has a different factor if the digit capsule differs, because each matrix for reconstructing each digit is unshared. 37 | 38 | You can find and enjoy some factors of variation. 39 | 40 | ![vis_tweaked0.png](https://raw.githubusercontent.com/soskek/dynamic_routing_between_capsules/upload-imgs/data/vis_imgs/vis_tweaked0.png) 41 | 42 | ![vis_tweaked1.png](https://raw.githubusercontent.com/soskek/dynamic_routing_between_capsules/upload-imgs/data/vis_imgs/vis_tweaked1.png) 43 | 44 | ![vis_tweaked2.png](https://raw.githubusercontent.com/soskek/dynamic_routing_between_capsules/upload-imgs/data/vis_imgs/vis_tweaked2.png) 45 | 46 | ![vis_tweaked3.png](https://raw.githubusercontent.com/soskek/dynamic_routing_between_capsules/upload-imgs/data/vis_imgs/vis_tweaked3.png) 47 | 48 | ![vis_tweaked4.png](https://raw.githubusercontent.com/soskek/dynamic_routing_between_capsules/upload-imgs/data/vis_imgs/vis_tweaked4.png) 49 | 50 | ![vis_tweaked5.png](https://raw.githubusercontent.com/soskek/dynamic_routing_between_capsules/upload-imgs/data/vis_imgs/vis_tweaked5.png) 51 | 52 | ![vis_tweaked6.png](https://raw.githubusercontent.com/soskek/dynamic_routing_between_capsules/upload-imgs/data/vis_imgs/vis_tweaked6.png) 53 | 54 | ![vis_tweaked7.png](https://raw.githubusercontent.com/soskek/dynamic_routing_between_capsules/upload-imgs/data/vis_imgs/vis_tweaked7.png) 55 | 56 | ![vis_tweaked8.png](https://raw.githubusercontent.com/soskek/dynamic_routing_between_capsules/upload-imgs/data/vis_imgs/vis_tweaked8.png) 57 | 58 | ![vis_tweaked9.png](https://raw.githubusercontent.com/soskek/dynamic_routing_between_capsules/upload-imgs/data/vis_imgs/vis_tweaked9.png) 59 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | import matplotlib.pyplot as plt 5 | 6 | import argparse 7 | import json 8 | import numpy as np 9 | 10 | import chainer 11 | from chainer.dataset.convert import concat_examples 12 | from chainer import serializers 13 | 14 | import nets 15 | 16 | 17 | def save_images(xs, filename, marked_row=0): 18 | width = xs[0].shape[0] 19 | height = len(xs) 20 | 21 | xs = [np.array(x.tolist(), np.float32) for x in xs] 22 | # subplots with many figs are very slow 23 | fig, ax = plt.subplots( 24 | height, width, figsize=(1 * width / 2.5, height / 2.5)) 25 | xs = np.concatenate(xs, axis=0) 26 | for i, (ai, xi) in enumerate(zip(ax.ravel(), xs)): 27 | ai.set_xticklabels([]) 28 | ai.set_yticklabels([]) 29 | ai.set_axis_off() 30 | color = 'Greens_r' if i // width == marked_row else 'Blues_r' 31 | ai.imshow(xi.reshape(28, 28), cmap=color, vmin=0., vmax=1.) 32 | 33 | plt.subplots_adjust( 34 | left=None, bottom=None, right=None, top=None, wspace=0.05, hspace=0.05) 35 | # saving and clearing subplots with many figs are also very slow 36 | fig.savefig(filename, bbox_inches='tight', pad=0.) 37 | plt.clf() 38 | plt.close('all') 39 | 40 | 41 | def visualize_reconstruction(model, x, t, filename='vis.png'): 42 | print('visualize', filename) 43 | vs_norm, vs = model.output(x) 44 | x_recon = model.reconstruct(vs, t) 45 | save_images([x, x_recon.data], 46 | filename) 47 | 48 | 49 | def visualize_reconstruction_alldigits(model, x, t, filename='vis_all.png'): 50 | print('visualize', filename) 51 | x_recon_list = [] 52 | vs_norm, vs = model.output(x) 53 | for i in range(10): 54 | pseudo_t = model.xp.full(t.shape, i).astype('i') 55 | x_recon = model.reconstruct(vs, pseudo_t).data 56 | x_recon_list.append(x_recon) 57 | save_images([x] + x_recon_list, 58 | filename) 59 | 60 | 61 | def visualize_reconstruction_tweaked(model, x, t, filename='vis_tweaked.png'): 62 | print('visualize', filename) 63 | x_recon_list = [] 64 | vs_norm, vs = model.output(x) 65 | vs = vs.data 66 | vs = model.xp.concatenate([vs] * 16, axis=0) 67 | t = model.xp.concatenate([t] * 16, axis=0) 68 | I = model.xp.arange(16) 69 | for i in range(9): 70 | tweaked_vs = model.xp.array(vs) 71 | tweaked_vs[I, I, :] += (i - 4.) * 0.075 # raw + [-0.30, 0.30] 72 | x_recon = model.reconstruct(tweaked_vs, t).data 73 | x_recon_list.append(x_recon) 74 | x_recon = model.reconstruct(vs, t).data 75 | save_images(x_recon_list, 76 | filename, 77 | marked_row=4) 78 | 79 | 80 | def get_samples(dataset): 81 | # 2 samples for each digit 82 | samples = [] 83 | for i, (x, t) in enumerate(dataset): 84 | if t == len(samples) // 2: 85 | print('{}-th sample is used'.format(i)) 86 | samples.append((x, t)) 87 | if len(samples) >= 20: 88 | break 89 | return samples 90 | 91 | 92 | if __name__ == '__main__': 93 | parser = argparse.ArgumentParser( 94 | description='CapsNet: MNIST reconstruction') 95 | parser.add_argument('--gpu', '-g', type=int, default=-1) 96 | parser.add_argument('--load') 97 | args = parser.parse_args() 98 | print(json.dumps(args.__dict__, indent=2)) 99 | 100 | model = nets.CapsNet(use_reconstruction=True) 101 | serializers.load_npz(args.load, model) 102 | if args.gpu >= 0: 103 | chainer.cuda.get_device_from_id(args.gpu).use() 104 | model.to_gpu() 105 | _, test = chainer.datasets.get_mnist(ndim=3) 106 | 107 | batch = get_samples(test) 108 | x, t = concat_examples(batch, args.gpu) 109 | 110 | with chainer.no_backprop_mode(): 111 | with chainer.using_config('train', False): 112 | visualize_reconstruction(model, x, t) 113 | visualize_reconstruction_alldigits(model, x, t) 114 | for i in range(10): 115 | visualize_reconstruction_tweaked( 116 | model, x[i * 2: i * 2 + 1], t[i * 2: i * 2 + 1], 117 | filename='vis_tweaked{}.png'.format(i)) 118 | -------------------------------------------------------------------------------- /nets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import chainer 4 | from chainer import cuda 5 | import chainer.functions as F 6 | import chainer.links as L 7 | 8 | 9 | def _augmentation(x): 10 | xp = cuda.get_array_module(x) 11 | MAX_SHIFT = 2 12 | batchsize, ch, h, w = x.shape 13 | h_shift, w_shift = xp.random.randint(-MAX_SHIFT, MAX_SHIFT + 1, size=2) 14 | a_h_sl = slice(max(0, h_shift), h_shift + h) 15 | a_w_sl = slice(max(0, w_shift), w_shift + w) 16 | x_h_sl = slice(max(0, - h_shift), - h_shift + h) 17 | x_w_sl = slice(max(0, - w_shift), - w_shift + w) 18 | a = xp.zeros(x.shape) 19 | a[:, :, a_h_sl, a_w_sl] = x[:, :, x_h_sl, x_w_sl] 20 | return a.astype(x.dtype) 21 | 22 | 23 | def _count_params(m, n_grids=6): 24 | print('# of params', sum(param.size for param in m.params())) 25 | # The number of parameters in the paper (11.36M) might be 26 | # of the model with unshared matrices over primary capsules in a same grid 27 | # when input data are 36x36 images of MultiMNIST (n_grids = 10). 28 | # Our model with n_grids=10 has 11.349008M parameters. 29 | # (In the Sec. 4, the paper says "each capsule in the [6, 6] grid 30 | # is sharing their weights with each other.") 31 | print('# of params if unshared', 32 | sum(param.size for param in m.params()) + 33 | sum(param.size for param in m.Ws.params()) * 34 | (n_grids * n_grids - 1)) 35 | 36 | 37 | def squash(ss): 38 | ss_norm2 = F.sum(ss ** 2, axis=1, keepdims=True) 39 | """ 40 | # ss_norm2 = F.broadcast_to(ss_norm2, ss.shape) 41 | # vs = ss_norm2 / (1. + ss_norm2) * ss / F.sqrt(ss_norm2): naive 42 | """ 43 | norm_div_1pnorm2 = F.sqrt(ss_norm2) / (1. + ss_norm2) 44 | norm_div_1pnorm2 = F.broadcast_to(norm_div_1pnorm2, ss.shape) 45 | vs = norm_div_1pnorm2 * ss # :efficient 46 | # (batchsize, 16, 10) 47 | return vs 48 | 49 | 50 | def get_norm(vs): 51 | return F.sqrt(F.sum(vs ** 2, axis=1)) 52 | 53 | 54 | init = chainer.initializers.Uniform(scale=0.05) 55 | 56 | 57 | class CapsNet(chainer.Chain): 58 | 59 | def __init__(self, use_reconstruction=False): 60 | super(CapsNet, self).__init__() 61 | self.n_iterations = 3 # dynamic routing 62 | self.n_grids = 6 # grid width of primary capsules layer 63 | self.n_raw_grids = self.n_grids 64 | self.use_reconstruction = use_reconstruction 65 | with self.init_scope(): 66 | self.conv1 = L.Convolution2D(1, 256, ksize=9, stride=1, 67 | initialW=init) 68 | self.conv2 = L.Convolution2D(256, 32 * 8, ksize=9, stride=2, 69 | initialW=init) 70 | self.Ws = chainer.ChainList( 71 | *[L.Convolution2D(8, 16 * 10, ksize=1, stride=1, initialW=init) 72 | for i in range(32)]) 73 | 74 | self.fc1 = L.Linear(16 * 10, 512, initialW=init) 75 | self.fc2 = L.Linear(512, 1024, initialW=init) 76 | self.fc3 = L.Linear(1024, 784, initialW=init) 77 | 78 | _count_params(self, n_grids=self.n_grids) 79 | self.results = {'N': 0., 'loss': [], 'correct': [], 80 | 'cls_loss': [], 'rcn_loss': []} 81 | 82 | def pop_results(self): 83 | merge = dict() 84 | merge['mean_loss'] = sum(self.results['loss']) / self.results['N'] 85 | merge['cls_loss'] = sum(self.results['cls_loss']) / self.results['N'] 86 | merge['rcn_loss'] = sum(self.results['rcn_loss']) / self.results['N'] 87 | merge['accuracy'] = sum(self.results['correct']) / self.results['N'] 88 | self.results = {'N': 0., 'loss': [], 'correct': [], 89 | 'cls_loss': [], 'rcn_loss': []} 90 | return merge 91 | 92 | def __call__(self, x, t): 93 | if chainer.config.train: 94 | x = _augmentation(x) 95 | vs_norm, vs = self.output(x) 96 | self.loss = self.calculate_loss(vs_norm, t, vs, x) 97 | 98 | self.results['loss'].append(self.loss.data * t.shape[0]) 99 | self.results['correct'].append(self.calculate_correct(vs_norm, t)) 100 | self.results['N'] += t.shape[0] 101 | return self.loss 102 | 103 | def output(self, x): 104 | batchsize = x.shape[0] 105 | n_iters = self.n_iterations 106 | gg = self.n_grids * self.n_grids 107 | 108 | # h1 = F.relu(self.conv1(x)) 109 | h1 = F.leaky_relu(self.conv1(x), 0.05) 110 | pr_caps = F.split_axis(self.conv2(h1), 32, axis=1) 111 | # shapes if MNIST. -> if MultiMNIST 112 | # x (batchsize, 1, 28, 28) -> (:, :, 36, 36) 113 | # h1 (batchsize, 256, 20, 20) -> (:, :, 28, 28) 114 | # pr_cap (batchsize, 8, 6, 6) -> (:, :, 10, 10) 115 | 116 | Preds = [] 117 | for i in range(32): 118 | pred = self.Ws[i](pr_caps[i]) 119 | Pred = pred.reshape((batchsize, 16, 10, gg)) 120 | Preds.append(Pred) 121 | Preds = F.stack(Preds, axis=3) 122 | assert(Preds.shape == (batchsize, 16, 10, 32, gg)) 123 | 124 | bs = self.xp.zeros((batchsize, 10, 32, gg), dtype='f') 125 | for i_iter in range(n_iters): 126 | cs = F.softmax(bs, axis=1) 127 | Cs = F.broadcast_to(cs[:, None], Preds.shape) 128 | assert(Cs.shape == (batchsize, 16, 10, 32, gg)) 129 | ss = F.sum(Cs * Preds, axis=(3, 4)) 130 | vs = squash(ss) 131 | assert(vs.shape == (batchsize, 16, 10)) 132 | 133 | if i_iter != n_iters - 1: 134 | Vs = F.broadcast_to(vs[:, :, :, None, None], Preds.shape) 135 | assert(Vs.shape == (batchsize, 16, 10, 32, gg)) 136 | bs = bs + F.sum(Vs * Preds, axis=1) 137 | assert(bs.shape == (batchsize, 10, 32, gg)) 138 | 139 | vs_norm = get_norm(vs) 140 | return vs_norm, vs 141 | 142 | def reconstruct(self, vs, t): 143 | xp = self.xp 144 | batchsize = t.shape[0] 145 | I = xp.arange(batchsize) 146 | mask = xp.zeros(vs.shape, dtype='f') 147 | mask[I, :, t] = 1. 148 | masked_vs = mask * vs 149 | 150 | x_recon = F.sigmoid( 151 | self.fc3(F.relu( 152 | self.fc2(F.relu( 153 | self.fc1(masked_vs)))))).reshape((batchsize, 1, 28, 28)) 154 | return x_recon 155 | 156 | def calculate_loss(self, vs_norm, t, vs, x): 157 | class_loss = self.calculate_classification_loss(vs_norm, t) 158 | self.results['cls_loss'].append(class_loss.data * t.shape[0]) 159 | if self.use_reconstruction: 160 | recon_loss = self.calculate_reconstruction_loss(vs, t, x) 161 | self.results['rcn_loss'].append(recon_loss.data * t.shape[0]) 162 | return class_loss + 0.0005 * recon_loss 163 | else: 164 | return class_loss 165 | 166 | def calculate_classification_loss(self, vs_norm, t): 167 | xp = self.xp 168 | batchsize = t.shape[0] 169 | I = xp.arange(batchsize) 170 | T = xp.zeros(vs_norm.shape, dtype='f') 171 | T[I, t] = 1. 172 | m = xp.full(vs_norm.shape, 0.1, dtype='f') 173 | m[I, t] = 0.9 174 | 175 | loss = T * F.relu(m - vs_norm) ** 2 + \ 176 | 0.5 * (1. - T) * F.relu(vs_norm - m) ** 2 177 | return F.sum(loss) / batchsize 178 | 179 | def calculate_reconstruction_loss(self, vs, t, x): 180 | batchsize = t.shape[0] 181 | x_recon = self.reconstruct(vs, t) 182 | loss = (x_recon - x) ** 2 183 | return F.sum(loss) / batchsize 184 | 185 | def calculate_correct(self, v, t): 186 | return (self.xp.argmax(v.data, axis=1) == t).sum() 187 | --------------------------------------------------------------------------------