├── .gitignore ├── README.md ├── attention.py ├── cnn-with-attention.py ├── figures ├── attn1.png ├── attn2.png ├── attn3.png └── rnn-with-attention.png ├── functions.py ├── models.py └── rnn-with-attention.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-attention-mechanism 2 | my codes for learning attention mechanism 3 | 4 | ## CNN with attention 5 | 6 | *Apply spatial attention to CIFAR100 dataset* 7 | 8 | 9 | 10 | ### Usage 11 | 12 | Train the model: 13 | 14 | ```bash 15 | $ python cnn-with-attention.py --train 16 | ``` 17 | 18 | Visualize attention map: 19 | 20 | ```bash 21 | $ python cnn-with-attention.py --visualize 22 | ``` 23 | 24 | ## RNN with attention 25 | 26 | *Apply temporal attention to sequential data* 27 | 28 | *e.g. A sequence of length 20, the output is only related to the 5th position and the 13th position* 29 | 30 | 31 | 32 | ### Usage 33 | 34 | Train the model: 35 | 36 | ```bash 37 | $ python rnn-with-attention.py --train 38 | ``` 39 | 40 | Visualize attention map: 41 | 42 | ```bash 43 | $ python rnn-with-attention.py --visualize 44 | ``` 45 | 46 | ## Todos 47 | 48 | - [x] CNN+attention 49 | - [x] RNN+attention 50 | 51 | ## References 52 | 53 | - [Learn to Pay Attention! Trainable Visual Attention in CNNs](https://towardsdatascience.com/learn-to-pay-attention-trainable-visual-attention-in-cnns-87e2869f89f1) 54 | - [Attention in Neural Networks and How to Use It](http://akosiorek.github.io/ml/2017/10/14/visual-attention.html) 55 | - [Attention? Attention!](https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html) 56 | - [Tutorial - Visual Attention for Action Recognition](https://dtransposed.github.io/blog/Action-Recognition-Attention.html) 57 | - [BahdanauAttention与LuongAttention注意力机制简介](https://blog.csdn.net/u010960155/article/details/82853632) 58 | - [目前主流的attention方法都有哪些?](https://www.zhihu.com/question/68482809) 59 | 60 | - [Keras Attention Mechanism](https://github.com/philipperemy/keras-attention-mechanism) -------------------------------------------------------------------------------- /attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | """ 6 | Attention blocks 7 | Reference: Learn To Pay Attention 8 | """ 9 | class ProjectorBlock(nn.Module): 10 | def __init__(self, in_features, out_features): 11 | super(ProjectorBlock, self).__init__() 12 | self.op = nn.Conv2d(in_channels=in_features, out_channels=out_features, 13 | kernel_size=1, padding=0, bias=False) 14 | 15 | def forward(self, x): 16 | return self.op(x) 17 | 18 | 19 | class SpatialAttn(nn.Module): 20 | def __init__(self, in_features, normalize_attn=True): 21 | super(SpatialAttn, self).__init__() 22 | self.normalize_attn = normalize_attn 23 | self.op = nn.Conv2d(in_channels=in_features, out_channels=1, 24 | kernel_size=1, padding=0, bias=False) 25 | 26 | def forward(self, l, g): 27 | N, C, H, W = l.size() 28 | c = self.op(l+g) # (batch_size,1,H,W) 29 | if self.normalize_attn: 30 | a = F.softmax(c.view(N,1,-1), dim=2).view(N,1,H,W) 31 | else: 32 | a = torch.sigmoid(c) 33 | g = torch.mul(a.expand_as(l), l) 34 | if self.normalize_attn: 35 | g = g.view(N,C,-1).sum(dim=2) # (batch_size,C) 36 | else: 37 | g = F.adaptive_avg_pool2d(g, (1,1)).view(N,C) 38 | return c.view(N,1,H,W), g 39 | 40 | """ 41 | Temporal attention block 42 | Reference: https://github.com/philipperemy/keras-attention-mechanism 43 | """ 44 | class TemporalAttn(nn.Module): 45 | def __init__(self, hidden_size): 46 | super(TemporalAttn, self).__init__() 47 | self.hidden_size = hidden_size 48 | self.fc1 = nn.Linear(self.hidden_size, self.hidden_size, bias=False) 49 | self.fc2 = nn.Linear(self.hidden_size*2, self.hidden_size, bias=False) 50 | 51 | def forward(self, hidden_states): 52 | # (batch_size, time_steps, hidden_size) 53 | score_first_part = self.fc1(hidden_states) 54 | # (batch_size, hidden_size) 55 | h_t = hidden_states[:,-1,:] 56 | # (batch_size, time_steps) 57 | score = torch.bmm(score_first_part, h_t.unsqueeze(2)).squeeze(2) 58 | attention_weights = F.softmax(score, dim=1) 59 | # (batch_size, hidden_size) 60 | context_vector = torch.bmm(hidden_states.permute(0,2,1), attention_weights.unsqueeze(2)).squeeze(2) 61 | # (batch_size, hidden_size*2) 62 | pre_activation = torch.cat((context_vector, h_t), dim=1) 63 | # (batch_size, hidden_size) 64 | attention_vector = self.fc2(pre_activation) 65 | attention_vector = torch.tanh(attention_vector) 66 | 67 | return attention_vector, attention_weights 68 | 69 | # Test 70 | if __name__ == '__main__': 71 | # 2d block 72 | spatial_block = SpatialAttn(in_features=3) 73 | l = torch.randn(16, 3, 128, 128) 74 | g = torch.randn(16, 3, 128, 128) 75 | print(spatial_block(l, g)) 76 | # temporal block 77 | temporal_block = TemporalAttn(hidden_size=256) 78 | x = torch.randn(16, 30, 256) 79 | print(temporal_block(x).shape) 80 | -------------------------------------------------------------------------------- /cnn-with-attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.optim.lr_scheduler as lr_scheduler 5 | from torch.utils.data import DataLoader 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | import torchvision.utils as utils 9 | from tensorboardX import SummaryWriter 10 | import os 11 | import argparse 12 | import numpy as np 13 | from datetime import datetime 14 | from models import AttnVGG 15 | from functions import train_epoch, val_epoch, visualize_attn 16 | 17 | 18 | # Parameters manager 19 | parser = argparse.ArgumentParser(description='CNN with Attention') 20 | parser.add_argument('--train', action='store_true', 21 | help='Train the network') 22 | parser.add_argument('--visualize', action='store_true', 23 | help='Visualize the attention vector') 24 | parser.add_argument('--no_save', action='store_true', 25 | help='Not save the model') 26 | parser.add_argument('--save_path', default='/home/haodong/Data/attention_models', type=str, 27 | help='Path to save the model') 28 | parser.add_argument('--checkpoint', default='cnn_checkpoint.pth', type=str, 29 | help='Path to checkpoint') 30 | parser.add_argument('--epochs', default=300, type=int, 31 | help='Epochs for training') 32 | parser.add_argument('--batch_size', default=32, type=int, 33 | help='Batch size for training or testing') 34 | parser.add_argument('--lr', default=1e-4, type=float, 35 | help='Learning rate for training') 36 | parser.add_argument('--weight_decay', default=1e-4, type=float, 37 | help='Weight decay for training') 38 | parser.add_argument('--device', default='0', type=str, 39 | help='Cuda device to use') 40 | parser.add_argument('--log_interval', default=100, type=int, 41 | help='Interval to print messages') 42 | args = parser.parse_args() 43 | 44 | # Use specific gpus 45 | os.environ["CUDA_VISIBLE_DEVICES"]=args.device 46 | # Device setting 47 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 48 | 49 | 50 | if __name__ == '__main__': 51 | # Load data 52 | transform_train = transforms.Compose([ 53 | transforms.RandomCrop(32, padding=4), 54 | transforms.RandomHorizontalFlip(), 55 | transforms.ToTensor(), 56 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 57 | ]) 58 | transform_test = transforms.Compose([ 59 | transforms.ToTensor(), 60 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 61 | ]) 62 | train_set = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train) 63 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=16) 64 | test_set = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test) 65 | test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=True, num_workers=16) 66 | # Create model 67 | model = AttnVGG(sample_size=32, num_classes=100).to(device) 68 | # Run the model parallelly 69 | if torch.cuda.device_count() > 1: 70 | print("Using {} GPUs".format(torch.cuda.device_count())) 71 | model = nn.DataParallel(model) 72 | # Summary writer 73 | writer = SummaryWriter("runs/cnn_attention_{:%Y-%m-%d_%H-%M-%S}".format(datetime.now())) 74 | # Train 75 | if args.train: 76 | # Create loss criterion & optimizer 77 | criterion = nn.CrossEntropyLoss() 78 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 79 | # lr_lambda = lambda epoch : np.power(0.5, int(epoch/25)) 80 | # scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 81 | 82 | for epoch in range(args.epochs): 83 | train_epoch(model, criterion, optimizer, train_loader, device, epoch, args.log_interval, writer) 84 | val_epoch(model, criterion, test_loader, device, epoch, writer) 85 | # adjust learning rate 86 | # scheduler.step() 87 | if not args.no_save: 88 | torch.save(model.state_dict(), os.path.join(args.save_path, "cnn_epoch{:03d}.pth".format(epoch+1))) 89 | print("Saving Model of Epoch {}".format(epoch+1)) 90 | 91 | # Visualize 92 | if args.visualize: 93 | # Load model 94 | model.load_state_dict(torch.load(args.checkpoint)) 95 | model.eval() 96 | 97 | with torch.no_grad(): 98 | for batch_idx, (inputs, labels) in enumerate(test_loader): 99 | # get images 100 | inputs = inputs.to(device) 101 | if batch_idx == 0: 102 | images = inputs[0:16,:,:,:] 103 | I = utils.make_grid(images, nrow=4, normalize=True, scale_each=True) 104 | writer.add_image('origin', I) 105 | _, c1, c2, c3 = model(images) 106 | # print(I.shape, c1.shape, c2.shape, c3.shape, c4.shape) 107 | attn1 = visualize_attn(I, c1) 108 | writer.add_image('attn1', attn1) 109 | attn2 = visualize_attn(I, c2) 110 | writer.add_image('attn2', attn2) 111 | attn3 = visualize_attn(I, c3) 112 | writer.add_image('attn3', attn3) 113 | break 114 | -------------------------------------------------------------------------------- /figures/attn1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0aqz0/pytorch-attention-mechanism/3625e7ad82ea5e7c01e1558e469883518b25ce31/figures/attn1.png -------------------------------------------------------------------------------- /figures/attn2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0aqz0/pytorch-attention-mechanism/3625e7ad82ea5e7c01e1558e469883518b25ce31/figures/attn2.png -------------------------------------------------------------------------------- /figures/attn3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0aqz0/pytorch-attention-mechanism/3625e7ad82ea5e7c01e1558e469883518b25ce31/figures/attn3.png -------------------------------------------------------------------------------- /figures/rnn-with-attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0aqz0/pytorch-attention-mechanism/3625e7ad82ea5e7c01e1558e469883518b25ce31/figures/rnn-with-attention.png -------------------------------------------------------------------------------- /functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torchvision.utils as utils 4 | from sklearn.metrics import accuracy_score 5 | import cv2 6 | 7 | def train_epoch(model, criterion, optimizer, dataloader, device, epoch, log_interval, writer): 8 | model.train() 9 | losses = [] 10 | all_label = [] 11 | all_pred = [] 12 | 13 | for batch_idx, (inputs, labels) in enumerate(dataloader): 14 | # get the inputs and labels 15 | inputs, labels = inputs.to(device), labels.to(device) 16 | 17 | optimizer.zero_grad() 18 | # forward 19 | outputs = model(inputs) 20 | if isinstance(outputs, list): 21 | outputs = outputs[0] 22 | 23 | # compute the loss 24 | loss = criterion(outputs, labels.squeeze()) 25 | losses.append(loss.item()) 26 | 27 | # compute the accuracy 28 | prediction = torch.max(outputs, 1)[1] 29 | all_label.extend(labels.squeeze()) 30 | all_pred.extend(prediction) 31 | score = accuracy_score(labels.squeeze().cpu().data.squeeze().numpy(), prediction.cpu().data.squeeze().numpy()) 32 | 33 | # backward & optimize 34 | loss.backward() 35 | optimizer.step() 36 | 37 | if (batch_idx + 1) % log_interval == 0: 38 | print("epoch {:3d} | iteration {:5d} | Loss {:.6f} | Acc {:.2f}%".format(epoch+1, batch_idx+1, loss.item(), score*100)) 39 | 40 | # Compute the average loss & accuracy 41 | training_loss = sum(losses)/len(losses) 42 | all_label = torch.stack(all_label, dim=0) 43 | all_pred = torch.stack(all_pred, dim=0) 44 | training_acc = accuracy_score(all_label.squeeze().cpu().data.squeeze().numpy(), all_pred.cpu().data.squeeze().numpy()) 45 | # Log 46 | writer.add_scalars('Loss', {'train': training_loss}, epoch+1) 47 | writer.add_scalars('Accuracy', {'train': training_acc}, epoch+1) 48 | print("Average Training Loss of Epoch {}: {:.6f} | Acc: {:.2f}%".format(epoch+1, training_loss, training_acc*100)) 49 | 50 | 51 | def val_epoch(model, criterion, dataloader, device, epoch, writer): 52 | model.eval() 53 | losses = [] 54 | all_label = [] 55 | all_pred = [] 56 | 57 | with torch.no_grad(): 58 | for batch_idx, (inputs, labels) in enumerate(dataloader): 59 | # get the inputs and labels 60 | inputs, labels = inputs.to(device), labels.to(device) 61 | # forward 62 | outputs = model(inputs) 63 | if isinstance(outputs, list): 64 | outputs = outputs[0] 65 | # compute the loss 66 | loss = criterion(outputs, labels.squeeze()) 67 | losses.append(loss.item()) 68 | # collect labels & prediction 69 | prediction = torch.max(outputs, 1)[1] 70 | all_label.extend(labels.squeeze()) 71 | all_pred.extend(prediction) 72 | 73 | # Compute the average loss & accuracy 74 | val_loss = sum(losses)/len(losses) 75 | all_label = torch.stack(all_label, dim=0) 76 | all_pred = torch.stack(all_pred, dim=0) 77 | val_acc = accuracy_score(all_label.squeeze().cpu().data.squeeze().numpy(), all_pred.cpu().data.squeeze().numpy()) 78 | # Log 79 | writer.add_scalars('Loss', {'val': val_loss}, epoch+1) 80 | writer.add_scalars('Accuracy', {'val': val_acc}, epoch+1) 81 | print("Average Validation Loss: {:.6f} | Acc: {:.2f}%".format(val_loss, val_acc*100)) 82 | 83 | 84 | def visualize_attn(I, c): 85 | # Image 86 | img = I.permute((1,2,0)).cpu().numpy() 87 | # Heatmap 88 | N, C, H, W = c.size() 89 | a = F.softmax(c.view(N,C,-1), dim=2).view(N,C,H,W) 90 | up_factor = 32/H 91 | # print(up_factor, I.size(), c.size()) 92 | if up_factor > 1: 93 | a = F.interpolate(a, scale_factor=up_factor, mode='bilinear', align_corners=False) 94 | attn = utils.make_grid(a, nrow=4, normalize=True, scale_each=True) 95 | attn = attn.permute((1,2,0)).mul(255).byte().cpu().numpy() 96 | attn = cv2.applyColorMap(attn, cv2.COLORMAP_JET) 97 | attn = cv2.cvtColor(attn, cv2.COLOR_BGR2RGB) 98 | # Add the heatmap to the image 99 | vis = 0.6 * img + 0.4 * attn 100 | return torch.from_numpy(vis).permute(2,0,1) 101 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from attention import ProjectorBlock, SpatialAttn, TemporalAttn 5 | import math 6 | 7 | """ 8 | VGG-16 with attention 9 | """ 10 | class AttnVGG(nn.Module): 11 | def __init__(self, sample_size, num_classes, attention=True, normalize_attn=True, init_weights=True): 12 | super(AttnVGG, self).__init__() 13 | # conv blocks 14 | self.conv1 = self._make_layer(3, 64, 2) 15 | self.conv2 = self._make_layer(64, 128, 2) 16 | self.conv3 = self._make_layer(128, 256, 3) 17 | self.conv4 = self._make_layer(256, 512, 3) 18 | self.conv5 = self._make_layer(512, 512, 3) 19 | self.conv6 = self._make_layer(512, 512, 2, pool=True) 20 | self.dense = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=int(sample_size/32), padding=0, bias=True) 21 | # attention blocks 22 | self.attention = attention 23 | if self.attention: 24 | self.projector = ProjectorBlock(256, 512) 25 | self.attn1 = SpatialAttn(in_features=512, normalize_attn=normalize_attn) 26 | self.attn2 = SpatialAttn(in_features=512, normalize_attn=normalize_attn) 27 | self.attn3 = SpatialAttn(in_features=512, normalize_attn=normalize_attn) 28 | # final classification layer 29 | if self.attention: 30 | self.classify = nn.Linear(in_features=512*3, out_features=num_classes, bias=True) 31 | else: 32 | self.classify = nn.Linear(in_features=512, out_features=num_classes, bias=True) 33 | # if init_weights: 34 | # self._initialize_weights() 35 | 36 | def forward(self, x): 37 | x = self.conv1(x) 38 | x = self.conv2(x) 39 | l1 = self.conv3(x) 40 | x = F.max_pool2d(l1, kernel_size=2, stride=2, padding=0) 41 | l2 = self.conv4(x) 42 | x = F.max_pool2d(l2, kernel_size=2, stride=2, padding=0) 43 | l3 = self.conv5(x) 44 | x = F.max_pool2d(l3, kernel_size=2, stride=2, padding=0) 45 | x = self.conv6(x) 46 | g = self.dense(x) # batch_sizex512x1x1 47 | # attention 48 | if self.attention: 49 | c1, g1 = self.attn1(self.projector(l1), g) 50 | c2, g2 = self.attn2(l2, g) 51 | c3, g3 = self.attn3(l3, g) 52 | g = torch.cat((g1,g2,g3), dim=1) # batch_sizex3C 53 | # classification layer 54 | x = self.classify(g) # batch_sizexnum_classes 55 | else: 56 | c1, c2, c3 = None, None, None 57 | x = self.classify(torch.squeeze(g)) 58 | return [x, c1, c2, c3] 59 | 60 | def _make_layer(self, in_features, out_features, blocks, pool=False): 61 | layers = [] 62 | for i in range(blocks): 63 | conv2d = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=3, padding=1, bias=False) 64 | layers += [conv2d, nn.BatchNorm2d(out_features), nn.ReLU(inplace=True)] 65 | in_features = out_features 66 | if pool: 67 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 68 | return nn.Sequential(*layers) 69 | 70 | def _initialize_weights(self): 71 | for m in self.modules(): 72 | if isinstance(m, nn.Conv2d): 73 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 74 | if m.bias is not None: 75 | nn.init.constant_(m.bias, 0) 76 | elif isinstance(m, nn.BatchNorm2d): 77 | nn.init.constant_(m.weight, 1) 78 | nn.init.constant_(m.bias, 0) 79 | elif isinstance(m, nn.Linear): 80 | nn.init.normal_(m.weight, 0, 0.01) 81 | nn.init.constant_(m.bias, 0) 82 | 83 | """ 84 | LSTM with attention 85 | """ 86 | class AttnLSTM(nn.Module): 87 | def __init__(self, input_size, hidden_size, num_layers): 88 | super(AttnLSTM, self).__init__() 89 | self.lstm = nn.LSTM( 90 | input_size=input_size, 91 | hidden_size=hidden_size, 92 | num_layers=num_layers, 93 | batch_first=True) 94 | self.attn = TemporalAttn(hidden_size=hidden_size) 95 | self.fc = nn.Linear(hidden_size, 1) 96 | 97 | def forward(self, x): 98 | x, (h_n, c_n) = self.lstm(x) 99 | x, weights = self.attn(x) 100 | x = self.fc(x) 101 | return x, weights 102 | 103 | # Test 104 | if __name__ == '__main__': 105 | model = AttnVGG(sample_size=128, num_classes=10) 106 | x = torch.randn(16,3,128,128) 107 | print(model(x)) 108 | model = AttnLSTM(input_size=1, hidden_size=128, num_layers=1) 109 | x = torch.randn(16, 20, 1) 110 | print(model(x)) 111 | -------------------------------------------------------------------------------- /rnn-with-attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from tensorboardX import SummaryWriter 5 | import os 6 | import argparse 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | from datetime import datetime 10 | from models import AttnLSTM 11 | 12 | # Parameters manager 13 | parser = argparse.ArgumentParser(description='RNN with Attention') 14 | parser.add_argument('--train', action='store_true', 15 | help='Train the network') 16 | parser.add_argument('--visualize', action='store_true', 17 | help='Visualize the attention vector') 18 | parser.add_argument('--no_save', action='store_true', 19 | help='Not save the model') 20 | parser.add_argument('--save_path', default='/home/haodong/Data/attention_models', type=str, 21 | help='Path to save the model') 22 | parser.add_argument('--checkpoint', default='rnn_checkpoint.pth', type=str, 23 | help='Path to checkpoint') 24 | parser.add_argument('--epochs', default=30, type=int, 25 | help='Epochs for training') 26 | parser.add_argument('--lr', default=1e-4, type=float, 27 | help='Learning rate for training') 28 | parser.add_argument('--weight_decay', default=1e-4, type=float, 29 | help='Weight decay for training') 30 | parser.add_argument('--device', default='0', type=str, 31 | help='Cuda device to use') 32 | parser.add_argument('--log_interval', default=1000, type=int, 33 | help='Interval to print messages') 34 | args = parser.parse_args() 35 | 36 | # Use specific gpus 37 | os.environ["CUDA_VISIBLE_DEVICES"]=args.device 38 | # Device setting 39 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 40 | 41 | def generate_data(n, seq_length, delimiter=0.0, index_1=None, index_2=None): 42 | x = np.random.uniform(0, 10, (n, seq_length)) 43 | y = np.zeros(shape=(n, 1)) 44 | for i in range(n): 45 | if index_1 is None and index_2 is None: 46 | a, b = np.random.choice(range(1, seq_length), size=2, replace=False) 47 | else: 48 | a, b = index_1, index_2 49 | y[i] = 0.5 * x[i, a] + 0.5 * x[i, b] 50 | x[i, a-1] = delimiter 51 | x[i, b-1] = delimiter 52 | x = np.expand_dims(x, axis=-1) 53 | return x, y 54 | 55 | 56 | if __name__ == '__main__': 57 | # Generate data 58 | seq_length, train_length, val_length, test_length = 20, 20000, 4000, 10 59 | x_train, y_train = generate_data(train_length, seq_length) 60 | x_val, y_val = generate_data(val_length, seq_length) 61 | x_test, y_test = generate_data(test_length, seq_length, index_1=5, index_2=13) 62 | # Create the model 63 | model = AttnLSTM(input_size=1, hidden_size=128, num_layers=1).to(device) 64 | # Run the model parallelly 65 | if torch.cuda.device_count() > 1: 66 | print("Using {} GPUs".format(torch.cuda.device_count())) 67 | model = nn.DataParallel(model) 68 | # Summary writer 69 | writer = SummaryWriter("runs/rnn_attention_{:%Y-%m-%d_%H-%M-%S}".format(datetime.now())) 70 | 71 | if args.train: 72 | # Create loss criterion & optimizer 73 | criterion = nn.MSELoss() 74 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 75 | 76 | for epoch in range(args.epochs): 77 | # train model 78 | model.train() 79 | losses = [] 80 | for i in range(train_length): 81 | x = torch.Tensor(x_train[i, :]).unsqueeze(0).to(device) 82 | y = torch.Tensor(y_train[i, :]).unsqueeze(0).to(device) 83 | # print(x.shape, y.shape) 84 | optimizer.zero_grad() 85 | # forward 86 | pred, _ = model(x) 87 | # compute the loss 88 | loss = criterion(pred, y) 89 | losses.append(loss.item()) 90 | # backward & optimize 91 | loss.backward() 92 | optimizer.step() 93 | 94 | if (i + 1) % args.log_interval == 0: 95 | print("epoch {:3d} | iteration {:5d} | Loss {:.6f}".format(epoch+1, i+1, loss.item())) 96 | 97 | # calculate average loss 98 | training_loss = sum(losses)/len(losses) 99 | writer.add_scalars('Loss', {'train': training_loss}, epoch+1) 100 | print("Average Training Loss of Epoch {}: {:.6f}".format(epoch+1, training_loss)) 101 | 102 | # save model 103 | if not args.no_save: 104 | torch.save(model.state_dict(), os.path.join(args.save_path, "rnn_epoch{:03d}.pth".format(epoch+1))) 105 | print("Saving Model of Epoch {}".format(epoch+1)) 106 | 107 | # validate model 108 | model.eval() 109 | losses = [] 110 | for i in range(val_length): 111 | x = torch.Tensor(x_val[i, :]).unsqueeze(0).to(device) 112 | y = torch.Tensor(y_val[i, :]).unsqueeze(0).to(device) 113 | # forward 114 | pred, _ = model(x) 115 | # compute the loss 116 | loss = criterion(pred, y) 117 | losses.append(loss.item()) 118 | 119 | # calculate average loss 120 | val_loss = sum(losses)/len(losses) 121 | writer.add_scalars('Loss', {'val': val_loss}, epoch+1) 122 | print("Average Validation Loss of Epoch {}: {:.6f}".format(epoch+1, val_loss)) 123 | 124 | # Visualize attention map 125 | if args.visualize: 126 | model.load_state_dict(torch.load(args.checkpoint)) 127 | model.eval() 128 | for i in range(test_length): 129 | with torch.no_grad(): 130 | x = torch.Tensor(x_test[i, :]).unsqueeze(0).to(device) 131 | y = torch.Tensor(y_test[i, :]).unsqueeze(0).to(device) 132 | # forward 133 | pred, weights = model(x) 134 | # print(y, pred, weights) 135 | plt.title('Attention Weights') 136 | plt.xticks(np.arange(0, seq_length)) 137 | plt.yticks(np.arange(0, 1, step=0.1)) 138 | plt.bar(range(seq_length), weights.squeeze().cpu().numpy(), color='royalblue') 139 | plt.savefig('output_{}.png'.format(i)) 140 | --------------------------------------------------------------------------------