├── .gitignore ├── fs_plugins ├── scripts │ ├── substitute_target.py │ ├── average_checkpoints.py │ └── prep_mustc_data.py ├── __init__.py ├── optim │ ├── __init__.py │ └── radam.py ├── models │ ├── __init__.py │ └── transducer │ │ ├── __init__.py │ │ ├── transducer.py │ │ ├── attention_transducer.py │ │ └── transducer_config.py ├── tasks │ ├── __init__.py │ └── transducer_speech_to_text.py ├── criterions │ ├── __init__.py │ ├── transducer_loss.py │ └── transducer_loss_asr.py ├── utils.py ├── modules │ ├── rand_pos.py │ ├── transducer_monotonic_multihead_attention.py │ ├── audio_encoder.py │ ├── monotonic_transformer_layer.py │ ├── attention_transducer_decoder.py │ ├── monotonic_transducer_decoder.py │ └── multihead_attention_patched.py ├── datasets │ └── transducer_speech_to_text_dataset.py └── agents │ ├── transducer_agent.py │ ├── transducer_agent_v2.py │ ├── attention_transducer_agent.py │ └── monotonic_transducer_agent.py ├── .gitmodules └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ -------------------------------------------------------------------------------- /fs_plugins/scripts/substitute_target.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "fairseq"] 2 | path = fairseq 3 | url = git@github.com:facebookresearch/fairseq.git 4 | -------------------------------------------------------------------------------- /fs_plugins/__init__.py: -------------------------------------------------------------------------------- 1 | from .criterions import * 2 | from .models import * 3 | from .tasks import * 4 | 5 | print("fairseq plugins loaded...") -------------------------------------------------------------------------------- /fs_plugins/optim/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in os.listdir(os.path.dirname(__file__)): 6 | if file.endswith(".py") and not file.startswith("_"): 7 | file_name = file[: file.find(".py")] 8 | importlib.import_module("fs_plugins.optim." + file_name) -------------------------------------------------------------------------------- /fs_plugins/models/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in os.listdir(os.path.dirname(__file__)): 6 | if file.endswith(".py") and not file.startswith("_"): 7 | file_name = file[: file.find(".py")] 8 | importlib.import_module("fs_plugins.models." + file_name) 9 | -------------------------------------------------------------------------------- /fs_plugins/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in os.listdir(os.path.dirname(__file__)): 6 | if file.endswith(".py") and not file.startswith("_"): 7 | file_name = file[: file.find(".py")] 8 | importlib.import_module("fs_plugins.tasks." + file_name) 9 | -------------------------------------------------------------------------------- /fs_plugins/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in os.listdir(os.path.dirname(__file__)): 6 | if file.endswith(".py") and not file.startswith("_"): 7 | file_name = file[: file.find(".py")] 8 | importlib.import_module("fs_plugins.criterions." + file_name) 9 | -------------------------------------------------------------------------------- /fs_plugins/models/transducer/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | # automatically import any Python files in the directory 5 | for file in os.listdir(os.path.dirname(__file__)): 6 | if file.endswith(".py") and not file.startswith("_"): 7 | file_name = file[: file.find(".py")] 8 | importlib.import_module("fs_plugins.models.transducer." + file_name) 9 | -------------------------------------------------------------------------------- /fs_plugins/utils.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Any, Dict, Optional, Union 3 | from fairseq.models import FairseqDecoder, FairseqEncoder 4 | from fairseq.file_io import PathManager 5 | from fairseq import utils, checkpoint_utils 6 | from collections import OrderedDict 7 | 8 | def load_pretrained_component_from_model_modified( 9 | component: Union[FairseqEncoder, FairseqDecoder], 10 | checkpoint: str, 11 | strict: bool = True, 12 | ): 13 | """ 14 | Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the 15 | provided `component` object. If state_dict fails to load, there may be a 16 | mismatch in the architecture of the corresponding `component` found in the 17 | `checkpoint` file. 18 | """ 19 | if not PathManager.exists(checkpoint): 20 | raise IOError("Model file not found: {}".format(checkpoint)) 21 | state = checkpoint_utils.load_checkpoint_to_cpu(checkpoint) 22 | if isinstance(component, FairseqEncoder): 23 | component_type = "encoder" 24 | elif isinstance(component, FairseqDecoder): 25 | component_type = "decoder" 26 | else: 27 | raise ValueError( 28 | "component to load must be either a FairseqEncoder or " 29 | "FairseqDecoder. Loading other component types are not supported." 30 | ) 31 | component_state_dict = OrderedDict() 32 | for key in state["model"].keys(): 33 | if key.startswith(component_type): 34 | # encoder.input_layers.0.0.weight --> input_layers.0.0.weight 35 | component_subkey = key[len(component_type) + 1 :] 36 | component_state_dict[component_subkey] = state["model"][key] 37 | keys_info = component.load_state_dict(component_state_dict, strict=strict) 38 | return component, keys_info -------------------------------------------------------------------------------- /fs_plugins/optim/radam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | import types 10 | 11 | import torch 12 | import torch.optim 13 | # from ipdb import set_trace 14 | from fairseq.optim import FairseqOptimizer, register_optimizer 15 | 16 | # from tensorboardX import SummaryWriter 17 | # # writer = SummaryWriter(logdir='./log/wmt/') 18 | # writer = SummaryWriter(logdir='./log/ada/') 19 | # iter_idx = 0 20 | 21 | @register_optimizer('radam') 22 | class FairseqRAdam(FairseqOptimizer): 23 | 24 | def __init__(self, args, params): 25 | super().__init__(args) 26 | 27 | self._optimizer = RAdam(params, **self.optimizer_config) 28 | 29 | @staticmethod 30 | def add_args(parser): 31 | """Add optimizer-specific arguments to the parser.""" 32 | # fmt: off 33 | parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B', 34 | help='betas for Adam optimizer') 35 | parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D', 36 | help='epsilon for Adam optimizer') 37 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 38 | help='weight decay') 39 | 40 | @property 41 | def optimizer_config(self): 42 | """ 43 | Return a kwarg dictionary that will be used to override optimizer 44 | args stored in checkpoints. This allows us to load a checkpoint and 45 | resume training using a different set of optimizer args, e.g., with a 46 | different learning rate. 47 | """ 48 | return { 49 | 'lr': self.args.lr[0], 50 | 'betas': eval(self.args.adam_betas), 51 | 'eps': self.args.adam_eps, 52 | 'weight_decay': self.args.weight_decay, 53 | } 54 | 55 | class RAdam(torch.optim.Optimizer): 56 | 57 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 58 | weight_decay=0, amsgrad=False): 59 | defaults = dict(lr=lr, betas=betas, eps=eps, 60 | weight_decay=weight_decay, amsgrad=amsgrad) 61 | 62 | super(RAdam, self).__init__(params, defaults) 63 | 64 | @property 65 | def supports_memory_efficient_fp16(self): 66 | return True 67 | 68 | def step(self, closure=None): 69 | """Performs a single optimization step. 70 | 71 | Arguments: 72 | closure (callable, optional): A closure that reevaluates the model 73 | and returns the loss. 74 | """ 75 | 76 | loss = None 77 | if closure is not None: 78 | loss = closure() 79 | 80 | for group in self.param_groups: 81 | 82 | for p in group['params']: 83 | if p.grad is None: 84 | continue 85 | grad = p.grad.data.float() 86 | if grad.is_sparse: 87 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 88 | amsgrad = group['amsgrad'] 89 | 90 | p_data_fp32 = p.data.float() 91 | 92 | state = self.state[p] 93 | 94 | if len(state) == 0: 95 | state['step'] = 0 96 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 97 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 98 | else: 99 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 100 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 101 | 102 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 103 | beta1, beta2 = group['betas'] 104 | 105 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 106 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 107 | 108 | state['step'] += 1 109 | 110 | beta2_t = beta2 ** state['step'] 111 | N_sma_max = 2 / (1 - beta2) - 1 112 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 113 | 114 | if group['weight_decay'] != 0: 115 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 116 | 117 | if N_sma >= 5: 118 | step_size = group['lr'] * math.sqrt((1 - beta2_t ) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) * (N_sma_max) / N_sma / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 119 | denom = exp_avg_sq.sqrt().add_(group['eps']) 120 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 121 | # p.data.copy_(p_data_fp32) 122 | else: 123 | step_size = group['lr'] / (1 - beta1 ** state['step']) 124 | # p_data_fp32.add_(-step_size, exp_avg) 125 | 126 | p.data.copy_(p_data_fp32) 127 | 128 | return loss -------------------------------------------------------------------------------- /fs_plugins/modules/rand_pos.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | import fairseq 5 | from fairseq import utils 6 | from typing import Any, Optional 7 | import math 8 | from fairseq.modules.positional_embedding import ( 9 | SinusoidalPositionalEmbedding, 10 | LearnedPositionalEmbedding 11 | ) 12 | 13 | def PositionalEmbedding( 14 | num_embeddings: int, 15 | embedding_dim: int, 16 | padding_idx: int, 17 | rand_max:int = 0, 18 | learned: bool = False, 19 | ): 20 | if rand_max > 0: 21 | assert learned ==False, "rand_start with learned positional embedding not implemented" 22 | m= RandStartSinPositionalEmbedding( 23 | embedding_dim, 24 | padding_idx, 25 | rand_max = rand_max, 26 | init_size = num_embeddings + padding_idx + 1 27 | ) 28 | elif learned: 29 | if padding_idx is not None: 30 | num_embeddings = num_embeddings + padding_idx + 1 31 | m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx) 32 | nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) 33 | if padding_idx is not None: 34 | nn.init.constant_(m.weight[padding_idx], 0) 35 | else: 36 | m = SinusoidalPositionalEmbedding( 37 | embedding_dim, 38 | padding_idx, 39 | init_size=num_embeddings + padding_idx + 1, 40 | ) 41 | return m 42 | 43 | class RandStartSinPositionalEmbedding(nn.Module): 44 | """ 45 | positional embedding starts index from a random number during training, 46 | which is more robust for speech encoder compared to starts from 0 47 | """ 48 | def __init__(self, embedding_dim, padding_idx, rand_max=1, init_size=1024): 49 | super().__init__() 50 | self.embedding_dim = embedding_dim 51 | self.padding_idx = padding_idx 52 | self.weights = RandStartSinPositionalEmbedding.get_embedding( 53 | init_size, embedding_dim, padding_idx 54 | ) 55 | 56 | self.register_buffer("_float_tensor", torch.FloatTensor(1)) 57 | self.rand_max= rand_max 58 | self.max_positions = int(1e5) 59 | self.onnx_trace = False 60 | 61 | @staticmethod 62 | def get_embedding( 63 | num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None 64 | ): 65 | """Build sinusoidal embeddings. 66 | 67 | This matches the implementation in tensor2tensor, but differs slightly 68 | from the description in Section 3.5 of "Attention Is All You Need". 69 | """ 70 | half_dim = embedding_dim // 2 71 | emb = math.log(10000) / (half_dim - 1) 72 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) 73 | emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze( 74 | 1 75 | ) * emb.unsqueeze(0) 76 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view( 77 | num_embeddings, -1 78 | ) 79 | if embedding_dim % 2 == 1: 80 | # zero pad 81 | emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) 82 | if padding_idx is not None: 83 | emb[padding_idx, :] = 0 84 | return emb 85 | 86 | def forward( 87 | self, 88 | input, 89 | incremental_state: Optional[Any] = None, 90 | timestep: Optional[Tensor] = None, 91 | positions: Optional[Any] = None, 92 | ): 93 | """Input is expected to be of size [bsz x seqlen].""" 94 | bspair = torch.onnx.operators.shape_as_tensor(input) 95 | bsz, seq_len = bspair[0], bspair[1] 96 | max_pos = self.padding_idx + 1 + seq_len 97 | if self.weights is None or max_pos > self.weights.size(0): 98 | # recompute/expand embeddings if needed 99 | self.weights = SinusoidalPositionalEmbedding.get_embedding( 100 | max_pos, self.embedding_dim, self.padding_idx 101 | ) 102 | self.weights = self.weights.to(self._float_tensor) 103 | 104 | if incremental_state is not None: 105 | # positions is the same for every token when decoding a single step 106 | pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len 107 | if self.onnx_trace: 108 | return ( 109 | self.weights.index_select(index=self.padding_idx + pos, dim=0) 110 | .unsqueeze(1) 111 | .repeat(bsz, 1, 1) 112 | ) 113 | return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) 114 | 115 | positions = utils.make_positions( 116 | input, self.padding_idx, onnx_trace=self.onnx_trace 117 | ) 118 | if self.training: 119 | rand_max= min(self.weights.shape[0] - max_pos, self.rand_max) 120 | bsz = positions.shape[0] 121 | rand_pos = (torch.rand(bsz)*rand_max).long().to(positions.device) 122 | positions += rand_pos.unsqueeze(1) 123 | 124 | return ( 125 | self.weights.index_select(0, positions.view(-1)) 126 | .view(bsz, seq_len, -1) 127 | .detach() 128 | ) -------------------------------------------------------------------------------- /fs_plugins/modules/transducer_monotonic_multihead_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | 8 | import torch 9 | from torch import Tensor 10 | import torch.nn as nn 11 | 12 | from examples.simultaneous_translation.utils.monotonic_attention import expected_soft_attention 13 | #from ..utils.monotonic_attention import expected_soft_attention 14 | from fairseq.modules import MultiheadAttention 15 | 16 | from typing import Dict, Optional 17 | 18 | 19 | class MonotonicAttention(MultiheadAttention): 20 | """ 21 | Abstract class of monotonic attentions 22 | """ 23 | 24 | def __init__(self, cfg): 25 | super().__init__( 26 | embed_dim=cfg.decoder.embed_dim, 27 | num_heads=cfg.decoder.attention_heads, 28 | kdim=cfg.encoder_embed_dim, 29 | vdim=cfg.encoder_embed_dim, 30 | dropout=cfg.attention_dropout, 31 | encoder_decoder_attention=True, 32 | ) 33 | 34 | self.eps = 1e-6 35 | 36 | 37 | def energy_from_qk( 38 | self, 39 | query: Tensor, 40 | key: Tensor, 41 | ): 42 | """ 43 | Compute energy from query and key 44 | q_tensor size: bsz, tgt_len, emb_dim 45 | k_tensor size: bsz, src_len, emb_dim 46 | """ 47 | 48 | length, bsz, _ = query.size() 49 | q = self.q_proj.forward(query) 50 | q = ( 51 | q.contiguous() 52 | .view(length, bsz * self.num_heads, self.head_dim) 53 | .transpose(0, 1) 54 | ) 55 | q = q * self.scaling 56 | length, bsz, _ = key.size() 57 | k = self.k_proj.forward(key) 58 | k = ( 59 | k.contiguous() 60 | .view(length, bsz * self.num_heads, self.head_dim) 61 | .transpose(0, 1) 62 | ) 63 | 64 | energy = torch.bmm(q, k.transpose(1, 2)) 65 | 66 | return energy 67 | 68 | 69 | 70 | def monotonic_attention_process_infer( 71 | self, 72 | query: Optional[Tensor], 73 | key: Optional[Tensor], 74 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], 75 | ): 76 | """ 77 | Monotonic attention at inference time 78 | Notice that this function is designed for simuleval not sequence_generator 79 | """ 80 | assert query is not None 81 | assert key is not None 82 | 83 | soft_energy = self.energy_from_qk( 84 | query, 85 | key, 86 | ) 87 | # TODO: beta_mask 88 | # soft_energy = soft_energy.masked_fill(beta_mask, -float("inf")) 89 | beta = torch.nn.functional.softmax(soft_energy, dim=-1) 90 | 91 | 92 | return beta 93 | 94 | def monotonic_attention_process_train( 95 | self, 96 | query: Optional[Tensor], 97 | key: Optional[Tensor], 98 | key_padding_mask: Optional[Tensor] = None, 99 | posterior: Optional[Tensor] = None, # posterior: B × (U+1) × T 100 | ): 101 | """ 102 | Calculating monotonic attention process for training 103 | Including: 104 | expected soft attention: beta 105 | """ 106 | assert query is not None 107 | assert key is not None 108 | 109 | 110 | soft_energy = self.energy_from_qk( 111 | query, 112 | key, 113 | ) 114 | 115 | beta = expected_soft_attention( 116 | posterior, 117 | soft_energy, 118 | padding_mask=key_padding_mask, 119 | chunk_size=None, 120 | eps=self.eps, 121 | ) 122 | 123 | return beta 124 | 125 | def forward( 126 | self, 127 | query: Optional[Tensor], 128 | key: Optional[Tensor], 129 | value: Optional[Tensor], 130 | key_padding_mask: Optional[Tensor] = None, 131 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, 132 | need_weights: bool = True, 133 | static_kv: bool = False, 134 | attn_mask: Optional[Tensor] = None, 135 | need_head_weights: bool = False, 136 | posterior: Optional[Tensor] = None, 137 | ): 138 | """ 139 | query: tgt_len, bsz, embed_dim 140 | key: src_len, bsz, embed_dim 141 | value: src_len, bsz, embed_dim 142 | """ 143 | 144 | assert attn_mask is None 145 | assert query is not None 146 | assert key is not None 147 | assert value is not None 148 | 149 | if need_head_weights: 150 | need_weights = True 151 | 152 | tgt_len, bsz, embed_dim = query.size() 153 | src_len = value.size(0) 154 | 155 | if key_padding_mask is not None: 156 | assert not key_padding_mask[:, 0].any(), ( 157 | "Only right padding is supported." 158 | ) 159 | key_padding_mask = ( 160 | key_padding_mask 161 | .unsqueeze(1) 162 | .expand([bsz, self.num_heads, src_len]) 163 | .contiguous() 164 | .view(-1, src_len) 165 | ) 166 | 167 | if incremental_state is not None: 168 | # Inference 169 | beta = self.monotonic_attention_process_infer(query, key, incremental_state) 170 | 171 | else: 172 | # Train 173 | beta = self.monotonic_attention_process_train(query, key, key_padding_mask, posterior.unsqueeze(1).expand([bsz, self.num_heads, -1, -1]).contiguous().view(-1, posterior.size(1), posterior.size(2))) 174 | 175 | v = self.v_proj(value) 176 | length, bsz, _ = v.size() 177 | v = ( 178 | v.contiguous() 179 | .view(length, bsz * self.num_heads, self.head_dim) 180 | .transpose(0, 1) 181 | ) 182 | 183 | attn = torch.bmm(beta.type_as(v), v) 184 | 185 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 186 | 187 | attn = self.out_proj(attn) 188 | attn_weights = None 189 | if need_weights: 190 | attn_weights = beta.view( 191 | bsz, self.num_heads, tgt_len, src_len 192 | ).transpose(1, 0) 193 | if not need_head_weights: 194 | # average attention weights over heads 195 | attn_weights = attn_weights.mean(dim=0) 196 | 197 | return attn, attn_weights 198 | 199 | 200 | -------------------------------------------------------------------------------- /fs_plugins/modules/audio_encoder.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Any, Dict, List, Optional, Tuple 3 | import torch 4 | import os 5 | from torch import Tensor 6 | import torch.nn as nn 7 | from fairseq import options, utils, checkpoint_utils 8 | from fairseq.dataclass import ChoiceEnum, FairseqDataclass 9 | from fairseq.models import FairseqEncoder 10 | from fairseq.data import Dictionary 11 | 12 | from fairseq.modules import ( 13 | FairseqDropout, 14 | LayerNorm,Fp32LayerNorm, 15 | LayerDropModuleList, 16 | TransformerEncoderLayer, 17 | ) 18 | from fairseq.modules.checkpoint_activations import checkpoint_wrapper 19 | 20 | #from fairseq.modules import PositionalEmbedding 21 | from .rand_pos import PositionalEmbedding 22 | from .audio_convs import get_conv 23 | 24 | 25 | class AudioTransformerEncoder(FairseqEncoder): 26 | def __init__(self, args): 27 | super().__init__(Dictionary()) 28 | self.dropout_module = FairseqDropout( 29 | args.dropout, module_name=self.__class__.__name__ 30 | ) 31 | self.encoder_layerdrop = args.encoder_layerdrop 32 | embed_dim = args.encoder_embed_dim 33 | self.padding_idx = self.dictionary.pad() 34 | self.max_source_positions = args.max_source_positions 35 | self.conv_layers = self.build_conv_layers(args) 36 | 37 | self.embed_positions = ( 38 | PositionalEmbedding( 39 | args.max_source_positions, 40 | embed_dim, 41 | self.padding_idx, 42 | rand_max = args.rand_pos_encoder, 43 | learned=args.encoder_learned_pos, 44 | ) 45 | if not args.no_audio_positional_embeddings 46 | else None 47 | ) 48 | self.layernorm_embedding = LayerNorm(embed_dim) 49 | 50 | 51 | if self.encoder_layerdrop > 0.0: 52 | self.layers = LayerDropModuleList(p=self.encoder_layerdrop) 53 | else: 54 | self.layers = nn.ModuleList([]) 55 | self.layers.extend( 56 | [self.build_encoder_layer(args) for i in range(args.encoder_layers)] 57 | ) 58 | self.num_layers = len(self.layers) 59 | 60 | if args.encoder_normalize_before: 61 | self.layer_norm = LayerNorm(embed_dim) 62 | else: 63 | self.layer_norm = None 64 | 65 | def build_conv_layers(self, args): 66 | return get_conv(args.conv_type)(args.input_feat_per_channel, args.encoder_embed_dim) 67 | 68 | def build_encoder_layer(self, args): 69 | layer = TransformerEncoderLayer(args) 70 | if getattr(args, "checkpoint_activations", False): 71 | layer = checkpoint_wrapper(layer) 72 | return layer 73 | 74 | def forward( 75 | self, 76 | fbank:torch.Tensor, 77 | fbk_lengths:torch.Tensor, 78 | **kwargs 79 | ): 80 | # x is already TBC 81 | x, padding_mask = self.conv_layers(fbank, fbk_lengths) 82 | 83 | fake_tokens = padding_mask.long() 84 | # layernorm after garbage convs 85 | x = self.layernorm_embedding(x) 86 | if self.embed_positions is not None: 87 | x = x + self.embed_positions(fake_tokens).transpose(0,1) 88 | 89 | # encoder layers 90 | for layer in self.layers: 91 | x = layer(x, padding_mask) 92 | 93 | if self.layer_norm is not None: 94 | x = self.layer_norm(x) 95 | 96 | return { 97 | "encoder_out": [x], # T x B x C 98 | "encoder_padding_mask": [padding_mask], # B x T 99 | "encoder_embedding": [], # B x T x C 100 | "encoder_states": [], # List[T x B x C] 101 | "src_tokens": [], 102 | "src_lengths": [], 103 | "dec1_state":[], # reserved for joint decoding 104 | "dec1_padding_mask":[], 105 | } 106 | 107 | @torch.jit.export 108 | def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): 109 | 110 | if len(encoder_out["encoder_out"]) == 0: 111 | new_encoder_out = [] 112 | else: 113 | new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] 114 | if len(encoder_out["encoder_padding_mask"]) == 0: 115 | new_encoder_padding_mask = [] 116 | else: 117 | new_encoder_padding_mask = [ 118 | encoder_out["encoder_padding_mask"][0].index_select(0, new_order) 119 | ] 120 | if len(encoder_out["encoder_embedding"]) == 0: 121 | new_encoder_embedding = [] 122 | else: 123 | new_encoder_embedding = [ 124 | encoder_out["encoder_embedding"][0].index_select(0, new_order) 125 | ] 126 | 127 | if len(encoder_out["src_tokens"]) == 0: 128 | src_tokens = [] 129 | else: 130 | src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)] 131 | 132 | if len(encoder_out["src_lengths"]) == 0: 133 | src_lengths = [] 134 | else: 135 | src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)] 136 | 137 | if len(encoder_out["dec1_state"]) ==0: 138 | dec1_states=[] 139 | else: 140 | dec1_states = [(encoder_out["dec1_state"][0]).index_select(1, new_order)] 141 | 142 | if len(encoder_out["dec1_padding_mask"]) == 0: 143 | dec1_padding_mask= [] 144 | else: 145 | dec1_padding_mask = [encoder_out["dec1_padding_mask"][0].index_select(0,new_order)] 146 | 147 | encoder_states = encoder_out["encoder_states"] 148 | if len(encoder_states) > 0: 149 | for idx, state in enumerate(encoder_states): 150 | encoder_states[idx] = state.index_select(1, new_order) 151 | 152 | return { 153 | "encoder_out": new_encoder_out, # T x B x C 154 | "encoder_padding_mask": new_encoder_padding_mask, # B x T 155 | "encoder_embedding": new_encoder_embedding, # B x T x C 156 | "encoder_states": encoder_states, # List[T x B x C] 157 | "src_tokens": src_tokens, # B x T 158 | "src_lengths": src_lengths, # B x 1 159 | "dec1_state":dec1_states, # TxBxC 160 | "dec1_padding_mask":dec1_padding_mask, # BxT 161 | } 162 | 163 | def max_positions(self): 164 | """Maximum input length supported by the encoder.""" 165 | if self.embed_positions is None: 166 | return self.max_source_positions 167 | return min(self.max_source_positions, self.embed_positions.max_positions) 168 | 169 | -------------------------------------------------------------------------------- /fs_plugins/modules/monotonic_transformer_layer.py: -------------------------------------------------------------------------------- 1 | from fairseq.modules.transformer_layer import TransformerDecoderLayerBase 2 | 3 | from .transducer_monotonic_multihead_attention import MonotonicAttention 4 | 5 | from typing import Dict, Optional, List 6 | 7 | from torch import Tensor 8 | import torch 9 | 10 | class MonotonicTransformerDecoderLayer(TransformerDecoderLayerBase): 11 | def __init__(self, cfg): 12 | super().__init__(cfg) 13 | self.encoder_attn = MonotonicAttention(cfg) 14 | 15 | def forward( 16 | self, 17 | x, 18 | encoder_out: Optional[Tensor] = None, 19 | encoder_padding_mask: Optional[Tensor] = None, 20 | posterior: Optional[Tensor] = None, 21 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, 22 | prev_self_attn_state: Optional[List[Tensor]] = None, 23 | prev_attn_state: Optional[List[Tensor]] = None, 24 | self_attn_mask: Optional[Tensor] = None, 25 | self_attn_padding_mask: Optional[Tensor] = None, 26 | need_attn: bool = False, 27 | need_head_weights: bool = False, 28 | ): 29 | """ 30 | Args: 31 | x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` 32 | encoder_padding_mask (ByteTensor, optional): binary 33 | ByteTensor of shape `(batch, src_len)` where padding 34 | elements are indicated by ``1``. 35 | need_attn (bool, optional): return attention weights 36 | need_head_weights (bool, optional): return attention weights 37 | for each head (default: return average over heads). 38 | 39 | Returns: 40 | encoded output of shape `(seq_len, batch, embed_dim)` 41 | """ 42 | if need_head_weights: 43 | need_attn = True 44 | 45 | residual = x 46 | if self.normalize_before: 47 | x = self.self_attn_layer_norm(x) 48 | if prev_self_attn_state is not None: 49 | prev_key, prev_value = prev_self_attn_state[:2] 50 | saved_state: Dict[str, Optional[Tensor]] = { 51 | "prev_key": prev_key, 52 | "prev_value": prev_value, 53 | } 54 | if len(prev_self_attn_state) >= 3: 55 | saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] 56 | assert incremental_state is not None 57 | self.self_attn._set_input_buffer(incremental_state, saved_state) 58 | _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) 59 | if self.cross_self_attention and not ( 60 | incremental_state is not None 61 | and _self_attn_input_buffer is not None 62 | and "prev_key" in _self_attn_input_buffer 63 | ): 64 | if self_attn_mask is not None: 65 | assert encoder_out is not None 66 | self_attn_mask = torch.cat( 67 | (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1 68 | ) 69 | if self_attn_padding_mask is not None: 70 | if encoder_padding_mask is None: 71 | assert encoder_out is not None 72 | encoder_padding_mask = self_attn_padding_mask.new_zeros( 73 | encoder_out.size(1), encoder_out.size(0) 74 | ) 75 | self_attn_padding_mask = torch.cat( 76 | (encoder_padding_mask, self_attn_padding_mask), dim=1 77 | ) 78 | assert encoder_out is not None 79 | y = torch.cat((encoder_out, x), dim=0) 80 | else: 81 | y = x 82 | 83 | x, attn = self.self_attn( 84 | query=x, 85 | key=y, 86 | value=y, 87 | key_padding_mask=self_attn_padding_mask, 88 | incremental_state=incremental_state, 89 | need_weights=False, 90 | attn_mask=self_attn_mask, 91 | ) 92 | x = self.dropout_module(x) 93 | x = self.residual_connection(x, residual) 94 | if not self.normalize_before: 95 | x = self.self_attn_layer_norm(x) 96 | 97 | assert self.encoder_attn is not None 98 | if encoder_out is not None: 99 | residual = x 100 | if self.normalize_before: 101 | x = self.encoder_attn_layer_norm(x) 102 | if prev_attn_state is not None: 103 | prev_key, prev_value = prev_attn_state[:2] 104 | saved_state: Dict[str, Optional[Tensor]] = { 105 | "prev_key": prev_key, 106 | "prev_value": prev_value, 107 | } 108 | if len(prev_attn_state) >= 3: 109 | saved_state["prev_key_padding_mask"] = prev_attn_state[2] 110 | assert incremental_state is not None 111 | self.encoder_attn._set_input_buffer(incremental_state, saved_state) 112 | 113 | x, attn = self.encoder_attn( 114 | query=x, 115 | key=encoder_out, 116 | value=encoder_out, 117 | key_padding_mask=encoder_padding_mask, 118 | incremental_state=incremental_state, 119 | static_kv=False, 120 | need_weights=need_attn or (not self.training and self.need_attn), 121 | need_head_weights=need_head_weights, 122 | posterior=posterior, 123 | ) 124 | x = self.dropout_module(x) 125 | x = self.residual_connection(x, residual) 126 | if not self.normalize_before: 127 | x = self.encoder_attn_layer_norm(x) 128 | 129 | residual = x 130 | if self.normalize_before: 131 | x = self.final_layer_norm(x) 132 | 133 | x = self.activation_fn(self.fc1(x)) 134 | x = self.activation_dropout_module(x) 135 | x = self.fc2(x) 136 | x = self.dropout_module(x) 137 | x = self.residual_connection(x, residual) 138 | if not self.normalize_before: 139 | x = self.final_layer_norm(x) 140 | if self.onnx_trace and incremental_state is not None: 141 | saved_state = self.self_attn._get_input_buffer(incremental_state) 142 | assert saved_state is not None 143 | if self_attn_padding_mask is not None: 144 | self_attn_state = [ 145 | saved_state["prev_key"], 146 | saved_state["prev_value"], 147 | saved_state["prev_key_padding_mask"], 148 | ] 149 | else: 150 | self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] 151 | return x, attn, self_attn_state 152 | return x, attn, None -------------------------------------------------------------------------------- /fs_plugins/modules/attention_transducer_decoder.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | import torch 3 | 4 | from torch import Tensor 5 | import torch.nn as nn 6 | 7 | from fairseq import utils 8 | from fairseq.models import FairseqIncrementalDecoder 9 | from fairseq.models.transformer import TransformerDecoder 10 | 11 | 12 | 13 | 14 | 15 | class AttentionDecoder(TransformerDecoder): 16 | def __init__(self, args, dictionary, embed_tokens): 17 | super().__init__( 18 | args, dictionary, embed_tokens, no_encoder_attn=False 19 | ) 20 | self.output_projection= None 21 | 22 | def forward( 23 | self, 24 | prev_output_tokens, 25 | encoder_out, 26 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, 27 | ): 28 | """ 29 | for transducer, prev_output_tokens should be [bos] concat target 30 | """ 31 | x, extra = self.extract_features( 32 | prev_output_tokens, 33 | encoder_out=encoder_out, 34 | incremental_state=incremental_state, 35 | full_context_alignment=False, 36 | alignment_layer=None, 37 | alignment_heads=None, 38 | ) 39 | return x 40 | 41 | 42 | class AddJointNet(nn.Module): 43 | def __init__( 44 | self, 45 | encoder_dim, 46 | decoder_dim, 47 | hid_dim, 48 | activation="tanh", 49 | downsample=1, 50 | ): 51 | super().__init__() 52 | self.downsample = downsample 53 | self.encoder_proj = nn.Linear(encoder_dim, hid_dim) 54 | self.decoder_proj = nn.Linear(decoder_dim, hid_dim) 55 | self.activation_fn = utils.get_activation_fn(activation) 56 | #self.joint_proj = nn.Linear(hid_dim, hid_dim) 57 | #self.layer_norm = LayerNorm(hid_dim) 58 | if downsample < 1: 59 | raise ValueError("downsample should be more than 1 for add_joint") 60 | 61 | def forward(self, encoder_out:Dict[str, List[Tensor]], decoder_state, padding_idx): 62 | """ 63 | use dimension same as transformer 64 | Args: 65 | encoder_out: "encoder_out": TxBxC 66 | decoder_state: BxUxC 67 | """ 68 | encoder_state = encoder_out["encoder_out"][0] 69 | encoder_state = encoder_state[::self.downsample].contiguous() 70 | encoder_state = encoder_state.transpose(0,1) 71 | 72 | h_enc = self.encoder_proj(encoder_state) 73 | h_dec = self.decoder_proj(decoder_state) 74 | h_joint = h_enc.unsqueeze(2) + h_dec.unsqueeze(1) 75 | h_joint = self.activation_fn(h_joint) 76 | #h_joint = self.joint_proj(h_joint) 77 | #h_joint = self.layer_norm(h_joint) 78 | 79 | fake_src_tokens = (encoder_out["encoder_padding_mask"][0]).long() 80 | fake_src_lengths = fake_src_tokens.ne(padding_idx).sum(dim=-1) 81 | fake_src_lengths = (fake_src_lengths / self.downsample).ceil().long() 82 | 83 | return h_joint, fake_src_lengths 84 | 85 | def infer(self, encoder_state, decoder_state): 86 | """ 87 | use dimension same as transformer 88 | Args: 89 | encoder_out: "encoder_out": C 90 | decoder_state: C 91 | """ 92 | 93 | h_enc = self.encoder_proj(encoder_state) 94 | h_dec = self.decoder_proj(decoder_state) 95 | h_joint = h_enc + h_dec 96 | h_joint = self.activation_fn(h_joint) 97 | #h_joint = self.joint_proj(h_joint) 98 | #h_joint = self.layer_norm(h_joint) 99 | 100 | return h_joint 101 | 102 | 103 | 104 | class ConcatJointNet(nn.Module): 105 | def __init__( 106 | self, 107 | encoder_dim, 108 | decoder_dim, 109 | hid_dim, 110 | activation="tanh", 111 | downsample=1, 112 | ) -> None: 113 | super().__init__() 114 | self.fc1 = nn.Linear((encoder_dim+decoder_dim), hid_dim) 115 | self.downsample = downsample 116 | self.activation_fn = utils.get_activation_fn(activation) 117 | if downsample < 1: 118 | raise ValueError("downsample should be more than 1 for concat_joint") 119 | 120 | def forward(self, encoder_out:Dict[str, List[Tensor]], decoder_state, padding_idx): 121 | 122 | encoder_state = encoder_out["encoder_out"][0] 123 | encoder_state = encoder_state[::self.downsample].contiguous() #TODO: downsample 124 | encoder_state = encoder_state.transpose(0,1) 125 | 126 | seq_lens = encoder_state.size(1) 127 | target_lens = decoder_state.size(1) 128 | 129 | encoder_state = encoder_state.unsqueeze(2) 130 | decoder_state = decoder_state.unsqueeze(1) 131 | 132 | encoder_state = encoder_state.expand(-1, -1, target_lens, -1) 133 | decoder_state = decoder_state.expand(-1, seq_lens, -1, -1) 134 | 135 | h_joint = torch.cat((encoder_state, decoder_state), dim=-1) 136 | 137 | h_joint = self.fc1(h_joint) 138 | h_joint = self.activation_fn(h_joint) 139 | 140 | fake_src_tokens = (encoder_out["encoder_padding_mask"][0]).long() 141 | fake_src_lengths = fake_src_tokens.ne(padding_idx).sum(dim=-1) 142 | fake_src_lengths = (fake_src_lengths / self.downsample).ceil().long() 143 | 144 | return h_joint, fake_src_lengths 145 | 146 | 147 | 148 | class AttentionTransducerDecoder(FairseqIncrementalDecoder): 149 | def __init__(self, args, dictionary, embed_tokens): 150 | super().__init__(dictionary) 151 | self.lm = AttentionDecoder(args, dictionary, embed_tokens) 152 | self.output_embed_dim = args.decoder_output_dim 153 | self.out_proj = nn.Linear(args.decoder_output_dim, len(dictionary), bias=False) 154 | if args.share_decoder_input_output_embed: 155 | self.out_proj.weight= embed_tokens.weight 156 | else: 157 | nn.init.normal_( 158 | self.out_proj.weight, mean=0, std=self.output_embed_dim ** -0.5 159 | ) 160 | self.blank_idx= dictionary.blank_index 161 | self.padding_idx = dictionary.pad() 162 | self.downsample = getattr(args, "transducer_downsample", 1) 163 | #self.jointer = ConcatJointNet(args.encoder_embed_dim, args.decoder_output_dim, args.decoder_output_dim, downsample=self.downsample) 164 | self.jointer = AddJointNet(args.encoder_embed_dim, args.decoder_output_dim, args.decoder_output_dim, downsample=self.downsample) 165 | 166 | def forward( 167 | self, 168 | prev_output_tokens:Tensor, 169 | encoder_out:Dict[str, List[Tensor]], 170 | ): 171 | h_lm = self.lm(prev_output_tokens, encoder_out) 172 | 173 | joint_result, fake_src_lengths = self.jointer(encoder_out, h_lm, self.padding_idx) 174 | 175 | joint_result = self.out_proj(joint_result) # it is logits, no logsoftmax performed 176 | 177 | return joint_result, fake_src_lengths 178 | -------------------------------------------------------------------------------- /fs_plugins/criterions/transducer_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from fairseq import metrics, utils 11 | from fairseq.criterions import FairseqCriterion, register_criterion 12 | from fairseq.dataclass import FairseqDataclass 13 | from torch import Tensor 14 | 15 | from dataclasses import dataclass, field 16 | 17 | 18 | @dataclass 19 | class LabelSmoothedDualImitationCriterionConfig(FairseqDataclass): 20 | label_smoothing: float = field( 21 | default=0.0, 22 | metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"}, 23 | ) 24 | 25 | 26 | 27 | @register_criterion("transducer_loss", dataclass=LabelSmoothedDualImitationCriterionConfig) 28 | class LabelSmoothedDualImitationCriterion(FairseqCriterion): 29 | def __init__(self, task, label_smoothing): 30 | super().__init__(task) 31 | self.label_smoothing = label_smoothing 32 | 33 | def _compute_loss( 34 | self, outputs, targets, masks=None, label_smoothing=0.0, name="loss", factor=1.0 35 | ): 36 | """ 37 | outputs: batch x len x d_model 38 | targets: batch x len 39 | masks: batch x len 40 | 41 | policy_logprob: if there is some policy 42 | depends on the likelihood score as rewards. 43 | """ 44 | 45 | def mean_ds(x: Tensor, dim=None) -> Tensor: 46 | return ( 47 | x.float().mean().type_as(x) 48 | if dim is None 49 | else x.float().mean(dim).type_as(x) 50 | ) 51 | 52 | if masks is not None: 53 | outputs, targets = outputs[masks], targets[masks] 54 | 55 | if masks is not None and not masks.any(): 56 | nll_loss = torch.tensor(0) 57 | loss = nll_loss 58 | else: 59 | logits = F.log_softmax(outputs, dim=-1) 60 | if targets.dim() == 1: 61 | losses = F.nll_loss(logits, targets.to(logits.device), reduction="none") 62 | 63 | else: # soft-labels 64 | losses = F.kl_div(logits, targets.to(logits.device), reduction="none") 65 | losses = losses.sum(-1) 66 | 67 | nll_loss = mean_ds(losses) 68 | if label_smoothing > 0: 69 | loss = ( 70 | nll_loss * (1 - label_smoothing) - mean_ds(logits) * label_smoothing 71 | ) 72 | else: 73 | loss = nll_loss 74 | 75 | loss = loss * factor 76 | return {"name": name, "loss": loss, "nll_loss": nll_loss, "factor": factor} 77 | 78 | def _custom_loss(self, loss, name="loss", factor=1.0): 79 | return {"name": name, "loss": loss, "factor": factor} 80 | 81 | def forward(self, model, sample, reduce=True): 82 | """Compute the loss for the given sample. 83 | Returns a tuple with three elements: 84 | 1) the loss 85 | 2) the sample size, which is used as the denominator for the gradient 86 | 3) logging outputs to display while training 87 | """ 88 | nsentences, ntokens = sample["nsentences"], sample["ntokens"] 89 | 90 | # B x T 91 | src_tokens, src_lengths = ( 92 | sample["net_input"]["src_tokens"], 93 | sample["net_input"]["src_lengths"], 94 | ) 95 | tgt_tokens, prev_output_tokens = sample["target"], sample["net_input"]["prev_output_tokens"] 96 | 97 | outputs = model(src_tokens, src_lengths, prev_output_tokens, tgt_tokens) 98 | losses, nll_loss = [], [] 99 | 100 | for obj in outputs: 101 | if outputs[obj].get("loss", None) is None: 102 | _losses = self._compute_loss( 103 | outputs[obj].get("out"), 104 | outputs[obj].get("tgt"), 105 | outputs[obj].get("mask", None), 106 | outputs[obj].get("ls", 0.0), 107 | name=obj + "-loss", 108 | factor=outputs[obj].get("factor", 1.0), 109 | ) 110 | else: 111 | _losses = self._custom_loss( 112 | outputs[obj].get("loss"), 113 | name=obj + "-loss", 114 | factor=outputs[obj].get("factor", 1.0), 115 | ) 116 | 117 | losses += [_losses] 118 | if outputs[obj].get("nll_loss", False): 119 | nll_loss += [_losses.get("nll_loss", 0.0)] 120 | 121 | loss = sum(l["loss"] for l in losses) 122 | nll_loss = sum(l for l in nll_loss) if len(nll_loss) > 0 else loss.new_tensor(0) 123 | 124 | # NOTE: 125 | # we don't need to use sample_size as denominator for the gradient 126 | # here sample_size is just used for logging 127 | sample_size = 1 128 | logging_output = { 129 | "loss": loss.data, 130 | "nll_loss": nll_loss.data, 131 | "ntokens": ntokens, 132 | "nsentences": nsentences, 133 | "sample_size": sample_size, 134 | } 135 | 136 | for l in losses: 137 | logging_output[l["name"]] = ( 138 | utils.item(l["loss"].data / l["factor"]) 139 | if reduce 140 | else l[["loss"]].data / l["factor"] 141 | ) 142 | 143 | return loss, sample_size, logging_output 144 | 145 | @staticmethod 146 | def reduce_metrics(logging_outputs) -> None: 147 | """Aggregate logging outputs from data parallel training.""" 148 | sample_size = utils.item( 149 | sum(log.get("sample_size", 0) for log in logging_outputs) 150 | ) 151 | loss = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) 152 | nll_loss = utils.item(sum(log.get("nll_loss", 0) for log in logging_outputs)) 153 | 154 | metrics.log_scalar( 155 | "loss", loss / sample_size / math.log(2), sample_size, round=3 156 | ) 157 | metrics.log_scalar( 158 | "nll_loss", nll_loss / sample_size / math.log(2), sample_size, round=3 159 | ) 160 | metrics.log_derived( 161 | "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) 162 | ) 163 | 164 | for key in logging_outputs[0]: 165 | if key[-5:] == "-loss": 166 | val = sum(log.get(key, 0) for log in logging_outputs) 167 | metrics.log_scalar( 168 | key[:-5], 169 | val / sample_size / math.log(2) if sample_size > 0 else 0.0, 170 | sample_size, 171 | round=3, 172 | ) 173 | 174 | @staticmethod 175 | def logging_outputs_can_be_summed() -> bool: 176 | """ 177 | Whether the logging outputs returned by `forward` can be summed 178 | across workers prior to calling `reduce_metrics`. Setting this 179 | to True will improves distributed training speed. 180 | """ 181 | return True 182 | -------------------------------------------------------------------------------- /fs_plugins/criterions/transducer_loss_asr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from fairseq import metrics, utils 11 | from fairseq.criterions import FairseqCriterion, register_criterion 12 | from fairseq.dataclass import FairseqDataclass 13 | from torch import Tensor 14 | 15 | from dataclasses import dataclass, field 16 | 17 | 18 | @dataclass 19 | class LabelSmoothedDualImitationCriterionConfig(FairseqDataclass): 20 | label_smoothing: float = field( 21 | default=0.0, 22 | metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"}, 23 | ) 24 | 25 | 26 | 27 | @register_criterion("transducer_loss_asr", dataclass=LabelSmoothedDualImitationCriterionConfig) 28 | class LabelSmoothedDualImitationCriterion(FairseqCriterion): 29 | def __init__(self, task, label_smoothing): 30 | super().__init__(task) 31 | self.label_smoothing = label_smoothing 32 | 33 | def _compute_loss( 34 | self, outputs, targets, masks=None, label_smoothing=0.0, name="loss", factor=1.0 35 | ): 36 | """ 37 | outputs: batch x len x d_model 38 | targets: batch x len 39 | masks: batch x len 40 | 41 | policy_logprob: if there is some policy 42 | depends on the likelihood score as rewards. 43 | """ 44 | 45 | def mean_ds(x: Tensor, dim=None) -> Tensor: 46 | return ( 47 | x.float().mean().type_as(x) 48 | if dim is None 49 | else x.float().mean(dim).type_as(x) 50 | ) 51 | 52 | if masks is not None: 53 | outputs, targets = outputs[masks], targets[masks] 54 | 55 | if masks is not None and not masks.any(): 56 | nll_loss = torch.tensor(0) 57 | loss = nll_loss 58 | else: 59 | logits = F.log_softmax(outputs, dim=-1) 60 | if targets.dim() == 1: 61 | losses = F.nll_loss(logits, targets.to(logits.device), reduction="none") 62 | 63 | else: # soft-labels 64 | losses = F.kl_div(logits, targets.to(logits.device), reduction="none") 65 | losses = losses.sum(-1) 66 | 67 | nll_loss = mean_ds(losses) 68 | if label_smoothing > 0: 69 | loss = ( 70 | nll_loss * (1 - label_smoothing) - mean_ds(logits) * label_smoothing 71 | ) 72 | else: 73 | loss = nll_loss 74 | 75 | loss = loss * factor 76 | return {"name": name, "loss": loss, "nll_loss": nll_loss, "factor": factor} 77 | 78 | def _custom_loss(self, loss, name="loss", factor=1.0): 79 | return {"name": name, "loss": loss, "factor": factor} 80 | 81 | def forward(self, model, sample, reduce=True): 82 | """Compute the loss for the given sample. 83 | Returns a tuple with three elements: 84 | 1) the loss 85 | 2) the sample size, which is used as the denominator for the gradient 86 | 3) logging outputs to display while training 87 | """ 88 | nsentences, ntokens = sample["nsentences"], sample["ntokens"] 89 | 90 | # B x T 91 | src_tokens, src_lengths = ( 92 | sample["net_input"]["src_tokens"], 93 | sample["net_input"]["src_lengths"], 94 | ) 95 | tgt_tokens, prev_output_tokens = sample["transcript"], sample["net_input"]["prev_output_tokens_transcript"] 96 | 97 | outputs = model(src_tokens, src_lengths, prev_output_tokens, tgt_tokens) 98 | losses, nll_loss = [], [] 99 | 100 | for obj in outputs: 101 | if outputs[obj].get("loss", None) is None: 102 | _losses = self._compute_loss( 103 | outputs[obj].get("out"), 104 | outputs[obj].get("tgt"), 105 | outputs[obj].get("mask", None), 106 | outputs[obj].get("ls", 0.0), 107 | name=obj + "-loss", 108 | factor=outputs[obj].get("factor", 1.0), 109 | ) 110 | else: 111 | _losses = self._custom_loss( 112 | outputs[obj].get("loss"), 113 | name=obj + "-loss", 114 | factor=outputs[obj].get("factor", 1.0), 115 | ) 116 | 117 | losses += [_losses] 118 | if outputs[obj].get("nll_loss", False): 119 | nll_loss += [_losses.get("nll_loss", 0.0)] 120 | 121 | loss = sum(l["loss"] for l in losses) 122 | nll_loss = sum(l for l in nll_loss) if len(nll_loss) > 0 else loss.new_tensor(0) 123 | 124 | # NOTE: 125 | # we don't need to use sample_size as denominator for the gradient 126 | # here sample_size is just used for logging 127 | sample_size = 1 128 | logging_output = { 129 | "loss": loss.data, 130 | "nll_loss": nll_loss.data, 131 | "ntokens": ntokens, 132 | "nsentences": nsentences, 133 | "sample_size": sample_size, 134 | } 135 | 136 | for l in losses: 137 | logging_output[l["name"]] = ( 138 | utils.item(l["loss"].data / l["factor"]) 139 | if reduce 140 | else l[["loss"]].data / l["factor"] 141 | ) 142 | 143 | return loss, sample_size, logging_output 144 | 145 | @staticmethod 146 | def reduce_metrics(logging_outputs) -> None: 147 | """Aggregate logging outputs from data parallel training.""" 148 | sample_size = utils.item( 149 | sum(log.get("sample_size", 0) for log in logging_outputs) 150 | ) 151 | loss = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) 152 | nll_loss = utils.item(sum(log.get("nll_loss", 0) for log in logging_outputs)) 153 | 154 | metrics.log_scalar( 155 | "loss", loss / sample_size / math.log(2), sample_size, round=3 156 | ) 157 | metrics.log_scalar( 158 | "nll_loss", nll_loss / sample_size / math.log(2), sample_size, round=3 159 | ) 160 | metrics.log_derived( 161 | "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) 162 | ) 163 | 164 | for key in logging_outputs[0]: 165 | if key[-5:] == "-loss": 166 | val = sum(log.get(key, 0) for log in logging_outputs) 167 | metrics.log_scalar( 168 | key[:-5], 169 | val / sample_size / math.log(2) if sample_size > 0 else 0.0, 170 | sample_size, 171 | round=3, 172 | ) 173 | 174 | @staticmethod 175 | def logging_outputs_can_be_summed() -> bool: 176 | """ 177 | Whether the logging outputs returned by `forward` can be summed 178 | across workers prior to calling `reduce_metrics`. Setting this 179 | to True will improves distributed training speed. 180 | """ 181 | return True 182 | -------------------------------------------------------------------------------- /fs_plugins/scripts/average_checkpoints.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import collections 9 | import os 10 | import re 11 | 12 | import torch 13 | from fairseq.file_io import PathManager 14 | 15 | 16 | def average_checkpoints(inputs): 17 | """Loads checkpoints from inputs and returns a model with averaged weights. 18 | 19 | Args: 20 | inputs: An iterable of string paths of checkpoints to load from. 21 | 22 | Returns: 23 | A dict of string keys mapping to various values. The 'model' key 24 | from the returned dict should correspond to an OrderedDict mapping 25 | string parameter names to torch Tensors. 26 | """ 27 | params_dict = collections.OrderedDict() 28 | params_keys = None 29 | new_state = None 30 | num_models = len(inputs) 31 | 32 | for fpath in inputs: 33 | with PathManager.open(fpath, "rb") as f: 34 | state = torch.load( 35 | f, 36 | map_location=( 37 | lambda s, _: torch.serialization.default_restore_location(s, "cpu") 38 | ), 39 | ) 40 | # Copies over the settings from the first checkpoint 41 | if new_state is None: 42 | new_state = state 43 | 44 | model_params = state["model"] 45 | 46 | model_params_keys = list(model_params.keys()) 47 | if params_keys is None: 48 | params_keys = model_params_keys 49 | elif params_keys != model_params_keys: 50 | raise KeyError( 51 | "For checkpoint {}, expected list of params: {}, " 52 | "but found: {}".format(f, params_keys, model_params_keys) 53 | ) 54 | 55 | for k in params_keys: 56 | p = model_params[k] 57 | if isinstance(p, torch.HalfTensor): 58 | p = p.float() 59 | if k not in params_dict: 60 | params_dict[k] = p.clone() 61 | # NOTE: clone() is needed in case of p is a shared parameter 62 | else: 63 | params_dict[k] += p 64 | 65 | averaged_params = collections.OrderedDict() 66 | for k, v in params_dict.items(): 67 | averaged_params[k] = v 68 | if averaged_params[k].is_floating_point(): 69 | averaged_params[k].div_(num_models) 70 | else: 71 | averaged_params[k] //= num_models 72 | new_state["model"] = averaged_params 73 | return new_state 74 | 75 | 76 | def last_n_checkpoints(path, n, update_based, upper_bound=None): 77 | # assert len(paths) == 1 78 | # path = paths[0] 79 | if update_based: 80 | pt_regexp = re.compile(r"checkpoint_\d+_(\d+)\.pt") 81 | else: 82 | pt_regexp = re.compile(r"checkpoint(\d+)\.pt") 83 | files = PathManager.ls(path) 84 | 85 | entries = [] 86 | for f in files: 87 | m = pt_regexp.fullmatch(f) 88 | if m is not None: 89 | sort_key = int(m.group(1)) 90 | if upper_bound is None or sort_key <= upper_bound: 91 | entries.append((sort_key, m.group(0))) 92 | if len(entries) < n: 93 | raise Exception( 94 | "Found {} checkpoint files but need at least {}", len(entries), n 95 | ) 96 | return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]] 97 | 98 | def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'): 99 | """Retrieves all checkpoints found in `path` directory. 100 | 101 | Checkpoints are identified by matching filename to the specified pattern. If 102 | the pattern contains groups, the result will be sorted by the first group in 103 | descending order. 104 | """ 105 | pt_regexp = re.compile(pattern) 106 | files = PathManager.ls(path) 107 | 108 | entries = [] 109 | for i, f in enumerate(files): 110 | m = pt_regexp.fullmatch(f) 111 | if m is not None: 112 | idx = float(m.group(1)) if len(m.groups()) > 0 else i 113 | entries.append((idx, m.group(0))) 114 | return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] 115 | 116 | def best_n_checkpoints(paths, n, max_metric, best_checkpoints_metric): 117 | checkpoints = checkpoint_paths( 118 | paths, 119 | pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format( 120 | best_checkpoints_metric 121 | ), 122 | ) 123 | 124 | if not max_metric: 125 | checkpoints = checkpoints[::-1] 126 | 127 | if len(checkpoints) < n: 128 | raise RuntimeError(f"num is too large, not enough checkpoints: {str(checkpoints)}") 129 | return checkpoints[:n] 130 | 131 | def main(): 132 | parser = argparse.ArgumentParser( 133 | description="Tool to average the params of input checkpoints to " 134 | "produce a new checkpoint", 135 | ) 136 | # fmt: off 137 | parser.add_argument('--inputs', required=True, nargs='+', 138 | help='Input checkpoint file paths.') 139 | parser.add_argument('--output', required=True, metavar='FILE', 140 | help='Write the new checkpoint containing the averaged weights to this path.') 141 | num_group = parser.add_mutually_exclusive_group() 142 | num_group.add_argument('--num-epoch-checkpoints', type=int, 143 | help='if set, will try to find checkpoints with names checkpoint_xx.pt in the ' 144 | 'path specified by input, and average last this many of them.') 145 | num_group.add_argument('--num-update-checkpoints', type=int, 146 | help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by' 147 | ' input, and average last this many of them.') 148 | parser.add_argument('--checkpoint-upper-bound', type=int, 149 | help='when using --num-epoch-checkpoints, this will set an upper bound on which epoch to use, ' 150 | 'when using --num-update-checkpoints, this will set an upper bound on which update to use' 151 | 'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be' 152 | ' averaged.' 153 | 'e.g., with --num-update-checkpoints=10 --checkpoint-upper-bound=50000, checkpoints 40500-50000 would' 154 | ' be averaged assuming --save-interval-updates 500' 155 | ) 156 | parser.add_argument('--best-checkpoints-metric', type=str, default=None) 157 | parser.add_argument('--max-metric', action="store_true", default=False) 158 | parser.add_argument('--num-best-checkpoints-metric', type=int, default=None) 159 | parser.add_argument('--debug', action="store_true") 160 | # fmt: on 161 | args = parser.parse_args() 162 | print(args) 163 | 164 | if args.debug: 165 | import ptvsd 166 | ptvsd.enable_attach() 167 | import logging 168 | logging.warning("wait debug") 169 | ptvsd.wait_for_attach() 170 | 171 | num = None 172 | is_update_based = "epoch" 173 | if args.num_update_checkpoints is not None: 174 | num = args.num_update_checkpoints 175 | is_update_based = "update" 176 | elif args.num_epoch_checkpoints is not None: 177 | num = args.num_epoch_checkpoints 178 | elif args.num_best_checkpoints_metric is not None: 179 | num = args.num_best_checkpoints_metric 180 | is_update_based = "metric" 181 | 182 | 183 | assert args.checkpoint_upper_bound is None or ( 184 | args.num_epoch_checkpoints is not None 185 | or args.num_update_checkpoints is not None 186 | ), "--checkpoint-upper-bound requires --num-epoch-checkpoints or --num-update-checkpoints" 187 | assert ( 188 | args.num_epoch_checkpoints is None or args.num_update_checkpoints is None 189 | ), "Cannot combine --num-epoch-checkpoints and --num-update-checkpoints" 190 | 191 | if num is not None: 192 | if is_update_based == "metric": 193 | args.inputs = best_n_checkpoints( 194 | args.inputs[0], num, args.max_metric, args.best_checkpoints_metric 195 | ) 196 | else: 197 | args.inputs = last_n_checkpoints( 198 | args.inputs[0], num, is_update_based == "update", upper_bound=args.checkpoint_upper_bound, 199 | ) 200 | print("averaging checkpoints: ", args.inputs) 201 | 202 | new_state = average_checkpoints(args.inputs) 203 | with PathManager.open(args.output, "wb") as f: 204 | torch.save(new_state, f) 205 | print("Finished writing averaged checkpoint to {}".format(args.output)) 206 | 207 | 208 | if __name__ == "__main__": 209 | main() 210 | -------------------------------------------------------------------------------- /fs_plugins/models/transducer/transducer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import numpy as np 8 | from pathlib import Path 9 | 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import Tensor 14 | 15 | from fairseq.dataclass.utils import gen_parser_from_dataclass 16 | from fairseq import utils, checkpoint_utils 17 | from fairseq.distributed import fsdp_wrap 18 | from fairseq.models import register_model, register_model_architecture 19 | from fairseq.models import FairseqEncoderDecoderModel 20 | from fairseq.models.transformer import Embedding 21 | from fairseq.modules.checkpoint_activations import checkpoint_wrapper 22 | 23 | 24 | 25 | 26 | import logging 27 | logger = logging.getLogger(__name__) 28 | 29 | from fs_plugins.models.transducer.transducer_config import TransducerConfig 30 | from fs_plugins.models.transducer.transducer_loss import TransducerLoss 31 | from fs_plugins.modules.unidirectional_encoder import UnidirectionalAudioTransformerEncoder 32 | from fs_plugins.modules.transducer_decoder import TransducerDecoder 33 | 34 | import pdb 35 | 36 | DEFAULT_MAX_TEXT_POSITIONS = 1024 37 | DEFAULT_MAX_AUDIO_POSITIONS = 6000 38 | 39 | DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) 40 | 41 | 42 | 43 | @register_model("transformer_transducer") 44 | class TransducerModel(FairseqEncoderDecoderModel): 45 | def __init__(self, args, encoder, decoder): 46 | super().__init__(encoder, decoder) 47 | self.args = args 48 | self.criterion = TransducerLoss(blank=self.decoder.blank_idx) 49 | self.padding_idx = decoder.dictionary.pad() 50 | 51 | @classmethod 52 | def add_args(cls, parser): 53 | """Add model-specific arguments to the parser.""" 54 | # we want to build the args recursively in this case. 55 | # do not set defaults so that settings defaults from various architectures still works 56 | gen_parser_from_dataclass( 57 | parser, TransducerConfig(), delete_default=True, with_prefix="" 58 | ) 59 | 60 | @classmethod 61 | def build_model(cls, args, task): 62 | """Build a new model instance.""" 63 | 64 | # make sure all arguments are present in older models 65 | # base_architecture(args) 66 | 67 | if getattr(args, "max_source_positions", None) is None: 68 | args.max_source_positions = DEFAULT_MAX_AUDIO_POSITIONS 69 | if getattr(args, "max_target_positions", None) is None: 70 | args.max_target_positions = DEFAULT_MAX_TEXT_POSITIONS 71 | 72 | 73 | decoder_embed_tokens = cls.build_embedding( 74 | args, task.target_dictionary, args.decoder_embed_dim 75 | ) 76 | 77 | encoder = cls.build_encoder(args) 78 | decoder = cls.build_decoder(args, task.target_dictionary, decoder_embed_tokens) 79 | 80 | model = cls(args, encoder, decoder) 81 | 82 | return model 83 | 84 | 85 | @classmethod 86 | def build_embedding(cls, args, dictionary, embed_dim, path=None): 87 | num_embeddings = len(dictionary) 88 | padding_idx = dictionary.pad() 89 | 90 | emb = Embedding(num_embeddings, embed_dim, padding_idx) 91 | # if provided, load from preloaded dictionaries 92 | if path: 93 | embed_dict = utils.parse_embedding(path) 94 | utils.load_embedding(embed_dict, dictionary, emb) 95 | return emb 96 | 97 | 98 | @classmethod 99 | def build_encoder(cls, args): 100 | encoder = UnidirectionalAudioTransformerEncoder(args) 101 | pretraining_path = getattr(args, "load_pretrained_encoder_from", None) 102 | if pretraining_path is not None: 103 | if not Path(pretraining_path).exists(): 104 | logger.warning( 105 | f"skipped pretraining because {pretraining_path} does not exist" 106 | ) 107 | else: 108 | encoder = checkpoint_utils.load_pretrained_component_from_model( 109 | component=encoder, checkpoint=pretraining_path 110 | ) 111 | logger.info(f"loaded pretrained encoder from: {pretraining_path}") 112 | 113 | return encoder 114 | 115 | @classmethod 116 | def build_decoder(cls, args, tgt_dict, embed_tokens): 117 | model= TransducerDecoder(args, tgt_dict, embed_tokens) 118 | 119 | #embed_dim= embed_tokens.embedding_dim 120 | 121 | #if model.lm.embed_positions is not None and args.rand_pos_decoder >0: 122 | # model.lm.embed_positions = PositionalEmbedding( 123 | # model.lm.max_target_positions, 124 | # embed_dim, 125 | # model.lm.padding_idx, 126 | # rand_max = args.rand_pos_decoder, 127 | # learned=args.decoder_learned_pos, 128 | # ) 129 | return model 130 | 131 | 132 | def forward( 133 | self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens 134 | ): 135 | encoder_out = self.encoder(src_tokens, fbk_lengths=src_lengths) 136 | logits, fake_src_lengths = self.decoder(prev_output_tokens, encoder_out) 137 | 138 | 139 | tgt_lengths = tgt_tokens.ne(self.padding_idx).sum(dim=-1) 140 | 141 | rnn_t_loss = self.criterion(logits, tgt_tokens, fake_src_lengths, tgt_lengths) 142 | 143 | ret_val = { 144 | "rnn_t_loss": {"loss": rnn_t_loss}, 145 | } 146 | 147 | return ret_val 148 | 149 | 150 | 151 | 152 | @register_model_architecture( 153 | "transformer_transducer", "transformer_transducer" 154 | ) 155 | def base_architecture(args): 156 | args.encoder_embed_path = getattr(args, "encoder_embed_path", None) 157 | args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) 158 | args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) 159 | args.encoder_layers = getattr(args, "encoder_layers", 6) 160 | args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) 161 | args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) 162 | args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) 163 | args.decoder_embed_path = getattr(args, "decoder_embed_path", None) 164 | args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) 165 | args.decoder_ffn_embed_dim = getattr( 166 | args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim 167 | ) 168 | args.decoder_layers = getattr(args, "decoder_layers", 6) 169 | args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) 170 | args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) 171 | args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) 172 | args.attention_dropout = getattr(args, "attention_dropout", 0.0) 173 | args.activation_dropout = getattr(args, "activation_dropout", 0.0) 174 | args.activation_fn = getattr(args, "activation_fn", "relu") 175 | args.dropout = getattr(args, "dropout", 0.1) 176 | args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) 177 | args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) 178 | args.share_decoder_input_output_embed = getattr( 179 | args, "share_decoder_input_output_embed", False 180 | ) 181 | args.no_token_positional_embeddings = getattr( 182 | args, "no_token_positional_embeddings", False 183 | ) 184 | args.adaptive_input = getattr(args, "adaptive_input", False) 185 | args.apply_bert_init = getattr(args, "apply_bert_init", False) 186 | 187 | args.decoder_output_dim = getattr( 188 | args, "decoder_output_dim", args.decoder_embed_dim 189 | ) 190 | args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) 191 | 192 | # --- speech arguments --- 193 | args.rand_pos_encoder = getattr(args, "rand_pos_encoder", 300) 194 | args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) 195 | args.conv_type= getattr(args, "conv_type", "shallow2d_base") 196 | args.no_audio_positional_embeddings = getattr( 197 | args, "no_audio_positional_embeddings", False 198 | ) 199 | args.main_context = getattr(args, "main_context", 32) 200 | args.right_context = getattr(args, "right_context", 16) 201 | args.encoder_max_relative_position = getattr(args, "encoder_max_relative_position", 32) 202 | args.no_scale_embedding = getattr(args, "no_scale_embedding", False) 203 | args.transducer_downsample = getattr(args, "transducer_downsample", 1) 204 | 205 | @register_model_architecture( 206 | "transformer_transducer", "t_t" 207 | ) 208 | def t_t_architecture(args): 209 | args.encoder_layers = getattr(args, "encoder_layers", 16) 210 | args.decoder_layers = getattr(args, "decoder_layers", 2) 211 | args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) 212 | base_architecture(args) -------------------------------------------------------------------------------- /fs_plugins/models/transducer/attention_transducer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import numpy as np 8 | from pathlib import Path 9 | 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import Tensor 14 | 15 | from fairseq.dataclass.utils import gen_parser_from_dataclass 16 | from fairseq import utils, checkpoint_utils 17 | from fairseq.models import register_model, register_model_architecture 18 | from fairseq.models import FairseqEncoderDecoderModel 19 | from fairseq.models.transformer import Embedding 20 | 21 | 22 | 23 | 24 | import logging 25 | logger = logging.getLogger(__name__) 26 | 27 | from fs_plugins.models.transducer.transducer_config import TransducerConfig 28 | from fs_plugins.models.transducer.transducer_loss import TransducerLoss 29 | from fs_plugins.modules.unidirectional_encoder import UnidirectionalAudioTransformerEncoder 30 | from fs_plugins.modules.attention_transducer_decoder import AttentionTransducerDecoder 31 | from fs_plugins.utils import load_pretrained_component_from_model_modified 32 | 33 | import pdb 34 | 35 | DEFAULT_MAX_TEXT_POSITIONS = 1024 36 | DEFAULT_MAX_AUDIO_POSITIONS = 6000 37 | 38 | DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) 39 | 40 | 41 | 42 | @register_model("attention_transformer_transducer") 43 | class TransducerModel(FairseqEncoderDecoderModel): 44 | def __init__(self, args, encoder, decoder): 45 | super().__init__(encoder, decoder) 46 | self.args = args 47 | self.criterion = TransducerLoss(blank=self.decoder.blank_idx) 48 | self.padding_idx = decoder.dictionary.pad() 49 | 50 | @classmethod 51 | def add_args(cls, parser): 52 | """Add model-specific arguments to the parser.""" 53 | # we want to build the args recursively in this case. 54 | # do not set defaults so that settings defaults from various architectures still works 55 | gen_parser_from_dataclass( 56 | parser, TransducerConfig(), delete_default=True, with_prefix="" 57 | ) 58 | 59 | @classmethod 60 | def build_model(cls, args, task): 61 | """Build a new model instance.""" 62 | 63 | # make sure all arguments are present in older models 64 | # base_architecture(args) 65 | 66 | if getattr(args, "max_source_positions", None) is None: 67 | args.max_source_positions = DEFAULT_MAX_AUDIO_POSITIONS 68 | if getattr(args, "max_target_positions", None) is None: 69 | args.max_target_positions = DEFAULT_MAX_TEXT_POSITIONS 70 | 71 | 72 | decoder_embed_tokens = cls.build_embedding( 73 | args, task.target_dictionary, args.decoder_embed_dim 74 | ) 75 | 76 | encoder = cls.build_encoder(args) 77 | decoder = cls.build_decoder(args, task.target_dictionary, decoder_embed_tokens) 78 | 79 | model = cls(args, encoder, decoder) 80 | 81 | return model 82 | 83 | 84 | @classmethod 85 | def build_embedding(cls, args, dictionary, embed_dim, path=None): 86 | num_embeddings = len(dictionary) 87 | padding_idx = dictionary.pad() 88 | 89 | emb = Embedding(num_embeddings, embed_dim, padding_idx) 90 | # if provided, load from preloaded dictionaries 91 | if path: 92 | embed_dict = utils.parse_embedding(path) 93 | utils.load_embedding(embed_dict, dictionary, emb) 94 | return emb 95 | 96 | 97 | @classmethod 98 | def build_encoder(cls, args): 99 | encoder = UnidirectionalAudioTransformerEncoder(args) 100 | pretraining_path = getattr(args, "load_pretrained_encoder_from", None) 101 | if pretraining_path is not None: 102 | if not Path(pretraining_path).exists(): 103 | logger.warning( 104 | f"skipped loading pretrained encoder because {pretraining_path} does not exist" 105 | ) 106 | else: 107 | encoder = checkpoint_utils.load_pretrained_component_from_model( 108 | component=encoder, checkpoint=pretraining_path 109 | ) 110 | logger.info(f"loaded pretrained encoder from: {pretraining_path}") 111 | 112 | return encoder 113 | 114 | @classmethod 115 | def build_decoder(cls, args, tgt_dict, embed_tokens): 116 | decoder = AttentionTransducerDecoder(args, tgt_dict, embed_tokens) 117 | 118 | pretraining_path = getattr(args, "load_pretrained_decoder_from", None) 119 | if pretraining_path is not None: 120 | if not Path(pretraining_path).exists(): 121 | logger.warning( 122 | f"skipped loading pretrained decoder because {pretraining_path} does not exist" 123 | ) 124 | else: 125 | #decoder = checkpoint_utils.load_pretrained_component_from_model( 126 | # component=decoder, checkpoint=pretraining_path, strict=False 127 | #) 128 | decoder, keys_info = load_pretrained_component_from_model_modified( 129 | component=decoder, checkpoint=pretraining_path, strict=False 130 | ) 131 | logger.info(f"loaded pretrained decoder from: {pretraining_path}") 132 | logger.info(f"keys information: {keys_info}") 133 | 134 | return decoder 135 | 136 | 137 | def forward( 138 | self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens 139 | ): 140 | encoder_out = self.encoder(src_tokens, fbk_lengths=src_lengths) 141 | logits, fake_src_lengths = self.decoder(prev_output_tokens, encoder_out) 142 | 143 | 144 | tgt_lengths = tgt_tokens.ne(self.padding_idx).sum(dim=-1) 145 | 146 | rnn_t_loss = self.criterion(logits, tgt_tokens, fake_src_lengths, tgt_lengths) 147 | 148 | 149 | ret_val = { 150 | "rnn_t_loss": {"loss": rnn_t_loss}, 151 | } 152 | 153 | return ret_val 154 | 155 | 156 | 157 | 158 | @register_model_architecture( 159 | "attention_transformer_transducer", "attention_transformer_transducer" 160 | ) 161 | def base_architecture(args): 162 | args.encoder_embed_path = getattr(args, "encoder_embed_path", None) 163 | args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) 164 | args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) 165 | args.encoder_layers = getattr(args, "encoder_layers", 6) 166 | args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) 167 | args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) 168 | args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) 169 | args.decoder_embed_path = getattr(args, "decoder_embed_path", None) 170 | args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) 171 | args.decoder_ffn_embed_dim = getattr( 172 | args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim 173 | ) 174 | args.decoder_layers = getattr(args, "decoder_layers", 6) 175 | args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) 176 | args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) 177 | args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) 178 | args.attention_dropout = getattr(args, "attention_dropout", 0.0) 179 | args.activation_dropout = getattr(args, "activation_dropout", 0.0) 180 | args.activation_fn = getattr(args, "activation_fn", "relu") 181 | args.dropout = getattr(args, "dropout", 0.1) 182 | args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) 183 | args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) 184 | args.share_decoder_input_output_embed = getattr( 185 | args, "share_decoder_input_output_embed", False 186 | ) 187 | args.no_token_positional_embeddings = getattr( 188 | args, "no_token_positional_embeddings", False 189 | ) 190 | args.adaptive_input = getattr(args, "adaptive_input", False) 191 | args.apply_bert_init = getattr(args, "apply_bert_init", False) 192 | 193 | args.decoder_output_dim = getattr( 194 | args, "decoder_output_dim", args.decoder_embed_dim 195 | ) 196 | args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) 197 | 198 | # --- speech arguments --- 199 | args.rand_pos_encoder = getattr(args, "rand_pos_encoder", 300) 200 | args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) 201 | args.conv_type= getattr(args, "conv_type", "shallow2d_base") 202 | args.no_audio_positional_embeddings = getattr( 203 | args, "no_audio_positional_embeddings", False 204 | ) 205 | args.main_context = getattr(args, "main_context", 32) 206 | args.right_context = getattr(args, "right_context", 16) 207 | args.encoder_max_relative_position = getattr(args, "encoder_max_relative_position", 32) 208 | args.no_scale_embedding = getattr(args, "no_scale_embedding", False) 209 | args.transducer_downsample = getattr(args, "transducer_downsample", 1) 210 | 211 | @register_model_architecture( 212 | "attention_transformer_transducer", "attention_t_t" 213 | ) 214 | def t_t_architecture(args): 215 | args.encoder_layers = getattr(args, "encoder_layers", 16) 216 | args.decoder_layers = getattr(args, "decoder_layers", 2) 217 | args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) 218 | base_architecture(args) -------------------------------------------------------------------------------- /fs_plugins/tasks/transducer_speech_to_text.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import json 7 | import torch 8 | import logging 9 | import numpy as np 10 | 11 | from argparse import Namespace 12 | from fairseq import utils, metrics 13 | from fairseq.tasks import register_task 14 | from fairseq.utils import new_arange 15 | from fairseq.optim.amp_optimizer import AMPOptimizer 16 | 17 | 18 | from fairseq.tasks.speech_to_text import SpeechToTextTask 19 | 20 | from fs_plugins.datasets.transducer_speech_to_text_dataset import ( 21 | TransducerSpeechToTextDataset, 22 | TransducerSpeechToTextDatasetCreator, 23 | ) 24 | 25 | 26 | logger = logging.getLogger(__name__) 27 | EVAL_BLEU_ORDER = 4 28 | 29 | @register_task("transducer_speech_to_text") 30 | class TransducerSpeechToTextTask(SpeechToTextTask): 31 | 32 | def __init__(self, args, tgt_dict): 33 | super().__init__(args, tgt_dict) 34 | self.pre_tokenizer = self.build_tokenizer(self.args) 35 | self.bpe_tokenizer = self.build_bpe(self.args) 36 | blank_index = self.tgt_dict.add_symbol("") 37 | self.tgt_dict.blank_index = blank_index 38 | 39 | def build_model(self, args, from_checkpoint=False): 40 | model = super().build_model(args, from_checkpoint) 41 | if self.args.eval_bleu: 42 | gen_args = json.loads(self.args.eval_bleu_args) 43 | self.sequence_generator = self.build_generator([model], Namespace(**gen_args)) 44 | return model 45 | 46 | @classmethod 47 | def add_args(cls, parser): 48 | SpeechToTextTask.add_args(parser) 49 | 50 | 51 | # options for reporting BLEU during validation 52 | parser.add_argument( 53 | "--eval-bleu", 54 | action="store_true", 55 | help="evaluation with BLEU scores", 56 | ) 57 | parser.add_argument( 58 | "--eval-bleu-detok", 59 | type=str, 60 | default="space", 61 | help="detokenize before computing BLEU (e.g., 'moses'); " 62 | "required if using --eval-bleu; use 'space' to " 63 | "disable detokenization; see fairseq.data.encoders " 64 | "for other options", 65 | ) 66 | parser.add_argument( 67 | "--eval-bleu-detok-args", 68 | type=str, 69 | metavar="JSON", 70 | help="args for building the tokenizer, if needed", 71 | ) 72 | parser.add_argument( 73 | "--eval-tokenized-bleu", 74 | action="store_true", 75 | default=False, 76 | help="compute tokenized BLEU instead of sacrebleu", 77 | ) 78 | parser.add_argument( 79 | "--eval-bleu-remove-bpe", 80 | nargs="?", 81 | const="@@ ", 82 | default=None, 83 | help="remove BPE before computing BLEU", 84 | ) 85 | parser.add_argument( 86 | "--eval-bleu-args", 87 | type=str, 88 | metavar="JSON", 89 | help="generation args for BLUE scoring, " 90 | "e.g., '{\"beam\": 4, \"lenpen\": 0.6}'", 91 | ) 92 | parser.add_argument( 93 | "--eval-bleu-print-samples", 94 | action="store_true", 95 | help="print sample generations during validation", 96 | ) 97 | parser.add_argument( 98 | "--eval-bleu-bpe", 99 | type=str, 100 | metavar="BPE", 101 | default=None, 102 | help="args for building the bpe, if needed", 103 | ) 104 | parser.add_argument( 105 | "--eval-bleu-bpe-path", 106 | type=str, 107 | metavar='BPE', 108 | help="args for building the bpe, if needed", 109 | ) 110 | 111 | def load_dataset(self, split, epoch=1, combine=False, **kwargs): 112 | is_train_split = split.startswith("train") 113 | pre_tokenizer = self.build_tokenizer(self.args) 114 | bpe_tokenizer = self.build_bpe(self.args) 115 | self.datasets[split] = TransducerSpeechToTextDatasetCreator.from_tsv( 116 | root=self.args.data, 117 | cfg=self.data_cfg, 118 | splits=split, 119 | tgt_dict=self.tgt_dict, 120 | pre_tokenizer=pre_tokenizer, 121 | bpe_tokenizer=bpe_tokenizer, 122 | is_train_split=is_train_split, 123 | epoch=epoch, 124 | seed=self.args.seed, 125 | speaker_to_id=self.speaker_to_id, 126 | multitask=self.multitask_tasks, 127 | ) 128 | 129 | 130 | 131 | def build_generator(self, models, args, **unused): 132 | # add models input to match the API for SequenceGenerator 133 | #TODO 134 | raise NotImplementedError 135 | 136 | 137 | def train_step( 138 | self, sample, model, criterion, optimizer, update_num, ignore_grad=False 139 | ): 140 | model.train() 141 | sample["update_num"] = update_num 142 | 143 | 144 | model.set_num_updates(update_num) 145 | with torch.autograd.profiler.record_function("forward"): 146 | with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))): 147 | loss, sample_size, logging_output = criterion(model, sample) 148 | if ignore_grad: 149 | loss *= 0 150 | with torch.autograd.profiler.record_function("backward"): 151 | optimizer.backward(loss) 152 | return loss, sample_size, logging_output 153 | 154 | def valid_step(self, sample, model, criterion): 155 | model.eval() 156 | 157 | with torch.no_grad(): 158 | loss, sample_size, logging_output = criterion(model, sample) 159 | if self.args.eval_bleu: 160 | bleu = self._inference_with_bleu(self.sequence_generator, sample, model) 161 | logging_output["_bleu_sys_len"] = bleu.sys_len 162 | logging_output["_bleu_ref_len"] = bleu.ref_len 163 | # we split counts into separate entries so that they can be 164 | # summed efficiently across workers using fast-stat-sync 165 | assert len(bleu.counts) == EVAL_BLEU_ORDER 166 | for i in range(EVAL_BLEU_ORDER): 167 | logging_output["_bleu_counts_" + str(i)] = bleu.counts[i] 168 | logging_output["_bleu_totals_" + str(i)] = bleu.totals[i] 169 | 170 | return loss, sample_size, logging_output 171 | 172 | def reduce_metrics(self, logging_outputs, criterion): 173 | super().reduce_metrics(logging_outputs, criterion) 174 | if self.args.eval_bleu: 175 | def sum_logs(key): 176 | if key in logging_outputs[0]: 177 | return sum(log[key].cpu().numpy() for log in logging_outputs) 178 | return sum(log.get(key, 0) for log in logging_outputs) 179 | 180 | counts, totals = [], [] 181 | for i in range(EVAL_BLEU_ORDER): 182 | counts.append(sum_logs("_bleu_counts_" + str(i))) 183 | totals.append(sum_logs("_bleu_totals_" + str(i))) 184 | 185 | if max(totals) > 0: 186 | # log counts as numpy arrays -- log_scalar will sum them correctly 187 | metrics.log_scalar("_bleu_counts", np.array(counts)) 188 | metrics.log_scalar("_bleu_totals", np.array(totals)) 189 | metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len")) 190 | metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len")) 191 | 192 | def compute_bleu(meters): 193 | import inspect 194 | import sacrebleu 195 | 196 | fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0] 197 | if "smooth_method" in fn_sig: 198 | smooth = {"smooth_method": "exp"} 199 | else: 200 | smooth = {"smooth": "exp"} 201 | bleu = sacrebleu.compute_bleu( 202 | correct=meters["_bleu_counts"].sum, 203 | total=meters["_bleu_totals"].sum, 204 | sys_len=meters["_bleu_sys_len"].sum, 205 | ref_len=meters["_bleu_ref_len"].sum, 206 | **smooth 207 | ) 208 | return round(bleu.score, 2) 209 | 210 | metrics.log_derived("bleu", compute_bleu) 211 | 212 | def _inference_with_bleu(self, generator, sample, model): 213 | import sacrebleu 214 | 215 | def decode(toks, escape_unk=False): 216 | s = self.tgt_dict.string( 217 | toks.int().cpu(), 218 | self.args.eval_bleu_remove_bpe, 219 | unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"), 220 | ) 221 | if self.bpe_tokenizer is not None: 222 | s = self.bpe_tokenizer.decode(s) 223 | if self.pre_tokenizer is not None: 224 | s = self.pre_tokenizer.decode(s) 225 | return s 226 | 227 | gen_out = self.inference_step(generator, [model], sample, prefix_tokens=None) 228 | hyps, refs = [], [] 229 | for i in range(len(gen_out)): 230 | hyp = decode(gen_out[i][0]["tokens"]) 231 | ref = decode( 232 | utils.strip_pad(sample["target"][i], self.tgt_dict.pad()), 233 | escape_unk=True, # don't count as matches to the hypo 234 | ) 235 | hyps.append(hyp) 236 | refs.append(ref) 237 | 238 | if self.args.eval_bleu_print_samples: 239 | logger.info("example hypothesis: " + hyps[0]) 240 | logger.info("example reference: " + refs[0]) 241 | if self.args.eval_tokenized_bleu: 242 | return sacrebleu.corpus_bleu(hyps, [refs], tokenize="none") 243 | else: 244 | return sacrebleu.corpus_bleu(hyps, [refs]) 245 | 246 | -------------------------------------------------------------------------------- /fs_plugins/scripts/prep_mustc_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import logging 9 | import os 10 | from pathlib import Path 11 | import shutil 12 | from itertools import groupby 13 | from tempfile import NamedTemporaryFile 14 | from typing import Tuple 15 | 16 | import numpy as np 17 | import pandas as pd 18 | import soundfile as sf 19 | from examples.speech_to_text.data_utils import ( 20 | create_zip, 21 | extract_fbank_features, 22 | filter_manifest_df, 23 | gen_config_yaml, 24 | gen_vocab, 25 | get_zip_manifest, 26 | load_df_from_tsv, 27 | save_df_to_tsv, 28 | cal_gcmvn_stats, 29 | ) 30 | import torch 31 | from torch.utils.data import Dataset 32 | from tqdm import tqdm 33 | 34 | from fairseq.data.audio.audio_utils import get_waveform, convert_waveform 35 | 36 | 37 | log = logging.getLogger(__name__) 38 | 39 | 40 | MANIFEST_COLUMNS = ["id", "audio", "n_frames", "src_text", "tgt_text", "speaker"] 41 | 42 | 43 | class MUSTC(Dataset): 44 | """ 45 | Create a Dataset for MuST-C. Each item is a tuple of the form: 46 | waveform, sample_rate, source utterance, target utterance, speaker_id, 47 | utterance_id 48 | """ 49 | 50 | SPLITS = ["train", "dev", "tst-COMMON", "tst-HE"] 51 | LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"] 52 | 53 | def __init__(self, root: str, lang: str, split: str) -> None: 54 | assert split in self.SPLITS and lang in self.LANGUAGES 55 | _root = Path(root) / f"en-{lang}" / "data" / split 56 | wav_root, txt_root = _root / "wav", _root / "txt" 57 | assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir() 58 | # Load audio segments 59 | try: 60 | import yaml 61 | except ImportError: 62 | print("Please install PyYAML to load the MuST-C YAML files") 63 | with open(txt_root / f"{split}.yaml") as f: 64 | segments = yaml.load(f, Loader=yaml.BaseLoader) 65 | # Load source and target utterances 66 | for _lang in ["en", lang]: 67 | with open(txt_root / f"{split}.{_lang}") as f: 68 | utterances = [r.strip() for r in f] 69 | assert len(segments) == len(utterances) 70 | for i, u in enumerate(utterances): 71 | segments[i][_lang] = u 72 | # Gather info 73 | self.data = [] 74 | for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]): 75 | wav_path = wav_root / wav_filename 76 | sample_rate = sf.info(wav_path.as_posix()).samplerate 77 | seg_group = sorted(_seg_group, key=lambda x: x["offset"]) 78 | for i, segment in enumerate(seg_group): 79 | offset = int(float(segment["offset"]) * sample_rate) 80 | n_frames = int(float(segment["duration"]) * sample_rate) 81 | _id = f"{wav_path.stem}_{i}" 82 | self.data.append( 83 | ( 84 | wav_path.as_posix(), 85 | offset, 86 | n_frames, 87 | sample_rate, 88 | segment["en"], 89 | segment[lang], 90 | segment["speaker_id"], 91 | _id, 92 | ) 93 | ) 94 | 95 | def __getitem__( 96 | self, n: int 97 | ) -> Tuple[torch.Tensor, int, str, str, str, str]: 98 | wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, \ 99 | utt_id = self.data[n] 100 | waveform, _ = get_waveform(wav_path, frames=n_frames, start=offset) 101 | waveform = torch.from_numpy(waveform) 102 | return waveform, sr, src_utt, tgt_utt, spk_id, utt_id 103 | 104 | def __len__(self) -> int: 105 | return len(self.data) 106 | 107 | 108 | def process(args): 109 | root = Path(args.data_root).absolute() 110 | for lang in MUSTC.LANGUAGES: 111 | cur_root = root / f"en-{lang}" 112 | if not cur_root.is_dir(): 113 | print(f"{cur_root.as_posix()} does not exist. Skipped.") 114 | continue 115 | # Extract features 116 | audio_root = cur_root / ("flac" if args.use_audio_input else "fbank80") 117 | audio_root.mkdir(exist_ok=True) 118 | 119 | for split in MUSTC.SPLITS: 120 | print(f"Fetching split {split}...") 121 | dataset = MUSTC(root.as_posix(), lang, split) 122 | if args.use_audio_input: 123 | print("Converting audios...") 124 | for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): 125 | tgt_sample_rate = 16_000 126 | _wavform, _ = convert_waveform( 127 | waveform, sample_rate, to_mono=True, 128 | to_sample_rate=tgt_sample_rate 129 | ) 130 | sf.write( 131 | (audio_root / f"{utt_id}.flac").as_posix(), 132 | _wavform.T.numpy(), tgt_sample_rate 133 | ) 134 | else: 135 | print("Extracting log mel filter bank features...") 136 | gcmvn_feature_list = [] 137 | if split == 'train' and args.cmvn_type == "global": 138 | print("And estimating cepstral mean and variance stats...") 139 | 140 | for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): 141 | features = extract_fbank_features( 142 | waveform, sample_rate, audio_root / f"{utt_id}.npy" 143 | ) 144 | if split == 'train' and args.cmvn_type == "global": 145 | if len(gcmvn_feature_list) < args.gcmvn_max_num: 146 | gcmvn_feature_list.append(features) 147 | 148 | if split == 'train' and args.cmvn_type == "global": 149 | # Estimate and save cmv 150 | stats = cal_gcmvn_stats(gcmvn_feature_list) 151 | with open(cur_root / "gcmvn.npz", "wb") as f: 152 | np.savez(f, mean=stats["mean"], std=stats["std"]) 153 | 154 | # Pack features into ZIP 155 | zip_path = cur_root / f"{audio_root.name}.zip" 156 | print("ZIPing audios/features...") 157 | create_zip(audio_root, zip_path) 158 | print("Fetching ZIP manifest...") 159 | audio_paths, audio_lengths = get_zip_manifest( 160 | zip_path, 161 | is_audio=args.use_audio_input, 162 | ) 163 | # Generate TSV manifest 164 | print("Generating manifest...") 165 | train_text = [] 166 | for split in MUSTC.SPLITS: 167 | is_train_split = split.startswith("train") 168 | manifest = {c: [] for c in MANIFEST_COLUMNS} 169 | dataset = MUSTC(args.data_root, lang, split) 170 | for _, _, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset): 171 | manifest["id"].append(utt_id) 172 | manifest["audio"].append(audio_paths[utt_id]) 173 | manifest["n_frames"].append(audio_lengths[utt_id]) 174 | manifest["src_text"].append(src_utt) 175 | manifest["tgt_text"].append(tgt_utt) 176 | manifest["speaker"].append(speaker_id) 177 | if is_train_split: 178 | train_text.extend(manifest["tgt_text"]) 179 | train_text.extend(manifest["src_text"]) 180 | df = pd.DataFrame.from_dict(manifest) 181 | df = filter_manifest_df(df, is_train_split=is_train_split) 182 | save_df_to_tsv(df, cur_root / f"{split}.tsv") 183 | # Generate vocab 184 | v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) 185 | spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}" 186 | with NamedTemporaryFile(mode="w") as f: 187 | for t in train_text: 188 | f.write(t + "\n") 189 | gen_vocab( 190 | Path(f.name), 191 | cur_root / spm_filename_prefix, 192 | args.vocab_type, 193 | args.vocab_size, 194 | ) 195 | # Generate config YAML 196 | if args.use_audio_input: 197 | gen_config_yaml( 198 | cur_root, 199 | spm_filename=spm_filename_prefix + ".model", 200 | yaml_filename=f"config.yaml", 201 | specaugment_policy=None, 202 | extra={"use_audio_input": True} 203 | ) 204 | else: 205 | gen_config_yaml( 206 | cur_root, 207 | spm_filename=spm_filename_prefix + ".model", 208 | yaml_filename=f"config.yaml", 209 | specaugment_policy="lb", 210 | cmvn_type=args.cmvn_type, 211 | gcmvn_path=( 212 | cur_root / "gcmvn.npz" if args.cmvn_type == "global" 213 | else None 214 | ), 215 | ) 216 | # Clean up 217 | shutil.rmtree(audio_root) 218 | 219 | 220 | 221 | def main(): 222 | parser = argparse.ArgumentParser() 223 | parser.add_argument("--data-root", "-d", required=True, type=str) 224 | parser.add_argument( 225 | "--vocab-type", 226 | default="unigram", 227 | required=True, 228 | type=str, 229 | choices=["bpe", "unigram", "char"], 230 | ), 231 | parser.add_argument("--vocab-size", default=8000, type=int) 232 | parser.add_argument( 233 | "--cmvn-type", default="utterance", 234 | choices=["global", "utterance"], 235 | help="The type of cepstral mean and variance normalization" 236 | ) 237 | parser.add_argument( 238 | "--gcmvn-max-num", default=150000, type=int, 239 | help="Maximum number of sentences to use to estimate global mean and " 240 | "variance" 241 | ) 242 | parser.add_argument("--use-audio-input", action="store_true") 243 | args = parser.parse_args() 244 | 245 | 246 | process(args) 247 | 248 | 249 | if __name__ == "__main__": 250 | main() 251 | -------------------------------------------------------------------------------- /fs_plugins/models/transducer/transducer_config.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from dataclasses import dataclass, field 3 | 4 | from omegaconf import II 5 | 6 | from fairseq import utils 7 | from fairseq.dataclass import ChoiceEnum, FairseqDataclass 8 | 9 | from fs_plugins.modules.audio_convs import get_available_convs 10 | 11 | DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) 12 | 13 | @dataclass 14 | class SpeechTransformerModelConfig(FairseqDataclass): 15 | activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( # type: ignore 16 | default="relu", 17 | metadata={"help": "activation function to use"}, 18 | ) 19 | dropout: float = field(default=0.0, metadata={"help": "dropout probability"}) 20 | attention_dropout: float = field( 21 | default=0.0, metadata={"help": "dropout probability for attention weights"} 22 | ) 23 | activation_dropout: float = field( 24 | default=0.0, 25 | metadata={ 26 | "help": "dropout probability after activation in FFN.", 27 | "alias": "--relu-dropout", 28 | }, 29 | ) 30 | relu_dropout: float = 0.0 31 | adaptive_input: bool = False 32 | 33 | # Relative Position 34 | encoder_max_relative_position: int = field( 35 | default=32, metadata={"help": "max_relative_position for encoder Relative attention, <0 for traditional attention"} 36 | ) 37 | decoder_max_relative_position: int = field( 38 | default=-1, metadata={"help": "max_relative_position for decoder Relative attention, <0 for traditional attention"} 39 | ) 40 | 41 | # Support Length 42 | max_audio_positions: Optional[int] = II("task.max_audio_positions") 43 | max_text_positions: Optional[int] = II("task.max_text_positions") 44 | max_source_positions: Optional[int] = II("task.max_audio_positions") 45 | max_target_positions: Optional[int] = II("task.max_text_positions") 46 | 47 | # Encoder Configuration 48 | conv_type: ChoiceEnum(get_available_convs()) = field( # type: ignore 49 | default= "shallow2d_base", metadata= {"help": "convolution type for speech encoder"} 50 | ) 51 | encoder_embed_path: Optional[str] = field( 52 | default=None, metadata={"help": "path to pre-trained encoder embedding"} 53 | ) 54 | encoder_embed_dim: int = field( 55 | default= 512, metadata={"help":"encoder embedding dimension"} 56 | ) 57 | encoder_ffn_embed_dim: int = field( 58 | default=2048, metadata={"help": "encoder embedding dimension for FFN"} 59 | ) 60 | encoder_layers: int = field(default=6, metadata={"help": "num encoder layers"}) 61 | encoder_attention_heads: int = field( 62 | default=8, metadata={"help": "num encoder attention heads"} 63 | ) 64 | encoder_normalize_before: bool = field( 65 | default=False, metadata={"help": "apply layernorm before each encoder block"} 66 | ) 67 | encoder_learned_pos: bool = field( 68 | default=False, 69 | metadata={"help": "use learned positional embeddings in the encoder"}, 70 | ) 71 | encoder_layerdrop: float = field( 72 | default=0.0, metadata={"help": "LayerDrop probability for encoder"} 73 | ) 74 | encoder_layers_to_keep: Optional[str] = field( 75 | default=None, 76 | metadata={ 77 | "help": "which layers to *keep* when pruning as a comma-separated list" 78 | }, 79 | ) 80 | 81 | # Decoder Configuration 82 | decoder_embed_path: Optional[str] = field( 83 | default=None, metadata={"help": "path to pre-trained decoder embedding"} 84 | ) 85 | decoder_embed_dim: int = field( 86 | default=512, metadata={"help": "decoder embedding dimension"} 87 | ) 88 | decoder_output_dim: int = field( 89 | default=512, metadata={"help": "decoder output dimension"} 90 | ) 91 | decoder_input_dim: int = field( 92 | default=512, metadata={"help": "decoder input dimension"} 93 | ) 94 | decoder_ffn_embed_dim: int = field( 95 | default=2048, metadata={"help": "decoder embedding dimension for FFN"} 96 | ) 97 | decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"}) #TODO 98 | decoder_attention_heads: int = field( 99 | default=8, metadata={"help": "num decoder attention heads"} 100 | ) 101 | decoder_normalize_before: bool = field( 102 | default=False, metadata={"help": "apply layernorm before each decoder block"} 103 | ) 104 | decoder_learned_pos: bool = field( 105 | default=False, 106 | metadata={"help": "use learned positional embeddings in the decoder"}, 107 | ) 108 | decoder_layerdrop: float = field( 109 | default=0.0, metadata={"help": "LayerDrop probability for decoder"} 110 | ) 111 | decoder_layers_to_keep: Optional[str] = field( 112 | default=None, 113 | metadata={ 114 | "help": "which layers to *keep* when pruning as a comma-separated list" 115 | }, 116 | ) 117 | no_decoder_final_norm: bool = field( 118 | default=False, 119 | metadata={"help": "don't add an extra layernorm after the last decoder block"}, #TODO 120 | ) 121 | share_decoder_input_output_embed: bool = field( 122 | default=False, metadata={"help": "share decoder input and output embeddings"} 123 | ) 124 | share_all_embeddings: bool = field( 125 | default=False, metadata={"help":"share encoder, decoder and output embeddings (requires shared dictionary and embed dim)"} 126 | ) 127 | 128 | 129 | no_token_positional_embeddings: bool = field( 130 | default=False, 131 | metadata={ 132 | "help": "if True, disables positional embeddings (outside self attention)" 133 | }, 134 | ) 135 | no_audio_positional_embeddings:bool = field( 136 | default = False, 137 | metadata={"help":"if True, disables positional embeddings in audio encoder"} 138 | ) 139 | adaptive_softmax_cutoff: Optional[str] = field( 140 | default=None, 141 | metadata={ 142 | "help": "comma separated list of adaptive softmax cutoff points. " 143 | "Must be used with adaptive_loss criterion" 144 | }, 145 | ) 146 | adaptive_softmax_dropout: float = field( 147 | default=0, 148 | metadata={"help": "sets adaptive softmax dropout for the tail projections"}, 149 | ) 150 | adaptive_softmax_factor: float = field( 151 | default=4, metadata={"help": "adaptive input factor"} 152 | ) 153 | layernorm_embedding: bool = field( 154 | default=False, metadata={"help": "add layernorm to embedding"} 155 | ) 156 | tie_adaptive_weights: bool = field( 157 | default=False, 158 | metadata={ 159 | "help": "if set, ties the weights of adaptive softmax and adaptive input" 160 | }, 161 | ) 162 | tie_adaptive_proj: bool = field( 163 | default=False, 164 | metadata={ 165 | "help": "if set, ties the projection weights of adaptive softmax and adaptive input" 166 | }, 167 | ) 168 | no_scale_embedding: bool = field( 169 | default=False, metadata={"help": "if True, dont scale embeddings"} 170 | ) 171 | checkpoint_activations: bool = field( 172 | default=False, 173 | metadata={ 174 | "help": "checkpoint activations at each layer, which saves GPU memory usage at the cost of some additional compute" 175 | }, 176 | ) 177 | offload_activations: bool = field( 178 | default=False, 179 | metadata={ 180 | "help": "checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations." 181 | }, 182 | ) 183 | 184 | # args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019) 185 | no_cross_attention: bool = field( 186 | default=False, metadata={"help": "do not perform cross-attention"} 187 | ) 188 | cross_self_attention: bool = field( 189 | default=False, metadata={"help": "perform cross+self-attention"} 190 | ) 191 | 192 | # args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020) 193 | quant_noise_pq: float = field( 194 | default=0.0, 195 | metadata={"help": "iterative PQ quantization noise at training time"}, 196 | ) 197 | quant_noise_pq_block_size: int = field( 198 | default=8, 199 | metadata={"help": "block size of quantization noise at training time"}, 200 | ) 201 | quant_noise_scalar: float = field( 202 | default=0.0, 203 | metadata={ 204 | "help": "scalar quantization noise and scalar quantization at training time" 205 | }, 206 | ) 207 | min_params_to_wrap: int = field( 208 | default=DEFAULT_MIN_PARAMS_TO_WRAP, 209 | metadata={ 210 | "help": "minimum number of params for a layer to be wrapped with FSDP() when " 211 | "training with --ddp-backend=fully_sharded. Smaller values will " 212 | "improve memory efficiency, but may make torch.distributed " 213 | "communication less efficient due to smaller input sizes. This option " 214 | "is set to 0 (i.e., always wrap) when --checkpoint-activations or " 215 | "--offload-activations are passed." 216 | }, 217 | ) 218 | 219 | # Speech Encoder Configuration 220 | rand_pos_encoder: int = field( 221 | default=300, 222 | metadata={ 223 | "help":"max random start for encoder position embedding" 224 | } 225 | ) 226 | rand_pos_decoder: int = field( 227 | default=0, 228 | metadata={ 229 | "help":"max random start for encoder position embedding" 230 | } 231 | ) 232 | load_pretrained_encoder_from: Optional[str] = field( 233 | default=None, metadata={"help":"pretrained_encoder_path"} 234 | ) 235 | load_pretrained_decoder_from: Optional[str] = field( 236 | default=None, metadata={"help":"pretrained_decoder_path"} 237 | ) 238 | 239 | # params for online 240 | main_context:int = field( 241 | default=16, metadata={"help":"main context frame"} 242 | ) 243 | right_context :int = field( 244 | default=16, metadata={"help":"right context frame"} 245 | ) 246 | 247 | 248 | 249 | @dataclass 250 | class TransducerConfig(SpeechTransformerModelConfig): 251 | alpha: float = field( 252 | default=1.0, 253 | metadata = {"help": "hyperparamter for constructing prior alignment"} 254 | ) 255 | transducer_downsample: int = field( 256 | default=4, 257 | metadata = {"help": "source downsample ratio for transducer"} 258 | ) 259 | transducer_activation: ChoiceEnum(utils.get_available_activation_fns()) = field( # type: ignore 260 | default="tanh", metadata={"help": "activation function to use"} 261 | ) 262 | transducer_smoothing: float = field( 263 | default= 0., metadata = {"help":"label smoothing for transducer loss"} 264 | ) 265 | tokens_per_step:int = field( 266 | default=20000, 267 | metadata={"help":"tokens per step for output head splitting"} 268 | ) 269 | delay_scale:float = field( 270 | default=1.0, 271 | metadata={"help":"scale for delay loss"} 272 | ) 273 | delay_func: ChoiceEnum(['zero', 'diag_positive', 'diagonal']) =field( # type: ignore 274 | default="diag_positive", metadata= {"help":"function for delay loss"} 275 | ) 276 | transducer_ce_scale: float= field( 277 | default=1.0, metadata= {'help':'scale for ce loss'} 278 | ) 279 | transducer_label_smoothing:float= field( 280 | default=0.1, metadata ={'help':"label smoothing for ce loss"} 281 | ) 282 | transducer_temperature:float= field( 283 | default=1.0, metadata={"help":"temperature for output probs"} 284 | ) 285 | 286 | 287 | -------------------------------------------------------------------------------- /fs_plugins/datasets/transducer_speech_to_text_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | from pathlib import Path 4 | from typing import Dict, List, Optional, Tuple, Union 5 | 6 | from dataclasses import dataclass 7 | from fairseq.data import data_utils as fairseq_data_utils 8 | from fairseq.data import ConcatDataset, Dictionary, FairseqDataset, ResamplingDataset 9 | from fairseq.data.audio.dataset_transforms.concataugment import ConcatAugment 10 | from fairseq.data.audio.speech_to_text_dataset import ( 11 | _collate_frames, 12 | S2TDataConfig, 13 | SpeechToTextDatasetItem, 14 | SpeechToTextDataset, 15 | SpeechToTextDatasetCreator, 16 | TextTargetMultitaskData, 17 | _is_int_or_np_int, 18 | ) 19 | from fairseq.data.audio.dataset_transforms.noisyoverlapaugment import ( 20 | NoisyOverlapAugment, 21 | ) 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | @dataclass 27 | class TransducerSpeechToTextDatasetItem(SpeechToTextDatasetItem): 28 | transcript: Optional[torch.Tensor] = None 29 | 30 | 31 | class TransducerSpeechToTextDataset(SpeechToTextDataset): 32 | """ 33 | Modified from SpeechToTextDataset. 34 | Prepend and append for prev_target. 35 | Only append for target. 36 | """ 37 | def __getitem__(self, index: int) -> TransducerSpeechToTextDatasetItem: 38 | has_concat = self.dataset_transforms.has_transform(ConcatAugment) 39 | if has_concat: 40 | concat = self.dataset_transforms.get_transform(ConcatAugment) 41 | indices = concat.find_indices(index, self.n_frames, self.n_samples) 42 | 43 | source = self._get_source_audio(indices if has_concat else index) 44 | source = self.pack_frames(source) 45 | 46 | target = None 47 | if self.tgt_texts is not None: 48 | tokenized = self.get_tokenized_tgt_text(indices if has_concat else index) 49 | target = self.tgt_dict.encode_line( 50 | tokenized, add_if_not_exist=False, append_eos=True, 51 | ).long() 52 | #bos = torch.LongTensor([self.tgt_dict.bos()]) 53 | #target = torch.cat((bos, target), 0) 54 | 55 | transcript = None 56 | if self.src_texts is not None: 57 | tokenized_transcript = self.get_tokenized_src_text(indices if has_concat else index) 58 | transcript = self.tgt_dict.encode_line( 59 | tokenized_transcript, add_if_not_exist=False, append_eos=True, 60 | ).long() 61 | #bos = torch.LongTensor([self.tgt_dict.bos()]) 62 | #transcript = torch.cat((bos, transcript), 0) 63 | 64 | speaker_id = None 65 | if self.speaker_to_id is not None: 66 | speaker_id = self.speaker_to_id[self.speakers[index]] 67 | return TransducerSpeechToTextDatasetItem( 68 | index=index, source=source, transcript=transcript, target=target, speaker_id=speaker_id 69 | ) 70 | 71 | def get_tokenized_src_text(self, index: Union[int, List[int]]): 72 | if _is_int_or_np_int(index): 73 | text = self.src_texts[index] 74 | else: 75 | text = " ".join([self.src_texts[i] for i in index]) 76 | 77 | text = self.tokenize(self.pre_tokenizer, text) 78 | text = self.tokenize(self.bpe_tokenizer, text) 79 | return text 80 | 81 | 82 | def collater( 83 | self, samples: List[TransducerSpeechToTextDatasetItem], return_order: bool = False 84 | ) -> Dict: 85 | if len(samples) == 0: 86 | return {} 87 | indices = torch.tensor([x.index for x in samples], dtype=torch.long) 88 | 89 | sources = [x.source for x in samples] 90 | has_NOAug = self.dataset_transforms.has_transform(NoisyOverlapAugment) 91 | if has_NOAug and self.cfg.use_audio_input: 92 | NOAug = self.dataset_transforms.get_transform(NoisyOverlapAugment) 93 | sources = NOAug(sources) 94 | 95 | frames = _collate_frames(sources, self.cfg.use_audio_input) 96 | # sort samples by descending number of frames 97 | n_frames = torch.tensor([x.size(0) for x in sources], dtype=torch.long) 98 | n_frames, order = n_frames.sort(descending=True) 99 | indices = indices.index_select(0, order) 100 | frames = frames.index_select(0, order) 101 | 102 | target, target_lengths = None, None 103 | prev_output_tokens = None 104 | ntokens = None 105 | if self.tgt_texts is not None: 106 | target = fairseq_data_utils.collate_tokens( 107 | [x.target for x in samples], 108 | self.tgt_dict.pad(), 109 | self.tgt_dict.eos(), 110 | left_pad=False, 111 | move_eos_to_beginning=False, 112 | ) 113 | target = target.index_select(0, order) 114 | target_lengths = torch.tensor( 115 | [x.target.size(0) for x in samples], dtype=torch.long 116 | ).index_select(0, order) 117 | ntokens = sum(x.target.size(0) for x in samples) 118 | B = target.size(0) 119 | bos = torch.LongTensor([self.tgt_dict.bos()]).expand(B, 1) 120 | prev_output_tokens = torch.cat((bos, target), dim=-1) 121 | 122 | transcript, transcript_lengths = None, None 123 | prev_output_tokens_transcript = None 124 | ntokens_transcript = None 125 | if self.src_texts is not None: 126 | transcript = fairseq_data_utils.collate_tokens( 127 | [x.transcript for x in samples], 128 | self.tgt_dict.pad(), 129 | self.tgt_dict.eos(), 130 | left_pad=False, 131 | move_eos_to_beginning=False, 132 | ) 133 | transcript = transcript.index_select(0, order) 134 | transcript_lengths = torch.tensor( 135 | [x.transcript.size(0) for x in samples], dtype=torch.long 136 | ).index_select(0, order) 137 | ntokens_transcript = sum(x.transcript.size(0) for x in samples) 138 | B = transcript.size(0) 139 | bos = torch.LongTensor([self.tgt_dict.bos()]).expand(B, 1) 140 | prev_output_tokens_transcript = torch.cat((bos, transcript), dim=-1) 141 | 142 | speaker = None 143 | if self.speaker_to_id is not None: 144 | speaker = ( 145 | torch.tensor([s.speaker_id for s in samples], dtype=torch.long) 146 | .index_select(0, order) 147 | .view(-1, 1) 148 | ) 149 | 150 | net_input = { 151 | "src_tokens": frames, 152 | "src_lengths": n_frames, 153 | "prev_output_tokens": prev_output_tokens, 154 | "prev_output_tokens_transcript": prev_output_tokens_transcript, 155 | } 156 | out = { 157 | "id": indices, 158 | "net_input": net_input, 159 | "speaker": speaker, 160 | "target": target, 161 | "target_lengths": target_lengths, 162 | "transcript": transcript, 163 | "transcript_lengths": transcript_lengths, 164 | "ntokens": ntokens, 165 | "nsentences": len(samples), 166 | } 167 | if return_order: 168 | out["order"] = order 169 | return out 170 | 171 | 172 | class TransducerSpeechToTextDatasetCreator(SpeechToTextDatasetCreator): 173 | DEFAULT_TGT_TEXT = "" 174 | 175 | @classmethod 176 | def _from_list( 177 | cls, 178 | split_name: str, 179 | is_train_split, 180 | samples: List[Dict], 181 | cfg: S2TDataConfig, 182 | tgt_dict, 183 | pre_tokenizer, 184 | bpe_tokenizer, 185 | n_frames_per_step, 186 | speaker_to_id, 187 | multitask: Optional[Dict] = None, 188 | ) -> TransducerSpeechToTextDataset: 189 | audio_root = Path(cfg.audio_root) 190 | ids = [s[cls.KEY_ID] for s in samples] 191 | audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples] 192 | n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples] 193 | tgt_texts = [s.get(cls.KEY_TGT_TEXT, cls.DEFAULT_TGT_TEXT) for s in samples] 194 | src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples] 195 | speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples] 196 | src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples] 197 | tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples] 198 | 199 | #has_multitask = multitask is not None and len(multitask.keys()) > 0 200 | #dataset_cls = ( 201 | # NATSpeechToTextMultitaskDataset if has_multitask else NATSpeechToTextDataset 202 | #) 203 | dataset_cls = TransducerSpeechToTextDataset 204 | 205 | ds = dataset_cls( 206 | split=split_name, 207 | is_train_split=is_train_split, 208 | cfg=cfg, 209 | audio_paths=audio_paths, 210 | n_frames=n_frames, 211 | src_texts=src_texts, 212 | tgt_texts=tgt_texts, 213 | speakers=speakers, 214 | src_langs=src_langs, 215 | tgt_langs=tgt_langs, 216 | ids=ids, 217 | tgt_dict=tgt_dict, 218 | pre_tokenizer=pre_tokenizer, 219 | bpe_tokenizer=bpe_tokenizer, 220 | n_frames_per_step=n_frames_per_step, 221 | speaker_to_id=speaker_to_id, 222 | ) 223 | 224 | return ds 225 | 226 | @classmethod 227 | def _from_tsv( 228 | cls, 229 | root: str, 230 | cfg: S2TDataConfig, 231 | split: str, 232 | tgt_dict, 233 | is_train_split: bool, 234 | pre_tokenizer, 235 | bpe_tokenizer, 236 | n_frames_per_step, 237 | speaker_to_id, 238 | multitask: Optional[Dict] = None, 239 | ) -> TransducerSpeechToTextDataset: 240 | samples = cls._load_samples_from_tsv(root, split) 241 | return cls._from_list( 242 | split, 243 | is_train_split, 244 | samples, 245 | cfg, 246 | tgt_dict, 247 | pre_tokenizer, 248 | bpe_tokenizer, 249 | n_frames_per_step, 250 | speaker_to_id, 251 | multitask, 252 | ) 253 | 254 | @classmethod 255 | def from_tsv( 256 | cls, 257 | root: str, 258 | cfg: S2TDataConfig, 259 | splits: str, 260 | tgt_dict, 261 | pre_tokenizer, 262 | bpe_tokenizer, 263 | is_train_split: bool, 264 | epoch: int, 265 | seed: int, 266 | n_frames_per_step: int = 1, 267 | speaker_to_id=None, 268 | multitask: Optional[Dict] = None, 269 | ) -> TransducerSpeechToTextDataset: 270 | datasets = [ 271 | cls._from_tsv( 272 | root=root, 273 | cfg=cfg, 274 | split=split, 275 | tgt_dict=tgt_dict, 276 | is_train_split=is_train_split, 277 | pre_tokenizer=pre_tokenizer, 278 | bpe_tokenizer=bpe_tokenizer, 279 | n_frames_per_step=n_frames_per_step, 280 | speaker_to_id=speaker_to_id, 281 | multitask=multitask, 282 | ) 283 | for split in splits.split(",") 284 | ] 285 | 286 | if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0: 287 | # temperature-based sampling 288 | size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha) 289 | datasets = [ 290 | ResamplingDataset( 291 | d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0) 292 | ) 293 | for r, d in zip(size_ratios, datasets) 294 | ] 295 | 296 | return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0] -------------------------------------------------------------------------------- /fs_plugins/agents/transducer_agent.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import json 4 | from typing import Dict, Optional 5 | 6 | import numpy as np 7 | import torch 8 | from torch import Tensor 9 | import torch.nn.functional as F 10 | import torchaudio.compliance.kaldi as kaldi 11 | import yaml 12 | 13 | from fairseq import checkpoint_utils, tasks, utils 14 | from fairseq.file_io import PathManager 15 | from examples.speech_to_text.data_utils import extract_fbank_features 16 | 17 | 18 | from simuleval.utils import entrypoint 19 | from simuleval.data.segments import EmptySegment, TextSegment, SpeechSegment 20 | from simuleval.agents import SpeechToTextAgent 21 | from simuleval.agents.states import AgentStates 22 | from simuleval.agents.actions import WriteAction, ReadAction 23 | 24 | 25 | import pdb 26 | 27 | SHIFT_SIZE = 10 28 | WINDOW_SIZE = 25 29 | SAMPLE_RATE = 16000 30 | FEATURE_DIM = 80 31 | BOW_PREFIX = "\u2581" 32 | DEFAULT_BOS = 0 33 | DEFAULT_EOS = 2 34 | 35 | 36 | class OfflineFeatureExtractor: 37 | """ 38 | Extract speech feature from sequence prefix. 39 | """ 40 | 41 | def __init__(self, args): 42 | self.shift_size = args.shift_size 43 | self.window_size = args.window_size 44 | assert self.window_size >= self.shift_size 45 | 46 | self.sample_rate = args.sample_rate 47 | self.feature_dim = args.feature_dim 48 | self.num_samples_per_shift = int(self.shift_size * self.sample_rate / 1000) 49 | self.num_samples_per_window = int(self.window_size * self.sample_rate / 1000) 50 | self.len_ms_to_samples = lambda x: x * self.sample_rate / 1000 51 | self.global_cmvn = args.global_cmvn 52 | self.device = 'cuda' if args.device == 'gpu' else 'cpu' 53 | 54 | 55 | def __call__(self, new_samples): 56 | samples = new_samples 57 | 58 | assert len(samples) >= self.num_samples_per_window 59 | 60 | num_frames = math.floor( 61 | (len(samples) - self.len_ms_to_samples(self.window_size - self.shift_size)) 62 | / self.num_samples_per_shift 63 | ) 64 | 65 | effective_num_samples = int( 66 | num_frames * self.len_ms_to_samples(self.shift_size) 67 | + self.len_ms_to_samples(self.window_size - self.shift_size) 68 | ) 69 | 70 | input_samples = samples[:effective_num_samples] 71 | 72 | torch.manual_seed(1) 73 | output = extract_fbank_features(torch.FloatTensor(input_samples).unsqueeze(0), self.sample_rate) 74 | 75 | output = self.transform(output) 76 | 77 | return torch.from_numpy(output).to(self.device) 78 | 79 | def transform(self, input): 80 | if self.global_cmvn is None: 81 | return input 82 | 83 | mean = self.global_cmvn["mean"] 84 | std = self.global_cmvn["std"] 85 | 86 | x = np.subtract(input, mean) 87 | x = np.divide(x, std) 88 | return x 89 | 90 | 91 | class TransducerSpeechToTextAgentStates(AgentStates): 92 | 93 | def __init__(self, device) -> None: 94 | self.device = device 95 | self.reset() 96 | 97 | def reset(self) -> None: 98 | """Reset Agent states""" 99 | 100 | super().reset() 101 | 102 | self.num_complete_chunk = 0 103 | self.prev_output_tokens = torch.tensor([[DEFAULT_BOS]]).long().to(self.device) 104 | 105 | self.unfinished_subword = [] 106 | self.incremental_state = torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) 107 | 108 | 109 | 110 | @entrypoint 111 | class TransducerSpeechToTextAgent(SpeechToTextAgent): 112 | 113 | speech_segment_size = 40 # in ms, 4 pooling ratio * 10 ms step size 114 | 115 | def __init__(self, args): 116 | super().__init__(args) 117 | 118 | self.device ='cuda' if args.device == 'gpu' else 'cpu' 119 | self.states = self.build_states() 120 | args.global_cmvn = None 121 | if args.config_yaml: 122 | with open(os.path.join(args.data_bin, args.config_yaml), "r") as f: 123 | config = yaml.load(f, Loader=yaml.BaseLoader) 124 | 125 | if "global_cmvn" in config: 126 | args.global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"]) 127 | 128 | 129 | self.load_model_vocab(args) 130 | #utils.import_user_module(args) 131 | 132 | self.feature_extractor = OfflineFeatureExtractor(args) 133 | 134 | self.downsample = args.transducer_downsample 135 | self.main_context = args.main_context 136 | self.right_context = args.right_context 137 | 138 | torch.set_grad_enabled(False) 139 | self.reset() 140 | 141 | 142 | def build_states(self) -> TransducerSpeechToTextAgentStates: 143 | """ 144 | Build states instance for agent 145 | 146 | Returns: 147 | TransducerSpeechToTextAgentStates: agent states 148 | """ 149 | return TransducerSpeechToTextAgentStates(self.device) 150 | 151 | @staticmethod 152 | def add_args(parser): 153 | # fmt: off 154 | parser.add_argument('--model-path', type=str, required=True, 155 | help='path to your pretrained model.') 156 | parser.add_argument("--data-bin", type=str, required=True, 157 | help="Path of data binary") 158 | parser.add_argument("--config-yaml", type=str, default=None, 159 | help="Path to config yaml file") 160 | parser.add_argument("--global-stats", type=str, default=None, 161 | help="Path to json file containing cmvn stats") 162 | parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece", 163 | help="Subword splitter type for target text") 164 | parser.add_argument("--tgt-splitter-path", type=str, default=None, 165 | help="Subword splitter model path for target text") 166 | parser.add_argument("--user-dir", type=str, default="examples/simultaneous_translation", 167 | help="User directory for simultaneous translation") 168 | parser.add_argument("--shift-size", type=int, default=SHIFT_SIZE, 169 | help="Shift size of feature extraction window.") 170 | parser.add_argument("--window-size", type=int, default=WINDOW_SIZE, 171 | help="Window size of feature extraction window.") 172 | parser.add_argument("--sample-rate", type=int, default=SAMPLE_RATE, 173 | help="Sample rate") 174 | parser.add_argument("--feature-dim", type=int, default=FEATURE_DIM, 175 | help="Acoustic feature dimension.") 176 | parser.add_argument("--main-context", type=int, default=32) 177 | parser.add_argument("--right-context", type=int, default=16) 178 | parser.add_argument("--transducer-downsample", type=int, default=1) 179 | parser.add_argument("--device", type=str, default='gpu') 180 | 181 | # fmt: on 182 | return parser 183 | 184 | def load_model_vocab(self, args): 185 | 186 | filename = args.model_path 187 | if not os.path.exists(filename): 188 | raise IOError("Model file not found: {}".format(filename)) 189 | 190 | state = checkpoint_utils.load_checkpoint_to_cpu(filename) 191 | utils.import_user_module(state["cfg"].common) 192 | 193 | task_args = state["cfg"]["task"] 194 | task_args.data = args.data_bin 195 | 196 | if args.config_yaml is not None: 197 | task_args.config_yaml = args.config_yaml 198 | 199 | task = tasks.setup_task(task_args) 200 | 201 | self.model = task.build_model(state["cfg"]["model"]) 202 | self.model.load_state_dict(state["model"], strict=True) 203 | self.model.eval() 204 | self.model.share_memory() 205 | 206 | if self.device == 'cuda': 207 | self.model.cuda() 208 | 209 | # Set dictionary 210 | self.tgt_dict = task.target_dictionary 211 | 212 | @torch.inference_mode() 213 | def policy(self): 214 | 215 | num_frames = math.floor( 216 | (len(self.states.source) - self.feature_extractor.len_ms_to_samples(self.feature_extractor.window_size - self.feature_extractor.shift_size)) 217 | / self.feature_extractor.num_samples_per_shift 218 | ) 219 | 220 | # at least a new complete chunk is received if not finished 221 | if not self.states.source_finished: 222 | if num_frames < self.main_context * (self.states.num_complete_chunk + 1) + self.right_context: 223 | return ReadAction() 224 | 225 | # this is used to caluculate self.states.num_complete_chunk 226 | num_complete_new_chunk = math.floor((num_frames - self.right_context) / self.main_context) - self.states.num_complete_chunk 227 | 228 | # Calculated the number of frames to make decisions 229 | if not self.states.source_finished: 230 | num_decision = num_complete_new_chunk * int(self.main_context / 4 / self.downsample) 231 | else: 232 | num_decision = None 233 | 234 | feature = self.feature_extractor(self.states.source) # prefix feature: T × C 235 | assert num_frames == feature.size(0) 236 | src_tokens = feature.unsqueeze(0) # 1 × T × C 237 | src_lengths = torch.tensor([feature.size(0)], device=self.device).long() # 1 238 | 239 | encoder_out = self.model.encoder(src_tokens, src_lengths) 240 | 241 | downsampled_encoder_out = encoder_out["encoder_out"][0][self.states.num_complete_chunk * (self.main_context // 4):][::self.downsample].squeeze(1) # num_decision × C 242 | downsampled_encoder_out = downsampled_encoder_out[:num_decision] 243 | 244 | final_output_tokens = [] 245 | 246 | for i in range(downsampled_encoder_out.size(0)): 247 | ii=0 # max emit per frame 248 | while True: 249 | #pdb.set_trace() 250 | h_lm_last = self.model.decoder.lm(self.states.prev_output_tokens, self.states.incremental_state)[:, -1].squeeze() # V 251 | joint_result = self.model.decoder.jointer.infer(downsampled_encoder_out[i], h_lm_last) # V 252 | 253 | log_probs = F.log_softmax(self.model.decoder.out_proj(joint_result), dim=-1) 254 | select_word = log_probs.argmax(dim=-1).item() 255 | 256 | if select_word == self.tgt_dict.blank_index: 257 | break 258 | self.states.prev_output_tokens = torch.cat([self.states.prev_output_tokens, torch.tensor([[select_word]], device=self.states.prev_output_tokens.device)], dim=-1) 259 | final_output_tokens.append(select_word) 260 | ii += 1 261 | if ii == 5: 262 | break 263 | #if select_word == 416: 264 | # break 265 | 266 | self.states.num_complete_chunk = self.states.num_complete_chunk + num_complete_new_chunk 267 | 268 | 269 | detok_output_tokens = [] 270 | 271 | for index in final_output_tokens: 272 | token = self.tgt_dict.string([index]) #return a string 273 | if token.startswith(BOW_PREFIX): 274 | if len(self.states.unfinished_subword) != 0: 275 | detok_output_tokens += ["".join(self.states.unfinished_subword)] 276 | self.states.unfinished_subword = [] 277 | self.states.unfinished_subword += [token.replace(BOW_PREFIX, "")] 278 | else: 279 | self.states.unfinished_subword += [token] 280 | 281 | if self.states.source_finished: 282 | detok_output_tokens += ["".join(self.states.unfinished_subword)] 283 | self.states.unfinished_subword = [] 284 | 285 | 286 | detok_output_string = " ".join(detok_output_tokens) 287 | 288 | return WriteAction(TextSegment(content=detok_output_string, finished=self.states.source_finished), finished=self.states.source_finished) 289 | 290 | 291 | -------------------------------------------------------------------------------- /fs_plugins/agents/transducer_agent_v2.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import json 4 | from typing import Dict, Optional 5 | 6 | import numpy as np 7 | import torch 8 | from torch import Tensor 9 | import torch.nn.functional as F 10 | import torchaudio.compliance.kaldi as kaldi 11 | import yaml 12 | 13 | from fairseq import checkpoint_utils, tasks, utils 14 | from fairseq.file_io import PathManager 15 | from examples.speech_to_text.data_utils import extract_fbank_features 16 | 17 | 18 | from simuleval.utils import entrypoint 19 | from simuleval.data.segments import EmptySegment, TextSegment, SpeechSegment 20 | from simuleval.agents import SpeechToTextAgent 21 | from simuleval.agents.states import AgentStates 22 | from simuleval.agents.actions import WriteAction, ReadAction 23 | 24 | 25 | import pdb 26 | 27 | SHIFT_SIZE = 10 28 | WINDOW_SIZE = 25 29 | SAMPLE_RATE = 16000 30 | FEATURE_DIM = 80 31 | BOW_PREFIX = "\u2581" 32 | DEFAULT_BOS = 0 33 | DEFAULT_EOS = 2 34 | 35 | 36 | class OfflineFeatureExtractor: 37 | """ 38 | Extract speech feature from sequence prefix. 39 | """ 40 | 41 | def __init__(self, args): 42 | self.shift_size = args.shift_size 43 | self.window_size = args.window_size 44 | assert self.window_size >= self.shift_size 45 | 46 | self.sample_rate = args.sample_rate 47 | self.feature_dim = args.feature_dim 48 | self.num_samples_per_shift = int(self.shift_size * self.sample_rate / 1000) 49 | self.num_samples_per_window = int(self.window_size * self.sample_rate / 1000) 50 | self.len_ms_to_samples = lambda x: x * self.sample_rate / 1000 51 | self.global_cmvn = args.global_cmvn 52 | self.device = 'cuda' if args.device == 'gpu' else 'cpu' 53 | 54 | 55 | def __call__(self, new_samples): 56 | samples = new_samples 57 | 58 | assert len(samples) >= self.num_samples_per_window 59 | 60 | num_frames = math.floor( 61 | (len(samples) - self.len_ms_to_samples(self.window_size - self.shift_size)) 62 | / self.num_samples_per_shift 63 | ) 64 | 65 | effective_num_samples = int( 66 | num_frames * self.len_ms_to_samples(self.shift_size) 67 | + self.len_ms_to_samples(self.window_size - self.shift_size) 68 | ) 69 | 70 | input_samples = samples[:effective_num_samples] 71 | 72 | torch.manual_seed(1) 73 | output = extract_fbank_features(torch.FloatTensor(input_samples).unsqueeze(0), self.sample_rate) 74 | 75 | output = self.transform(output) 76 | 77 | return torch.from_numpy(output).to(self.device) 78 | 79 | def transform(self, input): 80 | if self.global_cmvn is None: 81 | return input 82 | 83 | mean = self.global_cmvn["mean"] 84 | std = self.global_cmvn["std"] 85 | 86 | x = np.subtract(input, mean) 87 | x = np.divide(x, std) 88 | return x 89 | 90 | 91 | class TransducerSpeechToTextAgentStates(AgentStates): 92 | 93 | def __init__(self, device) -> None: 94 | self.device = device 95 | self.reset() 96 | 97 | def reset(self) -> None: 98 | """Reset Agent states""" 99 | 100 | super().reset() 101 | 102 | self.num_complete_chunk = 0 103 | self.prev_output_tokens = torch.tensor([[DEFAULT_BOS]]).long().to(self.device) 104 | 105 | self.unfinished_subword = [] 106 | self.incremental_state = torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) 107 | self.h_lm_last = None 108 | 109 | 110 | 111 | @entrypoint 112 | class TransducerSpeechToTextAgent(SpeechToTextAgent): 113 | 114 | speech_segment_size = 40 # in ms, 4 pooling ratio * 10 ms step size 115 | 116 | def __init__(self, args): 117 | super().__init__(args) 118 | 119 | self.device ='cuda' if args.device == 'gpu' else 'cpu' 120 | self.states = self.build_states() 121 | args.global_cmvn = None 122 | if args.config_yaml: 123 | with open(os.path.join(args.data_bin, args.config_yaml), "r") as f: 124 | config = yaml.load(f, Loader=yaml.BaseLoader) 125 | 126 | if "global_cmvn" in config: 127 | args.global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"]) 128 | 129 | 130 | self.load_model_vocab(args) 131 | #utils.import_user_module(args) 132 | 133 | self.feature_extractor = OfflineFeatureExtractor(args) 134 | 135 | self.downsample = args.transducer_downsample 136 | self.main_context = args.main_context 137 | self.right_context = args.right_context 138 | 139 | torch.set_grad_enabled(False) 140 | self.reset() 141 | 142 | 143 | def build_states(self) -> TransducerSpeechToTextAgentStates: 144 | """ 145 | Build states instance for agent 146 | 147 | Returns: 148 | TransducerSpeechToTextAgentStates: agent states 149 | """ 150 | return TransducerSpeechToTextAgentStates(self.device) 151 | 152 | @staticmethod 153 | def add_args(parser): 154 | # fmt: off 155 | parser.add_argument('--model-path', type=str, required=True, 156 | help='path to your pretrained model.') 157 | parser.add_argument("--data-bin", type=str, required=True, 158 | help="Path of data binary") 159 | parser.add_argument("--config-yaml", type=str, default=None, 160 | help="Path to config yaml file") 161 | parser.add_argument("--global-stats", type=str, default=None, 162 | help="Path to json file containing cmvn stats") 163 | parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece", 164 | help="Subword splitter type for target text") 165 | parser.add_argument("--tgt-splitter-path", type=str, default=None, 166 | help="Subword splitter model path for target text") 167 | parser.add_argument("--user-dir", type=str, default="examples/simultaneous_translation", 168 | help="User directory for simultaneous translation") 169 | parser.add_argument("--shift-size", type=int, default=SHIFT_SIZE, 170 | help="Shift size of feature extraction window.") 171 | parser.add_argument("--window-size", type=int, default=WINDOW_SIZE, 172 | help="Window size of feature extraction window.") 173 | parser.add_argument("--sample-rate", type=int, default=SAMPLE_RATE, 174 | help="Sample rate") 175 | parser.add_argument("--feature-dim", type=int, default=FEATURE_DIM, 176 | help="Acoustic feature dimension.") 177 | parser.add_argument("--main-context", type=int, default=32) 178 | parser.add_argument("--right-context", type=int, default=16) 179 | parser.add_argument("--transducer-downsample", type=int, default=1) 180 | parser.add_argument("--device", type=str, default='gpu') 181 | 182 | # fmt: on 183 | return parser 184 | 185 | def load_model_vocab(self, args): 186 | 187 | filename = args.model_path 188 | if not os.path.exists(filename): 189 | raise IOError("Model file not found: {}".format(filename)) 190 | 191 | state = checkpoint_utils.load_checkpoint_to_cpu(filename) 192 | utils.import_user_module(state["cfg"].common) 193 | 194 | task_args = state["cfg"]["task"] 195 | task_args.data = args.data_bin 196 | 197 | if args.config_yaml is not None: 198 | task_args.config_yaml = args.config_yaml 199 | 200 | task = tasks.setup_task(task_args) 201 | 202 | self.model = task.build_model(state["cfg"]["model"]) 203 | self.model.load_state_dict(state["model"], strict=True) 204 | self.model.eval() 205 | self.model.share_memory() 206 | 207 | if self.device == 'cuda': 208 | self.model.cuda() 209 | 210 | # Set dictionary 211 | self.tgt_dict = task.target_dictionary 212 | 213 | @torch.inference_mode() 214 | def policy(self): 215 | 216 | num_frames = math.floor( 217 | (len(self.states.source) - self.feature_extractor.len_ms_to_samples(self.feature_extractor.window_size - self.feature_extractor.shift_size)) 218 | / self.feature_extractor.num_samples_per_shift 219 | ) 220 | 221 | # at least a new complete chunk is received if not finished 222 | if not self.states.source_finished: 223 | if num_frames < self.main_context * (self.states.num_complete_chunk + 1) + self.right_context: 224 | return ReadAction() 225 | 226 | # this is used to caluculate self.states.num_complete_chunk 227 | num_complete_new_chunk = math.floor((num_frames - self.right_context) / self.main_context) - self.states.num_complete_chunk 228 | 229 | # Calculated the number of frames to make decisions 230 | if not self.states.source_finished: 231 | num_decision = num_complete_new_chunk * int(self.main_context / 4 / self.downsample) 232 | else: 233 | num_decision = None 234 | 235 | feature = self.feature_extractor(self.states.source) # prefix feature: T × C 236 | assert num_frames == feature.size(0) 237 | src_tokens = feature.unsqueeze(0) # 1 × T × C 238 | src_lengths = torch.tensor([feature.size(0)], device=self.device).long() # 1 239 | 240 | encoder_out = self.model.encoder(src_tokens, src_lengths) 241 | 242 | downsampled_encoder_out = encoder_out["encoder_out"][0][self.states.num_complete_chunk * (self.main_context // 4):][::self.downsample].squeeze(1) # num_decision × C 243 | downsampled_encoder_out = downsampled_encoder_out[:num_decision] 244 | 245 | final_output_tokens = [] 246 | 247 | if self.states.h_lm_last is None: 248 | self.states.h_lm_last = self.model.decoder.lm(self.states.prev_output_tokens, self.states.incremental_state)[:, -1].squeeze() # V 249 | 250 | 251 | for i in range(downsampled_encoder_out.size(0)): 252 | ii=0 # max emit per frame 253 | while True: 254 | #pdb.set_trace() 255 | joint_result = self.model.decoder.jointer.infer(downsampled_encoder_out[i], self.states.h_lm_last) # V 256 | 257 | log_probs = F.log_softmax(self.model.decoder.out_proj(joint_result), dim=-1) 258 | select_word = log_probs.argmax(dim=-1).item() 259 | 260 | if select_word == self.tgt_dict.blank_index: 261 | break 262 | self.states.prev_output_tokens = torch.cat([self.states.prev_output_tokens, torch.tensor([[select_word]], device=self.states.prev_output_tokens.device)], dim=-1) 263 | final_output_tokens.append(select_word) 264 | self.states.h_lm_last = self.model.decoder.lm(self.states.prev_output_tokens, self.states.incremental_state)[:, -1].squeeze() # V 265 | ii += 1 266 | if ii == 5: 267 | break 268 | #if select_word == 416: 269 | # break 270 | 271 | self.states.num_complete_chunk = self.states.num_complete_chunk + num_complete_new_chunk 272 | 273 | 274 | detok_output_tokens = [] 275 | 276 | for index in final_output_tokens: 277 | token = self.tgt_dict.string([index]) #return a string 278 | if token.startswith(BOW_PREFIX): 279 | if len(self.states.unfinished_subword) != 0: 280 | detok_output_tokens += ["".join(self.states.unfinished_subword)] 281 | self.states.unfinished_subword = [] 282 | self.states.unfinished_subword += [token.replace(BOW_PREFIX, "")] 283 | else: 284 | self.states.unfinished_subword += [token] 285 | 286 | if self.states.source_finished: 287 | detok_output_tokens += ["".join(self.states.unfinished_subword)] 288 | self.states.unfinished_subword = [] 289 | 290 | 291 | detok_output_string = " ".join(detok_output_tokens) 292 | 293 | return WriteAction(TextSegment(content=detok_output_string, finished=self.states.source_finished), finished=self.states.source_finished) 294 | 295 | 296 | -------------------------------------------------------------------------------- /fs_plugins/agents/attention_transducer_agent.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import json 4 | from typing import Dict, Optional 5 | 6 | import numpy as np 7 | import torch 8 | from torch import Tensor 9 | import torch.nn.functional as F 10 | import torchaudio.compliance.kaldi as kaldi 11 | import yaml 12 | 13 | from fairseq import checkpoint_utils, tasks, utils 14 | from fairseq.file_io import PathManager 15 | from examples.speech_to_text.data_utils import extract_fbank_features 16 | 17 | 18 | from simuleval.utils import entrypoint 19 | from simuleval.data.segments import EmptySegment, TextSegment, SpeechSegment 20 | from simuleval.agents import SpeechToTextAgent 21 | from simuleval.agents.states import AgentStates 22 | from simuleval.agents.actions import WriteAction, ReadAction 23 | 24 | 25 | import pdb 26 | 27 | SHIFT_SIZE = 10 28 | WINDOW_SIZE = 25 29 | SAMPLE_RATE = 16000 30 | FEATURE_DIM = 80 31 | BOW_PREFIX = "\u2581" 32 | DEFAULT_BOS = 0 33 | DEFAULT_EOS = 2 34 | 35 | 36 | class OfflineFeatureExtractor: 37 | """ 38 | Extract speech feature from sequence prefix. 39 | """ 40 | 41 | def __init__(self, args): 42 | self.shift_size = args.shift_size 43 | self.window_size = args.window_size 44 | assert self.window_size >= self.shift_size 45 | 46 | self.sample_rate = args.sample_rate 47 | self.feature_dim = args.feature_dim 48 | self.num_samples_per_shift = int(self.shift_size * self.sample_rate / 1000) 49 | self.num_samples_per_window = int(self.window_size * self.sample_rate / 1000) 50 | self.len_ms_to_samples = lambda x: x * self.sample_rate / 1000 51 | self.global_cmvn = args.global_cmvn 52 | self.device = 'cuda' if args.device == 'gpu' else 'cpu' 53 | 54 | 55 | def __call__(self, new_samples): 56 | samples = new_samples 57 | 58 | assert len(samples) >= self.num_samples_per_window 59 | 60 | num_frames = math.floor( 61 | (len(samples) - self.len_ms_to_samples(self.window_size - self.shift_size)) 62 | / self.num_samples_per_shift 63 | ) 64 | 65 | effective_num_samples = int( 66 | num_frames * self.len_ms_to_samples(self.shift_size) 67 | + self.len_ms_to_samples(self.window_size - self.shift_size) 68 | ) 69 | 70 | input_samples = samples[:effective_num_samples] 71 | 72 | torch.manual_seed(1) 73 | output = extract_fbank_features(torch.FloatTensor(input_samples).unsqueeze(0), self.sample_rate) 74 | 75 | output = self.transform(output) 76 | 77 | return torch.from_numpy(output).to(self.device) 78 | 79 | def transform(self, input): 80 | if self.global_cmvn is None: 81 | return input 82 | 83 | mean = self.global_cmvn["mean"] 84 | std = self.global_cmvn["std"] 85 | 86 | x = np.subtract(input, mean) 87 | x = np.divide(x, std) 88 | return x 89 | 90 | 91 | class TransducerSpeechToTextAgentStates(AgentStates): 92 | 93 | def __init__(self, device) -> None: 94 | self.device = device 95 | self.reset() 96 | 97 | def reset(self) -> None: 98 | """Reset Agent states""" 99 | 100 | super().reset() 101 | 102 | self.num_complete_chunk = 0 103 | self.prev_output_tokens = torch.tensor([[DEFAULT_BOS]]).long().to(self.device) 104 | 105 | self.unfinished_subword = [] 106 | self.incremental_state = torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) 107 | self.h_lm_last = None 108 | 109 | 110 | 111 | @entrypoint 112 | class TransducerSpeechToTextAgent(SpeechToTextAgent): 113 | 114 | speech_segment_size = 40 # in ms, 4 pooling ratio * 10 ms step size 115 | 116 | def __init__(self, args): 117 | super().__init__(args) 118 | 119 | self.device ='cuda' if args.device == 'gpu' else 'cpu' 120 | self.states = self.build_states() 121 | args.global_cmvn = None 122 | if args.config_yaml: 123 | with open(os.path.join(args.data_bin, args.config_yaml), "r") as f: 124 | config = yaml.load(f, Loader=yaml.BaseLoader) 125 | 126 | if "global_cmvn" in config: 127 | args.global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"]) 128 | 129 | 130 | self.load_model_vocab(args) 131 | #utils.import_user_module(args) 132 | 133 | self.feature_extractor = OfflineFeatureExtractor(args) 134 | 135 | self.downsample = args.transducer_downsample 136 | self.main_context = args.main_context 137 | self.right_context = args.right_context 138 | 139 | torch.set_grad_enabled(False) 140 | self.reset() 141 | 142 | 143 | def build_states(self) -> TransducerSpeechToTextAgentStates: 144 | """ 145 | Build states instance for agent 146 | 147 | Returns: 148 | TransducerSpeechToTextAgentStates: agent states 149 | """ 150 | return TransducerSpeechToTextAgentStates(self.device) 151 | 152 | @staticmethod 153 | def add_args(parser): 154 | # fmt: off 155 | parser.add_argument('--model-path', type=str, required=True, 156 | help='path to your pretrained model.') 157 | parser.add_argument("--data-bin", type=str, required=True, 158 | help="Path of data binary") 159 | parser.add_argument("--config-yaml", type=str, default=None, 160 | help="Path to config yaml file") 161 | parser.add_argument("--global-stats", type=str, default=None, 162 | help="Path to json file containing cmvn stats") 163 | parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece", 164 | help="Subword splitter type for target text") 165 | parser.add_argument("--tgt-splitter-path", type=str, default=None, 166 | help="Subword splitter model path for target text") 167 | parser.add_argument("--user-dir", type=str, default="examples/simultaneous_translation", 168 | help="User directory for simultaneous translation") 169 | parser.add_argument("--shift-size", type=int, default=SHIFT_SIZE, 170 | help="Shift size of feature extraction window.") 171 | parser.add_argument("--window-size", type=int, default=WINDOW_SIZE, 172 | help="Window size of feature extraction window.") 173 | parser.add_argument("--sample-rate", type=int, default=SAMPLE_RATE, 174 | help="Sample rate") 175 | parser.add_argument("--feature-dim", type=int, default=FEATURE_DIM, 176 | help="Acoustic feature dimension.") 177 | parser.add_argument("--main-context", type=int, default=32) 178 | parser.add_argument("--right-context", type=int, default=16) 179 | parser.add_argument("--transducer-downsample", type=int, default=1) 180 | parser.add_argument("--device", type=str, default='gpu') 181 | 182 | # fmt: on 183 | return parser 184 | 185 | def load_model_vocab(self, args): 186 | 187 | filename = args.model_path 188 | if not os.path.exists(filename): 189 | raise IOError("Model file not found: {}".format(filename)) 190 | 191 | state = checkpoint_utils.load_checkpoint_to_cpu(filename) 192 | utils.import_user_module(state["cfg"].common) 193 | 194 | task_args = state["cfg"]["task"] 195 | task_args.data = args.data_bin 196 | 197 | if args.config_yaml is not None: 198 | task_args.config_yaml = args.config_yaml 199 | 200 | task = tasks.setup_task(task_args) 201 | 202 | self.model = task.build_model(state["cfg"]["model"]) 203 | self.model.load_state_dict(state["model"], strict=True) 204 | self.model.eval() 205 | self.model.share_memory() 206 | 207 | if self.device == 'cuda': 208 | self.model.cuda() 209 | 210 | # Set dictionary 211 | self.tgt_dict = task.target_dictionary 212 | 213 | @torch.inference_mode() 214 | def policy(self): 215 | 216 | num_frames = math.floor( 217 | (len(self.states.source) - self.feature_extractor.len_ms_to_samples(self.feature_extractor.window_size - self.feature_extractor.shift_size)) 218 | / self.feature_extractor.num_samples_per_shift 219 | ) 220 | 221 | # at least a new complete chunk is received if not finished 222 | if not self.states.source_finished: 223 | if num_frames < self.main_context * (self.states.num_complete_chunk + 1) + self.right_context: 224 | return ReadAction() 225 | 226 | # this is used to caluculate self.states.num_complete_chunk 227 | num_complete_new_chunk = math.floor((num_frames - self.right_context) / self.main_context) - self.states.num_complete_chunk 228 | 229 | # Calculated the number of frames to make decisions 230 | if not self.states.source_finished: 231 | num_decision = num_complete_new_chunk * int(self.main_context / 4 / self.downsample) 232 | else: 233 | num_decision = None 234 | 235 | feature = self.feature_extractor(self.states.source) # prefix feature: T × C 236 | assert num_frames == feature.size(0) 237 | src_tokens = feature.unsqueeze(0) # 1 × T × C 238 | src_lengths = torch.tensor([feature.size(0)], device=self.device).long() # 1 239 | 240 | encoder_out = self.model.encoder(src_tokens, src_lengths) 241 | 242 | downsampled_encoder_out = encoder_out["encoder_out"][0][self.states.num_complete_chunk * (self.main_context // 4):][::self.downsample].squeeze(1) # num_decision × C 243 | downsampled_encoder_out = downsampled_encoder_out[:num_decision] 244 | 245 | final_output_tokens = [] 246 | 247 | if self.states.h_lm_last is None: 248 | self.states.h_lm_last = self.model.decoder.lm(self.states.prev_output_tokens, encoder_out, incremental_state=self.states.incremental_state)[:, -1].squeeze() # V 249 | 250 | for i in range(downsampled_encoder_out.size(0)): 251 | ii=0 # max emit per frame 252 | while True: 253 | #pdb.set_trace() 254 | joint_result = self.model.decoder.jointer.infer(downsampled_encoder_out[i], self.states.h_lm_last) # V 255 | 256 | log_probs = F.log_softmax(self.model.decoder.out_proj(joint_result), dim=-1) 257 | select_word = log_probs.argmax(dim=-1).item() 258 | 259 | if select_word == self.tgt_dict.blank_index: 260 | break 261 | self.states.prev_output_tokens = torch.cat([self.states.prev_output_tokens, torch.tensor([[select_word]], device=self.states.prev_output_tokens.device)], dim=-1) 262 | final_output_tokens.append(select_word) 263 | self.states.h_lm_last = self.model.decoder.lm(self.states.prev_output_tokens, encoder_out, incremental_state=self.states.incremental_state)[:, -1].squeeze() # V 264 | ii += 1 265 | if ii == 5: 266 | break 267 | #if select_word == 416: 268 | # break 269 | 270 | self.states.num_complete_chunk = self.states.num_complete_chunk + num_complete_new_chunk 271 | 272 | 273 | detok_output_tokens = [] 274 | 275 | for index in final_output_tokens: 276 | token = self.tgt_dict.string([index]) #return a string 277 | if token.startswith(BOW_PREFIX): 278 | if len(self.states.unfinished_subword) != 0: 279 | detok_output_tokens += ["".join(self.states.unfinished_subword)] 280 | self.states.unfinished_subword = [] 281 | self.states.unfinished_subword += [token.replace(BOW_PREFIX, "")] 282 | else: 283 | self.states.unfinished_subword += [token] 284 | 285 | if self.states.source_finished: 286 | detok_output_tokens += ["".join(self.states.unfinished_subword)] 287 | self.states.unfinished_subword = [] 288 | 289 | 290 | detok_output_string = " ".join(detok_output_tokens) 291 | 292 | return WriteAction(TextSegment(content=detok_output_string, finished=self.states.source_finished), finished=self.states.source_finished) 293 | 294 | 295 | -------------------------------------------------------------------------------- /fs_plugins/agents/monotonic_transducer_agent.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import json 4 | from typing import Dict, Optional 5 | 6 | import numpy as np 7 | import torch 8 | from torch import Tensor 9 | import torch.nn.functional as F 10 | import torchaudio.compliance.kaldi as kaldi 11 | import yaml 12 | 13 | from fairseq import checkpoint_utils, tasks, utils 14 | from fairseq.file_io import PathManager 15 | from examples.speech_to_text.data_utils import extract_fbank_features 16 | 17 | 18 | from simuleval.utils import entrypoint 19 | from simuleval.data.segments import EmptySegment, TextSegment, SpeechSegment 20 | from simuleval.agents import SpeechToTextAgent 21 | from simuleval.agents.states import AgentStates 22 | from simuleval.agents.actions import WriteAction, ReadAction 23 | 24 | 25 | import pdb 26 | 27 | SHIFT_SIZE = 10 28 | WINDOW_SIZE = 25 29 | SAMPLE_RATE = 16000 30 | FEATURE_DIM = 80 31 | BOW_PREFIX = "\u2581" 32 | DEFAULT_BOS = 0 33 | DEFAULT_EOS = 2 34 | 35 | 36 | class OfflineFeatureExtractor: 37 | """ 38 | Extract speech feature from sequence prefix. 39 | """ 40 | 41 | def __init__(self, args): 42 | self.shift_size = args.shift_size 43 | self.window_size = args.window_size 44 | assert self.window_size >= self.shift_size 45 | 46 | self.sample_rate = args.sample_rate 47 | self.feature_dim = args.feature_dim 48 | self.num_samples_per_shift = int(self.shift_size * self.sample_rate / 1000) 49 | self.num_samples_per_window = int(self.window_size * self.sample_rate / 1000) 50 | self.len_ms_to_samples = lambda x: x * self.sample_rate / 1000 51 | self.global_cmvn = args.global_cmvn 52 | self.device = 'cuda' if args.device == 'gpu' else 'cpu' 53 | 54 | 55 | def __call__(self, new_samples): 56 | samples = new_samples 57 | 58 | assert len(samples) >= self.num_samples_per_window 59 | 60 | num_frames = math.floor( 61 | (len(samples) - self.len_ms_to_samples(self.window_size - self.shift_size)) 62 | / self.num_samples_per_shift 63 | ) 64 | 65 | effective_num_samples = int( 66 | num_frames * self.len_ms_to_samples(self.shift_size) 67 | + self.len_ms_to_samples(self.window_size - self.shift_size) 68 | ) 69 | 70 | input_samples = samples[:effective_num_samples] 71 | 72 | torch.manual_seed(1) 73 | output = extract_fbank_features(torch.FloatTensor(input_samples).unsqueeze(0), self.sample_rate) 74 | 75 | output = self.transform(output) 76 | 77 | return torch.from_numpy(output).to(self.device) 78 | 79 | def transform(self, input): 80 | if self.global_cmvn is None: 81 | return input 82 | 83 | mean = self.global_cmvn["mean"] 84 | std = self.global_cmvn["std"] 85 | 86 | x = np.subtract(input, mean) 87 | x = np.divide(x, std) 88 | return x 89 | 90 | 91 | class TransducerSpeechToTextAgentStates(AgentStates): 92 | 93 | def __init__(self, device) -> None: 94 | self.device = device 95 | self.reset() 96 | 97 | def reset(self) -> None: 98 | """Reset Agent states""" 99 | 100 | super().reset() 101 | 102 | self.num_complete_chunk = 0 103 | self.prev_output_tokens = torch.tensor([[DEFAULT_BOS]]).long().to(self.device) 104 | 105 | self.unfinished_subword = [] 106 | self.incremental_state = torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) 107 | self.h_lm_last = None 108 | 109 | 110 | 111 | @entrypoint 112 | class TransducerSpeechToTextAgent(SpeechToTextAgent): 113 | 114 | speech_segment_size = 40 # in ms, 4 pooling ratio * 10 ms step size 115 | 116 | def __init__(self, args): 117 | super().__init__(args) 118 | 119 | self.device ='cuda' if args.device == 'gpu' else 'cpu' 120 | self.states = self.build_states() 121 | args.global_cmvn = None 122 | if args.config_yaml: 123 | with open(os.path.join(args.data_bin, args.config_yaml), "r") as f: 124 | config = yaml.load(f, Loader=yaml.BaseLoader) 125 | 126 | if "global_cmvn" in config: 127 | args.global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"]) 128 | 129 | 130 | self.load_model_vocab(args) 131 | #utils.import_user_module(args) 132 | 133 | self.feature_extractor = OfflineFeatureExtractor(args) 134 | 135 | self.downsample = args.transducer_downsample 136 | self.main_context = args.main_context 137 | self.right_context = args.right_context 138 | 139 | torch.set_grad_enabled(False) 140 | self.reset() 141 | 142 | 143 | def build_states(self) -> TransducerSpeechToTextAgentStates: 144 | """ 145 | Build states instance for agent 146 | 147 | Returns: 148 | TransducerSpeechToTextAgentStates: agent states 149 | """ 150 | return TransducerSpeechToTextAgentStates(self.device) 151 | 152 | @staticmethod 153 | def add_args(parser): 154 | # fmt: off 155 | parser.add_argument('--model-path', type=str, required=True, 156 | help='path to your pretrained model.') 157 | parser.add_argument("--data-bin", type=str, required=True, 158 | help="Path of data binary") 159 | parser.add_argument("--config-yaml", type=str, default=None, 160 | help="Path to config yaml file") 161 | parser.add_argument("--global-stats", type=str, default=None, 162 | help="Path to json file containing cmvn stats") 163 | parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece", 164 | help="Subword splitter type for target text") 165 | parser.add_argument("--tgt-splitter-path", type=str, default=None, 166 | help="Subword splitter model path for target text") 167 | parser.add_argument("--user-dir", type=str, default="examples/simultaneous_translation", 168 | help="User directory for simultaneous translation") 169 | parser.add_argument("--shift-size", type=int, default=SHIFT_SIZE, 170 | help="Shift size of feature extraction window.") 171 | parser.add_argument("--window-size", type=int, default=WINDOW_SIZE, 172 | help="Window size of feature extraction window.") 173 | parser.add_argument("--sample-rate", type=int, default=SAMPLE_RATE, 174 | help="Sample rate") 175 | parser.add_argument("--feature-dim", type=int, default=FEATURE_DIM, 176 | help="Acoustic feature dimension.") 177 | parser.add_argument("--main-context", type=int, default=32) 178 | parser.add_argument("--right-context", type=int, default=16) 179 | parser.add_argument("--transducer-downsample", type=int, default=1) 180 | parser.add_argument("--device", type=str, default='gpu') 181 | 182 | # fmt: on 183 | return parser 184 | 185 | def load_model_vocab(self, args): 186 | 187 | filename = args.model_path 188 | if not os.path.exists(filename): 189 | raise IOError("Model file not found: {}".format(filename)) 190 | 191 | state = checkpoint_utils.load_checkpoint_to_cpu(filename) 192 | utils.import_user_module(state["cfg"].common) 193 | 194 | task_args = state["cfg"]["task"] 195 | task_args.data = args.data_bin 196 | 197 | if args.config_yaml is not None: 198 | task_args.config_yaml = args.config_yaml 199 | 200 | task = tasks.setup_task(task_args) 201 | 202 | self.model = task.build_model(state["cfg"]["model"]) 203 | self.model.load_state_dict(state["model"], strict=True) 204 | self.model.eval() 205 | self.model.share_memory() 206 | 207 | if self.device == 'cuda': 208 | self.model.cuda() 209 | 210 | # Set dictionary 211 | self.tgt_dict = task.target_dictionary 212 | 213 | @torch.inference_mode() 214 | def policy(self): 215 | 216 | num_frames = math.floor( 217 | (len(self.states.source) - self.feature_extractor.len_ms_to_samples(self.feature_extractor.window_size - self.feature_extractor.shift_size)) 218 | / self.feature_extractor.num_samples_per_shift 219 | ) 220 | 221 | # at least a new complete chunk is received if not finished 222 | if not self.states.source_finished: 223 | if num_frames < self.main_context * (self.states.num_complete_chunk + 1) + self.right_context: 224 | return ReadAction() 225 | 226 | # this is used to caluculate self.states.num_complete_chunk 227 | num_complete_new_chunk = math.floor((num_frames - self.right_context) / self.main_context) - self.states.num_complete_chunk 228 | 229 | # Calculated the number of frames to make decisions 230 | if not self.states.source_finished: 231 | num_decision = num_complete_new_chunk * int(self.main_context / 4 / self.downsample) 232 | else: 233 | num_decision = None 234 | 235 | feature = self.feature_extractor(self.states.source) # prefix feature: T × C 236 | assert num_frames == feature.size(0) 237 | src_tokens = feature.unsqueeze(0) # 1 × T × C 238 | src_lengths = torch.tensor([feature.size(0)], device=self.device).long() # 1 239 | 240 | encoder_out = self.model.encoder(src_tokens, src_lengths) 241 | 242 | downsampled_encoder_out = encoder_out["encoder_out"][0][self.states.num_complete_chunk * (self.main_context // 4):][::self.downsample].squeeze(1) # num_decision × C 243 | downsampled_encoder_out = downsampled_encoder_out[:num_decision] 244 | 245 | final_output_tokens = [] 246 | 247 | if self.states.h_lm_last is None: 248 | self.states.h_lm_last = self.model.decoder.lm(self.states.prev_output_tokens, encoder_out, incremental_state=self.states.incremental_state)[:, -1].squeeze() # V 249 | 250 | for i in range(downsampled_encoder_out.size(0)): 251 | ii=0 # max emit per frame 252 | while True: 253 | #pdb.set_trace() 254 | joint_result = self.model.decoder.jointer.infer(downsampled_encoder_out[i], self.states.h_lm_last) # V 255 | 256 | log_probs = F.log_softmax(self.model.decoder.out_proj(joint_result), dim=-1) 257 | select_word = log_probs.argmax(dim=-1).item() 258 | 259 | if select_word == self.tgt_dict.blank_index: 260 | break 261 | self.states.prev_output_tokens = torch.cat([self.states.prev_output_tokens, torch.tensor([[select_word]], device=self.states.prev_output_tokens.device)], dim=-1) 262 | final_output_tokens.append(select_word) 263 | self.states.h_lm_last = self.model.decoder.lm(self.states.prev_output_tokens, encoder_out, incremental_state=self.states.incremental_state)[:, -1].squeeze() # V 264 | ii += 1 265 | if ii == 5: 266 | break 267 | #if select_word == 416: 268 | # break 269 | 270 | self.states.num_complete_chunk = self.states.num_complete_chunk + num_complete_new_chunk 271 | 272 | 273 | detok_output_tokens = [] 274 | 275 | for index in final_output_tokens: 276 | token = self.tgt_dict.string([index]) #return a string 277 | if token.startswith(BOW_PREFIX): 278 | if len(self.states.unfinished_subword) != 0: 279 | detok_output_tokens += ["".join(self.states.unfinished_subword)] 280 | self.states.unfinished_subword = [] 281 | self.states.unfinished_subword += [token.replace(BOW_PREFIX, "")] 282 | else: 283 | self.states.unfinished_subword += [token] 284 | 285 | if self.states.source_finished: 286 | detok_output_tokens += ["".join(self.states.unfinished_subword)] 287 | self.states.unfinished_subword = [] 288 | 289 | 290 | detok_output_string = " ".join(detok_output_tokens) 291 | 292 | return WriteAction(TextSegment(content=detok_output_string, finished=self.states.source_finished), finished=self.states.source_finished) 293 | 294 | 295 | -------------------------------------------------------------------------------- /fs_plugins/modules/monotonic_transducer_decoder.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | import torch 3 | 4 | from torch import Tensor 5 | import torch.nn as nn 6 | 7 | from fairseq import utils 8 | from fairseq.models import FairseqIncrementalDecoder 9 | from fairseq.models.transformer import TransformerDecoder 10 | from fairseq.modules import LayerNorm 11 | from fairseq.modules.checkpoint_activations import checkpoint_wrapper 12 | from fairseq.distributed import fsdp_wrap 13 | import pdb 14 | from .monotonic_transformer_layer import MonotonicTransformerDecoderLayer 15 | 16 | 17 | 18 | 19 | class MonotonicDecoder(TransformerDecoder): 20 | def __init__(self, args, dictionary, embed_tokens): 21 | super().__init__( 22 | args, dictionary, embed_tokens, no_encoder_attn=False 23 | ) 24 | self.output_projection= None 25 | 26 | def build_decoder_layer(self, cfg, no_encoder_attn=False): 27 | layer = MonotonicTransformerDecoderLayer(cfg) 28 | checkpoint = cfg.checkpoint_activations 29 | if checkpoint: 30 | offload_to_cpu = cfg.offload_activations 31 | layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) 32 | # if we are checkpointing, enforce that FSDP always wraps the 33 | # checkpointed layer, regardless of layer size 34 | min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 35 | layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) 36 | return layer 37 | 38 | def extract_features( 39 | self, 40 | prev_output_tokens, 41 | encoder_out: Optional[Dict[str, List[Tensor]]], 42 | posterior: Tensor, 43 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, 44 | full_context_alignment: bool = False, 45 | alignment_layer: Optional[int] = None, 46 | alignment_heads: Optional[int] = None, 47 | ): 48 | return self.extract_features_scriptable( 49 | prev_output_tokens, 50 | encoder_out, 51 | posterior, 52 | incremental_state, 53 | full_context_alignment, 54 | alignment_layer, 55 | alignment_heads, 56 | ) 57 | 58 | """ 59 | A scriptable subclass of this class has an extract_features method and calls 60 | super().extract_features, but super() is not supported in torchscript. A copy of 61 | this function is made to be used in the subclass instead. 62 | """ 63 | 64 | def extract_features_scriptable( 65 | self, 66 | prev_output_tokens, 67 | encoder_out: Optional[Dict[str, List[Tensor]]], 68 | posterior: Tensor, 69 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, 70 | full_context_alignment: bool = False, 71 | alignment_layer: Optional[int] = None, 72 | alignment_heads: Optional[int] = None, 73 | ): 74 | if alignment_layer is None: 75 | alignment_layer = self.num_layers - 1 76 | 77 | # embed positions 78 | positions = ( 79 | self.embed_positions( 80 | prev_output_tokens, incremental_state=incremental_state 81 | ) 82 | if self.embed_positions is not None 83 | else None 84 | ) 85 | 86 | if incremental_state is not None: 87 | prev_output_tokens = prev_output_tokens[:, -1:] 88 | if positions is not None: 89 | positions = positions[:, -1:] 90 | 91 | # embed tokens and positions 92 | x = self.embed_scale * self.embed_tokens(prev_output_tokens) 93 | 94 | if self.quant_noise is not None: 95 | x = self.quant_noise(x) 96 | 97 | if self.project_in_dim is not None: 98 | x = self.project_in_dim(x) 99 | 100 | if positions is not None: 101 | x += positions 102 | 103 | if self.layernorm_embedding is not None: 104 | x = self.layernorm_embedding(x) 105 | 106 | x = self.dropout_module(x) 107 | 108 | # B x T x C -> T x B x C 109 | x = x.transpose(0, 1) 110 | 111 | self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) 112 | 113 | # decoder layers 114 | attn: Optional[Tensor] = None 115 | inner_states: List[Optional[Tensor]] = [x] 116 | for idx, layer in enumerate(self.layers): 117 | if incremental_state is None and not full_context_alignment: 118 | self_attn_mask = self.buffered_future_mask(x) 119 | else: 120 | self_attn_mask = None 121 | 122 | x, layer_attn, _ = layer( 123 | x, 124 | encoder_out["encoder_out"][0] 125 | if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0) 126 | else None, 127 | encoder_out["encoder_padding_mask"][0] 128 | if ( 129 | encoder_out is not None 130 | and len(encoder_out["encoder_padding_mask"]) > 0 131 | ) 132 | else None, 133 | posterior, 134 | incremental_state, 135 | self_attn_mask=self_attn_mask, 136 | self_attn_padding_mask=self_attn_padding_mask, 137 | need_attn=bool((idx == alignment_layer)), 138 | need_head_weights=bool((idx == alignment_layer)), 139 | ) 140 | inner_states.append(x) 141 | if layer_attn is not None and idx == alignment_layer: 142 | attn = layer_attn.float().to(x) 143 | 144 | if attn is not None: 145 | if alignment_heads is not None: 146 | attn = attn[:alignment_heads] 147 | 148 | # average probabilities over heads 149 | attn = attn.mean(dim=0) 150 | 151 | if self.layer_norm is not None: 152 | x = self.layer_norm(x) 153 | 154 | # T x B x C -> B x T x C 155 | x = x.transpose(0, 1) 156 | 157 | if self.project_out_dim is not None: 158 | x = self.project_out_dim(x) 159 | 160 | return x, {"attn": [attn], "inner_states": inner_states} 161 | 162 | def forward( 163 | self, 164 | prev_output_tokens, 165 | encoder_out: Optional[Dict[str, List[Tensor]]] = None, 166 | posterior: Tensor = None, 167 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, 168 | ): 169 | """ 170 | for transducer, prev_output_tokens should be [bos] concat target (including [eos]) 171 | """ 172 | x, extra = self.extract_features( 173 | prev_output_tokens, 174 | encoder_out=encoder_out, 175 | posterior=posterior, 176 | incremental_state=incremental_state, 177 | full_context_alignment=False, 178 | alignment_layer=None, 179 | alignment_heads=None, 180 | ) 181 | return x 182 | 183 | 184 | class AddJointNet(nn.Module): 185 | def __init__( 186 | self, 187 | encoder_dim, 188 | decoder_dim, 189 | hid_dim, 190 | activation="tanh", 191 | downsample=1, 192 | ): 193 | super().__init__() 194 | self.downsample = downsample 195 | self.encoder_proj = nn.Linear(encoder_dim, hid_dim) 196 | self.decoder_proj = nn.Linear(decoder_dim, hid_dim) 197 | self.activation_fn = utils.get_activation_fn(activation) 198 | #self.joint_proj = nn.Linear(hid_dim, hid_dim) 199 | #self.layer_norm = LayerNorm(hid_dim) 200 | if downsample < 1: 201 | raise ValueError("downsample should be more than 1 for add_joint") 202 | 203 | def forward(self, encoder_out:Dict[str, List[Tensor]], decoder_state, padding_idx): 204 | """ 205 | use dimension same as transformer 206 | Args: 207 | encoder_out: "encoder_out": TxBxC 208 | decoder_state: BxUxC 209 | """ 210 | encoder_state = encoder_out["encoder_out"][0] 211 | encoder_state = encoder_state[::self.downsample].contiguous() 212 | encoder_state = encoder_state.transpose(0,1) 213 | 214 | h_enc = self.encoder_proj(encoder_state) 215 | h_dec = self.decoder_proj(decoder_state) 216 | h_joint = h_enc.unsqueeze(2) + h_dec.unsqueeze(1) 217 | h_joint = self.activation_fn(h_joint) 218 | #h_joint = self.joint_proj(h_joint) 219 | #h_joint = self.layer_norm(h_joint) 220 | 221 | fake_src_tokens = (encoder_out["encoder_padding_mask"][0]).long() 222 | fake_src_lengths = fake_src_tokens.ne(padding_idx).sum(dim=-1) 223 | fake_src_lengths = (fake_src_lengths / self.downsample).ceil().long() 224 | 225 | return h_joint, fake_src_lengths 226 | 227 | def infer(self, encoder_state, decoder_state): 228 | """ 229 | use dimension same as transformer 230 | Args: 231 | encoder_out: "encoder_out": C 232 | decoder_state: C 233 | """ 234 | 235 | h_enc = self.encoder_proj(encoder_state) 236 | h_dec = self.decoder_proj(decoder_state) 237 | h_joint = h_enc + h_dec 238 | h_joint = self.activation_fn(h_joint) 239 | #h_joint = self.joint_proj(h_joint) 240 | #h_joint = self.layer_norm(h_joint) 241 | 242 | return h_joint 243 | 244 | 245 | 246 | class ConcatJointNet(nn.Module): 247 | def __init__( 248 | self, 249 | encoder_dim, 250 | decoder_dim, 251 | hid_dim, 252 | activation="tanh", 253 | downsample=1, 254 | ) -> None: 255 | super().__init__() 256 | self.fc1 = nn.Linear((encoder_dim+decoder_dim), hid_dim) 257 | self.downsample = downsample 258 | self.activation_fn = utils.get_activation_fn(activation) 259 | if downsample < 1: 260 | raise ValueError("downsample should be more than 1 for concat_joint") 261 | 262 | def forward(self, encoder_out:Dict[str, List[Tensor]], decoder_state, padding_idx): 263 | 264 | encoder_state = encoder_out["encoder_out"][0] 265 | encoder_state = encoder_state[::self.downsample].contiguous() #TODO: downsample 266 | encoder_state = encoder_state.transpose(0,1) 267 | 268 | seq_lens = encoder_state.size(1) 269 | target_lens = decoder_state.size(1) 270 | 271 | encoder_state = encoder_state.unsqueeze(2) 272 | decoder_state = decoder_state.unsqueeze(1) 273 | 274 | encoder_state = encoder_state.expand(-1, -1, target_lens, -1) 275 | decoder_state = decoder_state.expand(-1, seq_lens, -1, -1) 276 | 277 | h_joint = torch.cat((encoder_state, decoder_state), dim=-1) 278 | 279 | h_joint = self.fc1(h_joint) 280 | h_joint = self.activation_fn(h_joint) 281 | 282 | fake_src_tokens = (encoder_out["encoder_padding_mask"][0]).long() 283 | fake_src_lengths = fake_src_tokens.ne(padding_idx).sum(dim=-1) 284 | fake_src_lengths = (fake_src_lengths / self.downsample).ceil().long() 285 | 286 | return h_joint, fake_src_lengths 287 | 288 | 289 | 290 | class MonotonicTransducerDecoder(FairseqIncrementalDecoder): 291 | def __init__(self, args, dictionary, embed_tokens): 292 | super().__init__(dictionary) 293 | self.lm = MonotonicDecoder(args, dictionary, embed_tokens) 294 | self.output_embed_dim = args.decoder_output_dim 295 | self.out_proj = nn.Linear(args.decoder_output_dim, len(dictionary), bias=False) 296 | if args.share_decoder_input_output_embed: 297 | self.out_proj.weight= embed_tokens.weight 298 | else: 299 | nn.init.normal_( 300 | self.out_proj.weight, mean=0, std=self.output_embed_dim ** -0.5 301 | ) 302 | self.blank_idx= dictionary.blank_index 303 | self.padding_idx = dictionary.pad() 304 | self.downsample = getattr(args, "transducer_downsample", 1) 305 | #self.jointer = ConcatJointNet(args.encoder_embed_dim, args.decoder_output_dim, args.decoder_output_dim, downsample=self.downsample) 306 | self.jointer = AddJointNet(args.encoder_embed_dim, args.decoder_output_dim, args.decoder_output_dim, downsample=self.downsample) 307 | 308 | def forward_naive( 309 | self, 310 | prev_output_tokens:Tensor, 311 | encoder_out:Dict[str, List[Tensor]], 312 | ): 313 | h_lm = self.lm(prev_output_tokens) 314 | 315 | joint_result, fake_src_lengths = self.jointer(encoder_out, h_lm, self.padding_idx) 316 | 317 | joint_result = self.out_proj(joint_result) # it is logits, no logsoftmax performed 318 | 319 | return joint_result, fake_src_lengths 320 | 321 | def forward( 322 | self, 323 | prev_output_tokens:Tensor, 324 | encoder_out:Dict[str, List[Tensor]], 325 | posterior:Tensor, 326 | ): 327 | h_lm = self.lm(prev_output_tokens, encoder_out, posterior) 328 | 329 | joint_result, fake_src_lengths = self.jointer(encoder_out, h_lm, self.padding_idx) 330 | 331 | joint_result = self.out_proj(joint_result) # it is logits, no logsoftmax performed 332 | 333 | return joint_result, fake_src_lengths 334 | -------------------------------------------------------------------------------- /fs_plugins/modules/multihead_attention_patched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import Parameter 9 | import torch.nn.functional as F 10 | 11 | from fairseq import utils 12 | from fairseq import pdb 13 | from fairseq import utils 14 | from fairseq.incremental_decoding_utils import with_incremental_state 15 | from fairseq.modules.fairseq_dropout import FairseqDropout 16 | from fairseq.modules.quant_noise import quant_noise 17 | from fairseq.modules import MultiheadAttention 18 | from torch import Tensor, nn 19 | from torch.nn import Parameter 20 | from typing import Dict, Optional, Tuple 21 | 22 | 23 | 24 | class MultiheadAttentionPatched(MultiheadAttention): 25 | """ 26 | small modify on padding_mask: during inference, set padding to -1e10 instead of -inf, for CAAT decoding 27 | """ 28 | def forward( 29 | self, 30 | query, 31 | key: Optional[Tensor], 32 | value: Optional[Tensor], 33 | key_padding_mask: Optional[Tensor] = None, 34 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, 35 | need_weights: bool = True, 36 | static_kv: bool = False, 37 | attn_mask: Optional[Tensor] = None, 38 | before_softmax: bool = False, 39 | need_head_weights: bool = False, 40 | ) -> Tuple[Tensor, Optional[Tensor]]: 41 | """Input shape: Time x Batch x Channel 42 | 43 | Args: 44 | key_padding_mask (ByteTensor, optional): mask to exclude 45 | keys that are pads, of shape `(batch, src_len)`, where 46 | padding elements are indicated by 1s. 47 | need_weights (bool, optional): return the attention weights, 48 | averaged over heads (default: False). 49 | attn_mask (ByteTensor, optional): typically used to 50 | implement causal attention, where the mask prevents the 51 | attention from looking forward in time (default: None). 52 | before_softmax (bool, optional): return the raw attention 53 | weights and values before the attention softmax. 54 | need_head_weights (bool, optional): return the attention 55 | weights for each head. Implies *need_weights*. Default: 56 | return the average attention weights over all heads. 57 | """ 58 | if need_head_weights: 59 | need_weights = True 60 | 61 | is_tpu = query.device.type == "xla" 62 | 63 | tgt_len, bsz, embed_dim = query.size() 64 | src_len = tgt_len 65 | assert embed_dim == self.embed_dim 66 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 67 | if key is not None: 68 | src_len, key_bsz, _ = key.size() 69 | if not torch.jit.is_scripting(): 70 | assert key_bsz == bsz 71 | assert value is not None 72 | assert src_len, bsz == value.shape[:2] 73 | 74 | if ( 75 | not self.onnx_trace 76 | and not is_tpu # don't use PyTorch version on TPUs 77 | and incremental_state is None 78 | and not static_kv 79 | # A workaround for quantization to work. Otherwise JIT compilation 80 | # treats bias in linear module as method. 81 | and not torch.jit.is_scripting() 82 | ): 83 | assert key is not None and value is not None 84 | return F.multi_head_attention_forward( 85 | query, 86 | key, 87 | value, 88 | self.embed_dim, 89 | self.num_heads, 90 | torch.empty([0]), 91 | torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), 92 | self.bias_k, 93 | self.bias_v, 94 | self.add_zero_attn, 95 | self.dropout_module.p, 96 | self.out_proj.weight, 97 | self.out_proj.bias, 98 | self.training or self.dropout_module.apply_during_inference, 99 | key_padding_mask, 100 | need_weights, 101 | attn_mask, 102 | use_separate_proj_weight=True, 103 | q_proj_weight=self.q_proj.weight, 104 | k_proj_weight=self.k_proj.weight, 105 | v_proj_weight=self.v_proj.weight, 106 | ) 107 | 108 | if incremental_state is not None: 109 | saved_state = self._get_input_buffer(incremental_state) 110 | if saved_state is not None and "prev_key" in saved_state: 111 | # previous time steps are cached - no need to recompute 112 | # key and value if they are static 113 | if static_kv: 114 | assert self.encoder_decoder_attention and not self.self_attention 115 | key = value = None 116 | else: 117 | saved_state = None 118 | 119 | if self.self_attention: 120 | q = self.q_proj(query) 121 | k = self.k_proj(query) 122 | v = self.v_proj(query) 123 | elif self.encoder_decoder_attention: 124 | # encoder-decoder attention 125 | q = self.q_proj(query) 126 | if key is None: 127 | assert value is None 128 | k = v = None 129 | else: 130 | k = self.k_proj(key) 131 | v = self.v_proj(key) 132 | 133 | else: 134 | assert key is not None and value is not None 135 | q = self.q_proj(query) 136 | k = self.k_proj(key) 137 | v = self.v_proj(value) 138 | q *= self.scaling 139 | 140 | if self.bias_k is not None: 141 | assert self.bias_v is not None 142 | k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) 143 | v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) 144 | if attn_mask is not None: 145 | attn_mask = torch.cat( 146 | [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 147 | ) 148 | if key_padding_mask is not None: 149 | key_padding_mask = torch.cat( 150 | [ 151 | key_padding_mask, 152 | key_padding_mask.new_zeros(key_padding_mask.size(0), 1), 153 | ], 154 | dim=1, 155 | ) 156 | 157 | q = ( 158 | q.contiguous() 159 | .view(tgt_len, bsz * self.num_heads, self.head_dim) 160 | .transpose(0, 1) 161 | ) 162 | if k is not None: 163 | k = ( 164 | k.contiguous() 165 | .view(-1, bsz * self.num_heads, self.head_dim) 166 | .transpose(0, 1) 167 | ) 168 | if v is not None: 169 | v = ( 170 | v.contiguous() 171 | .view(-1, bsz * self.num_heads, self.head_dim) 172 | .transpose(0, 1) 173 | ) 174 | 175 | if saved_state is not None: 176 | # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) 177 | if "prev_key" in saved_state: 178 | _prev_key = saved_state["prev_key"] 179 | assert _prev_key is not None 180 | prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) 181 | if static_kv: 182 | k = prev_key 183 | else: 184 | assert k is not None 185 | k = torch.cat([prev_key, k], dim=1) 186 | src_len = k.size(1) 187 | if "prev_value" in saved_state: 188 | _prev_value = saved_state["prev_value"] 189 | assert _prev_value is not None 190 | prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) 191 | if static_kv: 192 | v = prev_value 193 | else: 194 | assert v is not None 195 | v = torch.cat([prev_value, v], dim=1) 196 | prev_key_padding_mask: Optional[Tensor] = None 197 | if "prev_key_padding_mask" in saved_state: 198 | prev_key_padding_mask = saved_state["prev_key_padding_mask"] 199 | assert k is not None and v is not None 200 | key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( 201 | key_padding_mask=key_padding_mask, 202 | prev_key_padding_mask=prev_key_padding_mask, 203 | batch_size=bsz, 204 | src_len=k.size(1), 205 | static_kv=static_kv, 206 | ) 207 | 208 | saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) 209 | saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) 210 | saved_state["prev_key_padding_mask"] = key_padding_mask 211 | # In this branch incremental_state is never None 212 | assert incremental_state is not None 213 | incremental_state = self._set_input_buffer(incremental_state, saved_state) 214 | assert k is not None 215 | assert k.size(1) == src_len 216 | 217 | # This is part of a workaround to get around fork/join parallelism 218 | # not supporting Optional types. 219 | if key_padding_mask is not None and key_padding_mask.dim() == 0: 220 | key_padding_mask = None 221 | 222 | if key_padding_mask is not None: 223 | assert key_padding_mask.size(0) == bsz 224 | assert key_padding_mask.size(1) == src_len 225 | 226 | if self.add_zero_attn: 227 | assert v is not None 228 | src_len += 1 229 | k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) 230 | v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) 231 | if attn_mask is not None: 232 | attn_mask = torch.cat( 233 | [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 234 | ) 235 | if key_padding_mask is not None: 236 | key_padding_mask = torch.cat( 237 | [ 238 | key_padding_mask, 239 | torch.zeros(key_padding_mask.size(0), 1).type_as( 240 | key_padding_mask 241 | ), 242 | ], 243 | dim=1, 244 | ) 245 | 246 | attn_weights = torch.bmm(q, k.transpose(1, 2)) 247 | attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) 248 | 249 | assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] 250 | 251 | if attn_mask is not None: 252 | attn_mask = attn_mask.unsqueeze(0) 253 | if self.onnx_trace: 254 | attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) 255 | attn_weights += attn_mask 256 | 257 | if key_padding_mask is not None: 258 | # don't attend to padding symbols 259 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 260 | if not is_tpu: 261 | if self.training or attn_weights.dtype==torch.float16: 262 | attn_weights = attn_weights.masked_fill( 263 | key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), 264 | float("-inf"), 265 | ) 266 | else: 267 | # for attn_mask may only see pad 268 | attn_weights = attn_weights.masked_fill( 269 | key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), 270 | -1e10, 271 | ) 272 | else: 273 | attn_weights = attn_weights.transpose(0, 2) 274 | attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) 275 | attn_weights = attn_weights.transpose(0, 2) 276 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 277 | 278 | if before_softmax: 279 | return attn_weights, v 280 | 281 | attn_weights_float = utils.softmax( 282 | attn_weights, dim=-1, onnx_trace=self.onnx_trace 283 | ) 284 | attn_weights = attn_weights_float.type_as(attn_weights) 285 | attn_probs = self.dropout_module(attn_weights) 286 | 287 | assert v is not None 288 | attn = torch.bmm(attn_probs, v) 289 | assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] 290 | if self.onnx_trace and attn.size(1) == 1: 291 | # when ONNX tracing a single decoder step (sequence length == 1) 292 | # the transpose is a no-op copy before view, thus unnecessary 293 | attn = attn.contiguous().view(tgt_len, bsz, embed_dim) 294 | else: 295 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 296 | attn = self.out_proj(attn) 297 | attn_weights: Optional[Tensor] = None 298 | if need_weights: 299 | attn_weights = attn_weights_float.view( 300 | bsz, self.num_heads, tgt_len, src_len 301 | ).transpose(1, 0) 302 | if not need_head_weights: 303 | # average attention weights over heads 304 | attn_weights = attn_weights.mean(dim=0) 305 | 306 | return attn, attn_weights -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Overcoming Non-monotonicity in Transducer-based Streaming Generation 2 | > **Authors: [Zhengrui Ma](https://scholar.google.com/citations?user=dUgq6tEAAAAJ), [Yang Feng*](https://people.ucas.edu.cn/~yangfeng?language=en), [Min Zhang](https://scholar.google.com/citations?user=CncXH-YAAAAJ&hl=en)** 3 | 4 | **Files**: 5 | 6 | - We mainly provide the following files as plugins into [``fairseq:920a54``](https://github.com/facebookresearch/fairseq/tree/920a548ca770fb1a951f7f4289b4d3a0c1bc226f) in the [``fs_plugins``](https://github.com/ictnlp/MonoAttn-Transducer/tree/main/fs_plugins) directory. 7 | 8 | ``` 9 | fs_plugins 10 | ├── agents 11 | │ ├── attention_transducer_agent.py 12 | │ ├── monotonic_transducer_agent.py 13 | | ├── transducer_agent.py 14 | │ └── transducer_agent_v2.py 15 | ├── criterions 16 | │ ├── __init__.py 17 | │ ├── transducer_loss.py 18 | │ └── transducer_loss_asr.py 19 | ├── datasets 20 | │ └── transducer_speech_to_text_dataset.py 21 | ├── models 22 | │ ├── transducer 23 | │ │ ├── __init__.py 24 | │ │ ├── attention_transducer.py 25 | │ │ ├── monotonic_transducer.py 26 | │ │ ├── monotonic_transducer_chunk_diagonal_prior.py 27 | │ │ ├── monotonic_transducer_chunk_diagonal_prior_only.py 28 | │ │ ├── monotonic_transducer_diagonal_prior.py 29 | │ │ ├── transducer.py 30 | │ │ ├── transducer_config.py 31 | │ │ └── transducer_loss.py 32 | │ └── __init__.py 33 | ├── modules 34 | │ ├── attention_transducer_decoder.py 35 | │ ├── audio_convs.py 36 | │ ├── audio_encoder.py 37 | │ ├── monotonic_transducer_decoder.py 38 | │ ├── monotonic_transformer_layer.py 39 | │ ├── multihead_attention_patched.py 40 | │ ├── multihead_attention_relative.py 41 | │ ├── rand_pos.py 42 | │ ├── transducer_decoder.py 43 | │ ├── transducer_monotonic_multihead_attention.py 44 | │ └── unidirectional_encoder.py 45 | ├── optim 46 | │ ├── __init__.py 47 | │ └── radam.py 48 | ├── scripts 49 | │ ├── average_checkpoints.py 50 | | ├── prep_mustc_data.py 51 | │ └── substitute_target.py 52 | ├── tasks 53 | │ ├── __init__.py 54 | │ └── transducer_speech_to_text.py 55 | ├── __init__.py 56 | └── utils.py 57 | ``` 58 | ## Data Preparation 59 | Please refer to [Fairseq's speech-to-text modeling tutorial](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_text/README.md). 60 | 61 | ## Training Transformer-Transducer 62 | ### ASR Pretraining 63 | We use a batch size of approximating 160k tokens **(GPU number * max_tokens * update_freq == 160k)**. 64 | 65 | ```bash 66 | main=64 67 | downsample=4 68 | lr=5e-4 69 | warm=4000 70 | dropout=0.1 71 | tokens=8000 72 | language=es 73 | 74 | exp=en${language}.asr.cs_${main}.ds_${downsample}.kd.t_t.add.prenorm.amp.adam.lr_${lr}.warm_${warm}.drop_${dropout}.tk_${tokens}.bsz_160k 75 | MUSTC_ROOT=/path_to_your_dataset/mustc/ 76 | checkpoint_dir=./checkpoints/$exp 77 | 78 | nohup fairseq-train ${MUSTC_ROOT}/en-${language} \ 79 | --amp \ 80 | --config-yaml config_st.yaml --train-subset train_st_distilled --valid-subset dev_st \ 81 | --user-dir fs_plugins \ 82 | --task transducer_speech_to_text --arch t_t \ 83 | --max-source-positions 6000 --max-target-positions 1024 \ 84 | --main-context ${main} --right-context 0 --transducer-downsample ${downsample} \ 85 | --share-decoder-input-output-embed --rand-pos-encoder 300 --encoder-max-relative-position 32 \ 86 | --activation-dropout 0.1 --attention-dropout 0.1 \ 87 | --criterion transducer_loss_asr \ 88 | --dropout ${dropout} --weight-decay 0.01 --clip-norm 5.0 \ 89 | --optimizer adam --adam-betas '(0.9,0.98)' \ 90 | --lr ${lr} --lr-scheduler inverse_sqrt \ 91 | --warmup-init-lr '1e-07' --warmup-updates ${warm} \ 92 | --stop-min-lr '1e-09' --max-update 150000 \ 93 | --max-tokens ${tokens} --update-freq 20 --grouped-shuffling \ 94 | --save-dir ${checkpoint_dir} \ 95 | --ddp-backend=legacy_ddp \ 96 | --no-progress-bar --log-format json --log-interval 100 \ 97 | --save-interval-updates 2000 --keep-interval-updates 10 \ 98 | --save-interval 1000 --keep-last-epochs 10 \ 99 | --fixed-validation-seed 7 \ 100 | --skip-invalid-size-inputs-valid-test \ 101 | --validate-interval 1000 --validate-interval-updates 2000 \ 102 | --best-checkpoint-metric rnn_t_loss --keep-best-checkpoints 5 \ 103 | --patience 20 --num-workers 8 \ 104 | --tensorboard-logdir logs_board/$exp >> logs/$exp.txt & 105 | ``` 106 | 107 | ### ST Training 108 | We use a batch size of approximating 160k tokens **(GPU number * max_tokens * update_freq == 160k)**. 109 | 110 | ```bash 111 | main=64 112 | downsample=4 113 | lr=5e-4 114 | warm=4000 115 | dropout=0.1 116 | tokens=8000 117 | language=es 118 | pretrained_path=/path_to_asr_pretrained_checkpoint/avearge.pt 119 | 120 | 121 | exp=en${language}.s2t.cs_${main}.ds_${downsample}.kd.t_t.add.prenorm.amp.adam.lr_${lr}.warm_${warm}.drop_${dropout}.tk_${tokens}.bsz_160k 122 | MUSTC_ROOT=/path_to_your_dataset/mustc/ 123 | checkpoint_dir=./checkpoints/en-${language}/st/$exp 124 | 125 | nohup fairseq-train ${MUSTC_ROOT}/en-${language} \ 126 | --load-pretrained-encoder-from ${pretrained_path} \ 127 | --amp \ 128 | --config-yaml config_st.yaml --train-subset train_st_distilled --valid-subset dev_st \ 129 | --user-dir fs_plugins \ 130 | --task transducer_speech_to_text --arch t_t \ 131 | --max-source-positions 6000 --max-target-positions 1024 \ 132 | --main-context ${main} --right-context 0 --transducer-downsample ${downsample} \ 133 | --share-decoder-input-output-embed --rand-pos-encoder 300 --encoder-max-relative-position 32 \ 134 | --activation-dropout 0.1 --attention-dropout 0.1 \ 135 | --criterion transducer_loss \ 136 | --dropout ${dropout} --weight-decay 0.01 --clip-norm 5.0 \ 137 | --optimizer adam --adam-betas '(0.9,0.98)' \ 138 | --lr ${lr} --lr-scheduler inverse_sqrt \ 139 | --warmup-init-lr '1e-07' --warmup-updates ${warm} \ 140 | --stop-min-lr '1e-09' --max-update 150000 \ 141 | --max-tokens ${tokens} --update-freq 10 --grouped-shuffling \ 142 | --save-dir ${checkpoint_dir} \ 143 | --ddp-backend=legacy_ddp \ 144 | --no-progress-bar --log-format json --log-interval 100 \ 145 | --save-interval-updates 2000 --keep-interval-updates 10 \ 146 | --save-interval 1000 --keep-last-epochs 10 \ 147 | --fixed-validation-seed 7 \ 148 | --skip-invalid-size-inputs-valid-test \ 149 | --validate-interval 1000 --validate-interval-updates 2000 \ 150 | --best-checkpoint-metric rnn_t_loss --keep-best-checkpoints 5 \ 151 | --patience 20 --num-workers 8 \ 152 | --tensorboard-logdir logs_board/$exp >> logs/$exp.txt & 153 | ``` 154 | 155 | ## Training MonoAttn-Transducer 156 | 157 | ### Offline-Attn Pretraining 158 | We use a batch size of approximating 160k tokens **(GPU number * max_tokens * update_freq == 160k)**. 159 | 160 | ```bash 161 | main=64 162 | downsample=4 163 | lr=5e-4 164 | warm=4000 165 | dropout=0.1 166 | tokens=8000 167 | language=es 168 | pretrained_path=/path_to_asr_pretrained_checkpoint/avearge.pt # Use Transformer-Transducer ASR Pretrained Model 169 | 170 | exp=en${language}.s2t.cs_${main}.ds_${downsample}.kd.attn_t_t.add.prenorm.amp.adam.lr_${lr}.warm_${warm}.drop_${dropout}.tk_${tokens}.bsz_160k 171 | MUSTC_ROOT=/path_to_your_dataset/mustc/ 172 | checkpoint_dir=./checkpoints/en-${language}/st/$exp 173 | 174 | nohup fairseq-train ${MUSTC_ROOT}/en-${language} \ 175 | --load-pretrained-encoder-from ${pretrained_path} \ 176 | --amp \ 177 | --config-yaml config_st.yaml --train-subset train_st_distilled --valid-subset dev_st \ 178 | --user-dir fs_plugins \ 179 | --task transducer_speech_to_text --arch attention_t_t \ 180 | --max-source-positions 6000 --max-target-positions 1024 \ 181 | --main-context ${main} --right-context 0 --transducer-downsample ${downsample} \ 182 | --share-decoder-input-output-embed --rand-pos-encoder 300 --encoder-max-relative-position 32 \ 183 | --activation-dropout 0.1 --attention-dropout 0.1 \ 184 | --criterion transducer_loss \ 185 | --dropout ${dropout} --weight-decay 0.01 --clip-norm 5.0 \ 186 | --optimizer adam --adam-betas '(0.9,0.98)' \ 187 | --lr ${lr} --lr-scheduler inverse_sqrt \ 188 | --warmup-init-lr '1e-07' --warmup-updates ${warm} \ 189 | --stop-min-lr '1e-09' --max-update 50000 \ 190 | --max-tokens ${tokens} --update-freq 5 --grouped-shuffling \ 191 | --save-dir ${checkpoint_dir} \ 192 | --ddp-backend=legacy_ddp \ 193 | --no-progress-bar --log-format json --log-interval 100 \ 194 | --save-interval-updates 2000 --keep-interval-updates 10 \ 195 | --save-interval 1000 --keep-last-epochs 10 \ 196 | --fixed-validation-seed 7 \ 197 | --skip-invalid-size-inputs-valid-test \ 198 | --validate-interval 1000 --validate-interval-updates 2000 \ 199 | --best-checkpoint-metric rnn_t_loss --keep-best-checkpoints 5 \ 200 | --patience 20 --num-workers 8 --max-tokens-valid 4800 \ 201 | --tensorboard-logdir logs_board/$exp > logs/$exp.txt & 202 | ``` 203 | 204 | 205 | ### Mono-Attn Training 206 | We use a batch size of approximating 160k tokens **(GPU number * max_tokens * update_freq == 160k)**. 207 | 208 | ```bash 209 | main=64 210 | downsample=4 211 | lr=5e-4 212 | warm=4000 213 | dropout=0.1 214 | tokens=10000 215 | language=es 216 | 217 | pretrained_path=/path_to_offline_attn_trained_model/average.pt 218 | 219 | exp=en${language}.s2t.cs_${main}.ds_${downsample}.kd.mono_t_t_chunk_dia_prior.add.prenorm.amp.adam.lr_${lr}.warm_${warm}.drop_${dropout}.tk_${tokens}.bsz_160k 220 | MUSTC_ROOT=/path_to_your_dataset/mustc/ 221 | checkpoint_dir=./checkpoints/en-${language}/st/$exp 222 | 223 | nohup fairseq-train ${MUSTC_ROOT}/en-${language} \ 224 | --load-pretrained-encoder-from ${pretrained_path} \ 225 | --load-pretrained-decoder-from ${pretrained_path} \ 226 | --amp \ 227 | --config-yaml config_st.yaml --train-subset train_st_distilled --valid-subset dev_st \ 228 | --user-dir fs_plugins \ 229 | --task transducer_speech_to_text --arch monotonic_t_t_chunk_diagonal_prior \ 230 | --max-source-positions 6000 --max-target-positions 1024 \ 231 | --main-context ${main} --right-context ${main} --transducer-downsample ${downsample} \ 232 | --share-decoder-input-output-embed --rand-pos-encoder 300 --encoder-max-relative-position 32 \ 233 | --activation-dropout 0.1 --attention-dropout 0.1 \ 234 | --criterion transducer_loss \ 235 | --dropout ${dropout} --weight-decay 0.01 --clip-norm 5.0 \ 236 | --optimizer adam --adam-betas '(0.9,0.98)' \ 237 | --lr ${lr} --lr-scheduler inverse_sqrt \ 238 | --warmup-init-lr '1e-07' --warmup-updates ${warm} \ 239 | --stop-min-lr '1e-09' --max-update 20000 \ 240 | --max-tokens ${tokens} --update-freq 8 --grouped-shuffling \ 241 | --save-dir ${checkpoint_dir} \ 242 | --ddp-backend=legacy_ddp \ 243 | --no-progress-bar --log-format json --log-interval 100 \ 244 | --save-interval-updates 2000 --keep-interval-updates 20 \ 245 | --save-interval 1000 --keep-last-epochs 10 \ 246 | --fixed-validation-seed 7 \ 247 | --skip-invalid-size-inputs-valid-test \ 248 | --validate-interval 1000 --validate-interval-updates 2000 \ 249 | --best-checkpoint-metric montonic_rnn_t_loss --keep-best-checkpoints 5 \ 250 | --patience 20 --num-workers 8 \ 251 | --tensorboard-logdir logs_board/$exp > logs/$exp.txt & 252 | ``` 253 | 254 | ## Inference 255 | ### Testing Transformer-Transducer 256 | Use the agent ```transducer_agent_v2``` 257 | ```bash 258 | LANGUAGE=es 259 | exp=enes.s2t.cs_64.ds_4.kd.t_t.add.prenorm.amp.adam.lr_5e-4.warm_4000.drop_0.1.tk_10000.bsz_160k 260 | ckpt=average_last_5_40000 261 | file=./checkpoints/en-${LANGUAGE}/st/${exp}/${ckpt}.pt 262 | output_dir=./results/en-${LANGUAGE}/st 263 | main_context=64 264 | downsample=4 265 | 266 | simuleval \ 267 | --data-bin /dataset/mustc/en-${LANGUAGE} \ 268 | --source /dataset/mustc/en-${LANGUAGE}/data_segment/tst-COMMON.wav_list --target /dataset/mustc/en-${LANGUAGE}/data_segment/tst-COMMON.${LANGUAGE} \ 269 | --model-path $file \ 270 | --config-yaml config_st.yaml \ 271 | --agent ./fs_plugins/agents/transducer_agent_v2.py \ 272 | --transducer-downsample ${downsample} --main-context ${main_context} --right-context ${main_context} \ 273 | --source-segment-size ${main_context}0 \ 274 | --output $output_dir/${exp}_${ckpt} \ 275 | --quality-metrics BLEU --latency-metrics AL \ 276 | --device gpu 277 | ``` 278 | 279 | ## Inference 280 | ### Testing MonoAttn-Transducer 281 | Use the agent ```monotonic_transducer_agent``` 282 | ```bash 283 | LANGUAGE=es 284 | exp=enes.s2t.cs_64.ds_4.kd.mono_t_t.add.prenorm.amp.adam.lr_5e-4.warm_4000.drop_0.1.tk_10000.bsz_160k 285 | ckpt=average_last_5_40000 286 | file=./checkpoints/en-${LANGUAGE}/st/${exp}/${ckpt}.pt 287 | output_dir=./results/en-${LANGUAGE}/st 288 | main_context=64 289 | downsample=4 290 | 291 | simuleval \ 292 | --data-bin /dataset/mustc/en-${LANGUAGE} \ 293 | --source /dataset/mustc/en-${LANGUAGE}/data_segment/tst-COMMON.wav_list --target /dataset/mustc/en-${LANGUAGE}/data_segment/tst-COMMON.${LANGUAGE} \ 294 | --model-path $file \ 295 | --config-yaml config_st.yaml \ 296 | --agent ./fs_plugins/agents/monotonic_transducer_agent.py \ 297 | --transducer-downsample ${downsample} --main-context ${main_context} --right-context ${main_context} \ 298 | --source-segment-size ${main_context}0 \ 299 | --output $output_dir/${exp}_${ckpt} \ 300 | --quality-metrics BLEU --latency-metrics AL \ 301 | --device gpu 302 | ``` 303 | ## Citing 304 | 305 | Please kindly cite us if you find our papers or codes useful. 306 | 307 | ``` 308 | @inproceedings{ 309 | ma2025overcoming, 310 | title={Overcoming Non-monotonicity in Transducer-based Streaming Generation}, 311 | author={Zhengrui Ma and Yang Feng and Min Zhang}, 312 | booktitle={Proceedings of the 42nd International Conference on Machine Learning}, 313 | year={2025}, 314 | url={https://arxiv.org/abs/2411.17170} 315 | } 316 | ``` 317 | --------------------------------------------------------------------------------