├── LICENCE ├── README.md ├── glimpseSensor.py ├── media ├── 2cross_entropy.png ├── VAM_gif.gif ├── attention.png ├── combined.png ├── glimpsenetwork.png ├── glimpses.png ├── glimpsesensor.png ├── model.png └── sampling.png ├── network.py ├── train.py ├── visualize.py └── weightdecay.py /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Alok Kumar Bishoyi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Visual Attention Model 2 | Chainer implementation of Deepmind's Recurrent Models of Visual Attention. 3 | ![Image](media/VAM_gif.gif) 4 | Humans do not tend to process a whole scene in its entirety at once. Instead we focus attention selectively on parts of the visual space to acquire information when and where it is needed, and combine information from different fixations over time to build up an internal representation of the scene.Focusing the computational resources on parts of a scene saves “bandwidth” 5 | as fewer “pixels” need to be processed. 6 | ![Image](media/attention.png) 7 | 8 | The model is a recurrent neural network (RNN) which processes inputs sequentially, attending to 9 | different locations within the images (or video frames) one at a time, and incrementally combines 10 | information from these fixations to build up a dynamic internal representation of the scene or envi- 11 | ronment. Instead of processing an entire image or even bounding box at once, at each step, the model 12 | selects the next location to attend to based on past information 13 | and 14 | the demands of the task. Both 15 | the number of parameters in the model and the amount of computation it performs can be controlled 16 | independently of the size of the input image, which is in contrast to convolutional networks whose 17 | computational demands scale linearly with the number of image pixels. 18 | ## The Network Architecture 19 | 20 |
Network Architecture image from Sunner Li's Blogpost
21 |
Glimpse Sensor 22 | 23 | 24 |
25 | Glimpse Sensor is the implementation of RetinaThe idea is to allow our network to “take a glance” at the image around a given location, called a glimpse, then extract and resize this glimpse into various scales of image 26 | crops, but each scale is using the same resolution. For example, the glimpse in the above example contains 3 different scales, each scale has the same resolution (a.k.a. sensor bandwidth), e.g. 12x12. Therefore, the smallest 27 | scale of crop in the centre is most detailed, whereas the largest crop in the outer ring is most blurred. In summary, Glimpse Sensor takes a full-sized image and a location, outputs the “Retina-like” representation of the image 28 | around the given location. 29 |

30 | Glimpse Network 31 | 32 | 33 | 34 | 35 | Once we have defined glimpse sensor, Glimpse Network is simply a wrapped around Glimpse Sensor, to take a full-sized image and a location, extract a retina representation of the image via Glimpse Sensor, flatten, then combine the 36 | extracted retina representation with the glimpse location using hidden layers and ReLU, emitting a single vector g. This vector contains the information of both “what” (our retina representation) and “where” (the focused location within the image). 37 |

Recurrent Network
38 | Recurrent Network takes feature vector input from Glimpse Network, remembers the useful information via it’s hidden states (and memory cell). 39 |

40 | Location Network
41 | Location Network takes hidden states from Recurrent Network as input, and tries to predict the next location to look at. This location prediction will become input to the Glimpse Network in the next time step in the unrolled recurrent network. The Location Network is the key component in this whole idea since it directly determines where to pay attention to in the next time step. In order to maximize the performance of this Location Network, the paper introduce a stochastic process (i.e. gaussian distribution) to generate next location, and use reinforcement learning techniques to learn. It is also known as “hard” attention, since this stochastic process is non-differentiable (compared to “soft” attention). The intuition behind stochasticity is to balance between exploitation (to predict future using the history) and exploration (to try unprecedented stuff). Note that, this stochasticity makes the component non-differentiable, which will incur problem during back-propagation. And REINFORCE gradient policy algorithm is used to solve this problem. 42 |

Activation Network
43 | Activation Network takes hidden states from Recurrent Network as input, and tries to predict the digit. In addition, the prediction result is used to generate the reward point, which is used to train the Location Network (since the stochasticity makes it non-differentiable). 44 | 45 |

