├── modules ├── __init__.py ├── mamba │ ├── __init__.py │ ├── mamba_blocks.py │ └── bimamba.py └── Conmamba.py ├── figures ├── conmamba.png ├── performance.png └── mamba_encoder_decoder.png ├── requirement.txt ├── README.md ├── hparams ├── CTC │ ├── conformer_large.yaml │ └── conmamba_large.yaml └── S2S │ ├── conformer_large.yaml │ ├── conformer_small.yaml │ ├── conmamba_small.yaml │ ├── conmambamamba_small.yaml │ ├── conmamba_large.yaml │ └── conmambamamba_large.yaml ├── train_CTC.py ├── librispeech_prepare.py └── train_S2S.py /modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /modules/mamba/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/conmamba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xi-j/Mamba-ASR/HEAD/figures/conmamba.png -------------------------------------------------------------------------------- /figures/performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xi-j/Mamba-ASR/HEAD/figures/performance.png -------------------------------------------------------------------------------- /figures/mamba_encoder_decoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xi-j/Mamba-ASR/HEAD/figures/mamba_encoder_decoder.png -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | tqdm 3 | librosa 4 | soundfile 5 | wandb 6 | 7 | causal-conv1d==1.1.3.post1 8 | mamba-ssm==1.1.3.post1 9 | torch==2.0.1 --index-url https://download.pytorch.org/whl/cu118 10 | torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118 11 | speechbrain==1.0.0 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ConMamba [![arXiv](https://img.shields.io/badge/arXiv-2407.09732-.svg)](https://arxiv.org/abs/2407.09732) 2 | 3 | An official implementation of convolution-augmented Mamba for speech recognition. 4 | 5 | ## Architecture 6 | 7 | conmamba 8 | layers 9 | 10 | ## Prerequisites 11 | 12 | 1. Download LibriSpeech [corpus](https://www.openslr.org/12). 13 | 14 | 2. Install Packages. 15 | ``` 16 | conda create --name Slytherin python=3.9 17 | conda activate Slytherin 18 | pip install -r requirements.txt 19 | ``` 20 | You may need to install lower or higher versions of torch, torchaudio, causal-conv1d and mamba-ssm based on your hardware and system. Make sure they are compatible. 21 | 22 | 23 | ## Training 24 | To train a ConMamba Encoder-Transformer Decoder model on one GPU: 25 | ``` 26 | python train_S2S.py hparams/S2S/conmamba_large(small).yaml --data_folder --precision bf16 27 | ``` 28 | To train a ConMamba Encoder-Mamba Decoder model on one GPU: 29 | ``` 30 | python train_S2S.py hparams/S2S/conmambamamba_large(small).yaml --data_folder --precision bf16 31 | ``` 32 | To train a ConMamba Encoder model with a character-level CTC loss on four GPUs: 33 | ``` 34 | torchrun --nproc-per-node 4 train_CTC.py hparams/CTC/conmamba_large.yaml --data_folder --precision bf16 35 | ``` 36 | 37 | ## Inference and Checkpoints (Later) 38 | 39 | ## Performance (Word Error Rate%) 40 | performance 41 | 42 | ## Acknowledgement 43 | 44 | We acknowledge the wonderful work of [Mamba](https://arxiv.org/abs/2312.00752) and [Vision Mamba](https://arxiv.org/abs/2401.09417). We borrowed their implementation of [Mamba](https://github.com/state-spaces/mamba) and [bidirectional Mamba](https://github.com/hustvl/Vim). The training recipes are adapted from [SpeechBrain](https://speechbrain.github.io). 45 | 46 | ## Citation 47 | If you find this work helpful, please consider citing: 48 | 49 | ```bibtex 50 | @misc{jiang2024speechslytherin, 51 | title={Speech Slytherin: Examining the Performance and Efficiency of Mamba for Speech Separation, Recognition, and Synthesis}, 52 | author={Xilin Jiang and Yinghao Aaron Li and Adrian Nicolas Florea and Cong Han and Nima Mesgarani}, 53 | year={2024}, 54 | eprint={2407.09732}, 55 | archivePrefix={arXiv}, 56 | primaryClass={eess.AS}, 57 | url={https://arxiv.org/abs/2407.09732}, 58 | } 59 | ``` 60 | 61 | You may also like our Mamba for speech separation: https://github.com/xi-j/Mamba-TasNet 62 | 63 | -------------------------------------------------------------------------------- /modules/mamba/mamba_blocks.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copied and modified from 3 | https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py 4 | ''' 5 | 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | 10 | from functools import partial 11 | 12 | from mamba_ssm import Mamba 13 | from modules.mamba.bimamba import Mamba as BiMamba 14 | from modules.mamba.bimamba import Block as PreNormBlock 15 | 16 | try: 17 | from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn 18 | except ImportError: 19 | RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None 20 | 21 | 22 | def create_block( 23 | d_model, 24 | ssm_cls=None, 25 | ssm_cfg=None, 26 | norm_epsilon=1e-5, 27 | rms_norm=False, 28 | residual_in_fp32=False, 29 | fused_add_norm=True, 30 | layer_idx=None, 31 | device=None, 32 | dtype=None, 33 | ): 34 | if ssm_cfg is None: 35 | ssm_cfg = {} 36 | factory_kwargs = {"device": device, "dtype": dtype} 37 | mixer_cls = partial(ssm_cls, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) 38 | norm_cls = partial( 39 | nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs 40 | ) 41 | block = PreNormBlock( 42 | d_model, 43 | mixer_cls, 44 | norm_cls=norm_cls, 45 | fused_add_norm=fused_add_norm, 46 | residual_in_fp32=residual_in_fp32, 47 | ) 48 | block.layer_idx = layer_idx 49 | return block 50 | 51 | 52 | # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 53 | def _init_weights( 54 | module, 55 | n_layer, 56 | initializer_range=0.02, # Now only used for embedding layer. 57 | rescale_prenorm_residual=True, 58 | n_residuals_per_layer=1, # Change to 2 if we have MLP 59 | ): 60 | if isinstance(module, nn.Linear): 61 | if module.bias is not None: 62 | if not getattr(module.bias, "_no_reinit", False): 63 | nn.init.zeros_(module.bias) 64 | elif isinstance(module, nn.Embedding): 65 | nn.init.normal_(module.weight, std=initializer_range) 66 | 67 | if rescale_prenorm_residual: 68 | # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: 69 | # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale 70 | # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. 71 | # > -- GPT-2 :: https://openai.com/blog/better-language-models/ 72 | # 73 | # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py 74 | for name, p in module.named_parameters(): 75 | if name in ["out_proj.weight", "fc2.weight"]: 76 | # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block 77 | # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) 78 | # We need to reinit p since this code could be called multiple times 79 | # Having just p *= scale would repeatedly scale it down 80 | nn.init.kaiming_uniform_(p, a=math.sqrt(5)) 81 | with torch.no_grad(): 82 | p /= math.sqrt(n_residuals_per_layer * n_layer) 83 | 84 | 85 | class LnMambaAdd(nn.Module): 86 | 87 | def __init__(self, 88 | d_model, 89 | ssm_cls, 90 | ssm_cfg, 91 | rms_norm=False, 92 | layer_idx=None 93 | ): 94 | super().__init__() 95 | if rms_norm: 96 | self.norm = RMSNorm(d_model) 97 | else: 98 | self.norm = nn.LayerNorm(d_model) 99 | self.mamba = ssm_cls(d_model=d_model, **ssm_cfg) 100 | 101 | print(type(self.mamba)) 102 | 103 | print('Created LnMambaAdd.') 104 | 105 | def forward(self, x, residual=None, inference_params=None): 106 | if residual != None: 107 | x = x + residual 108 | return self.mamba(self.norm(x)), x 109 | 110 | 111 | class MambaBlocksSequential(nn.Module): 112 | """ 113 | A wrapper for the Mamba block to replicate it 114 | 115 | Arguments 116 | --------- 117 | n_mamba : int 118 | Number of Mamba blocks 119 | d_model : int 120 | Input dimension to Mamba (bottleneck dimension). 121 | d_state : int 122 | Mamba state dimension 123 | expand: int 124 | First linear projection d_model -> d_model * expand 125 | d_conv: int 126 | Kernel size of Mamba conv 127 | norm type : str 128 | The type of normalization, in ['gLN', 'cLN']. 129 | --------- 130 | """ 131 | 132 | def __init__(self, 133 | n_mamba: int, 134 | bidirectional: bool, 135 | d_model: int, # bottleneck dimension (B) 136 | d_state: int = 16, 137 | expand: int = 2, 138 | d_conv: int = 4, # kernel_size of 'Conv' in Mamba 139 | dt_rank: str="auto", 140 | conv_bias: bool = True, 141 | bias: bool = False, 142 | fused_add_norm: bool = True, 143 | rms_norm: bool = False, 144 | norm_epsilon: float = 1e-5, 145 | initializer_cfg=None, 146 | residual_in_fp32=False, 147 | use_simple_block=False 148 | ): 149 | super().__init__() 150 | self.residual_in_fp32 = residual_in_fp32 151 | self.bidirectional = bidirectional 152 | 153 | # We change the order of residual and layer norm: 154 | # Instead of LN -> Attn / MLP -> Add, we do: 155 | # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and 156 | # the main branch (output of MLP / Mixer). The model definition is unchanged. 157 | # This is for performance reason: we can fuse add + layer_norm. 158 | self.fused_add_norm = fused_add_norm 159 | if self.fused_add_norm: 160 | if layer_norm_fn is None or rms_norm_fn is None: 161 | raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") 162 | 163 | self.use_simple_block = use_simple_block 164 | 165 | ssm_cfg = { 166 | "d_state": d_state, 167 | "expand": expand, 168 | "d_conv": d_conv, 169 | "dt_rank": dt_rank, 170 | "conv_bias": conv_bias, 171 | "bias": bias 172 | } 173 | if bidirectional: 174 | ssm_cfg["bimamba_type"] = "v2" 175 | 176 | if use_simple_block: 177 | self.layers = nn.Sequential( 178 | *[ 179 | LnMambaAdd( 180 | d_model=d_model, 181 | ssm_cls=BiMamba if bidirectional else Mamba, 182 | ssm_cfg=ssm_cfg, 183 | rms_norm=rms_norm, 184 | layer_idx=i 185 | ) 186 | for i in range(n_mamba) 187 | ] 188 | ) 189 | else: 190 | self.layers = nn.Sequential( 191 | *[ 192 | create_block( 193 | d_model=d_model, 194 | ssm_cls=BiMamba if bidirectional else Mamba, 195 | ssm_cfg=ssm_cfg, 196 | norm_epsilon=norm_epsilon, 197 | rms_norm=rms_norm, 198 | residual_in_fp32=residual_in_fp32, 199 | fused_add_norm=fused_add_norm, 200 | layer_idx=i, 201 | ) 202 | for i in range(n_mamba) 203 | ] 204 | ) 205 | 206 | self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)( 207 | d_model, eps=norm_epsilon 208 | ) 209 | 210 | self.apply( 211 | partial( 212 | _init_weights, 213 | n_layer=n_mamba, 214 | **(initializer_cfg if initializer_cfg is not None else {}), 215 | ) 216 | ) 217 | 218 | 219 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 220 | return { 221 | i: block.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 222 | for i, layer in enumerate(self.layers) 223 | } 224 | 225 | def forward(self, x, inference_params=None): 226 | 227 | hidden_states = x 228 | residual = None 229 | for i, layer in enumerate(self.layers): 230 | hidden_states, residual = layer( 231 | hidden_states, residual, inference_params=inference_params 232 | ) 233 | 234 | if not self.fused_add_norm: 235 | residual = (hidden_states + residual) if residual is not None else hidden_states 236 | hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) 237 | else: 238 | # Set prenorm=False here since we don't need the residual 239 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn 240 | 241 | hidden_states = fused_add_norm_fn( 242 | hidden_states, 243 | self.norm_f.weight, 244 | self.norm_f.bias, 245 | eps=self.norm_f.eps, 246 | residual=residual, 247 | prenorm=False, 248 | residual_in_fp32=self.residual_in_fp32, 249 | ) 250 | 251 | return hidden_states 252 | -------------------------------------------------------------------------------- /hparams/CTC/conformer_large.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: E2E ASR with CTC 3 | # Encoder: Conformer Encoder 4 | # Decoder: CTC beam searcher and greedy searcher 5 | # Tokens: character 6 | # Training: Librispeech 960h 7 | # Authors: Xilin Jiang 8 | # ############################################################################ 9 | # Seed needs to be set at top of yaml, before objects with parameters are made 10 | 11 | seed: 3402 12 | __set_seed: !apply:torch.manual_seed [!ref ] 13 | project: Mamba-ASR 14 | experiment: conformer_L_CTC 15 | output_folder: !ref results/CTC_char// 16 | save_folder: !ref /save 17 | train_log: !ref /train_log.txt 18 | 19 | 20 | # Data files 21 | data_folder: !PLACEHOLDER 22 | # If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES 23 | # then data_folder_rirs should be /localscratch/xxx_corpus 24 | # otherwise the dataset will automatically be downloaded 25 | # data_folder_rirs: !ref 26 | train_splits: ["train-clean-100", "train-clean-360", "train-other-500"] 27 | dev_splits: ["dev-clean"] 28 | test_splits: ["dev-clean", "test-clean", "test-other"] 29 | skip_prep: False 30 | train_csv: !ref /train.csv 31 | valid_csv: !ref /dev-clean.csv 32 | test_csv: 33 | - !ref /test-clean.csv 34 | - !ref /test-other.csv 35 | 36 | skip_train: False 37 | precision: bf16 38 | 39 | ####################### Training Parameters #################################### 40 | 41 | number_of_epochs: 500 42 | batch_size: 32 # This works for 2x GPUs with 32GB 43 | grad_accumulation_factor: 4 44 | max_grad_norm: 5.0 45 | sorting: random 46 | num_workers: 16 47 | loss_reduction: batchmean 48 | valid_search_interval: 1 49 | avg_checkpoints: 10 # Number of checkpoints to average for evaluation 50 | 51 | lr_model: 0.001 52 | weight_decay: 0.0005 53 | 54 | # Feature parameters 55 | sample_rate: 16000 56 | n_fft: 512 57 | n_mels: 80 58 | win_length: 25 59 | 60 | # Training parameters 61 | # To make Transformers converge, the global bath size should be large enough. 62 | # The global batch size is max_batch_len * n_gpus * gradient_accumulation. 63 | # Empirically, we used 850 * 8 A40 45G GPUs * 2 or 1700 * 4 A100 80G * 2. 64 | # Please, set your parameters accordingly. 65 | dynamic_batching: True 66 | max_batch_length_train: 850 67 | max_batch_len_val: 100 68 | num_bucket: 200 69 | shuffle: False # if true re-creates batches at each epoch shuffling examples. 70 | max_batch_ex: 128 71 | batch_ordering: random 72 | 73 | dynamic_batch_sampler_train: 74 | max_batch_length: !ref 75 | num_buckets: !ref 76 | shuffle: !ref 77 | batch_ordering: !ref 78 | max_batch_ex: !ref 79 | 80 | dynamic_batch_sampler_val: 81 | max_batch_length: !ref 82 | num_buckets: !ref 83 | shuffle: !ref 84 | batch_ordering: !ref 85 | max_batch_ex: !ref 86 | 87 | # Dataloader options 88 | train_dataloader_opts: 89 | batch_size: !ref 90 | shuffle: True 91 | num_workers: !ref 92 | 93 | valid_dataloader_opts: 94 | batch_size: 1 95 | 96 | test_dataloader_opts: 97 | batch_size: 1 98 | 99 | ####################### Model Parameters ####################################### 100 | 101 | # Transformer 102 | attention_type: RelPosMHAXL 103 | d_model: 256 104 | nhead: 4 105 | d_ffn: 1024 106 | num_encoder_layers: 18 107 | num_decoder_layers: 0 108 | transformer_dropout: 0.1 109 | activation: !name:torch.nn.GELU 110 | output_neurons: 31 111 | 112 | # Outputs 113 | token_type: char # ["unigram", "bpe", "char"] 114 | character_coverage: 1.0 115 | blank_index: 0 116 | bos_index: 1 117 | eos_index: 2 118 | 119 | # Decoding parameters 120 | beam_size: 100 121 | beam_prune_logp: -12.0 122 | token_prune_min_logp: -1.2 123 | prune_history: False 124 | 125 | ############################## models ################################ 126 | 127 | CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd 128 | input_shape: (8, 10, 80) 129 | num_blocks: 2 130 | num_layers_per_block: 1 131 | out_channels: (64, 32) 132 | kernel_sizes: (3, 3) 133 | strides: (2, 2) 134 | residuals: (False, False) 135 | 136 | Transformer: !new:modules.TransformerASR.TransformerASR # yamllint disable-line rule:line-length 137 | input_size: 640 138 | tgt_vocab: !ref 139 | d_model: !ref 140 | nhead: !ref 141 | num_encoder_layers: !ref 142 | num_decoder_layers: !ref 143 | d_ffn: !ref 144 | dropout: !ref 145 | activation: !ref 146 | encoder_module: conformer 147 | attention_type: !ref 148 | normalize_before: True 149 | causal: False 150 | 151 | ctc_lin: !new:speechbrain.nnet.linear.Linear 152 | input_size: !ref 153 | n_neurons: !ref 154 | 155 | normalize: !new:speechbrain.processing.features.InputNormalization 156 | norm_type: global 157 | update_until_epoch: 4 158 | 159 | modules: 160 | CNN: !ref 161 | Transformer: !ref 162 | ctc_lin: !ref 163 | normalize: !ref 164 | 165 | model: !new:torch.nn.ModuleList 166 | - [!ref , !ref , !ref ] 167 | 168 | ####################### Decoding & optimiser ########################### 169 | 170 | # Decoding parameters 171 | test_beam_search: 172 | blank_index: !ref 173 | beam_size: !ref 174 | beam_prune_logp: !ref 175 | token_prune_min_logp: !ref 176 | prune_history: !ref 177 | 178 | ctc_cost: !name:speechbrain.nnet.losses.ctc_loss 179 | blank_index: !ref 180 | reduction: !ref 181 | 182 | n_warmup_steps: 7500 183 | noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler 184 | lr_initial: !ref 185 | n_warmup_steps: !ref 186 | 187 | model_opt_class: !name:torch.optim.AdamW 188 | lr: !ref 189 | betas: (0.9, 0.98) 190 | eps: 0.000000001 191 | weight_decay: !ref 192 | 193 | log_softmax: !new:torch.nn.LogSoftmax 194 | dim: -1 195 | 196 | ############################## Augmentations ################################### 197 | 198 | # Speed perturbation 199 | speed_changes: [95, 100, 105] # List of speed changes for time-stretching 200 | 201 | speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb 202 | orig_freq: !ref 203 | speeds: !ref 204 | 205 | # Time Drop 206 | # Roughly translated from 207 | # drop_chunk: !new:speechbrain.augment.time_domain.DropChunk 208 | # drop_length_low: 1000 209 | # drop_length_high: 2000 210 | # drop_count_low: 1 211 | # drop_count_high: 5 212 | time_drop_length_low: 6 # Min length for temporal chunk to drop in spectrogram 213 | time_drop_length_high: 12 # Max length for temporal chunk to drop in spectrogram 214 | time_drop_count_low: 1 # Min number of chunks to drop in time in the spectrogram 215 | time_drop_count_high: 5 # Max number of chunks to drop in time in the spectrogram 216 | time_drop_replace: "mean" # Method of dropping chunks 217 | 218 | time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop 219 | drop_length_low: !ref 220 | drop_length_high: !ref 221 | drop_count_low: !ref 222 | drop_count_high: !ref 223 | replace: !ref 224 | dim: 1 225 | 226 | # Frequency Drop 227 | # Roughly translated from 228 | # drop_freq: !new:speechbrain.augment.time_domain.DropFreq 229 | # drop_freq_low: 0 230 | # drop_freq_high: 1 231 | # drop_freq_count_low: 1 232 | # drop_freq_count_high: 3 233 | # drop_freq_width: 0.05 234 | freq_drop_length_low: 10 # Min length for chunks to drop in frequency in the spectrogram 235 | freq_drop_length_high: 20 # Max length for chunks to drop in frequency in the spectrogram 236 | freq_drop_count_low: 1 # Min number of chunks to drop in frequency in the spectrogram 237 | freq_drop_count_high: 3 # Max number of chunks to drop in frequency in the spectrogram 238 | freq_drop_replace: "mean" # Method of dropping chunks 239 | 240 | freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop 241 | drop_length_low: !ref 242 | drop_length_high: !ref 243 | drop_count_low: !ref 244 | drop_count_high: !ref 245 | replace: !ref 246 | dim: 2 247 | 248 | fea_augment: !new:speechbrain.augment.augmenter.Augmenter 249 | parallel_augment: False 250 | concat_original: False 251 | repeat_augment: 1 252 | shuffle_augmentations: False 253 | min_augmentations: 2 254 | max_augmentations: 2 255 | augment_prob: 1.0 256 | augmentations: [ 257 | !ref , 258 | !ref , 259 | ] 260 | 261 | compute_features: !new:speechbrain.lobes.features.Fbank 262 | sample_rate: !ref 263 | n_fft: !ref 264 | n_mels: !ref 265 | win_length: !ref 266 | 267 | ############################## Logging and Pretrainer ########################## 268 | 269 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 270 | checkpoints_dir: !ref 271 | recoverables: 272 | model: !ref 273 | noam_scheduler: !ref 274 | normalizer: !ref 275 | counter: !ref 276 | 277 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 278 | limit: !ref 279 | 280 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 281 | save_file: !ref 282 | 283 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 284 | split_tokens: True 285 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 286 | 287 | use_wandb: False 288 | resume: False 289 | wandb_logger: !name:speechbrain.utils.train_logger.WandBLogger 290 | initializer: !name:wandb.init 291 | entity: xj-audio 292 | project: !ref 293 | name: !ref 294 | dir: !ref 295 | reinit: true 296 | resume: !ref 297 | 298 | 299 | fixed_sec: null 300 | -------------------------------------------------------------------------------- /hparams/CTC/conmamba_large.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: E2E ASR with CTC 3 | # Encoder: ConMamba Encoder 4 | # Decoder: CTC beam searcher and greedy searcher 5 | # Tokens: character 6 | # Training: Librispeech 960h 7 | # Authors: Xilin Jiang 8 | # ############################################################################ 9 | # Seed needs to be set at top of yaml, before objects with parameters are made 10 | 11 | seed: 3402 12 | __set_seed: !apply:torch.manual_seed [!ref ] 13 | project: Mamba-ASR 14 | experiment: conmamba_L_CTC 15 | output_folder: !ref results/CTC_char// 16 | save_folder: !ref /save 17 | train_log: !ref /train_log.txt 18 | 19 | 20 | # Data files 21 | data_folder: !PLACEHOLDER 22 | # If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES 23 | # then data_folder_rirs should be /localscratch/xxx_corpus 24 | # otherwise the dataset will automatically be downloaded 25 | # data_folder_rirs: !ref 26 | train_splits: ["train-clean-100", "train-clean-360", "train-other-500"] 27 | dev_splits: ["dev-clean"] 28 | test_splits: ["dev-clean", "test-clean", "test-other"] 29 | skip_prep: False 30 | train_csv: !ref /train.csv 31 | valid_csv: !ref /dev-clean.csv 32 | test_csv: 33 | - !ref /test-clean.csv 34 | - !ref /test-other.csv 35 | 36 | skip_train: False 37 | precision: bf16 38 | 39 | ####################### Training Parameters #################################### 40 | 41 | number_of_epochs: 500 42 | batch_size: 32 # This works for 2x GPUs with 32GB 43 | grad_accumulation_factor: 4 44 | max_grad_norm: 5.0 45 | sorting: random 46 | num_workers: 16 47 | loss_reduction: batchmean 48 | valid_search_interval: 1 49 | avg_checkpoints: 10 # Number of checkpoints to average for evaluation 50 | 51 | lr_model: 0.001 52 | weight_decay: 0.0005 53 | 54 | # Feature parameters 55 | sample_rate: 16000 56 | n_fft: 512 57 | n_mels: 80 58 | win_length: 25 59 | 60 | # Training parameters 61 | # To make Transformers converge, the global bath size should be large enough. 62 | # The global batch size is max_batch_len * n_gpus * gradient_accumulation. 63 | # Empirically, we used 850 * 8 A40 45G GPUs * 2 or 1700 * 4 A100 80G * 2. 64 | # Please, set your parameters accordingly. 65 | dynamic_batching: True 66 | max_batch_length_train: 850 67 | max_batch_len_val: 100 68 | num_bucket: 200 69 | shuffle: False # if true re-creates batches at each epoch shuffling examples. 70 | max_batch_ex: 128 71 | batch_ordering: random 72 | 73 | dynamic_batch_sampler_train: 74 | max_batch_length: !ref 75 | num_buckets: !ref 76 | shuffle: !ref 77 | batch_ordering: !ref 78 | max_batch_ex: !ref 79 | 80 | dynamic_batch_sampler_val: 81 | max_batch_length: !ref 82 | num_buckets: !ref 83 | shuffle: !ref 84 | batch_ordering: !ref 85 | max_batch_ex: !ref 86 | 87 | # Dataloader options 88 | train_dataloader_opts: 89 | batch_size: !ref 90 | shuffle: True 91 | num_workers: !ref 92 | 93 | valid_dataloader_opts: 94 | batch_size: 1 95 | 96 | test_dataloader_opts: 97 | batch_size: 1 98 | 99 | ####################### Model Parameters ####################################### 100 | 101 | # Transformer dummy 102 | attention_type: RelPosMHAXL # unused 103 | nhead: 4 # unused 104 | 105 | # Common 106 | d_model: 256 107 | d_ffn: 1024 108 | num_encoder_layers: 18 109 | num_decoder_layers: 0 110 | transformer_dropout: 0.1 111 | activation: !name:torch.nn.GELU 112 | output_neurons: 31 113 | 114 | # Outputs 115 | token_type: char # ["unigram", "bpe", "char"] 116 | character_coverage: 1.0 117 | blank_index: 0 118 | bos_index: 1 119 | eos_index: 2 120 | 121 | # Decoding parameters 122 | beam_size: 100 123 | beam_prune_logp: -12.0 124 | token_prune_min_logp: -1.2 125 | prune_history: False 126 | 127 | # Mamba parameters 128 | d_state: 16 129 | expand: 2 130 | d_conv: 4 131 | bidirectional: True 132 | mamba_config: 133 | d_state: !ref 134 | expand: !ref 135 | d_conv: !ref 136 | bidirectional: !ref 137 | 138 | ############################## models ################################ 139 | 140 | CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd 141 | input_shape: (8, 10, 80) 142 | num_blocks: 2 143 | num_layers_per_block: 1 144 | out_channels: (64, 32) 145 | kernel_sizes: (3, 3) 146 | strides: (2, 2) 147 | residuals: (False, False) 148 | 149 | Transformer: !new:modules.TransformerASR.TransformerASR # yamllint disable-line rule:line-length 150 | input_size: 640 151 | tgt_vocab: !ref 152 | d_model: !ref 153 | nhead: !ref 154 | num_encoder_layers: !ref 155 | num_decoder_layers: !ref 156 | d_ffn: !ref 157 | dropout: !ref 158 | activation: !ref 159 | encoder_module: conmamba 160 | attention_type: !ref 161 | normalize_before: True 162 | causal: False 163 | mamba_config: !ref 164 | 165 | ctc_lin: !new:speechbrain.nnet.linear.Linear 166 | input_size: !ref 167 | n_neurons: !ref 168 | 169 | normalize: !new:speechbrain.processing.features.InputNormalization 170 | norm_type: global 171 | update_until_epoch: 4 172 | 173 | modules: 174 | CNN: !ref 175 | Transformer: !ref 176 | ctc_lin: !ref 177 | normalize: !ref 178 | 179 | model: !new:torch.nn.ModuleList 180 | - [!ref , !ref , !ref ] 181 | 182 | ####################### Decoding & optimiser ########################### 183 | 184 | # Decoding parameters 185 | test_beam_search: 186 | blank_index: !ref 187 | beam_size: !ref 188 | beam_prune_logp: !ref 189 | token_prune_min_logp: !ref 190 | prune_history: !ref 191 | 192 | ctc_cost: !name:speechbrain.nnet.losses.ctc_loss 193 | blank_index: !ref 194 | reduction: !ref 195 | 196 | n_warmup_steps: 7500 197 | noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler 198 | lr_initial: !ref 199 | n_warmup_steps: !ref 200 | 201 | model_opt_class: !name:torch.optim.AdamW 202 | lr: !ref 203 | betas: (0.9, 0.98) 204 | eps: 0.000000001 205 | weight_decay: !ref 206 | 207 | log_softmax: !new:torch.nn.LogSoftmax 208 | dim: -1 209 | 210 | ############################## Augmentations ################################### 211 | 212 | # Speed perturbation 213 | speed_changes: [95, 100, 105] # List of speed changes for time-stretching 214 | 215 | speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb 216 | orig_freq: !ref 217 | speeds: !ref 218 | 219 | # Time Drop 220 | # Roughly translated from 221 | # drop_chunk: !new:speechbrain.augment.time_domain.DropChunk 222 | # drop_length_low: 1000 223 | # drop_length_high: 2000 224 | # drop_count_low: 1 225 | # drop_count_high: 5 226 | time_drop_length_low: 6 # Min length for temporal chunk to drop in spectrogram 227 | time_drop_length_high: 12 # Max length for temporal chunk to drop in spectrogram 228 | time_drop_count_low: 1 # Min number of chunks to drop in time in the spectrogram 229 | time_drop_count_high: 5 # Max number of chunks to drop in time in the spectrogram 230 | time_drop_replace: "mean" # Method of dropping chunks 231 | 232 | time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop 233 | drop_length_low: !ref 234 | drop_length_high: !ref 235 | drop_count_low: !ref 236 | drop_count_high: !ref 237 | replace: !ref 238 | dim: 1 239 | 240 | # Frequency Drop 241 | # Roughly translated from 242 | # drop_freq: !new:speechbrain.augment.time_domain.DropFreq 243 | # drop_freq_low: 0 244 | # drop_freq_high: 1 245 | # drop_freq_count_low: 1 246 | # drop_freq_count_high: 3 247 | # drop_freq_width: 0.05 248 | freq_drop_length_low: 10 # Min length for chunks to drop in frequency in the spectrogram 249 | freq_drop_length_high: 20 # Max length for chunks to drop in frequency in the spectrogram 250 | freq_drop_count_low: 1 # Min number of chunks to drop in frequency in the spectrogram 251 | freq_drop_count_high: 3 # Max number of chunks to drop in frequency in the spectrogram 252 | freq_drop_replace: "mean" # Method of dropping chunks 253 | 254 | freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop 255 | drop_length_low: !ref 256 | drop_length_high: !ref 257 | drop_count_low: !ref 258 | drop_count_high: !ref 259 | replace: !ref 260 | dim: 2 261 | 262 | fea_augment: !new:speechbrain.augment.augmenter.Augmenter 263 | parallel_augment: False 264 | concat_original: False 265 | repeat_augment: 1 266 | shuffle_augmentations: False 267 | min_augmentations: 2 268 | max_augmentations: 2 269 | augment_prob: 1.0 270 | augmentations: [ 271 | !ref , 272 | !ref , 273 | ] 274 | 275 | compute_features: !new:speechbrain.lobes.features.Fbank 276 | sample_rate: !ref 277 | n_fft: !ref 278 | n_mels: !ref 279 | win_length: !ref 280 | 281 | ############################## Logging and Pretrainer ########################## 282 | 283 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 284 | checkpoints_dir: !ref 285 | recoverables: 286 | model: !ref 287 | noam_scheduler: !ref 288 | normalizer: !ref 289 | counter: !ref 290 | 291 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 292 | limit: !ref 293 | 294 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 295 | save_file: !ref 296 | 297 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 298 | split_tokens: True 299 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 300 | 301 | use_wandb: False 302 | resume: False 303 | wandb_logger: !name:speechbrain.utils.train_logger.WandBLogger 304 | initializer: !name:wandb.init 305 | entity: xj-audio 306 | project: !ref 307 | name: !ref 308 | dir: !ref 309 | reinit: true 310 | resume: !ref 311 | 312 | fixed_sec: null 313 | -------------------------------------------------------------------------------- /hparams/S2S/conformer_large.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: E2E ASR with Transformer 3 | # Encoder: Conformer Encoder 4 | # Decoder: Transformer Decoder + (CTC/ATT joint) beamsearch + TransformerLM 5 | # Tokens: unigram 6 | # losses: CTC + KLdiv (Label Smoothing loss) 7 | # Training: Librispeech 960h 8 | # Authors: Jianyuan Zhong, Titouan Parcollet, Samuele Cornell 9 | # ############################################################################ 10 | # Seed needs to be set at top of yaml, before objects with parameters are made 11 | 12 | seed: 3407 13 | __set_seed: !apply:torch.manual_seed [!ref ] 14 | project: Mamba-ASR 15 | experiment: conformer_L_S2S 16 | output_folder: !ref results/S2S// 17 | output_wer_folder: !ref / 18 | save_folder: !ref /save 19 | train_log: !ref /train_log.txt 20 | 21 | # Language model (LM) pretraining 22 | # NB: To avoid mismatch, the speech recognizer must be trained with the same 23 | # tokenizer used for LM training. Here, we download everything from the 24 | # speechbrain HuggingFace repository. However, a local path pointing to a 25 | # directory containing the lm.ckpt and tokenizer.ckpt may also be specified 26 | # instead. E.g if you want to use your own LM / tokenizer. 27 | pretrained_lm_tokenizer_path: speechbrain/asr-transformer-transformerlm-librispeech 28 | 29 | # Data files 30 | data_folder: !PLACEHOLDER 31 | # If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES 32 | # then data_folder_rirs should be /localscratch/xxx_corpus 33 | # otherwise the dataset will automatically be downloaded 34 | # data_folder_rirs: !ref 35 | train_splits: ["train-clean-100", "train-clean-360", "train-other-500"] 36 | dev_splits: ["dev-clean"] 37 | test_splits: ["dev-clean", "test-clean", "test-other"] 38 | skip_prep: False 39 | train_csv: !ref /train.csv 40 | valid_csv: !ref /dev-clean.csv 41 | test_csv: 42 | - !ref /test-clean.csv 43 | - !ref /test-other.csv 44 | 45 | skip_train: False 46 | no_lm: False 47 | 48 | # Training parameters 49 | # To make Transformers converge, the global bath size should be large enough. 50 | # The global batch size is computed as batch_size * n_gpus * grad_accumulation_factor. 51 | # Empirically, we found that this value should be >= 128. 52 | # Please, set your parameters accordingly. 53 | number_of_epochs: 120 54 | batch_size: 16 # This works for 2x GPUs with 32GB 55 | ctc_weight: 0.3 56 | grad_accumulation_factor: 8 57 | max_grad_norm: 5.0 58 | loss_reduction: 'batchmean' 59 | sorting: random 60 | num_workers: 4 61 | precision: bf16 # bf16, fp16 or fp32 62 | avg_checkpoints: 10 # Number of checkpoints to average for evaluation 63 | 64 | # stages related parameters 65 | lr_adam: 0.0008 66 | 67 | # Feature parameters 68 | sample_rate: 16000 69 | n_fft: 512 70 | n_mels: 80 71 | win_length: 32 72 | 73 | # This setup works well for A100 80GB GPU, adapts it to your needs. 74 | # Or turn it off (but training speed will decrease) 75 | dynamic_batching: True 76 | max_batch_length_train: 500 77 | max_batch_length_val: 100 # we reduce it as the beam is much wider (VRAM) 78 | num_bucket: 200 79 | shuffle: True # if true re-creates batches at each epoch shuffling examples. 80 | batch_ordering: random 81 | max_batch_ex: 256 82 | 83 | dynamic_batch_sampler_train: 84 | max_batch_length: !ref 85 | num_buckets: !ref 86 | shuffle: !ref 87 | batch_ordering: !ref 88 | max_batch_ex: !ref 89 | 90 | dynamic_batch_sampler_valid: 91 | max_batch_length: !ref 92 | num_buckets: !ref 93 | shuffle: !ref 94 | batch_ordering: !ref 95 | max_batch_ex: !ref 96 | 97 | # Dataloader options 98 | train_dataloader_opts: 99 | batch_size: !ref 100 | shuffle: True 101 | num_workers: !ref 102 | 103 | valid_dataloader_opts: 104 | batch_size: 1 105 | 106 | test_dataloader_opts: 107 | batch_size: 1 108 | 109 | ####################### Model parameters ########################### 110 | # Transformer 111 | d_model: 512 112 | nhead: 8 113 | num_encoder_layers: 12 114 | num_decoder_layers: 6 115 | d_ffn: 2048 116 | transformer_dropout: 0.1 117 | activation: !name:torch.nn.GELU 118 | output_neurons: 5000 119 | 120 | # Outputs 121 | blank_index: 0 122 | label_smoothing: 0.1 123 | pad_index: 0 124 | bos_index: 1 125 | eos_index: 2 126 | 127 | # Decoding parameters 128 | min_decode_ratio: 0.0 129 | max_decode_ratio: 1.0 130 | valid_search_interval: 10 131 | valid_beam_size: 10 132 | test_beam_size: 66 133 | lm_weight: 0.60 134 | ctc_weight_decode: 0.40 135 | 136 | ############################## models ################################ 137 | 138 | CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd 139 | input_shape: (8, 10, 80) 140 | num_blocks: 2 141 | num_layers_per_block: 1 142 | out_channels: (64, 32) 143 | kernel_sizes: (3, 3) 144 | strides: (2, 2) 145 | residuals: (False, False) 146 | 147 | Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length 148 | input_size: 640 149 | tgt_vocab: !ref 150 | d_model: !ref 151 | nhead: !ref 152 | num_encoder_layers: !ref 153 | num_decoder_layers: !ref 154 | d_ffn: !ref 155 | dropout: !ref 156 | activation: !ref 157 | encoder_module: conformer 158 | attention_type: RelPosMHAXL 159 | normalize_before: True 160 | causal: False 161 | 162 | # This is the TransformerLM that is used according to the Huggingface repository 163 | # Visit the HuggingFace model corresponding to the pretrained_lm_tokenizer_path 164 | # For more details about the model! 165 | # NB: It has to match the pre-trained TransformerLM!! 166 | lm_model: !new:speechbrain.lobes.models.transformer.TransformerLM.TransformerLM # yamllint disable-line rule:line-length 167 | vocab: !ref 168 | d_model: 768 169 | nhead: 12 170 | num_encoder_layers: 12 171 | num_decoder_layers: 0 172 | d_ffn: 3072 173 | dropout: 0.0 174 | activation: !name:torch.nn.GELU 175 | normalize_before: False 176 | 177 | tokenizer: !new:sentencepiece.SentencePieceProcessor 178 | 179 | ctc_lin: !new:speechbrain.nnet.linear.Linear 180 | input_size: !ref 181 | n_neurons: !ref 182 | 183 | seq_lin: !new:speechbrain.nnet.linear.Linear 184 | input_size: !ref 185 | n_neurons: !ref 186 | 187 | normalize: !new:speechbrain.processing.features.InputNormalization 188 | norm_type: global 189 | update_until_epoch: 4 190 | 191 | modules: 192 | CNN: !ref 193 | Transformer: !ref 194 | seq_lin: !ref 195 | ctc_lin: !ref 196 | normalize: !ref 197 | 198 | # define two optimizers here for two-stage training 199 | Adam: !name:torch.optim.AdamW 200 | lr: !ref 201 | betas: (0.9, 0.98) 202 | eps: 0.000000001 203 | 204 | model: !new:torch.nn.ModuleList 205 | - [!ref , !ref , !ref , !ref ] 206 | 207 | # Scorer 208 | ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer 209 | eos_index: !ref 210 | blank_index: !ref 211 | ctc_fc: !ref 212 | 213 | 214 | transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer 215 | language_model: !ref 216 | temperature: 1.15 217 | 218 | scorer_test_search: !new:speechbrain.decoders.scorer.ScorerBuilder 219 | full_scorers: [!ref , !ref ] 220 | weights: 221 | ctc: !ref 222 | transformerlm: !ref 223 | 224 | scorer_valid_search: !new:speechbrain.decoders.scorer.ScorerBuilder 225 | full_scorers: [!ref ] 226 | weights: 227 | ctc: !ref 228 | 229 | valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher 230 | modules: [!ref , !ref ] 231 | bos_index: !ref 232 | eos_index: !ref 233 | min_decode_ratio: !ref 234 | max_decode_ratio: !ref 235 | beam_size: !ref 236 | using_eos_threshold: False 237 | length_normalization: True 238 | scorer: !ref 239 | 240 | test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher 241 | modules: [!ref , !ref ] 242 | bos_index: !ref 243 | eos_index: !ref 244 | min_decode_ratio: !ref 245 | max_decode_ratio: !ref 246 | beam_size: !ref 247 | temperature: 1.15 248 | using_eos_threshold: False 249 | length_normalization: True 250 | scorer: !ref 251 | 252 | log_softmax: !new:torch.nn.LogSoftmax 253 | dim: -1 254 | 255 | ctc_cost: !name:speechbrain.nnet.losses.ctc_loss 256 | blank_index: !ref 257 | reduction: !ref 258 | 259 | seq_cost: !name:speechbrain.nnet.losses.kldiv_loss 260 | label_smoothing: !ref 261 | reduction: !ref 262 | 263 | n_warmup_steps: 30000 264 | noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler 265 | lr_initial: !ref 266 | n_warmup_steps: !ref 267 | 268 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 269 | checkpoints_dir: !ref 270 | recoverables: 271 | model: !ref 272 | noam_scheduler: !ref 273 | normalizer: !ref 274 | counter: !ref 275 | 276 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 277 | limit: !ref 278 | 279 | # Speed perturbation 280 | speed_changes: [95, 100, 105] # List of speed changes for time-stretching 281 | 282 | speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb 283 | orig_freq: !ref 284 | speeds: !ref 285 | 286 | # Time Drop 287 | time_drop_length_low: 15 # Min length for temporal chunk to drop in spectrogram 288 | time_drop_length_high: 25 # Max length for temporal chunk to drop in spectrogram 289 | time_drop_count_low: 4 # Min number of chunks to drop in time in the spectrogram 290 | time_drop_count_high: 4 # Max number of chunks to drop in time in the spectrogram 291 | time_drop_replace: "mean" # Method of dropping chunks 292 | 293 | time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop 294 | drop_length_low: !ref 295 | drop_length_high: !ref 296 | drop_count_low: !ref 297 | drop_count_high: !ref 298 | replace: !ref 299 | dim: 1 300 | 301 | # Frequency Drop 302 | freq_drop_length_low: 10 # Min length for chunks to drop in frequency in the spectrogram 303 | freq_drop_length_high: 20 # Max length for chunks to drop in frequency in the spectrogram 304 | freq_drop_count_low: 4 # Min number of chunks to drop in frequency in the spectrogram 305 | freq_drop_count_high: 4 # Max number of chunks to drop in frequency in the spectrogram 306 | freq_drop_replace: "mean" # Method of dropping chunks 307 | 308 | freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop 309 | drop_length_low: !ref 310 | drop_length_high: !ref 311 | drop_count_low: !ref 312 | drop_count_high: !ref 313 | replace: !ref 314 | dim: 2 315 | 316 | # Time warp 317 | time_warp_window: 5 # Length of time warping window 318 | time_warp_mode: "bicubic" # Time warping method 319 | 320 | time_warp: !new:speechbrain.augment.freq_domain.Warping 321 | warp_window: !ref 322 | warp_mode: !ref 323 | dim: 1 324 | 325 | fea_augment: !new:speechbrain.augment.augmenter.Augmenter 326 | parallel_augment: False 327 | concat_original: False 328 | repeat_augment: 1 329 | shuffle_augmentations: False 330 | min_augmentations: 3 331 | max_augmentations: 3 332 | augment_prob: 1.0 333 | augmentations: [ 334 | !ref , 335 | !ref , 336 | !ref ] 337 | 338 | compute_features: !new:speechbrain.lobes.features.Fbank 339 | sample_rate: !ref 340 | n_fft: !ref 341 | n_mels: !ref 342 | win_length: !ref 343 | 344 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 345 | save_file: !ref 346 | 347 | error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 348 | acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats 349 | 350 | # The pretrainer allows a mapping between pretrained files and instances that 351 | # are declared in the yaml. E.g here, we will download the file lm.ckpt 352 | # and it will be loaded into "lm" which is pointing to the defined 353 | # before. 354 | pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer 355 | collect_in: !ref 356 | loadables: 357 | lm: !ref 358 | tokenizer: !ref 359 | paths: 360 | lm: !ref /lm.ckpt 361 | tokenizer: !ref /tokenizer.ckpt 362 | 363 | use_wandb: False 364 | resume: False 365 | wandb_logger: !name:speechbrain.utils.train_logger.WandBLogger 366 | initializer: !name:wandb.init 367 | entity: xj-audio 368 | project: !ref 369 | name: !ref 370 | dir: !ref 371 | reinit: true 372 | resume: !ref 373 | -------------------------------------------------------------------------------- /hparams/S2S/conformer_small.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: E2E ASR with Transformer 3 | # Encoder: Conformer Encoder 4 | # Decoder: Transformer Decoder + (CTC/ATT joint) beamsearch + TransformerLM 5 | # Tokens: unigram 6 | # losses: CTC + KLdiv (Label Smoothing loss) 7 | # Training: Librispeech 960h 8 | # Authors: Jianyuan Zhong, Titouan Parcollet, Samuele Cornell 9 | # ############################################################################ 10 | # Seed needs to be set at top of yaml, before objects with parameters are made 11 | 12 | seed: 7775 13 | __set_seed: !apply:torch.manual_seed [!ref ] 14 | project: Mamba-ASR 15 | experiment: conformer_S_S2S 16 | output_folder: !ref results/S2S// 17 | output_wer_folder: !ref / 18 | save_folder: !ref /save 19 | train_log: !ref /train_log.txt 20 | 21 | # Language model (LM) pretraining 22 | # NB: To avoid mismatch, the speech recognizer must be trained with the same 23 | # tokenizer used for LM training. Here, we download everything from the 24 | # speechbrain HuggingFace repository. However, a local path pointing to a 25 | # directory containing the lm.ckpt and tokenizer.ckpt may also be specified 26 | # instead. E.g if you want to use your own LM / tokenizer. 27 | pretrained_lm_tokenizer_path: speechbrain/asr-transformer-transformerlm-librispeech 28 | 29 | # Data files 30 | data_folder: !PLACEHOLDER 31 | # If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES 32 | # then data_folder_rirs should be /localscratch/xxx_corpus 33 | # otherwise the dataset will automatically be downloaded 34 | # data_folder_rirs: !ref 35 | train_splits: ["train-clean-100", "train-clean-360", "train-other-500"] 36 | dev_splits: ["dev-clean"] 37 | test_splits: ["test-clean", "test-other"] 38 | skip_prep: False 39 | train_csv: !ref /train.csv 40 | valid_csv: !ref /dev-clean.csv 41 | test_csv: 42 | - !ref /test-clean.csv 43 | - !ref /test-other.csv 44 | 45 | skip_train: False 46 | no_lm: False 47 | 48 | # Training parameters 49 | # To make Transformers converge, the global bath size should be large enough. 50 | # The global batch size is computed as batch_size * n_gpus * grad_accumulation_factor. 51 | # Empirically, we found that this value should be >= 128. 52 | # Please, set your parameters accordingly. 53 | number_of_epochs: 110 54 | batch_size: 16 # This works for 2x GPUs with 32GB 55 | ctc_weight: 0.3 56 | grad_accumulation_factor: 8 # 1 57 | max_grad_norm: 5.0 58 | loss_reduction: 'batchmean' 59 | sorting: random 60 | num_workers: 4 61 | precision: bf16 # bf16, fp16 or fp32 62 | avg_checkpoints: 10 # Number of checkpoints to average for evaluation 63 | 64 | # stages related parameters 65 | # stage_one_epochs: 90 66 | lr_adam: !ref 0.001 67 | # lr_sgd: 0.000025 68 | 69 | # Feature parameters 70 | sample_rate: 16000 71 | n_fft: 400 72 | n_mels: 80 73 | 74 | # This setup works well for V100 32GB GPU, adapts it to your needs. 75 | # Or turn it off (but training speed will decrease) 76 | dynamic_batching: True 77 | max_batch_length_train: 900 78 | max_batch_length_val: 100 # we reduce it as the beam is much wider (VRAM) 79 | num_bucket: 200 80 | shuffle: True # if true re-creates batches at each epoch shuffling examples. 81 | batch_ordering: random 82 | max_batch_ex: 128 83 | 84 | dynamic_batch_sampler_train: 85 | max_batch_length: !ref 86 | num_buckets: !ref 87 | shuffle: !ref 88 | batch_ordering: !ref 89 | max_batch_ex: !ref 90 | 91 | dynamic_batch_sampler_valid: 92 | max_batch_length: !ref 93 | num_buckets: !ref 94 | shuffle: !ref 95 | batch_ordering: !ref 96 | max_batch_ex: !ref 97 | 98 | # Dataloader options 99 | train_dataloader_opts: 100 | batch_size: !ref 101 | shuffle: True 102 | num_workers: !ref 103 | 104 | valid_dataloader_opts: 105 | batch_size: 1 106 | 107 | test_dataloader_opts: 108 | batch_size: 1 109 | 110 | ####################### Model parameters ########################### 111 | # Transformer 112 | d_model: 144 113 | nhead: 4 114 | num_encoder_layers: 12 115 | num_decoder_layers: 4 116 | d_ffn: 1024 117 | transformer_dropout: 0.1 118 | activation: !name:torch.nn.GELU 119 | output_neurons: 5000 120 | 121 | # Outputs 122 | blank_index: 0 123 | label_smoothing: 0.0 124 | pad_index: 0 125 | bos_index: 1 126 | eos_index: 2 127 | 128 | # Decoding parameters 129 | min_decode_ratio: 0.0 130 | max_decode_ratio: 1.0 131 | valid_search_interval: 10 132 | valid_beam_size: 10 133 | test_beam_size: 66 134 | lm_weight: 0.60 135 | ctc_weight_decode: 0.40 136 | 137 | ############################## models ################################ 138 | 139 | CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd 140 | input_shape: (8, 10, 80) 141 | num_blocks: 2 142 | num_layers_per_block: 1 143 | out_channels: (64, 32) 144 | kernel_sizes: (3, 3) 145 | strides: (2, 2) 146 | residuals: (False, False) 147 | 148 | Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length 149 | input_size: 640 150 | tgt_vocab: !ref 151 | d_model: !ref 152 | nhead: !ref 153 | num_encoder_layers: !ref 154 | num_decoder_layers: !ref 155 | d_ffn: !ref 156 | dropout: !ref 157 | activation: !ref 158 | encoder_module: conformer 159 | attention_type: RelPosMHAXL 160 | normalize_before: True 161 | causal: False 162 | 163 | # This is the TransformerLM that is used according to the Huggingface repository 164 | # Visit the HuggingFace model corresponding to the pretrained_lm_tokenizer_path 165 | # For more details about the model! 166 | # NB: It has to match the pre-trained TransformerLM!! 167 | lm_model: !new:speechbrain.lobes.models.transformer.TransformerLM.TransformerLM # yamllint disable-line rule:line-length 168 | vocab: !ref 169 | d_model: 768 170 | nhead: 12 171 | num_encoder_layers: 12 172 | num_decoder_layers: 0 173 | d_ffn: 3072 174 | dropout: 0.0 175 | activation: !name:torch.nn.GELU 176 | normalize_before: False 177 | 178 | tokenizer: !new:sentencepiece.SentencePieceProcessor 179 | 180 | ctc_lin: !new:speechbrain.nnet.linear.Linear 181 | input_size: !ref 182 | n_neurons: !ref 183 | 184 | seq_lin: !new:speechbrain.nnet.linear.Linear 185 | input_size: !ref 186 | n_neurons: !ref 187 | 188 | normalize: !new:speechbrain.processing.features.InputNormalization 189 | norm_type: global 190 | # update_until_epoch: 4 191 | 192 | modules: 193 | CNN: !ref 194 | Transformer: !ref 195 | seq_lin: !ref 196 | ctc_lin: !ref 197 | normalize: !ref 198 | 199 | model: !new:torch.nn.ModuleList 200 | - [!ref , !ref , !ref , !ref ] 201 | 202 | # define two optimizers here for two-stage training 203 | Adam: !name:torch.optim.Adam 204 | lr: !ref 205 | betas: (0.9, 0.98) 206 | eps: 0.000000001 207 | 208 | #SGD: !name:torch.optim.SGD 209 | # lr: !ref 210 | # momentum: 0.99 211 | # nesterov: True 212 | 213 | # Scorer 214 | ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer 215 | eos_index: !ref 216 | blank_index: !ref 217 | ctc_fc: !ref 218 | 219 | 220 | transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer 221 | language_model: !ref 222 | temperature: 1.15 223 | 224 | scorer_test_search: !new:speechbrain.decoders.scorer.ScorerBuilder 225 | full_scorers: [!ref , !ref ] 226 | weights: 227 | ctc: !ref 228 | transformerlm: !ref 229 | 230 | scorer_valid_search: !new:speechbrain.decoders.scorer.ScorerBuilder 231 | full_scorers: [!ref ] 232 | weights: 233 | ctc: !ref 234 | 235 | valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher 236 | modules: [!ref , !ref ] 237 | bos_index: !ref 238 | eos_index: !ref 239 | min_decode_ratio: !ref 240 | max_decode_ratio: !ref 241 | beam_size: !ref 242 | using_eos_threshold: False 243 | length_normalization: True 244 | scorer: !ref 245 | 246 | test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher 247 | modules: [!ref , !ref ] 248 | bos_index: !ref 249 | eos_index: !ref 250 | min_decode_ratio: !ref 251 | max_decode_ratio: !ref 252 | beam_size: !ref 253 | temperature: 1.15 254 | using_eos_threshold: False 255 | length_normalization: True 256 | scorer: !ref 257 | 258 | 259 | log_softmax: !new:torch.nn.LogSoftmax 260 | dim: -1 261 | 262 | ctc_cost: !name:speechbrain.nnet.losses.ctc_loss 263 | blank_index: !ref 264 | reduction: !ref 265 | 266 | seq_cost: !name:speechbrain.nnet.losses.kldiv_loss 267 | label_smoothing: !ref 268 | reduction: !ref 269 | 270 | noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler 271 | lr_initial: !ref 272 | n_warmup_steps: 3125 # !ref 25000 // 273 | 274 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 275 | checkpoints_dir: !ref 276 | recoverables: 277 | model: !ref 278 | noam_scheduler: !ref 279 | normalize: !ref 280 | counter: !ref 281 | 282 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 283 | limit: !ref 284 | 285 | # Speed perturbation 286 | speed_changes: [95, 100, 105] # List of speed changes for time-stretching 287 | 288 | speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb 289 | orig_freq: !ref 290 | speeds: !ref 291 | 292 | # Time Drop 293 | time_drop_length_low: 15 # Min length for temporal chunk to drop in spectrogram 294 | time_drop_length_high: 25 # Max length for temporal chunk to drop in spectrogram 295 | time_drop_count_low: 4 # Min number of chunks to drop in time in the spectrogram 296 | time_drop_count_high: 4 # Max number of chunks to drop in time in the spectrogram 297 | time_drop_replace: "mean" # Method of dropping chunks 298 | 299 | time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop 300 | drop_length_low: !ref 301 | drop_length_high: !ref 302 | drop_count_low: !ref 303 | drop_count_high: !ref 304 | replace: !ref 305 | dim: 1 306 | 307 | # Frequency Drop 308 | freq_drop_length_low: 10 # Min length for chunks to drop in frequency in the spectrogram 309 | freq_drop_length_high: 20 # Max length for chunks to drop in frequency in the spectrogram 310 | freq_drop_count_low: 4 # Min number of chunks to drop in frequency in the spectrogram 311 | freq_drop_count_high: 4 # Max number of chunks to drop in frequency in the spectrogram 312 | freq_drop_replace: "mean" # Method of dropping chunks 313 | 314 | freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop 315 | drop_length_low: !ref 316 | drop_length_high: !ref 317 | drop_count_low: !ref 318 | drop_count_high: !ref 319 | replace: !ref 320 | dim: 2 321 | 322 | # Time warp 323 | time_warp_window: 5 # Length of time warping window 324 | time_warp_mode: "bicubic" # Time warping method 325 | 326 | time_warp: !new:speechbrain.augment.freq_domain.Warping 327 | warp_window: !ref 328 | warp_mode: !ref 329 | dim: 1 330 | 331 | fea_augment: !new:speechbrain.augment.augmenter.Augmenter 332 | parallel_augment: False 333 | concat_original: False 334 | repeat_augment: 1 335 | shuffle_augmentations: False 336 | min_augmentations: 3 337 | max_augmentations: 3 338 | augment_prob: 1.0 339 | augmentations: [ 340 | !ref , 341 | !ref , 342 | !ref ] 343 | 344 | compute_features: !new:speechbrain.lobes.features.Fbank 345 | sample_rate: !ref 346 | n_fft: !ref 347 | n_mels: !ref 348 | 349 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 350 | save_file: !ref 351 | 352 | error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 353 | acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats 354 | 355 | # The pretrainer allows a mapping between pretrained files and instances that 356 | # are declared in the yaml. E.g here, we will download the file lm.ckpt 357 | # and it will be loaded into "lm" which is pointing to the defined 358 | # before. 359 | pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer 360 | collect_in: !ref 361 | loadables: 362 | lm: !ref 363 | tokenizer: !ref 364 | paths: 365 | lm: !ref /lm.ckpt 366 | tokenizer: !ref /tokenizer.ckpt 367 | 368 | use_wandb: False 369 | resume: False 370 | wandb_logger: !name:speechbrain.utils.train_logger.WandBLogger 371 | initializer: !name:wandb.init 372 | entity: xj-audio 373 | project: !ref 374 | name: !ref 375 | dir: !ref 376 | reinit: true 377 | resume: !ref -------------------------------------------------------------------------------- /hparams/S2S/conmamba_small.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: E2E ASR with Transformer 3 | # Encoder: ConMamba Encoder 4 | # Decoder: Transformer Decoder + (CTC/ATT joint) beamsearch + TransformerLM 5 | # Tokens: unigram 6 | # losses: CTC + KLdiv (Label Smoothing loss) 7 | # Training: Librispeech 960h 8 | # Authors: Xilin Jiang 9 | # ############################################################################ 10 | # Seed needs to be set at top of yaml, before objects with parameters are made 11 | 12 | seed: 7775 13 | __set_seed: !apply:torch.manual_seed [!ref ] 14 | project: Mamba-ASR 15 | experiment: conmamba_S_S2S 16 | output_folder: !ref results/S2S// 17 | output_wer_folder: !ref / 18 | save_folder: !ref /save 19 | train_log: !ref /train_log.txt 20 | 21 | # Language model (LM) pretraining 22 | # NB: To avoid mismatch, the speech recognizer must be trained with the same 23 | # tokenizer used for LM training. Here, we download everything from the 24 | # speechbrain HuggingFace repository. However, a local path pointing to a 25 | # directory containing the lm.ckpt and tokenizer.ckpt may also be specified 26 | # instead. E.g if you want to use your own LM / tokenizer. 27 | pretrained_lm_tokenizer_path: speechbrain/asr-transformer-transformerlm-librispeech 28 | 29 | # Data files 30 | data_folder: !PLACEHOLDER 31 | # If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES 32 | # then data_folder_rirs should be /localscratch/xxx_corpus 33 | # otherwise the dataset will automatically be downloaded 34 | # data_folder_rirs: !ref 35 | train_splits: ["train-clean-100", "train-clean-360", "train-other-500"] 36 | dev_splits: ["dev-clean"] 37 | test_splits: ["test-clean", "test-other"] 38 | skip_prep: False 39 | train_csv: !ref /train.csv 40 | valid_csv: !ref /dev-clean.csv 41 | test_csv: 42 | - !ref /test-clean.csv 43 | - !ref /test-other.csv 44 | 45 | skip_train: False 46 | no_lm: False 47 | 48 | # Training parameters 49 | # To make Transformers converge, the global bath size should be large enough. 50 | # The global batch size is computed as batch_size * n_gpus * grad_accumulation_factor. 51 | # Empirically, we found that this value should be >= 128. 52 | # Please, set your parameters accordingly. 53 | number_of_epochs: 110 54 | batch_size: 16 # This works for 2x GPUs with 32GB 55 | ctc_weight: 0.3 56 | grad_accumulation_factor: 1 57 | max_grad_norm: 5.0 58 | loss_reduction: 'batchmean' 59 | sorting: random 60 | num_workers: 8 61 | precision: bf16 # bf16, fp16 or fp32 62 | avg_checkpoints: 10 # Number of checkpoints to average for evaluation 63 | 64 | # stages related parameters 65 | # stage_one_epochs: 90 66 | lr_adam: 0.001 67 | # lr_sgd: 0.000025 68 | 69 | # Feature parameters 70 | sample_rate: 16000 71 | n_fft: 400 72 | n_mels: 80 73 | 74 | # This setup works well for V100 32GB GPU, adapts it to your needs. 75 | # Or turn it off (but training speed will decrease) 76 | dynamic_batching: True 77 | max_batch_length_train: 900 78 | max_batch_length_val: 100 # we reduce it as the beam is much wider (VRAM) 79 | num_bucket: 200 80 | shuffle: True # if true re-creates batches at each epoch shuffling examples. 81 | batch_ordering: random 82 | max_batch_ex: 128 83 | 84 | dynamic_batch_sampler_train: 85 | max_batch_length: !ref 86 | num_buckets: !ref 87 | shuffle: !ref 88 | batch_ordering: !ref 89 | max_batch_ex: !ref 90 | 91 | dynamic_batch_sampler_valid: 92 | max_batch_length: !ref 93 | num_buckets: !ref 94 | shuffle: !ref 95 | batch_ordering: !ref 96 | max_batch_ex: !ref 97 | 98 | # Dataloader options 99 | train_dataloader_opts: 100 | batch_size: !ref 101 | shuffle: True 102 | num_workers: !ref 103 | 104 | valid_dataloader_opts: 105 | batch_size: 1 106 | 107 | test_dataloader_opts: 108 | batch_size: 1 109 | 110 | ####################### Model parameters ########################### 111 | # Transformer 112 | d_model: 144 113 | nhead: 4 114 | num_encoder_layers: 12 115 | num_decoder_layers: 4 116 | d_ffn: 1024 117 | transformer_dropout: 0.1 118 | activation: !name:torch.nn.GELU 119 | output_neurons: 5000 120 | 121 | # Outputs 122 | blank_index: 0 123 | label_smoothing: 0.0 124 | pad_index: 0 125 | bos_index: 1 126 | eos_index: 2 127 | 128 | # Decoding parameters 129 | min_decode_ratio: 0.0 130 | max_decode_ratio: 1.0 131 | valid_search_interval: 10 132 | valid_beam_size: 10 133 | test_beam_size: 66 134 | lm_weight: 0.60 135 | ctc_weight_decode: 0.40 136 | 137 | # Mamba parameters 138 | 139 | d_state: 16 140 | expand: 2 141 | d_conv: 4 142 | bidirectional: True 143 | mamba_config: 144 | d_state: !ref 145 | expand: !ref 146 | d_conv: !ref 147 | bidirectional: !ref 148 | 149 | ############################## models ################################ 150 | 151 | CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd 152 | input_shape: (8, 10, 80) 153 | num_blocks: 2 154 | num_layers_per_block: 1 155 | out_channels: (64, 32) 156 | kernel_sizes: (3, 3) 157 | strides: (2, 2) 158 | residuals: (False, False) 159 | 160 | Transformer: !new:modules.TransformerASR.TransformerASR # yamllint disable-line rule:line-length 161 | input_size: 640 162 | tgt_vocab: !ref 163 | d_model: !ref 164 | nhead: !ref # unused 165 | num_encoder_layers: !ref 166 | num_decoder_layers: !ref 167 | d_ffn: !ref 168 | dropout: !ref 169 | activation: !ref 170 | encoder_module: conmamba 171 | attention_type: RelPosMHAXL # unused 172 | normalize_before: True 173 | causal: False 174 | mamba_config: !ref 175 | 176 | # This is the TransformerLM that is used according to the Huggingface repository 177 | # Visit the HuggingFace model corresponding to the pretrained_lm_tokenizer_path 178 | # For more details about the model! 179 | # NB: It has to match the pre-trained TransformerLM!! 180 | lm_model: !new:speechbrain.lobes.models.transformer.TransformerLM.TransformerLM # yamllint disable-line rule:line-length 181 | vocab: !ref 182 | d_model: 768 183 | nhead: 12 184 | num_encoder_layers: 12 185 | num_decoder_layers: 0 186 | d_ffn: 3072 187 | dropout: 0.0 188 | activation: !name:torch.nn.GELU 189 | normalize_before: False 190 | 191 | tokenizer: !new:sentencepiece.SentencePieceProcessor 192 | 193 | ctc_lin: !new:speechbrain.nnet.linear.Linear 194 | input_size: !ref 195 | n_neurons: !ref 196 | 197 | seq_lin: !new:speechbrain.nnet.linear.Linear 198 | input_size: !ref 199 | n_neurons: !ref 200 | 201 | normalize: !new:speechbrain.processing.features.InputNormalization 202 | norm_type: global 203 | update_until_epoch: 4 204 | 205 | modules: 206 | CNN: !ref 207 | Transformer: !ref 208 | seq_lin: !ref 209 | ctc_lin: !ref 210 | normalize: !ref 211 | 212 | model: !new:torch.nn.ModuleList 213 | - [!ref , !ref , !ref , !ref ] 214 | 215 | # define two optimizers here for two-stage training 216 | Adam: !name:torch.optim.Adam 217 | lr: !ref 218 | betas: (0.9, 0.98) 219 | eps: 0.000000001 220 | 221 | # Scorer 222 | ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer 223 | eos_index: !ref 224 | blank_index: !ref 225 | ctc_fc: !ref 226 | 227 | 228 | transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer 229 | language_model: !ref 230 | temperature: 1.15 231 | 232 | scorer_test_search: !new:speechbrain.decoders.scorer.ScorerBuilder 233 | full_scorers: [!ref , !ref ] 234 | weights: 235 | ctc: !ref 236 | transformerlm: !ref 237 | 238 | scorer_valid_search: !new:speechbrain.decoders.scorer.ScorerBuilder 239 | full_scorers: [!ref ] 240 | weights: 241 | ctc: !ref 242 | 243 | valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher 244 | modules: [!ref , !ref ] 245 | bos_index: !ref 246 | eos_index: !ref 247 | min_decode_ratio: !ref 248 | max_decode_ratio: !ref 249 | beam_size: !ref 250 | using_eos_threshold: False 251 | length_normalization: True 252 | scorer: !ref 253 | 254 | test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher 255 | modules: [!ref , !ref ] 256 | bos_index: !ref 257 | eos_index: !ref 258 | min_decode_ratio: !ref 259 | max_decode_ratio: !ref 260 | beam_size: !ref 261 | temperature: 1.15 262 | using_eos_threshold: False 263 | length_normalization: True 264 | scorer: !ref 265 | 266 | 267 | log_softmax: !new:torch.nn.LogSoftmax 268 | dim: -1 269 | 270 | ctc_cost: !name:speechbrain.nnet.losses.ctc_loss 271 | blank_index: !ref 272 | reduction: !ref 273 | 274 | seq_cost: !name:speechbrain.nnet.losses.kldiv_loss 275 | label_smoothing: !ref 276 | reduction: !ref 277 | 278 | noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler 279 | lr_initial: !ref 280 | n_warmup_steps: !ref 25000 // 281 | 282 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 283 | checkpoints_dir: !ref 284 | recoverables: 285 | model: !ref 286 | noam_scheduler: !ref 287 | normalizer: !ref 288 | counter: !ref 289 | 290 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 291 | limit: !ref 292 | 293 | # Speed perturbation 294 | speed_changes: [95, 100, 105] # List of speed changes for time-stretching 295 | 296 | speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb 297 | orig_freq: !ref 298 | speeds: !ref 299 | 300 | # Time Drop 301 | time_drop_length_low: 15 # Min length for temporal chunk to drop in spectrogram 302 | time_drop_length_high: 25 # Max length for temporal chunk to drop in spectrogram 303 | time_drop_count_low: 4 # Min number of chunks to drop in time in the spectrogram 304 | time_drop_count_high: 4 # Max number of chunks to drop in time in the spectrogram 305 | time_drop_replace: "mean" # Method of dropping chunks 306 | 307 | time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop 308 | drop_length_low: !ref 309 | drop_length_high: !ref 310 | drop_count_low: !ref 311 | drop_count_high: !ref 312 | replace: !ref 313 | dim: 1 314 | 315 | # Frequency Drop 316 | freq_drop_length_low: 10 # Min length for chunks to drop in frequency in the spectrogram 317 | freq_drop_length_high: 20 # Max length for chunks to drop in frequency in the spectrogram 318 | freq_drop_count_low: 4 # Min number of chunks to drop in frequency in the spectrogram 319 | freq_drop_count_high: 4 # Max number of chunks to drop in frequency in the spectrogram 320 | freq_drop_replace: "mean" # Method of dropping chunks 321 | 322 | freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop 323 | drop_length_low: !ref 324 | drop_length_high: !ref 325 | drop_count_low: !ref 326 | drop_count_high: !ref 327 | replace: !ref 328 | dim: 2 329 | 330 | # Time warp 331 | time_warp_window: 5 # Length of time warping window 332 | time_warp_mode: "bicubic" # Time warping method 333 | 334 | time_warp: !new:speechbrain.augment.freq_domain.Warping 335 | warp_window: !ref 336 | warp_mode: !ref 337 | dim: 1 338 | 339 | fea_augment: !new:speechbrain.augment.augmenter.Augmenter 340 | parallel_augment: False 341 | concat_original: False 342 | repeat_augment: 1 343 | shuffle_augmentations: False 344 | min_augmentations: 3 345 | max_augmentations: 3 346 | augment_prob: 1.0 347 | augmentations: [ 348 | !ref , 349 | !ref , 350 | !ref ] 351 | 352 | compute_features: !new:speechbrain.lobes.features.Fbank 353 | sample_rate: !ref 354 | n_fft: !ref 355 | n_mels: !ref 356 | 357 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 358 | save_file: !ref 359 | 360 | error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 361 | acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats 362 | 363 | # The pretrainer allows a mapping between pretrained files and instances that 364 | # are declared in the yaml. E.g here, we will download the file lm.ckpt 365 | # and it will be loaded into "lm" which is pointing to the defined 366 | # before. 367 | pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer 368 | collect_in: !ref 369 | loadables: 370 | lm: !ref 371 | tokenizer: !ref 372 | paths: 373 | lm: !ref /lm.ckpt 374 | tokenizer: !ref /tokenizer.ckpt 375 | 376 | use_wandb: False 377 | resume: False 378 | wandb_logger: !name:speechbrain.utils.train_logger.WandBLogger 379 | initializer: !name:wandb.init 380 | entity: xj-audio 381 | project: !ref 382 | name: !ref 383 | dir: !ref 384 | reinit: true 385 | resume: !ref 386 | -------------------------------------------------------------------------------- /hparams/S2S/conmambamamba_small.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: E2E ASR with Transformer 3 | # Encoder: ConMamba Encoder 4 | # Decoder: Mamba Decoder + (CTC/ATT joint) beamsearch + TransformerLM 5 | # Tokens: unigram 6 | # losses: CTC + KLdiv (Label Smoothing loss) 7 | # Training: Librispeech 960h 8 | # Authors: Xilin Jiang 9 | # ############################################################################ 10 | # Seed needs to be set at top of yaml, before objects with parameters are made 11 | 12 | seed: 7775 13 | __set_seed: !apply:torch.manual_seed [!ref ] 14 | project: Mamba-ASR 15 | experiment: conmambamamba_S_S2S 16 | output_folder: !ref results/S2S// 17 | output_wer_folder: !ref / 18 | save_folder: !ref /save 19 | train_log: !ref /train_log.txt 20 | 21 | # Language model (LM) pretraining 22 | # NB: To avoid mismatch, the speech recognizer must be trained with the same 23 | # tokenizer used for LM training. Here, we download everything from the 24 | # speechbrain HuggingFace repository. However, a local path pointing to a 25 | # directory containing the lm.ckpt and tokenizer.ckpt may also be specified 26 | # instead. E.g if you want to use your own LM / tokenizer. 27 | pretrained_lm_tokenizer_path: speechbrain/asr-transformer-transformerlm-librispeech 28 | 29 | # Data files 30 | data_folder: !PLACEHOLDER 31 | # If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES 32 | # then data_folder_rirs should be /localscratch/xxx_corpus 33 | # otherwise the dataset will automatically be downloaded 34 | # data_folder_rirs: !ref 35 | train_splits: ["train-clean-100", "train-clean-360", "train-other-500"] 36 | dev_splits: ["dev-clean"] 37 | test_splits: ["test-clean", "test-other"] 38 | skip_prep: False 39 | train_csv: !ref /train.csv 40 | valid_csv: !ref /dev-clean.csv 41 | test_csv: 42 | - !ref /test-clean.csv 43 | - !ref /test-other.csv 44 | 45 | skip_train: False 46 | no_lm: False 47 | 48 | # Training parameters 49 | # To make Transformers converge, the global bath size should be large enough. 50 | # The global batch size is computed as batch_size * n_gpus * grad_accumulation_factor. 51 | # Empirically, we found that this value should be >= 128. 52 | # Please, set your parameters accordingly. 53 | number_of_epochs: 110 54 | batch_size: 16 # This works for 2x GPUs with 32GB 55 | ctc_weight: 0.3 56 | grad_accumulation_factor: 1 57 | max_grad_norm: 5.0 58 | loss_reduction: 'batchmean' 59 | sorting: random 60 | num_workers: 8 61 | precision: bf16 # bf16, fp16 or fp32 62 | avg_checkpoints: 10 # Number of checkpoints to average for evaluation 63 | 64 | # stages related parameters 65 | # stage_one_epochs: 90 66 | lr_adam: 0.001 67 | # lr_sgd: 0.000025 68 | 69 | # Feature parameters 70 | sample_rate: 16000 71 | n_fft: 400 72 | n_mels: 80 73 | 74 | # This setup works well for V100 32GB GPU, adapts it to your needs. 75 | # Or turn it off (but training speed will decrease) 76 | dynamic_batching: True 77 | max_batch_length_train: 900 78 | max_batch_length_val: 100 # we reduce it as the beam is much wider (VRAM) 79 | num_bucket: 200 80 | shuffle: True # if true re-creates batches at each epoch shuffling examples. 81 | batch_ordering: random 82 | max_batch_ex: 128 83 | 84 | dynamic_batch_sampler_train: 85 | max_batch_length: !ref 86 | num_buckets: !ref 87 | shuffle: !ref 88 | batch_ordering: !ref 89 | max_batch_ex: !ref 90 | 91 | dynamic_batch_sampler_valid: 92 | max_batch_length: !ref 93 | num_buckets: !ref 94 | shuffle: !ref 95 | batch_ordering: !ref 96 | max_batch_ex: !ref 97 | 98 | # Dataloader options 99 | train_dataloader_opts: 100 | batch_size: !ref 101 | shuffle: True 102 | num_workers: !ref 103 | 104 | valid_dataloader_opts: 105 | batch_size: 1 106 | 107 | test_dataloader_opts: 108 | batch_size: 1 109 | 110 | ####################### Model parameters ########################### 111 | # Transformer 112 | d_model: 144 113 | nhead: 4 114 | num_encoder_layers: 12 115 | num_decoder_layers: 4 116 | d_ffn: 1024 117 | transformer_dropout: 0.1 118 | activation: !name:torch.nn.GELU 119 | output_neurons: 5000 120 | 121 | # Outputs 122 | blank_index: 0 123 | label_smoothing: 0.0 124 | pad_index: 0 125 | bos_index: 1 126 | eos_index: 2 127 | 128 | # Decoding parameters 129 | min_decode_ratio: 0.0 130 | max_decode_ratio: 1.0 131 | valid_search_interval: 10 132 | valid_beam_size: 10 133 | test_beam_size: 66 134 | lm_weight: 0.60 135 | ctc_weight_decode: 0.40 136 | 137 | # Mamba parameters 138 | 139 | d_state: 16 140 | expand: 2 141 | d_conv: 4 142 | bidirectional: True 143 | mamba_config: 144 | d_state: !ref 145 | expand: !ref 146 | d_conv: !ref 147 | bidirectional: !ref 148 | 149 | ############################## models ################################ 150 | 151 | CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd 152 | input_shape: (8, 10, 80) 153 | num_blocks: 2 154 | num_layers_per_block: 1 155 | out_channels: (64, 32) 156 | kernel_sizes: (3, 3) 157 | strides: (2, 2) 158 | residuals: (False, False) 159 | 160 | Transformer: !new:modules.TransformerASR.TransformerASR # yamllint disable-line rule:line-length 161 | input_size: 640 162 | tgt_vocab: !ref 163 | d_model: !ref 164 | nhead: !ref # unused 165 | num_encoder_layers: !ref 166 | num_decoder_layers: !ref 167 | d_ffn: !ref 168 | dropout: !ref 169 | activation: !ref 170 | encoder_module: conmamba 171 | decoder_module: mamba 172 | attention_type: RelPosMHAXL # unused 173 | normalize_before: True 174 | causal: False 175 | mamba_config: !ref 176 | 177 | # This is the TransformerLM that is used according to the Huggingface repository 178 | # Visit the HuggingFace model corresponding to the pretrained_lm_tokenizer_path 179 | # For more details about the model! 180 | # NB: It has to match the pre-trained TransformerLM!! 181 | lm_model: !new:speechbrain.lobes.models.transformer.TransformerLM.TransformerLM # yamllint disable-line rule:line-length 182 | vocab: !ref 183 | d_model: 768 184 | nhead: 12 185 | num_encoder_layers: 12 186 | num_decoder_layers: 0 187 | d_ffn: 3072 188 | dropout: 0.0 189 | activation: !name:torch.nn.GELU 190 | normalize_before: False 191 | 192 | tokenizer: !new:sentencepiece.SentencePieceProcessor 193 | 194 | ctc_lin: !new:speechbrain.nnet.linear.Linear 195 | input_size: !ref 196 | n_neurons: !ref 197 | 198 | seq_lin: !new:speechbrain.nnet.linear.Linear 199 | input_size: !ref 200 | n_neurons: !ref 201 | 202 | normalize: !new:speechbrain.processing.features.InputNormalization 203 | norm_type: global 204 | update_until_epoch: 4 205 | 206 | modules: 207 | CNN: !ref 208 | Transformer: !ref 209 | seq_lin: !ref 210 | ctc_lin: !ref 211 | normalize: !ref 212 | 213 | model: !new:torch.nn.ModuleList 214 | - [!ref , !ref , !ref , !ref ] 215 | 216 | # define two optimizers here for two-stage training 217 | Adam: !name:torch.optim.Adam 218 | lr: !ref 219 | betas: (0.9, 0.98) 220 | eps: 0.000000001 221 | 222 | # Scorer 223 | ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer 224 | eos_index: !ref 225 | blank_index: !ref 226 | ctc_fc: !ref 227 | 228 | 229 | transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer 230 | language_model: !ref 231 | temperature: 1.15 232 | 233 | scorer_test_search: !new:speechbrain.decoders.scorer.ScorerBuilder 234 | full_scorers: [!ref , !ref ] 235 | weights: 236 | ctc: !ref 237 | transformerlm: !ref 238 | 239 | scorer_valid_search: !new:speechbrain.decoders.scorer.ScorerBuilder 240 | full_scorers: [!ref ] 241 | weights: 242 | ctc: !ref 243 | 244 | valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher 245 | modules: [!ref , !ref ] 246 | bos_index: !ref 247 | eos_index: !ref 248 | min_decode_ratio: !ref 249 | max_decode_ratio: !ref 250 | beam_size: !ref 251 | using_eos_threshold: False 252 | length_normalization: True 253 | scorer: !ref 254 | 255 | test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher 256 | modules: [!ref , !ref ] 257 | bos_index: !ref 258 | eos_index: !ref 259 | min_decode_ratio: !ref 260 | max_decode_ratio: !ref 261 | beam_size: !ref 262 | temperature: 1.15 263 | using_eos_threshold: False 264 | length_normalization: True 265 | scorer: !ref 266 | 267 | 268 | log_softmax: !new:torch.nn.LogSoftmax 269 | dim: -1 270 | 271 | ctc_cost: !name:speechbrain.nnet.losses.ctc_loss 272 | blank_index: !ref 273 | reduction: !ref 274 | 275 | seq_cost: !name:speechbrain.nnet.losses.kldiv_loss 276 | label_smoothing: !ref 277 | reduction: !ref 278 | 279 | noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler 280 | lr_initial: !ref 281 | n_warmup_steps: !ref 25000 // 282 | 283 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 284 | checkpoints_dir: !ref 285 | recoverables: 286 | model: !ref 287 | noam_scheduler: !ref 288 | normalizer: !ref 289 | counter: !ref 290 | 291 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 292 | limit: !ref 293 | 294 | # Speed perturbation 295 | speed_changes: [95, 100, 105] # List of speed changes for time-stretching 296 | 297 | speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb 298 | orig_freq: !ref 299 | speeds: !ref 300 | 301 | # Time Drop 302 | time_drop_length_low: 15 # Min length for temporal chunk to drop in spectrogram 303 | time_drop_length_high: 25 # Max length for temporal chunk to drop in spectrogram 304 | time_drop_count_low: 4 # Min number of chunks to drop in time in the spectrogram 305 | time_drop_count_high: 4 # Max number of chunks to drop in time in the spectrogram 306 | time_drop_replace: "mean" # Method of dropping chunks 307 | 308 | time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop 309 | drop_length_low: !ref 310 | drop_length_high: !ref 311 | drop_count_low: !ref 312 | drop_count_high: !ref 313 | replace: !ref 314 | dim: 1 315 | 316 | # Frequency Drop 317 | freq_drop_length_low: 10 # Min length for chunks to drop in frequency in the spectrogram 318 | freq_drop_length_high: 20 # Max length for chunks to drop in frequency in the spectrogram 319 | freq_drop_count_low: 4 # Min number of chunks to drop in frequency in the spectrogram 320 | freq_drop_count_high: 4 # Max number of chunks to drop in frequency in the spectrogram 321 | freq_drop_replace: "mean" # Method of dropping chunks 322 | 323 | freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop 324 | drop_length_low: !ref 325 | drop_length_high: !ref 326 | drop_count_low: !ref 327 | drop_count_high: !ref 328 | replace: !ref 329 | dim: 2 330 | 331 | # Time warp 332 | time_warp_window: 5 # Length of time warping window 333 | time_warp_mode: "bicubic" # Time warping method 334 | 335 | time_warp: !new:speechbrain.augment.freq_domain.Warping 336 | warp_window: !ref 337 | warp_mode: !ref 338 | dim: 1 339 | 340 | fea_augment: !new:speechbrain.augment.augmenter.Augmenter 341 | parallel_augment: False 342 | concat_original: False 343 | repeat_augment: 1 344 | shuffle_augmentations: False 345 | min_augmentations: 3 346 | max_augmentations: 3 347 | augment_prob: 1.0 348 | augmentations: [ 349 | !ref , 350 | !ref , 351 | !ref ] 352 | 353 | compute_features: !new:speechbrain.lobes.features.Fbank 354 | sample_rate: !ref 355 | n_fft: !ref 356 | n_mels: !ref 357 | 358 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 359 | save_file: !ref 360 | 361 | error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 362 | acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats 363 | 364 | # The pretrainer allows a mapping between pretrained files and instances that 365 | # are declared in the yaml. E.g here, we will download the file lm.ckpt 366 | # and it will be loaded into "lm" which is pointing to the defined 367 | # before. 368 | pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer 369 | collect_in: !ref 370 | loadables: 371 | lm: !ref 372 | tokenizer: !ref 373 | paths: 374 | lm: !ref /lm.ckpt 375 | tokenizer: !ref /tokenizer.ckpt 376 | 377 | use_wandb: False 378 | resume: False 379 | wandb_logger: !name:speechbrain.utils.train_logger.WandBLogger 380 | initializer: !name:wandb.init 381 | entity: xj-audio 382 | project: !ref 383 | name: !ref 384 | dir: !ref 385 | reinit: true 386 | resume: !ref 387 | -------------------------------------------------------------------------------- /hparams/S2S/conmamba_large.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: E2E ASR with Transformer 3 | # Encoder: ConMamba Encoder 4 | # Decoder: Transformer Decoder + (CTC/ATT joint) beamsearch + TransformerLM 5 | # Tokens: unigram 6 | # losses: CTC + KLdiv (Label Smoothing loss) 7 | # Training: Librispeech 960h 8 | # Authors: Xilin Jiang 9 | # ############################################################################ 10 | # Seed needs to be set at top of yaml, before objects with parameters are made 11 | 12 | seed: 3407 13 | __set_seed: !apply:torch.manual_seed [!ref ] 14 | project: Mamba-ASR 15 | experiment: conmamba_L_S2S 16 | output_folder: !ref results/S2S// 17 | output_wer_folder: !ref / 18 | save_folder: !ref /save 19 | train_log: !ref /train_log.txt 20 | 21 | # Language model (LM) pretraining 22 | # NB: To avoid mismatch, the speech recognizer must be trained with the same 23 | # tokenizer used for LM training. Here, we download everything from the 24 | # speechbrain HuggingFace repository. However, a local path pointing to a 25 | # directory containing the lm.ckpt and tokenizer.ckpt may also be specified 26 | # instead. E.g if you want to use your own LM / tokenizer. 27 | pretrained_lm_tokenizer_path: speechbrain/asr-transformer-transformerlm-librispeech 28 | 29 | # Data files 30 | data_folder: !PLACEHOLDER 31 | # If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES 32 | # then data_folder_rirs should be /localscratch/xxx_corpus 33 | # otherwise the dataset will automatically be downloaded 34 | # data_folder_rirs: !ref 35 | train_splits: ["train-clean-100", "train-clean-360", "train-other-500"] 36 | dev_splits: ["dev-clean"] 37 | test_splits: ["test-clean", "test-other"] 38 | skip_prep: False 39 | train_csv: !ref /train.csv 40 | valid_csv: !ref /dev-clean.csv 41 | test_csv: 42 | - !ref /test-clean.csv 43 | - !ref /test-other.csv 44 | 45 | skip_train: False 46 | no_lm: False 47 | 48 | # Training parameters 49 | # To make Transformers converge, the global bath size should be large enough. 50 | # The global batch size is computed as batch_size * n_gpus * grad_accumulation_factor. 51 | # Empirically, we found that this value should be >= 128. 52 | # Please, set your parameters accordingly. 53 | number_of_epochs: 120 54 | batch_size: 16 # This works for 2x GPUs with 32GB 55 | ctc_weight: 0.3 56 | grad_accumulation_factor: 8 57 | max_grad_norm: 5.0 58 | loss_reduction: 'batchmean' 59 | sorting: random 60 | num_workers: 4 61 | precision: bf16 # bf16, fp16 or fp32 62 | avg_checkpoints: 10 # Number of checkpoints to average for evaluation 63 | 64 | # stages related parameters 65 | lr_adam: 0.0008 66 | 67 | # Feature parameters 68 | sample_rate: 16000 69 | n_fft: 512 70 | n_mels: 80 71 | win_length: 32 72 | 73 | # This setup works well for A100 80GB GPU, adapts it to your needs. 74 | # Or turn it off (but training speed will decrease) 75 | dynamic_batching: True 76 | max_batch_length_train: 500 77 | max_batch_length_val: 100 # we reduce it as the beam is much wider (VRAM) 78 | num_bucket: 200 79 | shuffle: True # if true re-creates batches at each epoch shuffling examples. 80 | batch_ordering: random 81 | max_batch_ex: 256 82 | 83 | dynamic_batch_sampler_train: 84 | max_batch_length: !ref 85 | num_buckets: !ref 86 | shuffle: !ref 87 | batch_ordering: !ref 88 | max_batch_ex: !ref 89 | 90 | dynamic_batch_sampler_valid: 91 | max_batch_length: !ref 92 | num_buckets: !ref 93 | shuffle: !ref 94 | batch_ordering: !ref 95 | max_batch_ex: !ref 96 | 97 | # Dataloader options 98 | train_dataloader_opts: 99 | batch_size: !ref 100 | shuffle: True 101 | num_workers: !ref 102 | 103 | valid_dataloader_opts: 104 | batch_size: 1 105 | 106 | test_dataloader_opts: 107 | batch_size: 1 108 | 109 | ####################### Model parameters ########################### 110 | # Transformer dummy 111 | d_model: 512 112 | 113 | # Common 114 | nhead: 8 115 | num_encoder_layers: 12 116 | num_decoder_layers: 6 117 | d_ffn: 2048 118 | transformer_dropout: 0.1 119 | activation: !name:torch.nn.GELU 120 | output_neurons: 5000 121 | 122 | # Outputs 123 | blank_index: 0 124 | label_smoothing: 0.1 125 | pad_index: 0 126 | bos_index: 1 127 | eos_index: 2 128 | 129 | # Decoding parameters 130 | min_decode_ratio: 0.0 131 | max_decode_ratio: 1.0 132 | valid_search_interval: 10 133 | valid_beam_size: 10 134 | test_beam_size: 66 135 | lm_weight: 0.60 136 | ctc_weight_decode: 0.40 137 | 138 | # Mamba parameters 139 | d_state: 16 140 | expand: 2 141 | d_conv: 4 142 | bidirectional: True 143 | mamba_config: 144 | d_state: !ref 145 | expand: !ref 146 | d_conv: !ref 147 | bidirectional: !ref 148 | 149 | ############################## models ################################ 150 | 151 | CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd 152 | input_shape: (8, 10, 80) 153 | num_blocks: 2 154 | num_layers_per_block: 1 155 | out_channels: (64, 32) 156 | kernel_sizes: (3, 3) 157 | strides: (2, 2) 158 | residuals: (False, False) 159 | 160 | Transformer: !new:modules.TransformerASR.TransformerASR # yamllint disable-line rule:line-length 161 | input_size: 640 162 | tgt_vocab: !ref 163 | d_model: !ref 164 | nhead: !ref # unused 165 | num_encoder_layers: !ref 166 | num_decoder_layers: !ref 167 | d_ffn: !ref 168 | dropout: !ref 169 | activation: !ref 170 | encoder_module: conmamba 171 | attention_type: RelPosMHAXL 172 | normalize_before: True 173 | causal: False 174 | mamba_config: !ref 175 | 176 | # This is the TransformerLM that is used according to the Huggingface repository 177 | # Visit the HuggingFace model corresponding to the pretrained_lm_tokenizer_path 178 | # For more details about the model! 179 | # NB: It has to match the pre-trained TransformerLM!! 180 | lm_model: !new:speechbrain.lobes.models.transformer.TransformerLM.TransformerLM # yamllint disable-line rule:line-length 181 | vocab: !ref 182 | d_model: 768 183 | nhead: 12 184 | num_encoder_layers: 12 185 | num_decoder_layers: 0 186 | d_ffn: 3072 187 | dropout: 0.0 188 | activation: !name:torch.nn.GELU 189 | normalize_before: False 190 | 191 | tokenizer: !new:sentencepiece.SentencePieceProcessor 192 | 193 | ctc_lin: !new:speechbrain.nnet.linear.Linear 194 | input_size: !ref 195 | n_neurons: !ref 196 | 197 | seq_lin: !new:speechbrain.nnet.linear.Linear 198 | input_size: !ref 199 | n_neurons: !ref 200 | 201 | normalize: !new:speechbrain.processing.features.InputNormalization 202 | norm_type: global 203 | update_until_epoch: 4 204 | 205 | modules: 206 | CNN: !ref 207 | Transformer: !ref 208 | seq_lin: !ref 209 | ctc_lin: !ref 210 | normalize: !ref 211 | 212 | model: !new:torch.nn.ModuleList 213 | - [!ref , !ref , !ref , !ref ] 214 | 215 | # define two optimizers here for two-stage training 216 | Adam: !name:torch.optim.AdamW 217 | lr: !ref 218 | betas: (0.9, 0.98) 219 | eps: 0.000000001 220 | 221 | # Scorer 222 | ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer 223 | eos_index: !ref 224 | blank_index: !ref 225 | ctc_fc: !ref 226 | 227 | 228 | transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer 229 | language_model: !ref 230 | temperature: 1.15 231 | 232 | scorer_test_search: !new:speechbrain.decoders.scorer.ScorerBuilder 233 | full_scorers: [!ref , !ref ] 234 | weights: 235 | ctc: !ref 236 | transformerlm: !ref 237 | 238 | scorer_valid_search: !new:speechbrain.decoders.scorer.ScorerBuilder 239 | full_scorers: [!ref ] 240 | weights: 241 | ctc: !ref 242 | 243 | valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher 244 | modules: [!ref , !ref ] 245 | bos_index: !ref 246 | eos_index: !ref 247 | min_decode_ratio: !ref 248 | max_decode_ratio: !ref 249 | beam_size: !ref 250 | using_eos_threshold: False 251 | length_normalization: True 252 | scorer: !ref 253 | 254 | test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher 255 | modules: [!ref , !ref ] 256 | bos_index: !ref 257 | eos_index: !ref 258 | min_decode_ratio: !ref 259 | max_decode_ratio: !ref 260 | beam_size: !ref 261 | temperature: 1.15 262 | using_eos_threshold: False 263 | length_normalization: True 264 | scorer: !ref 265 | 266 | log_softmax: !new:torch.nn.LogSoftmax 267 | dim: -1 268 | 269 | ctc_cost: !name:speechbrain.nnet.losses.ctc_loss 270 | blank_index: !ref 271 | reduction: !ref 272 | 273 | seq_cost: !name:speechbrain.nnet.losses.kldiv_loss 274 | label_smoothing: !ref 275 | reduction: !ref 276 | 277 | n_warmup_steps: !ref 30000 // 278 | noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler 279 | lr_initial: !ref 280 | n_warmup_steps: !ref 281 | 282 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 283 | checkpoints_dir: !ref 284 | recoverables: 285 | model: !ref 286 | noam_scheduler: !ref 287 | normalizer: !ref 288 | counter: !ref 289 | 290 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 291 | limit: !ref 292 | 293 | # Speed perturbation 294 | speed_changes: [95, 100, 105] # List of speed changes for time-stretching 295 | 296 | speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb 297 | orig_freq: !ref 298 | speeds: !ref 299 | 300 | # Time Drop 301 | time_drop_length_low: 15 # Min length for temporal chunk to drop in spectrogram 302 | time_drop_length_high: 25 # Max length for temporal chunk to drop in spectrogram 303 | time_drop_count_low: 4 # Min number of chunks to drop in time in the spectrogram 304 | time_drop_count_high: 4 # Max number of chunks to drop in time in the spectrogram 305 | time_drop_replace: "mean" # Method of dropping chunks 306 | 307 | time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop 308 | drop_length_low: !ref 309 | drop_length_high: !ref 310 | drop_count_low: !ref 311 | drop_count_high: !ref 312 | replace: !ref 313 | dim: 1 314 | 315 | # Frequency Drop 316 | freq_drop_length_low: 10 # Min length for chunks to drop in frequency in the spectrogram 317 | freq_drop_length_high: 20 # Max length for chunks to drop in frequency in the spectrogram 318 | freq_drop_count_low: 4 # Min number of chunks to drop in frequency in the spectrogram 319 | freq_drop_count_high: 4 # Max number of chunks to drop in frequency in the spectrogram 320 | freq_drop_replace: "mean" # Method of dropping chunks 321 | 322 | freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop 323 | drop_length_low: !ref 324 | drop_length_high: !ref 325 | drop_count_low: !ref 326 | drop_count_high: !ref 327 | replace: !ref 328 | dim: 2 329 | 330 | # Time warp 331 | time_warp_window: 5 # Length of time warping window 332 | time_warp_mode: "bicubic" # Time warping method 333 | 334 | time_warp: !new:speechbrain.augment.freq_domain.Warping 335 | warp_window: !ref 336 | warp_mode: !ref 337 | dim: 1 338 | 339 | fea_augment: !new:speechbrain.augment.augmenter.Augmenter 340 | parallel_augment: False 341 | concat_original: False 342 | repeat_augment: 1 343 | shuffle_augmentations: False 344 | min_augmentations: 3 345 | max_augmentations: 3 346 | augment_prob: 1.0 347 | augmentations: [ 348 | !ref , 349 | !ref , 350 | !ref ] 351 | 352 | compute_features: !new:speechbrain.lobes.features.Fbank 353 | sample_rate: !ref 354 | n_fft: !ref 355 | n_mels: !ref 356 | win_length: !ref 357 | 358 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 359 | save_file: !ref 360 | 361 | error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 362 | acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats 363 | 364 | # The pretrainer allows a mapping between pretrained files and instances that 365 | # are declared in the yaml. E.g here, we will download the file lm.ckpt 366 | # and it will be loaded into "lm" which is pointing to the defined 367 | # before. 368 | pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer 369 | collect_in: !ref 370 | loadables: 371 | lm: !ref 372 | tokenizer: !ref 373 | paths: 374 | lm: !ref /lm.ckpt 375 | tokenizer: !ref /tokenizer.ckpt 376 | 377 | use_wandb: False 378 | resume: False 379 | wandb_logger: !name:speechbrain.utils.train_logger.WandBLogger 380 | initializer: !name:wandb.init 381 | entity: xj-audio 382 | project: !ref 383 | name: !ref 384 | dir: !ref 385 | reinit: true 386 | resume: !ref 387 | -------------------------------------------------------------------------------- /hparams/S2S/conmambamamba_large.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: E2E ASR with Transformer 3 | # Encoder: ConMamba Encoder 4 | # Decoder: Mamba Decoder + (CTC/ATT joint) beamsearch + TransformerLM 5 | # Tokens: unigram 6 | # losses: CTC + KLdiv (Label Smoothing loss) 7 | # Training: Librispeech 960h 8 | # Authors: Xilin Jiang 9 | # ############################################################################ 10 | # Seed needs to be set at top of yaml, before objects with parameters are made 11 | 12 | seed: 3407 13 | __set_seed: !apply:torch.manual_seed [!ref ] 14 | project: Mamba-ASR 15 | experiment: conmambamamba_L_S2S 16 | output_folder: !ref results/S2S// 17 | output_wer_folder: !ref / 18 | save_folder: !ref /save 19 | train_log: !ref /train_log.txt 20 | 21 | # Language model (LM) pretraining 22 | # NB: To avoid mismatch, the speech recognizer must be trained with the same 23 | # tokenizer used for LM training. Here, we download everything from the 24 | # speechbrain HuggingFace repository. However, a local path pointing to a 25 | # directory containing the lm.ckpt and tokenizer.ckpt may also be specified 26 | # instead. E.g if you want to use your own LM / tokenizer. 27 | pretrained_lm_tokenizer_path: speechbrain/asr-transformer-transformerlm-librispeech 28 | 29 | # Data files 30 | data_folder: !PLACEHOLDER 31 | # If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES 32 | # then data_folder_rirs should be /localscratch/xxx_corpus 33 | # otherwise the dataset will automatically be downloaded 34 | # data_folder_rirs: !ref 35 | train_splits: ["train-clean-100", "train-clean-360", "train-other-500"] 36 | dev_splits: ["dev-clean"] 37 | test_splits: ["test-clean", "test-other"] 38 | skip_prep: False 39 | train_csv: !ref /train.csv 40 | valid_csv: !ref /dev-clean.csv 41 | test_csv: 42 | - !ref /test-clean.csv 43 | - !ref /test-other.csv 44 | 45 | skip_train: False 46 | no_lm: False 47 | 48 | # Training parameters 49 | # To make Transformers converge, the global bath size should be large enough. 50 | # The global batch size is computed as batch_size * n_gpus * grad_accumulation_factor. 51 | # Empirically, we found that this value should be >= 128. 52 | # Please, set your parameters accordingly. 53 | number_of_epochs: 120 54 | batch_size: 16 # This works for 2x GPUs with 32GB 55 | ctc_weight: 0.3 56 | grad_accumulation_factor: 8 57 | max_grad_norm: 5.0 58 | loss_reduction: 'batchmean' 59 | sorting: random 60 | num_workers: 4 61 | precision: bf16 # bf16, fp16 or fp32 62 | avg_checkpoints: 10 # Number of checkpoints to average for evaluation 63 | 64 | # stages related parameters 65 | lr_adam: 0.0008 66 | 67 | # Feature parameters 68 | sample_rate: 16000 69 | n_fft: 512 70 | n_mels: 80 71 | win_length: 32 72 | 73 | # This setup works well for A100 80GB GPU, adapts it to your needs. 74 | # Or turn it off (but training speed will decrease) 75 | dynamic_batching: True 76 | max_batch_length_train: 500 77 | max_batch_length_val: 100 # we reduce it as the beam is much wider (VRAM) 78 | num_bucket: 200 79 | shuffle: True # if true re-creates batches at each epoch shuffling examples. 80 | batch_ordering: random 81 | max_batch_ex: 256 82 | 83 | dynamic_batch_sampler_train: 84 | max_batch_length: !ref 85 | num_buckets: !ref 86 | shuffle: !ref 87 | batch_ordering: !ref 88 | max_batch_ex: !ref 89 | 90 | dynamic_batch_sampler_valid: 91 | max_batch_length: !ref 92 | num_buckets: !ref 93 | shuffle: !ref 94 | batch_ordering: !ref 95 | max_batch_ex: !ref 96 | 97 | # Dataloader options 98 | train_dataloader_opts: 99 | batch_size: !ref 100 | shuffle: True 101 | num_workers: !ref 102 | 103 | valid_dataloader_opts: 104 | batch_size: 1 105 | 106 | test_dataloader_opts: 107 | batch_size: 1 108 | 109 | ####################### Model parameters ########################### 110 | # Transformer dummy 111 | d_model: 512 112 | 113 | # Common 114 | nhead: 8 115 | num_encoder_layers: 12 116 | num_decoder_layers: 6 117 | d_ffn: 2048 118 | transformer_dropout: 0.1 119 | activation: !name:torch.nn.GELU 120 | output_neurons: 5000 121 | 122 | # Outputs 123 | blank_index: 0 124 | label_smoothing: 0.1 125 | pad_index: 0 126 | bos_index: 1 127 | eos_index: 2 128 | 129 | # Decoding parameters 130 | min_decode_ratio: 0.0 131 | max_decode_ratio: 1.0 132 | valid_search_interval: 10 133 | valid_beam_size: 10 134 | test_beam_size: 66 135 | lm_weight: 0.60 136 | ctc_weight_decode: 0.40 137 | 138 | # Mamba parameters 139 | d_state: 16 140 | expand: 2 141 | d_conv: 4 142 | bidirectional: True 143 | mamba_config: 144 | d_state: !ref 145 | expand: !ref 146 | d_conv: !ref 147 | bidirectional: !ref 148 | 149 | ############################## models ################################ 150 | 151 | CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd 152 | input_shape: (8, 10, 80) 153 | num_blocks: 2 154 | num_layers_per_block: 1 155 | out_channels: (64, 32) 156 | kernel_sizes: (3, 3) 157 | strides: (2, 2) 158 | residuals: (False, False) 159 | 160 | Transformer: !new:modules.TransformerASR.TransformerASR # yamllint disable-line rule:line-length 161 | input_size: 640 162 | tgt_vocab: !ref 163 | d_model: !ref 164 | nhead: !ref # unused 165 | num_encoder_layers: !ref 166 | num_decoder_layers: !ref 167 | d_ffn: !ref 168 | dropout: !ref 169 | activation: !ref 170 | encoder_module: conmamba 171 | decoder_module: mamba 172 | attention_type: RelPosMHAXL 173 | normalize_before: True 174 | causal: False 175 | mamba_config: !ref 176 | 177 | # This is the TransformerLM that is used according to the Huggingface repository 178 | # Visit the HuggingFace model corresponding to the pretrained_lm_tokenizer_path 179 | # For more details about the model! 180 | # NB: It has to match the pre-trained TransformerLM!! 181 | lm_model: !new:speechbrain.lobes.models.transformer.TransformerLM.TransformerLM # yamllint disable-line rule:line-length 182 | vocab: !ref 183 | d_model: 768 184 | nhead: 12 185 | num_encoder_layers: 12 186 | num_decoder_layers: 0 187 | d_ffn: 3072 188 | dropout: 0.0 189 | activation: !name:torch.nn.GELU 190 | normalize_before: False 191 | 192 | tokenizer: !new:sentencepiece.SentencePieceProcessor 193 | 194 | ctc_lin: !new:speechbrain.nnet.linear.Linear 195 | input_size: !ref 196 | n_neurons: !ref 197 | 198 | seq_lin: !new:speechbrain.nnet.linear.Linear 199 | input_size: !ref 200 | n_neurons: !ref 201 | 202 | normalize: !new:speechbrain.processing.features.InputNormalization 203 | norm_type: global 204 | update_until_epoch: 4 205 | 206 | modules: 207 | CNN: !ref 208 | Transformer: !ref 209 | seq_lin: !ref 210 | ctc_lin: !ref 211 | normalize: !ref 212 | 213 | model: !new:torch.nn.ModuleList 214 | - [!ref , !ref , !ref , !ref ] 215 | 216 | # define two optimizers here for two-stage training 217 | Adam: !name:torch.optim.AdamW 218 | lr: !ref 219 | betas: (0.9, 0.98) 220 | eps: 0.000000001 221 | 222 | # Scorer 223 | ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer 224 | eos_index: !ref 225 | blank_index: !ref 226 | ctc_fc: !ref 227 | 228 | 229 | transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer 230 | language_model: !ref 231 | temperature: 1.15 232 | 233 | scorer_test_search: !new:speechbrain.decoders.scorer.ScorerBuilder 234 | full_scorers: [!ref , !ref ] 235 | weights: 236 | ctc: !ref 237 | transformerlm: !ref 238 | 239 | scorer_valid_search: !new:speechbrain.decoders.scorer.ScorerBuilder 240 | full_scorers: [!ref ] 241 | weights: 242 | ctc: !ref 243 | 244 | valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher 245 | modules: [!ref , !ref ] 246 | bos_index: !ref 247 | eos_index: !ref 248 | min_decode_ratio: !ref 249 | max_decode_ratio: !ref 250 | beam_size: !ref 251 | using_eos_threshold: False 252 | length_normalization: True 253 | scorer: !ref 254 | 255 | test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher 256 | modules: [!ref , !ref ] 257 | bos_index: !ref 258 | eos_index: !ref 259 | min_decode_ratio: !ref 260 | max_decode_ratio: !ref 261 | beam_size: !ref 262 | temperature: 1.15 263 | using_eos_threshold: False 264 | length_normalization: True 265 | scorer: !ref 266 | 267 | log_softmax: !new:torch.nn.LogSoftmax 268 | dim: -1 269 | 270 | ctc_cost: !name:speechbrain.nnet.losses.ctc_loss 271 | blank_index: !ref 272 | reduction: !ref 273 | 274 | seq_cost: !name:speechbrain.nnet.losses.kldiv_loss 275 | label_smoothing: !ref 276 | reduction: !ref 277 | 278 | n_warmup_steps: !ref 30000 // 279 | noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler 280 | lr_initial: !ref 281 | n_warmup_steps: !ref 282 | 283 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 284 | checkpoints_dir: !ref 285 | recoverables: 286 | model: !ref 287 | noam_scheduler: !ref 288 | normalizer: !ref 289 | counter: !ref 290 | 291 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 292 | limit: !ref 293 | 294 | # Speed perturbation 295 | speed_changes: [95, 100, 105] # List of speed changes for time-stretching 296 | 297 | speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb 298 | orig_freq: !ref 299 | speeds: !ref 300 | 301 | # Time Drop 302 | time_drop_length_low: 15 # Min length for temporal chunk to drop in spectrogram 303 | time_drop_length_high: 25 # Max length for temporal chunk to drop in spectrogram 304 | time_drop_count_low: 4 # Min number of chunks to drop in time in the spectrogram 305 | time_drop_count_high: 4 # Max number of chunks to drop in time in the spectrogram 306 | time_drop_replace: "mean" # Method of dropping chunks 307 | 308 | time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop 309 | drop_length_low: !ref 310 | drop_length_high: !ref 311 | drop_count_low: !ref 312 | drop_count_high: !ref 313 | replace: !ref 314 | dim: 1 315 | 316 | # Frequency Drop 317 | freq_drop_length_low: 10 # Min length for chunks to drop in frequency in the spectrogram 318 | freq_drop_length_high: 20 # Max length for chunks to drop in frequency in the spectrogram 319 | freq_drop_count_low: 4 # Min number of chunks to drop in frequency in the spectrogram 320 | freq_drop_count_high: 4 # Max number of chunks to drop in frequency in the spectrogram 321 | freq_drop_replace: "mean" # Method of dropping chunks 322 | 323 | freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop 324 | drop_length_low: !ref 325 | drop_length_high: !ref 326 | drop_count_low: !ref 327 | drop_count_high: !ref 328 | replace: !ref 329 | dim: 2 330 | 331 | # Time warp 332 | time_warp_window: 5 # Length of time warping window 333 | time_warp_mode: "bicubic" # Time warping method 334 | 335 | time_warp: !new:speechbrain.augment.freq_domain.Warping 336 | warp_window: !ref 337 | warp_mode: !ref 338 | dim: 1 339 | 340 | fea_augment: !new:speechbrain.augment.augmenter.Augmenter 341 | parallel_augment: False 342 | concat_original: False 343 | repeat_augment: 1 344 | shuffle_augmentations: False 345 | min_augmentations: 3 346 | max_augmentations: 3 347 | augment_prob: 1.0 348 | augmentations: [ 349 | !ref , 350 | !ref , 351 | !ref ] 352 | 353 | compute_features: !new:speechbrain.lobes.features.Fbank 354 | sample_rate: !ref 355 | n_fft: !ref 356 | n_mels: !ref 357 | win_length: !ref 358 | 359 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 360 | save_file: !ref 361 | 362 | error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 363 | acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats 364 | 365 | # The pretrainer allows a mapping between pretrained files and instances that 366 | # are declared in the yaml. E.g here, we will download the file lm.ckpt 367 | # and it will be loaded into "lm" which is pointing to the defined 368 | # before. 369 | pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer 370 | collect_in: !ref 371 | loadables: 372 | lm: !ref 373 | tokenizer: !ref 374 | paths: 375 | lm: !ref /lm.ckpt 376 | tokenizer: !ref /tokenizer.ckpt 377 | 378 | use_wandb: False 379 | resume: False 380 | wandb_logger: !name:speechbrain.utils.train_logger.WandBLogger 381 | initializer: !name:wandb.init 382 | entity: xj-audio 383 | project: !ref 384 | name: !ref 385 | dir: !ref 386 | reinit: true 387 | resume: !ref 388 | -------------------------------------------------------------------------------- /train_CTC.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copied and modified from 3 | https://github.com/speechbrain/speechbrain/blob/develop/recipes/LibriSpeech/ASR/CTC/train.py 4 | ''' 5 | 6 | 7 | """Recipe for training a Transformer ASR system with librispeech. 8 | The system employs an encoder and CTC greedy decoding. 9 | 10 | To run this recipe, do the following: 11 | > python train_from_scratch.py hparams/train_conformer.yaml 12 | or 13 | > python train_from_scratch.py hparams/train_branchformer.yaml 14 | 15 | With the default hyperparameters, the system employs a convolutional frontend and a transformer. 16 | Training is performed on the full LibriSpeech dataset (960 h). 17 | 18 | Authors 19 | * Titouan Parcollet 2021, 2022 20 | * Shucong Zhang 2023 21 | * Adel Moumen 2024 22 | """ 23 | 24 | import logging 25 | import os 26 | import sys 27 | from pathlib import Path 28 | 29 | import torch 30 | from hyperpyyaml import load_hyperpyyaml 31 | 32 | import speechbrain as sb 33 | from speechbrain.tokenizers.SentencePiece import SentencePiece 34 | from speechbrain.utils.distributed import if_main_process, run_on_main 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | 39 | # Define training procedure 40 | class ASR(sb.core.Brain): 41 | def compute_forward(self, batch, stage): 42 | """Forward computations from the waveform batches to the output probabilities.""" 43 | batch = batch.to(self.device) 44 | wavs, wav_lens = batch.sig 45 | 46 | # compute features 47 | feats = self.hparams.compute_features(wavs) # (B, T, 80) 48 | current_epoch = self.hparams.epoch_counter.current 49 | feats = self.modules.normalize(feats, wav_lens, epoch=current_epoch) 50 | 51 | # Add feature augmentation if specified. 52 | if stage == sb.Stage.TRAIN and hasattr(self.hparams, "fea_augment"): 53 | feats, fea_lens = self.hparams.fea_augment(feats, wav_lens) 54 | 55 | # forward modules 56 | src = self.modules.CNN(feats) 57 | 58 | enc_out, pred = self.modules.Transformer( 59 | src, tgt=None, wav_len=wav_lens 60 | ) 61 | 62 | # output layer for ctc log-probabilities 63 | logits = self.modules.ctc_lin(enc_out) 64 | p_ctc = self.hparams.log_softmax(logits) 65 | 66 | p_tokens = None 67 | if stage == sb.Stage.VALID: 68 | p_tokens = sb.decoders.ctc_greedy_decode( 69 | p_ctc, wav_lens, blank_id=self.hparams.blank_index 70 | ) 71 | elif stage == sb.Stage.TEST: 72 | p_tokens = test_searcher(p_ctc, wav_lens) 73 | 74 | return p_ctc, wav_lens, p_tokens 75 | 76 | def compute_objectives(self, predictions, batch, stage): 77 | """Computes the loss (CTC) given predictions and targets.""" 78 | 79 | p_ctc, wav_lens, predicted_tokens = predictions 80 | 81 | ids = batch.id 82 | tokens, tokens_lens = batch.tokens 83 | 84 | old_tokens = tokens.detach().clone() 85 | old_tokens_lens = tokens_lens.detach().clone() 86 | 87 | # Label Augmentation 88 | if stage == sb.Stage.TRAIN and hasattr(self.hparams, "fea_augment"): 89 | tokens = self.hparams.fea_augment.replicate_labels(tokens) 90 | tokens_lens = self.hparams.fea_augment.replicate_labels(tokens_lens) 91 | 92 | loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens).sum() 93 | 94 | if stage == sb.Stage.VALID: 95 | # Decode token terms to words 96 | predicted_words = self.tokenizer( 97 | predicted_tokens, task="decode_from_list" 98 | ) 99 | elif stage == sb.Stage.TEST: 100 | predicted_words = [ 101 | hyp[0].text.split(" ") for hyp in predicted_tokens 102 | ] 103 | 104 | if stage != sb.Stage.TRAIN: 105 | target_words = [wrd.split(" ") for wrd in batch.wrd] 106 | self.wer_metric.append(ids, predicted_words, target_words) 107 | self.cer_metric.append(ids, predicted_words, target_words) 108 | 109 | return loss 110 | 111 | def on_evaluate_start(self, max_key=None, min_key=None): 112 | """perform checkpoint averge if needed""" 113 | super().on_evaluate_start() 114 | 115 | ckpts = self.checkpointer.find_checkpoints( 116 | max_key=max_key, min_key=min_key 117 | ) 118 | ckpt = sb.utils.checkpoints.average_checkpoints( 119 | ckpts, 120 | recoverable_name="model", 121 | ) 122 | 123 | self.hparams.model.load_state_dict(ckpt, strict=True) 124 | self.hparams.model.eval() 125 | print("Loaded the average") 126 | 127 | def on_stage_start(self, stage, epoch): 128 | """Gets called at the beginning of each epoch""" 129 | if stage != sb.Stage.TRAIN: 130 | self.cer_metric = self.hparams.cer_computer() 131 | self.wer_metric = self.hparams.wer_computer() 132 | 133 | def on_stage_end(self, stage, stage_loss, epoch): 134 | """Gets called at the end of a epoch.""" 135 | # Compute/store important stats 136 | stage_stats = {"loss": stage_loss} 137 | if stage == sb.Stage.TRAIN: 138 | self.train_stats = stage_stats 139 | else: 140 | stage_stats["CER"] = self.cer_metric.summarize("error_rate") 141 | stage_stats["WER"] = self.wer_metric.summarize("error_rate") 142 | current_epoch = self.hparams.epoch_counter.current 143 | valid_search_interval = self.hparams.valid_search_interval 144 | if ( 145 | current_epoch % valid_search_interval == 0 146 | or stage == sb.Stage.TEST 147 | ): 148 | stage_stats["WER"] = self.wer_metric.summarize("error_rate") 149 | 150 | # log stats and save checkpoint at end-of-epoch 151 | if stage == sb.Stage.VALID: 152 | 153 | lr = self.hparams.noam_annealing.current_lr 154 | steps = self.optimizer_step 155 | optimizer = self.optimizer.__class__.__name__ 156 | 157 | epoch_stats = { 158 | "epoch": epoch, 159 | "lr": lr, 160 | "steps": steps, 161 | "optimizer": optimizer, 162 | } 163 | self.hparams.train_logger.log_stats( 164 | stats_meta=epoch_stats, 165 | train_stats=self.train_stats, 166 | valid_stats=stage_stats, 167 | ) 168 | self.checkpointer.save_and_keep_only( 169 | meta={"WER": stage_stats["WER"], "epoch": epoch}, 170 | min_keys=["WER"], 171 | num_to_keep=self.hparams.avg_checkpoints, 172 | ) 173 | 174 | elif stage == sb.Stage.TEST: 175 | self.hparams.train_logger.log_stats( 176 | stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, 177 | test_stats=stage_stats, 178 | ) 179 | if if_main_process(): 180 | with open(self.hparams.wer_file, "w") as w: 181 | self.wer_metric.write_stats(w) 182 | 183 | def on_fit_batch_end(self, batch, outputs, loss, should_step): 184 | if should_step: 185 | self.hparams.noam_annealing(self.optimizer) 186 | 187 | 188 | def dataio_prepare(hparams, tokenizer): 189 | """This function prepares the datasets to be used in the brain class. 190 | It also defines the data processing pipeline through user-defined functions. 191 | """ 192 | data_folder = hparams["data_folder"] 193 | 194 | train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( 195 | csv_path=hparams["train_csv"], replacements={"data_root": data_folder}, 196 | ) 197 | 198 | if hparams["sorting"] == "ascending": 199 | # we sort training data to speed up training and get better results. 200 | train_data = train_data.filtered_sorted(sort_key="duration") 201 | # when sorting do not shuffle in dataloader ! otherwise is pointless 202 | hparams["train_dataloader_opts"]["shuffle"] = False 203 | 204 | elif hparams["sorting"] == "descending": 205 | train_data = train_data.filtered_sorted( 206 | sort_key="duration", reverse=True 207 | ) 208 | # when sorting do not shuffle in dataloader ! otherwise is pointless 209 | hparams["train_dataloader_opts"]["shuffle"] = False 210 | 211 | elif hparams["sorting"] == "random": 212 | pass 213 | 214 | else: 215 | raise NotImplementedError( 216 | "sorting must be random, ascending or descending" 217 | ) 218 | valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( 219 | csv_path=hparams["valid_csv"], replacements={"data_root": data_folder}, 220 | ) 221 | valid_data = valid_data.filtered_sorted(sort_key="duration") 222 | 223 | # test is separate 224 | test_datasets = {} 225 | for csv_file in hparams["test_csv"]: 226 | name = Path(csv_file).stem 227 | test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv( 228 | csv_path=csv_file, replacements={"data_root": data_folder} 229 | ) 230 | test_datasets[name] = test_datasets[name].filtered_sorted( 231 | sort_key="duration" 232 | ) 233 | 234 | datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()] 235 | valtest_datasets = [valid_data] + [i for k, i in test_datasets.items()] 236 | 237 | # 2. Define audio pipeline: 238 | @sb.utils.data_pipeline.takes("wav") 239 | @sb.utils.data_pipeline.provides("sig") 240 | def audio_pipeline(wav): 241 | sig = sb.dataio.dataio.read_audio(wav) 242 | return sig 243 | 244 | sb.dataio.dataset.add_dynamic_item(valtest_datasets, audio_pipeline) 245 | 246 | @sb.utils.data_pipeline.takes("wav") 247 | @sb.utils.data_pipeline.provides("sig") 248 | def audio_pipeline_train(wav): 249 | # Speed Perturb is done here so it is multi-threaded with the 250 | # workers of the dataloader (faster). 251 | if hparams["speed_perturb"]: 252 | sig = sb.dataio.dataio.read_audio(wav) 253 | 254 | sig = hparams["speed_perturb"](sig.unsqueeze(0)).squeeze(0) 255 | else: 256 | sig = sb.dataio.dataio.read_audio(wav) 257 | return sig 258 | 259 | sb.dataio.dataset.add_dynamic_item([train_data], audio_pipeline_train) 260 | 261 | # 3. Define text pipeline: 262 | @sb.utils.data_pipeline.takes("wrd") 263 | @sb.utils.data_pipeline.provides( 264 | "wrd", "char_list", "tokens_list", "tokens" 265 | ) 266 | def text_pipeline(wrd): 267 | yield wrd 268 | char_list = list(wrd) 269 | yield char_list 270 | tokens_list = tokenizer.sp.encode_as_ids(wrd) 271 | yield tokens_list 272 | tokens = torch.LongTensor(tokens_list) 273 | yield tokens 274 | 275 | sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) 276 | 277 | # 4. Set output: 278 | sb.dataio.dataset.set_output_keys( 279 | datasets, ["id", "sig", "wrd", "char_list", "tokens"], 280 | ) 281 | 282 | # 5. If Dynamic Batching is used, we instantiate the needed samplers. 283 | train_batch_sampler = None 284 | valid_batch_sampler = None 285 | if hparams["dynamic_batching"]: 286 | from speechbrain.dataio.sampler import DynamicBatchSampler # noqa 287 | 288 | dynamic_hparams_train = hparams["dynamic_batch_sampler_train"] 289 | dynamic_hparams_val = hparams["dynamic_batch_sampler_val"] 290 | 291 | train_batch_sampler = DynamicBatchSampler( 292 | train_data, 293 | length_func=lambda x: x["duration"], 294 | **dynamic_hparams_train, 295 | ) 296 | 297 | valid_batch_sampler = DynamicBatchSampler( 298 | valid_data, 299 | length_func=lambda x: x["duration"], 300 | **dynamic_hparams_val, 301 | ) 302 | 303 | return ( 304 | train_data, 305 | valid_data, 306 | test_datasets, 307 | train_batch_sampler, 308 | valid_batch_sampler, 309 | ) 310 | 311 | 312 | if __name__ == "__main__": 313 | # CLI: 314 | hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) 315 | with open(hparams_file) as fin: 316 | hparams = load_hyperpyyaml(fin, overrides) 317 | 318 | # If --distributed_launch then 319 | # create ddp_group with the right communication protocol 320 | sb.utils.distributed.ddp_init_group(run_opts) 321 | 322 | # 1. # Dataset prep (parsing Librispeech) 323 | from librispeech_prepare import prepare_librispeech # noqa 324 | 325 | # Create experiment directory 326 | sb.create_experiment_directory( 327 | experiment_directory=hparams["output_folder"], 328 | hyperparams_to_save=hparams_file, 329 | overrides=overrides, 330 | ) 331 | 332 | # multi-gpu (ddp) save data preparation 333 | run_on_main( 334 | prepare_librispeech, 335 | kwargs={ 336 | "data_folder": hparams["data_folder"], 337 | "tr_splits": hparams["train_splits"], 338 | "dev_splits": hparams["dev_splits"], 339 | "te_splits": hparams["test_splits"], 340 | "save_folder": hparams["output_folder"], 341 | "merge_lst": hparams["train_splits"], 342 | "merge_name": "train.csv", 343 | "skip_prep": hparams["skip_prep"], 344 | }, 345 | ) 346 | 347 | # Defining tokenizer and loading it 348 | tokenizer = SentencePiece( 349 | model_dir=hparams["save_folder"], 350 | vocab_size=hparams["output_neurons"], 351 | annotation_train=hparams["train_csv"], 352 | annotation_read="wrd", 353 | model_type=hparams["token_type"], 354 | character_coverage=hparams["character_coverage"], 355 | bos_id=hparams["bos_index"], 356 | eos_id=hparams["eos_index"], 357 | ) 358 | 359 | # here we create the datasets objects as well as tokenization and encoding 360 | ( 361 | train_data, 362 | valid_data, 363 | test_datasets, 364 | train_bsampler, 365 | valid_bsampler, 366 | ) = dataio_prepare(hparams, tokenizer) 367 | 368 | # Init wandb 369 | if hparams['use_wandb']: 370 | hparams['train_logger'] = hparams['wandb_logger']() 371 | 372 | # Trainer initialization 373 | asr_brain = ASR( 374 | modules=hparams["modules"], 375 | opt_class=hparams["model_opt_class"], 376 | hparams=hparams, 377 | run_opts=run_opts, 378 | checkpointer=hparams["checkpointer"], 379 | ) 380 | 381 | # Adding objects to trainer. 382 | asr_brain.tokenizer = tokenizer 383 | vocab_list = [ 384 | tokenizer.sp.id_to_piece(i) for i in range(tokenizer.sp.vocab_size()) 385 | ] 386 | 387 | from speechbrain.decoders.ctc import CTCBeamSearcher 388 | 389 | test_searcher = CTCBeamSearcher( 390 | **hparams["test_beam_search"], 391 | vocab_list=vocab_list, 392 | ) 393 | 394 | train_dataloader_opts = hparams["train_dataloader_opts"] 395 | valid_dataloader_opts = hparams["valid_dataloader_opts"] 396 | 397 | if train_bsampler is not None: 398 | train_dataloader_opts = { 399 | "batch_sampler": train_bsampler, 400 | "num_workers": hparams["num_workers"], 401 | } 402 | 403 | if valid_bsampler is not None: 404 | valid_dataloader_opts = {"batch_sampler": valid_bsampler} 405 | 406 | if not hparams['skip_train']: 407 | # Training 408 | asr_brain.fit( 409 | asr_brain.hparams.epoch_counter, 410 | train_data, 411 | valid_data, 412 | train_loader_kwargs=train_dataloader_opts, 413 | valid_loader_kwargs=valid_dataloader_opts, 414 | ) 415 | 416 | # Testing 417 | for k in test_datasets.keys(): # keys are test_clean, test_other etc 418 | asr_brain.hparams.wer_file = os.path.join( 419 | hparams["output_folder"], f"wer_{k}.txt" 420 | ) 421 | asr_brain.evaluate( 422 | test_datasets[k], 423 | min_key="WER", 424 | test_loader_kwargs=hparams["test_dataloader_opts"], 425 | ) 426 | -------------------------------------------------------------------------------- /librispeech_prepare.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data preparation. 3 | 4 | Download: http://www.openslr.org/12 5 | 6 | Author 7 | ------ 8 | * Mirco Ravanelli, 2020 9 | * Ju-Chieh Chou, 2020 10 | * Loren Lugosch, 2020 11 | * Pierre Champion, 2023 12 | * Adel Moumen, 2024 13 | """ 14 | 15 | import csv 16 | import functools 17 | import logging 18 | import os 19 | import random 20 | from collections import Counter 21 | from dataclasses import dataclass 22 | 23 | from speechbrain.dataio.dataio import ( 24 | load_pkl, 25 | merge_csvs, 26 | read_audio_info, 27 | save_pkl, 28 | ) 29 | from speechbrain.utils.data_utils import download_file, get_all_files 30 | from speechbrain.utils.parallel import parallel_map 31 | 32 | logger = logging.getLogger(__name__) 33 | OPT_FILE = "opt_librispeech_prepare.pkl" 34 | SAMPLERATE = 16000 35 | OPEN_SLR_11_LINK = "http://www.openslr.org/resources/11/" 36 | OPEN_SLR_11_NGRAM_MODELs = [ 37 | "3-gram.arpa.gz", 38 | "3-gram.pruned.1e-7.arpa.gz", 39 | "3-gram.pruned.3e-7.arpa.gz", 40 | "4-gram.arpa.gz", 41 | ] 42 | 43 | 44 | def prepare_librispeech( 45 | data_folder, 46 | save_folder, 47 | tr_splits=[], 48 | dev_splits=[], 49 | te_splits=[], 50 | select_n_sentences=None, 51 | merge_lst=[], 52 | merge_name=None, 53 | create_lexicon=False, 54 | skip_prep=False, 55 | ): 56 | """ 57 | This class prepares the csv files for the LibriSpeech dataset. 58 | Download link: http://www.openslr.org/12 59 | 60 | Arguments 61 | --------- 62 | data_folder : str 63 | Path to the folder where the original LibriSpeech dataset is stored. 64 | save_folder : str 65 | The directory where to store the csv files. 66 | tr_splits : list 67 | List of train splits to prepare from ['test-others','train-clean-100', 68 | 'train-clean-360','train-other-500']. 69 | dev_splits : list 70 | List of dev splits to prepare from ['dev-clean','dev-others']. 71 | te_splits : list 72 | List of test splits to prepare from ['test-clean','test-others']. 73 | select_n_sentences : int 74 | Default : None 75 | If not None, only pick this many sentences. 76 | merge_lst : list 77 | List of librispeech splits (e.g, train-clean, train-clean-360,..) to 78 | merge in a single csv file. 79 | merge_name: str 80 | Name of the merged csv file. 81 | create_lexicon: bool 82 | If True, it outputs csv files containing mapping between grapheme 83 | to phonemes. Use it for training a G2P system. 84 | skip_prep: bool 85 | If True, data preparation is skipped. 86 | 87 | Returns 88 | ------- 89 | None 90 | 91 | Example 92 | ------- 93 | >>> data_folder = 'datasets/LibriSpeech' 94 | >>> tr_splits = ['train-clean-100'] 95 | >>> dev_splits = ['dev-clean'] 96 | >>> te_splits = ['test-clean'] 97 | >>> save_folder = 'librispeech_prepared' 98 | >>> prepare_librispeech(data_folder, save_folder, tr_splits, dev_splits, te_splits) 99 | """ 100 | 101 | if skip_prep: 102 | return 103 | data_folder = data_folder 104 | splits = tr_splits + dev_splits + te_splits 105 | save_folder = save_folder 106 | select_n_sentences = select_n_sentences 107 | conf = { 108 | "select_n_sentences": select_n_sentences, 109 | } 110 | 111 | # Other variables 112 | # Saving folder 113 | if not os.path.exists(save_folder): 114 | os.makedirs(save_folder) 115 | 116 | save_opt = os.path.join(save_folder, OPT_FILE) 117 | 118 | # Check if this phase is already done (if so, skip it) 119 | if skip(splits, save_folder, conf): 120 | logger.info("Skipping preparation, completed in previous run.") 121 | return 122 | else: 123 | logger.info("Data_preparation...") 124 | 125 | # Additional checks to make sure the data folder contains Librispeech 126 | check_librispeech_folders(data_folder, splits) 127 | 128 | # create csv files for each split 129 | all_texts = {} 130 | for split_index in range(len(splits)): 131 | split = splits[split_index] 132 | 133 | wav_lst = get_all_files( 134 | os.path.join(data_folder, split), match_and=[".flac"] 135 | ) 136 | 137 | text_lst = get_all_files( 138 | os.path.join(data_folder, split), match_and=["trans.txt"] 139 | ) 140 | 141 | text_dict = text_to_dict(text_lst) 142 | all_texts.update(text_dict) 143 | 144 | if select_n_sentences is not None: 145 | n_sentences = select_n_sentences[split_index] 146 | else: 147 | n_sentences = len(wav_lst) 148 | 149 | create_csv(save_folder, wav_lst, text_dict, split, n_sentences) 150 | 151 | # Merging csv file if needed 152 | if merge_lst and merge_name is not None: 153 | merge_files = [split_libri + ".csv" for split_libri in merge_lst] 154 | merge_csvs( 155 | data_folder=save_folder, csv_lst=merge_files, merged_csv=merge_name 156 | ) 157 | 158 | # Create lexicon.csv and oov.csv 159 | if create_lexicon: 160 | create_lexicon_and_oov_csv(all_texts, save_folder) 161 | 162 | # saving options 163 | save_pkl(conf, save_opt) 164 | 165 | 166 | def create_lexicon_and_oov_csv(all_texts, save_folder): 167 | """ 168 | Creates lexicon csv files useful for training and testing a 169 | grapheme-to-phoneme (G2P) model. 170 | 171 | Arguments 172 | --------- 173 | all_texts : dict 174 | Dictionary containing text from the librispeech transcriptions 175 | save_folder : str 176 | The directory where to store the csv files. 177 | """ 178 | # If the lexicon file does not exist, download it 179 | lexicon_url = "http://www.openslr.org/resources/11/librispeech-lexicon.txt" 180 | lexicon_path = os.path.join(save_folder, "librispeech-lexicon.txt") 181 | 182 | if not os.path.isfile(lexicon_path): 183 | logger.info( 184 | "Lexicon file not found. Downloading from %s." % lexicon_url 185 | ) 186 | download_file(lexicon_url, lexicon_path) 187 | 188 | # Get list of all words in the transcripts 189 | transcript_words = Counter() 190 | for key in all_texts: 191 | transcript_words.update(all_texts[key].split("_")) 192 | 193 | # Get list of all words in the lexicon 194 | lexicon_words = [] 195 | lexicon_pronunciations = [] 196 | with open(lexicon_path, "r") as f: 197 | lines = f.readlines() 198 | for line in lines: 199 | word = line.split()[0] 200 | pronunciation = line.split()[1:] 201 | lexicon_words.append(word) 202 | lexicon_pronunciations.append(pronunciation) 203 | 204 | # Create lexicon.csv 205 | header = "ID,duration,char,phn\n" 206 | lexicon_csv_path = os.path.join(save_folder, "lexicon.csv") 207 | with open(lexicon_csv_path, "w") as f: 208 | f.write(header) 209 | for idx in range(len(lexicon_words)): 210 | separated_graphemes = [c for c in lexicon_words[idx]] 211 | duration = len(separated_graphemes) 212 | graphemes = " ".join(separated_graphemes) 213 | pronunciation_no_numbers = [ 214 | p.strip("0123456789") for p in lexicon_pronunciations[idx] 215 | ] 216 | phonemes = " ".join(pronunciation_no_numbers) 217 | line = ( 218 | ",".join([str(idx), str(duration), graphemes, phonemes]) + "\n" 219 | ) 220 | f.write(line) 221 | logger.info("Lexicon written to %s." % lexicon_csv_path) 222 | 223 | # Split lexicon.csv in train, validation, and test splits 224 | split_lexicon(save_folder, [98, 1, 1]) 225 | 226 | 227 | def split_lexicon(data_folder, split_ratio): 228 | """ 229 | Splits the lexicon.csv file into train, validation, and test csv files 230 | 231 | Arguments 232 | --------- 233 | data_folder : str 234 | Path to the folder containing the lexicon.csv file to split. 235 | split_ratio : list 236 | List containing the training, validation, and test split ratio. Set it 237 | to [80, 10, 10] for having 80% of material for training, 10% for valid, 238 | and 10 for test. 239 | """ 240 | # Reading lexicon.csv 241 | lexicon_csv_path = os.path.join(data_folder, "lexicon.csv") 242 | with open(lexicon_csv_path, "r") as f: 243 | lexicon_lines = f.readlines() 244 | # Remove header 245 | lexicon_lines = lexicon_lines[1:] 246 | 247 | # Shuffle entries 248 | random.shuffle(lexicon_lines) 249 | 250 | # Selecting lines 251 | header = "ID,duration,char,phn\n" 252 | 253 | tr_snts = int(0.01 * split_ratio[0] * len(lexicon_lines)) 254 | train_lines = [header] + lexicon_lines[0:tr_snts] 255 | valid_snts = int(0.01 * split_ratio[1] * len(lexicon_lines)) 256 | valid_lines = [header] + lexicon_lines[tr_snts : tr_snts + valid_snts] 257 | test_lines = [header] + lexicon_lines[tr_snts + valid_snts :] 258 | 259 | # Saving files 260 | with open(os.path.join(data_folder, "lexicon_tr.csv"), "w") as f: 261 | f.writelines(train_lines) 262 | with open(os.path.join(data_folder, "lexicon_dev.csv"), "w") as f: 263 | f.writelines(valid_lines) 264 | with open(os.path.join(data_folder, "lexicon_test.csv"), "w") as f: 265 | f.writelines(test_lines) 266 | 267 | 268 | @dataclass 269 | class LSRow: 270 | snt_id: str 271 | spk_id: str 272 | duration: float 273 | file_path: str 274 | words: str 275 | 276 | 277 | def process_line(wav_file, text_dict) -> LSRow: 278 | snt_id = wav_file.split("/")[-1].replace(".flac", "") 279 | spk_id = "-".join(snt_id.split("-")[0:2]) 280 | wrds = text_dict[snt_id] 281 | wrds = " ".join(wrds.split("_")) 282 | 283 | info = read_audio_info(wav_file) 284 | duration = info.num_frames / info.sample_rate 285 | 286 | return LSRow( 287 | snt_id=snt_id, 288 | spk_id=spk_id, 289 | duration=duration, 290 | file_path=wav_file, 291 | words=wrds, 292 | ) 293 | 294 | 295 | def create_csv(save_folder, wav_lst, text_dict, split, select_n_sentences): 296 | """ 297 | Create the dataset csv file given a list of wav files. 298 | 299 | Arguments 300 | --------- 301 | save_folder : str 302 | Location of the folder for storing the csv. 303 | wav_lst : list 304 | The list of wav files of a given data split. 305 | text_dict : list 306 | The dictionary containing the text of each sentence. 307 | split : str 308 | The name of the current data split. 309 | select_n_sentences : int, optional 310 | The number of sentences to select. 311 | 312 | Returns 313 | ------- 314 | None 315 | """ 316 | # Setting path for the csv file 317 | csv_file = os.path.join(save_folder, split + ".csv") 318 | if os.path.exists(csv_file): 319 | logger.info("Csv file %s already exists, not recreating." % csv_file) 320 | return 321 | 322 | # Preliminary prints 323 | msg = "Creating csv lists in %s..." % (csv_file) 324 | logger.info(msg) 325 | 326 | csv_lines = [["ID", "duration", "wav", "spk_id", "wrd"]] 327 | 328 | snt_cnt = 0 329 | line_processor = functools.partial(process_line, text_dict=text_dict) 330 | # Processing all the wav files in wav_lst 331 | # FLAC metadata reading is already fast, so we set a high chunk size 332 | # to limit main thread CPU bottlenecks 333 | for row in parallel_map(line_processor, wav_lst, chunk_size=8192): 334 | csv_line = [ 335 | row.snt_id, 336 | str(row.duration), 337 | row.file_path, 338 | row.spk_id, 339 | row.words, 340 | ] 341 | 342 | # Appending current file to the csv_lines list 343 | csv_lines.append(csv_line) 344 | 345 | snt_cnt = snt_cnt + 1 346 | 347 | # parallel_map guarantees element ordering so we're OK 348 | if snt_cnt == select_n_sentences: 349 | break 350 | 351 | # Writing the csv_lines 352 | with open(csv_file, mode="w") as csv_f: 353 | csv_writer = csv.writer( 354 | csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL 355 | ) 356 | 357 | for line in csv_lines: 358 | csv_writer.writerow(line) 359 | 360 | # Final print 361 | msg = "%s successfully created!" % (csv_file) 362 | logger.info(msg) 363 | 364 | 365 | def skip(splits, save_folder, conf): 366 | """ 367 | Detect when the librispeech data prep can be skipped. 368 | 369 | Arguments 370 | --------- 371 | splits : list 372 | A list of the splits expected in the preparation. 373 | save_folder : str 374 | The location of the save directory 375 | conf : dict 376 | The configuration options to ensure they haven't changed. 377 | 378 | Returns 379 | ------- 380 | bool 381 | if True, the preparation phase can be skipped. 382 | if False, it must be done. 383 | """ 384 | 385 | # Checking csv files 386 | skip = True 387 | 388 | for split in splits: 389 | if not os.path.isfile(os.path.join(save_folder, split + ".csv")): 390 | skip = False 391 | 392 | # Checking saved options 393 | save_opt = os.path.join(save_folder, OPT_FILE) 394 | if skip is True: 395 | if os.path.isfile(save_opt): 396 | opts_old = load_pkl(save_opt) 397 | if opts_old == conf: 398 | skip = True 399 | else: 400 | skip = False 401 | else: 402 | skip = False 403 | 404 | return skip 405 | 406 | 407 | def text_to_dict(text_lst): 408 | """ 409 | This converts lines of text into a dictionary- 410 | 411 | Arguments 412 | --------- 413 | text_lst : str 414 | Path to the file containing the librispeech text transcription. 415 | 416 | Returns 417 | ------- 418 | dict 419 | The dictionary containing the text transcriptions for each sentence. 420 | 421 | """ 422 | # Initialization of the text dictionary 423 | text_dict = {} 424 | # Reading all the transcription files is text_lst 425 | for file in text_lst: 426 | with open(file, "r") as f: 427 | # Reading all line of the transcription file 428 | for line in f: 429 | line_lst = line.strip().split(" ") 430 | text_dict[line_lst[0]] = "_".join(line_lst[1:]) 431 | return text_dict 432 | 433 | 434 | def check_librispeech_folders(data_folder, splits): 435 | """ 436 | Check if the data folder actually contains the LibriSpeech dataset. 437 | 438 | If it does not, an error is raised. 439 | 440 | Arguments 441 | --------- 442 | data_folder : str 443 | The path to the directory with the data. 444 | splits : list 445 | The portions of the data to check. 446 | 447 | Raises 448 | ------ 449 | OSError 450 | If LibriSpeech is not found at the specified path. 451 | """ 452 | # Checking if all the splits exist 453 | for split in splits: 454 | split_folder = os.path.join(data_folder, split) 455 | if not os.path.exists(split_folder): 456 | err_msg = ( 457 | "the folder %s does not exist (it is expected in the " 458 | "Librispeech dataset)" % split_folder 459 | ) 460 | raise OSError(err_msg) 461 | 462 | 463 | def download_librispeech_vocab_text(destination): 464 | """Download librispeech vocab file and unpack it. 465 | 466 | Arguments 467 | --------- 468 | destination : str 469 | Place to put vocab file. 470 | """ 471 | f = "librispeech-vocab.txt" 472 | download_file(OPEN_SLR_11_LINK + f, destination) 473 | 474 | 475 | def download_openslr_librispeech_lm(destination, rescoring_lm=True): 476 | """Download openslr librispeech lm and unpack it. 477 | 478 | Arguments 479 | --------- 480 | destination : str 481 | Place to put lm. 482 | rescoring_lm : bool 483 | Also download bigger 4grams model 484 | """ 485 | os.makedirs(destination, exist_ok=True) 486 | for f in OPEN_SLR_11_NGRAM_MODELs: 487 | if f.startswith("4") and not rescoring_lm: 488 | continue 489 | d = os.path.join(destination, f) 490 | download_file(OPEN_SLR_11_LINK + f, d, unpack=True) 491 | 492 | 493 | def download_sb_librispeech_lm(destination, rescoring_lm=True): 494 | """Download sb librispeech lm and unpack it. 495 | 496 | Arguments 497 | --------- 498 | destination : str 499 | Place to put lm. 500 | rescoring_lm : bool 501 | Also download bigger 4grams model 502 | """ 503 | os.makedirs(destination, exist_ok=True) 504 | download_file( 505 | "https://www.dropbox.com/scl/fi/3fkkdlliavhveb5n3nsow/3gram_lm.arpa?rlkey=jgdrluppfut1pjminf3l3y106&dl=1", 506 | os.path.join(destination, "3-gram_sb.arpa"), 507 | ) 508 | if rescoring_lm: 509 | download_file( 510 | "https://www.dropbox.com/scl/fi/roz46ee0ah2lvy5csno4z/4gram_lm.arpa?rlkey=2wt8ozb1mqgde9h9n9rp2yppz&dl=1", 511 | os.path.join(destination, "4-gram_sb.arpa"), 512 | ) 513 | -------------------------------------------------------------------------------- /train_S2S.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copied and modified from 3 | https://github.com/speechbrain/speechbrain/blob/develop/recipes/LibriSpeech/ASR/transformer/train.py 4 | ''' 5 | 6 | 7 | #!/usr/bin/env python3 8 | """Recipe for training a Transformer ASR system with librispeech. 9 | The system employs an encoder, a decoder, and an attention mechanism 10 | between them. Decoding is performed with (CTC/Att joint) beamsearch coupled with a neural 11 | language model. 12 | 13 | To run this recipe, do the following: 14 | > python train.py hparams/transformer.yaml 15 | > python train.py hparams/conformer.yaml 16 | 17 | With the default hyperparameters, the system employs a convolutional frontend and a transformer. 18 | The decoder is based on a Transformer decoder. Beamsearch coupled with a Transformer 19 | language model is used on the top of decoder probabilities. 20 | 21 | The neural network is trained on both CTC and negative-log likelihood 22 | targets and sub-word units estimated with Byte Pairwise Encoding (BPE) 23 | are used as basic recognition tokens. Training is performed on the full 24 | LibriSpeech dataset (960 h). 25 | 26 | The best model is the average of the checkpoints from last 5 epochs. 27 | 28 | The experiment file is flexible enough to support a large variety of 29 | different systems. By properly changing the parameter files, you can try 30 | different encoders, decoders, tokens (e.g, characters instead of BPE), 31 | training split (e.g, train-clean 100 rather than the full one), and many 32 | other possible variations. 33 | 34 | 35 | Authors 36 | * Jianyuan Zhong 2020 37 | * Mirco Ravanelli 2020 38 | * Peter Plantinga 2020 39 | * Samuele Cornell 2020, 2021, 2022 40 | * Titouan Parcollet 2021, 2022 41 | """ 42 | 43 | import os 44 | import sys 45 | import torch 46 | import logging 47 | from pathlib import Path 48 | import speechbrain as sb 49 | from hyperpyyaml import load_hyperpyyaml 50 | from speechbrain.utils.distributed import run_on_main, if_main_process 51 | 52 | logger = logging.getLogger(__name__) 53 | 54 | os.environ['WANDB__SERVICE_WAIT'] = '999999' 55 | 56 | # Define training procedure 57 | class ASR(sb.core.Brain): 58 | def compute_forward(self, batch, stage): 59 | """Forward computations from the waveform batches to the output probabilities.""" 60 | batch = batch.to(self.device) 61 | wavs, wav_lens = batch.sig # (B, N) 62 | tokens_bos, _ = batch.tokens_bos 63 | 64 | # compute features 65 | feats = self.hparams.compute_features(wavs) # (B, T, 80) 66 | current_epoch = self.hparams.epoch_counter.current 67 | feats = self.modules.normalize(feats, wav_lens, epoch=current_epoch) 68 | 69 | # Add feature augmentation if specified. 70 | if stage == sb.Stage.TRAIN and hasattr(self.hparams, "fea_augment"): 71 | feats, fea_lens = self.hparams.fea_augment(feats, wav_lens) 72 | tokens_bos = self.hparams.fea_augment.replicate_labels(tokens_bos) 73 | 74 | # forward modules 75 | src = self.modules.CNN(feats) # (B, L, 20, 32) -> (B, L, 640) 76 | 77 | enc_out, pred = self.modules.Transformer( 78 | src, tokens_bos, wav_lens, pad_idx=self.hparams.pad_index, 79 | ) 80 | 81 | # output layer for ctc log-probabilities 82 | logits = self.modules.ctc_lin(enc_out) 83 | p_ctc = self.hparams.log_softmax(logits) 84 | 85 | # output layer for seq2seq log-probabilities 86 | pred = self.modules.seq_lin(pred) 87 | p_seq = self.hparams.log_softmax(pred) 88 | 89 | # Compute outputs 90 | hyps = None 91 | current_epoch = self.hparams.epoch_counter.current 92 | is_valid_search = ( 93 | stage == sb.Stage.VALID 94 | and current_epoch % self.hparams.valid_search_interval == 0 95 | ) 96 | is_test_search = stage == sb.Stage.TEST 97 | 98 | if any([is_valid_search, is_test_search]): 99 | # Note: For valid_search, for the sake of efficiency, we only perform beamsearch with 100 | # limited capacity and no LM to give user some idea of how the AM is doing 101 | 102 | # Decide searcher for inference: valid or test search 103 | if stage == sb.Stage.VALID: 104 | hyps, _, _, _ = self.hparams.valid_search( 105 | enc_out.detach(), wav_lens 106 | ) 107 | else: 108 | hyps, _, _, _ = self.hparams.test_search( 109 | enc_out.detach(), wav_lens 110 | ) 111 | 112 | return p_ctc, p_seq, wav_lens, hyps 113 | 114 | def compute_objectives(self, predictions, batch, stage): 115 | """Computes the loss (CTC+NLL) given predictions and targets.""" 116 | 117 | (p_ctc, p_seq, wav_lens, hyps,) = predictions 118 | 119 | ids = batch.id 120 | tokens_eos, tokens_eos_lens = batch.tokens_eos 121 | tokens, tokens_lens = batch.tokens 122 | 123 | if stage == sb.Stage.TRAIN: 124 | if hasattr(self.hparams, "fea_augment"): 125 | tokens = self.hparams.fea_augment.replicate_labels(tokens) 126 | tokens_lens = self.hparams.fea_augment.replicate_labels( 127 | tokens_lens 128 | ) 129 | tokens_eos = self.hparams.fea_augment.replicate_labels( 130 | tokens_eos 131 | ) 132 | tokens_eos_lens = self.hparams.fea_augment.replicate_labels( 133 | tokens_eos_lens 134 | ) 135 | 136 | loss_seq = self.hparams.seq_cost( 137 | p_seq, tokens_eos, length=tokens_eos_lens 138 | ).sum() 139 | 140 | loss_ctc = self.hparams.ctc_cost( 141 | p_ctc, tokens, wav_lens, tokens_lens 142 | ).sum() 143 | 144 | loss = ( 145 | self.hparams.ctc_weight * loss_ctc 146 | + (1 - self.hparams.ctc_weight) * loss_seq 147 | ) 148 | 149 | if stage != sb.Stage.TRAIN: 150 | current_epoch = self.hparams.epoch_counter.current 151 | valid_search_interval = self.hparams.valid_search_interval 152 | if current_epoch % valid_search_interval == 0 or ( 153 | stage == sb.Stage.TEST 154 | ): 155 | # Decode token terms to words 156 | predicted_words = [ 157 | tokenizer.decode_ids(utt_seq).split(" ") for utt_seq in hyps 158 | ] 159 | target_words = [wrd.split(" ") for wrd in batch.wrd] 160 | self.wer_metric.append(ids, predicted_words, target_words) 161 | 162 | # compute the accuracy of the one-step-forward prediction 163 | self.acc_metric.append(p_seq, tokens_eos, tokens_eos_lens) 164 | return loss 165 | 166 | def on_evaluate_start(self, max_key=None, min_key=None): 167 | """perform checkpoint averge if needed""" 168 | super().on_evaluate_start() 169 | 170 | ckpts = self.checkpointer.find_checkpoints( 171 | max_key=max_key, min_key=min_key 172 | ) 173 | ckpt = sb.utils.checkpoints.average_checkpoints( 174 | ckpts, recoverable_name="model", 175 | ) 176 | 177 | self.hparams.model.load_state_dict(ckpt, strict=True) 178 | self.hparams.model.eval() 179 | print("Loaded the average") 180 | 181 | def on_stage_start(self, stage, epoch): 182 | """Gets called at the beginning of each epoch""" 183 | if stage != sb.Stage.TRAIN: 184 | self.acc_metric = self.hparams.acc_computer() 185 | self.wer_metric = self.hparams.error_rate_computer() 186 | 187 | def on_stage_end(self, stage, stage_loss, epoch): 188 | """Gets called at the end of a epoch.""" 189 | # Compute/store important stats 190 | stage_stats = {"loss": stage_loss} 191 | if stage == sb.Stage.TRAIN: 192 | self.train_stats = stage_stats 193 | else: 194 | stage_stats["ACC"] = self.acc_metric.summarize() 195 | current_epoch = self.hparams.epoch_counter.current 196 | valid_search_interval = self.hparams.valid_search_interval 197 | if ( 198 | current_epoch % valid_search_interval == 0 199 | or stage == sb.Stage.TEST 200 | ): 201 | stage_stats["WER"] = self.wer_metric.summarize("error_rate") 202 | 203 | # log stats and save checkpoint at end-of-epoch 204 | if stage == sb.Stage.VALID: 205 | 206 | lr = self.hparams.noam_annealing.current_lr 207 | steps = self.optimizer_step 208 | optimizer = self.optimizer.__class__.__name__ 209 | 210 | epoch_stats = { 211 | "epoch": epoch, 212 | "lr": lr, 213 | "steps": steps, 214 | "optimizer": optimizer, 215 | } 216 | self.hparams.train_logger.log_stats( 217 | stats_meta={"epoch": epoch, "lr": lr}, 218 | train_stats=self.train_stats, 219 | valid_stats=stage_stats, 220 | ) 221 | 222 | self.checkpointer.save_and_keep_only( 223 | meta={"ACC": stage_stats["ACC"], "epoch": epoch}, 224 | max_keys=["ACC"], 225 | num_to_keep=self.hparams.avg_checkpoints, 226 | ) 227 | 228 | elif stage == sb.Stage.TEST: 229 | self.hparams.train_logger.log_stats( 230 | stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, 231 | test_stats=stage_stats, 232 | ) 233 | if if_main_process(): 234 | with open(self.hparams.test_wer_file, "w") as w: 235 | self.wer_metric.write_stats(w) 236 | 237 | # save the averaged checkpoint at the end of the evaluation stage 238 | # delete the rest of the intermediate checkpoints 239 | # ACC is set to 1.1 so checkpointer only keeps the averaged checkpoint 240 | self.checkpointer.save_and_keep_only( 241 | meta={"ACC": 1.1, "epoch": epoch}, 242 | max_keys=["ACC"], 243 | num_to_keep=1, 244 | ) 245 | 246 | def on_fit_batch_end(self, batch, outputs, loss, should_step): 247 | """At the end of the optimizer step, apply noam annealing.""" 248 | if should_step: 249 | self.hparams.noam_annealing(self.optimizer) 250 | 251 | 252 | def dataio_prepare(hparams): 253 | """This function prepares the datasets to be used in the brain class. 254 | It also defines the data processing pipeline through user-defined functions.""" 255 | data_folder = hparams["data_folder"] 256 | 257 | train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( 258 | csv_path=hparams["train_csv"], replacements={"data_root": data_folder}, 259 | ) 260 | 261 | if hparams["sorting"] == "ascending": 262 | # we sort training data to speed up training and get better results. 263 | train_data = train_data.filtered_sorted(sort_key="duration") 264 | # when sorting do not shuffle in dataloader ! otherwise is pointless 265 | hparams["train_dataloader_opts"]["shuffle"] = False 266 | 267 | elif hparams["sorting"] == "descending": 268 | train_data = train_data.filtered_sorted( 269 | sort_key="duration", reverse=True 270 | ) 271 | # when sorting do not shuffle in dataloader ! otherwise is pointless 272 | hparams["train_dataloader_opts"]["shuffle"] = False 273 | 274 | elif hparams["sorting"] == "random": 275 | pass 276 | 277 | else: 278 | raise NotImplementedError( 279 | "sorting must be random, ascending or descending" 280 | ) 281 | valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( 282 | csv_path=hparams["valid_csv"], replacements={"data_root": data_folder}, 283 | ) 284 | valid_data = valid_data.filtered_sorted(sort_key="duration") 285 | 286 | # test is separate 287 | test_datasets = {} 288 | for csv_file in hparams["test_csv"]: 289 | name = Path(csv_file).stem 290 | test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv( 291 | csv_path=csv_file, replacements={"data_root": data_folder} 292 | ) 293 | test_datasets[name] = test_datasets[name].filtered_sorted( 294 | sort_key="duration" 295 | ) 296 | 297 | datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()] 298 | valtest_datasets = [valid_data] + [i for k, i in test_datasets.items()] 299 | 300 | # We get the tokenizer as we need it to encode the labels when creating 301 | # mini-batches. 302 | tokenizer = hparams["tokenizer"] 303 | 304 | # 2. Define audio pipeline: 305 | @sb.utils.data_pipeline.takes("wav") 306 | @sb.utils.data_pipeline.provides("sig") 307 | def audio_pipeline(wav): 308 | sig = sb.dataio.dataio.read_audio(wav) 309 | return sig 310 | 311 | sb.dataio.dataset.add_dynamic_item(valtest_datasets, audio_pipeline) 312 | 313 | @sb.utils.data_pipeline.takes("wav") 314 | @sb.utils.data_pipeline.provides("sig") 315 | def audio_pipeline_train(wav): 316 | # Speed Perturb is done here so it is multi-threaded with the 317 | # workers of the dataloader (faster). 318 | if "speed_perturb" in hparams: 319 | sig = sb.dataio.dataio.read_audio(wav) 320 | 321 | sig = hparams["speed_perturb"](sig.unsqueeze(0)).squeeze(0) 322 | else: 323 | sig = sb.dataio.dataio.read_audio(wav) 324 | return sig 325 | 326 | sb.dataio.dataset.add_dynamic_item([train_data], audio_pipeline_train) 327 | 328 | # 3. Define text pipeline: 329 | @sb.utils.data_pipeline.takes("wrd") 330 | @sb.utils.data_pipeline.provides( 331 | "wrd", "tokens_list", "tokens_bos", "tokens_eos", "tokens" 332 | ) 333 | def text_pipeline(wrd): 334 | yield wrd 335 | tokens_list = tokenizer.encode_as_ids(wrd) 336 | yield tokens_list 337 | tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list)) 338 | yield tokens_bos 339 | tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]]) 340 | yield tokens_eos 341 | tokens = torch.LongTensor(tokens_list) 342 | yield tokens 343 | 344 | sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) 345 | 346 | # 4. Set output: 347 | sb.dataio.dataset.set_output_keys( 348 | datasets, ["id", "sig", "wrd", "tokens_bos", "tokens_eos", "tokens"], 349 | ) 350 | 351 | # 5. If Dynamic Batching is used, we instantiate the needed samplers. 352 | train_batch_sampler = None 353 | valid_batch_sampler = None 354 | if hparams["dynamic_batching"]: 355 | from speechbrain.dataio.sampler import DynamicBatchSampler # noqa 356 | 357 | dynamic_hparams_train = hparams["dynamic_batch_sampler_train"] 358 | dynamic_hparams_valid = hparams["dynamic_batch_sampler_valid"] 359 | 360 | # print(dynamic_hparams_train) 361 | 362 | train_batch_sampler = DynamicBatchSampler( 363 | train_data, 364 | length_func=lambda x: x["duration"], 365 | **dynamic_hparams_train, 366 | ) 367 | valid_batch_sampler = DynamicBatchSampler( 368 | valid_data, 369 | length_func=lambda x: x["duration"], 370 | **dynamic_hparams_valid, 371 | ) 372 | 373 | return ( 374 | train_data, 375 | valid_data, 376 | test_datasets, 377 | tokenizer, 378 | train_batch_sampler, 379 | valid_batch_sampler, 380 | ) 381 | 382 | 383 | if __name__ == "__main__": 384 | # CLI: 385 | hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) 386 | with open(hparams_file) as fin: 387 | hparams = load_hyperpyyaml(fin, overrides) 388 | 389 | # create ddp_group with the right communication protocol 390 | sb.utils.distributed.ddp_init_group(run_opts) 391 | 392 | # 1. # Dataset prep (parsing Librispeech) 393 | from librispeech_prepare import prepare_librispeech # noqa 394 | 395 | # Create experiment directory 396 | sb.create_experiment_directory( 397 | experiment_directory=hparams["output_folder"], 398 | hyperparams_to_save=hparams_file, 399 | overrides=overrides, 400 | ) 401 | 402 | # multi-gpu (ddp) save data preparation 403 | run_on_main( 404 | prepare_librispeech, 405 | kwargs={ 406 | "data_folder": hparams["data_folder"], 407 | "tr_splits": hparams["train_splits"], 408 | "dev_splits": hparams["dev_splits"], 409 | "te_splits": hparams["test_splits"], 410 | "save_folder": hparams["output_folder"], 411 | "merge_lst": hparams["train_splits"], 412 | "merge_name": "train.csv", 413 | "skip_prep": hparams["skip_prep"], 414 | }, 415 | ) 416 | 417 | # here we create the datasets objects as well as tokenization and encoding 418 | ( 419 | train_data, 420 | valid_data, 421 | test_datasets, 422 | tokenizer, 423 | train_bsampler, 424 | valid_bsampler, 425 | ) = dataio_prepare(hparams) 426 | 427 | # We download the pretrained LM from HuggingFace (or elsewhere depending on 428 | # the path given in the YAML file). The tokenizer is loaded at the same time. 429 | run_on_main(hparams["pretrainer"].collect_files) 430 | hparams["pretrainer"].load_collected() 431 | 432 | # Init wandb 433 | if hparams['use_wandb']: 434 | hparams['train_logger'] = hparams['wandb_logger']() 435 | 436 | if hparams['no_lm']: 437 | print('Evaluate without LM.') 438 | hparams['test_search'] = hparams['valid_search'] 439 | hparams["output_wer_folder"] = os.path.join(hparams["output_wer_folder"], 'no_lm') 440 | 441 | # Trainer initialization 442 | asr_brain = ASR( 443 | modules=hparams["modules"], 444 | opt_class=hparams["Adam"], 445 | hparams=hparams, 446 | run_opts=run_opts, 447 | checkpointer=hparams["checkpointer"], 448 | ) 449 | 450 | # adding objects to trainer: 451 | asr_brain.tokenizer = hparams["tokenizer"] 452 | train_dataloader_opts = hparams["train_dataloader_opts"] 453 | valid_dataloader_opts = hparams["valid_dataloader_opts"] 454 | 455 | if train_bsampler is not None: 456 | collate_fn = None 457 | if "collate_fn" in train_dataloader_opts: 458 | collate_fn = train_dataloader_opts["collate_fn"] 459 | 460 | train_dataloader_opts = { 461 | "batch_sampler": train_bsampler, 462 | "num_workers": hparams["num_workers"], 463 | } 464 | 465 | if collate_fn is not None: 466 | train_dataloader_opts["collate_fn"] = collate_fn 467 | 468 | if valid_bsampler is not None: 469 | collate_fn = None 470 | if "collate_fn" in valid_dataloader_opts: 471 | collate_fn = valid_dataloader_opts["collate_fn"] 472 | 473 | valid_dataloader_opts = {"batch_sampler": valid_bsampler} 474 | 475 | if collate_fn is not None: 476 | valid_dataloader_opts["collate_fn"] = collate_fn 477 | 478 | if not hparams['skip_train']: 479 | # Training 480 | asr_brain.fit( 481 | asr_brain.hparams.epoch_counter, 482 | train_data, 483 | valid_data, 484 | train_loader_kwargs=train_dataloader_opts, 485 | valid_loader_kwargs=valid_dataloader_opts, 486 | ) 487 | 488 | # Testing 489 | if not os.path.exists(hparams["output_wer_folder"]): 490 | os.makedirs(hparams["output_wer_folder"]) 491 | 492 | for k in test_datasets.keys(): # keys are test_clean, test_other etc 493 | asr_brain.hparams.test_wer_file = os.path.join( 494 | hparams["output_wer_folder"], f"wer_{k}.txt" 495 | ) 496 | asr_brain.evaluate( 497 | test_datasets[k], 498 | max_key="ACC", 499 | test_loader_kwargs=hparams["test_dataloader_opts"], 500 | ) 501 | -------------------------------------------------------------------------------- /modules/mamba/bimamba.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copied and modified from 3 | https://github.com/hustvl/Vim/blob/main/mamba-1p1p1/mamba_ssm/modules/mamba_simple.py 4 | ''' 5 | 6 | # Copyright (c) 2023, Tri Dao, Albert Gu. 7 | 8 | import math 9 | from typing import Optional 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch import Tensor 15 | 16 | from einops import rearrange, repeat 17 | 18 | try: 19 | from causal_conv1d import causal_conv1d_fn, causal_conv1d_update 20 | except ImportError: 21 | causal_conv1d_fn, causal_conv1d_update = None 22 | 23 | try: 24 | from modules.mamba.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj 25 | except ImportError: 26 | selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None 27 | 28 | try: 29 | from mamba_ssm.ops.triton.selective_state_update import selective_state_update 30 | except ImportError: 31 | selective_state_update = None 32 | 33 | try: 34 | from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn 35 | except ImportError: 36 | RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None 37 | 38 | 39 | class Mamba(nn.Module): 40 | def __init__( 41 | self, 42 | d_model, 43 | d_state=16, 44 | d_conv=4, 45 | expand=2, 46 | dt_rank="auto", 47 | dt_min=0.001, 48 | dt_max=0.1, 49 | dt_init="random", 50 | dt_scale=1.0, 51 | dt_init_floor=1e-4, 52 | conv_bias=True, 53 | bias=False, 54 | use_fast_path=True, # Fused kernel options 55 | layer_idx=None, 56 | device=None, 57 | dtype=None, 58 | bimamba_type="none", 59 | if_devide_out=True, # False 60 | init_layer_scale=None, 61 | ): 62 | factory_kwargs = {"device": device, "dtype": dtype} 63 | super().__init__() 64 | self.d_model = d_model 65 | self.d_state = d_state 66 | self.d_conv = d_conv 67 | self.expand = expand 68 | self.d_inner = int(self.expand * self.d_model) 69 | self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank 70 | self.use_fast_path = use_fast_path 71 | self.layer_idx = layer_idx 72 | self.bimamba_type = bimamba_type 73 | self.if_devide_out = if_devide_out 74 | 75 | assert bimamba_type == 'v2' 76 | 77 | self.init_layer_scale = init_layer_scale 78 | if init_layer_scale is not None: 79 | self.gamma = nn.Parameter(init_layer_scale * torch.ones((d_model)), requires_grad=True) 80 | 81 | self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) 82 | 83 | self.conv1d = nn.Conv1d( 84 | in_channels=self.d_inner, 85 | out_channels=self.d_inner, 86 | bias=conv_bias, 87 | kernel_size=d_conv, 88 | groups=self.d_inner, 89 | padding=d_conv - 1, 90 | **factory_kwargs, 91 | ) 92 | 93 | self.activation = "silu" 94 | self.act = nn.SiLU() 95 | 96 | self.x_proj = nn.Linear( 97 | self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs 98 | ) 99 | self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) 100 | 101 | # Initialize special dt projection to preserve variance at initialization 102 | dt_init_std = self.dt_rank**-0.5 * dt_scale 103 | if dt_init == "constant": 104 | nn.init.constant_(self.dt_proj.weight, dt_init_std) 105 | elif dt_init == "random": 106 | nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) 107 | else: 108 | raise NotImplementedError 109 | 110 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 111 | dt = torch.exp( 112 | torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) 113 | + math.log(dt_min) 114 | ).clamp(min=dt_init_floor) 115 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 116 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 117 | with torch.no_grad(): 118 | self.dt_proj.bias.copy_(inv_dt) 119 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 120 | self.dt_proj.bias._no_reinit = True 121 | 122 | # S4D real initialization 123 | A = repeat( 124 | torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), 125 | "n -> d n", 126 | d=self.d_inner, 127 | ).contiguous() 128 | A_log = torch.log(A) # Keep A_log in fp32 129 | self.A_log = nn.Parameter(A_log) 130 | self.A_log._no_weight_decay = True 131 | 132 | # D "skip" parameter 133 | self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 134 | self.D._no_weight_decay = True 135 | 136 | # bidirectional 137 | if bimamba_type == "v1": 138 | A_b = repeat( 139 | torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), 140 | "n -> d n", 141 | d=self.d_inner, 142 | ).contiguous() 143 | A_b_log = torch.log(A_b) # Keep A_b_log in fp32 144 | self.A_b_log = nn.Parameter(A_b_log) 145 | self.A_b_log._no_weight_decay = True 146 | elif bimamba_type == "v2": 147 | A_b = repeat( 148 | torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), 149 | "n -> d n", 150 | d=self.d_inner, 151 | ).contiguous() 152 | A_b_log = torch.log(A_b) # Keep A_b_log in fp32 153 | self.A_b_log = nn.Parameter(A_b_log) 154 | self.A_b_log._no_weight_decay = True 155 | 156 | self.conv1d_b = nn.Conv1d( 157 | in_channels=self.d_inner, 158 | out_channels=self.d_inner, 159 | bias=conv_bias, 160 | kernel_size=d_conv, 161 | groups=self.d_inner, 162 | padding=d_conv - 1, 163 | **factory_kwargs, 164 | ) 165 | 166 | self.x_proj_b = nn.Linear( 167 | self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs 168 | ) 169 | self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) 170 | 171 | self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 172 | self.D_b._no_weight_decay = True 173 | 174 | self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) 175 | 176 | def forward(self, hidden_states, inference_params=None): 177 | """ 178 | hidden_states: (B, L, D) 179 | Returns: same shape as hidden_states 180 | """ 181 | batch, seqlen, dim = hidden_states.shape 182 | conv_state, ssm_state = None, None 183 | 184 | if inference_params is not None: 185 | conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) 186 | if inference_params.seqlen_offset > 0: 187 | # The states are updated inplace 188 | out, _, _ = self.step(hidden_states, conv_state, ssm_state) 189 | return out 190 | 191 | # We do matmul and transpose BLH -> HBL at the same time 192 | xz = rearrange( 193 | self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"), 194 | "d (b l) -> b d l", 195 | l=seqlen, 196 | ) 197 | if self.in_proj.bias is not None: 198 | xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") 199 | 200 | A = -torch.exp(self.A_log.float()) # (d_inner, d_state) 201 | # In the backward pass we write dx and dz next to each other to avoid torch.cat 202 | if self.use_fast_path and inference_params is None: # Doesn't support outputting the states 203 | if self.bimamba_type == "v1": 204 | A_b = -torch.exp(self.A_b_log.float()) 205 | out = bimamba_inner_fn( 206 | xz, 207 | self.conv1d.weight, 208 | self.conv1d.bias, 209 | self.x_proj.weight, 210 | self.dt_proj.weight, 211 | self.out_proj.weight, 212 | self.out_proj.bias, 213 | A, 214 | A_b, 215 | None, # input-dependent B 216 | None, # input-dependent C 217 | self.D.float(), 218 | delta_bias=self.dt_proj.bias.float(), 219 | delta_softplus=True, 220 | ) 221 | elif self.bimamba_type == "v2": 222 | A_b = -torch.exp(self.A_b_log.float()) 223 | out = mamba_inner_fn_no_out_proj( 224 | xz, 225 | self.conv1d.weight, 226 | self.conv1d.bias, 227 | self.x_proj.weight, 228 | self.dt_proj.weight, 229 | A, 230 | None, # input-dependent B 231 | None, # input-dependent C 232 | self.D.float(), 233 | delta_bias=self.dt_proj.bias.float(), 234 | delta_softplus=True, 235 | ) 236 | out_b = mamba_inner_fn_no_out_proj( 237 | xz.flip([-1]), 238 | self.conv1d_b.weight, 239 | self.conv1d_b.bias, 240 | self.x_proj_b.weight, 241 | self.dt_proj_b.weight, 242 | A_b, 243 | None, 244 | None, 245 | self.D_b.float(), 246 | delta_bias=self.dt_proj_b.bias.float(), 247 | delta_softplus=True, 248 | ) 249 | 250 | if not self.if_devide_out: 251 | out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) 252 | else: 253 | out = F.linear(rearrange(0.5*out + 0.5*out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) 254 | 255 | else: 256 | out = mamba_inner_fn( 257 | xz, 258 | self.conv1d.weight, 259 | self.conv1d.bias, 260 | self.x_proj.weight, 261 | self.dt_proj.weight, 262 | self.out_proj.weight, 263 | self.out_proj.bias, 264 | A, 265 | None, # input-dependent B 266 | None, # input-dependent C 267 | self.D.float(), 268 | delta_bias=self.dt_proj.bias.float(), 269 | delta_softplus=True, 270 | ) 271 | else: 272 | x, z = xz.chunk(2, dim=1) 273 | # Compute short convolution 274 | if conv_state is not None: 275 | # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv 276 | # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. 277 | conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W) 278 | if causal_conv1d_fn is None: 279 | x = self.act(self.conv1d(x)[..., :seqlen]) 280 | else: 281 | assert self.activation in ["silu", "swish"] 282 | x = causal_conv1d_fn( 283 | x=x, 284 | weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), 285 | bias=self.conv1d.bias, 286 | activation=self.activation, 287 | ) 288 | 289 | # We're careful here about the layout, to avoid extra transposes. 290 | # We want dt to have d as the slowest moving dimension 291 | # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. 292 | x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) 293 | dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) 294 | dt = self.dt_proj.weight @ dt.t() 295 | dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) 296 | B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() 297 | C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() 298 | assert self.activation in ["silu", "swish"] 299 | y = selective_scan_fn( 300 | x, 301 | dt, 302 | A, 303 | B, 304 | C, 305 | self.D.float(), 306 | z=z, 307 | delta_bias=self.dt_proj.bias.float(), 308 | delta_softplus=True, 309 | return_last_state=ssm_state is not None, 310 | ) 311 | if ssm_state is not None: 312 | y, last_state = y 313 | ssm_state.copy_(last_state) 314 | y = rearrange(y, "b d l -> b l d") 315 | out = self.out_proj(y) 316 | if self.init_layer_scale is not None: 317 | out = out * self.gamma 318 | return out 319 | 320 | def step(self, hidden_states, conv_state, ssm_state): 321 | dtype = hidden_states.dtype 322 | assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" 323 | xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) 324 | x, z = xz.chunk(2, dim=-1) # (B D) 325 | 326 | # Conv step 327 | if causal_conv1d_update is None: 328 | conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) 329 | conv_state[:, :, -1] = x 330 | x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) 331 | if self.conv1d.bias is not None: 332 | x = x + self.conv1d.bias 333 | x = self.act(x).to(dtype=dtype) 334 | else: 335 | x = causal_conv1d_update( 336 | x, 337 | conv_state, 338 | rearrange(self.conv1d.weight, "d 1 w -> d w"), 339 | self.conv1d.bias, 340 | self.activation, 341 | ) 342 | 343 | x_db = self.x_proj(x) # (B dt_rank+2*d_state) 344 | dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) 345 | # Don't add dt_bias here 346 | dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) 347 | A = -torch.exp(self.A_log.float()) # (d_inner, d_state) 348 | 349 | # SSM step 350 | if selective_state_update is None: 351 | # Discretize A and B 352 | dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) 353 | dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) 354 | dB = torch.einsum("bd,bn->bdn", dt, B) 355 | ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) 356 | y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) 357 | y = y + self.D.to(dtype) * x 358 | y = y * self.act(z) # (B D) 359 | else: 360 | y = selective_state_update( 361 | ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True 362 | ) 363 | 364 | out = self.out_proj(y) 365 | return out.unsqueeze(1), conv_state, ssm_state 366 | 367 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 368 | device = self.out_proj.weight.device 369 | conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype 370 | conv_state = torch.zeros( 371 | batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype 372 | ) 373 | ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype 374 | # ssm_dtype = torch.float32 375 | ssm_state = torch.zeros( 376 | batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype 377 | ) 378 | return conv_state, ssm_state 379 | 380 | def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): 381 | assert self.layer_idx is not None 382 | if self.layer_idx not in inference_params.key_value_memory_dict: 383 | batch_shape = (batch_size,) 384 | conv_state = torch.zeros( 385 | batch_size, 386 | self.d_model * self.expand, 387 | self.d_conv, 388 | device=self.conv1d.weight.device, 389 | dtype=self.conv1d.weight.dtype, 390 | ) 391 | ssm_state = torch.zeros( 392 | batch_size, 393 | self.d_model * self.expand, 394 | self.d_state, 395 | device=self.dt_proj.weight.device, 396 | dtype=self.dt_proj.weight.dtype, 397 | # dtype=torch.float32, 398 | ) 399 | inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) 400 | else: 401 | conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] 402 | # TODO: What if batch size changes between generation, and we reuse the same states? 403 | if initialize_states: 404 | conv_state.zero_() 405 | ssm_state.zero_() 406 | return conv_state, ssm_state 407 | 408 | 409 | class Block(nn.Module): 410 | def __init__( 411 | self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False 412 | ): 413 | """ 414 | Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" 415 | 416 | This Block has a slightly different structure compared to a regular 417 | prenorm Transformer block. 418 | The standard block is: LN -> MHA/MLP -> Add. 419 | [Ref: https://arxiv.org/abs/2002.04745] 420 | Here we have: Add -> LN -> Mixer, returning both 421 | the hidden_states (output of the mixer) and the residual. 422 | This is purely for performance reasons, as we can fuse add and LayerNorm. 423 | The residual needs to be provided (except for the very first block). 424 | """ 425 | super().__init__() 426 | self.residual_in_fp32 = residual_in_fp32 427 | self.fused_add_norm = fused_add_norm 428 | self.mixer = mixer_cls(dim) 429 | self.norm = norm_cls(dim) 430 | if self.fused_add_norm: 431 | assert RMSNorm is not None, "RMSNorm import fails" 432 | assert isinstance( 433 | self.norm, (nn.LayerNorm, RMSNorm) 434 | ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" 435 | 436 | def forward( 437 | self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None 438 | ): 439 | r"""Pass the input through the encoder layer. 440 | 441 | Args: 442 | hidden_states: the sequence to the encoder layer (required). 443 | residual: hidden_states = Mixer(LN(residual)) 444 | """ 445 | if not self.fused_add_norm: 446 | residual = (hidden_states + residual) if residual is not None else hidden_states 447 | hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) 448 | if self.residual_in_fp32: 449 | residual = residual.to(torch.float32) 450 | else: 451 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn 452 | hidden_states, residual = fused_add_norm_fn( 453 | hidden_states, 454 | self.norm.weight, 455 | self.norm.bias, 456 | residual=residual, 457 | prenorm=True, 458 | residual_in_fp32=self.residual_in_fp32, 459 | eps=self.norm.eps, 460 | ) 461 | hidden_states = self.mixer(hidden_states, inference_params=inference_params) 462 | return hidden_states, residual 463 | 464 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 465 | return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 466 | -------------------------------------------------------------------------------- /modules/Conmamba.py: -------------------------------------------------------------------------------- 1 | """ConMamba encoder and Mamba decoder implementation. 2 | 3 | Authors 4 | ------- 5 | * Xilin Jiang 2024 6 | """ 7 | 8 | import warnings 9 | from dataclasses import dataclass 10 | from typing import List, Optional 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | import speechbrain as sb 17 | from speechbrain.nnet.activations import Swish 18 | from speechbrain.nnet.attention import ( 19 | MultiheadAttention, 20 | PositionalwiseFeedForward, 21 | RelPosMHAXL, 22 | ) 23 | from speechbrain.nnet.hypermixing import HyperMixing 24 | from speechbrain.nnet.normalization import LayerNorm 25 | from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig 26 | 27 | # Mamba 28 | from mamba_ssm import Mamba 29 | from modules.mamba.bimamba import Mamba as BiMamba 30 | 31 | 32 | class ConvolutionModule(nn.Module): 33 | """This is an implementation of convolution module in Conmamba. 34 | """ 35 | 36 | def __init__( 37 | self, 38 | input_size, 39 | kernel_size=31, 40 | bias=True, 41 | activation=Swish, 42 | dropout=0.0, 43 | causal=False, 44 | dilation=1, 45 | ): 46 | super().__init__() 47 | 48 | self.kernel_size = kernel_size 49 | self.causal = causal 50 | self.dilation = dilation 51 | 52 | if self.causal: 53 | self.padding = (kernel_size - 1) * 2 ** (dilation - 1) 54 | else: 55 | self.padding = (kernel_size - 1) * 2 ** (dilation - 1) // 2 56 | 57 | self.layer_norm = nn.LayerNorm(input_size) 58 | self.bottleneck = nn.Sequential( 59 | # pointwise 60 | nn.Conv1d( 61 | input_size, 2 * input_size, kernel_size=1, stride=1, bias=bias 62 | ), 63 | nn.GLU(dim=1), 64 | ) 65 | # depthwise 66 | self.conv = nn.Conv1d( 67 | input_size, 68 | input_size, 69 | kernel_size=kernel_size, 70 | stride=1, 71 | padding=self.padding, 72 | dilation=dilation, 73 | groups=input_size, 74 | bias=bias, 75 | ) 76 | 77 | # BatchNorm in the original Conformer replaced with a LayerNorm due to 78 | # https://github.com/speechbrain/speechbrain/pull/1329 79 | # see discussion 80 | # https://github.com/speechbrain/speechbrain/pull/933#issuecomment-1033367884 81 | 82 | self.after_conv = nn.Sequential( 83 | nn.LayerNorm(input_size), 84 | activation(), 85 | # pointwise 86 | nn.Linear(input_size, input_size, bias=bias), 87 | nn.Dropout(dropout), 88 | ) 89 | 90 | def forward( 91 | self, 92 | x: torch.Tensor, 93 | mask: Optional[torch.Tensor] = None, 94 | dynchunktrain_config: Optional[DynChunkTrainConfig] = None, 95 | ): 96 | """Applies the convolution to an input tensor `x`. 97 | """ 98 | 99 | if dynchunktrain_config is not None: 100 | # chances are chunking+causal is unintended; i don't know where it 101 | # may make sense, but if it does to you, feel free to implement it. 102 | assert ( 103 | not self.causal 104 | ), "Chunked convolution not supported with causal padding" 105 | 106 | assert ( 107 | self.dilation == 1 108 | ), "Current DynChunkTrain logic does not support dilation != 1" 109 | 110 | # in a causal convolution, which is not the case here, an output 111 | # frame would never be able to depend on a input frame from any 112 | # point in the future. 113 | 114 | # but with the dynamic chunk convolution, we instead use a "normal" 115 | # convolution but where, for any output frame, the future beyond the 116 | # "current" chunk gets masked. 117 | # see the paper linked in the documentation for details. 118 | 119 | chunk_size = dynchunktrain_config.chunk_size 120 | batch_size = x.shape[0] 121 | 122 | # determine the amount of padding we need to insert at the right of 123 | # the last chunk so that all chunks end up with the same size. 124 | if x.shape[1] % chunk_size != 0: 125 | final_right_padding = chunk_size - (x.shape[1] % chunk_size) 126 | else: 127 | final_right_padding = 0 128 | 129 | # -> [batch_size, t, in_channels] 130 | out = self.layer_norm(x) 131 | 132 | # -> [batch_size, in_channels, t] for the CNN 133 | out = out.transpose(1, 2) 134 | 135 | # -> [batch_size, in_channels, t] (pointwise) 136 | out = self.bottleneck(out) 137 | 138 | # -> [batch_size, in_channels, lc+t+final_right_padding] 139 | out = F.pad(out, (self.padding, final_right_padding), value=0) 140 | 141 | # now, make chunks with left context. 142 | # as a recap to what the above padding and this unfold do, consider 143 | # each a/b/c letter represents a frame as part of chunks a, b, c. 144 | # consider a chunk size of 4 and a kernel size of 5 (padding=2): 145 | # 146 | # input seq: 00aaaabbbbcc00 147 | # chunk #1: 00aaaa 148 | # chunk #2: aabbbb 149 | # chunk #3: bbcc00 150 | # 151 | # a few remarks here: 152 | # - the left padding gets inserted early so that the unfold logic 153 | # works trivially 154 | # - the right 0-padding got inserted as the number of time steps 155 | # could not be evenly split in `chunk_size` chunks 156 | 157 | # -> [batch_size, in_channels, num_chunks, lc+chunk_size] 158 | out = out.unfold(2, size=chunk_size + self.padding, step=chunk_size) 159 | 160 | # as we manually disable padding in the convolution below, we insert 161 | # right 0-padding to the chunks, e.g. reusing the above example: 162 | # 163 | # chunk #1: 00aaaa00 164 | # chunk #2: aabbbb00 165 | # chunk #3: bbcc0000 166 | 167 | # -> [batch_size, in_channels, num_chunks, lc+chunk_size+rpad] 168 | out = F.pad(out, (0, self.padding), value=0) 169 | 170 | # the transpose+flatten effectively flattens chunks into the batch 171 | # dimension to be processed into the time-wise convolution. the 172 | # chunks will later on be unflattened. 173 | 174 | # -> [batch_size, num_chunks, in_channels, lc+chunk_size+rpad] 175 | out = out.transpose(1, 2) 176 | 177 | # -> [batch_size * num_chunks, in_channels, lc+chunk_size+rpad] 178 | out = out.flatten(start_dim=0, end_dim=1) 179 | 180 | # TODO: experiment around reflect padding, which is difficult 181 | # because small chunks have too little time steps to reflect from 182 | 183 | # let's keep backwards compat by pointing at the weights from the 184 | # already declared Conv1d. 185 | # 186 | # still reusing the above example, the convolution will be applied, 187 | # with the padding truncated on both ends. the following example 188 | # shows the letter corresponding to the input frame on which the 189 | # convolution was centered. 190 | # 191 | # as you can see, the sum of lengths of all chunks is equal to our 192 | # input sequence length + `final_right_padding`. 193 | # 194 | # chunk #1: aaaa 195 | # chunk #2: bbbb 196 | # chunk #3: cc00 197 | 198 | # -> [batch_size * num_chunks, out_channels, chunk_size] 199 | out = F.conv1d( 200 | out, 201 | weight=self.conv.weight, 202 | bias=self.conv.bias, 203 | stride=self.conv.stride, 204 | padding=0, 205 | dilation=self.conv.dilation, 206 | groups=self.conv.groups, 207 | ) 208 | 209 | # -> [batch_size * num_chunks, chunk_size, out_channels] 210 | out = out.transpose(1, 2) 211 | 212 | out = self.after_conv(out) 213 | 214 | # -> [batch_size, num_chunks, chunk_size, out_channels] 215 | out = torch.unflatten(out, dim=0, sizes=(batch_size, -1)) 216 | 217 | # -> [batch_size, t + final_right_padding, out_channels] 218 | out = torch.flatten(out, start_dim=1, end_dim=2) 219 | 220 | # -> [batch_size, t, out_channels] 221 | if final_right_padding > 0: 222 | out = out[:, :-final_right_padding, :] 223 | else: 224 | out = self.layer_norm(x) 225 | out = out.transpose(1, 2) 226 | out = self.bottleneck(out) 227 | out = self.conv(out) 228 | 229 | if self.causal: 230 | # chomp 231 | out = out[..., : -self.padding] 232 | 233 | out = out.transpose(1, 2) 234 | out = self.after_conv(out) 235 | 236 | if mask is not None: 237 | out.masked_fill_(mask, 0.0) 238 | 239 | return out 240 | 241 | 242 | class ConmambaEncoderLayer(nn.Module): 243 | """This is an implementation of Conmamba encoder layer. 244 | """ 245 | 246 | def __init__( 247 | self, 248 | d_model, 249 | d_ffn, 250 | kernel_size=31, 251 | activation=Swish, 252 | bias=True, 253 | dropout=0.0, 254 | causal=False, 255 | mamba_config=None 256 | ): 257 | super().__init__() 258 | assert mamba_config != None 259 | 260 | bidirectional = mamba_config.pop('bidirectional') 261 | if causal or (not bidirectional): 262 | self.mamba = Mamba( 263 | d_model=d_model, 264 | **mamba_config 265 | ) 266 | else: 267 | self.mamba = BiMamba( 268 | d_model=d_model, 269 | bimamba_type='v2', 270 | **mamba_config 271 | ) 272 | mamba_config['bidirectional'] = bidirectional 273 | 274 | self.convolution_module = ConvolutionModule( 275 | d_model, kernel_size, bias, activation, dropout, causal=causal 276 | ) 277 | 278 | self.ffn_module1 = nn.Sequential( 279 | nn.LayerNorm(d_model), 280 | PositionalwiseFeedForward( 281 | d_ffn=d_ffn, 282 | input_size=d_model, 283 | dropout=dropout, 284 | activation=activation, 285 | ), 286 | nn.Dropout(dropout), 287 | ) 288 | 289 | self.ffn_module2 = nn.Sequential( 290 | nn.LayerNorm(d_model), 291 | PositionalwiseFeedForward( 292 | d_ffn=d_ffn, 293 | input_size=d_model, 294 | dropout=dropout, 295 | activation=activation, 296 | ), 297 | nn.Dropout(dropout), 298 | ) 299 | 300 | self.norm1 = LayerNorm(d_model) 301 | self.norm2 = LayerNorm(d_model) 302 | self.drop = nn.Dropout(dropout) 303 | 304 | def forward( 305 | self, 306 | x, 307 | src_mask: Optional[torch.Tensor] = None, 308 | src_key_padding_mask: Optional[torch.Tensor] = None, 309 | pos_embs: torch.Tensor = None, 310 | dynchunktrain_config: Optional[DynChunkTrainConfig] = None, 311 | ): 312 | conv_mask: Optional[torch.Tensor] = None 313 | if src_key_padding_mask is not None: 314 | conv_mask = src_key_padding_mask.unsqueeze(-1) 315 | 316 | conv_mask = None 317 | 318 | # ffn module 319 | x = x + 0.5 * self.ffn_module1(x) 320 | # mamba module 321 | skip = x 322 | x = self.norm1(x) 323 | x = self.mamba(x) 324 | x = x + skip 325 | # convolution module 326 | x = x + self.convolution_module( 327 | x, conv_mask, dynchunktrain_config=dynchunktrain_config 328 | ) 329 | # ffn module 330 | x = self.norm2(x + 0.5 * self.ffn_module2(x)) 331 | return x 332 | 333 | 334 | class ConmambaEncoder(nn.Module): 335 | """This class implements the Conmamba encoder. 336 | """ 337 | 338 | def __init__( 339 | self, 340 | num_layers, 341 | d_model, 342 | d_ffn, 343 | kernel_size=31, 344 | activation=Swish, 345 | bias=True, 346 | dropout=0.0, 347 | causal=False, 348 | mamba_config=None 349 | ): 350 | super().__init__() 351 | print(f'dropout={str(dropout)} is not used in Mamba.') 352 | 353 | self.layers = torch.nn.ModuleList( 354 | [ 355 | ConmambaEncoderLayer( 356 | d_model=d_model, 357 | d_ffn=d_ffn, 358 | dropout=dropout, 359 | activation=activation, 360 | kernel_size=kernel_size, 361 | bias=bias, 362 | causal=causal, 363 | mamba_config=mamba_config, 364 | ) 365 | for i in range(num_layers) 366 | ] 367 | ) 368 | self.norm = LayerNorm(d_model, eps=1e-6) 369 | 370 | def forward( 371 | self, 372 | src, 373 | src_mask: Optional[torch.Tensor] = None, 374 | src_key_padding_mask: Optional[torch.Tensor] = None, 375 | pos_embs: Optional[torch.Tensor] = None, 376 | dynchunktrain_config: Optional[DynChunkTrainConfig] = None, 377 | ): 378 | """ 379 | Arguments 380 | ---------- 381 | src : torch.Tensor 382 | The sequence to the encoder layer. 383 | src_mask : torch.Tensor, optional 384 | The mask for the src sequence. 385 | src_key_padding_mask : torch.Tensor, optional 386 | The mask for the src keys per batch. 387 | pos_embs: torch.Tensor, torch.nn.Module, 388 | Module or tensor containing the input sequence positional embeddings 389 | If custom pos_embs are given it needs to have the shape (1, 2*S-1, E) 390 | where S is the sequence length, and E is the embedding dimension. 391 | dynchunktrain_config: Optional[DynChunkTrainConfig] 392 | Dynamic Chunk Training configuration object for streaming, 393 | specifically involved here to apply Dynamic Chunk Convolution to the 394 | convolution module. 395 | """ 396 | 397 | output = src 398 | for enc_layer in self.layers: 399 | output = enc_layer( 400 | output, 401 | src_mask=src_mask, 402 | src_key_padding_mask=src_key_padding_mask, 403 | pos_embs=pos_embs, 404 | dynchunktrain_config=dynchunktrain_config, 405 | ) 406 | output = self.norm(output) 407 | 408 | return output, None 409 | 410 | 411 | class MambaDecoderLayer(nn.Module): 412 | """This class implements the Mamba decoder layer. 413 | """ 414 | 415 | def __init__( 416 | self, 417 | d_model, 418 | d_ffn, 419 | activation=nn.ReLU, 420 | dropout=0.0, 421 | normalize_before=False, 422 | mamba_config=None 423 | ): 424 | super().__init__() 425 | 426 | assert mamba_config != None 427 | 428 | bidirectional = mamba_config.pop('bidirectional') 429 | 430 | self.self_mamba = Mamba( 431 | d_model=d_model, 432 | **mamba_config 433 | ) 434 | 435 | self.cross_mamba = Mamba( 436 | d_model=d_model, 437 | **mamba_config 438 | ) 439 | 440 | mamba_config['bidirectional'] = bidirectional 441 | 442 | self.pos_ffn = sb.nnet.attention.PositionalwiseFeedForward( 443 | d_ffn=d_ffn, 444 | input_size=d_model, 445 | dropout=dropout, 446 | activation=activation, 447 | ) 448 | 449 | # normalization layers 450 | self.norm1 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) 451 | self.norm2 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) 452 | self.norm3 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) 453 | self.dropout1 = torch.nn.Dropout(dropout) 454 | self.dropout2 = torch.nn.Dropout(dropout) 455 | self.dropout3 = torch.nn.Dropout(dropout) 456 | 457 | self.normalize_before = normalize_before 458 | 459 | def forward( 460 | self, 461 | tgt, 462 | memory, 463 | tgt_mask=None, 464 | memory_mask=None, 465 | tgt_key_padding_mask=None, 466 | memory_key_padding_mask=None, 467 | pos_embs_tgt=None, 468 | pos_embs_src=None, 469 | ): 470 | """ 471 | Arguments 472 | ---------- 473 | tgt: torch.Tensor 474 | The sequence to the decoder layer (required). 475 | memory: torch.Tensor 476 | The sequence from the last layer of the encoder (required). 477 | tgt_mask: torch.Tensor 478 | The mask for the tgt sequence (optional). 479 | memory_mask: torch.Tensor 480 | The mask for the memory sequence (optional). 481 | tgt_key_padding_mask: torch.Tensor 482 | The mask for the tgt keys per batch (optional). 483 | memory_key_padding_mask: torch.Tensor 484 | The mask for the memory keys per batch (optional). 485 | pos_embs_tgt: torch.Tensor 486 | The positional embeddings for the target (optional). 487 | pos_embs_src: torch.Tensor 488 | The positional embeddings for the source (optional). 489 | """ 490 | if self.normalize_before: 491 | tgt1 = self.norm1(tgt) 492 | else: 493 | tgt1 = tgt 494 | 495 | # Mamba over the target sequence 496 | tgt2 = self.self_mamba(tgt1) 497 | 498 | # add & norm 499 | tgt = tgt + self.dropout1(tgt2) 500 | if not self.normalize_before: 501 | tgt = self.norm1(tgt) 502 | 503 | if self.normalize_before: 504 | tgt1 = self.norm2(tgt) 505 | else: 506 | tgt1 = tgt 507 | 508 | # Mamba over key=value + query 509 | # and only take the last len(query) tokens 510 | tgt2 = self.cross_mamba(torch.cat([memory, tgt1], dim=1))[:, -tgt1.shape[1]:] 511 | 512 | # add & norm 513 | tgt = tgt + self.dropout2(tgt2) 514 | if not self.normalize_before: 515 | tgt = self.norm2(tgt) 516 | 517 | if self.normalize_before: 518 | tgt1 = self.norm3(tgt) 519 | else: 520 | tgt1 = tgt 521 | 522 | tgt2 = self.pos_ffn(tgt1) 523 | 524 | # add & norm 525 | tgt = tgt + self.dropout3(tgt2) 526 | if not self.normalize_before: 527 | tgt = self.norm3(tgt) 528 | 529 | return tgt, None, None 530 | 531 | 532 | class MambaDecoder(nn.Module): 533 | """This class implements the Mamba decoder. 534 | """ 535 | 536 | def __init__( 537 | self, 538 | num_layers, 539 | d_model, 540 | d_ffn, 541 | activation=nn.ReLU, 542 | dropout=0.0, 543 | normalize_before=False, 544 | mamba_config=None 545 | ): 546 | super().__init__() 547 | self.layers = torch.nn.ModuleList( 548 | [ 549 | MambaDecoderLayer( 550 | d_model=d_model, 551 | d_ffn=d_ffn, 552 | activation=activation, 553 | dropout=dropout, 554 | normalize_before=normalize_before, 555 | mamba_config=mamba_config 556 | ) 557 | for _ in range(num_layers) 558 | ] 559 | ) 560 | self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) 561 | 562 | def forward( 563 | self, 564 | tgt, 565 | memory, 566 | tgt_mask=None, 567 | memory_mask=None, 568 | tgt_key_padding_mask=None, 569 | memory_key_padding_mask=None, 570 | pos_embs_tgt=None, 571 | pos_embs_src=None, 572 | ): 573 | """ 574 | Arguments 575 | ---------- 576 | tgt : torch.Tensor 577 | The sequence to the decoder layer (required). 578 | memory : torch.Tensor 579 | The sequence from the last layer of the encoder (required). 580 | tgt_mask : torch.Tensor 581 | The mask for the tgt sequence (optional). 582 | memory_mask : torch.Tensor 583 | The mask for the memory sequence (optional). 584 | tgt_key_padding_mask : torch.Tensor 585 | The mask for the tgt keys per batch (optional). 586 | memory_key_padding_mask : torch.Tensor 587 | The mask for the memory keys per batch (optional). 588 | pos_embs_tgt : torch.Tensor 589 | The positional embeddings for the target (optional). 590 | pos_embs_src : torch.Tensor 591 | The positional embeddings for the source (optional). 592 | """ 593 | output = tgt 594 | for dec_layer in self.layers: 595 | output, _, _ = dec_layer( 596 | output, 597 | memory, 598 | tgt_mask=tgt_mask, 599 | memory_mask=memory_mask, 600 | tgt_key_padding_mask=tgt_key_padding_mask, 601 | memory_key_padding_mask=memory_key_padding_mask, 602 | pos_embs_tgt=pos_embs_tgt, 603 | pos_embs_src=pos_embs_src, 604 | ) 605 | output = self.norm(output) 606 | 607 | return output, [None], [None] 608 | --------------------------------------------------------------------------------