├── fig.png ├── checkpoints ├── mnist_gdram.pth ├── cifar10_gdram.pth └── cifar100_gdram.pth ├── dataloader.py ├── LICENSE ├── mnist_generation.py ├── utils.py ├── inference.py ├── model.py ├── README.md ├── modules.py └── train.py /fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsshim0125/gaussian-ram/HEAD/fig.png -------------------------------------------------------------------------------- /checkpoints/mnist_gdram.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsshim0125/gaussian-ram/HEAD/checkpoints/mnist_gdram.pth -------------------------------------------------------------------------------- /checkpoints/cifar10_gdram.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsshim0125/gaussian-ram/HEAD/checkpoints/cifar10_gdram.pth -------------------------------------------------------------------------------- /checkpoints/cifar100_gdram.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dsshim0125/gaussian-ram/HEAD/checkpoints/cifar100_gdram.pth -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from torch.utils.data import Dataset, DataLoader 4 | import torch 5 | from skimage import io 6 | from torchvision import transforms 7 | from PIL import Image 8 | import pandas as pd 9 | 10 | class MnistClutteredDataset(Dataset): 11 | 12 | def __init__(self, data_path, type, transform=None): 13 | 14 | self.root_dir = data_path +'/'+ type + '/path.txt' 15 | self.transform = transform 16 | self.path = pd.read_csv(self.root_dir, sep=' ', header=None) 17 | 18 | def __getitem__(self, idx): 19 | if torch.is_tensor(idx): 20 | idx = idx.tolist() 21 | 22 | img_path = self.path.iloc[idx,0] 23 | 24 | image = Image.open(img_path) 25 | 26 | label = int(self.path.iloc[idx,1]) 27 | 28 | if self.transform: 29 | image = self.transform(image) 30 | 31 | return image, label 32 | 33 | def __len__(self): 34 | return len(self.path) 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Dongseok Shim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /mnist_generation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | root_path = 'data' 6 | 7 | data = np.load(root_path + '/mnist_sequence1_sample_5distortions5x5.npz') 8 | 9 | X_train = data['X_train'] 10 | y_train = data['y_train'] 11 | 12 | X_val = data['X_valid'] 13 | y_val = data['y_valid'] 14 | 15 | X_test = data['X_test'] 16 | y_test = data['y_test'] 17 | 18 | if not os.path.exists(root_path + 'train'): 19 | os.mkdir(root_path + 'train') 20 | 21 | f = open(root_path + '/train/path.txt', 'w') 22 | 23 | for i in range(len(X_train)): 24 | 25 | img_path = root_path + '/train/%05d.jpg'%i 26 | 27 | img = X_train[i].reshape(40,40) 28 | plt.imsave(img_path, img) 29 | label = y_train[i,0] 30 | 31 | f.write(img_path+ ' %d\n'%(label)) 32 | 33 | f.close() 34 | 35 | if not os.path.exists(root_path + 'val'): 36 | os.mkdir(root_path + 'val') 37 | 38 | f = open(root_path + '/val/path.txt', 'w') 39 | 40 | for i in range(len(X_val)): 41 | 42 | img_path = root_path + '/val/%05d.jpg'%i 43 | 44 | img = X_val[i].reshape(40,40) 45 | plt.imsave(img_path, img) 46 | label = y_val[i,0] 47 | 48 | f.write(img_path+ ' %d\n'%(label)) 49 | 50 | f.close() 51 | 52 | 53 | if not os.path.exists(root_path + 'test'): 54 | os.mkdir(root_path + 'test') 55 | 56 | f = open(root_path + '/test/path.txt', 'w') 57 | 58 | for i in range(len(X_test)): 59 | 60 | img_path = root_path + '/test/%05d.jpg'%i 61 | 62 | img = X_test[i].reshape(40,40) 63 | plt.imsave(img_path, img) 64 | label = y_test[i,0] 65 | 66 | f.write(img_path+ ' %d\n'%(label)) 67 | 68 | f.close() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('agg') 3 | from matplotlib.animation import FuncAnimation 4 | import matplotlib.pyplot as plt 5 | import matplotlib.patches as patches 6 | import numpy as np 7 | import torch 8 | from torch.nn import functional as F 9 | from torchvision.utils import save_image 10 | import numpy as np 11 | import os 12 | 13 | def get_glimpse(x, l, output_size, k, device): 14 | """Transform image to retina representation 15 | 16 | Assume that width = height and channel = 1 17 | """ 18 | batch_size, input_size = x.size(0), x.size(2) - 1 19 | #device = torch.device('cpu') 20 | assert output_size * 2**(k - 1) <= input_size, \ 21 | "output_size * 2**(k-1) should smaller than or equal to input_size" 22 | 23 | # construct theta for affine transformation 24 | theta = torch.zeros(batch_size, 2, 3) 25 | theta[:, :, 2] = l 26 | 27 | scale = output_size / input_size 28 | osize = torch.Size([batch_size, 1, output_size, output_size]) 29 | 30 | for i in range(k): 31 | theta[:, 0, 0] = scale 32 | theta[:, 1, 1] = scale 33 | grid = F.affine_grid(theta, osize, align_corners=False).to(device) 34 | glimpse = F.grid_sample(x, grid, align_corners=False) 35 | 36 | if i==0: 37 | output = glimpse 38 | else: 39 | output = torch.cat((output, glimpse), dim=1) 40 | scale *= 2 41 | 42 | return output.detach() 43 | 44 | 45 | def draw_locations(image, locations, weights=None, size=8, epoch=0, save_path='results'): 46 | image = np.transpose(image, (1,2,0)) 47 | weights = weights.detach().cpu().numpy() 48 | 49 | 50 | if (epoch>50): 51 | for idx in range(len(weights[0])-1): 52 | if (weights[0][idx] < 0.5) and (weights[0][idx+1] < 0.5): 53 | break 54 | 55 | locations = locations[:idx+1] 56 | 57 | 58 | #print(locations.shape) 59 | locations = list(locations) 60 | fig, ax = plt.subplots(1, len(locations)) 61 | for i, location in enumerate(locations): 62 | if len(locations) == 1: 63 | subplot = ax 64 | else: 65 | subplot = ax[i] 66 | 67 | subplot.axis('off') 68 | subplot.imshow(image, cmap='gray') 69 | loc = ((location[0] + 1) * image.shape[1] / 2 - size / 2, 70 | (location[1] + 1) * image.shape[0] / 2 - size / 2) 71 | 72 | rect = patches.Rectangle( 73 | loc, size, size, linewidth=1, edgecolor='r', facecolor='none') 74 | subplot.add_patch(rect) 75 | fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=None, hspace=None) 76 | 77 | 78 | if not os.path.exists(save_path): 79 | os.mkdir(save_path) 80 | plt.savefig(save_path+ '/glimpse_%d.png'%epoch, bbox_inches='tight') 81 | plt.close() 82 | 83 | if __name__ == '__main__': 84 | img = np.ones((3,3,28,28)) 85 | 86 | loc = np.ones((3,2)) 87 | 88 | img = torch.Tensor(img).cuda() 89 | loc = torch.Tensor(loc).cuda() 90 | 91 | out = get_glimpse(img,loc,8,2) 92 | print(out.shape) 93 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from torchvision import datasets, transforms 3 | from model import GDRAM 4 | from dataloader import MnistClutteredDataset 5 | import time 6 | import argparse 7 | 8 | 9 | def str2bool(v): 10 | if isinstance(v, bool): 11 | return v 12 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 13 | return True 14 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 15 | return False 16 | else: 17 | raise argparse.ArgumentTypeError('Boolean value expected.') 18 | 19 | 20 | parser = argparse.ArgumentParser(description='Inference') 21 | 22 | parser.add_argument('--data_path', type=str, default='data') 23 | parser.add_argument('--dataset', type=str, default='mnist') 24 | parser.add_argument('--device', type=str, default='cuda') 25 | parser.add_argument('--fast', type=str2bool, default='False') 26 | parser.add_argument('--random_seed', type=int, default=1) 27 | args = parser.parse_args() 28 | 29 | batch_size = 1 30 | 31 | kwargs = {'num_workers': 64, 'pin_memory': True} if not args.device=='cpu' else {} 32 | 33 | device = torch.device(args.device) 34 | 35 | model_path = 'checkpoints/'+args.dataset+'_gdram.pth' 36 | 37 | img_size = 128 38 | 39 | torch.manual_seed(args.random_seed) 40 | 41 | ################################################## 42 | 43 | if args.dataset == 'cifar10': 44 | 45 | transform = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor()]) 46 | 47 | test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_path, train=False,\ 48 | transform=transform),batch_size=batch_size, shuffle=False, **kwargs) 49 | 50 | elif args.dataset == 'cifar100': 51 | 52 | transform = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor()]) 53 | 54 | test_loader = torch.utils.data.DataLoader(datasets.CIFAR100(args.data_path, train=False,\ 55 | transform=transform),batch_size=batch_size, shuffle=False, **kwargs) 56 | 57 | elif args.dataset == 'mnist': 58 | 59 | transform = transforms.Compose([transforms.Resize(img_size),transforms.Grayscale(3), transforms.ToTensor()]) 60 | test_set = MnistClutteredDataset(args.data_path, type='test',transform=transform) 61 | 62 | test_loader = torch.utils.data.DataLoader( 63 | test_set, batch_size=batch_size, shuffle=False, **kwargs 64 | ) 65 | 66 | 67 | model = GDRAM(device=device, dataset=args.dataset, Fast=args.fast).to(device) 68 | model.eval() 69 | 70 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 71 | 72 | print('Model parameters: %d'%pytorch_total_params) 73 | 74 | model.load_state_dict(torch.load(model_path)) 75 | print('Model Loaded!') 76 | 77 | total_correct = 0.0 78 | 79 | def accuracy2(output, target, topk=(1,)): 80 | maxk = max(topk) 81 | batch_size = target.size(0) 82 | 83 | _, pred = output.topk(maxk, 1, True, True) 84 | pred = pred.t() 85 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 86 | 87 | res = [] 88 | for k in topk: 89 | correct_k = correct[:k].view(-1).float().sum(0) 90 | res.append(correct_k.mul_(100.0 / batch_size)) 91 | return res 92 | 93 | 94 | accuracy1 = 0 95 | accuracy5 = 0 96 | 97 | start_time = time.time() 98 | 99 | for data, labels in test_loader: 100 | data = data.to(device) 101 | action_logits, location, _, _, weights = model(data) 102 | predictions = torch.argmax(action_logits, dim=1) 103 | labels = labels.to(device) 104 | total_correct += torch.sum((labels == predictions)).item() 105 | 106 | acc1 , acc5 = accuracy2(action_logits, labels, topk=(1,5)) 107 | accuracy1 += acc1.detach().cpu().numpy() 108 | accuracy5 += acc5.detach().cpu().numpy() 109 | 110 | acc1 = accuracy1/len(test_loader) 111 | acc5 = accuracy5/len(test_loader) 112 | 113 | print("Top1:%.2f Top5:%.2f fps:%.5f"%(acc1, acc5,(time.time() - start_time)/len(test_loader.dataset))) 114 | 115 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from modules import * 4 | from utils import get_glimpse 5 | import math 6 | 7 | 8 | class GDRAM(nn.Module): 9 | def __init__(self, device=None, dataset=None, Fast=False): 10 | super(GDRAM, self).__init__() 11 | 12 | self.glimpse_size = 12 13 | self.num_scales = 4 14 | 15 | self.img_size = 128 16 | 17 | self.class_num = 10 18 | 19 | if dataset == 'cifar100': 20 | self.class_num = 100 21 | 22 | 23 | self.normalized_glimpse_size = self.glimpse_size/(self.img_size/2) 24 | 25 | self.glimpse_net = GlimpseNetwork(3*self.num_scales,self.glimpse_size,2,128,128) 26 | 27 | self.rnn1 = GlimpseLSTMCoreNetwork(128,128) 28 | self.rnn2 = LocationLSTMCoreNetwork(128,128,self.glimpse_size) 29 | 30 | self.class_net = ActionNetwork(128, self.class_num) 31 | self.emission_net = EmissionNetwork(128) 32 | 33 | self.baseline_net = BaselineNetwork(128*2,1) 34 | 35 | self.num_glimpses = 8 36 | self.location_size = 2 37 | 38 | self.device = device 39 | 40 | self.fast = Fast 41 | 42 | def forward(self, x): 43 | 44 | batch_size = x.size(0) 45 | 46 | hidden1, cell_state1 = self.rnn1.init_hidden(batch_size) 47 | hidden1 = hidden1.to(self.device) 48 | cell_state1 = cell_state1.to(self.device) 49 | 50 | 51 | hidden2, cell_state2 = self.rnn2.init_hidden(x, batch_size) 52 | 53 | hidden2 = hidden2.to(self.device) 54 | cell_state2 = cell_state2.to(self.device) 55 | 56 | #location = torch.zeros(batch_size,2).to(self.device) 57 | std = (torch.ones(batch_size,2)*(math.exp(-1/2))).to(self.device) 58 | 59 | location, std, log_prob = self.emission_net(hidden2) 60 | location = torch.clamp(location, min=-1 + self.normalized_glimpse_size / 2, 61 | max=1 - self.normalized_glimpse_size / 2) 62 | 63 | location_log_probs = torch.empty(batch_size, self.num_glimpses).to(self.device) 64 | locations = torch.empty(batch_size, self.num_glimpses, self.location_size).to(self.device) 65 | baselines = torch.empty(batch_size, self.num_glimpses).to(self.device) 66 | weights = torch.empty(batch_size, self.num_glimpses).to(self.device) 67 | 68 | weight = torch.ones(batch_size).to(self.device) 69 | 70 | action_logits = 0 71 | weight_sum = 0 72 | 73 | 74 | for i in range(self.num_glimpses): 75 | 76 | 77 | 78 | locations[:, i] = location 79 | 80 | location_log_probs[:, i] = log_prob 81 | 82 | glimpse = get_glimpse(x, location.detach(), self.glimpse_size, self.num_scales, device=self.device).to(self.device) 83 | glimpse_feature = self.glimpse_net(glimpse, location) 84 | 85 | hidden1, cell_state1 = self.rnn1(glimpse_feature, (hidden1, cell_state1)) 86 | hidden2, cell_state2 = self.rnn2(hidden1, (hidden2, cell_state2)) 87 | 88 | loc_diff, std, log_prob = self.emission_net(hidden2) 89 | loc_diff *= (self.normalized_glimpse_size/2 * 2**(self.num_scales - 1)) 90 | new_location = location.detach() + loc_diff 91 | new_location = torch.clamp(new_location, min = -1 + self.normalized_glimpse_size/2 , max= 1 - self.normalized_glimpse_size/2) 92 | 93 | 94 | location = new_location 95 | 96 | hidden = torch.cat((hidden1, hidden2), dim=1) 97 | baseline = self.baseline_net(hidden) 98 | 99 | #location_log_probs[:, i] = log_prob 100 | baselines[:, i] = baseline.squeeze() 101 | 102 | weight = weight.unsqueeze(1) 103 | action_logit = self.class_net(hidden1) 104 | 105 | action_logits += weight*action_logit 106 | 107 | weights[:,i] = weight.squeeze() 108 | 109 | weight_sum += weight 110 | 111 | if (not self.training and i>1) and self.fast: 112 | if weights[0,-1]<0.5 and weights[0,-2]<0.5: 113 | break 114 | 115 | std = torch.mean(std, dim=1) 116 | normalized_std = (std-math.exp(-1/2))/(math.exp(1/2)-math.exp(-1/2)) 117 | weight = 1 - normalized_std 118 | 119 | action_logits /= weight_sum 120 | 121 | return action_logits, locations, location_log_probs, baselines, weights 122 | 123 | 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gaussian RAM 2 | 3 | ### ICROS ICCAS 2020 Student Best Paper Finalist 4 | 5 | This repo is an official PyTorch implementation of "Gaussian RAM: Lightweight Image Classification via Stochastic Retina Inspired Glimpse and Reinforcement Learning". [[paper](https://arxiv.org/abs/2011.06190)] 6 | 7 | 8 | ## Abstract 9 | Previous studies on image classification have been mainly focused on the performance of the networks, not on real-time operation or model compression. We propose a Gaussian Deep Recurrent visual Attention Model (GDRAM)- a reinforcement learning based lightweight deep neural network for large scale image classification that outperformsthe conventional CNN (Convolutional Neural Network) which uses the entire image as input. Highly inspired by the biological visual recognition process, our model mimics the stochastic location of the retina with Gaussian distribution. We evaluate the model on Large cluttered MNIST, Large CIFAR-10 and Large CIFAR-100 datasets which are resized to 128 in both width and height. 10 | 11 |

