├── espnet2 ├── asr │ ├── mamba_ssm │ │ ├── ops │ │ │ ├── __init__.py │ │ │ ├── triton │ │ │ │ ├── __init__.py │ │ │ │ ├── __pycache__ │ │ │ │ │ ├── __init__.cpython-39.pyc │ │ │ │ │ ├── layernorm.cpython-39.pyc │ │ │ │ │ └── selective_state_update.cpython-39.pyc │ │ │ │ └── selective_state_update.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-39.pyc │ │ │ │ └── selective_scan_interface.cpython-39.pyc │ │ │ └── selective_scan_interface.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ └── mamba_simple.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── hf.py │ │ │ └── generation.py │ │ └── __init__.py │ ├── uma.py │ ├── decoder │ │ └── unimodal_attention_decoder.py │ └── encoder │ │ ├── mamba_encoder.py │ │ └── conformer_encoder.py └── bin │ └── asr_unimodal_train.py ├── uma.png ├── mamba_uma.png ├── egs2 ├── aishell │ ├── umaconf │ │ ├── decode_asr_uma.yaml │ │ ├── train_asr_uma_mamba.yaml │ │ ├── train_asr_uma_conformer.yaml │ │ └── train_asr_uma_conformer_condition.yaml │ ├── exp_uma_mamba_0617 │ │ └── asr_train_asr_uma_mamba_raw_zh_char_sp │ │ │ ├── images │ │ │ ├── cer.png │ │ │ ├── loss.png │ │ │ ├── cer_ctc.png │ │ │ ├── iter_time.png │ │ │ ├── loss_ctc.png │ │ │ ├── optim0_lr0.png │ │ │ ├── train_time.png │ │ │ ├── forward_time.png │ │ │ ├── text_vs_uma.png │ │ │ ├── backward_time.png │ │ │ ├── optim_step_time.png │ │ │ ├── uma_reduction.png │ │ │ └── gpu_max_cached_mem_GB.png │ │ │ └── RESULTS.md │ ├── exp_uma_conformer_12e_69 │ │ └── asr_train_asr_unimodal_conformer_raw_zh_char_sp │ │ │ ├── images │ │ │ ├── cer.png │ │ │ ├── loss.png │ │ │ ├── cer_ctc.png │ │ │ ├── iter_time.png │ │ │ ├── loss_ctc.png │ │ │ ├── optim0_lr0.png │ │ │ ├── train_time.png │ │ │ ├── forward_time.png │ │ │ ├── backward_time.png │ │ │ ├── optim_step_time.png │ │ │ └── gpu_max_cached_mem_GB.png │ │ │ └── RESULTS.md │ ├── exp_uma_conformer_condition0302_32_731 │ │ └── asr_train_asr_uma_conformer_condition_raw_zh_char_sp │ │ │ ├── images │ │ │ ├── cer.png │ │ │ ├── loss.png │ │ │ ├── cer_ctc.png │ │ │ ├── loss_ctc.png │ │ │ ├── iter_time.png │ │ │ ├── optim0_lr0.png │ │ │ ├── train_time.png │ │ │ ├── backward_time.png │ │ │ ├── forward_time.png │ │ │ ├── optim_step_time.png │ │ │ ├── gpu_max_cached_mem_GB.png │ │ │ ├── cer_interctc_declayer2.png │ │ │ ├── cer_interctc_declayer4.png │ │ │ ├── cer_interctc_enclayer12.png │ │ │ ├── cer_interctc_enclayer6.png │ │ │ ├── cer_interctc_enclayer9.png │ │ │ ├── loss_interctc_declayer2.png │ │ │ ├── loss_interctc_declayer4.png │ │ │ ├── loss_interctc_enclayer6.png │ │ │ ├── loss_interctc_enclayer9.png │ │ │ └── loss_interctc_enclayer12.png │ │ │ └── RESULTS.md │ └── run_unimodal.sh ├── hkust │ ├── umaconf │ │ ├── decode_asr_uma.yaml │ │ ├── train_asr_uma_conformer.yaml │ │ ├── train_asr_uma_branchformer.yaml │ │ ├── train_asr_uma_conformer_condition.yaml │ │ └── train_asr_uma_branchformer_condition.yaml │ ├── exp_uma_conformer_12e_67 │ │ └── asr_train_asr_uma_conformer_raw_zh_char_sp │ │ │ ├── images │ │ │ ├── cer.png │ │ │ ├── loss.png │ │ │ ├── cer_ctc.png │ │ │ ├── iter_time.png │ │ │ ├── loss_ctc.png │ │ │ ├── forward_time.png │ │ │ ├── optim0_lr0.png │ │ │ ├── train_time.png │ │ │ ├── backward_time.png │ │ │ ├── optim_step_time.png │ │ │ └── gpu_max_cached_mem_GB.png │ │ │ └── RESULTS.md │ ├── exp_uma_branchformer_12e_69 │ │ └── asr_train_asr_uma_branchformer_raw_zh_char_sp │ │ │ ├── images │ │ │ ├── cer.png │ │ │ ├── loss.png │ │ │ ├── cer_ctc.png │ │ │ ├── loss_ctc.png │ │ │ ├── iter_time.png │ │ │ ├── optim0_lr0.png │ │ │ ├── train_time.png │ │ │ ├── backward_time.png │ │ │ ├── forward_time.png │ │ │ ├── optim_step_time.png │ │ │ └── gpu_max_cached_mem_GB.png │ │ │ └── RESULTS.md │ ├── exp_uma_conformer_condition0302_32_712 │ │ └── asr_train_asr_uma_conformer_condition_raw_zh_char_sp │ │ │ ├── images │ │ │ ├── cer.png │ │ │ ├── loss.png │ │ │ ├── cer_ctc.png │ │ │ ├── loss_ctc.png │ │ │ ├── iter_time.png │ │ │ ├── optim0_lr0.png │ │ │ ├── train_time.png │ │ │ ├── backward_time.png │ │ │ ├── forward_time.png │ │ │ ├── optim_step_time.png │ │ │ ├── cer_interctc_declayer2.png │ │ │ ├── cer_interctc_declayer4.png │ │ │ ├── cer_interctc_enclayer12.png │ │ │ ├── cer_interctc_enclayer6.png │ │ │ ├── cer_interctc_enclayer9.png │ │ │ ├── gpu_max_cached_mem_GB.png │ │ │ ├── loss_interctc_declayer2.png │ │ │ ├── loss_interctc_declayer4.png │ │ │ ├── loss_interctc_enclayer6.png │ │ │ ├── loss_interctc_enclayer9.png │ │ │ └── loss_interctc_enclayer12.png │ │ │ └── RESULTS.md │ ├── exp_uma_branchformer_condition0302_32_711 │ │ └── asr_train_asr_uma_branchformer_condition_raw_zh_char_sp │ │ │ ├── images │ │ │ ├── cer.png │ │ │ ├── loss.png │ │ │ ├── cer_ctc.png │ │ │ ├── iter_time.png │ │ │ ├── loss_ctc.png │ │ │ ├── forward_time.png │ │ │ ├── optim0_lr0.png │ │ │ ├── train_time.png │ │ │ ├── backward_time.png │ │ │ ├── optim_step_time.png │ │ │ ├── cer_interctc_declayer2.png │ │ │ ├── cer_interctc_declayer4.png │ │ │ ├── cer_interctc_enclayer6.png │ │ │ ├── cer_interctc_enclayer9.png │ │ │ ├── gpu_max_cached_mem_GB.png │ │ │ ├── cer_interctc_enclayer12.png │ │ │ ├── loss_interctc_declayer2.png │ │ │ ├── loss_interctc_declayer4.png │ │ │ ├── loss_interctc_enclayer12.png │ │ │ ├── loss_interctc_enclayer6.png │ │ │ └── loss_interctc_enclayer9.png │ │ │ └── RESULTS.md │ └── run_unimodal.sh └── aishell2 │ ├── umaconf │ ├── decode_asr_uma.yaml │ ├── train_asr_uma_mamba_b.yaml │ ├── train_asr_uma_conformer.yaml │ └── train_asr_uma_conformer_condition.yaml │ ├── exp_uma_mamba_0819 │ └── asr_train_asr_uma_mamba_b_raw_zh_char_sp │ │ ├── images │ │ ├── cer.png │ │ ├── loss.png │ │ ├── cer_ctc.png │ │ ├── iter_time.png │ │ ├── loss_ctc.png │ │ ├── forward_time.png │ │ ├── optim0_lr0.png │ │ ├── text_vs_uma.png │ │ ├── train_time.png │ │ ├── backward_time.png │ │ ├── uma_reduction.png │ │ ├── optim_step_time.png │ │ └── gpu_max_cached_mem_GB.png │ │ └── RESULTS.md │ ├── exp_uma_conformer_12e_718 │ ├── asr_train_asr_uma_conformer_raw_zh_char_sp │ │ ├── images │ │ │ ├── cer.png │ │ │ ├── loss.png │ │ │ ├── cer_ctc.png │ │ │ ├── loss_ctc.png │ │ │ ├── iter_time.png │ │ │ ├── optim0_lr0.png │ │ │ ├── train_time.png │ │ │ ├── backward_time.png │ │ │ ├── forward_time.png │ │ │ ├── optim_step_time.png │ │ │ └── gpu_max_cached_mem_GB.png │ │ └── RESULTS.md │ └── asr_train_asr_uma_conformer_condition_raw_zh_char_sp │ │ ├── images │ │ ├── cer.png │ │ ├── loss.png │ │ ├── cer_ctc.png │ │ ├── loss_ctc.png │ │ ├── iter_time.png │ │ ├── optim0_lr0.png │ │ ├── train_time.png │ │ ├── backward_time.png │ │ ├── forward_time.png │ │ ├── optim_step_time.png │ │ ├── cer_interctc_declayer2.png │ │ ├── cer_interctc_declayer4.png │ │ ├── cer_interctc_enclayer12.png │ │ ├── cer_interctc_enclayer6.png │ │ ├── cer_interctc_enclayer9.png │ │ ├── gpu_max_cached_mem_GB.png │ │ ├── loss_interctc_declayer2.png │ │ ├── loss_interctc_declayer4.png │ │ ├── loss_interctc_enclayer6.png │ │ ├── loss_interctc_enclayer9.png │ │ └── loss_interctc_enclayer12.png │ │ └── RESULTS.md │ └── run_unimodal.sh └── README.md /espnet2/asr/mamba_ssm/ops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /espnet2/asr/mamba_ssm/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /espnet2/asr/mamba_ssm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /espnet2/asr/mamba_ssm/ops/triton/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /uma.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/uma.png -------------------------------------------------------------------------------- /mamba_uma.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/mamba_uma.png -------------------------------------------------------------------------------- /egs2/aishell/umaconf/decode_asr_uma.yaml: -------------------------------------------------------------------------------- 1 | beam_size: 1 2 | penalty: 0.0 3 | maxlenratio: 0.0 4 | minlenratio: 0.0 5 | ctc_weight: 1 6 | lm_weight: 0.7 -------------------------------------------------------------------------------- /egs2/hkust/umaconf/decode_asr_uma.yaml: -------------------------------------------------------------------------------- 1 | beam_size: 1 2 | penalty: 0.0 3 | maxlenratio: 0.0 4 | minlenratio: 0.0 5 | ctc_weight: 1 6 | lm_weight: 0.3 -------------------------------------------------------------------------------- /egs2/aishell2/umaconf/decode_asr_uma.yaml: -------------------------------------------------------------------------------- 1 | beam_size: 1 2 | penalty: 0.0 3 | maxlenratio: 0.0 4 | minlenratio: 0.0 5 | ctc_weight: 1 6 | lm_weight: 0.3 -------------------------------------------------------------------------------- /espnet2/asr/mamba_ssm/ops/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/espnet2/asr/mamba_ssm/ops/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /espnet2/asr/mamba_ssm/ops/triton/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/espnet2/asr/mamba_ssm/ops/triton/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /espnet2/asr/mamba_ssm/ops/triton/__pycache__/layernorm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/espnet2/asr/mamba_ssm/ops/triton/__pycache__/layernorm.cpython-39.pyc -------------------------------------------------------------------------------- /espnet2/asr/mamba_ssm/ops/__pycache__/selective_scan_interface.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/espnet2/asr/mamba_ssm/ops/__pycache__/selective_scan_interface.cpython-39.pyc -------------------------------------------------------------------------------- /espnet2/asr/mamba_ssm/ops/triton/__pycache__/selective_state_update.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/espnet2/asr/mamba_ssm/ops/triton/__pycache__/selective_state_update.cpython-39.pyc -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/cer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/cer.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/loss.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/cer_ctc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/cer_ctc.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/cer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/cer.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/loss.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/iter_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/iter_time.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/loss_ctc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/loss_ctc.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/optim0_lr0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/optim0_lr0.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/train_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/train_time.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/cer_ctc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/cer_ctc.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/forward_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/forward_time.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/text_vs_uma.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/text_vs_uma.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/iter_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/iter_time.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/loss_ctc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/loss_ctc.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/images/cer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/images/cer.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/images/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/images/loss.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/backward_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/backward_time.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/optim_step_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/optim_step_time.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/uma_reduction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/uma_reduction.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/images/cer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/images/cer.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/forward_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/forward_time.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/optim0_lr0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/optim0_lr0.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/text_vs_uma.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/text_vs_uma.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/train_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/train_time.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/images/cer_ctc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/images/cer_ctc.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/images/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/images/loss.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/backward_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/backward_time.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/uma_reduction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/uma_reduction.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/images/cer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/images/cer.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/images/iter_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/images/iter_time.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/images/loss_ctc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/images/loss_ctc.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/images/cer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/images/cer.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/images/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/images/loss.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/images/cer_ctc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/images/cer_ctc.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/images/loss_ctc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/images/loss_ctc.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/optim_step_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/optim_step_time.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/images/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/images/loss.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/images/forward_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/images/forward_time.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/images/optim0_lr0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/images/optim0_lr0.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/images/train_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/images/train_time.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/images/cer_ctc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/images/cer_ctc.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/gpu_max_cached_mem_GB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/images/gpu_max_cached_mem_GB.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/images/iter_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/images/iter_time.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/images/optim0_lr0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/images/optim0_lr0.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/images/train_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/images/train_time.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/images/cer_ctc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/images/cer_ctc.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/images/loss_ctc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/images/loss_ctc.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/images/backward_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/images/backward_time.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/images/iter_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/images/iter_time.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/images/loss_ctc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/images/loss_ctc.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/images/optim0_lr0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/images/optim0_lr0.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/images/train_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/images/train_time.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/images/backward_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/images/backward_time.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/images/forward_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/images/forward_time.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/gpu_max_cached_mem_GB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/images/gpu_max_cached_mem_GB.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/images/iter_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/images/iter_time.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/images/optim0_lr0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/images/optim0_lr0.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/images/train_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/images/train_time.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/images/optim_step_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/images/optim_step_time.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/images/forward_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/images/forward_time.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/images/optim_step_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/images/optim_step_time.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/images/backward_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/images/backward_time.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/images/forward_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/images/forward_time.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/images/backward_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/images/backward_time.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/images/optim_step_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/images/optim_step_time.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_ctc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_ctc.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_ctc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_ctc.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/images/optim_step_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/images/optim_step_time.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/images/gpu_max_cached_mem_GB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/images/gpu_max_cached_mem_GB.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/iter_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/iter_time.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/optim0_lr0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/optim0_lr0.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/train_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/train_time.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/backward_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/backward_time.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/forward_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/forward_time.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/images/gpu_max_cached_mem_GB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/images/gpu_max_cached_mem_GB.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/images/gpu_max_cached_mem_GB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/images/gpu_max_cached_mem_GB.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/images/gpu_max_cached_mem_GB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/images/gpu_max_cached_mem_GB.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/optim_step_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/optim_step_time.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_ctc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_ctc.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_ctc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_ctc.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_ctc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_ctc.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_ctc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_ctc.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/cer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/cer.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/loss.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/iter_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/iter_time.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/optim0_lr0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/optim0_lr0.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/train_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/train_time.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/iter_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/iter_time.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/optim0_lr0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/optim0_lr0.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/train_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/train_time.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_declayer2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_declayer2.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_declayer4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_declayer4.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer12.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer6.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer9.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/gpu_max_cached_mem_GB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/gpu_max_cached_mem_GB.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_declayer2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_declayer2.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_declayer4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_declayer4.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer6.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer9.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/cer_ctc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/cer_ctc.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/backward_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/backward_time.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/forward_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/forward_time.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/backward_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/backward_time.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/forward_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/forward_time.png -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer12.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/iter_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/iter_time.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/loss_ctc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/loss_ctc.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/optim_step_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/optim_step_time.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/optim_step_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/optim_step_time.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/forward_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/forward_time.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/optim0_lr0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/optim0_lr0.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/train_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/train_time.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/backward_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/backward_time.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/gpu_max_cached_mem_GB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/gpu_max_cached_mem_GB.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/optim_step_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/optim_step_time.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_declayer2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_declayer2.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_declayer4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_declayer4.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer12.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer6.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer9.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/gpu_max_cached_mem_GB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/gpu_max_cached_mem_GB.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_declayer2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_declayer2.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_declayer4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_declayer4.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer6.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer9.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_declayer2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_declayer2.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_declayer4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_declayer4.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer12.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer6.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer9.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_declayer2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_declayer2.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_declayer4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_declayer4.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer6.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer9.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer12.png -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer12.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/cer_interctc_declayer2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/cer_interctc_declayer2.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/cer_interctc_declayer4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/cer_interctc_declayer4.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer6.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer9.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/gpu_max_cached_mem_GB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/gpu_max_cached_mem_GB.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/cer_interctc_enclayer12.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/loss_interctc_declayer2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/loss_interctc_declayer2.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/loss_interctc_declayer4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/loss_interctc_declayer4.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer12.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer6.png -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/UMA-ASR/HEAD/egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/images/loss_interctc_enclayer9.png -------------------------------------------------------------------------------- /espnet2/asr/mamba_ssm/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: FnoY fangying@westlake.edu.cn 3 | LastEditors: FnoY0723 fangying@westlake.edu.cn 4 | LastEditTime: 2024-03-15 20:55:13 5 | FilePath: /espnet/espnet2/asr/mamba_ssm/__init__.py 6 | ''' 7 | __version__ = "1.2.0.post1" 8 | 9 | from espnet2.asr.mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn 10 | from espnet2.asr.mamba_ssm.modules.mamba_simple import Mamba 11 | 12 | -------------------------------------------------------------------------------- /espnet2/bin/asr_unimodal_train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: FnoY fangying@westlake.edu.cn 3 | LastEditTime: 2023-09-15 14:22:03 4 | FilePath: /espnet/espnet2/bin/asr_unimodal_train.py 5 | ''' 6 | #!/usr/bin/env python3 7 | from espnet2.tasks.asr_unimodal import ASRTask 8 | 9 | 10 | def get_parser(): 11 | parser = ASRTask.get_parser() 12 | return parser 13 | 14 | 15 | def main(cmd=None): 16 | r"""ASR training. 17 | 18 | Example: 19 | 20 | % python asr_train.py asr --print_config --optim adadelta \ 21 | > conf/train_asr.yaml 22 | % python asr_train.py --config conf/train_asr.yaml 23 | """ 24 | ASRTask.main(cmd=cmd) 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp/RESULTS.md: -------------------------------------------------------------------------------- 1 | 2 | # RESULTS 3 | ## Environments 4 | - date: `Thu Sep 5 11:44:58 CST 2024` 5 | - python version: `3.9.16 (main, Jan 11 2023, 16:05:54) [GCC 11.2.0]` 6 | - espnet version: `espnet 202301` 7 | - pytorch version: `pytorch 1.12.1` 8 | - Git hash: `8f063f87a8c5de189a5d092e050694a4fd5115d4` 9 | - Commit date: `Mon Jul 8 14:14:22 2024 +0800` 10 | 11 | ## exp_uma_mamba_0819/asr_train_asr_uma_mamba_b_raw_zh_char_sp 12 | ### WER 13 | 14 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 15 | |---|---|---|---|---|---|---|---|---| 16 | |decode_uma_asr_model_valid.cer.ave_10best/test_ios|5000|5002|62.1|37.9|0.0|0.0|37.9|37.9| 17 | 18 | ### CER 19 | 20 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 21 | |---|---|---|---|---|---|---|---|---| 22 | |decode_uma_asr_model_valid.cer.ave_10best/test_ios|5000|49534|94.1|5.6|0.3|0.2|6.1|37.9| 23 | 24 | ### TER 25 | 26 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 27 | |---|---|---|---|---|---|---|---|---| 28 | -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp/RESULTS.md: -------------------------------------------------------------------------------- 1 | 2 | # RESULTS 3 | ## Environments 4 | - date: `Mon Jun 12 20:55:04 CST 2023` 5 | - python version: `3.9.16 (main, Jan 11 2023, 16:05:54) [GCC 11.2.0]` 6 | - espnet version: `espnet 202301` 7 | - pytorch version: `pytorch 1.12.1` 8 | - Git hash: `69054b7b6203973d158be95a0816e551da4d4bd6` 9 | - Commit date: `Tue Jun 6 11:23:40 2023 +0800` 10 | 11 | ## exp_uma_conformer_12e_67/asr_train_asr_uma_conformer_raw_zh_char_sp 12 | ### WER 13 | 14 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 15 | |---|---|---|---|---|---|---|---|---| 16 | |decode_uma_asr_model_valid.cer.ave_10best/dev|5413|5240|28.3|71.6|0.1|0.7|72.4|70.1| 17 | 18 | ### CER 19 | 20 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 21 | |---|---|---|---|---|---|---|---|---| 22 | |decode_uma_asr_model_valid.cer.ave_10best/dev|5413|56154|81.8|15.6|2.7|3.2|21.4|68.7| 23 | 24 | ### TER 25 | 26 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 27 | |---|---|---|---|---|---|---|---|---| 28 | -------------------------------------------------------------------------------- /espnet2/asr/mamba_ssm/utils/hf.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | 5 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME 6 | from transformers.utils.hub import cached_file 7 | 8 | 9 | def load_config_hf(model_name): 10 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) 11 | return json.load(open(resolved_archive_file)) 12 | 13 | 14 | def load_state_dict_hf(model_name, device=None, dtype=None): 15 | # If not fp32, then we don't want to load directly to the GPU 16 | mapped_device = "cpu" if dtype not in [torch.float32, None] else device 17 | resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) 18 | return torch.load(resolved_archive_file, map_location=mapped_device) 19 | # Convert dtype before moving to GPU to save memory 20 | if dtype is not None: 21 | state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} 22 | state_dict = {k: v.to(device=device) for k, v in state_dict.items()} 23 | return state_dict 24 | -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp/RESULTS.md: -------------------------------------------------------------------------------- 1 | 2 | # RESULTS 3 | ## Environments 4 | - date: `Mon Jun 12 20:17:01 CST 2023` 5 | - python version: `3.9.16 (main, Jan 11 2023, 16:05:54) [GCC 11.2.0]` 6 | - espnet version: `espnet 202301` 7 | - pytorch version: `pytorch 1.12.1` 8 | - Git hash: `69054b7b6203973d158be95a0816e551da4d4bd6` 9 | - Commit date: `Tue Jun 6 11:23:40 2023 +0800` 10 | 11 | ## exp_uma_branchformer_12e_69/asr_train_asr_uma_branchformer_raw_zh_char_sp 12 | ### WER 13 | 14 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 15 | |---|---|---|---|---|---|---|---|---| 16 | |decode_uma_asr_model_valid.cer.ave_10best/dev|5413|5240|29.0|70.8|0.2|0.6|71.5|69.3| 17 | 18 | ### CER 19 | 20 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 21 | |---|---|---|---|---|---|---|---|---| 22 | |decode_uma_asr_model_valid.cer.ave_10best/dev|5413|56154|82.5|14.1|3.4|2.6|20.1|68.1| 23 | 24 | ### TER 25 | 26 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 27 | |---|---|---|---|---|---|---|---|---| 28 | -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/RESULTS.md: -------------------------------------------------------------------------------- 1 | 2 | # RESULTS 3 | ## Environments 4 | - date: `Fri Jul 14 12:33:51 CST 2023` 5 | - python version: `3.9.16 (main, Jan 11 2023, 16:05:54) [GCC 11.2.0]` 6 | - espnet version: `espnet 202301` 7 | - pytorch version: `pytorch 1.12.1` 8 | - Git hash: `58d7c097f69a5ddc15aa4658e9462e028157f326` 9 | - Commit date: `Thu Jun 29 15:27:10 2023 +0800` 10 | 11 | ## exp_uma_conformer_condition0302_32_712/asr_train_asr_uma_conformer_condition_raw_zh_char_sp 12 | ### WER 13 | 14 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 15 | |---|---|---|---|---|---|---|---|---| 16 | |decode_uma_asr_model_valid.cer.ave_10best/dev|5413|5240|29.6|70.2|0.1|0.6|71.0|68.7| 17 | 18 | ### CER 19 | 20 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 21 | |---|---|---|---|---|---|---|---|---| 22 | |decode_uma_asr_model_valid.cer.ave_10best/dev|5413|56154|83.0|14.4|2.6|3.1|20.0|67.3| 23 | 24 | ### TER 25 | 26 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 27 | |---|---|---|---|---|---|---|---|---| 28 | -------------------------------------------------------------------------------- /egs2/hkust/exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp/RESULTS.md: -------------------------------------------------------------------------------- 1 | 2 | # RESULTS 3 | ## Environments 4 | - date: `Wed Jul 12 11:10:43 CST 2023` 5 | - python version: `3.9.16 (main, Jan 11 2023, 16:05:54) [GCC 11.2.0]` 6 | - espnet version: `espnet 202301` 7 | - pytorch version: `pytorch 1.12.1` 8 | - Git hash: `58d7c097f69a5ddc15aa4658e9462e028157f326` 9 | - Commit date: `Thu Jun 29 15:27:10 2023 +0800` 10 | 11 | ## exp_uma_branchformer_condition0302_32_711/asr_train_asr_uma_branchformer_condition_raw_zh_char_sp 12 | ### WER 13 | 14 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 15 | |---|---|---|---|---|---|---|---|---| 16 | |decode_uma_asr_model_valid.cer.ave_10best/dev|5413|5240|30.8|69.0|0.2|0.4|69.6|67.3| 17 | 18 | ### CER 19 | 20 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 21 | |---|---|---|---|---|---|---|---|---| 22 | |decode_uma_asr_model_valid.cer.ave_10best/dev|5413|56154|83.7|13.7|2.6|2.9|19.2|66.0| 23 | 24 | ### TER 25 | 26 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 27 | |---|---|---|---|---|---|---|---|---| 28 | -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp/RESULTS.md: -------------------------------------------------------------------------------- 1 | 2 | # RESULTS 3 | ## Environments 4 | - date: `Thu Oct 31 12:45:06 CST 2024` 5 | - python version: `3.9.16 (main, Jan 11 2023, 16:05:54) [GCC 11.2.0]` 6 | - espnet version: `espnet 202301` 7 | - pytorch version: `pytorch 1.12.1` 8 | - Git hash: `4f786daee971a55f5d2e299f071d5e661de4a3ca` 9 | - Commit date: `Thu Oct 10 20:19:06 2024 +0800` 10 | 11 | ## exp_uma_mamba_0617/asr_train_asr_uma_mamba_raw_zh_char_sp 12 | ### WER 13 | 14 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 15 | |---|---|---|---|---|---|---|---|---| 16 | |decode_asr_unimodal_attention_asr_model_valid.cer.ave_10best/test|7176|7176|59.0|41.0|0.0|0.0|41.0|41.0| 17 | |peak_decode_asr_unimodal_attention_asr_model_valid.cer.ave_10best/test|7176|7176|59.0|41.0|0.0|0.0|41.0|41.0| 18 | 19 | ### CER 20 | 21 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 22 | |---|---|---|---|---|---|---|---|---| 23 | |decode_asr_unimodal_attention_asr_model_valid.cer.ave_10best/test|7176|104765|94.6|5.2|0.2|0.1|5.5|41.0| 24 | |peak_decode_asr_unimodal_attention_asr_model_valid.cer.ave_10best/test|7176|104765|94.6|5.2|0.1|0.2|5.6|41.0| 25 | 26 | ### TER 27 | 28 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 29 | |---|---|---|---|---|---|---|---|---| 30 | -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/RESULTS.md: -------------------------------------------------------------------------------- 1 | 2 | # RESULTS 3 | ## Environments 4 | - date: `Mon Jun 12 02:23:35 CST 2023` 5 | - python version: `3.9.16 (main, Jan 11 2023, 16:05:54) [GCC 11.2.0]` 6 | - espnet version: `espnet 202301` 7 | - pytorch version: `pytorch 1.12.1` 8 | - Git hash: `69054b7b6203973d158be95a0816e551da4d4bd6` 9 | - Commit date: `Tue Jun 6 11:23:40 2023 +0800` 10 | 11 | ## exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp 12 | ### WER 13 | 14 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 15 | |---|---|---|---|---|---|---|---|---| 16 | |decode_asr_unimodal_attention_asr_model_valid.cer.ave_10best/test|7176|7176|62.7|37.3|0.0|0.0|37.3|37.3| 17 | 18 | ### CER 19 | 20 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 21 | |---|---|---|---|---|---|---|---|---| 22 | |decode_asr_unimodal_attention_asr_model_valid.cer.ave_10best/test|7176|104765|95.3|4.5|0.2|0.1|4.8|37.3| 23 | 24 | ### TER 25 | 26 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 27 | |---|---|---|---|---|---|---|---|---| 28 | ## exp_uma_conformer_12e_69/asr_train_asr_unimodal_conformer_raw_zh_char_sp/decode_asr_unimodal_attention_asr_model_valid.cer.ave_10best 29 | ### WER 30 | 31 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 32 | |---|---|---|---|---|---|---|---|---| 33 | |org/dev|14326|14326|64.3|35.7|0.0|0.0|35.7|35.7| 34 | 35 | ### CER 36 | 37 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 38 | |---|---|---|---|---|---|---|---|---| 39 | |org/dev|14326|205341|95.6|4.3|0.1|0.1|4.5|35.7| 40 | 41 | ### TER 42 | 43 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 44 | |---|---|---|---|---|---|---|---|---| 45 | -------------------------------------------------------------------------------- /egs2/aishell/exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/RESULTS.md: -------------------------------------------------------------------------------- 1 | 2 | # RESULTS 3 | ## Environments 4 | - date: `Tue Aug 1 16:39:06 CST 2023` 5 | - python version: `3.9.16 (main, Jan 11 2023, 16:05:54) [GCC 11.2.0]` 6 | - espnet version: `espnet 202301` 7 | - pytorch version: `pytorch 1.12.1` 8 | - Git hash: `58d7c097f69a5ddc15aa4658e9462e028157f326` 9 | - Commit date: `Thu Jun 29 15:27:10 2023 +0800` 10 | 11 | ## exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp 12 | ### WER 13 | 14 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 15 | |---|---|---|---|---|---|---|---|---| 16 | |decode_asr_unimodal_attention_asr_model_valid.cer.ave_10best/test|7176|7176|64.1|35.9|0.0|0.0|35.9|35.9| 17 | 18 | ### CER 19 | 20 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 21 | |---|---|---|---|---|---|---|---|---| 22 | |decode_asr_unimodal_attention_asr_model_valid.cer.ave_10best/test|7176|104765|95.4|4.4|0.1|0.1|4.7|35.9| 23 | 24 | ### TER 25 | 26 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 27 | |---|---|---|---|---|---|---|---|---| 28 | ## exp_uma_conformer_condition0302_32_731/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/decode_asr_unimodal_attention_asr_model_valid.cer.ave_10best 29 | ### WER 30 | 31 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 32 | |---|---|---|---|---|---|---|---|---| 33 | |org/dev|14326|14326|66.0|34.0|0.0|0.0|34.0|34.0| 34 | 35 | ### CER 36 | 37 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 38 | |---|---|---|---|---|---|---|---|---| 39 | |org/dev|14326|205341|95.7|4.1|0.1|0.1|4.4|34.0| 40 | 41 | ### TER 42 | 43 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 44 | |---|---|---|---|---|---|---|---|---| 45 | -------------------------------------------------------------------------------- /egs2/aishell/umaconf/train_asr_uma_mamba.yaml: -------------------------------------------------------------------------------- 1 | # network architecture 2 | # encoder related 3 | encoder: mamba 4 | encoder_conf: 5 | output_size: 256 6 | num_blocks: 36 7 | dropout_rate: 0.1 8 | input_layer: causal_conv2d 9 | rms_norm: true 10 | fused_add_norm: true 11 | residual_in_fp32: true 12 | normalize_before: true 13 | lookahead_kernel: 17 14 | 15 | # decoder related 16 | decoder: unimodal_transformer 17 | decoder_conf: 18 | output_size: 256 19 | attention_heads: 4 20 | linear_units: 2048 21 | num_blocks: 6 22 | dropout_rate: 0.1 23 | positional_dropout_rate: 0.1 24 | 25 | # hybrid CTC/attention 26 | model_conf: 27 | ctc_weight: 1.0 28 | lsm_weight: 0.1 # label smoothing option 29 | length_normalized_loss: false 30 | 31 | 32 | # minibatch related 33 | batch_type: folded 34 | batch_size: 128 35 | 36 | # optimization related 37 | accum_grad: 1 38 | grad_clip: 5.0 39 | max_epoch: 50 40 | val_scheduler_criterion: 41 | - valid 42 | - loss 43 | best_model_criterion: 44 | - - valid 45 | - cer 46 | - min 47 | keep_nbest_models: 10 48 | 49 | optim: adamw 50 | optim_conf: 51 | lr: 0.001 52 | weight_decay: 0.01 53 | scheduler: warmuplr 54 | scheduler_conf: 55 | warmup_steps: 25000 56 | 57 | 58 | num_workers: 4 # num of workers of data loader 59 | use_amp: true # automatic mixed precision 60 | unused_parameters: false # set as true if some params are unused in DDP 61 | 62 | specaug: specaug 63 | specaug_conf: 64 | apply_time_warp: true 65 | time_warp_window: 5 66 | time_warp_mode: bicubic 67 | apply_freq_mask: true 68 | freq_mask_width_range: 69 | - 0 70 | - 27 71 | num_freq_mask: 2 72 | apply_time_mask: true 73 | time_mask_width_ratio_range: 74 | - 0. 75 | - 0.05 76 | num_time_mask: 10 77 | -------------------------------------------------------------------------------- /egs2/aishell2/umaconf/train_asr_uma_mamba_b.yaml: -------------------------------------------------------------------------------- 1 | # network architecture 2 | # encoder related 3 | encoder: mamba 4 | encoder_conf: 5 | output_size: 512 6 | num_blocks: 36 7 | dropout_rate: 0.1 8 | input_layer: causal_conv2d 9 | rms_norm: true 10 | fused_add_norm: true 11 | residual_in_fp32: true 12 | normalize_before: true 13 | lookahead_kernel: 29 14 | 15 | # decoder related 16 | decoder: unimodal_transformer 17 | decoder_conf: 18 | output_size: 512 19 | attention_heads: 8 20 | linear_units: 2048 21 | num_blocks: 6 22 | dropout_rate: 0.1 23 | positional_dropout_rate: 0.1 24 | 25 | # hybrid CTC/attention 26 | model_conf: 27 | ctc_weight: 1.0 28 | lsm_weight: 0.1 # label smoothing option 29 | length_normalized_loss: false 30 | 31 | 32 | # minibatch related 33 | batch_type: folded 34 | batch_size: 128 35 | num_iters_per_epoch: 7126 36 | 37 | # optimization related 38 | accum_grad: 2 39 | grad_clip: 5.0 40 | max_epoch: 150 41 | log_interval: 200 42 | val_scheduler_criterion: 43 | - valid 44 | - loss 45 | best_model_criterion: 46 | - - valid 47 | - cer 48 | - min 49 | keep_nbest_models: 10 50 | 51 | optim: adamw 52 | optim_conf: 53 | lr: 0.0005 54 | weight_decay: 0.1 55 | scheduler: warmuplr 56 | scheduler_conf: 57 | warmup_steps: 30000 58 | 59 | 60 | num_workers: 4 # num of workers of data loader 61 | use_amp: true # automatic mixed precision 62 | unused_parameters: false # set as true if some params are unused in DDP 63 | 64 | specaug: specaug 65 | specaug_conf: 66 | apply_time_warp: true 67 | time_warp_window: 5 68 | time_warp_mode: bicubic 69 | apply_freq_mask: true 70 | freq_mask_width_range: 71 | - 0 72 | - 27 73 | num_freq_mask: 2 74 | apply_time_mask: true 75 | time_mask_width_ratio_range: 76 | - 0. 77 | - 0.05 78 | num_time_mask: 10 79 | -------------------------------------------------------------------------------- /egs2/hkust/umaconf/train_asr_uma_conformer.yaml: -------------------------------------------------------------------------------- 1 | encoder: conformer 2 | encoder_conf: 3 | # comformer encoder 4 | output_size: 256 # dimension of attention 5 | attention_heads: 4 6 | linear_units: 2048 # the number of units of position-wise feed forward 7 | num_blocks: 12 # the number of encoder blocks 8 | dropout_rate: 0.1 9 | positional_dropout_rate: 0.1 10 | attention_dropout_rate: 0.0 11 | input_layer: conv2d # encoder architecture type 12 | normalize_before: true 13 | rel_pos_type: latest 14 | pos_enc_layer_type: rel_pos 15 | selfattention_layer_type: rel_selfattn 16 | activation_type: swish 17 | macaron_style: true 18 | use_cnn_module: true 19 | cnn_module_kernel: 31 20 | 21 | # decoder related 22 | decoder: unimodal_transformer 23 | decoder_conf: 24 | attention_heads: 4 25 | linear_units: 2048 26 | num_blocks: 6 27 | dropout_rate: 0.1 28 | positional_dropout_rate: 0.1 29 | 30 | # hybrid CTC/attention 31 | model_conf: 32 | ctc_weight: 1 33 | lsm_weight: 0.1 # label smoothing option 34 | length_normalized_loss: false 35 | 36 | # minibatch related 37 | batch_type: numel 38 | batch_bins: 20000000 39 | 40 | # optimization related 41 | accum_grad: 2 42 | grad_clip: 5 43 | max_epoch: 70 44 | val_scheduler_criterion: 45 | - valid 46 | - loss 47 | best_model_criterion: 48 | - - valid 49 | - cer 50 | - min 51 | keep_nbest_models: 10 52 | 53 | optim: adam 54 | optim_conf: 55 | lr: 0.0005 56 | scheduler: warmuplr 57 | scheduler_conf: 58 | warmup_steps: 30000 59 | 60 | specaug: specaug 61 | specaug_conf: 62 | apply_time_warp: true 63 | time_warp_window: 5 64 | time_warp_mode: bicubic 65 | apply_freq_mask: true 66 | freq_mask_width_range: 67 | - 0 68 | - 30 69 | num_freq_mask: 2 70 | apply_time_mask: true 71 | time_mask_width_range: 72 | - 0 73 | - 40 74 | num_time_mask: 2 75 | -------------------------------------------------------------------------------- /egs2/aishell/run_unimodal.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | ### 3 | # @Author: FnoY fangying@westlake.edu.cn 4 | # @LastEditTime: 2023-09-15 13:28:12 5 | # @FilePath: /espnet/egs2/aishell/asr1/run_unimodal.sh 6 | ### 7 | # Set bash to 'debug' mode, it will exit on : 8 | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', 9 | # CUDA_VISIBLE_DEVICES=6 10 | set -e 11 | set -u 12 | set -o pipefail 13 | 14 | train_set=train 15 | valid_set=dev 16 | test_sets="dev test" 17 | 18 | asr_config=umaconf/train_asr_uma_conformer.yaml 19 | inference_config=umaconf/decode_asr_uma.yaml 20 | 21 | lm_config=conf/train_lm_transformer.yaml 22 | use_lm=false 23 | use_wordlm=false 24 | expdir=exp_uma_conformer 25 | inference_asr_model=valid.cer.ave_10best.pth 26 | 27 | # speed perturbation related 28 | # (train_set will be "${train_set}_sp" if speed_perturb_factors is specified) 29 | speed_perturb_factors="0.9 1.0 1.1" 30 | 31 | ./asr_unimodal.sh \ 32 | --nj 64 \ 33 | --inference_nj 64 \ 34 | --ngpu 1 \ 35 | --lang zh \ 36 | --audio_format "flac.ark" \ 37 | --feats_type raw \ 38 | --token_type char \ 39 | --use_lm ${use_lm} \ 40 | --use_word_lm ${use_wordlm} \ 41 | --expdir ${expdir} \ 42 | --inference_asr_model ${inference_asr_model} \ 43 | --lm_config "${lm_config}" \ 44 | --asr_config "${asr_config}" \ 45 | --inference_config "${inference_config}" \ 46 | --train_set "${train_set}" \ 47 | --valid_set "${valid_set}" \ 48 | --test_sets "${test_sets}" \ 49 | --speed_perturb_factors "${speed_perturb_factors}" \ 50 | --asr_speech_fold_length 512 \ 51 | --asr_text_fold_length 150 \ 52 | --lm_fold_length 150 \ 53 | --lm_train_text "data/${train_set}/text" "$@" 54 | -------------------------------------------------------------------------------- /egs2/aishell2/umaconf/train_asr_uma_conformer.yaml: -------------------------------------------------------------------------------- 1 | # Trained with A100 (80GB) x 2 GPUs. It takes about 6 days. 2 | encoder: conformer 3 | encoder_conf: 4 | output_size: 512 # dimension of attention 5 | attention_heads: 8 6 | linear_units: 2048 # the number of units of position-wise feed forward 7 | num_blocks: 12 # the number of encoder blocks 8 | dropout_rate: 0.1 9 | positional_dropout_rate: 0.1 10 | attention_dropout_rate: 0.0 11 | input_layer: conv2d # encoder architecture type 12 | normalize_before: true 13 | pos_enc_layer_type: rel_pos 14 | selfattention_layer_type: rel_selfattn 15 | activation_type: swish 16 | macaron_style: true 17 | use_cnn_module: true 18 | cnn_module_kernel: 31 19 | 20 | # decoder related 21 | decoder: unimodal_transformer 22 | decoder_conf: 23 | attention_heads: 4 24 | linear_units: 2048 25 | num_blocks: 6 26 | dropout_rate: 0.1 27 | positional_dropout_rate: 0.1 28 | 29 | # hybrid CTC/attention 30 | model_conf: 31 | ctc_weight: 1 32 | lsm_weight: 0.1 # label smoothing option 33 | length_normalized_loss: false 34 | 35 | # minibatch related 36 | batch_type: numel 37 | batch_bins: 20000000 38 | num_workers: 4 39 | 40 | # optimization related 41 | accum_grad: 4 42 | grad_clip: 5 43 | max_epoch: 50 44 | val_scheduler_criterion: 45 | - valid 46 | - loss 47 | best_model_criterion: 48 | - - valid 49 | - cer 50 | - min 51 | keep_nbest_models: 10 52 | 53 | optim: adam 54 | optim_conf: 55 | lr: 0.0005 56 | scheduler: warmuplr 57 | scheduler_conf: 58 | warmup_steps: 30000 59 | 60 | specaug: specaug 61 | specaug_conf: 62 | apply_time_warp: true 63 | time_warp_window: 5 64 | time_warp_mode: bicubic 65 | apply_freq_mask: true 66 | freq_mask_width_range: 67 | - 0 68 | - 30 69 | num_freq_mask: 2 70 | apply_time_mask: true 71 | time_mask_width_range: 72 | - 0 73 | - 40 74 | num_time_mask: 2 75 | -------------------------------------------------------------------------------- /egs2/hkust/run_unimodal.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | ### 3 | # @Author: FnoY fangying@westlake.edu.cn 4 | # @LastEditTime: 2023-09-15 13:43:56 5 | # @FilePath: /espnet/egs2/hkust/asr1/run_unimodal.sh 6 | ### 7 | # Set bash to 'debug' mode, it will exit on : 8 | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', 9 | # CUDA_VISIBLE_DEVICES=7 10 | set -e 11 | set -u 12 | set -o pipefail 13 | 14 | train_set=train_nodup 15 | valid_set=train_dev 16 | test_sets="dev" 17 | 18 | 19 | asr_config=umaconf/train_asr_uma_conformer.yaml 20 | inference_config=umaconf/decode_asr_uma.yaml 21 | 22 | lm_config=conf/tuning/train_lm_transformer.yaml 23 | use_lm=false 24 | expdir=exp_uma_conformer 25 | inference_asr_model=valid.cer.ave_10best.pth 26 | 27 | # speed perturbation related 28 | # (train_set will be "${train_set}_sp" if speed_perturb_factors is specified) 29 | speed_perturb_factors="0.9 1.0 1.1" 30 | 31 | ./asr_unimodal.sh \ 32 | --nj 64 \ 33 | --inference_nj 1 \ 34 | --ngpu 1 \ 35 | --lang zh \ 36 | --audio_format flac \ 37 | --feats_type raw \ 38 | --token_type char \ 39 | --nlsyms_txt data/nlsyms.txt \ 40 | --use_lm ${use_lm} \ 41 | --expdir ${expdir} \ 42 | --inference_asr_model ${inference_asr_model} \ 43 | --lm_config "${lm_config}" \ 44 | --asr_config "${asr_config}" \ 45 | --inference_config "${inference_config}" \ 46 | --train_set "${train_set}" \ 47 | --valid_set "${valid_set}" \ 48 | --test_sets "${test_sets}" \ 49 | --speed_perturb_factors "${speed_perturb_factors}" \ 50 | --lm_train_text "data/${train_set}/text" "$@" -------------------------------------------------------------------------------- /egs2/aishell2/run_unimodal.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | ### 3 | # @Author: FnoY fangying@westlake.edu.cn 4 | # @LastEditTime: 2023-09-15 13:35:03 5 | # @FilePath: /espnet/egs2/aishell2/asr1/run_unimodal.sh 6 | ### 7 | # Set bash to 'debug' mode, it will exit on : 8 | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', 9 | # CUDA_VISIBLE_DEVICES=4,5 10 | set -e 11 | set -u 12 | set -o pipefail 13 | 14 | train_set=train_noeng 15 | valid_set=dev_ios 16 | test_sets="dev_ios test_android test_ios test_mic" 17 | 18 | asr_config=umaconf/train_asr_uma_conformer.yaml 19 | inference_config=umaconf/decode_uma.yaml 20 | 21 | lm_config=conf/train_lm_transformer.yaml 22 | use_lm=false 23 | use_wordlm=false 24 | expdir=exp_uma_conformer_12e_718 25 | inference_asr_model=valid.cer.ave_10best.pth 26 | 27 | # speed perturbation related 28 | # (train_set will be "${train_set}_sp" if speed_perturb_factors is specified) 29 | speed_perturb_factors="0.9 1.0 1.1" 30 | 31 | ./asr_unimodal.sh \ 32 | --nj 64 \ 33 | --inference_nj 1 \ 34 | --ngpu 2 \ 35 | --lang zh \ 36 | --audio_format wav \ 37 | --feats_type raw \ 38 | --token_type char \ 39 | --use_lm ${use_lm} \ 40 | --use_word_lm ${use_wordlm} \ 41 | --expdir ${expdir} \ 42 | --inference_asr_model ${inference_asr_model} \ 43 | --lm_config "${lm_config}" \ 44 | --asr_config "${asr_config}" \ 45 | --inference_config "${inference_config}" \ 46 | --train_set "${train_set}" \ 47 | --valid_set "${valid_set}" \ 48 | --test_sets "${test_sets}" \ 49 | --speed_perturb_factors "${speed_perturb_factors}" \ 50 | --asr_speech_fold_length 512 \ 51 | --asr_text_fold_length 150 \ 52 | --lm_fold_length 150 \ 53 | --lm_train_text "data/${train_set}/text" "$@" 54 | -------------------------------------------------------------------------------- /egs2/hkust/umaconf/train_asr_uma_branchformer.yaml: -------------------------------------------------------------------------------- 1 | # network architecture 2 | # encoder related 3 | # encoder: unimodal_branchformer 4 | encoder: e_branchformer 5 | encoder_conf: 6 | output_size: 256 7 | attention_heads: 4 8 | attention_layer_type: rel_selfattn 9 | pos_enc_layer_type: rel_pos 10 | rel_pos_type: latest 11 | cgmlp_linear_units: 1024 12 | cgmlp_conv_kernel: 31 13 | use_linear_after_conv: false 14 | gate_activation: identity 15 | num_blocks: 12 16 | dropout_rate: 0.1 17 | positional_dropout_rate: 0.1 18 | attention_dropout_rate: 0.1 19 | input_layer: conv2d 20 | layer_drop_rate: 0.0 21 | linear_units: 1024 22 | positionwise_layer_type: linear 23 | use_ffn: true 24 | macaron_ffn: true 25 | merge_conv_kernel: 31 26 | 27 | # # decoder related 28 | decoder: unimodal_transformer 29 | decoder_conf: 30 | attention_heads: 4 31 | linear_units: 2048 32 | num_blocks: 6 33 | dropout_rate: 0.1 34 | positional_dropout_rate: 0.1 35 | 36 | # hybrid CTC/attention 37 | model_conf: 38 | ctc_weight: 1 39 | lsm_weight: 0.1 # label smoothing option 40 | length_normalized_loss: false 41 | 42 | # minibatch related 43 | batch_type: numel 44 | batch_bins: 40000000 45 | 46 | # optimization related 47 | accum_grad: 1 48 | grad_clip: 5 49 | max_epoch: 70 50 | best_model_criterion: 51 | - - valid 52 | - cer 53 | - min 54 | keep_nbest_models: 10 55 | 56 | optim: adam 57 | optim_conf: 58 | lr: 0.001 59 | weight_decay: 0.000001 60 | scheduler: warmuplr 61 | scheduler_conf: 62 | warmup_steps: 35000 63 | 64 | num_workers: 4 # num of workers of data loader 65 | use_amp: true # automatic mixed precision 66 | unused_parameters: false # set as true if some params are unused in DDP 67 | 68 | specaug: specaug 69 | specaug_conf: 70 | apply_time_warp: true 71 | time_warp_window: 5 72 | time_warp_mode: bicubic 73 | apply_freq_mask: true 74 | freq_mask_width_range: 75 | - 0 76 | - 27 77 | num_freq_mask: 2 78 | apply_time_mask: true 79 | time_mask_width_ratio_range: 80 | - 0. 81 | - 0.05 82 | num_time_mask: 10 83 | -------------------------------------------------------------------------------- /egs2/hkust/umaconf/train_asr_uma_conformer_condition.yaml: -------------------------------------------------------------------------------- 1 | encoder: conformer 2 | encoder_conf: 3 | # comformer encoder 4 | output_size: 256 # dimension of attention 5 | attention_heads: 4 6 | linear_units: 2048 # the number of units of position-wise feed forward 7 | num_blocks: 12 # the number of encoder blocks 8 | dropout_rate: 0.1 9 | positional_dropout_rate: 0.1 10 | attention_dropout_rate: 0.0 11 | input_layer: conv2d # encoder architecture type 12 | normalize_before: true 13 | rel_pos_type: latest 14 | pos_enc_layer_type: rel_pos 15 | selfattention_layer_type: rel_selfattn 16 | activation_type: swish 17 | macaron_style: true 18 | use_cnn_module: true 19 | cnn_module_kernel: 31 20 | interctc_layer_idx: [6,9,12] 21 | interctc_use_conditioning: true 22 | 23 | # decoder related 24 | decoder: unimodal_transformer 25 | decoder_conf: 26 | attention_heads: 4 27 | linear_units: 2048 28 | num_blocks: 6 29 | dropout_rate: 0.1 30 | positional_dropout_rate: 0.1 31 | interctc_layer_idx: [2,4] 32 | interctc_use_conditioning: true 33 | 34 | # hybrid CTC/attention 35 | model_conf: 36 | ctc_weight: 1 37 | interctc_weight_enc: 0.3 38 | interctc_weight_dec: 0.2 39 | lsm_weight: 0.1 # label smoothing option 40 | length_normalized_loss: false 41 | 42 | # minibatch related 43 | batch_type: numel 44 | batch_bins: 20000000 45 | 46 | # optimization related 47 | accum_grad: 2 48 | grad_clip: 5 49 | max_epoch: 70 50 | val_scheduler_criterion: 51 | - valid 52 | - loss 53 | best_model_criterion: 54 | - - valid 55 | - cer 56 | - min 57 | keep_nbest_models: 10 58 | 59 | optim: adam 60 | optim_conf: 61 | lr: 0.0005 62 | scheduler: warmuplr 63 | scheduler_conf: 64 | warmup_steps: 30000 65 | 66 | specaug: specaug 67 | specaug_conf: 68 | apply_time_warp: true 69 | time_warp_window: 5 70 | time_warp_mode: bicubic 71 | apply_freq_mask: true 72 | freq_mask_width_range: 73 | - 0 74 | - 30 75 | num_freq_mask: 2 76 | apply_time_mask: true 77 | time_mask_width_range: 78 | - 0 79 | - 40 80 | num_time_mask: 2 81 | -------------------------------------------------------------------------------- /egs2/aishell2/umaconf/train_asr_uma_conformer_condition.yaml: -------------------------------------------------------------------------------- 1 | # Trained with A100 (80GB) x 2 GPUs. It takes about 6 days. 2 | encoder: conformer 3 | encoder_conf: 4 | output_size: 512 # dimension of attention 5 | attention_heads: 8 6 | linear_units: 2048 # the number of units of position-wise feed forward 7 | num_blocks: 12 # the number of encoder blocks 8 | dropout_rate: 0.1 9 | positional_dropout_rate: 0.1 10 | attention_dropout_rate: 0.0 11 | input_layer: conv2d # encoder architecture type 12 | normalize_before: true 13 | pos_enc_layer_type: rel_pos 14 | selfattention_layer_type: rel_selfattn 15 | activation_type: swish 16 | macaron_style: true 17 | use_cnn_module: true 18 | cnn_module_kernel: 31 19 | interctc_layer_idx: [6,9,12] 20 | interctc_use_conditioning: true 21 | 22 | # decoder related 23 | decoder: unimodal_transformer 24 | decoder_conf: 25 | attention_heads: 4 26 | linear_units: 2048 27 | num_blocks: 6 28 | dropout_rate: 0.1 29 | positional_dropout_rate: 0.1 30 | interctc_layer_idx: [2,4] 31 | interctc_use_conditioning: true 32 | 33 | # hybrid CTC/attention 34 | model_conf: 35 | ctc_weight: 1 36 | interctc_weight_enc: 0.3 37 | interctc_weight_dec: 0.2 38 | lsm_weight: 0.1 # label smoothing option 39 | length_normalized_loss: false 40 | 41 | # minibatch related 42 | batch_type: numel 43 | batch_bins: 20000000 44 | 45 | # optimization related 46 | accum_grad: 4 47 | grad_clip: 5 48 | max_epoch: 50 49 | val_scheduler_criterion: 50 | - valid 51 | - loss 52 | best_model_criterion: 53 | - - valid 54 | - cer 55 | - min 56 | keep_nbest_models: 10 57 | 58 | optim: adam 59 | optim_conf: 60 | lr: 0.0005 61 | scheduler: warmuplr 62 | scheduler_conf: 63 | warmup_steps: 30000 64 | 65 | specaug: specaug 66 | specaug_conf: 67 | apply_time_warp: true 68 | time_warp_window: 5 69 | time_warp_mode: bicubic 70 | apply_freq_mask: true 71 | freq_mask_width_range: 72 | - 0 73 | - 30 74 | num_freq_mask: 2 75 | apply_time_mask: true 76 | time_mask_width_range: 77 | - 0 78 | - 40 79 | num_time_mask: 2 80 | -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/RESULTS.md: -------------------------------------------------------------------------------- 1 | 2 | # RESULTS 3 | ## Environments 4 | - date: `Wed Aug 23 15:31:26 CST 2023` 5 | - python version: `3.9.16 (main, Jan 11 2023, 16:05:54) [GCC 11.2.0]` 6 | - espnet version: `espnet 202301` 7 | - pytorch version: `pytorch 1.12.1` 8 | - Git hash: `58d7c097f69a5ddc15aa4658e9462e028157f326` 9 | - Commit date: `Thu Jun 29 15:27:10 2023 +0800` 10 | 11 | ## exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp 12 | ### WER 13 | 14 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 15 | |---|---|---|---|---|---|---|---|---| 16 | |50epoch_decode_uma_asr_model_valid.cer.ave_10best/test_android|5000|5002|63.4|36.5|0.0|0.0|36.6|36.6| 17 | |50epoch_decode_uma_asr_model_valid.cer.ave_10best/test_ios|5000|5002|66.1|33.9|0.0|0.0|33.9|33.9| 18 | |50epoch_decode_uma_asr_model_valid.cer.ave_10best/test_mic|5000|5002|63.7|36.2|0.0|0.0|36.3|36.3| 19 | |decode_uma_asr_model_valid.cer.ave_10best/test_mic|50|50|48.0|52.0|0.0|0.0|52.0|52.0| 20 | 21 | ### CER 22 | 23 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 24 | |---|---|---|---|---|---|---|---|---| 25 | |50epoch_decode_uma_asr_model_valid.cer.ave_10best/test_android|5000|49534|94.1|5.6|0.3|0.2|6.0|36.6| 26 | |50epoch_decode_uma_asr_model_valid.cer.ave_10best/test_ios|5000|49534|94.8|5.0|0.2|0.2|5.3|33.9| 27 | |50epoch_decode_uma_asr_model_valid.cer.ave_10best/test_mic|5000|49534|94.2|5.6|0.2|0.2|5.9|36.3| 28 | |decode_uma_asr_model_valid.cer.ave_10best/test_mic|50|458|89.3|10.5|0.2|0.2|10.9|52.0| 29 | 30 | ### TER 31 | 32 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 33 | |---|---|---|---|---|---|---|---|---| 34 | ## exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_condition_raw_zh_char_sp/50epoch_decode_uma_asr_model_valid.cer.ave_10best 35 | ### WER 36 | 37 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 38 | |---|---|---|---|---|---|---|---|---| 39 | |org/dev_ios|2500|2500|67.9|32.1|0.0|0.0|32.1|32.1| 40 | 41 | ### CER 42 | 43 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 44 | |---|---|---|---|---|---|---|---|---| 45 | |org/dev_ios|2500|24802|95.2|4.6|0.2|0.1|4.9|32.1| 46 | 47 | ### TER 48 | 49 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 50 | |---|---|---|---|---|---|---|---|---| 51 | -------------------------------------------------------------------------------- /egs2/aishell/umaconf/train_asr_uma_conformer.yaml: -------------------------------------------------------------------------------- 1 | encoder: conformer 2 | encoder_conf: 3 | # comformer encoder 4 | output_size: 256 # dimension of attention 5 | attention_heads: 4 6 | linear_units: 2048 # the number of units of position-wise feed forward 7 | num_blocks: 12 # the number of encoder blocks 8 | dropout_rate: 0.1 9 | positional_dropout_rate: 0.1 10 | attention_dropout_rate: 0.1 11 | input_layer: conv2d # encoder architecture type 12 | normalize_before: true 13 | rel_pos_type: latest 14 | pos_enc_layer_type: rel_pos 15 | selfattention_layer_type: rel_selfattn 16 | activation_type: swish 17 | macaron_style: true 18 | use_cnn_module: true 19 | cnn_module_kernel: 31 20 | 21 | # decoder related 22 | decoder: unimodal_transformer 23 | decoder_conf: 24 | attention_heads: 4 25 | linear_units: 2048 26 | num_blocks: 6 27 | dropout_rate: 0.1 28 | positional_dropout_rate: 0.1 29 | 30 | # hybrid CTC/attention 31 | model_conf: 32 | ctc_weight: 1 33 | lsm_weight: 0.1 # label smoothing option 34 | length_normalized_loss: false 35 | 36 | # minibatch related 37 | batch_type: numel 38 | batch_bins: 25000000 39 | 40 | # optimization related 41 | accum_grad: 1 42 | grad_clip: 5 43 | # patience: 3 44 | max_epoch: 60 45 | val_scheduler_criterion: 46 | - valid 47 | - loss 48 | best_model_criterion: 49 | - - valid 50 | - cer 51 | - min 52 | keep_nbest_models: 10 53 | 54 | # NoamLR is deprecated. Use WarmupLR. 55 | # The following is equivalent setting for NoamLR: 56 | optim: adam 57 | optim_conf: 58 | lr: 0.001 59 | weight_decay: 0.000001 60 | scheduler: warmuplr # pytorch v1.1.0+ required 61 | scheduler_conf: 62 | warmup_steps: 35000 63 | 64 | num_workers: 4 # num of workers of data loader 65 | use_amp: true # automatic mixed precision 66 | unused_parameters: false # set as true if some params are unused in DDP 67 | 68 | specaug: specaug 69 | specaug_conf: 70 | apply_time_warp: true 71 | time_warp_window: 5 72 | time_warp_mode: bicubic 73 | apply_freq_mask: true 74 | freq_mask_width_range: 75 | - 0 76 | - 27 77 | num_freq_mask: 2 78 | apply_time_mask: true 79 | time_mask_width_ratio_range: 80 | - 0. 81 | - 0.05 82 | num_time_mask: 10 -------------------------------------------------------------------------------- /egs2/aishell2/exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/RESULTS.md: -------------------------------------------------------------------------------- 1 | 7 | 8 | # RESULTS 9 | ## Environments 10 | - date: `Wed Aug 23 15:04:44 CST 2023` 11 | - python version: `3.9.16 (main, Jan 11 2023, 16:05:54) [GCC 11.2.0]` 12 | - espnet version: `espnet 202301` 13 | - pytorch version: `pytorch 1.12.1` 14 | - Git hash: `58d7c097f69a5ddc15aa4658e9462e028157f326` 15 | - Commit date: `Thu Jun 29 15:27:10 2023 +0800` 16 | 17 | ## exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp 18 | ### WER 19 | 20 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 21 | |---|---|---|---|---|---|---|---|---| 22 | |decode_uma_asr_model_valid.cer.ave_10best/test_android|5000|5002|62.7|37.3|0.0|0.0|37.3|37.3| 23 | |decode_uma_asr_model_valid.cer.ave_10best/test_ios|5000|5002|65.6|34.3|0.0|0.0|34.4|34.3| 24 | |decode_uma_asr_model_valid.cer.ave_10best/test_mic|5000|5002|62.6|37.4|0.0|0.0|37.4|37.4| 25 | 26 | ### CER 27 | 28 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 29 | |---|---|---|---|---|---|---|---|---| 30 | |decode_uma_asr_model_valid.cer.ave_10best/test_android|5000|49534|94.1|5.7|0.2|0.1|6.0|37.3| 31 | |decode_uma_asr_model_valid.cer.ave_10best/test_ios|5000|49534|94.8|5.0|0.2|0.1|5.3|34.3| 32 | |decode_uma_asr_model_valid.cer.ave_10best/test_mic|5000|49534|94.1|5.7|0.2|0.2|6.0|37.4| 33 | 34 | ### TER 35 | 36 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 37 | |---|---|---|---|---|---|---|---|---| 38 | ## exp_uma_conformer_12e_718/asr_train_asr_uma_conformer_raw_zh_char_sp/decode_uma_asr_model_valid.cer.ave_10best 39 | ### WER 40 | 41 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 42 | |---|---|---|---|---|---|---|---|---| 43 | |org/dev_ios|2500|2500|67.5|32.5|0.0|0.0|32.5|32.5| 44 | 45 | ### CER 46 | 47 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 48 | |---|---|---|---|---|---|---|---|---| 49 | |org/dev_ios|2500|24802|95.2|4.6|0.2|0.1|4.9|32.5| 50 | 51 | ### TER 52 | 53 | |dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| 54 | |---|---|---|---|---|---|---|---|---| 55 | -------------------------------------------------------------------------------- /egs2/hkust/umaconf/train_asr_uma_branchformer_condition.yaml: -------------------------------------------------------------------------------- 1 | # network architecture 2 | # encoder related 3 | # encoder: unimodal_branchformer 4 | encoder: e_branchformer 5 | encoder_conf: 6 | output_size: 256 7 | attention_heads: 4 8 | attention_layer_type: rel_selfattn 9 | pos_enc_layer_type: rel_pos 10 | rel_pos_type: latest 11 | cgmlp_linear_units: 1024 12 | cgmlp_conv_kernel: 31 13 | use_linear_after_conv: false 14 | gate_activation: identity 15 | num_blocks: 12 16 | dropout_rate: 0.1 17 | positional_dropout_rate: 0.1 18 | attention_dropout_rate: 0.1 19 | input_layer: conv2d 20 | layer_drop_rate: 0.0 21 | linear_units: 1024 22 | positionwise_layer_type: linear 23 | use_ffn: true 24 | macaron_ffn: true 25 | merge_conv_kernel: 31 26 | interctc_layer_idx: [6,9,12] 27 | interctc_use_conditioning: true 28 | 29 | # # decoder related 30 | decoder: unimodal_transformer 31 | decoder_conf: 32 | attention_heads: 4 33 | linear_units: 2048 34 | num_blocks: 6 35 | dropout_rate: 0.1 36 | positional_dropout_rate: 0.1 37 | interctc_layer_idx: [2,4] 38 | interctc_use_conditioning: true 39 | 40 | # hybrid CTC/attention 41 | model_conf: 42 | ctc_weight: 1 43 | interctc_weight_enc: 0.3 44 | interctc_weight_dec: 0.2 45 | lsm_weight: 0.1 # label smoothing option 46 | length_normalized_loss: false 47 | 48 | # minibatch related 49 | batch_type: numel 50 | batch_bins: 40000000 51 | 52 | # optimization related 53 | accum_grad: 1 54 | grad_clip: 5 55 | max_epoch: 70 56 | best_model_criterion: 57 | - - valid 58 | - cer 59 | - min 60 | keep_nbest_models: 10 61 | 62 | optim: adam 63 | optim_conf: 64 | lr: 0.001 65 | weight_decay: 0.000001 66 | scheduler: warmuplr 67 | scheduler_conf: 68 | warmup_steps: 35000 69 | 70 | num_workers: 4 # num of workers of data loader 71 | use_amp: true # automatic mixed precision 72 | unused_parameters: false # set as true if some params are unused in DDP 73 | 74 | specaug: specaug 75 | specaug_conf: 76 | apply_time_warp: true 77 | time_warp_window: 5 78 | time_warp_mode: bicubic 79 | apply_freq_mask: true 80 | freq_mask_width_range: 81 | - 0 82 | - 27 83 | num_freq_mask: 2 84 | apply_time_mask: true 85 | time_mask_width_ratio_range: 86 | - 0. 87 | - 0.05 88 | num_time_mask: 10 89 | -------------------------------------------------------------------------------- /egs2/aishell/umaconf/train_asr_uma_conformer_condition.yaml: -------------------------------------------------------------------------------- 1 | encoder: conformer 2 | encoder_conf: 3 | # comformer encoder 4 | output_size: 256 # dimension of attention 5 | attention_heads: 4 6 | linear_units: 2048 # the number of units of position-wise feed forward 7 | num_blocks: 12 # the number of encoder blocks 8 | dropout_rate: 0.1 9 | positional_dropout_rate: 0.1 10 | attention_dropout_rate: 0.0 11 | input_layer: conv2d # encoder architecture type 12 | normalize_before: true 13 | rel_pos_type: latest 14 | pos_enc_layer_type: rel_pos 15 | selfattention_layer_type: rel_selfattn 16 | activation_type: swish 17 | macaron_style: true 18 | use_cnn_module: true 19 | cnn_module_kernel: 31 20 | interctc_layer_idx: [6,9,12] 21 | interctc_use_conditioning: true 22 | 23 | # decoder related 24 | decoder: unimodal_transformer 25 | decoder_conf: 26 | attention_heads: 4 27 | linear_units: 2048 28 | num_blocks: 6 29 | dropout_rate: 0.1 30 | positional_dropout_rate: 0.1 31 | interctc_layer_idx: [2,4] 32 | interctc_use_conditioning: true 33 | 34 | # hybrid CTC/attention 35 | model_conf: 36 | ctc_weight: 1 37 | interctc_weight_enc: 0.3 38 | interctc_weight_dec: 0.2 39 | lsm_weight: 0.1 # label smoothing option 40 | length_normalized_loss: false 41 | 42 | # minibatch related 43 | batch_type: numel 44 | batch_bins: 25000000 45 | 46 | # optimization related 47 | accum_grad: 1 48 | grad_clip: 5 49 | # patience: 3 50 | max_epoch: 60 51 | val_scheduler_criterion: 52 | - valid 53 | - loss 54 | best_model_criterion: 55 | - - valid 56 | - cer 57 | - min 58 | keep_nbest_models: 10 59 | 60 | # NoamLR is deprecated. Use WarmupLR. 61 | # The following is equivalent setting for NoamLR: 62 | optim: adam 63 | optim_conf: 64 | lr: 0.001 65 | weight_decay: 0.000001 66 | scheduler: warmuplr # pytorch v1.1.0+ required 67 | scheduler_conf: 68 | warmup_steps: 35000 69 | 70 | num_workers: 4 # num of workers of data loader 71 | use_amp: true # automatic mixed precision 72 | unused_parameters: false # set as true if some params are unused in DDP 73 | 74 | specaug: specaug 75 | specaug_conf: 76 | apply_time_warp: true 77 | time_warp_window: 5 78 | time_warp_mode: bicubic 79 | apply_freq_mask: true 80 | freq_mask_width_range: 81 | - 0 82 | - 27 83 | num_freq_mask: 2 84 | apply_time_mask: true 85 | time_mask_width_ratio_range: 86 | - 0. 87 | - 0.05 88 | num_time_mask: 10 -------------------------------------------------------------------------------- /espnet2/asr/uma.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | Author: FnoY fangying@westlake.edu.cn 4 | LastEditTime: 2024-10-09 14:49:17 5 | FilePath: \UMA-ASR\espnet2\asr\uma.py 6 | Notes: If the feature dimension changes from 256 to 512, just modify 'output_size: int = 256' to 'output_size: int = 512'; 7 | If you want to use the early termination during inference, just set 'self.EarlyTermination = True'. 8 | ''' 9 | # """Unimodal aggregation definition.""" 10 | import logging 11 | from typing import Optional, Tuple 12 | import torch 13 | from typeguard import check_argument_types 14 | 15 | 16 | class UMA(torch.nn.Module): 17 | """UMA module. 18 | 19 | """ 20 | 21 | def __init__( 22 | self, 23 | input_size: int = 512, 24 | output_size: int = 256, 25 | ): 26 | assert check_argument_types() 27 | super().__init__() 28 | self._output_size = output_size 29 | input_size = output_size 30 | 31 | self.linear_sigmoid = torch.nn.Sequential( 32 | torch.nn.Linear(input_size, 1), 33 | torch.nn.Sigmoid(), 34 | ) 35 | 36 | self.EarlyTermination = False 37 | 38 | def output_size(self) -> int: 39 | return self._output_size 40 | 41 | def forward( 42 | self, 43 | xs_pad: torch.Tensor, 44 | olens: torch.Tensor, 45 | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: 46 | """Calculate forward propagation. 47 | 48 | Args: 49 | xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). 50 | olens (torch.Tensor): Input length (#batch). 51 | prev_states (torch.Tensor): Not to be used now. 52 | Returns: 53 | torch.Tensor: Output tensor (#batch, I, output_size). 54 | torch.Tensor: Output length (#batch). 55 | torch.Tensor: Not to be used now. 56 | """ 57 | 58 | batch, length, _ = xs_pad.size() 59 | # Use Linear-Sigmoid to generate unimodal aggregation weights 60 | # uma_weights: (#batch, L, 1) 61 | uma_weights = self.linear_sigmoid(xs_pad) 62 | 63 | # Unimodal Detection 64 | scalar_before = uma_weights[:,:-1,:].detach() # (#batch, L-1, 1) 65 | scalar_after = uma_weights[:,1:,:].detach() # (#batch, L-1, 1) 66 | scalar_before = torch.nn.functional.pad(scalar_before,(0,0,1,0)) # (#batch, L, 1) 67 | scalar_after = torch.nn.functional.pad(scalar_after,(0,0,0,1)) # (#batch, L, 1) 68 | 69 | mask = (uma_weights.lt(scalar_before)) & (uma_weights.lt(scalar_after)) # bool tensor (#batch, L, 1) 70 | 71 | if not self.training and self.EarlyTermination: 72 | mask2 = (uma_weights.gt(scalar_before)) & (uma_weights.gt(scalar_after)) # bool tensor (#batch, L, 1) 73 | mask = mask | mask2 74 | 75 | mask = mask.reshape(uma_weights.shape[0], -1) # bool tensor (#batch, L) 76 | mask[:,0] = True 77 | # mask.nonzero() is [[0,0],[0,3],[0,7],...,[1,0],[1,2],...,[2,0],[2,4],...,[#batch-1,0],...] 78 | # mask.nonzero() : (K,2); K is the total number of valleys in this batch 79 | batch_index = mask.nonzero()[:,0] # (k,1); [0,0,0,...,1,1,...,2,2,...,#batch-1,...] 80 | valley_index_start = mask.nonzero()[:,1] # (k,1); [0,3,7,...,0,2,...,0,4,...,0,...] 81 | mask[:,0] = False 82 | mask[:,-1] = True 83 | valley_index_end = mask.nonzero()[:,1] + 2 84 | # (k,1); [5,9,...,4,...,6,...] 85 | valley_index_end = torch.where(valley_index_end > (length) * torch.ones_like(valley_index_end), 86 | (length) * torch.ones_like(valley_index_end), valley_index_end) 87 | 88 | _,counts = torch.unique(batch_index, return_counts = True) # (#batch, 1); the number of valleys in each sample 89 | max_counts = (torch.max(counts)).item() 90 | 91 | utri_mat1 = torch.tril(torch.ones(max_counts+1,max_counts),-1).to(xs_pad.device) 92 | batch_index_mask = utri_mat1[counts] 93 | batch_index_mask = batch_index_mask.reshape(-1,1) 94 | batch_index_mask = batch_index_mask.nonzero()[:, 0] 95 | 96 | valleys = torch.zeros(batch * max_counts, 2).type_as(valley_index_start) 97 | valleys[batch_index_mask] = torch.cat((valley_index_start.unsqueeze(1), valley_index_end.unsqueeze(1)),1) 98 | # logging.info(str(valleys)) 99 | 100 | # utri_mat = torch.tril(torch.cuda.FloatTensor(length+1,length).fill_(1),-1) 101 | utri_mat = torch.tril(torch.ones(length+1,length),-1).to(xs_pad.device) 102 | output_mask = (utri_mat[valleys[:,1]]-utri_mat[valleys[:,0]]).reshape(batch, max_counts, length) 103 | output_mask = output_mask.detach() 104 | 105 | # Aggregation 106 | alpha_h = torch.mul(uma_weights, xs_pad) 107 | xs_pad = torch.bmm(output_mask, alpha_h) / torch.bmm(output_mask, uma_weights).clamp_(1e-6) 108 | 109 | # olens = (olens / olens[0] * xs_pad.shape[1]).type_as(olens) 110 | olens = counts 111 | 112 | # return xs_pad, olens, uma_weights 113 | return xs_pad, olens, None 114 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 7 | # UMA-ASR 8 | This repository is the official implementation of unimodal aggregation (UMA) for automaticspeech recognition (ASR). 9 | 10 | It consists of two works: 11 | 1. for non-autoregressive offline ASR: ["Unimodal Aggregation for CTC-based Speech Recognition" (ICASSP 2024)](https://ieeexplore.ieee.org/abstract/document/10448248) 12 | 2. for streaming ASR: ["Mamba for Streaming ASR Combined with Unimodal Aggregation" (submitted to ICASSP 2025)](https://arxiv.org/abs/2410.00070) 13 | 14 |
15 |

