├── vocab.json
├── utils.py
├── LICENSE
├── preprocess.py
├── README.md
├── eval_asr.py
└── train_asr.py
/vocab.json:
--------------------------------------------------------------------------------
1 | {"[PAD]": 0, "": 1, "": 2, "[UNK]": 3, "|": 4, "'": 5, "a": 6, "b": 7, "c": 8, "d": 9, "e": 10, "f": 11, "g": 12, "h": 13, "i": 14, "j": 15, "k": 16, "l": 17, "m": 18, "n": 19, "o": 20, "p": 21, "q": 22, "r": 23, "s": 24, "t": 25, "u": 26, "v": 27, "w": 28, "x": 29, "y": 30, "z": 31, "ç": 32, "ö": 33, "ü": 34, "ğ": 35, "ı": 36, "ş": 37}
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | from unicode_tr import unicode_tr
3 |
4 | chars_to_remove_regex = '[,?.!\-\;\:"“%”�—…–()]'
5 | apostrophes = "[’‘`´ʹʻʼʽʿˈ]"
6 |
7 | def normalize_text(text):
8 |
9 | # Lower the text using 'unicode_tr'
10 | # Regular lower() does not work well for Turkish Language
11 | text_norm = unicode_tr(text).lower()
12 | # Unify apostrophes
13 | text_norm = re.sub(apostrophes, "'", text_norm)
14 | # Remove pre-defined chars
15 | text_norm = re.sub(chars_to_remove_regex, "", text_norm)
16 | # Remove single quotes
17 | text_norm = text_norm.replace(" '", " ")
18 | text_norm = text_norm.replace("' ", " ")
19 | # Handle hatted characters
20 | text_norm = re.sub('[â]', 'a', text_norm)
21 | text_norm = re.sub('[î]', 'i', text_norm)
22 | text_norm = re.sub('[ô]', 'o', text_norm)
23 | text_norm = re.sub('[û]', 'u', text_norm)
24 | # Handle alternate characters
25 | text_norm = re.sub('[é]', 'e', text_norm)
26 | text_norm = re.sub('[ë]', 'e', text_norm)
27 | # Remove multiple spaces
28 | text_norm = re.sub(r"\s+", " ", text_norm)
29 |
30 | return text_norm
31 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Muhammet Poyraz
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import json
4 | import argparse
5 | import pandas as pd
6 | from utils import normalize_text
7 |
8 | def check_invalid_char(sentence, vocab):
9 | return any([ch not in vocab for ch in re.sub(r"\s+", "", sentence)])
10 |
11 | def load_common_voice_corpus(path):
12 |
13 | # Load CommonVoice TSV files
14 | columns_keep = ['path', 'sentence']
15 | df_validated = pd.read_csv(os.path.join(path, 'validated.tsv'), sep='\t')[columns_keep]
16 | df_dev = pd.read_csv(os.path.join(path, 'dev.tsv'), sep='\t')[columns_keep]
17 | df_test = pd.read_csv(os.path.join(path, 'test.tsv'), sep='\t')[columns_keep]
18 |
19 | # Train set = Validated - (Dev + Test)
20 | dev_test_paths = df_dev['path'].to_list() + df_test['path'].to_list()
21 | df_train = df_validated[~df_validated['path'].isin(dev_test_paths)].copy()
22 |
23 | # Add full paths for audio records in all splits
24 | df_train['path'] = df_train['path'].apply(lambda x: os.path.join(path,'clips',x))
25 | df_dev['path'] = df_dev['path'].apply(lambda x: os.path.join(path,'clips',x))
26 | df_test['path'] = df_test['path'].apply(lambda x: os.path.join(path,'clips',x))
27 |
28 | return df_train, df_dev, df_test
29 |
30 | def load_media_speech_corpus(path):
31 |
32 | # Load Media Speech corpus
33 | ms_paths, ms_sentences = [], []
34 | for files in os.listdir(path):
35 | if files.endswith('txt'):
36 | ms_paths.append(os.path.join(path, files.replace('txt','flac')))
37 | with open(os.path.join(path, files)) as fp:
38 | ms_sentences.append(fp.read().strip())
39 |
40 | df_ms = pd.DataFrame({'path': ms_paths, 'sentence': ms_sentences})
41 | return df_ms
42 |
43 | def filter_dataset(df, vocab):
44 | # Normalize sentences
45 | df['sentence'] = df['sentence'].apply(normalize_text)
46 | # Keep samples with valid sentences only
47 | df['isInvalid'] = df['sentence'].apply(check_invalid_char, vocab=vocab)
48 | df = df[df['isInvalid']==0].drop(columns=['isInvalid'])
49 | return df
50 |
51 | def main():
52 |
53 | # Parse arguments
54 | parser = argparse.ArgumentParser(description="Train Wav2vec ASR with CTC")
55 | parser.add_argument("--vocab", type=str, required=True, help="ASR vocabulary of tokens")
56 | parser.add_argument("--cv_path", type=str, required=True, help="Path for CommonVoice TR dataset")
57 | parser.add_argument("--media_speech_path", type=str, help="Path for MediaSpeech TR dataset")
58 | parser.add_argument("--output", type=str, help="Output directory")
59 | args = parser.parse_args()
60 |
61 | # Load vocab
62 | with open(args.vocab) as fp:
63 | vocab_dict = json.load(fp)
64 |
65 | # Load CommonVoice dataset
66 | df_cv_train, df_cv_dev, df_cv_test = load_common_voice_corpus(args.cv_path)
67 |
68 | # Load MediaSpeech dataset
69 | if args.media_speech_path:
70 | df_ms = load_media_speech_corpus(args.media_speech_path)
71 |
72 | # Clean and filter datasets
73 | df_train = filter_dataset(pd.concat([df_cv_train, df_ms], ignore_index=True)
74 | if args.media_speech_path else df_cv_train, vocab_dict)
75 | df_dev = filter_dataset(df_cv_dev, vocab_dict)
76 | df_test = filter_dataset(df_cv_test, vocab_dict)
77 |
78 | # Save
79 | if args.output:
80 | df_train.to_csv(os.path.join(args.output, 'train.csv'), index=False)
81 | df_dev.to_csv(os.path.join(args.output, 'validation.csv'), index=False)
82 | df_test.to_csv(os.path.join(args.output, 'test.csv'), index=False)
83 |
84 | return
85 |
86 | if __name__ == "__main__":
87 | main()
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # wav2vec2-turkish
2 | Turkish Automated Speech Recognition (ASR) using Facebook's Wav2vec 2.0 models
3 |
4 | ## Fine-tuned Models
5 | The following Wav2vec 2.0 models were finetuned during Huggingface's [Robust Speech Challenge](https://github.com/huggingface/transformers/tree/master/examples/research_projects/robust-speech-event) event:
6 | 1. [mpoyraz/wav2vec2-xls-r-300m-cv6-turkish](https://huggingface.co/mpoyraz/wav2vec2-xls-r-300m-cv6-turkish) achives 8.83 % WER on Common Voice 6.1 TR test split
7 | 2. [mpoyraz/wav2vec2-xls-r-300m-cv7-turkish](https://huggingface.co/mpoyraz/wav2vec2-xls-r-300m-cv7-turkish) achives 8.62 % WER on Common Voice 7 TR test split
8 | 3. [mpoyraz/wav2vec2-xls-r-300m-cv8-turkish](https://huggingface.co/mpoyraz/wav2vec2-xls-r-300m-cv8-turkish) achives 10.61 % WER on Common Voice 8 TR test split
9 |
10 | ## Datasets
11 | The following open source speech corpora is available for Turkish:
12 | 1. [Mozilla Common Voice](https://commonvoice.mozilla.org/en/datasets)
13 | 2. [MediaSpeech](https://www.openslr.org/108/)
14 |
15 | This repo contains pre-processing and training scripts for these corpora.
16 |
17 | ## Pre-processing Datasets
18 | After downloading Turkish speech corpora above, `preprocess.py` can be used to create datasets files for training.
19 | - The script handles the text normalization required for proper training.
20 | - Common Voice TR corpus is handled as follows:
21 | - Train split: all samples in `validated` split except `dev` and `test` samples is reserved to training.
22 | - Validation split: same as `dev` split.
23 | - Test split: same as `test` split.
24 | - Media Speech corpus is fully included in the final train split if provided.
25 | - Final datasets CSV files with 'path' & 'sentence' columns are saved to the output directory: `train.csv`, `validation.csv` and `test.csv`
26 |
27 | ```bash
28 | python preprocess.py \
29 | --vocab vocab.json \
30 | --cv_path data/cv-corpus--/tr \
31 | --media_speech_path data/TR \
32 | --output data \
33 | ```
34 | ## Training
35 | [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) is large-scale multilingual pretrained model for speech and used for fine-tuning on Turkish speech corpora. The exact hyperparameters used are available at model card on each finetuned model on Huggingface model hub.
36 |
37 | An example training command:
38 |
39 | ```bash
40 | python train_asr.py \
41 | --model_name_or_path facebook/wav2vec2-xls-r-300m \
42 | --vocab_path vocab.json \
43 | --train_file train_validation.csv \
44 | --validation_file test.csv \
45 | --output_dir exp \
46 | --audio_path_column_name path \
47 | --text_column_name sentence \
48 | --preprocessing_num_workers 4 \
49 | --dataloader_num_workers 4 \
50 | --eval_metrics wer cer \
51 | --freeze_feature_extractor \
52 | --mask_time_prob 0.1 \
53 | --mask_feature_prob 0.1 \
54 | --attention_dropout 0.05 \
55 | --activation_dropout 0.05 \
56 | --feat_proj_dropout 0.05 \
57 | --final_dropout 0.1 \
58 | --learning_rate 2.5e-4 \
59 | --per_device_train_batch_size 8 \
60 | --per_device_eval_batch_size 8 \
61 | --gradient_accumulation_steps 8 \
62 | --num_train_epochs 20 \
63 | --warmup_steps 500 \
64 | --eval_steps 500 \
65 | --save_steps 500 \
66 | --evaluation_strategy steps \
67 | --save_total_limit 2 \
68 | --gradient_checkpointing \
69 | --fp16 \
70 | --group_by_length \
71 | --do_train \
72 | --do_eval \
73 | ```
74 | ## Evaluation
75 | The following finetuned models are available on Huggingface model hub and has an evaluation script `eval.py` with appropiate text normalization. The commands for running evaluations are also available on the model cards.
76 | 1. [mpoyraz/wav2vec2-xls-r-300m-cv6-turkish](https://huggingface.co/mpoyraz/wav2vec2-xls-r-300m-cv6-turkish) achives 8.83 % WER on Common Voice 6.1 TR test split
77 | 2. [mpoyraz/wav2vec2-xls-r-300m-cv7-turkish](https://huggingface.co/mpoyraz/wav2vec2-xls-r-300m-cv7-turkish) achives 8.62 % WER on Common Voice 7 TR test split
78 | 3. [mpoyraz/wav2vec2-xls-r-300m-cv8-turkish](https://huggingface.co/mpoyraz/wav2vec2-xls-r-300m-cv8-turkish) achives 10.61 % WER on Common Voice 8 TR test split
79 |
80 | ## Language Model
81 | For CTC beam search decoding with shallow LM fusion, n-gram language model is trained on a Turkish Wikipedia articles using KenLM and [ngram-lm-wiki](https://github.com/mpoyraz/ngram-lm-wiki) repo was used to generate arpa LM and convert it into binary format.
82 |
--------------------------------------------------------------------------------
/eval_asr.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import logging
3 | import argparse
4 | import torch
5 | import torchaudio
6 | from transformers import (
7 | AutoModelForCTC,
8 | AutoProcessor,
9 | )
10 | from datasets import DatasetDict, load_dataset, load_metric, set_caching_enabled
11 | from utils import normalize_text
12 |
13 | set_caching_enabled(False)
14 | logger = logging.getLogger(__name__)
15 |
16 | def main():
17 | # Parse arguments
18 | parser = argparse.ArgumentParser(description="Evaluate Wav2vec ASR with CTC")
19 | parser.add_argument("--model_name_or_path", type=str, required=True,
20 | help="Path to pretrained model or model identifier from huggingface.co/models")
21 | parser.add_argument("--dataset_name", type=str, default="common_voice",
22 | help="The configuration name of the dataset to use (via the datasets library)")
23 | parser.add_argument("--dataset_config_name", type=str, default="tr",
24 | help="The configuration name of the dataset to use (via the datasets library)")
25 | parser.add_argument("--eval_split_name", type=str, default="test",
26 | help="The name of the evaluation data set split to use (via the datasets library)")
27 | parser.add_argument("--use_auth_token", action='store_true',
28 | help="Use authentication for loading the dataset (via the datasets library)")
29 | parser.add_argument("--audio_column_name", type=str, default="audio",
30 | help="The name of the dataset column containing the audio data")
31 | parser.add_argument("--text_column_name", type=str, default="sentence",
32 | help="The name of the dataset column containing the text data")
33 | parser.add_argument("--preprocessing_num_workers", type=int, default=1,
34 | help="The number of processes to use for the preprocessing")
35 | parser.add_argument("--batch_size", type=int, default=4,
36 | help="The batch size for evaluation")
37 | args = parser.parse_args()
38 |
39 | # Setup logging
40 | logging.basicConfig(
41 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
42 | datefmt="%m/%d/%Y %H:%M:%S",
43 | handlers=[logging.StreamHandler(sys.stdout)],
44 | level=logging.INFO
45 | )
46 |
47 | # Torch device to run evaluation
48 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
49 | logging.info("Using {} for evaluation".format(device))
50 |
51 | # Wav2vec2 processor and model
52 | processor = AutoProcessor.from_pretrained(args.model_name_or_path)
53 | model = AutoModelForCTC.from_pretrained(args.model_name_or_path)
54 | model = model.to(device)
55 |
56 | # Load evaluation dataset
57 | eval_dataset = load_dataset(args.dataset_name, args.dataset_config_name,
58 | split=args.eval_split_name, use_auth_token=args.use_auth_token)
59 | logging.info("Dataset '{}' - split '{}' has {} records".format(
60 | args.dataset_name, args.eval_split_name, len(eval_dataset)))
61 |
62 | # Preprocess text and resample audio
63 | def preprocess(sample):
64 | # Normalize text
65 | text = sample[args.text_column_name]
66 | sample['text'] = normalize_text(text)
67 | # Resample audio array
68 | resampler = torchaudio.transforms.Resample(
69 | sample[args.audio_column_name]["sampling_rate"],
70 | processor.feature_extractor.sampling_rate
71 | )
72 | array_pt = torch.from_numpy(sample[args.audio_column_name]["array"]).unsqueeze(0)
73 | sample['audio_array'] = resampler(array_pt).squeeze().numpy()
74 | return sample
75 |
76 | eval_dataset = eval_dataset.map(
77 | preprocess, num_proc=args.preprocessing_num_workers
78 | )
79 |
80 | # Predict on eval dataset
81 | def predict(batch):
82 | inputs = processor(batch['audio_array'],
83 | sampling_rate=processor.feature_extractor.sampling_rate,
84 | return_tensors="pt", padding=True)
85 | # Move torch tensor to the device
86 | for k in inputs.keys():
87 | if inputs[k] is not None and torch.is_tensor(inputs[k]):
88 | inputs[k] = inputs[k].to(device)
89 | # Predict
90 | with torch.no_grad():
91 | logits = model(**inputs).logits
92 | # Decode with LM
93 | if hasattr(processor, 'decoder'):
94 | decode_results = processor.batch_decode(logits.cpu().numpy())
95 | batch["pred_strings"] = decode_results.text
96 | else: # No LM
97 | pred_ids = torch.argmax(logits, dim=-1)
98 | batch["pred_strings"] = processor.batch_decode(pred_ids)
99 | return batch
100 |
101 | eval_dataset = eval_dataset.map(
102 | predict, batched=True, batch_size=args.batch_size
103 | )
104 |
105 | # Load metrics and calculate on eval dataset
106 | wer, cer = load_metric("wer"), load_metric("cer")
107 | wer_score = wer.compute(predictions=eval_dataset["pred_strings"], references=eval_dataset["text"])
108 | cer_score = cer.compute(predictions=eval_dataset["pred_strings"], references=eval_dataset["text"])
109 | logging.info("WER: {:.2f} % , CER: {:.2f} %".format(100*wer_score, 100*cer_score))
110 |
111 | if __name__ == "__main__":
112 | main()
--------------------------------------------------------------------------------
/train_asr.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import json
3 | import logging
4 | import os
5 | import re
6 | import sys
7 | import warnings
8 | from dataclasses import dataclass, field
9 | from typing import Dict, List, Optional, Union
10 |
11 | import numpy as np
12 | import torch
13 | import torchaudio
14 | import datasets
15 | from datasets import DatasetDict, load_dataset, load_metric, set_caching_enabled
16 |
17 | import transformers
18 | from transformers import (
19 | AutoConfig,
20 | AutoFeatureExtractor,
21 | AutoModelForCTC,
22 | AutoProcessor,
23 | AutoTokenizer,
24 | HfArgumentParser,
25 | Trainer,
26 | TrainingArguments,
27 | Wav2Vec2Processor,
28 | Wav2Vec2CTCTokenizer,
29 | Wav2Vec2FeatureExtractor,
30 | set_seed,
31 | )
32 | from transformers.trainer_utils import get_last_checkpoint, is_main_process
33 | from transformers.utils import check_min_version
34 | from transformers.utils.versions import require_version
35 |
36 | logger = logging.getLogger(__name__)
37 | set_caching_enabled(False)
38 |
39 | def list_field(default=None, metadata=None):
40 | return field(default_factory=lambda: default, metadata=metadata)
41 |
42 | @dataclass
43 | class ModelArguments:
44 | """
45 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
46 | """
47 |
48 | model_name_or_path: str = field(
49 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
50 | )
51 | vocab_path: str = field(
52 | metadata={"help": "Path to ASR vocabulary, tokens as JSON file"}
53 | )
54 | cache_dir: Optional[str] = field(
55 | default=None,
56 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
57 | )
58 | freeze_feature_extractor: Optional[bool] = field(
59 | default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
60 | )
61 | attention_dropout: Optional[float] = field(
62 | default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."}
63 | )
64 | activation_dropout: Optional[float] = field(
65 | default=0.0, metadata={"help": "The dropout ratio for activations inside the fully connected layer."}
66 | )
67 | feat_proj_dropout: Optional[float] = field(
68 | default=0.0, metadata={"help": "The dropout ratio for the projected features."}
69 | )
70 | hidden_dropout: Optional[float] = field(
71 | default=0.0,
72 | metadata={
73 | "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
74 | },
75 | )
76 | final_dropout: Optional[float] = field(
77 | default=0.0,
78 | metadata={"help": "The dropout probability for the final projection layer."},
79 | )
80 | mask_time_prob: Optional[float] = field(
81 | default=0.05,
82 | metadata={
83 | "help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
84 | "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
85 | "vectors will be masked along the time axis."
86 | },
87 | )
88 | mask_time_length: Optional[int] = field(
89 | default=10,
90 | metadata={"help": "Length of vector span to mask along the time axis."},
91 | )
92 | mask_feature_prob: Optional[float] = field(
93 | default=0.0,
94 | metadata={
95 | "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
96 | "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
97 | },
98 | )
99 | mask_feature_length: Optional[int] = field(
100 | default=10,
101 | metadata={"help": "Length of vector span to mask along the feature axis."},
102 | )
103 | layerdrop: Optional[float] = field(default=0.0, metadata={"help": "The LayerDrop probability."})
104 | ctc_loss_reduction: Optional[str] = field(
105 | default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
106 | )
107 |
108 |
109 | @dataclass
110 | class DataTrainingArguments:
111 | """
112 | Arguments pertaining to what data we are going to input our model for training and eval.
113 | Using `HfArgumentParser` we can turn this class
114 | into argparse arguments to be able to specify them on
115 | the command line.
116 | """
117 |
118 | train_file: str = field(
119 | metadata={"help": "The training data file (CSV file)."}
120 | )
121 | validation_file: Optional[str] = field(
122 | default=None,
123 | metadata={"help": "An optional evaluation data file to evaluate on (CSV file)."},
124 | )
125 | delimiter: Optional[str] = field(
126 | default=",",
127 | metadata={"help": "Specifies the character delimiting individual cells in the CSV data"},
128 | )
129 | audio_path_column_name: Optional[str] = field(
130 | default="path",
131 | metadata={"help": "The name of the dataset column containing the audio paths. Defaults to 'path'"},
132 | )
133 | text_column_name: Optional[str] = field(
134 | default="text",
135 | metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
136 | )
137 | overwrite_cache: bool = field(
138 | default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
139 | )
140 | preprocessing_num_workers: Optional[int] = field(
141 | default=None,
142 | metadata={"help": "The number of processes to use for the preprocessing."},
143 | )
144 | eval_metrics: Optional[List[str]] = list_field(
145 | default=["wer"],
146 | metadata={"help": "A list of metrics the model should be evaluated on. E.g. `'wer cer'`"},
147 | )
148 | max_duration_in_seconds: Optional[float] = field(
149 | default=20.0,
150 | metadata={
151 | "help": "Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
152 | },
153 | )
154 | min_duration_in_seconds: Optional[float] = field(
155 | default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
156 | )
157 | unk_token: Optional[str] = field(
158 | default="[UNK]",
159 | metadata={"help": "The unk token for the tokenizer"},
160 | )
161 | pad_token: Optional[str] = field(
162 | default="[PAD]",
163 | metadata={"help": "The padding token for the tokenizer"},
164 | )
165 | word_delimiter_token: Optional[str] = field(
166 | default="|",
167 | metadata={"help": "The word delimiter token for the tokenizer"},
168 | )
169 |
170 | @dataclass
171 | class DataCollatorCTCWithPadding:
172 | """
173 | Data collator that will dynamically pad the inputs received.
174 | Args:
175 | processor (:class:`~transformers.AutoProcessor`)
176 | The processor used for proccessing the data.
177 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
178 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
179 | among:
180 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
181 | sequence if provided).
182 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
183 | maximum acceptable input length for the model if that argument is not provided.
184 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
185 | different lengths).
186 | max_length (:obj:`int`, `optional`):
187 | Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
188 | max_length_labels (:obj:`int`, `optional`):
189 | Maximum length of the ``labels`` returned list and optionally padding length (see above).
190 | pad_to_multiple_of (:obj:`int`, `optional`):
191 | If set will pad the sequence to a multiple of the provided value.
192 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
193 | 7.5 (Volta).
194 | """
195 |
196 | processor: AutoProcessor
197 | padding: Union[bool, str] = "longest"
198 | pad_to_multiple_of: Optional[int] = None
199 | pad_to_multiple_of_labels: Optional[int] = None
200 |
201 | def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
202 | # split inputs and labels since they have to be of different lenghts and need
203 | # different padding methods
204 | input_features = [{"input_values": feature["input_values"]} for feature in features]
205 | label_features = [{"input_ids": feature["labels"]} for feature in features]
206 |
207 | batch = self.processor.pad(
208 | input_features,
209 | padding=self.padding,
210 | pad_to_multiple_of=self.pad_to_multiple_of,
211 | return_tensors="pt",
212 | )
213 |
214 | with self.processor.as_target_processor():
215 | labels_batch = self.processor.pad(
216 | label_features,
217 | padding=self.padding,
218 | pad_to_multiple_of=self.pad_to_multiple_of_labels,
219 | return_tensors="pt",
220 | )
221 |
222 | # replace padding with -100 to ignore loss correctly
223 | labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
224 |
225 | batch["labels"] = labels
226 |
227 | return batch
228 |
229 | def main():
230 | # See all possible arguments in src/transformers/training_args.py
231 | # or by passing the --help flag to this script.
232 | # We now keep distinct sets of args, for a cleaner separation of concerns.
233 |
234 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
235 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
236 | # If we pass only one argument to the script and it's the path to a json file,
237 | # let's parse it to get our arguments.
238 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
239 | else:
240 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
241 |
242 | # Detecting last checkpoint.
243 | last_checkpoint = None
244 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
245 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
246 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
247 | raise ValueError(
248 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
249 | "Use --overwrite_output_dir to overcome."
250 | )
251 | elif last_checkpoint is not None:
252 | logger.info(
253 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
254 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
255 | )
256 |
257 | # Setup logging
258 | logging.basicConfig(
259 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
260 | datefmt="%m/%d/%Y %H:%M:%S",
261 | handlers=[logging.StreamHandler(sys.stdout)],
262 | )
263 | logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
264 |
265 | # Log on each process the small summary:
266 | logger.warning(
267 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
268 | f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
269 | )
270 | # Set the verbosity to info of the Transformers logger (on main process only):
271 | if is_main_process(training_args.local_rank):
272 | transformers.utils.logging.set_verbosity_info()
273 | logger.info("Training/evaluation parameters %s", training_args)
274 |
275 | # Set seed before initializing model.
276 | set_seed(training_args.seed)
277 |
278 | # Load the model config
279 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
280 |
281 | # We need to make sure that only first rank saves vocabulary
282 | # make sure all processes wait until vocab is created
283 | tokenizer_name_or_path = training_args.output_dir
284 |
285 | with training_args.main_process_first():
286 | # Load vocab from file
287 | with open(model_args.vocab_path) as fp:
288 | vocab_dict = json.load(fp)
289 |
290 | vocab_file = os.path.join(tokenizer_name_or_path, "vocab.json")
291 | if training_args.overwrite_output_dir and os.path.isfile(vocab_file):
292 | os.remove(vocab_file)
293 |
294 | # Save vocab dict to be loaded into tokenizer
295 | if not os.path.isfile(vocab_file):
296 | with open(vocab_file, "w") as file:
297 | json.dump(vocab_dict, file)
298 |
299 | # Tokenizer args
300 | tokenizer_kwargs = {
301 | "config": config if config.tokenizer_class is not None else None,
302 | "tokenizer_type": config.model_type if config.tokenizer_class is None else None,
303 | "unk_token": data_args.unk_token,
304 | "pad_token": data_args.pad_token,
305 | "word_delimiter_token": data_args.word_delimiter_token,
306 | }
307 |
308 | # Load feature_extractor and tokenizer
309 | tokenizer = AutoTokenizer.from_pretrained(
310 | tokenizer_name_or_path,
311 | **tokenizer_kwargs,
312 | )
313 |
314 | # Load feature_extractor
315 | feature_extractor = AutoFeatureExtractor.from_pretrained(
316 | model_args.model_name_or_path, cache_dir=model_args.cache_dir,
317 | )
318 | sampling_rate_target = feature_extractor.sampling_rate
319 |
320 | # Update config for finetuning
321 | config.update(
322 | {
323 | "feat_proj_dropout": model_args.feat_proj_dropout,
324 | "attention_dropout": model_args.attention_dropout,
325 | "hidden_dropout": model_args.hidden_dropout,
326 | "final_dropout": model_args.final_dropout,
327 | "mask_time_prob": model_args.mask_time_prob,
328 | "mask_time_length": model_args.mask_time_length,
329 | "mask_feature_prob": model_args.mask_feature_prob,
330 | "mask_feature_length": model_args.mask_feature_length,
331 | "gradient_checkpointing": training_args.gradient_checkpointing,
332 | "layerdrop": model_args.layerdrop,
333 | "ctc_loss_reduction": model_args.ctc_loss_reduction,
334 | "pad_token_id": tokenizer.pad_token_id,
335 | "vocab_size": len(tokenizer),
336 | "activation_dropout": model_args.activation_dropout,
337 | "bos_token_id" : vocab_dict[""],
338 | "eos_token_id" : vocab_dict[""],
339 | "pad_token_id" : vocab_dict[data_args.pad_token]
340 | }
341 | )
342 |
343 | # create model
344 | model = AutoModelForCTC.from_pretrained(
345 | model_args.model_name_or_path,
346 | cache_dir=model_args.cache_dir,
347 | config=config,
348 | )
349 |
350 | # Freeze encoder
351 | if model_args.freeze_feature_extractor:
352 | model.freeze_feature_extractor()
353 |
354 | # Create a single processor
355 | if is_main_process(training_args.local_rank):
356 | # save feature extractor, tokenizer and config
357 | feature_extractor.save_pretrained(training_args.output_dir)
358 | tokenizer.save_pretrained(training_args.output_dir)
359 | config.save_pretrained(training_args.output_dir)
360 |
361 | # Load processor
362 | try:
363 | processor = AutoProcessor.from_pretrained(training_args.output_dir)
364 | except (OSError, KeyError):
365 | warnings.warn(
366 | "Loading a processor from a feature extractor config that does not"
367 | " include a `processor_class` attribute is deprecated and will be removed in v5. Please add the following "
368 | " attribute to your `preprocessor_config.json` file to suppress this warning: "
369 | " `'processor_class': 'Wav2Vec2Processor'`",
370 | FutureWarning,
371 | )
372 | processor = Wav2Vec2Processor.from_pretrained(training_args.output_dir)
373 |
374 | # Custom data collator
375 | data_collator = DataCollatorCTCWithPadding(processor=processor)
376 |
377 | # Define evaluation metrics during training, *i.e.* word error rate, character error rate
378 | eval_metrics = {metric: load_metric(metric) for metric in data_args.eval_metrics}
379 |
380 | def compute_metrics(pred):
381 | # Prediction ids
382 | pred_ids = np.argmax(pred.predictions, axis=-1)
383 | # Convert -100 back to padding token id
384 | pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
385 | # Prediction and label strings
386 | # we do not want to group tokens when computing the metrics
387 | pred_str = tokenizer.batch_decode(pred_ids)
388 | label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)
389 | # Calcualte the metrics
390 | metrics = {k: v.compute(predictions=pred_str, references=label_str) for k, v in eval_metrics.items()}
391 | return metrics
392 |
393 | # Load the dataset from your local files.
394 | data_files = {"train": data_args.train_file, "validation": data_args.validation_file}
395 | datasets = load_dataset("csv", data_files=data_files,
396 | delimiter=data_args.delimiter, cache_dir=model_args.cache_dir)
397 |
398 | # Function to load audio and resample
399 | def load_audio_and_resample(path):
400 | # Load audio
401 | audio_array, sampling_rate = torchaudio.load(path)
402 | # Resample
403 | resampler = torchaudio.transforms.Resample(sampling_rate, sampling_rate_target)
404 | audio_array = resampler(audio_array).squeeze().numpy()
405 | return audio_array
406 |
407 | # Function to prepare inputs & targets
408 | def prepare_dataset(sample):
409 | # Load audio
410 | audio_array = load_audio_and_resample(sample[data_args.audio_path_column_name])
411 | # Input features
412 | inputs = feature_extractor(audio_array, sampling_rate=sampling_rate_target)
413 | sample["input_values"] = inputs.input_values[0]
414 | sample["input_length"] = len(sample["input_values"])
415 | # Encode targets
416 | sample["labels"] = tokenizer(sample[data_args.text_column_name]).input_ids
417 | return sample
418 |
419 | # Max & min input length for sample rate & max duration
420 | max_input_length = data_args.max_duration_in_seconds * sampling_rate_target
421 | min_input_length = data_args.min_duration_in_seconds * sampling_rate_target
422 |
423 | with training_args.main_process_first(desc="dataset map preprocessing"):
424 | # Prepare input features and targets
425 | datasets = datasets.map(
426 | prepare_dataset,
427 | remove_columns=[data_args.audio_path_column_name, data_args.text_column_name],
428 | desc="preprocess datasets",
429 | num_proc=data_args.preprocessing_num_workers
430 | )
431 |
432 | # Filter data samples based on length
433 | def is_audio_in_length_range(length):
434 | return length > min_input_length and length < max_input_length
435 |
436 | datasets = datasets.filter(
437 | is_audio_in_length_range,
438 | num_proc=data_args.preprocessing_num_workers,
439 | input_columns=["input_length"],
440 | )
441 |
442 | # Initialize Trainer
443 | trainer = Trainer(
444 | model=model,
445 | data_collator=data_collator,
446 | args=training_args,
447 | compute_metrics=compute_metrics,
448 | train_dataset=datasets["train"],
449 | eval_dataset=datasets["validation"] if data_args.validation_file else None,
450 | tokenizer=feature_extractor,
451 | )
452 |
453 | # Training
454 | if training_args.do_train:
455 | # Use last checkpoint if exist
456 | if last_checkpoint is not None:
457 | checkpoint = last_checkpoint
458 | elif os.path.isdir(model_args.model_name_or_path):
459 | checkpoint = model_args.model_name_or_path
460 | else:
461 | checkpoint = None
462 |
463 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
464 | trainer.save_model()
465 |
466 | # Train metrics
467 | train_metrics = train_result.metrics
468 | train_metrics["train_samples"] = len(datasets["train"])
469 | trainer.log_metrics("train", train_metrics)
470 | trainer.save_metrics("train", train_metrics)
471 | trainer.save_state()
472 |
473 | # Evaluation
474 | results = {}
475 | if training_args.do_eval:
476 | logger.info("*** Evaluate ***")
477 | eval_metrics = trainer.evaluate()
478 | # Evaluation metrics
479 | eval_metrics["eval_samples"] = len(datasets["validation"])
480 | trainer.log_metrics("eval", eval_metrics)
481 | trainer.save_metrics("eval", eval_metrics)
482 |
483 | if __name__ == "__main__":
484 | main()
485 |
--------------------------------------------------------------------------------