├── seq2seq.png
├── .idea
├── vcs.xml
├── misc.xml
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
├── seq2seq-speech.iml
├── git_toolbox_prj.xml
├── sshConfigs.xml
├── deployment.xml
└── webServers.xml
├── models
├── __init__.py
├── configuration_speech_encoder_decoder.py
├── configuration_bart.py
└── configuration_wav2vec2.py
├── scripts
├── ctc_ngram
│ ├── run_ctc_ngram_ami.sh
│ ├── run_ctc_ngram_tedlium.sh
│ ├── run_ctc_ngram_kensho.sh
│ ├── run_ctc_ngram_gs.sh
│ ├── run_ctc_ngram_earnings22.sh
│ ├── run_ctc_ngram_voxpopuli.sh
│ ├── run_ctc_ngram_cv9.sh
│ ├── run_ctc_ngram_librispeech.sh
│ └── run_ctc_ngram_dummy.sh
├── whisper
│ ├── run_tedlium.sh
│ ├── run_gigaspeech.sh
│ ├── run_spgispeech.sh
│ ├── run_earnings22.sh
│ ├── run_whisper_voxpopuli.sh
│ ├── run_ami.sh
│ ├── run_chime.sh
│ ├── run_switchboard.sh
│ ├── run_common_voice_9.sh
│ └── run_librispeech.sh
├── ctc
│ ├── run_gigaspeech.sh
│ ├── run_spgispeech.sh
│ ├── run_voxpopuli.sh
│ ├── run_common_voice_9.sh
│ ├── run_tedlium.sh
│ ├── run_ami.sh
│ ├── run_switchboard.sh
│ ├── run_librispeech.sh
│ └── run_earnings22.sh
└── seq2seq
│ ├── run_ami.sh
│ ├── run_gigaspeech.sh
│ ├── run_spgispeech.sh
│ ├── run_common_voice_9.sh
│ ├── run_tedlium.sh
│ ├── run_switchboard.sh
│ ├── run_voxpopuli.sh
│ ├── run_librispeech.sh
│ └── run_earnings22.sh
├── run_ctc_dummy.sh
├── run_whisper_dummy.sh
├── run_seq2seq_dummy.sh
├── tests
├── check_flax_ctc.py
├── check_flax_ctc_cv9.py
├── check_save_feature_encoder_output.ipynb
└── .ipynb_checkpoints
│ └── check_save_feature_encoder_output-checkpoint.ipynb
├── .gitignore
├── README.md
├── get_ctc_tokenizer.py
└── get_ctc_ngram.py
/seq2seq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sanchit-gandhi/seq2seq-speech/HEAD/seq2seq.png
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/seq2seq-speech.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from models.configuration_bart import BartConfig
2 | from models.configuration_wav2vec2 import Wav2Vec2Config
3 | from models.configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig
4 | from models.modeling_flax_wav2vec2 import FlaxWav2Vec2Model, FlaxWav2Vec2Module, FlaxWav2Vec2ForCTC, FlaxWav2Vec2ForCTCModule
5 | from models.modeling_flax_bart import FlaxBartForCausalLM, FlaxBartForCausalLMModule
6 | from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
7 |
--------------------------------------------------------------------------------
/.idea/git_toolbox_prj.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/scripts/ctc_ngram/run_ctc_ngram_ami.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python ./run_flax_speech_recognition_ctc_ngram.py \
3 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-ctc-ami-black-box" \
4 | --tokenizer_name="/home/patrick/ngrams/ami" \
5 | --decoder_name="/home/patrick/ngrams/ami" \
6 | --dataset_cache_dir="/home/patrick/.cache/huggingface/datasets" \
7 | --dataset_name="speech-seq2seq/ami" \
8 | --dataset_config_name="ihm" \
9 | --eval_split_name="validation" \
10 | --test_split_name="test" \
11 | --text_column="text" \
12 | --preprocessing_num_workers="1" \
13 | --output_dir="/home/patrick/ngrams/ami/evaluation" \
14 | --do_eval \
15 | --do_predict \
16 | --overwrite_output_dir \
17 | --use_auth_token
18 |
--------------------------------------------------------------------------------
/scripts/ctc_ngram/run_ctc_ngram_tedlium.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python ./run_flax_speech_recognition_ctc_ngram.py \
3 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-ctc-tedlium-black-box" \
4 | --tokenizer_name="/home/patrick/ngrams/tedlium" \
5 | --decoder_name="/home/patrick/ngrams/tedlium" \
6 | --dataset_cache_dir="/home/patrick/.cache/huggingface/datasets" \
7 | --dataset_name="LIUM/tedlium" \
8 | --dataset_config_name="release3" \
9 | --eval_split_name="validation" \
10 | --test_split_name="test" \
11 | --text_column="text" \
12 | --preprocessing_num_workers="1" \
13 | --output_dir="/home/patrick/ngrams/tedlium/evaluation" \
14 | --do_eval \
15 | --do_predict \
16 | --overwrite_output_dir \
17 | --use_auth_token
18 |
--------------------------------------------------------------------------------
/scripts/ctc_ngram/run_ctc_ngram_kensho.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python ./run_flax_speech_recognition_ctc_ngram.py \
3 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-ctc-spgispeech-black-box" \
4 | --tokenizer_name="/home/patrick/ngrams/spgispeech" \
5 | --decoder_name="/home/patrick/ngrams/spgispeech" \
6 | --dataset_cache_dir="/home/patrick/.cache/huggingface/datasets" \
7 | --dataset_name="kensho/spgispeech" \
8 | --dataset_config_name="L" \
9 | --eval_split_name="validation" \
10 | --test_split_name="test" \
11 | --text_column_name="transcript" \
12 | --preprocessing_num_workers="1" \
13 | --output_dir="/home/patrick/ngrams/spgispeech/evaluation" \
14 | --do_eval \
15 | --do_predict \
16 | --overwrite_output_dir \
17 | --use_auth_token \
18 |
--------------------------------------------------------------------------------
/scripts/ctc_ngram/run_ctc_ngram_gs.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python ./run_flax_speech_recognition_ctc_ngram.py \
3 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-ctc-gs-black-box" \
4 | --tokenizer_name="/home/patrick/ngrams/gigaspeech" \
5 | --decoder_name="/home/patrick/ngrams/gigaspeech" \
6 | --dataset_cache_dir="/home/patrick/.cache/huggingface/datasets" \
7 | --dataset_name="speechcolab/gigaspeech" \
8 | --dataset_config_name="l" \
9 | --eval_split_name="validation" \
10 | --test_split_name="test" \
11 | --text_column_name="text" \
12 | --preprocessing_num_workers="1" \
13 | --output_dir="/home/patrick/ngrams/gigaspeech/evaluation" \
14 | --do_eval \
15 | --do_predict \
16 | --overwrite_output_dir \
17 | --use_auth_token \
18 | --do_lower_case
19 |
--------------------------------------------------------------------------------
/scripts/ctc_ngram/run_ctc_ngram_earnings22.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python ./run_flax_speech_recognition_ctc_ngram.py \
3 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-ctc-earnings22-cased-hidden-activation-featproj-dropout-0.2" \
4 | --tokenizer_name="/home/patrick/ngrams/earnings22_robust_split" \
5 | --decoder_name="/home/patrick/ngrams/earnings22_robust_split" \
6 | --dataset_cache_dir="/home/patrick/.cache/huggingface/datasets" \
7 | --dataset_name="sanchit-gandhi/earnings22_robust_split" \
8 | --eval_split_name="validation" \
9 | --test_split_name="test" \
10 | --text_column="sentence" \
11 | --preprocessing_num_workers="1" \
12 | --output_dir="/home/patrick/ngrams/earnings22_robust_split/evaluation" \
13 | --do_eval \
14 | --do_predict \
15 | --overwrite_output_dir \
16 | --use_auth_token
17 |
--------------------------------------------------------------------------------
/scripts/ctc_ngram/run_ctc_ngram_voxpopuli.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python ./run_flax_speech_recognition_ctc_ngram.py \
3 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-ctc-voxpopuli-black-box" \
4 | --tokenizer_name="/home/patrick/ngrams/voxpopuli" \
5 | --decoder_name="/home/patrick/ngrams/voxpopuli" \
6 | --dataset_cache_dir="/home/patrick/.cache/huggingface/datasets" \
7 | --dataset_name="polinaeterna/voxpopuli" \
8 | --dataset_config_name="en" \
9 | --eval_split_name="validation" \
10 | --test_split_name="test" \
11 | --text_column_name="normalized_text" \
12 | --preprocessing_num_workers="1" \
13 | --output_dir="/home/patrick/ngrams/voxpopuli/evaluation" \
14 | --do_eval \
15 | --do_predict \
16 | --overwrite_output_dir \
17 | --max_label_length=2056 \
18 | --use_auth_token \
19 | --do_lower_case
20 |
--------------------------------------------------------------------------------
/scripts/ctc_ngram/run_ctc_ngram_cv9.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python ./run_flax_speech_recognition_ctc_ngram.py \
3 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-ctc-cv9-black-box" \
4 | --tokenizer_name="/home/patrick/ngrams/common_voice_9_0" \
5 | --decoder_name="/home/patrick/ngrams/common_voice_9_0" \
6 | --dataset_cache_dir="/home/patrick/.cache/huggingface/datasets" \
7 | --dataset_name="mozilla-foundation/common_voice_9_0" \
8 | --dataset_config_name="en" \
9 | --eval_split_name="validation" \
10 | --test_split_name="test" \
11 | --text_column_name="sentence" \
12 | --preprocessing_num_workers="1" \
13 | --max_eval_duration_in_seconds="20.0" \
14 | --output_dir="/home/patrick/ngrams/common_voice_9_0/evaluation" \
15 | --do_eval \
16 | --do_predict \
17 | --overwrite_output_dir \
18 | --use_auth_token
19 |
--------------------------------------------------------------------------------
/scripts/ctc_ngram/run_ctc_ngram_librispeech.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python ./run_flax_speech_recognition_ctc_ngram.py \
3 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-ctc-ls-960h-black-box" \
4 | --tokenizer_name="/home/patrick/ngrams/librispeech_asr" \
5 | --decoder_name="/home/patrick/ngrams/librispeech_asr" \
6 | --dataset_cache_dir="/home/patrick/.cache/huggingface/datasets" \
7 | --dataset_name="librispeech_asr" \
8 | --dataset_config_name="all" \
9 | --eval_split_name="validation.clean" \
10 | --test_split_name="validation.other+test.clean+test.other" \
11 | --text_column="text" \
12 | --preprocessing_num_workers="1" \
13 | --output_dir="/home/patrick/ngrams/librispeech_asr/evaluation" \
14 | --do_eval \
15 | --do_predict \
16 | --overwrite_output_dir \
17 | --use_auth_token \
18 | --do_lower_case
19 |
--------------------------------------------------------------------------------
/scripts/ctc_ngram/run_ctc_ngram_dummy.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python ./run_flax_speech_recognition_ctc_ngram.py \
3 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-ctc-ls-960h-baseline" \
4 | --tokenizer_name="sanchit-gandhi/flax-wav2vec2-ctc-ls-960h-baseline" \
5 | --decoder_name="patrickvonplaten/wav2vec2-base-100h-with-lm" \
6 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
7 | --dataset_name="librispeech_asr" \
8 | --dataset_config_name="all" \
9 | --eval_split_name="validation.clean" \
10 | --test_split_name="validation.other+test.clean+test.other" \
11 | --text_column="text" \
12 | --preprocessing_num_workers="1" \
13 | --output_dir="./ngram_output_dir" \
14 | --max_steps="50000" \
15 | --eval_steps="10000" \
16 | --save_steps="10000" \
17 | --wandb_project="librispeech_960h" \
18 | --wandb_name="flax-wav2vec2-ctc-ls-960h-with-lm-baseline" \
19 | --do_eval \
20 | --do_predict \
21 | --overwrite_output_dir \
22 | --use_auth_token
23 |
--------------------------------------------------------------------------------
/run_ctc_dummy.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python ./run_flax_speech_recognition_ctc.py \
3 | --model_name_or_path="speech-seq2seq/flax-wav2vec2-large-lv60-scan" \
4 | --tokenizer_name="speech-seq2seq/flax-wav2vec2-large-lv60-scan" \
5 | --dataset_name="librispeech_asr" \
6 | --dataset_config_name="clean" \
7 | --train_split_name="train.100" \
8 | --eval_split_name="validation" \
9 | --test_split_name="test" \
10 | --text_column_name="text" \
11 | --output_dir="./" \
12 | --wandb_project="librispeech_asr" \
13 | --wandb_name="flax-wav2vec2-ctc-ls-100h" \
14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
15 | --max_steps="50000" \
16 | --save_steps="10000" \
17 | --eval_steps="10000" \
18 | --learning_rate="3e-4" \
19 | --logging_steps="25" \
20 | --warmup_steps="500" \
21 | --preprocessing_num_workers="1" \
22 | --do_train \
23 | --do_eval \
24 | --do_predict \
25 | --overwrite_output_dir \
26 | --gradient_checkpointing \
27 | --freeze_feature_encoder \
28 | --push_to_hub \
29 | --use_auth_token
30 |
--------------------------------------------------------------------------------
/scripts/whisper/run_tedlium.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | CUDA_VISIBLE_DEVICES=0 python run_speech_recognition_whisper.py \
3 | --model_name_or_path="medium.en" \
4 | --dataset_name="LIUM/tedlium" \
5 | --dataset_config_name="release3" \
6 | --train_split_name="train" \
7 | --eval_split_name="validation" \
8 | --test_split_name="test" \
9 | --text_column_name="text" \
10 | --max_steps="2500" \
11 | --output_dir="./" \
12 | --run_name="whisper-tedlium" \
13 | --wandb_project="whisper" \
14 | --per_device_train_batch_size="64" \
15 | --per_device_eval_batch_size="16" \
16 | --logging_steps="25" \
17 | --learning_rate="1e-4" \
18 | --warmup_steps="500" \
19 | --report_to="wandb" \
20 | --preprocessing_num_workers="16" \
21 | --evaluation_strategy="steps" \
22 | --eval_steps="500" \
23 | --save_strategy="steps" \
24 | --save_steps="500" \
25 | --generation_max_length="224" \
26 | --length_column_name="input_lengths" \
27 | --do_lower_case="True" \
28 | --push_to_hub="False" \
29 | --gradient_checkpointing \
30 | --group_by_length \
31 | --freeze_encoder \
32 | --fp16 \
33 | --overwrite_output_dir \
34 | --do_train \
35 | --do_eval \
36 | --do_predict \
37 | --predict_with_generate \
38 | --use_auth_token
39 |
--------------------------------------------------------------------------------
/scripts/whisper/run_gigaspeech.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | CUDA_VISIBLE_DEVICES=1 python run_speech_recognition_whisper.py \
3 | --model_name_or_path="medium.en" \
4 | --dataset_name="speechcolab/gigaspeech" \
5 | --dataset_config_name="m" \
6 | --train_split_name="train" \
7 | --eval_split_name="validation" \
8 | --test_split_name="test" \
9 | --text_column_name="text" \
10 | --max_steps="5000" \
11 | --output_dir="./" \
12 | --run_name="whisper-gigaspeech-5k" \
13 | --wandb_project="whisper" \
14 | --per_device_train_batch_size="64" \
15 | --per_device_eval_batch_size="16" \
16 | --logging_steps="25" \
17 | --learning_rate="1e-4" \
18 | --warmup_steps="500" \
19 | --report_to="wandb" \
20 | --preprocessing_num_workers="16" \
21 | --evaluation_strategy="steps" \
22 | --eval_steps="1000" \
23 | --save_strategy="steps" \
24 | --save_steps="1000" \
25 | --generation_max_length="224" \
26 | --length_column_name="input_lengths" \
27 | --do_lower_case="True" \
28 | --push_to_hub="False" \
29 | --gradient_checkpointing \
30 | --group_by_length \
31 | --freeze_encoder \
32 | --fp16 \
33 | --overwrite_output_dir \
34 | --do_train \
35 | --do_eval \
36 | --do_predict \
37 | --predict_with_generate \
38 | --use_auth_token
39 |
--------------------------------------------------------------------------------
/scripts/whisper/run_spgispeech.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | CUDA_VISIBLE_DEVICES=2 python run_speech_recognition_whisper.py \
3 | --model_name_or_path="medium.en" \
4 | --dataset_name="kensho/spgispeech" \
5 | --dataset_config_name="M" \
6 | --train_split_name="train" \
7 | --eval_split_name="validation" \
8 | --test_split_name="test" \
9 | --text_column_name="transcript" \
10 | --max_steps="5000" \
11 | --output_dir="./" \
12 | --run_name="whisper-spgispeech-5k" \
13 | --wandb_project="whisper" \
14 | --per_device_train_batch_size="64" \
15 | --per_device_eval_batch_size="16" \
16 | --logging_steps="25" \
17 | --learning_rate="1e-4" \
18 | --warmup_steps="500" \
19 | --report_to="wandb" \
20 | --preprocessing_num_workers="16" \
21 | --evaluation_strategy="steps" \
22 | --eval_steps="1000" \
23 | --save_strategy="steps" \
24 | --save_steps="1000" \
25 | --generation_max_length="224" \
26 | --length_column_name="input_lengths" \
27 | --do_lower_case="False" \
28 | --push_to_hub="False" \
29 | --gradient_checkpointing \
30 | --group_by_length \
31 | --freeze_encoder \
32 | --fp16 \
33 | --overwrite_output_dir \
34 | --do_train \
35 | --do_eval \
36 | --do_predict \
37 | --predict_with_generate \
38 | --use_auth_token
39 |
--------------------------------------------------------------------------------
/scripts/whisper/run_earnings22.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | CUDA_VISIBLE_DEVICES=4 python run_speech_recognition_whisper.py \
3 | --model_name_or_path="medium.en" \
4 | --dataset_name="sanchit-gandhi/earnings22_split" \
5 | --dataset_config_name="all" \
6 | --train_split_name="train" \
7 | --eval_split_name="validation" \
8 | --test_split_name="test" \
9 | --text_column_name="sentence" \
10 | --max_steps="2500" \
11 | --output_dir="./" \
12 | --run_name="whisper-earnings22" \
13 | --wandb_project="whisper" \
14 | --per_device_train_batch_size="64" \
15 | --per_device_eval_batch_size="16" \
16 | --logging_steps="25" \
17 | --learning_rate="1e-4" \
18 | --warmup_steps="500" \
19 | --report_to="wandb" \
20 | --preprocessing_num_workers="16" \
21 | --evaluation_strategy="steps" \
22 | --eval_steps="500" \
23 | --save_strategy="steps" \
24 | --save_steps="500" \
25 | --generation_max_length="224" \
26 | --length_column_name="input_lengths" \
27 | --do_lower_case="False" \
28 | --push_to_hub="False" \
29 | --gradient_checkpointing \
30 | --group_by_length \
31 | --freeze_encoder \
32 | --fp16 \
33 | --overwrite_output_dir \
34 | --do_train \
35 | --do_eval \
36 | --do_predict \
37 | --predict_with_generate \
38 | --use_auth_token
39 |
--------------------------------------------------------------------------------
/scripts/whisper/run_whisper_voxpopuli.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | CUDA_VISIBLE_DEVICES=0 python run_speech_recognition_whisper.py \
3 | --model_name_or_path="medium.en" \
4 | --dataset_name="polinaeterna/voxpopuli" \
5 | --dataset_config_name="en" \
6 | --train_split_name="train" \
7 | --eval_split_name="validation" \
8 | --test_split_name="test" \
9 | --text_column_name="normalized_text" \
10 | --max_steps="5000" \
11 | --output_dir="./" \
12 | --run_name="whisper-voxpopuli-5k" \
13 | --wandb_project="whisper" \
14 | --per_device_train_batch_size="64" \
15 | --per_device_eval_batch_size="16" \
16 | --logging_steps="25" \
17 | --learning_rate="1e-4" \
18 | --warmup_steps="500" \
19 | --report_to="wandb" \
20 | --preprocessing_num_workers="16" \
21 | --evaluation_strategy="steps" \
22 | --eval_steps="500" \
23 | --save_strategy="steps" \
24 | --save_steps="500" \
25 | --generation_max_length="224" \
26 | --length_column_name="input_lengths" \
27 | --do_lower_case="True" \
28 | --push_to_hub="False" \
29 | --gradient_checkpointing \
30 | --group_by_length \
31 | --freeze_encoder \
32 | --fp16 \
33 | --overwrite_output_dir \
34 | --do_train \
35 | --do_eval \
36 | --do_predict \
37 | --predict_with_generate \
38 | --use_auth_token
39 |
--------------------------------------------------------------------------------
/scripts/ctc/run_gigaspeech.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python run_flax_speech_recognition_ctc.py \
3 | --model_name_or_path="speech-seq2seq/flax-wav2vec2-large-lv60-scan" \
4 | --tokenizer_name="sanchit-gandhi/wav2vec2-ctc-gs-black-box-tokenizer" \
5 | --dataset_name="speechcolab/gigaspeech" \
6 | --dataset_config_name="l" \
7 | --train_split_name="train" \
8 | --eval_split_name="validation" \
9 | --test_split_name="test" \
10 | --text_column_name="text" \
11 | --output_dir="./flax-wav2vec2-ctc-gs-black-box" \
12 | --wandb_project="gigaspeech" \
13 | --wandb_name="flax-wav2vec2-ctc-gs-black-box" \
14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
15 | --max_steps="50000" \
16 | --save_steps="10000" \
17 | --eval_steps="10000" \
18 | --learning_rate="3e-4" \
19 | --logging_steps="25" \
20 | --warmup_steps="5000" \
21 | --preprocessing_num_workers="1" \
22 | --do_train \
23 | --do_eval \
24 | --do_predict \
25 | --overwrite_output_dir \
26 | --gradient_checkpointing \
27 | --freeze_feature_encoder \
28 | --push_to_hub \
29 | --use_auth_token
--------------------------------------------------------------------------------
/scripts/whisper/run_ami.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | CUDA_VISIBLE_DEVICES=3 python run_speech_recognition_whisper.py \
3 | --model_name_or_path="medium.en" \
4 | --dataset_name="speech-seq2seq/ami" \
5 | --dataset_config_name="ihm" \
6 | --train_split_name="train" \
7 | --eval_split_name="validation" \
8 | --test_split_name="test" \
9 | --text_column_name="text" \
10 | --max_steps="2500" \
11 | --output_dir="./" \
12 | --run_name="whisper-ami-dropout-0.1" \
13 | --dropout_rate="0.1" \
14 | --wandb_project="whisper" \
15 | --per_device_train_batch_size="64" \
16 | --per_device_eval_batch_size="16" \
17 | --logging_steps="25" \
18 | --learning_rate="1e-4" \
19 | --warmup_steps="500" \
20 | --report_to="wandb" \
21 | --preprocessing_num_workers="16" \
22 | --evaluation_strategy="steps" \
23 | --eval_steps="500" \
24 | --save_strategy="steps" \
25 | --save_steps="500" \
26 | --generation_max_length="224" \
27 | --length_column_name="input_lengths" \
28 | --do_lower_case="False" \
29 | --push_to_hub="False" \
30 | --gradient_checkpointing \
31 | --group_by_length \
32 | --freeze_encoder \
33 | --fp16 \
34 | --overwrite_output_dir \
35 | --do_train \
36 | --do_eval \
37 | --do_predict \
38 | --predict_with_generate \
39 | --use_auth_token
40 |
--------------------------------------------------------------------------------
/run_whisper_dummy.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | CUDA_VISIBLE_DEVICES="" python run_speech_recognition_whisper.py \
3 | --model_name_or_path="tiny.en" \
4 | --dataset_name="patrickvonplaten/librispeech_asr_dummy" \
5 | --num_train_epochs="2" \
6 | --evaluation_strategy="epoch" \
7 | --dataset_config_name="clean" \
8 | --dataset_cache_dir="/home/patrick_huggingface_co/hey" \
9 | --train_split_name="validation[:32]" \
10 | --eval_split_name="validation" \
11 | --test_split_name="validation[:90%]" \
12 | --text_column_name="text" \
13 | --output_dir="./output_dir" \
14 | --run_name="whisper-ls-dummy" \
15 | --wandb_project="whisper-dummy" \
16 | --per_device_train_batch_size="8" \
17 | --per_device_eval_batch_size="4" \
18 | --logging_steps="1" \
19 | --learning_rate="1e-4" \
20 | --warmup_steps="3" \
21 | --report_to="wandb" \
22 | --push_to_hub="False" \
23 | --preprocessing_num_workers="4" \
24 | --evaluation_strategy="epoch" \
25 | --max_eval_samples="8" \
26 | --max_predict_samples="8" \
27 | --length_column_name="input_lengths" \
28 | --save_strategy="no" \
29 | --group_by_length \
30 | --overwrite_output_dir \
31 | --freeze_encoder \
32 | --do_eval \
33 | --do_predict \
34 | --do_train \
35 | --predict_with_generate \
36 | --generation_max_length=224 \
37 |
--------------------------------------------------------------------------------
/scripts/whisper/run_chime.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | CUDA_VISIBLE_DEVICES=5 python run_speech_recognition_whisper.py \
3 | --model_name_or_path="medium.en" \
4 | --dataset_name="speech-seq2seq/chime4-raw" \
5 | --dataset_config_name="1-channel" \
6 | --train_split_name="train" \
7 | --eval_split_name="validation" \
8 | --test_split_name="test" \
9 | --text_column_name="text" \
10 | --max_steps="2500" \
11 | --output_dir="./" \
12 | --run_name="whisper-chime4-dropout-0.1" \
13 | --dropout_rate="0.1" \
14 | --wandb_project="whisper" \
15 | --per_device_train_batch_size="64" \
16 | --per_device_eval_batch_size="16" \
17 | --logging_steps="25" \
18 | --learning_rate="1e-4" \
19 | --warmup_steps="500" \
20 | --report_to="wandb" \
21 | --preprocessing_num_workers="16" \
22 | --evaluation_strategy="steps" \
23 | --eval_steps="500" \
24 | --save_strategy="steps" \
25 | --save_steps="500" \
26 | --generation_max_length="224" \
27 | --length_column_name="input_lengths" \
28 | --do_lower_case="False" \
29 | --push_to_hub="False" \
30 | --gradient_checkpointing \
31 | --group_by_length \
32 | --freeze_encoder \
33 | --fp16 \
34 | --overwrite_output_dir \
35 | --do_train \
36 | --do_eval \
37 | --do_predict \
38 | --predict_with_generate \
39 | --use_auth_token
40 |
--------------------------------------------------------------------------------
/scripts/whisper/run_switchboard.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | CUDA_VISIBLE_DEVICES=1 python run_speech_recognition_whisper.py \
3 | --model_name_or_path="medium.en" \
4 | --dataset_name="ldc/switchboard" \
5 | --dataset_config_name="all" \
6 | --train_split_name="train.switchboard" \
7 | --eval_split_name="validation.switchboard" \
8 | --test_split_name="test.switchboard+test.callhome" \
9 | --text_column_name="text" \
10 | --max_steps="5000" \
11 | --output_dir="./" \
12 | --run_name="whisper-switchboard" \
13 | --wandb_project="whisper" \
14 | --per_device_train_batch_size="64" \
15 | --per_device_eval_batch_size="16" \
16 | --logging_steps="25" \
17 | --learning_rate="1e-4" \
18 | --warmup_steps="500" \
19 | --report_to="wandb" \
20 | --preprocessing_num_workers="16" \
21 | --evaluation_strategy="steps" \
22 | --eval_steps="1000" \
23 | --save_strategy="steps" \
24 | --save_steps="1000" \
25 | --generation_max_length="224" \
26 | --length_column_name="input_lengths" \
27 | --do_lower_case="True" \
28 | --push_to_hub="False" \
29 | --gradient_checkpointing \
30 | --group_by_length \
31 | --freeze_encoder \
32 | --fp16 \
33 | --overwrite_output_dir \
34 | --do_train \
35 | --do_eval \
36 | --do_predict \
37 | --predict_with_generate \
38 | --use_auth_token
39 |
--------------------------------------------------------------------------------
/scripts/whisper/run_common_voice_9.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | CUDA_VISIBLE_DEVICES=0 python run_speech_recognition_whisper.py \
3 | --model_name_or_path="medium.en" \
4 | --dataset_name="mozilla-foundation/common_voice_9_0" \
5 | --dataset_config_name="en" \
6 | --train_split_name="train" \
7 | --eval_split_name="validation" \
8 | --test_split_name="test" \
9 | --text_column_name="sentence" \
10 | --max_steps="5000" \
11 | --output_dir="./" \
12 | --run_name="whisper-cv9" \
13 | --wandb_project="whisper" \
14 | --per_device_train_batch_size="64" \
15 | --per_device_eval_batch_size="16" \
16 | --logging_steps="25" \
17 | --learning_rate="1e-4" \
18 | --warmup_steps="500" \
19 | --report_to="wandb" \
20 | --preprocessing_num_workers="16" \
21 | --evaluation_strategy="steps" \
22 | --eval_steps="1000" \
23 | --save_strategy="steps" \
24 | --save_steps="1000" \
25 | --generation_max_length="224" \
26 | --length_column_name="input_lengths" \
27 | --do_lower_case="False" \
28 | --push_to_hub="False" \
29 | --max_eval_duration_in_seconds="20" \
30 | --gradient_checkpointing \
31 | --group_by_length \
32 | --freeze_encoder \
33 | --fp16 \
34 | --overwrite_output_dir \
35 | --do_train \
36 | --do_eval \
37 | --do_predict \
38 | --predict_with_generate \
39 | --use_auth_token
40 |
--------------------------------------------------------------------------------
/scripts/ctc/run_spgispeech.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python run_flax_speech_recognition_ctc.py \
3 | --model_name_or_path="speech-seq2seq/flax-wav2vec2-large-lv60-scan" \
4 | --tokenizer_name="sanchit-gandhi/wav2vec2-ctc-spgispeech-tokenizer" \
5 | --dataset_name="kensho/spgispeech" \
6 | --dataset_config_name="L" \
7 | --train_split_name="train" \
8 | --eval_split_name="validation" \
9 | --test_split_name="test" \
10 | --text_column_name="transcript" \
11 | --output_dir="./flax-wav2vec2-ctc-spgispeech-baseline" \
12 | --wandb_project="spgispeech" \
13 | --wandb_name="flax-wav2vec2-ctc-spgispeech-baseline" \
14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
15 | --max_steps="50000" \
16 | --save_steps="10000" \
17 | --eval_steps="10000" \
18 | --learning_rate="3e-4" \
19 | --logging_steps="25" \
20 | --warmup_steps="5000" \
21 | --preprocessing_num_workers="1" \
22 | --do_lower_case="False" \
23 | --do_train \
24 | --do_eval \
25 | --do_predict \
26 | --overwrite_output_dir \
27 | --gradient_checkpointing \
28 | --freeze_feature_encoder \
29 | --push_to_hub \
30 | --use_auth_token
--------------------------------------------------------------------------------
/scripts/whisper/run_librispeech.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | CUDA_VISIBLE_DEVICES=0 python run_speech_recognition_whisper.py \
3 | --model_name_or_path="medium.en" \
4 | --dataset_name="librispeech_asr" \
5 | --dataset_config_name="all" \
6 | --train_split_name="train.clean.100+train.clean.360+train.other.500" \
7 | --eval_split_name="validation.clean" \
8 | --test_split_name="validation.other+test.clean+test.other" \
9 | --max_steps="5000" \
10 | --text_column_name="text" \
11 | --output_dir="./" \
12 | --run_name="whisper-ls-960h-5k" \
13 | --wandb_project="whisper" \
14 | --per_device_train_batch_size="64" \
15 | --per_device_eval_batch_size="16" \
16 | --logging_steps="25" \
17 | --learning_rate="1e-4" \
18 | --warmup_steps="500" \
19 | --report_to="wandb" \
20 | --preprocessing_num_workers="16" \
21 | --evaluation_strategy="steps" \
22 | --eval_steps="1000" \
23 | --save_strategy="steps" \
24 | --save_steps="1000" \
25 | --generation_max_length="224" \
26 | --length_column_name="input_lengths" \
27 | --do_lower_case="True" \
28 | --push_to_hub="False" \
29 | --gradient_checkpointing \
30 | --group_by_length \
31 | --freeze_encoder \
32 | --fp16 \
33 | --overwrite_output_dir \
34 | --do_train \
35 | --do_eval \
36 | --do_predict \
37 | --predict_with_generate \
38 | --use_auth_token
39 |
--------------------------------------------------------------------------------
/scripts/ctc/run_voxpopuli.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python run_flax_speech_recognition_ctc.py \
3 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-ctc-voxpopuli-black-box" \
4 | --tokenizer_name="sanchit-gandhi/flax-wav2vec2-ctc-voxpopuli-black-box" \
5 | --dataset_name="polinaeterna/voxpopuli" \
6 | --dataset_config_name="en" \
7 | --train_split_name="train" \
8 | --eval_split_name="validation" \
9 | --test_split_name="test" \
10 | --text_column_name="normalized_text" \
11 | --output_dir="./flax-wav2vec2-ctc-voxpopuli-black-box" \
12 | --wandb_project="voxpopuli" \
13 | --wandb_name="flax-wav2vec2-ctc-voxpopuli-black-box" \
14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
15 | --max_steps="50000" \
16 | --save_steps="10000" \
17 | --eval_steps="10000" \
18 | --learning_rate="3e-4" \
19 | --logging_steps="25" \
20 | --warmup_steps="5000" \
21 | --preprocessing_num_workers="1" \
22 | --per_device_eval_batch_size="1" \
23 | --do_train \
24 | --do_eval \
25 | --do_predict \
26 | --overwrite_output_dir \
27 | --gradient_checkpointing \
28 | --freeze_feature_encoder \
29 | --push_to_hub \
30 | --use_auth_token
--------------------------------------------------------------------------------
/scripts/ctc/run_common_voice_9.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python run_flax_speech_recognition_ctc.py \
3 | --model_name_or_path="speech-seq2seq/flax-wav2vec2-large-lv60-scan" \
4 | --tokenizer_name="sanchit-gandhi/wav2vec2-ctc-cv9-black-box-tokenizer" \
5 | --dataset_name="mozilla-foundation/common_voice_9_0" \
6 | --dataset_config_name="en" \
7 | --train_split_name="train" \
8 | --eval_split_name="validation" \
9 | --test_split_name="test" \
10 | --text_column_name="sentence" \
11 | --output_dir="./flax-wav2vec2-ctc-cv9-black-box" \
12 | --wandb_project="common_voice_9_0" \
13 | --wandb_name="flax-wav2vec2-ctc-cv9-black-box" \
14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
15 | --max_steps="50000" \
16 | --save_steps="10000" \
17 | --eval_steps="10000" \
18 | --learning_rate="3e-4" \
19 | --logging_steps="25" \
20 | --warmup_steps="5000" \
21 | --preprocessing_num_workers="1" \
22 | --do_lower_case="False" \
23 | --max_eval_duration_in_seconds="20" \
24 | --do_train \
25 | --do_eval \
26 | --do_predict \
27 | --overwrite_output_dir \
28 | --gradient_checkpointing \
29 | --freeze_feature_encoder \
30 | --push_to_hub \
31 | --use_auth_token
--------------------------------------------------------------------------------
/scripts/ctc/run_tedlium.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python run_flax_speech_recognition_ctc.py \
3 | --model_name_or_path="speech-seq2seq/flax-wav2vec2-large-lv60-scan" \
4 | --tokenizer_name="sanchit-gandhi/wav2vec2-ctc-tedlium-black-box-tokenizer" \
5 | --dataset_name="LIUM/tedlium" \
6 | --dataset_config_name="release3" \
7 | --train_split_name="train" \
8 | --eval_split_name="validation" \
9 | --test_split_name="test" \
10 | --text_column_name="text" \
11 | --output_dir="./flax-wav2vec2-ctc-tedlium-black-box" \
12 | --wandb_project="tedlium" \
13 | --wandb_name="flax-wav2vec2-ctc-tedlium-black-box" \
14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
15 | --max_steps="50000" \
16 | --save_steps="10000" \
17 | --eval_steps="10000" \
18 | --learning_rate="3e-4" \
19 | --logging_steps="25" \
20 | --warmup_steps="5000" \
21 | --preprocessing_num_workers="1" \
22 | --hidden_dropout="0.2" \
23 | --activation_dropout="0.2" \
24 | --feat_proj_dropout="0.2" \
25 | --do_train \
26 | --do_eval \
27 | --do_predict \
28 | --overwrite_output_dir \
29 | --gradient_checkpointing \
30 | --freeze_feature_encoder \
31 | --push_to_hub \
32 | --use_auth_token
--------------------------------------------------------------------------------
/scripts/ctc/run_ami.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python run_flax_speech_recognition_ctc.py \
3 | --model_name_or_path="speech-seq2seq/flax-wav2vec2-large-lv60-scan" \
4 | --tokenizer_name="patrickvonplaten/wav2vec2_ctc_ami_tokenizer" \
5 | --dataset_name="speech-seq2seq/ami" \
6 | --dataset_config_name="ihm" \
7 | --train_split_name="train" \
8 | --eval_split_name="validation" \
9 | --test_split_name="test" \
10 | --text_column_name="text" \
11 | --output_dir="./flax-wav2vec2-ctc-ami-black-box" \
12 | --wandb_project="ami" \
13 | --wandb_name="flax-wav2vec2-ctc-ami-black-box" \
14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
15 | --max_steps="50000" \
16 | --save_steps="10000" \
17 | --eval_steps="10000" \
18 | --learning_rate="3e-4" \
19 | --logging_steps="25" \
20 | --warmup_steps="5000" \
21 | --preprocessing_num_workers="1" \
22 | --hidden_dropout="0.2" \
23 | --activation_dropout="0.2" \
24 | --feat_proj_dropout="0.2" \
25 | --do_lower_case="False" \
26 | --do_train \
27 | --do_eval \
28 | --do_predict \
29 | --overwrite_output_dir \
30 | --gradient_checkpointing \
31 | --freeze_feature_encoder \
32 | --push_to_hub \
33 | --use_auth_token
--------------------------------------------------------------------------------
/scripts/ctc/run_switchboard.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python run_flax_speech_recognition_ctc.py \
3 | --model_name_or_path="speech-seq2seq/flax-wav2vec2-large-lv60-scan" \
4 | --tokenizer_name="sanchit-gandhi/wav2vec2-ctc-switchboard-black-box-tokenizer" \
5 | --dataset_name="ldc/switchboard" \
6 | --dataset_config_name="all" \
7 | --train_split_name="train.fisher+train.switchboard" \
8 | --eval_split_name="validation" \
9 | --test_split_name="test.switchboard+test.callhome" \
10 | --text_column_name="test" \
11 | --output_dir="./flax-wav2vec2-ctc-switchboard-fisher-black-box" \
12 | --wandb_project="switchboard" \
13 | --wandb_name="flax-wav2vec2-ctc-switchboard-fisher-black-box" \
14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
15 | --max_steps="50000" \
16 | --save_steps="10000" \
17 | --eval_steps="10000" \
18 | --learning_rate="3e-4" \
19 | --logging_steps="25" \
20 | --warmup_steps="5000" \
21 | --preprocessing_num_workers="1" \
22 | --do_lower_case="False" \
23 | --torchaudio_resampler="True" \
24 | --do_train \
25 | --do_eval \
26 | --do_predict \
27 | --overwrite_output_dir \
28 | --gradient_checkpointing \
29 | --freeze_feature_encoder \
30 | --push_to_hub \
31 | --use_auth_token
--------------------------------------------------------------------------------
/scripts/seq2seq/run_ami.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python run_flax_speech_recognition_seq2seq.py \
3 | --dataset_name="speech-seq2seq/ami" \
4 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \
5 | --dataset_config_name="ihm" \
6 | --id_column_name="audio_id" \
7 | --output_dir="./flax-wav2vec2-2-bart-large-ami-black-box" \
8 | --wandb_project="ami" \
9 | --wandb_name="flax-wav2vec2-2-bart-large-ami-black-box" \
10 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
11 | --per_device_train_batch_size="8" \
12 | --per_device_eval_batch_size="4" \
13 | --learning_rate="1e-4" \
14 | --warmup_steps="500" \
15 | --logging_steps="25" \
16 | --max_steps="50000" \
17 | --eval_steps="10000" \
18 | --save_steps="10000" \
19 | --generation_max_length="200" \
20 | --generation_num_beams="5" \
21 | --generation_length_penalty="1.2" \
22 | --hidden_dropout="0.2" \
23 | --activation_dropout="0.2" \
24 | --feat_proj_dropout="0.2" \
25 | --do_lower_case="False" \
26 | --overwrite_output_dir \
27 | --gradient_checkpointing \
28 | --freeze_feature_encoder \
29 | --predict_with_generate \
30 | --do_eval \
31 | --do_train \
32 | --do_predict \
33 | --push_to_hub \
34 | --use_auth_token
--------------------------------------------------------------------------------
/scripts/seq2seq/run_gigaspeech.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python run_flax_speech_recognition_seq2seq.py \
3 | --dataset_name="speechcolab/gigaspeech" \
4 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \
5 | --dataset_config_name="l" \
6 | --train_split_name="train" \
7 | --eval_split_name="validation" \
8 | --test_split_name="test" \
9 | --text_column_name="text" \
10 | --id_column_name="segment_id" \
11 | --output_dir="./" \
12 | --wandb_project="gigaspeech" \
13 | --wandb_name="flax-wav2vec2-2-bart-large-gs-black-box" \
14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
15 | --per_device_train_batch_size="8" \
16 | --per_device_eval_batch_size="2" \
17 | --learning_rate="1e-4" \
18 | --warmup_steps="500" \
19 | --logging_steps="25" \
20 | --max_steps="50000" \
21 | --eval_steps="10000" \
22 | --save_steps="10000" \
23 | --generation_max_length="200" \
24 | --generation_num_beams="5" \
25 | --generation_length_penalty="1.2" \
26 | --overwrite_output_dir \
27 | --gradient_checkpointing \
28 | --freeze_feature_encoder \
29 | --predict_with_generate \
30 | --do_lower_case \
31 | --do_eval \
32 | --do_train \
33 | --do_predict \
34 | --push_to_hub \
35 | --use_auth_token
36 |
--------------------------------------------------------------------------------
/scripts/ctc/run_librispeech.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python run_flax_speech_recognition_ctc.py \
3 | --model_name_or_path="speech-seq2seq/flax-wav2vec2-large-lv60-scan" \
4 | --tokenizer_name="sanchit-gandhi/wav2vec2-ctc-ls-black-box-tokenizer" \
5 | --dataset_name="librispeech_asr" \
6 | --dataset_config_name="all" \
7 | --train_split_name="train.clean.100+train.clean.360+train.other.500" \
8 | --eval_split_name="validation.clean" \
9 | --test_split_name="validation.other+test.clean+test.other" \
10 | --text_column_name="text" \
11 | --output_dir="./flax-wav2vec2-ctc-ls-960h-black-box" \
12 | --wandb_project="librispeech_960h" \
13 | --wandb_name="flax-wav2vec2-ctc-ls-960h-black-box" \
14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
15 | --max_steps="50000" \
16 | --save_steps="10000" \
17 | --eval_steps="10000" \
18 | --learning_rate="3e-4" \
19 | --logging_steps="25" \
20 | --warmup_steps="5000" \
21 | --preprocessing_num_workers="1" \
22 | --hidden_dropout="0.2" \
23 | --activation_dropout="0.2" \
24 | --feat_proj_dropout="0.2" \
25 | --do_train \
26 | --do_eval \
27 | --do_predict \
28 | --overwrite_output_dir \
29 | --gradient_checkpointing \
30 | --freeze_feature_encoder \
31 | --push_to_hub \
32 | --use_auth_token
--------------------------------------------------------------------------------
/scripts/seq2seq/run_spgispeech.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python run_flax_speech_recognition_seq2seq.py \
3 | --dataset_name="kensho/spgispeech" \
4 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \
5 | --dataset_config_name="L" \
6 | --train_split_name="train" \
7 | --eval_split_name="validation" \
8 | --test_split_name="test" \
9 | --text_column_name="transcript" \
10 | --id_column_name="wav_filename" \
11 | --output_dir="./" \
12 | --wandb_project="spgispeech" \
13 | --wandb_name="flax-wav2vec2-2-bart-large-spgispeech-black-box" \
14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
15 | --per_device_train_batch_size="8" \
16 | --per_device_eval_batch_size="2" \
17 | --learning_rate="1e-4" \
18 | --warmup_steps="500" \
19 | --logging_steps="25" \
20 | --max_steps="50000" \
21 | --eval_steps="10000" \
22 | --save_steps="10000" \
23 | --generation_max_length="200" \
24 | --generation_num_beams="5" \
25 | --generation_length_penalty="1.2" \
26 | --do_lower_case="False" \
27 | --overwrite_output_dir \
28 | --gradient_checkpointing \
29 | --freeze_feature_encoder \
30 | --predict_with_generate \
31 | --do_eval \
32 | --do_train \
33 | --do_predict \
34 | --push_to_hub \
35 | --use_auth_token
--------------------------------------------------------------------------------
/scripts/ctc/run_earnings22.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python run_flax_speech_recognition_ctc.py \
3 | --model_name_or_path="speech-seq2seq/flax-wav2vec2-large-lv60-scan" \
4 | --tokenizer_name="sanchit-gandhi/wav2vec2-ctc-earnings22-black-box-tokenizer" \
5 | --dataset_name="sanchit-gandhi/earnings22" \
6 | --dataset_config_name="all" \
7 | --train_split_name="train" \
8 | --eval_split_name="validation" \
9 | --test_split_name="test" \
10 | --text_column_name="sentence" \
11 | --output_dir="./flax-wav2vec2-ctc-earnings22-black-box" \
12 | --wandb_project="earnings22" \
13 | --wandb_name="flax-wav2vec2-ctc-earnings22-black-box" \
14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
15 | --max_steps="50000" \
16 | --save_steps="10000" \
17 | --eval_steps="10000" \
18 | --learning_rate="3e-4" \
19 | --logging_steps="25" \
20 | --warmup_steps="5000" \
21 | --preprocessing_num_workers="1" \
22 | --do_lower_case="False" \
23 | --hidden_dropout="0.2" \
24 | --activation_dropout="0.2" \
25 | --feat_proj_dropout="0.2" \
26 | --ignore_verifications="False" \
27 | --do_train \
28 | --do_eval \
29 | --do_predict \
30 | --overwrite_output_dir \
31 | --gradient_checkpointing \
32 | --freeze_feature_encoder \
33 | --push_to_hub \
34 | --use_auth_token
--------------------------------------------------------------------------------
/run_seq2seq_dummy.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python run_flax_speech_recognition_seq2seq.py \
3 | --dataset_name="librispeech_asr" \
4 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \
5 | --dataset_config_name="clean" \
6 | --train_split_name="train.100" \
7 | --eval_split_name="validation" \
8 | --test_split_name="test" \
9 | --text_column_name="text" \
10 | --id_column_name="id" \
11 | --output_dir="./" \
12 | --wandb_project="librispeech_960h" \
13 | --wandb_name="flax-wav2vec2-2-bart-large-ls-960h-baseline" \
14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
15 | --per_device_train_batch_size="8" \
16 | --per_device_eval_batch_size="4" \
17 | --logging_steps="25" \
18 | --max_steps="50000" \
19 | --eval_steps="10000" \
20 | --save_steps="10000" \
21 | --generation_max_length="40" \
22 | --generation_num_beams="1" \
23 | --generation_length_penalty="1.2" \
24 | --final_generation_max_length="200" \
25 | --final_generation_num_beams="5" \
26 | --learning_rate="1e-4" \
27 | --warmup_steps="500" \
28 | --overwrite_output_dir \
29 | --gradient_checkpointing \
30 | --freeze_feature_encoder \
31 | --predict_with_generate \
32 | --do_lower_case \
33 | --do_eval \
34 | --do_train \
35 | --do_predict \
36 | --push_to_hub \
37 | --use_auth_token
38 |
--------------------------------------------------------------------------------
/scripts/seq2seq/run_common_voice_9.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python run_flax_speech_recognition_seq2seq.py \
3 | --dataset_name="mozilla-foundation/common_voice_9_0" \
4 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \
5 | --dataset_config_name="en" \
6 | --train_split_name="train" \
7 | --eval_split_name="validation" \
8 | --test_split_name="test" \
9 | --text_column_name="sentence" \
10 | --id_column_name="client_id" \
11 | --output_dir="./" \
12 | --wandb_project="common_voice_9_0" \
13 | --wandb_name="flax-wav2vec2-2-bart-large-cv9-black-box" \
14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
15 | --per_device_train_batch_size="8" \
16 | --per_device_eval_batch_size="2" \
17 | --learning_rate="1e-4" \
18 | --warmup_steps="500" \
19 | --logging_steps="25" \
20 | --max_steps="50000" \
21 | --eval_steps="10000" \
22 | --save_steps="10000" \
23 | --generation_max_length="200" \
24 | --generation_num_beams="5" \
25 | --generation_length_penalty="1.2" \
26 | --do_lower_case="False" \
27 | --max_eval_duration_in_seconds="20" \
28 | --overwrite_output_dir \
29 | --gradient_checkpointing \
30 | --freeze_feature_encoder \
31 | --predict_with_generate \
32 | --do_eval \
33 | --do_train \
34 | --do_predict \
35 | --push_to_hub \
36 | --use_auth_token
37 |
--------------------------------------------------------------------------------
/scripts/seq2seq/run_tedlium.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python run_flax_speech_recognition_seq2seq.py \
3 | --dataset_name="LIUM/tedlium" \
4 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \
5 | --dataset_config_name="release3" \
6 | --train_split_name="train" \
7 | --eval_split_name="validation" \
8 | --test_split_name="test" \
9 | --text_column_name="text" \
10 | --id_column_name="id" \
11 | --output_dir="./" \
12 | --wandb_project="tedlium" \
13 | --wandb_name="flax-wav2vec2-2-bart-large-tedlium-black-box" \
14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
15 | --per_device_train_batch_size="8" \
16 | --per_device_eval_batch_size="2" \
17 | --learning_rate="1e-4" \
18 | --warmup_steps="500" \
19 | --logging_steps="25" \
20 | --max_steps="50000" \
21 | --eval_steps="10000" \
22 | --save_steps="10000" \
23 | --generation_max_length="200" \
24 | --generation_num_beams="5" \
25 | --generation_length_penalty="1.2" \
26 | --hidden_dropout="0.2" \
27 | --activation_dropout="0.2" \
28 | --feat_proj_dropout="0.2" \
29 | --overwrite_output_dir \
30 | --gradient_checkpointing \
31 | --freeze_feature_encoder \
32 | --predict_with_generate \
33 | --do_lower_case \
34 | --do_eval \
35 | --do_train \
36 | --do_predict \
37 | --push_to_hub \
38 | --use_auth_token
--------------------------------------------------------------------------------
/scripts/seq2seq/run_switchboard.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python run_flax_speech_recognition_seq2seq.py \
3 | --dataset_name="ldc/switchboard" \
4 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \
5 | --dataset_config_name="all" \
6 | --train_split_name="train.fisher+train.switchboard" \
7 | --eval_split_name="validation" \
8 | --test_split_name="test.switchboard+test.callhome" \
9 | --text_column_name="text" \
10 | --id_column_name="id" \
11 | --output_dir="./flax-wav2vec2-2-bart-large-switchboard-fisher-black-box" \
12 | --wandb_project="switchboard" \
13 | --wandb_name="flax-wav2vec2-2-bart-large-switchboard-fisher-black-box" \
14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
15 | --per_device_train_batch_size="8" \
16 | --per_device_eval_batch_size="2" \
17 | --learning_rate="1e-4" \
18 | --warmup_steps="500" \
19 | --logging_steps="25" \
20 | --max_steps="50000" \
21 | --eval_steps="10000" \
22 | --save_steps="10000" \
23 | --generation_max_length="200" \
24 | --generation_num_beams="5" \
25 | --generation_length_penalty="1.2" \
26 | --torchaudio_resampler="True" \
27 | --overwrite_output_dir \
28 | --gradient_checkpointing \
29 | --freeze_feature_encoder \
30 | --predict_with_generate \
31 | --do_eval \
32 | --do_train \
33 | --do_predict \
34 | --push_to_hub \
35 | --use_auth_token
--------------------------------------------------------------------------------
/scripts/seq2seq/run_voxpopuli.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python run_flax_speech_recognition_seq2seq.py \
3 | --dataset_name="google/xtreme_s" \
4 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \
5 | --dataset_config_name="voxpopuli.en" \
6 | --train_split_name="train" \
7 | --eval_split_name="validation" \
8 | --test_split_name="test" \
9 | --text_column_name="sentence" \
10 | --id_column_name="id" \
11 | --output_dir="./flax-wav2vec2-2-bart-large-voxpopuli-black-box" \
12 | --wandb_project="voxpopuli" \
13 | --wandb_name="flax-wav2vec2-2-bart-large-voxpopuli-black-box" \
14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
15 | --per_device_train_batch_size="8" \
16 | --per_device_eval_batch_size="1" \
17 | --learning_rate="1e-4" \
18 | --warmup_steps="500" \
19 | --logging_steps="25" \
20 | --max_steps="50000" \
21 | --eval_steps="10000" \
22 | --save_steps="10000" \
23 | --generation_max_length="200" \
24 | --generation_num_beams="5" \
25 | --generation_length_penalty="1.2" \
26 | --hidden_dropout="0.2" \
27 | --activation_dropout="0.2" \
28 | --feat_proj_dropout="0.2" \
29 | --overwrite_output_dir \
30 | --gradient_checkpointing \
31 | --freeze_feature_encoder \
32 | --predict_with_generate \
33 | --do_lower_case \
34 | --do_eval \
35 | --do_train \
36 | --do_predict \
37 | --push_to_hub \
38 | --use_auth_token
--------------------------------------------------------------------------------
/scripts/seq2seq/run_librispeech.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python run_flax_speech_recognition_seq2seq.py \
3 | --dataset_name="librispeech_asr" \
4 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \
5 | --dataset_config_name="all" \
6 | --train_split_name="train.clean.100+train.clean.360+train.other.500" \
7 | --eval_split_name="validation.clean" \
8 | --test_split_name="validation.other+test.clean+test.other" \
9 | --text_column_name="text" \
10 | --id_column_name="id" \
11 | --output_dir="./" \
12 | --wandb_project="librispeech_960h" \
13 | --wandb_name="flax-wav2vec2-2-bart-large-ls-960h-black-box" \
14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
15 | --per_device_train_batch_size="8" \
16 | --per_device_eval_batch_size="2" \
17 | --learning_rate="1e-4" \
18 | --warmup_steps="500" \
19 | --logging_steps="25" \
20 | --max_steps="50000" \
21 | --eval_steps="10000" \
22 | --save_steps="10000" \
23 | --generation_max_length="200" \
24 | --generation_num_beams="5" \
25 | --generation_length_penalty="1.2" \
26 | --hidden_dropout="0.2" \
27 | --activation_dropout="0.2" \
28 | --feat_proj_dropout="0.2" \
29 | --overwrite_output_dir \
30 | --gradient_checkpointing \
31 | --freeze_feature_encoder \
32 | --predict_with_generate \
33 | --do_lower_case \
34 | --do_eval \
35 | --do_train \
36 | --do_predict \
37 | --push_to_hub \
38 | --use_auth_token
39 |
--------------------------------------------------------------------------------
/scripts/seq2seq/run_earnings22.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python run_flax_speech_recognition_seq2seq.py \
3 | --dataset_name="sanchit-gandhi/earnings22" \
4 | --model_name_or_path="sanchit-gandhi/flax-wav2vec2-2-bart-large-scan" \
5 | --dataset_config_name="all" \
6 | --train_split_name="train" \
7 | --eval_split_name="validation" \
8 | --test_split_name="test" \
9 | --text_column_name="sentence" \
10 | --id_column_name="id" \
11 | --output_dir="./flax-wav2vec2-2-bart-large-earnings22-black-box" \
12 | --wandb_project="earnings22" \
13 | --wandb_name="flax-wav2vec2-2-bart-large-earnings22-black-box" \
14 | --dataset_cache_dir="/home/sanchitgandhi/cache/huggingface/datasets" \
15 | --per_device_train_batch_size="8" \
16 | --per_device_eval_batch_size="4" \
17 | --logging_steps="25" \
18 | --max_steps="50000" \
19 | --eval_steps="10000" \
20 | --save_steps="10000" \
21 | --generation_max_length="40" \
22 | --generation_num_beams="1" \
23 | --generation_length_penalty="1.2" \
24 | --final_generation_max_length="200" \
25 | --final_generation_num_beams="5" \
26 | --learning_rate="1e-4" \
27 | --warmup_steps="500" \
28 | --do_lower_case="False" \
29 | --hidden_dropout="0.2" \
30 | --activation_dropout="0.2" \
31 | --feat_proj_dropout="0.2" \
32 | --ignore_verifications="False" \
33 | --overwrite_output_dir \
34 | --gradient_checkpointing \
35 | --freeze_feature_encoder \
36 | --predict_with_generate \
37 | --do_eval \
38 | --do_train \
39 | --do_predict \
40 | --push_to_hub \
41 | --use_auth_token
--------------------------------------------------------------------------------
/tests/check_flax_ctc.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import tempfile
3 | import numpy as np
4 | import torch
5 | from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
6 | from models.modeling_flax_wav2vec2 import FlaxWav2Vec2ForCTC
7 | import jax.numpy as jnp
8 | from datasets import load_dataset
9 | from run_flax_speech_recognition_ctc import ctc_loss
10 |
11 | model_id = "facebook/wav2vec2-large-lv60"
12 | # model_id = "hf-internal-testing/tiny-random-wav2vec2"
13 |
14 | processor = Wav2Vec2Processor.from_pretrained(model_id, return_attention_mask=True)
15 | # in PyTorch we always use 'mean' by default.
16 | # See: https://github.com/huggingface/transformers/blob/93b802c43e70f41edfb166ad6ead79de95a26c32/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py#L126
17 | model_pt = Wav2Vec2ForCTC.from_pretrained(model_id, ctc_loss_reduction="mean")
18 |
19 | with tempfile.TemporaryDirectory() as temp_folder:
20 | model_pt.save_pretrained(temp_folder)
21 | model_fx = FlaxWav2Vec2ForCTC.from_pretrained(temp_folder, from_pt=True)
22 |
23 | # load dummy dataset and read soundfiles
24 | ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
25 |
26 | # inputs pt
27 | samples = [d["array"] for d in ds[:4]["audio"]]
28 | inputs_pt = processor(samples, return_tensors="pt", padding="longest", sampling_rate=16_000)
29 |
30 | # inputs fx
31 | inputs_fx = processor(samples, return_tensors="np", padding="longest", sampling_rate=16_000)
32 |
33 | # labels
34 | transcription = ds[:4]["text"]
35 | labels_pt = processor.tokenizer(transcription, return_tensors="pt", padding="longest")
36 | labels_ids_pt = labels_pt["input_ids"].masked_fill(labels_pt.attention_mask.ne(1), -100)
37 |
38 | labels_fx = processor.tokenizer(transcription, return_tensors="np", padding="longest")
39 | labels_ids_fx = jnp.where(labels_fx.attention_mask == 0, -100, labels_fx.input_ids)
40 |
41 | # pytorch
42 | with torch.no_grad():
43 | outputs = model_pt(**inputs_pt, labels=labels_ids_pt)
44 |
45 | logits_pt = outputs.logits
46 | loss_pt = outputs.loss
47 |
48 | # flax
49 | logits_fx = model_fx(**inputs_fx).logits
50 | logits_attention_mask = model_fx._get_feature_vector_attention_mask(logits_fx.shape[1], inputs_fx.attention_mask)
51 |
52 | # Check logits the same
53 | logits_diff = np.abs((logits_pt.detach().numpy() - np.asarray(logits_fx))).max()
54 | assert logits_diff < 1e-3, "Logits don't match"
55 |
56 | # flax loss
57 | blank_id = model_fx.config.pad_token_id
58 | loss_fx = ctc_loss(logits_fx, logits_attention_mask, labels_ids_fx, blank_id)
59 |
60 | # Check loss is the same
61 | loss_diff = np.asarray(loss_fx) - loss_pt.numpy()
62 | assert loss_diff < 1e-3, "Loss doesn't match"
63 |
--------------------------------------------------------------------------------
/.idea/sshConfigs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
--------------------------------------------------------------------------------
/tests/check_flax_ctc_cv9.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import tempfile
3 | import numpy as np
4 | import torch
5 | from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, AutoTokenizer, AutoFeatureExtractor
6 | from models.modeling_flax_wav2vec2 import FlaxWav2Vec2ForCTC
7 | import jax.numpy as jnp
8 | from datasets import load_dataset
9 | import datasets
10 | from run_flax_speech_recognition_ctc import ctc_loss
11 |
12 | model_id = "facebook/wav2vec2-large-lv60"
13 | # model_id = "hf-internal-testing/tiny-random-wav2vec2"
14 | tokenizer_id = "patrickvonplaten/wav2vec2_ctc_cv9_tokenizer"
15 |
16 | feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
17 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_id if tokenizer_id else model_id, return_attention_mask=True)
18 |
19 | # in PyTorch we always use 'mean' by default.
20 | # See: https://github.com/huggingface/transformers/blob/93b802c43e70f41edfb166ad6ead79de95a26c32/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py#L126
21 | model_pt = Wav2Vec2ForCTC.from_pretrained(model_id, ctc_loss_reduction="mean")
22 |
23 | with tempfile.TemporaryDirectory() as temp_folder:
24 | model_pt.save_pretrained(temp_folder)
25 | model_fx = FlaxWav2Vec2ForCTC.from_pretrained(temp_folder, from_pt=True)
26 |
27 |
28 | # load CV9 dataset and read soundfiles
29 | ds = load_dataset("mozilla-foundation/common_voice_9_0", "en", split="train[:1%]", cache_dir="/home/sanchitgandhi/cache/huggingface/datasets/")
30 |
31 | # resample dataset to 16kHz on the fly
32 | ds = ds.cast_column(
33 | "audio", datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
34 | )
35 |
36 | # inputs pt
37 | samples = [d["array"] for d in ds[:4]["audio"]]
38 | inputs_pt = feature_extractor(samples, return_tensors="pt", padding="longest", sampling_rate=16_000)
39 |
40 | # inputs fx
41 | inputs_fx = feature_extractor(samples, return_tensors="np", padding="longest", sampling_rate=16_000)
42 |
43 | # labels
44 | transcription = ds[:4]["sentence"]
45 | labels_pt = tokenizer(transcription, return_tensors="pt", padding="longest")
46 | labels_ids_pt = labels_pt["input_ids"].masked_fill(labels_pt.attention_mask.ne(1), -100)
47 |
48 | labels_fx = tokenizer(transcription, return_tensors="np", padding="longest")
49 | labels_ids_fx = jnp.where(labels_fx.attention_mask == 0, -100, labels_fx.input_ids)
50 |
51 | # set pt model config to accommodate for CV9 tokenizer vocab size
52 | model_pt.config.vocab_size = tokenizer.vocab_size
53 |
54 | # pytorch
55 | with torch.no_grad():
56 | outputs = model_pt(**inputs_pt, labels=labels_ids_pt)
57 |
58 | logits_pt = outputs.logits
59 | loss_pt = outputs.loss
60 |
61 | # flax
62 | logits_fx = model_fx(**inputs_fx).logits
63 | logits_attention_mask = model_fx._get_feature_vector_attention_mask(logits_fx.shape[1], inputs_fx.attention_mask)
64 |
65 | # Check logits the same
66 | logits_diff = np.abs((logits_pt.detach().numpy() - np.asarray(logits_fx))).max()
67 | assert logits_diff < 1e-3, "Logits don't match"
68 |
69 | # flax loss
70 | blank_id = model_fx.config.pad_token_id
71 | loss_fx = ctc_loss(logits_fx, logits_attention_mask, labels_ids_fx, blank_id)
72 |
73 | # Check loss is the same
74 | loss_diff = np.asarray(loss_fx) - loss_pt.numpy()
75 | assert loss_diff < 1e-3, "Loss doesn't match"
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Initially taken from Github's Python gitignore file
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 |
9 | # Created by https://www.toptal.com/developers/gitignore/api/intellij
10 | # Edit at https://www.toptal.com/developers/gitignore?templates=intellij
11 |
12 | ### Intellij ###
13 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
14 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
15 |
16 | # User-specific stuff
17 | .idea/**/workspace.xml
18 | .idea/**/tasks.xml
19 | .idea/**/usage.statistics.xml
20 | .idea/**/dictionaries
21 | .idea/**/shelf
22 |
23 | # AWS User-specific
24 | .idea/**/aws.xml
25 |
26 | # Generated files
27 | .idea/**/contentModel.xml
28 |
29 | # Sensitive or high-churn files
30 | .idea/**/dataSources/
31 | .idea/**/dataSources.ids
32 | .idea/**/dataSources.local.xml
33 | .idea/**/sqlDataSources.xml
34 | .idea/**/dynamic.xml
35 | .idea/**/uiDesigner.xml
36 | .idea/**/dbnavigator.xml
37 |
38 | # Gradle
39 | .idea/**/gradle.xml
40 | .idea/**/libraries
41 |
42 | # Gradle and Maven with auto-import
43 | # When using Gradle or Maven with auto-import, you should exclude module files,
44 | # since they will be recreated, and may cause churn. Uncomment if using
45 | # auto-import.
46 | # .idea/artifacts
47 | # .idea/compiler.xml
48 | # .idea/jarRepositories.xml
49 | # .idea/modules.xml
50 | # .idea/*.iml
51 | # .idea/modules
52 | # *.iml
53 | # *.ipr
54 |
55 | # CMake
56 | cmake-build-*/
57 |
58 | # Mongo Explorer plugin
59 | .idea/**/mongoSettings.xml
60 |
61 | # File-based project format
62 | *.iws
63 |
64 | # IntelliJ
65 | out/
66 |
67 | # mpeltonen/sbt-idea plugin
68 | .idea_modules/
69 |
70 | # JIRA plugin
71 | atlassian-ide-plugin.xml
72 |
73 | # Cursive Clojure plugin
74 | .idea/replstate.xml
75 |
76 | # SonarLint plugin
77 | .idea/sonarlint/
78 |
79 | # Crashlytics plugin (for Android Studio and IntelliJ)
80 | com_crashlytics_export_strings.xml
81 | crashlytics.properties
82 | crashlytics-build.properties
83 | fabric.properties
84 |
85 | # Editor-based Rest Client
86 | .idea/httpRequests
87 |
88 | # Android studio 3.1+ serialized cache file
89 | .idea/caches/build_file_checksums.ser
90 |
91 | ### Intellij Patch ###
92 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
93 |
94 | # *.iml
95 | # modules.xml
96 | # .idea/misc.xml
97 | # *.ipr
98 |
99 | # Sonarlint plugin
100 | # https://plugins.jetbrains.com/plugin/7973-sonarlint
101 | .idea/**/sonarlint/
102 |
103 | # SonarQube Plugin
104 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin
105 | .idea/**/sonarIssues.xml
106 |
107 | # Markdown Navigator plugin
108 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced
109 | .idea/**/markdown-navigator.xml
110 | .idea/**/markdown-navigator-enh.xml
111 | .idea/**/markdown-navigator/
112 |
113 | # Cache file creation bug
114 | # See https://youtrack.jetbrains.com/issue/JBR-2257
115 | .idea/$CACHE_FILE$
116 |
117 | # CodeStream plugin
118 | # https://plugins.jetbrains.com/plugin/12206-codestream
119 | .idea/codestream.xml
120 |
121 | # End of https://www.toptal.com/developers/gitignore/api/intellij
122 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Seq2Seq Speech in JAX
2 | A [JAX](https://jax.readthedocs.io/en/latest/)/[Flax](https://flax.readthedocs.io/en/latest/) repository for combining a pre-trained speech encoder model (e.g. Wav2Vec2, HuBERT, WavLM) with a pre-trained text decoder model (e.g. GPT2, Bart) to yield a Speech Sequence-to-Sequence (Seq2Seq) model for automatic speech recognition.
3 |
4 | The script `run_flax_speech_recognition_seq2seq.py` can be used to fine-tune a Speech Seq2Seq model on one of the official speech recognition datasets or a custom dataset. It makes use of the `pmap` JAX operator to provide data parallelism accross GPU/TPU devices.
5 |
6 | The modelling files are based very heavily on those from Hugging Face [Transformers 🤗](https://github.com/huggingface/transformers). This is a standalone repository to enable rapid prototyping and involvement with the community. The final modelling files and training script will be merged into Transformers 🤗 to be used with the rest of the open-source library. The final system weights will be made publicly available at [huggingface.co](huggingface.co) 🚀
7 |
8 | 
9 | **Figure 1:** Speech-encoder text-decoder style Seq2Seq model.
10 |
11 | ## Example Usage
12 | To instantiate a _Wav2Vec2-2-Bart_ model with the `FlaxSpeechEncoderDecoderModel` framework, run the following Python script inside the cloned repo:
13 | ```python
14 | from transformers import AutoFeatureExtractor, AutoTokenizer
15 | from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
16 | import numpy as np
17 |
18 | # checkpoints to leverage
19 | encoder_id = "facebook/wav2vec2-large-lv60"
20 | decoder_id = "facebook/bart-large"
21 |
22 | model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
23 | encoder_id, decoder_id, encoder_add_adapter=True, decoder_from_pt=True)
24 |
25 | model.config.decoder_start_token_id = model.config.decoder.bos_token_id
26 | model.config.pad_token_id = model.config.decoder.pad_token_id
27 | model.config.eos_token_id = model.config.decoder.eos_token_id
28 | model.config.use_cache = False
29 | model.config.processor_class = "Wav2Vec2Processor"
30 |
31 | # check if generation works
32 | out = model.generate(np.ones((1, 2000)))
33 |
34 | model.save_pretrained("./")
35 |
36 | feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_id)
37 | feature_extractor.save_pretrained("./")
38 | tokenizer = AutoTokenizer.from_pretrained(decoder_id)
39 | tokenizer.save_pretrained("./")
40 | ```
41 |
42 | To train the model on [Librispeech ASR](https://huggingface.co/datasets/librispeech_asr), run the template bash script [`run_seq2seq_dummy.sh`](https://github.com/sanchit-gandhi/seq2seq-speech/blob/main/run_seq2seq_dummy.sh).
43 |
44 | ## Flax Whisper Model
45 |
46 | ```bash
47 | #!/usr/bin/env bash
48 | python run_flax_speech_recognition_seq2seq.py \
49 | --dataset_name="librispeech_asr" \
50 | --model_name_or_path="openai/whisper-small" \
51 | --dataset_config_name="clean" \
52 | --train_split_name="train.100" \
53 | --eval_split_name="validation" \
54 | --test_split_name="test" \
55 | --text_column_name="text" \
56 | --id_column_name="id" \
57 | --output_dir="./flax-whisper-ft-librispeech-clean" \
58 | --wandb_project="librispeech_clean" \
59 | --wandb_name="flax-whisper-ft-librispeech-clean" \
60 | --per_device_train_batch_size="8" \
61 | --per_device_eval_batch_size="2" \
62 | --learning_rate="1e-4" \
63 | --warmup_steps="500" \
64 | --logging_steps="25" \
65 | --max_steps="50000" \
66 | --eval_steps="10000" \
67 | --save_steps="10000" \
68 | --generation_max_length="200" \
69 | --generation_num_beams="5" \
70 | --generation_length_penalty="1.2" \
71 | --hidden_dropout="0.2" \
72 | --activation_dropout="0.2" \
73 | --feat_proj_dropout="0.2" \
74 | --overwrite_output_dir \
75 | --gradient_checkpointing \
76 | --freeze_feature_encoder \
77 | --predict_with_generate \
78 | --do_lower_case \
79 | --do_eval \
80 | --do_train \
81 | --do_predict \
82 | --push_to_hub \
83 | --use_auth_token
84 | ```
85 |
86 | Control the precision through the `--precision` arg:
87 | * Full precision (weights and optimiser in fp32): `--precision=full`
88 | * Half-mixed (weights in bf16, optimiser in fp32 ): `--precision=half_mixed`
89 | * Full-mixed (weights and optimiser in bf16): `--precision=full_mixed`
--------------------------------------------------------------------------------
/models/configuration_speech_encoder_decoder.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import copy
18 |
19 | from transformers.configuration_utils import PretrainedConfig
20 | from transformers.utils import logging
21 | from models.configuration_wav2vec2 import Wav2Vec2Config
22 | from models.configuration_bart import BartConfig
23 | from transformers import AutoConfig
24 |
25 |
26 | logger = logging.get_logger(__name__)
27 |
28 |
29 | class SpeechEncoderDecoderConfig(PretrainedConfig):
30 | r"""
31 | [`SpeechEncoderDecoderConfig`] is the configuration class to store the configuration of a
32 | [`SpeechEncoderDecoderModel`]. It is used to instantiate an Encoder Decoder model according to the specified
33 | arguments, defining the encoder and decoder configs.
34 |
35 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36 | documentation from [`PretrainedConfig`] for more information.
37 |
38 | Args:
39 | kwargs (*optional*):
40 | Dictionary of keyword arguments. Notably:
41 |
42 | - **encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
43 | the encoder config.
44 | - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
45 | the decoder config.
46 |
47 | Examples:
48 |
49 | ```python
50 | >>> from transformers import BertConfig, Wav2Vec2Config, SpeechEncoderDecoderConfig, SpeechEncoderDecoderModel
51 |
52 | >>> # Initializing a Wav2Vec2 & BERT style configuration
53 | >>> config_encoder = Wav2Vec2Config()
54 | >>> config_decoder = BertConfig()
55 |
56 | >>> config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
57 |
58 | >>> # Initializing a Wav2Vec2Bert model from a Wav2Vec2 & bert-base-uncased style configurations
59 | >>> model = SpeechEncoderDecoderModel(config=config)
60 |
61 | >>> # Accessing the model configuration
62 | >>> config_encoder = model.config.encoder
63 | >>> config_decoder = model.config.decoder
64 | >>> # set decoder config to causal lm
65 | >>> config_decoder.is_decoder = True
66 | >>> config_decoder.add_cross_attention = True
67 |
68 | >>> # Saving the model, including its configuration
69 | >>> model.save_pretrained("my-model")
70 |
71 | >>> # loading model and config from pretrained folder
72 | >>> encoder_decoder_config = SpeechEncoderDecoderConfig.from_pretrained("my-model")
73 | >>> model = SpeechEncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config)
74 | ```"""
75 | model_type = "speech-encoder-decoder"
76 | is_composition = True
77 |
78 | def __init__(self, **kwargs):
79 | super().__init__(**kwargs)
80 | if "encoder" not in kwargs or "decoder" not in kwargs:
81 | raise ValueError(
82 | f"A configuraton of type {self.model_type} cannot be instantiated because not both `encoder` and `decoder` sub-configurations are passed, but only {kwargs}"
83 | )
84 |
85 | encoder_config = kwargs.pop("encoder")
86 | decoder_config = kwargs.pop("decoder")
87 |
88 | # TODO: Load configs from AutoConfig (as done in Transformers 🤗)
89 | self.encoder = Wav2Vec2Config(**encoder_config)
90 | self.decoder = BartConfig(**decoder_config)
91 | self.is_encoder_decoder = True
92 |
93 | @classmethod
94 | def from_encoder_decoder_configs(
95 | cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
96 | ) -> PretrainedConfig:
97 | r"""
98 | Instantiate a [`SpeechEncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model
99 | configuration and decoder model configuration.
100 |
101 | Returns:
102 | [`SpeechEncoderDecoderConfig`]: An instance of a configuration object
103 | """
104 | logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
105 | decoder_config.is_decoder = True
106 | decoder_config.add_cross_attention = True
107 |
108 | return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
109 |
110 | def to_dict(self):
111 | """
112 | Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*.
113 |
114 | Returns:
115 | `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
116 | """
117 | output = copy.deepcopy(self.__dict__)
118 | output["encoder"] = self.encoder.to_dict()
119 | output["decoder"] = self.decoder.to_dict()
120 | output["model_type"] = self.__class__.model_type
121 | return output
122 |
--------------------------------------------------------------------------------
/.idea/webServers.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 | TRC
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 | GPUs
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
--------------------------------------------------------------------------------
/models/configuration_bart.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ BART model configuration"""
16 | import warnings
17 |
18 | from transformers.configuration_utils import PretrainedConfig
19 | from transformers.utils import logging
20 |
21 |
22 | logger = logging.get_logger(__name__)
23 |
24 | BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
25 | "facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/config.json",
26 | # See all BART models at https://huggingface.co/models?filter=bart
27 | }
28 |
29 |
30 | class BartConfig(PretrainedConfig):
31 | r"""
32 | This is the configuration class to store the configuration of a [`BartModel`]. It is used to instantiate a BART
33 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
34 | defaults will yield a similar configuration to that of the BART
35 | [facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
36 |
37 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38 | documentation from [`PretrainedConfig`] for more information.
39 |
40 |
41 | Args:
42 | vocab_size (`int`, *optional*, defaults to 50265):
43 | Vocabulary size of the BART model. Defines the number of different tokens that can be represented by the
44 | `inputs_ids` passed when calling [`BartModel`] or [`TFBartModel`].
45 | d_model (`int`, *optional*, defaults to 1024):
46 | Dimensionality of the layers and the pooler layer.
47 | encoder_layers (`int`, *optional*, defaults to 12):
48 | Number of encoder layers.
49 | decoder_layers (`int`, *optional*, defaults to 12):
50 | Number of decoder layers.
51 | encoder_attention_heads (`int`, *optional*, defaults to 16):
52 | Number of attention heads for each attention layer in the Transformer encoder.
53 | decoder_attention_heads (`int`, *optional*, defaults to 16):
54 | Number of attention heads for each attention layer in the Transformer decoder.
55 | decoder_ffn_dim (`int`, *optional*, defaults to 4096):
56 | Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
57 | encoder_ffn_dim (`int`, *optional*, defaults to 4096):
58 | Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
59 | activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
60 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
61 | `"relu"`, `"silu"` and `"gelu_new"` are supported.
62 | dropout (`float`, *optional*, defaults to 0.1):
63 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
64 | attention_dropout (`float`, *optional*, defaults to 0.0):
65 | The dropout ratio for the attention probabilities.
66 | activation_dropout (`float`, *optional*, defaults to 0.0):
67 | The dropout ratio for activations inside the fully connected layer.
68 | classifier_dropout (`float`, *optional*, defaults to 0.0):
69 | The dropout ratio for classifier.
70 | max_position_embeddings (`int`, *optional*, defaults to 1024):
71 | The maximum sequence length that this model might ever be used with. Typically set this to something large
72 | just in case (e.g., 512 or 1024 or 2048).
73 | init_std (`float`, *optional*, defaults to 0.02):
74 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
75 | encoder_layerdrop: (`float`, *optional*, defaults to 0.0):
76 | The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
77 | for more details.
78 | decoder_layerdrop: (`float`, *optional*, defaults to 0.0):
79 | The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
80 | for more details.
81 | scale_embedding (`bool`, *optional*, defaults to `False`):
82 | Scale embeddings by diving by sqrt(d_model).
83 | use_cache (`bool`, *optional*, defaults to `True`):
84 | Whether or not the model should return the last key/values attentions (not used by all models).
85 | num_labels: (`int`, *optional*, defaults to 3):
86 | The number of labels to use in [`BartForSequenceClassification`].
87 | forced_eos_token_id (`int`, *optional*, defaults to 2):
88 | The id of the token to force as the last generated token when `max_length` is reached. Usually set to
89 | `eos_token_id`.
90 | use_scan (`bool`, *optional*, defaults to `False`):
91 | Whether or not to use nn.scan in the Flax Bart attention layers.
92 |
93 | Example:
94 |
95 | ```python
96 | >>> from transformers import BartModel, BartConfig
97 |
98 | >>> # Initializing a BART facebook/bart-large style configuration
99 | >>> configuration = BartConfig()
100 |
101 | >>> # Initializing a model from the facebook/bart-large style configuration
102 | >>> model = BartModel(configuration)
103 |
104 | >>> # Accessing the model configuration
105 | >>> configuration = model.config
106 | ```"""
107 | model_type = "bart"
108 | keys_to_ignore_at_inference = ["past_key_values"]
109 | attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
110 |
111 | def __init__(
112 | self,
113 | vocab_size=50265,
114 | max_position_embeddings=1024,
115 | encoder_layers=12,
116 | encoder_ffn_dim=4096,
117 | encoder_attention_heads=16,
118 | decoder_layers=12,
119 | decoder_ffn_dim=4096,
120 | decoder_attention_heads=16,
121 | encoder_layerdrop=0.0,
122 | decoder_layerdrop=0.0,
123 | activation_function="gelu",
124 | d_model=1024,
125 | dropout=0.1,
126 | attention_dropout=0.0,
127 | activation_dropout=0.0,
128 | init_std=0.02,
129 | classifier_dropout=0.0,
130 | scale_embedding=False,
131 | use_cache=True,
132 | use_scan=False,
133 | fuse_matmuls=False,
134 | num_labels=3,
135 | pad_token_id=1,
136 | bos_token_id=0,
137 | eos_token_id=2,
138 | is_encoder_decoder=True,
139 | decoder_start_token_id=2,
140 | forced_eos_token_id=2,
141 | **kwargs
142 | ):
143 | self.vocab_size = vocab_size
144 | self.max_position_embeddings = max_position_embeddings
145 | self.d_model = d_model
146 | self.encoder_ffn_dim = encoder_ffn_dim
147 | self.encoder_layers = encoder_layers
148 | self.encoder_attention_heads = encoder_attention_heads
149 | self.decoder_ffn_dim = decoder_ffn_dim
150 | self.decoder_layers = decoder_layers
151 | self.decoder_attention_heads = decoder_attention_heads
152 | self.dropout = dropout
153 | self.attention_dropout = attention_dropout
154 | self.activation_dropout = activation_dropout
155 | self.activation_function = activation_function
156 | self.init_std = init_std
157 | self.encoder_layerdrop = encoder_layerdrop
158 | self.decoder_layerdrop = decoder_layerdrop
159 | self.classifier_dropout = classifier_dropout
160 | self.use_cache = use_cache
161 | self.use_scan = use_scan
162 | self.fuse_matmuls = fuse_matmuls
163 | self.num_hidden_layers = encoder_layers
164 | self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
165 |
166 | super().__init__(
167 | num_labels=num_labels,
168 | pad_token_id=pad_token_id,
169 | bos_token_id=bos_token_id,
170 | eos_token_id=eos_token_id,
171 | is_encoder_decoder=is_encoder_decoder,
172 | decoder_start_token_id=decoder_start_token_id,
173 | forced_eos_token_id=forced_eos_token_id,
174 | **kwargs,
175 | )
176 |
177 | # ensure backward compatibility for BART CNN models
178 | if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
179 | self.forced_bos_token_id = self.bos_token_id
180 | warnings.warn(
181 | f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
182 | "The config can simply be saved and uploaded again to be fixed."
183 | )
184 |
--------------------------------------------------------------------------------
/get_ctc_tokenizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | from datasets import load_dataset
3 | from collections import Counter
4 | import json
5 | import os
6 | import re
7 | import tempfile
8 | from transformers import Wav2Vec2CTCTokenizer
9 |
10 | # which dataset
11 | dataset_name = "sanchit-gandhi/earnings22"
12 | # which config
13 | dataset_config = "all"
14 | # which split => @Sanchit, we should only use the train split for "fairness"
15 | split = "train"
16 | # in case the dataset requires access like CV9
17 | use_auth_token = True
18 | # name of the text data column
19 | text_column = "sentence"
20 | # name of tok to upload to the Hub
21 | tokenizer_name = "wav2vec2-ctc-earnings22-black-box-tokenizer"
22 | # only set to TRUE if dataset is NOT cased
23 | do_lower = False
24 | # ignore the verifications of the downloaded/processed dataset information in `load_dataset`, set to False in most cases, True for E22
25 | ignore_verifications = True
26 |
27 | # should be kept the same across datasets (except for ablation)
28 | do_upper = False
29 | remove_punctuation = False
30 | cutoff_freq = 0.01
31 | # dataset cache directory
32 | dataset_cache_dir = "/home/sanchitgandhi/cache/huggingface/datasets"
33 | # For GigaSpeech, we need to convert spelled out punctuation to symbolic form
34 | gigaspeech_punctuation = {"": ",", "": ".", "": "?", "', '_', '[', ']']
37 | # additional chars to remove if `remove_punctuation` is set to True
38 | additional_chars_to_remove_regex = '[,?.!-;:"“%‘”�{}()<>' + "']"
39 |
40 | dataset = load_dataset(
41 | dataset_name,
42 | dataset_config,
43 | split=split,
44 | use_auth_token=use_auth_token,
45 | cache_dir=dataset_cache_dir,
46 | ignore_verifications= ignore_verifications,
47 | )
48 |
49 | # remove all data that is unnecessary to save RAM
50 | dataset = dataset.remove_columns(list(set(dataset.column_names) - set([text_column])))
51 |
52 | # define function to see stats about letters and to create vocab
53 | def create_vocabulary_from_data(dataset, word_delimiter_token="|", do_lower=False, do_upper=False, remove_punctuation=False, cutoff_freq=0.0):
54 | def extract_all_chars(batch):
55 | all_text = " ".join(batch[text_column])
56 |
57 | if do_lower and do_upper:
58 | raise ValueError("Cannot do uppercase and lowercase tokenization concurrently. Set at most one of `do_lower` or `do_upper` to `True`.")
59 | if do_lower:
60 | all_text = all_text.lower()
61 | if do_upper:
62 | all_text = all_text.upper()
63 | for punctuation, replacement in gigaspeech_punctuation.items():
64 | all_text = all_text.replace(punctuation.lower(), replacement)
65 | all_text = all_text.replace(punctuation.upper(), replacement)
66 | for char in preprocessing_chars_to_remove:
67 | all_text = all_text.replace(char, "")
68 | if remove_punctuation:
69 | all_text = re.sub(additional_chars_to_remove_regex, '', all_text)
70 |
71 | count_chars_dict = Counter(list(all_text))
72 | # sort by freq
73 | count_chars_dict = sorted(count_chars_dict.items(), key=lambda item: (-item[1], item[0]))
74 | # retrieve dict, freq
75 | vocab, freqs = zip(*count_chars_dict)
76 |
77 | return {"vocab": list(vocab), "freqs": list(freqs)}
78 |
79 | dataset = dataset.map(
80 | extract_all_chars,
81 | batched=True,
82 | batch_size=-1,
83 | remove_columns=dataset.column_names,
84 | )
85 |
86 | vocab, freqs = dataset["vocab"], dataset["freqs"]
87 | total_num_chars = sum(freqs)
88 | chars_to_remove = []
89 |
90 | print("Character Occurences")
91 | print(f"Total characters in dataset: {total_num_chars}")
92 | print(50 * "-")
93 | print(f"{'Char'.rjust(5)} | {'Total occ'.rjust(10)} | {'% of total occ'.rjust(20)} |")
94 | print(50 * "-")
95 | for char, freq in zip(vocab, freqs):
96 | freq_in_percent = freq / total_num_chars * 100
97 | print(f"{char.rjust(5)} | {str(freq).rjust(10)} | {str(round(freq_in_percent, 3)).rjust(20)} |")
98 | if freq_in_percent < cutoff_freq:
99 | chars_to_remove.append(char)
100 | print(50 * "-")
101 |
102 | vocab = list(set(vocab) - set(chars_to_remove))
103 |
104 | # Wav2Vec2CTC Tokenizers always have those as the first tokens (important for CTC)
105 | vocab = ["", "", "", ""] + vocab
106 |
107 | alphabet = list(map(chr, range(97, 123)))
108 |
109 | for char in alphabet:
110 | char = char.upper() if do_upper else char
111 | if char not in vocab:
112 | vocab.append(char)
113 |
114 | # create json dict
115 | vocab_dict = {v: k for k, v in enumerate(list(vocab))}
116 |
117 | # replace white space with delimiter token
118 | if word_delimiter_token is not None:
119 | vocab_dict[word_delimiter_token] = vocab_dict[" "]
120 | del vocab_dict[" "]
121 |
122 | return vocab_dict
123 |
124 | # Note that the functions accepts the following important args
125 | # 1. --do_lower
126 | # => whether to lowercase all letters or not.
127 | # Note that if you lowercase letters for the vocab, then you also need to
128 | # do so when preparing the data for the training, dev and test set
129 | # 2. --cutoff_freq
130 | # => This is very important! Lots of datasets will contain "wrong" characters in the training set, e.g.
131 | # characters that just occur a couple of times.
132 | # By default, the CTC vocab creation would just add them to the vocab even if their occurance is neglectible # compared to the "super frequent" letters. We can see such characters as "errors" or irrelevant in the
133 | # dataset, so that we should delete them from the vocab. During training they would then just be classified
134 | # unkown tokens which the model can handle.
135 | # In this script, we deploy a mechanism to remove all chars whose freq in % is below a certain threshold.
136 |
137 |
138 | # To begin with, let's take a look into the charecter distribution to decide whether to lowercase everything
139 | # and how many "incorrect" chars are in the dataset
140 |
141 | # do_lower = False
142 | # cutoff_freq = 0.0
143 | # create_vocabulary_from_data(dataset, do_lower=do_lower, cutoff_freq=cutoff_freq)
144 |
145 | """
146 | Total characters in dataset: 57415071
147 | --------------------------------------------------
148 | Char | Total occ | % of total occ |
149 | --------------------------------------------------
150 | | 9158936 | 15.952 |
151 | e | 5656975 | 9.853 |
152 | a | 3843802 | 6.695 |
153 | t | 3612796 | 6.292 |
154 | i | 3362877 | 5.857 |
155 | o | 3275590 | 5.705 |
156 | n | 3208804 | 5.589 |
157 | s | 3155007 | 5.495 |
158 | r | 3065229 | 5.339 |
159 | h | 2033409 | 3.542 |
160 | l | 1985225 | 3.458 |
161 | d | 1727989 | 3.01 |
162 | c | 1399592 | 2.438 |
163 | u | 1167110 | 2.033 |
164 | m | 1075733 | 1.874 |
165 | f | 884318 | 1.54 |
166 | . | 881733 | 1.536 |
167 | p | 846057 | 1.474 |
168 | g | 809581 | 1.41 |
169 | y | 740494 | 1.29 |
170 | w | 722667 | 1.259 |
171 | b | 606687 | 1.057 |
172 | " | 571968 | 0.996 |
173 | v | 477330 | 0.831 |
174 | , | 345764 | 0.602 |
175 | T | 332058 | 0.578 |
176 | k | 284615 | 0.496 |
177 | S | 174125 | 0.303 |
178 | A | 157656 | 0.275 |
179 | H | 156398 | 0.272 |
180 | C | 143595 | 0.25 |
181 | I | 141826 | 0.247 |
182 | M | 113026 | 0.197 |
183 | B | 102932 | 0.179 |
184 | ' | 88702 | 0.154 |
185 | - | 88461 | 0.154 |
186 | x | 85563 | 0.149 |
187 | P | 84495 | 0.147 |
188 | L | 67280 | 0.117 |
189 | R | 67254 | 0.117 |
190 | W | 66508 | 0.116 |
191 | D | 66094 | 0.115 |
192 | F | 61841 | 0.108 |
193 | G | 59031 | 0.103 |
194 | E | 54387 | 0.095 |
195 | N | 53495 | 0.093 |
196 | z | 47955 | 0.084 |
197 | O | 43305 | 0.075 |
198 | j | 41654 | 0.073 |
199 | q | 40510 | 0.071 |
200 | J | 39647 | 0.069 |
201 | K | 31204 | 0.054 |
202 | U | 23517 | 0.041 |
203 | V | 21380 | 0.037 |
204 | Y | 14863 | 0.026 |
205 | ? | 8574 | 0.015 |
206 | ! | 7327 | 0.013 |
207 | : | 5456 | 0.01 |
208 | ’ | 4864 | 0.008 |
209 | Z | 4735 | 0.008 |
210 | Q | 4488 | 0.008 |
211 | ; | 2781 | 0.005 |
212 | X | 1391 | 0.002 |
213 | ” | 1374 | 0.002 |
214 | “ | 1344 | 0.002 |
215 | ‘ | 1100 | 0.002 |
216 | — | 690 | 0.001 |
217 | é | 297 | 0.001 |
218 | ü | 177 | 0.0 |
219 | ) | 122 | 0.0 |
220 | ( | 121 | 0.0 |
221 | ä | 109 | 0.0 |
222 | ...
223 | 阪 | 1 | 0.0 |
224 | fl | 1 | 0.0 |
225 | """
226 | # All right, we see lots of "wrong" tokens and also see that there is a mix of upper-case and lower-case tokens
227 | # Let's lower-case all tokens and take a look again
228 |
229 | # do_lower = True
230 | # cutoff_freq = 0.0
231 |
232 | # create_vocabulary_from_data(dataset, do_lower=do_lower, cutoff_freq=cutoff_freq)
233 |
234 |
235 | """
236 | Character Occurences
237 | Total characters in dataset: 57415071
238 | --------------------------------------------------
239 | Char | Total occ | % of total occ |
240 | --------------------------------------------------
241 | | 9158936 | 15.952 |
242 | e | 5711362 | 9.947 |
243 | a | 4001458 | 6.969 |
244 | t | 3944854 | 6.871 |
245 | i | 3504703 | 6.104 |
246 | s | 3329132 | 5.798 |
247 | o | 3318895 | 5.781 |
248 | n | 3262299 | 5.682 |
249 | r | 3132483 | 5.456 |
250 | h | 2189807 | 3.814 |
251 | l | 2052505 | 3.575 |
252 | d | 1794083 | 3.125 |
253 | c | 1543187 | 2.688 |
254 | u | 1190627 | 2.074 |
255 | m | 1188759 | 2.07 |
256 | f | 946159 | 1.648 |
257 | p | 930552 | 1.621 |
258 | . | 881733 | 1.536 |
259 | g | 868612 | 1.513 |
260 | w | 789175 | 1.375 |
261 | y | 755357 | 1.316 |
262 | b | 709619 | 1.236 |
263 | " | 571968 | 0.996 |
264 | v | 498710 | 0.869 |
265 | , | 345764 | 0.602 |
266 | k | 315819 | 0.55 |
267 | ' | 88702 | 0.154 |
268 | - | 88461 | 0.154 |
269 | x | 86954 | 0.151 |
270 | j | 81301 | 0.142 |
271 | z | 52690 | 0.092 |
272 | q | 44998 | 0.078 |
273 | ? | 8574 | 0.015 |
274 | ! | 7327 | 0.013 |
275 | : | 5456 | 0.01 |
276 | ’ | 4864 | 0.008 |
277 | ; | 2781 | 0.005 |
278 | ” | 1374 | 0.002 |
279 | “ | 1344 | 0.002 |
280 | ‘ | 1100 | 0.002 |
281 | — | 690 | 0.001 |
282 | é | 319 | 0.001 |
283 | ü | 182 | 0.0 |
284 | ) | 122 | 0.0 |
285 | ( | 121 | 0.0 |
286 | ...
287 | 阪 | 1 | 0.0 |
288 | fl | 1 | 0.0 |
289 | """
290 |
291 | # Cool, now let's remove very rare, "wrong" characters. Everything belowe 0.01% (note that's 1/10,000) seems like a good estimate
292 | # It keeps all letters of the alphabet and some punctuation, but removes clearly all incorrect letters like
293 | # accentuated letters from German or French, Chinese letters, ...
294 | # Running it once more and now keeping the dict
295 |
296 | vocab_dict = create_vocabulary_from_data(dataset, do_lower=do_lower, do_upper=do_upper, remove_punctuation=remove_punctuation, cutoff_freq=cutoff_freq)
297 |
298 | # save vocab dict to be loaded into tokenizer
299 | with tempfile.TemporaryDirectory() as tmp:
300 | with open(os.path.join(tmp, "vocab.json"), "w") as file:
301 | json.dump(vocab_dict, file)
302 |
303 | tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(tmp, do_lower_case=do_upper)
304 |
305 | # push tokenizer to the Hub
306 | # E.g. see: https://huggingface.co/patrickvonplaten/wav2vec2_ctc_cv9_tokenizer
307 | tokenizer.push_to_hub(tokenizer_name)
308 |
--------------------------------------------------------------------------------
/get_ctc_ngram.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | from datasets import load_dataset
3 | from collections import Counter
4 | import re
5 | import os
6 | from transformers import Wav2Vec2ProcessorWithLM
7 | from transformers import AutoProcessor
8 | from pathlib import Path
9 | from pyctcdecode import build_ctcdecoder
10 |
11 | # adapt to dataset
12 | dataset_name = "polinaeterna/voxpopuli"
13 | dataset_config = "en"
14 | split = "train"
15 | text_column = "normalized_text"
16 | tokenizer_name = "sanchit-gandhi/flax-wav2vec2-ctc-voxpopuli-black-box"
17 | do_lower = True # only set to TRUE if dataset is NOT cased
18 |
19 | # should be kept the same across datasets (except for ablation)
20 | cutoff_freq = 0.01
21 | do_upper = False # only set to TRUE for ablation studies
22 | preprocessing_chars_to_remove = [] # only remove chars for ablation studies
23 | additional_chars_to_remove_regex = "" # only set to something for ablation studies
24 | remove_punctuation = additional_chars_to_remove_regex != ""
25 | # dataset specific "error correction"
26 | # For GigaSpeech, we need to convert spelled out punctuation to symbolic form
27 | gigaspeech_punctuation = {"": ",", "": ".", "": "?", "": ",", " ": ".", " ": "?", " ": "!"}
40 | gigaspeech_disfluencies = ["", ""]
41 | swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "", "", "", "[laughter-", "[vocalized-noise]", "_1"]
42 | swb_punctuations = ["{", "}", "[", "]-", "]"]
43 | earnings_disfluencies = ["", "", "", "inaudible", "", "", ""]
44 | ignore_segments = ["ignore_time_segment_in_scoring", "", "", "[noise]", "[laughter]", "[silence]", "[vocalized-noise]", "", "", "", "", "", "", ""]
45 |
46 |
47 | # in case the dataset requires access like CV9
48 | use_auth_token = True
49 |
50 | dataset = load_dataset(
51 | dataset_name,
52 | dataset_config,
53 | split=split,
54 | use_auth_token=use_auth_token,
55 | cache_dir=dataset_cache_dir,
56 | )
57 |
58 | # remove all data that is unnecessary to save RAM
59 | dataset = dataset.remove_columns(list(set(dataset.column_names) - set([text_column])))
60 |
61 | # define function to see stats about letters and to create vocab
62 | # NOTE: this function has to be 1-to-1 aligned with:
63 | # https://github.com/sanchit-gandhi/seq2seq-speech/blob/25d3af18d779d12cdb6c30040f30f51f5a6bb75b/get_ctc_tokenizer.py#L45
64 | def process_text(dataset, word_delimiter_token="|", do_lower=False, do_upper=False, remove_punctuation=False, cutoff_freq=0.0):
65 | def extract_all_chars(batch):
66 | all_text = " ".join(batch[text_column])
67 |
68 | if do_lower and do_upper:
69 | raise ValueError("Cannot do uppercase and lowercase tokenization concurrently. Set at most one of `do_lower` or `do_upper` to `True`.")
70 | if do_lower:
71 | all_text = all_text.lower()
72 | if do_upper:
73 | all_text = all_text.upper()
74 | for punctuation, replacement in gigaspeech_punctuation.items():
75 | all_text = all_text.replace(punctuation.lower(), replacement)
76 | all_text = all_text.replace(punctuation.upper(), replacement)
77 | for char in preprocessing_chars_to_remove:
78 | all_text = all_text.replace(char, "")
79 | # only used for ablation studies
80 | if remove_punctuation:
81 | all_text = re.sub(additional_chars_to_remove_regex, '', all_text)
82 |
83 | count_chars_dict = Counter(list(all_text))
84 | # sort by freq
85 | count_chars_dict = sorted(count_chars_dict.items(), key=lambda item: (-item[1], item[0]))
86 | # retrieve dict, freq
87 | vocab, freqs = zip(*count_chars_dict)
88 |
89 | result = {"vocab": [list(vocab)], "freqs": [list(freqs)]}
90 | result[text_column] = [all_text]
91 |
92 | return result
93 |
94 | dataset = dataset.map(
95 | extract_all_chars,
96 | batched=True,
97 | batch_size=-1,
98 | remove_columns=dataset.column_names,
99 | )
100 |
101 | vocab, freqs = dataset["vocab"][0], dataset["freqs"][0]
102 | total_num_chars = sum(freqs)
103 | chars_to_remove = []
104 |
105 | print("Character Occurences")
106 | print(f"Total characters in dataset: {total_num_chars}")
107 | print(50 * "-")
108 | print(f"{'Char'.rjust(5)} | {'Total occ'.rjust(10)} | {'% of total occ'.rjust(20)} |")
109 | print(50 * "-")
110 | for char, freq in zip(vocab, freqs):
111 | freq_in_percent = freq / total_num_chars * 100
112 | print(f"{char.rjust(5)} | {str(freq).rjust(10)} | {str(round(freq_in_percent, 3)).rjust(20)} |")
113 | if freq_in_percent < cutoff_freq:
114 | chars_to_remove.append(char)
115 | print(50 * "-")
116 | print(f"REMOVED CHARS: {chars_to_remove}")
117 | print(50 * "-")
118 |
119 |
120 | def correct_data(batch):
121 | # LibriSpeech ASR
122 | new_input_strings = []
123 | for input_str in batch[text_column]:
124 | if dataset_name == "librispeech_asr":
125 | pass # no error correction necessary
126 |
127 | # VoxPopuli
128 | if dataset_name == "google/xtreme_s":
129 | pass # no error correction necessary
130 |
131 | # Common Voice 9
132 | if dataset_name == "mozilla-foundation/common_voice_9_0":
133 | if input_str.startswith('"') and input_str.endswith('"'):
134 | # we can remove trailing quotation marks as they do not affect the transcription
135 | input_str = input_str[1:-1]
136 | # replace double quotation marks with single
137 | input_str = input_str.replace('""', '"')
138 |
139 | # TED-LIUM (Release 3)
140 | if dataset_name == "LIUM/tedlium":
141 | # delete the token from the text
142 | input_str = input_str.replace("", "")
143 | # replace spaced apostrophes with un-spaced (it 's -> it's)
144 | for contraction in tedlium_contractions:
145 | input_str = input_str.replace(contraction, contraction[1:])
146 |
147 | # GigaSpeech
148 | if dataset_name == "speechcolab/gigaspeech":
149 | for disfluency in gigaspeech_disfluencies:
150 | input_str = input_str.replace(disfluency, "")
151 | # convert spelled out punctuation to symbolic form
152 | for punctuation, replacement in gigaspeech_punctuation.items():
153 | input_str = input_str.replace(punctuation, replacement)
154 |
155 | # SWB: hide the path to the private HF dataset
156 | if "switchboard" in dataset_name:
157 | for disfluency in swb_disfluencies:
158 | input_str = input_str.replace(disfluency, "")
159 | # remove parenthesised text (test data only)
160 | input_str = re.sub("[\(].*?[\)]", "", input_str)
161 | for punctuation in swb_punctuations:
162 | input_str = input_str.replace(punctuation, "")
163 | # replace anomalous words with their correct transcriptions
164 | split_str = input_str.split("/")
165 | if len(split_str) > 1:
166 | input_str = " ".join(
167 | [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]])
168 |
169 | # Earnings 22: still figuring out best segmenting method. Thus, dataset name subject to change
170 | if "earnings22" in dataset_name:
171 | for disfluency in earnings_disfluencies:
172 | input_str = input_str.replace(disfluency, "")
173 |
174 | # SPGISpeech
175 | if dataset_name == "kensho/spgispeech":
176 | pass # no error correction necessary
177 |
178 | # JIWER compliance (for WER/CER calc.)
179 | # remove multiple spaces
180 | input_str = re.sub(r"\s\s+", " ", input_str)
181 |
182 | # strip trailing spaces
183 | input_str = input_str.strip()
184 |
185 | new_input_strings.append(input_str)
186 |
187 | all_text = " ".join(new_input_strings)
188 |
189 | for char in chars_to_remove:
190 | all_text = all_text.replace(char, "")
191 |
192 | result = {}
193 | result[text_column] = [all_text]
194 |
195 | return result
196 |
197 | dataset = dataset.map(
198 | correct_data,
199 | batched=True,
200 | batch_size=-1,
201 | remove_columns=dataset.column_names,
202 | )
203 |
204 | return dataset
205 |
206 |
207 | # Cool, now let's remove very rare, "wrong" characters. Everything belowe 0.01% (note that's 1/10,000) seems like a good estimate
208 | # It keeps all letters of the alphabet and some punctuation, but removes clearly all incorrect letters like
209 | # accentuated letters from German or French, Chinese letters, ...
210 | # Running it once more and now keeping the dict
211 |
212 | if not os.path.isfile(text_save_path):
213 | text_data = process_text(dataset, do_lower=do_lower, do_upper=do_upper, remove_punctuation=remove_punctuation, cutoff_freq=cutoff_freq)
214 |
215 | # save vocab dict to be loaded into tokenizer
216 | Path(dir_path).mkdir(parents=True, exist_ok=True)
217 | with open(text_save_path, "w") as file:
218 | file.write(" ".join(text_data[text_column]))
219 |
220 | if not os.path.isfile(ngram_save_path):
221 | ngram_command = f"{home_path}/kenlm/build/bin/lmplz -o 5 < '{text_save_path}' > '{ngram_save_path}' --skip_symbols"
222 | os.system(ngram_command)
223 |
224 | # correct with ""
225 | ngram_save_path_correct = ngram_save_path + "_correct.arpa"
226 | with open(ngram_save_path, "r") as read_file, open(ngram_save_path_correct, "w") as write_file:
227 | has_added_eos = False
228 | for line in read_file:
229 | if not has_added_eos and "ngram 1=" in line:
230 | count = line.strip().split("=")[-1]
231 | write_file.write(line.replace(f"{count}", f"{int(count)+1}"))
232 | elif not has_added_eos and "" in line:
233 | write_file.write(line)
234 | write_file.write(line.replace("", ""))
235 | has_added_eos = True
236 | else:
237 | write_file.write(line)
238 |
239 | os.system(f"mv {ngram_save_path_correct} {ngram_save_path}")
240 |
241 |
242 | processor = AutoProcessor.from_pretrained(tokenizer_name)
243 | vocab_dict = processor.tokenizer.get_vocab()
244 | if do_lower:
245 | sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}
246 | else:
247 | sorted_vocab_dict = {k: v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}
248 |
249 | processor.tokenizer.encoder = sorted_vocab_dict
250 | processor.tokenizer.decoder = {v: k for k, v in processor.tokenizer.encoder.items()}
251 |
252 | decoder = build_ctcdecoder(
253 | labels=list(sorted_vocab_dict.keys()),
254 | kenlm_model_path=ngram_save_path,
255 | )
256 |
257 | processor_with_lm = Wav2Vec2ProcessorWithLM(
258 | feature_extractor=processor.feature_extractor,
259 | tokenizer=processor.tokenizer,
260 | decoder=decoder
261 | )
262 | processor_with_lm.save_pretrained(dir_path)
263 |
264 | new_ngram_path = os.path.join(dir_path, "language_model", file_name)
265 | bin_save_path = new_ngram_path.split(".")[0] + ".bin"
266 | os.system(f"{home_path}/kenlm/build/bin/build_binary '{new_ngram_path}' '{bin_save_path}'")
267 | os.system(f"mv '{bin_save_path}' '{new_ngram_path}'")
268 |
269 |
270 |
271 | # CONFIGS
272 | # ========================================================================
273 | # 1. LIBRISPEECH:
274 | # dataset_name = "librispeech_asr"
275 | # dataset_config = None
276 | # split = "train.clean.100+train.clean.360+train.other.500"
277 | # text_column = "text"
278 | # tokenizer_name = "sanchit-gandhi/flax-wav2vec2-ctc-ls-960h-baseline"
279 | # do_lower = True # only set to TRUE if dataset is NOT cased
280 |
281 | # => no REMOVED CHARS: [] -> all chars are used
282 |
283 | # ========================================================================
284 | # 2. TEDLIUM
285 | # dataset_name = "LIUM/tedlium"
286 | # dataset_config = "release3"
287 | # split = "train"
288 | # text_column = "text"
289 | # tokenizer_name = "sanchit-gandhi/flax-wav2vec2-ctc-tedlium-black-box"
290 | # do_lower = False # only set to TRUE if dataset is NOT cased
291 |
292 | # => REMOVED CHARS: ['0', '1', '2', '9', '[', ']', '3', '5', '8', '4', '$', '7', '6', '&', '+', '=', '#', '%', '@', '*', '\\', '^', 'ā']
293 |
294 | # ========================================================================
295 | # 3. AMI
296 | # dataset_name = "speech-seq2seq/ami"
297 | # dataset_config = "ihm"
298 | # split = "train"
299 | # text_column = "text"
300 | # tokenizer_name = "sanchit-gandhi/flax-wav2vec2-ctc-ami-black-box"
301 | # do_lower = False # only set to TRUE if dataset is NOT cased
302 |
303 | # => REMOVED CHARS: ['X', 'Q', 'Z', '0', '!', '*', '1', '3', '@']
304 |
305 | # ========================================================================
306 | # 4. CV 9
307 | # dataset_name = "mozilla-foundation/common_voice_9_0"
308 | # dataset_config = "en"
309 | # split = "train"
310 | # text_column = "sentence"
311 | # tokenizer_name = "sanchit-gandhi/flax-wav2vec2-ctc-cv9-black-box"
312 | # do_lower = False # only set to TRUE if dataset is NOT cased
313 |
314 | # => REMOVED CHARS: REMOVED CHARS: [':', '’', 'Z', 'Q', ';', 'X', '”', '“', '‘', '—', 'é', 'ü', ')', '(', 'ä', 'ö', 'á', 'ó', 'è', 'í', '–', '/', 'ç', '&', 'â', 'ō', 'ß', 'ñ', 'É', 'à', 'ï', 'ô', 'ú', 'ã', 'ê', 'ë', 'č', 'ł', '`', 'Ö', '…', '´', 'ø', 'ć', 'Š', 'ž', 'Ü', 'î', 'ð', 'û', 'ā', 'ă', 'ū', '%', 'Ä', 'ı', 'œ', 'š', '[', ']', '«', '»', 'Á', 'Ó', 'ò', 'ī', 'ș', '_', '¡', '·', 'Ç', 'Ú', 'æ', 'ý', 'ń', 'Ō', 'Œ', 'ř', 'ş', 'ʻ', 'α', 'κ', 'π', 'и', 'к', 'ạ', '#', '=', '~', '§', 'Ã', 'È', 'Î', 'Ø', 'å', 'õ', 'þ', 'Č', 'ē', 'ę', 'ě', 'ğ', 'ň', 'ő', 'Ș', 'ə', 'Α', 'Χ', 'В', 'а', 'е', 'з', 'й', 'л', 'н', 'ь', 'я', 'נ', 'ע', 'ṃ', 'ả', 'ị', 'ụ', '„', '€', '→', '≡', '京', '先', '大', '尚', '时', '生', '都', '阪', 'fl']
315 |
316 | # ========================================================================
317 | # 5. Gigaspeech
318 | # dataset_name = "speechcolab/gigaspeech"
319 | # dataset_config = "l"
320 | # split = "train"
321 | # text_column = "text"
322 | # tokenizer_name = "sanchit-gandhi/flax-wav2vec2-ctc-gs-black-box"
323 | # do_lower = True # only set to TRUE if dataset is NOT cased
324 |
325 | # => REMOVED CHARS: []
326 |
327 | # ========================================================================
328 | # 6. SPGI Kensho Speech
329 | # dataset_name = "kensho/spgispeech"
330 | # dataset_config = "L"
331 | # split = "train"
332 | # text_column = "transcript"
333 | # tokenizer_name = "sanchit-gandhi/flax-wav2vec2-ctc-spgispeech-black-box"
334 | # do_lower = False # only set to TRUE if dataset is NOT cased
335 |
336 | # => REMOVED CHARS: ['Q', 'V', 'U', 'K', '9', 'X', 'Z']
337 |
338 | # ========================================================================
339 | # 7. VoxPopuli
340 | # dataset_name = "polinaeterna/voxpopuli"
341 | # dataset_config = "en"
342 | # split = "train"
343 | # text_column = "normalized_text"
344 | # tokenizer_name = "sanchit-gandhi/flax-wav2vec2-ctc-voxpopuli-black-box"
345 | # do_lower = True # only set to TRUE if dataset is NOT cased
346 |
347 | # => REMOVED CHARS: ['!', '1']
348 |
349 | # ========================================================================
350 | # 8. Earnings 22:
351 | # dataset_name = "sanchit-gandhi/earnings22_robust_split"
352 | # dataset_config = None
353 | # split = "train"
354 | # text_column = "sentence"
355 | # tokenizer_name = "sanchit-gandhi/flax-wav2vec2-ctc-earnings22-cased-hidden-activation-featproj-dropout-0.2"
356 | # do_lower = False # only set to TRUE if dataset is NOT cased
357 |
358 | # => REMOVED CHARS: ['&', 'X', ';', '€', 'Z', '/', ':', '*', '£', '¥', 'ł', '₽', '!', '+', '–', '@', '¢', '₵', '\\', '_', '#', '=', 'ı', '₦', '[', 'ø', '₱']
359 |
--------------------------------------------------------------------------------
/models/configuration_wav2vec2.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Wav2Vec2 model configuration"""
16 |
17 | import functools
18 | import operator
19 |
20 | from transformers.configuration_utils import PretrainedConfig
21 | from transformers.utils import logging
22 |
23 |
24 | logger = logging.get_logger(__name__)
25 |
26 | WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27 | "facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json",
28 | # See all Wav2Vec2 models at https://huggingface.co/models?filter=wav2vec2
29 | }
30 |
31 |
32 | class Wav2Vec2Config(PretrainedConfig):
33 | r"""
34 | This is the configuration class to store the configuration of a [`Wav2Vec2Model`]. It is used to instantiate an
35 | Wav2Vec2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
36 | with the defaults will yield a similar configuration to that of the Wav2Vec2
37 | [facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) architecture.
38 |
39 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40 | documentation from [`PretrainedConfig`] for more information.
41 |
42 |
43 | Args:
44 | vocab_size (`int`, *optional*, defaults to 32):
45 | Vocabulary size of the Wav2Vec2 model. Defines the number of different tokens that can be represented by
46 | the `inputs_ids` passed when calling [`Wav2Vec2Model`] or [`TFWav2Vec2Model`]. Vocabulary size of the
47 | model. Defines the different tokens that can be represented by the *inputs_ids* passed to the forward
48 | method of [`Wav2Vec2Model`].
49 | hidden_size (`int`, *optional*, defaults to 768):
50 | Dimensionality of the encoder layers and the pooler layer.
51 | num_hidden_layers (`int`, *optional*, defaults to 12):
52 | Number of hidden layers in the Transformer encoder.
53 | num_attention_heads (`int`, *optional*, defaults to 12):
54 | Number of attention heads for each attention layer in the Transformer encoder.
55 | intermediate_size (`int`, *optional*, defaults to 3072):
56 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
57 | hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
58 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
59 | `"relu"`, `"selu"` and `"gelu_new"` are supported.
60 | hidden_dropout (`float`, *optional*, defaults to 0.1):
61 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
62 | attention_dropout (`float`, *optional*, defaults to 0.1):
63 | The dropout ratio for the attention probabilities.
64 | final_dropout (`float`, *optional*, defaults to 0.1):
65 | The dropout probability for the final projection layer of [`Wav2Vec2ForCTC`].
66 | initializer_range (`float`, *optional*, defaults to 0.02):
67 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68 | layer_norm_eps (`float`, *optional*, defaults to 1e-12):
69 | The epsilon used by the layer normalization layers.
70 | feat_extract_norm (`str`, *optional*, defaults to `"group"`):
71 | The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group
72 | normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
73 | convolutional layers.
74 | feat_proj_dropout (`float`, *optional*, defaults to 0.0):
75 | The dropout probability for output of the feature encoder.
76 | feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
77 | The non-linear activation function (function or string) in the 1D convolutional layers of the feature
78 | extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
79 | feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
80 | The dropout probabilitiy for quantized feature encoder states.
81 | conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
82 | A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
83 | feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
84 | conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
85 | A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
86 | of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
87 | conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
88 | A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
89 | length of *conv_kernel* defines the number of convolutional layers and has to match the length of
90 | *conv_dim*.
91 | conv_bias (`bool`, *optional*, defaults to `False`):
92 | Whether the 1D convolutional layers have a bias.
93 | num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
94 | Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
95 | embeddings layer.
96 | num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
97 | Number of groups of 1D convolutional positional embeddings layer.
98 | do_stable_layer_norm (`bool`, *optional*, defaults to `False`):
99 | Whether to apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is
100 | True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is
101 | False` corresponds to applying layer norm after the attention layer.
102 | apply_spec_augment (`bool`, *optional*, defaults to `True`):
103 | Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
104 | [SpecAugment: A Simple Data Augmentation Method for Automatic Speech
105 | Recognition](https://arxiv.org/abs/1904.08779).
106 | mask_time_prob (`float`, *optional*, defaults to 0.05):
107 | Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
108 | procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
109 | reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
110 | masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
111 | actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
112 | mask_time_length (`int`, *optional*, defaults to 10):
113 | Length of vector span along the time axis.
114 | mask_time_min_masks (`int`, *optional*, defaults to 2),:
115 | The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
116 | irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
117 | mask_time_min_masks''
118 | mask_feature_prob (`float`, *optional*, defaults to 0.0):
119 | Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
120 | masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
121 | the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
122 | span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
123 | may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
124 | True`.
125 | mask_feature_length (`int`, *optional*, defaults to 10):
126 | Length of vector span along the feature axis.
127 | mask_feature_min_masks (`int`, *optional*, defaults to 0),:
128 | The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
129 | step, irrespectively of `mask_feature_prob`. Only relevant if
130 | ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
131 | num_codevectors_per_group (`int`, *optional*, defaults to 320):
132 | Number of entries in each quantization codebook (group).
133 | num_codevector_groups (`int`, *optional*, defaults to 2):
134 | Number of codevector groups for product codevector quantization.
135 | contrastive_logits_temperature (`float`, *optional*, defaults to 0.1):
136 | The temperature *kappa* in the contrastive loss.
137 | feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
138 | The dropout probabilitiy for the output of the feature encoder that's used by the quantizer.
139 | num_negatives (`int`, *optional*, defaults to 100):
140 | Number of negative samples for the contrastive loss.
141 | codevector_dim (`int`, *optional*, defaults to 256):
142 | Dimensionality of the quantized feature vectors.
143 | proj_codevector_dim (`int`, *optional*, defaults to 256):
144 | Dimensionality of the final projection of both the quantized and the transformer features.
145 | diversity_loss_weight (`int`, *optional*, defaults to 0.1):
146 | The weight of the codebook diversity loss component.
147 | ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
148 | Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
149 | instance of [`Wav2Vec2ForCTC`].
150 | ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
151 | Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
152 | occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
153 | of [`Wav2Vec2ForCTC`].
154 | use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
155 | Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
156 | instance of [`Wav2Vec2ForSequenceClassification`].
157 | classifier_proj_size (`int`, *optional*, defaults to 256):
158 | Dimensionality of the projection before token mean-pooling for classification.
159 | tdnn_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
160 | A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
161 | module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
162 | tdnn_kernel (`Tuple[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
163 | A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
164 | *XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
165 | tdnn_dilation (`Tuple[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
166 | A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
167 | *XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
168 | xvector_output_dim (`int`, *optional*, defaults to 512):
169 | Dimensionality of the *XVector* embedding vectors.
170 | add_adapter (`bool`, *optional*, defaults to `False`):
171 | Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for
172 | warm-starting Wav2Vec2 for SpeechEncoderDecoder models.
173 | adapter_kernel_size (`int`, *optional*, defaults to 3):
174 | Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
175 | adapter_stride (`int`, *optional*, defaults to 2):
176 | Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
177 | num_adapter_layers (`int`, *optional*, defaults to 3):
178 | Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
179 | True`.
180 | output_hidden_size (`int`, *optional*):
181 | Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
182 | if `add_adapter is True`.
183 | use_scan (`bool`, *optional*, defaults to `False`):
184 | Whether or not to use nn.scan in the Flax Wav2Vec2 transformer layers.
185 |
186 | Example:
187 |
188 | ```python
189 | >>> from transformers import Wav2Vec2Model, Wav2Vec2Config
190 |
191 | >>> # Initializing a Wav2Vec2 facebook/wav2vec2-base-960h style configuration
192 | >>> configuration = Wav2Vec2Config()
193 |
194 | >>> # Initializing a model from the facebook/wav2vec2-base-960h style configuration
195 | >>> model = Wav2Vec2Model(configuration)
196 |
197 | >>> # Accessing the model configuration
198 | >>> configuration = model.config
199 | ```"""
200 | model_type = "wav2vec2"
201 |
202 | def __init__(
203 | self,
204 | vocab_size=32,
205 | hidden_size=768,
206 | num_hidden_layers=12,
207 | num_attention_heads=12,
208 | intermediate_size=3072,
209 | hidden_act="gelu",
210 | hidden_dropout=0.1,
211 | activation_dropout=0.1,
212 | attention_dropout=0.1,
213 | feat_proj_dropout=0.0,
214 | feat_quantizer_dropout=0.0,
215 | final_dropout=0.1,
216 | layerdrop=0.1,
217 | initializer_range=0.02,
218 | layer_norm_eps=1e-5,
219 | feat_extract_norm="group",
220 | feat_extract_activation="gelu",
221 | conv_dim=(512, 512, 512, 512, 512, 512, 512),
222 | conv_stride=(5, 2, 2, 2, 2, 2, 2),
223 | conv_kernel=(10, 3, 3, 3, 3, 2, 2),
224 | conv_bias=False,
225 | num_conv_pos_embeddings=128,
226 | num_conv_pos_embedding_groups=16,
227 | do_stable_layer_norm=False,
228 | apply_spec_augment=True,
229 | mask_time_prob=0.05,
230 | mask_time_length=10,
231 | mask_time_min_masks=2,
232 | mask_feature_prob=0.0,
233 | mask_feature_length=10,
234 | mask_feature_min_masks=0,
235 | num_codevectors_per_group=320,
236 | num_codevector_groups=2,
237 | contrastive_logits_temperature=0.1,
238 | num_negatives=100,
239 | codevector_dim=256,
240 | proj_codevector_dim=256,
241 | diversity_loss_weight=0.1,
242 | ctc_loss_reduction="sum",
243 | ctc_zero_infinity=False,
244 | use_weighted_layer_sum=False,
245 | classifier_proj_size=256,
246 | tdnn_dim=(512, 512, 512, 512, 1500),
247 | tdnn_kernel=(5, 3, 3, 1, 1),
248 | tdnn_dilation=(1, 2, 3, 1, 1),
249 | xvector_output_dim=512,
250 | pad_token_id=0,
251 | bos_token_id=1,
252 | eos_token_id=2,
253 | add_adapter=False,
254 | adapter_kernel_size=3,
255 | adapter_stride=2,
256 | num_adapter_layers=3,
257 | output_hidden_size=None,
258 | use_scan=False,
259 | fuse_matmuls=False,
260 | **kwargs
261 | ):
262 | super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
263 | self.hidden_size = hidden_size
264 | self.feat_extract_norm = feat_extract_norm
265 | self.feat_extract_activation = feat_extract_activation
266 | self.conv_dim = list(conv_dim)
267 | self.conv_stride = list(conv_stride)
268 | self.conv_kernel = list(conv_kernel)
269 | self.conv_bias = conv_bias
270 | self.num_conv_pos_embeddings = num_conv_pos_embeddings
271 | self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
272 | self.num_feat_extract_layers = len(self.conv_dim)
273 | self.num_hidden_layers = num_hidden_layers
274 | self.intermediate_size = intermediate_size
275 | self.hidden_act = hidden_act
276 | self.num_attention_heads = num_attention_heads
277 | self.hidden_dropout = hidden_dropout
278 | self.attention_dropout = attention_dropout
279 | self.activation_dropout = activation_dropout
280 | self.feat_proj_dropout = feat_proj_dropout
281 | self.final_dropout = final_dropout
282 | self.layerdrop = layerdrop
283 | self.layer_norm_eps = layer_norm_eps
284 | self.initializer_range = initializer_range
285 | self.vocab_size = vocab_size
286 | self.do_stable_layer_norm = do_stable_layer_norm
287 | self.use_weighted_layer_sum = use_weighted_layer_sum
288 | self.use_scan = use_scan
289 | self.fuse_matmuls = fuse_matmuls
290 |
291 | if (
292 | (len(self.conv_stride) != self.num_feat_extract_layers)
293 | or (len(self.conv_kernel) != self.num_feat_extract_layers)
294 | or (len(self.conv_dim) != self.num_feat_extract_layers)
295 | ):
296 | raise ValueError(
297 | "Configuration for convolutional layers is incorrect. "
298 | "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
299 | f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
300 | f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
301 | )
302 |
303 | # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
304 | self.apply_spec_augment = apply_spec_augment
305 | self.mask_time_prob = mask_time_prob
306 | self.mask_time_length = mask_time_length
307 | self.mask_time_min_masks = mask_time_min_masks
308 | self.mask_feature_prob = mask_feature_prob
309 | self.mask_feature_length = mask_feature_length
310 | self.mask_feature_min_masks = mask_feature_min_masks
311 |
312 | # parameters for pretraining with codevector quantized representations
313 | self.num_codevectors_per_group = num_codevectors_per_group
314 | self.num_codevector_groups = num_codevector_groups
315 | self.contrastive_logits_temperature = contrastive_logits_temperature
316 | self.feat_quantizer_dropout = feat_quantizer_dropout
317 | self.num_negatives = num_negatives
318 | self.codevector_dim = codevector_dim
319 | self.proj_codevector_dim = proj_codevector_dim
320 | self.diversity_loss_weight = diversity_loss_weight
321 |
322 | # ctc loss
323 | self.ctc_loss_reduction = ctc_loss_reduction
324 | self.ctc_zero_infinity = ctc_zero_infinity
325 |
326 | # adapter
327 | self.add_adapter = add_adapter
328 | self.adapter_kernel_size = adapter_kernel_size
329 | self.adapter_stride = adapter_stride
330 | self.num_adapter_layers = num_adapter_layers
331 | self.output_hidden_size = output_hidden_size or hidden_size
332 |
333 | # SequenceClassification-specific parameter. Feel free to ignore for other classes.
334 | self.classifier_proj_size = classifier_proj_size
335 |
336 | # XVector-specific parameters. Feel free to ignore for other classes.
337 | self.tdnn_dim = list(tdnn_dim)
338 | self.tdnn_kernel = list(tdnn_kernel)
339 | self.tdnn_dilation = list(tdnn_dilation)
340 | self.xvector_output_dim = xvector_output_dim
341 |
342 | @property
343 | def inputs_to_logits_ratio(self):
344 | return functools.reduce(operator.mul, self.conv_stride, 1)
345 |
--------------------------------------------------------------------------------
/tests/check_save_feature_encoder_output.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "b1db3741-29dc-4d1f-943c-ecf2ede5ad66",
7 | "metadata": {},
8 | "outputs": [
9 | {
10 | "name": "stderr",
11 | "output_type": "stream",
12 | "text": [
13 | "/Users/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/lib/__init__.py:33: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.\n",
14 | " warnings.warn(\"JAX on Mac ARM machines is experimental and minimally tested. \"\n"
15 | ]
16 | }
17 | ],
18 | "source": [
19 | "from transformers import FlaxSpeechEncoderDecoderModel\n",
20 | "from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel as CustomFlaxSpeechEncoderDecoderModel\n",
21 | "import numpy as np"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 2,
27 | "id": "d67ee694-b5e6-4c98-92f1-3f214500cf55",
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "encoder_id = 'hf-internal-testing/tiny-random-wav2vec2'\n",
32 | "decoder_id = 'hf-internal-testing/tiny-random-bart'"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 3,
38 | "id": "5b30237b-080d-4447-a19b-4f11c11603b3",
39 | "metadata": {},
40 | "outputs": [
41 | {
42 | "name": "stderr",
43 | "output_type": "stream",
44 | "text": [
45 | "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n",
46 | "Some weights of the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 were not used when initializing FlaxWav2Vec2Model: {('quantizer', 'weight_proj', 'bias'), ('lm_head', 'bias'), ('project_q', 'bias'), ('quantizer', 'codevectors'), ('quantizer', 'weight_proj', 'kernel'), ('lm_head', 'kernel'), ('project_hid', 'bias'), ('project_q', 'kernel'), ('project_hid', 'kernel')}\n",
47 | "- This IS expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
48 | "- This IS NOT expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
49 | "Some weights of FlaxWav2Vec2Model were not initialized from the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 and are newly initialized: {('feature_extractor', 'conv_layers', '2', 'layer_norm', 'bias'), ('feature_extractor', 'conv_layers', '1', 'layer_norm', 'scale'), ('feature_extractor', 'conv_layers', '2', 'layer_norm', 'scale'), ('feature_extractor', 'conv_layers', '1', 'layer_norm', 'bias')}\n",
50 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
51 | "You passed along `num_labels=3` with an incompatible id to label map: {'0': 'LABEL_0', '1': 'LABEL_1'}. The number of labels wil be overwritten to 2.\n",
52 | "Some weights of the model checkpoint at hf-internal-testing/tiny-random-bart were not used when initializing FlaxBartForCausalLM: {('encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('decoder', 'layers', '1', 'fc1', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn_layer_norm', 'scale'), ('model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'final_layer_norm', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn_layer_norm', 'bias'), ('encoder', 'layers', '0', 'fc1', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('model', 'shared', 'kernel'), ('shared', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn_layer_norm', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'embed_tokens', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('final_logits_bias',), ('decoder', 'embed_tokens', 'embedding'), ('model', 'encoder', 'layernorm_embedding', 'kernel'), ('encoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn_layer_norm', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '1', 'fc2', 'bias'), ('encoder', 'embed_tokens', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'final_layer_norm', 'bias'), ('encoder', 'layernorm_embedding', 'kernel'), ('decoder', 'layers', '0', 'final_layer_norm', 'scale'), ('lm_head', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'final_layer_norm', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'out_proj', 'bias'), ('encoder', 'layers', '1', 'fc2', 'bias'), ('model', 'encoder', 'layers', '1', 'final_layer_norm', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('decoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'fc2', 'kernel'), ('decoder', 'layers', '0', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'fc1', 'kernel'), ('model', 'encoder', 'layernorm_embedding', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '1', 'fc2', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '0', 'fc2', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('decoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn_layer_norm', 'scale'), ('encoder', 'layernorm_embedding', 'bias'), ('decoder', 'layers', '0', 'encoder_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn_layer_norm', 'scale'), ('decoder', 'layers', '1', 'encoder_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'final_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('decoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'fc2', 'bias'), ('model', 'encoder', 'layers', '1', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '0', 'final_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('model', 'encoder', 'layers', '1', 'fc2', 'kernel'), ('model', 'encoder', 'layers', '0', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('encoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'final_layer_norm', 'kernel'), ('model', 'encoder', 'embed_positions', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layernorm_embedding', 'scale'), ('decoder', 'embed_positions', 'embedding'), ('decoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layernorm_embedding', 'bias'), ('classification_head', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'q_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '0', 'fc2', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('encoder', 'embed_positions', 'kernel'), ('encoder', 'layers', '1', 'fc1', 'bias'), ('encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'final_layer_norm', 'bias'), ('qa_outputs', 'bias'), ('classification_head', 'out_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn', 'v_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn_layer_norm', 'scale'), ('decoder', 'layers', '1', 'encoder_attn', 'k_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'final_layer_norm', 'kernel'), ('qa_outputs', 'kernel'), ('classification_head', 'dense', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn_layer_norm', 'kernel'), ('encoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'v_proj', 'bias'), ('decoder', 'layers', '1', 'final_layer_norm', 'scale'), ('encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('decoder', 'layers', '1', 'final_layer_norm', 'bias'), ('decoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'fc2', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn', 'q_proj', 'kernel'), ('classification_head', 'dense', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('decoder', 'layers', '1', 'fc1', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('encoder', 'layers', '0', 'self_attn_layer_norm', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel')}\n",
53 | "- This IS expected if you are initializing FlaxBartForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
54 | "- This IS NOT expected if you are initializing FlaxBartForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
55 | "You passed along `num_labels=3` with an incompatible id to label map: {0: 'LABEL_0', 1: 'LABEL_1'}. The number of labels wil be overwritten to 2.\n",
56 | "Some weights of the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 were not used when initializing FlaxWav2Vec2Model: {('quantizer', 'weight_proj', 'bias'), ('lm_head', 'bias'), ('project_q', 'bias'), ('quantizer', 'codevectors'), ('quantizer', 'weight_proj', 'kernel'), ('lm_head', 'kernel'), ('project_hid', 'bias'), ('project_q', 'kernel'), ('project_hid', 'kernel')}\n",
57 | "- This IS expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
58 | "- This IS NOT expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
59 | "Some weights of FlaxWav2Vec2Model were not initialized from the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 and are newly initialized: {('feature_extractor', 'conv_layers', '2', 'layer_norm', 'bias'), ('feature_extractor', 'conv_layers', '1', 'layer_norm', 'scale'), ('feature_extractor', 'conv_layers', '2', 'layer_norm', 'scale'), ('feature_extractor', 'conv_layers', '1', 'layer_norm', 'bias')}\n",
60 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
61 | "You passed along `num_labels=3` with an incompatible id to label map: {'0': 'LABEL_0', '1': 'LABEL_1'}. The number of labels wil be overwritten to 2.\n",
62 | "Some weights of the model checkpoint at hf-internal-testing/tiny-random-bart were not used when initializing FlaxBartForCausalLM: {('encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('decoder', 'layers', '1', 'fc1', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn_layer_norm', 'scale'), ('model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'final_layer_norm', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn_layer_norm', 'bias'), ('encoder', 'layers', '0', 'fc1', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('model', 'shared', 'kernel'), ('shared', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn_layer_norm', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'embed_tokens', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('final_logits_bias',), ('decoder', 'embed_tokens', 'embedding'), ('model', 'encoder', 'layernorm_embedding', 'kernel'), ('encoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn_layer_norm', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '1', 'fc2', 'bias'), ('encoder', 'embed_tokens', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'final_layer_norm', 'bias'), ('encoder', 'layernorm_embedding', 'kernel'), ('decoder', 'layers', '0', 'final_layer_norm', 'scale'), ('lm_head', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'final_layer_norm', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'out_proj', 'bias'), ('encoder', 'layers', '1', 'fc2', 'bias'), ('model', 'encoder', 'layers', '1', 'final_layer_norm', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('decoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'fc2', 'kernel'), ('decoder', 'layers', '0', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'fc1', 'kernel'), ('model', 'encoder', 'layernorm_embedding', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '1', 'fc2', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '0', 'fc2', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('decoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn_layer_norm', 'scale'), ('encoder', 'layernorm_embedding', 'bias'), ('decoder', 'layers', '0', 'encoder_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn_layer_norm', 'scale'), ('decoder', 'layers', '1', 'encoder_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'final_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('decoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'fc2', 'bias'), ('model', 'encoder', 'layers', '1', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '0', 'final_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('model', 'encoder', 'layers', '1', 'fc2', 'kernel'), ('model', 'encoder', 'layers', '0', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('encoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'final_layer_norm', 'kernel'), ('model', 'encoder', 'embed_positions', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layernorm_embedding', 'scale'), ('decoder', 'embed_positions', 'embedding'), ('decoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layernorm_embedding', 'bias'), ('classification_head', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'q_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '0', 'fc2', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('encoder', 'embed_positions', 'kernel'), ('encoder', 'layers', '1', 'fc1', 'bias'), ('encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'final_layer_norm', 'bias'), ('qa_outputs', 'bias'), ('classification_head', 'out_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn', 'v_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn_layer_norm', 'scale'), ('decoder', 'layers', '1', 'encoder_attn', 'k_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'final_layer_norm', 'kernel'), ('qa_outputs', 'kernel'), ('classification_head', 'dense', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn_layer_norm', 'kernel'), ('encoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'v_proj', 'bias'), ('decoder', 'layers', '1', 'final_layer_norm', 'scale'), ('encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('decoder', 'layers', '1', 'final_layer_norm', 'bias'), ('decoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'fc2', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn', 'q_proj', 'kernel'), ('classification_head', 'dense', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('decoder', 'layers', '1', 'fc1', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('encoder', 'layers', '0', 'self_attn_layer_norm', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel')}\n",
63 | "- This IS expected if you are initializing FlaxBartForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
64 | "- This IS NOT expected if you are initializing FlaxBartForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
65 | "You passed along `num_labels=3` with an incompatible id to label map: {0: 'LABEL_0', 1: 'LABEL_1'}. The number of labels wil be overwritten to 2.\n"
66 | ]
67 | }
68 | ],
69 | "source": [
70 | "hf_model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_from_pt=True, decoder_from_pt=True)\n",
71 | "custom_model = CustomFlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_from_pt=True, decoder_from_pt=True)"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": 4,
77 | "id": "fdadc85b-768d-46c7-aa26-01b1f93fd635",
78 | "metadata": {},
79 | "outputs": [],
80 | "source": [
81 | "# create some dummy data\n",
82 | "inputs = np.random.randn(2, 2000)\n",
83 | "decoder_input_ids = np.arange(100).reshape(2,50)"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": 5,
89 | "id": "caaff775-e082-419f-b737-b392667469d5",
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "# get ground-truth outputs from Transformers 🤗 model\n",
94 | "hf_outputs = hf_model(inputs, decoder_input_ids=decoder_input_ids, output_hidden_states=True)"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": 6,
100 | "id": "1ff2be6b-73ef-49fc-8dbe-e933b5d4f8b0",
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "extract_features = custom_model.encode(inputs, output_features=True)"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": 7,
110 | "id": "0e2b2d35-dee8-4c47-be35-e543b0f9ef03",
111 | "metadata": {},
112 | "outputs": [],
113 | "source": [
114 | "custom_outputs = custom_model(inputs, extract_features=extract_features, decoder_input_ids=decoder_input_ids, output_hidden_states=True)"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": 8,
120 | "id": "aefb168b-4882-41f1-ba88-39f298835dac",
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "# define a helper function for our analysis\n",
125 | "def assert_almost_equals(a: np.ndarray, b: np.ndarray, tol: float = 1e-9):\n",
126 | " diff = np.abs((a - b)).max()\n",
127 | " if diff <= tol:\n",
128 | " print(f\"✅ Difference between Flax and PyTorch is {diff} (< {tol})\")\n",
129 | " else:\n",
130 | " print(f\"❌ Difference between Flax and PyTorch is {diff} (>= {tol})\")"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": 9,
136 | "id": "dea2f7c9-a462-40c1-b09b-7c8499318e87",
137 | "metadata": {},
138 | "outputs": [
139 | {
140 | "name": "stdout",
141 | "output_type": "stream",
142 | "text": [
143 | "--------------------------Checking encoder hidden states match--------------------------\n",
144 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n",
145 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n",
146 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n",
147 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n",
148 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n",
149 | "--------------------------Checking encoder last hidden states match--------------------------\n",
150 | "HF output shape: (2, 29, 16), custom output shape: (2, 29, 16)\n",
151 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n",
152 | "--------------------------Checking decoder hidden states match--------------------------\n",
153 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n",
154 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n",
155 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n",
156 | "--------------------------Checking logits match--------------------------\n",
157 | "HF logits shape: (2, 50, 1000), Custom logits shape: (2, 50, 1000)\n",
158 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n"
159 | ]
160 | }
161 | ],
162 | "source": [
163 | "print(\"--------------------------Checking encoder hidden states match--------------------------\")\n",
164 | "for hf_state, custom_state in zip(hf_outputs.encoder_hidden_states, custom_outputs.encoder_hidden_states):\n",
165 | " assert hf_state.shape == custom_state.shape\n",
166 | " assert_almost_equals(hf_state, custom_state)\n",
167 | "\n",
168 | "print(\"--------------------------Checking encoder last hidden states match--------------------------\")\n",
169 | "print(f\"HF output shape: {hf_outputs.encoder_last_hidden_state.shape}, custom output shape: {custom_outputs.encoder_last_hidden_state.shape}\")\n",
170 | "assert_almost_equals(hf_outputs.encoder_last_hidden_state, custom_outputs.encoder_last_hidden_state)\n",
171 | "\n",
172 | "print(\"--------------------------Checking decoder hidden states match--------------------------\")\n",
173 | "for hf_state, custom_state in zip(hf_outputs.decoder_hidden_states, custom_outputs.decoder_hidden_states):\n",
174 | " assert hf_state.shape == custom_state.shape\n",
175 | " assert_almost_equals(hf_state, custom_state)\n",
176 | "\n",
177 | "print(\"--------------------------Checking logits match--------------------------\")\n",
178 | "print(f\"HF logits shape: {hf_outputs.logits.shape}, Custom logits shape: {custom_outputs.logits.shape}\")\n",
179 | "assert_almost_equals(hf_outputs.logits, custom_outputs.logits)"
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": null,
185 | "id": "36f0a6f8-fc48-4326-bbaf-4ab5194f84a6",
186 | "metadata": {},
187 | "outputs": [],
188 | "source": []
189 | }
190 | ],
191 | "metadata": {
192 | "kernelspec": {
193 | "display_name": "Python 3 (ipykernel)",
194 | "language": "python",
195 | "name": "python3"
196 | },
197 | "language_info": {
198 | "codemirror_mode": {
199 | "name": "ipython",
200 | "version": 3
201 | },
202 | "file_extension": ".py",
203 | "mimetype": "text/x-python",
204 | "name": "python",
205 | "nbconvert_exporter": "python",
206 | "pygments_lexer": "ipython3",
207 | "version": "3.8.9"
208 | }
209 | },
210 | "nbformat": 4,
211 | "nbformat_minor": 5
212 | }
213 |
--------------------------------------------------------------------------------
/tests/.ipynb_checkpoints/check_save_feature_encoder_output-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "b1db3741-29dc-4d1f-943c-ecf2ede5ad66",
7 | "metadata": {},
8 | "outputs": [
9 | {
10 | "name": "stderr",
11 | "output_type": "stream",
12 | "text": [
13 | "/Users/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/lib/__init__.py:33: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.\n",
14 | " warnings.warn(\"JAX on Mac ARM machines is experimental and minimally tested. \"\n"
15 | ]
16 | }
17 | ],
18 | "source": [
19 | "from transformers import FlaxSpeechEncoderDecoderModel\n",
20 | "from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel as CustomFlaxSpeechEncoderDecoderModel\n",
21 | "import numpy as np"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 2,
27 | "id": "d67ee694-b5e6-4c98-92f1-3f214500cf55",
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "encoder_id = 'hf-internal-testing/tiny-random-wav2vec2'\n",
32 | "decoder_id = 'hf-internal-testing/tiny-random-bart'"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 3,
38 | "id": "5b30237b-080d-4447-a19b-4f11c11603b3",
39 | "metadata": {},
40 | "outputs": [
41 | {
42 | "name": "stderr",
43 | "output_type": "stream",
44 | "text": [
45 | "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n",
46 | "Some weights of the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 were not used when initializing FlaxWav2Vec2Model: {('quantizer', 'weight_proj', 'bias'), ('lm_head', 'bias'), ('project_q', 'bias'), ('quantizer', 'codevectors'), ('quantizer', 'weight_proj', 'kernel'), ('lm_head', 'kernel'), ('project_hid', 'bias'), ('project_q', 'kernel'), ('project_hid', 'kernel')}\n",
47 | "- This IS expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
48 | "- This IS NOT expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
49 | "Some weights of FlaxWav2Vec2Model were not initialized from the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 and are newly initialized: {('feature_extractor', 'conv_layers', '2', 'layer_norm', 'bias'), ('feature_extractor', 'conv_layers', '1', 'layer_norm', 'scale'), ('feature_extractor', 'conv_layers', '2', 'layer_norm', 'scale'), ('feature_extractor', 'conv_layers', '1', 'layer_norm', 'bias')}\n",
50 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
51 | "You passed along `num_labels=3` with an incompatible id to label map: {'0': 'LABEL_0', '1': 'LABEL_1'}. The number of labels wil be overwritten to 2.\n",
52 | "Some weights of the model checkpoint at hf-internal-testing/tiny-random-bart were not used when initializing FlaxBartForCausalLM: {('encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('decoder', 'layers', '1', 'fc1', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn_layer_norm', 'scale'), ('model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'final_layer_norm', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn_layer_norm', 'bias'), ('encoder', 'layers', '0', 'fc1', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('model', 'shared', 'kernel'), ('shared', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn_layer_norm', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'embed_tokens', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('final_logits_bias',), ('decoder', 'embed_tokens', 'embedding'), ('model', 'encoder', 'layernorm_embedding', 'kernel'), ('encoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn_layer_norm', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '1', 'fc2', 'bias'), ('encoder', 'embed_tokens', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'final_layer_norm', 'bias'), ('encoder', 'layernorm_embedding', 'kernel'), ('decoder', 'layers', '0', 'final_layer_norm', 'scale'), ('lm_head', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'final_layer_norm', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'out_proj', 'bias'), ('encoder', 'layers', '1', 'fc2', 'bias'), ('model', 'encoder', 'layers', '1', 'final_layer_norm', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('decoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'fc2', 'kernel'), ('decoder', 'layers', '0', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'fc1', 'kernel'), ('model', 'encoder', 'layernorm_embedding', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '1', 'fc2', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '0', 'fc2', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('decoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn_layer_norm', 'scale'), ('encoder', 'layernorm_embedding', 'bias'), ('decoder', 'layers', '0', 'encoder_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn_layer_norm', 'scale'), ('decoder', 'layers', '1', 'encoder_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'final_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('decoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'fc2', 'bias'), ('model', 'encoder', 'layers', '1', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '0', 'final_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('model', 'encoder', 'layers', '1', 'fc2', 'kernel'), ('model', 'encoder', 'layers', '0', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('encoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'final_layer_norm', 'kernel'), ('model', 'encoder', 'embed_positions', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layernorm_embedding', 'scale'), ('decoder', 'embed_positions', 'embedding'), ('decoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layernorm_embedding', 'bias'), ('classification_head', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'q_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '0', 'fc2', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('encoder', 'embed_positions', 'kernel'), ('encoder', 'layers', '1', 'fc1', 'bias'), ('encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'final_layer_norm', 'bias'), ('qa_outputs', 'bias'), ('classification_head', 'out_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn', 'v_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn_layer_norm', 'scale'), ('decoder', 'layers', '1', 'encoder_attn', 'k_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'final_layer_norm', 'kernel'), ('qa_outputs', 'kernel'), ('classification_head', 'dense', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn_layer_norm', 'kernel'), ('encoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'v_proj', 'bias'), ('decoder', 'layers', '1', 'final_layer_norm', 'scale'), ('encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('decoder', 'layers', '1', 'final_layer_norm', 'bias'), ('decoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'fc2', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn', 'q_proj', 'kernel'), ('classification_head', 'dense', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('decoder', 'layers', '1', 'fc1', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('encoder', 'layers', '0', 'self_attn_layer_norm', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel')}\n",
53 | "- This IS expected if you are initializing FlaxBartForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
54 | "- This IS NOT expected if you are initializing FlaxBartForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
55 | "You passed along `num_labels=3` with an incompatible id to label map: {0: 'LABEL_0', 1: 'LABEL_1'}. The number of labels wil be overwritten to 2.\n",
56 | "Some weights of the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 were not used when initializing FlaxWav2Vec2Model: {('quantizer', 'weight_proj', 'bias'), ('lm_head', 'bias'), ('project_q', 'bias'), ('quantizer', 'codevectors'), ('quantizer', 'weight_proj', 'kernel'), ('lm_head', 'kernel'), ('project_hid', 'bias'), ('project_q', 'kernel'), ('project_hid', 'kernel')}\n",
57 | "- This IS expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
58 | "- This IS NOT expected if you are initializing FlaxWav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
59 | "Some weights of FlaxWav2Vec2Model were not initialized from the model checkpoint at hf-internal-testing/tiny-random-wav2vec2 and are newly initialized: {('feature_extractor', 'conv_layers', '2', 'layer_norm', 'bias'), ('feature_extractor', 'conv_layers', '1', 'layer_norm', 'scale'), ('feature_extractor', 'conv_layers', '2', 'layer_norm', 'scale'), ('feature_extractor', 'conv_layers', '1', 'layer_norm', 'bias')}\n",
60 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
61 | "You passed along `num_labels=3` with an incompatible id to label map: {'0': 'LABEL_0', '1': 'LABEL_1'}. The number of labels wil be overwritten to 2.\n",
62 | "Some weights of the model checkpoint at hf-internal-testing/tiny-random-bart were not used when initializing FlaxBartForCausalLM: {('encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('decoder', 'layers', '1', 'fc1', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn_layer_norm', 'scale'), ('model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'final_layer_norm', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn_layer_norm', 'bias'), ('encoder', 'layers', '0', 'fc1', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('model', 'shared', 'kernel'), ('shared', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn_layer_norm', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'embed_tokens', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('final_logits_bias',), ('decoder', 'embed_tokens', 'embedding'), ('model', 'encoder', 'layernorm_embedding', 'kernel'), ('encoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn_layer_norm', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '1', 'fc2', 'bias'), ('encoder', 'embed_tokens', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'final_layer_norm', 'bias'), ('encoder', 'layernorm_embedding', 'kernel'), ('decoder', 'layers', '0', 'final_layer_norm', 'scale'), ('lm_head', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'final_layer_norm', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'out_proj', 'bias'), ('encoder', 'layers', '1', 'fc2', 'bias'), ('model', 'encoder', 'layers', '1', 'final_layer_norm', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('decoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'fc2', 'kernel'), ('decoder', 'layers', '0', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'fc1', 'kernel'), ('model', 'encoder', 'layernorm_embedding', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '1', 'fc2', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '0', 'fc2', 'kernel'), ('model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('decoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn_layer_norm', 'scale'), ('encoder', 'layernorm_embedding', 'bias'), ('decoder', 'layers', '0', 'encoder_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn_layer_norm', 'scale'), ('decoder', 'layers', '1', 'encoder_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'final_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('decoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'fc2', 'bias'), ('model', 'encoder', 'layers', '1', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '0', 'final_layer_norm', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('model', 'encoder', 'layers', '1', 'fc2', 'kernel'), ('model', 'encoder', 'layers', '0', 'fc1', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('encoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'k_proj', 'kernel'), ('encoder', 'layers', '0', 'final_layer_norm', 'kernel'), ('model', 'encoder', 'embed_positions', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layernorm_embedding', 'scale'), ('decoder', 'embed_positions', 'embedding'), ('decoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'fc1', 'kernel'), ('decoder', 'layernorm_embedding', 'bias'), ('classification_head', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'q_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('encoder', 'layers', '0', 'fc2', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('encoder', 'embed_positions', 'kernel'), ('encoder', 'layers', '1', 'fc1', 'bias'), ('encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'final_layer_norm', 'bias'), ('qa_outputs', 'bias'), ('classification_head', 'out_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn', 'v_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn_layer_norm', 'scale'), ('decoder', 'layers', '1', 'encoder_attn', 'k_proj', 'bias'), ('encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'fc1', 'kernel'), ('decoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('model', 'encoder', 'layers', '0', 'fc2', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'final_layer_norm', 'kernel'), ('qa_outputs', 'kernel'), ('classification_head', 'dense', 'bias'), ('decoder', 'layers', '1', 'encoder_attn', 'v_proj', 'kernel'), ('encoder', 'layers', '1', 'self_attn_layer_norm', 'kernel'), ('encoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '0', 'encoder_attn', 'v_proj', 'bias'), ('decoder', 'layers', '1', 'final_layer_norm', 'scale'), ('encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('decoder', 'layers', '1', 'final_layer_norm', 'bias'), ('decoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('model', 'encoder', 'layers', '0', 'fc2', 'kernel'), ('decoder', 'layers', '1', 'encoder_attn', 'q_proj', 'kernel'), ('classification_head', 'dense', 'kernel'), ('encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('decoder', 'layers', '1', 'fc1', 'bias'), ('model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('encoder', 'layers', '0', 'self_attn_layer_norm', 'kernel'), ('decoder', 'layers', '0', 'encoder_attn', 'v_proj', 'kernel'), ('decoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('model', 'encoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('decoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel')}\n",
63 | "- This IS expected if you are initializing FlaxBartForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
64 | "- This IS NOT expected if you are initializing FlaxBartForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
65 | "You passed along `num_labels=3` with an incompatible id to label map: {0: 'LABEL_0', 1: 'LABEL_1'}. The number of labels wil be overwritten to 2.\n"
66 | ]
67 | }
68 | ],
69 | "source": [
70 | "hf_model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_from_pt=True, decoder_from_pt=True)\n",
71 | "custom_model = CustomFlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_from_pt=True, decoder_from_pt=True)"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": 4,
77 | "id": "fdadc85b-768d-46c7-aa26-01b1f93fd635",
78 | "metadata": {},
79 | "outputs": [],
80 | "source": [
81 | "# create some dummy data\n",
82 | "inputs = np.random.randn(2, 2000)\n",
83 | "decoder_input_ids = np.arange(100).reshape(2,50)"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": 5,
89 | "id": "caaff775-e082-419f-b737-b392667469d5",
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "# get ground-truth outputs from Transformers 🤗 model\n",
94 | "hf_outputs = hf_model(inputs, decoder_input_ids=decoder_input_ids, output_hidden_states=True)"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": 6,
100 | "id": "1ff2be6b-73ef-49fc-8dbe-e933b5d4f8b0",
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "extract_features = custom_model.encode(inputs, output_features=True)"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": 7,
110 | "id": "0e2b2d35-dee8-4c47-be35-e543b0f9ef03",
111 | "metadata": {},
112 | "outputs": [],
113 | "source": [
114 | "custom_outputs = custom_model(inputs, extract_features=extract_features, decoder_input_ids=decoder_input_ids, output_hidden_states=True)"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": 8,
120 | "id": "aefb168b-4882-41f1-ba88-39f298835dac",
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "# define a helper function for our analysis\n",
125 | "def assert_almost_equals(a: np.ndarray, b: np.ndarray, tol: float = 1e-9):\n",
126 | " diff = np.abs((a - b)).max()\n",
127 | " if diff <= tol:\n",
128 | " print(f\"✅ Difference between Flax and PyTorch is {diff} (< {tol})\")\n",
129 | " else:\n",
130 | " print(f\"❌ Difference between Flax and PyTorch is {diff} (>= {tol})\")"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": 9,
136 | "id": "dea2f7c9-a462-40c1-b09b-7c8499318e87",
137 | "metadata": {},
138 | "outputs": [
139 | {
140 | "name": "stdout",
141 | "output_type": "stream",
142 | "text": [
143 | "--------------------------Checking encoder hidden states match--------------------------\n",
144 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n",
145 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n",
146 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n",
147 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n",
148 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n",
149 | "--------------------------Checking encoder last hidden states match--------------------------\n",
150 | "HF output shape: (2, 29, 16), custom output shape: (2, 29, 16)\n",
151 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n",
152 | "--------------------------Checking decoder hidden states match--------------------------\n",
153 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n",
154 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n",
155 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n",
156 | "--------------------------Checking logits match--------------------------\n",
157 | "HF logits shape: (2, 50, 1000), Custom logits shape: (2, 50, 1000)\n",
158 | "✅ Difference between Flax and PyTorch is 0.0 (< 1e-09)\n"
159 | ]
160 | }
161 | ],
162 | "source": [
163 | "print(\"--------------------------Checking encoder hidden states match--------------------------\")\n",
164 | "for hf_state, custom_state in zip(hf_outputs.encoder_hidden_states, custom_outputs.encoder_hidden_states):\n",
165 | " assert hf_state.shape == custom_state.shape\n",
166 | " assert_almost_equals(hf_state, custom_state)\n",
167 | "\n",
168 | "print(\"--------------------------Checking encoder last hidden states match--------------------------\")\n",
169 | "print(f\"HF output shape: {hf_outputs.encoder_last_hidden_state.shape}, custom output shape: {custom_outputs.encoder_last_hidden_state.shape}\")\n",
170 | "assert_almost_equals(hf_outputs.encoder_last_hidden_state, custom_outputs.encoder_last_hidden_state)\n",
171 | "\n",
172 | "print(\"--------------------------Checking decoder hidden states match--------------------------\")\n",
173 | "for hf_state, custom_state in zip(hf_outputs.decoder_hidden_states, custom_outputs.decoder_hidden_states):\n",
174 | " assert hf_state.shape == custom_state.shape\n",
175 | " assert_almost_equals(hf_state, custom_state)\n",
176 | "\n",
177 | "print(\"--------------------------Checking logits match--------------------------\")\n",
178 | "print(f\"HF logits shape: {hf_outputs.logits.shape}, Custom logits shape: {custom_outputs.logits.shape}\")\n",
179 | "assert_almost_equals(hf_outputs.logits, custom_outputs.logits)"
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": null,
185 | "id": "36f0a6f8-fc48-4326-bbaf-4ab5194f84a6",
186 | "metadata": {},
187 | "outputs": [],
188 | "source": []
189 | }
190 | ],
191 | "metadata": {
192 | "kernelspec": {
193 | "display_name": "Python 3 (ipykernel)",
194 | "language": "python",
195 | "name": "python3"
196 | },
197 | "language_info": {
198 | "codemirror_mode": {
199 | "name": "ipython",
200 | "version": 3
201 | },
202 | "file_extension": ".py",
203 | "mimetype": "text/x-python",
204 | "name": "python",
205 | "nbconvert_exporter": "python",
206 | "pygments_lexer": "ipython3",
207 | "version": "3.8.9"
208 | }
209 | },
210 | "nbformat": 4,
211 | "nbformat_minor": 5
212 | }
213 |
--------------------------------------------------------------------------------