├── .gitignore ├── README.md ├── docs ├── mfa_conformer.png └── results.png ├── loss ├── __init__.py ├── amsoftmax.py ├── softmax.py └── utils.py ├── main.py ├── module ├── __init__.py ├── _pooling.py ├── augment.py ├── conformer.py ├── conformer_cat.py ├── conformer_weight.py ├── dataset.py ├── ecapa_tdnn.py ├── feature.py ├── loader.py ├── resnet.py ├── transformer_cat.py └── utils.py ├── score ├── __init__.py ├── cosine.py └── utils.py ├── scripts ├── build_datalist.py ├── format_trials.py ├── make_balanced_data.py ├── make_cohort_set.py ├── make_tsne_set.py └── plot_score.py ├── start.sh └── wenet ├── transformer ├── attention.py ├── cmvn.py ├── convolution.py ├── embedding.py ├── encoder.py ├── encoder_cat.py ├── encoder_layer.py ├── encoder_weight.py ├── label_smoothing_loss.py ├── positionwise_feed_forward.py ├── subsampling.py └── swish.py └── utils ├── checkpoint.py ├── cmvn.py ├── common.py ├── ctc_util.py ├── executor.py ├── mask.py └── scheduler.py /.gitignore: -------------------------------------------------------------------------------- 1 | config.json 2 | experiment 3 | data 4 | test.sh 5 | meta 6 | *.wav 7 | lightning_logs 8 | *.ckpt 9 | *.pt 10 | *.lst 11 | *.txt 12 | data/ 13 | *.onnx 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | share/python-wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .nox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # celery beat schedule file 99 | celerybeat-schedule 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MFA-Conformer 2 | 3 | This repository contains the training code accompanying the paper "MFA-Conformer: Multi-scale Feature Aggregation Conformer for Automatic Speaker Verification", which is submitted to Interspeech 2022. 4 | 5 |

6 | 7 | The architecture of the MFA-Conformer is inspired by recent state-of-the-art models in speech recognition and speaker verification. Firstly, we introduce a convolution subsampling layer to decrease the computational cost of the model. Secondly, we adopt Conformer blocks which combine Transformers and convolution neural networks (CNNs) to capture global and local features effectively. Finally, the output feature maps from all Conformer blocks are concatenated to aggregate multi-scale representations before final pooling. The best system obtains 0.64%, 1.29% and 1.63% EER on VoxCeleb1-O, SITW.Dev, and SITW.Eval set, respectively. 8 | 9 | ## Data Preparation 10 | 11 | * [VoxCeleb 1&2](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/) 12 | * [SITW](http://www.speech.sri.com/projects/sitw/) 13 | 14 | ```bash 15 | # format Voxceleb test trial list 16 | rm -rf data; mkdir data 17 | wget -P data/ https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt 18 | python3 scripts/format_trials.py \ 19 | --voxceleb1_root $voxceleb1_dir \ 20 | --src_trials_path data/veri_test.txt \ 21 | --dst_trials_path data/vox1_test.txt 22 | 23 | # make csv for voxceleb1&2 dev audio (train_dir) 24 | python3 scripts/build_datalist.py \ 25 | --extension wav \ 26 | --dataset_dir data/$train_dir \ 27 | --data_list_path data/train.csv 28 | ``` 29 | 30 | ## Model Training 31 | 32 | ```bash 33 | python3 main.py \ 34 | --batch_size 200 \ 35 | --num_workers 40 \ 36 | --max_epochs 30 \ 37 | --embedding_dim $embedding_dim \ 38 | --save_dir $save_dir \ 39 | --encoder_name $encoder_name \ 40 | --train_csv_path $train_csv_path \ 41 | --learning_rate 0.001 \ 42 | --encoder_name ${encoder_name} \ 43 | --num_classes $num_classes \ 44 | --trial_path $trial_path \ 45 | --loss_name $loss_name \ 46 | --num_blocks $num_blocks \ 47 | --step_size 4 \ 48 | --gamma 0.5 \ 49 | --weight_decay 0.0000001 \ 50 | --input_layer $input_layer \ 51 | --pos_enc_layer_type $pos_enc_layer_type 52 | ``` 53 | 54 | ## Results 55 | 56 | The training results of default configuration is prestented below (Voxceleb1-test): 57 | 58 |