12 | 13 |

14 | 15 | ## Dataset 16 | Cluttered MNIST([download](https://drive.google.com/file/d/1nMO5XIFmjyPnJjfvBeFpujeuZ3Qk7vhd/view?usp=sharing)), CIFAR10 and CIFAR100 are used to train and evaluate. All the images are resized to 128 in both height and weight for generating high scale image. 17 | ## Requirements 18 | - Python3 19 | - PyTorch (> 1.0) 20 | - torchvision (> 0.2) 21 | - PIL 22 | - NumPy 23 | 24 | ## Training 25 | ```bash 26 | python train.py --data_path --dataset --batch_size --lr --epochs --random_seed --log_interval --resume --checkpoint 27 | ``` 28 | 29 | ## Inference 30 | ```bash 31 | python inference.py --data_path --dataset --random_seed --fast 32 | ``` 33 | 34 | ## Acknowledgement 35 | This work was supported by Institute of Information & Communications Technology Planning & Evaluation(IITP) grant funded by the Korea government (MSIT) (No. 2019-0-01367, Infant-Mimic Neurocognitive Developmental Machine Learning from Interaction Experience with Real World (BabyMind)) 36 | 37 | ## References 38 | [1] Y. Lecun, L. Bottou, Y. Bengio, and P. Haffner,“Gradient-based learning applied to documentrecognition,” inProceedings of the IEEE, 1998, pp.2278–2324.
39 | [2] K. Simonyan and A. Zisserman, “Very deep con-volutional networks for large-scale image recogni-tion,”arXiv preprint arXiv:1409.1556, 2014.
40 | [3] C. Szegedy, W. Liu, Y. Jia, P. Sermanet, S. Reed,D. Anguelov, D. Erhan, V. Vanhoucke, and A. Ra-binovich, “Going deeper with convolutions,” inPro-ceedings of the IEEE conference on computer visionand pattern recognition, 2015, pp. 1–9.
41 | [4] K. He, X. Zhang, S. Ren, and J. Sun, “Deep resid-ual learning for image recognition,” inProceedingsof the IEEE conference on computer vision and pat-tern recognition, 2016, pp. 770–778.
42 | [5] G. Huang, Z. Liu, L. Van Der Maaten, and K. Q.Weinberger, “Densely connected convolutional net-works,” inProceedings of the IEEE conference oncomputer vision and pattern recognition, 2017, pp.4700–4708.
43 | [6] Y. LeCun, “The mnist database of handwritten dig-its,”http://yann. lecun. com/exdb/mnist/.
44 | [7] O. Russakovsky, J. Deng, H. Su, J. Krause,S. Satheesh, S. Ma, Z. Huang, A. Karpathy,A. Khosla, M. Bernstein, A. C. Berg, and L. Fei-Fei,“ImageNet Large Scale Visual Recognition Chal-lenge,”International Journal of Computer Vision(IJCV), vol. 115, no. 3, pp. 211–252, 2015.
45 | [8] V. Mnih, N. Heess, A. Graveset al., “Recurrentmodels of visual attention,” inAdvances in neuralinformation processing systems, 2014, pp. 2204–2212.
46 | [9] J. Ba, V. Mnih, and K. Kavukcuoglu, “Multi-ple object recognition with visual attention,”arXivpreprint arXiv:1412.7755, 2014.
47 | [10] Q. Liu, R. Hang, H. Song, and Z. Li, “Learn-ing multi-scale deep features for high-resolutionsatellite image classification,”arXiv preprintarXiv:1611.03591, 2016.
48 | [11] M. Iftenea, Q. Liub, and Y. Wangc, “Very high res-olution images classification by fusing deep convo-lutional neural networks.”
49 | [12] A. Ablavatski, S. Lu, and J. Cai, “Enriched deeprecurrent visual attention model for multiple objectrecognition,” in2017 IEEE Winter Conference onApplications of Computer Vision (WACV).IEEE,2017, pp. 971–978.
50 | [13] M. Jaderberg, K. Simonyan, A. Zissermanet al., 51 | “Spatial transformer networks,” inAdvances inneural information processing systems, 2015, pp.2017–2025.
52 | [14] J. Redmon and A. Farhadi,“Yolov3:Anincrementalimprovement,”arXiv preprintarXiv:1804.02767, 2018.
53 | [15] J. Choi, D. Chun, H. Kim, and H.-J. Lee, “Gaussianyolov3: An accurate and fast object detector usinglocalization uncertainty for autonomous driving,” inProceedings of the IEEE International Conferenceon Computer Vision, 2019, pp. 502–511.
54 | [16] S. Ioffe and C. Szegedy, “Batch normaliza-tion: Accelerating deep network training by re-ducing internal covariate shift,”arXiv preprintarXiv:1502.03167, 2015.
55 | [17] S. Hochreiter and J. Schmidhuber, “Long short-termmemory,”Neural computation, vol. 9, no. 8, pp.1735–1780, 1997.
56 | [18] R. S. Sutton, D. A. McAllester, S. P. Singh, andY. Mansour, “Policy gradient methods for reinforce-ment learning with function approximation,” inAd-vances in neural information processing systems,2000, pp. 1057–1063. 57 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import cv2 5 | import numpy as np 6 | 7 | class GlimpseNetwork(nn.Module): 8 | 9 | def __init__(self, input_channel, glimpse_size, location_size, internal_size, output_size): 10 | super(GlimpseNetwork, self).__init__() 11 | 12 | self.fc_g = nn.Sequential( 13 | nn.Conv2d(input_channel, 128, kernel_size=3, stride=1, padding=1), 14 | nn.BatchNorm2d(128), 15 | nn.ReLU(), 16 | nn.MaxPool2d(2), 17 | nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), 18 | nn.BatchNorm2d(256), 19 | nn.ReLU(), 20 | nn.MaxPool2d(2) 21 | ) 22 | 23 | self.fc_l = nn.Sequential( 24 | nn.Linear(location_size, internal_size), 25 | nn.ReLU()) 26 | 27 | self.fc_gg = nn.Linear(glimpse_size//4 * glimpse_size//4 * 256, output_size) 28 | self.fc_lg = nn.Linear(internal_size, output_size) 29 | 30 | def forward(self, x, location): 31 | hg = self.fc_g(x).view(len(x), -1) 32 | hl = self.fc_l(location) 33 | 34 | output = F.relu(self.fc_gg(hg) * self.fc_lg(hl)) 35 | 36 | return output 37 | 38 | 39 | 40 | class CoreNetwork(nn.Module): 41 | 42 | def __init__(self, input_size, hidden_size): 43 | super(CoreNetwork, self).__init__() 44 | 45 | self.hidden_size = hidden_size 46 | self.rnn_cell = nn.RNNCell( 47 | input_size, hidden_size, nonlinearity='relu') 48 | 49 | def forward(self, g, prev_h): 50 | h = self.rnn_cell(g, prev_h) 51 | return h 52 | 53 | def init_hidden(self, batch_size): 54 | return torch.zeros(batch_size, self.hidden_size) 55 | 56 | 57 | class GRUCoreNetwork(nn.Module): 58 | 59 | def __init__(self, input_size, hidden_size): 60 | super(GRUCoreNetwork, self).__init__() 61 | 62 | self.hidden_size = hidden_size 63 | self.rnn_cell = nn.GRUCell( 64 | input_size, hidden_size) 65 | 66 | def forward(self, g, prev_h): 67 | h = self.rnn_cell(g, prev_h) 68 | return h 69 | 70 | def init_hidden(self, batch_size): 71 | return torch.zeros(batch_size, self.hidden_size) 72 | 73 | 74 | class GlimpseLSTMCoreNetwork(nn.Module): 75 | 76 | def __init__(self, input_size, hidden_size): 77 | super(GlimpseLSTMCoreNetwork, self).__init__() 78 | 79 | self.hidden_size = hidden_size 80 | self.lstm_cell = nn.LSTMCell( 81 | input_size, hidden_size) 82 | 83 | def forward(self, g, prev_h): 84 | h, c = self.lstm_cell(g, prev_h) 85 | return h, c 86 | 87 | def init_hidden(self, batch_size): 88 | return torch.zeros(batch_size, self.hidden_size), torch.zeros(batch_size, self.hidden_size) 89 | 90 | 91 | class LocationLSTMCoreNetwork(nn.Module): 92 | 93 | def __init__(self, input_size, hidden_size, glimpse_size): 94 | super(LocationLSTMCoreNetwork, self).__init__() 95 | 96 | self.hidden_size = hidden_size 97 | self.glimpse_size = glimpse_size 98 | 99 | self.lstm_cell = nn.LSTMCell( 100 | input_size, hidden_size) 101 | 102 | self.context_net1 = nn.Sequential( 103 | nn.Conv2d(3,64,3,padding=1), 104 | nn.ReLU(), 105 | nn.MaxPool2d(2)) 106 | 107 | self.context_net2 = nn.Sequential( 108 | nn.Conv2d(64,64,3,padding=1), 109 | nn.ReLU(), 110 | nn.MaxPool2d(2) 111 | ) 112 | 113 | self.fc = nn.Linear(glimpse_size//4*glimpse_size//4*64,hidden_size) 114 | 115 | def forward(self, g, prev_h): 116 | h, c = self.lstm_cell(g, prev_h) 117 | return h, c 118 | 119 | def init_hidden(self, x, batch_size): 120 | x = F.interpolate(x, (self.glimpse_size,self.glimpse_size)) 121 | 122 | h = self.fc(self.context_net2(self.context_net1(x)).view(batch_size,-1)) 123 | c = torch.zeros((batch_size, self.hidden_size)) 124 | 125 | return h, c 126 | 127 | 128 | class EmissionNetwork(nn.Module): 129 | 130 | def __init__(self, input_size, uniform=False, output_size=2, hidden=256): 131 | super(EmissionNetwork, self).__init__() 132 | 133 | self.fc = nn.Sequential( 134 | nn.Linear(input_size, hidden), 135 | nn.BatchNorm1d(hidden), 136 | nn.ReLU()) 137 | 138 | self.mu_net = nn.Sequential( 139 | nn.Linear(hidden, output_size), 140 | nn.Tanh() 141 | ) 142 | 143 | self.logvar_net = nn.Sequential( 144 | nn.Linear(hidden, output_size), 145 | nn.Tanh() 146 | ) 147 | 148 | self.unifrom = uniform 149 | 150 | def forward(self, x): 151 | 152 | z = self.fc(x.detach()) 153 | mu = self.mu_net(z) 154 | 155 | logvar = self.logvar_net(z) 156 | std = torch.exp(logvar*0.5) 157 | 158 | if self.training: 159 | 160 | #distribution = torch.distributions.Normal(mu, std) 161 | distribution = torch.distributions.Normal(mu, std) 162 | output = torch.clamp(distribution.sample(), -1.0, 1.0) 163 | log_p = distribution.log_prob(output) 164 | log_p = torch.sum(log_p, dim=1) 165 | 166 | else: 167 | 168 | # output = F.tanh(mu) 169 | output = mu 170 | log_p = torch.ones(output.size(0)) 171 | 172 | return output, std, log_p 173 | 174 | 175 | class ActionNetwork(nn.Module): 176 | 177 | def __init__(self, input_size, output_size, hidden=256): 178 | super(ActionNetwork, self).__init__() 179 | 180 | self.fc = nn.Sequential( 181 | nn.Linear(input_size, hidden), 182 | nn.ReLU(), 183 | nn.Linear(hidden, output_size) 184 | ) 185 | 186 | def forward(self, x): 187 | logit = self.fc(x) 188 | 189 | return logit 190 | 191 | 192 | class BaselineNetwork(nn.Module): 193 | 194 | def __init__(self, input_size, output_size, hidden_size=256): 195 | super(BaselineNetwork, self).__init__() 196 | 197 | self.fc = nn.Sequential( 198 | nn.Linear(input_size, hidden_size), 199 | nn.Linear(hidden_size, output_size) 200 | ) 201 | 202 | def forward(self, x): 203 | output = torch.sigmoid(self.fc(x.detach())) 204 | return output 205 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | from torch import nn, optim 5 | from torch.nn import functional as F 6 | import torch.utils.data 7 | from torchvision import datasets, transforms 8 | from torch.utils.data.sampler import SubsetRandomSampler 9 | import torchvision 10 | from model import GDRAM 11 | from utils import draw_locations 12 | from dataloader import MnistClutteredDataset 13 | 14 | def str2bool(v): 15 | if isinstance(v, bool): 16 | return v 17 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 18 | return True 19 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 20 | return False 21 | else: 22 | raise argparse.ArgumentTypeError('Boolean value expected.') 23 | 24 | parser = argparse.ArgumentParser(description='Gaussian-RAM') 25 | parser.add_argument('--data_path', type=str, default='data') 26 | parser.add_argument('--device', type=str, default='cuda', help='cuda or cpu') 27 | parser.add_argument('--batch_size', type=int, default = 128) 28 | parser.add_argument('--dataset', type=str, default='mnist') 29 | parser.add_argument('--lr', type=float, default='1e-3') 30 | parser.add_argument('--random_seed', type=int, default=1) 31 | parser.add_argument('--epochs', type=int, default=200) 32 | parser.add_argument('--log_interval', type=int, default=500) 33 | parser.add_argument('--resume', type=str2bool, default='False') 34 | parser.add_argument('--checkpoint', type=str, default=None) 35 | 36 | args = parser.parse_args() 37 | 38 | assert (args.dataset=='mnist' or args.dataset=='cifar10') or args.dataset=='cifar100', 'please use dataset in mnist, cifar10 or cifar100' 39 | torch.manual_seed(args.random_seed) 40 | 41 | kwargs = {'num_workers': 32, 'pin_memory': True} if not args.device=='cpu' else {} 42 | 43 | device = torch.device(args.device) 44 | 45 | 46 | img_size = 128 47 | 48 | 49 | 50 | ################################################## 51 | 52 | if args.dataset == 'cifar10': 53 | 54 | transform = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor()]) 55 | # training set : validation set : test set = 50000 : 10000 : 10000 56 | 57 | train_set = datasets.CIFAR10(args.data_path,train=True, download=True, transform=transform) 58 | indices = list(range(len(train_set))) 59 | valid_size = 10000 60 | train_size = len(train_set) - valid_size 61 | 62 | train_idx, valid_idx = indices[valid_size:], indices[:valid_size] 63 | 64 | train_sampler = SubsetRandomSampler(train_idx) 65 | valid_sampler = SubsetRandomSampler(valid_idx) 66 | 67 | train_loader = torch.utils.data.DataLoader( 68 | train_set, batch_size=args.batch_size, sampler=train_sampler, **kwargs) 69 | valid_loader = torch.utils.data.DataLoader( 70 | train_set, batch_size=args.batch_size, sampler=valid_sampler, **kwargs) 71 | 72 | test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_path, train=False,\ 73 | transform=transform),batch_size=args.batch_size, shuffle=False, **kwargs) 74 | if args.dataset == 'cifar100': 75 | 76 | transform = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor()]) 77 | # training set : validation set : test set = 50000 : 10000 : 10000 78 | 79 | train_set = datasets.CIFAR100(args.data_path,train=True, download=True, transform=transform) 80 | indices = list(range(len(train_set))) 81 | 82 | valid_size = 10000 83 | train_size = len(train_set) - valid_size 84 | 85 | train_idx, valid_idx = indices[valid_size:], indices[:valid_size] 86 | 87 | train_sampler = SubsetRandomSampler(train_idx) 88 | valid_sampler = SubsetRandomSampler(valid_idx) 89 | 90 | train_loader = torch.utils.data.DataLoader( 91 | train_set, batch_size=args.batch_size, sampler=train_sampler, **kwargs) 92 | valid_loader = torch.utils.data.DataLoader( 93 | train_set, batch_size=args.batch_size, sampler=valid_sampler, **kwargs) 94 | 95 | test_loader = torch.utils.data.DataLoader(datasets.CIFAR100(args.data_path, train=False,\ 96 | transform=transform),batch_size=args.batch_size, shuffle=False, **kwargs) 97 | 98 | elif args.dataset == 'mnist': 99 | 100 | transform = transforms.Compose([transforms.Resize(img_size),transforms.Grayscale(3), transforms.ToTensor()]) 101 | 102 | train_set = MnistClutteredDataset(args.data_path, type='train', transform=transform) 103 | valid_set = MnistClutteredDataset(args.data_path, type='val', transform= transform) 104 | test_set = MnistClutteredDataset(args.data_path, type='test',transform=transform) 105 | 106 | train_size = len(train_set) 107 | valid_size = len(valid_set) 108 | 109 | 110 | train_loader = torch.utils.data.DataLoader( 111 | train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 112 | valid_loader = torch.utils.data.DataLoader( 113 | valid_set, batch_size=args.batch_size, shuffle=True, **kwargs) 114 | test_loader = torch.utils.data.DataLoader( 115 | test_set, batch_size=args.batch_size, shuffle=False, **kwargs 116 | ) 117 | 118 | 119 | model = GDRAM(device=device, dataset=args.dataset, Fast = False).to(device) 120 | 121 | if args.resume: 122 | model.load_state_dict(torch.load(args.checkpoint)) 123 | 124 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 125 | 126 | print('Model parameters: %d'%pytorch_total_params) 127 | 128 | 129 | lr_decay_rate = args.lr / args.epochs 130 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 131 | 132 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=0.5, verbose=True, patience=5) 133 | 134 | predtion_loss_fn = nn.CrossEntropyLoss() 135 | 136 | def loss_function(labels, action_logits, location_log_probs, baselines): 137 | 138 | pred_loss = predtion_loss_fn(action_logits, labels.squeeze()) 139 | predictions = torch.argmax(action_logits, dim=1, keepdim=True) 140 | num_repeats = baselines.size(-1) 141 | rewards = (labels == predictions.detach()).float().repeat(1, num_repeats) 142 | 143 | 144 | baseline_loss = F.mse_loss(rewards, baselines) 145 | b_rewards = rewards - baselines.detach() 146 | reinforce_loss = torch.mean( 147 | torch.sum(-location_log_probs * b_rewards, dim=1)) 148 | 149 | return pred_loss + baseline_loss + reinforce_loss 150 | 151 | 152 | def train(epoch): 153 | model.train() 154 | train_loss = 0 155 | 156 | for batch_idx, (data, labels) in enumerate(train_loader): 157 | data = data.to(device) 158 | 159 | optimizer.zero_grad() 160 | 161 | action_logits, loc, location_log_probs, baselines, _ = model(data) 162 | 163 | labels = labels.unsqueeze(dim=1).to(device) 164 | 165 | loss = loss_function(labels, action_logits, location_log_probs, baselines) 166 | 167 | loss.backward() 168 | 169 | train_loss += loss.item() 170 | optimizer.step() 171 | 172 | if batch_idx % args.log_interval == 0: 173 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 174 | epoch, batch_idx * len(data), train_size, 175 | 100. * batch_idx / len(train_loader), 176 | loss.item() / len(data))) 177 | 178 | print('====> Epoch: {} Average loss: {:.4f}'.format( 179 | epoch, train_loss / train_size)) 180 | 181 | 182 | 183 | def test(epoch, data_source, size): 184 | model.eval() 185 | total_correct = 0.0 186 | with torch.no_grad(): 187 | for i, (data, labels) in enumerate(data_source): 188 | data = data.to(device) 189 | action_logits, _, _, _, _= model(data) 190 | predictions = torch.argmax(action_logits, dim=1) 191 | labels = labels.to(device) 192 | total_correct += torch.sum((labels == predictions)).item() 193 | accuracy = total_correct / size 194 | 195 | image = data[0:1] 196 | _, locations, _, _, weights = model(image) 197 | draw_locations(image.cpu().numpy()[0], locations.detach().cpu().numpy()[0], weights=weights, epoch=epoch) 198 | return accuracy 199 | 200 | 201 | best_valid_accuracy, test_accuracy = 0, 0 202 | 203 | for epoch in range(1, args.epochs + 1): 204 | accuracy = test(epoch, valid_loader, valid_size) 205 | scheduler.step(accuracy) 206 | print('====> Validation set accuracy: {:.2%}'.format(accuracy)) 207 | if accuracy > best_valid_accuracy: 208 | best_valid_accuracy = accuracy 209 | test_accuracy = test(epoch, test_loader, len(test_loader.dataset)) 210 | 211 | #torch.save(model.state_dict(), 'checkpoints/' + args.dataset + '_rnn_adaptive_12_test.pth') 212 | 213 | print('====> Test set accuracy: {:.2%}'.format(test_accuracy)) 214 | train(epoch) 215 | 216 | print('====> Test set accuracy: {:.2%}'.format(test_accuracy)) 217 | --------------------------------------------------------------------------------