├── LICENSE
├── README.md
├── figs
├── 115k_without_sma.png
├── 125k_with_sma.png
└── figure1.png
├── hparams.py
└── sma.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Keon Lee
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Stepwise_Monotonic_Multihead_Attention
2 |
3 | Pytorch Implementation of Stepwise Monotonic Multihead Attention (SMA) similar to [Enhancing Monotonicity for Robust Autoregressive Transformer TTS](https://www.isca-speech.org/archive/Interspeech_2020/pdfs/1751.pdf)
4 |
5 |
6 |
7 |
8 |
9 | # Example Results
10 |
11 | You may apply SMA to match mel-spectrogram to text in the length of sequences. Below are some results showing the effectiveness of SMA. The first figure is the alignment without sma (`hp.sma_tunable=False`) at 115k steps. The second figure is the one with sma tunning (`hp.sma_tunable=True`) at 125k steps.
12 |
13 |
14 |
15 |
16 |
17 |
18 | As you can see, I can confirm that the alignment is being highly stronger than normal multihead attention after applying SMA tunning.
19 |
20 |
21 | # Usage
22 | First, define the SMA. Let's say we have 256 dimensional encoding and 4-multihead attention.
23 | ```python
24 | from sma import StepwiseMonotonicMultiheadAttention
25 |
26 | ref_attention = StepwiseMonotonicMultiheadAttention(256, 256//4, 256//4)
27 | ```
28 | And then, you can apply attention and get an alignment as follows. `mel_len` is the frame size of reference audio, and `seq_len` is the length of input text (which is usually a sequence of phonemes). `fr_max` is a maximum value of focus rate from `focused_head()` function. Both `text_mask` and `attn_mask` have `1.` for values that will be masked out and `0.` for others to be kept.
29 | ```python
30 | """
31 | enc_out --- [batch, seq_len, 256]
32 | attn --- [batch, seq_len, mel_len]
33 | enc_text --- [batch, seq_len, 256]
34 | enc_audio --- [batch, mel_len, 256]
35 | text_mask --- [batch, seq_len, 1]
36 | attn_mask --- [batch, seq_len, mel_len]
37 | """
38 |
39 | # Attention
40 | enc_out, attn, fr_max = ref_attention(enc_text, enc_audio, enc_audio,\
41 | mel_len, mask=attn_mask, query_mask=text_mask)
42 | ```
43 | As you can see, SMA returns the text-audio fusion in text size (`seq_len`) regardless of the audio size (`mel_len`).
44 |
45 | # Notes
46 |
47 | 1. `hp.sma_tunable` is the hyperparameter that can toggle the tunning scheme of stepwise monotonic multihead attention. If set `True`, the stepwise monotonic multihead attention is activated. Else, it is a normal multihead attention, just like in Transformer. As in [Enhancing Monotonicity for Robust Autoregressive Transformer TTS](https://www.isca-speech.org/archive/Interspeech_2020/pdfs/1751.pdf)(we will call this paper as 'reference paper' in the following documents), for example, you may train module without SMA for certain steps to the faster training and model converge, and then activate SMA by setting `sma_tunable=True` to make strong monotonic alignment in few steps.
48 | 2. `expectation()` is the fucntion calculating stepwise monotonic expectation score which is denoted as `alpha` in the reference paper.
49 | 3. In the current implementation, the query is from text encoding (output of `encoder` in general TTS framework) and the key and value are from mel-spectrogram encoding (output of `reference encoder` in general mel-spectrogram encoding framework, e.g., reference encoder in GST scheme). As a result, current SMA module converts the mel-spectrogram encoding from the length of mel-spectrogram to the length of text. You MUST carefully modify the dimension (especially in the `expectation` function) of query, key, and value depending on the task.
50 | 4. During tunning phase (monotonic enhancement) with SMA, the `focused_head` function will select the best diagonal (monotonically increasing) alignment among heads. It follows the 'focus rate' in [FastSpeech](https://arxiv.org/pdf/1905.09263.pdf) framework as in the reference paper. Different from the reference paper, the maximum focus rated head is selected rather than by threshold. However, you can adopt it by simply adding `prefixed_threshold`(e.g.,`0.5`) to the `focused_head` function.
51 | 5. You can enjoy my code, and any suggestions are appreciated.
52 |
53 | # Citation
54 |
55 | ```
56 | @misc{lee2021sma,
57 | author = {Lee, Keon},
58 | title = {Stepwise_Monotonic_Multihead_Attention},
59 | year = {2021},
60 | publisher = {GitHub},
61 | journal = {GitHub repository},
62 | howpublished = {\url{https://github.com/keonlee9420/Stepwise_Monotonic_Multihead_Attention}}
63 | }
64 | ```
65 |
66 | # References
67 |
68 | - [Online and Linear-Time Attention by Enforcing Monotonic Alignments](https://arxiv.org/pdf/1704.00784.pdf)
69 | - [Robust Sequence-to-Sequence Acoustic Modeling with Stepwise Monotonic
70 | Attention for Neural TTS](https://arxiv.org/pdf/1906.00672.pdf) [[author's code](https://gist.github.com/mutiann/38a7638f75c21479582d7391490df37c)]
71 | - [Monotonic Multihead Attention](https://arxiv.org/pdf/1909.12406.pdf)
72 | - [Enhancing Monotonicity for Robust Autoregressive Transformer TTS](https://www.isca-speech.org/archive/Interspeech_2020/pdfs/1751.pdf)
73 | - [hirofumi0810's implementation](https://github.com/hirofumi0810) of Monotonic (multihead) chunkwise attention
74 |
--------------------------------------------------------------------------------
/figs/115k_without_sma.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/Stepwise_Monotonic_Multihead_Attention/3278cb54f2b923b57a9d6db2304d9cfb2e1328e1/figs/115k_without_sma.png
--------------------------------------------------------------------------------
/figs/125k_with_sma.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/Stepwise_Monotonic_Multihead_Attention/3278cb54f2b923b57a9d6db2304d9cfb2e1328e1/figs/125k_with_sma.png
--------------------------------------------------------------------------------
/figs/figure1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keonlee9420/Stepwise_Monotonic_Multihead_Attention/3278cb54f2b923b57a9d6db2304d9cfb2e1328e1/figs/figure1.png
--------------------------------------------------------------------------------
/hparams.py:
--------------------------------------------------------------------------------
1 | """
2 | You may change these hyperparameters depending on the task.
3 | """
4 | sma_head = 4
5 | sma_dropout = 0.1
6 | sma_tunable = False # If True, the stepwise monotonice multihead attention is activated. Else, it is a normal multihead attention just like in Transformer.
--------------------------------------------------------------------------------
/sma.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import hparams as hp
6 |
7 |
8 | class StepwiseMonotonicMultiheadAttention(nn.Module):
9 | """ Stepwise Monotonic Multihead Attention
10 | args:
11 | n_heads (int): number of monotonic attention heads
12 | d_model (int): dimension of model (attention)
13 | d_k (int): dimension of key
14 | d_v (int): dimension of value
15 | noise_std (float): standard deviation for input noisse
16 | dropout (float): dropout probability for attention weights
17 | """
18 |
19 | def __init__(self, d_model, d_k, d_v,
20 | noise_std=1.0,
21 | n_head=hp.sma_head,
22 | dropout=hp.sma_dropout,
23 | is_tunable=hp.sma_tunable):
24 | super(StepwiseMonotonicMultiheadAttention, self).__init__()
25 | self.n_head = n_head
26 | self.noise_std = noise_std
27 | self.energy = MultiheadEnergy(n_head, d_model, d_k, d_v)
28 |
29 | self.dropout = nn.Dropout(dropout)
30 | self.last_layer = nn.Linear(n_head*d_v, d_model)
31 | self.layer_norm = nn.LayerNorm(d_model)
32 |
33 | self.is_tunable = is_tunable
34 |
35 | def add_gaussian_noise(self, xs, std):
36 | """Add Gaussian noise to encourage discreteness."""
37 | noise = xs.new_zeros(xs.size()).normal_(std=std)
38 | return xs + noise
39 |
40 | def expectation(self, e, aw_prev, n_head):
41 | """
42 | e --- [batch*n_head, qlen, klen]
43 | aw_prev --- [batch*n_head, qlen, 1]
44 | See https://gist.github.com/mutiann/38a7638f75c21479582d7391490df37c
45 | See https://github.com/hirofumi0810/neural_sp/blob/093bfade110d5a15a4f7a58fffe8d235acbfe14f/neural_sp/models/modules/mocha.py#L430
46 | """
47 | batch_size, qlen, klen = aw_prev.size(0)//n_head, e.size(1), e.size(2)
48 |
49 | # Compute probability sampling matrix P
50 | p_sample = torch.sigmoid(self.add_gaussian_noise(e, self.noise_std) if self.training else e) # [batch*n_head, qlen, klen]
51 |
52 | alpha = []
53 | # Compute recurrence relation solution along mel frame domain
54 | for i in range(klen):
55 | p_sample_i = p_sample[:, :, i:i + 1]
56 | pad = torch.zeros([batch_size*n_head, 1, 1], dtype=aw_prev.dtype).to(aw_prev.device)
57 | aw_prev = aw_prev * p_sample_i + torch.cat(
58 | (pad, aw_prev[:, :-1, :] * (1.0 - p_sample_i[:, :-1, :])), dim=1)
59 | alpha.append(aw_prev)
60 |
61 | alpha = torch.cat(alpha, dim=-1) if klen > 1 else alpha[-1] # [batch*n_head, qlen, klen]
62 |
63 | assert not torch.isnan(alpha).any(), "NaN detected in alpha."
64 |
65 | return alpha, p_sample
66 |
67 | def focused_head(self, multihead, mel_len):
68 | """
69 | Apply focus rate to select the best diagonal head.
70 | multihead --- [batch*n_heads, seq_len, mel_len]
71 | mel_len --- [batch,]
72 | return --- [batch, seq_len, mel_len]
73 | """
74 | # [batch*n_heads, seq_len, mel_len] -> [batch, n_heads, seq_len, mel_len]
75 | multihead = multihead.reshape(self.n_head, -1, multihead.size(1), multihead.size(2)).transpose(0, 1)
76 | focus_rate = torch.max(multihead, dim=2)[0].sum(dim=-1)/(mel_len.unsqueeze(1)) # [batch, n_heads]
77 | h_idx = torch.argmax(focus_rate, dim=1) # [batch,]
78 | batch=list()
79 | fr_max=0
80 | for b, fr, i in zip(multihead, focus_rate, h_idx):
81 | batch.append(b[i])
82 | fr_max += fr[i].detach().item()
83 | return torch.stack(batch), fr_max/h_idx.size(0)
84 |
85 | def repeat_mask_multihead(self, mask):
86 | """
87 | Repeat mask over multihead.
88 | mask --- [batch, qlen, klen]
89 | return --- [batch*n_head, qlen, klen]
90 | """
91 | return mask.repeat(self.n_head, 1, 1)
92 |
93 | def forward(self, q, k, v, mel_len, mask=None, query_mask=None, aw_prev=None):
94 | batch_size, qlen, klen = q.size(0), q.size(1), k.size(1)
95 | if mask is not None:
96 | mask = self.repeat_mask_multihead(mask)
97 |
98 | # Calculate energy
99 | e, v = self.energy(q, k, v, mask) # [batch*n_head, qlen, klen], [batch*n_head, klen, d_v]
100 |
101 | # Get alpha
102 | alpha_cv = F.softmax(e, dim=-1) # [batch*n_head, qlen, klen]
103 |
104 | # Masking to ignore padding (query side)
105 | if query_mask is not None:
106 | query_mask = self.repeat_mask_multihead(query_mask.repeat(1, 1, klen))
107 | alpha_cv = alpha_cv.masked_fill(query_mask, 0.)
108 |
109 | # Get focused alpha
110 | alpha_fc, fr_max = self.focused_head(alpha_cv, mel_len) # [batch, qlen, klen]
111 |
112 | if self.is_tunable:
113 | # Monotonic enhancement
114 | if aw_prev is None:
115 | aw_prev = k.new_zeros(batch_size, qlen, 1) # [batch, qlen, 1]
116 | aw_prev[:, 0:1] = k.new_ones(batch_size, 1, 1) # initialize with [1, 0, 0 ... 0]
117 | alpha_me, _ = self.expectation(alpha_fc, aw_prev, 1) # [batch, qlen, klen]
118 |
119 | # Calculate context vector
120 | v = v.reshape(self.n_head, batch_size, klen, -1).permute(1, 2, 0, 3) # [batch, klen, n_head, d_v]
121 | cv = torch.bmm(alpha_me, v.reshape(batch_size, klen, -1)) # [batch, qlen, n_head*d_v]
122 | else:
123 | # Calculate normal multihead attention
124 | cv = torch.bmm(alpha_cv, v).reshape(self.n_head, batch_size, qlen, -1).permute(1, 2, 0, 3) # [batch, qlen, n_head, d_v]
125 | cv = cv.reshape(batch_size, qlen, -1) # [batch, qlen, n_head*d_v]
126 |
127 | cv = self.dropout(self.last_layer(cv))
128 | cv = self.layer_norm(cv)
129 | return cv, alpha_fc, fr_max
130 |
131 |
132 | class MultiheadEnergy(nn.Module):
133 | """ Energy function for the (monotonic) multihead attention """
134 |
135 | def __init__(self, n_head, d_model, d_k, d_v):
136 | super(MultiheadEnergy, self).__init__()
137 |
138 | self.n_head = n_head
139 | self.d_k = d_k
140 | self.d_v = d_v
141 |
142 | self.w_qs = nn.Linear(d_model, n_head * d_k)
143 | self.w_ks = nn.Linear(d_model, n_head * d_k)
144 | self.w_vs = nn.Linear(d_model, n_head * d_v)
145 |
146 | self.temperature = np.power(d_k, 0.5)
147 |
148 | def scaled_dot_product(self, q, k):
149 | sdp = torch.bmm(q, k.transpose(1, 2)) # (n*b) x lq x lk
150 | sdp = sdp / self.temperature
151 | return sdp
152 |
153 | def forward(self, q, k, v, mask=None):
154 |
155 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
156 |
157 | sz_b, len_q, _ = q.size()
158 | sz_b, len_k, _ = k.size()
159 | sz_b, len_v, _ = v.size()
160 |
161 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
162 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
163 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
164 | q = q.permute(2, 0, 1, 3).contiguous().view(-1,
165 | len_q, d_k) # (n*b) x lq x dk
166 | k = k.permute(2, 0, 1, 3).contiguous().view(-1,
167 | len_k, d_k) # (n*b) x lk x dk
168 | v = v.permute(2, 0, 1, 3).contiguous().view(-1,
169 | len_v, d_v) # (n*b) x lv x dv
170 |
171 | # Compute monotonic multihead energy
172 | e = self.scaled_dot_product(q, k) # (n*b) x lq x lk
173 |
174 | # Masking to ignore padding
175 | if mask is not None:
176 | NEG_INF = float(np.finfo(torch.tensor(0, dtype=e.dtype).numpy().dtype).min)
177 | e = e.masked_fill(mask, NEG_INF)
178 |
179 | return e, v
180 |
--------------------------------------------------------------------------------