├── Readme.md ├── attention.py ├── decoder.py ├── imgs └── attention_img.png └── test.py /Readme.md: -------------------------------------------------------------------------------- 1 | # PyTorch Implementation of [Monotonic Chunkwise Attention](https://openreview.net/forum?id=Hko85plCW) 2 | 3 | ## Requirements 4 | - PyTorch 0.4 5 | 6 | ## TODOs 7 | - [x] Soft MoChA 8 | - [x] Hard MoChA 9 | - [ ] Linear Time Decoding 10 | - [ ] Experiment with Real-world dataset 11 | 12 | ## Model figure 13 | ![Model figure 1](imgs/attention_img.png) 14 | 15 | ## Linear Time Decoding 16 | It's not clear if [authors' TF implementation](https://github.com/craffel/mad/blob/master/example_decoder.py#L235) supports decoding in linear time. 17 | They calculate energies for **whole encoder outputs** instead of scanning from previously attended encoder output. 18 | 19 | ## References 20 | - Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss and Douglas Eck. [Online and Linear-Time Attention by Enforcing Monotonic Alignments](http://arxiv.org/abs/1704.00784) (ICML 2017) 21 | - Chung-Cheng Chiu and Colin Raffel. [Monotonic Chunkwise Attention](https://openreview.net/forum?id=Hko85plCW) (ICLR 2018) 22 | -------------------------------------------------------------------------------- /attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Energy(nn.Module): 7 | def __init__(self, enc_dim=10, dec_dim=10, att_dim=10, init_r=-4): 8 | """ 9 | [Modified Bahdahnau attention] from 10 | "Online and Linear-Time Attention by Enforcing Monotonic Alignment" (ICML 2017) 11 | http://arxiv.org/abs/1704.00784 12 | 13 | Used for Monotonic Attention and Chunk Attention 14 | """ 15 | super().__init__() 16 | self.tanh = nn.Tanh() 17 | self.W = nn.Linear(enc_dim, att_dim, bias=False) 18 | self.V = nn.Linear(dec_dim, att_dim, bias=False) 19 | self.b = nn.Parameter(torch.Tensor(att_dim).normal_()) 20 | 21 | self.v = nn.utils.weight_norm(nn.Linear(10, 1)) 22 | self.v.weight_g.data = torch.Tensor([1 / att_dim]).sqrt() 23 | 24 | self.r = nn.Parameter(torch.Tensor([init_r])) 25 | 26 | def forward(self, encoder_outputs, decoder_h): 27 | """ 28 | Args: 29 | encoder_outputs: [batch_size, sequence_length, enc_dim] 30 | decoder_h: [batch_size, dec_dim] 31 | Return: 32 | Energy [batch_size, sequence_length] 33 | """ 34 | batch_size, sequence_length, enc_dim = encoder_outputs.size() 35 | encoder_outputs = encoder_outputs.view(-1, enc_dim) 36 | energy = self.tanh(self.W(encoder_outputs) + 37 | self.V(decoder_h).repeat(sequence_length, 1) + 38 | self.b) 39 | energy = self.v(energy).squeeze(-1) + self.r 40 | 41 | return energy.view(batch_size, sequence_length) 42 | 43 | 44 | class MonotonicAttention(nn.Module): 45 | def __init__(self): 46 | """ 47 | [Monotonic Attention] from 48 | "Online and Linear-Time Attention by Enforcing Monotonic Alignment" (ICML 2017) 49 | http://arxiv.org/abs/1704.00784 50 | """ 51 | super().__init__() 52 | 53 | self.monotonic_energy = Energy() 54 | self.sigmoid = nn.Sigmoid() 55 | 56 | def gaussian_noise(self, *size): 57 | """Additive gaussian nosie to encourage discreteness""" 58 | if torch.cuda.is_available(): 59 | return torch.cuda.FloatTensor(*size).normal_() 60 | else: 61 | return torch.Tensor(*size).normal_() 62 | 63 | def safe_cumprod(self, x): 64 | """Numerically stable cumulative product by cumulative sum in log-space""" 65 | return torch.exp(torch.cumsum(torch.log(torch.clamp(x, min=1e-10, max=1)), dim=1)) 66 | 67 | def exclusive_cumprod(self, x): 68 | """Exclusive cumulative product [a, b, c] => [1, a, a * b] 69 | * TensorFlow: https://www.tensorflow.org/api_docs/python/tf/cumprod 70 | * PyTorch: https://discuss.pytorch.org/t/cumprod-exclusive-true-equivalences/2614 71 | """ 72 | batch_size, sequence_length = x.size() 73 | if torch.cuda.is_available(): 74 | one_x = torch.cat([torch.ones(batch_size, 1).cuda(), x], dim=1)[:, :-1] 75 | else: 76 | one_x = torch.cat([torch.ones(batch_size, 1), x], dim=1)[:, :-1] 77 | return torch.cumprod(one_x, dim=1) 78 | 79 | def soft(self, encoder_outputs, decoder_h, previous_alpha=None): 80 | """ 81 | Soft monotonic attention (Train) 82 | Args: 83 | encoder_outputs [batch_size, sequence_length, enc_dim] 84 | decoder_h [batch_size, dec_dim] 85 | previous_alpha [batch_size, sequence_length] 86 | Return: 87 | alpha [batch_size, sequence_length] 88 | """ 89 | batch_size, sequence_length, enc_dim = encoder_outputs.size() 90 | 91 | monotonic_energy = self.monotonic_energy(encoder_outputs, decoder_h) 92 | p_select = self.sigmoid(monotonic_energy + self.gaussian_noise(monotonic_energy.size())) 93 | cumprod_1_minus_p = self.safe_cumprod(1 - p_select) 94 | 95 | if previous_alpha is None: 96 | # First iteration => alpha = [1, 0, 0 ... 0] 97 | alpha = torch.zeros(batch_size, sequence_length) 98 | alpha[:, 0] = torch.ones(batch_size) 99 | if torch.cuda.is_available: 100 | alpha = alpha.cuda() 101 | 102 | else: 103 | alpha = p_select * cumprod_1_minus_p * \ 104 | torch.cumsum(previous_alpha / cumprod_1_minus_p, dim=1) 105 | 106 | return alpha 107 | 108 | def hard(self, encoder_outputs, decoder_h, previous_attention=None): 109 | """ 110 | Hard monotonic attention (Test) 111 | Args: 112 | encoder_outputs [batch_size, sequence_length, enc_dim] 113 | decoder_h [batch_size, dec_dim] 114 | previous_attention [batch_size, sequence_length] 115 | Return: 116 | alpha [batch_size, sequence_length] 117 | """ 118 | batch_size, sequence_length, enc_dim = encoder_outputs.size() 119 | 120 | if previous_attention is None: 121 | # First iteration => alpha = [1, 0, 0 ... 0] 122 | attention = torch.zeros(batch_size, sequence_length) 123 | attention[:, 0] = torch.ones(batch_size) 124 | if torch.cuda.is_available: 125 | attention = attention.cuda() 126 | else: 127 | # TODO: Linear Time Decoding 128 | # It's not clear if authors' TF implementation decodes in linear time. 129 | # https://github.com/craffel/mad/blob/master/example_decoder.py#L235 130 | # They calculate energies for whole encoder outputs 131 | # instead of scanning from previous attended encoder output. 132 | monotonic_energy = self.monotonic_energy(encoder_outputs, decoder_h) 133 | 134 | # Hard Sigmoid 135 | # Attend when monotonic energy is above threshold (Sigmoid > 0.5) 136 | above_threshold = (monotonic_energy > 0).float() 137 | 138 | p_select = above_threshold * torch.cumsum(previous_attention, dim=1) 139 | attention = p_select * self.exclusive_cumprod(1 - p_select) 140 | 141 | # Not attended => attend at last encoder output 142 | # Assume that encoder outputs are not padded 143 | attended = attention.sum(dim=1) 144 | for batch_i in range(batch_size): 145 | if not attended[batch_i]: 146 | attention[batch_i, -1] = 1 147 | 148 | # Ex) 149 | # p_select = [0, 0, 0, 1, 1, 0, 1, 1] 150 | # 1 - p_select = [1, 1, 1, 0, 0, 1, 0, 0] 151 | # exclusive_cumprod(1 - p_select) = [1, 1, 1, 1, 0, 0, 0, 0] 152 | # attention: product of above = [0, 0, 0, 1, 0, 0, 0, 0] 153 | return attention 154 | 155 | 156 | class MoChA(MonotonicAttention): 157 | def __init__(self, chunk_size=3): 158 | """ 159 | [Monotonic Chunkwise Attention] from 160 | "Monotonic Chunkwise Attention" (ICLR 2018) 161 | https://openreview.net/forum?id=Hko85plCW 162 | """ 163 | super().__init__() 164 | self.chunk_size = chunk_size 165 | self.chunk_energy = Energy() 166 | self.softmax = nn.Softmax(dim=1) 167 | 168 | def moving_sum(self, x, back, forward): 169 | """Parallel moving sum with 1D Convolution""" 170 | # Pad window before applying convolution 171 | # [batch_size, back + sequence_length + forward] 172 | x_padded = F.pad(x, pad=[back, forward]) 173 | 174 | # Fake channel dimension for conv1d 175 | # [batch_size, 1, back + sequence_length + forward] 176 | x_padded = x_padded.unsqueeze(1) 177 | 178 | # Apply conv1d with filter of all ones for moving sum 179 | filters = torch.ones(1, 1, back + forward + 1) 180 | if torch.cuda.is_available(): 181 | filters = filters.cuda() 182 | x_sum = F.conv1d(x_padded, filters) 183 | 184 | # Remove fake channel dimension 185 | # [batch_size, sequence_length] 186 | return x_sum.squeeze(1) 187 | 188 | def chunkwise_attention_soft(self, alpha, u): 189 | """ 190 | Args: 191 | alpha [batch_size, sequence_length]: emission probability in monotonic attention 192 | u [batch_size, sequence_length]: chunk energy 193 | chunk_size (int): window size of chunk 194 | Return 195 | beta [batch_size, sequence_length]: MoChA weights 196 | """ 197 | 198 | # Numerical stability 199 | # Divide by same exponent => doesn't affect softmax 200 | u -= torch.max(u, dim=1, keepdim=True)[0] 201 | exp_u = torch.exp(u) 202 | # Limit range of logit 203 | exp_u = torch.clamp(exp_u, min=1e-5) 204 | 205 | # Moving sum: 206 | # Zero-pad (chunk size - 1) on the left + 1D conv with filters of 1s. 207 | # [batch_size, sequence_length] 208 | denominators = self.moving_sum(exp_u, 209 | back=self.chunk_size - 1, forward=0) 210 | 211 | # Compute beta (MoChA weights) 212 | beta = exp_u * self.moving_sum(alpha / denominators, 213 | back=0, forward=self.chunk_size - 1) 214 | return beta 215 | 216 | def chunkwise_attention_hard(self, monotonic_attention, chunk_energy): 217 | """ 218 | Mask non-attended area with '-inf' 219 | Args: 220 | monotonic_attention [batch_size, sequence_length] 221 | chunk_energy [batch_size, sequence_length] 222 | Return: 223 | masked_energy [batch_size, sequence_length] 224 | """ 225 | batch_size, sequence_length = monotonic_attention.size() 226 | 227 | # [batch_size] 228 | attended_indices = monotonic_attention.nonzero().cpu().data[:, 1].tolist() 229 | 230 | i = [[], []] 231 | total_i = 0 232 | for batch_i, attended_idx in enumerate(attended_indices): 233 | for window in range(self.chunk_size): 234 | if attended_idx - window >= 0: 235 | i[0].append(batch_i) 236 | i[1].append(attended_idx - window) 237 | total_i += 1 238 | i = torch.LongTensor(i) 239 | v = torch.FloatTensor([1] * total_i) 240 | mask = torch.sparse.FloatTensor(i, v, monotonic_attention.size()) 241 | mask = ~mask.to_dense().cuda().byte() 242 | 243 | # mask '-inf' energy before softmax 244 | masked_energy = chunk_energy.masked_fill_(mask, -float('inf')) 245 | return masked_energy 246 | 247 | def soft(self, encoder_outputs, decoder_h, previous_alpha=None): 248 | """ 249 | Soft monotonic chunkwise attention (Train) 250 | Args: 251 | encoder_outputs [batch_size, sequence_length, enc_dim] 252 | decoder_h [batch_size, dec_dim] 253 | previous_alpha [batch_size, sequence_length] 254 | Return: 255 | alpha [batch_size, sequence_length] 256 | beta [batch_size, sequence_length] 257 | """ 258 | alpha = super().soft(encoder_outputs, decoder_h, previous_alpha) 259 | chunk_energy = self.chunk_energy(encoder_outputs, decoder_h) 260 | beta = self.chunkwise_attention_soft(alpha, chunk_energy) 261 | return alpha, beta 262 | 263 | def hard(self, encoder_outputs, decoder_h, previous_attention=None): 264 | """ 265 | Hard monotonic chunkwise attention (Test) 266 | Args: 267 | encoder_outputs [batch_size, sequence_length, enc_dim] 268 | decoder_h [batch_size, dec_dim] 269 | previous_attention [batch_size, sequence_length] 270 | Return: 271 | monotonic_attention [batch_size, sequence_length]: hard alpha 272 | chunkwise_attention [batch_size, sequence_length]: hard beta 273 | """ 274 | # hard attention (one-hot) 275 | # [batch_size, sequence_length] 276 | monotonic_attention = super().hard(encoder_outputs, decoder_h, previous_attention) 277 | chunk_energy = self.chunk_energy(encoder_outputs, decoder_h) 278 | masked_energy = self.chunkwise_attention_hard(monotonic_attention, chunk_energy) 279 | chunkwise_attention = self.softmax(masked_energy) 280 | return monotonic_attention, chunkwise_attention 281 | -------------------------------------------------------------------------------- /decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from attention import MoChA 5 | 6 | 7 | class MoChADecoder(nn.Module): 8 | def __init__(self, enc_dim=10, dec_dim=10, embedding_dim=10, att_dim=10, 9 | out_dim=10, vocab_size=100, chunk_size=3): 10 | """RNN Decoder with Monotonic Chunkwise Attention""" 11 | super().__init__() 12 | 13 | self.embedding = nn.Embedding(vocab_size, embedding_dim) 14 | self.cell = nn.RNNCell(embedding_dim, dec_dim) 15 | self.attention = MoChA(chunk_size) 16 | 17 | # Effective Approaches to Attention-based Neural Machine Translation (EMNLP 2015) 18 | self.combine_c_h = nn.Linear(enc_dim + dec_dim, out_dim, bias=False) 19 | self.tanh = nn.Tanh() 20 | self.proj_vocab = nn.Linear(out_dim, vocab_size) 21 | 22 | def init_x(self, batch_size, sos_id=1): 23 | if torch.cuda.is_available: 24 | return torch.cuda.LongTensor([sos_id] * batch_size) 25 | else: 26 | return torch.LongTensor([sos_id] * batch_size) 27 | 28 | def init_h(self, batch_size, dec_dim=10): 29 | if torch.cuda.is_available: 30 | return torch.cuda.FloatTensor(batch_size, dec_dim).normal_() 31 | else: 32 | return torch.Tensor(batch_size, dec_dim).normal_() 33 | 34 | def forward_train(self, encoder_outputs, decoder_inputs): 35 | """ 36 | Args: 37 | encoder_outputs [batch_size, enc_sequence_length, enc_dim] 38 | decoder_inputs [batch_size, dec_sequence_length] 39 | Return: 40 | logits: [batch_size, dec_sequence_length, vocab_size] 41 | """ 42 | batch_size, enc_sequence_length, enc_dim = encoder_outputs.size() 43 | batch_size, dec_sequence_length = decoder_inputs.size() 44 | 45 | x = self.init_x(batch_size) 46 | h = self.init_h(batch_size) 47 | alpha = None 48 | logit_list = [] 49 | for i in range(dec_sequence_length): 50 | x = self.embedding(x) 51 | h = self.cell(x, h) 52 | 53 | # alpha: [batch_size, sequence_length] 54 | # beta: [batch_size, sequence_length] 55 | alpha, beta = self.attention.soft(encoder_outputs, h, alpha) 56 | 57 | # Weighted-sum 58 | # [batch_size, out_dim] 59 | context = torch.sum(beta.unsqueeze(-1) * encoder_outputs, dim=1) 60 | 61 | # [batch_size, out_dim] 62 | attentional = self.tanh(self.combine_c_h(torch.cat([context, h], dim=1))) 63 | 64 | # [batch_size, vocab_size] 65 | logit = self.proj_vocab(attentional) 66 | logit_list.append(logit) 67 | 68 | x = decoder_inputs[:, i] 69 | 70 | return torch.stack(logit_list, dim=1) 71 | 72 | def forward_test(self, encoder_outputs, max_dec_length=20): 73 | """ 74 | Args: 75 | encoder_outputs [batch_size, enc_sequence_length, enc_dim] 76 | max_dec_length (int; default=20) 77 | Return: 78 | outputs: [batch_size, max_dec_length] 79 | """ 80 | batch_size, enc_sequence_length, enc_dim = encoder_outputs.size() 81 | 82 | x = self.init_x(batch_size) 83 | h = self.init_h(batch_size) 84 | monotonic_attention = None 85 | output_list = [] 86 | for i in range(max_dec_length): 87 | x = self.embedding(x) 88 | h = self.cell(x, h) 89 | 90 | # monotonic_attention (one-hot): [batch_size, sequence_length] 91 | # chunkwise_attention (nonzero in chunk size): [batch_size, sequence_length] 92 | monotonic_attention, chunkwise_attention = self.attention.hard( 93 | encoder_outputs, h, monotonic_attention) 94 | 95 | # Weighted-sum 96 | # [batch_size, out_dim] 97 | context = torch.sum(chunkwise_attention.unsqueeze(-1) * encoder_outputs, dim=1) 98 | 99 | # [batch_size, out_dim] 100 | attentional = self.tanh(self.combine_c_h(torch.cat([context, h], dim=1))) 101 | 102 | # [batch_size, vocab_size] 103 | logit = self.proj_vocab(attentional) 104 | 105 | # Greedy Decoding 106 | # [batch_size] 107 | x = torch.max(logit, dim=1)[1] 108 | output_list.append(x) 109 | 110 | return torch.stack(output_list, dim=1) 111 | -------------------------------------------------------------------------------- /imgs/attention_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/j-min/MoChA-pytorch/94b54a7fa13e4ac6dc255b509dd0febc8c0a0ee6/imgs/attention_img.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | 4 | from decoder import MoChADecoder 5 | 6 | 7 | class MoChATest(unittest.TestCase): 8 | 9 | def setUp(self): 10 | self.batch_size = 5 11 | self.sequence_length = 40 12 | self.chunk_size = 3 13 | self.dim = 10 14 | self.vocab_size = 100 15 | 16 | def test_soft(self): 17 | """Soft Monotonic Chunkwise Attention""" 18 | 19 | enc_outputs = torch.Tensor( 20 | self.batch_size, self.sequence_length, self.dim).normal_() 21 | dec_inputs = torch.LongTensor( 22 | self.batch_size, self.sequence_length).clamp_(min=0, max=self.vocab_size - 1) 23 | decoder = MoChADecoder(vocab_size=self.vocab_size, chunk_size=self.chunk_size) 24 | 25 | if torch.cuda.is_available(): 26 | enc_outputs = enc_outputs.cuda() 27 | dec_inputs = dec_inputs.cuda() 28 | decoder = decoder.cuda() 29 | 30 | decoder.forward_train(enc_outputs, dec_inputs) 31 | 32 | def test_hard(self): 33 | """Hard Monotonic Chunkwise Attention""" 34 | 35 | enc_outputs = torch.Tensor( 36 | self.batch_size, self.sequence_length, self.dim).normal_() 37 | decoder = MoChADecoder( 38 | vocab_size=self.vocab_size, chunk_size=self.chunk_size) 39 | 40 | if torch.cuda.is_available(): 41 | enc_outputs = enc_outputs.cuda() 42 | decoder = decoder.cuda() 43 | 44 | decoder.forward_test(enc_outputs) 45 | 46 | 47 | if __name__ == '__main__': 48 | unittest.main() 49 | --------------------------------------------------------------------------------