├── plot └── multiPA.png ├── fairseq_roberta └── README.md ├── fairseq_hubert └── README.md ├── speechocean762 ├── convert_wavfile.py └── create_training_and_testing_list.py ├── LICENSE.txt ├── get_gt_alignment.py ├── whisper_asr_all.py ├── get_training_features.py ├── dataloader.py ├── test_open.py ├── test_closed.py ├── api.py ├── evaluation_speechocean_closed.py ├── utils.py ├── README.md ├── evaluation_speechocean_open.py ├── model_assessment.py ├── processors.py ├── utils_assessment.py ├── Charsiu.py └── models.py /plot/multiPA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuwchen/MultiPA/HEAD/plot/multiPA.png -------------------------------------------------------------------------------- /fairseq_roberta/README.md: -------------------------------------------------------------------------------- 1 | 2 | Download roberta.base model from https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.md 3 | 4 | Put the model.pt and dict.txt here -------------------------------------------------------------------------------- /fairseq_hubert/README.md: -------------------------------------------------------------------------------- 1 | 2 | Download HuBERT Base (~95M params) from https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/README.md 3 | 4 | Put the hubert_base_ls960.pt model here. -------------------------------------------------------------------------------- /speechocean762/convert_wavfile.py: -------------------------------------------------------------------------------- 1 | import os 2 | import librosa 3 | import soundfile as sf 4 | from tqdm import tqdm 5 | 6 | def get_filepaths(directory): 7 | file_paths = [] 8 | for root, directories, files in os.walk(directory): 9 | for filename in files: 10 | # Join the two strings in order to form the full filepath. 11 | filepath = os.path.join(root, filename) 12 | if filename.endswith('.WAV'): 13 | file_paths.append(filepath) 14 | return file_paths 15 | 16 | inputdir = './WAVE' 17 | outputdir = './wav' 18 | 19 | if not os.path.exists(outputdir): 20 | os.makedirs(outputdir) 21 | 22 | file_list = get_filepaths(inputdir) 23 | 24 | 25 | for path in tqdm(file_list): 26 | wavname = path.split(os.sep)[-1] 27 | new_path = os.path.join(outputdir, wavname.replace('.WAV','.wav')) 28 | if os.path.isfile(new_path): 29 | continue 30 | y, rate = librosa.load(path, sr=16000) 31 | sf.write(new_path, y, rate) 32 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 yuwchen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /get_gt_alignment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | from tqdm import tqdm 5 | from utils_assessment import * 6 | from Charsiu import charsiu_forced_aligner 7 | 8 | f = open('./speechocean762/resource/scores.json') # path to speechocean score json 9 | data = json.load(f) 10 | 11 | 12 | test_file = open('./speechocean762/test/wav.scp','r').read().splitlines() # path to speechocean test list 13 | test_data = {} 14 | for line in test_file: 15 | wavidx = line.split('\t')[0] 16 | test_data[wavidx] = data[wavidx] 17 | 18 | 19 | gt_alignment_dir = './gt_alignment_test' 20 | wav_dir = './speechocean762/wav' 21 | charsiu = charsiu_forced_aligner(aligner='charsiu/en_w2v2_fc_10ms') 22 | 23 | for wavidx in tqdm(test_data.keys()): 24 | 25 | wavpath = os.path.join(wav_dir , wavidx+'.wav') 26 | gt_sen_list = [] 27 | for word in data[wavidx]['words']: 28 | gt_sen_list.append(word['text'].lower()) 29 | 30 | gt_sen = ' '.join(gt_sen_list) 31 | try: 32 | pred_phones, pred_words, words, pred_prob, phone_ids, word_phone_map = get_charsiu_alignment(wavpath, gt_sen, charsiu) 33 | selected_idx = get_match_index(pred_words, words) 34 | pred_words = np.asarray(pred_words) 35 | pred_words = pred_words[selected_idx] 36 | torch.save(pred_words, os.path.join(gt_alignment_dir, wavidx+'.pt')) 37 | if len(gt_sen_list)!=len(pred_words): 38 | print(wavidx) 39 | print(gt_sen_list) 40 | print(pred_words) 41 | except Exception as e: 42 | print(e) 43 | print(wavidx) 44 | print(gt_sen_list) 45 | -------------------------------------------------------------------------------- /speechocean762/create_training_and_testing_list.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | 5 | def get_filepaths(directory): 6 | file_paths = [] 7 | for root, directories, files in os.walk(directory): 8 | for filename in files: 9 | # Join the two strings in order to form the full filepath. 10 | filepath = os.path.join(root, filename) 11 | if filename.endswith('.wav'): 12 | file_paths.append(filepath) 13 | return file_paths 14 | 15 | f = open('./resource/scores-detail.json') 16 | data = json.load(f) 17 | 18 | 19 | train_file = open('./train/wav.scp','r').read().splitlines() 20 | test_file = open('./test/wav.scp','r').read().splitlines() 21 | 22 | train_out = open('./speechocean762_train.txt','w') 23 | test_out = open('./speechocean762_test.txt','w') 24 | 25 | for line in train_file: 26 | wavidx = line.split('\t')[0] 27 | the_data = data[wavidx] 28 | accuracy = the_data['accuracy'] 29 | completeness = the_data['completeness'] 30 | fluency = the_data['fluency'] 31 | prosodic = the_data['prosodic'] 32 | total = the_data['total'] 33 | 34 | for idx in range(5): 35 | W_acc_list = [] 36 | W_stress_list = [] 37 | W_total_list = [] 38 | sen_length = len(the_data['words']) 39 | for w_idx in range(sen_length): 40 | w_acc = str(the_data['words'][w_idx]['accuracy'][idx]) 41 | w_stress = str(the_data['words'][w_idx]['stress'][idx]) 42 | w_total = str(the_data['words'][w_idx]['total'][idx]) 43 | W_acc_list.append(w_acc) 44 | W_stress_list.append(w_stress) 45 | W_total_list.append(w_total) 46 | 47 | word_acc = ','.join(W_acc_list) 48 | word_stress = ','.join(W_stress_list) 49 | word_total = ','.join(W_total_list) 50 | raw = '{}.wav;{};{};{};{};{};{};{};{}\n'.format(wavidx, accuracy[idx], fluency[idx], prosodic[idx], total[idx], word_acc, word_stress, word_total, sen_length) 51 | train_out.write(raw) 52 | 53 | 54 | for line in test_file: 55 | wavidx = line.split('\t')[0] 56 | test_out.write(wavidx+'.wav\n') -------------------------------------------------------------------------------- /whisper_asr_all.py: -------------------------------------------------------------------------------- 1 | import os 2 | import whisper 3 | import numpy as np 4 | import pandas as pd 5 | from tqdm import tqdm 6 | import argparse 7 | 8 | #Install whisper using: pip install -U openai-whisper 9 | #https://github.com/openai/whisper 10 | 11 | def get_filepaths(directory, format='.wav'): 12 | file_paths = [] 13 | for root, _, files in os.walk(directory): 14 | for filename in files: 15 | if filename.endswith(format): 16 | file_paths.append(filename) 17 | return file_paths 18 | 19 | def main(): 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--datadir', default='./speechocean762/wav', type=str, help='Path of DATA/ directory') 23 | args = parser.parse_args() 24 | input_dir= args.datadir 25 | 26 | file_list = get_filepaths(input_dir, format='.wav') #loop all the .wav file in dir 27 | file_list = set(file_list) 28 | model = whisper.load_model("base.en") 29 | outputname = os.path.join('whisper_results','speechocean_whisper_all_base_eng.csv') 30 | 31 | if not os.path.exists('whisper_results'): 32 | os.makedirs('whisper_results') 33 | 34 | try: 35 | print('Number of files:', len(file_list)) 36 | df = pd.read_csv(outputname) 37 | exist_list = set(df['wavname'].to_list()) 38 | print('Number of already processed files:', len(exist_list)) 39 | file_list = file_list - exist_list 40 | print('Number of unprocessed files:',len(file_list)) 41 | file_list = list(file_list) 42 | except Exception as e: 43 | print('Create new file') 44 | df = pd.DataFrame(columns=['wavname', 'transcript']) 45 | df.to_csv(outputname, sep=',', index=False, header=True) 46 | 47 | 48 | for filename in tqdm(file_list): 49 | path = os.path.join(input_dir, filename) 50 | result = model.transcribe(path, fp16=False) 51 | transcript = result['text'] 52 | results = pd.DataFrame([{'wavname':filename,'transcript':transcript}]) 53 | results.to_csv(outputname, mode='a', sep=',', index=False, header=False) 54 | 55 | 56 | if __name__ == '__main__': 57 | main() 58 | -------------------------------------------------------------------------------- /get_training_features.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import whisper 5 | import numpy as np 6 | import pandas as pd 7 | from tqdm import tqdm 8 | from utils_assessment import * 9 | from Charsiu import charsiu_forced_aligner 10 | from fairseq.models.roberta import RobertaModel 11 | 12 | 13 | def get_transcript(df): 14 | asr_results = {} 15 | for index, row in df.iterrows(): 16 | the_wavname = row['wavname'].replace('.wav','') 17 | the_transcript = row['transcript'] 18 | asr_results[the_wavname] = the_transcript 19 | return asr_results 20 | 21 | def create_dir(outputdir): 22 | if not os.path.exists(outputdir): 23 | os.makedirs(outputdir) 24 | 25 | def main(): 26 | 27 | 28 | f = open('./speechocean762/resource/scores.json') # path to speechocean score json 29 | data = json.load(f) 30 | 31 | train_file = open('./speechocean762/speechocean762_train.txt','r').read().splitlines() 32 | train_list = [] 33 | for line in train_file: 34 | wavname = line.split(';')[0].split('.')[0] 35 | train_list.append(wavname) 36 | train_list = list(set(train_list)) 37 | 38 | df = pd.read_csv('./whisper_results/speechocean_whisper_all_base_eng.csv') 39 | #df_m = pd.read_csv('./whisper_results/speechocean_whisper_all_medium_eng.csv') 40 | #df_s = pd.read_csv('./whisper_results/speechocean_whisper_all_small_eng.csv') 41 | #df_t = pd.read_csv('./whisper_results/speechocean_whisper_all_tiny_eng.csv') 42 | 43 | whisper_results = get_transcript(df) 44 | 45 | wav_dir = './speechocean762/wav' 46 | 47 | outputdir_pred_words_gt = './feature_base/pred_words_gt' 48 | outputdir_features_p = './feature_base/features_p' 49 | outputdir_features_w = './feature_base/features_w' 50 | outputdir_phone_vector = './feature_base/phone_vector' 51 | outputdir_gt_word_embed = './feature_base/gt_word_embed' 52 | outputdir_asr_word_embed = './feature_base/asr_word_embed' 53 | outputdir_word_phone_map = './feature_base/word_phone_map' 54 | 55 | create_dir(outputdir_pred_words_gt) 56 | create_dir(outputdir_features_p) 57 | create_dir(outputdir_features_w) 58 | create_dir(outputdir_phone_vector) 59 | create_dir(outputdir_gt_word_embed) 60 | create_dir(outputdir_asr_word_embed) 61 | create_dir(outputdir_word_phone_map) 62 | 63 | charsiu = charsiu_forced_aligner(aligner='charsiu/en_w2v2_fc_10ms') 64 | roberta = RobertaModel.from_pretrained('./fairseq_roberta', checkpoint_file='model.pt') 65 | roberta.eval() 66 | 67 | error_list = open('error_list.txt','w') 68 | for wavname in tqdm(train_list): 69 | try: 70 | wavpath = os.path.join(wav_dir , wavname+'.wav') 71 | gt_sen_list = [] 72 | for word in data[wavname]['words']: 73 | gt_sen_list.append(word['text'].lower()) 74 | 75 | gt_sen = ' '.join(gt_sen_list) 76 | 77 | asr_sen = whisper_results[wavname].lower() 78 | asr_sen = remove_pun_except_apostrophe(asr_sen) 79 | asr_sen = convert_num_to_word(asr_sen) 80 | 81 | pred_words_gt, features_p, features_w, phone_vector, gt_word_embed, asr_word_embed, word_phone_map = feature_extraction(wavpath, gt_sen, asr_sen, alignment_model=charsiu, word_model=roberta) 82 | 83 | torch.save(pred_words_gt, os.path.join(outputdir_pred_words_gt, wavname+'.pt')) 84 | torch.save(features_p, os.path.join(outputdir_features_p, wavname+'.pt')) 85 | torch.save(features_w, os.path.join(outputdir_features_w, wavname+'.pt')) 86 | torch.save(phone_vector, os.path.join(outputdir_phone_vector, wavname+'.pt')) 87 | torch.save(gt_word_embed, os.path.join(outputdir_gt_word_embed, wavname+'.pt')) 88 | torch.save(asr_word_embed, os.path.join(outputdir_asr_word_embed, wavname+'.pt')) 89 | torch.save(word_phone_map, os.path.join(outputdir_word_phone_map, wavname+'.pt')) 90 | 91 | if len(pred_words_gt) != len(gt_sen_list): 92 | error_list.write(wavname+'#'+str(pred_words_gt)+'#'+str(gt_sen_list)+'\n') 93 | 94 | except Exception as e: 95 | print(e) 96 | error_list.write(wavname+'\n') 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import torchaudio 5 | from torch.utils.data.dataset import Dataset 6 | from torch.nn.utils.rnn import pad_sequence 7 | 8 | 9 | # For MyDatasetW5 10 | ASR_WORD_EMBED_DIR = 'feature_base/asr_word_embed' 11 | GT_WORD_EMBED_DIR = 'feature_base/gt_word_embed' 12 | GT_ALIGNMENT_DIR = 'feature_base/pred_words_gt' 13 | WORD_FEATURE_DIR = 'feature_base/features_w' 14 | PHONE_FEATURE_DIR = 'feature_base/features_p' 15 | PHONEVECTOR_DIR = 'feature_base/phone_vector' 16 | WORD_PHONE_MAP_DIR = 'feature_base/word_phone_map' 17 | 18 | SAMPLE_RATE=16000 19 | 20 | class MyDataset(Dataset): 21 | 22 | def __init__(self, rootdir, data_list): 23 | 24 | self.A_lookup = {} 25 | self.F_lookup = {} 26 | self.P_lookup = {} 27 | self.T_lookup = {} 28 | self.w_acc_lookup = {} 29 | self.w_stress_lookup = {} 30 | self.w_total_lookup = {} 31 | self.num_w = {} 32 | 33 | wavfiles = [] 34 | for line in data_list: 35 | parts = line.split(';') 36 | wavfile = parts[0] 37 | 38 | A = float(parts[1]) 39 | F = float(parts[2]) 40 | P = float(parts[3]) 41 | T = float(parts[4]) 42 | w_acc = parts[5].split(',') 43 | w_stress = parts[6].split(',') 44 | w_total = parts[7].split(',') 45 | 46 | w_acc = [float(x) for x in w_acc] 47 | w_stress = [float(x) for x in w_stress] 48 | w_total = [float(x) for x in w_total] 49 | 50 | num_of_word = float(parts[8]) 51 | self.A_lookup[wavfile] = A 52 | self.F_lookup[wavfile] = F 53 | self.P_lookup[wavfile] = P 54 | self.T_lookup[wavfile] = T 55 | self.w_acc_lookup[wavfile] = w_acc 56 | self.w_stress_lookup[wavfile] = w_stress 57 | self.w_total_lookup[wavfile] = w_total 58 | self.num_w[wavfile] = num_of_word 59 | 60 | wavfiles.append(wavfile) 61 | 62 | self.rootdir = rootdir 63 | self.wavfiles = sorted(wavfiles) 64 | 65 | def __getitem__(self, idx): 66 | wavfile = self.wavfiles[idx] 67 | wavpath = os.path.join(self.rootdir, wavfile) 68 | wav = torchaudio.load(wavpath)[0] 69 | 70 | try: 71 | asr_word_embed = torch.from_numpy(torch.load(os.path.join(ASR_WORD_EMBED_DIR, wavfile.replace('.wav','.pt')))).float() 72 | gt_word_embed = torch.from_numpy(torch.load(os.path.join(GT_WORD_EMBED_DIR, wavfile.replace('.wav','.pt')))).float() 73 | gt_alignment = torch.load(os.path.join(GT_ALIGNMENT_DIR, wavfile.replace('.wav','.pt'))) 74 | features_w = torch.from_numpy(torch.load(os.path.join(WORD_FEATURE_DIR, wavfile.replace('.wav','.pt')))).float() 75 | features_p = torch.from_numpy(torch.load(os.path.join(PHONE_FEATURE_DIR, wavfile.replace('.wav','.pt')))).float() 76 | phonevector = torch.from_numpy(torch.load(os.path.join(PHONEVECTOR_DIR, wavfile.replace('.wav','.pt')))).float() 77 | word_phone_map = torch.load(os.path.join(WORD_PHONE_MAP_DIR, wavfile.replace('.wav','.pt'))) 78 | 79 | except Exception as e: 80 | print(e, wavfile) 81 | return None 82 | 83 | num_w = int(len(gt_alignment)) 84 | 85 | timesplit = [(int(float(word[0])*SAMPLE_RATE), int(float(word[1])*SAMPLE_RATE)) for word in gt_alignment] 86 | 87 | s_A = self.A_lookup[wavfile] 88 | s_F = self.F_lookup[wavfile] 89 | s_P = self.P_lookup[wavfile] 90 | s_T = self.T_lookup[wavfile] 91 | 92 | w_s_acc = torch.tensor(self.w_acc_lookup[wavfile]) 93 | w_s_stress = torch.tensor(self.w_stress_lookup[wavfile]) 94 | w_s_total = torch.tensor(self.w_total_lookup[wavfile]) 95 | 96 | 97 | return wav, s_A, s_F, s_P, s_T, w_s_acc, w_s_stress, w_s_total, timesplit, asr_word_embed, gt_word_embed, features_w, features_p, phonevector, word_phone_map, num_w, wavfile 98 | 99 | 100 | def __len__(self): 101 | return len(self.wavfiles) 102 | 103 | 104 | def collate_fn(self, batch): ## zero padding 105 | 106 | batch = list(filter(lambda x: x is not None, batch)) 107 | 108 | wav, s_A, s_F, s_P, s_T, w_s_acc, w_s_stress, w_s_total, timesplit, asr_word_embed, gt_word_embed, features_w, features_p, phonevector, word_phone_map, num_w, wavfile = zip(*batch) 109 | 110 | wavs = list(wav) 111 | max_len = max(wavs, key = lambda x : x.shape[1]).shape[1] 112 | output_wavs = [] 113 | for wav in wavs: 114 | amount_to_pad = max_len - wav.shape[1] 115 | padded_wav = torch.nn.functional.pad(wav, (0, amount_to_pad), 'constant', 0) 116 | output_wavs.append(padded_wav) 117 | output_wavs = torch.stack(output_wavs, dim=0) 118 | 119 | phonevector = pad_sequence(phonevector, batch_first=True) 120 | asr_word_embed = pad_sequence(asr_word_embed, batch_first=True) 121 | gt_word_embed = pad_sequence(gt_word_embed, batch_first=True) 122 | features_w = pad_sequence(features_w, batch_first=True) 123 | features_p = pad_sequence(features_p, batch_first=True) 124 | 125 | w_s_acc = pad_sequence(w_s_acc, batch_first=True) 126 | w_s_stress = pad_sequence(w_s_stress, batch_first=True) 127 | w_s_total = pad_sequence(w_s_total, batch_first=True) 128 | s_A = torch.stack([torch.tensor(x) for x in list(s_A)], dim=0) 129 | s_F = torch.stack([torch.tensor(x) for x in list(s_F)], dim=0) 130 | s_P = torch.stack([torch.tensor(x) for x in list(s_P)], dim=0) 131 | s_T = torch.stack([torch.tensor(x) for x in list(s_T)], dim=0) 132 | timesplit = list(timesplit) 133 | word_phone_map = list(word_phone_map) 134 | 135 | return output_wavs, s_A, s_F, s_P, s_T, w_s_acc, w_s_stress, w_s_total, timesplit, asr_word_embed, gt_word_embed, features_w, features_p, phonevector, word_phone_map, num_w, wavfile 136 | 137 | -------------------------------------------------------------------------------- /test_open.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import json 4 | import argparse 5 | import torch 6 | import fairseq 7 | import whisper 8 | import numpy as np 9 | import torch.nn as nn 10 | import torchaudio 11 | from tqdm import tqdm 12 | from fairseq.models.roberta import RobertaModel 13 | from dataclasses import dataclass 14 | from utils_assessment import * 15 | from model_assessment import PronunciationPredictor 16 | from Charsiu import charsiu_forced_aligner 17 | 18 | gc.collect() 19 | torch.cuda.empty_cache() 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--fairseq_base_model', type=str, default='./fairseq_hubert/hubert_base_ls960.pt', help='Path to pretrained fairseq hubert model.') 24 | parser.add_argument('--fairseq_roberta', type=str, default='./fairseq_roberta', help='Path to pretrained fairseq roberta.') 25 | parser.add_argument('--datadir', default='./speechocean762/wav', type=str, help='Path of your DATA/ directory') 26 | parser.add_argument('--datalist', default='./speechocean762/speechocean762_test.txt', type=str, help='') 27 | parser.add_argument('--ckptdir', type=str, help='Path to pretrained checkpoint.') 28 | 29 | 30 | args = parser.parse_args() 31 | 32 | ssl_path = args.fairseq_base_model 33 | roberta_path = args.fairseq_roberta 34 | my_checkpoint_dir = args.ckptdir 35 | datadir = args.datadir 36 | datalist = args.datalist 37 | 38 | 39 | word_model = RobertaModel.from_pretrained(roberta_path, checkpoint_file='model.pt') 40 | word_model.eval() 41 | whisper_model_s = whisper.load_model("medium.en") 42 | whisper_model_w = whisper.load_model("base.en") 43 | 44 | 45 | aligment_model = charsiu_forced_aligner(aligner='charsiu/en_w2v2_fc_10ms') 46 | 47 | SSL_OUT_DIM = 768 48 | TEXT_OUT_DIM = 768 49 | SAMPLE_RATE = 16000 50 | 51 | print('Loading checkpoint') 52 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 53 | print('DEVICE: ' + str(device)) 54 | 55 | ssl_model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ssl_path]) 56 | ssl_model = ssl_model[0] 57 | 58 | assessment_model = PronunciationPredictor(ssl_model, SSL_OUT_DIM, TEXT_OUT_DIM).to(device) 59 | assessment_model.eval() 60 | assessment_model.load_state_dict(torch.load(os.path.join(my_checkpoint_dir,'PRO'+os.sep+'best'))) 61 | 62 | print('Loading data') 63 | validset = open(datalist,'r').read().splitlines() 64 | outfile = my_checkpoint_dir.split("/")[-1]+'_'+datalist.split('/')[-1].replace('.txt','_mb.txt') 65 | 66 | output_dir = 'Results' 67 | if not os.path.exists(output_dir): 68 | os.makedirs(output_dir) 69 | 70 | prediction = open(os.path.join(output_dir, outfile), 'w') 71 | 72 | print('Starting prediction') 73 | for filename in tqdm(validset): 74 | 75 | with torch.no_grad(): 76 | if datalist is not None: 77 | filepath = os.path.join(datadir, filename) 78 | else: 79 | filepath=filename 80 | wav, sr = torchaudio.load(filepath) 81 | #resample audio recordin to 16000Hz 82 | if sr!=16000: 83 | transform = torchaudio.transforms.Resample(sr, SAMPLE_RATE) 84 | wav = transform(wav) 85 | sr = SAMPLE_RATE 86 | 87 | wav = torch.reshape(wav, (-1,)) 88 | sen_asr_s = remove_pun_except_apostrophe(get_transcript(wav, whisper_model_s)).lower() 89 | sen_asr_s = convert_num_to_word(sen_asr_s) 90 | 91 | sen_asr_w = remove_pun_except_apostrophe(get_transcript(wav, whisper_model_w)).lower() 92 | sen_asr_w = convert_num_to_word(sen_asr_w) 93 | 94 | try: 95 | pred_words_gt, features_p, features_w, phonevector, gt_word_embed, asr_word_embed, word_phone_map = feature_extraction(wav.numpy(), sen_asr_s, sen_asr_w, alignment_model=aligment_model, word_model=word_model) 96 | 97 | timesplit = [[(int(float(word[0])*SAMPLE_RATE), int(float(word[1])*SAMPLE_RATE)) for word in pred_words_gt]] 98 | word_phone_map = [word_phone_map] 99 | 100 | features_p = torch.from_numpy(features_p).to(device).float().unsqueeze(0) 101 | features_w = torch.from_numpy(features_w).to(device).float().unsqueeze(0) 102 | phonevector = torch.from_numpy(phonevector).to(device).float().unsqueeze(0) 103 | gt_word_embed = torch.from_numpy(gt_word_embed).to(device).float().unsqueeze(0) 104 | asr_word_embed = torch.from_numpy(asr_word_embed).to(device).float().unsqueeze(0) 105 | wav = wav.to(device).unsqueeze(0) 106 | 107 | score_A, score_F, score_P, score_T, w_acc, w_stress, w_total = assessment_model(wav, asr_word_embed, gt_word_embed, features_p, features_w, phonevector, word_phone_map, timesplit) 108 | score_A = score_A.cpu().detach().numpy()[0] 109 | score_F = score_F.cpu().detach().numpy()[0] 110 | score_P = score_P.cpu().detach().numpy()[0] 111 | score_T = score_T.cpu().detach().numpy()[0] 112 | w_a = w_acc.cpu().detach().numpy()[0] 113 | w_s = w_stress.cpu().detach().numpy()[0] 114 | w_t = w_total.cpu().detach().numpy()[0] 115 | 116 | w_a = ','.join([str(num) for num in w_a]) 117 | w_s = ','.join([str(num) for num in w_s]) 118 | w_t = ','.join([str(num) for num in w_t]) 119 | 120 | valid = 'T' 121 | output = "{}; A:{}; F:{}; P:{}; T:{}; Valid:{}; ASR_s:{}; ASR_w:{}; w_a:{}; w_s:{}; w_t:{}; alignment:{}".format(filename, score_A, score_F, score_P, score_T, valid, sen_asr_s, sen_asr_w, w_a, w_s, w_t, pred_words_gt.tolist()) 122 | print(output) 123 | prediction.write(output+'\n') 124 | 125 | except Exception as e: 126 | print(e) 127 | valid = 'F' 128 | output = "{}; A:{}; F:{}; P:{}; T:{}; Valid:{}; ASR_s:{}; ASR_w:{}; w_a:{}; w_s:{}; w_t:{}; alignment:{}".format(filename, '', '', '', '', valid, sen_asr_s, sen_asr_w, '', '', '', '') 129 | prediction.write(output+'\n') 130 | continue 131 | 132 | 133 | torch.cuda.empty_cache() 134 | 135 | 136 | 137 | 138 | if __name__ == '__main__': 139 | main() 140 | -------------------------------------------------------------------------------- /test_closed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import json 4 | import argparse 5 | import torch 6 | import fairseq 7 | import whisper 8 | import numpy as np 9 | import torch.nn as nn 10 | import torchaudio 11 | from tqdm import tqdm 12 | from fairseq.models.roberta import RobertaModel 13 | from dataclasses import dataclass 14 | from utils_assessment import * 15 | from model_assessment import PronunciationPredictor 16 | from Charsiu import charsiu_forced_aligner 17 | 18 | gc.collect() 19 | torch.cuda.empty_cache() 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--fairseq_base_model', type=str, default='./fairseq_hubert/hubert_base_ls960.pt', help='Path to pretrained fairseq hubert model.') 24 | parser.add_argument('--fairseq_roberta', type=str, default='./fairseq_roberta', help='Path to pretrained fairseq roberta.') 25 | parser.add_argument('--speechocean_gt', type=str, default='./speechocean762/resource/scores.json', help='Path to speechocean scores.json') 26 | parser.add_argument('--datadir', default='./speechocean762/wav', type=str, help='Path of your DATA/ directory') 27 | parser.add_argument('--datalist', default='./speechocean762/speechocean762_test.txt', type=str, help='') 28 | parser.add_argument('--ckptdir', type=str, help='Path to pretrained checkpoint.') 29 | 30 | 31 | args = parser.parse_args() 32 | 33 | ssl_path = args.fairseq_base_model 34 | roberta_path = args.fairseq_roberta 35 | my_checkpoint_dir = args.ckptdir 36 | datadir = args.datadir 37 | datalist = args.datalist 38 | 39 | f = open(args.speechocean_gt) 40 | gt_data = json.load(f) 41 | 42 | word_model = RobertaModel.from_pretrained(roberta_path, checkpoint_file='model.pt') 43 | word_model.eval() 44 | whisper_model_w = whisper.load_model("base.en") 45 | 46 | aligment_model = charsiu_forced_aligner(aligner='charsiu/en_w2v2_fc_10ms') 47 | 48 | SSL_OUT_DIM = 768 49 | TEXT_OUT_DIM = 768 50 | SAMPLE_RATE = 16000 51 | 52 | print('Loading checkpoint') 53 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 54 | print('DEVICE: ' + str(device)) 55 | 56 | ssl_model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ssl_path]) 57 | ssl_model = ssl_model[0] 58 | 59 | assessment_model = PronunciationPredictor(ssl_model, SSL_OUT_DIM, TEXT_OUT_DIM).to(device) 60 | assessment_model.eval() 61 | assessment_model.load_state_dict(torch.load(os.path.join(my_checkpoint_dir,'PRO'+os.sep+'best'))) 62 | 63 | print('Loading data') 64 | validset = open(datalist,'r').read().splitlines() 65 | outfile = my_checkpoint_dir.split("/")[-1]+'_'+datalist.split('/')[-1].replace('.txt','_gtb.txt') 66 | 67 | output_dir = 'Results' 68 | if not os.path.exists(output_dir): 69 | os.makedirs(output_dir) 70 | 71 | prediction = open(os.path.join(output_dir, outfile), 'w') 72 | 73 | print('Starting prediction') 74 | for filename in tqdm(validset): 75 | 76 | with torch.no_grad(): 77 | if datalist is not None: 78 | filepath = os.path.join(datadir, filename) 79 | else: 80 | filepath=filename 81 | wav, sr = torchaudio.load(filepath) 82 | #resample audio recordin to 16000Hz 83 | if sr!=16000: 84 | transform = torchaudio.transforms.Resample(sr, SAMPLE_RATE) 85 | wav = transform(wav) 86 | sr = SAMPLE_RATE 87 | 88 | sen_asr_s = [] 89 | for word in gt_data[filename.replace('.wav','')]['words']: 90 | sen_asr_s.append(word['text'].lower()) 91 | sen_asr_s = ' '.join(sen_asr_s) 92 | 93 | wav = torch.reshape(wav, (-1,)) 94 | 95 | sen_asr_w = remove_pun_except_apostrophe(get_transcript(wav, whisper_model_w)).lower() 96 | sen_asr_w = convert_num_to_word(sen_asr_w) 97 | 98 | try: 99 | 100 | pred_words_gt, features_p, features_w, phonevector, gt_word_embed, asr_word_embed, word_phone_map = feature_extraction(wav.numpy(), sen_asr_s, sen_asr_w, alignment_model=aligment_model, word_model=word_model) 101 | 102 | timesplit = [[(int(float(word[0])*SAMPLE_RATE), int(float(word[1])*SAMPLE_RATE)) for word in pred_words_gt]] 103 | word_phone_map = [word_phone_map] 104 | 105 | features_p = torch.from_numpy(features_p).to(device).float().unsqueeze(0) 106 | features_w = torch.from_numpy(features_w).to(device).float().unsqueeze(0) 107 | phonevector = torch.from_numpy(phonevector).to(device).float().unsqueeze(0) 108 | gt_word_embed = torch.from_numpy(gt_word_embed).to(device).float().unsqueeze(0) 109 | asr_word_embed = torch.from_numpy(asr_word_embed).to(device).float().unsqueeze(0) 110 | wav = wav.to(device).unsqueeze(0) 111 | 112 | score_A, score_F, score_P, score_T, w_acc, w_stress, w_total = assessment_model(wav, asr_word_embed, gt_word_embed, features_p, features_w, phonevector, word_phone_map, timesplit) 113 | score_A = score_A.cpu().detach().numpy()[0] 114 | score_F = score_F.cpu().detach().numpy()[0] 115 | score_P = score_P.cpu().detach().numpy()[0] 116 | score_T = score_T.cpu().detach().numpy()[0] 117 | w_a = w_acc.cpu().detach().numpy()[0] 118 | w_s = w_stress.cpu().detach().numpy()[0] 119 | w_t = w_total.cpu().detach().numpy()[0] 120 | 121 | w_a = ','.join([str(num) for num in w_a]) 122 | w_s = ','.join([str(num) for num in w_s]) 123 | w_t = ','.join([str(num) for num in w_t]) 124 | 125 | valid = 'T' 126 | output = "{}; A:{}; F:{}; P:{}; T:{}; Valid:{}; ASR_s:{}; ASR_w:{}; w_a:{}; w_s:{}; w_t:{}; alignment:{}".format(filename, score_A, score_F, score_P, score_T, valid, sen_asr_s, sen_asr_w, w_a, w_s, w_t, pred_words_gt.tolist()) 127 | print(output) 128 | prediction.write(output+'\n') 129 | 130 | except Exception as e: 131 | print(e) 132 | valid = 'F' 133 | output = "{}; A:{}; F:{}; P:{}; T:{}; Valid:{}; ASR_s:{}; ASR_w:{}; w_a:{}; w_s:{}; w_t:{}; alignment:{}".format(filename, '', '', '', '', valid, sen_asr_s, sen_asr_w, '', '', '', '') 134 | prediction.write(output+'\n') 135 | continue 136 | 137 | 138 | torch.cuda.empty_cache() 139 | 140 | 141 | 142 | 143 | if __name__ == '__main__': 144 | main() 145 | -------------------------------------------------------------------------------- /api.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import json 4 | import time 5 | import argparse 6 | import torch 7 | import copy 8 | import fairseq 9 | import whisper 10 | import numpy as np 11 | import torch.nn as nn 12 | import torchaudio 13 | from tqdm import tqdm 14 | from fairseq.models.roberta import RobertaModel 15 | from dataclasses import dataclass 16 | from utils_assessment import * 17 | from model_assessment import PronunciationPredictor 18 | from Charsiu import charsiu_forced_aligner 19 | 20 | gc.collect() 21 | torch.cuda.empty_cache() 22 | 23 | SSL_OUT_DIM = 768 24 | TEXT_OUT_DIM = 768 25 | SAMPLE_RATE = 16000 26 | 27 | def inference_one_seg(sen_asr_s, sen_asr_w, wav): 28 | 29 | sen_asr_s = remove_pun_except_apostrophe(sen_asr_s).lower() 30 | sen_asr_s = convert_num_to_word(sen_asr_s) 31 | 32 | sen_asr_w = remove_pun_except_apostrophe(sen_asr_w).lower() 33 | sen_asr_w = convert_num_to_word(sen_asr_w) 34 | 35 | pred_words_gt, features_p, features_w, phonevector, gt_word_embed, asr_word_embed, word_phone_map = feature_extraction(wav.numpy(), sen_asr_s, sen_asr_w, alignment_model=aligment_model, word_model=word_model) 36 | 37 | timesplit = [[(int(float(word[0])*SAMPLE_RATE), int(float(word[1])*SAMPLE_RATE)) for word in pred_words_gt]] 38 | word_phone_map = [word_phone_map] 39 | 40 | wav = torch.reshape(wav, (1, -1)) 41 | wav = wav.to(device) 42 | features_p = torch.from_numpy(features_p).to(device).float().unsqueeze(0) 43 | features_w = torch.from_numpy(features_w).to(device).float().unsqueeze(0) 44 | phonevector = torch.from_numpy(phonevector).to(device).float().unsqueeze(0) 45 | gt_word_embed = torch.from_numpy(gt_word_embed).to(device).float().unsqueeze(0) 46 | asr_word_embed = torch.from_numpy(asr_word_embed).to(device).float().unsqueeze(0) 47 | 48 | score_A, score_F, score_P, score_T, w_acc, w_stress, w_total = assessment_model(wav, asr_word_embed, gt_word_embed, features_p, features_w, phonevector, word_phone_map, timesplit) 49 | 50 | torch.cuda.empty_cache() 51 | 52 | return score_A, score_F, score_P, score_T, w_acc, w_stress, w_total, pred_words_gt 53 | 54 | 55 | def predict_one_file(filepath, whisper_model_s, whisper_model_w, word_model): 56 | 57 | results = {} 58 | results['wavname'] = filepath.split('/')[-1] 59 | with torch.no_grad(): 60 | 61 | wav, sr = torchaudio.load(filepath) 62 | #resample audio recordin to 16000Hz 63 | if sr!=16000: 64 | transform = torchaudio.transforms.Resample(sr, SAMPLE_RATE) 65 | wav = transform(wav) 66 | sr = SAMPLE_RATE 67 | 68 | if wav.shape[0]!=1: 69 | wav = torch.mean(wav,0) 70 | 71 | wav = torch.reshape(wav, (-1, )) 72 | 73 | if wav.shape[0] < SAMPLE_RATE*15: #if input wavfile is less than 15s, process the wavfile at once 74 | sen_asr_s = get_transcript(wav, whisper_model_s) 75 | sen_asr_w = get_transcript(wav, whisper_model_w) 76 | score_A, score_F, score_P, score_T, w_acc, w_stress, w_total, pred_words_gt = inference_one_seg(sen_asr_s, sen_asr_w, wav) 77 | score_A = score_A.cpu().detach().numpy()[0] 78 | score_F = score_F.cpu().detach().numpy()[0] 79 | score_P = score_P.cpu().detach().numpy()[0] 80 | score_T = score_T.cpu().detach().numpy()[0] 81 | w_a = w_acc.cpu().detach().numpy()[0] 82 | w_s = w_stress.cpu().detach().numpy()[0] 83 | w_t = w_total.cpu().detach().numpy()[0] 84 | results = {} 85 | pred_words = [word[-1] for word in pred_words_gt] 86 | results['uttr_acc'] = score_A 87 | results['uttr_fluency'] = score_F 88 | results['uttr_prosodic'] = score_P 89 | results['uttr_total'] = score_T 90 | results['word_acc'] = w_a 91 | results['word_stress'] = w_s 92 | results['word_total'] = w_t 93 | results['word_text'] = pred_words 94 | results['transcript_S'] = sen_asr_s 95 | results['transcript_W'] = sen_asr_w 96 | 97 | return results 98 | 99 | else: #if wavfile longer than 15s, do the segmentation to prevent OOM 100 | 101 | sen_asr_s_all = get_transcript(wav, whisper_model_s, return_seg=True) 102 | sen_asr_w_all = '' 103 | for seg in sen_asr_s_all['segments']: 104 | sen_asr_s = seg['text'] 105 | start = float(seg['start']) 106 | end = float(seg['end']) 107 | the_wav = wav[int(start*SAMPLE_RATE):int(end*SAMPLE_RATE)] 108 | sen_asr_w = get_transcript(the_wav, whisper_model_w) 109 | sen_asr_w_all = sen_asr_w_all+' '+sen_asr_w 110 | the_score_A, the_score_F, the_score_P, the_score_T, the_w_acc, the_w_stress, the_w_total, the_pred_words_gt = inference_one_seg(sen_asr_s, sen_asr_w, the_wav) 111 | the_pred_words = [word[-1] for word in the_pred_words_gt] 112 | try: 113 | score_A += the_score_A 114 | score_F += the_score_F 115 | score_P += the_score_P 116 | score_T += the_score_T 117 | w_acc = torch.cat((w_acc,the_w_acc.squeeze(0))) 118 | w_stress = torch.cat((w_stress,the_w_stress.squeeze(0))) 119 | w_total = torch.cat((w_acc,the_w_total.squeeze(0))) 120 | pred_word.extend(the_pred_words) 121 | except Exception as e: #first word 122 | score_A = the_score_A 123 | score_F = the_score_F 124 | score_P = the_score_P 125 | score_T = the_score_T 126 | w_acc = the_w_acc.squeeze(0) 127 | w_stress = the_w_stress.squeeze(0) 128 | w_total = the_w_total.squeeze(0) 129 | pred_word = the_pred_words 130 | 131 | num_of_seg = len(sen_asr_s_all['segments']) 132 | results['uttr_acc'] = (score_A/num_of_seg) 133 | results['uttr_fluency'] = (score_F/num_of_seg) 134 | results['uttr_prosodic'] = (score_P/num_of_seg) 135 | results['uttr_total'] = (score_T/num_of_seg) 136 | results['word_acc'] = w_acc.cpu().detach().numpy() 137 | results['word_stress'] = w_stress.cpu().detach().numpy() 138 | results['word_total'] = w_total.cpu().detach().numpy() 139 | results['word_text'] = pred_word 140 | results['transcript_S'] = sen_asr_s_all['text'] 141 | results['transcript_W'] = sen_asr_w_all 142 | 143 | return results 144 | 145 | 146 | parser = argparse.ArgumentParser() 147 | parser.add_argument('--fairseq_base_model', type=str, default='./fairseq_hubert/hubert_base_ls960.pt', help='Path to pretrained fairseq hubert model.') 148 | parser.add_argument('--fairseq_roberta', type=str, default='./fairseq_roberta', help='Path to pretrained fairseq roberta.') 149 | parser.add_argument('--inputdir', type=str, help='Path to testing wavfile.') 150 | parser.add_argument('--ckptdir', type=str, help='Path to pretrained checkpoint.') 151 | 152 | 153 | args = parser.parse_args() 154 | 155 | ssl_path = args.fairseq_base_model 156 | roberta_path = args.fairseq_roberta 157 | my_checkpoint_dir = args.ckptdir 158 | file_dir = args.inputdir 159 | 160 | word_model = RobertaModel.from_pretrained(roberta_path, checkpoint_file='model.pt') 161 | word_model.eval() 162 | whisper_model_s = whisper.load_model("medium.en") 163 | whisper_model_w = whisper.load_model("base.en") 164 | aligment_model = charsiu_forced_aligner(aligner='charsiu/en_w2v2_fc_10ms') 165 | 166 | 167 | print('Loading checkpoint') 168 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 169 | print('DEVICE: ' + str(device)) 170 | 171 | ssl_model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ssl_path]) 172 | ssl_model = ssl_model[0] 173 | 174 | assessment_model = PronunciationPredictor(ssl_model, SSL_OUT_DIM, TEXT_OUT_DIM).to(device) 175 | assessment_model.eval() 176 | assessment_model.load_state_dict(torch.load(os.path.join(my_checkpoint_dir,'PRO'+os.sep+'best'))) 177 | 178 | 179 | filepath_list = get_filepaths(file_dir) 180 | 181 | for filepath in filepath_list: 182 | s = time.time() 183 | results = predict_one_file(filepath, whisper_model_s, whisper_model_w, word_model) 184 | print('Process time:', time.time()-s) 185 | print(results) 186 | 187 | 188 | -------------------------------------------------------------------------------- /evaluation_speechocean_closed.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import scipy 4 | from sklearn.metrics import mean_squared_error 5 | import math 6 | import string 7 | import numpy as np 8 | import math 9 | from collections import Counter 10 | 11 | def print_result(pred, gt, score_name): 12 | mse = mean_squared_error(pred, gt) 13 | corr, _ = scipy.stats.pearsonr(pred, gt) 14 | spearman, _ = scipy.stats.spearmanr(pred, gt) 15 | #print('mse:', mse) 16 | #print('corr:', round(corr,4)) 17 | #print('srcc:', round(spearman,4)) 18 | print(score_name, round(corr,4)) 19 | 20 | 21 | f = open('./speechocean762/resource/scores.json') # path to speechocean score json 22 | data = json.load(f) 23 | 24 | test_file = open('./speechocean762/test/wav.scp','r').read().splitlines() # path to speechocean test list 25 | test_data = {} 26 | for line in test_file: 27 | wavidx = line.split('\t')[0] 28 | test_data[wavidx] = data[wavidx] 29 | 30 | def get_prediction(path): 31 | invalid = 0 32 | prediction = open(path,'r').read().splitlines() 33 | result_word = {} 34 | result_uttr = {} 35 | for sample in prediction: 36 | 37 | parts = sample.split(';') 38 | wavidx = parts[0].replace('.wav','') 39 | valid = parts[5].split(':')[1] 40 | if valid=='F': 41 | invalid+=1 42 | accuracy = 1.0 43 | fluency = 0.0 44 | prosodic = 0.0 45 | total = 0.0 46 | completeness = 0.0 47 | result_word[wavidx]={} 48 | result_word[wavidx]['word_accuracy'] = 0 49 | result_word[wavidx]['word_stress'] = 5 50 | result_word[wavidx]['word_total'] = 1 51 | result_word[wavidx]['text'] = '' 52 | else: 53 | accuracy = float(parts[1].split(':')[1]) 54 | fluency = float(parts[2].split(':')[1]) 55 | prosodic = float(parts[3].split(':')[1]) 56 | total = float(parts[4].split(':')[1]) 57 | alignment = eval(parts[-1].split(':')[1]) 58 | time_interval = [float(word[1])-float(word[0]) for word in alignment] 59 | completeness = [1 if the_interval > 0.07 else 0 for the_interval in time_interval] 60 | completeness = sum(completeness)/len(completeness) 61 | 62 | w_a = eval(parts[8].split(':')[1]) 63 | w_s = eval(parts[9].split(':')[1]) 64 | w_t = eval(parts[10].split(':')[1]) 65 | if isinstance(w_a , float): 66 | w_a = [w_a] 67 | w_s = [w_s] 68 | w_t = [w_t] 69 | w_a = [10 if x > 10 else x for x in w_a] 70 | w_s = [10 if x > 10 else x for x in w_s] 71 | w_t = [10 if x > 10 else x for x in w_t] 72 | result_word[wavidx]={} 73 | result_word[wavidx]['word_accuracy'] = w_a 74 | result_word[wavidx]['word_stress'] = w_s 75 | result_word[wavidx]['word_total'] = w_t 76 | result_word[wavidx]['text'] = eval(parts[-1].split(':')[1]) 77 | result_word[wavidx]['text'] = [word[-1] for word in result_word[wavidx]['text']] 78 | 79 | result_uttr[wavidx]={} 80 | result_uttr[wavidx]['accuracy'] = accuracy 81 | result_uttr[wavidx]['fluency'] = fluency 82 | result_uttr[wavidx]['prosodic'] = prosodic 83 | result_uttr[wavidx]['total'] = total 84 | result_uttr[wavidx]['completeness'] = completeness 85 | 86 | #print(invalid) 87 | return result_word, result_uttr 88 | 89 | def pad_mismatch_sequence(the_gt_w_text, the_pred_w_text, the_pred_w_acc, the_pred_w_stress, the_pred_w_total): 90 | """ 91 | Sometimes the model will merge consecutive occurrences of the same word. e.g. "seven nine nine one" to "seven nine one" 92 | In this case, the number of predicted scores won't in line with the number of ground-truth words. 93 | Therefore, we duplicate the scores for the merged word. 94 | """ 95 | padded_acc = [] 96 | padded_stress = [] 97 | padded_total = [] 98 | asr_w_idx=0 99 | 100 | for gt_word in the_gt_w_text: 101 | if asr_w_idx>=len(the_pred_w_text): 102 | padded_acc.append(the_pred_w_acc[asr_w_idx-1]) 103 | padded_stress.append(the_pred_w_stress[asr_w_idx-1]) 104 | padded_total.append(the_pred_w_total[asr_w_idx-1]) 105 | break 106 | 107 | if gt_word == the_pred_w_text[asr_w_idx]: 108 | padded_acc.append(the_pred_w_acc[asr_w_idx]) 109 | padded_stress.append(the_pred_w_stress[asr_w_idx]) 110 | padded_total.append(the_pred_w_total[asr_w_idx]) 111 | asr_w_idx+=1 112 | else: 113 | padded_acc.append(the_pred_w_acc[asr_w_idx-1]) 114 | padded_stress.append(the_pred_w_stress[asr_w_idx-1]) 115 | padded_total.append(the_pred_w_total[asr_w_idx-1]) 116 | 117 | return padded_acc, padded_stress, padded_total 118 | 119 | def calculate_performance(result_word, result_uttr, wav_idx_word, wav_idx_uttr): 120 | 121 | gt_A = [] 122 | gt_F = [] 123 | gt_P = [] 124 | gt_T = [] 125 | gt_C = [] 126 | 127 | pred_A = [] 128 | pred_F = [] 129 | pred_P = [] 130 | pred_T = [] 131 | pred_C = [] 132 | 133 | for wavidx in wav_idx_uttr: 134 | gt_A.append(test_data[wavidx]['accuracy']) 135 | pred_A.append(result_uttr[wavidx]['accuracy']) 136 | gt_F.append(test_data[wavidx]['fluency']) 137 | pred_F.append(result_uttr[wavidx]['fluency']) 138 | gt_P.append(test_data[wavidx]['prosodic']) 139 | pred_P.append(result_uttr[wavidx]['prosodic']) 140 | gt_T.append(test_data[wavidx]['total']) 141 | pred_T.append(result_uttr[wavidx]['total']) 142 | gt_C.append(test_data[wavidx]['completeness']) 143 | pred_C.append(result_uttr[wavidx]['completeness']) 144 | 145 | print('number of utterance', len(pred_A)) 146 | print_result(pred_A, gt_A, 'sen-accuracy') 147 | print_result(pred_F, gt_F, 'sen-fluency') 148 | print_result(pred_P, gt_P, 'sen-prosody') 149 | print_result(pred_T, gt_T, 'sen-total') 150 | print_result(pred_C, gt_C,'sen-completeness') 151 | 152 | gt_w_acc = [] 153 | gt_w_stress = [] 154 | gt_w_total = [] 155 | pred_w_acc = [] 156 | pred_w_stress = [] 157 | pred_w_total = [] 158 | count_sen = 0 159 | for wavidx in wav_idx_word: 160 | the_gt_w_acc = [] 161 | the_gt_w_stress = [] 162 | the_gt_w_total = [] 163 | the_gt_w_text = [] 164 | 165 | for word in test_data[wavidx]['words']: 166 | the_gt_w_acc.append(int(word['accuracy'])) 167 | the_gt_w_stress.append(int(word['stress'])) 168 | the_gt_w_total.append(int(word['total'])) 169 | the_gt_w_text.append(word['text'].lower()) 170 | 171 | the_pred_w_acc = result_word[wavidx]['word_accuracy'] 172 | the_pred_w_stress = result_word[wavidx]['word_stress'] 173 | the_pred_w_total = result_word[wavidx]['word_total'] 174 | 175 | if len(the_gt_w_text) != len(result_word[wavidx]['text']): #if ground-truth sen and predicted sen not equal in length 176 | if result_word[wavidx]['text'] == '': # if the ASR cannot recognize the sentence, return the lowest score in the training data 177 | gt_len = len(the_gt_w_text) 178 | the_pred_w_acc = [the_pred_w_acc for _ in range(gt_len)] 179 | the_pred_w_stress = [the_pred_w_stress for _ in range(gt_len)] 180 | the_pred_w_total = [the_pred_w_total for _ in range(gt_len)] 181 | else: # for the case where the forced alignment model merges consecutive occurrences of the same word 182 | the_pred_w_acc, the_pred_w_stress, the_pred_w_total = pad_mismatch_sequence(the_gt_w_text, result_word[wavidx]['text'], the_pred_w_acc, the_pred_w_stress, the_pred_w_total) 183 | 184 | #assert len(the_gt_w_acc) == len(the_pred_w_acc) 185 | gt_w_acc.extend(the_gt_w_acc) 186 | gt_w_stress.extend(the_gt_w_stress) 187 | gt_w_total.extend(the_gt_w_total) 188 | pred_w_acc.extend(the_pred_w_acc) 189 | pred_w_stress.extend(the_pred_w_stress) 190 | pred_w_total.extend(the_pred_w_total) 191 | count_sen+=1 192 | 193 | #print('number of sentences for word prediction:', count_sen, "# of words:", len(pred_w_acc)) 194 | print_result(pred_w_acc, gt_w_acc, 'word-acc') 195 | print_result(pred_w_stress, gt_w_stress, 'word-stress') 196 | print_result(pred_w_total, gt_w_total, 'word-total') 197 | 198 | 199 | 200 | resultA_word, resultA_uttr = get_prediction('./Results/model_assessment_val9_r1_speechocean762_test_gtb.txt') 201 | 202 | calculate_performance(resultA_word, resultA_uttr, list(resultA_word.keys()),list(resultA_uttr.keys())) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import numpy as np 6 | 7 | import re 8 | from praatio import textgrid 9 | from itertools import groupby 10 | from librosa.sequence import dtw 11 | 12 | 13 | 14 | def ctc2duration(phones,resolution=0.01): 15 | """ 16 | xxxxx convert ctc to duration 17 | 18 | Parameters 19 | ---------- 20 | phones : list 21 | A list of phone sequence 22 | resolution : float, optional 23 | The resolution of xxxxx. The default is 0.01. 24 | 25 | Returns 26 | ------- 27 | merged : list 28 | xxxxx A list of duration values. 29 | 30 | """ 31 | 32 | counter = 0 33 | out = [] 34 | for p,group in groupby(phones): 35 | length = len(list(group)) 36 | out.append((round(counter*resolution,2),round((counter+length)*resolution,2),p)) 37 | counter += length 38 | 39 | merged = [] 40 | for i, (s,e,p) in enumerate(out): 41 | if i==0 and p=='[PAD]': 42 | merged.append((s,e,'[SIL]')) 43 | elif p=='[PAD]': 44 | merged.append((out[i-1][0],e,out[i-1][2])) 45 | elif i==len(out)-1: 46 | merged.append((s,e,p)) 47 | return merged 48 | 49 | 50 | def seq2duration(phones,resolution=0.01): 51 | """ 52 | xxxxx convert phone sequence to duration 53 | 54 | Parameters 55 | ---------- 56 | phones : list 57 | A list of phone sequence 58 | resolution : float, optional 59 | The resolution of xxxxx. The default is 0.01. 60 | 61 | Returns 62 | ------- 63 | out : list 64 | xxxxx A list of duration values. 65 | 66 | """ 67 | 68 | counter = 0 69 | out = [] 70 | for p,group in groupby(phones): 71 | length = len(list(group)) 72 | out.append((round(counter*resolution,3),round((counter+length)*resolution,3),p)) 73 | counter += length 74 | return out 75 | 76 | 77 | 78 | def duration2textgrid(duration_seq,save_path=None): 79 | """ 80 | Save duration values to textgrids 81 | 82 | Parameters 83 | ---------- 84 | duration_seq : list 85 | xxxxx A list of duration values. 86 | save_path : str, optional 87 | The path to save the TextGrid files. The default is None. 88 | 89 | Returns 90 | ------- 91 | tg : TextGrid file?? str?? xxxxx? 92 | A textgrid object containing duration information. 93 | 94 | """ 95 | 96 | tg = textgrid.Textgrid() 97 | phoneTier = textgrid.IntervalTier('phones', duration_seq, 0, duration_seq[-1][1]) 98 | tg.addTier(phoneTier) 99 | if save_path: 100 | tg.save(save_path,format="short_textgrid", includeBlankSpaces=False) 101 | return tg 102 | 103 | 104 | def word2textgrid(duration_seq,word_seq,save_path=None): 105 | """ 106 | Save duration values to textgrids 107 | 108 | Parameters 109 | ---------- 110 | duration_seq : list 111 | xxxxx A list of duration values. 112 | save_path : str, optional 113 | The path to save the TextGrid files. The default is None. 114 | 115 | Returns 116 | ------- 117 | tg : TextGrid file?? str?? xxxxx? 118 | A textgrid object containing duration information. 119 | 120 | """ 121 | 122 | tg = textgrid.Textgrid() 123 | phoneTier = textgrid.IntervalTier('phones', duration_seq, 0, duration_seq[-1][1]) 124 | tg.addTier(phoneTier) 125 | wordTier = textgrid.IntervalTier('words', word_seq, 0, word_seq[-1][1]) 126 | tg.addTier(wordTier) 127 | if save_path: 128 | tg.save(save_path,format="short_textgrid", includeBlankSpaces=False) 129 | return tg 130 | 131 | 132 | 133 | def get_boundaries(phone_seq): 134 | """ 135 | Get time of phone boundaries 136 | 137 | Parameters 138 | ---------- 139 | phone_seq : list xxxx? 140 | A list of phone sequence. 141 | 142 | Returns 143 | ------- 144 | timings: A list of time stamps 145 | symbols: A list of phone symbols 146 | 147 | """ 148 | 149 | boundaries = defaultdict(set) 150 | for s,e,p in phone_seq: 151 | boundaries[s].update([p.upper()]) 152 | # boundaries[e].update([p.upper()+'_e']) 153 | timings = np.array(list(boundaries.keys())) 154 | symbols = list(boundaries.values()) 155 | return (timings,symbols) 156 | 157 | 158 | def check_textgrid_duration(textgrid,duration): 159 | """ 160 | Check whether the duration of a textgrid file equals to 'duration'. 161 | If not, replace duration of the textgrid file. 162 | 163 | Parameters 164 | ---------- 165 | textgrid : .TextGrid object 166 | A .TextGrid object. 167 | duration : float 168 | A given length of time. 169 | 170 | Returns 171 | ------- 172 | textgrid : .TextGrid object 173 | A modified/unmodified textgrid. 174 | 175 | """ 176 | 177 | 178 | endtime = textgrid.tierDict['phones'].entryList[-1].end 179 | if not endtime==duration: 180 | last = textgrid.tierDict['phones'].entryList.pop() 181 | textgrid.tierDict['phones'].entryList.append(last._replace(end=duration)) 182 | 183 | return textgrid 184 | 185 | 186 | def textgrid_to_labels(phones,duration,resolution): 187 | """ 188 | 189 | 190 | Parameters 191 | ---------- 192 | phones : list 193 | A list of phone sequence 194 | resolution : float, optional 195 | The resolution of xxxxx. The default is 0.01. 196 | duration : float 197 | A given length of time. 198 | 199 | 200 | Returns 201 | ------- 202 | labels : list 203 | A list of phone labels. 204 | 205 | """ 206 | 207 | labels = [] 208 | clock = 0.0 209 | 210 | for i, (s,e,p) in enumerate(phones): 211 | 212 | assert clock >= s 213 | while clock <= e: 214 | labels.append(p) 215 | clock += resolution 216 | 217 | # if more than half of the current frame is outside the current phone 218 | # we'll label it as the next phone 219 | if np.abs(clock-e) > resolution/2: 220 | labels[-1] = phones[min(len(phones)-1,i+1)][2] 221 | 222 | # if the final time interval is longer than the total duration 223 | # we will chop off this frame 224 | if clock-duration > resolution/2: 225 | labels.pop() 226 | 227 | return labels 228 | 229 | def remove_null_and_numbers(labels): 230 | """ 231 | Remove labels which are null, noise, or numbers. 232 | 233 | Parameters 234 | ---------- 235 | labels : list 236 | A list of text labels. 237 | 238 | Returns 239 | ------- 240 | out : list 241 | A list of new labels. 242 | 243 | """ 244 | 245 | out = [] 246 | noises = set(['SPN','NSN','LAU']) 247 | for l in labels: 248 | l = re.sub(r'\d+','',l) 249 | l = l.upper() 250 | if l == '' or l == 'SIL': 251 | l = '[SIL]' 252 | if l == 'SP': 253 | l = '[SIL]' 254 | if l in noises: 255 | l = '[UNK]' 256 | out.append(l) 257 | return out 258 | 259 | 260 | def insert_sil(phones): 261 | """ 262 | Insert silences. 263 | 264 | Parameters 265 | ---------- 266 | phones : list 267 | A list of phone sequence 268 | 269 | Returns 270 | ------- 271 | out : list 272 | A list of new labels. 273 | 274 | """ 275 | 276 | out = [] 277 | for i,(s,e,p) in enumerate(phones): 278 | 279 | if out: 280 | if out[-1][1]!=s: 281 | out.append((out[-1][1],s,'[SIL]')) 282 | out.append((s,e,p)) 283 | return out 284 | 285 | 286 | def forced_align(cost, phone_ids): 287 | 288 | """ 289 | Force align text to audio. 290 | 291 | Parameters 292 | ---------- 293 | cost : float xxxxx 294 | xxxxx. 295 | phone_ids : list 296 | A list of phone IDs. 297 | 298 | Returns 299 | ------- 300 | align_id : list 301 | A list of IDs for aligned phones. 302 | 303 | """ 304 | 305 | D,align = dtw(C=-cost[:,phone_ids], 306 | step_sizes_sigma=np.array([[1, 1], [1, 0]])) 307 | 308 | align_seq = [-1 for i in range(max(align[:,0])+1)] 309 | for i in list(align): 310 | # print(align) 311 | if align_seq[i[0]] 10 | 11 | 12 | - [Requirement](#Requirement) 13 | - [Train and evalaute on speechocean762 dataset](#Train-and-evalaute-on-speechocean762-dataset) 14 | - [Test on your data](#Test-on-your-data) 15 | - [References](#References) 16 | - [MultiPA data](#MultiPA-data) 17 | - [Citation](#Citation) 18 | 19 | 20 | ## Requirement 21 | 22 | ``` 23 | conda create -n MultiPA python=3.9 24 | conda activate MultiPA 25 | pip install fairseq 26 | pip install soundfile 27 | pip install -U openai-whisper 28 | pip install transformers 29 | pip install num2words 30 | pip install pyctcdecode 31 | pip install https://github.com/kpu/kenlm/archive/master.zip 32 | pip install spacy==2.3.0 33 | pip install levenshtein 34 | pip install nltk 35 | pip install praatio 36 | pip install g2pM 37 | pip install librosa 38 | pip install g2p-en 39 | pip install pandas 40 | ``` 41 | Note: spacy needs to be 2.x version 42 | 43 | 44 | #### Download pre-trained model 45 | (1) Download [HuBERT Base (~95M params)](https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/README.md), and put the hubert_base_ls960.pt in fairseq_hubert dir. 46 | (2) Download [roberta.base model](https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.md), and put the model.pt and dict.txt in fairseq_roberta dir. 47 | 48 | ## Train and evalaute on speechocean762 dataset 49 | 50 | ### Step 1. Data Preparation 51 | (1) Download the speechocean762 dataset: [Link](https://www.openslr.org/101). 52 | (2) Put the resource, test, train, and WAVE in the speechocean762 dir. 53 | - i.e., merge the speechocean762 dir in this repo and the download speechocean762 dir. 54 | 55 | (3) Convert .WAV file to .wav file using: 56 | ``` 57 | cd speechocean762 58 | python convert_wavfile.py 59 | ``` 60 | (4) Generate the training and testing list using: 61 | ``` 62 | cd speechocean762 63 | python create_training_and_testing_list.py 64 | ``` 65 | (5) Obtain the whisper ASR result. 66 | ``` 67 | python whisper_asr_all.py 68 | ``` 69 | (6) Generate training features 70 | ``` 71 | python get_training_features.py 72 | ``` 73 | 74 | ### Step 2. Model training 75 | ``` 76 | python model_assessment.py --outdir model_assessment 77 | ``` 78 | Note: usually, the validation loss will stop decreasing after 2 epochs. 79 | 80 | ### Step 3. Inference 81 | Get the assessment results in a closed response scenario (using ground-truth transcript) 82 | ``` 83 | python test_closed.py --ckptdir model_assessment 84 | ``` 85 | The results will be saved in the "model_assessment_speechocean762_test_gtb.txt" with the format: 86 | {wavname}; A:{uttr-accuracy}; F:{uttr-fluency}; P:{uttr-prosodic}; T:{uttr-total}; Valid:{whether_output_is_valid}; ASR_s:{groud-truth-sentence}; ASR_w:{asr-results}; w_a:{word-accuracy}; w_s:{word-stress}; w_t:{w-total}; alignment:{forced-alignment-result} 87 | 88 | 89 | Get the assessment results in an open response scenario (using the result of ASRt as a ground-truth transcript) 90 | ``` 91 | python test_open.py --ckptdir model_assessment 92 | ``` 93 | The results will be saved in the "model_assessment_speechocean762_test_mb.txt" 94 | 95 | 96 | ### Step 4. Evaluation 97 | ----- 98 | ### Closed response scenario 99 | 100 | Use "evaluation_speechocean_closed.py". Change the input path of the "get_prediction" function to the path of generated txt file in the Step 3. 101 | 102 | Note: 103 | - Since the whisper might give different recognition results for the same utterance, the performance scores will be slightly different for different runs. 104 | - For utterances that the MultiPA fails to process, the lowest scores in the training data are used. (i.e., accuracy = 1, fluency = 0, prosodic = 0, total = 0, completeness = 0, word_accuracy = 0, word_stress= 5, and word_total = 1.) 105 | - The scores higher than 10 will be clipped to 10. 106 | - The scores in the paper are the average of five models training with different random seeds. 107 | 108 | Closed response performance (PCC): 109 | | sen-accuracy | sen-fluency | sen-prosody | sen-total | word-accuracy | word-stress | word-total | 110 | |---------------|--------------|---------------|-------------|---------------|-------------|------------| 111 | | ~0.73 | ~0.79 | ~0.78 | ~0.76 |~0.52 |~0.19 | ~0.53 | 112 | 113 | ### Completeness assessment metric 114 | According to results of the previous studies, the completeness score is difficult to learn using a neural network. Therefore, we propose a forced-alignment-based method to assess completeness instead of training the completeness score with the main structure. The completeness score is defined as whether an L2 learner completes all words in the target sentence without any omissions. The idea is that if the L2 learner misses a word during pronunciation, the forced-alignment model will encounter difficulty in aligning that word to the speech signal, resulting in a shorter duration for the missed word compared to the properly pronounced words. Then, a predefined duration threshold selected based on empirical experiments can classify words into complete and incomplete words. Finally, the completeness score is calculated by identifying the ratio of complete words to the total number of words in the transcript. 115 | 116 | We analyze the word duration distribution from the forced alignment model. In this experiment, complete words are drawn from the speechocean762 dataset with full completeness scores, whereas incomplete words are simulated by randomly inserting an extra word into the original sentence. Complete words follow a nearly normal distribution, averaging around 0.38 seconds. In contrast, incomplete words exhibit a right-skewed pattern, with mean approximately at 0.075 seconds. Next, we assess various duration thresholds' effectiveness in distinguishing complete from incomplete words. Our findings reveal a threshold of roughly 0.07 seconds yielding the highest F1 score, approximately 85\%, indicating that the forced-alignment model can effectively detect missing words in a target sentence. 117 | 118 | 119 | ----- 120 | ### Open response scenario 121 | 122 | Use "evaluation_speechocean_open.py". 123 | (1) Calculate and save the alignment information of the ground-truth transcript using get_gt_alignment.py. Change the path to dir in line 163. 124 | (2) Change the input path of the "get_prediction" function to the path of the generated txt file in Step 3. 125 | Note: 126 | - The evaluation of the word-level assessment result is different from the closed response scenario because there is a potential mismatch between the ground-truth label and the predicted scores. (the ground-truth labels are aligned with the ground-truth words, whereas the predicted word-level scores are aligned with the ASR-recognized words.) 127 | - Since the whisper might give different recognition results for the same utterance, the performance scores will be slightly different for different runs. 128 | 129 | | sen-accuracy | sen-fluency | sen-prosody | sen-total | word-accuracy | word-stress | word-total | 130 | |---------------|--------------|---------------|-------------|---------------|-------------|------------| 131 | | ~0.70 | ~0.77 | ~0.76 | ~0.73 |~0.42 |~0.24 | ~0.44 | 132 | 133 | 134 | 135 | ## Test on your data 136 | 137 | ``` 138 | python api.py --inputdir /path/to/your/wav/dir --ckptdir model_assessment 139 | ``` 140 | Note: 141 | - This api works on open response. Please replace "sen_asr_s" with the target sentence if you want to test on closed response. 142 | - One limitation of MultiPA is its ability to assess long utterances. First, MultiPA might fail to process a long utterance due to the GPU out-off-memory issue. In addition, its generalization capabilities might be limited because it was trained on utterances shorter than 20 seconds. Therefore, an additional audio segmentation step is recommended when using MultiPA on long utterances. In the api.py, we implement a simple segmentation method based on whisper's results. Specifically, if a wave file is longer than 15 seconds, the model will work on whisper segments and merge (average) the results instead of processing the entire wave file at once. 143 | 144 | Pretrained model: 145 | Download pre-trained model: [Google Drive](https://drive.google.com/file/d/1Kpm3BeEh6Rh7JZ5tatyHMUMipuo0RYds/view?usp=sharing) 146 | 147 | ## References 148 | The Charsiu.py, models.py, processors.py, utils.py in this repo are revised from [Charsiu](https://github.com/lingjzhu/charsiu/tree/main). 149 | The major change includes: 150 | (1) return the output embedding (return the probability of all possible phones) 151 | (2) prevent merging the duration of multiple identical words. 152 | (e.g., transcript: very very -> return (s1, e1, very), (s2, e2, very) instead of (s1, e2, very)) 153 | -> However, in some cases, the model will still return only one alignment result, leading to the mismatch between words in the input sentence and the alignedd words. 154 | 155 | ## MultiPA data 156 | 157 | Pilot dataset for real-world open response scenario speech assessment. 158 | [Download](https://drive.google.com/drive/folders/1T1_xTcwPF94WtUU4XVId7bGYPvAvwCSZ?usp=sharing) 159 | Please submit a request to access. I will add your email to share list 160 | 161 | ### Acknowledge 162 | This data is collected by using [Label Studio](https://labelstud.io/) Academic Program. We thank Label studio to provide platform that allows researchers to collect data easily. 163 | 164 | ## Citation 165 | Please cite our paper if you find this repository useful, Thanks! 166 | 167 | @inproceedings{chen2024multipa, 168 | title = {{MultiPA}: A Multi-task Speech Pronunciation Assessment Model for Open Response Scenarios}, 169 | author = {Chen, Yu-Wen and Yu, Zhou and Hirschberg, Julia}, 170 | booktitle = {Proc. INTERSPEECH 2024} 171 | } 172 | -------------------------------------------------------------------------------- /evaluation_speechocean_open.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import scipy 4 | from sklearn.metrics import mean_squared_error 5 | import math 6 | import string 7 | import numpy as np 8 | import math 9 | import torch 10 | 11 | def pad_mismatch_sequence(the_gt_w_text, the_pred_w_text, the_pred_w_acc, the_pred_w_stress, the_pred_w_total): 12 | """ 13 | Sometimes the model will merge consecutive occurrences of the same word. e.g. "seven nine nine one" to "seven nine one" 14 | In this case, the number of predicted scores won't in line with the number of ground-truth words. 15 | Therefore, we duplicate the scores for the merged word. 16 | """ 17 | padded_acc = [] 18 | padded_stress = [] 19 | padded_total = [] 20 | asr_w_idx=0 21 | 22 | for gt_word in the_gt_w_text: 23 | if asr_w_idx>=len(the_pred_w_text): 24 | padded_acc.append(the_pred_w_acc[asr_w_idx-1]) 25 | padded_stress.append(the_pred_w_stress[asr_w_idx-1]) 26 | padded_total.append(the_pred_w_total[asr_w_idx-1]) 27 | break 28 | 29 | if gt_word == the_pred_w_text[asr_w_idx]: 30 | padded_acc.append(the_pred_w_acc[asr_w_idx]) 31 | padded_stress.append(the_pred_w_stress[asr_w_idx]) 32 | padded_total.append(the_pred_w_total[asr_w_idx]) 33 | asr_w_idx+=1 34 | else: 35 | padded_acc.append(the_pred_w_acc[asr_w_idx-1]) 36 | padded_stress.append(the_pred_w_stress[asr_w_idx-1]) 37 | padded_total.append(the_pred_w_total[asr_w_idx-1]) 38 | 39 | return padded_acc, padded_stress, padded_total 40 | 41 | def align_two_sentences(result_gt, result_asr): 42 | 43 | asr_wordidx_list = [] 44 | for _, gt_value in enumerate(result_gt): 45 | gt_start = gt_value[0] 46 | gt_end = gt_value[1] 47 | asr_wordidx = [] 48 | for asr_idx, asr_value in enumerate(result_asr): 49 | asr_start = asr_value[0] 50 | asr_end = asr_value[1] 51 | if gt_end <= asr_start: 52 | break 53 | if gt_start >= asr_end: 54 | continue 55 | if max(gt_start, asr_start) <= min(gt_end, asr_end): 56 | asr_wordidx.append(asr_idx) 57 | 58 | asr_wordidx_list.append(asr_wordidx) 59 | 60 | return asr_wordidx_list 61 | 62 | def print_result(pred, gt): 63 | mse = mean_squared_error(pred, gt) 64 | corr, _ = scipy.stats.pearsonr(pred, gt) 65 | spearman, _ = scipy.stats.spearmanr(pred, gt) 66 | #print('mse:', mse) 67 | #print('corr:', round(corr,4)) 68 | #print('srcc:', round(spearman,4)) 69 | print(round(corr,4)) 70 | 71 | 72 | def align_two_sentences(result_gt, result_asr): 73 | 74 | asr_wordidx_list = [] 75 | for _, gt_value in enumerate(result_gt): 76 | gt_start = gt_value[0] 77 | gt_end = gt_value[1] 78 | asr_wordidx = [] 79 | for asr_idx, asr_value in enumerate(result_asr): 80 | asr_start = asr_value[0] 81 | asr_end = asr_value[1] 82 | if gt_end <= asr_start: 83 | break 84 | if gt_start >= asr_end: 85 | continue 86 | if max(gt_start, asr_start) <= min(gt_end, asr_end): 87 | asr_wordidx.append(asr_idx) 88 | 89 | asr_wordidx_list.append(asr_wordidx) 90 | 91 | return asr_wordidx_list 92 | 93 | f = open('./speechocean762/resource/scores.json') # path to speechocean score json 94 | data = json.load(f) 95 | 96 | test_file = open('./speechocean762/test/wav.scp','r').read().splitlines() # path to speechocean test list 97 | test_data = {} 98 | for line in test_file: 99 | wavidx = line.split('\t')[0] 100 | test_data[wavidx] = data[wavidx] 101 | 102 | def get_prediction(path): 103 | invalid = 0 104 | prediction = open(path,'r').read().splitlines() 105 | 106 | print('# of prediction:', len(prediction)) 107 | result_word = {} 108 | result_uttr = {} 109 | for sample in prediction: 110 | 111 | parts = sample.split(';') 112 | wavidx = parts[0].replace('.wav','') 113 | valid = parts[5].split(':')[1] 114 | if valid=='F': 115 | invalid+=1 116 | accuracy = 1.0 117 | fluency = 0.0 118 | prosodic = 0.0 119 | total = 0.0 120 | result_word[wavidx]={} 121 | result_word[wavidx]['word_accuracy'] = 0 122 | result_word[wavidx]['word_stress'] = 5 123 | result_word[wavidx]['word_total'] = 1 124 | result_word[wavidx]['text'] = '' 125 | result_word[wavidx]['alignment'] = None 126 | else: 127 | accuracy = float(parts[1].split(':')[1]) 128 | fluency = float(parts[2].split(':')[1]) 129 | prosodic = float(parts[3].split(':')[1]) 130 | total = float(parts[4].split(':')[1]) 131 | 132 | w_a = eval(parts[8].split(':')[1]) 133 | w_s = eval(parts[9].split(':')[1]) 134 | w_t = eval(parts[10].split(':')[1]) 135 | if isinstance(w_a , float): 136 | w_a = [w_a] 137 | w_s = [w_s] 138 | w_t = [w_t] 139 | w_a = [10 if x > 10 else x for x in w_a] 140 | w_s = [10 if x > 10 else x for x in w_s] 141 | w_t = [10 if x > 10 else x for x in w_t] 142 | result_word[wavidx]={} 143 | result_word[wavidx]['word_accuracy'] = w_a 144 | result_word[wavidx]['word_stress'] = w_s 145 | result_word[wavidx]['word_total'] = w_t 146 | result_word[wavidx]['text'] = eval(parts[-1].split(':')[1]) 147 | result_word[wavidx]['text'] = [word[-1] for word in result_word[wavidx]['text']] 148 | result_word[wavidx]['alignment'] = eval(parts[-1].split(':')[1]) 149 | 150 | result_uttr[wavidx]={} 151 | result_uttr[wavidx]['accuracy'] = accuracy 152 | result_uttr[wavidx]['fluency'] = fluency 153 | result_uttr[wavidx]['prosodic'] = prosodic 154 | result_uttr[wavidx]['total'] = total 155 | 156 | 157 | print('# of invalid sentence:', invalid) 158 | return result_word, result_uttr 159 | 160 | 161 | def calculate_performance(result_word, result_uttr, wav_idx_word, wav_idx_uttr): 162 | 163 | gt_alignment_dir = '../gt_alignment_test' 164 | gt_A = [] 165 | gt_F = [] 166 | gt_P = [] 167 | gt_T = [] 168 | 169 | pred_A = [] 170 | pred_F = [] 171 | pred_P = [] 172 | pred_T = [] 173 | 174 | for wavidx in wav_idx_uttr: 175 | gt_A.append(test_data[wavidx]['accuracy']) 176 | pred_A.append(result_uttr[wavidx]['accuracy']) 177 | gt_F.append(test_data[wavidx]['fluency']) 178 | pred_F.append(result_uttr[wavidx]['fluency']) 179 | gt_P.append(test_data[wavidx]['prosodic']) 180 | pred_P.append(result_uttr[wavidx]['prosodic']) 181 | gt_T.append(test_data[wavidx]['total']) 182 | pred_T.append(result_uttr[wavidx]['total']) 183 | 184 | print('number of utterance', len(pred_A)) 185 | #print('accuracy') 186 | print_result(pred_A, gt_A) 187 | #print('fluency') 188 | print_result(pred_F, gt_F) 189 | #print('prosodic') 190 | print_result(pred_P, gt_P) 191 | #print('total') 192 | print_result(pred_T, gt_T) 193 | 194 | gt_w_acc = [] 195 | gt_w_stress = [] 196 | gt_w_total = [] 197 | pred_w_acc = [] 198 | pred_w_stress = [] 199 | pred_w_total = [] 200 | count_sen = 0 201 | for wavidx in wav_idx_word: 202 | the_gt_w_acc = [] 203 | the_gt_w_stress = [] 204 | the_gt_w_total = [] 205 | the_gt_w_text = [] 206 | 207 | for word in test_data[wavidx]['words']: 208 | the_gt_w_acc.append(int(word['accuracy'])) 209 | the_gt_w_stress.append(int(word['stress'])) 210 | the_gt_w_total.append(int(word['total'])) 211 | the_gt_w_text.append(word['text'].lower()) 212 | 213 | the_pred_w_acc = result_word[wavidx]['word_accuracy'] 214 | the_pred_w_stress = result_word[wavidx]['word_stress'] 215 | the_pred_w_total = result_word[wavidx]['word_total'] 216 | 217 | if result_word[wavidx]['alignment'] is None: #model fails 218 | gt_len = len(the_gt_w_text) 219 | the_pred_w_acc = [the_pred_w_acc for _ in range(gt_len)] 220 | the_pred_w_stress = [the_pred_w_stress for _ in range(gt_len)] 221 | the_pred_w_total = [the_pred_w_total for _ in range(gt_len)] 222 | gt_w_acc.extend(the_gt_w_acc) 223 | gt_w_stress.extend(the_gt_w_stress) 224 | gt_w_total.extend(the_gt_w_total) 225 | pred_w_acc.extend(the_pred_w_acc) 226 | pred_w_stress.extend(the_pred_w_stress) 227 | pred_w_total.extend(the_pred_w_total) 228 | continue 229 | 230 | pred_sen = ' '.join(result_word[wavidx]['text']) 231 | gt_sen = ' '.join(the_gt_w_text) 232 | if pred_sen!=gt_sen: 233 | 234 | pred_alignment = result_word[wavidx]['alignment'] 235 | gt_alignment = torch.load(os.path.join(gt_alignment_dir, wavidx+'.pt')) 236 | align_result = align_two_sentences(pred_alignment, gt_alignment) 237 | 238 | the_gt_w_acc = np.asarray(the_gt_w_acc) 239 | the_gt_w_stress = np.asarray(the_gt_w_stress) 240 | the_gt_w_total = np.asarray(the_gt_w_total) 241 | 242 | align_gt_w_acc = [] 243 | align_gt_w_stress = [] 244 | align_gt_w_total = [] 245 | for widxs in align_result: 246 | if len(widxs)!=0: 247 | align_gt_w_acc.append(np.mean(the_gt_w_acc[widxs])) 248 | align_gt_w_stress.append(np.mean(the_gt_w_stress[widxs])) 249 | align_gt_w_total.append(np.mean(the_gt_w_total[widxs])) 250 | else: 251 | align_gt_w_acc.append(0) 252 | align_gt_w_stress.append(5) 253 | align_gt_w_total.append(1) 254 | 255 | gt_w_acc.extend(align_gt_w_acc) 256 | gt_w_stress.extend(align_gt_w_stress) 257 | gt_w_total.extend(align_gt_w_total) 258 | pred_w_acc.extend(the_pred_w_acc) 259 | pred_w_stress.extend(the_pred_w_stress) 260 | pred_w_total.extend(the_pred_w_total) 261 | 262 | else: # prediction the same as the ground-truth 263 | 264 | if len(the_gt_w_text) != len(result_word[wavidx]['text']): #condition that forced alignment merge the duplicated sentence 265 | the_pred_w_acc, the_pred_w_stress, the_pred_w_total = pad_mismatch_sequence(the_gt_w_text, result_word[wavidx]['text'], the_pred_w_acc, the_pred_w_stress, the_pred_w_total) 266 | 267 | gt_w_acc.extend(the_gt_w_acc) 268 | gt_w_stress.extend(the_gt_w_stress) 269 | gt_w_total.extend(the_gt_w_total) 270 | pred_w_acc.extend(the_pred_w_acc) 271 | pred_w_stress.extend(the_pred_w_stress) 272 | pred_w_total.extend(the_pred_w_total) 273 | 274 | 275 | count_sen+=1 276 | #print(pred_sen, gt_sen) 277 | 278 | print('number of sentences for word prediction:', count_sen, "# of words:", len(pred_w_acc)) 279 | #print('word acc') 280 | print_result(pred_w_acc, gt_w_acc) 281 | #print('word stress') 282 | print_result(pred_w_stress, gt_w_stress) 283 | #print('word total') 284 | print_result(pred_w_total, gt_w_total) 285 | 286 | 287 | 288 | resultA_word, resultA_uttr = get_prediction('./Results/model_assessment_val9_r2_speechocean762_test_mb.txt') 289 | calculate_performance(resultA_word, resultA_uttr, list(resultA_word.keys()),list(resultA_uttr.keys())) -------------------------------------------------------------------------------- /model_assessment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import fairseq 4 | import random 5 | import torch 6 | import numpy as np 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | from tqdm import tqdm 11 | from torch.utils.data import DataLoader 12 | from dataloader import MyDataset 13 | 14 | 15 | random.seed(1984) 16 | 17 | 18 | class PronunciationPredictor(nn.Module): 19 | 20 | def __init__(self, ssl_model, ssl_out_dim, text_out_dim): 21 | super(PronunciationPredictor, self).__init__() 22 | self.ssl_model = ssl_model 23 | self.ssl_features = ssl_out_dim # size of HuBERT embedding 24 | self.text_out_dim = text_out_dim # size of roberta embedding 25 | self.w_features_size = 10 26 | self.p_align_size = 9 27 | self.p_pred_size = 42 28 | self.phone_vector = 71*2 29 | 30 | hidden_word = self.ssl_features+self.text_out_dim*2+self.phone_vector+self.w_features_size+self.p_align_size+self.p_pred_size*2+1 31 | hidden_sen = 768+hidden_word 32 | 33 | self.output_accuracy = nn.Linear(hidden_sen, 1) 34 | self.output_fluency = nn.Linear(hidden_sen, 1) 35 | self.output_prosodic = nn.Linear(hidden_sen, 1) 36 | self.output_total = nn.Linear(hidden_sen, 1) 37 | 38 | self.fusion_layer = nn.TransformerEncoderLayer(d_model=hidden_word, nhead=30) 39 | 40 | self.fusion_layer_word = nn.TransformerEncoderLayer(d_model=hidden_word, nhead=30) 41 | self.w_feature_layer = nn.TransformerEncoderLayer(d_model=self.w_features_size, nhead=5) 42 | self.p_feature_layer_align = nn.TransformerEncoderLayer(d_model=self.p_align_size, nhead=3) 43 | self.p_feature_layer_pred_gt = nn.TransformerEncoderLayer(d_model=self.p_pred_size, nhead=7) 44 | self.p_feature_layer_pred_asr = nn.TransformerEncoderLayer(d_model=self.p_pred_size, nhead=7) 45 | self.p_feature_layer_1 = nn.TransformerEncoderLayer(d_model=self.p_align_size+self.p_pred_size*2, nhead=3) 46 | self.p_feature_layer_2 = nn.TransformerEncoderLayer(d_model=self.p_align_size+self.p_pred_size*2, nhead=3) 47 | self.phonevector_layer = nn.TransformerEncoderLayer(d_model=self.phone_vector, nhead=2) 48 | 49 | self.word_acc = nn.Conv1d(hidden_word, 1, kernel_size=1) 50 | self.word_stress = nn.Conv1d(hidden_word, 1, kernel_size=1) 51 | self.word_total = nn.Conv1d(hidden_word, 1, kernel_size=1) 52 | 53 | def forward(self, wav, asr_word_embed, gt_word_embed, features_p, features_w, phonevector, word_phone_map, timesplit): 54 | 55 | res = self.ssl_model(wav, mask=False, features_only=True) 56 | wav_embedding_raw = res['x'] 57 | 58 | ### align word-level features to the wavform 59 | batch_size = gt_word_embed.shape[0] 60 | wav_aligned = torch.zeros((gt_word_embed.shape[0], gt_word_embed.shape[1], self.ssl_features)).cuda() 61 | for b_idx in range(batch_size): 62 | for w_idx in range(len(timesplit[b_idx])): 63 | start_point = timesplit[b_idx][w_idx][0] // 320 64 | end_point = timesplit[b_idx][w_idx][1] // 320 65 | if (end_point - start_point)==0: # avoid predict nan because of no aligned wav segment 66 | the_word = wav_embedding_raw[b_idx, start_point:start_point+1, :] 67 | else: 68 | the_word = wav_embedding_raw[b_idx, start_point:end_point, :] 69 | aligned_wav_embed = the_word.mean(dim=0) 70 | wav_aligned[b_idx, w_idx, :] = aligned_wav_embed 71 | 72 | 73 | features_w = self.w_feature_layer(features_w) 74 | features_p = self.p_feature_layer_1(features_p) 75 | features_p[:,:,:self.p_align_size] = self.p_feature_layer_align(features_p[:,:,:self.p_align_size]) 76 | features_p[:,:,self.p_align_size:self.p_align_size+self.p_pred_size] = self.p_feature_layer_pred_gt(features_p[:,:,self.p_align_size:self.p_align_size+self.p_pred_size]) 77 | features_p[:,:,self.p_align_size+self.p_pred_size:] = self.p_feature_layer_pred_asr(features_p[:,:,self.p_align_size+self.p_pred_size:]) 78 | 79 | # align phone-level features to word-level features 80 | features_p_aligned = torch.zeros((gt_word_embed.shape[0], gt_word_embed.shape[1], self.p_align_size+self.p_pred_size*2)).cuda() 81 | for b_idx in range(batch_size): 82 | for w_idx, p_list in enumerate(word_phone_map[b_idx]): 83 | features_p_aligned[b_idx, w_idx, :] = features_p[b_idx,p_list,:].mean(dim=0) 84 | 85 | features_p_aligned = self.p_feature_layer_2(features_p_aligned) 86 | phonevector = self.phonevector_layer(phonevector) 87 | fusion = torch.cat([wav_aligned, gt_word_embed, asr_word_embed, features_w, features_p_aligned, phonevector], dim=2) 88 | fusion = F.pad(fusion, (0, 1), mode='constant') # expand one dimension because original feature size is a prime number 89 | 90 | fusion_word = self.fusion_layer_word(fusion) 91 | fusion_word = fusion_word.transpose(1, 2) 92 | output_w_acc = self.word_acc(fusion_word) 93 | output_w_stress = self.word_stress(fusion_word) 94 | output_w_total = self.word_total(fusion_word) 95 | output_w_acc = output_w_acc.transpose(1,2).squeeze(2) 96 | output_w_stress = output_w_stress.transpose(1,2).squeeze(2) 97 | output_w_total = output_w_total.transpose(1,2).squeeze(2) 98 | 99 | 100 | fusion = self.fusion_layer(fusion) 101 | uttr_word = torch.mean(fusion, 1) 102 | wav_embedding = torch.mean(wav_embedding_raw, 1) 103 | uttr = torch.cat([wav_embedding, uttr_word], dim=1) 104 | output_A = self.output_accuracy(uttr) 105 | output_F = self.output_fluency(uttr) 106 | output_P = self.output_prosodic(uttr) 107 | output_T = self.output_total(uttr) 108 | 109 | 110 | return output_A.squeeze(1), output_F.squeeze(1), output_P.squeeze(1), output_T.squeeze(1), output_w_acc, output_w_stress, output_w_total 111 | 112 | 113 | 114 | def main(): 115 | 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument('--datadir', default='./speechocean762/wav', type=str, help='Path to root data directory') 118 | parser.add_argument('--txtfiledir', default='./speechocean762', type=str, help='Path to training txt directory') 119 | parser.add_argument('--fairseq_base_model', default='./fairseq_hubert/hubert_base_ls960.pt', type=str, help='Path to pretrained fairseq base model') 120 | parser.add_argument('--finetune_from_checkpoint', type=str, required=False, help='Path to the checkpoint to finetune from') 121 | parser.add_argument('--outdir', type=str, required=False, default='model_assessment', help='Output directory for trained checkpoints') 122 | 123 | args = parser.parse_args() 124 | 125 | cp_path = args.fairseq_base_model 126 | datadir = args.datadir 127 | ckptdir = args.outdir 128 | txtfiledir = args.txtfiledir 129 | my_checkpoint_dir = args.finetune_from_checkpoint 130 | 131 | if not os.path.exists(ckptdir): 132 | os.makedirs(os.path.join(ckptdir,'PRO')) 133 | 134 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 135 | print('DEVICE: ' + str(device)) 136 | 137 | wavdir = os.path.join(datadir, '') 138 | input_list = open(os.path.join(txtfiledir, 'speechocean762_train.txt'),'r').read().splitlines() 139 | random.shuffle(input_list) 140 | 141 | trainlist = input_list[:int(len(input_list)*0.9)] 142 | validlist = input_list[int(len(input_list)*0.9):] 143 | 144 | SSL_OUT_DIM = 768 145 | TEXT_OUT_DIM = 768 146 | 147 | model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path]) 148 | ssl_model = model[0] 149 | trainset = MyDataset(wavdir, trainlist) 150 | trainloader = DataLoader(trainset, batch_size=2, shuffle=True, num_workers=2, collate_fn=trainset.collate_fn) 151 | validset = MyDataset(wavdir, validlist) 152 | validloader = DataLoader(validset, batch_size=2, shuffle=True, num_workers=2, collate_fn=validset.collate_fn) 153 | 154 | net = PronunciationPredictor(ssl_model, SSL_OUT_DIM, TEXT_OUT_DIM) 155 | net = net.to(device) 156 | 157 | if my_checkpoint_dir != None: 158 | net.load_state_dict(torch.load(os.path.join(my_checkpoint_dir,'PRO','best'))) 159 | 160 | criterion = nn.MSELoss() 161 | optimizer = optim.SGD(net.parameters(), lr=0.00005, momentum=0.7) 162 | 163 | PREV_VAL_LOSS=9999999999 164 | orig_patience=2 165 | patience=orig_patience 166 | 167 | for epoch in range(1,10): 168 | STEPS=0 169 | net.train() 170 | running_loss = 0.0 171 | 172 | for i, data in enumerate(tqdm(trainloader), 0): 173 | 174 | wav, s_A, s_F, s_P, s_T, w_s_acc, w_s_stress, w_s_total, timesplit, asr_word_embed, gt_word_embed, features_w, features_p, phonevector, word_phone_map, _, wavname = data 175 | wav = wav.to(device) 176 | s_A = s_A.to(device) 177 | s_F = s_F.to(device) 178 | s_P = s_P.to(device) 179 | s_T = s_T.to(device) 180 | w_s_acc = w_s_acc.to(device) 181 | w_s_stress = w_s_stress.to(device) 182 | w_s_total = w_s_total.to(device) 183 | asr_word_embed = asr_word_embed.to(device) 184 | gt_word_embed = gt_word_embed.to(device) 185 | features_w = features_w.to(device) 186 | features_p = features_p.to(device) 187 | phonevector = phonevector.to(device) 188 | 189 | wav_input = wav.squeeze(1) 190 | optimizer.zero_grad() 191 | output_A, output_F, output_P, output_T, output_w_acc, output_w_stress, output_w_total = net(wav_input, asr_word_embed, gt_word_embed, features_p, features_w, phonevector, word_phone_map, timesplit) 192 | if output_w_acc.shape[1]!=w_s_acc.shape[1]: 193 | continue 194 | loss_A = criterion(output_A, s_A) 195 | loss_F = criterion(output_F, s_F) 196 | loss_P = criterion(output_P, s_P) 197 | loss_T = criterion(output_T, s_T) 198 | loss_wa = criterion(output_w_acc, w_s_acc) 199 | loss_ws = criterion(output_w_stress, w_s_stress) 200 | loss_wt = criterion(output_w_total, w_s_total) 201 | loss = loss_A + loss_F + loss_P + loss_T + loss_wa + loss_ws + loss_wt 202 | 203 | loss.backward() 204 | optimizer.step() 205 | STEPS += 1 206 | running_loss += loss.item() 207 | 208 | print('EPOCH: ' + str(epoch)) 209 | print('AVG EPOCH TRAIN LOSS: ' + str(running_loss / STEPS)) 210 | 211 | 212 | ## validation 213 | VALSTEPS=0 214 | epoch_val_loss = 0.0 215 | net.eval() 216 | ## clear memory to avoid OOM 217 | with torch.cuda.device(device): 218 | torch.cuda.empty_cache() 219 | torch.cuda.memory_allocated() 220 | torch.cuda.synchronize() 221 | 222 | for i, data in enumerate(validloader, 0): 223 | VALSTEPS+=1 224 | 225 | wav, s_A, s_F, s_P, s_T, w_s_acc, w_s_stress, w_s_total, timesplit, asr_word_embed, gt_word_embed, features_w, features_p, phonevector, word_phone_map, _, _ = data 226 | wav = wav.to(device) 227 | s_A = s_A.to(device) 228 | s_F = s_F.to(device) 229 | s_P = s_P.to(device) 230 | s_T = s_T.to(device) 231 | w_s_acc = w_s_acc.to(device) 232 | w_s_stress = w_s_stress.to(device) 233 | w_s_total = w_s_total.to(device) 234 | asr_word_embed = asr_word_embed.to(device) 235 | gt_word_embed = gt_word_embed.to(device) 236 | features_w = features_w.to(device) 237 | features_p = features_p.to(device) 238 | phonevector = phonevector.to(device) 239 | 240 | wav_input = wav.squeeze(1) 241 | 242 | with torch.no_grad(): 243 | output_A, output_F, output_P, output_T, output_w_acc, output_w_stress, output_w_total = net(wav_input, asr_word_embed, gt_word_embed, features_p, features_w, phonevector, word_phone_map, timesplit) 244 | if output_w_acc.shape[1]!=w_s_acc.shape[1]: 245 | continue 246 | loss_A = criterion(output_A, s_A) 247 | loss_F = criterion(output_F, s_F) 248 | loss_P = criterion(output_P, s_P) 249 | loss_T = criterion(output_T, s_T) 250 | loss_wa = criterion(output_w_acc, w_s_acc) 251 | loss_ws = criterion(output_w_stress, w_s_stress) 252 | loss_wt = criterion(output_w_total, w_s_total) 253 | loss = loss_A + loss_F + loss_P + loss_T + loss_wa + loss_ws + loss_wt 254 | 255 | epoch_val_loss += loss.item() 256 | 257 | avg_val_loss=epoch_val_loss/VALSTEPS 258 | print('EPOCH VAL LOSS: ' + str(avg_val_loss)) 259 | if avg_val_loss < PREV_VAL_LOSS: 260 | print('Loss has decreased') 261 | PREV_VAL_LOSS=avg_val_loss 262 | torch.save(net.state_dict(), os.path.join(ckptdir,'PRO','best')) 263 | patience = orig_patience 264 | else: 265 | patience-=1 266 | if patience == 0: 267 | print('loss has not decreased for ' + str(orig_patience) + ' epochs; early stopping at epoch ' + str(epoch)) 268 | break 269 | 270 | print('Finished Training of Pronunciation Assessment Model') 271 | 272 | if __name__ == '__main__': 273 | main() 274 | -------------------------------------------------------------------------------- /processors.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import re 6 | import numpy as np 7 | from itertools import groupby, chain, accumulate 8 | import soundfile as sf 9 | import librosa.core 10 | import unicodedata 11 | from builtins import str as unicode 12 | from nltk.tokenize import TweetTokenizer 13 | word_tokenize = TweetTokenizer().tokenize 14 | 15 | from g2p_en import G2p 16 | from g2p_en.expand import normalize_numbers 17 | from g2pM import G2pM 18 | from transformers import Wav2Vec2CTCTokenizer,Wav2Vec2FeatureExtractor, Wav2Vec2Processor 19 | 20 | 21 | 22 | class CharsiuPreprocessor: 23 | 24 | def __init__(self): 25 | pass 26 | 27 | 28 | def get_phones_and_words(self): 29 | raise NotImplementedError 30 | 31 | 32 | def get_phone_ids(self): 33 | raise NotImplementedError 34 | 35 | 36 | def mapping_phone2id(self,phone): 37 | ''' 38 | Convert a phone to a numerical id 39 | 40 | Parameters 41 | ---------- 42 | phone : str 43 | A phonetic symbol 44 | 45 | Returns 46 | ------- 47 | int 48 | A one-hot id for the input phone 49 | 50 | ''' 51 | return self.processor.tokenizer.convert_tokens_to_ids(phone) 52 | 53 | def mapping_id2phone(self,idx): 54 | ''' 55 | Convert a numerical id to a phone 56 | 57 | Parameters 58 | ---------- 59 | idx : int 60 | A one-hot id for a phone 61 | 62 | Returns 63 | ------- 64 | str 65 | A phonetic symbol 66 | 67 | ''' 68 | 69 | return self.processor.tokenizer.convert_ids_to_tokens(idx) 70 | 71 | 72 | def audio_preprocess(self,audio,sr=16000): 73 | ''' 74 | Load and normalize audio 75 | If the sampling rate is incompatible with models, the input audio will be resampled. 76 | 77 | Parameters 78 | ---------- 79 | path : str 80 | The path to the audio 81 | sr : int, optional 82 | Audio sampling rate, either 16000 or 32000. The default is 16000. 83 | 84 | Returns 85 | ------- 86 | torch.Tensor [(n,)] 87 | A list of audio sample as an one dimensional torch tensor 88 | 89 | ''' 90 | if type(audio)==str: 91 | if sr == 16000: 92 | features,fs = sf.read(audio) 93 | assert fs == 16000 94 | else: 95 | features, _ = librosa.core.load(audio,sr=sr) 96 | elif isinstance(audio, np.ndarray): 97 | features = audio 98 | else: 99 | raise Exception('The input must be a path or a numpy array!') 100 | return self.processor(features, sampling_rate=16000,return_tensors='pt').input_values.squeeze() 101 | 102 | ''' 103 | English g2p processor 104 | ''' 105 | class CharsiuPreprocessor_en(CharsiuPreprocessor): 106 | 107 | def __init__(self): 108 | 109 | tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('charsiu/tokenizer_en_cmu') 110 | feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False) 111 | self.processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) 112 | self.g2p = G2p() 113 | self.sil = '[SIL]' 114 | self.sil_idx = self.mapping_phone2id(self.sil) 115 | # self.punctuation = set('.,!?') 116 | self.punctuation = set() 117 | 118 | def get_phones_and_words(self,sen): 119 | ''' 120 | Convert texts to phone sequence 121 | 122 | Parameters 123 | ---------- 124 | sen : str 125 | A str of input sentence 126 | 127 | Returns 128 | ------- 129 | sen_clean : list 130 | A list of phone sequence without stress marks 131 | sen : list 132 | A list of phone sequence with stress marks 133 | 134 | 135 | xxxxx should sen_clean be deleted? 136 | 137 | ''' 138 | 139 | phones = self.g2p(sen) 140 | words = self._get_words(sen) 141 | 142 | phones = list(tuple(g) for k,g in groupby(phones, key=lambda x: x != ' ') if k) 143 | 144 | aligned_phones = [] 145 | aligned_words = [] 146 | for p,w in zip(phones,words): 147 | if re.search(r'\w+\d?',p[0]): 148 | aligned_phones.append(p) 149 | aligned_words.append(w) 150 | elif p in self.punctuation: 151 | aligned_phones.append((self.sil,)) 152 | aligned_words.append(self.sil) 153 | 154 | assert len(aligned_words) == len(aligned_phones) 155 | 156 | return aligned_phones, aligned_words 157 | 158 | assert len(words) == len(phones) 159 | 160 | return phones, words 161 | 162 | 163 | 164 | def get_phone_ids(self,phones,append_silence=True): 165 | ''' 166 | Convert phone sequence to ids 167 | 168 | Parameters 169 | ---------- 170 | phones : list 171 | A list of phone sequence 172 | append_silence : bool, optional 173 | Whether silence is appended at the beginning and the end of the sequence. 174 | The default is True. 175 | 176 | Returns 177 | ------- 178 | ids: list 179 | A list of one-hot representations of phones 180 | 181 | ''' 182 | phones = list(chain.from_iterable(phones)) 183 | ids = [self.mapping_phone2id(re.sub(r'\d','',p)) for p in phones] 184 | 185 | # append silence at the beginning and the end 186 | if append_silence: 187 | if ids[0]!=self.sil_idx: 188 | ids = [self.sil_idx]+ids 189 | if ids[-1]!=self.sil_idx: 190 | ids.append(self.sil_idx) 191 | return ids 192 | 193 | 194 | 195 | def _get_words(self,text): 196 | ''' 197 | from G2P_en 198 | https://github.com/Kyubyong/g2p/blob/master/g2p_en/g2p.py 199 | 200 | Parameters 201 | ---------- 202 | sen : TYPE 203 | DESCRIPTION. 204 | 205 | Returns 206 | ------- 207 | words : TYPE 208 | DESCRIPTION. 209 | 210 | ''' 211 | 212 | text = unicode(text) 213 | text = normalize_numbers(text) 214 | text = ''.join(char for char in unicodedata.normalize('NFD', text) 215 | if unicodedata.category(char) != 'Mn') # Strip accents 216 | text = text.lower() 217 | text = re.sub("[^ a-z'.,?!\-]", "", text) 218 | text = text.replace("i.e.", "that is") 219 | text = text.replace("e.g.", "for example") 220 | 221 | # tokenization 222 | words = word_tokenize(text) 223 | 224 | return words 225 | 226 | def align_words(self, preds, phones, words, return_map=False): 227 | 228 | # add idx to the word so that it won't merge when the input sentence has duplicated words 229 | words_rep = [w+'_'+str(idx) for idx, (ph,w) in enumerate(zip(phones,words)) for p in ph] 230 | phones_rep = [re.sub(r'\d','',p) for ph,w in zip(phones,words) for p in ph] 231 | assert len(words_rep)==len(phones_rep) 232 | # match each phone to its word 233 | word_dur = [] 234 | count = 0 235 | for dur in preds: 236 | if dur[-1] == '[SIL]': 237 | word_dur.append((dur,'[SIL]')) 238 | else: 239 | while dur[-1] != phones_rep[count]: 240 | count += 1 241 | word_dur.append((dur,words_rep[count])) #((start,end,phone),word) 242 | 243 | # merge phone-to-word alignment to derive word duration 244 | words = [] 245 | word_phone_map = [] #word-to-phone index: word_phone_map[i] = num of phone of word i 246 | for key, group in groupby(word_dur, lambda x: x[-1]): 247 | group = list(group) 248 | entry = (group[0][0][0],group[-1][0][1],key.split('_')[0]) 249 | words.append(entry) 250 | word_phone_map.append(len(group)) 251 | if not return_map: 252 | return words 253 | else: 254 | 255 | idx=0 256 | word_phone_map_converted = [] 257 | for length in(word_phone_map): 258 | the_word = [] 259 | for i in range(idx,idx+length): 260 | the_word.append(i) 261 | word_phone_map_converted.append(the_word) 262 | idx+=length 263 | 264 | return words, word_phone_map_converted 265 | 266 | 267 | ''' 268 | Mandarin g2p processor 269 | ''' 270 | 271 | 272 | class CharsiuPreprocessor_zh(CharsiuPreprocessor_en): 273 | 274 | def __init__(self): 275 | tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('charsiu/tokenizer_zh_pinyin') 276 | feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False) 277 | self.processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) 278 | self.g2p = G2pM() 279 | self.sil = "[SIL]" 280 | self.sil_idx = self.mapping_phone2id(self.sil) 281 | #self.punctuation = set('.,!?。,!?、') 282 | self.punctuation = set() 283 | # Pinyin tables 284 | self.consonant_list = set(['b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 285 | 'h', 'j', 'q', 'x', 'zh', 'ch', 'sh', 'r', 'z', 286 | 'c', 's']) 287 | 288 | self.transform_dict = {'ju':'jv', 'qu':'qv', 'xu':'xv','jue':'jve', 289 | 'que':'qve', 'xue':'xve','quan':'qvan', 290 | 'xuan':'xvan','juan':'jvan', 291 | 'qun':'qvn','xun':'xvn', 'jun':'jvn', 292 | 'yuan':'van', 'yue':'ve', 'yun':'vn', 293 | 'you':'iou', 'yan':'ian', 'yin':'in', 294 | 'wa':'ua', 'wo':'uo', 'wai':'uai', 295 | 'weng':'ueng', 'wang':'uang','wu':'u', 296 | 'yu':'v','yi':'i','yo':'io','ya':'ia', 'ye':'ie', 297 | 'yao':'iao','yang':'iang', 'ying':'ing', 'yong':'iong', 298 | 'yvan':'van', 'yve':'ve', 'yvn':'vn', 299 | 'wa':'ua', 'wo':'uo', 'wai':'uai', 300 | 'wei':'ui', 'wan':'uan', 'wen':'un', 301 | 'weng':'ueng', 'wang':'uang','yv':'v', 302 | 'wuen':'un','wuo':'uo','wuang':'uang', 303 | 'wuan':'uan','wua':'ua','wuai':'uai', 304 | 'zhi':'zhiii','chi':'chiii','shi':'shiii', 305 | 'zi':'zii','ci':'cii','si':'sii'} 306 | self.er_mapping ={'er1':('e1','rr'),'er2':('e2','rr'),'er3':('e3','rr'),'er4':('e4','rr'), 307 | 'er5':('e5','rr'),'r5':('e5','rr')} 308 | self.rhyme_mapping = {'iu1':'iou1','iu2':'iou2','iu3':'iou3','iu4':'iou4','iu5':'iou5', 309 | 'u:e1':'ve1','u:e2':'ve2','u:e3':'ve3','u:e4':'ve4','u:e5':'ve5', 310 | 'u:1':'v1','u:2':'v2','u:3':'v3','u:4':'v4','u:5':'v5', 311 | 'ueng1':('u1','eng1'),'ueng2':('u2','eng2'),'ueng3':('u3','eng3'), 312 | 'ueng4':('u4','eng4'),'ueng5':('u5','eng5'),'io5':('i5','o5'), 313 | 'io4':('i4','o4'),'io1':('i1','o1')} 314 | 315 | def get_phones_and_words(self,sen): 316 | ''' 317 | Convert texts to phone sequence 318 | 319 | Parameters 320 | ---------- 321 | sen : str 322 | A str of input sentence 323 | 324 | Returns 325 | ------- 326 | sen_clean : list 327 | A list of phone sequence without stress marks 328 | sen : list 329 | A list of phone sequence with stress marks 330 | 331 | xxxxx should sen_clean be removed? 332 | ''' 333 | 334 | phones = self.g2p(sen) 335 | 336 | aligned_phones = [] 337 | aligned_words = [] 338 | for p,w in zip(phones,sen): 339 | if re.search(r'\w+:?\d',p): 340 | aligned_phones.append(self._separate_syllable(self.transform_dict.get(p[:-1],p[:-1])+p[-1])) 341 | aligned_words.append(w) 342 | elif p in self.punctuation: 343 | aligned_phones.append((self.sil,)) 344 | aligned_words.append(self.sil) 345 | 346 | assert len(aligned_phones)==len(aligned_words) 347 | return aligned_phones, aligned_words 348 | 349 | 350 | def get_phone_ids(self,phones,append_silence=True): 351 | ''' 352 | Convert phone sequence to ids 353 | 354 | Parameters 355 | ---------- 356 | phones : list 357 | A list of phone sequence 358 | append_silence : bool, optional 359 | Whether silence is appended at the beginning and the end of the sequence. 360 | The default is True. 361 | 362 | Returns 363 | ------- 364 | ids: list 365 | A list of one-hot representations of phones 366 | 367 | ''' 368 | phones = list(chain.from_iterable(phones)) 369 | ids = [self.mapping_phone2id(p) for p in phones] 370 | 371 | # append silence at the beginning and the end 372 | if append_silence: 373 | if ids[0]!=self.sil_idx: 374 | ids = [self.sil_idx]+ids 375 | if ids[-1]!=self.sil_idx: 376 | ids.append(self.sil_idx) 377 | return ids 378 | 379 | 380 | def _separate_syllable(self,syllable): 381 | """ 382 | seprate syllable to consonant + ' ' + vowel 383 | 384 | Parameters 385 | ---------- 386 | syllable : xxxxx TYPE 387 | xxxxx DESCRIPTION. 388 | 389 | Returns 390 | ------- 391 | syllable: xxxxx TYPE 392 | xxxxxx DESCRIPTION. 393 | 394 | """ 395 | 396 | assert syllable[-1].isdigit() 397 | if syllable == 'ri4': 398 | return ('r','iii4') 399 | if syllable[:-1] == 'ueng' or syllable[:-1] == 'io': 400 | return self.rhyme_mapping.get(syllable,syllable) 401 | if syllable in self.er_mapping.keys(): 402 | return self.er_mapping[syllable] 403 | if syllable[0:2] in self.consonant_list: 404 | #return syllable[0:2].encode('utf-8'),syllable[2:].encode('utf-8') 405 | return syllable[0:2], self.rhyme_mapping.get(syllable[2:],syllable[2:]) 406 | elif syllable[0] in self.consonant_list: 407 | #return syllable[0].encode('utf-8'),syllable[1:].encode('utf-8') 408 | return syllable[0], self.rhyme_mapping.get(syllable[1:],syllable[1:]) 409 | else: 410 | #return (syllable.encode('utf-8'),) 411 | return (syllable,) 412 | 413 | 414 | def align_words(self, preds, phones, words): 415 | 416 | words_rep = [w+str(i) for i,(ph,w) in enumerate(zip(phones,words)) for p in ph] 417 | phones_rep = [p for ph,w in zip(phones,words) for p in ph] 418 | assert len(words_rep)==len(phones_rep) 419 | 420 | # match each phone to its word 421 | word_dur = [] 422 | count = 0 423 | for dur in preds: 424 | if dur[-1] == '[SIL]': 425 | word_dur.append((dur,'[SIL]')) 426 | else: 427 | while dur[-1] != phones_rep[count]: 428 | count += 1 429 | if count >= len(phones_rep): 430 | break 431 | word_dur.append((dur,words_rep[count])) #((start,end,phone),word) 432 | 433 | # merge phone-to-word alignment to derive word duration 434 | words = [] 435 | for key, group in groupby(word_dur, lambda x: x[-1]): 436 | group = list(group) 437 | entry = (group[0][0][0],group[-1][0][1],re.sub(r'\d','',key)) 438 | words.append(entry) 439 | 440 | return words 441 | 442 | 443 | 444 | if __name__ == '__main__': 445 | ''' 446 | Testing functions 447 | ''' 448 | 449 | processor = CharsiuPreprocessor_zh() 450 | phones, words = processor.get_phones_and_words("鱼香肉丝、王道椒香鸡腿和川蜀鸡翅。") 451 | print(phones) 452 | print(words) 453 | ids = processor.get_phone_ids(phones) 454 | print(ids) 455 | 456 | processor = CharsiuPreprocessor_en() 457 | phones, words = processor.get_phones_and_words("I’m playing octopath right now!") 458 | print(phones) 459 | print(words) 460 | ids = processor.get_phone_ids(phones) 461 | print(ids) 462 | 463 | 464 | 465 | -------------------------------------------------------------------------------- /utils_assessment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import nltk 3 | import string 4 | import num2words 5 | import numpy as np 6 | from dataclasses import dataclass 7 | from Levenshtein import ratio 8 | from nltk.corpus import cmudict 9 | from difflib import SequenceMatcher 10 | 11 | 12 | def fit_one_hot(inputlist): 13 | mapping = {} 14 | for i in range(len(inputlist)): 15 | mapping[inputlist[i]]=i 16 | return mapping 17 | 18 | """ 19 | load the cumdict and convert it to dict for shorter processing time 20 | """ 21 | nltk.download('cmudict') 22 | all_phonemes = set() 23 | entries = cmudict.entries() 24 | cmudict_dict = {entry[0].lower(): entry[1] for entry in entries} 25 | for entry in entries: 26 | phonemes = entry[1] 27 | all_phonemes.update(phonemes) 28 | all_phonemes = list(all_phonemes) 29 | all_phonemes.append('') 30 | all_phonemes = fit_one_hot(all_phonemes) 31 | phone_vector_dimension = len(all_phonemes.keys()) 32 | 33 | 34 | def get_filepaths(directory, format='.wav'): 35 | """ 36 | load all file in the directory 37 | Parameters 38 | ----input----- 39 | directory: str. path of directory 40 | ----Return---- 41 | file_paths: list. paths of files in the directory 42 | """ 43 | file_paths = [] 44 | for root, _, files in os.walk(directory): 45 | for filename in files: 46 | filepath = os.path.join(root, filename) 47 | if filename.endswith(format): 48 | file_paths.append(filepath) 49 | return file_paths 50 | 51 | def remove_pun_except_apostrophe(input_string): 52 | """ 53 | remove punctuations (except for ' ) of the inupt string. 54 | """ 55 | translator = str.maketrans('', '', string.punctuation.replace("'", "")) 56 | output_string = input_string.translate(translator).replace(' ',' ') 57 | return output_string 58 | 59 | def remove_pun(input_string): 60 | """ 61 | remove punctuations of the input_string. 62 | """ 63 | input_string = "".join([char for char in input_string if char not in string.punctuation]) 64 | return input_string 65 | 66 | def get_transcript(audio, whisper_model, return_seg=False): 67 | """ 68 | get ASR result using whisper model 69 | 70 | Parameters 71 | ----input----- 72 | audio: Union[str, np.ndarray, torch.Tensor] 73 | whisper_model: 74 | load whisper_model using: 75 | 、、、 76 | import whisper 77 | whisper_model = whisper.load_model("base.en") 78 | 、、、 79 | return_seg: bool 80 | whether to return the segmentation result 81 | ----Return---- 82 | transcript: str. ASR result of the input wavfile 83 | """ 84 | result = whisper_model.transcribe(audio, fp16=False) 85 | if not return_seg: 86 | return result['text'] 87 | else: 88 | return result 89 | 90 | def convert_num_to_word(sen): 91 | """ 92 | convert digit in a sentence to word. e.g. "7112" to "seven one one two". 93 | """ 94 | try: #for 4 digit samples of speechocean data. 95 | int(sen.replace(' ','')) 96 | sen = ' '.join([char for char in sen]) 97 | sen = ' '.join([num2words.num2words(i) if i.isdigit() else i for i in sen.split()]) 98 | sen = sen.replace(' ',' ') 99 | except: 100 | sen = ' '.join([num2words.num2words(i) if i.isdigit() else i for i in sen.split()]) 101 | return sen 102 | 103 | 104 | def get_phone_list(word_list): 105 | """ 106 | convert word to phone using cmudict 107 | Parameters 108 | ----input----- 109 | word_list: list of word. e.g. [[word1],[word2]...] or [[word1, word2],...] 110 | ----Return---- 111 | phone_list: list of corresponding phone e.g [[p1-1,p1-2], [p2-1,p2-2,p2-3],...] or [[p1-1,p1-2,p2-1,p2-2,p2-3], ...] 112 | """ 113 | phone_list = [] 114 | for word_position in word_list: 115 | the_phone_list = [] 116 | for word in word_position: 117 | phone = cmudict_dict.get(word.lower(), '') 118 | the_phone_list.extend(phone) 119 | phone_list.append(the_phone_list) 120 | 121 | return phone_list 122 | 123 | def get_phone_vector(phone_list): 124 | """ 125 | convert phone to phone-vector using one-hot encoding 126 | Parameters 127 | ----input----- 128 | phone_list: list of phone. e.g. [[phone1-1, phone1-2],[phone2-1, phone2-2, phone2-3]...] 129 | ----Return---- 130 | phone_vector: np.ndarray, [shape=(number of word, phone_vector_dimension=71)] 131 | """ 132 | num_of_word = len(phone_list) 133 | phone_vector = np.zeros((num_of_word, phone_vector_dimension)) 134 | 135 | for word_idx, the_phone_list in enumerate(phone_list): 136 | the_phone_vector = np.zeros((phone_vector_dimension, )) 137 | for phone in the_phone_list: 138 | the_phone_vector[all_phonemes[phone]]+=1 139 | phone_vector[word_idx,:] = the_phone_vector 140 | 141 | return phone_vector 142 | 143 | def get_phone_features_from_wordlist(gt_word_list, asr_word_list): 144 | 145 | """ 146 | get phone features of ground-turth word list and ASR word list 147 | Parameters 148 | ----input----- 149 | gt_word_list: list of ground-turth word. e.g. [[word1-gt],[word2-gt],[word3-gt]...] 150 | asr_word_list: list of asr word. e.g. [[word1-asr, word2-asr],[word3],[word4]...] 151 | # note: gt_word_list[i] is aligned with asr_word_list[i] based on the audio-word forced alignment result 152 | ----Return---- 153 | phone_distance: np.ndarray, [shape=(number of ground-truth word, 1)] 154 | the phone distance (by SequenceMatcher) between ground-truth word and the asr-word 155 | phone_vector: np.ndarray, [shape=(number of ground-truth word, (phone_vector_dimension=71)*2)] 156 | the phone vector of ground-truth word and the asr-word 157 | phone_count: np.ndarray, [shape=(number of ground-truth word, 1)] 158 | the number of phones of ground-truth word divided by the number of phones of asr word 159 | """ 160 | 161 | gt_length = len(gt_word_list) 162 | gt_phone_list = get_phone_list(gt_word_list) 163 | asr_phone_list = get_phone_list(asr_word_list) 164 | 165 | gt_phone_vector = get_phone_vector(gt_phone_list) 166 | asr_phone_vector = get_phone_vector(asr_phone_list) 167 | 168 | phone_distance = np.zeros((gt_length, 1)) 169 | phone_count = np.zeros((gt_length, 1)) 170 | for word_idx in range(gt_length): 171 | the_distance = SequenceMatcher(None, gt_phone_list[word_idx],asr_phone_list[word_idx]) 172 | phone_distance[word_idx,0]=the_distance.ratio() 173 | if len(asr_phone_list[word_idx])!=0: 174 | phone_count[word_idx,0]=len(gt_phone_list[word_idx])/len(asr_phone_list[word_idx]) 175 | 176 | phone_vector = np.concatenate((gt_phone_vector, asr_phone_vector), axis=1) 177 | 178 | return phone_distance, phone_vector, phone_count 179 | 180 | 181 | def get_word_alignment_features(alignment_gt, alignment_asr): 182 | """ 183 | get word-aligned features of ground-turth word list and ASR word list 184 | Parameters 185 | ----input----- 186 | alignment_gt: list, len = number of words in the ground-truth sentence 187 | word-audio alignment result of ground-truth sentence. e.g., [[(start_time1, end_time1, word1)], [(start_time2, end_time2, word2)], ...] 188 | alignment_asr: list, len = number of words in the ASR sentence 189 | word-audio alignment result of ASR sentence. e.g., [[(start_time1, end_time1, word1)], [(start_time2, end_time2, word2)], ...] 190 | 191 | ----Return---- 192 | gt_word_list: list, len = number of words in the ground-truth sentence 193 | words in the ground-truth sentence, e.g. [[word1_gt],[word2_gt],[word3_gt],...] 194 | asr_word_list: list, len = number of words in the ground-truth sentence 195 | words in the ASR sentence aligned with ground-truth word. e.g. [[word1_asr, word2_asr],[word3_asr],...] 196 | alignment_features: np.ndarray, [shape=(number of words in the ground-truth sentence, 10)] 197 | phonevector: np.ndarray, [shape=(number of words in the ground-truth sentence, 71*2)] 198 | phonevector of the ground-truth words and asr words 199 | asr_wordidx_list: list, len = number of words in the ground-truth sentence 200 | mapping between ground-truth words and asr words. 201 | asr_wordidx_list[i] = [j,m] means alignment_gt[i] is overlapped with alignment_asr[j] and alignment_asr[m] 202 | """ 203 | gt_length = len(alignment_gt) 204 | 205 | gt_word_list = [] 206 | asr_word_list = [] 207 | asr_wordidx_list = [] 208 | asr_distance = np.zeros((gt_length, 1)) 209 | duration_gt = np.zeros((gt_length, 1)) 210 | duration_asr = np.zeros((gt_length, 1)) 211 | time_diff_start = np.zeros((gt_length, 1)) 212 | time_diff_end = np.zeros((gt_length, 1)) 213 | interval_gt = np.zeros((gt_length, 1)) 214 | interval_asr = np.zeros((gt_length, 1)) 215 | 216 | pre_end_gt = 0 217 | for gt_idx, gt_value in enumerate(alignment_gt): 218 | 219 | gt_word = gt_value[2] 220 | gt_start = float(gt_value[0]) 221 | gt_end = float(gt_value[1]) 222 | 223 | duration_gt[gt_idx,0] = (gt_end-gt_start) 224 | interval_gt[gt_idx,0] = (gt_start - pre_end_gt) 225 | pre_end_gt = gt_end 226 | 227 | gt_word_list.append([gt_word]) 228 | 229 | asr_word_all = [] 230 | asr_wordidx = [] 231 | asr_start_flag = True 232 | the_asr_start = 0 233 | the_asr_end = float(alignment_asr[-1][1]) 234 | 235 | pre_end_asr = 0 236 | asr_interval_list = 0 237 | for asr_idx, asr_value in enumerate(alignment_asr): 238 | asr_start = float(asr_value[0]) 239 | asr_end = float(asr_value[1]) 240 | if gt_end <= asr_start: 241 | break 242 | if gt_start >= asr_end: 243 | continue 244 | if max(gt_start, asr_start) <= min(gt_end, asr_end): 245 | asr_word = asr_value[2] 246 | asr_word_all.append(asr_word) 247 | asr_wordidx.append(asr_idx) 248 | asr_interval_list += (asr_start - pre_end_asr) 249 | pre_end_asr = asr_end 250 | the_asr_end = asr_end 251 | if asr_start_flag: 252 | the_asr_start = asr_start 253 | asr_start_flag= False 254 | 255 | duration_asr[gt_idx,0] = (the_asr_end - the_asr_start) 256 | time_diff_start[gt_idx,0] = (the_asr_start - gt_start) 257 | time_diff_end[gt_idx,0] = (the_asr_end - gt_end) 258 | if len(asr_wordidx)!=0: 259 | interval_asr[gt_idx,0] = asr_interval_list/len(asr_wordidx) 260 | 261 | asr_wordidx_list.append(asr_wordidx) 262 | asr_word_list.append(asr_word_all) 263 | asr_distance[gt_idx,0] = ratio(gt_word, ' '.join(asr_word_all))*10 264 | 265 | align_word_count = [len(asr_word_list[word_idx]) for word_idx in range(gt_length)] 266 | phone_distance, phonevector, phone_count = get_phone_features_from_wordlist(gt_word_list, asr_word_list) 267 | 268 | align_word_count = np.asarray(align_word_count) 269 | align_word_count = np.expand_dims(align_word_count, axis=1) 270 | 271 | alignment_features = np.concatenate((asr_distance, align_word_count, duration_gt, duration_asr, time_diff_start, time_diff_end, phone_distance, phone_count, interval_gt, interval_asr), axis=1) 272 | 273 | return gt_word_list, asr_word_list, alignment_features, phonevector, asr_wordidx_list 274 | 275 | 276 | def get_phone_alignment_features(pred_phones_gt, pred_phones_asr, pred_prob_gt, phone_ids_gt, pred_prob_asr, phone_ids_asr): 277 | """ 278 | get phone-aligned features of ground-turth phone list and ASR phone list 279 | Parameters 280 | ----input----- 281 | pred_phones_gt: list, len = number of phones in the ground-truth sentence 282 | phoneme-audio alignment result of ground-truth sentence. e.g., [[(start_time1, end_time1, phone1)], [(start_time2, end_time2, phone2)], ...] 283 | pred_phones_asr: list, len = number of phones in the ASR sentence 284 | phoneme-audio alignment result of ASR sentence. e.g., [[(start_time1, end_time1, phone1)], [(start_time2, end_time2, phone2)], ...] 285 | pred_prob_gt: np.ndarray, [shape=(number of phones in the ground-truth sentence, 42)] 286 | output of the charsiu model 287 | phone_ids_gt: list, len = number of phones in the ground-truth sentence 288 | index of the aligned phone. 289 | pred_prob_asr: np.ndarray, [shape=(number of phones in the ASR sentence, 42)] 290 | output of the charsiu model 291 | phone_ids_asr: list, len = number of phones in the ASR sentence 292 | index of the aligned phone. 293 | 294 | ----Return---- 295 | features: np.ndarray, [shape=(number of phones in the ground-truth sentence, 93)] 296 | extracted features. 297 | features[:,:9] = alignment features 298 | features[:,9:9+42] = pred_prob_gt 299 | features[:,9+42:] = aligned_pred_prob_asr (align ASR phone to ground-truth phone using the time information) 300 | """ 301 | gt_length = len(pred_phones_gt) 302 | 303 | asr_phone_list = [] 304 | asr_phoneidx_list = [] 305 | duration_gt = np.zeros((gt_length, 1)) 306 | duration_asr = np.zeros((gt_length, 1)) 307 | time_diff_start = np.zeros((gt_length, 1)) 308 | time_diff_end = np.zeros((gt_length, 1)) 309 | interval_gt = np.zeros((gt_length, 1)) 310 | interval_asr = np.zeros((gt_length, 1)) 311 | 312 | pre_end_gt = 0 313 | for gt_idx, gt_value in enumerate(pred_phones_gt): 314 | 315 | gt_start = gt_value[0] 316 | gt_end = gt_value[1] 317 | 318 | duration_gt[gt_idx,0] = (gt_end - gt_start) 319 | interval_gt[gt_idx,0] = (gt_start - pre_end_gt) 320 | pre_end_gt = gt_end 321 | 322 | asr_phone_all = [] 323 | asr_phoneidx = [] 324 | asr_start_flag = True 325 | the_asr_start = 0 326 | 327 | pre_end_asr = 0 328 | asr_interval_list = 0 329 | for asr_idx, asr_value in enumerate(pred_phones_asr): 330 | asr_start = asr_value[0] 331 | asr_end = asr_value[1] 332 | if gt_end <= asr_start: 333 | break 334 | if gt_start >= asr_end: 335 | continue 336 | if max(gt_start, asr_start) <= min(gt_end, asr_end): 337 | asr_phone = asr_value[2] 338 | asr_phone_all.append(asr_phone) 339 | asr_phoneidx.append(asr_idx) 340 | asr_interval_list += (asr_start - pre_end_asr) 341 | pre_end_asr = asr_end 342 | the_asr_end = asr_end 343 | if asr_start_flag: 344 | the_asr_start = asr_start 345 | asr_start_flag= False 346 | 347 | duration_asr[gt_idx,0] = (the_asr_end - the_asr_start) 348 | time_diff_start[gt_idx,0] = (the_asr_start - gt_start) 349 | time_diff_end[gt_idx,0] = (the_asr_end - gt_end) 350 | if len(asr_phoneidx)!=0: 351 | interval_asr[gt_idx,0] = asr_interval_list/len(asr_phoneidx) 352 | 353 | asr_phoneidx_list.append(asr_phoneidx) 354 | asr_phone_list.append(asr_phone_all) 355 | 356 | align_phone_count = [len(asr_phone_list[phone_idx]) for phone_idx in range(gt_length)] 357 | 358 | align_phone_count = np.asarray(align_phone_count) 359 | align_phone_count = np.expand_dims(align_phone_count, axis=1) 360 | 361 | the_gt_phone_prob = np.zeros((pred_prob_gt.shape[0], 1)) 362 | the_asr_phone_prob = np.zeros((pred_prob_asr.shape[0], 1)) 363 | 364 | the_gt_phone_prob[:, 0] = pred_prob_gt[np.arange(pred_prob_gt.shape[0]), phone_ids_gt] 365 | the_asr_phone_prob[:, 0] = pred_prob_asr[np.arange(pred_prob_asr.shape[0]), phone_ids_asr] 366 | 367 | aligned_asr_prob = np.zeros((pred_prob_gt.shape[0], 1)) 368 | aligned_pred_prob_asr = np.zeros((pred_prob_gt.shape)) 369 | for gt_p_idx, gt_p in enumerate(asr_phoneidx_list): 370 | aligned_asr_prob[gt_p_idx,:] = np.mean(the_asr_phone_prob[gt_p,:], axis=0) 371 | aligned_pred_prob_asr[gt_p_idx,:] = np.mean(pred_prob_asr[gt_p,:], axis=0) 372 | 373 | features = np.concatenate((align_phone_count, duration_gt, duration_asr, time_diff_start, time_diff_end, interval_gt, interval_asr, the_gt_phone_prob, aligned_asr_prob, pred_prob_gt, aligned_pred_prob_asr), axis=1) 374 | 375 | return features 376 | 377 | 378 | def get_roberta_word_embed(word_list, num_of_token, roberta): 379 | """ 380 | get roberta word embedding of input word list 381 | Parameters 382 | ----input----- 383 | word_list: list, len = number of words 384 | num_of_token: int 385 | number of the words in the sentence. 386 | roberta: object. 387 | load roberta model using: 388 | 、、、 389 | from fairseq.models.roberta import RobertaModel 390 | roberta = RobertaModel.from_pretrained('./fairseq_roberta', checkpoint_file='model.pt') 391 | roberta.eval() 392 | 、、、 393 | ----Return---- 394 | sen_vec: np.ndarray, [shape=(num_of_token, 768)] 395 | word-by-word roberta embedding 396 | """ 397 | sen_vec = np.zeros((num_of_token, 768)) 398 | 399 | for w_idx, the_word_list in enumerate(word_list): 400 | the_sen = ' '.join(the_word_list) 401 | if the_sen=='': 402 | continue 403 | doc = roberta.extract_features_aligned_to_words(the_sen) 404 | the_sen_vec = np.zeros((768,)) 405 | for tok in doc: 406 | if str(tok)=='' or str(tok)=='': 407 | continue 408 | the_vec = tok.vector.detach().numpy() 409 | the_sen_vec[:] += the_vec 410 | the_sen_vec /= len(the_word_list) 411 | sen_vec[w_idx,:] = the_sen_vec 412 | 413 | return sen_vec 414 | 415 | def get_charsiu_alignment(audio, sen, aligner): 416 | pred_phones, pred_words, words, pred_prob, phone_ids, word_phone_map = aligner.align(audio=audio, text=sen) 417 | return pred_phones, pred_words, words, pred_prob, phone_ids, word_phone_map 418 | 419 | def get_match_index(pred_words, words): 420 | """ 421 | remove [SIL] in the charsiu word alignment result. 422 | """ 423 | selected_idx = [] 424 | curren_idx=0 425 | for the_word in words: 426 | for i in range(curren_idx, len(pred_words)): 427 | if pred_words[i][2]==the_word: 428 | selected_idx.append(i) 429 | curren_idx = i+1 430 | break 431 | return selected_idx 432 | 433 | 434 | def feature_extraction(audio, gt_sen, asr_sen, alignment_model, word_model): 435 | """ 436 | extract features for the assessment model 437 | Parameters 438 | ----input----- 439 | audio: str. or np.ndarray [shape=(n,)] 440 | path to the input wavfile or np.ndarray of wavfile 441 | gt_sen: str. 442 | ground-truth sentence 443 | asr_sen: str. 444 | ASR sentence: 445 | alignment_model: object 446 | load charsiu model using: 447 | 、、、 448 | from Charsiu import charsiu_forced_aligner 449 | alignment_model = charsiu_forced_aligner(aligner='charsiu/en_w2v2_fc_10ms') 450 | 、、、 451 | word_model: object: 452 | load roberta model using: 453 | 、、、 454 | from fairseq.models.roberta import RobertaModel 455 | word_model = RobertaModel.from_pretrained('./fairseq_roberta', checkpoint_file='model.pt') 456 | word_model.eval() 457 | 、、、 458 | ----Return---- 459 | """ 460 | 461 | pred_phones_gt, pred_words_gt, words_gt, pred_prob_gt, phone_ids_gt, word_phone_map_gt = get_charsiu_alignment(audio, gt_sen, alignment_model) 462 | pred_phones_asr, pred_words_asr, words_asr, pred_prob_asr, phone_ids_asr, _ = get_charsiu_alignment(audio, asr_sen, alignment_model) 463 | 464 | features_p = get_phone_alignment_features(pred_phones_gt, pred_phones_asr, pred_prob_gt, phone_ids_gt, pred_prob_asr, phone_ids_asr) 465 | 466 | selected_idx_gt = get_match_index(pred_words_gt, words_gt) 467 | 468 | word_phone_map_gt = [word_phone_map_gt[i] for i in selected_idx_gt] 469 | 470 | pred_words_gt = np.asarray(pred_words_gt) 471 | pred_words_gt = pred_words_gt[selected_idx_gt] 472 | 473 | selected_idx_asr = get_match_index(pred_words_asr, words_asr) 474 | pred_words_asr = np.asarray(pred_words_asr) 475 | pred_words_asr = pred_words_asr[selected_idx_asr] 476 | gt_word_list, asr_word_list, features_w, phonevector, _ = get_word_alignment_features(pred_words_gt, pred_words_asr) 477 | 478 | num_of_token = len(selected_idx_gt) 479 | gt_word_embed = get_roberta_word_embed(gt_word_list, num_of_token, word_model) 480 | asr_word_embed = get_roberta_word_embed(asr_word_list, num_of_token, word_model) 481 | 482 | return pred_words_gt, features_p, features_w, phonevector, gt_word_embed, asr_word_embed, word_phone_map_gt 483 | -------------------------------------------------------------------------------- /Charsiu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import sys 6 | import torch 7 | from itertools import groupby 8 | sys.path.append('src/') 9 | import numpy as np 10 | #sys.path.insert(0,'src') 11 | from models import Wav2Vec2ForAttentionAlignment, Wav2Vec2ForFrameClassification, Wav2Vec2ForCTC 12 | from utils import seq2duration,forced_align,duration2textgrid,word2textgrid 13 | from processors import CharsiuPreprocessor_zh, CharsiuPreprocessor_en 14 | 15 | processors = {'zh':CharsiuPreprocessor_zh, 16 | 'en':CharsiuPreprocessor_en} 17 | 18 | class charsiu_aligner: 19 | 20 | def __init__(self, 21 | lang='en', 22 | sampling_rate=16000, 23 | device=None, 24 | recognizer=None, 25 | processor=None, 26 | resolution=0.01): 27 | 28 | self.lang = lang 29 | 30 | if processor is not None: 31 | self.processor = processor 32 | else: 33 | self.charsiu_processor = processors[self.lang]() 34 | 35 | 36 | 37 | self.resolution = resolution 38 | 39 | self.sr = sampling_rate 40 | 41 | self.recognizer = recognizer 42 | 43 | if device is None: 44 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 45 | else: 46 | self.device = device 47 | 48 | 49 | def _freeze_model(self): 50 | self.aligner.eval().to(self.device) 51 | if self.recognizer is not None: 52 | self.recognizer.eval().to(self.device) 53 | 54 | 55 | 56 | def align(self,audio,text): 57 | raise NotImplementedError() 58 | 59 | 60 | 61 | def serve(self,audio,save_to,output_format='variable',text=None): 62 | raise NotImplementedError() 63 | 64 | 65 | def _to_textgrid(self,phones,save_to): 66 | ''' 67 | Convert output tuples to a textgrid file 68 | 69 | Parameters 70 | ---------- 71 | phones : TYPE 72 | DESCRIPTION. 73 | 74 | Returns 75 | ------- 76 | None. 77 | 78 | ''' 79 | duration2textgrid(phones,save_path=save_to) 80 | print('Alignment output has been saved to %s'%(save_to)) 81 | 82 | 83 | 84 | def _to_tsv(self,phones,save_to): 85 | ''' 86 | Convert output tuples to a tab-separated file 87 | 88 | Parameters 89 | ---------- 90 | phones : TYPE 91 | DESCRIPTION. 92 | 93 | Returns 94 | ------- 95 | None. 96 | 97 | ''' 98 | with open(save_to,'w') as f: 99 | for start,end,phone in phones: 100 | f.write('%s\t%s\t%s\n'%(start,end,phone)) 101 | print('Alignment output has been saved to %s'%(save_to)) 102 | 103 | 104 | 105 | 106 | 107 | class charsiu_forced_aligner(charsiu_aligner): 108 | 109 | def __init__(self, aligner, sil_threshold=4, **kwargs): 110 | super(charsiu_forced_aligner, self).__init__(**kwargs) 111 | self.aligner = Wav2Vec2ForFrameClassification.from_pretrained(aligner) 112 | self.sil_threshold = sil_threshold 113 | 114 | self._freeze_model() 115 | 116 | 117 | def align(self, audio, text): 118 | ''' 119 | Perform forced alignment 120 | 121 | Parameters 122 | ---------- 123 | audio : np.ndarray [shape=(n,)] 124 | time series of speech signal 125 | text : str 126 | The transcription 127 | 128 | Returns 129 | ------- 130 | A tuple of aligned phones in the form (start_time, end_time, phone) 131 | 132 | ''' 133 | audio = self.charsiu_processor.audio_preprocess(audio,sr=self.sr) 134 | audio = torch.Tensor(audio).unsqueeze(0).to(self.device) 135 | phones, words = self.charsiu_processor.get_phones_and_words(text) 136 | phone_ids = self.charsiu_processor.get_phone_ids(phones) 137 | 138 | with torch.no_grad(): 139 | out = self.aligner(audio) 140 | 141 | cost = torch.softmax(out.logits,dim=-1).detach().cpu().numpy().squeeze() 142 | sil_mask = self._get_sil_mask(cost) 143 | 144 | nonsil_idx = np.argwhere(sil_mask!=self.charsiu_processor.sil_idx).squeeze() 145 | if nonsil_idx is None: 146 | raise Exception("No speech detected! Please check the audio file!") 147 | 148 | aligned_phone_ids = forced_align(cost[nonsil_idx,:],phone_ids[1:-1]) 149 | 150 | aligned_phones = [self.charsiu_processor.mapping_id2phone(phone_ids[1:-1][i]) for i in aligned_phone_ids] 151 | 152 | pred_phones_ori = self._merge_silence(aligned_phones,sil_mask) 153 | 154 | pred_phones = seq2duration(pred_phones_ori,resolution=self.resolution) 155 | 156 | pred_words, word_phone_map = self.charsiu_processor.align_words(pred_phones, phones, words, return_map=True) 157 | 158 | ## get the predicted probability of output phone 159 | counter = 0 160 | pred_prob = np.empty((len(pred_phones),cost.shape[1])) 161 | group_idx = 0 162 | for _, group in groupby(pred_phones_ori): 163 | length = len(list(group)) 164 | pred_prob[group_idx,:] = np.mean(cost[counter:counter+length,:], axis=0) 165 | counter += length 166 | group_idx+=1 167 | 168 | final_phones = [the_phone[2] for the_phone in pred_phones ] 169 | phone_ids = self.charsiu_processor.mapping_phone2id(final_phones) 170 | 171 | return pred_phones, pred_words, words, pred_prob, phone_ids, word_phone_map 172 | 173 | 174 | def serve(self,audio,text,save_to,output_format='textgrid'): 175 | ''' 176 | A wrapper function for quick inference 177 | 178 | Parameters 179 | ---------- 180 | audio : TYPE 181 | DESCRIPTION. 182 | text : TYPE, optional 183 | DESCRIPTION. The default is None. 184 | output_format : str, optional 185 | Output phone-taudio alignment as a "tsv" or "textgrid" file. 186 | The default is 'textgrid'. 187 | 188 | Returns 189 | ------- 190 | None. 191 | 192 | ''' 193 | phones, words = self.align(audio,text) 194 | 195 | if output_format == 'tsv': 196 | if save_to.endswith('.tsv'): 197 | save_to_phone = save_to.replace('.tsv','_phone.tsv') 198 | save_to_word = save_to.replace('.tsv','_word.tsv') 199 | else: 200 | save_to_phone = save_to + '_phone.tsv' 201 | save_to_word = save_to + '_word.tsv' 202 | 203 | self._to_tsv(phones, save_to_phone) 204 | self._to_tsv(words, save_to_word) 205 | 206 | elif output_format == 'textgrid': 207 | self._to_textgrid(phones, words, save_to) 208 | else: 209 | raise Exception('Please specify the correct output format (tsv or textgird)!') 210 | 211 | def _to_textgrid(self,phones,words,save_to): 212 | ''' 213 | Convert output tuples to a textgrid file 214 | 215 | Parameters 216 | ---------- 217 | phones : TYPE 218 | DESCRIPTION. 219 | 220 | Returns 221 | ------- 222 | None. 223 | 224 | ''' 225 | word2textgrid(phones,words,save_path=save_to) 226 | print('Alignment output has been saved to %s'%(save_to)) 227 | 228 | 229 | def _merge_silence(self,aligned_phones,sil_mask): 230 | # merge silent and non-silent intervals 231 | pred_phones = [] 232 | count = 0 233 | for i in sil_mask: 234 | if i==self.charsiu_processor.sil_idx: 235 | pred_phones.append('[SIL]') 236 | else: 237 | pred_phones.append(aligned_phones[count]) 238 | count += 1 239 | assert len(pred_phones) == len(sil_mask) 240 | return pred_phones 241 | 242 | 243 | 244 | def _get_sil_mask(self,cost): 245 | # single out silent intervals 246 | 247 | preds = np.argmax(cost,axis=-1) 248 | sil_mask = [] 249 | for key, group in groupby(preds): 250 | group = list(group) 251 | if (key==self.charsiu_processor.sil_idx and len(group) 0 97 | for kernel in kernels: 98 | self.cnns.append(nn.Conv1d(latest_size, cnn_size, kernel, padding=kernel//2)) 99 | latest_size = cnn_size * len(kernels) 100 | 101 | self.out_linear = nn.Linear(latest_size, output_class_num) 102 | 103 | def forward(self, features): 104 | hidden = F.dropout(F.relu(self.in_linear(features)), p=self.drop_p) 105 | 106 | conv_feats = [] 107 | hidden = hidden.transpose(1, 2).contiguous() 108 | for cnn in self.cnns: 109 | conv_feats.append(cnn(hidden)) 110 | hidden = torch.cat(conv_feats, dim=1).transpose(1, 2).contiguous() 111 | hidden = F.dropout(F.relu(hidden), p=self.drop_p) 112 | 113 | predicted = self.out_linear(hidden) 114 | return predicted 115 | 116 | 117 | 118 | class RNN(nn.Module): 119 | 120 | def __init__(self,hidden_dim,out_dim): 121 | super().__init__() 122 | 123 | self.lstm = nn.LSTM(hidden_dim,hidden_dim,bidirectional=True,num_layers=1,batch_first=True) 124 | self.linear = nn.Sequential(nn.Linear(2*hidden_dim,hidden_dim), 125 | nn.ReLU(), 126 | nn.Linear(hidden_dim,out_dim)) 127 | 128 | 129 | def forward(self, embeddings, lens): 130 | 131 | packed_input = pack_padded_sequence(embeddings, lens.cpu(), batch_first=True,enforce_sorted=False) 132 | packed_output, (ht, ct)= self.lstm(packed_input) 133 | out, _ = pad_packed_sequence(packed_output, batch_first=True) 134 | out = self.linear(out) 135 | return out 136 | 137 | 138 | class Wav2Vec2ForAttentionAlignment(Wav2Vec2ForPreTraining): 139 | ''' 140 | Implementation adapted from: https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2ForPreTraining 141 | ''' 142 | def __init__(self,config): 143 | super().__init__(config) 144 | bert_config = self.get_bert_config(config) 145 | self.bert = BertForMaskedPhoneLM(bert_config) 146 | self.cnn = ConvBank(config.hidden_size, 147 | config.bert_hidden_size, 148 | config.bert_convbank, 149 | config.bert_hidden_size, 150 | config.bert_hidden_size, 151 | config.hidden_dropout) 152 | #self.lm_head = nn.Linear(config.hidden_size,config.vocab_size) 153 | # self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim) 154 | # self.phone_rnn = RNN(384,config.vocab_size) 155 | 156 | self.attention = Attention(config.bert_hidden_size) 157 | self.align_loss = ForwardSumLoss() 158 | 159 | def freeze_wav2vec2(self): 160 | for param in self.wav2vec2.parameters(): 161 | param.requires_grad = False 162 | 163 | def initialize_phone_model(self,path): 164 | 165 | self.bert = BertForMaskedPhoneLM.from_pretrained(path) 166 | 167 | def get_bert_config(self,config): 168 | bert_config = BertConfig(architectures=config.bert_architectures, 169 | attention_probs_dropout_prob=config.bert_attention_probs_dropout_prob, 170 | gradient_checkpointing=config.bert_gradient_checkpointing, 171 | hidden_act=config.bert_hidden_act, 172 | hidden_dropout_prob=config.bert_hidden_dropout_prob, 173 | hidden_size=config.bert_hidden_size, 174 | initializer_range=config.bert_initializer_range, 175 | intermediate_size=config.bert_intermediate_size, 176 | layer_norm_eps=config.bert_layer_norm_eps, 177 | max_position_embeddings=config.bert_max_position_embeddings, 178 | model_type=config.bert_model_type, 179 | num_attention_heads=config.bert_num_attention_heads, 180 | num_hidden_layers=config.bert_num_hidden_layers, 181 | pad_token_id=config.bert_pad_token_id, 182 | position_embedding_type=config.bert_position_embedding_type, 183 | transformers_version=config.bert_transformers_version, 184 | type_vocab_size=config.bert_type_vocab_size, 185 | use_cache=config.bert_use_cache, 186 | vocab_size=config.bert_vocab_size, 187 | convbank=config.bert_convbank) 188 | 189 | return bert_config 190 | 191 | def forward( 192 | self, 193 | input_values, 194 | attention_mask=None, 195 | output_attentions=None, 196 | output_hidden_states=None, 197 | mask_time_indices=None, 198 | return_dict=None, 199 | labels=None, 200 | labels_attention_mask=None, 201 | text_len=None, 202 | frame_len=None, 203 | weight=1 204 | ): 205 | 206 | # check the availability of attention masks 207 | # if not present, create full attention masks 208 | if attention_mask is None: 209 | attention_mask = torch.ones_like(input_values) 210 | 211 | if labels_attention_mask is None: 212 | labels_attention_mask = torch.ones_like(labels) 213 | 214 | 215 | 216 | outputs = self.wav2vec2( 217 | input_values, 218 | attention_mask=attention_mask, 219 | output_attentions=output_attentions, 220 | output_hidden_states=output_hidden_states, 221 | mask_time_indices=mask_time_indices, 222 | return_dict=return_dict, 223 | ) 224 | 225 | # acoustic embeddings 226 | frame_hidden = outputs[0] 227 | # frame_hidden = self.dropout(frame_hidden) 228 | frame_hidden = self.cnn(frame_hidden) 229 | 230 | # phone embeddings 231 | phone_hidden = self.bert(input_ids=labels,attention_mask=labels_attention_mask).hidden_states[-1] 232 | 233 | # compute cross attention 234 | att_out,energy = self.attention(frame_hidden,phone_hidden,labels_attention_mask) 235 | 236 | 237 | # start masked modeling 238 | # 0. remove the blank symbol 239 | # 1. project all transformed features (including masked) to final vq dim 240 | transformer_features = self.project_hid(torch.tanh(att_out)) 241 | 242 | 243 | # 2. quantize all (unmasked) extracted features and project to final vq dim 244 | extract_features = self.dropout_features(outputs[1]) 245 | quantized_features, codevector_perplexity = self.quantizer(extract_features, mask_time_indices) 246 | quantized_features = self.project_q(quantized_features) 247 | 248 | 249 | # if attention_mask is passed, make sure that padded feature vectors cannot be sampled 250 | if attention_mask is not None: 251 | # compute reduced attention_mask correponding to feature vectors 252 | attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) 253 | 254 | # loss_fct = nn.CrossEntropyLoss() 255 | 256 | # phone_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 257 | 258 | 259 | loss = None 260 | if self.training: 261 | # for training, we sample negatives 262 | # 3. sample K negatives (distractors) quantized states for contrastive loss 263 | 264 | negative_quantized_features = self._sample_negatives( 265 | quantized_features, self.config.num_negatives, attention_mask=attention_mask 266 | ) 267 | 268 | # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa` 269 | # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf 270 | logits = self.compute_contrastive_logits( 271 | quantized_features[None, :], 272 | negative_quantized_features, 273 | transformer_features, 274 | self.config.contrastive_logits_temperature, 275 | ) 276 | 277 | # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low), 278 | # its cosine similarity will be masked 279 | neg_is_pos = (quantized_features == negative_quantized_features).all(-1) 280 | if neg_is_pos.any(): 281 | logits[1:][neg_is_pos] = float("-inf") 282 | 283 | # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) = 284 | # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa)) 285 | preds = logits.transpose(0, 2).reshape(-1, logits.size(0)) 286 | target = ((1 - attention_mask.long()) * -100).transpose(0, 1).flatten() 287 | contrastive_loss = nn.functional.cross_entropy(preds.float(), target, reduction="mean") 288 | 289 | # 7. compute diversity loss: \mathbf{L}_d 290 | # num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups 291 | # diversity_loss = (num_codevectors - codevector_perplexity) / num_codevectors 292 | 293 | # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d 294 | expanded_labels_attention_mask = (1-labels_attention_mask)*-10000.0 295 | expanded_labels_attention_mask = expanded_labels_attention_mask.unsqueeze(1).repeat(1,energy.size(1),1) 296 | att = torch.log_softmax(energy+expanded_labels_attention_mask,dim=-1) 297 | align_loss = self.align_loss(att.unsqueeze(1),text_len,frame_len) 298 | 299 | # expanded_attention_mask = attention_mask.unsqueeze(2).repeat(1,1,energy.size(2)) * labels_attention_mask.unsqueeze(1).repeat(1,energy.size(1),1) 300 | # expanded_attention_mask = (1-expanded_attention_mask)*-10000.0 301 | # phone_attention = torch.softmax((energy+expanded_attention_mask).transpose(2,1),dim=-1) 302 | # phone_emb = torch.bmm(phone_attention,frame_hidden) 303 | # prediction_scores = self.phone_rnn(phone_emb,text_len) 304 | # labels = labels.masked_fill(labels_attention_mask.ne(1), -100) 305 | # inter_phone = F.cosine_similarity(phone_emb[:,:-1,:],phone_emb[:,1:,:],dim=-1)*labels_attention_mask[:,1:] 306 | # interphone_loss = torch.sum(inter_phone)/torch.sum(labels_attention_mask[:,1:]) 307 | 308 | 309 | loss = contrastive_loss + weight*align_loss #+ interphone_loss 310 | 311 | 312 | return CausalLMOutput( 313 | loss=loss, logits=energy, hidden_states=outputs.hidden_states, attentions=None 314 | ) 315 | 316 | 317 | class BertForMaskedPhoneLM(BertForMaskedLM): 318 | 319 | def __init__(self,config): 320 | super().__init__(config) 321 | self.cnn = ConvBank(config.hidden_size, 322 | config.hidden_size, 323 | config.convbank, 324 | config.hidden_size, 325 | config.hidden_size, 326 | config.hidden_dropout_prob) 327 | 328 | def freeze_feature_extractor(self): 329 | for param in self.bert.parameters(): 330 | param.requires_grad = False 331 | 332 | def forward( 333 | self, 334 | input_ids=None, 335 | attention_mask=None, 336 | token_type_ids=None, 337 | position_ids=None, 338 | head_mask=None, 339 | inputs_embeds=None, 340 | encoder_hidden_states=None, 341 | encoder_attention_mask=None, 342 | labels=None, 343 | output_attentions=None, 344 | output_hidden_states=True, 345 | ): 346 | 347 | 348 | outputs = self.bert( 349 | input_ids, 350 | attention_mask=attention_mask, 351 | token_type_ids=token_type_ids, 352 | position_ids=position_ids, 353 | head_mask=head_mask, 354 | inputs_embeds=inputs_embeds, 355 | encoder_hidden_states=encoder_hidden_states, 356 | encoder_attention_mask=encoder_attention_mask, 357 | output_attentions=output_attentions, 358 | output_hidden_states=output_hidden_states 359 | ) 360 | 361 | 362 | prediction_scores = self.cnn(outputs.hidden_states[-1]) 363 | 364 | masked_lm_loss = None 365 | if labels is not None: 366 | loss_fct = CrossEntropyLoss() # -100 index = padding token 367 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 368 | 369 | return MaskedLMOutput( 370 | loss=masked_lm_loss, 371 | logits=prediction_scores, 372 | hidden_states=outputs.hidden_states, 373 | attentions=outputs.attentions, 374 | ) 375 | 376 | 377 | 378 | class Attention(nn.Module): 379 | 380 | def __init__(self,hidden_dim): 381 | super().__init__() 382 | self.q = nn.Linear(hidden_dim, hidden_dim) 383 | self.k = nn.Linear(hidden_dim, hidden_dim) 384 | # self.v = nn.Linear(hidden_dim*2, hidden_dim*2) 385 | self.layer_norm = nn.LayerNorm(hidden_dim) 386 | 387 | def forward(self,frame_hidden, phone_hidden,labels_attention_mask): 388 | 389 | frame_hidden = self.q(frame_hidden) 390 | phone_hidden = self.k(phone_hidden) 391 | 392 | energy = torch.bmm(frame_hidden,phone_hidden.transpose(2,1)) 393 | attention_mask = (1-labels_attention_mask)*-10000.0 394 | energy = energy+attention_mask.unsqueeze(1).repeat(1,energy.size(1),1) 395 | 396 | att_matrix = torch.softmax(energy,dim=-1) 397 | att_out = torch.bmm(att_matrix,phone_hidden) 398 | att_out = torch.cat([att_out,frame_hidden],dim=-1) 399 | # att_out = self.layer_norm(att_out + frame_hidden) 400 | 401 | return att_out, energy 402 | 403 | 404 | class Wav2Vec2ForFrameClassification(Wav2Vec2ForCTC): 405 | 406 | def forward( 407 | self, 408 | input_values, 409 | attention_mask=None, 410 | output_attentions=None, 411 | output_hidden_states=None, 412 | return_dict=None, 413 | labels=None, 414 | ): 415 | 416 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 417 | 418 | outputs = self.wav2vec2( 419 | input_values, 420 | attention_mask=attention_mask, 421 | output_attentions=output_attentions, 422 | output_hidden_states=output_hidden_states, 423 | return_dict=return_dict, 424 | ) 425 | 426 | hidden_states = outputs[0] 427 | hidden_states = self.dropout(hidden_states) 428 | 429 | logits = self.lm_head(hidden_states) 430 | 431 | loss = None 432 | if labels is not None: 433 | 434 | if labels.max() >= self.config.vocab_size: 435 | raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") 436 | 437 | # retrieve loss input_lengths from attention_mask 438 | attention_mask = ( 439 | attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) 440 | ) 441 | # input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) 442 | 443 | loss = nn.functional.cross_entropy(logits.view(-1,logits.size(2)), labels.flatten(), reduction="mean") 444 | 445 | 446 | 447 | if not return_dict: 448 | output = (logits,) + outputs[2:] 449 | return ((loss,) + output) if loss is not None else output 450 | 451 | return CausalLMOutput( 452 | loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions 453 | ) 454 | 455 | 456 | 457 | class Wav2Vec2ForCTCAndPretraining(Wav2Vec2ForPreTraining): 458 | ''' 459 | Implementation adapted from: https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2ForPreTraining 460 | ''' 461 | def __init__(self,config): 462 | super().__init__(config) 463 | 464 | self.dropout = nn.Dropout(config.final_dropout) 465 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) 466 | self.cnn = ConvBank(config.vocab_size-1,config.hidden_size,[1],config.hidden_size,config.hidden_size,0.1) 467 | # was [1,3,5,7] 468 | 469 | def freeze_wav2vec2(self): 470 | for param in self.wav2vec2.parameters(): 471 | param.requires_grad = False 472 | 473 | def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): 474 | output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) 475 | batch_size = attention_mask.shape[0] 476 | 477 | attention_mask = torch.zeros( 478 | (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device 479 | ) 480 | # these two operations makes sure that all values before the output lengths idxs are attended to 481 | attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 482 | attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() 483 | return attention_mask 484 | 485 | @staticmethod 486 | def _sample_negatives( 487 | features: torch.FloatTensor, num_negatives: int, attention_mask: Optional[torch.LongTensor] = None 488 | ): 489 | """ 490 | Sample `num_negatives` vectors from feature vectors. 491 | """ 492 | batch_size, sequence_length, hidden_size = features.shape 493 | if sequence_length <= 1: 494 | raise ValueError( 495 | f"`features should have `sequence_length` > 1, but are of shape (batch_size, sequence_length, hidden_size) = ({batch_size, sequence_length, hidden_size})." 496 | ) 497 | 498 | features = features.view(-1, hidden_size) # BTC => (BxT)C 499 | 500 | with torch.no_grad(): 501 | # get `num_negatives` random vector indices from the same utterance 502 | sampled_negative_indices = [] 503 | for batch_idx in range(batch_size): 504 | high = attention_mask[batch_idx].sum() - 1 if attention_mask is not None else sequence_length - 1 505 | sampled_indices_slice = torch.randint( 506 | 0, high, size=(num_negatives * sequence_length,), device=features.device 507 | ) 508 | sampled_negative_indices.append(sampled_indices_slice) 509 | 510 | sampled_negative_indices = torch.stack(sampled_negative_indices) 511 | 512 | # generate indices of the positive vectors themselves, repeat them `num_negatives` times 513 | feature_indices = ( 514 | torch.arange(sequence_length, device=features.device)[:, None] 515 | .expand(sequence_length, num_negatives) 516 | .flatten() 517 | ) 518 | 519 | # avoid sampling the same positive vector, but keep the distribution uniform 520 | sampled_negative_indices[sampled_negative_indices >= feature_indices] += 1 521 | 522 | # correct for batch size 523 | for batch_idx in range(1, batch_size): 524 | sampled_negative_indices[batch_idx] += batch_idx * sequence_length 525 | 526 | # take negative vectors from sampled indices 527 | sampled_negatives = features[sampled_negative_indices.view(-1)] 528 | sampled_negatives = sampled_negatives.view(batch_size, sequence_length, num_negatives, hidden_size).permute( 529 | 2, 0, 1, 3 530 | ) 531 | 532 | return sampled_negatives 533 | 534 | 535 | def forward( 536 | self, 537 | input_values, 538 | labels=None, 539 | attention_mask=None, 540 | mask_time_indices=None, 541 | output_attentions=None, 542 | output_hidden_states=None, 543 | return_dict=None, 544 | ): 545 | 546 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 547 | 548 | if mask_time_indices is not None: 549 | mask_time_indices = mask_time_indices.to(torch.bool) 550 | 551 | outputs = self.wav2vec2( 552 | input_values, 553 | attention_mask=attention_mask, 554 | output_attentions=output_attentions, 555 | output_hidden_states=output_hidden_states, 556 | mask_time_indices=mask_time_indices, 557 | ) 558 | 559 | # get CTC loss 560 | hidden_states = self.dropout(outputs[0]) 561 | ctc_logits = self.lm_head(hidden_states) 562 | 563 | loss = None 564 | ctc_loss = None 565 | contrastive_loss = None 566 | codevector_perplexity = None 567 | if labels is not None: 568 | 569 | if labels.max() >= self.config.vocab_size: 570 | raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") 571 | 572 | # retrieve loss input_lengths from attention_mask 573 | attention_mask = ( 574 | attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) 575 | ) 576 | input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) 577 | 578 | # assuming that padded tokens are filled with -100 579 | # when not being attended to 580 | labels_mask = labels >= 0 581 | target_lengths = labels_mask.sum(-1) 582 | flattened_targets = labels.masked_select(labels_mask) 583 | 584 | # ctc_loss doesn't support fp16 585 | log_probs = nn.functional.log_softmax(ctc_logits, dim=-1, dtype=torch.float32).transpose(0, 1) 586 | 587 | with torch.backends.cudnn.flags(enabled=False): 588 | ctc_loss = nn.functional.ctc_loss( 589 | log_probs, 590 | flattened_targets, 591 | input_lengths, 592 | target_lengths, 593 | blank=self.config.pad_token_id, 594 | reduction=self.config.ctc_loss_reduction, 595 | zero_infinity=self.config.ctc_zero_infinity, 596 | ) 597 | 598 | # start masked modeling 599 | # 0. remove the blank symbol 600 | # 1. project all transformed features (including masked) to final vq dim 601 | transformer_features = self.project_hid(self.cnn(ctc_logits[:,:,:-1])) 602 | 603 | 604 | # 2. quantize all (unmasked) extracted features and project to final vq dim 605 | extract_features = self.dropout_features(outputs[1]) 606 | quantized_features, codevector_perplexity = self.quantizer(extract_features, mask_time_indices) 607 | quantized_features = self.project_q(quantized_features) 608 | 609 | loss = None 610 | if self.training: 611 | # for training, we sample negatives 612 | # 3. sample K negatives (distractors) quantized states for contrastive loss 613 | # if attention_mask is passed, make sure that padded feature vectors cannot be sampled 614 | if attention_mask is not None: 615 | # compute reduced attention_mask correponding to feature vectors 616 | attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) 617 | 618 | negative_quantized_features = self._sample_negatives( 619 | quantized_features, self.config.num_negatives, attention_mask=attention_mask 620 | ) 621 | 622 | # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa` 623 | # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf 624 | logits = self.compute_contrastive_logits( 625 | quantized_features[None, :], 626 | negative_quantized_features, 627 | transformer_features, 628 | self.config.contrastive_logits_temperature, 629 | ) 630 | 631 | # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low), 632 | # its cosine similarity will be masked 633 | neg_is_pos = (quantized_features == negative_quantized_features).all(-1) 634 | if neg_is_pos.any(): 635 | logits[1:][neg_is_pos] = float("-inf") 636 | 637 | # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) = 638 | # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa)) 639 | preds = logits.transpose(0, 2).reshape(-1, logits.size(0)) 640 | target = ((1 - attention_mask.long()) * -100).transpose(0, 1).flatten() 641 | contrastive_loss = nn.functional.cross_entropy(preds.float(), target, reduction="mean") 642 | 643 | # 7. compute diversity loss: \mathbf{L}_d 644 | num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups 645 | diversity_loss = (num_codevectors - codevector_perplexity) / num_codevectors 646 | 647 | # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d 648 | contrastive_loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss 649 | 650 | loss = ctc_loss + contrastive_loss 651 | 652 | return Wav2Vec2ForCTCAndPretrainingOutput( 653 | loss=loss, 654 | ctc_logits = ctc_logits, 655 | ctc_loss = ctc_loss, 656 | contrastive_loss = contrastive_loss, 657 | codevector_perplexity=codevector_perplexity, 658 | projected_states=transformer_features, 659 | projected_quantized_states=quantized_features, 660 | hidden_states=outputs.hidden_states, 661 | attentions=outputs.attentions, 662 | ) 663 | 664 | 665 | --------------------------------------------------------------------------------