├── speech_transformer
├── __init__.py
├── mask.py
├── embeddings.py
├── sublayers.py
├── layers.py
├── modules.py
├── encoder.py
├── attention.py
├── model.py
├── convolution.py
├── decoder.py
└── beam_decoder.py
├── LICENSE
└── README.md
/speech_transformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import SpeechTransformer
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Soohwan Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/speech_transformer/mask.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 |
17 | from torch import Tensor
18 |
19 |
20 | def get_attn_pad_mask(inputs, input_lengths, expand_length):
21 | """ mask position is set to 1 """
22 | def get_transformer_non_pad_mask(inputs: Tensor, input_lengths: Tensor) -> Tensor:
23 | """ Padding position is set to 0, either use input_lengths or pad_id """
24 | batch_size = inputs.size(0)
25 |
26 | if len(inputs.size()) == 2:
27 | non_pad_mask = inputs.new_ones(inputs.size()) # B x T
28 | elif len(inputs.size()) == 3:
29 | non_pad_mask = inputs.new_ones(inputs.size()[:-1]) # B x T
30 | else:
31 | raise ValueError(f"Unsupported input shape {inputs.size()}")
32 |
33 | for i in range(batch_size):
34 | non_pad_mask[i, input_lengths[i]:] = 0
35 |
36 | return non_pad_mask
37 |
38 | non_pad_mask = get_transformer_non_pad_mask(inputs, input_lengths)
39 | pad_mask = non_pad_mask.lt(1)
40 | attn_pad_mask = pad_mask.unsqueeze(1).expand(-1, expand_length, -1)
41 | return attn_pad_mask
42 |
43 |
44 | def get_attn_subsequent_mask(seq):
45 | assert seq.dim() == 2
46 | attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
47 | subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1)
48 |
49 | if seq.is_cuda:
50 | subsequent_mask = subsequent_mask.cuda()
51 |
52 | return subsequent_mask
53 |
--------------------------------------------------------------------------------
/speech_transformer/embeddings.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 | import torch
17 | import torch.nn as nn
18 | from torch import Tensor
19 |
20 |
21 | class PositionalEncoding(nn.Module):
22 | """
23 | Positional Encoding proposed in "Attention Is All You Need".
24 | Since speech_transformer contains no recurrence and no convolution, in order for the model to make
25 | use of the order of the sequence, we must add some positional information.
26 |
27 | "Attention Is All You Need" use sine and cosine functions of different frequencies:
28 | PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model))
29 | PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model))
30 | """
31 | def __init__(self, d_model: int = 512, max_len: int = 5000) -> None:
32 | super(PositionalEncoding, self).__init__()
33 | pe = torch.zeros(max_len, d_model, requires_grad=False)
34 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
35 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
36 | pe[:, 0::2] = torch.sin(position * div_term)
37 | pe[:, 1::2] = torch.cos(position * div_term)
38 | pe = pe.unsqueeze(0)
39 | self.register_buffer('pe', pe)
40 |
41 | def forward(self, length: int) -> Tensor:
42 | return self.pe[:, :length]
43 |
44 |
45 | class Embedding(nn.Module):
46 | """
47 | Embedding layer. Similarly to other sequence transduction models, speech_transformer use learned embeddings
48 | to convert the input tokens and output tokens to vectors of dimension d_model.
49 | In the embedding layers, speech_transformer multiply those weights by sqrt(d_model)
50 | """
51 | def __init__(self, num_embeddings: int, pad_id: int, d_model: int = 512) -> None:
52 | super(Embedding, self).__init__()
53 | self.sqrt_dim = math.sqrt(d_model)
54 | self.embedding = nn.Embedding(num_embeddings, d_model, padding_idx=pad_id)
55 |
56 | def forward(self, inputs: Tensor) -> Tensor:
57 | return self.embedding(inputs) * self.sqrt_dim
58 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Speech-Transformer
2 |
3 | PyTorch implementation of [The SpeechTransformer for Large-scale Mandarin Chinese Speech Recognition](https://ieeexplore.ieee.org/document/8682586).
4 |
5 |
6 |
7 | Speech Transformer is a transformer framework specialized in speech recognition tasks.
8 | This repository contains only model code, but you can train with speech transformer with this [repository](https://github.com/sooftware/KoSpeech).
9 | I appreciate any kind of [feedback or contribution](https://github.com/sooftware/Speech-Transformer/issues)
10 |
11 | ## Usage
12 | - Training
13 | ```python
14 | import torch
15 | from speech_transformer import SpeechTransformer
16 |
17 | BATCH_SIZE, SEQ_LENGTH, DIM, NUM_CLASSES = 3, 12345, 80, 4
18 |
19 | cuda = torch.cuda.is_available()
20 | device = torch.device('cuda' if cuda else 'cpu')
21 |
22 | inputs = torch.rand(BATCH_SIZE, SEQ_LENGTH, DIM).to(device)
23 | input_lengths = torch.IntTensor([100, 50, 8])
24 | targets = torch.LongTensor([[2, 3, 3, 3, 3, 3, 2, 2, 1, 0],
25 | [2, 3, 3, 3, 3, 3, 2, 1, 2, 0],
26 | [2, 3, 3, 3, 3, 3, 2, 2, 0, 1]]).to(device) # 1 means
27 | target_lengths = torch.IntTensor([10, 9, 8])
28 |
29 | model = SpeechTransformer(num_classes=NUM_CLASSES, d_model=512, num_heads=8, input_dim=DIM)
30 | predictions, logits = model(inputs, input_lengths, targets, target_lengths)
31 | ```
32 | - Beam Search Decoding
33 | ```python
34 | import torch
35 | from speech_transformer import SpeechTransformer
36 |
37 | BATCH_SIZE, SEQ_LENGTH, DIM, NUM_CLASSES = 3, 12345, 80, 10
38 |
39 | cuda = torch.cuda.is_available()
40 | device = torch.device('cuda' if cuda else 'cpu')
41 |
42 | inputs = torch.rand(BATCH_SIZE, SEQ_LENGTH, DIM).to(device) # BxTxD
43 | input_lengths = torch.LongTensor([SEQ_LENGTH, SEQ_LENGTH - 10, SEQ_LENGTH - 20]).to(device)
44 |
45 | model = SpeechTransformer(num_classes=NUM_CLASSES, d_model=512, num_heads=8, input_dim=DIM)
46 | model.set_beam_decoder(batch_size=BATCH_SIZE, beam_size=3)
47 | predictions, _ = model(inputs, input_lengths)
48 | ```
49 |
50 | ## Troubleshoots and Contributing
51 | If you have any questions, bug reports, and feature requests, please [open an issue](https://github.com/sooftware/Jasper-pytorch/issues) on github or
52 | contacts sh951011@gmail.com please.
53 |
54 | I appreciate any kind of feedback or contribution. Feel free to proceed with small issues like bug fixes, documentation improvement. For major contributions and new features, please discuss with the collaborators in corresponding issues.
55 |
56 | ## Code Style
57 | I follow [PEP-8](https://www.python.org/dev/peps/pep-0008/) for code style. Especially the style of docstrings is important to generate documentation.
58 |
59 | ## Reference
60 | - [The SpeechTransformer for Large-scale Mandarin Chinese Speech Recognition (Yuanyuan Zhao et al, 2019)](https://ieeexplore.ieee.org/document/8682586)
61 | - [kaituoxu/Speech-Transformer](https://github.com/kaituoxu/Speech-Transformer)
62 |
63 | ## Author
64 |
65 | * Soohwan Kim [@sooftware](https://github.com/sooftware)
66 | * Contacts: sh951011@gmail.com
67 |
--------------------------------------------------------------------------------
/speech_transformer/sublayers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch.nn as nn
16 |
17 | from torch import Tensor
18 | from typing import Any, Optional
19 | from speech_transformer.modules import (
20 | Linear,
21 | MaskConv2d,
22 | )
23 |
24 |
25 | class AddNorm(nn.Module):
26 | """
27 | Add & Normalization layer proposed in "Attention Is All You Need".
28 | Transformer employ a residual connection around each of the two sub-layers,
29 | (Multi-Head Attention & Feed-Forward) followed by layer normalization.
30 | """
31 | def __init__(self, sublayer: nn.Module, d_model: int = 512) -> None:
32 | super(AddNorm, self).__init__()
33 | self.sublayer = sublayer
34 | self.layer_norm = nn.LayerNorm(d_model)
35 |
36 | def forward(self, *args):
37 | residual = args[0]
38 | output = self.sublayer(*args)
39 |
40 | if isinstance(output, tuple):
41 | return self.layer_norm(output[0] + residual), output[1]
42 |
43 | return self.layer_norm(output + residual)
44 |
45 |
46 | class PositionWiseFeedForwardNet(nn.Module):
47 | """
48 | Position-wise Feedforward Networks proposed in "Attention Is All You Need".
49 | Fully connected feed-forward network, which is applied to each position separately and identically.
50 | This consists of two linear transformations with a ReLU activation in between.
51 | Another way of describing this is as two convolutions with kernel size 1.
52 | """
53 | def __init__(self, d_model: int = 512, d_ff: int = 2048,
54 | dropout_p: float = 0.3, ffnet_style: str = 'ff') -> None:
55 | super(PositionWiseFeedForwardNet, self).__init__()
56 | self.ffnet_style = ffnet_style.lower()
57 | if self.ffnet_style == 'ff':
58 | self.feed_forward = nn.Sequential(
59 | Linear(d_model, d_ff),
60 | nn.Dropout(dropout_p),
61 | nn.ReLU(),
62 | Linear(d_ff, d_model),
63 | nn.Dropout(dropout_p),
64 | )
65 |
66 | elif self.ffnet_style == 'conv':
67 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
68 | self.relu = nn.ReLU()
69 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
70 |
71 | else:
72 | raise ValueError("Unsupported mode: {0}".format(self.mode))
73 |
74 | def forward(self, inputs: Tensor) -> Tensor:
75 | if self.ffnet_style == 'conv':
76 | output = self.conv1(inputs.transpose(1, 2))
77 | output = self.relu(output)
78 | return self.conv2(output).transpose(1, 2)
79 |
80 | return self.feed_forward(inputs)
81 |
--------------------------------------------------------------------------------
/speech_transformer/layers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch.nn as nn
16 | from torch import Tensor
17 | from typing import Tuple, Optional, Any
18 | from speech_transformer.attention import MultiHeadAttention
19 | from speech_transformer.sublayers import (
20 | PositionWiseFeedForwardNet,
21 | AddNorm,
22 | )
23 |
24 |
25 | class SpeechTransformerEncoderLayer(nn.Module):
26 | """
27 | EncoderLayer is made up of self-attention and feedforward network.
28 | This standard encoder layer is based on the paper "Attention Is All You Need".
29 |
30 | Args:
31 | d_model: dimension of model (default: 512)
32 | num_heads: number of attention heads (default: 8)
33 | d_ff: dimension of feed forward network (default: 2048)
34 | dropout_p: probability of dropout (default: 0.3)
35 | ffnet_style: style of feed forward network [ff, conv] (default: ff)
36 | """
37 |
38 | def __init__(
39 | self,
40 | d_model: int = 512, # dimension of model
41 | num_heads: int = 8, # number of attention heads
42 | d_ff: int = 2048, # dimension of feed forward network
43 | dropout_p: float = 0.3, # probability of dropout
44 | ffnet_style: str = 'ff' # style of feed forward network
45 | ) -> None:
46 | super(SpeechTransformerEncoderLayer, self).__init__()
47 | self.self_attention = AddNorm(MultiHeadAttention(d_model, num_heads), d_model)
48 | self.feed_forward = AddNorm(PositionWiseFeedForwardNet(d_model, d_ff, dropout_p, ffnet_style), d_model)
49 |
50 | def forward(self, inputs: Tensor, self_attn_mask: Optional[Any] = None) -> Tuple[Tensor, Tensor]:
51 | output, attn = self.self_attention(inputs, inputs, inputs, self_attn_mask)
52 | output = self.feed_forward(output)
53 | return output, attn
54 |
55 |
56 | class SpeechTransformerDecoderLayer(nn.Module):
57 | """
58 | DecoderLayer is made up of self-attention, multi-head attention and feedforward network.
59 | This standard decoder layer is based on the paper "Attention Is All You Need".
60 |
61 | Args:
62 | d_model: dimension of model (default: 512)
63 | num_heads: number of attention heads (default: 8)
64 | d_ff: dimension of feed forward network (default: 2048)
65 | dropout_p: probability of dropout (default: 0.3)
66 | ffnet_style: style of feed forward network [ff, conv] (default: ff)
67 | """
68 |
69 | def __init__(
70 | self,
71 | d_model: int = 512, # dimension of model
72 | num_heads: int = 8, # number of attention heads
73 | d_ff: int = 2048, # dimension of feed forward network
74 | dropout_p: float = 0.3, # probability of dropout
75 | ffnet_style: str = 'ff' # style of feed forward network
76 | ) -> None:
77 | super(SpeechTransformerDecoderLayer, self).__init__()
78 | self.self_attention = AddNorm(MultiHeadAttention(d_model, num_heads), d_model)
79 | self.memory_attention = AddNorm(MultiHeadAttention(d_model, num_heads), d_model)
80 | self.feed_forward = AddNorm(PositionWiseFeedForwardNet(d_model, d_ff, dropout_p, ffnet_style), d_model)
81 |
82 | def forward(
83 | self,
84 | inputs: Tensor, # tensor contains target sequence
85 | memory: Tensor, # tensor contains encoder outputs
86 | self_attn_mask: Optional[Any] = None, # tensor contains mask of self attention
87 | memory_mask: Optional[Any] = None # tensor contains mask of encoder outputs
88 | ) -> Tuple[Tensor, Tensor, Tensor]:
89 | output, self_attn = self.self_attention(inputs, inputs, inputs, self_attn_mask)
90 | output, memory_attn = self.memory_attention(output, memory, memory, memory_mask)
91 | output = self.feed_forward(output)
92 | return output, self_attn, memory_attn
93 |
--------------------------------------------------------------------------------
/speech_transformer/modules.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.init as init
18 |
19 | from torch import Tensor
20 | from typing import Tuple
21 |
22 |
23 | class MaskConv2d(nn.Module):
24 | """
25 | Masking Convolutional Neural Network
26 | Adds padding to the output of the module based on the given lengths.
27 | This is to ensure that the results of the model do not change when batch sizes change during inference.
28 | Input needs to be in the shape of (batch_size, channel, hidden_dim, seq_len)
29 | Refer to https://github.com/SeanNaren/deepspeech.pytorch/blob/master/model.py
30 | Copyright (c) 2017 Sean Naren
31 | MIT License
32 | Args:
33 | sequential (torch.nn): sequential list of convolution layer
34 | Inputs: inputs, seq_lengths
35 | - **inputs** (torch.FloatTensor): The input of size BxCxHxT
36 | - **seq_lengths** (torch.IntTensor): The actual length of each sequence in the batch
37 | Returns: output, seq_lengths
38 | - **output**: Masked output from the sequential
39 | - **seq_lengths**: Sequence length of output from the sequential
40 | """
41 | def __init__(self, sequential: nn.Sequential) -> None:
42 | super(MaskConv2d, self).__init__()
43 | self.sequential = sequential
44 |
45 | def forward(self, inputs: Tensor, seq_lengths: Tensor) -> Tuple[Tensor, Tensor]:
46 | output = None
47 |
48 | for module in self.sequential:
49 | output = module(inputs)
50 | mask = torch.BoolTensor(output.size()).fill_(0)
51 |
52 | if output.is_cuda:
53 | mask = mask.cuda()
54 |
55 | seq_lengths = self.get_sequence_lengths(module, seq_lengths)
56 |
57 | for idx, length in enumerate(seq_lengths):
58 | length = length.item()
59 |
60 | if (mask[idx].size(2) - length) > 0:
61 | mask[idx].narrow(dim=2, start=length, length=mask[idx].size(2) - length).fill_(1)
62 |
63 | output = output.masked_fill(mask, 0)
64 | inputs = output
65 |
66 | return output, seq_lengths
67 |
68 | def get_sequence_lengths(self, module: nn.Module, seq_lengths: Tensor) -> Tensor:
69 | """
70 | Calculate convolutional neural network receptive formula
71 | Args:
72 | module (torch.nn.Module): module of CNN
73 | seq_lengths (torch.IntTensor): The actual length of each sequence in the batch
74 | Returns: seq_lengths
75 | - **seq_lengths**: Sequence length of output from the module
76 | """
77 | if isinstance(module, nn.Conv2d):
78 | numerator = seq_lengths + 2 * module.padding[1] - module.dilation[1] * (module.kernel_size[1] - 1) - 1
79 | seq_lengths = numerator.float() / float(module.stride[1])
80 | seq_lengths = seq_lengths.int() + 1
81 |
82 | elif isinstance(module, nn.MaxPool2d):
83 | seq_lengths >>= 1
84 |
85 | return seq_lengths.int()
86 |
87 |
88 | class Linear(nn.Module):
89 | """
90 | Wrapper class of torch.nn.Linear
91 | Weight initialize by xavier initialization and bias initialize to zeros.
92 | """
93 | def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
94 | super(Linear, self).__init__()
95 | self.linear = nn.Linear(in_features, out_features, bias=bias)
96 | init.xavier_uniform_(self.linear.weight)
97 | if bias:
98 | init.zeros_(self.linear.bias)
99 |
100 | def forward(self, x: Tensor) -> Tensor:
101 | return self.linear(x)
102 |
103 |
104 | class Transpose(nn.Module):
105 | """ Wrapper class of torch.transpose() for Sequential module. """
106 | def __init__(self, shape: tuple):
107 | super(Transpose, self).__init__()
108 | self.shape = shape
109 |
110 | def forward(self, inputs: Tensor):
111 | return inputs.transpose(*self.shape)
112 |
--------------------------------------------------------------------------------
/speech_transformer/encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch.nn as nn
16 | from torch import Tensor
17 | from typing import Tuple, Optional, Any
18 |
19 | from speech_transformer.attention import MultiHeadAttention
20 | from speech_transformer.convolution import VGGExtractor
21 | from speech_transformer.embeddings import PositionalEncoding
22 | from speech_transformer.mask import get_attn_pad_mask
23 | from speech_transformer.modules import Linear
24 | from speech_transformer.sublayers import AddNorm, PositionWiseFeedForwardNet
25 |
26 |
27 | class SpeechTransformerEncoderLayer(nn.Module):
28 | """
29 | EncoderLayer is made up of self-attention and feedforward network.
30 | This standard encoder layer is based on the paper "Attention Is All You Need".
31 |
32 | Args:
33 | d_model: dimension of model (default: 512)
34 | num_heads: number of attention heads (default: 8)
35 | d_ff: dimension of feed forward network (default: 2048)
36 | dropout_p: probability of dropout (default: 0.3)
37 | ffnet_style: style of feed forward network [ff, conv] (default: ff)
38 | """
39 |
40 | def __init__(
41 | self,
42 | d_model: int = 512, # dimension of model
43 | num_heads: int = 8, # number of attention heads
44 | d_ff: int = 2048, # dimension of feed forward network
45 | dropout_p: float = 0.3, # probability of dropout
46 | ffnet_style: str = 'ff' # style of feed forward network
47 | ) -> None:
48 | super(SpeechTransformerEncoderLayer, self).__init__()
49 | self.self_attention = AddNorm(MultiHeadAttention(d_model, num_heads), d_model)
50 | self.feed_forward = AddNorm(PositionWiseFeedForwardNet(d_model, d_ff, dropout_p, ffnet_style), d_model)
51 |
52 | def forward(self, inputs: Tensor, self_attn_mask: Optional[Any] = None) -> Tuple[Tensor, Tensor]:
53 | output, attn = self.self_attention(inputs, inputs, inputs, self_attn_mask)
54 | output = self.feed_forward(output)
55 | return output, attn
56 |
57 |
58 | class SpeechTransformerEncoder(nn.Module):
59 | """
60 | The TransformerEncoder is composed of a stack of N identical layers.
61 | Each layer has two sub-layers. The first is a multi-head self-attention mechanism,
62 | and the second is a simple, position-wise fully connected feed-forward network.
63 |
64 | Args:
65 | d_model: dimension of model (default: 512)
66 | input_dim: dimension of feature vector (default: 80)
67 | d_ff: dimension of feed forward network (default: 2048)
68 | num_layers: number of encoder layers (default: 6)
69 | num_heads: number of attention heads (default: 8)
70 | ffnet_style: style of feed forward network [ff, conv] (default: ff)
71 | dropout_p: probability of dropout (default: 0.3)
72 | pad_id: identification of pad token (default: 0)
73 |
74 | Inputs:
75 | - **inputs**: list of sequences, whose length is the batch size and within which each sequence is list of tokens
76 | - **input_lengths**: list of sequence lengths
77 | """
78 |
79 | def __init__(
80 | self,
81 | d_model: int = 512,
82 | input_dim: int = 80,
83 | d_ff: int = 2048,
84 | num_layers: int = 6,
85 | num_heads: int = 8,
86 | ffnet_style: str = 'ff',
87 | dropout_p: float = 0.3,
88 | pad_id: int = 0,
89 | ) -> None:
90 | super(SpeechTransformerEncoder, self).__init__()
91 | self.d_model = d_model
92 | self.num_layers = num_layers
93 | self.num_heads = num_heads
94 | self.pad_id = pad_id
95 | self.conv = VGGExtractor(input_dim)
96 | self.input_proj = Linear(self.conv.get_output_dim(), d_model)
97 | self.input_dropout = nn.Dropout(p=dropout_p)
98 | self.positional_encoding = PositionalEncoding(d_model)
99 | self.layers = nn.ModuleList(
100 | [SpeechTransformerEncoderLayer(d_model, num_heads, d_ff, dropout_p, ffnet_style) for _ in range(num_layers)]
101 | )
102 |
103 | def forward(self, inputs: Tensor, input_lengths: Tensor = None) -> Tuple[Tensor, Tensor]:
104 | conv_outputs, output_lengths = self.conv(inputs, input_lengths)
105 |
106 | self_attn_mask = get_attn_pad_mask(conv_outputs, output_lengths, conv_outputs.size(1))
107 |
108 | outputs = self.input_proj(conv_outputs)
109 | outputs += self.positional_encoding(outputs.size(1))
110 | outputs = self.input_dropout(outputs)
111 |
112 | for layer in self.layers:
113 | outputs, attn = layer(outputs, self_attn_mask)
114 |
115 | return outputs, output_lengths
116 |
--------------------------------------------------------------------------------
/speech_transformer/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 | import numpy as np
19 |
20 | from speech_transformer.modules import Linear
21 | from torch import Tensor
22 | from typing import Optional, Tuple
23 |
24 |
25 | class ScaledDotProductAttention(nn.Module):
26 | """
27 | Scaled Dot-Product Attention proposed in "Attention Is All You Need"
28 | Compute the dot products of the query with all keys, divide each by sqrt(dim),
29 | and apply a softmax function to obtain the weights on the values
30 |
31 | Args: dim, mask
32 | dim (int): dimension of attention
33 | mask (torch.Tensor): tensor containing indices to be masked
34 |
35 | Inputs: query, key, value, mask
36 | - **query** (batch, q_len, d_model): tensor containing projection vector for decoder.
37 | - **key** (batch, k_len, d_model): tensor containing projection vector for encoder.
38 | - **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence.
39 | - **mask** (-): tensor containing indices to be masked
40 |
41 | Returns: context, attn
42 | - **context**: tensor containing the context vector from attention mechanism.
43 | - **attn**: tensor containing the attention (alignment) from the encoder outputs.
44 | """
45 | def __init__(self, dim: int) -> None:
46 | super(ScaledDotProductAttention, self).__init__()
47 | self.sqrt_dim = np.sqrt(dim)
48 |
49 | def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
50 | score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim
51 |
52 | if mask is not None:
53 | score.masked_fill_(mask, -1e9)
54 |
55 | attn = F.softmax(score, -1)
56 | context = torch.bmm(attn, value)
57 | return context, attn
58 |
59 |
60 | class MultiHeadAttention(nn.Module):
61 | """
62 | Multi-Head Attention proposed in "Attention Is All You Need"
63 | Instead of performing a single attention function with d_model-dimensional keys, values, and queries,
64 | project the queries, keys and values h times with different, learned linear projections to d_head dimensions.
65 | These are concatenated and once again projected, resulting in the final values.
66 | Multi-head attention allows the model to jointly attend to information from different representation
67 | subspaces at different positions.
68 |
69 | MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o
70 | where head_i = Attention(Q · W_q, K · W_k, V · W_v)
71 |
72 | Args:
73 | d_model (int): The dimension of keys / values / quries (default: 512)
74 | num_heads (int): The number of attention heads. (default: 8)
75 |
76 | Inputs: query, key, value, mask
77 | - **query** (batch, q_len, d_model): tensor containing projection vector for decoder.
78 | - **key** (batch, k_len, d_model): tensor containing projection vector for encoder.
79 | - **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence.
80 | - **mask** (-): tensor containing indices to be masked
81 |
82 | Returns: output, attn
83 | - **output** (batch, output_len, dimensions): tensor containing the attended output features.
84 | - **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
85 | """
86 | def __init__(self, d_model: int = 512, num_heads: int = 8) -> None:
87 | super(MultiHeadAttention, self).__init__()
88 |
89 | assert d_model % num_heads == 0, "hidden_dim % num_heads should be zero."
90 |
91 | self.d_head = int(d_model / num_heads)
92 | self.num_heads = num_heads
93 | self.query_proj = Linear(d_model, self.d_head * num_heads)
94 | self.key_proj = Linear(d_model, self.d_head * num_heads)
95 | self.value_proj = Linear(d_model, self.d_head * num_heads)
96 | self.sqrt_dim = np.sqrt(d_model)
97 | self.scaled_dot_attn = ScaledDotProductAttention(self.d_head)
98 |
99 | def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
100 | batch_size = value.size(0)
101 |
102 | query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) # BxQ_LENxNxD
103 | key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head) # BxK_LENxNxD
104 | value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head) # BxV_LENxNxD
105 |
106 | query = query.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxQ_LENxD
107 | key = key.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxK_LENxD
108 | value = value.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxV_LENxD
109 |
110 | if mask is not None:
111 | mask = mask.repeat(self.num_heads, 1, 1)
112 |
113 | context, attn = self.scaled_dot_attn(query, key, value, mask)
114 | context = context.view(self.num_heads, batch_size, -1, self.d_head)
115 | context = context.permute(1, 2, 0, 3).contiguous().view(batch_size, -1, self.num_heads * self.d_head) # BxTxND
116 |
117 | return context, attn
118 |
--------------------------------------------------------------------------------
/speech_transformer/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch.nn as nn
16 | from torch import Tensor
17 | from typing import Optional, Union
18 |
19 | from speech_transformer.beam_decoder import BeamTransformerDecoder
20 | from speech_transformer.decoder import SpeechTransformerDecoder
21 | from speech_transformer.encoder import SpeechTransformerEncoder
22 |
23 |
24 | class SpeechTransformer(nn.Module):
25 | """
26 | A Speech Transformer model. User is able to modify the attributes as needed.
27 | The model is based on the paper "Attention Is All You Need".
28 |
29 | Args:
30 | num_classes (int): the number of classfication
31 | d_model (int): dimension of model (default: 512)
32 | input_dim (int): dimension of input
33 | pad_id (int): identification of
34 | eos_id (int): identification of
35 | d_ff (int): dimension of feed forward network (default: 2048)
36 | num_encoder_layers (int): number of encoder layers (default: 6)
37 | num_decoder_layers (int): number of decoder layers (default: 6)
38 | num_heads (int): number of attention heads (default: 8)
39 | dropout_p (float): dropout probability (default: 0.3)
40 | ffnet_style (str): if poswise_ffnet is 'ff', position-wise feed forware network to be a feed forward,
41 | otherwise, position-wise feed forward network to be a convolution layer. (default: ff)
42 |
43 | Inputs: inputs, input_lengths, targets, teacher_forcing_ratio
44 | - **inputs** (torch.Tensor): tensor of sequences, whose length is the batch size and within which
45 | each sequence is a list of token IDs. This information is forwarded to the encoder.
46 | - **input_lengths** (torch.Tensor): tensor of sequences, whose contains length of inputs.
47 | - **targets** (torch.Tensor): tensor of sequences, whose length is the batch size and within which
48 | each sequence is a list of token IDs. This information is forwarded to the decoder.
49 |
50 | Returns: output
51 | - **output**: tensor containing the outputs
52 | """
53 |
54 | def __init__(
55 | self,
56 | num_classes: int,
57 | d_model: int = 512,
58 | input_dim: int = 80,
59 | pad_id: int = 0,
60 | sos_id: int = 1,
61 | eos_id: int = 2,
62 | d_ff: int = 2048,
63 | num_heads: int = 8,
64 | num_encoder_layers: int = 6,
65 | num_decoder_layers: int = 6,
66 | dropout_p: float = 0.3,
67 | ffnet_style: str = 'ff',
68 | extractor: str = 'vgg',
69 | joint_ctc_attention: bool = False,
70 | max_length: int = 128,
71 | ) -> None:
72 | super(SpeechTransformer, self).__init__()
73 |
74 | assert d_model % num_heads == 0, "d_model % num_heads should be zero."
75 |
76 | self.num_classes = num_classes
77 | self.extractor = extractor
78 | self.joint_ctc_attention = joint_ctc_attention
79 | self.sos_id = sos_id
80 | self.eos_id = eos_id
81 | self.pad_id = pad_id
82 | self.max_length = max_length
83 |
84 | self.encoder = SpeechTransformerEncoder(
85 | d_model=d_model,
86 | input_dim=input_dim,
87 | d_ff=d_ff,
88 | num_layers=num_encoder_layers,
89 | num_heads=num_heads,
90 | ffnet_style=ffnet_style,
91 | dropout_p=dropout_p,
92 | pad_id=pad_id,
93 | )
94 |
95 | self.decoder = SpeechTransformerDecoder(
96 | num_classes=num_classes,
97 | d_model=d_model,
98 | d_ff=d_ff,
99 | num_layers=num_decoder_layers,
100 | num_heads=num_heads,
101 | ffnet_style=ffnet_style,
102 | dropout_p=dropout_p,
103 | pad_id=pad_id,
104 | sos_id=sos_id,
105 | eos_id=eos_id,
106 | )
107 |
108 | def set_beam_decoder(self, batch_size: int = None, beam_size: int = 3):
109 | """ Setting beam search decoder """
110 | self.decoder = BeamTransformerDecoder(
111 | decoder=self.decoder,
112 | batch_size=batch_size,
113 | beam_size=beam_size,
114 | )
115 |
116 | def forward(
117 | self,
118 | inputs: Tensor,
119 | input_lengths: Tensor,
120 | targets: Optional[Tensor] = None,
121 | target_lengths: Optional[Tensor] = None,
122 | ) -> Union[Tensor, tuple]:
123 | """
124 | inputs (torch.FloatTensor): (batch_size, sequence_length, dimension)
125 | input_lengths (torch.LongTensor): (batch_size)
126 | """
127 | logits = None
128 | encoder_outputs, encoder_logits, encoder_output_lengths = self.encoder(inputs, input_lengths)
129 | if isinstance(self.decoder, BeamTransformerDecoder):
130 | predictions = self.decoder(encoder_outputs, encoder_output_lengths)
131 | else:
132 | logits = self.decoder(
133 | encoder_outputs=encoder_outputs,
134 | encoder_output_lengths=encoder_output_lengths,
135 | targets=targets,
136 | teacher_forcing_ratio=0.0,
137 | target_lengths=target_lengths,
138 | )
139 | predictions = logits.max(-1)[1]
140 |
141 | return predictions, logits
142 |
--------------------------------------------------------------------------------
/speech_transformer/convolution.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 | from torch import Tensor
18 | from typing import Tuple
19 |
20 |
21 | class MaskCNN(nn.Module):
22 | r"""
23 | Masking Convolutional Neural Network
24 |
25 | Adds padding to the output of the module based on the given lengths.
26 | This is to ensure that the results of the model do not change when batch sizes change during inference.
27 | Input needs to be in the shape of (batch_size, channel, hidden_dim, seq_len)
28 |
29 | Refer to https://github.com/SeanNaren/deepspeech.pytorch/blob/master/model.py
30 | Copyright (c) 2017 Sean Naren
31 | MIT License
32 |
33 | Args:
34 | sequential (torch.nn): sequential list of convolution layer
35 |
36 | Inputs: inputs, seq_lengths
37 | - **inputs** (torch.FloatTensor): The input of size BxCxHxT
38 | - **seq_lengths** (torch.IntTensor): The actual length of each sequence in the batch
39 |
40 | Returns: output, seq_lengths
41 | - **output**: Masked output from the sequential
42 | - **seq_lengths**: Sequence length of output from the sequential
43 | """
44 | def __init__(self, sequential: nn.Sequential) -> None:
45 | super(MaskCNN, self).__init__()
46 | self.sequential = sequential
47 |
48 | def forward(self, inputs: Tensor, seq_lengths: Tensor) -> Tuple[Tensor, Tensor]:
49 | output = None
50 |
51 | for module in self.sequential:
52 | output = module(inputs)
53 | mask = torch.BoolTensor(output.size()).fill_(0)
54 |
55 | if output.is_cuda:
56 | mask = mask.cuda()
57 |
58 | seq_lengths = self._get_sequence_lengths(module, seq_lengths)
59 |
60 | for idx, length in enumerate(seq_lengths):
61 | length = length.item()
62 |
63 | if (mask[idx].size(2) - length) > 0:
64 | mask[idx].narrow(dim=2, start=length, length=mask[idx].size(2) - length).fill_(1)
65 |
66 | output = output.masked_fill(mask, 0)
67 | inputs = output
68 |
69 | return output, seq_lengths
70 |
71 | def _get_sequence_lengths(self, module: nn.Module, seq_lengths: Tensor) -> Tensor:
72 | r"""
73 | Calculate convolutional neural network receptive formula
74 |
75 | Args:
76 | module (torch.nn.Module): module of CNN
77 | seq_lengths (torch.IntTensor): The actual length of each sequence in the batch
78 |
79 | Returns: seq_lengths
80 | - **seq_lengths**: Sequence length of output from the module
81 | """
82 | if isinstance(module, nn.Conv2d):
83 | numerator = seq_lengths + 2 * module.padding[1] - module.dilation[1] * (module.kernel_size[1] - 1) - 1
84 | seq_lengths = numerator.float() / float(module.stride[1])
85 | seq_lengths = seq_lengths.int() + 1
86 |
87 | elif isinstance(module, nn.MaxPool2d):
88 | seq_lengths >>= 1
89 |
90 | return seq_lengths.int()
91 |
92 |
93 | class VGGExtractor(nn.Module):
94 | r"""
95 | VGG extractor for automatic speech recognition described in
96 | "Advances in Joint CTC-Attention based End-to-End Speech Recognition with a Deep CNN Encoder and RNN-LM" paper
97 | - https://arxiv.org/pdf/1706.02737.pdf
98 |
99 | Args:
100 | input_dim (int): Dimension of input vector
101 | in_channels (int): Number of channels in the input image
102 | out_channels (int or tuple): Number of channels produced by the convolution
103 |
104 | Inputs: inputs, input_lengths
105 | - **inputs** (batch, time, dim): Tensor containing input vectors
106 | - **input_lengths**: Tensor containing containing sequence lengths
107 |
108 | Returns: outputs, output_lengths
109 | - **outputs**: Tensor produced by the convolution
110 | - **output_lengths**: Tensor containing sequence lengths produced by the convolution
111 | """
112 | def __init__(
113 | self,
114 | input_dim: int,
115 | in_channels: int = 1,
116 | out_channels: int or tuple = (64, 128),
117 | ):
118 | super(VGGExtractor, self).__init__()
119 | self.input_dim = input_dim
120 | self.in_channels = in_channels
121 | self.out_channels = out_channels
122 | self.conv = MaskCNN(
123 | nn.Sequential(
124 | nn.Conv2d(in_channels, out_channels[0], kernel_size=3, stride=1, padding=1, bias=False),
125 | nn.BatchNorm2d(num_features=out_channels[0]),
126 | nn.ReLU(),
127 | nn.Conv2d(out_channels[0], out_channels[0], kernel_size=3, stride=1, padding=1, bias=False),
128 | nn.BatchNorm2d(num_features=out_channels[0]),
129 | nn.ReLU(),
130 | nn.MaxPool2d(2, stride=2),
131 | nn.Conv2d(out_channels[0], out_channels[1], kernel_size=3, stride=1, padding=1, bias=False),
132 | nn.BatchNorm2d(num_features=out_channels[1]),
133 | nn.ReLU(),
134 | nn.Conv2d(out_channels[1], out_channels[1], kernel_size=3, stride=1, padding=1, bias=False),
135 | nn.BatchNorm2d(num_features=out_channels[1]),
136 | nn.ReLU(),
137 | nn.MaxPool2d(2, stride=2),
138 | )
139 | )
140 |
141 | def get_output_lengths(self, seq_lengths: Tensor):
142 | assert self.conv is not None, "self.conv should be defined"
143 |
144 | for module in self.conv:
145 | if isinstance(module, nn.Conv2d):
146 | numerator = seq_lengths + 2 * module.padding[1] - module.dilation[1] * (module.kernel_size[1] - 1) - 1
147 | seq_lengths = numerator.float() / float(module.stride[1])
148 | seq_lengths = seq_lengths.int() + 1
149 |
150 | elif isinstance(module, nn.MaxPool2d):
151 | seq_lengths >>= 1
152 |
153 | return seq_lengths.int()
154 |
155 | def get_output_dim(self):
156 | return (self.input_dim - 1) << 5 if self.input_dim % 2 else self.input_dim << 5
157 |
158 | def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]:
159 | r"""
160 | inputs: torch.FloatTensor (batch, time, dimension)
161 | input_lengths: torch.IntTensor (batch)
162 | """
163 | outputs, output_lengths = self.conv(inputs.unsqueeze(1).transpose(2, 3), input_lengths)
164 |
165 | batch_size, channels, dimension, seq_lengths = outputs.size()
166 | outputs = outputs.permute(0, 3, 1, 2)
167 | outputs = outputs.view(batch_size, seq_lengths, channels * dimension)
168 |
169 | return outputs, output_lengths
170 |
--------------------------------------------------------------------------------
/speech_transformer/decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 | import random
18 | from torch import Tensor
19 | from typing import Optional, Any, Tuple
20 |
21 | from speech_transformer.attention import MultiHeadAttention
22 | from speech_transformer.embeddings import Embedding, PositionalEncoding
23 | from speech_transformer.mask import get_attn_pad_mask, get_attn_subsequent_mask
24 | from speech_transformer.modules import Linear
25 | from speech_transformer.sublayers import AddNorm, PositionWiseFeedForwardNet
26 |
27 |
28 | class SpeechTransformerDecoderLayer(nn.Module):
29 | """
30 | DecoderLayer is made up of self-attention, multi-head attention and feedforward network.
31 | This standard decoder layer is based on the paper "Attention Is All You Need".
32 |
33 | Args:
34 | d_model: dimension of model (default: 512)
35 | num_heads: number of attention heads (default: 8)
36 | d_ff: dimension of feed forward network (default: 2048)
37 | dropout_p: probability of dropout (default: 0.3)
38 | ffnet_style: style of feed forward network [ff, conv] (default: ff)
39 | """
40 |
41 | def __init__(
42 | self,
43 | d_model: int = 512, # dimension of model
44 | num_heads: int = 8, # number of attention heads
45 | d_ff: int = 2048, # dimension of feed forward network
46 | dropout_p: float = 0.3, # probability of dropout
47 | ffnet_style: str = 'ff' # style of feed forward network
48 | ) -> None:
49 | super(SpeechTransformerDecoderLayer, self).__init__()
50 | self.self_attention = AddNorm(MultiHeadAttention(d_model, num_heads), d_model)
51 | self.memory_attention = AddNorm(MultiHeadAttention(d_model, num_heads), d_model)
52 | self.feed_forward = AddNorm(PositionWiseFeedForwardNet(d_model, d_ff, dropout_p, ffnet_style), d_model)
53 |
54 | def forward(
55 | self,
56 | inputs: Tensor,
57 | memory: Tensor,
58 | self_attn_mask: Optional[Any] = None,
59 | memory_mask: Optional[Any] = None
60 | ) -> Tuple[Tensor, Tensor, Tensor]:
61 | output, self_attn = self.self_attention(inputs, inputs, inputs, self_attn_mask)
62 | output, memory_attn = self.memory_attention(output, memory, memory, memory_mask)
63 | output = self.feed_forward(output)
64 | return output, self_attn, memory_attn
65 |
66 |
67 | class SpeechTransformerDecoder(nn.Module):
68 | r"""
69 | The TransformerDecoder is composed of a stack of N identical layers.
70 | Each layer has three sub-layers. The first is a multi-head self-attention mechanism,
71 | and the second is a multi-head attention mechanism, third is a feed-forward network.
72 |
73 | Args:
74 | num_classes: umber of classes
75 | d_model: dimension of model
76 | d_ff: dimension of feed forward network
77 | num_layers: number of decoder layers
78 | num_heads: number of attention heads
79 | ffnet_style: style of feed forward network
80 | dropout_p: probability of dropout
81 | pad_id: identification of pad token
82 | eos_id: identification of end of sentence token
83 | """
84 |
85 | def __init__(
86 | self,
87 | num_classes: int,
88 | d_model: int = 512,
89 | d_ff: int = 2048,
90 | num_layers: int = 6,
91 | num_heads: int = 8,
92 | ffnet_style: str = 'ff',
93 | dropout_p: float = 0.3,
94 | pad_id: int = 0,
95 | sos_id: int = 1,
96 | eos_id: int = 2,
97 | ) -> None:
98 | super(SpeechTransformerDecoder, self).__init__()
99 | self.d_model = d_model
100 | self.num_layers = num_layers
101 | self.num_heads = num_heads
102 | self.embedding = Embedding(num_classes, pad_id, d_model)
103 | self.positional_encoding = PositionalEncoding(d_model)
104 | self.input_dropout = nn.Dropout(p=dropout_p)
105 | self.layers = nn.ModuleList([
106 | SpeechTransformerDecoderLayer(d_model, num_heads, d_ff, dropout_p, ffnet_style) for _ in range(num_layers)
107 | ])
108 | self.pad_id = pad_id
109 | self.sos_id = sos_id
110 | self.eos_id = eos_id
111 | self.fc = nn.Sequential(
112 | nn.LayerNorm(d_model),
113 | Linear(d_model, num_classes, bias=False),
114 | )
115 |
116 | def forward_step(
117 | self,
118 | decoder_inputs,
119 | decoder_input_lengths,
120 | encoder_outputs,
121 | encoder_output_lengths,
122 | positional_encoding_length,
123 | ) -> Tensor:
124 | dec_self_attn_pad_mask = get_attn_pad_mask(
125 | decoder_inputs, decoder_input_lengths, decoder_inputs.size(1)
126 | )
127 | dec_self_attn_subsequent_mask = get_attn_subsequent_mask(decoder_inputs)
128 | self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)
129 |
130 | encoder_attn_mask = get_attn_pad_mask(
131 | encoder_outputs, encoder_output_lengths, decoder_inputs.size(1)
132 | )
133 |
134 | outputs = self.embedding(decoder_inputs) + self.positional_encoding(positional_encoding_length)
135 | outputs = self.input_dropout(outputs)
136 |
137 | for layer in self.layers:
138 | outputs, self_attn, memory_attn = layer(
139 | inputs=outputs,
140 | encoder_outputs=encoder_outputs,
141 | self_attn_mask=self_attn_mask,
142 | encoder_attn_mask=encoder_attn_mask,
143 | )
144 |
145 | return outputs
146 |
147 | def forward(
148 | self,
149 | encoder_outputs: Tensor,
150 | targets: Optional[torch.LongTensor] = None,
151 | encoder_output_lengths: Tensor = None,
152 | target_lengths: Tensor = None,
153 | teacher_forcing_ratio: float = 1.0,
154 | ) -> Tensor:
155 | r"""
156 | Forward propagate a `encoder_outputs` for training.
157 |
158 | Args:
159 | targets (torch.LongTensor): A target sequence passed to decoders. `IntTensor` of size
160 | ``(batch, seq_length)``
161 | encoder_outputs (torch.FloatTensor): A output sequence of encoders. `FloatTensor` of size
162 | ``(batch, seq_length, dimension)``
163 | encoder_output_lengths (torch.LongTensor): The length of encoders outputs. ``(batch)``
164 | teacher_forcing_ratio (float): ratio of teacher forcing
165 |
166 | Returns:
167 | * logits (torch.FloatTensor): Log probability of model predictions.
168 | """
169 | batch_size = encoder_outputs.size(0)
170 | use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
171 |
172 | if targets is not None and use_teacher_forcing:
173 | targets = targets[targets != self.eos_id].view(batch_size, -1)
174 | target_length = targets.size(1)
175 |
176 | outputs = self.forward_step(
177 | decoder_inputs=targets,
178 | decoder_input_lengths=target_lengths,
179 | encoder_outputs=encoder_outputs,
180 | encoder_output_lengths=encoder_output_lengths,
181 | positional_encoding_length=target_length,
182 | )
183 | return self.fc(outputs).log_softmax(dim=-1)
184 |
185 | # Inference
186 | else:
187 | logits = list()
188 |
189 | input_var = encoder_outputs.new_zeros(batch_size, self.max_length).long()
190 | input_var = input_var.fill_(self.pad_id)
191 | input_var[:, 0] = self.sos_id
192 |
193 | for di in range(1, self.max_length):
194 | input_lengths = torch.IntTensor(batch_size).fill_(di)
195 |
196 | outputs = self.forward_step(
197 | decoder_inputs=input_var[:, :di],
198 | decoder_input_lengths=input_lengths,
199 | encoder_outputs=encoder_outputs,
200 | encoder_output_lengths=encoder_output_lengths,
201 | positional_encoding_length=di,
202 | )
203 | step_output = self.fc(outputs).log_softmax(dim=-1)
204 |
205 | logits.append(step_output[:, -1, :])
206 | input_var = logits[-1].topk(1)[1]
207 |
208 | return torch.stack(logits, dim=1)
209 |
--------------------------------------------------------------------------------
/speech_transformer/beam_decoder.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | from torch import Tensor
5 |
6 | from speech_transformer.decoder import SpeechTransformerDecoder
7 |
8 |
9 | class BeamTransformerDecoder(nn.Module):
10 | def __init__(self, decoder: SpeechTransformerDecoder, batch_size: int, beam_size: int = 3) -> None:
11 | super(BeamTransformerDecoder, self).__init__()
12 | self.decoder = decoder
13 | self.beam_size = beam_size
14 | self.sos_id = decoder.sos_id
15 | self.pad_id = decoder.pad_id
16 | self.eos_id = decoder.eos_id
17 | self.ongoing_beams = None
18 | self.cumulative_ps = None
19 | self.finished = [[] for _ in range(batch_size)]
20 | self.finished_ps = [[] for _ in range(batch_size)]
21 | self.forward_step = decoder.forward_step
22 | self.use_cuda = True if torch.cuda.is_available() else False
23 |
24 | def _inflate(self, tensor: Tensor, n_repeat: int, dim: int) -> Tensor:
25 | repeat_dims = [1] * len(tensor.size())
26 | repeat_dims[dim] *= n_repeat
27 |
28 | return tensor.repeat(*repeat_dims)
29 |
30 | def _get_successor(
31 | self,
32 | current_ps: Tensor,
33 | current_vs: Tensor,
34 | finished_ids: tuple,
35 | num_successor: int,
36 | eos_count: int,
37 | k: int
38 | ) -> int:
39 | finished_batch_idx, finished_idx = finished_ids
40 |
41 | successor_ids = current_ps.topk(k + num_successor)[1]
42 | successor_idx = successor_ids[finished_batch_idx, -1]
43 |
44 | successor_p = current_ps[finished_batch_idx, successor_idx]
45 | successor_v = current_vs[finished_batch_idx, successor_idx]
46 |
47 | prev_status_idx = (successor_idx // k)
48 | prev_status = self.ongoing_beams[finished_batch_idx, prev_status_idx]
49 | prev_status = prev_status.view(-1)[:-1]
50 |
51 | successor = torch.cat([prev_status, successor_v.view(1)])
52 |
53 | if int(successor_v) == self.eos_id:
54 | self.finished[finished_batch_idx].append(successor)
55 | self.finished_ps[finished_batch_idx].append(successor_p)
56 | eos_count = self._get_successor(
57 | current_ps=current_ps,
58 | current_vs=current_vs,
59 | finished_ids=finished_ids,
60 | num_successor=num_successor + eos_count,
61 | eos_count=eos_count + 1,
62 | k=k,
63 | )
64 |
65 | else:
66 | self.ongoing_beams[finished_batch_idx, finished_idx] = successor
67 | self.cumulative_ps[finished_batch_idx, finished_idx] = successor_p
68 |
69 | return eos_count
70 |
71 | def _get_hypothesis(self):
72 | predictions = list()
73 |
74 | for batch_idx, batch in enumerate(self.finished):
75 | # if there is no terminated sentences, bring ongoing sentence which has the highest probability instead
76 | if len(batch) == 0:
77 | prob_batch = self.cumulative_ps[batch_idx]
78 | top_beam_idx = int(prob_batch.topk(1)[1])
79 | predictions.append(self.ongoing_beams[batch_idx, top_beam_idx])
80 |
81 | # bring highest probability sentence
82 | else:
83 | top_beam_idx = int(torch.FloatTensor(self.finished_ps[batch_idx]).topk(1)[1])
84 | predictions.append(self.finished[batch_idx][top_beam_idx])
85 |
86 | predictions = self._fill_sequence(predictions)
87 | return predictions
88 |
89 | def _is_all_finished(self, k: int) -> bool:
90 | for done in self.finished:
91 | if len(done) < k:
92 | return False
93 |
94 | return True
95 |
96 | def _fill_sequence(self, y_hats: list) -> Tensor:
97 | batch_size = len(y_hats)
98 | max_length = -1
99 |
100 | for y_hat in y_hats:
101 | if len(y_hat) > max_length:
102 | max_length = len(y_hat)
103 |
104 | matched = torch.zeros((batch_size, max_length), dtype=torch.long)
105 |
106 | for batch_idx, y_hat in enumerate(y_hats):
107 | matched[batch_idx, :len(y_hat)] = y_hat
108 | matched[batch_idx, len(y_hat):] = int(self.pad_id)
109 |
110 | return matched
111 |
112 | def forward(self, encoder_outputs: torch.FloatTensor, encoder_output_lengths: torch.FloatTensor):
113 | batch_size = encoder_outputs.size(0)
114 |
115 | decoder_inputs = torch.IntTensor(batch_size, self.decoder.max_length).fill_(self.sos_id).long()
116 | decoder_input_lengths = torch.IntTensor(batch_size).fill_(1)
117 |
118 | outputs = self.forward_step(
119 | decoder_inputs=decoder_inputs[:, :1],
120 | decoder_input_lengths=decoder_input_lengths,
121 | encoder_outputs=encoder_outputs,
122 | encoder_output_lengths=encoder_output_lengths,
123 | positional_encoding_length=1,
124 | )
125 | step_outputs = self.decoder.fc(outputs).log_softmax(dim=-1)
126 | self.cumulative_ps, self.ongoing_beams = step_outputs.topk(self.beam_size)
127 |
128 | self.ongoing_beams = self.ongoing_beams.view(batch_size * self.beam_size, 1)
129 | self.cumulative_ps = self.cumulative_ps.view(batch_size * self.beam_size, 1)
130 |
131 | decoder_inputs = torch.IntTensor(batch_size * self.beam_size, 1).fill_(self.sos_id)
132 | decoder_inputs = torch.cat((decoder_inputs, self.ongoing_beams), dim=-1) # bsz * beam x 2
133 |
134 | encoder_dim = encoder_outputs.size(2)
135 | encoder_outputs = self._inflate(encoder_outputs, self.beam_size, dim=0)
136 | encoder_outputs = encoder_outputs.view(self.beam_size, batch_size, -1, encoder_dim)
137 | encoder_outputs = encoder_outputs.transpose(0, 1)
138 | encoder_outputs = encoder_outputs.reshape(batch_size * self.beam_size, -1, encoder_dim)
139 |
140 | encoder_output_lengths = encoder_output_lengths.unsqueeze(1).repeat(1, self.beam_size).view(-1)
141 |
142 | for di in range(2, self.decoder.max_length):
143 | if self._is_all_finished(self.beam_size):
144 | break
145 |
146 | decoder_input_lengths = torch.LongTensor(batch_size * self.beam_size).fill_(di)
147 |
148 | step_outputs = self.forward_step(
149 | decoder_inputs=decoder_inputs[:, :di],
150 | decoder_input_lengths=decoder_input_lengths,
151 | encoder_outputs=encoder_outputs,
152 | encoder_output_lengths=encoder_output_lengths,
153 | positional_encoding_length=di,
154 | )
155 | step_outputs = self.decoder.fc(step_outputs).log_softmax(dim=-1)
156 |
157 | step_outputs = step_outputs.view(batch_size, self.beam_size, -1, 10)
158 | current_ps, current_vs = step_outputs.topk(self.beam_size)
159 |
160 | # TODO: Check transformer's beam search
161 | current_ps = current_ps[:, :, -1, :]
162 | current_vs = current_vs[:, :, -1, :]
163 |
164 | self.cumulative_ps = self.cumulative_ps.view(batch_size, self.beam_size)
165 | self.ongoing_beams = self.ongoing_beams.view(batch_size, self.beam_size, -1)
166 |
167 | current_ps = (current_ps.permute(0, 2, 1) + self.cumulative_ps.unsqueeze(1)).permute(0, 2, 1)
168 | current_ps = current_ps.view(batch_size, self.beam_size ** 2)
169 | current_vs = current_vs.contiguous().view(batch_size, self.beam_size ** 2)
170 |
171 | self.cumulative_ps = self.cumulative_ps.view(batch_size, self.beam_size)
172 | self.ongoing_beams = self.ongoing_beams.view(batch_size, self.beam_size, -1)
173 |
174 | topk_current_ps, topk_status_ids = current_ps.topk(self.beam_size)
175 | prev_status_ids = (topk_status_ids // self.beam_size)
176 |
177 | topk_current_vs = torch.zeros((batch_size, self.beam_size), dtype=torch.long)
178 | prev_status = torch.zeros(self.ongoing_beams.size(), dtype=torch.long)
179 |
180 | for batch_idx, batch in enumerate(topk_status_ids):
181 | for idx, topk_status_idx in enumerate(batch):
182 | topk_current_vs[batch_idx, idx] = current_vs[batch_idx, topk_status_idx]
183 | prev_status[batch_idx, idx] = self.ongoing_beams[batch_idx, prev_status_ids[batch_idx, idx]]
184 |
185 | self.ongoing_beams = torch.cat([prev_status, topk_current_vs.unsqueeze(2)], dim=2)
186 | self.cumulative_ps = topk_current_ps
187 |
188 | if torch.any(topk_current_vs == self.eos_id):
189 | finished_ids = torch.where(topk_current_vs == self.eos_id)
190 | num_successors = [1] * batch_size
191 |
192 | for (batch_idx, idx) in zip(*finished_ids):
193 | self.finished[batch_idx].append(self.ongoing_beams[batch_idx, idx])
194 | self.finished_ps[batch_idx].append(self.cumulative_ps[batch_idx, idx])
195 |
196 | if self.beam_size != 1:
197 | eos_count = self._get_successor(
198 | current_ps=current_ps,
199 | current_vs=current_vs,
200 | finished_ids=(batch_idx, idx),
201 | num_successor=num_successors[batch_idx],
202 | eos_count=1,
203 | k=self.beam_size,
204 | )
205 | num_successors[batch_idx] += eos_count
206 |
207 | ongoing_beams = self.ongoing_beams.clone().view(batch_size * self.beam_size, -1)
208 | decoder_inputs = torch.cat((decoder_inputs, ongoing_beams[:, :-1]), dim=-1)
209 |
210 | return self._get_hypothesis()
211 |
--------------------------------------------------------------------------------