├── .gitmodules ├── LICENSE ├── README.md ├── data ├── arpa_phonemes └── metadata_l2arctic ├── evaluate.py ├── hparams ├── evaluate.yaml ├── train.yaml ├── train_mpl.yaml └── transcribe.yaml ├── l2arctic_prepare.py ├── l2arctic_unlabeled_prepare.py ├── mpd_eval_v3.py ├── split_train_dev.py ├── train.py ├── train_mpl.py └── transcribe.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "speechbrain"] 2 | path = speechbrain 3 | url = https://github.com/Mu-Y/speechbrain 4 | branch = stable_sb 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Mu-Y 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mpl-mdd 2 | 3 | Code for our paper "[Improving Mispronunciation Detection with Wav2vec2-based Momentum Pseudo-Labeling for Accentedness and Intelligibility Assessment](https://arxiv.org/abs/2203.15937)". An audio demo is available [here](https://mu-y.github.io/speech_samples/mdd_IS22/). 4 | 5 | This repo contains code for fine-tuning a wav2vec2-based MDD model with momentum pseudo-labeling (MPL). The implementation is based on [SpeechBrain](https://github.com/speechbrain/speechbrain). 6 | 7 | ## Pull the repo 8 | ``` 9 | git clone git@github.com:Mu-Y/mpl-mdd.git 10 | cd mpl-mdd 11 | git submodule update --init --recursive 12 | ``` 13 | 14 | ## Install dependencies and set up env 15 | Install the requirements by SpeechBrain and some extras. 16 | ``` 17 | cd mpl-mdd/speechbrain 18 | pip install -r requirements.txt 19 | pip install textgrid transformers librosa 20 | ``` 21 | Append the path to speechbrain module to `PYTHONPATH`. 22 | ``` 23 | export PYTHONPATH=$PYTHONPATH: 24 | ``` 25 | 26 | ## Data preperation 27 | First, download [L2-ARCTIC](https://psi.engr.tamu.edu/l2-arctic-corpus/) dataset, and unzip it. Then run the following commands: 28 | ``` 29 | # for labeled samples - get train.json and test.json 30 | python l2arctic_prepare.py 31 | 32 | # for unlabled samples - get train_unlabeled.json 33 | python l2arctic_unlabeled_prepare.py 34 | 35 | # split dev set from training - get train-train.json and train-dev.json 36 | python split_train_dev.py --in_json=data/train.json --out_json_train=data/train-train.json --out_json_dev=data/train-dev.json 37 | ``` 38 | 39 | 40 | 41 | ## Training 42 | ### Step 1 43 | Fine-tune a pre-trained wav2vec2 model on labeled samples. 44 | ``` 45 | python train.py hparams/train.yaml 46 | ``` 47 | ### Step 2 48 | Fine-tune the model from step 1 with momentum pseudo-labeling, using both labeled and unlabled samples. 49 | ``` 50 | python train_mpl.py hparams/train_mpl.yaml 51 | ``` 52 | 53 | ## Evaluate the trained model 54 | ``` 55 | python evaluate.py hparams/evaluate.yaml 56 | ``` 57 | This will print PER and MDD F1, and write the PER and MDD details files. Note that the F1 printed here is from a MDD evaluator that is quite different from the one we used in the paper. The one used in the paper follows the prior work here: https://github.com/cageyoko/CTC-Attention-Mispronunciation. You need to convert the predictions into the acceptable format of that evaluator, which should be very straightforward. 58 | 59 | ## Inference with the trained model 60 | ``` 61 | python transcribe.py hparams/transcribe.yaml 62 | ``` 63 | By default, this command will write predictions of L2-ARCTIC test set into a json file. You can change the save path in `hparams/transcribe.yaml`. 64 | 65 | ## Acknowledgements 66 | The code is adapted from several SpeechBrain recipes: 67 | https://github.com/speechbrain/speechbrain/tree/develop/recipes/TIMIT/ASR/seq2seq 68 | https://github.com/speechbrain/speechbrain/tree/develop/recipes/LibriSpeech/ASR/transformer 69 | 70 | ## Citation 71 | ``` 72 | @inproceedings{yang22IS_Improving, 73 | author={Mu Yang and Kevin Hirschi and Stephen Daniel Looney and Okim Kang and John H.L. Hansen}, 74 | title={{Improving Mispronunciation Detection with Wav2vec2-based Momentum Pseudo-Labeling for Accentedness and Intelligibility Assessment}}, 75 | year=2022, 76 | booktitle={Proc. Interspeech 2022}, 77 | pages={4481--4485}, 78 | doi={10.21437/Interspeech.2022-11039} 79 | } 80 | ``` 81 | -------------------------------------------------------------------------------- /data/arpa_phonemes: -------------------------------------------------------------------------------- 1 | aa 0 2 | ae 1 3 | ah 2 4 | ao 3 5 | aw 4 6 | ay 5 7 | b 6 8 | ch 7 9 | d 8 10 | dh 9 11 | eh 10 12 | er 11 13 | ey 12 14 | f 13 15 | g 14 16 | hh 15 17 | ih 16 18 | iy 17 19 | jh 18 20 | k 19 21 | l 20 22 | m 21 23 | n 22 24 | ng 23 25 | ow 24 26 | oy 25 27 | p 26 28 | r 27 29 | s 28 30 | sh 29 31 | t 30 32 | th 31 33 | uh 32 34 | uw 33 35 | v 34 36 | w 35 37 | y 36 38 | z 37 39 | zh 38 40 | sil 39 41 | -------------------------------------------------------------------------------- /data/metadata_l2arctic: -------------------------------------------------------------------------------- 1 | Speaker accent gender 2 | ABA EN_AB M 3 | SKA EN_AB F 4 | YBAA EN_AB M 5 | ZHAA EN_AB F 6 | BWC EN_CN M 7 | LXC EN_CN F 8 | NCC EN_CN F 9 | TXHC EN_CN M 10 | ASI EN_IN M 11 | RRBI EN_IN M 12 | SVBI EN_IN F 13 | TNI EN_IN F 14 | HJK EN_KR F 15 | HKK EN_KR M 16 | YDCK EN_KR F 17 | YKWK EN_KR M 18 | EBVS EN_SP M 19 | ERMS EN_SP M 20 | MBMPS EN_SP F 21 | NJS EN_SP F 22 | HQTV EN_VN M 23 | PNV EN_VN F 24 | THV EN_VN F 25 | TLV EN_VN M 26 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import logging 5 | import speechbrain as sb 6 | from hyperpyyaml import load_hyperpyyaml 7 | from mpd_eval_v3 import MpdStats 8 | import librosa 9 | import json 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | def make_attn_mask(wavs, wav_lens): 14 | """ 15 | wav_lens: relative lengths(i.e. 0-1) of a batch. shape: (bs, ) 16 | return a tensor of shape (bs, seq_len), representing mask on allowed positions. 17 | 1 for regular tokens, 0 for padded tokens 18 | """ 19 | abs_lens = (wav_lens*wavs.shape[1]).long() 20 | attn_mask = wavs.new(wavs.shape).zero_().long() 21 | for i in range(len(abs_lens)): 22 | attn_mask[i, :abs_lens[i]] = 1 23 | return attn_mask 24 | 25 | # Define training procedure 26 | class ASR(sb.Brain): 27 | def compute_forward(self, batch, stage): 28 | "Given an input batch it computes the phoneme probabilities." 29 | batch = batch.to(self.device) 30 | wavs, wav_lens = batch.sig 31 | # phns_bos, _ = batch.phn_encoded_bos 32 | 33 | if stage == sb.Stage.TRAIN: 34 | if hasattr(self.hparams, "augmentation"): 35 | wavs = self.hparams.augmentation(wavs, wav_lens) 36 | 37 | # some wav2vec models (e.g. large-lv60) needs attention_mask 38 | if self.modules.wav2vec2.feature_extractor.return_attention_mask: 39 | attn_mask = make_attn_mask(wavs, wav_lens) 40 | else: 41 | attn_mask = None 42 | feats = self.modules.wav2vec2(wavs, attention_mask=attn_mask) 43 | x = self.modules.enc(feats) 44 | 45 | # output layer for ctc log-probabilities 46 | logits = self.modules.ctc_lin(x) 47 | p_ctc = self.hparams.log_softmax(logits) 48 | 49 | return p_ctc, wav_lens 50 | 51 | def compute_objectives(self, predictions, batch, stage): 52 | "Given the network predictions and targets computed the NLL loss." 53 | 54 | p_ctc, wav_lens = predictions 55 | 56 | ids = batch.id 57 | # phns_eos, phn_lens_eos = batch.phn_encoded_eos 58 | targets, target_lens = batch.phn_encoded_target 59 | if stage != sb.Stage.TRAIN: 60 | canonicals, canonical_lens = batch.phn_encoded_canonical 61 | perceiveds, perceived_lens = batch.phn_encoded_perceived 62 | 63 | loss_ctc = self.hparams.ctc_cost(p_ctc, targets, wav_lens, target_lens) 64 | loss = loss_ctc 65 | 66 | # Record losses for posterity 67 | if stage != sb.Stage.TRAIN: 68 | # Note: sb.decoders.ctc_greedy_decode will also remove padded tokens 69 | # that is, it return a list of list with different lengths 70 | sequence = sb.decoders.ctc_greedy_decode( 71 | p_ctc, wav_lens, blank_id=self.hparams.blank_index 72 | ) 73 | self.ctc_metrics.append(ids, p_ctc, targets, wav_lens, target_lens) 74 | 75 | self.per_metrics.append( 76 | ids=ids, 77 | predict=sequence, 78 | target=targets, 79 | predict_len=None, 80 | target_len=target_lens, 81 | ind2lab=self.label_encoder.decode_ndim, 82 | ) 83 | self.mpd_metrics.append( 84 | ids=ids, 85 | predict=sequence, 86 | canonical=canonicals, 87 | perceived=perceiveds, 88 | predict_len=None, 89 | canonical_len=canonical_lens, 90 | perceived_len=perceived_lens, 91 | ind2lab=self.label_encoder.decode_ndim, 92 | ) 93 | 94 | return loss 95 | 96 | def evaluate_batch(self, batch, stage): 97 | """Computations needed for validation/test batches""" 98 | predictions = self.compute_forward(batch, stage=stage) 99 | loss = self.compute_objectives(predictions, batch, stage=stage) 100 | return loss.detach() 101 | 102 | def on_stage_start(self, stage, epoch): 103 | "Gets called when a stage (either training, validation, test) starts." 104 | self.ctc_metrics = self.hparams.ctc_stats() 105 | if self.hparams.wav2vec2_specaug: 106 | self.modules.wav2vec2.model.config.apply_spec_augment = True 107 | 108 | if stage != sb.Stage.TRAIN: 109 | self.modules.wav2vec2.model.config.apply_spec_augment = False 110 | self.per_metrics = self.hparams.per_stats() 111 | self.mpd_metrics = MpdStats() 112 | 113 | def on_stage_end(self, stage, stage_loss, epoch): 114 | """Gets called at the end of a epoch.""" 115 | if stage == sb.Stage.TRAIN: 116 | self.train_loss = stage_loss 117 | else: 118 | per = self.per_metrics.summarize("error_rate") 119 | mpd_f1 = self.mpd_metrics.summarize("mpd_f1") 120 | 121 | if stage == sb.Stage.VALID: 122 | 123 | self.hparams.train_logger.log_stats( 124 | stats_meta={ 125 | "epoch": epoch, 126 | "lr_adam": self.adam_optimizer.param_groups[0]["lr"], 127 | "lr_wav2vec": self.wav2vec_optimizer.param_groups[0]["lr"], 128 | }, 129 | train_stats={"loss": self.train_loss}, 130 | valid_stats={ 131 | "loss": stage_loss, 132 | "ctc_loss": self.ctc_metrics.summarize("average"), 133 | "PER": per, 134 | "mpd_f1": mpd_f1 135 | }, 136 | ) 137 | self.checkpointer.save_and_keep_only( 138 | meta={"PER": per, "mpd_f1": mpd_f1}, min_keys=["PER"], max_keys=["mpd_f1"] 139 | ) 140 | 141 | if stage == sb.Stage.TEST: 142 | self.hparams.train_logger.log_stats( 143 | stats_meta={"Epoch loaded": 0}, 144 | test_stats={"loss": stage_loss, "PER": per, "mpd_f1": mpd_f1}, 145 | ) 146 | with open(self.hparams.wer_file, "w") as w: 147 | w.write("CTC loss stats:\n") 148 | self.ctc_metrics.write_stats(w) 149 | w.write("\nPER stats:\n") 150 | self.per_metrics.write_stats(w) 151 | print( 152 | "CTC and PER stats written to file", 153 | self.hparams.wer_file, 154 | ) 155 | with open(self.hparams.mpd_file, "w") as m: 156 | m.write("MPD results and stats:\n") 157 | self.mpd_metrics.write_stats(m) 158 | print( 159 | "MPD results and stats written to file", 160 | self.hparams.mpd_file, 161 | ) 162 | 163 | 164 | def dataio_prep(hparams): 165 | """This function prepares the datasets to be used in the brain class. 166 | It also defines the data processing pipeline through user-defined functions.""" 167 | data_folder = hparams["data_folder_save"] 168 | # 1. Declarations: 169 | test_data = sb.dataio.dataset.DynamicItemDataset.from_json( 170 | json_path=hparams["test_annotation"], 171 | replacements={"data_root": data_folder}, 172 | ) 173 | test_data = test_data.filtered_sorted(sort_key="duration") 174 | 175 | datasets = [test_data] 176 | label_encoder = sb.dataio.encoder.CTCTextEncoder() 177 | 178 | # 2. Define audio pipeline: 179 | @sb.utils.data_pipeline.takes("wav") 180 | @sb.utils.data_pipeline.provides("sig") 181 | def audio_pipeline(wav): 182 | # sig = sb.dataio.dataio.read_audio(wav) 183 | # # sample rate change to 16000, e,g, using librosa 184 | # sig = torch.Tensor(librosa.core.load(wav, hparams["sample_rate"])[0]) 185 | # Use wav2vec processor to do normalization 186 | sig = hparams["wav2vec2"].feature_extractor( 187 | librosa.core.load(wav, hparams["sample_rate"])[0], 188 | sampling_rate=hparams["sample_rate"], 189 | ).input_values[0] 190 | sig = torch.Tensor(sig) 191 | return sig 192 | 193 | sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) 194 | 195 | # 3. Define text pipeline: 196 | @sb.utils.data_pipeline.takes("perceived_train_target", "canonical_aligned", "perceived_aligned") 197 | @sb.utils.data_pipeline.provides( 198 | "phn_list_target", 199 | "phn_encoded_list_target", 200 | "phn_encoded_target", 201 | "phn_list_canonical", 202 | "phn_encoded_list_canonical", 203 | "phn_encoded_canonical", 204 | "phn_list_perceived", 205 | "phn_encoded_list_perceived", 206 | "phn_encoded_perceived", 207 | ) 208 | def text_pipeline_test(target, canonical, perceived): 209 | phn_list_target = target.strip().split() 210 | yield phn_list_target 211 | phn_encoded_list_target = label_encoder.encode_sequence(phn_list_target) 212 | yield phn_encoded_list_target 213 | phn_encoded_target = torch.LongTensor(phn_encoded_list_target) 214 | yield phn_encoded_target 215 | phn_list_canonical = canonical.strip().split() 216 | yield phn_list_canonical 217 | phn_encoded_list_canonical = label_encoder.encode_sequence(phn_list_canonical) 218 | yield phn_encoded_list_canonical 219 | phn_encoded_canonical = torch.LongTensor(phn_encoded_list_canonical) 220 | yield phn_encoded_canonical 221 | phn_list_perceived = perceived.strip().split() 222 | yield phn_list_perceived 223 | phn_encoded_list_perceived = label_encoder.encode_sequence(phn_list_perceived) 224 | yield phn_encoded_list_perceived 225 | phn_encoded_perceived = torch.LongTensor(phn_encoded_list_perceived) 226 | yield phn_encoded_perceived 227 | 228 | sb.dataio.dataset.add_dynamic_item([test_data], text_pipeline_test) 229 | 230 | # 3. Fit encoder: 231 | # Load the label encoder 232 | lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") 233 | label_encoder.load(lab_enc_file) 234 | 235 | # 4. Set output: 236 | sb.dataio.dataset.set_output_keys( 237 | [test_data], 238 | ["id", "sig", "phn_encoded_target", "phn_encoded_canonical", "phn_encoded_perceived"], 239 | ) 240 | 241 | return test_data, label_encoder 242 | 243 | 244 | if __name__ == "__main__": 245 | # CLI: 246 | hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) 247 | 248 | # Load hyperparameters file with command-line overrides 249 | with open(hparams_file) as fin: 250 | hparams = load_hyperpyyaml(fin, overrides) 251 | 252 | # Initialize ddp (useful only for multi-GPU DDP training) 253 | sb.utils.distributed.ddp_init_group(run_opts) 254 | 255 | # Create experiment directory 256 | sb.create_experiment_directory( 257 | experiment_directory=hparams["output_folder"], 258 | hyperparams_to_save=hparams_file, 259 | overrides=overrides, 260 | ) 261 | 262 | # Dataset IO prep: creating Dataset objects and proper encodings for phones 263 | test_data, label_encoder = dataio_prep(hparams) 264 | 265 | # Trainer initialization 266 | asr_brain = ASR( 267 | modules=hparams["modules"], 268 | hparams=hparams, 269 | run_opts=run_opts, 270 | checkpointer=hparams["checkpointer"], 271 | ) 272 | asr_brain.label_encoder = label_encoder 273 | 274 | # Test 275 | asr_brain.evaluate( 276 | test_data, 277 | test_loader_kwargs=hparams["test_dataloader_opts"], 278 | min_key="PER" 279 | ) 280 | -------------------------------------------------------------------------------- /hparams/evaluate.yaml: -------------------------------------------------------------------------------- 1 | # Seed needs to be set at top of yaml, before objects with parameters are made 2 | seed: 1234 3 | __set_seed: !apply:torch.manual_seed [!ref ] 4 | output_folder: !ref results/wav2vec2-base_ctc/ 5 | wer_file: !ref /wer.txt 6 | mpd_file: !ref /mpd.txt 7 | save_folder: !ref /save 8 | train_log: !ref /train_log.txt 9 | 10 | # URL for the wav2vec2 model. 11 | wav2vec2_hub: "facebook/wav2vec2-base" # wav2vec2-base, pre-trained only on LS 960h 12 | 13 | # Data files 14 | data_folder_save: "./data" 15 | # prepared l2arctic data 16 | test_annotation: !ref /test.json 17 | 18 | # Training parameters 19 | sample_rate: 16000 20 | 21 | # Model parameters 22 | activation: !name:torch.nn.LeakyReLU 23 | dnn_layers: 2 24 | dnn_neurons: 384 25 | freeze_wav2vec: False 26 | freeze_wav2vec_feature_extractor: True # freeze the CNN extractor in wav2vec 27 | wav2vec2_specaug: True 28 | 29 | # Outputs 30 | output_neurons: 42 # l2arctic: 40phns(sil)+err+blank=42 31 | blank_index: 0 32 | 33 | # Dataloader options 34 | test_dataloader_opts: 35 | batch_size: 1 36 | num_workers: 1 37 | 38 | enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN 39 | input_shape: [null, null, 768] 40 | activation: !ref 41 | dnn_blocks: !ref 42 | dnn_neurons: !ref 43 | 44 | wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2 45 | source: !ref 46 | output_norm: True 47 | freeze: !ref 48 | freeze_feature_extractor: !ref 49 | save_path: !ref /wav2vec2_checkpoint 50 | 51 | ctc_lin: !new:speechbrain.nnet.linear.Linear 52 | input_size: !ref 53 | n_neurons: !ref # 39 phonemes + 1 blank 54 | 55 | log_softmax: !new:speechbrain.nnet.activations.Softmax 56 | apply_log: True 57 | 58 | ctc_cost: !name:speechbrain.nnet.losses.ctc_loss 59 | blank_index: !ref 60 | 61 | model: !new:torch.nn.ModuleList 62 | - [!ref , !ref ] 63 | 64 | 65 | modules: 66 | wav2vec2: !ref 67 | enc: !ref 68 | ctc_lin: !ref 69 | 70 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 71 | checkpoints_dir: !ref 72 | recoverables: 73 | model: !ref 74 | wav2vec2: !ref 75 | 76 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 77 | save_file: !ref 78 | 79 | ctc_stats: !name:speechbrain.utils.metric_stats.MetricStats 80 | metric: !name:speechbrain.nnet.losses.ctc_loss 81 | blank_index: !ref 82 | reduction: batch 83 | 84 | per_stats: !name:speechbrain.utils.metric_stats.ErrorRateStats 85 | -------------------------------------------------------------------------------- /hparams/train.yaml: -------------------------------------------------------------------------------- 1 | # Seed needs to be set at top of yaml, before objects with parameters are made 2 | seed: 1234 3 | __set_seed: !apply:torch.manual_seed [!ref ] 4 | output_folder: !ref results/wav2vec2-base_ctc/ 5 | wer_file: !ref /wer.txt 6 | mpd_file: !ref /mpd.txt 7 | save_folder: !ref /save 8 | train_log: !ref /train_log.txt 9 | 10 | # URL for the wav2vec2 model. 11 | wav2vec2_hub: "facebook/wav2vec2-base" # wav2vec2-base, pre-trained only on LS 960h 12 | 13 | # Data files 14 | data_folder_save: "./data" 15 | # prepared l2arctic data 16 | train_annotation: !ref /train-train.json 17 | valid_annotation: !ref /train-dev.json 18 | test_annotation: !ref /test.json 19 | 20 | # Training parameters 21 | number_of_epochs: 50 22 | batch_size: 16 23 | lr: 0.0003 24 | lr_wav2vec: 0.00001 25 | sorting: ascending 26 | auto_mix_prec: False 27 | sample_rate: 16000 28 | gradient_accumulation: 2 29 | 30 | # Model parameters 31 | activation: !name:torch.nn.LeakyReLU 32 | dnn_layers: 2 33 | dnn_neurons: 384 34 | freeze_wav2vec: False 35 | freeze_wav2vec_feature_extractor: True # freeze the CNN extractor in wav2vec 36 | wav2vec2_specaug: True 37 | 38 | # Outputs 39 | output_neurons: 42 # l2arctic: 40phns(sil)+err+blank=42 40 | blank_index: 0 41 | 42 | # Dataloader options 43 | train_dataloader_opts: 44 | batch_size: !ref 45 | num_workers: !ref 46 | 47 | valid_dataloader_opts: 48 | batch_size: !ref 49 | num_workers: !ref 50 | 51 | test_dataloader_opts: 52 | batch_size: 1 53 | num_workers: 1 54 | 55 | augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 56 | sample_rate: !ref 57 | speeds: [95, 100, 105] 58 | 59 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 60 | limit: !ref 61 | 62 | # 1024 for wav2vec2-large, 768 for wav2vec2-base 63 | enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN 64 | input_shape: [null, null, 768] 65 | activation: !ref 66 | dnn_blocks: !ref 67 | dnn_neurons: !ref 68 | 69 | wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2 70 | source: !ref 71 | output_norm: True 72 | freeze: !ref 73 | freeze_feature_extractor: !ref 74 | save_path: !ref /wav2vec2_checkpoint 75 | 76 | ctc_lin: !new:speechbrain.nnet.linear.Linear 77 | input_size: !ref 78 | n_neurons: !ref # 39 phonemes + 1 blank 79 | 80 | 81 | log_softmax: !new:speechbrain.nnet.activations.Softmax 82 | apply_log: True 83 | 84 | ctc_cost: !name:speechbrain.nnet.losses.ctc_loss 85 | blank_index: !ref 86 | 87 | 88 | model: !new:torch.nn.ModuleList 89 | - [!ref , !ref ] 90 | 91 | adam_opt_class: !name:torch.optim.Adam 92 | lr: !ref 93 | 94 | wav2vec_opt_class: !name:torch.optim.Adam 95 | lr: !ref 96 | 97 | modules: 98 | wav2vec2: !ref 99 | enc: !ref 100 | ctc_lin: !ref 101 | 102 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 103 | checkpoints_dir: !ref 104 | recoverables: 105 | model: !ref 106 | wav2vec2: !ref 107 | counter: !ref 108 | 109 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 110 | save_file: !ref 111 | 112 | ctc_stats: !name:speechbrain.utils.metric_stats.MetricStats 113 | metric: !name:speechbrain.nnet.losses.ctc_loss 114 | blank_index: !ref 115 | reduction: batch 116 | 117 | per_stats: !name:speechbrain.utils.metric_stats.ErrorRateStats 118 | -------------------------------------------------------------------------------- /hparams/train_mpl.yaml: -------------------------------------------------------------------------------- 1 | # Seed needs to be set at top of yaml, before objects with parameters are made 2 | seed: 1234 3 | __set_seed: !apply:torch.manual_seed [!ref ] 4 | output_folder: !ref results/wav2vec2-base_ctc/ 5 | wer_file: !ref /wer.txt 6 | mpd_file: !ref /mpd.txt 7 | save_folder: !ref /save 8 | train_log: !ref /train_log.txt 9 | 10 | # URL for the wav2vec2 model. 11 | wav2vec2_hub: "facebook/wav2vec2-base" # wav2vec2-base, pre-trained only on LS 960h 12 | 13 | # Data files 14 | data_folder_save: "./data" 15 | # prepared l2arctic data 16 | train_annotation: !ref /train-train.json 17 | unlabeled_annotation: !ref /train_unlabeled.json 18 | valid_annotation: !ref /train-dev.json 19 | test_annotation: !ref /test.json 20 | 21 | # Training parameters 22 | number_of_epochs: 100 23 | batch_size_labeled: 16 # 16 works for wav2vec2-base; 4 works for wav2vec-base-960h 24 | batch_size_unlabeled: 16 # 16 works for wav2vec2-base; 4 works for wav2vec-base-960h 25 | num_workers: 8 26 | lr: 0.0003 27 | lr_wav2vec: 0.00001 28 | sorting: random 29 | auto_mix_prec: False 30 | sample_rate: 16000 31 | gradient_accumulation: 2 32 | base_model_factor: 0.5 33 | 34 | # Model parameters 35 | activation: !name:torch.nn.LeakyReLU 36 | dnn_layers: 2 37 | dnn_neurons: 384 38 | freeze_wav2vec: False 39 | freeze_wav2vec_feature_extractor: True # freeze the CNN extractor in wav2vec 40 | wav2vec2_specaug: True 41 | 42 | # Outputs 43 | output_neurons: 42 # l2arctic: 40phns(sil)+err+blank=42 44 | blank_index: 0 45 | 46 | augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 47 | sample_rate: !ref 48 | speeds: [95, 100, 105] 49 | 50 | epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter 51 | limit: !ref 52 | 53 | # 1024 for wav2vec2-large, 768 for wav2vec2-base 54 | enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN 55 | input_shape: [null, null, 768] 56 | activation: !ref 57 | dnn_blocks: !ref 58 | dnn_neurons: !ref 59 | 60 | wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2 61 | source: !ref 62 | output_norm: True 63 | freeze: !ref 64 | freeze_feature_extractor: !ref 65 | save_path: !ref /wav2vec2_checkpoint 66 | 67 | 68 | ctc_lin: !new:speechbrain.nnet.linear.Linear 69 | input_size: !ref 70 | n_neurons: !ref # 39 phonemes + 1 blank 71 | 72 | 73 | log_softmax: !new:speechbrain.nnet.activations.Softmax 74 | apply_log: True 75 | 76 | ctc_cost: !name:speechbrain.nnet.losses.ctc_loss 77 | blank_index: !ref 78 | 79 | 80 | model: !new:torch.nn.ModuleList 81 | - [!ref , !ref ] 82 | 83 | adam_opt_class: !name:torch.optim.Adam 84 | lr: !ref 85 | 86 | wav2vec_opt_class: !name:torch.optim.Adam 87 | lr: !ref 88 | 89 | modules: 90 | wav2vec2: !ref 91 | enc: !ref 92 | ctc_lin: !ref 93 | 94 | ### teacher model layers 95 | enc_teacher: !copy 96 | ctc_lin_teacher: !copy 97 | wav2vec2_teacher: !copy 98 | model_teacher: !new:torch.nn.ModuleList 99 | - [!ref , !ref ] 100 | modules_teacher: 101 | wav2vec2: !ref 102 | enc: !ref 103 | ctc_lin: !ref 104 | 105 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 106 | checkpoints_dir: !ref 107 | recoverables: 108 | model: !ref 109 | wav2vec2: !ref 110 | counter: !ref 111 | 112 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 113 | save_file: !ref 114 | 115 | ctc_stats: !name:speechbrain.utils.metric_stats.MetricStats 116 | metric: !name:speechbrain.nnet.losses.ctc_loss 117 | blank_index: !ref 118 | reduction: batch 119 | 120 | per_stats: !name:speechbrain.utils.metric_stats.ErrorRateStats 121 | -------------------------------------------------------------------------------- /hparams/transcribe.yaml: -------------------------------------------------------------------------------- 1 | # Seed needs to be set at top of yaml, before objects with parameters are made 2 | seed: 1234 3 | __set_seed: !apply:torch.manual_seed [!ref ] 4 | output_folder: !ref results/wav2vec2-base_ctc_mpl/ 5 | wer_file: !ref /wer.txt 6 | mpd_file: !ref /mpd.txt 7 | save_folder: !ref /save 8 | train_log: !ref /train_log.txt 9 | 10 | # URL for the wav2vec2 model. 11 | wav2vec2_hub: "facebook/wav2vec2-base" # wav2vec2-base, pre-trained only on LS 960h 12 | 13 | # Data files 14 | data_folder_save: "./data" 15 | inference_annotation: !ref /test.json 16 | inference_annotation_saved: !ref /test.pred.json 17 | 18 | # Training parameters 19 | sample_rate: 16000 20 | 21 | # Model parameters 22 | activation: !name:torch.nn.LeakyReLU 23 | dnn_layers: 2 24 | dnn_neurons: 384 25 | freeze_wav2vec: False 26 | freeze_wav2vec_feature_extractor: True # freeze the CNN extractor in wav2vec 27 | wav2vec2_specaug: False 28 | 29 | # Outputs 30 | output_neurons: 42 # TIMIT: 39phs(sil)+blank=40; l2arctic: 40phns(sil)+err+blank=42 31 | blank_index: 0 32 | 33 | inference_dataloader_opts: 34 | batch_size: 1 35 | num_workers: 1 36 | 37 | augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment 38 | sample_rate: !ref 39 | speeds: [95, 100, 105] 40 | 41 | # 1024 for wav2vec2-large, 768 for wav2vec2-base 42 | enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN 43 | input_shape: [null, null, 768] 44 | activation: !ref 45 | dnn_blocks: !ref 46 | dnn_neurons: !ref 47 | 48 | wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2 49 | source: !ref 50 | output_norm: True 51 | freeze: !ref 52 | freeze_feature_extractor: !ref 53 | save_path: !ref /wav2vec2_checkpoint 54 | 55 | 56 | ctc_lin: !new:speechbrain.nnet.linear.Linear 57 | input_size: !ref 58 | n_neurons: !ref # 39 phonemes + 1 blank 59 | 60 | log_softmax: !new:speechbrain.nnet.activations.Softmax 61 | apply_log: True 62 | 63 | ctc_cost: !name:speechbrain.nnet.losses.ctc_loss 64 | blank_index: !ref 65 | 66 | model: !new:torch.nn.ModuleList 67 | - [!ref , !ref ] 68 | 69 | modules: 70 | wav2vec2: !ref 71 | enc: !ref 72 | ctc_lin: !ref 73 | 74 | checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer 75 | checkpoints_dir: !ref 76 | recoverables: 77 | model: !ref 78 | wav2vec2: !ref 79 | 80 | train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger 81 | save_file: !ref 82 | 83 | ctc_stats: !name:speechbrain.utils.metric_stats.MetricStats 84 | metric: !name:speechbrain.nnet.losses.ctc_loss 85 | blank_index: !ref 86 | reduction: batch 87 | 88 | per_stats: !name:speechbrain.utils.metric_stats.ErrorRateStats 89 | -------------------------------------------------------------------------------- /l2arctic_prepare.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import json 4 | from speechbrain.dataio.dataio import read_audio 5 | from glob import glob 6 | from textgrid import TextGrid, IntervalTier 7 | import re 8 | import copy 9 | from collections import defaultdict 10 | 11 | 12 | SAMPLERATE = 44100 13 | phn_set="data/arpa_phonemes" 14 | def process_arpa_phoneme(path): 15 | with open(path, 'r') as f: 16 | lines = f.readlines() 17 | arpa_phonemes= [] 18 | for line in lines: 19 | items = line.strip().split() 20 | arpa_phonemes.append(items[0]) 21 | return arpa_phonemes 22 | 23 | ARPA_PHONEMES = process_arpa_phoneme(phn_set) 24 | 25 | def prepare_l2arctic( 26 | data_folder, 27 | save_json_train="train_l2arctic.json", 28 | save_json_test="test_l2arctic.json", 29 | metadata_l2arctic="data/metadata_l2arctic", 30 | test_spks=['TLV', 'NJS', 'TNI', 'TXHC', 'ZHAA', 'YKWK',] 31 | ): 32 | if not os.path.exists(os.path.dirname(save_json_train)): 33 | os.makedirs(os.path.dirname(save_json_train)) 34 | if not os.path.exists(os.path.dirname(save_json_test)): 35 | os.makedirs(os.path.dirname(save_json_test)) 36 | total_spks = [] 37 | with open(metadata_l2arctic, 'r') as reader: 38 | for ii, line in enumerate(reader): 39 | if ii == 0: 40 | # Skip the header 41 | continue 42 | name, dialect, gender = line.split() 43 | total_spks.append(name) 44 | train_spks = [x for x in total_spks if x not in test_spks] 45 | 46 | spks_to_solve = [ 47 | (save_json_train, train_spks), 48 | (save_json_test, test_spks) 49 | ] 50 | for split, spks in spks_to_solve: 51 | make_json(data_folder, split, spks) 52 | 53 | def make_json(data_folder, split, spks): 54 | print("Creating {}".format(split)) 55 | 56 | json_data = defaultdict(dict) 57 | for spk in spks: 58 | spk_data = get_data_from_spk(data_folder, spk) 59 | json_data.update(spk_data) 60 | with open(split, mode="w") as json_f: 61 | json.dump(json_data, json_f, indent=2) 62 | 63 | print(f"{split} successfully created!") 64 | 65 | def get_data_from_spk(data_folder, spk): 66 | wav_dir = os.path.join(data_folder, spk, 'wav') 67 | tg_dir = os.path.join(data_folder, spk, 'annotation') 68 | text_dir = os.path.join(data_folder, spk, 'transcript') 69 | 70 | spk_data = defaultdict(dict) 71 | for tg_file in glob(os.path.join(tg_dir, "*.TextGrid")): 72 | 73 | tg = TextGrid() 74 | try: 75 | tg.read(tg_file) 76 | except ValueError: 77 | continue 78 | 79 | basename = os.path.basename(tg_file).split(".")[0] 80 | wav_file = os.path.join(wav_dir, basename + ".wav") 81 | text_file = os.path.join(text_dir, basename + '.txt') 82 | utt_data = get_data_from_utt(tg, wav_file, text_file, spk) 83 | spk_data.update(utt_data) 84 | return spk_data 85 | 86 | def get_data_from_utt(tg, wav_file, text_file, spk): 87 | utt_data = {} 88 | utt_data[wav_file] = {} 89 | utt_data[wav_file]["wav"] = wav_file 90 | # Reading the signal (to retrieve duration in seconds) 91 | signal = read_audio(wav_file) 92 | duration = len(signal) / SAMPLERATE 93 | utt_data[wav_file]["duration"] = duration 94 | utt_data[wav_file]["spk_id"] = spk 95 | ## To keep original human annotation, set `keep_artifical_sil=True`, `rm_repetitive_sil=False` 96 | ## this preserve the original alignment within the annotations 97 | cano_phns_align, perc_phns_align = get_phonemes(tg, keep_artificial_sil=True, rm_repetitive_sil=False) 98 | utt_data[wav_file]["canonical_aligned"] = cano_phns_align 99 | utt_data[wav_file]["perceived_aligned"] = perc_phns_align 100 | ## To get training target phones, set `keep_artifical_sil=False`, `rm_repetitive_sil=True` 101 | ## this apply some preprocessing on the perceived phones, i.e. rm artifical and repetitive sil 102 | _, target_phns = get_phonemes(tg, keep_artificial_sil=False, rm_repetitive_sil=True) 103 | utt_data[wav_file]["perceived_train_target"] = target_phns 104 | 105 | with open(text_file, "r") as reader: 106 | text = reader.readline() 107 | utt_data[wav_file]["wrd"] = text 108 | return utt_data 109 | 110 | def get_phonemes(tg, keep_artificial_sil=False, rm_repetitive_sil=True): 111 | phone_tier = tg.getFirst("phones") 112 | perceived_phones = normalize_tier_mark(phone_tier, "NormalizePhonePerceived", keep_artificial_sil) 113 | canonical_phones = normalize_tier_mark(phone_tier, "NormalizePhoneCanonical", keep_artificial_sil) 114 | canonical_phones = tier_to_list(canonical_phones) 115 | perceived_phones = tier_to_list(perceived_phones) 116 | if keep_artificial_sil: 117 | # when we preserve the artificial sils, the canonical phones and 118 | # perceived phones should be perfectly aligned 119 | assert len(canonical_phones) == len(perceived_phones) 120 | if rm_repetitive_sil: 121 | canonical_phones = remove_repetitive_sil(canonical_phones) 122 | perceived_phones = remove_repetitive_sil(perceived_phones) 123 | return " ".join(canonical_phones), " ".join(perceived_phones) 124 | 125 | def tier_to_list(tier): 126 | return [interval.mark for interval in tier] 127 | 128 | def get_word_bounds(word_tier, phone_tier): 129 | """ 130 | word_tier: [(minTime, maxTime, word1), (minTime, maxTime, word2), ...] 131 | phone_tier: [(minTime, maxTime, phn1), (minTime, maxTime, phn2), ...] 132 | return word_bounds: [(0, 3), (4, 7), ...], length should be the same as word_tier 133 | """ 134 | phn_interval_list = [x for x in phone_tier] 135 | word_interval_list = [x for x in word_tier] 136 | phn_idx = 0 137 | word_bounds = [] 138 | for word_idx in range(len(word_interval_list)): 139 | word_interval = word_interval_list[word_idx] 140 | bound = [] 141 | while word_interval.maxTime >= phn_interval_list[phn_idx].maxTime: 142 | bound.append(phn_idx) 143 | phn_idx += 1 144 | if phn_idx == len(phn_interval_list): 145 | break 146 | word_bounds.append(bound) 147 | 148 | word_bounds = [(x[0], x[-1]) for x in word_bounds if x != []] 149 | return word_bounds 150 | 151 | 152 | def remove_repetitive_sil(phone_list): 153 | # Filtering out consecutive silences by applying a mask with `True` marking 154 | # which sils to remove 155 | # e.g. 156 | # phone_list [ "a", "sil", "sil", "sil", "b"] 157 | # --- 158 | # create: 159 | # remove_sil_mask [False, True, True, False, False] 160 | # --- 161 | # so end result is: 162 | # phone_list ["a", "sil", "b"] 163 | 164 | remove_sil_mask = [True if x == "sil" else False for x in phone_list] 165 | 166 | for i, val in enumerate(remove_sil_mask): 167 | if val is True: 168 | if i == len(remove_sil_mask) - 1: 169 | remove_sil_mask[i] = False 170 | elif remove_sil_mask[i + 1] is False: 171 | remove_sil_mask[i] = False 172 | 173 | phone_list = [ 174 | phon for i, phon in enumerate(phone_list) if not remove_sil_mask[i] 175 | ] 176 | return phone_list 177 | 178 | def normalize_tier_mark(tier: IntervalTier, 179 | mode="NormalizePhoneCanonical", keep_artificial_sil=False) -> IntervalTier: 180 | """Normalize the marks of an IntervalTier. 181 | Refer to the code for supported modes. 182 | Args: 183 | tier: An IntervalTier object. 184 | mode: The filter function for each mark in the tier. 185 | Returns: 186 | tier: Mark-normalized tier. 187 | """ 188 | tier = copy.deepcopy(tier) 189 | tier_out = IntervalTier() 190 | if mode not in {"NormalizePhoneCanonical", 191 | "NormalizePhonePerceived", 192 | "NormalizePhoneAnnotation", 193 | "NormalizeWord"}: 194 | raise ValueError("Mode %s is not valid.", mode) 195 | for each_interval in tier.intervals: 196 | if mode == "NormalizePhoneCanonical": 197 | # Only keep the canonical pronunciation. 198 | p = normalize_phone(each_interval.mark, True, True, keep_artificial_sil) 199 | elif mode == "NormalizePhonePerceived": 200 | # Only keep the perceived pronunciation. 201 | p = normalize_phone(each_interval.mark, True, False, keep_artificial_sil) 202 | elif mode == "NormalizePhoneAnnotation": 203 | # Keep the annotations. 204 | p = normalize_phone(each_interval.mark, False) 205 | elif mode == "NormalizeWord": 206 | p = normalize_word(each_interval.mark) 207 | 208 | if p is None: 209 | continue 210 | if p == 'ax': 211 | p = 'ah' 212 | each_interval.mark = p 213 | assert p in ARPA_PHONEMES + ["err"], pdb.set_trace() 214 | tier_out.addInterval(each_interval) 215 | return tier_out 216 | 217 | def normalize_phone(s: str, is_rm_annotation=True, is_phoneme_canonical=True, 218 | keep_artificial_sil=False) -> str: 219 | """Normalize phoneme labels to lower case, stress-free form. 220 | This will also deal with L2-ARCTIC annotations. 221 | Args: 222 | s: A phoneme annotation. 223 | is_rm_annotation: [optional] Only return the canonical pronunciation if 224 | set to true, otherwise will keep the annotations. 225 | is_phoneme_canonical: [optional] If set to true, return canonical phoneme; otherwise 226 | return perceived phoneme. 227 | keep_artificial_sil: If true, will keep the artificial sil produced by the way L2ARCTIC was annotated. 228 | If false, will not have the sil 229 | e.g. when false, 'ah, sil, d' canonical: ah, perceived: None 230 | when true, 'ah, sil, d' canonical: ah, perceived: sil 231 | Returns: 232 | Normalized phoneme (canonical pronunciation or with annotations). 233 | """ 234 | t = s.lower() 235 | pattern = re.compile(r"[^a-z,]") 236 | parse_tag = pattern.sub("", t) 237 | if is_sil(parse_tag): 238 | return "sil" 239 | if len(parse_tag) == 0: 240 | raise ValueError("Input %s is invalid.", s) 241 | if len(parse_tag.split(",")) == 1: 242 | if parse_tag.split(",")[0] == 'ax': 243 | return 'ah' 244 | else: 245 | return parse_tag.split(",")[0] 246 | if is_rm_annotation: 247 | # This handles the L2-ARCTIC annotations, here we extract the canonical 248 | # pronunciation 249 | if keep_artificial_sil: 250 | if is_phoneme_canonical: 251 | return parse_tag.split(",")[0] 252 | else: 253 | return parse_tag.split(",")[1] 254 | elif not keep_artificial_sil: 255 | if is_phoneme_canonical: 256 | if parse_tag.split(",")[2] in ['s', 'd']: 257 | return parse_tag.split(",")[0] 258 | elif parse_tag.split(",")[2] == 'a': 259 | return None 260 | else: 261 | if parse_tag.split(",")[2] in ['s', 'a']: 262 | return parse_tag.split(",")[1] 263 | elif parse_tag.split(",")[2] == 'd': 264 | return None 265 | else: 266 | return parse_tag 267 | 268 | def is_sil(s: str) -> bool: 269 | """Test if the input string represents silence. 270 | Args: 271 | s: A phoneme label. 272 | Returns: 273 | True if is silence, otherwise False. 274 | """ 275 | if s.lower() in {"sil", "sp", "spn", "pau", ""}: 276 | return True 277 | else: 278 | return False 279 | 280 | 281 | if __name__ == "__main__": 282 | 283 | prepare_l2arctic( 284 | data_folder=sys.argv[1], 285 | save_json_train="data/train.json", 286 | save_json_test="data/test.json", 287 | metadata_l2arctic="data/metadata_l2arctic", 288 | test_spks=['TLV', 'NJS', 'TNI', 'TXHC', 'ZHAA', 'YKWK',] 289 | ) 290 | 291 | 292 | -------------------------------------------------------------------------------- /l2arctic_unlabeled_prepare.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import json 4 | from speechbrain.dataio.dataio import read_audio 5 | from glob import glob 6 | import re 7 | import copy 8 | from collections import defaultdict 9 | 10 | 11 | SAMPLERATE = 44100 12 | phn_set="data/arpa_phonemes" 13 | def process_arpa_phoneme(path): 14 | with open(path, 'r') as f: 15 | lines = f.readlines() 16 | arpa_phonemes= [] 17 | for line in lines: 18 | items = line.strip().split() 19 | arpa_phonemes.append(items[0]) 20 | return arpa_phonemes 21 | 22 | ARPA_PHONEMES = process_arpa_phoneme(phn_set) 23 | 24 | def prepare_l2arctic_unlabeled( 25 | data_folder, 26 | save_json_train="train_l2arctic_unlabeled.json", 27 | labeled_json="train_l2arctic.json", 28 | metadata_l2arctic="data/metadata_l2arctic", 29 | test_spks=['TLV', 'NJS', 'TNI', 'TXHC', 'ZHAA', 'YKWK',], 30 | ): 31 | if not os.path.exists(os.path.dirname(save_json_train)): 32 | os.makedirs(os.path.dirname(save_json_train)) 33 | total_spks = [] 34 | with open(metadata_l2arctic, 'r') as reader: 35 | for ii, line in enumerate(reader): 36 | if ii == 0: 37 | # Skip the header 38 | continue 39 | name, dialect, gender = line.split() 40 | total_spks.append(name) 41 | train_spks = [x for x in total_spks if x not in test_spks] 42 | 43 | with open(labeled_json, "r") as json_file: 44 | labeled_data = json.load(json_file) 45 | 46 | spks_to_solve = [ 47 | (save_json_train, train_spks), 48 | ] 49 | for split, spks in spks_to_solve: 50 | make_json(data_folder, split, spks, labeled_data) 51 | 52 | def make_json(data_folder, split, spks, labeled_data): 53 | """ 54 | check whether the wav is presented in labled_data 55 | we only keep those wav that were not labled. 56 | """ 57 | print("Creating {}".format(split)) 58 | 59 | json_data = defaultdict(dict) 60 | for spk in spks: 61 | spk_data = get_data_from_spk(data_folder, spk, labeled_data) 62 | json_data.update(spk_data) 63 | with open(split, mode="w") as json_f: 64 | json.dump(json_data, json_f, indent=2) 65 | 66 | print(f"{split} successfully created!") 67 | 68 | def get_data_from_spk(data_folder, spk, labeled_data): 69 | wav_dir = os.path.join(data_folder, spk, 'wav') 70 | tg_dir = os.path.join(data_folder, spk, 'annotation') 71 | text_dir = os.path.join(data_folder, spk, 'transcript') 72 | spk_data = defaultdict(dict) 73 | for wav_file in glob(os.path.join(wav_dir, "*.wav")): 74 | if wav_file in labeled_data: 75 | continue 76 | 77 | basename = os.path.basename(wav_file).split(".")[0] 78 | text_file = os.path.join(text_dir, basename + '.txt') 79 | utt_data = get_data_from_utt(wav_file, text_file, spk) 80 | spk_data.update(utt_data) 81 | return spk_data 82 | 83 | def get_data_from_utt( wav_file, text_file, spk): 84 | utt_data = {} 85 | utt_data[wav_file] = {} 86 | utt_data[wav_file]["wav"] = wav_file 87 | # Reading the signal (to retrieve duration in seconds) 88 | signal = read_audio(wav_file) 89 | duration = len(signal) / SAMPLERATE 90 | utt_data[wav_file]["duration"] = duration 91 | utt_data[wav_file]["spk_id"] = spk 92 | 93 | with open(text_file, "r") as reader: 94 | text = reader.readline() 95 | utt_data[wav_file]["wrd"] = text 96 | return utt_data 97 | 98 | 99 | 100 | if __name__ == "__main__": 101 | 102 | prepare_l2arctic_unlabeled( 103 | data_folder=sys.argv[1], 104 | save_json_train="data/train_unlabeled.json", 105 | labeled_json="data/train.json", 106 | metadata_l2arctic="data/metadata_l2arctic", 107 | test_spks=['TLV', 'NJS', 'TNI', 'TXHC', 'ZHAA', 'YKWK',], 108 | ) 109 | 110 | 111 | -------------------------------------------------------------------------------- /mpd_eval_v3.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import argparse 5 | import speechbrain as sb 6 | from speechbrain.utils.data_utils import undo_padding 7 | from speechbrain.utils.edit_distance import wer_details_for_batch 8 | from speechbrain.dataio.wer import print_alignments, _print_alignment 9 | from speechbrain.utils.metric_stats import MetricStats, ErrorRateStats 10 | 11 | EDIT_SYMBOLS = { 12 | "eq": "=", # when tokens are equal 13 | "ins": "I", 14 | "del": "D", 15 | "sub": "S", 16 | } 17 | 18 | class MpdStats(MetricStats): 19 | """Compute MDD eval metrics, adapted from speechbrain.utils.metric_stats.MetricStats 20 | see speechbrain.utils.metric_stats.MetricStats 21 | """ 22 | 23 | def __init__(self, merge_tokens=False, split_tokens=False, space_token="_"): 24 | self.clear() 25 | self.merge_tokens = merge_tokens 26 | self.split_tokens = split_tokens 27 | self.space_token = space_token 28 | 29 | def append( 30 | self, 31 | ids, 32 | predict, 33 | canonical, 34 | perceived, 35 | predict_len=None, 36 | canonical_len=None, 37 | perceived_len=None, 38 | ind2lab=None, 39 | ): 40 | self.ids.extend(ids) 41 | 42 | if predict_len is not None: 43 | predict = undo_padding(predict, predict_len) 44 | 45 | if canonical_len is not None: 46 | canonical = undo_padding(canonical, canonical_len) 47 | if perceived_len is not None: 48 | perceived = undo_padding(perceived, perceived_len) 49 | 50 | if ind2lab is not None: 51 | predict = ind2lab(predict) 52 | canonical = ind2lab(canonical) 53 | perceived = ind2lab(perceived) 54 | 55 | if self.merge_tokens: 56 | predict = merge_char(predict, space=self.space_token) 57 | target = merge_char(target, space=self.space_token) 58 | 59 | if self.split_tokens: 60 | predict = split_word(predict, space=self.space_token) 61 | target = split_word(target, space=self.space_token) 62 | 63 | ## remove parallel sil in cano and perc 64 | canonical, perceived = rm_parallel_sil_batch(canonical, perceived) 65 | assert len(canonical) == len(perceived) # make sure cano and perc are aligned 66 | 67 | ## remove all sil in hyp 68 | predict = [[x for x in y if x!= "sil"] for y in predict] 69 | 70 | 71 | alignments = [extract_alignment(c, p) for c, p in zip(canonical, perceived)] 72 | wer_details = wer_details_for_batch(ids=ids, 73 | refs=[[s for s in c if s != "sil"] for c in canonical], 74 | hyps=predict, 75 | compute_alignments=True) 76 | ## let's be clear about the two alignments' names, rename the keys 77 | for a, p, det in zip(alignments, perceived, wer_details): 78 | det["alignment_cano2hyp"] = det.pop("alignment") 79 | det["canonical"] = det.pop("ref_tokens") 80 | det["hypothesis"] = det.pop("hyp_tokens") 81 | det.update({"alignment_cano2perc": a}) 82 | det.update({"perceived": [s for s in p if s != "sil"]}) 83 | 84 | 85 | self.scores.extend(wer_details) 86 | 87 | def summarize(self, field=None): 88 | """Summarize the error_rate and return relevant statistics. 89 | * See MetricStats.summarize() 90 | """ 91 | # self.summary = wer_summary(self.scores) 92 | self.summary = mpd_summary(self.scores) 93 | 94 | # Add additional, more generic key 95 | self.summary["mpd_f1"] = self.summary["f1"] 96 | 97 | if field is not None: 98 | return self.summary[field] 99 | else: 100 | return self.summary 101 | 102 | def write_stats(self, filestream): 103 | """Write all relevant info (e.g., error rate alignments) to file. 104 | * See MetricStats.write_stats() 105 | """ 106 | if not self.summary: 107 | self.summarize() 108 | 109 | print_mpd_details(self.scores, self.summary, filestream) 110 | 111 | 112 | def mpd_eval_on_dataset(in_json, mpd_file=sys.stdout, per_file=None): 113 | 114 | if per_file: 115 | error_rate_stats = ErrorRateStats() 116 | total_wer_details = [] 117 | 118 | for wav_id, wav_data in in_json.items(): 119 | cano_phns = wav_data["canonical_phn"].split() 120 | perc_phns = wav_data["phn"].split() 121 | cano_phns, perc_phns = rm_parallel_sil(cano_phns, perc_phns) 122 | assert len(cano_phns) == len(perc_phns) 123 | 124 | alignment = extract_alignment(cano_phns, perc_phns) 125 | 126 | 127 | hyp = [s for s in wav_data["hyp"].split() if s!= "sil"] 128 | # hyp = wav_data["hyp"].split() 129 | wer_details = wer_details_for_batch(ids=[wav_id], 130 | refs=[[s for s in cano_phns if s != "sil"]], 131 | hyps=[hyp], 132 | compute_alignments=True)[0] 133 | ## let's be clear about the two alignments' names, rename the keys 134 | wer_details["alignment_cano2hyp"] = wer_details.pop("alignment") 135 | wer_details["canonical"] = wer_details.pop("ref_tokens") 136 | wer_details["hypothesis"] = wer_details.pop("hyp_tokens") 137 | wer_details.update({"alignment_cano2perc": alignment}) 138 | wer_details.update({"perceived": [s for s in perc_phns if s != "sil"]}) 139 | wer_details.update({"wav_id": wav_id}) 140 | 141 | total_wer_details.append(wer_details) 142 | 143 | 144 | if per_file: 145 | error_rate_stats.append(ids=[wav_id], 146 | target=[[s for s in cano_phns if s != "sil"]], 147 | predict=[hyp]) 148 | 149 | if per_file: 150 | error_rate_stats.write_stats(per_file) 151 | mpd_stats = mpd_summary(total_wer_details) 152 | print_mpd_details(total_wer_details, mpd_stats, mpd_file) 153 | 154 | 155 | def mpd_summary(total_wer_details): 156 | 157 | total_ta, total_fr, total_fa, total_tr, total_cor_diag, total_err_diag = 0, 0, 0, 0, 0, 0 158 | total_ins, total_del, total_sub, total_eq = 0, 0, 0, 0 159 | for det in total_wer_details: 160 | 161 | total_ins += len([a for a in det["alignment_cano2perc"] if a[0] == "I"]) 162 | total_del += len([a for a in det["alignment_cano2perc"] if a[0] == "D"]) 163 | total_sub += len([a for a in det["alignment_cano2perc"] if a[0] == "S"]) 164 | total_eq += len([a for a in det["alignment_cano2perc"] if a[0] == "="]) 165 | 166 | 167 | ta, fr, fa, tr, cor_diag, err_diag = mpd_stats(det["alignment_cano2perc"], 168 | det["alignment_cano2hyp"], 169 | det["canonical"], 170 | det["perceived"], 171 | det["hypothesis"]) 172 | assert tr == (cor_diag + err_diag) 173 | det.update({ 174 | "ta": ta, 175 | "fr": fr, 176 | "fa": fa, 177 | "tr": tr, 178 | "cor_diag": cor_diag, 179 | "err_diag": err_diag, 180 | }) 181 | 182 | total_ta += ta 183 | total_fr += fr 184 | total_fa += fa 185 | total_tr += tr 186 | total_cor_diag += cor_diag 187 | total_err_diag += err_diag 188 | 189 | precision = 1.0*total_tr / (total_fr + total_tr) 190 | recall = 1.0*total_tr / (total_fa + total_tr) 191 | f1 = 2.0 * precision * recall / (precision + recall) 192 | return { 193 | "total_eq": total_eq, 194 | "total_sub": total_sub, 195 | "total_del": total_del, 196 | "total_ins": total_ins, 197 | "ta": total_ta, 198 | "fr": total_fr, 199 | "fa": total_fa, 200 | "tr": total_tr, 201 | "cor_diag": total_cor_diag, 202 | "err_diag": total_err_diag, 203 | "precision": precision, 204 | "recall": recall, 205 | "f1": f1 206 | } 207 | 208 | def print_mpd_details(wer_details, mpd_stats, mpd_file): 209 | 210 | 211 | print("In original annotation: \nTotal Eq: {}, Total Sub: {}, Total Del: {}, Total Ins: {}".format(\ 212 | mpd_stats["total_eq"], mpd_stats["total_sub"], mpd_stats["total_del"], mpd_stats["total_ins"]), file=mpd_file) 213 | print("Overall MPD results: \nTrue Accept: {}, False Rejection: {}, False Accept: {}, True Rejection: {}, Corr Diag: {}, Err Diag: {}".format(\ 214 | mpd_stats["ta"], mpd_stats["fr"], mpd_stats["fa"], mpd_stats["tr"], mpd_stats["cor_diag"], mpd_stats["err_diag"]), file=mpd_file) 215 | print("Precision: {}, Recall: {}, F1: {}".format(mpd_stats["precision"], mpd_stats["recall"], mpd_stats["f1"]), file=mpd_file) 216 | 217 | for det in wer_details: 218 | print("="*80, file=mpd_file) 219 | print(det["key"], file=mpd_file) 220 | print("Human annotation: Canonical vs Perceived:", file=mpd_file) 221 | _print_alignment(alignment=det["alignment_cano2perc"], 222 | a=det["canonical"], 223 | b=det["perceived"], 224 | file=mpd_file) 225 | 226 | print("Model Prediction: Canonical vs Hypothesis:", file=mpd_file) 227 | _print_alignment(alignment=det["alignment_cano2hyp"], 228 | a=det["canonical"], 229 | b=det["hypothesis"], 230 | file=mpd_file) 231 | print("True Accept: {}, False Rejection: {}, False Accept: {}, True Reject: {}, Corr Diag: {}, Err Diag: {}".format(\ 232 | det["ta"], det["fr"], det["fa"], det["tr"], det["cor_diag"], det["err_diag"]), file=mpd_file) 233 | 234 | 235 | 236 | def mpd_stats(align_c2p, align_c2h, c, p, h): 237 | """ 238 | schema: [(operator, idx_i(None), idx_j(None))] 239 | c: canonical 240 | p: perceived 241 | h: hypothesis 242 | """ 243 | cnt = 0 244 | ta, fr, fa, tr, cor_diag, err_diag = 0, 0, 0, 0, 0, 0 245 | # cano_len = 1 + max(x[1] for x in align_c2p) 246 | assert max(x[1] for x in align_c2p if x[1] is not None) == max(x[1] for x in align_c2h if x[1] is not None) 247 | 248 | i, j = 0, 0 249 | while i < len(align_c2p) and j < len(align_c2h): 250 | ## sub and del cases 251 | if align_c2p[i][1] is not None and \ 252 | align_c2h[j][1] is not None and \ 253 | align_c2p[i][1] == align_c2h[j][1]: 254 | assert align_c2p[i][0] != EDIT_SYMBOLS["ins"] 255 | assert align_c2h[j][0] != EDIT_SYMBOLS["ins"] 256 | if align_c2p[i][0] == EDIT_SYMBOLS["eq"]: 257 | ## canonical cases 258 | if align_c2h[j][0] == EDIT_SYMBOLS["eq"]: 259 | ta += 1 260 | else: 261 | fr += 1 262 | elif align_c2p[i][0] != EDIT_SYMBOLS["eq"]: 263 | ## mispronunciation cases 264 | if align_c2h[j][0] == EDIT_SYMBOLS["eq"]: 265 | fa += 1 266 | else: 267 | tr += 1 268 | if align_c2p[i][0] != align_c2h[j][0]: 269 | err_diag += 1 270 | elif align_c2p[i][0] == EDIT_SYMBOLS["del"] and align_c2h[j][0] == EDIT_SYMBOLS["del"]: 271 | cor_diag += 1 272 | elif align_c2p[i][0] == EDIT_SYMBOLS["sub"] and align_c2h[j][0] == EDIT_SYMBOLS["sub"]: 273 | if p[align_c2p[i][2]] == h[align_c2h[j][2]]: 274 | cor_diag += 1 275 | else: 276 | err_diag += 1 277 | i += 1 278 | j += 1 279 | ## ins cases 280 | elif align_c2p[i][1] is None and \ 281 | align_c2h[j][1] is not None: 282 | fa += 1 283 | i += 1 284 | elif align_c2p[i][1] is not None and \ 285 | align_c2h[j][1] is None: 286 | fr += 1 287 | j += 1 288 | elif align_c2p[i][1] is None and align_c2h[j][1] is None: 289 | tr += 1 290 | if p[align_c2p[i][2]] == h[align_c2h[j][2]]: 291 | cor_diag += 1 292 | else: 293 | err_diag += 1 294 | i += 1 295 | j += 1 296 | if i == len(align_c2p) and j != len(align_c2h): 297 | fr += len(align_c2h[j:]) 298 | if i != len(align_c2p) and j == len(align_c2h): 299 | fa += len(align_c2p[j:]) 300 | 301 | return ta, fr, fa, tr, cor_diag, err_diag 302 | 303 | 304 | def extract_alignment(a, b, gap_token="sil"): 305 | """ 306 | a, b are two aligned lists (i.e. same length) 307 | gap_token is the artificial token placeholder used in L2Arctic annotation. In this case is a `sil` token 308 | """ 309 | alignment = [] 310 | idx_a, idx_b = 0, 0 311 | for str_a, str_b in zip(a, b): 312 | if str_a == gap_token and str_b != gap_token: 313 | alignment.append((EDIT_SYMBOLS["ins"], None, idx_b)) 314 | idx_b += 1 315 | elif str_a != gap_token and str_b == gap_token: 316 | alignment.append((EDIT_SYMBOLS["del"], idx_a, None)) 317 | idx_a += 1 318 | elif str_a != gap_token and str_b != gap_token and str_a != str_b: 319 | alignment.append((EDIT_SYMBOLS["sub"], idx_a, idx_b)) 320 | idx_a += 1 321 | idx_b += 1 322 | else: 323 | alignment.append((EDIT_SYMBOLS["eq"], idx_a, idx_b)) 324 | idx_a += 1 325 | idx_b += 1 326 | return alignment 327 | 328 | def rm_parallel_sil_batch(canos, percs): 329 | canos_out, percs_out = [], [] 330 | assert len(canos) == len(percs) ## batch size 331 | for cano, perc in zip(canos, percs): 332 | cano, perc = rm_parallel_sil(cano, perc) 333 | canos_out.append(cano) 334 | percs_out.append(perc) 335 | return canos_out, percs_out 336 | 337 | def rm_parallel_sil(canos, percs): 338 | canos_out, percs_out = [], [] 339 | assert len(canos) == len(percs) ## aligned 340 | for cano, perc in zip(canos, percs): 341 | if (cano==perc and cano=="sil"): 342 | continue 343 | canos_out.append(cano) 344 | percs_out.append(perc) 345 | return canos_out, percs_out 346 | 347 | 348 | def main(args): 349 | with open(args.json_path, "r") as f: 350 | json_data = json.load(f) 351 | per_file = open(args.per_file, "w") 352 | mpd_file = open(args.mpd_file, "w") 353 | mpd_eval_on_dataset(json_data, mpd_file, per_file) 354 | 355 | 356 | 357 | 358 | if __name__ == "__main__": 359 | p = argparse.ArgumentParser() 360 | p.add_argument("--json_path", type=str) 361 | p.add_argument("--per_file", type=str, default=None) 362 | p.add_argument("--mpd_file", type=str, default=None) 363 | args = p.parse_args() 364 | 365 | main(args) 366 | -------------------------------------------------------------------------------- /split_train_dev.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import random 4 | 5 | def split_list(in_list, ratio=0.1): 6 | small = random.sample(in_list, int(ratio*len(in_list))) 7 | big = [x for x in in_list if x not in small] 8 | return big, small 9 | 10 | def split_by_speaker(in_json, ratio=0.1): 11 | spks = set(in_json[wav_id]["spk_id"] for wav_id in in_json) 12 | out_train = {} 13 | out_dev = {} 14 | for spk in spks: 15 | spk_wav_ids = [wav_id for wav_id in in_json if in_json[wav_id]["spk_id"] == spk] 16 | train, dev = split_list(spk_wav_ids, ratio) 17 | for i in train: 18 | out_train.update({i: in_json[i]}) 19 | for i in dev: 20 | out_dev.update({i: in_json[i]}) 21 | return out_train, out_dev 22 | 23 | def main(args): 24 | with open(args.in_json, "r") as f: 25 | in_data = json.load(f) 26 | 27 | out_train, out_dev = split_by_speaker(in_data) 28 | 29 | with open(args.out_json_train, "w") as f: 30 | json.dump(out_train, f, indent=2) 31 | with open(args.out_json_dev, "w") as f: 32 | json.dump(out_dev, f, indent=2) 33 | 34 | 35 | if __name__ == "__main__": 36 | p = argparse.ArgumentParser() 37 | p.add_argument("--in_json", type=str) 38 | p.add_argument("--dev_ratio", type=float, default=0.1) 39 | p.add_argument("--out_json_train", type=str) 40 | p.add_argument("--out_json_dev", type=str) 41 | args = p.parse_args() 42 | 43 | main(args) 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import logging 5 | import speechbrain as sb 6 | from hyperpyyaml import load_hyperpyyaml 7 | from mpd_eval_v3 import MpdStats 8 | import librosa 9 | import json 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | def make_attn_mask(wavs, wav_lens): 14 | """ 15 | wav_lens: relative lengths(i.e. 0-1) of a batch. shape: (bs, ) 16 | return a tensor of shape (bs, seq_len), representing mask on allowed positions. 17 | 1 for regular tokens, 0 for padded tokens 18 | """ 19 | abs_lens = (wav_lens*wavs.shape[1]).long() 20 | attn_mask = wavs.new(wavs.shape).zero_().long() 21 | for i in range(len(abs_lens)): 22 | attn_mask[i, :abs_lens[i]] = 1 23 | return attn_mask 24 | 25 | # Define training procedure 26 | class ASR(sb.Brain): 27 | def compute_forward(self, batch, stage): 28 | "Given an input batch it computes the phoneme probabilities." 29 | batch = batch.to(self.device) 30 | wavs, wav_lens = batch.sig 31 | # phns_bos, _ = batch.phn_encoded_bos 32 | 33 | if stage == sb.Stage.TRAIN: 34 | if hasattr(self.hparams, "augmentation"): 35 | wavs = self.hparams.augmentation(wavs, wav_lens) 36 | 37 | # some wav2vec models (e.g. large-lv60) needs attention_mask 38 | if self.modules.wav2vec2.feature_extractor.return_attention_mask: 39 | attn_mask = make_attn_mask(wavs, wav_lens) 40 | else: 41 | attn_mask = None 42 | feats = self.modules.wav2vec2(wavs, attention_mask=attn_mask) 43 | x = self.modules.enc(feats) 44 | 45 | # output layer for ctc log-probabilities 46 | logits = self.modules.ctc_lin(x) 47 | p_ctc = self.hparams.log_softmax(logits) 48 | 49 | return p_ctc, wav_lens 50 | 51 | def compute_objectives(self, predictions, batch, stage): 52 | "Given the network predictions and targets computed the NLL loss." 53 | 54 | p_ctc, wav_lens = predictions 55 | 56 | ids = batch.id 57 | targets, target_lens = batch.phn_encoded_target 58 | if stage != sb.Stage.TRAIN: 59 | canonicals, canonical_lens = batch.phn_encoded_canonical 60 | perceiveds, perceived_lens = batch.phn_encoded_perceived 61 | 62 | loss_ctc = self.hparams.ctc_cost(p_ctc, targets, wav_lens, target_lens) 63 | loss = loss_ctc 64 | 65 | # Record losses for posterity 66 | if stage != sb.Stage.TRAIN: 67 | # Note: sb.decoders.ctc_greedy_decode will also remove padded tokens 68 | # that is, it return a list of list with different lengths 69 | sequence = sb.decoders.ctc_greedy_decode( 70 | p_ctc, wav_lens, blank_id=self.hparams.blank_index 71 | ) 72 | self.ctc_metrics.append(ids, p_ctc, targets, wav_lens, target_lens) 73 | 74 | self.per_metrics.append( 75 | ids=ids, 76 | predict=sequence, 77 | target=targets, 78 | predict_len=None, 79 | target_len=target_lens, 80 | ind2lab=self.label_encoder.decode_ndim, 81 | ) 82 | self.mpd_metrics.append( 83 | ids=ids, 84 | predict=sequence, 85 | canonical=canonicals, 86 | perceived=perceiveds, 87 | predict_len=None, 88 | canonical_len=canonical_lens, 89 | perceived_len=perceived_lens, 90 | ind2lab=self.label_encoder.decode_ndim, 91 | ) 92 | 93 | return loss 94 | 95 | def evaluate_batch(self, batch, stage): 96 | """Computations needed for validation/test batches""" 97 | predictions = self.compute_forward(batch, stage=stage) 98 | loss = self.compute_objectives(predictions, batch, stage=stage) 99 | return loss.detach() 100 | 101 | def on_stage_start(self, stage, epoch): 102 | "Gets called when a stage (either training, validation, test) starts." 103 | self.ctc_metrics = self.hparams.ctc_stats() 104 | if self.hparams.wav2vec2_specaug: 105 | self.modules.wav2vec2.model.config.apply_spec_augment = True 106 | 107 | if stage != sb.Stage.TRAIN: 108 | self.modules.wav2vec2.model.config.apply_spec_augment = False 109 | self.per_metrics = self.hparams.per_stats() 110 | self.mpd_metrics = MpdStats() 111 | 112 | def on_stage_end(self, stage, stage_loss, epoch): 113 | """Gets called at the end of a epoch.""" 114 | if stage == sb.Stage.TRAIN: 115 | self.train_loss = stage_loss 116 | else: 117 | per = self.per_metrics.summarize("error_rate") 118 | mpd_f1 = self.mpd_metrics.summarize("mpd_f1") 119 | 120 | if stage == sb.Stage.VALID: 121 | 122 | self.hparams.train_logger.log_stats( 123 | stats_meta={ 124 | "epoch": epoch, 125 | "lr_adam": self.adam_optimizer.param_groups[0]["lr"], 126 | "lr_wav2vec": self.wav2vec_optimizer.param_groups[0]["lr"], 127 | }, 128 | train_stats={"loss": self.train_loss}, 129 | valid_stats={ 130 | "loss": stage_loss, 131 | "ctc_loss": self.ctc_metrics.summarize("average"), 132 | "PER": per, 133 | "mpd_f1": mpd_f1 134 | }, 135 | ) 136 | self.checkpointer.save_and_keep_only( 137 | meta={"PER": per, "mpd_f1": mpd_f1}, min_keys=["PER"], max_keys=["mpd_f1"] 138 | ) 139 | 140 | if stage == sb.Stage.TEST: 141 | self.hparams.train_logger.log_stats( 142 | stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, 143 | test_stats={"loss": stage_loss, "PER": per, "mpd_f1": mpd_f1}, 144 | ) 145 | with open(self.hparams.wer_file, "w") as w: 146 | w.write("CTC loss stats:\n") 147 | self.ctc_metrics.write_stats(w) 148 | w.write("\nPER stats:\n") 149 | self.per_metrics.write_stats(w) 150 | print( 151 | "CTC and PER stats written to file", 152 | self.hparams.wer_file, 153 | ) 154 | with open(self.hparams.mpd_file, "w") as m: 155 | m.write("MPD results and stats:\n") 156 | self.mpd_metrics.write_stats(m) 157 | print( 158 | "MPD results and stats written to file", 159 | self.hparams.mpd_file, 160 | ) 161 | 162 | 163 | def fit_batch(self, batch): 164 | """Fit one batch, override to do multiple updates. 165 | 166 | The default implementation depends on a few methods being defined 167 | with a particular behavior: 168 | 169 | * ``compute_forward()`` 170 | * ``compute_objectives()`` 171 | 172 | Also depends on having optimizers passed at initialization. 173 | 174 | Arguments 175 | --------- 176 | batch : list of torch.Tensors 177 | Batch of data to use for training. Default implementation assumes 178 | this batch has two elements: inputs and targets. 179 | 180 | Returns 181 | ------- 182 | detached loss 183 | """ 184 | # Managing automatic mixed precision 185 | if self.auto_mix_prec: 186 | 187 | self.wav2vec_optimizer.zero_grad() 188 | self.adam_optimizer.zero_grad() 189 | 190 | with torch.cuda.amp.autocast(): 191 | outputs = self.compute_forward(batch, sb.Stage.TRAIN) 192 | loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN) 193 | 194 | self.scaler.scale(loss).backward() 195 | self.scaler.unscale_(self.wav2vec_optimizer) 196 | self.scaler.unscale_(self.adam_optimizer) 197 | 198 | if self.check_gradients(loss): 199 | self.scaler.step(self.wav2vec_optimizer) 200 | self.scaler.step(self.adam_optimizer) 201 | 202 | self.scaler.update() 203 | else: 204 | outputs = self.compute_forward(batch, sb.Stage.TRAIN) 205 | 206 | loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN) 207 | # normalize the loss by gradient_accumulation step 208 | (loss / self.hparams.gradient_accumulation).backward() 209 | 210 | if self.step % self.hparams.gradient_accumulation == 0: 211 | # gradient clipping & early stop if loss is not fini 212 | if self.check_gradients(loss): 213 | self.wav2vec_optimizer.step() 214 | self.adam_optimizer.step() 215 | 216 | self.wav2vec_optimizer.zero_grad() 217 | self.adam_optimizer.zero_grad() 218 | 219 | return loss.detach().cpu() 220 | 221 | def init_optimizers(self): 222 | "Initializes the wav2vec2 optimizer and model optimizer" 223 | self.wav2vec_optimizer = self.hparams.wav2vec_opt_class( 224 | self.modules.wav2vec2.model.parameters() 225 | ) 226 | self.adam_optimizer = self.hparams.adam_opt_class( 227 | self.hparams.model.parameters() 228 | ) 229 | 230 | if self.checkpointer is not None: 231 | self.checkpointer.add_recoverable( 232 | "wav2vec_opt", self.wav2vec_optimizer 233 | ) 234 | self.checkpointer.add_recoverable("adam_opt", self.adam_optimizer) 235 | def on_fit_start(self): 236 | """Gets called at the beginning of ``fit()``, on multiple processes 237 | if ``distributed_count > 0`` and backend is ddp. 238 | 239 | Default implementation compiles the jit modules, initializes 240 | optimizers, and loads the latest checkpoint to resume training. 241 | """ 242 | # Run this *after* starting all processes since jit modules cannot be 243 | # pickled. 244 | self._compile_jit() 245 | 246 | # Wrap modules with parallel backend after jit 247 | self._wrap_distributed() 248 | 249 | # Initialize optimizers after parameters are configured 250 | self.init_optimizers() 251 | 252 | # Load latest checkpoint to resume training if interrupted 253 | ## NOTE: make sure to use the "best" model to continual training 254 | ## so we set the `min_key` argument 255 | if self.checkpointer is not None: 256 | self.checkpointer.recover_if_possible( 257 | device=torch.device(self.device), 258 | min_key="PER" 259 | ) 260 | 261 | 262 | 263 | def dataio_prep(hparams): 264 | """This function prepares the datasets to be used in the brain class. 265 | It also defines the data processing pipeline through user-defined functions.""" 266 | data_folder = hparams["data_folder_save"] 267 | # 1. Declarations: 268 | train_data = sb.dataio.dataset.DynamicItemDataset.from_json( 269 | json_path=hparams["train_annotation"], 270 | replacements={"data_root": data_folder}, 271 | ) 272 | if hparams["sorting"] == "ascending": 273 | # we sort training data to speed up training and get better results. 274 | train_data = train_data.filtered_sorted(sort_key="duration") 275 | # when sorting do not shuffle in dataloader ! otherwise is pointless 276 | hparams["train_dataloader_opts"]["shuffle"] = False 277 | 278 | elif hparams["sorting"] == "descending": 279 | train_data = train_data.filtered_sorted( 280 | sort_key="duration", reverse=True 281 | ) 282 | # when sorting do not shuffle in dataloader ! otherwise is pointless 283 | hparams["train_dataloader_opts"]["shuffle"] = False 284 | 285 | elif hparams["sorting"] == "random": 286 | pass 287 | 288 | else: 289 | raise NotImplementedError( 290 | "sorting must be random, ascending or descending" 291 | ) 292 | 293 | valid_data = sb.dataio.dataset.DynamicItemDataset.from_json( 294 | json_path=hparams["valid_annotation"], 295 | replacements={"data_root": data_folder}, 296 | ) 297 | valid_data = valid_data.filtered_sorted(sort_key="duration") 298 | 299 | test_data = sb.dataio.dataset.DynamicItemDataset.from_json( 300 | json_path=hparams["test_annotation"], 301 | replacements={"data_root": data_folder}, 302 | ) 303 | test_data = test_data.filtered_sorted(sort_key="duration") 304 | 305 | datasets = [train_data, valid_data, test_data] 306 | label_encoder = sb.dataio.encoder.CTCTextEncoder() 307 | 308 | # 2. Define audio pipeline: 309 | @sb.utils.data_pipeline.takes("wav") 310 | @sb.utils.data_pipeline.provides("sig") 311 | def audio_pipeline(wav): 312 | # sig = sb.dataio.dataio.read_audio(wav) 313 | # # sample rate change to 16000, e,g, using librosa 314 | # sig = torch.Tensor(librosa.core.load(wav, hparams["sample_rate"])[0]) 315 | # Use wav2vec processor to do normalization 316 | sig = hparams["wav2vec2"].feature_extractor( 317 | librosa.core.load(wav, hparams["sample_rate"])[0], 318 | sampling_rate=hparams["sample_rate"], 319 | ).input_values[0] 320 | sig = torch.Tensor(sig) 321 | return sig 322 | 323 | sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) 324 | 325 | # 3. Define text pipeline: 326 | @sb.utils.data_pipeline.takes("perceived_train_target") 327 | @sb.utils.data_pipeline.provides( 328 | "phn_list_target", 329 | "phn_encoded_list_target", 330 | "phn_encoded_target", 331 | ) 332 | def text_pipeline_train(phn): 333 | phn_list = phn.strip().split() 334 | yield phn_list 335 | phn_encoded_list = label_encoder.encode_sequence(phn_list) 336 | yield phn_encoded_list 337 | phn_encoded = torch.LongTensor(phn_encoded_list) 338 | yield phn_encoded 339 | 340 | @sb.utils.data_pipeline.takes("perceived_train_target", "canonical_aligned", "perceived_aligned") 341 | @sb.utils.data_pipeline.provides( 342 | "phn_list_target", 343 | "phn_encoded_list_target", 344 | "phn_encoded_target", 345 | "phn_list_canonical", 346 | "phn_encoded_list_canonical", 347 | "phn_encoded_canonical", 348 | "phn_list_perceived", 349 | "phn_encoded_list_perceived", 350 | "phn_encoded_perceived", 351 | ) 352 | def text_pipeline_test(target, canonical, perceived): 353 | phn_list_target = target.strip().split() 354 | yield phn_list_target 355 | phn_encoded_list_target = label_encoder.encode_sequence(phn_list_target) 356 | yield phn_encoded_list_target 357 | phn_encoded_target = torch.LongTensor(phn_encoded_list_target) 358 | yield phn_encoded_target 359 | phn_list_canonical = canonical.strip().split() 360 | yield phn_list_canonical 361 | phn_encoded_list_canonical = label_encoder.encode_sequence(phn_list_canonical) 362 | yield phn_encoded_list_canonical 363 | phn_encoded_canonical = torch.LongTensor(phn_encoded_list_canonical) 364 | yield phn_encoded_canonical 365 | phn_list_perceived = perceived.strip().split() 366 | yield phn_list_perceived 367 | phn_encoded_list_perceived = label_encoder.encode_sequence(phn_list_perceived) 368 | yield phn_encoded_list_perceived 369 | phn_encoded_perceived = torch.LongTensor(phn_encoded_list_perceived) 370 | yield phn_encoded_perceived 371 | 372 | sb.dataio.dataset.add_dynamic_item([train_data], text_pipeline_train) 373 | sb.dataio.dataset.add_dynamic_item([valid_data, test_data], text_pipeline_test) 374 | 375 | # 3. Fit encoder: 376 | # Load or compute the label encoder 377 | lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") 378 | special_labels = { 379 | "blank_label": hparams["blank_index"], 380 | } 381 | label_encoder.load_or_create( 382 | path=lab_enc_file, 383 | from_didatasets=[train_data], 384 | output_key="phn_list_target", 385 | special_labels=special_labels, 386 | sequence_input=True, 387 | ) 388 | 389 | # 4. Set output: 390 | sb.dataio.dataset.set_output_keys( 391 | [train_data], 392 | ["id", "sig", "phn_encoded_target"], 393 | ) 394 | sb.dataio.dataset.set_output_keys( 395 | [valid_data, test_data], 396 | ["id", "sig", "phn_encoded_target", "phn_encoded_canonical", "phn_encoded_perceived"], 397 | ) 398 | 399 | return train_data, valid_data, test_data, label_encoder 400 | 401 | 402 | if __name__ == "__main__": 403 | # CLI: 404 | hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) 405 | 406 | # Load hyperparameters file with command-line overrides 407 | with open(hparams_file) as fin: 408 | hparams = load_hyperpyyaml(fin, overrides) 409 | 410 | # Initialize ddp (useful only for multi-GPU DDP training) 411 | sb.utils.distributed.ddp_init_group(run_opts) 412 | 413 | # Create experiment directory 414 | sb.create_experiment_directory( 415 | experiment_directory=hparams["output_folder"], 416 | hyperparams_to_save=hparams_file, 417 | overrides=overrides, 418 | ) 419 | 420 | # Dataset IO prep: creating Dataset objects and proper encodings for phones 421 | train_data, valid_data, test_data, label_encoder = dataio_prep(hparams) 422 | 423 | # Trainer initialization 424 | asr_brain = ASR( 425 | modules=hparams["modules"], 426 | hparams=hparams, 427 | run_opts=run_opts, 428 | checkpointer=hparams["checkpointer"], 429 | ) 430 | asr_brain.label_encoder = label_encoder 431 | 432 | # Training/validation loop 433 | asr_brain.fit( 434 | asr_brain.hparams.epoch_counter, 435 | train_data, 436 | valid_data, 437 | train_loader_kwargs=hparams["train_dataloader_opts"], 438 | valid_loader_kwargs=hparams["valid_dataloader_opts"], 439 | ) 440 | 441 | # Test 442 | asr_brain.evaluate( 443 | test_data, 444 | min_key="PER", 445 | test_loader_kwargs=hparams["test_dataloader_opts"], 446 | ) 447 | -------------------------------------------------------------------------------- /train_mpl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | from tqdm.contrib import tqdm 5 | import torch 6 | import logging 7 | import speechbrain as sb 8 | from hyperpyyaml import load_hyperpyyaml 9 | from torch.utils.data import DataLoader 10 | from torch.nn.utils.rnn import pad_sequence 11 | from speechbrain.dataio.batch import PaddedBatch 12 | from speechbrain.core import Stage 13 | from speechbrain.utils.distributed import run_on_main 14 | from mpd_eval_v3 import MpdStats 15 | import librosa 16 | import json 17 | import itertools 18 | import math 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | def make_attn_mask(wavs, wav_lens): 23 | """ 24 | wav_lens: relative lengths(i.e. 0-1) of a batch. shape: (bs, ) 25 | return a tensor of shape (bs, seq_len), representing mask on allowed positions. 26 | 1 for regular tokens, 0 for padded tokens 27 | """ 28 | abs_lens = (wav_lens*wavs.shape[1]).long() 29 | attn_mask = wavs.new(wavs.shape).zero_().long() 30 | for i in range(len(abs_lens)): 31 | attn_mask[i, :abs_lens[i]] = 1 32 | return attn_mask 33 | 34 | # Define training procedure 35 | class ASR(sb.Brain): 36 | def compute_forward(self, batch, stage): 37 | "Given an input batch it computes the phoneme probabilities." 38 | batch = batch.to(self.device) 39 | wavs, wav_lens = batch.sig 40 | 41 | if stage == sb.Stage.TRAIN: 42 | if hasattr(self.hparams, "augmentation"): 43 | wavs = self.hparams.augmentation(wavs, wav_lens) 44 | 45 | # some wav2vec models (e.g. large-lv60) needs attention_mask 46 | if self.modules.wav2vec2.feature_extractor.return_attention_mask: 47 | attn_mask = make_attn_mask(wavs, wav_lens) 48 | else: 49 | attn_mask = None 50 | feats = self.modules.wav2vec2(wavs, attention_mask=attn_mask) 51 | x = self.modules.enc(feats) 52 | 53 | # output layer for ctc log-probabilities 54 | logits = self.modules.ctc_lin(x) 55 | p_ctc = self.hparams.log_softmax(logits) 56 | 57 | 58 | return p_ctc, wav_lens 59 | 60 | def compute_objectives(self, predictions, batch, stage): 61 | "Given the network predictions and targets computed the NLL loss." 62 | 63 | p_ctc, wav_lens = predictions 64 | 65 | ids = batch.id 66 | targets, target_lens = batch.phn_encoded_target 67 | if stage != sb.Stage.TRAIN: 68 | canonicals, canonical_lens = batch.phn_encoded_canonical 69 | perceiveds, perceived_lens = batch.phn_encoded_perceived 70 | 71 | loss_ctc = self.hparams.ctc_cost(p_ctc, targets, wav_lens, target_lens) 72 | loss = loss_ctc 73 | 74 | # Record losses for posterity 75 | if stage != sb.Stage.TRAIN: 76 | # Note: sb.decoders.ctc_greedy_decode will also remove padded tokens 77 | # that is, it return a list of list with different lengths 78 | sequence = sb.decoders.ctc_greedy_decode( 79 | p_ctc, wav_lens, blank_id=self.hparams.blank_index 80 | ) 81 | self.ctc_metrics.append(ids, p_ctc, targets, wav_lens, target_lens) 82 | 83 | # Note: the None in the function arguments mean that we do not specify predict_len 84 | # meaning the padding has been removed so we do not need to predicted lengths 85 | self.per_metrics.append( 86 | ids=ids, 87 | predict=sequence, 88 | target=targets, 89 | predict_len=None, 90 | target_len=target_lens, 91 | ind2lab=self.label_encoder.decode_ndim, 92 | ) 93 | self.mpd_metrics.append( 94 | ids=ids, 95 | predict=sequence, 96 | canonical=canonicals, 97 | perceived=perceiveds, 98 | predict_len=None, 99 | canonical_len=canonical_lens, 100 | perceived_len=perceived_lens, 101 | ind2lab=self.label_encoder.decode_ndim, 102 | ) 103 | 104 | return loss 105 | 106 | def evaluate_batch(self, batch, stage): 107 | """Computations needed for validation/test batches""" 108 | predictions = self.compute_forward(batch, stage=stage) 109 | loss = self.compute_objectives(predictions, batch, stage=stage) 110 | return loss.detach() 111 | 112 | def on_stage_start(self, stage, epoch): 113 | "Gets called when a stage (either training, validation, test) starts." 114 | self.ctc_metrics = self.hparams.ctc_stats() 115 | # self.seq_metrics = self.hparams.seq_stats() 116 | if self.hparams.wav2vec2_specaug: 117 | self.modules.wav2vec2.model.config.apply_spec_augment = True 118 | 119 | if stage != sb.Stage.TRAIN: 120 | self.modules.wav2vec2.model.config.apply_spec_augment = False 121 | self.per_metrics = self.hparams.per_stats() 122 | self.mpd_metrics = MpdStats() 123 | 124 | def on_stage_end(self, stage, stage_loss, epoch): 125 | """Gets called at the end of a epoch.""" 126 | if stage == sb.Stage.TRAIN: 127 | self.train_loss = stage_loss 128 | else: 129 | per = self.per_metrics.summarize("error_rate") 130 | mpd_f1 = self.mpd_metrics.summarize("mpd_f1") 131 | 132 | if stage == sb.Stage.VALID: 133 | 134 | self.hparams.train_logger.log_stats( 135 | stats_meta={ 136 | "epoch": epoch, 137 | "lr_adam": self.adam_optimizer.param_groups[0]["lr"], 138 | "lr_wav2vec": self.wav2vec_optimizer.param_groups[0]["lr"], 139 | }, 140 | train_stats={"loss": self.train_loss}, 141 | valid_stats={ 142 | "loss": stage_loss, 143 | "ctc_loss": self.ctc_metrics.summarize("average"), 144 | "PER": per, 145 | "mpd_f1": mpd_f1 146 | }, 147 | ) 148 | self.checkpointer.save_and_keep_only( 149 | meta={"PER": per, "mpd_f1": mpd_f1}, min_keys=["PER"], max_keys=["mpd_f1"] 150 | ) 151 | 152 | if stage == sb.Stage.TEST: 153 | self.hparams.train_logger.log_stats( 154 | stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, 155 | test_stats={"loss": stage_loss, "PER": per, "mpd_f1": mpd_f1}, 156 | ) 157 | with open(self.hparams.wer_file, "w") as w: 158 | w.write("CTC loss stats:\n") 159 | self.ctc_metrics.write_stats(w) 160 | w.write("\nPER stats:\n") 161 | self.per_metrics.write_stats(w) 162 | print( 163 | "CTC and PER stats written to file", 164 | self.hparams.wer_file, 165 | ) 166 | with open(self.hparams.mpd_file, "w") as m: 167 | m.write("MPD results and stats:\n") 168 | self.mpd_metrics.write_stats(m) 169 | print( 170 | "MPD results and stats written to file", 171 | self.hparams.mpd_file, 172 | ) 173 | 174 | 175 | def fit_batch(self, batch): 176 | """Fit one batch, override to do multiple updates. 177 | 178 | The default implementation depends on a few methods being defined 179 | with a particular behavior: 180 | 181 | * ``compute_forward()`` 182 | * ``compute_objectives()`` 183 | 184 | Also depends on having optimizers passed at initialization. 185 | 186 | Arguments 187 | --------- 188 | batch : list of torch.Tensors 189 | Batch of data to use for training. Default implementation assumes 190 | this batch has two elements: inputs and targets. 191 | 192 | Returns 193 | ------- 194 | detached loss 195 | """ 196 | # Managing automatic mixed precision 197 | if self.auto_mix_prec: 198 | 199 | self.wav2vec_optimizer.zero_grad() 200 | self.adam_optimizer.zero_grad() 201 | 202 | with torch.cuda.amp.autocast(): 203 | outputs = self.compute_forward(batch, sb.Stage.TRAIN) 204 | loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN) 205 | 206 | self.scaler.scale(loss).backward() 207 | self.scaler.unscale_(self.wav2vec_optimizer) 208 | self.scaler.unscale_(self.adam_optimizer) 209 | 210 | if self.check_gradients(loss): 211 | self.scaler.step(self.wav2vec_optimizer) 212 | self.scaler.step(self.adam_optimizer) 213 | 214 | self.scaler.update() 215 | else: 216 | if not hasattr(batch, "phn_encoded_target"): 217 | ## unlabeled batch inference for pseudo labels 218 | ## make sure modules are in eval() mode, augmentations are DISABLED, see infer_batch() 219 | pls, pl_lens = self.infer_batch(batch) 220 | ## then, set to train() mode and ENABLE all augmentations to train on PLs 221 | self.modules.train() 222 | if self.hparams.wav2vec2_specaug: 223 | self.modules.wav2vec2.model.config.apply_spec_augment = True 224 | ## forward pass on unlabled batch, with all augmentations ENABLED 225 | outputs_u = self.compute_forward(batch, sb.Stage.TRAIN) 226 | loss = self.compute_objectives_unlabeled(outputs_u, pls, pl_lens) 227 | elif hasattr(batch, "phn_encoded_target"): 228 | ## labeled batch - make sure modules are in train() mode, augmentations are added 229 | self.modules.train() 230 | if self.hparams.wav2vec2_specaug: 231 | self.modules.wav2vec2.model.config.apply_spec_augment = True 232 | outputs_l = self.compute_forward(batch, sb.Stage.TRAIN) 233 | loss = self.compute_objectives(outputs_l, batch, sb.Stage.TRAIN) 234 | 235 | # normalize the loss by gradient_accumulation step 236 | (loss / self.hparams.gradient_accumulation).backward() 237 | 238 | if self.step % self.hparams.gradient_accumulation == 0: 239 | # gradient clipping & early stop if loss is not fini 240 | if self.check_gradients(loss): 241 | self.wav2vec_optimizer.step() 242 | self.adam_optimizer.step() 243 | 244 | self.wav2vec_optimizer.zero_grad() 245 | self.adam_optimizer.zero_grad() 246 | 247 | # momentum update on teacher model 248 | self.teacher_momentum_update() 249 | 250 | return loss.detach().cpu() 251 | 252 | def infer_batch(self, batch): 253 | ## make sure modules are in eval() mode, augmentations are disabled 254 | self.modules_teacher.eval() 255 | self.modules_teacher.wav2vec2.model.config.apply_spec_augment = False 256 | batch = batch.to(self.device) 257 | wavs, wav_lens = batch.sig 258 | with torch.no_grad(): 259 | # some wav2vec models (e.g. large-lv60) needs attention_mask 260 | if self.modules_teacher.wav2vec2.feature_extractor.return_attention_mask: 261 | attn_mask = make_attn_mask(wavs, wav_lens) 262 | else: 263 | attn_mask = None 264 | feats = self.modules_teacher.wav2vec2(wavs, attention_mask=attn_mask) 265 | x = self.modules_teacher.enc(feats) 266 | 267 | # output layer for ctc log-probabilities 268 | logits = self.modules_teacher.ctc_lin(x) 269 | p_ctc = self.hparams.log_softmax(logits) 270 | 271 | pseudo_labels = sb.decoders.ctc_greedy_decode( 272 | p_ctc, wav_lens, blank_id=self.hparams.blank_index 273 | ) 274 | max_len = max(len(x) for x in pseudo_labels) 275 | pseudo_label_lens = torch.tensor([float(len(x)/max_len) for x in pseudo_labels]) 276 | pseudo_labels = pad_sequence( 277 | [torch.tensor(x) for x in pseudo_labels], 278 | batch_first=True 279 | ) 280 | return pseudo_labels.to(self.device), pseudo_label_lens.to(self.device) 281 | 282 | def compute_objectives_unlabeled(self, predictions, targets, target_lens): 283 | "Simply compute the CTC loss" 284 | 285 | p_ctc, wav_lens = predictions 286 | 287 | loss_ctc = self.hparams.ctc_cost(p_ctc, targets, wav_lens, target_lens) 288 | return loss_ctc 289 | 290 | def teacher_momentum_update(self): 291 | with torch.no_grad(): 292 | for pname, param in self.modules_teacher.state_dict().items(): 293 | param = self.momentum_factor * param + (1 - self.momentum_factor) * self.modules.state_dict()[pname] 294 | 295 | 296 | def init_optimizers(self): 297 | "Initializes the wav2vec2 optimizer and model optimizer" 298 | self.wav2vec_optimizer = self.hparams.wav2vec_opt_class( 299 | self.modules.wav2vec2.model.parameters() 300 | ) 301 | self.adam_optimizer = self.hparams.adam_opt_class( 302 | self.hparams.model.parameters() 303 | ) 304 | 305 | if self.checkpointer is not None: 306 | self.checkpointer.add_recoverable( 307 | "wav2vec_opt", self.wav2vec_optimizer 308 | ) 309 | self.checkpointer.add_recoverable("adam_opt", self.adam_optimizer) 310 | def on_fit_start(self): 311 | """Gets called at the beginning of ``fit()``, on multiple processes 312 | if ``distributed_count > 0`` and backend is ddp. 313 | 314 | Default implementation compiles the jit modules, initializes 315 | optimizers, and loads the latest checkpoint to resume training. 316 | """ 317 | # Run this *after* starting all processes since jit modules cannot be 318 | # pickled. 319 | self._compile_jit() 320 | 321 | # Wrap modules with parallel backend after jit 322 | self._wrap_distributed() 323 | 324 | # Initialize optimizers after parameters are configured 325 | self.init_optimizers() 326 | 327 | # Load latest checkpoint to resume training if interrupted 328 | ## NOTE: make sure to use the "best" model to continual training 329 | ## so we set the `min_key` argument 330 | if self.checkpointer is not None: 331 | self.checkpointer.recover_if_possible( 332 | device=torch.device(self.device), 333 | min_key="PER" 334 | ) 335 | 336 | ## set the epoch_counter to start from epoch 50 337 | self.hparams.epoch_counter.current=50 338 | 339 | 340 | ## initialize teacher model - load from the same base model ckpt 341 | chosen_ckpt = self.checkpointer.find_checkpoint(min_key="PER") 342 | model_layers = self.hparams.model_teacher 343 | wav2vec2_layers = self.hparams.wav2vec2_teacher 344 | model_layers.load_state_dict( 345 | torch.load(chosen_ckpt.paramfiles["model"], map_location=torch.device(self.device)) 346 | ) 347 | wav2vec2_layers.load_state_dict( 348 | torch.load(chosen_ckpt.paramfiles["wav2vec2"], map_location=torch.device(self.device)) 349 | ) 350 | self.modules_teacher.eval() 351 | 352 | self.set_momentum_factor( 353 | self.n_train_batch, 354 | self.hparams.epoch_counter.limit-self.hparams.epoch_counter.current 355 | ) 356 | 357 | def set_momentum_factor(self, n_train_batch, n_epochs): 358 | total_steps = float(n_train_batch * n_epochs // self.hparams.gradient_accumulation) 359 | self.momentum_factor = math.exp( (1/total_steps) * math.log(self.hparams.base_model_factor)) 360 | logger.info("Momentum Factor: {}".format(self.momentum_factor)) 361 | 362 | def fit( 363 | self, 364 | epoch_counter, 365 | train_data_l, 366 | train_data_u, 367 | valid_set=None, 368 | progressbar=None, 369 | train_loader_kwargs={}, 370 | valid_loader_kwargs={}, 371 | ): 372 | """Iterate epochs and datasets to improve objective. 373 | Relies on the existence of multiple functions that can (or should) be 374 | overridden. The following methods are used and expected to have a 375 | certain behavior: 376 | * ``fit_batch()`` 377 | * ``evaluate_batch()`` 378 | * ``update_average()`` 379 | If the initialization was done with distributed_count > 0 and the 380 | distributed_backend is ddp, this will generally handle multiprocess 381 | logic, like splitting the training data into subsets for each device and 382 | only saving a checkpoint on the main process. 383 | Arguments 384 | --------- 385 | epoch_counter : iterable 386 | Each call should return an integer indicating the epoch count. 387 | train_set : Dataset, DataLoader 388 | A set of data to use for training. If a Dataset is given, a 389 | DataLoader is automatically created. If a DataLoader is given, it is 390 | used directly. 391 | valid_set : Dataset, DataLoader 392 | A set of data to use for validation. If a Dataset is given, a 393 | DataLoader is automatically created. If a DataLoader is given, it is 394 | used directly. 395 | train_loader_kwargs : dict 396 | Kwargs passed to `make_dataloader()` for making the train_loader 397 | (if train_set is a Dataset, not DataLoader). 398 | E.G. batch_size, num_workers. 399 | DataLoader kwargs are all valid. 400 | valid_loader_kwargs : dict 401 | Kwargs passed to `make_dataloader()` for making the valid_loader 402 | (if valid_set is a Dataset, not DataLoader). 403 | E.g., batch_size, num_workers. 404 | DataLoader kwargs are all valid. 405 | progressbar : bool 406 | Whether to display the progress of each epoch in a progressbar. 407 | """ 408 | 409 | self.n_train_batch = len(train_data_l) + len(train_data_u) 410 | 411 | self.on_fit_start() 412 | 413 | if progressbar is None: 414 | progressbar = not self.noprogressbar 415 | 416 | # Iterate epochs 417 | for epoch in epoch_counter: 418 | 419 | ## chain labeled data loader and unlabeled data loader 420 | train_set = itertools.chain(train_data_l, train_data_u) 421 | 422 | # Training stage 423 | self.on_stage_start(Stage.TRAIN, epoch) 424 | self.modules.train() 425 | 426 | # Reset nonfinite count to 0 each epoch 427 | self.nonfinite_count = 0 428 | 429 | if self.train_sampler is not None and hasattr( 430 | self.train_sampler, "set_epoch" 431 | ): 432 | self.train_sampler.set_epoch(epoch) 433 | 434 | # Time since last intra-epoch checkpoint 435 | last_ckpt_time = time.time() 436 | 437 | # Only show progressbar if requested and main_process 438 | enable = progressbar and sb.utils.distributed.if_main_process() 439 | with tqdm( 440 | train_set, 441 | initial=self.step, 442 | total=self.n_train_batch, 443 | dynamic_ncols=True, 444 | disable=not enable, 445 | ) as t: 446 | for batch in t: 447 | self.step += 1 448 | loss = self.fit_batch(batch) 449 | self.avg_train_loss = self.update_average( 450 | loss, self.avg_train_loss 451 | ) 452 | t.set_postfix(train_loss=self.avg_train_loss) 453 | 454 | # Debug mode only runs a few batches 455 | if self.debug and self.step == self.debug_batches: 456 | break 457 | 458 | if ( 459 | self.checkpointer is not None 460 | and self.ckpt_interval_minutes > 0 461 | and time.time() - last_ckpt_time 462 | >= self.ckpt_interval_minutes * 60.0 463 | ): 464 | # This should not use run_on_main, because that 465 | # includes a DDP barrier. That eventually leads to a 466 | # crash when the processes' 467 | # time.time() - last_ckpt_time differ and some 468 | # processes enter this block while others don't, 469 | # missing the barrier. 470 | if sb.utils.distributed.if_main_process(): 471 | self._save_intra_epoch_ckpt() 472 | last_ckpt_time = time.time() 473 | 474 | # Run train "on_stage_end" on all processes 475 | self.on_stage_end(Stage.TRAIN, self.avg_train_loss, epoch) 476 | self.avg_train_loss = 0.0 477 | self.step = 0 478 | 479 | # Validation stage 480 | if valid_set is not None: 481 | self.on_stage_start(Stage.VALID, epoch) 482 | self.modules.eval() 483 | avg_valid_loss = 0.0 484 | with torch.no_grad(): 485 | for batch in tqdm( 486 | valid_set, dynamic_ncols=True, disable=not enable 487 | ): 488 | self.step += 1 489 | loss = self.evaluate_batch(batch, stage=Stage.VALID) 490 | avg_valid_loss = self.update_average( 491 | loss, avg_valid_loss 492 | ) 493 | 494 | # Debug mode only runs a few batches 495 | if self.debug and self.step == self.debug_batches: 496 | break 497 | 498 | # Only run validation "on_stage_end" on main process 499 | self.step = 0 500 | run_on_main( 501 | self.on_stage_end, 502 | args=[Stage.VALID, avg_valid_loss, epoch], 503 | ) 504 | 505 | # Debug mode only runs a few epochs 506 | if self.debug and epoch == self.debug_epochs: 507 | break 508 | 509 | 510 | def dataio_prep(hparams): 511 | """This function prepares the datasets to be used in the brain class. 512 | It also defines the data processing pipeline through user-defined functions.""" 513 | data_folder = hparams["data_folder_save"] 514 | # 1. Declarations: 515 | ## labeled training data 516 | train_data = sb.dataio.dataset.DynamicItemDataset.from_json( 517 | json_path=hparams["train_annotation"], 518 | replacements={"data_root": data_folder}, 519 | ) 520 | if hparams["sorting"] == "ascending": 521 | # we sort training data to speed up training and get better results. 522 | train_data = train_data.filtered_sorted(sort_key="duration") 523 | # when sorting do not shuffle in dataloader ! otherwise is pointless 524 | hparams["train_dataloader_opts"]["shuffle"] = False 525 | 526 | elif hparams["sorting"] == "descending": 527 | train_data = train_data.filtered_sorted( 528 | sort_key="duration", reverse=True 529 | ) 530 | # when sorting do not shuffle in dataloader ! otherwise is pointless 531 | hparams["train_dataloader_opts"]["shuffle"] = False 532 | 533 | elif hparams["sorting"] == "random": 534 | pass 535 | 536 | else: 537 | raise NotImplementedError( 538 | "sorting must be random, ascending or descending" 539 | ) 540 | ## unlabled training data 541 | train_data_u = sb.dataio.dataset.DynamicItemDataset.from_json( 542 | json_path=hparams["unlabeled_annotation"], 543 | replacements={"data_root": data_folder}, 544 | ) 545 | 546 | valid_data = sb.dataio.dataset.DynamicItemDataset.from_json( 547 | json_path=hparams["valid_annotation"], 548 | replacements={"data_root": data_folder}, 549 | ) 550 | valid_data = valid_data.filtered_sorted(sort_key="duration") 551 | 552 | test_data = sb.dataio.dataset.DynamicItemDataset.from_json( 553 | json_path=hparams["test_annotation"], 554 | replacements={"data_root": data_folder}, 555 | ) 556 | test_data = test_data.filtered_sorted(sort_key="duration") 557 | 558 | datasets = [train_data, train_data_u, valid_data, test_data] 559 | label_encoder = sb.dataio.encoder.CTCTextEncoder() 560 | 561 | # 2. Define audio pipeline: 562 | @sb.utils.data_pipeline.takes("wav") 563 | @sb.utils.data_pipeline.provides("sig") 564 | def audio_pipeline(wav): 565 | # sig = sb.dataio.dataio.read_audio(wav) 566 | # # sample rate change to 16000, e,g, using librosa 567 | # sig = torch.Tensor(librosa.core.load(wav, hparams["sample_rate"])[0]) 568 | # Use wav2vec processor to do normalization 569 | sig = hparams["wav2vec2"].feature_extractor( 570 | librosa.core.load(wav, hparams["sample_rate"])[0], 571 | sampling_rate=hparams["sample_rate"], 572 | ).input_values[0] 573 | sig = torch.Tensor(sig) 574 | return sig 575 | 576 | sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) 577 | 578 | # 3. Define text pipeline: 579 | @sb.utils.data_pipeline.takes("perceived_train_target") 580 | @sb.utils.data_pipeline.provides( 581 | "phn_list_target", 582 | "phn_encoded_list_target", 583 | "phn_encoded_target", 584 | ) 585 | def text_pipeline_train(phn): 586 | phn_list = phn.strip().split() 587 | yield phn_list 588 | phn_encoded_list = label_encoder.encode_sequence(phn_list) 589 | yield phn_encoded_list 590 | phn_encoded = torch.LongTensor(phn_encoded_list) 591 | yield phn_encoded 592 | 593 | @sb.utils.data_pipeline.takes("perceived_train_target", "canonical_aligned", "perceived_aligned") 594 | @sb.utils.data_pipeline.provides( 595 | "phn_list_target", 596 | "phn_encoded_list_target", 597 | "phn_encoded_target", 598 | "phn_list_canonical", 599 | "phn_encoded_list_canonical", 600 | "phn_encoded_canonical", 601 | "phn_list_perceived", 602 | "phn_encoded_list_perceived", 603 | "phn_encoded_perceived", 604 | ) 605 | def text_pipeline_test(target, canonical, perceived): 606 | phn_list_target = target.strip().split() 607 | yield phn_list_target 608 | phn_encoded_list_target = label_encoder.encode_sequence(phn_list_target) 609 | yield phn_encoded_list_target 610 | phn_encoded_target = torch.LongTensor(phn_encoded_list_target) 611 | yield phn_encoded_target 612 | phn_list_canonical = canonical.strip().split() 613 | yield phn_list_canonical 614 | phn_encoded_list_canonical = label_encoder.encode_sequence(phn_list_canonical) 615 | yield phn_encoded_list_canonical 616 | phn_encoded_canonical = torch.LongTensor(phn_encoded_list_canonical) 617 | yield phn_encoded_canonical 618 | phn_list_perceived = perceived.strip().split() 619 | yield phn_list_perceived 620 | phn_encoded_list_perceived = label_encoder.encode_sequence(phn_list_perceived) 621 | yield phn_encoded_list_perceived 622 | phn_encoded_perceived = torch.LongTensor(phn_encoded_list_perceived) 623 | yield phn_encoded_perceived 624 | 625 | sb.dataio.dataset.add_dynamic_item([train_data], text_pipeline_train) 626 | sb.dataio.dataset.add_dynamic_item([valid_data, test_data], text_pipeline_test) 627 | 628 | # 3. Fit encoder: 629 | # Load or compute the label encoder 630 | lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") 631 | special_labels = { 632 | "blank_label": hparams["blank_index"], 633 | } 634 | label_encoder.load_or_create( 635 | path=lab_enc_file, 636 | from_didatasets=[train_data], 637 | output_key="phn_list_target", 638 | special_labels=special_labels, 639 | sequence_input=True, 640 | ) 641 | 642 | # 4. Set output: 643 | sb.dataio.dataset.set_output_keys( 644 | [train_data], 645 | ["id", "sig", "phn_encoded_target"], 646 | ) 647 | sb.dataio.dataset.set_output_keys( 648 | [train_data_u], 649 | ["id", "sig"], 650 | ) 651 | sb.dataio.dataset.set_output_keys( 652 | [valid_data, test_data], 653 | ["id", "sig", "phn_encoded_target", "phn_encoded_canonical", "phn_encoded_perceived"], 654 | ) 655 | 656 | return train_data, train_data_u, valid_data, test_data, label_encoder 657 | 658 | 659 | if __name__ == "__main__": 660 | # CLI: 661 | hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) 662 | 663 | # Load hyperparameters file with command-line overrides 664 | with open(hparams_file) as fin: 665 | hparams = load_hyperpyyaml(fin, overrides) 666 | 667 | # Initialize ddp (useful only for multi-GPU DDP training) 668 | sb.utils.distributed.ddp_init_group(run_opts) 669 | 670 | # Create experiment directory 671 | sb.create_experiment_directory( 672 | experiment_directory=hparams["output_folder"], 673 | hyperparams_to_save=hparams_file, 674 | overrides=overrides, 675 | ) 676 | 677 | # Dataset IO prep: creating Dataset objects and proper encodings for phones 678 | train_data_l, train_data_u, valid_data, test_data, label_encoder = dataio_prep(hparams) 679 | 680 | ## build data loaders for all datasets 681 | train_data_l = DataLoader( 682 | train_data_l, 683 | batch_size=hparams["batch_size_labeled"], 684 | drop_last=False, 685 | shuffle=True, 686 | sampler=None, 687 | collate_fn=PaddedBatch, 688 | num_workers=hparams["num_workers"] 689 | ) 690 | train_data_u = DataLoader( 691 | train_data_u, 692 | batch_size=hparams["batch_size_unlabeled"], 693 | drop_last=False, 694 | shuffle=True, 695 | sampler=None, 696 | collate_fn=PaddedBatch, 697 | num_workers=hparams["num_workers"] 698 | ) 699 | valid_data = DataLoader( 700 | valid_data, 701 | batch_size=hparams["batch_size_labeled"], 702 | drop_last=False, 703 | shuffle=False, 704 | sampler=None, 705 | collate_fn=PaddedBatch, 706 | num_workers=hparams["num_workers"] 707 | ) 708 | test_data = DataLoader( 709 | test_data, 710 | batch_size=1, 711 | drop_last=False, 712 | shuffle=False, 713 | sampler=None, 714 | collate_fn=PaddedBatch, 715 | num_workers=1 716 | ) 717 | 718 | 719 | # Trainer initialization 720 | asr_brain = ASR( 721 | modules=hparams["modules"], 722 | hparams=hparams, 723 | run_opts=run_opts, 724 | checkpointer=hparams["checkpointer"], 725 | ) 726 | asr_brain.label_encoder = label_encoder 727 | asr_brain.modules_teacher = torch.nn.ModuleDict(hparams["modules_teacher"]).to(asr_brain.device) 728 | 729 | # Training/validation loop 730 | asr_brain.fit( 731 | asr_brain.hparams.epoch_counter, 732 | train_data_l, 733 | train_data_u, 734 | valid_data, 735 | train_loader_kwargs=None, 736 | valid_loader_kwargs=None, 737 | ) 738 | 739 | # Test 740 | asr_brain.evaluate( 741 | test_data, 742 | min_key="PER", 743 | test_loader_kwargs=None, 744 | ) 745 | -------------------------------------------------------------------------------- /transcribe.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import logging 5 | import speechbrain as sb 6 | from hyperpyyaml import load_hyperpyyaml 7 | import librosa 8 | from tqdm import tqdm 9 | import json 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | def make_attn_mask(wavs, wav_lens): 14 | """ 15 | wav_lens: relative lengths(i.e. 0-1) of a batch. shape: (bs, ) 16 | return a tensor of shape (bs, seq_len), representing mask on allowed positions. 17 | 1 for regular tokens, 0 for padded tokens 18 | """ 19 | abs_lens = (wav_lens*wavs.shape[1]).long() 20 | attn_mask = wavs.new(wavs.shape).zero_().long() 21 | for i in range(len(abs_lens)): 22 | attn_mask[i, :abs_lens[i]] = 1 23 | return attn_mask 24 | 25 | # Define training procedure 26 | class ASR(sb.Brain): 27 | def compute_forward(self, batch, stage): 28 | "Given an input batch it computes the phoneme probabilities." 29 | batch = batch.to(self.device) 30 | ids = batch.id 31 | wavs, wav_lens = batch.sig 32 | 33 | if stage == sb.Stage.TRAIN: 34 | if hasattr(self.hparams, "augmentation"): 35 | wavs = self.hparams.augmentation(wavs, wav_lens) 36 | 37 | # some wav2vec models (e.g. large-lv60) needs attention_mask 38 | if self.modules.wav2vec2.feature_extractor.return_attention_mask: 39 | attn_mask = make_attn_mask(wavs, wav_lens) 40 | else: 41 | attn_mask = None 42 | feats = self.modules.wav2vec2(wavs, attention_mask=attn_mask) 43 | x = self.modules.enc(feats) 44 | 45 | # output layer for ctc log-probabilities 46 | logits = self.modules.ctc_lin(x) 47 | p_ctc = self.hparams.log_softmax(logits) 48 | # Note: sb.decoders.ctc_greedy_decode will also remove padded tokens 49 | # that is, it return a list of list with different lengths 50 | sequence = sb.decoders.ctc_greedy_decode( 51 | p_ctc, wav_lens, blank_id=self.hparams.blank_index 52 | ) 53 | transcriptions = [" ".join(self.label_encoder.decode_ndim(s)) for s in sequence] 54 | 55 | 56 | return ids, transcriptions 57 | 58 | def transcribe_dataset( 59 | self, 60 | dataset, # Must be obtained from the dataio_function 61 | min_key, # We load the model with the lowest WER 62 | loader_kwargs # opts for the dataloading 63 | ): 64 | 65 | # If dataset isn't a Dataloader, we create it. 66 | if not isinstance(dataset, torch.utils.data.DataLoader): 67 | loader_kwargs["ckpt_prefix"] = None 68 | dataset = self.make_dataloader( 69 | dataset, sb.Stage.TEST, **loader_kwargs 70 | ) 71 | 72 | 73 | self.on_evaluate_start(min_key=min_key) # We call the on_evaluate_start that will load the best model 74 | self.modules.eval() # We set the model to eval mode (remove dropout etc) 75 | self.modules.wav2vec2.model.config.apply_spec_augment = False # make sure no spec aug applied on wav2vec2 76 | 77 | # Now we iterate over the dataset and we simply compute_forward and decode 78 | with torch.no_grad(): 79 | 80 | wav_ids = [] 81 | transcripts = [] 82 | for batch in tqdm(dataset, dynamic_ncols=True): 83 | 84 | ids, preds = self.compute_forward(batch, stage=sb.Stage.TEST) 85 | 86 | transcripts.extend(preds) 87 | wav_ids.extend(ids) 88 | 89 | return wav_ids, transcripts 90 | 91 | 92 | def dataio_prep(hparams): 93 | """This function prepares the datasets to be used in the brain class. 94 | It also defines the data processing pipeline through user-defined functions.""" 95 | data_folder = hparams["data_folder_save"] 96 | # 1. Declarations: 97 | 98 | inference_data = sb.dataio.dataset.DynamicItemDataset.from_json( 99 | json_path=hparams["inference_annotation"], 100 | replacements={"data_root": data_folder}, 101 | ) 102 | inference_data = inference_data.filtered_sorted(sort_key="duration") 103 | 104 | datasets = [inference_data] 105 | label_encoder = sb.dataio.encoder.CTCTextEncoder() 106 | 107 | # 2. Define audio pipeline: 108 | @sb.utils.data_pipeline.takes("wav") 109 | @sb.utils.data_pipeline.provides("sig") 110 | def audio_pipeline(wav): 111 | # sig = sb.dataio.dataio.read_audio(wav) 112 | # # sample rate change to 16000, e,g, using librosa 113 | # sig = torch.Tensor(librosa.core.load(wav, hparams["sample_rate"])[0]) 114 | # Use wav2vec processor to do normalization 115 | sig = hparams["wav2vec2"].feature_extractor( 116 | librosa.core.load(wav, hparams["sample_rate"])[0], 117 | sampling_rate=hparams["sample_rate"], 118 | ).input_values[0] 119 | sig = torch.Tensor(sig) 120 | return sig 121 | 122 | sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) 123 | 124 | 125 | # 3. Fit encoder: 126 | # Load the label encoder 127 | lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") 128 | label_encoder.load(lab_enc_file) 129 | 130 | # 4. Set output: 131 | sb.dataio.dataset.set_output_keys( 132 | datasets, 133 | ["id", "sig"], 134 | ) 135 | 136 | return inference_data, label_encoder 137 | 138 | 139 | if __name__ == "__main__": 140 | # CLI: 141 | hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) 142 | 143 | # Load hyperparameters file with command-line overrides 144 | with open(hparams_file) as fin: 145 | hparams = load_hyperpyyaml(fin, overrides) 146 | 147 | 148 | # Initialize ddp (useful only for multi-GPU DDP training) 149 | sb.utils.distributed.ddp_init_group(run_opts) 150 | 151 | # Create experiment directory 152 | sb.create_experiment_directory( 153 | experiment_directory=hparams["output_folder"], 154 | hyperparams_to_save=hparams_file, 155 | overrides=overrides, 156 | ) 157 | 158 | 159 | # Dataset IO prep: creating Dataset objects and proper encodings for phones 160 | inference_data, label_encoder = dataio_prep(hparams) 161 | 162 | # Trainer initialization 163 | asr_brain = ASR( 164 | modules=hparams["modules"], 165 | hparams=hparams, 166 | run_opts=run_opts, 167 | checkpointer=hparams["checkpointer"], 168 | ) 169 | asr_brain.label_encoder = label_encoder 170 | wav_ids, transcripts = asr_brain.transcribe_dataset( 171 | dataset=inference_data, # Must be obtained from the dataio_function 172 | min_key="PER", # We load the model with the lowest PER 173 | loader_kwargs=hparams["inference_dataloader_opts"], # opts for the dataloading 174 | ) 175 | with open(hparams["inference_annotation"], "r") as json_f: 176 | unlabeled_data = json.load(json_f) 177 | for wav_id, transcript in zip(wav_ids, transcripts): 178 | unlabeled_data[wav_id].update({"pred_phns": transcript}) 179 | 180 | ## save as new json file 181 | with open(hparams["inference_annotation_saved"], "w") as json_f_save: 182 | json.dump(unlabeled_data, json_f_save, indent=2) 183 | --------------------------------------------------------------------------------