├── .gitignore
├── README.md
├── docs
├── mfa_conformer.png
└── results.png
├── loss
├── __init__.py
├── amsoftmax.py
├── softmax.py
└── utils.py
├── main.py
├── module
├── __init__.py
├── _pooling.py
├── augment.py
├── conformer.py
├── conformer_cat.py
├── conformer_weight.py
├── dataset.py
├── ecapa_tdnn.py
├── feature.py
├── loader.py
├── resnet.py
├── transformer_cat.py
└── utils.py
├── score
├── __init__.py
├── cosine.py
└── utils.py
├── scripts
├── build_datalist.py
├── format_trials.py
├── make_balanced_data.py
├── make_cohort_set.py
├── make_tsne_set.py
└── plot_score.py
├── start.sh
└── wenet
├── transformer
├── attention.py
├── cmvn.py
├── convolution.py
├── embedding.py
├── encoder.py
├── encoder_cat.py
├── encoder_layer.py
├── encoder_weight.py
├── label_smoothing_loss.py
├── positionwise_feed_forward.py
├── subsampling.py
└── swish.py
└── utils
├── checkpoint.py
├── cmvn.py
├── common.py
├── ctc_util.py
├── executor.py
├── mask.py
└── scheduler.py
/.gitignore:
--------------------------------------------------------------------------------
1 | config.json
2 | experiment
3 | data
4 | test.sh
5 | meta
6 | *.wav
7 | lightning_logs
8 | *.ckpt
9 | *.pt
10 | *.lst
11 | *.txt
12 | data/
13 | *.onnx
14 |
15 | # Byte-compiled / optimized / DLL files
16 | __pycache__/
17 | *.py[cod]
18 | *$py.class
19 |
20 | # C extensions
21 | *.so
22 |
23 | # Distribution / packaging
24 | .Python
25 | build/
26 | develop-eggs/
27 | dist/
28 | downloads/
29 | eggs/
30 | .eggs/
31 | lib/
32 | lib64/
33 | parts/
34 | sdist/
35 | var/
36 | wheels/
37 | share/python-wheels/
38 | *.egg-info/
39 | .installed.cfg
40 | *.egg
41 | MANIFEST
42 |
43 | # PyInstaller
44 | # Usually these files are written by a python script from a template
45 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
46 | *.manifest
47 | *.spec
48 |
49 | # Installer logs
50 | pip-log.txt
51 | pip-delete-this-directory.txt
52 |
53 | # Unit test / coverage reports
54 | htmlcov/
55 | .tox/
56 | .nox/
57 | .coverage
58 | .coverage.*
59 | .cache
60 | nosetests.xml
61 | coverage.xml
62 | *.cover
63 | .hypothesis/
64 | .pytest_cache/
65 |
66 | # Translations
67 | *.mo
68 | *.pot
69 |
70 | # Django stuff:
71 | *.log
72 | local_settings.py
73 | db.sqlite3
74 |
75 | # Flask stuff:
76 | instance/
77 | .webassets-cache
78 |
79 | # Scrapy stuff:
80 | .scrapy
81 |
82 | # Sphinx documentation
83 | docs/_build/
84 |
85 | # PyBuilder
86 | target/
87 |
88 | # Jupyter Notebook
89 | .ipynb_checkpoints
90 |
91 | # IPython
92 | profile_default/
93 | ipython_config.py
94 |
95 | # pyenv
96 | .python-version
97 |
98 | # celery beat schedule file
99 | celerybeat-schedule
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MFA-Conformer
2 |
3 | This repository contains the training code accompanying the paper "MFA-Conformer: Multi-scale Feature Aggregation Conformer for Automatic Speaker Verification", which is submitted to Interspeech 2022.
4 |
5 |

