├── run_all.sh ├── playground ├── spellcheck_tests_3.py ├── spellcheck_tests.py ├── augmentation_tests.py └── spellcheck_tests_2.py ├── requirements.txt ├── finetune.sh ├── LICENSE ├── finetune_with_params.sh ├── generate_all_trainings.py ├── common_voice_usage.py ├── wav2vec_languages.csv ├── sweep.yaml ├── Dockerfile ├── common_voice_eval.py ├── home-server.html ├── .gitignore ├── README.md ├── cer.py ├── wer.py ├── MODEL_CARD.md ├── run_common_voice.py └── dataset_ext.py /run_all.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | supervisord -n -u 42420 -c /etc/supervisor/supervisor.conf 3 | -------------------------------------------------------------------------------- /playground/spellcheck_tests_3.py: -------------------------------------------------------------------------------- 1 | from autocorrect import Speller 2 | 3 | spell = Speller('pl') 4 | print(spell('ptaaki latatją kluczmm')) 5 | 6 | spell = Speller('fr') 7 | print(spell("CE SITE CONTIENT QUATRE TOMBEAUX DE LA DYNASTIE ASHÉMÉNIDE ET SEPT DES SASANNIDES".lower())) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.1 2 | torchaudio==0.8.1 3 | datasets==1.5.0 4 | jiwer==2.2.0 5 | soundfile==0.10.3.post1 6 | lang-trans==0.6.0 7 | librosa==0.8.0 8 | samplerate==0.1.0 9 | git+https://github.com/huggingface/transformers.git 10 | scipy==1.5.4 11 | audiomentations==0.16.0 12 | torch-audiomentations==0.6.0 13 | pyloudnorm==0.1.0 14 | wandb==0.10.23 15 | homoglyphs==2.0.4 16 | pandas==1.2.3 17 | matplotlib==3.3.4 18 | gdown==3.12.2 -------------------------------------------------------------------------------- /playground/spellcheck_tests.py: -------------------------------------------------------------------------------- 1 | from spellchecker import SpellChecker 2 | 3 | # turn off loading a built language dictionary, case sensitive on (if desired) 4 | spell = SpellChecker(language="fr") 5 | 6 | # CE SITE CONTIENT QUATRE TOMBEAUX DE LA DYNASTIE ACHÉMÉNIDE ET SEPT DES SASSANIDES. 7 | words = "CE SITE CONTIENT QUATRE TOMBEAUX DE LA DYNASTIE ASHÉMÉNIDE ET SEPT DES SASANNIDES".split() 8 | 9 | for word in words: 10 | word = word.lower() 11 | if word in spell: 12 | print("'{}' is spelled correctly!".format(word)) 13 | else: 14 | cor = spell.correction(word) 15 | print("The best spelling for '{}' is '{}'".format(word, cor)) 16 | 17 | print("If that is not enough; here are all possible candidate words:") 18 | print(spell.candidates(word)) -------------------------------------------------------------------------------- /finetune.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python run_common_voice.py \ 3 | --model_name_or_path="facebook/wav2vec2-large-xlsr-53" \ 4 | --dataset_config_name="tr" \ 5 | --output_dir=/workspace/container_0/wav2vec2-large-xlsr-turkish-demo \ 6 | --cache_dir=/workspace/container_0 \ 7 | --overwrite_output_dir \ 8 | --num_train_epochs="1" \ 9 | --per_device_train_batch_size="32" \ 10 | --per_device_train_batch_size="32" \ 11 | --evaluation_strategy="steps" \ 12 | --learning_rate="3e-4" \ 13 | --warmup_steps="500" \ 14 | --fp16 \ 15 | --freeze_feature_extractor \ 16 | --save_steps="10" \ 17 | --eval_steps="10" \ 18 | --save_total_limit="1" \ 19 | --logging_steps="10" \ 20 | --group_by_length \ 21 | --feat_proj_dropout="0.0" \ 22 | --layerdrop="0.1" \ 23 | --gradient_checkpointing \ 24 | --do_train --do_eval \ 25 | --max_train_samples 100 --max_val_samples 100 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jonatas Grosman 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /finetune_with_params.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | python /workspace/wav2vec/run_common_voice.py \ 5 | --model_name_or_path=$model_name_or_path \ 6 | --dataset_config_name=$dataset_config_name \ 7 | --output_dir=$output_dir \ 8 | --cache_dir=$cache_dir \ 9 | --overwrite_output_dir \ 10 | --num_train_epochs=$num_train_epochs \ 11 | --per_device_train_batch_size=$per_device_train_batch_size \ 12 | --evaluation_strategy=$evaluation_strategy \ 13 | --learning_rate=$learning_rate \ 14 | --warmup_steps=$warmup_steps \ 15 | --fp16 \ 16 | --freeze_feature_extractor \ 17 | --save_steps=$save_steps \ 18 | --eval_steps=$eval_steps \ 19 | --save_total_limit=$save_total_limit \ 20 | --logging_steps=$logging_steps \ 21 | --group_by_length \ 22 | --feat_proj_dropout=$feat_proj_dropout \ 23 | --layerdrop=$layerdrop \ 24 | --gradient_checkpointing \ 25 | --do_train \ 26 | --do_eval \ 27 | --max_train_samples $max_train_samples \ 28 | --max_val_samples $max_val_samples \ 29 | --report_to $report_to \ 30 | --run_name $run_name \ 31 | --augmentation_factor $augmentation_factor 32 | 33 | -------------------------------------------------------------------------------- /generate_all_trainings.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[8]: 5 | 6 | 7 | import os 8 | import csv 9 | 10 | 11 | # In[20]: 12 | 13 | 14 | with open('wav2vec_languages.csv') as csv_file: 15 | csv_reader = csv.reader(csv_file, delimiter=',') 16 | # This skips the first row of the CSV file because it's a header 17 | next(csv_reader) 18 | for (language_code, language_full_name) in csv_reader: 19 | print(f"#Launching Training for {language_code}-{language_full_name}") 20 | cmd = f"ovhai job run --gpu 1 --name '{language_code}-{language_full_name}' --volume output_models@GRA/{language_code}:/workspace/output_models:RW:cache -e model_name_or_path='facebook/wav2vec2-large-xlsr-53' -e dataset_config_name={language_code} -e output_dir='/workspace/output_models/wav2vec2-large-xlsr-{language_code}-{language_full_name}-demo' -e cache_dir='/workspace/data' -e num_train_epochs=10 databuzzword/hf-wav2vec -- sh /workspace/wav2vec/finetune_with_params.sh" 21 | print(cmd) 22 | stream = os.popen(cmd) 23 | output = stream.read() 24 | output 25 | 26 | 27 | # In[3]: 28 | 29 | 30 | 31 | 32 | 33 | # In[ ]: 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /common_voice_usage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import librosa 3 | import warnings 4 | from datasets import load_dataset 5 | from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor 6 | 7 | LANG_ID = "pt" 8 | MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese" 9 | SAMPLES = 10 10 | 11 | test_dataset = load_dataset("common_voice", LANG_ID, split=f"test[:{SAMPLES}]") 12 | 13 | processor = Wav2Vec2Processor.from_pretrained(MODEL_ID) 14 | model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID) 15 | 16 | # Preprocessing the datasets. 17 | # We need to read the audio files as arrays 18 | def speech_file_to_array_fn(batch): 19 | with warnings.catch_warnings(): 20 | warnings.simplefilter("ignore") 21 | speech_array, sampling_rate = librosa.load(batch["path"], sr=16_000) 22 | batch["speech"] = speech_array 23 | batch["sentence"] = batch["sentence"].upper() 24 | return batch 25 | 26 | test_dataset = test_dataset.map(speech_file_to_array_fn) 27 | inputs = processor(test_dataset["speech"], sampling_rate=16_000, return_tensors="pt", padding=True) 28 | 29 | with torch.no_grad(): 30 | logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits 31 | 32 | predicted_ids = torch.argmax(logits, dim=-1) 33 | predicted_sentences = processor.batch_decode(predicted_ids) 34 | 35 | for i, predicted_sentence in enumerate(predicted_sentences): 36 | print("-" * 100) 37 | print("Reference:", test_dataset[i]["sentence"]) 38 | print("Prediction:", predicted_sentence) 39 | -------------------------------------------------------------------------------- /wav2vec_languages.csv: -------------------------------------------------------------------------------- 1 | language_code;language_full_name;valid_hours_size;has_homoglyphs; 2 | ab;abkhazian;0.05;False; 3 | vi;vietnamese;0.74;True;x 4 | as;assamese;0.74;False; 5 | br;breton;7;False; 6 | en;english;1686;True; 7 | cnh;cnh;2;False; 8 | cs;czech;36;False; 9 | cv;chuvash;4;False; 10 | cy;welsh;95;False; 11 | de;german;777;True; 12 | dv;divehi;18;False; 13 | ca;catalan;623;True; 14 | fr;french;623;True; 15 | es;spanish;324;True; 16 | it;italian;158;True; 17 | ru;russian;111;True; 18 | eu;basque;89;False; 19 | fa;persian;282;False; 20 | pl;polish;108;True; 21 | lv;latvian;99;True; 22 | fy-NL;western_frisian-netherlands;14;False; 23 | ga-IE;irish-ireland;3;False; 24 | hi;hindi;0.54;False; 25 | hsb;upper_sorbian;2;False; 26 | eo;esperanto;90;True; 27 | ia;interlingua;6;False; 28 | id;indonesian;9;False; 29 | nl;dutch;59;True; 30 | ja;japanese;3;False; 31 | ka;georgian;3;False; 32 | kab;kabyle;525;False; 33 | ky;kyrgyz;11;False; 34 | lg;ganda;3;False; 35 | pt;portuguese;50;True; 36 | ar;arabic;49;True; 37 | mn;mongolian;11;False; 38 | mt;maltese;7;False; 39 | tr;turkish;20;True; 40 | or;odia;0.87;False; 41 | pa-IN;punjabi-india;0.5;False; 42 | et;estonian;19;True; 43 | hu;hungarian;8;True;x 44 | rm-sursilv;romansh_sursilv;5;False; 45 | rm-vallader;romansh_vallader;2;False; 46 | th;thai;8;True;x 47 | el;greek;6;True; 48 | rw;kinyarwanda;1183;False; 49 | sah;sakha;4;False; 50 | ro;romanian;6;True;x 51 | sv-SE;swedish-sweden;12;False; 52 | ta;tamil;14;False; 53 | sl;slovenian;5;True;x 54 | lt;lithuanian;2;True;x 55 | tt;tatar;26;False; 56 | uk;ukrainian;30;False; 57 | fi;finnish;1;True;x 58 | vot;votic;0;False; 59 | zh-CN;chinese-china;56;False; 60 | zh-HK;chinese-hong_kong_sar_china;50;False; 61 | zh-TW;chinese-taiwan;55;False; -------------------------------------------------------------------------------- /sweep.yaml: -------------------------------------------------------------------------------- 1 | program: run_common_voice.py 2 | name: hf-wav2vec-sprint-fi 3 | method: random 4 | metric: 5 | goal: minimize 6 | name: eval/loss 7 | parameters: 8 | seed: 9 | value: 42 10 | report_to: 11 | value: wandb 12 | model_name_or_path: 13 | value: facebook/wav2vec2-large-xlsr-53 14 | dataset_config_name: 15 | value: fi 16 | output_dir: 17 | value: ../models/fi/wav2vec2-large-xlsr-fi-sweep 18 | cache_dir: 19 | value: ../data/fi 20 | overwrite_output_dir: 21 | value: True 22 | fp16: 23 | value: True 24 | max_steps: 25 | value: 500 26 | eval_steps: 27 | value: 100 28 | logging_steps: 29 | value: 100 30 | do_eval: 31 | value: True 32 | do_train: 33 | value: True 34 | per_device_train_batch_size: 35 | value: 16 36 | per_device_eval_batch_size: 37 | value: 16 38 | dataloader_num_workers: 39 | value: 10 40 | preprocessing_num_workers: 41 | value: 10 42 | load_best_model_at_end: 43 | value: True 44 | save_total_limit: 45 | value: 1 46 | evaluation_strategy: 47 | value: steps 48 | freeze_feature_extractor: 49 | value: True 50 | group_by_length: 51 | value: True 52 | min_duration: 53 | value: 2.0 54 | max_duration: 55 | value: 9.0 56 | lr_warmup_ratio: 57 | value: 0.5 58 | lr_constant_ratio: 59 | value: 0.0 60 | augmentation_factor: 61 | values: [0, 1] 62 | layerdrop: 63 | value: 0.0 64 | learning_rate: 65 | values: [1e-4, 3e-4, 6e-4, 1e-3] 66 | attention_dropout: 67 | values: [0.05, 0.1, 0.2] 68 | activation_dropout: 69 | values: [0.05, 0.1, 0.2] 70 | hidden_dropout: 71 | values: [0.05, 0.1, 0.2] 72 | feat_proj_dropout: 73 | values: [0.05, 0.1, 0.2] 74 | mask_time_prob: 75 | values: [0.05, 0.1, 0.2] 76 | early_terminate: 77 | type: hyperband 78 | min_iter: 200 79 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ovhcom/ai-training-one-for-all 2 | 3 | RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add - 4 | 5 | RUN apt-get update && \ 6 | apt install -y bash \ 7 | build-essential \ 8 | libsndfile1-dev \ 9 | git-lfs \ 10 | ffmpeg \ 11 | sox \ 12 | libsox-fmt-mp3 13 | 14 | RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash && \ 15 | apt-get install git-lfs && \ 16 | git lfs install 17 | 18 | RUN python3 -m pip install --no-cache-dir --upgrade pip && \ 19 | python3 -m pip install --no-cache-dir \ 20 | torch==1.8.1 \ 21 | torchaudio==0.8.1 \ 22 | datasets==1.5.0 \ 23 | jiwer==2.2.0 \ 24 | soundfile==0.10.3.post1 \ 25 | lang-trans==0.6.0 \ 26 | librosa==0.8.0 \ 27 | samplerate==0.1.0 \ 28 | scipy==1.5.4 \ 29 | audiomentations==0.16.0 \ 30 | torch-audiomentations==0.6.0 \ 31 | pyloudnorm==0.1.0 \ 32 | wandb==0.10.23 \ 33 | homoglyphs==2.0.4 \ 34 | gdown 35 | 36 | RUN pip3 uninstall -y typing allennlp 37 | 38 | RUN pip3 install git+https://github.com/huggingface/transformers.git 39 | 40 | RUN mkdir -p /workspace/wav2vec/ 41 | 42 | COPY finetune.sh finetune_with_params.sh run_common_voice.py dataset_ext.py cer.py wer.py common_voice_eval.py common_voice_usage.py /workspace/wav2vec/ 43 | 44 | COPY home-server.html run_all.sh /usr/bin/ 45 | 46 | RUN chown -R 42420:42420 /workspace 47 | 48 | RUN chown -R 42420:42420 /usr/bin/run_all.sh 49 | 50 | #Default training env variables 51 | ENV model_name_or_path="facebook/wav2vec2-large-xlsr-53" \ 52 | dataset_config_name="fr" \ 53 | output_dir="/workspace/output_models/wav2vec2-large-xlsr-french-demo" \ 54 | cache_dir="/workspace/data" \ 55 | num_train_epochs="1" \ 56 | per_device_train_batch_size="32" \ 57 | evaluation_strategy="steps" \ 58 | learning_rate="3e-4" \ 59 | warmup_steps="500" \ 60 | save_steps="10" \ 61 | eval_steps="10" \ 62 | save_total_limit="1" \ 63 | logging_steps="10" \ 64 | feat_proj_dropout="0.0" \ 65 | layerdrop="0.1" \ 66 | max_train_samples=100 \ 67 | max_val_samples=100 68 | 69 | WORKDIR /workspace 70 | ENTRYPOINT [] 71 | #CMD ["sh", "/usr/bin/run_all.sh"] 72 | CMD ["supervisord", "-n", "-u", "42420", "-c", "/etc/supervisor/supervisor.conf"] 73 | -------------------------------------------------------------------------------- /common_voice_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import librosa 3 | import re 4 | import warnings 5 | from datasets import load_dataset, load_metric 6 | from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor 7 | 8 | LANG_ID = "pt" 9 | MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese" 10 | DEVICE = "cuda" 11 | 12 | CHARS_TO_IGNORE = [",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "�", "ʿ", "·", "჻", "~", "՞", 13 | "؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")", "[", "]", 14 | "{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。", 15 | "、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽", 16 | "『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "ʻ", "ˆ"] 17 | 18 | test_dataset = load_dataset("common_voice", LANG_ID, split="test") 19 | 20 | wer = load_metric("wer.py") # https://github.com/jonatasgrosman/wav2vec2-sprint/blob/main/wer.py 21 | cer = load_metric("cer.py") # https://github.com/jonatasgrosman/wav2vec2-sprint/blob/main/cer.py 22 | 23 | chars_to_ignore_regex = f"[{re.escape(''.join(CHARS_TO_IGNORE))}]" 24 | 25 | processor = Wav2Vec2Processor.from_pretrained(MODEL_ID) 26 | model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID) 27 | model.to(DEVICE) 28 | 29 | # Preprocessing the datasets. 30 | # We need to read the audio files as arrays 31 | def speech_file_to_array_fn(batch): 32 | with warnings.catch_warnings(): 33 | warnings.simplefilter("ignore") 34 | speech_array, sampling_rate = librosa.load(batch["path"], sr=16_000) 35 | batch["speech"] = speech_array 36 | batch["sentence"] = re.sub(chars_to_ignore_regex, "", batch["sentence"]).upper() 37 | return batch 38 | 39 | test_dataset = test_dataset.map(speech_file_to_array_fn) 40 | 41 | # Preprocessing the datasets. 42 | # We need to read the audio files as arrays 43 | def evaluate(batch): 44 | inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True) 45 | 46 | with torch.no_grad(): 47 | logits = model(inputs.input_values.to(DEVICE), attention_mask=inputs.attention_mask.to(DEVICE)).logits 48 | 49 | pred_ids = torch.argmax(logits, dim=-1) 50 | batch["pred_strings"] = processor.batch_decode(pred_ids) 51 | return batch 52 | 53 | result = test_dataset.map(evaluate, batched=True, batch_size=8) 54 | 55 | predictions = [x.upper() for x in result["pred_strings"]] 56 | references = [x.upper() for x in result["sentence"]] 57 | 58 | print(f"WER: {wer.compute(predictions=predictions, references=references, chunk_size=1000) * 100}") 59 | print(f"CER: {cer.compute(predictions=predictions, references=references, chunk_size=1000) * 100}") 60 | -------------------------------------------------------------------------------- /home-server.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | OVH AI Training Job 4 | 5 | 6 | 7 |
8 |
9 |
10 |
11 |
12 | 13 |
14 |

