├── fig.png
├── checkpoints
├── mnist_gdram.pth
├── cifar10_gdram.pth
└── cifar100_gdram.pth
├── dataloader.py
├── LICENSE
├── mnist_generation.py
├── utils.py
├── inference.py
├── model.py
├── README.md
├── modules.py
└── train.py
/fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dsshim0125/gaussian-ram/HEAD/fig.png
--------------------------------------------------------------------------------
/checkpoints/mnist_gdram.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dsshim0125/gaussian-ram/HEAD/checkpoints/mnist_gdram.pth
--------------------------------------------------------------------------------
/checkpoints/cifar10_gdram.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dsshim0125/gaussian-ram/HEAD/checkpoints/cifar10_gdram.pth
--------------------------------------------------------------------------------
/checkpoints/cifar100_gdram.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dsshim0125/gaussian-ram/HEAD/checkpoints/cifar100_gdram.pth
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from torch.utils.data import Dataset, DataLoader
4 | import torch
5 | from skimage import io
6 | from torchvision import transforms
7 | from PIL import Image
8 | import pandas as pd
9 |
10 | class MnistClutteredDataset(Dataset):
11 |
12 | def __init__(self, data_path, type, transform=None):
13 |
14 | self.root_dir = data_path +'/'+ type + '/path.txt'
15 | self.transform = transform
16 | self.path = pd.read_csv(self.root_dir, sep=' ', header=None)
17 |
18 | def __getitem__(self, idx):
19 | if torch.is_tensor(idx):
20 | idx = idx.tolist()
21 |
22 | img_path = self.path.iloc[idx,0]
23 |
24 | image = Image.open(img_path)
25 |
26 | label = int(self.path.iloc[idx,1])
27 |
28 | if self.transform:
29 | image = self.transform(image)
30 |
31 | return image, label
32 |
33 | def __len__(self):
34 | return len(self.path)
35 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Dongseok Shim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/mnist_generation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 |
5 | root_path = 'data'
6 |
7 | data = np.load(root_path + '/mnist_sequence1_sample_5distortions5x5.npz')
8 |
9 | X_train = data['X_train']
10 | y_train = data['y_train']
11 |
12 | X_val = data['X_valid']
13 | y_val = data['y_valid']
14 |
15 | X_test = data['X_test']
16 | y_test = data['y_test']
17 |
18 | if not os.path.exists(root_path + 'train'):
19 | os.mkdir(root_path + 'train')
20 |
21 | f = open(root_path + '/train/path.txt', 'w')
22 |
23 | for i in range(len(X_train)):
24 |
25 | img_path = root_path + '/train/%05d.jpg'%i
26 |
27 | img = X_train[i].reshape(40,40)
28 | plt.imsave(img_path, img)
29 | label = y_train[i,0]
30 |
31 | f.write(img_path+ ' %d\n'%(label))
32 |
33 | f.close()
34 |
35 | if not os.path.exists(root_path + 'val'):
36 | os.mkdir(root_path + 'val')
37 |
38 | f = open(root_path + '/val/path.txt', 'w')
39 |
40 | for i in range(len(X_val)):
41 |
42 | img_path = root_path + '/val/%05d.jpg'%i
43 |
44 | img = X_val[i].reshape(40,40)
45 | plt.imsave(img_path, img)
46 | label = y_val[i,0]
47 |
48 | f.write(img_path+ ' %d\n'%(label))
49 |
50 | f.close()
51 |
52 |
53 | if not os.path.exists(root_path + 'test'):
54 | os.mkdir(root_path + 'test')
55 |
56 | f = open(root_path + '/test/path.txt', 'w')
57 |
58 | for i in range(len(X_test)):
59 |
60 | img_path = root_path + '/test/%05d.jpg'%i
61 |
62 | img = X_test[i].reshape(40,40)
63 | plt.imsave(img_path, img)
64 | label = y_test[i,0]
65 |
66 | f.write(img_path+ ' %d\n'%(label))
67 |
68 | f.close()
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use('agg')
3 | from matplotlib.animation import FuncAnimation
4 | import matplotlib.pyplot as plt
5 | import matplotlib.patches as patches
6 | import numpy as np
7 | import torch
8 | from torch.nn import functional as F
9 | from torchvision.utils import save_image
10 | import numpy as np
11 | import os
12 |
13 | def get_glimpse(x, l, output_size, k, device):
14 | """Transform image to retina representation
15 |
16 | Assume that width = height and channel = 1
17 | """
18 | batch_size, input_size = x.size(0), x.size(2) - 1
19 | #device = torch.device('cpu')
20 | assert output_size * 2**(k - 1) <= input_size, \
21 | "output_size * 2**(k-1) should smaller than or equal to input_size"
22 |
23 | # construct theta for affine transformation
24 | theta = torch.zeros(batch_size, 2, 3)
25 | theta[:, :, 2] = l
26 |
27 | scale = output_size / input_size
28 | osize = torch.Size([batch_size, 1, output_size, output_size])
29 |
30 | for i in range(k):
31 | theta[:, 0, 0] = scale
32 | theta[:, 1, 1] = scale
33 | grid = F.affine_grid(theta, osize, align_corners=False).to(device)
34 | glimpse = F.grid_sample(x, grid, align_corners=False)
35 |
36 | if i==0:
37 | output = glimpse
38 | else:
39 | output = torch.cat((output, glimpse), dim=1)
40 | scale *= 2
41 |
42 | return output.detach()
43 |
44 |
45 | def draw_locations(image, locations, weights=None, size=8, epoch=0, save_path='results'):
46 | image = np.transpose(image, (1,2,0))
47 | weights = weights.detach().cpu().numpy()
48 |
49 |
50 | if (epoch>50):
51 | for idx in range(len(weights[0])-1):
52 | if (weights[0][idx] < 0.5) and (weights[0][idx+1] < 0.5):
53 | break
54 |
55 | locations = locations[:idx+1]
56 |
57 |
58 | #print(locations.shape)
59 | locations = list(locations)
60 | fig, ax = plt.subplots(1, len(locations))
61 | for i, location in enumerate(locations):
62 | if len(locations) == 1:
63 | subplot = ax
64 | else:
65 | subplot = ax[i]
66 |
67 | subplot.axis('off')
68 | subplot.imshow(image, cmap='gray')
69 | loc = ((location[0] + 1) * image.shape[1] / 2 - size / 2,
70 | (location[1] + 1) * image.shape[0] / 2 - size / 2)
71 |
72 | rect = patches.Rectangle(
73 | loc, size, size, linewidth=1, edgecolor='r', facecolor='none')
74 | subplot.add_patch(rect)
75 | fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=None, hspace=None)
76 |
77 |
78 | if not os.path.exists(save_path):
79 | os.mkdir(save_path)
80 | plt.savefig(save_path+ '/glimpse_%d.png'%epoch, bbox_inches='tight')
81 | plt.close()
82 |
83 | if __name__ == '__main__':
84 | img = np.ones((3,3,28,28))
85 |
86 | loc = np.ones((3,2))
87 |
88 | img = torch.Tensor(img).cuda()
89 | loc = torch.Tensor(loc).cuda()
90 |
91 | out = get_glimpse(img,loc,8,2)
92 | print(out.shape)
93 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from torchvision import datasets, transforms
3 | from model import GDRAM
4 | from dataloader import MnistClutteredDataset
5 | import time
6 | import argparse
7 |
8 |
9 | def str2bool(v):
10 | if isinstance(v, bool):
11 | return v
12 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
13 | return True
14 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
15 | return False
16 | else:
17 | raise argparse.ArgumentTypeError('Boolean value expected.')
18 |
19 |
20 | parser = argparse.ArgumentParser(description='Inference')
21 |
22 | parser.add_argument('--data_path', type=str, default='data')
23 | parser.add_argument('--dataset', type=str, default='mnist')
24 | parser.add_argument('--device', type=str, default='cuda')
25 | parser.add_argument('--fast', type=str2bool, default='False')
26 | parser.add_argument('--random_seed', type=int, default=1)
27 | args = parser.parse_args()
28 |
29 | batch_size = 1
30 |
31 | kwargs = {'num_workers': 64, 'pin_memory': True} if not args.device=='cpu' else {}
32 |
33 | device = torch.device(args.device)
34 |
35 | model_path = 'checkpoints/'+args.dataset+'_gdram.pth'
36 |
37 | img_size = 128
38 |
39 | torch.manual_seed(args.random_seed)
40 |
41 | ##################################################
42 |
43 | if args.dataset == 'cifar10':
44 |
45 | transform = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor()])
46 |
47 | test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_path, train=False,\
48 | transform=transform),batch_size=batch_size, shuffle=False, **kwargs)
49 |
50 | elif args.dataset == 'cifar100':
51 |
52 | transform = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor()])
53 |
54 | test_loader = torch.utils.data.DataLoader(datasets.CIFAR100(args.data_path, train=False,\
55 | transform=transform),batch_size=batch_size, shuffle=False, **kwargs)
56 |
57 | elif args.dataset == 'mnist':
58 |
59 | transform = transforms.Compose([transforms.Resize(img_size),transforms.Grayscale(3), transforms.ToTensor()])
60 | test_set = MnistClutteredDataset(args.data_path, type='test',transform=transform)
61 |
62 | test_loader = torch.utils.data.DataLoader(
63 | test_set, batch_size=batch_size, shuffle=False, **kwargs
64 | )
65 |
66 |
67 | model = GDRAM(device=device, dataset=args.dataset, Fast=args.fast).to(device)
68 | model.eval()
69 |
70 | pytorch_total_params = sum(p.numel() for p in model.parameters())
71 |
72 | print('Model parameters: %d'%pytorch_total_params)
73 |
74 | model.load_state_dict(torch.load(model_path))
75 | print('Model Loaded!')
76 |
77 | total_correct = 0.0
78 |
79 | def accuracy2(output, target, topk=(1,)):
80 | maxk = max(topk)
81 | batch_size = target.size(0)
82 |
83 | _, pred = output.topk(maxk, 1, True, True)
84 | pred = pred.t()
85 | correct = pred.eq(target.view(1, -1).expand_as(pred))
86 |
87 | res = []
88 | for k in topk:
89 | correct_k = correct[:k].view(-1).float().sum(0)
90 | res.append(correct_k.mul_(100.0 / batch_size))
91 | return res
92 |
93 |
94 | accuracy1 = 0
95 | accuracy5 = 0
96 |
97 | start_time = time.time()
98 |
99 | for data, labels in test_loader:
100 | data = data.to(device)
101 | action_logits, location, _, _, weights = model(data)
102 | predictions = torch.argmax(action_logits, dim=1)
103 | labels = labels.to(device)
104 | total_correct += torch.sum((labels == predictions)).item()
105 |
106 | acc1 , acc5 = accuracy2(action_logits, labels, topk=(1,5))
107 | accuracy1 += acc1.detach().cpu().numpy()
108 | accuracy5 += acc5.detach().cpu().numpy()
109 |
110 | acc1 = accuracy1/len(test_loader)
111 | acc5 = accuracy5/len(test_loader)
112 |
113 | print("Top1:%.2f Top5:%.2f fps:%.5f"%(acc1, acc5,(time.time() - start_time)/len(test_loader.dataset)))
114 |
115 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from modules import *
4 | from utils import get_glimpse
5 | import math
6 |
7 |
8 | class GDRAM(nn.Module):
9 | def __init__(self, device=None, dataset=None, Fast=False):
10 | super(GDRAM, self).__init__()
11 |
12 | self.glimpse_size = 12
13 | self.num_scales = 4
14 |
15 | self.img_size = 128
16 |
17 | self.class_num = 10
18 |
19 | if dataset == 'cifar100':
20 | self.class_num = 100
21 |
22 |
23 | self.normalized_glimpse_size = self.glimpse_size/(self.img_size/2)
24 |
25 | self.glimpse_net = GlimpseNetwork(3*self.num_scales,self.glimpse_size,2,128,128)
26 |
27 | self.rnn1 = GlimpseLSTMCoreNetwork(128,128)
28 | self.rnn2 = LocationLSTMCoreNetwork(128,128,self.glimpse_size)
29 |
30 | self.class_net = ActionNetwork(128, self.class_num)
31 | self.emission_net = EmissionNetwork(128)
32 |
33 | self.baseline_net = BaselineNetwork(128*2,1)
34 |
35 | self.num_glimpses = 8
36 | self.location_size = 2
37 |
38 | self.device = device
39 |
40 | self.fast = Fast
41 |
42 | def forward(self, x):
43 |
44 | batch_size = x.size(0)
45 |
46 | hidden1, cell_state1 = self.rnn1.init_hidden(batch_size)
47 | hidden1 = hidden1.to(self.device)
48 | cell_state1 = cell_state1.to(self.device)
49 |
50 |
51 | hidden2, cell_state2 = self.rnn2.init_hidden(x, batch_size)
52 |
53 | hidden2 = hidden2.to(self.device)
54 | cell_state2 = cell_state2.to(self.device)
55 |
56 | #location = torch.zeros(batch_size,2).to(self.device)
57 | std = (torch.ones(batch_size,2)*(math.exp(-1/2))).to(self.device)
58 |
59 | location, std, log_prob = self.emission_net(hidden2)
60 | location = torch.clamp(location, min=-1 + self.normalized_glimpse_size / 2,
61 | max=1 - self.normalized_glimpse_size / 2)
62 |
63 | location_log_probs = torch.empty(batch_size, self.num_glimpses).to(self.device)
64 | locations = torch.empty(batch_size, self.num_glimpses, self.location_size).to(self.device)
65 | baselines = torch.empty(batch_size, self.num_glimpses).to(self.device)
66 | weights = torch.empty(batch_size, self.num_glimpses).to(self.device)
67 |
68 | weight = torch.ones(batch_size).to(self.device)
69 |
70 | action_logits = 0
71 | weight_sum = 0
72 |
73 |
74 | for i in range(self.num_glimpses):
75 |
76 |
77 |
78 | locations[:, i] = location
79 |
80 | location_log_probs[:, i] = log_prob
81 |
82 | glimpse = get_glimpse(x, location.detach(), self.glimpse_size, self.num_scales, device=self.device).to(self.device)
83 | glimpse_feature = self.glimpse_net(glimpse, location)
84 |
85 | hidden1, cell_state1 = self.rnn1(glimpse_feature, (hidden1, cell_state1))
86 | hidden2, cell_state2 = self.rnn2(hidden1, (hidden2, cell_state2))
87 |
88 | loc_diff, std, log_prob = self.emission_net(hidden2)
89 | loc_diff *= (self.normalized_glimpse_size/2 * 2**(self.num_scales - 1))
90 | new_location = location.detach() + loc_diff
91 | new_location = torch.clamp(new_location, min = -1 + self.normalized_glimpse_size/2 , max= 1 - self.normalized_glimpse_size/2)
92 |
93 |
94 | location = new_location
95 |
96 | hidden = torch.cat((hidden1, hidden2), dim=1)
97 | baseline = self.baseline_net(hidden)
98 |
99 | #location_log_probs[:, i] = log_prob
100 | baselines[:, i] = baseline.squeeze()
101 |
102 | weight = weight.unsqueeze(1)
103 | action_logit = self.class_net(hidden1)
104 |
105 | action_logits += weight*action_logit
106 |
107 | weights[:,i] = weight.squeeze()
108 |
109 | weight_sum += weight
110 |
111 | if (not self.training and i>1) and self.fast:
112 | if weights[0,-1]<0.5 and weights[0,-2]<0.5:
113 | break
114 |
115 | std = torch.mean(std, dim=1)
116 | normalized_std = (std-math.exp(-1/2))/(math.exp(1/2)-math.exp(-1/2))
117 | weight = 1 - normalized_std
118 |
119 | action_logits /= weight_sum
120 |
121 | return action_logits, locations, location_log_probs, baselines, weights
122 |
123 |
124 |
125 |
126 |
127 |
128 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Gaussian RAM
2 |
3 | ### ICROS ICCAS 2020 Student Best Paper Finalist
4 |
5 | This repo is an official PyTorch implementation of "Gaussian RAM: Lightweight Image Classification via Stochastic Retina Inspired Glimpse and Reinforcement Learning". [[paper](https://arxiv.org/abs/2011.06190)]
6 |
7 |
8 | ## Abstract
9 | Previous studies on image classification have been mainly focused on the performance of the networks, not on real-time operation or model compression. We propose a Gaussian Deep Recurrent visual Attention Model (GDRAM)- a reinforcement learning based lightweight deep neural network for large scale image classification that outperformsthe conventional CNN (Convolutional Neural Network) which uses the entire image as input. Highly inspired by the biological visual recognition process, our model mimics the stochastic location of the retina with Gaussian distribution. We evaluate the model on Large cluttered MNIST, Large CIFAR-10 and Large CIFAR-100 datasets which are resized to 128 in both width and height.
10 |
11 |
12 |
13 |
14 |
15 | ## Dataset
16 | Cluttered MNIST([download](https://drive.google.com/file/d/1nMO5XIFmjyPnJjfvBeFpujeuZ3Qk7vhd/view?usp=sharing)), CIFAR10 and CIFAR100 are used to train and evaluate. All the images are resized to 128 in both height and weight for generating high scale image.
17 | ## Requirements
18 | - Python3
19 | - PyTorch (> 1.0)
20 | - torchvision (> 0.2)
21 | - PIL
22 | - NumPy
23 |
24 | ## Training
25 | ```bash
26 | python train.py --data_path --dataset --batch_size --lr --epochs --random_seed --log_interval --resume --checkpoint
27 | ```
28 |
29 | ## Inference
30 | ```bash
31 | python inference.py --data_path --dataset --random_seed --fast
32 | ```
33 |
34 | ## Acknowledgement
35 | This work was supported by Institute of Information & Communications Technology Planning & Evaluation(IITP) grant funded by the Korea government (MSIT) (No. 2019-0-01367, Infant-Mimic Neurocognitive Developmental Machine Learning from Interaction Experience with Real World (BabyMind))
36 |
37 | ## References
38 | [1] Y. Lecun, L. Bottou, Y. Bengio, and P. Haffner,“Gradient-based learning applied to documentrecognition,” inProceedings of the IEEE, 1998, pp.2278–2324.
39 | [2] K. Simonyan and A. Zisserman, “Very deep con-volutional networks for large-scale image recogni-tion,”arXiv preprint arXiv:1409.1556, 2014.
40 | [3] C. Szegedy, W. Liu, Y. Jia, P. Sermanet, S. Reed,D. Anguelov, D. Erhan, V. Vanhoucke, and A. Ra-binovich, “Going deeper with convolutions,” inPro-ceedings of the IEEE conference on computer visionand pattern recognition, 2015, pp. 1–9.
41 | [4] K. He, X. Zhang, S. Ren, and J. Sun, “Deep resid-ual learning for image recognition,” inProceedingsof the IEEE conference on computer vision and pat-tern recognition, 2016, pp. 770–778.
42 | [5] G. Huang, Z. Liu, L. Van Der Maaten, and K. Q.Weinberger, “Densely connected convolutional net-works,” inProceedings of the IEEE conference oncomputer vision and pattern recognition, 2017, pp.4700–4708.
43 | [6] Y. LeCun, “The mnist database of handwritten dig-its,”http://yann. lecun. com/exdb/mnist/.
44 | [7] O. Russakovsky, J. Deng, H. Su, J. Krause,S. Satheesh, S. Ma, Z. Huang, A. Karpathy,A. Khosla, M. Bernstein, A. C. Berg, and L. Fei-Fei,“ImageNet Large Scale Visual Recognition Chal-lenge,”International Journal of Computer Vision(IJCV), vol. 115, no. 3, pp. 211–252, 2015.
45 | [8] V. Mnih, N. Heess, A. Graveset al., “Recurrentmodels of visual attention,” inAdvances in neuralinformation processing systems, 2014, pp. 2204–2212.
46 | [9] J. Ba, V. Mnih, and K. Kavukcuoglu, “Multi-ple object recognition with visual attention,”arXivpreprint arXiv:1412.7755, 2014.
47 | [10] Q. Liu, R. Hang, H. Song, and Z. Li, “Learn-ing multi-scale deep features for high-resolutionsatellite image classification,”arXiv preprintarXiv:1611.03591, 2016.
48 | [11] M. Iftenea, Q. Liub, and Y. Wangc, “Very high res-olution images classification by fusing deep convo-lutional neural networks.”
49 | [12] A. Ablavatski, S. Lu, and J. Cai, “Enriched deeprecurrent visual attention model for multiple objectrecognition,” in2017 IEEE Winter Conference onApplications of Computer Vision (WACV).IEEE,2017, pp. 971–978.
50 | [13] M. Jaderberg, K. Simonyan, A. Zissermanet al.,
51 | “Spatial transformer networks,” inAdvances inneural information processing systems, 2015, pp.2017–2025.
52 | [14] J. Redmon and A. Farhadi,“Yolov3:Anincrementalimprovement,”arXiv preprintarXiv:1804.02767, 2018.
53 | [15] J. Choi, D. Chun, H. Kim, and H.-J. Lee, “Gaussianyolov3: An accurate and fast object detector usinglocalization uncertainty for autonomous driving,” inProceedings of the IEEE International Conferenceon Computer Vision, 2019, pp. 502–511.
54 | [16] S. Ioffe and C. Szegedy, “Batch normaliza-tion: Accelerating deep network training by re-ducing internal covariate shift,”arXiv preprintarXiv:1502.03167, 2015.
55 | [17] S. Hochreiter and J. Schmidhuber, “Long short-termmemory,”Neural computation, vol. 9, no. 8, pp.1735–1780, 1997.
56 | [18] R. S. Sutton, D. A. McAllester, S. P. Singh, andY. Mansour, “Policy gradient methods for reinforce-ment learning with function approximation,” inAd-vances in neural information processing systems,2000, pp. 1057–1063.
57 |
--------------------------------------------------------------------------------
/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | import cv2
5 | import numpy as np
6 |
7 | class GlimpseNetwork(nn.Module):
8 |
9 | def __init__(self, input_channel, glimpse_size, location_size, internal_size, output_size):
10 | super(GlimpseNetwork, self).__init__()
11 |
12 | self.fc_g = nn.Sequential(
13 | nn.Conv2d(input_channel, 128, kernel_size=3, stride=1, padding=1),
14 | nn.BatchNorm2d(128),
15 | nn.ReLU(),
16 | nn.MaxPool2d(2),
17 | nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
18 | nn.BatchNorm2d(256),
19 | nn.ReLU(),
20 | nn.MaxPool2d(2)
21 | )
22 |
23 | self.fc_l = nn.Sequential(
24 | nn.Linear(location_size, internal_size),
25 | nn.ReLU())
26 |
27 | self.fc_gg = nn.Linear(glimpse_size//4 * glimpse_size//4 * 256, output_size)
28 | self.fc_lg = nn.Linear(internal_size, output_size)
29 |
30 | def forward(self, x, location):
31 | hg = self.fc_g(x).view(len(x), -1)
32 | hl = self.fc_l(location)
33 |
34 | output = F.relu(self.fc_gg(hg) * self.fc_lg(hl))
35 |
36 | return output
37 |
38 |
39 |
40 | class CoreNetwork(nn.Module):
41 |
42 | def __init__(self, input_size, hidden_size):
43 | super(CoreNetwork, self).__init__()
44 |
45 | self.hidden_size = hidden_size
46 | self.rnn_cell = nn.RNNCell(
47 | input_size, hidden_size, nonlinearity='relu')
48 |
49 | def forward(self, g, prev_h):
50 | h = self.rnn_cell(g, prev_h)
51 | return h
52 |
53 | def init_hidden(self, batch_size):
54 | return torch.zeros(batch_size, self.hidden_size)
55 |
56 |
57 | class GRUCoreNetwork(nn.Module):
58 |
59 | def __init__(self, input_size, hidden_size):
60 | super(GRUCoreNetwork, self).__init__()
61 |
62 | self.hidden_size = hidden_size
63 | self.rnn_cell = nn.GRUCell(
64 | input_size, hidden_size)
65 |
66 | def forward(self, g, prev_h):
67 | h = self.rnn_cell(g, prev_h)
68 | return h
69 |
70 | def init_hidden(self, batch_size):
71 | return torch.zeros(batch_size, self.hidden_size)
72 |
73 |
74 | class GlimpseLSTMCoreNetwork(nn.Module):
75 |
76 | def __init__(self, input_size, hidden_size):
77 | super(GlimpseLSTMCoreNetwork, self).__init__()
78 |
79 | self.hidden_size = hidden_size
80 | self.lstm_cell = nn.LSTMCell(
81 | input_size, hidden_size)
82 |
83 | def forward(self, g, prev_h):
84 | h, c = self.lstm_cell(g, prev_h)
85 | return h, c
86 |
87 | def init_hidden(self, batch_size):
88 | return torch.zeros(batch_size, self.hidden_size), torch.zeros(batch_size, self.hidden_size)
89 |
90 |
91 | class LocationLSTMCoreNetwork(nn.Module):
92 |
93 | def __init__(self, input_size, hidden_size, glimpse_size):
94 | super(LocationLSTMCoreNetwork, self).__init__()
95 |
96 | self.hidden_size = hidden_size
97 | self.glimpse_size = glimpse_size
98 |
99 | self.lstm_cell = nn.LSTMCell(
100 | input_size, hidden_size)
101 |
102 | self.context_net1 = nn.Sequential(
103 | nn.Conv2d(3,64,3,padding=1),
104 | nn.ReLU(),
105 | nn.MaxPool2d(2))
106 |
107 | self.context_net2 = nn.Sequential(
108 | nn.Conv2d(64,64,3,padding=1),
109 | nn.ReLU(),
110 | nn.MaxPool2d(2)
111 | )
112 |
113 | self.fc = nn.Linear(glimpse_size//4*glimpse_size//4*64,hidden_size)
114 |
115 | def forward(self, g, prev_h):
116 | h, c = self.lstm_cell(g, prev_h)
117 | return h, c
118 |
119 | def init_hidden(self, x, batch_size):
120 | x = F.interpolate(x, (self.glimpse_size,self.glimpse_size))
121 |
122 | h = self.fc(self.context_net2(self.context_net1(x)).view(batch_size,-1))
123 | c = torch.zeros((batch_size, self.hidden_size))
124 |
125 | return h, c
126 |
127 |
128 | class EmissionNetwork(nn.Module):
129 |
130 | def __init__(self, input_size, uniform=False, output_size=2, hidden=256):
131 | super(EmissionNetwork, self).__init__()
132 |
133 | self.fc = nn.Sequential(
134 | nn.Linear(input_size, hidden),
135 | nn.BatchNorm1d(hidden),
136 | nn.ReLU())
137 |
138 | self.mu_net = nn.Sequential(
139 | nn.Linear(hidden, output_size),
140 | nn.Tanh()
141 | )
142 |
143 | self.logvar_net = nn.Sequential(
144 | nn.Linear(hidden, output_size),
145 | nn.Tanh()
146 | )
147 |
148 | self.unifrom = uniform
149 |
150 | def forward(self, x):
151 |
152 | z = self.fc(x.detach())
153 | mu = self.mu_net(z)
154 |
155 | logvar = self.logvar_net(z)
156 | std = torch.exp(logvar*0.5)
157 |
158 | if self.training:
159 |
160 | #distribution = torch.distributions.Normal(mu, std)
161 | distribution = torch.distributions.Normal(mu, std)
162 | output = torch.clamp(distribution.sample(), -1.0, 1.0)
163 | log_p = distribution.log_prob(output)
164 | log_p = torch.sum(log_p, dim=1)
165 |
166 | else:
167 |
168 | # output = F.tanh(mu)
169 | output = mu
170 | log_p = torch.ones(output.size(0))
171 |
172 | return output, std, log_p
173 |
174 |
175 | class ActionNetwork(nn.Module):
176 |
177 | def __init__(self, input_size, output_size, hidden=256):
178 | super(ActionNetwork, self).__init__()
179 |
180 | self.fc = nn.Sequential(
181 | nn.Linear(input_size, hidden),
182 | nn.ReLU(),
183 | nn.Linear(hidden, output_size)
184 | )
185 |
186 | def forward(self, x):
187 | logit = self.fc(x)
188 |
189 | return logit
190 |
191 |
192 | class BaselineNetwork(nn.Module):
193 |
194 | def __init__(self, input_size, output_size, hidden_size=256):
195 | super(BaselineNetwork, self).__init__()
196 |
197 | self.fc = nn.Sequential(
198 | nn.Linear(input_size, hidden_size),
199 | nn.Linear(hidden_size, output_size)
200 | )
201 |
202 | def forward(self, x):
203 | output = torch.sigmoid(self.fc(x.detach()))
204 | return output
205 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import torch
4 | from torch import nn, optim
5 | from torch.nn import functional as F
6 | import torch.utils.data
7 | from torchvision import datasets, transforms
8 | from torch.utils.data.sampler import SubsetRandomSampler
9 | import torchvision
10 | from model import GDRAM
11 | from utils import draw_locations
12 | from dataloader import MnistClutteredDataset
13 |
14 | def str2bool(v):
15 | if isinstance(v, bool):
16 | return v
17 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
18 | return True
19 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
20 | return False
21 | else:
22 | raise argparse.ArgumentTypeError('Boolean value expected.')
23 |
24 | parser = argparse.ArgumentParser(description='Gaussian-RAM')
25 | parser.add_argument('--data_path', type=str, default='data')
26 | parser.add_argument('--device', type=str, default='cuda', help='cuda or cpu')
27 | parser.add_argument('--batch_size', type=int, default = 128)
28 | parser.add_argument('--dataset', type=str, default='mnist')
29 | parser.add_argument('--lr', type=float, default='1e-3')
30 | parser.add_argument('--random_seed', type=int, default=1)
31 | parser.add_argument('--epochs', type=int, default=200)
32 | parser.add_argument('--log_interval', type=int, default=500)
33 | parser.add_argument('--resume', type=str2bool, default='False')
34 | parser.add_argument('--checkpoint', type=str, default=None)
35 |
36 | args = parser.parse_args()
37 |
38 | assert (args.dataset=='mnist' or args.dataset=='cifar10') or args.dataset=='cifar100', 'please use dataset in mnist, cifar10 or cifar100'
39 | torch.manual_seed(args.random_seed)
40 |
41 | kwargs = {'num_workers': 32, 'pin_memory': True} if not args.device=='cpu' else {}
42 |
43 | device = torch.device(args.device)
44 |
45 |
46 | img_size = 128
47 |
48 |
49 |
50 | ##################################################
51 |
52 | if args.dataset == 'cifar10':
53 |
54 | transform = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor()])
55 | # training set : validation set : test set = 50000 : 10000 : 10000
56 |
57 | train_set = datasets.CIFAR10(args.data_path,train=True, download=True, transform=transform)
58 | indices = list(range(len(train_set)))
59 | valid_size = 10000
60 | train_size = len(train_set) - valid_size
61 |
62 | train_idx, valid_idx = indices[valid_size:], indices[:valid_size]
63 |
64 | train_sampler = SubsetRandomSampler(train_idx)
65 | valid_sampler = SubsetRandomSampler(valid_idx)
66 |
67 | train_loader = torch.utils.data.DataLoader(
68 | train_set, batch_size=args.batch_size, sampler=train_sampler, **kwargs)
69 | valid_loader = torch.utils.data.DataLoader(
70 | train_set, batch_size=args.batch_size, sampler=valid_sampler, **kwargs)
71 |
72 | test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_path, train=False,\
73 | transform=transform),batch_size=args.batch_size, shuffle=False, **kwargs)
74 | if args.dataset == 'cifar100':
75 |
76 | transform = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor()])
77 | # training set : validation set : test set = 50000 : 10000 : 10000
78 |
79 | train_set = datasets.CIFAR100(args.data_path,train=True, download=True, transform=transform)
80 | indices = list(range(len(train_set)))
81 |
82 | valid_size = 10000
83 | train_size = len(train_set) - valid_size
84 |
85 | train_idx, valid_idx = indices[valid_size:], indices[:valid_size]
86 |
87 | train_sampler = SubsetRandomSampler(train_idx)
88 | valid_sampler = SubsetRandomSampler(valid_idx)
89 |
90 | train_loader = torch.utils.data.DataLoader(
91 | train_set, batch_size=args.batch_size, sampler=train_sampler, **kwargs)
92 | valid_loader = torch.utils.data.DataLoader(
93 | train_set, batch_size=args.batch_size, sampler=valid_sampler, **kwargs)
94 |
95 | test_loader = torch.utils.data.DataLoader(datasets.CIFAR100(args.data_path, train=False,\
96 | transform=transform),batch_size=args.batch_size, shuffle=False, **kwargs)
97 |
98 | elif args.dataset == 'mnist':
99 |
100 | transform = transforms.Compose([transforms.Resize(img_size),transforms.Grayscale(3), transforms.ToTensor()])
101 |
102 | train_set = MnistClutteredDataset(args.data_path, type='train', transform=transform)
103 | valid_set = MnistClutteredDataset(args.data_path, type='val', transform= transform)
104 | test_set = MnistClutteredDataset(args.data_path, type='test',transform=transform)
105 |
106 | train_size = len(train_set)
107 | valid_size = len(valid_set)
108 |
109 |
110 | train_loader = torch.utils.data.DataLoader(
111 | train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
112 | valid_loader = torch.utils.data.DataLoader(
113 | valid_set, batch_size=args.batch_size, shuffle=True, **kwargs)
114 | test_loader = torch.utils.data.DataLoader(
115 | test_set, batch_size=args.batch_size, shuffle=False, **kwargs
116 | )
117 |
118 |
119 | model = GDRAM(device=device, dataset=args.dataset, Fast = False).to(device)
120 |
121 | if args.resume:
122 | model.load_state_dict(torch.load(args.checkpoint))
123 |
124 | pytorch_total_params = sum(p.numel() for p in model.parameters())
125 |
126 | print('Model parameters: %d'%pytorch_total_params)
127 |
128 |
129 | lr_decay_rate = args.lr / args.epochs
130 | optimizer = optim.Adam(model.parameters(), lr=args.lr)
131 |
132 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=0.5, verbose=True, patience=5)
133 |
134 | predtion_loss_fn = nn.CrossEntropyLoss()
135 |
136 | def loss_function(labels, action_logits, location_log_probs, baselines):
137 |
138 | pred_loss = predtion_loss_fn(action_logits, labels.squeeze())
139 | predictions = torch.argmax(action_logits, dim=1, keepdim=True)
140 | num_repeats = baselines.size(-1)
141 | rewards = (labels == predictions.detach()).float().repeat(1, num_repeats)
142 |
143 |
144 | baseline_loss = F.mse_loss(rewards, baselines)
145 | b_rewards = rewards - baselines.detach()
146 | reinforce_loss = torch.mean(
147 | torch.sum(-location_log_probs * b_rewards, dim=1))
148 |
149 | return pred_loss + baseline_loss + reinforce_loss
150 |
151 |
152 | def train(epoch):
153 | model.train()
154 | train_loss = 0
155 |
156 | for batch_idx, (data, labels) in enumerate(train_loader):
157 | data = data.to(device)
158 |
159 | optimizer.zero_grad()
160 |
161 | action_logits, loc, location_log_probs, baselines, _ = model(data)
162 |
163 | labels = labels.unsqueeze(dim=1).to(device)
164 |
165 | loss = loss_function(labels, action_logits, location_log_probs, baselines)
166 |
167 | loss.backward()
168 |
169 | train_loss += loss.item()
170 | optimizer.step()
171 |
172 | if batch_idx % args.log_interval == 0:
173 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
174 | epoch, batch_idx * len(data), train_size,
175 | 100. * batch_idx / len(train_loader),
176 | loss.item() / len(data)))
177 |
178 | print('====> Epoch: {} Average loss: {:.4f}'.format(
179 | epoch, train_loss / train_size))
180 |
181 |
182 |
183 | def test(epoch, data_source, size):
184 | model.eval()
185 | total_correct = 0.0
186 | with torch.no_grad():
187 | for i, (data, labels) in enumerate(data_source):
188 | data = data.to(device)
189 | action_logits, _, _, _, _= model(data)
190 | predictions = torch.argmax(action_logits, dim=1)
191 | labels = labels.to(device)
192 | total_correct += torch.sum((labels == predictions)).item()
193 | accuracy = total_correct / size
194 |
195 | image = data[0:1]
196 | _, locations, _, _, weights = model(image)
197 | draw_locations(image.cpu().numpy()[0], locations.detach().cpu().numpy()[0], weights=weights, epoch=epoch)
198 | return accuracy
199 |
200 |
201 | best_valid_accuracy, test_accuracy = 0, 0
202 |
203 | for epoch in range(1, args.epochs + 1):
204 | accuracy = test(epoch, valid_loader, valid_size)
205 | scheduler.step(accuracy)
206 | print('====> Validation set accuracy: {:.2%}'.format(accuracy))
207 | if accuracy > best_valid_accuracy:
208 | best_valid_accuracy = accuracy
209 | test_accuracy = test(epoch, test_loader, len(test_loader.dataset))
210 |
211 | #torch.save(model.state_dict(), 'checkpoints/' + args.dataset + '_rnn_adaptive_12_test.pth')
212 |
213 | print('====> Test set accuracy: {:.2%}'.format(test_accuracy))
214 | train(epoch)
215 |
216 | print('====> Test set accuracy: {:.2%}'.format(test_accuracy))
217 |
--------------------------------------------------------------------------------