6 |
7 | The architecture of the MFA-Conformer is inspired by recent state-of-the-art models in speech recognition and speaker verification. Firstly, we introduce a convolution subsampling layer to decrease the computational cost of the model. Secondly, we adopt Conformer blocks which combine Transformers and convolution neural networks (CNNs) to capture global and local features effectively. Finally, the output feature maps from all Conformer blocks are concatenated to aggregate multi-scale representations before final pooling. The best system obtains 0.64%, 1.29% and 1.63% EER on VoxCeleb1-O, SITW.Dev, and SITW.Eval set, respectively.
8 |
9 | ## Data Preparation
10 |
11 | * [VoxCeleb 1&2](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/)
12 | * [SITW](http://www.speech.sri.com/projects/sitw/)
13 |
14 | ```bash
15 | # format Voxceleb test trial list
16 | rm -rf data; mkdir data
17 | wget -P data/ https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt
18 | python3 scripts/format_trials.py \
19 | --voxceleb1_root $voxceleb1_dir \
20 | --src_trials_path data/veri_test.txt \
21 | --dst_trials_path data/vox1_test.txt
22 |
23 | # make csv for voxceleb1&2 dev audio (train_dir)
24 | python3 scripts/build_datalist.py \
25 | --extension wav \
26 | --dataset_dir data/$train_dir \
27 | --data_list_path data/train.csv
28 | ```
29 |
30 | ## Model Training
31 |
32 | ```bash
33 | python3 main.py \
34 | --batch_size 200 \
35 | --num_workers 40 \
36 | --max_epochs 30 \
37 | --embedding_dim $embedding_dim \
38 | --save_dir $save_dir \
39 | --encoder_name $encoder_name \
40 | --train_csv_path $train_csv_path \
41 | --learning_rate 0.001 \
42 | --encoder_name ${encoder_name} \
43 | --num_classes $num_classes \
44 | --trial_path $trial_path \
45 | --loss_name $loss_name \
46 | --num_blocks $num_blocks \
47 | --step_size 4 \
48 | --gamma 0.5 \
49 | --weight_decay 0.0000001 \
50 | --input_layer $input_layer \
51 | --pos_enc_layer_type $pos_enc_layer_type
52 | ```
53 |
54 | ## Results
55 |
56 | The training results of default configuration is prestented below (Voxceleb1-test):
57 |
58 | 
59 |
60 | ## Others
61 |
62 | What's more, here are some tips might be useful:
63 |
64 | 1. **The Conformer block**: We the borrow a lot of code from [WeNet](https://github.com/wenet-e2e/wenet) toolkit.
65 | 2. **Average the checkpoint weights**: When the model training is done, we average the parameters of the last 3~10 checkpoints to generate a new checkpoint. The new checkpoint always tends to achieve a better recognition performance.
66 | 3. **Warmup**: We perform a linear warmup learning rate schedule at the first 2k training steps. And we find that this warmup procedure is very helpful for the model training.
67 | 4. **AS-norm**: Adaptive score normalization (AS-norm) is common trick for speaker recognition. In our experiment, it will lead to 5%-10% relative improvement in EER metric.
68 |
69 | ## Citation
70 |
71 | If you find this code useful for your research, please cite our paper.
72 |
73 | ```
74 | @article{zhang2022mfa,
75 | title={MFA-Conformer: Multi-scale Feature Aggregation Conformer for Automatic Speaker Verification},
76 | author={Zhang, Yang and Lv, Zhiqiang and Wu, Haibin and Zhang, Shanshan and Hu, Pengfei and Wu, Zhiyong and Lee, Hung-yi and Meng, Helen},
77 | journal={arXiv preprint arXiv:2203.15249},
78 | year={2022}
79 | }
80 | ```
81 |
--------------------------------------------------------------------------------
/docs/mfa_conformer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zyzisyz/mfa_conformer/1b9c229948f8dbdbe9370937813ec75d4b06b097/docs/mfa_conformer.png
--------------------------------------------------------------------------------
/docs/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zyzisyz/mfa_conformer/1b9c229948f8dbdbe9370937813ec75d4b06b097/docs/results.png
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Yang Zhang.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .softmax import softmax
16 | from .amsoftmax import amsoftmax
17 |
18 |
--------------------------------------------------------------------------------
/loss/amsoftmax.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 | # Adapted from https://github.com/CoinCheung/pytorch-loss (MIT License)
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from .utils import accuracy
9 |
10 | class amsoftmax(nn.Module):
11 | def __init__(self, embedding_dim, num_classes, margin=0.2, scale=30, **kwargs):
12 | super(amsoftmax, self).__init__()
13 |
14 | self.m = margin
15 | self.s = scale
16 | self.in_feats = embedding_dim
17 | self.W = torch.nn.Parameter(torch.randn(embedding_dim, num_classes), requires_grad=True)
18 | self.ce = nn.CrossEntropyLoss()
19 | nn.init.xavier_normal_(self.W, gain=1)
20 |
21 | print('Initialised AM-Softmax m=%.3f s=%.3f'%(self.m, self.s))
22 | print('Embedding dim is {}, number of speakers is {}'.format(embedding_dim, num_classes))
23 |
24 | def forward(self, x, label=None):
25 | assert x.size()[0] == label.size()[0]
26 | assert x.size()[1] == self.in_feats
27 |
28 | x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12)
29 | x_norm = torch.div(x, x_norm)
30 | w_norm = torch.norm(self.W, p=2, dim=0, keepdim=True).clamp(min=1e-12)
31 | w_norm = torch.div(self.W, w_norm)
32 | costh = torch.mm(x_norm, w_norm)
33 | label_view = label.view(-1, 1)
34 | if label_view.is_cuda: label_view = label_view.cpu()
35 | delt_costh = torch.zeros(costh.size()).scatter_(1, label_view, self.m)
36 | if x.is_cuda: delt_costh = delt_costh.cuda()
37 | costh_m = costh - delt_costh
38 | costh_m_s = self.s * costh_m
39 | loss = self.ce(costh_m_s, label)
40 | acc = accuracy(costh_m_s.detach(), label.detach(), topk=(1,))[0]
41 | return loss, acc
42 |
43 |
--------------------------------------------------------------------------------
/loss/softmax.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .utils import accuracy
7 |
8 | class softmax(nn.Module):
9 | def __init__(self, embedding_dim, num_classes, **kwargs):
10 | super(softmax, self).__init__()
11 | self.embedding_dim = embedding_dim
12 | self.fc = nn.Linear(embedding_dim, num_classes)
13 | self.criertion = nn.CrossEntropyLoss()
14 |
15 | print('init softmax')
16 | print('Embedding dim is {}, number of speakers is {}'.format(embedding_dim, num_classes))
17 |
18 | def forward(self, x, label=None):
19 | assert x.size()[0] == label.size()[0]
20 | assert x.size()[1] == self.embedding_dim
21 |
22 | x = F.normalize(x, dim=1)
23 | x = self.fc(x)
24 | loss = self.criertion(x, label)
25 | acc1 = accuracy(x.detach(), label.detach(), topk=(1,))[0]
26 | return loss, acc1
27 |
28 |
29 | if __name__ == "__main__":
30 | model = softmax(10, 100)
31 | data = torch.randn((2, 10))
32 | label = torch.tensor([0, 1])
33 | loss, acc = model(data, label)
34 |
35 | print(data.shape)
36 | print(loss)
37 | print(acc)
38 |
39 |
--------------------------------------------------------------------------------
/loss/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 |
4 | import torch
5 | import torch.nn.functional as F
6 |
7 |
8 | def accuracy(output, target, topk=(1,)):
9 | """Computes the precision@k for the specified values of k"""
10 | maxk = max(topk)
11 | batch_size = target.size(0)
12 |
13 | _, pred = output.topk(maxk, 1, True, True)
14 | pred = pred.t()
15 | correct = pred.eq(target.view(1, -1).expand_as(pred))
16 |
17 | res = []
18 | for k in topk:
19 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
20 | res.append(correct_k.mul_(100.0 / batch_size))
21 | return res
22 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from copy import deepcopy
3 | from typing import Any, Union
4 | import torch.distributed as dist
5 | from pytorch_lightning.plugins import DDPPlugin
6 | import random
7 |
8 | import torch
9 | import torch.nn as nn
10 | import numpy as np
11 |
12 | from pytorch_lightning import LightningModule, Trainer, seed_everything
13 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
14 | from torch.nn import functional as F
15 | from torch.optim import AdamW
16 | from torch.optim.lr_scheduler import StepLR, CyclicLR
17 |
18 | from module.feature import Mel_Spectrogram
19 | from module.loader import SPK_datamodule
20 | import score as score
21 | from loss import softmax, amsoftmax
22 |
23 | class Task(LightningModule):
24 | def __init__(
25 | self,
26 | learning_rate: float = 0.2,
27 | weight_decay: float = 1.5e-6,
28 | batch_size: int = 32,
29 | num_workers: int = 10,
30 | max_epochs: int = 1000,
31 | trial_path: str = "data/vox1_test.txt",
32 | **kwargs
33 | ):
34 | super().__init__()
35 | self.save_hyperparameters()
36 | self.trials = np.loadtxt(self.hparams.trial_path, str)
37 | self.mel_trans = Mel_Spectrogram()
38 |
39 | from module.resnet import resnet34, resnet18, resnet34_large
40 | from module.ecapa_tdnn import ecapa_tdnn, ecapa_tdnn_large
41 | from module.transformer_cat import transformer_cat
42 | from module.conformer import conformer
43 | from module.conformer_cat import conformer_cat
44 | from module.conformer_weight import conformer_weight
45 |
46 | if self.hparams.encoder_name == "resnet18":
47 | self.encoder = resnet18(embedding_dim=self.hparams.embedding_dim)
48 |
49 | elif self.hparams.encoder_name == "resnet34":
50 | self.encoder = resnet34_large(embedding_dim=self.hparams.embedding_dim)
51 |
52 | elif self.hparams.encoder_name == "ecapa_tdnn":
53 | self.encoder = ecapa_tdnn(embedding_dim=self.hparams.embedding_dim)
54 |
55 | elif self.hparams.encoder_name == "ecapa_tdnn_large":
56 | self.encoder = ecapa_tdnn_large(embedding_dim=self.hparams.embedding_dim)
57 |
58 | elif self.hparams.encoder_name == "conformer":
59 | print("num_blocks is {}".format(self.hparams.num_blocks))
60 | self.encoder = conformer(embedding_dim=self.hparams.embedding_dim,
61 | num_blocks=self.hparams.num_blocks, input_layer=self.hparams.input_layer)
62 |
63 | elif self.hparams.encoder_name == "transformer_cat":
64 | print("num_blocks is {}".format(self.hparams.num_blocks))
65 | self.encoder = transformer_cat(embedding_dim=self.hparams.embedding_dim,
66 | num_blocks=self.hparams.num_blocks, input_layer=self.hparams.input_layer)
67 |
68 | elif self.hparams.encoder_name == "conformer_cat":
69 | print("num_blocks is {}".format(self.hparams.num_blocks))
70 | self.encoder = conformer_cat(embedding_dim=self.hparams.embedding_dim,
71 | num_blocks=self.hparams.num_blocks, input_layer=self.hparams.input_layer,
72 | pos_enc_layer_type=self.hparams.pos_enc_layer_type)
73 |
74 | elif self.hparams.encoder_name == "conformer_weight":
75 | print("num_blocks is {}".format(self.hparams.num_blocks))
76 | self.encoder = conformer_weight(embedding_dim=self.hparams.embedding_dim,
77 | num_blocks=self.hparams.num_blocks, input_layer=self.hparams.input_layer)
78 |
79 | else:
80 | raise ValueError("encoder name error")
81 |
82 | if self.hparams.loss_name == "amsoftmax":
83 | self.loss_fun = amsoftmax(embedding_dim=self.hparams.embedding_dim, num_classes=self.hparams.num_classes)
84 | else:
85 | self.loss_fun = softmax(embedding_dim=self.hparams.embedding_dim, num_classes=self.hparams.num_classes)
86 |
87 | def forward(self, x):
88 | feature = self.mel_trans(x)
89 | embedding = self.encoder(feature)
90 | return embedding
91 |
92 | def training_step(self, batch, batch_idx):
93 | waveform, label = batch
94 | feature = self.mel_trans(waveform)
95 | embedding = self.encoder(feature)
96 | loss, acc = self.loss_fun(embedding, label)
97 | self.log('train_loss', loss, prog_bar=True)
98 | self.log('acc', acc, prog_bar=True)
99 | return loss
100 |
101 | def on_test_epoch_start(self):
102 | return self.on_validation_epoch_start()
103 |
104 | def on_validation_epoch_start(self):
105 | self.index_mapping = {}
106 | self.eval_vectors = []
107 |
108 | def test_step(self, batch, batch_idx):
109 | self.validation_step(batch, batch_idx)
110 |
111 | def validation_step(self, batch, batch_idx):
112 | x, path = batch
113 | path = path[0]
114 | with torch.no_grad():
115 | x = self.mel_trans(x)
116 | self.encoder.eval()
117 | x = self.encoder(x)
118 | x = x.detach().cpu().numpy()[0]
119 | self.eval_vectors.append(x)
120 | self.index_mapping[path] = batch_idx
121 |
122 | def test_epoch_end(self, outputs):
123 | return self.validation_epoch_end(outputs)
124 |
125 | def validation_epoch_end(self, outputs):
126 | num_gpus = torch.cuda.device_count()
127 | eval_vectors = [None for _ in range(num_gpus)]
128 | dist.all_gather_object(eval_vectors, self.eval_vectors)
129 | eval_vectors = np.vstack(eval_vectors)
130 |
131 | table = [None for _ in range(num_gpus)]
132 | dist.all_gather_object(table, self.index_mapping)
133 |
134 | index_mapping = {}
135 | for i in table:
136 | index_mapping.update(i)
137 |
138 | eval_vectors = eval_vectors - np.mean(eval_vectors, axis=0)
139 | labels, scores = score.cosine_score(
140 | self.trials, index_mapping, eval_vectors)
141 | EER, threshold = score.compute_eer(labels, scores)
142 |
143 | print("\ncosine EER: {:.2f}% with threshold {:.2f}".format(EER*100, threshold))
144 | self.log("cosine_eer", EER*100)
145 |
146 | minDCF, threshold = score.compute_minDCF(labels, scores, p_target=0.01)
147 | print("cosine minDCF(10-2): {:.2f} with threshold {:.2f}".format(minDCF, threshold))
148 | self.log("cosine_minDCF(10-2)", minDCF)
149 |
150 | minDCF, threshold = score.compute_minDCF(labels, scores, p_target=0.001)
151 | print("cosine minDCF(10-3): {:.2f} with threshold {:.2f}".format(minDCF, threshold))
152 | self.log("cosine_minDCF(10-3)", minDCF)
153 |
154 |
155 | def configure_optimizers(self):
156 | optimizer = torch.optim.Adam(
157 | self.parameters(),
158 | self.hparams.learning_rate,
159 | weight_decay=self.hparams.weight_decay
160 | )
161 | scheduler = StepLR(optimizer, step_size=self.hparams.step_size, gamma=self.hparams.gamma)
162 | return [optimizer], [scheduler]
163 |
164 | def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
165 | optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
166 | # warm up learning_rate
167 | if self.trainer.global_step < self.hparams.warmup_step:
168 | lr_scale = min(1., float(self.trainer.global_step +
169 | 1) / float(self.hparams.warmup_step))
170 | for idx, pg in enumerate(optimizer.param_groups):
171 | pg['lr'] = lr_scale * self.hparams.learning_rate
172 | # update params
173 | optimizer.step(closure=optimizer_closure)
174 | optimizer.zero_grad()
175 |
176 | @staticmethod
177 | def add_model_specific_args(parent_parser):
178 | parser = ArgumentParser(parents=[parent_parser], add_help=False)
179 | (args, _) = parser.parse_known_args()
180 |
181 | parser.add_argument("--num_workers", default=40, type=int)
182 | parser.add_argument("--embedding_dim", default=256, type=int)
183 | parser.add_argument("--num_classes", type=int, default=1211)
184 | parser.add_argument("--num_blocks", type=int, default=6)
185 |
186 | parser.add_argument("--input_layer", type=str, default="conv2d")
187 | parser.add_argument("--pos_enc_layer_type", type=str, default="abs_pos")
188 |
189 | parser.add_argument("--second", type=int, default=3)
190 | parser.add_argument('--step_size', type=int, default=1)
191 | parser.add_argument('--gamma', type=float, default=0.9)
192 | parser.add_argument("--batch_size", type=int, default=80)
193 | parser.add_argument("--learning_rate", type=float, default=0.001)
194 | parser.add_argument("--warmup_step", type=float, default=2000)
195 | parser.add_argument("--weight_decay", type=float, default=0.000001)
196 |
197 | parser.add_argument("--save_dir", type=str, default=None)
198 | parser.add_argument("--checkpoint_path", type=str, default=None)
199 | parser.add_argument("--loss_name", type=str, default="amsoftmax")
200 | parser.add_argument("--encoder_name", type=str, default="resnet34")
201 |
202 | parser.add_argument("--train_csv_path", type=str, default="data/train.csv")
203 | parser.add_argument("--trial_path", type=str, default="data/vox1_test.txt")
204 | parser.add_argument("--score_save_path", type=str, default=None)
205 |
206 | parser.add_argument('--eval', action='store_true')
207 | parser.add_argument('--aug', action='store_true')
208 | return parser
209 |
210 |
211 | def cli_main():
212 | parser = ArgumentParser()
213 | # trainer args
214 | parser = Trainer.add_argparse_args(parser)
215 |
216 | # model args
217 | parser = Task.add_model_specific_args(parser)
218 | args = parser.parse_args()
219 |
220 | model = Task(**args.__dict__)
221 |
222 | if args.checkpoint_path is not None:
223 | state_dict = torch.load(args.checkpoint_path, map_location="cpu")["state_dict"]
224 | model.load_state_dict(state_dict, strict=True)
225 | print("load weight from {}".format(args.checkpoint_path))
226 |
227 | assert args.save_dir is not None
228 | checkpoint_callback = ModelCheckpoint(monitor='cosine_eer', save_top_k=100,
229 | filename="{epoch}_{cosine_eer:.2f}", dirpath=args.save_dir)
230 | lr_monitor = LearningRateMonitor(logging_interval='step')
231 |
232 | # init default datamodule
233 | print("data augmentation {}".format(args.aug))
234 | dm = SPK_datamodule(train_csv_path=args.train_csv_path, trial_path=args.trial_path, second=args.second,
235 | aug=args.aug, batch_size=args.batch_size, num_workers=args.num_workers, pairs=False)
236 | AVAIL_GPUS = torch.cuda.device_count()
237 | trainer = Trainer(
238 | max_epochs=args.max_epochs,
239 | plugins=DDPPlugin(find_unused_parameters=False),
240 | gpus=AVAIL_GPUS,
241 | num_sanity_val_steps=-1,
242 | sync_batchnorm=True,
243 | callbacks=[checkpoint_callback, lr_monitor],
244 | default_root_dir=args.save_dir,
245 | reload_dataloaders_every_n_epochs=1,
246 | accumulate_grad_batches=1,
247 | log_every_n_steps=25,
248 | )
249 | if args.eval:
250 | trainer.test(model, datamodule=dm)
251 | else:
252 | trainer.fit(model, datamodule=dm)
253 |
254 |
255 | if __name__ == "__main__":
256 | cli_main()
257 |
258 |
--------------------------------------------------------------------------------
/module/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zyzisyz/mfa_conformer/1b9c229948f8dbdbe9370937813ec75d4b06b097/module/__init__.py
--------------------------------------------------------------------------------
/module/_pooling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class Temporal_Average_Pooling(nn.Module):
6 | def __init__(self, **kwargs):
7 | """TAP
8 | Paper: Multi-Task Learning with High-Order Statistics for X-vector based Text-Independent Speaker Verification
9 | Link: https://arxiv.org/pdf/1903.12058.pdf
10 | """
11 | super(Temporal_Average_Pooling, self).__init__()
12 |
13 | def forward(self, x):
14 | """Computes Temporal Average Pooling Module
15 | Args:
16 | x (torch.Tensor): Input tensor (#batch, channels, frames).
17 | Returns:
18 | torch.Tensor: Output tensor (#batch, channels)
19 | """
20 | x = torch.mean(x, axis=2)
21 | return x
22 |
23 |
24 | class Temporal_Statistics_Pooling(nn.Module):
25 | def __init__(self, **kwargs):
26 | """TSP
27 | Paper: X-vectors: Robust DNN Embeddings for Speaker Recognition
28 | Link: http://www.danielpovey.com/files/2018_icassp_xvectors.pdf
29 | """
30 | super(Temporal_Statistics_Pooling, self).__init__()
31 |
32 | def forward(self, x):
33 | """Computes Temporal Statistics Pooling Module
34 | Args:
35 | x (torch.Tensor): Input tensor (#batch, channels, frames).
36 | Returns:
37 | torch.Tensor: Output tensor (#batch, channels*2)
38 | """
39 | mean = torch.mean(x, axis=2)
40 | var = torch.var(x, axis=2)
41 | x = torch.cat((mean, var), axis=1)
42 | return x
43 |
44 |
45 | class Self_Attentive_Pooling(nn.Module):
46 | def __init__(self, dim):
47 | """SAP
48 | Paper: Self-Attentive Speaker Embeddings for Text-Independent Speaker Verification
49 | Link: https://danielpovey.com/files/2018_interspeech_xvector_attention.pdf
50 | Args:
51 | dim (pair): the size of attention weights
52 | """
53 | super(Self_Attentive_Pooling, self).__init__()
54 | self.sap_linear = nn.Linear(dim, dim)
55 | self.attention = nn.Parameter(torch.FloatTensor(dim, 1))
56 |
57 | def forward(self, x):
58 | """Computes Self-Attentive Pooling Module
59 | Args:
60 | x (torch.Tensor): Input tensor (#batch, dim, frames).
61 | Returns:
62 | torch.Tensor: Output tensor (#batch, dim)
63 | """
64 | x = x.permute(0, 2, 1)
65 | h = torch.tanh(self.sap_linear(x))
66 | w = torch.matmul(h, self.attention).squeeze(dim=2)
67 | w = F.softmax(w, dim=1).view(x.size(0), x.size(1), 1)
68 | x = torch.sum(x * w, dim=1)
69 | return x
70 |
71 |
72 | class Attentive_Statistics_Pooling(nn.Module):
73 | def __init__(self, dim):
74 | """ASP
75 | Paper: Attentive Statistics Pooling for Deep Speaker Embedding
76 | Link: https://arxiv.org/pdf/1803.10963.pdf
77 | Args:
78 | dim (pair): the size of attention weights
79 | """
80 | super(Attentive_Statistics_Pooling, self).__init__()
81 | self.sap_linear = nn.Linear(dim, dim)
82 | self.attention = nn.Parameter(torch.FloatTensor(dim, 1))
83 |
84 | def forward(self, x):
85 | """Computes Attentive Statistics Pooling Module
86 | Args:
87 | x (torch.Tensor): Input tensor (#batch, dim, frames).
88 | Returns:
89 | torch.Tensor: Output tensor (#batch, dim*2)
90 | """
91 | x = x.permute(0, 2, 1)
92 | h = torch.tanh(self.sap_linear(x))
93 | w = torch.matmul(h, self.attention).squeeze(dim=2)
94 | w = F.softmax(w, dim=1).view(x.size(0), x.size(1), 1)
95 | mu = torch.sum(x * w, dim=1)
96 | rh = torch.sqrt( ( torch.sum((x**2) * w, dim=1) - mu**2 ).clamp(min=1e-5) )
97 | x = torch.cat((mu, rh), 1)
98 | return x
99 |
100 |
101 | if __name__ == "__main__":
102 | data = torch.randn(10, 128, 100)
103 | pooling = Self_Attentive_Pooling(128)
104 | out = pooling(data)
105 | print(data.shape)
106 | print(out.shape)
107 |
--------------------------------------------------------------------------------
/module/augment.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import pandas as pd
5 |
6 | from scipy.io import wavfile
7 | from scipy import signal
8 | import soundfile
9 |
10 | def compute_dB(waveform):
11 | """
12 | Args:
13 | x (numpy.array): Input waveform (#length).
14 | Returns:
15 | numpy.array: Output array (#length).
16 | """
17 | val = max(0.0, np.mean(np.power(waveform, 2)))
18 | dB = 10*np.log10(val+1e-4)
19 | return dB
20 |
21 | class WavAugment(object):
22 | def __init__(self, noise_csv_path="data/noise.csv", rir_csv_path="data/rir.csv"):
23 | self.noise_paths = pd.read_csv(noise_csv_path)["utt_paths"].values
24 | self.noise_names = pd.read_csv(noise_csv_path)["speaker_name"].values
25 | self.rir_paths = pd.read_csv(rir_csv_path)["utt_paths"].values
26 |
27 | def __call__(self, waveform):
28 | idx = np.random.randint(0, 10)
29 | if idx == 0:
30 | waveform = self.add_gaussian_noise(waveform)
31 | waveform = self.add_real_noise(waveform)
32 |
33 | if idx == 1 or idx == 2 or idx == 3:
34 | waveform = self.add_real_noise(waveform)
35 |
36 | if idx == 4 or idx == 5 or idx == 6:
37 | waveform = self.reverberate(waveform)
38 |
39 | if idx == 7:
40 | waveform = self.change_volum(waveform)
41 | waveform = self.reverberate(waveform)
42 |
43 | if idx == 6:
44 | waveform = self.change_volum(waveform)
45 | waveform = self.add_real_noise(waveform)
46 |
47 | if idx == 8:
48 | waveform = self.add_gaussian_noise(waveform)
49 | waveform = self.reverberate(waveform)
50 |
51 | return waveform
52 |
53 | def add_gaussian_noise(self, waveform):
54 | """
55 | Args:
56 | x (numpy.array): Input waveform array (#length).
57 | Returns:
58 | numpy.array: Output waveform array (#length).
59 | """
60 | snr = np.random.uniform(low=10, high=25)
61 | clean_dB = compute_dB(waveform)
62 | noise = np.random.randn(len(waveform))
63 | noise_dB = compute_dB(noise)
64 | noise = np.sqrt(10 ** ((clean_dB - noise_dB - snr) / 10)) * noise
65 | waveform = (waveform + noise)
66 | return waveform
67 |
68 | def change_volum(self, waveform):
69 | """
70 | Args:
71 | x (numpy.array): Input waveform array (#length).
72 | Returns:
73 | numpy.array: Output waveform array (#length).
74 | """
75 | volum = np.random.uniform(low=0.8, high=1.0005)
76 | waveform = waveform * volum
77 | return waveform
78 |
79 | def add_real_noise(self, waveform):
80 | """
81 | Args:
82 | x (numpy.array): Input length (#length).
83 | Returns:
84 | numpy.array: Output waveform array (#length).
85 | """
86 | clean_dB = compute_dB(waveform)
87 |
88 | idx = np.random.randint(0, len(self.noise_paths))
89 | sample_rate, noise = wavfile.read(self.noise_paths[idx])
90 | noise = noise.astype(np.float64)
91 |
92 | snr = np.random.uniform(15, 25)
93 |
94 | noise_length = len(noise)
95 | audio_length = len(waveform)
96 |
97 | if audio_length >= noise_length:
98 | shortage = audio_length - noise_length
99 | noise = np.pad(noise, (0, shortage), 'wrap')
100 | else:
101 | start = np.random.randint(0, (noise_length-audio_length))
102 | noise = noise[start:start+audio_length]
103 |
104 | noise_dB = compute_dB(noise)
105 | noise = np.sqrt(10 ** ((clean_dB - noise_dB - snr) / 10)) * noise
106 | waveform = (waveform + noise)
107 | return waveform
108 |
109 | def reverberate(self, waveform):
110 | """
111 | Args:
112 | x (numpy.array): Input length (#length).
113 | Returns:
114 | numpy.array: Output waveform array (#length).
115 | """
116 | audio_length = len(waveform)
117 | idx = np.random.randint(0, len(self.rir_paths))
118 |
119 | path = self.rir_paths[idx]
120 | rir, sample_rate = soundfile.read(path)
121 | rir = rir/np.sqrt(np.sum(rir**2))
122 |
123 | waveform = signal.convolve(waveform, rir, mode='full')
124 | return waveform[:audio_length]
125 |
126 |
127 | if __name__ == "__main__":
128 | aug = WavAugment()
129 | sample_rate, waveform = wavfile.read("input.wav")
130 | waveform = waveform.astype(np.float64)
131 |
132 | gaussian_noise_wave = aug.add_gaussian_noise(waveform)
133 | print(gaussian_noise_wave.dtype)
134 | wavfile.write("gaussian_noise_wave.wav", 16000, gaussian_noise_wave.astype(np.int16))
135 |
136 | real_noise_wave = aug.add_real_noise(waveform)
137 | print(real_noise_wave.dtype)
138 | wavfile.write("real_noise_wave.wav", 16000, real_noise_wave.astype(np.int16))
139 |
140 | change_volum_wave = aug.change_volum(waveform)
141 | print(change_volum_wave.dtype)
142 | wavfile.write("change_volum_wave.wav", 16000, change_volum_wave.astype(np.int16))
143 |
144 | reverberate_wave = aug.reverberate(waveform)
145 | print(reverberate_wave.dtype)
146 | wavfile.write("reverberate_wave.wav", 16000, reverberate_wave.astype(np.int16))
147 |
148 | reverb_noise_wave = aug.reverberate(waveform)
149 | reverb_noise_wave = aug.add_real_noise(waveform)
150 | print(reverb_noise_wave.dtype)
151 | wavfile.write("reverb_noise_wave.wav", 16000, reverb_noise_wave.astype(np.int16))
152 |
153 | noise_reverb_wave = aug.add_real_noise(waveform)
154 | noise_reverb_wave = aug.reverberate(waveform)
155 | print(noise_reverb_wave.dtype)
156 | wavfile.write("noise_reverb_wave.wav", 16000, reverb_noise_wave.astype(np.int16))
157 |
158 | a = torch.FloatTensor(noise_reverb_wave)
159 | print(a.dtype)
160 |
--------------------------------------------------------------------------------
/module/conformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from wenet.transformer.encoder import ConformerEncoder
3 | from speechbrain.lobes.models.ECAPA_TDNN import AttentiveStatisticsPooling
4 | from speechbrain.lobes.models.ECAPA_TDNN import BatchNorm1d
5 |
6 | class Conformer(torch.nn.Module):
7 | def __init__(self, n_mels=80, num_blocks=6, output_size=256, embedding_dim=192, input_layer="conv2d2",
8 | pos_enc_layer_type="rel_pos"):
9 | super(Conformer, self).__init__()
10 | self.conformer = ConformerEncoder(input_size=n_mels, num_blocks=num_blocks,
11 | output_size=output_size, input_layer=input_layer, pos_enc_layer_type=pos_enc_layer_type)
12 | self.pooling = AttentiveStatisticsPooling(output_size)
13 | self.bn = BatchNorm1d(input_size=output_size*2)
14 | self.fc = torch.nn.Linear(output_size*2, embedding_dim)
15 |
16 | def forward(self, feat):
17 | feat = feat.squeeze(1).permute(0, 2, 1)
18 | lens = torch.ones(feat.shape[0]).to(feat.device)
19 | lens = torch.round(lens*feat.shape[1]).int()
20 | x, masks = self.conformer(feat, lens)
21 | x = x.permute(0, 2, 1)
22 | x = self.pooling(x)
23 | x = self.bn(x)
24 | x = x.permute(0, 2, 1)
25 | x = self.fc(x)
26 | x = x.squeeze(1)
27 | return x
28 |
29 | def conformer(n_mels=80, num_blocks=6, output_size=256,
30 | embedding_dim=192, input_layer="conv2d", pos_enc_layer_type="rel_pos"):
31 | model = Conformer(n_mels=n_mels, num_blocks=num_blocks, output_size=output_size,
32 | embedding_dim=embedding_dim, input_layer=input_layer, pos_enc_layer_type=pos_enc_layer_type)
33 | return model
34 |
35 |
36 |
37 |
38 | if __name__ == "__main__":
39 | for i in range(6, 7):
40 | print("num_blocks is {}".format(i))
41 | model = conformer(num_blocks=i)
42 |
43 | import time
44 | model = model.eval()
45 | time1 = time.time()
46 | with torch.no_grad():
47 | for i in range(100):
48 | data = torch.randn(1, 1, 80, 500)
49 | embedding = model(data)
50 | time2 = time.time()
51 | val = (time2 - time1)/100
52 | rtf = val / 5
53 |
54 | total = sum([param.nelement() for param in model.parameters()])
55 | print("total param: {:.2f}M".format(total/1e6))
56 | print("RTF {:.4f}".format(rtf))
57 |
58 |
--------------------------------------------------------------------------------
/module/conformer_cat.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from wenet.transformer.encoder_cat import ConformerEncoder
3 | from speechbrain.lobes.models.ECAPA_TDNN import AttentiveStatisticsPooling
4 | from speechbrain.lobes.models.ECAPA_TDNN import BatchNorm1d
5 |
6 | class Conformer(torch.nn.Module):
7 | def __init__(self, n_mels=80, num_blocks=6, output_size=256, embedding_dim=192, input_layer="conv2d2",
8 | pos_enc_layer_type="rel_pos"):
9 |
10 | super(Conformer, self).__init__()
11 | print("input_layer: {}".format(input_layer))
12 | print("pos_enc_layer_type: {}".format(pos_enc_layer_type))
13 | self.conformer = ConformerEncoder(input_size=n_mels, num_blocks=num_blocks,
14 | output_size=output_size, input_layer=input_layer, pos_enc_layer_type=pos_enc_layer_type)
15 | self.pooling = AttentiveStatisticsPooling(output_size*num_blocks)
16 | self.bn = BatchNorm1d(input_size=output_size*num_blocks*2)
17 | self.fc = torch.nn.Linear(output_size*num_blocks*2, embedding_dim)
18 |
19 | def forward(self, feat):
20 | feat = feat.squeeze(1).permute(0, 2, 1)
21 | lens = torch.ones(feat.shape[0]).to(feat.device)
22 | lens = torch.round(lens*feat.shape[1]).int()
23 | x, masks = self.conformer(feat, lens)
24 | x = x.permute(0, 2, 1)
25 | x = self.pooling(x)
26 | x = self.bn(x)
27 | x = x.permute(0, 2, 1)
28 | x = self.fc(x)
29 | x = x.squeeze(1)
30 | return x
31 |
32 | def conformer_cat(n_mels=80, num_blocks=6, output_size=256,
33 | embedding_dim=192, input_layer="conv2d", pos_enc_layer_type="rel_pos"):
34 | model = Conformer(n_mels=n_mels, num_blocks=num_blocks, output_size=output_size,
35 | embedding_dim=embedding_dim, input_layer=input_layer, pos_enc_layer_type=pos_enc_layer_type)
36 | return model
37 |
38 |
39 |
--------------------------------------------------------------------------------
/module/conformer_weight.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from wenet.transformer.encoder_weight import ConformerEncoder
3 | from speechbrain.lobes.models.ECAPA_TDNN import AttentiveStatisticsPooling
4 | from speechbrain.lobes.models.ECAPA_TDNN import BatchNorm1d
5 |
6 | class Conformer(torch.nn.Module):
7 | def __init__(self, n_mels=80, num_blocks=6, output_size=256, embedding_dim=192, input_layer="conv2d2",
8 | pos_enc_layer_type="rel_pos"):
9 |
10 | super(Conformer, self).__init__()
11 | print("input_layer: {}".format(input_layer))
12 | print("pos_enc_layer_type: {}".format(pos_enc_layer_type))
13 | self.conformer = ConformerEncoder(input_size=n_mels, num_blocks=num_blocks,
14 | output_size=output_size, input_layer=input_layer, pos_enc_layer_type=pos_enc_layer_type)
15 | self.pooling = AttentiveStatisticsPooling(output_size)
16 | self.bn = BatchNorm1d(input_size=output_size*2)
17 | self.fc = torch.nn.Linear(output_size*2, embedding_dim)
18 |
19 | def forward(self, feat):
20 | feat = feat.squeeze(1).permute(0, 2, 1)
21 | lens = torch.ones(feat.shape[0]).to(feat.device)
22 | lens = torch.round(lens*feat.shape[1]).int()
23 | x, masks = self.conformer(feat, lens)
24 | x = x.permute(0, 2, 1)
25 | x = self.pooling(x)
26 | x = self.bn(x)
27 | x = x.permute(0, 2, 1)
28 | x = self.fc(x)
29 | x = x.squeeze(1)
30 | return x
31 |
32 | def conformer_weight(n_mels=80, num_blocks=6, output_size=256,
33 | embedding_dim=192, input_layer="conv2d", pos_enc_layer_type="rel_pos"):
34 | model = Conformer(n_mels=n_mels, num_blocks=num_blocks, output_size=output_size,
35 | embedding_dim=embedding_dim, input_layer=input_layer, pos_enc_layer_type=pos_enc_layer_type)
36 | return model
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/module/dataset.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import os
3 | import random
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import torch
8 | from scipy import signal
9 | from scipy.io import wavfile
10 | from sklearn.utils import shuffle
11 | from torch.utils.data import DataLoader, Dataset
12 | from .augment import WavAugment
13 |
14 |
15 | def load_audio(filename, second=2):
16 | sample_rate, waveform = wavfile.read(filename)
17 | audio_length = waveform.shape[0]
18 |
19 | if second <= 0:
20 | return waveform.astype(np.float64).copy()
21 |
22 | length = np.int64(sample_rate * second)
23 |
24 | if audio_length <= length:
25 | shortage = length - audio_length
26 | waveform = np.pad(waveform, (0, shortage), 'wrap')
27 | waveform = waveform.astype(np.float64)
28 | else:
29 | start = np.int64(random.random()*(audio_length-length))
30 | waveform = waveform[start:start+length].astype(np.float64)
31 | return waveform.copy()
32 |
33 | class Train_Dataset(Dataset):
34 | def __init__(self, train_csv_path, second=3, pairs=True, aug=False, **kwargs):
35 | self.second = second
36 | self.pairs = pairs
37 |
38 | df = pd.read_csv(train_csv_path)
39 | self.labels = df["utt_spk_int_labels"].values
40 | self.paths = df["utt_paths"].values
41 | self.labels, self.paths = shuffle(self.labels, self.paths)
42 | self.aug = aug
43 | if aug:
44 | self.wav_aug = WavAugment()
45 |
46 | print("Train Dataset load {} speakers".format(len(set(self.labels))))
47 | print("Train Dataset load {} utterance".format(len(self.labels)))
48 |
49 | def __getitem__(self, index):
50 | waveform_1 = load_audio(self.paths[index], self.second)
51 | if self.aug == True:
52 | waveform_1 = self.wav_aug(waveform_1)
53 | if self.pairs == False:
54 | return torch.FloatTensor(waveform_1), self.labels[index]
55 |
56 | else:
57 | waveform_2 = load_audio(self.paths[index], self.second)
58 | if self.aug == True:
59 | waveform_2 = self.wav_aug(waveform_2)
60 | return torch.FloatTensor(waveform_1), torch.FloatTensor(waveform_2), self.labels[index]
61 |
62 | def __len__(self):
63 | return len(self.paths)
64 |
65 |
66 | class Semi_Dataset(Dataset):
67 | def __init__(self, label_csv_path, unlabel_csv_path, second=2, pairs=True, aug=False, **kwargs):
68 | self.second = second
69 | self.pairs = pairs
70 |
71 | df = pd.read_csv(label_csv_path)
72 | self.labels = df["utt_spk_int_labels"].values
73 | self.paths = df["utt_paths"].values
74 |
75 | self.aug = aug
76 | if aug:
77 | self.wav_aug = WavAugment()
78 |
79 | df = pd.read_csv(unlabel_csv_path)
80 | self.u_paths = df["utt_paths"].values
81 | self.u_paths_length = len(self.u_paths)
82 |
83 | if label_csv_path != unlabel_csv_path:
84 | self.labels, self.paths = shuffle(self.labels, self.paths)
85 | self.u_paths = shuffle(self.u_paths)
86 |
87 | # self.labels = self.labels[:self.u_paths_length]
88 | # self.paths = self.paths[:self.u_paths_length]
89 | print("Semi Dataset load {} speakers".format(len(set(self.labels))))
90 | print("Semi Dataset load {} utterance".format(len(self.labels)))
91 |
92 | def __getitem__(self, index):
93 | waveform_l = load_audio(self.paths[index], self.second)
94 |
95 | idx = np.random.randint(0, self.u_paths_length)
96 | waveform_u_1 = load_audio(self.u_paths[idx], self.second)
97 | if self.aug == True:
98 | waveform_u_1 = self.wav_aug(waveform_u_1)
99 |
100 | if self.pairs == False:
101 | return torch.FloatTensor(waveform_l), self.labels[index], torch.FloatTensor(waveform_u_1)
102 |
103 | else:
104 | waveform_u_2 = load_audio(self.u_paths[idx], self.second)
105 | if self.aug == True:
106 | waveform_u_2 = self.wav_aug(waveform_u_2)
107 | return torch.FloatTensor(waveform_l), self.labels[index], torch.FloatTensor(waveform_u_1), torch.FloatTensor(waveform_u_2)
108 |
109 | def __len__(self):
110 | return len(self.paths)
111 |
112 |
113 | class Evaluation_Dataset(Dataset):
114 | def __init__(self, paths, second=-1, **kwargs):
115 | self.paths = paths
116 | self.second = second
117 | print("load {} utterance".format(len(self.paths)))
118 |
119 | def __getitem__(self, index):
120 | waveform = load_audio(self.paths[index], self.second)
121 | return torch.FloatTensor(waveform), self.paths[index]
122 |
123 | def __len__(self):
124 | return len(self.paths)
125 |
126 | if __name__ == "__main__":
127 | dataset = Train_Dataset(train_csv_path="data/train.csv", second=3)
128 | loader = DataLoader(
129 | dataset,
130 | batch_size=10,
131 | shuffle=False
132 | )
133 | for x, label in loader:
134 | pass
135 |
136 |
--------------------------------------------------------------------------------
/module/ecapa_tdnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from speechbrain.lobes.models.ECAPA_TDNN import ECAPA_TDNN
3 |
4 | class Model(torch.nn.Module):
5 | def __init__(self, n_mels=80, embedding_dim=192, channel=512):
6 | super(Model, self).__init__()
7 | channels = [channel for _ in range(4)]
8 | channels.append(channel*3)
9 | self.model = ECAPA_TDNN(input_size=n_mels, lin_neurons=embedding_dim, channels=channels)
10 |
11 | def forward(self, x):
12 | x = x.squeeze(1)
13 | x = x.permute(0, 2, 1)
14 | x = self.model(x)
15 | x = x.squeeze(1)
16 | return x
17 |
18 | def ecapa_tdnn(n_mels=80, embedding_dim=192, channel=512):
19 | model = Model(n_mels=n_mels, embedding_dim=embedding_dim, channel=channel)
20 | return model
21 |
22 | def ecapa_tdnn_large(n_mels=80, embedding_dim=192, channel=1024):
23 | model = Model(n_mels=n_mels, embedding_dim=embedding_dim, channel=channel)
24 | return model
25 |
26 |
27 |
--------------------------------------------------------------------------------
/module/feature.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | class PreEmphasis(torch.nn.Module):
7 | def __init__(self, coef: float = 0.97):
8 | super(PreEmphasis, self).__init__()
9 | self.coef = coef
10 | # make kernel
11 | # In pytorch, the convolution operation uses cross-correlation. So, filter is flipped.
12 | self.register_buffer(
13 | 'flipped_filter', torch.FloatTensor(
14 | [-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
15 | )
16 |
17 | def forward(self, inputs: torch.tensor) -> torch.tensor:
18 | assert len(
19 | inputs.size()) == 2, 'The number of dimensions of inputs tensor must be 2!'
20 | # reflect padding to match lengths of in/out
21 | inputs = inputs.unsqueeze(1)
22 | inputs = F.pad(inputs, (1, 0), 'reflect')
23 | return F.conv1d(inputs, self.flipped_filter).squeeze(1)
24 |
25 |
26 | class Mel_Spectrogram(nn.Module):
27 | def __init__(self, sample_rate=16000, n_fft=512, win_length=400, hop=160, n_mels=80, coef=0.97, requires_grad=False):
28 | super(Mel_Spectrogram, self).__init__()
29 | self.n_fft = n_fft
30 | self.n_mels = n_mels
31 | self.win_length = win_length
32 | self.hop = hop
33 |
34 | self.pre_emphasis = PreEmphasis(coef)
35 | mel_basis = librosa.filters.mel(
36 | sr=sample_rate, n_fft=n_fft, n_mels=n_mels)
37 | self.mel_basis = nn.Parameter(
38 | torch.FloatTensor(mel_basis), requires_grad=requires_grad)
39 | self.instance_norm = nn.InstanceNorm1d(num_features=n_mels)
40 | window = torch.hamming_window(self.win_length)
41 | self.window = nn.Parameter(
42 | torch.FloatTensor(window), requires_grad=False)
43 |
44 | def forward(self, x):
45 | x = self.pre_emphasis(x)
46 | x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop,
47 | window=self.window, win_length=self.win_length, return_complex=True)
48 | x = torch.abs(x)
49 | x += 1e-9
50 | x = torch.log(x)
51 | x = torch.matmul(self.mel_basis, x)
52 | x = self.instance_norm(x)
53 | x = x.unsqueeze(1)
54 | return x
55 |
56 |
57 | if __name__ == "__main__":
58 | from scipy.io import wavfile
59 | import matplotlib.pyplot as plt
60 | from torchvision import transforms as transforms
61 |
62 | sample_rate, sig = wavfile.read("test.wav")
63 | sig = torch.FloatTensor(sig.copy())
64 | sig = sig.repeat(10, 1)
65 |
66 | spec = Mel_Spectrogram()
67 | out = spec(sig)
68 | out = out
69 | print(out.shape)
70 |
71 | plt.subplot(211)
72 | plt.imshow(out[0][0])
73 |
74 | trans = transforms.RandomResizedCrop((80, 200))
75 | out = trans(out)
76 | print(out.shape)
77 |
78 | plt.subplot(212)
79 | plt.imshow(out[0][0])
80 |
81 | plt.savefig("test.png")
82 |
--------------------------------------------------------------------------------
/module/loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Any, Callable, Optional
3 |
4 | import numpy as np
5 | import torch
6 | from pytorch_lightning import LightningDataModule
7 | from torch.utils.data import DataLoader
8 |
9 | from pl_bolts.datasets import UnlabeledImagenet
10 | from pl_bolts.utils.warnings import warn_missing_pkg
11 |
12 | from .dataset import Evaluation_Dataset, Train_Dataset, Semi_Dataset
13 |
14 |
15 | class SPK_datamodule(LightningDataModule):
16 | def __init__(
17 | self,
18 | train_csv_path,
19 | trial_path,
20 | unlabel_csv_path = None,
21 | second: int = 2,
22 | num_workers: int = 16,
23 | batch_size: int = 32,
24 | shuffle: bool = True,
25 | pin_memory: bool = True,
26 | drop_last: bool = True,
27 | pairs: bool = True,
28 | aug: bool = False,
29 | semi: bool = False,
30 | *args: Any,
31 | **kwargs: Any,
32 | ) -> None:
33 | super().__init__(*args, **kwargs)
34 |
35 | self.train_csv_path = train_csv_path
36 | self.unlabel_csv_path = unlabel_csv_path
37 | self.second = second
38 | self.num_workers = num_workers
39 | self.batch_size = batch_size
40 | self.trial_path = trial_path
41 | self.pairs = pairs
42 | self.aug = aug
43 | print("second is {:.2f}".format(second))
44 |
45 | def train_dataloader(self) -> DataLoader:
46 | if self.unlabel_csv_path is None:
47 | train_dataset = Train_Dataset(self.train_csv_path, self.second, self.pairs, self.aug)
48 | else:
49 | train_dataset = Semi_Dataset(self.train_csv_path, self.unlabel_csv_path, self.second, self.pairs, self.aug)
50 | loader = torch.utils.data.DataLoader(
51 | train_dataset,
52 | shuffle=True,
53 | num_workers=self.num_workers,
54 | batch_size=self.batch_size,
55 | pin_memory=True,
56 | drop_last=False,
57 | )
58 | return loader
59 |
60 | def val_dataloader(self) -> DataLoader:
61 | trials = np.loadtxt(self.trial_path, str)
62 | self.trials = trials
63 | eval_path = np.unique(np.concatenate((trials.T[1], trials.T[2])))
64 | print("number of enroll: {}".format(len(set(trials.T[1]))))
65 | print("number of test: {}".format(len(set(trials.T[2]))))
66 | print("number of evaluation: {}".format(len(eval_path)))
67 | eval_dataset = Evaluation_Dataset(eval_path, second=-1)
68 | loader = torch.utils.data.DataLoader(eval_dataset,
69 | num_workers=10,
70 | shuffle=False,
71 | batch_size=1)
72 | return loader
73 |
74 | def test_dataloader(self) -> DataLoader:
75 | return self.val_dataloader()
76 |
77 |
78 |
--------------------------------------------------------------------------------
/module/resnet.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, List, Optional, Type, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch import Tensor
6 |
7 | try:
8 | from ._pooling import *
9 | except:
10 | from _pooling import *
11 |
12 |
13 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
14 | """3x3 convolution with padding"""
15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
16 | padding=dilation, groups=groups, bias=False, dilation=dilation)
17 |
18 |
19 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
20 | """1x1 convolution"""
21 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
22 |
23 |
24 | class BasicBlock(nn.Module):
25 | expansion: int = 1
26 |
27 | def __init__(
28 | self,
29 | inplanes: int,
30 | planes: int,
31 | stride: int = 1,
32 | downsample: Optional[nn.Module] = None,
33 | groups: int = 1,
34 | base_width: int = 64,
35 | dilation: int = 1,
36 | norm_layer: Optional[Callable[..., nn.Module]] = None
37 | ) -> None:
38 | super(BasicBlock, self).__init__()
39 | if norm_layer is None:
40 | norm_layer = nn.BatchNorm2d
41 | if groups != 1 or base_width != 64:
42 | raise ValueError(
43 | 'BasicBlock only supports groups=1 and base_width=64')
44 | if dilation > 1:
45 | raise NotImplementedError(
46 | "Dilation > 1 not supported in BasicBlock")
47 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
48 | self.conv1 = conv3x3(inplanes, planes, stride)
49 | self.bn1 = norm_layer(planes)
50 | self.relu = nn.ReLU(inplace=True)
51 | self.conv2 = conv3x3(planes, planes)
52 | self.bn2 = norm_layer(planes)
53 | self.downsample = downsample
54 | self.stride = stride
55 |
56 | def forward(self, x: Tensor) -> Tensor:
57 | identity = x
58 |
59 | out = self.conv1(x)
60 | out = self.bn1(out)
61 | out = self.relu(out)
62 |
63 | out = self.conv2(out)
64 | out = self.bn2(out)
65 |
66 | if self.downsample is not None:
67 | identity = self.downsample(x)
68 |
69 | out += identity
70 | out = self.relu(out)
71 |
72 | return out
73 |
74 |
75 | class Bottleneck(nn.Module):
76 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
77 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
78 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
79 | # This variant is also known as ResNet V1.5 and improves accuracy according to
80 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
81 |
82 | expansion: int = 4
83 |
84 | def __init__(
85 | self,
86 | inplanes: int,
87 | planes: int,
88 | stride: int = 1,
89 | downsample: Optional[nn.Module] = None,
90 | groups: int = 1,
91 | base_width: int = 64,
92 | dilation: int = 1,
93 | norm_layer: Optional[Callable[..., nn.Module]] = None
94 | ) -> None:
95 | super(Bottleneck, self).__init__()
96 | if norm_layer is None:
97 | norm_layer = nn.BatchNorm2d
98 | width = int(planes * (base_width / 64.)) * groups
99 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
100 | self.conv1 = conv1x1(inplanes, width)
101 | self.bn1 = norm_layer(width)
102 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
103 | self.bn2 = norm_layer(width)
104 | self.conv3 = conv1x1(width, planes * self.expansion)
105 | self.bn3 = norm_layer(planes * self.expansion)
106 | self.relu = nn.ReLU(inplace=True)
107 | self.downsample = downsample
108 | self.stride = stride
109 |
110 | def forward(self, x: Tensor) -> Tensor:
111 | identity = x
112 |
113 | out = self.conv1(x)
114 | out = self.bn1(out)
115 | out = self.relu(out)
116 |
117 | out = self.conv2(out)
118 | out = self.bn2(out)
119 | out = self.relu(out)
120 |
121 | out = self.conv3(out)
122 | out = self.bn3(out)
123 |
124 | if self.downsample is not None:
125 | identity = self.downsample(x)
126 |
127 | out += identity
128 | out = self.relu(out)
129 |
130 | return out
131 |
132 |
133 | class ResNet(nn.Module):
134 |
135 | def __init__(
136 | self,
137 | block: Type[Union[BasicBlock, Bottleneck]],
138 | layers: List[int],
139 | num_channels: List[int] = [1, 32, 64, 128, 256],
140 | embedding_dim: int = 256,
141 | n_mels: int = 80,
142 | pooling_type="TSP",
143 | zero_init_residual: bool = False,
144 | groups: int = 1,
145 | width_per_group: int = 64,
146 | replace_stride_with_dilation: Optional[List[bool]] = None,
147 | norm_layer: Optional[Callable[..., nn.Module]] = None,
148 | **kwargs
149 | ) -> None:
150 | super(ResNet, self).__init__()
151 | if norm_layer is None:
152 | norm_layer = nn.BatchNorm2d
153 | self._norm_layer = norm_layer
154 |
155 | self.inplanes = 64
156 | self.dilation = 1
157 | if replace_stride_with_dilation is None:
158 | # each element in the tuple indicates if we should replace
159 | # the 2x2 stride with a dilated convolution instead
160 | replace_stride_with_dilation = [False, False, False]
161 | if len(replace_stride_with_dilation) != 3:
162 | raise ValueError("replace_stride_with_dilation should be None "
163 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
164 | self.groups = groups
165 | self.base_width = width_per_group
166 | self.conv1 = nn.Conv2d(num_channels[0], self.inplanes, kernel_size=3, stride=1, padding=1,
167 | bias=False)
168 | self.bn1 = norm_layer(self.inplanes)
169 | self.relu = nn.ReLU(inplace=True)
170 | self.layer1 = self._make_layer(block, num_channels[1], layers[0])
171 | self.layer2 = self._make_layer(block, num_channels[2], layers[1], stride=2,
172 | dilate=replace_stride_with_dilation[0])
173 | self.layer3 = self._make_layer(block, num_channels[3], layers[2], stride=2,
174 | dilate=replace_stride_with_dilation[1])
175 | self.layer4 = self._make_layer(block, num_channels[4], layers[3], stride=2,
176 | dilate=replace_stride_with_dilation[2])
177 |
178 | out_dim = num_channels[4] * block.expansion * (n_mels//8)
179 | if pooling_type == "Temporal_Average_Pooling" or pooling_type == "TAP":
180 | self.pooling = Temporal_Average_Pooling()
181 | self.fc = nn.Linear(out_dim, embedding_dim)
182 |
183 | elif pooling_type == "Temporal_Statistics_Pooling" or pooling_type == "TSP":
184 | self.pooling = Temporal_Statistics_Pooling()
185 | self.fc = nn.Linear(out_dim*2, embedding_dim)
186 |
187 | elif pooling_type == "Self_Attentive_Pooling" or pooling_type == "SAP":
188 | self.pooling = Self_Attentive_Pooling(out_dim)
189 | self.fc = nn.Linear(out_dim, embedding_dim)
190 |
191 | elif pooling_type == "Attentive_Statistics_Pooling" or pooling_type == "ASP":
192 | self.pooling = Attentive_Statistics_Pooling(out_dim)
193 | self.fc = nn.Linear(out_dim*2, embedding_dim)
194 |
195 | else:
196 | raise ValueError(
197 | '{} pooling type is not defined'.format(pooling_type))
198 |
199 | print("resnet num_channels: {}".format(num_channels))
200 | print("n_mels: {}".format(n_mels))
201 | print("embedding_dim: {}".format(embedding_dim))
202 | print("pooling_type: {}".format(pooling_type))
203 |
204 | for m in self.modules():
205 | if isinstance(m, nn.Conv2d):
206 | nn.init.kaiming_normal_(
207 | m.weight, mode='fan_out', nonlinearity='relu')
208 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
209 | nn.init.constant_(m.weight, 1)
210 | nn.init.constant_(m.bias, 0)
211 |
212 | # Zero-initialize the last BN in each residual branch,
213 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
214 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
215 | if zero_init_residual:
216 | for m in self.modules():
217 | if isinstance(m, Bottleneck):
218 | # type: ignore[arg-type]
219 | nn.init.constant_(m.bn3.weight, 0)
220 | elif isinstance(m, BasicBlock):
221 | # type: ignore[arg-type]
222 | nn.init.constant_(m.bn2.weight, 0)
223 |
224 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
225 | stride: int = 1, dilate: bool = False) -> nn.Sequential:
226 | norm_layer = self._norm_layer
227 | downsample = None
228 | previous_dilation = self.dilation
229 | if dilate:
230 | self.dilation *= stride
231 | stride = 1
232 | if stride != 1 or self.inplanes != planes * block.expansion:
233 | downsample = nn.Sequential(
234 | conv1x1(self.inplanes, planes * block.expansion, stride),
235 | norm_layer(planes * block.expansion),
236 | )
237 |
238 | layers = []
239 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
240 | self.base_width, previous_dilation, norm_layer))
241 | self.inplanes = planes * block.expansion
242 | for _ in range(1, blocks):
243 | layers.append(block(self.inplanes, planes, groups=self.groups,
244 | base_width=self.base_width, dilation=self.dilation,
245 | norm_layer=norm_layer))
246 |
247 | return nn.Sequential(*layers)
248 |
249 | def _forward_impl(self, x: Tensor) -> Tensor:
250 | # See note [TorchScript super()]
251 | x = self.conv1(x)
252 | x = self.bn1(x)
253 | x = self.relu(x)
254 |
255 | x = self.layer1(x)
256 | x = self.layer2(x)
257 | x = self.layer3(x)
258 | x = self.layer4(x)
259 |
260 | x = x.reshape(x.shape[0], -1, x.shape[-1])
261 |
262 | x = self.pooling(x)
263 |
264 | x = torch.flatten(x, 1)
265 | x = self.fc(x)
266 |
267 | return x
268 |
269 | def forward(self, x: Tensor) -> Tensor:
270 | return self._forward_impl(x)
271 |
272 |
273 | def _resnet(
274 | arch: str,
275 | block: Type[Union[BasicBlock, Bottleneck]],
276 | layers: List[int],
277 | **kwargs: Any
278 | ) -> ResNet:
279 | model = ResNet(block, layers, **kwargs)
280 | return model
281 |
282 |
283 | def resnet18(**kwargs: Any) -> ResNet:
284 | r"""ResNet-18 model from
285 | `"Deep Residual Learning for Image Recognition" `_.
286 | Args:
287 | **kwargs: Any
288 | """
289 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], num_channels=[1, 64, 128, 256, 512], **kwargs)
290 |
291 |
292 | def resnet34(**kwargs: Any) -> ResNet:
293 | r"""ResNet-34 model from
294 | `"Deep Residual Learning for Image Recognition" `_.
295 |
296 | Args:
297 | **kwargs: Any
298 | """
299 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], **kwargs)
300 |
301 |
302 | def resnet34_large(**kwargs: Any) -> ResNet:
303 | r"""ResNet-34 model from
304 | `"Deep Residual Learning for Image Recognition" `_.
305 |
306 | Args:
307 | **kwargs: Any
308 | """
309 | model = _resnet('resnet34', BasicBlock, [3, 4, 6, 3], num_channels=[1, 64, 128, 256, 512], **kwargs)
310 | return model
311 |
312 | def resnet50(**kwargs: Any) -> ResNet:
313 | r"""ResNet-50 model from
314 | `"Deep Residual Learning for Image Recognition" `_.
315 |
316 | Args:
317 | **kwargs: Any
318 | """
319 | model = _resnet('resnet50', Bottleneck, [3, 4, 6, 3], num_channels=[1, 64, 128, 256, 512], **kwargs)
320 | return model
321 |
322 |
323 | def resnext50_32x4d(**kwargs: Any) -> ResNet:
324 | r"""ResNeXt-50 32x4d model from
325 | `"Aggregated Residual Transformation for Deep Neural Networks" `_.
326 |
327 | Args:
328 | **kwargs: Any
329 | """
330 | kwargs['groups'] = 32
331 | kwargs['width_per_group'] = 4
332 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], **kwargs)
333 |
334 |
335 |
336 |
--------------------------------------------------------------------------------
/module/transformer_cat.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from wenet.transformer.encoder_cat import TransformerEncoder
3 | from speechbrain.lobes.models.ECAPA_TDNN import AttentiveStatisticsPooling
4 | from speechbrain.lobes.models.ECAPA_TDNN import BatchNorm1d
5 |
6 | class Transformer(torch.nn.Module):
7 | def __init__(self, n_mels=80, num_blocks=6, output_size=256, embedding_dim=192, input_layer="conv2d2",
8 | pos_enc_layer_type="rel_pos"):
9 |
10 | super(Transformer, self).__init__()
11 | print("input_layer: {}".format(input_layer))
12 | print("pos_enc_layer_type: {}".format(pos_enc_layer_type))
13 | self.conformer = TransformerEncoder(input_size=n_mels, num_blocks=num_blocks,
14 | output_size=output_size, input_layer=input_layer, pos_enc_layer_type=pos_enc_layer_type)
15 | self.pooling = AttentiveStatisticsPooling(output_size*num_blocks)
16 | self.bn = BatchNorm1d(input_size=output_size*num_blocks*2)
17 | self.fc = torch.nn.Linear(output_size*num_blocks*2, embedding_dim)
18 |
19 | def forward(self, feat):
20 | feat = feat.squeeze(1).permute(0, 2, 1)
21 | lens = torch.ones(feat.shape[0]).to(feat.device)
22 | lens = torch.round(lens*feat.shape[1]).int()
23 | x, masks = self.conformer(feat, lens)
24 | x = x.permute(0, 2, 1)
25 | x = self.pooling(x)
26 | x = self.bn(x)
27 | x = x.permute(0, 2, 1)
28 | x = self.fc(x)
29 | x = x.squeeze(1)
30 | return x
31 |
32 | def transformer_cat(n_mels=80, num_blocks=6, output_size=256,
33 | embedding_dim=192, input_layer="conv2d", pos_enc_layer_type="rel_pos"):
34 | model = Transformer(n_mels=n_mels, num_blocks=num_blocks, output_size=output_size,
35 | embedding_dim=embedding_dim, input_layer=input_layer, pos_enc_layer_type=pos_enc_layer_type)
36 | return model
37 |
38 |
39 |
--------------------------------------------------------------------------------
/module/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | def compute_dB(waveform):
5 | """
6 | Args:
7 | x (torch.tensor): Input waveform (#length).
8 | Returns:
9 | torch.tensor: Output array (#length).
10 | """
11 | val = max(0.0, torch.mean(torch.pow(waveform, 2)))
12 | dB = 10*torch.log10(val+1e-4)
13 | return dB
14 |
15 | def compute_SNR(waveform, noise):
16 | """
17 | Args:
18 | x (numpy.array): Input waveform (#length).
19 | Returns:
20 | numpy.array: Output array (#length).
21 | """
22 | SNR = 10*np.log10(np.mean(waveform**2)/np.mean(noise**2)+1e-9)
23 | return SNR
24 |
25 |
26 |
--------------------------------------------------------------------------------
/score/__init__.py:
--------------------------------------------------------------------------------
1 | from .cosine import cosine_score
2 | from .utils import compute_eer, compute_minDCF
3 |
--------------------------------------------------------------------------------
/score/cosine.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def cosine_score(trials, index_mapping, eval_vectors):
4 | labels = []
5 | scores = []
6 | for item in trials:
7 | enroll_vector = eval_vectors[index_mapping[item[1]]]
8 | test_vector = eval_vectors[index_mapping[item[2]]]
9 | score = enroll_vector.dot(test_vector.T)
10 | denom = np.linalg.norm(enroll_vector) * np.linalg.norm(test_vector)
11 | score = score/denom
12 | labels.append(int(item[0]))
13 | scores.append(score)
14 | return labels, scores
15 |
16 |
--------------------------------------------------------------------------------
/score/utils.py:
--------------------------------------------------------------------------------
1 | from scipy.interpolate import interp1d
2 | from sklearn.metrics import roc_curve
3 | from scipy.optimize import brentq
4 |
5 | def compute_eer(labels, scores):
6 | """sklearn style compute eer
7 | """
8 | fpr, tpr, thresholds = roc_curve(labels, scores, pos_label=1)
9 | eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0)
10 | threshold = interp1d(fpr, thresholds)(eer)
11 | return eer, threshold
12 |
13 |
14 | def compute_minDCF(labels, scores, p_target=0.01, c_miss=1, c_fa=1):
15 | """MinDCF
16 | Computes the minimum of the detection cost function. The comments refer to
17 | equations in Section 3 of the NIST 2016 Speaker Recognition Evaluation Plan.
18 | """
19 | fpr, tpr, thresholds = roc_curve(labels, scores, pos_label=1)
20 | fnr = 1.0 - tpr
21 |
22 | min_c_det = float("inf")
23 | min_c_det_threshold = thresholds[0]
24 | for i in range(0, len(fnr)):
25 | c_det = c_miss * fnr[i] * p_target + c_fa * fpr[i] * (1 - p_target)
26 | if c_det < min_c_det:
27 | min_c_det = c_det
28 | min_c_det_threshold = thresholds[i]
29 | c_def = min(c_miss * p_target, c_fa * (1 - p_target))
30 | min_dcf = min_c_det / c_def
31 | return min_dcf, min_c_det_threshold
32 |
33 |
--------------------------------------------------------------------------------
/scripts/build_datalist.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 |
4 | import argparse
5 | import os
6 |
7 | import numpy as np
8 | import pandas as pd
9 | import tqdm
10 |
11 |
12 | def findAllSeqs(dirName,
13 | extension='.wav',
14 | load_data_list=False,
15 | speaker_level=1):
16 | r"""
17 | Lists all the sequences with the given extension in the dirName directory.
18 | Output:
19 | outSequences, speakers
20 | outSequence
21 | A list of tuples seq_path, speaker where:
22 | - seq_path is the relative path of each sequence relative to the
23 | parent directory
24 | - speaker is the corresponding speaker index
25 | outSpeakers
26 | The speaker labels (in order)
27 | The speaker labels are organized the following way
28 | \dirName
29 | \speaker_label
30 | \..
31 | ...
32 | seqName.extension
33 | Adjust the value of speaker_level if you want to choose which level of
34 | directory defines the speaker label. Ex if speaker_level == 2 then the
35 | dataset should be organized in the following fashion
36 | \dirName
37 | \crappy_label
38 | \speaker_label
39 | \..
40 | ...
41 | seqName.extension
42 | Set speaker_label == 0 if no speaker label will be retrieved no matter the
43 | organization of the dataset.
44 | """
45 | if dirName[-1] != os.sep:
46 | dirName += os.sep
47 | prefixSize = len(dirName)
48 | speakersTarget = {}
49 | outSequences = []
50 | print("finding {}, Waiting...".format(extension))
51 | for root, dirs, filenames in tqdm.tqdm(os.walk(dirName, followlinks=True)):
52 | filtered_files = [f for f in filenames if f.endswith(extension)]
53 | if len(filtered_files) > 0:
54 | speakerStr = (os.sep).join(
55 | root[prefixSize:].split(os.sep)[:speaker_level])
56 | if speakerStr not in speakersTarget:
57 | speakersTarget[speakerStr] = len(speakersTarget)
58 | speaker = speakersTarget[speakerStr]
59 | for filename in filtered_files:
60 | full_path = os.path.join(root, filename)
61 | outSequences.append((speaker, full_path))
62 | outSpeakers = [None for x in speakersTarget]
63 |
64 | for key, index in speakersTarget.items():
65 | outSpeakers[index] = key
66 |
67 | print("find {} speakers".format(len(outSpeakers)))
68 | print("find {} utterance".format(len(outSequences)))
69 |
70 | return outSequences, outSpeakers
71 |
72 |
73 | if __name__ == "__main__":
74 | parser = argparse.ArgumentParser()
75 | parser.add_argument(
76 | '--extension', help='file extension name', type=str, default="wav")
77 | parser.add_argument('--dataset_dir', help='dataset dir',
78 | type=str, default="data")
79 | parser.add_argument('--data_list_path',
80 | help='list save path', type=str, default="data_list")
81 | parser.add_argument('--speaker_level',
82 | help='list save path', type=int, default=1)
83 | args = parser.parse_args()
84 |
85 | outSequences, outSpeakers = findAllSeqs(args.dataset_dir,
86 | extension=args.extension,
87 | load_data_list=False,
88 | speaker_level=1)
89 |
90 | outSequences = np.array(outSequences, dtype=str)
91 | utt_spk_int_labels = outSequences.T[0].astype(int)
92 | utt_paths = outSequences.T[1]
93 | utt_spk_str_labels = []
94 | for i in utt_spk_int_labels:
95 | utt_spk_str_labels.append(outSpeakers[i])
96 |
97 | csv_dict = {"speaker_name": utt_spk_str_labels,
98 | "utt_paths": utt_paths,
99 | "utt_spk_int_labels": utt_spk_int_labels
100 | }
101 | df = pd.DataFrame(data=csv_dict)
102 |
103 | try:
104 | df.to_csv(args.data_list_path)
105 | print(f'Saved data list file at {args.data_list_path}')
106 | except OSError as err:
107 | print(f'Ran in an error while saving {args.data_list_path}: {err}')
108 |
--------------------------------------------------------------------------------
/scripts/format_trials.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 |
4 | import argparse
5 | import os
6 |
7 | import numpy as np
8 |
9 | if __name__ == "__main__":
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--voxceleb1_root', help='voxceleb1_root', type=str,
12 | default="datasets/VoxCeleb/voxceleb1/")
13 | parser.add_argument('--src_trials_path', help='src_trials_path',
14 | type=str, default="voxceleb1_test_v2.txt")
15 | parser.add_argument('--dst_trials_path', help='dst_trials_path',
16 | type=str, default="data/trial.lst")
17 | args = parser.parse_args()
18 |
19 | trials = np.loadtxt(args.src_trials_path, dtype=str)
20 |
21 | f = open(args.dst_trials_path, "w")
22 | for item in trials:
23 | enroll_path = os.path.join(
24 | args.voxceleb1_root, "wav", item[1])
25 | test_path = os.path.join(args.voxceleb1_root, "wav", item[2])
26 | f.write("{} {} {}\n".format(item[0], enroll_path, test_path))
27 |
--------------------------------------------------------------------------------
/scripts/make_balanced_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pandas as pd
3 | import random
4 |
5 | def build_balance_data(args):
6 | data = pd.read_csv(args.data_path)
7 |
8 | labels = data["utt_spk_int_labels"].values
9 | name = data["speaker_name"].values
10 | paths = data["utt_paths"].values
11 | durations = data["durations"].values
12 |
13 | #将各列数据都依label为索引值,建立字典方便使用label值来查找
14 | dict_name = {}
15 | dict_paths = {}
16 | dict_durations = {}
17 | for idx, label in enumerate(labels):
18 | if label not in dict_paths:
19 | dict_name[label] = name[idx]
20 | dict_paths[label] = []
21 | dict_durations[label] = []
22 | if abs(durations[idx] - 9) < 3: #筛选语音长度,保证单条语音的长度不至于过大,也趋近于平均值
23 | dict_paths[label].append(paths[idx])
24 | dict_durations[label].append(durations[idx])
25 |
26 |
27 | #产生随机的说话人(args.num_spk不同个labels),保存到列表random_num_spk
28 | candi_spk = []
29 | for label in range(max(labels) + 1):
30 | if args.utt_per_spk <= len(dict_paths[label]): #筛选候选集合,保证长度足够可选
31 | candi_spk.append(label)
32 | random_num_spk = random.sample(candi_spk, args.num_spk)
33 |
34 |
35 | result_name = []
36 | result_paths = []
37 | result_durations = []
38 | result_labels = []
39 | for label in random_num_spk: #dict_name[label] dict_paths[label] label dict_durations[label]
40 | #对于每一个随机选出来的spk(label),下面再随机选出utt_per_spk条不同的语音下标,保存到列表random_utt_per_spk
41 | candi_utt = [i for i in range(len(dict_paths[label]))]
42 | random_utt_per_spk = random.sample(candi_utt, args.utt_per_spk)
43 | #保存结果
44 | result_labels.extend([label] * args.utt_per_spk)
45 | for idx in random_utt_per_spk:
46 | result_name.append(dict_name[label])
47 | result_paths.append(dict_paths[label][idx])
48 | result_durations.append(dict_durations[label][idx])
49 |
50 | table = {}
51 | for idx, label in enumerate(set(result_labels)):
52 | table[label] = idx
53 |
54 | labels = []
55 | for label in result_labels:
56 | labels.append(table[label])
57 |
58 | #写到csv文件
59 | dic = {'speaker_name': result_name, 'utt_paths': result_paths, 'utt_spk_int_labels': labels, 'durations': result_durations}
60 | df = pd.DataFrame(dic)
61 | df.to_csv(args.save_path)
62 |
63 |
64 | if __name__ == "__main__":
65 | parser = argparse.ArgumentParser()
66 | parser.add_argument('--data_path', type=str, default="data/train.csv")
67 | parser.add_argument('--save_path', type=str, default="balance.csv")
68 | parser.add_argument('--num_spk', type=int, default=1211)
69 | parser.add_argument('--utt_per_spk', type=int, default=122)
70 | args = parser.parse_args()
71 |
72 | build_balance_data(args)
73 |
--------------------------------------------------------------------------------
/scripts/make_cohort_set.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pandas as pd
3 | import numpy as np
4 |
5 | if __name__ == "__main__":
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument('--data_list_path', type=str, default="data/train.csv")
8 | parser.add_argument('--cohort_save_path', type=str, default="data/cohort.csv")
9 | parser.add_argument('--num_cohort', type=int, default=3000)
10 | args = parser.parse_args()
11 |
12 | data = pd.read_csv(args.data_list_path)
13 | utt_paths = data["utt_paths"].values
14 | np.random.shuffle(utt_paths)
15 | utt_paths = utt_paths[:args.num_cohort]
16 | with open(args.cohort_save_path, "w") as f:
17 | for item in utt_paths:
18 | f.write(item)
19 | f.write("\n")
20 |
--------------------------------------------------------------------------------
/scripts/make_tsne_set.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pandas as pd
3 | import numpy as np
4 | import random
5 |
6 | def random_dataset(args):
7 | data = pd.read_csv(args.data_list_path)
8 |
9 | labels = data["utt_spk_int_labels"].values
10 | name = data["speaker_name"].values
11 | paths = data["utt_paths"].values
12 | #durations = data["durations"].values
13 |
14 | #将各列数据都依label为索引值,建立字典方便使用label值来查找
15 | dict_name = {}
16 | dict_paths = {}
17 | #dict_durations = {}
18 | for idx, label in enumerate(labels):
19 | if label not in dict_paths:
20 | dict_name[label] = name[idx]
21 | dict_paths[label] = []
22 | #dict_durations[label] = []
23 | dict_paths[label].append(paths[idx])
24 | #dict_durations[label].append(durations[idx])
25 |
26 |
27 | #产生随机的说话人(args.num_spk不同个labels),保存到列表random_num_spk
28 | candi_spk = []
29 | for label in range(max(labels) + 1):
30 | if args.utt_per_spk <= len(dict_paths[label]): #筛选候选集合,保证长度足够可选
31 | candi_spk.append(label)
32 |
33 | random_num_spk = random.sample(candi_spk, args.num_spk)
34 |
35 |
36 | result_name = []
37 | result_paths = []
38 | #result_durations = []
39 | result_labels = []
40 | for label in random_num_spk: #dict_name[label] dict_paths[label] label dict_durations[label]
41 | #对于每一个随机选出来的spk(label),下面再随机选出utt_per_spk条不同的语音下标,保存到列表random_utt_per_spk
42 | candi_utt = [i for i in range(len(dict_paths[label]))]
43 | random_utt_per_spk = random.sample(candi_utt, args.utt_per_spk)
44 | #保存结果
45 | result_labels.extend([label] * args.utt_per_spk)
46 | for idx in random_utt_per_spk:
47 | result_name.append(dict_name[label])
48 | result_paths.append(dict_paths[label][idx])
49 | #result_durations.append(dict_durations[label][idx])
50 |
51 | #写到csv文件
52 | #dict = {'speaker_name': result_name, 'utt_paths': result_paths, 'utt_spk_int_labels': result_labels, 'durations': result_durations}
53 |
54 | label_set = set(result_labels)
55 | table = {}
56 | for idx, s in enumerate(label_set):
57 | table[s] = idx
58 |
59 | new_labels = []
60 | for label in result_labels:
61 | new_labels.append(table[label])
62 |
63 | dic = {'speaker_name': result_name, 'utt_paths': result_paths, 'utt_spk_int_labels': new_labels}
64 | df = pd.DataFrame(dic)
65 | df.to_csv(args.tsne_set_save_path)
66 |
67 |
68 | if __name__ == "__main__":
69 | parser = argparse.ArgumentParser()
70 | parser.add_argument('--data_list_path', type=str, default="data.csv")
71 | parser.add_argument('--tsne_set_save_path', type=str, default="tsne.csv")
72 | parser.add_argument('--num_spk', type=int, default=20)
73 | parser.add_argument('--utt_per_spk', type=int, default=200)
74 | args = parser.parse_args()
75 |
76 | random_dataset(args)
77 |
--------------------------------------------------------------------------------
/scripts/plot_score.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 | labels, scores = np.loadtxt("score.txt").T
5 |
6 | target_score = []
7 | nontarget_score = []
8 |
9 | for idx,i in enumerate(labels):
10 | if i == 0:
11 | nontarget_score.append(scores[idx])
12 | else:
13 | target_score.append(scores[idx])
14 |
15 | print(scores.shape)
16 | print(labels.shape)
17 |
18 | plt.hist(target_score, bins=100, label="target score")
19 | plt.hist(nontarget_score, bins=100, label="nontarget score")
20 | plt.legend()
21 | plt.tight_layout()
22 | plt.savefig("test.png")
23 |
--------------------------------------------------------------------------------
/start.sh:
--------------------------------------------------------------------------------
1 | encoder_name="conformer_cat" # conformer_cat | ecapa_tdnn_large | resnet34
2 | embedding_dim=192
3 | loss_name="amsoftmax"
4 |
5 | dataset="vox"
6 | num_classes=7205
7 | num_blocks=6
8 | train_csv_path="data/train.csv"
9 |
10 | input_layer=conv2d2
11 | pos_enc_layer_type=rel_pos # no_pos| rel_pos
12 | save_dir=experiment/${input_layer}/${encoder_name}_${num_blocks}_${embedding_dim}_${loss_name}
13 | trial_path=data/vox1_test.txt
14 |
15 | mkdir -p $save_dir
16 | cp start.sh $save_dir
17 | cp main.py $save_dir
18 | cp -r module $save_dir
19 | cp -r wenet $save_dir
20 | cp -r scripts $save_dir
21 | cp -r loss $save_dir
22 | echo save_dir: $save_dir
23 |
24 | export CUDA_VISIBLE_DEVICES=0
25 | python3 main.py \
26 | --batch_size 200 \
27 | --num_workers 40 \
28 | --max_epochs 30 \
29 | --embedding_dim $embedding_dim \
30 | --save_dir $save_dir \
31 | --encoder_name $encoder_name \
32 | --train_csv_path $train_csv_path \
33 | --learning_rate 0.001 \
34 | --encoder_name ${encoder_name} \
35 | --num_classes $num_classes \
36 | --trial_path $trial_path \
37 | --loss_name $loss_name \
38 | --num_blocks $num_blocks \
39 | --step_size 4 \
40 | --gamma 0.5 \
41 | --weight_decay 0.0000001 \
42 | --input_layer $input_layer \
43 | --pos_enc_layer_type $pos_enc_layer_type
44 |
45 |
--------------------------------------------------------------------------------
/wenet/transformer/attention.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2019 Shigeki Karita
5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6 | """Multi-Head Attention layer definition."""
7 |
8 | import math
9 | from typing import Optional, Tuple
10 |
11 | import torch
12 | from torch import nn
13 |
14 |
15 | class MultiHeadedAttention(nn.Module):
16 | """Multi-Head Attention layer.
17 |
18 | Args:
19 | n_head (int): The number of heads.
20 | n_feat (int): The number of features.
21 | dropout_rate (float): Dropout rate.
22 |
23 | """
24 | def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
25 | """Construct an MultiHeadedAttention object."""
26 | super().__init__()
27 | assert n_feat % n_head == 0
28 | # We assume d_v always equals d_k
29 | self.d_k = n_feat // n_head
30 | self.h = n_head
31 | self.linear_q = nn.Linear(n_feat, n_feat)
32 | self.linear_k = nn.Linear(n_feat, n_feat)
33 | self.linear_v = nn.Linear(n_feat, n_feat)
34 | self.linear_out = nn.Linear(n_feat, n_feat)
35 | self.dropout = nn.Dropout(p=dropout_rate)
36 |
37 | def forward_qkv(
38 | self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
39 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
40 | """Transform query, key and value.
41 |
42 | Args:
43 | query (torch.Tensor): Query tensor (#batch, time1, size).
44 | key (torch.Tensor): Key tensor (#batch, time2, size).
45 | value (torch.Tensor): Value tensor (#batch, time2, size).
46 |
47 | Returns:
48 | torch.Tensor: Transformed query tensor, size
49 | (#batch, n_head, time1, d_k).
50 | torch.Tensor: Transformed key tensor, size
51 | (#batch, n_head, time2, d_k).
52 | torch.Tensor: Transformed value tensor, size
53 | (#batch, n_head, time2, d_k).
54 |
55 | """
56 | n_batch = query.size(0)
57 | q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
58 | k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
59 | v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
60 | q = q.transpose(1, 2) # (batch, head, time1, d_k)
61 | k = k.transpose(1, 2) # (batch, head, time2, d_k)
62 | v = v.transpose(1, 2) # (batch, head, time2, d_k)
63 |
64 | return q, k, v
65 |
66 | def forward_attention(self, value: torch.Tensor, scores: torch.Tensor,
67 | mask: Optional[torch.Tensor]) -> torch.Tensor:
68 | """Compute attention context vector.
69 |
70 | Args:
71 | value (torch.Tensor): Transformed value, size
72 | (#batch, n_head, time2, d_k).
73 | scores (torch.Tensor): Attention score, size
74 | (#batch, n_head, time1, time2).
75 | mask (torch.Tensor): Mask, size (#batch, 1, time2) or
76 | (#batch, time1, time2).
77 |
78 | Returns:
79 | torch.Tensor: Transformed value (#batch, time1, d_model)
80 | weighted by the attention score (#batch, time1, time2).
81 |
82 | """
83 | n_batch = value.size(0)
84 | if mask is not None:
85 | mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
86 | scores = scores.masked_fill(mask, -float('inf'))
87 | attn = torch.softmax(scores, dim=-1).masked_fill(
88 | mask, 0.0) # (batch, head, time1, time2)
89 | else:
90 | attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
91 |
92 | p_attn = self.dropout(attn)
93 | x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
94 | x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
95 | self.h * self.d_k)
96 | ) # (batch, time1, d_model)
97 |
98 | return self.linear_out(x) # (batch, time1, d_model)
99 |
100 | def forward(self, query: torch.Tensor, key: torch.Tensor,
101 | value: torch.Tensor,
102 | mask: Optional[torch.Tensor],
103 | pos_emb: torch.Tensor = torch.empty(0),) -> torch.Tensor:
104 | """Compute scaled dot product attention.
105 |
106 | Args:
107 | query (torch.Tensor): Query tensor (#batch, time1, size).
108 | key (torch.Tensor): Key tensor (#batch, time2, size).
109 | value (torch.Tensor): Value tensor (#batch, time2, size).
110 | mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
111 | (#batch, time1, time2).
112 | 1.When applying cross attention between decoder and encoder,
113 | the batch padding mask for input is in (#batch, 1, T) shape.
114 | 2.When applying self attention of encoder,
115 | the mask is in (#batch, T, T) shape.
116 | 3.When applying self attention of decoder,
117 | the mask is in (#batch, L, L) shape.
118 | 4.If the different position in decoder see different block
119 | of the encoder, such as Mocha, the passed in mask could be
120 | in (#batch, L, T) shape. But there is no such case in current
121 | Wenet.
122 |
123 |
124 | Returns:
125 | torch.Tensor: Output tensor (#batch, time1, d_model).
126 |
127 | """
128 | q, k, v = self.forward_qkv(query, key, value)
129 | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
130 | return self.forward_attention(v, scores, mask)
131 |
132 |
133 | class RelPositionMultiHeadedAttention(MultiHeadedAttention):
134 | """Multi-Head Attention layer with relative position encoding.
135 | Paper: https://arxiv.org/abs/1901.02860
136 | Args:
137 | n_head (int): The number of heads.
138 | n_feat (int): The number of features.
139 | dropout_rate (float): Dropout rate.
140 | """
141 | def __init__(self, n_head, n_feat, dropout_rate):
142 | """Construct an RelPositionMultiHeadedAttention object."""
143 | super().__init__(n_head, n_feat, dropout_rate)
144 | # linear transformation for positional encoding
145 | self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
146 | # these two learnable bias are used in matrix c and matrix d
147 | # as described in https://arxiv.org/abs/1901.02860 Section 3.3
148 | self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
149 | self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
150 | torch.nn.init.xavier_uniform_(self.pos_bias_u)
151 | torch.nn.init.xavier_uniform_(self.pos_bias_v)
152 |
153 | def rel_shift(self, x, zero_triu: bool = False):
154 | """Compute relative positinal encoding.
155 | Args:
156 | x (torch.Tensor): Input tensor (batch, time, size).
157 | zero_triu (bool): If true, return the lower triangular part of
158 | the matrix.
159 | Returns:
160 | torch.Tensor: Output tensor.
161 | """
162 |
163 | zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
164 | device=x.device,
165 | dtype=x.dtype)
166 | x_padded = torch.cat([zero_pad, x], dim=-1)
167 |
168 | x_padded = x_padded.view(x.size()[0],
169 | x.size()[1],
170 | x.size(3) + 1, x.size(2))
171 | x = x_padded[:, :, 1:].view_as(x)
172 |
173 | if zero_triu:
174 | ones = torch.ones((x.size(2), x.size(3)))
175 | x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
176 |
177 | return x
178 |
179 | def forward(self, query: torch.Tensor, key: torch.Tensor,
180 | value: torch.Tensor, mask: Optional[torch.Tensor],
181 | pos_emb: torch.Tensor):
182 | """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
183 | Args:
184 | query (torch.Tensor): Query tensor (#batch, time1, size).
185 | key (torch.Tensor): Key tensor (#batch, time2, size).
186 | value (torch.Tensor): Value tensor (#batch, time2, size).
187 | mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
188 | (#batch, time1, time2).
189 | pos_emb (torch.Tensor): Positional embedding tensor
190 | (#batch, time2, size).
191 | Returns:
192 | torch.Tensor: Output tensor (#batch, time1, d_model).
193 | """
194 | q, k, v = self.forward_qkv(query, key, value)
195 | q = q.transpose(1, 2) # (batch, time1, head, d_k)
196 |
197 | n_batch_pos = pos_emb.size(0)
198 | p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
199 | p = p.transpose(1, 2) # (batch, head, time1, d_k)
200 |
201 | # (batch, head, time1, d_k)
202 | q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
203 | # (batch, head, time1, d_k)
204 | q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
205 |
206 | # compute attention score
207 | # first compute matrix a and matrix c
208 | # as described in https://arxiv.org/abs/1901.02860 Section 3.3
209 | # (batch, head, time1, time2)
210 | matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
211 |
212 | # compute matrix b and matrix d
213 | # (batch, head, time1, time2)
214 | matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
215 | # Remove rel_shift since it is useless in speech recognition,
216 | # and it requires special attention for streaming.
217 | # matrix_bd = self.rel_shift(matrix_bd)
218 |
219 | scores = (matrix_ac + matrix_bd) / math.sqrt(
220 | self.d_k) # (batch, head, time1, time2)
221 |
222 | return self.forward_attention(v, scores, mask)
223 |
--------------------------------------------------------------------------------
/wenet/transformer/cmvn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 |
18 |
19 | class GlobalCMVN(torch.nn.Module):
20 | def __init__(self,
21 | mean: torch.Tensor,
22 | istd: torch.Tensor,
23 | norm_var: bool = True):
24 | """
25 | Args:
26 | mean (torch.Tensor): mean stats
27 | istd (torch.Tensor): inverse std, std which is 1.0 / std
28 | """
29 | super().__init__()
30 | assert mean.shape == istd.shape
31 | self.norm_var = norm_var
32 | # The buffer can be accessed from this module using self.mean
33 | self.register_buffer("mean", mean)
34 | self.register_buffer("istd", istd)
35 |
36 | def forward(self, x: torch.Tensor):
37 | """
38 | Args:
39 | x (torch.Tensor): (batch, max_len, feat_dim)
40 |
41 | Returns:
42 | (torch.Tensor): normalized feature
43 | """
44 | x = x - self.mean
45 | if self.norm_var:
46 | x = x * self.istd
47 | return x
48 |
--------------------------------------------------------------------------------
/wenet/transformer/convolution.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2021 Mobvoi Inc. All Rights Reserved.
5 | # Author: di.wu@mobvoi.com (DI WU)
6 | """ConvolutionModule definition."""
7 |
8 | from typing import Optional, Tuple
9 |
10 | import torch
11 | from torch import nn
12 | from typeguard import check_argument_types
13 |
14 |
15 | class ConvolutionModule(nn.Module):
16 | """ConvolutionModule in Conformer model."""
17 | def __init__(self,
18 | channels: int,
19 | kernel_size: int = 15,
20 | activation: nn.Module = nn.ReLU(),
21 | norm: str = "batch_norm",
22 | causal: bool = False,
23 | bias: bool = True):
24 | """Construct an ConvolutionModule object.
25 | Args:
26 | channels (int): The number of channels of conv layers.
27 | kernel_size (int): Kernel size of conv layers.
28 | causal (int): Whether use causal convolution or not
29 | """
30 | assert check_argument_types()
31 | super().__init__()
32 |
33 | self.pointwise_conv1 = nn.Conv1d(
34 | channels,
35 | 2 * channels,
36 | kernel_size=1,
37 | stride=1,
38 | padding=0,
39 | bias=bias,
40 | )
41 | # self.lorder is used to distinguish if it's a causal convolution,
42 | # if self.lorder > 0: it's a causal convolution, the input will be
43 | # padded with self.lorder frames on the left in forward.
44 | # else: it's a symmetrical convolution
45 | if causal:
46 | padding = 0
47 | self.lorder = kernel_size - 1
48 | else:
49 | # kernel_size should be an odd number for none causal convolution
50 | assert (kernel_size - 1) % 2 == 0
51 | padding = (kernel_size - 1) // 2
52 | self.lorder = 0
53 | self.depthwise_conv = nn.Conv1d(
54 | channels,
55 | channels,
56 | kernel_size,
57 | stride=1,
58 | padding=padding,
59 | groups=channels,
60 | bias=bias,
61 | )
62 |
63 | assert norm in ['batch_norm', 'layer_norm']
64 | if norm == "batch_norm":
65 | self.use_layer_norm = False
66 | self.norm = nn.BatchNorm1d(channels)
67 | else:
68 | self.use_layer_norm = True
69 | self.norm = nn.LayerNorm(channels)
70 |
71 | self.pointwise_conv2 = nn.Conv1d(
72 | channels,
73 | channels,
74 | kernel_size=1,
75 | stride=1,
76 | padding=0,
77 | bias=bias,
78 | )
79 | self.activation = activation
80 |
81 | def forward(
82 | self,
83 | x: torch.Tensor,
84 | mask_pad: Optional[torch.Tensor] = None,
85 | cache: Optional[torch.Tensor] = None,
86 | ) -> Tuple[torch.Tensor, torch.Tensor]:
87 | """Compute convolution module.
88 | Args:
89 | x (torch.Tensor): Input tensor (#batch, time, channels).
90 | mask_pad (torch.Tensor): used for batch padding (#batch, 1, time)
91 | cache (torch.Tensor): left context cache, it is only
92 | used in causal convolution
93 | Returns:
94 | torch.Tensor: Output tensor (#batch, time, channels).
95 | """
96 | # exchange the temporal dimension and the feature dimension
97 | x = x.transpose(1, 2) # (#batch, channels, time)
98 |
99 | # mask batch padding
100 | if mask_pad is not None:
101 | x.masked_fill_(~mask_pad, 0.0)
102 |
103 | if self.lorder > 0:
104 | if cache is None:
105 | x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
106 | else:
107 | assert cache.size(0) == x.size(0)
108 | assert cache.size(1) == x.size(1)
109 | x = torch.cat((cache, x), dim=2)
110 | assert (x.size(2) > self.lorder)
111 | new_cache = x[:, :, -self.lorder:]
112 | else:
113 | # It's better we just return None if no cache is requried,
114 | # However, for JIT export, here we just fake one tensor instead of
115 | # None.
116 | new_cache = torch.tensor([0.0], dtype=x.dtype, device=x.device)
117 |
118 | # GLU mechanism
119 | x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
120 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
121 |
122 | # 1D Depthwise Conv
123 | x = self.depthwise_conv(x)
124 | if self.use_layer_norm:
125 | x = x.transpose(1, 2)
126 | x = self.activation(self.norm(x))
127 | if self.use_layer_norm:
128 | x = x.transpose(1, 2)
129 | x = self.pointwise_conv2(x)
130 | # mask batch padding
131 | if mask_pad is not None:
132 | x.masked_fill_(~mask_pad, 0.0)
133 |
134 | return x.transpose(1, 2), new_cache
135 |
--------------------------------------------------------------------------------
/wenet/transformer/embedding.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2019 Mobvoi Inc. All Rights Reserved.
5 | # Author: di.wu@mobvoi.com (DI WU)
6 | """Positonal Encoding Module."""
7 |
8 | import math
9 | from typing import Tuple
10 |
11 | import torch
12 |
13 |
14 | class PositionalEncoding(torch.nn.Module):
15 | """Positional encoding.
16 |
17 | :param int d_model: embedding dim
18 | :param float dropout_rate: dropout rate
19 | :param int max_len: maximum input length
20 |
21 | PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
22 | PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
23 | """
24 | def __init__(self,
25 | d_model: int,
26 | dropout_rate: float,
27 | max_len: int = 50000,
28 | reverse: bool = False):
29 | """Construct an PositionalEncoding object."""
30 | super().__init__()
31 | self.d_model = d_model
32 | self.xscale = math.sqrt(self.d_model)
33 | self.dropout = torch.nn.Dropout(p=dropout_rate)
34 | self.max_len = max_len
35 |
36 | self.pe = torch.zeros(self.max_len, self.d_model)
37 | position = torch.arange(0, self.max_len,
38 | dtype=torch.float32).unsqueeze(1)
39 | div_term = torch.exp(
40 | torch.arange(0, self.d_model, 2, dtype=torch.float32) *
41 | -(math.log(10000.0) / self.d_model))
42 | self.pe[:, 0::2] = torch.sin(position * div_term)
43 | self.pe[:, 1::2] = torch.cos(position * div_term)
44 | self.pe = self.pe.unsqueeze(0)
45 |
46 | def forward(self,
47 | x: torch.Tensor,
48 | offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
49 | """Add positional encoding.
50 |
51 | Args:
52 | x (torch.Tensor): Input. Its shape is (batch, time, ...)
53 | offset (int): position offset
54 |
55 | Returns:
56 | torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
57 | torch.Tensor: for compatibility to RelPositionalEncoding
58 | """
59 | assert offset + x.size(1) < self.max_len
60 | self.pe = self.pe.to(x.device)
61 | pos_emb = self.pe[:, offset:offset + x.size(1)]
62 | x = x * self.xscale + pos_emb
63 | return self.dropout(x), self.dropout(pos_emb)
64 |
65 | def position_encoding(self, offset: int, size: int) -> torch.Tensor:
66 | """ For getting encoding in a streaming fashion
67 |
68 | Attention!!!!!
69 | we apply dropout only once at the whole utterance level in a none
70 | streaming way, but will call this function several times with
71 | increasing input size in a streaming scenario, so the dropout will
72 | be applied several times.
73 |
74 | Args:
75 | offset (int): start offset
76 | size (int): requried size of position encoding
77 |
78 | Returns:
79 | torch.Tensor: Corresponding encoding
80 | """
81 | assert offset + size < self.max_len
82 | return self.dropout(self.pe[:, offset:offset + size])
83 |
84 |
85 | class RelPositionalEncoding(PositionalEncoding):
86 | """Relative positional encoding module.
87 | See : Appendix B in https://arxiv.org/abs/1901.02860
88 | Args:
89 | d_model (int): Embedding dimension.
90 | dropout_rate (float): Dropout rate.
91 | max_len (int): Maximum input length.
92 | """
93 | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 100000):
94 | """Initialize class."""
95 | super().__init__(d_model, dropout_rate, max_len, reverse=True)
96 |
97 | def forward(self,
98 | x: torch.Tensor,
99 | offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
100 | """Compute positional encoding.
101 | Args:
102 | x (torch.Tensor): Input tensor (batch, time, `*`).
103 | Returns:
104 | torch.Tensor: Encoded tensor (batch, time, `*`).
105 | torch.Tensor: Positional embedding tensor (1, time, `*`).
106 | """
107 | assert offset + x.size(1) < self.max_len
108 | self.pe = self.pe.to(x.device)
109 | x = x * self.xscale
110 | pos_emb = self.pe[:, offset:offset + x.size(1)]
111 | return self.dropout(x), self.dropout(pos_emb)
112 |
113 |
114 | class NoPositionalEncoding(torch.nn.Module):
115 | """ No position encoding
116 | """
117 | def __init__(self, d_model: int, dropout_rate: float):
118 | super().__init__()
119 | self.d_model = d_model
120 | self.dropout = torch.nn.Dropout(p=dropout_rate)
121 |
122 | def forward(self,
123 | x: torch.Tensor,
124 | offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
125 | """ Just return zero vector for interface compatibility
126 | """
127 | pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
128 | return self.dropout(x), pos_emb
129 |
130 | def position_encoding(self, offset: int, size: int) -> torch.Tensor:
131 | return torch.zeros(1, size, self.d_model)
132 |
--------------------------------------------------------------------------------
/wenet/transformer/encoder.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2019 Mobvoi Inc. All Rights Reserved.
5 | # Author: di.wu@mobvoi.com (DI WU)
6 | """Encoder definition."""
7 | from typing import Tuple, List, Optional
8 |
9 | import torch
10 | from typeguard import check_argument_types
11 |
12 | from wenet.transformer.attention import MultiHeadedAttention
13 | from wenet.transformer.attention import RelPositionMultiHeadedAttention
14 | from wenet.transformer.convolution import ConvolutionModule
15 | from wenet.transformer.embedding import PositionalEncoding
16 | from wenet.transformer.embedding import RelPositionalEncoding
17 | from wenet.transformer.embedding import NoPositionalEncoding
18 | from wenet.transformer.encoder_layer import TransformerEncoderLayer
19 | from wenet.transformer.encoder_layer import ConformerEncoderLayer
20 | from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
21 | from wenet.transformer.subsampling import Conv2dSubsampling2
22 | from wenet.transformer.subsampling import Conv2dSubsampling4
23 | from wenet.transformer.subsampling import Conv2dSubsampling6
24 | from wenet.transformer.subsampling import Conv2dSubsampling8
25 |
26 | from wenet.transformer.subsampling import LinearNoSubsampling
27 | from wenet.utils.common import get_activation
28 | from wenet.utils.mask import make_pad_mask
29 | from wenet.utils.mask import add_optional_chunk_mask
30 |
31 |
32 | class BaseEncoder(torch.nn.Module):
33 | def __init__(
34 | self,
35 | input_size: int,
36 | output_size: int = 256,
37 | attention_heads: int = 4,
38 | linear_units: int = 2048,
39 | num_blocks: int = 6,
40 | dropout_rate: float = 0.1,
41 | positional_dropout_rate: float = 0.1,
42 | attention_dropout_rate: float = 0.0,
43 | input_layer: str = "conv2d",
44 | pos_enc_layer_type: str = "abs_pos",
45 | normalize_before: bool = True,
46 | concat_after: bool = False,
47 | static_chunk_size: int = 0,
48 | use_dynamic_chunk: bool = False,
49 | global_cmvn: torch.nn.Module = None,
50 | use_dynamic_left_chunk: bool = False,
51 | ):
52 | """
53 | Args:
54 | input_size (int): input dim
55 | output_size (int): dimension of attention
56 | attention_heads (int): the number of heads of multi head attention
57 | linear_units (int): the hidden units number of position-wise feed
58 | forward
59 | num_blocks (int): the number of decoder blocks
60 | dropout_rate (float): dropout rate
61 | attention_dropout_rate (float): dropout rate in attention
62 | positional_dropout_rate (float): dropout rate after adding
63 | positional encoding
64 | input_layer (str): input layer type.
65 | optional [linear, conv2d, conv2d6, conv2d8]
66 | pos_enc_layer_type (str): Encoder positional encoding layer type.
67 | opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
68 | normalize_before (bool):
69 | True: use layer_norm before each sub-block of a layer.
70 | False: use layer_norm after each sub-block of a layer.
71 | concat_after (bool): whether to concat attention layer's input
72 | and output.
73 | True: x -> x + linear(concat(x, att(x)))
74 | False: x -> x + att(x)
75 | static_chunk_size (int): chunk size for static chunk training and
76 | decoding
77 | use_dynamic_chunk (bool): whether use dynamic chunk size for
78 | training or not, You can only use fixed chunk(chunk_size > 0)
79 | or dyanmic chunk size(use_dynamic_chunk = True)
80 | global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
81 | use_dynamic_left_chunk (bool): whether use dynamic left chunk in
82 | dynamic chunk training
83 | """
84 | assert check_argument_types()
85 | super().__init__()
86 | self._output_size = output_size
87 |
88 | if pos_enc_layer_type == "abs_pos":
89 | pos_enc_class = PositionalEncoding
90 | elif pos_enc_layer_type == "rel_pos":
91 | pos_enc_class = RelPositionalEncoding
92 | elif pos_enc_layer_type == "no_pos":
93 | pos_enc_class = NoPositionalEncoding
94 | else:
95 | raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
96 |
97 | if input_layer == "linear":
98 | subsampling_class = LinearNoSubsampling
99 | elif input_layer == "conv2d":
100 | subsampling_class = Conv2dSubsampling4
101 | elif input_layer == "conv2d6":
102 | subsampling_class = Conv2dSubsampling6
103 | elif input_layer == "conv2d8":
104 | subsampling_class = Conv2dSubsampling8
105 | elif input_layer == "conv2d2":
106 | subsampling_class = Conv2dSubsampling2
107 | else:
108 | raise ValueError("unknown input_layer: " + input_layer)
109 |
110 | self.global_cmvn = global_cmvn
111 | self.embed = subsampling_class(
112 | input_size,
113 | output_size,
114 | dropout_rate,
115 | pos_enc_class(output_size, positional_dropout_rate),
116 | )
117 |
118 | self.normalize_before = normalize_before
119 | self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-12)
120 | self.static_chunk_size = static_chunk_size
121 | self.use_dynamic_chunk = use_dynamic_chunk
122 | self.use_dynamic_left_chunk = use_dynamic_left_chunk
123 |
124 | def output_size(self) -> int:
125 | return self._output_size
126 |
127 | def forward(
128 | self,
129 | xs: torch.Tensor,
130 | xs_lens: torch.Tensor,
131 | decoding_chunk_size: int = 0,
132 | num_decoding_left_chunks: int = -1,
133 | ) -> Tuple[torch.Tensor, torch.Tensor]:
134 | """Embed positions in tensor.
135 |
136 | Args:
137 | xs: padded input tensor (B, T, D)
138 | xs_lens: input length (B)
139 | decoding_chunk_size: decoding chunk size for dynamic chunk
140 | 0: default for training, use random dynamic chunk.
141 | <0: for decoding, use full chunk.
142 | >0: for decoding, use fixed chunk size as set.
143 | num_decoding_left_chunks: number of left chunks, this is for decoding,
144 | the chunk size is decoding_chunk_size.
145 | >=0: use num_decoding_left_chunks
146 | <0: use all left chunks
147 | Returns:
148 | encoder output tensor xs, and subsampled masks
149 | xs: padded output tensor (B, T' ~= T/subsample_rate, D)
150 | masks: torch.Tensor batch padding mask after subsample
151 | (B, 1, T' ~= T/subsample_rate)
152 | """
153 | masks = ~make_pad_mask(xs_lens).unsqueeze(1) # (B, 1, T)
154 | if self.global_cmvn is not None:
155 | xs = self.global_cmvn(xs)
156 |
157 | xs, pos_emb, masks = self.embed(xs, masks)
158 | mask_pad = masks # (B, 1, T/subsample_rate)
159 | chunk_masks = add_optional_chunk_mask(xs, masks,
160 | self.use_dynamic_chunk,
161 | self.use_dynamic_left_chunk,
162 | decoding_chunk_size,
163 | self.static_chunk_size,
164 | num_decoding_left_chunks)
165 | for layer in self.encoders:
166 | xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
167 | if self.normalize_before:
168 | xs = self.after_norm(xs)
169 | # Here we assume the mask is not changed in encoder layers, so just
170 | # return the masks before encoder layers, and the masks will be used
171 | # for cross attention with decoder later
172 | return xs, masks
173 |
174 | def forward_chunk(
175 | self,
176 | xs: torch.Tensor,
177 | offset: int,
178 | required_cache_size: int,
179 | subsampling_cache: Optional[torch.Tensor] = None,
180 | elayers_output_cache: Optional[List[torch.Tensor]] = None,
181 | conformer_cnn_cache: Optional[List[torch.Tensor]] = None,
182 | ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor],
183 | List[torch.Tensor]]:
184 | """ Forward just one chunk
185 |
186 | Args:
187 | xs (torch.Tensor): chunk input
188 | offset (int): current offset in encoder output time stamp
189 | required_cache_size (int): cache size required for next chunk
190 | compuation
191 | >=0: actual cache size
192 | <0: means all history cache is required
193 | subsampling_cache (Optional[torch.Tensor]): subsampling cache
194 | elayers_output_cache (Optional[List[torch.Tensor]]):
195 | transformer/conformer encoder layers output cache
196 | conformer_cnn_cache (Optional[List[torch.Tensor]]): conformer
197 | cnn cache
198 |
199 | Returns:
200 | torch.Tensor: output of current input xs
201 | torch.Tensor: subsampling cache required for next chunk computation
202 | List[torch.Tensor]: encoder layers output cache required for next
203 | chunk computation
204 | List[torch.Tensor]: conformer cnn cache
205 |
206 | """
207 | assert xs.size(0) == 1
208 | # tmp_masks is just for interface compatibility
209 | tmp_masks = torch.ones(1,
210 | xs.size(1),
211 | device=xs.device,
212 | dtype=torch.bool)
213 | tmp_masks = tmp_masks.unsqueeze(1)
214 | if self.global_cmvn is not None:
215 | xs = self.global_cmvn(xs)
216 | xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
217 | if subsampling_cache is not None:
218 | cache_size = subsampling_cache.size(1)
219 | xs = torch.cat((subsampling_cache, xs), dim=1)
220 | else:
221 | cache_size = 0
222 | pos_emb = self.embed.position_encoding(offset - cache_size, xs.size(1))
223 | if required_cache_size < 0:
224 | next_cache_start = 0
225 | elif required_cache_size == 0:
226 | next_cache_start = xs.size(1)
227 | else:
228 | next_cache_start = max(xs.size(1) - required_cache_size, 0)
229 | r_subsampling_cache = xs[:, next_cache_start:, :]
230 | # Real mask for transformer/conformer layers
231 | masks = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool)
232 | masks = masks.unsqueeze(1)
233 | r_elayers_output_cache = []
234 | r_conformer_cnn_cache = []
235 | for i, layer in enumerate(self.encoders):
236 | if elayers_output_cache is None:
237 | attn_cache = None
238 | else:
239 | attn_cache = elayers_output_cache[i]
240 | if conformer_cnn_cache is None:
241 | cnn_cache = None
242 | else:
243 | cnn_cache = conformer_cnn_cache[i]
244 | xs, _, new_cnn_cache = layer(xs,
245 | masks,
246 | pos_emb,
247 | output_cache=attn_cache,
248 | cnn_cache=cnn_cache)
249 | r_elayers_output_cache.append(xs[:, next_cache_start:, :])
250 | r_conformer_cnn_cache.append(new_cnn_cache)
251 | if self.normalize_before:
252 | xs = self.after_norm(xs)
253 |
254 | return (xs[:, cache_size:, :], r_subsampling_cache,
255 | r_elayers_output_cache, r_conformer_cnn_cache)
256 |
257 | def forward_chunk_by_chunk(
258 | self,
259 | xs: torch.Tensor,
260 | decoding_chunk_size: int,
261 | num_decoding_left_chunks: int = -1,
262 | ) -> Tuple[torch.Tensor, torch.Tensor]:
263 | """ Forward input chunk by chunk with chunk_size like a streaming
264 | fashion
265 |
266 | Here we should pay special attention to computation cache in the
267 | streaming style forward chunk by chunk. Three things should be taken
268 | into account for computation in the current network:
269 | 1. transformer/conformer encoder layers output cache
270 | 2. convolution in conformer
271 | 3. convolution in subsampling
272 |
273 | However, we don't implement subsampling cache for:
274 | 1. We can control subsampling module to output the right result by
275 | overlapping input instead of cache left context, even though it
276 | wastes some computation, but subsampling only takes a very
277 | small fraction of computation in the whole model.
278 | 2. Typically, there are several covolution layers with subsampling
279 | in subsampling module, it is tricky and complicated to do cache
280 | with different convolution layers with different subsampling
281 | rate.
282 | 3. Currently, nn.Sequential is used to stack all the convolution
283 | layers in subsampling, we need to rewrite it to make it work
284 | with cache, which is not prefered.
285 | Args:
286 | xs (torch.Tensor): (1, max_len, dim)
287 | chunk_size (int): decoding chunk size
288 | """
289 | assert decoding_chunk_size > 0
290 | # The model is trained by static or dynamic chunk
291 | assert self.static_chunk_size > 0 or self.use_dynamic_chunk
292 | subsampling = self.embed.subsampling_rate
293 | context = self.embed.right_context + 1 # Add current frame
294 | stride = subsampling * decoding_chunk_size
295 | decoding_window = (decoding_chunk_size - 1) * subsampling + context
296 | num_frames = xs.size(1)
297 | subsampling_cache: Optional[torch.Tensor] = None
298 | elayers_output_cache: Optional[List[torch.Tensor]] = None
299 | conformer_cnn_cache: Optional[List[torch.Tensor]] = None
300 | outputs = []
301 | offset = 0
302 | required_cache_size = decoding_chunk_size * num_decoding_left_chunks
303 |
304 | # Feed forward overlap input step by step
305 | for cur in range(0, num_frames - context + 1, stride):
306 | end = min(cur + decoding_window, num_frames)
307 | chunk_xs = xs[:, cur:end, :]
308 | (y, subsampling_cache, elayers_output_cache,
309 | conformer_cnn_cache) = self.forward_chunk(chunk_xs, offset,
310 | required_cache_size,
311 | subsampling_cache,
312 | elayers_output_cache,
313 | conformer_cnn_cache)
314 | outputs.append(y)
315 | offset += y.size(1)
316 | ys = torch.cat(outputs, 1)
317 | masks = torch.ones(1, ys.size(1), device=ys.device, dtype=torch.bool)
318 | masks = masks.unsqueeze(1)
319 | return ys, masks
320 |
321 |
322 | class TransformerEncoder(BaseEncoder):
323 | """Transformer encoder module."""
324 | def __init__(
325 | self,
326 | input_size: int,
327 | output_size: int = 256,
328 | attention_heads: int = 4,
329 | linear_units: int = 2048,
330 | num_blocks: int = 6,
331 | dropout_rate: float = 0.1,
332 | positional_dropout_rate: float = 0.1,
333 | attention_dropout_rate: float = 0.0,
334 | input_layer: str = "conv2d",
335 | pos_enc_layer_type: str = "abs_pos",
336 | normalize_before: bool = True,
337 | concat_after: bool = False,
338 | static_chunk_size: int = 0,
339 | use_dynamic_chunk: bool = False,
340 | global_cmvn: torch.nn.Module = None,
341 | use_dynamic_left_chunk: bool = False,
342 | ):
343 | """ Construct TransformerEncoder
344 |
345 | See Encoder for the meaning of each parameter.
346 | """
347 | assert check_argument_types()
348 | super().__init__(input_size, output_size, attention_heads,
349 | linear_units, num_blocks, dropout_rate,
350 | positional_dropout_rate, attention_dropout_rate,
351 | input_layer, pos_enc_layer_type, normalize_before,
352 | concat_after, static_chunk_size, use_dynamic_chunk,
353 | global_cmvn, use_dynamic_left_chunk)
354 | self.encoders = torch.nn.ModuleList([
355 | TransformerEncoderLayer(
356 | output_size,
357 | MultiHeadedAttention(attention_heads, output_size,
358 | attention_dropout_rate),
359 | PositionwiseFeedForward(output_size, linear_units,
360 | dropout_rate), dropout_rate,
361 | normalize_before, concat_after) for _ in range(num_blocks)
362 | ])
363 |
364 |
365 | class ConformerEncoder(BaseEncoder):
366 | """Conformer encoder module."""
367 | def __init__(
368 | self,
369 | input_size: int,
370 | output_size: int = 256,
371 | attention_heads: int = 4,
372 | linear_units: int = 2048,
373 | num_blocks: int = 6,
374 | dropout_rate: float = 0.1,
375 | positional_dropout_rate: float = 0.1,
376 | attention_dropout_rate: float = 0.0,
377 | input_layer: str = "conv2d",
378 | pos_enc_layer_type: str = "rel_pos",
379 | normalize_before: bool = True,
380 | concat_after: bool = False,
381 | static_chunk_size: int = 0,
382 | use_dynamic_chunk: bool = False,
383 | global_cmvn: torch.nn.Module = None,
384 | use_dynamic_left_chunk: bool = False,
385 | positionwise_conv_kernel_size: int = 1,
386 | macaron_style: bool = True,
387 | selfattention_layer_type: str = "rel_selfattn",
388 | activation_type: str = "swish",
389 | use_cnn_module: bool = True,
390 | cnn_module_kernel: int = 15,
391 | causal: bool = False,
392 | cnn_module_norm: str = "batch_norm",
393 | ):
394 | """Construct ConformerEncoder
395 |
396 | Args:
397 | input_size to use_dynamic_chunk, see in BaseEncoder
398 | positionwise_conv_kernel_size (int): Kernel size of positionwise
399 | conv1d layer.
400 | macaron_style (bool): Whether to use macaron style for
401 | positionwise layer.
402 | selfattention_layer_type (str): Encoder attention layer type,
403 | the parameter has no effect now, it's just for configure
404 | compatibility.
405 | activation_type (str): Encoder activation function type.
406 | use_cnn_module (bool): Whether to use convolution module.
407 | cnn_module_kernel (int): Kernel size of convolution module.
408 | causal (bool): whether to use causal convolution or not.
409 | """
410 | assert check_argument_types()
411 | super().__init__(input_size, output_size, attention_heads,
412 | linear_units, num_blocks, dropout_rate,
413 | positional_dropout_rate, attention_dropout_rate,
414 | input_layer, pos_enc_layer_type, normalize_before,
415 | concat_after, static_chunk_size, use_dynamic_chunk,
416 | global_cmvn, use_dynamic_left_chunk)
417 | activation = get_activation(activation_type)
418 |
419 | # self-attention module definition
420 | if pos_enc_layer_type == "no_pos":
421 | encoder_selfattn_layer = MultiHeadedAttention
422 | else:
423 | encoder_selfattn_layer = RelPositionMultiHeadedAttention
424 | encoder_selfattn_layer_args = (
425 | attention_heads,
426 | output_size,
427 | attention_dropout_rate,
428 | )
429 | # feed-forward module definition
430 | positionwise_layer = PositionwiseFeedForward
431 | positionwise_layer_args = (
432 | output_size,
433 | linear_units,
434 | dropout_rate,
435 | activation,
436 | )
437 | # convolution module definition
438 | convolution_layer = ConvolutionModule
439 | convolution_layer_args = (output_size, cnn_module_kernel, activation,
440 | cnn_module_norm, causal)
441 |
442 | self.encoders = torch.nn.ModuleList([
443 | ConformerEncoderLayer(
444 | output_size,
445 | encoder_selfattn_layer(*encoder_selfattn_layer_args),
446 | positionwise_layer(*positionwise_layer_args),
447 | positionwise_layer(
448 | *positionwise_layer_args) if macaron_style else None,
449 | convolution_layer(
450 | *convolution_layer_args) if use_cnn_module else None,
451 | dropout_rate,
452 | normalize_before,
453 | concat_after,
454 | ) for _ in range(num_blocks)
455 | ])
456 |
--------------------------------------------------------------------------------
/wenet/transformer/encoder_cat.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2019 Mobvoi Inc. All Rights Reserved.
5 | # Author: di.wu@mobvoi.com (DI WU)
6 | """Encoder definition."""
7 | from typing import Tuple, List, Optional
8 |
9 | import torch
10 | from typeguard import check_argument_types
11 |
12 | from wenet.transformer.attention import MultiHeadedAttention
13 | from wenet.transformer.attention import RelPositionMultiHeadedAttention
14 | from wenet.transformer.convolution import ConvolutionModule
15 | from wenet.transformer.embedding import PositionalEncoding
16 | from wenet.transformer.embedding import RelPositionalEncoding
17 | from wenet.transformer.embedding import NoPositionalEncoding
18 | from wenet.transformer.encoder_layer import TransformerEncoderLayer
19 | from wenet.transformer.encoder_layer import ConformerEncoderLayer
20 | from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
21 | from wenet.transformer.subsampling import Conv2dSubsampling4
22 | from wenet.transformer.subsampling import Conv2dSubsampling6
23 | from wenet.transformer.subsampling import Conv2dSubsampling8
24 | from wenet.transformer.subsampling import Conv2dSubsampling2
25 | from wenet.transformer.subsampling import LinearNoSubsampling
26 | from wenet.utils.common import get_activation
27 | from wenet.utils.mask import make_pad_mask
28 | from wenet.utils.mask import add_optional_chunk_mask
29 |
30 |
31 | class BaseEncoder(torch.nn.Module):
32 | def __init__(
33 | self,
34 | input_size: int,
35 | output_size: int = 256,
36 | attention_heads: int = 4,
37 | linear_units: int = 2048,
38 | num_blocks: int = 6,
39 | dropout_rate: float = 0.1,
40 | positional_dropout_rate: float = 0.1,
41 | attention_dropout_rate: float = 0.0,
42 | input_layer: str = "conv2d",
43 | pos_enc_layer_type: str = "abs_pos",
44 | normalize_before: bool = True,
45 | concat_after: bool = False,
46 | static_chunk_size: int = 0,
47 | use_dynamic_chunk: bool = False,
48 | global_cmvn: torch.nn.Module = None,
49 | use_dynamic_left_chunk: bool = False,
50 | ):
51 | """
52 | Args:
53 | input_size (int): input dim
54 | output_size (int): dimension of attention
55 | attention_heads (int): the number of heads of multi head attention
56 | linear_units (int): the hidden units number of position-wise feed
57 | forward
58 | num_blocks (int): the number of decoder blocks
59 | dropout_rate (float): dropout rate
60 | attention_dropout_rate (float): dropout rate in attention
61 | positional_dropout_rate (float): dropout rate after adding
62 | positional encoding
63 | input_layer (str): input layer type.
64 | optional [linear, conv2d, conv2d6, conv2d8]
65 | pos_enc_layer_type (str): Encoder positional encoding layer type.
66 | opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
67 | normalize_before (bool):
68 | True: use layer_norm before each sub-block of a layer.
69 | False: use layer_norm after each sub-block of a layer.
70 | concat_after (bool): whether to concat attention layer's input
71 | and output.
72 | True: x -> x + linear(concat(x, att(x)))
73 | False: x -> x + att(x)
74 | static_chunk_size (int): chunk size for static chunk training and
75 | decoding
76 | use_dynamic_chunk (bool): whether use dynamic chunk size for
77 | training or not, You can only use fixed chunk(chunk_size > 0)
78 | or dyanmic chunk size(use_dynamic_chunk = True)
79 | global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
80 | use_dynamic_left_chunk (bool): whether use dynamic left chunk in
81 | dynamic chunk training
82 | """
83 | assert check_argument_types()
84 | super().__init__()
85 | self._output_size = output_size * num_blocks
86 |
87 | if pos_enc_layer_type == "abs_pos":
88 | pos_enc_class = PositionalEncoding
89 | elif pos_enc_layer_type == "rel_pos":
90 | pos_enc_class = RelPositionalEncoding
91 | elif pos_enc_layer_type == "no_pos":
92 | pos_enc_class = NoPositionalEncoding
93 | else:
94 | raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
95 |
96 | if input_layer == "linear":
97 | subsampling_class = LinearNoSubsampling
98 | elif input_layer == "conv2d":
99 | subsampling_class = Conv2dSubsampling4
100 | elif input_layer == "conv2d6":
101 | subsampling_class = Conv2dSubsampling6
102 | elif input_layer == "conv2d8":
103 | subsampling_class = Conv2dSubsampling8
104 | elif input_layer == "conv2d2":
105 | subsampling_class = Conv2dSubsampling2
106 | else:
107 | raise ValueError("unknown input_layer: " + input_layer)
108 |
109 | self.global_cmvn = global_cmvn
110 | self.embed = subsampling_class(
111 | input_size,
112 | output_size,
113 | dropout_rate,
114 | pos_enc_class(output_size, positional_dropout_rate),
115 | )
116 |
117 | self.normalize_before = normalize_before
118 | self.after_norm = torch.nn.LayerNorm(output_size * num_blocks, eps=1e-12)
119 | self.static_chunk_size = static_chunk_size
120 | self.use_dynamic_chunk = use_dynamic_chunk
121 | self.use_dynamic_left_chunk = use_dynamic_left_chunk
122 |
123 | def output_size(self) -> int:
124 | return self._output_size
125 |
126 | def forward(
127 | self,
128 | xs: torch.Tensor,
129 | xs_lens: torch.Tensor,
130 | decoding_chunk_size: int = 0,
131 | num_decoding_left_chunks: int = -1,
132 | ) -> Tuple[torch.Tensor, torch.Tensor]:
133 | """Embed positions in tensor.
134 |
135 | Args:
136 | xs: padded input tensor (B, T, D)
137 | xs_lens: input length (B)
138 | decoding_chunk_size: decoding chunk size for dynamic chunk
139 | 0: default for training, use random dynamic chunk.
140 | <0: for decoding, use full chunk.
141 | >0: for decoding, use fixed chunk size as set.
142 | num_decoding_left_chunks: number of left chunks, this is for decoding,
143 | the chunk size is decoding_chunk_size.
144 | >=0: use num_decoding_left_chunks
145 | <0: use all left chunks
146 | Returns:
147 | encoder output tensor xs, and subsampled masks
148 | xs: padded output tensor (B, T' ~= T/subsample_rate, D)
149 | masks: torch.Tensor batch padding mask after subsample
150 | (B, 1, T' ~= T/subsample_rate)
151 | """
152 | masks = ~make_pad_mask(xs_lens).unsqueeze(1) # (B, 1, T)
153 | if self.global_cmvn is not None:
154 | xs = self.global_cmvn(xs)
155 | xs, pos_emb, masks = self.embed(xs, masks)
156 | mask_pad = masks # (B, 1, T/subsample_rate)
157 | chunk_masks = add_optional_chunk_mask(xs, masks,
158 | self.use_dynamic_chunk,
159 | self.use_dynamic_left_chunk,
160 | decoding_chunk_size,
161 | self.static_chunk_size,
162 | num_decoding_left_chunks)
163 | out = []
164 | for layer in self.encoders:
165 | xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
166 | out.append(xs)
167 | xs = torch.cat(out, dim=-1)
168 | if self.normalize_before:
169 | xs = self.after_norm(xs)
170 | # Here we assume the mask is not changed in encoder layers, so just
171 | # return the masks before encoder layers, and the masks will be used
172 | # for cross attention with decoder later
173 | return xs, masks
174 |
175 | def forward_chunk(
176 | self,
177 | xs: torch.Tensor,
178 | offset: int,
179 | required_cache_size: int,
180 | subsampling_cache: Optional[torch.Tensor] = None,
181 | elayers_output_cache: Optional[List[torch.Tensor]] = None,
182 | conformer_cnn_cache: Optional[List[torch.Tensor]] = None,
183 | ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor],
184 | List[torch.Tensor]]:
185 | """ Forward just one chunk
186 |
187 | Args:
188 | xs (torch.Tensor): chunk input
189 | offset (int): current offset in encoder output time stamp
190 | required_cache_size (int): cache size required for next chunk
191 | compuation
192 | >=0: actual cache size
193 | <0: means all history cache is required
194 | subsampling_cache (Optional[torch.Tensor]): subsampling cache
195 | elayers_output_cache (Optional[List[torch.Tensor]]):
196 | transformer/conformer encoder layers output cache
197 | conformer_cnn_cache (Optional[List[torch.Tensor]]): conformer
198 | cnn cache
199 |
200 | Returns:
201 | torch.Tensor: output of current input xs
202 | torch.Tensor: subsampling cache required for next chunk computation
203 | List[torch.Tensor]: encoder layers output cache required for next
204 | chunk computation
205 | List[torch.Tensor]: conformer cnn cache
206 |
207 | """
208 | assert xs.size(0) == 1
209 | # tmp_masks is just for interface compatibility
210 | tmp_masks = torch.ones(1,
211 | xs.size(1),
212 | device=xs.device,
213 | dtype=torch.bool)
214 | tmp_masks = tmp_masks.unsqueeze(1)
215 | if self.global_cmvn is not None:
216 | xs = self.global_cmvn(xs)
217 | xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
218 | if subsampling_cache is not None:
219 | cache_size = subsampling_cache.size(1)
220 | xs = torch.cat((subsampling_cache, xs), dim=1)
221 | else:
222 | cache_size = 0
223 | pos_emb = self.embed.position_encoding(offset - cache_size, xs.size(1))
224 | if required_cache_size < 0:
225 | next_cache_start = 0
226 | elif required_cache_size == 0:
227 | next_cache_start = xs.size(1)
228 | else:
229 | next_cache_start = max(xs.size(1) - required_cache_size, 0)
230 | r_subsampling_cache = xs[:, next_cache_start:, :]
231 | # Real mask for transformer/conformer layers
232 | masks = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool)
233 | masks = masks.unsqueeze(1)
234 | r_elayers_output_cache = []
235 | r_conformer_cnn_cache = []
236 | for i, layer in enumerate(self.encoders):
237 | if elayers_output_cache is None:
238 | attn_cache = None
239 | else:
240 | attn_cache = elayers_output_cache[i]
241 | if conformer_cnn_cache is None:
242 | cnn_cache = None
243 | else:
244 | cnn_cache = conformer_cnn_cache[i]
245 | xs, _, new_cnn_cache = layer(xs,
246 | masks,
247 | pos_emb,
248 | output_cache=attn_cache,
249 | cnn_cache=cnn_cache)
250 | r_elayers_output_cache.append(xs[:, next_cache_start:, :])
251 | r_conformer_cnn_cache.append(new_cnn_cache)
252 | if self.normalize_before:
253 | xs = self.after_norm(xs)
254 |
255 | return (xs[:, cache_size:, :], r_subsampling_cache,
256 | r_elayers_output_cache, r_conformer_cnn_cache)
257 |
258 | def forward_chunk_by_chunk(
259 | self,
260 | xs: torch.Tensor,
261 | decoding_chunk_size: int,
262 | num_decoding_left_chunks: int = -1,
263 | ) -> Tuple[torch.Tensor, torch.Tensor]:
264 | """ Forward input chunk by chunk with chunk_size like a streaming
265 | fashion
266 |
267 | Here we should pay special attention to computation cache in the
268 | streaming style forward chunk by chunk. Three things should be taken
269 | into account for computation in the current network:
270 | 1. transformer/conformer encoder layers output cache
271 | 2. convolution in conformer
272 | 3. convolution in subsampling
273 |
274 | However, we don't implement subsampling cache for:
275 | 1. We can control subsampling module to output the right result by
276 | overlapping input instead of cache left context, even though it
277 | wastes some computation, but subsampling only takes a very
278 | small fraction of computation in the whole model.
279 | 2. Typically, there are several covolution layers with subsampling
280 | in subsampling module, it is tricky and complicated to do cache
281 | with different convolution layers with different subsampling
282 | rate.
283 | 3. Currently, nn.Sequential is used to stack all the convolution
284 | layers in subsampling, we need to rewrite it to make it work
285 | with cache, which is not prefered.
286 | Args:
287 | xs (torch.Tensor): (1, max_len, dim)
288 | chunk_size (int): decoding chunk size
289 | """
290 | assert decoding_chunk_size > 0
291 | # The model is trained by static or dynamic chunk
292 | assert self.static_chunk_size > 0 or self.use_dynamic_chunk
293 | subsampling = self.embed.subsampling_rate
294 | context = self.embed.right_context + 1 # Add current frame
295 | stride = subsampling * decoding_chunk_size
296 | decoding_window = (decoding_chunk_size - 1) * subsampling + context
297 | num_frames = xs.size(1)
298 | subsampling_cache: Optional[torch.Tensor] = None
299 | elayers_output_cache: Optional[List[torch.Tensor]] = None
300 | conformer_cnn_cache: Optional[List[torch.Tensor]] = None
301 | outputs = []
302 | offset = 0
303 | required_cache_size = decoding_chunk_size * num_decoding_left_chunks
304 |
305 | # Feed forward overlap input step by step
306 | for cur in range(0, num_frames - context + 1, stride):
307 | end = min(cur + decoding_window, num_frames)
308 | chunk_xs = xs[:, cur:end, :]
309 | (y, subsampling_cache, elayers_output_cache,
310 | conformer_cnn_cache) = self.forward_chunk(chunk_xs, offset,
311 | required_cache_size,
312 | subsampling_cache,
313 | elayers_output_cache,
314 | conformer_cnn_cache)
315 | outputs.append(y)
316 | offset += y.size(1)
317 | ys = torch.cat(outputs, 1)
318 | masks = torch.ones(1, ys.size(1), device=ys.device, dtype=torch.bool)
319 | masks = masks.unsqueeze(1)
320 | return ys, masks
321 |
322 |
323 | class TransformerEncoder(BaseEncoder):
324 | """Transformer encoder module."""
325 | def __init__(
326 | self,
327 | input_size: int,
328 | output_size: int = 256,
329 | attention_heads: int = 4,
330 | linear_units: int = 2048,
331 | num_blocks: int = 6,
332 | dropout_rate: float = 0.1,
333 | positional_dropout_rate: float = 0.1,
334 | attention_dropout_rate: float = 0.0,
335 | input_layer: str = "conv2d",
336 | pos_enc_layer_type: str = "abs_pos",
337 | normalize_before: bool = True,
338 | concat_after: bool = False,
339 | static_chunk_size: int = 0,
340 | use_dynamic_chunk: bool = False,
341 | global_cmvn: torch.nn.Module = None,
342 | use_dynamic_left_chunk: bool = False,
343 | ):
344 | """ Construct TransformerEncoder
345 |
346 | See Encoder for the meaning of each parameter.
347 | """
348 | assert check_argument_types()
349 | super().__init__(input_size, output_size, attention_heads,
350 | linear_units, num_blocks, dropout_rate,
351 | positional_dropout_rate, attention_dropout_rate,
352 | input_layer, pos_enc_layer_type, normalize_before,
353 | concat_after, static_chunk_size, use_dynamic_chunk,
354 | global_cmvn, use_dynamic_left_chunk)
355 | self.encoders = torch.nn.ModuleList([
356 | TransformerEncoderLayer(
357 | output_size,
358 | MultiHeadedAttention(attention_heads, output_size,
359 | attention_dropout_rate),
360 | PositionwiseFeedForward(output_size, linear_units,
361 | dropout_rate), dropout_rate,
362 | normalize_before, concat_after) for _ in range(num_blocks)
363 | ])
364 |
365 |
366 | class ConformerEncoder(BaseEncoder):
367 | """Conformer encoder module."""
368 | def __init__(
369 | self,
370 | input_size: int,
371 | output_size: int = 256,
372 | attention_heads: int = 4,
373 | linear_units: int = 2048,
374 | num_blocks: int = 6,
375 | dropout_rate: float = 0.1,
376 | positional_dropout_rate: float = 0.1,
377 | attention_dropout_rate: float = 0.0,
378 | input_layer: str = "conv2d",
379 | pos_enc_layer_type: str = "rel_pos",
380 | normalize_before: bool = True,
381 | concat_after: bool = False,
382 | static_chunk_size: int = 0,
383 | use_dynamic_chunk: bool = False,
384 | global_cmvn: torch.nn.Module = None,
385 | use_dynamic_left_chunk: bool = False,
386 | positionwise_conv_kernel_size: int = 1,
387 | macaron_style: bool = True,
388 | selfattention_layer_type: str = "rel_selfattn",
389 | activation_type: str = "swish",
390 | use_cnn_module: bool = True,
391 | cnn_module_kernel: int = 15,
392 | causal: bool = False,
393 | cnn_module_norm: str = "batch_norm",
394 | ):
395 | """Construct ConformerEncoder
396 |
397 | Args:
398 | input_size to use_dynamic_chunk, see in BaseEncoder
399 | positionwise_conv_kernel_size (int): Kernel size of positionwise
400 | conv1d layer.
401 | macaron_style (bool): Whether to use macaron style for
402 | positionwise layer.
403 | selfattention_layer_type (str): Encoder attention layer type,
404 | the parameter has no effect now, it's just for configure
405 | compatibility.
406 | activation_type (str): Encoder activation function type.
407 | use_cnn_module (bool): Whether to use convolution module.
408 | cnn_module_kernel (int): Kernel size of convolution module.
409 | causal (bool): whether to use causal convolution or not.
410 | """
411 | assert check_argument_types()
412 | super().__init__(input_size, output_size, attention_heads,
413 | linear_units, num_blocks, dropout_rate,
414 | positional_dropout_rate, attention_dropout_rate,
415 | input_layer, pos_enc_layer_type, normalize_before,
416 | concat_after, static_chunk_size, use_dynamic_chunk,
417 | global_cmvn, use_dynamic_left_chunk)
418 | activation = get_activation(activation_type)
419 |
420 | # self-attention module definition
421 | if pos_enc_layer_type == "no_pos":
422 | encoder_selfattn_layer = MultiHeadedAttention
423 | else:
424 | encoder_selfattn_layer = RelPositionMultiHeadedAttention
425 | encoder_selfattn_layer_args = (
426 | attention_heads,
427 | output_size,
428 | attention_dropout_rate,
429 | )
430 | # feed-forward module definition
431 | positionwise_layer = PositionwiseFeedForward
432 | positionwise_layer_args = (
433 | output_size,
434 | linear_units,
435 | dropout_rate,
436 | activation,
437 | )
438 | # convolution module definition
439 | convolution_layer = ConvolutionModule
440 | convolution_layer_args = (output_size, cnn_module_kernel, activation,
441 | cnn_module_norm, causal)
442 |
443 | self.encoders = torch.nn.ModuleList([
444 | ConformerEncoderLayer(
445 | output_size,
446 | encoder_selfattn_layer(*encoder_selfattn_layer_args),
447 | positionwise_layer(*positionwise_layer_args),
448 | positionwise_layer(
449 | *positionwise_layer_args) if macaron_style else None,
450 | convolution_layer(
451 | *convolution_layer_args) if use_cnn_module else None,
452 | dropout_rate,
453 | normalize_before,
454 | concat_after,
455 | ) for _ in range(num_blocks)
456 | ])
457 |
--------------------------------------------------------------------------------
/wenet/transformer/encoder_layer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2019 Mobvoi Inc. All Rights Reserved.
5 | # Author: di.wu@mobvoi.com (DI WU)
6 | """Encoder self-attention layer definition."""
7 |
8 | from typing import Optional, Tuple
9 |
10 | import torch
11 | from torch import nn
12 |
13 |
14 | class TransformerEncoderLayer(nn.Module):
15 | """Encoder layer module.
16 |
17 | Args:
18 | size (int): Input dimension.
19 | self_attn (torch.nn.Module): Self-attention module instance.
20 | `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
21 | instance can be used as the argument.
22 | feed_forward (torch.nn.Module): Feed-forward module instance.
23 | `PositionwiseFeedForward`, instance can be used as the argument.
24 | dropout_rate (float): Dropout rate.
25 | normalize_before (bool):
26 | True: use layer_norm before each sub-block.
27 | False: to use layer_norm after each sub-block.
28 | concat_after (bool): Whether to concat attention layer's input and
29 | output.
30 | True: x -> x + linear(concat(x, att(x)))
31 | False: x -> x + att(x)
32 |
33 | """
34 | def __init__(
35 | self,
36 | size: int,
37 | self_attn: torch.nn.Module,
38 | feed_forward: torch.nn.Module,
39 | dropout_rate: float,
40 | normalize_before: bool = True,
41 | concat_after: bool = False,
42 | ):
43 | """Construct an EncoderLayer object."""
44 | super().__init__()
45 | self.self_attn = self_attn
46 | self.feed_forward = feed_forward
47 | self.norm1 = nn.LayerNorm(size, eps=1e-12)
48 | self.norm2 = nn.LayerNorm(size, eps=1e-12)
49 | self.dropout = nn.Dropout(dropout_rate)
50 | self.size = size
51 | self.normalize_before = normalize_before
52 | self.concat_after = concat_after
53 |
54 | def forward(
55 | self,
56 | x: torch.Tensor,
57 | mask: torch.Tensor,
58 | pos_emb: torch.Tensor,
59 | mask_pad: Optional[torch.Tensor] = None,
60 | output_cache: Optional[torch.Tensor] = None,
61 | cnn_cache: Optional[torch.Tensor] = None,
62 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
63 | """Compute encoded features.
64 |
65 | Args:
66 | x (torch.Tensor): Input tensor (#batch, time, size).
67 | mask (torch.Tensor): Mask tensor for the input (#batch, time).
68 | pos_emb (torch.Tensor): just for interface compatibility
69 | to ConformerEncoderLayer
70 | mask_pad (torch.Tensor): does not used in transformer layer,
71 | just for unified api with conformer.
72 | output_cache (torch.Tensor): Cache tensor of the output
73 | (#batch, time2, size), time2 < time in x.
74 | cnn_cache (torch.Tensor): not used here, it's for interface
75 | compatibility to ConformerEncoderLayer
76 | Returns:
77 | torch.Tensor: Output tensor (#batch, time, size).
78 | torch.Tensor: Mask tensor (#batch, time).
79 |
80 | """
81 | residual = x
82 | if self.normalize_before:
83 | x = self.norm1(x)
84 |
85 | if output_cache is None:
86 | x_q = x
87 | else:
88 | assert output_cache.size(0) == x.size(0)
89 | assert output_cache.size(2) == self.size
90 | assert output_cache.size(1) < x.size(1)
91 | chunk = x.size(1) - output_cache.size(1)
92 | x_q = x[:, -chunk:, :]
93 | residual = residual[:, -chunk:, :]
94 | mask = mask[:, -chunk:, :]
95 |
96 | x = residual + self.dropout(self.self_attn(x_q, x, x, mask))
97 | if not self.normalize_before:
98 | x = self.norm1(x)
99 |
100 | residual = x
101 | if self.normalize_before:
102 | x = self.norm2(x)
103 | x = residual + self.dropout(self.feed_forward(x))
104 | if not self.normalize_before:
105 | x = self.norm2(x)
106 |
107 | if output_cache is not None:
108 | x = torch.cat([output_cache, x], dim=1)
109 |
110 | fake_cnn_cache = torch.tensor([0.0], dtype=x.dtype, device=x.device)
111 | return x, mask, fake_cnn_cache
112 |
113 |
114 | class ConformerEncoderLayer(nn.Module):
115 | """Encoder layer module.
116 | Args:
117 | size (int): Input dimension.
118 | self_attn (torch.nn.Module): Self-attention module instance.
119 | `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
120 | instance can be used as the argument.
121 | feed_forward (torch.nn.Module): Feed-forward module instance.
122 | `PositionwiseFeedForward` instance can be used as the argument.
123 | feed_forward_macaron (torch.nn.Module): Additional feed-forward module
124 | instance.
125 | `PositionwiseFeedForward` instance can be used as the argument.
126 | conv_module (torch.nn.Module): Convolution module instance.
127 | `ConvlutionModule` instance can be used as the argument.
128 | dropout_rate (float): Dropout rate.
129 | normalize_before (bool):
130 | True: use layer_norm before each sub-block.
131 | False: use layer_norm after each sub-block.
132 | concat_after (bool): Whether to concat attention layer's input and
133 | output.
134 | True: x -> x + linear(concat(x, att(x)))
135 | False: x -> x + att(x)
136 | """
137 | def __init__(
138 | self,
139 | size: int,
140 | self_attn: torch.nn.Module,
141 | feed_forward: Optional[nn.Module] = None,
142 | feed_forward_macaron: Optional[nn.Module] = None,
143 | conv_module: Optional[nn.Module] = None,
144 | dropout_rate: float = 0.1,
145 | normalize_before: bool = True,
146 | concat_after: bool = False,
147 | ):
148 | """Construct an EncoderLayer object."""
149 | super().__init__()
150 | self.self_attn = self_attn
151 | self.feed_forward = feed_forward
152 | self.feed_forward_macaron = feed_forward_macaron
153 | self.conv_module = conv_module
154 | self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
155 | self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
156 | if feed_forward_macaron is not None:
157 | self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
158 | self.ff_scale = 0.5
159 | else:
160 | self.ff_scale = 1.0
161 | if self.conv_module is not None:
162 | self.norm_conv = nn.LayerNorm(size,
163 | eps=1e-12) # for the CNN module
164 | self.norm_final = nn.LayerNorm(
165 | size, eps=1e-12) # for the final output of the block
166 | self.dropout = nn.Dropout(dropout_rate)
167 | self.size = size
168 | self.normalize_before = normalize_before
169 | self.concat_after = concat_after
170 |
171 | def forward(
172 | self,
173 | x: torch.Tensor,
174 | mask: torch.Tensor,
175 | pos_emb: torch.Tensor,
176 | mask_pad: Optional[torch.Tensor] = None,
177 | output_cache: Optional[torch.Tensor] = None,
178 | cnn_cache: Optional[torch.Tensor] = None,
179 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
180 | """Compute encoded features.
181 |
182 | Args:
183 | x (torch.Tensor): (#batch, time, size)
184 | mask (torch.Tensor): Mask tensor for the input (#batch, time,time).
185 | pos_emb (torch.Tensor): positional encoding, must not be None
186 | for ConformerEncoderLayer.
187 | mask_pad (torch.Tensor): batch padding mask used for conv module.
188 | (#batch, 1,time)
189 | output_cache (torch.Tensor): Cache tensor of the output
190 | (#batch, time2, size), time2 < time in x.
191 | cnn_cache (torch.Tensor): Convolution cache in conformer layer
192 | Returns:
193 | torch.Tensor: Output tensor (#batch, time, size).
194 | torch.Tensor: Mask tensor (#batch, time).
195 | """
196 |
197 | # whether to use macaron style
198 | if self.feed_forward_macaron is not None:
199 | residual = x
200 | if self.normalize_before:
201 | x = self.norm_ff_macaron(x)
202 | x = residual + self.ff_scale * self.dropout(
203 | self.feed_forward_macaron(x))
204 | if not self.normalize_before:
205 | x = self.norm_ff_macaron(x)
206 |
207 | # multi-headed self-attention module
208 | residual = x
209 | if self.normalize_before:
210 | x = self.norm_mha(x)
211 |
212 | if output_cache is None:
213 | x_q = x
214 | else:
215 | assert output_cache.size(0) == x.size(0)
216 | assert output_cache.size(2) == self.size
217 | assert output_cache.size(1) < x.size(1)
218 | chunk = x.size(1) - output_cache.size(1)
219 | x_q = x[:, -chunk:, :]
220 | residual = residual[:, -chunk:, :]
221 | mask = mask[:, -chunk:, :]
222 |
223 | x_att = self.self_attn(x_q, x, x, mask, pos_emb)
224 | x = residual + self.dropout(x_att)
225 | if not self.normalize_before:
226 | x = self.norm_mha(x)
227 |
228 | # convolution module
229 | # Fake new cnn cache here, and then change it in conv_module
230 | new_cnn_cache = torch.tensor([0.0], dtype=x.dtype, device=x.device)
231 | if self.conv_module is not None:
232 | residual = x
233 | if self.normalize_before:
234 | x = self.norm_conv(x)
235 | x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
236 | x = residual + self.dropout(x)
237 |
238 | if not self.normalize_before:
239 | x = self.norm_conv(x)
240 |
241 | # feed forward module
242 | residual = x
243 | if self.normalize_before:
244 | x = self.norm_ff(x)
245 |
246 | x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
247 | if not self.normalize_before:
248 | x = self.norm_ff(x)
249 |
250 | if self.conv_module is not None:
251 | x = self.norm_final(x)
252 |
253 | if output_cache is not None:
254 | x = torch.cat([output_cache, x], dim=1)
255 |
256 | return x, mask, new_cnn_cache
257 |
--------------------------------------------------------------------------------
/wenet/transformer/encoder_weight.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2019 Mobvoi Inc. All Rights Reserved.
5 | # Author: di.wu@mobvoi.com (DI WU)
6 | """Encoder definition."""
7 | from typing import Tuple, List, Optional
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from typeguard import check_argument_types
13 |
14 | from wenet.transformer.attention import MultiHeadedAttention
15 | from wenet.transformer.attention import RelPositionMultiHeadedAttention
16 | from wenet.transformer.convolution import ConvolutionModule
17 | from wenet.transformer.embedding import PositionalEncoding
18 | from wenet.transformer.embedding import RelPositionalEncoding
19 | from wenet.transformer.embedding import NoPositionalEncoding
20 | from wenet.transformer.encoder_layer import TransformerEncoderLayer
21 | from wenet.transformer.encoder_layer import ConformerEncoderLayer
22 | from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
23 | from wenet.transformer.subsampling import Conv2dSubsampling2
24 | from wenet.transformer.subsampling import Conv2dSubsampling4
25 | from wenet.transformer.subsampling import Conv2dSubsampling6
26 | from wenet.transformer.subsampling import Conv2dSubsampling8
27 | from wenet.transformer.subsampling import LinearNoSubsampling
28 | from wenet.utils.common import get_activation
29 | from wenet.utils.mask import make_pad_mask
30 | from wenet.utils.mask import add_optional_chunk_mask
31 |
32 |
33 | class BaseEncoder(torch.nn.Module):
34 | def __init__(
35 | self,
36 | input_size: int,
37 | output_size: int = 256,
38 | attention_heads: int = 4,
39 | linear_units: int = 2048,
40 | num_blocks: int = 6,
41 | dropout_rate: float = 0.1,
42 | positional_dropout_rate: float = 0.1,
43 | attention_dropout_rate: float = 0.0,
44 | input_layer: str = "conv2d",
45 | pos_enc_layer_type: str = "abs_pos",
46 | normalize_before: bool = True,
47 | concat_after: bool = False,
48 | static_chunk_size: int = 0,
49 | use_dynamic_chunk: bool = False,
50 | global_cmvn: torch.nn.Module = None,
51 | use_dynamic_left_chunk: bool = False,
52 | ):
53 | """
54 | Args:
55 | input_size (int): input dim
56 | output_size (int): dimension of attention
57 | attention_heads (int): the number of heads of multi head attention
58 | linear_units (int): the hidden units number of position-wise feed
59 | forward
60 | num_blocks (int): the number of decoder blocks
61 | dropout_rate (float): dropout rate
62 | attention_dropout_rate (float): dropout rate in attention
63 | positional_dropout_rate (float): dropout rate after adding
64 | positional encoding
65 | input_layer (str): input layer type.
66 | optional [linear, conv2d, conv2d6, conv2d8]
67 | pos_enc_layer_type (str): Encoder positional encoding layer type.
68 | opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
69 | normalize_before (bool):
70 | True: use layer_norm before each sub-block of a layer.
71 | False: use layer_norm after each sub-block of a layer.
72 | concat_after (bool): whether to concat attention layer's input
73 | and output.
74 | True: x -> x + linear(concat(x, att(x)))
75 | False: x -> x + att(x)
76 | static_chunk_size (int): chunk size for static chunk training and
77 | decoding
78 | use_dynamic_chunk (bool): whether use dynamic chunk size for
79 | training or not, You can only use fixed chunk(chunk_size > 0)
80 | or dyanmic chunk size(use_dynamic_chunk = True)
81 | global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
82 | use_dynamic_left_chunk (bool): whether use dynamic left chunk in
83 | dynamic chunk training
84 | """
85 | assert check_argument_types()
86 | super().__init__()
87 | self._output_size = output_size
88 |
89 | if pos_enc_layer_type == "abs_pos":
90 | pos_enc_class = PositionalEncoding
91 | elif pos_enc_layer_type == "rel_pos":
92 | pos_enc_class = RelPositionalEncoding
93 | elif pos_enc_layer_type == "no_pos":
94 | pos_enc_class = NoPositionalEncoding
95 | else:
96 | raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
97 |
98 | if input_layer == "linear":
99 | subsampling_class = LinearNoSubsampling
100 | elif input_layer == "conv2d":
101 | subsampling_class = Conv2dSubsampling4
102 | elif input_layer == "conv2d6":
103 | subsampling_class = Conv2dSubsampling6
104 | elif input_layer == "conv2d8":
105 | subsampling_class = Conv2dSubsampling8
106 | elif input_layer == "conv2d2":
107 | subsampling_class = Conv2dSubsampling2
108 | else:
109 | raise ValueError("unknown input_layer: " + input_layer)
110 |
111 | self.global_cmvn = global_cmvn
112 | self.embed = subsampling_class(
113 | input_size,
114 | output_size,
115 | dropout_rate,
116 | pos_enc_class(output_size, positional_dropout_rate),
117 | )
118 |
119 | self.normalize_before = normalize_before
120 | self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-12)
121 | self.static_chunk_size = static_chunk_size
122 | self.use_dynamic_chunk = use_dynamic_chunk
123 | self.use_dynamic_left_chunk = use_dynamic_left_chunk
124 | self.num_blocks = num_blocks
125 | self.feature_weight = nn.Parameter(torch.ones(self.num_blocks))
126 |
127 | def output_size(self) -> int:
128 | return self._output_size
129 |
130 | def forward(
131 | self,
132 | xs: torch.Tensor,
133 | xs_lens: torch.Tensor,
134 | decoding_chunk_size: int = 0,
135 | num_decoding_left_chunks: int = -1,
136 | ) -> Tuple[torch.Tensor, torch.Tensor]:
137 | """Embed positions in tensor.
138 |
139 | Args:
140 | xs: padded input tensor (B, T, D)
141 | xs_lens: input length (B)
142 | decoding_chunk_size: decoding chunk size for dynamic chunk
143 | 0: default for training, use random dynamic chunk.
144 | <0: for decoding, use full chunk.
145 | >0: for decoding, use fixed chunk size as set.
146 | num_decoding_left_chunks: number of left chunks, this is for decoding,
147 | the chunk size is decoding_chunk_size.
148 | >=0: use num_decoding_left_chunks
149 | <0: use all left chunks
150 | Returns:
151 | encoder output tensor xs, and subsampled masks
152 | xs: padded output tensor (B, T' ~= T/subsample_rate, D)
153 | masks: torch.Tensor batch padding mask after subsample
154 | (B, 1, T' ~= T/subsample_rate)
155 | """
156 | masks = ~make_pad_mask(xs_lens).unsqueeze(1) # (B, 1, T)
157 | if self.global_cmvn is not None:
158 | xs = self.global_cmvn(xs)
159 | xs, pos_emb, masks = self.embed(xs, masks)
160 | mask_pad = masks # (B, 1, T/subsample_rate)
161 | chunk_masks = add_optional_chunk_mask(xs, masks,
162 | self.use_dynamic_chunk,
163 | self.use_dynamic_left_chunk,
164 | decoding_chunk_size,
165 | self.static_chunk_size,
166 | num_decoding_left_chunks)
167 | out = []
168 | for layer in self.encoders:
169 | xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
170 | out.append(xs)
171 | xs = torch.cat(out, dim=-1)
172 | xs = xs.reshape(xs.shape[0], xs.shape[1], self.output_size(), self.num_blocks)
173 | norm_weights = F.softmax(self.feature_weight, dim=-1)
174 | xs = xs.matmul(norm_weights)
175 |
176 | if self.normalize_before:
177 | xs = self.after_norm(xs)
178 | # Here we assume the mask is not changed in encoder layers, so just
179 | # return the masks before encoder layers, and the masks will be used
180 | # for cross attention with decoder later
181 | return xs, masks
182 |
183 | def forward_chunk(
184 | self,
185 | xs: torch.Tensor,
186 | offset: int,
187 | required_cache_size: int,
188 | subsampling_cache: Optional[torch.Tensor] = None,
189 | elayers_output_cache: Optional[List[torch.Tensor]] = None,
190 | conformer_cnn_cache: Optional[List[torch.Tensor]] = None,
191 | ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor],
192 | List[torch.Tensor]]:
193 | """ Forward just one chunk
194 |
195 | Args:
196 | xs (torch.Tensor): chunk input
197 | offset (int): current offset in encoder output time stamp
198 | required_cache_size (int): cache size required for next chunk
199 | compuation
200 | >=0: actual cache size
201 | <0: means all history cache is required
202 | subsampling_cache (Optional[torch.Tensor]): subsampling cache
203 | elayers_output_cache (Optional[List[torch.Tensor]]):
204 | transformer/conformer encoder layers output cache
205 | conformer_cnn_cache (Optional[List[torch.Tensor]]): conformer
206 | cnn cache
207 |
208 | Returns:
209 | torch.Tensor: output of current input xs
210 | torch.Tensor: subsampling cache required for next chunk computation
211 | List[torch.Tensor]: encoder layers output cache required for next
212 | chunk computation
213 | List[torch.Tensor]: conformer cnn cache
214 |
215 | """
216 | assert xs.size(0) == 1
217 | # tmp_masks is just for interface compatibility
218 | tmp_masks = torch.ones(1,
219 | xs.size(1),
220 | device=xs.device,
221 | dtype=torch.bool)
222 | tmp_masks = tmp_masks.unsqueeze(1)
223 | if self.global_cmvn is not None:
224 | xs = self.global_cmvn(xs)
225 | xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
226 | if subsampling_cache is not None:
227 | cache_size = subsampling_cache.size(1)
228 | xs = torch.cat((subsampling_cache, xs), dim=1)
229 | else:
230 | cache_size = 0
231 | pos_emb = self.embed.position_encoding(offset - cache_size, xs.size(1))
232 | if required_cache_size < 0:
233 | next_cache_start = 0
234 | elif required_cache_size == 0:
235 | next_cache_start = xs.size(1)
236 | else:
237 | next_cache_start = max(xs.size(1) - required_cache_size, 0)
238 | r_subsampling_cache = xs[:, next_cache_start:, :]
239 | # Real mask for transformer/conformer layers
240 | masks = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool)
241 | masks = masks.unsqueeze(1)
242 | r_elayers_output_cache = []
243 | r_conformer_cnn_cache = []
244 | for i, layer in enumerate(self.encoders):
245 | if elayers_output_cache is None:
246 | attn_cache = None
247 | else:
248 | attn_cache = elayers_output_cache[i]
249 | if conformer_cnn_cache is None:
250 | cnn_cache = None
251 | else:
252 | cnn_cache = conformer_cnn_cache[i]
253 | xs, _, new_cnn_cache = layer(xs,
254 | masks,
255 | pos_emb,
256 | output_cache=attn_cache,
257 | cnn_cache=cnn_cache)
258 | r_elayers_output_cache.append(xs[:, next_cache_start:, :])
259 | r_conformer_cnn_cache.append(new_cnn_cache)
260 | if self.normalize_before:
261 | xs = self.after_norm(xs)
262 |
263 | return (xs[:, cache_size:, :], r_subsampling_cache,
264 | r_elayers_output_cache, r_conformer_cnn_cache)
265 |
266 | def forward_chunk_by_chunk(
267 | self,
268 | xs: torch.Tensor,
269 | decoding_chunk_size: int,
270 | num_decoding_left_chunks: int = -1,
271 | ) -> Tuple[torch.Tensor, torch.Tensor]:
272 | """ Forward input chunk by chunk with chunk_size like a streaming
273 | fashion
274 |
275 | Here we should pay special attention to computation cache in the
276 | streaming style forward chunk by chunk. Three things should be taken
277 | into account for computation in the current network:
278 | 1. transformer/conformer encoder layers output cache
279 | 2. convolution in conformer
280 | 3. convolution in subsampling
281 |
282 | However, we don't implement subsampling cache for:
283 | 1. We can control subsampling module to output the right result by
284 | overlapping input instead of cache left context, even though it
285 | wastes some computation, but subsampling only takes a very
286 | small fraction of computation in the whole model.
287 | 2. Typically, there are several covolution layers with subsampling
288 | in subsampling module, it is tricky and complicated to do cache
289 | with different convolution layers with different subsampling
290 | rate.
291 | 3. Currently, nn.Sequential is used to stack all the convolution
292 | layers in subsampling, we need to rewrite it to make it work
293 | with cache, which is not prefered.
294 | Args:
295 | xs (torch.Tensor): (1, max_len, dim)
296 | chunk_size (int): decoding chunk size
297 | """
298 | assert decoding_chunk_size > 0
299 | # The model is trained by static or dynamic chunk
300 | assert self.static_chunk_size > 0 or self.use_dynamic_chunk
301 | subsampling = self.embed.subsampling_rate
302 | context = self.embed.right_context + 1 # Add current frame
303 | stride = subsampling * decoding_chunk_size
304 | decoding_window = (decoding_chunk_size - 1) * subsampling + context
305 | num_frames = xs.size(1)
306 | subsampling_cache: Optional[torch.Tensor] = None
307 | elayers_output_cache: Optional[List[torch.Tensor]] = None
308 | conformer_cnn_cache: Optional[List[torch.Tensor]] = None
309 | outputs = []
310 | offset = 0
311 | required_cache_size = decoding_chunk_size * num_decoding_left_chunks
312 |
313 | # Feed forward overlap input step by step
314 | for cur in range(0, num_frames - context + 1, stride):
315 | end = min(cur + decoding_window, num_frames)
316 | chunk_xs = xs[:, cur:end, :]
317 | (y, subsampling_cache, elayers_output_cache,
318 | conformer_cnn_cache) = self.forward_chunk(chunk_xs, offset,
319 | required_cache_size,
320 | subsampling_cache,
321 | elayers_output_cache,
322 | conformer_cnn_cache)
323 | outputs.append(y)
324 | offset += y.size(1)
325 | ys = torch.cat(outputs, 1)
326 | masks = torch.ones(1, ys.size(1), device=ys.device, dtype=torch.bool)
327 | masks = masks.unsqueeze(1)
328 | return ys, masks
329 |
330 |
331 | class TransformerEncoder(BaseEncoder):
332 | """Transformer encoder module."""
333 | def __init__(
334 | self,
335 | input_size: int,
336 | output_size: int = 256,
337 | attention_heads: int = 4,
338 | linear_units: int = 2048,
339 | num_blocks: int = 6,
340 | dropout_rate: float = 0.1,
341 | positional_dropout_rate: float = 0.1,
342 | attention_dropout_rate: float = 0.0,
343 | input_layer: str = "conv2d",
344 | pos_enc_layer_type: str = "abs_pos",
345 | normalize_before: bool = True,
346 | concat_after: bool = False,
347 | static_chunk_size: int = 0,
348 | use_dynamic_chunk: bool = False,
349 | global_cmvn: torch.nn.Module = None,
350 | use_dynamic_left_chunk: bool = False,
351 | ):
352 | """ Construct TransformerEncoder
353 |
354 | See Encoder for the meaning of each parameter.
355 | """
356 | assert check_argument_types()
357 | super().__init__(input_size, output_size, attention_heads,
358 | linear_units, num_blocks, dropout_rate,
359 | positional_dropout_rate, attention_dropout_rate,
360 | input_layer, pos_enc_layer_type, normalize_before,
361 | concat_after, static_chunk_size, use_dynamic_chunk,
362 | global_cmvn, use_dynamic_left_chunk)
363 | self.encoders = torch.nn.ModuleList([
364 | TransformerEncoderLayer(
365 | output_size,
366 | MultiHeadedAttention(attention_heads, output_size,
367 | attention_dropout_rate),
368 | PositionwiseFeedForward(output_size, linear_units,
369 | dropout_rate), dropout_rate,
370 | normalize_before, concat_after) for _ in range(num_blocks)
371 | ])
372 |
373 |
374 | class ConformerEncoder(BaseEncoder):
375 | """Conformer encoder module."""
376 | def __init__(
377 | self,
378 | input_size: int,
379 | output_size: int = 256,
380 | attention_heads: int = 4,
381 | linear_units: int = 2048,
382 | num_blocks: int = 6,
383 | dropout_rate: float = 0.1,
384 | positional_dropout_rate: float = 0.1,
385 | attention_dropout_rate: float = 0.0,
386 | input_layer: str = "conv2d",
387 | pos_enc_layer_type: str = "rel_pos",
388 | normalize_before: bool = True,
389 | concat_after: bool = False,
390 | static_chunk_size: int = 0,
391 | use_dynamic_chunk: bool = False,
392 | global_cmvn: torch.nn.Module = None,
393 | use_dynamic_left_chunk: bool = False,
394 | positionwise_conv_kernel_size: int = 1,
395 | macaron_style: bool = True,
396 | selfattention_layer_type: str = "rel_selfattn",
397 | activation_type: str = "swish",
398 | use_cnn_module: bool = True,
399 | cnn_module_kernel: int = 15,
400 | causal: bool = False,
401 | cnn_module_norm: str = "batch_norm",
402 | ):
403 | """Construct ConformerEncoder
404 |
405 | Args:
406 | input_size to use_dynamic_chunk, see in BaseEncoder
407 | positionwise_conv_kernel_size (int): Kernel size of positionwise
408 | conv1d layer.
409 | macaron_style (bool): Whether to use macaron style for
410 | positionwise layer.
411 | selfattention_layer_type (str): Encoder attention layer type,
412 | the parameter has no effect now, it's just for configure
413 | compatibility.
414 | activation_type (str): Encoder activation function type.
415 | use_cnn_module (bool): Whether to use convolution module.
416 | cnn_module_kernel (int): Kernel size of convolution module.
417 | causal (bool): whether to use causal convolution or not.
418 | """
419 | assert check_argument_types()
420 | super().__init__(input_size, output_size, attention_heads,
421 | linear_units, num_blocks, dropout_rate,
422 | positional_dropout_rate, attention_dropout_rate,
423 | input_layer, pos_enc_layer_type, normalize_before,
424 | concat_after, static_chunk_size, use_dynamic_chunk,
425 | global_cmvn, use_dynamic_left_chunk)
426 | activation = get_activation(activation_type)
427 |
428 | # self-attention module definition
429 | if pos_enc_layer_type == "no_pos":
430 | encoder_selfattn_layer = MultiHeadedAttention
431 | else:
432 | encoder_selfattn_layer = RelPositionMultiHeadedAttention
433 | encoder_selfattn_layer_args = (
434 | attention_heads,
435 | output_size,
436 | attention_dropout_rate,
437 | )
438 | # feed-forward module definition
439 | positionwise_layer = PositionwiseFeedForward
440 | positionwise_layer_args = (
441 | output_size,
442 | linear_units,
443 | dropout_rate,
444 | activation,
445 | )
446 | # convolution module definition
447 | convolution_layer = ConvolutionModule
448 | convolution_layer_args = (output_size, cnn_module_kernel, activation,
449 | cnn_module_norm, causal)
450 |
451 | self.encoders = torch.nn.ModuleList([
452 | ConformerEncoderLayer(
453 | output_size,
454 | encoder_selfattn_layer(*encoder_selfattn_layer_args),
455 | positionwise_layer(*positionwise_layer_args),
456 | positionwise_layer(
457 | *positionwise_layer_args) if macaron_style else None,
458 | convolution_layer(
459 | *convolution_layer_args) if use_cnn_module else None,
460 | dropout_rate,
461 | normalize_before,
462 | concat_after,
463 | ) for _ in range(num_blocks)
464 | ])
465 |
--------------------------------------------------------------------------------
/wenet/transformer/label_smoothing_loss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2019 Shigeki Karita
5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6 | """Label smoothing module."""
7 |
8 | import torch
9 | from torch import nn
10 |
11 |
12 | class LabelSmoothingLoss(nn.Module):
13 | """Label-smoothing loss.
14 |
15 | In a standard CE loss, the label's data distribution is:
16 | [0,1,2] ->
17 | [
18 | [1.0, 0.0, 0.0],
19 | [0.0, 1.0, 0.0],
20 | [1.0, 0.0, 1.0],
21 | ]
22 |
23 | In the smoothing version CE Loss,some probabilities
24 | are taken from the true label prob (1.0) and are divided
25 | among other labels.
26 |
27 | e.g.
28 | smoothing=0.1
29 | [0,1,2] ->
30 | [
31 | [0.9, 0.05, 0.05],
32 | [0.05, 0.9, 0.05],
33 | [0.05, 0.05, 0.9],
34 | ]
35 |
36 | Args:
37 | size (int): the number of class
38 | padding_idx (int): padding class id which will be ignored for loss
39 | smoothing (float): smoothing rate (0.0 means the conventional CE)
40 | normalize_length (bool):
41 | normalize loss by sequence length if True
42 | normalize loss by batch size if False
43 | """
44 | def __init__(self,
45 | size: int,
46 | padding_idx: int,
47 | smoothing: float,
48 | normalize_length: bool = False):
49 | """Construct an LabelSmoothingLoss object."""
50 | super(LabelSmoothingLoss, self).__init__()
51 | self.criterion = nn.KLDivLoss(reduction="none")
52 | self.padding_idx = padding_idx
53 | self.confidence = 1.0 - smoothing
54 | self.smoothing = smoothing
55 | self.size = size
56 | self.normalize_length = normalize_length
57 |
58 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
59 | """Compute loss between x and target.
60 |
61 | The model outputs and data labels tensors are flatten to
62 | (batch*seqlen, class) shape and a mask is applied to the
63 | padding part which should not be calculated for loss.
64 |
65 | Args:
66 | x (torch.Tensor): prediction (batch, seqlen, class)
67 | target (torch.Tensor):
68 | target signal masked with self.padding_id (batch, seqlen)
69 | Returns:
70 | loss (torch.Tensor) : The KL loss, scalar float value
71 | """
72 | assert x.size(2) == self.size
73 | batch_size = x.size(0)
74 | x = x.view(-1, self.size)
75 | target = target.view(-1)
76 | # use zeros_like instead of torch.no_grad() for true_dist,
77 | # since no_grad() can not be exported by JIT
78 | true_dist = torch.zeros_like(x)
79 | true_dist.fill_(self.smoothing / (self.size - 1))
80 | ignore = target == self.padding_idx # (B,)
81 | total = len(target) - ignore.sum().item()
82 | target = target.masked_fill(ignore, 0) # avoid -1 index
83 | true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
84 | kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
85 | denom = total if self.normalize_length else batch_size
86 | return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
87 |
--------------------------------------------------------------------------------
/wenet/transformer/positionwise_feed_forward.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2019 Shigeki Karita
5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6 | """Positionwise feed forward layer definition."""
7 |
8 | import torch
9 |
10 |
11 | class PositionwiseFeedForward(torch.nn.Module):
12 | """Positionwise feed forward layer.
13 |
14 | FeedForward are appied on each position of the sequence.
15 | The output dim is same with the input dim.
16 |
17 | Args:
18 | idim (int): Input dimenstion.
19 | hidden_units (int): The number of hidden units.
20 | dropout_rate (float): Dropout rate.
21 | activation (torch.nn.Module): Activation function
22 | """
23 | def __init__(self,
24 | idim: int,
25 | hidden_units: int,
26 | dropout_rate: float,
27 | activation: torch.nn.Module = torch.nn.ReLU()):
28 | """Construct a PositionwiseFeedForward object."""
29 | super(PositionwiseFeedForward, self).__init__()
30 | self.w_1 = torch.nn.Linear(idim, hidden_units)
31 | self.activation = activation
32 | self.dropout = torch.nn.Dropout(dropout_rate)
33 | self.w_2 = torch.nn.Linear(hidden_units, idim)
34 |
35 | def forward(self, xs: torch.Tensor) -> torch.Tensor:
36 | """Forward function.
37 |
38 | Args:
39 | xs: input tensor (B, L, D)
40 | Returns:
41 | output tensor, (B, L, D)
42 | """
43 | return self.w_2(self.dropout(self.activation(self.w_1(xs))))
44 |
--------------------------------------------------------------------------------
/wenet/transformer/subsampling.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2019 Mobvoi Inc. All Rights Reserved.
5 | # Author: di.wu@mobvoi.com (DI WU)
6 | """Subsampling layer definition."""
7 |
8 | from typing import Tuple
9 |
10 | import torch
11 |
12 | class BaseSubsampling(torch.nn.Module):
13 | def __init__(self):
14 | super().__init__()
15 | self.right_context = 0
16 | self.subsampling_rate = 1
17 |
18 | def position_encoding(self, offset: int, size: int) -> torch.Tensor:
19 | return self.pos_enc.position_encoding(offset, size)
20 |
21 |
22 | class LinearNoSubsampling(BaseSubsampling):
23 | """Linear transform the input without subsampling
24 |
25 | Args:
26 | idim (int): Input dimension.
27 | odim (int): Output dimension.
28 | dropout_rate (float): Dropout rate.
29 |
30 | """
31 | def __init__(self, idim: int, odim: int, dropout_rate: float,
32 | pos_enc_class: torch.nn.Module):
33 | """Construct an linear object."""
34 | super().__init__()
35 | self.out = torch.nn.Sequential(
36 | torch.nn.Linear(idim, odim),
37 | torch.nn.LayerNorm(odim, eps=1e-12),
38 | torch.nn.Dropout(dropout_rate),
39 | )
40 | self.pos_enc = pos_enc_class
41 | self.right_context = 0
42 | self.subsampling_rate = 1
43 |
44 | def forward(
45 | self,
46 | x: torch.Tensor,
47 | x_mask: torch.Tensor,
48 | offset: int = 0
49 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
50 | """Input x.
51 |
52 | Args:
53 | x (torch.Tensor): Input tensor (#batch, time, idim).
54 | x_mask (torch.Tensor): Input mask (#batch, 1, time).
55 |
56 | Returns:
57 | torch.Tensor: linear input tensor (#batch, time', odim),
58 | where time' = time .
59 | torch.Tensor: linear input mask (#batch, 1, time'),
60 | where time' = time .
61 |
62 | """
63 | x = self.out(x)
64 | x, pos_emb = self.pos_enc(x, offset)
65 | return x, pos_emb, x_mask
66 |
67 | class Conv2dSubsampling2(BaseSubsampling):
68 | """Convolutional 2D subsampling (to 1/2 length).
69 |
70 | Args:
71 | idim (int): Input dimension.
72 | odim (int): Output dimension.
73 | dropout_rate (float): Dropout rate.
74 |
75 | """
76 | def __init__(self, idim: int, odim: int, dropout_rate: float,
77 | pos_enc_class: torch.nn.Module):
78 | """Construct an Conv2dSubsampling4 object."""
79 | super().__init__()
80 | self.conv = torch.nn.Sequential(
81 | torch.nn.Conv2d(1, odim, 3, 2),
82 | torch.nn.ReLU(),
83 | )
84 | self.out = torch.nn.Sequential(
85 | torch.nn.Linear(odim * (idim // 2 - 1), odim))
86 | self.pos_enc = pos_enc_class
87 | # The right context for every conv layer is computed by:
88 | # (kernel_size - 1) * frame_rate_of_this_layer
89 | self.subsampling_rate = 2
90 | # 2 = (3 - 1) * 1
91 | self.right_context = 2
92 |
93 | def forward(
94 | self,
95 | x: torch.Tensor,
96 | x_mask: torch.Tensor,
97 | offset: int = 0
98 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
99 | """Subsample x.
100 |
101 | Args:
102 | x (torch.Tensor): Input tensor (#batch, time, idim).
103 | x_mask (torch.Tensor): Input mask (#batch, 1, time).
104 |
105 | Returns:
106 | torch.Tensor: Subsampled tensor (#batch, time', odim),
107 | where time' = time // 4.
108 | torch.Tensor: Subsampled mask (#batch, 1, time'),
109 | where time' = time // 4.
110 | torch.Tensor: positional encoding
111 |
112 | """
113 | x = x.unsqueeze(1) # (b, c=1, t, f)
114 | x = self.conv(x)
115 | b, c, t, f = x.size()
116 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
117 | x, pos_emb = self.pos_enc(x, offset)
118 | return x, pos_emb, x_mask[:, :, :-2:2]
119 |
120 |
121 | class Conv2dSubsampling4(BaseSubsampling):
122 | """Convolutional 2D subsampling (to 1/4 length).
123 |
124 | Args:
125 | idim (int): Input dimension.
126 | odim (int): Output dimension.
127 | dropout_rate (float): Dropout rate.
128 |
129 | """
130 | def __init__(self, idim: int, odim: int, dropout_rate: float,
131 | pos_enc_class: torch.nn.Module):
132 | """Construct an Conv2dSubsampling4 object."""
133 | super().__init__()
134 | self.conv = torch.nn.Sequential(
135 | torch.nn.Conv2d(1, odim, 3, 2),
136 | torch.nn.ReLU(),
137 | torch.nn.Conv2d(odim, odim, 3, 2),
138 | torch.nn.ReLU(),
139 | )
140 | self.out = torch.nn.Sequential(
141 | torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
142 | self.pos_enc = pos_enc_class
143 | # The right context for every conv layer is computed by:
144 | # (kernel_size - 1) * frame_rate_of_this_layer
145 | self.subsampling_rate = 4
146 | # 6 = (3 - 1) * 1 + (3 - 1) * 2
147 | self.right_context = 6
148 |
149 | def forward(
150 | self,
151 | x: torch.Tensor,
152 | x_mask: torch.Tensor,
153 | offset: int = 0
154 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
155 | """Subsample x.
156 |
157 | Args:
158 | x (torch.Tensor): Input tensor (#batch, time, idim).
159 | x_mask (torch.Tensor): Input mask (#batch, 1, time).
160 |
161 | Returns:
162 | torch.Tensor: Subsampled tensor (#batch, time', odim),
163 | where time' = time // 4.
164 | torch.Tensor: Subsampled mask (#batch, 1, time'),
165 | where time' = time // 4.
166 | torch.Tensor: positional encoding
167 |
168 | """
169 | x = x.unsqueeze(1) # (b, c=1, t, f)
170 | x = self.conv(x)
171 | b, c, t, f = x.size()
172 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
173 | x, pos_emb = self.pos_enc(x, offset)
174 | return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]
175 |
176 |
177 | class Conv2dSubsampling6(BaseSubsampling):
178 | """Convolutional 2D subsampling (to 1/6 length).
179 | Args:
180 | idim (int): Input dimension.
181 | odim (int): Output dimension.
182 | dropout_rate (float): Dropout rate.
183 | pos_enc (torch.nn.Module): Custom position encoding layer.
184 | """
185 | def __init__(self, idim: int, odim: int, dropout_rate: float,
186 | pos_enc_class: torch.nn.Module):
187 | """Construct an Conv2dSubsampling6 object."""
188 | super().__init__()
189 | self.conv = torch.nn.Sequential(
190 | torch.nn.Conv2d(1, odim, 3, 2),
191 | torch.nn.ReLU(),
192 | torch.nn.Conv2d(odim, odim, 5, 3),
193 | torch.nn.ReLU(),
194 | )
195 | self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
196 | odim)
197 | self.pos_enc = pos_enc_class
198 | # 10 = (3 - 1) * 1 + (5 - 1) * 2
199 | self.subsampling_rate = 6
200 | self.right_context = 10
201 |
202 | def forward(
203 | self,
204 | x: torch.Tensor,
205 | x_mask: torch.Tensor,
206 | offset: int = 0
207 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
208 | """Subsample x.
209 | Args:
210 | x (torch.Tensor): Input tensor (#batch, time, idim).
211 | x_mask (torch.Tensor): Input mask (#batch, 1, time).
212 |
213 | Returns:
214 | torch.Tensor: Subsampled tensor (#batch, time', odim),
215 | where time' = time // 6.
216 | torch.Tensor: Subsampled mask (#batch, 1, time'),
217 | where time' = time // 6.
218 | torch.Tensor: positional encoding
219 | """
220 | x = x.unsqueeze(1) # (b, c, t, f)
221 | x = self.conv(x)
222 | b, c, t, f = x.size()
223 | x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
224 | x, pos_emb = self.pos_enc(x, offset)
225 | return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3]
226 |
227 |
228 | class Conv2dSubsampling8(BaseSubsampling):
229 | """Convolutional 2D subsampling (to 1/8 length).
230 |
231 | Args:
232 | idim (int): Input dimension.
233 | odim (int): Output dimension.
234 | dropout_rate (float): Dropout rate.
235 |
236 | """
237 | def __init__(self, idim: int, odim: int, dropout_rate: float,
238 | pos_enc_class: torch.nn.Module):
239 | """Construct an Conv2dSubsampling8 object."""
240 | super().__init__()
241 | self.conv = torch.nn.Sequential(
242 | torch.nn.Conv2d(1, odim, 3, 2),
243 | torch.nn.ReLU(),
244 | torch.nn.Conv2d(odim, odim, 3, 2),
245 | torch.nn.ReLU(),
246 | torch.nn.Conv2d(odim, odim, 3, 2),
247 | torch.nn.ReLU(),
248 | )
249 | self.linear = torch.nn.Linear(
250 | odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
251 | self.pos_enc = pos_enc_class
252 | self.subsampling_rate = 8
253 | # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
254 | self.right_context = 14
255 |
256 | def forward(
257 | self,
258 | x: torch.Tensor,
259 | x_mask: torch.Tensor,
260 | offset: int = 0
261 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
262 | """Subsample x.
263 |
264 | Args:
265 | x (torch.Tensor): Input tensor (#batch, time, idim).
266 | x_mask (torch.Tensor): Input mask (#batch, 1, time).
267 |
268 | Returns:
269 | torch.Tensor: Subsampled tensor (#batch, time', odim),
270 | where time' = time // 8.
271 | torch.Tensor: Subsampled mask (#batch, 1, time'),
272 | where time' = time // 8.
273 | torch.Tensor: positional encoding
274 | """
275 | x = x.unsqueeze(1) # (b, c, t, f)
276 | x = self.conv(x)
277 | b, c, t, f = x.size()
278 | x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
279 | x, pos_emb = self.pos_enc(x, offset)
280 | return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]
281 |
282 |
283 |
284 |
--------------------------------------------------------------------------------
/wenet/transformer/swish.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2020 Johns Hopkins University (Shinji Watanabe)
5 | # Northwestern Polytechnical University (Pengcheng Guo)
6 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
7 | """Swish() activation function for Conformer."""
8 |
9 | import torch
10 |
11 |
12 | class Swish(torch.nn.Module):
13 | """Construct an Swish object."""
14 | def forward(self, x: torch.Tensor) -> torch.Tensor:
15 | """Return Swish activation function."""
16 | return x * torch.sigmoid(x)
17 |
--------------------------------------------------------------------------------
/wenet/utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Mobvoi Inc. All Rights Reserved.
2 | # Author: binbinzhang@mobvoi.com (Binbin Zhang)
3 |
4 | import logging
5 | import os
6 | import re
7 |
8 | import yaml
9 | import torch
10 |
11 |
12 | def load_checkpoint(model: torch.nn.Module, path: str) -> dict:
13 | if torch.cuda.is_available():
14 | logging.info('Checkpoint: loading from checkpoint %s for GPU' % path)
15 | checkpoint = torch.load(path)
16 | else:
17 | logging.info('Checkpoint: loading from checkpoint %s for CPU' % path)
18 | checkpoint = torch.load(path, map_location='cpu')
19 | model.load_state_dict(checkpoint)
20 | info_path = re.sub('.pt$', '.yaml', path)
21 | configs = {}
22 | if os.path.exists(info_path):
23 | with open(info_path, 'r') as fin:
24 | configs = yaml.load(fin, Loader=yaml.FullLoader)
25 | return configs
26 |
27 |
28 | def save_checkpoint(model: torch.nn.Module, path: str, infos=None):
29 | '''
30 | Args:
31 | infos (dict or None): any info you want to save.
32 | '''
33 | logging.info('Checkpoint: save to checkpoint %s' % path)
34 | if isinstance(model, torch.nn.DataParallel):
35 | state_dict = model.module.state_dict()
36 | elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
37 | state_dict = model.module.state_dict()
38 | else:
39 | state_dict = model.state_dict()
40 | torch.save(state_dict, path)
41 | info_path = re.sub('.pt$', '.yaml', path)
42 | if infos is None:
43 | infos = {}
44 | with open(info_path, 'w') as fout:
45 | data = yaml.dump(infos)
46 | fout.write(data)
47 |
--------------------------------------------------------------------------------
/wenet/utils/cmvn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import json
17 | import math
18 |
19 | import numpy as np
20 |
21 |
22 | def _load_json_cmvn(json_cmvn_file):
23 | """ Load the json format cmvn stats file and calculate cmvn
24 |
25 | Args:
26 | json_cmvn_file: cmvn stats file in json format
27 |
28 | Returns:
29 | a numpy array of [means, vars]
30 | """
31 | with open(json_cmvn_file) as f:
32 | cmvn_stats = json.load(f)
33 |
34 | means = cmvn_stats['mean_stat']
35 | variance = cmvn_stats['var_stat']
36 | count = cmvn_stats['frame_num']
37 | for i in range(len(means)):
38 | means[i] /= count
39 | variance[i] = variance[i] / count - means[i] * means[i]
40 | if variance[i] < 1.0e-20:
41 | variance[i] = 1.0e-20
42 | variance[i] = 1.0 / math.sqrt(variance[i])
43 | cmvn = np.array([means, variance])
44 | return cmvn
45 |
46 |
47 | def _load_kaldi_cmvn(kaldi_cmvn_file):
48 | """ Load the kaldi format cmvn stats file and calculate cmvn
49 |
50 | Args:
51 | kaldi_cmvn_file: kaldi text style global cmvn file, which
52 | is generated by:
53 | compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
54 |
55 | Returns:
56 | a numpy array of [means, vars]
57 | """
58 | means = []
59 | variance = []
60 | with open(kaldi_cmvn_file, 'r') as fid:
61 | # kaldi binary file start with '\0B'
62 | if fid.read(2) == '\0B':
63 | logging.error('kaldi cmvn binary file is not supported, please '
64 | 'recompute it by: compute-cmvn-stats --binary=false '
65 | ' scp:feats.scp global_cmvn')
66 | sys.exit(1)
67 | fid.seek(0)
68 | arr = fid.read().split()
69 | assert (arr[0] == '[')
70 | assert (arr[-2] == '0')
71 | assert (arr[-1] == ']')
72 | feat_dim = int((len(arr) - 2 - 2) / 2)
73 | for i in range(1, feat_dim + 1):
74 | means.append(float(arr[i]))
75 | count = float(arr[feat_dim + 1])
76 | for i in range(feat_dim + 2, 2 * feat_dim + 2):
77 | variance.append(float(arr[i]))
78 |
79 | for i in range(len(means)):
80 | means[i] /= count
81 | variance[i] = variance[i] / count - means[i] * means[i]
82 | if variance[i] < 1.0e-20:
83 | variance[i] = 1.0e-20
84 | variance[i] = 1.0 / math.sqrt(variance[i])
85 | cmvn = np.array([means, variance])
86 | return cmvn
87 |
88 |
89 | def load_cmvn(cmvn_file, is_json):
90 | if is_json:
91 | cmvn = _load_json_cmvn(cmvn_file)
92 | else:
93 | cmvn = _load_kaldi_cmvn(cmvn_file)
94 | return cmvn[0], cmvn[1]
95 |
--------------------------------------------------------------------------------
/wenet/utils/common.py:
--------------------------------------------------------------------------------
1 | """Unility functions for Transformer."""
2 |
3 | import math
4 | from typing import Tuple, List
5 |
6 | import torch
7 | from torch.nn.utils.rnn import pad_sequence
8 |
9 | IGNORE_ID = -1
10 |
11 |
12 | def pad_list(xs: List[torch.Tensor], pad_value: int):
13 | """Perform padding for the list of tensors.
14 |
15 | Args:
16 | xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
17 | pad_value (float): Value for padding.
18 |
19 | Returns:
20 | Tensor: Padded tensor (B, Tmax, `*`).
21 |
22 | Examples:
23 | >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
24 | >>> x
25 | [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
26 | >>> pad_list(x, 0)
27 | tensor([[1., 1., 1., 1.],
28 | [1., 1., 0., 0.],
29 | [1., 0., 0., 0.]])
30 |
31 | """
32 | n_batch = len(xs)
33 | max_len = max([x.size(0) for x in xs])
34 | pad = torch.zeros(n_batch, max_len, dtype=xs[0].dtype, device=xs[0].device)
35 | pad = pad.fill_(pad_value)
36 | for i in range(n_batch):
37 | pad[i, :xs[i].size(0)] = xs[i]
38 |
39 | return pad
40 |
41 |
42 | def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int,
43 | ignore_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
44 | """Add and labels.
45 |
46 | Args:
47 | ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)
48 | sos (int): index of
49 | eos (int): index of
50 | ignore_id (int): index of padding
51 |
52 | Returns:
53 | ys_in (torch.Tensor) : (B, Lmax + 1)
54 | ys_out (torch.Tensor) : (B, Lmax + 1)
55 |
56 | Examples:
57 | >>> sos_id = 10
58 | >>> eos_id = 11
59 | >>> ignore_id = -1
60 | >>> ys_pad
61 | tensor([[ 1, 2, 3, 4, 5],
62 | [ 4, 5, 6, -1, -1],
63 | [ 7, 8, 9, -1, -1]], dtype=torch.int32)
64 | >>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)
65 | >>> ys_in
66 | tensor([[10, 1, 2, 3, 4, 5],
67 | [10, 4, 5, 6, 11, 11],
68 | [10, 7, 8, 9, 11, 11]])
69 | >>> ys_out
70 | tensor([[ 1, 2, 3, 4, 5, 11],
71 | [ 4, 5, 6, 11, -1, -1],
72 | [ 7, 8, 9, 11, -1, -1]])
73 | """
74 | _sos = torch.tensor([sos],
75 | dtype=torch.long,
76 | requires_grad=False,
77 | device=ys_pad.device)
78 | _eos = torch.tensor([eos],
79 | dtype=torch.long,
80 | requires_grad=False,
81 | device=ys_pad.device)
82 | ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
83 | ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
84 | ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
85 | return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
86 |
87 |
88 | def reverse_pad_list(ys_pad: torch.Tensor,
89 | ys_lens: torch.Tensor,
90 | pad_value: float = -1.0) -> torch.Tensor:
91 | """Reverse padding for the list of tensors.
92 |
93 | Args:
94 | ys_pad (tensor): The padded tensor (B, Tokenmax).
95 | ys_lens (tensor): The lens of token seqs (B)
96 | pad_value (int): Value for padding.
97 |
98 | Returns:
99 | Tensor: Padded tensor (B, Tokenmax).
100 |
101 | Examples:
102 | >>> x
103 | tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
104 | >>> pad_list(x, 0)
105 | tensor([[4, 3, 2, 1],
106 | [7, 6, 5, 0],
107 | [9, 8, 0, 0]])
108 |
109 | """
110 | r_ys_pad = pad_sequence([(torch.flip(y.int()[:i], [0]))
111 | for y, i in zip(ys_pad, ys_lens)], True,
112 | pad_value)
113 | return r_ys_pad
114 |
115 |
116 | def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
117 | ignore_label: int) -> float:
118 | """Calculate accuracy.
119 |
120 | Args:
121 | pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
122 | pad_targets (LongTensor): Target label tensors (B, Lmax, D).
123 | ignore_label (int): Ignore label id.
124 |
125 | Returns:
126 | float: Accuracy value (0.0 - 1.0).
127 |
128 | """
129 | pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1),
130 | pad_outputs.size(1)).argmax(2)
131 | mask = pad_targets != ignore_label
132 | numerator = torch.sum(
133 | pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
134 | denominator = torch.sum(mask)
135 | return float(numerator) / float(denominator)
136 |
137 |
138 | def get_activation(act):
139 | """Return activation function."""
140 | # Lazy load to avoid unused import
141 | from wenet.transformer.swish import Swish
142 |
143 | activation_funcs = {
144 | "hardtanh": torch.nn.Hardtanh,
145 | "tanh": torch.nn.Tanh,
146 | "relu": torch.nn.ReLU,
147 | "selu": torch.nn.SELU,
148 | "swish": Swish,
149 | "gelu": torch.nn.GELU
150 | }
151 |
152 | return activation_funcs[act]()
153 |
154 |
155 | def get_subsample(config):
156 | input_layer = config["encoder_conf"]["input_layer"]
157 | assert input_layer in ["conv2d", "conv2d6", "conv2d8"]
158 | if input_layer == "conv2d":
159 | return 4
160 | elif input_layer == "conv2d6":
161 | return 6
162 | elif input_layer == "conv2d8":
163 | return 8
164 |
165 |
166 | def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
167 | new_hyp: List[int] = []
168 | cur = 0
169 | while cur < len(hyp):
170 | if hyp[cur] != 0:
171 | new_hyp.append(hyp[cur])
172 | prev = cur
173 | while cur < len(hyp) and hyp[cur] == hyp[prev]:
174 | cur += 1
175 | return new_hyp
176 |
177 |
178 | def log_add(args: List[int]) -> float:
179 | """
180 | Stable log add
181 | """
182 | if all(a == -float('inf') for a in args):
183 | return -float('inf')
184 | a_max = max(args)
185 | lsp = math.log(sum(math.exp(a - a_max) for a in args))
186 | return a_max + lsp
187 |
--------------------------------------------------------------------------------
/wenet/utils/ctc_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Mobvoi Inc. All Rights Reserved.
2 | # Author: binbinzhang@mobvoi.com (Di Wu)
3 |
4 | import numpy as np
5 | import torch
6 |
7 | def insert_blank(label, blank_id=0):
8 | """Insert blank token between every two label token."""
9 | label = np.expand_dims(label, 1)
10 | blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
11 | label = np.concatenate([blanks, label], axis=1)
12 | label = label.reshape(-1)
13 | label = np.append(label, label[0])
14 | return label
15 |
16 | def forced_align(ctc_probs: torch.Tensor,
17 | y: torch.Tensor,
18 | blank_id=0) -> list:
19 | """ctc forced alignment.
20 |
21 | Args:
22 | torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D)
23 | torch.Tensor y: id sequence tensor 1d tensor (L)
24 | int blank_id: blank symbol index
25 | Returns:
26 | torch.Tensor: alignment result
27 | """
28 | y_insert_blank = insert_blank(y, blank_id)
29 |
30 | log_alpha = torch.zeros((ctc_probs.size(0), len(y_insert_blank)))
31 | log_alpha = log_alpha - float('inf') # log of zero
32 | state_path = (torch.zeros(
33 | (ctc_probs.size(0), len(y_insert_blank)), dtype=torch.int16) - 1
34 | ) # state path
35 |
36 | # init start state
37 | log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]]
38 | log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]]
39 |
40 | for t in range(1, ctc_probs.size(0)):
41 | for s in range(len(y_insert_blank)):
42 | if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[
43 | s] == y_insert_blank[s - 2]:
44 | candidates = torch.tensor(
45 | [log_alpha[t - 1, s], log_alpha[t - 1, s - 1]])
46 | prev_state = [s, s - 1]
47 | else:
48 | candidates = torch.tensor([
49 | log_alpha[t - 1, s],
50 | log_alpha[t - 1, s - 1],
51 | log_alpha[t - 1, s - 2],
52 | ])
53 | prev_state = [s, s - 1, s - 2]
54 | log_alpha[t, s] = torch.max(candidates) + ctc_probs[t][y_insert_blank[s]]
55 | state_path[t, s] = prev_state[torch.argmax(candidates)]
56 |
57 | state_seq = -1 * torch.ones((ctc_probs.size(0), 1), dtype=torch.int16)
58 |
59 | candidates = torch.tensor([
60 | log_alpha[-1, len(y_insert_blank) - 1],
61 | log_alpha[-1, len(y_insert_blank) - 2]
62 | ])
63 | prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2]
64 | state_seq[-1] = prev_state[torch.argmax(candidates)]
65 | for t in range(ctc_probs.size(0) - 2, -1, -1):
66 | state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]
67 |
68 | output_alignment = []
69 | for t in range(0, ctc_probs.size(0)):
70 | output_alignment.append(y_insert_blank[state_seq[t, 0]])
71 |
72 | return output_alignment
73 |
--------------------------------------------------------------------------------
/wenet/utils/executor.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Mobvoi Inc. All Rights Reserved.
2 | # Author: binbinzhang@mobvoi.com (Binbin Zhang)
3 |
4 | import logging
5 | from contextlib import nullcontext
6 | # if your python version < 3.7 use the below one
7 | # from contextlib import suppress as nullcontext
8 | import torch
9 | from torch.nn.utils import clip_grad_norm_
10 |
11 |
12 | class Executor:
13 | def __init__(self):
14 | self.step = 0
15 |
16 | def train(self, model, optimizer, scheduler, data_loader, device, writer,
17 | args, scaler):
18 | ''' Train one epoch
19 | '''
20 | model.train()
21 | clip = args.get('grad_clip', 50.0)
22 | log_interval = args.get('log_interval', 10)
23 | rank = args.get('rank', 0)
24 | accum_grad = args.get('accum_grad', 1)
25 | is_distributed = args.get('is_distributed', True)
26 | use_amp = args.get('use_amp', False)
27 | logging.info('using accumulate grad, new batch size is {} times'
28 | 'larger than before'.format(accum_grad))
29 | if use_amp:
30 | assert scaler is not None
31 | num_seen_utts = 0
32 | num_total_batch = len(data_loader)
33 | for batch_idx, batch in enumerate(data_loader):
34 | key, feats, target, feats_lengths, target_lengths = batch
35 | feats = feats.to(device)
36 | target = target.to(device)
37 | feats_lengths = feats_lengths.to(device)
38 | target_lengths = target_lengths.to(device)
39 | num_utts = target_lengths.size(0)
40 | if num_utts == 0:
41 | continue
42 | context = None
43 | # Disable gradient synchronizations across DDP processes.
44 | # Within this context, gradients will be accumulated on module
45 | # variables, which will later be synchronized.
46 | if is_distributed and batch_idx % accum_grad != 0:
47 | context = model.no_sync
48 | # Used for single gpu training and DDP gradient synchronization
49 | # processes.
50 | else:
51 | context = nullcontext
52 | with context():
53 | # autocast context
54 | # The more details about amp can be found in
55 | # https://pytorch.org/docs/stable/notes/amp_examples.html
56 | with torch.cuda.amp.autocast(scaler is not None):
57 | loss, loss_att, loss_ctc = model(feats, feats_lengths,
58 | target, target_lengths)
59 | loss = loss / accum_grad
60 | if use_amp:
61 | scaler.scale(loss).backward()
62 | else:
63 | loss.backward()
64 |
65 | num_seen_utts += num_utts
66 | if batch_idx % accum_grad == 0:
67 | if rank == 0 and writer is not None:
68 | writer.add_scalar('train_loss', loss, self.step)
69 | # Use mixed precision training
70 | if use_amp:
71 | scaler.unscale_(optimizer)
72 | grad_norm = clip_grad_norm_(model.parameters(), clip)
73 | # Must invoke scaler.update() if unscale_() is used in the
74 | # iteration to avoid the following error:
75 | # RuntimeError: unscale_() has already been called
76 | # on this optimizer since the last update().
77 | # We don't check grad here since that if the gradient has
78 | # inf/nan values, scaler.step will skip optimizer.step().
79 | scaler.step(optimizer)
80 | scaler.update()
81 | else:
82 | grad_norm = clip_grad_norm_(model.parameters(), clip)
83 | if torch.isfinite(grad_norm):
84 | optimizer.step()
85 | optimizer.zero_grad()
86 | scheduler.step()
87 | self.step += 1
88 | if batch_idx % log_interval == 0:
89 | lr = optimizer.param_groups[0]['lr']
90 | log_str = 'TRAIN Batch {}/{} loss {:.6f} '.format(
91 | batch_idx, num_total_batch,
92 | loss.item() * accum_grad)
93 | if loss_att is not None:
94 | log_str += 'loss_att {:.6f} '.format(loss_att.item())
95 | if loss_ctc is not None:
96 | log_str += 'loss_ctc {:.6f} '.format(loss_ctc.item())
97 | log_str += 'lr {:.8f} rank {}'.format(lr, rank)
98 | logging.debug(log_str)
99 |
100 | def cv(self, model, data_loader, device, args):
101 | ''' Cross validation on
102 | '''
103 | model.eval()
104 | log_interval = args.get('log_interval', 10)
105 | # in order to avoid division by 0
106 | num_seen_utts = 1
107 | total_loss = 0.0
108 | num_total_batch = len(data_loader)
109 | with torch.no_grad():
110 | for batch_idx, batch in enumerate(data_loader):
111 | key, feats, target, feats_lengths, target_lengths = batch
112 | feats = feats.to(device)
113 | target = target.to(device)
114 | feats_lengths = feats_lengths.to(device)
115 | target_lengths = target_lengths.to(device)
116 | num_utts = target_lengths.size(0)
117 | if num_utts == 0:
118 | continue
119 | loss, loss_att, loss_ctc = model(feats, feats_lengths, target,
120 | target_lengths)
121 | if torch.isfinite(loss):
122 | num_seen_utts += num_utts
123 | total_loss += loss.item() * num_utts
124 | if batch_idx % log_interval == 0:
125 | log_str = 'CV Batch {}/{} loss {:.6f} '.format(
126 | batch_idx, num_total_batch, loss.item())
127 | if loss_att is not None:
128 | log_str += 'loss_att {:.6f} '.format(loss_att.item())
129 | if loss_ctc is not None:
130 | log_str += 'loss_ctc {:.6f} '.format(loss_ctc.item())
131 | log_str += 'history loss {:.6f}'.format(total_loss /
132 | num_seen_utts)
133 | logging.debug(log_str)
134 |
135 | return total_loss, num_seen_utts
136 |
--------------------------------------------------------------------------------
/wenet/utils/mask.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright 2019 Shigeki Karita
4 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
5 |
6 | import torch
7 |
8 |
9 | def subsequent_mask(
10 | size: int,
11 | device: torch.device = torch.device("cpu"),
12 | ) -> torch.Tensor:
13 | """Create mask for subsequent steps (size, size).
14 |
15 | This mask is used only in decoder which works in an auto-regressive mode.
16 | This means the current step could only do attention with its left steps.
17 |
18 | In encoder, fully attention is used when streaming is not necessary and
19 | the sequence is not long. In this case, no attention mask is needed.
20 |
21 | When streaming is need, chunk-based attention is used in encoder. See
22 | subsequent_chunk_mask for the chunk-based attention mask.
23 |
24 | Args:
25 | size (int): size of mask
26 | str device (str): "cpu" or "cuda" or torch.Tensor.device
27 | dtype (torch.device): result dtype
28 |
29 | Returns:
30 | torch.Tensor: mask
31 |
32 | Examples:
33 | >>> subsequent_mask(3)
34 | [[1, 0, 0],
35 | [1, 1, 0],
36 | [1, 1, 1]]
37 | """
38 | ret = torch.ones(size, size, device=device, dtype=torch.bool)
39 | return torch.tril(ret, out=ret)
40 |
41 |
42 | def subsequent_chunk_mask(
43 | size: int,
44 | chunk_size: int,
45 | num_left_chunks: int = -1,
46 | device: torch.device = torch.device("cpu"),
47 | ) -> torch.Tensor:
48 | """Create mask for subsequent steps (size, size) with chunk size,
49 | this is for streaming encoder
50 |
51 | Args:
52 | size (int): size of mask
53 | chunk_size (int): size of chunk
54 | num_left_chunks (int): number of left chunks
55 | <0: use full chunk
56 | >=0: use num_left_chunks
57 | device (torch.device): "cpu" or "cuda" or torch.Tensor.device
58 |
59 | Returns:
60 | torch.Tensor: mask
61 |
62 | Examples:
63 | >>> subsequent_chunk_mask(4, 2)
64 | [[1, 1, 0, 0],
65 | [1, 1, 0, 0],
66 | [1, 1, 1, 1],
67 | [1, 1, 1, 1]]
68 | """
69 | ret = torch.zeros(size, size, device=device, dtype=torch.bool)
70 | for i in range(size):
71 | if num_left_chunks < 0:
72 | start = 0
73 | else:
74 | start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
75 | ending = min((i // chunk_size + 1) * chunk_size, size)
76 | ret[i, start:ending] = True
77 | return ret
78 |
79 |
80 | def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
81 | use_dynamic_chunk: bool,
82 | use_dynamic_left_chunk: bool,
83 | decoding_chunk_size: int, static_chunk_size: int,
84 | num_decoding_left_chunks: int):
85 | """ Apply optional mask for encoder.
86 |
87 | Args:
88 | xs (torch.Tensor): padded input, (B, L, D), L for max length
89 | mask (torch.Tensor): mask for xs, (B, 1, L)
90 | use_dynamic_chunk (bool): whether to use dynamic chunk or not
91 | use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
92 | training.
93 | decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
94 | 0: default for training, use random dynamic chunk.
95 | <0: for decoding, use full chunk.
96 | >0: for decoding, use fixed chunk size as set.
97 | static_chunk_size (int): chunk size for static chunk training/decoding
98 | if it's greater than 0, if use_dynamic_chunk is true,
99 | this parameter will be ignored
100 | num_decoding_left_chunks: number of left chunks, this is for decoding,
101 | the chunk size is decoding_chunk_size.
102 | >=0: use num_decoding_left_chunks
103 | <0: use all left chunks
104 |
105 | Returns:
106 | torch.Tensor: chunk mask of the input xs.
107 | """
108 | # Whether to use chunk mask or not
109 | if use_dynamic_chunk:
110 | max_len = xs.size(1)
111 | if decoding_chunk_size < 0:
112 | chunk_size = max_len
113 | num_left_chunks = -1
114 | elif decoding_chunk_size > 0:
115 | chunk_size = decoding_chunk_size
116 | num_left_chunks = num_decoding_left_chunks
117 | else:
118 | # chunk size is either [1, 25] or full context(max_len).
119 | # Since we use 4 times subsampling and allow up to 1s(100 frames)
120 | # delay, the maximum frame is 100 / 4 = 25.
121 | chunk_size = torch.randint(1, max_len, (1, )).item()
122 | num_left_chunks = -1
123 | if chunk_size > max_len // 2:
124 | chunk_size = max_len
125 | else:
126 | chunk_size = chunk_size % 25 + 1
127 | if use_dynamic_left_chunk:
128 | max_left_chunks = (max_len - 1) // chunk_size
129 | num_left_chunks = torch.randint(0, max_left_chunks,
130 | (1, )).item()
131 | chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
132 | num_left_chunks,
133 | xs.device) # (L, L)
134 | chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
135 | chunk_masks = masks & chunk_masks # (B, L, L)
136 | elif static_chunk_size > 0:
137 | num_left_chunks = num_decoding_left_chunks
138 | chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
139 | num_left_chunks,
140 | xs.device) # (L, L)
141 | chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
142 | chunk_masks = masks & chunk_masks # (B, L, L)
143 | else:
144 | chunk_masks = masks
145 | return chunk_masks
146 |
147 |
148 | def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
149 | """Make mask tensor containing indices of padded part.
150 |
151 | See description of make_non_pad_mask.
152 |
153 | Args:
154 | lengths (torch.Tensor): Batch of lengths (B,).
155 | Returns:
156 | torch.Tensor: Mask tensor containing indices of padded part.
157 |
158 | Examples:
159 | >>> lengths = [5, 3, 2]
160 | >>> make_pad_mask(lengths)
161 | masks = [[0, 0, 0, 0 ,0],
162 | [0, 0, 0, 1, 1],
163 | [0, 0, 1, 1, 1]]
164 | """
165 | batch_size = int(lengths.size(0))
166 | max_len = int(lengths.max().item())
167 | seq_range = torch.arange(0,
168 | max_len,
169 | dtype=torch.int64,
170 | device=lengths.device)
171 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
172 | seq_length_expand = lengths.unsqueeze(-1)
173 | mask = seq_range_expand >= seq_length_expand
174 | return mask
175 |
176 |
177 | def make_non_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
178 | """Make mask tensor containing indices of non-padded part.
179 |
180 | The sequences in a batch may have different lengths. To enable
181 | batch computing, padding is need to make all sequence in same
182 | size. To avoid the padding part pass value to context dependent
183 | block such as attention or convolution , this padding part is
184 | masked.
185 |
186 | This pad_mask is used in both encoder and decoder.
187 |
188 | 1 for non-padded part and 0 for padded part.
189 |
190 | Args:
191 | lengths (torch.Tensor): Batch of lengths (B,).
192 | Returns:
193 | torch.Tensor: mask tensor containing indices of padded part.
194 |
195 | Examples:
196 | >>> lengths = [5, 3, 2]
197 | >>> make_non_pad_mask(lengths)
198 | masks = [[1, 1, 1, 1 ,1],
199 | [1, 1, 1, 0, 0],
200 | [1, 1, 0, 0, 0]]
201 | """
202 | return ~make_pad_mask(lengths)
203 |
204 |
205 | def mask_finished_scores(score: torch.Tensor,
206 | flag: torch.Tensor) -> torch.Tensor:
207 | """
208 | If a sequence is finished, we only allow one alive branch. This function
209 | aims to give one branch a zero score and the rest -inf score.
210 |
211 | Args:
212 | score (torch.Tensor): A real value array with shape
213 | (batch_size * beam_size, beam_size).
214 | flag (torch.Tensor): A bool array with shape
215 | (batch_size * beam_size, 1).
216 |
217 | Returns:
218 | torch.Tensor: (batch_size * beam_size, beam_size).
219 | """
220 | beam_size = score.size(-1)
221 | zero_mask = torch.zeros_like(flag, dtype=torch.bool)
222 | if beam_size > 1:
223 | unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])),
224 | dim=1)
225 | finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])),
226 | dim=1)
227 | else:
228 | unfinished = zero_mask
229 | finished = flag
230 | score.masked_fill_(unfinished, -float('inf'))
231 | score.masked_fill_(finished, 0)
232 | return score
233 |
234 |
235 | def mask_finished_preds(pred: torch.Tensor, flag: torch.Tensor,
236 | eos: int) -> torch.Tensor:
237 | """
238 | If a sequence is finished, all of its branch should be
239 |
240 | Args:
241 | pred (torch.Tensor): A int array with shape
242 | (batch_size * beam_size, beam_size).
243 | flag (torch.Tensor): A bool array with shape
244 | (batch_size * beam_size, 1).
245 |
246 | Returns:
247 | torch.Tensor: (batch_size * beam_size).
248 | """
249 | beam_size = pred.size(-1)
250 | finished = flag.repeat([1, beam_size])
251 | return pred.masked_fill_(finished, eos)
252 |
--------------------------------------------------------------------------------
/wenet/utils/scheduler.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import torch
4 | from torch.optim.lr_scheduler import _LRScheduler
5 |
6 | from typeguard import check_argument_types
7 |
8 |
9 | class WarmupLR(_LRScheduler):
10 | """The WarmupLR scheduler
11 |
12 | This scheduler is almost same as NoamLR Scheduler except for following
13 | difference:
14 |
15 | NoamLR:
16 | lr = optimizer.lr * model_size ** -0.5
17 | * min(step ** -0.5, step * warmup_step ** -1.5)
18 | WarmupLR:
19 | lr = optimizer.lr * warmup_step ** 0.5
20 | * min(step ** -0.5, step * warmup_step ** -1.5)
21 |
22 | Note that the maximum lr equals to optimizer.lr in this scheduler.
23 |
24 | """
25 |
26 | def __init__(
27 | self,
28 | optimizer: torch.optim.Optimizer,
29 | warmup_steps: Union[int, float] = 25000,
30 | last_epoch: int = -1,
31 | ):
32 | assert check_argument_types()
33 | self.warmup_steps = warmup_steps
34 |
35 | # __init__() must be invoked before setting field
36 | # because step() is also invoked in __init__()
37 | super().__init__(optimizer, last_epoch)
38 |
39 | def __repr__(self):
40 | return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"
41 |
42 | def get_lr(self):
43 | step_num = self.last_epoch + 1
44 | return [
45 | lr
46 | * self.warmup_steps ** 0.5
47 | * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5)
48 | for lr in self.base_lrs
49 | ]
50 |
51 | def set_step(self, step: int):
52 | self.last_epoch = step
53 |
--------------------------------------------------------------------------------