├── benchmarks ├── DASB │ ├── utils │ │ ├── __init__.py │ │ └── data.py │ ├── Libri2Mix │ │ └── separation │ │ │ ├── crdnn │ │ │ ├── utils.py │ │ │ ├── metrics │ │ │ │ ├── dnsmos.py │ │ │ │ ├── dwer.py │ │ │ │ └── spk_sim.py │ │ │ ├── custom_model.py │ │ │ └── librimix_prepare.py │ │ │ ├── conformer │ │ │ ├── utils.py │ │ │ ├── metrics │ │ │ │ ├── dwer.py │ │ │ │ ├── dnsmos.py │ │ │ │ └── spk_sim.py │ │ │ ├── custom_model.py │ │ │ └── librimix_prepare.py │ │ │ └── metrics │ │ │ ├── model_v8.onnx │ │ │ ├── dwer.py │ │ │ └── spk_sim.py │ ├── VoiceBank │ │ └── enhancement │ │ │ ├── crdnn │ │ │ ├── utils.py │ │ │ ├── metrics │ │ │ │ ├── dwer.py │ │ │ │ ├── dnsmos.py │ │ │ │ └── spk_sim.py │ │ │ ├── custom_model.py │ │ │ └── voicebank_prepare.py │ │ │ ├── conformer │ │ │ ├── utils.py │ │ │ ├── metrics │ │ │ │ ├── dnsmos.py │ │ │ │ ├── dwer.py │ │ │ │ └── spk_sim.py │ │ │ ├── custom_model.py │ │ │ └── voicebank_prepare.py │ │ │ └── metrics │ │ │ ├── model_v8.onnx │ │ │ ├── dwer.py │ │ │ └── spk_sim.py │ ├── LJSpeech │ │ ├── TTS │ │ │ └── tokotron │ │ │ │ ├── data.py │ │ │ │ ├── eval.py │ │ │ │ ├── Tokotron.py │ │ │ │ ├── audio_tokens.py │ │ │ │ ├── ljspeech_prepare.py │ │ │ │ ├── preparation.py │ │ │ │ ├── hparams │ │ │ │ ├── char_en.txt │ │ │ │ ├── arpabet.txt │ │ │ │ └── eval.yaml │ │ │ │ ├── train_encodec.py │ │ │ │ ├── train_speech_tokenizer.py │ │ │ │ ├── train_continuous_ssl.py │ │ │ │ ├── train_dac.py │ │ │ │ └── train_discrete_ssl.py │ │ └── quantization │ │ │ ├── ljspeech_prepare.py │ │ │ ├── extra-requirements.txt │ │ │ ├── README.md │ │ │ └── hparams │ │ │ └── train_discrete_ssl.yaml │ ├── IEMOCAP │ │ ├── quantization │ │ │ ├── extra-requirements.txt │ │ │ ├── iemocap_prepare.py │ │ │ ├── README.md │ │ │ └── hparams │ │ │ │ └── train_discrete_ssl.yaml │ │ └── emotion_recognition │ │ │ ├── linear │ │ │ ├── iemocap_prepare.py │ │ │ ├── custom_model.py │ │ │ └── hparams │ │ │ │ └── train_weighted_ssl.yaml │ │ │ └── ecapa_tdnn │ │ │ ├── iemocap_prepare.py │ │ │ └── custom_model.py │ ├── CommonVoice │ │ ├── ASR │ │ │ ├── LSTM │ │ │ │ ├── custom_model.py │ │ │ │ └── common_voice_prepare.py │ │ │ └── linear │ │ │ │ ├── custom_model.py │ │ │ │ └── common_voice_prepare.py │ │ └── quantization │ │ │ ├── extra-requirements.txt │ │ │ ├── common_voice_prepare.py │ │ │ ├── README.md │ │ │ └── hparams │ │ │ └── train_discrete_ssl.yaml │ ├── LibriSpeech │ │ ├── ASR │ │ │ ├── LSTM │ │ │ │ ├── custom_model.py │ │ │ │ └── librispeech_prepare.py │ │ │ └── contextnet │ │ │ │ ├── custom_model.py │ │ │ │ └── librispeech_prepare.py │ │ └── quantization │ │ │ ├── extra-requirements.txt │ │ │ ├── librispeech_prepare.py │ │ │ ├── README.md │ │ │ ├── hparams │ │ │ ├── train_discrete_ssl.yaml │ │ │ └── train_subwording.yaml │ │ │ └── train_subwording.py │ ├── VoiceCeleb1 │ │ ├── quantization │ │ │ ├── extra-requirements.txt │ │ │ ├── voxceleb_prepare.py │ │ │ ├── README.md │ │ │ └── hparams │ │ │ │ └── train_discrete_ssl.yaml │ │ └── speaker_ver │ │ │ ├── Xvector │ │ │ ├── custom_model.py │ │ │ └── voxceleb_prepare.py │ │ │ └── ecapa_tdnn │ │ │ ├── custom_model.py │ │ │ └── voxceleb_prepare.py │ ├── SLURP │ │ └── intent_classification │ │ │ ├── linear │ │ │ ├── slurp_prepare.py │ │ │ └── custom_model.py │ │ │ └── LSTM_linear │ │ │ ├── slurp_prepare.py │ │ │ └── custom_model.py │ ├── Google-speech-commands │ │ └── keyword-spotting │ │ │ ├── Xvector │ │ │ ├── prepare_GSC.py │ │ │ └── custom_model.py │ │ │ └── ecapa_tdnn │ │ │ ├── prepare_GSC.py │ │ │ └── custom_model.py │ ├── DASB_logo.png │ ├── extra_requirements.txt │ ├── run_discriminative_benchmark.sh │ ├── run_generative_benchmark.sh │ └── model │ │ └── custom_model.py ├── MP3S │ ├── SLURP │ │ └── linear │ │ │ └── prepare.py │ ├── Buckeye │ │ └── contextnet │ │ │ └── buckeye_prepare.py │ ├── IEMOCAP │ │ └── linear │ │ │ └── iemocap_prepare.py │ ├── VoxCeleb1 │ │ └── Xvectors │ │ │ └── voxceleb_prepare.py │ ├── CommonVoice │ │ └── linear │ │ │ └── common_voice_prepare.py │ ├── LibriSpeech │ │ └── contextnet │ │ │ └── librispeech_prepare.py │ ├── extra_requirements.txt │ └── run_benchmark.sh ├── CL_MASR │ ├── wavlm │ │ ├── common_voice_prepare.py │ │ └── hparams │ │ │ ├── pretrain.yaml │ │ │ ├── train_joint.yaml │ │ │ ├── train_ft.yaml │ │ │ ├── train_er.yaml │ │ │ ├── train_agem.yaml │ │ │ ├── train_lwf.yaml │ │ │ ├── train_der.yaml │ │ │ ├── train_pnn.yaml │ │ │ ├── train_ewc.yaml │ │ │ ├── train_mas.yaml │ │ │ ├── train_pb.yaml │ │ │ └── train_l2p.yaml │ ├── whisper │ │ ├── common_voice_prepare.py │ │ └── hparams │ │ │ ├── train_ft.yaml │ │ │ ├── train_joint.yaml │ │ │ ├── train_er.yaml │ │ │ ├── train_agem.yaml │ │ │ ├── train_pnn.yaml │ │ │ ├── train_der.yaml │ │ │ ├── train_pb.yaml │ │ │ ├── train_lwf.yaml │ │ │ ├── train_mas.yaml │ │ │ ├── train_ewc.yaml │ │ │ └── train_l2p.yaml │ └── .gitignore └── MOABB │ ├── hparams │ └── orion │ │ ├── hparams_random_search.yaml │ │ └── hparams_tpe.yaml │ ├── extra-requirements.txt │ └── models │ └── BraindecodeNN.py ├── requirements.txt ├── tests ├── __init__.py ├── samples │ ├── ASR │ │ ├── spk1_snt1.wav │ │ ├── spk1_snt2.wav │ │ ├── spk1_snt3.wav │ │ ├── spk1_snt4.wav │ │ ├── spk1_snt5.wav │ │ ├── spk1_snt6.wav │ │ ├── spk2_snt1.wav │ │ ├── spk2_snt2.wav │ │ ├── spk2_snt3.wav │ │ ├── spk2_snt4.wav │ │ ├── spk2_snt5.wav │ │ └── spk2_snt6.wav │ └── annotation │ │ ├── ASR_train.csv │ │ └── ASR_Buckeye.csv ├── .run-HF-checks.sh ├── .run-url-checks.sh ├── .run-recipe-tests.sh ├── .run-linters.sh ├── consistency │ ├── test_docstrings.py │ ├── test_yaml.py │ ├── README.md │ ├── DOCSTRINGS.md │ └── test_HF_repo.py ├── .run-doctests.sh ├── .run-load-yaml-tests.sh ├── utils │ ├── overrides.yaml │ ├── check_url.py │ └── check_HF_repo.py └── PRE-RELEASE-TESTS.md ├── .gitmodules ├── lint-requirements.txt ├── .flake8 ├── pyproject.toml ├── .yamllint.yaml ├── .github └── workflows │ └── pre-commit.yml ├── .pre-commit-config.yaml └── .gitignore /benchmarks/DASB/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchmarks/DASB/Libri2Mix/separation/crdnn/utils.py: -------------------------------------------------------------------------------- 1 | ../utils.py -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceBank/enhancement/crdnn/utils.py: -------------------------------------------------------------------------------- 1 | ../utils.py -------------------------------------------------------------------------------- /benchmarks/DASB/LJSpeech/TTS/tokotron/data.py: -------------------------------------------------------------------------------- 1 | ../../../utils/data.py -------------------------------------------------------------------------------- /benchmarks/DASB/LJSpeech/TTS/tokotron/eval.py: -------------------------------------------------------------------------------- 1 | ../../../utils/eval.py -------------------------------------------------------------------------------- /benchmarks/DASB/Libri2Mix/separation/conformer/utils.py: -------------------------------------------------------------------------------- 1 | ../utils.py -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceBank/enhancement/conformer/utils.py: -------------------------------------------------------------------------------- 1 | ../utils.py -------------------------------------------------------------------------------- /benchmarks/MP3S/SLURP/linear/prepare.py: -------------------------------------------------------------------------------- 1 | ../LSTM_linear/prepare.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r lint-requirements.txt 2 | speechbrain>=0.5.14 3 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/wavlm/common_voice_prepare.py: -------------------------------------------------------------------------------- 1 | ../common_voice_prepare.py -------------------------------------------------------------------------------- /benchmarks/CL_MASR/whisper/common_voice_prepare.py: -------------------------------------------------------------------------------- 1 | ../common_voice_prepare.py -------------------------------------------------------------------------------- /benchmarks/DASB/IEMOCAP/quantization/extra-requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn 2 | -------------------------------------------------------------------------------- /benchmarks/DASB/IEMOCAP/quantization/iemocap_prepare.py: -------------------------------------------------------------------------------- 1 | ../iemocap_prepare.py -------------------------------------------------------------------------------- /benchmarks/DASB/LJSpeech/TTS/tokotron/Tokotron.py: -------------------------------------------------------------------------------- 1 | ../../../model/Tokotron.py -------------------------------------------------------------------------------- /benchmarks/DASB/CommonVoice/ASR/LSTM/custom_model.py: -------------------------------------------------------------------------------- 1 | ../../../model/custom_model.py -------------------------------------------------------------------------------- /benchmarks/DASB/CommonVoice/quantization/extra-requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn 2 | -------------------------------------------------------------------------------- /benchmarks/DASB/LJSpeech/TTS/tokotron/audio_tokens.py: -------------------------------------------------------------------------------- 1 | ../../../utils/audio_tokens.py -------------------------------------------------------------------------------- /benchmarks/DASB/LJSpeech/TTS/tokotron/ljspeech_prepare.py: -------------------------------------------------------------------------------- 1 | ../../ljspeech_prepare.py -------------------------------------------------------------------------------- /benchmarks/DASB/LJSpeech/TTS/tokotron/preparation.py: -------------------------------------------------------------------------------- 1 | ../../../utils/preparation.py -------------------------------------------------------------------------------- /benchmarks/DASB/LJSpeech/quantization/ljspeech_prepare.py: -------------------------------------------------------------------------------- 1 | ../ljspeech_prepare.py -------------------------------------------------------------------------------- /benchmarks/DASB/Libri2Mix/separation/conformer/metrics/dwer.py: -------------------------------------------------------------------------------- 1 | ../../metrics/dwer.py -------------------------------------------------------------------------------- /benchmarks/DASB/Libri2Mix/separation/crdnn/metrics/dnsmos.py: -------------------------------------------------------------------------------- 1 | ../../metrics/dnsmos.py -------------------------------------------------------------------------------- /benchmarks/DASB/Libri2Mix/separation/crdnn/metrics/dwer.py: -------------------------------------------------------------------------------- 1 | ../../metrics/dwer.py -------------------------------------------------------------------------------- /benchmarks/DASB/LibriSpeech/ASR/LSTM/custom_model.py: -------------------------------------------------------------------------------- 1 | ../../../model/custom_model.py -------------------------------------------------------------------------------- /benchmarks/DASB/LibriSpeech/quantization/extra-requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn 2 | -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceBank/enhancement/crdnn/metrics/dwer.py: -------------------------------------------------------------------------------- 1 | ../../metrics/dwer.py -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceCeleb1/quantization/extra-requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn 2 | -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceCeleb1/quantization/voxceleb_prepare.py: -------------------------------------------------------------------------------- 1 | ../voxceleb_prepare.py -------------------------------------------------------------------------------- /benchmarks/MP3S/Buckeye/contextnet/buckeye_prepare.py: -------------------------------------------------------------------------------- 1 | ../LSTM/buckeye_prepare.py -------------------------------------------------------------------------------- /benchmarks/MP3S/IEMOCAP/linear/iemocap_prepare.py: -------------------------------------------------------------------------------- 1 | ../ecapa_tdnn/iemocap_prepare.py -------------------------------------------------------------------------------- /benchmarks/DASB/CommonVoice/ASR/linear/custom_model.py: -------------------------------------------------------------------------------- 1 | ../../../model/custom_model.py -------------------------------------------------------------------------------- /benchmarks/DASB/Libri2Mix/separation/conformer/metrics/dnsmos.py: -------------------------------------------------------------------------------- 1 | ../../metrics/dnsmos.py -------------------------------------------------------------------------------- /benchmarks/DASB/Libri2Mix/separation/crdnn/custom_model.py: -------------------------------------------------------------------------------- 1 | ../../../model/custom_model.py -------------------------------------------------------------------------------- /benchmarks/DASB/Libri2Mix/separation/crdnn/librimix_prepare.py: -------------------------------------------------------------------------------- 1 | ../../librimix_prepare.py -------------------------------------------------------------------------------- /benchmarks/DASB/Libri2Mix/separation/crdnn/metrics/spk_sim.py: -------------------------------------------------------------------------------- 1 | ../../metrics/spk_sim.py -------------------------------------------------------------------------------- /benchmarks/DASB/LibriSpeech/ASR/LSTM/librispeech_prepare.py: -------------------------------------------------------------------------------- 1 | ../../librispeech_prepare.py -------------------------------------------------------------------------------- /benchmarks/DASB/LibriSpeech/ASR/contextnet/custom_model.py: -------------------------------------------------------------------------------- 1 | ../../../model/custom_model.py -------------------------------------------------------------------------------- /benchmarks/DASB/LibriSpeech/quantization/librispeech_prepare.py: -------------------------------------------------------------------------------- 1 | ../librispeech_prepare.py -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceBank/enhancement/conformer/metrics/dnsmos.py: -------------------------------------------------------------------------------- 1 | ../../metrics/dnsmos.py -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceBank/enhancement/conformer/metrics/dwer.py: -------------------------------------------------------------------------------- 1 | ../../metrics/dwer.py -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceBank/enhancement/crdnn/metrics/dnsmos.py: -------------------------------------------------------------------------------- 1 | ../../metrics/dnsmos.py -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceBank/enhancement/crdnn/metrics/spk_sim.py: -------------------------------------------------------------------------------- 1 | ../../metrics/spk_sim.py -------------------------------------------------------------------------------- /benchmarks/MP3S/VoxCeleb1/Xvectors/voxceleb_prepare.py: -------------------------------------------------------------------------------- 1 | ../ecapa_tdnn/voxceleb_prepare.py -------------------------------------------------------------------------------- /benchmarks/DASB/CommonVoice/ASR/LSTM/common_voice_prepare.py: -------------------------------------------------------------------------------- 1 | ../../common_voice_prepare.py -------------------------------------------------------------------------------- /benchmarks/DASB/CommonVoice/ASR/linear/common_voice_prepare.py: -------------------------------------------------------------------------------- 1 | ../../common_voice_prepare.py -------------------------------------------------------------------------------- /benchmarks/DASB/CommonVoice/quantization/common_voice_prepare.py: -------------------------------------------------------------------------------- 1 | ../common_voice_prepare.py -------------------------------------------------------------------------------- /benchmarks/DASB/IEMOCAP/emotion_recognition/linear/iemocap_prepare.py: -------------------------------------------------------------------------------- 1 | ../../iemocap_prepare.py -------------------------------------------------------------------------------- /benchmarks/DASB/Libri2Mix/separation/conformer/custom_model.py: -------------------------------------------------------------------------------- 1 | ../../../model/custom_model.py -------------------------------------------------------------------------------- /benchmarks/DASB/Libri2Mix/separation/conformer/librimix_prepare.py: -------------------------------------------------------------------------------- 1 | ../../librimix_prepare.py -------------------------------------------------------------------------------- /benchmarks/DASB/Libri2Mix/separation/conformer/metrics/spk_sim.py: -------------------------------------------------------------------------------- 1 | ../../metrics/spk_sim.py -------------------------------------------------------------------------------- /benchmarks/DASB/LibriSpeech/ASR/contextnet/librispeech_prepare.py: -------------------------------------------------------------------------------- 1 | ../../librispeech_prepare.py -------------------------------------------------------------------------------- /benchmarks/DASB/SLURP/intent_classification/linear/slurp_prepare.py: -------------------------------------------------------------------------------- 1 | ../../slurp_prepare.py -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceBank/enhancement/conformer/custom_model.py: -------------------------------------------------------------------------------- 1 | ../../../model/custom_model.py -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceBank/enhancement/conformer/metrics/spk_sim.py: -------------------------------------------------------------------------------- 1 | ../../metrics/spk_sim.py -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceBank/enhancement/crdnn/custom_model.py: -------------------------------------------------------------------------------- 1 | ../../../model/custom_model.py -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceBank/enhancement/crdnn/voicebank_prepare.py: -------------------------------------------------------------------------------- 1 | ../../voicebank_prepare.py -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceCeleb1/speaker_ver/Xvector/custom_model.py: -------------------------------------------------------------------------------- 1 | ../../../model/custom_model.py -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceCeleb1/speaker_ver/Xvector/voxceleb_prepare.py: -------------------------------------------------------------------------------- 1 | ../../voxceleb_prepare.py -------------------------------------------------------------------------------- /benchmarks/MP3S/CommonVoice/linear/common_voice_prepare.py: -------------------------------------------------------------------------------- 1 | ../LSTM/common_voice_prepare.py -------------------------------------------------------------------------------- /benchmarks/MP3S/LibriSpeech/contextnet/librispeech_prepare.py: -------------------------------------------------------------------------------- 1 | ../LSTM/librispeech_prepare.py -------------------------------------------------------------------------------- /benchmarks/DASB/Google-speech-commands/keyword-spotting/Xvector/prepare_GSC.py: -------------------------------------------------------------------------------- 1 | ../../prepare_GSC.py -------------------------------------------------------------------------------- /benchmarks/DASB/IEMOCAP/emotion_recognition/ecapa_tdnn/iemocap_prepare.py: -------------------------------------------------------------------------------- 1 | ../../iemocap_prepare.py -------------------------------------------------------------------------------- /benchmarks/DASB/IEMOCAP/emotion_recognition/linear/custom_model.py: -------------------------------------------------------------------------------- 1 | ../../../model/custom_model.py -------------------------------------------------------------------------------- /benchmarks/DASB/SLURP/intent_classification/LSTM_linear/slurp_prepare.py: -------------------------------------------------------------------------------- 1 | ../../slurp_prepare.py -------------------------------------------------------------------------------- /benchmarks/DASB/SLURP/intent_classification/linear/custom_model.py: -------------------------------------------------------------------------------- 1 | ../../../model/custom_model.py -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceBank/enhancement/conformer/voicebank_prepare.py: -------------------------------------------------------------------------------- 1 | ../../voicebank_prepare.py -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceCeleb1/speaker_ver/ecapa_tdnn/custom_model.py: -------------------------------------------------------------------------------- 1 | ../../../model/custom_model.py -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceCeleb1/speaker_ver/ecapa_tdnn/voxceleb_prepare.py: -------------------------------------------------------------------------------- 1 | ../../voxceleb_prepare.py -------------------------------------------------------------------------------- /benchmarks/DASB/Google-speech-commands/keyword-spotting/ecapa_tdnn/prepare_GSC.py: -------------------------------------------------------------------------------- 1 | ../../prepare_GSC.py -------------------------------------------------------------------------------- /benchmarks/DASB/IEMOCAP/emotion_recognition/ecapa_tdnn/custom_model.py: -------------------------------------------------------------------------------- 1 | ../../../model/custom_model.py -------------------------------------------------------------------------------- /benchmarks/DASB/LJSpeech/quantization/extra-requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn 2 | tgt 3 | unidecode 4 | -------------------------------------------------------------------------------- /benchmarks/DASB/SLURP/intent_classification/LSTM_linear/custom_model.py: -------------------------------------------------------------------------------- 1 | ../../../model/custom_model.py -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """ Availing scripts for testing to be imported, i.e., code in tests/utils. 2 | """ 3 | -------------------------------------------------------------------------------- /benchmarks/DASB/Google-speech-commands/keyword-spotting/Xvector/custom_model.py: -------------------------------------------------------------------------------- 1 | ../../../model/custom_model.py -------------------------------------------------------------------------------- /benchmarks/DASB/Google-speech-commands/keyword-spotting/ecapa_tdnn/custom_model.py: -------------------------------------------------------------------------------- 1 | ../../../model/custom_model.py -------------------------------------------------------------------------------- /benchmarks/DASB/DASB_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechbrain/benchmarks/HEAD/benchmarks/DASB/DASB_logo.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "speechbrain"] 2 | path = speechbrain 3 | url = https://github.com/speechbrain/speechbrain.git 4 | -------------------------------------------------------------------------------- /tests/samples/ASR/spk1_snt1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechbrain/benchmarks/HEAD/tests/samples/ASR/spk1_snt1.wav -------------------------------------------------------------------------------- /tests/samples/ASR/spk1_snt2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechbrain/benchmarks/HEAD/tests/samples/ASR/spk1_snt2.wav -------------------------------------------------------------------------------- /tests/samples/ASR/spk1_snt3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechbrain/benchmarks/HEAD/tests/samples/ASR/spk1_snt3.wav -------------------------------------------------------------------------------- /tests/samples/ASR/spk1_snt4.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechbrain/benchmarks/HEAD/tests/samples/ASR/spk1_snt4.wav -------------------------------------------------------------------------------- /tests/samples/ASR/spk1_snt5.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechbrain/benchmarks/HEAD/tests/samples/ASR/spk1_snt5.wav -------------------------------------------------------------------------------- /tests/samples/ASR/spk1_snt6.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechbrain/benchmarks/HEAD/tests/samples/ASR/spk1_snt6.wav -------------------------------------------------------------------------------- /tests/samples/ASR/spk2_snt1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechbrain/benchmarks/HEAD/tests/samples/ASR/spk2_snt1.wav -------------------------------------------------------------------------------- /tests/samples/ASR/spk2_snt2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechbrain/benchmarks/HEAD/tests/samples/ASR/spk2_snt2.wav -------------------------------------------------------------------------------- /tests/samples/ASR/spk2_snt3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechbrain/benchmarks/HEAD/tests/samples/ASR/spk2_snt3.wav -------------------------------------------------------------------------------- /tests/samples/ASR/spk2_snt4.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechbrain/benchmarks/HEAD/tests/samples/ASR/spk2_snt4.wav -------------------------------------------------------------------------------- /tests/samples/ASR/spk2_snt5.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechbrain/benchmarks/HEAD/tests/samples/ASR/spk2_snt5.wav -------------------------------------------------------------------------------- /tests/samples/ASR/spk2_snt6.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechbrain/benchmarks/HEAD/tests/samples/ASR/spk2_snt6.wav -------------------------------------------------------------------------------- /lint-requirements.txt: -------------------------------------------------------------------------------- 1 | black==19.10b0 2 | click==8.0.4 3 | flake8==3.7.9 4 | pycodestyle==2.5.0 5 | pytest==5.4.1 6 | yamllint==1.23.0 7 | -------------------------------------------------------------------------------- /benchmarks/MOABB/hparams/orion/hparams_random_search.yaml: -------------------------------------------------------------------------------- 1 | experiment: 2 | algorithms: 3 | random: 4 | seed: 1986 5 | -------------------------------------------------------------------------------- /benchmarks/MP3S/extra_requirements.txt: -------------------------------------------------------------------------------- 1 | kenlm # kenlm is only required for ASR recipes when decoding using a language model, which is not the default option. 2 | -------------------------------------------------------------------------------- /benchmarks/DASB/Libri2Mix/separation/metrics/model_v8.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechbrain/benchmarks/HEAD/benchmarks/DASB/Libri2Mix/separation/metrics/model_v8.onnx -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceBank/enhancement/metrics/model_v8.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/speechbrain/benchmarks/HEAD/benchmarks/DASB/VoiceBank/enhancement/metrics/model_v8.onnx -------------------------------------------------------------------------------- /tests/.run-HF-checks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python -c 'from tests.utils.check_HF_repo import run_HF_check; print("TEST FAILED!") if not(run_HF_check()) else print("TEST PASSED!")' 3 | -------------------------------------------------------------------------------- /benchmarks/MOABB/extra-requirements.txt: -------------------------------------------------------------------------------- 1 | -e git+https://github.com/braindecode/braindecode.git#egg=braindecode 2 | mne 3 | moabb 4 | orion 5 | orion[profet] 6 | scikit-learn 7 | torchinfo 8 | -------------------------------------------------------------------------------- /benchmarks/MOABB/hparams/orion/hparams_tpe.yaml: -------------------------------------------------------------------------------- 1 | experiment: 2 | algorithms: 3 | tpe: 4 | seed: 1986 5 | n_initial_points: 20 6 | n_ei_candidates: 24 7 | -------------------------------------------------------------------------------- /tests/.run-url-checks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | pip install requests 3 | python -c 'from tests.utils.check_url import check_links; print("TEST FAILED!") if not(check_links()) else print("TEST PASSED!")' 4 | -------------------------------------------------------------------------------- /tests/.run-recipe-tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python -c 'from tests.utils.recipe_tests import run_recipe_tests; print("TEST FAILED!") if not(run_recipe_tests(run_opts="--device=cuda")) else print("TEST PASSED")' 3 | -------------------------------------------------------------------------------- /benchmarks/DASB/extra_requirements.txt: -------------------------------------------------------------------------------- 1 | beartype 2 | jsonlines 3 | librosa>=0.9.2 4 | onnxruntime>=1.16.3 5 | scikit-learn 6 | speechbrain>=1.0.0 7 | speechtokenizer>=0.1.2 8 | tensorboard 9 | tgt 10 | unidecode 11 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, E266, E501, W503 3 | # line length is intentionally set to 80 here because black uses Bugbear 4 | # See https://github.com/psf/black/blob/master/README.md#line-length for more details 5 | max-line-length = 80 6 | max-complexity = 18 7 | select = B,C,E,F,W,T4,B9 8 | -------------------------------------------------------------------------------- /benchmarks/DASB/LJSpeech/TTS/tokotron/hparams/char_en.txt: -------------------------------------------------------------------------------- 1 | A 2 | B 3 | C 4 | D 5 | E 6 | F 7 | G 8 | H 9 | I 10 | J 11 | K 12 | L 13 | M 14 | N 15 | O 16 | P 17 | Q 18 | R 19 | S 20 | T 21 | U 22 | V 23 | W 24 | X 25 | Y 26 | Z 27 | ' 28 | " 29 | ! 30 | ( 31 | ) 32 | , 33 | - 34 | . 35 | : 36 | ; 37 | ? 38 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 80 3 | target-version = ['py38'] 4 | exclude = ''' 5 | 6 | ( 7 | /( 8 | \.eggs # exclude a few common directories in the 9 | | \.git # root of the project 10 | | \.mypy_cache 11 | | \.tox 12 | | \.venv 13 | )/ 14 | ) 15 | ''' 16 | -------------------------------------------------------------------------------- /.yamllint.yaml: -------------------------------------------------------------------------------- 1 | extends: default 2 | 3 | rules: 4 | document-start: 5 | present: False 6 | truthy: 7 | allowed-values: 8 | - 'True' 9 | - 'False' 10 | level: error 11 | comments: disable 12 | comments-indentation: disable 13 | line-length: 14 | level: warning 15 | allow-non-breakable-inline-mappings: True 16 | -------------------------------------------------------------------------------- /tests/.run-linters.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e -u -o pipefail 3 | 4 | echo "===Black===" 5 | git ls-files | grep -E "\.py$" | xargs black --check --diff 6 | echo "===Flake8===" 7 | git ls-files | grep -E "\.py$" | xargs flake8 --count --statistics 8 | echo "===Yamllint===" 9 | git ls-files | grep -E "\.yaml$|\.yml$" | xargs yamllint --no-warnings 10 | -------------------------------------------------------------------------------- /tests/consistency/test_docstrings.py: -------------------------------------------------------------------------------- 1 | """Tests for checking the docstrings of functions and classes. 2 | 3 | Authors 4 | * Mirco Ravanelli 2022 5 | """ 6 | from tests.utils.check_docstrings import check_docstrings 7 | 8 | 9 | def test_recipe_list(base_folder="."): 10 | check_folders = ["speechbrain", "tools", "templates"] 11 | assert check_docstrings(base_folder, check_folders) 12 | -------------------------------------------------------------------------------- /benchmarks/DASB/LJSpeech/TTS/tokotron/hparams/arpabet.txt: -------------------------------------------------------------------------------- 1 | AA 2 | AE 3 | AH 4 | AO 5 | AW 6 | AY 7 | B 8 | CH 9 | D 10 | DH 11 | EH 12 | ER 13 | EY 14 | F 15 | G 16 | HH 17 | IH 18 | IY 19 | JH 20 | K 21 | L 22 | M 23 | N 24 | NG 25 | OW 26 | OY 27 | P 28 | R 29 | S 30 | SH 31 | T 32 | TH 33 | UH 34 | UW 35 | V 36 | W 37 | Y 38 | Z 39 | ZH 40 | ' 41 | " 42 | ! 43 | ( 44 | ) 45 | , 46 | - 47 | . 48 | : 49 | ; 50 | ? 51 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: SpeechBrain pre-commit 2 | 3 | on: # yamllint disable-line rule:truthy 4 | push: 5 | branches: [main] 6 | pull_request: 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v2 13 | - uses: actions/setup-python@v2 14 | with: 15 | python-version: '3.12' 16 | - uses: pre-commit/action@v3.0.1 17 | -------------------------------------------------------------------------------- /tests/.run-doctests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e -u -o pipefail 3 | 4 | # To run doctests locally, the easiest approach is to do: 5 | # > pytest --doctest-modules speechbrain/ 6 | # However, we take this more complex approach to avoid testing files not 7 | # tracked by git. We filter out tests that require optional dependencies. 8 | avoid="transducer_loss.py\|fairseq_wav2vec.py\|huggingface_wav2vec.py\|bleu.py\|ctc_segmentation.py\|check_url.py\|huggingface_whisper.py" 9 | git ls-files speechbrain | grep -e "\.py$" | grep -v $avoid | xargs pytest --doctest-modules 10 | -------------------------------------------------------------------------------- /tests/.run-load-yaml-tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | pip install pesq 3 | pip install pystoi 4 | pip install librosa 5 | pip install tensorboard 6 | pip install transformers 7 | # avoid list: these yamls cause segfaults (on a GPU node) 8 | # pip install git+https://github.com/jfsantos/SRMRpy 9 | python -c 'from tests.utils.recipe_tests import load_yaml_test; print("TEST FAILED!") if not(load_yaml_test(avoid_list=["recipes/Voicebank/dereverb/MetricGAN-U/hparams/train_dereverb.yaml", "recipes/Voicebank/dereverb/spectral_mask/hparams/train.yaml", "recipes/Voicebank/enhance/MetricGAN-U/hparams/train_dnsmos.yaml", "recipes/Voicebank/enhance/MetricGAN/hparams/train.yaml", "recipes/Voicebank/enhance/spectral_mask/hparams/train.yaml", "recipes/Voicebank/enhance/waveform_map/hparams/train.yaml"])) else print("TEST PASSED")' 10 | -------------------------------------------------------------------------------- /benchmarks/MP3S/run_benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | # Please consult the README.md file for instructions on how to run the benchmark. 3 | 4 | hub='facebook/hubert-large-ll60k' 5 | num_layers='25' 6 | encoder_dim='1024' 7 | output_folder='/path/to/output' 8 | declare -a DatasetsFolders= ('path/to/LibriSpeech' 'path/to/IEMOCAP') 9 | declare -a ConsideredTasks=('LibriSpeechASR' 'IEMOCAP') 10 | declare -a DownStreams=('BiLSTM' 'ecapa_tdnn') 11 | for i in "${!ConsideredTasks[@]}"; do 12 | task=${ConsideredTasks[i]} 13 | downstream=${DownStreams[i]} 14 | dataset_folder =${DatasetsFolders[i]} 15 | python $task/$downstream/train.py $task/$downstream/hparams/ssl.yaml --num_layers_ssl $num_layers --ssl_hub $hub --encoder_dim $encoder_dim --output_folder $output_folder/$task/$downstream --data_folder $dataset_folder 16 | done 17 | -------------------------------------------------------------------------------- /tests/utils/overrides.yaml: -------------------------------------------------------------------------------- 1 | LibriSpeech_data: !PLACEHOLDER 2 | CommonVoice_EN_data: !PLACEHOLDER 3 | CommonVoice_FR_data: !PLACEHOLDER 4 | IEMOCAP_data: !PLACEHOLDER 5 | 6 | new_interfaces_git: https://github.com/speechbrain/speechbrain 7 | new_interfaces_branch: hf-interface-testing 8 | new_interfaces_local_dir: tests/tmp/hf_interfaces 9 | 10 | # Filter HF repos (will be used in a local glob dir crawling) 11 | # glob_filter: "*wav2vec2*" 12 | # glob_filter: "*libri*" 13 | glob_filter: "*" 14 | 15 | # put False to test 'before' only, e.g. via override 16 | after: True 17 | 18 | LibriSpeech: 19 | data_folder: !ref 20 | skip_prep: True 21 | 22 | CommonVoice_EN: 23 | data_folder: !ref 24 | 25 | CommonVoice_FR: 26 | data_folder: !ref 27 | 28 | IEMOCAP: 29 | data_folder: !ref 30 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 # Use the ref you want to point at 4 | hooks: 5 | - id: trailing-whitespace 6 | types: [file, text] 7 | exclude: ".*char_[a-z]{2}.txt" 8 | - id: end-of-file-fixer 9 | types: [python] 10 | - id: requirements-txt-fixer 11 | - id: mixed-line-ending 12 | types: [python] 13 | args: [--fix=no] 14 | - id: check-added-large-files 15 | args: [--maxkb=1024] 16 | 17 | - repo: https://github.com/psf/black 18 | rev: 19.10b0 19 | hooks: 20 | - id: black 21 | types: [python] 22 | additional_dependencies: ['click==8.0.4'] 23 | - repo: https://github.com/PyCQA/flake8 24 | rev: 7.0.0 25 | hooks: 26 | - id: flake8 27 | types: [python] 28 | 29 | - repo: https://github.com/adrienverge/yamllint 30 | rev: v1.23.0 31 | hooks: 32 | - id: yamllint 33 | -------------------------------------------------------------------------------- /benchmarks/DASB/LJSpeech/TTS/tokotron/train_encodec.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env/python3 2 | """Recipe for training a Text-to-Speech system based on tokenized audio - Encodec version 3 | 4 | Inspired by WhisperSpeech 5 | https://github.com/collabora/WhisperSpeech 6 | 7 | However, this is not an implementation of WhisperSpeech, but rather 8 | a radical simplification of it that uses only an acoustic model 9 | 10 | 11 | Authors 12 | * Artem Ploujnikov 2024 13 | """ 14 | 15 | from train import TokotronBrain, run_experiment 16 | from speechbrain.dataio.dataio import clean_padding_ 17 | 18 | 19 | class TokotronEncodecBrain(TokotronBrain): 20 | """Tokotron implementation for Encodec""" 21 | 22 | def create_waveform(self, audio, length): 23 | """Creates a waveform from a discrete or continuous audio 24 | representation 25 | 26 | Arguments 27 | --------- 28 | audio : torch.Tensor 29 | An audio tensor (Batch x Length x Heads or Batch x Length x Heads x Features) 30 | lengths : torch.Tensor 31 | A 1-D tensor 32 | 33 | Returns 34 | ------- 35 | wav : torch.Tensor 36 | """ 37 | wav = self.modules.token_model.decode(audio) 38 | wav = wav.squeeze(1) 39 | clean_padding_(wav, length) 40 | return wav 41 | 42 | 43 | if __name__ == "__main__": 44 | run_experiment(TokotronEncodecBrain) 45 | -------------------------------------------------------------------------------- /benchmarks/DASB/LJSpeech/TTS/tokotron/train_speech_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env/python3 2 | """Recipe for training a Text-to-Speech system based on tokenized audio - Encodec version 3 | 4 | Inspired by WhisperSpeech 5 | https://github.com/collabora/WhisperSpeech 6 | 7 | However, this is not an implementation of WhisperSpeech, but rather 8 | a radical simplification of it that uses only an acoustic model 9 | 10 | 11 | Authors 12 | * Artem Ploujnikov 2024 13 | """ 14 | 15 | from train import TokotronBrain, run_experiment 16 | from speechbrain.dataio.dataio import clean_padding_ 17 | 18 | 19 | class TokotronSTBrain(TokotronBrain): 20 | """Tokotron implementation for Encodec""" 21 | 22 | def create_waveform(self, audio, length): 23 | """Creates a waveform from a discrete or continuous audio 24 | representation 25 | 26 | Arguments 27 | --------- 28 | audio : torch.Tensor 29 | An audio tensor (Batch x Length x Heads or Batch x Length x Heads x Features) 30 | lengths : torch.Tensor 31 | A 1-D tensor 32 | 33 | Returns 34 | ------- 35 | wav : torch.Tensor 36 | """ 37 | wav = self.modules.token_model.decode(audio) 38 | if length is not None: 39 | clean_padding_(wav, length) 40 | return wav 41 | 42 | 43 | if __name__ == "__main__": 44 | run_experiment(TokotronSTBrain) 45 | -------------------------------------------------------------------------------- /benchmarks/DASB/LJSpeech/TTS/tokotron/train_continuous_ssl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env/python3 2 | """Recipe for training a Text-to-Speech system based on tokenized audio 3 | Continuous SSL verfsion 4 | 5 | Inspired by WhisperSpeech 6 | https://github.com/collabora/WhisperSpeech 7 | 8 | However, this is not an implementation of WhisperSpeech, but rather 9 | a radical simplification of it that uses only an acoustic model 10 | 11 | 12 | Authors 13 | * Artem Ploujnikov 2024 14 | """ 15 | 16 | from train import TokotronBrain, run_experiment 17 | from speechbrain.dataio.dataio import clean_padding_ 18 | 19 | 20 | class TokotronContinuousSSLBrain(TokotronBrain): 21 | """Tokotron implementation for Encodec""" 22 | 23 | def create_waveform(self, audio, length): 24 | """Creates a waveform from a discrete or continuous audio 25 | representation 26 | 27 | Arguments 28 | --------- 29 | audio : torch.Tensor 30 | An audio tensor (Batch x Length x Heads or Batch x Length x Heads x Features) 31 | lengths : torch.Tensor 32 | A 1-D tensor 33 | 34 | Returns 35 | ------- 36 | wav : torch.Tensor 37 | """ 38 | wav = self.modules.vocoder(audio) 39 | wav = wav.squeeze(1) 40 | clean_padding_(wav, length) 41 | return wav 42 | 43 | 44 | if __name__ == "__main__": 45 | run_experiment(TokotronContinuousSSLBrain) 46 | -------------------------------------------------------------------------------- /tests/samples/annotation/ASR_train.csv: -------------------------------------------------------------------------------- 1 | ID, duration,start,stop,wav,clean_wav,noisy_wav,s1_wav,s2_wav,mix_wav,target,spk_id,ali,phn,char,wrd,text,transcript,semantics,command 2 | spk1_snt1,2.87,0,16000,tests/samples/ASR/spk1_snt1.wav,tests/samples/ASR/spk1_snt1.wav,tests/samples/ASR/spk1_snt1.wav,tests/samples/ASR/spk1_snt1.wav,tests/samples/ASR/spk1_snt1.wav,tests/samples/ASR/spk1_snt1.wav,tests/samples/ASR/spk1_snt1.wav,spk1,tests/samples/ASR/spk1_snt1.pkl,dh ax cl ch ay l vcl d ao l m ow s cl t hh er cl t sil dh ax s m ao l vcl d ao vcl,t h e c h i l d a l m o s t h u r t t h e s m a l l d o g ,the child almost hurt the small dog,the child almost hurt the small dog,the child almost hurt the small dog,{'scenario': 'calendar'| 'action': 'set'| 'entities': []},{'scenario': 'calendar'| 'action': 'set'| 'entities': []} 3 | spk1_snt2,3.15,0,16000,tests/samples/ASR/spk1_snt2.wav,tests/samples/ASR/spk1_snt2.wav,tests/samples/ASR/spk1_snt2.wav,tests/samples/ASR/spk1_snt2.wav,tests/samples/ASR/spk1_snt2.wav,tests/samples/ASR/spk1_snt2.wav,tests/samples/ASR/spk1_snt2.wav,spk1,tests/samples/ASR/spk1_snt2.pkl,vcl d r aa cl p dh ax cl t uw sil w eh n y uw ae vcl d dh ax f ih vcl g y er,d r o p t h e t u e w h e n y o u a d d t h e f i g u r e s,drop the tue when you add the figures,drop the tue when you add the figures,drop the tue when you add the figures,{'scenario': 'calendar'| 'action': 'set'| 'entities': []},{'scenario': 'calendar'| 'action': 'set'| 'entities': []} 4 | -------------------------------------------------------------------------------- /tests/samples/annotation/ASR_Buckeye.csv: -------------------------------------------------------------------------------- 1 | ID, duration,start_seg,end_seg,wav,clean_wav,noisy_wav,s1_wav,s2_wav,mix_wav,target,spk_id,ali,phn,char,wrd,text,transcript,semantics,command 2 | spk1_snt1,2.0,0,2.0,tests/samples/ASR/spk1_snt1.wav,tests/samples/ASR/spk1_snt1.wav,tests/samples/ASR/spk1_snt1.wav,tests/samples/ASR/spk1_snt1.wav,tests/samples/ASR/spk1_snt1.wav,tests/samples/ASR/spk1_snt1.wav,tests/samples/ASR/spk1_snt1.wav,spk1,tests/samples/ASR/spk1_snt1.pkl,dh ax cl ch ay l vcl d ao l m ow s cl t hh er cl t sil dh ax s m ao l vcl d ao vcl,t h e c h i l d a l m o s t h u r t t h e s m a l l d o g ,the child almost hurt the small dog,the child almost hurt the small dog,the child almost hurt the small dog,{'scenario': 'calendar'| 'action': 'set'| 'entities': []},{'scenario': 'calendar'| 'action': 'set'| 'entities': []} 3 | spk1_snt2,3.0,0,3.0,tests/samples/ASR/spk1_snt2.wav,tests/samples/ASR/spk1_snt2.wav,tests/samples/ASR/spk1_snt2.wav,tests/samples/ASR/spk1_snt2.wav,tests/samples/ASR/spk1_snt2.wav,tests/samples/ASR/spk1_snt2.wav,tests/samples/ASR/spk1_snt2.wav,spk1,tests/samples/ASR/spk1_snt2.pkl,vcl d r aa cl p dh ax cl t uw sil w eh n y uw ae vcl d dh ax f ih vcl g y er,d r o p t h e t u e w h e n y o u a d d t h e f i g u r e s,drop the tue when you add the figures,drop the tue when you add the figures,drop the tue when you add the figures,{'scenario': 'calendar'| 'action': 'set'| 'entities': []},{'scenario': 'calendar'| 'action': 'set'| 'entities': []} 4 | -------------------------------------------------------------------------------- /benchmarks/DASB/LJSpeech/TTS/tokotron/train_dac.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env/python3 2 | """Recipe for training a Text-to-Speech system based on tokenized audio - DAC version 3 | 4 | Inspired by WhisperSpeech 5 | https://github.com/collabora/WhisperSpeech 6 | 7 | However, this is not an implementation of WhisperSpeech, but rather 8 | a radical simplification of it that uses only an acoustic model 9 | 10 | 11 | Authors 12 | * Artem Ploujnikov 2024 13 | """ 14 | from train import TokotronBrain, run_experiment 15 | from speechbrain.dataio.dataio import clean_padding_ 16 | 17 | 18 | class TokotronDACBrain(TokotronBrain): 19 | """Tokotron implementation for Encodec""" 20 | 21 | def create_waveform(self, audio, length): 22 | """Creates a waveform from a discrete or continuous audio 23 | representation 24 | 25 | Arguments 26 | --------- 27 | audio : torch.Tensor 28 | An audio tensor (Batch x Length x Heads or Batch x Length x Heads x Features) 29 | lengths : torch.Tensor 30 | A 1-D tensor 31 | 32 | Returns 33 | ------- 34 | wav : torch.Tensor 35 | """ 36 | z, _, _ = self.modules.dac.quantizer.from_codes( 37 | audio.transpose(1, 2).int() 38 | ) 39 | wav = self.modules.dac.decode(z).squeeze(1) 40 | clean_padding_(wav, length) 41 | return wav 42 | 43 | 44 | if __name__ == "__main__": 45 | run_experiment(TokotronDACBrain) 46 | -------------------------------------------------------------------------------- /benchmarks/DASB/LJSpeech/TTS/tokotron/hparams/eval.yaml: -------------------------------------------------------------------------------- 1 | eval_sample_rate: 16000 2 | eval_samples: null 3 | eval_interval: 1 4 | eval_asr_type: whisper 5 | eval_asr_source: !apply:speechbrain.utils.hparams.choice 6 | value: !ref 7 | choices: 8 | encoder_decoder: speechbrain/asr-transformer-transformerlm-librispeech 9 | whisper: openai/whisper-small 10 | evaluations: utmos,asr 11 | tmp_folder: null 12 | utmos_batch_size: 8 13 | utmos_model_path: ./utmos 14 | utmos_ckpt_name: epoch=3-step=7459.ckpt 15 | utmos_ckpt_path: !ref / 16 | utmos_use_python: True 17 | utmos_script: predict.py 18 | 19 | 20 | eval_asr: !apply:speechbrain.utils.hparams.choice 21 | value: !ref 22 | choices: 23 | encoder_decoder: !name:eval.EncoderDecoderASRSpeechEvaluator 24 | source: !ref 25 | sample_rate: !ref 26 | overrides: 27 | lm_weight: 0.0 28 | whisper: !name:eval.WhisperASRSpeechEvaluator 29 | source: !ref 30 | sample_rate: !ref 31 | savedir: !ref 32 | 33 | evaluators: 34 | asr: !ref 35 | 36 | bulk_evaluators: 37 | utmos: !name:eval.UTMOSSpeechEvaluator 38 | model_path: !ref 39 | output_folder: !ref 40 | ckpt_path: !ref 41 | batch_size: !ref 42 | script: !ref 43 | use_python: !ref 44 | tmp_folder: !ref 45 | 46 | eval_summary: 47 | asr: 48 | descriptive: ["wer", "cer", "wer_ref", "cer_ref", "dwer", "dcer"] 49 | utmos: 50 | descriptive: ["utmos"] 51 | -------------------------------------------------------------------------------- /tests/PRE-RELEASE-TESTS.md: -------------------------------------------------------------------------------- 1 | # Pre-release Tests 2 | 3 | 1. Create a new environment. For instance, using conda: 4 | ``` 5 | conda create --name fresh_env python=3.9 6 | ``` 7 | 2. Activate the new environment 8 | ``` 9 | conda activate fresh_env 10 | ``` 11 | 3. Clone the dev version of SpeechBrain 12 | https://github.com/speechbrain/speechbrain 13 | 14 | 4. Install the extra-dependencies 15 | ``` 16 | cd speechbrain 17 | pip install -r requirements.txt 18 | ``` 19 | 5. Install SpeechBrain 20 | ``` 21 | pip install -e . 22 | ``` 23 | 6. Install all recipe extra-dependencies (check for latest/fixed versions) 24 | ``` 25 | find recipes | grep extra | xargs cat | sort -u | grep -v \# | xargs -I {} pip install {} 26 | pip install fairseq 27 | conda install 'ffmpeg<4.4' 28 | ``` 29 | 7. Run the basic tests by typing: 30 | ``` 31 | pytest 32 | ``` 33 | 8. Run load yaml test: 34 | ``` 35 | tests/.run-load-yaml-tests.sh 36 | ``` 37 | 9. Run recipe tests 38 | ``` 39 | tests/.run-recipe-tests.sh 40 | ``` 41 | 10. Make sure all HuggingFace repos are working 42 | ``` 43 | tests/.run-HF-checks.sh 44 | ``` 45 | 11. Check URLs 46 | ``` 47 | tests/.run-url-checks.sh 48 | ``` 49 | 50 | Make sure all the tests are passing. Also, make sure to check that the tutorials are working (we might set up an automatic test for that as well in the future). 51 | 52 | # Maintainer checks for releases 53 | 54 | Up until here, all the above madness should have settled. 55 | Commit logs outline what happened; features are summarized. 56 | 57 | _Note: a good point to check https://speechbrain.github.io/ is up-to-date._ 58 | 59 | The task at hand is: 60 | * change the version number; 61 | * compile a change log, and 62 | * release the latest version on PyPI. 63 | 64 | Another CI/CD lifecycle begins. 65 | -------------------------------------------------------------------------------- /tests/consistency/test_yaml.py: -------------------------------------------------------------------------------- 1 | """Consistency check between yaml files and script files. 2 | 3 | Authors 4 | * Mirco Ravanelli 2022 5 | """ 6 | import os 7 | import csv 8 | from tests.consistency.test_recipe import __skip_list 9 | from tests.utils.check_yaml import check_yaml_vs_script 10 | 11 | 12 | def test_yaml_script_consistency(recipe_folder="tests/recipes"): 13 | """This test checks the consistency between yaml files (used to specify 14 | hyperparameters) and script files (that implement the training recipe). 15 | 16 | Arguments 17 | --------- 18 | recipe_folder : path 19 | Path of the folder with csv files containing the training scripts with their coupled 20 | yaml files (with colums called 'Hparam_file', 'Script_file', 'Data_prep_file') 21 | """ 22 | 23 | # Use this list to itemize special yaml for which we do not have to test 24 | avoid_check = [] 25 | 26 | # Loop over all recipe CSVs 27 | for recipe_csvfile in os.listdir(recipe_folder): 28 | if recipe_csvfile in __skip_list: 29 | continue 30 | with open( 31 | os.path.join(recipe_folder, recipe_csvfile), newline="" 32 | ) as csvfile: 33 | check = True 34 | reader = csv.DictReader( 35 | csvfile, delimiter=",", skipinitialspace=True 36 | ) 37 | for row in reader: 38 | 39 | # Avoid checks 40 | if row["Hparam_file"] in avoid_check: 41 | continue 42 | 43 | # Check yaml-script consistency 44 | if not ( 45 | check_yaml_vs_script(row["Hparam_file"], row["Script_file"]) 46 | ): 47 | check = False 48 | 49 | # Check module variables 50 | # if not (check_module_vars(row["Hparam_file"], row["Script_file"])): 51 | # check = False 52 | 53 | assert check 54 | -------------------------------------------------------------------------------- /benchmarks/DASB/run_discriminative_benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Please consult the README.md file for instructions on how to run the benchmark. 3 | 4 | tokenizer_name=$1 5 | if [[ "$tokenizer_name" == "" ]]; then 6 | echo "Usage: run_generative_benchmark.sh " 7 | exit 1 8 | fi 9 | 10 | output_folder='/path/to/output' 11 | declare -a DatasetsFolders=('path/to/LibriSpeech' 'path/to/CommonVoice' 'path/to/IEMOCAP' 'path/to/SLURP' 'path/to/Google-speech-commands' 'path/to/VoiceCeleb1') 12 | declare -a ConsideredTasks=('LibriSpeech/ASR' 'CommonVoice/ASR' 'IEMOCAP/emotion_recognition' 'SLURP/intent_classification' 'Google-speech-commands/keyword-spotting' 'VoiceCeleb1/speaker_ver') 13 | declare -a DownStreams=('LSTM' 'LSTM' 'ecapa_tdnn' 'LSTM_linear' 'Xvector','Xvector') 14 | declare -a Locales=('cy' 'eu') 15 | declare -a LocalesVobSize=(100 200) 16 | 17 | shift 18 | script_args="$@" 19 | 20 | for i in "${!ConsideredTasks[@]}"; do 21 | task=${ConsideredTasks[i]} 22 | downstream=${DownStreams[i]} 23 | dataset_folder=${DatasetsFolders[i]} 24 | recipe_extra_args="$script_args" 25 | set -- "$recipe_extra_args" 26 | if [[ "$task" == "CommonVoice/ASR" ]]; then 27 | echo "${tokenizer_name}/${task}/${downstream}" 28 | for j in "${!Locales[@]}"; do 29 | locale=${Locales[j]} 30 | vocab=${LocalesVobSize[j]} 31 | python $task/$downstream/train_$tokenizer_name.py $task/$downstream/hparams/train_$tokenizer_name.yaml --output_folder $output_folder/$tokenizer_name/$task/$downstream/$locale --data_folder $dataset_folder/$locale --language $locale --output_neurons $vocab $@ 32 | done 33 | else 34 | python $task/$downstream/train_$tokenizer_name.py $task/$downstream/hparams/train_$tokenizer_name.yaml --output_folder $output_folder/$tokenizer_name/$task/$downstream --data_folder $dataset_folder $@ 35 | fi 36 | done 37 | -------------------------------------------------------------------------------- /benchmarks/DASB/IEMOCAP/quantization/README.md: -------------------------------------------------------------------------------- 1 | 2 | # K-means (Quantization) 3 | This folder contains recipes for training K-means clustering model for the IEMOCAP Dataset. 4 | The model serves to quantize self-supervised representations into discrete representation. Thus representations can be used as a discrete audio input for various tasks including classification, ASR and speech generation. 5 | It supports kmeans model using the features from HuBERT, WAVLM or Wav2Vec. 6 | 7 | You can download IEMOCAP at https://sail.usc.edu/iemocap/ 8 | 9 | ## Installing Extra Dependencies 10 | 11 | Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal: 12 | 13 | ``` 14 | pip install -r extra_requirements.txt 15 | ``` 16 | 17 | # How to run: 18 | ```shell 19 | python train.py hparams/train_with_{SSL_model}.yaml 20 | ``` 21 | 22 | # Results 23 | 24 | The output folders with checkpoints and logs can be found [here](https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0). 25 | 26 | The checkpoints can be also found at [this](https://huggingface.co/speechbrain/SSL_Quantization) HuggingFace repository. 27 | 28 | 29 | 30 | # **About SpeechBrain** 31 | - Website: https://speechbrain.github.io/ 32 | - Code: https://github.com/speechbrain/speechbrain/ 33 | - HuggingFace: https://huggingface.co/speechbrain/ 34 | 35 | 36 | # **Citing SpeechBrain** 37 | Please, cite SpeechBrain if you use it for your research or business. 38 | 39 | ```bibtex 40 | @misc{speechbrain, 41 | title={{SpeechBrain}: A General-Purpose Speech Toolkit}, 42 | author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio}, 43 | year={2021}, 44 | eprint={2106.04624}, 45 | archivePrefix={arXiv}, 46 | primaryClass={eess.AS}, 47 | note={arXiv:2106.04624} 48 | } 49 | ``` 50 | -------------------------------------------------------------------------------- /benchmarks/DASB/LJSpeech/quantization/README.md: -------------------------------------------------------------------------------- 1 | 2 | # K-means (Quantization) 3 | This folder contains recipes for training K-means clustering model for the LJSpeech Dataset. 4 | The model serves to quantize self-supervised representations into discrete representation. Thus representations can be used as a discrete audio input for various tasks including classification, ASR and speech generation. 5 | It supports kmeans model using the features from HuBERT, WAVLM or Wav2Vec. 6 | 7 | You can download LibriSpeech at http://www.openslr.org/12 8 | 9 | ## Installing Extra Dependencies 10 | 11 | Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal: 12 | 13 | ``` 14 | pip install -r extra_requirements.txt 15 | ``` 16 | 17 | # How to run: 18 | ```shell 19 | python train.py hparams/train_with_{SSL_model}.yaml 20 | ``` 21 | 22 | # Results 23 | 24 | The output folders with checkpoints and logs can be found [here](https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0). 25 | 26 | The checkpoints can be also found at [this](https://huggingface.co/speechbrain/SSL_Quantization) HuggingFace repository. 27 | 28 | 29 | 30 | # **About SpeechBrain** 31 | - Website: https://speechbrain.github.io/ 32 | - Code: https://github.com/speechbrain/speechbrain/ 33 | - HuggingFace: https://huggingface.co/speechbrain/ 34 | 35 | 36 | # **Citing SpeechBrain** 37 | Please, cite SpeechBrain if you use it for your research or business. 38 | 39 | ```bibtex 40 | @misc{speechbrain, 41 | title={{SpeechBrain}: A General-Purpose Speech Toolkit}, 42 | author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio}, 43 | year={2021}, 44 | eprint={2106.04624}, 45 | archivePrefix={arXiv}, 46 | primaryClass={eess.AS}, 47 | note={arXiv:2106.04624} 48 | } 49 | ``` 50 | -------------------------------------------------------------------------------- /benchmarks/DASB/LibriSpeech/quantization/README.md: -------------------------------------------------------------------------------- 1 | 2 | # K-means (Quantization) 3 | This folder contains recipes for training K-means clustering model for the LibriSpeech Dataset. 4 | The model serves to quantize self-supervised representations into discrete representation. Thus representations can be used as a discrete audio input for various tasks including classification, ASR and speech generation. 5 | It supports kmeans model using the features from HuBERT, WAVLM or Wav2Vec. 6 | 7 | You can download LibriSpeech at http://www.openslr.org/12 8 | 9 | ## Installing Extra Dependencies 10 | 11 | Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal: 12 | 13 | ``` 14 | pip install -r extra_requirements.txt 15 | ``` 16 | 17 | # How to run: 18 | ```shell 19 | python train.py hparams/train_with_{SSL_model}.yaml 20 | ``` 21 | 22 | # Results 23 | 24 | The output folders with checkpoints and logs can be found [here](https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0). 25 | 26 | The checkpoints can be also found at [this](https://huggingface.co/speechbrain/SSL_Quantization) HuggingFace repository. 27 | 28 | 29 | 30 | # **About SpeechBrain** 31 | - Website: https://speechbrain.github.io/ 32 | - Code: https://github.com/speechbrain/speechbrain/ 33 | - HuggingFace: https://huggingface.co/speechbrain/ 34 | 35 | 36 | # **Citing SpeechBrain** 37 | Please, cite SpeechBrain if you use it for your research or business. 38 | 39 | ```bibtex 40 | @misc{speechbrain, 41 | title={{SpeechBrain}: A General-Purpose Speech Toolkit}, 42 | author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio}, 43 | year={2021}, 44 | eprint={2106.04624}, 45 | archivePrefix={arXiv}, 46 | primaryClass={eess.AS}, 47 | note={arXiv:2106.04624} 48 | } 49 | ``` 50 | -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceCeleb1/quantization/README.md: -------------------------------------------------------------------------------- 1 | 2 | # K-means (Quantization) 3 | This folder contains recipes for training K-means clustering model for the LibriSpeech Dataset. 4 | The model serves to quantize self-supervised representations into discrete representation. Thus representations can be used as a discrete audio input for various tasks including classification, ASR and speech generation. 5 | It supports kmeans model using the features from HuBERT, WAVLM or Wav2Vec. 6 | 7 | You can download LibriSpeech at http://www.openslr.org/12 8 | 9 | ## Installing Extra Dependencies 10 | 11 | Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal: 12 | 13 | ``` 14 | pip install -r extra_requirements.txt 15 | ``` 16 | 17 | # How to run: 18 | ```shell 19 | python train.py hparams/train_with_{SSL_model}.yaml 20 | ``` 21 | 22 | # Results 23 | 24 | The output folders with checkpoints and logs can be found [here](https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0). 25 | 26 | The checkpoints can be also found at [this](https://huggingface.co/speechbrain/SSL_Quantization) HuggingFace repository. 27 | 28 | 29 | 30 | # **About SpeechBrain** 31 | - Website: https://speechbrain.github.io/ 32 | - Code: https://github.com/speechbrain/speechbrain/ 33 | - HuggingFace: https://huggingface.co/speechbrain/ 34 | 35 | 36 | # **Citing SpeechBrain** 37 | Please, cite SpeechBrain if you use it for your research or business. 38 | 39 | ```bibtex 40 | @misc{speechbrain, 41 | title={{SpeechBrain}: A General-Purpose Speech Toolkit}, 42 | author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio}, 43 | year={2021}, 44 | eprint={2106.04624}, 45 | archivePrefix={arXiv}, 46 | primaryClass={eess.AS}, 47 | note={arXiv:2106.04624} 48 | } 49 | ``` 50 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # IPython 78 | profile_default/ 79 | ipython_config.py 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/* 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | .dmypy.json 112 | dmypy.json 113 | 114 | # Pyre type checker 115 | .pyre/ 116 | 117 | # PyCharm 118 | .idea/ 119 | 120 | # Other 121 | random_idxes.txt 122 | tokenizer/ 123 | tokenizer.zip 124 | 125 | # Data 126 | CL-MASR/ 127 | 128 | # Results 129 | results/ 130 | -------------------------------------------------------------------------------- /benchmarks/DASB/CommonVoice/quantization/README.md: -------------------------------------------------------------------------------- 1 | 2 | # K-means (Quantization) 3 | This folder contains recipes for training K-means clustering model for the CommonVoice Dataset. 4 | The model serves to quantize self-supervised representations into discrete representation. Thus representations can be used as a discrete audio input for various tasks including classification, ASR and speech generation. 5 | It supports kmeans model using the features from HuBERT, WAVLM or Wav2Vec. 6 | 7 | You can download CommonVoice at https://commonvoice.mozilla.org/en 8 | 9 | ## Installing Extra Dependencies 10 | 11 | Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal: 12 | 13 | ``` 14 | pip install -r extra_requirements.txt 15 | ``` 16 | 17 | # How to run: 18 | ```shell 19 | python train.py hparams/train_with_{SSL_model}.yaml 20 | ``` 21 | 22 | # Results 23 | 24 | The output folders with checkpoints and logs can be found [here](https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0). 25 | 26 | The checkpoints can be also found at [this](https://huggingface.co/speechbrain/SSL_Quantization) HuggingFace repository. 27 | 28 | 29 | 30 | # **About SpeechBrain** 31 | - Website: https://speechbrain.github.io/ 32 | - Code: https://github.com/speechbrain/speechbrain/ 33 | - HuggingFace: https://huggingface.co/speechbrain/ 34 | 35 | 36 | # **Citing SpeechBrain** 37 | Please, cite SpeechBrain if you use it for your research or business. 38 | 39 | ```bibtex 40 | @misc{speechbrain, 41 | title={{SpeechBrain}: A General-Purpose Speech Toolkit}, 42 | author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio}, 43 | year={2021}, 44 | eprint={2106.04624}, 45 | archivePrefix={arXiv}, 46 | primaryClass={eess.AS}, 47 | note={arXiv:2106.04624} 48 | } 49 | ``` 50 | -------------------------------------------------------------------------------- /benchmarks/DASB/run_generative_benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Please consult the README.md file for instructions on how to run the benchmark. 3 | 4 | tokenizer_name=$1 5 | if [[ "$tokenizer_name" == "" ]]; then 6 | echo "Usage: run_generative_benchmark.sh " 7 | exit 1 8 | fi 9 | 10 | output_folder='path/to/output' 11 | librimix_path='path/to/Libri2Mix' 12 | voicebank_path='path/to/VoiceBank' 13 | ljspeech_path='path/to/ljspeech' 14 | utmos_path='path/to/utmos' 15 | tts_args="--token_list_file_text %recipe_root%/hparams/char_en.txt --utmos_model_path $utmos_path" 16 | 17 | declare -a DatasetsFolders=(\ 18 | "$librimix_path" \ 19 | "$voicebank_path" \ 20 | "$ljspeech_path" \ 21 | "$ljspeech_path" \ 22 | ) 23 | declare -a ConsideredTasks=(\ 24 | 'Libri2Mix/separation' \ 25 | 'VoiceBank/enhancement' \ 26 | 'LJSpeech/TTS' \ 27 | 'LJSpeech/TTS' \ 28 | ) 29 | declare -a DownStreams=(\ 30 | 'conformer' \ 31 | 'conformer' \ 32 | 'tokotron' \ 33 | 'tokotron' \ 34 | ) 35 | declare -a ExtraArgs=(\ 36 | '' \ 37 | '' \ 38 | "$tts_args" \ 39 | "$tts_args --enc_num_layers 3 --dec_num_layers 6" \ 40 | ) 41 | 42 | declare -a OutputSuffix=(\ 43 | '' \ 44 | '' \ 45 | '' \ 46 | '-small' 47 | ) 48 | 49 | shift 50 | script_args="$@" 51 | 52 | for i in "${!ConsideredTasks[@]}"; do 53 | task=${ConsideredTasks[i]} 54 | downstream=${DownStreams[i]} 55 | dataset_folder=${DatasetsFolders[i]} 56 | extra_args=${ExtraArgs[i]} 57 | suffix=${OutputSuffix[i]} 58 | recipe_root="$task/$downstream" 59 | recipe_extra_args="$script_args ${extra_args//%recipe_root%/$recipe_root}" 60 | set -- "$recipe_extra_args" 61 | echo "${tokenizer_name}/${task}/${downstream}" 62 | python $task/$downstream/train_$tokenizer_name.py \ 63 | $task/$downstream/hparams/train_$tokenizer_name.yaml \ 64 | --output_folder $output_folder/$tokenizer_name/$task/$downstream$suffix \ 65 | --data_folder $dataset_folder \ 66 | $@ 67 | done 68 | -------------------------------------------------------------------------------- /benchmarks/MOABB/models/BraindecodeNN.py: -------------------------------------------------------------------------------- 1 | """Braindecode from https://braindecode.org/stable/index.html. 2 | Braindecode is an open-source Python toolbox for decoding raw electrophysiological brain data with 3 | deep learning models. It includes dataset fetchers, data preprocessing and visualization tools, as 4 | well as implementations of several deep learning architectures and data augmentations for analysis 5 | of EEG, ECoG and MEG. 6 | 7 | This code is a Speechbrain interface for the Braindecode models. This wrapper allows the usage of 8 | Braindecode models with the benchmarks pipeline for experiment reproducibility. 9 | 10 | Note 1: We recommend using the braindecode from the source code to avoid compatibility issues. 11 | 12 | ```bash 13 | pip install git+https://github.com/braindecode/braindecode.git#egg=braindecode 14 | ``` 15 | 16 | Note 2: Softmax is added to the model layer stack since NLL is used. 17 | 18 | Authors 19 | * Davide Borra, 2023 20 | * Drew Wagner, 2024 21 | * Victor Cruz, 2024 22 | * Bruno Aristimunha, 2024 23 | """ 24 | import torch 25 | from einops.layers.torch import Rearrange 26 | 27 | 28 | class BraindecodeNN(torch.nn.Module): 29 | """Class for wrapping braindecode models. 30 | 31 | Arguments 32 | --------- 33 | model: braindecode.model() 34 | Braindecode model class 35 | 36 | Example 37 | ------- 38 | >>> from benchmarks.MOABB.models.EEGConformer import EEGConformer 39 | >>> model = EEGConformer(input_shape=inp_tensor.shape) 40 | >>> model_braindecode = BraindecodeNN(model) 41 | """ 42 | 43 | def __init__(self, model): 44 | super().__init__() 45 | self.model = model 46 | self.input_layer = Rearrange("batch time chan 1 -> batch chan time") 47 | self.softmax = torch.nn.LogSoftmax(dim=1) 48 | 49 | def forward(self, x: torch.Tensor) -> torch.Tensor: 50 | """Returns the output of the model. 51 | 52 | Arguments 53 | --------- 54 | x : torch.Tensor (batch, time, EEG channel, channel) 55 | Input to convolve. 4d tensors are expected. 56 | """ 57 | # (batch, time_, EEG channel, channel) -> # (batch, EEG channel, time_, channel) 58 | x = self.input_layer(x) 59 | x = self.model(x) 60 | x = self.softmax(x) 61 | return x 62 | -------------------------------------------------------------------------------- /benchmarks/DASB/Libri2Mix/separation/metrics/dwer.py: -------------------------------------------------------------------------------- 1 | """Differential WER (dWER) (see https://arxiv.org/abs/1911.07953). 2 | 3 | Authors 4 | * Luca Della Libera 2024 5 | """ 6 | 7 | import torch 8 | import torchaudio 9 | from speechbrain.decoders.seq2seq import S2SWhisperGreedySearcher 10 | from speechbrain.lobes.models.huggingface_transformers import Whisper 11 | from speechbrain.utils.metric_stats import ErrorRateStats, MetricStats 12 | 13 | 14 | __all__ = ["DWER"] 15 | 16 | 17 | SAMPLE_RATE = 16000 18 | 19 | 20 | class DWER(MetricStats): 21 | def __init__(self, model_hub, save_path, sample_rate): 22 | self.sample_rate = sample_rate 23 | self.model = Whisper( 24 | model_hub, save_path, SAMPLE_RATE, freeze=True, freeze_encoder=True, 25 | ).cpu() 26 | self.searcher = S2SWhisperGreedySearcher( 27 | self.model, min_decode_ratio=0.0, max_decode_ratio=1.0, 28 | ) 29 | self.model.tokenizer.set_prefix_tokens("english", "transcribe", False) 30 | self.wer_computer = ErrorRateStats() 31 | 32 | def clear(self): 33 | self.wer_computer.clear() 34 | 35 | @torch.no_grad() 36 | def append(self, ids, hyp_audio, ref_audio, lens=None): 37 | assert hyp_audio.shape == ref_audio.shape 38 | assert hyp_audio.ndim == 2 39 | 40 | # Concatenate 41 | audio = torch.cat([hyp_audio, ref_audio]) 42 | if lens is not None: 43 | lens = torch.cat([lens, lens]) 44 | 45 | # Resample 46 | audio = torchaudio.functional.resample( 47 | audio, self.sample_rate, SAMPLE_RATE 48 | ) 49 | 50 | self.model.to(hyp_audio.device) 51 | self.model.eval() 52 | 53 | # Forward 54 | enc_out = self.model.forward_encoder(self.model._get_mel(audio)) 55 | text, _, _, _ = self.searcher(enc_out, lens) 56 | text = self.model.tokenizer.batch_decode(text, skip_special_tokens=True) 57 | text = [self.model.tokenizer._normalize(x).split(" ") for x in text] 58 | hyp_text = text[: hyp_audio.shape[0]] 59 | ref_text = text[hyp_audio.shape[0] :] 60 | 61 | # Compute WER 62 | self.wer_computer.append(ids, hyp_text, ref_text) 63 | 64 | def summarize(self, field=None): 65 | return self.wer_computer.summarize(field) 66 | 67 | def write_stats(self, filestream, verbose=False): 68 | self.wer_computer.write_stats(filestream) 69 | -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceBank/enhancement/metrics/dwer.py: -------------------------------------------------------------------------------- 1 | """Differential WER (dWER) (see https://arxiv.org/abs/1911.07953). 2 | 3 | Authors 4 | * Luca Della Libera 2024 5 | """ 6 | 7 | import torch 8 | import torchaudio 9 | from speechbrain.decoders.seq2seq import S2SWhisperGreedySearcher 10 | from speechbrain.lobes.models.huggingface_transformers import Whisper 11 | from speechbrain.utils.metric_stats import ErrorRateStats, MetricStats 12 | 13 | 14 | __all__ = ["DWER"] 15 | 16 | 17 | SAMPLE_RATE = 16000 18 | 19 | 20 | class DWER(MetricStats): 21 | def __init__(self, model_hub, save_path, sample_rate): 22 | self.sample_rate = sample_rate 23 | self.model = Whisper( 24 | model_hub, save_path, SAMPLE_RATE, freeze=True, freeze_encoder=True, 25 | ).cpu() 26 | self.searcher = S2SWhisperGreedySearcher( 27 | self.model, min_decode_ratio=0.0, max_decode_ratio=1.0, 28 | ) 29 | self.model.tokenizer.set_prefix_tokens("english", "transcribe", False) 30 | self.wer_computer = ErrorRateStats() 31 | 32 | def clear(self): 33 | self.wer_computer.clear() 34 | 35 | @torch.no_grad() 36 | def append(self, ids, hyp_audio, ref_audio, lens=None): 37 | assert hyp_audio.shape == ref_audio.shape 38 | assert hyp_audio.ndim == 2 39 | 40 | # Concatenate 41 | audio = torch.cat([hyp_audio, ref_audio]) 42 | if lens is not None: 43 | lens = torch.cat([lens, lens]) 44 | 45 | # Resample 46 | audio = torchaudio.functional.resample( 47 | audio, self.sample_rate, SAMPLE_RATE 48 | ) 49 | 50 | self.model.to(hyp_audio.device) 51 | self.model.eval() 52 | 53 | # Forward 54 | enc_out = self.model.forward_encoder(self.model._get_mel(audio)) 55 | text, _, _, _ = self.searcher(enc_out, lens) 56 | text = self.model.tokenizer.batch_decode(text, skip_special_tokens=True) 57 | text = [self.model.tokenizer._normalize(x).split(" ") for x in text] 58 | hyp_text = text[: hyp_audio.shape[0]] 59 | ref_text = text[hyp_audio.shape[0] :] 60 | 61 | # Compute WER 62 | self.wer_computer.append(ids, hyp_text, ref_text) 63 | 64 | def summarize(self, field=None): 65 | return self.wer_computer.summarize(field) 66 | 67 | def write_stats(self, filestream, verbose=False): 68 | self.wer_computer.write_stats(filestream) 69 | -------------------------------------------------------------------------------- /benchmarks/DASB/utils/data.py: -------------------------------------------------------------------------------- 1 | """Data utilities 2 | 3 | Authors 4 | * Artem Ploujnikov 2024 5 | """ 6 | 7 | import torch 8 | from speechbrain.dataio.batch import PaddedData 9 | 10 | 11 | def undo_batch(batch): 12 | """Converts a padded batch or a dicitionary to a list of 13 | dictionaries. Any instances of PaddedData encountered will 14 | be converted to plain tensors 15 | 16 | Arguments 17 | --------- 18 | batch: dict|speechbrain.dataio.batch.PaddedBatch 19 | the batch 20 | 21 | Returns 22 | ------- 23 | result: dict 24 | a list of dictionaries with each dictionary as a batch 25 | element 26 | """ 27 | if hasattr(batch, "as_dict"): 28 | batch = batch.as_dict() 29 | keys = batch.keys() 30 | return [ 31 | dict(zip(keys, item)) 32 | for item in zip( 33 | *[_unpack_feature(feature) for feature in batch.values()] 34 | ) 35 | ] 36 | 37 | 38 | def _unpack_feature(feature): 39 | """Un-batches a single feature. If a PaddedBatch is provided, it will be converted 40 | to a list of unpadded tensors. Otherwise, it will be returned unmodified 41 | 42 | Arguments 43 | --------- 44 | feature : any 45 | The feature to un-batch 46 | """ 47 | if isinstance(feature, PaddedData): 48 | device = feature.data.device 49 | feature = _undo_padding(feature.data, feature.lengths) 50 | feature = [torch.tensor(item, device=device) for item in feature] 51 | return feature 52 | 53 | 54 | # NOTE: Similar to the function in speechbrain.utils.data_utils 55 | # but it keeps values in tensor form 56 | def _undo_padding(batch, lengths): 57 | """Produces Python lists given a batch of sentences with 58 | their corresponding relative lengths. 59 | 60 | Arguments 61 | --------- 62 | batch : torch.Tensor 63 | Batch of sentences gathered in a batch. 64 | lengths : torch.Tensor 65 | Relative length of each sentence in the batch. 66 | 67 | Returns 68 | ------- 69 | as_list : list 70 | A python list of the corresponding input tensor. 71 | 72 | Example 73 | ------- 74 | >>> batch=torch.rand([4,100]) 75 | >>> lengths=torch.tensor([0.5,0.6,0.7,1.0]) 76 | >>> snt_list=undo_padding(batch, lengths) 77 | >>> len(snt_list) 78 | 4 79 | """ 80 | batch_max_len = batch.shape[1] 81 | as_list = [] 82 | for seq, seq_length in zip(batch, lengths): 83 | actual_size = int(torch.round(seq_length * batch_max_len)) 84 | seq_true = seq[:actual_size] 85 | as_list.append(seq_true) 86 | return as_list 87 | 88 | 89 | def as_dict(batch): 90 | """Converts a batch to a dictionary""" 91 | return {key: getattr(batch, key) for key in batch._PaddedBatch__keys} 92 | -------------------------------------------------------------------------------- /benchmarks/DASB/LJSpeech/TTS/tokotron/train_discrete_ssl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env/python3 2 | """Recipe for training a Text-to-Speech system based on tokenized audio 3 | Discrete SSL version 4 | 5 | Inspired by WhisperSpeech 6 | https://github.com/collabora/WhisperSpeech 7 | 8 | However, this is not an implementation of WhisperSpeech, but rather 9 | a radical simplification of it that uses only an acoustic model 10 | 11 | 12 | Authors 13 | * Artem Ploujnikov 2024 14 | """ 15 | 16 | import torch 17 | from train import TokotronBrain, run_experiment 18 | from speechbrain.dataio.dataio import clean_padding_ 19 | 20 | 21 | class TokotronDiscreteSSLBrain(TokotronBrain): 22 | """Tokotron implementation for Encodec""" 23 | 24 | def on_stage_start(self, stage, epoch): 25 | self.compute_offset() 26 | return super().on_stage_start(stage, epoch) 27 | 28 | def compute_offset(self): 29 | """Computes per-layer offsets""" 30 | layers_set = set(self.hparams.token_model_layers) 31 | available_layers_set = set(self.hparams.vocoder_available_layers) 32 | if not layers_set.issubset(available_layers_set): 33 | unavailable_layers = ",".join( 34 | str(layer) for layer in (layers_set - available_layers_set) 35 | ) 36 | raise ValueError(f"Layers {unavailable_layers} are not supported") 37 | self.num_units = self.hparams.audio_num_tokens 38 | _, layers_idx = torch.where( 39 | torch.tensor( 40 | self.hparams.vocoder_available_layers, device=self.device 41 | ).unsqueeze(0) 42 | == torch.tensor( 43 | self.hparams.token_model_layers, device=self.device 44 | ).unsqueeze(1) 45 | ) 46 | self.layer_offset = ( 47 | torch.tensor(layers_idx, device=self.device) * self.num_units 48 | )[None, None, :] 49 | self.offset = self.hparams.token_offset 50 | self.modules.vocoder.tokenize = False 51 | 52 | def create_waveform(self, audio, length): 53 | """Creates a waveform from a discrete or continuous audio 54 | representation 55 | 56 | Arguments 57 | --------- 58 | audio : torch.Tensor 59 | An audio tensor (Batch x Length x Heads or Batch x Length x Heads x Features) 60 | lengths : torch.Tensor 61 | A 1-D tensor 62 | 63 | Returns 64 | ------- 65 | wav : torch.Tensor 66 | """ 67 | units_with_offset = ( 68 | audio + self.layer_offset.to(audio.device) + self.offset 69 | ) 70 | wav = self.modules.vocoder(units_with_offset) 71 | wav = wav.squeeze(1) 72 | clean_padding_(wav, length) 73 | return wav 74 | 75 | 76 | if __name__ == "__main__": 77 | run_experiment(TokotronDiscreteSSLBrain) 78 | -------------------------------------------------------------------------------- /benchmarks/DASB/LJSpeech/quantization/hparams/train_discrete_ssl.yaml: -------------------------------------------------------------------------------- 1 | ################################ 2 | # Recipe for Training K-Means Clustering on LJSpeech Data 3 | # Using Self-Supervised Model-Based Representations 4 | # 5 | # It is used for creating discrete audio representations from LJSpeech data. 6 | # 7 | # Author: Pooneh Mousavi (2023) 8 | ################################ 9 | # Seed needs to be set at top of yaml, before objects with parameters are made 10 | seed: 1986 11 | __set_seed: !apply:torch.manual_seed [!ref ] 12 | output_folder: !ref results/LJSpeech/clustering/hubert/ 13 | save_folder: !ref /save 14 | 15 | # Data files 16 | data_folder: !PLACEHOLDER # e,g./path/to/LJSpeech-1.1 17 | 18 | train_json: !ref /train.json 19 | 20 | splits: ["train"] 21 | split_ratio: [80] 22 | skip_prep: False 23 | sample_rate: 16000 24 | 25 | # ssl_model_type: hubert, wavlm, wav2vec2 26 | # ssl_hub: facebook/hubert-large-ll60k, microsoft/wavlm-large, facebook/wav2vec2-large 27 | ssl_model_type: hubert # hubert, wavml or wav2vec2 28 | ssl_hub: facebook/hubert-large-ll60k 29 | ssl_folder: !ref /ssl_checkpoint 30 | freeze_feature_extractor: True 31 | freeze_ssl: True 32 | ssl_layer_num: 7 33 | batch_size: 128 # batch_size for loading and extracting features. It is different from kmeans_batch_size. 34 | checkpoint_interval: 100 35 | 36 | 37 | # Dataloader options 38 | train_dataloader_opts: 39 | batch_size: !ref 40 | drop_last: True 41 | 42 | ssl_model: !apply:speechbrain.utils.hparams.choice 43 | value: !ref 44 | choices: 45 | wavlm: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM 46 | source: !ref 47 | output_norm: False 48 | freeze: !ref 49 | freeze_feature_extractor: !ref 50 | output_all_hiddens: True 51 | save_path: !ref 52 | hubert: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT 53 | source: !ref 54 | output_norm: False 55 | freeze: !ref 56 | freeze_feature_extractor: !ref 57 | output_all_hiddens: True 58 | save_path: !ref 59 | wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 60 | source: !ref 61 | output_norm: False 62 | freeze: !ref 63 | freeze_feature_extractor: !ref 64 | output_all_hiddens: True 65 | save_path: !ref 66 | 67 | #################### 68 | # Model Parameters # 69 | #################### 70 | num_clusters: 128 71 | init: k-means++ 72 | max_iter: 100 73 | kmeans_batch_size: 1000 # should be >= num_clusters 74 | tol: 0.0 75 | max_no_improvement: 100 76 | n_init: 20 77 | reassignment_ratio: 0.0 78 | -------------------------------------------------------------------------------- /tests/consistency/README.md: -------------------------------------------------------------------------------- 1 | # How to add your recipe in tests/recipes 2 | The folder `tests/recipes` is introduced for tracking all the recipes and their connected resources (e.g., HuggingFace repo, README files, recipe folders, etc). 3 | Each CSV file in that folder corresponds to one recipe dataset and enlists depending recipe tests. 4 | 5 | When you write a new recipe (e.g., recipes/your_dataset/) you need to: 6 | 1. ensure the CSV file `tests/recipes/your_dataset.csv` exists (simply copy the header from another CSV) 7 | 2. add a new line to the `tests/recipes/your_dataset.csv`. 8 | 9 | More specifically, you have to fill the following fields: 10 | 11 | - Task (mandatory): 12 | The task that the recipe is addressing (e.g.. `ASR`). 13 | - Dataset (mandatory): 14 | Dataset of the recipe (e.g. `LibriSpeech`). 15 | - Script_file (mandatory): 16 | Training script of the recipe (e.g., `recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py`) 17 | - Hparam_file (mandatory): 18 | Hyperparameter file of the recipe (e.g., `recipes/LibriSpeech/ASR/CTC/hparams/train_with_wav2vec.yaml`) 19 | - Data_prep_file (optional): 20 | Data preparation file (e.g., `recipes/LibriSpeech/librispeech_prepare.py`) 21 | - Readme_file (mandatory): 22 | Readme file describing the recipe (e.g., `recipes/LibriSpeech/ASR/CTC/README.md`) 23 | - Result_url (mandatory): 24 | URL where the output folder is stored (e.g., `https://drive.google.com/drive/folders/1pg0QzW-LqAISG8Viw_lUTGjXwOqh7gkl?usp=sharing` ). 25 | Note that with SpeechBrain we would like to make available the full output folder to the users. The output folder contains the logs, checkpoints, etc that help users debug and reproduce the results. 26 | Make sure this URL is mentioned in the README file. 27 | - HF_repo (optional): 28 | Link to the HuggingFace repository containing the pre-trained model (e.g., `https://huggingface.co/speechbrain/asr-wav2vec2-librispeech`). If specified, it must be mentioned in the README file. 29 | - test_debug_flags (optional): 30 | This optional field reports the flags to run recipe tests (see `tests/.run-recipe-tests.sh`). The goal of the recipe tests is to run an experiment with a tiny dataset and 1-make sure it runs 2-make sure it overfits properly. 31 | For instance, `--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=10 --skip_prep=True` will run an experiment with the given train and hparams files using a tiny dataset (ASR_train.csv) 32 | - test_debug_checks (optional) 33 | Checks if the recipe test produces the expected output. For instance,`file_exists=[env.log,hyperparams.yaml,log.txt,train_log.txt,train.py,wer_ASR_train.txt,save/lm.ckpt,save/tokenizer.ckpt] performance_check=[train_log.txt, train loss, <350, epoch: 10]` will first checks if the files in the file_exists list have been created and then checks if the training loss reported in train_log.txt is below a certain threshold. 34 | -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceCeleb1/quantization/hparams/train_discrete_ssl.yaml: -------------------------------------------------------------------------------- 1 | ################################ 2 | # Recipe for Training K-Means Clustering on LibriSpeech Data 3 | # Using Self-Supervised Model-Based Representations 4 | # 5 | # It is used for creating discrete audio representations from LibriSpeech data. 6 | # 7 | # Author: Pooneh Mousavi (2023) 8 | ################################ 9 | # Seed needs to be set at top of yaml, before objects with parameters are made 10 | seed: 1986 11 | __set_seed: !apply:torch.manual_seed [!ref ] 12 | output_folder: !ref results/VoxCeleb/clustering/hubert/ 13 | save_folder: !ref /save 14 | 15 | # Data files 16 | data_folder: !PLACEHOLDER # e.g. /path/to/Voxcele 17 | train_annotation: !ref /train.csv 18 | verification_file: https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt 19 | skip_prep: False 20 | split_ratio: [90] 21 | sample_rate: 16000 22 | sentence_len: 3.0 # seconds 23 | shuffle: True 24 | random_chunk: True 25 | 26 | # ssl_model_type: hubert, wavlm, wav2vec2 27 | # ssl_hub: facebook/hubert-large-ll60k, microsoft/wavlm-large, facebook/wav2vec2-large 28 | ssl_model_type: hubert # hubert, wavml or wav2vec2 29 | ssl_hub: facebook/hubert-large-ll60k 30 | ssl_folder: !ref /ssl_checkpoint 31 | freeze_feature_extractor: True 32 | freeze_ssl: True 33 | ssl_layer_num: 7 34 | batch_size: 128 # batch_size for loading and extracting features. It is different from kmeans_batch_size. 35 | checkpoint_interval: 100 36 | 37 | sorting: descending 38 | 39 | # Dataloader options 40 | train_dataloader_opts: 41 | batch_size: !ref 42 | drop_last: True 43 | 44 | ssl_model: !apply:speechbrain.utils.hparams.choice 45 | value: !ref 46 | choices: 47 | wavlm: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM 48 | source: !ref 49 | output_norm: False 50 | freeze: !ref 51 | freeze_feature_extractor: !ref 52 | output_all_hiddens: True 53 | save_path: !ref 54 | hubert: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT 55 | source: !ref 56 | output_norm: False 57 | freeze: !ref 58 | freeze_feature_extractor: !ref 59 | output_all_hiddens: True 60 | save_path: !ref 61 | wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 62 | source: !ref 63 | output_norm: False 64 | freeze: !ref 65 | freeze_feature_extractor: !ref 66 | output_all_hiddens: True 67 | save_path: !ref 68 | 69 | #################### 70 | # Model Parameters # 71 | #################### 72 | num_clusters: 128 73 | init: k-means++ 74 | max_iter: 100 75 | kmeans_batch_size: 1000 # should be >= num_clusters 76 | tol: 0.0 77 | max_no_improvement: 100 78 | n_init: 20 79 | reassignment_ratio: 0.0 80 | -------------------------------------------------------------------------------- /benchmarks/DASB/LibriSpeech/quantization/hparams/train_discrete_ssl.yaml: -------------------------------------------------------------------------------- 1 | ################################ 2 | # Recipe for Training K-Means Clustering on LibriSpeech Data 3 | # Using Self-Supervised Model-Based Representations 4 | # 5 | # It is used for creating discrete audio representations from LibriSpeech data. 6 | # 7 | # Author: Pooneh Mousavi (2023) 8 | ################################ 9 | # Seed needs to be set at top of yaml, before objects with parameters are made 10 | seed: 1986 11 | __set_seed: !apply:torch.manual_seed [!ref ] 12 | output_folder: !ref results/LibriSpeech/clustering/hubert/ 13 | save_folder: !ref /save 14 | 15 | # Data files 16 | data_folder: data/test-clean/LibriSpeech # e,g./path/to/LibriSpeech 17 | train_splits: ["train-clean-100", "train-clean-360", "train-other-500"] 18 | dev_splits: [] 19 | test_splits: [] 20 | skip_prep: False 21 | ckpt_interval_minutes: 25 # save checkpoint every N min 22 | train_csv: !ref /train.csv 23 | sample_rate: 16000 24 | 25 | # ssl_model_type: hubert, wavlm, wav2vec2 26 | # ssl_hub: facebook/hubert-large-ll60k, microsoft/wavlm-large, facebook/wav2vec2-large 27 | ssl_model_type: hubert # hubert, wavml or wav2vec2 28 | ssl_hub: facebook/hubert-large-ll60k 29 | ssl_folder: !ref /ssl_checkpoint 30 | freeze_feature_extractor: True 31 | freeze_ssl: True 32 | ssl_layer_num: 7 33 | batch_size: 128 # batch_size for loading and extracting features. It is different from kmeans_batch_size. 34 | checkpoint_interval: 100 35 | 36 | sorting: ascending 37 | 38 | # Dataloader options 39 | train_dataloader_opts: 40 | batch_size: !ref 41 | drop_last: True 42 | 43 | ssl_model: !apply:speechbrain.utils.hparams.choice 44 | value: !ref 45 | choices: 46 | wavlm: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM 47 | source: !ref 48 | output_norm: False 49 | freeze: !ref 50 | freeze_feature_extractor: !ref 51 | output_all_hiddens: True 52 | save_path: !ref 53 | hubert: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT 54 | source: !ref 55 | output_norm: False 56 | freeze: !ref 57 | freeze_feature_extractor: !ref 58 | output_all_hiddens: True 59 | save_path: !ref 60 | wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 61 | source: !ref 62 | output_norm: False 63 | freeze: !ref 64 | freeze_feature_extractor: !ref 65 | output_all_hiddens: True 66 | save_path: !ref 67 | 68 | #################### 69 | # Model Parameters # 70 | #################### 71 | num_clusters: 128 72 | init: k-means++ 73 | max_iter: 100 74 | kmeans_batch_size: 1000 # should be >= num_clusters 75 | tol: 0.0 76 | max_no_improvement: 100 77 | n_init: 20 78 | reassignment_ratio: 0.0 79 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | tests/tmp/ 55 | tests/download/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | docs/source/*.rst 77 | !docs/source/index.rst 78 | !docs/source/_templates 79 | !docs/source/_static 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # PyCharm project settings 123 | .idea 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | 143 | # pytype static type analyzer 144 | .pytype/ 145 | 146 | # Audio folders 147 | **/audio_cache/ 148 | 149 | # Pretrained & models folders 150 | **/model_checkpoints/ 151 | **/pretrained_model_checkpoints/ 152 | **/pretrained_models/ 153 | 154 | # Results folders 155 | **/results/ 156 | 157 | # Log folders 158 | **/log/ 159 | 160 | # Mac OS 161 | .DS_Store -------------------------------------------------------------------------------- /benchmarks/DASB/IEMOCAP/quantization/hparams/train_discrete_ssl.yaml: -------------------------------------------------------------------------------- 1 | ################################ 2 | # Recipe for Training K-Means Clustering on IEMOCAP Data 3 | # Using Self-Supervised Model-Based Representations 4 | # 5 | # It is used for creating discrete audio representations from IEMOCAP data. 6 | # 7 | # Author: Pooneh Mousavi (2024) 8 | ################################ 9 | # Seed needs to be set at top of yaml, before objects with parameters are made 10 | seed: 1986 11 | __set_seed: !apply:torch.manual_seed [!ref ] 12 | output_folder: !ref results/IEMOCAP/clustering/hubert/ 13 | save_folder: !ref /save 14 | 15 | # Data files 16 | # Dataset will be downloaded to the `data_original` 17 | data_folder: !PLACEHOLDER # e.g., /path/to/IEMOCAP_full_release 18 | 19 | # different speakers for train, valid and test sets 20 | different_speakers: False 21 | # which speaker is used for test set, value from 1 to 10 22 | test_spk_id: 1 23 | # Path where data manifest files will be stored 24 | train_annotation: !ref /train.json 25 | valid_annotation: !ref /valid.json 26 | test_annotation: !ref /test.json 27 | split_ratio: [80, 10, 10] 28 | skip_prep: False 29 | sample_rate: 16000 30 | 31 | # ssl_model_type: hubert, wavlm, wav2vec2 32 | # ssl_hub: facebook/hubert-large-ll60k, microsoft/wavlm-large, facebook/wav2vec2-large 33 | ssl_model_type: hubert # hubert, wavml or wav2vec2 34 | ssl_hub: facebook/hubert-large-ll60k 35 | ssl_folder: !ref /ssl_checkpoint 36 | freeze_feature_extractor: True 37 | freeze_ssl: True 38 | ssl_layer_num: 7 39 | batch_size: 128 # batch_size for loading and extracting features. It is different from kmeans_batch_size. 40 | checkpoint_interval: 100 41 | 42 | # Dataloader options 43 | train_dataloader_opts: 44 | batch_size: !ref 45 | drop_last: True 46 | 47 | ssl_model: !apply:speechbrain.utils.hparams.choice 48 | value: !ref 49 | choices: 50 | wavlm: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM 51 | source: !ref 52 | output_norm: False 53 | freeze: !ref 54 | freeze_feature_extractor: !ref 55 | output_all_hiddens: True 56 | save_path: !ref 57 | hubert: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT 58 | source: !ref 59 | output_norm: False 60 | freeze: !ref 61 | freeze_feature_extractor: !ref 62 | output_all_hiddens: True 63 | save_path: !ref 64 | wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 65 | source: !ref 66 | output_norm: False 67 | freeze: !ref 68 | freeze_feature_extractor: !ref 69 | output_all_hiddens: True 70 | save_path: !ref 71 | 72 | #################### 73 | # Model Parameters # 74 | #################### 75 | num_clusters: 128 76 | init: k-means++ 77 | max_iter: 100 78 | kmeans_batch_size: 1000 # should be >= num_clusters 79 | tol: 0.0 80 | max_no_improvement: 100 81 | n_init: 20 82 | reassignment_ratio: 0.0 83 | -------------------------------------------------------------------------------- /benchmarks/DASB/LibriSpeech/quantization/hparams/train_subwording.yaml: -------------------------------------------------------------------------------- 1 | ################################ 2 | # Recipe for Training BPE tokenizer on discrete SSL tokens 3 | # Author: Pooneh Mousavi (2024) 4 | ################################ 5 | # Seed needs to be set at top of yaml, before objects with parameters are made 6 | seed: 1986 7 | __set_seed: !apply:torch.manual_seed [!ref ] 8 | output_folder: !ref results/LibriSpeech/subwording/discrete-ssl-bpe/ 9 | save_folder: !ref /save 10 | 11 | 12 | # Data files 13 | data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech 14 | train_splits: ["train-clean-100"] 15 | dev_splits: [] 16 | test_splits: [] 17 | skip_prep: False 18 | ckpt_interval_minutes: 25 # save checkpoint every N min 19 | train_csv: !ref /train.csv 20 | sample_rate: 16000 21 | tokenized_train: !ref /tokenized_train.csv 22 | vocab_size: 1000 23 | unk_id: 1 24 | pad_id: 0 25 | 26 | # ssl_model_type: hubert, wavlm, wav2vec2 27 | # ssl_hub: facebook/hubert-large-ll60k, microsoft/wavlm-large, facebook/wav2vec2-large 28 | ssl_model_type: hubert # hubert, wavml or wav2vec2 29 | ssl_hub: facebook/hubert-large-ll60k 30 | ssl_folder: !ref /ssl_checkpoint 31 | kmeans_repo_id: speechbrain/SSL_Quantization 32 | kmeans_cache_dir: !ref /kmeans_checkpoint 33 | kmeans_dataset: LibriSpeech-100-360-500 34 | num_clusters: 800 35 | freeze_ssl: True 36 | freeze_feature_extractor: True 37 | # Layer number should be among the supported layers for discrete SSL models(kmenas model should be available for that layer) 38 | ssl_layer_num: [7, 23] 39 | deduplicate: [False, False] 40 | bpe_tokenizer_path: [null, null] 41 | 42 | 43 | tokenizer_config: 44 | SSL_layers: !ref 45 | deduplicates: !ref 46 | bpe_tokenizers: !ref 47 | 48 | 49 | ssl_model: !apply:speechbrain.utils.hparams.choice 50 | value: !ref 51 | choices: 52 | wavlm: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM 53 | source: !ref 54 | output_norm: False 55 | freeze: !ref 56 | freeze_feature_extractor: !ref 57 | output_all_hiddens: True 58 | save_path: !ref 59 | hubert: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT 60 | source: !ref 61 | output_norm: False 62 | freeze: !ref 63 | freeze_feature_extractor: !ref 64 | output_all_hiddens: True 65 | save_path: !ref 66 | wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 67 | source: !ref 68 | output_norm: False 69 | freeze: !ref 70 | freeze_feature_extractor: !ref 71 | output_all_hiddens: True 72 | save_path: !ref 73 | 74 | discrete_ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.discrete_ssl.DiscreteSSL 75 | save_path: !ref 76 | ssl_model: !ref 77 | kmeans_dataset: !ref 78 | kmeans_repo_id: !ref 79 | num_clusters: !ref 80 | -------------------------------------------------------------------------------- /benchmarks/DASB/CommonVoice/quantization/hparams/train_discrete_ssl.yaml: -------------------------------------------------------------------------------- 1 | ################################ 2 | # Recipe for Training K-Means Clustering on CommonVoice Data 3 | # Using Self-Supervised Model-Based Representations 4 | # 5 | # It is used for creating discrete audio representations from CommonVoice data. 6 | # 7 | # Author: Pooneh Mousavi (2023) 8 | ################################ 9 | # Seed needs to be set at top of yaml, before objects with parameters are made 10 | seed: 1986 11 | __set_seed: !apply:torch.manual_seed [!ref ] 12 | output_folder: !ref results/CommonVoice/clustering/hubert/ 13 | save_folder: !ref /save 14 | 15 | # Data files 16 | data_folder: !PLACEHOLDER # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr 17 | train_tsv_file: !ref /train.tsv # Standard CommonVoice .tsv files 18 | accented_letters: False 19 | language: en # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english 20 | train_csv: !ref /train.csv 21 | skip_prep: False # Skip data preparation 22 | sample_rate: 16000 23 | 24 | # We remove utterance slonger than 10s in the train/dev/test sets as 25 | # longer sentences certainly correspond to "open microphones". 26 | avoid_if_longer_than: 10.0 27 | 28 | # ssl_model_type: hubert, wavlm, wav2vec2 29 | # ssl_hub: facebook/hubert-large-ll60k, microsoft/wavlm-large, facebook/wav2vec2-large 30 | ssl_model_type: hubert # hubert, wavml or wav2vec2 31 | ssl_hub: facebook/hubert-large-ll60k 32 | freeze_feature_extractor: True 33 | freeze_ssl: True 34 | ssl_folder: !ref /ssl_checkpoint 35 | ssl_layer_num: 7 36 | batch_size: 128 # batch_size for loading and extracting features. It is different from kmeans_batch_size. 37 | dataloader_num_workers: 8 38 | sorting: ascending 39 | checkpoint_interval: 100 40 | 41 | 42 | # Dataloader options 43 | dataloader_options: 44 | batch_size: !ref 45 | num_workers: !ref 46 | drop_last: True 47 | 48 | ssl_model: !apply:speechbrain.utils.hparams.choice 49 | value: !ref 50 | choices: 51 | wavlm: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM 52 | source: !ref 53 | output_norm: False 54 | freeze: !ref 55 | freeze_feature_extractor: !ref 56 | output_all_hiddens: True 57 | save_path: !ref 58 | hubert: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT 59 | source: !ref 60 | output_norm: False 61 | freeze: !ref 62 | freeze_feature_extractor: !ref 63 | output_all_hiddens: True 64 | save_path: !ref 65 | wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 66 | source: !ref 67 | output_norm: False 68 | freeze: !ref 69 | freeze_feature_extractor: !ref 70 | output_all_hiddens: True 71 | save_path: !ref 72 | 73 | 74 | #################### 75 | # Model Parameters # 76 | #################### 77 | num_clusters: 128 78 | init: k-means++ 79 | max_iter: 100 80 | kmeans_batch_size: 1000 # should be >= num_clusters 81 | tol: 0.0 82 | max_no_improvement: 100 83 | n_init: 20 84 | reassignment_ratio: 0.0 85 | -------------------------------------------------------------------------------- /benchmarks/DASB/Libri2Mix/separation/metrics/spk_sim.py: -------------------------------------------------------------------------------- 1 | """Cosine similarity between speaker embeddings. 2 | 3 | Authors 4 | * Luca Della Libera 2024 5 | """ 6 | 7 | import torch 8 | import torchaudio 9 | from speechbrain.dataio.dataio import length_to_mask 10 | from speechbrain.inference.speaker import SpeakerRecognition 11 | from speechbrain.utils.metric_stats import MetricStats 12 | from transformers import AutoModelForAudioXVector 13 | 14 | 15 | __all__ = ["SpkSimECAPATDNN", "SpkSimWavLM"] 16 | 17 | 18 | SAMPLE_RATE = 16000 19 | 20 | 21 | class SpkSimECAPATDNN(MetricStats): 22 | def __init__(self, model_hub, save_path, sample_rate): 23 | self.sample_rate = sample_rate 24 | self.model = SpeakerRecognition.from_hparams( 25 | model_hub, savedir=save_path 26 | ).cpu() 27 | self.clear() 28 | 29 | @torch.no_grad() 30 | def append(self, ids, hyp_audio, ref_audio, lens=None): 31 | assert hyp_audio.shape == ref_audio.shape 32 | assert hyp_audio.ndim == 2 33 | 34 | # Concatenate 35 | audio = torch.cat([hyp_audio, ref_audio]) 36 | if lens is not None: 37 | lens = torch.cat([lens, lens]) 38 | 39 | # Resample 40 | audio = torchaudio.functional.resample( 41 | audio, self.sample_rate, SAMPLE_RATE 42 | ) 43 | 44 | self.model.device = hyp_audio.device 45 | self.model.to(hyp_audio.device) 46 | self.model.eval() 47 | 48 | # Forward 49 | embs = self.model.encode_batch(audio, lens, normalize=False) 50 | hyp_embs, ref_embs = embs.split([len(hyp_audio), len(ref_audio)]) 51 | scores = self.model.similarity(hyp_embs, ref_embs)[:, 0] 52 | 53 | self.ids += ids 54 | self.scores += scores.cpu().tolist() 55 | 56 | 57 | class SpkSimWavLM(MetricStats): 58 | def __init__(self, model_hub, save_path, sample_rate): 59 | self.sample_rate = sample_rate 60 | self.model = AutoModelForAudioXVector.from_pretrained( 61 | model_hub, cache_dir=save_path 62 | ) 63 | self.clear() 64 | 65 | @torch.no_grad() 66 | def append(self, ids, hyp_audio, ref_audio, lens=None): 67 | assert hyp_audio.shape == ref_audio.shape 68 | assert hyp_audio.ndim == 2 69 | 70 | # Concatenate 71 | audio = torch.cat([hyp_audio, ref_audio]) 72 | if lens is not None: 73 | lens = torch.cat([lens, lens]) 74 | 75 | # Resample 76 | audio = torchaudio.functional.resample( 77 | audio, self.sample_rate, SAMPLE_RATE 78 | ) 79 | 80 | self.model.to(hyp_audio.device) 81 | self.model.eval() 82 | 83 | # Attention mask 84 | attention_mask = None 85 | if lens is not None: 86 | abs_length = lens * audio.shape[-1] 87 | attention_mask = length_to_mask( 88 | abs_length.int() 89 | ).long() # 0 for masked tokens 90 | 91 | # Forward 92 | embs = self.model( 93 | input_values=audio, 94 | attention_mask=attention_mask, 95 | output_attentions=False, 96 | ).embeddings 97 | 98 | hyp_embs, ref_embs = embs.split([len(hyp_audio), len(ref_audio)]) 99 | scores = torch.nn.functional.cosine_similarity( 100 | hyp_embs, ref_embs, dim=-1 101 | ) 102 | 103 | self.ids += ids 104 | self.scores += scores.cpu().tolist() 105 | -------------------------------------------------------------------------------- /benchmarks/DASB/VoiceBank/enhancement/metrics/spk_sim.py: -------------------------------------------------------------------------------- 1 | """Cosine similarity between speaker embeddings. 2 | 3 | Authors 4 | * Luca Della Libera 2024 5 | """ 6 | 7 | import torch 8 | import torchaudio 9 | from speechbrain.dataio.dataio import length_to_mask 10 | from speechbrain.inference.speaker import SpeakerRecognition 11 | from speechbrain.utils.metric_stats import MetricStats 12 | from transformers import AutoModelForAudioXVector 13 | 14 | 15 | __all__ = ["SpkSimECAPATDNN", "SpkSimWavLM"] 16 | 17 | 18 | SAMPLE_RATE = 16000 19 | 20 | 21 | class SpkSimECAPATDNN(MetricStats): 22 | def __init__(self, model_hub, save_path, sample_rate): 23 | self.sample_rate = sample_rate 24 | self.model = SpeakerRecognition.from_hparams( 25 | model_hub, savedir=save_path 26 | ).cpu() 27 | self.clear() 28 | 29 | @torch.no_grad() 30 | def append(self, ids, hyp_audio, ref_audio, lens=None): 31 | assert hyp_audio.shape == ref_audio.shape 32 | assert hyp_audio.ndim == 2 33 | 34 | # Concatenate 35 | audio = torch.cat([hyp_audio, ref_audio]) 36 | if lens is not None: 37 | lens = torch.cat([lens, lens]) 38 | 39 | # Resample 40 | audio = torchaudio.functional.resample( 41 | audio, self.sample_rate, SAMPLE_RATE 42 | ) 43 | 44 | self.model.device = hyp_audio.device 45 | self.model.to(hyp_audio.device) 46 | self.model.eval() 47 | 48 | # Forward 49 | embs = self.model.encode_batch(audio, lens, normalize=False) 50 | hyp_embs, ref_embs = embs.split([len(hyp_audio), len(ref_audio)]) 51 | scores = self.model.similarity(hyp_embs, ref_embs)[:, 0] 52 | 53 | self.ids += ids 54 | self.scores += scores.cpu().tolist() 55 | 56 | 57 | class SpkSimWavLM(MetricStats): 58 | def __init__(self, model_hub, save_path, sample_rate): 59 | self.sample_rate = sample_rate 60 | self.model = AutoModelForAudioXVector.from_pretrained( 61 | model_hub, cache_dir=save_path 62 | ) 63 | self.clear() 64 | 65 | @torch.no_grad() 66 | def append(self, ids, hyp_audio, ref_audio, lens=None): 67 | assert hyp_audio.shape == ref_audio.shape 68 | assert hyp_audio.ndim == 2 69 | 70 | # Concatenate 71 | audio = torch.cat([hyp_audio, ref_audio]) 72 | if lens is not None: 73 | lens = torch.cat([lens, lens]) 74 | 75 | # Resample 76 | audio = torchaudio.functional.resample( 77 | audio, self.sample_rate, SAMPLE_RATE 78 | ) 79 | 80 | self.model.to(hyp_audio.device) 81 | self.model.eval() 82 | 83 | # Attention mask 84 | attention_mask = None 85 | if lens is not None: 86 | abs_length = lens * audio.shape[-1] 87 | attention_mask = length_to_mask( 88 | abs_length.int() 89 | ).long() # 0 for masked tokens 90 | 91 | # Forward 92 | embs = self.model( 93 | input_values=audio, 94 | attention_mask=attention_mask, 95 | output_attentions=False, 96 | ).embeddings 97 | 98 | hyp_embs, ref_embs = embs.split([len(hyp_audio), len(ref_audio)]) 99 | scores = torch.nn.functional.cosine_similarity( 100 | hyp_embs, ref_embs, dim=-1 101 | ) 102 | 103 | self.ids += ids 104 | self.scores += scores.cpu().tolist() 105 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/whisper/hparams/train_ft.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: Whisper (encoder-decoder) + cross-entropy loss 3 | # Authors: Luca Della Libera 2023 4 | # ############################################################################ 5 | 6 | experiment_name: FT 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | 13 | # Data preparation 14 | base_locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 15 | new_locales: [rw, eo, kab, lg, mhr, ckb, ab, kmr, fy-NL, ia] 16 | data_folder: !PLACEHOLDER 17 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 18 | 19 | # Output directories 20 | output_folder: !ref results/// 21 | save_folder: !ref /save 22 | 23 | # Training parameters 24 | train_batch_size: 6 25 | valid_batch_size: 16 26 | train_num_workers: 6 27 | valid_num_workers: 6 28 | 29 | sample_rate: 16000 30 | sorting: ascending 31 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 32 | 33 | ignore_index: -100 # For cross-entropy loss 34 | label_smoothing: 0 35 | 36 | num_epochs: 2 37 | lr: 0.0001 38 | improvement_threshold: 0.0025 39 | annealing_factor: 0.8 40 | 41 | whisper_variant: whisper-large-v2 42 | encoder_only: False 43 | freeze: False 44 | freeze_encoder: True 45 | 46 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 47 | nonfinite_patience: 10 48 | max_grad_norm: 5.0 49 | precision: fp16 50 | gradient_checkpointing: False 51 | ckpt_interval_minutes: 600 52 | 53 | max_gen_tokens: 80 54 | forced_decoder_locale: null # Set dynamically 55 | normalize_transcripts: True 56 | 57 | # Dataloader options 58 | train_dataloader_kwargs: 59 | batch_size: !ref 60 | num_workers: !ref 61 | 62 | valid_dataloader_kwargs: 63 | batch_size: !ref 64 | num_workers: !ref 65 | 66 | # Modules 67 | whisper: !new:model.ProgressiveWhisper 68 | source: !ref openai/ 69 | save_path: !ref /checkpoint 70 | sampling_rate: !ref 71 | encoder_only: !ref 72 | freeze: !ref 73 | freeze_encoder: !ref 74 | 75 | ce_loss: !new:torch.nn.CrossEntropyLoss 76 | ignore_index: !ref 77 | label_smoothing: !ref 78 | 79 | modules: 80 | whisper: !ref 81 | 82 | # Optimizers 83 | opt_class: !name:torch.optim.AdamW 84 | lr: !ref 85 | 86 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 87 | initial_value: !ref 88 | improvement_threshold: !ref 89 | annealing_factor: !ref 90 | patient: 0 91 | 92 | # Performance metrics 93 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 94 | 95 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 96 | split_tokens: True 97 | 98 | # Counters, checkpointers, loggers, etc. 99 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 100 | limit: !ref 101 | 102 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 103 | checkpoints_dir: !ref 104 | recoverables: 105 | model: !ref 106 | scheduler: !ref 107 | counter: !ref 108 | 109 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 110 | save_file: !ref /.txt 111 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/whisper/hparams/train_joint.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: Whisper (encoder-decoder) + cross-entropy loss 3 | # Authors: Luca Della Libera 2023 4 | # ############################################################################ 5 | 6 | experiment_name: joint 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | 13 | # Data preparation 14 | base_locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 15 | new_locales: [rw, eo, kab, lg, mhr, ckb, ab, kmr, fy-NL, ia] 16 | data_folder: !PLACEHOLDER 17 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 18 | 19 | # Output directories 20 | output_folder: !ref results/// 21 | save_folder: !ref /save 22 | 23 | # Training parameters 24 | train_batch_size: 6 25 | valid_batch_size: 16 26 | train_num_workers: 6 27 | valid_num_workers: 6 28 | 29 | sample_rate: 16000 30 | sorting: ascending 31 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 32 | 33 | ignore_index: -100 # For cross-entropy loss 34 | label_smoothing: 0 35 | 36 | num_epochs: 2 37 | lr: 0.0001 38 | improvement_threshold: 0.0025 39 | annealing_factor: 0.8 40 | 41 | whisper_variant: whisper-large-v2 42 | encoder_only: False 43 | freeze: False 44 | freeze_encoder: True 45 | 46 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 47 | nonfinite_patience: 10 48 | max_grad_norm: 5.0 49 | precision: fp16 50 | gradient_checkpointing: False 51 | ckpt_interval_minutes: 600 52 | 53 | max_gen_tokens: 80 54 | forced_decoder_locale: null # Set dynamically 55 | normalize_transcripts: True 56 | 57 | # Dataloader options 58 | train_dataloader_kwargs: 59 | batch_size: !ref 60 | num_workers: !ref 61 | 62 | valid_dataloader_kwargs: 63 | batch_size: !ref 64 | num_workers: !ref 65 | 66 | # Modules 67 | whisper: !new:model.ProgressiveWhisper 68 | source: !ref openai/ 69 | save_path: !ref /checkpoint 70 | sampling_rate: !ref 71 | encoder_only: !ref 72 | freeze: !ref 73 | freeze_encoder: !ref 74 | 75 | ce_loss: !new:torch.nn.CrossEntropyLoss 76 | ignore_index: !ref 77 | label_smoothing: !ref 78 | 79 | modules: 80 | whisper: !ref 81 | 82 | # Optimizers 83 | opt_class: !name:torch.optim.AdamW 84 | lr: !ref 85 | 86 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 87 | initial_value: !ref 88 | improvement_threshold: !ref 89 | annealing_factor: !ref 90 | patient: 0 91 | 92 | # Performance metrics 93 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 94 | 95 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 96 | split_tokens: True 97 | 98 | # Counters, checkpointers, loggers, etc. 99 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 100 | limit: !ref 101 | 102 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 103 | checkpoints_dir: !ref 104 | recoverables: 105 | model: !ref 106 | scheduler: !ref 107 | counter: !ref 108 | 109 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 110 | save_file: !ref /.txt 111 | -------------------------------------------------------------------------------- /benchmarks/DASB/model/custom_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AttentionMLP(torch.nn.Module): 5 | def __init__(self, input_dim, hidden_dim): 6 | super(AttentionMLP, self).__init__() 7 | self.layers = torch.nn.Sequential( 8 | torch.nn.Linear(input_dim, hidden_dim), 9 | torch.nn.ReLU(), 10 | torch.nn.Linear(hidden_dim, 1, bias=False), 11 | ) 12 | 13 | def forward(self, x): 14 | x = self.layers(x) 15 | att_w = torch.nn.functional.softmax(x, dim=2) 16 | return att_w 17 | 18 | 19 | class Discrete_EmbeddingLayer(torch.nn.Module): 20 | """This class handles embedding layers for discrete tokens. 21 | 22 | Arguments 23 | --------- 24 | num_codebooks: int , 25 | number of codebooks of the tokenizer. 26 | vocab_size : int, 27 | size of the dictionary of embeddings 28 | emb_dim: int , 29 | the size of each embedding vector 30 | pad_index: int (default: 0), 31 | If specified, the entries at padding_idx do not contribute to the gradient. 32 | init: boolean (default: False): 33 | If set to True, init the embedding with the tokenizer embedding otherwise init randomly. 34 | freeze: boolean (default: False) 35 | If True, the embedding is frozen. If False, the model will be trained 36 | alongside with the rest of the pipeline. 37 | 38 | Example 39 | ------- 40 | >>> from speechbrain.lobes.models.huggingface_transformers.encodec import Encodec 41 | >>> model_hub = "facebook/encodec_24khz" 42 | >>> save_path = "savedir" 43 | >>> model = Encodec(model_hub, save_path) 44 | >>> audio = torch.randn(4, 1000) 45 | >>> length = torch.tensor([1.0, .5, .75, 1.0]) 46 | >>> tokens, emb = model.encode(audio, length) 47 | >>> print(tokens.shape) 48 | torch.Size([4, 4, 2]) 49 | >>> emb= Discrete_EmbeddingLayer(2, 1024, 1024) 50 | >>> in_emb = emb(tokens) 51 | >>> print(in_emb.shape) 52 | torch.Size([4, 4, 2, 1024]) 53 | """ 54 | 55 | def __init__( 56 | self, 57 | num_codebooks, 58 | vocab_size, 59 | emb_dim, 60 | pad_index=0, 61 | init=False, 62 | freeze=False, 63 | ): 64 | super(Discrete_EmbeddingLayer, self).__init__() 65 | self.vocab_size = vocab_size 66 | self.num_codebooks = num_codebooks 67 | self.freeze = freeze 68 | self.embedding = torch.nn.Embedding( 69 | num_codebooks * vocab_size, emb_dim 70 | ).requires_grad_(not self.freeze) 71 | self.init = init 72 | 73 | def init_embedding(self, weights): 74 | with torch.no_grad(): 75 | self.embedding.weight = torch.nn.Parameter(weights) 76 | 77 | def forward(self, in_tokens): 78 | """Computes the embedding for discrete tokens. 79 | a sample. 80 | 81 | Arguments 82 | --------- 83 | in_tokens : torch.Tensor 84 | A (Batch x Time x num_codebooks) 85 | audio sample 86 | Returns 87 | ------- 88 | in_embs : torch.Tensor 89 | """ 90 | with torch.set_grad_enabled(not self.freeze): 91 | # Add unique token IDs across diffrent codebooks by adding num_codebooks * vocab_size 92 | in_tokens += torch.arange( 93 | 0, 94 | self.num_codebooks * self.vocab_size, 95 | self.vocab_size, 96 | device=in_tokens.device, 97 | ) 98 | # Forward Pass to embedding and 99 | in_embs = self.embedding(in_tokens) 100 | return in_embs 101 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/whisper/hparams/train_er.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: Whisper (encoder-decoder) + cross-entropy loss 3 | # Authors: Luca Della Libera 2023 4 | # ############################################################################ 5 | 6 | experiment_name: ER 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | 13 | # Data preparation 14 | base_locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 15 | new_locales: [rw, eo, kab, lg, mhr, ckb, ab, kmr, fy-NL, ia] 16 | data_folder: !PLACEHOLDER 17 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 18 | 19 | # Output directories 20 | output_folder: !ref results/// 21 | save_folder: !ref /save 22 | 23 | # Training parameters 24 | train_batch_size: 4 25 | valid_batch_size: 16 26 | train_num_workers: 6 27 | valid_num_workers: 6 28 | 29 | sample_rate: 16000 30 | sorting: ascending 31 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 32 | 33 | ignore_index: -100 # For cross-entropy loss 34 | label_smoothing: 0 35 | 36 | num_epochs: 2 37 | lr: 0.0001 38 | improvement_threshold: 0.0025 39 | annealing_factor: 0.8 40 | replay_ratio: 0.1 41 | 42 | whisper_variant: whisper-large-v2 43 | encoder_only: False 44 | freeze: False 45 | freeze_encoder: True 46 | 47 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 48 | nonfinite_patience: 10 49 | max_grad_norm: 5.0 50 | precision: fp16 51 | gradient_checkpointing: False 52 | ckpt_interval_minutes: 600 53 | 54 | max_gen_tokens: 80 55 | forced_decoder_locale: null # Set dynamically 56 | normalize_transcripts: True 57 | 58 | # Dataloader options 59 | train_dataloader_kwargs: 60 | batch_size: !ref 61 | num_workers: !ref 62 | 63 | valid_dataloader_kwargs: 64 | batch_size: !ref 65 | num_workers: !ref 66 | 67 | # Modules 68 | whisper: !new:model.ProgressiveWhisper 69 | source: !ref openai/ 70 | save_path: !ref /checkpoint 71 | sampling_rate: !ref 72 | encoder_only: !ref 73 | freeze: !ref 74 | freeze_encoder: !ref 75 | 76 | ce_loss: !new:torch.nn.CrossEntropyLoss 77 | ignore_index: !ref 78 | label_smoothing: !ref 79 | 80 | modules: 81 | whisper: !ref 82 | 83 | # Optimizers 84 | opt_class: !name:torch.optim.AdamW 85 | lr: !ref 86 | 87 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 88 | initial_value: !ref 89 | improvement_threshold: !ref 90 | annealing_factor: !ref 91 | patient: 0 92 | 93 | # Performance metrics 94 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 95 | 96 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 97 | split_tokens: True 98 | 99 | # Counters, checkpointers, loggers, etc. 100 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 101 | limit: !ref 102 | 103 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 104 | checkpoints_dir: !ref 105 | recoverables: 106 | model: !ref 107 | scheduler: !ref 108 | counter: !ref 109 | 110 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 111 | save_file: !ref /.txt 112 | -------------------------------------------------------------------------------- /tests/consistency/DOCSTRINGS.md: -------------------------------------------------------------------------------- 1 | # Docstrings in SpeechBrain 2 | All the functions or classes of SpeechBrain must have a docstring. SpeechBrain adopts the NumPy-like style for the docstrings. 3 | - Here is an example of a class: 4 | 5 | > class SincConv(nn.Module): 6 | > """This function implements SincConv (SincNet). 7 | > 8 | > M. Ravanelli, Y. Bengio, "Speaker Recognition from raw waveform with 9 | > SincNet", in Proc. of SLT 2018 (https://arxiv.org/abs/1808.00158) 10 | > 11 | > Arguments 12 | > --------- 13 | > input_shape : tuple 14 | > The shape of the input. Alternatively use ``in_channels``. 15 | > in_channels : int 16 | > The number of input channels. Alternatively use ``input_shape``. 17 | > out_channels : int 18 | > It is the number of output channels. 19 | > kernel_size: int 20 | > Kernel size of the convolutional filters. 21 | > stride : int 22 | > Stride factor of the convolutional filters. When the stride factor > 1, 23 | > a decimation in time is performed. 24 | > dilation : int 25 | > Dilation factor of the convolutional filters. 26 | > padding : str 27 | > (same, valid, causal). If "valid", no padding is performed. 28 | > If "same" and stride is 1, the output shape is the same as the input shape. 29 | > "causal" results in causal (dilated) convolutions. 30 | > padding_mode : str 31 | > This flag specifies the type of padding. See torch.nn documentation 32 | > for more information. 33 | > groups: int 34 | > This option specifies the convolutional groups. See torch.nn 35 | > documentation for more information. 36 | > bias : bool 37 | > If True, the additive bias b is adopted. 38 | > sample_rate : int, 39 | > The sampling rate of the input signals. It is only used for sinc_conv. 40 | > min_low_hz : float 41 | > Lowest possible frequency (in Hz) for a filter. It is only used for 42 | > sinc_conv. 43 | > min_low_hz : float 44 | > Lowest possible value (in Hz) for a filter bandwidth. 45 | > 46 | > Example 47 | > ------- 48 | > >>> inp_tensor = torch.rand([10, 16000]) 49 | > >>> conv = SincConv(input_shape=inp_tensor.shape, out_channels=25, kernel_size=11) 50 | > >>> out_tensor = conv(inp_tensor) 51 | > >>> out_tensor.shape 52 | > torch.Size([10, 16000, 25]) 53 | > """ 54 | 55 | Here is an example of a function: 56 | 57 | > def ngram_perplexity(eval_details, logbase=10.0): 58 | > """ 59 | > Computes perplexity from a list of individual sentence evaluations. 60 | > 61 | > Arguments 62 | > --------- 63 | > eval_details : list 64 | > List of individual sentence evaluations. As returned by 65 | > `ngram_evaluation_details` 66 | > logbase : float 67 | > The logarithm base to use. 68 | > 69 | > Returns 70 | > ------- 71 | > float 72 | > The computed perplexity. 73 | > 74 | > Example 75 | > ------- 76 | > >>> eval_details = [ 77 | > ... collections.Counter(neglogprob=5, num_tokens=5), 78 | > ... collections.Counter(neglogprob=15, num_tokens=15)] 79 | > >>> ngram_perplexity(eval_details) 80 | > 10.0 81 | 82 | We strongly encourage contributors to add a runnable example for the most important functions and classes. The examples will be tested automatically with pytest and help clarify how the function/classes should be used. We also encourage contributors to accurately describe the arguments and returns of a function (along with their types). Short docstring (e.g., 1-line) are acceptable for minor functions only, but we encourage anyway to describe at least the inputs and outputs. 83 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/whisper/hparams/train_agem.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: Whisper (encoder-decoder) + cross-entropy loss 3 | # Authors: Luca Della Libera 2023 4 | # ############################################################################ 5 | 6 | experiment_name: A-GEM 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | 13 | # Data preparation 14 | base_locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 15 | new_locales: [rw, eo, kab, lg, mhr, ckb, ab, kmr, fy-NL, ia] 16 | data_folder: !PLACEHOLDER 17 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 18 | 19 | # Output directories 20 | output_folder: !ref results/// 21 | save_folder: !ref /save 22 | 23 | # Training parameters 24 | train_batch_size: 1 25 | valid_batch_size: 16 26 | train_num_workers: 6 27 | valid_num_workers: 6 28 | 29 | sample_rate: 16000 30 | sorting: ascending 31 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 32 | 33 | ignore_index: -100 # For cross-entropy loss 34 | label_smoothing: 0 35 | 36 | num_epochs: 2 37 | lr: 0.0001 38 | improvement_threshold: 0.0025 39 | annealing_factor: 0.8 40 | replay_ratio: 0.1 41 | 42 | whisper_variant: whisper-large-v2 43 | encoder_only: False 44 | freeze: False 45 | freeze_encoder: True 46 | 47 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 48 | nonfinite_patience: 10 49 | max_grad_norm: 5.0 50 | precision: fp16 51 | gradient_checkpointing: False 52 | ckpt_interval_minutes: 600 53 | 54 | max_gen_tokens: 80 55 | forced_decoder_locale: null # Set dynamically 56 | normalize_transcripts: True 57 | 58 | # Dataloader options 59 | train_dataloader_kwargs: 60 | batch_size: !ref 61 | num_workers: !ref 62 | 63 | valid_dataloader_kwargs: 64 | batch_size: !ref 65 | num_workers: !ref 66 | 67 | # Modules 68 | whisper: !new:model.ProgressiveWhisper 69 | source: !ref openai/ 70 | save_path: !ref /checkpoint 71 | sampling_rate: !ref 72 | encoder_only: !ref 73 | freeze: !ref 74 | freeze_encoder: !ref 75 | 76 | ce_loss: !new:torch.nn.CrossEntropyLoss 77 | ignore_index: !ref 78 | label_smoothing: !ref 79 | 80 | modules: 81 | whisper: !ref 82 | 83 | # Optimizers 84 | opt_class: !name:torch.optim.AdamW 85 | lr: !ref 86 | 87 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 88 | initial_value: !ref 89 | improvement_threshold: !ref 90 | annealing_factor: !ref 91 | patient: 0 92 | 93 | # Performance metrics 94 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 95 | 96 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 97 | split_tokens: True 98 | 99 | # Counters, checkpointers, loggers, etc. 100 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 101 | limit: !ref 102 | 103 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 104 | checkpoints_dir: !ref 105 | recoverables: 106 | model: !ref 107 | scheduler: !ref 108 | counter: !ref 109 | 110 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 111 | save_file: !ref /.txt 112 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/whisper/hparams/train_pnn.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: Whisper (encoder-decoder) + cross-entropy loss 3 | # Authors: Luca Della Libera 2023 4 | # ############################################################################ 5 | 6 | experiment_name: PNN 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | 13 | # Data preparation 14 | base_locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 15 | new_locales: [rw, eo, kab, lg, mhr, ckb, ab, kmr, fy-NL, ia] 16 | data_folder: !PLACEHOLDER 17 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 18 | 19 | # Output directories 20 | output_folder: !ref results/// 21 | save_folder: !ref /save 22 | 23 | # Training parameters 24 | train_batch_size: 8 25 | valid_batch_size: 16 26 | train_num_workers: 6 27 | valid_num_workers: 6 28 | 29 | sample_rate: 16000 30 | sorting: ascending 31 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 32 | 33 | ignore_index: -100 # For cross-entropy loss 34 | label_smoothing: 0 35 | 36 | num_epochs: 2 37 | lr: 0.0001 38 | improvement_threshold: 0.0025 39 | annealing_factor: 0.8 40 | num_new_decoder_layers: 1 41 | 42 | whisper_variant: whisper-large-v2 43 | encoder_only: False 44 | freeze: False 45 | freeze_encoder: True 46 | 47 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 48 | nonfinite_patience: 10 49 | max_grad_norm: 5.0 50 | precision: fp16 51 | gradient_checkpointing: False 52 | ckpt_interval_minutes: 600 53 | 54 | max_gen_tokens: 80 55 | forced_decoder_locale: null # Set dynamically 56 | normalize_transcripts: True 57 | 58 | # Dataloader options 59 | train_dataloader_kwargs: 60 | batch_size: !ref 61 | num_workers: !ref 62 | 63 | valid_dataloader_kwargs: 64 | batch_size: !ref 65 | num_workers: !ref 66 | 67 | # Modules 68 | whisper: !new:model.ProgressiveWhisper 69 | source: !ref openai/ 70 | save_path: !ref /checkpoint 71 | sampling_rate: !ref 72 | encoder_only: !ref 73 | freeze: !ref 74 | freeze_encoder: !ref 75 | 76 | ce_loss: !new:torch.nn.CrossEntropyLoss 77 | ignore_index: !ref 78 | label_smoothing: !ref 79 | 80 | modules: 81 | whisper: !ref 82 | 83 | # Optimizers 84 | opt_class: !name:torch.optim.AdamW 85 | lr: !ref 86 | 87 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 88 | initial_value: !ref 89 | improvement_threshold: !ref 90 | annealing_factor: !ref 91 | patient: 0 92 | 93 | # Performance metrics 94 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 95 | 96 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 97 | split_tokens: True 98 | 99 | # Counters, checkpointers, loggers, etc. 100 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 101 | limit: !ref 102 | 103 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 104 | checkpoints_dir: !ref 105 | recoverables: 106 | model: !ref 107 | scheduler: !ref 108 | counter: !ref 109 | 110 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 111 | save_file: !ref /.txt 112 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/wavlm/hparams/pretrain.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: WavLM + LSTM + CTC loss 3 | # Authors: Luca Della Libera 2023 4 | # ############################################################################ 5 | 6 | experiment_name: pretrain 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | 13 | # Data preparation 14 | locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 15 | data_folder: !PLACEHOLDER 16 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 17 | 18 | # Output directories 19 | output_folder: !ref results/// 20 | save_folder: !ref /save 21 | 22 | # Training parameters 23 | train_batch_size: 8 24 | valid_batch_size: 16 25 | train_num_workers: 6 26 | valid_num_workers: 6 27 | 28 | sample_rate: 16000 29 | sorting: ascending 30 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 31 | 32 | blank_index: 0 33 | 34 | num_epochs: 20 35 | lr: 0.0001 36 | improvement_threshold: 0.0025 37 | annealing_factor: 0.8 38 | 39 | wavlm_variant: wavlm-large 40 | output_norm: True 41 | freeze: False # WavLM + LSTM 42 | freeze_encoder: False # WavLM 43 | freeze_feature_extractor: True # Feature extractor of WavLM 44 | hidden_size: 1024 45 | num_layers: 2 46 | dropout: 0.0 47 | bidirectional: True 48 | 49 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 50 | nonfinite_patience: 10 51 | max_grad_norm: 5.0 52 | precision: fp16 53 | gradient_checkpointing: False 54 | ckpt_interval_minutes: 600 55 | 56 | # Dataloader options 57 | train_dataloader_kwargs: 58 | batch_size: !ref 59 | num_workers: !ref 60 | 61 | valid_dataloader_kwargs: 62 | batch_size: !ref 63 | num_workers: !ref 64 | 65 | # Modules 66 | wavlm: !new:model.ProgressiveWavLM 67 | source: !ref microsoft/ 68 | save_path: !ref /checkpoint 69 | output_norm: !ref 70 | freeze: !ref 71 | freeze_encoder: !ref 72 | freeze_feature_extractor: !ref 73 | hidden_size: !ref 74 | num_layers: !ref 75 | dropout: !ref 76 | bidirectional: !ref 77 | 78 | ctc_loss: !name:speechbrain.nnet.losses.ctc_loss 79 | blank_index: !ref 80 | 81 | modules: 82 | wavlm: !ref 83 | 84 | # Optimizers 85 | opt_class: !name:torch.optim.AdamW 86 | lr: !ref 87 | 88 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 89 | initial_value: !ref 90 | improvement_threshold: !ref 91 | annealing_factor: !ref 92 | patient: 0 93 | 94 | # Performance metrics 95 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 96 | 97 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 98 | split_tokens: True 99 | 100 | # Counters, checkpointers, loggers, etc. 101 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 102 | limit: !ref 103 | 104 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 105 | checkpoints_dir: !ref 106 | recoverables: 107 | model: !ref 108 | scheduler: !ref 109 | counter: !ref 110 | 111 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 112 | save_file: !ref /.txt 113 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/whisper/hparams/train_der.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: Whisper (encoder-decoder) + cross-entropy loss 3 | # Authors: Luca Della Libera 2023 4 | # ############################################################################ 5 | 6 | experiment_name: DER 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | 13 | # Data preparation 14 | base_locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 15 | new_locales: [rw, eo, kab, lg, mhr, ckb, ab, kmr, fy-NL, ia] 16 | data_folder: !PLACEHOLDER 17 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 18 | 19 | # Output directories 20 | output_folder: !ref results/// 21 | save_folder: !ref /save 22 | 23 | # Training parameters 24 | train_batch_size: 4 25 | valid_batch_size: 16 26 | train_num_workers: 6 27 | valid_num_workers: 6 28 | 29 | sample_rate: 16000 30 | sorting: ascending 31 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 32 | 33 | ignore_index: -100 # For cross-entropy loss 34 | label_smoothing: 0 35 | 36 | num_epochs: 2 37 | lr: 0.0001 38 | improvement_threshold: 0.0025 39 | annealing_factor: 0.8 40 | replay_ratio: 0.1 41 | der_alpha: 1.0 42 | 43 | whisper_variant: whisper-large-v2 44 | encoder_only: False 45 | freeze: False 46 | freeze_encoder: True 47 | 48 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 49 | nonfinite_patience: 10 50 | max_grad_norm: 5.0 51 | precision: fp16 52 | gradient_checkpointing: False 53 | ckpt_interval_minutes: 600 54 | 55 | max_gen_tokens: 80 56 | forced_decoder_locale: null # Set dynamically 57 | normalize_transcripts: True 58 | 59 | # Dataloader options 60 | train_dataloader_kwargs: 61 | batch_size: !ref 62 | num_workers: !ref 63 | 64 | valid_dataloader_kwargs: 65 | batch_size: !ref 66 | num_workers: !ref 67 | 68 | # Modules 69 | whisper: !new:model.ProgressiveWhisper 70 | source: !ref openai/ 71 | save_path: !ref /checkpoint 72 | sampling_rate: !ref 73 | encoder_only: !ref 74 | freeze: !ref 75 | freeze_encoder: !ref 76 | 77 | ce_loss: !new:torch.nn.CrossEntropyLoss 78 | ignore_index: !ref 79 | label_smoothing: !ref 80 | 81 | modules: 82 | whisper: !ref 83 | 84 | # Optimizers 85 | opt_class: !name:torch.optim.AdamW 86 | lr: !ref 87 | 88 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 89 | initial_value: !ref 90 | improvement_threshold: !ref 91 | annealing_factor: !ref 92 | patient: 0 93 | 94 | # Performance metrics 95 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 96 | 97 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 98 | split_tokens: True 99 | 100 | # Counters, checkpointers, loggers, etc. 101 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 102 | limit: !ref 103 | 104 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 105 | checkpoints_dir: !ref 106 | recoverables: 107 | model: !ref 108 | scheduler: !ref 109 | counter: !ref 110 | 111 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 112 | save_file: !ref /.txt 113 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/whisper/hparams/train_pb.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: Whisper (encoder-decoder) + cross-entropy loss 3 | # Authors: Luca Della Libera 2023 4 | # ############################################################################ 5 | 6 | experiment_name: PB 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | 13 | # Data preparation 14 | base_locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 15 | new_locales: [rw, eo, kab, lg, mhr, ckb, ab, kmr, fy-NL, ia] 16 | data_folder: !PLACEHOLDER 17 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 18 | 19 | # Output directories 20 | output_folder: !ref results/// 21 | save_folder: !ref /save 22 | 23 | # Training parameters 24 | train_batch_size: 4 25 | valid_batch_size: 12 26 | train_num_workers: 6 27 | valid_num_workers: 6 28 | 29 | sample_rate: 16000 30 | sorting: ascending 31 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 32 | 33 | ignore_index: -100 # For cross-entropy loss 34 | label_smoothing: 0 35 | 36 | num_epochs: 2 37 | lr: 0.0001 38 | improvement_threshold: 0.0025 39 | annealing_factor: 0.8 40 | mask_init: 0.01 41 | mask_threshold: 0.005 42 | 43 | whisper_variant: whisper-large-v2 44 | encoder_only: False 45 | freeze: False 46 | freeze_encoder: True 47 | 48 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 49 | nonfinite_patience: 10 50 | max_grad_norm: 5.0 51 | precision: fp16 52 | gradient_checkpointing: False 53 | ckpt_interval_minutes: 600 54 | 55 | max_gen_tokens: 80 56 | forced_decoder_locale: null # Set dynamically 57 | normalize_transcripts: True 58 | 59 | # Dataloader options 60 | train_dataloader_kwargs: 61 | batch_size: !ref 62 | num_workers: !ref 63 | 64 | valid_dataloader_kwargs: 65 | batch_size: !ref 66 | num_workers: !ref 67 | 68 | # Modules 69 | whisper: !new:model.ProgressiveWhisper 70 | source: !ref openai/ 71 | save_path: !ref /checkpoint 72 | sampling_rate: !ref 73 | encoder_only: !ref 74 | freeze: !ref 75 | freeze_encoder: !ref 76 | 77 | ce_loss: !new:torch.nn.CrossEntropyLoss 78 | ignore_index: !ref 79 | label_smoothing: !ref 80 | 81 | modules: 82 | whisper: !ref 83 | 84 | # Optimizers 85 | opt_class: !name:torch.optim.AdamW 86 | lr: !ref 87 | 88 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 89 | initial_value: !ref 90 | improvement_threshold: !ref 91 | annealing_factor: !ref 92 | patient: 0 93 | 94 | # Performance metrics 95 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 96 | 97 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 98 | split_tokens: True 99 | 100 | # Counters, checkpointers, loggers, etc. 101 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 102 | limit: !ref 103 | 104 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 105 | checkpoints_dir: !ref 106 | recoverables: 107 | model: !ref 108 | scheduler: !ref 109 | counter: !ref 110 | 111 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 112 | save_file: !ref /.txt 113 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/whisper/hparams/train_lwf.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: Whisper (encoder-decoder) + cross-entropy loss 3 | # Authors: Luca Della Libera 2023, Pooneh Mousavi 2023 4 | # ############################################################################ 5 | 6 | experiment_name: LwF 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | 13 | # Data preparation 14 | base_locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 15 | new_locales: [rw, eo, kab, lg, mhr, ckb, ab, kmr, fy-NL, ia] 16 | data_folder: !PLACEHOLDER 17 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 18 | 19 | # Output directories 20 | output_folder: !ref results/// 21 | save_folder: !ref /save 22 | 23 | # Training parameters 24 | train_batch_size: 1 25 | valid_batch_size: 16 26 | train_num_workers: 6 27 | valid_num_workers: 6 28 | 29 | sample_rate: 16000 30 | sorting: ascending 31 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 32 | 33 | ignore_index: -100 # For cross-entropy loss 34 | label_smoothing: 0 35 | 36 | num_epochs: 2 37 | lr: 0.0001 38 | improvement_threshold: 0.0025 39 | annealing_factor: 0.8 40 | lwf_lambda: 10.0 41 | lwf_T: 2.0 42 | 43 | whisper_variant: whisper-large-v2 44 | encoder_only: False 45 | freeze: False 46 | freeze_encoder: True 47 | 48 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 49 | nonfinite_patience: 10 50 | max_grad_norm: 5.0 51 | precision: fp16 52 | gradient_checkpointing: False 53 | ckpt_interval_minutes: 600 54 | 55 | max_gen_tokens: 80 56 | forced_decoder_locale: null # Set dynamically 57 | normalize_transcripts: True 58 | 59 | # Dataloader options 60 | train_dataloader_kwargs: 61 | batch_size: !ref 62 | num_workers: !ref 63 | 64 | valid_dataloader_kwargs: 65 | batch_size: !ref 66 | num_workers: !ref 67 | 68 | # Modules 69 | whisper: !new:model.ProgressiveWhisper 70 | source: !ref openai/ 71 | save_path: !ref /checkpoint 72 | sampling_rate: !ref 73 | encoder_only: !ref 74 | freeze: !ref 75 | freeze_encoder: !ref 76 | 77 | ce_loss: !new:torch.nn.CrossEntropyLoss 78 | ignore_index: !ref 79 | label_smoothing: !ref 80 | 81 | modules: 82 | whisper: !ref 83 | 84 | # Optimizers 85 | opt_class: !name:torch.optim.AdamW 86 | lr: !ref 87 | 88 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 89 | initial_value: !ref 90 | improvement_threshold: !ref 91 | annealing_factor: !ref 92 | patient: 0 93 | 94 | # Performance metrics 95 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 96 | 97 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 98 | split_tokens: True 99 | 100 | # Counters, checkpointers, loggers, etc. 101 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 102 | limit: !ref 103 | 104 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 105 | checkpoints_dir: !ref 106 | recoverables: 107 | model: !ref 108 | scheduler: !ref 109 | counter: !ref 110 | 111 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 112 | save_file: !ref /.txt 113 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/whisper/hparams/train_mas.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: Whisper (encoder-decoder) + cross-entropy loss 3 | # Authors: Luca Della Libera 2023 4 | # ############################################################################ 5 | 6 | experiment_name: MAS 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | skip_mas: False 13 | 14 | # Data preparation 15 | base_locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 16 | new_locales: [rw, eo, kab, lg, mhr, ckb, ab, kmr, fy-NL, ia] 17 | data_folder: !PLACEHOLDER 18 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 19 | 20 | # Output directories 21 | output_folder: !ref results/// 22 | save_folder: !ref /save 23 | 24 | # Training parameters 25 | train_batch_size: 6 26 | valid_batch_size: 16 27 | train_num_workers: 6 28 | valid_num_workers: 6 29 | 30 | sample_rate: 16000 31 | sorting: ascending 32 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 33 | 34 | ignore_index: -100 # For cross-entropy loss 35 | label_smoothing: 0 36 | 37 | num_epochs: 2 38 | lr: 0.0001 39 | improvement_threshold: 0.0025 40 | annealing_factor: 0.8 41 | mas_lambda: 1.0 42 | mas_alpha: 0.5 43 | 44 | whisper_variant: whisper-large-v2 45 | encoder_only: False 46 | freeze: False 47 | freeze_encoder: True 48 | 49 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 50 | nonfinite_patience: 10 51 | max_grad_norm: 5.0 52 | precision: fp16 53 | gradient_checkpointing: False 54 | ckpt_interval_minutes: 600 55 | 56 | max_gen_tokens: 80 57 | forced_decoder_locale: null # Set dynamically 58 | normalize_transcripts: True 59 | 60 | # Dataloader options 61 | train_dataloader_kwargs: 62 | batch_size: !ref 63 | num_workers: !ref 64 | 65 | valid_dataloader_kwargs: 66 | batch_size: !ref 67 | num_workers: !ref 68 | 69 | # Modules 70 | whisper: !new:model.ProgressiveWhisper 71 | source: !ref openai/ 72 | save_path: !ref /checkpoint 73 | sampling_rate: !ref 74 | encoder_only: !ref 75 | freeze: !ref 76 | freeze_encoder: !ref 77 | 78 | ce_loss: !new:torch.nn.CrossEntropyLoss 79 | ignore_index: !ref 80 | label_smoothing: !ref 81 | 82 | modules: 83 | whisper: !ref 84 | 85 | # Optimizers 86 | opt_class: !name:torch.optim.AdamW 87 | lr: !ref 88 | 89 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 90 | initial_value: !ref 91 | improvement_threshold: !ref 92 | annealing_factor: !ref 93 | patient: 0 94 | 95 | # Performance metrics 96 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 97 | 98 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 99 | split_tokens: True 100 | 101 | # Counters, checkpointers, loggers, etc. 102 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 103 | limit: !ref 104 | 105 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 106 | checkpoints_dir: !ref 107 | recoverables: 108 | model: !ref 109 | scheduler: !ref 110 | counter: !ref 111 | 112 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 113 | save_file: !ref /.txt 114 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/whisper/hparams/train_ewc.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: Whisper (encoder-decoder) + cross-entropy loss 3 | # Authors: Luca Della Libera 2023, Pooneh Mousavi 2023 4 | # ############################################################################ 5 | 6 | experiment_name: EWC 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | skip_ewc: False 13 | 14 | # Data preparation 15 | base_locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 16 | new_locales: [rw, eo, kab, lg, mhr, ckb, ab, kmr, fy-NL, ia] 17 | data_folder: !PLACEHOLDER 18 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 19 | 20 | # Output directories 21 | output_folder: !ref results/// 22 | save_folder: !ref /save 23 | 24 | # Training parameters 25 | train_batch_size: 6 26 | valid_batch_size: 16 27 | train_num_workers: 6 28 | valid_num_workers: 6 29 | 30 | sample_rate: 16000 31 | sorting: ascending 32 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 33 | 34 | ignore_index: -100 # For cross-entropy loss 35 | label_smoothing: 0 36 | 37 | num_epochs: 2 38 | lr: 0.0001 39 | improvement_threshold: 0.0025 40 | annealing_factor: 0.8 41 | ewc_lambda: 5.0 42 | ewc_alpha: 0.5 43 | 44 | whisper_variant: whisper-large-v2 45 | encoder_only: False 46 | freeze: False 47 | freeze_encoder: True 48 | 49 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 50 | nonfinite_patience: 10 51 | max_grad_norm: 5.0 52 | precision: fp16 53 | gradient_checkpointing: False 54 | ckpt_interval_minutes: 600 55 | 56 | max_gen_tokens: 80 57 | forced_decoder_locale: null # Set dynamically 58 | normalize_transcripts: True 59 | 60 | # Dataloader options 61 | train_dataloader_kwargs: 62 | batch_size: !ref 63 | num_workers: !ref 64 | 65 | valid_dataloader_kwargs: 66 | batch_size: !ref 67 | num_workers: !ref 68 | 69 | # Modules 70 | whisper: !new:model.ProgressiveWhisper 71 | source: !ref openai/ 72 | save_path: !ref /checkpoint 73 | sampling_rate: !ref 74 | encoder_only: !ref 75 | freeze: !ref 76 | freeze_encoder: !ref 77 | 78 | ce_loss: !new:torch.nn.CrossEntropyLoss 79 | ignore_index: !ref 80 | label_smoothing: !ref 81 | 82 | modules: 83 | whisper: !ref 84 | 85 | # Optimizers 86 | opt_class: !name:torch.optim.AdamW 87 | lr: !ref 88 | 89 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 90 | initial_value: !ref 91 | improvement_threshold: !ref 92 | annealing_factor: !ref 93 | patient: 0 94 | 95 | # Performance metrics 96 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 97 | 98 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 99 | split_tokens: True 100 | 101 | # Counters, checkpointers, loggers, etc. 102 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 103 | limit: !ref 104 | 105 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 106 | checkpoints_dir: !ref 107 | recoverables: 108 | model: !ref 109 | scheduler: !ref 110 | counter: !ref 111 | 112 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 113 | save_file: !ref /.txt 114 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/wavlm/hparams/train_joint.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: WavLM + LSTM + CTC loss 3 | # Authors: Luca Della Libera 2023 4 | # ############################################################################ 5 | 6 | experiment_name: joint 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | 13 | # Data preparation 14 | base_locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 15 | new_locales: [rw, eo, kab, lg, mhr, ckb, ab, kmr, fy-NL, ia] 16 | data_folder: !PLACEHOLDER 17 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 18 | 19 | # Output directories 20 | output_folder: !ref results/// 21 | save_folder: !ref /save 22 | 23 | # Training parameters 24 | train_batch_size: 8 25 | valid_batch_size: 16 26 | train_num_workers: 6 27 | valid_num_workers: 6 28 | 29 | sample_rate: 16000 30 | sorting: ascending 31 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 32 | 33 | blank_index: 0 34 | 35 | num_epochs: 2 36 | lr: 0.0001 37 | improvement_threshold: 0.0025 38 | annealing_factor: 0.8 39 | 40 | wavlm_variant: wavlm-large 41 | pretrained_wavlm_path: null 42 | output_norm: True 43 | freeze: False # WavLM + LSTM 44 | freeze_encoder: False # WavLM 45 | freeze_feature_extractor: True # Feature extractor of WavLM 46 | hidden_size: 1024 47 | num_layers: 2 48 | dropout: 0.0 49 | bidirectional: True 50 | 51 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 52 | nonfinite_patience: 10 53 | max_grad_norm: 5.0 54 | precision: fp16 55 | gradient_checkpointing: False 56 | ckpt_interval_minutes: 600 57 | 58 | # Dataloader options 59 | train_dataloader_kwargs: 60 | batch_size: !ref 61 | num_workers: !ref 62 | 63 | valid_dataloader_kwargs: 64 | batch_size: !ref 65 | num_workers: !ref 66 | 67 | # Modules 68 | wavlm: !new:model.ProgressiveWavLM 69 | source: !ref microsoft/ 70 | save_path: !ref /checkpoint 71 | output_norm: !ref 72 | freeze: !ref 73 | freeze_encoder: !ref 74 | freeze_feature_extractor: !ref 75 | hidden_size: !ref 76 | num_layers: !ref 77 | dropout: !ref 78 | bidirectional: !ref 79 | 80 | ctc_loss: !name:speechbrain.nnet.losses.ctc_loss 81 | blank_index: !ref 82 | 83 | modules: 84 | wavlm: !ref 85 | 86 | # Optimizers 87 | opt_class: !name:torch.optim.AdamW 88 | lr: !ref 89 | 90 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 91 | initial_value: !ref 92 | improvement_threshold: !ref 93 | annealing_factor: !ref 94 | patient: 0 95 | 96 | # Performance metrics 97 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 98 | 99 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 100 | split_tokens: True 101 | 102 | # Counters, checkpointers, loggers, etc. 103 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 104 | limit: !ref 105 | 106 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 107 | checkpoints_dir: !ref 108 | recoverables: 109 | model: !ref 110 | scheduler: !ref 111 | counter: !ref 112 | 113 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 114 | save_file: !ref /.txt 115 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/wavlm/hparams/train_ft.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: WavLM + LSTM + CTC loss 3 | # Authors: Luca Della Libera 2023, Salah Zaiem 2023 4 | # ############################################################################ 5 | 6 | experiment_name: FT 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | 13 | # Data preparation 14 | base_locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 15 | new_locales: [rw, eo, kab, lg, mhr, ckb, ab, kmr, fy-NL, ia] 16 | data_folder: !PLACEHOLDER 17 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 18 | 19 | # Output directories 20 | output_folder: !ref results/// 21 | save_folder: !ref /save 22 | 23 | # Training parameters 24 | train_batch_size: 8 25 | valid_batch_size: 16 26 | train_num_workers: 6 27 | valid_num_workers: 6 28 | 29 | sample_rate: 16000 30 | sorting: ascending 31 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 32 | 33 | blank_index: 0 34 | 35 | num_epochs: 2 36 | lr: 0.0001 37 | improvement_threshold: 0.0025 38 | annealing_factor: 0.8 39 | 40 | wavlm_variant: wavlm-large 41 | pretrained_wavlm_path: null 42 | output_norm: True 43 | freeze: False # WavLM + LSTM 44 | freeze_encoder: False # WavLM 45 | freeze_feature_extractor: True # Feature extractor of WavLM 46 | hidden_size: 1024 47 | num_layers: 2 48 | dropout: 0.0 49 | bidirectional: True 50 | 51 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 52 | nonfinite_patience: 10 53 | max_grad_norm: 5.0 54 | precision: fp16 55 | gradient_checkpointing: False 56 | ckpt_interval_minutes: 600 57 | 58 | # Dataloader options 59 | train_dataloader_kwargs: 60 | batch_size: !ref 61 | num_workers: !ref 62 | 63 | valid_dataloader_kwargs: 64 | batch_size: !ref 65 | num_workers: !ref 66 | 67 | # Modules 68 | wavlm: !new:model.ProgressiveWavLM 69 | source: !ref microsoft/ 70 | save_path: !ref /checkpoint 71 | output_norm: !ref 72 | freeze: !ref 73 | freeze_encoder: !ref 74 | freeze_feature_extractor: !ref 75 | hidden_size: !ref 76 | num_layers: !ref 77 | dropout: !ref 78 | bidirectional: !ref 79 | 80 | ctc_loss: !name:speechbrain.nnet.losses.ctc_loss 81 | blank_index: !ref 82 | 83 | modules: 84 | wavlm: !ref 85 | 86 | # Optimizers 87 | opt_class: !name:torch.optim.AdamW 88 | lr: !ref 89 | 90 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 91 | initial_value: !ref 92 | improvement_threshold: !ref 93 | annealing_factor: !ref 94 | patient: 0 95 | 96 | # Performance metrics 97 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 98 | 99 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 100 | split_tokens: True 101 | 102 | # Counters, checkpointers, loggers, etc. 103 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 104 | limit: !ref 105 | 106 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 107 | checkpoints_dir: !ref 108 | recoverables: 109 | model: !ref 110 | scheduler: !ref 111 | counter: !ref 112 | 113 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 114 | save_file: !ref /.txt 115 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/wavlm/hparams/train_er.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: WavLM + LSTM + CTC loss 3 | # Authors: Luca Della Libera 2023 4 | # ############################################################################ 5 | 6 | experiment_name: ER 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | 13 | # Data preparation 14 | base_locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 15 | new_locales: [rw, eo, kab, lg, mhr, ckb, ab, kmr, fy-NL, ia] 16 | data_folder: !PLACEHOLDER 17 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 18 | 19 | # Output directories 20 | output_folder: !ref results/// 21 | save_folder: !ref /save 22 | 23 | # Training parameters 24 | train_batch_size: 8 25 | valid_batch_size: 16 26 | train_num_workers: 6 27 | valid_num_workers: 6 28 | 29 | sample_rate: 16000 30 | sorting: ascending 31 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 32 | 33 | blank_index: 0 34 | 35 | num_epochs: 2 36 | lr: 0.0001 37 | improvement_threshold: 0.0025 38 | annealing_factor: 0.8 39 | replay_ratio: 0.1 40 | 41 | wavlm_variant: wavlm-large 42 | pretrained_wavlm_path: null 43 | output_norm: True 44 | freeze: False # WavLM + LSTM 45 | freeze_encoder: False # WavLM 46 | freeze_feature_extractor: True # Feature extractor of WavLM 47 | hidden_size: 1024 48 | num_layers: 2 49 | dropout: 0.0 50 | bidirectional: True 51 | 52 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 53 | nonfinite_patience: 10 54 | max_grad_norm: 5.0 55 | precision: fp16 56 | gradient_checkpointing: False 57 | ckpt_interval_minutes: 600 58 | 59 | # Dataloader options 60 | train_dataloader_kwargs: 61 | batch_size: !ref 62 | num_workers: !ref 63 | 64 | valid_dataloader_kwargs: 65 | batch_size: !ref 66 | num_workers: !ref 67 | 68 | # Modules 69 | wavlm: !new:model.ProgressiveWavLM 70 | source: !ref microsoft/ 71 | save_path: !ref /checkpoint 72 | output_norm: !ref 73 | freeze: !ref 74 | freeze_encoder: !ref 75 | freeze_feature_extractor: !ref 76 | hidden_size: !ref 77 | num_layers: !ref 78 | dropout: !ref 79 | bidirectional: !ref 80 | 81 | ctc_loss: !name:speechbrain.nnet.losses.ctc_loss 82 | blank_index: !ref 83 | 84 | modules: 85 | wavlm: !ref 86 | 87 | # Optimizers 88 | opt_class: !name:torch.optim.AdamW 89 | lr: !ref 90 | 91 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 92 | initial_value: !ref 93 | improvement_threshold: !ref 94 | annealing_factor: !ref 95 | patient: 0 96 | 97 | # Performance metrics 98 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 99 | 100 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 101 | split_tokens: True 102 | 103 | # Counters, checkpointers, loggers, etc. 104 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 105 | limit: !ref 106 | 107 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 108 | checkpoints_dir: !ref 109 | recoverables: 110 | model: !ref 111 | scheduler: !ref 112 | counter: !ref 113 | 114 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 115 | save_file: !ref /.txt 116 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/wavlm/hparams/train_agem.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: WavLM + LSTM + CTC loss 3 | # Authors: Luca Della Libera 2023 4 | # ############################################################################ 5 | 6 | experiment_name: A-GEM 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | 13 | # Data preparation 14 | base_locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 15 | new_locales: [rw, eo, kab, lg, mhr, ckb, ab, kmr, fy-NL, ia] 16 | data_folder: !PLACEHOLDER 17 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 18 | 19 | # Output directories 20 | output_folder: !ref results/// 21 | save_folder: !ref /save 22 | 23 | # Training parameters 24 | train_batch_size: 8 25 | valid_batch_size: 16 26 | train_num_workers: 6 27 | valid_num_workers: 6 28 | 29 | sample_rate: 16000 30 | sorting: ascending 31 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 32 | 33 | blank_index: 0 34 | 35 | num_epochs: 2 36 | lr: 0.0001 37 | improvement_threshold: 0.0025 38 | annealing_factor: 0.8 39 | replay_ratio: 0.1 40 | 41 | wavlm_variant: wavlm-large 42 | pretrained_wavlm_path: null 43 | output_norm: True 44 | freeze: False # WavLM + LSTM 45 | freeze_encoder: False # WavLM 46 | freeze_feature_extractor: True # Feature extractor of WavLM 47 | hidden_size: 1024 48 | num_layers: 2 49 | dropout: 0.0 50 | bidirectional: True 51 | 52 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 53 | nonfinite_patience: 10 54 | max_grad_norm: 5.0 55 | precision: fp16 56 | gradient_checkpointing: False 57 | ckpt_interval_minutes: 600 58 | 59 | # Dataloader options 60 | train_dataloader_kwargs: 61 | batch_size: !ref 62 | num_workers: !ref 63 | 64 | valid_dataloader_kwargs: 65 | batch_size: !ref 66 | num_workers: !ref 67 | 68 | # Modules 69 | wavlm: !new:model.ProgressiveWavLM 70 | source: !ref microsoft/ 71 | save_path: !ref /checkpoint 72 | output_norm: !ref 73 | freeze: !ref 74 | freeze_encoder: !ref 75 | freeze_feature_extractor: !ref 76 | hidden_size: !ref 77 | num_layers: !ref 78 | dropout: !ref 79 | bidirectional: !ref 80 | 81 | ctc_loss: !name:speechbrain.nnet.losses.ctc_loss 82 | blank_index: !ref 83 | 84 | modules: 85 | wavlm: !ref 86 | 87 | # Optimizers 88 | opt_class: !name:torch.optim.AdamW 89 | lr: !ref 90 | 91 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 92 | initial_value: !ref 93 | improvement_threshold: !ref 94 | annealing_factor: !ref 95 | patient: 0 96 | 97 | # Performance metrics 98 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 99 | 100 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 101 | split_tokens: True 102 | 103 | # Counters, checkpointers, loggers, etc. 104 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 105 | limit: !ref 106 | 107 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 108 | checkpoints_dir: !ref 109 | recoverables: 110 | model: !ref 111 | scheduler: !ref 112 | counter: !ref 113 | 114 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 115 | save_file: !ref /.txt 116 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/wavlm/hparams/train_lwf.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: WavLM + LSTM + CTC loss 3 | # Authors: Luca Della Libera 2023 4 | # ############################################################################ 5 | 6 | experiment_name: LwF 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | 13 | # Data preparation 14 | base_locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 15 | new_locales: [rw, eo, kab, lg, mhr, ckb, ab, kmr, fy-NL, ia] 16 | data_folder: !PLACEHOLDER 17 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 18 | 19 | # Output directories 20 | output_folder: !ref results/// 21 | save_folder: !ref /save 22 | 23 | # Training parameters 24 | train_batch_size: 8 25 | valid_batch_size: 16 26 | train_num_workers: 6 27 | valid_num_workers: 6 28 | 29 | sample_rate: 16000 30 | sorting: ascending 31 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 32 | 33 | blank_index: 0 34 | 35 | num_epochs: 2 36 | lr: 0.0001 37 | improvement_threshold: 0.0025 38 | annealing_factor: 0.8 39 | lwf_lambda: 10.0 40 | lwf_T: 2.0 41 | 42 | wavlm_variant: wavlm-large 43 | pretrained_wavlm_path: null 44 | output_norm: True 45 | freeze: False # WavLM + LSTM 46 | freeze_encoder: False # WavLM 47 | freeze_feature_extractor: True # Feature extractor of WavLM 48 | hidden_size: 1024 49 | num_layers: 2 50 | dropout: 0.0 51 | bidirectional: True 52 | 53 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 54 | nonfinite_patience: 10 55 | max_grad_norm: 5.0 56 | precision: fp16 57 | gradient_checkpointing: False 58 | ckpt_interval_minutes: 600 59 | 60 | # Dataloader options 61 | train_dataloader_kwargs: 62 | batch_size: !ref 63 | num_workers: !ref 64 | 65 | valid_dataloader_kwargs: 66 | batch_size: !ref 67 | num_workers: !ref 68 | 69 | # Modules 70 | wavlm: !new:model.ProgressiveWavLM 71 | source: !ref microsoft/ 72 | save_path: !ref /checkpoint 73 | output_norm: !ref 74 | freeze: !ref 75 | freeze_encoder: !ref 76 | freeze_feature_extractor: !ref 77 | hidden_size: !ref 78 | num_layers: !ref 79 | dropout: !ref 80 | bidirectional: !ref 81 | 82 | ctc_loss: !name:speechbrain.nnet.losses.ctc_loss 83 | blank_index: !ref 84 | 85 | modules: 86 | wavlm: !ref 87 | 88 | # Optimizers 89 | opt_class: !name:torch.optim.AdamW 90 | lr: !ref 91 | 92 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 93 | initial_value: !ref 94 | improvement_threshold: !ref 95 | annealing_factor: !ref 96 | patient: 0 97 | 98 | # Performance metrics 99 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 100 | 101 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 102 | split_tokens: True 103 | 104 | # Counters, checkpointers, loggers, etc. 105 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 106 | limit: !ref 107 | 108 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 109 | checkpoints_dir: !ref 110 | recoverables: 111 | model: !ref 112 | scheduler: !ref 113 | counter: !ref 114 | 115 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 116 | save_file: !ref /.txt 117 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/wavlm/hparams/train_der.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: WavLM + LSTM + CTC loss 3 | # Authors: Luca Della Libera 2023 4 | # ############################################################################ 5 | 6 | experiment_name: DER 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | 13 | # Data preparation 14 | base_locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 15 | new_locales: [rw, eo, kab, lg, mhr, ckb, ab, kmr, fy-NL, ia] 16 | data_folder: !PLACEHOLDER 17 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 18 | 19 | # Output directories 20 | output_folder: !ref results/// 21 | save_folder: !ref /save 22 | 23 | # Training parameters 24 | train_batch_size: 4 25 | valid_batch_size: 12 26 | train_num_workers: 6 27 | valid_num_workers: 6 28 | 29 | sample_rate: 16000 30 | sorting: ascending 31 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 32 | 33 | blank_index: 0 34 | 35 | num_epochs: 2 36 | lr: 0.0001 37 | improvement_threshold: 0.0025 38 | annealing_factor: 0.8 39 | replay_ratio: 0.1 40 | der_alpha: 1.0 41 | 42 | wavlm_variant: wavlm-large 43 | pretrained_wavlm_path: null 44 | output_norm: True 45 | freeze: False # WavLM + LSTM 46 | freeze_encoder: False # WavLM 47 | freeze_feature_extractor: True # Feature extractor of WavLM 48 | hidden_size: 1024 49 | num_layers: 2 50 | dropout: 0.0 51 | bidirectional: True 52 | 53 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 54 | nonfinite_patience: 10 55 | max_grad_norm: 5.0 56 | precision: fp16 57 | gradient_checkpointing: False 58 | ckpt_interval_minutes: 600 59 | 60 | # Dataloader options 61 | train_dataloader_kwargs: 62 | batch_size: !ref 63 | num_workers: !ref 64 | 65 | valid_dataloader_kwargs: 66 | batch_size: !ref 67 | num_workers: !ref 68 | 69 | # Modules 70 | wavlm: !new:model.ProgressiveWavLM 71 | source: !ref microsoft/ 72 | save_path: !ref /checkpoint 73 | output_norm: !ref 74 | freeze: !ref 75 | freeze_encoder: !ref 76 | freeze_feature_extractor: !ref 77 | hidden_size: !ref 78 | num_layers: !ref 79 | dropout: !ref 80 | bidirectional: !ref 81 | 82 | ctc_loss: !name:speechbrain.nnet.losses.ctc_loss 83 | blank_index: !ref 84 | 85 | modules: 86 | wavlm: !ref 87 | 88 | # Optimizers 89 | opt_class: !name:torch.optim.AdamW 90 | lr: !ref 91 | 92 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 93 | initial_value: !ref 94 | improvement_threshold: !ref 95 | annealing_factor: !ref 96 | patient: 0 97 | 98 | # Performance metrics 99 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 100 | 101 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 102 | split_tokens: True 103 | 104 | # Counters, checkpointers, loggers, etc. 105 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 106 | limit: !ref 107 | 108 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 109 | checkpoints_dir: !ref 110 | recoverables: 111 | model: !ref 112 | scheduler: !ref 113 | counter: !ref 114 | 115 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 116 | save_file: !ref /.txt 117 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/wavlm/hparams/train_pnn.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: WavLM + LSTM + CTC loss 3 | # Authors: Luca Della Libera 2023, Salah Zaiem 2023 4 | # ############################################################################ 5 | 6 | experiment_name: PNN 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | 13 | # Data preparation 14 | base_locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 15 | new_locales: [rw, eo, kab, lg, mhr, ckb, ab, kmr, fy-NL, ia] 16 | data_folder: !PLACEHOLDER 17 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 18 | 19 | # Output directories 20 | output_folder: !ref results/// 21 | save_folder: !ref /save 22 | 23 | # Training parameters 24 | train_batch_size: 8 25 | valid_batch_size: 16 26 | train_num_workers: 6 27 | valid_num_workers: 6 28 | 29 | sample_rate: 16000 30 | sorting: ascending 31 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 32 | 33 | blank_index: 0 34 | 35 | num_epochs: 2 36 | lr: 0.0001 37 | improvement_threshold: 0.0025 38 | annealing_factor: 0.8 39 | num_new_decoder_layers: 1 40 | 41 | wavlm_variant: wavlm-large 42 | pretrained_wavlm_path: null 43 | output_norm: True 44 | freeze: False # WavLM + LSTM 45 | freeze_encoder: False # WavLM 46 | freeze_feature_extractor: True # Feature extractor of WavLM 47 | hidden_size: 1024 48 | num_layers: 2 49 | dropout: 0.0 50 | bidirectional: True 51 | 52 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 53 | nonfinite_patience: 10 54 | max_grad_norm: 5.0 55 | precision: fp16 56 | gradient_checkpointing: False 57 | ckpt_interval_minutes: 600 58 | 59 | # Dataloader options 60 | train_dataloader_kwargs: 61 | batch_size: !ref 62 | num_workers: !ref 63 | 64 | valid_dataloader_kwargs: 65 | batch_size: !ref 66 | num_workers: !ref 67 | 68 | # Modules 69 | wavlm: !new:model.ProgressiveWavLM 70 | source: !ref microsoft/ 71 | save_path: !ref /checkpoint 72 | output_norm: !ref 73 | freeze: !ref 74 | freeze_encoder: !ref 75 | freeze_feature_extractor: !ref 76 | hidden_size: !ref 77 | num_layers: !ref 78 | dropout: !ref 79 | bidirectional: !ref 80 | 81 | ctc_loss: !name:speechbrain.nnet.losses.ctc_loss 82 | blank_index: !ref 83 | 84 | modules: 85 | wavlm: !ref 86 | 87 | # Optimizers 88 | opt_class: !name:torch.optim.AdamW 89 | lr: !ref 90 | 91 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 92 | initial_value: !ref 93 | improvement_threshold: !ref 94 | annealing_factor: !ref 95 | patient: 0 96 | 97 | # Performance metrics 98 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 99 | 100 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 101 | split_tokens: True 102 | 103 | # Counters, checkpointers, loggers, etc. 104 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 105 | limit: !ref 106 | 107 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 108 | checkpoints_dir: !ref 109 | recoverables: 110 | model: !ref 111 | scheduler: !ref 112 | counter: !ref 113 | 114 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 115 | save_file: !ref /.txt 116 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/wavlm/hparams/train_ewc.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: WavLM + LSTM + CTC loss 3 | # Authors: Luca Della Libera 2023 4 | # ############################################################################ 5 | 6 | experiment_name: EWC 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | skip_ewc: False 13 | 14 | # Data preparation 15 | base_locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 16 | new_locales: [rw, eo, kab, lg, mhr, ckb, ab, kmr, fy-NL, ia] 17 | data_folder: !PLACEHOLDER 18 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 19 | 20 | # Output directories 21 | output_folder: !ref results/// 22 | save_folder: !ref /save 23 | 24 | # Training parameters 25 | train_batch_size: 8 26 | valid_batch_size: 16 27 | train_num_workers: 6 28 | valid_num_workers: 6 29 | 30 | sample_rate: 16000 31 | sorting: ascending 32 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 33 | 34 | blank_index: 0 35 | 36 | num_epochs: 2 37 | lr: 0.0001 38 | improvement_threshold: 0.0025 39 | annealing_factor: 0.8 40 | ewc_lambda: 5.0 41 | ewc_alpha: 0.5 42 | 43 | wavlm_variant: wavlm-large 44 | pretrained_wavlm_path: null 45 | output_norm: True 46 | freeze: False # WavLM + LSTM 47 | freeze_encoder: False # WavLM 48 | freeze_feature_extractor: True # Feature extractor of WavLM 49 | hidden_size: 1024 50 | num_layers: 2 51 | dropout: 0.0 52 | bidirectional: True 53 | 54 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 55 | nonfinite_patience: 10 56 | max_grad_norm: 5.0 57 | precision: fp16 58 | gradient_checkpointing: False 59 | ckpt_interval_minutes: 600 60 | 61 | # Dataloader options 62 | train_dataloader_kwargs: 63 | batch_size: !ref 64 | num_workers: !ref 65 | 66 | valid_dataloader_kwargs: 67 | batch_size: !ref 68 | num_workers: !ref 69 | 70 | # Modules 71 | wavlm: !new:model.ProgressiveWavLM 72 | source: !ref microsoft/ 73 | save_path: !ref /checkpoint 74 | output_norm: !ref 75 | freeze: !ref 76 | freeze_encoder: !ref 77 | freeze_feature_extractor: !ref 78 | hidden_size: !ref 79 | num_layers: !ref 80 | dropout: !ref 81 | bidirectional: !ref 82 | 83 | ctc_loss: !name:speechbrain.nnet.losses.ctc_loss 84 | blank_index: !ref 85 | 86 | modules: 87 | wavlm: !ref 88 | 89 | # Optimizers 90 | opt_class: !name:torch.optim.AdamW 91 | lr: !ref 92 | 93 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 94 | initial_value: !ref 95 | improvement_threshold: !ref 96 | annealing_factor: !ref 97 | patient: 0 98 | 99 | # Performance metrics 100 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 101 | 102 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 103 | split_tokens: True 104 | 105 | # Counters, checkpointers, loggers, etc. 106 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 107 | limit: !ref 108 | 109 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 110 | checkpoints_dir: !ref 111 | recoverables: 112 | model: !ref 113 | scheduler: !ref 114 | counter: !ref 115 | 116 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 117 | save_file: !ref /.txt 118 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/wavlm/hparams/train_mas.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: WavLM + LSTM + CTC loss 3 | # Authors: Luca Della Libera 2023 4 | # ############################################################################ 5 | 6 | experiment_name: MAS 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | skip_mas: False 13 | 14 | # Data preparation 15 | base_locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 16 | new_locales: [rw, eo, kab, lg, mhr, ckb, ab, kmr, fy-NL, ia] 17 | data_folder: !PLACEHOLDER 18 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 19 | 20 | # Output directories 21 | output_folder: !ref results/// 22 | save_folder: !ref /save 23 | 24 | # Training parameters 25 | train_batch_size: 8 26 | valid_batch_size: 16 27 | train_num_workers: 6 28 | valid_num_workers: 6 29 | 30 | sample_rate: 16000 31 | sorting: ascending 32 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 33 | 34 | blank_index: 0 35 | 36 | num_epochs: 2 37 | lr: 0.0001 38 | improvement_threshold: 0.0025 39 | annealing_factor: 0.8 40 | mas_lambda: 1.0 41 | mas_alpha: 0.5 42 | 43 | wavlm_variant: wavlm-large 44 | pretrained_wavlm_path: null 45 | output_norm: True 46 | freeze: False # WavLM + LSTM 47 | freeze_encoder: False # WavLM 48 | freeze_feature_extractor: True # Feature extractor of WavLM 49 | hidden_size: 1024 50 | num_layers: 2 51 | dropout: 0.0 52 | bidirectional: True 53 | 54 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 55 | nonfinite_patience: 10 56 | max_grad_norm: 5.0 57 | precision: fp16 58 | gradient_checkpointing: False 59 | ckpt_interval_minutes: 600 60 | 61 | # Dataloader options 62 | train_dataloader_kwargs: 63 | batch_size: !ref 64 | num_workers: !ref 65 | 66 | valid_dataloader_kwargs: 67 | batch_size: !ref 68 | num_workers: !ref 69 | 70 | # Modules 71 | wavlm: !new:model.ProgressiveWavLM 72 | source: !ref microsoft/ 73 | save_path: !ref /checkpoint 74 | output_norm: !ref 75 | freeze: !ref 76 | freeze_encoder: !ref 77 | freeze_feature_extractor: !ref 78 | hidden_size: !ref 79 | num_layers: !ref 80 | dropout: !ref 81 | bidirectional: !ref 82 | 83 | ctc_loss: !name:speechbrain.nnet.losses.ctc_loss 84 | blank_index: !ref 85 | 86 | modules: 87 | wavlm: !ref 88 | 89 | # Optimizers 90 | opt_class: !name:torch.optim.AdamW 91 | lr: !ref 92 | 93 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 94 | initial_value: !ref 95 | improvement_threshold: !ref 96 | annealing_factor: !ref 97 | patient: 0 98 | 99 | # Performance metrics 100 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 101 | 102 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 103 | split_tokens: True 104 | 105 | # Counters, checkpointers, loggers, etc. 106 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 107 | limit: !ref 108 | 109 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 110 | checkpoints_dir: !ref 111 | recoverables: 112 | model: !ref 113 | scheduler: !ref 114 | counter: !ref 115 | 116 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 117 | save_file: !ref /.txt 118 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/wavlm/hparams/train_pb.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: WavLM + LSTM + CTC loss 3 | # Authors: Luca Della Libera 2023, Salah Zaiem 2023 4 | # ############################################################################ 5 | 6 | experiment_name: PB 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | 13 | # Data preparation 14 | base_locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 15 | new_locales: [rw, eo, kab, lg, mhr, ckb, ab, kmr, fy-NL, ia] 16 | data_folder: !PLACEHOLDER 17 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 18 | 19 | # Output directories 20 | output_folder: !ref results/// 21 | save_folder: !ref /save 22 | 23 | # Training parameters 24 | train_batch_size: 8 25 | valid_batch_size: 16 26 | train_num_workers: 6 27 | valid_num_workers: 6 28 | 29 | sample_rate: 16000 30 | sorting: ascending 31 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 32 | 33 | blank_index: 0 34 | 35 | num_epochs: 2 36 | lr: 0.0001 37 | improvement_threshold: 0.0025 38 | annealing_factor: 0.8 39 | mask_init: 0.01 40 | mask_threshold: 0.005 41 | 42 | wavlm_variant: wavlm-large 43 | pretrained_wavlm_path: null 44 | output_norm: True 45 | freeze: False # WavLM + LSTM 46 | freeze_encoder: False # WavLM 47 | freeze_feature_extractor: True # Feature extractor of WavLM 48 | hidden_size: 1024 49 | num_layers: 2 50 | dropout: 0.0 51 | bidirectional: True 52 | 53 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 54 | nonfinite_patience: 10 55 | max_grad_norm: 5.0 56 | precision: fp16 57 | gradient_checkpointing: False 58 | ckpt_interval_minutes: 600 59 | 60 | # Dataloader options 61 | train_dataloader_kwargs: 62 | batch_size: !ref 63 | num_workers: !ref 64 | 65 | valid_dataloader_kwargs: 66 | batch_size: !ref 67 | num_workers: !ref 68 | 69 | # Modules 70 | wavlm: !new:model.ProgressiveWavLM 71 | source: !ref microsoft/ 72 | save_path: !ref /checkpoint 73 | output_norm: !ref 74 | freeze: !ref 75 | freeze_encoder: !ref 76 | freeze_feature_extractor: !ref 77 | hidden_size: !ref 78 | num_layers: !ref 79 | dropout: !ref 80 | bidirectional: !ref 81 | 82 | ctc_loss: !name:speechbrain.nnet.losses.ctc_loss 83 | blank_index: !ref 84 | 85 | modules: 86 | wavlm: !ref 87 | 88 | # Optimizers 89 | opt_class: !name:torch.optim.AdamW 90 | lr: !ref 91 | 92 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 93 | initial_value: !ref 94 | improvement_threshold: !ref 95 | annealing_factor: !ref 96 | patient: 0 97 | 98 | # Performance metrics 99 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 100 | 101 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 102 | split_tokens: True 103 | 104 | # Counters, checkpointers, loggers, etc. 105 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 106 | limit: !ref 107 | 108 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 109 | checkpoints_dir: !ref 110 | recoverables: 111 | model: !ref 112 | scheduler: !ref 113 | counter: !ref 114 | 115 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 116 | save_file: !ref /.txt 117 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/whisper/hparams/train_l2p.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: Whisper (encoder-decoder) + cross-entropy loss 3 | # Authors: Luca Della Libera 2023 4 | # ############################################################################ 5 | 6 | experiment_name: L2P 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | 13 | # Data preparation 14 | base_locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 15 | new_locales: [rw, eo, kab, lg, mhr, ckb, ab, kmr, fy-NL, ia] 16 | data_folder: !PLACEHOLDER 17 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 18 | 19 | # Output directories 20 | output_folder: !ref results/// 21 | save_folder: !ref /save 22 | 23 | # Training parameters 24 | train_batch_size: 8 25 | valid_batch_size: 16 26 | train_num_workers: 6 27 | valid_num_workers: 6 28 | 29 | sample_rate: 16000 30 | sorting: ascending 31 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 32 | 33 | ignore_index: -100 # For cross-entropy loss 34 | label_smoothing: 0 35 | 36 | num_epochs: 2 37 | lr: 0.0001 38 | improvement_threshold: 0.0025 39 | annealing_factor: 0.8 40 | d_prompt: 1280 41 | 42 | whisper_variant: whisper-large-v2 43 | encoder_only: False 44 | freeze: False 45 | freeze_encoder: True 46 | 47 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 48 | nonfinite_patience: 10 49 | max_grad_norm: 5.0 50 | precision: fp16 51 | gradient_checkpointing: False 52 | ckpt_interval_minutes: 600 53 | 54 | max_gen_tokens: 80 55 | forced_decoder_locale: null # Set dynamically 56 | normalize_transcripts: True 57 | 58 | # Dataloader options 59 | train_dataloader_kwargs: 60 | batch_size: !ref 61 | num_workers: !ref 62 | 63 | valid_dataloader_kwargs: 64 | batch_size: !ref 65 | num_workers: !ref 66 | 67 | # Modules 68 | whisper: !new:model.ProgressiveWhisper 69 | source: !ref openai/ 70 | save_path: !ref /checkpoint 71 | sampling_rate: !ref 72 | encoder_only: !ref 73 | freeze: !ref 74 | freeze_encoder: !ref 75 | 76 | ce_loss: !new:torch.nn.CrossEntropyLoss 77 | ignore_index: !ref 78 | label_smoothing: !ref 79 | 80 | prompt_pool: !new:train_l2p.PromptPool 81 | locales: !ref 82 | d_prompt: !ref 83 | 84 | modules: 85 | whisper: !ref 86 | prompt_pool: !ref 87 | 88 | # Optimizers 89 | opt_class: !name:torch.optim.AdamW 90 | lr: !ref 91 | 92 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 93 | initial_value: !ref 94 | improvement_threshold: !ref 95 | annealing_factor: !ref 96 | patient: 0 97 | 98 | # Performance metrics 99 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 100 | 101 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 102 | split_tokens: True 103 | 104 | # Counters, checkpointers, loggers, etc. 105 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 106 | limit: !ref 107 | 108 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 109 | checkpoints_dir: !ref 110 | recoverables: 111 | model: !ref 112 | prompt_pool: !ref 113 | scheduler: !ref 114 | counter: !ref 115 | 116 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 117 | save_file: !ref /.txt 118 | -------------------------------------------------------------------------------- /benchmarks/CL_MASR/wavlm/hparams/train_l2p.yaml: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # Model: WavLM + LSTM + CTC loss 3 | # Authors: Luca Della Libera 2023 4 | # ############################################################################ 5 | 6 | experiment_name: L2P 7 | 8 | seed: 0 9 | __set_seed: !apply:torch.manual_seed [!ref ] 10 | 11 | skip_test: False 12 | 13 | # Data preparation 14 | base_locales: [en, zh-CN, de, es, ru, fr, pt, ja, tr, pl] 15 | new_locales: [rw, eo, kab, lg, mhr, ckb, ab, kmr, fy-NL, ia] 16 | data_folder: !PLACEHOLDER 17 | max_durations: [36000, 3600, 3600] # Maximum total durations in seconds for train, dev, and test splits for each locale 18 | 19 | # Output directories 20 | output_folder: !ref results/// 21 | save_folder: !ref /save 22 | 23 | # Training parameters 24 | train_batch_size: 8 25 | valid_batch_size: 16 26 | train_num_workers: 6 27 | valid_num_workers: 6 28 | 29 | sample_rate: 16000 30 | sorting: ascending 31 | avoid_if_longer_than: 10 # Remove utterances longer than 10s (open microphones) 32 | 33 | blank_index: 0 34 | 35 | num_epochs: 2 36 | lr: 0.0001 37 | improvement_threshold: 0.0025 38 | annealing_factor: 0.8 39 | d_prompt: 1024 40 | 41 | wavlm_variant: wavlm-large 42 | pretrained_wavlm_path: null 43 | output_norm: True 44 | freeze: False # WavLM + LSTM 45 | freeze_encoder: False # WavLM 46 | freeze_feature_extractor: True # Feature extractor of WavLM 47 | hidden_size: 1024 48 | num_layers: 2 49 | dropout: 0.0 50 | bidirectional: True 51 | 52 | max_target_length: 448 # Must be <= 448 (Whisper maximum target length) 53 | nonfinite_patience: 10 54 | max_grad_norm: 5.0 55 | precision: fp16 56 | gradient_checkpointing: False 57 | ckpt_interval_minutes: 600 58 | 59 | forced_decoder_locale: null # Set dynamically 60 | 61 | # Dataloader options 62 | train_dataloader_kwargs: 63 | batch_size: !ref 64 | num_workers: !ref 65 | 66 | valid_dataloader_kwargs: 67 | batch_size: !ref 68 | num_workers: !ref 69 | 70 | # Modules 71 | wavlm: !new:model.ProgressiveWavLM 72 | source: !ref microsoft/ 73 | save_path: !ref /checkpoint 74 | output_norm: !ref 75 | freeze: !ref 76 | freeze_encoder: !ref 77 | freeze_feature_extractor: !ref 78 | hidden_size: !ref 79 | num_layers: !ref 80 | dropout: !ref 81 | bidirectional: !ref 82 | 83 | ctc_loss: !name:speechbrain.nnet.losses.ctc_loss 84 | blank_index: !ref 85 | 86 | prompt_pool: !new:train_l2p.PromptPool 87 | locales: !ref 88 | d_prompt: !ref 89 | 90 | modules: 91 | wavlm: !ref 92 | prompt_pool: !ref 93 | 94 | # Optimizers 95 | opt_class: !name:torch.optim.AdamW 96 | lr: !ref 97 | 98 | lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler 99 | initial_value: !ref 100 | improvement_threshold: !ref 101 | annealing_factor: !ref 102 | patient: 0 103 | 104 | # Performance metrics 105 | wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 106 | 107 | cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats 108 | split_tokens: True 109 | 110 | # Counters, checkpointers, loggers, etc. 111 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 112 | limit: !ref 113 | 114 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 115 | checkpoints_dir: !ref 116 | recoverables: 117 | model: !ref 118 | prompt_pool: !ref 119 | scheduler: !ref 120 | counter: !ref 121 | 122 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 123 | save_file: !ref /.txt 124 | -------------------------------------------------------------------------------- /tests/consistency/test_HF_repo.py: -------------------------------------------------------------------------------- 1 | """Library for the HuggingFace (HF) repositories. 2 | 3 | Authors 4 | * Mirco Ravanelli 2022 5 | * Andreas Nautsch 2022 6 | """ 7 | import os 8 | import csv 9 | from speechbrain.utils.data_utils import download_file 10 | 11 | 12 | def run_HF_check( 13 | recipe_folder="tests/recipes", field="HF_repo", output_folder="HF_repos", 14 | ): 15 | """Checks if the code reported in the readme files of the HF repository is 16 | runnable. Note: the tests run the code marked as python in the readme file. 17 | 18 | Arguments 19 | --------- 20 | recipe_folder: path 21 | Path of the folder with csv recipe files summarizing all the recipes in the repo. 22 | field: string 23 | Field of the csv recipe file containing the links to HF repos. 24 | output_folder: path 25 | Where to download the HF readme files. 26 | 27 | Returns 28 | --------- 29 | check: True 30 | True if all the code runs, False otherwise. 31 | """ 32 | check = True 33 | for recipe_csvfile in os.listdir(recipe_folder): 34 | # Detect list of HF repositories 35 | HF_repos = repo_list(os.path.join(recipe_folder, recipe_csvfile), field) 36 | 37 | # Set up output folder 38 | os.makedirs(output_folder, exist_ok=True) 39 | os.chdir(output_folder) 40 | 41 | # Checking all detected repos 42 | for repo in HF_repos: 43 | if not (check_repo(repo)): 44 | check = False 45 | return check 46 | 47 | 48 | def repo_list(recipe_folder="tests/recipes", field="HF_repo"): 49 | """Get the list of HF recipes in the csv recipe file. 50 | 51 | Arguments 52 | --------- 53 | recipe_folder: path 54 | Path of the fodler with csv recipe files summarizing all the recipes in the repo. 55 | field: string 56 | Field of the csv recipe file containing the links to HF repos. 57 | 58 | Returns 59 | --------- 60 | HF_repos: list 61 | List of the detected HF repos. 62 | """ 63 | HF_repos = [] 64 | 65 | # Loop over all recipe CSVs 66 | for recipe_csvfile in os.listdir(recipe_folder): 67 | with open( 68 | os.path.join(recipe_folder, recipe_csvfile), newline="" 69 | ) as csvf: 70 | reader = csv.DictReader(csvf, delimiter=",", skipinitialspace=True) 71 | for row in reader: 72 | if len(row[field]) > 0: 73 | repos = row[field].split(" ") 74 | for repo in repos: 75 | HF_repos.append(repo) 76 | HF_repos = set(HF_repos) 77 | return HF_repos 78 | 79 | 80 | def check_repo(HF_repo): 81 | """Runs the code reported in the README file of the given HF_repo. It checks 82 | if the code runs without errors. 83 | 84 | Arguments 85 | --------- 86 | HF_repo: string 87 | URL of the HF repository to check. 88 | 89 | Returns 90 | --------- 91 | check: bool 92 | True if all the code runs, False otherwise. 93 | """ 94 | exp_name = os.path.basename(HF_repo) 95 | readme_file = HF_repo + "/raw/main/README.md" 96 | 97 | dest_file = exp_name + ".md" 98 | download_file(readme_file, dest_file) 99 | 100 | code_snippets = [] 101 | code = [] 102 | flag = False 103 | with open(dest_file, "r") as f: 104 | for line in f: 105 | if "```python" in line: 106 | flag = True 107 | code = [] 108 | elif "```\n" in line and flag: 109 | flag = False 110 | code_snippets.append(code) 111 | elif flag: 112 | if len(line.strip()) > 0: 113 | code.append(line) 114 | print(line) 115 | 116 | for code in code_snippets: 117 | try: 118 | exec("\n".join(code)) 119 | except Exception as e: 120 | print("\t" + str(e)) 121 | check = False 122 | print("\tERROR: cannot run code snippet in %s" % (HF_repo)) 123 | return check 124 | -------------------------------------------------------------------------------- /benchmarks/DASB/LibriSpeech/quantization/train_subwording.py: -------------------------------------------------------------------------------- 1 | """ 2 | Recipe to train subwording tokenization on semantic tokens(Discrete SSL tokens). 3 | 4 | To run this recipe, do the following: 5 | > python train.py hparams/train_with_[SSL-model].yaml --data_folder=/path/to/LibriSPeech 6 | Author 7 | * Pooneh Mousavi 2023 8 | """ 9 | 10 | import os 11 | import sys 12 | import logging 13 | import speechbrain as sb 14 | from speechbrain.tokenizers.SentencePiece import SentencePiece 15 | from speechbrain.utils.distributed import run_on_main 16 | from hyperpyyaml import load_hyperpyyaml 17 | import torchaudio 18 | import csv 19 | 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | if __name__ == "__main__": 24 | # Load hyperparameters file with command-line overrides 25 | hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) 26 | 27 | with open(hparams_file) as fin: 28 | hparams = load_hyperpyyaml(fin, overrides) 29 | 30 | # Create experiment directory 31 | sb.create_experiment_directory( 32 | experiment_directory=hparams["output_folder"], 33 | hyperparams_to_save=hparams_file, 34 | overrides=overrides, 35 | ) 36 | 37 | # Dataset prep (parsing Librispeech) 38 | from librispeech_prepare import prepare_librispeech # noqa 39 | 40 | if not os.path.exists(hparams["tokenized_train"]): 41 | # multi-gpu (ddp) save data preparation 42 | run_on_main( 43 | prepare_librispeech, 44 | kwargs={ 45 | "data_folder": hparams["data_folder"], 46 | "tr_splits": hparams["train_splits"], 47 | "dev_splits": hparams["dev_splits"], 48 | "te_splits": hparams["test_splits"], 49 | "save_folder": hparams["output_folder"], 50 | "merge_lst": hparams["train_splits"], 51 | "merge_name": "train.csv", 52 | "skip_prep": hparams["skip_prep"], 53 | }, 54 | ) 55 | 56 | with open(hparams["train_csv"], newline="") as csvfile: 57 | reader = csv.reader(csvfile, delimiter=",") 58 | next(reader, None) 59 | with open(hparams["tokenized_train"], "w", newline="") as csvwrite: 60 | # writer = csv.writer(csvwrite) 61 | header = ["id"] 62 | for layer in hparams["ssl_layer_num"]: 63 | header.append(f"textified_tokens_layer_{layer}") 64 | writer = csv.DictWriter(csvwrite, fieldnames=header) 65 | writer.writeheader() 66 | # writer.writerow(header) 67 | 68 | for row in reader: 69 | sig = sb.dataio.dataio.read_audio(row[2]) 70 | info = torchaudio.info(row[2]) 71 | resampled = torchaudio.transforms.Resample( 72 | info.sample_rate, hparams["sample_rate"], 73 | )(sig) 74 | discrete_units, _, _ = hparams["discrete_ssl_model"]( 75 | resampled.unsqueeze(0), 76 | None, 77 | **hparams["tokenizer_config"], 78 | ) 79 | row_dic = {} 80 | row_dic["id"] = row[0] 81 | for i, layer in enumerate(hparams["ssl_layer_num"]): 82 | tokens = (discrete_units[:, :, i]).squeeze(0) 83 | tokens_char = " ".join( 84 | [chr(token + 97) for token in tokens] 85 | ) 86 | row_dic[f"textified_tokens_layer_{layer}"] = tokens_char 87 | writer.writerow(row_dic) 88 | 89 | for layer in hparams["ssl_layer_num"]: 90 | model_dir = os.path.join( 91 | hparams["save_folder"], f"tokenizer_layer_{layer}" 92 | ) 93 | os.makedirs(model_dir, exist_ok=True) 94 | 95 | SentencePiece( 96 | model_dir=model_dir, 97 | vocab_size=hparams["vocab_size"], 98 | annotation_train=hparams["tokenized_train"], 99 | annotation_read=f"textified_tokens_layer_{layer}", 100 | annotation_format="csv", 101 | model_type="bpe", 102 | character_coverage=1.0, 103 | unk_id=hparams["unk_id"], 104 | pad_id=hparams["pad_id"], 105 | ) 106 | -------------------------------------------------------------------------------- /tests/utils/check_url.py: -------------------------------------------------------------------------------- 1 | """Libraries for automatic finding URLs in the files and checking if they are 2 | reachable. 3 | 4 | Authors 5 | * Mirco Ravanelli 2022 6 | """ 7 | import os 8 | import re 9 | import time 10 | import requests 11 | from tqdm.contrib import tqdm 12 | from speechbrain.utils.data_utils import get_all_files 13 | 14 | 15 | def get_url(path): 16 | """This function searches for the URLs in the specified file. 17 | 18 | Arguments 19 | --------- 20 | path: path 21 | Path of the file where to search for URLs. 22 | 23 | Returns 24 | ------- 25 | urls: list 26 | a list of all the URLs found in the specified path. 27 | """ 28 | # Check if files exist 29 | if not (os.path.exists(path)): 30 | print("File %s not found!" % (path)) 31 | return False 32 | 33 | # Read the file 34 | with open(path, "r") as file: 35 | text = file.read() 36 | 37 | # Set up Regex for URL detection 38 | url_regex = re.compile( 39 | r"((https?):((//)|(\\\\))+([\w\d:#@%/;$()~_?\+-=\\\.&](#!)?)*)", 40 | re.DOTALL, 41 | ) 42 | urls = re.findall(url_regex, text) 43 | 44 | return list(set(urls)) 45 | 46 | 47 | def get_all_urls(file_lst, avoid_urls): 48 | """This function searches for all the URLs in the specified file list 49 | 50 | Arguments 51 | --------- 52 | file_lst: list 53 | List of the files where to search for URLs. 54 | avoid_urls: list 55 | List of URLs to avoid. 56 | 57 | Returns 58 | ------- 59 | urls: dict 60 | A dictionary where the keys are the detected URLs and the values 61 | are the files where the URLs are found. 62 | """ 63 | all_urls = {} 64 | 65 | for path in file_lst: 66 | if ".gz" in path: 67 | continue 68 | 69 | urls = get_url(path) 70 | 71 | for url in urls: 72 | 73 | # Clean up urls 74 | url = url[0].split(")")[0] 75 | if ( 76 | url[-1] == "." 77 | or url[-1] == "," 78 | or url[-1] == " " 79 | or url[-1] == "/" 80 | ): 81 | url = url[:-1] 82 | 83 | if url in avoid_urls: 84 | continue 85 | 86 | if url not in all_urls: 87 | all_urls[url] = [] 88 | all_urls[url].append(path) 89 | return all_urls 90 | 91 | 92 | def check_url(url): 93 | """Cheks if an URL is broken 94 | 95 | Arguments 96 | --------- 97 | url: string 98 | URL to check 99 | 100 | Returns 101 | ------- 102 | Bool 103 | False if the URL is broken, True otherwise. 104 | """ 105 | try: 106 | response = requests.head(url) 107 | if response.status_code == 404 or response.status_code > 499: 108 | return False 109 | else: 110 | return True 111 | except requests.ConnectionError: 112 | return False 113 | 114 | 115 | def check_links( 116 | folder=".", 117 | match_or=[".py", ".md", ".txt"], 118 | exclude_or=[".pyc"], 119 | avoid_files=[""], 120 | avoid_urls=["http:/", "http://", "https:/", "https://"], 121 | ): 122 | """This test checks if the files in the specified folders contain broken URLs 123 | 124 | Arguments 125 | --------- 126 | folder: path 127 | The top Folder for searching for the files. 128 | match_or: list 129 | Used to specify the extensions of the files to check. 130 | exclude_or: list 131 | Used to avoid some file extensions. 132 | avoid_files: list 133 | Used to avoid testing some specific file. 134 | """ 135 | 136 | check_test = True 137 | # Find all the files that potentially contain urls 138 | file_lst = get_all_files(folder, match_or=match_or, exclude_or=exclude_or) 139 | 140 | # Get urls for the list of files - unique list 141 | all_urls = get_all_urls(file_lst, avoid_urls) 142 | 143 | # Check all the urls 144 | with tqdm(all_urls) as all_urls_progressbar: 145 | for url in all_urls_progressbar: 146 | time.sleep(1) 147 | if not (check_url(url)): 148 | check_test = False 149 | print("WARNING: %s is DOWN!" % (url)) 150 | for path in all_urls[url]: 151 | print("\t link detected in %s" % (path)) 152 | return check_test 153 | -------------------------------------------------------------------------------- /tests/utils/check_HF_repo.py: -------------------------------------------------------------------------------- 1 | """Library for the HuggingFace (HF) repositories. 2 | 3 | Authors 4 | * Mirco Ravanelli 2022 5 | * Andreas Nautsch 2022, 2023 6 | """ 7 | import os 8 | import csv 9 | from speechbrain.utils.data_utils import download_file 10 | from tests.consistency.test_recipe import __skip_list 11 | 12 | 13 | def run_HF_check( 14 | recipe_folder="tests/recipes", field="HF_repo", output_folder="tests/tmp", 15 | ): 16 | """Checks if the code reported in the readme files of the HF repository is 17 | runnable. Note: the tests run the code marked as python in the readme file. 18 | 19 | Arguments 20 | --------- 21 | recipe_folder: path 22 | Path of the folder containing csv recipe files summarizing all the recipes in the repo. 23 | field: string 24 | Field of the csv recipe file containing the links to HF repos. 25 | output_folder: path 26 | Where to download the HF readme files. 27 | 28 | Returns 29 | --------- 30 | check: True 31 | True if all the code runs, False otherwise. 32 | """ 33 | # Detect list of HF repositories 34 | HF_repos = repo_list(recipe_folder, field) 35 | 36 | # Set up output folder 37 | os.makedirs(output_folder, exist_ok=True) 38 | os.chdir(output_folder) 39 | 40 | # Checking all detected repos 41 | check = True 42 | for i, repo in enumerate(HF_repos): 43 | print("(%i/%i) Checking %s..." % (i + 1, len(HF_repos), repo)) 44 | if not check_repo(repo): 45 | check = False 46 | return check 47 | 48 | 49 | def repo_list(recipe_folder="tests/recipes", field="HF_repo"): 50 | """Get the list of HF recipes in the csv recipe file. 51 | 52 | Arguments 53 | --------- 54 | recipe_folder: path 55 | Path of the folder containing csv recipe files summarizing all the recipes in the repo. 56 | field: string 57 | Field of the csv recipe file containing the links to HF repos. 58 | 59 | Returns 60 | --------- 61 | HF_repos: list 62 | List of the detected HF repos. 63 | """ 64 | HF_repos = [] 65 | 66 | # Loop over all recipe CSVs 67 | for recipe_csvfile in os.listdir(recipe_folder): 68 | if recipe_csvfile in __skip_list: 69 | continue 70 | 71 | with open( 72 | os.path.join(recipe_folder, recipe_csvfile), newline="" 73 | ) as csvf: 74 | reader = csv.DictReader(csvf, delimiter=",", skipinitialspace=True) 75 | for row in reader: 76 | if len(row[field]) > 0: 77 | repos = row[field].split(" ") 78 | for repo in repos: 79 | HF_repos.append(repo) 80 | HF_repos = set(HF_repos) 81 | return HF_repos 82 | 83 | 84 | def check_repo(HF_repo): 85 | """Runs the code reported in the README file of the given HF_repo. It checks 86 | if the code runs without errors. 87 | 88 | Arguments 89 | --------- 90 | HF_repo: string 91 | URL of the HF repository to check. 92 | 93 | Returns 94 | --------- 95 | check: bool 96 | True if all the code runs, False otherwise. 97 | """ 98 | exp_name = os.path.basename(HF_repo) 99 | if HF_repo[-1] == "/": 100 | readme_file = HF_repo + "raw/main/README.md" 101 | else: 102 | readme_file = HF_repo + "/raw/main/README.md" 103 | 104 | dest_file = exp_name + ".md" 105 | download_file(readme_file, dest_file) 106 | 107 | code_snippets = [] 108 | code = [] 109 | flag = False 110 | check = True 111 | with open(dest_file, "r") as f: 112 | for line in f: 113 | if "```python" in line: 114 | flag = True 115 | code = [] 116 | elif "```" in line and flag: 117 | flag = False 118 | code_snippets.append(code) 119 | elif flag: 120 | if len(line.strip()) > 0: 121 | # adjust local audio paths 'tests/samples' -> '../samples' 122 | if "tests/samples" in line: 123 | line = line.replace("tests/samples", "../samples") 124 | code.append(line) 125 | 126 | for code in code_snippets: 127 | try: 128 | exec("".join(code)) 129 | except Exception as e: 130 | print("\t" + str(e)) 131 | check = False 132 | print("\tERROR: cannot run code snippet in %s" % (HF_repo)) 133 | print("---\n" + "".join(code) + "---\n") 134 | return check 135 | -------------------------------------------------------------------------------- /benchmarks/DASB/IEMOCAP/emotion_recognition/linear/hparams/train_weighted_ssl.yaml: -------------------------------------------------------------------------------- 1 | # ######################################## 2 | # Recipe for training an emotion recognition system from speech data 3 | # only using IEMOCAP and an SSL feature extractor 4 | # The system classifies 4 emotions ( anger, happiness, sadness, neutrality) 5 | # with an ECAPA-TDNN model. 6 | # Authors 7 | # * Pooneh Mousavi 2024 8 | # ######################################## 9 | 10 | # Seed needs to be set at top of yaml, before objects with parameters are made 11 | seed: 1986 12 | __set_seed: !apply:torch.manual_seed [!ref ] 13 | 14 | # Dataset will be downloaded to the `data_original` 15 | data_folder: !PLACEHOLDER # e.g., /path/to/IEMOCAP_full_release 16 | output_folder: !ref results/IEMOCAP/IEMOCAP_full_release/linear/weighted_ssl/ 17 | save_folder: !ref /save 18 | train_log: !ref /train_log.txt 19 | 20 | # URL for the ssl encoder model, you can change to benchmark diffrenet models 21 | # Important: we use ssl encoder base and not the fine-tuned one with ASR task 22 | # This allow you to have ~4% improvment 23 | 24 | ssl_hub: microsoft/wavlm-large 25 | ssl_folder: !ref /ssl_checkpoints 26 | encoder_dim: 1024 27 | 28 | # different speakers for train, valid and test sets 29 | different_speakers: True 30 | # which speaker is used for test set, value from 1 to 10 31 | # Change this value and run this value 10 times and take the mean of it 32 | test_spk_id: 1 33 | 34 | # Path where data manifest files will be stored 35 | train_annotation: !ref /train.json 36 | valid_annotation: !ref /valid.json 37 | test_annotation: !ref /test.json 38 | skip_prep: False 39 | 40 | # The train logger writes training statistics to a file, as well as stdout. 41 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 42 | save_file: !ref 43 | 44 | ckpt_interval_minutes: 15 # save checkpoint every N min 45 | 46 | # Training parameters 47 | precision: fp32 48 | number_of_epochs: 30 49 | batch_size: 2 50 | test_batch_size: 1 51 | 52 | lr: 0.0002 53 | lr_weights: 0.01 54 | 55 | # Number of emotions 56 | out_n_neurons: 4 # (anger, happiness, sadness, neutral) 57 | 58 | # Dataloader options 59 | train_dataloader_opts: 60 | batch_size: !ref 61 | shuffle: True 62 | num_workers: 2 # 2 on linux but 0 works on windows 63 | drop_last: False 64 | 65 | valid_dataloader_opts: 66 | batch_size: !ref 67 | 68 | test_dataloader_opts: 69 | batch_size: !ref 70 | 71 | weighted_ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.WeightedSSLModel # yamllint disable-line rule:line-length 72 | hub: !ref 73 | save_path: !ref 74 | 75 | avg_pool: !new:speechbrain.nnet.pooling.StatisticsPooling 76 | return_std: False 77 | 78 | output_mlp: !new:speechbrain.nnet.linear.Linear 79 | input_size: !ref 80 | n_neurons: !ref 81 | bias: False 82 | 83 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 84 | limit: !ref 85 | 86 | modules: 87 | output_mlp: !ref 88 | weighted_ssl_model: !ref 89 | 90 | model: !new:torch.nn.ModuleList 91 | - [!ref ] 92 | 93 | log_softmax: !new:speechbrain.nnet.activations.Softmax 94 | apply_log: True 95 | 96 | compute_cost: !name:speechbrain.nnet.losses.nll_loss 97 | 98 | error_stats: !name:speechbrain.utils.metric_stats.MetricStats 99 | metric: !name:speechbrain.nnet.losses.classification_error 100 | reduction: batch 101 | 102 | model_opt_class: !name:torch.optim.Adam 103 | lr: !ref 104 | 105 | weights_opt_class: !name:torch.optim.Adam 106 | lr: !ref 107 | 108 | lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler 109 | initial_value: !ref 110 | improvement_threshold: 0.0025 111 | annealing_factor: 0.9 112 | patient: 0 113 | 114 | lr_annealing_weights: !new:speechbrain.nnet.schedulers.NewBobScheduler 115 | initial_value: !ref 116 | improvement_threshold: 0.0025 117 | annealing_factor: 0.9 118 | 119 | 120 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 121 | checkpoints_dir: !ref 122 | recoverables: 123 | model: !ref 124 | ssl_model: !ref 125 | scheduler_model: !ref 126 | scheduler_encoder: !ref 127 | counter: !ref 128 | --------------------------------------------------------------------------------