├── espnet ├── __init__.py ├── nets │ ├── __init__.py │ ├── pytorch_backend │ │ ├── __init__.py │ │ ├── lm │ │ │ ├── __init__.py │ │ │ ├── seq_rnn.py │ │ │ └── transformer.py │ │ ├── transformer │ │ │ ├── __init__.py │ │ │ ├── repeat.py │ │ │ ├── layer_norm.py │ │ │ ├── positionwise_feed_forward.py │ │ │ ├── add_sos_eos.py │ │ │ ├── mask.py │ │ │ ├── subsampling.py │ │ │ ├── optimizer.py │ │ │ ├── label_smoothing_loss.py │ │ │ ├── convolution.py │ │ │ ├── raw_embeddings.py │ │ │ ├── multi_layer_conv.py │ │ │ ├── decoder_layer.py │ │ │ ├── plot.py │ │ │ ├── encoder_layer.py │ │ │ ├── embedding.py │ │ │ ├── Bert_layers.py │ │ │ └── decoder.py │ │ ├── backbones │ │ │ ├── conv1d_extractor.py │ │ │ ├── conv3d_extractor.py │ │ │ └── modules │ │ │ │ ├── resnet.py │ │ │ │ ├── shufflenetv2.py │ │ │ │ └── resnet1d.py │ │ └── ctc.py │ ├── scorers │ │ ├── __init__.py │ │ ├── length_bonus.py │ │ └── ctc.py │ ├── lm_interface.py │ ├── scorer_interface.py │ └── e2e_asr_common.py └── utils │ ├── dynamic_import.py │ ├── fill_missing_args.py │ └── cli_utils.py ├── src ├── models │ ├── __init__.py │ ├── VCAFE.py │ ├── AVRelScore.py │ ├── Lip_reader.py │ └── model.json └── data │ ├── 20words_mean_face.npy │ ├── char_list.py │ ├── landmark_transform.py │ ├── visual_corruption.py │ ├── transforms.py │ └── dataset.py ├── occlusion_patch └── README.md ├── checkpoints ├── frontend │ └── README.md └── LM │ └── README.md ├── img └── IMG.png ├── preprocessing.py ├── test.py └── README.md /espnet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /espnet/nets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /occlusion_patch/README.md: -------------------------------------------------------------------------------- 1 | Put occlusion folders here. 2 | -------------------------------------------------------------------------------- /espnet/nets/scorers/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize sub package.""" 2 | -------------------------------------------------------------------------------- /checkpoints/frontend/README.md: -------------------------------------------------------------------------------- 1 | Put pre-trained frontend files here. 2 | -------------------------------------------------------------------------------- /img/IMG.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ms-dot-k/AVSR/HEAD/img/IMG.png -------------------------------------------------------------------------------- /checkpoints/LM/README.md: -------------------------------------------------------------------------------- 1 | Put pre-trained language model checkpoints here. 2 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/lm/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize sub package.""" 2 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize sub package.""" 2 | -------------------------------------------------------------------------------- /src/data/20words_mean_face.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ms-dot-k/AVSR/HEAD/src/data/20words_mean_face.npy -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/repeat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Repeat the same layer definition.""" 8 | 9 | import torch 10 | 11 | 12 | class MultiSequential(torch.nn.Sequential): 13 | """Multi-input multi-output torch.nn.Sequential.""" 14 | 15 | def forward(self, *args): 16 | """Repeat.""" 17 | for m in self: 18 | args = m(*args) 19 | return args 20 | 21 | 22 | def repeat(N, fn): 23 | """Repeat module N times. 24 | 25 | :param int N: repeat time 26 | :param function fn: function to generate module 27 | :return: repeated modules 28 | :rtype: MultiSequential 29 | """ 30 | return MultiSequential(*[fn() for _ in range(N)]) 31 | -------------------------------------------------------------------------------- /espnet/utils/dynamic_import.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | def dynamic_import(import_path, alias=dict()): 5 | """dynamic import module and class 6 | 7 | :param str import_path: syntax 'module_name:class_name' 8 | e.g., 'espnet.transform.add_deltas:AddDeltas' 9 | :param dict alias: shortcut for registered class 10 | :return: imported class 11 | """ 12 | if import_path not in alias and ":" not in import_path: 13 | raise ValueError( 14 | "import_path should be one of {} or " 15 | 'include ":", e.g. "espnet.transform.add_deltas:AddDeltas" : ' 16 | "{}".format(set(alias), import_path) 17 | ) 18 | if ":" not in import_path: 19 | import_path = alias[import_path] 20 | 21 | module_name, objname = import_path.split(":") 22 | m = importlib.import_module(module_name) 23 | return getattr(m, objname) 24 | -------------------------------------------------------------------------------- /src/data/char_list.py: -------------------------------------------------------------------------------- 1 | char_list=[ 2 | "", 3 | "", 4 | "'", 5 | "0", 6 | "1", 7 | "2", 8 | "3", 9 | "4", 10 | "5", 11 | "6", 12 | "7", 13 | "8", 14 | "9", 15 | "", 16 | "A", 17 | "B", 18 | "C", 19 | "D", 20 | "E", 21 | "F", 22 | "G", 23 | "H", 24 | "I", 25 | "J", 26 | "K", 27 | "L", 28 | "M", 29 | "N", 30 | "O", 31 | "P", 32 | "Q", 33 | "R", 34 | "S", 35 | "T", 36 | "U", 37 | "V", 38 | "W", 39 | "X", 40 | "Y", 41 | "Z", 42 | "" 43 | ] -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/layer_norm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Layer normalization module.""" 8 | 9 | import torch 10 | 11 | 12 | class LayerNorm(torch.nn.LayerNorm): 13 | """Layer normalization module. 14 | 15 | :param int nout: output dim size 16 | :param int dim: dimension to be normalized 17 | """ 18 | 19 | def __init__(self, nout, dim=-1): 20 | """Construct an LayerNorm object.""" 21 | super(LayerNorm, self).__init__(nout, eps=1e-12) 22 | self.dim = dim 23 | 24 | def forward(self, x): 25 | """Apply layer normalization. 26 | 27 | :param torch.Tensor x: input tensor 28 | :return: layer normalized tensor 29 | :rtype torch.Tensor 30 | """ 31 | if self.dim == -1: 32 | return super(LayerNorm, self).forward(x) 33 | return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) 34 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/positionwise_feed_forward.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Positionwise feed forward layer definition.""" 8 | 9 | import torch 10 | 11 | 12 | class PositionwiseFeedForward(torch.nn.Module): 13 | """Positionwise feed forward layer. 14 | 15 | :param int idim: input dimenstion 16 | :param int hidden_units: number of hidden units 17 | :param float dropout_rate: dropout rate 18 | 19 | """ 20 | 21 | def __init__(self, idim, hidden_units, dropout_rate): 22 | """Construct an PositionwiseFeedForward object.""" 23 | super(PositionwiseFeedForward, self).__init__() 24 | self.w_1 = torch.nn.Linear(idim, hidden_units) 25 | self.w_2 = torch.nn.Linear(hidden_units, idim) 26 | self.dropout = torch.nn.Dropout(dropout_rate) 27 | 28 | def forward(self, x): 29 | """Forward funciton.""" 30 | return self.w_2(self.dropout(torch.relu(self.w_1(x)))) 31 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/add_sos_eos.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Unility funcitons for Transformer.""" 8 | 9 | import torch 10 | 11 | 12 | def add_sos_eos(ys_pad, sos, eos, ignore_id): 13 | """Add and labels. 14 | 15 | :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) 16 | :param int sos: index of 17 | :param int eos: index of 18 | :param int ignore_id: index of padding 19 | :return: padded tensor (B, Lmax) 20 | :rtype: torch.Tensor 21 | :return: padded tensor (B, Lmax) 22 | :rtype: torch.Tensor 23 | """ 24 | from espnet.nets.pytorch_backend.nets_utils import pad_list 25 | 26 | _sos = ys_pad.new([sos]) 27 | _eos = ys_pad.new([eos]) 28 | ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys 29 | ys_in = [torch.cat([_sos, y], dim=0) for y in ys] 30 | ys_out = [torch.cat([y, _eos], dim=0) for y in ys] 31 | return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) 32 | -------------------------------------------------------------------------------- /espnet/utils/fill_missing_args.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2018 Nagoya University (Tomoki Hayashi) 4 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 5 | 6 | import argparse 7 | import logging 8 | 9 | 10 | def fill_missing_args(args, add_arguments): 11 | """Fill missing arguments in args. 12 | 13 | Args: 14 | args (Namespace or None): Namesapce containing hyperparameters. 15 | add_arguments (function): Function to add arguments. 16 | 17 | Returns: 18 | Namespace: Arguments whose missing ones are filled with default value. 19 | 20 | Examples: 21 | >>> from argparse import Namespace 22 | >>> from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import Tacotron2 23 | >>> args = Namespace() 24 | >>> fill_missing_args(args, Tacotron2.add_arguments_fn) 25 | Namespace(aconv_chans=32, aconv_filts=15, adim=512, atype='location', ...) 26 | 27 | """ 28 | # check argument type 29 | assert isinstance(args, argparse.Namespace) or args is None 30 | assert callable(add_arguments) 31 | 32 | # get default arguments 33 | default_args, _ = add_arguments(argparse.ArgumentParser()).parse_known_args() 34 | 35 | # convert to dict 36 | args = {} if args is None else vars(args) 37 | default_args = vars(default_args) 38 | 39 | for key, value in default_args.items(): 40 | if key not in args: 41 | logging.info( 42 | 'attribute "%s" does not exist. use default %s.' % (key, str(value)) 43 | ) 44 | args[key] = value 45 | 46 | return argparse.Namespace(**args) 47 | -------------------------------------------------------------------------------- /espnet/utils/cli_utils.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | from distutils.util import strtobool as dist_strtobool 3 | import sys 4 | 5 | import numpy 6 | 7 | 8 | def strtobool(x): 9 | # distutils.util.strtobool returns integer, but it's confusing, 10 | return bool(dist_strtobool(x)) 11 | 12 | 13 | def get_commandline_args(): 14 | extra_chars = [ 15 | " ", 16 | ";", 17 | "&", 18 | "(", 19 | ")", 20 | "|", 21 | "^", 22 | "<", 23 | ">", 24 | "?", 25 | "*", 26 | "[", 27 | "]", 28 | "$", 29 | "`", 30 | '"', 31 | "\\", 32 | "!", 33 | "{", 34 | "}", 35 | ] 36 | 37 | # Escape the extra characters for shell 38 | argv = [ 39 | arg.replace("'", "'\\''") 40 | if all(char not in arg for char in extra_chars) 41 | else "'" + arg.replace("'", "'\\''") + "'" 42 | for arg in sys.argv 43 | ] 44 | 45 | return sys.executable + " " + " ".join(argv) 46 | 47 | 48 | def is_scipy_wav_style(value): 49 | # If Tuple[int, numpy.ndarray] or not 50 | return ( 51 | isinstance(value, Sequence) 52 | and len(value) == 2 53 | and isinstance(value[0], int) 54 | and isinstance(value[1], numpy.ndarray) 55 | ) 56 | 57 | 58 | def assert_scipy_wav_style(value): 59 | assert is_scipy_wav_style( 60 | value 61 | ), "Must be Tuple[int, numpy.ndarray], but got {}".format( 62 | type(value) 63 | if not isinstance(value, Sequence) 64 | else "{}[{}]".format(type(value), ", ".join(str(type(v)) for v in value)) 65 | ) 66 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/backbones/conv1d_extractor.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2021 Imperial College London (Pingchuan Ma) 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | import logging 7 | import torch 8 | import numpy as np 9 | from espnet.nets.pytorch_backend.backbones.modules.resnet1d import ResNet1D 10 | from espnet.nets.pytorch_backend.backbones.modules.resnet1d import BasicBlock1D 11 | 12 | 13 | class Conv1dResNet(torch.nn.Module): 14 | """Conv1dResNet 15 | """ 16 | 17 | def __init__(self, relu_type="swish", a_upsample_ratio=1): 18 | """__init__. 19 | 20 | :param relu_type: str, Activation function used in an audio front-end. 21 | :param a_upsample_ratio: int, The ratio related to the \ 22 | temporal resolution of output features of the frontend. \ 23 | a_upsample_ratio=1 produce features with a fps of 25. 24 | """ 25 | 26 | super(Conv1dResNet, self).__init__() 27 | self.a_upsample_ratio=a_upsample_ratio 28 | self.trunk = ResNet1D( 29 | BasicBlock1D, 30 | [2, 2, 2, 2], 31 | relu_type=relu_type, 32 | a_upsample_ratio=a_upsample_ratio 33 | ) 34 | 35 | def forward(self, xs_pad): 36 | """forward. 37 | 38 | :param xs_pad: torch.Tensor, batch of padded input sequences (B, Tmax, idim) 39 | """ 40 | B, T, C = xs_pad.size() 41 | xs_pad = xs_pad[:,:T//640*640,:] 42 | xs_pad = xs_pad.transpose(1, 2).contiguous() 43 | xs_pad = self.trunk(xs_pad) 44 | # -- from B x C x T to B x T x C 45 | xs_pad = xs_pad.transpose(1, 2).contiguous() 46 | return xs_pad 47 | -------------------------------------------------------------------------------- /src/models/VCAFE.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2023 Korea Advanced Institute of Science and Technology (Joanna Hong, Minsu Kim) 5 | 6 | import torch 7 | import torch.nn as nn 8 | import math 9 | 10 | class VCA_Masking(nn.Module): 11 | def __init__(self, out_dim=512): 12 | super().__init__() 13 | 14 | self.softmax = nn.Softmax(2) 15 | self.k = nn.Linear(512, out_dim) 16 | self.v = nn.Linear(512, out_dim) 17 | self.q = nn.Linear(512, out_dim) 18 | self.out_dim = out_dim 19 | 20 | self.sigmoid = nn.Sigmoid() 21 | self.mask = nn.Sequential( 22 | nn.Conv1d(out_dim, 512, 3, padding=1), 23 | nn.ReLU(True), 24 | nn.Conv1d(512, 512, 1) 25 | ) 26 | self.dropout = nn.Dropout(0.3) 27 | self.fusion = nn.Linear(1024, 512) 28 | 29 | def forward(self, aud, vid, v_len): 30 | #aud: B,T,512 31 | #vid: B,S,512 32 | q = self.q(aud) # B,T,OD 33 | k = self.k(vid).transpose(1, 2).contiguous() # B,OD,S 34 | 35 | att = torch.bmm(q, k) / math.sqrt(self.out_dim) # B,T,S 36 | for i in range(att.size(0)): 37 | att[i, :, v_len[i]:] = float('-inf') 38 | att = self.softmax(att) # B,T,S 39 | 40 | v = self.v(vid) # B,S,OD 41 | value = torch.bmm(att, v) # B,T,OD 42 | 43 | mask = self.mask(value.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous() 44 | for i in range(aud.size(0)): 45 | mask[i, v_len[i]:, :] = float('-inf') 46 | mask = self.sigmoid(mask) 47 | enhance_aud = aud * mask 48 | enhanced_aud = enhance_aud + aud 49 | 50 | fusion = torch.cat([enhanced_aud, vid], 2) 51 | fusion = self.fusion(self.dropout(fusion)) 52 | 53 | return fusion -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/mask.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2019 Shigeki Karita 4 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 5 | 6 | """Mask module.""" 7 | 8 | from distutils.version import LooseVersion 9 | 10 | import torch 11 | 12 | is_torch_1_2_plus = LooseVersion(torch.__version__) >= LooseVersion("1.2.0") 13 | # LooseVersion('1.2.0') == LooseVersion(torch.__version__) can't include e.g. 1.2.0+aaa 14 | is_torch_1_2 = ( 15 | LooseVersion("1.3") > LooseVersion(torch.__version__) >= LooseVersion("1.2") 16 | ) 17 | datatype = torch.bool if is_torch_1_2_plus else torch.uint8 18 | 19 | 20 | def subsequent_mask(size, device="cpu", dtype=datatype): 21 | """Create mask for subsequent steps (1, size, size). 22 | 23 | :param int size: size of mask 24 | :param str device: "cpu" or "cuda" or torch.Tensor.device 25 | :param torch.dtype dtype: result dtype 26 | :rtype: torch.Tensor 27 | >>> subsequent_mask(3) 28 | [[1, 0, 0], 29 | [1, 1, 0], 30 | [1, 1, 1]] 31 | """ 32 | if is_torch_1_2 and dtype == torch.bool: 33 | # torch=1.2 doesn't support tril for bool tensor 34 | ret = torch.ones(size, size, device=device, dtype=torch.uint8) 35 | return torch.tril(ret, out=ret).type(dtype) 36 | else: 37 | ret = torch.ones(size, size, device=device, dtype=dtype) 38 | return torch.tril(ret, out=ret) 39 | 40 | 41 | def target_mask(ys_in_pad, ignore_id): 42 | """Create mask for decoder self-attention. 43 | 44 | :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) 45 | :param int ignore_id: index of padding 46 | :param torch.dtype dtype: result dtype 47 | :rtype: torch.Tensor 48 | """ 49 | ys_mask = ys_in_pad != ignore_id 50 | m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) 51 | return ys_mask.unsqueeze(-2) & m 52 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/subsampling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Subsampling layer definition.""" 8 | 9 | import torch 10 | 11 | 12 | class Conv2dSubsampling(torch.nn.Module): 13 | """Convolutional 2D subsampling (to 1/4 length). 14 | 15 | :param int idim: input dim 16 | :param int odim: output dim 17 | :param flaot dropout_rate: dropout rate 18 | :param nn.Module pos_enc_class: positional encoding layer 19 | 20 | """ 21 | 22 | def __init__(self, idim, odim, dropout_rate, pos_enc_class): 23 | """Construct an Conv2dSubsampling object.""" 24 | super(Conv2dSubsampling, self).__init__() 25 | self.conv = torch.nn.Sequential( 26 | torch.nn.Conv2d(1, odim, 3, 2), 27 | torch.nn.ReLU(), 28 | torch.nn.Conv2d(odim, odim, 3, 2), 29 | torch.nn.ReLU(), 30 | ) 31 | self.out = torch.nn.Sequential( 32 | torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim), pos_enc_class, 33 | ) 34 | 35 | def forward(self, x, x_mask): 36 | """Subsample x. 37 | 38 | :param torch.Tensor x: input tensor 39 | :param torch.Tensor x_mask: input mask 40 | :return: subsampled x and mask 41 | :rtype Tuple[torch.Tensor, torch.Tensor] 42 | or Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor] 43 | """ 44 | x = x.unsqueeze(1) # (b, c, t, f) 45 | x = self.conv(x) 46 | b, c, t, f = x.size() 47 | # if RelPositionalEncoding, x: Tuple[torch.Tensor, torch.Tensor] 48 | # else x: torch.Tensor 49 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) 50 | if x_mask is None: 51 | return x, None 52 | return x, x_mask[:, :, :-2:2][:, :, :-2:2] 53 | -------------------------------------------------------------------------------- /espnet/nets/scorers/length_bonus.py: -------------------------------------------------------------------------------- 1 | """Length bonus module.""" 2 | from typing import Any 3 | from typing import List 4 | from typing import Tuple 5 | 6 | import torch 7 | 8 | from espnet.nets.scorer_interface import BatchScorerInterface 9 | 10 | 11 | class LengthBonus(BatchScorerInterface): 12 | """Length bonus in beam search.""" 13 | 14 | def __init__(self, n_vocab: int): 15 | """Initialize class. 16 | 17 | Args: 18 | n_vocab (int): The number of tokens in vocabulary for beam search 19 | 20 | """ 21 | self.n = n_vocab 22 | 23 | def score(self, y, state, x): 24 | """Score new token. 25 | 26 | Args: 27 | y (torch.Tensor): 1D torch.int64 prefix tokens. 28 | state: Scorer state for prefix tokens 29 | x (torch.Tensor): 2D encoder feature that generates ys. 30 | 31 | Returns: 32 | tuple[torch.Tensor, Any]: Tuple of 33 | torch.float32 scores for next token (n_vocab) 34 | and None 35 | 36 | """ 37 | return torch.tensor([1.0], device=x.device, dtype=x.dtype).expand(self.n), None 38 | 39 | def batch_score( 40 | self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor 41 | ) -> Tuple[torch.Tensor, List[Any]]: 42 | """Score new token batch. 43 | 44 | Args: 45 | ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). 46 | states (List[Any]): Scorer states for prefix tokens. 47 | xs (torch.Tensor): 48 | The encoder feature that generates ys (n_batch, xlen, n_feat). 49 | 50 | Returns: 51 | tuple[torch.Tensor, List[Any]]: Tuple of 52 | batchfied scores for next token with shape of `(n_batch, n_vocab)` 53 | and next state list for ys. 54 | 55 | """ 56 | return ( 57 | torch.tensor([1.0], device=xs.device, dtype=xs.dtype).expand( 58 | ys.shape[0], self.n 59 | ), 60 | None, 61 | ) 62 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/optimizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Optimizer module.""" 8 | 9 | import torch 10 | 11 | 12 | class NoamOpt(object): 13 | """Optim wrapper that implements rate.""" 14 | 15 | def __init__(self, model_size, factor, warmup, optimizer): 16 | """Construct an NoamOpt object.""" 17 | self.optimizer = optimizer 18 | self._step = 0 19 | self.warmup = warmup 20 | self.factor = factor 21 | self.model_size = model_size 22 | self._rate = 0 23 | 24 | @property 25 | def param_groups(self): 26 | """Return param_groups.""" 27 | return self.optimizer.param_groups 28 | 29 | def step(self): 30 | """Update parameters and rate.""" 31 | self._step += 1 32 | rate = self.rate() 33 | for p in self.optimizer.param_groups: 34 | p["lr"] = rate 35 | self._rate = rate 36 | self.optimizer.step() 37 | 38 | def rate(self, step=None): 39 | """Implement `lrate` above.""" 40 | if step is None: 41 | step = self._step 42 | return ( 43 | self.factor 44 | * self.model_size ** (-0.5) 45 | * min(step ** (-0.5), step * self.warmup ** (-1.5)) 46 | ) 47 | 48 | def zero_grad(self): 49 | """Reset gradient.""" 50 | self.optimizer.zero_grad() 51 | 52 | def state_dict(self): 53 | """Return state_dict.""" 54 | return { 55 | "_step": self._step, 56 | "warmup": self.warmup, 57 | "factor": self.factor, 58 | "model_size": self.model_size, 59 | "_rate": self._rate, 60 | "optimizer": self.optimizer.state_dict(), 61 | } 62 | 63 | def load_state_dict(self, state_dict): 64 | """Load state_dict.""" 65 | for key, value in state_dict.items(): 66 | if key == "optimizer": 67 | self.optimizer.load_state_dict(state_dict["optimizer"]) 68 | else: 69 | setattr(self, key, value) 70 | 71 | 72 | def get_std_opt(model, d_model, warmup, factor): 73 | """Get standard NoamOpt.""" 74 | base = torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9) 75 | return NoamOpt(d_model, factor, warmup, base) 76 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Label smoothing module.""" 8 | 9 | import torch 10 | from torch import nn 11 | 12 | 13 | class LabelSmoothingLoss(nn.Module): 14 | """Label-smoothing loss. 15 | 16 | :param int size: the number of class 17 | :param int padding_idx: ignored class id 18 | :param float smoothing: smoothing rate (0.0 means the conventional CE) 19 | :param bool normalize_length: normalize loss by sequence length if True 20 | :param torch.nn.Module criterion: loss function to be smoothed 21 | """ 22 | 23 | def __init__( 24 | self, 25 | size, 26 | padding_idx, 27 | smoothing, 28 | normalize_length=False, 29 | criterion=nn.KLDivLoss(reduction="none"), 30 | ): 31 | """Construct an LabelSmoothingLoss object.""" 32 | super(LabelSmoothingLoss, self).__init__() 33 | self.criterion = criterion 34 | self.padding_idx = padding_idx 35 | self.confidence = 1.0 - smoothing 36 | self.smoothing = smoothing 37 | self.size = size 38 | self.true_dist = None 39 | self.normalize_length = normalize_length 40 | 41 | def forward(self, x, target): 42 | """Compute loss between x and target. 43 | 44 | :param torch.Tensor x: prediction (batch, seqlen, class) 45 | :param torch.Tensor target: 46 | target signal masked with self.padding_id (batch, seqlen) 47 | :return: scalar float value 48 | :rtype torch.Tensor 49 | """ 50 | # assert x.size(2) == self.size 51 | batch_size = x.size(0) 52 | x = x.view(-1, self.size) 53 | target = target.view(-1) 54 | with torch.no_grad(): 55 | true_dist = x.clone() 56 | true_dist.fill_(self.smoothing / (self.size - 1)) 57 | ignore = target == self.padding_idx # (B,) 58 | total = len(target) - ignore.sum().item() 59 | target = target.masked_fill(ignore, 0) # avoid -1 index 60 | true_dist.scatter_(1, target.unsqueeze(1), self.confidence) 61 | kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) 62 | denom = total if self.normalize_length else batch_size 63 | return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom 64 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/convolution.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2020 Johns Hopkins University (Shinji Watanabe) 5 | # Northwestern Polytechnical University (Pengcheng Guo) 6 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 7 | 8 | """ConvolutionModule definition.""" 9 | 10 | import torch 11 | from torch import nn 12 | 13 | 14 | class ConvolutionModule(nn.Module): 15 | """ConvolutionModule in Conformer model. 16 | 17 | :param int channels: channels of cnn 18 | :param int kernel_size: kernerl size of cnn 19 | 20 | """ 21 | 22 | def __init__(self, channels, kernel_size, bias=True): 23 | """Construct an ConvolutionModule object.""" 24 | super(ConvolutionModule, self).__init__() 25 | # kernerl_size should be a odd number for 'SAME' padding 26 | assert (kernel_size - 1) % 2 == 0 27 | 28 | self.pointwise_cov1 = nn.Conv1d( 29 | channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias, 30 | ) 31 | self.depthwise_conv = nn.Conv1d( 32 | channels, 33 | channels, 34 | kernel_size, 35 | stride=1, 36 | padding=(kernel_size - 1) // 2, 37 | groups=channels, 38 | bias=bias, 39 | ) 40 | self.norm = nn.BatchNorm1d(channels) 41 | self.pointwise_cov2 = nn.Conv1d( 42 | channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, 43 | ) 44 | self.activation = Swish() 45 | 46 | def forward(self, x): 47 | """Compute covolution module. 48 | 49 | :param torch.Tensor x: (batch, time, size) 50 | :return torch.Tensor: convoluted `value` (batch, time, d_model) 51 | """ 52 | # exchange the temporal dimension and the feature dimension 53 | x = x.transpose(1, 2) 54 | 55 | # GLU mechanism 56 | x = self.pointwise_cov1(x) # (batch, 2*channel, dim) 57 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim) 58 | 59 | # 1D Depthwise Conv 60 | x = self.depthwise_conv(x) 61 | x = self.activation(self.norm(x)) 62 | 63 | x = self.pointwise_cov2(x) 64 | 65 | return x.transpose(1, 2) 66 | 67 | 68 | class Swish(nn.Module): 69 | """Construct an Swish object.""" 70 | 71 | def forward(self, x): 72 | """Return Swich activation function.""" 73 | return x * torch.sigmoid(x) 74 | -------------------------------------------------------------------------------- /src/models/AVRelScore.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2023 Korea Advanced Institute of Science and Technology (Joanna Hong, Minsu Kim) 5 | 6 | import torch 7 | from torch import nn 8 | 9 | class Scoring_Module(torch.nn.Module): 10 | def __init__(self, indim=512): 11 | super().__init__() 12 | self.score = nn.Sequential(nn.Conv1d(indim, indim*2, 7, padding=3), 13 | nn.BatchNorm1d(indim*2), 14 | nn.ReLU(True), 15 | nn.Conv1d(indim*2, indim*2, 7, padding=3), 16 | nn.BatchNorm1d(indim*2), 17 | nn.ReLU(True), 18 | nn.Conv1d(indim*2, indim, 7, padding=3), 19 | nn.BatchNorm1d(indim), 20 | nn.Sigmoid()) 21 | 22 | def forward(self, x): 23 | x = x.transpose(1,2).contiguous() 24 | x = self.score(x) 25 | x = x.transpose(1,2).contiguous() 26 | return x 27 | 28 | class Scoring(torch.nn.Module): 29 | def __init__(self, indim=512): 30 | super().__init__() 31 | self.score_vid = Scoring_Module(indim) 32 | self.score_aud = Scoring_Module(indim) 33 | 34 | def generate_key_mask(self, length, sz): 35 | masks = [] 36 | for i in range(length.size(0)): 37 | mask = [0] * length[i] 38 | mask += [1] * (sz - length[i]) 39 | masks += [torch.tensor(mask*2)] 40 | masks = torch.stack(masks, dim=0).bool() 41 | return masks 42 | 43 | def forward(self, v, a, v_len, is_residual=True, is_scoring=True): 44 | if not is_scoring: 45 | out = torch.cat([v, a], dim=1) 46 | merged_attention_mask = self.generate_key_mask(v_len, v.size(1)).to(out.device) 47 | return out, ~merged_attention_mask.unsqueeze(1).contiguous() 48 | else: 49 | vid_s = self.score_vid(v) 50 | if is_residual: 51 | vid_s = v * vid_s + v 52 | elif is_residual: 53 | vid_s = v * vid_s 54 | else: 55 | raise NotImplementedError 56 | aud_s = self.score_aud(a) 57 | if is_residual: 58 | aud_s = a * aud_s + a 59 | elif is_residual: 60 | aud_s = a * aud_s 61 | else: 62 | raise NotImplementedError 63 | 64 | out = torch.cat([vid_s, aud_s], dim=1) 65 | merged_attention_mask = self.generate_key_mask(v_len, vid_s.size(1)).to(out.device) 66 | return out, ~merged_attention_mask.unsqueeze(1).contiguous() 67 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/raw_embeddings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | 4 | from espnet.nets.pytorch_backend.backbones.conv3d_extractor import Conv3dResNet 5 | from espnet.nets.pytorch_backend.backbones.conv1d_extractor import Conv1dResNet 6 | 7 | 8 | class VideoEmbedding(torch.nn.Module): 9 | """Video Embedding 10 | 11 | :param int idim: input dim 12 | :param int odim: output dim 13 | :param flaot dropout_rate: dropout rate 14 | """ 15 | 16 | def __init__(self, idim, odim, dropout_rate, pos_enc_class, backbone_type="resnet", relu_type="prelu"): 17 | super(VideoEmbedding, self).__init__() 18 | self.trunk = Conv3dResNet( 19 | backbone_type=backbone_type, 20 | relu_type=relu_type 21 | ) 22 | self.out = torch.nn.Sequential( 23 | torch.nn.Linear(idim, odim), 24 | pos_enc_class, 25 | ) 26 | 27 | def forward(self, x, x_mask, extract_feats=None): 28 | """video embedding for x 29 | 30 | :param torch.Tensor x: input tensor 31 | :param torch.Tensor x_mask: input mask 32 | :param str extract_features: the position for feature extraction 33 | :return: subsampled x and mask 34 | :rtype Tuple[torch.Tensor, torch.Tensor] 35 | """ 36 | x_resnet, x_mask = self.trunk(x, x_mask) 37 | x = self.out(x_resnet) 38 | if extract_feats: 39 | return x, x_mask, x_resnet 40 | else: 41 | return x, x_mask 42 | 43 | 44 | class AudioEmbedding(torch.nn.Module): 45 | """Audio Embedding 46 | 47 | :param int idim: input dim 48 | :param int odim: output dim 49 | :param flaot dropout_rate: dropout rate 50 | """ 51 | 52 | def __init__(self, idim, odim, dropout_rate, pos_enc_class, relu_type="prelu", a_upsample_ratio=1): 53 | super(AudioEmbedding, self).__init__() 54 | self.trunk = Conv1dResNet( 55 | relu_type=relu_type, 56 | a_upsample_ratio=a_upsample_ratio, 57 | ) 58 | self.out = torch.nn.Sequential( 59 | torch.nn.Linear(idim, odim), 60 | pos_enc_class, 61 | ) 62 | 63 | def forward(self, x, x_mask, extract_feats=None): 64 | """audio embedding for x 65 | 66 | :param torch.Tensor x: input tensor 67 | :param torch.Tensor x_mask: input mask 68 | :param str extract_features: the position for feature extraction 69 | :return: subsampled x and mask 70 | :rtype Tuple[torch.Tensor, torch.Tensor] 71 | """ 72 | x_resnet, x_mask = self.trunk(x, x_mask) 73 | x = self.out(x_resnet) 74 | if extract_feats: 75 | return x, x_mask, x_resnet 76 | else: 77 | return x, x_mask 78 | -------------------------------------------------------------------------------- /espnet/nets/lm_interface.py: -------------------------------------------------------------------------------- 1 | """Language model interface.""" 2 | 3 | import argparse 4 | 5 | from espnet.nets.scorer_interface import ScorerInterface 6 | from espnet.utils.dynamic_import import dynamic_import 7 | from espnet.utils.fill_missing_args import fill_missing_args 8 | 9 | 10 | class LMInterface(ScorerInterface): 11 | """LM Interface for ESPnet model implementation.""" 12 | 13 | @staticmethod 14 | def add_arguments(parser): 15 | """Add arguments to command line argument parser.""" 16 | return parser 17 | 18 | @classmethod 19 | def build(cls, n_vocab: int, **kwargs): 20 | """Initialize this class with python-level args. 21 | 22 | Args: 23 | idim (int): The number of vocabulary. 24 | 25 | Returns: 26 | LMinterface: A new instance of LMInterface. 27 | 28 | """ 29 | # local import to avoid cyclic import in lm_train 30 | from espnet.bin.lm_train import get_parser 31 | 32 | def wrap(parser): 33 | return get_parser(parser, required=False) 34 | 35 | args = argparse.Namespace(**kwargs) 36 | args = fill_missing_args(args, wrap) 37 | args = fill_missing_args(args, cls.add_arguments) 38 | return cls(n_vocab, args) 39 | 40 | def forward(self, x, t): 41 | """Compute LM loss value from buffer sequences. 42 | 43 | Args: 44 | x (torch.Tensor): Input ids. (batch, len) 45 | t (torch.Tensor): Target ids. (batch, len) 46 | 47 | Returns: 48 | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of 49 | loss to backward (scalar), 50 | negative log-likelihood of t: -log p(t) (scalar) and 51 | the number of elements in x (scalar) 52 | 53 | Notes: 54 | The last two return values are used 55 | in perplexity: p(t)^{-n} = exp(-log p(t) / n) 56 | 57 | """ 58 | raise NotImplementedError("forward method is not implemented") 59 | 60 | 61 | predefined_lms = { 62 | "pytorch": { 63 | "default": "espnet.nets.pytorch_backend.lm.default:DefaultRNNLM", 64 | "seq_rnn": "espnet.nets.pytorch_backend.lm.seq_rnn:SequentialRNNLM", 65 | "transformer": "espnet.nets.pytorch_backend.lm.transformer:TransformerLM", 66 | }, 67 | "chainer": {"default": "espnet.lm.chainer_backend.lm:DefaultRNNLM"}, 68 | } 69 | 70 | 71 | def dynamic_import_lm(module, backend): 72 | """Import LM class dynamically. 73 | 74 | Args: 75 | module (str): module_name:class_name or alias in `predefined_lms` 76 | backend (str): NN backend. e.g., pytorch, chainer 77 | 78 | Returns: 79 | type: LM class 80 | 81 | """ 82 | model_class = dynamic_import(module, predefined_lms.get(backend, dict())) 83 | assert issubclass( 84 | model_class, LMInterface 85 | ), f"{module} does not implement LMInterface" 86 | return model_class 87 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/backbones/conv3d_extractor.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2021 Imperial College London (Pingchuan Ma) 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | import logging 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | import numpy as np 11 | 12 | from espnet.nets.pytorch_backend.backbones.modules.resnet import ResNet 13 | from espnet.nets.pytorch_backend.backbones.modules.resnet import BasicBlock 14 | from espnet.nets.pytorch_backend.backbones.modules.shufflenetv2 import ShuffleNetV2 15 | from espnet.nets.pytorch_backend.transformer.convolution import Swish 16 | 17 | 18 | # -- auxiliary functions 19 | def threeD_to_2D_tensor(x): 20 | n_batch, n_channels, s_time, sx, sy = x.shape 21 | x = x.transpose(1, 2).contiguous() 22 | return x.reshape(n_batch*s_time, n_channels, sx, sy) 23 | 24 | 25 | class Conv3dResNet(torch.nn.Module): 26 | """Conv3dResNet module 27 | """ 28 | 29 | def __init__(self, backbone_type="resnet", relu_type="swish"): 30 | """__init__. 31 | 32 | :param backbone_type: str, the type of a visual front-end. 33 | :param relu_type: str, activation function used in an audio front-end. 34 | """ 35 | super(Conv3dResNet, self).__init__() 36 | 37 | self.backbone_type = backbone_type 38 | 39 | if self.backbone_type == "resnet": 40 | self.frontend_nout = 64 41 | self.trunk = ResNet( 42 | BasicBlock, 43 | [2, 2, 2, 2], 44 | relu_type=relu_type, 45 | ) 46 | elif self.backbone_type == "shufflenet": 47 | shufflenet = ShuffleNetV2( 48 | input_size=96, 49 | width_mult=1.0 50 | ) 51 | self.trunk = nn.Sequential( 52 | shufflenet.features, 53 | shufflenet.conv_last, 54 | shufflenet.globalpool, 55 | ) 56 | self.frontend_nout = 24 57 | self.stage_out_channels = shufflenet.stage_out_channels[-1] 58 | 59 | # -- frontend3D 60 | if relu_type == 'relu': 61 | frontend_relu = nn.ReLU(True) 62 | elif relu_type == 'prelu': 63 | frontend_relu = nn.PReLU( self.frontend_nout ) 64 | elif relu_type == 'swish': 65 | frontend_relu = Swish() 66 | 67 | self.frontend3D = nn.Sequential( 68 | nn.Conv3d( 69 | in_channels=1, 70 | out_channels=self.frontend_nout, 71 | kernel_size=(5, 7, 7), 72 | stride=(1, 2, 2), 73 | padding=(2, 3, 3), 74 | bias=False 75 | ), 76 | nn.BatchNorm3d(self.frontend_nout), 77 | frontend_relu, 78 | nn.MaxPool3d( 79 | kernel_size=(1, 3, 3), 80 | stride=(1, 2, 2), 81 | padding=(0, 1, 1), 82 | ) 83 | ) 84 | 85 | 86 | def forward(self, xs_pad): 87 | """forward. 88 | 89 | :param xs_pad: torch.Tensor, batch of padded input sequences. 90 | """ 91 | # -- include Channel dimension 92 | xs_pad = xs_pad.unsqueeze(1) 93 | 94 | B, C, T, H, W = xs_pad.size() 95 | xs_pad = self.frontend3D(xs_pad) 96 | Tnew = xs_pad.shape[2] # outpu should be B x C2 x Tnew x H x W 97 | xs_pad = threeD_to_2D_tensor( xs_pad ) 98 | xs_pad = self.trunk(xs_pad) 99 | xs_pad = xs_pad.view(B, Tnew, xs_pad.size(1)) 100 | 101 | return xs_pad 102 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/multi_layer_conv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Tomoki Hayashi 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Layer modules for FFT block in FastSpeech (Feed-forward Transformer).""" 8 | 9 | import torch 10 | 11 | 12 | class MultiLayeredConv1d(torch.nn.Module): 13 | """Multi-layered conv1d for Transformer block. 14 | 15 | This is a module of multi-leyered conv1d designed 16 | to replace positionwise feed-forward network 17 | in Transforner block, which is introduced in 18 | `FastSpeech: Fast, Robust and Controllable Text to Speech`_. 19 | 20 | .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: 21 | https://arxiv.org/pdf/1905.09263.pdf 22 | 23 | """ 24 | 25 | def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate): 26 | """Initialize MultiLayeredConv1d module. 27 | 28 | Args: 29 | in_chans (int): Number of input channels. 30 | hidden_chans (int): Number of hidden channels. 31 | kernel_size (int): Kernel size of conv1d. 32 | dropout_rate (float): Dropout rate. 33 | 34 | """ 35 | super(MultiLayeredConv1d, self).__init__() 36 | self.w_1 = torch.nn.Conv1d( 37 | in_chans, 38 | hidden_chans, 39 | kernel_size, 40 | stride=1, 41 | padding=(kernel_size - 1) // 2, 42 | ) 43 | self.w_2 = torch.nn.Conv1d( 44 | hidden_chans, 45 | in_chans, 46 | kernel_size, 47 | stride=1, 48 | padding=(kernel_size - 1) // 2, 49 | ) 50 | self.dropout = torch.nn.Dropout(dropout_rate) 51 | 52 | def forward(self, x): 53 | """Calculate forward propagation. 54 | 55 | Args: 56 | x (Tensor): Batch of input tensors (B, ..., in_chans). 57 | 58 | Returns: 59 | Tensor: Batch of output tensors (B, ..., hidden_chans). 60 | 61 | """ 62 | x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) 63 | return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1) 64 | 65 | 66 | class Conv1dLinear(torch.nn.Module): 67 | """Conv1D + Linear for Transformer block. 68 | 69 | A variant of MultiLayeredConv1d, which replaces second conv-layer to linear. 70 | 71 | """ 72 | 73 | def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate): 74 | """Initialize Conv1dLinear module. 75 | 76 | Args: 77 | in_chans (int): Number of input channels. 78 | hidden_chans (int): Number of hidden channels. 79 | kernel_size (int): Kernel size of conv1d. 80 | dropout_rate (float): Dropout rate. 81 | 82 | """ 83 | super(Conv1dLinear, self).__init__() 84 | self.w_1 = torch.nn.Conv1d( 85 | in_chans, 86 | hidden_chans, 87 | kernel_size, 88 | stride=1, 89 | padding=(kernel_size - 1) // 2, 90 | ) 91 | self.w_2 = torch.nn.Linear(hidden_chans, in_chans) 92 | self.dropout = torch.nn.Dropout(dropout_rate) 93 | 94 | def forward(self, x): 95 | """Calculate forward propagation. 96 | 97 | Args: 98 | x (Tensor): Batch of input tensors (B, ..., in_chans). 99 | 100 | Returns: 101 | Tensor: Batch of output tensors (B, ..., hidden_chans). 102 | 103 | """ 104 | x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) 105 | return self.w_2(self.dropout(x)) 106 | -------------------------------------------------------------------------------- /preprocessing.py: -------------------------------------------------------------------------------- 1 | import os, glob 2 | import argparse, cv2 3 | from src.data.landmark_transform import VideoProcess 4 | from tqdm import tqdm 5 | import numpy as np 6 | import pickle 7 | import shutil 8 | from joblib import Parallel, delayed 9 | 10 | def build_file_list(data_path, data_type): 11 | if data_type == 'LRS2': 12 | files = sorted(glob.glob(os.path.join(data_path, 'main', '*', '*.mp4'))) 13 | files.extend(glob.glob(os.path.join(data_path, 'pretrain', '*', '*.mp4'))) 14 | elif data_type == 'LRS3': 15 | files = sorted(glob.glob(os.path.join(data_path, 'trainval', '*', '*.mp4'))) 16 | files.extend(glob.glob(os.path.join(data_path, 'pretrain', '*', '*.mp4'))) 17 | files.extend(glob.glob(os.path.join(data_path, 'test', '*', '*.mp4'))) 18 | else: 19 | raise NotImplementedError 20 | return [f.replace(data_path + '/', '')[:-4] for f in files] 21 | 22 | def load_video(data_filename): 23 | """load_video. 24 | 25 | :param filename: str, the fileanme for a video sequence. 26 | """ 27 | frames = [] 28 | cap = cv2.VideoCapture(data_filename) 29 | while(cap.isOpened()): 30 | ret, frame = cap.read() # BGR 31 | if ret: 32 | frames.append(frame) 33 | else: 34 | break 35 | cap.release() 36 | return np.array(frames) 37 | 38 | 39 | def per_file(f, args, video_process): 40 | save_path = os.path.join(args.save_path, 'Video', f) 41 | if os.path.exists(save_path + '.mp4'): return 42 | lm_save_path = os.path.join(args.save_path, 'Transformed_LM', f) 43 | aud_save_path = os.path.join(args.save_path, 'Audio', f) 44 | txt_save_path = os.path.join(args.save_path, 'Text', f) 45 | if not os.path.exists(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path), exist_ok=True) 46 | if not os.path.exists(os.path.dirname(lm_save_path)): os.makedirs(os.path.dirname(lm_save_path), exist_ok=True) 47 | if not os.path.exists(os.path.dirname(aud_save_path)): os.makedirs(os.path.dirname(aud_save_path), exist_ok=True) 48 | if not os.path.exists(os.path.dirname(txt_save_path)): os.makedirs(os.path.dirname(txt_save_path), exist_ok=True) 49 | if os.path.exists(os.path.join(args.landmark_path, f + '.pkl')): 50 | with open(os.path.join(args.landmark_path, f + '.pkl'), "rb") as pkl_file: 51 | lm = pickle.load(pkl_file) 52 | vid_name = os.path.join(args.data_path, f + '.mp4') 53 | vid = load_video(vid_name) 54 | 55 | if all(x is None for x in lm) or len(vid) == 0: 56 | return 57 | 58 | output = video_process(vid, lm) 59 | if output is None: 60 | return 61 | p_vid, yx_min, transformed_landmarks = output 62 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 63 | output = cv2.VideoWriter(save_path + '.mp4', fourcc, 25, (96, 96)) 64 | for v in p_vid: 65 | output.write(v) 66 | with open(lm_save_path + '.pkl', "wb") as pkl_file: 67 | pickle.dump({'landmarks': transformed_landmarks, 'yx_min': yx_min}, pkl_file) 68 | os.system(f'ffmpeg -loglevel panic -nostdin -y -i {vid_name} -acodec pcm_s16le -ar 16000 -ac 1 {aud_save_path}.wav') 69 | shutil.copy(vid_name[:-4] + '.txt', txt_save_path + '.txt') 70 | 71 | def main(): 72 | parser = get_parser() 73 | args = parser.parse_args() 74 | video_process = VideoProcess(mean_face_path='20words_mean_face.npy', convert_gray=False) 75 | file_lists = build_file_list(args.data_path, args.data_type) 76 | Parallel(n_jobs=3)(delayed(per_file)(f, args, video_process) for f in tqdm(file_lists)) 77 | 78 | def get_parser(): 79 | parser = argparse.ArgumentParser( 80 | description="Command-line script for preprocessing." 81 | ) 82 | parser.add_argument( 83 | "--data_path", type=str, required=True, help="path including video and split files like train.txt" 84 | ) 85 | parser.add_argument( 86 | "--landmark_path", type=str, required=True, help="path including landmark files" 87 | ) 88 | parser.add_argument( 89 | "--save_path", type=str, required=True, help="path for saving" 90 | ) 91 | parser.add_argument( 92 | "--data_type", type=str, required=True, help="LRS2 or LRS3" 93 | ) 94 | return parser 95 | 96 | 97 | if __name__ == "__main__": 98 | main() -------------------------------------------------------------------------------- /src/models/Lip_reader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from espnet.asr.asr_utils import torch_load 3 | from espnet.asr.asr_utils import get_model_conf 4 | from espnet.nets.lm_interface import dynamic_import_lm 5 | from espnet.nets.scorers.length_bonus import LengthBonus 6 | from espnet.nets.batch_beam_search import BatchBeamSearch 7 | from espnet.asr.asr_utils import add_results_to_json 8 | import numpy as np 9 | 10 | class Lipreading(torch.nn.Module): 11 | """Lipreading.""" 12 | 13 | def __init__(self, config, odim, model, char_list, feats_position="resnet"): 14 | """__init__. 15 | 16 | :param config: ConfigParser class, contains model's configuration. 17 | :param feats_position: str, the position to extract features. 18 | """ 19 | super(Lipreading, self).__init__() 20 | 21 | self.feats_position = feats_position 22 | 23 | self.odim = odim 24 | self.model = model 25 | self.char_list = char_list 26 | self.get_beam_search(config) 27 | 28 | self.beam_search.cuda().eval() 29 | 30 | def get_beam_search(self, config): 31 | """get_beam_search. 32 | 33 | :param config: ConfigParser Objects, the main configuration parser. 34 | """ 35 | 36 | rnnlm = config.rnnlm 37 | rnnlm_conf = config.rnnlm_conf 38 | 39 | penalty = config.penalty 40 | maxlenratio = config.maxlenratio 41 | minlenratio = config.minlenratio 42 | ctc_weight = config.ctc_weight 43 | lm_weight = config.lm_weight 44 | beam_size = config.beam_size 45 | 46 | print(f'Beam search with ctc_weight: {ctc_weight}, lm_weight: {lm_weight}, beam_size: {beam_size}') 47 | 48 | sos = self.odim - 1 49 | eos = self.odim - 1 50 | scorers = self.model.scorers() 51 | 52 | if not rnnlm: 53 | lm = None 54 | else: 55 | lm_args = get_model_conf(rnnlm, rnnlm_conf) 56 | lm_model_module = getattr(lm_args, "model_module", "default") 57 | lm_class = dynamic_import_lm(lm_model_module, lm_args.backend) 58 | lm = lm_class(len(self.char_list), lm_args) 59 | torch_load(rnnlm, lm) 60 | print(f"load a pre-trained language model from: {rnnlm}") 61 | lm.eval() 62 | 63 | scorers["lm"] = lm 64 | scorers["length_bonus"] = LengthBonus(len(self.char_list)) 65 | weights = dict( 66 | decoder=1.0 - ctc_weight, 67 | ctc=ctc_weight, 68 | lm=lm_weight, 69 | length_bonus=penalty, 70 | ) 71 | 72 | # -- decoding config 73 | self.beam_size = beam_size 74 | self.nbest = 1 75 | self.weights = weights 76 | self.scorers = scorers 77 | self.sos = sos 78 | self.eos = eos 79 | self.ctc_weight = ctc_weight 80 | self.maxlenratio = maxlenratio 81 | self.minlenratio = minlenratio 82 | 83 | self.beam_search = BatchBeamSearch( 84 | beam_size=self.beam_size, 85 | vocab_size=len(self.char_list), 86 | weights=self.weights, 87 | scorers=self.scorers, 88 | sos=self.sos, 89 | eos=self.eos, 90 | token_list=self.char_list, 91 | pre_beam_score_key=None if self.ctc_weight == 1.0 else "decoder", 92 | ) 93 | 94 | def predict(self, sequence, sequence_aud, search='beam'): 95 | """predict. 96 | 97 | :param sequence: ndarray, the raw sequence saved in a format of numpy array. 98 | """ 99 | with torch.no_grad(): 100 | if isinstance(sequence, np.ndarray): 101 | sequence = (torch.FloatTensor(sequence).cuda()) 102 | sequence_aud = (torch.FloatTensor(sequence_aud).cuda()) 103 | 104 | if hasattr(self.model, "module"): 105 | enc_feats = self.model.module.encode(sequence, sequence_aud) 106 | else: 107 | enc_feats = self.model.encode(sequence, sequence_aud) 108 | 109 | if search=='beam': 110 | nbest_hyps = self.beam_search( 111 | x=enc_feats, 112 | maxlenratio=self.maxlenratio, 113 | minlenratio=self.minlenratio 114 | ) 115 | nbest_hyps = [ 116 | h.asdict() for h in nbest_hyps[:min(len(nbest_hyps), self.nbest)] 117 | ] 118 | 119 | transcription = add_results_to_json(nbest_hyps, self.char_list) 120 | 121 | return transcription.replace("", "") -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/decoder_layer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Decoder self-attention layer definition.""" 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm 13 | 14 | 15 | class DecoderLayer(nn.Module): 16 | """Single decoder layer module. 17 | :param int size: input dim 18 | :param espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention 19 | self_attn: self attention module 20 | :param espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention 21 | src_attn: source attention module 22 | :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward. 23 | PositionwiseFeedForward feed_forward: feed forward layer module 24 | :param float dropout_rate: dropout rate 25 | :param bool normalize_before: whether to use layer_norm before the first block 26 | :param bool concat_after: whether to concat attention layer's input and output 27 | if True, additional linear will be applied. 28 | i.e. x -> x + linear(concat(x, att(x))) 29 | if False, no additional linear will be applied. i.e. x -> x + att(x) 30 | """ 31 | 32 | def __init__( 33 | self, 34 | size, 35 | self_attn, 36 | src_attn, 37 | feed_forward, 38 | dropout_rate, 39 | normalize_before=True, 40 | concat_after=False, 41 | ): 42 | """Construct an DecoderLayer object.""" 43 | super(DecoderLayer, self).__init__() 44 | self.size = size 45 | self.self_attn = self_attn 46 | self.src_attn = src_attn 47 | self.feed_forward = feed_forward 48 | self.norm1 = LayerNorm(size) 49 | self.norm2 = LayerNorm(size) 50 | self.norm3 = LayerNorm(size) 51 | self.dropout = nn.Dropout(dropout_rate) 52 | self.normalize_before = normalize_before 53 | self.concat_after = concat_after 54 | if self.concat_after: 55 | self.concat_linear1 = nn.Linear(size + size, size) 56 | self.concat_linear2 = nn.Linear(size + size, size) 57 | 58 | def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None): 59 | """Compute decoded features. 60 | Args: 61 | tgt (torch.Tensor): 62 | decoded previous target features (batch, max_time_out, size) 63 | tgt_mask (torch.Tensor): mask for x (batch, max_time_out) 64 | memory (torch.Tensor): encoded source features (batch, max_time_in, size) 65 | memory_mask (torch.Tensor): mask for memory (batch, max_time_in) 66 | cache (torch.Tensor): cached output (batch, max_time_out-1, size) 67 | """ 68 | residual = tgt 69 | if self.normalize_before: 70 | tgt = self.norm1(tgt) 71 | 72 | if cache is None: 73 | tgt_q = tgt 74 | tgt_q_mask = tgt_mask 75 | else: 76 | # compute only the last frame query keeping dim: max_time_out -> 1 77 | assert cache.shape == ( 78 | tgt.shape[0], 79 | tgt.shape[1] - 1, 80 | self.size, 81 | ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" 82 | tgt_q = tgt[:, -1:, :] 83 | residual = residual[:, -1:, :] 84 | tgt_q_mask = None 85 | if tgt_mask is not None: 86 | tgt_q_mask = tgt_mask[:, -1:, :] 87 | 88 | if self.concat_after: 89 | tgt_concat = torch.cat( 90 | (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1 91 | ) 92 | x = residual + self.concat_linear1(tgt_concat) 93 | else: 94 | x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) 95 | if not self.normalize_before: 96 | x = self.norm1(x) 97 | 98 | residual = x 99 | if self.normalize_before: 100 | x = self.norm2(x) 101 | if self.concat_after: 102 | x_concat = torch.cat( 103 | (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1 104 | ) 105 | x = residual + self.concat_linear2(x_concat) 106 | else: 107 | x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask)) 108 | if not self.normalize_before: 109 | x = self.norm2(x) 110 | 111 | residual = x 112 | if self.normalize_before: 113 | x = self.norm3(x) 114 | x = residual + self.dropout(self.feed_forward(x)) 115 | if not self.normalize_before: 116 | x = self.norm3(x) 117 | 118 | if cache is not None: 119 | x = torch.cat([cache, x], dim=1) 120 | 121 | return x, tgt_mask, memory, memory_mask 122 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/plot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | import logging 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy 11 | 12 | from espnet.asr import asr_utils 13 | 14 | 15 | def _plot_and_save_attention(att_w, filename, xtokens=None, ytokens=None): 16 | # dynamically import matplotlib due to not found error 17 | from matplotlib.ticker import MaxNLocator 18 | import os 19 | 20 | d = os.path.dirname(filename) 21 | if not os.path.exists(d): 22 | os.makedirs(d) 23 | w, h = plt.figaspect(1.0 / len(att_w)) 24 | fig = plt.Figure(figsize=(w * 2, h * 2)) 25 | axes = fig.subplots(1, len(att_w)) 26 | if len(att_w) == 1: 27 | axes = [axes] 28 | for ax, aw in zip(axes, att_w): 29 | # plt.subplot(1, len(att_w), h) 30 | ax.imshow(aw.astype(numpy.float32), aspect="auto") 31 | ax.set_xlabel("Input") 32 | ax.set_ylabel("Output") 33 | ax.xaxis.set_major_locator(MaxNLocator(integer=True)) 34 | ax.yaxis.set_major_locator(MaxNLocator(integer=True)) 35 | # Labels for major ticks 36 | if xtokens is not None: 37 | ax.set_xticks(numpy.linspace(0, len(xtokens) - 1, len(xtokens))) 38 | ax.set_xticks(numpy.linspace(0, len(xtokens) - 1, 1), minor=True) 39 | ax.set_xticklabels(xtokens + [""], rotation=40) 40 | if ytokens is not None: 41 | ax.set_yticks(numpy.linspace(0, len(ytokens) - 1, len(ytokens))) 42 | ax.set_yticks(numpy.linspace(0, len(ytokens) - 1, 1), minor=True) 43 | ax.set_yticklabels(ytokens + [""]) 44 | fig.tight_layout() 45 | return fig 46 | 47 | 48 | def savefig(plot, filename): 49 | plot.savefig(filename) 50 | plt.clf() 51 | 52 | 53 | def plot_multi_head_attention( 54 | data, 55 | attn_dict, 56 | outdir, 57 | suffix="png", 58 | savefn=savefig, 59 | ikey="input", 60 | iaxis=0, 61 | okey="output", 62 | oaxis=0, 63 | ): 64 | """Plot multi head attentions. 65 | 66 | :param dict data: utts info from json file 67 | :param dict[str, torch.Tensor] attn_dict: multi head attention dict. 68 | values should be torch.Tensor (head, input_length, output_length) 69 | :param str outdir: dir to save fig 70 | :param str suffix: filename suffix including image type (e.g., png) 71 | :param savefn: function to save 72 | 73 | """ 74 | for name, att_ws in attn_dict.items(): 75 | for idx, att_w in enumerate(att_ws): 76 | filename = "%s/%s.%s.%s" % (outdir, data[idx][0], name, suffix) 77 | dec_len = int(data[idx][1][okey][oaxis]["shape"][0]) 78 | enc_len = int(data[idx][1][ikey][iaxis]["shape"][0]) 79 | xtokens, ytokens = None, None 80 | if "encoder" in name: 81 | att_w = att_w[:, :enc_len, :enc_len] 82 | # for MT 83 | if "token" in data[idx][1][ikey][iaxis].keys(): 84 | xtokens = data[idx][1][ikey][iaxis]["token"].split() 85 | ytokens = xtokens[:] 86 | elif "decoder" in name: 87 | if "self" in name: 88 | att_w = att_w[:, : dec_len + 1, : dec_len + 1] # +1 for 89 | else: 90 | att_w = att_w[:, : dec_len + 1, :enc_len] # +1 for 91 | # for MT 92 | if "token" in data[idx][1][ikey][iaxis].keys(): 93 | xtokens = data[idx][1][ikey][iaxis]["token"].split() 94 | # for ASR/ST/MT 95 | if "token" in data[idx][1][okey][oaxis].keys(): 96 | ytokens = [""] + data[idx][1][okey][oaxis]["token"].split() 97 | if "self" in name: 98 | xtokens = ytokens[:] 99 | else: 100 | logging.warning("unknown name for shaping attention") 101 | fig = _plot_and_save_attention(att_w, filename, xtokens, ytokens) 102 | savefn(fig, filename) 103 | 104 | 105 | class PlotAttentionReport(asr_utils.PlotAttentionReport): 106 | def plotfn(self, *args, **kwargs): 107 | kwargs["ikey"] = self.ikey 108 | kwargs["iaxis"] = self.iaxis 109 | kwargs["okey"] = self.okey 110 | kwargs["oaxis"] = self.oaxis 111 | plot_multi_head_attention(*args, **kwargs) 112 | 113 | def __call__(self, trainer): 114 | attn_dict = self.get_attention_weights() 115 | suffix = "ep.{.updater.epoch}.png".format(trainer) 116 | self.plotfn(self.data, attn_dict, self.outdir, suffix, savefig) 117 | 118 | def get_attention_weights(self): 119 | batch = self.converter([self.transform(self.data)], self.device) 120 | if isinstance(batch, tuple): 121 | att_ws = self.att_vis_fn(*batch) 122 | elif isinstance(batch, dict): 123 | att_ws = self.att_vis_fn(**batch) 124 | return att_ws 125 | 126 | def log_attentions(self, logger, step): 127 | def log_fig(plot, filename): 128 | from os.path import basename 129 | 130 | logger.add_figure(basename(filename), plot, step) 131 | plt.clf() 132 | 133 | attn_dict = self.get_attention_weights() 134 | self.plotfn(self.data, attn_dict, self.outdir, "", log_fig) 135 | -------------------------------------------------------------------------------- /espnet/nets/scorers/ctc.py: -------------------------------------------------------------------------------- 1 | """ScorerInterface implementation for CTC.""" 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from espnet.nets.ctc_prefix_score import CTCPrefixScore 7 | from espnet.nets.ctc_prefix_score import CTCPrefixScoreTH 8 | from espnet.nets.scorer_interface import BatchPartialScorerInterface 9 | 10 | 11 | class CTCPrefixScorer(BatchPartialScorerInterface): 12 | """Decoder interface wrapper for CTCPrefixScore.""" 13 | 14 | def __init__(self, ctc: torch.nn.Module, eos: int): 15 | """Initialize class. 16 | 17 | Args: 18 | ctc (torch.nn.Module): The CTC implementation. 19 | For example, :class:`espnet.nets.pytorch_backend.ctc.CTC` 20 | eos (int): The end-of-sequence id. 21 | 22 | """ 23 | self.ctc = ctc 24 | self.eos = eos 25 | self.impl = None 26 | 27 | def init_state(self, x: torch.Tensor): 28 | """Get an initial state for decoding. 29 | 30 | Args: 31 | x (torch.Tensor): The encoded feature tensor 32 | 33 | Returns: initial state 34 | 35 | """ 36 | logp = self.ctc.log_softmax(x.unsqueeze(0)).detach().squeeze(0).cpu().numpy() 37 | # TODO(karita): use CTCPrefixScoreTH 38 | self.impl = CTCPrefixScore(logp, 0, self.eos, np) 39 | return 0, self.impl.initial_state() 40 | 41 | def select_state(self, state, i, new_id=None): 42 | """Select state with relative ids in the main beam search. 43 | 44 | Args: 45 | state: Decoder state for prefix tokens 46 | i (int): Index to select a state in the main beam search 47 | new_id (int): New label id to select a state if necessary 48 | 49 | Returns: 50 | state: pruned state 51 | 52 | """ 53 | if type(state) == tuple: 54 | if len(state) == 2: # for CTCPrefixScore 55 | sc, st = state 56 | return sc[i], st[i] 57 | else: # for CTCPrefixScoreTH (need new_id > 0) 58 | r, log_psi, f_min, f_max, scoring_idmap = state 59 | s = log_psi[i, new_id].expand(log_psi.size(1)) 60 | if scoring_idmap is not None: 61 | return r[:, :, i, scoring_idmap[i, new_id]], s, f_min, f_max 62 | else: 63 | return r[:, :, i, new_id], s, f_min, f_max 64 | return None if state is None else state[i] 65 | 66 | def score_partial(self, y, ids, state, x): 67 | """Score new token. 68 | 69 | Args: 70 | y (torch.Tensor): 1D prefix token 71 | next_tokens (torch.Tensor): torch.int64 next token to score 72 | state: decoder state for prefix tokens 73 | x (torch.Tensor): 2D encoder feature that generates ys 74 | 75 | Returns: 76 | tuple[torch.Tensor, Any]: 77 | Tuple of a score tensor for y that has a shape `(len(next_tokens),)` 78 | and next state for ys 79 | 80 | """ 81 | prev_score, state = state 82 | presub_score, new_st = self.impl(y.cpu(), ids.cpu(), state) 83 | tscore = torch.as_tensor( 84 | presub_score - prev_score, device=x.device, dtype=x.dtype 85 | ) 86 | return tscore, (presub_score, new_st) 87 | 88 | def batch_init_state(self, x: torch.Tensor): 89 | """Get an initial state for decoding. 90 | 91 | Args: 92 | x (torch.Tensor): The encoded feature tensor 93 | 94 | Returns: initial state 95 | 96 | """ 97 | logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1 98 | xlen = torch.tensor([logp.size(1)]) 99 | self.impl = CTCPrefixScoreTH(logp, xlen, 0, self.eos) 100 | return None 101 | 102 | def batch_score_partial(self, y, ids, state, x): 103 | """Score new token. 104 | 105 | Args: 106 | y (torch.Tensor): 1D prefix token 107 | ids (torch.Tensor): torch.int64 next token to score 108 | state: decoder state for prefix tokens 109 | x (torch.Tensor): 2D encoder feature that generates ys 110 | 111 | Returns: 112 | tuple[torch.Tensor, Any]: 113 | Tuple of a score tensor for y that has a shape `(len(next_tokens),)` 114 | and next state for ys 115 | 116 | """ 117 | batch_state = ( 118 | ( 119 | torch.stack([s[0] for s in state], dim=2), 120 | torch.stack([s[1] for s in state]), 121 | state[0][2], 122 | state[0][3], 123 | ) 124 | if state[0] is not None 125 | else None 126 | ) 127 | return self.impl(y, batch_state, ids) 128 | 129 | def extend_prob(self, x: torch.Tensor): 130 | """Extend probs for decoding. 131 | 132 | This extension is for streaming decoding 133 | as in Eq (14) in https://arxiv.org/abs/2006.14941 134 | 135 | Args: 136 | x (torch.Tensor): The encoded feature tensor 137 | 138 | """ 139 | logp = self.ctc.log_softmax(x.unsqueeze(0)) 140 | self.impl.extend_prob(logp) 141 | 142 | def extend_state(self, state): 143 | """Extend state for decoding. 144 | 145 | This extension is for streaming decoding 146 | as in Eq (14) in https://arxiv.org/abs/2006.14941 147 | 148 | Args: 149 | state: The states of hyps 150 | 151 | Returns: exteded state 152 | 153 | """ 154 | new_state = [] 155 | for s in state: 156 | new_state.append(self.impl.extend_state(s)) 157 | 158 | return new_state 159 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import torch 4 | from torch import nn, optim 5 | from torch.utils.tensorboard import SummaryWriter 6 | import numpy as np 7 | from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E 8 | import os 9 | from src.data.dataset import AVDataset 10 | from src.models.Lip_reader import Lipreading 11 | import glob 12 | import editdistance 13 | import json 14 | from tqdm import tqdm 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--data_path', default="preprocessed_data_path") 19 | parser.add_argument('--split_file', default="./src/data/LRS2/test.ref") 20 | parser.add_argument('--data_type', default="LRS2") 21 | parser.add_argument('--model_conf', default="./src/models/model.json") 22 | 23 | parser.add_argument('--results_path', default='./test_results.txt') 24 | 25 | parser.add_argument("--checkpoint", type=str, default=None) 26 | 27 | parser.add_argument("--rnnlm", type=str, default='./checkpoints/LM/model.pth') 28 | parser.add_argument("--rnnlm_conf", type=str, default='./checkpoints/LM/model.json') 29 | 30 | parser.add_argument("--beam_size", type=int, default=40) 31 | parser.add_argument("--penalty", type=float, default=0.5) 32 | parser.add_argument("--maxlenratio", type=float, default=0) 33 | parser.add_argument("--minlenratio", type=float, default=0) 34 | parser.add_argument("--ctc_weight", type=float, default=0.1) 35 | parser.add_argument("--lm_weight", type=float, default=0.5) 36 | 37 | parser.add_argument("--architecture", default='AVRelScore', help='AVRelScore, VCAFE, Conformer') 38 | parser.add_argument("--gpu", type=str, default='0') 39 | parser.add_argument("--local_rank", type=int, default=0) 40 | args = parser.parse_args() 41 | return args 42 | 43 | def main(args): 44 | torch.backends.cudnn.deterministic = True 45 | torch.backends.cudnn.benchmark = False 46 | torch.manual_seed(args.local_rank) 47 | torch.cuda.manual_seed_all(args.local_rank) 48 | random.seed(args.local_rank) 49 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 50 | 51 | with open(args.model_conf, "rb") as f: 52 | confs = json.load(f) 53 | if isinstance(confs, dict): 54 | model_args = confs 55 | else: 56 | _, odim, model_args = confs 57 | model_args = argparse.Namespace(**model_args) 58 | 59 | model = E2E(odim, model_args, architecture=args.architecture) 60 | 61 | if args.checkpoint is not None: 62 | if args.local_rank == 0: 63 | print(f"Loading checkpoint: {args.checkpoint}") 64 | checkpoint = torch.load(args.checkpoint, map_location=lambda storage, loc: storage.cuda()) 65 | model.load_state_dict(checkpoint['state_dict']) 66 | del checkpoint 67 | 68 | model.cuda() 69 | model.eval() 70 | 71 | test_data = AVDataset( 72 | data_path=args.data_path, 73 | split_file=args.split_file, 74 | mode='test', 75 | data_type=args.data_type, 76 | ) 77 | 78 | Lip_reader = Lipreading(args, odim, model, test_data.char_list) 79 | test(Lip_reader, test_data) 80 | 81 | def test(Lip_reader, test_data): 82 | wer_list = AverageMeter() 83 | cer_list = AverageMeter() 84 | with torch.no_grad(): 85 | with open(args.split_file, 'r') as txt: 86 | lines = txt.readlines() 87 | for idx, line in enumerate(lines): 88 | basename, groundtruth = line.split()[0], " ".join(line.split()[1:]) 89 | data_filename = os.path.join(args.data_path, 'Video', basename + '.mp4') 90 | data_aud_filename = os.path.join(args.data_path, 'Audio', basename + '.wav') 91 | 92 | vid, aud = test_data.load_data(data_filename, data_aud_filename) 93 | output = Lip_reader.predict(vid, aud) 94 | if isinstance(output, str): 95 | print(f"hyp: {output}") 96 | if groundtruth is not None: 97 | print(f"ref: {groundtruth}") 98 | wer_list.update(*get_wer(output, groundtruth)) 99 | cer_list.update(*get_cer(output, groundtruth)) 100 | print( 101 | f"progress: {idx + 1}/{len(lines)}\tcur WER: {wer_list.val * 100:.1f}\t" 102 | f"cur CER: {cer_list.val * 100:.1f}\t" 103 | f"avg WER: {wer_list.avg * 100:.1f}\tavg CER: {cer_list.avg * 100:.1f}" 104 | ) 105 | 106 | if args.results_path is not None: 107 | with open(args.results_path, 'w') as txt: 108 | txt.write(f'WER: {wer_list.avg * 100:.3f}, CER: {cer_list.avg * 100:.3f}') 109 | 110 | def get_wer(predict, truth): 111 | predict = predict.split(' ') 112 | truth = truth.split(' ') 113 | err = editdistance.eval(predict, truth) 114 | num = len(truth) 115 | return err, num 116 | 117 | def get_cer(predict, truth): 118 | predict = predict.replace(' ', '') 119 | truth = truth.replace(' ', '') 120 | err = editdistance.eval(predict, truth) 121 | num = len(truth) 122 | return err, num 123 | 124 | class AverageMeter(object): 125 | """Computes and stores the average and current value""" 126 | 127 | def __init__(self): 128 | self.reset() 129 | 130 | def reset(self): 131 | self.val = 0 132 | self.avg = 0 133 | self.sum = 0 134 | self.count = 0 135 | 136 | def update(self, err, num=1): 137 | self.val = err / num 138 | self.sum += err 139 | self.count += num 140 | self.avg = self.sum / self.count 141 | 142 | if __name__ == "__main__": 143 | args = parse_args() 144 | main(args) 145 | 146 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/backbones/modules/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import pdb 4 | 5 | from espnet.nets.pytorch_backend.transformer.convolution import Swish 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | """conv3x3. 10 | 11 | :param in_planes: int, number of channels in the input sequence. 12 | :param out_planes: int, number of channels produced by the convolution. 13 | :param stride: int, size of the convolving kernel. 14 | """ 15 | return nn.Conv2d( 16 | in_planes, 17 | out_planes, 18 | kernel_size=3, 19 | stride=stride, 20 | padding=1, 21 | bias=False, 22 | ) 23 | 24 | 25 | def downsample_basic_block(inplanes, outplanes, stride): 26 | """downsample_basic_block. 27 | 28 | :param inplanes: int, number of channels in the input sequence. 29 | :param outplanes: int, number of channels produced by the convolution. 30 | :param stride: int, size of the convolving kernel. 31 | """ 32 | return nn.Sequential( 33 | nn.Conv2d( 34 | inplanes, 35 | outplanes, 36 | kernel_size=1, 37 | stride=stride, 38 | bias=False, 39 | ), 40 | nn.BatchNorm2d(outplanes), 41 | ) 42 | 43 | 44 | class BasicBlock(nn.Module): 45 | expansion = 1 46 | 47 | def __init__( 48 | self, 49 | inplanes, 50 | planes, 51 | stride=1, 52 | downsample=None, 53 | relu_type="swish", 54 | ): 55 | """__init__. 56 | 57 | :param inplanes: int, number of channels in the input sequence. 58 | :param planes: int, number of channels produced by the convolution. 59 | :param stride: int, size of the convolving kernel. 60 | :param downsample: boolean, if True, the temporal resolution is downsampled. 61 | :param relu_type: str, type of activation function. 62 | """ 63 | super(BasicBlock, self).__init__() 64 | 65 | assert relu_type in ["relu", "prelu", "swish"] 66 | 67 | self.conv1 = conv3x3(inplanes, planes, stride) 68 | self.bn1 = nn.BatchNorm2d(planes) 69 | 70 | if relu_type == "relu": 71 | self.relu1 = nn.ReLU(inplace=True) 72 | self.relu2 = nn.ReLU(inplace=True) 73 | elif relu_type == "prelu": 74 | self.relu1 = nn.PReLU(num_parameters=planes) 75 | self.relu2 = nn.PReLU(num_parameters=planes) 76 | elif relu_type == "swish": 77 | self.relu1 = Swish() 78 | self.relu2 = Swish() 79 | else: 80 | raise NotImplementedError 81 | # -------- 82 | 83 | self.conv2 = conv3x3(planes, planes) 84 | self.bn2 = nn.BatchNorm2d(planes) 85 | 86 | self.downsample = downsample 87 | self.stride = stride 88 | 89 | def forward(self, x): 90 | """forward. 91 | 92 | :param x: torch.Tensor, input tensor with input size (B, C, T, H, W). 93 | """ 94 | residual = x 95 | out = self.conv1(x) 96 | out = self.bn1(out) 97 | out = self.relu1(out) 98 | out = self.conv2(out) 99 | out = self.bn2(out) 100 | if self.downsample is not None: 101 | residual = self.downsample(x) 102 | 103 | out += residual 104 | out = self.relu2(out) 105 | 106 | return out 107 | 108 | 109 | class ResNet(nn.Module): 110 | 111 | def __init__( 112 | self, 113 | block, 114 | layers, 115 | relu_type="swish", 116 | ): 117 | super(ResNet, self).__init__() 118 | self.inplanes = 64 119 | self.relu_type = relu_type 120 | self.downsample_block = downsample_basic_block 121 | 122 | self.layer1 = self._make_layer(block, 64, layers[0]) 123 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 124 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 125 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 126 | self.avgpool = nn.AdaptiveAvgPool2d(1) 127 | 128 | 129 | def _make_layer(self, block, planes, blocks, stride=1): 130 | """_make_layer. 131 | 132 | :param block: torch.nn.Module, class of blocks. 133 | :param planes: int, number of channels produced by the convolution. 134 | :param blocks: int, number of layers in a block. 135 | :param stride: int, size of the convolving kernel. 136 | """ 137 | downsample = None 138 | if stride != 1 or self.inplanes != planes * block.expansion: 139 | downsample = self.downsample_block( 140 | inplanes=self.inplanes, 141 | outplanes=planes*block.expansion, 142 | stride=stride, 143 | ) 144 | 145 | layers = [] 146 | layers.append( 147 | block( 148 | self.inplanes, 149 | planes, 150 | stride, 151 | downsample, 152 | relu_type=self.relu_type, 153 | ) 154 | ) 155 | self.inplanes = planes * block.expansion 156 | for i in range(1, blocks): 157 | layers.append( 158 | block( 159 | self.inplanes, 160 | planes, 161 | relu_type=self.relu_type, 162 | ) 163 | ) 164 | 165 | return nn.Sequential(*layers) 166 | 167 | def forward(self, x): 168 | """forward. 169 | 170 | :param x: torch.Tensor, input tensor with input size (B, C, T, H, W). 171 | """ 172 | x = self.layer1(x) 173 | x = self.layer2(x) 174 | x = self.layer3(x) 175 | x = self.layer4(x) 176 | x = self.avgpool(x) 177 | x = x.view(x.size(0), -1) 178 | return x 179 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/encoder_layer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Encoder self-attention layer definition.""" 8 | 9 | import copy 10 | import torch 11 | 12 | from torch import nn 13 | 14 | from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm 15 | 16 | 17 | class EncoderLayer(nn.Module): 18 | """Encoder layer module. 19 | 20 | :param int size: input dim 21 | :param espnet.nets.pytorch_backend.transformer.attention. 22 | MultiHeadedAttention self_attn: self attention module 23 | RelPositionMultiHeadedAttention self_attn: self attention module 24 | :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward. 25 | PositionwiseFeedForward feed_forward: 26 | feed forward module 27 | :param espnet.nets.pytorch_backend.transformer.convolution. 28 | ConvolutionModule feed_foreard: 29 | feed forward module 30 | :param float dropout_rate: dropout rate 31 | :param bool normalize_before: whether to use layer_norm before the first block 32 | :param bool concat_after: whether to concat attention layer's input and output 33 | if True, additional linear will be applied. 34 | i.e. x -> x + linear(concat(x, att(x))) 35 | if False, no additional linear will be applied. i.e. x -> x + att(x) 36 | :param bool macaron_style: whether to use macaron style for PositionwiseFeedForward 37 | 38 | """ 39 | 40 | def __init__( 41 | self, 42 | size, 43 | self_attn, 44 | feed_forward, 45 | conv_module, 46 | dropout_rate, 47 | normalize_before=True, 48 | concat_after=False, 49 | macaron_style=False, 50 | ): 51 | """Construct an EncoderLayer object.""" 52 | super(EncoderLayer, self).__init__() 53 | self.self_attn = self_attn 54 | self.feed_forward = feed_forward 55 | self.ff_scale = 1.0 56 | self.conv_module = conv_module 57 | self.macaron_style = macaron_style 58 | self.norm_ff = LayerNorm(size) # for the FNN module 59 | self.norm_mha = LayerNorm(size) # for the MHA module 60 | if self.macaron_style: 61 | self.feed_forward_macaron = copy.deepcopy(feed_forward) 62 | self.ff_scale = 0.5 63 | # for another FNN module in macaron style 64 | self.norm_ff_macaron = LayerNorm(size) 65 | if self.conv_module is not None: 66 | self.norm_conv = LayerNorm(size) # for the CNN module 67 | self.norm_final = LayerNorm(size) # for the final output of the block 68 | self.dropout = nn.Dropout(dropout_rate) 69 | self.size = size 70 | self.normalize_before = normalize_before 71 | self.concat_after = concat_after 72 | if self.concat_after: 73 | self.concat_linear = nn.Linear(size + size, size) 74 | 75 | def forward(self, x_input, mask, cache=None): 76 | """Compute encoded features. 77 | 78 | :param torch.Tensor x_input: encoded source features (batch, max_time_in, size) 79 | :param torch.Tensor mask: mask for x (batch, max_time_in) 80 | :param torch.Tensor cache: cache for x (batch, max_time_in - 1, size) 81 | :rtype: Tuple[torch.Tensor, torch.Tensor] 82 | """ 83 | if isinstance(x_input, tuple): 84 | x, pos_emb = x_input[0], x_input[1] 85 | else: 86 | x, pos_emb = x_input, None 87 | 88 | # whether to use macaron style 89 | if self.macaron_style: 90 | residual = x 91 | if self.normalize_before: 92 | x = self.norm_ff_macaron(x) 93 | x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x)) 94 | if not self.normalize_before: 95 | x = self.norm_ff_macaron(x) 96 | 97 | # multi-headed self-attention module 98 | residual = x 99 | if self.normalize_before: 100 | x = self.norm_mha(x) 101 | 102 | if cache is None: 103 | x_q = x 104 | else: 105 | assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) 106 | x_q = x[:, -1:, :] 107 | residual = residual[:, -1:, :] 108 | mask = None if mask is None else mask[:, -1:, :] 109 | 110 | if pos_emb is not None: 111 | x_att = self.self_attn(x_q, x, x, pos_emb, mask) 112 | else: 113 | x_att = self.self_attn(x_q, x, x, mask) 114 | 115 | if self.concat_after: 116 | x_concat = torch.cat((x, x_att), dim=-1) 117 | x = residual + self.concat_linear(x_concat) 118 | else: 119 | x = residual + self.dropout(x_att) 120 | if not self.normalize_before: 121 | x = self.norm_mha(x) 122 | 123 | # convolution module 124 | if self.conv_module is not None: 125 | residual = x 126 | if self.normalize_before: 127 | x = self.norm_conv(x) 128 | x = residual + self.dropout(self.conv_module(x)) 129 | if not self.normalize_before: 130 | x = self.norm_conv(x) 131 | 132 | # feed forward module 133 | residual = x 134 | if self.normalize_before: 135 | x = self.norm_ff(x) 136 | x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) 137 | if not self.normalize_before: 138 | x = self.norm_ff(x) 139 | 140 | if self.conv_module is not None: 141 | x = self.norm_final(x) 142 | 143 | if cache is not None: 144 | x = torch.cat([cache, x], dim=1) 145 | 146 | if pos_emb is not None: 147 | return (x, pos_emb), mask 148 | else: 149 | return x, mask 150 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/backbones/modules/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from collections import OrderedDict 6 | from torch.nn import init 7 | import math 8 | 9 | import pdb 10 | 11 | def conv_bn(inp, oup, stride): 12 | return nn.Sequential( 13 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 14 | nn.BatchNorm2d(oup), 15 | nn.ReLU(inplace=True) 16 | ) 17 | 18 | 19 | def conv_1x1_bn(inp, oup): 20 | return nn.Sequential( 21 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 22 | nn.BatchNorm2d(oup), 23 | nn.ReLU(inplace=True) 24 | ) 25 | 26 | def channel_shuffle(x, groups): 27 | batchsize, num_channels, height, width = x.data.size() 28 | 29 | channels_per_group = num_channels // groups 30 | 31 | # reshape 32 | x = x.view(batchsize, groups, 33 | channels_per_group, height, width) 34 | 35 | x = torch.transpose(x, 1, 2).contiguous() 36 | 37 | # flatten 38 | x = x.view(batchsize, -1, height, width) 39 | 40 | return x 41 | 42 | class InvertedResidual(nn.Module): 43 | def __init__(self, inp, oup, stride, benchmodel): 44 | super(InvertedResidual, self).__init__() 45 | self.benchmodel = benchmodel 46 | self.stride = stride 47 | assert stride in [1, 2] 48 | 49 | oup_inc = oup//2 50 | 51 | if self.benchmodel == 1: 52 | #assert inp == oup_inc 53 | self.banch2 = nn.Sequential( 54 | # pw 55 | nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False), 56 | nn.BatchNorm2d(oup_inc), 57 | nn.ReLU(inplace=True), 58 | # dw 59 | nn.Conv2d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False), 60 | nn.BatchNorm2d(oup_inc), 61 | # pw-linear 62 | nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False), 63 | nn.BatchNorm2d(oup_inc), 64 | nn.ReLU(inplace=True), 65 | ) 66 | else: 67 | self.banch1 = nn.Sequential( 68 | # dw 69 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 70 | nn.BatchNorm2d(inp), 71 | # pw-linear 72 | nn.Conv2d(inp, oup_inc, 1, 1, 0, bias=False), 73 | nn.BatchNorm2d(oup_inc), 74 | nn.ReLU(inplace=True), 75 | ) 76 | 77 | self.banch2 = nn.Sequential( 78 | # pw 79 | nn.Conv2d(inp, oup_inc, 1, 1, 0, bias=False), 80 | nn.BatchNorm2d(oup_inc), 81 | nn.ReLU(inplace=True), 82 | # dw 83 | nn.Conv2d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False), 84 | nn.BatchNorm2d(oup_inc), 85 | # pw-linear 86 | nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False), 87 | nn.BatchNorm2d(oup_inc), 88 | nn.ReLU(inplace=True), 89 | ) 90 | 91 | @staticmethod 92 | def _concat(x, out): 93 | # concatenate along channel axis 94 | return torch.cat((x, out), 1) 95 | 96 | def forward(self, x): 97 | if 1==self.benchmodel: 98 | x1 = x[:, :(x.shape[1]//2), :, :] 99 | x2 = x[:, (x.shape[1]//2):, :, :] 100 | out = self._concat(x1, self.banch2(x2)) 101 | elif 2==self.benchmodel: 102 | out = self._concat(self.banch1(x), self.banch2(x)) 103 | 104 | return channel_shuffle(out, 2) 105 | 106 | 107 | class ShuffleNetV2(nn.Module): 108 | def __init__(self, n_class=1000, input_size=224, width_mult=2.): 109 | super(ShuffleNetV2, self).__init__() 110 | 111 | assert input_size % 32 == 0, "Input size needs to be divisible by 32" 112 | 113 | self.stage_repeats = [4, 8, 4] 114 | # index 0 is invalid and should never be called. 115 | # only used for indexing convenience. 116 | if width_mult == 0.5: 117 | self.stage_out_channels = [-1, 24, 48, 96, 192, 1024] 118 | elif width_mult == 1.0: 119 | self.stage_out_channels = [-1, 24, 116, 232, 464, 1024] 120 | elif width_mult == 1.5: 121 | self.stage_out_channels = [-1, 24, 176, 352, 704, 1024] 122 | elif width_mult == 2.0: 123 | self.stage_out_channels = [-1, 24, 244, 488, 976, 2048] 124 | else: 125 | raise ValueError( 126 | """Width multiplier should be in [0.5, 1.0, 1.5, 2.0]. Current value: {}""".format(width_mult)) 127 | 128 | # building first layer 129 | input_channel = self.stage_out_channels[1] 130 | self.conv1 = conv_bn(3, input_channel, 2) 131 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 132 | 133 | self.features = [] 134 | # building inverted residual blocks 135 | for idxstage in range(len(self.stage_repeats)): 136 | numrepeat = self.stage_repeats[idxstage] 137 | output_channel = self.stage_out_channels[idxstage+2] 138 | for i in range(numrepeat): 139 | if i == 0: 140 | #inp, oup, stride, benchmodel): 141 | self.features.append(InvertedResidual(input_channel, output_channel, 2, 2)) 142 | else: 143 | self.features.append(InvertedResidual(input_channel, output_channel, 1, 1)) 144 | input_channel = output_channel 145 | 146 | 147 | # make it nn.Sequential 148 | self.features = nn.Sequential(*self.features) 149 | 150 | # building last several layers 151 | self.conv_last = conv_1x1_bn(input_channel, self.stage_out_channels[-1]) 152 | self.globalpool = nn.Sequential(nn.AvgPool2d(int(input_size/32))) 153 | 154 | # building classifier 155 | self.classifier = nn.Sequential(nn.Linear(self.stage_out_channels[-1], n_class)) 156 | 157 | def forward(self, x): 158 | x = self.conv1(x) 159 | x = self.maxpool(x) 160 | x = self.features(x) 161 | x = self.conv_last(x) 162 | x = self.globalpool(x) 163 | x = x.view(-1, self.stage_out_channels[-1]) 164 | x = self.classifier(x) 165 | return x 166 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/lm/seq_rnn.py: -------------------------------------------------------------------------------- 1 | """Sequential implementation of Recurrent Neural Network Language Model.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from espnet.nets.lm_interface import LMInterface 8 | 9 | 10 | class SequentialRNNLM(LMInterface, torch.nn.Module): 11 | """Sequential RNNLM. 12 | 13 | See also: 14 | https://github.com/pytorch/examples/blob/4581968193699de14b56527296262dd76ab43557/word_language_model/model.py 15 | 16 | """ 17 | 18 | @staticmethod 19 | def add_arguments(parser): 20 | """Add arguments to command line argument parser.""" 21 | parser.add_argument( 22 | "--type", 23 | type=str, 24 | default="lstm", 25 | nargs="?", 26 | choices=["lstm", "gru"], 27 | help="Which type of RNN to use", 28 | ) 29 | parser.add_argument( 30 | "--layer", "-l", type=int, default=2, help="Number of hidden layers" 31 | ) 32 | parser.add_argument( 33 | "--unit", "-u", type=int, default=650, help="Number of hidden units" 34 | ) 35 | parser.add_argument( 36 | "--dropout-rate", type=float, default=0.5, help="dropout probability" 37 | ) 38 | return parser 39 | 40 | def __init__(self, n_vocab, args): 41 | """Initialize class. 42 | 43 | Args: 44 | n_vocab (int): The size of the vocabulary 45 | args (argparse.Namespace): configurations. see py:method:`add_arguments` 46 | 47 | """ 48 | torch.nn.Module.__init__(self) 49 | self._setup( 50 | rnn_type=args.type.upper(), 51 | ntoken=n_vocab, 52 | ninp=args.unit, 53 | nhid=args.unit, 54 | nlayers=args.layer, 55 | dropout=args.dropout_rate, 56 | ) 57 | 58 | def _setup( 59 | self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False 60 | ): 61 | self.drop = nn.Dropout(dropout) 62 | self.encoder = nn.Embedding(ntoken, ninp) 63 | if rnn_type in ["LSTM", "GRU"]: 64 | self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout) 65 | else: 66 | try: 67 | nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type] 68 | except KeyError: 69 | raise ValueError( 70 | "An invalid option for `--model` was supplied, " 71 | "options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']" 72 | ) 73 | self.rnn = nn.RNN( 74 | ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout 75 | ) 76 | self.decoder = nn.Linear(nhid, ntoken) 77 | 78 | # Optionally tie weights as in: 79 | # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) 80 | # https://arxiv.org/abs/1608.05859 81 | # and 82 | # "Tying Word Vectors and Word Classifiers: 83 | # A Loss Framework for Language Modeling" (Inan et al. 2016) 84 | # https://arxiv.org/abs/1611.01462 85 | if tie_weights: 86 | if nhid != ninp: 87 | raise ValueError( 88 | "When using the tied flag, nhid must be equal to emsize" 89 | ) 90 | self.decoder.weight = self.encoder.weight 91 | 92 | self._init_weights() 93 | 94 | self.rnn_type = rnn_type 95 | self.nhid = nhid 96 | self.nlayers = nlayers 97 | 98 | def _init_weights(self): 99 | # NOTE: original init in pytorch/examples 100 | # initrange = 0.1 101 | # self.encoder.weight.data.uniform_(-initrange, initrange) 102 | # self.decoder.bias.data.zero_() 103 | # self.decoder.weight.data.uniform_(-initrange, initrange) 104 | # NOTE: our default.py:RNNLM init 105 | for param in self.parameters(): 106 | param.data.uniform_(-0.1, 0.1) 107 | 108 | def forward(self, x, t): 109 | """Compute LM loss value from buffer sequences. 110 | 111 | Args: 112 | x (torch.Tensor): Input ids. (batch, len) 113 | t (torch.Tensor): Target ids. (batch, len) 114 | 115 | Returns: 116 | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of 117 | loss to backward (scalar), 118 | negative log-likelihood of t: -log p(t) (scalar) and 119 | the number of elements in x (scalar) 120 | 121 | Notes: 122 | The last two return values are used 123 | in perplexity: p(t)^{-n} = exp(-log p(t) / n) 124 | 125 | """ 126 | y = self._before_loss(x, None)[0] 127 | mask = (x != 0).to(y.dtype) 128 | loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none") 129 | logp = loss * mask.view(-1) 130 | logp = logp.sum() 131 | count = mask.sum() 132 | return logp / count, logp, count 133 | 134 | def _before_loss(self, input, hidden): 135 | emb = self.drop(self.encoder(input)) 136 | output, hidden = self.rnn(emb, hidden) 137 | output = self.drop(output) 138 | decoded = self.decoder( 139 | output.view(output.size(0) * output.size(1), output.size(2)) 140 | ) 141 | return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden 142 | 143 | def init_state(self, x): 144 | """Get an initial state for decoding. 145 | 146 | Args: 147 | x (torch.Tensor): The encoded feature tensor 148 | 149 | Returns: initial state 150 | 151 | """ 152 | bsz = 1 153 | weight = next(self.parameters()) 154 | if self.rnn_type == "LSTM": 155 | return ( 156 | weight.new_zeros(self.nlayers, bsz, self.nhid), 157 | weight.new_zeros(self.nlayers, bsz, self.nhid), 158 | ) 159 | else: 160 | return weight.new_zeros(self.nlayers, bsz, self.nhid) 161 | 162 | def score(self, y, state, x): 163 | """Score new token. 164 | 165 | Args: 166 | y (torch.Tensor): 1D torch.int64 prefix tokens. 167 | state: Scorer state for prefix tokens 168 | x (torch.Tensor): 2D encoder feature that generates ys. 169 | 170 | Returns: 171 | tuple[torch.Tensor, Any]: Tuple of 172 | torch.float32 scores for next token (n_vocab) 173 | and next state for ys 174 | 175 | """ 176 | y, new_state = self._before_loss(y[-1].view(1, 1), state) 177 | logp = y.log_softmax(dim=-1).view(-1) 178 | return logp, new_state 179 | -------------------------------------------------------------------------------- /espnet/nets/scorer_interface.py: -------------------------------------------------------------------------------- 1 | """Scorer interface module.""" 2 | 3 | from typing import Any 4 | from typing import List 5 | from typing import Tuple 6 | 7 | import torch 8 | import warnings 9 | 10 | 11 | class ScorerInterface: 12 | """Scorer interface for beam search. 13 | 14 | The scorer performs scoring of the all tokens in vocabulary. 15 | 16 | Examples: 17 | * Search heuristics 18 | * :class:`espnet.nets.scorers.length_bonus.LengthBonus` 19 | * Decoder networks of the sequence-to-sequence models 20 | * :class:`espnet.nets.pytorch_backend.nets.transformer.decoder.Decoder` 21 | * :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder` 22 | * Neural language models 23 | * :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM` 24 | * :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM` 25 | * :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM` 26 | 27 | """ 28 | 29 | def init_state(self, x: torch.Tensor) -> Any: 30 | """Get an initial state for decoding (optional). 31 | 32 | Args: 33 | x (torch.Tensor): The encoded feature tensor 34 | 35 | Returns: initial state 36 | 37 | """ 38 | return None 39 | 40 | def select_state(self, state: Any, i: int, new_id: int = None) -> Any: 41 | """Select state with relative ids in the main beam search. 42 | 43 | Args: 44 | state: Decoder state for prefix tokens 45 | i (int): Index to select a state in the main beam search 46 | new_id (int): New label index to select a state if necessary 47 | 48 | Returns: 49 | state: pruned state 50 | 51 | """ 52 | return None if state is None else state[i] 53 | 54 | def score( 55 | self, y: torch.Tensor, state: Any, x: torch.Tensor 56 | ) -> Tuple[torch.Tensor, Any]: 57 | """Score new token (required). 58 | 59 | Args: 60 | y (torch.Tensor): 1D torch.int64 prefix tokens. 61 | state: Scorer state for prefix tokens 62 | x (torch.Tensor): The encoder feature that generates ys. 63 | 64 | Returns: 65 | tuple[torch.Tensor, Any]: Tuple of 66 | scores for next token that has a shape of `(n_vocab)` 67 | and next state for ys 68 | 69 | """ 70 | raise NotImplementedError 71 | 72 | def final_score(self, state: Any) -> float: 73 | """Score eos (optional). 74 | 75 | Args: 76 | state: Scorer state for prefix tokens 77 | 78 | Returns: 79 | float: final score 80 | 81 | """ 82 | return 0.0 83 | 84 | 85 | class BatchScorerInterface(ScorerInterface): 86 | """Batch scorer interface.""" 87 | 88 | def batch_init_state(self, x: torch.Tensor) -> Any: 89 | """Get an initial state for decoding (optional). 90 | 91 | Args: 92 | x (torch.Tensor): The encoded feature tensor 93 | 94 | Returns: initial state 95 | 96 | """ 97 | return self.init_state(x) 98 | 99 | def batch_score( 100 | self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor 101 | ) -> Tuple[torch.Tensor, List[Any]]: 102 | """Score new token batch (required). 103 | 104 | Args: 105 | ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). 106 | states (List[Any]): Scorer states for prefix tokens. 107 | xs (torch.Tensor): 108 | The encoder feature that generates ys (n_batch, xlen, n_feat). 109 | 110 | Returns: 111 | tuple[torch.Tensor, List[Any]]: Tuple of 112 | batchfied scores for next token with shape of `(n_batch, n_vocab)` 113 | and next state list for ys. 114 | 115 | """ 116 | warnings.warn( 117 | "{} batch score is implemented through for loop not parallelized".format( 118 | self.__class__.__name__ 119 | ) 120 | ) 121 | scores = list() 122 | outstates = list() 123 | for i, (y, state, x) in enumerate(zip(ys, states, xs)): 124 | score, outstate = self.score(y, state, x) 125 | outstates.append(outstate) 126 | scores.append(score) 127 | scores = torch.cat(scores, 0).view(ys.shape[0], -1) 128 | return scores, outstates 129 | 130 | 131 | class PartialScorerInterface(ScorerInterface): 132 | """Partial scorer interface for beam search. 133 | 134 | The partial scorer performs scoring when non-partial scorer finished scoring, 135 | and receives pre-pruned next tokens to score because it is too heavy to score 136 | all the tokens. 137 | 138 | Examples: 139 | * Prefix search for connectionist-temporal-classification models 140 | * :class:`espnet.nets.scorers.ctc.CTCPrefixScorer` 141 | 142 | """ 143 | 144 | def score_partial( 145 | self, y: torch.Tensor, next_tokens: torch.Tensor, state: Any, x: torch.Tensor 146 | ) -> Tuple[torch.Tensor, Any]: 147 | """Score new token (required). 148 | 149 | Args: 150 | y (torch.Tensor): 1D prefix token 151 | next_tokens (torch.Tensor): torch.int64 next token to score 152 | state: decoder state for prefix tokens 153 | x (torch.Tensor): The encoder feature that generates ys 154 | 155 | Returns: 156 | tuple[torch.Tensor, Any]: 157 | Tuple of a score tensor for y that has a shape `(len(next_tokens),)` 158 | and next state for ys 159 | 160 | """ 161 | raise NotImplementedError 162 | 163 | 164 | class BatchPartialScorerInterface(BatchScorerInterface, PartialScorerInterface): 165 | """Batch partial scorer interface for beam search.""" 166 | 167 | def batch_score_partial( 168 | self, 169 | ys: torch.Tensor, 170 | next_tokens: torch.Tensor, 171 | states: List[Any], 172 | xs: torch.Tensor, 173 | ) -> Tuple[torch.Tensor, Any]: 174 | """Score new token (required). 175 | 176 | Args: 177 | ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). 178 | next_tokens (torch.Tensor): torch.int64 tokens to score (n_batch, n_token). 179 | states (List[Any]): Scorer states for prefix tokens. 180 | xs (torch.Tensor): 181 | The encoder feature that generates ys (n_batch, xlen, n_feat). 182 | 183 | Returns: 184 | tuple[torch.Tensor, Any]: 185 | Tuple of a score tensor for ys that has a shape `(n_batch, n_vocab)` 186 | and next states for ys 187 | """ 188 | raise NotImplementedError 189 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/backbones/modules/resnet1d.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import pdb 4 | 5 | from espnet.nets.pytorch_backend.transformer.convolution import Swish 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | """conv3x3. 10 | 11 | :param in_planes: int, number of channels in the input sequence. 12 | :param out_planes: int, number of channels produced by the convolution. 13 | :param stride: int, size of the convolving kernel. 14 | """ 15 | return nn.Conv1d( 16 | in_planes, 17 | out_planes, 18 | kernel_size=3, 19 | stride=stride, 20 | padding=1, 21 | bias=False, 22 | ) 23 | 24 | 25 | def downsample_basic_block(inplanes, outplanes, stride): 26 | """downsample_basic_block. 27 | 28 | :param inplanes: int, number of channels in the input sequence. 29 | :param outplanes: int, number of channels produced by the convolution. 30 | :param stride: int, size of the convolving kernel. 31 | """ 32 | return nn.Sequential( 33 | nn.Conv1d( 34 | inplanes, 35 | outplanes, 36 | kernel_size=1, 37 | stride=stride, 38 | bias=False, 39 | ), 40 | nn.BatchNorm1d(outplanes), 41 | ) 42 | 43 | 44 | class BasicBlock1D(nn.Module): 45 | expansion = 1 46 | 47 | def __init__( 48 | self, 49 | inplanes, 50 | planes, 51 | stride=1, 52 | downsample=None, 53 | relu_type="relu", 54 | ): 55 | """__init__. 56 | 57 | :param inplanes: int, number of channels in the input sequence. 58 | :param planes: int, number of channels produced by the convolution. 59 | :param stride: int, size of the convolving kernel. 60 | :param downsample: boolean, if True, the temporal resolution is downsampled. 61 | :param relu_type: str, type of activation function. 62 | """ 63 | super(BasicBlock1D, self).__init__() 64 | 65 | assert relu_type in ["relu","prelu", "swish"] 66 | 67 | self.conv1 = conv3x3(inplanes, planes, stride) 68 | self.bn1 = nn.BatchNorm1d(planes) 69 | 70 | # type of ReLU is an input option 71 | if relu_type == "relu": 72 | self.relu1 = nn.ReLU(inplace=True) 73 | self.relu2 = nn.ReLU(inplace=True) 74 | elif relu_type == "prelu": 75 | self.relu1 = nn.PReLU(num_parameters=planes) 76 | self.relu2 = nn.PReLU(num_parameters=planes) 77 | elif relu_type == "swish": 78 | self.relu1 = Swish() 79 | self.relu2 = Swish() 80 | else: 81 | raise NotImplementedError 82 | # -------- 83 | 84 | self.conv2 = conv3x3(planes, planes) 85 | self.bn2 = nn.BatchNorm1d(planes) 86 | 87 | self.downsample = downsample 88 | self.stride = stride 89 | 90 | def forward(self, x): 91 | """forward. 92 | 93 | :param x: torch.Tensor, input tensor with input size (B, C, T) 94 | """ 95 | residual = x 96 | out = self.conv1(x) 97 | out = self.bn1(out) 98 | out = self.relu1(out) 99 | out = self.conv2(out) 100 | out = self.bn2(out) 101 | if self.downsample is not None: 102 | residual = self.downsample(x) 103 | 104 | out += residual 105 | out = self.relu2(out) 106 | 107 | return out 108 | 109 | 110 | class ResNet1D(nn.Module): 111 | 112 | def __init__(self, 113 | block, 114 | layers, 115 | relu_type="swish", 116 | a_upsample_ratio=1, 117 | ): 118 | """__init__. 119 | 120 | :param block: torch.nn.Module, class of blocks. 121 | :param layers: List, customised layers in each block. 122 | :param relu_type: str, type of activation function. 123 | :param a_upsample_ratio: int, The ratio related to the \ 124 | temporal resolution of output features of the frontend. \ 125 | a_upsample_ratio=1 produce features with a fps of 25. 126 | """ 127 | super(ResNet1D, self).__init__() 128 | self.inplanes = 64 129 | self.relu_type = relu_type 130 | self.downsample_block = downsample_basic_block 131 | self.a_upsample_ratio = a_upsample_ratio 132 | 133 | self.conv1 = nn.Conv1d( 134 | in_channels=1, 135 | out_channels=self.inplanes, 136 | kernel_size=80, 137 | stride=4, 138 | padding=38, 139 | bias=False, 140 | ) 141 | self.bn1 = nn.BatchNorm1d(self.inplanes) 142 | 143 | if relu_type == "relu": 144 | self.relu = nn.ReLU(inplace=True) 145 | elif relu_type == "prelu": 146 | self.relu = nn.PReLU(num_parameters=self.inplanes) 147 | elif relu_type == "swish": 148 | self.relu = Swish() 149 | 150 | self.layer1 = self._make_layer(block, 64, layers[0]) 151 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 152 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 153 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 154 | self.avgpool = nn.AvgPool1d( 155 | kernel_size=20//self.a_upsample_ratio, 156 | stride=20//self.a_upsample_ratio, 157 | ) 158 | 159 | 160 | def _make_layer(self, block, planes, blocks, stride=1): 161 | """_make_layer. 162 | 163 | :param block: torch.nn.Module, class of blocks. 164 | :param planes: int, number of channels produced by the convolution. 165 | :param blocks: int, number of layers in a block. 166 | :param stride: int, size of the convolving kernel. 167 | """ 168 | 169 | downsample = None 170 | if stride != 1 or self.inplanes != planes * block.expansion: 171 | downsample = self.downsample_block( 172 | inplanes=self.inplanes, 173 | outplanes=planes*block.expansion, 174 | stride=stride, 175 | ) 176 | 177 | layers = [] 178 | layers.append( 179 | block( 180 | self.inplanes, 181 | planes, 182 | stride, 183 | downsample, 184 | relu_type=self.relu_type, 185 | ) 186 | ) 187 | self.inplanes = planes * block.expansion 188 | for i in range(1, blocks): 189 | layers.append( 190 | block( 191 | self.inplanes, 192 | planes, 193 | relu_type=self.relu_type, 194 | ) 195 | ) 196 | 197 | return nn.Sequential(*layers) 198 | 199 | def forward(self, x): 200 | """forward. 201 | 202 | :param x: torch.Tensor, input tensor with input size (B, C, T) 203 | """ 204 | x = self.conv1(x) 205 | x = self.bn1(x) 206 | x = self.relu(x) 207 | 208 | x = self.layer1(x) 209 | x = self.layer2(x) 210 | x = self.layer3(x) 211 | x = self.layer4(x) 212 | x = self.avgpool(x) 213 | return x 214 | -------------------------------------------------------------------------------- /src/models/model.json: -------------------------------------------------------------------------------- 1 | [ 2 | 9216, 3 | 41, 4 | { 5 | "a_upsample_ratio": 1, 6 | "accum_grad": 2, 7 | "adim": 256, 8 | "aheads": 4, 9 | "apply_uttmvn": true, 10 | "aux_lsm_weight": 0.0, 11 | "backend": "pytorch", 12 | "badim": 320, 13 | "batch_bins": 0, 14 | "batch_count": "auto", 15 | "batch_frames_in": 0, 16 | "batch_frames_inout": 0, 17 | "batch_frames_out": 0, 18 | "bdropout_rate": 0.0, 19 | "beam_size": 4, 20 | "blayers": 2, 21 | "bnmask": 2, 22 | "bprojs": 300, 23 | "btype": "blstmp", 24 | "bunits": 300, 25 | "char_list": [ 26 | "", 27 | "", 28 | "'", 29 | "0", 30 | "1", 31 | "2", 32 | "3", 33 | "4", 34 | "5", 35 | "6", 36 | "7", 37 | "8", 38 | "9", 39 | "", 40 | "A", 41 | "B", 42 | "C", 43 | "D", 44 | "E", 45 | "F", 46 | "G", 47 | "H", 48 | "I", 49 | "J", 50 | "K", 51 | "L", 52 | "M", 53 | "N", 54 | "O", 55 | "P", 56 | "Q", 57 | "R", 58 | "S", 59 | "T", 60 | "U", 61 | "V", 62 | "W", 63 | "X", 64 | "Y", 65 | "Z", 66 | "" 67 | ], 68 | "cnn_module_kernel": 31, 69 | "config2": null, 70 | "config3": null, 71 | "context_residual": false, 72 | "criterion": "acc", 73 | "ctc_type": "warpctc", 74 | "ctc_weight": 0.3, 75 | "debugmode": 1, 76 | "dec_init": null, 77 | "dec_init_mods": [ 78 | "att.", 79 | " dec." 80 | ], 81 | "dict": "data/lang_1char/units.txt", 82 | "dlayers": 6, 83 | "dropout_rate": 0.1, 84 | "dunits": 2048, 85 | "early_stop_criterion": "validation/main/acc", 86 | "elayers": 12, 87 | "enc_init": null, 88 | "enc_init_mods": [ 89 | "enc.enc." 90 | ], 91 | "eps": 1e-08, 92 | "eps_decay": 0.01, 93 | "eunits": 2048, 94 | "fbank_fmax": null, 95 | "fbank_fmin": 0.0, 96 | "fbank_fs": 16000, 97 | "grad_clip": 5.0, 98 | "grad_noise": false, 99 | "lm_weight": 0.1, 100 | "lsm_weight": 0.1, 101 | "macaron_style": 1, 102 | "maxlen_in": 220, 103 | "maxlen_out": 220, 104 | "maxlenratio": 0.0, 105 | "minibatches": 0, 106 | "minlenratio": 0.0, 107 | "model_module": "espnet.nets.pytorch_backend.e2e_asr_transformer_multitask_dual:E2E", 108 | "mtl_custom_worker_l1_weight": 0.0, 109 | "mtl_custom_worker_length_normalized_loss": 0, 110 | "mtl_custom_worker_mlp_hdim": 256, 111 | "mtl_custom_worker_mlp_nlayers": 2, 112 | "mtl_custom_worker_mlp_nonlin_end": 0, 113 | "mtl_custom_worker_mlp_nonlin_type": "relu", 114 | "mtl_custom_worker_name": "patrickvonplaten/wav2vec2-base", 115 | "mtl_custom_worker_task_type": "", 116 | "mtl_custom_worker_tgt_type": "projected_quantized_states", 117 | "mtl_kl_weight": 0.0, 118 | "mtl_kl_weight_2": 0.0, 119 | "mtl_l1_weight": 0.4, 120 | "mtl_l1_weight_2": 0.4, 121 | "mtl_length_normalized_loss": 1, 122 | "mtl_length_normalized_loss_2": 1, 123 | "mtl_mlp_hdim": 256, 124 | "mtl_mlp_hdim_2": 256, 125 | "mtl_mlp_nlayers": 1, 126 | "mtl_mlp_nlayers_2": 1, 127 | "mtl_mlp_nonlin_end": 0, 128 | "mtl_mlp_nonlin_end_2": 0, 129 | "mtl_mlp_nonlin_type": "relu", 130 | "mtl_mlp_nonlin_type_2": "relu", 131 | "mtl_task_layer": "conformer6", 132 | "mtl_task_type": "l1", 133 | "mtl_task_type_2": "l1", 134 | "mtl_worker_source": "conv1d_lrs2_lrs3v04", 135 | "mtl_worker_source_2": "conv3d_lrs2_lrs3v04_dual", 136 | "mtlalpha": 0.1, 137 | "n_iter_processes": 12, 138 | "n_mels": 80, 139 | "nbest": 1, 140 | "ngpu": 1, 141 | "num_encs": 1, 142 | "num_input": 2, 143 | "num_save_attention": 3, 144 | "num_spkrs": 1, 145 | "opt": "noam", 146 | "patience": 0, 147 | "penalty": 0.0, 148 | "preprocess_conf": null, 149 | "pretrain_dataset": "lrs3_v04_full_dual_ignore", 150 | "raw_max_freq_width": 150, 151 | "raw_max_speed_rate": 1.1, 152 | "raw_max_time_width": 0.4, 153 | "raw_min_speed_rate": 0.9, 154 | "raw_n_freq_mask": 2, 155 | "raw_n_time_mask": 2, 156 | "raw_speech_do_normalize": false, 157 | "ref_channel": -1, 158 | "rel_pos_type": "latest", 159 | "relu_type": "swish", 160 | "report_cer": false, 161 | "report_interval_iters": 100, 162 | "report_wer": false, 163 | "rnnlm": null, 164 | "rnnlm_conf": null, 165 | "save_interval_iters": 0, 166 | "seed": 1, 167 | "sortagrad": 0, 168 | "specaug_max_freq_width": 30, 169 | "specaug_max_time_warp": 5, 170 | "specaug_max_time_width": 40, 171 | "specaug_n_freq_mask": 2, 172 | "specaug_n_time_mask": 2, 173 | "sr_interp_mode": "nearest", 174 | "sr_interp_scale_factor": 1.0, 175 | "stats_file": null, 176 | "sym_blank": "", 177 | "sym_space": "", 178 | "threshold": 0.0001, 179 | "train_dtype": "float32", 180 | "transformer_attn_dropout_rate": 0.1, 181 | "transformer_encoder_attn_layer_type": "rel_mha", 182 | "transformer_init": "pytorch", 183 | "transformer_input_layer": "conv3d", 184 | "transformer_length_normalized_loss": 0, 185 | "transformer_warmup_steps": 25000, 186 | "use_beamformer": true, 187 | "use_cnn_module": 1, 188 | "use_dnn_mask_for_wpe": false, 189 | "use_freqmask": false, 190 | "use_frontend": false, 191 | "use_noiseaug": false, 192 | "use_specaug": false, 193 | "use_speedaug": false, 194 | "use_timemask": false, 195 | "use_v_adaptive_timemask": true, 196 | "use_v_cutout": false, 197 | "use_v_timemask": false, 198 | "use_wpe": false, 199 | "uttmvn_norm_means": true, 200 | "uttmvn_norm_vars": false, 201 | "v_cutout_max_hole_length": 22, 202 | "v_cutout_n_holes": 1, 203 | "v_raw_max_time_width": 0.4, 204 | "v_raw_n_time_mask": 1, 205 | "v_timemask_replace_with_zero": false, 206 | "v_timemask_stride": 1.0, 207 | "verbose": 0, 208 | "wavaugments": null, 209 | "wdropout_rate": 0.0, 210 | "weight_decay": 0.0, 211 | "wlayers": 2, 212 | "wpe_delay": 3, 213 | "wpe_taps": 5, 214 | "wprojs": 300, 215 | "wtype": "blstmp", 216 | "wunits": 300, 217 | "zero_triu": false 218 | } 219 | ] -------------------------------------------------------------------------------- /src/data/landmark_transform.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2023 Imperial College London (Pingchuan Ma) 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | import os 8 | import cv2 9 | import numpy as np 10 | from skimage import transform as tf 11 | 12 | 13 | def linear_interpolate(landmarks, start_idx, stop_idx): 14 | start_landmarks = landmarks[start_idx] 15 | stop_landmarks = landmarks[stop_idx] 16 | delta = stop_landmarks - start_landmarks 17 | for idx in range(1, stop_idx-start_idx): 18 | landmarks[start_idx+idx] = start_landmarks + idx/float(stop_idx-start_idx) * delta 19 | return landmarks 20 | 21 | 22 | def warp_img(src, dst, img, std_size): 23 | tform = tf.estimate_transform('similarity', src, dst) 24 | warped = tf.warp(img, inverse_map=tform.inverse, output_shape=std_size) 25 | warped = (warped * 255).astype('uint8') 26 | return warped, tform 27 | 28 | 29 | def apply_transform(transform, img, std_size): 30 | warped = tf.warp(img, inverse_map=transform.inverse, output_shape=std_size) 31 | warped = (warped * 255).astype('uint8') 32 | return warped 33 | 34 | def cut_patch(img, landmarks, height, width, threshold=5): 35 | center_x, center_y = np.mean(landmarks, axis=0) 36 | # Check for too much bias in height and width 37 | if abs(center_y - img.shape[0] / 2) > height + threshold: 38 | return None 39 | if abs(center_x - img.shape[1] / 2) > width + threshold: 40 | return None 41 | # Calculate bounding box coordinates 42 | y_min = int(round(np.clip(center_y - height, 0, img.shape[0]))) 43 | y_max = int(round(np.clip(center_y + height, 0, img.shape[0]))) 44 | x_min = int(round(np.clip(center_x - width, 0, img.shape[1]))) 45 | x_max = int(round(np.clip(center_x + width, 0, img.shape[1]))) 46 | # Cut the image 47 | cutted_img = np.copy(img[y_min:y_max, x_min:x_max]) 48 | return cutted_img, y_min, x_min 49 | 50 | 51 | class VideoProcess: 52 | def __init__(self, mean_face_path="20words_mean_face.npy", crop_width=96, crop_height=96, 53 | start_idx=48, stop_idx=68, window_margin=12, convert_gray=True): 54 | self.reference = np.load(os.path.join(os.path.dirname(__file__), mean_face_path)) 55 | self.crop_width = crop_width 56 | self.crop_height = crop_height 57 | self.start_idx = start_idx 58 | self.stop_idx = stop_idx 59 | self.window_margin = window_margin 60 | self.convert_gray = convert_gray 61 | 62 | def __call__(self, video, landmarks): 63 | # Pre-process landmarks: interpolate frames that are not detected 64 | preprocessed_landmarks = self.interpolate_landmarks(landmarks) 65 | # Exclude corner cases: no landmark in all frames 66 | if not preprocessed_landmarks: 67 | return 68 | # Affine transformation and crop patch 69 | output = self.crop_patch(video, preprocessed_landmarks) 70 | if output is None: 71 | return None 72 | sequence, yx_min, transformed_landmarks = output 73 | assert sequence is not None, f"cannot crop a patch from." 74 | return sequence, yx_min, transformed_landmarks 75 | 76 | 77 | def crop_patch(self, video, landmarks): 78 | sequence = [] 79 | tf_landmarks = [] 80 | yx_min = [] 81 | for frame_idx, frame in enumerate(video): 82 | window_margin = min(self.window_margin // 2, frame_idx, len(landmarks) - 1 - frame_idx) 83 | smoothed_landmarks = np.mean([landmarks[x] for x in range(frame_idx - window_margin, frame_idx + window_margin + 1)], axis=0) 84 | smoothed_landmarks += landmarks[frame_idx].mean(axis=0) - smoothed_landmarks.mean(axis=0) 85 | transformed_frame, transformed_landmarks = self.affine_transform(frame, smoothed_landmarks, self.reference, grayscale=self.convert_gray) 86 | output = cut_patch(transformed_frame, transformed_landmarks[self.start_idx:self.stop_idx], self.crop_height//2, self.crop_width//2,) 87 | if output is None: 88 | return None 89 | patch, y_min, x_min = output 90 | sequence.append(patch) 91 | yx_min.append([y_min, x_min]) 92 | tf_landmarks.append(transformed_landmarks) 93 | return np.array(sequence), yx_min, tf_landmarks 94 | 95 | 96 | def interpolate_landmarks(self, landmarks): 97 | valid_frames_idx = [idx for idx, lm in enumerate(landmarks) if lm is not None] 98 | 99 | if not valid_frames_idx: 100 | return None 101 | 102 | for idx in range(1, len(valid_frames_idx)): 103 | if valid_frames_idx[idx] - valid_frames_idx[idx - 1] > 1: 104 | landmarks = linear_interpolate(landmarks, valid_frames_idx[idx - 1], valid_frames_idx[idx]) 105 | 106 | valid_frames_idx = [idx for idx, lm in enumerate(landmarks) if lm is not None] 107 | 108 | # Handle corner case: keep frames at the beginning or at the end that failed to be detected 109 | if valid_frames_idx: 110 | landmarks[:valid_frames_idx[0]] = [landmarks[valid_frames_idx[0]]] * valid_frames_idx[0] 111 | landmarks[valid_frames_idx[-1]:] = [landmarks[valid_frames_idx[-1]]] * (len(landmarks) - valid_frames_idx[-1]) 112 | 113 | assert all(lm is not None for lm in landmarks), "not every frame has landmark" 114 | 115 | return landmarks 116 | 117 | 118 | def affine_transform(self, frame, landmarks, reference, grayscale=True, 119 | target_size=(256, 256), reference_size=(256, 256), stable_points=(28, 33, 36, 39, 42, 45, 48, 54), 120 | interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_CONSTANT, border_value=0): 121 | if grayscale: 122 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) 123 | stable_reference = self.get_stable_reference(reference, stable_points, reference_size, target_size) 124 | transform = self.estimate_affine_transform(landmarks, stable_points, stable_reference) 125 | transformed_frame, transformed_landmarks = self.apply_affine_transform(frame, landmarks, transform, target_size, interpolation, border_mode, border_value) 126 | 127 | return transformed_frame, transformed_landmarks 128 | 129 | 130 | def get_stable_reference(self, reference, stable_points, reference_size, target_size): 131 | stable_reference = np.vstack([reference[x] for x in stable_points]) 132 | stable_reference[:, 0] -= (reference_size[0] - target_size[0]) / 2.0 133 | stable_reference[:, 1] -= (reference_size[1] - target_size[1]) / 2.0 134 | return stable_reference 135 | 136 | 137 | def estimate_affine_transform(self, landmarks, stable_points, stable_reference): 138 | return cv2.estimateAffinePartial2D(np.vstack([landmarks[x] for x in stable_points]), stable_reference, method=cv2.LMEDS)[0] 139 | 140 | 141 | def apply_affine_transform(self, frame, landmarks, transform, target_size, interpolation, border_mode, border_value): 142 | transformed_frame = cv2.warpAffine(frame, transform, dsize=(target_size[0], target_size[1]), 143 | flags=interpolation, borderMode=border_mode, borderValue=border_value) 144 | transformed_landmarks = np.matmul(landmarks, transform[:, :2].transpose()) + transform[:, 2].transpose() 145 | return transformed_frame, transformed_landmarks -------------------------------------------------------------------------------- /src/data/visual_corruption.py: -------------------------------------------------------------------------------- 1 | import os, random, cv2 2 | import albumentations as A 3 | import numpy as np 4 | import torchvision 5 | import torch 6 | from skimage.util import random_noise 7 | 8 | class Visual_Corruption_Modeling: 9 | def __init__(self, d_image='./occlusion_patch/object_image_sr', d_mask='./occlusion_patch/object_mask_x4'): 10 | assert os.path.exists(d_image), "Please download coco_object.7z first" 11 | self.d_image = d_image 12 | self.d_mask = d_mask 13 | self.aug = get_occluder_augmentor() 14 | self.occlude_imgs = os.listdir(d_image) 15 | 16 | def get_occluders(self): 17 | occlude_img = random.choice(self.occlude_imgs) 18 | 19 | occlude_mask = occlude_img.replace('jpeg', 'png') 20 | 21 | ori_occluder_img = cv2.imread(os.path.join(self.d_image, occlude_img), -1) 22 | ori_occluder_img = cv2.cvtColor(ori_occluder_img, cv2.COLOR_BGR2RGB) 23 | 24 | occluder_mask = cv2.imread(os.path.join(self.d_mask, occlude_mask)) 25 | occluder_mask = cv2.cvtColor(occluder_mask, cv2.COLOR_BGR2GRAY) 26 | 27 | occluder_mask = cv2.resize(occluder_mask, (ori_occluder_img.shape[1], ori_occluder_img.shape[0]), 28 | interpolation=cv2.INTER_LANCZOS4) 29 | 30 | occluder_img = cv2.bitwise_and(ori_occluder_img, ori_occluder_img, mask=occluder_mask) 31 | 32 | transformed = self.aug(image=occluder_img, mask=occluder_mask) 33 | occluder_img, occluder_mask = transformed["image"], transformed["mask"] 34 | 35 | occluder_size = random.choice(range(20, 46)) 36 | 37 | occluder_img = cv2.resize(occluder_img, (occluder_size, occluder_size), interpolation= cv2.INTER_LANCZOS4) 38 | occluder_mask = cv2.resize(occluder_mask, (occluder_size, occluder_size), interpolation= cv2.INTER_LANCZOS4) 39 | 40 | return occlude_img, occluder_img, occluder_mask 41 | 42 | def noise_sequence(self, img_seq, freq=1): 43 | if freq == 1: 44 | len = img_seq.shape[0] 45 | occ_len = random.randint(int(len * 0.1), int(len * 0.5)) 46 | start_fr = random.randint(0, len-occ_len) 47 | 48 | raw_sequence = img_seq[start_fr:start_fr+occ_len] 49 | prob = random.random() 50 | if prob < 0.3: 51 | var = random.random() * 0.2 52 | raw_sequence = np.expand_dims(raw_sequence, 3) 53 | raw_sequence = random_noise(raw_sequence, mode='gaussian', mean=0, var=var, clip=True) * 255 54 | raw_sequence = raw_sequence.squeeze(3) 55 | elif prob < 0.6: 56 | blur = torchvision.transforms.GaussianBlur(kernel_size=(7, 7), sigma=(0.1, 2.0)) 57 | raw_sequence = np.expand_dims(raw_sequence, 3) 58 | raw_sequence = blur(torch.tensor(raw_sequence).permute(0, 3, 1, 2)).permute(0, 2, 3, 1).numpy() 59 | raw_sequence = raw_sequence.squeeze(3) 60 | else: 61 | pass 62 | 63 | img_seq[start_fr:start_fr + occ_len] = raw_sequence 64 | 65 | else: 66 | len_global = img_seq.shape[0] 67 | len = img_seq.shape[0] // freq 68 | for j in range(freq): 69 | try: 70 | occ_len = random.randint(int(len_global * 0.3), int(len_global * 0.5)) 71 | start_fr = random.randint(0, len*j + len - occ_len) 72 | if start_fr < len*j: 73 | assert 1==2 74 | except: 75 | occ_len = len // 2 76 | start_fr = len * j 77 | 78 | raw_sequence = img_seq[start_fr:start_fr + occ_len] 79 | prob = random.random() 80 | if prob < 0.3: 81 | var = random.random() * 0.2 82 | raw_sequence = np.expand_dims(raw_sequence, 3) 83 | raw_sequence = random_noise(raw_sequence, mode='gaussian', mean=0, var=var, clip=True) * 255 84 | raw_sequence = raw_sequence.squeeze(3) 85 | elif prob < 0.6: 86 | blur = torchvision.transforms.GaussianBlur(kernel_size=(7, 7), sigma=(0.1, 2.0)) 87 | raw_sequence = np.expand_dims(raw_sequence, 3) 88 | raw_sequence = blur(torch.tensor(raw_sequence).permute(0, 3, 1, 2)).permute(0, 2, 3, 1).numpy() 89 | raw_sequence = raw_sequence.squeeze(3) 90 | else: 91 | pass 92 | 93 | img_seq[start_fr:start_fr + occ_len] = raw_sequence 94 | 95 | return img_seq 96 | 97 | 98 | def occlude_sequence(self, img_seq, landmarks, yx_min, freq=1): 99 | if freq == 1: 100 | occlude_img, occluder_img, occluder_mask = self.get_occluders() 101 | 102 | len = img_seq.shape[0] 103 | start_pt_idx = random.randint(48,67) 104 | offset = random.randint(20,60) 105 | occ_len = random.randint(int(len * 0.1), int(len * 0.5)) 106 | start_fr = random.randint(0, len-occ_len) 107 | 108 | for i in range(occ_len): 109 | fr = cv2.cvtColor(img_seq[i+start_fr], cv2.COLOR_GRAY2RGB) 110 | x, y = landmarks[i][start_pt_idx] 111 | 112 | alpha_mask = np.expand_dims(occluder_mask, axis=2) 113 | alpha_mask = np.repeat(alpha_mask, 3, axis=2) / 255.0 114 | 115 | fr = self.overlay_image_alpha(fr, occluder_img, int(y-yx_min[i][0]-offset), int(x-yx_min[i][1]-offset), alpha_mask) 116 | img_seq[i + start_fr] = cv2.cvtColor(fr, cv2.COLOR_RGB2GRAY) 117 | 118 | else: 119 | len_global = img_seq.shape[0] 120 | len = img_seq.shape[0] // freq 121 | for j in range(freq): 122 | occlude_img, occluder_img, occluder_mask = self.get_occluders() 123 | 124 | start_pt_idx = random.randint(48, 67) 125 | offset = random.randint(20, 40) 126 | try: 127 | occ_len = random.randint(int(len_global * 0.3), int(len_global * 0.5)) 128 | start_fr = random.randint(0, len*j + len - occ_len) 129 | if start_fr < len*j: 130 | assert 1==2 131 | except: 132 | occ_len = len // 2 133 | start_fr = len * j 134 | 135 | for i in range(occ_len): 136 | fr = cv2.cvtColor(img_seq[i + start_fr], cv2.COLOR_GRAY2RGB) 137 | x, y = landmarks[i][start_pt_idx] 138 | 139 | alpha_mask = np.expand_dims(occluder_mask, axis=2) 140 | alpha_mask = np.repeat(alpha_mask, 3, axis=2) / 255.0 141 | fr = self.overlay_image_alpha(fr, occluder_img, int(y - yx_min[i][0] - offset), int(x - yx_min[i][1] - offset), 142 | alpha_mask) 143 | img_seq[i + start_fr] = cv2.cvtColor(fr, cv2.COLOR_RGB2GRAY) 144 | return img_seq, occlude_img 145 | 146 | def overlay_image_alpha(self, img, img_overlay, x, y, alpha_mask): 147 | """Overlay `img_overlay` onto `img` at (x, y) and blend using `alpha_mask`. 148 | 149 | `alpha_mask` must have same HxW as `img_overlay` and values in range [0, 1]. 150 | """ 151 | # Image ranges 152 | y1, y2 = max(0, y), min(img.shape[0], y + img_overlay.shape[0]) 153 | x1, x2 = max(0, x), min(img.shape[1], x + img_overlay.shape[1]) 154 | 155 | # Overlay ranges 156 | y1o, y2o = max(0, -y), min(img_overlay.shape[0], img.shape[0] - y) 157 | x1o, x2o = max(0, -x), min(img_overlay.shape[1], img.shape[1] - x) 158 | 159 | # Exit if nothing to do 160 | if y1 >= y2 or x1 >= x2 or y1o >= y2o or x1o >= x2o: 161 | return img 162 | 163 | # Blend overlay within the determined ranges 164 | img_crop = img[y1:y2, x1:x2] 165 | img_overlay_crop = img_overlay[y1o:y2o, x1o:x2o] 166 | 167 | alpha = alpha_mask[y1o:y2o, x1o:x2o] 168 | alpha_inv = 1.0 - alpha 169 | img_crop[:] = alpha * img_overlay_crop + alpha_inv * img_crop 170 | return img 171 | 172 | def get_occluder_augmentor(): 173 | """ 174 | Occludor augmentor 175 | """ 176 | aug=A.Compose([ 177 | A.AdvancedBlur(), 178 | A.OneOf([ 179 | A.ImageCompression (quality_lower=70,p=0.5), 180 | ], p=0.5), 181 | A.Affine ( 182 | scale=(0.8,1.2), 183 | rotate=(-15,15), 184 | shear=(-8,8), 185 | fit_output=True, 186 | p=0.7 187 | ), 188 | A.RandomBrightnessContrast(p=0.5,brightness_limit=0.1, contrast_limit=0.1, brightness_by_max=False), 189 | ]) 190 | return aug 191 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/embedding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Positional Encoding Module.""" 8 | 9 | import math 10 | 11 | import torch 12 | 13 | 14 | def _pre_hook( 15 | state_dict, 16 | prefix, 17 | local_metadata, 18 | strict, 19 | missing_keys, 20 | unexpected_keys, 21 | error_msgs, 22 | ): 23 | """Perform pre-hook in load_state_dict for backward compatibility. 24 | Note: 25 | We saved self.pe until v.0.5.2 but we have omitted it later. 26 | Therefore, we remove the item "pe" from `state_dict` for backward compatibility. 27 | """ 28 | k = prefix + "pe" 29 | if k in state_dict: 30 | state_dict.pop(k) 31 | 32 | 33 | class PositionalEncoding(torch.nn.Module): 34 | """Positional encoding. 35 | Args: 36 | d_model (int): Embedding dimension. 37 | dropout_rate (float): Dropout rate. 38 | max_len (int): Maximum input length. 39 | reverse (bool): Whether to reverse the input position. Only for 40 | the class LegacyRelPositionalEncoding. We remove it in the current 41 | class RelPositionalEncoding. 42 | """ 43 | 44 | def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False): 45 | """Construct an PositionalEncoding object.""" 46 | super(PositionalEncoding, self).__init__() 47 | self.d_model = d_model 48 | self.reverse = reverse 49 | self.xscale = math.sqrt(self.d_model) 50 | self.dropout = torch.nn.Dropout(p=dropout_rate) 51 | self.pe = None 52 | self.extend_pe(torch.tensor(0.0).expand(1, max_len)) 53 | self._register_load_state_dict_pre_hook(_pre_hook) 54 | 55 | def extend_pe(self, x): 56 | """Reset the positional encodings.""" 57 | if self.pe is not None: 58 | if self.pe.size(1) >= x.size(1): 59 | if self.pe.dtype != x.dtype or self.pe.device != x.device: 60 | self.pe = self.pe.to(dtype=x.dtype, device=x.device) 61 | return 62 | pe = torch.zeros(x.size(1), self.d_model) 63 | if self.reverse: 64 | position = torch.arange( 65 | x.size(1) - 1, -1, -1.0, dtype=torch.float32 66 | ).unsqueeze(1) 67 | else: 68 | position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) 69 | div_term = torch.exp( 70 | torch.arange(0, self.d_model, 2, dtype=torch.float32) 71 | * -(math.log(10000.0) / self.d_model) 72 | ) 73 | pe[:, 0::2] = torch.sin(position * div_term) 74 | pe[:, 1::2] = torch.cos(position * div_term) 75 | pe = pe.unsqueeze(0) 76 | self.pe = pe.to(device=x.device, dtype=x.dtype) 77 | 78 | def forward(self, x: torch.Tensor): 79 | """Add positional encoding. 80 | Args: 81 | x (torch.Tensor): Input tensor (batch, time, `*`). 82 | Returns: 83 | torch.Tensor: Encoded tensor (batch, time, `*`). 84 | """ 85 | self.extend_pe(x) 86 | x = x * self.xscale + self.pe[:, : x.size(1)] 87 | return self.dropout(x) 88 | 89 | 90 | class ScaledPositionalEncoding(PositionalEncoding): 91 | """Scaled positional encoding module. 92 | See Sec. 3.2 https://arxiv.org/abs/1809.08895 93 | Args: 94 | d_model (int): Embedding dimension. 95 | dropout_rate (float): Dropout rate. 96 | max_len (int): Maximum input length. 97 | """ 98 | 99 | def __init__(self, d_model, dropout_rate, max_len=5000): 100 | """Initialize class.""" 101 | super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len) 102 | self.alpha = torch.nn.Parameter(torch.tensor(1.0)) 103 | 104 | def reset_parameters(self): 105 | """Reset parameters.""" 106 | self.alpha.data = torch.tensor(1.0) 107 | 108 | def forward(self, x): 109 | """Add positional encoding. 110 | Args: 111 | x (torch.Tensor): Input tensor (batch, time, `*`). 112 | Returns: 113 | torch.Tensor: Encoded tensor (batch, time, `*`). 114 | """ 115 | self.extend_pe(x) 116 | x = x + self.alpha * self.pe[:, : x.size(1)] 117 | return self.dropout(x) 118 | 119 | 120 | class LegacyRelPositionalEncoding(PositionalEncoding): 121 | """Relative positional encoding module (old version). 122 | Details can be found in https://github.com/espnet/espnet/pull/2816. 123 | See : Appendix B in https://arxiv.org/abs/1901.02860 124 | Args: 125 | d_model (int): Embedding dimension. 126 | dropout_rate (float): Dropout rate. 127 | max_len (int): Maximum input length. 128 | """ 129 | 130 | def __init__(self, d_model, dropout_rate, max_len=5000): 131 | """Initialize class.""" 132 | super().__init__( 133 | d_model=d_model, 134 | dropout_rate=dropout_rate, 135 | max_len=max_len, 136 | reverse=True, 137 | ) 138 | 139 | def forward(self, x): 140 | """Compute positional encoding. 141 | Args: 142 | x (torch.Tensor): Input tensor (batch, time, `*`). 143 | Returns: 144 | torch.Tensor: Encoded tensor (batch, time, `*`). 145 | torch.Tensor: Positional embedding tensor (1, time, `*`). 146 | """ 147 | self.extend_pe(x) 148 | x = x * self.xscale 149 | pos_emb = self.pe[:, : x.size(1)] 150 | return self.dropout(x), self.dropout(pos_emb) 151 | 152 | 153 | class RelPositionalEncoding(torch.nn.Module): 154 | """Relative positional encoding module (new implementation). 155 | Details can be found in https://github.com/espnet/espnet/pull/2816. 156 | See : Appendix B in https://arxiv.org/abs/1901.02860 157 | Args: 158 | d_model (int): Embedding dimension. 159 | dropout_rate (float): Dropout rate. 160 | max_len (int): Maximum input length. 161 | """ 162 | 163 | def __init__(self, d_model, dropout_rate, max_len=5000): 164 | """Construct an PositionalEncoding object.""" 165 | super(RelPositionalEncoding, self).__init__() 166 | self.d_model = d_model 167 | self.xscale = math.sqrt(self.d_model) 168 | self.dropout = torch.nn.Dropout(p=dropout_rate) 169 | self.pe = None 170 | self.extend_pe(torch.tensor(0.0).expand(1, max_len)) 171 | 172 | def extend_pe(self, x): 173 | """Reset the positional encodings.""" 174 | if self.pe is not None: 175 | # self.pe contains both positive and negative parts 176 | # the length of self.pe is 2 * input_len - 1 177 | if self.pe.size(1) >= x.size(1) * 2 - 1: 178 | if self.pe.dtype != x.dtype or self.pe.device != x.device: 179 | self.pe = self.pe.to(dtype=x.dtype, device=x.device) 180 | return 181 | # Suppose `i` means to the position of query vecotr and `j` means the 182 | # position of key vector. We use position relative positions when keys 183 | # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 145 | """Compute LM loss value from buffer sequences. 146 | 147 | Args: 148 | x (torch.Tensor): Input ids. (batch, len) 149 | t (torch.Tensor): Target ids. (batch, len) 150 | 151 | Returns: 152 | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of 153 | loss to backward (scalar), 154 | negative log-likelihood of t: -log p(t) (scalar) and 155 | the number of elements in x (scalar) 156 | 157 | Notes: 158 | The last two return values are used 159 | in perplexity: p(t)^{-n} = exp(-log p(t) / n) 160 | 161 | """ 162 | xm = x != 0 163 | 164 | if self.embed_drop is not None: 165 | emb = self.embed_drop(self.embed(x)) 166 | else: 167 | emb = self.embed(x) 168 | 169 | h, _ = self.encoder(emb, self._target_mask(x)) 170 | y = self.decoder(h) 171 | loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none") 172 | mask = xm.to(dtype=loss.dtype) 173 | logp = loss * mask.view(-1) 174 | logp = logp.sum() 175 | count = mask.sum() 176 | return logp / count, logp, count 177 | 178 | def score( 179 | self, y: torch.Tensor, state: Any, x: torch.Tensor 180 | ) -> Tuple[torch.Tensor, Any]: 181 | """Score new token. 182 | 183 | Args: 184 | y (torch.Tensor): 1D torch.int64 prefix tokens. 185 | state: Scorer state for prefix tokens 186 | x (torch.Tensor): encoder feature that generates ys. 187 | 188 | Returns: 189 | tuple[torch.Tensor, Any]: Tuple of 190 | torch.float32 scores for next token (n_vocab) 191 | and next state for ys 192 | 193 | """ 194 | y = y.unsqueeze(0) 195 | 196 | if self.embed_drop is not None: 197 | emb = self.embed_drop(self.embed(y)) 198 | else: 199 | emb = self.embed(y) 200 | 201 | h, _, cache = self.encoder.forward_one_step( 202 | emb, self._target_mask(y), cache=state 203 | ) 204 | h = self.decoder(h[:, -1]) 205 | logp = h.log_softmax(dim=-1).squeeze(0) 206 | return logp, cache 207 | 208 | # batch beam search API (see BatchScorerInterface) 209 | def batch_score( 210 | self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor 211 | ) -> Tuple[torch.Tensor, List[Any]]: 212 | """Score new token batch (required). 213 | 214 | Args: 215 | ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). 216 | states (List[Any]): Scorer states for prefix tokens. 217 | xs (torch.Tensor): 218 | The encoder feature that generates ys (n_batch, xlen, n_feat). 219 | 220 | Returns: 221 | tuple[torch.Tensor, List[Any]]: Tuple of 222 | batchfied scores for next token with shape of `(n_batch, n_vocab)` 223 | and next state list for ys. 224 | 225 | """ 226 | # merge states 227 | n_batch = len(ys) 228 | n_layers = len(self.encoder.encoders) 229 | if states[0] is None: 230 | batch_state = None 231 | else: 232 | # transpose state of [batch, layer] into [layer, batch] 233 | batch_state = [ 234 | torch.stack([states[b][i] for b in range(n_batch)]) 235 | for i in range(n_layers) 236 | ] 237 | 238 | if self.embed_drop is not None: 239 | emb = self.embed_drop(self.embed(ys)) 240 | else: 241 | emb = self.embed(ys) 242 | 243 | # batch decoding 244 | h, _, states = self.encoder.forward_one_step( 245 | emb, self._target_mask(ys), cache=batch_state 246 | ) 247 | h = self.decoder(h[:, -1]) 248 | logp = h.log_softmax(dim=-1) 249 | 250 | # transpose state of [layer, batch] into [batch, layer] 251 | state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] 252 | return logp, state_list 253 | -------------------------------------------------------------------------------- /espnet/nets/e2e_asr_common.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | 4 | # Copyright 2017 Johns Hopkins University (Shinji Watanabe) 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Common functions for ASR.""" 8 | 9 | import json 10 | import logging 11 | import sys 12 | 13 | from itertools import groupby 14 | import numpy as np 15 | import six 16 | 17 | 18 | def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))): 19 | """End detection. 20 | 21 | described in Eq. (50) of S. Watanabe et al 22 | "Hybrid CTC/Attention Architecture for End-to-End Speech Recognition" 23 | 24 | :param ended_hyps: 25 | :param i: 26 | :param M: 27 | :param D_end: 28 | :return: 29 | """ 30 | if len(ended_hyps) == 0: 31 | return False 32 | count = 0 33 | best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0] 34 | for m in six.moves.range(M): 35 | # get ended_hyps with their length is i - m 36 | hyp_length = i - m 37 | hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length] 38 | if len(hyps_same_length) > 0: 39 | best_hyp_same_length = sorted( 40 | hyps_same_length, key=lambda x: x["score"], reverse=True 41 | )[0] 42 | if best_hyp_same_length["score"] - best_hyp["score"] < D_end: 43 | count += 1 44 | 45 | if count == M: 46 | return True 47 | else: 48 | return False 49 | 50 | 51 | # TODO(takaaki-hori): add different smoothing methods 52 | def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0): 53 | """Obtain label distribution for loss smoothing. 54 | 55 | :param odim: 56 | :param lsm_type: 57 | :param blank: 58 | :param transcript: 59 | :return: 60 | """ 61 | if transcript is not None: 62 | with open(transcript, "rb") as f: 63 | trans_json = json.load(f)["utts"] 64 | 65 | if lsm_type == "unigram": 66 | assert transcript is not None, ( 67 | "transcript is required for %s label smoothing" % lsm_type 68 | ) 69 | labelcount = np.zeros(odim) 70 | for k, v in trans_json.items(): 71 | ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()]) 72 | # to avoid an error when there is no text in an uttrance 73 | if len(ids) > 0: 74 | labelcount[ids] += 1 75 | labelcount[odim - 1] = len(transcript) # count 76 | labelcount[labelcount == 0] = 1 # flooring 77 | labelcount[blank] = 0 # remove counts for blank 78 | labeldist = labelcount.astype(np.float32) / np.sum(labelcount) 79 | else: 80 | logging.error("Error: unexpected label smoothing type: %s" % lsm_type) 81 | sys.exit() 82 | 83 | return labeldist 84 | 85 | 86 | def get_vgg2l_odim(idim, in_channel=3, out_channel=128): 87 | """Return the output size of the VGG frontend. 88 | 89 | :param in_channel: input channel size 90 | :param out_channel: output channel size 91 | :return: output size 92 | :rtype int 93 | """ 94 | idim = idim / in_channel 95 | idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 1st max pooling 96 | idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 2nd max pooling 97 | return int(idim) * out_channel # numer of channels 98 | 99 | 100 | class ErrorCalculator(object): 101 | """Calculate CER and WER for E2E_ASR and CTC models during training. 102 | 103 | :param y_hats: numpy array with predicted text 104 | :param y_pads: numpy array with true (target) text 105 | :param char_list: 106 | :param sym_space: 107 | :param sym_blank: 108 | :return: 109 | """ 110 | 111 | def __init__( 112 | self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False 113 | ): 114 | """Construct an ErrorCalculator object.""" 115 | super(ErrorCalculator, self).__init__() 116 | 117 | self.report_cer = report_cer 118 | self.report_wer = report_wer 119 | 120 | self.char_list = char_list 121 | self.space = sym_space 122 | self.blank = sym_blank 123 | self.idx_blank = self.char_list.index(self.blank) 124 | if self.space in self.char_list: 125 | self.idx_space = self.char_list.index(self.space) 126 | else: 127 | self.idx_space = None 128 | 129 | def __call__(self, ys_hat, ys_pad, is_ctc=False): 130 | """Calculate sentence-level WER/CER score. 131 | 132 | :param torch.Tensor ys_hat: prediction (batch, seqlen) 133 | :param torch.Tensor ys_pad: reference (batch, seqlen) 134 | :param bool is_ctc: calculate CER score for CTC 135 | :return: sentence-level WER score 136 | :rtype float 137 | :return: sentence-level CER score 138 | :rtype float 139 | """ 140 | cer, wer = None, None 141 | if is_ctc: 142 | return self.calculate_cer_ctc(ys_hat, ys_pad) 143 | elif not self.report_cer and not self.report_wer: 144 | return cer, wer 145 | 146 | seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad) 147 | if self.report_cer: 148 | cer = self.calculate_cer(seqs_hat, seqs_true) 149 | 150 | if self.report_wer: 151 | wer = self.calculate_wer(seqs_hat, seqs_true) 152 | return cer, wer 153 | 154 | def calculate_cer_ctc(self, ys_hat, ys_pad): 155 | """Calculate sentence-level CER score for CTC. 156 | 157 | :param torch.Tensor ys_hat: prediction (batch, seqlen) 158 | :param torch.Tensor ys_pad: reference (batch, seqlen) 159 | :return: average sentence-level CER score 160 | :rtype float 161 | """ 162 | import editdistance 163 | 164 | cers, char_ref_lens = [], [] 165 | for i, y in enumerate(ys_hat): 166 | y_hat = [x[0] for x in groupby(y)] 167 | y_true = ys_pad[i] 168 | seq_hat, seq_true = [], [] 169 | for idx in y_hat: 170 | idx = int(idx) 171 | if idx != -1 and idx != self.idx_blank and idx != self.idx_space: 172 | seq_hat.append(self.char_list[int(idx)]) 173 | 174 | for idx in y_true: 175 | idx = int(idx) 176 | if idx != -1 and idx != self.idx_blank and idx != self.idx_space: 177 | seq_true.append(self.char_list[int(idx)]) 178 | 179 | hyp_chars = "".join(seq_hat) 180 | ref_chars = "".join(seq_true) 181 | if len(ref_chars) > 0: 182 | cers.append(editdistance.eval(hyp_chars, ref_chars)) 183 | char_ref_lens.append(len(ref_chars)) 184 | 185 | cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None 186 | return cer_ctc 187 | 188 | def convert_to_char(self, ys_hat, ys_pad): 189 | """Convert index to character. 190 | 191 | :param torch.Tensor seqs_hat: prediction (batch, seqlen) 192 | :param torch.Tensor seqs_true: reference (batch, seqlen) 193 | :return: token list of prediction 194 | :rtype list 195 | :return: token list of reference 196 | :rtype list 197 | """ 198 | seqs_hat, seqs_true = [], [] 199 | for i, y_hat in enumerate(ys_hat): 200 | y_true = ys_pad[i] 201 | eos_true = np.where(y_true == -1)[0] 202 | ymax = eos_true[0] if len(eos_true) > 0 else len(y_true) 203 | # NOTE: padding index (-1) in y_true is used to pad y_hat 204 | seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]] 205 | seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1] 206 | seq_hat_text = "".join(seq_hat).replace(self.space, " ") 207 | seq_hat_text = seq_hat_text.replace(self.blank, "") 208 | seq_true_text = "".join(seq_true).replace(self.space, " ") 209 | seqs_hat.append(seq_hat_text) 210 | seqs_true.append(seq_true_text) 211 | return seqs_hat, seqs_true 212 | 213 | def calculate_cer(self, seqs_hat, seqs_true): 214 | """Calculate sentence-level CER score. 215 | 216 | :param list seqs_hat: prediction 217 | :param list seqs_true: reference 218 | :return: average sentence-level CER score 219 | :rtype float 220 | """ 221 | import editdistance 222 | 223 | char_eds, char_ref_lens = [], [] 224 | for i, seq_hat_text in enumerate(seqs_hat): 225 | seq_true_text = seqs_true[i] 226 | hyp_chars = seq_hat_text.replace(" ", "") 227 | ref_chars = seq_true_text.replace(" ", "") 228 | char_eds.append(editdistance.eval(hyp_chars, ref_chars)) 229 | char_ref_lens.append(len(ref_chars)) 230 | return float(sum(char_eds)) / sum(char_ref_lens) 231 | 232 | def calculate_wer(self, seqs_hat, seqs_true): 233 | """Calculate sentence-level WER score. 234 | 235 | :param list seqs_hat: prediction 236 | :param list seqs_true: reference 237 | :return: average sentence-level WER score 238 | :rtype float 239 | """ 240 | import editdistance 241 | 242 | word_eds, word_ref_lens = [], [] 243 | for i, seq_hat_text in enumerate(seqs_hat): 244 | seq_true_text = seqs_true[i] 245 | hyp_words = seq_hat_text.split() 246 | ref_words = seq_true_text.split() 247 | word_eds.append(editdistance.eval(hyp_words, ref_words)) 248 | word_ref_lens.append(len(ref_words)) 249 | return float(sum(word_eds)) / sum(word_ref_lens) 250 | -------------------------------------------------------------------------------- /src/data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.utils.data import Dataset 8 | from einops import rearrange 9 | import sys 10 | from src.data.char_list import char_list 11 | from src.data.transforms import * 12 | from src.data.visual_corruption import * 13 | import librosa 14 | import cv2 15 | import pickle 16 | 17 | class AVDataset(Dataset): 18 | def __init__(self, data_path, split_file, mode, data_type, max_vid_len=600, max_txt_len=200, visual_corruption=True, fast_validate=False): 19 | assert mode in ['train', 'test', 'val'] 20 | self.mode = mode 21 | self.fast_validate = fast_validate 22 | self.visual_corruption = visual_corruption if mode == 'train' else False 23 | self.data_path = data_path 24 | self.file_paths = self.build_file_list(split_file, mode, data_type) 25 | self.char_list = char_list 26 | self.char2idx = {v: k for k, v in enumerate(char_list)} 27 | 28 | self.max_vid_len = max_vid_len 29 | self.max_txt_len = max_txt_len 30 | 31 | self._noise = np.load('./src/data/babbleNoise_resample_16K.npy') 32 | self.transform_aud = self.get_audio_transform(self._noise, split=mode) 33 | self.transform_vid = self.get_video_transform(split=mode) 34 | 35 | if visual_corruption: 36 | self.visual_corruption = Visual_Corruption_Modeling() 37 | 38 | def build_file_list(self, split_file, mode, data_type): 39 | datalist = open(split_file).read().splitlines() 40 | if mode == 'val': 41 | if data_type == 'LRS2': 42 | return [os.path.join('main', x.strip().split()[0]) for x in datalist] 43 | elif data_type == 'LRS3': 44 | return [os.path.join('trainval', x.strip().split()[0]) for x in datalist] 45 | else: 46 | raise NotImplementedError("data_type should be LRS2 or LRS3") 47 | else: 48 | return [x.strip().split()[0] for x in datalist] 49 | 50 | def __len__(self): 51 | return len(self.file_paths) 52 | 53 | def __getitem__(self, idx): 54 | f_name = self.file_paths[idx] 55 | 56 | vid_path = os.path.join(self.data_path, 'Video', f_name + '.mp4') 57 | aud_path = os.path.join(self.data_path, 'Audio', f_name + '.wav') 58 | txt_path = os.path.join(self.data_path, 'Text', f_name + '.txt') 59 | lm_path = os.path.join(self.data_path, 'Transformed_LM', f_name + '.pkl') 60 | if not (os.path.exists(vid_path) and os.path.exists(aud_path) and os.path.exists(txt_path) and os.path.exists(lm_path)): 61 | return None 62 | 63 | vid = self.load_video(vid_path) 64 | if len(vid) == 0: 65 | return None 66 | if len(vid) > self.max_vid_len: 67 | vid = vid[:self.max_vid_len] 68 | 69 | aud = self.load_audio(aud_path) 70 | aud = aud[:int(len(vid) / 25 * 16000)] 71 | if len(aud) < int(len(vid) / 25 * 16000): 72 | aud = np.concatenate([aud, np.zeros(int(len(vid) / 25 * 16000) - len(aud)), 0]) 73 | 74 | gt = self.load_txt(txt_path) 75 | text = np.array(self.parse_transcript(gt)) 76 | if len(text) > self.max_txt_len: 77 | text = text[:self.max_txt_len] 78 | 79 | if self.visual_corruption: 80 | with open(lm_path, "rb") as pkl_file: 81 | pkl = pickle.load(pkl_file) 82 | lm = pkl['landmarks'] 83 | yx_min = pkl['yx_min'] 84 | 85 | prob = random.random() 86 | if prob < 0.2: 87 | pass 88 | else: 89 | freq1, freq2 = [random.choice([1, 2, 3]) for _ in range(2)] 90 | vid, _ = self.visual_corruption.occlude_sequence(vid, lm, yx_min, freq=freq1) 91 | vid = self.visual_corruption.noise_sequence(vid, freq=freq2) 92 | 93 | vid = self.transform_vid(vid) 94 | aud = self.transform_aud(aud) 95 | 96 | return vid, aud, text 97 | 98 | def load_txt(self, filename): 99 | text = open(filename, 'r').readline().strip() 100 | return text[text.find(' '):].strip() 101 | 102 | def parse_transcript(self, transcript): 103 | idx_list = list() 104 | for char in transcript: 105 | if char==" ": 106 | char = "" 107 | idx = self.char2idx.get(char) 108 | idx = idx if idx is not None else self.char2idx[''] 109 | idx_list.append(idx) 110 | return idx_list 111 | 112 | def get_audio_transform(self, noise_data, split): 113 | """get_audio_transform. 114 | 115 | :param noise_data: numpy.ndarray, the noisy data to be injected to data. 116 | """ 117 | if split != 'train': 118 | return Compose([ 119 | NormalizeUtterance(), 120 | ExpandDims()] 121 | ) 122 | else: 123 | return Compose([ 124 | AddNoise( 125 | noise=noise_data 126 | ), 127 | NormalizeUtterance(), 128 | ExpandDims()] 129 | ) 130 | 131 | def get_video_transform(self, split='test'): 132 | crop_size = (88, 88) 133 | (mean, std) = (0.421, 0.165) 134 | 135 | return Compose([ 136 | Normalize(0.0, 255.0), 137 | CenterCrop(crop_size) if split=='test' else RandomCrop(crop_size), 138 | Identity() if split=='test' else HorizontalFlip(0.5), 139 | Normalize(mean,std), 140 | Identity() if split=='test' else TimeMask(max_mask_length=15), 141 | Identity() if split=='test' else CutoutHole(min_hole_length=22, max_hole_length=44) 142 | ]) 143 | 144 | def load_data(self, vid_path, aud_path): 145 | vid = self.load_video(vid_path) 146 | aud = self.load_audio(aud_path) 147 | aud = aud[:int(len(vid) / 25 * 16000)] 148 | if len(aud) < int(len(vid) / 25 * 16000): 149 | aud = np.concatenate([aud, np.zeros(int(len(vid) / 25 * 16000) - len(aud)), 0]) 150 | vid = self.transform_vid(vid) 151 | aud = self.transform_aud(aud) 152 | return vid, aud 153 | 154 | def load_video(self, data_filename): 155 | frames = [] 156 | cap = cv2.VideoCapture(data_filename) 157 | while(cap.isOpened()): 158 | ret, frame = cap.read() # BGR 159 | if ret: 160 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 161 | frames.append(frame) 162 | else: 163 | break 164 | cap.release() 165 | return np.array(frames) 166 | 167 | def load_audio(self, audio_filename, specified_sr=16000, int_16=False): 168 | """load_audio. 169 | 170 | :param audio_filename: str, the filename for an audio waveform. 171 | :param specified_sr: int, expected sampling rate, the default value is 16KHz. 172 | :param int_16: boolean, return 16-bit PCM if set it as True. 173 | """ 174 | try: 175 | if audio_filename.endswith('npy'): 176 | audio = np.load(audio_filename) 177 | elif audio_filename.endswith('npz'): 178 | audio = np.load(audio_filename)['data'] 179 | else: 180 | audio, sr = librosa.load(audio_filename, sr=None) 181 | audio = librosa.resample(audio, sr, specified_sr) if sr != specified_sr else audio 182 | except IOError: 183 | sys.exit() 184 | 185 | if int_16 and audio.dtype == np.float32: 186 | audio = ((audio - 1.) * (65535. / 2.) + 32767.).astype(np.int16) 187 | audio = np.array(np.clip(np.round(audio), -2 ** 15, 2 ** 15 - 1), dtype=np.int16) 188 | if not int_16 and audio.dtype == np.int16: 189 | audio = ((audio - 32767.) * 2 / 65535. + 1).astype(np.float32) 190 | return audio 191 | 192 | def collate_fn(self, batch): 193 | vid_len, aud_len, text_len = [], [], [] 194 | for data in batch: 195 | if data is not None: 196 | vid_len.append(len(data[0])) 197 | aud_len.append(len(data[1])) 198 | text_len.append(len(data[2])) 199 | 200 | max_vid_len = max(vid_len) 201 | max_aud_len = max(aud_len) 202 | max_text_len = max(text_len) 203 | 204 | padded_vid = [] 205 | padded_aud = [] 206 | padded_text = [] 207 | 208 | for i, data in enumerate(batch): 209 | if data is not None: 210 | vid, aud, text = data 211 | padded_vid.append(torch.cat([torch.tensor(vid), torch.zeros([max_vid_len - len(vid), 88, 88])], 0)) 212 | padded_aud.append(torch.cat([torch.tensor(aud), torch.zeros([max_aud_len - len(aud), 1])], 0)) 213 | padded_text.append(torch.cat([torch.tensor(text), torch.ones([max_text_len - len(text)]) * -1], 0)) 214 | 215 | vid = torch.stack(padded_vid, 0).float() 216 | aud = torch.stack(padded_aud, 0).float() 217 | text = torch.stack(padded_text, 0).long() 218 | vid_len = torch.IntTensor(vid_len) 219 | aud_len = torch.IntTensor(aud_len) 220 | return vid, aud, vid_len, aud_len, text 221 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/Bert_layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | BERT layers from the huggingface implementation 3 | (https://github.com/huggingface/transformers) 4 | """ 5 | # coding=utf-8 6 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 7 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | import logging 21 | import math 22 | 23 | import torch 24 | from torch import nn 25 | from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm 26 | 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | def gelu(x): 32 | """Implementation of the gelu activation function. 33 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 34 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 35 | Also see https://arxiv.org/abs/1606.08415 36 | """ 37 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 38 | 39 | 40 | def swish(x): 41 | return x * torch.sigmoid(x) 42 | 43 | 44 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 45 | 46 | 47 | class GELU(nn.Module): 48 | def forward(self, input_): 49 | output = gelu(input_) 50 | return output 51 | 52 | 53 | class BertSelfAttention(nn.Module): 54 | def __init__(self, config): 55 | super(BertSelfAttention, self).__init__() 56 | if config.hidden_size % config.num_attention_heads != 0: 57 | raise ValueError( 58 | "The hidden size (%d) is not a multiple of the number of attention " 59 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 60 | self.num_attention_heads = config.num_attention_heads 61 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 62 | self.all_head_size = self.num_attention_heads * self.attention_head_size 63 | 64 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 65 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 66 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 67 | 68 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 69 | 70 | def transpose_for_scores(self, x): 71 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 72 | x = x.view(*new_x_shape) 73 | return x.permute(0, 2, 1, 3) 74 | 75 | def forward(self, hidden_states, attention_mask): 76 | mixed_query_layer = self.query(hidden_states) 77 | mixed_key_layer = self.key(hidden_states) 78 | mixed_value_layer = self.value(hidden_states) 79 | 80 | query_layer = self.transpose_for_scores(mixed_query_layer) 81 | key_layer = self.transpose_for_scores(mixed_key_layer) 82 | value_layer = self.transpose_for_scores(mixed_value_layer) 83 | 84 | # Take the dot product between "query" and "key" to get the raw attention scores. 85 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 86 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 87 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 88 | attention_scores = attention_scores + attention_mask 89 | 90 | # Normalize the attention scores to probabilities. 91 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 92 | 93 | # This is actually dropping out entire tokens to attend to, which might 94 | # seem a bit unusual, but is taken from the original Transformer paper. 95 | attention_probs = self.dropout(attention_probs) 96 | 97 | context_layer = torch.matmul(attention_probs, value_layer) 98 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 99 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 100 | context_layer = context_layer.view(*new_context_layer_shape) 101 | return context_layer 102 | 103 | 104 | class BertSelfOutput(nn.Module): 105 | def __init__(self, config): 106 | super(BertSelfOutput, self).__init__() 107 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 108 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 109 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 110 | 111 | def forward(self, hidden_states, input_tensor): 112 | hidden_states = self.dense(hidden_states) 113 | hidden_states = self.dropout(hidden_states) 114 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 115 | return hidden_states 116 | 117 | 118 | class BertAttention(nn.Module): 119 | def __init__(self, config): 120 | super(BertAttention, self).__init__() 121 | self.self = BertSelfAttention(config) 122 | self.output = BertSelfOutput(config) 123 | 124 | def forward(self, input_tensor, attention_mask): 125 | self_output = self.self(input_tensor, attention_mask) 126 | attention_output = self.output(self_output, input_tensor) 127 | return attention_output 128 | 129 | 130 | class BertIntermediate(nn.Module): 131 | def __init__(self, config): 132 | super(BertIntermediate, self).__init__() 133 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 134 | if isinstance(config.hidden_act, str): 135 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 136 | else: 137 | self.intermediate_act_fn = config.hidden_act 138 | 139 | def forward(self, hidden_states): 140 | hidden_states = self.dense(hidden_states) 141 | hidden_states = self.intermediate_act_fn(hidden_states) 142 | return hidden_states 143 | 144 | 145 | class BertOutput(nn.Module): 146 | def __init__(self, config): 147 | super(BertOutput, self).__init__() 148 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 149 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 150 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 151 | 152 | def forward(self, hidden_states, input_tensor): 153 | hidden_states = self.dense(hidden_states) 154 | hidden_states = self.dropout(hidden_states) 155 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 156 | return hidden_states 157 | 158 | 159 | class BertLayer(nn.Module): 160 | def __init__(self, config): 161 | super(BertLayer, self).__init__() 162 | self.attention = BertAttention(config) 163 | self.intermediate = BertIntermediate(config) 164 | self.output = BertOutput(config) 165 | 166 | def forward(self, hidden_states, attention_mask): 167 | attention_output = self.attention(hidden_states, attention_mask) 168 | intermediate_output = self.intermediate(attention_output) 169 | layer_output = self.output(intermediate_output, attention_output) 170 | return layer_output 171 | 172 | 173 | class BertPooler(nn.Module): 174 | def __init__(self, config): 175 | super(BertPooler, self).__init__() 176 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 177 | self.activation = nn.Tanh() 178 | 179 | def forward(self, hidden_states): 180 | # We "pool" the model by simply taking the hidden state corresponding 181 | # to the first token. 182 | first_token_tensor = hidden_states[:, 0] 183 | pooled_output = self.dense(first_token_tensor) 184 | pooled_output = self.activation(pooled_output) 185 | return pooled_output 186 | 187 | 188 | class BertPredictionHeadTransform(nn.Module): 189 | def __init__(self, config): 190 | super(BertPredictionHeadTransform, self).__init__() 191 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 192 | if isinstance(config.hidden_act, str): 193 | self.transform_act_fn = ACT2FN[config.hidden_act] 194 | else: 195 | self.transform_act_fn = config.hidden_act 196 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 197 | 198 | def forward(self, hidden_states): 199 | hidden_states = self.dense(hidden_states) 200 | hidden_states = self.transform_act_fn(hidden_states) 201 | hidden_states = self.LayerNorm(hidden_states) 202 | return hidden_states 203 | 204 | 205 | class BertLMPredictionHead(nn.Module): 206 | def __init__(self, config, bert_model_embedding_weights): 207 | super(BertLMPredictionHead, self).__init__() 208 | self.transform = BertPredictionHeadTransform(config) 209 | 210 | # The output weights are the same as the input embeddings, but there is 211 | # an output-only bias for each token. 212 | self.decoder = nn.Linear(bert_model_embedding_weights.size(1), 213 | bert_model_embedding_weights.size(0), 214 | bias=False) 215 | self.decoder.weight = bert_model_embedding_weights 216 | self.bias = nn.Parameter( 217 | torch.zeros(bert_model_embedding_weights.size(0))) 218 | 219 | def forward(self, hidden_states): 220 | hidden_states = self.transform(hidden_states) 221 | hidden_states = self.decoder(hidden_states) + self.bias 222 | return hidden_states 223 | 224 | 225 | class BertOnlyMLMHead(nn.Module): 226 | def __init__(self, config, bert_model_embedding_weights): 227 | super(BertOnlyMLMHead, self).__init__() 228 | self.predictions = BertLMPredictionHead(config, 229 | bert_model_embedding_weights) 230 | 231 | def forward(self, sequence_output): 232 | prediction_scores = self.predictions(sequence_output) 233 | return prediction_scores 234 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/transformer/decoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Decoder definition.""" 8 | 9 | from typing import Any 10 | from typing import List 11 | from typing import Tuple 12 | 13 | import torch 14 | 15 | from espnet.nets.pytorch_backend.nets_utils import rename_state_dict 16 | from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention 17 | from espnet.nets.pytorch_backend.transformer.decoder_layer import DecoderLayer 18 | from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding 19 | from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm 20 | from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask 21 | from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( 22 | PositionwiseFeedForward, # noqa: H301 23 | ) 24 | from espnet.nets.pytorch_backend.transformer.repeat import repeat 25 | from espnet.nets.scorer_interface import BatchScorerInterface 26 | 27 | 28 | def _pre_hook( 29 | state_dict, 30 | prefix, 31 | local_metadata, 32 | strict, 33 | missing_keys, 34 | unexpected_keys, 35 | error_msgs, 36 | ): 37 | # https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563 38 | rename_state_dict(prefix + "output_norm.", prefix + "after_norm.", state_dict) 39 | 40 | 41 | class Decoder(BatchScorerInterface, torch.nn.Module): 42 | """Transfomer decoder module. 43 | 44 | :param int odim: output dim 45 | :param int attention_dim: dimention of attention 46 | :param int attention_heads: the number of heads of multi head attention 47 | :param int linear_units: the number of units of position-wise feed forward 48 | :param int num_blocks: the number of decoder blocks 49 | :param float dropout_rate: dropout rate 50 | :param float attention_dropout_rate: dropout rate for attention 51 | :param str or torch.nn.Module input_layer: input layer type 52 | :param bool use_output_layer: whether to use output layer 53 | :param class pos_enc_class: PositionalEncoding or ScaledPositionalEncoding 54 | :param bool normalize_before: whether to use layer_norm before the first block 55 | :param bool concat_after: whether to concat attention layer's input and output 56 | if True, additional linear will be applied. 57 | i.e. x -> x + linear(concat(x, att(x))) 58 | if False, no additional linear will be applied. i.e. x -> x + att(x) 59 | """ 60 | 61 | def __init__( 62 | self, 63 | odim, 64 | attention_dim=256, 65 | attention_heads=4, 66 | linear_units=2048, 67 | num_blocks=6, 68 | dropout_rate=0.1, 69 | positional_dropout_rate=0.1, 70 | self_attention_dropout_rate=0.0, 71 | src_attention_dropout_rate=0.0, 72 | input_layer="embed", 73 | use_output_layer=True, 74 | pos_enc_class=PositionalEncoding, 75 | normalize_before=True, 76 | concat_after=False, 77 | ): 78 | """Construct an Decoder object.""" 79 | torch.nn.Module.__init__(self) 80 | self._register_load_state_dict_pre_hook(_pre_hook) 81 | if input_layer == "embed": 82 | self.embed = torch.nn.Sequential( 83 | torch.nn.Embedding(odim, attention_dim), 84 | pos_enc_class(attention_dim, positional_dropout_rate), 85 | ) 86 | elif input_layer == "linear": 87 | self.embed = torch.nn.Sequential( 88 | torch.nn.Linear(odim, attention_dim), 89 | torch.nn.LayerNorm(attention_dim), 90 | torch.nn.Dropout(dropout_rate), 91 | torch.nn.ReLU(), 92 | pos_enc_class(attention_dim, positional_dropout_rate), 93 | ) 94 | elif isinstance(input_layer, torch.nn.Module): 95 | self.embed = torch.nn.Sequential( 96 | input_layer, pos_enc_class(attention_dim, positional_dropout_rate) 97 | ) 98 | else: 99 | raise NotImplementedError("only `embed` or torch.nn.Module is supported.") 100 | self.normalize_before = normalize_before 101 | self.decoders = repeat( 102 | num_blocks, 103 | lambda: DecoderLayer( 104 | attention_dim, 105 | MultiHeadedAttention( 106 | attention_heads, attention_dim, self_attention_dropout_rate 107 | ), 108 | MultiHeadedAttention( 109 | attention_heads, attention_dim, src_attention_dropout_rate 110 | ), 111 | PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), 112 | dropout_rate, 113 | normalize_before, 114 | concat_after, 115 | ), 116 | ) 117 | if self.normalize_before: 118 | self.after_norm = LayerNorm(attention_dim) 119 | if use_output_layer: 120 | self.output_layer = torch.nn.Linear(attention_dim, odim) 121 | else: 122 | self.output_layer = None 123 | 124 | def forward(self, tgt, tgt_mask, memory, memory_mask): 125 | """Forward decoder. 126 | :param torch.Tensor tgt: input token ids, int64 (batch, maxlen_out) 127 | if input_layer == "embed" 128 | input tensor (batch, maxlen_out, #mels) 129 | in the other cases 130 | :param torch.Tensor tgt_mask: input token mask, (batch, maxlen_out) 131 | dtype=torch.uint8 in PyTorch 1.2- 132 | dtype=torch.bool in PyTorch 1.2+ (include 1.2) 133 | :param torch.Tensor memory: encoded memory, float32 (batch, maxlen_in, feat) 134 | :param torch.Tensor memory_mask: encoded memory mask, (batch, maxlen_in) 135 | dtype=torch.uint8 in PyTorch 1.2- 136 | dtype=torch.bool in PyTorch 1.2+ (include 1.2) 137 | :return x: decoded token score before softmax (batch, maxlen_out, token) 138 | if use_output_layer is True, 139 | final block outputs (batch, maxlen_out, attention_dim) 140 | in the other cases 141 | :rtype: torch.Tensor 142 | :return tgt_mask: score mask before softmax (batch, maxlen_out) 143 | :rtype: torch.Tensor 144 | """ 145 | x = self.embed(tgt) 146 | x, tgt_mask, memory, memory_mask = self.decoders( 147 | x, tgt_mask, memory, memory_mask 148 | ) 149 | if self.normalize_before: 150 | x = self.after_norm(x) 151 | if self.output_layer is not None: 152 | x = self.output_layer(x) 153 | return x, tgt_mask 154 | 155 | def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None): 156 | """Forward one step. 157 | :param torch.Tensor tgt: input token ids, int64 (batch, maxlen_out) 158 | :param torch.Tensor tgt_mask: input token mask, (batch, maxlen_out) 159 | dtype=torch.uint8 in PyTorch 1.2- 160 | dtype=torch.bool in PyTorch 1.2+ (include 1.2) 161 | :param torch.Tensor memory: encoded memory, float32 (batch, maxlen_in, feat) 162 | :param List[torch.Tensor] cache: 163 | cached output list of (batch, max_time_out-1, size) 164 | :return y, cache: NN output value and cache per `self.decoders`. 165 | `y.shape` is (batch, maxlen_out, token) 166 | :rtype: Tuple[torch.Tensor, List[torch.Tensor]] 167 | """ 168 | x = self.embed(tgt) 169 | if cache is None: 170 | cache = [None] * len(self.decoders) 171 | new_cache = [] 172 | for c, decoder in zip(cache, self.decoders): 173 | x, tgt_mask, memory, memory_mask = decoder( 174 | x, tgt_mask, memory, memory_mask, cache=c 175 | ) 176 | new_cache.append(x) 177 | 178 | if self.normalize_before: 179 | y = self.after_norm(x[:, -1]) 180 | else: 181 | y = x[:, -1] 182 | if self.output_layer is not None: 183 | y = torch.log_softmax(self.output_layer(y), dim=-1) 184 | 185 | return y, new_cache 186 | 187 | # beam search API (see ScorerInterface) 188 | def score(self, ys, state, x): 189 | """Score.""" 190 | ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) 191 | logp, state = self.forward_one_step( 192 | ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state 193 | ) 194 | return logp.squeeze(0), state 195 | 196 | # batch beam search API (see BatchScorerInterface) 197 | def batch_score( 198 | self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor 199 | ) -> Tuple[torch.Tensor, List[Any]]: 200 | """Score new token batch (required). 201 | Args: 202 | ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). 203 | states (List[Any]): Scorer states for prefix tokens. 204 | xs (torch.Tensor): 205 | The encoder feature that generates ys (n_batch, xlen, n_feat). 206 | Returns: 207 | tuple[torch.Tensor, List[Any]]: Tuple of 208 | batchfied scores for next token with shape of `(n_batch, n_vocab)` 209 | and next state list for ys. 210 | """ 211 | # merge states 212 | n_batch = len(ys) 213 | n_layers = len(self.decoders) 214 | if states[0] is None: 215 | batch_state = None 216 | else: 217 | # transpose state of [batch, layer] into [layer, batch] 218 | batch_state = [ 219 | torch.stack([states[b][l] for b in range(n_batch)]) 220 | for l in range(n_layers) 221 | ] 222 | 223 | # batch decoding 224 | ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0) 225 | logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state) 226 | 227 | # transpose state of [layer, batch] into [batch, layer] 228 | state_list = [[states[l][b] for l in range(n_layers)] for b in range(n_batch)] 229 | return logp, state_list 230 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Audio-Visual Speech Recognition (AVSR) - AVRelScore, VCAFE 2 | This repository contains the PyTorch implementation of the following papers: 3 | > **Watch or Listen: Robust Audio-Visual Speech Recognition with Visual Corruption Modeling and Reliability Scoring (CVPR2023) - AVRelScore**
4 | > Joanna Hong\*, Minsu Kim\*, Jeongsoo Choi, and Yong Man Ro (\*Equal contribution)
5 | > \[[Paper](https://arxiv.org/abs/2303.08536)\] \[[Demo Video](https://github.com/joannahong/AV-RelScore/tree/main/demo_video)\]

6 | > **Visual Context-driven Audio Feature Enhancement for Robust End-to-End Audio-Visual Speech Recognition (Interspeech 2022) - VCAFE**
7 | > Joanna Hong\*, Minsu Kim\*, and Yong Man Ro (\*Equal contribution) 8 | > \[[Paper](https://arxiv.org/abs/2207.06020)\] 9 | 10 |
11 | 12 | ## Requirements 13 | - python 3.8 14 | - pytorch 1.8 ~ 1.9 15 | - torchvision 16 | - torchaudio 17 | - ffmpeg 18 | - av 19 | - tensorboard 20 | - scikit-image 21 | - opencv-python 22 | - pillow 23 | - librosa 24 | - scipy 25 | - albumentations 26 | 27 | ## Preparation 28 | ### Dataset Download 29 | LRS2/LRS3 dataset can be downloaded from the below link. 30 | - https://www.robots.ox.ac.uk/~vgg/data/lip_reading/ 31 | 32 | ### Landmark Download 33 | For data preprocessing, download the landmark of LRS2 and LRS3 from the [repository](https://github.com/mpc001/Visual_Speech_Recognition_for_Multiple_Languages#Model-Zoo). 34 | (Landmarks for "VSR for multiple languages models") 35 | 36 | ### Occlusion Data Download 37 | For visual corruption modeling, download `coco_object.7z` from the [repository](https://github.com/kennyvoo/face-occlusion-generation). 38 | 39 | Unzip and put the files at 40 | ``` 41 | ./occlusion_patch/object_image_sr 42 | ./occlusion_patch/object_mask_x4 43 | ``` 44 | 45 | ### Babble Noise Download 46 | For audio corruption modeling, download babble noise file from [here](https://drive.google.com/file/d/15CSWCYz12CIsgFeDT139CiCc545jhbyK/view?usp=sharing). 47 | 48 | put the file at 49 | ``` 50 | ./src/data/babbleNoise_resample_16K.npy 51 | ``` 52 | 53 | ### Pre-trained Frontends 54 | For initializing visual frontend and audio frontend, please download the pre-trained models from the [repository](https://github.com/mpc001/Lipreading_using_Temporal_Convolutional_Networks#model-zoo). (resnet18_dctcn_audio/resnet18_dctcn_video) 55 | 56 | Put the .tar file at 57 | ``` 58 | ./checkpoints/frontend/lrw_resnet18_dctcn_audio.pth.tar 59 | ./checkpoints/frontend/lrw_resnet18_dctcn_video.pth.tar 60 | ``` 61 | 62 | ### Preprocessing 63 | After download the dataset and landmark, we 1) align and crop the lip centered video, 2) extract audio, 3) obtain aligned landmark. 64 | We suppose the data directory is constructed as 65 | ``` 66 | LRS2 67 | ├── main 68 | | ├── * 69 | | | └── *.mp4 70 | | | └── *.txt 71 | ├── pretrain 72 | | ├── * 73 | | | └── *.mp4 74 | | | └── *.txt 75 | ``` 76 | 77 | ``` 78 | LRS3 79 | ├── trainval 80 | | ├── * 81 | | | └── *.mp4 82 | | | └── *.txt 83 | ├── pretrain 84 | | ├── * 85 | | | └── *.mp4 86 | | | └── *.txt 87 | ├── test 88 | | ├── * 89 | | | └── *.mp4 90 | | | └── *.txt 91 | ``` 92 | 93 | Run preprocessing with the following commands: 94 | ```shell 95 | # For LRS2 96 | python preprocessing.py \ 97 | --data_path '/path_to/LRS2' \ 98 | --data_type LRS2 \ 99 | --landmark_path '/path_to/LRS2_landmarks' \ 100 | --save_path '/path_to/LRS2_processed' 101 | ``` 102 | ```shell 103 | # For LRS3 104 | python preprocessing.py \ 105 | --data_path '/path_to/LRS3' \ 106 | --data_type LRS3 \ 107 | --landmark_path '/path_to/LRS3_landmarks' \ 108 | --save_path '/path_to/LRS3_processed' 109 | ``` 110 | 111 | ## Training the Model 112 | Basically, you can choice model architecture with the parameter `architecture`.
113 | There are three options for the `architecture`: `AVRelScore`, `VCAFE`, `Conformer`.
114 | To train the model, run following command: 115 | 116 | ```shell 117 | # AVRelScore: Distributed training example using 2 GPUs on LRS2 (nproc_per_node should have the same number with gpus) 118 | python -m torch.distributed.launch --nproc_per_node=2 \ 119 | train.py \ 120 | --data_path '/path_to/LRS2_processed' \ 121 | --data_type LRS2 \ 122 | --split_file ./src/data/LRS2/0_600.txt \ 123 | --model_conf ./src/models/model.json \ 124 | --checkpoint_dir 'enter_the_path_to_save' \ 125 | --v_frontend_checkpoint ./checkpoints/frontend/lrw_resnet18_dctcn_video.pth.tar \ 126 | --a_frontend_checkpoint ./checkpoints/frontend/lrw_resnet18_dctcn_audio.pth.tar \ 127 | --wandb_project 'wandb_project_name' \ 128 | --batch_size 4 \ 129 | --update_frequency 1 \ 130 | --epochs 200 \ 131 | --eval_step 5000 \ 132 | --visual_corruption \ 133 | --architecture AVRelScore \ 134 | --distributed \ 135 | --gpu 0,1 136 | ``` 137 | 138 | ```shell 139 | # VCAFE: 1 GPU training example on LRS3 140 | python train.py \ 141 | --data_path '/path_to/LRS3_processed' \ 142 | --data_type LRS3 \ 143 | --split_file ./src/data/LRS3/0_600.txt \ 144 | --model_conf ./src/models/model.json \ 145 | --checkpoint_dir 'enter_the_path_to_save' \ 146 | --v_frontend_checkpoint ./checkpoints/frontend/lrw_resnet18_dctcn_video.pth.tar \ 147 | --a_frontend_checkpoint ./checkpoints/frontend/lrw_resnet18_dctcn_audio.pth.tar \ 148 | --wandb_project 'wandb_project_name' \ 149 | --batch_size 4 \ 150 | --update_frequency 1 \ 151 | --epochs 200 \ 152 | --eval_step 5000 \ 153 | --visual_corruption \ 154 | --architecture VCAFE \ 155 | --gpu 0 156 | ``` 157 | 158 | Descriptions of training parameters are as follows: 159 | - `--data_path`: Preprocessed Dataset location (LRS2 or LRS3) 160 | - `--data_type`: Choose to train on LRS2 or LRS3 161 | - `--split_file`: train and validation file lists (you can do curriculum learning by changing the split_file, 0_100.txt consists of files with frames between 0 to 100; training directly on 0_600.txt is also not too bad.) 162 | - `--checkpoint_dir`: directory for saving checkpoints 163 | - `--checkpoint`: saved checkpoint where the training is resumed from 164 | - `--model_conf`: model_configuration 165 | - `--wandb_project`: if want to use wandb, please set the project name here. 166 | - `--batch_size`: batch size 167 | - `--update_frequency`: update_frquency, if you use too small batch_size increase update_frequency. Training batch_size = batch_size * udpate_frequency 168 | - `--epochs`: number of epochs 169 | - `--tot_iters`: if set, the train is finished at the total iterations set 170 | - `--eval_step`: every step for performing evaluation 171 | - `--fast_validate`: if set, validation is performed for a subset of validation data 172 | - `--visual_corruption`: if set, we apply visual corruption modeling during training 173 | - `--architecture`: choose which architecture will be trained. (options: AVRelScore, VCAFE, Conformer) 174 | - `--gpu`: gpu number for training 175 | - `--distributed`: if set, distributed training is performed 176 | - Refer to `train.py` for the other training parameters 177 | 178 | ### check the training logs 179 | ```shell 180 | tensorboard --logdir='./runs/logs to watch' --host='ip address of the server' 181 | ``` 182 | The tensorboard shows the training and validation loss, evaluation metrics. 183 | Also, if you set `wandb_project`, you can check wandb log. 184 | 185 | ## Testing the Model 186 | To test the model, run following command: 187 | ```shell 188 | # AVRelScore: test example on LRS2 189 | python test.py \ 190 | --data_path '/path_to/LRS2_processed' \ 191 | --data_type LRS2 \ 192 | --model_conf ./src/models/model.json \ 193 | --split_file ./src/data/LRS2/test.ref \ 194 | --checkpoint 'enter_the_checkpoint_path' \ 195 | --architecture AVRelScore \ 196 | --results_path './test_results.txt' \ 197 | --rnnlm ./checkpoints/LM/model.pth \ 198 | --rnnlm_conf ./checkpoints/LM/model.json \ 199 | --beam_size 40 \ 200 | --ctc_weight 0.1 \ 201 | --lm_weight 0.5 \ 202 | --gpu 0 203 | ``` 204 | 205 | Descriptions of testing parameters are as follows: 206 | - `--data_path`: Preprocessed Dataset location (LRS2 or LRS3) 207 | - `--data_type`: Choose to train on LRS2 or LRS3 208 | - `--split_file`: set to test.ref (./src/data/LRS2./test.ref or ./src/data/LRS3/test.ref) 209 | - `--checkpoint`: model for testing 210 | - `--model_conf`: model_configuration 211 | - `--architecture`: choose which architecture will be trained. (options: AVRelScore, VCAFE, Conformer) 212 | - `--gpu`: gpu number for training 213 | - `--rnnlm`: language model checkpoint 214 | - `--rnnlm_conf`: language model configuration 215 | - `--beam_size`: beam size 216 | - `--ctc_weight`: ctc weight for joint decoding 217 | - `--lm_weight`: language model weight for decoding 218 | - Refer to `test.py` for the other parameters 219 | 220 | 221 | ## Pre-trained model checkpoints 222 | We release the pre-trained AVSR models (VCAFE and AVRelScore) on LRS2 and LRS3 datasbases. (Below WERs can be obtained at `beam_width`: 40, `ctc_weight`: 0.1, `lm_weight`: 0.5) 223 | 224 | | Model | Dataset | WER | 225 | |:-------------------:|:-------------------:|:--------:| 226 | |VCAFE|LRS2 | [4.459](https://drive.google.com/file/d/1509xCvaMgMwtfJxE04zPgWRgiYgy3fFD/view?usp=sharing) | 227 | |VCAFE|LRS3 | [2.821](https://drive.google.com/file/d/1539Td4FaBCta-1KOCKEz1QovAlzx-DkO/view?usp=sharing) | 228 | |AVRelScore|LRS2 | [4.129](https://drive.google.com/file/d/157fsllT8pldpuCFtVTRYUNd1yK5AoE-b/view?usp=sharing) | 229 | |AVRelScore|LRS3 | [2.770](https://drive.google.com/file/d/159gCUDJAKDIYchS5iNXQ1M5pHnPGIIDd/view?usp=sharing) | 230 | 231 | You can find the pre-trained Language Model in the following [repository](https://github.com/mpc001/Visual_Speech_Recognition_for_Multiple_Languages#Model-Zoo). 232 | Put the language model at 233 | ``` 234 | ./checkpoints/LM/model.pth 235 | ./checkpoints/LM/model.json 236 | ``` 237 | 238 | ## Testing under Audio-Visual Noise Condition 239 | Please refer to the following [repository](https://github.com/joannahong/AV-RelScore) for making the audio-visual corrupted dataset. 240 | 241 | ## Acknowledgment 242 | The code are based on the following two repositories, [ESPNet](https://github.com/espnet/espnet) and [VSR for Multiple Languages](https://github.com/mpc001/Visual_Speech_Recognition_for_Multiple_Languages). 243 | 244 | ## Citation 245 | If you find this work useful in your research, please cite the papers: 246 | ``` 247 | @inproceedings{hong2023watch, 248 | title={Watch or Listen: Robust Audio-Visual Speech Recognition with Visual Corruption Modeling and Reliability Scoring}, 249 | author={Hong, Joanna and Kim, Minsu and Choi, Jeongsoo and Ro, Yong Man}, 250 | booktitle={Proc. CVPR}, 251 | pages={18783--18794}, 252 | year={2023} 253 | } 254 | ``` 255 | ``` 256 | @inproceedings{hong2022visual, 257 | title={Visual Context-driven Audio Feature Enhancement for Robust End-to-End Audio-Visual Speech Recognition}, 258 | author={Hong, Joanna and Kim, Minsu and Ro, Yong Man}, 259 | booktitle={Proc. Interspeech}, 260 | pages={2838--2842}, 261 | year={2022}, 262 | organization={ISCA} 263 | } 264 | ``` 265 | 266 | -------------------------------------------------------------------------------- /espnet/nets/pytorch_backend/ctc.py: -------------------------------------------------------------------------------- 1 | from distutils.version import LooseVersion 2 | import logging 3 | 4 | import numpy as np 5 | import six 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | from espnet.nets.pytorch_backend.nets_utils import to_device 10 | 11 | 12 | class CTC(torch.nn.Module): 13 | """CTC module 14 | 15 | :param int odim: dimension of outputs 16 | :param int eprojs: number of encoder projection units 17 | :param float dropout_rate: dropout rate (0.0 ~ 1.0) 18 | :param str ctc_type: builtin or warpctc 19 | :param bool reduce: reduce the CTC loss into a scalar 20 | """ 21 | 22 | def __init__(self, odim, eprojs, dropout_rate, ctc_type="warpctc", reduce=True): 23 | super().__init__() 24 | self.dropout_rate = dropout_rate 25 | self.loss = None 26 | self.ctc_lo = torch.nn.Linear(eprojs, odim) 27 | self.dropout = torch.nn.Dropout(dropout_rate) 28 | self.probs = None # for visualization 29 | 30 | # In case of Pytorch >= 1.7.0, CTC will be always builtin 31 | self.ctc_type = ( 32 | ctc_type 33 | if LooseVersion(torch.__version__) < LooseVersion("1.7.0") 34 | else "builtin" 35 | ) 36 | 37 | if self.ctc_type == "builtin": 38 | reduction_type = "sum" if reduce else "none" 39 | self.ctc_loss = torch.nn.CTCLoss( 40 | reduction=reduction_type, zero_infinity=True 41 | ) 42 | elif self.ctc_type == "cudnnctc": 43 | reduction_type = "sum" if reduce else "none" 44 | self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type) 45 | elif self.ctc_type == "warpctc": 46 | import warpctc_pytorch as warp_ctc 47 | 48 | self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce) 49 | elif self.ctc_type == "gtnctc": 50 | from espnet.nets.pytorch_backend.gtn_ctc import GTNCTCLossFunction 51 | 52 | self.ctc_loss = GTNCTCLossFunction.apply 53 | else: 54 | raise ValueError( 55 | 'ctc_type must be "builtin" or "warpctc": {}'.format(self.ctc_type) 56 | ) 57 | 58 | self.ignore_id = -1 59 | self.reduce = reduce 60 | 61 | def loss_fn(self, th_pred, th_target, th_ilen, th_olen): 62 | if self.ctc_type in ["builtin", "cudnnctc"]: 63 | th_pred = th_pred.log_softmax(2) 64 | # Use the deterministic CuDNN implementation of CTC loss to avoid 65 | # [issue#17798](https://github.com/pytorch/pytorch/issues/17798) 66 | with torch.backends.cudnn.flags(deterministic=True): 67 | loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen) 68 | # Batch-size average 69 | loss = loss / th_pred.size(1) 70 | return loss 71 | elif self.ctc_type == "warpctc": 72 | return self.ctc_loss(th_pred, th_target, th_ilen, th_olen) 73 | elif self.ctc_type == "gtnctc": 74 | targets = [t.tolist() for t in th_target] 75 | log_probs = torch.nn.functional.log_softmax(th_pred, dim=2) 76 | return self.ctc_loss(log_probs, targets, th_ilen, 0, "none") 77 | else: 78 | raise NotImplementedError 79 | 80 | def forward(self, hs_pad, hlens, ys_pad): 81 | """CTC forward 82 | 83 | :param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D) 84 | :param torch.Tensor hlens: batch of lengths of hidden state sequences (B) 85 | :param torch.Tensor ys_pad: 86 | batch of padded character id sequence tensor (B, Lmax) 87 | :return: ctc loss value 88 | :rtype: torch.Tensor 89 | """ 90 | # TODO(kan-bayashi): need to make more smart way 91 | ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys 92 | 93 | # zero padding for hs 94 | ys_hat = self.ctc_lo(self.dropout(hs_pad)) 95 | if self.ctc_type != "gtnctc": 96 | ys_hat = ys_hat.transpose(0, 1) 97 | 98 | if self.ctc_type == "builtin": 99 | olens = to_device(ys_hat, torch.LongTensor([len(s) for s in ys])) 100 | hlens = hlens.long() 101 | ys_pad = torch.cat(ys) # without this the code breaks for asr_mix 102 | self.loss = self.loss_fn(ys_hat, ys_pad, hlens, olens) 103 | else: 104 | self.loss = None 105 | hlens = torch.from_numpy(np.fromiter(hlens, dtype=np.int32)) 106 | olens = torch.from_numpy( 107 | np.fromiter((x.size(0) for x in ys), dtype=np.int32) 108 | ) 109 | # zero padding for ys 110 | ys_true = torch.cat(ys).cpu().int() # batch x olen 111 | # get ctc loss 112 | # expected shape of seqLength x batchSize x alphabet_size 113 | dtype = ys_hat.dtype 114 | if self.ctc_type == "warpctc" or dtype == torch.float16: 115 | # warpctc only supports float32 116 | # torch.ctc does not support float16 (#1751) 117 | ys_hat = ys_hat.to(dtype=torch.float32) 118 | if self.ctc_type == "cudnnctc": 119 | # use GPU when using the cuDNN implementation 120 | ys_true = to_device(hs_pad, ys_true) 121 | if self.ctc_type == "gtnctc": 122 | # keep as list for gtn 123 | ys_true = ys 124 | self.loss = to_device( 125 | hs_pad, self.loss_fn(ys_hat, ys_true, hlens, olens) 126 | ).to(dtype=dtype) 127 | 128 | # get length info 129 | logging.info( 130 | self.__class__.__name__ 131 | + " input lengths: " 132 | + "".join(str(hlens).split("\n")) 133 | ) 134 | logging.info( 135 | self.__class__.__name__ 136 | + " output lengths: " 137 | + "".join(str(olens).split("\n")) 138 | ) 139 | 140 | if self.reduce: 141 | # NOTE: sum() is needed to keep consistency 142 | # since warpctc return as tensor w/ shape (1,) 143 | # but builtin return as tensor w/o shape (scalar). 144 | self.loss = self.loss.sum() 145 | logging.info("ctc loss:" + str(float(self.loss))) 146 | 147 | return self.loss 148 | 149 | def softmax(self, hs_pad): 150 | """softmax of frame activations 151 | 152 | :param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) 153 | :return: log softmax applied 3d tensor (B, Tmax, odim) 154 | :rtype: torch.Tensor 155 | """ 156 | self.probs = F.softmax(self.ctc_lo(hs_pad), dim=2) 157 | return self.probs 158 | 159 | def log_softmax(self, hs_pad): 160 | """log_softmax of frame activations 161 | 162 | :param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) 163 | :return: log softmax applied 3d tensor (B, Tmax, odim) 164 | :rtype: torch.Tensor 165 | """ 166 | return F.log_softmax(self.ctc_lo(hs_pad), dim=2) 167 | 168 | def argmax(self, hs_pad): 169 | """argmax of frame activations 170 | 171 | :param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) 172 | :return: argmax applied 2d tensor (B, Tmax) 173 | :rtype: torch.Tensor 174 | """ 175 | return torch.argmax(self.ctc_lo(hs_pad), dim=2) 176 | 177 | def forced_align(self, h, y, blank_id=0): 178 | """forced alignment. 179 | 180 | :param torch.Tensor h: hidden state sequence, 2d tensor (T, D) 181 | :param torch.Tensor y: id sequence tensor 1d tensor (L) 182 | :param int y: blank symbol index 183 | :return: best alignment results 184 | :rtype: list 185 | """ 186 | 187 | def interpolate_blank(label, blank_id=0): 188 | """Insert blank token between every two label token.""" 189 | label = np.expand_dims(label, 1) 190 | blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id 191 | label = np.concatenate([blanks, label], axis=1) 192 | label = label.reshape(-1) 193 | label = np.append(label, label[0]) 194 | return label 195 | 196 | lpz = self.log_softmax(h) 197 | lpz = lpz.squeeze(0) 198 | 199 | y_int = interpolate_blank(y, blank_id) 200 | 201 | logdelta = np.zeros((lpz.size(0), len(y_int))) - 100000000000.0 # log of zero 202 | state_path = ( 203 | np.zeros((lpz.size(0), len(y_int)), dtype=np.int16) - 1 204 | ) # state path 205 | 206 | logdelta[0, 0] = lpz[0][y_int[0]] 207 | logdelta[0, 1] = lpz[0][y_int[1]] 208 | 209 | for t in six.moves.range(1, lpz.size(0)): 210 | for s in six.moves.range(len(y_int)): 211 | if y_int[s] == blank_id or s < 2 or y_int[s] == y_int[s - 2]: 212 | candidates = np.array([logdelta[t - 1, s], logdelta[t - 1, s - 1]]) 213 | prev_state = [s, s - 1] 214 | else: 215 | candidates = np.array( 216 | [ 217 | logdelta[t - 1, s], 218 | logdelta[t - 1, s - 1], 219 | logdelta[t - 1, s - 2], 220 | ] 221 | ) 222 | prev_state = [s, s - 1, s - 2] 223 | logdelta[t, s] = np.max(candidates) + lpz[t][y_int[s]] 224 | state_path[t, s] = prev_state[np.argmax(candidates)] 225 | 226 | state_seq = -1 * np.ones((lpz.size(0), 1), dtype=np.int16) 227 | 228 | candidates = np.array( 229 | [logdelta[-1, len(y_int) - 1], logdelta[-1, len(y_int) - 2]] 230 | ) 231 | prev_state = [len(y_int) - 1, len(y_int) - 2] 232 | state_seq[-1] = prev_state[np.argmax(candidates)] 233 | for t in six.moves.range(lpz.size(0) - 2, -1, -1): 234 | state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]] 235 | 236 | output_state_seq = [] 237 | for t in six.moves.range(0, lpz.size(0)): 238 | output_state_seq.append(y_int[state_seq[t, 0]]) 239 | 240 | return output_state_seq 241 | 242 | 243 | def ctc_for(args, odim, reduce=True): 244 | """Returns the CTC module for the given args and output dimension 245 | 246 | :param Namespace args: the program args 247 | :param int odim : The output dimension 248 | :param bool reduce : return the CTC loss in a scalar 249 | :return: the corresponding CTC module 250 | """ 251 | num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility 252 | if num_encs == 1: 253 | # compatible with single encoder asr mode 254 | return CTC( 255 | odim, args.eprojs, args.dropout_rate, ctc_type=args.ctc_type, reduce=reduce 256 | ) 257 | elif num_encs >= 1: 258 | ctcs_list = torch.nn.ModuleList() 259 | if args.share_ctc: 260 | # use dropout_rate of the first encoder 261 | ctc = CTC( 262 | odim, 263 | args.eprojs, 264 | args.dropout_rate[0], 265 | ctc_type=args.ctc_type, 266 | reduce=reduce, 267 | ) 268 | ctcs_list.append(ctc) 269 | else: 270 | for idx in range(num_encs): 271 | ctc = CTC( 272 | odim, 273 | args.eprojs, 274 | args.dropout_rate[idx], 275 | ctc_type=args.ctc_type, 276 | reduce=reduce, 277 | ) 278 | ctcs_list.append(ctc) 279 | return ctcs_list 280 | else: 281 | raise ValueError( 282 | "Number of encoders needs to be more than one. {}".format(num_encs) 283 | ) 284 | --------------------------------------------------------------------------------