├── vocab.json ├── utils.py ├── LICENSE ├── preprocess.py ├── README.md ├── eval_asr.py └── train_asr.py /vocab.json: -------------------------------------------------------------------------------- 1 | {"[PAD]": 0, "": 1, "": 2, "[UNK]": 3, "|": 4, "'": 5, "a": 6, "b": 7, "c": 8, "d": 9, "e": 10, "f": 11, "g": 12, "h": 13, "i": 14, "j": 15, "k": 16, "l": 17, "m": 18, "n": 19, "o": 20, "p": 21, "q": 22, "r": 23, "s": 24, "t": 25, "u": 26, "v": 27, "w": 28, "x": 29, "y": 30, "z": 31, "ç": 32, "ö": 33, "ü": 34, "ğ": 35, "ı": 36, "ş": 37} -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from unicode_tr import unicode_tr 3 | 4 | chars_to_remove_regex = '[,?.!\-\;\:"“%”�—…–()]' 5 | apostrophes = "[’‘`´ʹʻʼʽʿˈ]" 6 | 7 | def normalize_text(text): 8 | 9 | # Lower the text using 'unicode_tr' 10 | # Regular lower() does not work well for Turkish Language 11 | text_norm = unicode_tr(text).lower() 12 | # Unify apostrophes 13 | text_norm = re.sub(apostrophes, "'", text_norm) 14 | # Remove pre-defined chars 15 | text_norm = re.sub(chars_to_remove_regex, "", text_norm) 16 | # Remove single quotes 17 | text_norm = text_norm.replace(" '", " ") 18 | text_norm = text_norm.replace("' ", " ") 19 | # Handle hatted characters 20 | text_norm = re.sub('[â]', 'a', text_norm) 21 | text_norm = re.sub('[î]', 'i', text_norm) 22 | text_norm = re.sub('[ô]', 'o', text_norm) 23 | text_norm = re.sub('[û]', 'u', text_norm) 24 | # Handle alternate characters 25 | text_norm = re.sub('[é]', 'e', text_norm) 26 | text_norm = re.sub('[ë]', 'e', text_norm) 27 | # Remove multiple spaces 28 | text_norm = re.sub(r"\s+", " ", text_norm) 29 | 30 | return text_norm 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Muhammet Poyraz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import argparse 5 | import pandas as pd 6 | from utils import normalize_text 7 | 8 | def check_invalid_char(sentence, vocab): 9 | return any([ch not in vocab for ch in re.sub(r"\s+", "", sentence)]) 10 | 11 | def load_common_voice_corpus(path): 12 | 13 | # Load CommonVoice TSV files 14 | columns_keep = ['path', 'sentence'] 15 | df_validated = pd.read_csv(os.path.join(path, 'validated.tsv'), sep='\t')[columns_keep] 16 | df_dev = pd.read_csv(os.path.join(path, 'dev.tsv'), sep='\t')[columns_keep] 17 | df_test = pd.read_csv(os.path.join(path, 'test.tsv'), sep='\t')[columns_keep] 18 | 19 | # Train set = Validated - (Dev + Test) 20 | dev_test_paths = df_dev['path'].to_list() + df_test['path'].to_list() 21 | df_train = df_validated[~df_validated['path'].isin(dev_test_paths)].copy() 22 | 23 | # Add full paths for audio records in all splits 24 | df_train['path'] = df_train['path'].apply(lambda x: os.path.join(path,'clips',x)) 25 | df_dev['path'] = df_dev['path'].apply(lambda x: os.path.join(path,'clips',x)) 26 | df_test['path'] = df_test['path'].apply(lambda x: os.path.join(path,'clips',x)) 27 | 28 | return df_train, df_dev, df_test 29 | 30 | def load_media_speech_corpus(path): 31 | 32 | # Load Media Speech corpus 33 | ms_paths, ms_sentences = [], [] 34 | for files in os.listdir(path): 35 | if files.endswith('txt'): 36 | ms_paths.append(os.path.join(path, files.replace('txt','flac'))) 37 | with open(os.path.join(path, files)) as fp: 38 | ms_sentences.append(fp.read().strip()) 39 | 40 | df_ms = pd.DataFrame({'path': ms_paths, 'sentence': ms_sentences}) 41 | return df_ms 42 | 43 | def filter_dataset(df, vocab): 44 | # Normalize sentences 45 | df['sentence'] = df['sentence'].apply(normalize_text) 46 | # Keep samples with valid sentences only 47 | df['isInvalid'] = df['sentence'].apply(check_invalid_char, vocab=vocab) 48 | df = df[df['isInvalid']==0].drop(columns=['isInvalid']) 49 | return df 50 | 51 | def main(): 52 | 53 | # Parse arguments 54 | parser = argparse.ArgumentParser(description="Train Wav2vec ASR with CTC") 55 | parser.add_argument("--vocab", type=str, required=True, help="ASR vocabulary of tokens") 56 | parser.add_argument("--cv_path", type=str, required=True, help="Path for CommonVoice TR dataset") 57 | parser.add_argument("--media_speech_path", type=str, help="Path for MediaSpeech TR dataset") 58 | parser.add_argument("--output", type=str, help="Output directory") 59 | args = parser.parse_args() 60 | 61 | # Load vocab 62 | with open(args.vocab) as fp: 63 | vocab_dict = json.load(fp) 64 | 65 | # Load CommonVoice dataset 66 | df_cv_train, df_cv_dev, df_cv_test = load_common_voice_corpus(args.cv_path) 67 | 68 | # Load MediaSpeech dataset 69 | if args.media_speech_path: 70 | df_ms = load_media_speech_corpus(args.media_speech_path) 71 | 72 | # Clean and filter datasets 73 | df_train = filter_dataset(pd.concat([df_cv_train, df_ms], ignore_index=True) 74 | if args.media_speech_path else df_cv_train, vocab_dict) 75 | df_dev = filter_dataset(df_cv_dev, vocab_dict) 76 | df_test = filter_dataset(df_cv_test, vocab_dict) 77 | 78 | # Save 79 | if args.output: 80 | df_train.to_csv(os.path.join(args.output, 'train.csv'), index=False) 81 | df_dev.to_csv(os.path.join(args.output, 'validation.csv'), index=False) 82 | df_test.to_csv(os.path.join(args.output, 'test.csv'), index=False) 83 | 84 | return 85 | 86 | if __name__ == "__main__": 87 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wav2vec2-turkish 2 | Turkish Automated Speech Recognition (ASR) using Facebook's Wav2vec 2.0 models 3 | 4 | ## Fine-tuned Models 5 | The following Wav2vec 2.0 models were finetuned during Huggingface's [Robust Speech Challenge](https://github.com/huggingface/transformers/tree/master/examples/research_projects/robust-speech-event) event: 6 | 1. [mpoyraz/wav2vec2-xls-r-300m-cv6-turkish](https://huggingface.co/mpoyraz/wav2vec2-xls-r-300m-cv6-turkish) achives 8.83 % WER on Common Voice 6.1 TR test split 7 | 2. [mpoyraz/wav2vec2-xls-r-300m-cv7-turkish](https://huggingface.co/mpoyraz/wav2vec2-xls-r-300m-cv7-turkish) achives 8.62 % WER on Common Voice 7 TR test split 8 | 3. [mpoyraz/wav2vec2-xls-r-300m-cv8-turkish](https://huggingface.co/mpoyraz/wav2vec2-xls-r-300m-cv8-turkish) achives 10.61 % WER on Common Voice 8 TR test split 9 | 10 | ## Datasets 11 | The following open source speech corpora is available for Turkish: 12 | 1. [Mozilla Common Voice](https://commonvoice.mozilla.org/en/datasets) 13 | 2. [MediaSpeech](https://www.openslr.org/108/) 14 | 15 | This repo contains pre-processing and training scripts for these corpora. 16 | 17 | ## Pre-processing Datasets 18 | After downloading Turkish speech corpora above, `preprocess.py` can be used to create datasets files for training. 19 | - The script handles the text normalization required for proper training. 20 | - Common Voice TR corpus is handled as follows: 21 | - Train split: all samples in `validated` split except `dev` and `test` samples is reserved to training. 22 | - Validation split: same as `dev` split. 23 | - Test split: same as `test` split. 24 | - Media Speech corpus is fully included in the final train split if provided. 25 | - Final datasets CSV files with 'path' & 'sentence' columns are saved to the output directory: `train.csv`, `validation.csv` and `test.csv` 26 | 27 | ```bash 28 | python preprocess.py \ 29 | --vocab vocab.json \ 30 | --cv_path data/cv-corpus--/tr \ 31 | --media_speech_path data/TR \ 32 | --output data \ 33 | ``` 34 | ## Training 35 | [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) is large-scale multilingual pretrained model for speech and used for fine-tuning on Turkish speech corpora. The exact hyperparameters used are available at model card on each finetuned model on Huggingface model hub. 36 | 37 | An example training command: 38 | 39 | ```bash 40 | python train_asr.py \ 41 | --model_name_or_path facebook/wav2vec2-xls-r-300m \ 42 | --vocab_path vocab.json \ 43 | --train_file train_validation.csv \ 44 | --validation_file test.csv \ 45 | --output_dir exp \ 46 | --audio_path_column_name path \ 47 | --text_column_name sentence \ 48 | --preprocessing_num_workers 4 \ 49 | --dataloader_num_workers 4 \ 50 | --eval_metrics wer cer \ 51 | --freeze_feature_extractor \ 52 | --mask_time_prob 0.1 \ 53 | --mask_feature_prob 0.1 \ 54 | --attention_dropout 0.05 \ 55 | --activation_dropout 0.05 \ 56 | --feat_proj_dropout 0.05 \ 57 | --final_dropout 0.1 \ 58 | --learning_rate 2.5e-4 \ 59 | --per_device_train_batch_size 8 \ 60 | --per_device_eval_batch_size 8 \ 61 | --gradient_accumulation_steps 8 \ 62 | --num_train_epochs 20 \ 63 | --warmup_steps 500 \ 64 | --eval_steps 500 \ 65 | --save_steps 500 \ 66 | --evaluation_strategy steps \ 67 | --save_total_limit 2 \ 68 | --gradient_checkpointing \ 69 | --fp16 \ 70 | --group_by_length \ 71 | --do_train \ 72 | --do_eval \ 73 | ``` 74 | ## Evaluation 75 | The following finetuned models are available on Huggingface model hub and has an evaluation script `eval.py` with appropiate text normalization. The commands for running evaluations are also available on the model cards. 76 | 1. [mpoyraz/wav2vec2-xls-r-300m-cv6-turkish](https://huggingface.co/mpoyraz/wav2vec2-xls-r-300m-cv6-turkish) achives 8.83 % WER on Common Voice 6.1 TR test split 77 | 2. [mpoyraz/wav2vec2-xls-r-300m-cv7-turkish](https://huggingface.co/mpoyraz/wav2vec2-xls-r-300m-cv7-turkish) achives 8.62 % WER on Common Voice 7 TR test split 78 | 3. [mpoyraz/wav2vec2-xls-r-300m-cv8-turkish](https://huggingface.co/mpoyraz/wav2vec2-xls-r-300m-cv8-turkish) achives 10.61 % WER on Common Voice 8 TR test split 79 | 80 | ## Language Model 81 | For CTC beam search decoding with shallow LM fusion, n-gram language model is trained on a Turkish Wikipedia articles using KenLM and [ngram-lm-wiki](https://github.com/mpoyraz/ngram-lm-wiki) repo was used to generate arpa LM and convert it into binary format. 82 | -------------------------------------------------------------------------------- /eval_asr.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | import argparse 4 | import torch 5 | import torchaudio 6 | from transformers import ( 7 | AutoModelForCTC, 8 | AutoProcessor, 9 | ) 10 | from datasets import DatasetDict, load_dataset, load_metric, set_caching_enabled 11 | from utils import normalize_text 12 | 13 | set_caching_enabled(False) 14 | logger = logging.getLogger(__name__) 15 | 16 | def main(): 17 | # Parse arguments 18 | parser = argparse.ArgumentParser(description="Evaluate Wav2vec ASR with CTC") 19 | parser.add_argument("--model_name_or_path", type=str, required=True, 20 | help="Path to pretrained model or model identifier from huggingface.co/models") 21 | parser.add_argument("--dataset_name", type=str, default="common_voice", 22 | help="The configuration name of the dataset to use (via the datasets library)") 23 | parser.add_argument("--dataset_config_name", type=str, default="tr", 24 | help="The configuration name of the dataset to use (via the datasets library)") 25 | parser.add_argument("--eval_split_name", type=str, default="test", 26 | help="The name of the evaluation data set split to use (via the datasets library)") 27 | parser.add_argument("--use_auth_token", action='store_true', 28 | help="Use authentication for loading the dataset (via the datasets library)") 29 | parser.add_argument("--audio_column_name", type=str, default="audio", 30 | help="The name of the dataset column containing the audio data") 31 | parser.add_argument("--text_column_name", type=str, default="sentence", 32 | help="The name of the dataset column containing the text data") 33 | parser.add_argument("--preprocessing_num_workers", type=int, default=1, 34 | help="The number of processes to use for the preprocessing") 35 | parser.add_argument("--batch_size", type=int, default=4, 36 | help="The batch size for evaluation") 37 | args = parser.parse_args() 38 | 39 | # Setup logging 40 | logging.basicConfig( 41 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 42 | datefmt="%m/%d/%Y %H:%M:%S", 43 | handlers=[logging.StreamHandler(sys.stdout)], 44 | level=logging.INFO 45 | ) 46 | 47 | # Torch device to run evaluation 48 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 49 | logging.info("Using {} for evaluation".format(device)) 50 | 51 | # Wav2vec2 processor and model 52 | processor = AutoProcessor.from_pretrained(args.model_name_or_path) 53 | model = AutoModelForCTC.from_pretrained(args.model_name_or_path) 54 | model = model.to(device) 55 | 56 | # Load evaluation dataset 57 | eval_dataset = load_dataset(args.dataset_name, args.dataset_config_name, 58 | split=args.eval_split_name, use_auth_token=args.use_auth_token) 59 | logging.info("Dataset '{}' - split '{}' has {} records".format( 60 | args.dataset_name, args.eval_split_name, len(eval_dataset))) 61 | 62 | # Preprocess text and resample audio 63 | def preprocess(sample): 64 | # Normalize text 65 | text = sample[args.text_column_name] 66 | sample['text'] = normalize_text(text) 67 | # Resample audio array 68 | resampler = torchaudio.transforms.Resample( 69 | sample[args.audio_column_name]["sampling_rate"], 70 | processor.feature_extractor.sampling_rate 71 | ) 72 | array_pt = torch.from_numpy(sample[args.audio_column_name]["array"]).unsqueeze(0) 73 | sample['audio_array'] = resampler(array_pt).squeeze().numpy() 74 | return sample 75 | 76 | eval_dataset = eval_dataset.map( 77 | preprocess, num_proc=args.preprocessing_num_workers 78 | ) 79 | 80 | # Predict on eval dataset 81 | def predict(batch): 82 | inputs = processor(batch['audio_array'], 83 | sampling_rate=processor.feature_extractor.sampling_rate, 84 | return_tensors="pt", padding=True) 85 | # Move torch tensor to the device 86 | for k in inputs.keys(): 87 | if inputs[k] is not None and torch.is_tensor(inputs[k]): 88 | inputs[k] = inputs[k].to(device) 89 | # Predict 90 | with torch.no_grad(): 91 | logits = model(**inputs).logits 92 | # Decode with LM 93 | if hasattr(processor, 'decoder'): 94 | decode_results = processor.batch_decode(logits.cpu().numpy()) 95 | batch["pred_strings"] = decode_results.text 96 | else: # No LM 97 | pred_ids = torch.argmax(logits, dim=-1) 98 | batch["pred_strings"] = processor.batch_decode(pred_ids) 99 | return batch 100 | 101 | eval_dataset = eval_dataset.map( 102 | predict, batched=True, batch_size=args.batch_size 103 | ) 104 | 105 | # Load metrics and calculate on eval dataset 106 | wer, cer = load_metric("wer"), load_metric("cer") 107 | wer_score = wer.compute(predictions=eval_dataset["pred_strings"], references=eval_dataset["text"]) 108 | cer_score = cer.compute(predictions=eval_dataset["pred_strings"], references=eval_dataset["text"]) 109 | logging.info("WER: {:.2f} % , CER: {:.2f} %".format(100*wer_score, 100*cer_score)) 110 | 111 | if __name__ == "__main__": 112 | main() -------------------------------------------------------------------------------- /train_asr.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import json 3 | import logging 4 | import os 5 | import re 6 | import sys 7 | import warnings 8 | from dataclasses import dataclass, field 9 | from typing import Dict, List, Optional, Union 10 | 11 | import numpy as np 12 | import torch 13 | import torchaudio 14 | import datasets 15 | from datasets import DatasetDict, load_dataset, load_metric, set_caching_enabled 16 | 17 | import transformers 18 | from transformers import ( 19 | AutoConfig, 20 | AutoFeatureExtractor, 21 | AutoModelForCTC, 22 | AutoProcessor, 23 | AutoTokenizer, 24 | HfArgumentParser, 25 | Trainer, 26 | TrainingArguments, 27 | Wav2Vec2Processor, 28 | Wav2Vec2CTCTokenizer, 29 | Wav2Vec2FeatureExtractor, 30 | set_seed, 31 | ) 32 | from transformers.trainer_utils import get_last_checkpoint, is_main_process 33 | from transformers.utils import check_min_version 34 | from transformers.utils.versions import require_version 35 | 36 | logger = logging.getLogger(__name__) 37 | set_caching_enabled(False) 38 | 39 | def list_field(default=None, metadata=None): 40 | return field(default_factory=lambda: default, metadata=metadata) 41 | 42 | @dataclass 43 | class ModelArguments: 44 | """ 45 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 46 | """ 47 | 48 | model_name_or_path: str = field( 49 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 50 | ) 51 | vocab_path: str = field( 52 | metadata={"help": "Path to ASR vocabulary, tokens as JSON file"} 53 | ) 54 | cache_dir: Optional[str] = field( 55 | default=None, 56 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 57 | ) 58 | freeze_feature_extractor: Optional[bool] = field( 59 | default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."} 60 | ) 61 | attention_dropout: Optional[float] = field( 62 | default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."} 63 | ) 64 | activation_dropout: Optional[float] = field( 65 | default=0.0, metadata={"help": "The dropout ratio for activations inside the fully connected layer."} 66 | ) 67 | feat_proj_dropout: Optional[float] = field( 68 | default=0.0, metadata={"help": "The dropout ratio for the projected features."} 69 | ) 70 | hidden_dropout: Optional[float] = field( 71 | default=0.0, 72 | metadata={ 73 | "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler." 74 | }, 75 | ) 76 | final_dropout: Optional[float] = field( 77 | default=0.0, 78 | metadata={"help": "The dropout probability for the final projection layer."}, 79 | ) 80 | mask_time_prob: Optional[float] = field( 81 | default=0.05, 82 | metadata={ 83 | "help": "Probability of each feature vector along the time axis to be chosen as the start of the vector" 84 | "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature" 85 | "vectors will be masked along the time axis." 86 | }, 87 | ) 88 | mask_time_length: Optional[int] = field( 89 | default=10, 90 | metadata={"help": "Length of vector span to mask along the time axis."}, 91 | ) 92 | mask_feature_prob: Optional[float] = field( 93 | default=0.0, 94 | metadata={ 95 | "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector" 96 | "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis." 97 | }, 98 | ) 99 | mask_feature_length: Optional[int] = field( 100 | default=10, 101 | metadata={"help": "Length of vector span to mask along the feature axis."}, 102 | ) 103 | layerdrop: Optional[float] = field(default=0.0, metadata={"help": "The LayerDrop probability."}) 104 | ctc_loss_reduction: Optional[str] = field( 105 | default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."} 106 | ) 107 | 108 | 109 | @dataclass 110 | class DataTrainingArguments: 111 | """ 112 | Arguments pertaining to what data we are going to input our model for training and eval. 113 | Using `HfArgumentParser` we can turn this class 114 | into argparse arguments to be able to specify them on 115 | the command line. 116 | """ 117 | 118 | train_file: str = field( 119 | metadata={"help": "The training data file (CSV file)."} 120 | ) 121 | validation_file: Optional[str] = field( 122 | default=None, 123 | metadata={"help": "An optional evaluation data file to evaluate on (CSV file)."}, 124 | ) 125 | delimiter: Optional[str] = field( 126 | default=",", 127 | metadata={"help": "Specifies the character delimiting individual cells in the CSV data"}, 128 | ) 129 | audio_path_column_name: Optional[str] = field( 130 | default="path", 131 | metadata={"help": "The name of the dataset column containing the audio paths. Defaults to 'path'"}, 132 | ) 133 | text_column_name: Optional[str] = field( 134 | default="text", 135 | metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"}, 136 | ) 137 | overwrite_cache: bool = field( 138 | default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} 139 | ) 140 | preprocessing_num_workers: Optional[int] = field( 141 | default=None, 142 | metadata={"help": "The number of processes to use for the preprocessing."}, 143 | ) 144 | eval_metrics: Optional[List[str]] = list_field( 145 | default=["wer"], 146 | metadata={"help": "A list of metrics the model should be evaluated on. E.g. `'wer cer'`"}, 147 | ) 148 | max_duration_in_seconds: Optional[float] = field( 149 | default=20.0, 150 | metadata={ 151 | "help": "Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`" 152 | }, 153 | ) 154 | min_duration_in_seconds: Optional[float] = field( 155 | default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"} 156 | ) 157 | unk_token: Optional[str] = field( 158 | default="[UNK]", 159 | metadata={"help": "The unk token for the tokenizer"}, 160 | ) 161 | pad_token: Optional[str] = field( 162 | default="[PAD]", 163 | metadata={"help": "The padding token for the tokenizer"}, 164 | ) 165 | word_delimiter_token: Optional[str] = field( 166 | default="|", 167 | metadata={"help": "The word delimiter token for the tokenizer"}, 168 | ) 169 | 170 | @dataclass 171 | class DataCollatorCTCWithPadding: 172 | """ 173 | Data collator that will dynamically pad the inputs received. 174 | Args: 175 | processor (:class:`~transformers.AutoProcessor`) 176 | The processor used for proccessing the data. 177 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): 178 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 179 | among: 180 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 181 | sequence if provided). 182 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 183 | maximum acceptable input length for the model if that argument is not provided. 184 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 185 | different lengths). 186 | max_length (:obj:`int`, `optional`): 187 | Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). 188 | max_length_labels (:obj:`int`, `optional`): 189 | Maximum length of the ``labels`` returned list and optionally padding length (see above). 190 | pad_to_multiple_of (:obj:`int`, `optional`): 191 | If set will pad the sequence to a multiple of the provided value. 192 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 193 | 7.5 (Volta). 194 | """ 195 | 196 | processor: AutoProcessor 197 | padding: Union[bool, str] = "longest" 198 | pad_to_multiple_of: Optional[int] = None 199 | pad_to_multiple_of_labels: Optional[int] = None 200 | 201 | def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: 202 | # split inputs and labels since they have to be of different lenghts and need 203 | # different padding methods 204 | input_features = [{"input_values": feature["input_values"]} for feature in features] 205 | label_features = [{"input_ids": feature["labels"]} for feature in features] 206 | 207 | batch = self.processor.pad( 208 | input_features, 209 | padding=self.padding, 210 | pad_to_multiple_of=self.pad_to_multiple_of, 211 | return_tensors="pt", 212 | ) 213 | 214 | with self.processor.as_target_processor(): 215 | labels_batch = self.processor.pad( 216 | label_features, 217 | padding=self.padding, 218 | pad_to_multiple_of=self.pad_to_multiple_of_labels, 219 | return_tensors="pt", 220 | ) 221 | 222 | # replace padding with -100 to ignore loss correctly 223 | labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) 224 | 225 | batch["labels"] = labels 226 | 227 | return batch 228 | 229 | def main(): 230 | # See all possible arguments in src/transformers/training_args.py 231 | # or by passing the --help flag to this script. 232 | # We now keep distinct sets of args, for a cleaner separation of concerns. 233 | 234 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 235 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 236 | # If we pass only one argument to the script and it's the path to a json file, 237 | # let's parse it to get our arguments. 238 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 239 | else: 240 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 241 | 242 | # Detecting last checkpoint. 243 | last_checkpoint = None 244 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 245 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 246 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 247 | raise ValueError( 248 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 249 | "Use --overwrite_output_dir to overcome." 250 | ) 251 | elif last_checkpoint is not None: 252 | logger.info( 253 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 254 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 255 | ) 256 | 257 | # Setup logging 258 | logging.basicConfig( 259 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 260 | datefmt="%m/%d/%Y %H:%M:%S", 261 | handlers=[logging.StreamHandler(sys.stdout)], 262 | ) 263 | logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) 264 | 265 | # Log on each process the small summary: 266 | logger.warning( 267 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 268 | f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 269 | ) 270 | # Set the verbosity to info of the Transformers logger (on main process only): 271 | if is_main_process(training_args.local_rank): 272 | transformers.utils.logging.set_verbosity_info() 273 | logger.info("Training/evaluation parameters %s", training_args) 274 | 275 | # Set seed before initializing model. 276 | set_seed(training_args.seed) 277 | 278 | # Load the model config 279 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) 280 | 281 | # We need to make sure that only first rank saves vocabulary 282 | # make sure all processes wait until vocab is created 283 | tokenizer_name_or_path = training_args.output_dir 284 | 285 | with training_args.main_process_first(): 286 | # Load vocab from file 287 | with open(model_args.vocab_path) as fp: 288 | vocab_dict = json.load(fp) 289 | 290 | vocab_file = os.path.join(tokenizer_name_or_path, "vocab.json") 291 | if training_args.overwrite_output_dir and os.path.isfile(vocab_file): 292 | os.remove(vocab_file) 293 | 294 | # Save vocab dict to be loaded into tokenizer 295 | if not os.path.isfile(vocab_file): 296 | with open(vocab_file, "w") as file: 297 | json.dump(vocab_dict, file) 298 | 299 | # Tokenizer args 300 | tokenizer_kwargs = { 301 | "config": config if config.tokenizer_class is not None else None, 302 | "tokenizer_type": config.model_type if config.tokenizer_class is None else None, 303 | "unk_token": data_args.unk_token, 304 | "pad_token": data_args.pad_token, 305 | "word_delimiter_token": data_args.word_delimiter_token, 306 | } 307 | 308 | # Load feature_extractor and tokenizer 309 | tokenizer = AutoTokenizer.from_pretrained( 310 | tokenizer_name_or_path, 311 | **tokenizer_kwargs, 312 | ) 313 | 314 | # Load feature_extractor 315 | feature_extractor = AutoFeatureExtractor.from_pretrained( 316 | model_args.model_name_or_path, cache_dir=model_args.cache_dir, 317 | ) 318 | sampling_rate_target = feature_extractor.sampling_rate 319 | 320 | # Update config for finetuning 321 | config.update( 322 | { 323 | "feat_proj_dropout": model_args.feat_proj_dropout, 324 | "attention_dropout": model_args.attention_dropout, 325 | "hidden_dropout": model_args.hidden_dropout, 326 | "final_dropout": model_args.final_dropout, 327 | "mask_time_prob": model_args.mask_time_prob, 328 | "mask_time_length": model_args.mask_time_length, 329 | "mask_feature_prob": model_args.mask_feature_prob, 330 | "mask_feature_length": model_args.mask_feature_length, 331 | "gradient_checkpointing": training_args.gradient_checkpointing, 332 | "layerdrop": model_args.layerdrop, 333 | "ctc_loss_reduction": model_args.ctc_loss_reduction, 334 | "pad_token_id": tokenizer.pad_token_id, 335 | "vocab_size": len(tokenizer), 336 | "activation_dropout": model_args.activation_dropout, 337 | "bos_token_id" : vocab_dict[""], 338 | "eos_token_id" : vocab_dict[""], 339 | "pad_token_id" : vocab_dict[data_args.pad_token] 340 | } 341 | ) 342 | 343 | # create model 344 | model = AutoModelForCTC.from_pretrained( 345 | model_args.model_name_or_path, 346 | cache_dir=model_args.cache_dir, 347 | config=config, 348 | ) 349 | 350 | # Freeze encoder 351 | if model_args.freeze_feature_extractor: 352 | model.freeze_feature_extractor() 353 | 354 | # Create a single processor 355 | if is_main_process(training_args.local_rank): 356 | # save feature extractor, tokenizer and config 357 | feature_extractor.save_pretrained(training_args.output_dir) 358 | tokenizer.save_pretrained(training_args.output_dir) 359 | config.save_pretrained(training_args.output_dir) 360 | 361 | # Load processor 362 | try: 363 | processor = AutoProcessor.from_pretrained(training_args.output_dir) 364 | except (OSError, KeyError): 365 | warnings.warn( 366 | "Loading a processor from a feature extractor config that does not" 367 | " include a `processor_class` attribute is deprecated and will be removed in v5. Please add the following " 368 | " attribute to your `preprocessor_config.json` file to suppress this warning: " 369 | " `'processor_class': 'Wav2Vec2Processor'`", 370 | FutureWarning, 371 | ) 372 | processor = Wav2Vec2Processor.from_pretrained(training_args.output_dir) 373 | 374 | # Custom data collator 375 | data_collator = DataCollatorCTCWithPadding(processor=processor) 376 | 377 | # Define evaluation metrics during training, *i.e.* word error rate, character error rate 378 | eval_metrics = {metric: load_metric(metric) for metric in data_args.eval_metrics} 379 | 380 | def compute_metrics(pred): 381 | # Prediction ids 382 | pred_ids = np.argmax(pred.predictions, axis=-1) 383 | # Convert -100 back to padding token id 384 | pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id 385 | # Prediction and label strings 386 | # we do not want to group tokens when computing the metrics 387 | pred_str = tokenizer.batch_decode(pred_ids) 388 | label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False) 389 | # Calcualte the metrics 390 | metrics = {k: v.compute(predictions=pred_str, references=label_str) for k, v in eval_metrics.items()} 391 | return metrics 392 | 393 | # Load the dataset from your local files. 394 | data_files = {"train": data_args.train_file, "validation": data_args.validation_file} 395 | datasets = load_dataset("csv", data_files=data_files, 396 | delimiter=data_args.delimiter, cache_dir=model_args.cache_dir) 397 | 398 | # Function to load audio and resample 399 | def load_audio_and_resample(path): 400 | # Load audio 401 | audio_array, sampling_rate = torchaudio.load(path) 402 | # Resample 403 | resampler = torchaudio.transforms.Resample(sampling_rate, sampling_rate_target) 404 | audio_array = resampler(audio_array).squeeze().numpy() 405 | return audio_array 406 | 407 | # Function to prepare inputs & targets 408 | def prepare_dataset(sample): 409 | # Load audio 410 | audio_array = load_audio_and_resample(sample[data_args.audio_path_column_name]) 411 | # Input features 412 | inputs = feature_extractor(audio_array, sampling_rate=sampling_rate_target) 413 | sample["input_values"] = inputs.input_values[0] 414 | sample["input_length"] = len(sample["input_values"]) 415 | # Encode targets 416 | sample["labels"] = tokenizer(sample[data_args.text_column_name]).input_ids 417 | return sample 418 | 419 | # Max & min input length for sample rate & max duration 420 | max_input_length = data_args.max_duration_in_seconds * sampling_rate_target 421 | min_input_length = data_args.min_duration_in_seconds * sampling_rate_target 422 | 423 | with training_args.main_process_first(desc="dataset map preprocessing"): 424 | # Prepare input features and targets 425 | datasets = datasets.map( 426 | prepare_dataset, 427 | remove_columns=[data_args.audio_path_column_name, data_args.text_column_name], 428 | desc="preprocess datasets", 429 | num_proc=data_args.preprocessing_num_workers 430 | ) 431 | 432 | # Filter data samples based on length 433 | def is_audio_in_length_range(length): 434 | return length > min_input_length and length < max_input_length 435 | 436 | datasets = datasets.filter( 437 | is_audio_in_length_range, 438 | num_proc=data_args.preprocessing_num_workers, 439 | input_columns=["input_length"], 440 | ) 441 | 442 | # Initialize Trainer 443 | trainer = Trainer( 444 | model=model, 445 | data_collator=data_collator, 446 | args=training_args, 447 | compute_metrics=compute_metrics, 448 | train_dataset=datasets["train"], 449 | eval_dataset=datasets["validation"] if data_args.validation_file else None, 450 | tokenizer=feature_extractor, 451 | ) 452 | 453 | # Training 454 | if training_args.do_train: 455 | # Use last checkpoint if exist 456 | if last_checkpoint is not None: 457 | checkpoint = last_checkpoint 458 | elif os.path.isdir(model_args.model_name_or_path): 459 | checkpoint = model_args.model_name_or_path 460 | else: 461 | checkpoint = None 462 | 463 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 464 | trainer.save_model() 465 | 466 | # Train metrics 467 | train_metrics = train_result.metrics 468 | train_metrics["train_samples"] = len(datasets["train"]) 469 | trainer.log_metrics("train", train_metrics) 470 | trainer.save_metrics("train", train_metrics) 471 | trainer.save_state() 472 | 473 | # Evaluation 474 | results = {} 475 | if training_args.do_eval: 476 | logger.info("*** Evaluate ***") 477 | eval_metrics = trainer.evaluate() 478 | # Evaluation metrics 479 | eval_metrics["eval_samples"] = len(datasets["validation"]) 480 | trainer.log_metrics("eval", eval_metrics) 481 | trainer.save_metrics("eval", eval_metrics) 482 | 483 | if __name__ == "__main__": 484 | main() 485 | --------------------------------------------------------------------------------