├── run_all.sh
├── playground
├── spellcheck_tests_3.py
├── spellcheck_tests.py
├── augmentation_tests.py
└── spellcheck_tests_2.py
├── requirements.txt
├── finetune.sh
├── LICENSE
├── finetune_with_params.sh
├── generate_all_trainings.py
├── common_voice_usage.py
├── wav2vec_languages.csv
├── sweep.yaml
├── Dockerfile
├── common_voice_eval.py
├── home-server.html
├── .gitignore
├── README.md
├── cer.py
├── wer.py
├── MODEL_CARD.md
├── run_common_voice.py
└── dataset_ext.py
/run_all.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | supervisord -n -u 42420 -c /etc/supervisor/supervisor.conf
3 |
--------------------------------------------------------------------------------
/playground/spellcheck_tests_3.py:
--------------------------------------------------------------------------------
1 | from autocorrect import Speller
2 |
3 | spell = Speller('pl')
4 | print(spell('ptaaki latatją kluczmm'))
5 |
6 | spell = Speller('fr')
7 | print(spell("CE SITE CONTIENT QUATRE TOMBEAUX DE LA DYNASTIE ASHÉMÉNIDE ET SEPT DES SASANNIDES".lower()))
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.8.1
2 | torchaudio==0.8.1
3 | datasets==1.5.0
4 | jiwer==2.2.0
5 | soundfile==0.10.3.post1
6 | lang-trans==0.6.0
7 | librosa==0.8.0
8 | samplerate==0.1.0
9 | git+https://github.com/huggingface/transformers.git
10 | scipy==1.5.4
11 | audiomentations==0.16.0
12 | torch-audiomentations==0.6.0
13 | pyloudnorm==0.1.0
14 | wandb==0.10.23
15 | homoglyphs==2.0.4
16 | pandas==1.2.3
17 | matplotlib==3.3.4
18 | gdown==3.12.2
--------------------------------------------------------------------------------
/playground/spellcheck_tests.py:
--------------------------------------------------------------------------------
1 | from spellchecker import SpellChecker
2 |
3 | # turn off loading a built language dictionary, case sensitive on (if desired)
4 | spell = SpellChecker(language="fr")
5 |
6 | # CE SITE CONTIENT QUATRE TOMBEAUX DE LA DYNASTIE ACHÉMÉNIDE ET SEPT DES SASSANIDES.
7 | words = "CE SITE CONTIENT QUATRE TOMBEAUX DE LA DYNASTIE ASHÉMÉNIDE ET SEPT DES SASANNIDES".split()
8 |
9 | for word in words:
10 | word = word.lower()
11 | if word in spell:
12 | print("'{}' is spelled correctly!".format(word))
13 | else:
14 | cor = spell.correction(word)
15 | print("The best spelling for '{}' is '{}'".format(word, cor))
16 |
17 | print("If that is not enough; here are all possible candidate words:")
18 | print(spell.candidates(word))
--------------------------------------------------------------------------------
/finetune.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python run_common_voice.py \
3 | --model_name_or_path="facebook/wav2vec2-large-xlsr-53" \
4 | --dataset_config_name="tr" \
5 | --output_dir=/workspace/container_0/wav2vec2-large-xlsr-turkish-demo \
6 | --cache_dir=/workspace/container_0 \
7 | --overwrite_output_dir \
8 | --num_train_epochs="1" \
9 | --per_device_train_batch_size="32" \
10 | --per_device_train_batch_size="32" \
11 | --evaluation_strategy="steps" \
12 | --learning_rate="3e-4" \
13 | --warmup_steps="500" \
14 | --fp16 \
15 | --freeze_feature_extractor \
16 | --save_steps="10" \
17 | --eval_steps="10" \
18 | --save_total_limit="1" \
19 | --logging_steps="10" \
20 | --group_by_length \
21 | --feat_proj_dropout="0.0" \
22 | --layerdrop="0.1" \
23 | --gradient_checkpointing \
24 | --do_train --do_eval \
25 | --max_train_samples 100 --max_val_samples 100
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Jonatas Grosman
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/finetune_with_params.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 |
4 | python /workspace/wav2vec/run_common_voice.py \
5 | --model_name_or_path=$model_name_or_path \
6 | --dataset_config_name=$dataset_config_name \
7 | --output_dir=$output_dir \
8 | --cache_dir=$cache_dir \
9 | --overwrite_output_dir \
10 | --num_train_epochs=$num_train_epochs \
11 | --per_device_train_batch_size=$per_device_train_batch_size \
12 | --evaluation_strategy=$evaluation_strategy \
13 | --learning_rate=$learning_rate \
14 | --warmup_steps=$warmup_steps \
15 | --fp16 \
16 | --freeze_feature_extractor \
17 | --save_steps=$save_steps \
18 | --eval_steps=$eval_steps \
19 | --save_total_limit=$save_total_limit \
20 | --logging_steps=$logging_steps \
21 | --group_by_length \
22 | --feat_proj_dropout=$feat_proj_dropout \
23 | --layerdrop=$layerdrop \
24 | --gradient_checkpointing \
25 | --do_train \
26 | --do_eval \
27 | --max_train_samples $max_train_samples \
28 | --max_val_samples $max_val_samples \
29 | --report_to $report_to \
30 | --run_name $run_name \
31 | --augmentation_factor $augmentation_factor
32 |
33 |
--------------------------------------------------------------------------------
/generate_all_trainings.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # In[8]:
5 |
6 |
7 | import os
8 | import csv
9 |
10 |
11 | # In[20]:
12 |
13 |
14 | with open('wav2vec_languages.csv') as csv_file:
15 | csv_reader = csv.reader(csv_file, delimiter=',')
16 | # This skips the first row of the CSV file because it's a header
17 | next(csv_reader)
18 | for (language_code, language_full_name) in csv_reader:
19 | print(f"#Launching Training for {language_code}-{language_full_name}")
20 | cmd = f"ovhai job run --gpu 1 --name '{language_code}-{language_full_name}' --volume output_models@GRA/{language_code}:/workspace/output_models:RW:cache -e model_name_or_path='facebook/wav2vec2-large-xlsr-53' -e dataset_config_name={language_code} -e output_dir='/workspace/output_models/wav2vec2-large-xlsr-{language_code}-{language_full_name}-demo' -e cache_dir='/workspace/data' -e num_train_epochs=10 databuzzword/hf-wav2vec -- sh /workspace/wav2vec/finetune_with_params.sh"
21 | print(cmd)
22 | stream = os.popen(cmd)
23 | output = stream.read()
24 | output
25 |
26 |
27 | # In[3]:
28 |
29 |
30 |
31 |
32 |
33 | # In[ ]:
34 |
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/common_voice_usage.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import librosa
3 | import warnings
4 | from datasets import load_dataset
5 | from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
6 |
7 | LANG_ID = "pt"
8 | MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese"
9 | SAMPLES = 10
10 |
11 | test_dataset = load_dataset("common_voice", LANG_ID, split=f"test[:{SAMPLES}]")
12 |
13 | processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
14 | model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
15 |
16 | # Preprocessing the datasets.
17 | # We need to read the audio files as arrays
18 | def speech_file_to_array_fn(batch):
19 | with warnings.catch_warnings():
20 | warnings.simplefilter("ignore")
21 | speech_array, sampling_rate = librosa.load(batch["path"], sr=16_000)
22 | batch["speech"] = speech_array
23 | batch["sentence"] = batch["sentence"].upper()
24 | return batch
25 |
26 | test_dataset = test_dataset.map(speech_file_to_array_fn)
27 | inputs = processor(test_dataset["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
28 |
29 | with torch.no_grad():
30 | logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
31 |
32 | predicted_ids = torch.argmax(logits, dim=-1)
33 | predicted_sentences = processor.batch_decode(predicted_ids)
34 |
35 | for i, predicted_sentence in enumerate(predicted_sentences):
36 | print("-" * 100)
37 | print("Reference:", test_dataset[i]["sentence"])
38 | print("Prediction:", predicted_sentence)
39 |
--------------------------------------------------------------------------------
/wav2vec_languages.csv:
--------------------------------------------------------------------------------
1 | language_code;language_full_name;valid_hours_size;has_homoglyphs;
2 | ab;abkhazian;0.05;False;
3 | vi;vietnamese;0.74;True;x
4 | as;assamese;0.74;False;
5 | br;breton;7;False;
6 | en;english;1686;True;
7 | cnh;cnh;2;False;
8 | cs;czech;36;False;
9 | cv;chuvash;4;False;
10 | cy;welsh;95;False;
11 | de;german;777;True;
12 | dv;divehi;18;False;
13 | ca;catalan;623;True;
14 | fr;french;623;True;
15 | es;spanish;324;True;
16 | it;italian;158;True;
17 | ru;russian;111;True;
18 | eu;basque;89;False;
19 | fa;persian;282;False;
20 | pl;polish;108;True;
21 | lv;latvian;99;True;
22 | fy-NL;western_frisian-netherlands;14;False;
23 | ga-IE;irish-ireland;3;False;
24 | hi;hindi;0.54;False;
25 | hsb;upper_sorbian;2;False;
26 | eo;esperanto;90;True;
27 | ia;interlingua;6;False;
28 | id;indonesian;9;False;
29 | nl;dutch;59;True;
30 | ja;japanese;3;False;
31 | ka;georgian;3;False;
32 | kab;kabyle;525;False;
33 | ky;kyrgyz;11;False;
34 | lg;ganda;3;False;
35 | pt;portuguese;50;True;
36 | ar;arabic;49;True;
37 | mn;mongolian;11;False;
38 | mt;maltese;7;False;
39 | tr;turkish;20;True;
40 | or;odia;0.87;False;
41 | pa-IN;punjabi-india;0.5;False;
42 | et;estonian;19;True;
43 | hu;hungarian;8;True;x
44 | rm-sursilv;romansh_sursilv;5;False;
45 | rm-vallader;romansh_vallader;2;False;
46 | th;thai;8;True;x
47 | el;greek;6;True;
48 | rw;kinyarwanda;1183;False;
49 | sah;sakha;4;False;
50 | ro;romanian;6;True;x
51 | sv-SE;swedish-sweden;12;False;
52 | ta;tamil;14;False;
53 | sl;slovenian;5;True;x
54 | lt;lithuanian;2;True;x
55 | tt;tatar;26;False;
56 | uk;ukrainian;30;False;
57 | fi;finnish;1;True;x
58 | vot;votic;0;False;
59 | zh-CN;chinese-china;56;False;
60 | zh-HK;chinese-hong_kong_sar_china;50;False;
61 | zh-TW;chinese-taiwan;55;False;
--------------------------------------------------------------------------------
/sweep.yaml:
--------------------------------------------------------------------------------
1 | program: run_common_voice.py
2 | name: hf-wav2vec-sprint-fi
3 | method: random
4 | metric:
5 | goal: minimize
6 | name: eval/loss
7 | parameters:
8 | seed:
9 | value: 42
10 | report_to:
11 | value: wandb
12 | model_name_or_path:
13 | value: facebook/wav2vec2-large-xlsr-53
14 | dataset_config_name:
15 | value: fi
16 | output_dir:
17 | value: ../models/fi/wav2vec2-large-xlsr-fi-sweep
18 | cache_dir:
19 | value: ../data/fi
20 | overwrite_output_dir:
21 | value: True
22 | fp16:
23 | value: True
24 | max_steps:
25 | value: 500
26 | eval_steps:
27 | value: 100
28 | logging_steps:
29 | value: 100
30 | do_eval:
31 | value: True
32 | do_train:
33 | value: True
34 | per_device_train_batch_size:
35 | value: 16
36 | per_device_eval_batch_size:
37 | value: 16
38 | dataloader_num_workers:
39 | value: 10
40 | preprocessing_num_workers:
41 | value: 10
42 | load_best_model_at_end:
43 | value: True
44 | save_total_limit:
45 | value: 1
46 | evaluation_strategy:
47 | value: steps
48 | freeze_feature_extractor:
49 | value: True
50 | group_by_length:
51 | value: True
52 | min_duration:
53 | value: 2.0
54 | max_duration:
55 | value: 9.0
56 | lr_warmup_ratio:
57 | value: 0.5
58 | lr_constant_ratio:
59 | value: 0.0
60 | augmentation_factor:
61 | values: [0, 1]
62 | layerdrop:
63 | value: 0.0
64 | learning_rate:
65 | values: [1e-4, 3e-4, 6e-4, 1e-3]
66 | attention_dropout:
67 | values: [0.05, 0.1, 0.2]
68 | activation_dropout:
69 | values: [0.05, 0.1, 0.2]
70 | hidden_dropout:
71 | values: [0.05, 0.1, 0.2]
72 | feat_proj_dropout:
73 | values: [0.05, 0.1, 0.2]
74 | mask_time_prob:
75 | values: [0.05, 0.1, 0.2]
76 | early_terminate:
77 | type: hyperband
78 | min_iter: 200
79 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM ovhcom/ai-training-one-for-all
2 |
3 | RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add -
4 |
5 | RUN apt-get update && \
6 | apt install -y bash \
7 | build-essential \
8 | libsndfile1-dev \
9 | git-lfs \
10 | ffmpeg \
11 | sox \
12 | libsox-fmt-mp3
13 |
14 | RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash && \
15 | apt-get install git-lfs && \
16 | git lfs install
17 |
18 | RUN python3 -m pip install --no-cache-dir --upgrade pip && \
19 | python3 -m pip install --no-cache-dir \
20 | torch==1.8.1 \
21 | torchaudio==0.8.1 \
22 | datasets==1.5.0 \
23 | jiwer==2.2.0 \
24 | soundfile==0.10.3.post1 \
25 | lang-trans==0.6.0 \
26 | librosa==0.8.0 \
27 | samplerate==0.1.0 \
28 | scipy==1.5.4 \
29 | audiomentations==0.16.0 \
30 | torch-audiomentations==0.6.0 \
31 | pyloudnorm==0.1.0 \
32 | wandb==0.10.23 \
33 | homoglyphs==2.0.4 \
34 | gdown
35 |
36 | RUN pip3 uninstall -y typing allennlp
37 |
38 | RUN pip3 install git+https://github.com/huggingface/transformers.git
39 |
40 | RUN mkdir -p /workspace/wav2vec/
41 |
42 | COPY finetune.sh finetune_with_params.sh run_common_voice.py dataset_ext.py cer.py wer.py common_voice_eval.py common_voice_usage.py /workspace/wav2vec/
43 |
44 | COPY home-server.html run_all.sh /usr/bin/
45 |
46 | RUN chown -R 42420:42420 /workspace
47 |
48 | RUN chown -R 42420:42420 /usr/bin/run_all.sh
49 |
50 | #Default training env variables
51 | ENV model_name_or_path="facebook/wav2vec2-large-xlsr-53" \
52 | dataset_config_name="fr" \
53 | output_dir="/workspace/output_models/wav2vec2-large-xlsr-french-demo" \
54 | cache_dir="/workspace/data" \
55 | num_train_epochs="1" \
56 | per_device_train_batch_size="32" \
57 | evaluation_strategy="steps" \
58 | learning_rate="3e-4" \
59 | warmup_steps="500" \
60 | save_steps="10" \
61 | eval_steps="10" \
62 | save_total_limit="1" \
63 | logging_steps="10" \
64 | feat_proj_dropout="0.0" \
65 | layerdrop="0.1" \
66 | max_train_samples=100 \
67 | max_val_samples=100
68 |
69 | WORKDIR /workspace
70 | ENTRYPOINT []
71 | #CMD ["sh", "/usr/bin/run_all.sh"]
72 | CMD ["supervisord", "-n", "-u", "42420", "-c", "/etc/supervisor/supervisor.conf"]
73 |
--------------------------------------------------------------------------------
/common_voice_eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import librosa
3 | import re
4 | import warnings
5 | from datasets import load_dataset, load_metric
6 | from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
7 |
8 | LANG_ID = "pt"
9 | MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese"
10 | DEVICE = "cuda"
11 |
12 | CHARS_TO_IGNORE = [",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "�", "ʿ", "·", "჻", "~", "՞",
13 | "؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")", "[", "]",
14 | "{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。",
15 | "、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽",
16 | "『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "ʻ", "ˆ"]
17 |
18 | test_dataset = load_dataset("common_voice", LANG_ID, split="test")
19 |
20 | wer = load_metric("wer.py") # https://github.com/jonatasgrosman/wav2vec2-sprint/blob/main/wer.py
21 | cer = load_metric("cer.py") # https://github.com/jonatasgrosman/wav2vec2-sprint/blob/main/cer.py
22 |
23 | chars_to_ignore_regex = f"[{re.escape(''.join(CHARS_TO_IGNORE))}]"
24 |
25 | processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
26 | model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
27 | model.to(DEVICE)
28 |
29 | # Preprocessing the datasets.
30 | # We need to read the audio files as arrays
31 | def speech_file_to_array_fn(batch):
32 | with warnings.catch_warnings():
33 | warnings.simplefilter("ignore")
34 | speech_array, sampling_rate = librosa.load(batch["path"], sr=16_000)
35 | batch["speech"] = speech_array
36 | batch["sentence"] = re.sub(chars_to_ignore_regex, "", batch["sentence"]).upper()
37 | return batch
38 |
39 | test_dataset = test_dataset.map(speech_file_to_array_fn)
40 |
41 | # Preprocessing the datasets.
42 | # We need to read the audio files as arrays
43 | def evaluate(batch):
44 | inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
45 |
46 | with torch.no_grad():
47 | logits = model(inputs.input_values.to(DEVICE), attention_mask=inputs.attention_mask.to(DEVICE)).logits
48 |
49 | pred_ids = torch.argmax(logits, dim=-1)
50 | batch["pred_strings"] = processor.batch_decode(pred_ids)
51 | return batch
52 |
53 | result = test_dataset.map(evaluate, batched=True, batch_size=8)
54 |
55 | predictions = [x.upper() for x in result["pred_strings"]]
56 | references = [x.upper() for x in result["sentence"]]
57 |
58 | print(f"WER: {wer.compute(predictions=predictions, references=references, chunk_size=1000) * 100}")
59 | print(f"CER: {cer.compute(predictions=predictions, references=references, chunk_size=1000) * 100}")
60 |
--------------------------------------------------------------------------------
/home-server.html:
--------------------------------------------------------------------------------
1 |
2 |
3 | OVH AI Training Job
4 |
5 |
6 |
7 |
55 |
56 |
57 |
--------------------------------------------------------------------------------
/playground/augmentation_tests.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import datasets
4 | import torch
5 | import torchaudio
6 | import soundfile as sf
7 | import librosa
8 |
9 | from audiomentations import (
10 | Compose,
11 | AddGaussianNoise,
12 | AddGaussianSNR,
13 | ClippingDistortion,
14 | FrequencyMask,
15 | Gain,
16 | LoudnessNormalization,
17 | Normalize,
18 | PitchShift,
19 | PolarityInversion,
20 | Shift,
21 | TimeMask,
22 | TimeStretch,
23 | )
24 |
25 | os.makedirs("_ignore_data", exist_ok=True)
26 |
27 | # creating augmentation pipeline
28 |
29 | augmentator = Compose([
30 | AddGaussianNoise(min_amplitude=0.0001, max_amplitude=0.001, p=0.5),
31 | Gain(min_gain_in_db=-1, max_gain_in_db=1, p=0.5),
32 | PitchShift(min_semitones=-2, max_semitones=2, p=0.5),
33 | # TimeStretch(min_rate=0.7, max_rate=1.3, leave_length_unchanged=False, p=0.5),
34 | # FrequencyMask(min_frequency_band=0.0, max_frequency_band=0.5, p=0.5),
35 | # TimeMask(min_band_part=0.0, max_band_part=0.01, fade=True, p=0.5),
36 | # ClippingDistortion(min_percentile_threshold=0, max_percentile_threshold=5, p=0.5),
37 | # LoudnessNormalization(min_lufs_in_db=-31, max_lufs_in_db=-13, p=0.5)
38 | # PolarityInversion(p=0.5),
39 | # Shift(min_fraction=-0.01, max_fraction=0.01, rollover=False, p=0.5),
40 | # AddGaussianSNR(min_SNR=0.001, max_SNR=1.0, p=0.5),
41 | # Normalize(p=0.5),
42 | ])
43 |
44 | # Get the dataset
45 |
46 | # train_dataset = datasets.load_dataset(
47 | # "common_voice_ext.py", "fi", augmentation_factor=2, split="train+validation", cache_dir="_ignore_data/cache/fi"
48 | # )
49 |
50 | train_dataset = datasets.load_dataset(
51 | "common_voice", "pt", split="train+validation", cache_dir="_ignore_data/cache/pt"
52 | )
53 |
54 | # Getting one sample per gender
55 |
56 | male_sample = None
57 | female_sample = None
58 |
59 | for sample in train_dataset:
60 |
61 | if sample.get("gender") == "male":
62 | male_sample = sample
63 | elif sample.get("gender") == "female":
64 | female_sample = sample
65 |
66 | if male_sample is not None and female_sample is not None:
67 | break
68 |
69 | # Augmenting data and saving it
70 |
71 | speech_array, sample_rate = librosa.load(male_sample["path"], sr = 16000, res_type="zero_order_hold")
72 | speech_array_augmented = augmentator(samples=speech_array, sample_rate=sample_rate)
73 | sf.write("_ignore_data/male_original.wav", speech_array, sample_rate, subtype="PCM_24")
74 | sf.write("_ignore_data/male_augmented.flac", speech_array_augmented, sample_rate)
75 |
76 | speech_array, sample_rate = librosa.load(female_sample["path"], sr = 16000, res_type="zero_order_hold")
77 | speech_array_augmented = augmentator(samples=speech_array, sample_rate=sample_rate)
78 | sf.write("_ignore_data/female_original.wav", speech_array, sample_rate, subtype="PCM_24")
79 | sf.write("_ignore_data/female_augmented.wav", speech_array_augmented, sample_rate, subtype="PCM_24")
80 |
81 | print(":)")
82 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # SO
132 | .DS_Store
133 |
134 | # IDE
135 | .vscode
136 |
137 | # ignorable folder/files pattern
138 | _ignore*
139 |
140 | # lock files
141 | *.lock
142 |
143 | # wandB files
144 | wandb
145 |
146 | # training files
147 | all_results.json
148 | config,json
149 | eval_results.json
150 | preprocessor_config.json
151 | pytorch_model.bin
152 | special_tokens_map.json
153 | tokenizer_config.json
154 | train_results.json
155 | trainer_state.json
156 | training_args.args
157 | vocab.json
158 | *.zip
159 |
160 |
161 |
162 |
163 |
164 |
--------------------------------------------------------------------------------
/playground/spellcheck_tests_2.py:
--------------------------------------------------------------------------------
1 | import spacy
2 | import contextualSpellCheck
3 |
4 | # ENGLISH
5 |
6 | # python -m spacy download en_core_web_sm
7 | # nlp = spacy.load("en_core_web_sm")
8 |
9 | # nlp.add_pipe("contextual spellchecker", config={"max_edit_dist": 100})
10 |
11 | # doc = nlp("Icome was $9.4 milion compared to the last yiear of $2.7 milion.")
12 | # print(doc._.performed_spellCheck)
13 | # print(doc._.outcome_spellCheck)
14 |
15 | # # Doc Extention
16 | # print(doc._.contextual_spellCheck)
17 |
18 | # print(doc._.performed_spellCheck)
19 |
20 | # print(doc._.suggestions_spellCheck)
21 |
22 | # print(doc._.outcome_spellCheck)
23 |
24 | # print(doc._.score_spellCheck)
25 |
26 | # # Token Extention
27 | # print(doc[4]._.get_require_spellCheck)
28 |
29 | # print(doc[4]._.get_suggestion_spellCheck)
30 |
31 | # print(doc[4]._.score_spellCheck)
32 |
33 | # # Span Extention
34 | # print(doc[2:6]._.get_has_spellCheck)
35 |
36 | # print(doc[2:6]._.score_spellCheck)
37 |
38 | # JAPANESE
39 |
40 | # python -m spacy download ja_core_news_sm
41 | # pip install mecab-python3==0.996.5
42 | # pip install ipadic==1.0.0
43 | # pip install unidic-lite==1.0.6
44 | # pip install fugashi==1.1.0
45 | # nlp = spacy.load("ja_core_news_sm")
46 |
47 | # nlp.add_pipe(
48 | # "contextual spellchecker",
49 | # config={
50 | # "model_name": "cl-tohoku/bert-base-japanese-whole-word-masking",
51 | # "max_edit_dist": 2,
52 | # },
53 | # )
54 |
55 | # doc = nlp("しかし大勢においては、ここような事故はウィキペディアの拡大には影響を及ぼしていない。")
56 | # print(doc._.performed_spellCheck)
57 | # print(doc._.outcome_spellCheck)
58 |
59 | # PORTUGUESE
60 |
61 | # python -m spacy download pt_core_news_lg
62 | # nlp = spacy.load("pt_core_news_lg")
63 |
64 | # nlp.add_pipe(
65 | # "contextual spellchecker",
66 | # config={
67 | # "model_name": "neuralmind/bert-large-portuguese-cased",
68 | # "max_edit_dist": 5,
69 | # },
70 | # )
71 |
72 | ## PEDIR DINHEIRO EMPRESTADO ÀS PESSOAS DA ALDEIA
73 | # doc = nlp("EDIR DINHEIRO EMPRESTADO ÀS PESSOAS DO ALDEIRA".lower())
74 | # print(doc._.performed_spellCheck)
75 | # print(doc._.outcome_spellCheck)
76 |
77 |
78 | # SPANISH
79 |
80 | # # python -m spacy download es_dep_news_trf
81 | # nlp = spacy.load("es_dep_news_trf")
82 |
83 | # nlp.add_pipe(
84 | # "contextual spellchecker",
85 | # config={
86 | # "model_name": "Geotrend/bert-base-es-cased",
87 | # "max_edit_dist": 5,
88 | # },
89 | # )
90 |
91 | ## HABITAN EN AGUAS POCO PROFUNDAS Y ROCOSAS
92 | # doc = nlp("HABITAN AGUAS POCO PROFUNDAS Y ROCOSAS".lower())
93 | # print(doc._.performed_spellCheck)
94 | # print(doc._.outcome_spellCheck)
95 |
96 |
97 | # FRENCH
98 |
99 | # python -m spacy download fr_core_news_sm
100 | nlp = spacy.load("fr_core_news_sm")
101 |
102 | nlp.add_pipe(
103 | "contextual spellchecker",
104 | config={
105 | "model_name": "camembert-base",
106 | "max_edit_dist": 5,
107 | },
108 | )
109 |
110 | # CE SITE CONTIENT QUATRE TOMBEAUX DE LA DYNASTIE ACHÉMÉNIDE ET SEPT DES SASSANIDES.
111 | doc = nlp("CE SITE CONTIENT QUATRE TOMBEAUX DE LA DYNASTIE ASHÉMÉNIDE ET SEPT DES SASANNIDES".lower())
112 | doc = nlp("CE SITE CONTIENT QUATRE TOMBEAUX".lower())
113 | print(doc._.performed_spellCheck)
114 | print(doc._.outcome_spellCheck)
115 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This repo is deprecated in favor of https://github.com/jonatasgrosman/huggingsound
2 |
3 | # Wav2Vec Trainer
4 |
5 | This repository is based on https://github.com/jqueguiner/wav2vec2-sprint
6 |
7 | ## Building docker image
8 |
9 | Dockerhub available at https://hub.docker.com/r/patilsuraj/hf-wav2vec
10 |
11 | to build the docker :
12 |
13 | ```
14 | $ docker build -t hf-wav2vec-sprint -f Dockerfile .
15 | ```
16 |
17 | to push it to dockerhub
18 | First create a repository on dockerhub
19 | ```
20 | $ docker tag hf-wav2vec-sprint your-dockerhub-user/hf-wav2vec-sprint
21 | ```
22 |
23 | to push it to dockerhub
24 |
25 | ```
26 | $ docker push your-dockerhub-user/hf-wav2vec-sprint
27 | ```
28 |
29 | ## Running WandB sweep
30 |
31 | Initialize your sweep from any machine...
32 |
33 | ```
34 | $ export WANDB_API_KEY=YOUR_WANDB_API_KEY
35 | $ export WANDB_ENTITY=YOUR_WANDB_ENTITY
36 | $ export WANDB_PROJECT=YOUR_WANDB_PROJECT
37 |
38 | $ wandb sweep sweep.yaml
39 | ```
40 | ... the execution above will give you a sweep id, save it and on the training machine run:
41 |
42 | ```
43 | $ export WANDB_API_KEY=YOUR_WANDB_API_KEY
44 | $ export WANDB_ENTITY=YOUR_WANDB_ENTITY
45 | $ export WANDB_PROJECT=YOUR_WANDB_PROJECT
46 |
47 | $ wandb agent YOUR_SWEEP_ID
48 | ```
49 |
50 | ## Uploading model to HF
51 |
52 | You need to upload the following files to the HF repository
53 |
54 | - preprocessor_config.json
55 | - special_tokens_map.json
56 | - tokenizer_config.json
57 | - vocab.json
58 | - config.json
59 | - pytorch_model.bin
60 | - README.md (create this file based on the MODEL_CARD.md)
61 |
62 | ```
63 | $ git config --global user.email "email@example.com"
64 |
65 | $ git config --global user.name "Your name"
66 |
67 | $ transformers-cli login
68 |
69 | $ transformers-cli repo create your-model-name
70 |
71 | $ git clone https://username:password_or_token@huggingface.co/username/your-model-name
72 |
73 | $ git add .
74 |
75 | $ git commit -m "Initial commit"
76 |
77 | $ git push
78 |
79 | ```
80 |
81 | ## Troubleshooting
82 |
83 | - audioread.exceptions.NoBackendError: `$ sudo apt-get install ffmpeg sox libsox-fmt-mp3`
84 |
85 |
86 | ## Finetuned models
87 |
88 | ### Wav2Vec2-XLSR-53
89 |
90 | - [Arabic](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-arabic)
91 | - [Chinese](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn)
92 | - [Dutch](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-dutch)
93 | - [Finnish](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-finnish)
94 | - [French](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-french)
95 | - [German](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-german)
96 | - [Greek](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-greek)
97 | - [Hungarian](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-hungarian)
98 | - [Italian](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-italian)
99 | - [Japanese](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-japanese)
100 | - [Persian](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-persian)
101 | - [Polish](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-polish)
102 | - [Portuguese](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-portuguese)
103 | - [Russian](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-russian)
104 | - [Spanish](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-spanish)
105 |
--------------------------------------------------------------------------------
/cer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The HuggingFace Datasets Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Character Error Ratio (CER) metric. """
16 | import jiwer
17 | import jiwer.transforms as tr
18 | from typing import List
19 | import datasets
20 | import gc
21 |
22 | _CITATION = """\
23 | @inproceedings{inproceedings,
24 | author = {Morris, Andrew and Maier, Viktoria and Green, Phil},
25 | year = {2004},
26 | month = {01},
27 | pages = {},
28 | title = {From WER and RIL to MER and WIL: improved evaluation measures for connected speech recognition.}
29 | }
30 | """
31 | _DESCRIPTION = """\
32 | Character error rate (CER) is a common metric of the performance of an automatic speech recognition system.
33 | CER is similar to Word Error Rate (WER), but operate on character insted of word. Please refer to docs of WER for further information.
34 | Character error rate can be computed as:
35 | CER = (S + D + I) / N = (S + D + I) / (S + D + C)
36 | where
37 | S is the number of substitutions,
38 | D is the number of deletions,
39 | I is the number of insertions,
40 | C is the number of correct characters,
41 | N is the number of characters in the reference (N=S+D+C).
42 | CER's output is always a number between 0 and 1. This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the
43 | performance of the ASR system with a CER of 0 being a perfect score.
44 | """
45 | _KWARGS_DESCRIPTION = """
46 | Computes CER score of transcribed segments against references.
47 | Args:
48 | references: list of references for each speech input.
49 | predictions: list of transcribtions to score.
50 | Returns:
51 | (float): the character error rate
52 | Examples:
53 | >>> predictions = ["this is the prediction", "there is an other sample"]
54 | >>> references = ["this is the reference", "there is another one"]
55 | >>> cer = datasets.load_metric("cer")
56 | >>> cer_score = cer.compute(predictions=predictions, references=references)
57 | >>> print(cer_score)
58 | 0.34146341463414637
59 | """
60 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
61 | class CER(datasets.Metric):
62 | def _info(self):
63 | return datasets.MetricInfo(
64 | description=_DESCRIPTION,
65 | citation=_CITATION,
66 | inputs_description=_KWARGS_DESCRIPTION,
67 | features=datasets.Features(
68 | {
69 | "predictions": datasets.Value("string", id="sequence"),
70 | "references": datasets.Value("string", id="sequence"),
71 | }
72 | ),
73 | codebase_urls=["https://github.com/jitsi/jiwer/"],
74 | reference_urls=[
75 | "https://en.wikipedia.org/wiki/Word_error_rate",
76 | ],
77 | )
78 | def _compute(self, predictions, references, chunk_size=None):
79 | if chunk_size is None:
80 | preds = [char for seq in predictions for char in list(seq)]
81 | refs = [char for seq in references for char in list(seq)]
82 | return jiwer.wer(refs, preds)
83 | start = 0
84 | end = chunk_size
85 | H, S, D, I = 0, 0, 0, 0
86 | while start < len(references):
87 | preds = [char for seq in predictions[start:end] for char in list(seq)]
88 | refs = [char for seq in references[start:end] for char in list(seq)]
89 | chunk_metrics = jiwer.compute_measures(refs, preds)
90 | H = H + chunk_metrics["hits"]
91 | S = S + chunk_metrics["substitutions"]
92 | D = D + chunk_metrics["deletions"]
93 | I = I + chunk_metrics["insertions"]
94 | start += chunk_size
95 | end += chunk_size
96 | del preds
97 | del refs
98 | del chunk_metrics
99 | gc.collect()
100 | return float(S + D + I) / float(H + S + D)
101 |
--------------------------------------------------------------------------------
/wer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The HuggingFace Datasets Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Word Error Ratio (WER) metric. """
16 |
17 | import jiwer
18 | import datasets
19 |
20 | _CITATION = """\
21 | @inproceedings{inproceedings,
22 | author = {Morris, Andrew and Maier, Viktoria and Green, Phil},
23 | year = {2004},
24 | month = {01},
25 | pages = {},
26 | title = {From WER and RIL to MER and WIL: improved evaluation measures for connected speech recognition.}
27 | }
28 | """
29 |
30 | _DESCRIPTION = """\
31 | Word error rate (WER) is a common metric of the performance of an automatic speech recognition system.
32 | The general difficulty of measuring performance lies in the fact that the recognized word sequence can have a different length from the reference word sequence (supposedly the correct one). The WER is derived from the Levenshtein distance, working at the word level instead of the phoneme level. The WER is a valuable tool for comparing different systems as well as for evaluating improvements within one system. This kind of measurement, however, provides no details on the nature of translation errors and further work is therefore required to identify the main source(s) of error and to focus any research effort.
33 | This problem is solved by first aligning the recognized word sequence with the reference (spoken) word sequence using dynamic string alignment. Examination of this issue is seen through a theory called the power law that states the correlation between perplexity and word error rate.
34 | Word error rate can then be computed as:
35 | WER = (S + D + I) / N = (S + D + I) / (S + D + C)
36 | where
37 | S is the number of substitutions,
38 | D is the number of deletions,
39 | I is the number of insertions,
40 | C is the number of correct words,
41 | N is the number of words in the reference (N=S+D+C).
42 | WER's output is always a number between 0 and 1. This value indicates the percentage of words that were incorrectly predicted. The lower the value, the better the
43 | performance of the ASR system with a WER of 0 being a perfect score.
44 | """
45 |
46 | _KWARGS_DESCRIPTION = """
47 | Computes WER score of transcribed segments against references.
48 | Args:
49 | references: list of references for each speech input.
50 | predictions: list of transcribtions to score.
51 | Returns:
52 | (float): the word error rate
53 | Examples:
54 | >>> predictions = ["this is the prediction", "there is an other sample"]
55 | >>> references = ["this is the reference", "there is another one"]
56 | >>> wer = datasets.load_metric("wer")
57 | >>> wer_score = wer.compute(predictions=predictions, references=references)
58 | >>> print(wer_score)
59 | 0.5
60 | """
61 |
62 |
63 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
64 | class WER(datasets.Metric):
65 | def _info(self):
66 | return datasets.MetricInfo(
67 | description=_DESCRIPTION,
68 | citation=_CITATION,
69 | inputs_description=_KWARGS_DESCRIPTION,
70 | features=datasets.Features(
71 | {
72 | "predictions": datasets.Value("string", id="sequence"),
73 | "references": datasets.Value("string", id="sequence"),
74 | }
75 | ),
76 | codebase_urls=["https://github.com/jitsi/jiwer/"],
77 | reference_urls=[
78 | "https://en.wikipedia.org/wiki/Word_error_rate",
79 | ],
80 | )
81 |
82 | def _compute(self, predictions, references, chunk_size=None):
83 | if chunk_size is None: return jiwer.wer(references, predictions)
84 | start = 0
85 | end = chunk_size
86 | H, S, D, I = 0, 0, 0, 0
87 | while start < len(references):
88 | chunk_metrics = jiwer.compute_measures(references[start:end], predictions[start:end])
89 | H = H + chunk_metrics["hits"]
90 | S = S + chunk_metrics["substitutions"]
91 | D = D + chunk_metrics["deletions"]
92 | I = I + chunk_metrics["insertions"]
93 | start += chunk_size
94 | end += chunk_size
95 | return float(S + D + I) / float(H + S + D)
96 |
--------------------------------------------------------------------------------
/MODEL_CARD.md:
--------------------------------------------------------------------------------
1 | ---
2 | language: pt
3 | datasets:
4 | - common_voice
5 | metrics:
6 | - wer
7 | - cer
8 | tags:
9 | - audio
10 | - automatic-speech-recognition
11 | - speech
12 | - xlsr-fine-tuning-week
13 | license: apache-2.0
14 | model-index:
15 | - name: XLSR Wav2Vec2 Portuguese by Jonatas Grosman
16 | results:
17 | - task:
18 | name: Speech Recognition
19 | type: automatic-speech-recognition
20 | dataset:
21 | name: Common Voice pt
22 | type: common_voice
23 | args: pt
24 | metrics:
25 | - name: Test WER
26 | type: wer
27 | value: {wer_result_on_test}
28 | - name: Test CER
29 | type: cer
30 | value: {cer_result_on_test} #TODO (IMPORTANT): replace {wer_result_on_test} with the WER error rate you achieved on the common_voice test set. It should be in the format XX.XX (don't add the % sign here). **Please** remember to fill out this value after you evaluated your model, so that your model appears on the leaderboard. If you fill out this model card before evaluating your model, please remember to edit the model card afterward to fill in your value
31 | ---
32 |
33 | # Wav2Vec2-Large-XLSR-53-Portuguese
34 |
35 | Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on Portuguese using the [Common Voice](https://huggingface.co/datasets/common_voice).
36 | When using this model, make sure that your speech input is sampled at 16kHz.
37 |
38 | The script used for training can be found here: https://github.com/jonatasgrosman/wav2vec2-sprint
39 | ## Usage
40 |
41 | The model can be used directly (without a language model) as follows:
42 |
43 | ```python
44 | import torch
45 | import librosa
46 | from datasets import load_dataset
47 | from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
48 |
49 | LANG_ID = "pt"
50 | MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese"
51 | SAMPLES = 5
52 |
53 | test_dataset = load_dataset("common_voice", LANG_ID, split=f"test[:{SAMPLES}]")
54 |
55 | processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
56 | model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
57 |
58 | # Preprocessing the datasets.
59 | # We need to read the audio files as arrays
60 | def speech_file_to_array_fn(batch):
61 | speech_array, sampling_rate = librosa.load(batch["path"], sr=16_000)
62 | batch["speech"] = speech_array
63 | batch["sentence"] = batch["sentence"].upper()
64 | return batch
65 |
66 | test_dataset = test_dataset.map(speech_file_to_array_fn)
67 | inputs = processor(test_dataset["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
68 |
69 | with torch.no_grad():
70 | logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
71 |
72 | predicted_ids = torch.argmax(logits, dim=-1)
73 | predicted_sentences = processor.batch_decode(predicted_ids)
74 |
75 | for i, predicted_sentence in enumerate(predicted_sentences):
76 | print("-" * 100)
77 | print("Reference:", test_dataset[i]["sentence"])
78 | print("Prediction:", predicted_sentence)
79 | ```
80 |
81 | | Reference | Prediction |
82 | | ------------- | ------------- |
83 | | NEM O RADAR NEM OS OUTROS INSTRUMENTOS DETECTARAM O BOMBARDEIRO STEALTH. | NEM UM VADA ME OS OUTOS INSTRUMENTOS DE TETERAM UM BAMBEDER OSTAU |
84 | | PEDIR DINHEIRO EMPRESTADO ÀS PESSOAS DA ALDEIA | PEDIAR DINHEIRO EMPRESTADO DÀS PESSOAS DA ALDEIA |
85 | | OITO | OITO |
86 | | TRANCÁ-LOS | TRAM CALDOS |
87 | | REALIZAR UMA INVESTIGAÇÃO PARA RESOLVER O PROBLEMA | REALIZARAMA INVESTIGAÇÃO PARA RESOLVER O PROBLEMA |
88 |
89 | ## Evaluation
90 |
91 | The model can be evaluated as follows on the Portuguese test data of Common Voice.
92 |
93 | ```python
94 | import torch
95 | import librosa
96 | import re
97 | from datasets import load_dataset, load_metric
98 | from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
99 |
100 | LANG_ID = "pt"
101 | MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese"
102 | DEVICE = "cuda"
103 |
104 | CHARS_TO_IGNORE = [",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "�", "ʿ", "·", "჻", "~", "՞",
105 | "؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")", "[", "]",
106 | "{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。",
107 | "、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽",
108 | "『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "ʻ", "ˆ"]
109 |
110 | test_dataset = load_dataset("common_voice", LANG_ID, split="test")
111 |
112 | wer = load_metric("wer") # https://github.com/jonatasgrosman/wav2vec2-sprint/blob/main/wer.py
113 | cer = load_metric("cer") # https://github.com/jonatasgrosman/wav2vec2-sprint/blob/main/cer.py
114 |
115 | chars_to_ignore_regex = f"[{re.escape(''.join(CHARS_TO_IGNORE))}]"
116 |
117 | processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
118 | model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
119 | model.to(DEVICE)
120 |
121 | # Preprocessing the datasets.
122 | # We need to read the audio files as arrays
123 | def speech_file_to_array_fn(batch):
124 | batch["sentence"] = re.sub(chars_to_ignore_regex, "", batch["sentence"]).upper()
125 | speech_array, sampling_rate = librosa.load(batch["path"], sr=16_000)
126 | batch["speech"] = speech_array
127 | return batch
128 |
129 | test_dataset = test_dataset.map(speech_file_to_array_fn)
130 |
131 | # Preprocessing the datasets.
132 | # We need to read the audio files as arrays
133 | def evaluate(batch):
134 | inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
135 |
136 | with torch.no_grad():
137 | logits = model(inputs.input_values.to(DEVICE), attention_mask=inputs.attention_mask.to(DEVICE)).logits
138 |
139 | pred_ids = torch.argmax(logits, dim=-1)
140 | batch["pred_strings"] = processor.batch_decode(pred_ids)
141 | return batch
142 |
143 | result = test_dataset.map(evaluate, batched=True, batch_size=8)
144 |
145 | predictions = [x.upper() for x in result["pred_strings"]]
146 | references = [x.upper() for x in result["sentence"]]
147 |
148 | print(f"WER: {wer.compute(predictions=predictions, references=references, chunk_size=1000) * 100}")
149 | print(f"CER: {cer.compute(predictions=predictions, references=references, chunk_size=1000) * 100}")
150 | ```
151 |
152 | **Test Result**:
153 |
154 | In the table below I report the Word Error Rate (WER) and the Character Error Rate (CER) of the model. I ran the evaluation script described above on other models as well (on YYYY-MM-DD). Note that the table below may show different results from those already reported, this may have been caused due to some specificity of the other evaluation scripts used.
155 |
156 | | Model | WER | CER |
157 | | ------------- | ------------- | ------------- |
158 | | **jonatasgrosman/wav2vec2-large-xlsr-53-portuguese** | XX.XX% | XX.XX% |
159 | | username/model_name | XX.XX% | XX.XX% |
160 |
--------------------------------------------------------------------------------
/run_common_voice.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import json
3 | import logging
4 | import os
5 | import re
6 | import sys
7 | import collections
8 | from dataclasses import dataclass, field
9 | from typing import Any, Dict, List, Optional, Union
10 |
11 | import datasets
12 | import numpy as np
13 | import torch
14 | import soundfile as sf
15 | from packaging import version
16 | from torch import nn
17 | from pathlib import Path
18 | import wandb
19 |
20 | from torch_audiomentations import Compose, Gain
21 | from audiomentations import (
22 | Compose,
23 | AddGaussianNoise,
24 | AddGaussianSNR,
25 | ClippingDistortion,
26 | FrequencyMask,
27 | Gain,
28 | LoudnessNormalization,
29 | Normalize,
30 | PitchShift,
31 | PolarityInversion,
32 | Shift,
33 | TimeMask,
34 | TimeStretch,
35 | )
36 |
37 | import transformers
38 | from transformers import (
39 | HfArgumentParser,
40 | Trainer,
41 | TrainingArguments,
42 | Wav2Vec2CTCTokenizer,
43 | Wav2Vec2FeatureExtractor,
44 | Wav2Vec2ForCTC,
45 | Wav2Vec2Processor,
46 | is_apex_available,
47 | set_seed,
48 | )
49 | from transformers.trainer_utils import get_last_checkpoint, is_main_process
50 | from transformers.trainer_pt_utils import LengthGroupedSampler, DistributedLengthGroupedSampler
51 |
52 |
53 | PRETRAINED_MODELS = [
54 | "facebook/wav2vec2-large",
55 | "facebook/wav2vec2-large-xlsr-53",
56 | "facebook/wav2vec2-large-es-voxpopuli",
57 | "facebook/wav2vec2-large-fr-voxpopuli",
58 | "facebook/wav2vec2-large-it-voxpopuli",
59 | "facebook/wav2vec2-large-nl-voxpopuli",
60 | "facebook/wav2vec2-large-sv-voxpopuli",
61 | "facebook/wav2vec2-large-10k-voxpopuli",
62 | "facebook/wav2vec2-large-100k-voxpopuli"
63 | ]
64 |
65 |
66 | if is_apex_available():
67 | from apex import amp
68 |
69 |
70 | if version.parse(torch.__version__) >= version.parse("1.6"):
71 | _is_native_amp_available = True
72 | from torch.cuda.amp import autocast
73 |
74 | logger = logging.getLogger(__name__)
75 |
76 |
77 | def list_field(default=None, metadata=None):
78 | return field(default_factory=lambda: default, metadata=metadata)
79 |
80 |
81 | @dataclass
82 | class AdditionalTrainingArguments:
83 | """
84 | Additional training arguments
85 | """
86 |
87 | lr_warmup_ratio: Optional[float] = field(
88 | default=0.1,
89 | metadata={"help": "Percentage of steps for LR warmup phase"},
90 | )
91 | lr_constant_ratio: Optional[float] = field(
92 | default=0.4,
93 | metadata={"help": "Percentage of steps for LR constant phase (after warmup)"},
94 | )
95 | upload_final_model_to_wandb: Optional[bool] = field(
96 | default=False,
97 | metadata={"help": "Upload the final trained model to the WandB artifacts repository"},
98 | )
99 | upload_model_to_wandb_each_step: Optional[int] = field(
100 | default=None,
101 | metadata={"help": "Frequency (in steps) to upload the trained model to the WandB artifacts repository"},
102 | )
103 | apply_gaussian_noise_with_p: Optional[float] = field(
104 | default=0.5,
105 | metadata={"help": "Probability to apply Gaussian Noise in the original samples"},
106 | )
107 | apply_gain_with_p: Optional[float] = field(
108 | default=0.5,
109 | metadata={"help": "Probability to apply Gain in the original samples"},
110 | )
111 | apply_pitch_shift_with_p: Optional[float] = field(
112 | default=0.5,
113 | metadata={"help": "Probability to apply Pitch Shift in the original samples"},
114 | )
115 | apply_time_stretch_with_p: Optional[float] = field(
116 | default=0.5,
117 | metadata={"help": "Probability to apply Time Stretch in the original samples"},
118 | )
119 | min_char_occurrence_ratio: Optional[float] = field(
120 | default=None,
121 | metadata={"help": "Minimum ratio of character occurrences to be considered for the vocabulary builder"},
122 | )
123 | max_dataset_size_vocab_builder: Optional[int] = field(
124 | default=10000,
125 | metadata={"help": "Maximum size of the dataset to be considered for vocabulary builder"},
126 | )
127 | remove_samples_with_oov_from_training: Optional[bool] = field(
128 | default=False,
129 | metadata={"help": "Whether to remove samples from training when there are OOV characters on them"},
130 | )
131 | use_only_top_k_most_common_accent: Optional[int] = field(
132 | default=None,
133 | metadata={"help": "Use only the top most common accent in dataset for training"},
134 | )
135 |
136 | @dataclass
137 | class ModelArguments:
138 | """
139 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
140 | """
141 |
142 | model_name_or_path: str = field(
143 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
144 | )
145 | cache_dir: Optional[str] = field(
146 | default=None,
147 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
148 | )
149 | freeze_feature_extractor: Optional[bool] = field(
150 | default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
151 | )
152 | attention_dropout: Optional[float] = field(
153 | default=0.1, metadata={"help": "The dropout ratio for the attention probabilities."}
154 | )
155 | activation_dropout: Optional[float] = field(
156 | default=0.1, metadata={"help": "The dropout ratio for activations inside the fully connected layer."}
157 | )
158 | hidden_dropout: Optional[float] = field(
159 | default=0.1,
160 | metadata={
161 | "help": "The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler."
162 | },
163 | )
164 | feat_proj_dropout: Optional[float] = field(
165 | default=0.1,
166 | metadata={"help": "The dropout probabilitiy for all 1D convolutional layers in feature extractor."},
167 | )
168 | mask_time_prob: Optional[float] = field(
169 | default=0.05,
170 | metadata={
171 | "help": "Propability of each feature vector along the time axis to be chosen as the start of the vector"
172 | "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
173 | "vectors will be masked along the time axis. This is only relevant if ``apply_spec_augment is True``."
174 | },
175 | )
176 | gradient_checkpointing: Optional[bool] = field(
177 | default=True,
178 | metadata={
179 | "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
180 | },
181 | )
182 | layerdrop: Optional[float] = field(default=0.0, metadata={"help": "The LayerDrop probability."})
183 |
184 |
185 | @dataclass
186 | class DataTrainingArguments:
187 | """
188 | Arguments pertaining to what data we are going to input our model for training and eval.
189 |
190 | Using `HfArgumentParser` we can turn this class
191 | into argparse arguments to be able to specify them on
192 | the command line.
193 | """
194 |
195 | dataset_config_name: Optional[str] = field(
196 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
197 | )
198 | overwrite_cache: bool = field(
199 | default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
200 | )
201 | preprocessing_num_workers: Optional[int] = field(
202 | default=None,
203 | metadata={"help": "The number of processes to use for the preprocessing."},
204 | )
205 | max_train_samples: Optional[int] = field(
206 | default=None,
207 | metadata={
208 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
209 | "value if set."
210 | },
211 | )
212 | max_val_samples: Optional[int] = field(
213 | default=1000,
214 | metadata={
215 | "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
216 | "value if set."
217 | },
218 | )
219 | val_ratio: Optional[float] = field(
220 | default=0.2,
221 | metadata={
222 | "help": "Percentage of dataset samples to be used for evaluation, default is 20%"
223 | },
224 | )
225 | chars_to_ignore: List[str] = list_field(
226 | default=[",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "�", "ʿ", "·", "჻", "~", "՞",
227 | "؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")", "[", "]",
228 | "{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。",
229 | "、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽",
230 | "『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "ʻ", "ˆ"],
231 | metadata={"help": "A list of characters to remove from the transcripts."},
232 | )
233 | min_duration: Optional[float] = field(
234 | default=0.0,
235 | metadata={
236 | "help": "The minimum duration (in seconds) that a sample needs to have to be considered for training"
237 | },
238 | )
239 | max_duration: Optional[float] = field(
240 | default=float("inf"),
241 | metadata={
242 | "help": "The maximum duration (in seconds) that a sample needs to have to be considered for training"
243 | },
244 | )
245 | use_only_common_voice_data: bool = field(
246 | default=False, metadata={"help": "Use only common voice data in training."}
247 | )
248 |
249 |
250 | @dataclass
251 | class DataCollatorCTCWithPadding:
252 | """
253 | Data collator that will dynamically pad the inputs received.
254 | Args:
255 | processor (:class:`~transformers.Wav2Vec2Processor`)
256 | The processor used for proccessing the data.
257 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
258 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
259 | among:
260 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
261 | sequence if provided).
262 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
263 | maximum acceptable input length for the model if that argument is not provided.
264 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
265 | different lengths).
266 | max_length (:obj:`int`, `optional`):
267 | Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
268 | max_length_labels (:obj:`int`, `optional`):
269 | Maximum length of the ``labels`` returned list and optionally padding length (see above).
270 | pad_to_multiple_of (:obj:`int`, `optional`):
271 | If set will pad the sequence to a multiple of the provided value.
272 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
273 | 7.5 (Volta).
274 | """
275 |
276 | processor: Wav2Vec2Processor
277 | padding: Union[bool, str] = True
278 | max_length: Optional[int] = None
279 | max_length_labels: Optional[int] = None
280 | pad_to_multiple_of: Optional[int] = None
281 | pad_to_multiple_of_labels: Optional[int] = None
282 |
283 | def __init__(self, processor, padding=True, apply_gaussian_noise_with_p=0.5, apply_gain_with_p=0.5, apply_pitch_shift_with_p=0.5,
284 | apply_time_stretch_with_p=0.5, sample_rate=16_000):
285 | self.processor = processor
286 | self.padding = padding
287 | self.apply_gaussian_noise_with_p = apply_gaussian_noise_with_p
288 | self.apply_gain_with_p = apply_gain_with_p
289 | self.apply_pitch_shift_with_p = apply_pitch_shift_with_p
290 | self.apply_time_stretch_with_p = apply_time_stretch_with_p
291 | self.sample_rate = sample_rate
292 |
293 | self.augmentator = None
294 | if self.apply_gaussian_noise_with_p + self.apply_gain_with_p + self.apply_pitch_shift_with_p + self.apply_time_stretch_with_p > 0:
295 | self.augmentator = Compose([
296 | TimeStretch(min_rate=0.8, max_rate=1.2, leave_length_unchanged=False, p=self.apply_time_stretch_with_p),
297 | PitchShift(min_semitones=-1, max_semitones=1, p=self.apply_pitch_shift_with_p),
298 | Gain(min_gain_in_db=-1, max_gain_in_db=1, p=self.apply_gain_with_p),
299 | AddGaussianNoise(min_amplitude=0.0001, max_amplitude=0.001, p=self.apply_gaussian_noise_with_p),
300 | ])
301 |
302 | def _apply_augmentation(self, input_values: List[float]):
303 | """apply some audio augmentations in the given input_values"""
304 | if self.augmentator is not None:
305 | return self.augmentator(samples=np.array(input_values), sample_rate=self.sample_rate).tolist()
306 | else:
307 | return input_values
308 |
309 | def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
310 | # split inputs and labels since they have to be of different lenghts and need
311 | # different padding methods
312 |
313 | input_features = [{"input_values": self._apply_augmentation(feature["input_values"])} for feature in features]
314 | label_features = [{"input_ids": feature["labels"]} for feature in features]
315 |
316 | batch = self.processor.pad(
317 | input_features,
318 | padding=self.padding,
319 | max_length=self.max_length,
320 | pad_to_multiple_of=self.pad_to_multiple_of,
321 | return_tensors="pt",
322 | )
323 | with self.processor.as_target_processor():
324 | labels_batch = self.processor.pad(
325 | label_features,
326 | padding=self.padding,
327 | max_length=self.max_length_labels,
328 | pad_to_multiple_of=self.pad_to_multiple_of_labels,
329 | return_tensors="pt",
330 | )
331 |
332 | # replace padding with -100 to ignore loss correctly
333 | labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
334 |
335 | batch["labels"] = labels
336 |
337 | return batch
338 |
339 |
340 | class CTCTrainer(Trainer):
341 |
342 | def __init__(self, model_output_dir, length_field_name="length", upload_model_to_wandb_each_step=None, lr_warmup_ratio=0.1,
343 | lr_constant_ratio=0.4, sampling_rate=16_000, **kwargs):
344 | super().__init__(**kwargs)
345 | self.model_output_dir = model_output_dir
346 | self.length_field_name = length_field_name
347 | self.upload_model_to_wandb_each_step = upload_model_to_wandb_each_step
348 | self.lr_warmup_ratio = lr_warmup_ratio
349 | self.lr_constant_ratio = lr_constant_ratio
350 | self.sampling_rate = sampling_rate
351 |
352 | def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
353 | if isinstance(self.train_dataset, torch.utils.data.IterableDataset) or not isinstance(
354 | self.train_dataset, collections.abc.Sized
355 | ):
356 | return None
357 |
358 | # Build the sampler.
359 | if self.args.group_by_length:
360 | lengths = self.train_dataset[self.length_field_name] if self.length_field_name is not None else None
361 | model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
362 | if self.args.world_size <= 1:
363 | return LengthGroupedSampler(
364 | self.train_dataset, self.args.train_batch_size, lengths=lengths, model_input_name=model_input_name
365 | )
366 | else:
367 | return DistributedLengthGroupedSampler(
368 | self.train_dataset,
369 | self.args.train_batch_size,
370 | num_replicas=self.args.world_size,
371 | rank=self.args.process_index,
372 | lengths=lengths,
373 | model_input_name=model_input_name,
374 | )
375 |
376 | else:
377 | return super()._get_train_sampler()
378 |
379 | def create_scheduler(self, num_training_steps: int):
380 | """
381 | Setup the scheduler. The optimizer of the trainer must have been set up before this method is called.
382 |
383 | This method was built based on https://arxiv.org/pdf/2006.13979 :
384 | "The learning rate schedule has three phases: warm up for the first 10% of updates,
385 | keep constant for 40% and then linearly decay for the remainder"
386 |
387 | Args:
388 | num_training_steps (int): The number of training steps to do.
389 | """
390 | def lr_lambda(current_step):
391 | warmup_steps = int(num_training_steps * self.lr_warmup_ratio)
392 | constant_steps = int(num_training_steps * self.lr_constant_ratio)
393 | if current_step < warmup_steps:
394 | return float(current_step) / float(max(1, warmup_steps))
395 | elif (self.lr_warmup_ratio + self.lr_constant_ratio) == 1.0 or current_step < (warmup_steps + constant_steps):
396 | return 1
397 | else:
398 | return max(
399 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - (warmup_steps + constant_steps)))
400 | )
401 |
402 | self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
403 |
404 | def _apply_some_audio_transformations(self, inputs):
405 | """Perform some audio transformations"""
406 |
407 | # adding an extra dimmention for the channels as our data is mono audio and
408 | # the expected shape of input for torch_audiomentations is (batch_size, num_channels, num_samples)
409 | transformed_inputs = inputs["input_values"].unsqueeze(1)
410 |
411 | transformed_inputs = self.augmentator(transformed_inputs, sample_rate=self.sampling_rate)
412 |
413 | # returning the inputs to the original shape
414 | transformed_inputs = torch.squeeze(transformed_inputs, 1)
415 |
416 | inputs["input_values"] = transformed_inputs
417 |
418 | return inputs
419 |
420 | def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
421 | """
422 | Perform a training step on a batch of inputs.
423 |
424 | Subclass and override to inject custom behavior.
425 |
426 | Args:
427 | model (:obj:`nn.Module`):
428 | The model to train.
429 | inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
430 | The inputs and targets of the model.
431 |
432 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
433 | argument :obj:`labels`. Check your model's documentation for all accepted arguments.
434 |
435 | Return:
436 | :obj:`torch.Tensor`: The tensor with training loss on this batch.
437 | """
438 |
439 | model.train()
440 | inputs = self._prepare_inputs(inputs)
441 |
442 | if self.use_amp:
443 | with autocast():
444 | loss = self.compute_loss(model, inputs)
445 | else:
446 | loss = self.compute_loss(model, inputs)
447 |
448 | if self.args.n_gpu > 1:
449 | if model.module.config.ctc_loss_reduction == "mean":
450 | loss = loss.mean()
451 | elif model.module.config.ctc_loss_reduction == "sum":
452 | loss = loss.sum() / (inputs["labels"] >= 0).sum()
453 | else:
454 | raise ValueError(f"{model.config.ctc_loss_reduction} is not valid. Choose one of ['mean', 'sum']")
455 |
456 | if self.args.gradient_accumulation_steps > 1:
457 | loss = loss / self.args.gradient_accumulation_steps
458 |
459 | if self.use_amp:
460 | self.scaler.scale(loss).backward()
461 | elif self.use_apex:
462 | with amp.scale_loss(loss, self.optimizer) as scaled_loss:
463 | scaled_loss.backward()
464 | elif self.deepspeed:
465 | self.deepspeed.backward(loss)
466 | else:
467 | loss.backward()
468 |
469 | if self.upload_model_to_wandb_each_step is not None and self.state.global_step > 0 \
470 | and self.state.global_step % self.upload_model_to_wandb_each_step == 0:
471 | upload_model_to_wandb(self.model_output_dir, name=f"{wandb.run.name}_{self.state.global_step}", metadata={"loss": float(loss)})
472 |
473 | return loss.detach()
474 |
475 |
476 | def build_tokenizer(model_output_dir, dataset, num_proc, min_char_occurrence_ratio):
477 |
478 | def extract_all_chars(batch):
479 | all_text = " ".join(batch["text"]).replace("", "")
480 | return {"all_text": [all_text]}
481 |
482 | vocab_train = dataset.map(
483 | extract_all_chars,
484 | batched=True,
485 | batch_size=-1,
486 | remove_columns=dataset.column_names,
487 | num_proc=num_proc
488 | )
489 |
490 | special_vocab_dict = {"": 0, "": 1, "": 2, "": 3, "|": 4}
491 |
492 | min_char_occurrence = int(min_char_occurrence_ratio * len(vocab_train["all_text"][0])) if min_char_occurrence_ratio is not None else 1
493 |
494 | if min_char_occurrence > 1:
495 | character_counter = collections.Counter(vocab_train["all_text"][0])
496 | vocab_list = [character for character, count in character_counter.items() if count >= min_char_occurrence]
497 | else:
498 | vocab_list = set(vocab_train["all_text"][0])
499 |
500 | vocab_list = [x for x in vocab_list if x.isalpha() or x in ["-", "'"]] # removing non-alpha (except - or ') characters
501 |
502 | vocab_list = sorted(vocab_list)
503 | vocab_dict = {v: k + len(special_vocab_dict) for k, v in enumerate(vocab_list)}
504 | vocab_dict = dict(special_vocab_dict, **vocab_dict)
505 |
506 | vocab_path = os.path.join(model_output_dir, "vocab.json")
507 |
508 | with open(vocab_path, "w") as vocab_file:
509 | json.dump(vocab_dict, vocab_file)
510 |
511 | return Wav2Vec2CTCTokenizer(
512 | vocab_path,
513 | unk_token="",
514 | pad_token="",
515 | word_delimiter_token="|",
516 | )
517 |
518 |
519 | def upload_model_to_wandb(model_output_dir, name, metadata=None):
520 | artifact = wandb.Artifact(name=name, type="model", metadata=metadata)
521 | artifact.add_dir(model_output_dir)
522 | wandb.run.log_artifact(artifact)
523 |
524 |
525 | def main():
526 | # See all possible arguments in src/transformers/training_args.py
527 | # or by passing the --help flag to this script.
528 | # We now keep distinct sets of args, for a cleaner separation of concerns.
529 |
530 | # override default run name
531 |
532 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, AdditionalTrainingArguments, TrainingArguments))
533 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
534 | # If we pass only one argument to the script and it's the path to a json file,
535 | # let's parse it to get our arguments.
536 | model_args, data_args, additional_training_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
537 | else:
538 | model_args, data_args, additional_training_args, training_args = parser.parse_args_into_dataclasses()
539 |
540 | os.makedirs(training_args.output_dir, exist_ok=True)
541 | os.makedirs(model_args.cache_dir, exist_ok=True)
542 |
543 | wandb.init(dir=model_args.cache_dir)
544 |
545 | # Detecting last checkpoint.
546 | last_checkpoint = None
547 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
548 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
549 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
550 | raise ValueError(
551 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
552 | "Use --overwrite_output_dir to overcome."
553 | )
554 | elif last_checkpoint is not None:
555 | logger.info(
556 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
557 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
558 | )
559 |
560 | # Setup logging
561 | logging.basicConfig(
562 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
563 | datefmt="%m/%d/%Y %H:%M:%S",
564 | handlers=[logging.StreamHandler(sys.stdout)],
565 | )
566 | logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
567 |
568 | # Log on each process the small summary:
569 | logger.warning(
570 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
571 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
572 | )
573 | # Set the verbosity to info of the Transformers logger (on main process only):
574 | if is_main_process(training_args.local_rank):
575 | transformers.utils.logging.set_verbosity_info()
576 | logger.info("Training/evaluation parameters %s", training_args)
577 |
578 | # Set seed before initializing model.
579 | set_seed(training_args.seed)
580 |
581 | # Get the datasets:
582 |
583 | # As Common Voice dataset for most of the languages are really small, we'll merge the train and validation splits
584 | dataset = datasets.load_dataset(
585 | "dataset_ext.py", data_args.dataset_config_name,
586 | split="train+validation",
587 | cache_dir=model_args.cache_dir
588 | )
589 |
590 | print("DATASET COUNT:")
591 | print(collections.Counter(dataset["dataset"]))
592 |
593 | if data_args.val_ratio > 0 and data_args.max_val_samples > 0 and training_args.do_eval:
594 | if len(dataset) * data_args.val_ratio > data_args.max_val_samples:
595 | dataset = dataset.train_test_split(test_size=data_args.max_val_samples)
596 | else:
597 | dataset = dataset.train_test_split(test_size=data_args.val_ratio)
598 |
599 | train_dataset = dataset["train"]
600 | eval_dataset = dataset["test"]
601 |
602 | else:
603 | train_dataset = dataset
604 | eval_dataset = None
605 |
606 |
607 | # Filtering dataset:
608 |
609 | train_dataset_original_size = len(train_dataset)
610 | if eval_dataset is not None:
611 | eval_dataset_original_size = len(eval_dataset)
612 |
613 | if data_args.use_only_common_voice_data:
614 | train_dataset = train_dataset.filter(
615 | lambda example: example["dataset"] == "common_voice",
616 | num_proc=data_args.preprocessing_num_workers
617 | )
618 |
619 | train_dataset = train_dataset.filter(
620 | lambda example: example["duration"] >= data_args.min_duration and example["duration"] <= data_args.max_duration,
621 | num_proc=data_args.preprocessing_num_workers
622 | )
623 |
624 | if data_args.max_train_samples is not None and train_dataset_original_size > data_args.max_train_samples:
625 | train_dataset = train_dataset.select(range(data_args.max_train_samples))
626 |
627 | if eval_dataset is not None and data_args.max_val_samples is not None and eval_dataset_original_size > data_args.max_val_samples:
628 | eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
629 |
630 | train_dataset_final_size = len(train_dataset)
631 | if eval_dataset is not None:
632 | eval_dataset_final_size = len(eval_dataset)
633 |
634 | logger.info(f"After filtering {train_dataset_final_size} of {train_dataset_original_size} samples will be used to train the model")
635 | if eval_dataset is not None:
636 | logger.info(f"After filtering {eval_dataset_final_size} of {eval_dataset_original_size} samples will be used to eval the model")
637 |
638 | # Create and save tokenizer
639 | chars_to_ignore_regex = f"[{re.escape(''.join(data_args.chars_to_ignore))}]"
640 |
641 | def remove_special_characters(batch):
642 | batch["text"] = re.sub(chars_to_ignore_regex, "", batch["sentence"]).strip().upper() + " "
643 | return batch
644 |
645 | train_dataset = train_dataset.map(
646 | remove_special_characters,
647 | remove_columns=["sentence"],
648 | num_proc=data_args.preprocessing_num_workers
649 | )
650 | if eval_dataset is not None:
651 | eval_dataset = eval_dataset.map(
652 | remove_special_characters,
653 | remove_columns=["sentence"],
654 | num_proc=data_args.preprocessing_num_workers
655 | )
656 |
657 | # Load pretrained model and tokenizer
658 | #
659 | # Distributed training:
660 | # The .from_pretrained methods guarantee that only one local process can concurrently
661 | # download model & vocab.
662 |
663 | if model_args.model_name_or_path in PRETRAINED_MODELS:
664 | dataset = datasets.concatenate_datasets([train_dataset, eval_dataset]) if eval_dataset is not None else train_dataset
665 | if len(dataset) > additional_training_args.max_dataset_size_vocab_builder:
666 | dataset = dataset.select(range(additional_training_args.max_dataset_size_vocab_builder))
667 | tokenizer = build_tokenizer(training_args.output_dir, dataset, data_args.preprocessing_num_workers, additional_training_args.min_char_occurrence_ratio)
668 | feature_extractor = Wav2Vec2FeatureExtractor(
669 | feature_size=1, sampling_rate=16_000, padding_value=0.0, do_normalize=True, return_attention_mask=True
670 | )
671 | processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
672 | else:
673 | processor = Wav2Vec2Processor.from_pretrained(model_args.model_name_or_path)
674 |
675 | if additional_training_args.remove_samples_with_oov_from_training:
676 | vocab = set(processor.tokenizer.encoder.keys())
677 | train_dataset_size = len(train_dataset)
678 | train_dataset = train_dataset.filter(
679 | lambda example: vocab.issuperset(example["text"].replace(" ", "")),
680 | num_proc=data_args.preprocessing_num_workers
681 | )
682 | print(f"OOV found in {train_dataset_size - len(train_dataset)} samples, and they were removed from training set")
683 | print(f"The final training set size is {len(train_dataset)}")
684 |
685 | if additional_training_args.use_only_top_k_most_common_accent is not None:
686 |
687 | train_dataset_size = len(train_dataset)
688 |
689 | accent_count = collections.Counter(train_dataset["accent"])
690 | # accent_count.pop("", None)
691 | major_accents = [k for k, x in accent_count.most_common(additional_training_args.use_only_top_k_most_common_accent)]
692 |
693 | print(f"ACCENT COUNT: {accent_count}")
694 |
695 | train_dataset = train_dataset.filter(
696 | lambda example: example["accent"] in major_accents,
697 | num_proc=data_args.preprocessing_num_workers
698 | )
699 |
700 | print(f"{train_dataset_size - len(train_dataset)} were removed from dataset due accent filtering, the final training dataset size is {len(train_dataset)}")
701 |
702 | # save the feature_extractor and the tokenizer
703 | processor.save_pretrained(training_args.output_dir)
704 |
705 | model = Wav2Vec2ForCTC.from_pretrained(
706 | model_args.model_name_or_path,
707 | cache_dir=model_args.cache_dir,
708 | activation_dropout=model_args.activation_dropout,
709 | attention_dropout=model_args.attention_dropout,
710 | hidden_dropout=model_args.hidden_dropout,
711 | feat_proj_dropout=model_args.feat_proj_dropout,
712 | mask_time_prob=model_args.mask_time_prob,
713 | gradient_checkpointing=model_args.gradient_checkpointing,
714 | layerdrop=model_args.layerdrop,
715 | ctc_loss_reduction="mean",
716 | pad_token_id=processor.tokenizer.pad_token_id,
717 | vocab_size=len(processor.tokenizer),
718 | ctc_zero_infinity=True
719 | )
720 |
721 | # Preprocessing the datasets.
722 | # We need to read the audio files as arrays and tokenize the targets.
723 | def speech_file_to_array_fn(batch):
724 | speech_array, sampling_rate = sf.read(batch["path"])
725 | batch["speech"] = speech_array
726 | batch["sampling_rate"] = sampling_rate
727 | batch["target_text"] = batch["text"]
728 | return batch
729 |
730 | print("TRAIN DATASET COUNT:")
731 | print(collections.Counter(train_dataset["dataset"]))
732 | print("EVAL DATASET COUNT:")
733 | print(collections.Counter(eval_dataset["dataset"]))
734 |
735 | train_dataset = train_dataset.map(
736 | speech_file_to_array_fn,
737 | remove_columns=train_dataset.column_names,
738 | num_proc=data_args.preprocessing_num_workers
739 | )
740 | if eval_dataset is not None:
741 | eval_dataset = eval_dataset.map(
742 | speech_file_to_array_fn,
743 | remove_columns=eval_dataset.column_names,
744 | num_proc=data_args.preprocessing_num_workers
745 | )
746 |
747 | def prepare_dataset(batch):
748 | # check that all files have the correct sampling rate
749 | assert (
750 | len(set(batch["sampling_rate"])) == 1
751 | ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
752 | batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
753 |
754 | # Setup the processor for targets
755 | with processor.as_target_processor():
756 | batch["labels"] = processor(batch["target_text"]).input_ids
757 | return batch
758 |
759 | train_dataset = train_dataset.map(
760 | prepare_dataset,
761 | remove_columns=train_dataset.column_names,
762 | batch_size=training_args.per_device_train_batch_size,
763 | batched=True,
764 | num_proc=data_args.preprocessing_num_workers
765 | )
766 | if eval_dataset is not None:
767 | eval_dataset = eval_dataset.map(
768 | prepare_dataset,
769 | remove_columns=eval_dataset.column_names,
770 | batch_size=training_args.per_device_train_batch_size,
771 | batched=True,
772 | num_proc=data_args.preprocessing_num_workers
773 | )
774 |
775 | # Pre-compute sample lengths
776 | def input_lengths(example):
777 | example["length"] = len(example["input_values"])
778 | return example
779 |
780 | train_dataset = train_dataset.map(
781 | input_lengths,
782 | num_proc=data_args.preprocessing_num_workers
783 | )
784 |
785 | # Metric
786 | wer_metric = datasets.load_metric("wer.py")
787 | cer_metric = datasets.load_metric("cer.py")
788 |
789 | def compute_metrics(pred):
790 | pred_logits = pred.predictions
791 | pred_ids = np.argmax(pred_logits, axis=-1)
792 |
793 | pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
794 |
795 | pred_str = processor.batch_decode(pred_ids)
796 | # we do not want to group tokens when computing the metrics
797 | label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
798 |
799 | wer = wer_metric.compute(predictions=pred_str, references=label_str, chunk_size=1000)
800 | cer = cer_metric.compute(predictions=pred_str, references=label_str, chunk_size=1000)
801 |
802 | return {"wer": wer, "cer": cer}
803 |
804 | if model_args.freeze_feature_extractor:
805 | model.freeze_feature_extractor()
806 |
807 | # Data collator
808 | data_collator = DataCollatorCTCWithPadding(
809 | processor=processor,
810 | padding=True,
811 | apply_gaussian_noise_with_p=additional_training_args.apply_gaussian_noise_with_p,
812 | apply_gain_with_p=additional_training_args.apply_gain_with_p,
813 | apply_pitch_shift_with_p=additional_training_args.apply_pitch_shift_with_p,
814 | apply_time_stretch_with_p=additional_training_args.apply_time_stretch_with_p,
815 | sample_rate=16_000,
816 | )
817 |
818 | # Initialize our Trainer
819 | trainer = CTCTrainer(
820 | model_output_dir=training_args.output_dir,
821 | length_field_name="length",
822 | upload_model_to_wandb_each_step=additional_training_args.upload_model_to_wandb_each_step,
823 | lr_warmup_ratio=additional_training_args.lr_warmup_ratio,
824 | lr_constant_ratio=additional_training_args.lr_constant_ratio,
825 | sampling_rate=16_000,
826 | model=model,
827 | data_collator=data_collator,
828 | args=training_args,
829 | compute_metrics=compute_metrics,
830 | train_dataset=train_dataset,
831 | eval_dataset=eval_dataset,
832 | tokenizer=processor.feature_extractor,
833 | )
834 |
835 | # Training
836 | if training_args.do_train:
837 | if last_checkpoint is not None:
838 | checkpoint = last_checkpoint
839 | elif os.path.isdir(model_args.model_name_or_path):
840 | checkpoint = model_args.model_name_or_path
841 | else:
842 | checkpoint = None
843 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
844 | trainer.save_model()
845 |
846 | metrics = train_result.metrics
847 | max_train_samples = (
848 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
849 | )
850 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
851 |
852 | trainer.log_metrics("train", metrics)
853 | trainer.save_metrics("train", metrics)
854 | trainer.save_state()
855 |
856 | # Evaluation
857 | metrics = {}
858 | if eval_dataset is not None and training_args.do_eval:
859 | logger.info("*** Evaluate ***")
860 | metrics = trainer.evaluate()
861 | max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
862 | metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
863 |
864 | trainer.log_metrics("eval", metrics)
865 | trainer.save_metrics("eval", metrics)
866 |
867 | # save model files
868 | if additional_training_args.upload_final_model_to_wandb:
869 | upload_model_to_wandb(training_args.output_dir, name=f"{wandb.run.name}_final", metadata=metrics)
870 |
871 | if __name__ == "__main__":
872 | main()
873 |
--------------------------------------------------------------------------------
/dataset_ext.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The HuggingFace Datasets Authors and the current dataset script contributor.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Common Voice Dataset"""
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import os
20 | import re
21 | import homoglyphs as hg
22 | import gdown
23 | import json
24 | import pandas as pd
25 | import glob
26 |
27 | import datasets
28 |
29 | import soundfile as sf
30 | import librosa
31 | import warnings
32 |
33 | from lang_trans.arabic import buckwalter
34 |
35 | _DATA_URL = "https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com/cv-corpus-6.1-2020-12-11/{}.tar.gz"
36 |
37 | _CITATION = """\
38 | @inproceedings{commonvoice:2020,
39 | author = {Ardila, R. and Branson, M. and Davis, K. and Henretty, M. and Kohler, M. and Meyer, J. and Morais, R. and Saunders, L. and Tyers, F. M. and Weber, G.},
40 | title = {Common Voice: A Massively-Multilingual Speech Corpus},
41 | booktitle = {Proceedings of the 12th Conference on Language Resources and Evaluation (LREC 2020)},
42 | pages = {4211--4215},
43 | year = 2020
44 | }
45 | """
46 |
47 | _DESCRIPTION = """\
48 | Common Voice is Mozilla's initiative to help teach machines how real people speak.
49 | The dataset currently consists of 7,335 validated hours of speech in 60 languages, but we’re always adding more voices and languages.
50 | """
51 |
52 | _HOMEPAGE = "https://commonvoice.mozilla.org/en/datasets"
53 |
54 | _LICENSE = "https://github.com/common-voice/common-voice/blob/main/LICENSE"
55 |
56 | _LANGUAGES = {
57 | "ab": {
58 | "Language": "Abkhaz",
59 | "Date": "2020-12-11",
60 | "Size": "39 MB",
61 | "Version": "ab_1h_2020-12-11",
62 | "Validated_Hr_Total": 0.05,
63 | "Overall_Hr_Total": 1,
64 | "Number_Of_Voice": 14,
65 | },
66 | "ar": {
67 | "Language": "Arabic",
68 | "Date": "2020-12-11",
69 | "Size": "2 GB",
70 | "Version": "ar_77h_2020-12-11",
71 | "Validated_Hr_Total": 49,
72 | "Overall_Hr_Total": 77,
73 | "Number_Of_Voice": 672,
74 | },
75 | "as": {
76 | "Language": "Assamese",
77 | "Date": "2020-12-11",
78 | "Size": "21 MB",
79 | "Version": "as_0.78h_2020-12-11",
80 | "Validated_Hr_Total": 0.74,
81 | "Overall_Hr_Total": 0.78,
82 | "Number_Of_Voice": 17,
83 | },
84 | "br": {
85 | "Language": "Breton",
86 | "Date": "2020-12-11",
87 | "Size": "444 MB",
88 | "Version": "br_16h_2020-12-11",
89 | "Validated_Hr_Total": 7,
90 | "Overall_Hr_Total": 16,
91 | "Number_Of_Voice": 157,
92 | },
93 | "ca": {
94 | "Language": "Catalan",
95 | "Date": "2020-12-11",
96 | "Size": "19 GB",
97 | "Version": "ca_748h_2020-12-11",
98 | "Validated_Hr_Total": 623,
99 | "Overall_Hr_Total": 748,
100 | "Number_Of_Voice": 5376,
101 | },
102 | "cnh": {
103 | "Language": "Hakha Chin",
104 | "Date": "2020-12-11",
105 | "Size": "39 MB",
106 | "Version": "ab_1h_2020-12-11",
107 | "Validated_Hr_Total": 0.05,
108 | "Overall_Hr_Total": 1,
109 | "Number_Of_Voice": 14,
110 | },
111 | "cs": {
112 | "Language": "Czech",
113 | "Date": "2020-12-11",
114 | "Size": "39 MB",
115 | "Version": "ab_1h_2020-12-11",
116 | "Validated_Hr_Total": 0.05,
117 | "Overall_Hr_Total": 1,
118 | "Number_Of_Voice": 14,
119 | },
120 | "cv": {
121 | "Language": "Chuvash",
122 | "Date": "2020-12-11",
123 | "Size": "419 MB",
124 | "Version": "cv_16h_2020-12-11",
125 | "Validated_Hr_Total": 4,
126 | "Overall_Hr_Total": 16,
127 | "Number_Of_Voice": 92,
128 | },
129 | "cy": {
130 | "Language": "Welsh",
131 | "Date": "2020-12-11",
132 | "Size": "3 GB",
133 | "Version": "cy_124h_2020-12-11",
134 | "Validated_Hr_Total": 95,
135 | "Overall_Hr_Total": 124,
136 | "Number_Of_Voice": 1382,
137 | },
138 | "de": {
139 | "Language": "German",
140 | "Date": "2020-12-11",
141 | "Size": "22 GB",
142 | "Version": "de_836h_2020-12-11",
143 | "Validated_Hr_Total": 777,
144 | "Overall_Hr_Total": 836,
145 | "Number_Of_Voice": 12659,
146 | },
147 | "dv": {
148 | "Language": "Dhivehi",
149 | "Date": "2020-12-11",
150 | "Size": "515 MB",
151 | "Version": "dv_19h_2020-12-11",
152 | "Validated_Hr_Total": 18,
153 | "Overall_Hr_Total": 19,
154 | "Number_Of_Voice": 167,
155 | },
156 | "el": {
157 | "Language": "Greek",
158 | "Date": "2020-12-11",
159 | "Size": "364 MB",
160 | "Version": "el_13h_2020-12-11",
161 | "Validated_Hr_Total": 6,
162 | "Overall_Hr_Total": 13,
163 | "Number_Of_Voice": 118,
164 | },
165 | "en": {
166 | "Language": "English",
167 | "Date": "2020-12-11",
168 | "Size": "56 GB",
169 | "Version": "en_2181h_2020-12-11",
170 | "Validated_Hr_Total": 1686,
171 | "Overall_Hr_Total": 2181,
172 | "Number_Of_Voice": 66173,
173 | },
174 | "eo": {
175 | "Language": "Esperanto",
176 | "Date": "2020-12-11",
177 | "Size": "3 GB",
178 | "Version": "eo_102h_2020-12-11",
179 | "Validated_Hr_Total": 90,
180 | "Overall_Hr_Total": 102,
181 | "Number_Of_Voice": 574,
182 | },
183 | "es": {
184 | "Language": "Spanish",
185 | "Date": "2020-12-11",
186 | "Size": "15 GB",
187 | "Version": "es_579h_2020-12-11",
188 | "Validated_Hr_Total": 324,
189 | "Overall_Hr_Total": 579,
190 | "Number_Of_Voice": 19484,
191 | },
192 | "et": {
193 | "Language": "Estonian",
194 | "Date": "2020-12-11",
195 | "Size": "732 MB",
196 | "Version": "et_27h_2020-12-11",
197 | "Validated_Hr_Total": 19,
198 | "Overall_Hr_Total": 27,
199 | "Number_Of_Voice": 543,
200 | },
201 | "eu": {
202 | "Language": "Basque",
203 | "Date": "2020-12-11",
204 | "Size": "3 GB",
205 | "Version": "eu_131h_2020-12-11",
206 | "Validated_Hr_Total": 89,
207 | "Overall_Hr_Total": 131,
208 | "Number_Of_Voice": 1028,
209 | },
210 | "fa": {
211 | "Language": "Persian",
212 | "Date": "2020-12-11",
213 | "Size": "8 GB",
214 | "Version": "fa_321h_2020-12-11",
215 | "Validated_Hr_Total": 282,
216 | "Overall_Hr_Total": 321,
217 | "Number_Of_Voice": 3655,
218 | },
219 | "fi": {
220 | "Language": "Finnish",
221 | "Date": "2020-12-11",
222 | "Size": "48 MB",
223 | "Version": "fi_1h_2020-12-11",
224 | "Validated_Hr_Total": 1,
225 | "Overall_Hr_Total": 1,
226 | "Number_Of_Voice": 27,
227 | },
228 | "fr": {
229 | "Language": "French",
230 | "Date": "2020-12-11",
231 | "Size": "18 GB",
232 | "Version": "fr_682h_2020-12-11",
233 | "Validated_Hr_Total": 623,
234 | "Overall_Hr_Total": 682,
235 | "Number_Of_Voice": 12953,
236 | },
237 | "fy-NL": {
238 | "Language": "Frisian",
239 | "Date": "2020-12-11",
240 | "Size": "1 GB",
241 | "Version": "fy-NL_46h_2020-12-11",
242 | "Validated_Hr_Total": 14,
243 | "Overall_Hr_Total": 46,
244 | "Number_Of_Voice": 467,
245 | },
246 | "ga-IE": {
247 | "Language": "Irish",
248 | "Date": "2020-12-11",
249 | "Size": "149 MB",
250 | "Version": "ga-IE_5h_2020-12-11",
251 | "Validated_Hr_Total": 3,
252 | "Overall_Hr_Total": 5,
253 | "Number_Of_Voice": 101,
254 | },
255 | "hi": {
256 | "Language": "Hindi",
257 | "Date": "2020-12-11",
258 | "Size": "20 MB",
259 | "Version": "hi_0.8h_2020-12-11",
260 | "Validated_Hr_Total": 0.54,
261 | "Overall_Hr_Total": 0.8,
262 | "Number_Of_Voice": 31,
263 | },
264 | "hsb": {
265 | "Language": "Sorbian, Upper",
266 | "Date": "2020-12-11",
267 | "Size": "76 MB",
268 | "Version": "hsb_2h_2020-12-11",
269 | "Validated_Hr_Total": 2,
270 | "Overall_Hr_Total": 2,
271 | "Number_Of_Voice": 19,
272 | },
273 | "hu": {
274 | "Language": "Hungarian",
275 | "Date": "2020-12-11",
276 | "Size": "232 MB",
277 | "Version": "hu_8h_2020-12-11",
278 | "Validated_Hr_Total": 8,
279 | "Overall_Hr_Total": 8,
280 | "Number_Of_Voice": 47,
281 | },
282 | "ia": {
283 | "Language": "InterLinguia",
284 | "Date": "2020-12-11",
285 | "Size": "216 MB",
286 | "Version": "ia_8h_2020-12-11",
287 | "Validated_Hr_Total": 6,
288 | "Overall_Hr_Total": 8,
289 | "Number_Of_Voice": 36,
290 | },
291 | "id": {
292 | "Language": "Indonesian",
293 | "Date": "2020-12-11",
294 | "Size": "454 MB",
295 | "Version": "id_17h_2020-12-11",
296 | "Validated_Hr_Total": 9,
297 | "Overall_Hr_Total": 17,
298 | "Number_Of_Voice": 219,
299 | },
300 | "it": {
301 | "Language": "Italian",
302 | "Date": "2020-12-11",
303 | "Size": "5 GB",
304 | "Version": "it_199h_2020-12-11",
305 | "Validated_Hr_Total": 158,
306 | "Overall_Hr_Total": 199,
307 | "Number_Of_Voice": 5729,
308 | },
309 | "ja": {
310 | "Language": "Japanese",
311 | "Date": "2020-12-11",
312 | "Size": "146 MB",
313 | "Version": "ja_5h_2020-12-11",
314 | "Validated_Hr_Total": 3,
315 | "Overall_Hr_Total": 5,
316 | "Number_Of_Voice": 235,
317 | },
318 | "ka": {
319 | "Language": "Georgian",
320 | "Date": "2020-12-11",
321 | "Size": "99 MB",
322 | "Version": "ka_3h_2020-12-11",
323 | "Validated_Hr_Total": 3,
324 | "Overall_Hr_Total": 3,
325 | "Number_Of_Voice": 44,
326 | },
327 | "kab": {
328 | "Language": "Kabyle",
329 | "Date": "2020-12-11",
330 | "Size": "16 GB",
331 | "Version": "kab_622h_2020-12-11",
332 | "Validated_Hr_Total": 525,
333 | "Overall_Hr_Total": 622,
334 | "Number_Of_Voice": 1309,
335 | },
336 | "ky": {
337 | "Language": "Kyrgyz",
338 | "Date": "2020-12-11",
339 | "Size": "553 MB",
340 | "Version": "ky_22h_2020-12-11",
341 | "Validated_Hr_Total": 11,
342 | "Overall_Hr_Total": 22,
343 | "Number_Of_Voice": 134,
344 | },
345 | "lg": {
346 | "Language": "Luganda",
347 | "Date": "2020-12-11",
348 | "Size": "199 MB",
349 | "Version": "lg_8h_2020-12-11",
350 | "Validated_Hr_Total": 3,
351 | "Overall_Hr_Total": 8,
352 | "Number_Of_Voice": 76,
353 | },
354 | "lt": {
355 | "Language": "Lithuanian",
356 | "Date": "2020-12-11",
357 | "Size": "129 MB",
358 | "Version": "lt_4h_2020-12-11",
359 | "Validated_Hr_Total": 2,
360 | "Overall_Hr_Total": 4,
361 | "Number_Of_Voice": 30,
362 | },
363 | "lv": {
364 | "Language": "Latvian",
365 | "Date": "2020-12-11",
366 | "Size": "199 MB",
367 | "Version": "lv_7h_2020-12-11",
368 | "Validated_Hr_Total": 6,
369 | "Overall_Hr_Total": 7,
370 | "Number_Of_Voice": 99,
371 | },
372 | "mn": {
373 | "Language": "Mongolian",
374 | "Date": "2020-12-11",
375 | "Size": "464 MB",
376 | "Version": "mn_17h_2020-12-11",
377 | "Validated_Hr_Total": 11,
378 | "Overall_Hr_Total": 17,
379 | "Number_Of_Voice": 376,
380 | },
381 | "mt": {
382 | "Language": "Maltese",
383 | "Date": "2020-12-11",
384 | "Size": "405 MB",
385 | "Version": "mt_15h_2020-12-11",
386 | "Validated_Hr_Total": 7,
387 | "Overall_Hr_Total": 15,
388 | "Number_Of_Voice": 171,
389 | },
390 | "nl": {
391 | "Language": "Dutch",
392 | "Date": "2020-12-11",
393 | "Size": "2 GB",
394 | "Version": "nl_63h_2020-12-11",
395 | "Validated_Hr_Total": 59,
396 | "Overall_Hr_Total": 63,
397 | "Number_Of_Voice": 1012,
398 | },
399 | "or": {
400 | "Language": "Odia",
401 | "Date": "2020-12-11",
402 | "Size": "190 MB",
403 | "Version": "or_7h_2020-12-11",
404 | "Validated_Hr_Total": 0.87,
405 | "Overall_Hr_Total": 7,
406 | "Number_Of_Voice": 34,
407 | },
408 | "pa-IN": {
409 | "Language": "Punjabi",
410 | "Date": "2020-12-11",
411 | "Size": "67 MB",
412 | "Version": "pa-IN_2h_2020-12-11",
413 | "Validated_Hr_Total": 0.5,
414 | "Overall_Hr_Total": 2,
415 | "Number_Of_Voice": 26,
416 | },
417 | "pl": {
418 | "Language": "Polish",
419 | "Date": "2020-12-11",
420 | "Size": "3 GB",
421 | "Version": "pl_129h_2020-12-11",
422 | "Validated_Hr_Total": 108,
423 | "Overall_Hr_Total": 129,
424 | "Number_Of_Voice": 2647,
425 | },
426 | "pt": {
427 | "Language": "Portuguese",
428 | "Date": "2020-12-11",
429 | "Size": "2 GB",
430 | "Version": "pt_63h_2020-12-11",
431 | "Validated_Hr_Total": 50,
432 | "Overall_Hr_Total": 63,
433 | "Number_Of_Voice": 1120,
434 | },
435 | "rm-sursilv": {
436 | "Language": "Romansh Sursilvan",
437 | "Date": "2020-12-11",
438 | "Size": "263 MB",
439 | "Version": "rm-sursilv_9h_2020-12-11",
440 | "Validated_Hr_Total": 5,
441 | "Overall_Hr_Total": 9,
442 | "Number_Of_Voice": 78,
443 | },
444 | "rm-vallader": {
445 | "Language": "Romansh Vallader",
446 | "Date": "2020-12-11",
447 | "Size": "103 MB",
448 | "Version": "rm-vallader_3h_2020-12-11",
449 | "Validated_Hr_Total": 2,
450 | "Overall_Hr_Total": 3,
451 | "Number_Of_Voice": 39,
452 | },
453 | "ro": {
454 | "Language": "Romanian",
455 | "Date": "2020-12-11",
456 | "Size": "250 MB",
457 | "Version": "ro_9h_2020-12-11",
458 | "Validated_Hr_Total": 6,
459 | "Overall_Hr_Total": 9,
460 | "Number_Of_Voice": 130,
461 | },
462 | "ru": {
463 | "Language": "Russian",
464 | "Date": "2020-12-11",
465 | "Size": "3 GB",
466 | "Version": "ru_130h_2020-12-11",
467 | "Validated_Hr_Total": 111,
468 | "Overall_Hr_Total": 130,
469 | "Number_Of_Voice": 1412,
470 | },
471 | "rw": {
472 | "Language": "Kinyarwanda",
473 | "Date": "2020-12-11",
474 | "Size": "40 GB",
475 | "Version": "rw_1510h_2020-12-11",
476 | "Validated_Hr_Total": 1183,
477 | "Overall_Hr_Total": 1510,
478 | "Number_Of_Voice": 410,
479 | },
480 | "sah": {
481 | "Language": "Sakha",
482 | "Date": "2020-12-11",
483 | "Size": "173 MB",
484 | "Version": "sah_6h_2020-12-11",
485 | "Validated_Hr_Total": 4,
486 | "Overall_Hr_Total": 6,
487 | "Number_Of_Voice": 42,
488 | },
489 | "sl": {
490 | "Language": "Slovenian",
491 | "Date": "2020-12-11",
492 | "Size": "212 MB",
493 | "Version": "sl_7h_2020-12-11",
494 | "Validated_Hr_Total": 5,
495 | "Overall_Hr_Total": 7,
496 | "Number_Of_Voice": 82,
497 | },
498 | "sv-SE": {
499 | "Language": "Swedish",
500 | "Date": "2020-12-11",
501 | "Size": "402 MB",
502 | "Version": "sv-SE_15h_2020-12-11",
503 | "Validated_Hr_Total": 12,
504 | "Overall_Hr_Total": 15,
505 | "Number_Of_Voice": 222,
506 | },
507 | "ta": {
508 | "Language": "Tamil",
509 | "Date": "2020-12-11",
510 | "Size": "648 MB",
511 | "Version": "ta_24h_2020-12-11",
512 | "Validated_Hr_Total": 14,
513 | "Overall_Hr_Total": 24,
514 | "Number_Of_Voice": 266,
515 | },
516 | "th": {
517 | "Language": "Thai",
518 | "Date": "2020-12-11",
519 | "Size": "325 MB",
520 | "Version": "th_12h_2020-12-11",
521 | "Validated_Hr_Total": 8,
522 | "Overall_Hr_Total": 12,
523 | "Number_Of_Voice": 182,
524 | },
525 | "tr": {
526 | "Language": "Turkish",
527 | "Date": "2020-12-11",
528 | "Size": "592 MB",
529 | "Version": "tr_22h_2020-12-11",
530 | "Validated_Hr_Total": 20,
531 | "Overall_Hr_Total": 22,
532 | "Number_Of_Voice": 678,
533 | },
534 | "tt": {
535 | "Language": "Tatar",
536 | "Date": "2020-12-11",
537 | "Size": "741 MB",
538 | "Version": "tt_28h_2020-12-11",
539 | "Validated_Hr_Total": 26,
540 | "Overall_Hr_Total": 28,
541 | "Number_Of_Voice": 185,
542 | },
543 | "uk": {
544 | "Language": "Ukrainian",
545 | "Date": "2020-12-11",
546 | "Size": "1 GB",
547 | "Version": "uk_43h_2020-12-11",
548 | "Validated_Hr_Total": 30,
549 | "Overall_Hr_Total": 43,
550 | "Number_Of_Voice": 459,
551 | },
552 | "vi": {
553 | "Language": "Vietnamese",
554 | "Date": "2020-12-11",
555 | "Size": "50 MB",
556 | "Version": "vi_1h_2020-12-11",
557 | "Validated_Hr_Total": 0.74,
558 | "Overall_Hr_Total": 1,
559 | "Number_Of_Voice": 62,
560 | },
561 | "vot": {
562 | "Language": "Votic",
563 | "Date": "2020-12-11",
564 | "Size": "7 MB",
565 | "Version": "vot_0.28h_2020-12-11",
566 | "Validated_Hr_Total": 0,
567 | "Overall_Hr_Total": 0.28,
568 | "Number_Of_Voice": 3,
569 | },
570 | "zh-CN": {
571 | "Language": "Chinese (China)",
572 | "Date": "2020-12-11",
573 | "Size": "2 GB",
574 | "Version": "zh-CN_78h_2020-12-11",
575 | "Validated_Hr_Total": 56,
576 | "Overall_Hr_Total": 78,
577 | "Number_Of_Voice": 3501,
578 | },
579 | "zh-HK": {
580 | "Language": "Chinese (Hong Kong)",
581 | "Date": "2020-12-11",
582 | "Size": "3 GB",
583 | "Version": "zh-HK_100h_2020-12-11",
584 | "Validated_Hr_Total": 50,
585 | "Overall_Hr_Total": 100,
586 | "Number_Of_Voice": 2536,
587 | },
588 | "zh-TW": {
589 | "Language": "Chinese (Taiwan)",
590 | "Date": "2020-12-11",
591 | "Size": "2 GB",
592 | "Version": "zh-TW_78h_2020-12-11",
593 | "Validated_Hr_Total": 55,
594 | "Overall_Hr_Total": 78,
595 | "Number_Of_Voice": 1444,
596 | },
597 | }
598 |
599 | _CSS10_URLS = {
600 | "de": "https://drive.google.com/uc?id=1wgCHGvT0S8YrNfRTVyn23sW-5MFknoHA", # 7427 samples
601 | "el": "https://drive.google.com/uc?id=10BNORyOqkosxEf3qAAtWM1qWjHEZzXTO", # 1844 samples
602 | "es": "https://drive.google.com/uc?id=1dyUvSxv0KowTseI35dE8UXpVsYFhEpQV", # 11100 samples
603 | "fi": "https://drive.google.com/uc?id=1H4-eGIgf4aK_s14uo-srbKMENpysuV2u", # 4842 samples
604 | "fr": "https://drive.google.com/uc?id=1kuhoDjhA_Cij0SJuMI_4kneDTR_cqahS", # 8648 samples
605 | "hu": "https://drive.google.com/uc?id=1ms2INJ1e0ChU0TMzgDYLa8jtoTK2gkmE", # 4515 samples
606 | "ja": "https://drive.google.com/uc?id=1E4k8FduAk-_wy85AQrGakZBcw2hLhmU6", # 6841 samples
607 | "nl": "https://drive.google.com/uc?id=1ji8QD4lJzInz2vomGkMafRjpz3gGBYsf", # 6494 samples
608 | "ru": "https://drive.google.com/uc?id=1tx3dpO8SX8CriF0YsK8XeISZc9yGRody", # 9599 samples
609 | "zh-CN": "https://drive.google.com/uc?id=1hliY4KD_I8y4FQg5zta9IDGN0HRQLRiv", # 2971 samples
610 | }
611 |
612 | _JSUT_URLS = {
613 | "ja": "http://ss-takashi.sakura.ne.jp/corpus/jsut_ver1.1.zip" # 7696 samples
614 | }
615 |
616 | _NST_URLS = {
617 | "sv-SE": {
618 | "metadata": "https://www.nb.no/sbfil/talegjenkjenning/16kHz_2020/se_2020/ADB_SWE_0467.tar.gz",
619 | "files": "https://www.nb.no/sbfil/talegjenkjenning/16kHz_2020/se_2020/lydfiler_16_1.tar.gz", # ? samples
620 | }
621 | }
622 |
623 | _FREE_ST_URLS = {
624 | "zh-CN": "https://www.openslr.org/resources/38/ST-CMDS-20170001_1-OS.tar.gz", # 102600 samples
625 | }
626 |
627 | _ARABIC_SPEECH = {
628 | "ar": "http://en.arabicspeechcorpus.com/arabic-speech-corpus.zip" # 1913 samples
629 | }
630 |
631 | _TIMIT = {
632 | "en": "https://data.deepai.org/timit.zip" # 4620 samples
633 | }
634 |
635 | _LIBRISPEECH_DL_URL = "http://www.openslr.org/resources/12/"
636 | _LIBRISPEECH = {
637 | "en": [
638 | _LIBRISPEECH_DL_URL + "dev-clean.tar.gz", # 2703 samples
639 | _LIBRISPEECH_DL_URL + "dev-other.tar.gz", # 2864 samples
640 | _LIBRISPEECH_DL_URL + "train-clean-100.tar.gz", # 28539 samples
641 | _LIBRISPEECH_DL_URL + "train-clean-360.tar.gz", # 104014 samples
642 | _LIBRISPEECH_DL_URL + "train-other-500.tar.gz", # 148688 samples
643 | ]
644 | }
645 |
646 | _MAX_TRAIN_SAMPLES = 90000
647 | _MAX_VAL_SAMPLES = 10000
648 |
649 | class CommonVoiceConfig(datasets.BuilderConfig):
650 | """BuilderConfig for CommonVoice."""
651 |
652 | def __init__(self, name, sub_version, **kwargs):
653 | """
654 | Args:
655 | data_dir: `string`, the path to the folder containing the files in the
656 | downloaded .tar
657 | citation: `string`, citation for the data set
658 | url: `string`, url for information about the data set
659 | **kwargs: keyword arguments forwarded to super.
660 | """
661 | self.sub_version = sub_version
662 | self.language = kwargs.pop("language", None)
663 | self.date_of_snapshot = kwargs.pop("date", None)
664 | self.size = kwargs.pop("size", None)
665 | self.validated_hr_total = kwargs.pop("val_hrs", None)
666 | self.total_hr_total = kwargs.pop("total_hrs", None)
667 | self.num_of_voice = kwargs.pop("num_of_voice", None)
668 |
669 | self.unk_token_regex = None
670 | if self.language in hg.Languages.get_all():
671 | # creating regex to match language specific non valid characters
672 | currency_symbols = ["$", "£", "€", "¥", "₩", "₹", "₽", "₱", "₦", "₼", "ლ", "₭", "₴", "₲", "₫", "₡", "₵", "₿", "฿", "¢"]
673 | alphabet = list(hg.Languages.get_alphabet([self.language]))
674 | valid_chars = alphabet + currency_symbols
675 | self.unk_token_regex = "[^"+re.escape("".join(valid_chars))+"\s\d]"
676 |
677 | description = f"Common Voice speech to text dataset in {self.language} version {self.sub_version} of {self.date_of_snapshot}. The dataset comprises {self.validated_hr_total} of validated transcribed speech data from {self.num_of_voice} speakers. The dataset has a size of {self.size}"
678 | super(CommonVoiceConfig, self).__init__(
679 | name=name, version=datasets.Version("6.1.0", ""), description=description, **kwargs
680 | )
681 |
682 |
683 | class CommonVoice(datasets.GeneratorBasedBuilder):
684 |
685 | BUILDER_CONFIGS = [
686 | CommonVoiceConfig(
687 | name=lang_id,
688 | language=_LANGUAGES[lang_id]["Language"],
689 | sub_version=_LANGUAGES[lang_id]["Version"],
690 | date=_LANGUAGES[lang_id]["Date"],
691 | size=_LANGUAGES[lang_id]["Size"],
692 | val_hrs=_LANGUAGES[lang_id]["Validated_Hr_Total"],
693 | total_hrs=_LANGUAGES[lang_id]["Overall_Hr_Total"],
694 | num_of_voice=_LANGUAGES[lang_id]["Number_Of_Voice"],
695 | )
696 | for lang_id in _LANGUAGES.keys()
697 | ]
698 |
699 | def _info(self):
700 | features = datasets.Features(
701 | {
702 | "client_id": datasets.Value("string"),
703 | "path": datasets.Value("string"),
704 | "sentence": datasets.Value("string"),
705 | "up_votes": datasets.Value("int64"),
706 | "down_votes": datasets.Value("int64"),
707 | "age": datasets.Value("string"),
708 | "gender": datasets.Value("string"),
709 | "accent": datasets.Value("string"),
710 | "locale": datasets.Value("string"),
711 | "segment": datasets.Value("string"),
712 | "duration": datasets.Value("float32"),
713 | "dataset": datasets.Value("string"),
714 | }
715 | )
716 |
717 | return datasets.DatasetInfo(
718 | description=_DESCRIPTION,
719 | features=features,
720 | supervised_keys=None,
721 | homepage=_HOMEPAGE,
722 | license=_LICENSE,
723 | citation=_CITATION,
724 | )
725 |
726 | def _download_from_gdrive(self, src_url: str, dst_path: str):
727 | """Downloading from Gdrive"""
728 | gdown.download(src_url, dst_path, quiet=False)
729 |
730 | def _split_generators(self, dl_manager):
731 | """Returns SplitGenerators."""
732 | dl_path = dl_manager.download_and_extract(_DATA_URL.format(self.config.name))
733 | abs_path_to_data = os.path.join(dl_path, "cv-corpus-6.1-2020-12-11", self.config.name)
734 | abs_path_to_clips = os.path.join(abs_path_to_data, "clips")
735 |
736 | css10_dir = None
737 | if self.config.name in _CSS10_URLS:
738 | css10_url = _CSS10_URLS[self.config.name]
739 | css10_dir = dl_manager.extract(dl_manager.download_custom(css10_url, self._download_from_gdrive))
740 |
741 | jsut_dir = None
742 | if self.config.name in _JSUT_URLS:
743 | jsut_url = _JSUT_URLS[self.config.name]
744 | jsut_dir = dl_manager.download_and_extract(jsut_url)
745 | jsut_dir = os.path.join(jsut_dir, "jsut_ver1.1")
746 |
747 | nst_metadata_dir = None
748 | nst_files_dir = None
749 | if self.config.name in _NST_URLS:
750 | nst_metadata_dir = dl_manager.download_and_extract(_NST_URLS[self.config.name]["metadata"])
751 | nst_files_dir = dl_manager.download_and_extract(_NST_URLS[self.config.name]["files"])
752 |
753 | free_st_dir = None
754 | if self.config.name in _FREE_ST_URLS:
755 | free_st_dir = dl_manager.download_and_extract(_FREE_ST_URLS[self.config.name])
756 | free_st_dir = os.path.join(free_st_dir, "ST-CMDS-20170001_1-OS")
757 |
758 | arabic_speech_dir = None
759 | if self.config.name in _ARABIC_SPEECH:
760 | arabic_speech_dir = dl_manager.download_and_extract(_ARABIC_SPEECH[self.config.name])
761 | arabic_speech_dir = os.path.join(arabic_speech_dir, "arabic-speech-corpus")
762 |
763 | timit_dir = None
764 | if self.config.name in _TIMIT:
765 | timit_dir = dl_manager.download_and_extract(_TIMIT[self.config.name])
766 |
767 | librispeech_dirs = None
768 | if self.config.name in _LIBRISPEECH:
769 | librispeech_dirs = []
770 | for librispeech_url in _LIBRISPEECH[self.config.name]:
771 | librispeech_dir = dl_manager.download_and_extract(librispeech_url)
772 | librispeech_dirs.append(librispeech_dir)
773 |
774 | return [
775 | datasets.SplitGenerator(
776 | name=datasets.Split.TRAIN,
777 | gen_kwargs={
778 | "filepath": os.path.join(abs_path_to_data, "train.tsv"),
779 | "path_to_clips": abs_path_to_clips,
780 | "css10_dir": css10_dir,
781 | "jsut_dir": jsut_dir,
782 | "nst_metadata_dir": nst_metadata_dir,
783 | "nst_files_dir": nst_files_dir,
784 | "free_st_dir": free_st_dir,
785 | "arabic_speech_dir": arabic_speech_dir,
786 | "timit_dir": timit_dir,
787 | "librispeech_dirs": librispeech_dirs,
788 | "max_samples": _MAX_TRAIN_SAMPLES
789 | },
790 | ),
791 | datasets.SplitGenerator(
792 | name=datasets.Split.VALIDATION,
793 | gen_kwargs={
794 | "filepath": os.path.join(abs_path_to_data, "dev.tsv"),
795 | "path_to_clips": abs_path_to_clips,
796 | "css10_dir": None,
797 | "jsut_dir": None,
798 | "nst_metadata_dir": None,
799 | "nst_files_dir": None,
800 | "free_st_dir": None,
801 | "arabic_speech_dir": None,
802 | "timit_dir": None,
803 | "librispeech_dirs": None,
804 | "max_samples": _MAX_VAL_SAMPLES
805 | },
806 | )
807 | ]
808 |
809 | def _convert_to_flac_and_save_it(self, path, delete_original_file=True):
810 | """We'll convert all the audio files to FLAC format to speedup the loading"""
811 |
812 | sample_path, sample_extension = os.path.splitext(path)
813 | new_path = f"{sample_path}.flac"
814 |
815 | if not os.path.isfile(new_path):
816 |
817 | with warnings.catch_warnings():
818 | warnings.simplefilter("ignore")
819 | speech_array, sample_rate = librosa.load(path, sr=16_000)
820 |
821 | sf.write(new_path, speech_array, sample_rate)
822 |
823 | if delete_original_file:
824 | os.remove(path)
825 |
826 | return new_path
827 |
828 | def _common_voice_examples_generator(self, filepath, path_to_clips):
829 |
830 | data_fields = list(self._info().features.keys())
831 | path_idx = data_fields.index("path")
832 |
833 | with open(filepath, encoding="utf-8") as f:
834 | lines = f.readlines()
835 |
836 | for line in lines[1:]:
837 | field_values = line.strip().split("\t")
838 |
839 | # set absolute path for mp3 audio file
840 | field_values[path_idx] = os.path.join(path_to_clips, field_values[path_idx])
841 |
842 | # if data is incomplete, fill with empty values
843 | if len(field_values) < len(data_fields):
844 | field_values += (len(data_fields) - len(field_values)) * ["''"]
845 |
846 | sample = {key: value for key, value in zip(data_fields, field_values)}
847 |
848 | new_path = self._convert_to_flac_and_save_it(sample.get("path"))
849 | speech_array, sampling_rate = sf.read(new_path)
850 | sample["duration"] = len(speech_array) / sampling_rate
851 | sample["path"] = new_path
852 | sample["dataset"] = "common_voice"
853 |
854 | if self.config.unk_token_regex is not None:
855 | sample["sentence"] = re.sub(self.config.unk_token_regex, "", sample["sentence"])
856 |
857 | yield sample
858 |
859 | def _css10_examples_generator(self, css10_dir):
860 |
861 | with open(os.path.join(css10_dir, "transcript.txt"), encoding="utf-8") as f:
862 | lines = f.readlines()
863 |
864 | for line in lines:
865 | values = line.strip().split("|")
866 |
867 | audio_path = self._convert_to_flac_and_save_it(os.path.join(css10_dir, values[0]))
868 | text = values[1] if self.config.name in ["ja", "zh"] else values[2]
869 | text = re.sub("\s+", " ", text) # remove multiple spaces
870 | duration = float(values[3])
871 |
872 | if self.config.unk_token_regex is not None:
873 | text = re.sub(self.config.unk_token_regex, "", text)
874 |
875 | yield {
876 | "client_id": None,
877 | "path": audio_path,
878 | "sentence": text,
879 | "up_votes": 0,
880 | "down_votes": 0,
881 | "age": None,
882 | "gender": None,
883 | "accent": None,
884 | "locale": None,
885 | "segment": None,
886 | "duration": duration,
887 | "dataset": "css10"
888 | }
889 |
890 | def _jsut_examples_generator(self, jsut_dir):
891 |
892 | for subset in os.listdir(jsut_dir):
893 |
894 | if not os.path.isdir(os.path.join(jsut_dir, subset)):
895 | continue
896 |
897 | transcript_path = os.path.join(jsut_dir, subset, "transcript_utf8.txt")
898 |
899 | with open(transcript_path, encoding="utf-8") as f:
900 |
901 | lines = f.readlines()
902 |
903 | for line in lines:
904 |
905 | values = line.split(":")
906 | audio_path = os.path.join(jsut_dir, subset, "wav", f"{values[0]}.wav")
907 | text = values[1]
908 | text = re.sub("\s+", " ", text) # remove multiple spaces
909 |
910 | if self.config.unk_token_regex is not None:
911 | text = re.sub(self.config.unk_token_regex, "", text)
912 |
913 | new_audio_path = self._convert_to_flac_and_save_it(audio_path)
914 | speech_array, sampling_rate = sf.read(new_audio_path)
915 | duration = len(speech_array) / sampling_rate
916 |
917 | yield {
918 | "client_id": None,
919 | "path": new_audio_path,
920 | "sentence": text,
921 | "up_votes": 0,
922 | "down_votes": 0,
923 | "age": None,
924 | "gender": None,
925 | "accent": None,
926 | "locale": None,
927 | "segment": None,
928 | "duration": duration,
929 | "dataset": "jsut"
930 | }
931 |
932 | def _nst_examples_generator(self, nst_metadata_dir, nst_files_dir):
933 |
934 | for metadata_filename in os.listdir(nst_metadata_dir):
935 |
936 | metadata_filepath = os.path.join(nst_metadata_dir, metadata_filename)
937 |
938 | with open(metadata_filepath) as metadata_file:
939 | metadata = json.load(metadata_file)
940 |
941 | client_id = metadata.get("info", {}).get("Speaker_ID", None)
942 | age = metadata.get("info", {}).get("Age", None)
943 | gender = metadata.get("info", {}).get("Sex", None)
944 | lang = metadata.get("metadata").get("lang")
945 | pid = metadata.get("pid")
946 | audio_dir = os.path.join(nst_files_dir, lang, pid)
947 |
948 | for val_recording in metadata.get("val_recordings", []):
949 |
950 | audio_filename = f"{pid}_{val_recording.get('file').replace('.wav', '-1.wav')}"
951 | audio_path = os.path.join(audio_dir, audio_filename)
952 |
953 | # there are some missing files on the original dataset, so we need to handle this
954 | if not os.path.isfile(audio_path):
955 | continue
956 |
957 | text = val_recording.get("text")
958 | text = re.sub("\s+", " ", text) # remove multiple spaces
959 |
960 | if self.config.unk_token_regex is not None:
961 | text = re.sub(self.config.unk_token_regex, "", text)
962 |
963 | new_audio_path = self._convert_to_flac_and_save_it(audio_path)
964 | speech_array, sampling_rate = sf.read(new_audio_path)
965 | duration = len(speech_array) / sampling_rate
966 |
967 | yield {
968 | "client_id": client_id,
969 | "path": new_audio_path,
970 | "sentence": text,
971 | "up_votes": 0,
972 | "down_votes": 0,
973 | "age": age,
974 | "gender": gender,
975 | "accent": None,
976 | "locale": None,
977 | "segment": None,
978 | "duration": duration,
979 | "dataset": "nst"
980 | }
981 |
982 | def _free_st_examples_generator(self, free_st_dir):
983 |
984 | for filename in os.listdir(free_st_dir):
985 |
986 | if filename.endswith(".wav"):
987 |
988 | audio_path = os.path.join(free_st_dir, filename)
989 | text_path = os.path.join(free_st_dir, filename.replace(".wav", ".txt"))
990 |
991 | with open(text_path, "r") as text_file:
992 | text = text_file.read().replace("\n", "").strip()
993 | text = re.sub("\s+", " ", text) # remove multiple spaces
994 |
995 | if self.config.unk_token_regex is not None:
996 | text = re.sub(self.config.unk_token_regex, "", text)
997 |
998 | new_audio_path = self._convert_to_flac_and_save_it(audio_path)
999 | speech_array, sampling_rate = sf.read(new_audio_path)
1000 | duration = len(speech_array) / sampling_rate
1001 |
1002 | yield {
1003 | "client_id": None,
1004 | "path": new_audio_path,
1005 | "sentence": text,
1006 | "up_votes": 0,
1007 | "down_votes": 0,
1008 | "age": None,
1009 | "gender": None,
1010 | "accent": None,
1011 | "locale": None,
1012 | "segment": None,
1013 | "duration": duration,
1014 | "dataset": "free_st"
1015 | }
1016 |
1017 | def _arabic_speech_examples_generator(self, arabic_speech_dir):
1018 |
1019 | with open(os.path.join(arabic_speech_dir, "orthographic-transcript.txt"), encoding="utf-8") as f:
1020 |
1021 | lines = f.readlines()
1022 |
1023 | for line in lines:
1024 |
1025 | values = line.split('" "')
1026 | filename = values[0].strip()[1:]
1027 | text = values[1].strip()[:-1]
1028 | audio_path = os.path.join(arabic_speech_dir, "wav", filename)
1029 |
1030 | # converting buckwalter format to arabic letters
1031 | text = buckwalter.untransliterate(text)
1032 | text = re.sub("\s+", " ", text) # remove multiple spaces
1033 |
1034 | if self.config.unk_token_regex is not None:
1035 | text = re.sub(self.config.unk_token_regex, "", text)
1036 |
1037 | new_audio_path = self._convert_to_flac_and_save_it(audio_path)
1038 | speech_array, sampling_rate = sf.read(new_audio_path)
1039 | duration = len(speech_array) / sampling_rate
1040 |
1041 | yield {
1042 | "client_id": None,
1043 | "path": new_audio_path,
1044 | "sentence": text,
1045 | "up_votes": 0,
1046 | "down_votes": 0,
1047 | "age": None,
1048 | "gender": None,
1049 | "accent": None,
1050 | "locale": None,
1051 | "segment": None,
1052 | "duration": duration,
1053 | "dataset": "arabic_speech"
1054 | }
1055 |
1056 | def _timit_examples_generator(self, timit_dir):
1057 |
1058 | data_info_csv = os.path.join(timit_dir, "train_data.csv")
1059 |
1060 | """Generate examples from TIMIT archive_path based on the test/train csv information."""
1061 | # Extract the archive path
1062 | data_path = os.path.join(os.path.dirname(data_info_csv).strip(), "data")
1063 |
1064 | # Read the data info to extract rows mentioning about non-converted audio only
1065 | data_info = pd.read_csv(data_info_csv, encoding="utf8")
1066 | # making sure that the columns having no information about the file paths are removed
1067 | data_info.dropna(subset=["path_from_data_dir"], inplace=True)
1068 |
1069 | # filter out only the required information for data preparation
1070 | data_info = data_info.loc[(data_info["is_audio"]) & (~data_info["is_converted_audio"])]
1071 |
1072 | # Iterating the contents of the data to extract the relevant information
1073 | for audio_idx in range(data_info.shape[0]):
1074 | audio_data = data_info.iloc[audio_idx]
1075 |
1076 | # extract the path to audio
1077 | wav_path = os.path.join(data_path, *(audio_data["path_from_data_dir"].split("/")))
1078 |
1079 | # extract transcript
1080 | with open(wav_path.replace(".WAV", ".TXT"), "r", encoding="utf-8") as op:
1081 | transcript = " ".join(op.readlines()[0].split()[2:]) # first two items are sample number
1082 |
1083 | new_audio_path = self._convert_to_flac_and_save_it(wav_path)
1084 | speech_array, sampling_rate = sf.read(new_audio_path)
1085 | duration = len(speech_array) / sampling_rate
1086 |
1087 | yield {
1088 | "client_id": str(audio_data["speaker_id"]),
1089 | "path": new_audio_path,
1090 | "sentence": transcript,
1091 | "up_votes": 0,
1092 | "down_votes": 0,
1093 | "age": None,
1094 | "gender": None,
1095 | "accent": audio_data["dialect_region"],
1096 | "locale": None,
1097 | "segment": None,
1098 | "duration": duration,
1099 | "dataset": "timit"
1100 | }
1101 |
1102 | def _librispeech_examples_generator(self, librispeech_dir):
1103 |
1104 | transcripts_glob = os.path.join(librispeech_dir, "LibriSpeech", "*/*/*/*.txt")
1105 | for transcript_file in sorted(glob.glob(transcripts_glob)):
1106 | path = os.path.dirname(transcript_file)
1107 | # with open(os.path.join(path, transcript_file), "r", encoding="utf-8") as f:
1108 | with open(transcript_file, "r", encoding="utf-8") as f:
1109 | for line in f:
1110 | line = line.strip()
1111 | key, transcript = line.split(" ", 1)
1112 | audio_file = f"{key}.flac"
1113 | audio_file = os.path.join(path, audio_file)
1114 | speaker_id, chapter_id = [int(el) for el in key.split("-")[:2]]
1115 |
1116 | speech_array, sampling_rate = sf.read(audio_file)
1117 | duration = len(speech_array) / sampling_rate
1118 |
1119 | yield {
1120 | "client_id": str(speaker_id),
1121 | "path": audio_file,
1122 | "sentence": transcript,
1123 | "up_votes": 0,
1124 | "down_votes": 0,
1125 | "age": None,
1126 | "gender": None,
1127 | "accent": None,
1128 | "locale": None,
1129 | "segment": None,
1130 | "duration": duration,
1131 | "dataset": "librispeech"
1132 | }
1133 |
1134 | def _generate_examples(self, filepath, path_to_clips, css10_dir, jsut_dir, nst_metadata_dir,
1135 | nst_files_dir, free_st_dir, arabic_speech_dir, timit_dir,
1136 | librispeech_dirs, max_samples):
1137 | """ Yields examples. """
1138 | _id = 0
1139 |
1140 | for example in self._common_voice_examples_generator(filepath, path_to_clips):
1141 | if _id == max_samples:
1142 | break
1143 | yield _id, example
1144 | _id += 1
1145 |
1146 | if timit_dir is not None and _id < max_samples:
1147 | for example in self._timit_examples_generator(timit_dir):
1148 | if _id < max_samples:
1149 | yield _id, example
1150 | _id += 1
1151 | else:
1152 | break
1153 |
1154 | if css10_dir is not None and _id < max_samples:
1155 | for example in self._css10_examples_generator(css10_dir):
1156 | if _id < max_samples:
1157 | yield _id, example
1158 | _id += 1
1159 | else:
1160 | break
1161 |
1162 | if librispeech_dirs is not None and _id < max_samples:
1163 | for librispeech_dir in librispeech_dirs:
1164 | for example in self._librispeech_examples_generator(librispeech_dir):
1165 | if _id < max_samples:
1166 | yield _id, example
1167 | _id += 1
1168 | else:
1169 | break
1170 |
1171 | if jsut_dir is not None and _id < max_samples:
1172 | for example in self._jsut_examples_generator(jsut_dir):
1173 | if _id < max_samples:
1174 | yield _id, example
1175 | _id += 1
1176 | else:
1177 | break
1178 |
1179 | if nst_files_dir is not None and _id < max_samples:
1180 | for example in self._nst_examples_generator(nst_metadata_dir, nst_files_dir):
1181 | if _id < max_samples:
1182 | yield _id, example
1183 | _id += 1
1184 | else:
1185 | break
1186 |
1187 | if free_st_dir is not None and _id < max_samples:
1188 | for example in self._free_st_examples_generator(free_st_dir):
1189 | if _id < max_samples:
1190 | yield _id, example
1191 | _id += 1
1192 | else:
1193 | break
1194 |
1195 | if arabic_speech_dir is not None and _id < max_samples:
1196 | root_dirs = [arabic_speech_dir, os.path.join(arabic_speech_dir, "test set")]
1197 | for root_dir in root_dirs:
1198 | for example in self._arabic_speech_examples_generator(root_dir):
1199 | if _id < max_samples:
1200 | yield _id, example
1201 | _id += 1
1202 | else:
1203 | break
1204 |
--------------------------------------------------------------------------------