├── .gitignore ├── LICENSE ├── README.md ├── benchmarks └── LRS3 │ ├── language_models │ └── README.md │ └── models │ └── README.md ├── configs └── LRS3_V_WER19.1.ini ├── espnet ├── asr │ └── asr_utils.py ├── nets │ ├── batch_beam_search.py │ ├── beam_search.py │ ├── ctc_prefix_score.py │ ├── e2e_asr_common.py │ ├── lm_interface.py │ ├── pytorch_backend │ │ ├── backbones │ │ │ ├── conv1d_extractor.py │ │ │ ├── conv3d_extractor.py │ │ │ └── modules │ │ │ │ ├── resnet.py │ │ │ │ ├── resnet1d.py │ │ │ │ └── shufflenetv2.py │ │ ├── ctc.py │ │ ├── e2e_asr_transformer.py │ │ ├── e2e_asr_transformer_av.py │ │ ├── lm │ │ │ ├── __init__.py │ │ │ ├── default.py │ │ │ ├── seq_rnn.py │ │ │ └── transformer.py │ │ ├── nets_utils.py │ │ └── transformer │ │ │ ├── __init__.py │ │ │ ├── add_sos_eos.py │ │ │ ├── attention.py │ │ │ ├── convolution.py │ │ │ ├── decoder.py │ │ │ ├── decoder_layer.py │ │ │ ├── embedding.py │ │ │ ├── encoder.py │ │ │ ├── encoder_layer.py │ │ │ ├── label_smoothing_loss.py │ │ │ ├── layer_norm.py │ │ │ ├── mask.py │ │ │ ├── multi_layer_conv.py │ │ │ ├── optimizer.py │ │ │ ├── plot.py │ │ │ ├── positionwise_feed_forward.py │ │ │ ├── raw_embeddings.py │ │ │ ├── repeat.py │ │ │ └── subsampling.py │ ├── scorer_interface.py │ └── scorers │ │ ├── __init__.py │ │ ├── ctc.py │ │ └── length_bonus.py └── utils │ ├── cli_utils.py │ ├── dynamic_import.py │ └── fill_missing_args.py ├── hydra_configs └── default.yaml ├── main.py ├── pipelines ├── data │ ├── data_module.py │ ├── noise │ │ ├── babble_noise.wav │ │ ├── pink_noise.wav │ │ └── white_noise.wav │ └── transforms.py ├── detectors │ ├── mediapipe │ │ ├── 20words_mean_face.npy │ │ ├── detector.py │ │ └── video_process.py │ └── retinaface │ │ ├── 20words_mean_face.npy │ │ ├── detector.py │ │ └── video_process.py ├── metrics │ └── measures.py ├── model.py ├── pipeline.py └── tokens │ └── unigram5000_units.txt ├── requirements.txt └── thumbnail.png /.gitignore: -------------------------------------------------------------------------------- 1 | bin/ 2 | lib/ 3 | include/ 4 | pyvenv.cfg 5 | .Python 6 | *.py[cod] 7 | __pycache__/ 8 | *.so 9 | .Python 10 | pip-log.txt 11 | .env 12 | .venv 13 | env/ 14 | venv/ 15 | ENV/ 16 | env.bak/ 17 | venv.bak/ 18 | .env.local 19 | .env.development.local 20 | .env.test.local 21 | .env.production.local 22 | *.pyc 23 | .DS_Store 24 | .idea/ 25 | .vscode/ 26 | *.swp 27 | *.swo 28 | dist/ 29 | build/ 30 | *.egg 31 | *.egg-info/ 32 | benchmarks/LRS3/language_models/lm_en_subword/ 33 | benchmarks/LRS3/models/LRS3_V_WER19.1/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Amanvir Parhar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Chaplin 2 | 3 | ![Chaplin Thumbnail](./thumbnail.png) 4 | 5 | A visual speech recognition (VSR) tool that reads your lips in real-time and types whatever you silently mouth. Runs fully locally. 6 | 7 | Relies on a [model](https://github.com/mpc001/Visual_Speech_Recognition_for_Multiple_Languages?tab=readme-ov-file#autoavsr-models) trained on the [Lip Reading Sentences 3](https://mmai.io/datasets/lip_reading/) dataset as part of the [Auto-AVSR](https://github.com/mpc001/auto_avsr) project. 8 | 9 | Watch a demo of Chaplin [here](https://youtu.be/qlHi0As2alQ). 10 | 11 | ## Setup 12 | 13 | 1. Clone the repository, and `cd` into it: 14 | ```bash 15 | git clone https://github.com/amanvirparhar/chaplin 16 | cd chaplin 17 | ``` 18 | 2. Download the required model components: [LRS3_V_WER19.1](https://drive.google.com/file/d/1t8RHhzDTTvOQkLQhmK1LZGnXRRXOXGi6/view) and [lm_en_subword](https://drive.google.com/file/d/1g31HGxJnnOwYl17b70ObFQZ1TSnPvRQv/view). 19 | 3. Unzip both folders, and place them in their respective directories: 20 | ``` 21 | chaplin/ 22 | ├── benchmarks/ 23 | ├── LRS3/ 24 | ├── language_models/ 25 | ├── lm_en_subword/ 26 | ├── models/ 27 | ├── LRS3_V_WER19.1/ 28 | ├── ... 29 | ``` 30 | 4. Install and run `ollama`, and pull the [`llama3.2`](https://ollama.com/library/llama3.2) model. 31 | 5. Install [`uv`](https://github.com/astral-sh/uv). 32 | 33 | ## Usage 34 | 35 | 1. Run the following command: 36 | ```bash 37 | sudo uv run --with-requirements requirements.txt --python 3.12 main.py config_filename=./configs/LRS3_V_WER19.1.ini detector=mediapipe 38 | ``` 39 | 2. Once the camera feed is displayed, you can start "recording" by pressing the `option` key (Mac) or the `alt` key (Windows/Linux), and start mouthing words. 40 | 3. To stop recording, press the `option` key (Mac) or the `alt` key (Windows/Linux) again. You should see some text being typed out wherever your cursor is. 41 | 4. To exit gracefully, focus on the window displaying the camera feed and press `q`. 42 | -------------------------------------------------------------------------------- /benchmarks/LRS3/language_models/README.md: -------------------------------------------------------------------------------- 1 | Put the `lm_en_subword` folder in this directory. -------------------------------------------------------------------------------- /benchmarks/LRS3/models/README.md: -------------------------------------------------------------------------------- 1 | Put the `LRS3_V_WER19.1` folder in this directory. -------------------------------------------------------------------------------- /configs/LRS3_V_WER19.1.ini: -------------------------------------------------------------------------------- 1 | [input] 2 | modality=video 3 | v_fps=25 4 | 5 | [model] 6 | v_fps=25 7 | model_path=benchmarks/LRS3/models/LRS3_V_WER19.1/model.pth 8 | model_conf=benchmarks/LRS3/models/LRS3_V_WER19.1/model.json 9 | rnnlm=benchmarks/LRS3/language_models/lm_en_subword/model.pth 10 | rnnlm_conf=benchmarks/LRS3/language_models/lm_en_subword/model.json 11 | 12 | [decode] 13 | beam_size=40 14 | penalty=0.0 15 | maxlenratio=0.0 16 | minlenratio=0.0 17 | ctc_weight=0.1 18 | lm_weight=0.3 19 | -------------------------------------------------------------------------------- /espnet/nets/e2e_asr_common.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | 4 | # Copyright 2017 Johns Hopkins University (Shinji Watanabe) 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Common functions for ASR.""" 8 | 9 | import json 10 | import logging 11 | import sys 12 | 13 | from itertools import groupby 14 | import numpy as np 15 | import six 16 | 17 | 18 | def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))): 19 | """End detection. 20 | 21 | described in Eq. (50) of S. Watanabe et al 22 | "Hybrid CTC/Attention Architecture for End-to-End Speech Recognition" 23 | 24 | :param ended_hyps: 25 | :param i: 26 | :param M: 27 | :param D_end: 28 | :return: 29 | """ 30 | if len(ended_hyps) == 0: 31 | return False 32 | count = 0 33 | best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0] 34 | for m in six.moves.range(M): 35 | # get ended_hyps with their length is i - m 36 | hyp_length = i - m 37 | hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length] 38 | if len(hyps_same_length) > 0: 39 | best_hyp_same_length = sorted( 40 | hyps_same_length, key=lambda x: x["score"], reverse=True 41 | )[0] 42 | if best_hyp_same_length["score"] - best_hyp["score"] < D_end: 43 | count += 1 44 | 45 | if count == M: 46 | return True 47 | else: 48 | return False 49 | 50 | 51 | # TODO(takaaki-hori): add different smoothing methods 52 | def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0): 53 | """Obtain label distribution for loss smoothing. 54 | 55 | :param odim: 56 | :param lsm_type: 57 | :param blank: 58 | :param transcript: 59 | :return: 60 | """ 61 | if transcript is not None: 62 | with open(transcript, "rb") as f: 63 | trans_json = json.load(f)["utts"] 64 | 65 | if lsm_type == "unigram": 66 | assert transcript is not None, ( 67 | "transcript is required for %s label smoothing" % lsm_type 68 | ) 69 | labelcount = np.zeros(odim) 70 | for k, v in trans_json.items(): 71 | ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()]) 72 | # to avoid an error when there is no text in an uttrance 73 | if len(ids) > 0: 74 | labelcount[ids] += 1 75 | labelcount[odim - 1] = len(transcript) # count 76 | labelcount[labelcount == 0] = 1 # flooring 77 | labelcount[blank] = 0 # remove counts for blank 78 | labeldist = labelcount.astype(np.float32) / np.sum(labelcount) 79 | else: 80 | logging.error("Error: unexpected label smoothing type: %s" % lsm_type) 81 | sys.exit() 82 | 83 | return labeldist 84 | 85 | 86 | def get_vgg2l_odim(idim, in_channel=3, out_channel=128): 87 | """Return the output size of the VGG frontend. 88 | 89 | :param in_channel: input channel size 90 | :param out_channel: output channel size 91 | :return: output size 92 | :rtype int 93 | """ 94 | idim = idim / in_channel 95 | idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 1st max pooling 96 | idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 2nd max pooling 97 | return int(idim) * out_channel # numer of channels 98 | 99 | 100 | class ErrorCalculator(object): 101 | """Calculate CER and WER for E2E_ASR and CTC models during training. 102 | 103 | :param y_hats: numpy array with predicted text 104 | :param y_pads: numpy array with true (target) text 105 | :param char_list: 106 | :param sym_space: 107 | :param sym_blank: 108 | :return: 109 | """ 110 | 111 | def __init__( 112 | self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False 113 | ): 114 | """Construct an ErrorCalculator object.""" 115 | super(ErrorCalculator, self).__init__() 116 | 117 | self.report_cer = report_cer 118 | self.report_wer = report_wer 119 | 120 | self.char_list = char_list 121 | self.space = sym_space 122 | self.blank = sym_blank 123 | self.idx_blank = self.char_list.index(self.blank) 124 | if self.space in self.char_list: 125 | self.idx_space = self.char_list.index(self.space) 126 | else: 127 | self.idx_space = None 128 | 129 | def __call__(self, ys_hat, ys_pad, is_ctc=False): 130 | """Calculate sentence-level WER/CER score. 131 | 132 | :param torch.Tensor ys_hat: prediction (batch, seqlen) 133 | :param torch.Tensor ys_pad: reference (batch, seqlen) 134 | :param bool is_ctc: calculate CER score for CTC 135 | :return: sentence-level WER score 136 | :rtype float 137 | :return: sentence-level CER score 138 | :rtype float 139 | """ 140 | cer, wer = None, None 141 | if is_ctc: 142 | return self.calculate_cer_ctc(ys_hat, ys_pad) 143 | elif not self.report_cer and not self.report_wer: 144 | return cer, wer 145 | 146 | seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad) 147 | if self.report_cer: 148 | cer = self.calculate_cer(seqs_hat, seqs_true) 149 | 150 | if self.report_wer: 151 | wer = self.calculate_wer(seqs_hat, seqs_true) 152 | return cer, wer 153 | 154 | def calculate_cer_ctc(self, ys_hat, ys_pad): 155 | """Calculate sentence-level CER score for CTC. 156 | 157 | :param torch.Tensor ys_hat: prediction (batch, seqlen) 158 | :param torch.Tensor ys_pad: reference (batch, seqlen) 159 | :return: average sentence-level CER score 160 | :rtype float 161 | """ 162 | import editdistance 163 | 164 | cers, char_ref_lens = [], [] 165 | for i, y in enumerate(ys_hat): 166 | y_hat = [x[0] for x in groupby(y)] 167 | y_true = ys_pad[i] 168 | seq_hat, seq_true = [], [] 169 | for idx in y_hat: 170 | idx = int(idx) 171 | if idx != -1 and idx != self.idx_blank and idx != self.idx_space: 172 | seq_hat.append(self.char_list[int(idx)]) 173 | 174 | for idx in y_true: 175 | idx = int(idx) 176 | if idx != -1 and idx != self.idx_blank and idx != self.idx_space: 177 | seq_true.append(self.char_list[int(idx)]) 178 | 179 | hyp_chars = "".join(seq_hat) 180 | ref_chars = "".join(seq_true) 181 | if len(ref_chars) > 0: 182 | cers.append(editdistance.eval(hyp_chars, ref_chars)) 183 | char_ref_lens.append(len(ref_chars)) 184 | 185 | cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None 186 | return cer_ctc 187 | 188 | def convert_to_char(self, ys_hat, ys_pad): 189 | """Convert index to character. 190 | 191 | :param torch.Tensor seqs_hat: prediction (batch, seqlen) 192 | :param torch.Tensor seqs_true: reference (batch, seqlen) 193 | :return: token list of prediction 194 | :rtype list 195 | :return: token list of reference 196 | :rtype list 197 | """ 198 | seqs_hat, seqs_true = [], [] 199 | for i, y_hat in enumerate(ys_hat): 200 | y_true = ys_pad[i] 201 | eos_true = np.where(y_true == -1)[0] 202 | ymax = eos_true[0] if len(eos_true) > 0 else len(y_true) 203 | # NOTE: padding index (-1) in y_true is used to pad y_hat 204 | seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]] 205 | seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1] 206 | seq_hat_text = "".join(seq_hat).replace(self.space, " ") 207 | seq_hat_text = seq_hat_text.replace(self.blank, "") 208 | seq_true_text = "".join(seq_true).replace(self.space, " ") 209 | seqs_hat.append(seq_hat_text) 210 | seqs_true.append(seq_true_text) 211 | return seqs_hat, seqs_true 212 | 213 | def calculate_cer(self, seqs_hat, seqs_true): 214 | """Calculate sentence-level CER score. 215 | 216 | :param list seqs_hat: prediction 217 | :param list seqs_true: reference 218 | :return: average sentence-level CER score 219 | :rtype float 220 | """ 221 | import editdistance 222 | 223 | char_eds, char_ref_lens = [], [] 224 | for i, seq_hat_text in enumerate(seqs_hat): 225 | seq_true_text = seqs_true[i] 226 | hyp_chars = seq_hat_text.replace(" ", "") 227 | ref_chars = seq_true_text.replace(" ", "") 228 | char_eds.append(editdistance.eval(hyp_chars, ref_chars)) 229 | char_ref_lens.append(len(ref_chars)) 230 | return float(sum(char_eds)) / sum(char_ref_lens) 231 | 232 | def calculate_wer(self, seqs_hat, seqs_true): 233 | """Calculate sentence-level WER score. 234 | 235 | :param list seqs_hat: prediction 236 | :param list seqs_true: reference 237 | :return: average sentence-level WER score 238 | :rtype float 239 | """ 240 | import editdistance 241 | 242 | word_eds, word_ref_lens = [], [] 243 | for i, seq_hat_text in enumerate(seqs_hat): 244 | seq_true_text = seqs_true[i] 245 | hyp_words = seq_hat_text.split() 246 | ref_words = seq_true_text.split() 247 | word_eds.append(editdistance.eval(hyp_words, ref_words)) 248 | word_ref_lens.append(len(ref_words)) 249 | return float(sum(word_eds)) / sum(word_ref_lens) 250 | -------------------------------------------------------------------------------- /espnet/nets/lm_interface.py: -------------------------------------------------------------------------------- 1 | """Language model interface.""" 2 | 3 | import argparse 4 | 5 | from espnet.nets.scorer_interface import ScorerInterface 6 | from espnet.utils.dynamic_import import dynamic_import 7 | from espnet.utils.fill_missing_args import fill_missing_args 8 | 9 | 10 | class LMInterface(ScorerInterface): 11 | """LM Interface for ESPnet model implementation.""" 12 | 13 | @staticmethod 14 | def add_arguments(parser): 15 | """Add arguments to command line argument parser.""" 16 | return parser 17 | 18 | @classmethod 19 | def build(cls, n_vocab: int, **kwargs): 20 | """Initialize this class with python-level args. 21 | 22 | Args: 23 | idim (int): The number of vocabulary. 24 | 25 | Returns: 26 | LMinterface: A new instance of LMInterface. 27 | 28 | """ 29 | # local import to avoid cyclic import in lm_train 30 | from espnet.bin.lm_train import get_parser 31 | 32 | def wrap(parser): 33 | return get_parser(parser, required=False) 34 | 35 | args = argparse.Namespace(**kwargs) 36 | args = fill_missing_args(args, wrap) 37 | args = fill_missing_args(args, cls.add_arguments) 38 | return cls(n_vocab, args) 39 | 40 | def forward(self, x, t): 41 | """Compute LM loss value from buffer sequences. 42 | 43 | Args: 44 | x (torch.Tensor): Input ids. (batch, len) 45 | t (torch.Tensor): Target ids. (batch, len) 46 | 47 | Returns: 48 | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of 49 | loss to backward (scalar), 50 | negative log-likelihood of t: -log p(t) (scalar) and 51 | the number of elements in x (scalar) 52 | 53 | Notes: 54 | The last two return values are used 55 | in perplexity: p(t)^{-n} = exp(-log p(t) / n) 56 | 57 | """ 58 | raise NotImplementedError("forward method is not implemented") 59 | 60 | 61 | predefined_lms = { 62 | "pytorch": { 63 | "default": "espnet.nets.pytorch_backend.lm.default:DefaultRNNLM", 64 | "seq_rnn": "espnet.nets.pytorch_backend.lm.seq_rnn:SequentialRNNLM", 65 | "transformer": "espnet.nets.pytorch_backend.lm.transformer:TransformerLM", 66 | }, 67 | "chainer": {"default": "espnet.lm.chainer_backend.lm:DefaultRNNLM"}, 68 | } 69 | 70 | 71 | def dynamic_import_lm(module, backend): 72 | """Import LM class dynamically. 73 | 74 | Args: 75 | module (str): module_name:class_name or alias in `predefined_lms` 76 | backend (str): NN backend. e.g., pytorch, chainer 77 | 78 | Returns: 79 | type: LM class 80 | 81 | """ 82 | model_class = dynamic_import(module, predefined_lms.get(backend, dict())) 83 | assert issubclass( 84 | model_class, LMInterface 85 | ), f"{module} does not implement LMInterface" 86 | return model_class 87 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/backbones/conv1d_extractor.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2021 Imperial College London (Pingchuan Ma) 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | import torch 7 | from espnet.nets.pytorch_backend.backbones.modules.resnet1d import ResNet1D, BasicBlock1D 8 | 9 | class Conv1dResNet(torch.nn.Module): 10 | def __init__(self, relu_type="swish", a_upsample_ratio=1): 11 | super().__init__() 12 | self.a_upsample_ratio = a_upsample_ratio 13 | self.trunk = ResNet1D(BasicBlock1D, [2, 2, 2, 2], relu_type=relu_type, a_upsample_ratio=a_upsample_ratio) 14 | 15 | 16 | def forward(self, xs_pad): 17 | """forward. 18 | 19 | :param xs_pad: torch.Tensor, batch of padded input sequences (B, Tmax, idim) 20 | """ 21 | B, T, C = xs_pad.size() 22 | xs_pad = xs_pad[:, :T // 640 * 640, :] 23 | xs_pad = xs_pad.transpose(1, 2) 24 | xs_pad = self.trunk(xs_pad) 25 | return xs_pad.transpose(1, 2) 26 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/backbones/conv3d_extractor.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2021 Imperial College London (Pingchuan Ma) 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | import torch 8 | import torch.nn as nn 9 | from espnet.nets.pytorch_backend.backbones.modules.resnet import ResNet, BasicBlock 10 | from espnet.nets.pytorch_backend.transformer.convolution import Swish 11 | 12 | 13 | def threeD_to_2D_tensor(x): 14 | n_batch, n_channels, s_time, sx, sy = x.shape 15 | x = x.transpose(1, 2) 16 | return x.reshape(n_batch * s_time, n_channels, sx, sy) 17 | 18 | 19 | 20 | class Conv3dResNet(torch.nn.Module): 21 | """Conv3dResNet module 22 | """ 23 | 24 | def __init__(self, backbone_type="resnet", relu_type="swish"): 25 | """__init__. 26 | 27 | :param backbone_type: str, the type of a visual front-end. 28 | :param relu_type: str, activation function used in an audio front-end. 29 | """ 30 | super(Conv3dResNet, self).__init__() 31 | self.frontend_nout = 64 32 | self.trunk = ResNet(BasicBlock, [2, 2, 2, 2], relu_type=relu_type) 33 | self.frontend3D = nn.Sequential( 34 | nn.Conv3d(1, self.frontend_nout, (5, 7, 7), (1, 2, 2), (2, 3, 3), bias=False), 35 | nn.BatchNorm3d(self.frontend_nout), 36 | Swish(), 37 | nn.MaxPool3d((1, 3, 3), (1, 2, 2), (0, 1, 1)) 38 | ) 39 | 40 | 41 | def forward(self, xs_pad): 42 | B, C, T, H, W = xs_pad.size() 43 | xs_pad = self.frontend3D(xs_pad) 44 | Tnew = xs_pad.shape[2] 45 | xs_pad = threeD_to_2D_tensor(xs_pad) 46 | xs_pad = self.trunk(xs_pad) 47 | return xs_pad.view(B, Tnew, xs_pad.size(1)) 48 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/backbones/modules/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import pdb 4 | 5 | from espnet.nets.pytorch_backend.transformer.convolution import Swish 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | """conv3x3. 10 | 11 | :param in_planes: int, number of channels in the input sequence. 12 | :param out_planes: int, number of channels produced by the convolution. 13 | :param stride: int, size of the convolving kernel. 14 | """ 15 | return nn.Conv2d( 16 | in_planes, 17 | out_planes, 18 | kernel_size=3, 19 | stride=stride, 20 | padding=1, 21 | bias=False, 22 | ) 23 | 24 | 25 | def downsample_basic_block(inplanes, outplanes, stride): 26 | """downsample_basic_block. 27 | 28 | :param inplanes: int, number of channels in the input sequence. 29 | :param outplanes: int, number of channels produced by the convolution. 30 | :param stride: int, size of the convolving kernel. 31 | """ 32 | return nn.Sequential( 33 | nn.Conv2d( 34 | inplanes, 35 | outplanes, 36 | kernel_size=1, 37 | stride=stride, 38 | bias=False, 39 | ), 40 | nn.BatchNorm2d(outplanes), 41 | ) 42 | 43 | 44 | class BasicBlock(nn.Module): 45 | expansion = 1 46 | 47 | def __init__( 48 | self, 49 | inplanes, 50 | planes, 51 | stride=1, 52 | downsample=None, 53 | relu_type="swish", 54 | ): 55 | """__init__. 56 | 57 | :param inplanes: int, number of channels in the input sequence. 58 | :param planes: int, number of channels produced by the convolution. 59 | :param stride: int, size of the convolving kernel. 60 | :param downsample: boolean, if True, the temporal resolution is downsampled. 61 | :param relu_type: str, type of activation function. 62 | """ 63 | super(BasicBlock, self).__init__() 64 | 65 | assert relu_type in ["relu", "prelu", "swish"] 66 | 67 | self.conv1 = conv3x3(inplanes, planes, stride) 68 | self.bn1 = nn.BatchNorm2d(planes) 69 | 70 | if relu_type == "relu": 71 | self.relu1 = nn.ReLU(inplace=True) 72 | self.relu2 = nn.ReLU(inplace=True) 73 | elif relu_type == "prelu": 74 | self.relu1 = nn.PReLU(num_parameters=planes) 75 | self.relu2 = nn.PReLU(num_parameters=planes) 76 | elif relu_type == "swish": 77 | self.relu1 = Swish() 78 | self.relu2 = Swish() 79 | else: 80 | raise NotImplementedError 81 | # -------- 82 | 83 | self.conv2 = conv3x3(planes, planes) 84 | self.bn2 = nn.BatchNorm2d(planes) 85 | 86 | self.downsample = downsample 87 | self.stride = stride 88 | 89 | def forward(self, x): 90 | """forward. 91 | 92 | :param x: torch.Tensor, input tensor with input size (B, C, T, H, W). 93 | """ 94 | residual = x 95 | out = self.conv1(x) 96 | out = self.bn1(out) 97 | out = self.relu1(out) 98 | out = self.conv2(out) 99 | out = self.bn2(out) 100 | if self.downsample is not None: 101 | residual = self.downsample(x) 102 | 103 | out += residual 104 | out = self.relu2(out) 105 | 106 | return out 107 | 108 | 109 | class ResNet(nn.Module): 110 | 111 | def __init__( 112 | self, 113 | block, 114 | layers, 115 | relu_type="swish", 116 | ): 117 | super(ResNet, self).__init__() 118 | self.inplanes = 64 119 | self.relu_type = relu_type 120 | self.downsample_block = downsample_basic_block 121 | 122 | self.layer1 = self._make_layer(block, 64, layers[0]) 123 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 124 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 125 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 126 | self.avgpool = nn.AdaptiveAvgPool2d(1) 127 | 128 | 129 | def _make_layer(self, block, planes, blocks, stride=1): 130 | """_make_layer. 131 | 132 | :param block: torch.nn.Module, class of blocks. 133 | :param planes: int, number of channels produced by the convolution. 134 | :param blocks: int, number of layers in a block. 135 | :param stride: int, size of the convolving kernel. 136 | """ 137 | downsample = None 138 | if stride != 1 or self.inplanes != planes * block.expansion: 139 | downsample = self.downsample_block( 140 | inplanes=self.inplanes, 141 | outplanes=planes*block.expansion, 142 | stride=stride, 143 | ) 144 | 145 | layers = [] 146 | layers.append( 147 | block( 148 | self.inplanes, 149 | planes, 150 | stride, 151 | downsample, 152 | relu_type=self.relu_type, 153 | ) 154 | ) 155 | self.inplanes = planes * block.expansion 156 | for i in range(1, blocks): 157 | layers.append( 158 | block( 159 | self.inplanes, 160 | planes, 161 | relu_type=self.relu_type, 162 | ) 163 | ) 164 | 165 | return nn.Sequential(*layers) 166 | 167 | def forward(self, x): 168 | """forward. 169 | 170 | :param x: torch.Tensor, input tensor with input size (B, C, T, H, W). 171 | """ 172 | x = self.layer1(x) 173 | x = self.layer2(x) 174 | x = self.layer3(x) 175 | x = self.layer4(x) 176 | x = self.avgpool(x) 177 | x = x.view(x.size(0), -1) 178 | return x 179 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/backbones/modules/resnet1d.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import pdb 4 | 5 | from espnet.nets.pytorch_backend.transformer.convolution import Swish 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | """conv3x3. 10 | 11 | :param in_planes: int, number of channels in the input sequence. 12 | :param out_planes: int, number of channels produced by the convolution. 13 | :param stride: int, size of the convolving kernel. 14 | """ 15 | return nn.Conv1d( 16 | in_planes, 17 | out_planes, 18 | kernel_size=3, 19 | stride=stride, 20 | padding=1, 21 | bias=False, 22 | ) 23 | 24 | 25 | def downsample_basic_block(inplanes, outplanes, stride): 26 | """downsample_basic_block. 27 | 28 | :param inplanes: int, number of channels in the input sequence. 29 | :param outplanes: int, number of channels produced by the convolution. 30 | :param stride: int, size of the convolving kernel. 31 | """ 32 | return nn.Sequential( 33 | nn.Conv1d( 34 | inplanes, 35 | outplanes, 36 | kernel_size=1, 37 | stride=stride, 38 | bias=False, 39 | ), 40 | nn.BatchNorm1d(outplanes), 41 | ) 42 | 43 | 44 | class BasicBlock1D(nn.Module): 45 | expansion = 1 46 | 47 | def __init__( 48 | self, 49 | inplanes, 50 | planes, 51 | stride=1, 52 | downsample=None, 53 | relu_type="relu", 54 | ): 55 | """__init__. 56 | 57 | :param inplanes: int, number of channels in the input sequence. 58 | :param planes: int, number of channels produced by the convolution. 59 | :param stride: int, size of the convolving kernel. 60 | :param downsample: boolean, if True, the temporal resolution is downsampled. 61 | :param relu_type: str, type of activation function. 62 | """ 63 | super(BasicBlock1D, self).__init__() 64 | 65 | assert relu_type in ["relu","prelu", "swish"] 66 | 67 | self.conv1 = conv3x3(inplanes, planes, stride) 68 | self.bn1 = nn.BatchNorm1d(planes) 69 | 70 | # type of ReLU is an input option 71 | if relu_type == "relu": 72 | self.relu1 = nn.ReLU(inplace=True) 73 | self.relu2 = nn.ReLU(inplace=True) 74 | elif relu_type == "prelu": 75 | self.relu1 = nn.PReLU(num_parameters=planes) 76 | self.relu2 = nn.PReLU(num_parameters=planes) 77 | elif relu_type == "swish": 78 | self.relu1 = Swish() 79 | self.relu2 = Swish() 80 | else: 81 | raise NotImplementedError 82 | # -------- 83 | 84 | self.conv2 = conv3x3(planes, planes) 85 | self.bn2 = nn.BatchNorm1d(planes) 86 | 87 | self.downsample = downsample 88 | self.stride = stride 89 | 90 | def forward(self, x): 91 | """forward. 92 | 93 | :param x: torch.Tensor, input tensor with input size (B, C, T) 94 | """ 95 | residual = x 96 | out = self.conv1(x) 97 | out = self.bn1(out) 98 | out = self.relu1(out) 99 | out = self.conv2(out) 100 | out = self.bn2(out) 101 | if self.downsample is not None: 102 | residual = self.downsample(x) 103 | 104 | out += residual 105 | out = self.relu2(out) 106 | 107 | return out 108 | 109 | 110 | class ResNet1D(nn.Module): 111 | 112 | def __init__(self, 113 | block, 114 | layers, 115 | relu_type="swish", 116 | a_upsample_ratio=1, 117 | ): 118 | """__init__. 119 | 120 | :param block: torch.nn.Module, class of blocks. 121 | :param layers: List, customised layers in each block. 122 | :param relu_type: str, type of activation function. 123 | :param a_upsample_ratio: int, The ratio related to the \ 124 | temporal resolution of output features of the frontend. \ 125 | a_upsample_ratio=1 produce features with a fps of 25. 126 | """ 127 | super(ResNet1D, self).__init__() 128 | self.inplanes = 64 129 | self.relu_type = relu_type 130 | self.downsample_block = downsample_basic_block 131 | self.a_upsample_ratio = a_upsample_ratio 132 | 133 | self.conv1 = nn.Conv1d( 134 | in_channels=1, 135 | out_channels=self.inplanes, 136 | kernel_size=80, 137 | stride=4, 138 | padding=38, 139 | bias=False, 140 | ) 141 | self.bn1 = nn.BatchNorm1d(self.inplanes) 142 | 143 | if relu_type == "relu": 144 | self.relu = nn.ReLU(inplace=True) 145 | elif relu_type == "prelu": 146 | self.relu = nn.PReLU(num_parameters=self.inplanes) 147 | elif relu_type == "swish": 148 | self.relu = Swish() 149 | 150 | self.layer1 = self._make_layer(block, 64, layers[0]) 151 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 152 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 153 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 154 | self.avgpool = nn.AvgPool1d( 155 | kernel_size=20//self.a_upsample_ratio, 156 | stride=20//self.a_upsample_ratio, 157 | ) 158 | 159 | 160 | def _make_layer(self, block, planes, blocks, stride=1): 161 | """_make_layer. 162 | 163 | :param block: torch.nn.Module, class of blocks. 164 | :param planes: int, number of channels produced by the convolution. 165 | :param blocks: int, number of layers in a block. 166 | :param stride: int, size of the convolving kernel. 167 | """ 168 | 169 | downsample = None 170 | if stride != 1 or self.inplanes != planes * block.expansion: 171 | downsample = self.downsample_block( 172 | inplanes=self.inplanes, 173 | outplanes=planes*block.expansion, 174 | stride=stride, 175 | ) 176 | 177 | layers = [] 178 | layers.append( 179 | block( 180 | self.inplanes, 181 | planes, 182 | stride, 183 | downsample, 184 | relu_type=self.relu_type, 185 | ) 186 | ) 187 | self.inplanes = planes * block.expansion 188 | for i in range(1, blocks): 189 | layers.append( 190 | block( 191 | self.inplanes, 192 | planes, 193 | relu_type=self.relu_type, 194 | ) 195 | ) 196 | 197 | return nn.Sequential(*layers) 198 | 199 | def forward(self, x): 200 | """forward. 201 | 202 | :param x: torch.Tensor, input tensor with input size (B, C, T) 203 | """ 204 | x = self.conv1(x) 205 | x = self.bn1(x) 206 | x = self.relu(x) 207 | 208 | x = self.layer1(x) 209 | x = self.layer2(x) 210 | x = self.layer3(x) 211 | x = self.layer4(x) 212 | x = self.avgpool(x) 213 | return x 214 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/backbones/modules/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from collections import OrderedDict 6 | from torch.nn import init 7 | import math 8 | 9 | import pdb 10 | 11 | def conv_bn(inp, oup, stride): 12 | return nn.Sequential( 13 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 14 | nn.BatchNorm2d(oup), 15 | nn.ReLU(inplace=True) 16 | ) 17 | 18 | 19 | def conv_1x1_bn(inp, oup): 20 | return nn.Sequential( 21 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 22 | nn.BatchNorm2d(oup), 23 | nn.ReLU(inplace=True) 24 | ) 25 | 26 | def channel_shuffle(x, groups): 27 | batchsize, num_channels, height, width = x.data.size() 28 | 29 | channels_per_group = num_channels // groups 30 | 31 | # reshape 32 | x = x.view(batchsize, groups, 33 | channels_per_group, height, width) 34 | 35 | x = torch.transpose(x, 1, 2).contiguous() 36 | 37 | # flatten 38 | x = x.view(batchsize, -1, height, width) 39 | 40 | return x 41 | 42 | class InvertedResidual(nn.Module): 43 | def __init__(self, inp, oup, stride, benchmodel): 44 | super(InvertedResidual, self).__init__() 45 | self.benchmodel = benchmodel 46 | self.stride = stride 47 | assert stride in [1, 2] 48 | 49 | oup_inc = oup//2 50 | 51 | if self.benchmodel == 1: 52 | #assert inp == oup_inc 53 | self.banch2 = nn.Sequential( 54 | # pw 55 | nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False), 56 | nn.BatchNorm2d(oup_inc), 57 | nn.ReLU(inplace=True), 58 | # dw 59 | nn.Conv2d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False), 60 | nn.BatchNorm2d(oup_inc), 61 | # pw-linear 62 | nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False), 63 | nn.BatchNorm2d(oup_inc), 64 | nn.ReLU(inplace=True), 65 | ) 66 | else: 67 | self.banch1 = nn.Sequential( 68 | # dw 69 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 70 | nn.BatchNorm2d(inp), 71 | # pw-linear 72 | nn.Conv2d(inp, oup_inc, 1, 1, 0, bias=False), 73 | nn.BatchNorm2d(oup_inc), 74 | nn.ReLU(inplace=True), 75 | ) 76 | 77 | self.banch2 = nn.Sequential( 78 | # pw 79 | nn.Conv2d(inp, oup_inc, 1, 1, 0, bias=False), 80 | nn.BatchNorm2d(oup_inc), 81 | nn.ReLU(inplace=True), 82 | # dw 83 | nn.Conv2d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False), 84 | nn.BatchNorm2d(oup_inc), 85 | # pw-linear 86 | nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False), 87 | nn.BatchNorm2d(oup_inc), 88 | nn.ReLU(inplace=True), 89 | ) 90 | 91 | @staticmethod 92 | def _concat(x, out): 93 | # concatenate along channel axis 94 | return torch.cat((x, out), 1) 95 | 96 | def forward(self, x): 97 | if 1==self.benchmodel: 98 | x1 = x[:, :(x.shape[1]//2), :, :] 99 | x2 = x[:, (x.shape[1]//2):, :, :] 100 | out = self._concat(x1, self.banch2(x2)) 101 | elif 2==self.benchmodel: 102 | out = self._concat(self.banch1(x), self.banch2(x)) 103 | 104 | return channel_shuffle(out, 2) 105 | 106 | 107 | class ShuffleNetV2(nn.Module): 108 | def __init__(self, n_class=1000, input_size=224, width_mult=2.): 109 | super(ShuffleNetV2, self).__init__() 110 | 111 | assert input_size % 32 == 0, "Input size needs to be divisible by 32" 112 | 113 | self.stage_repeats = [4, 8, 4] 114 | # index 0 is invalid and should never be called. 115 | # only used for indexing convenience. 116 | if width_mult == 0.5: 117 | self.stage_out_channels = [-1, 24, 48, 96, 192, 1024] 118 | elif width_mult == 1.0: 119 | self.stage_out_channels = [-1, 24, 116, 232, 464, 1024] 120 | elif width_mult == 1.5: 121 | self.stage_out_channels = [-1, 24, 176, 352, 704, 1024] 122 | elif width_mult == 2.0: 123 | self.stage_out_channels = [-1, 24, 244, 488, 976, 2048] 124 | else: 125 | raise ValueError( 126 | """Width multiplier should be in [0.5, 1.0, 1.5, 2.0]. Current value: {}""".format(width_mult)) 127 | 128 | # building first layer 129 | input_channel = self.stage_out_channels[1] 130 | self.conv1 = conv_bn(3, input_channel, 2) 131 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 132 | 133 | self.features = [] 134 | # building inverted residual blocks 135 | for idxstage in range(len(self.stage_repeats)): 136 | numrepeat = self.stage_repeats[idxstage] 137 | output_channel = self.stage_out_channels[idxstage+2] 138 | for i in range(numrepeat): 139 | if i == 0: 140 | #inp, oup, stride, benchmodel): 141 | self.features.append(InvertedResidual(input_channel, output_channel, 2, 2)) 142 | else: 143 | self.features.append(InvertedResidual(input_channel, output_channel, 1, 1)) 144 | input_channel = output_channel 145 | 146 | 147 | # make it nn.Sequential 148 | self.features = nn.Sequential(*self.features) 149 | 150 | # building last several layers 151 | self.conv_last = conv_1x1_bn(input_channel, self.stage_out_channels[-1]) 152 | self.globalpool = nn.Sequential(nn.AvgPool2d(int(input_size/32))) 153 | 154 | # building classifier 155 | self.classifier = nn.Sequential(nn.Linear(self.stage_out_channels[-1], n_class)) 156 | 157 | def forward(self, x): 158 | x = self.conv1(x) 159 | x = self.maxpool(x) 160 | x = self.features(x) 161 | x = self.conv_last(x) 162 | x = self.globalpool(x) 163 | x = x.view(-1, self.stage_out_channels[-1]) 164 | x = self.classifier(x) 165 | return x 166 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/ctc.py: -------------------------------------------------------------------------------- 1 | from distutils.version import LooseVersion 2 | import logging 3 | 4 | import numpy as np 5 | import six 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | from espnet.nets.pytorch_backend.nets_utils import to_device 10 | 11 | 12 | class CTC(torch.nn.Module): 13 | """CTC module 14 | 15 | :param int odim: dimension of outputs 16 | :param int eprojs: number of encoder projection units 17 | :param float dropout_rate: dropout rate (0.0 ~ 1.0) 18 | :param str ctc_type: builtin or warpctc 19 | :param bool reduce: reduce the CTC loss into a scalar 20 | """ 21 | 22 | def __init__(self, odim, eprojs, dropout_rate, ctc_type="warpctc", reduce=True): 23 | super().__init__() 24 | self.dropout_rate = dropout_rate 25 | self.loss = None 26 | self.ctc_lo = torch.nn.Linear(eprojs, odim) 27 | self.dropout = torch.nn.Dropout(dropout_rate) 28 | self.probs = None # for visualization 29 | 30 | # In case of Pytorch >= 1.7.0, CTC will be always builtin 31 | self.ctc_type = ( 32 | ctc_type 33 | if LooseVersion(torch.__version__) < LooseVersion("1.7.0") 34 | else "builtin" 35 | ) 36 | 37 | if self.ctc_type == "builtin": 38 | reduction_type = "sum" if reduce else "none" 39 | self.ctc_loss = torch.nn.CTCLoss( 40 | reduction=reduction_type, zero_infinity=True 41 | ) 42 | elif self.ctc_type == "cudnnctc": 43 | reduction_type = "sum" if reduce else "none" 44 | self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type) 45 | elif self.ctc_type == "warpctc": 46 | import warpctc_pytorch as warp_ctc 47 | 48 | self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce) 49 | elif self.ctc_type == "gtnctc": 50 | from espnet.nets.pytorch_backend.gtn_ctc import GTNCTCLossFunction 51 | 52 | self.ctc_loss = GTNCTCLossFunction.apply 53 | else: 54 | raise ValueError( 55 | 'ctc_type must be "builtin" or "warpctc": {}'.format(self.ctc_type) 56 | ) 57 | 58 | self.ignore_id = -1 59 | self.reduce = reduce 60 | 61 | def loss_fn(self, th_pred, th_target, th_ilen, th_olen): 62 | if self.ctc_type in ["builtin", "cudnnctc"]: 63 | th_pred = th_pred.log_softmax(2) 64 | # Use the deterministic CuDNN implementation of CTC loss to avoid 65 | # [issue#17798](https://github.com/pytorch/pytorch/issues/17798) 66 | with torch.backends.cudnn.flags(deterministic=True): 67 | loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen) 68 | # Batch-size average 69 | loss = loss / th_pred.size(1) 70 | return loss 71 | elif self.ctc_type == "warpctc": 72 | return self.ctc_loss(th_pred, th_target, th_ilen, th_olen) 73 | elif self.ctc_type == "gtnctc": 74 | targets = [t.tolist() for t in th_target] 75 | log_probs = torch.nn.functional.log_softmax(th_pred, dim=2) 76 | return self.ctc_loss(log_probs, targets, th_ilen, 0, "none") 77 | else: 78 | raise NotImplementedError 79 | 80 | def forward(self, hs_pad, hlens, ys_pad): 81 | """CTC forward 82 | 83 | :param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D) 84 | :param torch.Tensor hlens: batch of lengths of hidden state sequences (B) 85 | :param torch.Tensor ys_pad: 86 | batch of padded character id sequence tensor (B, Lmax) 87 | :return: ctc loss value 88 | :rtype: torch.Tensor 89 | """ 90 | # TODO(kan-bayashi): need to make more smart way 91 | ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys 92 | 93 | # zero padding for hs 94 | ys_hat = self.ctc_lo(self.dropout(hs_pad)) 95 | if self.ctc_type != "gtnctc": 96 | ys_hat = ys_hat.transpose(0, 1) 97 | 98 | if self.ctc_type == "builtin": 99 | olens = to_device(ys_hat, torch.LongTensor([len(s) for s in ys])) 100 | hlens = hlens.long() 101 | ys_pad = torch.cat(ys) # without this the code breaks for asr_mix 102 | self.loss = self.loss_fn(ys_hat, ys_pad, hlens, olens) 103 | else: 104 | self.loss = None 105 | hlens = torch.from_numpy(np.fromiter(hlens, dtype=np.int32)) 106 | olens = torch.from_numpy( 107 | np.fromiter((x.size(0) for x in ys), dtype=np.int32) 108 | ) 109 | # zero padding for ys 110 | ys_true = torch.cat(ys).cpu().int() # batch x olen 111 | # get ctc loss 112 | # expected shape of seqLength x batchSize x alphabet_size 113 | dtype = ys_hat.dtype 114 | if self.ctc_type == "warpctc" or dtype == torch.float16: 115 | # warpctc only supports float32 116 | # torch.ctc does not support float16 (#1751) 117 | ys_hat = ys_hat.to(dtype=torch.float32) 118 | if self.ctc_type == "cudnnctc": 119 | # use GPU when using the cuDNN implementation 120 | ys_true = to_device(hs_pad, ys_true) 121 | if self.ctc_type == "gtnctc": 122 | # keep as list for gtn 123 | ys_true = ys 124 | self.loss = to_device( 125 | hs_pad, self.loss_fn(ys_hat, ys_true, hlens, olens) 126 | ).to(dtype=dtype) 127 | 128 | # get length info 129 | logging.info( 130 | self.__class__.__name__ 131 | + " input lengths: " 132 | + "".join(str(hlens).split("\n")) 133 | ) 134 | logging.info( 135 | self.__class__.__name__ 136 | + " output lengths: " 137 | + "".join(str(olens).split("\n")) 138 | ) 139 | 140 | if self.reduce: 141 | # NOTE: sum() is needed to keep consistency 142 | # since warpctc return as tensor w/ shape (1,) 143 | # but builtin return as tensor w/o shape (scalar). 144 | self.loss = self.loss.sum() 145 | logging.info("ctc loss:" + str(float(self.loss))) 146 | 147 | return self.loss 148 | 149 | def softmax(self, hs_pad): 150 | """softmax of frame activations 151 | 152 | :param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) 153 | :return: log softmax applied 3d tensor (B, Tmax, odim) 154 | :rtype: torch.Tensor 155 | """ 156 | self.probs = F.softmax(self.ctc_lo(hs_pad), dim=2) 157 | return self.probs 158 | 159 | def log_softmax(self, hs_pad): 160 | """log_softmax of frame activations 161 | 162 | :param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) 163 | :return: log softmax applied 3d tensor (B, Tmax, odim) 164 | :rtype: torch.Tensor 165 | """ 166 | return F.log_softmax(self.ctc_lo(hs_pad), dim=2) 167 | 168 | def argmax(self, hs_pad): 169 | """argmax of frame activations 170 | 171 | :param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) 172 | :return: argmax applied 2d tensor (B, Tmax) 173 | :rtype: torch.Tensor 174 | """ 175 | return torch.argmax(self.ctc_lo(hs_pad), dim=2) 176 | 177 | def forced_align(self, h, y, blank_id=0): 178 | """forced alignment. 179 | 180 | :param torch.Tensor h: hidden state sequence, 2d tensor (T, D) 181 | :param torch.Tensor y: id sequence tensor 1d tensor (L) 182 | :param int y: blank symbol index 183 | :return: best alignment results 184 | :rtype: list 185 | """ 186 | 187 | def interpolate_blank(label, blank_id=0): 188 | """Insert blank token between every two label token.""" 189 | label = np.expand_dims(label, 1) 190 | blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id 191 | label = np.concatenate([blanks, label], axis=1) 192 | label = label.reshape(-1) 193 | label = np.append(label, label[0]) 194 | return label 195 | 196 | lpz = self.log_softmax(h) 197 | lpz = lpz.squeeze(0) 198 | 199 | y_int = interpolate_blank(y, blank_id) 200 | 201 | logdelta = np.zeros((lpz.size(0), len(y_int))) - 100000000000.0 # log of zero 202 | state_path = ( 203 | np.zeros((lpz.size(0), len(y_int)), dtype=np.int16) - 1 204 | ) # state path 205 | 206 | logdelta[0, 0] = lpz[0][y_int[0]] 207 | logdelta[0, 1] = lpz[0][y_int[1]] 208 | 209 | for t in six.moves.range(1, lpz.size(0)): 210 | for s in six.moves.range(len(y_int)): 211 | if y_int[s] == blank_id or s < 2 or y_int[s] == y_int[s - 2]: 212 | candidates = np.array([logdelta[t - 1, s], logdelta[t - 1, s - 1]]) 213 | prev_state = [s, s - 1] 214 | else: 215 | candidates = np.array( 216 | [ 217 | logdelta[t - 1, s], 218 | logdelta[t - 1, s - 1], 219 | logdelta[t - 1, s - 2], 220 | ] 221 | ) 222 | prev_state = [s, s - 1, s - 2] 223 | logdelta[t, s] = np.max(candidates) + lpz[t][y_int[s]] 224 | state_path[t, s] = prev_state[np.argmax(candidates)] 225 | 226 | state_seq = -1 * np.ones((lpz.size(0), 1), dtype=np.int16) 227 | 228 | candidates = np.array( 229 | [logdelta[-1, len(y_int) - 1], logdelta[-1, len(y_int) - 2]] 230 | ) 231 | prev_state = [len(y_int) - 1, len(y_int) - 2] 232 | state_seq[-1] = prev_state[np.argmax(candidates)] 233 | for t in six.moves.range(lpz.size(0) - 2, -1, -1): 234 | state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]] 235 | 236 | output_state_seq = [] 237 | for t in six.moves.range(0, lpz.size(0)): 238 | output_state_seq.append(y_int[state_seq[t, 0]]) 239 | 240 | return output_state_seq 241 | 242 | 243 | def ctc_for(args, odim, reduce=True): 244 | """Returns the CTC module for the given args and output dimension 245 | 246 | :param Namespace args: the program args 247 | :param int odim : The output dimension 248 | :param bool reduce : return the CTC loss in a scalar 249 | :return: the corresponding CTC module 250 | """ 251 | num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility 252 | if num_encs == 1: 253 | # compatible with single encoder asr mode 254 | return CTC( 255 | odim, args.eprojs, args.dropout_rate, ctc_type=args.ctc_type, reduce=reduce 256 | ) 257 | elif num_encs >= 1: 258 | ctcs_list = torch.nn.ModuleList() 259 | if args.share_ctc: 260 | # use dropout_rate of the first encoder 261 | ctc = CTC( 262 | odim, 263 | args.eprojs, 264 | args.dropout_rate[0], 265 | ctc_type=args.ctc_type, 266 | reduce=reduce, 267 | ) 268 | ctcs_list.append(ctc) 269 | else: 270 | for idx in range(num_encs): 271 | ctc = CTC( 272 | odim, 273 | args.eprojs, 274 | args.dropout_rate[idx], 275 | ctc_type=args.ctc_type, 276 | reduce=reduce, 277 | ) 278 | ctcs_list.append(ctc) 279 | return ctcs_list 280 | else: 281 | raise ValueError( 282 | "Number of encoders needs to be more than one. {}".format(num_encs) 283 | ) 284 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/e2e_asr_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Shigeki Karita 2 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 3 | 4 | """Transformer speech recognition model (pytorch).""" 5 | 6 | from argparse import Namespace 7 | from distutils.util import strtobool 8 | import logging 9 | import math 10 | 11 | import numpy 12 | import torch 13 | 14 | from espnet.nets.ctc_prefix_score import CTCPrefixScore 15 | from espnet.nets.e2e_asr_common import end_detect 16 | from espnet.nets.e2e_asr_common import ErrorCalculator 17 | from espnet.nets.pytorch_backend.ctc import CTC 18 | from espnet.nets.pytorch_backend.nets_utils import get_subsample 19 | from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask 20 | from espnet.nets.pytorch_backend.nets_utils import th_accuracy 21 | from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos 22 | from espnet.nets.pytorch_backend.transformer.attention import ( 23 | MultiHeadedAttention, # noqa: H301 24 | RelPositionMultiHeadedAttention, # noqa: H301 25 | ) 26 | from espnet.nets.pytorch_backend.transformer.decoder import Decoder 27 | from espnet.nets.pytorch_backend.transformer.encoder import Encoder 28 | from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( 29 | LabelSmoothingLoss, # noqa: H301 30 | ) 31 | from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask 32 | from espnet.nets.pytorch_backend.transformer.mask import target_mask 33 | from espnet.nets.scorers.ctc import CTCPrefixScorer 34 | 35 | 36 | class E2E(torch.nn.Module): 37 | """E2E module. 38 | 39 | :param int idim: dimension of inputs 40 | :param int odim: dimension of outputs 41 | :param Namespace args: argument Namespace containing options 42 | 43 | """ 44 | 45 | @staticmethod 46 | def add_arguments(parser): 47 | """Add arguments.""" 48 | group = parser.add_argument_group("transformer model setting") 49 | 50 | group.add_argument( 51 | "--transformer-init", 52 | type=str, 53 | default="pytorch", 54 | choices=[ 55 | "pytorch", 56 | "xavier_uniform", 57 | "xavier_normal", 58 | "kaiming_uniform", 59 | "kaiming_normal", 60 | ], 61 | help="how to initialize transformer parameters", 62 | ) 63 | group.add_argument( 64 | "--transformer-input-layer", 65 | type=str, 66 | default="conv2d", 67 | choices=["conv3d", "conv2d", "conv1d", "linear", "embed"], 68 | help="transformer input layer type", 69 | ) 70 | group.add_argument( 71 | "--transformer-encoder-attn-layer-type", 72 | type=str, 73 | default="mha", 74 | choices=["mha", "rel_mha", "legacy_rel_mha"], 75 | help="transformer encoder attention layer type", 76 | ) 77 | group.add_argument( 78 | "--transformer-attn-dropout-rate", 79 | default=None, 80 | type=float, 81 | help="dropout in transformer attention. use --dropout-rate if None is set", 82 | ) 83 | group.add_argument( 84 | "--transformer-lr", 85 | default=10.0, 86 | type=float, 87 | help="Initial value of learning rate", 88 | ) 89 | group.add_argument( 90 | "--transformer-warmup-steps", 91 | default=25000, 92 | type=int, 93 | help="optimizer warmup steps", 94 | ) 95 | group.add_argument( 96 | "--transformer-length-normalized-loss", 97 | default=True, 98 | type=strtobool, 99 | help="normalize loss by length", 100 | ) 101 | group.add_argument( 102 | "--dropout-rate", 103 | default=0.0, 104 | type=float, 105 | help="Dropout rate for the encoder", 106 | ) 107 | group.add_argument( 108 | "--macaron-style", 109 | default=False, 110 | type=strtobool, 111 | help="Whether to use macaron style for positionwise layer", 112 | ) 113 | # -- input 114 | group.add_argument( 115 | "--a-upsample-ratio", 116 | default=1, 117 | type=int, 118 | help="Upsample rate for audio", 119 | ) 120 | group.add_argument( 121 | "--relu-type", 122 | default="swish", 123 | type=str, 124 | help="the type of activation layer", 125 | ) 126 | # Encoder 127 | group.add_argument( 128 | "--elayers", 129 | default=4, 130 | type=int, 131 | help="Number of encoder layers (for shared recognition part " 132 | "in multi-speaker asr mode)", 133 | ) 134 | group.add_argument( 135 | "--eunits", 136 | "-u", 137 | default=300, 138 | type=int, 139 | help="Number of encoder hidden units", 140 | ) 141 | group.add_argument( 142 | "--use-cnn-module", 143 | default=False, 144 | type=strtobool, 145 | help="Use convolution module or not", 146 | ) 147 | group.add_argument( 148 | "--cnn-module-kernel", 149 | default=31, 150 | type=int, 151 | help="Kernel size of convolution module.", 152 | ) 153 | # Attention 154 | group.add_argument( 155 | "--adim", 156 | default=320, 157 | type=int, 158 | help="Number of attention transformation dimensions", 159 | ) 160 | group.add_argument( 161 | "--aheads", 162 | default=4, 163 | type=int, 164 | help="Number of heads for multi head attention", 165 | ) 166 | group.add_argument( 167 | "--zero-triu", 168 | default=False, 169 | type=strtobool, 170 | help="If true, zero the uppper triangular part of attention matrix.", 171 | ) 172 | # Relative positional encoding 173 | group.add_argument( 174 | "--rel-pos-type", 175 | type=str, 176 | default="legacy", 177 | choices=["legacy", "latest"], 178 | help="Whether to use the latest relative positional encoding or the legacy one." 179 | "The legacy relative positional encoding will be deprecated in the future." 180 | "More Details can be found in https://github.com/espnet/espnet/pull/2816.", 181 | ) 182 | # Decoder 183 | group.add_argument( 184 | "--dlayers", default=1, type=int, help="Number of decoder layers" 185 | ) 186 | group.add_argument( 187 | "--dunits", default=320, type=int, help="Number of decoder hidden units" 188 | ) 189 | # -- pretrain 190 | group.add_argument("--pretrain-dataset", 191 | default="", 192 | type=str, 193 | help='pre-trained dataset for encoder' 194 | ) 195 | # -- custom name 196 | group.add_argument("--custom-pretrain-name", 197 | default="", 198 | type=str, 199 | help='pre-trained model for encoder' 200 | ) 201 | return parser 202 | 203 | @property 204 | def attention_plot_class(self): 205 | """Return PlotAttentionReport.""" 206 | return PlotAttentionReport 207 | 208 | def __init__(self, odim, args, ignore_id=-1): 209 | """Construct an E2E object. 210 | :param int odim: dimension of outputs 211 | :param Namespace args: argument Namespace containing options 212 | """ 213 | torch.nn.Module.__init__(self) 214 | if args.transformer_attn_dropout_rate is None: 215 | args.transformer_attn_dropout_rate = args.dropout_rate 216 | # Check the relative positional encoding type 217 | self.rel_pos_type = getattr(args, "rel_pos_type", None) 218 | if self.rel_pos_type is None and args.transformer_encoder_attn_layer_type == "rel_mha": 219 | args.transformer_encoder_attn_layer_type = "legacy_rel_mha" 220 | logging.warning( 221 | "Using legacy_rel_pos and it will be deprecated in the future." 222 | ) 223 | 224 | idim = 80 225 | 226 | self.encoder = Encoder( 227 | idim=idim, 228 | attention_dim=args.adim, 229 | attention_heads=args.aheads, 230 | linear_units=args.eunits, 231 | num_blocks=args.elayers, 232 | input_layer=args.transformer_input_layer, 233 | dropout_rate=args.dropout_rate, 234 | positional_dropout_rate=args.dropout_rate, 235 | attention_dropout_rate=args.transformer_attn_dropout_rate, 236 | encoder_attn_layer_type=args.transformer_encoder_attn_layer_type, 237 | macaron_style=args.macaron_style, 238 | use_cnn_module=args.use_cnn_module, 239 | cnn_module_kernel=args.cnn_module_kernel, 240 | zero_triu=getattr(args, "zero_triu", False), 241 | a_upsample_ratio=args.a_upsample_ratio, 242 | relu_type=getattr(args, "relu_type", "swish"), 243 | ) 244 | 245 | self.transformer_input_layer = args.transformer_input_layer 246 | self.a_upsample_ratio = args.a_upsample_ratio 247 | 248 | if args.mtlalpha < 1: 249 | self.decoder = Decoder( 250 | odim=odim, 251 | attention_dim=args.adim, 252 | attention_heads=args.aheads, 253 | linear_units=args.dunits, 254 | num_blocks=args.dlayers, 255 | dropout_rate=args.dropout_rate, 256 | positional_dropout_rate=args.dropout_rate, 257 | self_attention_dropout_rate=args.transformer_attn_dropout_rate, 258 | src_attention_dropout_rate=args.transformer_attn_dropout_rate, 259 | ) 260 | else: 261 | self.decoder = None 262 | self.blank = 0 263 | self.sos = odim - 1 264 | self.eos = odim - 1 265 | self.odim = odim 266 | self.ignore_id = ignore_id 267 | self.subsample = get_subsample(args, mode="asr", arch="transformer") 268 | 269 | # self.lsm_weight = a 270 | self.criterion = LabelSmoothingLoss( 271 | self.odim, 272 | self.ignore_id, 273 | args.lsm_weight, 274 | args.transformer_length_normalized_loss, 275 | ) 276 | 277 | self.adim = args.adim 278 | self.mtlalpha = args.mtlalpha 279 | if args.mtlalpha > 0.0: 280 | self.ctc = CTC( 281 | odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True 282 | ) 283 | else: 284 | self.ctc = None 285 | 286 | if args.report_cer or args.report_wer: 287 | self.error_calculator = ErrorCalculator( 288 | args.char_list, 289 | args.sym_space, 290 | args.sym_blank, 291 | args.report_cer, 292 | args.report_wer, 293 | ) 294 | else: 295 | self.error_calculator = None 296 | self.rnnlm = None 297 | 298 | def scorers(self): 299 | """Scorers.""" 300 | return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) 301 | 302 | def encode(self, x, extract_resnet_feats=False): 303 | """Encode acoustic features. 304 | 305 | :param ndarray x: source acoustic feature (T, D) 306 | :return: encoder outputs 307 | :rtype: torch.Tensor 308 | """ 309 | self.eval() 310 | x = torch.as_tensor(x).unsqueeze(0) 311 | if extract_resnet_feats: 312 | resnet_feats = self.encoder( 313 | x, 314 | None, 315 | extract_resnet_feats=extract_resnet_feats, 316 | ) 317 | return resnet_feats.squeeze(0) 318 | else: 319 | enc_output, _ = self.encoder(x, None) 320 | return enc_output.squeeze(0) 321 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/lm/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize sub package.""" 2 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/lm/seq_rnn.py: -------------------------------------------------------------------------------- 1 | """Sequential implementation of Recurrent Neural Network Language Model.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from espnet.nets.lm_interface import LMInterface 8 | 9 | 10 | class SequentialRNNLM(LMInterface, torch.nn.Module): 11 | """Sequential RNNLM. 12 | 13 | See also: 14 | https://github.com/pytorch/examples/blob/4581968193699de14b56527296262dd76ab43557/word_language_model/model.py 15 | 16 | """ 17 | 18 | @staticmethod 19 | def add_arguments(parser): 20 | """Add arguments to command line argument parser.""" 21 | parser.add_argument( 22 | "--type", 23 | type=str, 24 | default="lstm", 25 | nargs="?", 26 | choices=["lstm", "gru"], 27 | help="Which type of RNN to use", 28 | ) 29 | parser.add_argument( 30 | "--layer", "-l", type=int, default=2, help="Number of hidden layers" 31 | ) 32 | parser.add_argument( 33 | "--unit", "-u", type=int, default=650, help="Number of hidden units" 34 | ) 35 | parser.add_argument( 36 | "--dropout-rate", type=float, default=0.5, help="dropout probability" 37 | ) 38 | return parser 39 | 40 | def __init__(self, n_vocab, args): 41 | """Initialize class. 42 | 43 | Args: 44 | n_vocab (int): The size of the vocabulary 45 | args (argparse.Namespace): configurations. see py:method:`add_arguments` 46 | 47 | """ 48 | torch.nn.Module.__init__(self) 49 | self._setup( 50 | rnn_type=args.type.upper(), 51 | ntoken=n_vocab, 52 | ninp=args.unit, 53 | nhid=args.unit, 54 | nlayers=args.layer, 55 | dropout=args.dropout_rate, 56 | ) 57 | 58 | def _setup( 59 | self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False 60 | ): 61 | self.drop = nn.Dropout(dropout) 62 | self.encoder = nn.Embedding(ntoken, ninp) 63 | if rnn_type in ["LSTM", "GRU"]: 64 | self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout) 65 | else: 66 | try: 67 | nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type] 68 | except KeyError: 69 | raise ValueError( 70 | "An invalid option for `--model` was supplied, " 71 | "options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']" 72 | ) 73 | self.rnn = nn.RNN( 74 | ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout 75 | ) 76 | self.decoder = nn.Linear(nhid, ntoken) 77 | 78 | # Optionally tie weights as in: 79 | # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) 80 | # https://arxiv.org/abs/1608.05859 81 | # and 82 | # "Tying Word Vectors and Word Classifiers: 83 | # A Loss Framework for Language Modeling" (Inan et al. 2016) 84 | # https://arxiv.org/abs/1611.01462 85 | if tie_weights: 86 | if nhid != ninp: 87 | raise ValueError( 88 | "When using the tied flag, nhid must be equal to emsize" 89 | ) 90 | self.decoder.weight = self.encoder.weight 91 | 92 | self._init_weights() 93 | 94 | self.rnn_type = rnn_type 95 | self.nhid = nhid 96 | self.nlayers = nlayers 97 | 98 | def _init_weights(self): 99 | # NOTE: original init in pytorch/examples 100 | # initrange = 0.1 101 | # self.encoder.weight.data.uniform_(-initrange, initrange) 102 | # self.decoder.bias.data.zero_() 103 | # self.decoder.weight.data.uniform_(-initrange, initrange) 104 | # NOTE: our default.py:RNNLM init 105 | for param in self.parameters(): 106 | param.data.uniform_(-0.1, 0.1) 107 | 108 | def forward(self, x, t): 109 | """Compute LM loss value from buffer sequences. 110 | 111 | Args: 112 | x (torch.Tensor): Input ids. (batch, len) 113 | t (torch.Tensor): Target ids. (batch, len) 114 | 115 | Returns: 116 | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of 117 | loss to backward (scalar), 118 | negative log-likelihood of t: -log p(t) (scalar) and 119 | the number of elements in x (scalar) 120 | 121 | Notes: 122 | The last two return values are used 123 | in perplexity: p(t)^{-n} = exp(-log p(t) / n) 124 | 125 | """ 126 | y = self._before_loss(x, None)[0] 127 | mask = (x != 0).to(y.dtype) 128 | loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none") 129 | logp = loss * mask.view(-1) 130 | logp = logp.sum() 131 | count = mask.sum() 132 | return logp / count, logp, count 133 | 134 | def _before_loss(self, input, hidden): 135 | emb = self.drop(self.encoder(input)) 136 | output, hidden = self.rnn(emb, hidden) 137 | output = self.drop(output) 138 | decoded = self.decoder( 139 | output.view(output.size(0) * output.size(1), output.size(2)) 140 | ) 141 | return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden 142 | 143 | def init_state(self, x): 144 | """Get an initial state for decoding. 145 | 146 | Args: 147 | x (torch.Tensor): The encoded feature tensor 148 | 149 | Returns: initial state 150 | 151 | """ 152 | bsz = 1 153 | weight = next(self.parameters()) 154 | if self.rnn_type == "LSTM": 155 | return ( 156 | weight.new_zeros(self.nlayers, bsz, self.nhid), 157 | weight.new_zeros(self.nlayers, bsz, self.nhid), 158 | ) 159 | else: 160 | return weight.new_zeros(self.nlayers, bsz, self.nhid) 161 | 162 | def score(self, y, state, x): 163 | """Score new token. 164 | 165 | Args: 166 | y (torch.Tensor): 1D torch.int64 prefix tokens. 167 | state: Scorer state for prefix tokens 168 | x (torch.Tensor): 2D encoder feature that generates ys. 169 | 170 | Returns: 171 | tuple[torch.Tensor, Any]: Tuple of 172 | torch.float32 scores for next token (n_vocab) 173 | and next state for ys 174 | 175 | """ 176 | y, new_state = self._before_loss(y[-1].view(1, 1), state) 177 | logp = y.log_softmax(dim=-1).view(-1) 178 | return logp, new_state 179 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/lm/transformer.py: -------------------------------------------------------------------------------- 1 | """Transformer language model.""" 2 | 3 | from typing import Any 4 | from typing import List 5 | from typing import Tuple 6 | 7 | import logging 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from espnet.nets.lm_interface import LMInterface 13 | from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding 14 | from espnet.nets.pytorch_backend.transformer.encoder import Encoder 15 | from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask 16 | from espnet.nets.scorer_interface import BatchScorerInterface 17 | from espnet.utils.cli_utils import strtobool 18 | 19 | 20 | class TransformerLM(nn.Module, LMInterface, BatchScorerInterface): 21 | """Transformer language model.""" 22 | 23 | @staticmethod 24 | def add_arguments(parser): 25 | """Add arguments to command line argument parser.""" 26 | parser.add_argument( 27 | "--layer", type=int, default=4, help="Number of hidden layers" 28 | ) 29 | parser.add_argument( 30 | "--unit", 31 | type=int, 32 | default=1024, 33 | help="Number of hidden units in feedforward layer", 34 | ) 35 | parser.add_argument( 36 | "--att-unit", 37 | type=int, 38 | default=256, 39 | help="Number of hidden units in attention layer", 40 | ) 41 | parser.add_argument( 42 | "--embed-unit", 43 | type=int, 44 | default=128, 45 | help="Number of hidden units in embedding layer", 46 | ) 47 | parser.add_argument( 48 | "--head", type=int, default=2, help="Number of multi head attention" 49 | ) 50 | parser.add_argument( 51 | "--dropout-rate", type=float, default=0.5, help="dropout probability" 52 | ) 53 | parser.add_argument( 54 | "--att-dropout-rate", 55 | type=float, 56 | default=0.0, 57 | help="att dropout probability", 58 | ) 59 | parser.add_argument( 60 | "--emb-dropout-rate", 61 | type=float, 62 | default=0.0, 63 | help="emb dropout probability", 64 | ) 65 | parser.add_argument( 66 | "--tie-weights", 67 | type=strtobool, 68 | default=False, 69 | help="Tie input and output embeddings", 70 | ) 71 | parser.add_argument( 72 | "--pos-enc", 73 | default="sinusoidal", 74 | choices=["sinusoidal", "none"], 75 | help="positional encoding", 76 | ) 77 | return parser 78 | 79 | def __init__(self, n_vocab, args): 80 | """Initialize class. 81 | 82 | Args: 83 | n_vocab (int): The size of the vocabulary 84 | args (argparse.Namespace): configurations. see py:method:`add_arguments` 85 | 86 | """ 87 | nn.Module.__init__(self) 88 | 89 | # NOTE: for a compatibility with less than 0.9.7 version models 90 | emb_dropout_rate = getattr(args, "emb_dropout_rate", 0.0) 91 | # NOTE: for a compatibility with less than 0.9.7 version models 92 | tie_weights = getattr(args, "tie_weights", False) 93 | # NOTE: for a compatibility with less than 0.9.7 version models 94 | att_dropout_rate = getattr(args, "att_dropout_rate", 0.0) 95 | 96 | if args.pos_enc == "sinusoidal": 97 | pos_enc_class = PositionalEncoding 98 | elif args.pos_enc == "none": 99 | 100 | def pos_enc_class(*args, **kwargs): 101 | return nn.Sequential() # indentity 102 | 103 | else: 104 | raise ValueError(f"unknown pos-enc option: {args.pos_enc}") 105 | 106 | self.embed = nn.Embedding(n_vocab, args.embed_unit) 107 | 108 | if emb_dropout_rate == 0.0: 109 | self.embed_drop = None 110 | else: 111 | self.embed_drop = nn.Dropout(emb_dropout_rate) 112 | 113 | self.encoder = Encoder( 114 | idim=args.embed_unit, 115 | attention_dim=args.att_unit, 116 | attention_heads=args.head, 117 | linear_units=args.unit, 118 | num_blocks=args.layer, 119 | dropout_rate=args.dropout_rate, 120 | attention_dropout_rate=att_dropout_rate, 121 | input_layer="linear", 122 | pos_enc_class=pos_enc_class, 123 | ) 124 | self.decoder = nn.Linear(args.att_unit, n_vocab) 125 | 126 | logging.info("Tie weights set to {}".format(tie_weights)) 127 | logging.info("Dropout set to {}".format(args.dropout_rate)) 128 | logging.info("Emb Dropout set to {}".format(emb_dropout_rate)) 129 | logging.info("Att Dropout set to {}".format(att_dropout_rate)) 130 | 131 | if tie_weights: 132 | assert ( 133 | args.att_unit == args.embed_unit 134 | ), "Tie Weights: True need embedding and final dimensions to match" 135 | self.decoder.weight = self.embed.weight 136 | 137 | def _target_mask(self, ys_in_pad): 138 | ys_mask = ys_in_pad != 0 139 | m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) 140 | return ys_mask.unsqueeze(-2) & m 141 | 142 | def forward( 143 | self, x: torch.Tensor, t: torch.Tensor 144 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 145 | """Compute LM loss value from buffer sequences. 146 | 147 | Args: 148 | x (torch.Tensor): Input ids. (batch, len) 149 | t (torch.Tensor): Target ids. (batch, len) 150 | 151 | Returns: 152 | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of 153 | loss to backward (scalar), 154 | negative log-likelihood of t: -log p(t) (scalar) and 155 | the number of elements in x (scalar) 156 | 157 | Notes: 158 | The last two return values are used 159 | in perplexity: p(t)^{-n} = exp(-log p(t) / n) 160 | 161 | """ 162 | xm = x != 0 163 | 164 | if self.embed_drop is not None: 165 | emb = self.embed_drop(self.embed(x)) 166 | else: 167 | emb = self.embed(x) 168 | 169 | h, _ = self.encoder(emb, self._target_mask(x)) 170 | y = self.decoder(h) 171 | loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none") 172 | mask = xm.to(dtype=loss.dtype) 173 | logp = loss * mask.view(-1) 174 | logp = logp.sum() 175 | count = mask.sum() 176 | return logp / count, logp, count 177 | 178 | def score( 179 | self, y: torch.Tensor, state: Any, x: torch.Tensor 180 | ) -> Tuple[torch.Tensor, Any]: 181 | """Score new token. 182 | 183 | Args: 184 | y (torch.Tensor): 1D torch.int64 prefix tokens. 185 | state: Scorer state for prefix tokens 186 | x (torch.Tensor): encoder feature that generates ys. 187 | 188 | Returns: 189 | tuple[torch.Tensor, Any]: Tuple of 190 | torch.float32 scores for next token (n_vocab) 191 | and next state for ys 192 | 193 | """ 194 | y = y.unsqueeze(0) 195 | 196 | if self.embed_drop is not None: 197 | emb = self.embed_drop(self.embed(y)) 198 | else: 199 | emb = self.embed(y) 200 | 201 | h, _, cache = self.encoder.forward_one_step( 202 | emb, self._target_mask(y), cache=state 203 | ) 204 | h = self.decoder(h[:, -1]) 205 | logp = h.log_softmax(dim=-1).squeeze(0) 206 | return logp, cache 207 | 208 | # batch beam search API (see BatchScorerInterface) 209 | def batch_score( 210 | self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor 211 | ) -> Tuple[torch.Tensor, List[Any]]: 212 | """Score new token batch (required). 213 | 214 | Args: 215 | ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). 216 | states (List[Any]): Scorer states for prefix tokens. 217 | xs (torch.Tensor): 218 | The encoder feature that generates ys (n_batch, xlen, n_feat). 219 | 220 | Returns: 221 | tuple[torch.Tensor, List[Any]]: Tuple of 222 | batchfied scores for next token with shape of `(n_batch, n_vocab)` 223 | and next state list for ys. 224 | 225 | """ 226 | # merge states 227 | n_batch = len(ys) 228 | n_layers = len(self.encoder.encoders) 229 | if states[0] is None: 230 | batch_state = None 231 | else: 232 | # transpose state of [batch, layer] into [layer, batch] 233 | batch_state = [ 234 | torch.stack([states[b][i] for b in range(n_batch)]) 235 | for i in range(n_layers) 236 | ] 237 | 238 | if self.embed_drop is not None: 239 | emb = self.embed_drop(self.embed(ys)) 240 | else: 241 | emb = self.embed(ys) 242 | 243 | # batch decoding 244 | h, _, states = self.encoder.forward_one_step( 245 | emb, self._target_mask(ys), cache=batch_state 246 | ) 247 | h = self.decoder(h[:, -1]) 248 | logp = h.log_softmax(dim=-1) 249 | 250 | # transpose state of [layer, batch] into [batch, layer] 251 | state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] 252 | return logp, state_list 253 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize sub package.""" 2 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/add_sos_eos.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Unility funcitons for Transformer.""" 8 | 9 | import torch 10 | 11 | 12 | def add_sos_eos(ys_pad, sos, eos, ignore_id): 13 | """Add and labels. 14 | 15 | :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) 16 | :param int sos: index of 17 | :param int eos: index of 18 | :param int ignore_id: index of padding 19 | :return: padded tensor (B, Lmax) 20 | :rtype: torch.Tensor 21 | :return: padded tensor (B, Lmax) 22 | :rtype: torch.Tensor 23 | """ 24 | from espnet.nets.pytorch_backend.nets_utils import pad_list 25 | 26 | _sos = ys_pad.new([sos]) 27 | _eos = ys_pad.new([eos]) 28 | ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys 29 | ys_in = [torch.cat([_sos, y], dim=0) for y in ys] 30 | ys_out = [torch.cat([y, _eos], dim=0) for y in ys] 31 | return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) 32 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Multi-Head Attention layer definition.""" 8 | 9 | import math 10 | 11 | import numpy 12 | import torch 13 | from torch import nn 14 | 15 | 16 | class MultiHeadedAttention(nn.Module): 17 | """Multi-Head Attention layer. 18 | Args: 19 | n_head (int): The number of heads. 20 | n_feat (int): The number of features. 21 | dropout_rate (float): Dropout rate. 22 | """ 23 | 24 | def __init__(self, n_head, n_feat, dropout_rate): 25 | """Construct an MultiHeadedAttention object.""" 26 | super(MultiHeadedAttention, self).__init__() 27 | assert n_feat % n_head == 0 28 | # We assume d_v always equals d_k 29 | self.d_k = n_feat // n_head 30 | self.h = n_head 31 | self.linear_q = nn.Linear(n_feat, n_feat) 32 | self.linear_k = nn.Linear(n_feat, n_feat) 33 | self.linear_v = nn.Linear(n_feat, n_feat) 34 | self.linear_out = nn.Linear(n_feat, n_feat) 35 | self.attn = None 36 | self.dropout = nn.Dropout(p=dropout_rate) 37 | 38 | def forward_qkv(self, query, key, value): 39 | """Transform query, key and value. 40 | Args: 41 | query (torch.Tensor): Query tensor (#batch, time1, size). 42 | key (torch.Tensor): Key tensor (#batch, time2, size). 43 | value (torch.Tensor): Value tensor (#batch, time2, size). 44 | Returns: 45 | torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). 46 | torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). 47 | torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). 48 | """ 49 | n_batch = query.size(0) 50 | q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) 51 | k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) 52 | v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) 53 | q = q.transpose(1, 2) # (batch, head, time1, d_k) 54 | k = k.transpose(1, 2) # (batch, head, time2, d_k) 55 | v = v.transpose(1, 2) # (batch, head, time2, d_k) 56 | 57 | return q, k, v 58 | 59 | def forward_attention(self, value, scores, mask, rtn_attn=False): 60 | """Compute attention context vector. 61 | Args: 62 | value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). 63 | scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). 64 | mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). 65 | rtn_attn (boolean): Flag of return attention score 66 | Returns: 67 | torch.Tensor: Transformed value (#batch, time1, d_model) 68 | weighted by the attention score (#batch, time1, time2). 69 | """ 70 | n_batch = value.size(0) 71 | if mask is not None: 72 | mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) 73 | min_value = float( 74 | numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min 75 | ) 76 | scores = scores.masked_fill(mask, min_value) 77 | self.attn = torch.softmax(scores, dim=-1).masked_fill( 78 | mask, 0.0 79 | ) # (batch, head, time1, time2) 80 | else: 81 | self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) 82 | 83 | p_attn = self.dropout(self.attn) 84 | x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) 85 | x = ( 86 | x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) 87 | ) # (batch, time1, d_model) 88 | if rtn_attn: 89 | return self.linear_out(x), self.attn 90 | return self.linear_out(x) # (batch, time1, d_model) 91 | 92 | def forward(self, query, key, value, mask, rtn_attn=False): 93 | """Compute scaled dot product attention. 94 | Args: 95 | query (torch.Tensor): Query tensor (#batch, time1, size). 96 | key (torch.Tensor): Key tensor (#batch, time2, size). 97 | value (torch.Tensor): Value tensor (#batch, time2, size). 98 | mask (torch.Tensor): Mask tensor (#batch, 1, time2) or 99 | (#batch, time1, time2). 100 | rtn_attn (boolean): Flag of return attention score 101 | Returns: 102 | torch.Tensor: Output tensor (#batch, time1, d_model). 103 | """ 104 | q, k, v = self.forward_qkv(query, key, value) 105 | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) 106 | return self.forward_attention(v, scores, mask, rtn_attn) 107 | 108 | 109 | class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention): 110 | """Multi-Head Attention layer with relative position encoding (old version). 111 | Details can be found in https://github.com/espnet/espnet/pull/2816. 112 | Paper: https://arxiv.org/abs/1901.02860 113 | Args: 114 | n_head (int): The number of heads. 115 | n_feat (int): The number of features. 116 | dropout_rate (float): Dropout rate. 117 | zero_triu (bool): Whether to zero the upper triangular part of attention matrix. 118 | """ 119 | 120 | def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): 121 | """Construct an RelPositionMultiHeadedAttention object.""" 122 | super().__init__(n_head, n_feat, dropout_rate) 123 | self.zero_triu = zero_triu 124 | # linear transformation for positional encoding 125 | self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) 126 | # these two learnable bias are used in matrix c and matrix d 127 | # as described in https://arxiv.org/abs/1901.02860 Section 3.3 128 | self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) 129 | self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) 130 | torch.nn.init.xavier_uniform_(self.pos_bias_u) 131 | torch.nn.init.xavier_uniform_(self.pos_bias_v) 132 | 133 | def rel_shift(self, x): 134 | """Compute relative positional encoding. 135 | Args: 136 | x (torch.Tensor): Input tensor (batch, head, time1, time2). 137 | Returns: 138 | torch.Tensor: Output tensor. 139 | """ 140 | zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) 141 | x_padded = torch.cat([zero_pad, x], dim=-1) 142 | 143 | x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) 144 | x = x_padded[:, :, 1:].view_as(x) 145 | 146 | if self.zero_triu: 147 | ones = torch.ones((x.size(2), x.size(3))) 148 | x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] 149 | 150 | return x 151 | 152 | def forward(self, query, key, value, pos_emb, mask): 153 | """Compute 'Scaled Dot Product Attention' with rel. positional encoding. 154 | Args: 155 | query (torch.Tensor): Query tensor (#batch, time1, size). 156 | key (torch.Tensor): Key tensor (#batch, time2, size). 157 | value (torch.Tensor): Value tensor (#batch, time2, size). 158 | pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size). 159 | mask (torch.Tensor): Mask tensor (#batch, 1, time2) or 160 | (#batch, time1, time2). 161 | Returns: 162 | torch.Tensor: Output tensor (#batch, time1, d_model). 163 | """ 164 | q, k, v = self.forward_qkv(query, key, value) 165 | q = q.transpose(1, 2) # (batch, time1, head, d_k) 166 | 167 | n_batch_pos = pos_emb.size(0) 168 | p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) 169 | p = p.transpose(1, 2) # (batch, head, time1, d_k) 170 | 171 | # (batch, head, time1, d_k) 172 | q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) 173 | # (batch, head, time1, d_k) 174 | q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) 175 | 176 | # compute attention score 177 | # first compute matrix a and matrix c 178 | # as described in https://arxiv.org/abs/1901.02860 Section 3.3 179 | # (batch, head, time1, time2) 180 | matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) 181 | 182 | # compute matrix b and matrix d 183 | # (batch, head, time1, time1) 184 | matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) 185 | matrix_bd = self.rel_shift(matrix_bd) 186 | 187 | scores = (matrix_ac + matrix_bd) / math.sqrt( 188 | self.d_k 189 | ) # (batch, head, time1, time2) 190 | 191 | return self.forward_attention(v, scores, mask) 192 | 193 | 194 | class RelPositionMultiHeadedAttention(MultiHeadedAttention): 195 | """Multi-Head Attention layer with relative position encoding (new implementation). 196 | Details can be found in https://github.com/espnet/espnet/pull/2816. 197 | Paper: https://arxiv.org/abs/1901.02860 198 | Args: 199 | n_head (int): The number of heads. 200 | n_feat (int): The number of features. 201 | dropout_rate (float): Dropout rate. 202 | zero_triu (bool): Whether to zero the upper triangular part of attention matrix. 203 | """ 204 | 205 | def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): 206 | """Construct an RelPositionMultiHeadedAttention object.""" 207 | super().__init__(n_head, n_feat, dropout_rate) 208 | self.zero_triu = zero_triu 209 | # linear transformation for positional encoding 210 | self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) 211 | # these two learnable bias are used in matrix c and matrix d 212 | # as described in https://arxiv.org/abs/1901.02860 Section 3.3 213 | self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) 214 | self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) 215 | torch.nn.init.xavier_uniform_(self.pos_bias_u) 216 | torch.nn.init.xavier_uniform_(self.pos_bias_v) 217 | 218 | def rel_shift(self, x): 219 | """Compute relative positional encoding. 220 | Args: 221 | x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). 222 | time1 means the length of query vector. 223 | Returns: 224 | torch.Tensor: Output tensor. 225 | """ 226 | zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) 227 | x_padded = torch.cat([zero_pad, x], dim=-1) 228 | 229 | x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) 230 | x = x_padded[:, :, 1:].view_as(x)[ 231 | :, :, :, : x.size(-1) // 2 + 1 232 | ] # only keep the positions from 0 to time2 233 | 234 | if self.zero_triu: 235 | ones = torch.ones((x.size(2), x.size(3)), device=x.device) 236 | x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] 237 | 238 | return x 239 | 240 | def forward(self, query, key, value, pos_emb, mask): 241 | """Compute 'Scaled Dot Product Attention' with rel. positional encoding. 242 | Args: 243 | query (torch.Tensor): Query tensor (#batch, time1, size). 244 | key (torch.Tensor): Key tensor (#batch, time2, size). 245 | value (torch.Tensor): Value tensor (#batch, time2, size). 246 | pos_emb (torch.Tensor): Positional embedding tensor 247 | (#batch, 2*time1-1, size). 248 | mask (torch.Tensor): Mask tensor (#batch, 1, time2) or 249 | (#batch, time1, time2). 250 | Returns: 251 | torch.Tensor: Output tensor (#batch, time1, d_model). 252 | """ 253 | q, k, v = self.forward_qkv(query, key, value) 254 | q = q.transpose(1, 2) # (batch, time1, head, d_k) 255 | 256 | n_batch_pos = pos_emb.size(0) 257 | p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) 258 | p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) 259 | 260 | # (batch, head, time1, d_k) 261 | q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) 262 | # (batch, head, time1, d_k) 263 | q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) 264 | 265 | # compute attention score 266 | # first compute matrix a and matrix c 267 | # as described in https://arxiv.org/abs/1901.02860 Section 3.3 268 | # (batch, head, time1, time2) 269 | matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) 270 | 271 | # compute matrix b and matrix d 272 | # (batch, head, time1, 2*time1-1) 273 | matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) 274 | matrix_bd = self.rel_shift(matrix_bd) 275 | 276 | scores = (matrix_ac + matrix_bd) / math.sqrt( 277 | self.d_k 278 | ) # (batch, head, time1, time2) 279 | 280 | return self.forward_attention(v, scores, mask) 281 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/convolution.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2020 Johns Hopkins University (Shinji Watanabe) 5 | # Northwestern Polytechnical University (Pengcheng Guo) 6 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 7 | 8 | """ConvolutionModule definition.""" 9 | 10 | import torch 11 | from torch import nn 12 | 13 | 14 | class ConvolutionModule(nn.Module): 15 | """ConvolutionModule in Conformer model. 16 | 17 | :param int channels: channels of cnn 18 | :param int kernel_size: kernerl size of cnn 19 | 20 | """ 21 | 22 | def __init__(self, channels, kernel_size, bias=True): 23 | """Construct an ConvolutionModule object.""" 24 | super(ConvolutionModule, self).__init__() 25 | # kernerl_size should be a odd number for 'SAME' padding 26 | assert (kernel_size - 1) % 2 == 0 27 | 28 | self.pointwise_cov1 = nn.Conv1d( 29 | channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias, 30 | ) 31 | self.depthwise_conv = nn.Conv1d( 32 | channels, 33 | channels, 34 | kernel_size, 35 | stride=1, 36 | padding=(kernel_size - 1) // 2, 37 | groups=channels, 38 | bias=bias, 39 | ) 40 | self.norm = nn.BatchNorm1d(channels) 41 | self.pointwise_cov2 = nn.Conv1d( 42 | channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, 43 | ) 44 | self.activation = Swish() 45 | 46 | def forward(self, x): 47 | """Compute covolution module. 48 | 49 | :param torch.Tensor x: (batch, time, size) 50 | :return torch.Tensor: convoluted `value` (batch, time, d_model) 51 | """ 52 | # exchange the temporal dimension and the feature dimension 53 | x = x.transpose(1, 2) 54 | 55 | # GLU mechanism 56 | x = self.pointwise_cov1(x) # (batch, 2*channel, dim) 57 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim) 58 | 59 | # 1D Depthwise Conv 60 | x = self.depthwise_conv(x) 61 | x = self.activation(self.norm(x)) 62 | 63 | x = self.pointwise_cov2(x) 64 | 65 | return x.transpose(1, 2) 66 | 67 | 68 | class Swish(nn.Module): 69 | """Construct an Swish object.""" 70 | 71 | def forward(self, x): 72 | """Return Swich activation function.""" 73 | return x * torch.sigmoid(x) 74 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/decoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Decoder definition.""" 8 | 9 | from typing import Any 10 | from typing import List 11 | from typing import Tuple 12 | 13 | import torch 14 | 15 | from espnet.nets.pytorch_backend.nets_utils import rename_state_dict 16 | from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention 17 | from espnet.nets.pytorch_backend.transformer.decoder_layer import DecoderLayer 18 | from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding 19 | from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm 20 | from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask 21 | from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( 22 | PositionwiseFeedForward, # noqa: H301 23 | ) 24 | from espnet.nets.pytorch_backend.transformer.repeat import repeat 25 | from espnet.nets.scorer_interface import BatchScorerInterface 26 | 27 | 28 | def _pre_hook( 29 | state_dict, 30 | prefix, 31 | local_metadata, 32 | strict, 33 | missing_keys, 34 | unexpected_keys, 35 | error_msgs, 36 | ): 37 | # https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563 38 | rename_state_dict(prefix + "output_norm.", prefix + "after_norm.", state_dict) 39 | 40 | 41 | class Decoder(BatchScorerInterface, torch.nn.Module): 42 | """Transfomer decoder module. 43 | 44 | :param int odim: output dim 45 | :param int attention_dim: dimention of attention 46 | :param int attention_heads: the number of heads of multi head attention 47 | :param int linear_units: the number of units of position-wise feed forward 48 | :param int num_blocks: the number of decoder blocks 49 | :param float dropout_rate: dropout rate 50 | :param float attention_dropout_rate: dropout rate for attention 51 | :param str or torch.nn.Module input_layer: input layer type 52 | :param bool use_output_layer: whether to use output layer 53 | :param class pos_enc_class: PositionalEncoding or ScaledPositionalEncoding 54 | :param bool normalize_before: whether to use layer_norm before the first block 55 | :param bool concat_after: whether to concat attention layer's input and output 56 | if True, additional linear will be applied. 57 | i.e. x -> x + linear(concat(x, att(x))) 58 | if False, no additional linear will be applied. i.e. x -> x + att(x) 59 | """ 60 | 61 | def __init__( 62 | self, 63 | odim, 64 | attention_dim=256, 65 | attention_heads=4, 66 | linear_units=2048, 67 | num_blocks=6, 68 | dropout_rate=0.1, 69 | positional_dropout_rate=0.1, 70 | self_attention_dropout_rate=0.0, 71 | src_attention_dropout_rate=0.0, 72 | input_layer="embed", 73 | use_output_layer=True, 74 | pos_enc_class=PositionalEncoding, 75 | normalize_before=True, 76 | concat_after=False, 77 | ): 78 | """Construct an Decoder object.""" 79 | torch.nn.Module.__init__(self) 80 | self._register_load_state_dict_pre_hook(_pre_hook) 81 | if input_layer == "embed": 82 | self.embed = torch.nn.Sequential( 83 | torch.nn.Embedding(odim, attention_dim), 84 | pos_enc_class(attention_dim, positional_dropout_rate), 85 | ) 86 | elif input_layer == "linear": 87 | self.embed = torch.nn.Sequential( 88 | torch.nn.Linear(odim, attention_dim), 89 | torch.nn.LayerNorm(attention_dim), 90 | torch.nn.Dropout(dropout_rate), 91 | torch.nn.ReLU(), 92 | pos_enc_class(attention_dim, positional_dropout_rate), 93 | ) 94 | elif isinstance(input_layer, torch.nn.Module): 95 | self.embed = torch.nn.Sequential( 96 | input_layer, pos_enc_class(attention_dim, positional_dropout_rate) 97 | ) 98 | else: 99 | raise NotImplementedError("only `embed` or torch.nn.Module is supported.") 100 | self.normalize_before = normalize_before 101 | self.decoders = repeat( 102 | num_blocks, 103 | lambda: DecoderLayer( 104 | attention_dim, 105 | MultiHeadedAttention( 106 | attention_heads, attention_dim, self_attention_dropout_rate 107 | ), 108 | MultiHeadedAttention( 109 | attention_heads, attention_dim, src_attention_dropout_rate 110 | ), 111 | PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), 112 | dropout_rate, 113 | normalize_before, 114 | concat_after, 115 | ), 116 | ) 117 | if self.normalize_before: 118 | self.after_norm = LayerNorm(attention_dim) 119 | if use_output_layer: 120 | self.output_layer = torch.nn.Linear(attention_dim, odim) 121 | else: 122 | self.output_layer = None 123 | 124 | def forward(self, tgt, tgt_mask, memory, memory_mask): 125 | """Forward decoder. 126 | :param torch.Tensor tgt: input token ids, int64 (batch, maxlen_out) 127 | if input_layer == "embed" 128 | input tensor (batch, maxlen_out, #mels) 129 | in the other cases 130 | :param torch.Tensor tgt_mask: input token mask, (batch, maxlen_out) 131 | dtype=torch.uint8 in PyTorch 1.2- 132 | dtype=torch.bool in PyTorch 1.2+ (include 1.2) 133 | :param torch.Tensor memory: encoded memory, float32 (batch, maxlen_in, feat) 134 | :param torch.Tensor memory_mask: encoded memory mask, (batch, maxlen_in) 135 | dtype=torch.uint8 in PyTorch 1.2- 136 | dtype=torch.bool in PyTorch 1.2+ (include 1.2) 137 | :return x: decoded token score before softmax (batch, maxlen_out, token) 138 | if use_output_layer is True, 139 | final block outputs (batch, maxlen_out, attention_dim) 140 | in the other cases 141 | :rtype: torch.Tensor 142 | :return tgt_mask: score mask before softmax (batch, maxlen_out) 143 | :rtype: torch.Tensor 144 | """ 145 | x = self.embed(tgt) 146 | x, tgt_mask, memory, memory_mask = self.decoders( 147 | x, tgt_mask, memory, memory_mask 148 | ) 149 | if self.normalize_before: 150 | x = self.after_norm(x) 151 | if self.output_layer is not None: 152 | x = self.output_layer(x) 153 | return x, tgt_mask 154 | 155 | def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None): 156 | """Forward one step. 157 | :param torch.Tensor tgt: input token ids, int64 (batch, maxlen_out) 158 | :param torch.Tensor tgt_mask: input token mask, (batch, maxlen_out) 159 | dtype=torch.uint8 in PyTorch 1.2- 160 | dtype=torch.bool in PyTorch 1.2+ (include 1.2) 161 | :param torch.Tensor memory: encoded memory, float32 (batch, maxlen_in, feat) 162 | :param List[torch.Tensor] cache: 163 | cached output list of (batch, max_time_out-1, size) 164 | :return y, cache: NN output value and cache per `self.decoders`. 165 | `y.shape` is (batch, maxlen_out, token) 166 | :rtype: Tuple[torch.Tensor, List[torch.Tensor]] 167 | """ 168 | x = self.embed(tgt) 169 | if cache is None: 170 | cache = [None] * len(self.decoders) 171 | new_cache = [] 172 | for c, decoder in zip(cache, self.decoders): 173 | x, tgt_mask, memory, memory_mask = decoder( 174 | x, tgt_mask, memory, memory_mask, cache=c 175 | ) 176 | new_cache.append(x) 177 | 178 | if self.normalize_before: 179 | y = self.after_norm(x[:, -1]) 180 | else: 181 | y = x[:, -1] 182 | if self.output_layer is not None: 183 | y = torch.log_softmax(self.output_layer(y), dim=-1) 184 | 185 | return y, new_cache 186 | 187 | # beam search API (see ScorerInterface) 188 | def score(self, ys, state, x): 189 | """Score.""" 190 | ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) 191 | logp, state = self.forward_one_step( 192 | ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state 193 | ) 194 | return logp.squeeze(0), state 195 | 196 | # batch beam search API (see BatchScorerInterface) 197 | def batch_score( 198 | self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor 199 | ) -> Tuple[torch.Tensor, List[Any]]: 200 | """Score new token batch (required). 201 | Args: 202 | ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). 203 | states (List[Any]): Scorer states for prefix tokens. 204 | xs (torch.Tensor): 205 | The encoder feature that generates ys (n_batch, xlen, n_feat). 206 | Returns: 207 | tuple[torch.Tensor, List[Any]]: Tuple of 208 | batchfied scores for next token with shape of `(n_batch, n_vocab)` 209 | and next state list for ys. 210 | """ 211 | # merge states 212 | n_batch = len(ys) 213 | n_layers = len(self.decoders) 214 | if states[0] is None: 215 | batch_state = None 216 | else: 217 | # transpose state of [batch, layer] into [layer, batch] 218 | batch_state = [ 219 | torch.stack([states[b][l] for b in range(n_batch)]) 220 | for l in range(n_layers) 221 | ] 222 | 223 | # batch decoding 224 | ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0) 225 | logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state) 226 | 227 | # transpose state of [layer, batch] into [batch, layer] 228 | state_list = [[states[l][b] for l in range(n_layers)] for b in range(n_batch)] 229 | return logp, state_list 230 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/decoder_layer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Decoder self-attention layer definition.""" 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm 13 | 14 | 15 | class DecoderLayer(nn.Module): 16 | """Single decoder layer module. 17 | :param int size: input dim 18 | :param espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention 19 | self_attn: self attention module 20 | :param espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention 21 | src_attn: source attention module 22 | :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward. 23 | PositionwiseFeedForward feed_forward: feed forward layer module 24 | :param float dropout_rate: dropout rate 25 | :param bool normalize_before: whether to use layer_norm before the first block 26 | :param bool concat_after: whether to concat attention layer's input and output 27 | if True, additional linear will be applied. 28 | i.e. x -> x + linear(concat(x, att(x))) 29 | if False, no additional linear will be applied. i.e. x -> x + att(x) 30 | """ 31 | 32 | def __init__( 33 | self, 34 | size, 35 | self_attn, 36 | src_attn, 37 | feed_forward, 38 | dropout_rate, 39 | normalize_before=True, 40 | concat_after=False, 41 | ): 42 | """Construct an DecoderLayer object.""" 43 | super(DecoderLayer, self).__init__() 44 | self.size = size 45 | self.self_attn = self_attn 46 | self.src_attn = src_attn 47 | self.feed_forward = feed_forward 48 | self.norm1 = LayerNorm(size) 49 | self.norm2 = LayerNorm(size) 50 | self.norm3 = LayerNorm(size) 51 | self.dropout = nn.Dropout(dropout_rate) 52 | self.normalize_before = normalize_before 53 | self.concat_after = concat_after 54 | if self.concat_after: 55 | self.concat_linear1 = nn.Linear(size + size, size) 56 | self.concat_linear2 = nn.Linear(size + size, size) 57 | 58 | def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None): 59 | """Compute decoded features. 60 | Args: 61 | tgt (torch.Tensor): 62 | decoded previous target features (batch, max_time_out, size) 63 | tgt_mask (torch.Tensor): mask for x (batch, max_time_out) 64 | memory (torch.Tensor): encoded source features (batch, max_time_in, size) 65 | memory_mask (torch.Tensor): mask for memory (batch, max_time_in) 66 | cache (torch.Tensor): cached output (batch, max_time_out-1, size) 67 | """ 68 | residual = tgt 69 | if self.normalize_before: 70 | tgt = self.norm1(tgt) 71 | 72 | if cache is None: 73 | tgt_q = tgt 74 | tgt_q_mask = tgt_mask 75 | else: 76 | # compute only the last frame query keeping dim: max_time_out -> 1 77 | assert cache.shape == ( 78 | tgt.shape[0], 79 | tgt.shape[1] - 1, 80 | self.size, 81 | ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" 82 | tgt_q = tgt[:, -1:, :] 83 | residual = residual[:, -1:, :] 84 | tgt_q_mask = None 85 | if tgt_mask is not None: 86 | tgt_q_mask = tgt_mask[:, -1:, :] 87 | 88 | if self.concat_after: 89 | tgt_concat = torch.cat( 90 | (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1 91 | ) 92 | x = residual + self.concat_linear1(tgt_concat) 93 | else: 94 | x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) 95 | if not self.normalize_before: 96 | x = self.norm1(x) 97 | 98 | residual = x 99 | if self.normalize_before: 100 | x = self.norm2(x) 101 | if self.concat_after: 102 | x_concat = torch.cat( 103 | (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1 104 | ) 105 | x = residual + self.concat_linear2(x_concat) 106 | else: 107 | x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask)) 108 | if not self.normalize_before: 109 | x = self.norm2(x) 110 | 111 | residual = x 112 | if self.normalize_before: 113 | x = self.norm3(x) 114 | x = residual + self.dropout(self.feed_forward(x)) 115 | if not self.normalize_before: 116 | x = self.norm3(x) 117 | 118 | if cache is not None: 119 | x = torch.cat([cache, x], dim=1) 120 | 121 | return x, tgt_mask, memory, memory_mask 122 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/embedding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Positional Encoding Module.""" 8 | 9 | import math 10 | 11 | import torch 12 | 13 | 14 | def _pre_hook( 15 | state_dict, 16 | prefix, 17 | local_metadata, 18 | strict, 19 | missing_keys, 20 | unexpected_keys, 21 | error_msgs, 22 | ): 23 | """Perform pre-hook in load_state_dict for backward compatibility. 24 | Note: 25 | We saved self.pe until v.0.5.2 but we have omitted it later. 26 | Therefore, we remove the item "pe" from `state_dict` for backward compatibility. 27 | """ 28 | k = prefix + "pe" 29 | if k in state_dict: 30 | state_dict.pop(k) 31 | 32 | 33 | class PositionalEncoding(torch.nn.Module): 34 | """Positional encoding. 35 | Args: 36 | d_model (int): Embedding dimension. 37 | dropout_rate (float): Dropout rate. 38 | max_len (int): Maximum input length. 39 | reverse (bool): Whether to reverse the input position. Only for 40 | the class LegacyRelPositionalEncoding. We remove it in the current 41 | class RelPositionalEncoding. 42 | """ 43 | 44 | def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False): 45 | """Construct an PositionalEncoding object.""" 46 | super(PositionalEncoding, self).__init__() 47 | self.d_model = d_model 48 | self.reverse = reverse 49 | self.xscale = math.sqrt(self.d_model) 50 | self.dropout = torch.nn.Dropout(p=dropout_rate) 51 | self.pe = None 52 | self.extend_pe(torch.tensor(0.0).expand(1, max_len)) 53 | self._register_load_state_dict_pre_hook(_pre_hook) 54 | 55 | def extend_pe(self, x): 56 | """Reset the positional encodings.""" 57 | if self.pe is not None: 58 | if self.pe.size(1) >= x.size(1): 59 | if self.pe.dtype != x.dtype or self.pe.device != x.device: 60 | self.pe = self.pe.to(dtype=x.dtype, device=x.device) 61 | return 62 | pe = torch.zeros(x.size(1), self.d_model) 63 | if self.reverse: 64 | position = torch.arange( 65 | x.size(1) - 1, -1, -1.0, dtype=torch.float32 66 | ).unsqueeze(1) 67 | else: 68 | position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) 69 | div_term = torch.exp( 70 | torch.arange(0, self.d_model, 2, dtype=torch.float32) 71 | * -(math.log(10000.0) / self.d_model) 72 | ) 73 | pe[:, 0::2] = torch.sin(position * div_term) 74 | pe[:, 1::2] = torch.cos(position * div_term) 75 | pe = pe.unsqueeze(0) 76 | self.pe = pe.to(device=x.device, dtype=x.dtype) 77 | 78 | def forward(self, x: torch.Tensor): 79 | """Add positional encoding. 80 | Args: 81 | x (torch.Tensor): Input tensor (batch, time, `*`). 82 | Returns: 83 | torch.Tensor: Encoded tensor (batch, time, `*`). 84 | """ 85 | self.extend_pe(x) 86 | x = x * self.xscale + self.pe[:, : x.size(1)] 87 | return self.dropout(x) 88 | 89 | 90 | class ScaledPositionalEncoding(PositionalEncoding): 91 | """Scaled positional encoding module. 92 | See Sec. 3.2 https://arxiv.org/abs/1809.08895 93 | Args: 94 | d_model (int): Embedding dimension. 95 | dropout_rate (float): Dropout rate. 96 | max_len (int): Maximum input length. 97 | """ 98 | 99 | def __init__(self, d_model, dropout_rate, max_len=5000): 100 | """Initialize class.""" 101 | super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len) 102 | self.alpha = torch.nn.Parameter(torch.tensor(1.0)) 103 | 104 | def reset_parameters(self): 105 | """Reset parameters.""" 106 | self.alpha.data = torch.tensor(1.0) 107 | 108 | def forward(self, x): 109 | """Add positional encoding. 110 | Args: 111 | x (torch.Tensor): Input tensor (batch, time, `*`). 112 | Returns: 113 | torch.Tensor: Encoded tensor (batch, time, `*`). 114 | """ 115 | self.extend_pe(x) 116 | x = x + self.alpha * self.pe[:, : x.size(1)] 117 | return self.dropout(x) 118 | 119 | 120 | class LegacyRelPositionalEncoding(PositionalEncoding): 121 | """Relative positional encoding module (old version). 122 | Details can be found in https://github.com/espnet/espnet/pull/2816. 123 | See : Appendix B in https://arxiv.org/abs/1901.02860 124 | Args: 125 | d_model (int): Embedding dimension. 126 | dropout_rate (float): Dropout rate. 127 | max_len (int): Maximum input length. 128 | """ 129 | 130 | def __init__(self, d_model, dropout_rate, max_len=5000): 131 | """Initialize class.""" 132 | super().__init__( 133 | d_model=d_model, 134 | dropout_rate=dropout_rate, 135 | max_len=max_len, 136 | reverse=True, 137 | ) 138 | 139 | def forward(self, x): 140 | """Compute positional encoding. 141 | Args: 142 | x (torch.Tensor): Input tensor (batch, time, `*`). 143 | Returns: 144 | torch.Tensor: Encoded tensor (batch, time, `*`). 145 | torch.Tensor: Positional embedding tensor (1, time, `*`). 146 | """ 147 | self.extend_pe(x) 148 | x = x * self.xscale 149 | pos_emb = self.pe[:, : x.size(1)] 150 | return self.dropout(x), self.dropout(pos_emb) 151 | 152 | 153 | class RelPositionalEncoding(torch.nn.Module): 154 | """Relative positional encoding module (new implementation). 155 | Details can be found in https://github.com/espnet/espnet/pull/2816. 156 | See : Appendix B in https://arxiv.org/abs/1901.02860 157 | Args: 158 | d_model (int): Embedding dimension. 159 | dropout_rate (float): Dropout rate. 160 | max_len (int): Maximum input length. 161 | """ 162 | 163 | def __init__(self, d_model, dropout_rate, max_len=5000): 164 | """Construct an PositionalEncoding object.""" 165 | super(RelPositionalEncoding, self).__init__() 166 | self.d_model = d_model 167 | self.xscale = math.sqrt(self.d_model) 168 | self.dropout = torch.nn.Dropout(p=dropout_rate) 169 | self.pe = None 170 | self.extend_pe(torch.tensor(0.0).expand(1, max_len)) 171 | 172 | def extend_pe(self, x): 173 | """Reset the positional encodings.""" 174 | if self.pe is not None: 175 | # self.pe contains both positive and negative parts 176 | # the length of self.pe is 2 * input_len - 1 177 | if self.pe.size(1) >= x.size(1) * 2 - 1: 178 | if self.pe.dtype != x.dtype or self.pe.device != x.device: 179 | self.pe = self.pe.to(dtype=x.dtype, device=x.device) 180 | return 181 | # Suppose `i` means to the position of query vecotr and `j` means the 182 | # position of key vector. We use position relative positions when keys 183 | # are to the left (i>j) and negative relative positions otherwise (i x + linear(concat(x, att(x))) 71 | if False, no additional linear will be applied. i.e. x -> x + att(x) 72 | :param str positionwise_layer_type: linear of conv1d 73 | :param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer 74 | :param str encoder_attn_layer_type: encoder attention layer type 75 | :param bool macaron_style: whether to use macaron style for positionwise layer 76 | :param bool use_cnn_module: whether to use convolution module 77 | :param bool zero_triu: whether to zero the upper triangular part of attention matrix 78 | :param int cnn_module_kernel: kernerl size of convolution module 79 | :param int padding_idx: padding_idx for input_layer=embed 80 | """ 81 | 82 | def __init__( 83 | self, 84 | idim, 85 | attention_dim=256, 86 | attention_heads=4, 87 | linear_units=2048, 88 | num_blocks=6, 89 | dropout_rate=0.1, 90 | positional_dropout_rate=0.1, 91 | attention_dropout_rate=0.0, 92 | input_layer="conv2d", 93 | pos_enc_class=PositionalEncoding, 94 | normalize_before=True, 95 | concat_after=False, 96 | positionwise_layer_type="linear", 97 | positionwise_conv_kernel_size=1, 98 | macaron_style=False, 99 | encoder_attn_layer_type="mha", 100 | use_cnn_module=False, 101 | zero_triu=False, 102 | cnn_module_kernel=31, 103 | padding_idx=-1, 104 | relu_type="prelu", 105 | a_upsample_ratio=1, 106 | ): 107 | """Construct an Encoder object.""" 108 | super(Encoder, self).__init__() 109 | self._register_load_state_dict_pre_hook(_pre_hook) 110 | 111 | if encoder_attn_layer_type == "rel_mha": 112 | pos_enc_class = RelPositionalEncoding 113 | elif encoder_attn_layer_type == "legacy_rel_mha": 114 | pos_enc_class = LegacyRelPositionalEncoding 115 | # -- frontend module. 116 | if input_layer == "conv1d": 117 | self.frontend = Conv1dResNet( 118 | relu_type=relu_type, 119 | a_upsample_ratio=a_upsample_ratio, 120 | ) 121 | elif input_layer == "conv3d": 122 | self.frontend = Conv3dResNet(relu_type=relu_type) 123 | else: 124 | self.frontend = None 125 | # -- backend module. 126 | if input_layer == "linear": 127 | self.embed = torch.nn.Sequential( 128 | torch.nn.Linear(idim, attention_dim), 129 | torch.nn.LayerNorm(attention_dim), 130 | torch.nn.Dropout(dropout_rate), 131 | torch.nn.ReLU(), 132 | pos_enc_class(attention_dim, positional_dropout_rate), 133 | ) 134 | elif input_layer == "conv2d": 135 | self.embed = Conv2dSubsampling( 136 | idim, 137 | attention_dim, 138 | dropout_rate, 139 | pos_enc_class(attention_dim, dropout_rate), 140 | ) 141 | elif input_layer == "vgg2l": 142 | self.embed = VGG2L(idim, attention_dim) 143 | elif input_layer == "embed": 144 | self.embed = torch.nn.Sequential( 145 | torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx), 146 | pos_enc_class(attention_dim, positional_dropout_rate), 147 | ) 148 | elif isinstance(input_layer, torch.nn.Module): 149 | self.embed = torch.nn.Sequential( 150 | input_layer, pos_enc_class(attention_dim, positional_dropout_rate), 151 | ) 152 | elif input_layer in ["conv1d", "conv3d"]: 153 | self.embed = torch.nn.Sequential( 154 | torch.nn.Linear(512, attention_dim), 155 | pos_enc_class(attention_dim, positional_dropout_rate) 156 | ) 157 | elif input_layer is None: 158 | self.embed = torch.nn.Sequential( 159 | pos_enc_class(attention_dim, positional_dropout_rate) 160 | ) 161 | else: 162 | raise ValueError("unknown input_layer: " + input_layer) 163 | self.normalize_before = normalize_before 164 | if positionwise_layer_type == "linear": 165 | positionwise_layer = PositionwiseFeedForward 166 | positionwise_layer_args = (attention_dim, linear_units, dropout_rate) 167 | elif positionwise_layer_type == "conv1d": 168 | positionwise_layer = MultiLayeredConv1d 169 | positionwise_layer_args = ( 170 | attention_dim, 171 | linear_units, 172 | positionwise_conv_kernel_size, 173 | dropout_rate, 174 | ) 175 | elif positionwise_layer_type == "conv1d-linear": 176 | positionwise_layer = Conv1dLinear 177 | positionwise_layer_args = ( 178 | attention_dim, 179 | linear_units, 180 | positionwise_conv_kernel_size, 181 | dropout_rate, 182 | ) 183 | else: 184 | raise NotImplementedError("Support only linear or conv1d.") 185 | 186 | if encoder_attn_layer_type == "mha": 187 | encoder_attn_layer = MultiHeadedAttention 188 | encoder_attn_layer_args = ( 189 | attention_heads, 190 | attention_dim, 191 | attention_dropout_rate, 192 | ) 193 | elif encoder_attn_layer_type == "legacy_rel_mha": 194 | encoder_attn_layer = LegacyRelPositionMultiHeadedAttention 195 | encoder_attn_layer_args = ( 196 | attention_heads, 197 | attention_dim, 198 | attention_dropout_rate, 199 | ) 200 | elif encoder_attn_layer_type == "rel_mha": 201 | encoder_attn_layer = RelPositionMultiHeadedAttention 202 | encoder_attn_layer_args = ( 203 | attention_heads, 204 | attention_dim, 205 | attention_dropout_rate, 206 | zero_triu, 207 | ) 208 | else: 209 | raise ValueError("unknown encoder_attn_layer: " + encoder_attn_layer) 210 | 211 | convolution_layer = ConvolutionModule 212 | convolution_layer_args = (attention_dim, cnn_module_kernel) 213 | 214 | self.encoders = repeat( 215 | num_blocks, 216 | lambda: EncoderLayer( 217 | attention_dim, 218 | encoder_attn_layer(*encoder_attn_layer_args), 219 | positionwise_layer(*positionwise_layer_args), 220 | convolution_layer(*convolution_layer_args) if use_cnn_module else None, 221 | dropout_rate, 222 | normalize_before, 223 | concat_after, 224 | macaron_style, 225 | ), 226 | ) 227 | if self.normalize_before: 228 | self.after_norm = LayerNorm(attention_dim) 229 | 230 | def forward(self, xs, masks, extract_resnet_feats=False): 231 | """Encode input sequence. 232 | 233 | :param torch.Tensor xs: input tensor 234 | :param torch.Tensor masks: input mask 235 | :param str extract_features: the position for feature extraction 236 | :return: position embedded tensor and mask 237 | :rtype Tuple[torch.Tensor, torch.Tensor]: 238 | """ 239 | if isinstance(self.frontend, (Conv1dResNet, Conv3dResNet)): 240 | xs = self.frontend(xs) 241 | if extract_resnet_feats: 242 | return xs 243 | 244 | if isinstance(self.embed, Conv2dSubsampling): 245 | xs, masks = self.embed(xs, masks) 246 | else: 247 | xs = self.embed(xs) 248 | 249 | xs, masks = self.encoders(xs, masks) 250 | 251 | if isinstance(xs, tuple): 252 | xs = xs[0] 253 | 254 | if self.normalize_before: 255 | xs = self.after_norm(xs) 256 | 257 | return xs, masks 258 | 259 | def forward_one_step(self, xs, masks, cache=None): 260 | """Encode input frame. 261 | 262 | :param torch.Tensor xs: input tensor 263 | :param torch.Tensor masks: input mask 264 | :param List[torch.Tensor] cache: cache tensors 265 | :return: position embedded tensor, mask and new cache 266 | :rtype Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: 267 | """ 268 | if isinstance(self.frontend, (Conv1dResNet, Conv3dResNet)): 269 | xs = self.frontend(xs) 270 | 271 | if isinstance(self.embed, Conv2dSubsampling): 272 | xs, masks = self.embed(xs, masks) 273 | else: 274 | xs = self.embed(xs) 275 | if cache is None: 276 | cache = [None for _ in range(len(self.encoders))] 277 | new_cache = [] 278 | for c, e in zip(cache, self.encoders): 279 | xs, masks = e(xs, masks, cache=c) 280 | new_cache.append(xs) 281 | if self.normalize_before: 282 | xs = self.after_norm(xs) 283 | return xs, masks, new_cache 284 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/encoder_layer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Encoder self-attention layer definition.""" 8 | 9 | import copy 10 | import torch 11 | 12 | from torch import nn 13 | 14 | from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm 15 | 16 | 17 | class EncoderLayer(nn.Module): 18 | """Encoder layer module. 19 | 20 | :param int size: input dim 21 | :param espnet.nets.pytorch_backend.transformer.attention. 22 | MultiHeadedAttention self_attn: self attention module 23 | RelPositionMultiHeadedAttention self_attn: self attention module 24 | :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward. 25 | PositionwiseFeedForward feed_forward: 26 | feed forward module 27 | :param espnet.nets.pytorch_backend.transformer.convolution. 28 | ConvolutionModule feed_foreard: 29 | feed forward module 30 | :param float dropout_rate: dropout rate 31 | :param bool normalize_before: whether to use layer_norm before the first block 32 | :param bool concat_after: whether to concat attention layer's input and output 33 | if True, additional linear will be applied. 34 | i.e. x -> x + linear(concat(x, att(x))) 35 | if False, no additional linear will be applied. i.e. x -> x + att(x) 36 | :param bool macaron_style: whether to use macaron style for PositionwiseFeedForward 37 | 38 | """ 39 | 40 | def __init__( 41 | self, 42 | size, 43 | self_attn, 44 | feed_forward, 45 | conv_module, 46 | dropout_rate, 47 | normalize_before=True, 48 | concat_after=False, 49 | macaron_style=False, 50 | ): 51 | """Construct an EncoderLayer object.""" 52 | super(EncoderLayer, self).__init__() 53 | self.self_attn = self_attn 54 | self.feed_forward = feed_forward 55 | self.ff_scale = 1.0 56 | self.conv_module = conv_module 57 | self.macaron_style = macaron_style 58 | self.norm_ff = LayerNorm(size) # for the FNN module 59 | self.norm_mha = LayerNorm(size) # for the MHA module 60 | if self.macaron_style: 61 | self.feed_forward_macaron = copy.deepcopy(feed_forward) 62 | self.ff_scale = 0.5 63 | # for another FNN module in macaron style 64 | self.norm_ff_macaron = LayerNorm(size) 65 | if self.conv_module is not None: 66 | self.norm_conv = LayerNorm(size) # for the CNN module 67 | self.norm_final = LayerNorm(size) # for the final output of the block 68 | self.dropout = nn.Dropout(dropout_rate) 69 | self.size = size 70 | self.normalize_before = normalize_before 71 | self.concat_after = concat_after 72 | if self.concat_after: 73 | self.concat_linear = nn.Linear(size + size, size) 74 | 75 | def forward(self, x_input, mask, cache=None): 76 | """Compute encoded features. 77 | 78 | :param torch.Tensor x_input: encoded source features (batch, max_time_in, size) 79 | :param torch.Tensor mask: mask for x (batch, max_time_in) 80 | :param torch.Tensor cache: cache for x (batch, max_time_in - 1, size) 81 | :rtype: Tuple[torch.Tensor, torch.Tensor] 82 | """ 83 | if isinstance(x_input, tuple): 84 | x, pos_emb = x_input[0], x_input[1] 85 | else: 86 | x, pos_emb = x_input, None 87 | 88 | # whether to use macaron style 89 | if self.macaron_style: 90 | residual = x 91 | if self.normalize_before: 92 | x = self.norm_ff_macaron(x) 93 | x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x)) 94 | if not self.normalize_before: 95 | x = self.norm_ff_macaron(x) 96 | 97 | # multi-headed self-attention module 98 | residual = x 99 | if self.normalize_before: 100 | x = self.norm_mha(x) 101 | 102 | if cache is None: 103 | x_q = x 104 | else: 105 | assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) 106 | x_q = x[:, -1:, :] 107 | residual = residual[:, -1:, :] 108 | mask = None if mask is None else mask[:, -1:, :] 109 | 110 | if pos_emb is not None: 111 | x_att = self.self_attn(x_q, x, x, pos_emb, mask) 112 | else: 113 | x_att = self.self_attn(x_q, x, x, mask) 114 | 115 | if self.concat_after: 116 | x_concat = torch.cat((x, x_att), dim=-1) 117 | x = residual + self.concat_linear(x_concat) 118 | else: 119 | x = residual + self.dropout(x_att) 120 | if not self.normalize_before: 121 | x = self.norm_mha(x) 122 | 123 | # convolution module 124 | if self.conv_module is not None: 125 | residual = x 126 | if self.normalize_before: 127 | x = self.norm_conv(x) 128 | x = residual + self.dropout(self.conv_module(x)) 129 | if not self.normalize_before: 130 | x = self.norm_conv(x) 131 | 132 | # feed forward module 133 | residual = x 134 | if self.normalize_before: 135 | x = self.norm_ff(x) 136 | x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) 137 | if not self.normalize_before: 138 | x = self.norm_ff(x) 139 | 140 | if self.conv_module is not None: 141 | x = self.norm_final(x) 142 | 143 | if cache is not None: 144 | x = torch.cat([cache, x], dim=1) 145 | 146 | if pos_emb is not None: 147 | return (x, pos_emb), mask 148 | else: 149 | return x, mask 150 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Label smoothing module.""" 8 | 9 | import torch 10 | from torch import nn 11 | 12 | 13 | class LabelSmoothingLoss(nn.Module): 14 | """Label-smoothing loss. 15 | 16 | :param int size: the number of class 17 | :param int padding_idx: ignored class id 18 | :param float smoothing: smoothing rate (0.0 means the conventional CE) 19 | :param bool normalize_length: normalize loss by sequence length if True 20 | :param torch.nn.Module criterion: loss function to be smoothed 21 | """ 22 | 23 | def __init__( 24 | self, 25 | size, 26 | padding_idx, 27 | smoothing, 28 | normalize_length=False, 29 | criterion=nn.KLDivLoss(reduction="none"), 30 | ): 31 | """Construct an LabelSmoothingLoss object.""" 32 | super(LabelSmoothingLoss, self).__init__() 33 | self.criterion = criterion 34 | self.padding_idx = padding_idx 35 | self.confidence = 1.0 - smoothing 36 | self.smoothing = smoothing 37 | self.size = size 38 | self.true_dist = None 39 | self.normalize_length = normalize_length 40 | 41 | def forward(self, x, target): 42 | """Compute loss between x and target. 43 | 44 | :param torch.Tensor x: prediction (batch, seqlen, class) 45 | :param torch.Tensor target: 46 | target signal masked with self.padding_id (batch, seqlen) 47 | :return: scalar float value 48 | :rtype torch.Tensor 49 | """ 50 | assert x.size(2) == self.size 51 | batch_size = x.size(0) 52 | x = x.view(-1, self.size) 53 | target = target.view(-1) 54 | with torch.no_grad(): 55 | true_dist = x.clone() 56 | true_dist.fill_(self.smoothing / (self.size - 1)) 57 | ignore = target == self.padding_idx # (B,) 58 | total = len(target) - ignore.sum().item() 59 | target = target.masked_fill(ignore, 0) # avoid -1 index 60 | true_dist.scatter_(1, target.unsqueeze(1), self.confidence) 61 | kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) 62 | denom = total if self.normalize_length else batch_size 63 | return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom 64 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/layer_norm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Layer normalization module.""" 8 | 9 | import torch 10 | 11 | 12 | class LayerNorm(torch.nn.LayerNorm): 13 | """Layer normalization module. 14 | 15 | :param int nout: output dim size 16 | :param int dim: dimension to be normalized 17 | """ 18 | 19 | def __init__(self, nout, dim=-1): 20 | """Construct an LayerNorm object.""" 21 | super(LayerNorm, self).__init__(nout, eps=1e-12) 22 | self.dim = dim 23 | 24 | def forward(self, x): 25 | """Apply layer normalization. 26 | 27 | :param torch.Tensor x: input tensor 28 | :return: layer normalized tensor 29 | :rtype torch.Tensor 30 | """ 31 | if self.dim == -1: 32 | return super(LayerNorm, self).forward(x) 33 | return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) 34 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/mask.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2019 Shigeki Karita 4 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 5 | 6 | """Mask module.""" 7 | 8 | from distutils.version import LooseVersion 9 | 10 | import torch 11 | 12 | is_torch_1_2_plus = LooseVersion(torch.__version__) >= LooseVersion("1.2.0") 13 | # LooseVersion('1.2.0') == LooseVersion(torch.__version__) can't include e.g. 1.2.0+aaa 14 | is_torch_1_2 = ( 15 | LooseVersion("1.3") > LooseVersion(torch.__version__) >= LooseVersion("1.2") 16 | ) 17 | datatype = torch.bool if is_torch_1_2_plus else torch.uint8 18 | 19 | 20 | def subsequent_mask(size, device="cpu", dtype=datatype): 21 | """Create mask for subsequent steps (1, size, size). 22 | 23 | :param int size: size of mask 24 | :param str device: "cpu" or "cuda" or torch.Tensor.device 25 | :param torch.dtype dtype: result dtype 26 | :rtype: torch.Tensor 27 | >>> subsequent_mask(3) 28 | [[1, 0, 0], 29 | [1, 1, 0], 30 | [1, 1, 1]] 31 | """ 32 | if is_torch_1_2 and dtype == torch.bool: 33 | # torch=1.2 doesn't support tril for bool tensor 34 | ret = torch.ones(size, size, device=device, dtype=torch.uint8) 35 | return torch.tril(ret, out=ret).type(dtype) 36 | else: 37 | ret = torch.ones(size, size, device=device, dtype=dtype) 38 | return torch.tril(ret, out=ret) 39 | 40 | 41 | def target_mask(ys_in_pad, ignore_id): 42 | """Create mask for decoder self-attention. 43 | 44 | :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) 45 | :param int ignore_id: index of padding 46 | :param torch.dtype dtype: result dtype 47 | :rtype: torch.Tensor 48 | """ 49 | ys_mask = ys_in_pad != ignore_id 50 | m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) 51 | return ys_mask.unsqueeze(-2) & m 52 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/multi_layer_conv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Tomoki Hayashi 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Layer modules for FFT block in FastSpeech (Feed-forward Transformer).""" 8 | 9 | import torch 10 | 11 | 12 | class MultiLayeredConv1d(torch.nn.Module): 13 | """Multi-layered conv1d for Transformer block. 14 | 15 | This is a module of multi-leyered conv1d designed 16 | to replace positionwise feed-forward network 17 | in Transforner block, which is introduced in 18 | `FastSpeech: Fast, Robust and Controllable Text to Speech`_. 19 | 20 | .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: 21 | https://arxiv.org/pdf/1905.09263.pdf 22 | 23 | """ 24 | 25 | def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate): 26 | """Initialize MultiLayeredConv1d module. 27 | 28 | Args: 29 | in_chans (int): Number of input channels. 30 | hidden_chans (int): Number of hidden channels. 31 | kernel_size (int): Kernel size of conv1d. 32 | dropout_rate (float): Dropout rate. 33 | 34 | """ 35 | super(MultiLayeredConv1d, self).__init__() 36 | self.w_1 = torch.nn.Conv1d( 37 | in_chans, 38 | hidden_chans, 39 | kernel_size, 40 | stride=1, 41 | padding=(kernel_size - 1) // 2, 42 | ) 43 | self.w_2 = torch.nn.Conv1d( 44 | hidden_chans, 45 | in_chans, 46 | kernel_size, 47 | stride=1, 48 | padding=(kernel_size - 1) // 2, 49 | ) 50 | self.dropout = torch.nn.Dropout(dropout_rate) 51 | 52 | def forward(self, x): 53 | """Calculate forward propagation. 54 | 55 | Args: 56 | x (Tensor): Batch of input tensors (B, ..., in_chans). 57 | 58 | Returns: 59 | Tensor: Batch of output tensors (B, ..., hidden_chans). 60 | 61 | """ 62 | x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) 63 | return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1) 64 | 65 | 66 | class Conv1dLinear(torch.nn.Module): 67 | """Conv1D + Linear for Transformer block. 68 | 69 | A variant of MultiLayeredConv1d, which replaces second conv-layer to linear. 70 | 71 | """ 72 | 73 | def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate): 74 | """Initialize Conv1dLinear module. 75 | 76 | Args: 77 | in_chans (int): Number of input channels. 78 | hidden_chans (int): Number of hidden channels. 79 | kernel_size (int): Kernel size of conv1d. 80 | dropout_rate (float): Dropout rate. 81 | 82 | """ 83 | super(Conv1dLinear, self).__init__() 84 | self.w_1 = torch.nn.Conv1d( 85 | in_chans, 86 | hidden_chans, 87 | kernel_size, 88 | stride=1, 89 | padding=(kernel_size - 1) // 2, 90 | ) 91 | self.w_2 = torch.nn.Linear(hidden_chans, in_chans) 92 | self.dropout = torch.nn.Dropout(dropout_rate) 93 | 94 | def forward(self, x): 95 | """Calculate forward propagation. 96 | 97 | Args: 98 | x (Tensor): Batch of input tensors (B, ..., in_chans). 99 | 100 | Returns: 101 | Tensor: Batch of output tensors (B, ..., hidden_chans). 102 | 103 | """ 104 | x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) 105 | return self.w_2(self.dropout(x)) 106 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/optimizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Optimizer module.""" 8 | 9 | import torch 10 | 11 | 12 | class NoamOpt(object): 13 | """Optim wrapper that implements rate.""" 14 | 15 | def __init__(self, model_size, factor, warmup, optimizer): 16 | """Construct an NoamOpt object.""" 17 | self.optimizer = optimizer 18 | self._step = 0 19 | self.warmup = warmup 20 | self.factor = factor 21 | self.model_size = model_size 22 | self._rate = 0 23 | 24 | @property 25 | def param_groups(self): 26 | """Return param_groups.""" 27 | return self.optimizer.param_groups 28 | 29 | def step(self): 30 | """Update parameters and rate.""" 31 | self._step += 1 32 | rate = self.rate() 33 | for p in self.optimizer.param_groups: 34 | p["lr"] = rate 35 | self._rate = rate 36 | self.optimizer.step() 37 | 38 | def rate(self, step=None): 39 | """Implement `lrate` above.""" 40 | if step is None: 41 | step = self._step 42 | return ( 43 | self.factor 44 | * self.model_size ** (-0.5) 45 | * min(step ** (-0.5), step * self.warmup ** (-1.5)) 46 | ) 47 | 48 | def zero_grad(self): 49 | """Reset gradient.""" 50 | self.optimizer.zero_grad() 51 | 52 | def state_dict(self): 53 | """Return state_dict.""" 54 | return { 55 | "_step": self._step, 56 | "warmup": self.warmup, 57 | "factor": self.factor, 58 | "model_size": self.model_size, 59 | "_rate": self._rate, 60 | "optimizer": self.optimizer.state_dict(), 61 | } 62 | 63 | def load_state_dict(self, state_dict): 64 | """Load state_dict.""" 65 | for key, value in state_dict.items(): 66 | if key == "optimizer": 67 | self.optimizer.load_state_dict(state_dict["optimizer"]) 68 | else: 69 | setattr(self, key, value) 70 | 71 | 72 | def get_std_opt(model, d_model, warmup, factor): 73 | """Get standard NoamOpt.""" 74 | base = torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9) 75 | return NoamOpt(d_model, factor, warmup, base) 76 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/plot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | import logging 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy 11 | 12 | from espnet.asr import asr_utils 13 | 14 | 15 | def _plot_and_save_attention(att_w, filename, xtokens=None, ytokens=None): 16 | # dynamically import matplotlib due to not found error 17 | from matplotlib.ticker import MaxNLocator 18 | import os 19 | 20 | d = os.path.dirname(filename) 21 | if not os.path.exists(d): 22 | os.makedirs(d) 23 | w, h = plt.figaspect(1.0 / len(att_w)) 24 | fig = plt.Figure(figsize=(w * 2, h * 2)) 25 | axes = fig.subplots(1, len(att_w)) 26 | if len(att_w) == 1: 27 | axes = [axes] 28 | for ax, aw in zip(axes, att_w): 29 | # plt.subplot(1, len(att_w), h) 30 | ax.imshow(aw.astype(numpy.float32), aspect="auto") 31 | ax.set_xlabel("Input") 32 | ax.set_ylabel("Output") 33 | ax.xaxis.set_major_locator(MaxNLocator(integer=True)) 34 | ax.yaxis.set_major_locator(MaxNLocator(integer=True)) 35 | # Labels for major ticks 36 | if xtokens is not None: 37 | ax.set_xticks(numpy.linspace(0, len(xtokens) - 1, len(xtokens))) 38 | ax.set_xticks(numpy.linspace(0, len(xtokens) - 1, 1), minor=True) 39 | ax.set_xticklabels(xtokens + [""], rotation=40) 40 | if ytokens is not None: 41 | ax.set_yticks(numpy.linspace(0, len(ytokens) - 1, len(ytokens))) 42 | ax.set_yticks(numpy.linspace(0, len(ytokens) - 1, 1), minor=True) 43 | ax.set_yticklabels(ytokens + [""]) 44 | fig.tight_layout() 45 | return fig 46 | 47 | 48 | def savefig(plot, filename): 49 | plot.savefig(filename) 50 | plt.clf() 51 | 52 | 53 | def plot_multi_head_attention( 54 | data, 55 | attn_dict, 56 | outdir, 57 | suffix="png", 58 | savefn=savefig, 59 | ikey="input", 60 | iaxis=0, 61 | okey="output", 62 | oaxis=0, 63 | ): 64 | """Plot multi head attentions. 65 | 66 | :param dict data: utts info from json file 67 | :param dict[str, torch.Tensor] attn_dict: multi head attention dict. 68 | values should be torch.Tensor (head, input_length, output_length) 69 | :param str outdir: dir to save fig 70 | :param str suffix: filename suffix including image type (e.g., png) 71 | :param savefn: function to save 72 | 73 | """ 74 | for name, att_ws in attn_dict.items(): 75 | for idx, att_w in enumerate(att_ws): 76 | filename = "%s/%s.%s.%s" % (outdir, data[idx][0], name, suffix) 77 | dec_len = int(data[idx][1][okey][oaxis]["shape"][0]) 78 | enc_len = int(data[idx][1][ikey][iaxis]["shape"][0]) 79 | xtokens, ytokens = None, None 80 | if "encoder" in name: 81 | att_w = att_w[:, :enc_len, :enc_len] 82 | # for MT 83 | if "token" in data[idx][1][ikey][iaxis].keys(): 84 | xtokens = data[idx][1][ikey][iaxis]["token"].split() 85 | ytokens = xtokens[:] 86 | elif "decoder" in name: 87 | if "self" in name: 88 | att_w = att_w[:, : dec_len + 1, : dec_len + 1] # +1 for 89 | else: 90 | att_w = att_w[:, : dec_len + 1, :enc_len] # +1 for 91 | # for MT 92 | if "token" in data[idx][1][ikey][iaxis].keys(): 93 | xtokens = data[idx][1][ikey][iaxis]["token"].split() 94 | # for ASR/ST/MT 95 | if "token" in data[idx][1][okey][oaxis].keys(): 96 | ytokens = [""] + data[idx][1][okey][oaxis]["token"].split() 97 | if "self" in name: 98 | xtokens = ytokens[:] 99 | else: 100 | logging.warning("unknown name for shaping attention") 101 | fig = _plot_and_save_attention(att_w, filename, xtokens, ytokens) 102 | savefn(fig, filename) 103 | 104 | 105 | class PlotAttentionReport(asr_utils.PlotAttentionReport): 106 | def plotfn(self, *args, **kwargs): 107 | kwargs["ikey"] = self.ikey 108 | kwargs["iaxis"] = self.iaxis 109 | kwargs["okey"] = self.okey 110 | kwargs["oaxis"] = self.oaxis 111 | plot_multi_head_attention(*args, **kwargs) 112 | 113 | def __call__(self, trainer): 114 | attn_dict = self.get_attention_weights() 115 | suffix = "ep.{.updater.epoch}.png".format(trainer) 116 | self.plotfn(self.data, attn_dict, self.outdir, suffix, savefig) 117 | 118 | def get_attention_weights(self): 119 | batch = self.converter([self.transform(self.data)], self.device) 120 | if isinstance(batch, tuple): 121 | att_ws = self.att_vis_fn(*batch) 122 | elif isinstance(batch, dict): 123 | att_ws = self.att_vis_fn(**batch) 124 | return att_ws 125 | 126 | def log_attentions(self, logger, step): 127 | def log_fig(plot, filename): 128 | from os.path import basename 129 | 130 | logger.add_figure(basename(filename), plot, step) 131 | plt.clf() 132 | 133 | attn_dict = self.get_attention_weights() 134 | self.plotfn(self.data, attn_dict, self.outdir, "", log_fig) 135 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/positionwise_feed_forward.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Positionwise feed forward layer definition.""" 8 | 9 | import torch 10 | 11 | 12 | class PositionwiseFeedForward(torch.nn.Module): 13 | """Positionwise feed forward layer. 14 | 15 | :param int idim: input dimenstion 16 | :param int hidden_units: number of hidden units 17 | :param float dropout_rate: dropout rate 18 | 19 | """ 20 | 21 | def __init__(self, idim, hidden_units, dropout_rate): 22 | """Construct an PositionwiseFeedForward object.""" 23 | super(PositionwiseFeedForward, self).__init__() 24 | self.w_1 = torch.nn.Linear(idim, hidden_units) 25 | self.w_2 = torch.nn.Linear(hidden_units, idim) 26 | self.dropout = torch.nn.Dropout(dropout_rate) 27 | 28 | def forward(self, x): 29 | """Forward funciton.""" 30 | return self.w_2(self.dropout(torch.relu(self.w_1(x)))) 31 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/raw_embeddings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | 4 | from espnet.nets.pytorch_backend.backbones.conv3d_extractor import Conv3dResNet 5 | from espnet.nets.pytorch_backend.backbones.conv1d_extractor import Conv1dResNet 6 | 7 | 8 | class VideoEmbedding(torch.nn.Module): 9 | """Video Embedding 10 | 11 | :param int idim: input dim 12 | :param int odim: output dim 13 | :param flaot dropout_rate: dropout rate 14 | """ 15 | 16 | def __init__(self, idim, odim, dropout_rate, pos_enc_class, backbone_type="resnet", relu_type="prelu"): 17 | super(VideoEmbedding, self).__init__() 18 | self.trunk = Conv3dResNet( 19 | backbone_type=backbone_type, 20 | relu_type=relu_type 21 | ) 22 | self.out = torch.nn.Sequential( 23 | torch.nn.Linear(idim, odim), 24 | pos_enc_class, 25 | ) 26 | 27 | def forward(self, x, x_mask, extract_feats=None): 28 | """video embedding for x 29 | 30 | :param torch.Tensor x: input tensor 31 | :param torch.Tensor x_mask: input mask 32 | :param str extract_features: the position for feature extraction 33 | :return: subsampled x and mask 34 | :rtype Tuple[torch.Tensor, torch.Tensor] 35 | """ 36 | x_resnet, x_mask = self.trunk(x, x_mask) 37 | x = self.out(x_resnet) 38 | if extract_feats: 39 | return x, x_mask, x_resnet 40 | else: 41 | return x, x_mask 42 | 43 | 44 | class AudioEmbedding(torch.nn.Module): 45 | """Audio Embedding 46 | 47 | :param int idim: input dim 48 | :param int odim: output dim 49 | :param flaot dropout_rate: dropout rate 50 | """ 51 | 52 | def __init__(self, idim, odim, dropout_rate, pos_enc_class, relu_type="prelu", a_upsample_ratio=1): 53 | super(AudioEmbedding, self).__init__() 54 | self.trunk = Conv1dResNet( 55 | relu_type=relu_type, 56 | a_upsample_ratio=a_upsample_ratio, 57 | ) 58 | self.out = torch.nn.Sequential( 59 | torch.nn.Linear(idim, odim), 60 | pos_enc_class, 61 | ) 62 | 63 | def forward(self, x, x_mask, extract_feats=None): 64 | """audio embedding for x 65 | 66 | :param torch.Tensor x: input tensor 67 | :param torch.Tensor x_mask: input mask 68 | :param str extract_features: the position for feature extraction 69 | :return: subsampled x and mask 70 | :rtype Tuple[torch.Tensor, torch.Tensor] 71 | """ 72 | x_resnet, x_mask = self.trunk(x, x_mask) 73 | x = self.out(x_resnet) 74 | if extract_feats: 75 | return x, x_mask, x_resnet 76 | else: 77 | return x, x_mask 78 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/repeat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Repeat the same layer definition.""" 8 | 9 | import torch 10 | 11 | 12 | class MultiSequential(torch.nn.Sequential): 13 | """Multi-input multi-output torch.nn.Sequential.""" 14 | 15 | def forward(self, *args): 16 | """Repeat.""" 17 | for m in self: 18 | args = m(*args) 19 | return args 20 | 21 | 22 | def repeat(N, fn): 23 | """Repeat module N times. 24 | 25 | :param int N: repeat time 26 | :param function fn: function to generate module 27 | :return: repeated modules 28 | :rtype: MultiSequential 29 | """ 30 | return MultiSequential(*[fn() for _ in range(N)]) 31 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/subsampling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Subsampling layer definition.""" 8 | 9 | import torch 10 | 11 | 12 | class Conv2dSubsampling(torch.nn.Module): 13 | """Convolutional 2D subsampling (to 1/4 length). 14 | 15 | :param int idim: input dim 16 | :param int odim: output dim 17 | :param flaot dropout_rate: dropout rate 18 | :param nn.Module pos_enc_class: positional encoding layer 19 | 20 | """ 21 | 22 | def __init__(self, idim, odim, dropout_rate, pos_enc_class): 23 | """Construct an Conv2dSubsampling object.""" 24 | super(Conv2dSubsampling, self).__init__() 25 | self.conv = torch.nn.Sequential( 26 | torch.nn.Conv2d(1, odim, 3, 2), 27 | torch.nn.ReLU(), 28 | torch.nn.Conv2d(odim, odim, 3, 2), 29 | torch.nn.ReLU(), 30 | ) 31 | self.out = torch.nn.Sequential( 32 | torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim), pos_enc_class, 33 | ) 34 | 35 | def forward(self, x, x_mask): 36 | """Subsample x. 37 | 38 | :param torch.Tensor x: input tensor 39 | :param torch.Tensor x_mask: input mask 40 | :return: subsampled x and mask 41 | :rtype Tuple[torch.Tensor, torch.Tensor] 42 | or Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor] 43 | """ 44 | x = x.unsqueeze(1) # (b, c, t, f) 45 | x = self.conv(x) 46 | b, c, t, f = x.size() 47 | # if RelPositionalEncoding, x: Tuple[torch.Tensor, torch.Tensor] 48 | # else x: torch.Tensor 49 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) 50 | if x_mask is None: 51 | return x, None 52 | return x, x_mask[:, :, :-2:2][:, :, :-2:2] 53 | -------------------------------------------------------------------------------- /espnet/nets/scorer_interface.py: -------------------------------------------------------------------------------- 1 | """Scorer interface module.""" 2 | 3 | from typing import Any 4 | from typing import List 5 | from typing import Tuple 6 | 7 | import torch 8 | import warnings 9 | 10 | 11 | class ScorerInterface: 12 | """Scorer interface for beam search. 13 | 14 | The scorer performs scoring of the all tokens in vocabulary. 15 | 16 | Examples: 17 | * Search heuristics 18 | * :class:`espnet.nets.scorers.length_bonus.LengthBonus` 19 | * Decoder networks of the sequence-to-sequence models 20 | * :class:`espnet.nets.pytorch_backend.nets.transformer.decoder.Decoder` 21 | * :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder` 22 | * Neural language models 23 | * :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM` 24 | * :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM` 25 | * :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM` 26 | 27 | """ 28 | 29 | def init_state(self, x: torch.Tensor) -> Any: 30 | """Get an initial state for decoding (optional). 31 | 32 | Args: 33 | x (torch.Tensor): The encoded feature tensor 34 | 35 | Returns: initial state 36 | 37 | """ 38 | return None 39 | 40 | def select_state(self, state: Any, i: int, new_id: int = None) -> Any: 41 | """Select state with relative ids in the main beam search. 42 | 43 | Args: 44 | state: Decoder state for prefix tokens 45 | i (int): Index to select a state in the main beam search 46 | new_id (int): New label index to select a state if necessary 47 | 48 | Returns: 49 | state: pruned state 50 | 51 | """ 52 | return None if state is None else state[i] 53 | 54 | def score( 55 | self, y: torch.Tensor, state: Any, x: torch.Tensor 56 | ) -> Tuple[torch.Tensor, Any]: 57 | """Score new token (required). 58 | 59 | Args: 60 | y (torch.Tensor): 1D torch.int64 prefix tokens. 61 | state: Scorer state for prefix tokens 62 | x (torch.Tensor): The encoder feature that generates ys. 63 | 64 | Returns: 65 | tuple[torch.Tensor, Any]: Tuple of 66 | scores for next token that has a shape of `(n_vocab)` 67 | and next state for ys 68 | 69 | """ 70 | raise NotImplementedError 71 | 72 | def final_score(self, state: Any) -> float: 73 | """Score eos (optional). 74 | 75 | Args: 76 | state: Scorer state for prefix tokens 77 | 78 | Returns: 79 | float: final score 80 | 81 | """ 82 | return 0.0 83 | 84 | 85 | class BatchScorerInterface(ScorerInterface): 86 | """Batch scorer interface.""" 87 | 88 | def batch_init_state(self, x: torch.Tensor) -> Any: 89 | """Get an initial state for decoding (optional). 90 | 91 | Args: 92 | x (torch.Tensor): The encoded feature tensor 93 | 94 | Returns: initial state 95 | 96 | """ 97 | return self.init_state(x) 98 | 99 | def batch_score( 100 | self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor 101 | ) -> Tuple[torch.Tensor, List[Any]]: 102 | """Score new token batch (required). 103 | 104 | Args: 105 | ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). 106 | states (List[Any]): Scorer states for prefix tokens. 107 | xs (torch.Tensor): 108 | The encoder feature that generates ys (n_batch, xlen, n_feat). 109 | 110 | Returns: 111 | tuple[torch.Tensor, List[Any]]: Tuple of 112 | batchfied scores for next token with shape of `(n_batch, n_vocab)` 113 | and next state list for ys. 114 | 115 | """ 116 | warnings.warn( 117 | "{} batch score is implemented through for loop not parallelized".format( 118 | self.__class__.__name__ 119 | ) 120 | ) 121 | scores = list() 122 | outstates = list() 123 | for i, (y, state, x) in enumerate(zip(ys, states, xs)): 124 | score, outstate = self.score(y, state, x) 125 | outstates.append(outstate) 126 | scores.append(score) 127 | scores = torch.cat(scores, 0).view(ys.shape[0], -1) 128 | return scores, outstates 129 | 130 | 131 | class PartialScorerInterface(ScorerInterface): 132 | """Partial scorer interface for beam search. 133 | 134 | The partial scorer performs scoring when non-partial scorer finished scoring, 135 | and receives pre-pruned next tokens to score because it is too heavy to score 136 | all the tokens. 137 | 138 | Examples: 139 | * Prefix search for connectionist-temporal-classification models 140 | * :class:`espnet.nets.scorers.ctc.CTCPrefixScorer` 141 | 142 | """ 143 | 144 | def score_partial( 145 | self, y: torch.Tensor, next_tokens: torch.Tensor, state: Any, x: torch.Tensor 146 | ) -> Tuple[torch.Tensor, Any]: 147 | """Score new token (required). 148 | 149 | Args: 150 | y (torch.Tensor): 1D prefix token 151 | next_tokens (torch.Tensor): torch.int64 next token to score 152 | state: decoder state for prefix tokens 153 | x (torch.Tensor): The encoder feature that generates ys 154 | 155 | Returns: 156 | tuple[torch.Tensor, Any]: 157 | Tuple of a score tensor for y that has a shape `(len(next_tokens),)` 158 | and next state for ys 159 | 160 | """ 161 | raise NotImplementedError 162 | 163 | 164 | class BatchPartialScorerInterface(BatchScorerInterface, PartialScorerInterface): 165 | """Batch partial scorer interface for beam search.""" 166 | 167 | def batch_score_partial( 168 | self, 169 | ys: torch.Tensor, 170 | next_tokens: torch.Tensor, 171 | states: List[Any], 172 | xs: torch.Tensor, 173 | ) -> Tuple[torch.Tensor, Any]: 174 | """Score new token (required). 175 | 176 | Args: 177 | ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). 178 | next_tokens (torch.Tensor): torch.int64 tokens to score (n_batch, n_token). 179 | states (List[Any]): Scorer states for prefix tokens. 180 | xs (torch.Tensor): 181 | The encoder feature that generates ys (n_batch, xlen, n_feat). 182 | 183 | Returns: 184 | tuple[torch.Tensor, Any]: 185 | Tuple of a score tensor for ys that has a shape `(n_batch, n_vocab)` 186 | and next states for ys 187 | """ 188 | raise NotImplementedError 189 | -------------------------------------------------------------------------------- /espnet/nets/scorers/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize sub package.""" 2 | -------------------------------------------------------------------------------- /espnet/nets/scorers/ctc.py: -------------------------------------------------------------------------------- 1 | """ScorerInterface implementation for CTC.""" 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from espnet.nets.ctc_prefix_score import CTCPrefixScore 7 | from espnet.nets.ctc_prefix_score import CTCPrefixScoreTH 8 | from espnet.nets.scorer_interface import BatchPartialScorerInterface 9 | 10 | 11 | class CTCPrefixScorer(BatchPartialScorerInterface): 12 | """Decoder interface wrapper for CTCPrefixScore.""" 13 | 14 | def __init__(self, ctc: torch.nn.Module, eos: int): 15 | """Initialize class. 16 | 17 | Args: 18 | ctc (torch.nn.Module): The CTC implementation. 19 | For example, :class:`espnet.nets.pytorch_backend.ctc.CTC` 20 | eos (int): The end-of-sequence id. 21 | 22 | """ 23 | self.ctc = ctc 24 | self.eos = eos 25 | self.impl = None 26 | 27 | def init_state(self, x: torch.Tensor): 28 | """Get an initial state for decoding. 29 | 30 | Args: 31 | x (torch.Tensor): The encoded feature tensor 32 | 33 | Returns: initial state 34 | 35 | """ 36 | logp = self.ctc.log_softmax(x.unsqueeze(0)).detach().squeeze(0).cpu().numpy() 37 | # TODO(karita): use CTCPrefixScoreTH 38 | self.impl = CTCPrefixScore(logp, 0, self.eos, np) 39 | return 0, self.impl.initial_state() 40 | 41 | def select_state(self, state, i, new_id=None): 42 | """Select state with relative ids in the main beam search. 43 | 44 | Args: 45 | state: Decoder state for prefix tokens 46 | i (int): Index to select a state in the main beam search 47 | new_id (int): New label id to select a state if necessary 48 | 49 | Returns: 50 | state: pruned state 51 | 52 | """ 53 | if type(state) == tuple: 54 | if len(state) == 2: # for CTCPrefixScore 55 | sc, st = state 56 | return sc[i], st[i] 57 | else: # for CTCPrefixScoreTH (need new_id > 0) 58 | r, log_psi, f_min, f_max, scoring_idmap = state 59 | s = log_psi[i, new_id].expand(log_psi.size(1)) 60 | if scoring_idmap is not None: 61 | return r[:, :, i, scoring_idmap[i, new_id]], s, f_min, f_max 62 | else: 63 | return r[:, :, i, new_id], s, f_min, f_max 64 | return None if state is None else state[i] 65 | 66 | def score_partial(self, y, ids, state, x): 67 | """Score new token. 68 | 69 | Args: 70 | y (torch.Tensor): 1D prefix token 71 | next_tokens (torch.Tensor): torch.int64 next token to score 72 | state: decoder state for prefix tokens 73 | x (torch.Tensor): 2D encoder feature that generates ys 74 | 75 | Returns: 76 | tuple[torch.Tensor, Any]: 77 | Tuple of a score tensor for y that has a shape `(len(next_tokens),)` 78 | and next state for ys 79 | 80 | """ 81 | prev_score, state = state 82 | presub_score, new_st = self.impl(y.cpu(), ids.cpu(), state) 83 | tscore = torch.as_tensor( 84 | presub_score - prev_score, device=x.device, dtype=x.dtype 85 | ) 86 | return tscore, (presub_score, new_st) 87 | 88 | def batch_init_state(self, x: torch.Tensor): 89 | """Get an initial state for decoding. 90 | 91 | Args: 92 | x (torch.Tensor): The encoded feature tensor 93 | 94 | Returns: initial state 95 | 96 | """ 97 | logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1 98 | xlen = torch.tensor([logp.size(1)]) 99 | self.impl = CTCPrefixScoreTH(logp, xlen, 0, self.eos) 100 | return None 101 | 102 | def batch_score_partial(self, y, ids, state, x): 103 | """Score new token. 104 | 105 | Args: 106 | y (torch.Tensor): 1D prefix token 107 | ids (torch.Tensor): torch.int64 next token to score 108 | state: decoder state for prefix tokens 109 | x (torch.Tensor): 2D encoder feature that generates ys 110 | 111 | Returns: 112 | tuple[torch.Tensor, Any]: 113 | Tuple of a score tensor for y that has a shape `(len(next_tokens),)` 114 | and next state for ys 115 | 116 | """ 117 | batch_state = ( 118 | ( 119 | torch.stack([s[0] for s in state], dim=2), 120 | torch.stack([s[1] for s in state]), 121 | state[0][2], 122 | state[0][3], 123 | ) 124 | if state[0] is not None 125 | else None 126 | ) 127 | return self.impl(y, batch_state, ids) 128 | 129 | def extend_prob(self, x: torch.Tensor): 130 | """Extend probs for decoding. 131 | 132 | This extension is for streaming decoding 133 | as in Eq (14) in https://arxiv.org/abs/2006.14941 134 | 135 | Args: 136 | x (torch.Tensor): The encoded feature tensor 137 | 138 | """ 139 | logp = self.ctc.log_softmax(x.unsqueeze(0)) 140 | self.impl.extend_prob(logp) 141 | 142 | def extend_state(self, state): 143 | """Extend state for decoding. 144 | 145 | This extension is for streaming decoding 146 | as in Eq (14) in https://arxiv.org/abs/2006.14941 147 | 148 | Args: 149 | state: The states of hyps 150 | 151 | Returns: exteded state 152 | 153 | """ 154 | new_state = [] 155 | for s in state: 156 | new_state.append(self.impl.extend_state(s)) 157 | 158 | return new_state 159 | -------------------------------------------------------------------------------- /espnet/nets/scorers/length_bonus.py: -------------------------------------------------------------------------------- 1 | """Length bonus module.""" 2 | from typing import Any 3 | from typing import List 4 | from typing import Tuple 5 | 6 | import torch 7 | 8 | from espnet.nets.scorer_interface import BatchScorerInterface 9 | 10 | 11 | class LengthBonus(BatchScorerInterface): 12 | """Length bonus in beam search.""" 13 | 14 | def __init__(self, n_vocab: int): 15 | """Initialize class. 16 | 17 | Args: 18 | n_vocab (int): The number of tokens in vocabulary for beam search 19 | 20 | """ 21 | self.n = n_vocab 22 | 23 | def score(self, y, state, x): 24 | """Score new token. 25 | 26 | Args: 27 | y (torch.Tensor): 1D torch.int64 prefix tokens. 28 | state: Scorer state for prefix tokens 29 | x (torch.Tensor): 2D encoder feature that generates ys. 30 | 31 | Returns: 32 | tuple[torch.Tensor, Any]: Tuple of 33 | torch.float32 scores for next token (n_vocab) 34 | and None 35 | 36 | """ 37 | return torch.tensor([1.0], device=x.device, dtype=x.dtype).expand(self.n), None 38 | 39 | def batch_score( 40 | self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor 41 | ) -> Tuple[torch.Tensor, List[Any]]: 42 | """Score new token batch. 43 | 44 | Args: 45 | ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). 46 | states (List[Any]): Scorer states for prefix tokens. 47 | xs (torch.Tensor): 48 | The encoder feature that generates ys (n_batch, xlen, n_feat). 49 | 50 | Returns: 51 | tuple[torch.Tensor, List[Any]]: Tuple of 52 | batchfied scores for next token with shape of `(n_batch, n_vocab)` 53 | and next state list for ys. 54 | 55 | """ 56 | return ( 57 | torch.tensor([1.0], device=xs.device, dtype=xs.dtype).expand( 58 | ys.shape[0], self.n 59 | ), 60 | None, 61 | ) 62 | -------------------------------------------------------------------------------- /espnet/utils/cli_utils.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | from distutils.util import strtobool as dist_strtobool 3 | import sys 4 | 5 | import numpy 6 | 7 | 8 | def strtobool(x): 9 | # distutils.util.strtobool returns integer, but it's confusing, 10 | return bool(dist_strtobool(x)) 11 | 12 | 13 | def get_commandline_args(): 14 | extra_chars = [ 15 | " ", 16 | ";", 17 | "&", 18 | "(", 19 | ")", 20 | "|", 21 | "^", 22 | "<", 23 | ">", 24 | "?", 25 | "*", 26 | "[", 27 | "]", 28 | "$", 29 | "`", 30 | '"', 31 | "\\", 32 | "!", 33 | "{", 34 | "}", 35 | ] 36 | 37 | # Escape the extra characters for shell 38 | argv = [ 39 | arg.replace("'", "'\\''") 40 | if all(char not in arg for char in extra_chars) 41 | else "'" + arg.replace("'", "'\\''") + "'" 42 | for arg in sys.argv 43 | ] 44 | 45 | return sys.executable + " " + " ".join(argv) 46 | 47 | 48 | def is_scipy_wav_style(value): 49 | # If Tuple[int, numpy.ndarray] or not 50 | return ( 51 | isinstance(value, Sequence) 52 | and len(value) == 2 53 | and isinstance(value[0], int) 54 | and isinstance(value[1], numpy.ndarray) 55 | ) 56 | 57 | 58 | def assert_scipy_wav_style(value): 59 | assert is_scipy_wav_style( 60 | value 61 | ), "Must be Tuple[int, numpy.ndarray], but got {}".format( 62 | type(value) 63 | if not isinstance(value, Sequence) 64 | else "{}[{}]".format(type(value), ", ".join(str(type(v)) for v in value)) 65 | ) 66 | -------------------------------------------------------------------------------- /espnet/utils/dynamic_import.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | def dynamic_import(import_path, alias=dict()): 5 | """dynamic import module and class 6 | 7 | :param str import_path: syntax 'module_name:class_name' 8 | e.g., 'espnet.transform.add_deltas:AddDeltas' 9 | :param dict alias: shortcut for registered class 10 | :return: imported class 11 | """ 12 | if import_path not in alias and ":" not in import_path: 13 | raise ValueError( 14 | "import_path should be one of {} or " 15 | 'include ":", e.g. "espnet.transform.add_deltas:AddDeltas" : ' 16 | "{}".format(set(alias), import_path) 17 | ) 18 | if ":" not in import_path: 19 | import_path = alias[import_path] 20 | 21 | module_name, objname = import_path.split(":") 22 | m = importlib.import_module(module_name) 23 | return getattr(m, objname) 24 | -------------------------------------------------------------------------------- /espnet/utils/fill_missing_args.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2018 Nagoya University (Tomoki Hayashi) 4 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 5 | 6 | import argparse 7 | import logging 8 | 9 | 10 | def fill_missing_args(args, add_arguments): 11 | """Fill missing arguments in args. 12 | 13 | Args: 14 | args (Namespace or None): Namesapce containing hyperparameters. 15 | add_arguments (function): Function to add arguments. 16 | 17 | Returns: 18 | Namespace: Arguments whose missing ones are filled with default value. 19 | 20 | Examples: 21 | >>> from argparse import Namespace 22 | >>> from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import Tacotron2 23 | >>> args = Namespace() 24 | >>> fill_missing_args(args, Tacotron2.add_arguments_fn) 25 | Namespace(aconv_chans=32, aconv_filts=15, adim=512, atype='location', ...) 26 | 27 | """ 28 | # check argument type 29 | assert isinstance(args, argparse.Namespace) or args is None 30 | assert callable(add_arguments) 31 | 32 | # get default arguments 33 | default_args, _ = add_arguments(argparse.ArgumentParser()).parse_known_args() 34 | 35 | # convert to dict 36 | args = {} if args is None else vars(args) 37 | default_args = vars(default_args) 38 | 39 | for key, value in default_args.items(): 40 | if key not in args: 41 | logging.info( 42 | 'attribute "%s" does not exist. use default %s.' % (key, str(value)) 43 | ) 44 | args[key] = value 45 | 46 | return argparse.Namespace(**args) 47 | -------------------------------------------------------------------------------- /hydra_configs/default.yaml: -------------------------------------------------------------------------------- 1 | # config.yaml 2 | 3 | defaults: 4 | - _self_ 5 | - override hydra/hydra_logging: disabled 6 | - override hydra/job_logging: disabled 7 | hydra: 8 | output_subdir: null 9 | run: 10 | dir: . 11 | config_filename: null 12 | data_dir: null 13 | data_filename: null 14 | data_ext: ".mp4" 15 | landmarks_dir: null 16 | landmarks_filename: null 17 | landmarks_ext: ".pkl" 18 | labels_filename: null 19 | detector: retinaface 20 | dst_filename: null 21 | gpu_idx: 0 22 | output_subdir: null 23 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import hydra 3 | import cv2 4 | import time 5 | from pipelines.pipeline import InferencePipeline 6 | import numpy as np 7 | from datetime import datetime 8 | from ollama import chat 9 | from pydantic import BaseModel 10 | import keyboard 11 | from concurrent.futures import ThreadPoolExecutor 12 | import os 13 | 14 | 15 | # pydantic model for the chat output 16 | class ChaplinOutput(BaseModel): 17 | list_of_changes: str 18 | corrected_text: str 19 | 20 | 21 | class Chaplin: 22 | def __init__(self): 23 | self.vsr_model = None 24 | 25 | # flag to toggle recording 26 | self.recording = False 27 | 28 | # thread stuff 29 | self.executor = ThreadPoolExecutor(max_workers=1) 30 | 31 | # video params 32 | self.output_prefix = "webcam" 33 | self.res_factor = 3 34 | self.fps = 16 35 | self.frame_interval = 1 / self.fps 36 | self.frame_compression = 25 37 | 38 | def perform_inference(self, video_path): 39 | # perform inference on the video with the vsr model 40 | output = self.vsr_model(video_path) 41 | 42 | # write the raw output 43 | keyboard.write(output) 44 | 45 | # shift left to select the entire output 46 | cmd = "" 47 | for i in range(len(output)): 48 | cmd += 'shift+left, ' 49 | cmd = cmd[:-2] 50 | keyboard.press_and_release(cmd) 51 | 52 | # perform inference on the raw output to get back a "correct" version 53 | response = chat( 54 | model='llama3.2', 55 | messages=[ 56 | { 57 | 'role': 'system', 58 | 'content': f"You are an assistant that helps make corrections to the output of a lipreading model. The text you will receive was transcribed using a video-to-text system that attempts to lipread the subject speaking in the video, so the text will likely be imperfect.\n\nIf something seems unusual, assume it was mistranscribed. Do your best to infer the words actually spoken, and make changes to the mistranscriptions in your response. Do not add more words or content, just change the ones that seem to be out of place (and, therefore, mistranscribed). Do not change even the wording of sentences, just individual words that look nonsensical in the context of all of the other words in the sentence.\n\nAlso, add correct punctuation to the entire text. ALWAYS end each sentence with the appropriate sentence ending: '.', '?', or '!'. The input text in all-caps, although your respose should be capitalized correctly and should NOT be in all-caps.\n\nReturn the corrected text in the format of 'list_of_changes' and 'corrected_text'." 59 | }, 60 | { 61 | 'role': 'user', 62 | 'content': f"Transcription:\n\n{output}" 63 | } 64 | ], 65 | format=ChaplinOutput.model_json_schema() 66 | ) 67 | 68 | # get only the corrected text 69 | chat_output = ChaplinOutput.model_validate_json( 70 | response.message.content) 71 | 72 | # if last character isn't a sentence ending (happens sometimes), add a period 73 | if chat_output.corrected_text[-1] not in ['.', '?', '!']: 74 | chat_output.corrected_text += '.' 75 | 76 | # write the corrected text 77 | keyboard.write(chat_output.corrected_text + " ") 78 | 79 | # return the corrected text and the video path 80 | return { 81 | "output": chat_output.corrected_text, 82 | "video_path": video_path 83 | } 84 | 85 | def start_webcam(self): 86 | # init webcam 87 | cap = cv2.VideoCapture(0) 88 | 89 | # set webcam resolution, and get frame dimensions 90 | cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640 // self.res_factor) 91 | cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480 // self.res_factor) 92 | frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 93 | frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 94 | 95 | last_frame_time = time.time() 96 | 97 | futures = [] 98 | output_path = "" 99 | out = None 100 | frame_count = 0 101 | 102 | while True: 103 | key = cv2.waitKey(1) & 0xFF 104 | if key == ord('q'): 105 | # remove any remaining videos that were saved to disk 106 | for file in os.listdir(): 107 | if file.startswith(self.output_prefix) and file.endswith('.mp4'): 108 | os.remove(file) 109 | break 110 | 111 | current_time = time.time() 112 | 113 | # conditional ensures that the video is recorded at the correct frame rate 114 | if current_time - last_frame_time >= self.frame_interval: 115 | ret, frame = cap.read() 116 | if ret: 117 | # frame compression 118 | encode_param = [ 119 | int(cv2.IMWRITE_JPEG_QUALITY), self.frame_compression] 120 | _, buffer = cv2.imencode('.jpg', frame, encode_param) 121 | compressed_frame = cv2.imdecode( 122 | buffer, cv2.IMREAD_GRAYSCALE) 123 | 124 | if self.recording: 125 | if out is None: 126 | output_path = self.output_prefix + \ 127 | str(time.time_ns() // 1_000_000) + '.mp4' 128 | out = cv2.VideoWriter( 129 | output_path, 130 | cv2.VideoWriter_fourcc(*'mp4v'), 131 | self.fps, 132 | (frame_width, frame_height), 133 | False # isColor 134 | ) 135 | 136 | out.write(compressed_frame) 137 | 138 | last_frame_time = current_time 139 | 140 | # circle to indicate recording, only appears in the window and is not present in video saved to disk 141 | cv2.circle(compressed_frame, (frame_width - 142 | 20, 20), 10, (0, 0, 0), -1) 143 | 144 | frame_count += 1 145 | # check if not recording AND video is at least 2 seconds long 146 | elif not self.recording and frame_count > 0: 147 | if out is not None: 148 | out.release() 149 | 150 | # only run inference if the video is at least 2 seconds long 151 | if frame_count >= self.fps * 2: 152 | futures.append(self.executor.submit( 153 | self.perform_inference, output_path)) 154 | else: 155 | os.remove(output_path) 156 | 157 | output_path = self.output_prefix + \ 158 | str(time.time_ns() // 1_000_000) + '.mp4' 159 | out = cv2.VideoWriter( 160 | output_path, 161 | cv2.VideoWriter_fourcc(*'mp4v'), 162 | self.fps, 163 | (frame_width, frame_height), 164 | False # isColor 165 | ) 166 | 167 | frame_count = 0 168 | 169 | # display the frame in the window 170 | cv2.imshow('Chaplin', cv2.flip(compressed_frame, 1)) 171 | 172 | # ensures that videos are handled in the order they were recorded 173 | for fut in futures: 174 | if fut.done(): 175 | result = fut.result() 176 | # once done processing, delete the video with the video path 177 | os.remove(result["video_path"]) 178 | futures.remove(fut) 179 | else: 180 | break 181 | 182 | # release everything 183 | cap.release() 184 | if out: 185 | out.release() 186 | cv2.destroyAllWindows() 187 | 188 | def on_action(self, event): 189 | # toggles recording when alt key is pressed 190 | if event.event_type == keyboard.KEY_DOWN and event.name == 'alt': 191 | self.recording = not self.recording 192 | 193 | 194 | @hydra.main(version_base=None, config_path="hydra_configs", config_name="default") 195 | def main(cfg): 196 | chaplin = Chaplin() 197 | 198 | # hook to toggle recording 199 | keyboard.hook(lambda e: chaplin.on_action(e)) 200 | 201 | # load the model 202 | chaplin.vsr_model = InferencePipeline( 203 | cfg.config_filename, device=torch.device(f"cuda:{cfg.gpu_idx}" if torch.cuda.is_available( 204 | ) and cfg.gpu_idx >= 0 else "cpu"), detector=cfg.detector, face_track=True) 205 | print("Model loaded successfully!") 206 | 207 | # start the webcam video capture 208 | chaplin.start_webcam() 209 | 210 | 211 | if __name__ == '__main__': 212 | main() 213 | -------------------------------------------------------------------------------- /pipelines/data/data_module.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2023 Imperial College London (Pingchuan Ma) 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | import torch 8 | import torchaudio 9 | import torchvision 10 | from .transforms import AudioTransform, VideoTransform 11 | 12 | 13 | class AVSRDataLoader: 14 | def __init__(self, modality, speed_rate=1, transform=True, detector="retinaface", convert_gray=True): 15 | self.modality = modality 16 | self.transform = transform 17 | if self.modality in ["audio", "audiovisual"]: 18 | self.audio_transform = AudioTransform() 19 | if self.modality in ["video", "audiovisual"]: 20 | if detector == "mediapipe": 21 | from pipelines.detectors.mediapipe.video_process import VideoProcess 22 | self.video_process = VideoProcess(convert_gray=convert_gray) 23 | if detector == "retinaface": 24 | from pipelines.detectors.retinaface.video_process import VideoProcess 25 | self.video_process = VideoProcess(convert_gray=convert_gray) 26 | self.video_transform = VideoTransform(speed_rate=speed_rate) 27 | 28 | 29 | def load_data(self, data_filename, landmarks=None, transform=True): 30 | if self.modality == "audio": 31 | audio, sample_rate = self.load_audio(data_filename) 32 | audio = self.audio_process(audio, sample_rate) 33 | return self.audio_transform(audio) if self.transform else audio 34 | if self.modality == "video": 35 | video = self.load_video(data_filename) 36 | video = self.video_process(video, landmarks) 37 | video = torch.tensor(video) 38 | return self.video_transform(video) if self.transform else video 39 | if self.modality == "audiovisual": 40 | rate_ratio = 640 41 | audio, sample_rate = self.load_audio(data_filename) 42 | audio = self.audio_process(audio, sample_rate) 43 | video = self.load_video(data_filename) 44 | video = self.video_process(video, landmarks) 45 | video = torch.tensor(video) 46 | min_t = min(len(video), audio.size(1) // rate_ratio) 47 | audio = audio[:, :min_t*rate_ratio] 48 | video = video[:min_t] 49 | if self.transform: 50 | audio = self.audio_transform(audio) 51 | video = self.video_transform(video) 52 | return video, audio 53 | 54 | 55 | def load_audio(self, data_filename): 56 | waveform, sample_rate = torchaudio.load(data_filename, normalize=True) 57 | return waveform, sample_rate 58 | 59 | 60 | def load_video(self, data_filename): 61 | return torchvision.io.read_video(data_filename, pts_unit='sec')[0].numpy() 62 | 63 | 64 | def audio_process(self, waveform, sample_rate, target_sample_rate=16000): 65 | if sample_rate != target_sample_rate: 66 | waveform = torchaudio.functional.resample(waveform, sample_rate, target_sample_rate) 67 | waveform = torch.mean(waveform, dim=0, keepdim=True) 68 | return waveform 69 | -------------------------------------------------------------------------------- /pipelines/data/noise/babble_noise.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amanvirparhar/chaplin/b68158d7d7b56fbc7631a1df47c7c9e4e5f23e2f/pipelines/data/noise/babble_noise.wav -------------------------------------------------------------------------------- /pipelines/data/noise/pink_noise.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amanvirparhar/chaplin/b68158d7d7b56fbc7631a1df47c7c9e4e5f23e2f/pipelines/data/noise/pink_noise.wav -------------------------------------------------------------------------------- /pipelines/data/noise/white_noise.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amanvirparhar/chaplin/b68158d7d7b56fbc7631a1df47c7c9e4e5f23e2f/pipelines/data/noise/white_noise.wav -------------------------------------------------------------------------------- /pipelines/data/transforms.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2023 Imperial College London (Pingchuan Ma) 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | import torch 8 | import torchaudio 9 | import torchvision 10 | 11 | 12 | class FunctionalModule(torch.nn.Module): 13 | def __init__(self, functional): 14 | super().__init__() 15 | self.functional = functional 16 | 17 | def forward(self, input): 18 | return self.functional(input) 19 | 20 | 21 | class VideoTransform: 22 | def __init__(self, speed_rate): 23 | self.video_pipeline = torch.nn.Sequential( 24 | FunctionalModule(lambda x: x.unsqueeze(-1)), 25 | FunctionalModule(lambda x: x if speed_rate == 1 else torch.index_select(x, dim=0, index=torch.linspace(0, x.shape[0]-1, int(x.shape[0] / speed_rate), dtype=torch.int64))), 26 | FunctionalModule(lambda x: x.permute(3, 0, 1, 2)), 27 | FunctionalModule(lambda x: x / 255.), 28 | torchvision.transforms.CenterCrop(88), 29 | torchvision.transforms.Normalize(0.421, 0.165), 30 | ) 31 | 32 | def __call__(self, sample): 33 | return self.video_pipeline(sample) 34 | 35 | 36 | class AudioTransform: 37 | def __init__(self): 38 | self.audio_pipeline = torch.nn.Sequential( 39 | FunctionalModule(lambda x: torch.nn.functional.layer_norm(x, x.shape, eps=0)), 40 | FunctionalModule(lambda x: x.transpose(0, 1)), 41 | ) 42 | 43 | def __call__(self, sample): 44 | return self.audio_pipeline(sample) 45 | -------------------------------------------------------------------------------- /pipelines/detectors/mediapipe/20words_mean_face.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amanvirparhar/chaplin/b68158d7d7b56fbc7631a1df47c7c9e4e5f23e2f/pipelines/detectors/mediapipe/20words_mean_face.npy -------------------------------------------------------------------------------- /pipelines/detectors/mediapipe/detector.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2021 Imperial College London (Pingchuan Ma) 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | import warnings 8 | import torchvision 9 | import mediapipe as mp 10 | import os 11 | import cv2 12 | import numpy as np 13 | 14 | 15 | class LandmarksDetector: 16 | def __init__(self): 17 | self.mp_face_detection = mp.solutions.face_detection 18 | self.short_range_detector = self.mp_face_detection.FaceDetection(min_detection_confidence=0.5, model_selection=0) 19 | self.full_range_detector = self.mp_face_detection.FaceDetection(min_detection_confidence=0.5, model_selection=1) 20 | 21 | def __call__(self, filename): 22 | video_frames = torchvision.io.read_video(filename, pts_unit='sec')[0].numpy() 23 | landmarks = self.detect(video_frames, self.full_range_detector) 24 | if all(element is None for element in landmarks): 25 | landmarks = self.detect(video_frames, self.short_range_detector) 26 | assert any(l is not None for l in landmarks), "Cannot detect any frames in the video" 27 | return landmarks 28 | 29 | def detect(self, video_frames, detector): 30 | landmarks = [] 31 | for frame in video_frames: 32 | results = detector.process(frame) 33 | if not results.detections: 34 | landmarks.append(None) 35 | continue 36 | face_points = [] 37 | for idx, detected_faces in enumerate(results.detections): 38 | max_id, max_size = 0, 0 39 | bboxC = detected_faces.location_data.relative_bounding_box 40 | ih, iw, ic = frame.shape 41 | bbox = int(bboxC.xmin * iw), int(bboxC.ymin * ih), int(bboxC.width * iw), int(bboxC.height * ih) 42 | bbox_size = (bbox[2] - bbox[0]) + (bbox[3] - bbox[1]) 43 | if bbox_size > max_size: 44 | max_id, max_size = idx, bbox_size 45 | lmx = [ 46 | [int(detected_faces.location_data.relative_keypoints[self.mp_face_detection.FaceKeyPoint(0).value].x * iw), 47 | int(detected_faces.location_data.relative_keypoints[self.mp_face_detection.FaceKeyPoint(0).value].y * ih)], 48 | [int(detected_faces.location_data.relative_keypoints[self.mp_face_detection.FaceKeyPoint(1).value].x * iw), 49 | int(detected_faces.location_data.relative_keypoints[self.mp_face_detection.FaceKeyPoint(1).value].y * ih)], 50 | [int(detected_faces.location_data.relative_keypoints[self.mp_face_detection.FaceKeyPoint(2).value].x * iw), 51 | int(detected_faces.location_data.relative_keypoints[self.mp_face_detection.FaceKeyPoint(2).value].y * ih)], 52 | [int(detected_faces.location_data.relative_keypoints[self.mp_face_detection.FaceKeyPoint(3).value].x * iw), 53 | int(detected_faces.location_data.relative_keypoints[self.mp_face_detection.FaceKeyPoint(3).value].y * ih)], 54 | ] 55 | face_points.append(lmx) 56 | landmarks.append(np.array(face_points[max_id])) 57 | return landmarks 58 | -------------------------------------------------------------------------------- /pipelines/detectors/mediapipe/video_process.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2023 Imperial College London (Pingchuan Ma) 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | import os 8 | import cv2 9 | import numpy as np 10 | from skimage import transform as tf 11 | 12 | 13 | def linear_interpolate(landmarks, start_idx, stop_idx): 14 | start_landmarks = landmarks[start_idx] 15 | stop_landmarks = landmarks[stop_idx] 16 | delta = stop_landmarks - start_landmarks 17 | for idx in range(1, stop_idx-start_idx): 18 | landmarks[start_idx+idx] = start_landmarks + idx/float(stop_idx-start_idx) * delta 19 | return landmarks 20 | 21 | 22 | def warp_img(src, dst, img, std_size): 23 | tform = tf.estimate_transform('similarity', src, dst) 24 | warped = tf.warp(img, inverse_map=tform.inverse, output_shape=std_size) 25 | warped = (warped * 255).astype('uint8') 26 | return warped, tform 27 | 28 | 29 | def apply_transform(transform, img, std_size): 30 | warped = tf.warp(img, inverse_map=transform.inverse, output_shape=std_size) 31 | warped = (warped * 255).astype('uint8') 32 | return warped 33 | 34 | 35 | def cut_patch(img, landmarks, height, width, threshold=5): 36 | center_x, center_y = np.mean(landmarks, axis=0) 37 | # Check for too much bias in height and width 38 | if abs(center_y - img.shape[0] / 2) > height + threshold: 39 | raise Exception('too much bias in height') 40 | if abs(center_x - img.shape[1] / 2) > width + threshold: 41 | raise Exception('too much bias in width') 42 | # Calculate bounding box coordinates 43 | y_min = int(round(np.clip(center_y - height, 0, img.shape[0]))) 44 | y_max = int(round(np.clip(center_y + height, 0, img.shape[0]))) 45 | x_min = int(round(np.clip(center_x - width, 0, img.shape[1]))) 46 | x_max = int(round(np.clip(center_x + width, 0, img.shape[1]))) 47 | # Cut the image 48 | cutted_img = np.copy(img[y_min:y_max, x_min:x_max]) 49 | return cutted_img 50 | 51 | 52 | class VideoProcess: 53 | def __init__(self, mean_face_path="20words_mean_face.npy", crop_width=96, crop_height=96, 54 | start_idx=3, stop_idx=4, window_margin=12, convert_gray=True): 55 | self.reference = np.load(os.path.join(os.path.dirname(__file__), mean_face_path)) 56 | self.crop_width = crop_width 57 | self.crop_height = crop_height 58 | self.start_idx = start_idx 59 | self.stop_idx = stop_idx 60 | self.window_margin = window_margin 61 | self.convert_gray = convert_gray 62 | 63 | def __call__(self, video, landmarks): 64 | # Pre-process landmarks: interpolate frames that are not detected 65 | preprocessed_landmarks = self.interpolate_landmarks(landmarks) 66 | # Exclude corner cases: no landmark in all frames 67 | if not preprocessed_landmarks: 68 | return 69 | # Affine transformation and crop patch 70 | sequence = self.crop_patch(video, preprocessed_landmarks) 71 | assert sequence is not None, f"cannot crop a patch from {filename}." 72 | return sequence 73 | 74 | 75 | def crop_patch(self, video, landmarks): 76 | sequence = [] 77 | for frame_idx, frame in enumerate(video): 78 | window_margin = min(self.window_margin // 2, frame_idx, len(landmarks) - 1 - frame_idx) 79 | smoothed_landmarks = np.mean([landmarks[x] for x in range(frame_idx - window_margin, frame_idx + window_margin + 1)], axis=0) 80 | smoothed_landmarks += landmarks[frame_idx].mean(axis=0) - smoothed_landmarks.mean(axis=0) 81 | transformed_frame, transformed_landmarks = self.affine_transform(frame,smoothed_landmarks,self.reference,grayscale=self.convert_gray) 82 | patch = cut_patch(transformed_frame, transformed_landmarks[self.start_idx:self.stop_idx], self.crop_height//2, self.crop_width//2,) 83 | sequence.append(patch) 84 | return np.array(sequence) 85 | 86 | 87 | def interpolate_landmarks(self, landmarks): 88 | valid_frames_idx = [idx for idx, lm in enumerate(landmarks) if lm is not None] 89 | 90 | if not valid_frames_idx: 91 | return None 92 | 93 | for idx in range(1, len(valid_frames_idx)): 94 | if valid_frames_idx[idx] - valid_frames_idx[idx - 1] > 1: 95 | landmarks = linear_interpolate(landmarks, valid_frames_idx[idx - 1], valid_frames_idx[idx]) 96 | 97 | valid_frames_idx = [idx for idx, lm in enumerate(landmarks) if lm is not None] 98 | 99 | # Handle corner case: keep frames at the beginning or at the end that failed to be detected 100 | if valid_frames_idx: 101 | landmarks[:valid_frames_idx[0]] = [landmarks[valid_frames_idx[0]]] * valid_frames_idx[0] 102 | landmarks[valid_frames_idx[-1]:] = [landmarks[valid_frames_idx[-1]]] * (len(landmarks) - valid_frames_idx[-1]) 103 | 104 | assert all(lm is not None for lm in landmarks), "not every frame has landmark" 105 | 106 | return landmarks 107 | 108 | 109 | def affine_transform(self, frame, landmarks, reference, grayscale=False, 110 | target_size=(256, 256), reference_size=(256, 256), stable_points=(0, 1, 2, 3), 111 | interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_CONSTANT, border_value=0): 112 | if grayscale: 113 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) 114 | stable_reference = self.get_stable_reference(reference, reference_size, target_size) 115 | transform = self.estimate_affine_transform(landmarks, stable_points, stable_reference) 116 | transformed_frame, transformed_landmarks = self.apply_affine_transform(frame, landmarks, transform, target_size, interpolation, border_mode, border_value) 117 | 118 | return transformed_frame, transformed_landmarks 119 | 120 | 121 | def get_stable_reference(self, reference, reference_size, target_size): 122 | # -- right eye, left eye, nose tip, mouth center 123 | stable_reference = np.vstack([ 124 | np.mean(reference[36:42], axis=0), 125 | np.mean(reference[42:48], axis=0), 126 | np.mean(reference[31:36], axis=0), 127 | np.mean(reference[48:68], axis=0) 128 | ]) 129 | stable_reference[:, 0] -= (reference_size[0] - target_size[0]) / 2.0 130 | stable_reference[:, 1] -= (reference_size[1] - target_size[1]) / 2.0 131 | return stable_reference 132 | 133 | 134 | def estimate_affine_transform(self, landmarks, stable_points, stable_reference): 135 | return cv2.estimateAffinePartial2D(np.vstack([landmarks[x] for x in stable_points]), stable_reference, method=cv2.LMEDS)[0] 136 | 137 | 138 | def apply_affine_transform(self, frame, landmarks, transform, target_size, interpolation, border_mode, border_value): 139 | transformed_frame = cv2.warpAffine(frame, transform, dsize=(target_size[0], target_size[1]), 140 | flags=interpolation, borderMode=border_mode, borderValue=border_value) 141 | transformed_landmarks = np.matmul(landmarks, transform[:, :2].transpose()) + transform[:, 2].transpose() 142 | return transformed_frame, transformed_landmarks 143 | -------------------------------------------------------------------------------- /pipelines/detectors/retinaface/20words_mean_face.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amanvirparhar/chaplin/b68158d7d7b56fbc7631a1df47c7c9e4e5f23e2f/pipelines/detectors/retinaface/20words_mean_face.npy -------------------------------------------------------------------------------- /pipelines/detectors/retinaface/detector.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2021 Imperial College London (Pingchuan Ma) 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | import warnings 8 | import torchvision 9 | from ibug.face_detection import RetinaFacePredictor 10 | from ibug.face_alignment import FANPredictor 11 | warnings.filterwarnings("ignore") 12 | 13 | 14 | class LandmarksDetector: 15 | def __init__(self, device="cuda:0", model_name='resnet50'): 16 | self.face_detector = RetinaFacePredictor( 17 | device=device, 18 | threshold=0.8, 19 | model=RetinaFacePredictor.get_model(model_name) 20 | ) 21 | self.landmark_detector = FANPredictor(device=device, model=None) 22 | 23 | def __call__(self, filename): 24 | video_frames = torchvision.io.read_video(filename, pts_unit='sec')[0].numpy() 25 | landmarks = [] 26 | for frame in video_frames: 27 | detected_faces = self.face_detector(frame, rgb=False) 28 | face_points, _ = self.landmark_detector(frame, detected_faces, rgb=True) 29 | if len(detected_faces) == 0: 30 | landmarks.append(None) 31 | else: 32 | max_id, max_size = 0, 0 33 | for idx, bbox in enumerate(detected_faces): 34 | bbox_size = (bbox[2] - bbox[0]) + (bbox[3] - bbox[1]) 35 | if bbox_size > max_size: 36 | max_id, max_size = idx, bbox_size 37 | landmarks.append(face_points[max_id]) 38 | return landmarks 39 | -------------------------------------------------------------------------------- /pipelines/detectors/retinaface/video_process.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2023 Imperial College London (Pingchuan Ma) 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | import os 8 | import cv2 9 | import numpy as np 10 | from skimage import transform as tf 11 | 12 | 13 | def linear_interpolate(landmarks, start_idx, stop_idx): 14 | start_landmarks = landmarks[start_idx] 15 | stop_landmarks = landmarks[stop_idx] 16 | delta = stop_landmarks - start_landmarks 17 | for idx in range(1, stop_idx-start_idx): 18 | landmarks[start_idx+idx] = start_landmarks + idx/float(stop_idx-start_idx) * delta 19 | return landmarks 20 | 21 | 22 | def warp_img(src, dst, img, std_size): 23 | tform = tf.estimate_transform('similarity', src, dst) 24 | warped = tf.warp(img, inverse_map=tform.inverse, output_shape=std_size) 25 | warped = (warped * 255).astype('uint8') 26 | return warped, tform 27 | 28 | 29 | def apply_transform(transform, img, std_size): 30 | warped = tf.warp(img, inverse_map=transform.inverse, output_shape=std_size) 31 | warped = (warped * 255).astype('uint8') 32 | return warped 33 | 34 | 35 | def cut_patch(img, landmarks, height, width, threshold=5): 36 | center_x, center_y = np.mean(landmarks, axis=0) 37 | # Check for too much bias in height and width 38 | if abs(center_y - img.shape[0] / 2) > height + threshold: 39 | raise Exception('too much bias in height') 40 | if abs(center_x - img.shape[1] / 2) > width + threshold: 41 | raise Exception('too much bias in width') 42 | # Calculate bounding box coordinates 43 | y_min = int(round(np.clip(center_y - height, 0, img.shape[0]))) 44 | y_max = int(round(np.clip(center_y + height, 0, img.shape[0]))) 45 | x_min = int(round(np.clip(center_x - width, 0, img.shape[1]))) 46 | x_max = int(round(np.clip(center_x + width, 0, img.shape[1]))) 47 | # Cut the image 48 | cutted_img = np.copy(img[y_min:y_max, x_min:x_max]) 49 | return cutted_img 50 | 51 | 52 | class VideoProcess: 53 | def __init__(self, mean_face_path="20words_mean_face.npy", crop_width=96, crop_height=96, 54 | start_idx=48, stop_idx=68, window_margin=12, convert_gray=True): 55 | self.reference = np.load(os.path.join(os.path.dirname(__file__), mean_face_path)) 56 | self.crop_width = crop_width 57 | self.crop_height = crop_height 58 | self.start_idx = start_idx 59 | self.stop_idx = stop_idx 60 | self.window_margin = window_margin 61 | self.convert_gray = convert_gray 62 | 63 | def __call__(self, video, landmarks): 64 | # Pre-process landmarks: interpolate frames that are not detected 65 | preprocessed_landmarks = self.interpolate_landmarks(landmarks) 66 | # Exclude corner cases: no landmark in all frames or number of frames is less than window length 67 | if not preprocessed_landmarks or len(preprocessed_landmarks) < self.window_margin: 68 | return 69 | # Affine transformation and crop patch 70 | sequence = self.crop_patch(video, preprocessed_landmarks) 71 | assert sequence is not None, f"cannot crop a patch from {filename}." 72 | return sequence 73 | 74 | 75 | def crop_patch(self, video, landmarks): 76 | sequence = [] 77 | for frame_idx, frame in enumerate(video): 78 | window_margin = min(self.window_margin // 2, frame_idx, len(landmarks) - 1 - frame_idx) 79 | smoothed_landmarks = np.mean([landmarks[x] for x in range(frame_idx - window_margin, frame_idx + window_margin + 1)], axis=0) 80 | smoothed_landmarks += landmarks[frame_idx].mean(axis=0) - smoothed_landmarks.mean(axis=0) 81 | transformed_frame, transformed_landmarks = self.affine_transform(frame,smoothed_landmarks,self.reference,grayscale=self.convert_gray) 82 | patch = cut_patch(transformed_frame, transformed_landmarks[self.start_idx:self.stop_idx], self.crop_height//2, self.crop_width//2,) 83 | sequence.append(patch) 84 | return np.array(sequence) 85 | 86 | 87 | def interpolate_landmarks(self, landmarks): 88 | valid_frames_idx = [idx for idx, lm in enumerate(landmarks) if lm is not None] 89 | 90 | if not valid_frames_idx: 91 | return None 92 | 93 | for idx in range(1, len(valid_frames_idx)): 94 | if valid_frames_idx[idx] - valid_frames_idx[idx - 1] > 1: 95 | landmarks = linear_interpolate(landmarks, valid_frames_idx[idx - 1], valid_frames_idx[idx]) 96 | 97 | valid_frames_idx = [idx for idx, lm in enumerate(landmarks) if lm is not None] 98 | 99 | # Handle corner case: keep frames at the beginning or at the end that failed to be detected 100 | if valid_frames_idx: 101 | landmarks[:valid_frames_idx[0]] = [landmarks[valid_frames_idx[0]]] * valid_frames_idx[0] 102 | landmarks[valid_frames_idx[-1]:] = [landmarks[valid_frames_idx[-1]]] * (len(landmarks) - valid_frames_idx[-1]) 103 | 104 | assert all(lm is not None for lm in landmarks), "not every frame has landmark" 105 | 106 | return landmarks 107 | 108 | 109 | def affine_transform(self, frame, landmarks, reference, grayscale=True, 110 | target_size=(256, 256), reference_size=(256, 256), stable_points=(28, 33, 36, 39, 42, 45, 48, 54), 111 | interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_CONSTANT, border_value=0): 112 | if grayscale: 113 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) 114 | stable_reference = self.get_stable_reference(reference, stable_points, reference_size, target_size) 115 | transform = self.estimate_affine_transform(landmarks, stable_points, stable_reference) 116 | transformed_frame, transformed_landmarks = self.apply_affine_transform(frame, landmarks, transform, target_size, interpolation, border_mode, border_value) 117 | 118 | return transformed_frame, transformed_landmarks 119 | 120 | 121 | def get_stable_reference(self, reference, stable_points, reference_size, target_size): 122 | stable_reference = np.vstack([reference[x] for x in stable_points]) 123 | stable_reference[:, 0] -= (reference_size[0] - target_size[0]) / 2.0 124 | stable_reference[:, 1] -= (reference_size[1] - target_size[1]) / 2.0 125 | return stable_reference 126 | 127 | 128 | def estimate_affine_transform(self, landmarks, stable_points, stable_reference): 129 | return cv2.estimateAffinePartial2D(np.vstack([landmarks[x] for x in stable_points]), stable_reference, method=cv2.LMEDS)[0] 130 | 131 | 132 | def apply_affine_transform(self, frame, landmarks, transform, target_size, interpolation, border_mode, border_value): 133 | transformed_frame = cv2.warpAffine(frame, transform, dsize=(target_size[0], target_size[1]), 134 | flags=interpolation, borderMode=border_mode, borderValue=border_value) 135 | transformed_landmarks = np.matmul(landmarks, transform[:, :2].transpose()) + transform[:, 2].transpose() 136 | return transformed_frame, transformed_landmarks 137 | -------------------------------------------------------------------------------- /pipelines/metrics/measures.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2021 Imperial College London (Pingchuan Ma) 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | # This code refers https://github.com/espnet/espnet/blob/24c3676a8d4c2e60d2726e9bcd9bdbed740610e0/espnet/nets/e2e_asr_common.py#L213-L249 8 | 9 | import numpy as np 10 | 11 | def get_wer(s, ref): 12 | return get_er(s.split(), ref.split()) 13 | 14 | def get_cer(s, ref): 15 | return get_er(s.replace(" ", ""), ref.replace(" ", "")) 16 | 17 | def get_er(s, ref): 18 | """ 19 | FROM wikipedia levenshtein distance 20 | s: list of words/char in sentence to measure 21 | ref: list of words/char in reference 22 | """ 23 | 24 | costs = np.zeros((len(s) + 1, len(ref) + 1)) 25 | for i in range(len(s) + 1): 26 | costs[i, 0] = i 27 | for j in range(len(ref) + 1): 28 | costs[0, j] = j 29 | 30 | for j in range(1, len(ref) + 1): 31 | for i in range(1, len(s) + 1): 32 | cost = None 33 | if s[i-1] == ref[j-1]: 34 | cost = 0 35 | else: 36 | cost = 1 37 | costs[i,j] = min( 38 | costs[i-1, j] + 1, 39 | costs[i, j-1] + 1, 40 | costs[i-1, j-1] + cost 41 | ) 42 | 43 | return costs[-1,-1] / len(ref) 44 | -------------------------------------------------------------------------------- /pipelines/model.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2023 Imperial College London (Pingchuan Ma) 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | import os 8 | import json 9 | import torch 10 | import argparse 11 | import numpy as np 12 | 13 | from espnet.asr.asr_utils import torch_load 14 | from espnet.asr.asr_utils import get_model_conf 15 | from espnet.asr.asr_utils import add_results_to_json 16 | from espnet.nets.batch_beam_search import BatchBeamSearch 17 | from espnet.nets.lm_interface import dynamic_import_lm 18 | from espnet.nets.scorers.length_bonus import LengthBonus 19 | from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E 20 | 21 | 22 | class AVSR(torch.nn.Module): 23 | def __init__(self, modality, model_path, model_conf, rnnlm=None, rnnlm_conf=None, 24 | penalty=0., ctc_weight=0.1, lm_weight=0., beam_size=40, device="cuda:0"): 25 | super(AVSR, self).__init__() 26 | self.device = device 27 | 28 | if modality == "audiovisual": 29 | from espnet.nets.pytorch_backend.e2e_asr_transformer_av import E2E 30 | else: 31 | from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E 32 | 33 | with open(model_conf, "rb") as f: 34 | confs = json.load(f) 35 | args = confs if isinstance(confs, dict) else confs[2] 36 | self.train_args = argparse.Namespace(**args) 37 | 38 | labels_type = getattr(self.train_args, "labels_type", "char") 39 | if labels_type == "char": 40 | self.token_list = self.train_args.char_list 41 | elif labels_type == "unigram5000": 42 | file_path = os.path.join(os.path.dirname(__file__), "tokens", "unigram5000_units.txt") 43 | self.token_list = [''] + [word.split()[0] for word in open(file_path).read().splitlines()] + [''] 44 | self.odim = len(self.token_list) 45 | 46 | self.model = E2E(self.odim, self.train_args) 47 | self.model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage)) 48 | self.model.to(device=self.device).eval() 49 | 50 | self.beam_search = get_beam_search_decoder(self.model, self.token_list, rnnlm, rnnlm_conf, penalty, ctc_weight, lm_weight, beam_size) 51 | self.beam_search.to(device=self.device).eval() 52 | 53 | def infer(self, data): 54 | with torch.no_grad(): 55 | if isinstance(data, tuple): 56 | enc_feats = self.model.encode(data[0].to(self.device), data[1].to(self.device)) 57 | else: 58 | enc_feats = self.model.encode(data.to(self.device)) 59 | nbest_hyps = self.beam_search(enc_feats) 60 | nbest_hyps = [h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), 1)]] 61 | transcription = add_results_to_json(nbest_hyps, self.token_list) 62 | transcription = transcription.replace("▁", " ").strip() 63 | return transcription.replace("", "") 64 | 65 | 66 | def get_beam_search_decoder(model, token_list, rnnlm=None, rnnlm_conf=None, penalty=0, ctc_weight=0.1, lm_weight=0., beam_size=40): 67 | sos = model.odim - 1 68 | eos = model.odim - 1 69 | scorers = model.scorers() 70 | 71 | if not rnnlm: 72 | lm = None 73 | else: 74 | lm_args = get_model_conf(rnnlm, rnnlm_conf) 75 | lm_model_module = getattr(lm_args, "model_module", "default") 76 | lm_class = dynamic_import_lm(lm_model_module, lm_args.backend) 77 | lm = lm_class(len(token_list), lm_args) 78 | torch_load(rnnlm, lm) 79 | lm.eval() 80 | 81 | scorers["lm"] = lm 82 | scorers["length_bonus"] = LengthBonus(len(token_list)) 83 | weights = dict( 84 | decoder=1.0 - ctc_weight, 85 | ctc=ctc_weight, 86 | lm=lm_weight, 87 | length_bonus=penalty, 88 | ) 89 | 90 | return BatchBeamSearch( 91 | beam_size=beam_size, 92 | vocab_size=len(token_list), 93 | weights=weights, 94 | scorers=scorers, 95 | sos=sos, 96 | eos=eos, 97 | token_list=token_list, 98 | pre_beam_score_key=None if ctc_weight == 1.0 else "decoder", 99 | ) 100 | -------------------------------------------------------------------------------- /pipelines/pipeline.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2023 Imperial College London (Pingchuan Ma) 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | import os 8 | import torch 9 | import pickle 10 | from configparser import ConfigParser 11 | 12 | from pipelines.model import AVSR 13 | from pipelines.data.data_module import AVSRDataLoader 14 | 15 | 16 | class InferencePipeline(torch.nn.Module): 17 | def __init__(self, config_filename, detector="retinaface", face_track=False, device="cuda:0"): 18 | super(InferencePipeline, self).__init__() 19 | assert os.path.isfile(config_filename), f"config_filename: {config_filename} does not exist." 20 | 21 | config = ConfigParser() 22 | config.read(config_filename) 23 | 24 | # modality configuration 25 | modality = config.get("input", "modality") 26 | 27 | self.modality = modality 28 | # data configuration 29 | input_v_fps = config.getfloat("input", "v_fps") 30 | model_v_fps = config.getfloat("model", "v_fps") 31 | 32 | # model configuration 33 | model_path = config.get("model","model_path") 34 | model_conf = config.get("model","model_conf") 35 | 36 | # language model configuration 37 | rnnlm = config.get("model", "rnnlm") 38 | rnnlm_conf = config.get("model", "rnnlm_conf") 39 | penalty = config.getfloat("decode", "penalty") 40 | ctc_weight = config.getfloat("decode", "ctc_weight") 41 | lm_weight = config.getfloat("decode", "lm_weight") 42 | beam_size = config.getint("decode", "beam_size") 43 | 44 | self.dataloader = AVSRDataLoader(modality, speed_rate=input_v_fps/model_v_fps, detector=detector) 45 | self.model = AVSR(modality, model_path, model_conf, rnnlm, rnnlm_conf, penalty, ctc_weight, lm_weight, beam_size, device) 46 | if face_track and self.modality in ["video", "audiovisual"]: 47 | if detector == "mediapipe": 48 | from pipelines.detectors.mediapipe.detector import LandmarksDetector 49 | self.landmarks_detector = LandmarksDetector() 50 | if detector == "retinaface": 51 | from pipelines.detectors.retinaface.detector import LandmarksDetector 52 | self.landmarks_detector = LandmarksDetector(device="cuda:0") 53 | else: 54 | self.landmarks_detector = None 55 | 56 | 57 | def process_landmarks(self, data_filename, landmarks_filename): 58 | if self.modality == "audio": 59 | return None 60 | if self.modality in ["video", "audiovisual"]: 61 | if isinstance(landmarks_filename, str): 62 | landmarks = pickle.load(open(landmarks_filename, "rb")) 63 | else: 64 | landmarks = self.landmarks_detector(data_filename) 65 | return landmarks 66 | 67 | 68 | def forward(self, data_filename, landmarks_filename=None): 69 | assert os.path.isfile(data_filename), f"data_filename: {data_filename} does not exist." 70 | landmarks = self.process_landmarks(data_filename, landmarks_filename) 71 | data = self.dataloader.load_data(data_filename, landmarks) 72 | transcript = self.model.infer(data) 73 | return transcript -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core >= 1.3.2 2 | opencv-python >= 4.5.5.62 3 | scipy >= 1.3.0 4 | scikit-image >= 0.13.0 5 | av >= 10.0.0 6 | six >= 1.16.0 7 | mediapipe 8 | torch 9 | torchvision 10 | torchaudio 11 | ollama 12 | pydantic 13 | keyboard -------------------------------------------------------------------------------- /thumbnail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amanvirparhar/chaplin/b68158d7d7b56fbc7631a1df47c7c9e4e5f23e2f/thumbnail.png --------------------------------------------------------------------------------