├── seq2seq.png ├── .idea ├── vcs.xml ├── misc.xml ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml ├── seq2seq-speech.iml ├── git_toolbox_prj.xml ├── sshConfigs.xml ├── deployment.xml └── webServers.xml ├── models ├── __init__.py ├── configuration_speech_encoder_decoder.py ├── configuration_bart.py └── configuration_wav2vec2.py ├── scripts ├── ctc_ngram │ ├── run_ctc_ngram_ami.sh │ ├── run_ctc_ngram_tedlium.sh │ ├── run_ctc_ngram_kensho.sh │ ├── run_ctc_ngram_gs.sh │ ├── run_ctc_ngram_earnings22.sh │ ├── run_ctc_ngram_voxpopuli.sh │ ├── run_ctc_ngram_cv9.sh │ ├── run_ctc_ngram_librispeech.sh │ └── run_ctc_ngram_dummy.sh ├── whisper │ ├── run_tedlium.sh │ ├── run_gigaspeech.sh │ ├── run_spgispeech.sh │ ├── run_earnings22.sh │ ├── run_whisper_voxpopuli.sh │ ├── run_ami.sh │ ├── run_chime.sh │ ├── run_switchboard.sh │ ├── run_common_voice_9.sh │ └── run_librispeech.sh ├── ctc │ ├── run_gigaspeech.sh │ ├── run_spgispeech.sh │ ├── run_voxpopuli.sh │ ├── run_common_voice_9.sh │ ├── run_tedlium.sh │ ├── run_ami.sh │ ├── run_switchboard.sh │ ├── run_librispeech.sh │ └── run_earnings22.sh └── seq2seq │ ├── run_ami.sh │ ├── run_gigaspeech.sh │ ├── run_spgispeech.sh │ ├── run_common_voice_9.sh │ ├── run_tedlium.sh │ ├── run_switchboard.sh │ ├── run_voxpopuli.sh │ ├── run_librispeech.sh │ └── run_earnings22.sh ├── run_ctc_dummy.sh ├── run_whisper_dummy.sh ├── run_seq2seq_dummy.sh ├── tests ├── check_flax_ctc.py ├── check_flax_ctc_cv9.py ├── check_save_feature_encoder_output.ipynb └── .ipynb_checkpoints │ └── check_save_feature_encoder_output-checkpoint.ipynb ├── .gitignore ├── README.md ├── get_ctc_tokenizer.py └── get_ctc_ngram.py /seq2seq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanchit-gandhi/seq2seq-speech/HEAD/seq2seq.png -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/seq2seq-speech.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.configuration_bart import BartConfig 2 | from models.configuration_wav2vec2 import Wav2Vec2Config 3 | from models.configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig 4 | from models.modeling_flax_wav2vec2 import FlaxWav2Vec2Model, FlaxWav2Vec2Module, FlaxWav2Vec2ForCTC, FlaxWav2Vec2ForCTCModule 5 | from models.modeling_flax_bart import FlaxBartForCausalLM, FlaxBartForCausalLMModule 6 | from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel 7 | -------------------------------------------------------------------------------- /.idea/git_toolbox_prj.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 9 | 14 | 15 | -------------------------------------------------------------------------------- /scripts/ctc_ngram/run_ctc_ngram_ami.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python ./run_flax_speech_recognition_ctc_ngram.py \ 3 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-ctc-ami-black-box" \ 4 | --tokenizer_name="/home/patrick/ngrams/ami" \ 5 | --decoder_name="/home/patrick/ngrams/ami" \ 6 | --dataset_cache_dir="/home/patrick/.cache/huggingface/datasets" \ 7 | --dataset_name="speech-seq2seq/ami" \ 8 | --dataset_config_name="ihm" \ 9 | --eval_split_name="validation" \ 10 | --test_split_name="test" \ 11 | --text_column="text" \ 12 | --preprocessing_num_workers="1" \ 13 | --output_dir="/home/patrick/ngrams/ami/evaluation" \ 14 | --do_eval \ 15 | --do_predict \ 16 | --overwrite_output_dir \ 17 | --use_auth_token 18 | -------------------------------------------------------------------------------- /scripts/ctc_ngram/run_ctc_ngram_tedlium.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python ./run_flax_speech_recognition_ctc_ngram.py \ 3 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-ctc-tedlium-black-box" \ 4 | --tokenizer_name="/home/patrick/ngrams/tedlium" \ 5 | --decoder_name="/home/patrick/ngrams/tedlium" \ 6 | --dataset_cache_dir="/home/patrick/.cache/huggingface/datasets" \ 7 | --dataset_name="LIUM/tedlium" \ 8 | --dataset_config_name="release3" \ 9 | --eval_split_name="validation" \ 10 | --test_split_name="test" \ 11 | --text_column="text" \ 12 | --preprocessing_num_workers="1" \ 13 | --output_dir="/home/patrick/ngrams/tedlium/evaluation" \ 14 | --do_eval \ 15 | --do_predict \ 16 | --overwrite_output_dir \ 17 | --use_auth_token 18 | -------------------------------------------------------------------------------- /scripts/ctc_ngram/run_ctc_ngram_kensho.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python ./run_flax_speech_recognition_ctc_ngram.py \ 3 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-ctc-spgispeech-black-box" \ 4 | --tokenizer_name="/home/patrick/ngrams/spgispeech" \ 5 | --decoder_name="/home/patrick/ngrams/spgispeech" \ 6 | --dataset_cache_dir="/home/patrick/.cache/huggingface/datasets" \ 7 | --dataset_name="kensho/spgispeech" \ 8 | --dataset_config_name="L" \ 9 | --eval_split_name="validation" \ 10 | --test_split_name="test" \ 11 | --text_column_name="transcript" \ 12 | --preprocessing_num_workers="1" \ 13 | --output_dir="/home/patrick/ngrams/spgispeech/evaluation" \ 14 | --do_eval \ 15 | --do_predict \ 16 | --overwrite_output_dir \ 17 | --use_auth_token \ 18 | -------------------------------------------------------------------------------- /scripts/ctc_ngram/run_ctc_ngram_gs.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python ./run_flax_speech_recognition_ctc_ngram.py \ 3 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-ctc-gs-black-box" \ 4 | --tokenizer_name="/home/patrick/ngrams/gigaspeech" \ 5 | --decoder_name="/home/patrick/ngrams/gigaspeech" \ 6 | --dataset_cache_dir="/home/patrick/.cache/huggingface/datasets" \ 7 | --dataset_name="speechcolab/gigaspeech" \ 8 | --dataset_config_name="l" \ 9 | --eval_split_name="validation" \ 10 | --test_split_name="test" \ 11 | --text_column_name="text" \ 12 | --preprocessing_num_workers="1" \ 13 | --output_dir="/home/patrick/ngrams/gigaspeech/evaluation" \ 14 | --do_eval \ 15 | --do_predict \ 16 | --overwrite_output_dir \ 17 | --use_auth_token \ 18 | --do_lower_case 19 | -------------------------------------------------------------------------------- /scripts/ctc_ngram/run_ctc_ngram_earnings22.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python ./run_flax_speech_recognition_ctc_ngram.py \ 3 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-ctc-earnings22-cased-hidden-activation-featproj-dropout-0.2" \ 4 | --tokenizer_name="/home/patrick/ngrams/earnings22_robust_split" \ 5 | --decoder_name="/home/patrick/ngrams/earnings22_robust_split" \ 6 | --dataset_cache_dir="/home/patrick/.cache/huggingface/datasets" \ 7 | --dataset_name="sanchit-gandhi/earnings22_robust_split" \ 8 | --eval_split_name="validation" \ 9 | --test_split_name="test" \ 10 | --text_column="sentence" \ 11 | --preprocessing_num_workers="1" \ 12 | --output_dir="/home/patrick/ngrams/earnings22_robust_split/evaluation" \ 13 | --do_eval \ 14 | --do_predict \ 15 | --overwrite_output_dir \ 16 | --use_auth_token 17 | -------------------------------------------------------------------------------- /scripts/ctc_ngram/run_ctc_ngram_voxpopuli.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python ./run_flax_speech_recognition_ctc_ngram.py \ 3 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-ctc-voxpopuli-black-box" \ 4 | --tokenizer_name="/home/patrick/ngrams/voxpopuli" \ 5 | --decoder_name="/home/patrick/ngrams/voxpopuli" \ 6 | --dataset_cache_dir="/home/patrick/.cache/huggingface/datasets" \ 7 | --dataset_name="polinaeterna/voxpopuli" \ 8 | --dataset_config_name="en" \ 9 | --eval_split_name="validation" \ 10 | --test_split_name="test" \ 11 | --text_column_name="normalized_text" \ 12 | --preprocessing_num_workers="1" \ 13 | --output_dir="/home/patrick/ngrams/voxpopuli/evaluation" \ 14 | --do_eval \ 15 | --do_predict \ 16 | --overwrite_output_dir \ 17 | --max_label_length=2056 \ 18 | --use_auth_token \ 19 | --do_lower_case 20 | -------------------------------------------------------------------------------- /scripts/ctc_ngram/run_ctc_ngram_cv9.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python ./run_flax_speech_recognition_ctc_ngram.py \ 3 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-ctc-cv9-black-box" \ 4 | --tokenizer_name="/home/patrick/ngrams/common_voice_9_0" \ 5 | --decoder_name="/home/patrick/ngrams/common_voice_9_0" \ 6 | --dataset_cache_dir="/home/patrick/.cache/huggingface/datasets" \ 7 | --dataset_name="mozilla-foundation/common_voice_9_0" \ 8 | --dataset_config_name="en" \ 9 | --eval_split_name="validation" \ 10 | --test_split_name="test" \ 11 | --text_column_name="sentence" \ 12 | --preprocessing_num_workers="1" \ 13 | --max_eval_duration_in_seconds="20.0" \ 14 | --output_dir="/home/patrick/ngrams/common_voice_9_0/evaluation" \ 15 | --do_eval \ 16 | --do_predict \ 17 | --overwrite_output_dir \ 18 | --use_auth_token 19 | -------------------------------------------------------------------------------- /scripts/ctc_ngram/run_ctc_ngram_librispeech.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python ./run_flax_speech_recognition_ctc_ngram.py \ 3 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-ctc-ls-960h-black-box" \ 4 | --tokenizer_name="/home/patrick/ngrams/librispeech_asr" \ 5 | --decoder_name="/home/patrick/ngrams/librispeech_asr" \ 6 | --dataset_cache_dir="/home/patrick/.cache/huggingface/datasets" \ 7 | --dataset_name="librispeech_asr" \ 8 | --dataset_config_name="all" \ 9 | --eval_split_name="validation.clean" \ 10 | --test_split_name="validation.other+test.clean+test.other" \ 11 | --text_column="text" \ 12 | --preprocessing_num_workers="1" \ 13 | --output_dir="/home/patrick/ngrams/librispeech_asr/evaluation" \ 14 | --do_eval \ 15 | --do_predict \ 16 | --overwrite_output_dir \ 17 | --use_auth_token \ 18 | --do_lower_case 19 | -------------------------------------------------------------------------------- /scripts/ctc_ngram/run_ctc_ngram_dummy.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python ./run_flax_speech_recognition_ctc_ngram.py \ 3 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-ctc-ls-960h-baseline" \ 4 | --tokenizer_name="sanchit-gandhi/flax-wav2vec2-ctc-ls-960h-baseline" \ 5 | --decoder_name="patrickvonplaten/wav2vec2-base-100h-with-lm" \ 6 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \ 7 | --dataset_name="librispeech_asr" \ 8 | --dataset_config_name="all" \ 9 | --eval_split_name="validation.clean" \ 10 | --test_split_name="validation.other+test.clean+test.other" \ 11 | --text_column="text" \ 12 | --preprocessing_num_workers="1" \ 13 | --output_dir="./ngram_output_dir" \ 14 | --max_steps="50000" \ 15 | --eval_steps="10000" \ 16 | --save_steps="10000" \ 17 | --wandb_project="librispeech_960h" \ 18 | --wandb_name="flax-wav2vec2-ctc-ls-960h-with-lm-baseline" \ 19 | --do_eval \ 20 | --do_predict \ 21 | --overwrite_output_dir \ 22 | --use_auth_token 23 | -------------------------------------------------------------------------------- /run_ctc_dummy.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python ./run_flax_speech_recognition_ctc.py \ 3 | --model_name_or_path="speech-seq2seq/flax-wav2vec2-large-lv60-scan" \ 4 | --tokenizer_name="speech-seq2seq/flax-wav2vec2-large-lv60-scan" \ 5 | --dataset_name="librispeech_asr" \ 6 | --dataset_config_name="clean" \ 7 | --train_split_name="train.100" \ 8 | --eval_split_name="validation" \ 9 | --test_split_name="test" \ 10 | --text_column_name="text" \ 11 | --output_dir="./" \ 12 | --wandb_project="librispeech_asr" \ 13 | --wandb_name="flax-wav2vec2-ctc-ls-100h" \ 14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \ 15 | --max_steps="50000" \ 16 | --save_steps="10000" \ 17 | --eval_steps="10000" \ 18 | --learning_rate="3e-4" \ 19 | --logging_steps="25" \ 20 | --warmup_steps="500" \ 21 | --preprocessing_num_workers="1" \ 22 | --do_train \ 23 | --do_eval \ 24 | --do_predict \ 25 | --overwrite_output_dir \ 26 | --gradient_checkpointing \ 27 | --freeze_feature_encoder \ 28 | --push_to_hub \ 29 | --use_auth_token 30 | -------------------------------------------------------------------------------- /scripts/whisper/run_tedlium.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0 python run_speech_recognition_whisper.py \ 3 | --model_name_or_path="medium.en" \ 4 | --dataset_name="LIUM/tedlium" \ 5 | --dataset_config_name="release3" \ 6 | --train_split_name="train" \ 7 | --eval_split_name="validation" \ 8 | --test_split_name="test" \ 9 | --text_column_name="text" \ 10 | --max_steps="2500" \ 11 | --output_dir="./" \ 12 | --run_name="whisper-tedlium" \ 13 | --wandb_project="whisper" \ 14 | --per_device_train_batch_size="64" \ 15 | --per_device_eval_batch_size="16" \ 16 | --logging_steps="25" \ 17 | --learning_rate="1e-4" \ 18 | --warmup_steps="500" \ 19 | --report_to="wandb" \ 20 | --preprocessing_num_workers="16" \ 21 | --evaluation_strategy="steps" \ 22 | --eval_steps="500" \ 23 | --save_strategy="steps" \ 24 | --save_steps="500" \ 25 | --generation_max_length="224" \ 26 | --length_column_name="input_lengths" \ 27 | --do_lower_case="True" \ 28 | --push_to_hub="False" \ 29 | --gradient_checkpointing \ 30 | --group_by_length \ 31 | --freeze_encoder \ 32 | --fp16 \ 33 | --overwrite_output_dir \ 34 | --do_train \ 35 | --do_eval \ 36 | --do_predict \ 37 | --predict_with_generate \ 38 | --use_auth_token 39 | -------------------------------------------------------------------------------- /scripts/whisper/run_gigaspeech.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=1 python run_speech_recognition_whisper.py \ 3 | --model_name_or_path="medium.en" \ 4 | --dataset_name="speechcolab/gigaspeech" \ 5 | --dataset_config_name="m" \ 6 | --train_split_name="train" \ 7 | --eval_split_name="validation" \ 8 | --test_split_name="test" \ 9 | --text_column_name="text" \ 10 | --max_steps="5000" \ 11 | --output_dir="./" \ 12 | --run_name="whisper-gigaspeech-5k" \ 13 | --wandb_project="whisper" \ 14 | --per_device_train_batch_size="64" \ 15 | --per_device_eval_batch_size="16" \ 16 | --logging_steps="25" \ 17 | --learning_rate="1e-4" \ 18 | --warmup_steps="500" \ 19 | --report_to="wandb" \ 20 | --preprocessing_num_workers="16" \ 21 | --evaluation_strategy="steps" \ 22 | --eval_steps="1000" \ 23 | --save_strategy="steps" \ 24 | --save_steps="1000" \ 25 | --generation_max_length="224" \ 26 | --length_column_name="input_lengths" \ 27 | --do_lower_case="True" \ 28 | --push_to_hub="False" \ 29 | --gradient_checkpointing \ 30 | --group_by_length \ 31 | --freeze_encoder \ 32 | --fp16 \ 33 | --overwrite_output_dir \ 34 | --do_train \ 35 | --do_eval \ 36 | --do_predict \ 37 | --predict_with_generate \ 38 | --use_auth_token 39 | -------------------------------------------------------------------------------- /scripts/whisper/run_spgispeech.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=2 python run_speech_recognition_whisper.py \ 3 | --model_name_or_path="medium.en" \ 4 | --dataset_name="kensho/spgispeech" \ 5 | --dataset_config_name="M" \ 6 | --train_split_name="train" \ 7 | --eval_split_name="validation" \ 8 | --test_split_name="test" \ 9 | --text_column_name="transcript" \ 10 | --max_steps="5000" \ 11 | --output_dir="./" \ 12 | --run_name="whisper-spgispeech-5k" \ 13 | --wandb_project="whisper" \ 14 | --per_device_train_batch_size="64" \ 15 | --per_device_eval_batch_size="16" \ 16 | --logging_steps="25" \ 17 | --learning_rate="1e-4" \ 18 | --warmup_steps="500" \ 19 | --report_to="wandb" \ 20 | --preprocessing_num_workers="16" \ 21 | --evaluation_strategy="steps" \ 22 | --eval_steps="1000" \ 23 | --save_strategy="steps" \ 24 | --save_steps="1000" \ 25 | --generation_max_length="224" \ 26 | --length_column_name="input_lengths" \ 27 | --do_lower_case="False" \ 28 | --push_to_hub="False" \ 29 | --gradient_checkpointing \ 30 | --group_by_length \ 31 | --freeze_encoder \ 32 | --fp16 \ 33 | --overwrite_output_dir \ 34 | --do_train \ 35 | --do_eval \ 36 | --do_predict \ 37 | --predict_with_generate \ 38 | --use_auth_token 39 | -------------------------------------------------------------------------------- /scripts/whisper/run_earnings22.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=4 python run_speech_recognition_whisper.py \ 3 | --model_name_or_path="medium.en" \ 4 | --dataset_name="sanchit-gandhi/earnings22_split" \ 5 | --dataset_config_name="all" \ 6 | --train_split_name="train" \ 7 | --eval_split_name="validation" \ 8 | --test_split_name="test" \ 9 | --text_column_name="sentence" \ 10 | --max_steps="2500" \ 11 | --output_dir="./" \ 12 | --run_name="whisper-earnings22" \ 13 | --wandb_project="whisper" \ 14 | --per_device_train_batch_size="64" \ 15 | --per_device_eval_batch_size="16" \ 16 | --logging_steps="25" \ 17 | --learning_rate="1e-4" \ 18 | --warmup_steps="500" \ 19 | --report_to="wandb" \ 20 | --preprocessing_num_workers="16" \ 21 | --evaluation_strategy="steps" \ 22 | --eval_steps="500" \ 23 | --save_strategy="steps" \ 24 | --save_steps="500" \ 25 | --generation_max_length="224" \ 26 | --length_column_name="input_lengths" \ 27 | --do_lower_case="False" \ 28 | --push_to_hub="False" \ 29 | --gradient_checkpointing \ 30 | --group_by_length \ 31 | --freeze_encoder \ 32 | --fp16 \ 33 | --overwrite_output_dir \ 34 | --do_train \ 35 | --do_eval \ 36 | --do_predict \ 37 | --predict_with_generate \ 38 | --use_auth_token 39 | -------------------------------------------------------------------------------- /scripts/whisper/run_whisper_voxpopuli.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0 python run_speech_recognition_whisper.py \ 3 | --model_name_or_path="medium.en" \ 4 | --dataset_name="polinaeterna/voxpopuli" \ 5 | --dataset_config_name="en" \ 6 | --train_split_name="train" \ 7 | --eval_split_name="validation" \ 8 | --test_split_name="test" \ 9 | --text_column_name="normalized_text" \ 10 | --max_steps="5000" \ 11 | --output_dir="./" \ 12 | --run_name="whisper-voxpopuli-5k" \ 13 | --wandb_project="whisper" \ 14 | --per_device_train_batch_size="64" \ 15 | --per_device_eval_batch_size="16" \ 16 | --logging_steps="25" \ 17 | --learning_rate="1e-4" \ 18 | --warmup_steps="500" \ 19 | --report_to="wandb" \ 20 | --preprocessing_num_workers="16" \ 21 | --evaluation_strategy="steps" \ 22 | --eval_steps="500" \ 23 | --save_strategy="steps" \ 24 | --save_steps="500" \ 25 | --generation_max_length="224" \ 26 | --length_column_name="input_lengths" \ 27 | --do_lower_case="True" \ 28 | --push_to_hub="False" \ 29 | --gradient_checkpointing \ 30 | --group_by_length \ 31 | --freeze_encoder \ 32 | --fp16 \ 33 | --overwrite_output_dir \ 34 | --do_train \ 35 | --do_eval \ 36 | --do_predict \ 37 | --predict_with_generate \ 38 | --use_auth_token 39 | -------------------------------------------------------------------------------- /scripts/ctc/run_gigaspeech.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python run_flax_speech_recognition_ctc.py \ 3 | --model_name_or_path="speech-seq2seq/flax-wav2vec2-large-lv60-scan" \ 4 | --tokenizer_name="sanchit-gandhi/wav2vec2-ctc-gs-black-box-tokenizer" \ 5 | --dataset_name="speechcolab/gigaspeech" \ 6 | --dataset_config_name="l" \ 7 | --train_split_name="train" \ 8 | --eval_split_name="validation" \ 9 | --test_split_name="test" \ 10 | --text_column_name="text" \ 11 | --output_dir="./flax-wav2vec2-ctc-gs-black-box" \ 12 | --wandb_project="gigaspeech" \ 13 | --wandb_name="flax-wav2vec2-ctc-gs-black-box" \ 14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \ 15 | --max_steps="50000" \ 16 | --save_steps="10000" \ 17 | --eval_steps="10000" \ 18 | --learning_rate="3e-4" \ 19 | --logging_steps="25" \ 20 | --warmup_steps="5000" \ 21 | --preprocessing_num_workers="1" \ 22 | --do_train \ 23 | --do_eval \ 24 | --do_predict \ 25 | --overwrite_output_dir \ 26 | --gradient_checkpointing \ 27 | --freeze_feature_encoder \ 28 | --push_to_hub \ 29 | --use_auth_token -------------------------------------------------------------------------------- /scripts/whisper/run_ami.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=3 python run_speech_recognition_whisper.py \ 3 | --model_name_or_path="medium.en" \ 4 | --dataset_name="speech-seq2seq/ami" \ 5 | --dataset_config_name="ihm" \ 6 | --train_split_name="train" \ 7 | --eval_split_name="validation" \ 8 | --test_split_name="test" \ 9 | --text_column_name="text" \ 10 | --max_steps="2500" \ 11 | --output_dir="./" \ 12 | --run_name="whisper-ami-dropout-0.1" \ 13 | --dropout_rate="0.1" \ 14 | --wandb_project="whisper" \ 15 | --per_device_train_batch_size="64" \ 16 | --per_device_eval_batch_size="16" \ 17 | --logging_steps="25" \ 18 | --learning_rate="1e-4" \ 19 | --warmup_steps="500" \ 20 | --report_to="wandb" \ 21 | --preprocessing_num_workers="16" \ 22 | --evaluation_strategy="steps" \ 23 | --eval_steps="500" \ 24 | --save_strategy="steps" \ 25 | --save_steps="500" \ 26 | --generation_max_length="224" \ 27 | --length_column_name="input_lengths" \ 28 | --do_lower_case="False" \ 29 | --push_to_hub="False" \ 30 | --gradient_checkpointing \ 31 | --group_by_length \ 32 | --freeze_encoder \ 33 | --fp16 \ 34 | --overwrite_output_dir \ 35 | --do_train \ 36 | --do_eval \ 37 | --do_predict \ 38 | --predict_with_generate \ 39 | --use_auth_token 40 | -------------------------------------------------------------------------------- /run_whisper_dummy.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES="" python run_speech_recognition_whisper.py \ 3 | --model_name_or_path="tiny.en" \ 4 | --dataset_name="patrickvonplaten/librispeech_asr_dummy" \ 5 | --num_train_epochs="2" \ 6 | --evaluation_strategy="epoch" \ 7 | --dataset_config_name="clean" \ 8 | --dataset_cache_dir="/home/patrick_huggingface_co/hey" \ 9 | --train_split_name="validation[:32]" \ 10 | --eval_split_name="validation" \ 11 | --test_split_name="validation[:90%]" \ 12 | --text_column_name="text" \ 13 | --output_dir="./output_dir" \ 14 | --run_name="whisper-ls-dummy" \ 15 | --wandb_project="whisper-dummy" \ 16 | --per_device_train_batch_size="8" \ 17 | --per_device_eval_batch_size="4" \ 18 | --logging_steps="1" \ 19 | --learning_rate="1e-4" \ 20 | --warmup_steps="3" \ 21 | --report_to="wandb" \ 22 | --push_to_hub="False" \ 23 | --preprocessing_num_workers="4" \ 24 | --evaluation_strategy="epoch" \ 25 | --max_eval_samples="8" \ 26 | --max_predict_samples="8" \ 27 | --length_column_name="input_lengths" \ 28 | --save_strategy="no" \ 29 | --group_by_length \ 30 | --overwrite_output_dir \ 31 | --freeze_encoder \ 32 | --do_eval \ 33 | --do_predict \ 34 | --do_train \ 35 | --predict_with_generate \ 36 | --generation_max_length=224 \ 37 | -------------------------------------------------------------------------------- /scripts/whisper/run_chime.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=5 python run_speech_recognition_whisper.py \ 3 | --model_name_or_path="medium.en" \ 4 | --dataset_name="speech-seq2seq/chime4-raw" \ 5 | --dataset_config_name="1-channel" \ 6 | --train_split_name="train" \ 7 | --eval_split_name="validation" \ 8 | --test_split_name="test" \ 9 | --text_column_name="text" \ 10 | --max_steps="2500" \ 11 | --output_dir="./" \ 12 | --run_name="whisper-chime4-dropout-0.1" \ 13 | --dropout_rate="0.1" \ 14 | --wandb_project="whisper" \ 15 | --per_device_train_batch_size="64" \ 16 | --per_device_eval_batch_size="16" \ 17 | --logging_steps="25" \ 18 | --learning_rate="1e-4" \ 19 | --warmup_steps="500" \ 20 | --report_to="wandb" \ 21 | --preprocessing_num_workers="16" \ 22 | --evaluation_strategy="steps" \ 23 | --eval_steps="500" \ 24 | --save_strategy="steps" \ 25 | --save_steps="500" \ 26 | --generation_max_length="224" \ 27 | --length_column_name="input_lengths" \ 28 | --do_lower_case="False" \ 29 | --push_to_hub="False" \ 30 | --gradient_checkpointing \ 31 | --group_by_length \ 32 | --freeze_encoder \ 33 | --fp16 \ 34 | --overwrite_output_dir \ 35 | --do_train \ 36 | --do_eval \ 37 | --do_predict \ 38 | --predict_with_generate \ 39 | --use_auth_token 40 | -------------------------------------------------------------------------------- /scripts/whisper/run_switchboard.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=1 python run_speech_recognition_whisper.py \ 3 | --model_name_or_path="medium.en" \ 4 | --dataset_name="ldc/switchboard" \ 5 | --dataset_config_name="all" \ 6 | --train_split_name="train.switchboard" \ 7 | --eval_split_name="validation.switchboard" \ 8 | --test_split_name="test.switchboard+test.callhome" \ 9 | --text_column_name="text" \ 10 | --max_steps="5000" \ 11 | --output_dir="./" \ 12 | --run_name="whisper-switchboard" \ 13 | --wandb_project="whisper" \ 14 | --per_device_train_batch_size="64" \ 15 | --per_device_eval_batch_size="16" \ 16 | --logging_steps="25" \ 17 | --learning_rate="1e-4" \ 18 | --warmup_steps="500" \ 19 | --report_to="wandb" \ 20 | --preprocessing_num_workers="16" \ 21 | --evaluation_strategy="steps" \ 22 | --eval_steps="1000" \ 23 | --save_strategy="steps" \ 24 | --save_steps="1000" \ 25 | --generation_max_length="224" \ 26 | --length_column_name="input_lengths" \ 27 | --do_lower_case="True" \ 28 | --push_to_hub="False" \ 29 | --gradient_checkpointing \ 30 | --group_by_length \ 31 | --freeze_encoder \ 32 | --fp16 \ 33 | --overwrite_output_dir \ 34 | --do_train \ 35 | --do_eval \ 36 | --do_predict \ 37 | --predict_with_generate \ 38 | --use_auth_token 39 | -------------------------------------------------------------------------------- /scripts/whisper/run_common_voice_9.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0 python run_speech_recognition_whisper.py \ 3 | --model_name_or_path="medium.en" \ 4 | --dataset_name="mozilla-foundation/common_voice_9_0" \ 5 | --dataset_config_name="en" \ 6 | --train_split_name="train" \ 7 | --eval_split_name="validation" \ 8 | --test_split_name="test" \ 9 | --text_column_name="sentence" \ 10 | --max_steps="5000" \ 11 | --output_dir="./" \ 12 | --run_name="whisper-cv9" \ 13 | --wandb_project="whisper" \ 14 | --per_device_train_batch_size="64" \ 15 | --per_device_eval_batch_size="16" \ 16 | --logging_steps="25" \ 17 | --learning_rate="1e-4" \ 18 | --warmup_steps="500" \ 19 | --report_to="wandb" \ 20 | --preprocessing_num_workers="16" \ 21 | --evaluation_strategy="steps" \ 22 | --eval_steps="1000" \ 23 | --save_strategy="steps" \ 24 | --save_steps="1000" \ 25 | --generation_max_length="224" \ 26 | --length_column_name="input_lengths" \ 27 | --do_lower_case="False" \ 28 | --push_to_hub="False" \ 29 | --max_eval_duration_in_seconds="20" \ 30 | --gradient_checkpointing \ 31 | --group_by_length \ 32 | --freeze_encoder \ 33 | --fp16 \ 34 | --overwrite_output_dir \ 35 | --do_train \ 36 | --do_eval \ 37 | --do_predict \ 38 | --predict_with_generate \ 39 | --use_auth_token 40 | -------------------------------------------------------------------------------- /scripts/ctc/run_spgispeech.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python run_flax_speech_recognition_ctc.py \ 3 | --model_name_or_path="speech-seq2seq/flax-wav2vec2-large-lv60-scan" \ 4 | --tokenizer_name="sanchit-gandhi/wav2vec2-ctc-spgispeech-tokenizer" \ 5 | --dataset_name="kensho/spgispeech" \ 6 | --dataset_config_name="L" \ 7 | --train_split_name="train" \ 8 | --eval_split_name="validation" \ 9 | --test_split_name="test" \ 10 | --text_column_name="transcript" \ 11 | --output_dir="./flax-wav2vec2-ctc-spgispeech-baseline" \ 12 | --wandb_project="spgispeech" \ 13 | --wandb_name="flax-wav2vec2-ctc-spgispeech-baseline" \ 14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \ 15 | --max_steps="50000" \ 16 | --save_steps="10000" \ 17 | --eval_steps="10000" \ 18 | --learning_rate="3e-4" \ 19 | --logging_steps="25" \ 20 | --warmup_steps="5000" \ 21 | --preprocessing_num_workers="1" \ 22 | --do_lower_case="False" \ 23 | --do_train \ 24 | --do_eval \ 25 | --do_predict \ 26 | --overwrite_output_dir \ 27 | --gradient_checkpointing \ 28 | --freeze_feature_encoder \ 29 | --push_to_hub \ 30 | --use_auth_token -------------------------------------------------------------------------------- /scripts/whisper/run_librispeech.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0 python run_speech_recognition_whisper.py \ 3 | --model_name_or_path="medium.en" \ 4 | --dataset_name="librispeech_asr" \ 5 | --dataset_config_name="all" \ 6 | --train_split_name="train.clean.100+train.clean.360+train.other.500" \ 7 | --eval_split_name="validation.clean" \ 8 | --test_split_name="validation.other+test.clean+test.other" \ 9 | --max_steps="5000" \ 10 | --text_column_name="text" \ 11 | --output_dir="./" \ 12 | --run_name="whisper-ls-960h-5k" \ 13 | --wandb_project="whisper" \ 14 | --per_device_train_batch_size="64" \ 15 | --per_device_eval_batch_size="16" \ 16 | --logging_steps="25" \ 17 | --learning_rate="1e-4" \ 18 | --warmup_steps="500" \ 19 | --report_to="wandb" \ 20 | --preprocessing_num_workers="16" \ 21 | --evaluation_strategy="steps" \ 22 | --eval_steps="1000" \ 23 | --save_strategy="steps" \ 24 | --save_steps="1000" \ 25 | --generation_max_length="224" \ 26 | --length_column_name="input_lengths" \ 27 | --do_lower_case="True" \ 28 | --push_to_hub="False" \ 29 | --gradient_checkpointing \ 30 | --group_by_length \ 31 | --freeze_encoder \ 32 | --fp16 \ 33 | --overwrite_output_dir \ 34 | --do_train \ 35 | --do_eval \ 36 | --do_predict \ 37 | --predict_with_generate \ 38 | --use_auth_token 39 | -------------------------------------------------------------------------------- /scripts/ctc/run_voxpopuli.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python run_flax_speech_recognition_ctc.py \ 3 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-ctc-voxpopuli-black-box" \ 4 | --tokenizer_name="sanchit-gandhi/flax-wav2vec2-ctc-voxpopuli-black-box" \ 5 | --dataset_name="polinaeterna/voxpopuli" \ 6 | --dataset_config_name="en" \ 7 | --train_split_name="train" \ 8 | --eval_split_name="validation" \ 9 | --test_split_name="test" \ 10 | --text_column_name="normalized_text" \ 11 | --output_dir="./flax-wav2vec2-ctc-voxpopuli-black-box" \ 12 | --wandb_project="voxpopuli" \ 13 | --wandb_name="flax-wav2vec2-ctc-voxpopuli-black-box" \ 14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \ 15 | --max_steps="50000" \ 16 | --save_steps="10000" \ 17 | --eval_steps="10000" \ 18 | --learning_rate="3e-4" \ 19 | --logging_steps="25" \ 20 | --warmup_steps="5000" \ 21 | --preprocessing_num_workers="1" \ 22 | --per_device_eval_batch_size="1" \ 23 | --do_train \ 24 | --do_eval \ 25 | --do_predict \ 26 | --overwrite_output_dir \ 27 | --gradient_checkpointing \ 28 | --freeze_feature_encoder \ 29 | --push_to_hub \ 30 | --use_auth_token -------------------------------------------------------------------------------- /scripts/ctc/run_common_voice_9.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python run_flax_speech_recognition_ctc.py \ 3 | --model_name_or_path="speech-seq2seq/flax-wav2vec2-large-lv60-scan" \ 4 | --tokenizer_name="sanchit-gandhi/wav2vec2-ctc-cv9-black-box-tokenizer" \ 5 | --dataset_name="mozilla-foundation/common_voice_9_0" \ 6 | --dataset_config_name="en" \ 7 | --train_split_name="train" \ 8 | --eval_split_name="validation" \ 9 | --test_split_name="test" \ 10 | --text_column_name="sentence" \ 11 | --output_dir="./flax-wav2vec2-ctc-cv9-black-box" \ 12 | --wandb_project="common_voice_9_0" \ 13 | --wandb_name="flax-wav2vec2-ctc-cv9-black-box" \ 14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \ 15 | --max_steps="50000" \ 16 | --save_steps="10000" \ 17 | --eval_steps="10000" \ 18 | --learning_rate="3e-4" \ 19 | --logging_steps="25" \ 20 | --warmup_steps="5000" \ 21 | --preprocessing_num_workers="1" \ 22 | --do_lower_case="False" \ 23 | --max_eval_duration_in_seconds="20" \ 24 | --do_train \ 25 | --do_eval \ 26 | --do_predict \ 27 | --overwrite_output_dir \ 28 | --gradient_checkpointing \ 29 | --freeze_feature_encoder \ 30 | --push_to_hub \ 31 | --use_auth_token -------------------------------------------------------------------------------- /scripts/ctc/run_tedlium.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python run_flax_speech_recognition_ctc.py \ 3 | --model_name_or_path="speech-seq2seq/flax-wav2vec2-large-lv60-scan" \ 4 | --tokenizer_name="sanchit-gandhi/wav2vec2-ctc-tedlium-black-box-tokenizer" \ 5 | --dataset_name="LIUM/tedlium" \ 6 | --dataset_config_name="release3" \ 7 | --train_split_name="train" \ 8 | --eval_split_name="validation" \ 9 | --test_split_name="test" \ 10 | --text_column_name="text" \ 11 | --output_dir="./flax-wav2vec2-ctc-tedlium-black-box" \ 12 | --wandb_project="tedlium" \ 13 | --wandb_name="flax-wav2vec2-ctc-tedlium-black-box" \ 14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \ 15 | --max_steps="50000" \ 16 | --save_steps="10000" \ 17 | --eval_steps="10000" \ 18 | --learning_rate="3e-4" \ 19 | --logging_steps="25" \ 20 | --warmup_steps="5000" \ 21 | --preprocessing_num_workers="1" \ 22 | --hidden_dropout="0.2" \ 23 | --activation_dropout="0.2" \ 24 | --feat_proj_dropout="0.2" \ 25 | --do_train \ 26 | --do_eval \ 27 | --do_predict \ 28 | --overwrite_output_dir \ 29 | --gradient_checkpointing \ 30 | --freeze_feature_encoder \ 31 | --push_to_hub \ 32 | --use_auth_token -------------------------------------------------------------------------------- /scripts/ctc/run_ami.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python run_flax_speech_recognition_ctc.py \ 3 | --model_name_or_path="speech-seq2seq/flax-wav2vec2-large-lv60-scan" \ 4 | --tokenizer_name="patrickvonplaten/wav2vec2_ctc_ami_tokenizer" \ 5 | --dataset_name="speech-seq2seq/ami" \ 6 | --dataset_config_name="ihm" \ 7 | --train_split_name="train" \ 8 | --eval_split_name="validation" \ 9 | --test_split_name="test" \ 10 | --text_column_name="text" \ 11 | --output_dir="./flax-wav2vec2-ctc-ami-black-box" \ 12 | --wandb_project="ami" \ 13 | --wandb_name="flax-wav2vec2-ctc-ami-black-box" \ 14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \ 15 | --max_steps="50000" \ 16 | --save_steps="10000" \ 17 | --eval_steps="10000" \ 18 | --learning_rate="3e-4" \ 19 | --logging_steps="25" \ 20 | --warmup_steps="5000" \ 21 | --preprocessing_num_workers="1" \ 22 | --hidden_dropout="0.2" \ 23 | --activation_dropout="0.2" \ 24 | --feat_proj_dropout="0.2" \ 25 | --do_lower_case="False" \ 26 | --do_train \ 27 | --do_eval \ 28 | --do_predict \ 29 | --overwrite_output_dir \ 30 | --gradient_checkpointing \ 31 | --freeze_feature_encoder \ 32 | --push_to_hub \ 33 | --use_auth_token -------------------------------------------------------------------------------- /scripts/ctc/run_switchboard.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python run_flax_speech_recognition_ctc.py \ 3 | --model_name_or_path="speech-seq2seq/flax-wav2vec2-large-lv60-scan" \ 4 | --tokenizer_name="sanchit-gandhi/wav2vec2-ctc-switchboard-black-box-tokenizer" \ 5 | --dataset_name="ldc/switchboard" \ 6 | --dataset_config_name="all" \ 7 | --train_split_name="train.fisher+train.switchboard" \ 8 | --eval_split_name="validation" \ 9 | --test_split_name="test.switchboard+test.callhome" \ 10 | --text_column_name="test" \ 11 | --output_dir="./flax-wav2vec2-ctc-switchboard-fisher-black-box" \ 12 | --wandb_project="switchboard" \ 13 | --wandb_name="flax-wav2vec2-ctc-switchboard-fisher-black-box" \ 14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \ 15 | --max_steps="50000" \ 16 | --save_steps="10000" \ 17 | --eval_steps="10000" \ 18 | --learning_rate="3e-4" \ 19 | --logging_steps="25" \ 20 | --warmup_steps="5000" \ 21 | --preprocessing_num_workers="1" \ 22 | --do_lower_case="False" \ 23 | --torchaudio_resampler="True" \ 24 | --do_train \ 25 | --do_eval \ 26 | --do_predict \ 27 | --overwrite_output_dir \ 28 | --gradient_checkpointing \ 29 | --freeze_feature_encoder \ 30 | --push_to_hub \ 31 | --use_auth_token -------------------------------------------------------------------------------- /scripts/seq2seq/run_ami.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python run_flax_speech_recognition_seq2seq.py \ 3 | --dataset_name="speech-seq2seq/ami" \ 4 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \ 5 | --dataset_config_name="ihm" \ 6 | --id_column_name="audio_id" \ 7 | --output_dir="./flax-wav2vec2-2-bart-large-ami-black-box" \ 8 | --wandb_project="ami" \ 9 | --wandb_name="flax-wav2vec2-2-bart-large-ami-black-box" \ 10 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \ 11 | --per_device_train_batch_size="8" \ 12 | --per_device_eval_batch_size="4" \ 13 | --learning_rate="1e-4" \ 14 | --warmup_steps="500" \ 15 | --logging_steps="25" \ 16 | --max_steps="50000" \ 17 | --eval_steps="10000" \ 18 | --save_steps="10000" \ 19 | --generation_max_length="200" \ 20 | --generation_num_beams="5" \ 21 | --generation_length_penalty="1.2" \ 22 | --hidden_dropout="0.2" \ 23 | --activation_dropout="0.2" \ 24 | --feat_proj_dropout="0.2" \ 25 | --do_lower_case="False" \ 26 | --overwrite_output_dir \ 27 | --gradient_checkpointing \ 28 | --freeze_feature_encoder \ 29 | --predict_with_generate \ 30 | --do_eval \ 31 | --do_train \ 32 | --do_predict \ 33 | --push_to_hub \ 34 | --use_auth_token -------------------------------------------------------------------------------- /scripts/seq2seq/run_gigaspeech.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python run_flax_speech_recognition_seq2seq.py \ 3 | --dataset_name="speechcolab/gigaspeech" \ 4 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \ 5 | --dataset_config_name="l" \ 6 | --train_split_name="train" \ 7 | --eval_split_name="validation" \ 8 | --test_split_name="test" \ 9 | --text_column_name="text" \ 10 | --id_column_name="segment_id" \ 11 | --output_dir="./" \ 12 | --wandb_project="gigaspeech" \ 13 | --wandb_name="flax-wav2vec2-2-bart-large-gs-black-box" \ 14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \ 15 | --per_device_train_batch_size="8" \ 16 | --per_device_eval_batch_size="2" \ 17 | --learning_rate="1e-4" \ 18 | --warmup_steps="500" \ 19 | --logging_steps="25" \ 20 | --max_steps="50000" \ 21 | --eval_steps="10000" \ 22 | --save_steps="10000" \ 23 | --generation_max_length="200" \ 24 | --generation_num_beams="5" \ 25 | --generation_length_penalty="1.2" \ 26 | --overwrite_output_dir \ 27 | --gradient_checkpointing \ 28 | --freeze_feature_encoder \ 29 | --predict_with_generate \ 30 | --do_lower_case \ 31 | --do_eval \ 32 | --do_train \ 33 | --do_predict \ 34 | --push_to_hub \ 35 | --use_auth_token 36 | -------------------------------------------------------------------------------- /scripts/ctc/run_librispeech.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python run_flax_speech_recognition_ctc.py \ 3 | --model_name_or_path="speech-seq2seq/flax-wav2vec2-large-lv60-scan" \ 4 | --tokenizer_name="sanchit-gandhi/wav2vec2-ctc-ls-black-box-tokenizer" \ 5 | --dataset_name="librispeech_asr" \ 6 | --dataset_config_name="all" \ 7 | --train_split_name="train.clean.100+train.clean.360+train.other.500" \ 8 | --eval_split_name="validation.clean" \ 9 | --test_split_name="validation.other+test.clean+test.other" \ 10 | --text_column_name="text" \ 11 | --output_dir="./flax-wav2vec2-ctc-ls-960h-black-box" \ 12 | --wandb_project="librispeech_960h" \ 13 | --wandb_name="flax-wav2vec2-ctc-ls-960h-black-box" \ 14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \ 15 | --max_steps="50000" \ 16 | --save_steps="10000" \ 17 | --eval_steps="10000" \ 18 | --learning_rate="3e-4" \ 19 | --logging_steps="25" \ 20 | --warmup_steps="5000" \ 21 | --preprocessing_num_workers="1" \ 22 | --hidden_dropout="0.2" \ 23 | --activation_dropout="0.2" \ 24 | --feat_proj_dropout="0.2" \ 25 | --do_train \ 26 | --do_eval \ 27 | --do_predict \ 28 | --overwrite_output_dir \ 29 | --gradient_checkpointing \ 30 | --freeze_feature_encoder \ 31 | --push_to_hub \ 32 | --use_auth_token -------------------------------------------------------------------------------- /scripts/seq2seq/run_spgispeech.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python run_flax_speech_recognition_seq2seq.py \ 3 | --dataset_name="kensho/spgispeech" \ 4 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \ 5 | --dataset_config_name="L" \ 6 | --train_split_name="train" \ 7 | --eval_split_name="validation" \ 8 | --test_split_name="test" \ 9 | --text_column_name="transcript" \ 10 | --id_column_name="wav_filename" \ 11 | --output_dir="./" \ 12 | --wandb_project="spgispeech" \ 13 | --wandb_name="flax-wav2vec2-2-bart-large-spgispeech-black-box" \ 14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \ 15 | --per_device_train_batch_size="8" \ 16 | --per_device_eval_batch_size="2" \ 17 | --learning_rate="1e-4" \ 18 | --warmup_steps="500" \ 19 | --logging_steps="25" \ 20 | --max_steps="50000" \ 21 | --eval_steps="10000" \ 22 | --save_steps="10000" \ 23 | --generation_max_length="200" \ 24 | --generation_num_beams="5" \ 25 | --generation_length_penalty="1.2" \ 26 | --do_lower_case="False" \ 27 | --overwrite_output_dir \ 28 | --gradient_checkpointing \ 29 | --freeze_feature_encoder \ 30 | --predict_with_generate \ 31 | --do_eval \ 32 | --do_train \ 33 | --do_predict \ 34 | --push_to_hub \ 35 | --use_auth_token -------------------------------------------------------------------------------- /scripts/ctc/run_earnings22.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python run_flax_speech_recognition_ctc.py \ 3 | --model_name_or_path="speech-seq2seq/flax-wav2vec2-large-lv60-scan" \ 4 | --tokenizer_name="sanchit-gandhi/wav2vec2-ctc-earnings22-black-box-tokenizer" \ 5 | --dataset_name="sanchit-gandhi/earnings22" \ 6 | --dataset_config_name="all" \ 7 | --train_split_name="train" \ 8 | --eval_split_name="validation" \ 9 | --test_split_name="test" \ 10 | --text_column_name="sentence" \ 11 | --output_dir="./flax-wav2vec2-ctc-earnings22-black-box" \ 12 | --wandb_project="earnings22" \ 13 | --wandb_name="flax-wav2vec2-ctc-earnings22-black-box" \ 14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \ 15 | --max_steps="50000" \ 16 | --save_steps="10000" \ 17 | --eval_steps="10000" \ 18 | --learning_rate="3e-4" \ 19 | --logging_steps="25" \ 20 | --warmup_steps="5000" \ 21 | --preprocessing_num_workers="1" \ 22 | --do_lower_case="False" \ 23 | --hidden_dropout="0.2" \ 24 | --activation_dropout="0.2" \ 25 | --feat_proj_dropout="0.2" \ 26 | --ignore_verifications="False" \ 27 | --do_train \ 28 | --do_eval \ 29 | --do_predict \ 30 | --overwrite_output_dir \ 31 | --gradient_checkpointing \ 32 | --freeze_feature_encoder \ 33 | --push_to_hub \ 34 | --use_auth_token -------------------------------------------------------------------------------- /run_seq2seq_dummy.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python run_flax_speech_recognition_seq2seq.py \ 3 | --dataset_name="librispeech_asr" \ 4 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \ 5 | --dataset_config_name="clean" \ 6 | --train_split_name="train.100" \ 7 | --eval_split_name="validation" \ 8 | --test_split_name="test" \ 9 | --text_column_name="text" \ 10 | --id_column_name="id" \ 11 | --output_dir="./" \ 12 | --wandb_project="librispeech_960h" \ 13 | --wandb_name="flax-wav2vec2-2-bart-large-ls-960h-baseline" \ 14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \ 15 | --per_device_train_batch_size="8" \ 16 | --per_device_eval_batch_size="4" \ 17 | --logging_steps="25" \ 18 | --max_steps="50000" \ 19 | --eval_steps="10000" \ 20 | --save_steps="10000" \ 21 | --generation_max_length="40" \ 22 | --generation_num_beams="1" \ 23 | --generation_length_penalty="1.2" \ 24 | --final_generation_max_length="200" \ 25 | --final_generation_num_beams="5" \ 26 | --learning_rate="1e-4" \ 27 | --warmup_steps="500" \ 28 | --overwrite_output_dir \ 29 | --gradient_checkpointing \ 30 | --freeze_feature_encoder \ 31 | --predict_with_generate \ 32 | --do_lower_case \ 33 | --do_eval \ 34 | --do_train \ 35 | --do_predict \ 36 | --push_to_hub \ 37 | --use_auth_token 38 | -------------------------------------------------------------------------------- /scripts/seq2seq/run_common_voice_9.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python run_flax_speech_recognition_seq2seq.py \ 3 | --dataset_name="mozilla-foundation/common_voice_9_0" \ 4 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \ 5 | --dataset_config_name="en" \ 6 | --train_split_name="train" \ 7 | --eval_split_name="validation" \ 8 | --test_split_name="test" \ 9 | --text_column_name="sentence" \ 10 | --id_column_name="client_id" \ 11 | --output_dir="./" \ 12 | --wandb_project="common_voice_9_0" \ 13 | --wandb_name="flax-wav2vec2-2-bart-large-cv9-black-box" \ 14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \ 15 | --per_device_train_batch_size="8" \ 16 | --per_device_eval_batch_size="2" \ 17 | --learning_rate="1e-4" \ 18 | --warmup_steps="500" \ 19 | --logging_steps="25" \ 20 | --max_steps="50000" \ 21 | --eval_steps="10000" \ 22 | --save_steps="10000" \ 23 | --generation_max_length="200" \ 24 | --generation_num_beams="5" \ 25 | --generation_length_penalty="1.2" \ 26 | --do_lower_case="False" \ 27 | --max_eval_duration_in_seconds="20" \ 28 | --overwrite_output_dir \ 29 | --gradient_checkpointing \ 30 | --freeze_feature_encoder \ 31 | --predict_with_generate \ 32 | --do_eval \ 33 | --do_train \ 34 | --do_predict \ 35 | --push_to_hub \ 36 | --use_auth_token 37 | -------------------------------------------------------------------------------- /scripts/seq2seq/run_tedlium.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python run_flax_speech_recognition_seq2seq.py \ 3 | --dataset_name="LIUM/tedlium" \ 4 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \ 5 | --dataset_config_name="release3" \ 6 | --train_split_name="train" \ 7 | --eval_split_name="validation" \ 8 | --test_split_name="test" \ 9 | --text_column_name="text" \ 10 | --id_column_name="id" \ 11 | --output_dir="./" \ 12 | --wandb_project="tedlium" \ 13 | --wandb_name="flax-wav2vec2-2-bart-large-tedlium-black-box" \ 14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \ 15 | --per_device_train_batch_size="8" \ 16 | --per_device_eval_batch_size="2" \ 17 | --learning_rate="1e-4" \ 18 | --warmup_steps="500" \ 19 | --logging_steps="25" \ 20 | --max_steps="50000" \ 21 | --eval_steps="10000" \ 22 | --save_steps="10000" \ 23 | --generation_max_length="200" \ 24 | --generation_num_beams="5" \ 25 | --generation_length_penalty="1.2" \ 26 | --hidden_dropout="0.2" \ 27 | --activation_dropout="0.2" \ 28 | --feat_proj_dropout="0.2" \ 29 | --overwrite_output_dir \ 30 | --gradient_checkpointing \ 31 | --freeze_feature_encoder \ 32 | --predict_with_generate \ 33 | --do_lower_case \ 34 | --do_eval \ 35 | --do_train \ 36 | --do_predict \ 37 | --push_to_hub \ 38 | --use_auth_token -------------------------------------------------------------------------------- /scripts/seq2seq/run_switchboard.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python run_flax_speech_recognition_seq2seq.py \ 3 | --dataset_name="ldc/switchboard" \ 4 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \ 5 | --dataset_config_name="all" \ 6 | --train_split_name="train.fisher+train.switchboard" \ 7 | --eval_split_name="validation" \ 8 | --test_split_name="test.switchboard+test.callhome" \ 9 | --text_column_name="text" \ 10 | --id_column_name="id" \ 11 | --output_dir="./flax-wav2vec2-2-bart-large-switchboard-fisher-black-box" \ 12 | --wandb_project="switchboard" \ 13 | --wandb_name="flax-wav2vec2-2-bart-large-switchboard-fisher-black-box" \ 14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \ 15 | --per_device_train_batch_size="8" \ 16 | --per_device_eval_batch_size="2" \ 17 | --learning_rate="1e-4" \ 18 | --warmup_steps="500" \ 19 | --logging_steps="25" \ 20 | --max_steps="50000" \ 21 | --eval_steps="10000" \ 22 | --save_steps="10000" \ 23 | --generation_max_length="200" \ 24 | --generation_num_beams="5" \ 25 | --generation_length_penalty="1.2" \ 26 | --torchaudio_resampler="True" \ 27 | --overwrite_output_dir \ 28 | --gradient_checkpointing \ 29 | --freeze_feature_encoder \ 30 | --predict_with_generate \ 31 | --do_eval \ 32 | --do_train \ 33 | --do_predict \ 34 | --push_to_hub \ 35 | --use_auth_token -------------------------------------------------------------------------------- /scripts/seq2seq/run_voxpopuli.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python run_flax_speech_recognition_seq2seq.py \ 3 | --dataset_name="google/xtreme_s" \ 4 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \ 5 | --dataset_config_name="voxpopuli.en" \ 6 | --train_split_name="train" \ 7 | --eval_split_name="validation" \ 8 | --test_split_name="test" \ 9 | --text_column_name="sentence" \ 10 | --id_column_name="id" \ 11 | --output_dir="./flax-wav2vec2-2-bart-large-voxpopuli-black-box" \ 12 | --wandb_project="voxpopuli" \ 13 | --wandb_name="flax-wav2vec2-2-bart-large-voxpopuli-black-box" \ 14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \ 15 | --per_device_train_batch_size="8" \ 16 | --per_device_eval_batch_size="1" \ 17 | --learning_rate="1e-4" \ 18 | --warmup_steps="500" \ 19 | --logging_steps="25" \ 20 | --max_steps="50000" \ 21 | --eval_steps="10000" \ 22 | --save_steps="10000" \ 23 | --generation_max_length="200" \ 24 | --generation_num_beams="5" \ 25 | --generation_length_penalty="1.2" \ 26 | --hidden_dropout="0.2" \ 27 | --activation_dropout="0.2" \ 28 | --feat_proj_dropout="0.2" \ 29 | --overwrite_output_dir \ 30 | --gradient_checkpointing \ 31 | --freeze_feature_encoder \ 32 | --predict_with_generate \ 33 | --do_lower_case \ 34 | --do_eval \ 35 | --do_train \ 36 | --do_predict \ 37 | --push_to_hub \ 38 | --use_auth_token -------------------------------------------------------------------------------- /scripts/seq2seq/run_librispeech.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python run_flax_speech_recognition_seq2seq.py \ 3 | --dataset_name="librispeech_asr" \ 4 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \ 5 | --dataset_config_name="all" \ 6 | --train_split_name="train.clean.100+train.clean.360+train.other.500" \ 7 | --eval_split_name="validation.clean" \ 8 | --test_split_name="validation.other+test.clean+test.other" \ 9 | --text_column_name="text" \ 10 | --id_column_name="id" \ 11 | --output_dir="./" \ 12 | --wandb_project="librispeech_960h" \ 13 | --wandb_name="flax-wav2vec2-2-bart-large-ls-960h-black-box" \ 14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \ 15 | --per_device_train_batch_size="8" \ 16 | --per_device_eval_batch_size="2" \ 17 | --learning_rate="1e-4" \ 18 | --warmup_steps="500" \ 19 | --logging_steps="25" \ 20 | --max_steps="50000" \ 21 | --eval_steps="10000" \ 22 | --save_steps="10000" \ 23 | --generation_max_length="200" \ 24 | --generation_num_beams="5" \ 25 | --generation_length_penalty="1.2" \ 26 | --hidden_dropout="0.2" \ 27 | --activation_dropout="0.2" \ 28 | --feat_proj_dropout="0.2" \ 29 | --overwrite_output_dir \ 30 | --gradient_checkpointing \ 31 | --freeze_feature_encoder \ 32 | --predict_with_generate \ 33 | --do_lower_case \ 34 | --do_eval \ 35 | --do_train \ 36 | --do_predict \ 37 | --push_to_hub \ 38 | --use_auth_token 39 | -------------------------------------------------------------------------------- /scripts/seq2seq/run_earnings22.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python run_flax_speech_recognition_seq2seq.py \ 3 | --dataset_name="sanchit-gandhi/earnings22" \ 4 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \ 5 | --dataset_config_name="all" \ 6 | --train_split_name="train" \ 7 | --eval_split_name="validation" \ 8 | --test_split_name="test" \ 9 | --text_column_name="sentence" \ 10 | --id_column_name="id" \ 11 | --output_dir="./flax-wav2vec2-2-bart-large-earnings22-black-box" \ 12 | --wandb_project="earnings22" \ 13 | --wandb_name="flax-wav2vec2-2-bart-large-earnings22-black-box" \ 14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \ 15 | --per_device_train_batch_size="8" \ 16 | --per_device_eval_batch_size="4" \ 17 | --logging_steps="25" \ 18 | --max_steps="50000" \ 19 | --eval_steps="10000" \ 20 | --save_steps="10000" \ 21 | --generation_max_length="40" \ 22 | --generation_num_beams="1" \ 23 | --generation_length_penalty="1.2" \ 24 | --final_generation_max_length="200" \ 25 | --final_generation_num_beams="5" \ 26 | --learning_rate="1e-4" \ 27 | --warmup_steps="500" \ 28 | --do_lower_case="False" \ 29 | --hidden_dropout="0.2" \ 30 | --activation_dropout="0.2" \ 31 | --feat_proj_dropout="0.2" \ 32 | --ignore_verifications="False" \ 33 | --overwrite_output_dir \ 34 | --gradient_checkpointing \ 35 | --freeze_feature_encoder \ 36 | --predict_with_generate \ 37 | --do_eval \ 38 | --do_train \ 39 | --do_predict \ 40 | --push_to_hub \ 41 | --use_auth_token -------------------------------------------------------------------------------- /tests/check_flax_ctc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import tempfile 3 | import numpy as np 4 | import torch 5 | from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor 6 | from models.modeling_flax_wav2vec2 import FlaxWav2Vec2ForCTC 7 | import jax.numpy as jnp 8 | from datasets import load_dataset 9 | from run_flax_speech_recognition_ctc import ctc_loss 10 | 11 | model_id = "facebook/wav2vec2-large-lv60" 12 | # model_id = "hf-internal-testing/tiny-random-wav2vec2" 13 | 14 | processor = Wav2Vec2Processor.from_pretrained(model_id, return_attention_mask=True) 15 | # in PyTorch we always use 'mean' by default. 16 | # See: https://github.com/huggingface/transformers/blob/93b802c43e70f41edfb166ad6ead79de95a26c32/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py#L126 17 | model_pt = Wav2Vec2ForCTC.from_pretrained(model_id, ctc_loss_reduction="mean") 18 | 19 | with tempfile.TemporaryDirectory() as temp_folder: 20 | model_pt.save_pretrained(temp_folder) 21 | model_fx = FlaxWav2Vec2ForCTC.from_pretrained(temp_folder, from_pt=True) 22 | 23 | # load dummy dataset and read soundfiles 24 | ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") 25 | 26 | # inputs pt 27 | samples = [d["array"] for d in ds[:4]["audio"]] 28 | inputs_pt = processor(samples, return_tensors="pt", padding="longest", sampling_rate=16_000) 29 | 30 | # inputs fx 31 | inputs_fx = processor(samples, return_tensors="np", padding="longest", sampling_rate=16_000) 32 | 33 | # labels 34 | transcription = ds[:4]["text"] 35 | labels_pt = processor.tokenizer(transcription, return_tensors="pt", padding="longest") 36 | labels_ids_pt = labels_pt["input_ids"].masked_fill(labels_pt.attention_mask.ne(1), -100) 37 | 38 | labels_fx = processor.tokenizer(transcription, return_tensors="np", padding="longest") 39 | labels_ids_fx = jnp.where(labels_fx.attention_mask == 0, -100, labels_fx.input_ids) 40 | 41 | # pytorch 42 | with torch.no_grad(): 43 | outputs = model_pt(**inputs_pt, labels=labels_ids_pt) 44 | 45 | logits_pt = outputs.logits 46 | loss_pt = outputs.loss 47 | 48 | # flax 49 | logits_fx = model_fx(**inputs_fx).logits 50 | logits_attention_mask = model_fx._get_feature_vector_attention_mask(logits_fx.shape[1], inputs_fx.attention_mask) 51 | 52 | # Check logits the same 53 | logits_diff = np.abs((logits_pt.detach().numpy() - np.asarray(logits_fx))).max() 54 | assert logits_diff < 1e-3, "Logits don't match" 55 | 56 | # flax loss 57 | blank_id = model_fx.config.pad_token_id 58 | loss_fx = ctc_loss(logits_fx, logits_attention_mask, labels_ids_fx, blank_id) 59 | 60 | # Check loss is the same 61 | loss_diff = np.asarray(loss_fx) - loss_pt.numpy() 62 | assert loss_diff < 1e-3, "Loss doesn't match" 63 | -------------------------------------------------------------------------------- /.idea/sshConfigs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 8 | 9 | 11 | 12 | 14 | 15 | 17 | 18 | 20 | 21 | 23 | 24 | 26 | 27 | 29 | 30 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 85 | -------------------------------------------------------------------------------- /tests/check_flax_ctc_cv9.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import tempfile 3 | import numpy as np 4 | import torch 5 | from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, AutoTokenizer, AutoFeatureExtractor 6 | from models.modeling_flax_wav2vec2 import FlaxWav2Vec2ForCTC 7 | import jax.numpy as jnp 8 | from datasets import load_dataset 9 | import datasets 10 | from run_flax_speech_recognition_ctc import ctc_loss 11 | 12 | model_id = "facebook/wav2vec2-large-lv60" 13 | # model_id = "hf-internal-testing/tiny-random-wav2vec2" 14 | tokenizer_id = "patrickvonplaten/wav2vec2_ctc_cv9_tokenizer" 15 | 16 | feature_extractor = AutoFeatureExtractor.from_pretrained(model_id) 17 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_id if tokenizer_id else model_id, return_attention_mask=True) 18 | 19 | # in PyTorch we always use 'mean' by default. 20 | # See: https://github.com/huggingface/transformers/blob/93b802c43e70f41edfb166ad6ead79de95a26c32/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py#L126 21 | model_pt = Wav2Vec2ForCTC.from_pretrained(model_id, ctc_loss_reduction="mean") 22 | 23 | with tempfile.TemporaryDirectory() as temp_folder: 24 | model_pt.save_pretrained(temp_folder) 25 | model_fx = FlaxWav2Vec2ForCTC.from_pretrained(temp_folder, from_pt=True) 26 | 27 | 28 | # load CV9 dataset and read soundfiles 29 | ds = load_dataset("mozilla-foundation/common_voice_9_0", "en", split="train[:1%]", cache_dir="/home/sanchitgandhi/cache/huggingface/datasets/") 30 | 31 | # resample dataset to 16kHz on the fly 32 | ds = ds.cast_column( 33 | "audio", datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) 34 | ) 35 | 36 | # inputs pt 37 | samples = [d["array"] for d in ds[:4]["audio"]] 38 | inputs_pt = feature_extractor(samples, return_tensors="pt", padding="longest", sampling_rate=16_000) 39 | 40 | # inputs fx 41 | inputs_fx = feature_extractor(samples, return_tensors="np", padding="longest", sampling_rate=16_000) 42 | 43 | # labels 44 | transcription = ds[:4]["sentence"] 45 | labels_pt = tokenizer(transcription, return_tensors="pt", padding="longest") 46 | labels_ids_pt = labels_pt["input_ids"].masked_fill(labels_pt.attention_mask.ne(1), -100) 47 | 48 | labels_fx = tokenizer(transcription, return_tensors="np", padding="longest") 49 | labels_ids_fx = jnp.where(labels_fx.attention_mask == 0, -100, labels_fx.input_ids) 50 | 51 | # set pt model config to accommodate for CV9 tokenizer vocab size 52 | model_pt.config.vocab_size = tokenizer.vocab_size 53 | 54 | # pytorch 55 | with torch.no_grad(): 56 | outputs = model_pt(**inputs_pt, labels=labels_ids_pt) 57 | 58 | logits_pt = outputs.logits 59 | loss_pt = outputs.loss 60 | 61 | # flax 62 | logits_fx = model_fx(**inputs_fx).logits 63 | logits_attention_mask = model_fx._get_feature_vector_attention_mask(logits_fx.shape[1], inputs_fx.attention_mask) 64 | 65 | # Check logits the same 66 | logits_diff = np.abs((logits_pt.detach().numpy() - np.asarray(logits_fx))).max() 67 | assert logits_diff < 1e-3, "Logits don't match" 68 | 69 | # flax loss 70 | blank_id = model_fx.config.pad_token_id 71 | loss_fx = ctc_loss(logits_fx, logits_attention_mask, labels_ids_fx, blank_id) 72 | 73 | # Check loss is the same 74 | loss_diff = np.asarray(loss_fx) - loss_pt.numpy() 75 | assert loss_diff < 1e-3, "Loss doesn't match" -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Initially taken from Github's Python gitignore file 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | 9 | # Created by https://www.toptal.com/developers/gitignore/api/intellij 10 | # Edit at https://www.toptal.com/developers/gitignore?templates=intellij 11 | 12 | ### Intellij ### 13 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 14 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 15 | 16 | # User-specific stuff 17 | .idea/**/workspace.xml 18 | .idea/**/tasks.xml 19 | .idea/**/usage.statistics.xml 20 | .idea/**/dictionaries 21 | .idea/**/shelf 22 | 23 | # AWS User-specific 24 | .idea/**/aws.xml 25 | 26 | # Generated files 27 | .idea/**/contentModel.xml 28 | 29 | # Sensitive or high-churn files 30 | .idea/**/dataSources/ 31 | .idea/**/dataSources.ids 32 | .idea/**/dataSources.local.xml 33 | .idea/**/sqlDataSources.xml 34 | .idea/**/dynamic.xml 35 | .idea/**/uiDesigner.xml 36 | .idea/**/dbnavigator.xml 37 | 38 | # Gradle 39 | .idea/**/gradle.xml 40 | .idea/**/libraries 41 | 42 | # Gradle and Maven with auto-import 43 | # When using Gradle or Maven with auto-import, you should exclude module files, 44 | # since they will be recreated, and may cause churn. Uncomment if using 45 | # auto-import. 46 | # .idea/artifacts 47 | # .idea/compiler.xml 48 | # .idea/jarRepositories.xml 49 | # .idea/modules.xml 50 | # .idea/*.iml 51 | # .idea/modules 52 | # *.iml 53 | # *.ipr 54 | 55 | # CMake 56 | cmake-build-*/ 57 | 58 | # Mongo Explorer plugin 59 | .idea/**/mongoSettings.xml 60 | 61 | # File-based project format 62 | *.iws 63 | 64 | # IntelliJ 65 | out/ 66 | 67 | # mpeltonen/sbt-idea plugin 68 | .idea_modules/ 69 | 70 | # JIRA plugin 71 | atlassian-ide-plugin.xml 72 | 73 | # Cursive Clojure plugin 74 | .idea/replstate.xml 75 | 76 | # SonarLint plugin 77 | .idea/sonarlint/ 78 | 79 | # Crashlytics plugin (for Android Studio and IntelliJ) 80 | com_crashlytics_export_strings.xml 81 | crashlytics.properties 82 | crashlytics-build.properties 83 | fabric.properties 84 | 85 | # Editor-based Rest Client 86 | .idea/httpRequests 87 | 88 | # Android studio 3.1+ serialized cache file 89 | .idea/caches/build_file_checksums.ser 90 | 91 | ### Intellij Patch ### 92 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 93 | 94 | # *.iml 95 | # modules.xml 96 | # .idea/misc.xml 97 | # *.ipr 98 | 99 | # Sonarlint plugin 100 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 101 | .idea/**/sonarlint/ 102 | 103 | # SonarQube Plugin 104 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 105 | .idea/**/sonarIssues.xml 106 | 107 | # Markdown Navigator plugin 108 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 109 | .idea/**/markdown-navigator.xml 110 | .idea/**/markdown-navigator-enh.xml 111 | .idea/**/markdown-navigator/ 112 | 113 | # Cache file creation bug 114 | # See https://youtrack.jetbrains.com/issue/JBR-2257 115 | .idea/$CACHE_FILE$ 116 | 117 | # CodeStream plugin 118 | # https://plugins.jetbrains.com/plugin/12206-codestream 119 | .idea/codestream.xml 120 | 121 | # End of https://www.toptal.com/developers/gitignore/api/intellij 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Seq2Seq Speech in JAX 2 | A [JAX](https://jax.readthedocs.io/en/latest/)/[Flax](https://flax.readthedocs.io/en/latest/) repository for combining a pre-trained speech encoder model (e.g. Wav2Vec2, HuBERT, WavLM) with a pre-trained text decoder model (e.g. GPT2, Bart) to yield a Speech Sequence-to-Sequence (Seq2Seq) model for automatic speech recognition. 3 | 4 | The script `run_flax_speech_recognition_seq2seq.py` can be used to fine-tune a Speech Seq2Seq model on one of the official speech recognition datasets or a custom dataset. It makes use of the `pmap` JAX operator to provide data parallelism accross GPU/TPU devices. 5 | 6 | The modelling files are based very heavily on those from Hugging Face [Transformers 🤗](https://github.com/huggingface/transformers). This is a standalone repository to enable rapid prototyping and involvement with the community. The final modelling files and training script will be merged into Transformers 🤗 to be used with the rest of the open-source library. The final system weights will be made publicly available at [huggingface.co](huggingface.co) 🚀 7 | 8 | ![Seq2SeqModel](seq2seq.png) 9 | **Figure 1:** Speech-encoder text-decoder style Seq2Seq model. 10 | 11 | ## Example Usage 12 | To instantiate a _Wav2Vec2-2-Bart_ model with the `FlaxSpeechEncoderDecoderModel` framework, run the following Python script inside the cloned repo: 13 | ```python 14 | from transformers import AutoFeatureExtractor, AutoTokenizer 15 | from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel 16 | import numpy as np 17 | 18 | # checkpoints to leverage 19 | encoder_id = "facebook/wav2vec2-large-lv60" 20 | decoder_id = "facebook/bart-large" 21 | 22 | model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained( 23 | encoder_id, decoder_id, encoder_add_adapter=True, decoder_from_pt=True) 24 | 25 | model.config.decoder_start_token_id = model.config.decoder.bos_token_id 26 | model.config.pad_token_id = model.config.decoder.pad_token_id 27 | model.config.eos_token_id = model.config.decoder.eos_token_id 28 | model.config.use_cache = False 29 | model.config.processor_class = "Wav2Vec2Processor" 30 | 31 | # check if generation works 32 | out = model.generate(np.ones((1, 2000))) 33 | 34 | model.save_pretrained("./") 35 | 36 | feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_id) 37 | feature_extractor.save_pretrained("./") 38 | tokenizer = AutoTokenizer.from_pretrained(decoder_id) 39 | tokenizer.save_pretrained("./") 40 | ``` 41 | 42 | To train the model on [Librispeech ASR](https://huggingface.co/datasets/librispeech_asr), run the template bash script [`run_seq2seq_dummy.sh`](https://github.com/sanchit-gandhi/seq2seq-speech/blob/main/run_seq2seq_dummy.sh). 43 | 44 | ## Flax Whisper Model 45 | 46 | ```bash 47 | #!/usr/bin/env bash 48 | python run_flax_speech_recognition_seq2seq.py \ 49 | --dataset_name="librispeech_asr" \ 50 | --model_name_or_path="openai/whisper-small" \ 51 | --dataset_config_name="clean" \ 52 | --train_split_name="train.100" \ 53 | --eval_split_name="validation" \ 54 | --test_split_name="test" \ 55 | --text_column_name="text" \ 56 | --id_column_name="id" \ 57 | --output_dir="./flax-whisper-ft-librispeech-clean" \ 58 | --wandb_project="librispeech_clean" \ 59 | --wandb_name="flax-whisper-ft-librispeech-clean" \ 60 | --per_device_train_batch_size="8" \ 61 | --per_device_eval_batch_size="2" \ 62 | --learning_rate="1e-4" \ 63 | --warmup_steps="500" \ 64 | --logging_steps="25" \ 65 | --max_steps="50000" \ 66 | --eval_steps="10000" \ 67 | --save_steps="10000" \ 68 | --generation_max_length="200" \ 69 | --generation_num_beams="5" \ 70 | --generation_length_penalty="1.2" \ 71 | --hidden_dropout="0.2" \ 72 | --activation_dropout="0.2" \ 73 | --feat_proj_dropout="0.2" \ 74 | --overwrite_output_dir \ 75 | --gradient_checkpointing \ 76 | --freeze_feature_encoder \ 77 | --predict_with_generate \ 78 | --do_lower_case \ 79 | --do_eval \ 80 | --do_train \ 81 | --do_predict \ 82 | --push_to_hub \ 83 | --use_auth_token 84 | ``` 85 | 86 | Control the precision through the `--precision` arg: 87 | * Full precision (weights and optimiser in fp32): `--precision=full` 88 | * Half-mixed (weights in bf16, optimiser in fp32 ): `--precision=half_mixed` 89 | * Full-mixed (weights and optimiser in bf16): `--precision=full_mixed` -------------------------------------------------------------------------------- /models/configuration_speech_encoder_decoder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import copy 18 | 19 | from transformers.configuration_utils import PretrainedConfig 20 | from transformers.utils import logging 21 | from models.configuration_wav2vec2 import Wav2Vec2Config 22 | from models.configuration_bart import BartConfig 23 | from transformers import AutoConfig 24 | 25 | 26 | logger = logging.get_logger(__name__) 27 | 28 | 29 | class SpeechEncoderDecoderConfig(PretrainedConfig): 30 | r""" 31 | [`SpeechEncoderDecoderConfig`] is the configuration class to store the configuration of a 32 | [`SpeechEncoderDecoderModel`]. It is used to instantiate an Encoder Decoder model according to the specified 33 | arguments, defining the encoder and decoder configs. 34 | 35 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 36 | documentation from [`PretrainedConfig`] for more information. 37 | 38 | Args: 39 | kwargs (*optional*): 40 | Dictionary of keyword arguments. Notably: 41 | 42 | - **encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines 43 | the encoder config. 44 | - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines 45 | the decoder config. 46 | 47 | Examples: 48 | 49 | ```python 50 | >>> from transformers import BertConfig, Wav2Vec2Config, SpeechEncoderDecoderConfig, SpeechEncoderDecoderModel 51 | 52 | >>> # Initializing a Wav2Vec2 & BERT style configuration 53 | >>> config_encoder = Wav2Vec2Config() 54 | >>> config_decoder = BertConfig() 55 | 56 | >>> config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder) 57 | 58 | >>> # Initializing a Wav2Vec2Bert model from a Wav2Vec2 & bert-base-uncased style configurations 59 | >>> model = SpeechEncoderDecoderModel(config=config) 60 | 61 | >>> # Accessing the model configuration 62 | >>> config_encoder = model.config.encoder 63 | >>> config_decoder = model.config.decoder 64 | >>> # set decoder config to causal lm 65 | >>> config_decoder.is_decoder = True 66 | >>> config_decoder.add_cross_attention = True 67 | 68 | >>> # Saving the model, including its configuration 69 | >>> model.save_pretrained("my-model") 70 | 71 | >>> # loading model and config from pretrained folder 72 | >>> encoder_decoder_config = SpeechEncoderDecoderConfig.from_pretrained("my-model") 73 | >>> model = SpeechEncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config) 74 | ```""" 75 | model_type = "speech-encoder-decoder" 76 | is_composition = True 77 | 78 | def __init__(self, **kwargs): 79 | super().__init__(**kwargs) 80 | if "encoder" not in kwargs or "decoder" not in kwargs: 81 | raise ValueError( 82 | f"A configuraton of type {self.model_type} cannot be instantiated because not both `encoder` and `decoder` sub-configurations are passed, but only {kwargs}" 83 | ) 84 | 85 | encoder_config = kwargs.pop("encoder") 86 | decoder_config = kwargs.pop("decoder") 87 | 88 | # TODO: Load configs from AutoConfig (as done in Transformers 🤗) 89 | self.encoder = Wav2Vec2Config(**encoder_config) 90 | self.decoder = BartConfig(**decoder_config) 91 | self.is_encoder_decoder = True 92 | 93 | @classmethod 94 | def from_encoder_decoder_configs( 95 | cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs 96 | ) -> PretrainedConfig: 97 | r""" 98 | Instantiate a [`SpeechEncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model 99 | configuration and decoder model configuration. 100 | 101 | Returns: 102 | [`SpeechEncoderDecoderConfig`]: An instance of a configuration object 103 | """ 104 | logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config") 105 | decoder_config.is_decoder = True 106 | decoder_config.add_cross_attention = True 107 | 108 | return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs) 109 | 110 | def to_dict(self): 111 | """ 112 | Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*. 113 | 114 | Returns: 115 | `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, 116 | """ 117 | output = copy.deepcopy(self.__dict__) 118 | output["encoder"] = self.encoder.to_dict() 119 | output["decoder"] = self.decoder.to_dict() 120 | output["model_type"] = self.__class__.model_type 121 | return output 122 | -------------------------------------------------------------------------------- /.idea/webServers.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 83 | 84 | 85 | TRC 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | GPUs 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /models/configuration_bart.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ BART model configuration""" 16 | import warnings 17 | 18 | from transformers.configuration_utils import PretrainedConfig 19 | from transformers.utils import logging 20 | 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | BART_PRETRAINED_CONFIG_ARCHIVE_MAP = { 25 | "facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/config.json", 26 | # See all BART models at https://huggingface.co/models?filter=bart 27 | } 28 | 29 | 30 | class BartConfig(PretrainedConfig): 31 | r""" 32 | This is the configuration class to store the configuration of a [`BartModel`]. It is used to instantiate a BART 33 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 34 | defaults will yield a similar configuration to that of the BART 35 | [facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture. 36 | 37 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 38 | documentation from [`PretrainedConfig`] for more information. 39 | 40 | 41 | Args: 42 | vocab_size (`int`, *optional*, defaults to 50265): 43 | Vocabulary size of the BART model. Defines the number of different tokens that can be represented by the 44 | `inputs_ids` passed when calling [`BartModel`] or [`TFBartModel`]. 45 | d_model (`int`, *optional*, defaults to 1024): 46 | Dimensionality of the layers and the pooler layer. 47 | encoder_layers (`int`, *optional*, defaults to 12): 48 | Number of encoder layers. 49 | decoder_layers (`int`, *optional*, defaults to 12): 50 | Number of decoder layers. 51 | encoder_attention_heads (`int`, *optional*, defaults to 16): 52 | Number of attention heads for each attention layer in the Transformer encoder. 53 | decoder_attention_heads (`int`, *optional*, defaults to 16): 54 | Number of attention heads for each attention layer in the Transformer decoder. 55 | decoder_ffn_dim (`int`, *optional*, defaults to 4096): 56 | Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. 57 | encoder_ffn_dim (`int`, *optional*, defaults to 4096): 58 | Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. 59 | activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): 60 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, 61 | `"relu"`, `"silu"` and `"gelu_new"` are supported. 62 | dropout (`float`, *optional*, defaults to 0.1): 63 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 64 | attention_dropout (`float`, *optional*, defaults to 0.0): 65 | The dropout ratio for the attention probabilities. 66 | activation_dropout (`float`, *optional*, defaults to 0.0): 67 | The dropout ratio for activations inside the fully connected layer. 68 | classifier_dropout (`float`, *optional*, defaults to 0.0): 69 | The dropout ratio for classifier. 70 | max_position_embeddings (`int`, *optional*, defaults to 1024): 71 | The maximum sequence length that this model might ever be used with. Typically set this to something large 72 | just in case (e.g., 512 or 1024 or 2048). 73 | init_std (`float`, *optional*, defaults to 0.02): 74 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 75 | encoder_layerdrop: (`float`, *optional*, defaults to 0.0): 76 | The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) 77 | for more details. 78 | decoder_layerdrop: (`float`, *optional*, defaults to 0.0): 79 | The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) 80 | for more details. 81 | scale_embedding (`bool`, *optional*, defaults to `False`): 82 | Scale embeddings by diving by sqrt(d_model). 83 | use_cache (`bool`, *optional*, defaults to `True`): 84 | Whether or not the model should return the last key/values attentions (not used by all models). 85 | num_labels: (`int`, *optional*, defaults to 3): 86 | The number of labels to use in [`BartForSequenceClassification`]. 87 | forced_eos_token_id (`int`, *optional*, defaults to 2): 88 | The id of the token to force as the last generated token when `max_length` is reached. Usually set to 89 | `eos_token_id`. 90 | use_scan (`bool`, *optional*, defaults to `False`): 91 | Whether or not to use nn.scan in the Flax Bart attention layers. 92 | 93 | Example: 94 | 95 | ```python 96 | >>> from transformers import BartModel, BartConfig 97 | 98 | >>> # Initializing a BART facebook/bart-large style configuration 99 | >>> configuration = BartConfig() 100 | 101 | >>> # Initializing a model from the facebook/bart-large style configuration 102 | >>> model = BartModel(configuration) 103 | 104 | >>> # Accessing the model configuration 105 | >>> configuration = model.config 106 | ```""" 107 | model_type = "bart" 108 | keys_to_ignore_at_inference = ["past_key_values"] 109 | attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} 110 | 111 | def __init__( 112 | self, 113 | vocab_size=50265, 114 | max_position_embeddings=1024, 115 | encoder_layers=12, 116 | encoder_ffn_dim=4096, 117 | encoder_attention_heads=16, 118 | decoder_layers=12, 119 | decoder_ffn_dim=4096, 120 | decoder_attention_heads=16, 121 | encoder_layerdrop=0.0, 122 | decoder_layerdrop=0.0, 123 | activation_function="gelu", 124 | d_model=1024, 125 | dropout=0.1, 126 | attention_dropout=0.0, 127 | activation_dropout=0.0, 128 | init_std=0.02, 129 | classifier_dropout=0.0, 130 | scale_embedding=False, 131 | use_cache=True, 132 | use_scan=False, 133 | fuse_matmuls=False, 134 | num_labels=3, 135 | pad_token_id=1, 136 | bos_token_id=0, 137 | eos_token_id=2, 138 | is_encoder_decoder=True, 139 | decoder_start_token_id=2, 140 | forced_eos_token_id=2, 141 | **kwargs 142 | ): 143 | self.vocab_size = vocab_size 144 | self.max_position_embeddings = max_position_embeddings 145 | self.d_model = d_model 146 | self.encoder_ffn_dim = encoder_ffn_dim 147 | self.encoder_layers = encoder_layers 148 | self.encoder_attention_heads = encoder_attention_heads 149 | self.decoder_ffn_dim = decoder_ffn_dim 150 | self.decoder_layers = decoder_layers 151 | self.decoder_attention_heads = decoder_attention_heads 152 | self.dropout = dropout 153 | self.attention_dropout = attention_dropout 154 | self.activation_dropout = activation_dropout 155 | self.activation_function = activation_function 156 | self.init_std = init_std 157 | self.encoder_layerdrop = encoder_layerdrop 158 | self.decoder_layerdrop = decoder_layerdrop 159 | self.classifier_dropout = classifier_dropout 160 | self.use_cache = use_cache 161 | self.use_scan = use_scan 162 | self.fuse_matmuls = fuse_matmuls 163 | self.num_hidden_layers = encoder_layers 164 | self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True 165 | 166 | super().__init__( 167 | num_labels=num_labels, 168 | pad_token_id=pad_token_id, 169 | bos_token_id=bos_token_id, 170 | eos_token_id=eos_token_id, 171 | is_encoder_decoder=is_encoder_decoder, 172 | decoder_start_token_id=decoder_start_token_id, 173 | forced_eos_token_id=forced_eos_token_id, 174 | **kwargs, 175 | ) 176 | 177 | # ensure backward compatibility for BART CNN models 178 | if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False): 179 | self.forced_bos_token_id = self.bos_token_id 180 | warnings.warn( 181 | f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. " 182 | "The config can simply be saved and uploaded again to be fixed." 183 | ) 184 | -------------------------------------------------------------------------------- /get_ctc_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from datasets import load_dataset 3 | from collections import Counter 4 | import json 5 | import os 6 | import re 7 | import tempfile 8 | from transformers import Wav2Vec2CTCTokenizer 9 | 10 | # which dataset 11 | dataset_name = "sanchit-gandhi/earnings22" 12 | # which config 13 | dataset_config = "all" 14 | # which split => @Sanchit, we should only use the train split for "fairness" 15 | split = "train" 16 | # in case the dataset requires access like CV9 17 | use_auth_token = True 18 | # name of the text data column 19 | text_column = "sentence" 20 | # name of tok to upload to the Hub 21 | tokenizer_name = "wav2vec2-ctc-earnings22-black-box-tokenizer" 22 | # only set to TRUE if dataset is NOT cased 23 | do_lower = False 24 | # ignore the verifications of the downloaded/processed dataset information in `load_dataset`, set to False in most cases, True for E22 25 | ignore_verifications = True 26 | 27 | # should be kept the same across datasets (except for ablation) 28 | do_upper = False 29 | remove_punctuation = False 30 | cutoff_freq = 0.01 31 | # dataset cache directory 32 | dataset_cache_dir = "/home/sanchitgandhi/cache/huggingface/datasets" 33 | # For GigaSpeech, we need to convert spelled out punctuation to symbolic form 34 | gigaspeech_punctuation = {"": ",", "": ".", "": "?", "', '_', '[', ']'] 37 | # additional chars to remove if `remove_punctuation` is set to True 38 | additional_chars_to_remove_regex = '[,?.!-;:"“%‘”�{}()<>' + "']" 39 | 40 | dataset = load_dataset( 41 | dataset_name, 42 | dataset_config, 43 | split=split, 44 | use_auth_token=use_auth_token, 45 | cache_dir=dataset_cache_dir, 46 | ignore_verifications= ignore_verifications, 47 | ) 48 | 49 | # remove all data that is unnecessary to save RAM 50 | dataset = dataset.remove_columns(list(set(dataset.column_names) - set([text_column]))) 51 | 52 | # define function to see stats about letters and to create vocab 53 | def create_vocabulary_from_data(dataset, word_delimiter_token="|", do_lower=False, do_upper=False, remove_punctuation=False, cutoff_freq=0.0): 54 | def extract_all_chars(batch): 55 | all_text = " ".join(batch[text_column]) 56 | 57 | if do_lower and do_upper: 58 | raise ValueError("Cannot do uppercase and lowercase tokenization concurrently. Set at most one of `do_lower` or `do_upper` to `True`.") 59 | if do_lower: 60 | all_text = all_text.lower() 61 | if do_upper: 62 | all_text = all_text.upper() 63 | for punctuation, replacement in gigaspeech_punctuation.items(): 64 | all_text = all_text.replace(punctuation.lower(), replacement) 65 | all_text = all_text.replace(punctuation.upper(), replacement) 66 | for char in preprocessing_chars_to_remove: 67 | all_text = all_text.replace(char, "") 68 | if remove_punctuation: 69 | all_text = re.sub(additional_chars_to_remove_regex, '', all_text) 70 | 71 | count_chars_dict = Counter(list(all_text)) 72 | # sort by freq 73 | count_chars_dict = sorted(count_chars_dict.items(), key=lambda item: (-item[1], item[0])) 74 | # retrieve dict, freq 75 | vocab, freqs = zip(*count_chars_dict) 76 | 77 | return {"vocab": list(vocab), "freqs": list(freqs)} 78 | 79 | dataset = dataset.map( 80 | extract_all_chars, 81 | batched=True, 82 | batch_size=-1, 83 | remove_columns=dataset.column_names, 84 | ) 85 | 86 | vocab, freqs = dataset["vocab"], dataset["freqs"] 87 | total_num_chars = sum(freqs) 88 | chars_to_remove = [] 89 | 90 | print("Character Occurences") 91 | print(f"Total characters in dataset: {total_num_chars}") 92 | print(50 * "-") 93 | print(f"{'Char'.rjust(5)} | {'Total occ'.rjust(10)} | {'% of total occ'.rjust(20)} |") 94 | print(50 * "-") 95 | for char, freq in zip(vocab, freqs): 96 | freq_in_percent = freq / total_num_chars * 100 97 | print(f"{char.rjust(5)} | {str(freq).rjust(10)} | {str(round(freq_in_percent, 3)).rjust(20)} |") 98 | if freq_in_percent < cutoff_freq: 99 | chars_to_remove.append(char) 100 | print(50 * "-") 101 | 102 | vocab = list(set(vocab) - set(chars_to_remove)) 103 | 104 | # Wav2Vec2CTC Tokenizers always have those as the first tokens (important for CTC) 105 | vocab = ["", "", "", ""] + vocab 106 | 107 | alphabet = list(map(chr, range(97, 123))) 108 | 109 | for char in alphabet: 110 | char = char.upper() if do_upper else char 111 | if char not in vocab: 112 | vocab.append(char) 113 | 114 | # create json dict 115 | vocab_dict = {v: k for k, v in enumerate(list(vocab))} 116 | 117 | # replace white space with delimiter token 118 | if word_delimiter_token is not None: 119 | vocab_dict[word_delimiter_token] = vocab_dict[" "] 120 | del vocab_dict[" "] 121 | 122 | return vocab_dict 123 | 124 | # Note that the functions accepts the following important args 125 | # 1. --do_lower 126 | # => whether to lowercase all letters or not. 127 | # Note that if you lowercase letters for the vocab, then you also need to 128 | # do so when preparing the data for the training, dev and test set 129 | # 2. --cutoff_freq 130 | # => This is very important! Lots of datasets will contain "wrong" characters in the training set, e.g. 131 | # characters that just occur a couple of times. 132 | # By default, the CTC vocab creation would just add them to the vocab even if their occurance is neglectible # compared to the "super frequent" letters. We can see such characters as "errors" or irrelevant in the 133 | # dataset, so that we should delete them from the vocab. During training they would then just be classified 134 | # unkown tokens which the model can handle. 135 | # In this script, we deploy a mechanism to remove all chars whose freq in % is below a certain threshold. 136 | 137 | 138 | # To begin with, let's take a look into the charecter distribution to decide whether to lowercase everything 139 | # and how many "incorrect" chars are in the dataset 140 | 141 | # do_lower = False 142 | # cutoff_freq = 0.0 143 | # create_vocabulary_from_data(dataset, do_lower=do_lower, cutoff_freq=cutoff_freq) 144 | 145 | """ 146 | Total characters in dataset: 57415071 147 | -------------------------------------------------- 148 | Char | Total occ | % of total occ | 149 | -------------------------------------------------- 150 | | 9158936 | 15.952 | 151 | e | 5656975 | 9.853 | 152 | a | 3843802 | 6.695 | 153 | t | 3612796 | 6.292 | 154 | i | 3362877 | 5.857 | 155 | o | 3275590 | 5.705 | 156 | n | 3208804 | 5.589 | 157 | s | 3155007 | 5.495 | 158 | r | 3065229 | 5.339 | 159 | h | 2033409 | 3.542 | 160 | l | 1985225 | 3.458 | 161 | d | 1727989 | 3.01 | 162 | c | 1399592 | 2.438 | 163 | u | 1167110 | 2.033 | 164 | m | 1075733 | 1.874 | 165 | f | 884318 | 1.54 | 166 | . | 881733 | 1.536 | 167 | p | 846057 | 1.474 | 168 | g | 809581 | 1.41 | 169 | y | 740494 | 1.29 | 170 | w | 722667 | 1.259 | 171 | b | 606687 | 1.057 | 172 | " | 571968 | 0.996 | 173 | v | 477330 | 0.831 | 174 | , | 345764 | 0.602 | 175 | T | 332058 | 0.578 | 176 | k | 284615 | 0.496 | 177 | S | 174125 | 0.303 | 178 | A | 157656 | 0.275 | 179 | H | 156398 | 0.272 | 180 | C | 143595 | 0.25 | 181 | I | 141826 | 0.247 | 182 | M | 113026 | 0.197 | 183 | B | 102932 | 0.179 | 184 | ' | 88702 | 0.154 | 185 | - | 88461 | 0.154 | 186 | x | 85563 | 0.149 | 187 | P | 84495 | 0.147 | 188 | L | 67280 | 0.117 | 189 | R | 67254 | 0.117 | 190 | W | 66508 | 0.116 | 191 | D | 66094 | 0.115 | 192 | F | 61841 | 0.108 | 193 | G | 59031 | 0.103 | 194 | E | 54387 | 0.095 | 195 | N | 53495 | 0.093 | 196 | z | 47955 | 0.084 | 197 | O | 43305 | 0.075 | 198 | j | 41654 | 0.073 | 199 | q | 40510 | 0.071 | 200 | J | 39647 | 0.069 | 201 | K | 31204 | 0.054 | 202 | U | 23517 | 0.041 | 203 | V | 21380 | 0.037 | 204 | Y | 14863 | 0.026 | 205 | ? | 8574 | 0.015 | 206 | ! | 7327 | 0.013 | 207 | : | 5456 | 0.01 | 208 | ’ | 4864 | 0.008 | 209 | Z | 4735 | 0.008 | 210 | Q | 4488 | 0.008 | 211 | ; | 2781 | 0.005 | 212 | X | 1391 | 0.002 | 213 | ” | 1374 | 0.002 | 214 | “ | 1344 | 0.002 | 215 | ‘ | 1100 | 0.002 | 216 | — | 690 | 0.001 | 217 | é | 297 | 0.001 | 218 | ü | 177 | 0.0 | 219 | ) | 122 | 0.0 | 220 | ( | 121 | 0.0 | 221 | ä | 109 | 0.0 | 222 | ... 223 | 阪 | 1 | 0.0 | 224 | fl | 1 | 0.0 | 225 | """ 226 | # All right, we see lots of "wrong" tokens and also see that there is a mix of upper-case and lower-case tokens 227 | # Let's lower-case all tokens and take a look again 228 | 229 | # do_lower = True 230 | # cutoff_freq = 0.0 231 | 232 | # create_vocabulary_from_data(dataset, do_lower=do_lower, cutoff_freq=cutoff_freq) 233 | 234 | 235 | """ 236 | Character Occurences 237 | Total characters in dataset: 57415071 238 | -------------------------------------------------- 239 | Char | Total occ | % of total occ | 240 | -------------------------------------------------- 241 | | 9158936 | 15.952 | 242 | e | 5711362 | 9.947 | 243 | a | 4001458 | 6.969 | 244 | t | 3944854 | 6.871 | 245 | i | 3504703 | 6.104 | 246 | s | 3329132 | 5.798 | 247 | o | 3318895 | 5.781 | 248 | n | 3262299 | 5.682 | 249 | r | 3132483 | 5.456 | 250 | h | 2189807 | 3.814 | 251 | l | 2052505 | 3.575 | 252 | d | 1794083 | 3.125 | 253 | c | 1543187 | 2.688 | 254 | u | 1190627 | 2.074 | 255 | m | 1188759 | 2.07 | 256 | f | 946159 | 1.648 | 257 | p | 930552 | 1.621 | 258 | . | 881733 | 1.536 | 259 | g | 868612 | 1.513 | 260 | w | 789175 | 1.375 | 261 | y | 755357 | 1.316 | 262 | b | 709619 | 1.236 | 263 | " | 571968 | 0.996 | 264 | v | 498710 | 0.869 | 265 | , | 345764 | 0.602 | 266 | k | 315819 | 0.55 | 267 | ' | 88702 | 0.154 | 268 | - | 88461 | 0.154 | 269 | x | 86954 | 0.151 | 270 | j | 81301 | 0.142 | 271 | z | 52690 | 0.092 | 272 | q | 44998 | 0.078 | 273 | ? | 8574 | 0.015 | 274 | ! | 7327 | 0.013 | 275 | : | 5456 | 0.01 | 276 | ’ | 4864 | 0.008 | 277 | ; | 2781 | 0.005 | 278 | ” | 1374 | 0.002 | 279 | “ | 1344 | 0.002 | 280 | ‘ | 1100 | 0.002 | 281 | — | 690 | 0.001 | 282 | é | 319 | 0.001 | 283 | ü | 182 | 0.0 | 284 | ) | 122 | 0.0 | 285 | ( | 121 | 0.0 | 286 | ... 287 | 阪 | 1 | 0.0 | 288 | fl | 1 | 0.0 | 289 | """ 290 | 291 | # Cool, now let's remove very rare, "wrong" characters. Everything belowe 0.01% (note that's 1/10,000) seems like a good estimate 292 | # It keeps all letters of the alphabet and some punctuation, but removes clearly all incorrect letters like 293 | # accentuated letters from German or French, Chinese letters, ... 294 | # Running it once more and now keeping the dict 295 | 296 | vocab_dict = create_vocabulary_from_data(dataset, do_lower=do_lower, do_upper=do_upper, remove_punctuation=remove_punctuation, cutoff_freq=cutoff_freq) 297 | 298 | # save vocab dict to be loaded into tokenizer 299 | with tempfile.TemporaryDirectory() as tmp: 300 | with open(os.path.join(tmp, "vocab.json"), "w") as file: 301 | json.dump(vocab_dict, file) 302 | 303 | tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(tmp, do_lower_case=do_upper) 304 | 305 | # push tokenizer to the Hub 306 | # E.g. see: https://huggingface.co/patrickvonplaten/wav2vec2_ctc_cv9_tokenizer 307 | tokenizer.push_to_hub(tokenizer_name) 308 | -------------------------------------------------------------------------------- /get_ctc_ngram.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from datasets import load_dataset 3 | from collections import Counter 4 | import re 5 | import os 6 | from transformers import Wav2Vec2ProcessorWithLM 7 | from transformers import AutoProcessor 8 | from pathlib import Path 9 | from pyctcdecode import build_ctcdecoder 10 | 11 | # adapt to dataset 12 | dataset_name = "polinaeterna/voxpopuli" 13 | dataset_config = "en" 14 | split = "train" 15 | text_column = "normalized_text" 16 | tokenizer_name = "sanchit-gandhi/flax-wav2vec2-ctc-voxpopuli-black-box" 17 | do_lower = True # only set to TRUE if dataset is NOT cased 18 | 19 | # should be kept the same across datasets (except for ablation) 20 | cutoff_freq = 0.01 21 | do_upper = False # only set to TRUE for ablation studies 22 | preprocessing_chars_to_remove = [] # only remove chars for ablation studies 23 | additional_chars_to_remove_regex = "" # only set to something for ablation studies 24 | remove_punctuation = additional_chars_to_remove_regex != "" 25 | # dataset specific "error correction" 26 | # For GigaSpeech, we need to convert spelled out punctuation to symbolic form 27 | gigaspeech_punctuation = {"": ",", "": ".", "": "?", "": ",", " ": ".", " ": "?", " ": "!"} 40 | gigaspeech_disfluencies = ["", ""] 41 | swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "", "", "", "[laughter-", "[vocalized-noise]", "_1"] 42 | swb_punctuations = ["{", "}", "[", "]-", "]"] 43 | earnings_disfluencies = ["", "", "", "inaudible", "", "", ""] 44 | ignore_segments = ["ignore_time_segment_in_scoring", "", "", "[noise]", "[laughter]", "[silence]", "[vocalized-noise]", "", "", "", "", "", "", ""] 45 | 46 | 47 | # in case the dataset requires access like CV9 48 | use_auth_token = True 49 | 50 | dataset = load_dataset( 51 | dataset_name, 52 | dataset_config, 53 | split=split, 54 | use_auth_token=use_auth_token, 55 | cache_dir=dataset_cache_dir, 56 | ) 57 | 58 | # remove all data that is unnecessary to save RAM 59 | dataset = dataset.remove_columns(list(set(dataset.column_names) - set([text_column]))) 60 | 61 | # define function to see stats about letters and to create vocab 62 | # NOTE: this function has to be 1-to-1 aligned with: 63 | # https://github.com/sanchit-gandhi/seq2seq-speech/blob/25d3af18d779d12cdb6c30040f30f51f5a6bb75b/get_ctc_tokenizer.py#L45 64 | def process_text(dataset, word_delimiter_token="|", do_lower=False, do_upper=False, remove_punctuation=False, cutoff_freq=0.0): 65 | def extract_all_chars(batch): 66 | all_text = " ".join(batch[text_column]) 67 | 68 | if do_lower and do_upper: 69 | raise ValueError("Cannot do uppercase and lowercase tokenization concurrently. Set at most one of `do_lower` or `do_upper` to `True`.") 70 | if do_lower: 71 | all_text = all_text.lower() 72 | if do_upper: 73 | all_text = all_text.upper() 74 | for punctuation, replacement in gigaspeech_punctuation.items(): 75 | all_text = all_text.replace(punctuation.lower(), replacement) 76 | all_text = all_text.replace(punctuation.upper(), replacement) 77 | for char in preprocessing_chars_to_remove: 78 | all_text = all_text.replace(char, "") 79 | # only used for ablation studies 80 | if remove_punctuation: 81 | all_text = re.sub(additional_chars_to_remove_regex, '', all_text) 82 | 83 | count_chars_dict = Counter(list(all_text)) 84 | # sort by freq 85 | count_chars_dict = sorted(count_chars_dict.items(), key=lambda item: (-item[1], item[0])) 86 | # retrieve dict, freq 87 | vocab, freqs = zip(*count_chars_dict) 88 | 89 | result = {"vocab": [list(vocab)], "freqs": [list(freqs)]} 90 | result[text_column] = [all_text] 91 | 92 | return result 93 | 94 | dataset = dataset.map( 95 | extract_all_chars, 96 | batched=True, 97 | batch_size=-1, 98 | remove_columns=dataset.column_names, 99 | ) 100 | 101 | vocab, freqs = dataset["vocab"][0], dataset["freqs"][0] 102 | total_num_chars = sum(freqs) 103 | chars_to_remove = [] 104 | 105 | print("Character Occurences") 106 | print(f"Total characters in dataset: {total_num_chars}") 107 | print(50 * "-") 108 | print(f"{'Char'.rjust(5)} | {'Total occ'.rjust(10)} | {'% of total occ'.rjust(20)} |") 109 | print(50 * "-") 110 | for char, freq in zip(vocab, freqs): 111 | freq_in_percent = freq / total_num_chars * 100 112 | print(f"{char.rjust(5)} | {str(freq).rjust(10)} | {str(round(freq_in_percent, 3)).rjust(20)} |") 113 | if freq_in_percent < cutoff_freq: 114 | chars_to_remove.append(char) 115 | print(50 * "-") 116 | print(f"REMOVED CHARS: {chars_to_remove}") 117 | print(50 * "-") 118 | 119 | 120 | def correct_data(batch): 121 | # LibriSpeech ASR 122 | new_input_strings = [] 123 | for input_str in batch[text_column]: 124 | if dataset_name == "librispeech_asr": 125 | pass # no error correction necessary 126 | 127 | # VoxPopuli 128 | if dataset_name == "google/xtreme_s": 129 | pass # no error correction necessary 130 | 131 | # Common Voice 9 132 | if dataset_name == "mozilla-foundation/common_voice_9_0": 133 | if input_str.startswith('"') and input_str.endswith('"'): 134 | # we can remove trailing quotation marks as they do not affect the transcription 135 | input_str = input_str[1:-1] 136 | # replace double quotation marks with single 137 | input_str = input_str.replace('""', '"') 138 | 139 | # TED-LIUM (Release 3) 140 | if dataset_name == "LIUM/tedlium": 141 | # delete the token from the text 142 | input_str = input_str.replace("", "") 143 | # replace spaced apostrophes with un-spaced (it 's -> it's) 144 | for contraction in tedlium_contractions: 145 | input_str = input_str.replace(contraction, contraction[1:]) 146 | 147 | # GigaSpeech 148 | if dataset_name == "speechcolab/gigaspeech": 149 | for disfluency in gigaspeech_disfluencies: 150 | input_str = input_str.replace(disfluency, "") 151 | # convert spelled out punctuation to symbolic form 152 | for punctuation, replacement in gigaspeech_punctuation.items(): 153 | input_str = input_str.replace(punctuation, replacement) 154 | 155 | # SWB: hide the path to the private HF dataset 156 | if "switchboard" in dataset_name: 157 | for disfluency in swb_disfluencies: 158 | input_str = input_str.replace(disfluency, "") 159 | # remove parenthesised text (test data only) 160 | input_str = re.sub("[\(].*?[\)]", "", input_str) 161 | for punctuation in swb_punctuations: 162 | input_str = input_str.replace(punctuation, "") 163 | # replace anomalous words with their correct transcriptions 164 | split_str = input_str.split("/") 165 | if len(split_str) > 1: 166 | input_str = " ".join( 167 | [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]]) 168 | 169 | # Earnings 22: still figuring out best segmenting method. Thus, dataset name subject to change 170 | if "earnings22" in dataset_name: 171 | for disfluency in earnings_disfluencies: 172 | input_str = input_str.replace(disfluency, "") 173 | 174 | # SPGISpeech 175 | if dataset_name == "kensho/spgispeech": 176 | pass # no error correction necessary 177 | 178 | # JIWER compliance (for WER/CER calc.) 179 | # remove multiple spaces 180 | input_str = re.sub(r"\s\s+", " ", input_str) 181 | 182 | # strip trailing spaces 183 | input_str = input_str.strip() 184 | 185 | new_input_strings.append(input_str) 186 | 187 | all_text = " ".join(new_input_strings) 188 | 189 | for char in chars_to_remove: 190 | all_text = all_text.replace(char, "") 191 | 192 | result = {} 193 | result[text_column] = [all_text] 194 | 195 | return result 196 | 197 | dataset = dataset.map( 198 | correct_data, 199 | batched=True, 200 | batch_size=-1, 201 | remove_columns=dataset.column_names, 202 | ) 203 | 204 | return dataset 205 | 206 | 207 | # Cool, now let's remove very rare, "wrong" characters. Everything belowe 0.01% (note that's 1/10,000) seems like a good estimate 208 | # It keeps all letters of the alphabet and some punctuation, but removes clearly all incorrect letters like 209 | # accentuated letters from German or French, Chinese letters, ... 210 | # Running it once more and now keeping the dict 211 | 212 | if not os.path.isfile(text_save_path): 213 | text_data = process_text(dataset, do_lower=do_lower, do_upper=do_upper, remove_punctuation=remove_punctuation, cutoff_freq=cutoff_freq) 214 | 215 | # save vocab dict to be loaded into tokenizer 216 | Path(dir_path).mkdir(parents=True, exist_ok=True) 217 | with open(text_save_path, "w") as file: 218 | file.write(" ".join(text_data[text_column])) 219 | 220 | if not os.path.isfile(ngram_save_path): 221 | ngram_command = f"{home_path}/kenlm/build/bin/lmplz -o 5 < '{text_save_path}' > '{ngram_save_path}' --skip_symbols" 222 | os.system(ngram_command) 223 | 224 | # correct with "" 225 | ngram_save_path_correct = ngram_save_path + "_correct.arpa" 226 | with open(ngram_save_path, "r") as read_file, open(ngram_save_path_correct, "w") as write_file: 227 | has_added_eos = False 228 | for line in read_file: 229 | if not has_added_eos and "ngram 1=" in line: 230 | count = line.strip().split("=")[-1] 231 | write_file.write(line.replace(f"{count}", f"{int(count)+1}")) 232 | elif not has_added_eos and "" in line: 233 | write_file.write(line) 234 | write_file.write(line.replace("", "")) 235 | has_added_eos = True 236 | else: 237 | write_file.write(line) 238 | 239 | os.system(f"mv {ngram_save_path_correct} {ngram_save_path}") 240 | 241 | 242 | processor = AutoProcessor.from_pretrained(tokenizer_name) 243 | vocab_dict = processor.tokenizer.get_vocab() 244 | if do_lower: 245 | sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])} 246 | else: 247 | sorted_vocab_dict = {k: v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])} 248 | 249 | processor.tokenizer.encoder = sorted_vocab_dict 250 | processor.tokenizer.decoder = {v: k for k, v in processor.tokenizer.encoder.items()} 251 | 252 | decoder = build_ctcdecoder( 253 | labels=list(sorted_vocab_dict.keys()), 254 | kenlm_model_path=ngram_save_path, 255 | ) 256 | 257 | processor_with_lm = Wav2Vec2ProcessorWithLM( 258 | feature_extractor=processor.feature_extractor, 259 | tokenizer=processor.tokenizer, 260 | decoder=decoder 261 | ) 262 | processor_with_lm.save_pretrained(dir_path) 263 | 264 | new_ngram_path = os.path.join(dir_path, "language_model", file_name) 265 | bin_save_path = new_ngram_path.split(".")[0] + ".bin" 266 | os.system(f"{home_path}/kenlm/build/bin/build_binary '{new_ngram_path}' '{bin_save_path}'") 267 | os.system(f"mv '{bin_save_path}' '{new_ngram_path}'") 268 | 269 | 270 | 271 | # CONFIGS 272 | # ======================================================================== 273 | # 1. LIBRISPEECH: 274 | # dataset_name = "librispeech_asr" 275 | # dataset_config = None 276 | # split = "train.clean.100+train.clean.360+train.other.500" 277 | # text_column = "text" 278 | # tokenizer_name = "sanchit-gandhi/flax-wav2vec2-ctc-ls-960h-baseline" 279 | # do_lower = True # only set to TRUE if dataset is NOT cased 280 | 281 | # => no REMOVED CHARS: [] -> all chars are used 282 | 283 | # ======================================================================== 284 | # 2. TEDLIUM 285 | # dataset_name = "LIUM/tedlium" 286 | # dataset_config = "release3" 287 | # split = "train" 288 | # text_column = "text" 289 | # tokenizer_name = "sanchit-gandhi/flax-wav2vec2-ctc-tedlium-black-box" 290 | # do_lower = False # only set to TRUE if dataset is NOT cased 291 | 292 | # => REMOVED CHARS: ['0', '1', '2', '9', '[', ']', '3', '5', '8', '4', '$', '7', '6', '&', '+', '=', '#', '%', '@', '*', '\\', '^', 'ā'] 293 | 294 | # ======================================================================== 295 | # 3. AMI 296 | # dataset_name = "speech-seq2seq/ami" 297 | # dataset_config = "ihm" 298 | # split = "train" 299 | # text_column = "text" 300 | # tokenizer_name = "sanchit-gandhi/flax-wav2vec2-ctc-ami-black-box" 301 | # do_lower = False # only set to TRUE if dataset is NOT cased 302 | 303 | # => REMOVED CHARS: ['X', 'Q', 'Z', '0', '!', '*', '1', '3', '@'] 304 | 305 | # ======================================================================== 306 | # 4. CV 9 307 | # dataset_name = "mozilla-foundation/common_voice_9_0" 308 | # dataset_config = "en" 309 | # split = "train" 310 | # text_column = "sentence" 311 | # tokenizer_name = "sanchit-gandhi/flax-wav2vec2-ctc-cv9-black-box" 312 | # do_lower = False # only set to TRUE if dataset is NOT cased 313 | 314 | # => REMOVED CHARS: REMOVED CHARS: [':', '’', 'Z', 'Q', ';', 'X', '”', '“', '‘', '—', 'é', 'ü', ')', '(', 'ä', 'ö', 'á', 'ó', 'è', 'í', '–', '/', 'ç', '&', 'â', 'ō', 'ß', 'ñ', 'É', 'à', 'ï', 'ô', 'ú', 'ã', 'ê', 'ë', 'č', 'ł', '`', 'Ö', '…', '´', 'ø', 'ć', 'Š', 'ž', 'Ü', 'î', 'ð', 'û', 'ā', 'ă', 'ū', '%', 'Ä', 'ı', 'œ', 'š', '[', ']', '«', '»', 'Á', 'Ó', 'ò', 'ī', 'ș', '_', '¡', '·', 'Ç', 'Ú', 'æ', 'ý', 'ń', 'Ō', 'Œ', 'ř', 'ş', 'ʻ', 'α', 'κ', 'π', 'и', 'к', 'ạ', '#', '=', '~', '§', 'Ã', 'È', 'Î', 'Ø', 'å', 'õ', 'þ', 'Č', 'ē', 'ę', 'ě', 'ğ', 'ň', 'ő', 'Ș', 'ə', 'Α', 'Χ', 'В', 'а', 'е', 'з', 'й', 'л', 'н', 'ь', 'я', 'נ', 'ע', 'ṃ', 'ả', 'ị', 'ụ', '„', '€', '→', '≡', '京', '先', '大', '尚', '时', '生', '都', '阪', 'fl'] 315 | 316 | # ======================================================================== 317 | # 5. Gigaspeech 318 | # dataset_name = "speechcolab/gigaspeech" 319 | # dataset_config = "l" 320 | # split = "train" 321 | # text_column = "text" 322 | # tokenizer_name = "sanchit-gandhi/flax-wav2vec2-ctc-gs-black-box" 323 | # do_lower = True # only set to TRUE if dataset is NOT cased 324 | 325 | # => REMOVED CHARS: [] 326 | 327 | # ======================================================================== 328 | # 6. SPGI Kensho Speech 329 | # dataset_name = "kensho/spgispeech" 330 | # dataset_config = "L" 331 | # split = "train" 332 | # text_column = "transcript" 333 | # tokenizer_name = "sanchit-gandhi/flax-wav2vec2-ctc-spgispeech-black-box" 334 | # do_lower = False # only set to TRUE if dataset is NOT cased 335 | 336 | # => REMOVED CHARS: ['Q', 'V', 'U', 'K', '9', 'X', 'Z'] 337 | 338 | # ======================================================================== 339 | # 7. VoxPopuli 340 | # dataset_name = "polinaeterna/voxpopuli" 341 | # dataset_config = "en" 342 | # split = "train" 343 | # text_column = "normalized_text" 344 | # tokenizer_name = "sanchit-gandhi/flax-wav2vec2-ctc-voxpopuli-black-box" 345 | # do_lower = True # only set to TRUE if dataset is NOT cased 346 | 347 | # => REMOVED CHARS: ['!', '1'] 348 | 349 | # ======================================================================== 350 | # 8. Earnings 22: 351 | # dataset_name = "sanchit-gandhi/earnings22_robust_split" 352 | # dataset_config = None 353 | # split = "train" 354 | # text_column = "sentence" 355 | # tokenizer_name = "sanchit-gandhi/flax-wav2vec2-ctc-earnings22-cased-hidden-activation-featproj-dropout-0.2" 356 | # do_lower = False # only set to TRUE if dataset is NOT cased 357 | 358 | # => REMOVED CHARS: ['&', 'X', ';', '€', 'Z', '/', ':', '*', '£', '¥', 'ł', '₽', '!', '+', '–', '@', '¢', '₵', '\\', '_', '#', '=', 'ı', '₦', '[', 'ø', '₱'] 359 | -------------------------------------------------------------------------------- /models/configuration_wav2vec2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Wav2Vec2 model configuration""" 16 | 17 | import functools 18 | import operator 19 | 20 | from transformers.configuration_utils import PretrainedConfig 21 | from transformers.utils import logging 22 | 23 | 24 | logger = logging.get_logger(__name__) 25 | 26 | WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP = { 27 | "facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json", 28 | # See all Wav2Vec2 models at https://huggingface.co/models?filter=wav2vec2 29 | } 30 | 31 | 32 | class Wav2Vec2Config(PretrainedConfig): 33 | r""" 34 | This is the configuration class to store the configuration of a [`Wav2Vec2Model`]. It is used to instantiate an 35 | Wav2Vec2 model according to the specified arguments, defining the model architecture. Instantiating a configuration 36 | with the defaults will yield a similar configuration to that of the Wav2Vec2 37 | [facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) architecture. 38 | 39 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 40 | documentation from [`PretrainedConfig`] for more information. 41 | 42 | 43 | Args: 44 | vocab_size (`int`, *optional*, defaults to 32): 45 | Vocabulary size of the Wav2Vec2 model. Defines the number of different tokens that can be represented by 46 | the `inputs_ids` passed when calling [`Wav2Vec2Model`] or [`TFWav2Vec2Model`]. Vocabulary size of the 47 | model. Defines the different tokens that can be represented by the *inputs_ids* passed to the forward 48 | method of [`Wav2Vec2Model`]. 49 | hidden_size (`int`, *optional*, defaults to 768): 50 | Dimensionality of the encoder layers and the pooler layer. 51 | num_hidden_layers (`int`, *optional*, defaults to 12): 52 | Number of hidden layers in the Transformer encoder. 53 | num_attention_heads (`int`, *optional*, defaults to 12): 54 | Number of attention heads for each attention layer in the Transformer encoder. 55 | intermediate_size (`int`, *optional*, defaults to 3072): 56 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. 57 | hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): 58 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, 59 | `"relu"`, `"selu"` and `"gelu_new"` are supported. 60 | hidden_dropout (`float`, *optional*, defaults to 0.1): 61 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 62 | attention_dropout (`float`, *optional*, defaults to 0.1): 63 | The dropout ratio for the attention probabilities. 64 | final_dropout (`float`, *optional*, defaults to 0.1): 65 | The dropout probability for the final projection layer of [`Wav2Vec2ForCTC`]. 66 | initializer_range (`float`, *optional*, defaults to 0.02): 67 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 68 | layer_norm_eps (`float`, *optional*, defaults to 1e-12): 69 | The epsilon used by the layer normalization layers. 70 | feat_extract_norm (`str`, *optional*, defaults to `"group"`): 71 | The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group 72 | normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D 73 | convolutional layers. 74 | feat_proj_dropout (`float`, *optional*, defaults to 0.0): 75 | The dropout probability for output of the feature encoder. 76 | feat_extract_activation (`str, `optional`, defaults to `"gelu"`): 77 | The non-linear activation function (function or string) in the 1D convolutional layers of the feature 78 | extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported. 79 | feat_quantizer_dropout (`float`, *optional*, defaults to 0.0): 80 | The dropout probabilitiy for quantized feature encoder states. 81 | conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`): 82 | A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the 83 | feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers. 84 | conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`): 85 | A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length 86 | of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*. 87 | conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`): 88 | A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The 89 | length of *conv_kernel* defines the number of convolutional layers and has to match the length of 90 | *conv_dim*. 91 | conv_bias (`bool`, *optional*, defaults to `False`): 92 | Whether the 1D convolutional layers have a bias. 93 | num_conv_pos_embeddings (`int`, *optional*, defaults to 128): 94 | Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional 95 | embeddings layer. 96 | num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16): 97 | Number of groups of 1D convolutional positional embeddings layer. 98 | do_stable_layer_norm (`bool`, *optional*, defaults to `False`): 99 | Whether to apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is 100 | True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is 101 | False` corresponds to applying layer norm after the attention layer. 102 | apply_spec_augment (`bool`, *optional*, defaults to `True`): 103 | Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see 104 | [SpecAugment: A Simple Data Augmentation Method for Automatic Speech 105 | Recognition](https://arxiv.org/abs/1904.08779). 106 | mask_time_prob (`float`, *optional*, defaults to 0.05): 107 | Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking 108 | procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If 109 | reasoning from the propability of each feature vector to be chosen as the start of the vector span to be 110 | masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the 111 | actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`. 112 | mask_time_length (`int`, *optional*, defaults to 10): 113 | Length of vector span along the time axis. 114 | mask_time_min_masks (`int`, *optional*, defaults to 2),: 115 | The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step, 116 | irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length < 117 | mask_time_min_masks'' 118 | mask_feature_prob (`float`, *optional*, defaults to 0.0): 119 | Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The 120 | masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over 121 | the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector 122 | span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap 123 | may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is 124 | True`. 125 | mask_feature_length (`int`, *optional*, defaults to 10): 126 | Length of vector span along the feature axis. 127 | mask_feature_min_masks (`int`, *optional*, defaults to 0),: 128 | The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time 129 | step, irrespectively of `mask_feature_prob`. Only relevant if 130 | ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks'' 131 | num_codevectors_per_group (`int`, *optional*, defaults to 320): 132 | Number of entries in each quantization codebook (group). 133 | num_codevector_groups (`int`, *optional*, defaults to 2): 134 | Number of codevector groups for product codevector quantization. 135 | contrastive_logits_temperature (`float`, *optional*, defaults to 0.1): 136 | The temperature *kappa* in the contrastive loss. 137 | feat_quantizer_dropout (`float`, *optional*, defaults to 0.0): 138 | The dropout probabilitiy for the output of the feature encoder that's used by the quantizer. 139 | num_negatives (`int`, *optional*, defaults to 100): 140 | Number of negative samples for the contrastive loss. 141 | codevector_dim (`int`, *optional*, defaults to 256): 142 | Dimensionality of the quantized feature vectors. 143 | proj_codevector_dim (`int`, *optional*, defaults to 256): 144 | Dimensionality of the final projection of both the quantized and the transformer features. 145 | diversity_loss_weight (`int`, *optional*, defaults to 0.1): 146 | The weight of the codebook diversity loss component. 147 | ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`): 148 | Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an 149 | instance of [`Wav2Vec2ForCTC`]. 150 | ctc_zero_infinity (`bool`, *optional*, defaults to `False`): 151 | Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly 152 | occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance 153 | of [`Wav2Vec2ForCTC`]. 154 | use_weighted_layer_sum (`bool`, *optional*, defaults to `False`): 155 | Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an 156 | instance of [`Wav2Vec2ForSequenceClassification`]. 157 | classifier_proj_size (`int`, *optional*, defaults to 256): 158 | Dimensionality of the projection before token mean-pooling for classification. 159 | tdnn_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`): 160 | A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN* 161 | module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers. 162 | tdnn_kernel (`Tuple[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`): 163 | A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the 164 | *XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*. 165 | tdnn_dilation (`Tuple[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`): 166 | A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the 167 | *XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*. 168 | xvector_output_dim (`int`, *optional*, defaults to 512): 169 | Dimensionality of the *XVector* embedding vectors. 170 | add_adapter (`bool`, *optional*, defaults to `False`): 171 | Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for 172 | warm-starting Wav2Vec2 for SpeechEncoderDecoder models. 173 | adapter_kernel_size (`int`, *optional*, defaults to 3): 174 | Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`. 175 | adapter_stride (`int`, *optional*, defaults to 2): 176 | Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`. 177 | num_adapter_layers (`int`, *optional*, defaults to 3): 178 | Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is 179 | True`. 180 | output_hidden_size (`int`, *optional*): 181 | Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant 182 | if `add_adapter is True`. 183 | use_scan (`bool`, *optional*, defaults to `False`): 184 | Whether or not to use nn.scan in the Flax Wav2Vec2 transformer layers. 185 | 186 | Example: 187 | 188 | ```python 189 | >>> from transformers import Wav2Vec2Model, Wav2Vec2Config 190 | 191 | >>> # Initializing a Wav2Vec2 facebook/wav2vec2-base-960h style configuration 192 | >>> configuration = Wav2Vec2Config() 193 | 194 | >>> # Initializing a model from the facebook/wav2vec2-base-960h style configuration 195 | >>> model = Wav2Vec2Model(configuration) 196 | 197 | >>> # Accessing the model configuration 198 | >>> configuration = model.config 199 | ```""" 200 | model_type = "wav2vec2" 201 | 202 | def __init__( 203 | self, 204 | vocab_size=32, 205 | hidden_size=768, 206 | num_hidden_layers=12, 207 | num_attention_heads=12, 208 | intermediate_size=3072, 209 | hidden_act="gelu", 210 | hidden_dropout=0.1, 211 | activation_dropout=0.1, 212 | attention_dropout=0.1, 213 | feat_proj_dropout=0.0, 214 | feat_quantizer_dropout=0.0, 215 | final_dropout=0.1, 216 | layerdrop=0.1, 217 | initializer_range=0.02, 218 | layer_norm_eps=1e-5, 219 | feat_extract_norm="group", 220 | feat_extract_activation="gelu", 221 | conv_dim=(512, 512, 512, 512, 512, 512, 512), 222 | conv_stride=(5, 2, 2, 2, 2, 2, 2), 223 | conv_kernel=(10, 3, 3, 3, 3, 2, 2), 224 | conv_bias=False, 225 | num_conv_pos_embeddings=128, 226 | num_conv_pos_embedding_groups=16, 227 | do_stable_layer_norm=False, 228 | apply_spec_augment=True, 229 | mask_time_prob=0.05, 230 | mask_time_length=10, 231 | mask_time_min_masks=2, 232 | mask_feature_prob=0.0, 233 | mask_feature_length=10, 234 | mask_feature_min_masks=0, 235 | num_codevectors_per_group=320, 236 | num_codevector_groups=2, 237 | contrastive_logits_temperature=0.1, 238 | num_negatives=100, 239 | codevector_dim=256, 240 | proj_codevector_dim=256, 241 | diversity_loss_weight=0.1, 242 | ctc_loss_reduction="sum", 243 | ctc_zero_infinity=False, 244 | use_weighted_layer_sum=False, 245 | classifier_proj_size=256, 246 | tdnn_dim=(512, 512, 512, 512, 1500), 247 | tdnn_kernel=(5, 3, 3, 1, 1), 248 | tdnn_dilation=(1, 2, 3, 1, 1), 249 | xvector_output_dim=512, 250 | pad_token_id=0, 251 | bos_token_id=1, 252 | eos_token_id=2, 253 | add_adapter=False, 254 | adapter_kernel_size=3, 255 | adapter_stride=2, 256 | num_adapter_layers=3, 257 | output_hidden_size=None, 258 | use_scan=False, 259 | fuse_matmuls=False, 260 | **kwargs 261 | ): 262 | super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id) 263 | self.hidden_size = hidden_size 264 | self.feat_extract_norm = feat_extract_norm 265 | self.feat_extract_activation = feat_extract_activation 266 | self.conv_dim = list(conv_dim) 267 | self.conv_stride = list(conv_stride) 268 | self.conv_kernel = list(conv_kernel) 269 | self.conv_bias = conv_bias 270 | self.num_conv_pos_embeddings = num_conv_pos_embeddings 271 | self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups 272 | self.num_feat_extract_layers = len(self.conv_dim) 273 | self.num_hidden_layers = num_hidden_layers 274 | self.intermediate_size = intermediate_size 275 | self.hidden_act = hidden_act 276 | self.num_attention_heads = num_attention_heads 277 | self.hidden_dropout = hidden_dropout 278 | self.attention_dropout = attention_dropout 279 | self.activation_dropout = activation_dropout 280 | self.feat_proj_dropout = feat_proj_dropout 281 | self.final_dropout = final_dropout 282 | self.layerdrop = layerdrop 283 | self.layer_norm_eps = layer_norm_eps 284 | self.initializer_range = initializer_range 285 | self.vocab_size = vocab_size 286 | self.do_stable_layer_norm = do_stable_layer_norm 287 | self.use_weighted_layer_sum = use_weighted_layer_sum 288 | self.use_scan = use_scan 289 | self.fuse_matmuls = fuse_matmuls 290 | 291 | if ( 292 | (len(self.conv_stride) != self.num_feat_extract_layers) 293 | or (len(self.conv_kernel) != self.num_feat_extract_layers) 294 | or (len(self.conv_dim) != self.num_feat_extract_layers) 295 | ): 296 | raise ValueError( 297 | "Configuration for convolutional layers is incorrect. " 298 | "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, " 299 | f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) " 300 | f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`." 301 | ) 302 | 303 | # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 304 | self.apply_spec_augment = apply_spec_augment 305 | self.mask_time_prob = mask_time_prob 306 | self.mask_time_length = mask_time_length 307 | self.mask_time_min_masks = mask_time_min_masks 308 | self.mask_feature_prob = mask_feature_prob 309 | self.mask_feature_length = mask_feature_length 310 | self.mask_feature_min_masks = mask_feature_min_masks 311 | 312 | # parameters for pretraining with codevector quantized representations 313 | self.num_codevectors_per_group = num_codevectors_per_group 314 | self.num_codevector_groups = num_codevector_groups 315 | self.contrastive_logits_temperature = contrastive_logits_temperature 316 | self.feat_quantizer_dropout = feat_quantizer_dropout 317 | self.num_negatives = num_negatives 318 | self.codevector_dim = codevector_dim 319 | self.proj_codevector_dim = proj_codevector_dim 320 | self.diversity_loss_weight = diversity_loss_weight 321 | 322 | # ctc loss 323 | self.ctc_loss_reduction = ctc_loss_reduction 324 | self.ctc_zero_infinity = ctc_zero_infinity 325 | 326 | # adapter 327 | self.add_adapter = add_adapter 328 | self.adapter_kernel_size = adapter_kernel_size 329 | self.adapter_stride = adapter_stride 330 | self.num_adapter_layers = num_adapter_layers 331 | self.output_hidden_size = output_hidden_size or hidden_size 332 | 333 | # SequenceClassification-specific parameter. Feel free to ignore for other classes. 334 | self.classifier_proj_size = classifier_proj_size 335 | 336 | # XVector-specific parameters. Feel free to ignore for other classes. 337 | self.tdnn_dim = list(tdnn_dim) 338 | self.tdnn_kernel = list(tdnn_kernel) 339 | self.tdnn_dilation = list(tdnn_dilation) 340 | self.xvector_output_dim = xvector_output_dim 341 | 342 | @property 343 | def inputs_to_logits_ratio(self): 344 | return functools.reduce(operator.mul, self.conv_stride, 1) 345 | -------------------------------------------------------------------------------- /tests/check_save_feature_encoder_output.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "b1db3741-29dc-4d1f-943c-ecf2ede5ad66", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stderr", 11 | "output_type": "stream", 12 | "text": [ 13 | "/Users/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/lib/__init__.py:33: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.\n", 14 | " warnings.warn(\"JAX on Mac ARM machines is experimental and minimally tested. \"\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "from transformers import FlaxSpeechEncoderDecoderModel\n", 20 | "from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel as CustomFlaxSpeechEncoderDecoderModel\n", 21 | "import numpy as np" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "id": "d67ee694-b5e6-4c98-92f1-3f214500cf55", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "encoder_id = 'hf-internal-testing/tiny-random-wav2vec2'\n", 32 | "decoder_id = 'hf-internal-testing/tiny-random-bart'" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "id": "5b30237b-080d-4447-a19b-4f11c11603b3", 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stderr", 43 | "output_type": "stream", 44 | "text": [ 45 | "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n", 46 | "Some weights of the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 were not used when initializing FlaxWav2Vec2Model: {('quantizer', 'weight_proj', 'bias'), ('lm_head', 'bias'), ('project_q', 'bias'), ('quantizer', 'codevectors'), ('quantizer', 'weight_proj', 'kernel'), ('lm_head', 'kernel'), ('project_hid', 'bias'), ('project_q', 'kernel'), ('project_hid', 'kernel')}\n", 47 | "- This IS expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 48 | "- This IS NOT expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 49 | "Some weights of FlaxWav2Vec2Model were not initialized from the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 and are newly initialized: {('feature_extractor', 'conv_layers', '2', 'layer_norm', 'bias'), ('feature_extractor', 'conv_layers', '1', 'layer_norm', 'scale'), ('feature_extractor', 'conv_layers', '2', 'layer_norm', 'scale'), ('feature_extractor', 'conv_layers', '1', 'layer_norm', 'bias')}\n", 50 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", 51 | "You passed along `num_labels=3` with an incompatible id to label map: {'0': 'LABEL_0', '1': 'LABEL_1'}. The number of labels wil be overwritten to 2.\n", 52 | "Some weights of the model checkpoint at hf-internal-testing/tiny-random-bart were not used when initializing FlaxBartForCausalLM: {('encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('decoder', 'layers', '1', 'fc1', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn_layer_norm', 'scale'), ('model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'final_layer_norm', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn_layer_norm', 'bias'), ('encoder', 'layers', '0', 'fc1', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('model', 'shared', 'kernel'), ('shared', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn_layer_norm', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'embed_tokens', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('final_logits_bias',), ('decoder', 'embed_tokens', 'embedding'), ('model', 'encoder', 'layernorm_embedding', 'kernel'), ('encoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn_layer_norm', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '1', 'fc2', 'bias'), ('encoder', 'embed_tokens', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'final_layer_norm', 'bias'), ('encoder', 'layernorm_embedding', 'kernel'), ('decoder', 'layers', '0', 'final_layer_norm', 'scale'), ('lm_head', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'final_layer_norm', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'out_proj', 'bias'), ('encoder', 'layers', '1', 'fc2', 'bias'), ('model', 'encoder', 'layers', '1', 'final_layer_norm', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('decoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'fc2', 'kernel'), ('decoder', 'layers', '0', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'fc1', 'kernel'), ('model', 'encoder', 'layernorm_embedding', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '1', 'fc2', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '0', 'fc2', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('decoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn_layer_norm', 'scale'), ('encoder', 'layernorm_embedding', 'bias'), ('decoder', 'layers', '0', 'encoder_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn_layer_norm', 'scale'), ('decoder', 'layers', '1', 'encoder_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'final_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('decoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'fc2', 'bias'), ('model', 'encoder', 'layers', '1', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '0', 'final_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('model', 'encoder', 'layers', '1', 'fc2', 'kernel'), ('model', 'encoder', 'layers', '0', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('encoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'final_layer_norm', 'kernel'), ('model', 'encoder', 'embed_positions', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layernorm_embedding', 'scale'), ('decoder', 'embed_positions', 'embedding'), ('decoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layernorm_embedding', 'bias'), ('classification_head', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'q_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '0', 'fc2', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('encoder', 'embed_positions', 'kernel'), ('encoder', 'layers', '1', 'fc1', 'bias'), ('encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'final_layer_norm', 'bias'), ('qa_outputs', 'bias'), ('classification_head', 'out_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn', 'v_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn_layer_norm', 'scale'), ('decoder', 'layers', '1', 'encoder_attn', 'k_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'final_layer_norm', 'kernel'), ('qa_outputs', 'kernel'), ('classification_head', 'dense', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn_layer_norm', 'kernel'), ('encoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'v_proj', 'bias'), ('decoder', 'layers', '1', 'final_layer_norm', 'scale'), ('encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('decoder', 'layers', '1', 'final_layer_norm', 'bias'), ('decoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'fc2', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn', 'q_proj', 'kernel'), ('classification_head', 'dense', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('decoder', 'layers', '1', 'fc1', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('encoder', 'layers', '0', 'self_attn_layer_norm', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel')}\n", 53 | "- This IS expected if you are initializing FlaxBartForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 54 | "- This IS NOT expected if you are initializing FlaxBartForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 55 | "You passed along `num_labels=3` with an incompatible id to label map: {0: 'LABEL_0', 1: 'LABEL_1'}. The number of labels wil be overwritten to 2.\n", 56 | "Some weights of the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 were not used when initializing FlaxWav2Vec2Model: {('quantizer', 'weight_proj', 'bias'), ('lm_head', 'bias'), ('project_q', 'bias'), ('quantizer', 'codevectors'), ('quantizer', 'weight_proj', 'kernel'), ('lm_head', 'kernel'), ('project_hid', 'bias'), ('project_q', 'kernel'), ('project_hid', 'kernel')}\n", 57 | "- This IS expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 58 | "- This IS NOT expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 59 | "Some weights of FlaxWav2Vec2Model were not initialized from the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 and are newly initialized: {('feature_extractor', 'conv_layers', '2', 'layer_norm', 'bias'), ('feature_extractor', 'conv_layers', '1', 'layer_norm', 'scale'), ('feature_extractor', 'conv_layers', '2', 'layer_norm', 'scale'), ('feature_extractor', 'conv_layers', '1', 'layer_norm', 'bias')}\n", 60 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", 61 | "You passed along `num_labels=3` with an incompatible id to label map: {'0': 'LABEL_0', '1': 'LABEL_1'}. The number of labels wil be overwritten to 2.\n", 62 | "Some weights of the model checkpoint at hf-internal-testing/tiny-random-bart were not used when initializing FlaxBartForCausalLM: {('encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('decoder', 'layers', '1', 'fc1', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn_layer_norm', 'scale'), ('model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'final_layer_norm', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn_layer_norm', 'bias'), ('encoder', 'layers', '0', 'fc1', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('model', 'shared', 'kernel'), ('shared', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn_layer_norm', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'embed_tokens', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('final_logits_bias',), ('decoder', 'embed_tokens', 'embedding'), ('model', 'encoder', 'layernorm_embedding', 'kernel'), ('encoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn_layer_norm', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '1', 'fc2', 'bias'), ('encoder', 'embed_tokens', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'final_layer_norm', 'bias'), ('encoder', 'layernorm_embedding', 'kernel'), ('decoder', 'layers', '0', 'final_layer_norm', 'scale'), ('lm_head', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'final_layer_norm', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'out_proj', 'bias'), ('encoder', 'layers', '1', 'fc2', 'bias'), ('model', 'encoder', 'layers', '1', 'final_layer_norm', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('decoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'fc2', 'kernel'), ('decoder', 'layers', '0', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'fc1', 'kernel'), ('model', 'encoder', 'layernorm_embedding', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '1', 'fc2', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '0', 'fc2', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('decoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn_layer_norm', 'scale'), ('encoder', 'layernorm_embedding', 'bias'), ('decoder', 'layers', '0', 'encoder_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn_layer_norm', 'scale'), ('decoder', 'layers', '1', 'encoder_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'final_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('decoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'fc2', 'bias'), ('model', 'encoder', 'layers', '1', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '0', 'final_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('model', 'encoder', 'layers', '1', 'fc2', 'kernel'), ('model', 'encoder', 'layers', '0', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('encoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'final_layer_norm', 'kernel'), ('model', 'encoder', 'embed_positions', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layernorm_embedding', 'scale'), ('decoder', 'embed_positions', 'embedding'), ('decoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layernorm_embedding', 'bias'), ('classification_head', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'q_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '0', 'fc2', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('encoder', 'embed_positions', 'kernel'), ('encoder', 'layers', '1', 'fc1', 'bias'), ('encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'final_layer_norm', 'bias'), ('qa_outputs', 'bias'), ('classification_head', 'out_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn', 'v_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn_layer_norm', 'scale'), ('decoder', 'layers', '1', 'encoder_attn', 'k_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'final_layer_norm', 'kernel'), ('qa_outputs', 'kernel'), ('classification_head', 'dense', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn_layer_norm', 'kernel'), ('encoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'v_proj', 'bias'), ('decoder', 'layers', '1', 'final_layer_norm', 'scale'), ('encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('decoder', 'layers', '1', 'final_layer_norm', 'bias'), ('decoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'fc2', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn', 'q_proj', 'kernel'), ('classification_head', 'dense', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('decoder', 'layers', '1', 'fc1', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('encoder', 'layers', '0', 'self_attn_layer_norm', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel')}\n", 63 | "- This IS expected if you are initializing FlaxBartForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 64 | "- This IS NOT expected if you are initializing FlaxBartForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 65 | "You passed along `num_labels=3` with an incompatible id to label map: {0: 'LABEL_0', 1: 'LABEL_1'}. The number of labels wil be overwritten to 2.\n" 66 | ] 67 | } 68 | ], 69 | "source": [ 70 | "hf_model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_from_pt=True, decoder_from_pt=True)\n", 71 | "custom_model = CustomFlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_from_pt=True, decoder_from_pt=True)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 4, 77 | "id": "fdadc85b-768d-46c7-aa26-01b1f93fd635", 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "# create some dummy data\n", 82 | "inputs = np.random.randn(2, 2000)\n", 83 | "decoder_input_ids = np.arange(100).reshape(2,50)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 5, 89 | "id": "caaff775-e082-419f-b737-b392667469d5", 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "# get ground-truth outputs from Transformers 🤗 model\n", 94 | "hf_outputs = hf_model(inputs, decoder_input_ids=decoder_input_ids, output_hidden_states=True)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 6, 100 | "id": "1ff2be6b-73ef-49fc-8dbe-e933b5d4f8b0", 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "extract_features = custom_model.encode(inputs, output_features=True)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 7, 110 | "id": "0e2b2d35-dee8-4c47-be35-e543b0f9ef03", 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "custom_outputs = custom_model(inputs, extract_features=extract_features, decoder_input_ids=decoder_input_ids, output_hidden_states=True)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 8, 120 | "id": "aefb168b-4882-41f1-ba88-39f298835dac", 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "# define a helper function for our analysis\n", 125 | "def assert_almost_equals(a: np.ndarray, b: np.ndarray, tol: float = 1e-9):\n", 126 | " diff = np.abs((a - b)).max()\n", 127 | " if diff <= tol:\n", 128 | " print(f\"✅ Difference between Flax and PyTorch is {diff} (< {tol})\")\n", 129 | " else:\n", 130 | " print(f\"❌ Difference between Flax and PyTorch is {diff} (>= {tol})\")" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 9, 136 | "id": "dea2f7c9-a462-40c1-b09b-7c8499318e87", 137 | "metadata": {}, 138 | "outputs": [ 139 | { 140 | "name": "stdout", 141 | "output_type": "stream", 142 | "text": [ 143 | "--------------------------Checking encoder hidden states match--------------------------\n", 144 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n", 145 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n", 146 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n", 147 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n", 148 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n", 149 | "--------------------------Checking encoder last hidden states match--------------------------\n", 150 | "HF output shape: (2, 29, 16), custom output shape: (2, 29, 16)\n", 151 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n", 152 | "--------------------------Checking decoder hidden states match--------------------------\n", 153 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n", 154 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n", 155 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n", 156 | "--------------------------Checking logits match--------------------------\n", 157 | "HF logits shape: (2, 50, 1000), Custom logits shape: (2, 50, 1000)\n", 158 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n" 159 | ] 160 | } 161 | ], 162 | "source": [ 163 | "print(\"--------------------------Checking encoder hidden states match--------------------------\")\n", 164 | "for hf_state, custom_state in zip(hf_outputs.encoder_hidden_states, custom_outputs.encoder_hidden_states):\n", 165 | " assert hf_state.shape == custom_state.shape\n", 166 | " assert_almost_equals(hf_state, custom_state)\n", 167 | "\n", 168 | "print(\"--------------------------Checking encoder last hidden states match--------------------------\")\n", 169 | "print(f\"HF output shape: {hf_outputs.encoder_last_hidden_state.shape}, custom output shape: {custom_outputs.encoder_last_hidden_state.shape}\")\n", 170 | "assert_almost_equals(hf_outputs.encoder_last_hidden_state, custom_outputs.encoder_last_hidden_state)\n", 171 | "\n", 172 | "print(\"--------------------------Checking decoder hidden states match--------------------------\")\n", 173 | "for hf_state, custom_state in zip(hf_outputs.decoder_hidden_states, custom_outputs.decoder_hidden_states):\n", 174 | " assert hf_state.shape == custom_state.shape\n", 175 | " assert_almost_equals(hf_state, custom_state)\n", 176 | "\n", 177 | "print(\"--------------------------Checking logits match--------------------------\")\n", 178 | "print(f\"HF logits shape: {hf_outputs.logits.shape}, Custom logits shape: {custom_outputs.logits.shape}\")\n", 179 | "assert_almost_equals(hf_outputs.logits, custom_outputs.logits)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "id": "36f0a6f8-fc48-4326-bbaf-4ab5194f84a6", 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [] 189 | } 190 | ], 191 | "metadata": { 192 | "kernelspec": { 193 | "display_name": "Python 3 (ipykernel)", 194 | "language": "python", 195 | "name": "python3" 196 | }, 197 | "language_info": { 198 | "codemirror_mode": { 199 | "name": "ipython", 200 | "version": 3 201 | }, 202 | "file_extension": ".py", 203 | "mimetype": "text/x-python", 204 | "name": "python", 205 | "nbconvert_exporter": "python", 206 | "pygments_lexer": "ipython3", 207 | "version": "3.8.9" 208 | } 209 | }, 210 | "nbformat": 4, 211 | "nbformat_minor": 5 212 | } 213 | -------------------------------------------------------------------------------- /tests/.ipynb_checkpoints/check_save_feature_encoder_output-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "b1db3741-29dc-4d1f-943c-ecf2ede5ad66", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stderr", 11 | "output_type": "stream", 12 | "text": [ 13 | "/Users/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/lib/__init__.py:33: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.\n", 14 | " warnings.warn(\"JAX on Mac ARM machines is experimental and minimally tested. \"\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "from transformers import FlaxSpeechEncoderDecoderModel\n", 20 | "from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel as CustomFlaxSpeechEncoderDecoderModel\n", 21 | "import numpy as np" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "id": "d67ee694-b5e6-4c98-92f1-3f214500cf55", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "encoder_id = 'hf-internal-testing/tiny-random-wav2vec2'\n", 32 | "decoder_id = 'hf-internal-testing/tiny-random-bart'" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "id": "5b30237b-080d-4447-a19b-4f11c11603b3", 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stderr", 43 | "output_type": "stream", 44 | "text": [ 45 | "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n", 46 | "Some weights of the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 were not used when initializing FlaxWav2Vec2Model: {('quantizer', 'weight_proj', 'bias'), ('lm_head', 'bias'), ('project_q', 'bias'), ('quantizer', 'codevectors'), ('quantizer', 'weight_proj', 'kernel'), ('lm_head', 'kernel'), ('project_hid', 'bias'), ('project_q', 'kernel'), ('project_hid', 'kernel')}\n", 47 | "- This IS expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 48 | "- This IS NOT expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 49 | "Some weights of FlaxWav2Vec2Model were not initialized from the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 and are newly initialized: {('feature_extractor', 'conv_layers', '2', 'layer_norm', 'bias'), ('feature_extractor', 'conv_layers', '1', 'layer_norm', 'scale'), ('feature_extractor', 'conv_layers', '2', 'layer_norm', 'scale'), ('feature_extractor', 'conv_layers', '1', 'layer_norm', 'bias')}\n", 50 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", 51 | "You passed along `num_labels=3` with an incompatible id to label map: {'0': 'LABEL_0', '1': 'LABEL_1'}. The number of labels wil be overwritten to 2.\n", 52 | "Some weights of the model checkpoint at hf-internal-testing/tiny-random-bart were not used when initializing FlaxBartForCausalLM: {('encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('decoder', 'layers', '1', 'fc1', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn_layer_norm', 'scale'), ('model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'final_layer_norm', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn_layer_norm', 'bias'), ('encoder', 'layers', '0', 'fc1', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('model', 'shared', 'kernel'), ('shared', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn_layer_norm', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'embed_tokens', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('final_logits_bias',), ('decoder', 'embed_tokens', 'embedding'), ('model', 'encoder', 'layernorm_embedding', 'kernel'), ('encoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn_layer_norm', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '1', 'fc2', 'bias'), ('encoder', 'embed_tokens', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'final_layer_norm', 'bias'), ('encoder', 'layernorm_embedding', 'kernel'), ('decoder', 'layers', '0', 'final_layer_norm', 'scale'), ('lm_head', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'final_layer_norm', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'out_proj', 'bias'), ('encoder', 'layers', '1', 'fc2', 'bias'), ('model', 'encoder', 'layers', '1', 'final_layer_norm', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('decoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'fc2', 'kernel'), ('decoder', 'layers', '0', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'fc1', 'kernel'), ('model', 'encoder', 'layernorm_embedding', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '1', 'fc2', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '0', 'fc2', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('decoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn_layer_norm', 'scale'), ('encoder', 'layernorm_embedding', 'bias'), ('decoder', 'layers', '0', 'encoder_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn_layer_norm', 'scale'), ('decoder', 'layers', '1', 'encoder_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'final_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('decoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'fc2', 'bias'), ('model', 'encoder', 'layers', '1', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '0', 'final_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('model', 'encoder', 'layers', '1', 'fc2', 'kernel'), ('model', 'encoder', 'layers', '0', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('encoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'final_layer_norm', 'kernel'), ('model', 'encoder', 'embed_positions', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layernorm_embedding', 'scale'), ('decoder', 'embed_positions', 'embedding'), ('decoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layernorm_embedding', 'bias'), ('classification_head', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'q_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '0', 'fc2', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('encoder', 'embed_positions', 'kernel'), ('encoder', 'layers', '1', 'fc1', 'bias'), ('encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'final_layer_norm', 'bias'), ('qa_outputs', 'bias'), ('classification_head', 'out_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn', 'v_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn_layer_norm', 'scale'), ('decoder', 'layers', '1', 'encoder_attn', 'k_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'final_layer_norm', 'kernel'), ('qa_outputs', 'kernel'), ('classification_head', 'dense', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn_layer_norm', 'kernel'), ('encoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'v_proj', 'bias'), ('decoder', 'layers', '1', 'final_layer_norm', 'scale'), ('encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('decoder', 'layers', '1', 'final_layer_norm', 'bias'), ('decoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'fc2', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn', 'q_proj', 'kernel'), ('classification_head', 'dense', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('decoder', 'layers', '1', 'fc1', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('encoder', 'layers', '0', 'self_attn_layer_norm', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel')}\n", 53 | "- This IS expected if you are initializing FlaxBartForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 54 | "- This IS NOT expected if you are initializing FlaxBartForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 55 | "You passed along `num_labels=3` with an incompatible id to label map: {0: 'LABEL_0', 1: 'LABEL_1'}. The number of labels wil be overwritten to 2.\n", 56 | "Some weights of the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 were not used when initializing FlaxWav2Vec2Model: {('quantizer', 'weight_proj', 'bias'), ('lm_head', 'bias'), ('project_q', 'bias'), ('quantizer', 'codevectors'), ('quantizer', 'weight_proj', 'kernel'), ('lm_head', 'kernel'), ('project_hid', 'bias'), ('project_q', 'kernel'), ('project_hid', 'kernel')}\n", 57 | "- This IS expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 58 | "- This IS NOT expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 59 | "Some weights of FlaxWav2Vec2Model were not initialized from the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 and are newly initialized: {('feature_extractor', 'conv_layers', '2', 'layer_norm', 'bias'), ('feature_extractor', 'conv_layers', '1', 'layer_norm', 'scale'), ('feature_extractor', 'conv_layers', '2', 'layer_norm', 'scale'), ('feature_extractor', 'conv_layers', '1', 'layer_norm', 'bias')}\n", 60 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", 61 | "You passed along `num_labels=3` with an incompatible id to label map: {'0': 'LABEL_0', '1': 'LABEL_1'}. The number of labels wil be overwritten to 2.\n", 62 | "Some weights of the model checkpoint at hf-internal-testing/tiny-random-bart were not used when initializing FlaxBartForCausalLM: {('encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('decoder', 'layers', '1', 'fc1', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn_layer_norm', 'scale'), ('model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'final_layer_norm', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn_layer_norm', 'bias'), ('encoder', 'layers', '0', 'fc1', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('model', 'shared', 'kernel'), ('shared', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn_layer_norm', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'embed_tokens', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('final_logits_bias',), ('decoder', 'embed_tokens', 'embedding'), ('model', 'encoder', 'layernorm_embedding', 'kernel'), ('encoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn_layer_norm', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '1', 'fc2', 'bias'), ('encoder', 'embed_tokens', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'final_layer_norm', 'bias'), ('encoder', 'layernorm_embedding', 'kernel'), ('decoder', 'layers', '0', 'final_layer_norm', 'scale'), ('lm_head', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'final_layer_norm', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'out_proj', 'bias'), ('encoder', 'layers', '1', 'fc2', 'bias'), ('model', 'encoder', 'layers', '1', 'final_layer_norm', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('decoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'fc2', 'kernel'), ('decoder', 'layers', '0', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'fc1', 'kernel'), ('model', 'encoder', 'layernorm_embedding', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '1', 'fc2', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '0', 'fc2', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('decoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn_layer_norm', 'scale'), ('encoder', 'layernorm_embedding', 'bias'), ('decoder', 'layers', '0', 'encoder_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn_layer_norm', 'scale'), ('decoder', 'layers', '1', 'encoder_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'final_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('decoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'fc2', 'bias'), ('model', 'encoder', 'layers', '1', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '0', 'final_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('model', 'encoder', 'layers', '1', 'fc2', 'kernel'), ('model', 'encoder', 'layers', '0', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('encoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'final_layer_norm', 'kernel'), ('model', 'encoder', 'embed_positions', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layernorm_embedding', 'scale'), ('decoder', 'embed_positions', 'embedding'), ('decoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layernorm_embedding', 'bias'), ('classification_head', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'q_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '0', 'fc2', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('encoder', 'embed_positions', 'kernel'), ('encoder', 'layers', '1', 'fc1', 'bias'), ('encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'final_layer_norm', 'bias'), ('qa_outputs', 'bias'), ('classification_head', 'out_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn', 'v_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn_layer_norm', 'scale'), ('decoder', 'layers', '1', 'encoder_attn', 'k_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'final_layer_norm', 'kernel'), ('qa_outputs', 'kernel'), ('classification_head', 'dense', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn_layer_norm', 'kernel'), ('encoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'v_proj', 'bias'), ('decoder', 'layers', '1', 'final_layer_norm', 'scale'), ('encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('decoder', 'layers', '1', 'final_layer_norm', 'bias'), ('decoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'fc2', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn', 'q_proj', 'kernel'), ('classification_head', 'dense', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('decoder', 'layers', '1', 'fc1', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('encoder', 'layers', '0', 'self_attn_layer_norm', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel')}\n", 63 | "- This IS expected if you are initializing FlaxBartForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 64 | "- This IS NOT expected if you are initializing FlaxBartForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 65 | "You passed along `num_labels=3` with an incompatible id to label map: {0: 'LABEL_0', 1: 'LABEL_1'}. The number of labels wil be overwritten to 2.\n" 66 | ] 67 | } 68 | ], 69 | "source": [ 70 | "hf_model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_from_pt=True, decoder_from_pt=True)\n", 71 | "custom_model = CustomFlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_from_pt=True, decoder_from_pt=True)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 4, 77 | "id": "fdadc85b-768d-46c7-aa26-01b1f93fd635", 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "# create some dummy data\n", 82 | "inputs = np.random.randn(2, 2000)\n", 83 | "decoder_input_ids = np.arange(100).reshape(2,50)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 5, 89 | "id": "caaff775-e082-419f-b737-b392667469d5", 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "# get ground-truth outputs from Transformers 🤗 model\n", 94 | "hf_outputs = hf_model(inputs, decoder_input_ids=decoder_input_ids, output_hidden_states=True)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 6, 100 | "id": "1ff2be6b-73ef-49fc-8dbe-e933b5d4f8b0", 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "extract_features = custom_model.encode(inputs, output_features=True)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 7, 110 | "id": "0e2b2d35-dee8-4c47-be35-e543b0f9ef03", 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "custom_outputs = custom_model(inputs, extract_features=extract_features, decoder_input_ids=decoder_input_ids, output_hidden_states=True)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 8, 120 | "id": "aefb168b-4882-41f1-ba88-39f298835dac", 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "# define a helper function for our analysis\n", 125 | "def assert_almost_equals(a: np.ndarray, b: np.ndarray, tol: float = 1e-9):\n", 126 | " diff = np.abs((a - b)).max()\n", 127 | " if diff <= tol:\n", 128 | " print(f\"✅ Difference between Flax and PyTorch is {diff} (< {tol})\")\n", 129 | " else:\n", 130 | " print(f\"❌ Difference between Flax and PyTorch is {diff} (>= {tol})\")" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 9, 136 | "id": "dea2f7c9-a462-40c1-b09b-7c8499318e87", 137 | "metadata": {}, 138 | "outputs": [ 139 | { 140 | "name": "stdout", 141 | "output_type": "stream", 142 | "text": [ 143 | "--------------------------Checking encoder hidden states match--------------------------\n", 144 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n", 145 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n", 146 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n", 147 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n", 148 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n", 149 | "--------------------------Checking encoder last hidden states match--------------------------\n", 150 | "HF output shape: (2, 29, 16), custom output shape: (2, 29, 16)\n", 151 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n", 152 | "--------------------------Checking decoder hidden states match--------------------------\n", 153 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n", 154 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n", 155 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n", 156 | "--------------------------Checking logits match--------------------------\n", 157 | "HF logits shape: (2, 50, 1000), Custom logits shape: (2, 50, 1000)\n", 158 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n" 159 | ] 160 | } 161 | ], 162 | "source": [ 163 | "print(\"--------------------------Checking encoder hidden states match--------------------------\")\n", 164 | "for hf_state, custom_state in zip(hf_outputs.encoder_hidden_states, custom_outputs.encoder_hidden_states):\n", 165 | " assert hf_state.shape == custom_state.shape\n", 166 | " assert_almost_equals(hf_state, custom_state)\n", 167 | "\n", 168 | "print(\"--------------------------Checking encoder last hidden states match--------------------------\")\n", 169 | "print(f\"HF output shape: {hf_outputs.encoder_last_hidden_state.shape}, custom output shape: {custom_outputs.encoder_last_hidden_state.shape}\")\n", 170 | "assert_almost_equals(hf_outputs.encoder_last_hidden_state, custom_outputs.encoder_last_hidden_state)\n", 171 | "\n", 172 | "print(\"--------------------------Checking decoder hidden states match--------------------------\")\n", 173 | "for hf_state, custom_state in zip(hf_outputs.decoder_hidden_states, custom_outputs.decoder_hidden_states):\n", 174 | " assert hf_state.shape == custom_state.shape\n", 175 | " assert_almost_equals(hf_state, custom_state)\n", 176 | "\n", 177 | "print(\"--------------------------Checking logits match--------------------------\")\n", 178 | "print(f\"HF logits shape: {hf_outputs.logits.shape}, Custom logits shape: {custom_outputs.logits.shape}\")\n", 179 | "assert_almost_equals(hf_outputs.logits, custom_outputs.logits)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "id": "36f0a6f8-fc48-4326-bbaf-4ab5194f84a6", 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [] 189 | } 190 | ], 191 | "metadata": { 192 | "kernelspec": { 193 | "display_name": "Python 3 (ipykernel)", 194 | "language": "python", 195 | "name": "python3" 196 | }, 197 | "language_info": { 198 | "codemirror_mode": { 199 | "name": "ipython", 200 | "version": 3 201 | }, 202 | "file_extension": ".py", 203 | "mimetype": "text/x-python", 204 | "name": "python", 205 | "nbconvert_exporter": "python", 206 | "pygments_lexer": "ipython3", 207 | "version": "3.8.9" 208 | } 209 | }, 210 | "nbformat": 4, 211 | "nbformat_minor": 5 212 | } 213 | --------------------------------------------------------------------------------