├── demo └── .gitkeep ├── .gitmodules ├── duration.py ├── preprocess.py ├── prepare_wenet_data.py ├── README.md └── finetune_whisper.py /demo/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "wenet"] 2 | path = wenet 3 | url = https://github.com/turinaf/wenet.git 4 | -------------------------------------------------------------------------------- /duration.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wave 3 | 4 | def get_wav_duration(file_path): 5 | with wave.open(file_path, 'r') as audio: 6 | frames = audio.getnframes() 7 | rate = audio.getframerate() 8 | duration = frames / float(rate) 9 | return duration 10 | 11 | def total_duration_in_subfolders(root_folder): 12 | total_duration = 0.0 13 | for subdir, _, files in os.walk(root_folder): 14 | for filename in files: 15 | if filename.endswith('.wav'): 16 | file_path = os.path.join(subdir, filename) 17 | total_duration += get_wav_duration(file_path) 18 | return total_duration 19 | 20 | root_folders = ["sagalee/train", "sagalee/dev", "sagalee/test"] 21 | for root_folder in root_folders: 22 | duration_in_seconds = total_duration_in_subfolders(root_folder) 23 | hours = int(duration_in_seconds // 3600) 24 | minutes = int((duration_in_seconds % 3600) // 60) 25 | seconds = int(duration_in_seconds % 60) 26 | 27 | print(f"{root_folder} duration: {hours} hours, {minutes} minutes, {seconds} seconds") -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from tqdm import tqdm 4 | 5 | root_dir = "sagalee" 6 | def remove_punc(sentence): 7 | # remove some special characters 8 | sentence = sentence.replace("è", "e").replace("ₒ", "").replace("•", "").replace("ʼ", "").replace("''", "").replace("_", " ").replace('\xa0', " ") 9 | # Remove punc while retaining apostrophe and dot in decimal numbers 10 | sentence = re.sub(r"(?!\b'\b)(? 4 | Paper is now available on arxiv: [Sagalee: an Open Source Automatic Speech Recognition Dataset for Oromo Language](https://arxiv.org/abs/2502.00421)
5 | The dataset: [on this link](https://openslr.org/157/) 6 | 7 | ## News 8 | - 🎉 [2024-12-20] Sagalee paper accepted to [ICASSP 2025](https://2025.ieeeicassp.org/) Conference 9 | - ✨ [2024-11-28] Sagalee dataset released under [CC BY-NC 4.0 International](https://creativecommons.org/licenses/by-nc/4.0/legalcode) license. 10 | 11 | ## Training ASR on Sagalee Dataset 12 | ### Clone this Repo 13 | ``` 14 | git clone https://github.com/turinaf/sagalee.git 15 | cd sagalee 16 | git submodule update --init --no-fetch 17 | ``` 18 | ### Create env and install dependancy 19 | ``` 20 | conda create -n wenet python=3.10 21 | conda activate wenet 22 | conda install conda-forge::sox 23 | pip install torch==2.2.2+cu121 torchaudio==2.2.2+cu121 -f https://download.pytorch.org/whl/torch_stable.html 24 | ``` 25 | ``` 26 | cd wenet 27 | pip install -r requirements.txt 28 | ``` 29 | ## Training recipes 30 | 31 | ### 1 Prepare the data. 32 | Running the script `prepare_wenet_data.py` will prepare data in required format inside `wenet/examples/sagalee/s0/data/`. It organize the wav files and text files into two files. `wav.scp` containing two tab-separated columns with `wav_id` and `wav_path` and `text` containing two tab-separated columns `wav_id` and `text_label` 33 | 34 | 35 | `wav.scp` file: 36 | ``` 37 | sagalee_SPKR232_122 sagalee/train/SPKR232/sagalee_SPKR232_122.wav 38 | sagalee_SPKR232_002 sagalee/train/SPKR232/sagalee_SPKR232_002.wav 39 | ``` 40 | `text` file 41 | ``` 42 | sagalee_SPKR232_082 HOJJATAA JIRA JECHUUN KOMATE 43 | sagalee_SPKR232_093 SAMMUU KEE KEESSA HIN KAAYANI 44 | ``` 45 | ### 2 Run the training 46 | After preparing data, navigate to the directory containing `run.sh`, and simply run the stages starting from stage 1. 47 | ``` 48 | cd wenet/examples/sagalee/s0 49 | ``` 50 | ``` 51 | bash run.sh --stage 1 --stop_stage 1 52 | bash run.sh --stage 2 --stop_stage 2 53 | bash run.sh --stage 3 --stop_stage 3 54 | bash run.sh --stage 4 --stop_stage 4 55 | bash run.sh --stage 5 --stop_stage 5 56 | ``` 57 | * Stage 1: is used to extract global cmvn(cepstral mean and variance normalization) statistics. These statistics will be used to normalize the acoustic features. 58 | * Stage 2: Generate label token dictionary 59 | * Stage 3: This stage generates the WeNet required format file `data.list` in json format. 60 | * Stage 4: Training 61 | * Stage 4: Testing the trained model 62 | ## Finetuning Whisper model 63 | - `finetune_whisper.py` is used to fine tune whisper largev3 (you can change model size) by freezing bottom layers of encoder on Sagalee dataset, you can simply run this python script to finetune. 64 | ``` 65 | python finetune_whisper.py 66 | ``` 67 | - For full paramater finetuning, follow these [steps](https://github.com/turinaf/wenet/blob/f4ff710f95bb30bdd898fd463f2877a504df7533/examples/aishell/whisper/README.md) in wenet script. 68 | 69 | ## Citation 70 | 71 | ``` 72 | @INPROCEEDINGS{10890761, 73 | author={Abu, Turi and Shi, Ying and Zheng, Thomas Fang and Wang, Dong}, 74 | booktitle={ICASSP 2025 - 2025 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 75 | title={Sagalee: an Open Source Automatic Speech Recognition Dataset for Oromo Language}, 76 | year={2025}, 77 | volume={}, 78 | number={}, 79 | pages={1-5}, 80 | keywords={Crowdsourcing;Error analysis;Signal processing;Phonetics;Audio recording;Acoustics;Noise measurement;Speech processing;Research and development;Automatic speech recognition;Speech Recognition;Afaan Oromo;Dataset;Speech processing}, 81 | doi={10.1109/ICASSP49660.2025.10890761}} 82 | ``` 83 | 84 | ## Acknowledgement 85 | The training code is adapted from [WeNet](https://github.com/wenet-e2e/wenet) and used to train model on our custom [Sagalee](https://github.com/turinaf/Sagalee) Dataset. 86 | -------------------------------------------------------------------------------- /finetune_whisper.py: -------------------------------------------------------------------------------- 1 | from transformers import WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration 2 | from datasets import load_dataset, load_from_disk 3 | import torch 4 | import os 5 | from torch.utils.data import DataLoader 6 | from transformers import AdamW, get_scheduler 7 | from tqdm import tqdm 8 | import torchaudio 9 | import pandas as pd 10 | from datasets import Dataset, DatasetDict 11 | import evaluate 12 | 13 | 14 | def load_sagalee_dataset(base_dir): 15 | dataset = [] 16 | for split in ["train", "dev", "test"]: 17 | split_dir = os.path.join(base_dir, split) 18 | for speaker_id in tqdm(os.listdir(split_dir),total=len(os.listdir(split_dir)),desc=f"loading {split}"): 19 | speaker_dir = os.path.join(split_dir, speaker_id) 20 | if os.path.isdir(speaker_dir): 21 | for file in os.listdir(speaker_dir): 22 | if file.endswith(".wav"): 23 | audio_path = os.path.join(speaker_dir, file) 24 | transcript_path = os.path.join(speaker_dir, file.replace('.wav', '.txt')) 25 | with open(transcript_path, 'r') as f: 26 | transcription = f.read().strip() 27 | dataset.append({"audio": audio_path, "text": transcription, "split": split}) 28 | return dataset 29 | 30 | # Preprocessing: Tokenization and feature extraction 31 | def preprocess_function(examples): 32 | audio_path = examples["audio"] 33 | speech_array, _ = torchaudio.load(audio_path) 34 | transcription = examples["text"] 35 | tokenized_input = tokenizer(transcription, return_tensors="pt").input_ids 36 | features = processor(speech_array.squeeze().numpy(), return_tensors="pt", sampling_rate=16000).input_features 37 | return { 38 | "input_features": features.squeeze(), 39 | "labels": tokenized_input.squeeze(), 40 | } 41 | 42 | # Load Whisper Model, Tokenizer, Processor 43 | model_name = "whisper-large-v3" 44 | model_path = f"openai/{model_name}" 45 | tokenizer = WhisperTokenizer.from_pretrained(model_path) 46 | processor = WhisperProcessor.from_pretrained(model_path) 47 | model = WhisperForConditionalGeneration.from_pretrained(model_path) 48 | 49 | # Freeze the bottom layers of the encoder 50 | n_freeze_layers = 20 51 | for param in model.model.encoder.layers[:n_freeze_layers]: 52 | param.requires_grad = False 53 | 54 | # freeze entire encoder 55 | # for param in model.model.encoder.parameters(): 56 | # param.requires_grad = False 57 | 58 | # Dataset 59 | dataset_dir = "processed_dataset" 60 | if os.path.exists(dataset_dir+"/train") and os.path.exists(dataset_dir+"/test") and os.path.exists(dataset_dir+"/dev"): 61 | print(f"Loading processed dataset from: {dataset_dir}") 62 | train_data = load_from_disk(dataset_dir+"/train") 63 | test_data = load_from_disk(dataset_dir+"/test") 64 | dev_data = load_from_disk(dataset_dir+"/dev") 65 | else: 66 | print("Loading raw data and extracting features") 67 | # Prepare dataset and data loader 68 | base_dir = "/work103/turi/project/oasr/sagalee" 69 | dataset = load_sagalee_dataset(base_dir) 70 | dataset = Dataset.from_pandas(pd.DataFrame(dataset)) 71 | dataset = DatasetDict({'train': dataset.filter(lambda x: x['split'] == 'train'), 'dev': dataset.filter(lambda x: x['split']=='dev'), 'test': dataset.filter(lambda x: x['split']=='test')}) 72 | # Preprocess data 73 | train_data = dataset['train'].map(preprocess_function) 74 | dev_data = dataset['dev'].map(preprocess_function) 75 | test_data = dataset['test'].map(preprocess_function) 76 | print(f"Preprocessed dataset\n{train_data}") 77 | train_data = train_data.remove_columns(['audio','text', 'split']) 78 | dev_data = dev_data.remove_columns(['audio','text', 'split']) 79 | test_data = test_data.remove_columns(['audio','text', 'split']) 80 | print(f"Train data columns removed: \n{train_data}") 81 | # Save processed data to disk for later use 82 | train_data.save_to_disk(dataset_dir+"/train") 83 | dev_data.save_to_disk(dataset_dir+"/dev") 84 | test_data.save_to_disk(dataset_dir+"/test") 85 | print(f"Saved processed dataset to {dataset_dir}") 86 | 87 | # Custom collate function to pad sequences 88 | def collate_fn(examples): 89 | input_features = [{'input_features': item['input_features']} for item in examples] 90 | batch = processor.feature_extractor.pad(input_features, return_tensors='pt') 91 | label_features = [{'input_ids': item['labels']} for item in examples] 92 | labels_batch = processor.tokenizer.pad(label_features, return_tensors='pt') 93 | # replace padding with -100 to ignore loss correctly 94 | labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) 95 | batch["labels"] = labels 96 | return batch 97 | 98 | 99 | # Convert the dataset into a PyTorch DataLoader 100 | train_dataloader = DataLoader(train_data, batch_size=2, shuffle=True, collate_fn=collate_fn) 101 | dev_dataloader = DataLoader(dev_data, batch_size=2, collate_fn=collate_fn) 102 | test_dataloader = DataLoader(test_data, batch_size=2, collate_fn=collate_fn) 103 | 104 | # Optimizer and Scheduler 105 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) 106 | num_epochs = 20 107 | num_training_steps = num_epochs * len(train_dataloader)//2 108 | lr_scheduler = get_scheduler( 109 | name="linear", optimizer=optimizer, num_warmup_steps=500, num_training_steps=num_training_steps 110 | ) 111 | 112 | # Training 113 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 114 | print(f"\nDevice being used: {device}\n") 115 | model.to(device) 116 | model.generation_config.task = "transcribe" 117 | 118 | wer_metric = evaluate.load("wer") 119 | best_val_loss = float("inf") # Initialize with a very large value 120 | exp_path = f"exp/{model_name}-om" 121 | best_model_path = f"{exp_path}/best_model" 122 | 123 | checkpoint_path = f"{exp_path}/checkpoints" 124 | if not os.path.exists(checkpoint_path): 125 | os.makedirs(checkpoint_path) 126 | 127 | progress_bar = tqdm(range(num_training_steps)) 128 | 129 | model.train() 130 | for epoch in range(num_epochs): 131 | total_loss = 0 132 | total_correct = 0 133 | total_samples = 0 134 | progress_bar.set_description(f"Epoch {epoch+1}") 135 | for batch in train_dataloader: 136 | input_features = batch["input_features"].to(device) 137 | labels = batch["labels"].to(device) 138 | # Forward pass 139 | outputs = model(input_features=input_features, labels=labels) 140 | loss = outputs.loss 141 | total_loss += loss.item() 142 | 143 | # Backward pass 144 | loss.backward() 145 | optimizer.step() 146 | lr_scheduler.step() 147 | optimizer.zero_grad() 148 | progress_bar.set_postfix(loss=loss.item()) 149 | progress_bar.update(1) 150 | model.save_pretrained(f"{exp_path}/checkpoints/epoch_{epoch+1}") 151 | # Calculate average loss and accuracy for the epoch 152 | avg_loss = total_loss / len(train_dataloader) 153 | #avg_accuracy = total_correct / total_samples 154 | print(f"Epoch {epoch + 1}: Avg Loss = {avg_loss:.4f}") 155 | 156 | # Validate the model on the dev set 157 | model.eval() 158 | val_loss = 0 159 | with torch.no_grad(): 160 | for batch in dev_dataloader: 161 | input_features = batch["input_features"].to(device) 162 | labels = batch["labels"].to(device) 163 | 164 | outputs = model(input_features=input_features, labels=labels) 165 | val_loss += outputs.loss.item() 166 | 167 | avg_val_loss = val_loss / len(dev_dataloader) 168 | print(f"Validation Loss: {avg_val_loss:.4f}") 169 | 170 | # Save the model if validation loss improves 171 | if avg_val_loss < best_val_loss: 172 | best_val_loss = avg_val_loss 173 | model.save_pretrained(best_model_path) 174 | print(f"Best model saved with validation loss: {avg_val_loss:.4f}") 175 | 176 | model.train() # Return to training mode after validation 177 | 178 | print(f"\n\n COMPLETED TRAINING \n") 179 | 180 | # Evaluation: Calculate Word Error Rate (WER) on the test set 181 | model.load_pretrained(best_model_path) # Load the best model for evaluation 182 | model.eval() 183 | 184 | for batch in test_dataloader: 185 | input_features = batch["input_features"].to(device) 186 | labels = batch["labels"].to(device) 187 | 188 | with torch.no_grad(): 189 | generated_tokens = model.generate(input_features) 190 | 191 | decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) 192 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 193 | 194 | wer_metric.add_batch(predictions=decoded_preds, references=decoded_labels) 195 | 196 | final_wer = wer_metric.compute() 197 | print(f"Word Error Rate (WER) on Test Set: {final_wer:.2f}") 198 | --------------------------------------------------------------------------------