├── LICENSE ├── README.md ├── __init__.py ├── tdnn.py └── tdnn_asym.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fast TDNN layer implementation 2 | 3 | This is an alternative implementation of the TDNN layer, proposed by Waibel _et al._ [1]. 4 | The main difference compared to other implementations is that it exploits the 5 | [Pytorch Conv1d](https://pytorch.org/docs/stable/nn.html?highlight=conv1d#torch.nn.Conv1d) dilatation argument, 6 | making it multitudes faster than other popular implementations such as 7 | [SiddGururani's PyTorch-TDNN](https://github.com/SiddGururani/Pytorch-TDNN). 8 | 9 | ## Usage 10 | ```python 11 | # Create a TDNN layer 12 | layer_context = [-2, 0, 2] 13 | input_n_feat = previous_layer_n_feat 14 | tddn_layer = TDNN(context=layer_context, input_channels=input_n_feat, output_channels=512, full_context=False) 15 | 16 | # Run a forward pass; batch.size = [BATCH_SIZE, INPUT_CHANNELS, SEQUENCE_LENGTH] 17 | out = tdnn_layer(batch) 18 | ``` 19 | 20 | ## References 21 | [\[1\] A. Waibel, T. Hanazawa, G. Hinton, and K. Shikano, 22 | “Phoneme Recognition Using Time-Delay Neural Networks,”, 1989](http://www.cs.toronto.edu/~fritz/absps/waibelTDNN.pdf) 23 | 24 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tdnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.utils import weight_norm 3 | 4 | __author__ = 'Jonas Van Der Donckt' 5 | 6 | 7 | class TDNN(nn.Module): 8 | def __init__(self, context: list, input_channels: int, output_channels: int, full_context: bool = True): 9 | """ 10 | Implementation of a 'Fast' TDNN layer by exploiting the dilation argument of the PyTorch Conv1d class 11 | 12 | Due to its fastness the context has gained two constraints: 13 | * The context must be symmetric 14 | * The context must have equal spacing between each consecutive element 15 | 16 | For example: the non-full and symmetric context {-3, -2, 0, +2, +3} is not valid since it doesn't have 17 | equal spacing; The non-full context {-6, -3, 0, 3, 6} is both symmetric and has an equal spacing, this is 18 | considered valid. 19 | 20 | :param context: The temporal context 21 | :param input_channels: The number of input channels 22 | :param output_channels: The number of channels produced by the temporal convolution 23 | :param full_context: Indicates whether a full context needs to be used 24 | """ 25 | super(TDNN, self).__init__() 26 | self.full_context = full_context 27 | self.input_dim = input_channels 28 | self.output_dim = output_channels 29 | 30 | context = sorted(context) 31 | self.check_valid_context(context, full_context) 32 | 33 | if full_context: 34 | kernel_size = context[-1] - context[0] + 1 if len(context) > 1 else 1 35 | self.temporal_conv = weight_norm(nn.Conv1d(input_channels, output_channels, kernel_size)) 36 | else: 37 | # use dilation 38 | delta = context[1] - context[0] 39 | self.temporal_conv = weight_norm( 40 | nn.Conv1d(input_channels, output_channels, kernel_size=len(context), dilation=delta)) 41 | 42 | def forward(self, x): 43 | """ 44 | :param x: is one batch of data, x.size(): [batch_size, input_channels, sequence_length] 45 | sequence length is the dimension of the arbitrary length data 46 | :return: [batch_size, output_dim, len(valid_steps)] 47 | """ 48 | return self.temporal_conv(x) 49 | 50 | @staticmethod 51 | def check_valid_context(context: list, full_context: bool) -> None: 52 | """ 53 | Check whether the context is symmetrical and whether and whether the passed 54 | context can be used for creating a convolution kernel with dil 55 | 56 | :param full_context: indicates whether the full context (dilation=1) will be used 57 | :param context: The context of the model, must be symmetric if no full context and have an equal spacing. 58 | """ 59 | if full_context: 60 | assert len(context) <= 2, "If the full context is given one must only define the smallest and largest" 61 | if len(context) == 2: 62 | assert context[0] + context[-1] == 0, "The context must be symmetric" 63 | else: 64 | assert len(context) % 2 != 0, "The context size must be odd" 65 | assert context[len(context) // 2] == 0, "The context contain 0 in the center" 66 | if len(context) > 1: 67 | delta = [context[i] - context[i - 1] for i in range(1, len(context))] 68 | assert all(delta[0] == delta[i] for i in range(1, len(delta))), "Intra context spacing must be equal!" 69 | -------------------------------------------------------------------------------- /tdnn_asym.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils import weight_norm 4 | 5 | __author__ = 'Jonas Van Der Donckt' 6 | 7 | 8 | class TDNN_ASYM(nn.Module): 9 | def __init__(self, input_channels: int, output_channels: int, context: list): 10 | """ 11 | Implementation of a TDNN layer which uses weight masking to create non symmetric convolutions 12 | 13 | :param input_channels: The number of input channels 14 | :param output_channels: The number of channels produced by the temporal convolution 15 | :param context: The temporal context 16 | """ 17 | super(TDNN_ASYM, self).__init__() 18 | 19 | # create the convolution mask 20 | self.conv_mask = self._create_conv_mask(context) 21 | 22 | # TDNN convolution 23 | self.temporal_conv = weight_norm(nn.Conv1d(input_channels, output_channels, 24 | kernel_size=self.conv_mask.size()[0])) 25 | 26 | # expand the mask and register a hook to zero gradients during backprop 27 | self.conv_mask = self.conv_mask.expand_as(self.temporal_conv.weight) 28 | self.temporal_conv.weight.register_backward_hook(lambda grad: grad * self.conv_mask) 29 | 30 | def forward(self, x): 31 | """ 32 | :param x: is one batch of data, x.size(): [batch_size, input_channels, sequence_length] 33 | sequence length is the dimension of the arbitrary length data 34 | :return: [batch_size, output_dim, sequence_length - kernel_size + 1] 35 | """ 36 | return self.temporal_conv(x) 37 | 38 | @staticmethod 39 | def _create_conv_mask(context: list) -> torch.Tensor: 40 | """ 41 | :param context: The temporal context 42 | TODO some more exlanation about the convolution 43 | :return: The convolutional mask 44 | """ 45 | context = sorted(context) 46 | min_pos, max_pos = context[0], context[-1] 47 | mask = torch.zeros(size=(max_pos - min_pos + 1,), dtype=torch.int8) 48 | context = list(map(lambda x: x-min_pos, context)) 49 | mask[context] = 1 50 | return mask 51 | --------------------------------------------------------------------------------