├── .gitignore
├── README.md
├── attention.py
├── cnn-with-attention.py
├── figures
├── attn1.png
├── attn2.png
├── attn3.png
└── rnn-with-attention.png
├── functions.py
├── models.py
└── rnn-with-attention.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-attention-mechanism
2 | my codes for learning attention mechanism
3 |
4 | ## CNN with attention
5 |
6 | *Apply spatial attention to CIFAR100 dataset*
7 |
8 |
9 |
10 | ### Usage
11 |
12 | Train the model:
13 |
14 | ```bash
15 | $ python cnn-with-attention.py --train
16 | ```
17 |
18 | Visualize attention map:
19 |
20 | ```bash
21 | $ python cnn-with-attention.py --visualize
22 | ```
23 |
24 | ## RNN with attention
25 |
26 | *Apply temporal attention to sequential data*
27 |
28 | *e.g. A sequence of length 20, the output is only related to the 5th position and the 13th position*
29 |
30 |
31 |
32 | ### Usage
33 |
34 | Train the model:
35 |
36 | ```bash
37 | $ python rnn-with-attention.py --train
38 | ```
39 |
40 | Visualize attention map:
41 |
42 | ```bash
43 | $ python rnn-with-attention.py --visualize
44 | ```
45 |
46 | ## Todos
47 |
48 | - [x] CNN+attention
49 | - [x] RNN+attention
50 |
51 | ## References
52 |
53 | - [Learn to Pay Attention! Trainable Visual Attention in CNNs](https://towardsdatascience.com/learn-to-pay-attention-trainable-visual-attention-in-cnns-87e2869f89f1)
54 | - [Attention in Neural Networks and How to Use It](http://akosiorek.github.io/ml/2017/10/14/visual-attention.html)
55 | - [Attention? Attention!](https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html)
56 | - [Tutorial - Visual Attention for Action Recognition](https://dtransposed.github.io/blog/Action-Recognition-Attention.html)
57 | - [BahdanauAttention与LuongAttention注意力机制简介](https://blog.csdn.net/u010960155/article/details/82853632)
58 | - [目前主流的attention方法都有哪些?](https://www.zhihu.com/question/68482809)
59 |
60 | - [Keras Attention Mechanism](https://github.com/philipperemy/keras-attention-mechanism)
--------------------------------------------------------------------------------
/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | """
6 | Attention blocks
7 | Reference: Learn To Pay Attention
8 | """
9 | class ProjectorBlock(nn.Module):
10 | def __init__(self, in_features, out_features):
11 | super(ProjectorBlock, self).__init__()
12 | self.op = nn.Conv2d(in_channels=in_features, out_channels=out_features,
13 | kernel_size=1, padding=0, bias=False)
14 |
15 | def forward(self, x):
16 | return self.op(x)
17 |
18 |
19 | class SpatialAttn(nn.Module):
20 | def __init__(self, in_features, normalize_attn=True):
21 | super(SpatialAttn, self).__init__()
22 | self.normalize_attn = normalize_attn
23 | self.op = nn.Conv2d(in_channels=in_features, out_channels=1,
24 | kernel_size=1, padding=0, bias=False)
25 |
26 | def forward(self, l, g):
27 | N, C, H, W = l.size()
28 | c = self.op(l+g) # (batch_size,1,H,W)
29 | if self.normalize_attn:
30 | a = F.softmax(c.view(N,1,-1), dim=2).view(N,1,H,W)
31 | else:
32 | a = torch.sigmoid(c)
33 | g = torch.mul(a.expand_as(l), l)
34 | if self.normalize_attn:
35 | g = g.view(N,C,-1).sum(dim=2) # (batch_size,C)
36 | else:
37 | g = F.adaptive_avg_pool2d(g, (1,1)).view(N,C)
38 | return c.view(N,1,H,W), g
39 |
40 | """
41 | Temporal attention block
42 | Reference: https://github.com/philipperemy/keras-attention-mechanism
43 | """
44 | class TemporalAttn(nn.Module):
45 | def __init__(self, hidden_size):
46 | super(TemporalAttn, self).__init__()
47 | self.hidden_size = hidden_size
48 | self.fc1 = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
49 | self.fc2 = nn.Linear(self.hidden_size*2, self.hidden_size, bias=False)
50 |
51 | def forward(self, hidden_states):
52 | # (batch_size, time_steps, hidden_size)
53 | score_first_part = self.fc1(hidden_states)
54 | # (batch_size, hidden_size)
55 | h_t = hidden_states[:,-1,:]
56 | # (batch_size, time_steps)
57 | score = torch.bmm(score_first_part, h_t.unsqueeze(2)).squeeze(2)
58 | attention_weights = F.softmax(score, dim=1)
59 | # (batch_size, hidden_size)
60 | context_vector = torch.bmm(hidden_states.permute(0,2,1), attention_weights.unsqueeze(2)).squeeze(2)
61 | # (batch_size, hidden_size*2)
62 | pre_activation = torch.cat((context_vector, h_t), dim=1)
63 | # (batch_size, hidden_size)
64 | attention_vector = self.fc2(pre_activation)
65 | attention_vector = torch.tanh(attention_vector)
66 |
67 | return attention_vector, attention_weights
68 |
69 | # Test
70 | if __name__ == '__main__':
71 | # 2d block
72 | spatial_block = SpatialAttn(in_features=3)
73 | l = torch.randn(16, 3, 128, 128)
74 | g = torch.randn(16, 3, 128, 128)
75 | print(spatial_block(l, g))
76 | # temporal block
77 | temporal_block = TemporalAttn(hidden_size=256)
78 | x = torch.randn(16, 30, 256)
79 | print(temporal_block(x).shape)
80 |
--------------------------------------------------------------------------------
/cnn-with-attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.optim.lr_scheduler as lr_scheduler
5 | from torch.utils.data import DataLoader
6 | import torchvision
7 | import torchvision.transforms as transforms
8 | import torchvision.utils as utils
9 | from tensorboardX import SummaryWriter
10 | import os
11 | import argparse
12 | import numpy as np
13 | from datetime import datetime
14 | from models import AttnVGG
15 | from functions import train_epoch, val_epoch, visualize_attn
16 |
17 |
18 | # Parameters manager
19 | parser = argparse.ArgumentParser(description='CNN with Attention')
20 | parser.add_argument('--train', action='store_true',
21 | help='Train the network')
22 | parser.add_argument('--visualize', action='store_true',
23 | help='Visualize the attention vector')
24 | parser.add_argument('--no_save', action='store_true',
25 | help='Not save the model')
26 | parser.add_argument('--save_path', default='/home/haodong/Data/attention_models', type=str,
27 | help='Path to save the model')
28 | parser.add_argument('--checkpoint', default='cnn_checkpoint.pth', type=str,
29 | help='Path to checkpoint')
30 | parser.add_argument('--epochs', default=300, type=int,
31 | help='Epochs for training')
32 | parser.add_argument('--batch_size', default=32, type=int,
33 | help='Batch size for training or testing')
34 | parser.add_argument('--lr', default=1e-4, type=float,
35 | help='Learning rate for training')
36 | parser.add_argument('--weight_decay', default=1e-4, type=float,
37 | help='Weight decay for training')
38 | parser.add_argument('--device', default='0', type=str,
39 | help='Cuda device to use')
40 | parser.add_argument('--log_interval', default=100, type=int,
41 | help='Interval to print messages')
42 | args = parser.parse_args()
43 |
44 | # Use specific gpus
45 | os.environ["CUDA_VISIBLE_DEVICES"]=args.device
46 | # Device setting
47 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48 |
49 |
50 | if __name__ == '__main__':
51 | # Load data
52 | transform_train = transforms.Compose([
53 | transforms.RandomCrop(32, padding=4),
54 | transforms.RandomHorizontalFlip(),
55 | transforms.ToTensor(),
56 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
57 | ])
58 | transform_test = transforms.Compose([
59 | transforms.ToTensor(),
60 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
61 | ])
62 | train_set = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
63 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=16)
64 | test_set = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
65 | test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=True, num_workers=16)
66 | # Create model
67 | model = AttnVGG(sample_size=32, num_classes=100).to(device)
68 | # Run the model parallelly
69 | if torch.cuda.device_count() > 1:
70 | print("Using {} GPUs".format(torch.cuda.device_count()))
71 | model = nn.DataParallel(model)
72 | # Summary writer
73 | writer = SummaryWriter("runs/cnn_attention_{:%Y-%m-%d_%H-%M-%S}".format(datetime.now()))
74 | # Train
75 | if args.train:
76 | # Create loss criterion & optimizer
77 | criterion = nn.CrossEntropyLoss()
78 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
79 | # lr_lambda = lambda epoch : np.power(0.5, int(epoch/25))
80 | # scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
81 |
82 | for epoch in range(args.epochs):
83 | train_epoch(model, criterion, optimizer, train_loader, device, epoch, args.log_interval, writer)
84 | val_epoch(model, criterion, test_loader, device, epoch, writer)
85 | # adjust learning rate
86 | # scheduler.step()
87 | if not args.no_save:
88 | torch.save(model.state_dict(), os.path.join(args.save_path, "cnn_epoch{:03d}.pth".format(epoch+1)))
89 | print("Saving Model of Epoch {}".format(epoch+1))
90 |
91 | # Visualize
92 | if args.visualize:
93 | # Load model
94 | model.load_state_dict(torch.load(args.checkpoint))
95 | model.eval()
96 |
97 | with torch.no_grad():
98 | for batch_idx, (inputs, labels) in enumerate(test_loader):
99 | # get images
100 | inputs = inputs.to(device)
101 | if batch_idx == 0:
102 | images = inputs[0:16,:,:,:]
103 | I = utils.make_grid(images, nrow=4, normalize=True, scale_each=True)
104 | writer.add_image('origin', I)
105 | _, c1, c2, c3 = model(images)
106 | # print(I.shape, c1.shape, c2.shape, c3.shape, c4.shape)
107 | attn1 = visualize_attn(I, c1)
108 | writer.add_image('attn1', attn1)
109 | attn2 = visualize_attn(I, c2)
110 | writer.add_image('attn2', attn2)
111 | attn3 = visualize_attn(I, c3)
112 | writer.add_image('attn3', attn3)
113 | break
114 |
--------------------------------------------------------------------------------
/figures/attn1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0aqz0/pytorch-attention-mechanism/3625e7ad82ea5e7c01e1558e469883518b25ce31/figures/attn1.png
--------------------------------------------------------------------------------
/figures/attn2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0aqz0/pytorch-attention-mechanism/3625e7ad82ea5e7c01e1558e469883518b25ce31/figures/attn2.png
--------------------------------------------------------------------------------
/figures/attn3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0aqz0/pytorch-attention-mechanism/3625e7ad82ea5e7c01e1558e469883518b25ce31/figures/attn3.png
--------------------------------------------------------------------------------
/figures/rnn-with-attention.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0aqz0/pytorch-attention-mechanism/3625e7ad82ea5e7c01e1558e469883518b25ce31/figures/rnn-with-attention.png
--------------------------------------------------------------------------------
/functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torchvision.utils as utils
4 | from sklearn.metrics import accuracy_score
5 | import cv2
6 |
7 | def train_epoch(model, criterion, optimizer, dataloader, device, epoch, log_interval, writer):
8 | model.train()
9 | losses = []
10 | all_label = []
11 | all_pred = []
12 |
13 | for batch_idx, (inputs, labels) in enumerate(dataloader):
14 | # get the inputs and labels
15 | inputs, labels = inputs.to(device), labels.to(device)
16 |
17 | optimizer.zero_grad()
18 | # forward
19 | outputs = model(inputs)
20 | if isinstance(outputs, list):
21 | outputs = outputs[0]
22 |
23 | # compute the loss
24 | loss = criterion(outputs, labels.squeeze())
25 | losses.append(loss.item())
26 |
27 | # compute the accuracy
28 | prediction = torch.max(outputs, 1)[1]
29 | all_label.extend(labels.squeeze())
30 | all_pred.extend(prediction)
31 | score = accuracy_score(labels.squeeze().cpu().data.squeeze().numpy(), prediction.cpu().data.squeeze().numpy())
32 |
33 | # backward & optimize
34 | loss.backward()
35 | optimizer.step()
36 |
37 | if (batch_idx + 1) % log_interval == 0:
38 | print("epoch {:3d} | iteration {:5d} | Loss {:.6f} | Acc {:.2f}%".format(epoch+1, batch_idx+1, loss.item(), score*100))
39 |
40 | # Compute the average loss & accuracy
41 | training_loss = sum(losses)/len(losses)
42 | all_label = torch.stack(all_label, dim=0)
43 | all_pred = torch.stack(all_pred, dim=0)
44 | training_acc = accuracy_score(all_label.squeeze().cpu().data.squeeze().numpy(), all_pred.cpu().data.squeeze().numpy())
45 | # Log
46 | writer.add_scalars('Loss', {'train': training_loss}, epoch+1)
47 | writer.add_scalars('Accuracy', {'train': training_acc}, epoch+1)
48 | print("Average Training Loss of Epoch {}: {:.6f} | Acc: {:.2f}%".format(epoch+1, training_loss, training_acc*100))
49 |
50 |
51 | def val_epoch(model, criterion, dataloader, device, epoch, writer):
52 | model.eval()
53 | losses = []
54 | all_label = []
55 | all_pred = []
56 |
57 | with torch.no_grad():
58 | for batch_idx, (inputs, labels) in enumerate(dataloader):
59 | # get the inputs and labels
60 | inputs, labels = inputs.to(device), labels.to(device)
61 | # forward
62 | outputs = model(inputs)
63 | if isinstance(outputs, list):
64 | outputs = outputs[0]
65 | # compute the loss
66 | loss = criterion(outputs, labels.squeeze())
67 | losses.append(loss.item())
68 | # collect labels & prediction
69 | prediction = torch.max(outputs, 1)[1]
70 | all_label.extend(labels.squeeze())
71 | all_pred.extend(prediction)
72 |
73 | # Compute the average loss & accuracy
74 | val_loss = sum(losses)/len(losses)
75 | all_label = torch.stack(all_label, dim=0)
76 | all_pred = torch.stack(all_pred, dim=0)
77 | val_acc = accuracy_score(all_label.squeeze().cpu().data.squeeze().numpy(), all_pred.cpu().data.squeeze().numpy())
78 | # Log
79 | writer.add_scalars('Loss', {'val': val_loss}, epoch+1)
80 | writer.add_scalars('Accuracy', {'val': val_acc}, epoch+1)
81 | print("Average Validation Loss: {:.6f} | Acc: {:.2f}%".format(val_loss, val_acc*100))
82 |
83 |
84 | def visualize_attn(I, c):
85 | # Image
86 | img = I.permute((1,2,0)).cpu().numpy()
87 | # Heatmap
88 | N, C, H, W = c.size()
89 | a = F.softmax(c.view(N,C,-1), dim=2).view(N,C,H,W)
90 | up_factor = 32/H
91 | # print(up_factor, I.size(), c.size())
92 | if up_factor > 1:
93 | a = F.interpolate(a, scale_factor=up_factor, mode='bilinear', align_corners=False)
94 | attn = utils.make_grid(a, nrow=4, normalize=True, scale_each=True)
95 | attn = attn.permute((1,2,0)).mul(255).byte().cpu().numpy()
96 | attn = cv2.applyColorMap(attn, cv2.COLORMAP_JET)
97 | attn = cv2.cvtColor(attn, cv2.COLOR_BGR2RGB)
98 | # Add the heatmap to the image
99 | vis = 0.6 * img + 0.4 * attn
100 | return torch.from_numpy(vis).permute(2,0,1)
101 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from attention import ProjectorBlock, SpatialAttn, TemporalAttn
5 | import math
6 |
7 | """
8 | VGG-16 with attention
9 | """
10 | class AttnVGG(nn.Module):
11 | def __init__(self, sample_size, num_classes, attention=True, normalize_attn=True, init_weights=True):
12 | super(AttnVGG, self).__init__()
13 | # conv blocks
14 | self.conv1 = self._make_layer(3, 64, 2)
15 | self.conv2 = self._make_layer(64, 128, 2)
16 | self.conv3 = self._make_layer(128, 256, 3)
17 | self.conv4 = self._make_layer(256, 512, 3)
18 | self.conv5 = self._make_layer(512, 512, 3)
19 | self.conv6 = self._make_layer(512, 512, 2, pool=True)
20 | self.dense = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=int(sample_size/32), padding=0, bias=True)
21 | # attention blocks
22 | self.attention = attention
23 | if self.attention:
24 | self.projector = ProjectorBlock(256, 512)
25 | self.attn1 = SpatialAttn(in_features=512, normalize_attn=normalize_attn)
26 | self.attn2 = SpatialAttn(in_features=512, normalize_attn=normalize_attn)
27 | self.attn3 = SpatialAttn(in_features=512, normalize_attn=normalize_attn)
28 | # final classification layer
29 | if self.attention:
30 | self.classify = nn.Linear(in_features=512*3, out_features=num_classes, bias=True)
31 | else:
32 | self.classify = nn.Linear(in_features=512, out_features=num_classes, bias=True)
33 | # if init_weights:
34 | # self._initialize_weights()
35 |
36 | def forward(self, x):
37 | x = self.conv1(x)
38 | x = self.conv2(x)
39 | l1 = self.conv3(x)
40 | x = F.max_pool2d(l1, kernel_size=2, stride=2, padding=0)
41 | l2 = self.conv4(x)
42 | x = F.max_pool2d(l2, kernel_size=2, stride=2, padding=0)
43 | l3 = self.conv5(x)
44 | x = F.max_pool2d(l3, kernel_size=2, stride=2, padding=0)
45 | x = self.conv6(x)
46 | g = self.dense(x) # batch_sizex512x1x1
47 | # attention
48 | if self.attention:
49 | c1, g1 = self.attn1(self.projector(l1), g)
50 | c2, g2 = self.attn2(l2, g)
51 | c3, g3 = self.attn3(l3, g)
52 | g = torch.cat((g1,g2,g3), dim=1) # batch_sizex3C
53 | # classification layer
54 | x = self.classify(g) # batch_sizexnum_classes
55 | else:
56 | c1, c2, c3 = None, None, None
57 | x = self.classify(torch.squeeze(g))
58 | return [x, c1, c2, c3]
59 |
60 | def _make_layer(self, in_features, out_features, blocks, pool=False):
61 | layers = []
62 | for i in range(blocks):
63 | conv2d = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=3, padding=1, bias=False)
64 | layers += [conv2d, nn.BatchNorm2d(out_features), nn.ReLU(inplace=True)]
65 | in_features = out_features
66 | if pool:
67 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
68 | return nn.Sequential(*layers)
69 |
70 | def _initialize_weights(self):
71 | for m in self.modules():
72 | if isinstance(m, nn.Conv2d):
73 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
74 | if m.bias is not None:
75 | nn.init.constant_(m.bias, 0)
76 | elif isinstance(m, nn.BatchNorm2d):
77 | nn.init.constant_(m.weight, 1)
78 | nn.init.constant_(m.bias, 0)
79 | elif isinstance(m, nn.Linear):
80 | nn.init.normal_(m.weight, 0, 0.01)
81 | nn.init.constant_(m.bias, 0)
82 |
83 | """
84 | LSTM with attention
85 | """
86 | class AttnLSTM(nn.Module):
87 | def __init__(self, input_size, hidden_size, num_layers):
88 | super(AttnLSTM, self).__init__()
89 | self.lstm = nn.LSTM(
90 | input_size=input_size,
91 | hidden_size=hidden_size,
92 | num_layers=num_layers,
93 | batch_first=True)
94 | self.attn = TemporalAttn(hidden_size=hidden_size)
95 | self.fc = nn.Linear(hidden_size, 1)
96 |
97 | def forward(self, x):
98 | x, (h_n, c_n) = self.lstm(x)
99 | x, weights = self.attn(x)
100 | x = self.fc(x)
101 | return x, weights
102 |
103 | # Test
104 | if __name__ == '__main__':
105 | model = AttnVGG(sample_size=128, num_classes=10)
106 | x = torch.randn(16,3,128,128)
107 | print(model(x))
108 | model = AttnLSTM(input_size=1, hidden_size=128, num_layers=1)
109 | x = torch.randn(16, 20, 1)
110 | print(model(x))
111 |
--------------------------------------------------------------------------------
/rnn-with-attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | from tensorboardX import SummaryWriter
5 | import os
6 | import argparse
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 | from datetime import datetime
10 | from models import AttnLSTM
11 |
12 | # Parameters manager
13 | parser = argparse.ArgumentParser(description='RNN with Attention')
14 | parser.add_argument('--train', action='store_true',
15 | help='Train the network')
16 | parser.add_argument('--visualize', action='store_true',
17 | help='Visualize the attention vector')
18 | parser.add_argument('--no_save', action='store_true',
19 | help='Not save the model')
20 | parser.add_argument('--save_path', default='/home/haodong/Data/attention_models', type=str,
21 | help='Path to save the model')
22 | parser.add_argument('--checkpoint', default='rnn_checkpoint.pth', type=str,
23 | help='Path to checkpoint')
24 | parser.add_argument('--epochs', default=30, type=int,
25 | help='Epochs for training')
26 | parser.add_argument('--lr', default=1e-4, type=float,
27 | help='Learning rate for training')
28 | parser.add_argument('--weight_decay', default=1e-4, type=float,
29 | help='Weight decay for training')
30 | parser.add_argument('--device', default='0', type=str,
31 | help='Cuda device to use')
32 | parser.add_argument('--log_interval', default=1000, type=int,
33 | help='Interval to print messages')
34 | args = parser.parse_args()
35 |
36 | # Use specific gpus
37 | os.environ["CUDA_VISIBLE_DEVICES"]=args.device
38 | # Device setting
39 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40 |
41 | def generate_data(n, seq_length, delimiter=0.0, index_1=None, index_2=None):
42 | x = np.random.uniform(0, 10, (n, seq_length))
43 | y = np.zeros(shape=(n, 1))
44 | for i in range(n):
45 | if index_1 is None and index_2 is None:
46 | a, b = np.random.choice(range(1, seq_length), size=2, replace=False)
47 | else:
48 | a, b = index_1, index_2
49 | y[i] = 0.5 * x[i, a] + 0.5 * x[i, b]
50 | x[i, a-1] = delimiter
51 | x[i, b-1] = delimiter
52 | x = np.expand_dims(x, axis=-1)
53 | return x, y
54 |
55 |
56 | if __name__ == '__main__':
57 | # Generate data
58 | seq_length, train_length, val_length, test_length = 20, 20000, 4000, 10
59 | x_train, y_train = generate_data(train_length, seq_length)
60 | x_val, y_val = generate_data(val_length, seq_length)
61 | x_test, y_test = generate_data(test_length, seq_length, index_1=5, index_2=13)
62 | # Create the model
63 | model = AttnLSTM(input_size=1, hidden_size=128, num_layers=1).to(device)
64 | # Run the model parallelly
65 | if torch.cuda.device_count() > 1:
66 | print("Using {} GPUs".format(torch.cuda.device_count()))
67 | model = nn.DataParallel(model)
68 | # Summary writer
69 | writer = SummaryWriter("runs/rnn_attention_{:%Y-%m-%d_%H-%M-%S}".format(datetime.now()))
70 |
71 | if args.train:
72 | # Create loss criterion & optimizer
73 | criterion = nn.MSELoss()
74 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
75 |
76 | for epoch in range(args.epochs):
77 | # train model
78 | model.train()
79 | losses = []
80 | for i in range(train_length):
81 | x = torch.Tensor(x_train[i, :]).unsqueeze(0).to(device)
82 | y = torch.Tensor(y_train[i, :]).unsqueeze(0).to(device)
83 | # print(x.shape, y.shape)
84 | optimizer.zero_grad()
85 | # forward
86 | pred, _ = model(x)
87 | # compute the loss
88 | loss = criterion(pred, y)
89 | losses.append(loss.item())
90 | # backward & optimize
91 | loss.backward()
92 | optimizer.step()
93 |
94 | if (i + 1) % args.log_interval == 0:
95 | print("epoch {:3d} | iteration {:5d} | Loss {:.6f}".format(epoch+1, i+1, loss.item()))
96 |
97 | # calculate average loss
98 | training_loss = sum(losses)/len(losses)
99 | writer.add_scalars('Loss', {'train': training_loss}, epoch+1)
100 | print("Average Training Loss of Epoch {}: {:.6f}".format(epoch+1, training_loss))
101 |
102 | # save model
103 | if not args.no_save:
104 | torch.save(model.state_dict(), os.path.join(args.save_path, "rnn_epoch{:03d}.pth".format(epoch+1)))
105 | print("Saving Model of Epoch {}".format(epoch+1))
106 |
107 | # validate model
108 | model.eval()
109 | losses = []
110 | for i in range(val_length):
111 | x = torch.Tensor(x_val[i, :]).unsqueeze(0).to(device)
112 | y = torch.Tensor(y_val[i, :]).unsqueeze(0).to(device)
113 | # forward
114 | pred, _ = model(x)
115 | # compute the loss
116 | loss = criterion(pred, y)
117 | losses.append(loss.item())
118 |
119 | # calculate average loss
120 | val_loss = sum(losses)/len(losses)
121 | writer.add_scalars('Loss', {'val': val_loss}, epoch+1)
122 | print("Average Validation Loss of Epoch {}: {:.6f}".format(epoch+1, val_loss))
123 |
124 | # Visualize attention map
125 | if args.visualize:
126 | model.load_state_dict(torch.load(args.checkpoint))
127 | model.eval()
128 | for i in range(test_length):
129 | with torch.no_grad():
130 | x = torch.Tensor(x_test[i, :]).unsqueeze(0).to(device)
131 | y = torch.Tensor(y_test[i, :]).unsqueeze(0).to(device)
132 | # forward
133 | pred, weights = model(x)
134 | # print(y, pred, weights)
135 | plt.title('Attention Weights')
136 | plt.xticks(np.arange(0, seq_length))
137 | plt.yticks(np.arange(0, 1, step=0.1))
138 | plt.bar(range(seq_length), weights.squeeze().cpu().numpy(), color='royalblue')
139 | plt.savefig('output_{}.png'.format(i))
140 |
--------------------------------------------------------------------------------