├── .gitignore ├── LICENSE ├── README.md ├── figures ├── ftdnn.png └── ftdnn_arch.png └── models.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 cvqluu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Factorized-TDNN 2 | 3 | PyTorch implementation of the Factorized TDNN (TDNN-F) from ["Semi-Orthogonal Low-Rank Matrix Factorization for Deep Neural Networks"](http://danielpovey.com/files/2018_interspeech_tdnnf.pdf)[1]. This is also known as TDNN-F in nnet3 of [Kaldi](https://github.com/kaldi-asr/kaldi). 4 | 5 | ![model_fig](figures/ftdnn.png?raw=true "ftdnn diag") Taken from [1] 6 | 7 | A TDNN-F layer is implemented in the class `FTDNNLayer` of `models.py`. To be specific to the description in [1], it is an implementation of the **"3-stage splicing"** implementation, in which three convolutions are used in sequence, with the first two being constrained to be semi-orthogonal. These convolutions are followed by a ReLU and then BatchNorm layer. The semi-orthogonal constraint is the **"floating case"** in [1]. (TODO: implement the scaled case like in Kaldi) 8 | 9 | # Usage 10 | 11 | ## `FTDNNLayer` 12 | 13 | This `FTDNNLayer` of `models.py` is used as follows: 14 | 15 | ```python 16 | import torch 17 | from models import FTDNNLayer, SOrthConv 18 | 19 | tdnn_f = FTDNNLayer(1280, 512, 256, context_size=2, dilations=[2,2,2], paddings=[1,1,1]) 20 | # This is a sequence of three 2x1 convolutions 21 | # dimensions go from 1280 -> 256 -> 256 -> 512 22 | # dilations and paddings handles how much to dilate and pad each convolution 23 | # Having these configurable is to ensure the sequence length stays the same 24 | 25 | test_input = torch.rand(5, 100, 1280) 26 | # inputs to the FTDNNLayer must be (batch_size, seq_len, in_dim) 27 | 28 | tdnn_f(test_input).shape # returns (5, 100, 512) 29 | 30 | tdnn_f.step_semi_orth() # The key method to constrain the first two convolutions, perform after every SGD step 31 | 32 | tdnn_f.orth_error() # This returns the orth error of the constrained convs, useful for debugging 33 | ``` 34 | 35 | ## `SOrthConv` 36 | 37 | The components of `FTDNNLayer` which have the semi-orthogonal constraint are based around the class `SOrthConv`, which is essentially a `nn.Conv1d` with a `.step_semi_orth()` method to perform the semi-orthogonal update as in [1]. 38 | 39 | ```python 40 | sorth_conv = SOrthConv(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, padding_mode='zeros') 41 | ``` 42 | 43 | The implementation of the `.step_semi_orth()` method has been made to be as close to `ConstrainOrthonormalInternal` from [nnet-utils.cc](https://github.com/kaldi-asr/kaldi/blob/master/src/nnet3/nnet-utils.cc) in Kaldi's `nnet3` module. 44 | 45 | # Extras 46 | 47 | Also included in this repo in `models.py` is the following: 48 | * `FTDNN`: Factorized TDNN x-vector architecture (FTDNN) up to the embedding layer seen in ["State-of-the-art speaker recognition with neural network embeddings in NIST SRE18 and Speakers in the Wild evaluations"](https://www.sciencedirect.com/science/article/pii/S0885230819302700)[2]. (This is not EXACTLY the same, but should be close enough). 49 | * `SharedDimScaleDropout`: The shared dimension scaled dropout described in [1] and in Kaldi: 50 | * Instead of randomly setting inputs to 0, use a continuous dropout scale. 51 | * For a dropout 'strength' alpha, multiply inputs inputs by a mask sampled from the uniform distribution on the interval [1 - 2 \* alpha, 1 + 2 \* alpha]. 52 | * Share dropout masks along a dimension, such as time. From [1]: "If, for instance, a dimension is zeroed on a particular frame it will be zeroed on all frames of that sequence". 53 | 54 | ![model_fig](figures/ftdnn_arch.png?raw=true "ftdnn arch") The FTDNN x-vector architecture description taken from [2]. Up until layer 12 is implemented in `FTDNN` in `models.py`. 55 | 56 | 57 | # Demo [WIP] 58 | 59 | An demonstration of the `FTDNN` model being trained can be seen in the following output log (code not included, TODO: basic experiment demo): 60 | 61 | ``` 62 | exp/sp_ftdnn_bl: Wed Nov 20 14:21:15 2019: [10/120000] C-Loss:21.9116, AvgLoss:21.6991, lr: 0.2, bs: 400 63 | Orth error: 22.44341427081963 64 | exp/sp_ftdnn_bl: Wed Nov 20 14:21:29 2019: [20/120000] C-Loss:21.6260, AvgLoss:21.7459, lr: 0.2, bs: 400 65 | Orth error: 8.235212338215206 66 | exp/sp_ftdnn_bl: Wed Nov 20 14:21:43 2019: [30/120000] C-Loss:21.7663, AvgLoss:21.7525, lr: 0.2, bs: 400 67 | Orth error: 1.2611256236341433 68 | exp/sp_ftdnn_bl: Wed Nov 20 14:21:56 2019: [40/120000] C-Loss:21.6153, AvgLoss:21.6527, lr: 0.2, bs: 400 69 | Orth error: 0.005309408872562926 70 | exp/sp_ftdnn_bl: Wed Nov 20 14:22:14 2019: [50/120000] C-Loss:21.0997, AvgLoss:21.5722, lr: 0.2, bs: 400 71 | Orth error: 0.005543942232179688 72 | exp/sp_ftdnn_bl: Wed Nov 20 14:22:26 2019: [60/120000] C-Loss:21.2629, AvgLoss:21.5222, lr: 0.2, bs: 400 73 | Orth error: 0.004769200691953301 74 | exp/sp_ftdnn_bl: Wed Nov 20 14:22:40 2019: [70/120000] C-Loss:20.9551, AvgLoss:21.4158, lr: 0.2, bs: 400 75 | Orth error: 0.006055477493646322 76 | exp/sp_ftdnn_bl: Wed Nov 20 14:22:56 2019: [80/120000] C-Loss:20.4425, AvgLoss:21.3274, lr: 0.2, bs: 400 77 | Orth error: 0.009634702852054033 78 | exp/sp_ftdnn_bl: Wed Nov 20 14:23:09 2019: [90/120000] C-Loss:21.0025, AvgLoss:21.2727, lr: 0.2, bs: 400 79 | Orth error: 0.00611297079740325 80 | exp/sp_ftdnn_bl: Wed Nov 20 14:23:25 2019: [100/120000] C-Loss:20.6145, AvgLoss:21.1736, lr: 0.2, bs: 400 81 | Orth error: 0.008151484609697945 82 | exp/sp_ftdnn_bl: Wed Nov 20 14:23:38 2019: [110/120000] C-Loss:20.1985, AvgLoss:21.0890, lr: 0.2, bs: 400 83 | Orth error: 0.0072971017434610985 84 | exp/sp_ftdnn_bl: Wed Nov 20 14:23:53 2019: [120/120000] C-Loss:20.5698, AvgLoss:21.0300, lr: 0.2, bs: 400 85 | Orth error: 0.00629939052669215 86 | exp/sp_ftdnn_bl: Wed Nov 20 14:24:08 2019: [130/120000] C-Loss:20.2024, AvgLoss:20.9425, lr: 0.2, bs: 400 87 | Orth error: 0.008707787481398555 88 | exp/sp_ftdnn_bl: Wed Nov 20 14:24:21 2019: [140/120000] C-Loss:19.7034, AvgLoss:20.8641, lr: 0.2, bs: 400 89 | Orth error: 0.010941843771433923 90 | exp/sp_ftdnn_bl: Wed Nov 20 14:24:37 2019: [150/120000] C-Loss:19.9718, AvgLoss:20.8035, lr: 0.2, bs: 400 91 | Orth error: 0.00768740743296803 92 | ``` 93 | 94 | The FTDNN x-vector architecture seems to train successfully, and most importantly the Orth error is minimized. 95 | 96 | # TODOs 97 | 98 | * Implement 'scaled' case of semi-orthogonal constraint 99 | * Refactor so that seq_len is final dim (or not?) 100 | * Simple experiment/toy demo 101 | 102 | # References 103 | 104 | ``` 105 | [1] 106 | @inproceedings{Povey2018, 107 | author={Daniel Povey and Gaofeng Cheng and Yiming Wang and Ke Li and Hainan Xu and Mahsa Yarmohammadi and Sanjeev Khudanpur}, 108 | title={Semi-Orthogonal Low-Rank Matrix Factorization for Deep Neural Networks}, 109 | year=2018, 110 | booktitle={Proc. Interspeech 2018}, 111 | pages={3743--3747}, 112 | doi={10.21437/Interspeech.2018-1417}, 113 | url={http://dx.doi.org/10.21437/Interspeech.2018-1417} 114 | } 115 | ``` 116 | 117 | ``` 118 | [2] 119 | @article{VILLALBA2020101026, 120 | title = "State-of-the-art speaker recognition with neural network embeddings in NIST SRE18 and Speakers in the Wild evaluations", 121 | journal = "Computer Speech & Language", 122 | volume = "60", 123 | pages = "101026", 124 | year = "2020", 125 | issn = "0885-2308", 126 | doi = "https://doi.org/10.1016/j.csl.2019.101026", 127 | url = "http://www.sciencedirect.com/science/article/pii/S0885230819302700", 128 | author = "Jesús Villalba and Nanxin Chen and David Snyder and Daniel Garcia-Romero and Alan McCree and Gregory Sell and Jonas Borgstrom and Leibny Paola García-Perera and Fred Richardson and Réda Dehak and Pedro A. Torres-Carrasquillo and Najim Dehak" 129 | } 130 | ``` 131 | -------------------------------------------------------------------------------- /figures/ftdnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvqluu/Factorized-TDNN/13506c3cd3c3b6b41de560d2d2ca2ef09cf96438/figures/ftdnn.png -------------------------------------------------------------------------------- /figures/ftdnn_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvqluu/Factorized-TDNN/13506c3cd3c3b6b41de560d2d2ca2ef09cf96438/figures/ftdnn_arch.png -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SOrthConv(nn.Module): 7 | 8 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, padding_mode='zeros'): 9 | ''' 10 | Conv1d with a method for stepping towards semi-orthongonality 11 | http://danielpovey.com/files/2018_interspeech_tdnnf.pdf 12 | ''' 13 | super(SOrthConv, self).__init__() 14 | 15 | kwargs = {'bias': False} 16 | self.conv = nn.Conv1d(in_channels, out_channels, 17 | kernel_size, stride=stride, 18 | padding=padding, dilation=dilation, 19 | bias=False, padding_mode=padding_mode) 20 | self.reset_parameters() 21 | 22 | def forward(self, x): 23 | x = self.conv(x) 24 | return x 25 | 26 | def step_semi_orth(self): 27 | with torch.no_grad(): 28 | M = self.get_semi_orth_weight(self.conv) 29 | self.conv.weight.copy_(M) 30 | 31 | def reset_parameters(self): 32 | # Standard dev of M init values is inverse of sqrt of num cols 33 | nn.init._no_grad_normal_(self.conv.weight, 0., 34 | self.get_M_shape(self.conv.weight)[1]**-0.5) 35 | 36 | def orth_error(self): 37 | return self.get_semi_orth_error(self.conv).item() 38 | 39 | @staticmethod 40 | def get_semi_orth_weight(conv1dlayer): 41 | # updates conv1 weight M using update rule to make it more semi orthogonal 42 | # based off ConstrainOrthonormalInternal in nnet-utils.cc in Kaldi src/nnet3 43 | # includes the tweaks related to slowing the update speed 44 | # only an implementation of the 'floating scale' case 45 | with torch.no_grad(): 46 | update_speed = 0.125 47 | orig_shape = conv1dlayer.weight.shape 48 | # a conv weight differs slightly from TDNN formulation: 49 | # Conv weight: (out_filters, in_filters, kernel_width) 50 | # TDNN weight M is of shape: (in_dim, out_dim) or [rows, cols] 51 | # the in_dim of the TDNN weight is equivalent to in_filters * kernel_width of the Conv 52 | M = conv1dlayer.weight.reshape( 53 | orig_shape[0], orig_shape[1]*orig_shape[2]).T 54 | # M now has shape (in_dim[rows], out_dim[cols]) 55 | mshape = M.shape 56 | if mshape[0] > mshape[1]: # semi orthogonal constraint for rows > cols 57 | M = M.T 58 | P = torch.mm(M, M.T) 59 | PP = torch.mm(P, P.T) 60 | trace_P = torch.trace(P) 61 | trace_PP = torch.trace(PP) 62 | ratio = trace_PP * P.shape[0] / (trace_P * trace_P) 63 | 64 | # the following is the tweak to avoid divergence (more info in Kaldi) 65 | assert ratio > 0.99 66 | if ratio > 1.02: 67 | update_speed *= 0.5 68 | if ratio > 1.1: 69 | update_speed *= 0.5 70 | 71 | scale2 = trace_PP/trace_P 72 | update = P - (torch.matrix_power(P, 0) * scale2) 73 | alpha = update_speed / scale2 74 | update = (-4.0 * alpha) * torch.mm(update, M) 75 | updated = M + update 76 | # updated has shape (cols, rows) if rows > cols, else has shape (rows, cols) 77 | # Transpose (or not) to shape (cols, rows) (IMPORTANT, s.t. correct dimensions are reshaped) 78 | # Then reshape to (cols, in_filters, kernel_width) 79 | return updated.reshape(*orig_shape) if mshape[0] > mshape[1] else updated.T.reshape(*orig_shape) 80 | 81 | @staticmethod 82 | def get_M_shape(conv_weight): 83 | orig_shape = conv_weight.shape 84 | return (orig_shape[1]*orig_shape[2], orig_shape[0]) 85 | 86 | @staticmethod 87 | def get_semi_orth_error(conv1dlayer): 88 | with torch.no_grad(): 89 | orig_shape = conv1dlayer.weight.shape 90 | M = conv1dlayer.weight.reshape( 91 | orig_shape[0], orig_shape[1]*orig_shape[2]).T 92 | mshape = M.shape 93 | if mshape[0] > mshape[1]: # semi orthogonal constraint for rows > cols 94 | M = M.T 95 | P = torch.mm(M, M.T) 96 | PP = torch.mm(P, P.T) 97 | trace_P = torch.trace(P) 98 | trace_PP = torch.trace(PP) 99 | scale2 = torch.sqrt(trace_PP/trace_P) ** 2 100 | update = P - (torch.matrix_power(P, 0) * scale2) 101 | return torch.norm(update, p='fro') 102 | 103 | 104 | class SharedDimScaleDropout(nn.Module): 105 | def __init__(self, alpha: float = 0.5, dim=1): 106 | ''' 107 | Continuous scaled dropout that is const over chosen dim (usually across time) 108 | Multiplies inputs by random mask taken from Uniform([1 - 2\alpha, 1 + 2\alpha]) 109 | ''' 110 | super(SharedDimScaleDropout, self).__init__() 111 | if alpha > 0.5 or alpha < 0: 112 | raise ValueError("alpha must be between 0 and 0.5") 113 | self.alpha = alpha 114 | self.dim = dim 115 | self.register_buffer('mask', torch.tensor(0.)) 116 | 117 | def forward(self, X): 118 | if self.training: 119 | if self.alpha != 0.: 120 | # sample mask from uniform dist with dim of length 1 in self.dim and then repeat to match size 121 | tied_mask_shape = list(X.shape) 122 | tied_mask_shape[self.dim] = 1 123 | repeats = [1 if i != self.dim else X.shape[self.dim] 124 | for i in range(len(X.shape))] 125 | return X * self.mask.repeat(tied_mask_shape).uniform_(1 - 2*self.alpha, 1 + 2*self.alpha).repeat(repeats) 126 | # expected value of dropout mask is 1 so no need to scale outputs like vanilla dropout 127 | return X 128 | 129 | 130 | class FTDNNLayer(nn.Module): 131 | 132 | def __init__(self, in_dim, out_dim, bottleneck_dim, context_size=2, dilations=None, paddings=None, alpha=0.0): 133 | ''' 134 | 3 stage factorised TDNN http://danielpovey.com/files/2018_interspeech_tdnnf.pdf 135 | ''' 136 | super(FTDNNLayer, self).__init__() 137 | paddings = [1, 1, 1] if not paddings else paddings 138 | dilations = [2, 2, 2] if not dilations else dilations 139 | assert len(paddings) == 3 140 | assert len(dilations) == 3 141 | self.factor1 = SOrthConv( 142 | in_dim, bottleneck_dim, context_size, padding=paddings[0], dilation=dilations[0]) 143 | self.factor2 = SOrthConv(bottleneck_dim, bottleneck_dim, 144 | context_size, padding=paddings[1], dilation=dilations[1]) 145 | self.factor3 = nn.Conv1d(bottleneck_dim, out_dim, context_size, 146 | padding=paddings[2], dilation=dilations[2], bias=False) 147 | self.nl = nn.ReLU() 148 | self.bn = nn.BatchNorm1d(out_dim) 149 | self.dropout = SharedDimScaleDropout(alpha=alpha, dim=1) 150 | 151 | def forward(self, x): 152 | ''' input (batch_size, seq_len, in_dim) ''' 153 | assert (x.shape[-1] == self.factor1.conv.weight.shape[1]) 154 | x = self.factor1(x.transpose(1, 2)) 155 | x = self.factor2(x) 156 | x = self.factor3(x) 157 | x = self.nl(x) 158 | x = self.bn(x).transpose(1, 2) 159 | x = self.dropout(x) 160 | return x 161 | 162 | def step_semi_orth(self): 163 | for layer in self.children(): 164 | if isinstance(layer, SOrthConv): 165 | layer.step_semi_orth() 166 | 167 | def orth_error(self): 168 | orth_error = 0 169 | for layer in self.children(): 170 | if isinstance(layer, SOrthConv): 171 | orth_error += layer.orth_error() 172 | return orth_error 173 | 174 | 175 | class DenseReLU(nn.Module): 176 | 177 | def __init__(self, in_dim, out_dim): 178 | super(DenseReLU, self).__init__() 179 | self.fc = nn.Linear(in_dim, out_dim) 180 | self.bn = nn.BatchNorm1d(out_dim) 181 | self.nl = nn.ReLU() 182 | 183 | def forward(self, x): 184 | x = self.fc(x) 185 | x = self.nl(x) 186 | if len(x.shape) > 2: 187 | x = self.bn(x.transpose(1, 2)).transpose(1, 2) 188 | else: 189 | x = self.bn(x) 190 | return x 191 | 192 | 193 | class StatsPool(nn.Module): 194 | 195 | def __init__(self, floor=1e-10, bessel=False): 196 | super(StatsPool, self).__init__() 197 | self.floor = floor 198 | self.bessel = bessel 199 | 200 | def forward(self, x): 201 | means = torch.mean(x, dim=1) 202 | _, t, _ = x.shape 203 | if self.bessel: 204 | t = t - 1 205 | residuals = x - means.unsqueeze(1) 206 | numerator = torch.sum(residuals**2, dim=1) 207 | stds = torch.sqrt(torch.clamp(numerator, min=self.floor)/t) 208 | x = torch.cat([means, stds], dim=1) 209 | return x 210 | 211 | 212 | class TDNN(nn.Module): 213 | 214 | def __init__( 215 | self, 216 | input_dim=23, 217 | output_dim=512, 218 | context_size=5, 219 | stride=1, 220 | dilation=1, 221 | batch_norm=True, 222 | dropout_p=0.0, 223 | padding=0 224 | ): 225 | super(TDNN, self).__init__() 226 | self.context_size = context_size 227 | self.stride = stride 228 | self.input_dim = input_dim 229 | self.output_dim = output_dim 230 | self.dilation = dilation 231 | self.dropout_p = dropout_p 232 | self.padding = padding 233 | 234 | self.kernel = nn.Conv1d(self.input_dim, 235 | self.output_dim, 236 | self.context_size, 237 | stride=self.stride, 238 | padding=self.padding, 239 | dilation=self.dilation) 240 | 241 | self.nonlinearity = nn.ReLU() 242 | self.batch_norm = batch_norm 243 | if batch_norm: 244 | self.bn = nn.BatchNorm1d(output_dim) 245 | self.drop = nn.Dropout(p=self.dropout_p) 246 | 247 | def forward(self, x): 248 | ''' 249 | input: size (batch, seq_len, input_features) 250 | outpu: size (batch, new_seq_len, output_features) 251 | ''' 252 | 253 | _, _, d = x.shape 254 | assert (d == self.input_dim), 'Input dimension was wrong. Expected ({}), got ({})'.format( 255 | self.input_dim, d) 256 | 257 | x = self.kernel(x.transpose(1, 2)) 258 | x = self.nonlinearity(x) 259 | x = self.drop(x) 260 | 261 | if self.batch_norm: 262 | x = self.bn(x) 263 | return x.transpose(1, 2) 264 | 265 | 266 | class FTDNN(nn.Module): 267 | 268 | def __init__(self, in_dim=30): 269 | ''' 270 | The FTDNN architecture from 271 | "State-of-the-art speaker recognition with neural network embeddings in 272 | NIST SRE18 and Speakers in the Wild evaluations" 273 | https://www.sciencedirect.com/science/article/pii/S0885230819302700 274 | ''' 275 | super(FTDNN, self).__init__() 276 | 277 | self.layer01 = TDNN(input_dim=in_dim, output_dim=512, 278 | context_size=5, padding=2) 279 | self.layer02 = FTDNNLayer(512, 1024, 256, context_size=2, dilations=[ 280 | 2, 2, 2], paddings=[1, 1, 1]) 281 | self.layer03 = FTDNNLayer(1024, 1024, 256, context_size=1, dilations=[ 282 | 1, 1, 1], paddings=[0, 0, 0]) 283 | self.layer04 = FTDNNLayer(1024, 1024, 256, context_size=2, dilations=[ 284 | 3, 3, 2], paddings=[2, 1, 1]) 285 | self.layer05 = FTDNNLayer(2048, 1024, 256, context_size=1, dilations=[ 286 | 1, 1, 1], paddings=[0, 0, 0]) 287 | self.layer06 = FTDNNLayer(1024, 1024, 256, context_size=2, dilations=[ 288 | 3, 3, 2], paddings=[2, 1, 1]) 289 | self.layer07 = FTDNNLayer(3072, 1024, 256, context_size=2, dilations=[ 290 | 3, 3, 2], paddings=[2, 1, 1]) 291 | self.layer08 = FTDNNLayer(1024, 1024, 256, context_size=2, dilations=[ 292 | 3, 3, 2], paddings=[2, 1, 1]) 293 | self.layer09 = FTDNNLayer(3072, 1024, 256, context_size=1, dilations=[ 294 | 1, 1, 1], paddings=[0, 0, 0]) 295 | self.layer10 = DenseReLU(1024, 2048) 296 | 297 | self.layer11 = StatsPool() 298 | 299 | self.layer12 = DenseReLU(4096, 512) 300 | 301 | def forward(self, x): 302 | ''' 303 | Input must be (batch_size, seq_len, in_dim) 304 | ''' 305 | x = self.layer01(x) 306 | x_2 = self.layer02(x) 307 | x_3 = self.layer03(x_2) 308 | x_4 = self.layer04(x_3) 309 | skip_5 = torch.cat([x_4, x_3], dim=-1) 310 | x = self.layer05(skip_5) 311 | x_6 = self.layer06(x) 312 | skip_7 = torch.cat([x_6, x_4, x_2], dim=-1) 313 | x = self.layer07(skip_7) 314 | x_8 = self.layer08(x) 315 | skip_9 = torch.cat([x_8, x_6, x_4], dim=-1) 316 | x = self.layer09(skip_9) 317 | x = self.layer10(x) 318 | x = self.layer11(x) 319 | x = self.layer12(x) 320 | return x 321 | 322 | def step_ftdnn_layers(self): 323 | for layer in self.children(): 324 | if isinstance(layer, FTDNNLayer): 325 | layer.step_semi_orth() 326 | 327 | def set_dropout_alpha(self, alpha): 328 | for layer in self.children(): 329 | if isinstance(layer, FTDNNLayer): 330 | layer.dropout.alpha = alpha 331 | 332 | def get_orth_errors(self): 333 | errors = 0. 334 | with torch.no_grad(): 335 | for layer in self.children(): 336 | if isinstance(layer, FTDNNLayer): 337 | errors += layer.orth_error() 338 | return errors 339 | --------------------------------------------------------------------------------