├── config ├── dataset │ ├── common_voice_czech.yaml │ ├── common_voice_dutch.yaml │ ├── common_voice_greek.yaml │ ├── common_voice_hindi.yaml │ ├── common_voice_tamil.yaml │ ├── common_voice_thai.yaml │ ├── common_voice_arabic.yaml │ ├── common_voice_catalan.yaml │ ├── common_voice_estonian.yaml │ ├── common_voice_finnish.yaml │ ├── common_voice_georgian.yaml │ ├── common_voice_hungarian.yaml │ ├── common_voice_indonesian.yaml │ ├── common_voice_italian.yaml │ ├── common_voice_japanese.yaml │ ├── common_voice_latvian.yaml │ ├── common_voice_lithuanian.yaml │ ├── common_voice_portuguese.yaml │ ├── common_voice_romanian.yaml │ ├── common_voice_slovenian.yaml │ ├── common_voice_turkish.yaml │ ├── common_voice_ukranian.yaml │ ├── common_voice_vietnamese.yaml │ ├── mls_dutch_10h.yaml │ ├── mls_dutch_1h.yaml │ ├── mls_french_1h.yaml │ ├── mls_german_1h.yaml │ ├── mls_polish_1h.yaml │ ├── mls_english_1h.yaml │ ├── mls_french_10h.yaml │ ├── mls_german_10h.yaml │ ├── mls_italian_10h.yaml │ ├── mls_italian_1h.yaml │ ├── mls_polish_10h.yaml │ ├── mls_spanish_10h.yaml │ ├── mls_spanish_1h.yaml │ ├── mls_portuguese_10h.yaml │ └── mls_portuguese_1h.yaml ├── distill │ ├── vanilla.yaml │ ├── random_init.yaml │ ├── infoxlm_base_head.yaml │ ├── xlm_roberta_base_body.yaml │ ├── xlm_roberta_base_head.yaml │ ├── xlm_roberta_base_tail.yaml │ ├── infoxlm_large_head.yaml │ ├── xlm_roberta_large_head.yaml │ ├── xlm_roberta_base_headx2.yaml │ ├── xlm_roberta_base_headx3.yaml │ ├── interpolate_linear.yaml │ ├── interpolate_nearest.yaml │ ├── interpolate_linear_nonfilter.yaml │ ├── interpolate_nearest_nonfilter.yaml │ └── shrink.yaml ├── xlsr │ └── w2v2_xlsr.yaml └── train │ ├── v1.yaml │ └── v1_large.yaml ├── resources └── figure.png ├── requirements.txt ├── LICENSE ├── README.md ├── create_mls_csv.py ├── train.py ├── model_utils.py └── data_utils.py /config/dataset/common_voice_czech.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: cs 3 | -------------------------------------------------------------------------------- /config/dataset/common_voice_dutch.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: nl 3 | -------------------------------------------------------------------------------- /config/dataset/common_voice_greek.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: el 3 | -------------------------------------------------------------------------------- /config/dataset/common_voice_hindi.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: hi 3 | -------------------------------------------------------------------------------- /config/dataset/common_voice_tamil.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: ta 3 | -------------------------------------------------------------------------------- /config/dataset/common_voice_thai.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: th 3 | -------------------------------------------------------------------------------- /config/dataset/common_voice_arabic.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: ar 3 | -------------------------------------------------------------------------------- /config/dataset/common_voice_catalan.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: ca 3 | -------------------------------------------------------------------------------- /config/dataset/common_voice_estonian.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: et 3 | -------------------------------------------------------------------------------- /config/dataset/common_voice_finnish.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: fi 3 | -------------------------------------------------------------------------------- /config/dataset/common_voice_georgian.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: ka 3 | -------------------------------------------------------------------------------- /config/dataset/common_voice_hungarian.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: hu 3 | -------------------------------------------------------------------------------- /config/dataset/common_voice_indonesian.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: id 3 | -------------------------------------------------------------------------------- /config/dataset/common_voice_italian.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: it 3 | -------------------------------------------------------------------------------- /config/dataset/common_voice_japanese.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: ja 3 | -------------------------------------------------------------------------------- /config/dataset/common_voice_latvian.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: lv 3 | -------------------------------------------------------------------------------- /config/dataset/common_voice_lithuanian.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: lt 3 | -------------------------------------------------------------------------------- /config/dataset/common_voice_portuguese.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: pt 3 | -------------------------------------------------------------------------------- /config/dataset/common_voice_romanian.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: ro 3 | -------------------------------------------------------------------------------- /config/dataset/common_voice_slovenian.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: sl 3 | -------------------------------------------------------------------------------- /config/dataset/common_voice_turkish.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: tr 3 | -------------------------------------------------------------------------------- /config/dataset/common_voice_ukranian.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: uk 3 | -------------------------------------------------------------------------------- /config/dataset/common_voice_vietnamese.yaml: -------------------------------------------------------------------------------- 1 | name: common_voice 2 | language: vi 3 | -------------------------------------------------------------------------------- /config/dataset/mls_dutch_10h.yaml: -------------------------------------------------------------------------------- 1 | name: multilingual_librispeech_10h 2 | language: dutch 3 | -------------------------------------------------------------------------------- /config/dataset/mls_dutch_1h.yaml: -------------------------------------------------------------------------------- 1 | name: multilingual_librispeech_1h 2 | language: dutch 3 | -------------------------------------------------------------------------------- /config/dataset/mls_french_1h.yaml: -------------------------------------------------------------------------------- 1 | name: multilingual_librispeech_1h 2 | language: french 3 | -------------------------------------------------------------------------------- /config/dataset/mls_german_1h.yaml: -------------------------------------------------------------------------------- 1 | name: multilingual_librispeech_1h 2 | language: german 3 | -------------------------------------------------------------------------------- /config/dataset/mls_polish_1h.yaml: -------------------------------------------------------------------------------- 1 | name: multilingual_librispeech_1h 2 | language: polish 3 | -------------------------------------------------------------------------------- /config/dataset/mls_english_1h.yaml: -------------------------------------------------------------------------------- 1 | name: multilingual_librispeech_1h 2 | language: english 3 | -------------------------------------------------------------------------------- /config/dataset/mls_french_10h.yaml: -------------------------------------------------------------------------------- 1 | name: multilingual_librispeech_10h 2 | language: french 3 | -------------------------------------------------------------------------------- /config/dataset/mls_german_10h.yaml: -------------------------------------------------------------------------------- 1 | name: multilingual_librispeech_10h 2 | language: german 3 | -------------------------------------------------------------------------------- /config/dataset/mls_italian_10h.yaml: -------------------------------------------------------------------------------- 1 | name: multilingual_librispeech_10h 2 | language: italian 3 | -------------------------------------------------------------------------------- /config/dataset/mls_italian_1h.yaml: -------------------------------------------------------------------------------- 1 | name: multilingual_librispeech_1h 2 | language: italian 3 | -------------------------------------------------------------------------------- /config/dataset/mls_polish_10h.yaml: -------------------------------------------------------------------------------- 1 | name: multilingual_librispeech_10h 2 | language: polish 3 | -------------------------------------------------------------------------------- /config/dataset/mls_spanish_10h.yaml: -------------------------------------------------------------------------------- 1 | name: multilingual_librispeech_10h 2 | language: spanish 3 | -------------------------------------------------------------------------------- /config/dataset/mls_spanish_1h.yaml: -------------------------------------------------------------------------------- 1 | name: multilingual_librispeech_1h 2 | language: spanish 3 | -------------------------------------------------------------------------------- /resources/figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juice500ml/xlm_to_xlsr/HEAD/resources/figure.png -------------------------------------------------------------------------------- /config/dataset/mls_portuguese_10h.yaml: -------------------------------------------------------------------------------- 1 | name: multilingual_librispeech_10h 2 | language: portuguese 3 | -------------------------------------------------------------------------------- /config/dataset/mls_portuguese_1h.yaml: -------------------------------------------------------------------------------- 1 | name: multilingual_librispeech_1h 2 | language: portuguese 3 | -------------------------------------------------------------------------------- /config/distill/vanilla.yaml: -------------------------------------------------------------------------------- 1 | name: vanilla 2 | lm_name: xlm-roberta-base 3 | attn_loss: 0.0 4 | feat_loss: 0.0 5 | random_init: False 6 | -------------------------------------------------------------------------------- /config/distill/random_init.yaml: -------------------------------------------------------------------------------- 1 | name: random_init 2 | lm_name: xlm-roberta-base 3 | attn_loss: 0.0 4 | feat_loss: 0.0 5 | random_init: True 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html 2 | torchaudio==0.8.0 3 | datasets==1.13.3 4 | transformers==4.11.3 5 | librosa 6 | jiwer 7 | tensorboard==2.7.0 8 | hydra-core==1.1.1 9 | -------------------------------------------------------------------------------- /config/xlsr/w2v2_xlsr.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_name_or_path: facebook/wav2vec2-large-xlsr-53 2 | ctc_loss_reduction: mean 3 | attention_dropout: 0.1 4 | hidden_dropout: 0.1 5 | feat_proj_dropout: 0.0 6 | mask_time_prob: 0.05 7 | layerdrop: 0.1 8 | -------------------------------------------------------------------------------- /config/distill/infoxlm_base_head.yaml: -------------------------------------------------------------------------------- 1 | lm_name: microsoft/infoxlm-base 2 | lm_attn_size: 100 3 | sm_attn_size: 499 4 | lm_feat_size: 768 5 | sm_feat_size: 1024 6 | 7 | attn_loss: 0.0 8 | feat_loss: 1.0 9 | 10 | # w2v2: 25, infoxlm-base: 13 11 | name: infoxlm_base_head 12 | feat_target: 13 | - lm_index: 12 14 | sm_index: 24 15 | -------------------------------------------------------------------------------- /config/distill/xlm_roberta_base_body.yaml: -------------------------------------------------------------------------------- 1 | lm_name: xlm-roberta-base 2 | lm_attn_size: 100 3 | sm_attn_size: 499 4 | lm_feat_size: 768 5 | sm_feat_size: 1024 6 | 7 | attn_loss: 0.0 8 | feat_loss: 1.0 9 | 10 | # w2v2: 25, xlm-roberta-base: 13 11 | name: xlm_roberta_base_body 12 | feat_target: 13 | - lm_index: 6 14 | sm_index: 12 15 | -------------------------------------------------------------------------------- /config/distill/xlm_roberta_base_head.yaml: -------------------------------------------------------------------------------- 1 | lm_name: xlm-roberta-base 2 | lm_attn_size: 100 3 | sm_attn_size: 499 4 | lm_feat_size: 768 5 | sm_feat_size: 1024 6 | 7 | attn_loss: 0.0 8 | feat_loss: 1.0 9 | 10 | # w2v2: 25, xlm-roberta-base: 13 11 | name: xlm_roberta_base_head 12 | feat_target: 13 | - lm_index: 12 14 | sm_index: 24 15 | -------------------------------------------------------------------------------- /config/distill/xlm_roberta_base_tail.yaml: -------------------------------------------------------------------------------- 1 | lm_name: xlm-roberta-base 2 | lm_attn_size: 100 3 | sm_attn_size: 499 4 | lm_feat_size: 768 5 | sm_feat_size: 1024 6 | 7 | attn_loss: 0.0 8 | feat_loss: 1.0 9 | 10 | # w2v2: 25, xlm-roberta-base: 13 11 | name: xlm_roberta_base_tail 12 | feat_target: 13 | - lm_index: 0 14 | sm_index: 0 15 | -------------------------------------------------------------------------------- /config/distill/infoxlm_large_head.yaml: -------------------------------------------------------------------------------- 1 | lm_name: microsoft/infoxlm-large 2 | lm_attn_size: 100 3 | sm_attn_size: 499 4 | lm_feat_size: 1024 5 | sm_feat_size: 1024 6 | 7 | attn_loss: 0.0 8 | feat_loss: 1.0 9 | 10 | # w2v2: 25, xlm-roberta-base: 25 11 | name: xlm_roberta_large_head 12 | feat_target: 13 | - lm_index: 24 14 | sm_index: 24 15 | -------------------------------------------------------------------------------- /config/distill/xlm_roberta_large_head.yaml: -------------------------------------------------------------------------------- 1 | lm_name: xlm-roberta-large 2 | lm_attn_size: 100 3 | sm_attn_size: 499 4 | lm_feat_size: 1024 5 | sm_feat_size: 1024 6 | 7 | attn_loss: 0.0 8 | feat_loss: 1.0 9 | 10 | # w2v2: 25, xlm-roberta-base: 25 11 | name: xlm_roberta_large_head 12 | feat_target: 13 | - lm_index: 24 14 | sm_index: 24 15 | -------------------------------------------------------------------------------- /config/distill/xlm_roberta_base_headx2.yaml: -------------------------------------------------------------------------------- 1 | lm_name: xlm-roberta-base 2 | lm_attn_size: 100 3 | sm_attn_size: 499 4 | lm_feat_size: 768 5 | sm_feat_size: 1024 6 | 7 | attn_loss: 0.0 8 | feat_loss: 1.0 9 | 10 | # w2v2: 25, xlm-roberta-base: 13 11 | name: xlm_roberta_base_headx2 12 | feat_target: 13 | - lm_index: 12 14 | sm_index: 24 15 | - lm_index: 11 16 | sm_index: 23 17 | -------------------------------------------------------------------------------- /config/distill/xlm_roberta_base_headx3.yaml: -------------------------------------------------------------------------------- 1 | lm_name: xlm-roberta-base 2 | lm_attn_size: 100 3 | sm_attn_size: 499 4 | lm_feat_size: 768 5 | sm_feat_size: 1024 6 | 7 | attn_loss: 0.0 8 | feat_loss: 1.0 9 | 10 | # w2v2: 25, xlm-roberta-base: 13 11 | name: xlm_roberta_base_headx3 12 | feat_target: 13 | - lm_index: 12 14 | sm_index: 24 15 | - lm_index: 11 16 | sm_index: 23 17 | - lm_index: 10 18 | sm_index: 22 19 | -------------------------------------------------------------------------------- /config/distill/interpolate_linear.yaml: -------------------------------------------------------------------------------- 1 | lm_name: microsoft/infoxlm-base 2 | lm_attn_size: 100 3 | sm_attn_size: 499 4 | lm_feat_size: 768 5 | sm_feat_size: 1024 6 | 7 | attn_loss: 0.0 8 | feat_loss: 0.25 9 | 10 | # w2v2: 25, infoxlm-base: 13 11 | name: interpolate_v2_linear 12 | feat_target: 13 | - lm_index: 12 14 | sm_index: 24 15 | 16 | # feature interpolation 17 | interpolation: 18 | mode: linear # nearest, linear 19 | filter_out_pad: True 20 | -------------------------------------------------------------------------------- /config/distill/interpolate_nearest.yaml: -------------------------------------------------------------------------------- 1 | lm_name: microsoft/infoxlm-base 2 | lm_attn_size: 100 3 | sm_attn_size: 499 4 | lm_feat_size: 768 5 | sm_feat_size: 1024 6 | 7 | attn_loss: 0.0 8 | feat_loss: 0.25 9 | 10 | # w2v2: 25, infoxlm-base: 13 11 | name: interpolate_v2_nearest 12 | feat_target: 13 | - lm_index: 12 14 | sm_index: 24 15 | 16 | # feature interpolation 17 | interpolation: 18 | mode: nearest # nearest, linear 19 | filter_out_pad: True 20 | -------------------------------------------------------------------------------- /config/distill/interpolate_linear_nonfilter.yaml: -------------------------------------------------------------------------------- 1 | lm_name: microsoft/infoxlm-base 2 | lm_attn_size: 100 3 | sm_attn_size: 499 4 | lm_feat_size: 768 5 | sm_feat_size: 1024 6 | 7 | attn_loss: 0.0 8 | feat_loss: 0.25 9 | 10 | # w2v2: 25, infoxlm-base: 13 11 | name: interpolate_v2_linear_nonfilter 12 | feat_target: 13 | - lm_index: 12 14 | sm_index: 24 15 | 16 | # feature interpolation 17 | interpolation: 18 | mode: linear # nearest, linear 19 | filter_out_pad: False 20 | -------------------------------------------------------------------------------- /config/distill/interpolate_nearest_nonfilter.yaml: -------------------------------------------------------------------------------- 1 | lm_name: microsoft/infoxlm-base 2 | lm_attn_size: 100 3 | sm_attn_size: 499 4 | lm_feat_size: 768 5 | sm_feat_size: 1024 6 | 7 | attn_loss: 0.0 8 | feat_loss: 0.25 9 | 10 | # w2v2: 25, infoxlm-base: 13 11 | name: interpolate_v2_nearest_nonfilter 12 | feat_target: 13 | - lm_index: 12 14 | sm_index: 24 15 | 16 | # feature interpolation 17 | interpolation: 18 | mode: nearest # nearest, linear 19 | filter_out_pad: False 20 | -------------------------------------------------------------------------------- /config/distill/shrink.yaml: -------------------------------------------------------------------------------- 1 | lm_name: microsoft/infoxlm-base 2 | lm_attn_size: 100 3 | sm_attn_size: 499 4 | lm_feat_size: 768 5 | sm_feat_size: 1024 6 | 7 | attn_loss: 0.0 8 | feat_loss: 1.0 9 | 10 | # w2v2: 25, infoxlm-base: 13 11 | name: shrink 12 | feat_target: 13 | - lm_index: 12 14 | sm_index: 24 15 | 16 | # feature interpolation 17 | interpolation: 18 | mode: nearest # nearest, linear 19 | filter_out_pad: False 20 | shrink: True 21 | 22 | random_init: False 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Kwanghee Choi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config/train/v1.yaml: -------------------------------------------------------------------------------- 1 | group_by_length: True 2 | logging_steps: 10 3 | do_train: True 4 | do_eval: True 5 | do_predict: True 6 | seed: 42 7 | 8 | # https://github.com/pytorch/fairseq/blob/7f5ec30/examples/wav2vec/xlsr/config/finetune.yaml#L10-L14 9 | evaluation_strategy: steps 10 | save_strategy: steps 11 | save_steps: 2000 12 | eval_steps: 2000 13 | metric_for_best_model: wer 14 | greater_is_better: False 15 | save_total_limit: 2 16 | 17 | # https://github.com/pytorch/fairseq/blob/7f5ec30/examples/wav2vec/xlsr/config/finetune.yaml#L4 18 | fp16: True 19 | fp16_full_eval: True 20 | 21 | # https://github.com/pytorch/fairseq/tree/7f5ec30/examples/wav2vec/xlsr 22 | per_device_train_batch_size: 12 # Need 2 GPUs for batch_size=24 23 | per_device_eval_batch_size: 16 24 | gradient_accumulation_steps: 1 25 | 26 | # Inside the paper: 27 | # We determine the best learning rates setting in [2e-5, 6e-5] based on dev set error rate 28 | learning_rate: 2e-5 29 | 30 | # https://github.com/pytorch/fairseq/blob/7f5ec30/examples/wav2vec/xlsr/config/finetune.yaml#L44-L47 31 | adam_beta1: 0.9 32 | adam_beta2: 0.98 33 | adam_epsilon: 1e-08 34 | 35 | # Inside the paper: 36 | # warm up for the first 10% of updates, 37 | # keep constant for 40% and then linearly decay for the remainder. 38 | # For CommonVoice, we fine-tune for 20k updates. 39 | # Most similar implementation w/o major code fixes: cosine + warmup 40 | max_steps: 20000 41 | warmup_steps: 2000 42 | lr_scheduler_type: cosine 43 | -------------------------------------------------------------------------------- /config/train/v1_large.yaml: -------------------------------------------------------------------------------- 1 | group_by_length: True 2 | logging_steps: 10 3 | do_train: True 4 | do_eval: True 5 | do_predict: True 6 | seed: 42 7 | 8 | # https://github.com/pytorch/fairseq/blob/7f5ec30/examples/wav2vec/xlsr/config/finetune.yaml#L10-L14 9 | evaluation_strategy: steps 10 | save_strategy: steps 11 | save_steps: 2000 12 | eval_steps: 2000 13 | metric_for_best_model: wer 14 | greater_is_better: False 15 | save_total_limit: 2 16 | 17 | # https://github.com/pytorch/fairseq/blob/7f5ec30/examples/wav2vec/xlsr/config/finetune.yaml#L4 18 | fp16: True 19 | fp16_full_eval: True 20 | 21 | # https://github.com/pytorch/fairseq/tree/7f5ec30/examples/wav2vec/xlsr 22 | per_device_train_batch_size: 3 # Need 8 GPUs for batch_size=24 23 | per_device_eval_batch_size: 6 24 | gradient_accumulation_steps: 1 25 | 26 | # Inside the paper: 27 | # We determine the best learning rates setting in [2e-5, 6e-5] based on dev set error rate 28 | learning_rate: 2e-5 29 | 30 | # https://github.com/pytorch/fairseq/blob/7f5ec30/examples/wav2vec/xlsr/config/finetune.yaml#L44-L47 31 | adam_beta1: 0.9 32 | adam_beta2: 0.98 33 | adam_epsilon: 1e-08 34 | 35 | # Inside the paper: 36 | # warm up for the first 10% of updates, 37 | # keep constant for 40% and then linearly decay for the remainder. 38 | # For CommonVoice, we fine-tune for 20k updates. 39 | # Most similar implementation w/o major code fixes: cosine + warmup 40 | max_steps: 20000 41 | warmup_steps: 2000 42 | lr_scheduler_type: cosine 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Distilling a Pretrained Language Model to a Multilingual ASR Model 2 | ![plot](./resources/figure.png) 3 | - Official implementation of the paper: https://arxiv.org/abs/2206.12638 4 | - Accepted to Interspeech 2022. 5 | 6 | ## Oral presentation @ Interspeech 7 | presentation 10 | 11 | 12 | ## How to run experiments (Table 1) 13 | **Environments** 14 | - I used Python 3.8.12. 15 | - Check [requirements.txt](./requirements.txt) for additional requirements. 16 | 17 | **Supported datasets** 18 | - Check [configs](config/dataset) for supported datasets. 19 | - For example, if you want CommonVoice Czech, set `$dataset` as `common_voice_czech`. 20 | 21 | **From scratch** 22 | ```bash 23 | # If you change the # of GPUs, you have to fix per_device_train_batch_size in training config. 24 | CUDA_VISIBLE_DEVICES=0,1 python3 train.py \ 25 | +distill=random_init \ 26 | +dataset=$dataset \ 27 | +train=v1 \ 28 | +xlsr=w2v2_xlsr 29 | ``` 30 | **Fine-tuning** 31 | ```bash 32 | CUDA_VISIBLE_DEVICES=0,1 python3 train.py \ 33 | +distill=vanilla \ 34 | +dataset=$dataset \ 35 | +train=v1 \ 36 | +xlsr=w2v2_xlsr 37 | ``` 38 | **Fine-tuning + Distill-L2S** 39 | ```bash 40 | # You have to set $lambda as the trade-off hyperparameter, i.e., 0.25, 0.5 or 1.0. 41 | CUDA_VISIBLE_DEVICES=0,1 python3 train.py \ 42 | +distill=shrink \ 43 | +dataset=$dataset \ 44 | +train=v1 \ 45 | +xlsr=w2v2_xlsr \ 46 | distill.feat_loss=$lambda 47 | ``` 48 | -------------------------------------------------------------------------------- /create_mls_csv.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import pandas as pd 5 | 6 | 7 | def _read_transcription(foldername): 8 | df = pd.read_csv(foldername / 'transcripts.txt', sep='\t', names=['audio', 'sentence']) 9 | df.audio = df.audio.apply(lambda s: f'{foldername}/audio/{"/".join(s.split("_")[:2])}/{s}.wav') 10 | return df 11 | 12 | 13 | def _prepare_handles(mls_root, hour): 14 | assert hour in (1, 10) 15 | one_handles = mls_root / 'train' / 'limited_supervision' / '1hr' 16 | s = set() 17 | 18 | for f in one_handles.glob('*/handles.txt'): 19 | s.update([line.strip() for line in open(f).readlines()]) 20 | 21 | if hour == 10: 22 | nine_handle = mls_root / 'train' / 'limited_supervision' / '9hr' / 'handles.txt' 23 | s.update([line.strip() for line in open(nine_handle).readlines()]) 24 | 25 | return s 26 | 27 | 28 | def _remove_handle_mask(df, handles): 29 | return df.apply( 30 | lambda x: x.audio.split('/')[-1].split('.')[0] in handles, axis=1 31 | ) 32 | 33 | 34 | if __name__ == '__main__': 35 | mls_root = Path(os.environ['MLS_OPUS_ROOT']) 36 | assert mls_root.exists() 37 | 38 | dataset_name = mls_root.name.split('_opus')[0] 39 | print(f'Create dataset: {dataset_name}') 40 | 41 | dataset_csv_root = Path('dataset_csv') 42 | dataset_csv_root.mkdir(exist_ok=True) 43 | 44 | for split in ('train', 'dev', 'test'): 45 | df = _read_transcription(mls_root / split) 46 | df.to_csv(dataset_csv_root / f'{dataset_name}_{split}.csv', index=False) 47 | print(f'{split}: {len(df)}') 48 | if split == 'train': 49 | for hour in (1, 10): 50 | hour_handles = _prepare_handles(mls_root, hour) 51 | hour_df = df[_remove_handle_mask(df, hour_handles)] 52 | assert len(hour_handles) == len(hour_df) 53 | hour_df.to_csv(dataset_csv_root / f'{dataset_name}_{split}_{hour}h.csv', index=False) 54 | print(f'{split}_{hour}h: {len(hour_df)}') 55 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import json 2 | from functools import partial 3 | 4 | import hydra 5 | import numpy as np 6 | from datasets import load_metric 7 | from omegaconf import OmegaConf 8 | from transformers import (AutoTokenizer, Trainer, TrainingArguments) 9 | 10 | from data_utils import (get_output_dir, get_processor, load_datasets) 11 | from model_utils import Wav2Vec2ForDistill, DistillTrainer 12 | 13 | 14 | def get_compute_metrics(processor): 15 | wer_metric = load_metric("wer") 16 | cer_metric = load_metric("cer") 17 | 18 | def _compute_metrics(pred): 19 | pred_logits = pred.predictions 20 | pred_ids = np.argmax(pred_logits, axis=-1) 21 | 22 | pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id 23 | 24 | pred_str = processor.batch_decode(pred_ids) 25 | 26 | # we do not want to group tokens when computing the metrics 27 | label_str = processor.batch_decode(pred.label_ids, group_tokens=False) 28 | 29 | return { 30 | "wer": wer_metric.compute(predictions=pred_str, references=label_str), 31 | "cer": cer_metric.compute(predictions=pred_str, references=label_str), 32 | } 33 | return _compute_metrics 34 | 35 | 36 | @hydra.main(config_path="config") 37 | def main(cfg): 38 | output_dir = get_output_dir(cfg) 39 | (train_ds, eval_ds, test_ds), cleanser, collator = load_datasets(**cfg.dataset) 40 | 41 | (output_dir / "processor").mkdir(exist_ok=False, parents=False) 42 | processor = get_processor(output_dir / "processor", train_ds, eval_ds) 43 | lm_tokenizer = AutoTokenizer.from_pretrained(cfg.distill.lm_name) 44 | 45 | _cleanse_ds = partial(cleanser, processor=processor, lm_tokenizer=lm_tokenizer) 46 | train_ds, eval_ds, test_ds = _cleanse_ds(train_ds), _cleanse_ds(eval_ds), _cleanse_ds(test_ds) 47 | print(f"Preparing done: {len(train_ds)}, {len(eval_ds)}, {len(test_ds)}") 48 | 49 | data_collator = collator(processor=processor, lm_tokenizer=lm_tokenizer) 50 | 51 | model = Wav2Vec2ForDistill.from_pretrained( 52 | **cfg.xlsr, 53 | pad_token_id=processor.tokenizer.pad_token_id, 54 | vocab_size=len(processor.tokenizer), 55 | task_specific_params=OmegaConf.to_container(cfg.distill, resolve=True) 56 | ) 57 | if cfg.distill.random_init: 58 | model.apply(model._init_weights) 59 | else: 60 | model.freeze_feature_extractor() 61 | 62 | training_args = TrainingArguments( 63 | **cfg.train, 64 | output_dir=output_dir, 65 | load_best_model_at_end=True, 66 | ) 67 | trainer = DistillTrainer( 68 | model=model, 69 | data_collator=data_collator, 70 | args=training_args, 71 | compute_metrics=get_compute_metrics(processor), 72 | train_dataset=train_ds, 73 | eval_dataset=eval_ds, 74 | tokenizer=processor.feature_extractor, 75 | ) 76 | trainer.train() 77 | 78 | preds = trainer.predict(test_dataset=test_ds) 79 | trainer.log(preds.metrics) 80 | 81 | (output_dir / "outputs").mkdir(exist_ok=True, parents=True) 82 | with open(output_dir / "outputs" / "metrics.json", "w") as f: 83 | json.dump(preds.metrics, f) 84 | np.save(output_dir / "outputs" / "preds.pkl", preds.predictions) 85 | np.save(output_dir / "outputs" / "label_ids.pkl", preds.label_ids) 86 | 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import math 3 | from typing import Optional, Tuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | from transformers import AutoModelForMaskedLM, Wav2Vec2PreTrainedModel, Wav2Vec2ForCTC, Wav2Vec2Model, Trainer 8 | from transformers.file_utils import ModelOutput 9 | 10 | 11 | class DistillTrainer(Trainer): 12 | def compute_loss(self, model, inputs, return_outputs=False): 13 | loss, outputs = super().compute_loss(model, inputs, return_outputs=True) 14 | 15 | log_data = {'ctc_loss': outputs.ctc_loss.mean().item()} 16 | if outputs.feat_loss is not None: 17 | log_data['feat_loss'] = outputs.feat_loss.mean().item() 18 | self.log(log_data) 19 | 20 | return (loss, {'logits': outputs.logits}) if return_outputs else loss 21 | 22 | 23 | @dataclass 24 | class DistillLMOutput(ModelOutput): 25 | loss: Optional[torch.FloatTensor] = None 26 | ctc_loss: Optional[torch.FloatTensor] = None 27 | feat_loss: Optional[torch.FloatTensor] = None 28 | logits: torch.FloatTensor = None 29 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 30 | attentions: Optional[Tuple[torch.FloatTensor]] = None 31 | 32 | 33 | class Wav2Vec2ForDistill(Wav2Vec2ForCTC): 34 | def __init__(self, config): 35 | super(Wav2Vec2PreTrainedModel, self).__init__(config) 36 | 37 | self.wav2vec2 = Wav2Vec2Model(config) 38 | self.dropout = nn.Dropout(config.final_dropout) 39 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) 40 | self._vocab_size = config.vocab_size 41 | 42 | cfg = config.task_specific_params 43 | self._train_feat_loss = cfg['feat_loss'] > 0.0 44 | self._train_attn_loss = cfg['attn_loss'] > 0.0 45 | self._feat_loss_weight = cfg['feat_loss'] 46 | self._attn_loss_weight = cfg['attn_loss'] 47 | 48 | self.lm = None 49 | if self._train_feat_loss or self._train_attn_loss: 50 | self.lm = AutoModelForMaskedLM.from_pretrained(cfg['lm_name']).eval() 51 | 52 | if self._train_feat_loss: 53 | self._interpolation_do_filter = cfg['interpolation']['filter_out_pad'] 54 | self._interpolation_do_shrink = cfg['interpolation']['shrink'] 55 | self.temporal_adapter_kwargs = dict( 56 | mode=cfg['interpolation']['mode'], 57 | align_corners=True if cfg['interpolation']['mode'] == 'linear' else None) 58 | self.feat_adapter = nn.Parameter(torch.empty( 59 | cfg['sm_feat_size'], cfg['lm_feat_size'])) 60 | nn.init.kaiming_uniform_(self.feat_adapter, a=math.sqrt(5)) 61 | 62 | assert len(cfg['feat_target']) == 1 63 | self.feat_adapter_config = { 64 | 'lm_index': cfg['feat_target'][0]['lm_index'], 'sm_index': cfg['feat_target'][0]['sm_index']} 65 | 66 | if self._train_attn_loss: 67 | raise NotImplemented 68 | 69 | def forward( 70 | self, 71 | input_values, 72 | attention_mask=None, 73 | output_attentions=None, 74 | output_hidden_states=None, 75 | return_dict=None, 76 | labels=None, 77 | lm_input_ids=None, 78 | lm_attention_mask=None, 79 | ): 80 | outputs = self.wav2vec2( 81 | input_values, 82 | attention_mask=attention_mask, 83 | output_attentions=output_attentions or self._train_attn_loss, 84 | output_hidden_states=output_hidden_states or self._train_feat_loss, 85 | return_dict=return_dict, 86 | ) 87 | 88 | hidden_states = outputs[0] 89 | hidden_states = self.dropout(hidden_states) 90 | 91 | logits = self.lm_head(hidden_states) 92 | 93 | loss = None 94 | ctc_loss = None 95 | feat_loss = None 96 | 97 | if labels is not None: 98 | # retrieve loss input_lengths from attention_mask 99 | attention_mask = ( 100 | attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) 101 | ) 102 | input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) 103 | 104 | # assuming that padded tokens are filled with -100 105 | # when not being attended to 106 | labels_mask = labels >= 0 107 | target_lengths = labels_mask.sum(-1) 108 | flattened_targets = labels.masked_select(labels_mask) 109 | 110 | # ctc_loss doesn't support fp16 111 | log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) 112 | 113 | with torch.backends.cudnn.flags(enabled=False): 114 | loss = ctc_loss = nn.functional.ctc_loss( 115 | log_probs, 116 | flattened_targets, 117 | input_lengths, 118 | target_lengths, 119 | blank=self.config.pad_token_id, 120 | reduction=self.config.ctc_loss_reduction, 121 | zero_infinity=self.config.ctc_zero_infinity, 122 | ) 123 | 124 | if self.lm: 125 | lm_outputs = self.lm( 126 | input_ids=lm_input_ids, 127 | attention_mask=lm_attention_mask, 128 | output_attentions=self._train_attn_loss, 129 | output_hidden_states=self._train_feat_loss, 130 | ) 131 | 132 | if self._train_feat_loss: 133 | feat_loss = 0.0 134 | sm_feat = outputs['hidden_states'][self.feat_adapter_config['sm_index']] 135 | lm_feat = lm_outputs['hidden_states'][self.feat_adapter_config['lm_index']] 136 | 137 | for sm_logit, sm_length, sm_f, lm_mask, lm_f in \ 138 | zip(logits, input_lengths, sm_feat, lm_attention_mask, lm_feat): 139 | 140 | # Generate sm_mask to filter out speech features 141 | logit_mask = torch.ones(sm_logit.shape[0], dtype=bool, device=sm_length.device) 142 | logit_mask[sm_length:] = False 143 | 144 | sm_f = sm_f[logit_mask] 145 | sm_logit = sm_logit[logit_mask] 146 | lm_f = lm_f[lm_mask.bool()] 147 | 148 | if self._interpolation_do_filter: 149 | sm_mask = (sm_logit.argmax(1) < (self._vocab_size - 2)) 150 | if sm_mask.sum() > 1: 151 | sm_f = sm_f[sm_mask] 152 | 153 | if self._interpolation_do_shrink: 154 | sm_f = self._shrink(sm_logit.argmax(1), sm_f) 155 | 156 | # Feature interpolation (SM -> LM) 157 | feature_adapted_sm_f = torch.tensordot(sm_f, self.feat_adapter, dims=([1], [0])) 158 | 159 | # Time interpolation (LM -> SM) 160 | time_adapted_lm_f = nn.functional.interpolate( 161 | input=torch.unsqueeze(lm_f, 0).permute(0, 2, 1), 162 | size=sm_f.shape[0], 163 | **self.temporal_adapter_kwargs) 164 | time_adapted_lm_f = time_adapted_lm_f.squeeze().permute(1, 0) 165 | 166 | # MSE Loss 167 | feat_loss += nn.functional.mse_loss( 168 | feature_adapted_sm_f, time_adapted_lm_f, reduction='mean', 169 | ) 170 | 171 | loss += self._feat_loss_weight * feat_loss / logits.shape[0] 172 | 173 | return DistillLMOutput( 174 | loss=loss, 175 | ctc_loss=ctc_loss, 176 | feat_loss=feat_loss, 177 | logits=logits, 178 | hidden_states=outputs.hidden_states if output_hidden_states else None, 179 | attentions=outputs.attentions if output_attentions else None, 180 | ) 181 | 182 | def _shrink(self, logit_max, feats): 183 | aligned_feats = [] 184 | 185 | i = 0 186 | while i < len(logit_max): 187 | j = 1 188 | while (i + j) < len(logit_max) and logit_max[i + j].item() == logit_max[i].item(): 189 | j += 1 190 | if logit_max[i].item() < self._vocab_size - 2: 191 | aligned_feats.append(feats[i:i + j].mean(0)) 192 | i += j 193 | 194 | if len(aligned_feats) > 1: 195 | return torch.stack(aligned_feats) 196 | else: 197 | return feats 198 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import sys 4 | import unicodedata 5 | from collections import Counter 6 | from dataclasses import dataclass 7 | from functools import partial 8 | from pathlib import Path 9 | from typing import Optional 10 | 11 | import pandas as pd 12 | from datasets import Audio, load_dataset 13 | from transformers import (AutoTokenizer, Wav2Vec2CTCTokenizer, 14 | Wav2Vec2FeatureExtractor, Wav2Vec2Processor) 15 | 16 | 17 | def get_output_dir(cfg): 18 | p = Path(f"./runs/{cfg.dataset.name}-{cfg.dataset.language}-{cfg.distill.name}-{cfg.distill.feat_loss}") 19 | assert not p.exists() 20 | p.mkdir(parents=True) 21 | return p 22 | 23 | 24 | def load_datasets(name, language): 25 | def _common_voice_process(sp, lang): 26 | ds = load_dataset( 27 | "common_voice", 28 | lang, 29 | split=sp, 30 | cache_dir="/data/dataset/public/huggingface_datasets", 31 | # download_mode="reuse_cache_if_exists", 32 | ) 33 | ds = ds.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"]) 34 | ds = ds.map(remove_special_characters) 35 | show_random_elements(ds.remove_columns(["path", "audio"])) 36 | return ds 37 | 38 | def _mls_1h_process(sp, lang): 39 | name = {"train": "train_1h", "validation": "dev", "test": "test"}[sp] 40 | ds = load_dataset( 41 | "../../..", 42 | data_files=f"dataset_csv/mls_{lang}_{name}.csv", 43 | download_mode="force_redownload", 44 | split="train" 45 | ) 46 | ds = ds.map(remove_special_characters) 47 | show_random_elements(ds) 48 | return ds 49 | 50 | def _mls_10h_process(sp, lang): 51 | name = {"train": "train_10h", "validation": "dev", "test": "test"}[sp] 52 | ds = load_dataset( 53 | "../../..", 54 | data_files=f"dataset_csv/mls_{lang}_{name}.csv", 55 | download_mode="force_redownload", 56 | split="train" 57 | ) 58 | ds = ds.map(remove_special_characters) 59 | show_random_elements(ds) 60 | return ds 61 | 62 | processor, max_seconds, num_proc = { 63 | "common_voice": (_common_voice_process, 10, 1), 64 | "multilingual_librispeech_1h": (_mls_1h_process, 20, 64), 65 | "multilingual_librispeech_10h": (_mls_10h_process, 20, 64), 66 | }[name] 67 | 68 | cleanser = partial(cleanse_dataset, max_seconds=max_seconds, num_proc=num_proc) 69 | collator = partial( 70 | DataCollatorCTCWithPadding, 71 | max_length=16000 * max_seconds, 72 | max_length_labels=20 * max_seconds, 73 | max_length_lm=20 * max_seconds) 74 | datasets = tuple( 75 | processor(split, language) 76 | for split in ("train", "validation", "test")) 77 | 78 | return datasets, cleanser, collator 79 | 80 | 81 | def get_processor(save_dir, train_ds, eval_ds): 82 | vocab_dict = get_vocab(train_ds, eval_ds) 83 | with open(save_dir / "vocab.json", "w") as vocab_file: 84 | json.dump(vocab_dict, vocab_file) 85 | 86 | tokenizer = Wav2Vec2CTCTokenizer( 87 | save_dir / "vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|") 88 | feature_extractor = Wav2Vec2FeatureExtractor( 89 | feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True) 90 | processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) 91 | processor.save_pretrained(save_dir) 92 | 93 | return processor 94 | 95 | 96 | def cleanse_dataset(ds, processor, lm_tokenizer, max_seconds, num_proc): 97 | ds = ds.cast_column("audio", Audio(sampling_rate=16_000)) 98 | 99 | rand_int = random.randint(0, len(ds) - 1) 100 | 101 | print("Target text:", ds[rand_int]["sentence"]) 102 | print("Input array shape:", ds[rand_int]["audio"]["array"].shape) 103 | print("Sampling rate:", ds[rand_int]["audio"]["sampling_rate"]) 104 | 105 | ds = ds.map( 106 | partial(prepare_each_batch, processor=processor, lm_tokenizer=lm_tokenizer), 107 | remove_columns=ds.column_names, num_proc=num_proc) 108 | 109 | print(f"Original dataset size: {len(ds)}") 110 | ds = ds.filter(partial(filter_too_long_audio, max_seconds=max_seconds), num_proc=num_proc) 111 | print(f"Filtered dataset size: {len(ds)}") 112 | 113 | return ds 114 | 115 | 116 | @dataclass 117 | class DataCollatorCTCWithPadding: 118 | processor: Wav2Vec2Processor 119 | lm_tokenizer: AutoTokenizer 120 | max_length: Optional[int] = None 121 | max_length_labels: Optional[int] = None 122 | max_length_lm: Optional[int] = None 123 | pad_to_multiple_of: Optional[int] = None 124 | pad_to_multiple_of_labels: Optional[int] = None 125 | 126 | def __call__(self, features): 127 | input_features = [{"input_values": feature["input_values"]} for feature in features] 128 | label_features = [{"input_ids": feature["labels"]} for feature in features] 129 | 130 | batch = self.processor.pad( 131 | input_features, 132 | padding='max_length', 133 | max_length=self.max_length, 134 | pad_to_multiple_of=self.pad_to_multiple_of, 135 | return_tensors="pt", 136 | ) 137 | with self.processor.as_target_processor(): 138 | labels_batch = self.processor.pad( 139 | label_features, 140 | padding='max_length', 141 | max_length=self.max_length_labels, 142 | pad_to_multiple_of=self.pad_to_multiple_of_labels, 143 | return_tensors="pt", 144 | ) 145 | 146 | # replace padding with -100 to ignore loss correctly 147 | labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) 148 | batch["labels"] = labels 149 | 150 | # Prepare XLM 151 | lm_input_features = [{'input_ids': feature['lm_input_ids']} for feature in features] 152 | lm_batch = self.lm_tokenizer.pad( 153 | lm_input_features, 154 | padding='max_length', 155 | max_length=self.max_length_lm, 156 | return_tensors='pt', 157 | ) 158 | 159 | batch["lm_input_ids"] = lm_batch["input_ids"] 160 | batch["lm_attention_mask"] = lm_batch["attention_mask"] 161 | 162 | return batch 163 | 164 | 165 | def show_random_elements(dataset, num_examples=10): 166 | assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset." 167 | picks = [] 168 | for _ in range(num_examples): 169 | pick = random.randint(0, len(dataset) - 1) 170 | while pick in picks: 171 | pick = random.randint(0, len(dataset) - 1) 172 | picks.append(pick) 173 | 174 | df = pd.DataFrame(dataset[picks]) 175 | print(df) 176 | 177 | 178 | def remove_special_characters( 179 | batch, 180 | punctuation_table=dict.fromkeys( 181 | i for i in range(sys.maxunicode) 182 | if (not unicodedata.category(chr(i)).startswith("L")) and (chr(i) != ' ') 183 | ) 184 | ): 185 | batch["sentence"] = unicodedata.normalize("NFKC", batch["sentence"]) 186 | batch["sentence"] = batch["sentence"].translate(punctuation_table).lower() + " " 187 | return batch 188 | 189 | 190 | def filter_too_long_audio(batch, max_seconds): 191 | return (len(batch["input_values"]) <= 16000 * max_seconds) \ 192 | and (len(batch["lm_input_ids"]) <= 20 * max_seconds) \ 193 | and (len(batch["labels"]) <= 20 * max_seconds) 194 | 195 | 196 | def extract_all_chars(batch): 197 | all_text = " ".join(batch["sentence"]) 198 | vocab = list(set(all_text)) 199 | return {"vocab": [vocab], "all_text": [all_text]} 200 | 201 | 202 | def get_vocab(train_dataset, test_dataset, threshold=0.9999): 203 | counter = Counter() 204 | for dataset in (train_dataset, test_dataset): 205 | for row in dataset: 206 | counter.update(row['sentence']) 207 | 208 | sum_count = 0 209 | total_count = sum(counter.values()) 210 | vocab_dict = {} 211 | 212 | for i, (char, count) in enumerate(counter.most_common()): 213 | sum_count += count 214 | print(f"[{char}]: {count} ({sum_count / total_count * 100:.6f}%)") 215 | if sum_count / total_count < threshold: 216 | vocab_dict[char] = i 217 | 218 | vocab_dict["|"] = vocab_dict[" "] 219 | del vocab_dict[" "] 220 | 221 | vocab_dict["[UNK]"] = len(vocab_dict) 222 | vocab_dict["[PAD]"] = len(vocab_dict) 223 | 224 | print("== Vocabulary ==") 225 | print(vocab_dict) 226 | 227 | return vocab_dict 228 | 229 | 230 | def prepare_each_batch(batch, processor, lm_tokenizer): 231 | audio = batch["audio"] 232 | batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0] 233 | 234 | with processor.as_target_processor(): 235 | batch["labels"] = processor(batch["sentence"]).input_ids 236 | 237 | lm_input = lm_tokenizer(batch["sentence"]) 238 | batch["lm_input_ids"] = lm_input["input_ids"] 239 | 240 | return batch 241 | --------------------------------------------------------------------------------