├── demo
└── .gitkeep
├── .gitmodules
├── duration.py
├── preprocess.py
├── prepare_wenet_data.py
├── README.md
└── finetune_whisper.py
/demo/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "wenet"]
2 | path = wenet
3 | url = https://github.com/turinaf/wenet.git
4 |
--------------------------------------------------------------------------------
/duration.py:
--------------------------------------------------------------------------------
1 | import os
2 | import wave
3 |
4 | def get_wav_duration(file_path):
5 | with wave.open(file_path, 'r') as audio:
6 | frames = audio.getnframes()
7 | rate = audio.getframerate()
8 | duration = frames / float(rate)
9 | return duration
10 |
11 | def total_duration_in_subfolders(root_folder):
12 | total_duration = 0.0
13 | for subdir, _, files in os.walk(root_folder):
14 | for filename in files:
15 | if filename.endswith('.wav'):
16 | file_path = os.path.join(subdir, filename)
17 | total_duration += get_wav_duration(file_path)
18 | return total_duration
19 |
20 | root_folders = ["sagalee/train", "sagalee/dev", "sagalee/test"]
21 | for root_folder in root_folders:
22 | duration_in_seconds = total_duration_in_subfolders(root_folder)
23 | hours = int(duration_in_seconds // 3600)
24 | minutes = int((duration_in_seconds % 3600) // 60)
25 | seconds = int(duration_in_seconds % 60)
26 |
27 | print(f"{root_folder} duration: {hours} hours, {minutes} minutes, {seconds} seconds")
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | from tqdm import tqdm
4 |
5 | root_dir = "sagalee"
6 | def remove_punc(sentence):
7 | # remove some special characters
8 | sentence = sentence.replace("è", "e").replace("ₒ", "").replace("•", "").replace("ʼ", "").replace("''", "").replace("_", " ").replace('\xa0', " ")
9 | # Remove punc while retaining apostrophe and dot in decimal numbers
10 | sentence = re.sub(r"(?!\b'\b)(?
4 | Paper is now available on arxiv: [Sagalee: an Open Source Automatic Speech Recognition Dataset for Oromo Language](https://arxiv.org/abs/2502.00421)
5 | The dataset: [on this link](https://openslr.org/157/)
6 |
7 | ## News
8 | - 🎉 [2024-12-20] Sagalee paper accepted to [ICASSP 2025](https://2025.ieeeicassp.org/) Conference
9 | - ✨ [2024-11-28] Sagalee dataset released under [CC BY-NC 4.0 International](https://creativecommons.org/licenses/by-nc/4.0/legalcode) license.
10 |
11 | ## Training ASR on Sagalee Dataset
12 | ### Clone this Repo
13 | ```
14 | git clone https://github.com/turinaf/sagalee.git
15 | cd sagalee
16 | git submodule update --init --no-fetch
17 | ```
18 | ### Create env and install dependancy
19 | ```
20 | conda create -n wenet python=3.10
21 | conda activate wenet
22 | conda install conda-forge::sox
23 | pip install torch==2.2.2+cu121 torchaudio==2.2.2+cu121 -f https://download.pytorch.org/whl/torch_stable.html
24 | ```
25 | ```
26 | cd wenet
27 | pip install -r requirements.txt
28 | ```
29 | ## Training recipes
30 |
31 | ### 1 Prepare the data.
32 | Running the script `prepare_wenet_data.py` will prepare data in required format inside `wenet/examples/sagalee/s0/data/`. It organize the wav files and text files into two files. `wav.scp` containing two tab-separated columns with `wav_id` and `wav_path` and `text` containing two tab-separated columns `wav_id` and `text_label`
33 |
34 |
35 | `wav.scp` file:
36 | ```
37 | sagalee_SPKR232_122 sagalee/train/SPKR232/sagalee_SPKR232_122.wav
38 | sagalee_SPKR232_002 sagalee/train/SPKR232/sagalee_SPKR232_002.wav
39 | ```
40 | `text` file
41 | ```
42 | sagalee_SPKR232_082 HOJJATAA JIRA JECHUUN KOMATE
43 | sagalee_SPKR232_093 SAMMUU KEE KEESSA HIN KAAYANI
44 | ```
45 | ### 2 Run the training
46 | After preparing data, navigate to the directory containing `run.sh`, and simply run the stages starting from stage 1.
47 | ```
48 | cd wenet/examples/sagalee/s0
49 | ```
50 | ```
51 | bash run.sh --stage 1 --stop_stage 1
52 | bash run.sh --stage 2 --stop_stage 2
53 | bash run.sh --stage 3 --stop_stage 3
54 | bash run.sh --stage 4 --stop_stage 4
55 | bash run.sh --stage 5 --stop_stage 5
56 | ```
57 | * Stage 1: is used to extract global cmvn(cepstral mean and variance normalization) statistics. These statistics will be used to normalize the acoustic features.
58 | * Stage 2: Generate label token dictionary
59 | * Stage 3: This stage generates the WeNet required format file `data.list` in json format.
60 | * Stage 4: Training
61 | * Stage 4: Testing the trained model
62 | ## Finetuning Whisper model
63 | - `finetune_whisper.py` is used to fine tune whisper largev3 (you can change model size) by freezing bottom layers of encoder on Sagalee dataset, you can simply run this python script to finetune.
64 | ```
65 | python finetune_whisper.py
66 | ```
67 | - For full paramater finetuning, follow these [steps](https://github.com/turinaf/wenet/blob/f4ff710f95bb30bdd898fd463f2877a504df7533/examples/aishell/whisper/README.md) in wenet script.
68 |
69 | ## Citation
70 |
71 | ```
72 | @INPROCEEDINGS{10890761,
73 | author={Abu, Turi and Shi, Ying and Zheng, Thomas Fang and Wang, Dong},
74 | booktitle={ICASSP 2025 - 2025 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
75 | title={Sagalee: an Open Source Automatic Speech Recognition Dataset for Oromo Language},
76 | year={2025},
77 | volume={},
78 | number={},
79 | pages={1-5},
80 | keywords={Crowdsourcing;Error analysis;Signal processing;Phonetics;Audio recording;Acoustics;Noise measurement;Speech processing;Research and development;Automatic speech recognition;Speech Recognition;Afaan Oromo;Dataset;Speech processing},
81 | doi={10.1109/ICASSP49660.2025.10890761}}
82 | ```
83 |
84 | ## Acknowledgement
85 | The training code is adapted from [WeNet](https://github.com/wenet-e2e/wenet) and used to train model on our custom [Sagalee](https://github.com/turinaf/Sagalee) Dataset.
86 |
--------------------------------------------------------------------------------
/finetune_whisper.py:
--------------------------------------------------------------------------------
1 | from transformers import WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration
2 | from datasets import load_dataset, load_from_disk
3 | import torch
4 | import os
5 | from torch.utils.data import DataLoader
6 | from transformers import AdamW, get_scheduler
7 | from tqdm import tqdm
8 | import torchaudio
9 | import pandas as pd
10 | from datasets import Dataset, DatasetDict
11 | import evaluate
12 |
13 |
14 | def load_sagalee_dataset(base_dir):
15 | dataset = []
16 | for split in ["train", "dev", "test"]:
17 | split_dir = os.path.join(base_dir, split)
18 | for speaker_id in tqdm(os.listdir(split_dir),total=len(os.listdir(split_dir)),desc=f"loading {split}"):
19 | speaker_dir = os.path.join(split_dir, speaker_id)
20 | if os.path.isdir(speaker_dir):
21 | for file in os.listdir(speaker_dir):
22 | if file.endswith(".wav"):
23 | audio_path = os.path.join(speaker_dir, file)
24 | transcript_path = os.path.join(speaker_dir, file.replace('.wav', '.txt'))
25 | with open(transcript_path, 'r') as f:
26 | transcription = f.read().strip()
27 | dataset.append({"audio": audio_path, "text": transcription, "split": split})
28 | return dataset
29 |
30 | # Preprocessing: Tokenization and feature extraction
31 | def preprocess_function(examples):
32 | audio_path = examples["audio"]
33 | speech_array, _ = torchaudio.load(audio_path)
34 | transcription = examples["text"]
35 | tokenized_input = tokenizer(transcription, return_tensors="pt").input_ids
36 | features = processor(speech_array.squeeze().numpy(), return_tensors="pt", sampling_rate=16000).input_features
37 | return {
38 | "input_features": features.squeeze(),
39 | "labels": tokenized_input.squeeze(),
40 | }
41 |
42 | # Load Whisper Model, Tokenizer, Processor
43 | model_name = "whisper-large-v3"
44 | model_path = f"openai/{model_name}"
45 | tokenizer = WhisperTokenizer.from_pretrained(model_path)
46 | processor = WhisperProcessor.from_pretrained(model_path)
47 | model = WhisperForConditionalGeneration.from_pretrained(model_path)
48 |
49 | # Freeze the bottom layers of the encoder
50 | n_freeze_layers = 20
51 | for param in model.model.encoder.layers[:n_freeze_layers]:
52 | param.requires_grad = False
53 |
54 | # freeze entire encoder
55 | # for param in model.model.encoder.parameters():
56 | # param.requires_grad = False
57 |
58 | # Dataset
59 | dataset_dir = "processed_dataset"
60 | if os.path.exists(dataset_dir+"/train") and os.path.exists(dataset_dir+"/test") and os.path.exists(dataset_dir+"/dev"):
61 | print(f"Loading processed dataset from: {dataset_dir}")
62 | train_data = load_from_disk(dataset_dir+"/train")
63 | test_data = load_from_disk(dataset_dir+"/test")
64 | dev_data = load_from_disk(dataset_dir+"/dev")
65 | else:
66 | print("Loading raw data and extracting features")
67 | # Prepare dataset and data loader
68 | base_dir = "/work103/turi/project/oasr/sagalee"
69 | dataset = load_sagalee_dataset(base_dir)
70 | dataset = Dataset.from_pandas(pd.DataFrame(dataset))
71 | dataset = DatasetDict({'train': dataset.filter(lambda x: x['split'] == 'train'), 'dev': dataset.filter(lambda x: x['split']=='dev'), 'test': dataset.filter(lambda x: x['split']=='test')})
72 | # Preprocess data
73 | train_data = dataset['train'].map(preprocess_function)
74 | dev_data = dataset['dev'].map(preprocess_function)
75 | test_data = dataset['test'].map(preprocess_function)
76 | print(f"Preprocessed dataset\n{train_data}")
77 | train_data = train_data.remove_columns(['audio','text', 'split'])
78 | dev_data = dev_data.remove_columns(['audio','text', 'split'])
79 | test_data = test_data.remove_columns(['audio','text', 'split'])
80 | print(f"Train data columns removed: \n{train_data}")
81 | # Save processed data to disk for later use
82 | train_data.save_to_disk(dataset_dir+"/train")
83 | dev_data.save_to_disk(dataset_dir+"/dev")
84 | test_data.save_to_disk(dataset_dir+"/test")
85 | print(f"Saved processed dataset to {dataset_dir}")
86 |
87 | # Custom collate function to pad sequences
88 | def collate_fn(examples):
89 | input_features = [{'input_features': item['input_features']} for item in examples]
90 | batch = processor.feature_extractor.pad(input_features, return_tensors='pt')
91 | label_features = [{'input_ids': item['labels']} for item in examples]
92 | labels_batch = processor.tokenizer.pad(label_features, return_tensors='pt')
93 | # replace padding with -100 to ignore loss correctly
94 | labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
95 | batch["labels"] = labels
96 | return batch
97 |
98 |
99 | # Convert the dataset into a PyTorch DataLoader
100 | train_dataloader = DataLoader(train_data, batch_size=2, shuffle=True, collate_fn=collate_fn)
101 | dev_dataloader = DataLoader(dev_data, batch_size=2, collate_fn=collate_fn)
102 | test_dataloader = DataLoader(test_data, batch_size=2, collate_fn=collate_fn)
103 |
104 | # Optimizer and Scheduler
105 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
106 | num_epochs = 20
107 | num_training_steps = num_epochs * len(train_dataloader)//2
108 | lr_scheduler = get_scheduler(
109 | name="linear", optimizer=optimizer, num_warmup_steps=500, num_training_steps=num_training_steps
110 | )
111 |
112 | # Training
113 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114 | print(f"\nDevice being used: {device}\n")
115 | model.to(device)
116 | model.generation_config.task = "transcribe"
117 |
118 | wer_metric = evaluate.load("wer")
119 | best_val_loss = float("inf") # Initialize with a very large value
120 | exp_path = f"exp/{model_name}-om"
121 | best_model_path = f"{exp_path}/best_model"
122 |
123 | checkpoint_path = f"{exp_path}/checkpoints"
124 | if not os.path.exists(checkpoint_path):
125 | os.makedirs(checkpoint_path)
126 |
127 | progress_bar = tqdm(range(num_training_steps))
128 |
129 | model.train()
130 | for epoch in range(num_epochs):
131 | total_loss = 0
132 | total_correct = 0
133 | total_samples = 0
134 | progress_bar.set_description(f"Epoch {epoch+1}")
135 | for batch in train_dataloader:
136 | input_features = batch["input_features"].to(device)
137 | labels = batch["labels"].to(device)
138 | # Forward pass
139 | outputs = model(input_features=input_features, labels=labels)
140 | loss = outputs.loss
141 | total_loss += loss.item()
142 |
143 | # Backward pass
144 | loss.backward()
145 | optimizer.step()
146 | lr_scheduler.step()
147 | optimizer.zero_grad()
148 | progress_bar.set_postfix(loss=loss.item())
149 | progress_bar.update(1)
150 | model.save_pretrained(f"{exp_path}/checkpoints/epoch_{epoch+1}")
151 | # Calculate average loss and accuracy for the epoch
152 | avg_loss = total_loss / len(train_dataloader)
153 | #avg_accuracy = total_correct / total_samples
154 | print(f"Epoch {epoch + 1}: Avg Loss = {avg_loss:.4f}")
155 |
156 | # Validate the model on the dev set
157 | model.eval()
158 | val_loss = 0
159 | with torch.no_grad():
160 | for batch in dev_dataloader:
161 | input_features = batch["input_features"].to(device)
162 | labels = batch["labels"].to(device)
163 |
164 | outputs = model(input_features=input_features, labels=labels)
165 | val_loss += outputs.loss.item()
166 |
167 | avg_val_loss = val_loss / len(dev_dataloader)
168 | print(f"Validation Loss: {avg_val_loss:.4f}")
169 |
170 | # Save the model if validation loss improves
171 | if avg_val_loss < best_val_loss:
172 | best_val_loss = avg_val_loss
173 | model.save_pretrained(best_model_path)
174 | print(f"Best model saved with validation loss: {avg_val_loss:.4f}")
175 |
176 | model.train() # Return to training mode after validation
177 |
178 | print(f"\n\n COMPLETED TRAINING \n")
179 |
180 | # Evaluation: Calculate Word Error Rate (WER) on the test set
181 | model.load_pretrained(best_model_path) # Load the best model for evaluation
182 | model.eval()
183 |
184 | for batch in test_dataloader:
185 | input_features = batch["input_features"].to(device)
186 | labels = batch["labels"].to(device)
187 |
188 | with torch.no_grad():
189 | generated_tokens = model.generate(input_features)
190 |
191 | decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
192 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
193 |
194 | wer_metric.add_batch(predictions=decoded_preds, references=decoded_labels)
195 |
196 | final_wer = wer_metric.compute()
197 | print(f"Word Error Rate (WER) on Test Set: {final_wer:.2f}")
198 |
--------------------------------------------------------------------------------