59 | 60 | ## Others 61 | 62 | What's more, here are some tips might be useful: 63 | 64 | 1. **The Conformer block**: We the borrow a lot of code from [WeNet](https://github.com/wenet-e2e/wenet) toolkit. 65 | 2. **Average the checkpoint weights**: When the model training is done, we average the parameters of the last 3~10 checkpoints to generate a new checkpoint. The new checkpoint always tends to achieve a better recognition performance. 66 | 3. **Warmup**: We perform a linear warmup learning rate schedule at the first 2k training steps. And we find that this warmup procedure is very helpful for the model training. 67 | 4. **AS-norm**: Adaptive score normalization (AS-norm) is common trick for speaker recognition. In our experiment, it will lead to 5%-10% relative improvement in EER metric. 68 | 69 | ## Citation 70 | 71 | If you find this code useful for your research, please cite our paper. 72 | 73 | ``` 74 | @article{zhang2022mfa, 75 | title={MFA-Conformer: Multi-scale Feature Aggregation Conformer for Automatic Speaker Verification}, 76 | author={Zhang, Yang and Lv, Zhiqiang and Wu, Haibin and Zhang, Shanshan and Hu, Pengfei and Wu, Zhiyong and Lee, Hung-yi and Meng, Helen}, 77 | journal={arXiv preprint arXiv:2203.15249}, 78 | year={2022} 79 | } 80 | ``` 81 | -------------------------------------------------------------------------------- /docs/mfa_conformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyzisyz/mfa_conformer/1b9c229948f8dbdbe9370937813ec75d4b06b097/docs/mfa_conformer.png -------------------------------------------------------------------------------- /docs/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyzisyz/mfa_conformer/1b9c229948f8dbdbe9370937813ec75d4b06b097/docs/results.png -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Yang Zhang. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .softmax import softmax 16 | from .amsoftmax import amsoftmax 17 | 18 | -------------------------------------------------------------------------------- /loss/amsoftmax.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | # Adapted from https://github.com/CoinCheung/pytorch-loss (MIT License) 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from .utils import accuracy 9 | 10 | class amsoftmax(nn.Module): 11 | def __init__(self, embedding_dim, num_classes, margin=0.2, scale=30, **kwargs): 12 | super(amsoftmax, self).__init__() 13 | 14 | self.m = margin 15 | self.s = scale 16 | self.in_feats = embedding_dim 17 | self.W = torch.nn.Parameter(torch.randn(embedding_dim, num_classes), requires_grad=True) 18 | self.ce = nn.CrossEntropyLoss() 19 | nn.init.xavier_normal_(self.W, gain=1) 20 | 21 | print('Initialised AM-Softmax m=%.3f s=%.3f'%(self.m, self.s)) 22 | print('Embedding dim is {}, number of speakers is {}'.format(embedding_dim, num_classes)) 23 | 24 | def forward(self, x, label=None): 25 | assert x.size()[0] == label.size()[0] 26 | assert x.size()[1] == self.in_feats 27 | 28 | x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12) 29 | x_norm = torch.div(x, x_norm) 30 | w_norm = torch.norm(self.W, p=2, dim=0, keepdim=True).clamp(min=1e-12) 31 | w_norm = torch.div(self.W, w_norm) 32 | costh = torch.mm(x_norm, w_norm) 33 | label_view = label.view(-1, 1) 34 | if label_view.is_cuda: label_view = label_view.cpu() 35 | delt_costh = torch.zeros(costh.size()).scatter_(1, label_view, self.m) 36 | if x.is_cuda: delt_costh = delt_costh.cuda() 37 | costh_m = costh - delt_costh 38 | costh_m_s = self.s * costh_m 39 | loss = self.ce(costh_m_s, label) 40 | acc = accuracy(costh_m_s.detach(), label.detach(), topk=(1,))[0] 41 | return loss, acc 42 | 43 | -------------------------------------------------------------------------------- /loss/softmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .utils import accuracy 7 | 8 | class softmax(nn.Module): 9 | def __init__(self, embedding_dim, num_classes, **kwargs): 10 | super(softmax, self).__init__() 11 | self.embedding_dim = embedding_dim 12 | self.fc = nn.Linear(embedding_dim, num_classes) 13 | self.criertion = nn.CrossEntropyLoss() 14 | 15 | print('init softmax') 16 | print('Embedding dim is {}, number of speakers is {}'.format(embedding_dim, num_classes)) 17 | 18 | def forward(self, x, label=None): 19 | assert x.size()[0] == label.size()[0] 20 | assert x.size()[1] == self.embedding_dim 21 | 22 | x = F.normalize(x, dim=1) 23 | x = self.fc(x) 24 | loss = self.criertion(x, label) 25 | acc1 = accuracy(x.detach(), label.detach(), topk=(1,))[0] 26 | return loss, acc1 27 | 28 | 29 | if __name__ == "__main__": 30 | model = softmax(10, 100) 31 | data = torch.randn((2, 10)) 32 | label = torch.tensor([0, 1]) 33 | loss, acc = model(data, label) 34 | 35 | print(data.shape) 36 | print(loss) 37 | print(acc) 38 | 39 | -------------------------------------------------------------------------------- /loss/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | def accuracy(output, target, topk=(1,)): 9 | """Computes the precision@k for the specified values of k""" 10 | maxk = max(topk) 11 | batch_size = target.size(0) 12 | 13 | _, pred = output.topk(maxk, 1, True, True) 14 | pred = pred.t() 15 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 16 | 17 | res = [] 18 | for k in topk: 19 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 20 | res.append(correct_k.mul_(100.0 / batch_size)) 21 | return res 22 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from copy import deepcopy 3 | from typing import Any, Union 4 | import torch.distributed as dist 5 | from pytorch_lightning.plugins import DDPPlugin 6 | import random 7 | 8 | import torch 9 | import torch.nn as nn 10 | import numpy as np 11 | 12 | from pytorch_lightning import LightningModule, Trainer, seed_everything 13 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 14 | from torch.nn import functional as F 15 | from torch.optim import AdamW 16 | from torch.optim.lr_scheduler import StepLR, CyclicLR 17 | 18 | from module.feature import Mel_Spectrogram 19 | from module.loader import SPK_datamodule 20 | import score as score 21 | from loss import softmax, amsoftmax 22 | 23 | class Task(LightningModule): 24 | def __init__( 25 | self, 26 | learning_rate: float = 0.2, 27 | weight_decay: float = 1.5e-6, 28 | batch_size: int = 32, 29 | num_workers: int = 10, 30 | max_epochs: int = 1000, 31 | trial_path: str = "data/vox1_test.txt", 32 | **kwargs 33 | ): 34 | super().__init__() 35 | self.save_hyperparameters() 36 | self.trials = np.loadtxt(self.hparams.trial_path, str) 37 | self.mel_trans = Mel_Spectrogram() 38 | 39 | from module.resnet import resnet34, resnet18, resnet34_large 40 | from module.ecapa_tdnn import ecapa_tdnn, ecapa_tdnn_large 41 | from module.transformer_cat import transformer_cat 42 | from module.conformer import conformer 43 | from module.conformer_cat import conformer_cat 44 | from module.conformer_weight import conformer_weight 45 | 46 | if self.hparams.encoder_name == "resnet18": 47 | self.encoder = resnet18(embedding_dim=self.hparams.embedding_dim) 48 | 49 | elif self.hparams.encoder_name == "resnet34": 50 | self.encoder = resnet34_large(embedding_dim=self.hparams.embedding_dim) 51 | 52 | elif self.hparams.encoder_name == "ecapa_tdnn": 53 | self.encoder = ecapa_tdnn(embedding_dim=self.hparams.embedding_dim) 54 | 55 | elif self.hparams.encoder_name == "ecapa_tdnn_large": 56 | self.encoder = ecapa_tdnn_large(embedding_dim=self.hparams.embedding_dim) 57 | 58 | elif self.hparams.encoder_name == "conformer": 59 | print("num_blocks is {}".format(self.hparams.num_blocks)) 60 | self.encoder = conformer(embedding_dim=self.hparams.embedding_dim, 61 | num_blocks=self.hparams.num_blocks, input_layer=self.hparams.input_layer) 62 | 63 | elif self.hparams.encoder_name == "transformer_cat": 64 | print("num_blocks is {}".format(self.hparams.num_blocks)) 65 | self.encoder = transformer_cat(embedding_dim=self.hparams.embedding_dim, 66 | num_blocks=self.hparams.num_blocks, input_layer=self.hparams.input_layer) 67 | 68 | elif self.hparams.encoder_name == "conformer_cat": 69 | print("num_blocks is {}".format(self.hparams.num_blocks)) 70 | self.encoder = conformer_cat(embedding_dim=self.hparams.embedding_dim, 71 | num_blocks=self.hparams.num_blocks, input_layer=self.hparams.input_layer, 72 | pos_enc_layer_type=self.hparams.pos_enc_layer_type) 73 | 74 | elif self.hparams.encoder_name == "conformer_weight": 75 | print("num_blocks is {}".format(self.hparams.num_blocks)) 76 | self.encoder = conformer_weight(embedding_dim=self.hparams.embedding_dim, 77 | num_blocks=self.hparams.num_blocks, input_layer=self.hparams.input_layer) 78 | 79 | else: 80 | raise ValueError("encoder name error") 81 | 82 | if self.hparams.loss_name == "amsoftmax": 83 | self.loss_fun = amsoftmax(embedding_dim=self.hparams.embedding_dim, num_classes=self.hparams.num_classes) 84 | else: 85 | self.loss_fun = softmax(embedding_dim=self.hparams.embedding_dim, num_classes=self.hparams.num_classes) 86 | 87 | def forward(self, x): 88 | feature = self.mel_trans(x) 89 | embedding = self.encoder(feature) 90 | return embedding 91 | 92 | def training_step(self, batch, batch_idx): 93 | waveform, label = batch 94 | feature = self.mel_trans(waveform) 95 | embedding = self.encoder(feature) 96 | loss, acc = self.loss_fun(embedding, label) 97 | self.log('train_loss', loss, prog_bar=True) 98 | self.log('acc', acc, prog_bar=True) 99 | return loss 100 | 101 | def on_test_epoch_start(self): 102 | return self.on_validation_epoch_start() 103 | 104 | def on_validation_epoch_start(self): 105 | self.index_mapping = {} 106 | self.eval_vectors = [] 107 | 108 | def test_step(self, batch, batch_idx): 109 | self.validation_step(batch, batch_idx) 110 | 111 | def validation_step(self, batch, batch_idx): 112 | x, path = batch 113 | path = path[0] 114 | with torch.no_grad(): 115 | x = self.mel_trans(x) 116 | self.encoder.eval() 117 | x = self.encoder(x) 118 | x = x.detach().cpu().numpy()[0] 119 | self.eval_vectors.append(x) 120 | self.index_mapping[path] = batch_idx 121 | 122 | def test_epoch_end(self, outputs): 123 | return self.validation_epoch_end(outputs) 124 | 125 | def validation_epoch_end(self, outputs): 126 | num_gpus = torch.cuda.device_count() 127 | eval_vectors = [None for _ in range(num_gpus)] 128 | dist.all_gather_object(eval_vectors, self.eval_vectors) 129 | eval_vectors = np.vstack(eval_vectors) 130 | 131 | table = [None for _ in range(num_gpus)] 132 | dist.all_gather_object(table, self.index_mapping) 133 | 134 | index_mapping = {} 135 | for i in table: 136 | index_mapping.update(i) 137 | 138 | eval_vectors = eval_vectors - np.mean(eval_vectors, axis=0) 139 | labels, scores = score.cosine_score( 140 | self.trials, index_mapping, eval_vectors) 141 | EER, threshold = score.compute_eer(labels, scores) 142 | 143 | print("\ncosine EER: {:.2f}% with threshold {:.2f}".format(EER*100, threshold)) 144 | self.log("cosine_eer", EER*100) 145 | 146 | minDCF, threshold = score.compute_minDCF(labels, scores, p_target=0.01) 147 | print("cosine minDCF(10-2): {:.2f} with threshold {:.2f}".format(minDCF, threshold)) 148 | self.log("cosine_minDCF(10-2)", minDCF) 149 | 150 | minDCF, threshold = score.compute_minDCF(labels, scores, p_target=0.001) 151 | print("cosine minDCF(10-3): {:.2f} with threshold {:.2f}".format(minDCF, threshold)) 152 | self.log("cosine_minDCF(10-3)", minDCF) 153 | 154 | 155 | def configure_optimizers(self): 156 | optimizer = torch.optim.Adam( 157 | self.parameters(), 158 | self.hparams.learning_rate, 159 | weight_decay=self.hparams.weight_decay 160 | ) 161 | scheduler = StepLR(optimizer, step_size=self.hparams.step_size, gamma=self.hparams.gamma) 162 | return [optimizer], [scheduler] 163 | 164 | def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, 165 | optimizer_closure, on_tpu, using_native_amp, using_lbfgs): 166 | # warm up learning_rate 167 | if self.trainer.global_step < self.hparams.warmup_step: 168 | lr_scale = min(1., float(self.trainer.global_step + 169 | 1) / float(self.hparams.warmup_step)) 170 | for idx, pg in enumerate(optimizer.param_groups): 171 | pg['lr'] = lr_scale * self.hparams.learning_rate 172 | # update params 173 | optimizer.step(closure=optimizer_closure) 174 | optimizer.zero_grad() 175 | 176 | @staticmethod 177 | def add_model_specific_args(parent_parser): 178 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 179 | (args, _) = parser.parse_known_args() 180 | 181 | parser.add_argument("--num_workers", default=40, type=int) 182 | parser.add_argument("--embedding_dim", default=256, type=int) 183 | parser.add_argument("--num_classes", type=int, default=1211) 184 | parser.add_argument("--num_blocks", type=int, default=6) 185 | 186 | parser.add_argument("--input_layer", type=str, default="conv2d") 187 | parser.add_argument("--pos_enc_layer_type", type=str, default="abs_pos") 188 | 189 | parser.add_argument("--second", type=int, default=3) 190 | parser.add_argument('--step_size', type=int, default=1) 191 | parser.add_argument('--gamma', type=float, default=0.9) 192 | parser.add_argument("--batch_size", type=int, default=80) 193 | parser.add_argument("--learning_rate", type=float, default=0.001) 194 | parser.add_argument("--warmup_step", type=float, default=2000) 195 | parser.add_argument("--weight_decay", type=float, default=0.000001) 196 | 197 | parser.add_argument("--save_dir", type=str, default=None) 198 | parser.add_argument("--checkpoint_path", type=str, default=None) 199 | parser.add_argument("--loss_name", type=str, default="amsoftmax") 200 | parser.add_argument("--encoder_name", type=str, default="resnet34") 201 | 202 | parser.add_argument("--train_csv_path", type=str, default="data/train.csv") 203 | parser.add_argument("--trial_path", type=str, default="data/vox1_test.txt") 204 | parser.add_argument("--score_save_path", type=str, default=None) 205 | 206 | parser.add_argument('--eval', action='store_true') 207 | parser.add_argument('--aug', action='store_true') 208 | return parser 209 | 210 | 211 | def cli_main(): 212 | parser = ArgumentParser() 213 | # trainer args 214 | parser = Trainer.add_argparse_args(parser) 215 | 216 | # model args 217 | parser = Task.add_model_specific_args(parser) 218 | args = parser.parse_args() 219 | 220 | model = Task(**args.__dict__) 221 | 222 | if args.checkpoint_path is not None: 223 | state_dict = torch.load(args.checkpoint_path, map_location="cpu")["state_dict"] 224 | model.load_state_dict(state_dict, strict=True) 225 | print("load weight from {}".format(args.checkpoint_path)) 226 | 227 | assert args.save_dir is not None 228 | checkpoint_callback = ModelCheckpoint(monitor='cosine_eer', save_top_k=100, 229 | filename="{epoch}_{cosine_eer:.2f}", dirpath=args.save_dir) 230 | lr_monitor = LearningRateMonitor(logging_interval='step') 231 | 232 | # init default datamodule 233 | print("data augmentation {}".format(args.aug)) 234 | dm = SPK_datamodule(train_csv_path=args.train_csv_path, trial_path=args.trial_path, second=args.second, 235 | aug=args.aug, batch_size=args.batch_size, num_workers=args.num_workers, pairs=False) 236 | AVAIL_GPUS = torch.cuda.device_count() 237 | trainer = Trainer( 238 | max_epochs=args.max_epochs, 239 | plugins=DDPPlugin(find_unused_parameters=False), 240 | gpus=AVAIL_GPUS, 241 | num_sanity_val_steps=-1, 242 | sync_batchnorm=True, 243 | callbacks=[checkpoint_callback, lr_monitor], 244 | default_root_dir=args.save_dir, 245 | reload_dataloaders_every_n_epochs=1, 246 | accumulate_grad_batches=1, 247 | log_every_n_steps=25, 248 | ) 249 | if args.eval: 250 | trainer.test(model, datamodule=dm) 251 | else: 252 | trainer.fit(model, datamodule=dm) 253 | 254 | 255 | if __name__ == "__main__": 256 | cli_main() 257 | 258 | -------------------------------------------------------------------------------- /module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyzisyz/mfa_conformer/1b9c229948f8dbdbe9370937813ec75d4b06b097/module/__init__.py -------------------------------------------------------------------------------- /module/_pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Temporal_Average_Pooling(nn.Module): 6 | def __init__(self, **kwargs): 7 | """TAP 8 | Paper: Multi-Task Learning with High-Order Statistics for X-vector based Text-Independent Speaker Verification 9 | Link: https://arxiv.org/pdf/1903.12058.pdf 10 | """ 11 | super(Temporal_Average_Pooling, self).__init__() 12 | 13 | def forward(self, x): 14 | """Computes Temporal Average Pooling Module 15 | Args: 16 | x (torch.Tensor): Input tensor (#batch, channels, frames). 17 | Returns: 18 | torch.Tensor: Output tensor (#batch, channels) 19 | """ 20 | x = torch.mean(x, axis=2) 21 | return x 22 | 23 | 24 | class Temporal_Statistics_Pooling(nn.Module): 25 | def __init__(self, **kwargs): 26 | """TSP 27 | Paper: X-vectors: Robust DNN Embeddings for Speaker Recognition 28 | Link: http://www.danielpovey.com/files/2018_icassp_xvectors.pdf 29 | """ 30 | super(Temporal_Statistics_Pooling, self).__init__() 31 | 32 | def forward(self, x): 33 | """Computes Temporal Statistics Pooling Module 34 | Args: 35 | x (torch.Tensor): Input tensor (#batch, channels, frames). 36 | Returns: 37 | torch.Tensor: Output tensor (#batch, channels*2) 38 | """ 39 | mean = torch.mean(x, axis=2) 40 | var = torch.var(x, axis=2) 41 | x = torch.cat((mean, var), axis=1) 42 | return x 43 | 44 | 45 | class Self_Attentive_Pooling(nn.Module): 46 | def __init__(self, dim): 47 | """SAP 48 | Paper: Self-Attentive Speaker Embeddings for Text-Independent Speaker Verification 49 | Link: https://danielpovey.com/files/2018_interspeech_xvector_attention.pdf 50 | Args: 51 | dim (pair): the size of attention weights 52 | """ 53 | super(Self_Attentive_Pooling, self).__init__() 54 | self.sap_linear = nn.Linear(dim, dim) 55 | self.attention = nn.Parameter(torch.FloatTensor(dim, 1)) 56 | 57 | def forward(self, x): 58 | """Computes Self-Attentive Pooling Module 59 | Args: 60 | x (torch.Tensor): Input tensor (#batch, dim, frames). 61 | Returns: 62 | torch.Tensor: Output tensor (#batch, dim) 63 | """ 64 | x = x.permute(0, 2, 1) 65 | h = torch.tanh(self.sap_linear(x)) 66 | w = torch.matmul(h, self.attention).squeeze(dim=2) 67 | w = F.softmax(w, dim=1).view(x.size(0), x.size(1), 1) 68 | x = torch.sum(x * w, dim=1) 69 | return x 70 | 71 | 72 | class Attentive_Statistics_Pooling(nn.Module): 73 | def __init__(self, dim): 74 | """ASP 75 | Paper: Attentive Statistics Pooling for Deep Speaker Embedding 76 | Link: https://arxiv.org/pdf/1803.10963.pdf 77 | Args: 78 | dim (pair): the size of attention weights 79 | """ 80 | super(Attentive_Statistics_Pooling, self).__init__() 81 | self.sap_linear = nn.Linear(dim, dim) 82 | self.attention = nn.Parameter(torch.FloatTensor(dim, 1)) 83 | 84 | def forward(self, x): 85 | """Computes Attentive Statistics Pooling Module 86 | Args: 87 | x (torch.Tensor): Input tensor (#batch, dim, frames). 88 | Returns: 89 | torch.Tensor: Output tensor (#batch, dim*2) 90 | """ 91 | x = x.permute(0, 2, 1) 92 | h = torch.tanh(self.sap_linear(x)) 93 | w = torch.matmul(h, self.attention).squeeze(dim=2) 94 | w = F.softmax(w, dim=1).view(x.size(0), x.size(1), 1) 95 | mu = torch.sum(x * w, dim=1) 96 | rh = torch.sqrt( ( torch.sum((x**2) * w, dim=1) - mu**2 ).clamp(min=1e-5) ) 97 | x = torch.cat((mu, rh), 1) 98 | return x 99 | 100 | 101 | if __name__ == "__main__": 102 | data = torch.randn(10, 128, 100) 103 | pooling = Self_Attentive_Pooling(128) 104 | out = pooling(data) 105 | print(data.shape) 106 | print(out.shape) 107 | -------------------------------------------------------------------------------- /module/augment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import pandas as pd 5 | 6 | from scipy.io import wavfile 7 | from scipy import signal 8 | import soundfile 9 | 10 | def compute_dB(waveform): 11 | """ 12 | Args: 13 | x (numpy.array): Input waveform (#length). 14 | Returns: 15 | numpy.array: Output array (#length). 16 | """ 17 | val = max(0.0, np.mean(np.power(waveform, 2))) 18 | dB = 10*np.log10(val+1e-4) 19 | return dB 20 | 21 | class WavAugment(object): 22 | def __init__(self, noise_csv_path="data/noise.csv", rir_csv_path="data/rir.csv"): 23 | self.noise_paths = pd.read_csv(noise_csv_path)["utt_paths"].values 24 | self.noise_names = pd.read_csv(noise_csv_path)["speaker_name"].values 25 | self.rir_paths = pd.read_csv(rir_csv_path)["utt_paths"].values 26 | 27 | def __call__(self, waveform): 28 | idx = np.random.randint(0, 10) 29 | if idx == 0: 30 | waveform = self.add_gaussian_noise(waveform) 31 | waveform = self.add_real_noise(waveform) 32 | 33 | if idx == 1 or idx == 2 or idx == 3: 34 | waveform = self.add_real_noise(waveform) 35 | 36 | if idx == 4 or idx == 5 or idx == 6: 37 | waveform = self.reverberate(waveform) 38 | 39 | if idx == 7: 40 | waveform = self.change_volum(waveform) 41 | waveform = self.reverberate(waveform) 42 | 43 | if idx == 6: 44 | waveform = self.change_volum(waveform) 45 | waveform = self.add_real_noise(waveform) 46 | 47 | if idx == 8: 48 | waveform = self.add_gaussian_noise(waveform) 49 | waveform = self.reverberate(waveform) 50 | 51 | return waveform 52 | 53 | def add_gaussian_noise(self, waveform): 54 | """ 55 | Args: 56 | x (numpy.array): Input waveform array (#length). 57 | Returns: 58 | numpy.array: Output waveform array (#length). 59 | """ 60 | snr = np.random.uniform(low=10, high=25) 61 | clean_dB = compute_dB(waveform) 62 | noise = np.random.randn(len(waveform)) 63 | noise_dB = compute_dB(noise) 64 | noise = np.sqrt(10 ** ((clean_dB - noise_dB - snr) / 10)) * noise 65 | waveform = (waveform + noise) 66 | return waveform 67 | 68 | def change_volum(self, waveform): 69 | """ 70 | Args: 71 | x (numpy.array): Input waveform array (#length). 72 | Returns: 73 | numpy.array: Output waveform array (#length). 74 | """ 75 | volum = np.random.uniform(low=0.8, high=1.0005) 76 | waveform = waveform * volum 77 | return waveform 78 | 79 | def add_real_noise(self, waveform): 80 | """ 81 | Args: 82 | x (numpy.array): Input length (#length). 83 | Returns: 84 | numpy.array: Output waveform array (#length). 85 | """ 86 | clean_dB = compute_dB(waveform) 87 | 88 | idx = np.random.randint(0, len(self.noise_paths)) 89 | sample_rate, noise = wavfile.read(self.noise_paths[idx]) 90 | noise = noise.astype(np.float64) 91 | 92 | snr = np.random.uniform(15, 25) 93 | 94 | noise_length = len(noise) 95 | audio_length = len(waveform) 96 | 97 | if audio_length >= noise_length: 98 | shortage = audio_length - noise_length 99 | noise = np.pad(noise, (0, shortage), 'wrap') 100 | else: 101 | start = np.random.randint(0, (noise_length-audio_length)) 102 | noise = noise[start:start+audio_length] 103 | 104 | noise_dB = compute_dB(noise) 105 | noise = np.sqrt(10 ** ((clean_dB - noise_dB - snr) / 10)) * noise 106 | waveform = (waveform + noise) 107 | return waveform 108 | 109 | def reverberate(self, waveform): 110 | """ 111 | Args: 112 | x (numpy.array): Input length (#length). 113 | Returns: 114 | numpy.array: Output waveform array (#length). 115 | """ 116 | audio_length = len(waveform) 117 | idx = np.random.randint(0, len(self.rir_paths)) 118 | 119 | path = self.rir_paths[idx] 120 | rir, sample_rate = soundfile.read(path) 121 | rir = rir/np.sqrt(np.sum(rir**2)) 122 | 123 | waveform = signal.convolve(waveform, rir, mode='full') 124 | return waveform[:audio_length] 125 | 126 | 127 | if __name__ == "__main__": 128 | aug = WavAugment() 129 | sample_rate, waveform = wavfile.read("input.wav") 130 | waveform = waveform.astype(np.float64) 131 | 132 | gaussian_noise_wave = aug.add_gaussian_noise(waveform) 133 | print(gaussian_noise_wave.dtype) 134 | wavfile.write("gaussian_noise_wave.wav", 16000, gaussian_noise_wave.astype(np.int16)) 135 | 136 | real_noise_wave = aug.add_real_noise(waveform) 137 | print(real_noise_wave.dtype) 138 | wavfile.write("real_noise_wave.wav", 16000, real_noise_wave.astype(np.int16)) 139 | 140 | change_volum_wave = aug.change_volum(waveform) 141 | print(change_volum_wave.dtype) 142 | wavfile.write("change_volum_wave.wav", 16000, change_volum_wave.astype(np.int16)) 143 | 144 | reverberate_wave = aug.reverberate(waveform) 145 | print(reverberate_wave.dtype) 146 | wavfile.write("reverberate_wave.wav", 16000, reverberate_wave.astype(np.int16)) 147 | 148 | reverb_noise_wave = aug.reverberate(waveform) 149 | reverb_noise_wave = aug.add_real_noise(waveform) 150 | print(reverb_noise_wave.dtype) 151 | wavfile.write("reverb_noise_wave.wav", 16000, reverb_noise_wave.astype(np.int16)) 152 | 153 | noise_reverb_wave = aug.add_real_noise(waveform) 154 | noise_reverb_wave = aug.reverberate(waveform) 155 | print(noise_reverb_wave.dtype) 156 | wavfile.write("noise_reverb_wave.wav", 16000, reverb_noise_wave.astype(np.int16)) 157 | 158 | a = torch.FloatTensor(noise_reverb_wave) 159 | print(a.dtype) 160 | -------------------------------------------------------------------------------- /module/conformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from wenet.transformer.encoder import ConformerEncoder 3 | from speechbrain.lobes.models.ECAPA_TDNN import AttentiveStatisticsPooling 4 | from speechbrain.lobes.models.ECAPA_TDNN import BatchNorm1d 5 | 6 | class Conformer(torch.nn.Module): 7 | def __init__(self, n_mels=80, num_blocks=6, output_size=256, embedding_dim=192, input_layer="conv2d2", 8 | pos_enc_layer_type="rel_pos"): 9 | super(Conformer, self).__init__() 10 | self.conformer = ConformerEncoder(input_size=n_mels, num_blocks=num_blocks, 11 | output_size=output_size, input_layer=input_layer, pos_enc_layer_type=pos_enc_layer_type) 12 | self.pooling = AttentiveStatisticsPooling(output_size) 13 | self.bn = BatchNorm1d(input_size=output_size*2) 14 | self.fc = torch.nn.Linear(output_size*2, embedding_dim) 15 | 16 | def forward(self, feat): 17 | feat = feat.squeeze(1).permute(0, 2, 1) 18 | lens = torch.ones(feat.shape[0]).to(feat.device) 19 | lens = torch.round(lens*feat.shape[1]).int() 20 | x, masks = self.conformer(feat, lens) 21 | x = x.permute(0, 2, 1) 22 | x = self.pooling(x) 23 | x = self.bn(x) 24 | x = x.permute(0, 2, 1) 25 | x = self.fc(x) 26 | x = x.squeeze(1) 27 | return x 28 | 29 | def conformer(n_mels=80, num_blocks=6, output_size=256, 30 | embedding_dim=192, input_layer="conv2d", pos_enc_layer_type="rel_pos"): 31 | model = Conformer(n_mels=n_mels, num_blocks=num_blocks, output_size=output_size, 32 | embedding_dim=embedding_dim, input_layer=input_layer, pos_enc_layer_type=pos_enc_layer_type) 33 | return model 34 | 35 | 36 | 37 | 38 | if __name__ == "__main__": 39 | for i in range(6, 7): 40 | print("num_blocks is {}".format(i)) 41 | model = conformer(num_blocks=i) 42 | 43 | import time 44 | model = model.eval() 45 | time1 = time.time() 46 | with torch.no_grad(): 47 | for i in range(100): 48 | data = torch.randn(1, 1, 80, 500) 49 | embedding = model(data) 50 | time2 = time.time() 51 | val = (time2 - time1)/100 52 | rtf = val / 5 53 | 54 | total = sum([param.nelement() for param in model.parameters()]) 55 | print("total param: {:.2f}M".format(total/1e6)) 56 | print("RTF {:.4f}".format(rtf)) 57 | 58 | -------------------------------------------------------------------------------- /module/conformer_cat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from wenet.transformer.encoder_cat import ConformerEncoder 3 | from speechbrain.lobes.models.ECAPA_TDNN import AttentiveStatisticsPooling 4 | from speechbrain.lobes.models.ECAPA_TDNN import BatchNorm1d 5 | 6 | class Conformer(torch.nn.Module): 7 | def __init__(self, n_mels=80, num_blocks=6, output_size=256, embedding_dim=192, input_layer="conv2d2", 8 | pos_enc_layer_type="rel_pos"): 9 | 10 | super(Conformer, self).__init__() 11 | print("input_layer: {}".format(input_layer)) 12 | print("pos_enc_layer_type: {}".format(pos_enc_layer_type)) 13 | self.conformer = ConformerEncoder(input_size=n_mels, num_blocks=num_blocks, 14 | output_size=output_size, input_layer=input_layer, pos_enc_layer_type=pos_enc_layer_type) 15 | self.pooling = AttentiveStatisticsPooling(output_size*num_blocks) 16 | self.bn = BatchNorm1d(input_size=output_size*num_blocks*2) 17 | self.fc = torch.nn.Linear(output_size*num_blocks*2, embedding_dim) 18 | 19 | def forward(self, feat): 20 | feat = feat.squeeze(1).permute(0, 2, 1) 21 | lens = torch.ones(feat.shape[0]).to(feat.device) 22 | lens = torch.round(lens*feat.shape[1]).int() 23 | x, masks = self.conformer(feat, lens) 24 | x = x.permute(0, 2, 1) 25 | x = self.pooling(x) 26 | x = self.bn(x) 27 | x = x.permute(0, 2, 1) 28 | x = self.fc(x) 29 | x = x.squeeze(1) 30 | return x 31 | 32 | def conformer_cat(n_mels=80, num_blocks=6, output_size=256, 33 | embedding_dim=192, input_layer="conv2d", pos_enc_layer_type="rel_pos"): 34 | model = Conformer(n_mels=n_mels, num_blocks=num_blocks, output_size=output_size, 35 | embedding_dim=embedding_dim, input_layer=input_layer, pos_enc_layer_type=pos_enc_layer_type) 36 | return model 37 | 38 | 39 | -------------------------------------------------------------------------------- /module/conformer_weight.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from wenet.transformer.encoder_weight import ConformerEncoder 3 | from speechbrain.lobes.models.ECAPA_TDNN import AttentiveStatisticsPooling 4 | from speechbrain.lobes.models.ECAPA_TDNN import BatchNorm1d 5 | 6 | class Conformer(torch.nn.Module): 7 | def __init__(self, n_mels=80, num_blocks=6, output_size=256, embedding_dim=192, input_layer="conv2d2", 8 | pos_enc_layer_type="rel_pos"): 9 | 10 | super(Conformer, self).__init__() 11 | print("input_layer: {}".format(input_layer)) 12 | print("pos_enc_layer_type: {}".format(pos_enc_layer_type)) 13 | self.conformer = ConformerEncoder(input_size=n_mels, num_blocks=num_blocks, 14 | output_size=output_size, input_layer=input_layer, pos_enc_layer_type=pos_enc_layer_type) 15 | self.pooling = AttentiveStatisticsPooling(output_size) 16 | self.bn = BatchNorm1d(input_size=output_size*2) 17 | self.fc = torch.nn.Linear(output_size*2, embedding_dim) 18 | 19 | def forward(self, feat): 20 | feat = feat.squeeze(1).permute(0, 2, 1) 21 | lens = torch.ones(feat.shape[0]).to(feat.device) 22 | lens = torch.round(lens*feat.shape[1]).int() 23 | x, masks = self.conformer(feat, lens) 24 | x = x.permute(0, 2, 1) 25 | x = self.pooling(x) 26 | x = self.bn(x) 27 | x = x.permute(0, 2, 1) 28 | x = self.fc(x) 29 | x = x.squeeze(1) 30 | return x 31 | 32 | def conformer_weight(n_mels=80, num_blocks=6, output_size=256, 33 | embedding_dim=192, input_layer="conv2d", pos_enc_layer_type="rel_pos"): 34 | model = Conformer(n_mels=n_mels, num_blocks=num_blocks, output_size=output_size, 35 | embedding_dim=embedding_dim, input_layer=input_layer, pos_enc_layer_type=pos_enc_layer_type) 36 | return model 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /module/dataset.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from scipy import signal 9 | from scipy.io import wavfile 10 | from sklearn.utils import shuffle 11 | from torch.utils.data import DataLoader, Dataset 12 | from .augment import WavAugment 13 | 14 | 15 | def load_audio(filename, second=2): 16 | sample_rate, waveform = wavfile.read(filename) 17 | audio_length = waveform.shape[0] 18 | 19 | if second <= 0: 20 | return waveform.astype(np.float64).copy() 21 | 22 | length = np.int64(sample_rate * second) 23 | 24 | if audio_length <= length: 25 | shortage = length - audio_length 26 | waveform = np.pad(waveform, (0, shortage), 'wrap') 27 | waveform = waveform.astype(np.float64) 28 | else: 29 | start = np.int64(random.random()*(audio_length-length)) 30 | waveform = waveform[start:start+length].astype(np.float64) 31 | return waveform.copy() 32 | 33 | class Train_Dataset(Dataset): 34 | def __init__(self, train_csv_path, second=3, pairs=True, aug=False, **kwargs): 35 | self.second = second 36 | self.pairs = pairs 37 | 38 | df = pd.read_csv(train_csv_path) 39 | self.labels = df["utt_spk_int_labels"].values 40 | self.paths = df["utt_paths"].values 41 | self.labels, self.paths = shuffle(self.labels, self.paths) 42 | self.aug = aug 43 | if aug: 44 | self.wav_aug = WavAugment() 45 | 46 | print("Train Dataset load {} speakers".format(len(set(self.labels)))) 47 | print("Train Dataset load {} utterance".format(len(self.labels))) 48 | 49 | def __getitem__(self, index): 50 | waveform_1 = load_audio(self.paths[index], self.second) 51 | if self.aug == True: 52 | waveform_1 = self.wav_aug(waveform_1) 53 | if self.pairs == False: 54 | return torch.FloatTensor(waveform_1), self.labels[index] 55 | 56 | else: 57 | waveform_2 = load_audio(self.paths[index], self.second) 58 | if self.aug == True: 59 | waveform_2 = self.wav_aug(waveform_2) 60 | return torch.FloatTensor(waveform_1), torch.FloatTensor(waveform_2), self.labels[index] 61 | 62 | def __len__(self): 63 | return len(self.paths) 64 | 65 | 66 | class Semi_Dataset(Dataset): 67 | def __init__(self, label_csv_path, unlabel_csv_path, second=2, pairs=True, aug=False, **kwargs): 68 | self.second = second 69 | self.pairs = pairs 70 | 71 | df = pd.read_csv(label_csv_path) 72 | self.labels = df["utt_spk_int_labels"].values 73 | self.paths = df["utt_paths"].values 74 | 75 | self.aug = aug 76 | if aug: 77 | self.wav_aug = WavAugment() 78 | 79 | df = pd.read_csv(unlabel_csv_path) 80 | self.u_paths = df["utt_paths"].values 81 | self.u_paths_length = len(self.u_paths) 82 | 83 | if label_csv_path != unlabel_csv_path: 84 | self.labels, self.paths = shuffle(self.labels, self.paths) 85 | self.u_paths = shuffle(self.u_paths) 86 | 87 | # self.labels = self.labels[:self.u_paths_length] 88 | # self.paths = self.paths[:self.u_paths_length] 89 | print("Semi Dataset load {} speakers".format(len(set(self.labels)))) 90 | print("Semi Dataset load {} utterance".format(len(self.labels))) 91 | 92 | def __getitem__(self, index): 93 | waveform_l = load_audio(self.paths[index], self.second) 94 | 95 | idx = np.random.randint(0, self.u_paths_length) 96 | waveform_u_1 = load_audio(self.u_paths[idx], self.second) 97 | if self.aug == True: 98 | waveform_u_1 = self.wav_aug(waveform_u_1) 99 | 100 | if self.pairs == False: 101 | return torch.FloatTensor(waveform_l), self.labels[index], torch.FloatTensor(waveform_u_1) 102 | 103 | else: 104 | waveform_u_2 = load_audio(self.u_paths[idx], self.second) 105 | if self.aug == True: 106 | waveform_u_2 = self.wav_aug(waveform_u_2) 107 | return torch.FloatTensor(waveform_l), self.labels[index], torch.FloatTensor(waveform_u_1), torch.FloatTensor(waveform_u_2) 108 | 109 | def __len__(self): 110 | return len(self.paths) 111 | 112 | 113 | class Evaluation_Dataset(Dataset): 114 | def __init__(self, paths, second=-1, **kwargs): 115 | self.paths = paths 116 | self.second = second 117 | print("load {} utterance".format(len(self.paths))) 118 | 119 | def __getitem__(self, index): 120 | waveform = load_audio(self.paths[index], self.second) 121 | return torch.FloatTensor(waveform), self.paths[index] 122 | 123 | def __len__(self): 124 | return len(self.paths) 125 | 126 | if __name__ == "__main__": 127 | dataset = Train_Dataset(train_csv_path="data/train.csv", second=3) 128 | loader = DataLoader( 129 | dataset, 130 | batch_size=10, 131 | shuffle=False 132 | ) 133 | for x, label in loader: 134 | pass 135 | 136 | -------------------------------------------------------------------------------- /module/ecapa_tdnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from speechbrain.lobes.models.ECAPA_TDNN import ECAPA_TDNN 3 | 4 | class Model(torch.nn.Module): 5 | def __init__(self, n_mels=80, embedding_dim=192, channel=512): 6 | super(Model, self).__init__() 7 | channels = [channel for _ in range(4)] 8 | channels.append(channel*3) 9 | self.model = ECAPA_TDNN(input_size=n_mels, lin_neurons=embedding_dim, channels=channels) 10 | 11 | def forward(self, x): 12 | x = x.squeeze(1) 13 | x = x.permute(0, 2, 1) 14 | x = self.model(x) 15 | x = x.squeeze(1) 16 | return x 17 | 18 | def ecapa_tdnn(n_mels=80, embedding_dim=192, channel=512): 19 | model = Model(n_mels=n_mels, embedding_dim=embedding_dim, channel=channel) 20 | return model 21 | 22 | def ecapa_tdnn_large(n_mels=80, embedding_dim=192, channel=1024): 23 | model = Model(n_mels=n_mels, embedding_dim=embedding_dim, channel=channel) 24 | return model 25 | 26 | 27 | -------------------------------------------------------------------------------- /module/feature.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class PreEmphasis(torch.nn.Module): 7 | def __init__(self, coef: float = 0.97): 8 | super(PreEmphasis, self).__init__() 9 | self.coef = coef 10 | # make kernel 11 | # In pytorch, the convolution operation uses cross-correlation. So, filter is flipped. 12 | self.register_buffer( 13 | 'flipped_filter', torch.FloatTensor( 14 | [-self.coef, 1.]).unsqueeze(0).unsqueeze(0) 15 | ) 16 | 17 | def forward(self, inputs: torch.tensor) -> torch.tensor: 18 | assert len( 19 | inputs.size()) == 2, 'The number of dimensions of inputs tensor must be 2!' 20 | # reflect padding to match lengths of in/out 21 | inputs = inputs.unsqueeze(1) 22 | inputs = F.pad(inputs, (1, 0), 'reflect') 23 | return F.conv1d(inputs, self.flipped_filter).squeeze(1) 24 | 25 | 26 | class Mel_Spectrogram(nn.Module): 27 | def __init__(self, sample_rate=16000, n_fft=512, win_length=400, hop=160, n_mels=80, coef=0.97, requires_grad=False): 28 | super(Mel_Spectrogram, self).__init__() 29 | self.n_fft = n_fft 30 | self.n_mels = n_mels 31 | self.win_length = win_length 32 | self.hop = hop 33 | 34 | self.pre_emphasis = PreEmphasis(coef) 35 | mel_basis = librosa.filters.mel( 36 | sr=sample_rate, n_fft=n_fft, n_mels=n_mels) 37 | self.mel_basis = nn.Parameter( 38 | torch.FloatTensor(mel_basis), requires_grad=requires_grad) 39 | self.instance_norm = nn.InstanceNorm1d(num_features=n_mels) 40 | window = torch.hamming_window(self.win_length) 41 | self.window = nn.Parameter( 42 | torch.FloatTensor(window), requires_grad=False) 43 | 44 | def forward(self, x): 45 | x = self.pre_emphasis(x) 46 | x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, 47 | window=self.window, win_length=self.win_length, return_complex=True) 48 | x = torch.abs(x) 49 | x += 1e-9 50 | x = torch.log(x) 51 | x = torch.matmul(self.mel_basis, x) 52 | x = self.instance_norm(x) 53 | x = x.unsqueeze(1) 54 | return x 55 | 56 | 57 | if __name__ == "__main__": 58 | from scipy.io import wavfile 59 | import matplotlib.pyplot as plt 60 | from torchvision import transforms as transforms 61 | 62 | sample_rate, sig = wavfile.read("test.wav") 63 | sig = torch.FloatTensor(sig.copy()) 64 | sig = sig.repeat(10, 1) 65 | 66 | spec = Mel_Spectrogram() 67 | out = spec(sig) 68 | out = out 69 | print(out.shape) 70 | 71 | plt.subplot(211) 72 | plt.imshow(out[0][0]) 73 | 74 | trans = transforms.RandomResizedCrop((80, 200)) 75 | out = trans(out) 76 | print(out.shape) 77 | 78 | plt.subplot(212) 79 | plt.imshow(out[0][0]) 80 | 81 | plt.savefig("test.png") 82 | -------------------------------------------------------------------------------- /module/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Callable, Optional 3 | 4 | import numpy as np 5 | import torch 6 | from pytorch_lightning import LightningDataModule 7 | from torch.utils.data import DataLoader 8 | 9 | from pl_bolts.datasets import UnlabeledImagenet 10 | from pl_bolts.utils.warnings import warn_missing_pkg 11 | 12 | from .dataset import Evaluation_Dataset, Train_Dataset, Semi_Dataset 13 | 14 | 15 | class SPK_datamodule(LightningDataModule): 16 | def __init__( 17 | self, 18 | train_csv_path, 19 | trial_path, 20 | unlabel_csv_path = None, 21 | second: int = 2, 22 | num_workers: int = 16, 23 | batch_size: int = 32, 24 | shuffle: bool = True, 25 | pin_memory: bool = True, 26 | drop_last: bool = True, 27 | pairs: bool = True, 28 | aug: bool = False, 29 | semi: bool = False, 30 | *args: Any, 31 | **kwargs: Any, 32 | ) -> None: 33 | super().__init__(*args, **kwargs) 34 | 35 | self.train_csv_path = train_csv_path 36 | self.unlabel_csv_path = unlabel_csv_path 37 | self.second = second 38 | self.num_workers = num_workers 39 | self.batch_size = batch_size 40 | self.trial_path = trial_path 41 | self.pairs = pairs 42 | self.aug = aug 43 | print("second is {:.2f}".format(second)) 44 | 45 | def train_dataloader(self) -> DataLoader: 46 | if self.unlabel_csv_path is None: 47 | train_dataset = Train_Dataset(self.train_csv_path, self.second, self.pairs, self.aug) 48 | else: 49 | train_dataset = Semi_Dataset(self.train_csv_path, self.unlabel_csv_path, self.second, self.pairs, self.aug) 50 | loader = torch.utils.data.DataLoader( 51 | train_dataset, 52 | shuffle=True, 53 | num_workers=self.num_workers, 54 | batch_size=self.batch_size, 55 | pin_memory=True, 56 | drop_last=False, 57 | ) 58 | return loader 59 | 60 | def val_dataloader(self) -> DataLoader: 61 | trials = np.loadtxt(self.trial_path, str) 62 | self.trials = trials 63 | eval_path = np.unique(np.concatenate((trials.T[1], trials.T[2]))) 64 | print("number of enroll: {}".format(len(set(trials.T[1])))) 65 | print("number of test: {}".format(len(set(trials.T[2])))) 66 | print("number of evaluation: {}".format(len(eval_path))) 67 | eval_dataset = Evaluation_Dataset(eval_path, second=-1) 68 | loader = torch.utils.data.DataLoader(eval_dataset, 69 | num_workers=10, 70 | shuffle=False, 71 | batch_size=1) 72 | return loader 73 | 74 | def test_dataloader(self) -> DataLoader: 75 | return self.val_dataloader() 76 | 77 | 78 | -------------------------------------------------------------------------------- /module/resnet.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Optional, Type, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | 7 | try: 8 | from ._pooling import * 9 | except: 10 | from _pooling import * 11 | 12 | 13 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=dilation, groups=groups, bias=False, dilation=dilation) 17 | 18 | 19 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 20 | """1x1 convolution""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion: int = 1 26 | 27 | def __init__( 28 | self, 29 | inplanes: int, 30 | planes: int, 31 | stride: int = 1, 32 | downsample: Optional[nn.Module] = None, 33 | groups: int = 1, 34 | base_width: int = 64, 35 | dilation: int = 1, 36 | norm_layer: Optional[Callable[..., nn.Module]] = None 37 | ) -> None: 38 | super(BasicBlock, self).__init__() 39 | if norm_layer is None: 40 | norm_layer = nn.BatchNorm2d 41 | if groups != 1 or base_width != 64: 42 | raise ValueError( 43 | 'BasicBlock only supports groups=1 and base_width=64') 44 | if dilation > 1: 45 | raise NotImplementedError( 46 | "Dilation > 1 not supported in BasicBlock") 47 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 48 | self.conv1 = conv3x3(inplanes, planes, stride) 49 | self.bn1 = norm_layer(planes) 50 | self.relu = nn.ReLU(inplace=True) 51 | self.conv2 = conv3x3(planes, planes) 52 | self.bn2 = norm_layer(planes) 53 | self.downsample = downsample 54 | self.stride = stride 55 | 56 | def forward(self, x: Tensor) -> Tensor: 57 | identity = x 58 | 59 | out = self.conv1(x) 60 | out = self.bn1(out) 61 | out = self.relu(out) 62 | 63 | out = self.conv2(out) 64 | out = self.bn2(out) 65 | 66 | if self.downsample is not None: 67 | identity = self.downsample(x) 68 | 69 | out += identity 70 | out = self.relu(out) 71 | 72 | return out 73 | 74 | 75 | class Bottleneck(nn.Module): 76 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 77 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 78 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 79 | # This variant is also known as ResNet V1.5 and improves accuracy according to 80 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 81 | 82 | expansion: int = 4 83 | 84 | def __init__( 85 | self, 86 | inplanes: int, 87 | planes: int, 88 | stride: int = 1, 89 | downsample: Optional[nn.Module] = None, 90 | groups: int = 1, 91 | base_width: int = 64, 92 | dilation: int = 1, 93 | norm_layer: Optional[Callable[..., nn.Module]] = None 94 | ) -> None: 95 | super(Bottleneck, self).__init__() 96 | if norm_layer is None: 97 | norm_layer = nn.BatchNorm2d 98 | width = int(planes * (base_width / 64.)) * groups 99 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 100 | self.conv1 = conv1x1(inplanes, width) 101 | self.bn1 = norm_layer(width) 102 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 103 | self.bn2 = norm_layer(width) 104 | self.conv3 = conv1x1(width, planes * self.expansion) 105 | self.bn3 = norm_layer(planes * self.expansion) 106 | self.relu = nn.ReLU(inplace=True) 107 | self.downsample = downsample 108 | self.stride = stride 109 | 110 | def forward(self, x: Tensor) -> Tensor: 111 | identity = x 112 | 113 | out = self.conv1(x) 114 | out = self.bn1(out) 115 | out = self.relu(out) 116 | 117 | out = self.conv2(out) 118 | out = self.bn2(out) 119 | out = self.relu(out) 120 | 121 | out = self.conv3(out) 122 | out = self.bn3(out) 123 | 124 | if self.downsample is not None: 125 | identity = self.downsample(x) 126 | 127 | out += identity 128 | out = self.relu(out) 129 | 130 | return out 131 | 132 | 133 | class ResNet(nn.Module): 134 | 135 | def __init__( 136 | self, 137 | block: Type[Union[BasicBlock, Bottleneck]], 138 | layers: List[int], 139 | num_channels: List[int] = [1, 32, 64, 128, 256], 140 | embedding_dim: int = 256, 141 | n_mels: int = 80, 142 | pooling_type="TSP", 143 | zero_init_residual: bool = False, 144 | groups: int = 1, 145 | width_per_group: int = 64, 146 | replace_stride_with_dilation: Optional[List[bool]] = None, 147 | norm_layer: Optional[Callable[..., nn.Module]] = None, 148 | **kwargs 149 | ) -> None: 150 | super(ResNet, self).__init__() 151 | if norm_layer is None: 152 | norm_layer = nn.BatchNorm2d 153 | self._norm_layer = norm_layer 154 | 155 | self.inplanes = 64 156 | self.dilation = 1 157 | if replace_stride_with_dilation is None: 158 | # each element in the tuple indicates if we should replace 159 | # the 2x2 stride with a dilated convolution instead 160 | replace_stride_with_dilation = [False, False, False] 161 | if len(replace_stride_with_dilation) != 3: 162 | raise ValueError("replace_stride_with_dilation should be None " 163 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 164 | self.groups = groups 165 | self.base_width = width_per_group 166 | self.conv1 = nn.Conv2d(num_channels[0], self.inplanes, kernel_size=3, stride=1, padding=1, 167 | bias=False) 168 | self.bn1 = norm_layer(self.inplanes) 169 | self.relu = nn.ReLU(inplace=True) 170 | self.layer1 = self._make_layer(block, num_channels[1], layers[0]) 171 | self.layer2 = self._make_layer(block, num_channels[2], layers[1], stride=2, 172 | dilate=replace_stride_with_dilation[0]) 173 | self.layer3 = self._make_layer(block, num_channels[3], layers[2], stride=2, 174 | dilate=replace_stride_with_dilation[1]) 175 | self.layer4 = self._make_layer(block, num_channels[4], layers[3], stride=2, 176 | dilate=replace_stride_with_dilation[2]) 177 | 178 | out_dim = num_channels[4] * block.expansion * (n_mels//8) 179 | if pooling_type == "Temporal_Average_Pooling" or pooling_type == "TAP": 180 | self.pooling = Temporal_Average_Pooling() 181 | self.fc = nn.Linear(out_dim, embedding_dim) 182 | 183 | elif pooling_type == "Temporal_Statistics_Pooling" or pooling_type == "TSP": 184 | self.pooling = Temporal_Statistics_Pooling() 185 | self.fc = nn.Linear(out_dim*2, embedding_dim) 186 | 187 | elif pooling_type == "Self_Attentive_Pooling" or pooling_type == "SAP": 188 | self.pooling = Self_Attentive_Pooling(out_dim) 189 | self.fc = nn.Linear(out_dim, embedding_dim) 190 | 191 | elif pooling_type == "Attentive_Statistics_Pooling" or pooling_type == "ASP": 192 | self.pooling = Attentive_Statistics_Pooling(out_dim) 193 | self.fc = nn.Linear(out_dim*2, embedding_dim) 194 | 195 | else: 196 | raise ValueError( 197 | '{} pooling type is not defined'.format(pooling_type)) 198 | 199 | print("resnet num_channels: {}".format(num_channels)) 200 | print("n_mels: {}".format(n_mels)) 201 | print("embedding_dim: {}".format(embedding_dim)) 202 | print("pooling_type: {}".format(pooling_type)) 203 | 204 | for m in self.modules(): 205 | if isinstance(m, nn.Conv2d): 206 | nn.init.kaiming_normal_( 207 | m.weight, mode='fan_out', nonlinearity='relu') 208 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 209 | nn.init.constant_(m.weight, 1) 210 | nn.init.constant_(m.bias, 0) 211 | 212 | # Zero-initialize the last BN in each residual branch, 213 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 214 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 215 | if zero_init_residual: 216 | for m in self.modules(): 217 | if isinstance(m, Bottleneck): 218 | # type: ignore[arg-type] 219 | nn.init.constant_(m.bn3.weight, 0) 220 | elif isinstance(m, BasicBlock): 221 | # type: ignore[arg-type] 222 | nn.init.constant_(m.bn2.weight, 0) 223 | 224 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 225 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 226 | norm_layer = self._norm_layer 227 | downsample = None 228 | previous_dilation = self.dilation 229 | if dilate: 230 | self.dilation *= stride 231 | stride = 1 232 | if stride != 1 or self.inplanes != planes * block.expansion: 233 | downsample = nn.Sequential( 234 | conv1x1(self.inplanes, planes * block.expansion, stride), 235 | norm_layer(planes * block.expansion), 236 | ) 237 | 238 | layers = [] 239 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 240 | self.base_width, previous_dilation, norm_layer)) 241 | self.inplanes = planes * block.expansion 242 | for _ in range(1, blocks): 243 | layers.append(block(self.inplanes, planes, groups=self.groups, 244 | base_width=self.base_width, dilation=self.dilation, 245 | norm_layer=norm_layer)) 246 | 247 | return nn.Sequential(*layers) 248 | 249 | def _forward_impl(self, x: Tensor) -> Tensor: 250 | # See note [TorchScript super()] 251 | x = self.conv1(x) 252 | x = self.bn1(x) 253 | x = self.relu(x) 254 | 255 | x = self.layer1(x) 256 | x = self.layer2(x) 257 | x = self.layer3(x) 258 | x = self.layer4(x) 259 | 260 | x = x.reshape(x.shape[0], -1, x.shape[-1]) 261 | 262 | x = self.pooling(x) 263 | 264 | x = torch.flatten(x, 1) 265 | x = self.fc(x) 266 | 267 | return x 268 | 269 | def forward(self, x: Tensor) -> Tensor: 270 | return self._forward_impl(x) 271 | 272 | 273 | def _resnet( 274 | arch: str, 275 | block: Type[Union[BasicBlock, Bottleneck]], 276 | layers: List[int], 277 | **kwargs: Any 278 | ) -> ResNet: 279 | model = ResNet(block, layers, **kwargs) 280 | return model 281 | 282 | 283 | def resnet18(**kwargs: Any) -> ResNet: 284 | r"""ResNet-18 model from 285 | `"Deep Residual Learning for Image Recognition" `_. 286 | Args: 287 | **kwargs: Any 288 | """ 289 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], num_channels=[1, 64, 128, 256, 512], **kwargs) 290 | 291 | 292 | def resnet34(**kwargs: Any) -> ResNet: 293 | r"""ResNet-34 model from 294 | `"Deep Residual Learning for Image Recognition" `_. 295 | 296 | Args: 297 | **kwargs: Any 298 | """ 299 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], **kwargs) 300 | 301 | 302 | def resnet34_large(**kwargs: Any) -> ResNet: 303 | r"""ResNet-34 model from 304 | `"Deep Residual Learning for Image Recognition" `_. 305 | 306 | Args: 307 | **kwargs: Any 308 | """ 309 | model = _resnet('resnet34', BasicBlock, [3, 4, 6, 3], num_channels=[1, 64, 128, 256, 512], **kwargs) 310 | return model 311 | 312 | def resnet50(**kwargs: Any) -> ResNet: 313 | r"""ResNet-50 model from 314 | `"Deep Residual Learning for Image Recognition" `_. 315 | 316 | Args: 317 | **kwargs: Any 318 | """ 319 | model = _resnet('resnet50', Bottleneck, [3, 4, 6, 3], num_channels=[1, 64, 128, 256, 512], **kwargs) 320 | return model 321 | 322 | 323 | def resnext50_32x4d(**kwargs: Any) -> ResNet: 324 | r"""ResNeXt-50 32x4d model from 325 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 326 | 327 | Args: 328 | **kwargs: Any 329 | """ 330 | kwargs['groups'] = 32 331 | kwargs['width_per_group'] = 4 332 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], **kwargs) 333 | 334 | 335 | 336 | -------------------------------------------------------------------------------- /module/transformer_cat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from wenet.transformer.encoder_cat import TransformerEncoder 3 | from speechbrain.lobes.models.ECAPA_TDNN import AttentiveStatisticsPooling 4 | from speechbrain.lobes.models.ECAPA_TDNN import BatchNorm1d 5 | 6 | class Transformer(torch.nn.Module): 7 | def __init__(self, n_mels=80, num_blocks=6, output_size=256, embedding_dim=192, input_layer="conv2d2", 8 | pos_enc_layer_type="rel_pos"): 9 | 10 | super(Transformer, self).__init__() 11 | print("input_layer: {}".format(input_layer)) 12 | print("pos_enc_layer_type: {}".format(pos_enc_layer_type)) 13 | self.conformer = TransformerEncoder(input_size=n_mels, num_blocks=num_blocks, 14 | output_size=output_size, input_layer=input_layer, pos_enc_layer_type=pos_enc_layer_type) 15 | self.pooling = AttentiveStatisticsPooling(output_size*num_blocks) 16 | self.bn = BatchNorm1d(input_size=output_size*num_blocks*2) 17 | self.fc = torch.nn.Linear(output_size*num_blocks*2, embedding_dim) 18 | 19 | def forward(self, feat): 20 | feat = feat.squeeze(1).permute(0, 2, 1) 21 | lens = torch.ones(feat.shape[0]).to(feat.device) 22 | lens = torch.round(lens*feat.shape[1]).int() 23 | x, masks = self.conformer(feat, lens) 24 | x = x.permute(0, 2, 1) 25 | x = self.pooling(x) 26 | x = self.bn(x) 27 | x = x.permute(0, 2, 1) 28 | x = self.fc(x) 29 | x = x.squeeze(1) 30 | return x 31 | 32 | def transformer_cat(n_mels=80, num_blocks=6, output_size=256, 33 | embedding_dim=192, input_layer="conv2d", pos_enc_layer_type="rel_pos"): 34 | model = Transformer(n_mels=n_mels, num_blocks=num_blocks, output_size=output_size, 35 | embedding_dim=embedding_dim, input_layer=input_layer, pos_enc_layer_type=pos_enc_layer_type) 36 | return model 37 | 38 | 39 | -------------------------------------------------------------------------------- /module/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def compute_dB(waveform): 5 | """ 6 | Args: 7 | x (torch.tensor): Input waveform (#length). 8 | Returns: 9 | torch.tensor: Output array (#length). 10 | """ 11 | val = max(0.0, torch.mean(torch.pow(waveform, 2))) 12 | dB = 10*torch.log10(val+1e-4) 13 | return dB 14 | 15 | def compute_SNR(waveform, noise): 16 | """ 17 | Args: 18 | x (numpy.array): Input waveform (#length). 19 | Returns: 20 | numpy.array: Output array (#length). 21 | """ 22 | SNR = 10*np.log10(np.mean(waveform**2)/np.mean(noise**2)+1e-9) 23 | return SNR 24 | 25 | 26 | -------------------------------------------------------------------------------- /score/__init__.py: -------------------------------------------------------------------------------- 1 | from .cosine import cosine_score 2 | from .utils import compute_eer, compute_minDCF 3 | -------------------------------------------------------------------------------- /score/cosine.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def cosine_score(trials, index_mapping, eval_vectors): 4 | labels = [] 5 | scores = [] 6 | for item in trials: 7 | enroll_vector = eval_vectors[index_mapping[item[1]]] 8 | test_vector = eval_vectors[index_mapping[item[2]]] 9 | score = enroll_vector.dot(test_vector.T) 10 | denom = np.linalg.norm(enroll_vector) * np.linalg.norm(test_vector) 11 | score = score/denom 12 | labels.append(int(item[0])) 13 | scores.append(score) 14 | return labels, scores 15 | 16 | -------------------------------------------------------------------------------- /score/utils.py: -------------------------------------------------------------------------------- 1 | from scipy.interpolate import interp1d 2 | from sklearn.metrics import roc_curve 3 | from scipy.optimize import brentq 4 | 5 | def compute_eer(labels, scores): 6 | """sklearn style compute eer 7 | """ 8 | fpr, tpr, thresholds = roc_curve(labels, scores, pos_label=1) 9 | eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0) 10 | threshold = interp1d(fpr, thresholds)(eer) 11 | return eer, threshold 12 | 13 | 14 | def compute_minDCF(labels, scores, p_target=0.01, c_miss=1, c_fa=1): 15 | """MinDCF 16 | Computes the minimum of the detection cost function. The comments refer to 17 | equations in Section 3 of the NIST 2016 Speaker Recognition Evaluation Plan. 18 | """ 19 | fpr, tpr, thresholds = roc_curve(labels, scores, pos_label=1) 20 | fnr = 1.0 - tpr 21 | 22 | min_c_det = float("inf") 23 | min_c_det_threshold = thresholds[0] 24 | for i in range(0, len(fnr)): 25 | c_det = c_miss * fnr[i] * p_target + c_fa * fpr[i] * (1 - p_target) 26 | if c_det < min_c_det: 27 | min_c_det = c_det 28 | min_c_det_threshold = thresholds[i] 29 | c_def = min(c_miss * p_target, c_fa * (1 - p_target)) 30 | min_dcf = min_c_det / c_def 31 | return min_dcf, min_c_det_threshold 32 | 33 | -------------------------------------------------------------------------------- /scripts/build_datalist.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | import argparse 5 | import os 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import tqdm 10 | 11 | 12 | def findAllSeqs(dirName, 13 | extension='.wav', 14 | load_data_list=False, 15 | speaker_level=1): 16 | r""" 17 | Lists all the sequences with the given extension in the dirName directory. 18 | Output: 19 | outSequences, speakers 20 | outSequence 21 | A list of tuples seq_path, speaker where: 22 | - seq_path is the relative path of each sequence relative to the 23 | parent directory 24 | - speaker is the corresponding speaker index 25 | outSpeakers 26 | The speaker labels (in order) 27 | The speaker labels are organized the following way 28 | \dirName 29 | \speaker_label 30 | \.. 31 | ... 32 | seqName.extension 33 | Adjust the value of speaker_level if you want to choose which level of 34 | directory defines the speaker label. Ex if speaker_level == 2 then the 35 | dataset should be organized in the following fashion 36 | \dirName 37 | \crappy_label 38 | \speaker_label 39 | \.. 40 | ... 41 | seqName.extension 42 | Set speaker_label == 0 if no speaker label will be retrieved no matter the 43 | organization of the dataset. 44 | """ 45 | if dirName[-1] != os.sep: 46 | dirName += os.sep 47 | prefixSize = len(dirName) 48 | speakersTarget = {} 49 | outSequences = [] 50 | print("finding {}, Waiting...".format(extension)) 51 | for root, dirs, filenames in tqdm.tqdm(os.walk(dirName, followlinks=True)): 52 | filtered_files = [f for f in filenames if f.endswith(extension)] 53 | if len(filtered_files) > 0: 54 | speakerStr = (os.sep).join( 55 | root[prefixSize:].split(os.sep)[:speaker_level]) 56 | if speakerStr not in speakersTarget: 57 | speakersTarget[speakerStr] = len(speakersTarget) 58 | speaker = speakersTarget[speakerStr] 59 | for filename in filtered_files: 60 | full_path = os.path.join(root, filename) 61 | outSequences.append((speaker, full_path)) 62 | outSpeakers = [None for x in speakersTarget] 63 | 64 | for key, index in speakersTarget.items(): 65 | outSpeakers[index] = key 66 | 67 | print("find {} speakers".format(len(outSpeakers))) 68 | print("find {} utterance".format(len(outSequences))) 69 | 70 | return outSequences, outSpeakers 71 | 72 | 73 | if __name__ == "__main__": 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument( 76 | '--extension', help='file extension name', type=str, default="wav") 77 | parser.add_argument('--dataset_dir', help='dataset dir', 78 | type=str, default="data") 79 | parser.add_argument('--data_list_path', 80 | help='list save path', type=str, default="data_list") 81 | parser.add_argument('--speaker_level', 82 | help='list save path', type=int, default=1) 83 | args = parser.parse_args() 84 | 85 | outSequences, outSpeakers = findAllSeqs(args.dataset_dir, 86 | extension=args.extension, 87 | load_data_list=False, 88 | speaker_level=1) 89 | 90 | outSequences = np.array(outSequences, dtype=str) 91 | utt_spk_int_labels = outSequences.T[0].astype(int) 92 | utt_paths = outSequences.T[1] 93 | utt_spk_str_labels = [] 94 | for i in utt_spk_int_labels: 95 | utt_spk_str_labels.append(outSpeakers[i]) 96 | 97 | csv_dict = {"speaker_name": utt_spk_str_labels, 98 | "utt_paths": utt_paths, 99 | "utt_spk_int_labels": utt_spk_int_labels 100 | } 101 | df = pd.DataFrame(data=csv_dict) 102 | 103 | try: 104 | df.to_csv(args.data_list_path) 105 | print(f'Saved data list file at {args.data_list_path}') 106 | except OSError as err: 107 | print(f'Ran in an error while saving {args.data_list_path}: {err}') 108 | -------------------------------------------------------------------------------- /scripts/format_trials.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | import argparse 5 | import os 6 | 7 | import numpy as np 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--voxceleb1_root', help='voxceleb1_root', type=str, 12 | default="datasets/VoxCeleb/voxceleb1/") 13 | parser.add_argument('--src_trials_path', help='src_trials_path', 14 | type=str, default="voxceleb1_test_v2.txt") 15 | parser.add_argument('--dst_trials_path', help='dst_trials_path', 16 | type=str, default="data/trial.lst") 17 | args = parser.parse_args() 18 | 19 | trials = np.loadtxt(args.src_trials_path, dtype=str) 20 | 21 | f = open(args.dst_trials_path, "w") 22 | for item in trials: 23 | enroll_path = os.path.join( 24 | args.voxceleb1_root, "wav", item[1]) 25 | test_path = os.path.join(args.voxceleb1_root, "wav", item[2]) 26 | f.write("{} {} {}\n".format(item[0], enroll_path, test_path)) 27 | -------------------------------------------------------------------------------- /scripts/make_balanced_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | import random 4 | 5 | def build_balance_data(args): 6 | data = pd.read_csv(args.data_path) 7 | 8 | labels = data["utt_spk_int_labels"].values 9 | name = data["speaker_name"].values 10 | paths = data["utt_paths"].values 11 | durations = data["durations"].values 12 | 13 | #将各列数据都依label为索引值,建立字典方便使用label值来查找 14 | dict_name = {} 15 | dict_paths = {} 16 | dict_durations = {} 17 | for idx, label in enumerate(labels): 18 | if label not in dict_paths: 19 | dict_name[label] = name[idx] 20 | dict_paths[label] = [] 21 | dict_durations[label] = [] 22 | if abs(durations[idx] - 9) < 3: #筛选语音长度,保证单条语音的长度不至于过大,也趋近于平均值 23 | dict_paths[label].append(paths[idx]) 24 | dict_durations[label].append(durations[idx]) 25 | 26 | 27 | #产生随机的说话人(args.num_spk不同个labels),保存到列表random_num_spk 28 | candi_spk = [] 29 | for label in range(max(labels) + 1): 30 | if args.utt_per_spk <= len(dict_paths[label]): #筛选候选集合,保证长度足够可选 31 | candi_spk.append(label) 32 | random_num_spk = random.sample(candi_spk, args.num_spk) 33 | 34 | 35 | result_name = [] 36 | result_paths = [] 37 | result_durations = [] 38 | result_labels = [] 39 | for label in random_num_spk: #dict_name[label] dict_paths[label] label dict_durations[label] 40 | #对于每一个随机选出来的spk(label),下面再随机选出utt_per_spk条不同的语音下标,保存到列表random_utt_per_spk 41 | candi_utt = [i for i in range(len(dict_paths[label]))] 42 | random_utt_per_spk = random.sample(candi_utt, args.utt_per_spk) 43 | #保存结果 44 | result_labels.extend([label] * args.utt_per_spk) 45 | for idx in random_utt_per_spk: 46 | result_name.append(dict_name[label]) 47 | result_paths.append(dict_paths[label][idx]) 48 | result_durations.append(dict_durations[label][idx]) 49 | 50 | table = {} 51 | for idx, label in enumerate(set(result_labels)): 52 | table[label] = idx 53 | 54 | labels = [] 55 | for label in result_labels: 56 | labels.append(table[label]) 57 | 58 | #写到csv文件 59 | dic = {'speaker_name': result_name, 'utt_paths': result_paths, 'utt_spk_int_labels': labels, 'durations': result_durations} 60 | df = pd.DataFrame(dic) 61 | df.to_csv(args.save_path) 62 | 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('--data_path', type=str, default="data/train.csv") 67 | parser.add_argument('--save_path', type=str, default="balance.csv") 68 | parser.add_argument('--num_spk', type=int, default=1211) 69 | parser.add_argument('--utt_per_spk', type=int, default=122) 70 | args = parser.parse_args() 71 | 72 | build_balance_data(args) 73 | -------------------------------------------------------------------------------- /scripts/make_cohort_set.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | import numpy as np 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--data_list_path', type=str, default="data/train.csv") 8 | parser.add_argument('--cohort_save_path', type=str, default="data/cohort.csv") 9 | parser.add_argument('--num_cohort', type=int, default=3000) 10 | args = parser.parse_args() 11 | 12 | data = pd.read_csv(args.data_list_path) 13 | utt_paths = data["utt_paths"].values 14 | np.random.shuffle(utt_paths) 15 | utt_paths = utt_paths[:args.num_cohort] 16 | with open(args.cohort_save_path, "w") as f: 17 | for item in utt_paths: 18 | f.write(item) 19 | f.write("\n") 20 | -------------------------------------------------------------------------------- /scripts/make_tsne_set.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | import numpy as np 4 | import random 5 | 6 | def random_dataset(args): 7 | data = pd.read_csv(args.data_list_path) 8 | 9 | labels = data["utt_spk_int_labels"].values 10 | name = data["speaker_name"].values 11 | paths = data["utt_paths"].values 12 | #durations = data["durations"].values 13 | 14 | #将各列数据都依label为索引值,建立字典方便使用label值来查找 15 | dict_name = {} 16 | dict_paths = {} 17 | #dict_durations = {} 18 | for idx, label in enumerate(labels): 19 | if label not in dict_paths: 20 | dict_name[label] = name[idx] 21 | dict_paths[label] = [] 22 | #dict_durations[label] = [] 23 | dict_paths[label].append(paths[idx]) 24 | #dict_durations[label].append(durations[idx]) 25 | 26 | 27 | #产生随机的说话人(args.num_spk不同个labels),保存到列表random_num_spk 28 | candi_spk = [] 29 | for label in range(max(labels) + 1): 30 | if args.utt_per_spk <= len(dict_paths[label]): #筛选候选集合,保证长度足够可选 31 | candi_spk.append(label) 32 | 33 | random_num_spk = random.sample(candi_spk, args.num_spk) 34 | 35 | 36 | result_name = [] 37 | result_paths = [] 38 | #result_durations = [] 39 | result_labels = [] 40 | for label in random_num_spk: #dict_name[label] dict_paths[label] label dict_durations[label] 41 | #对于每一个随机选出来的spk(label),下面再随机选出utt_per_spk条不同的语音下标,保存到列表random_utt_per_spk 42 | candi_utt = [i for i in range(len(dict_paths[label]))] 43 | random_utt_per_spk = random.sample(candi_utt, args.utt_per_spk) 44 | #保存结果 45 | result_labels.extend([label] * args.utt_per_spk) 46 | for idx in random_utt_per_spk: 47 | result_name.append(dict_name[label]) 48 | result_paths.append(dict_paths[label][idx]) 49 | #result_durations.append(dict_durations[label][idx]) 50 | 51 | #写到csv文件 52 | #dict = {'speaker_name': result_name, 'utt_paths': result_paths, 'utt_spk_int_labels': result_labels, 'durations': result_durations} 53 | 54 | label_set = set(result_labels) 55 | table = {} 56 | for idx, s in enumerate(label_set): 57 | table[s] = idx 58 | 59 | new_labels = [] 60 | for label in result_labels: 61 | new_labels.append(table[label]) 62 | 63 | dic = {'speaker_name': result_name, 'utt_paths': result_paths, 'utt_spk_int_labels': new_labels} 64 | df = pd.DataFrame(dic) 65 | df.to_csv(args.tsne_set_save_path) 66 | 67 | 68 | if __name__ == "__main__": 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument('--data_list_path', type=str, default="data.csv") 71 | parser.add_argument('--tsne_set_save_path', type=str, default="tsne.csv") 72 | parser.add_argument('--num_spk', type=int, default=20) 73 | parser.add_argument('--utt_per_spk', type=int, default=200) 74 | args = parser.parse_args() 75 | 76 | random_dataset(args) 77 | -------------------------------------------------------------------------------- /scripts/plot_score.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | labels, scores = np.loadtxt("score.txt").T 5 | 6 | target_score = [] 7 | nontarget_score = [] 8 | 9 | for idx,i in enumerate(labels): 10 | if i == 0: 11 | nontarget_score.append(scores[idx]) 12 | else: 13 | target_score.append(scores[idx]) 14 | 15 | print(scores.shape) 16 | print(labels.shape) 17 | 18 | plt.hist(target_score, bins=100, label="target score") 19 | plt.hist(nontarget_score, bins=100, label="nontarget score") 20 | plt.legend() 21 | plt.tight_layout() 22 | plt.savefig("test.png") 23 | -------------------------------------------------------------------------------- /start.sh: -------------------------------------------------------------------------------- 1 | encoder_name="conformer_cat" # conformer_cat | ecapa_tdnn_large | resnet34 2 | embedding_dim=192 3 | loss_name="amsoftmax" 4 | 5 | dataset="vox" 6 | num_classes=7205 7 | num_blocks=6 8 | train_csv_path="data/train.csv" 9 | 10 | input_layer=conv2d2 11 | pos_enc_layer_type=rel_pos # no_pos| rel_pos 12 | save_dir=experiment/${input_layer}/${encoder_name}_${num_blocks}_${embedding_dim}_${loss_name} 13 | trial_path=data/vox1_test.txt 14 | 15 | mkdir -p $save_dir 16 | cp start.sh $save_dir 17 | cp main.py $save_dir 18 | cp -r module $save_dir 19 | cp -r wenet $save_dir 20 | cp -r scripts $save_dir 21 | cp -r loss $save_dir 22 | echo save_dir: $save_dir 23 | 24 | export CUDA_VISIBLE_DEVICES=0 25 | python3 main.py \ 26 | --batch_size 200 \ 27 | --num_workers 40 \ 28 | --max_epochs 30 \ 29 | --embedding_dim $embedding_dim \ 30 | --save_dir $save_dir \ 31 | --encoder_name $encoder_name \ 32 | --train_csv_path $train_csv_path \ 33 | --learning_rate 0.001 \ 34 | --encoder_name ${encoder_name} \ 35 | --num_classes $num_classes \ 36 | --trial_path $trial_path \ 37 | --loss_name $loss_name \ 38 | --num_blocks $num_blocks \ 39 | --step_size 4 \ 40 | --gamma 0.5 \ 41 | --weight_decay 0.0000001 \ 42 | --input_layer $input_layer \ 43 | --pos_enc_layer_type $pos_enc_layer_type 44 | 45 | -------------------------------------------------------------------------------- /wenet/transformer/attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | """Multi-Head Attention layer definition.""" 7 | 8 | import math 9 | from typing import Optional, Tuple 10 | 11 | import torch 12 | from torch import nn 13 | 14 | 15 | class MultiHeadedAttention(nn.Module): 16 | """Multi-Head Attention layer. 17 | 18 | Args: 19 | n_head (int): The number of heads. 20 | n_feat (int): The number of features. 21 | dropout_rate (float): Dropout rate. 22 | 23 | """ 24 | def __init__(self, n_head: int, n_feat: int, dropout_rate: float): 25 | """Construct an MultiHeadedAttention object.""" 26 | super().__init__() 27 | assert n_feat % n_head == 0 28 | # We assume d_v always equals d_k 29 | self.d_k = n_feat // n_head 30 | self.h = n_head 31 | self.linear_q = nn.Linear(n_feat, n_feat) 32 | self.linear_k = nn.Linear(n_feat, n_feat) 33 | self.linear_v = nn.Linear(n_feat, n_feat) 34 | self.linear_out = nn.Linear(n_feat, n_feat) 35 | self.dropout = nn.Dropout(p=dropout_rate) 36 | 37 | def forward_qkv( 38 | self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor 39 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 40 | """Transform query, key and value. 41 | 42 | Args: 43 | query (torch.Tensor): Query tensor (#batch, time1, size). 44 | key (torch.Tensor): Key tensor (#batch, time2, size). 45 | value (torch.Tensor): Value tensor (#batch, time2, size). 46 | 47 | Returns: 48 | torch.Tensor: Transformed query tensor, size 49 | (#batch, n_head, time1, d_k). 50 | torch.Tensor: Transformed key tensor, size 51 | (#batch, n_head, time2, d_k). 52 | torch.Tensor: Transformed value tensor, size 53 | (#batch, n_head, time2, d_k). 54 | 55 | """ 56 | n_batch = query.size(0) 57 | q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) 58 | k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) 59 | v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) 60 | q = q.transpose(1, 2) # (batch, head, time1, d_k) 61 | k = k.transpose(1, 2) # (batch, head, time2, d_k) 62 | v = v.transpose(1, 2) # (batch, head, time2, d_k) 63 | 64 | return q, k, v 65 | 66 | def forward_attention(self, value: torch.Tensor, scores: torch.Tensor, 67 | mask: Optional[torch.Tensor]) -> torch.Tensor: 68 | """Compute attention context vector. 69 | 70 | Args: 71 | value (torch.Tensor): Transformed value, size 72 | (#batch, n_head, time2, d_k). 73 | scores (torch.Tensor): Attention score, size 74 | (#batch, n_head, time1, time2). 75 | mask (torch.Tensor): Mask, size (#batch, 1, time2) or 76 | (#batch, time1, time2). 77 | 78 | Returns: 79 | torch.Tensor: Transformed value (#batch, time1, d_model) 80 | weighted by the attention score (#batch, time1, time2). 81 | 82 | """ 83 | n_batch = value.size(0) 84 | if mask is not None: 85 | mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) 86 | scores = scores.masked_fill(mask, -float('inf')) 87 | attn = torch.softmax(scores, dim=-1).masked_fill( 88 | mask, 0.0) # (batch, head, time1, time2) 89 | else: 90 | attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) 91 | 92 | p_attn = self.dropout(attn) 93 | x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) 94 | x = (x.transpose(1, 2).contiguous().view(n_batch, -1, 95 | self.h * self.d_k) 96 | ) # (batch, time1, d_model) 97 | 98 | return self.linear_out(x) # (batch, time1, d_model) 99 | 100 | def forward(self, query: torch.Tensor, key: torch.Tensor, 101 | value: torch.Tensor, 102 | mask: Optional[torch.Tensor], 103 | pos_emb: torch.Tensor = torch.empty(0),) -> torch.Tensor: 104 | """Compute scaled dot product attention. 105 | 106 | Args: 107 | query (torch.Tensor): Query tensor (#batch, time1, size). 108 | key (torch.Tensor): Key tensor (#batch, time2, size). 109 | value (torch.Tensor): Value tensor (#batch, time2, size). 110 | mask (torch.Tensor): Mask tensor (#batch, 1, time2) or 111 | (#batch, time1, time2). 112 | 1.When applying cross attention between decoder and encoder, 113 | the batch padding mask for input is in (#batch, 1, T) shape. 114 | 2.When applying self attention of encoder, 115 | the mask is in (#batch, T, T) shape. 116 | 3.When applying self attention of decoder, 117 | the mask is in (#batch, L, L) shape. 118 | 4.If the different position in decoder see different block 119 | of the encoder, such as Mocha, the passed in mask could be 120 | in (#batch, L, T) shape. But there is no such case in current 121 | Wenet. 122 | 123 | 124 | Returns: 125 | torch.Tensor: Output tensor (#batch, time1, d_model). 126 | 127 | """ 128 | q, k, v = self.forward_qkv(query, key, value) 129 | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) 130 | return self.forward_attention(v, scores, mask) 131 | 132 | 133 | class RelPositionMultiHeadedAttention(MultiHeadedAttention): 134 | """Multi-Head Attention layer with relative position encoding. 135 | Paper: https://arxiv.org/abs/1901.02860 136 | Args: 137 | n_head (int): The number of heads. 138 | n_feat (int): The number of features. 139 | dropout_rate (float): Dropout rate. 140 | """ 141 | def __init__(self, n_head, n_feat, dropout_rate): 142 | """Construct an RelPositionMultiHeadedAttention object.""" 143 | super().__init__(n_head, n_feat, dropout_rate) 144 | # linear transformation for positional encoding 145 | self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) 146 | # these two learnable bias are used in matrix c and matrix d 147 | # as described in https://arxiv.org/abs/1901.02860 Section 3.3 148 | self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) 149 | self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) 150 | torch.nn.init.xavier_uniform_(self.pos_bias_u) 151 | torch.nn.init.xavier_uniform_(self.pos_bias_v) 152 | 153 | def rel_shift(self, x, zero_triu: bool = False): 154 | """Compute relative positinal encoding. 155 | Args: 156 | x (torch.Tensor): Input tensor (batch, time, size). 157 | zero_triu (bool): If true, return the lower triangular part of 158 | the matrix. 159 | Returns: 160 | torch.Tensor: Output tensor. 161 | """ 162 | 163 | zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1), 164 | device=x.device, 165 | dtype=x.dtype) 166 | x_padded = torch.cat([zero_pad, x], dim=-1) 167 | 168 | x_padded = x_padded.view(x.size()[0], 169 | x.size()[1], 170 | x.size(3) + 1, x.size(2)) 171 | x = x_padded[:, :, 1:].view_as(x) 172 | 173 | if zero_triu: 174 | ones = torch.ones((x.size(2), x.size(3))) 175 | x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] 176 | 177 | return x 178 | 179 | def forward(self, query: torch.Tensor, key: torch.Tensor, 180 | value: torch.Tensor, mask: Optional[torch.Tensor], 181 | pos_emb: torch.Tensor): 182 | """Compute 'Scaled Dot Product Attention' with rel. positional encoding. 183 | Args: 184 | query (torch.Tensor): Query tensor (#batch, time1, size). 185 | key (torch.Tensor): Key tensor (#batch, time2, size). 186 | value (torch.Tensor): Value tensor (#batch, time2, size). 187 | mask (torch.Tensor): Mask tensor (#batch, 1, time2) or 188 | (#batch, time1, time2). 189 | pos_emb (torch.Tensor): Positional embedding tensor 190 | (#batch, time2, size). 191 | Returns: 192 | torch.Tensor: Output tensor (#batch, time1, d_model). 193 | """ 194 | q, k, v = self.forward_qkv(query, key, value) 195 | q = q.transpose(1, 2) # (batch, time1, head, d_k) 196 | 197 | n_batch_pos = pos_emb.size(0) 198 | p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) 199 | p = p.transpose(1, 2) # (batch, head, time1, d_k) 200 | 201 | # (batch, head, time1, d_k) 202 | q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) 203 | # (batch, head, time1, d_k) 204 | q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) 205 | 206 | # compute attention score 207 | # first compute matrix a and matrix c 208 | # as described in https://arxiv.org/abs/1901.02860 Section 3.3 209 | # (batch, head, time1, time2) 210 | matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) 211 | 212 | # compute matrix b and matrix d 213 | # (batch, head, time1, time2) 214 | matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) 215 | # Remove rel_shift since it is useless in speech recognition, 216 | # and it requires special attention for streaming. 217 | # matrix_bd = self.rel_shift(matrix_bd) 218 | 219 | scores = (matrix_ac + matrix_bd) / math.sqrt( 220 | self.d_k) # (batch, head, time1, time2) 221 | 222 | return self.forward_attention(v, scores, mask) 223 | -------------------------------------------------------------------------------- /wenet/transformer/cmvn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | 19 | class GlobalCMVN(torch.nn.Module): 20 | def __init__(self, 21 | mean: torch.Tensor, 22 | istd: torch.Tensor, 23 | norm_var: bool = True): 24 | """ 25 | Args: 26 | mean (torch.Tensor): mean stats 27 | istd (torch.Tensor): inverse std, std which is 1.0 / std 28 | """ 29 | super().__init__() 30 | assert mean.shape == istd.shape 31 | self.norm_var = norm_var 32 | # The buffer can be accessed from this module using self.mean 33 | self.register_buffer("mean", mean) 34 | self.register_buffer("istd", istd) 35 | 36 | def forward(self, x: torch.Tensor): 37 | """ 38 | Args: 39 | x (torch.Tensor): (batch, max_len, feat_dim) 40 | 41 | Returns: 42 | (torch.Tensor): normalized feature 43 | """ 44 | x = x - self.mean 45 | if self.norm_var: 46 | x = x * self.istd 47 | return x 48 | -------------------------------------------------------------------------------- /wenet/transformer/convolution.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2021 Mobvoi Inc. All Rights Reserved. 5 | # Author: di.wu@mobvoi.com (DI WU) 6 | """ConvolutionModule definition.""" 7 | 8 | from typing import Optional, Tuple 9 | 10 | import torch 11 | from torch import nn 12 | from typeguard import check_argument_types 13 | 14 | 15 | class ConvolutionModule(nn.Module): 16 | """ConvolutionModule in Conformer model.""" 17 | def __init__(self, 18 | channels: int, 19 | kernel_size: int = 15, 20 | activation: nn.Module = nn.ReLU(), 21 | norm: str = "batch_norm", 22 | causal: bool = False, 23 | bias: bool = True): 24 | """Construct an ConvolutionModule object. 25 | Args: 26 | channels (int): The number of channels of conv layers. 27 | kernel_size (int): Kernel size of conv layers. 28 | causal (int): Whether use causal convolution or not 29 | """ 30 | assert check_argument_types() 31 | super().__init__() 32 | 33 | self.pointwise_conv1 = nn.Conv1d( 34 | channels, 35 | 2 * channels, 36 | kernel_size=1, 37 | stride=1, 38 | padding=0, 39 | bias=bias, 40 | ) 41 | # self.lorder is used to distinguish if it's a causal convolution, 42 | # if self.lorder > 0: it's a causal convolution, the input will be 43 | # padded with self.lorder frames on the left in forward. 44 | # else: it's a symmetrical convolution 45 | if causal: 46 | padding = 0 47 | self.lorder = kernel_size - 1 48 | else: 49 | # kernel_size should be an odd number for none causal convolution 50 | assert (kernel_size - 1) % 2 == 0 51 | padding = (kernel_size - 1) // 2 52 | self.lorder = 0 53 | self.depthwise_conv = nn.Conv1d( 54 | channels, 55 | channels, 56 | kernel_size, 57 | stride=1, 58 | padding=padding, 59 | groups=channels, 60 | bias=bias, 61 | ) 62 | 63 | assert norm in ['batch_norm', 'layer_norm'] 64 | if norm == "batch_norm": 65 | self.use_layer_norm = False 66 | self.norm = nn.BatchNorm1d(channels) 67 | else: 68 | self.use_layer_norm = True 69 | self.norm = nn.LayerNorm(channels) 70 | 71 | self.pointwise_conv2 = nn.Conv1d( 72 | channels, 73 | channels, 74 | kernel_size=1, 75 | stride=1, 76 | padding=0, 77 | bias=bias, 78 | ) 79 | self.activation = activation 80 | 81 | def forward( 82 | self, 83 | x: torch.Tensor, 84 | mask_pad: Optional[torch.Tensor] = None, 85 | cache: Optional[torch.Tensor] = None, 86 | ) -> Tuple[torch.Tensor, torch.Tensor]: 87 | """Compute convolution module. 88 | Args: 89 | x (torch.Tensor): Input tensor (#batch, time, channels). 90 | mask_pad (torch.Tensor): used for batch padding (#batch, 1, time) 91 | cache (torch.Tensor): left context cache, it is only 92 | used in causal convolution 93 | Returns: 94 | torch.Tensor: Output tensor (#batch, time, channels). 95 | """ 96 | # exchange the temporal dimension and the feature dimension 97 | x = x.transpose(1, 2) # (#batch, channels, time) 98 | 99 | # mask batch padding 100 | if mask_pad is not None: 101 | x.masked_fill_(~mask_pad, 0.0) 102 | 103 | if self.lorder > 0: 104 | if cache is None: 105 | x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) 106 | else: 107 | assert cache.size(0) == x.size(0) 108 | assert cache.size(1) == x.size(1) 109 | x = torch.cat((cache, x), dim=2) 110 | assert (x.size(2) > self.lorder) 111 | new_cache = x[:, :, -self.lorder:] 112 | else: 113 | # It's better we just return None if no cache is requried, 114 | # However, for JIT export, here we just fake one tensor instead of 115 | # None. 116 | new_cache = torch.tensor([0.0], dtype=x.dtype, device=x.device) 117 | 118 | # GLU mechanism 119 | x = self.pointwise_conv1(x) # (batch, 2*channel, dim) 120 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim) 121 | 122 | # 1D Depthwise Conv 123 | x = self.depthwise_conv(x) 124 | if self.use_layer_norm: 125 | x = x.transpose(1, 2) 126 | x = self.activation(self.norm(x)) 127 | if self.use_layer_norm: 128 | x = x.transpose(1, 2) 129 | x = self.pointwise_conv2(x) 130 | # mask batch padding 131 | if mask_pad is not None: 132 | x.masked_fill_(~mask_pad, 0.0) 133 | 134 | return x.transpose(1, 2), new_cache 135 | -------------------------------------------------------------------------------- /wenet/transformer/embedding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Mobvoi Inc. All Rights Reserved. 5 | # Author: di.wu@mobvoi.com (DI WU) 6 | """Positonal Encoding Module.""" 7 | 8 | import math 9 | from typing import Tuple 10 | 11 | import torch 12 | 13 | 14 | class PositionalEncoding(torch.nn.Module): 15 | """Positional encoding. 16 | 17 | :param int d_model: embedding dim 18 | :param float dropout_rate: dropout rate 19 | :param int max_len: maximum input length 20 | 21 | PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) 22 | PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) 23 | """ 24 | def __init__(self, 25 | d_model: int, 26 | dropout_rate: float, 27 | max_len: int = 50000, 28 | reverse: bool = False): 29 | """Construct an PositionalEncoding object.""" 30 | super().__init__() 31 | self.d_model = d_model 32 | self.xscale = math.sqrt(self.d_model) 33 | self.dropout = torch.nn.Dropout(p=dropout_rate) 34 | self.max_len = max_len 35 | 36 | self.pe = torch.zeros(self.max_len, self.d_model) 37 | position = torch.arange(0, self.max_len, 38 | dtype=torch.float32).unsqueeze(1) 39 | div_term = torch.exp( 40 | torch.arange(0, self.d_model, 2, dtype=torch.float32) * 41 | -(math.log(10000.0) / self.d_model)) 42 | self.pe[:, 0::2] = torch.sin(position * div_term) 43 | self.pe[:, 1::2] = torch.cos(position * div_term) 44 | self.pe = self.pe.unsqueeze(0) 45 | 46 | def forward(self, 47 | x: torch.Tensor, 48 | offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: 49 | """Add positional encoding. 50 | 51 | Args: 52 | x (torch.Tensor): Input. Its shape is (batch, time, ...) 53 | offset (int): position offset 54 | 55 | Returns: 56 | torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) 57 | torch.Tensor: for compatibility to RelPositionalEncoding 58 | """ 59 | assert offset + x.size(1) < self.max_len 60 | self.pe = self.pe.to(x.device) 61 | pos_emb = self.pe[:, offset:offset + x.size(1)] 62 | x = x * self.xscale + pos_emb 63 | return self.dropout(x), self.dropout(pos_emb) 64 | 65 | def position_encoding(self, offset: int, size: int) -> torch.Tensor: 66 | """ For getting encoding in a streaming fashion 67 | 68 | Attention!!!!! 69 | we apply dropout only once at the whole utterance level in a none 70 | streaming way, but will call this function several times with 71 | increasing input size in a streaming scenario, so the dropout will 72 | be applied several times. 73 | 74 | Args: 75 | offset (int): start offset 76 | size (int): requried size of position encoding 77 | 78 | Returns: 79 | torch.Tensor: Corresponding encoding 80 | """ 81 | assert offset + size < self.max_len 82 | return self.dropout(self.pe[:, offset:offset + size]) 83 | 84 | 85 | class RelPositionalEncoding(PositionalEncoding): 86 | """Relative positional encoding module. 87 | See : Appendix B in https://arxiv.org/abs/1901.02860 88 | Args: 89 | d_model (int): Embedding dimension. 90 | dropout_rate (float): Dropout rate. 91 | max_len (int): Maximum input length. 92 | """ 93 | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 100000): 94 | """Initialize class.""" 95 | super().__init__(d_model, dropout_rate, max_len, reverse=True) 96 | 97 | def forward(self, 98 | x: torch.Tensor, 99 | offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: 100 | """Compute positional encoding. 101 | Args: 102 | x (torch.Tensor): Input tensor (batch, time, `*`). 103 | Returns: 104 | torch.Tensor: Encoded tensor (batch, time, `*`). 105 | torch.Tensor: Positional embedding tensor (1, time, `*`). 106 | """ 107 | assert offset + x.size(1) < self.max_len 108 | self.pe = self.pe.to(x.device) 109 | x = x * self.xscale 110 | pos_emb = self.pe[:, offset:offset + x.size(1)] 111 | return self.dropout(x), self.dropout(pos_emb) 112 | 113 | 114 | class NoPositionalEncoding(torch.nn.Module): 115 | """ No position encoding 116 | """ 117 | def __init__(self, d_model: int, dropout_rate: float): 118 | super().__init__() 119 | self.d_model = d_model 120 | self.dropout = torch.nn.Dropout(p=dropout_rate) 121 | 122 | def forward(self, 123 | x: torch.Tensor, 124 | offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: 125 | """ Just return zero vector for interface compatibility 126 | """ 127 | pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device) 128 | return self.dropout(x), pos_emb 129 | 130 | def position_encoding(self, offset: int, size: int) -> torch.Tensor: 131 | return torch.zeros(1, size, self.d_model) 132 | -------------------------------------------------------------------------------- /wenet/transformer/encoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Mobvoi Inc. All Rights Reserved. 5 | # Author: di.wu@mobvoi.com (DI WU) 6 | """Encoder definition.""" 7 | from typing import Tuple, List, Optional 8 | 9 | import torch 10 | from typeguard import check_argument_types 11 | 12 | from wenet.transformer.attention import MultiHeadedAttention 13 | from wenet.transformer.attention import RelPositionMultiHeadedAttention 14 | from wenet.transformer.convolution import ConvolutionModule 15 | from wenet.transformer.embedding import PositionalEncoding 16 | from wenet.transformer.embedding import RelPositionalEncoding 17 | from wenet.transformer.embedding import NoPositionalEncoding 18 | from wenet.transformer.encoder_layer import TransformerEncoderLayer 19 | from wenet.transformer.encoder_layer import ConformerEncoderLayer 20 | from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward 21 | from wenet.transformer.subsampling import Conv2dSubsampling2 22 | from wenet.transformer.subsampling import Conv2dSubsampling4 23 | from wenet.transformer.subsampling import Conv2dSubsampling6 24 | from wenet.transformer.subsampling import Conv2dSubsampling8 25 | 26 | from wenet.transformer.subsampling import LinearNoSubsampling 27 | from wenet.utils.common import get_activation 28 | from wenet.utils.mask import make_pad_mask 29 | from wenet.utils.mask import add_optional_chunk_mask 30 | 31 | 32 | class BaseEncoder(torch.nn.Module): 33 | def __init__( 34 | self, 35 | input_size: int, 36 | output_size: int = 256, 37 | attention_heads: int = 4, 38 | linear_units: int = 2048, 39 | num_blocks: int = 6, 40 | dropout_rate: float = 0.1, 41 | positional_dropout_rate: float = 0.1, 42 | attention_dropout_rate: float = 0.0, 43 | input_layer: str = "conv2d", 44 | pos_enc_layer_type: str = "abs_pos", 45 | normalize_before: bool = True, 46 | concat_after: bool = False, 47 | static_chunk_size: int = 0, 48 | use_dynamic_chunk: bool = False, 49 | global_cmvn: torch.nn.Module = None, 50 | use_dynamic_left_chunk: bool = False, 51 | ): 52 | """ 53 | Args: 54 | input_size (int): input dim 55 | output_size (int): dimension of attention 56 | attention_heads (int): the number of heads of multi head attention 57 | linear_units (int): the hidden units number of position-wise feed 58 | forward 59 | num_blocks (int): the number of decoder blocks 60 | dropout_rate (float): dropout rate 61 | attention_dropout_rate (float): dropout rate in attention 62 | positional_dropout_rate (float): dropout rate after adding 63 | positional encoding 64 | input_layer (str): input layer type. 65 | optional [linear, conv2d, conv2d6, conv2d8] 66 | pos_enc_layer_type (str): Encoder positional encoding layer type. 67 | opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos] 68 | normalize_before (bool): 69 | True: use layer_norm before each sub-block of a layer. 70 | False: use layer_norm after each sub-block of a layer. 71 | concat_after (bool): whether to concat attention layer's input 72 | and output. 73 | True: x -> x + linear(concat(x, att(x))) 74 | False: x -> x + att(x) 75 | static_chunk_size (int): chunk size for static chunk training and 76 | decoding 77 | use_dynamic_chunk (bool): whether use dynamic chunk size for 78 | training or not, You can only use fixed chunk(chunk_size > 0) 79 | or dyanmic chunk size(use_dynamic_chunk = True) 80 | global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module 81 | use_dynamic_left_chunk (bool): whether use dynamic left chunk in 82 | dynamic chunk training 83 | """ 84 | assert check_argument_types() 85 | super().__init__() 86 | self._output_size = output_size 87 | 88 | if pos_enc_layer_type == "abs_pos": 89 | pos_enc_class = PositionalEncoding 90 | elif pos_enc_layer_type == "rel_pos": 91 | pos_enc_class = RelPositionalEncoding 92 | elif pos_enc_layer_type == "no_pos": 93 | pos_enc_class = NoPositionalEncoding 94 | else: 95 | raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) 96 | 97 | if input_layer == "linear": 98 | subsampling_class = LinearNoSubsampling 99 | elif input_layer == "conv2d": 100 | subsampling_class = Conv2dSubsampling4 101 | elif input_layer == "conv2d6": 102 | subsampling_class = Conv2dSubsampling6 103 | elif input_layer == "conv2d8": 104 | subsampling_class = Conv2dSubsampling8 105 | elif input_layer == "conv2d2": 106 | subsampling_class = Conv2dSubsampling2 107 | else: 108 | raise ValueError("unknown input_layer: " + input_layer) 109 | 110 | self.global_cmvn = global_cmvn 111 | self.embed = subsampling_class( 112 | input_size, 113 | output_size, 114 | dropout_rate, 115 | pos_enc_class(output_size, positional_dropout_rate), 116 | ) 117 | 118 | self.normalize_before = normalize_before 119 | self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-12) 120 | self.static_chunk_size = static_chunk_size 121 | self.use_dynamic_chunk = use_dynamic_chunk 122 | self.use_dynamic_left_chunk = use_dynamic_left_chunk 123 | 124 | def output_size(self) -> int: 125 | return self._output_size 126 | 127 | def forward( 128 | self, 129 | xs: torch.Tensor, 130 | xs_lens: torch.Tensor, 131 | decoding_chunk_size: int = 0, 132 | num_decoding_left_chunks: int = -1, 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """Embed positions in tensor. 135 | 136 | Args: 137 | xs: padded input tensor (B, T, D) 138 | xs_lens: input length (B) 139 | decoding_chunk_size: decoding chunk size for dynamic chunk 140 | 0: default for training, use random dynamic chunk. 141 | <0: for decoding, use full chunk. 142 | >0: for decoding, use fixed chunk size as set. 143 | num_decoding_left_chunks: number of left chunks, this is for decoding, 144 | the chunk size is decoding_chunk_size. 145 | >=0: use num_decoding_left_chunks 146 | <0: use all left chunks 147 | Returns: 148 | encoder output tensor xs, and subsampled masks 149 | xs: padded output tensor (B, T' ~= T/subsample_rate, D) 150 | masks: torch.Tensor batch padding mask after subsample 151 | (B, 1, T' ~= T/subsample_rate) 152 | """ 153 | masks = ~make_pad_mask(xs_lens).unsqueeze(1) # (B, 1, T) 154 | if self.global_cmvn is not None: 155 | xs = self.global_cmvn(xs) 156 | 157 | xs, pos_emb, masks = self.embed(xs, masks) 158 | mask_pad = masks # (B, 1, T/subsample_rate) 159 | chunk_masks = add_optional_chunk_mask(xs, masks, 160 | self.use_dynamic_chunk, 161 | self.use_dynamic_left_chunk, 162 | decoding_chunk_size, 163 | self.static_chunk_size, 164 | num_decoding_left_chunks) 165 | for layer in self.encoders: 166 | xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad) 167 | if self.normalize_before: 168 | xs = self.after_norm(xs) 169 | # Here we assume the mask is not changed in encoder layers, so just 170 | # return the masks before encoder layers, and the masks will be used 171 | # for cross attention with decoder later 172 | return xs, masks 173 | 174 | def forward_chunk( 175 | self, 176 | xs: torch.Tensor, 177 | offset: int, 178 | required_cache_size: int, 179 | subsampling_cache: Optional[torch.Tensor] = None, 180 | elayers_output_cache: Optional[List[torch.Tensor]] = None, 181 | conformer_cnn_cache: Optional[List[torch.Tensor]] = None, 182 | ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], 183 | List[torch.Tensor]]: 184 | """ Forward just one chunk 185 | 186 | Args: 187 | xs (torch.Tensor): chunk input 188 | offset (int): current offset in encoder output time stamp 189 | required_cache_size (int): cache size required for next chunk 190 | compuation 191 | >=0: actual cache size 192 | <0: means all history cache is required 193 | subsampling_cache (Optional[torch.Tensor]): subsampling cache 194 | elayers_output_cache (Optional[List[torch.Tensor]]): 195 | transformer/conformer encoder layers output cache 196 | conformer_cnn_cache (Optional[List[torch.Tensor]]): conformer 197 | cnn cache 198 | 199 | Returns: 200 | torch.Tensor: output of current input xs 201 | torch.Tensor: subsampling cache required for next chunk computation 202 | List[torch.Tensor]: encoder layers output cache required for next 203 | chunk computation 204 | List[torch.Tensor]: conformer cnn cache 205 | 206 | """ 207 | assert xs.size(0) == 1 208 | # tmp_masks is just for interface compatibility 209 | tmp_masks = torch.ones(1, 210 | xs.size(1), 211 | device=xs.device, 212 | dtype=torch.bool) 213 | tmp_masks = tmp_masks.unsqueeze(1) 214 | if self.global_cmvn is not None: 215 | xs = self.global_cmvn(xs) 216 | xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) 217 | if subsampling_cache is not None: 218 | cache_size = subsampling_cache.size(1) 219 | xs = torch.cat((subsampling_cache, xs), dim=1) 220 | else: 221 | cache_size = 0 222 | pos_emb = self.embed.position_encoding(offset - cache_size, xs.size(1)) 223 | if required_cache_size < 0: 224 | next_cache_start = 0 225 | elif required_cache_size == 0: 226 | next_cache_start = xs.size(1) 227 | else: 228 | next_cache_start = max(xs.size(1) - required_cache_size, 0) 229 | r_subsampling_cache = xs[:, next_cache_start:, :] 230 | # Real mask for transformer/conformer layers 231 | masks = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool) 232 | masks = masks.unsqueeze(1) 233 | r_elayers_output_cache = [] 234 | r_conformer_cnn_cache = [] 235 | for i, layer in enumerate(self.encoders): 236 | if elayers_output_cache is None: 237 | attn_cache = None 238 | else: 239 | attn_cache = elayers_output_cache[i] 240 | if conformer_cnn_cache is None: 241 | cnn_cache = None 242 | else: 243 | cnn_cache = conformer_cnn_cache[i] 244 | xs, _, new_cnn_cache = layer(xs, 245 | masks, 246 | pos_emb, 247 | output_cache=attn_cache, 248 | cnn_cache=cnn_cache) 249 | r_elayers_output_cache.append(xs[:, next_cache_start:, :]) 250 | r_conformer_cnn_cache.append(new_cnn_cache) 251 | if self.normalize_before: 252 | xs = self.after_norm(xs) 253 | 254 | return (xs[:, cache_size:, :], r_subsampling_cache, 255 | r_elayers_output_cache, r_conformer_cnn_cache) 256 | 257 | def forward_chunk_by_chunk( 258 | self, 259 | xs: torch.Tensor, 260 | decoding_chunk_size: int, 261 | num_decoding_left_chunks: int = -1, 262 | ) -> Tuple[torch.Tensor, torch.Tensor]: 263 | """ Forward input chunk by chunk with chunk_size like a streaming 264 | fashion 265 | 266 | Here we should pay special attention to computation cache in the 267 | streaming style forward chunk by chunk. Three things should be taken 268 | into account for computation in the current network: 269 | 1. transformer/conformer encoder layers output cache 270 | 2. convolution in conformer 271 | 3. convolution in subsampling 272 | 273 | However, we don't implement subsampling cache for: 274 | 1. We can control subsampling module to output the right result by 275 | overlapping input instead of cache left context, even though it 276 | wastes some computation, but subsampling only takes a very 277 | small fraction of computation in the whole model. 278 | 2. Typically, there are several covolution layers with subsampling 279 | in subsampling module, it is tricky and complicated to do cache 280 | with different convolution layers with different subsampling 281 | rate. 282 | 3. Currently, nn.Sequential is used to stack all the convolution 283 | layers in subsampling, we need to rewrite it to make it work 284 | with cache, which is not prefered. 285 | Args: 286 | xs (torch.Tensor): (1, max_len, dim) 287 | chunk_size (int): decoding chunk size 288 | """ 289 | assert decoding_chunk_size > 0 290 | # The model is trained by static or dynamic chunk 291 | assert self.static_chunk_size > 0 or self.use_dynamic_chunk 292 | subsampling = self.embed.subsampling_rate 293 | context = self.embed.right_context + 1 # Add current frame 294 | stride = subsampling * decoding_chunk_size 295 | decoding_window = (decoding_chunk_size - 1) * subsampling + context 296 | num_frames = xs.size(1) 297 | subsampling_cache: Optional[torch.Tensor] = None 298 | elayers_output_cache: Optional[List[torch.Tensor]] = None 299 | conformer_cnn_cache: Optional[List[torch.Tensor]] = None 300 | outputs = [] 301 | offset = 0 302 | required_cache_size = decoding_chunk_size * num_decoding_left_chunks 303 | 304 | # Feed forward overlap input step by step 305 | for cur in range(0, num_frames - context + 1, stride): 306 | end = min(cur + decoding_window, num_frames) 307 | chunk_xs = xs[:, cur:end, :] 308 | (y, subsampling_cache, elayers_output_cache, 309 | conformer_cnn_cache) = self.forward_chunk(chunk_xs, offset, 310 | required_cache_size, 311 | subsampling_cache, 312 | elayers_output_cache, 313 | conformer_cnn_cache) 314 | outputs.append(y) 315 | offset += y.size(1) 316 | ys = torch.cat(outputs, 1) 317 | masks = torch.ones(1, ys.size(1), device=ys.device, dtype=torch.bool) 318 | masks = masks.unsqueeze(1) 319 | return ys, masks 320 | 321 | 322 | class TransformerEncoder(BaseEncoder): 323 | """Transformer encoder module.""" 324 | def __init__( 325 | self, 326 | input_size: int, 327 | output_size: int = 256, 328 | attention_heads: int = 4, 329 | linear_units: int = 2048, 330 | num_blocks: int = 6, 331 | dropout_rate: float = 0.1, 332 | positional_dropout_rate: float = 0.1, 333 | attention_dropout_rate: float = 0.0, 334 | input_layer: str = "conv2d", 335 | pos_enc_layer_type: str = "abs_pos", 336 | normalize_before: bool = True, 337 | concat_after: bool = False, 338 | static_chunk_size: int = 0, 339 | use_dynamic_chunk: bool = False, 340 | global_cmvn: torch.nn.Module = None, 341 | use_dynamic_left_chunk: bool = False, 342 | ): 343 | """ Construct TransformerEncoder 344 | 345 | See Encoder for the meaning of each parameter. 346 | """ 347 | assert check_argument_types() 348 | super().__init__(input_size, output_size, attention_heads, 349 | linear_units, num_blocks, dropout_rate, 350 | positional_dropout_rate, attention_dropout_rate, 351 | input_layer, pos_enc_layer_type, normalize_before, 352 | concat_after, static_chunk_size, use_dynamic_chunk, 353 | global_cmvn, use_dynamic_left_chunk) 354 | self.encoders = torch.nn.ModuleList([ 355 | TransformerEncoderLayer( 356 | output_size, 357 | MultiHeadedAttention(attention_heads, output_size, 358 | attention_dropout_rate), 359 | PositionwiseFeedForward(output_size, linear_units, 360 | dropout_rate), dropout_rate, 361 | normalize_before, concat_after) for _ in range(num_blocks) 362 | ]) 363 | 364 | 365 | class ConformerEncoder(BaseEncoder): 366 | """Conformer encoder module.""" 367 | def __init__( 368 | self, 369 | input_size: int, 370 | output_size: int = 256, 371 | attention_heads: int = 4, 372 | linear_units: int = 2048, 373 | num_blocks: int = 6, 374 | dropout_rate: float = 0.1, 375 | positional_dropout_rate: float = 0.1, 376 | attention_dropout_rate: float = 0.0, 377 | input_layer: str = "conv2d", 378 | pos_enc_layer_type: str = "rel_pos", 379 | normalize_before: bool = True, 380 | concat_after: bool = False, 381 | static_chunk_size: int = 0, 382 | use_dynamic_chunk: bool = False, 383 | global_cmvn: torch.nn.Module = None, 384 | use_dynamic_left_chunk: bool = False, 385 | positionwise_conv_kernel_size: int = 1, 386 | macaron_style: bool = True, 387 | selfattention_layer_type: str = "rel_selfattn", 388 | activation_type: str = "swish", 389 | use_cnn_module: bool = True, 390 | cnn_module_kernel: int = 15, 391 | causal: bool = False, 392 | cnn_module_norm: str = "batch_norm", 393 | ): 394 | """Construct ConformerEncoder 395 | 396 | Args: 397 | input_size to use_dynamic_chunk, see in BaseEncoder 398 | positionwise_conv_kernel_size (int): Kernel size of positionwise 399 | conv1d layer. 400 | macaron_style (bool): Whether to use macaron style for 401 | positionwise layer. 402 | selfattention_layer_type (str): Encoder attention layer type, 403 | the parameter has no effect now, it's just for configure 404 | compatibility. 405 | activation_type (str): Encoder activation function type. 406 | use_cnn_module (bool): Whether to use convolution module. 407 | cnn_module_kernel (int): Kernel size of convolution module. 408 | causal (bool): whether to use causal convolution or not. 409 | """ 410 | assert check_argument_types() 411 | super().__init__(input_size, output_size, attention_heads, 412 | linear_units, num_blocks, dropout_rate, 413 | positional_dropout_rate, attention_dropout_rate, 414 | input_layer, pos_enc_layer_type, normalize_before, 415 | concat_after, static_chunk_size, use_dynamic_chunk, 416 | global_cmvn, use_dynamic_left_chunk) 417 | activation = get_activation(activation_type) 418 | 419 | # self-attention module definition 420 | if pos_enc_layer_type == "no_pos": 421 | encoder_selfattn_layer = MultiHeadedAttention 422 | else: 423 | encoder_selfattn_layer = RelPositionMultiHeadedAttention 424 | encoder_selfattn_layer_args = ( 425 | attention_heads, 426 | output_size, 427 | attention_dropout_rate, 428 | ) 429 | # feed-forward module definition 430 | positionwise_layer = PositionwiseFeedForward 431 | positionwise_layer_args = ( 432 | output_size, 433 | linear_units, 434 | dropout_rate, 435 | activation, 436 | ) 437 | # convolution module definition 438 | convolution_layer = ConvolutionModule 439 | convolution_layer_args = (output_size, cnn_module_kernel, activation, 440 | cnn_module_norm, causal) 441 | 442 | self.encoders = torch.nn.ModuleList([ 443 | ConformerEncoderLayer( 444 | output_size, 445 | encoder_selfattn_layer(*encoder_selfattn_layer_args), 446 | positionwise_layer(*positionwise_layer_args), 447 | positionwise_layer( 448 | *positionwise_layer_args) if macaron_style else None, 449 | convolution_layer( 450 | *convolution_layer_args) if use_cnn_module else None, 451 | dropout_rate, 452 | normalize_before, 453 | concat_after, 454 | ) for _ in range(num_blocks) 455 | ]) 456 | -------------------------------------------------------------------------------- /wenet/transformer/encoder_cat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Mobvoi Inc. All Rights Reserved. 5 | # Author: di.wu@mobvoi.com (DI WU) 6 | """Encoder definition.""" 7 | from typing import Tuple, List, Optional 8 | 9 | import torch 10 | from typeguard import check_argument_types 11 | 12 | from wenet.transformer.attention import MultiHeadedAttention 13 | from wenet.transformer.attention import RelPositionMultiHeadedAttention 14 | from wenet.transformer.convolution import ConvolutionModule 15 | from wenet.transformer.embedding import PositionalEncoding 16 | from wenet.transformer.embedding import RelPositionalEncoding 17 | from wenet.transformer.embedding import NoPositionalEncoding 18 | from wenet.transformer.encoder_layer import TransformerEncoderLayer 19 | from wenet.transformer.encoder_layer import ConformerEncoderLayer 20 | from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward 21 | from wenet.transformer.subsampling import Conv2dSubsampling4 22 | from wenet.transformer.subsampling import Conv2dSubsampling6 23 | from wenet.transformer.subsampling import Conv2dSubsampling8 24 | from wenet.transformer.subsampling import Conv2dSubsampling2 25 | from wenet.transformer.subsampling import LinearNoSubsampling 26 | from wenet.utils.common import get_activation 27 | from wenet.utils.mask import make_pad_mask 28 | from wenet.utils.mask import add_optional_chunk_mask 29 | 30 | 31 | class BaseEncoder(torch.nn.Module): 32 | def __init__( 33 | self, 34 | input_size: int, 35 | output_size: int = 256, 36 | attention_heads: int = 4, 37 | linear_units: int = 2048, 38 | num_blocks: int = 6, 39 | dropout_rate: float = 0.1, 40 | positional_dropout_rate: float = 0.1, 41 | attention_dropout_rate: float = 0.0, 42 | input_layer: str = "conv2d", 43 | pos_enc_layer_type: str = "abs_pos", 44 | normalize_before: bool = True, 45 | concat_after: bool = False, 46 | static_chunk_size: int = 0, 47 | use_dynamic_chunk: bool = False, 48 | global_cmvn: torch.nn.Module = None, 49 | use_dynamic_left_chunk: bool = False, 50 | ): 51 | """ 52 | Args: 53 | input_size (int): input dim 54 | output_size (int): dimension of attention 55 | attention_heads (int): the number of heads of multi head attention 56 | linear_units (int): the hidden units number of position-wise feed 57 | forward 58 | num_blocks (int): the number of decoder blocks 59 | dropout_rate (float): dropout rate 60 | attention_dropout_rate (float): dropout rate in attention 61 | positional_dropout_rate (float): dropout rate after adding 62 | positional encoding 63 | input_layer (str): input layer type. 64 | optional [linear, conv2d, conv2d6, conv2d8] 65 | pos_enc_layer_type (str): Encoder positional encoding layer type. 66 | opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos] 67 | normalize_before (bool): 68 | True: use layer_norm before each sub-block of a layer. 69 | False: use layer_norm after each sub-block of a layer. 70 | concat_after (bool): whether to concat attention layer's input 71 | and output. 72 | True: x -> x + linear(concat(x, att(x))) 73 | False: x -> x + att(x) 74 | static_chunk_size (int): chunk size for static chunk training and 75 | decoding 76 | use_dynamic_chunk (bool): whether use dynamic chunk size for 77 | training or not, You can only use fixed chunk(chunk_size > 0) 78 | or dyanmic chunk size(use_dynamic_chunk = True) 79 | global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module 80 | use_dynamic_left_chunk (bool): whether use dynamic left chunk in 81 | dynamic chunk training 82 | """ 83 | assert check_argument_types() 84 | super().__init__() 85 | self._output_size = output_size * num_blocks 86 | 87 | if pos_enc_layer_type == "abs_pos": 88 | pos_enc_class = PositionalEncoding 89 | elif pos_enc_layer_type == "rel_pos": 90 | pos_enc_class = RelPositionalEncoding 91 | elif pos_enc_layer_type == "no_pos": 92 | pos_enc_class = NoPositionalEncoding 93 | else: 94 | raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) 95 | 96 | if input_layer == "linear": 97 | subsampling_class = LinearNoSubsampling 98 | elif input_layer == "conv2d": 99 | subsampling_class = Conv2dSubsampling4 100 | elif input_layer == "conv2d6": 101 | subsampling_class = Conv2dSubsampling6 102 | elif input_layer == "conv2d8": 103 | subsampling_class = Conv2dSubsampling8 104 | elif input_layer == "conv2d2": 105 | subsampling_class = Conv2dSubsampling2 106 | else: 107 | raise ValueError("unknown input_layer: " + input_layer) 108 | 109 | self.global_cmvn = global_cmvn 110 | self.embed = subsampling_class( 111 | input_size, 112 | output_size, 113 | dropout_rate, 114 | pos_enc_class(output_size, positional_dropout_rate), 115 | ) 116 | 117 | self.normalize_before = normalize_before 118 | self.after_norm = torch.nn.LayerNorm(output_size * num_blocks, eps=1e-12) 119 | self.static_chunk_size = static_chunk_size 120 | self.use_dynamic_chunk = use_dynamic_chunk 121 | self.use_dynamic_left_chunk = use_dynamic_left_chunk 122 | 123 | def output_size(self) -> int: 124 | return self._output_size 125 | 126 | def forward( 127 | self, 128 | xs: torch.Tensor, 129 | xs_lens: torch.Tensor, 130 | decoding_chunk_size: int = 0, 131 | num_decoding_left_chunks: int = -1, 132 | ) -> Tuple[torch.Tensor, torch.Tensor]: 133 | """Embed positions in tensor. 134 | 135 | Args: 136 | xs: padded input tensor (B, T, D) 137 | xs_lens: input length (B) 138 | decoding_chunk_size: decoding chunk size for dynamic chunk 139 | 0: default for training, use random dynamic chunk. 140 | <0: for decoding, use full chunk. 141 | >0: for decoding, use fixed chunk size as set. 142 | num_decoding_left_chunks: number of left chunks, this is for decoding, 143 | the chunk size is decoding_chunk_size. 144 | >=0: use num_decoding_left_chunks 145 | <0: use all left chunks 146 | Returns: 147 | encoder output tensor xs, and subsampled masks 148 | xs: padded output tensor (B, T' ~= T/subsample_rate, D) 149 | masks: torch.Tensor batch padding mask after subsample 150 | (B, 1, T' ~= T/subsample_rate) 151 | """ 152 | masks = ~make_pad_mask(xs_lens).unsqueeze(1) # (B, 1, T) 153 | if self.global_cmvn is not None: 154 | xs = self.global_cmvn(xs) 155 | xs, pos_emb, masks = self.embed(xs, masks) 156 | mask_pad = masks # (B, 1, T/subsample_rate) 157 | chunk_masks = add_optional_chunk_mask(xs, masks, 158 | self.use_dynamic_chunk, 159 | self.use_dynamic_left_chunk, 160 | decoding_chunk_size, 161 | self.static_chunk_size, 162 | num_decoding_left_chunks) 163 | out = [] 164 | for layer in self.encoders: 165 | xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad) 166 | out.append(xs) 167 | xs = torch.cat(out, dim=-1) 168 | if self.normalize_before: 169 | xs = self.after_norm(xs) 170 | # Here we assume the mask is not changed in encoder layers, so just 171 | # return the masks before encoder layers, and the masks will be used 172 | # for cross attention with decoder later 173 | return xs, masks 174 | 175 | def forward_chunk( 176 | self, 177 | xs: torch.Tensor, 178 | offset: int, 179 | required_cache_size: int, 180 | subsampling_cache: Optional[torch.Tensor] = None, 181 | elayers_output_cache: Optional[List[torch.Tensor]] = None, 182 | conformer_cnn_cache: Optional[List[torch.Tensor]] = None, 183 | ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], 184 | List[torch.Tensor]]: 185 | """ Forward just one chunk 186 | 187 | Args: 188 | xs (torch.Tensor): chunk input 189 | offset (int): current offset in encoder output time stamp 190 | required_cache_size (int): cache size required for next chunk 191 | compuation 192 | >=0: actual cache size 193 | <0: means all history cache is required 194 | subsampling_cache (Optional[torch.Tensor]): subsampling cache 195 | elayers_output_cache (Optional[List[torch.Tensor]]): 196 | transformer/conformer encoder layers output cache 197 | conformer_cnn_cache (Optional[List[torch.Tensor]]): conformer 198 | cnn cache 199 | 200 | Returns: 201 | torch.Tensor: output of current input xs 202 | torch.Tensor: subsampling cache required for next chunk computation 203 | List[torch.Tensor]: encoder layers output cache required for next 204 | chunk computation 205 | List[torch.Tensor]: conformer cnn cache 206 | 207 | """ 208 | assert xs.size(0) == 1 209 | # tmp_masks is just for interface compatibility 210 | tmp_masks = torch.ones(1, 211 | xs.size(1), 212 | device=xs.device, 213 | dtype=torch.bool) 214 | tmp_masks = tmp_masks.unsqueeze(1) 215 | if self.global_cmvn is not None: 216 | xs = self.global_cmvn(xs) 217 | xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) 218 | if subsampling_cache is not None: 219 | cache_size = subsampling_cache.size(1) 220 | xs = torch.cat((subsampling_cache, xs), dim=1) 221 | else: 222 | cache_size = 0 223 | pos_emb = self.embed.position_encoding(offset - cache_size, xs.size(1)) 224 | if required_cache_size < 0: 225 | next_cache_start = 0 226 | elif required_cache_size == 0: 227 | next_cache_start = xs.size(1) 228 | else: 229 | next_cache_start = max(xs.size(1) - required_cache_size, 0) 230 | r_subsampling_cache = xs[:, next_cache_start:, :] 231 | # Real mask for transformer/conformer layers 232 | masks = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool) 233 | masks = masks.unsqueeze(1) 234 | r_elayers_output_cache = [] 235 | r_conformer_cnn_cache = [] 236 | for i, layer in enumerate(self.encoders): 237 | if elayers_output_cache is None: 238 | attn_cache = None 239 | else: 240 | attn_cache = elayers_output_cache[i] 241 | if conformer_cnn_cache is None: 242 | cnn_cache = None 243 | else: 244 | cnn_cache = conformer_cnn_cache[i] 245 | xs, _, new_cnn_cache = layer(xs, 246 | masks, 247 | pos_emb, 248 | output_cache=attn_cache, 249 | cnn_cache=cnn_cache) 250 | r_elayers_output_cache.append(xs[:, next_cache_start:, :]) 251 | r_conformer_cnn_cache.append(new_cnn_cache) 252 | if self.normalize_before: 253 | xs = self.after_norm(xs) 254 | 255 | return (xs[:, cache_size:, :], r_subsampling_cache, 256 | r_elayers_output_cache, r_conformer_cnn_cache) 257 | 258 | def forward_chunk_by_chunk( 259 | self, 260 | xs: torch.Tensor, 261 | decoding_chunk_size: int, 262 | num_decoding_left_chunks: int = -1, 263 | ) -> Tuple[torch.Tensor, torch.Tensor]: 264 | """ Forward input chunk by chunk with chunk_size like a streaming 265 | fashion 266 | 267 | Here we should pay special attention to computation cache in the 268 | streaming style forward chunk by chunk. Three things should be taken 269 | into account for computation in the current network: 270 | 1. transformer/conformer encoder layers output cache 271 | 2. convolution in conformer 272 | 3. convolution in subsampling 273 | 274 | However, we don't implement subsampling cache for: 275 | 1. We can control subsampling module to output the right result by 276 | overlapping input instead of cache left context, even though it 277 | wastes some computation, but subsampling only takes a very 278 | small fraction of computation in the whole model. 279 | 2. Typically, there are several covolution layers with subsampling 280 | in subsampling module, it is tricky and complicated to do cache 281 | with different convolution layers with different subsampling 282 | rate. 283 | 3. Currently, nn.Sequential is used to stack all the convolution 284 | layers in subsampling, we need to rewrite it to make it work 285 | with cache, which is not prefered. 286 | Args: 287 | xs (torch.Tensor): (1, max_len, dim) 288 | chunk_size (int): decoding chunk size 289 | """ 290 | assert decoding_chunk_size > 0 291 | # The model is trained by static or dynamic chunk 292 | assert self.static_chunk_size > 0 or self.use_dynamic_chunk 293 | subsampling = self.embed.subsampling_rate 294 | context = self.embed.right_context + 1 # Add current frame 295 | stride = subsampling * decoding_chunk_size 296 | decoding_window = (decoding_chunk_size - 1) * subsampling + context 297 | num_frames = xs.size(1) 298 | subsampling_cache: Optional[torch.Tensor] = None 299 | elayers_output_cache: Optional[List[torch.Tensor]] = None 300 | conformer_cnn_cache: Optional[List[torch.Tensor]] = None 301 | outputs = [] 302 | offset = 0 303 | required_cache_size = decoding_chunk_size * num_decoding_left_chunks 304 | 305 | # Feed forward overlap input step by step 306 | for cur in range(0, num_frames - context + 1, stride): 307 | end = min(cur + decoding_window, num_frames) 308 | chunk_xs = xs[:, cur:end, :] 309 | (y, subsampling_cache, elayers_output_cache, 310 | conformer_cnn_cache) = self.forward_chunk(chunk_xs, offset, 311 | required_cache_size, 312 | subsampling_cache, 313 | elayers_output_cache, 314 | conformer_cnn_cache) 315 | outputs.append(y) 316 | offset += y.size(1) 317 | ys = torch.cat(outputs, 1) 318 | masks = torch.ones(1, ys.size(1), device=ys.device, dtype=torch.bool) 319 | masks = masks.unsqueeze(1) 320 | return ys, masks 321 | 322 | 323 | class TransformerEncoder(BaseEncoder): 324 | """Transformer encoder module.""" 325 | def __init__( 326 | self, 327 | input_size: int, 328 | output_size: int = 256, 329 | attention_heads: int = 4, 330 | linear_units: int = 2048, 331 | num_blocks: int = 6, 332 | dropout_rate: float = 0.1, 333 | positional_dropout_rate: float = 0.1, 334 | attention_dropout_rate: float = 0.0, 335 | input_layer: str = "conv2d", 336 | pos_enc_layer_type: str = "abs_pos", 337 | normalize_before: bool = True, 338 | concat_after: bool = False, 339 | static_chunk_size: int = 0, 340 | use_dynamic_chunk: bool = False, 341 | global_cmvn: torch.nn.Module = None, 342 | use_dynamic_left_chunk: bool = False, 343 | ): 344 | """ Construct TransformerEncoder 345 | 346 | See Encoder for the meaning of each parameter. 347 | """ 348 | assert check_argument_types() 349 | super().__init__(input_size, output_size, attention_heads, 350 | linear_units, num_blocks, dropout_rate, 351 | positional_dropout_rate, attention_dropout_rate, 352 | input_layer, pos_enc_layer_type, normalize_before, 353 | concat_after, static_chunk_size, use_dynamic_chunk, 354 | global_cmvn, use_dynamic_left_chunk) 355 | self.encoders = torch.nn.ModuleList([ 356 | TransformerEncoderLayer( 357 | output_size, 358 | MultiHeadedAttention(attention_heads, output_size, 359 | attention_dropout_rate), 360 | PositionwiseFeedForward(output_size, linear_units, 361 | dropout_rate), dropout_rate, 362 | normalize_before, concat_after) for _ in range(num_blocks) 363 | ]) 364 | 365 | 366 | class ConformerEncoder(BaseEncoder): 367 | """Conformer encoder module.""" 368 | def __init__( 369 | self, 370 | input_size: int, 371 | output_size: int = 256, 372 | attention_heads: int = 4, 373 | linear_units: int = 2048, 374 | num_blocks: int = 6, 375 | dropout_rate: float = 0.1, 376 | positional_dropout_rate: float = 0.1, 377 | attention_dropout_rate: float = 0.0, 378 | input_layer: str = "conv2d", 379 | pos_enc_layer_type: str = "rel_pos", 380 | normalize_before: bool = True, 381 | concat_after: bool = False, 382 | static_chunk_size: int = 0, 383 | use_dynamic_chunk: bool = False, 384 | global_cmvn: torch.nn.Module = None, 385 | use_dynamic_left_chunk: bool = False, 386 | positionwise_conv_kernel_size: int = 1, 387 | macaron_style: bool = True, 388 | selfattention_layer_type: str = "rel_selfattn", 389 | activation_type: str = "swish", 390 | use_cnn_module: bool = True, 391 | cnn_module_kernel: int = 15, 392 | causal: bool = False, 393 | cnn_module_norm: str = "batch_norm", 394 | ): 395 | """Construct ConformerEncoder 396 | 397 | Args: 398 | input_size to use_dynamic_chunk, see in BaseEncoder 399 | positionwise_conv_kernel_size (int): Kernel size of positionwise 400 | conv1d layer. 401 | macaron_style (bool): Whether to use macaron style for 402 | positionwise layer. 403 | selfattention_layer_type (str): Encoder attention layer type, 404 | the parameter has no effect now, it's just for configure 405 | compatibility. 406 | activation_type (str): Encoder activation function type. 407 | use_cnn_module (bool): Whether to use convolution module. 408 | cnn_module_kernel (int): Kernel size of convolution module. 409 | causal (bool): whether to use causal convolution or not. 410 | """ 411 | assert check_argument_types() 412 | super().__init__(input_size, output_size, attention_heads, 413 | linear_units, num_blocks, dropout_rate, 414 | positional_dropout_rate, attention_dropout_rate, 415 | input_layer, pos_enc_layer_type, normalize_before, 416 | concat_after, static_chunk_size, use_dynamic_chunk, 417 | global_cmvn, use_dynamic_left_chunk) 418 | activation = get_activation(activation_type) 419 | 420 | # self-attention module definition 421 | if pos_enc_layer_type == "no_pos": 422 | encoder_selfattn_layer = MultiHeadedAttention 423 | else: 424 | encoder_selfattn_layer = RelPositionMultiHeadedAttention 425 | encoder_selfattn_layer_args = ( 426 | attention_heads, 427 | output_size, 428 | attention_dropout_rate, 429 | ) 430 | # feed-forward module definition 431 | positionwise_layer = PositionwiseFeedForward 432 | positionwise_layer_args = ( 433 | output_size, 434 | linear_units, 435 | dropout_rate, 436 | activation, 437 | ) 438 | # convolution module definition 439 | convolution_layer = ConvolutionModule 440 | convolution_layer_args = (output_size, cnn_module_kernel, activation, 441 | cnn_module_norm, causal) 442 | 443 | self.encoders = torch.nn.ModuleList([ 444 | ConformerEncoderLayer( 445 | output_size, 446 | encoder_selfattn_layer(*encoder_selfattn_layer_args), 447 | positionwise_layer(*positionwise_layer_args), 448 | positionwise_layer( 449 | *positionwise_layer_args) if macaron_style else None, 450 | convolution_layer( 451 | *convolution_layer_args) if use_cnn_module else None, 452 | dropout_rate, 453 | normalize_before, 454 | concat_after, 455 | ) for _ in range(num_blocks) 456 | ]) 457 | -------------------------------------------------------------------------------- /wenet/transformer/encoder_layer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Mobvoi Inc. All Rights Reserved. 5 | # Author: di.wu@mobvoi.com (DI WU) 6 | """Encoder self-attention layer definition.""" 7 | 8 | from typing import Optional, Tuple 9 | 10 | import torch 11 | from torch import nn 12 | 13 | 14 | class TransformerEncoderLayer(nn.Module): 15 | """Encoder layer module. 16 | 17 | Args: 18 | size (int): Input dimension. 19 | self_attn (torch.nn.Module): Self-attention module instance. 20 | `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` 21 | instance can be used as the argument. 22 | feed_forward (torch.nn.Module): Feed-forward module instance. 23 | `PositionwiseFeedForward`, instance can be used as the argument. 24 | dropout_rate (float): Dropout rate. 25 | normalize_before (bool): 26 | True: use layer_norm before each sub-block. 27 | False: to use layer_norm after each sub-block. 28 | concat_after (bool): Whether to concat attention layer's input and 29 | output. 30 | True: x -> x + linear(concat(x, att(x))) 31 | False: x -> x + att(x) 32 | 33 | """ 34 | def __init__( 35 | self, 36 | size: int, 37 | self_attn: torch.nn.Module, 38 | feed_forward: torch.nn.Module, 39 | dropout_rate: float, 40 | normalize_before: bool = True, 41 | concat_after: bool = False, 42 | ): 43 | """Construct an EncoderLayer object.""" 44 | super().__init__() 45 | self.self_attn = self_attn 46 | self.feed_forward = feed_forward 47 | self.norm1 = nn.LayerNorm(size, eps=1e-12) 48 | self.norm2 = nn.LayerNorm(size, eps=1e-12) 49 | self.dropout = nn.Dropout(dropout_rate) 50 | self.size = size 51 | self.normalize_before = normalize_before 52 | self.concat_after = concat_after 53 | 54 | def forward( 55 | self, 56 | x: torch.Tensor, 57 | mask: torch.Tensor, 58 | pos_emb: torch.Tensor, 59 | mask_pad: Optional[torch.Tensor] = None, 60 | output_cache: Optional[torch.Tensor] = None, 61 | cnn_cache: Optional[torch.Tensor] = None, 62 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 63 | """Compute encoded features. 64 | 65 | Args: 66 | x (torch.Tensor): Input tensor (#batch, time, size). 67 | mask (torch.Tensor): Mask tensor for the input (#batch, time). 68 | pos_emb (torch.Tensor): just for interface compatibility 69 | to ConformerEncoderLayer 70 | mask_pad (torch.Tensor): does not used in transformer layer, 71 | just for unified api with conformer. 72 | output_cache (torch.Tensor): Cache tensor of the output 73 | (#batch, time2, size), time2 < time in x. 74 | cnn_cache (torch.Tensor): not used here, it's for interface 75 | compatibility to ConformerEncoderLayer 76 | Returns: 77 | torch.Tensor: Output tensor (#batch, time, size). 78 | torch.Tensor: Mask tensor (#batch, time). 79 | 80 | """ 81 | residual = x 82 | if self.normalize_before: 83 | x = self.norm1(x) 84 | 85 | if output_cache is None: 86 | x_q = x 87 | else: 88 | assert output_cache.size(0) == x.size(0) 89 | assert output_cache.size(2) == self.size 90 | assert output_cache.size(1) < x.size(1) 91 | chunk = x.size(1) - output_cache.size(1) 92 | x_q = x[:, -chunk:, :] 93 | residual = residual[:, -chunk:, :] 94 | mask = mask[:, -chunk:, :] 95 | 96 | x = residual + self.dropout(self.self_attn(x_q, x, x, mask)) 97 | if not self.normalize_before: 98 | x = self.norm1(x) 99 | 100 | residual = x 101 | if self.normalize_before: 102 | x = self.norm2(x) 103 | x = residual + self.dropout(self.feed_forward(x)) 104 | if not self.normalize_before: 105 | x = self.norm2(x) 106 | 107 | if output_cache is not None: 108 | x = torch.cat([output_cache, x], dim=1) 109 | 110 | fake_cnn_cache = torch.tensor([0.0], dtype=x.dtype, device=x.device) 111 | return x, mask, fake_cnn_cache 112 | 113 | 114 | class ConformerEncoderLayer(nn.Module): 115 | """Encoder layer module. 116 | Args: 117 | size (int): Input dimension. 118 | self_attn (torch.nn.Module): Self-attention module instance. 119 | `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` 120 | instance can be used as the argument. 121 | feed_forward (torch.nn.Module): Feed-forward module instance. 122 | `PositionwiseFeedForward` instance can be used as the argument. 123 | feed_forward_macaron (torch.nn.Module): Additional feed-forward module 124 | instance. 125 | `PositionwiseFeedForward` instance can be used as the argument. 126 | conv_module (torch.nn.Module): Convolution module instance. 127 | `ConvlutionModule` instance can be used as the argument. 128 | dropout_rate (float): Dropout rate. 129 | normalize_before (bool): 130 | True: use layer_norm before each sub-block. 131 | False: use layer_norm after each sub-block. 132 | concat_after (bool): Whether to concat attention layer's input and 133 | output. 134 | True: x -> x + linear(concat(x, att(x))) 135 | False: x -> x + att(x) 136 | """ 137 | def __init__( 138 | self, 139 | size: int, 140 | self_attn: torch.nn.Module, 141 | feed_forward: Optional[nn.Module] = None, 142 | feed_forward_macaron: Optional[nn.Module] = None, 143 | conv_module: Optional[nn.Module] = None, 144 | dropout_rate: float = 0.1, 145 | normalize_before: bool = True, 146 | concat_after: bool = False, 147 | ): 148 | """Construct an EncoderLayer object.""" 149 | super().__init__() 150 | self.self_attn = self_attn 151 | self.feed_forward = feed_forward 152 | self.feed_forward_macaron = feed_forward_macaron 153 | self.conv_module = conv_module 154 | self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module 155 | self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module 156 | if feed_forward_macaron is not None: 157 | self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12) 158 | self.ff_scale = 0.5 159 | else: 160 | self.ff_scale = 1.0 161 | if self.conv_module is not None: 162 | self.norm_conv = nn.LayerNorm(size, 163 | eps=1e-12) # for the CNN module 164 | self.norm_final = nn.LayerNorm( 165 | size, eps=1e-12) # for the final output of the block 166 | self.dropout = nn.Dropout(dropout_rate) 167 | self.size = size 168 | self.normalize_before = normalize_before 169 | self.concat_after = concat_after 170 | 171 | def forward( 172 | self, 173 | x: torch.Tensor, 174 | mask: torch.Tensor, 175 | pos_emb: torch.Tensor, 176 | mask_pad: Optional[torch.Tensor] = None, 177 | output_cache: Optional[torch.Tensor] = None, 178 | cnn_cache: Optional[torch.Tensor] = None, 179 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 180 | """Compute encoded features. 181 | 182 | Args: 183 | x (torch.Tensor): (#batch, time, size) 184 | mask (torch.Tensor): Mask tensor for the input (#batch, time,time). 185 | pos_emb (torch.Tensor): positional encoding, must not be None 186 | for ConformerEncoderLayer. 187 | mask_pad (torch.Tensor): batch padding mask used for conv module. 188 | (#batch, 1,time) 189 | output_cache (torch.Tensor): Cache tensor of the output 190 | (#batch, time2, size), time2 < time in x. 191 | cnn_cache (torch.Tensor): Convolution cache in conformer layer 192 | Returns: 193 | torch.Tensor: Output tensor (#batch, time, size). 194 | torch.Tensor: Mask tensor (#batch, time). 195 | """ 196 | 197 | # whether to use macaron style 198 | if self.feed_forward_macaron is not None: 199 | residual = x 200 | if self.normalize_before: 201 | x = self.norm_ff_macaron(x) 202 | x = residual + self.ff_scale * self.dropout( 203 | self.feed_forward_macaron(x)) 204 | if not self.normalize_before: 205 | x = self.norm_ff_macaron(x) 206 | 207 | # multi-headed self-attention module 208 | residual = x 209 | if self.normalize_before: 210 | x = self.norm_mha(x) 211 | 212 | if output_cache is None: 213 | x_q = x 214 | else: 215 | assert output_cache.size(0) == x.size(0) 216 | assert output_cache.size(2) == self.size 217 | assert output_cache.size(1) < x.size(1) 218 | chunk = x.size(1) - output_cache.size(1) 219 | x_q = x[:, -chunk:, :] 220 | residual = residual[:, -chunk:, :] 221 | mask = mask[:, -chunk:, :] 222 | 223 | x_att = self.self_attn(x_q, x, x, mask, pos_emb) 224 | x = residual + self.dropout(x_att) 225 | if not self.normalize_before: 226 | x = self.norm_mha(x) 227 | 228 | # convolution module 229 | # Fake new cnn cache here, and then change it in conv_module 230 | new_cnn_cache = torch.tensor([0.0], dtype=x.dtype, device=x.device) 231 | if self.conv_module is not None: 232 | residual = x 233 | if self.normalize_before: 234 | x = self.norm_conv(x) 235 | x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) 236 | x = residual + self.dropout(x) 237 | 238 | if not self.normalize_before: 239 | x = self.norm_conv(x) 240 | 241 | # feed forward module 242 | residual = x 243 | if self.normalize_before: 244 | x = self.norm_ff(x) 245 | 246 | x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) 247 | if not self.normalize_before: 248 | x = self.norm_ff(x) 249 | 250 | if self.conv_module is not None: 251 | x = self.norm_final(x) 252 | 253 | if output_cache is not None: 254 | x = torch.cat([output_cache, x], dim=1) 255 | 256 | return x, mask, new_cnn_cache 257 | -------------------------------------------------------------------------------- /wenet/transformer/encoder_weight.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Mobvoi Inc. All Rights Reserved. 5 | # Author: di.wu@mobvoi.com (DI WU) 6 | """Encoder definition.""" 7 | from typing import Tuple, List, Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from typeguard import check_argument_types 13 | 14 | from wenet.transformer.attention import MultiHeadedAttention 15 | from wenet.transformer.attention import RelPositionMultiHeadedAttention 16 | from wenet.transformer.convolution import ConvolutionModule 17 | from wenet.transformer.embedding import PositionalEncoding 18 | from wenet.transformer.embedding import RelPositionalEncoding 19 | from wenet.transformer.embedding import NoPositionalEncoding 20 | from wenet.transformer.encoder_layer import TransformerEncoderLayer 21 | from wenet.transformer.encoder_layer import ConformerEncoderLayer 22 | from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward 23 | from wenet.transformer.subsampling import Conv2dSubsampling2 24 | from wenet.transformer.subsampling import Conv2dSubsampling4 25 | from wenet.transformer.subsampling import Conv2dSubsampling6 26 | from wenet.transformer.subsampling import Conv2dSubsampling8 27 | from wenet.transformer.subsampling import LinearNoSubsampling 28 | from wenet.utils.common import get_activation 29 | from wenet.utils.mask import make_pad_mask 30 | from wenet.utils.mask import add_optional_chunk_mask 31 | 32 | 33 | class BaseEncoder(torch.nn.Module): 34 | def __init__( 35 | self, 36 | input_size: int, 37 | output_size: int = 256, 38 | attention_heads: int = 4, 39 | linear_units: int = 2048, 40 | num_blocks: int = 6, 41 | dropout_rate: float = 0.1, 42 | positional_dropout_rate: float = 0.1, 43 | attention_dropout_rate: float = 0.0, 44 | input_layer: str = "conv2d", 45 | pos_enc_layer_type: str = "abs_pos", 46 | normalize_before: bool = True, 47 | concat_after: bool = False, 48 | static_chunk_size: int = 0, 49 | use_dynamic_chunk: bool = False, 50 | global_cmvn: torch.nn.Module = None, 51 | use_dynamic_left_chunk: bool = False, 52 | ): 53 | """ 54 | Args: 55 | input_size (int): input dim 56 | output_size (int): dimension of attention 57 | attention_heads (int): the number of heads of multi head attention 58 | linear_units (int): the hidden units number of position-wise feed 59 | forward 60 | num_blocks (int): the number of decoder blocks 61 | dropout_rate (float): dropout rate 62 | attention_dropout_rate (float): dropout rate in attention 63 | positional_dropout_rate (float): dropout rate after adding 64 | positional encoding 65 | input_layer (str): input layer type. 66 | optional [linear, conv2d, conv2d6, conv2d8] 67 | pos_enc_layer_type (str): Encoder positional encoding layer type. 68 | opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos] 69 | normalize_before (bool): 70 | True: use layer_norm before each sub-block of a layer. 71 | False: use layer_norm after each sub-block of a layer. 72 | concat_after (bool): whether to concat attention layer's input 73 | and output. 74 | True: x -> x + linear(concat(x, att(x))) 75 | False: x -> x + att(x) 76 | static_chunk_size (int): chunk size for static chunk training and 77 | decoding 78 | use_dynamic_chunk (bool): whether use dynamic chunk size for 79 | training or not, You can only use fixed chunk(chunk_size > 0) 80 | or dyanmic chunk size(use_dynamic_chunk = True) 81 | global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module 82 | use_dynamic_left_chunk (bool): whether use dynamic left chunk in 83 | dynamic chunk training 84 | """ 85 | assert check_argument_types() 86 | super().__init__() 87 | self._output_size = output_size 88 | 89 | if pos_enc_layer_type == "abs_pos": 90 | pos_enc_class = PositionalEncoding 91 | elif pos_enc_layer_type == "rel_pos": 92 | pos_enc_class = RelPositionalEncoding 93 | elif pos_enc_layer_type == "no_pos": 94 | pos_enc_class = NoPositionalEncoding 95 | else: 96 | raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) 97 | 98 | if input_layer == "linear": 99 | subsampling_class = LinearNoSubsampling 100 | elif input_layer == "conv2d": 101 | subsampling_class = Conv2dSubsampling4 102 | elif input_layer == "conv2d6": 103 | subsampling_class = Conv2dSubsampling6 104 | elif input_layer == "conv2d8": 105 | subsampling_class = Conv2dSubsampling8 106 | elif input_layer == "conv2d2": 107 | subsampling_class = Conv2dSubsampling2 108 | else: 109 | raise ValueError("unknown input_layer: " + input_layer) 110 | 111 | self.global_cmvn = global_cmvn 112 | self.embed = subsampling_class( 113 | input_size, 114 | output_size, 115 | dropout_rate, 116 | pos_enc_class(output_size, positional_dropout_rate), 117 | ) 118 | 119 | self.normalize_before = normalize_before 120 | self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-12) 121 | self.static_chunk_size = static_chunk_size 122 | self.use_dynamic_chunk = use_dynamic_chunk 123 | self.use_dynamic_left_chunk = use_dynamic_left_chunk 124 | self.num_blocks = num_blocks 125 | self.feature_weight = nn.Parameter(torch.ones(self.num_blocks)) 126 | 127 | def output_size(self) -> int: 128 | return self._output_size 129 | 130 | def forward( 131 | self, 132 | xs: torch.Tensor, 133 | xs_lens: torch.Tensor, 134 | decoding_chunk_size: int = 0, 135 | num_decoding_left_chunks: int = -1, 136 | ) -> Tuple[torch.Tensor, torch.Tensor]: 137 | """Embed positions in tensor. 138 | 139 | Args: 140 | xs: padded input tensor (B, T, D) 141 | xs_lens: input length (B) 142 | decoding_chunk_size: decoding chunk size for dynamic chunk 143 | 0: default for training, use random dynamic chunk. 144 | <0: for decoding, use full chunk. 145 | >0: for decoding, use fixed chunk size as set. 146 | num_decoding_left_chunks: number of left chunks, this is for decoding, 147 | the chunk size is decoding_chunk_size. 148 | >=0: use num_decoding_left_chunks 149 | <0: use all left chunks 150 | Returns: 151 | encoder output tensor xs, and subsampled masks 152 | xs: padded output tensor (B, T' ~= T/subsample_rate, D) 153 | masks: torch.Tensor batch padding mask after subsample 154 | (B, 1, T' ~= T/subsample_rate) 155 | """ 156 | masks = ~make_pad_mask(xs_lens).unsqueeze(1) # (B, 1, T) 157 | if self.global_cmvn is not None: 158 | xs = self.global_cmvn(xs) 159 | xs, pos_emb, masks = self.embed(xs, masks) 160 | mask_pad = masks # (B, 1, T/subsample_rate) 161 | chunk_masks = add_optional_chunk_mask(xs, masks, 162 | self.use_dynamic_chunk, 163 | self.use_dynamic_left_chunk, 164 | decoding_chunk_size, 165 | self.static_chunk_size, 166 | num_decoding_left_chunks) 167 | out = [] 168 | for layer in self.encoders: 169 | xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad) 170 | out.append(xs) 171 | xs = torch.cat(out, dim=-1) 172 | xs = xs.reshape(xs.shape[0], xs.shape[1], self.output_size(), self.num_blocks) 173 | norm_weights = F.softmax(self.feature_weight, dim=-1) 174 | xs = xs.matmul(norm_weights) 175 | 176 | if self.normalize_before: 177 | xs = self.after_norm(xs) 178 | # Here we assume the mask is not changed in encoder layers, so just 179 | # return the masks before encoder layers, and the masks will be used 180 | # for cross attention with decoder later 181 | return xs, masks 182 | 183 | def forward_chunk( 184 | self, 185 | xs: torch.Tensor, 186 | offset: int, 187 | required_cache_size: int, 188 | subsampling_cache: Optional[torch.Tensor] = None, 189 | elayers_output_cache: Optional[List[torch.Tensor]] = None, 190 | conformer_cnn_cache: Optional[List[torch.Tensor]] = None, 191 | ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], 192 | List[torch.Tensor]]: 193 | """ Forward just one chunk 194 | 195 | Args: 196 | xs (torch.Tensor): chunk input 197 | offset (int): current offset in encoder output time stamp 198 | required_cache_size (int): cache size required for next chunk 199 | compuation 200 | >=0: actual cache size 201 | <0: means all history cache is required 202 | subsampling_cache (Optional[torch.Tensor]): subsampling cache 203 | elayers_output_cache (Optional[List[torch.Tensor]]): 204 | transformer/conformer encoder layers output cache 205 | conformer_cnn_cache (Optional[List[torch.Tensor]]): conformer 206 | cnn cache 207 | 208 | Returns: 209 | torch.Tensor: output of current input xs 210 | torch.Tensor: subsampling cache required for next chunk computation 211 | List[torch.Tensor]: encoder layers output cache required for next 212 | chunk computation 213 | List[torch.Tensor]: conformer cnn cache 214 | 215 | """ 216 | assert xs.size(0) == 1 217 | # tmp_masks is just for interface compatibility 218 | tmp_masks = torch.ones(1, 219 | xs.size(1), 220 | device=xs.device, 221 | dtype=torch.bool) 222 | tmp_masks = tmp_masks.unsqueeze(1) 223 | if self.global_cmvn is not None: 224 | xs = self.global_cmvn(xs) 225 | xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) 226 | if subsampling_cache is not None: 227 | cache_size = subsampling_cache.size(1) 228 | xs = torch.cat((subsampling_cache, xs), dim=1) 229 | else: 230 | cache_size = 0 231 | pos_emb = self.embed.position_encoding(offset - cache_size, xs.size(1)) 232 | if required_cache_size < 0: 233 | next_cache_start = 0 234 | elif required_cache_size == 0: 235 | next_cache_start = xs.size(1) 236 | else: 237 | next_cache_start = max(xs.size(1) - required_cache_size, 0) 238 | r_subsampling_cache = xs[:, next_cache_start:, :] 239 | # Real mask for transformer/conformer layers 240 | masks = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool) 241 | masks = masks.unsqueeze(1) 242 | r_elayers_output_cache = [] 243 | r_conformer_cnn_cache = [] 244 | for i, layer in enumerate(self.encoders): 245 | if elayers_output_cache is None: 246 | attn_cache = None 247 | else: 248 | attn_cache = elayers_output_cache[i] 249 | if conformer_cnn_cache is None: 250 | cnn_cache = None 251 | else: 252 | cnn_cache = conformer_cnn_cache[i] 253 | xs, _, new_cnn_cache = layer(xs, 254 | masks, 255 | pos_emb, 256 | output_cache=attn_cache, 257 | cnn_cache=cnn_cache) 258 | r_elayers_output_cache.append(xs[:, next_cache_start:, :]) 259 | r_conformer_cnn_cache.append(new_cnn_cache) 260 | if self.normalize_before: 261 | xs = self.after_norm(xs) 262 | 263 | return (xs[:, cache_size:, :], r_subsampling_cache, 264 | r_elayers_output_cache, r_conformer_cnn_cache) 265 | 266 | def forward_chunk_by_chunk( 267 | self, 268 | xs: torch.Tensor, 269 | decoding_chunk_size: int, 270 | num_decoding_left_chunks: int = -1, 271 | ) -> Tuple[torch.Tensor, torch.Tensor]: 272 | """ Forward input chunk by chunk with chunk_size like a streaming 273 | fashion 274 | 275 | Here we should pay special attention to computation cache in the 276 | streaming style forward chunk by chunk. Three things should be taken 277 | into account for computation in the current network: 278 | 1. transformer/conformer encoder layers output cache 279 | 2. convolution in conformer 280 | 3. convolution in subsampling 281 | 282 | However, we don't implement subsampling cache for: 283 | 1. We can control subsampling module to output the right result by 284 | overlapping input instead of cache left context, even though it 285 | wastes some computation, but subsampling only takes a very 286 | small fraction of computation in the whole model. 287 | 2. Typically, there are several covolution layers with subsampling 288 | in subsampling module, it is tricky and complicated to do cache 289 | with different convolution layers with different subsampling 290 | rate. 291 | 3. Currently, nn.Sequential is used to stack all the convolution 292 | layers in subsampling, we need to rewrite it to make it work 293 | with cache, which is not prefered. 294 | Args: 295 | xs (torch.Tensor): (1, max_len, dim) 296 | chunk_size (int): decoding chunk size 297 | """ 298 | assert decoding_chunk_size > 0 299 | # The model is trained by static or dynamic chunk 300 | assert self.static_chunk_size > 0 or self.use_dynamic_chunk 301 | subsampling = self.embed.subsampling_rate 302 | context = self.embed.right_context + 1 # Add current frame 303 | stride = subsampling * decoding_chunk_size 304 | decoding_window = (decoding_chunk_size - 1) * subsampling + context 305 | num_frames = xs.size(1) 306 | subsampling_cache: Optional[torch.Tensor] = None 307 | elayers_output_cache: Optional[List[torch.Tensor]] = None 308 | conformer_cnn_cache: Optional[List[torch.Tensor]] = None 309 | outputs = [] 310 | offset = 0 311 | required_cache_size = decoding_chunk_size * num_decoding_left_chunks 312 | 313 | # Feed forward overlap input step by step 314 | for cur in range(0, num_frames - context + 1, stride): 315 | end = min(cur + decoding_window, num_frames) 316 | chunk_xs = xs[:, cur:end, :] 317 | (y, subsampling_cache, elayers_output_cache, 318 | conformer_cnn_cache) = self.forward_chunk(chunk_xs, offset, 319 | required_cache_size, 320 | subsampling_cache, 321 | elayers_output_cache, 322 | conformer_cnn_cache) 323 | outputs.append(y) 324 | offset += y.size(1) 325 | ys = torch.cat(outputs, 1) 326 | masks = torch.ones(1, ys.size(1), device=ys.device, dtype=torch.bool) 327 | masks = masks.unsqueeze(1) 328 | return ys, masks 329 | 330 | 331 | class TransformerEncoder(BaseEncoder): 332 | """Transformer encoder module.""" 333 | def __init__( 334 | self, 335 | input_size: int, 336 | output_size: int = 256, 337 | attention_heads: int = 4, 338 | linear_units: int = 2048, 339 | num_blocks: int = 6, 340 | dropout_rate: float = 0.1, 341 | positional_dropout_rate: float = 0.1, 342 | attention_dropout_rate: float = 0.0, 343 | input_layer: str = "conv2d", 344 | pos_enc_layer_type: str = "abs_pos", 345 | normalize_before: bool = True, 346 | concat_after: bool = False, 347 | static_chunk_size: int = 0, 348 | use_dynamic_chunk: bool = False, 349 | global_cmvn: torch.nn.Module = None, 350 | use_dynamic_left_chunk: bool = False, 351 | ): 352 | """ Construct TransformerEncoder 353 | 354 | See Encoder for the meaning of each parameter. 355 | """ 356 | assert check_argument_types() 357 | super().__init__(input_size, output_size, attention_heads, 358 | linear_units, num_blocks, dropout_rate, 359 | positional_dropout_rate, attention_dropout_rate, 360 | input_layer, pos_enc_layer_type, normalize_before, 361 | concat_after, static_chunk_size, use_dynamic_chunk, 362 | global_cmvn, use_dynamic_left_chunk) 363 | self.encoders = torch.nn.ModuleList([ 364 | TransformerEncoderLayer( 365 | output_size, 366 | MultiHeadedAttention(attention_heads, output_size, 367 | attention_dropout_rate), 368 | PositionwiseFeedForward(output_size, linear_units, 369 | dropout_rate), dropout_rate, 370 | normalize_before, concat_after) for _ in range(num_blocks) 371 | ]) 372 | 373 | 374 | class ConformerEncoder(BaseEncoder): 375 | """Conformer encoder module.""" 376 | def __init__( 377 | self, 378 | input_size: int, 379 | output_size: int = 256, 380 | attention_heads: int = 4, 381 | linear_units: int = 2048, 382 | num_blocks: int = 6, 383 | dropout_rate: float = 0.1, 384 | positional_dropout_rate: float = 0.1, 385 | attention_dropout_rate: float = 0.0, 386 | input_layer: str = "conv2d", 387 | pos_enc_layer_type: str = "rel_pos", 388 | normalize_before: bool = True, 389 | concat_after: bool = False, 390 | static_chunk_size: int = 0, 391 | use_dynamic_chunk: bool = False, 392 | global_cmvn: torch.nn.Module = None, 393 | use_dynamic_left_chunk: bool = False, 394 | positionwise_conv_kernel_size: int = 1, 395 | macaron_style: bool = True, 396 | selfattention_layer_type: str = "rel_selfattn", 397 | activation_type: str = "swish", 398 | use_cnn_module: bool = True, 399 | cnn_module_kernel: int = 15, 400 | causal: bool = False, 401 | cnn_module_norm: str = "batch_norm", 402 | ): 403 | """Construct ConformerEncoder 404 | 405 | Args: 406 | input_size to use_dynamic_chunk, see in BaseEncoder 407 | positionwise_conv_kernel_size (int): Kernel size of positionwise 408 | conv1d layer. 409 | macaron_style (bool): Whether to use macaron style for 410 | positionwise layer. 411 | selfattention_layer_type (str): Encoder attention layer type, 412 | the parameter has no effect now, it's just for configure 413 | compatibility. 414 | activation_type (str): Encoder activation function type. 415 | use_cnn_module (bool): Whether to use convolution module. 416 | cnn_module_kernel (int): Kernel size of convolution module. 417 | causal (bool): whether to use causal convolution or not. 418 | """ 419 | assert check_argument_types() 420 | super().__init__(input_size, output_size, attention_heads, 421 | linear_units, num_blocks, dropout_rate, 422 | positional_dropout_rate, attention_dropout_rate, 423 | input_layer, pos_enc_layer_type, normalize_before, 424 | concat_after, static_chunk_size, use_dynamic_chunk, 425 | global_cmvn, use_dynamic_left_chunk) 426 | activation = get_activation(activation_type) 427 | 428 | # self-attention module definition 429 | if pos_enc_layer_type == "no_pos": 430 | encoder_selfattn_layer = MultiHeadedAttention 431 | else: 432 | encoder_selfattn_layer = RelPositionMultiHeadedAttention 433 | encoder_selfattn_layer_args = ( 434 | attention_heads, 435 | output_size, 436 | attention_dropout_rate, 437 | ) 438 | # feed-forward module definition 439 | positionwise_layer = PositionwiseFeedForward 440 | positionwise_layer_args = ( 441 | output_size, 442 | linear_units, 443 | dropout_rate, 444 | activation, 445 | ) 446 | # convolution module definition 447 | convolution_layer = ConvolutionModule 448 | convolution_layer_args = (output_size, cnn_module_kernel, activation, 449 | cnn_module_norm, causal) 450 | 451 | self.encoders = torch.nn.ModuleList([ 452 | ConformerEncoderLayer( 453 | output_size, 454 | encoder_selfattn_layer(*encoder_selfattn_layer_args), 455 | positionwise_layer(*positionwise_layer_args), 456 | positionwise_layer( 457 | *positionwise_layer_args) if macaron_style else None, 458 | convolution_layer( 459 | *convolution_layer_args) if use_cnn_module else None, 460 | dropout_rate, 461 | normalize_before, 462 | concat_after, 463 | ) for _ in range(num_blocks) 464 | ]) 465 | -------------------------------------------------------------------------------- /wenet/transformer/label_smoothing_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | """Label smoothing module.""" 7 | 8 | import torch 9 | from torch import nn 10 | 11 | 12 | class LabelSmoothingLoss(nn.Module): 13 | """Label-smoothing loss. 14 | 15 | In a standard CE loss, the label's data distribution is: 16 | [0,1,2] -> 17 | [ 18 | [1.0, 0.0, 0.0], 19 | [0.0, 1.0, 0.0], 20 | [1.0, 0.0, 1.0], 21 | ] 22 | 23 | In the smoothing version CE Loss,some probabilities 24 | are taken from the true label prob (1.0) and are divided 25 | among other labels. 26 | 27 | e.g. 28 | smoothing=0.1 29 | [0,1,2] -> 30 | [ 31 | [0.9, 0.05, 0.05], 32 | [0.05, 0.9, 0.05], 33 | [0.05, 0.05, 0.9], 34 | ] 35 | 36 | Args: 37 | size (int): the number of class 38 | padding_idx (int): padding class id which will be ignored for loss 39 | smoothing (float): smoothing rate (0.0 means the conventional CE) 40 | normalize_length (bool): 41 | normalize loss by sequence length if True 42 | normalize loss by batch size if False 43 | """ 44 | def __init__(self, 45 | size: int, 46 | padding_idx: int, 47 | smoothing: float, 48 | normalize_length: bool = False): 49 | """Construct an LabelSmoothingLoss object.""" 50 | super(LabelSmoothingLoss, self).__init__() 51 | self.criterion = nn.KLDivLoss(reduction="none") 52 | self.padding_idx = padding_idx 53 | self.confidence = 1.0 - smoothing 54 | self.smoothing = smoothing 55 | self.size = size 56 | self.normalize_length = normalize_length 57 | 58 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 59 | """Compute loss between x and target. 60 | 61 | The model outputs and data labels tensors are flatten to 62 | (batch*seqlen, class) shape and a mask is applied to the 63 | padding part which should not be calculated for loss. 64 | 65 | Args: 66 | x (torch.Tensor): prediction (batch, seqlen, class) 67 | target (torch.Tensor): 68 | target signal masked with self.padding_id (batch, seqlen) 69 | Returns: 70 | loss (torch.Tensor) : The KL loss, scalar float value 71 | """ 72 | assert x.size(2) == self.size 73 | batch_size = x.size(0) 74 | x = x.view(-1, self.size) 75 | target = target.view(-1) 76 | # use zeros_like instead of torch.no_grad() for true_dist, 77 | # since no_grad() can not be exported by JIT 78 | true_dist = torch.zeros_like(x) 79 | true_dist.fill_(self.smoothing / (self.size - 1)) 80 | ignore = target == self.padding_idx # (B,) 81 | total = len(target) - ignore.sum().item() 82 | target = target.masked_fill(ignore, 0) # avoid -1 index 83 | true_dist.scatter_(1, target.unsqueeze(1), self.confidence) 84 | kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) 85 | denom = total if self.normalize_length else batch_size 86 | return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom 87 | -------------------------------------------------------------------------------- /wenet/transformer/positionwise_feed_forward.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | """Positionwise feed forward layer definition.""" 7 | 8 | import torch 9 | 10 | 11 | class PositionwiseFeedForward(torch.nn.Module): 12 | """Positionwise feed forward layer. 13 | 14 | FeedForward are appied on each position of the sequence. 15 | The output dim is same with the input dim. 16 | 17 | Args: 18 | idim (int): Input dimenstion. 19 | hidden_units (int): The number of hidden units. 20 | dropout_rate (float): Dropout rate. 21 | activation (torch.nn.Module): Activation function 22 | """ 23 | def __init__(self, 24 | idim: int, 25 | hidden_units: int, 26 | dropout_rate: float, 27 | activation: torch.nn.Module = torch.nn.ReLU()): 28 | """Construct a PositionwiseFeedForward object.""" 29 | super(PositionwiseFeedForward, self).__init__() 30 | self.w_1 = torch.nn.Linear(idim, hidden_units) 31 | self.activation = activation 32 | self.dropout = torch.nn.Dropout(dropout_rate) 33 | self.w_2 = torch.nn.Linear(hidden_units, idim) 34 | 35 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 36 | """Forward function. 37 | 38 | Args: 39 | xs: input tensor (B, L, D) 40 | Returns: 41 | output tensor, (B, L, D) 42 | """ 43 | return self.w_2(self.dropout(self.activation(self.w_1(xs)))) 44 | -------------------------------------------------------------------------------- /wenet/transformer/subsampling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Mobvoi Inc. All Rights Reserved. 5 | # Author: di.wu@mobvoi.com (DI WU) 6 | """Subsampling layer definition.""" 7 | 8 | from typing import Tuple 9 | 10 | import torch 11 | 12 | class BaseSubsampling(torch.nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | self.right_context = 0 16 | self.subsampling_rate = 1 17 | 18 | def position_encoding(self, offset: int, size: int) -> torch.Tensor: 19 | return self.pos_enc.position_encoding(offset, size) 20 | 21 | 22 | class LinearNoSubsampling(BaseSubsampling): 23 | """Linear transform the input without subsampling 24 | 25 | Args: 26 | idim (int): Input dimension. 27 | odim (int): Output dimension. 28 | dropout_rate (float): Dropout rate. 29 | 30 | """ 31 | def __init__(self, idim: int, odim: int, dropout_rate: float, 32 | pos_enc_class: torch.nn.Module): 33 | """Construct an linear object.""" 34 | super().__init__() 35 | self.out = torch.nn.Sequential( 36 | torch.nn.Linear(idim, odim), 37 | torch.nn.LayerNorm(odim, eps=1e-12), 38 | torch.nn.Dropout(dropout_rate), 39 | ) 40 | self.pos_enc = pos_enc_class 41 | self.right_context = 0 42 | self.subsampling_rate = 1 43 | 44 | def forward( 45 | self, 46 | x: torch.Tensor, 47 | x_mask: torch.Tensor, 48 | offset: int = 0 49 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 50 | """Input x. 51 | 52 | Args: 53 | x (torch.Tensor): Input tensor (#batch, time, idim). 54 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 55 | 56 | Returns: 57 | torch.Tensor: linear input tensor (#batch, time', odim), 58 | where time' = time . 59 | torch.Tensor: linear input mask (#batch, 1, time'), 60 | where time' = time . 61 | 62 | """ 63 | x = self.out(x) 64 | x, pos_emb = self.pos_enc(x, offset) 65 | return x, pos_emb, x_mask 66 | 67 | class Conv2dSubsampling2(BaseSubsampling): 68 | """Convolutional 2D subsampling (to 1/2 length). 69 | 70 | Args: 71 | idim (int): Input dimension. 72 | odim (int): Output dimension. 73 | dropout_rate (float): Dropout rate. 74 | 75 | """ 76 | def __init__(self, idim: int, odim: int, dropout_rate: float, 77 | pos_enc_class: torch.nn.Module): 78 | """Construct an Conv2dSubsampling4 object.""" 79 | super().__init__() 80 | self.conv = torch.nn.Sequential( 81 | torch.nn.Conv2d(1, odim, 3, 2), 82 | torch.nn.ReLU(), 83 | ) 84 | self.out = torch.nn.Sequential( 85 | torch.nn.Linear(odim * (idim // 2 - 1), odim)) 86 | self.pos_enc = pos_enc_class 87 | # The right context for every conv layer is computed by: 88 | # (kernel_size - 1) * frame_rate_of_this_layer 89 | self.subsampling_rate = 2 90 | # 2 = (3 - 1) * 1 91 | self.right_context = 2 92 | 93 | def forward( 94 | self, 95 | x: torch.Tensor, 96 | x_mask: torch.Tensor, 97 | offset: int = 0 98 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 99 | """Subsample x. 100 | 101 | Args: 102 | x (torch.Tensor): Input tensor (#batch, time, idim). 103 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 104 | 105 | Returns: 106 | torch.Tensor: Subsampled tensor (#batch, time', odim), 107 | where time' = time // 4. 108 | torch.Tensor: Subsampled mask (#batch, 1, time'), 109 | where time' = time // 4. 110 | torch.Tensor: positional encoding 111 | 112 | """ 113 | x = x.unsqueeze(1) # (b, c=1, t, f) 114 | x = self.conv(x) 115 | b, c, t, f = x.size() 116 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) 117 | x, pos_emb = self.pos_enc(x, offset) 118 | return x, pos_emb, x_mask[:, :, :-2:2] 119 | 120 | 121 | class Conv2dSubsampling4(BaseSubsampling): 122 | """Convolutional 2D subsampling (to 1/4 length). 123 | 124 | Args: 125 | idim (int): Input dimension. 126 | odim (int): Output dimension. 127 | dropout_rate (float): Dropout rate. 128 | 129 | """ 130 | def __init__(self, idim: int, odim: int, dropout_rate: float, 131 | pos_enc_class: torch.nn.Module): 132 | """Construct an Conv2dSubsampling4 object.""" 133 | super().__init__() 134 | self.conv = torch.nn.Sequential( 135 | torch.nn.Conv2d(1, odim, 3, 2), 136 | torch.nn.ReLU(), 137 | torch.nn.Conv2d(odim, odim, 3, 2), 138 | torch.nn.ReLU(), 139 | ) 140 | self.out = torch.nn.Sequential( 141 | torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)) 142 | self.pos_enc = pos_enc_class 143 | # The right context for every conv layer is computed by: 144 | # (kernel_size - 1) * frame_rate_of_this_layer 145 | self.subsampling_rate = 4 146 | # 6 = (3 - 1) * 1 + (3 - 1) * 2 147 | self.right_context = 6 148 | 149 | def forward( 150 | self, 151 | x: torch.Tensor, 152 | x_mask: torch.Tensor, 153 | offset: int = 0 154 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 155 | """Subsample x. 156 | 157 | Args: 158 | x (torch.Tensor): Input tensor (#batch, time, idim). 159 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 160 | 161 | Returns: 162 | torch.Tensor: Subsampled tensor (#batch, time', odim), 163 | where time' = time // 4. 164 | torch.Tensor: Subsampled mask (#batch, 1, time'), 165 | where time' = time // 4. 166 | torch.Tensor: positional encoding 167 | 168 | """ 169 | x = x.unsqueeze(1) # (b, c=1, t, f) 170 | x = self.conv(x) 171 | b, c, t, f = x.size() 172 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) 173 | x, pos_emb = self.pos_enc(x, offset) 174 | return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2] 175 | 176 | 177 | class Conv2dSubsampling6(BaseSubsampling): 178 | """Convolutional 2D subsampling (to 1/6 length). 179 | Args: 180 | idim (int): Input dimension. 181 | odim (int): Output dimension. 182 | dropout_rate (float): Dropout rate. 183 | pos_enc (torch.nn.Module): Custom position encoding layer. 184 | """ 185 | def __init__(self, idim: int, odim: int, dropout_rate: float, 186 | pos_enc_class: torch.nn.Module): 187 | """Construct an Conv2dSubsampling6 object.""" 188 | super().__init__() 189 | self.conv = torch.nn.Sequential( 190 | torch.nn.Conv2d(1, odim, 3, 2), 191 | torch.nn.ReLU(), 192 | torch.nn.Conv2d(odim, odim, 5, 3), 193 | torch.nn.ReLU(), 194 | ) 195 | self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), 196 | odim) 197 | self.pos_enc = pos_enc_class 198 | # 10 = (3 - 1) * 1 + (5 - 1) * 2 199 | self.subsampling_rate = 6 200 | self.right_context = 10 201 | 202 | def forward( 203 | self, 204 | x: torch.Tensor, 205 | x_mask: torch.Tensor, 206 | offset: int = 0 207 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 208 | """Subsample x. 209 | Args: 210 | x (torch.Tensor): Input tensor (#batch, time, idim). 211 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 212 | 213 | Returns: 214 | torch.Tensor: Subsampled tensor (#batch, time', odim), 215 | where time' = time // 6. 216 | torch.Tensor: Subsampled mask (#batch, 1, time'), 217 | where time' = time // 6. 218 | torch.Tensor: positional encoding 219 | """ 220 | x = x.unsqueeze(1) # (b, c, t, f) 221 | x = self.conv(x) 222 | b, c, t, f = x.size() 223 | x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) 224 | x, pos_emb = self.pos_enc(x, offset) 225 | return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3] 226 | 227 | 228 | class Conv2dSubsampling8(BaseSubsampling): 229 | """Convolutional 2D subsampling (to 1/8 length). 230 | 231 | Args: 232 | idim (int): Input dimension. 233 | odim (int): Output dimension. 234 | dropout_rate (float): Dropout rate. 235 | 236 | """ 237 | def __init__(self, idim: int, odim: int, dropout_rate: float, 238 | pos_enc_class: torch.nn.Module): 239 | """Construct an Conv2dSubsampling8 object.""" 240 | super().__init__() 241 | self.conv = torch.nn.Sequential( 242 | torch.nn.Conv2d(1, odim, 3, 2), 243 | torch.nn.ReLU(), 244 | torch.nn.Conv2d(odim, odim, 3, 2), 245 | torch.nn.ReLU(), 246 | torch.nn.Conv2d(odim, odim, 3, 2), 247 | torch.nn.ReLU(), 248 | ) 249 | self.linear = torch.nn.Linear( 250 | odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim) 251 | self.pos_enc = pos_enc_class 252 | self.subsampling_rate = 8 253 | # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4 254 | self.right_context = 14 255 | 256 | def forward( 257 | self, 258 | x: torch.Tensor, 259 | x_mask: torch.Tensor, 260 | offset: int = 0 261 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 262 | """Subsample x. 263 | 264 | Args: 265 | x (torch.Tensor): Input tensor (#batch, time, idim). 266 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 267 | 268 | Returns: 269 | torch.Tensor: Subsampled tensor (#batch, time', odim), 270 | where time' = time // 8. 271 | torch.Tensor: Subsampled mask (#batch, 1, time'), 272 | where time' = time // 8. 273 | torch.Tensor: positional encoding 274 | """ 275 | x = x.unsqueeze(1) # (b, c, t, f) 276 | x = self.conv(x) 277 | b, c, t, f = x.size() 278 | x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) 279 | x, pos_emb = self.pos_enc(x, offset) 280 | return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2] 281 | 282 | 283 | 284 | -------------------------------------------------------------------------------- /wenet/transformer/swish.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2020 Johns Hopkins University (Shinji Watanabe) 5 | # Northwestern Polytechnical University (Pengcheng Guo) 6 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 7 | """Swish() activation function for Conformer.""" 8 | 9 | import torch 10 | 11 | 12 | class Swish(torch.nn.Module): 13 | """Construct an Swish object.""" 14 | def forward(self, x: torch.Tensor) -> torch.Tensor: 15 | """Return Swish activation function.""" 16 | return x * torch.sigmoid(x) 17 | -------------------------------------------------------------------------------- /wenet/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Mobvoi Inc. All Rights Reserved. 2 | # Author: binbinzhang@mobvoi.com (Binbin Zhang) 3 | 4 | import logging 5 | import os 6 | import re 7 | 8 | import yaml 9 | import torch 10 | 11 | 12 | def load_checkpoint(model: torch.nn.Module, path: str) -> dict: 13 | if torch.cuda.is_available(): 14 | logging.info('Checkpoint: loading from checkpoint %s for GPU' % path) 15 | checkpoint = torch.load(path) 16 | else: 17 | logging.info('Checkpoint: loading from checkpoint %s for CPU' % path) 18 | checkpoint = torch.load(path, map_location='cpu') 19 | model.load_state_dict(checkpoint) 20 | info_path = re.sub('.pt$', '.yaml', path) 21 | configs = {} 22 | if os.path.exists(info_path): 23 | with open(info_path, 'r') as fin: 24 | configs = yaml.load(fin, Loader=yaml.FullLoader) 25 | return configs 26 | 27 | 28 | def save_checkpoint(model: torch.nn.Module, path: str, infos=None): 29 | ''' 30 | Args: 31 | infos (dict or None): any info you want to save. 32 | ''' 33 | logging.info('Checkpoint: save to checkpoint %s' % path) 34 | if isinstance(model, torch.nn.DataParallel): 35 | state_dict = model.module.state_dict() 36 | elif isinstance(model, torch.nn.parallel.DistributedDataParallel): 37 | state_dict = model.module.state_dict() 38 | else: 39 | state_dict = model.state_dict() 40 | torch.save(state_dict, path) 41 | info_path = re.sub('.pt$', '.yaml', path) 42 | if infos is None: 43 | infos = {} 44 | with open(info_path, 'w') as fout: 45 | data = yaml.dump(infos) 46 | fout.write(data) 47 | -------------------------------------------------------------------------------- /wenet/utils/cmvn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import json 17 | import math 18 | 19 | import numpy as np 20 | 21 | 22 | def _load_json_cmvn(json_cmvn_file): 23 | """ Load the json format cmvn stats file and calculate cmvn 24 | 25 | Args: 26 | json_cmvn_file: cmvn stats file in json format 27 | 28 | Returns: 29 | a numpy array of [means, vars] 30 | """ 31 | with open(json_cmvn_file) as f: 32 | cmvn_stats = json.load(f) 33 | 34 | means = cmvn_stats['mean_stat'] 35 | variance = cmvn_stats['var_stat'] 36 | count = cmvn_stats['frame_num'] 37 | for i in range(len(means)): 38 | means[i] /= count 39 | variance[i] = variance[i] / count - means[i] * means[i] 40 | if variance[i] < 1.0e-20: 41 | variance[i] = 1.0e-20 42 | variance[i] = 1.0 / math.sqrt(variance[i]) 43 | cmvn = np.array([means, variance]) 44 | return cmvn 45 | 46 | 47 | def _load_kaldi_cmvn(kaldi_cmvn_file): 48 | """ Load the kaldi format cmvn stats file and calculate cmvn 49 | 50 | Args: 51 | kaldi_cmvn_file: kaldi text style global cmvn file, which 52 | is generated by: 53 | compute-cmvn-stats --binary=false scp:feats.scp global_cmvn 54 | 55 | Returns: 56 | a numpy array of [means, vars] 57 | """ 58 | means = [] 59 | variance = [] 60 | with open(kaldi_cmvn_file, 'r') as fid: 61 | # kaldi binary file start with '\0B' 62 | if fid.read(2) == '\0B': 63 | logging.error('kaldi cmvn binary file is not supported, please ' 64 | 'recompute it by: compute-cmvn-stats --binary=false ' 65 | ' scp:feats.scp global_cmvn') 66 | sys.exit(1) 67 | fid.seek(0) 68 | arr = fid.read().split() 69 | assert (arr[0] == '[') 70 | assert (arr[-2] == '0') 71 | assert (arr[-1] == ']') 72 | feat_dim = int((len(arr) - 2 - 2) / 2) 73 | for i in range(1, feat_dim + 1): 74 | means.append(float(arr[i])) 75 | count = float(arr[feat_dim + 1]) 76 | for i in range(feat_dim + 2, 2 * feat_dim + 2): 77 | variance.append(float(arr[i])) 78 | 79 | for i in range(len(means)): 80 | means[i] /= count 81 | variance[i] = variance[i] / count - means[i] * means[i] 82 | if variance[i] < 1.0e-20: 83 | variance[i] = 1.0e-20 84 | variance[i] = 1.0 / math.sqrt(variance[i]) 85 | cmvn = np.array([means, variance]) 86 | return cmvn 87 | 88 | 89 | def load_cmvn(cmvn_file, is_json): 90 | if is_json: 91 | cmvn = _load_json_cmvn(cmvn_file) 92 | else: 93 | cmvn = _load_kaldi_cmvn(cmvn_file) 94 | return cmvn[0], cmvn[1] 95 | -------------------------------------------------------------------------------- /wenet/utils/common.py: -------------------------------------------------------------------------------- 1 | """Unility functions for Transformer.""" 2 | 3 | import math 4 | from typing import Tuple, List 5 | 6 | import torch 7 | from torch.nn.utils.rnn import pad_sequence 8 | 9 | IGNORE_ID = -1 10 | 11 | 12 | def pad_list(xs: List[torch.Tensor], pad_value: int): 13 | """Perform padding for the list of tensors. 14 | 15 | Args: 16 | xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. 17 | pad_value (float): Value for padding. 18 | 19 | Returns: 20 | Tensor: Padded tensor (B, Tmax, `*`). 21 | 22 | Examples: 23 | >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] 24 | >>> x 25 | [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] 26 | >>> pad_list(x, 0) 27 | tensor([[1., 1., 1., 1.], 28 | [1., 1., 0., 0.], 29 | [1., 0., 0., 0.]]) 30 | 31 | """ 32 | n_batch = len(xs) 33 | max_len = max([x.size(0) for x in xs]) 34 | pad = torch.zeros(n_batch, max_len, dtype=xs[0].dtype, device=xs[0].device) 35 | pad = pad.fill_(pad_value) 36 | for i in range(n_batch): 37 | pad[i, :xs[i].size(0)] = xs[i] 38 | 39 | return pad 40 | 41 | 42 | def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int, 43 | ignore_id: int) -> Tuple[torch.Tensor, torch.Tensor]: 44 | """Add and labels. 45 | 46 | Args: 47 | ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) 48 | sos (int): index of 49 | eos (int): index of 50 | ignore_id (int): index of padding 51 | 52 | Returns: 53 | ys_in (torch.Tensor) : (B, Lmax + 1) 54 | ys_out (torch.Tensor) : (B, Lmax + 1) 55 | 56 | Examples: 57 | >>> sos_id = 10 58 | >>> eos_id = 11 59 | >>> ignore_id = -1 60 | >>> ys_pad 61 | tensor([[ 1, 2, 3, 4, 5], 62 | [ 4, 5, 6, -1, -1], 63 | [ 7, 8, 9, -1, -1]], dtype=torch.int32) 64 | >>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id) 65 | >>> ys_in 66 | tensor([[10, 1, 2, 3, 4, 5], 67 | [10, 4, 5, 6, 11, 11], 68 | [10, 7, 8, 9, 11, 11]]) 69 | >>> ys_out 70 | tensor([[ 1, 2, 3, 4, 5, 11], 71 | [ 4, 5, 6, 11, -1, -1], 72 | [ 7, 8, 9, 11, -1, -1]]) 73 | """ 74 | _sos = torch.tensor([sos], 75 | dtype=torch.long, 76 | requires_grad=False, 77 | device=ys_pad.device) 78 | _eos = torch.tensor([eos], 79 | dtype=torch.long, 80 | requires_grad=False, 81 | device=ys_pad.device) 82 | ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys 83 | ys_in = [torch.cat([_sos, y], dim=0) for y in ys] 84 | ys_out = [torch.cat([y, _eos], dim=0) for y in ys] 85 | return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) 86 | 87 | 88 | def reverse_pad_list(ys_pad: torch.Tensor, 89 | ys_lens: torch.Tensor, 90 | pad_value: float = -1.0) -> torch.Tensor: 91 | """Reverse padding for the list of tensors. 92 | 93 | Args: 94 | ys_pad (tensor): The padded tensor (B, Tokenmax). 95 | ys_lens (tensor): The lens of token seqs (B) 96 | pad_value (int): Value for padding. 97 | 98 | Returns: 99 | Tensor: Padded tensor (B, Tokenmax). 100 | 101 | Examples: 102 | >>> x 103 | tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]]) 104 | >>> pad_list(x, 0) 105 | tensor([[4, 3, 2, 1], 106 | [7, 6, 5, 0], 107 | [9, 8, 0, 0]]) 108 | 109 | """ 110 | r_ys_pad = pad_sequence([(torch.flip(y.int()[:i], [0])) 111 | for y, i in zip(ys_pad, ys_lens)], True, 112 | pad_value) 113 | return r_ys_pad 114 | 115 | 116 | def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor, 117 | ignore_label: int) -> float: 118 | """Calculate accuracy. 119 | 120 | Args: 121 | pad_outputs (Tensor): Prediction tensors (B * Lmax, D). 122 | pad_targets (LongTensor): Target label tensors (B, Lmax, D). 123 | ignore_label (int): Ignore label id. 124 | 125 | Returns: 126 | float: Accuracy value (0.0 - 1.0). 127 | 128 | """ 129 | pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1), 130 | pad_outputs.size(1)).argmax(2) 131 | mask = pad_targets != ignore_label 132 | numerator = torch.sum( 133 | pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) 134 | denominator = torch.sum(mask) 135 | return float(numerator) / float(denominator) 136 | 137 | 138 | def get_activation(act): 139 | """Return activation function.""" 140 | # Lazy load to avoid unused import 141 | from wenet.transformer.swish import Swish 142 | 143 | activation_funcs = { 144 | "hardtanh": torch.nn.Hardtanh, 145 | "tanh": torch.nn.Tanh, 146 | "relu": torch.nn.ReLU, 147 | "selu": torch.nn.SELU, 148 | "swish": Swish, 149 | "gelu": torch.nn.GELU 150 | } 151 | 152 | return activation_funcs[act]() 153 | 154 | 155 | def get_subsample(config): 156 | input_layer = config["encoder_conf"]["input_layer"] 157 | assert input_layer in ["conv2d", "conv2d6", "conv2d8"] 158 | if input_layer == "conv2d": 159 | return 4 160 | elif input_layer == "conv2d6": 161 | return 6 162 | elif input_layer == "conv2d8": 163 | return 8 164 | 165 | 166 | def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: 167 | new_hyp: List[int] = [] 168 | cur = 0 169 | while cur < len(hyp): 170 | if hyp[cur] != 0: 171 | new_hyp.append(hyp[cur]) 172 | prev = cur 173 | while cur < len(hyp) and hyp[cur] == hyp[prev]: 174 | cur += 1 175 | return new_hyp 176 | 177 | 178 | def log_add(args: List[int]) -> float: 179 | """ 180 | Stable log add 181 | """ 182 | if all(a == -float('inf') for a in args): 183 | return -float('inf') 184 | a_max = max(args) 185 | lsp = math.log(sum(math.exp(a - a_max) for a in args)) 186 | return a_max + lsp 187 | -------------------------------------------------------------------------------- /wenet/utils/ctc_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Mobvoi Inc. All Rights Reserved. 2 | # Author: binbinzhang@mobvoi.com (Di Wu) 3 | 4 | import numpy as np 5 | import torch 6 | 7 | def insert_blank(label, blank_id=0): 8 | """Insert blank token between every two label token.""" 9 | label = np.expand_dims(label, 1) 10 | blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id 11 | label = np.concatenate([blanks, label], axis=1) 12 | label = label.reshape(-1) 13 | label = np.append(label, label[0]) 14 | return label 15 | 16 | def forced_align(ctc_probs: torch.Tensor, 17 | y: torch.Tensor, 18 | blank_id=0) -> list: 19 | """ctc forced alignment. 20 | 21 | Args: 22 | torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D) 23 | torch.Tensor y: id sequence tensor 1d tensor (L) 24 | int blank_id: blank symbol index 25 | Returns: 26 | torch.Tensor: alignment result 27 | """ 28 | y_insert_blank = insert_blank(y, blank_id) 29 | 30 | log_alpha = torch.zeros((ctc_probs.size(0), len(y_insert_blank))) 31 | log_alpha = log_alpha - float('inf') # log of zero 32 | state_path = (torch.zeros( 33 | (ctc_probs.size(0), len(y_insert_blank)), dtype=torch.int16) - 1 34 | ) # state path 35 | 36 | # init start state 37 | log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] 38 | log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] 39 | 40 | for t in range(1, ctc_probs.size(0)): 41 | for s in range(len(y_insert_blank)): 42 | if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[ 43 | s] == y_insert_blank[s - 2]: 44 | candidates = torch.tensor( 45 | [log_alpha[t - 1, s], log_alpha[t - 1, s - 1]]) 46 | prev_state = [s, s - 1] 47 | else: 48 | candidates = torch.tensor([ 49 | log_alpha[t - 1, s], 50 | log_alpha[t - 1, s - 1], 51 | log_alpha[t - 1, s - 2], 52 | ]) 53 | prev_state = [s, s - 1, s - 2] 54 | log_alpha[t, s] = torch.max(candidates) + ctc_probs[t][y_insert_blank[s]] 55 | state_path[t, s] = prev_state[torch.argmax(candidates)] 56 | 57 | state_seq = -1 * torch.ones((ctc_probs.size(0), 1), dtype=torch.int16) 58 | 59 | candidates = torch.tensor([ 60 | log_alpha[-1, len(y_insert_blank) - 1], 61 | log_alpha[-1, len(y_insert_blank) - 2] 62 | ]) 63 | prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2] 64 | state_seq[-1] = prev_state[torch.argmax(candidates)] 65 | for t in range(ctc_probs.size(0) - 2, -1, -1): 66 | state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]] 67 | 68 | output_alignment = [] 69 | for t in range(0, ctc_probs.size(0)): 70 | output_alignment.append(y_insert_blank[state_seq[t, 0]]) 71 | 72 | return output_alignment 73 | -------------------------------------------------------------------------------- /wenet/utils/executor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Mobvoi Inc. All Rights Reserved. 2 | # Author: binbinzhang@mobvoi.com (Binbin Zhang) 3 | 4 | import logging 5 | from contextlib import nullcontext 6 | # if your python version < 3.7 use the below one 7 | # from contextlib import suppress as nullcontext 8 | import torch 9 | from torch.nn.utils import clip_grad_norm_ 10 | 11 | 12 | class Executor: 13 | def __init__(self): 14 | self.step = 0 15 | 16 | def train(self, model, optimizer, scheduler, data_loader, device, writer, 17 | args, scaler): 18 | ''' Train one epoch 19 | ''' 20 | model.train() 21 | clip = args.get('grad_clip', 50.0) 22 | log_interval = args.get('log_interval', 10) 23 | rank = args.get('rank', 0) 24 | accum_grad = args.get('accum_grad', 1) 25 | is_distributed = args.get('is_distributed', True) 26 | use_amp = args.get('use_amp', False) 27 | logging.info('using accumulate grad, new batch size is {} times' 28 | 'larger than before'.format(accum_grad)) 29 | if use_amp: 30 | assert scaler is not None 31 | num_seen_utts = 0 32 | num_total_batch = len(data_loader) 33 | for batch_idx, batch in enumerate(data_loader): 34 | key, feats, target, feats_lengths, target_lengths = batch 35 | feats = feats.to(device) 36 | target = target.to(device) 37 | feats_lengths = feats_lengths.to(device) 38 | target_lengths = target_lengths.to(device) 39 | num_utts = target_lengths.size(0) 40 | if num_utts == 0: 41 | continue 42 | context = None 43 | # Disable gradient synchronizations across DDP processes. 44 | # Within this context, gradients will be accumulated on module 45 | # variables, which will later be synchronized. 46 | if is_distributed and batch_idx % accum_grad != 0: 47 | context = model.no_sync 48 | # Used for single gpu training and DDP gradient synchronization 49 | # processes. 50 | else: 51 | context = nullcontext 52 | with context(): 53 | # autocast context 54 | # The more details about amp can be found in 55 | # https://pytorch.org/docs/stable/notes/amp_examples.html 56 | with torch.cuda.amp.autocast(scaler is not None): 57 | loss, loss_att, loss_ctc = model(feats, feats_lengths, 58 | target, target_lengths) 59 | loss = loss / accum_grad 60 | if use_amp: 61 | scaler.scale(loss).backward() 62 | else: 63 | loss.backward() 64 | 65 | num_seen_utts += num_utts 66 | if batch_idx % accum_grad == 0: 67 | if rank == 0 and writer is not None: 68 | writer.add_scalar('train_loss', loss, self.step) 69 | # Use mixed precision training 70 | if use_amp: 71 | scaler.unscale_(optimizer) 72 | grad_norm = clip_grad_norm_(model.parameters(), clip) 73 | # Must invoke scaler.update() if unscale_() is used in the 74 | # iteration to avoid the following error: 75 | # RuntimeError: unscale_() has already been called 76 | # on this optimizer since the last update(). 77 | # We don't check grad here since that if the gradient has 78 | # inf/nan values, scaler.step will skip optimizer.step(). 79 | scaler.step(optimizer) 80 | scaler.update() 81 | else: 82 | grad_norm = clip_grad_norm_(model.parameters(), clip) 83 | if torch.isfinite(grad_norm): 84 | optimizer.step() 85 | optimizer.zero_grad() 86 | scheduler.step() 87 | self.step += 1 88 | if batch_idx % log_interval == 0: 89 | lr = optimizer.param_groups[0]['lr'] 90 | log_str = 'TRAIN Batch {}/{} loss {:.6f} '.format( 91 | batch_idx, num_total_batch, 92 | loss.item() * accum_grad) 93 | if loss_att is not None: 94 | log_str += 'loss_att {:.6f} '.format(loss_att.item()) 95 | if loss_ctc is not None: 96 | log_str += 'loss_ctc {:.6f} '.format(loss_ctc.item()) 97 | log_str += 'lr {:.8f} rank {}'.format(lr, rank) 98 | logging.debug(log_str) 99 | 100 | def cv(self, model, data_loader, device, args): 101 | ''' Cross validation on 102 | ''' 103 | model.eval() 104 | log_interval = args.get('log_interval', 10) 105 | # in order to avoid division by 0 106 | num_seen_utts = 1 107 | total_loss = 0.0 108 | num_total_batch = len(data_loader) 109 | with torch.no_grad(): 110 | for batch_idx, batch in enumerate(data_loader): 111 | key, feats, target, feats_lengths, target_lengths = batch 112 | feats = feats.to(device) 113 | target = target.to(device) 114 | feats_lengths = feats_lengths.to(device) 115 | target_lengths = target_lengths.to(device) 116 | num_utts = target_lengths.size(0) 117 | if num_utts == 0: 118 | continue 119 | loss, loss_att, loss_ctc = model(feats, feats_lengths, target, 120 | target_lengths) 121 | if torch.isfinite(loss): 122 | num_seen_utts += num_utts 123 | total_loss += loss.item() * num_utts 124 | if batch_idx % log_interval == 0: 125 | log_str = 'CV Batch {}/{} loss {:.6f} '.format( 126 | batch_idx, num_total_batch, loss.item()) 127 | if loss_att is not None: 128 | log_str += 'loss_att {:.6f} '.format(loss_att.item()) 129 | if loss_ctc is not None: 130 | log_str += 'loss_ctc {:.6f} '.format(loss_ctc.item()) 131 | log_str += 'history loss {:.6f}'.format(total_loss / 132 | num_seen_utts) 133 | logging.debug(log_str) 134 | 135 | return total_loss, num_seen_utts 136 | -------------------------------------------------------------------------------- /wenet/utils/mask.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2019 Shigeki Karita 4 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 5 | 6 | import torch 7 | 8 | 9 | def subsequent_mask( 10 | size: int, 11 | device: torch.device = torch.device("cpu"), 12 | ) -> torch.Tensor: 13 | """Create mask for subsequent steps (size, size). 14 | 15 | This mask is used only in decoder which works in an auto-regressive mode. 16 | This means the current step could only do attention with its left steps. 17 | 18 | In encoder, fully attention is used when streaming is not necessary and 19 | the sequence is not long. In this case, no attention mask is needed. 20 | 21 | When streaming is need, chunk-based attention is used in encoder. See 22 | subsequent_chunk_mask for the chunk-based attention mask. 23 | 24 | Args: 25 | size (int): size of mask 26 | str device (str): "cpu" or "cuda" or torch.Tensor.device 27 | dtype (torch.device): result dtype 28 | 29 | Returns: 30 | torch.Tensor: mask 31 | 32 | Examples: 33 | >>> subsequent_mask(3) 34 | [[1, 0, 0], 35 | [1, 1, 0], 36 | [1, 1, 1]] 37 | """ 38 | ret = torch.ones(size, size, device=device, dtype=torch.bool) 39 | return torch.tril(ret, out=ret) 40 | 41 | 42 | def subsequent_chunk_mask( 43 | size: int, 44 | chunk_size: int, 45 | num_left_chunks: int = -1, 46 | device: torch.device = torch.device("cpu"), 47 | ) -> torch.Tensor: 48 | """Create mask for subsequent steps (size, size) with chunk size, 49 | this is for streaming encoder 50 | 51 | Args: 52 | size (int): size of mask 53 | chunk_size (int): size of chunk 54 | num_left_chunks (int): number of left chunks 55 | <0: use full chunk 56 | >=0: use num_left_chunks 57 | device (torch.device): "cpu" or "cuda" or torch.Tensor.device 58 | 59 | Returns: 60 | torch.Tensor: mask 61 | 62 | Examples: 63 | >>> subsequent_chunk_mask(4, 2) 64 | [[1, 1, 0, 0], 65 | [1, 1, 0, 0], 66 | [1, 1, 1, 1], 67 | [1, 1, 1, 1]] 68 | """ 69 | ret = torch.zeros(size, size, device=device, dtype=torch.bool) 70 | for i in range(size): 71 | if num_left_chunks < 0: 72 | start = 0 73 | else: 74 | start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) 75 | ending = min((i // chunk_size + 1) * chunk_size, size) 76 | ret[i, start:ending] = True 77 | return ret 78 | 79 | 80 | def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor, 81 | use_dynamic_chunk: bool, 82 | use_dynamic_left_chunk: bool, 83 | decoding_chunk_size: int, static_chunk_size: int, 84 | num_decoding_left_chunks: int): 85 | """ Apply optional mask for encoder. 86 | 87 | Args: 88 | xs (torch.Tensor): padded input, (B, L, D), L for max length 89 | mask (torch.Tensor): mask for xs, (B, 1, L) 90 | use_dynamic_chunk (bool): whether to use dynamic chunk or not 91 | use_dynamic_left_chunk (bool): whether to use dynamic left chunk for 92 | training. 93 | decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's 94 | 0: default for training, use random dynamic chunk. 95 | <0: for decoding, use full chunk. 96 | >0: for decoding, use fixed chunk size as set. 97 | static_chunk_size (int): chunk size for static chunk training/decoding 98 | if it's greater than 0, if use_dynamic_chunk is true, 99 | this parameter will be ignored 100 | num_decoding_left_chunks: number of left chunks, this is for decoding, 101 | the chunk size is decoding_chunk_size. 102 | >=0: use num_decoding_left_chunks 103 | <0: use all left chunks 104 | 105 | Returns: 106 | torch.Tensor: chunk mask of the input xs. 107 | """ 108 | # Whether to use chunk mask or not 109 | if use_dynamic_chunk: 110 | max_len = xs.size(1) 111 | if decoding_chunk_size < 0: 112 | chunk_size = max_len 113 | num_left_chunks = -1 114 | elif decoding_chunk_size > 0: 115 | chunk_size = decoding_chunk_size 116 | num_left_chunks = num_decoding_left_chunks 117 | else: 118 | # chunk size is either [1, 25] or full context(max_len). 119 | # Since we use 4 times subsampling and allow up to 1s(100 frames) 120 | # delay, the maximum frame is 100 / 4 = 25. 121 | chunk_size = torch.randint(1, max_len, (1, )).item() 122 | num_left_chunks = -1 123 | if chunk_size > max_len // 2: 124 | chunk_size = max_len 125 | else: 126 | chunk_size = chunk_size % 25 + 1 127 | if use_dynamic_left_chunk: 128 | max_left_chunks = (max_len - 1) // chunk_size 129 | num_left_chunks = torch.randint(0, max_left_chunks, 130 | (1, )).item() 131 | chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size, 132 | num_left_chunks, 133 | xs.device) # (L, L) 134 | chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) 135 | chunk_masks = masks & chunk_masks # (B, L, L) 136 | elif static_chunk_size > 0: 137 | num_left_chunks = num_decoding_left_chunks 138 | chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size, 139 | num_left_chunks, 140 | xs.device) # (L, L) 141 | chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) 142 | chunk_masks = masks & chunk_masks # (B, L, L) 143 | else: 144 | chunk_masks = masks 145 | return chunk_masks 146 | 147 | 148 | def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor: 149 | """Make mask tensor containing indices of padded part. 150 | 151 | See description of make_non_pad_mask. 152 | 153 | Args: 154 | lengths (torch.Tensor): Batch of lengths (B,). 155 | Returns: 156 | torch.Tensor: Mask tensor containing indices of padded part. 157 | 158 | Examples: 159 | >>> lengths = [5, 3, 2] 160 | >>> make_pad_mask(lengths) 161 | masks = [[0, 0, 0, 0 ,0], 162 | [0, 0, 0, 1, 1], 163 | [0, 0, 1, 1, 1]] 164 | """ 165 | batch_size = int(lengths.size(0)) 166 | max_len = int(lengths.max().item()) 167 | seq_range = torch.arange(0, 168 | max_len, 169 | dtype=torch.int64, 170 | device=lengths.device) 171 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) 172 | seq_length_expand = lengths.unsqueeze(-1) 173 | mask = seq_range_expand >= seq_length_expand 174 | return mask 175 | 176 | 177 | def make_non_pad_mask(lengths: torch.Tensor) -> torch.Tensor: 178 | """Make mask tensor containing indices of non-padded part. 179 | 180 | The sequences in a batch may have different lengths. To enable 181 | batch computing, padding is need to make all sequence in same 182 | size. To avoid the padding part pass value to context dependent 183 | block such as attention or convolution , this padding part is 184 | masked. 185 | 186 | This pad_mask is used in both encoder and decoder. 187 | 188 | 1 for non-padded part and 0 for padded part. 189 | 190 | Args: 191 | lengths (torch.Tensor): Batch of lengths (B,). 192 | Returns: 193 | torch.Tensor: mask tensor containing indices of padded part. 194 | 195 | Examples: 196 | >>> lengths = [5, 3, 2] 197 | >>> make_non_pad_mask(lengths) 198 | masks = [[1, 1, 1, 1 ,1], 199 | [1, 1, 1, 0, 0], 200 | [1, 1, 0, 0, 0]] 201 | """ 202 | return ~make_pad_mask(lengths) 203 | 204 | 205 | def mask_finished_scores(score: torch.Tensor, 206 | flag: torch.Tensor) -> torch.Tensor: 207 | """ 208 | If a sequence is finished, we only allow one alive branch. This function 209 | aims to give one branch a zero score and the rest -inf score. 210 | 211 | Args: 212 | score (torch.Tensor): A real value array with shape 213 | (batch_size * beam_size, beam_size). 214 | flag (torch.Tensor): A bool array with shape 215 | (batch_size * beam_size, 1). 216 | 217 | Returns: 218 | torch.Tensor: (batch_size * beam_size, beam_size). 219 | """ 220 | beam_size = score.size(-1) 221 | zero_mask = torch.zeros_like(flag, dtype=torch.bool) 222 | if beam_size > 1: 223 | unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])), 224 | dim=1) 225 | finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])), 226 | dim=1) 227 | else: 228 | unfinished = zero_mask 229 | finished = flag 230 | score.masked_fill_(unfinished, -float('inf')) 231 | score.masked_fill_(finished, 0) 232 | return score 233 | 234 | 235 | def mask_finished_preds(pred: torch.Tensor, flag: torch.Tensor, 236 | eos: int) -> torch.Tensor: 237 | """ 238 | If a sequence is finished, all of its branch should be 239 | 240 | Args: 241 | pred (torch.Tensor): A int array with shape 242 | (batch_size * beam_size, beam_size). 243 | flag (torch.Tensor): A bool array with shape 244 | (batch_size * beam_size, 1). 245 | 246 | Returns: 247 | torch.Tensor: (batch_size * beam_size). 248 | """ 249 | beam_size = pred.size(-1) 250 | finished = flag.repeat([1, beam_size]) 251 | return pred.masked_fill_(finished, eos) 252 | -------------------------------------------------------------------------------- /wenet/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | from torch.optim.lr_scheduler import _LRScheduler 5 | 6 | from typeguard import check_argument_types 7 | 8 | 9 | class WarmupLR(_LRScheduler): 10 | """The WarmupLR scheduler 11 | 12 | This scheduler is almost same as NoamLR Scheduler except for following 13 | difference: 14 | 15 | NoamLR: 16 | lr = optimizer.lr * model_size ** -0.5 17 | * min(step ** -0.5, step * warmup_step ** -1.5) 18 | WarmupLR: 19 | lr = optimizer.lr * warmup_step ** 0.5 20 | * min(step ** -0.5, step * warmup_step ** -1.5) 21 | 22 | Note that the maximum lr equals to optimizer.lr in this scheduler. 23 | 24 | """ 25 | 26 | def __init__( 27 | self, 28 | optimizer: torch.optim.Optimizer, 29 | warmup_steps: Union[int, float] = 25000, 30 | last_epoch: int = -1, 31 | ): 32 | assert check_argument_types() 33 | self.warmup_steps = warmup_steps 34 | 35 | # __init__() must be invoked before setting field 36 | # because step() is also invoked in __init__() 37 | super().__init__(optimizer, last_epoch) 38 | 39 | def __repr__(self): 40 | return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})" 41 | 42 | def get_lr(self): 43 | step_num = self.last_epoch + 1 44 | return [ 45 | lr 46 | * self.warmup_steps ** 0.5 47 | * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5) 48 | for lr in self.base_lrs 49 | ] 50 | 51 | def set_step(self, step: int): 52 | self.last_epoch = step 53 | --------------------------------------------------------------------------------