16 | version 17 | version 18 | python 19 |
20 | 21 | [Poster :star_struck:](https://sigport.org/sites/default/files/docs/fangying_UMA_poster4.0.pdf) **|** [Issues :sweat_smile:](https://github.com/Audio-WestlakeU/UMA-ASR/issues) 22 | **|** [Lab :hear_no_evil:](https://github.com/Audio-WestlakeU) **|** [Contact :kissing_heart:](fangying@westlake.edu.cn) 23 | 24 | ## Introduction 25 | 26 | ### For Non-autoregressive Offline ASR 27 | A unimodal aggregation (UMA) is proposed to segment and integrate the feature frames that belong to the same text token, and thus to learn better feature representations for text tokens. The frame-wise features and weights are both derived from an encoder. Then, the feature frames with unimodal weights are integrated and further processed by a decoder. Connectionist temporal classification (CTC) loss is applied for training. Moreover, by integrating self-conditioned CTC into the proposed framework, the performance can be further noticeably improved. 28 | 29 |
30 | The proposed UMA model 31 |
32 | 33 | ### For Streaming ASR 34 | Mamba, a recently proposed state space model, has demonstrated the ability to match or surpass Transformers in various tasks while benefiting from a linear complexity advantage. We explore the efficiency of Mamba encoder for streaming ASR and propose an associated lookahead mechanism for leveraging controllable future information. Additionally, a streaming-style unimodal aggregation (UMA) method is 35 | implemented, which automatically detects token activity and streamingly triggers token output, and meanwhile aggregates feature frames for better learning token representation. Based on UMA, an early termination (ET) method is proposed to further reduce recognition latency. 36 | 37 |
38 | The proposed Mamba-UMA model 39 |
40 | 41 | 42 | ## Get started 43 | 1. The proposed method is implemented using [ESPnet2](https://github.com/espnet/espnet). So please make sure you have [installed ESPnet](https://espnet.github.io/espnet/installation.html#) successfully. 44 | 2. Roll back [espnet](https://github.com/espnet/espnet/tree/v.202304) to the specified version as follows: 45 | ``` 46 | git checkout v.202304 47 | ``` 48 | 3. Clone the UMA-ASR codes by: 49 | ``` 50 | git clone https://github.com/Audio-WestlakeU/UMA-ASR 51 | ``` 52 | 4. Copy the configurations of the recipes in the [egs2](https://github.com/Audio-WestlakeU/UMA-ASR/tree/main/egs2) folder to the corresponding directory in "espnet/egs2/". At present, experiments have only been conducted on AISHELL-1, AISHELL-2, HKUST dataset. If you want to experiment on other Chinese datasets, you can refer to these configurations. 53 | 5. Copy the files in the [espnet2](https://github.com/Audio-WestlakeU/UMA-ASR/tree/main/espnet2) folder to the corresponding folder in "espnet/espnet2", and check that the comment path in the file header matches your path. 54 | 6. To experiment, follow the [ESPnet's steps](https://espnet.github.io/espnet/espnet2_tutorial.html#recipes-using-espnet2). You can implement UMA method by simply changing **run.sh** from the command line to our **run_unimodal.sh**. For example: 55 | ``` 56 | ./run_unimodal.sh --stage 10 --stop_stage 13 57 | ``` 58 | Be careful to change the permissions of the bash files to executable. 59 | ``` 60 | chmod -x asr_unimodal.sh 61 | chmod -x run_unimodal.sh 62 | ``` 63 | 64 | ## Citation 65 | You can cite this paper like: 66 | 67 | ``` 68 | @inproceedings{fang2024unimodal, 69 | title={Unimodal aggregation for CTC-based speech recognition}, 70 | author={Fang, Ying and Li, Xiaofei}, 71 | booktitle={ICASSP 2024-2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 72 | pages={10591--10595}, 73 | year={2024}, 74 | organization={IEEE} 75 | } 76 | 77 | @article{fang2024mambauma, 78 | title={Mamba for Streaming ASR Combined with Unimodal Aggregation}, 79 | author={Ying Fang and Xiaofei Li}, 80 | journal={arXiv preprint arXiv:2410.00070}, 81 | year={2024} 82 | } 83 | ``` 84 | -------------------------------------------------------------------------------- /espnet2/asr/mamba_ssm/ops/triton/selective_state_update.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | """We want triton==2.1.0 for this 4 | """ 5 | 6 | import math 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | import triton 11 | import triton.language as tl 12 | 13 | from einops import rearrange, repeat 14 | 15 | 16 | @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) 17 | @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) 18 | @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) 19 | @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) 20 | @triton.jit 21 | def _selective_scan_update_kernel( 22 | # Pointers to matrices 23 | state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, 24 | # Matrix dimensions 25 | batch, dim, dstate, 26 | # Strides 27 | stride_state_batch, stride_state_dim, stride_state_dstate, 28 | stride_x_batch, stride_x_dim, 29 | stride_dt_batch, stride_dt_dim, 30 | stride_dt_bias_dim, 31 | stride_A_dim, stride_A_dstate, 32 | stride_B_batch, stride_B_dstate, 33 | stride_C_batch, stride_C_dstate, 34 | stride_D_dim, 35 | stride_z_batch, stride_z_dim, 36 | stride_out_batch, stride_out_dim, 37 | # Meta-parameters 38 | DT_SOFTPLUS: tl.constexpr, 39 | BLOCK_SIZE_M: tl.constexpr, 40 | HAS_DT_BIAS: tl.constexpr, 41 | HAS_D: tl.constexpr, 42 | HAS_Z: tl.constexpr, 43 | BLOCK_SIZE_DSTATE: tl.constexpr, 44 | ): 45 | pid_m = tl.program_id(axis=0) 46 | pid_b = tl.program_id(axis=1) 47 | state_ptr += pid_b * stride_state_batch 48 | x_ptr += pid_b * stride_x_batch 49 | dt_ptr += pid_b * stride_dt_batch 50 | B_ptr += pid_b * stride_B_batch 51 | C_ptr += pid_b * stride_C_batch 52 | if HAS_Z: 53 | z_ptr += pid_b * stride_z_batch 54 | out_ptr += pid_b * stride_out_batch 55 | 56 | offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 57 | offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) 58 | state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) 59 | x_ptrs = x_ptr + offs_m * stride_x_dim 60 | dt_ptrs = dt_ptr + offs_m * stride_dt_dim 61 | if HAS_DT_BIAS: 62 | dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim 63 | A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate) 64 | B_ptrs = B_ptr + offs_n * stride_B_dstate 65 | C_ptrs = C_ptr + offs_n * stride_C_dstate 66 | if HAS_D: 67 | D_ptrs = D_ptr + offs_m * stride_D_dim 68 | if HAS_Z: 69 | z_ptrs = z_ptr + offs_m * stride_z_dim 70 | out_ptrs = out_ptr + offs_m * stride_out_dim 71 | 72 | state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0) 73 | x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 74 | dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 75 | if HAS_DT_BIAS: 76 | dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 77 | if DT_SOFTPLUS: 78 | dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) 79 | A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) 80 | dA = tl.exp(A * dt[:, None]) 81 | B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) 82 | C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) 83 | if HAS_D: 84 | D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 85 | if HAS_Z: 86 | z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 87 | 88 | dB = B[None, :] * dt[:, None] 89 | state = state * dA + dB * x[:, None] 90 | tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) 91 | out = tl.sum(state * C[None, :], axis=1) 92 | if HAS_D: 93 | out += x * D 94 | if HAS_Z: 95 | out *= z * tl.sigmoid(z) 96 | tl.store(out_ptrs, out, mask=offs_m < dim) 97 | 98 | 99 | def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): 100 | """ 101 | Argument: 102 | state: (batch, dim, dstate) 103 | x: (batch, dim) 104 | dt: (batch, dim) 105 | A: (dim, dstate) 106 | B: (batch, dstate) 107 | C: (batch, dstate) 108 | D: (dim,) 109 | z: (batch, dim) 110 | dt_bias: (dim,) 111 | Return: 112 | out: (batch, dim) 113 | """ 114 | batch, dim, dstate = state.shape 115 | assert x.shape == (batch, dim) 116 | assert dt.shape == x.shape 117 | assert A.shape == (dim, dstate) 118 | assert B.shape == (batch, dstate) 119 | assert C.shape == B.shape 120 | if D is not None: 121 | assert D.shape == (dim,) 122 | if z is not None: 123 | assert z.shape == x.shape 124 | if dt_bias is not None: 125 | assert dt_bias.shape == (dim,) 126 | out = torch.empty_like(x) 127 | grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch) 128 | z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0)) 129 | # We don't want autotune since it will overwrite the state 130 | # We instead tune by hand. 131 | BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 132 | else ((16, 4) if dstate <= 32 else 133 | ((8, 4) if dstate <= 64 else 134 | ((4, 4) if dstate <= 128 else 135 | ((4, 8)))))) 136 | with torch.cuda.device(x.device.index): 137 | _selective_scan_update_kernel[grid]( 138 | state, x, dt, dt_bias, A, B, C, D, z, out, 139 | batch, dim, dstate, 140 | state.stride(0), state.stride(1), state.stride(2), 141 | x.stride(0), x.stride(1), 142 | dt.stride(0), dt.stride(1), 143 | dt_bias.stride(0) if dt_bias is not None else 0, 144 | A.stride(0), A.stride(1), 145 | B.stride(0), B.stride(1), 146 | C.stride(0), C.stride(1), 147 | D.stride(0) if D is not None else 0, 148 | z_strides[0], z_strides[1], 149 | out.stride(0), out.stride(1), 150 | dt_softplus, 151 | BLOCK_SIZE_M, 152 | num_warps=num_warps, 153 | ) 154 | return out 155 | 156 | 157 | def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): 158 | """ 159 | Argument: 160 | state: (batch, dim, dstate) 161 | x: (batch, dim) 162 | dt: (batch, dim) 163 | A: (dim, dstate) 164 | B: (batch, dstate) 165 | C: (batch, dstate) 166 | D: (dim,) 167 | z: (batch, dim) 168 | dt_bias: (dim,) 169 | Return: 170 | out: (batch, dim) 171 | """ 172 | batch, dim, dstate = state.shape 173 | assert x.shape == (batch, dim) 174 | assert dt.shape == x.shape 175 | assert A.shape == (dim, dstate) 176 | assert B.shape == (batch, dstate) 177 | assert C.shape == B.shape 178 | if D is not None: 179 | assert D.shape == (dim,) 180 | if z is not None: 181 | assert z.shape == x.shape 182 | if dt_bias is not None: 183 | assert dt_bias.shape == (dim,) 184 | dt = dt + dt_bias 185 | dt = F.softplus(dt) if dt_softplus else dt 186 | dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate) 187 | dB = rearrange(dt, "b d -> b d 1") * rearrange(B, "b n -> b 1 n") # (batch, dim, dstate) 188 | state.copy_(state * dA + dB * rearrange(x, "b d -> b d 1")) # (batch, dim, dstate 189 | out = torch.einsum("bdn,bn->bd", state.to(C.dtype), C) 190 | if D is not None: 191 | out += (x * D).to(out.dtype) 192 | return (out if z is None else out * F.silu(z)).to(x.dtype) 193 | -------------------------------------------------------------------------------- /espnet2/asr/decoder/unimodal_attention_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Shigeki Karita 2 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 3 | ''' 4 | Author: FnoY fangying@westlake.edu.cn 5 | LastEditTime: 2023-09-15 14:06:18 6 | FilePath: /espnet/espnet2/asr/decoder/unimodal_attention_decoder.py 7 | ''' 8 | 9 | """UMA Decoder definition.""" 10 | from typing import Any, List, Sequence, Tuple 11 | 12 | import torch 13 | from typeguard import check_argument_types 14 | 15 | from espnet2.asr.decoder.abs_decoder import AbsDecoder 16 | from espnet.nets.pytorch_backend.nets_utils import make_pad_mask 17 | from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention 18 | from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding 19 | from espnet.nets.pytorch_backend.transformer.encoder_layer import EncoderLayer 20 | from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm 21 | from espnet.nets.pytorch_backend.transformer.multi_layer_conv import ( 22 | Conv1dLinear, 23 | MultiLayeredConv1d, 24 | ) 25 | from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( 26 | PositionwiseFeedForward, 27 | ) 28 | from espnet.nets.pytorch_backend.transformer.repeat import repeat 29 | from espnet.nets.pytorch_backend.transformer.subsampling import ( 30 | Conv2dSubsampling, 31 | Conv2dSubsampling1, 32 | Conv2dSubsampling2, 33 | Conv2dSubsampling6, 34 | Conv2dSubsampling8, 35 | TooShortUttError, 36 | check_short_utt, 37 | ) 38 | 39 | from espnet2.asr.ctc import CTC 40 | 41 | 42 | class UnimodalAttentionDecoder(AbsDecoder): 43 | """Transformer encoder module. 44 | 45 | Args: 46 | input_size: input dim 47 | output_size: dimension of attention 48 | attention_heads: the number of heads of multi head attention 49 | linear_units: the number of units of position-wise feed forward 50 | num_blocks: the number of decoder blocks 51 | dropout_rate: dropout rate 52 | attention_dropout_rate: dropout rate in attention 53 | positional_dropout_rate: dropout rate after adding positional encoding 54 | input_layer: input layer type 55 | pos_enc_class: PositionalEncoding or ScaledPositionalEncoding 56 | normalize_before: whether to use layer_norm before the first block 57 | concat_after: whether to concat attention layer's input and output 58 | if True, additional linear will be applied. 59 | i.e. x -> x + linear(concat(x, att(x))) 60 | if False, no additional linear will be applied. 61 | i.e. x -> x + att(x) 62 | positionwise_layer_type: linear of conv1d 63 | positionwise_conv_kernel_size: kernel size of positionwise conv1d layer 64 | padding_idx: padding_idx for input_layer=embed 65 | """ 66 | 67 | def __init__( 68 | self, 69 | vocab_size: int, 70 | encoder_output_size: int, 71 | output_size: int = 256, 72 | attention_heads: int = 4, 73 | linear_units: int = 2048, 74 | num_blocks: int = 6, 75 | dropout_rate: float = 0.1, 76 | positional_dropout_rate: float = 0.1, 77 | attention_dropout_rate: float = 0.0, 78 | pos_enc_class=PositionalEncoding, 79 | normalize_before: bool = True, 80 | concat_after: bool = False, 81 | positionwise_layer_type: str = "linear", 82 | positionwise_conv_kernel_size: int = 1, 83 | padding_idx: int = -1, 84 | interctc_layer_idx: List[int] = [], 85 | interctc_use_conditioning: bool = False, 86 | ): 87 | assert check_argument_types() 88 | super().__init__() 89 | output_size = encoder_output_size 90 | self._output_size = output_size 91 | 92 | self.embed = torch.nn.Sequential( 93 | torch.nn.Linear(encoder_output_size, output_size), 94 | torch.nn.LayerNorm(output_size), 95 | torch.nn.Dropout(dropout_rate), 96 | torch.nn.ReLU(), 97 | pos_enc_class(output_size, positional_dropout_rate), 98 | ) 99 | 100 | self.normalize_before = normalize_before 101 | if positionwise_layer_type == "linear": 102 | positionwise_layer = PositionwiseFeedForward 103 | positionwise_layer_args = ( 104 | output_size, 105 | linear_units, 106 | dropout_rate, 107 | ) 108 | elif positionwise_layer_type == "conv1d": 109 | positionwise_layer = MultiLayeredConv1d 110 | positionwise_layer_args = ( 111 | output_size, 112 | linear_units, 113 | positionwise_conv_kernel_size, 114 | dropout_rate, 115 | ) 116 | elif positionwise_layer_type == "conv1d-linear": 117 | positionwise_layer = Conv1dLinear 118 | positionwise_layer_args = ( 119 | output_size, 120 | linear_units, 121 | positionwise_conv_kernel_size, 122 | dropout_rate, 123 | ) 124 | else: 125 | raise NotImplementedError("Support only linear or conv1d.") 126 | self.encoders = repeat( 127 | num_blocks, 128 | lambda lnum: EncoderLayer( 129 | output_size, 130 | MultiHeadedAttention( 131 | attention_heads, output_size, attention_dropout_rate 132 | ), 133 | positionwise_layer(*positionwise_layer_args), 134 | dropout_rate, 135 | normalize_before, 136 | concat_after, 137 | ), 138 | ) 139 | if self.normalize_before: 140 | self.after_norm = LayerNorm(output_size) 141 | 142 | self.interctc_layer_idx = interctc_layer_idx 143 | if len(interctc_layer_idx) > 0: 144 | assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks 145 | self.interctc_use_conditioning = interctc_use_conditioning 146 | self.conditioning_layer = None 147 | 148 | 149 | def output_size(self) -> int: 150 | return self._output_size 151 | 152 | def forward( 153 | self, 154 | hs_pad: torch.Tensor, 155 | hlens: torch.Tensor, 156 | ys_in_pad: torch.Tensor, 157 | ys_in_lens: torch.Tensor, 158 | ctc: CTC = None, 159 | ) -> Tuple[torch.Tensor, torch.Tensor]: 160 | """Embed positions in tensor. 161 | 162 | Args: 163 | xs_pad: input tensor (B, L, D) 164 | ilens: input length (B) 165 | prev_states: Not to be used now. 166 | Returns: 167 | position embedded tensor and mask 168 | """ 169 | 170 | masks = (~make_pad_mask(hlens)[:, None, :]).to(hs_pad.device) 171 | 172 | hs_pad = self.embed(hs_pad) 173 | 174 | 175 | intermediate_outs = [] 176 | if len(self.interctc_layer_idx) == 0: 177 | hs_pad, masks = self.encoders(hs_pad, masks) 178 | else: 179 | for layer_idx, encoder_layer in enumerate(self.encoders): 180 | hs_pad, masks = encoder_layer(hs_pad, masks) 181 | 182 | if layer_idx + 1 in self.interctc_layer_idx: 183 | encoder_out = hs_pad 184 | if isinstance(encoder_out, tuple): 185 | encoder_out = encoder_out[0] 186 | 187 | # intermediate outputs are also normalized 188 | if self.normalize_before: 189 | encoder_out = self.after_norm(encoder_out) 190 | 191 | intermediate_outs.append((layer_idx + 1, encoder_out)) 192 | 193 | if self.interctc_use_conditioning: 194 | ctc_out = ctc.softmax(encoder_out) 195 | 196 | if isinstance(hs_pad, tuple): 197 | x, pos_emb = hs_pad 198 | x = x + self.conditioning_layer(ctc_out) 199 | hs_pad = (x, pos_emb) 200 | else: 201 | hs_pad = hs_pad + self.conditioning_layer(ctc_out) 202 | 203 | if isinstance(hs_pad, tuple): 204 | hs_pad = hs_pad[0] 205 | 206 | if self.normalize_before: 207 | hs_pad = self.after_norm(hs_pad) 208 | 209 | olens = masks.squeeze(1).sum(1) 210 | 211 | if len(intermediate_outs) > 0: 212 | return (hs_pad, intermediate_outs), olens 213 | 214 | return hs_pad, olens 215 | -------------------------------------------------------------------------------- /espnet2/asr/mamba_ssm/modules/mamba_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao, Albert Gu. 2 | 3 | import math 4 | from typing import Optional 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch import Tensor 10 | 11 | from einops import rearrange, repeat 12 | 13 | from espnet2.asr.mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn 14 | 15 | try: 16 | from causal_conv1d import causal_conv1d_fn, causal_conv1d_update 17 | except ImportError: 18 | causal_conv1d_fn, causal_conv1d_update = None, None 19 | # causal_conv1d_fn, causal_conv1d_update = None, None 20 | 21 | try: 22 | from espnet2.asr.mamba_ssm.ops.triton.selective_state_update import selective_state_update 23 | except ImportError: 24 | selective_state_update = None 25 | 26 | try: 27 | from espnet2.asr.mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn 28 | except ImportError: 29 | RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None 30 | 31 | 32 | class Mamba(nn.Module): 33 | def __init__( 34 | self, 35 | d_model, 36 | d_state=32, 37 | d_conv=4, 38 | expand=2, 39 | dt_rank="auto", 40 | dt_min=0.001, 41 | dt_max=0.1, 42 | dt_init="random", 43 | dt_scale=1.0, 44 | dt_init_floor=1e-4, 45 | conv_bias=True, 46 | bias=False, 47 | use_fast_path=False, # Fused kernel options 48 | layer_idx=None, 49 | device=None, 50 | dtype=None, 51 | ): 52 | factory_kwargs = {"device": device, "dtype": dtype} 53 | super().__init__() 54 | self.d_model = d_model 55 | self.d_state = d_state 56 | self.d_conv = d_conv 57 | self.expand = expand 58 | self.d_inner = int(self.expand * self.d_model) 59 | self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank 60 | self.use_fast_path = use_fast_path 61 | self.layer_idx = layer_idx 62 | 63 | self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) 64 | 65 | self.conv1d = nn.Conv1d( 66 | in_channels=self.d_inner, 67 | out_channels=self.d_inner, 68 | bias=conv_bias, 69 | kernel_size=d_conv, 70 | groups=self.d_inner, 71 | padding=d_conv - 1, 72 | **factory_kwargs, 73 | ) 74 | 75 | self.activation = "silu" 76 | self.act = nn.SiLU() 77 | 78 | self.x_proj = nn.Linear( 79 | self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs 80 | ) 81 | self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) 82 | 83 | # Initialize special dt projection to preserve variance at initialization 84 | dt_init_std = self.dt_rank**-0.5 * dt_scale 85 | if dt_init == "constant": 86 | nn.init.constant_(self.dt_proj.weight, dt_init_std) 87 | elif dt_init == "random": 88 | nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) 89 | else: 90 | raise NotImplementedError 91 | 92 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 93 | dt = torch.exp( 94 | torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) 95 | + math.log(dt_min) 96 | ).clamp(min=dt_init_floor) 97 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 98 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 99 | with torch.no_grad(): 100 | self.dt_proj.bias.copy_(inv_dt) 101 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 102 | self.dt_proj.bias._no_reinit = True 103 | 104 | # S4D real initialization 105 | A = repeat( 106 | torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), 107 | "n -> d n", 108 | d=self.d_inner, 109 | ).contiguous() 110 | A_log = torch.log(A) # Keep A_log in fp32 111 | self.A_log = nn.Parameter(A_log) 112 | self.A_log._no_weight_decay = True 113 | 114 | # D "skip" parameter 115 | self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 116 | self.D._no_weight_decay = True 117 | 118 | self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) 119 | 120 | def forward(self, hidden_states, inference_params=None): 121 | """ 122 | hidden_states: (B, L, D) 123 | Returns: same shape as hidden_states 124 | """ 125 | batch, seqlen, dim = hidden_states.shape 126 | 127 | conv_state, ssm_state = None, None 128 | if inference_params is not None: 129 | conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) 130 | if inference_params.seqlen_offset > 0: 131 | # The states are updated inplace 132 | out, _, _ = self.step(hidden_states, conv_state, ssm_state) 133 | return out 134 | 135 | # We do matmul and transpose BLH -> HBL at the same time 136 | xz = rearrange( 137 | self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"), 138 | "d (b l) -> b d l", 139 | l=seqlen, 140 | ) 141 | if self.in_proj.bias is not None: 142 | xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") 143 | 144 | A = -torch.exp(self.A_log.float()) # (d_inner, d_state) 145 | # In the backward pass we write dx and dz next to each other to avoid torch.cat 146 | if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states 147 | out = mamba_inner_fn( 148 | xz, 149 | self.conv1d.weight, 150 | self.conv1d.bias, 151 | self.x_proj.weight, 152 | self.dt_proj.weight, 153 | self.out_proj.weight, 154 | self.out_proj.bias, 155 | A, 156 | None, # input-dependent B 157 | None, # input-dependent C 158 | self.D.float(), 159 | delta_bias=self.dt_proj.bias.float(), 160 | delta_softplus=True, 161 | ) 162 | else: 163 | x, z = xz.chunk(2, dim=1) 164 | # Compute short convolution 165 | if conv_state is not None: 166 | # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv 167 | # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. 168 | conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W) 169 | if causal_conv1d_fn is None: 170 | x = self.act(self.conv1d(x)[..., :seqlen]) 171 | else: 172 | assert self.activation in ["silu", "swish"] 173 | x = causal_conv1d_fn( 174 | x=x, 175 | weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), 176 | bias=self.conv1d.bias, 177 | activation=self.activation, 178 | ) 179 | 180 | # We're careful here about the layout, to avoid extra transposes. 181 | # We want dt to have d as the slowest moving dimension 182 | # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. 183 | x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) 184 | dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) 185 | dt = self.dt_proj.weight @ dt.t() 186 | dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) 187 | B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() 188 | C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() 189 | assert self.activation in ["silu", "swish"] 190 | y = selective_scan_fn( 191 | x, 192 | dt, 193 | A, 194 | B, 195 | C, 196 | self.D.float(), 197 | z=z, 198 | delta_bias=self.dt_proj.bias.float(), 199 | delta_softplus=True, 200 | return_last_state=ssm_state is not None, 201 | ) 202 | if ssm_state is not None: 203 | y, last_state = y 204 | ssm_state.copy_(last_state) 205 | y = rearrange(y, "b d l -> b l d") 206 | out = self.out_proj(y) 207 | return out 208 | 209 | def step(self, hidden_states, conv_state, ssm_state): 210 | dtype = hidden_states.dtype 211 | assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" 212 | xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) 213 | x, z = xz.chunk(2, dim=-1) # (B D) 214 | 215 | # Conv step 216 | if causal_conv1d_update is None: 217 | conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) 218 | conv_state[:, :, -1] = x 219 | x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) 220 | if self.conv1d.bias is not None: 221 | x = x + self.conv1d.bias 222 | x = self.act(x).to(dtype=dtype) 223 | else: 224 | x = causal_conv1d_update( 225 | x, 226 | conv_state, 227 | rearrange(self.conv1d.weight, "d 1 w -> d w"), 228 | self.conv1d.bias, 229 | self.activation, 230 | ) 231 | 232 | x_db = self.x_proj(x) # (B dt_rank+2*d_state) 233 | dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) 234 | # Don't add dt_bias here 235 | dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) 236 | A = -torch.exp(self.A_log.float()) # (d_inner, d_state) 237 | 238 | # SSM step 239 | if selective_state_update is None: 240 | # Discretize A and B 241 | dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) 242 | dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) 243 | dB = torch.einsum("bd,bn->bdn", dt, B) 244 | ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) 245 | y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) 246 | y = y + self.D.to(dtype) * x 247 | y = y * self.act(z) # (B D) 248 | else: 249 | y = selective_state_update( 250 | ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True 251 | ) 252 | 253 | out = self.out_proj(y) 254 | return out.unsqueeze(1), conv_state, ssm_state 255 | 256 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 257 | device = self.out_proj.weight.device 258 | conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype 259 | conv_state = torch.zeros( 260 | batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype 261 | ) 262 | ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype 263 | # ssm_dtype = torch.float32 264 | ssm_state = torch.zeros( 265 | batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype 266 | ) 267 | return conv_state, ssm_state 268 | 269 | def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): 270 | assert self.layer_idx is not None 271 | if self.layer_idx not in inference_params.key_value_memory_dict: 272 | batch_shape = (batch_size,) 273 | conv_state = torch.zeros( 274 | batch_size, 275 | self.d_model * self.expand, 276 | self.d_conv, 277 | device=self.conv1d.weight.device, 278 | dtype=self.conv1d.weight.dtype, 279 | ) 280 | ssm_state = torch.zeros( 281 | batch_size, 282 | self.d_model * self.expand, 283 | self.d_state, 284 | device=self.dt_proj.weight.device, 285 | dtype=self.dt_proj.weight.dtype, 286 | # dtype=torch.float32, 287 | ) 288 | inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) 289 | else: 290 | conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] 291 | # TODO: What if batch size changes between generation, and we reuse the same states? 292 | if initialize_states: 293 | conv_state.zero_() 294 | ssm_state.zero_() 295 | return conv_state, ssm_state 296 | 297 | 298 | class Block(nn.Module): 299 | def __init__( 300 | self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False 301 | ): 302 | """ 303 | Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" 304 | 305 | This Block has a slightly different structure compared to a regular 306 | prenorm Transformer block. 307 | The standard block is: LN -> MHA/MLP -> Add. 308 | [Ref: https://arxiv.org/abs/2002.04745] 309 | Here we have: Add -> LN -> Mixer, returning both 310 | the hidden_states (output of the mixer) and the residual. 311 | This is purely for performance reasons, as we can fuse add and LayerNorm. 312 | The residual needs to be provided (except for the very first block). 313 | """ 314 | super().__init__() 315 | self.residual_in_fp32 = residual_in_fp32 316 | self.fused_add_norm = fused_add_norm 317 | self.mixer = mixer_cls(dim) 318 | self.norm = norm_cls(dim) 319 | if self.fused_add_norm: 320 | assert RMSNorm is not None, "RMSNorm import fails" 321 | assert isinstance( 322 | self.norm, (nn.LayerNorm, RMSNorm) 323 | ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" 324 | 325 | def forward( 326 | self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None 327 | ): 328 | r"""Pass the input through the encoder layer. 329 | 330 | Args: 331 | hidden_states: the sequence to the encoder layer (required). 332 | residual: hidden_states = Mixer(LN(residual)) 333 | """ 334 | if not self.fused_add_norm: 335 | residual = (hidden_states + residual) if residual is not None else hidden_states 336 | hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) 337 | if self.residual_in_fp32: 338 | residual = residual.to(torch.float32) 339 | else: 340 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn 341 | hidden_states, residual = fused_add_norm_fn( 342 | hidden_states, 343 | self.norm.weight, 344 | self.norm.bias, 345 | residual=residual, 346 | prenorm=True, 347 | residual_in_fp32=self.residual_in_fp32, 348 | eps=self.norm.eps, 349 | ) 350 | hidden_states = self.mixer(hidden_states, inference_params=inference_params) 351 | return hidden_states, residual 352 | 353 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 354 | return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 355 | -------------------------------------------------------------------------------- /espnet2/asr/encoder/mamba_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Shigeki Karita 2 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import torch 6 | import torch.nn as nn 7 | from typeguard import check_argument_types 8 | 9 | from espnet2.asr.ctc import CTC 10 | from espnet.nets.pytorch_backend.nets_utils import make_pad_mask 11 | from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding 12 | from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm 13 | from espnet.nets.pytorch_backend.nets_utils import get_activation 14 | from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule 15 | 16 | from espnet.nets.pytorch_backend.transformer.repeat import repeat 17 | from espnet.nets.pytorch_backend.transformer.subsampling import ( 18 | Conv2dSubsampling, 19 | TooShortUttError, 20 | check_short_utt, 21 | ) 22 | 23 | import logging 24 | import math 25 | from functools import partial 26 | 27 | import torch 28 | import torch.nn as nn 29 | 30 | from espnet2.asr.mamba_ssm.modules.mamba_simple import Mamba, Block 31 | 32 | try: 33 | from espnet2.asr.mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn 34 | except ImportError: 35 | RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None 36 | 37 | from dataclasses import dataclass, field 38 | 39 | def create_block( 40 | d_model, 41 | ssm_cfg=None, 42 | norm_epsilon=1e-12, 43 | rms_norm=False, 44 | residual_in_fp32=False, 45 | fused_add_norm=False, 46 | layer_idx=None, 47 | device=None, 48 | dtype=None, 49 | ): 50 | if ssm_cfg is None: 51 | ssm_cfg = {} 52 | factory_kwargs = {"device": device, "dtype": dtype} 53 | mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) 54 | norm_cls = partial( 55 | nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs 56 | ) 57 | block = Block( 58 | d_model, 59 | mixer_cls, 60 | norm_cls=norm_cls, 61 | fused_add_norm=fused_add_norm, 62 | residual_in_fp32=residual_in_fp32, 63 | ) 64 | block.layer_idx = layer_idx 65 | return block 66 | 67 | 68 | # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 69 | def _init_weights( 70 | module, 71 | n_layer, 72 | initializer_range=0.02, # Now only used for embedding layer. 73 | rescale_prenorm_residual=True, 74 | n_residuals_per_layer=1, # Change to 2 if we have MLP 75 | ): 76 | if isinstance(module, nn.Linear): 77 | if module.bias is not None: 78 | if not getattr(module.bias, "_no_reinit", False): 79 | nn.init.zeros_(module.bias) 80 | elif isinstance(module, nn.Embedding): 81 | nn.init.normal_(module.weight, std=initializer_range) 82 | 83 | if rescale_prenorm_residual: 84 | # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: 85 | # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale 86 | # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. 87 | # > -- GPT-2 :: https://openai.com/blog/better-language-models/ 88 | # 89 | # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py 90 | for name, p in module.named_parameters(): 91 | if name in ["out_proj.weight", "fc2.weight"]: 92 | # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block 93 | # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) 94 | # We need to reinit p since this code could be called multiple times 95 | # Having just p *= scale would repeatedly scale it down 96 | nn.init.kaiming_uniform_(p, a=math.sqrt(5)) 97 | with torch.no_grad(): 98 | p /= math.sqrt(n_residuals_per_layer * n_layer) 99 | 100 | 101 | 102 | 103 | class CausalConv2dSubsampling(nn.Module): 104 | def __init__(self, in_channels, out_channels, kernel_size, bias=True): 105 | super(CausalConv2dSubsampling, self).__init__() 106 | self.padding = (kernel_size[0] - 1, 0) 107 | 108 | self.subsample1 = nn.Sequential( 109 | nn.Conv2d( 110 | 1, 111 | out_channels, 112 | kernel_size=kernel_size, 113 | padding=self.padding, 114 | stride=2, 115 | bias=bias, 116 | ), 117 | nn.ReLU(), 118 | ) 119 | 120 | self.subsample2 = nn.Sequential( 121 | nn.Conv2d( 122 | out_channels, 123 | out_channels, 124 | kernel_size=kernel_size, 125 | padding=self.padding, 126 | stride=2, 127 | bias=bias, 128 | ), 129 | nn.ReLU(), 130 | ) 131 | 132 | self.out = torch.nn.Sequential( 133 | torch.nn.Linear(out_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels), 134 | torch.nn.Dropout(0.1), 135 | ) 136 | 137 | def forward(self, x, x_mask): 138 | x = x.unsqueeze(1) # (b, c, t, f) 139 | x = self.subsample1(x) 140 | if self.padding[0] != 0: 141 | x = x[:, :, :-self.padding[0], :] 142 | x = self.subsample2(x) 143 | if self.padding[0] != 0: 144 | x = x[:, :, :-self.padding[0], :] 145 | b, c, t, f = x.size() 146 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) 147 | if x_mask is None: 148 | return x, None 149 | return x, x_mask[:, :, :-2:2][:, :, :-2:2] 150 | 151 | 152 | class MambaEncoder(nn.Module): 153 | """Transformer encoder module. 154 | 155 | Args: 156 | 157 | """ 158 | 159 | def __init__( 160 | self, 161 | input_size: int, 162 | output_size: int = 256, 163 | num_blocks: int = 6, 164 | dropout_rate: float = 0.1, 165 | input_layer: Optional[str] = "conv2d", 166 | normalize_before: bool = False, 167 | ssm_cfg=None, 168 | norm_epsilon: float = 1e-12, 169 | rms_norm: bool = False, 170 | initializer_cfg=None, 171 | fused_add_norm=False, 172 | residual_in_fp32=False, 173 | device=None, 174 | dtype=None, 175 | lookahead_kernel: int = 0, 176 | right_context: int = 0, 177 | interctc_layer_idx: List[int] = [], 178 | interctc_use_conditioning: bool = False, 179 | ): 180 | assert check_argument_types() 181 | super().__init__() 182 | self._output_size = output_size 183 | factory_kwargs = {"device": device, "dtype": dtype} 184 | self.residual_in_fp32 = residual_in_fp32 185 | 186 | if input_layer == "causal_conv2d": 187 | self.embed = CausalConv2dSubsampling(input_size, output_size, (3, 3), bias=True) 188 | elif input_layer == "conv2d": 189 | self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate) 190 | elif input_layer is None: 191 | if input_size == output_size: 192 | self.embed = None 193 | else: 194 | self.embed = torch.nn.Linear(input_size, output_size) 195 | else: 196 | raise ValueError("unknown input_layer: " + input_layer) 197 | 198 | 199 | self.normalize_before = normalize_before 200 | if self.normalize_before: 201 | self.norm_before_mamba = LayerNorm(output_size) 202 | d_model = output_size 203 | n_layer = num_blocks 204 | # We change the order of residual and layer norm: 205 | # Instead of LN -> Attn / MLP -> Add, we do: 206 | # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and 207 | # the main branch (output of MLP / Mixer). The model definition is unchanged. 208 | # This is for performance reason: we can fuse add + layer_norm. 209 | self.fused_add_norm = fused_add_norm 210 | if self.fused_add_norm: 211 | if layer_norm_fn is None or rms_norm_fn is None: 212 | raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") 213 | 214 | self.layers = nn.ModuleList( 215 | [ 216 | create_block( 217 | d_model, 218 | ssm_cfg=ssm_cfg, 219 | norm_epsilon=norm_epsilon, 220 | rms_norm=rms_norm, 221 | residual_in_fp32=residual_in_fp32, 222 | fused_add_norm=fused_add_norm, 223 | layer_idx=i, 224 | **factory_kwargs, 225 | ) 226 | for i in range(n_layer) 227 | ] 228 | ) 229 | 230 | self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)( 231 | d_model, eps=norm_epsilon, **factory_kwargs 232 | ) 233 | 234 | 235 | self.lookahead_kernel = lookahead_kernel 236 | self.right_context = right_context 237 | self.left_context = lookahead_kernel - 1 - self.right_context 238 | 239 | if lookahead_kernel > 0: 240 | self.lookahead_cnn = nn.Conv1d( 241 | output_size, 242 | output_size, 243 | lookahead_kernel, 244 | stride=1, 245 | padding=lookahead_kernel//2, 246 | bias=True, 247 | ) 248 | 249 | activation_type = "swish" 250 | self.activation = get_activation(activation_type) 251 | 252 | self.lookahead_norm = LayerNorm(output_size) 253 | self.dropout = torch.nn.Dropout(dropout_rate) 254 | 255 | 256 | if not hasattr(self, "lookahead_cnn"): 257 | self.encoder_out_embed = torch.nn.Sequential( 258 | torch.nn.Linear(output_size, output_size), 259 | get_activation('swish'), 260 | LayerNorm(output_size), 261 | torch.nn.Dropout(dropout_rate), 262 | ) 263 | 264 | self.interctc_layer_idx = interctc_layer_idx 265 | if len(interctc_layer_idx) > 0: 266 | assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) <= num_blocks 267 | self.interctc_use_conditioning = interctc_use_conditioning 268 | self.conditioning_layer = None 269 | 270 | 271 | self.apply( 272 | partial( 273 | _init_weights, 274 | n_layer=n_layer, 275 | **(initializer_cfg if initializer_cfg is not None else {}), 276 | ) 277 | ) 278 | 279 | 280 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 281 | return { 282 | i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 283 | for i, layer in enumerate(self.layers) 284 | } 285 | 286 | def output_size(self) -> int: 287 | return self._output_size 288 | 289 | def forward( 290 | self, 291 | xs_pad: torch.Tensor, 292 | ilens: torch.Tensor, 293 | prev_states: torch.Tensor = None, 294 | ctc: CTC = None, 295 | inference_params=None, 296 | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: 297 | """Embed positions in tensor. 298 | 299 | Args: 300 | xs_pad: input tensor (B, L, D) 301 | ilens: input length (B) 302 | prev_states: Not to be used now. 303 | Returns: 304 | position embedded tensor and mask 305 | """ 306 | 307 | masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) 308 | 309 | if self.embed is None: 310 | xs_pad = xs_pad 311 | elif ( 312 | isinstance(self.embed, Conv2dSubsampling) 313 | or isinstance(self.embed, CausalConv2dSubsampling) 314 | ): 315 | short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) 316 | if short_status: 317 | raise TooShortUttError( 318 | f"has {xs_pad.size(1)} frames and is too short for subsampling " 319 | + f"(it needs more than {limit_size} frames), return empty results", 320 | xs_pad.size(1), 321 | limit_size, 322 | ) 323 | xs_pad, masks = self.embed(xs_pad, masks) 324 | else: 325 | xs_pad = self.embed(xs_pad) 326 | 327 | if hasattr(self, "norm_before_mamba"): 328 | xs_pad = self.norm_before_mamba(xs_pad) 329 | 330 | residual = None 331 | 332 | intermediate_outs = [] 333 | if len(self.interctc_layer_idx) == 0: 334 | for layer_idx, layer in enumerate(self.layers): 335 | xs_pad, residual = layer(xs_pad, residual, inference_params=inference_params) 336 | if hasattr(self, "mamba_layer_dropout"): 337 | xs_pad = self.mamba_layer_dropout(xs_pad) 338 | 339 | else: 340 | for layer_idx, layer in enumerate(self.layers): 341 | xs_pad, residual = layer(xs_pad, residual, inference_params=inference_params) 342 | 343 | if layer_idx + 1 in self.interctc_layer_idx: 344 | encoder_out = xs_pad 345 | if isinstance(encoder_out, tuple): 346 | encoder_out = encoder_out[0] 347 | 348 | if not self.fused_add_norm: 349 | residual = (encoder_out + residual) if residual is not None else encoder_out 350 | encoder_out = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) 351 | else: 352 | # Set prenorm=False here since we don't need the residual 353 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn 354 | encoder_out = fused_add_norm_fn( 355 | encoder_out, 356 | self.norm_f.weight, 357 | self.norm_f.bias, 358 | eps=self.norm_f.eps, 359 | residual=residual, 360 | prenorm=False, 361 | residual_in_fp32=self.residual_in_fp32, 362 | ) 363 | 364 | encoder_out = self.encoder_out_embed(encoder_out) 365 | 366 | intermediate_outs.append((layer_idx + 1, encoder_out)) 367 | 368 | if self.interctc_use_conditioning: 369 | ctc_out = ctc.softmax(encoder_out) 370 | 371 | if isinstance(xs_pad, tuple): 372 | x, pos_emb = xs_pad 373 | x = x + self.conditioning_layer(ctc_out) 374 | xs_pad = (x, pos_emb) 375 | else: 376 | xs_pad = xs_pad + self.conditioning_layer(ctc_out) 377 | 378 | if isinstance(xs_pad, tuple): 379 | xs_pad = xs_pad[0] 380 | 381 | if not self.fused_add_norm: 382 | residual = (xs_pad + residual) if residual is not None else xs_pad 383 | xs_pad = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) 384 | else: 385 | # Set prenorm=False here since we don't need the residual 386 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn 387 | xs_pad = fused_add_norm_fn( 388 | xs_pad, 389 | self.norm_f.weight, 390 | self.norm_f.bias, 391 | eps=self.norm_f.eps, 392 | residual=residual, 393 | prenorm=False, 394 | residual_in_fp32=self.residual_in_fp32, 395 | ) 396 | 397 | if hasattr(self, "lookahead_cnn"): 398 | xs_pad = self.lookahead_cnn(xs_pad.transpose(1, 2)).transpose(1, 2) 399 | xs_pad = self.activation(xs_pad) 400 | xs_pad = self.lookahead_norm(xs_pad) 401 | xs_pad = self.dropout(xs_pad) 402 | 403 | if hasattr(self, "encoder_out_embed"): 404 | xs_pad = self.encoder_out_embed(xs_pad) 405 | 406 | 407 | olens = masks.squeeze(1).sum(1) 408 | if len(intermediate_outs) > 0: 409 | return (xs_pad, intermediate_outs), olens, None 410 | 411 | return xs_pad, olens, None 412 | 413 | 414 | -------------------------------------------------------------------------------- /espnet2/asr/mamba_ssm/utils/generation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Albert Gu, Tri Dao. 2 | import gc 3 | import time 4 | from collections import namedtuple 5 | from dataclasses import dataclass, field 6 | from functools import partial 7 | from typing import Callable, Optional, Sequence, Union 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from einops import rearrange, repeat 12 | from torch import Tensor 13 | from torch.profiler import ProfilerActivity, profile, record_function 14 | from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer 15 | 16 | 17 | @dataclass 18 | class InferenceParams: 19 | """Inference parameters that are passed to the main model in order 20 | to efficienly calculate and store the context during inference.""" 21 | 22 | max_seqlen: int 23 | max_batch_size: int 24 | seqlen_offset: int = 0 25 | batch_size_offset: int = 0 26 | key_value_memory_dict: dict = field(default_factory=dict) 27 | lengths_per_sample: Optional[Tensor] = None 28 | 29 | def reset(self, max_seqlen, max_batch_size): 30 | self.max_seqlen = max_seqlen 31 | self.max_batch_size = max_batch_size 32 | self.seqlen_offset = 0 33 | if self.lengths_per_sample is not None: 34 | self.lengths_per_sample.zero_() 35 | 36 | 37 | def modify_logits_for_min_p_filtering(logits, min_p): 38 | """Set the logits for none min_p values to -inf. Done in-place.""" 39 | if min_p <= 0.0 or min_p >= 1.0: 40 | return 41 | indices_to_remove = logits < min_p 42 | logits.masked_fill_(indices_to_remove, float("-Inf")) 43 | # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py 44 | # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231 45 | def modify_logits_for_top_k_filtering(logits, top_k): 46 | """Set the logits for none top-k values to -inf. Done in-place.""" 47 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 48 | logits.masked_fill_(indices_to_remove, float("-Inf")) 49 | 50 | 51 | # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py 52 | # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170 53 | def modify_logits_for_top_p_filtering(logits, top_p): 54 | """Set the logits for none top-p values to -inf. Done in-place.""" 55 | if top_p <= 0.0 or top_p >= 1.0: 56 | return 57 | # First sort and calculate cumulative sum of probabilities. 58 | sorted_logits, sorted_indices = torch.sort(logits, descending=False) 59 | cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) 60 | # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) 61 | sorted_indices_to_remove = cumulative_probs <= (1 - top_p) 62 | # scatter sorted tensors to original indexing 63 | indices_to_remove = sorted_indices_to_remove.scatter( 64 | 1, sorted_indices, sorted_indices_to_remove 65 | ) 66 | logits.masked_fill_(indices_to_remove, float("-inf")) 67 | 68 | 69 | def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0): 70 | """Apply repetition penalty. See https://arxiv.org/abs/1909.05858 71 | logits: (batch_size, vocab_size) 72 | prev_output_tokens: (batch_size, seq_len) 73 | """ 74 | if repetition_penalty == 1.0: 75 | return logits 76 | score = torch.gather(logits, 1, prev_output_tokens) 77 | # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability 78 | score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty) 79 | logits.scatter_(1, prev_output_tokens, score) 80 | return logits 81 | 82 | 83 | def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0): 84 | """Sample from top-k logits. 85 | Arguments: 86 | logits: Tensor of shape (batch_size, vocab_size) 87 | """ 88 | if top_k == 1: # Short-circuit for greedy decoding 89 | return logits.argmax(dim=-1) 90 | else: 91 | if top_p > 0.0: 92 | assert top_p <= 1.0, "top-p should be in (0, 1]." 93 | if top_k > 0: 94 | top_k = min(top_k, logits.size(-1)) # Safety check 95 | logits_top, indices = torch.topk(logits, top_k, dim=-1) 96 | if temperature != 1.0: 97 | logits_top /= temperature 98 | modify_logits_for_top_p_filtering(logits_top, top_p) 99 | return indices[ 100 | torch.arange(indices.shape[0], device=indices.device), 101 | torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1), 102 | ] 103 | else: 104 | if min_p > 0.0: 105 | logits_top = logits.clone() 106 | max_prob = logits_top[..., 0].item() 107 | min_prob = max_prob * min_p 108 | modify_logits_for_min_p_filtering(logits_top, min_p) 109 | if temperature != 1.0: 110 | logits_top /= temperature 111 | return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1) 112 | # Clone so that when we modify for top_p we don't change the original logits 113 | logits_top = logits / temperature if temperature != 1.0 else logits.clone() 114 | modify_logits_for_top_p_filtering(logits_top, top_p) 115 | return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze( 116 | dim=-1 117 | ) 118 | 119 | 120 | @torch.inference_mode() 121 | def decode( 122 | input_ids, 123 | model, 124 | max_length, 125 | top_k=1, 126 | top_p=0.0, 127 | min_p=0.0, 128 | temperature=1.0, 129 | repetition_penalty=1.0, 130 | eos_token_id=None, 131 | teacher_outputs=None, 132 | vocab_size=None, 133 | cg=False, 134 | enable_timing=False, 135 | streamer: Optional[TextStreamer] = None 136 | ): 137 | """Decoding, either greedy or with top-k or top-p sampling. 138 | If top-k = 0, don't limit the number of candidates (pure sampling). 139 | Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, 140 | then top-p. 141 | We assume that all sequences in the same batch have the same length. 142 | 143 | Arguments: 144 | input_ids: (batch, seq_len) 145 | max_length: int 146 | teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the 147 | logits, the next token is taken from the teacher_outputs. Useful for testing. 148 | Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: 149 | sequences: (batch, max_length) 150 | scores: tuples of (batch, vocab_size) 151 | """ 152 | if streamer is not None: 153 | streamer.put(input_ids.cpu()) 154 | 155 | batch_size, seqlen_og = input_ids.shape 156 | teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0 157 | if cg: 158 | if not hasattr(model, "_decoding_cache"): 159 | model._decoding_cache = None 160 | model._decoding_cache = update_graph_cache( 161 | model, 162 | model._decoding_cache, 163 | batch_size, 164 | seqlen_og, 165 | max_length, 166 | ) 167 | inference_params = model._decoding_cache.inference_params 168 | inference_params.reset(max_length, batch_size) 169 | else: 170 | inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) 171 | 172 | def get_logits(input_ids, inference_params): 173 | decoding = inference_params.seqlen_offset > 0 174 | if decoding: 175 | position_ids = torch.full( 176 | (batch_size, 1), 177 | inference_params.seqlen_offset, 178 | dtype=torch.long, 179 | device=input_ids.device, 180 | ) 181 | else: 182 | position_ids = None 183 | if not cg or not decoding: 184 | logits = model( 185 | input_ids, 186 | position_ids=position_ids, 187 | inference_params=inference_params, 188 | num_last_tokens=1, 189 | ).logits.squeeze(dim=1) 190 | else: 191 | logits = model._decoding_cache.run( 192 | input_ids, position_ids, inference_params.seqlen_offset 193 | ).squeeze(dim=1) 194 | return logits[..., :vocab_size] if vocab_size is not None else logits 195 | 196 | def sample_tokens(logits, inference_params): 197 | if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset: 198 | token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature) 199 | else: 200 | token = teacher_outputs[:, inference_params.seqlen_offset] 201 | # return rearrange(token, "b -> b 1") 202 | return token.unsqueeze(1) 203 | 204 | def should_stop(current_token, inference_params): 205 | if inference_params.seqlen_offset == 0: 206 | return False 207 | if eos_token_id is not None and (current_token == eos_token_id).all(): 208 | return True 209 | if inference_params.seqlen_offset >= max_length - 1: 210 | return True 211 | return False 212 | 213 | start = torch.cuda.Event(enable_timing=enable_timing) 214 | end = torch.cuda.Event(enable_timing=enable_timing) 215 | 216 | if enable_timing: 217 | start.record() 218 | scores, sequences = [], [input_ids] 219 | sequences_cat = input_ids 220 | while not should_stop(sequences[-1], inference_params): 221 | scores.append(get_logits(sequences[-1], inference_params)) 222 | inference_params.seqlen_offset += sequences[-1].shape[1] 223 | if repetition_penalty == 1.0: 224 | sampled_tokens = sample_tokens(scores[-1], inference_params) 225 | else: 226 | logits = modify_logit_for_repetition_penalty( 227 | scores[-1].clone(), sequences_cat, repetition_penalty 228 | ) 229 | sampled_tokens = sample_tokens(logits, inference_params) 230 | sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1) 231 | sequences.append(sampled_tokens) 232 | if streamer is not None: 233 | streamer.put(sampled_tokens.cpu()) 234 | if streamer is not None: 235 | streamer.end() 236 | if enable_timing: 237 | end.record() 238 | torch.cuda.synchronize() 239 | print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms") 240 | output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput 241 | return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores)) 242 | 243 | 244 | class GenerationMixin: 245 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 246 | raise NotImplementedError 247 | 248 | def generate( 249 | self, 250 | input_ids, 251 | max_length, 252 | top_k=1, 253 | top_p=0.0, 254 | min_p=0.0, 255 | temperature=1.0, 256 | return_dict_in_generate=False, 257 | output_scores=False, 258 | **kwargs, 259 | ): 260 | output = decode( 261 | input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs 262 | ) 263 | if not output_scores: 264 | output.scores = None 265 | return output if return_dict_in_generate else output.sequences 266 | 267 | 268 | @dataclass 269 | class DecodingCGCache: 270 | max_batch_size: int = 0 271 | max_seqlen: int = 0 272 | device = None 273 | dtype = None 274 | callables: dict = field(default_factory=dict) 275 | mempool = None 276 | inference_params: Optional[InferenceParams] = None 277 | run: Optional[Callable] = None 278 | 279 | 280 | @torch.inference_mode() 281 | def update_graph_cache( 282 | model, 283 | cache, 284 | batch_size, 285 | seqlen_og, 286 | max_seqlen, 287 | decoding_seqlens=(1,), 288 | dtype=None, 289 | n_warmups=2, 290 | ): 291 | if cache is None: 292 | cache = DecodingCGCache() 293 | param_example = next(iter(model.parameters())) 294 | device = param_example.device 295 | if dtype is None: 296 | dtype = param_example.dtype 297 | if ( 298 | (device, dtype) != (cache.device, cache.dtype) 299 | or batch_size > cache.max_batch_size 300 | or max_seqlen > cache.max_seqlen 301 | ): # Invalidate the cache 302 | cache.callables = {} 303 | cache.mempool = None 304 | cache.inference_params = None 305 | gc.collect() 306 | cache.device, cache.dtype = device, dtype 307 | cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen 308 | assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache" 309 | inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) 310 | lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) 311 | cache.inference_params = InferenceParams( 312 | max_seqlen=max_seqlen, 313 | max_batch_size=batch_size, 314 | seqlen_offset=seqlen_og, 315 | key_value_memory_dict=inf_cache, 316 | lengths_per_sample=lengths_per_sample, 317 | ) 318 | cache.mempool = torch.cuda.graphs.graph_pool_handle() 319 | for decoding_seqlen in decoding_seqlens: 320 | if (batch_size, decoding_seqlen) not in cache.callables: 321 | cache.callables[batch_size, decoding_seqlen] = capture_graph( 322 | model, 323 | cache.inference_params, 324 | batch_size, 325 | max_seqlen, 326 | decoding_seqlen=decoding_seqlen, 327 | mempool=cache.mempool, 328 | n_warmups=n_warmups, 329 | ) 330 | 331 | def dispatch(input_ids, position_ids, seqlen): 332 | batch_size, decoding_seqlen = input_ids.shape[:2] 333 | return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen) 334 | 335 | cache.run = dispatch 336 | cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing 337 | return cache 338 | 339 | 340 | def capture_graph( 341 | model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2 342 | ): 343 | device = next(iter(model.parameters())).device 344 | input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) 345 | position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) 346 | seqlen_offset_og = inference_params.seqlen_offset 347 | inference_params.seqlen_offset = max_seqlen - decoding_seqlen 348 | inference_params.lengths_per_sample[:] = inference_params.seqlen_offset 349 | 350 | # Warmup before capture 351 | s = torch.cuda.Stream() 352 | s.wait_stream(torch.cuda.current_stream()) 353 | with torch.cuda.stream(s): 354 | for _ in range(n_warmups): 355 | logits = model( 356 | input_ids, 357 | position_ids=position_ids, 358 | inference_params=inference_params, 359 | num_last_tokens=decoding_seqlen, 360 | ).logits 361 | s.synchronize() 362 | # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, 363 | # which requires that graph launch and non-captured launch to not overlap (I think, 364 | # that's how I interpret the documentation). I'm not sure if this is required. 365 | if torch.distributed.is_initialized(): 366 | torch.distributed.barrier() 367 | torch.cuda.current_stream().wait_stream(s) 368 | # Captures the graph 369 | # To allow capture, automatically sets a side stream as the current stream in the context 370 | graph = torch.cuda.CUDAGraph() 371 | with torch.cuda.graph(graph, pool=mempool): 372 | logits = model( 373 | input_ids, 374 | position_ids=position_ids, 375 | inference_params=inference_params, 376 | num_last_tokens=decoding_seqlen, 377 | ).logits 378 | 379 | def run(new_input_ids, new_position_ids, seqlen): 380 | inference_params.lengths_per_sample[:] = seqlen 381 | input_ids.copy_(new_input_ids) 382 | position_ids.copy_(new_position_ids) 383 | graph.replay() 384 | return logits.clone() 385 | 386 | inference_params.seqlen_offset = seqlen_offset_og 387 | return run 388 | -------------------------------------------------------------------------------- /espnet2/asr/encoder/conformer_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Tomoki Hayashi 2 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 3 | ''' 4 | Author: FnoY fangying@westlake.edu.cn 5 | LastEditTime: 2023-09-15 14:28:25 6 | FilePath: /espnet/espnet2/asr/encoder/conformer_encoder.py 7 | Change: from 'assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks' 8 | to 'assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) <= num_blocks' 9 | ''' 10 | """Conformer encoder definition.""" 11 | 12 | import logging 13 | from typing import List, Optional, Tuple, Union 14 | 15 | import torch 16 | from typeguard import check_argument_types 17 | 18 | from espnet2.asr.ctc import CTC 19 | from espnet2.asr.encoder.abs_encoder import AbsEncoder 20 | from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule 21 | from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer 22 | from espnet.nets.pytorch_backend.nets_utils import get_activation, make_pad_mask 23 | from espnet.nets.pytorch_backend.transformer.attention import ( 24 | LegacyRelPositionMultiHeadedAttention, 25 | MultiHeadedAttention, 26 | RelPositionMultiHeadedAttention, 27 | ) 28 | from espnet.nets.pytorch_backend.transformer.embedding import ( 29 | LegacyRelPositionalEncoding, 30 | PositionalEncoding, 31 | RelPositionalEncoding, 32 | ScaledPositionalEncoding, 33 | ) 34 | from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm 35 | from espnet.nets.pytorch_backend.transformer.multi_layer_conv import ( 36 | Conv1dLinear, 37 | MultiLayeredConv1d, 38 | ) 39 | from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( 40 | PositionwiseFeedForward, 41 | ) 42 | from espnet.nets.pytorch_backend.transformer.repeat import repeat 43 | from espnet.nets.pytorch_backend.transformer.subsampling import ( 44 | Conv2dSubsampling, 45 | Conv2dSubsampling1, 46 | Conv2dSubsampling2, 47 | Conv2dSubsampling6, 48 | Conv2dSubsampling8, 49 | TooShortUttError, 50 | check_short_utt, 51 | ) 52 | 53 | 54 | class ConformerEncoder(AbsEncoder): 55 | """Conformer encoder module. 56 | 57 | Args: 58 | input_size (int): Input dimension. 59 | output_size (int): Dimension of attention. 60 | attention_heads (int): The number of heads of multi head attention. 61 | linear_units (int): The number of units of position-wise feed forward. 62 | num_blocks (int): The number of decoder blocks. 63 | dropout_rate (float): Dropout rate. 64 | attention_dropout_rate (float): Dropout rate in attention. 65 | positional_dropout_rate (float): Dropout rate after adding positional encoding. 66 | input_layer (Union[str, torch.nn.Module]): Input layer type. 67 | normalize_before (bool): Whether to use layer_norm before the first block. 68 | concat_after (bool): Whether to concat attention layer's input and output. 69 | If True, additional linear will be applied. 70 | i.e. x -> x + linear(concat(x, att(x))) 71 | If False, no additional linear will be applied. i.e. x -> x + att(x) 72 | positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". 73 | positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. 74 | rel_pos_type (str): Whether to use the latest relative positional encoding or 75 | the legacy one. The legacy relative positional encoding will be deprecated 76 | in the future. More Details can be found in 77 | https://github.com/espnet/espnet/pull/2816. 78 | encoder_pos_enc_layer_type (str): Encoder positional encoding layer type. 79 | encoder_attn_layer_type (str): Encoder attention layer type. 80 | activation_type (str): Encoder activation function type. 81 | macaron_style (bool): Whether to use macaron style for positionwise layer. 82 | use_cnn_module (bool): Whether to use convolution module. 83 | zero_triu (bool): Whether to zero the upper triangular part of attention matrix. 84 | cnn_module_kernel (int): Kernerl size of convolution module. 85 | padding_idx (int): Padding idx for input_layer=embed. 86 | 87 | """ 88 | 89 | def __init__( 90 | self, 91 | input_size: int, 92 | output_size: int = 256, 93 | attention_heads: int = 4, 94 | linear_units: int = 2048, 95 | num_blocks: int = 6, 96 | dropout_rate: float = 0.1, 97 | positional_dropout_rate: float = 0.1, 98 | attention_dropout_rate: float = 0.0, 99 | input_layer: str = "conv2d", 100 | normalize_before: bool = True, 101 | concat_after: bool = False, 102 | positionwise_layer_type: str = "linear", 103 | positionwise_conv_kernel_size: int = 3, 104 | macaron_style: bool = False, 105 | rel_pos_type: str = "legacy", 106 | pos_enc_layer_type: str = "rel_pos", 107 | selfattention_layer_type: str = "rel_selfattn", 108 | activation_type: str = "swish", 109 | use_cnn_module: bool = True, 110 | zero_triu: bool = False, 111 | cnn_module_kernel: int = 31, 112 | padding_idx: int = -1, 113 | interctc_layer_idx: List[int] = [], 114 | interctc_use_conditioning: bool = False, 115 | stochastic_depth_rate: Union[float, List[float]] = 0.0, 116 | layer_drop_rate: float = 0.0, 117 | max_pos_emb_len: int = 5000, 118 | ): 119 | assert check_argument_types() 120 | super().__init__() 121 | self._output_size = output_size 122 | 123 | if rel_pos_type == "legacy": 124 | if pos_enc_layer_type == "rel_pos": 125 | pos_enc_layer_type = "legacy_rel_pos" 126 | if selfattention_layer_type == "rel_selfattn": 127 | selfattention_layer_type = "legacy_rel_selfattn" 128 | elif rel_pos_type == "latest": 129 | assert selfattention_layer_type != "legacy_rel_selfattn" 130 | assert pos_enc_layer_type != "legacy_rel_pos" 131 | else: 132 | raise ValueError("unknown rel_pos_type: " + rel_pos_type) 133 | 134 | activation = get_activation(activation_type) 135 | if pos_enc_layer_type == "abs_pos": 136 | pos_enc_class = PositionalEncoding 137 | elif pos_enc_layer_type == "scaled_abs_pos": 138 | pos_enc_class = ScaledPositionalEncoding 139 | elif pos_enc_layer_type == "rel_pos": 140 | assert selfattention_layer_type == "rel_selfattn" 141 | pos_enc_class = RelPositionalEncoding 142 | elif pos_enc_layer_type == "legacy_rel_pos": 143 | assert selfattention_layer_type == "legacy_rel_selfattn" 144 | pos_enc_class = LegacyRelPositionalEncoding 145 | logging.warning( 146 | "Using legacy_rel_pos and it will be deprecated in the future." 147 | ) 148 | else: 149 | raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) 150 | 151 | if input_layer == "linear": 152 | self.embed = torch.nn.Sequential( 153 | torch.nn.Linear(input_size, output_size), 154 | torch.nn.LayerNorm(output_size), 155 | torch.nn.Dropout(dropout_rate), 156 | pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), 157 | ) 158 | elif input_layer == "conv2d": 159 | self.embed = Conv2dSubsampling( 160 | input_size, 161 | output_size, 162 | dropout_rate, 163 | pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), 164 | ) 165 | elif input_layer == "conv2d1": 166 | self.embed = Conv2dSubsampling1( 167 | input_size, 168 | output_size, 169 | dropout_rate, 170 | pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), 171 | ) 172 | elif input_layer == "conv2d2": 173 | self.embed = Conv2dSubsampling2( 174 | input_size, 175 | output_size, 176 | dropout_rate, 177 | pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), 178 | ) 179 | elif input_layer == "conv2d6": 180 | self.embed = Conv2dSubsampling6( 181 | input_size, 182 | output_size, 183 | dropout_rate, 184 | pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), 185 | ) 186 | elif input_layer == "conv2d8": 187 | self.embed = Conv2dSubsampling8( 188 | input_size, 189 | output_size, 190 | dropout_rate, 191 | pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), 192 | ) 193 | elif input_layer == "embed": 194 | self.embed = torch.nn.Sequential( 195 | torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), 196 | pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), 197 | ) 198 | elif isinstance(input_layer, torch.nn.Module): 199 | self.embed = torch.nn.Sequential( 200 | input_layer, 201 | pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), 202 | ) 203 | elif input_layer is None: 204 | self.embed = torch.nn.Sequential( 205 | pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len) 206 | ) 207 | else: 208 | raise ValueError("unknown input_layer: " + input_layer) 209 | self.normalize_before = normalize_before 210 | if positionwise_layer_type == "linear": 211 | positionwise_layer = PositionwiseFeedForward 212 | positionwise_layer_args = ( 213 | output_size, 214 | linear_units, 215 | dropout_rate, 216 | activation, 217 | ) 218 | elif positionwise_layer_type == "conv1d": 219 | positionwise_layer = MultiLayeredConv1d 220 | positionwise_layer_args = ( 221 | output_size, 222 | linear_units, 223 | positionwise_conv_kernel_size, 224 | dropout_rate, 225 | ) 226 | elif positionwise_layer_type == "conv1d-linear": 227 | positionwise_layer = Conv1dLinear 228 | positionwise_layer_args = ( 229 | output_size, 230 | linear_units, 231 | positionwise_conv_kernel_size, 232 | dropout_rate, 233 | ) 234 | else: 235 | raise NotImplementedError("Support only linear or conv1d.") 236 | 237 | if selfattention_layer_type == "selfattn": 238 | encoder_selfattn_layer = MultiHeadedAttention 239 | encoder_selfattn_layer_args = ( 240 | attention_heads, 241 | output_size, 242 | attention_dropout_rate, 243 | ) 244 | elif selfattention_layer_type == "legacy_rel_selfattn": 245 | assert pos_enc_layer_type == "legacy_rel_pos" 246 | encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention 247 | encoder_selfattn_layer_args = ( 248 | attention_heads, 249 | output_size, 250 | attention_dropout_rate, 251 | ) 252 | logging.warning( 253 | "Using legacy_rel_selfattn and it will be deprecated in the future." 254 | ) 255 | elif selfattention_layer_type == "rel_selfattn": 256 | assert pos_enc_layer_type == "rel_pos" 257 | encoder_selfattn_layer = RelPositionMultiHeadedAttention 258 | encoder_selfattn_layer_args = ( 259 | attention_heads, 260 | output_size, 261 | attention_dropout_rate, 262 | zero_triu, 263 | ) 264 | else: 265 | raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type) 266 | 267 | convolution_layer = ConvolutionModule 268 | convolution_layer_args = (output_size, cnn_module_kernel, activation) 269 | 270 | if isinstance(stochastic_depth_rate, float): 271 | stochastic_depth_rate = [stochastic_depth_rate] * num_blocks 272 | 273 | if len(stochastic_depth_rate) != num_blocks: 274 | raise ValueError( 275 | f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) " 276 | f"should be equal to num_blocks ({num_blocks})" 277 | ) 278 | 279 | self.encoders = repeat( 280 | num_blocks, 281 | lambda lnum: EncoderLayer( 282 | output_size, 283 | encoder_selfattn_layer(*encoder_selfattn_layer_args), 284 | positionwise_layer(*positionwise_layer_args), 285 | positionwise_layer(*positionwise_layer_args) if macaron_style else None, 286 | convolution_layer(*convolution_layer_args) if use_cnn_module else None, 287 | dropout_rate, 288 | normalize_before, 289 | concat_after, 290 | stochastic_depth_rate[lnum], 291 | ), 292 | layer_drop_rate, 293 | ) 294 | if self.normalize_before: 295 | self.after_norm = LayerNorm(output_size) 296 | 297 | self.interctc_layer_idx = interctc_layer_idx 298 | if len(interctc_layer_idx) > 0: 299 | assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) <= num_blocks 300 | self.interctc_use_conditioning = interctc_use_conditioning 301 | self.conditioning_layer = None 302 | 303 | def output_size(self) -> int: 304 | return self._output_size 305 | 306 | def forward( 307 | self, 308 | xs_pad: torch.Tensor, 309 | ilens: torch.Tensor, 310 | prev_states: torch.Tensor = None, 311 | ctc: CTC = None, 312 | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: 313 | """Calculate forward propagation. 314 | 315 | Args: 316 | xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). 317 | ilens (torch.Tensor): Input length (#batch). 318 | prev_states (torch.Tensor): Not to be used now. 319 | 320 | Returns: 321 | torch.Tensor: Output tensor (#batch, L, output_size). 322 | torch.Tensor: Output length (#batch). 323 | torch.Tensor: Not to be used now. 324 | 325 | """ 326 | masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) 327 | 328 | if ( 329 | isinstance(self.embed, Conv2dSubsampling) 330 | or isinstance(self.embed, Conv2dSubsampling1) 331 | or isinstance(self.embed, Conv2dSubsampling2) 332 | or isinstance(self.embed, Conv2dSubsampling6) 333 | or isinstance(self.embed, Conv2dSubsampling8) 334 | ): 335 | short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) 336 | if short_status: 337 | raise TooShortUttError( 338 | f"has {xs_pad.size(1)} frames and is too short for subsampling " 339 | + f"(it needs more than {limit_size} frames), return empty results", 340 | xs_pad.size(1), 341 | limit_size, 342 | ) 343 | xs_pad, masks = self.embed(xs_pad, masks) 344 | else: 345 | xs_pad = self.embed(xs_pad) 346 | 347 | intermediate_outs = [] 348 | if len(self.interctc_layer_idx) == 0: 349 | xs_pad, masks = self.encoders(xs_pad, masks) 350 | else: 351 | for layer_idx, encoder_layer in enumerate(self.encoders): 352 | xs_pad, masks = encoder_layer(xs_pad, masks) 353 | 354 | if layer_idx + 1 in self.interctc_layer_idx: 355 | encoder_out = xs_pad 356 | if isinstance(encoder_out, tuple): 357 | encoder_out = encoder_out[0] 358 | 359 | # intermediate outputs are also normalized 360 | if self.normalize_before: 361 | encoder_out = self.after_norm(encoder_out) 362 | 363 | intermediate_outs.append((layer_idx + 1, encoder_out)) 364 | 365 | if self.interctc_use_conditioning: 366 | ctc_out = ctc.softmax(encoder_out) 367 | 368 | if isinstance(xs_pad, tuple): 369 | x, pos_emb = xs_pad 370 | x = x + self.conditioning_layer(ctc_out) 371 | xs_pad = (x, pos_emb) 372 | else: 373 | xs_pad = xs_pad + self.conditioning_layer(ctc_out) 374 | 375 | if isinstance(xs_pad, tuple): 376 | xs_pad = xs_pad[0] 377 | if self.normalize_before: 378 | xs_pad = self.after_norm(xs_pad) 379 | 380 | olens = masks.squeeze(1).sum(1) 381 | if len(intermediate_outs) > 0: 382 | return (xs_pad, intermediate_outs), olens, None 383 | return xs_pad, olens, None 384 | -------------------------------------------------------------------------------- /espnet2/asr/mamba_ssm/ops/selective_scan_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao, Albert Gu. 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.cuda.amp import custom_bwd, custom_fwd 6 | 7 | from einops import rearrange, repeat 8 | 9 | try: 10 | from causal_conv1d import causal_conv1d_fn 11 | import causal_conv1d_cuda 12 | except ImportError: 13 | causal_conv1d_fn = None 14 | causal_conv1d_cuda = None 15 | 16 | import selective_scan_cuda 17 | 18 | 19 | class SelectiveScanFn(torch.autograd.Function): 20 | 21 | @staticmethod 22 | def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, 23 | return_last_state=False): 24 | if u.stride(-1) != 1: 25 | u = u.contiguous() 26 | if delta.stride(-1) != 1: 27 | delta = delta.contiguous() 28 | if D is not None: 29 | D = D.contiguous() 30 | if B.stride(-1) != 1: 31 | B = B.contiguous() 32 | if C.stride(-1) != 1: 33 | C = C.contiguous() 34 | if z is not None and z.stride(-1) != 1: 35 | z = z.contiguous() 36 | if B.dim() == 3: 37 | B = rearrange(B, "b dstate l -> b 1 dstate l") 38 | ctx.squeeze_B = True 39 | if C.dim() == 3: 40 | C = rearrange(C, "b dstate l -> b 1 dstate l") 41 | ctx.squeeze_C = True 42 | out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) 43 | ctx.delta_softplus = delta_softplus 44 | ctx.has_z = z is not None 45 | last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) 46 | if not ctx.has_z: 47 | ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) 48 | return out if not return_last_state else (out, last_state) 49 | else: 50 | ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) 51 | out_z = rest[0] 52 | return out_z if not return_last_state else (out_z, last_state) 53 | 54 | @staticmethod 55 | def backward(ctx, dout, *args): 56 | if not ctx.has_z: 57 | u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors 58 | z = None 59 | out = None 60 | else: 61 | u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors 62 | if dout.stride(-1) != 1: 63 | dout = dout.contiguous() 64 | # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the 65 | # backward of selective_scan_cuda with the backward of chunk). 66 | # Here we just pass in None and dz will be allocated in the C++ code. 67 | du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( 68 | u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, 69 | False # option to recompute out_z, not used here 70 | ) 71 | dz = rest[0] if ctx.has_z else None 72 | dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB 73 | dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC 74 | return (du, ddelta, dA, dB, dC, 75 | dD if D is not None else None, 76 | dz, 77 | ddelta_bias if delta_bias is not None else None, 78 | None, 79 | None) 80 | 81 | 82 | def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, 83 | return_last_state=False): 84 | """if return_last_state is True, returns (out, last_state) 85 | last_state has shape (batch, dim, dstate). Note that the gradient of the last state is 86 | not considered in the backward pass. 87 | """ 88 | return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) 89 | 90 | 91 | def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, 92 | return_last_state=False): 93 | """ 94 | u: r(B D L) 95 | delta: r(B D L) 96 | A: c(D N) or r(D N) 97 | B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) 98 | C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) 99 | D: r(D) 100 | z: r(B D L) 101 | delta_bias: r(D), fp32 102 | 103 | out: r(B D L) 104 | last_state (optional): r(B D dstate) or c(B D dstate) 105 | """ 106 | dtype_in = u.dtype 107 | u = u.float() 108 | delta = delta.float() 109 | if delta_bias is not None: 110 | delta = delta + delta_bias[..., None].float() 111 | if delta_softplus: 112 | delta = F.softplus(delta) 113 | batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] 114 | is_variable_B = B.dim() >= 3 115 | is_variable_C = C.dim() >= 3 116 | if A.is_complex(): 117 | if is_variable_B: 118 | B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) 119 | if is_variable_C: 120 | C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) 121 | else: 122 | B = B.float() 123 | C = C.float() 124 | x = A.new_zeros((batch, dim, dstate)) 125 | ys = [] 126 | deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) 127 | if not is_variable_B: 128 | deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) 129 | else: 130 | if B.dim() == 3: 131 | deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) 132 | else: 133 | B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) 134 | deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) 135 | if is_variable_C and C.dim() == 4: 136 | C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) 137 | last_state = None 138 | for i in range(u.shape[2]): 139 | x = deltaA[:, :, i] * x + deltaB_u[:, :, i] 140 | if not is_variable_C: 141 | y = torch.einsum('bdn,dn->bd', x, C) 142 | else: 143 | if C.dim() == 3: 144 | y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) 145 | else: 146 | y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) 147 | if i == u.shape[2] - 1: 148 | last_state = x 149 | if y.is_complex(): 150 | y = y.real * 2 151 | ys.append(y) 152 | y = torch.stack(ys, dim=2) # (batch dim L) 153 | out = y if D is None else y + u * rearrange(D, "d -> d 1") 154 | if z is not None: 155 | out = out * F.silu(z) 156 | out = out.to(dtype=dtype_in) 157 | return out if not return_last_state else (out, last_state) 158 | 159 | 160 | class MambaInnerFn(torch.autograd.Function): 161 | 162 | @staticmethod 163 | @custom_fwd 164 | def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, 165 | out_proj_weight, out_proj_bias, 166 | A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, 167 | C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): 168 | """ 169 | xz: (batch, dim, seqlen) 170 | """ 171 | assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." 172 | assert checkpoint_lvl in [0, 1] 173 | L = xz.shape[-1] 174 | delta_rank = delta_proj_weight.shape[1] 175 | d_state = A.shape[-1] * (1 if not A.is_complex() else 2) 176 | if torch.is_autocast_enabled(): 177 | x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) 178 | delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) 179 | out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) 180 | out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) 181 | if out_proj_bias is not None else None) 182 | if xz.stride(-1) != 1: 183 | xz = xz.contiguous() 184 | conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") 185 | x, z = xz.chunk(2, dim=1) 186 | conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None 187 | conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( 188 | x, conv1d_weight, conv1d_bias, None, None, None, True 189 | ) 190 | # We're being very careful here about the layout, to avoid extra transposes. 191 | # We want delta to have d as the slowest moving dimension 192 | # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. 193 | x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) 194 | delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) 195 | ctx.is_variable_B = B is None 196 | ctx.is_variable_C = C is None 197 | ctx.B_proj_bias_is_None = B_proj_bias is None 198 | ctx.C_proj_bias_is_None = C_proj_bias is None 199 | if B is None: # variable B 200 | B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate) 201 | if B_proj_bias is not None: 202 | B = B + B_proj_bias.to(dtype=B.dtype) 203 | if not A.is_complex(): 204 | # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() 205 | B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() 206 | else: 207 | B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() 208 | else: 209 | if B.stride(-1) != 1: 210 | B = B.contiguous() 211 | if C is None: # variable C 212 | C = x_dbl[:, -d_state:] # (bl dstate) 213 | if C_proj_bias is not None: 214 | C = C + C_proj_bias.to(dtype=C.dtype) 215 | if not A.is_complex(): 216 | # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() 217 | C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() 218 | else: 219 | C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() 220 | else: 221 | if C.stride(-1) != 1: 222 | C = C.contiguous() 223 | if D is not None: 224 | D = D.contiguous() 225 | out, scan_intermediates, out_z = selective_scan_cuda.fwd( 226 | conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus 227 | ) 228 | ctx.delta_softplus = delta_softplus 229 | ctx.out_proj_bias_is_None = out_proj_bias is None 230 | ctx.checkpoint_lvl = checkpoint_lvl 231 | if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass 232 | conv1d_out, delta = None, None 233 | ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, 234 | delta_proj_weight, out_proj_weight, conv1d_out, delta, 235 | A, B, C, D, delta_bias, scan_intermediates, out) 236 | return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) 237 | 238 | @staticmethod 239 | @custom_bwd 240 | def backward(ctx, dout): 241 | # dout: (batch, seqlen, dim) 242 | assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." 243 | (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, 244 | conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors 245 | L = xz.shape[-1] 246 | delta_rank = delta_proj_weight.shape[1] 247 | d_state = A.shape[-1] * (1 if not A.is_complex() else 2) 248 | x, z = xz.chunk(2, dim=1) 249 | if dout.stride(-1) != 1: 250 | dout = dout.contiguous() 251 | if ctx.checkpoint_lvl == 1: 252 | conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( 253 | x, conv1d_weight, conv1d_bias, None, None, None, True 254 | ) 255 | delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), 256 | "d (b l) -> b d l", l = L) 257 | # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the 258 | # backward of selective_scan_cuda with the backward of chunk). 259 | dxz = torch.empty_like(xz) # (batch, dim, seqlen) 260 | dx, dz = dxz.chunk(2, dim=1) 261 | dout = rearrange(dout, "b l e -> e (b l)") 262 | dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L) 263 | dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( 264 | conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz, 265 | ctx.delta_softplus, 266 | True # option to recompute out_z 267 | ) 268 | dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) 269 | dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None 270 | dD = dD if D is not None else None 271 | dx_dbl = torch.empty_like(x_dbl) 272 | dB_proj_bias = None 273 | if ctx.is_variable_B: 274 | if not A.is_complex(): 275 | dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() 276 | else: 277 | dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() 278 | dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None 279 | dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d) 280 | dB = None 281 | dC_proj_bias = None 282 | if ctx.is_variable_C: 283 | if not A.is_complex(): 284 | dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() 285 | else: 286 | dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() 287 | dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None 288 | dx_dbl[:, -d_state:] = dC # (bl d) 289 | dC = None 290 | ddelta = rearrange(ddelta, "b d l -> d (b l)") 291 | ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) 292 | dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) 293 | dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") 294 | dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) 295 | dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) 296 | dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) 297 | # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the 298 | # backward of conv1d with the backward of chunk). 299 | dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( 300 | x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True 301 | ) 302 | dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None 303 | dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") 304 | return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, 305 | dout_proj_weight, dout_proj_bias, 306 | dA, dB, dC, dD, 307 | ddelta_bias if delta_bias is not None else None, 308 | dB_proj_bias, dC_proj_bias, None) 309 | 310 | 311 | def mamba_inner_fn( 312 | xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, 313 | out_proj_weight, out_proj_bias, 314 | A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, 315 | C_proj_bias=None, delta_softplus=True 316 | ): 317 | return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, 318 | out_proj_weight, out_proj_bias, 319 | A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) 320 | 321 | 322 | def mamba_inner_ref( 323 | xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, 324 | out_proj_weight, out_proj_bias, 325 | A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, 326 | C_proj_bias=None, delta_softplus=True 327 | ): 328 | assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d." 329 | L = xz.shape[-1] 330 | delta_rank = delta_proj_weight.shape[1] 331 | d_state = A.shape[-1] * (1 if not A.is_complex() else 2) 332 | x, z = xz.chunk(2, dim=1) 333 | x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu") 334 | # We're being very careful here about the layout, to avoid extra transposes. 335 | # We want delta to have d as the slowest moving dimension 336 | # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. 337 | x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d) 338 | delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() 339 | delta = rearrange(delta, "d (b l) -> b d l", l=L) 340 | if B is None: # variable B 341 | B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d) 342 | if B_proj_bias is not None: 343 | B = B + B_proj_bias.to(dtype=B.dtype) 344 | if not A.is_complex(): 345 | B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() 346 | else: 347 | B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() 348 | if C is None: # variable B 349 | C = x_dbl[:, -d_state:] # (bl d) 350 | if C_proj_bias is not None: 351 | C = C + C_proj_bias.to(dtype=C.dtype) 352 | if not A.is_complex(): 353 | C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() 354 | else: 355 | C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() 356 | y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) 357 | return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) 358 | --------------------------------------------------------------------------------