├── speech_transformer ├── __init__.py ├── mask.py ├── embeddings.py ├── sublayers.py ├── layers.py ├── modules.py ├── encoder.py ├── attention.py ├── model.py ├── convolution.py ├── decoder.py └── beam_decoder.py ├── LICENSE └── README.md /speech_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import SpeechTransformer 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Soohwan Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /speech_transformer/mask.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from torch import Tensor 18 | 19 | 20 | def get_attn_pad_mask(inputs, input_lengths, expand_length): 21 | """ mask position is set to 1 """ 22 | def get_transformer_non_pad_mask(inputs: Tensor, input_lengths: Tensor) -> Tensor: 23 | """ Padding position is set to 0, either use input_lengths or pad_id """ 24 | batch_size = inputs.size(0) 25 | 26 | if len(inputs.size()) == 2: 27 | non_pad_mask = inputs.new_ones(inputs.size()) # B x T 28 | elif len(inputs.size()) == 3: 29 | non_pad_mask = inputs.new_ones(inputs.size()[:-1]) # B x T 30 | else: 31 | raise ValueError(f"Unsupported input shape {inputs.size()}") 32 | 33 | for i in range(batch_size): 34 | non_pad_mask[i, input_lengths[i]:] = 0 35 | 36 | return non_pad_mask 37 | 38 | non_pad_mask = get_transformer_non_pad_mask(inputs, input_lengths) 39 | pad_mask = non_pad_mask.lt(1) 40 | attn_pad_mask = pad_mask.unsqueeze(1).expand(-1, expand_length, -1) 41 | return attn_pad_mask 42 | 43 | 44 | def get_attn_subsequent_mask(seq): 45 | assert seq.dim() == 2 46 | attn_shape = [seq.size(0), seq.size(1), seq.size(1)] 47 | subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1) 48 | 49 | if seq.is_cuda: 50 | subsequent_mask = subsequent_mask.cuda() 51 | 52 | return subsequent_mask 53 | -------------------------------------------------------------------------------- /speech_transformer/embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import torch 17 | import torch.nn as nn 18 | from torch import Tensor 19 | 20 | 21 | class PositionalEncoding(nn.Module): 22 | """ 23 | Positional Encoding proposed in "Attention Is All You Need". 24 | Since speech_transformer contains no recurrence and no convolution, in order for the model to make 25 | use of the order of the sequence, we must add some positional information. 26 | 27 | "Attention Is All You Need" use sine and cosine functions of different frequencies: 28 | PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model)) 29 | PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model)) 30 | """ 31 | def __init__(self, d_model: int = 512, max_len: int = 5000) -> None: 32 | super(PositionalEncoding, self).__init__() 33 | pe = torch.zeros(max_len, d_model, requires_grad=False) 34 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 35 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)) 36 | pe[:, 0::2] = torch.sin(position * div_term) 37 | pe[:, 1::2] = torch.cos(position * div_term) 38 | pe = pe.unsqueeze(0) 39 | self.register_buffer('pe', pe) 40 | 41 | def forward(self, length: int) -> Tensor: 42 | return self.pe[:, :length] 43 | 44 | 45 | class Embedding(nn.Module): 46 | """ 47 | Embedding layer. Similarly to other sequence transduction models, speech_transformer use learned embeddings 48 | to convert the input tokens and output tokens to vectors of dimension d_model. 49 | In the embedding layers, speech_transformer multiply those weights by sqrt(d_model) 50 | """ 51 | def __init__(self, num_embeddings: int, pad_id: int, d_model: int = 512) -> None: 52 | super(Embedding, self).__init__() 53 | self.sqrt_dim = math.sqrt(d_model) 54 | self.embedding = nn.Embedding(num_embeddings, d_model, padding_idx=pad_id) 55 | 56 | def forward(self, inputs: Tensor) -> Tensor: 57 | return self.embedding(inputs) * self.sqrt_dim 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Speech-Transformer 2 | 3 | PyTorch implementation of [The SpeechTransformer for Large-scale Mandarin Chinese Speech Recognition](https://ieeexplore.ieee.org/document/8682586). 4 | 5 | 6 | 7 | Speech Transformer is a transformer framework specialized in speech recognition tasks. 8 | This repository contains only model code, but you can train with speech transformer with this [repository](https://github.com/sooftware/KoSpeech). 9 | I appreciate any kind of [feedback or contribution](https://github.com/sooftware/Speech-Transformer/issues) 10 | 11 | ## Usage 12 | - Training 13 | ```python 14 | import torch 15 | from speech_transformer import SpeechTransformer 16 | 17 | BATCH_SIZE, SEQ_LENGTH, DIM, NUM_CLASSES = 3, 12345, 80, 4 18 | 19 | cuda = torch.cuda.is_available() 20 | device = torch.device('cuda' if cuda else 'cpu') 21 | 22 | inputs = torch.rand(BATCH_SIZE, SEQ_LENGTH, DIM).to(device) 23 | input_lengths = torch.IntTensor([100, 50, 8]) 24 | targets = torch.LongTensor([[2, 3, 3, 3, 3, 3, 2, 2, 1, 0], 25 | [2, 3, 3, 3, 3, 3, 2, 1, 2, 0], 26 | [2, 3, 3, 3, 3, 3, 2, 2, 0, 1]]).to(device) # 1 means 27 | target_lengths = torch.IntTensor([10, 9, 8]) 28 | 29 | model = SpeechTransformer(num_classes=NUM_CLASSES, d_model=512, num_heads=8, input_dim=DIM) 30 | predictions, logits = model(inputs, input_lengths, targets, target_lengths) 31 | ``` 32 | - Beam Search Decoding 33 | ```python 34 | import torch 35 | from speech_transformer import SpeechTransformer 36 | 37 | BATCH_SIZE, SEQ_LENGTH, DIM, NUM_CLASSES = 3, 12345, 80, 10 38 | 39 | cuda = torch.cuda.is_available() 40 | device = torch.device('cuda' if cuda else 'cpu') 41 | 42 | inputs = torch.rand(BATCH_SIZE, SEQ_LENGTH, DIM).to(device) # BxTxD 43 | input_lengths = torch.LongTensor([SEQ_LENGTH, SEQ_LENGTH - 10, SEQ_LENGTH - 20]).to(device) 44 | 45 | model = SpeechTransformer(num_classes=NUM_CLASSES, d_model=512, num_heads=8, input_dim=DIM) 46 | model.set_beam_decoder(batch_size=BATCH_SIZE, beam_size=3) 47 | predictions, _ = model(inputs, input_lengths) 48 | ``` 49 | 50 | ## Troubleshoots and Contributing 51 | If you have any questions, bug reports, and feature requests, please [open an issue](https://github.com/sooftware/Jasper-pytorch/issues) on github or 52 | contacts sh951011@gmail.com please. 53 | 54 | I appreciate any kind of feedback or contribution. Feel free to proceed with small issues like bug fixes, documentation improvement. For major contributions and new features, please discuss with the collaborators in corresponding issues. 55 | 56 | ## Code Style 57 | I follow [PEP-8](https://www.python.org/dev/peps/pep-0008/) for code style. Especially the style of docstrings is important to generate documentation. 58 | 59 | ## Reference 60 | - [The SpeechTransformer for Large-scale Mandarin Chinese Speech Recognition (Yuanyuan Zhao et al, 2019)](https://ieeexplore.ieee.org/document/8682586) 61 | - [kaituoxu/Speech-Transformer](https://github.com/kaituoxu/Speech-Transformer) 62 | 63 | ## Author 64 | 65 | * Soohwan Kim [@sooftware](https://github.com/sooftware) 66 | * Contacts: sh951011@gmail.com 67 | -------------------------------------------------------------------------------- /speech_transformer/sublayers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch.nn as nn 16 | 17 | from torch import Tensor 18 | from typing import Any, Optional 19 | from speech_transformer.modules import ( 20 | Linear, 21 | MaskConv2d, 22 | ) 23 | 24 | 25 | class AddNorm(nn.Module): 26 | """ 27 | Add & Normalization layer proposed in "Attention Is All You Need". 28 | Transformer employ a residual connection around each of the two sub-layers, 29 | (Multi-Head Attention & Feed-Forward) followed by layer normalization. 30 | """ 31 | def __init__(self, sublayer: nn.Module, d_model: int = 512) -> None: 32 | super(AddNorm, self).__init__() 33 | self.sublayer = sublayer 34 | self.layer_norm = nn.LayerNorm(d_model) 35 | 36 | def forward(self, *args): 37 | residual = args[0] 38 | output = self.sublayer(*args) 39 | 40 | if isinstance(output, tuple): 41 | return self.layer_norm(output[0] + residual), output[1] 42 | 43 | return self.layer_norm(output + residual) 44 | 45 | 46 | class PositionWiseFeedForwardNet(nn.Module): 47 | """ 48 | Position-wise Feedforward Networks proposed in "Attention Is All You Need". 49 | Fully connected feed-forward network, which is applied to each position separately and identically. 50 | This consists of two linear transformations with a ReLU activation in between. 51 | Another way of describing this is as two convolutions with kernel size 1. 52 | """ 53 | def __init__(self, d_model: int = 512, d_ff: int = 2048, 54 | dropout_p: float = 0.3, ffnet_style: str = 'ff') -> None: 55 | super(PositionWiseFeedForwardNet, self).__init__() 56 | self.ffnet_style = ffnet_style.lower() 57 | if self.ffnet_style == 'ff': 58 | self.feed_forward = nn.Sequential( 59 | Linear(d_model, d_ff), 60 | nn.Dropout(dropout_p), 61 | nn.ReLU(), 62 | Linear(d_ff, d_model), 63 | nn.Dropout(dropout_p), 64 | ) 65 | 66 | elif self.ffnet_style == 'conv': 67 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 68 | self.relu = nn.ReLU() 69 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 70 | 71 | else: 72 | raise ValueError("Unsupported mode: {0}".format(self.mode)) 73 | 74 | def forward(self, inputs: Tensor) -> Tensor: 75 | if self.ffnet_style == 'conv': 76 | output = self.conv1(inputs.transpose(1, 2)) 77 | output = self.relu(output) 78 | return self.conv2(output).transpose(1, 2) 79 | 80 | return self.feed_forward(inputs) 81 | -------------------------------------------------------------------------------- /speech_transformer/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch.nn as nn 16 | from torch import Tensor 17 | from typing import Tuple, Optional, Any 18 | from speech_transformer.attention import MultiHeadAttention 19 | from speech_transformer.sublayers import ( 20 | PositionWiseFeedForwardNet, 21 | AddNorm, 22 | ) 23 | 24 | 25 | class SpeechTransformerEncoderLayer(nn.Module): 26 | """ 27 | EncoderLayer is made up of self-attention and feedforward network. 28 | This standard encoder layer is based on the paper "Attention Is All You Need". 29 | 30 | Args: 31 | d_model: dimension of model (default: 512) 32 | num_heads: number of attention heads (default: 8) 33 | d_ff: dimension of feed forward network (default: 2048) 34 | dropout_p: probability of dropout (default: 0.3) 35 | ffnet_style: style of feed forward network [ff, conv] (default: ff) 36 | """ 37 | 38 | def __init__( 39 | self, 40 | d_model: int = 512, # dimension of model 41 | num_heads: int = 8, # number of attention heads 42 | d_ff: int = 2048, # dimension of feed forward network 43 | dropout_p: float = 0.3, # probability of dropout 44 | ffnet_style: str = 'ff' # style of feed forward network 45 | ) -> None: 46 | super(SpeechTransformerEncoderLayer, self).__init__() 47 | self.self_attention = AddNorm(MultiHeadAttention(d_model, num_heads), d_model) 48 | self.feed_forward = AddNorm(PositionWiseFeedForwardNet(d_model, d_ff, dropout_p, ffnet_style), d_model) 49 | 50 | def forward(self, inputs: Tensor, self_attn_mask: Optional[Any] = None) -> Tuple[Tensor, Tensor]: 51 | output, attn = self.self_attention(inputs, inputs, inputs, self_attn_mask) 52 | output = self.feed_forward(output) 53 | return output, attn 54 | 55 | 56 | class SpeechTransformerDecoderLayer(nn.Module): 57 | """ 58 | DecoderLayer is made up of self-attention, multi-head attention and feedforward network. 59 | This standard decoder layer is based on the paper "Attention Is All You Need". 60 | 61 | Args: 62 | d_model: dimension of model (default: 512) 63 | num_heads: number of attention heads (default: 8) 64 | d_ff: dimension of feed forward network (default: 2048) 65 | dropout_p: probability of dropout (default: 0.3) 66 | ffnet_style: style of feed forward network [ff, conv] (default: ff) 67 | """ 68 | 69 | def __init__( 70 | self, 71 | d_model: int = 512, # dimension of model 72 | num_heads: int = 8, # number of attention heads 73 | d_ff: int = 2048, # dimension of feed forward network 74 | dropout_p: float = 0.3, # probability of dropout 75 | ffnet_style: str = 'ff' # style of feed forward network 76 | ) -> None: 77 | super(SpeechTransformerDecoderLayer, self).__init__() 78 | self.self_attention = AddNorm(MultiHeadAttention(d_model, num_heads), d_model) 79 | self.memory_attention = AddNorm(MultiHeadAttention(d_model, num_heads), d_model) 80 | self.feed_forward = AddNorm(PositionWiseFeedForwardNet(d_model, d_ff, dropout_p, ffnet_style), d_model) 81 | 82 | def forward( 83 | self, 84 | inputs: Tensor, # tensor contains target sequence 85 | memory: Tensor, # tensor contains encoder outputs 86 | self_attn_mask: Optional[Any] = None, # tensor contains mask of self attention 87 | memory_mask: Optional[Any] = None # tensor contains mask of encoder outputs 88 | ) -> Tuple[Tensor, Tensor, Tensor]: 89 | output, self_attn = self.self_attention(inputs, inputs, inputs, self_attn_mask) 90 | output, memory_attn = self.memory_attention(output, memory, memory, memory_mask) 91 | output = self.feed_forward(output) 92 | return output, self_attn, memory_attn 93 | -------------------------------------------------------------------------------- /speech_transformer/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.init as init 18 | 19 | from torch import Tensor 20 | from typing import Tuple 21 | 22 | 23 | class MaskConv2d(nn.Module): 24 | """ 25 | Masking Convolutional Neural Network 26 | Adds padding to the output of the module based on the given lengths. 27 | This is to ensure that the results of the model do not change when batch sizes change during inference. 28 | Input needs to be in the shape of (batch_size, channel, hidden_dim, seq_len) 29 | Refer to https://github.com/SeanNaren/deepspeech.pytorch/blob/master/model.py 30 | Copyright (c) 2017 Sean Naren 31 | MIT License 32 | Args: 33 | sequential (torch.nn): sequential list of convolution layer 34 | Inputs: inputs, seq_lengths 35 | - **inputs** (torch.FloatTensor): The input of size BxCxHxT 36 | - **seq_lengths** (torch.IntTensor): The actual length of each sequence in the batch 37 | Returns: output, seq_lengths 38 | - **output**: Masked output from the sequential 39 | - **seq_lengths**: Sequence length of output from the sequential 40 | """ 41 | def __init__(self, sequential: nn.Sequential) -> None: 42 | super(MaskConv2d, self).__init__() 43 | self.sequential = sequential 44 | 45 | def forward(self, inputs: Tensor, seq_lengths: Tensor) -> Tuple[Tensor, Tensor]: 46 | output = None 47 | 48 | for module in self.sequential: 49 | output = module(inputs) 50 | mask = torch.BoolTensor(output.size()).fill_(0) 51 | 52 | if output.is_cuda: 53 | mask = mask.cuda() 54 | 55 | seq_lengths = self.get_sequence_lengths(module, seq_lengths) 56 | 57 | for idx, length in enumerate(seq_lengths): 58 | length = length.item() 59 | 60 | if (mask[idx].size(2) - length) > 0: 61 | mask[idx].narrow(dim=2, start=length, length=mask[idx].size(2) - length).fill_(1) 62 | 63 | output = output.masked_fill(mask, 0) 64 | inputs = output 65 | 66 | return output, seq_lengths 67 | 68 | def get_sequence_lengths(self, module: nn.Module, seq_lengths: Tensor) -> Tensor: 69 | """ 70 | Calculate convolutional neural network receptive formula 71 | Args: 72 | module (torch.nn.Module): module of CNN 73 | seq_lengths (torch.IntTensor): The actual length of each sequence in the batch 74 | Returns: seq_lengths 75 | - **seq_lengths**: Sequence length of output from the module 76 | """ 77 | if isinstance(module, nn.Conv2d): 78 | numerator = seq_lengths + 2 * module.padding[1] - module.dilation[1] * (module.kernel_size[1] - 1) - 1 79 | seq_lengths = numerator.float() / float(module.stride[1]) 80 | seq_lengths = seq_lengths.int() + 1 81 | 82 | elif isinstance(module, nn.MaxPool2d): 83 | seq_lengths >>= 1 84 | 85 | return seq_lengths.int() 86 | 87 | 88 | class Linear(nn.Module): 89 | """ 90 | Wrapper class of torch.nn.Linear 91 | Weight initialize by xavier initialization and bias initialize to zeros. 92 | """ 93 | def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: 94 | super(Linear, self).__init__() 95 | self.linear = nn.Linear(in_features, out_features, bias=bias) 96 | init.xavier_uniform_(self.linear.weight) 97 | if bias: 98 | init.zeros_(self.linear.bias) 99 | 100 | def forward(self, x: Tensor) -> Tensor: 101 | return self.linear(x) 102 | 103 | 104 | class Transpose(nn.Module): 105 | """ Wrapper class of torch.transpose() for Sequential module. """ 106 | def __init__(self, shape: tuple): 107 | super(Transpose, self).__init__() 108 | self.shape = shape 109 | 110 | def forward(self, inputs: Tensor): 111 | return inputs.transpose(*self.shape) 112 | -------------------------------------------------------------------------------- /speech_transformer/encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch.nn as nn 16 | from torch import Tensor 17 | from typing import Tuple, Optional, Any 18 | 19 | from speech_transformer.attention import MultiHeadAttention 20 | from speech_transformer.convolution import VGGExtractor 21 | from speech_transformer.embeddings import PositionalEncoding 22 | from speech_transformer.mask import get_attn_pad_mask 23 | from speech_transformer.modules import Linear 24 | from speech_transformer.sublayers import AddNorm, PositionWiseFeedForwardNet 25 | 26 | 27 | class SpeechTransformerEncoderLayer(nn.Module): 28 | """ 29 | EncoderLayer is made up of self-attention and feedforward network. 30 | This standard encoder layer is based on the paper "Attention Is All You Need". 31 | 32 | Args: 33 | d_model: dimension of model (default: 512) 34 | num_heads: number of attention heads (default: 8) 35 | d_ff: dimension of feed forward network (default: 2048) 36 | dropout_p: probability of dropout (default: 0.3) 37 | ffnet_style: style of feed forward network [ff, conv] (default: ff) 38 | """ 39 | 40 | def __init__( 41 | self, 42 | d_model: int = 512, # dimension of model 43 | num_heads: int = 8, # number of attention heads 44 | d_ff: int = 2048, # dimension of feed forward network 45 | dropout_p: float = 0.3, # probability of dropout 46 | ffnet_style: str = 'ff' # style of feed forward network 47 | ) -> None: 48 | super(SpeechTransformerEncoderLayer, self).__init__() 49 | self.self_attention = AddNorm(MultiHeadAttention(d_model, num_heads), d_model) 50 | self.feed_forward = AddNorm(PositionWiseFeedForwardNet(d_model, d_ff, dropout_p, ffnet_style), d_model) 51 | 52 | def forward(self, inputs: Tensor, self_attn_mask: Optional[Any] = None) -> Tuple[Tensor, Tensor]: 53 | output, attn = self.self_attention(inputs, inputs, inputs, self_attn_mask) 54 | output = self.feed_forward(output) 55 | return output, attn 56 | 57 | 58 | class SpeechTransformerEncoder(nn.Module): 59 | """ 60 | The TransformerEncoder is composed of a stack of N identical layers. 61 | Each layer has two sub-layers. The first is a multi-head self-attention mechanism, 62 | and the second is a simple, position-wise fully connected feed-forward network. 63 | 64 | Args: 65 | d_model: dimension of model (default: 512) 66 | input_dim: dimension of feature vector (default: 80) 67 | d_ff: dimension of feed forward network (default: 2048) 68 | num_layers: number of encoder layers (default: 6) 69 | num_heads: number of attention heads (default: 8) 70 | ffnet_style: style of feed forward network [ff, conv] (default: ff) 71 | dropout_p: probability of dropout (default: 0.3) 72 | pad_id: identification of pad token (default: 0) 73 | 74 | Inputs: 75 | - **inputs**: list of sequences, whose length is the batch size and within which each sequence is list of tokens 76 | - **input_lengths**: list of sequence lengths 77 | """ 78 | 79 | def __init__( 80 | self, 81 | d_model: int = 512, 82 | input_dim: int = 80, 83 | d_ff: int = 2048, 84 | num_layers: int = 6, 85 | num_heads: int = 8, 86 | ffnet_style: str = 'ff', 87 | dropout_p: float = 0.3, 88 | pad_id: int = 0, 89 | ) -> None: 90 | super(SpeechTransformerEncoder, self).__init__() 91 | self.d_model = d_model 92 | self.num_layers = num_layers 93 | self.num_heads = num_heads 94 | self.pad_id = pad_id 95 | self.conv = VGGExtractor(input_dim) 96 | self.input_proj = Linear(self.conv.get_output_dim(), d_model) 97 | self.input_dropout = nn.Dropout(p=dropout_p) 98 | self.positional_encoding = PositionalEncoding(d_model) 99 | self.layers = nn.ModuleList( 100 | [SpeechTransformerEncoderLayer(d_model, num_heads, d_ff, dropout_p, ffnet_style) for _ in range(num_layers)] 101 | ) 102 | 103 | def forward(self, inputs: Tensor, input_lengths: Tensor = None) -> Tuple[Tensor, Tensor]: 104 | conv_outputs, output_lengths = self.conv(inputs, input_lengths) 105 | 106 | self_attn_mask = get_attn_pad_mask(conv_outputs, output_lengths, conv_outputs.size(1)) 107 | 108 | outputs = self.input_proj(conv_outputs) 109 | outputs += self.positional_encoding(outputs.size(1)) 110 | outputs = self.input_dropout(outputs) 111 | 112 | for layer in self.layers: 113 | outputs, attn = layer(outputs, self_attn_mask) 114 | 115 | return outputs, output_lengths 116 | -------------------------------------------------------------------------------- /speech_transformer/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import numpy as np 19 | 20 | from speech_transformer.modules import Linear 21 | from torch import Tensor 22 | from typing import Optional, Tuple 23 | 24 | 25 | class ScaledDotProductAttention(nn.Module): 26 | """ 27 | Scaled Dot-Product Attention proposed in "Attention Is All You Need" 28 | Compute the dot products of the query with all keys, divide each by sqrt(dim), 29 | and apply a softmax function to obtain the weights on the values 30 | 31 | Args: dim, mask 32 | dim (int): dimension of attention 33 | mask (torch.Tensor): tensor containing indices to be masked 34 | 35 | Inputs: query, key, value, mask 36 | - **query** (batch, q_len, d_model): tensor containing projection vector for decoder. 37 | - **key** (batch, k_len, d_model): tensor containing projection vector for encoder. 38 | - **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence. 39 | - **mask** (-): tensor containing indices to be masked 40 | 41 | Returns: context, attn 42 | - **context**: tensor containing the context vector from attention mechanism. 43 | - **attn**: tensor containing the attention (alignment) from the encoder outputs. 44 | """ 45 | def __init__(self, dim: int) -> None: 46 | super(ScaledDotProductAttention, self).__init__() 47 | self.sqrt_dim = np.sqrt(dim) 48 | 49 | def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: 50 | score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim 51 | 52 | if mask is not None: 53 | score.masked_fill_(mask, -1e9) 54 | 55 | attn = F.softmax(score, -1) 56 | context = torch.bmm(attn, value) 57 | return context, attn 58 | 59 | 60 | class MultiHeadAttention(nn.Module): 61 | """ 62 | Multi-Head Attention proposed in "Attention Is All You Need" 63 | Instead of performing a single attention function with d_model-dimensional keys, values, and queries, 64 | project the queries, keys and values h times with different, learned linear projections to d_head dimensions. 65 | These are concatenated and once again projected, resulting in the final values. 66 | Multi-head attention allows the model to jointly attend to information from different representation 67 | subspaces at different positions. 68 | 69 | MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o 70 | where head_i = Attention(Q · W_q, K · W_k, V · W_v) 71 | 72 | Args: 73 | d_model (int): The dimension of keys / values / quries (default: 512) 74 | num_heads (int): The number of attention heads. (default: 8) 75 | 76 | Inputs: query, key, value, mask 77 | - **query** (batch, q_len, d_model): tensor containing projection vector for decoder. 78 | - **key** (batch, k_len, d_model): tensor containing projection vector for encoder. 79 | - **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence. 80 | - **mask** (-): tensor containing indices to be masked 81 | 82 | Returns: output, attn 83 | - **output** (batch, output_len, dimensions): tensor containing the attended output features. 84 | - **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs. 85 | """ 86 | def __init__(self, d_model: int = 512, num_heads: int = 8) -> None: 87 | super(MultiHeadAttention, self).__init__() 88 | 89 | assert d_model % num_heads == 0, "hidden_dim % num_heads should be zero." 90 | 91 | self.d_head = int(d_model / num_heads) 92 | self.num_heads = num_heads 93 | self.query_proj = Linear(d_model, self.d_head * num_heads) 94 | self.key_proj = Linear(d_model, self.d_head * num_heads) 95 | self.value_proj = Linear(d_model, self.d_head * num_heads) 96 | self.sqrt_dim = np.sqrt(d_model) 97 | self.scaled_dot_attn = ScaledDotProductAttention(self.d_head) 98 | 99 | def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: 100 | batch_size = value.size(0) 101 | 102 | query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) # BxQ_LENxNxD 103 | key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head) # BxK_LENxNxD 104 | value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head) # BxV_LENxNxD 105 | 106 | query = query.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxQ_LENxD 107 | key = key.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxK_LENxD 108 | value = value.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxV_LENxD 109 | 110 | if mask is not None: 111 | mask = mask.repeat(self.num_heads, 1, 1) 112 | 113 | context, attn = self.scaled_dot_attn(query, key, value, mask) 114 | context = context.view(self.num_heads, batch_size, -1, self.d_head) 115 | context = context.permute(1, 2, 0, 3).contiguous().view(batch_size, -1, self.num_heads * self.d_head) # BxTxND 116 | 117 | return context, attn 118 | -------------------------------------------------------------------------------- /speech_transformer/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch.nn as nn 16 | from torch import Tensor 17 | from typing import Optional, Union 18 | 19 | from speech_transformer.beam_decoder import BeamTransformerDecoder 20 | from speech_transformer.decoder import SpeechTransformerDecoder 21 | from speech_transformer.encoder import SpeechTransformerEncoder 22 | 23 | 24 | class SpeechTransformer(nn.Module): 25 | """ 26 | A Speech Transformer model. User is able to modify the attributes as needed. 27 | The model is based on the paper "Attention Is All You Need". 28 | 29 | Args: 30 | num_classes (int): the number of classfication 31 | d_model (int): dimension of model (default: 512) 32 | input_dim (int): dimension of input 33 | pad_id (int): identification of 34 | eos_id (int): identification of 35 | d_ff (int): dimension of feed forward network (default: 2048) 36 | num_encoder_layers (int): number of encoder layers (default: 6) 37 | num_decoder_layers (int): number of decoder layers (default: 6) 38 | num_heads (int): number of attention heads (default: 8) 39 | dropout_p (float): dropout probability (default: 0.3) 40 | ffnet_style (str): if poswise_ffnet is 'ff', position-wise feed forware network to be a feed forward, 41 | otherwise, position-wise feed forward network to be a convolution layer. (default: ff) 42 | 43 | Inputs: inputs, input_lengths, targets, teacher_forcing_ratio 44 | - **inputs** (torch.Tensor): tensor of sequences, whose length is the batch size and within which 45 | each sequence is a list of token IDs. This information is forwarded to the encoder. 46 | - **input_lengths** (torch.Tensor): tensor of sequences, whose contains length of inputs. 47 | - **targets** (torch.Tensor): tensor of sequences, whose length is the batch size and within which 48 | each sequence is a list of token IDs. This information is forwarded to the decoder. 49 | 50 | Returns: output 51 | - **output**: tensor containing the outputs 52 | """ 53 | 54 | def __init__( 55 | self, 56 | num_classes: int, 57 | d_model: int = 512, 58 | input_dim: int = 80, 59 | pad_id: int = 0, 60 | sos_id: int = 1, 61 | eos_id: int = 2, 62 | d_ff: int = 2048, 63 | num_heads: int = 8, 64 | num_encoder_layers: int = 6, 65 | num_decoder_layers: int = 6, 66 | dropout_p: float = 0.3, 67 | ffnet_style: str = 'ff', 68 | extractor: str = 'vgg', 69 | joint_ctc_attention: bool = False, 70 | max_length: int = 128, 71 | ) -> None: 72 | super(SpeechTransformer, self).__init__() 73 | 74 | assert d_model % num_heads == 0, "d_model % num_heads should be zero." 75 | 76 | self.num_classes = num_classes 77 | self.extractor = extractor 78 | self.joint_ctc_attention = joint_ctc_attention 79 | self.sos_id = sos_id 80 | self.eos_id = eos_id 81 | self.pad_id = pad_id 82 | self.max_length = max_length 83 | 84 | self.encoder = SpeechTransformerEncoder( 85 | d_model=d_model, 86 | input_dim=input_dim, 87 | d_ff=d_ff, 88 | num_layers=num_encoder_layers, 89 | num_heads=num_heads, 90 | ffnet_style=ffnet_style, 91 | dropout_p=dropout_p, 92 | pad_id=pad_id, 93 | ) 94 | 95 | self.decoder = SpeechTransformerDecoder( 96 | num_classes=num_classes, 97 | d_model=d_model, 98 | d_ff=d_ff, 99 | num_layers=num_decoder_layers, 100 | num_heads=num_heads, 101 | ffnet_style=ffnet_style, 102 | dropout_p=dropout_p, 103 | pad_id=pad_id, 104 | sos_id=sos_id, 105 | eos_id=eos_id, 106 | ) 107 | 108 | def set_beam_decoder(self, batch_size: int = None, beam_size: int = 3): 109 | """ Setting beam search decoder """ 110 | self.decoder = BeamTransformerDecoder( 111 | decoder=self.decoder, 112 | batch_size=batch_size, 113 | beam_size=beam_size, 114 | ) 115 | 116 | def forward( 117 | self, 118 | inputs: Tensor, 119 | input_lengths: Tensor, 120 | targets: Optional[Tensor] = None, 121 | target_lengths: Optional[Tensor] = None, 122 | ) -> Union[Tensor, tuple]: 123 | """ 124 | inputs (torch.FloatTensor): (batch_size, sequence_length, dimension) 125 | input_lengths (torch.LongTensor): (batch_size) 126 | """ 127 | logits = None 128 | encoder_outputs, encoder_logits, encoder_output_lengths = self.encoder(inputs, input_lengths) 129 | if isinstance(self.decoder, BeamTransformerDecoder): 130 | predictions = self.decoder(encoder_outputs, encoder_output_lengths) 131 | else: 132 | logits = self.decoder( 133 | encoder_outputs=encoder_outputs, 134 | encoder_output_lengths=encoder_output_lengths, 135 | targets=targets, 136 | teacher_forcing_ratio=0.0, 137 | target_lengths=target_lengths, 138 | ) 139 | predictions = logits.max(-1)[1] 140 | 141 | return predictions, logits 142 | -------------------------------------------------------------------------------- /speech_transformer/convolution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | from torch import Tensor 18 | from typing import Tuple 19 | 20 | 21 | class MaskCNN(nn.Module): 22 | r""" 23 | Masking Convolutional Neural Network 24 | 25 | Adds padding to the output of the module based on the given lengths. 26 | This is to ensure that the results of the model do not change when batch sizes change during inference. 27 | Input needs to be in the shape of (batch_size, channel, hidden_dim, seq_len) 28 | 29 | Refer to https://github.com/SeanNaren/deepspeech.pytorch/blob/master/model.py 30 | Copyright (c) 2017 Sean Naren 31 | MIT License 32 | 33 | Args: 34 | sequential (torch.nn): sequential list of convolution layer 35 | 36 | Inputs: inputs, seq_lengths 37 | - **inputs** (torch.FloatTensor): The input of size BxCxHxT 38 | - **seq_lengths** (torch.IntTensor): The actual length of each sequence in the batch 39 | 40 | Returns: output, seq_lengths 41 | - **output**: Masked output from the sequential 42 | - **seq_lengths**: Sequence length of output from the sequential 43 | """ 44 | def __init__(self, sequential: nn.Sequential) -> None: 45 | super(MaskCNN, self).__init__() 46 | self.sequential = sequential 47 | 48 | def forward(self, inputs: Tensor, seq_lengths: Tensor) -> Tuple[Tensor, Tensor]: 49 | output = None 50 | 51 | for module in self.sequential: 52 | output = module(inputs) 53 | mask = torch.BoolTensor(output.size()).fill_(0) 54 | 55 | if output.is_cuda: 56 | mask = mask.cuda() 57 | 58 | seq_lengths = self._get_sequence_lengths(module, seq_lengths) 59 | 60 | for idx, length in enumerate(seq_lengths): 61 | length = length.item() 62 | 63 | if (mask[idx].size(2) - length) > 0: 64 | mask[idx].narrow(dim=2, start=length, length=mask[idx].size(2) - length).fill_(1) 65 | 66 | output = output.masked_fill(mask, 0) 67 | inputs = output 68 | 69 | return output, seq_lengths 70 | 71 | def _get_sequence_lengths(self, module: nn.Module, seq_lengths: Tensor) -> Tensor: 72 | r""" 73 | Calculate convolutional neural network receptive formula 74 | 75 | Args: 76 | module (torch.nn.Module): module of CNN 77 | seq_lengths (torch.IntTensor): The actual length of each sequence in the batch 78 | 79 | Returns: seq_lengths 80 | - **seq_lengths**: Sequence length of output from the module 81 | """ 82 | if isinstance(module, nn.Conv2d): 83 | numerator = seq_lengths + 2 * module.padding[1] - module.dilation[1] * (module.kernel_size[1] - 1) - 1 84 | seq_lengths = numerator.float() / float(module.stride[1]) 85 | seq_lengths = seq_lengths.int() + 1 86 | 87 | elif isinstance(module, nn.MaxPool2d): 88 | seq_lengths >>= 1 89 | 90 | return seq_lengths.int() 91 | 92 | 93 | class VGGExtractor(nn.Module): 94 | r""" 95 | VGG extractor for automatic speech recognition described in 96 | "Advances in Joint CTC-Attention based End-to-End Speech Recognition with a Deep CNN Encoder and RNN-LM" paper 97 | - https://arxiv.org/pdf/1706.02737.pdf 98 | 99 | Args: 100 | input_dim (int): Dimension of input vector 101 | in_channels (int): Number of channels in the input image 102 | out_channels (int or tuple): Number of channels produced by the convolution 103 | 104 | Inputs: inputs, input_lengths 105 | - **inputs** (batch, time, dim): Tensor containing input vectors 106 | - **input_lengths**: Tensor containing containing sequence lengths 107 | 108 | Returns: outputs, output_lengths 109 | - **outputs**: Tensor produced by the convolution 110 | - **output_lengths**: Tensor containing sequence lengths produced by the convolution 111 | """ 112 | def __init__( 113 | self, 114 | input_dim: int, 115 | in_channels: int = 1, 116 | out_channels: int or tuple = (64, 128), 117 | ): 118 | super(VGGExtractor, self).__init__() 119 | self.input_dim = input_dim 120 | self.in_channels = in_channels 121 | self.out_channels = out_channels 122 | self.conv = MaskCNN( 123 | nn.Sequential( 124 | nn.Conv2d(in_channels, out_channels[0], kernel_size=3, stride=1, padding=1, bias=False), 125 | nn.BatchNorm2d(num_features=out_channels[0]), 126 | nn.ReLU(), 127 | nn.Conv2d(out_channels[0], out_channels[0], kernel_size=3, stride=1, padding=1, bias=False), 128 | nn.BatchNorm2d(num_features=out_channels[0]), 129 | nn.ReLU(), 130 | nn.MaxPool2d(2, stride=2), 131 | nn.Conv2d(out_channels[0], out_channels[1], kernel_size=3, stride=1, padding=1, bias=False), 132 | nn.BatchNorm2d(num_features=out_channels[1]), 133 | nn.ReLU(), 134 | nn.Conv2d(out_channels[1], out_channels[1], kernel_size=3, stride=1, padding=1, bias=False), 135 | nn.BatchNorm2d(num_features=out_channels[1]), 136 | nn.ReLU(), 137 | nn.MaxPool2d(2, stride=2), 138 | ) 139 | ) 140 | 141 | def get_output_lengths(self, seq_lengths: Tensor): 142 | assert self.conv is not None, "self.conv should be defined" 143 | 144 | for module in self.conv: 145 | if isinstance(module, nn.Conv2d): 146 | numerator = seq_lengths + 2 * module.padding[1] - module.dilation[1] * (module.kernel_size[1] - 1) - 1 147 | seq_lengths = numerator.float() / float(module.stride[1]) 148 | seq_lengths = seq_lengths.int() + 1 149 | 150 | elif isinstance(module, nn.MaxPool2d): 151 | seq_lengths >>= 1 152 | 153 | return seq_lengths.int() 154 | 155 | def get_output_dim(self): 156 | return (self.input_dim - 1) << 5 if self.input_dim % 2 else self.input_dim << 5 157 | 158 | def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]: 159 | r""" 160 | inputs: torch.FloatTensor (batch, time, dimension) 161 | input_lengths: torch.IntTensor (batch) 162 | """ 163 | outputs, output_lengths = self.conv(inputs.unsqueeze(1).transpose(2, 3), input_lengths) 164 | 165 | batch_size, channels, dimension, seq_lengths = outputs.size() 166 | outputs = outputs.permute(0, 3, 1, 2) 167 | outputs = outputs.view(batch_size, seq_lengths, channels * dimension) 168 | 169 | return outputs, output_lengths 170 | -------------------------------------------------------------------------------- /speech_transformer/decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | import random 18 | from torch import Tensor 19 | from typing import Optional, Any, Tuple 20 | 21 | from speech_transformer.attention import MultiHeadAttention 22 | from speech_transformer.embeddings import Embedding, PositionalEncoding 23 | from speech_transformer.mask import get_attn_pad_mask, get_attn_subsequent_mask 24 | from speech_transformer.modules import Linear 25 | from speech_transformer.sublayers import AddNorm, PositionWiseFeedForwardNet 26 | 27 | 28 | class SpeechTransformerDecoderLayer(nn.Module): 29 | """ 30 | DecoderLayer is made up of self-attention, multi-head attention and feedforward network. 31 | This standard decoder layer is based on the paper "Attention Is All You Need". 32 | 33 | Args: 34 | d_model: dimension of model (default: 512) 35 | num_heads: number of attention heads (default: 8) 36 | d_ff: dimension of feed forward network (default: 2048) 37 | dropout_p: probability of dropout (default: 0.3) 38 | ffnet_style: style of feed forward network [ff, conv] (default: ff) 39 | """ 40 | 41 | def __init__( 42 | self, 43 | d_model: int = 512, # dimension of model 44 | num_heads: int = 8, # number of attention heads 45 | d_ff: int = 2048, # dimension of feed forward network 46 | dropout_p: float = 0.3, # probability of dropout 47 | ffnet_style: str = 'ff' # style of feed forward network 48 | ) -> None: 49 | super(SpeechTransformerDecoderLayer, self).__init__() 50 | self.self_attention = AddNorm(MultiHeadAttention(d_model, num_heads), d_model) 51 | self.memory_attention = AddNorm(MultiHeadAttention(d_model, num_heads), d_model) 52 | self.feed_forward = AddNorm(PositionWiseFeedForwardNet(d_model, d_ff, dropout_p, ffnet_style), d_model) 53 | 54 | def forward( 55 | self, 56 | inputs: Tensor, 57 | memory: Tensor, 58 | self_attn_mask: Optional[Any] = None, 59 | memory_mask: Optional[Any] = None 60 | ) -> Tuple[Tensor, Tensor, Tensor]: 61 | output, self_attn = self.self_attention(inputs, inputs, inputs, self_attn_mask) 62 | output, memory_attn = self.memory_attention(output, memory, memory, memory_mask) 63 | output = self.feed_forward(output) 64 | return output, self_attn, memory_attn 65 | 66 | 67 | class SpeechTransformerDecoder(nn.Module): 68 | r""" 69 | The TransformerDecoder is composed of a stack of N identical layers. 70 | Each layer has three sub-layers. The first is a multi-head self-attention mechanism, 71 | and the second is a multi-head attention mechanism, third is a feed-forward network. 72 | 73 | Args: 74 | num_classes: umber of classes 75 | d_model: dimension of model 76 | d_ff: dimension of feed forward network 77 | num_layers: number of decoder layers 78 | num_heads: number of attention heads 79 | ffnet_style: style of feed forward network 80 | dropout_p: probability of dropout 81 | pad_id: identification of pad token 82 | eos_id: identification of end of sentence token 83 | """ 84 | 85 | def __init__( 86 | self, 87 | num_classes: int, 88 | d_model: int = 512, 89 | d_ff: int = 2048, 90 | num_layers: int = 6, 91 | num_heads: int = 8, 92 | ffnet_style: str = 'ff', 93 | dropout_p: float = 0.3, 94 | pad_id: int = 0, 95 | sos_id: int = 1, 96 | eos_id: int = 2, 97 | ) -> None: 98 | super(SpeechTransformerDecoder, self).__init__() 99 | self.d_model = d_model 100 | self.num_layers = num_layers 101 | self.num_heads = num_heads 102 | self.embedding = Embedding(num_classes, pad_id, d_model) 103 | self.positional_encoding = PositionalEncoding(d_model) 104 | self.input_dropout = nn.Dropout(p=dropout_p) 105 | self.layers = nn.ModuleList([ 106 | SpeechTransformerDecoderLayer(d_model, num_heads, d_ff, dropout_p, ffnet_style) for _ in range(num_layers) 107 | ]) 108 | self.pad_id = pad_id 109 | self.sos_id = sos_id 110 | self.eos_id = eos_id 111 | self.fc = nn.Sequential( 112 | nn.LayerNorm(d_model), 113 | Linear(d_model, num_classes, bias=False), 114 | ) 115 | 116 | def forward_step( 117 | self, 118 | decoder_inputs, 119 | decoder_input_lengths, 120 | encoder_outputs, 121 | encoder_output_lengths, 122 | positional_encoding_length, 123 | ) -> Tensor: 124 | dec_self_attn_pad_mask = get_attn_pad_mask( 125 | decoder_inputs, decoder_input_lengths, decoder_inputs.size(1) 126 | ) 127 | dec_self_attn_subsequent_mask = get_attn_subsequent_mask(decoder_inputs) 128 | self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0) 129 | 130 | encoder_attn_mask = get_attn_pad_mask( 131 | encoder_outputs, encoder_output_lengths, decoder_inputs.size(1) 132 | ) 133 | 134 | outputs = self.embedding(decoder_inputs) + self.positional_encoding(positional_encoding_length) 135 | outputs = self.input_dropout(outputs) 136 | 137 | for layer in self.layers: 138 | outputs, self_attn, memory_attn = layer( 139 | inputs=outputs, 140 | encoder_outputs=encoder_outputs, 141 | self_attn_mask=self_attn_mask, 142 | encoder_attn_mask=encoder_attn_mask, 143 | ) 144 | 145 | return outputs 146 | 147 | def forward( 148 | self, 149 | encoder_outputs: Tensor, 150 | targets: Optional[torch.LongTensor] = None, 151 | encoder_output_lengths: Tensor = None, 152 | target_lengths: Tensor = None, 153 | teacher_forcing_ratio: float = 1.0, 154 | ) -> Tensor: 155 | r""" 156 | Forward propagate a `encoder_outputs` for training. 157 | 158 | Args: 159 | targets (torch.LongTensor): A target sequence passed to decoders. `IntTensor` of size 160 | ``(batch, seq_length)`` 161 | encoder_outputs (torch.FloatTensor): A output sequence of encoders. `FloatTensor` of size 162 | ``(batch, seq_length, dimension)`` 163 | encoder_output_lengths (torch.LongTensor): The length of encoders outputs. ``(batch)`` 164 | teacher_forcing_ratio (float): ratio of teacher forcing 165 | 166 | Returns: 167 | * logits (torch.FloatTensor): Log probability of model predictions. 168 | """ 169 | batch_size = encoder_outputs.size(0) 170 | use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False 171 | 172 | if targets is not None and use_teacher_forcing: 173 | targets = targets[targets != self.eos_id].view(batch_size, -1) 174 | target_length = targets.size(1) 175 | 176 | outputs = self.forward_step( 177 | decoder_inputs=targets, 178 | decoder_input_lengths=target_lengths, 179 | encoder_outputs=encoder_outputs, 180 | encoder_output_lengths=encoder_output_lengths, 181 | positional_encoding_length=target_length, 182 | ) 183 | return self.fc(outputs).log_softmax(dim=-1) 184 | 185 | # Inference 186 | else: 187 | logits = list() 188 | 189 | input_var = encoder_outputs.new_zeros(batch_size, self.max_length).long() 190 | input_var = input_var.fill_(self.pad_id) 191 | input_var[:, 0] = self.sos_id 192 | 193 | for di in range(1, self.max_length): 194 | input_lengths = torch.IntTensor(batch_size).fill_(di) 195 | 196 | outputs = self.forward_step( 197 | decoder_inputs=input_var[:, :di], 198 | decoder_input_lengths=input_lengths, 199 | encoder_outputs=encoder_outputs, 200 | encoder_output_lengths=encoder_output_lengths, 201 | positional_encoding_length=di, 202 | ) 203 | step_output = self.fc(outputs).log_softmax(dim=-1) 204 | 205 | logits.append(step_output[:, -1, :]) 206 | input_var = logits[-1].topk(1)[1] 207 | 208 | return torch.stack(logits, dim=1) 209 | -------------------------------------------------------------------------------- /speech_transformer/beam_decoder.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from torch import Tensor 5 | 6 | from speech_transformer.decoder import SpeechTransformerDecoder 7 | 8 | 9 | class BeamTransformerDecoder(nn.Module): 10 | def __init__(self, decoder: SpeechTransformerDecoder, batch_size: int, beam_size: int = 3) -> None: 11 | super(BeamTransformerDecoder, self).__init__() 12 | self.decoder = decoder 13 | self.beam_size = beam_size 14 | self.sos_id = decoder.sos_id 15 | self.pad_id = decoder.pad_id 16 | self.eos_id = decoder.eos_id 17 | self.ongoing_beams = None 18 | self.cumulative_ps = None 19 | self.finished = [[] for _ in range(batch_size)] 20 | self.finished_ps = [[] for _ in range(batch_size)] 21 | self.forward_step = decoder.forward_step 22 | self.use_cuda = True if torch.cuda.is_available() else False 23 | 24 | def _inflate(self, tensor: Tensor, n_repeat: int, dim: int) -> Tensor: 25 | repeat_dims = [1] * len(tensor.size()) 26 | repeat_dims[dim] *= n_repeat 27 | 28 | return tensor.repeat(*repeat_dims) 29 | 30 | def _get_successor( 31 | self, 32 | current_ps: Tensor, 33 | current_vs: Tensor, 34 | finished_ids: tuple, 35 | num_successor: int, 36 | eos_count: int, 37 | k: int 38 | ) -> int: 39 | finished_batch_idx, finished_idx = finished_ids 40 | 41 | successor_ids = current_ps.topk(k + num_successor)[1] 42 | successor_idx = successor_ids[finished_batch_idx, -1] 43 | 44 | successor_p = current_ps[finished_batch_idx, successor_idx] 45 | successor_v = current_vs[finished_batch_idx, successor_idx] 46 | 47 | prev_status_idx = (successor_idx // k) 48 | prev_status = self.ongoing_beams[finished_batch_idx, prev_status_idx] 49 | prev_status = prev_status.view(-1)[:-1] 50 | 51 | successor = torch.cat([prev_status, successor_v.view(1)]) 52 | 53 | if int(successor_v) == self.eos_id: 54 | self.finished[finished_batch_idx].append(successor) 55 | self.finished_ps[finished_batch_idx].append(successor_p) 56 | eos_count = self._get_successor( 57 | current_ps=current_ps, 58 | current_vs=current_vs, 59 | finished_ids=finished_ids, 60 | num_successor=num_successor + eos_count, 61 | eos_count=eos_count + 1, 62 | k=k, 63 | ) 64 | 65 | else: 66 | self.ongoing_beams[finished_batch_idx, finished_idx] = successor 67 | self.cumulative_ps[finished_batch_idx, finished_idx] = successor_p 68 | 69 | return eos_count 70 | 71 | def _get_hypothesis(self): 72 | predictions = list() 73 | 74 | for batch_idx, batch in enumerate(self.finished): 75 | # if there is no terminated sentences, bring ongoing sentence which has the highest probability instead 76 | if len(batch) == 0: 77 | prob_batch = self.cumulative_ps[batch_idx] 78 | top_beam_idx = int(prob_batch.topk(1)[1]) 79 | predictions.append(self.ongoing_beams[batch_idx, top_beam_idx]) 80 | 81 | # bring highest probability sentence 82 | else: 83 | top_beam_idx = int(torch.FloatTensor(self.finished_ps[batch_idx]).topk(1)[1]) 84 | predictions.append(self.finished[batch_idx][top_beam_idx]) 85 | 86 | predictions = self._fill_sequence(predictions) 87 | return predictions 88 | 89 | def _is_all_finished(self, k: int) -> bool: 90 | for done in self.finished: 91 | if len(done) < k: 92 | return False 93 | 94 | return True 95 | 96 | def _fill_sequence(self, y_hats: list) -> Tensor: 97 | batch_size = len(y_hats) 98 | max_length = -1 99 | 100 | for y_hat in y_hats: 101 | if len(y_hat) > max_length: 102 | max_length = len(y_hat) 103 | 104 | matched = torch.zeros((batch_size, max_length), dtype=torch.long) 105 | 106 | for batch_idx, y_hat in enumerate(y_hats): 107 | matched[batch_idx, :len(y_hat)] = y_hat 108 | matched[batch_idx, len(y_hat):] = int(self.pad_id) 109 | 110 | return matched 111 | 112 | def forward(self, encoder_outputs: torch.FloatTensor, encoder_output_lengths: torch.FloatTensor): 113 | batch_size = encoder_outputs.size(0) 114 | 115 | decoder_inputs = torch.IntTensor(batch_size, self.decoder.max_length).fill_(self.sos_id).long() 116 | decoder_input_lengths = torch.IntTensor(batch_size).fill_(1) 117 | 118 | outputs = self.forward_step( 119 | decoder_inputs=decoder_inputs[:, :1], 120 | decoder_input_lengths=decoder_input_lengths, 121 | encoder_outputs=encoder_outputs, 122 | encoder_output_lengths=encoder_output_lengths, 123 | positional_encoding_length=1, 124 | ) 125 | step_outputs = self.decoder.fc(outputs).log_softmax(dim=-1) 126 | self.cumulative_ps, self.ongoing_beams = step_outputs.topk(self.beam_size) 127 | 128 | self.ongoing_beams = self.ongoing_beams.view(batch_size * self.beam_size, 1) 129 | self.cumulative_ps = self.cumulative_ps.view(batch_size * self.beam_size, 1) 130 | 131 | decoder_inputs = torch.IntTensor(batch_size * self.beam_size, 1).fill_(self.sos_id) 132 | decoder_inputs = torch.cat((decoder_inputs, self.ongoing_beams), dim=-1) # bsz * beam x 2 133 | 134 | encoder_dim = encoder_outputs.size(2) 135 | encoder_outputs = self._inflate(encoder_outputs, self.beam_size, dim=0) 136 | encoder_outputs = encoder_outputs.view(self.beam_size, batch_size, -1, encoder_dim) 137 | encoder_outputs = encoder_outputs.transpose(0, 1) 138 | encoder_outputs = encoder_outputs.reshape(batch_size * self.beam_size, -1, encoder_dim) 139 | 140 | encoder_output_lengths = encoder_output_lengths.unsqueeze(1).repeat(1, self.beam_size).view(-1) 141 | 142 | for di in range(2, self.decoder.max_length): 143 | if self._is_all_finished(self.beam_size): 144 | break 145 | 146 | decoder_input_lengths = torch.LongTensor(batch_size * self.beam_size).fill_(di) 147 | 148 | step_outputs = self.forward_step( 149 | decoder_inputs=decoder_inputs[:, :di], 150 | decoder_input_lengths=decoder_input_lengths, 151 | encoder_outputs=encoder_outputs, 152 | encoder_output_lengths=encoder_output_lengths, 153 | positional_encoding_length=di, 154 | ) 155 | step_outputs = self.decoder.fc(step_outputs).log_softmax(dim=-1) 156 | 157 | step_outputs = step_outputs.view(batch_size, self.beam_size, -1, 10) 158 | current_ps, current_vs = step_outputs.topk(self.beam_size) 159 | 160 | # TODO: Check transformer's beam search 161 | current_ps = current_ps[:, :, -1, :] 162 | current_vs = current_vs[:, :, -1, :] 163 | 164 | self.cumulative_ps = self.cumulative_ps.view(batch_size, self.beam_size) 165 | self.ongoing_beams = self.ongoing_beams.view(batch_size, self.beam_size, -1) 166 | 167 | current_ps = (current_ps.permute(0, 2, 1) + self.cumulative_ps.unsqueeze(1)).permute(0, 2, 1) 168 | current_ps = current_ps.view(batch_size, self.beam_size ** 2) 169 | current_vs = current_vs.contiguous().view(batch_size, self.beam_size ** 2) 170 | 171 | self.cumulative_ps = self.cumulative_ps.view(batch_size, self.beam_size) 172 | self.ongoing_beams = self.ongoing_beams.view(batch_size, self.beam_size, -1) 173 | 174 | topk_current_ps, topk_status_ids = current_ps.topk(self.beam_size) 175 | prev_status_ids = (topk_status_ids // self.beam_size) 176 | 177 | topk_current_vs = torch.zeros((batch_size, self.beam_size), dtype=torch.long) 178 | prev_status = torch.zeros(self.ongoing_beams.size(), dtype=torch.long) 179 | 180 | for batch_idx, batch in enumerate(topk_status_ids): 181 | for idx, topk_status_idx in enumerate(batch): 182 | topk_current_vs[batch_idx, idx] = current_vs[batch_idx, topk_status_idx] 183 | prev_status[batch_idx, idx] = self.ongoing_beams[batch_idx, prev_status_ids[batch_idx, idx]] 184 | 185 | self.ongoing_beams = torch.cat([prev_status, topk_current_vs.unsqueeze(2)], dim=2) 186 | self.cumulative_ps = topk_current_ps 187 | 188 | if torch.any(topk_current_vs == self.eos_id): 189 | finished_ids = torch.where(topk_current_vs == self.eos_id) 190 | num_successors = [1] * batch_size 191 | 192 | for (batch_idx, idx) in zip(*finished_ids): 193 | self.finished[batch_idx].append(self.ongoing_beams[batch_idx, idx]) 194 | self.finished_ps[batch_idx].append(self.cumulative_ps[batch_idx, idx]) 195 | 196 | if self.beam_size != 1: 197 | eos_count = self._get_successor( 198 | current_ps=current_ps, 199 | current_vs=current_vs, 200 | finished_ids=(batch_idx, idx), 201 | num_successor=num_successors[batch_idx], 202 | eos_count=1, 203 | k=self.beam_size, 204 | ) 205 | num_successors[batch_idx] += eos_count 206 | 207 | ongoing_beams = self.ongoing_beams.clone().view(batch_size * self.beam_size, -1) 208 | decoder_inputs = torch.cat((decoder_inputs, ongoing_beams[:, :-1]), dim=-1) 209 | 210 | return self._get_hypothesis() 211 | --------------------------------------------------------------------------------