Jupyter Lab

16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 | 26 |
27 |

Visual Studio Code

29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 | 37 | 40 |
41 |
42 |
43 |
44 |
45 |
46 | 47 | 50 |
51 |
52 |
53 |
54 |
55 | 56 | 57 | -------------------------------------------------------------------------------- /playground/augmentation_tests.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import datasets 4 | import torch 5 | import torchaudio 6 | import soundfile as sf 7 | import librosa 8 | 9 | from audiomentations import ( 10 | Compose, 11 | AddGaussianNoise, 12 | AddGaussianSNR, 13 | ClippingDistortion, 14 | FrequencyMask, 15 | Gain, 16 | LoudnessNormalization, 17 | Normalize, 18 | PitchShift, 19 | PolarityInversion, 20 | Shift, 21 | TimeMask, 22 | TimeStretch, 23 | ) 24 | 25 | os.makedirs("_ignore_data", exist_ok=True) 26 | 27 | # creating augmentation pipeline 28 | 29 | augmentator = Compose([ 30 | AddGaussianNoise(min_amplitude=0.0001, max_amplitude=0.001, p=0.5), 31 | Gain(min_gain_in_db=-1, max_gain_in_db=1, p=0.5), 32 | PitchShift(min_semitones=-2, max_semitones=2, p=0.5), 33 | # TimeStretch(min_rate=0.7, max_rate=1.3, leave_length_unchanged=False, p=0.5), 34 | # FrequencyMask(min_frequency_band=0.0, max_frequency_band=0.5, p=0.5), 35 | # TimeMask(min_band_part=0.0, max_band_part=0.01, fade=True, p=0.5), 36 | # ClippingDistortion(min_percentile_threshold=0, max_percentile_threshold=5, p=0.5), 37 | # LoudnessNormalization(min_lufs_in_db=-31, max_lufs_in_db=-13, p=0.5) 38 | # PolarityInversion(p=0.5), 39 | # Shift(min_fraction=-0.01, max_fraction=0.01, rollover=False, p=0.5), 40 | # AddGaussianSNR(min_SNR=0.001, max_SNR=1.0, p=0.5), 41 | # Normalize(p=0.5), 42 | ]) 43 | 44 | # Get the dataset 45 | 46 | # train_dataset = datasets.load_dataset( 47 | # "common_voice_ext.py", "fi", augmentation_factor=2, split="train+validation", cache_dir="_ignore_data/cache/fi" 48 | # ) 49 | 50 | train_dataset = datasets.load_dataset( 51 | "common_voice", "pt", split="train+validation", cache_dir="_ignore_data/cache/pt" 52 | ) 53 | 54 | # Getting one sample per gender 55 | 56 | male_sample = None 57 | female_sample = None 58 | 59 | for sample in train_dataset: 60 | 61 | if sample.get("gender") == "male": 62 | male_sample = sample 63 | elif sample.get("gender") == "female": 64 | female_sample = sample 65 | 66 | if male_sample is not None and female_sample is not None: 67 | break 68 | 69 | # Augmenting data and saving it 70 | 71 | speech_array, sample_rate = librosa.load(male_sample["path"], sr = 16000, res_type="zero_order_hold") 72 | speech_array_augmented = augmentator(samples=speech_array, sample_rate=sample_rate) 73 | sf.write("_ignore_data/male_original.wav", speech_array, sample_rate, subtype="PCM_24") 74 | sf.write("_ignore_data/male_augmented.flac", speech_array_augmented, sample_rate) 75 | 76 | speech_array, sample_rate = librosa.load(female_sample["path"], sr = 16000, res_type="zero_order_hold") 77 | speech_array_augmented = augmentator(samples=speech_array, sample_rate=sample_rate) 78 | sf.write("_ignore_data/female_original.wav", speech_array, sample_rate, subtype="PCM_24") 79 | sf.write("_ignore_data/female_augmented.wav", speech_array_augmented, sample_rate, subtype="PCM_24") 80 | 81 | print(":)") 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # SO 132 | .DS_Store 133 | 134 | # IDE 135 | .vscode 136 | 137 | # ignorable folder/files pattern 138 | _ignore* 139 | 140 | # lock files 141 | *.lock 142 | 143 | # wandB files 144 | wandb 145 | 146 | # training files 147 | all_results.json 148 | config,json 149 | eval_results.json 150 | preprocessor_config.json 151 | pytorch_model.bin 152 | special_tokens_map.json 153 | tokenizer_config.json 154 | train_results.json 155 | trainer_state.json 156 | training_args.args 157 | vocab.json 158 | *.zip 159 | 160 | 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /playground/spellcheck_tests_2.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | import contextualSpellCheck 3 | 4 | # ENGLISH 5 | 6 | # python -m spacy download en_core_web_sm 7 | # nlp = spacy.load("en_core_web_sm") 8 | 9 | # nlp.add_pipe("contextual spellchecker", config={"max_edit_dist": 100}) 10 | 11 | # doc = nlp("Icome was $9.4 milion compared to the last yiear of $2.7 milion.") 12 | # print(doc._.performed_spellCheck) 13 | # print(doc._.outcome_spellCheck) 14 | 15 | # # Doc Extention 16 | # print(doc._.contextual_spellCheck) 17 | 18 | # print(doc._.performed_spellCheck) 19 | 20 | # print(doc._.suggestions_spellCheck) 21 | 22 | # print(doc._.outcome_spellCheck) 23 | 24 | # print(doc._.score_spellCheck) 25 | 26 | # # Token Extention 27 | # print(doc[4]._.get_require_spellCheck) 28 | 29 | # print(doc[4]._.get_suggestion_spellCheck) 30 | 31 | # print(doc[4]._.score_spellCheck) 32 | 33 | # # Span Extention 34 | # print(doc[2:6]._.get_has_spellCheck) 35 | 36 | # print(doc[2:6]._.score_spellCheck) 37 | 38 | # JAPANESE 39 | 40 | # python -m spacy download ja_core_news_sm 41 | # pip install mecab-python3==0.996.5 42 | # pip install ipadic==1.0.0 43 | # pip install unidic-lite==1.0.6 44 | # pip install fugashi==1.1.0 45 | # nlp = spacy.load("ja_core_news_sm") 46 | 47 | # nlp.add_pipe( 48 | # "contextual spellchecker", 49 | # config={ 50 | # "model_name": "cl-tohoku/bert-base-japanese-whole-word-masking", 51 | # "max_edit_dist": 2, 52 | # }, 53 | # ) 54 | 55 | # doc = nlp("しかし大勢においては、ここような事故はウィキペディアの拡大には影響を及ぼしていない。") 56 | # print(doc._.performed_spellCheck) 57 | # print(doc._.outcome_spellCheck) 58 | 59 | # PORTUGUESE 60 | 61 | # python -m spacy download pt_core_news_lg 62 | # nlp = spacy.load("pt_core_news_lg") 63 | 64 | # nlp.add_pipe( 65 | # "contextual spellchecker", 66 | # config={ 67 | # "model_name": "neuralmind/bert-large-portuguese-cased", 68 | # "max_edit_dist": 5, 69 | # }, 70 | # ) 71 | 72 | ## PEDIR DINHEIRO EMPRESTADO ÀS PESSOAS DA ALDEIA 73 | # doc = nlp("EDIR DINHEIRO EMPRESTADO ÀS PESSOAS DO ALDEIRA".lower()) 74 | # print(doc._.performed_spellCheck) 75 | # print(doc._.outcome_spellCheck) 76 | 77 | 78 | # SPANISH 79 | 80 | # # python -m spacy download es_dep_news_trf 81 | # nlp = spacy.load("es_dep_news_trf") 82 | 83 | # nlp.add_pipe( 84 | # "contextual spellchecker", 85 | # config={ 86 | # "model_name": "Geotrend/bert-base-es-cased", 87 | # "max_edit_dist": 5, 88 | # }, 89 | # ) 90 | 91 | ## HABITAN EN AGUAS POCO PROFUNDAS Y ROCOSAS 92 | # doc = nlp("HABITAN AGUAS POCO PROFUNDAS Y ROCOSAS".lower()) 93 | # print(doc._.performed_spellCheck) 94 | # print(doc._.outcome_spellCheck) 95 | 96 | 97 | # FRENCH 98 | 99 | # python -m spacy download fr_core_news_sm 100 | nlp = spacy.load("fr_core_news_sm") 101 | 102 | nlp.add_pipe( 103 | "contextual spellchecker", 104 | config={ 105 | "model_name": "camembert-base", 106 | "max_edit_dist": 5, 107 | }, 108 | ) 109 | 110 | # CE SITE CONTIENT QUATRE TOMBEAUX DE LA DYNASTIE ACHÉMÉNIDE ET SEPT DES SASSANIDES. 111 | doc = nlp("CE SITE CONTIENT QUATRE TOMBEAUX DE LA DYNASTIE ASHÉMÉNIDE ET SEPT DES SASANNIDES".lower()) 112 | doc = nlp("CE SITE CONTIENT QUATRE TOMBEAUX".lower()) 113 | print(doc._.performed_spellCheck) 114 | print(doc._.outcome_spellCheck) 115 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repo is deprecated in favor of https://github.com/jonatasgrosman/huggingsound 2 | 3 | # Wav2Vec Trainer 4 | 5 | This repository is based on https://github.com/jqueguiner/wav2vec2-sprint 6 | 7 | ## Building docker image 8 | 9 | Dockerhub available at https://hub.docker.com/r/patilsuraj/hf-wav2vec 10 | 11 | to build the docker : 12 | 13 | ``` 14 | $ docker build -t hf-wav2vec-sprint -f Dockerfile . 15 | ``` 16 | 17 | to push it to dockerhub 18 | First create a repository on dockerhub 19 | ``` 20 | $ docker tag hf-wav2vec-sprint your-dockerhub-user/hf-wav2vec-sprint 21 | ``` 22 | 23 | to push it to dockerhub 24 | 25 | ``` 26 | $ docker push your-dockerhub-user/hf-wav2vec-sprint 27 | ``` 28 | 29 | ## Running WandB sweep 30 | 31 | Initialize your sweep from any machine... 32 | 33 | ``` 34 | $ export WANDB_API_KEY=YOUR_WANDB_API_KEY 35 | $ export WANDB_ENTITY=YOUR_WANDB_ENTITY 36 | $ export WANDB_PROJECT=YOUR_WANDB_PROJECT 37 | 38 | $ wandb sweep sweep.yaml 39 | ``` 40 | ... the execution above will give you a sweep id, save it and on the training machine run: 41 | 42 | ``` 43 | $ export WANDB_API_KEY=YOUR_WANDB_API_KEY 44 | $ export WANDB_ENTITY=YOUR_WANDB_ENTITY 45 | $ export WANDB_PROJECT=YOUR_WANDB_PROJECT 46 | 47 | $ wandb agent YOUR_SWEEP_ID 48 | ``` 49 | 50 | ## Uploading model to HF 51 | 52 | You need to upload the following files to the HF repository 53 | 54 | - preprocessor_config.json 55 | - special_tokens_map.json 56 | - tokenizer_config.json 57 | - vocab.json 58 | - config.json 59 | - pytorch_model.bin 60 | - README.md (create this file based on the MODEL_CARD.md) 61 | 62 | ``` 63 | $ git config --global user.email "email@example.com" 64 | 65 | $ git config --global user.name "Your name" 66 | 67 | $ transformers-cli login 68 | 69 | $ transformers-cli repo create your-model-name 70 | 71 | $ git clone https://username:password_or_token@huggingface.co/username/your-model-name 72 | 73 | $ git add . 74 | 75 | $ git commit -m "Initial commit" 76 | 77 | $ git push 78 | 79 | ``` 80 | 81 | ## Troubleshooting 82 | 83 | - audioread.exceptions.NoBackendError: `$ sudo apt-get install ffmpeg sox libsox-fmt-mp3` 84 | 85 | 86 | ## Finetuned models 87 | 88 | ### Wav2Vec2-XLSR-53 89 | 90 | - [Arabic](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-arabic) 91 | - [Chinese](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn) 92 | - [Dutch](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-dutch) 93 | - [Finnish](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-finnish) 94 | - [French](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-french) 95 | - [German](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-german) 96 | - [Greek](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-greek) 97 | - [Hungarian](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-hungarian) 98 | - [Italian](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-italian) 99 | - [Japanese](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-japanese) 100 | - [Persian](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-persian) 101 | - [Polish](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-polish) 102 | - [Portuguese](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-portuguese) 103 | - [Russian](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-russian) 104 | - [Spanish](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-spanish) 105 | -------------------------------------------------------------------------------- /cer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Datasets Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Character Error Ratio (CER) metric. """ 16 | import jiwer 17 | import jiwer.transforms as tr 18 | from typing import List 19 | import datasets 20 | import gc 21 | 22 | _CITATION = """\ 23 | @inproceedings{inproceedings, 24 | author = {Morris, Andrew and Maier, Viktoria and Green, Phil}, 25 | year = {2004}, 26 | month = {01}, 27 | pages = {}, 28 | title = {From WER and RIL to MER and WIL: improved evaluation measures for connected speech recognition.} 29 | } 30 | """ 31 | _DESCRIPTION = """\ 32 | Character error rate (CER) is a common metric of the performance of an automatic speech recognition system. 33 | CER is similar to Word Error Rate (WER), but operate on character insted of word. Please refer to docs of WER for further information. 34 | Character error rate can be computed as: 35 | CER = (S + D + I) / N = (S + D + I) / (S + D + C) 36 | where 37 | S is the number of substitutions, 38 | D is the number of deletions, 39 | I is the number of insertions, 40 | C is the number of correct characters, 41 | N is the number of characters in the reference (N=S+D+C). 42 | CER's output is always a number between 0 and 1. This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the 43 | performance of the ASR system with a CER of 0 being a perfect score. 44 | """ 45 | _KWARGS_DESCRIPTION = """ 46 | Computes CER score of transcribed segments against references. 47 | Args: 48 | references: list of references for each speech input. 49 | predictions: list of transcribtions to score. 50 | Returns: 51 | (float): the character error rate 52 | Examples: 53 | >>> predictions = ["this is the prediction", "there is an other sample"] 54 | >>> references = ["this is the reference", "there is another one"] 55 | >>> cer = datasets.load_metric("cer") 56 | >>> cer_score = cer.compute(predictions=predictions, references=references) 57 | >>> print(cer_score) 58 | 0.34146341463414637 59 | """ 60 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 61 | class CER(datasets.Metric): 62 | def _info(self): 63 | return datasets.MetricInfo( 64 | description=_DESCRIPTION, 65 | citation=_CITATION, 66 | inputs_description=_KWARGS_DESCRIPTION, 67 | features=datasets.Features( 68 | { 69 | "predictions": datasets.Value("string", id="sequence"), 70 | "references": datasets.Value("string", id="sequence"), 71 | } 72 | ), 73 | codebase_urls=["https://github.com/jitsi/jiwer/"], 74 | reference_urls=[ 75 | "https://en.wikipedia.org/wiki/Word_error_rate", 76 | ], 77 | ) 78 | def _compute(self, predictions, references, chunk_size=None): 79 | if chunk_size is None: 80 | preds = [char for seq in predictions for char in list(seq)] 81 | refs = [char for seq in references for char in list(seq)] 82 | return jiwer.wer(refs, preds) 83 | start = 0 84 | end = chunk_size 85 | H, S, D, I = 0, 0, 0, 0 86 | while start < len(references): 87 | preds = [char for seq in predictions[start:end] for char in list(seq)] 88 | refs = [char for seq in references[start:end] for char in list(seq)] 89 | chunk_metrics = jiwer.compute_measures(refs, preds) 90 | H = H + chunk_metrics["hits"] 91 | S = S + chunk_metrics["substitutions"] 92 | D = D + chunk_metrics["deletions"] 93 | I = I + chunk_metrics["insertions"] 94 | start += chunk_size 95 | end += chunk_size 96 | del preds 97 | del refs 98 | del chunk_metrics 99 | gc.collect() 100 | return float(S + D + I) / float(H + S + D) 101 | -------------------------------------------------------------------------------- /wer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Datasets Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Word Error Ratio (WER) metric. """ 16 | 17 | import jiwer 18 | import datasets 19 | 20 | _CITATION = """\ 21 | @inproceedings{inproceedings, 22 | author = {Morris, Andrew and Maier, Viktoria and Green, Phil}, 23 | year = {2004}, 24 | month = {01}, 25 | pages = {}, 26 | title = {From WER and RIL to MER and WIL: improved evaluation measures for connected speech recognition.} 27 | } 28 | """ 29 | 30 | _DESCRIPTION = """\ 31 | Word error rate (WER) is a common metric of the performance of an automatic speech recognition system. 32 | The general difficulty of measuring performance lies in the fact that the recognized word sequence can have a different length from the reference word sequence (supposedly the correct one). The WER is derived from the Levenshtein distance, working at the word level instead of the phoneme level. The WER is a valuable tool for comparing different systems as well as for evaluating improvements within one system. This kind of measurement, however, provides no details on the nature of translation errors and further work is therefore required to identify the main source(s) of error and to focus any research effort. 33 | This problem is solved by first aligning the recognized word sequence with the reference (spoken) word sequence using dynamic string alignment. Examination of this issue is seen through a theory called the power law that states the correlation between perplexity and word error rate. 34 | Word error rate can then be computed as: 35 | WER = (S + D + I) / N = (S + D + I) / (S + D + C) 36 | where 37 | S is the number of substitutions, 38 | D is the number of deletions, 39 | I is the number of insertions, 40 | C is the number of correct words, 41 | N is the number of words in the reference (N=S+D+C). 42 | WER's output is always a number between 0 and 1. This value indicates the percentage of words that were incorrectly predicted. The lower the value, the better the 43 | performance of the ASR system with a WER of 0 being a perfect score. 44 | """ 45 | 46 | _KWARGS_DESCRIPTION = """ 47 | Computes WER score of transcribed segments against references. 48 | Args: 49 | references: list of references for each speech input. 50 | predictions: list of transcribtions to score. 51 | Returns: 52 | (float): the word error rate 53 | Examples: 54 | >>> predictions = ["this is the prediction", "there is an other sample"] 55 | >>> references = ["this is the reference", "there is another one"] 56 | >>> wer = datasets.load_metric("wer") 57 | >>> wer_score = wer.compute(predictions=predictions, references=references) 58 | >>> print(wer_score) 59 | 0.5 60 | """ 61 | 62 | 63 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 64 | class WER(datasets.Metric): 65 | def _info(self): 66 | return datasets.MetricInfo( 67 | description=_DESCRIPTION, 68 | citation=_CITATION, 69 | inputs_description=_KWARGS_DESCRIPTION, 70 | features=datasets.Features( 71 | { 72 | "predictions": datasets.Value("string", id="sequence"), 73 | "references": datasets.Value("string", id="sequence"), 74 | } 75 | ), 76 | codebase_urls=["https://github.com/jitsi/jiwer/"], 77 | reference_urls=[ 78 | "https://en.wikipedia.org/wiki/Word_error_rate", 79 | ], 80 | ) 81 | 82 | def _compute(self, predictions, references, chunk_size=None): 83 | if chunk_size is None: return jiwer.wer(references, predictions) 84 | start = 0 85 | end = chunk_size 86 | H, S, D, I = 0, 0, 0, 0 87 | while start < len(references): 88 | chunk_metrics = jiwer.compute_measures(references[start:end], predictions[start:end]) 89 | H = H + chunk_metrics["hits"] 90 | S = S + chunk_metrics["substitutions"] 91 | D = D + chunk_metrics["deletions"] 92 | I = I + chunk_metrics["insertions"] 93 | start += chunk_size 94 | end += chunk_size 95 | return float(S + D + I) / float(H + S + D) 96 | -------------------------------------------------------------------------------- /MODEL_CARD.md: -------------------------------------------------------------------------------- 1 | --- 2 | language: pt 3 | datasets: 4 | - common_voice 5 | metrics: 6 | - wer 7 | - cer 8 | tags: 9 | - audio 10 | - automatic-speech-recognition 11 | - speech 12 | - xlsr-fine-tuning-week 13 | license: apache-2.0 14 | model-index: 15 | - name: XLSR Wav2Vec2 Portuguese by Jonatas Grosman 16 | results: 17 | - task: 18 | name: Speech Recognition 19 | type: automatic-speech-recognition 20 | dataset: 21 | name: Common Voice pt 22 | type: common_voice 23 | args: pt 24 | metrics: 25 | - name: Test WER 26 | type: wer 27 | value: {wer_result_on_test} 28 | - name: Test CER 29 | type: cer 30 | value: {cer_result_on_test} #TODO (IMPORTANT): replace {wer_result_on_test} with the WER error rate you achieved on the common_voice test set. It should be in the format XX.XX (don't add the % sign here). **Please** remember to fill out this value after you evaluated your model, so that your model appears on the leaderboard. If you fill out this model card before evaluating your model, please remember to edit the model card afterward to fill in your value 31 | --- 32 | 33 | # Wav2Vec2-Large-XLSR-53-Portuguese 34 | 35 | Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on Portuguese using the [Common Voice](https://huggingface.co/datasets/common_voice). 36 | When using this model, make sure that your speech input is sampled at 16kHz. 37 | 38 | The script used for training can be found here: https://github.com/jonatasgrosman/wav2vec2-sprint 39 | ## Usage 40 | 41 | The model can be used directly (without a language model) as follows: 42 | 43 | ```python 44 | import torch 45 | import librosa 46 | from datasets import load_dataset 47 | from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor 48 | 49 | LANG_ID = "pt" 50 | MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese" 51 | SAMPLES = 5 52 | 53 | test_dataset = load_dataset("common_voice", LANG_ID, split=f"test[:{SAMPLES}]") 54 | 55 | processor = Wav2Vec2Processor.from_pretrained(MODEL_ID) 56 | model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID) 57 | 58 | # Preprocessing the datasets. 59 | # We need to read the audio files as arrays 60 | def speech_file_to_array_fn(batch): 61 | speech_array, sampling_rate = librosa.load(batch["path"], sr=16_000) 62 | batch["speech"] = speech_array 63 | batch["sentence"] = batch["sentence"].upper() 64 | return batch 65 | 66 | test_dataset = test_dataset.map(speech_file_to_array_fn) 67 | inputs = processor(test_dataset["speech"], sampling_rate=16_000, return_tensors="pt", padding=True) 68 | 69 | with torch.no_grad(): 70 | logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits 71 | 72 | predicted_ids = torch.argmax(logits, dim=-1) 73 | predicted_sentences = processor.batch_decode(predicted_ids) 74 | 75 | for i, predicted_sentence in enumerate(predicted_sentences): 76 | print("-" * 100) 77 | print("Reference:", test_dataset[i]["sentence"]) 78 | print("Prediction:", predicted_sentence) 79 | ``` 80 | 81 | | Reference | Prediction | 82 | | ------------- | ------------- | 83 | | NEM O RADAR NEM OS OUTROS INSTRUMENTOS DETECTARAM O BOMBARDEIRO STEALTH. | NEM UM VADA ME OS OUTOS INSTRUMENTOS DE TETERAM UM BAMBEDER OSTAU | 84 | | PEDIR DINHEIRO EMPRESTADO ÀS PESSOAS DA ALDEIA | PEDIAR DINHEIRO EMPRESTADO DÀS PESSOAS DA ALDEIA | 85 | | OITO | OITO | 86 | | TRANCÁ-LOS | TRAM CALDOS | 87 | | REALIZAR UMA INVESTIGAÇÃO PARA RESOLVER O PROBLEMA | REALIZARAMA INVESTIGAÇÃO PARA RESOLVER O PROBLEMA | 88 | 89 | ## Evaluation 90 | 91 | The model can be evaluated as follows on the Portuguese test data of Common Voice. 92 | 93 | ```python 94 | import torch 95 | import librosa 96 | import re 97 | from datasets import load_dataset, load_metric 98 | from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor 99 | 100 | LANG_ID = "pt" 101 | MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese" 102 | DEVICE = "cuda" 103 | 104 | CHARS_TO_IGNORE = [",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "�", "ʿ", "·", "჻", "~", "՞", 105 | "؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")", "[", "]", 106 | "{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。", 107 | "、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽", 108 | "『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "ʻ", "ˆ"] 109 | 110 | test_dataset = load_dataset("common_voice", LANG_ID, split="test") 111 | 112 | wer = load_metric("wer") # https://github.com/jonatasgrosman/wav2vec2-sprint/blob/main/wer.py 113 | cer = load_metric("cer") # https://github.com/jonatasgrosman/wav2vec2-sprint/blob/main/cer.py 114 | 115 | chars_to_ignore_regex = f"[{re.escape(''.join(CHARS_TO_IGNORE))}]" 116 | 117 | processor = Wav2Vec2Processor.from_pretrained(MODEL_ID) 118 | model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID) 119 | model.to(DEVICE) 120 | 121 | # Preprocessing the datasets. 122 | # We need to read the audio files as arrays 123 | def speech_file_to_array_fn(batch): 124 | batch["sentence"] = re.sub(chars_to_ignore_regex, "", batch["sentence"]).upper() 125 | speech_array, sampling_rate = librosa.load(batch["path"], sr=16_000) 126 | batch["speech"] = speech_array 127 | return batch 128 | 129 | test_dataset = test_dataset.map(speech_file_to_array_fn) 130 | 131 | # Preprocessing the datasets. 132 | # We need to read the audio files as arrays 133 | def evaluate(batch): 134 | inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True) 135 | 136 | with torch.no_grad(): 137 | logits = model(inputs.input_values.to(DEVICE), attention_mask=inputs.attention_mask.to(DEVICE)).logits 138 | 139 | pred_ids = torch.argmax(logits, dim=-1) 140 | batch["pred_strings"] = processor.batch_decode(pred_ids) 141 | return batch 142 | 143 | result = test_dataset.map(evaluate, batched=True, batch_size=8) 144 | 145 | predictions = [x.upper() for x in result["pred_strings"]] 146 | references = [x.upper() for x in result["sentence"]] 147 | 148 | print(f"WER: {wer.compute(predictions=predictions, references=references, chunk_size=1000) * 100}") 149 | print(f"CER: {cer.compute(predictions=predictions, references=references, chunk_size=1000) * 100}") 150 | ``` 151 | 152 | **Test Result**: 153 | 154 | In the table below I report the Word Error Rate (WER) and the Character Error Rate (CER) of the model. I ran the evaluation script described above on other models as well (on YYYY-MM-DD). Note that the table below may show different results from those already reported, this may have been caused due to some specificity of the other evaluation scripts used. 155 | 156 | | Model | WER | CER | 157 | | ------------- | ------------- | ------------- | 158 | | **jonatasgrosman/wav2vec2-large-xlsr-53-portuguese** | XX.XX% | XX.XX% | 159 | | username/model_name | XX.XX% | XX.XX% | 160 | -------------------------------------------------------------------------------- /run_common_voice.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import json 3 | import logging 4 | import os 5 | import re 6 | import sys 7 | import collections 8 | from dataclasses import dataclass, field 9 | from typing import Any, Dict, List, Optional, Union 10 | 11 | import datasets 12 | import numpy as np 13 | import torch 14 | import soundfile as sf 15 | from packaging import version 16 | from torch import nn 17 | from pathlib import Path 18 | import wandb 19 | 20 | from torch_audiomentations import Compose, Gain 21 | from audiomentations import ( 22 | Compose, 23 | AddGaussianNoise, 24 | AddGaussianSNR, 25 | ClippingDistortion, 26 | FrequencyMask, 27 | Gain, 28 | LoudnessNormalization, 29 | Normalize, 30 | PitchShift, 31 | PolarityInversion, 32 | Shift, 33 | TimeMask, 34 | TimeStretch, 35 | ) 36 | 37 | import transformers 38 | from transformers import ( 39 | HfArgumentParser, 40 | Trainer, 41 | TrainingArguments, 42 | Wav2Vec2CTCTokenizer, 43 | Wav2Vec2FeatureExtractor, 44 | Wav2Vec2ForCTC, 45 | Wav2Vec2Processor, 46 | is_apex_available, 47 | set_seed, 48 | ) 49 | from transformers.trainer_utils import get_last_checkpoint, is_main_process 50 | from transformers.trainer_pt_utils import LengthGroupedSampler, DistributedLengthGroupedSampler 51 | 52 | 53 | PRETRAINED_MODELS = [ 54 | "facebook/wav2vec2-large", 55 | "facebook/wav2vec2-large-xlsr-53", 56 | "facebook/wav2vec2-large-es-voxpopuli", 57 | "facebook/wav2vec2-large-fr-voxpopuli", 58 | "facebook/wav2vec2-large-it-voxpopuli", 59 | "facebook/wav2vec2-large-nl-voxpopuli", 60 | "facebook/wav2vec2-large-sv-voxpopuli", 61 | "facebook/wav2vec2-large-10k-voxpopuli", 62 | "facebook/wav2vec2-large-100k-voxpopuli" 63 | ] 64 | 65 | 66 | if is_apex_available(): 67 | from apex import amp 68 | 69 | 70 | if version.parse(torch.__version__) >= version.parse("1.6"): 71 | _is_native_amp_available = True 72 | from torch.cuda.amp import autocast 73 | 74 | logger = logging.getLogger(__name__) 75 | 76 | 77 | def list_field(default=None, metadata=None): 78 | return field(default_factory=lambda: default, metadata=metadata) 79 | 80 | 81 | @dataclass 82 | class AdditionalTrainingArguments: 83 | """ 84 | Additional training arguments 85 | """ 86 | 87 | lr_warmup_ratio: Optional[float] = field( 88 | default=0.1, 89 | metadata={"help": "Percentage of steps for LR warmup phase"}, 90 | ) 91 | lr_constant_ratio: Optional[float] = field( 92 | default=0.4, 93 | metadata={"help": "Percentage of steps for LR constant phase (after warmup)"}, 94 | ) 95 | upload_final_model_to_wandb: Optional[bool] = field( 96 | default=False, 97 | metadata={"help": "Upload the final trained model to the WandB artifacts repository"}, 98 | ) 99 | upload_model_to_wandb_each_step: Optional[int] = field( 100 | default=None, 101 | metadata={"help": "Frequency (in steps) to upload the trained model to the WandB artifacts repository"}, 102 | ) 103 | apply_gaussian_noise_with_p: Optional[float] = field( 104 | default=0.5, 105 | metadata={"help": "Probability to apply Gaussian Noise in the original samples"}, 106 | ) 107 | apply_gain_with_p: Optional[float] = field( 108 | default=0.5, 109 | metadata={"help": "Probability to apply Gain in the original samples"}, 110 | ) 111 | apply_pitch_shift_with_p: Optional[float] = field( 112 | default=0.5, 113 | metadata={"help": "Probability to apply Pitch Shift in the original samples"}, 114 | ) 115 | apply_time_stretch_with_p: Optional[float] = field( 116 | default=0.5, 117 | metadata={"help": "Probability to apply Time Stretch in the original samples"}, 118 | ) 119 | min_char_occurrence_ratio: Optional[float] = field( 120 | default=None, 121 | metadata={"help": "Minimum ratio of character occurrences to be considered for the vocabulary builder"}, 122 | ) 123 | max_dataset_size_vocab_builder: Optional[int] = field( 124 | default=10000, 125 | metadata={"help": "Maximum size of the dataset to be considered for vocabulary builder"}, 126 | ) 127 | remove_samples_with_oov_from_training: Optional[bool] = field( 128 | default=False, 129 | metadata={"help": "Whether to remove samples from training when there are OOV characters on them"}, 130 | ) 131 | use_only_top_k_most_common_accent: Optional[int] = field( 132 | default=None, 133 | metadata={"help": "Use only the top most common accent in dataset for training"}, 134 | ) 135 | 136 | @dataclass 137 | class ModelArguments: 138 | """ 139 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 140 | """ 141 | 142 | model_name_or_path: str = field( 143 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 144 | ) 145 | cache_dir: Optional[str] = field( 146 | default=None, 147 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 148 | ) 149 | freeze_feature_extractor: Optional[bool] = field( 150 | default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."} 151 | ) 152 | attention_dropout: Optional[float] = field( 153 | default=0.1, metadata={"help": "The dropout ratio for the attention probabilities."} 154 | ) 155 | activation_dropout: Optional[float] = field( 156 | default=0.1, metadata={"help": "The dropout ratio for activations inside the fully connected layer."} 157 | ) 158 | hidden_dropout: Optional[float] = field( 159 | default=0.1, 160 | metadata={ 161 | "help": "The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler." 162 | }, 163 | ) 164 | feat_proj_dropout: Optional[float] = field( 165 | default=0.1, 166 | metadata={"help": "The dropout probabilitiy for all 1D convolutional layers in feature extractor."}, 167 | ) 168 | mask_time_prob: Optional[float] = field( 169 | default=0.05, 170 | metadata={ 171 | "help": "Propability of each feature vector along the time axis to be chosen as the start of the vector" 172 | "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature" 173 | "vectors will be masked along the time axis. This is only relevant if ``apply_spec_augment is True``." 174 | }, 175 | ) 176 | gradient_checkpointing: Optional[bool] = field( 177 | default=True, 178 | metadata={ 179 | "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." 180 | }, 181 | ) 182 | layerdrop: Optional[float] = field(default=0.0, metadata={"help": "The LayerDrop probability."}) 183 | 184 | 185 | @dataclass 186 | class DataTrainingArguments: 187 | """ 188 | Arguments pertaining to what data we are going to input our model for training and eval. 189 | 190 | Using `HfArgumentParser` we can turn this class 191 | into argparse arguments to be able to specify them on 192 | the command line. 193 | """ 194 | 195 | dataset_config_name: Optional[str] = field( 196 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 197 | ) 198 | overwrite_cache: bool = field( 199 | default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} 200 | ) 201 | preprocessing_num_workers: Optional[int] = field( 202 | default=None, 203 | metadata={"help": "The number of processes to use for the preprocessing."}, 204 | ) 205 | max_train_samples: Optional[int] = field( 206 | default=None, 207 | metadata={ 208 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 209 | "value if set." 210 | }, 211 | ) 212 | max_val_samples: Optional[int] = field( 213 | default=1000, 214 | metadata={ 215 | "help": "For debugging purposes or quicker training, truncate the number of validation examples to this " 216 | "value if set." 217 | }, 218 | ) 219 | val_ratio: Optional[float] = field( 220 | default=0.2, 221 | metadata={ 222 | "help": "Percentage of dataset samples to be used for evaluation, default is 20%" 223 | }, 224 | ) 225 | chars_to_ignore: List[str] = list_field( 226 | default=[",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "�", "ʿ", "·", "჻", "~", "՞", 227 | "؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")", "[", "]", 228 | "{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。", 229 | "、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽", 230 | "『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "ʻ", "ˆ"], 231 | metadata={"help": "A list of characters to remove from the transcripts."}, 232 | ) 233 | min_duration: Optional[float] = field( 234 | default=0.0, 235 | metadata={ 236 | "help": "The minimum duration (in seconds) that a sample needs to have to be considered for training" 237 | }, 238 | ) 239 | max_duration: Optional[float] = field( 240 | default=float("inf"), 241 | metadata={ 242 | "help": "The maximum duration (in seconds) that a sample needs to have to be considered for training" 243 | }, 244 | ) 245 | use_only_common_voice_data: bool = field( 246 | default=False, metadata={"help": "Use only common voice data in training."} 247 | ) 248 | 249 | 250 | @dataclass 251 | class DataCollatorCTCWithPadding: 252 | """ 253 | Data collator that will dynamically pad the inputs received. 254 | Args: 255 | processor (:class:`~transformers.Wav2Vec2Processor`) 256 | The processor used for proccessing the data. 257 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): 258 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 259 | among: 260 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 261 | sequence if provided). 262 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 263 | maximum acceptable input length for the model if that argument is not provided. 264 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 265 | different lengths). 266 | max_length (:obj:`int`, `optional`): 267 | Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). 268 | max_length_labels (:obj:`int`, `optional`): 269 | Maximum length of the ``labels`` returned list and optionally padding length (see above). 270 | pad_to_multiple_of (:obj:`int`, `optional`): 271 | If set will pad the sequence to a multiple of the provided value. 272 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 273 | 7.5 (Volta). 274 | """ 275 | 276 | processor: Wav2Vec2Processor 277 | padding: Union[bool, str] = True 278 | max_length: Optional[int] = None 279 | max_length_labels: Optional[int] = None 280 | pad_to_multiple_of: Optional[int] = None 281 | pad_to_multiple_of_labels: Optional[int] = None 282 | 283 | def __init__(self, processor, padding=True, apply_gaussian_noise_with_p=0.5, apply_gain_with_p=0.5, apply_pitch_shift_with_p=0.5, 284 | apply_time_stretch_with_p=0.5, sample_rate=16_000): 285 | self.processor = processor 286 | self.padding = padding 287 | self.apply_gaussian_noise_with_p = apply_gaussian_noise_with_p 288 | self.apply_gain_with_p = apply_gain_with_p 289 | self.apply_pitch_shift_with_p = apply_pitch_shift_with_p 290 | self.apply_time_stretch_with_p = apply_time_stretch_with_p 291 | self.sample_rate = sample_rate 292 | 293 | self.augmentator = None 294 | if self.apply_gaussian_noise_with_p + self.apply_gain_with_p + self.apply_pitch_shift_with_p + self.apply_time_stretch_with_p > 0: 295 | self.augmentator = Compose([ 296 | TimeStretch(min_rate=0.8, max_rate=1.2, leave_length_unchanged=False, p=self.apply_time_stretch_with_p), 297 | PitchShift(min_semitones=-1, max_semitones=1, p=self.apply_pitch_shift_with_p), 298 | Gain(min_gain_in_db=-1, max_gain_in_db=1, p=self.apply_gain_with_p), 299 | AddGaussianNoise(min_amplitude=0.0001, max_amplitude=0.001, p=self.apply_gaussian_noise_with_p), 300 | ]) 301 | 302 | def _apply_augmentation(self, input_values: List[float]): 303 | """apply some audio augmentations in the given input_values""" 304 | if self.augmentator is not None: 305 | return self.augmentator(samples=np.array(input_values), sample_rate=self.sample_rate).tolist() 306 | else: 307 | return input_values 308 | 309 | def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: 310 | # split inputs and labels since they have to be of different lenghts and need 311 | # different padding methods 312 | 313 | input_features = [{"input_values": self._apply_augmentation(feature["input_values"])} for feature in features] 314 | label_features = [{"input_ids": feature["labels"]} for feature in features] 315 | 316 | batch = self.processor.pad( 317 | input_features, 318 | padding=self.padding, 319 | max_length=self.max_length, 320 | pad_to_multiple_of=self.pad_to_multiple_of, 321 | return_tensors="pt", 322 | ) 323 | with self.processor.as_target_processor(): 324 | labels_batch = self.processor.pad( 325 | label_features, 326 | padding=self.padding, 327 | max_length=self.max_length_labels, 328 | pad_to_multiple_of=self.pad_to_multiple_of_labels, 329 | return_tensors="pt", 330 | ) 331 | 332 | # replace padding with -100 to ignore loss correctly 333 | labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) 334 | 335 | batch["labels"] = labels 336 | 337 | return batch 338 | 339 | 340 | class CTCTrainer(Trainer): 341 | 342 | def __init__(self, model_output_dir, length_field_name="length", upload_model_to_wandb_each_step=None, lr_warmup_ratio=0.1, 343 | lr_constant_ratio=0.4, sampling_rate=16_000, **kwargs): 344 | super().__init__(**kwargs) 345 | self.model_output_dir = model_output_dir 346 | self.length_field_name = length_field_name 347 | self.upload_model_to_wandb_each_step = upload_model_to_wandb_each_step 348 | self.lr_warmup_ratio = lr_warmup_ratio 349 | self.lr_constant_ratio = lr_constant_ratio 350 | self.sampling_rate = sampling_rate 351 | 352 | def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: 353 | if isinstance(self.train_dataset, torch.utils.data.IterableDataset) or not isinstance( 354 | self.train_dataset, collections.abc.Sized 355 | ): 356 | return None 357 | 358 | # Build the sampler. 359 | if self.args.group_by_length: 360 | lengths = self.train_dataset[self.length_field_name] if self.length_field_name is not None else None 361 | model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None 362 | if self.args.world_size <= 1: 363 | return LengthGroupedSampler( 364 | self.train_dataset, self.args.train_batch_size, lengths=lengths, model_input_name=model_input_name 365 | ) 366 | else: 367 | return DistributedLengthGroupedSampler( 368 | self.train_dataset, 369 | self.args.train_batch_size, 370 | num_replicas=self.args.world_size, 371 | rank=self.args.process_index, 372 | lengths=lengths, 373 | model_input_name=model_input_name, 374 | ) 375 | 376 | else: 377 | return super()._get_train_sampler() 378 | 379 | def create_scheduler(self, num_training_steps: int): 380 | """ 381 | Setup the scheduler. The optimizer of the trainer must have been set up before this method is called. 382 | 383 | This method was built based on https://arxiv.org/pdf/2006.13979 : 384 | "The learning rate schedule has three phases: warm up for the first 10% of updates, 385 | keep constant for 40% and then linearly decay for the remainder" 386 | 387 | Args: 388 | num_training_steps (int): The number of training steps to do. 389 | """ 390 | def lr_lambda(current_step): 391 | warmup_steps = int(num_training_steps * self.lr_warmup_ratio) 392 | constant_steps = int(num_training_steps * self.lr_constant_ratio) 393 | if current_step < warmup_steps: 394 | return float(current_step) / float(max(1, warmup_steps)) 395 | elif (self.lr_warmup_ratio + self.lr_constant_ratio) == 1.0 or current_step < (warmup_steps + constant_steps): 396 | return 1 397 | else: 398 | return max( 399 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - (warmup_steps + constant_steps))) 400 | ) 401 | 402 | self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda) 403 | 404 | def _apply_some_audio_transformations(self, inputs): 405 | """Perform some audio transformations""" 406 | 407 | # adding an extra dimmention for the channels as our data is mono audio and 408 | # the expected shape of input for torch_audiomentations is (batch_size, num_channels, num_samples) 409 | transformed_inputs = inputs["input_values"].unsqueeze(1) 410 | 411 | transformed_inputs = self.augmentator(transformed_inputs, sample_rate=self.sampling_rate) 412 | 413 | # returning the inputs to the original shape 414 | transformed_inputs = torch.squeeze(transformed_inputs, 1) 415 | 416 | inputs["input_values"] = transformed_inputs 417 | 418 | return inputs 419 | 420 | def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: 421 | """ 422 | Perform a training step on a batch of inputs. 423 | 424 | Subclass and override to inject custom behavior. 425 | 426 | Args: 427 | model (:obj:`nn.Module`): 428 | The model to train. 429 | inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): 430 | The inputs and targets of the model. 431 | 432 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 433 | argument :obj:`labels`. Check your model's documentation for all accepted arguments. 434 | 435 | Return: 436 | :obj:`torch.Tensor`: The tensor with training loss on this batch. 437 | """ 438 | 439 | model.train() 440 | inputs = self._prepare_inputs(inputs) 441 | 442 | if self.use_amp: 443 | with autocast(): 444 | loss = self.compute_loss(model, inputs) 445 | else: 446 | loss = self.compute_loss(model, inputs) 447 | 448 | if self.args.n_gpu > 1: 449 | if model.module.config.ctc_loss_reduction == "mean": 450 | loss = loss.mean() 451 | elif model.module.config.ctc_loss_reduction == "sum": 452 | loss = loss.sum() / (inputs["labels"] >= 0).sum() 453 | else: 454 | raise ValueError(f"{model.config.ctc_loss_reduction} is not valid. Choose one of ['mean', 'sum']") 455 | 456 | if self.args.gradient_accumulation_steps > 1: 457 | loss = loss / self.args.gradient_accumulation_steps 458 | 459 | if self.use_amp: 460 | self.scaler.scale(loss).backward() 461 | elif self.use_apex: 462 | with amp.scale_loss(loss, self.optimizer) as scaled_loss: 463 | scaled_loss.backward() 464 | elif self.deepspeed: 465 | self.deepspeed.backward(loss) 466 | else: 467 | loss.backward() 468 | 469 | if self.upload_model_to_wandb_each_step is not None and self.state.global_step > 0 \ 470 | and self.state.global_step % self.upload_model_to_wandb_each_step == 0: 471 | upload_model_to_wandb(self.model_output_dir, name=f"{wandb.run.name}_{self.state.global_step}", metadata={"loss": float(loss)}) 472 | 473 | return loss.detach() 474 | 475 | 476 | def build_tokenizer(model_output_dir, dataset, num_proc, min_char_occurrence_ratio): 477 | 478 | def extract_all_chars(batch): 479 | all_text = " ".join(batch["text"]).replace("", "") 480 | return {"all_text": [all_text]} 481 | 482 | vocab_train = dataset.map( 483 | extract_all_chars, 484 | batched=True, 485 | batch_size=-1, 486 | remove_columns=dataset.column_names, 487 | num_proc=num_proc 488 | ) 489 | 490 | special_vocab_dict = {"": 0, "": 1, "": 2, "": 3, "|": 4} 491 | 492 | min_char_occurrence = int(min_char_occurrence_ratio * len(vocab_train["all_text"][0])) if min_char_occurrence_ratio is not None else 1 493 | 494 | if min_char_occurrence > 1: 495 | character_counter = collections.Counter(vocab_train["all_text"][0]) 496 | vocab_list = [character for character, count in character_counter.items() if count >= min_char_occurrence] 497 | else: 498 | vocab_list = set(vocab_train["all_text"][0]) 499 | 500 | vocab_list = [x for x in vocab_list if x.isalpha() or x in ["-", "'"]] # removing non-alpha (except - or ') characters 501 | 502 | vocab_list = sorted(vocab_list) 503 | vocab_dict = {v: k + len(special_vocab_dict) for k, v in enumerate(vocab_list)} 504 | vocab_dict = dict(special_vocab_dict, **vocab_dict) 505 | 506 | vocab_path = os.path.join(model_output_dir, "vocab.json") 507 | 508 | with open(vocab_path, "w") as vocab_file: 509 | json.dump(vocab_dict, vocab_file) 510 | 511 | return Wav2Vec2CTCTokenizer( 512 | vocab_path, 513 | unk_token="", 514 | pad_token="", 515 | word_delimiter_token="|", 516 | ) 517 | 518 | 519 | def upload_model_to_wandb(model_output_dir, name, metadata=None): 520 | artifact = wandb.Artifact(name=name, type="model", metadata=metadata) 521 | artifact.add_dir(model_output_dir) 522 | wandb.run.log_artifact(artifact) 523 | 524 | 525 | def main(): 526 | # See all possible arguments in src/transformers/training_args.py 527 | # or by passing the --help flag to this script. 528 | # We now keep distinct sets of args, for a cleaner separation of concerns. 529 | 530 | # override default run name 531 | 532 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, AdditionalTrainingArguments, TrainingArguments)) 533 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 534 | # If we pass only one argument to the script and it's the path to a json file, 535 | # let's parse it to get our arguments. 536 | model_args, data_args, additional_training_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 537 | else: 538 | model_args, data_args, additional_training_args, training_args = parser.parse_args_into_dataclasses() 539 | 540 | os.makedirs(training_args.output_dir, exist_ok=True) 541 | os.makedirs(model_args.cache_dir, exist_ok=True) 542 | 543 | wandb.init(dir=model_args.cache_dir) 544 | 545 | # Detecting last checkpoint. 546 | last_checkpoint = None 547 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 548 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 549 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 550 | raise ValueError( 551 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 552 | "Use --overwrite_output_dir to overcome." 553 | ) 554 | elif last_checkpoint is not None: 555 | logger.info( 556 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 557 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 558 | ) 559 | 560 | # Setup logging 561 | logging.basicConfig( 562 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 563 | datefmt="%m/%d/%Y %H:%M:%S", 564 | handlers=[logging.StreamHandler(sys.stdout)], 565 | ) 566 | logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) 567 | 568 | # Log on each process the small summary: 569 | logger.warning( 570 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 571 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 572 | ) 573 | # Set the verbosity to info of the Transformers logger (on main process only): 574 | if is_main_process(training_args.local_rank): 575 | transformers.utils.logging.set_verbosity_info() 576 | logger.info("Training/evaluation parameters %s", training_args) 577 | 578 | # Set seed before initializing model. 579 | set_seed(training_args.seed) 580 | 581 | # Get the datasets: 582 | 583 | # As Common Voice dataset for most of the languages are really small, we'll merge the train and validation splits 584 | dataset = datasets.load_dataset( 585 | "dataset_ext.py", data_args.dataset_config_name, 586 | split="train+validation", 587 | cache_dir=model_args.cache_dir 588 | ) 589 | 590 | print("DATASET COUNT:") 591 | print(collections.Counter(dataset["dataset"])) 592 | 593 | if data_args.val_ratio > 0 and data_args.max_val_samples > 0 and training_args.do_eval: 594 | if len(dataset) * data_args.val_ratio > data_args.max_val_samples: 595 | dataset = dataset.train_test_split(test_size=data_args.max_val_samples) 596 | else: 597 | dataset = dataset.train_test_split(test_size=data_args.val_ratio) 598 | 599 | train_dataset = dataset["train"] 600 | eval_dataset = dataset["test"] 601 | 602 | else: 603 | train_dataset = dataset 604 | eval_dataset = None 605 | 606 | 607 | # Filtering dataset: 608 | 609 | train_dataset_original_size = len(train_dataset) 610 | if eval_dataset is not None: 611 | eval_dataset_original_size = len(eval_dataset) 612 | 613 | if data_args.use_only_common_voice_data: 614 | train_dataset = train_dataset.filter( 615 | lambda example: example["dataset"] == "common_voice", 616 | num_proc=data_args.preprocessing_num_workers 617 | ) 618 | 619 | train_dataset = train_dataset.filter( 620 | lambda example: example["duration"] >= data_args.min_duration and example["duration"] <= data_args.max_duration, 621 | num_proc=data_args.preprocessing_num_workers 622 | ) 623 | 624 | if data_args.max_train_samples is not None and train_dataset_original_size > data_args.max_train_samples: 625 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 626 | 627 | if eval_dataset is not None and data_args.max_val_samples is not None and eval_dataset_original_size > data_args.max_val_samples: 628 | eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) 629 | 630 | train_dataset_final_size = len(train_dataset) 631 | if eval_dataset is not None: 632 | eval_dataset_final_size = len(eval_dataset) 633 | 634 | logger.info(f"After filtering {train_dataset_final_size} of {train_dataset_original_size} samples will be used to train the model") 635 | if eval_dataset is not None: 636 | logger.info(f"After filtering {eval_dataset_final_size} of {eval_dataset_original_size} samples will be used to eval the model") 637 | 638 | # Create and save tokenizer 639 | chars_to_ignore_regex = f"[{re.escape(''.join(data_args.chars_to_ignore))}]" 640 | 641 | def remove_special_characters(batch): 642 | batch["text"] = re.sub(chars_to_ignore_regex, "", batch["sentence"]).strip().upper() + " " 643 | return batch 644 | 645 | train_dataset = train_dataset.map( 646 | remove_special_characters, 647 | remove_columns=["sentence"], 648 | num_proc=data_args.preprocessing_num_workers 649 | ) 650 | if eval_dataset is not None: 651 | eval_dataset = eval_dataset.map( 652 | remove_special_characters, 653 | remove_columns=["sentence"], 654 | num_proc=data_args.preprocessing_num_workers 655 | ) 656 | 657 | # Load pretrained model and tokenizer 658 | # 659 | # Distributed training: 660 | # The .from_pretrained methods guarantee that only one local process can concurrently 661 | # download model & vocab. 662 | 663 | if model_args.model_name_or_path in PRETRAINED_MODELS: 664 | dataset = datasets.concatenate_datasets([train_dataset, eval_dataset]) if eval_dataset is not None else train_dataset 665 | if len(dataset) > additional_training_args.max_dataset_size_vocab_builder: 666 | dataset = dataset.select(range(additional_training_args.max_dataset_size_vocab_builder)) 667 | tokenizer = build_tokenizer(training_args.output_dir, dataset, data_args.preprocessing_num_workers, additional_training_args.min_char_occurrence_ratio) 668 | feature_extractor = Wav2Vec2FeatureExtractor( 669 | feature_size=1, sampling_rate=16_000, padding_value=0.0, do_normalize=True, return_attention_mask=True 670 | ) 671 | processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) 672 | else: 673 | processor = Wav2Vec2Processor.from_pretrained(model_args.model_name_or_path) 674 | 675 | if additional_training_args.remove_samples_with_oov_from_training: 676 | vocab = set(processor.tokenizer.encoder.keys()) 677 | train_dataset_size = len(train_dataset) 678 | train_dataset = train_dataset.filter( 679 | lambda example: vocab.issuperset(example["text"].replace(" ", "")), 680 | num_proc=data_args.preprocessing_num_workers 681 | ) 682 | print(f"OOV found in {train_dataset_size - len(train_dataset)} samples, and they were removed from training set") 683 | print(f"The final training set size is {len(train_dataset)}") 684 | 685 | if additional_training_args.use_only_top_k_most_common_accent is not None: 686 | 687 | train_dataset_size = len(train_dataset) 688 | 689 | accent_count = collections.Counter(train_dataset["accent"]) 690 | # accent_count.pop("", None) 691 | major_accents = [k for k, x in accent_count.most_common(additional_training_args.use_only_top_k_most_common_accent)] 692 | 693 | print(f"ACCENT COUNT: {accent_count}") 694 | 695 | train_dataset = train_dataset.filter( 696 | lambda example: example["accent"] in major_accents, 697 | num_proc=data_args.preprocessing_num_workers 698 | ) 699 | 700 | print(f"{train_dataset_size - len(train_dataset)} were removed from dataset due accent filtering, the final training dataset size is {len(train_dataset)}") 701 | 702 | # save the feature_extractor and the tokenizer 703 | processor.save_pretrained(training_args.output_dir) 704 | 705 | model = Wav2Vec2ForCTC.from_pretrained( 706 | model_args.model_name_or_path, 707 | cache_dir=model_args.cache_dir, 708 | activation_dropout=model_args.activation_dropout, 709 | attention_dropout=model_args.attention_dropout, 710 | hidden_dropout=model_args.hidden_dropout, 711 | feat_proj_dropout=model_args.feat_proj_dropout, 712 | mask_time_prob=model_args.mask_time_prob, 713 | gradient_checkpointing=model_args.gradient_checkpointing, 714 | layerdrop=model_args.layerdrop, 715 | ctc_loss_reduction="mean", 716 | pad_token_id=processor.tokenizer.pad_token_id, 717 | vocab_size=len(processor.tokenizer), 718 | ctc_zero_infinity=True 719 | ) 720 | 721 | # Preprocessing the datasets. 722 | # We need to read the audio files as arrays and tokenize the targets. 723 | def speech_file_to_array_fn(batch): 724 | speech_array, sampling_rate = sf.read(batch["path"]) 725 | batch["speech"] = speech_array 726 | batch["sampling_rate"] = sampling_rate 727 | batch["target_text"] = batch["text"] 728 | return batch 729 | 730 | print("TRAIN DATASET COUNT:") 731 | print(collections.Counter(train_dataset["dataset"])) 732 | print("EVAL DATASET COUNT:") 733 | print(collections.Counter(eval_dataset["dataset"])) 734 | 735 | train_dataset = train_dataset.map( 736 | speech_file_to_array_fn, 737 | remove_columns=train_dataset.column_names, 738 | num_proc=data_args.preprocessing_num_workers 739 | ) 740 | if eval_dataset is not None: 741 | eval_dataset = eval_dataset.map( 742 | speech_file_to_array_fn, 743 | remove_columns=eval_dataset.column_names, 744 | num_proc=data_args.preprocessing_num_workers 745 | ) 746 | 747 | def prepare_dataset(batch): 748 | # check that all files have the correct sampling rate 749 | assert ( 750 | len(set(batch["sampling_rate"])) == 1 751 | ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}." 752 | batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values 753 | 754 | # Setup the processor for targets 755 | with processor.as_target_processor(): 756 | batch["labels"] = processor(batch["target_text"]).input_ids 757 | return batch 758 | 759 | train_dataset = train_dataset.map( 760 | prepare_dataset, 761 | remove_columns=train_dataset.column_names, 762 | batch_size=training_args.per_device_train_batch_size, 763 | batched=True, 764 | num_proc=data_args.preprocessing_num_workers 765 | ) 766 | if eval_dataset is not None: 767 | eval_dataset = eval_dataset.map( 768 | prepare_dataset, 769 | remove_columns=eval_dataset.column_names, 770 | batch_size=training_args.per_device_train_batch_size, 771 | batched=True, 772 | num_proc=data_args.preprocessing_num_workers 773 | ) 774 | 775 | # Pre-compute sample lengths 776 | def input_lengths(example): 777 | example["length"] = len(example["input_values"]) 778 | return example 779 | 780 | train_dataset = train_dataset.map( 781 | input_lengths, 782 | num_proc=data_args.preprocessing_num_workers 783 | ) 784 | 785 | # Metric 786 | wer_metric = datasets.load_metric("wer.py") 787 | cer_metric = datasets.load_metric("cer.py") 788 | 789 | def compute_metrics(pred): 790 | pred_logits = pred.predictions 791 | pred_ids = np.argmax(pred_logits, axis=-1) 792 | 793 | pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id 794 | 795 | pred_str = processor.batch_decode(pred_ids) 796 | # we do not want to group tokens when computing the metrics 797 | label_str = processor.batch_decode(pred.label_ids, group_tokens=False) 798 | 799 | wer = wer_metric.compute(predictions=pred_str, references=label_str, chunk_size=1000) 800 | cer = cer_metric.compute(predictions=pred_str, references=label_str, chunk_size=1000) 801 | 802 | return {"wer": wer, "cer": cer} 803 | 804 | if model_args.freeze_feature_extractor: 805 | model.freeze_feature_extractor() 806 | 807 | # Data collator 808 | data_collator = DataCollatorCTCWithPadding( 809 | processor=processor, 810 | padding=True, 811 | apply_gaussian_noise_with_p=additional_training_args.apply_gaussian_noise_with_p, 812 | apply_gain_with_p=additional_training_args.apply_gain_with_p, 813 | apply_pitch_shift_with_p=additional_training_args.apply_pitch_shift_with_p, 814 | apply_time_stretch_with_p=additional_training_args.apply_time_stretch_with_p, 815 | sample_rate=16_000, 816 | ) 817 | 818 | # Initialize our Trainer 819 | trainer = CTCTrainer( 820 | model_output_dir=training_args.output_dir, 821 | length_field_name="length", 822 | upload_model_to_wandb_each_step=additional_training_args.upload_model_to_wandb_each_step, 823 | lr_warmup_ratio=additional_training_args.lr_warmup_ratio, 824 | lr_constant_ratio=additional_training_args.lr_constant_ratio, 825 | sampling_rate=16_000, 826 | model=model, 827 | data_collator=data_collator, 828 | args=training_args, 829 | compute_metrics=compute_metrics, 830 | train_dataset=train_dataset, 831 | eval_dataset=eval_dataset, 832 | tokenizer=processor.feature_extractor, 833 | ) 834 | 835 | # Training 836 | if training_args.do_train: 837 | if last_checkpoint is not None: 838 | checkpoint = last_checkpoint 839 | elif os.path.isdir(model_args.model_name_or_path): 840 | checkpoint = model_args.model_name_or_path 841 | else: 842 | checkpoint = None 843 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 844 | trainer.save_model() 845 | 846 | metrics = train_result.metrics 847 | max_train_samples = ( 848 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 849 | ) 850 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 851 | 852 | trainer.log_metrics("train", metrics) 853 | trainer.save_metrics("train", metrics) 854 | trainer.save_state() 855 | 856 | # Evaluation 857 | metrics = {} 858 | if eval_dataset is not None and training_args.do_eval: 859 | logger.info("*** Evaluate ***") 860 | metrics = trainer.evaluate() 861 | max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset) 862 | metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) 863 | 864 | trainer.log_metrics("eval", metrics) 865 | trainer.save_metrics("eval", metrics) 866 | 867 | # save model files 868 | if additional_training_args.upload_final_model_to_wandb: 869 | upload_model_to_wandb(training_args.output_dir, name=f"{wandb.run.name}_final", metadata=metrics) 870 | 871 | if __name__ == "__main__": 872 | main() 873 | -------------------------------------------------------------------------------- /dataset_ext.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Datasets Authors and the current dataset script contributor. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Common Voice Dataset""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import os 20 | import re 21 | import homoglyphs as hg 22 | import gdown 23 | import json 24 | import pandas as pd 25 | import glob 26 | 27 | import datasets 28 | 29 | import soundfile as sf 30 | import librosa 31 | import warnings 32 | 33 | from lang_trans.arabic import buckwalter 34 | 35 | _DATA_URL = "https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com/cv-corpus-6.1-2020-12-11/{}.tar.gz" 36 | 37 | _CITATION = """\ 38 | @inproceedings{commonvoice:2020, 39 | author = {Ardila, R. and Branson, M. and Davis, K. and Henretty, M. and Kohler, M. and Meyer, J. and Morais, R. and Saunders, L. and Tyers, F. M. and Weber, G.}, 40 | title = {Common Voice: A Massively-Multilingual Speech Corpus}, 41 | booktitle = {Proceedings of the 12th Conference on Language Resources and Evaluation (LREC 2020)}, 42 | pages = {4211--4215}, 43 | year = 2020 44 | } 45 | """ 46 | 47 | _DESCRIPTION = """\ 48 | Common Voice is Mozilla's initiative to help teach machines how real people speak. 49 | The dataset currently consists of 7,335 validated hours of speech in 60 languages, but we’re always adding more voices and languages. 50 | """ 51 | 52 | _HOMEPAGE = "https://commonvoice.mozilla.org/en/datasets" 53 | 54 | _LICENSE = "https://github.com/common-voice/common-voice/blob/main/LICENSE" 55 | 56 | _LANGUAGES = { 57 | "ab": { 58 | "Language": "Abkhaz", 59 | "Date": "2020-12-11", 60 | "Size": "39 MB", 61 | "Version": "ab_1h_2020-12-11", 62 | "Validated_Hr_Total": 0.05, 63 | "Overall_Hr_Total": 1, 64 | "Number_Of_Voice": 14, 65 | }, 66 | "ar": { 67 | "Language": "Arabic", 68 | "Date": "2020-12-11", 69 | "Size": "2 GB", 70 | "Version": "ar_77h_2020-12-11", 71 | "Validated_Hr_Total": 49, 72 | "Overall_Hr_Total": 77, 73 | "Number_Of_Voice": 672, 74 | }, 75 | "as": { 76 | "Language": "Assamese", 77 | "Date": "2020-12-11", 78 | "Size": "21 MB", 79 | "Version": "as_0.78h_2020-12-11", 80 | "Validated_Hr_Total": 0.74, 81 | "Overall_Hr_Total": 0.78, 82 | "Number_Of_Voice": 17, 83 | }, 84 | "br": { 85 | "Language": "Breton", 86 | "Date": "2020-12-11", 87 | "Size": "444 MB", 88 | "Version": "br_16h_2020-12-11", 89 | "Validated_Hr_Total": 7, 90 | "Overall_Hr_Total": 16, 91 | "Number_Of_Voice": 157, 92 | }, 93 | "ca": { 94 | "Language": "Catalan", 95 | "Date": "2020-12-11", 96 | "Size": "19 GB", 97 | "Version": "ca_748h_2020-12-11", 98 | "Validated_Hr_Total": 623, 99 | "Overall_Hr_Total": 748, 100 | "Number_Of_Voice": 5376, 101 | }, 102 | "cnh": { 103 | "Language": "Hakha Chin", 104 | "Date": "2020-12-11", 105 | "Size": "39 MB", 106 | "Version": "ab_1h_2020-12-11", 107 | "Validated_Hr_Total": 0.05, 108 | "Overall_Hr_Total": 1, 109 | "Number_Of_Voice": 14, 110 | }, 111 | "cs": { 112 | "Language": "Czech", 113 | "Date": "2020-12-11", 114 | "Size": "39 MB", 115 | "Version": "ab_1h_2020-12-11", 116 | "Validated_Hr_Total": 0.05, 117 | "Overall_Hr_Total": 1, 118 | "Number_Of_Voice": 14, 119 | }, 120 | "cv": { 121 | "Language": "Chuvash", 122 | "Date": "2020-12-11", 123 | "Size": "419 MB", 124 | "Version": "cv_16h_2020-12-11", 125 | "Validated_Hr_Total": 4, 126 | "Overall_Hr_Total": 16, 127 | "Number_Of_Voice": 92, 128 | }, 129 | "cy": { 130 | "Language": "Welsh", 131 | "Date": "2020-12-11", 132 | "Size": "3 GB", 133 | "Version": "cy_124h_2020-12-11", 134 | "Validated_Hr_Total": 95, 135 | "Overall_Hr_Total": 124, 136 | "Number_Of_Voice": 1382, 137 | }, 138 | "de": { 139 | "Language": "German", 140 | "Date": "2020-12-11", 141 | "Size": "22 GB", 142 | "Version": "de_836h_2020-12-11", 143 | "Validated_Hr_Total": 777, 144 | "Overall_Hr_Total": 836, 145 | "Number_Of_Voice": 12659, 146 | }, 147 | "dv": { 148 | "Language": "Dhivehi", 149 | "Date": "2020-12-11", 150 | "Size": "515 MB", 151 | "Version": "dv_19h_2020-12-11", 152 | "Validated_Hr_Total": 18, 153 | "Overall_Hr_Total": 19, 154 | "Number_Of_Voice": 167, 155 | }, 156 | "el": { 157 | "Language": "Greek", 158 | "Date": "2020-12-11", 159 | "Size": "364 MB", 160 | "Version": "el_13h_2020-12-11", 161 | "Validated_Hr_Total": 6, 162 | "Overall_Hr_Total": 13, 163 | "Number_Of_Voice": 118, 164 | }, 165 | "en": { 166 | "Language": "English", 167 | "Date": "2020-12-11", 168 | "Size": "56 GB", 169 | "Version": "en_2181h_2020-12-11", 170 | "Validated_Hr_Total": 1686, 171 | "Overall_Hr_Total": 2181, 172 | "Number_Of_Voice": 66173, 173 | }, 174 | "eo": { 175 | "Language": "Esperanto", 176 | "Date": "2020-12-11", 177 | "Size": "3 GB", 178 | "Version": "eo_102h_2020-12-11", 179 | "Validated_Hr_Total": 90, 180 | "Overall_Hr_Total": 102, 181 | "Number_Of_Voice": 574, 182 | }, 183 | "es": { 184 | "Language": "Spanish", 185 | "Date": "2020-12-11", 186 | "Size": "15 GB", 187 | "Version": "es_579h_2020-12-11", 188 | "Validated_Hr_Total": 324, 189 | "Overall_Hr_Total": 579, 190 | "Number_Of_Voice": 19484, 191 | }, 192 | "et": { 193 | "Language": "Estonian", 194 | "Date": "2020-12-11", 195 | "Size": "732 MB", 196 | "Version": "et_27h_2020-12-11", 197 | "Validated_Hr_Total": 19, 198 | "Overall_Hr_Total": 27, 199 | "Number_Of_Voice": 543, 200 | }, 201 | "eu": { 202 | "Language": "Basque", 203 | "Date": "2020-12-11", 204 | "Size": "3 GB", 205 | "Version": "eu_131h_2020-12-11", 206 | "Validated_Hr_Total": 89, 207 | "Overall_Hr_Total": 131, 208 | "Number_Of_Voice": 1028, 209 | }, 210 | "fa": { 211 | "Language": "Persian", 212 | "Date": "2020-12-11", 213 | "Size": "8 GB", 214 | "Version": "fa_321h_2020-12-11", 215 | "Validated_Hr_Total": 282, 216 | "Overall_Hr_Total": 321, 217 | "Number_Of_Voice": 3655, 218 | }, 219 | "fi": { 220 | "Language": "Finnish", 221 | "Date": "2020-12-11", 222 | "Size": "48 MB", 223 | "Version": "fi_1h_2020-12-11", 224 | "Validated_Hr_Total": 1, 225 | "Overall_Hr_Total": 1, 226 | "Number_Of_Voice": 27, 227 | }, 228 | "fr": { 229 | "Language": "French", 230 | "Date": "2020-12-11", 231 | "Size": "18 GB", 232 | "Version": "fr_682h_2020-12-11", 233 | "Validated_Hr_Total": 623, 234 | "Overall_Hr_Total": 682, 235 | "Number_Of_Voice": 12953, 236 | }, 237 | "fy-NL": { 238 | "Language": "Frisian", 239 | "Date": "2020-12-11", 240 | "Size": "1 GB", 241 | "Version": "fy-NL_46h_2020-12-11", 242 | "Validated_Hr_Total": 14, 243 | "Overall_Hr_Total": 46, 244 | "Number_Of_Voice": 467, 245 | }, 246 | "ga-IE": { 247 | "Language": "Irish", 248 | "Date": "2020-12-11", 249 | "Size": "149 MB", 250 | "Version": "ga-IE_5h_2020-12-11", 251 | "Validated_Hr_Total": 3, 252 | "Overall_Hr_Total": 5, 253 | "Number_Of_Voice": 101, 254 | }, 255 | "hi": { 256 | "Language": "Hindi", 257 | "Date": "2020-12-11", 258 | "Size": "20 MB", 259 | "Version": "hi_0.8h_2020-12-11", 260 | "Validated_Hr_Total": 0.54, 261 | "Overall_Hr_Total": 0.8, 262 | "Number_Of_Voice": 31, 263 | }, 264 | "hsb": { 265 | "Language": "Sorbian, Upper", 266 | "Date": "2020-12-11", 267 | "Size": "76 MB", 268 | "Version": "hsb_2h_2020-12-11", 269 | "Validated_Hr_Total": 2, 270 | "Overall_Hr_Total": 2, 271 | "Number_Of_Voice": 19, 272 | }, 273 | "hu": { 274 | "Language": "Hungarian", 275 | "Date": "2020-12-11", 276 | "Size": "232 MB", 277 | "Version": "hu_8h_2020-12-11", 278 | "Validated_Hr_Total": 8, 279 | "Overall_Hr_Total": 8, 280 | "Number_Of_Voice": 47, 281 | }, 282 | "ia": { 283 | "Language": "InterLinguia", 284 | "Date": "2020-12-11", 285 | "Size": "216 MB", 286 | "Version": "ia_8h_2020-12-11", 287 | "Validated_Hr_Total": 6, 288 | "Overall_Hr_Total": 8, 289 | "Number_Of_Voice": 36, 290 | }, 291 | "id": { 292 | "Language": "Indonesian", 293 | "Date": "2020-12-11", 294 | "Size": "454 MB", 295 | "Version": "id_17h_2020-12-11", 296 | "Validated_Hr_Total": 9, 297 | "Overall_Hr_Total": 17, 298 | "Number_Of_Voice": 219, 299 | }, 300 | "it": { 301 | "Language": "Italian", 302 | "Date": "2020-12-11", 303 | "Size": "5 GB", 304 | "Version": "it_199h_2020-12-11", 305 | "Validated_Hr_Total": 158, 306 | "Overall_Hr_Total": 199, 307 | "Number_Of_Voice": 5729, 308 | }, 309 | "ja": { 310 | "Language": "Japanese", 311 | "Date": "2020-12-11", 312 | "Size": "146 MB", 313 | "Version": "ja_5h_2020-12-11", 314 | "Validated_Hr_Total": 3, 315 | "Overall_Hr_Total": 5, 316 | "Number_Of_Voice": 235, 317 | }, 318 | "ka": { 319 | "Language": "Georgian", 320 | "Date": "2020-12-11", 321 | "Size": "99 MB", 322 | "Version": "ka_3h_2020-12-11", 323 | "Validated_Hr_Total": 3, 324 | "Overall_Hr_Total": 3, 325 | "Number_Of_Voice": 44, 326 | }, 327 | "kab": { 328 | "Language": "Kabyle", 329 | "Date": "2020-12-11", 330 | "Size": "16 GB", 331 | "Version": "kab_622h_2020-12-11", 332 | "Validated_Hr_Total": 525, 333 | "Overall_Hr_Total": 622, 334 | "Number_Of_Voice": 1309, 335 | }, 336 | "ky": { 337 | "Language": "Kyrgyz", 338 | "Date": "2020-12-11", 339 | "Size": "553 MB", 340 | "Version": "ky_22h_2020-12-11", 341 | "Validated_Hr_Total": 11, 342 | "Overall_Hr_Total": 22, 343 | "Number_Of_Voice": 134, 344 | }, 345 | "lg": { 346 | "Language": "Luganda", 347 | "Date": "2020-12-11", 348 | "Size": "199 MB", 349 | "Version": "lg_8h_2020-12-11", 350 | "Validated_Hr_Total": 3, 351 | "Overall_Hr_Total": 8, 352 | "Number_Of_Voice": 76, 353 | }, 354 | "lt": { 355 | "Language": "Lithuanian", 356 | "Date": "2020-12-11", 357 | "Size": "129 MB", 358 | "Version": "lt_4h_2020-12-11", 359 | "Validated_Hr_Total": 2, 360 | "Overall_Hr_Total": 4, 361 | "Number_Of_Voice": 30, 362 | }, 363 | "lv": { 364 | "Language": "Latvian", 365 | "Date": "2020-12-11", 366 | "Size": "199 MB", 367 | "Version": "lv_7h_2020-12-11", 368 | "Validated_Hr_Total": 6, 369 | "Overall_Hr_Total": 7, 370 | "Number_Of_Voice": 99, 371 | }, 372 | "mn": { 373 | "Language": "Mongolian", 374 | "Date": "2020-12-11", 375 | "Size": "464 MB", 376 | "Version": "mn_17h_2020-12-11", 377 | "Validated_Hr_Total": 11, 378 | "Overall_Hr_Total": 17, 379 | "Number_Of_Voice": 376, 380 | }, 381 | "mt": { 382 | "Language": "Maltese", 383 | "Date": "2020-12-11", 384 | "Size": "405 MB", 385 | "Version": "mt_15h_2020-12-11", 386 | "Validated_Hr_Total": 7, 387 | "Overall_Hr_Total": 15, 388 | "Number_Of_Voice": 171, 389 | }, 390 | "nl": { 391 | "Language": "Dutch", 392 | "Date": "2020-12-11", 393 | "Size": "2 GB", 394 | "Version": "nl_63h_2020-12-11", 395 | "Validated_Hr_Total": 59, 396 | "Overall_Hr_Total": 63, 397 | "Number_Of_Voice": 1012, 398 | }, 399 | "or": { 400 | "Language": "Odia", 401 | "Date": "2020-12-11", 402 | "Size": "190 MB", 403 | "Version": "or_7h_2020-12-11", 404 | "Validated_Hr_Total": 0.87, 405 | "Overall_Hr_Total": 7, 406 | "Number_Of_Voice": 34, 407 | }, 408 | "pa-IN": { 409 | "Language": "Punjabi", 410 | "Date": "2020-12-11", 411 | "Size": "67 MB", 412 | "Version": "pa-IN_2h_2020-12-11", 413 | "Validated_Hr_Total": 0.5, 414 | "Overall_Hr_Total": 2, 415 | "Number_Of_Voice": 26, 416 | }, 417 | "pl": { 418 | "Language": "Polish", 419 | "Date": "2020-12-11", 420 | "Size": "3 GB", 421 | "Version": "pl_129h_2020-12-11", 422 | "Validated_Hr_Total": 108, 423 | "Overall_Hr_Total": 129, 424 | "Number_Of_Voice": 2647, 425 | }, 426 | "pt": { 427 | "Language": "Portuguese", 428 | "Date": "2020-12-11", 429 | "Size": "2 GB", 430 | "Version": "pt_63h_2020-12-11", 431 | "Validated_Hr_Total": 50, 432 | "Overall_Hr_Total": 63, 433 | "Number_Of_Voice": 1120, 434 | }, 435 | "rm-sursilv": { 436 | "Language": "Romansh Sursilvan", 437 | "Date": "2020-12-11", 438 | "Size": "263 MB", 439 | "Version": "rm-sursilv_9h_2020-12-11", 440 | "Validated_Hr_Total": 5, 441 | "Overall_Hr_Total": 9, 442 | "Number_Of_Voice": 78, 443 | }, 444 | "rm-vallader": { 445 | "Language": "Romansh Vallader", 446 | "Date": "2020-12-11", 447 | "Size": "103 MB", 448 | "Version": "rm-vallader_3h_2020-12-11", 449 | "Validated_Hr_Total": 2, 450 | "Overall_Hr_Total": 3, 451 | "Number_Of_Voice": 39, 452 | }, 453 | "ro": { 454 | "Language": "Romanian", 455 | "Date": "2020-12-11", 456 | "Size": "250 MB", 457 | "Version": "ro_9h_2020-12-11", 458 | "Validated_Hr_Total": 6, 459 | "Overall_Hr_Total": 9, 460 | "Number_Of_Voice": 130, 461 | }, 462 | "ru": { 463 | "Language": "Russian", 464 | "Date": "2020-12-11", 465 | "Size": "3 GB", 466 | "Version": "ru_130h_2020-12-11", 467 | "Validated_Hr_Total": 111, 468 | "Overall_Hr_Total": 130, 469 | "Number_Of_Voice": 1412, 470 | }, 471 | "rw": { 472 | "Language": "Kinyarwanda", 473 | "Date": "2020-12-11", 474 | "Size": "40 GB", 475 | "Version": "rw_1510h_2020-12-11", 476 | "Validated_Hr_Total": 1183, 477 | "Overall_Hr_Total": 1510, 478 | "Number_Of_Voice": 410, 479 | }, 480 | "sah": { 481 | "Language": "Sakha", 482 | "Date": "2020-12-11", 483 | "Size": "173 MB", 484 | "Version": "sah_6h_2020-12-11", 485 | "Validated_Hr_Total": 4, 486 | "Overall_Hr_Total": 6, 487 | "Number_Of_Voice": 42, 488 | }, 489 | "sl": { 490 | "Language": "Slovenian", 491 | "Date": "2020-12-11", 492 | "Size": "212 MB", 493 | "Version": "sl_7h_2020-12-11", 494 | "Validated_Hr_Total": 5, 495 | "Overall_Hr_Total": 7, 496 | "Number_Of_Voice": 82, 497 | }, 498 | "sv-SE": { 499 | "Language": "Swedish", 500 | "Date": "2020-12-11", 501 | "Size": "402 MB", 502 | "Version": "sv-SE_15h_2020-12-11", 503 | "Validated_Hr_Total": 12, 504 | "Overall_Hr_Total": 15, 505 | "Number_Of_Voice": 222, 506 | }, 507 | "ta": { 508 | "Language": "Tamil", 509 | "Date": "2020-12-11", 510 | "Size": "648 MB", 511 | "Version": "ta_24h_2020-12-11", 512 | "Validated_Hr_Total": 14, 513 | "Overall_Hr_Total": 24, 514 | "Number_Of_Voice": 266, 515 | }, 516 | "th": { 517 | "Language": "Thai", 518 | "Date": "2020-12-11", 519 | "Size": "325 MB", 520 | "Version": "th_12h_2020-12-11", 521 | "Validated_Hr_Total": 8, 522 | "Overall_Hr_Total": 12, 523 | "Number_Of_Voice": 182, 524 | }, 525 | "tr": { 526 | "Language": "Turkish", 527 | "Date": "2020-12-11", 528 | "Size": "592 MB", 529 | "Version": "tr_22h_2020-12-11", 530 | "Validated_Hr_Total": 20, 531 | "Overall_Hr_Total": 22, 532 | "Number_Of_Voice": 678, 533 | }, 534 | "tt": { 535 | "Language": "Tatar", 536 | "Date": "2020-12-11", 537 | "Size": "741 MB", 538 | "Version": "tt_28h_2020-12-11", 539 | "Validated_Hr_Total": 26, 540 | "Overall_Hr_Total": 28, 541 | "Number_Of_Voice": 185, 542 | }, 543 | "uk": { 544 | "Language": "Ukrainian", 545 | "Date": "2020-12-11", 546 | "Size": "1 GB", 547 | "Version": "uk_43h_2020-12-11", 548 | "Validated_Hr_Total": 30, 549 | "Overall_Hr_Total": 43, 550 | "Number_Of_Voice": 459, 551 | }, 552 | "vi": { 553 | "Language": "Vietnamese", 554 | "Date": "2020-12-11", 555 | "Size": "50 MB", 556 | "Version": "vi_1h_2020-12-11", 557 | "Validated_Hr_Total": 0.74, 558 | "Overall_Hr_Total": 1, 559 | "Number_Of_Voice": 62, 560 | }, 561 | "vot": { 562 | "Language": "Votic", 563 | "Date": "2020-12-11", 564 | "Size": "7 MB", 565 | "Version": "vot_0.28h_2020-12-11", 566 | "Validated_Hr_Total": 0, 567 | "Overall_Hr_Total": 0.28, 568 | "Number_Of_Voice": 3, 569 | }, 570 | "zh-CN": { 571 | "Language": "Chinese (China)", 572 | "Date": "2020-12-11", 573 | "Size": "2 GB", 574 | "Version": "zh-CN_78h_2020-12-11", 575 | "Validated_Hr_Total": 56, 576 | "Overall_Hr_Total": 78, 577 | "Number_Of_Voice": 3501, 578 | }, 579 | "zh-HK": { 580 | "Language": "Chinese (Hong Kong)", 581 | "Date": "2020-12-11", 582 | "Size": "3 GB", 583 | "Version": "zh-HK_100h_2020-12-11", 584 | "Validated_Hr_Total": 50, 585 | "Overall_Hr_Total": 100, 586 | "Number_Of_Voice": 2536, 587 | }, 588 | "zh-TW": { 589 | "Language": "Chinese (Taiwan)", 590 | "Date": "2020-12-11", 591 | "Size": "2 GB", 592 | "Version": "zh-TW_78h_2020-12-11", 593 | "Validated_Hr_Total": 55, 594 | "Overall_Hr_Total": 78, 595 | "Number_Of_Voice": 1444, 596 | }, 597 | } 598 | 599 | _CSS10_URLS = { 600 | "de": "https://drive.google.com/uc?id=1wgCHGvT0S8YrNfRTVyn23sW-5MFknoHA", # 7427 samples 601 | "el": "https://drive.google.com/uc?id=10BNORyOqkosxEf3qAAtWM1qWjHEZzXTO", # 1844 samples 602 | "es": "https://drive.google.com/uc?id=1dyUvSxv0KowTseI35dE8UXpVsYFhEpQV", # 11100 samples 603 | "fi": "https://drive.google.com/uc?id=1H4-eGIgf4aK_s14uo-srbKMENpysuV2u", # 4842 samples 604 | "fr": "https://drive.google.com/uc?id=1kuhoDjhA_Cij0SJuMI_4kneDTR_cqahS", # 8648 samples 605 | "hu": "https://drive.google.com/uc?id=1ms2INJ1e0ChU0TMzgDYLa8jtoTK2gkmE", # 4515 samples 606 | "ja": "https://drive.google.com/uc?id=1E4k8FduAk-_wy85AQrGakZBcw2hLhmU6", # 6841 samples 607 | "nl": "https://drive.google.com/uc?id=1ji8QD4lJzInz2vomGkMafRjpz3gGBYsf", # 6494 samples 608 | "ru": "https://drive.google.com/uc?id=1tx3dpO8SX8CriF0YsK8XeISZc9yGRody", # 9599 samples 609 | "zh-CN": "https://drive.google.com/uc?id=1hliY4KD_I8y4FQg5zta9IDGN0HRQLRiv", # 2971 samples 610 | } 611 | 612 | _JSUT_URLS = { 613 | "ja": "http://ss-takashi.sakura.ne.jp/corpus/jsut_ver1.1.zip" # 7696 samples 614 | } 615 | 616 | _NST_URLS = { 617 | "sv-SE": { 618 | "metadata": "https://www.nb.no/sbfil/talegjenkjenning/16kHz_2020/se_2020/ADB_SWE_0467.tar.gz", 619 | "files": "https://www.nb.no/sbfil/talegjenkjenning/16kHz_2020/se_2020/lydfiler_16_1.tar.gz", # ? samples 620 | } 621 | } 622 | 623 | _FREE_ST_URLS = { 624 | "zh-CN": "https://www.openslr.org/resources/38/ST-CMDS-20170001_1-OS.tar.gz", # 102600 samples 625 | } 626 | 627 | _ARABIC_SPEECH = { 628 | "ar": "http://en.arabicspeechcorpus.com/arabic-speech-corpus.zip" # 1913 samples 629 | } 630 | 631 | _TIMIT = { 632 | "en": "https://data.deepai.org/timit.zip" # 4620 samples 633 | } 634 | 635 | _LIBRISPEECH_DL_URL = "http://www.openslr.org/resources/12/" 636 | _LIBRISPEECH = { 637 | "en": [ 638 | _LIBRISPEECH_DL_URL + "dev-clean.tar.gz", # 2703 samples 639 | _LIBRISPEECH_DL_URL + "dev-other.tar.gz", # 2864 samples 640 | _LIBRISPEECH_DL_URL + "train-clean-100.tar.gz", # 28539 samples 641 | _LIBRISPEECH_DL_URL + "train-clean-360.tar.gz", # 104014 samples 642 | _LIBRISPEECH_DL_URL + "train-other-500.tar.gz", # 148688 samples 643 | ] 644 | } 645 | 646 | _MAX_TRAIN_SAMPLES = 90000 647 | _MAX_VAL_SAMPLES = 10000 648 | 649 | class CommonVoiceConfig(datasets.BuilderConfig): 650 | """BuilderConfig for CommonVoice.""" 651 | 652 | def __init__(self, name, sub_version, **kwargs): 653 | """ 654 | Args: 655 | data_dir: `string`, the path to the folder containing the files in the 656 | downloaded .tar 657 | citation: `string`, citation for the data set 658 | url: `string`, url for information about the data set 659 | **kwargs: keyword arguments forwarded to super. 660 | """ 661 | self.sub_version = sub_version 662 | self.language = kwargs.pop("language", None) 663 | self.date_of_snapshot = kwargs.pop("date", None) 664 | self.size = kwargs.pop("size", None) 665 | self.validated_hr_total = kwargs.pop("val_hrs", None) 666 | self.total_hr_total = kwargs.pop("total_hrs", None) 667 | self.num_of_voice = kwargs.pop("num_of_voice", None) 668 | 669 | self.unk_token_regex = None 670 | if self.language in hg.Languages.get_all(): 671 | # creating regex to match language specific non valid characters 672 | currency_symbols = ["$", "£", "€", "¥", "₩", "₹", "₽", "₱", "₦", "₼", "ლ", "₭", "₴", "₲", "₫", "₡", "₵", "₿", "฿", "¢"] 673 | alphabet = list(hg.Languages.get_alphabet([self.language])) 674 | valid_chars = alphabet + currency_symbols 675 | self.unk_token_regex = "[^"+re.escape("".join(valid_chars))+"\s\d]" 676 | 677 | description = f"Common Voice speech to text dataset in {self.language} version {self.sub_version} of {self.date_of_snapshot}. The dataset comprises {self.validated_hr_total} of validated transcribed speech data from {self.num_of_voice} speakers. The dataset has a size of {self.size}" 678 | super(CommonVoiceConfig, self).__init__( 679 | name=name, version=datasets.Version("6.1.0", ""), description=description, **kwargs 680 | ) 681 | 682 | 683 | class CommonVoice(datasets.GeneratorBasedBuilder): 684 | 685 | BUILDER_CONFIGS = [ 686 | CommonVoiceConfig( 687 | name=lang_id, 688 | language=_LANGUAGES[lang_id]["Language"], 689 | sub_version=_LANGUAGES[lang_id]["Version"], 690 | date=_LANGUAGES[lang_id]["Date"], 691 | size=_LANGUAGES[lang_id]["Size"], 692 | val_hrs=_LANGUAGES[lang_id]["Validated_Hr_Total"], 693 | total_hrs=_LANGUAGES[lang_id]["Overall_Hr_Total"], 694 | num_of_voice=_LANGUAGES[lang_id]["Number_Of_Voice"], 695 | ) 696 | for lang_id in _LANGUAGES.keys() 697 | ] 698 | 699 | def _info(self): 700 | features = datasets.Features( 701 | { 702 | "client_id": datasets.Value("string"), 703 | "path": datasets.Value("string"), 704 | "sentence": datasets.Value("string"), 705 | "up_votes": datasets.Value("int64"), 706 | "down_votes": datasets.Value("int64"), 707 | "age": datasets.Value("string"), 708 | "gender": datasets.Value("string"), 709 | "accent": datasets.Value("string"), 710 | "locale": datasets.Value("string"), 711 | "segment": datasets.Value("string"), 712 | "duration": datasets.Value("float32"), 713 | "dataset": datasets.Value("string"), 714 | } 715 | ) 716 | 717 | return datasets.DatasetInfo( 718 | description=_DESCRIPTION, 719 | features=features, 720 | supervised_keys=None, 721 | homepage=_HOMEPAGE, 722 | license=_LICENSE, 723 | citation=_CITATION, 724 | ) 725 | 726 | def _download_from_gdrive(self, src_url: str, dst_path: str): 727 | """Downloading from Gdrive""" 728 | gdown.download(src_url, dst_path, quiet=False) 729 | 730 | def _split_generators(self, dl_manager): 731 | """Returns SplitGenerators.""" 732 | dl_path = dl_manager.download_and_extract(_DATA_URL.format(self.config.name)) 733 | abs_path_to_data = os.path.join(dl_path, "cv-corpus-6.1-2020-12-11", self.config.name) 734 | abs_path_to_clips = os.path.join(abs_path_to_data, "clips") 735 | 736 | css10_dir = None 737 | if self.config.name in _CSS10_URLS: 738 | css10_url = _CSS10_URLS[self.config.name] 739 | css10_dir = dl_manager.extract(dl_manager.download_custom(css10_url, self._download_from_gdrive)) 740 | 741 | jsut_dir = None 742 | if self.config.name in _JSUT_URLS: 743 | jsut_url = _JSUT_URLS[self.config.name] 744 | jsut_dir = dl_manager.download_and_extract(jsut_url) 745 | jsut_dir = os.path.join(jsut_dir, "jsut_ver1.1") 746 | 747 | nst_metadata_dir = None 748 | nst_files_dir = None 749 | if self.config.name in _NST_URLS: 750 | nst_metadata_dir = dl_manager.download_and_extract(_NST_URLS[self.config.name]["metadata"]) 751 | nst_files_dir = dl_manager.download_and_extract(_NST_URLS[self.config.name]["files"]) 752 | 753 | free_st_dir = None 754 | if self.config.name in _FREE_ST_URLS: 755 | free_st_dir = dl_manager.download_and_extract(_FREE_ST_URLS[self.config.name]) 756 | free_st_dir = os.path.join(free_st_dir, "ST-CMDS-20170001_1-OS") 757 | 758 | arabic_speech_dir = None 759 | if self.config.name in _ARABIC_SPEECH: 760 | arabic_speech_dir = dl_manager.download_and_extract(_ARABIC_SPEECH[self.config.name]) 761 | arabic_speech_dir = os.path.join(arabic_speech_dir, "arabic-speech-corpus") 762 | 763 | timit_dir = None 764 | if self.config.name in _TIMIT: 765 | timit_dir = dl_manager.download_and_extract(_TIMIT[self.config.name]) 766 | 767 | librispeech_dirs = None 768 | if self.config.name in _LIBRISPEECH: 769 | librispeech_dirs = [] 770 | for librispeech_url in _LIBRISPEECH[self.config.name]: 771 | librispeech_dir = dl_manager.download_and_extract(librispeech_url) 772 | librispeech_dirs.append(librispeech_dir) 773 | 774 | return [ 775 | datasets.SplitGenerator( 776 | name=datasets.Split.TRAIN, 777 | gen_kwargs={ 778 | "filepath": os.path.join(abs_path_to_data, "train.tsv"), 779 | "path_to_clips": abs_path_to_clips, 780 | "css10_dir": css10_dir, 781 | "jsut_dir": jsut_dir, 782 | "nst_metadata_dir": nst_metadata_dir, 783 | "nst_files_dir": nst_files_dir, 784 | "free_st_dir": free_st_dir, 785 | "arabic_speech_dir": arabic_speech_dir, 786 | "timit_dir": timit_dir, 787 | "librispeech_dirs": librispeech_dirs, 788 | "max_samples": _MAX_TRAIN_SAMPLES 789 | }, 790 | ), 791 | datasets.SplitGenerator( 792 | name=datasets.Split.VALIDATION, 793 | gen_kwargs={ 794 | "filepath": os.path.join(abs_path_to_data, "dev.tsv"), 795 | "path_to_clips": abs_path_to_clips, 796 | "css10_dir": None, 797 | "jsut_dir": None, 798 | "nst_metadata_dir": None, 799 | "nst_files_dir": None, 800 | "free_st_dir": None, 801 | "arabic_speech_dir": None, 802 | "timit_dir": None, 803 | "librispeech_dirs": None, 804 | "max_samples": _MAX_VAL_SAMPLES 805 | }, 806 | ) 807 | ] 808 | 809 | def _convert_to_flac_and_save_it(self, path, delete_original_file=True): 810 | """We'll convert all the audio files to FLAC format to speedup the loading""" 811 | 812 | sample_path, sample_extension = os.path.splitext(path) 813 | new_path = f"{sample_path}.flac" 814 | 815 | if not os.path.isfile(new_path): 816 | 817 | with warnings.catch_warnings(): 818 | warnings.simplefilter("ignore") 819 | speech_array, sample_rate = librosa.load(path, sr=16_000) 820 | 821 | sf.write(new_path, speech_array, sample_rate) 822 | 823 | if delete_original_file: 824 | os.remove(path) 825 | 826 | return new_path 827 | 828 | def _common_voice_examples_generator(self, filepath, path_to_clips): 829 | 830 | data_fields = list(self._info().features.keys()) 831 | path_idx = data_fields.index("path") 832 | 833 | with open(filepath, encoding="utf-8") as f: 834 | lines = f.readlines() 835 | 836 | for line in lines[1:]: 837 | field_values = line.strip().split("\t") 838 | 839 | # set absolute path for mp3 audio file 840 | field_values[path_idx] = os.path.join(path_to_clips, field_values[path_idx]) 841 | 842 | # if data is incomplete, fill with empty values 843 | if len(field_values) < len(data_fields): 844 | field_values += (len(data_fields) - len(field_values)) * ["''"] 845 | 846 | sample = {key: value for key, value in zip(data_fields, field_values)} 847 | 848 | new_path = self._convert_to_flac_and_save_it(sample.get("path")) 849 | speech_array, sampling_rate = sf.read(new_path) 850 | sample["duration"] = len(speech_array) / sampling_rate 851 | sample["path"] = new_path 852 | sample["dataset"] = "common_voice" 853 | 854 | if self.config.unk_token_regex is not None: 855 | sample["sentence"] = re.sub(self.config.unk_token_regex, "", sample["sentence"]) 856 | 857 | yield sample 858 | 859 | def _css10_examples_generator(self, css10_dir): 860 | 861 | with open(os.path.join(css10_dir, "transcript.txt"), encoding="utf-8") as f: 862 | lines = f.readlines() 863 | 864 | for line in lines: 865 | values = line.strip().split("|") 866 | 867 | audio_path = self._convert_to_flac_and_save_it(os.path.join(css10_dir, values[0])) 868 | text = values[1] if self.config.name in ["ja", "zh"] else values[2] 869 | text = re.sub("\s+", " ", text) # remove multiple spaces 870 | duration = float(values[3]) 871 | 872 | if self.config.unk_token_regex is not None: 873 | text = re.sub(self.config.unk_token_regex, "", text) 874 | 875 | yield { 876 | "client_id": None, 877 | "path": audio_path, 878 | "sentence": text, 879 | "up_votes": 0, 880 | "down_votes": 0, 881 | "age": None, 882 | "gender": None, 883 | "accent": None, 884 | "locale": None, 885 | "segment": None, 886 | "duration": duration, 887 | "dataset": "css10" 888 | } 889 | 890 | def _jsut_examples_generator(self, jsut_dir): 891 | 892 | for subset in os.listdir(jsut_dir): 893 | 894 | if not os.path.isdir(os.path.join(jsut_dir, subset)): 895 | continue 896 | 897 | transcript_path = os.path.join(jsut_dir, subset, "transcript_utf8.txt") 898 | 899 | with open(transcript_path, encoding="utf-8") as f: 900 | 901 | lines = f.readlines() 902 | 903 | for line in lines: 904 | 905 | values = line.split(":") 906 | audio_path = os.path.join(jsut_dir, subset, "wav", f"{values[0]}.wav") 907 | text = values[1] 908 | text = re.sub("\s+", " ", text) # remove multiple spaces 909 | 910 | if self.config.unk_token_regex is not None: 911 | text = re.sub(self.config.unk_token_regex, "", text) 912 | 913 | new_audio_path = self._convert_to_flac_and_save_it(audio_path) 914 | speech_array, sampling_rate = sf.read(new_audio_path) 915 | duration = len(speech_array) / sampling_rate 916 | 917 | yield { 918 | "client_id": None, 919 | "path": new_audio_path, 920 | "sentence": text, 921 | "up_votes": 0, 922 | "down_votes": 0, 923 | "age": None, 924 | "gender": None, 925 | "accent": None, 926 | "locale": None, 927 | "segment": None, 928 | "duration": duration, 929 | "dataset": "jsut" 930 | } 931 | 932 | def _nst_examples_generator(self, nst_metadata_dir, nst_files_dir): 933 | 934 | for metadata_filename in os.listdir(nst_metadata_dir): 935 | 936 | metadata_filepath = os.path.join(nst_metadata_dir, metadata_filename) 937 | 938 | with open(metadata_filepath) as metadata_file: 939 | metadata = json.load(metadata_file) 940 | 941 | client_id = metadata.get("info", {}).get("Speaker_ID", None) 942 | age = metadata.get("info", {}).get("Age", None) 943 | gender = metadata.get("info", {}).get("Sex", None) 944 | lang = metadata.get("metadata").get("lang") 945 | pid = metadata.get("pid") 946 | audio_dir = os.path.join(nst_files_dir, lang, pid) 947 | 948 | for val_recording in metadata.get("val_recordings", []): 949 | 950 | audio_filename = f"{pid}_{val_recording.get('file').replace('.wav', '-1.wav')}" 951 | audio_path = os.path.join(audio_dir, audio_filename) 952 | 953 | # there are some missing files on the original dataset, so we need to handle this 954 | if not os.path.isfile(audio_path): 955 | continue 956 | 957 | text = val_recording.get("text") 958 | text = re.sub("\s+", " ", text) # remove multiple spaces 959 | 960 | if self.config.unk_token_regex is not None: 961 | text = re.sub(self.config.unk_token_regex, "", text) 962 | 963 | new_audio_path = self._convert_to_flac_and_save_it(audio_path) 964 | speech_array, sampling_rate = sf.read(new_audio_path) 965 | duration = len(speech_array) / sampling_rate 966 | 967 | yield { 968 | "client_id": client_id, 969 | "path": new_audio_path, 970 | "sentence": text, 971 | "up_votes": 0, 972 | "down_votes": 0, 973 | "age": age, 974 | "gender": gender, 975 | "accent": None, 976 | "locale": None, 977 | "segment": None, 978 | "duration": duration, 979 | "dataset": "nst" 980 | } 981 | 982 | def _free_st_examples_generator(self, free_st_dir): 983 | 984 | for filename in os.listdir(free_st_dir): 985 | 986 | if filename.endswith(".wav"): 987 | 988 | audio_path = os.path.join(free_st_dir, filename) 989 | text_path = os.path.join(free_st_dir, filename.replace(".wav", ".txt")) 990 | 991 | with open(text_path, "r") as text_file: 992 | text = text_file.read().replace("\n", "").strip() 993 | text = re.sub("\s+", " ", text) # remove multiple spaces 994 | 995 | if self.config.unk_token_regex is not None: 996 | text = re.sub(self.config.unk_token_regex, "", text) 997 | 998 | new_audio_path = self._convert_to_flac_and_save_it(audio_path) 999 | speech_array, sampling_rate = sf.read(new_audio_path) 1000 | duration = len(speech_array) / sampling_rate 1001 | 1002 | yield { 1003 | "client_id": None, 1004 | "path": new_audio_path, 1005 | "sentence": text, 1006 | "up_votes": 0, 1007 | "down_votes": 0, 1008 | "age": None, 1009 | "gender": None, 1010 | "accent": None, 1011 | "locale": None, 1012 | "segment": None, 1013 | "duration": duration, 1014 | "dataset": "free_st" 1015 | } 1016 | 1017 | def _arabic_speech_examples_generator(self, arabic_speech_dir): 1018 | 1019 | with open(os.path.join(arabic_speech_dir, "orthographic-transcript.txt"), encoding="utf-8") as f: 1020 | 1021 | lines = f.readlines() 1022 | 1023 | for line in lines: 1024 | 1025 | values = line.split('" "') 1026 | filename = values[0].strip()[1:] 1027 | text = values[1].strip()[:-1] 1028 | audio_path = os.path.join(arabic_speech_dir, "wav", filename) 1029 | 1030 | # converting buckwalter format to arabic letters 1031 | text = buckwalter.untransliterate(text) 1032 | text = re.sub("\s+", " ", text) # remove multiple spaces 1033 | 1034 | if self.config.unk_token_regex is not None: 1035 | text = re.sub(self.config.unk_token_regex, "", text) 1036 | 1037 | new_audio_path = self._convert_to_flac_and_save_it(audio_path) 1038 | speech_array, sampling_rate = sf.read(new_audio_path) 1039 | duration = len(speech_array) / sampling_rate 1040 | 1041 | yield { 1042 | "client_id": None, 1043 | "path": new_audio_path, 1044 | "sentence": text, 1045 | "up_votes": 0, 1046 | "down_votes": 0, 1047 | "age": None, 1048 | "gender": None, 1049 | "accent": None, 1050 | "locale": None, 1051 | "segment": None, 1052 | "duration": duration, 1053 | "dataset": "arabic_speech" 1054 | } 1055 | 1056 | def _timit_examples_generator(self, timit_dir): 1057 | 1058 | data_info_csv = os.path.join(timit_dir, "train_data.csv") 1059 | 1060 | """Generate examples from TIMIT archive_path based on the test/train csv information.""" 1061 | # Extract the archive path 1062 | data_path = os.path.join(os.path.dirname(data_info_csv).strip(), "data") 1063 | 1064 | # Read the data info to extract rows mentioning about non-converted audio only 1065 | data_info = pd.read_csv(data_info_csv, encoding="utf8") 1066 | # making sure that the columns having no information about the file paths are removed 1067 | data_info.dropna(subset=["path_from_data_dir"], inplace=True) 1068 | 1069 | # filter out only the required information for data preparation 1070 | data_info = data_info.loc[(data_info["is_audio"]) & (~data_info["is_converted_audio"])] 1071 | 1072 | # Iterating the contents of the data to extract the relevant information 1073 | for audio_idx in range(data_info.shape[0]): 1074 | audio_data = data_info.iloc[audio_idx] 1075 | 1076 | # extract the path to audio 1077 | wav_path = os.path.join(data_path, *(audio_data["path_from_data_dir"].split("/"))) 1078 | 1079 | # extract transcript 1080 | with open(wav_path.replace(".WAV", ".TXT"), "r", encoding="utf-8") as op: 1081 | transcript = " ".join(op.readlines()[0].split()[2:]) # first two items are sample number 1082 | 1083 | new_audio_path = self._convert_to_flac_and_save_it(wav_path) 1084 | speech_array, sampling_rate = sf.read(new_audio_path) 1085 | duration = len(speech_array) / sampling_rate 1086 | 1087 | yield { 1088 | "client_id": str(audio_data["speaker_id"]), 1089 | "path": new_audio_path, 1090 | "sentence": transcript, 1091 | "up_votes": 0, 1092 | "down_votes": 0, 1093 | "age": None, 1094 | "gender": None, 1095 | "accent": audio_data["dialect_region"], 1096 | "locale": None, 1097 | "segment": None, 1098 | "duration": duration, 1099 | "dataset": "timit" 1100 | } 1101 | 1102 | def _librispeech_examples_generator(self, librispeech_dir): 1103 | 1104 | transcripts_glob = os.path.join(librispeech_dir, "LibriSpeech", "*/*/*/*.txt") 1105 | for transcript_file in sorted(glob.glob(transcripts_glob)): 1106 | path = os.path.dirname(transcript_file) 1107 | # with open(os.path.join(path, transcript_file), "r", encoding="utf-8") as f: 1108 | with open(transcript_file, "r", encoding="utf-8") as f: 1109 | for line in f: 1110 | line = line.strip() 1111 | key, transcript = line.split(" ", 1) 1112 | audio_file = f"{key}.flac" 1113 | audio_file = os.path.join(path, audio_file) 1114 | speaker_id, chapter_id = [int(el) for el in key.split("-")[:2]] 1115 | 1116 | speech_array, sampling_rate = sf.read(audio_file) 1117 | duration = len(speech_array) / sampling_rate 1118 | 1119 | yield { 1120 | "client_id": str(speaker_id), 1121 | "path": audio_file, 1122 | "sentence": transcript, 1123 | "up_votes": 0, 1124 | "down_votes": 0, 1125 | "age": None, 1126 | "gender": None, 1127 | "accent": None, 1128 | "locale": None, 1129 | "segment": None, 1130 | "duration": duration, 1131 | "dataset": "librispeech" 1132 | } 1133 | 1134 | def _generate_examples(self, filepath, path_to_clips, css10_dir, jsut_dir, nst_metadata_dir, 1135 | nst_files_dir, free_st_dir, arabic_speech_dir, timit_dir, 1136 | librispeech_dirs, max_samples): 1137 | """ Yields examples. """ 1138 | _id = 0 1139 | 1140 | for example in self._common_voice_examples_generator(filepath, path_to_clips): 1141 | if _id == max_samples: 1142 | break 1143 | yield _id, example 1144 | _id += 1 1145 | 1146 | if timit_dir is not None and _id < max_samples: 1147 | for example in self._timit_examples_generator(timit_dir): 1148 | if _id < max_samples: 1149 | yield _id, example 1150 | _id += 1 1151 | else: 1152 | break 1153 | 1154 | if css10_dir is not None and _id < max_samples: 1155 | for example in self._css10_examples_generator(css10_dir): 1156 | if _id < max_samples: 1157 | yield _id, example 1158 | _id += 1 1159 | else: 1160 | break 1161 | 1162 | if librispeech_dirs is not None and _id < max_samples: 1163 | for librispeech_dir in librispeech_dirs: 1164 | for example in self._librispeech_examples_generator(librispeech_dir): 1165 | if _id < max_samples: 1166 | yield _id, example 1167 | _id += 1 1168 | else: 1169 | break 1170 | 1171 | if jsut_dir is not None and _id < max_samples: 1172 | for example in self._jsut_examples_generator(jsut_dir): 1173 | if _id < max_samples: 1174 | yield _id, example 1175 | _id += 1 1176 | else: 1177 | break 1178 | 1179 | if nst_files_dir is not None and _id < max_samples: 1180 | for example in self._nst_examples_generator(nst_metadata_dir, nst_files_dir): 1181 | if _id < max_samples: 1182 | yield _id, example 1183 | _id += 1 1184 | else: 1185 | break 1186 | 1187 | if free_st_dir is not None and _id < max_samples: 1188 | for example in self._free_st_examples_generator(free_st_dir): 1189 | if _id < max_samples: 1190 | yield _id, example 1191 | _id += 1 1192 | else: 1193 | break 1194 | 1195 | if arabic_speech_dir is not None and _id < max_samples: 1196 | root_dirs = [arabic_speech_dir, os.path.join(arabic_speech_dir, "test set")] 1197 | for root_dir in root_dirs: 1198 | for example in self._arabic_speech_examples_generator(root_dir): 1199 | if _id < max_samples: 1200 | yield _id, example 1201 | _id += 1 1202 | else: 1203 | break 1204 | --------------------------------------------------------------------------------