Architecture Combined
46 | Combining all the element illustrated above, we have our network architecture below. 47 | ![Image](media/combined.png) 48 | 49 | ## Experiments 50 | - [x] MNIST 51 | ![Image](media/2cross_entropy.png) 52 | - [ ] Translated MNIST 53 | - [ ] Cluttered MNIST 54 | - [ ] SVHN 55 | 56 | ## Credits 57 | Some of the texts and images have been medium posts by
Tristan and Sunner Li 58 | -------------------------------------------------------------------------------- /glimpseSensor.py: -------------------------------------------------------------------------------- 1 | from chainer import cuda 2 | from chainer import function 3 | import numpy as np 4 | 5 | 6 | class GlimpseSensor(function.Function): 7 | def __init__(self, center, output_size,depth=1, scale=2, using_conv = False, ): 8 | if type(output_size) is not tuple: 9 | self.output_size = output_size 10 | else: 11 | assert output_size[0] == output_size[1],"Output dims must be same" 12 | self.output_size = output_size[0] 13 | self.center = center 14 | self.depth = depth 15 | self.scale = scale 16 | self.using_conv = using_conv 17 | 18 | def forward(self, images): 19 | xp = cuda.get_array_module(*images) 20 | 21 | n, c, h_i, w_i = images[0].shape 22 | assert h_i == w_i, "Image should be square" 23 | size_i = h_i 24 | size_o = self.output_size 25 | 26 | # [-1, 1]^2 -> [0, size_i - 1]x[0, size_i - 1] 27 | center = (0.5 * (self.center + 1) * (size_i - 1)).data # center:shape -> [n X 2] 28 | y = xp.zeros(shape=(n, c*self.depth, size_o, size_o), dtype=xp.float32) 29 | 30 | xmin = xp.zeros(shape=(self.depth, n), dtype=xp.int32) 31 | ymin = xp.zeros(shape=(self.depth, n), dtype=xp.int32) 32 | xmax = xp.zeros(shape=(self.depth, n), dtype=xp.int32) 33 | ymax = xp.zeros(shape=(self.depth, n), dtype=xp.int32) 34 | 35 | xstart = xp.zeros(shape=(self.depth, n), dtype=xp.int32) 36 | ystart = xp.zeros(shape=(self.depth, n), dtype=xp.int32) 37 | 38 | 39 | for depth in range(self.depth): 40 | xmin[depth] = xp.clip(xp.rint(center[:, 0]) - (0.5 * size_o * (np.power(self.scale,depth))), 0., size_i).astype(xp.int32) 41 | ymin[depth] = xp.clip(xp.rint(center[:, 1]) - (0.5 * size_o * (np.power(self.scale,depth))), 0., size_i).astype(xp.int32) 42 | xmax[depth] = xp.clip(xp.rint(center[:, 0]) + (0.5 * size_o * (np.power(self.scale,depth))), 0., size_i).astype(xp.int32) 43 | ymax[depth] = xp.clip(xp.rint(center[:, 1]) + (0.5 * size_o * (xp.power(self.scale,depth))), 0., size_i).astype(xp.int32) 44 | 45 | xstart[depth] = xmin[depth] - (xp.rint(center[:, 0]) - (0.5 * size_o * (np.power(self.scale,depth)))) 46 | ystart[depth] = ymin[depth] - (xp.rint(center[:, 1]) - (0.5 * size_o * (np.power(self.scale,depth)))) 47 | 48 | for i in range(n): 49 | for j in range(self.depth): 50 | 51 | cropped = images[0][i][:,xmin[j][i]:xmax[j][i], ymin[j][i]:ymax[j][i]] 52 | # TODO: resize images 53 | 54 | y[i][c*j: (c*j)+c, xstart[j][i]: xstart[j][i] + xmax[j][i] - xmin[j][i] , 55 | ystart[j][i]: ystart[j][i] + ymax[j][i] - ymin[j][i]] += cropped 56 | 57 | if self.using_conv: 58 | return y, 59 | else: 60 | return y.reshape(n,-1), 61 | 62 | def backward(self, images, gy): 63 | #return zero grad 64 | xp = cuda.get_array_module(*images) 65 | n, c_in ,h_i, w_i = images[0].shape 66 | gx = xp.zeros(shape=(n, c_in, h_i, w_i), dtype=xp.float32) 67 | return gx, 68 | 69 | 70 | def getGlimpses(x, center, size, depth=1, scale=2, using_conv = False): 71 | return GlimpseSensor(center, size, depth, scale)(x) 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /media/2cross_entropy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alokwhitewolf/Visual-Attention-Model/f3a14f7399fbab18372b06bcf763fc608c35eb47/media/2cross_entropy.png -------------------------------------------------------------------------------- /media/VAM_gif.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alokwhitewolf/Visual-Attention-Model/f3a14f7399fbab18372b06bcf763fc608c35eb47/media/VAM_gif.gif -------------------------------------------------------------------------------- /media/attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alokwhitewolf/Visual-Attention-Model/f3a14f7399fbab18372b06bcf763fc608c35eb47/media/attention.png -------------------------------------------------------------------------------- /media/combined.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alokwhitewolf/Visual-Attention-Model/f3a14f7399fbab18372b06bcf763fc608c35eb47/media/combined.png -------------------------------------------------------------------------------- /media/glimpsenetwork.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alokwhitewolf/Visual-Attention-Model/f3a14f7399fbab18372b06bcf763fc608c35eb47/media/glimpsenetwork.png -------------------------------------------------------------------------------- /media/glimpses.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alokwhitewolf/Visual-Attention-Model/f3a14f7399fbab18372b06bcf763fc608c35eb47/media/glimpses.png -------------------------------------------------------------------------------- /media/glimpsesensor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alokwhitewolf/Visual-Attention-Model/f3a14f7399fbab18372b06bcf763fc608c35eb47/media/glimpsesensor.png -------------------------------------------------------------------------------- /media/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alokwhitewolf/Visual-Attention-Model/f3a14f7399fbab18372b06bcf763fc608c35eb47/media/model.png -------------------------------------------------------------------------------- /media/sampling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alokwhitewolf/Visual-Attention-Model/f3a14f7399fbab18372b06bcf763fc608c35eb47/media/sampling.png -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | from chainer import cuda, Variable 3 | import chainer.functions as F 4 | import chainer.links as L 5 | import numpy as np 6 | from glimpseSensor import getGlimpses 7 | from chainer import reporter 8 | 9 | 10 | class RAM(chainer.Chain): 11 | def __init__(self, n_hidden = 256, n_units = 128, sigma = 0.03, 12 | g_size=8, n_steps=6, n_depth=1, n_scale = 2, n_class = 10, using_conv = False): 13 | 14 | n_in = g_size * g_size * n_depth 15 | super(RAM, self).__init__( 16 | ll1=L.Linear(2, n_units), # 2 refers to x,y coordinate 17 | lrho1=L.Linear(n_in, n_units), 18 | lh1=L.Linear(n_units*2, n_hidden), 19 | lh2=L.Linear(n_hidden, n_hidden), 20 | lstm=L.LSTM(n_hidden, n_hidden), 21 | ly=L.Linear(n_hidden, n_class), # class/action output 22 | ll=L.Linear(n_hidden, 2), # location output 23 | lb=L.Linear(n_hidden, 1), # baseline output 24 | ) 25 | self.g_size = g_size 26 | self.n_depth = n_depth 27 | self.n_scale = n_scale 28 | self.sigma = sigma 29 | self.n_steps = n_steps 30 | self.using_conv = using_conv 31 | 32 | def reset_state(self): 33 | self.lstm.reset_state() 34 | 35 | def get_location_loss(self, l, mean): 36 | if chainer.config.train: 37 | term1 = 0.5 * (l - mean) ** 2 * self.sigma ** -2 38 | return F.sum(term1, axis=1).reshape(-1,1) 39 | else: 40 | xp = cuda.get_array_module(l) 41 | return Variable(xp.zeros(l.shape[0])) 42 | 43 | def sample_location(self, l): 44 | """ 45 | sample new location from center l_data 46 | """ 47 | if chainer.global_config.train: 48 | l_data = l.data 49 | bs = l_data.shape[0] 50 | xp = cuda.get_array_module(l_data) 51 | randomness = (xp.random.normal(0, 1, size=(bs, 2))).astype(np.float32) 52 | l_sampled = l_data + np.sqrt(self.sigma) * randomness 53 | return Variable(xp.array(l_sampled)) 54 | else: 55 | return l 56 | 57 | def forward(self, x, l, first=False): 58 | if not first: 59 | centers = self.sample_location(l) 60 | ln_pi = self.get_location_loss(centers, l) 61 | 62 | 63 | else: 64 | centers = l 65 | ln_pi = self.get_location_loss(l, l) # ==0's 66 | rho = getGlimpses(x, centers, self.g_size, self.n_depth, self.n_scale, self.using_conv) 67 | 68 | g0 = F.relu(self.ll1(centers)) 69 | g1 = F.relu(self.lrho1(rho)) 70 | h0 = F.concat([g0, g1], axis=1) 71 | h1 = F.relu(self.lh1(h0)) 72 | h2 = F.relu(self.lh2(h1)) 73 | h_out = self.lstm(h2) 74 | y = self.ly(h_out) 75 | l_out = F.tanh(self.ll(h_out)) 76 | b = F.sigmoid(self.lb(h_out)) 77 | return l_out, ln_pi, y, b 78 | 79 | 80 | 81 | def __call__(self, x, t): 82 | 83 | x = chainer.Variable(self.xp.asarray(x)) 84 | t = chainer.Variable(self.xp.asarray(t)) 85 | #print(x.shape) 86 | #print(t.shape) 87 | batchsize = x.data.shape[0] 88 | self.reset_state() 89 | 90 | # initial l 91 | l = np.random.uniform(-1, 1, size=(batchsize, 2)).astype(np.float32) 92 | l = chainer.Variable(self.xp.asarray(l)) 93 | 94 | sum_ln_pi = Variable((self.xp.zeros((batchsize,1)))) 95 | sum_ln_pi = F.cast(sum_ln_pi,'float32') 96 | l, ln_pi, y, b = self.forward(x, l, first=True) 97 | for i in range(1,self.n_steps): 98 | l, ln_pi, y, b = self.forward(x, l) 99 | sum_ln_pi += ln_pi 100 | self.loss_action = F.softmax_cross_entropy(y, t) 101 | self.loss = self.loss_action 102 | self.accuracy = F.accuracy(y, t) 103 | reporter.report({'accuracy': self.accuracy}, self) 104 | self.y = F.argmax(y, axis=1) 105 | if chainer.global_config.train: 106 | 107 | conditions = self.xp.argmax(y.data, axis=1) == t.data 108 | r = self.xp.where(conditions, 1., 0.).astype(self.xp.float32) 109 | r = self.xp.expand_dims(r, 1) 110 | # squared error between reward and baseline 111 | self.loss_baseline = F.mean_squared_error(r, b) 112 | self.loss += self.loss_baseline 113 | # loss with reinforce rule 114 | mean_ln_pi = sum_ln_pi / (self.n_steps - 1) 115 | a = F.sum(-mean_ln_pi * (r - b)) / batchsize 116 | self.reinforce_loss = F.sum(-mean_ln_pi * (r-b)) / batchsize 117 | self.loss += self.reinforce_loss 118 | reporter.report({'cross_entropy_loss': self.loss_action}, self) 119 | #reporter.report({'reinforce_loss': self.reinforce_loss}, self) 120 | #reporter.report({'total_loss': self.loss}, self) 121 | reporter.report({'training_accuracy': self.accuracy}, self) 122 | 123 | #print(self.loss) 124 | return self.loss 125 | 126 | 127 | if __name__ == "__main__": 128 | train, test = chainer.datasets.get_mnist(withlabel=True, ndim=3) 129 | train_data, train_targets = np.array(train).transpose() 130 | train_data = np.array(list(train_data)).reshape(train_data.shape[0], 1, 28, 28) 131 | train_targets = np.array(train_targets).astype(np.int32) 132 | x = train_data[0:2] 133 | t = train_targets[0:2] 134 | model = RAM() 135 | model.to_gpu() 136 | model(x, t) 137 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import chainer 4 | from chainer import training, datasets, iterators, optimizers, serializers 5 | from chainer import reporter 6 | from network import RAM 7 | from chainer.training import extensions 8 | from weightdecay import lr_drop 9 | 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser(description='VAM in Chainer:MNIST') 13 | parser.add_argument('--batchsize', '-b', type=int, default=128, 14 | help='Number of images in each mini-batch') 15 | parser.add_argument('--epoch', '-e', type=int, default=1000, 16 | help='Number of sweeps over the dataset to train') 17 | parser.add_argument('--gpu', '-g', type=int, default=-1, 18 | help='GPU ID (negative value indicates CPU)') 19 | parser.add_argument('--out', '-o', default='result', 20 | help='Directory to output the result') 21 | parser.add_argument('--resume', '-r', default='', 22 | help='Resume the training from snapshot') 23 | parser.add_argument('--unit', '-u', type=int, default=128, 24 | help='Dimension of locator, glimpse hidden state') 25 | parser.add_argument('--hidden','-hi', type=int, default=256, 26 | help='Dimension of lstm hidden state') 27 | parser.add_argument('--g_size', '-g_size', type=int, default=8, 28 | help='Dimension of output') 29 | parser.add_argument('--len_seq', '-l', type=int, default=6, 30 | help='Length of action sequence') 31 | parser.add_argument('--depth', '-d', type=int, default=1, 32 | help='no of depths/glimpses to be taken at once') 33 | parser.add_argument('--scale', '-s', type=float, default=2, 34 | help='subsequent scales of cropped image for sequential depths (int>1)') 35 | parser.add_argument('--sigma', '-si',type=float, default=0.03, 36 | help='sigma of location sampling model') 37 | parser.add_argument('--evalm', '-evalm', type=str, default=None, 38 | help='Evaluation mode: path to saved model file') 39 | parser.add_argument('--evalo', '-eval0', type=str, default=None, 40 | help='Evaluation mode: path to saved optimizer file') 41 | args = parser.parse_args() 42 | 43 | print('GPU: {}'.format(args.gpu)) 44 | print('# n_units: {}'.format(args.unit)) 45 | print('# n_hidden: {}'.format(args.hidden)) 46 | print('# Length of action sequence: {}'.format(args.len_seq)) 47 | print('# sigma: {}'.format(args.sigma)) 48 | print('# Minibatch-size: {}'.format(args.batchsize)) 49 | print('# epoch: {}'.format(args.epoch)) 50 | print('') 51 | 52 | train, test = chainer.datasets.get_mnist() 53 | train_data, train_targets = np.array(train).transpose() 54 | test_data, test_targets = np.array(test).transpose() 55 | train_data = np.array(list(train_data)).reshape(train_data.shape[0], 1, 28, 28) 56 | test_data = np.array(list(test_data)).reshape(test_data.shape[0], 1, 28, 28) 57 | train_targets = np.array(train_targets).astype(np.int32) 58 | test_targets = np.array(test_targets).astype(np.int32) 59 | if args.evalm is not None: 60 | chainer.global_config.train = False 61 | 62 | model = RAM(args.hidden, args.unit, args.sigma, 63 | args.g_size, args.len_seq, args.depth, args.scale, using_conv = False) 64 | #model.to_gpu() 65 | optimizer = optimizers.NesterovAG() 66 | if args.evalm is not None: 67 | serializers.load_npz(args.evalm, model) 68 | print('model loaded') 69 | if args.evalo is not None: 70 | serializers.load_npz(args.evalo, optimizer) 71 | print('optimizer loaded') 72 | 73 | if args.gpu>=0: 74 | model.to_gpu() 75 | 76 | optimizer.setup(model) 77 | 78 | train_dataset = datasets.TupleDataset(train_data, train_targets) 79 | train_iter = iterators.SerialIterator(train_dataset, args.batchsize) 80 | test_dataset = datasets.TupleDataset(test_data, test_targets) 81 | train_iter = iterators.SerialIterator(test_dataset, 128) 82 | stop_trigger = (args.epoch, 'epoch') 83 | updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu) 84 | trainer = training.Trainer(updater, stop_trigger, out=args.out) 85 | trainer.extend(lr_drop) 86 | trainer.extend(extensions.snapshot_object(model, '2model{.updater.epoch}.npz'), trigger=(50,'epoch')) 87 | trainer.extend(extensions.snapshot_object(optimizer, '2opt{.updater.epoch}.npz'), trigger=(50, 'epoch')) 88 | trainer.extend(extensions.PlotReport(['main/training_accuracy'], 'epoch', trigger=(1, 'epoch'), file_name='2train_accuracy.png', 89 | marker=".")) 90 | trainer.extend(extensions.PlotReport(['main/cross_entropy_loss'], 'epoch', trigger=(1, 'epoch'), file_name='2cross_entropy.png', 91 | marker=".")) 92 | trainer.extend(extensions.ProgressBar((args.epoch,'epoch'),update_interval=50)) 93 | trainer.run() 94 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from matplotlib import gridspec 5 | import chainer 6 | from chainer import serializers 7 | import PIL 8 | from PIL import ImageDraw 9 | import numpy as np 10 | import argparse 11 | import chainer.functions as F 12 | from network import RAM 13 | 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser(description='VAM in Chainer:MNIST') 17 | parser.add_argument('--batchsize', '-b', type=int, default=128, 18 | help='Number of images in each mini-batch') 19 | parser.add_argument('--gpu', '-g', type=int, default=-1, 20 | help='GPU ID (negative value indicates CPU)') 21 | parser.add_argument('--out', '-o', default='result', 22 | help='Directory to output the result') 23 | parser.add_argument('--unit', '-u', type=int, default=128, 24 | help='Dimension of locator, glimpse hidden state') 25 | parser.add_argument('--hidden','-hi', type=int, default=256, 26 | help='Dimension of lstm hidden state') 27 | parser.add_argument('--g_size', '-g_size', type=int, default=8, 28 | help='Dimension of output') 29 | parser.add_argument('--len_seq', '-l', type=int, default=6, 30 | help='Length of action sequence') 31 | parser.add_argument('--depth', '-d', type=int, default=1, 32 | help='no of depths/glimpses to be taken at once') 33 | parser.add_argument('--scale', '-s', type=float, default=2, 34 | help='subsequent scales of cropped image for sequential depths (int>1)') 35 | parser.add_argument('--sigma', '-si',type=float, default=0.03, 36 | help='sigma of location sampling model') 37 | parser.add_argument('--eval', '-eval', type=str, default=None, 38 | help='Evaluation mode: path to saved model file relative to current working dir') 39 | args = parser.parse_args() 40 | 41 | model = RAM(args.hidden, args.unit, args.sigma, args.g_size, args.len_seq, args.depth, args.scale) 42 | serializers.load_npz(os.getcwd() + args.eval, model) 43 | train, test = chainer.datasets.get_mnist() 44 | train_data, train_targets = np.array(train).transpose() 45 | test_data, test_targets = np.array(test).transpose() 46 | train_data = np.array(list(train_data)).reshape(train_data.shape[0], 1, 28, 28) 47 | test_data = np.array(list(test_data)).reshape(test_data.shape[0], 1, 28, 28) 48 | train_targets = np.array(train_targets).astype(np.int32) 49 | test_targets = np.array(test_targets).astype(np.int32) 50 | g_size = args.g_size 51 | 52 | 53 | def visualize(model): 54 | chainer.global_config.train = False 55 | index = np.random.randint(0, 9999) 56 | x_raw = train_data[index:index + 1] 57 | t_raw = train_targets[index] 58 | x = chainer.Variable(np.asarray(x_raw)) 59 | t = chainer.Variable(np.asarray(t_raw)) 60 | batchsize = x.data.shape[0] 61 | model.reset_state() 62 | ls = [] 63 | probs = [] 64 | 65 | l = np.random.uniform(-1, 1, size=(batchsize, 2)).astype(np.float32) 66 | l = chainer.Variable(np.asarray(l)) 67 | ls.append(l.data) 68 | for i in range(6): 69 | l, ln_pi, y, b = model.forward(x, l, first=True) 70 | y = F.softmax(y) 71 | probs.append(y.data) 72 | ls.append(l.data) 73 | fig = plt.figure(figsize=(8, 6)) 74 | gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1]) 75 | ax0 = plt.subplot(gs[0]) 76 | image = PIL.Image.fromarray(train_data[index][0] * 255).convert('RGB') 77 | canvas = image.copy() 78 | draw = ImageDraw.Draw(canvas) 79 | 80 | locs = (((ls[i] + 1) / 2) * ((np.array([28, 28])) - 1)) 81 | 82 | color = (0, 255, 0) 83 | xy = np.array([locs[0][0], locs[0][1], locs[0][0], locs[0][1]]) 84 | wh = np.array([-g_size // 2, -g_size // 2, g_size // 2, g_size // 2]) 85 | xys = [xy + np.power(2, s) * wh for s in range(args.depth)] 86 | 87 | for xy in xys: 88 | draw.rectangle(xy=list(xy), outline=color) 89 | del draw 90 | 91 | plt.imshow(canvas) 92 | plt.axis('off') 93 | 94 | y_ticks = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '0'] 95 | 96 | bar_lengths = probs[i][0] 97 | 98 | ax1 = plt.subplot(gs[1]) 99 | ax1.barh(y_ticks, bar_lengths, color='#006080') 100 | ax1.get_xaxis().set_ticks([]) 101 | plt.tight_layout() 102 | plt.savefig(args.result+str(i) + '.png') 103 | 104 | 105 | visualize(model) 106 | -------------------------------------------------------------------------------- /weightdecay.py: -------------------------------------------------------------------------------- 1 | from chainer import training 2 | 3 | @training.make_extension(trigger=(400, 'epoch')) 4 | def lr_drop(trainer): 5 | trainer.updater.get_optimizer('main').lr *= 0.1 --------------------------------------------------------------------------------