├── plot
└── multiPA.png
├── fairseq_roberta
└── README.md
├── fairseq_hubert
└── README.md
├── speechocean762
├── convert_wavfile.py
└── create_training_and_testing_list.py
├── LICENSE.txt
├── get_gt_alignment.py
├── whisper_asr_all.py
├── get_training_features.py
├── dataloader.py
├── test_open.py
├── test_closed.py
├── api.py
├── evaluation_speechocean_closed.py
├── utils.py
├── README.md
├── evaluation_speechocean_open.py
├── model_assessment.py
├── processors.py
├── utils_assessment.py
├── Charsiu.py
└── models.py
/plot/multiPA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuwchen/MultiPA/HEAD/plot/multiPA.png
--------------------------------------------------------------------------------
/fairseq_roberta/README.md:
--------------------------------------------------------------------------------
1 |
2 | Download roberta.base model from https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.md
3 |
4 | Put the model.pt and dict.txt here
--------------------------------------------------------------------------------
/fairseq_hubert/README.md:
--------------------------------------------------------------------------------
1 |
2 | Download HuBERT Base (~95M params) from https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/README.md
3 |
4 | Put the hubert_base_ls960.pt model here.
--------------------------------------------------------------------------------
/speechocean762/convert_wavfile.py:
--------------------------------------------------------------------------------
1 | import os
2 | import librosa
3 | import soundfile as sf
4 | from tqdm import tqdm
5 |
6 | def get_filepaths(directory):
7 | file_paths = []
8 | for root, directories, files in os.walk(directory):
9 | for filename in files:
10 | # Join the two strings in order to form the full filepath.
11 | filepath = os.path.join(root, filename)
12 | if filename.endswith('.WAV'):
13 | file_paths.append(filepath)
14 | return file_paths
15 |
16 | inputdir = './WAVE'
17 | outputdir = './wav'
18 |
19 | if not os.path.exists(outputdir):
20 | os.makedirs(outputdir)
21 |
22 | file_list = get_filepaths(inputdir)
23 |
24 |
25 | for path in tqdm(file_list):
26 | wavname = path.split(os.sep)[-1]
27 | new_path = os.path.join(outputdir, wavname.replace('.WAV','.wav'))
28 | if os.path.isfile(new_path):
29 | continue
30 | y, rate = librosa.load(path, sr=16000)
31 | sf.write(new_path, y, rate)
32 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 yuwchen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/get_gt_alignment.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | from tqdm import tqdm
5 | from utils_assessment import *
6 | from Charsiu import charsiu_forced_aligner
7 |
8 | f = open('./speechocean762/resource/scores.json') # path to speechocean score json
9 | data = json.load(f)
10 |
11 |
12 | test_file = open('./speechocean762/test/wav.scp','r').read().splitlines() # path to speechocean test list
13 | test_data = {}
14 | for line in test_file:
15 | wavidx = line.split('\t')[0]
16 | test_data[wavidx] = data[wavidx]
17 |
18 |
19 | gt_alignment_dir = './gt_alignment_test'
20 | wav_dir = './speechocean762/wav'
21 | charsiu = charsiu_forced_aligner(aligner='charsiu/en_w2v2_fc_10ms')
22 |
23 | for wavidx in tqdm(test_data.keys()):
24 |
25 | wavpath = os.path.join(wav_dir , wavidx+'.wav')
26 | gt_sen_list = []
27 | for word in data[wavidx]['words']:
28 | gt_sen_list.append(word['text'].lower())
29 |
30 | gt_sen = ' '.join(gt_sen_list)
31 | try:
32 | pred_phones, pred_words, words, pred_prob, phone_ids, word_phone_map = get_charsiu_alignment(wavpath, gt_sen, charsiu)
33 | selected_idx = get_match_index(pred_words, words)
34 | pred_words = np.asarray(pred_words)
35 | pred_words = pred_words[selected_idx]
36 | torch.save(pred_words, os.path.join(gt_alignment_dir, wavidx+'.pt'))
37 | if len(gt_sen_list)!=len(pred_words):
38 | print(wavidx)
39 | print(gt_sen_list)
40 | print(pred_words)
41 | except Exception as e:
42 | print(e)
43 | print(wavidx)
44 | print(gt_sen_list)
45 |
--------------------------------------------------------------------------------
/speechocean762/create_training_and_testing_list.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import numpy as np
4 |
5 | def get_filepaths(directory):
6 | file_paths = []
7 | for root, directories, files in os.walk(directory):
8 | for filename in files:
9 | # Join the two strings in order to form the full filepath.
10 | filepath = os.path.join(root, filename)
11 | if filename.endswith('.wav'):
12 | file_paths.append(filepath)
13 | return file_paths
14 |
15 | f = open('./resource/scores-detail.json')
16 | data = json.load(f)
17 |
18 |
19 | train_file = open('./train/wav.scp','r').read().splitlines()
20 | test_file = open('./test/wav.scp','r').read().splitlines()
21 |
22 | train_out = open('./speechocean762_train.txt','w')
23 | test_out = open('./speechocean762_test.txt','w')
24 |
25 | for line in train_file:
26 | wavidx = line.split('\t')[0]
27 | the_data = data[wavidx]
28 | accuracy = the_data['accuracy']
29 | completeness = the_data['completeness']
30 | fluency = the_data['fluency']
31 | prosodic = the_data['prosodic']
32 | total = the_data['total']
33 |
34 | for idx in range(5):
35 | W_acc_list = []
36 | W_stress_list = []
37 | W_total_list = []
38 | sen_length = len(the_data['words'])
39 | for w_idx in range(sen_length):
40 | w_acc = str(the_data['words'][w_idx]['accuracy'][idx])
41 | w_stress = str(the_data['words'][w_idx]['stress'][idx])
42 | w_total = str(the_data['words'][w_idx]['total'][idx])
43 | W_acc_list.append(w_acc)
44 | W_stress_list.append(w_stress)
45 | W_total_list.append(w_total)
46 |
47 | word_acc = ','.join(W_acc_list)
48 | word_stress = ','.join(W_stress_list)
49 | word_total = ','.join(W_total_list)
50 | raw = '{}.wav;{};{};{};{};{};{};{};{}\n'.format(wavidx, accuracy[idx], fluency[idx], prosodic[idx], total[idx], word_acc, word_stress, word_total, sen_length)
51 | train_out.write(raw)
52 |
53 |
54 | for line in test_file:
55 | wavidx = line.split('\t')[0]
56 | test_out.write(wavidx+'.wav\n')
--------------------------------------------------------------------------------
/whisper_asr_all.py:
--------------------------------------------------------------------------------
1 | import os
2 | import whisper
3 | import numpy as np
4 | import pandas as pd
5 | from tqdm import tqdm
6 | import argparse
7 |
8 | #Install whisper using: pip install -U openai-whisper
9 | #https://github.com/openai/whisper
10 |
11 | def get_filepaths(directory, format='.wav'):
12 | file_paths = []
13 | for root, _, files in os.walk(directory):
14 | for filename in files:
15 | if filename.endswith(format):
16 | file_paths.append(filename)
17 | return file_paths
18 |
19 | def main():
20 |
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument('--datadir', default='./speechocean762/wav', type=str, help='Path of DATA/ directory')
23 | args = parser.parse_args()
24 | input_dir= args.datadir
25 |
26 | file_list = get_filepaths(input_dir, format='.wav') #loop all the .wav file in dir
27 | file_list = set(file_list)
28 | model = whisper.load_model("base.en")
29 | outputname = os.path.join('whisper_results','speechocean_whisper_all_base_eng.csv')
30 |
31 | if not os.path.exists('whisper_results'):
32 | os.makedirs('whisper_results')
33 |
34 | try:
35 | print('Number of files:', len(file_list))
36 | df = pd.read_csv(outputname)
37 | exist_list = set(df['wavname'].to_list())
38 | print('Number of already processed files:', len(exist_list))
39 | file_list = file_list - exist_list
40 | print('Number of unprocessed files:',len(file_list))
41 | file_list = list(file_list)
42 | except Exception as e:
43 | print('Create new file')
44 | df = pd.DataFrame(columns=['wavname', 'transcript'])
45 | df.to_csv(outputname, sep=',', index=False, header=True)
46 |
47 |
48 | for filename in tqdm(file_list):
49 | path = os.path.join(input_dir, filename)
50 | result = model.transcribe(path, fp16=False)
51 | transcript = result['text']
52 | results = pd.DataFrame([{'wavname':filename,'transcript':transcript}])
53 | results.to_csv(outputname, mode='a', sep=',', index=False, header=False)
54 |
55 |
56 | if __name__ == '__main__':
57 | main()
58 |
--------------------------------------------------------------------------------
/get_training_features.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import whisper
5 | import numpy as np
6 | import pandas as pd
7 | from tqdm import tqdm
8 | from utils_assessment import *
9 | from Charsiu import charsiu_forced_aligner
10 | from fairseq.models.roberta import RobertaModel
11 |
12 |
13 | def get_transcript(df):
14 | asr_results = {}
15 | for index, row in df.iterrows():
16 | the_wavname = row['wavname'].replace('.wav','')
17 | the_transcript = row['transcript']
18 | asr_results[the_wavname] = the_transcript
19 | return asr_results
20 |
21 | def create_dir(outputdir):
22 | if not os.path.exists(outputdir):
23 | os.makedirs(outputdir)
24 |
25 | def main():
26 |
27 |
28 | f = open('./speechocean762/resource/scores.json') # path to speechocean score json
29 | data = json.load(f)
30 |
31 | train_file = open('./speechocean762/speechocean762_train.txt','r').read().splitlines()
32 | train_list = []
33 | for line in train_file:
34 | wavname = line.split(';')[0].split('.')[0]
35 | train_list.append(wavname)
36 | train_list = list(set(train_list))
37 |
38 | df = pd.read_csv('./whisper_results/speechocean_whisper_all_base_eng.csv')
39 | #df_m = pd.read_csv('./whisper_results/speechocean_whisper_all_medium_eng.csv')
40 | #df_s = pd.read_csv('./whisper_results/speechocean_whisper_all_small_eng.csv')
41 | #df_t = pd.read_csv('./whisper_results/speechocean_whisper_all_tiny_eng.csv')
42 |
43 | whisper_results = get_transcript(df)
44 |
45 | wav_dir = './speechocean762/wav'
46 |
47 | outputdir_pred_words_gt = './feature_base/pred_words_gt'
48 | outputdir_features_p = './feature_base/features_p'
49 | outputdir_features_w = './feature_base/features_w'
50 | outputdir_phone_vector = './feature_base/phone_vector'
51 | outputdir_gt_word_embed = './feature_base/gt_word_embed'
52 | outputdir_asr_word_embed = './feature_base/asr_word_embed'
53 | outputdir_word_phone_map = './feature_base/word_phone_map'
54 |
55 | create_dir(outputdir_pred_words_gt)
56 | create_dir(outputdir_features_p)
57 | create_dir(outputdir_features_w)
58 | create_dir(outputdir_phone_vector)
59 | create_dir(outputdir_gt_word_embed)
60 | create_dir(outputdir_asr_word_embed)
61 | create_dir(outputdir_word_phone_map)
62 |
63 | charsiu = charsiu_forced_aligner(aligner='charsiu/en_w2v2_fc_10ms')
64 | roberta = RobertaModel.from_pretrained('./fairseq_roberta', checkpoint_file='model.pt')
65 | roberta.eval()
66 |
67 | error_list = open('error_list.txt','w')
68 | for wavname in tqdm(train_list):
69 | try:
70 | wavpath = os.path.join(wav_dir , wavname+'.wav')
71 | gt_sen_list = []
72 | for word in data[wavname]['words']:
73 | gt_sen_list.append(word['text'].lower())
74 |
75 | gt_sen = ' '.join(gt_sen_list)
76 |
77 | asr_sen = whisper_results[wavname].lower()
78 | asr_sen = remove_pun_except_apostrophe(asr_sen)
79 | asr_sen = convert_num_to_word(asr_sen)
80 |
81 | pred_words_gt, features_p, features_w, phone_vector, gt_word_embed, asr_word_embed, word_phone_map = feature_extraction(wavpath, gt_sen, asr_sen, alignment_model=charsiu, word_model=roberta)
82 |
83 | torch.save(pred_words_gt, os.path.join(outputdir_pred_words_gt, wavname+'.pt'))
84 | torch.save(features_p, os.path.join(outputdir_features_p, wavname+'.pt'))
85 | torch.save(features_w, os.path.join(outputdir_features_w, wavname+'.pt'))
86 | torch.save(phone_vector, os.path.join(outputdir_phone_vector, wavname+'.pt'))
87 | torch.save(gt_word_embed, os.path.join(outputdir_gt_word_embed, wavname+'.pt'))
88 | torch.save(asr_word_embed, os.path.join(outputdir_asr_word_embed, wavname+'.pt'))
89 | torch.save(word_phone_map, os.path.join(outputdir_word_phone_map, wavname+'.pt'))
90 |
91 | if len(pred_words_gt) != len(gt_sen_list):
92 | error_list.write(wavname+'#'+str(pred_words_gt)+'#'+str(gt_sen_list)+'\n')
93 |
94 | except Exception as e:
95 | print(e)
96 | error_list.write(wavname+'\n')
97 |
98 | if __name__ == '__main__':
99 | main()
100 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import torchaudio
5 | from torch.utils.data.dataset import Dataset
6 | from torch.nn.utils.rnn import pad_sequence
7 |
8 |
9 | # For MyDatasetW5
10 | ASR_WORD_EMBED_DIR = 'feature_base/asr_word_embed'
11 | GT_WORD_EMBED_DIR = 'feature_base/gt_word_embed'
12 | GT_ALIGNMENT_DIR = 'feature_base/pred_words_gt'
13 | WORD_FEATURE_DIR = 'feature_base/features_w'
14 | PHONE_FEATURE_DIR = 'feature_base/features_p'
15 | PHONEVECTOR_DIR = 'feature_base/phone_vector'
16 | WORD_PHONE_MAP_DIR = 'feature_base/word_phone_map'
17 |
18 | SAMPLE_RATE=16000
19 |
20 | class MyDataset(Dataset):
21 |
22 | def __init__(self, rootdir, data_list):
23 |
24 | self.A_lookup = {}
25 | self.F_lookup = {}
26 | self.P_lookup = {}
27 | self.T_lookup = {}
28 | self.w_acc_lookup = {}
29 | self.w_stress_lookup = {}
30 | self.w_total_lookup = {}
31 | self.num_w = {}
32 |
33 | wavfiles = []
34 | for line in data_list:
35 | parts = line.split(';')
36 | wavfile = parts[0]
37 |
38 | A = float(parts[1])
39 | F = float(parts[2])
40 | P = float(parts[3])
41 | T = float(parts[4])
42 | w_acc = parts[5].split(',')
43 | w_stress = parts[6].split(',')
44 | w_total = parts[7].split(',')
45 |
46 | w_acc = [float(x) for x in w_acc]
47 | w_stress = [float(x) for x in w_stress]
48 | w_total = [float(x) for x in w_total]
49 |
50 | num_of_word = float(parts[8])
51 | self.A_lookup[wavfile] = A
52 | self.F_lookup[wavfile] = F
53 | self.P_lookup[wavfile] = P
54 | self.T_lookup[wavfile] = T
55 | self.w_acc_lookup[wavfile] = w_acc
56 | self.w_stress_lookup[wavfile] = w_stress
57 | self.w_total_lookup[wavfile] = w_total
58 | self.num_w[wavfile] = num_of_word
59 |
60 | wavfiles.append(wavfile)
61 |
62 | self.rootdir = rootdir
63 | self.wavfiles = sorted(wavfiles)
64 |
65 | def __getitem__(self, idx):
66 | wavfile = self.wavfiles[idx]
67 | wavpath = os.path.join(self.rootdir, wavfile)
68 | wav = torchaudio.load(wavpath)[0]
69 |
70 | try:
71 | asr_word_embed = torch.from_numpy(torch.load(os.path.join(ASR_WORD_EMBED_DIR, wavfile.replace('.wav','.pt')))).float()
72 | gt_word_embed = torch.from_numpy(torch.load(os.path.join(GT_WORD_EMBED_DIR, wavfile.replace('.wav','.pt')))).float()
73 | gt_alignment = torch.load(os.path.join(GT_ALIGNMENT_DIR, wavfile.replace('.wav','.pt')))
74 | features_w = torch.from_numpy(torch.load(os.path.join(WORD_FEATURE_DIR, wavfile.replace('.wav','.pt')))).float()
75 | features_p = torch.from_numpy(torch.load(os.path.join(PHONE_FEATURE_DIR, wavfile.replace('.wav','.pt')))).float()
76 | phonevector = torch.from_numpy(torch.load(os.path.join(PHONEVECTOR_DIR, wavfile.replace('.wav','.pt')))).float()
77 | word_phone_map = torch.load(os.path.join(WORD_PHONE_MAP_DIR, wavfile.replace('.wav','.pt')))
78 |
79 | except Exception as e:
80 | print(e, wavfile)
81 | return None
82 |
83 | num_w = int(len(gt_alignment))
84 |
85 | timesplit = [(int(float(word[0])*SAMPLE_RATE), int(float(word[1])*SAMPLE_RATE)) for word in gt_alignment]
86 |
87 | s_A = self.A_lookup[wavfile]
88 | s_F = self.F_lookup[wavfile]
89 | s_P = self.P_lookup[wavfile]
90 | s_T = self.T_lookup[wavfile]
91 |
92 | w_s_acc = torch.tensor(self.w_acc_lookup[wavfile])
93 | w_s_stress = torch.tensor(self.w_stress_lookup[wavfile])
94 | w_s_total = torch.tensor(self.w_total_lookup[wavfile])
95 |
96 |
97 | return wav, s_A, s_F, s_P, s_T, w_s_acc, w_s_stress, w_s_total, timesplit, asr_word_embed, gt_word_embed, features_w, features_p, phonevector, word_phone_map, num_w, wavfile
98 |
99 |
100 | def __len__(self):
101 | return len(self.wavfiles)
102 |
103 |
104 | def collate_fn(self, batch): ## zero padding
105 |
106 | batch = list(filter(lambda x: x is not None, batch))
107 |
108 | wav, s_A, s_F, s_P, s_T, w_s_acc, w_s_stress, w_s_total, timesplit, asr_word_embed, gt_word_embed, features_w, features_p, phonevector, word_phone_map, num_w, wavfile = zip(*batch)
109 |
110 | wavs = list(wav)
111 | max_len = max(wavs, key = lambda x : x.shape[1]).shape[1]
112 | output_wavs = []
113 | for wav in wavs:
114 | amount_to_pad = max_len - wav.shape[1]
115 | padded_wav = torch.nn.functional.pad(wav, (0, amount_to_pad), 'constant', 0)
116 | output_wavs.append(padded_wav)
117 | output_wavs = torch.stack(output_wavs, dim=0)
118 |
119 | phonevector = pad_sequence(phonevector, batch_first=True)
120 | asr_word_embed = pad_sequence(asr_word_embed, batch_first=True)
121 | gt_word_embed = pad_sequence(gt_word_embed, batch_first=True)
122 | features_w = pad_sequence(features_w, batch_first=True)
123 | features_p = pad_sequence(features_p, batch_first=True)
124 |
125 | w_s_acc = pad_sequence(w_s_acc, batch_first=True)
126 | w_s_stress = pad_sequence(w_s_stress, batch_first=True)
127 | w_s_total = pad_sequence(w_s_total, batch_first=True)
128 | s_A = torch.stack([torch.tensor(x) for x in list(s_A)], dim=0)
129 | s_F = torch.stack([torch.tensor(x) for x in list(s_F)], dim=0)
130 | s_P = torch.stack([torch.tensor(x) for x in list(s_P)], dim=0)
131 | s_T = torch.stack([torch.tensor(x) for x in list(s_T)], dim=0)
132 | timesplit = list(timesplit)
133 | word_phone_map = list(word_phone_map)
134 |
135 | return output_wavs, s_A, s_F, s_P, s_T, w_s_acc, w_s_stress, w_s_total, timesplit, asr_word_embed, gt_word_embed, features_w, features_p, phonevector, word_phone_map, num_w, wavfile
136 |
137 |
--------------------------------------------------------------------------------
/test_open.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gc
3 | import json
4 | import argparse
5 | import torch
6 | import fairseq
7 | import whisper
8 | import numpy as np
9 | import torch.nn as nn
10 | import torchaudio
11 | from tqdm import tqdm
12 | from fairseq.models.roberta import RobertaModel
13 | from dataclasses import dataclass
14 | from utils_assessment import *
15 | from model_assessment import PronunciationPredictor
16 | from Charsiu import charsiu_forced_aligner
17 |
18 | gc.collect()
19 | torch.cuda.empty_cache()
20 |
21 | def main():
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument('--fairseq_base_model', type=str, default='./fairseq_hubert/hubert_base_ls960.pt', help='Path to pretrained fairseq hubert model.')
24 | parser.add_argument('--fairseq_roberta', type=str, default='./fairseq_roberta', help='Path to pretrained fairseq roberta.')
25 | parser.add_argument('--datadir', default='./speechocean762/wav', type=str, help='Path of your DATA/ directory')
26 | parser.add_argument('--datalist', default='./speechocean762/speechocean762_test.txt', type=str, help='')
27 | parser.add_argument('--ckptdir', type=str, help='Path to pretrained checkpoint.')
28 |
29 |
30 | args = parser.parse_args()
31 |
32 | ssl_path = args.fairseq_base_model
33 | roberta_path = args.fairseq_roberta
34 | my_checkpoint_dir = args.ckptdir
35 | datadir = args.datadir
36 | datalist = args.datalist
37 |
38 |
39 | word_model = RobertaModel.from_pretrained(roberta_path, checkpoint_file='model.pt')
40 | word_model.eval()
41 | whisper_model_s = whisper.load_model("medium.en")
42 | whisper_model_w = whisper.load_model("base.en")
43 |
44 |
45 | aligment_model = charsiu_forced_aligner(aligner='charsiu/en_w2v2_fc_10ms')
46 |
47 | SSL_OUT_DIM = 768
48 | TEXT_OUT_DIM = 768
49 | SAMPLE_RATE = 16000
50 |
51 | print('Loading checkpoint')
52 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53 | print('DEVICE: ' + str(device))
54 |
55 | ssl_model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ssl_path])
56 | ssl_model = ssl_model[0]
57 |
58 | assessment_model = PronunciationPredictor(ssl_model, SSL_OUT_DIM, TEXT_OUT_DIM).to(device)
59 | assessment_model.eval()
60 | assessment_model.load_state_dict(torch.load(os.path.join(my_checkpoint_dir,'PRO'+os.sep+'best')))
61 |
62 | print('Loading data')
63 | validset = open(datalist,'r').read().splitlines()
64 | outfile = my_checkpoint_dir.split("/")[-1]+'_'+datalist.split('/')[-1].replace('.txt','_mb.txt')
65 |
66 | output_dir = 'Results'
67 | if not os.path.exists(output_dir):
68 | os.makedirs(output_dir)
69 |
70 | prediction = open(os.path.join(output_dir, outfile), 'w')
71 |
72 | print('Starting prediction')
73 | for filename in tqdm(validset):
74 |
75 | with torch.no_grad():
76 | if datalist is not None:
77 | filepath = os.path.join(datadir, filename)
78 | else:
79 | filepath=filename
80 | wav, sr = torchaudio.load(filepath)
81 | #resample audio recordin to 16000Hz
82 | if sr!=16000:
83 | transform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)
84 | wav = transform(wav)
85 | sr = SAMPLE_RATE
86 |
87 | wav = torch.reshape(wav, (-1,))
88 | sen_asr_s = remove_pun_except_apostrophe(get_transcript(wav, whisper_model_s)).lower()
89 | sen_asr_s = convert_num_to_word(sen_asr_s)
90 |
91 | sen_asr_w = remove_pun_except_apostrophe(get_transcript(wav, whisper_model_w)).lower()
92 | sen_asr_w = convert_num_to_word(sen_asr_w)
93 |
94 | try:
95 | pred_words_gt, features_p, features_w, phonevector, gt_word_embed, asr_word_embed, word_phone_map = feature_extraction(wav.numpy(), sen_asr_s, sen_asr_w, alignment_model=aligment_model, word_model=word_model)
96 |
97 | timesplit = [[(int(float(word[0])*SAMPLE_RATE), int(float(word[1])*SAMPLE_RATE)) for word in pred_words_gt]]
98 | word_phone_map = [word_phone_map]
99 |
100 | features_p = torch.from_numpy(features_p).to(device).float().unsqueeze(0)
101 | features_w = torch.from_numpy(features_w).to(device).float().unsqueeze(0)
102 | phonevector = torch.from_numpy(phonevector).to(device).float().unsqueeze(0)
103 | gt_word_embed = torch.from_numpy(gt_word_embed).to(device).float().unsqueeze(0)
104 | asr_word_embed = torch.from_numpy(asr_word_embed).to(device).float().unsqueeze(0)
105 | wav = wav.to(device).unsqueeze(0)
106 |
107 | score_A, score_F, score_P, score_T, w_acc, w_stress, w_total = assessment_model(wav, asr_word_embed, gt_word_embed, features_p, features_w, phonevector, word_phone_map, timesplit)
108 | score_A = score_A.cpu().detach().numpy()[0]
109 | score_F = score_F.cpu().detach().numpy()[0]
110 | score_P = score_P.cpu().detach().numpy()[0]
111 | score_T = score_T.cpu().detach().numpy()[0]
112 | w_a = w_acc.cpu().detach().numpy()[0]
113 | w_s = w_stress.cpu().detach().numpy()[0]
114 | w_t = w_total.cpu().detach().numpy()[0]
115 |
116 | w_a = ','.join([str(num) for num in w_a])
117 | w_s = ','.join([str(num) for num in w_s])
118 | w_t = ','.join([str(num) for num in w_t])
119 |
120 | valid = 'T'
121 | output = "{}; A:{}; F:{}; P:{}; T:{}; Valid:{}; ASR_s:{}; ASR_w:{}; w_a:{}; w_s:{}; w_t:{}; alignment:{}".format(filename, score_A, score_F, score_P, score_T, valid, sen_asr_s, sen_asr_w, w_a, w_s, w_t, pred_words_gt.tolist())
122 | print(output)
123 | prediction.write(output+'\n')
124 |
125 | except Exception as e:
126 | print(e)
127 | valid = 'F'
128 | output = "{}; A:{}; F:{}; P:{}; T:{}; Valid:{}; ASR_s:{}; ASR_w:{}; w_a:{}; w_s:{}; w_t:{}; alignment:{}".format(filename, '', '', '', '', valid, sen_asr_s, sen_asr_w, '', '', '', '')
129 | prediction.write(output+'\n')
130 | continue
131 |
132 |
133 | torch.cuda.empty_cache()
134 |
135 |
136 |
137 |
138 | if __name__ == '__main__':
139 | main()
140 |
--------------------------------------------------------------------------------
/test_closed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gc
3 | import json
4 | import argparse
5 | import torch
6 | import fairseq
7 | import whisper
8 | import numpy as np
9 | import torch.nn as nn
10 | import torchaudio
11 | from tqdm import tqdm
12 | from fairseq.models.roberta import RobertaModel
13 | from dataclasses import dataclass
14 | from utils_assessment import *
15 | from model_assessment import PronunciationPredictor
16 | from Charsiu import charsiu_forced_aligner
17 |
18 | gc.collect()
19 | torch.cuda.empty_cache()
20 |
21 | def main():
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument('--fairseq_base_model', type=str, default='./fairseq_hubert/hubert_base_ls960.pt', help='Path to pretrained fairseq hubert model.')
24 | parser.add_argument('--fairseq_roberta', type=str, default='./fairseq_roberta', help='Path to pretrained fairseq roberta.')
25 | parser.add_argument('--speechocean_gt', type=str, default='./speechocean762/resource/scores.json', help='Path to speechocean scores.json')
26 | parser.add_argument('--datadir', default='./speechocean762/wav', type=str, help='Path of your DATA/ directory')
27 | parser.add_argument('--datalist', default='./speechocean762/speechocean762_test.txt', type=str, help='')
28 | parser.add_argument('--ckptdir', type=str, help='Path to pretrained checkpoint.')
29 |
30 |
31 | args = parser.parse_args()
32 |
33 | ssl_path = args.fairseq_base_model
34 | roberta_path = args.fairseq_roberta
35 | my_checkpoint_dir = args.ckptdir
36 | datadir = args.datadir
37 | datalist = args.datalist
38 |
39 | f = open(args.speechocean_gt)
40 | gt_data = json.load(f)
41 |
42 | word_model = RobertaModel.from_pretrained(roberta_path, checkpoint_file='model.pt')
43 | word_model.eval()
44 | whisper_model_w = whisper.load_model("base.en")
45 |
46 | aligment_model = charsiu_forced_aligner(aligner='charsiu/en_w2v2_fc_10ms')
47 |
48 | SSL_OUT_DIM = 768
49 | TEXT_OUT_DIM = 768
50 | SAMPLE_RATE = 16000
51 |
52 | print('Loading checkpoint')
53 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54 | print('DEVICE: ' + str(device))
55 |
56 | ssl_model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ssl_path])
57 | ssl_model = ssl_model[0]
58 |
59 | assessment_model = PronunciationPredictor(ssl_model, SSL_OUT_DIM, TEXT_OUT_DIM).to(device)
60 | assessment_model.eval()
61 | assessment_model.load_state_dict(torch.load(os.path.join(my_checkpoint_dir,'PRO'+os.sep+'best')))
62 |
63 | print('Loading data')
64 | validset = open(datalist,'r').read().splitlines()
65 | outfile = my_checkpoint_dir.split("/")[-1]+'_'+datalist.split('/')[-1].replace('.txt','_gtb.txt')
66 |
67 | output_dir = 'Results'
68 | if not os.path.exists(output_dir):
69 | os.makedirs(output_dir)
70 |
71 | prediction = open(os.path.join(output_dir, outfile), 'w')
72 |
73 | print('Starting prediction')
74 | for filename in tqdm(validset):
75 |
76 | with torch.no_grad():
77 | if datalist is not None:
78 | filepath = os.path.join(datadir, filename)
79 | else:
80 | filepath=filename
81 | wav, sr = torchaudio.load(filepath)
82 | #resample audio recordin to 16000Hz
83 | if sr!=16000:
84 | transform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)
85 | wav = transform(wav)
86 | sr = SAMPLE_RATE
87 |
88 | sen_asr_s = []
89 | for word in gt_data[filename.replace('.wav','')]['words']:
90 | sen_asr_s.append(word['text'].lower())
91 | sen_asr_s = ' '.join(sen_asr_s)
92 |
93 | wav = torch.reshape(wav, (-1,))
94 |
95 | sen_asr_w = remove_pun_except_apostrophe(get_transcript(wav, whisper_model_w)).lower()
96 | sen_asr_w = convert_num_to_word(sen_asr_w)
97 |
98 | try:
99 |
100 | pred_words_gt, features_p, features_w, phonevector, gt_word_embed, asr_word_embed, word_phone_map = feature_extraction(wav.numpy(), sen_asr_s, sen_asr_w, alignment_model=aligment_model, word_model=word_model)
101 |
102 | timesplit = [[(int(float(word[0])*SAMPLE_RATE), int(float(word[1])*SAMPLE_RATE)) for word in pred_words_gt]]
103 | word_phone_map = [word_phone_map]
104 |
105 | features_p = torch.from_numpy(features_p).to(device).float().unsqueeze(0)
106 | features_w = torch.from_numpy(features_w).to(device).float().unsqueeze(0)
107 | phonevector = torch.from_numpy(phonevector).to(device).float().unsqueeze(0)
108 | gt_word_embed = torch.from_numpy(gt_word_embed).to(device).float().unsqueeze(0)
109 | asr_word_embed = torch.from_numpy(asr_word_embed).to(device).float().unsqueeze(0)
110 | wav = wav.to(device).unsqueeze(0)
111 |
112 | score_A, score_F, score_P, score_T, w_acc, w_stress, w_total = assessment_model(wav, asr_word_embed, gt_word_embed, features_p, features_w, phonevector, word_phone_map, timesplit)
113 | score_A = score_A.cpu().detach().numpy()[0]
114 | score_F = score_F.cpu().detach().numpy()[0]
115 | score_P = score_P.cpu().detach().numpy()[0]
116 | score_T = score_T.cpu().detach().numpy()[0]
117 | w_a = w_acc.cpu().detach().numpy()[0]
118 | w_s = w_stress.cpu().detach().numpy()[0]
119 | w_t = w_total.cpu().detach().numpy()[0]
120 |
121 | w_a = ','.join([str(num) for num in w_a])
122 | w_s = ','.join([str(num) for num in w_s])
123 | w_t = ','.join([str(num) for num in w_t])
124 |
125 | valid = 'T'
126 | output = "{}; A:{}; F:{}; P:{}; T:{}; Valid:{}; ASR_s:{}; ASR_w:{}; w_a:{}; w_s:{}; w_t:{}; alignment:{}".format(filename, score_A, score_F, score_P, score_T, valid, sen_asr_s, sen_asr_w, w_a, w_s, w_t, pred_words_gt.tolist())
127 | print(output)
128 | prediction.write(output+'\n')
129 |
130 | except Exception as e:
131 | print(e)
132 | valid = 'F'
133 | output = "{}; A:{}; F:{}; P:{}; T:{}; Valid:{}; ASR_s:{}; ASR_w:{}; w_a:{}; w_s:{}; w_t:{}; alignment:{}".format(filename, '', '', '', '', valid, sen_asr_s, sen_asr_w, '', '', '', '')
134 | prediction.write(output+'\n')
135 | continue
136 |
137 |
138 | torch.cuda.empty_cache()
139 |
140 |
141 |
142 |
143 | if __name__ == '__main__':
144 | main()
145 |
--------------------------------------------------------------------------------
/api.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gc
3 | import json
4 | import time
5 | import argparse
6 | import torch
7 | import copy
8 | import fairseq
9 | import whisper
10 | import numpy as np
11 | import torch.nn as nn
12 | import torchaudio
13 | from tqdm import tqdm
14 | from fairseq.models.roberta import RobertaModel
15 | from dataclasses import dataclass
16 | from utils_assessment import *
17 | from model_assessment import PronunciationPredictor
18 | from Charsiu import charsiu_forced_aligner
19 |
20 | gc.collect()
21 | torch.cuda.empty_cache()
22 |
23 | SSL_OUT_DIM = 768
24 | TEXT_OUT_DIM = 768
25 | SAMPLE_RATE = 16000
26 |
27 | def inference_one_seg(sen_asr_s, sen_asr_w, wav):
28 |
29 | sen_asr_s = remove_pun_except_apostrophe(sen_asr_s).lower()
30 | sen_asr_s = convert_num_to_word(sen_asr_s)
31 |
32 | sen_asr_w = remove_pun_except_apostrophe(sen_asr_w).lower()
33 | sen_asr_w = convert_num_to_word(sen_asr_w)
34 |
35 | pred_words_gt, features_p, features_w, phonevector, gt_word_embed, asr_word_embed, word_phone_map = feature_extraction(wav.numpy(), sen_asr_s, sen_asr_w, alignment_model=aligment_model, word_model=word_model)
36 |
37 | timesplit = [[(int(float(word[0])*SAMPLE_RATE), int(float(word[1])*SAMPLE_RATE)) for word in pred_words_gt]]
38 | word_phone_map = [word_phone_map]
39 |
40 | wav = torch.reshape(wav, (1, -1))
41 | wav = wav.to(device)
42 | features_p = torch.from_numpy(features_p).to(device).float().unsqueeze(0)
43 | features_w = torch.from_numpy(features_w).to(device).float().unsqueeze(0)
44 | phonevector = torch.from_numpy(phonevector).to(device).float().unsqueeze(0)
45 | gt_word_embed = torch.from_numpy(gt_word_embed).to(device).float().unsqueeze(0)
46 | asr_word_embed = torch.from_numpy(asr_word_embed).to(device).float().unsqueeze(0)
47 |
48 | score_A, score_F, score_P, score_T, w_acc, w_stress, w_total = assessment_model(wav, asr_word_embed, gt_word_embed, features_p, features_w, phonevector, word_phone_map, timesplit)
49 |
50 | torch.cuda.empty_cache()
51 |
52 | return score_A, score_F, score_P, score_T, w_acc, w_stress, w_total, pred_words_gt
53 |
54 |
55 | def predict_one_file(filepath, whisper_model_s, whisper_model_w, word_model):
56 |
57 | results = {}
58 | results['wavname'] = filepath.split('/')[-1]
59 | with torch.no_grad():
60 |
61 | wav, sr = torchaudio.load(filepath)
62 | #resample audio recordin to 16000Hz
63 | if sr!=16000:
64 | transform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)
65 | wav = transform(wav)
66 | sr = SAMPLE_RATE
67 |
68 | if wav.shape[0]!=1:
69 | wav = torch.mean(wav,0)
70 |
71 | wav = torch.reshape(wav, (-1, ))
72 |
73 | if wav.shape[0] < SAMPLE_RATE*15: #if input wavfile is less than 15s, process the wavfile at once
74 | sen_asr_s = get_transcript(wav, whisper_model_s)
75 | sen_asr_w = get_transcript(wav, whisper_model_w)
76 | score_A, score_F, score_P, score_T, w_acc, w_stress, w_total, pred_words_gt = inference_one_seg(sen_asr_s, sen_asr_w, wav)
77 | score_A = score_A.cpu().detach().numpy()[0]
78 | score_F = score_F.cpu().detach().numpy()[0]
79 | score_P = score_P.cpu().detach().numpy()[0]
80 | score_T = score_T.cpu().detach().numpy()[0]
81 | w_a = w_acc.cpu().detach().numpy()[0]
82 | w_s = w_stress.cpu().detach().numpy()[0]
83 | w_t = w_total.cpu().detach().numpy()[0]
84 | results = {}
85 | pred_words = [word[-1] for word in pred_words_gt]
86 | results['uttr_acc'] = score_A
87 | results['uttr_fluency'] = score_F
88 | results['uttr_prosodic'] = score_P
89 | results['uttr_total'] = score_T
90 | results['word_acc'] = w_a
91 | results['word_stress'] = w_s
92 | results['word_total'] = w_t
93 | results['word_text'] = pred_words
94 | results['transcript_S'] = sen_asr_s
95 | results['transcript_W'] = sen_asr_w
96 |
97 | return results
98 |
99 | else: #if wavfile longer than 15s, do the segmentation to prevent OOM
100 |
101 | sen_asr_s_all = get_transcript(wav, whisper_model_s, return_seg=True)
102 | sen_asr_w_all = ''
103 | for seg in sen_asr_s_all['segments']:
104 | sen_asr_s = seg['text']
105 | start = float(seg['start'])
106 | end = float(seg['end'])
107 | the_wav = wav[int(start*SAMPLE_RATE):int(end*SAMPLE_RATE)]
108 | sen_asr_w = get_transcript(the_wav, whisper_model_w)
109 | sen_asr_w_all = sen_asr_w_all+' '+sen_asr_w
110 | the_score_A, the_score_F, the_score_P, the_score_T, the_w_acc, the_w_stress, the_w_total, the_pred_words_gt = inference_one_seg(sen_asr_s, sen_asr_w, the_wav)
111 | the_pred_words = [word[-1] for word in the_pred_words_gt]
112 | try:
113 | score_A += the_score_A
114 | score_F += the_score_F
115 | score_P += the_score_P
116 | score_T += the_score_T
117 | w_acc = torch.cat((w_acc,the_w_acc.squeeze(0)))
118 | w_stress = torch.cat((w_stress,the_w_stress.squeeze(0)))
119 | w_total = torch.cat((w_acc,the_w_total.squeeze(0)))
120 | pred_word.extend(the_pred_words)
121 | except Exception as e: #first word
122 | score_A = the_score_A
123 | score_F = the_score_F
124 | score_P = the_score_P
125 | score_T = the_score_T
126 | w_acc = the_w_acc.squeeze(0)
127 | w_stress = the_w_stress.squeeze(0)
128 | w_total = the_w_total.squeeze(0)
129 | pred_word = the_pred_words
130 |
131 | num_of_seg = len(sen_asr_s_all['segments'])
132 | results['uttr_acc'] = (score_A/num_of_seg)
133 | results['uttr_fluency'] = (score_F/num_of_seg)
134 | results['uttr_prosodic'] = (score_P/num_of_seg)
135 | results['uttr_total'] = (score_T/num_of_seg)
136 | results['word_acc'] = w_acc.cpu().detach().numpy()
137 | results['word_stress'] = w_stress.cpu().detach().numpy()
138 | results['word_total'] = w_total.cpu().detach().numpy()
139 | results['word_text'] = pred_word
140 | results['transcript_S'] = sen_asr_s_all['text']
141 | results['transcript_W'] = sen_asr_w_all
142 |
143 | return results
144 |
145 |
146 | parser = argparse.ArgumentParser()
147 | parser.add_argument('--fairseq_base_model', type=str, default='./fairseq_hubert/hubert_base_ls960.pt', help='Path to pretrained fairseq hubert model.')
148 | parser.add_argument('--fairseq_roberta', type=str, default='./fairseq_roberta', help='Path to pretrained fairseq roberta.')
149 | parser.add_argument('--inputdir', type=str, help='Path to testing wavfile.')
150 | parser.add_argument('--ckptdir', type=str, help='Path to pretrained checkpoint.')
151 |
152 |
153 | args = parser.parse_args()
154 |
155 | ssl_path = args.fairseq_base_model
156 | roberta_path = args.fairseq_roberta
157 | my_checkpoint_dir = args.ckptdir
158 | file_dir = args.inputdir
159 |
160 | word_model = RobertaModel.from_pretrained(roberta_path, checkpoint_file='model.pt')
161 | word_model.eval()
162 | whisper_model_s = whisper.load_model("medium.en")
163 | whisper_model_w = whisper.load_model("base.en")
164 | aligment_model = charsiu_forced_aligner(aligner='charsiu/en_w2v2_fc_10ms')
165 |
166 |
167 | print('Loading checkpoint')
168 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
169 | print('DEVICE: ' + str(device))
170 |
171 | ssl_model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ssl_path])
172 | ssl_model = ssl_model[0]
173 |
174 | assessment_model = PronunciationPredictor(ssl_model, SSL_OUT_DIM, TEXT_OUT_DIM).to(device)
175 | assessment_model.eval()
176 | assessment_model.load_state_dict(torch.load(os.path.join(my_checkpoint_dir,'PRO'+os.sep+'best')))
177 |
178 |
179 | filepath_list = get_filepaths(file_dir)
180 |
181 | for filepath in filepath_list:
182 | s = time.time()
183 | results = predict_one_file(filepath, whisper_model_s, whisper_model_w, word_model)
184 | print('Process time:', time.time()-s)
185 | print(results)
186 |
187 |
188 |
--------------------------------------------------------------------------------
/evaluation_speechocean_closed.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import scipy
4 | from sklearn.metrics import mean_squared_error
5 | import math
6 | import string
7 | import numpy as np
8 | import math
9 | from collections import Counter
10 |
11 | def print_result(pred, gt, score_name):
12 | mse = mean_squared_error(pred, gt)
13 | corr, _ = scipy.stats.pearsonr(pred, gt)
14 | spearman, _ = scipy.stats.spearmanr(pred, gt)
15 | #print('mse:', mse)
16 | #print('corr:', round(corr,4))
17 | #print('srcc:', round(spearman,4))
18 | print(score_name, round(corr,4))
19 |
20 |
21 | f = open('./speechocean762/resource/scores.json') # path to speechocean score json
22 | data = json.load(f)
23 |
24 | test_file = open('./speechocean762/test/wav.scp','r').read().splitlines() # path to speechocean test list
25 | test_data = {}
26 | for line in test_file:
27 | wavidx = line.split('\t')[0]
28 | test_data[wavidx] = data[wavidx]
29 |
30 | def get_prediction(path):
31 | invalid = 0
32 | prediction = open(path,'r').read().splitlines()
33 | result_word = {}
34 | result_uttr = {}
35 | for sample in prediction:
36 |
37 | parts = sample.split(';')
38 | wavidx = parts[0].replace('.wav','')
39 | valid = parts[5].split(':')[1]
40 | if valid=='F':
41 | invalid+=1
42 | accuracy = 1.0
43 | fluency = 0.0
44 | prosodic = 0.0
45 | total = 0.0
46 | completeness = 0.0
47 | result_word[wavidx]={}
48 | result_word[wavidx]['word_accuracy'] = 0
49 | result_word[wavidx]['word_stress'] = 5
50 | result_word[wavidx]['word_total'] = 1
51 | result_word[wavidx]['text'] = ''
52 | else:
53 | accuracy = float(parts[1].split(':')[1])
54 | fluency = float(parts[2].split(':')[1])
55 | prosodic = float(parts[3].split(':')[1])
56 | total = float(parts[4].split(':')[1])
57 | alignment = eval(parts[-1].split(':')[1])
58 | time_interval = [float(word[1])-float(word[0]) for word in alignment]
59 | completeness = [1 if the_interval > 0.07 else 0 for the_interval in time_interval]
60 | completeness = sum(completeness)/len(completeness)
61 |
62 | w_a = eval(parts[8].split(':')[1])
63 | w_s = eval(parts[9].split(':')[1])
64 | w_t = eval(parts[10].split(':')[1])
65 | if isinstance(w_a , float):
66 | w_a = [w_a]
67 | w_s = [w_s]
68 | w_t = [w_t]
69 | w_a = [10 if x > 10 else x for x in w_a]
70 | w_s = [10 if x > 10 else x for x in w_s]
71 | w_t = [10 if x > 10 else x for x in w_t]
72 | result_word[wavidx]={}
73 | result_word[wavidx]['word_accuracy'] = w_a
74 | result_word[wavidx]['word_stress'] = w_s
75 | result_word[wavidx]['word_total'] = w_t
76 | result_word[wavidx]['text'] = eval(parts[-1].split(':')[1])
77 | result_word[wavidx]['text'] = [word[-1] for word in result_word[wavidx]['text']]
78 |
79 | result_uttr[wavidx]={}
80 | result_uttr[wavidx]['accuracy'] = accuracy
81 | result_uttr[wavidx]['fluency'] = fluency
82 | result_uttr[wavidx]['prosodic'] = prosodic
83 | result_uttr[wavidx]['total'] = total
84 | result_uttr[wavidx]['completeness'] = completeness
85 |
86 | #print(invalid)
87 | return result_word, result_uttr
88 |
89 | def pad_mismatch_sequence(the_gt_w_text, the_pred_w_text, the_pred_w_acc, the_pred_w_stress, the_pred_w_total):
90 | """
91 | Sometimes the model will merge consecutive occurrences of the same word. e.g. "seven nine nine one" to "seven nine one"
92 | In this case, the number of predicted scores won't in line with the number of ground-truth words.
93 | Therefore, we duplicate the scores for the merged word.
94 | """
95 | padded_acc = []
96 | padded_stress = []
97 | padded_total = []
98 | asr_w_idx=0
99 |
100 | for gt_word in the_gt_w_text:
101 | if asr_w_idx>=len(the_pred_w_text):
102 | padded_acc.append(the_pred_w_acc[asr_w_idx-1])
103 | padded_stress.append(the_pred_w_stress[asr_w_idx-1])
104 | padded_total.append(the_pred_w_total[asr_w_idx-1])
105 | break
106 |
107 | if gt_word == the_pred_w_text[asr_w_idx]:
108 | padded_acc.append(the_pred_w_acc[asr_w_idx])
109 | padded_stress.append(the_pred_w_stress[asr_w_idx])
110 | padded_total.append(the_pred_w_total[asr_w_idx])
111 | asr_w_idx+=1
112 | else:
113 | padded_acc.append(the_pred_w_acc[asr_w_idx-1])
114 | padded_stress.append(the_pred_w_stress[asr_w_idx-1])
115 | padded_total.append(the_pred_w_total[asr_w_idx-1])
116 |
117 | return padded_acc, padded_stress, padded_total
118 |
119 | def calculate_performance(result_word, result_uttr, wav_idx_word, wav_idx_uttr):
120 |
121 | gt_A = []
122 | gt_F = []
123 | gt_P = []
124 | gt_T = []
125 | gt_C = []
126 |
127 | pred_A = []
128 | pred_F = []
129 | pred_P = []
130 | pred_T = []
131 | pred_C = []
132 |
133 | for wavidx in wav_idx_uttr:
134 | gt_A.append(test_data[wavidx]['accuracy'])
135 | pred_A.append(result_uttr[wavidx]['accuracy'])
136 | gt_F.append(test_data[wavidx]['fluency'])
137 | pred_F.append(result_uttr[wavidx]['fluency'])
138 | gt_P.append(test_data[wavidx]['prosodic'])
139 | pred_P.append(result_uttr[wavidx]['prosodic'])
140 | gt_T.append(test_data[wavidx]['total'])
141 | pred_T.append(result_uttr[wavidx]['total'])
142 | gt_C.append(test_data[wavidx]['completeness'])
143 | pred_C.append(result_uttr[wavidx]['completeness'])
144 |
145 | print('number of utterance', len(pred_A))
146 | print_result(pred_A, gt_A, 'sen-accuracy')
147 | print_result(pred_F, gt_F, 'sen-fluency')
148 | print_result(pred_P, gt_P, 'sen-prosody')
149 | print_result(pred_T, gt_T, 'sen-total')
150 | print_result(pred_C, gt_C,'sen-completeness')
151 |
152 | gt_w_acc = []
153 | gt_w_stress = []
154 | gt_w_total = []
155 | pred_w_acc = []
156 | pred_w_stress = []
157 | pred_w_total = []
158 | count_sen = 0
159 | for wavidx in wav_idx_word:
160 | the_gt_w_acc = []
161 | the_gt_w_stress = []
162 | the_gt_w_total = []
163 | the_gt_w_text = []
164 |
165 | for word in test_data[wavidx]['words']:
166 | the_gt_w_acc.append(int(word['accuracy']))
167 | the_gt_w_stress.append(int(word['stress']))
168 | the_gt_w_total.append(int(word['total']))
169 | the_gt_w_text.append(word['text'].lower())
170 |
171 | the_pred_w_acc = result_word[wavidx]['word_accuracy']
172 | the_pred_w_stress = result_word[wavidx]['word_stress']
173 | the_pred_w_total = result_word[wavidx]['word_total']
174 |
175 | if len(the_gt_w_text) != len(result_word[wavidx]['text']): #if ground-truth sen and predicted sen not equal in length
176 | if result_word[wavidx]['text'] == '': # if the ASR cannot recognize the sentence, return the lowest score in the training data
177 | gt_len = len(the_gt_w_text)
178 | the_pred_w_acc = [the_pred_w_acc for _ in range(gt_len)]
179 | the_pred_w_stress = [the_pred_w_stress for _ in range(gt_len)]
180 | the_pred_w_total = [the_pred_w_total for _ in range(gt_len)]
181 | else: # for the case where the forced alignment model merges consecutive occurrences of the same word
182 | the_pred_w_acc, the_pred_w_stress, the_pred_w_total = pad_mismatch_sequence(the_gt_w_text, result_word[wavidx]['text'], the_pred_w_acc, the_pred_w_stress, the_pred_w_total)
183 |
184 | #assert len(the_gt_w_acc) == len(the_pred_w_acc)
185 | gt_w_acc.extend(the_gt_w_acc)
186 | gt_w_stress.extend(the_gt_w_stress)
187 | gt_w_total.extend(the_gt_w_total)
188 | pred_w_acc.extend(the_pred_w_acc)
189 | pred_w_stress.extend(the_pred_w_stress)
190 | pred_w_total.extend(the_pred_w_total)
191 | count_sen+=1
192 |
193 | #print('number of sentences for word prediction:', count_sen, "# of words:", len(pred_w_acc))
194 | print_result(pred_w_acc, gt_w_acc, 'word-acc')
195 | print_result(pred_w_stress, gt_w_stress, 'word-stress')
196 | print_result(pred_w_total, gt_w_total, 'word-total')
197 |
198 |
199 |
200 | resultA_word, resultA_uttr = get_prediction('./Results/model_assessment_val9_r1_speechocean762_test_gtb.txt')
201 |
202 | calculate_performance(resultA_word, resultA_uttr, list(resultA_word.keys()),list(resultA_uttr.keys()))
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | import numpy as np
6 |
7 | import re
8 | from praatio import textgrid
9 | from itertools import groupby
10 | from librosa.sequence import dtw
11 |
12 |
13 |
14 | def ctc2duration(phones,resolution=0.01):
15 | """
16 | xxxxx convert ctc to duration
17 |
18 | Parameters
19 | ----------
20 | phones : list
21 | A list of phone sequence
22 | resolution : float, optional
23 | The resolution of xxxxx. The default is 0.01.
24 |
25 | Returns
26 | -------
27 | merged : list
28 | xxxxx A list of duration values.
29 |
30 | """
31 |
32 | counter = 0
33 | out = []
34 | for p,group in groupby(phones):
35 | length = len(list(group))
36 | out.append((round(counter*resolution,2),round((counter+length)*resolution,2),p))
37 | counter += length
38 |
39 | merged = []
40 | for i, (s,e,p) in enumerate(out):
41 | if i==0 and p=='[PAD]':
42 | merged.append((s,e,'[SIL]'))
43 | elif p=='[PAD]':
44 | merged.append((out[i-1][0],e,out[i-1][2]))
45 | elif i==len(out)-1:
46 | merged.append((s,e,p))
47 | return merged
48 |
49 |
50 | def seq2duration(phones,resolution=0.01):
51 | """
52 | xxxxx convert phone sequence to duration
53 |
54 | Parameters
55 | ----------
56 | phones : list
57 | A list of phone sequence
58 | resolution : float, optional
59 | The resolution of xxxxx. The default is 0.01.
60 |
61 | Returns
62 | -------
63 | out : list
64 | xxxxx A list of duration values.
65 |
66 | """
67 |
68 | counter = 0
69 | out = []
70 | for p,group in groupby(phones):
71 | length = len(list(group))
72 | out.append((round(counter*resolution,3),round((counter+length)*resolution,3),p))
73 | counter += length
74 | return out
75 |
76 |
77 |
78 | def duration2textgrid(duration_seq,save_path=None):
79 | """
80 | Save duration values to textgrids
81 |
82 | Parameters
83 | ----------
84 | duration_seq : list
85 | xxxxx A list of duration values.
86 | save_path : str, optional
87 | The path to save the TextGrid files. The default is None.
88 |
89 | Returns
90 | -------
91 | tg : TextGrid file?? str?? xxxxx?
92 | A textgrid object containing duration information.
93 |
94 | """
95 |
96 | tg = textgrid.Textgrid()
97 | phoneTier = textgrid.IntervalTier('phones', duration_seq, 0, duration_seq[-1][1])
98 | tg.addTier(phoneTier)
99 | if save_path:
100 | tg.save(save_path,format="short_textgrid", includeBlankSpaces=False)
101 | return tg
102 |
103 |
104 | def word2textgrid(duration_seq,word_seq,save_path=None):
105 | """
106 | Save duration values to textgrids
107 |
108 | Parameters
109 | ----------
110 | duration_seq : list
111 | xxxxx A list of duration values.
112 | save_path : str, optional
113 | The path to save the TextGrid files. The default is None.
114 |
115 | Returns
116 | -------
117 | tg : TextGrid file?? str?? xxxxx?
118 | A textgrid object containing duration information.
119 |
120 | """
121 |
122 | tg = textgrid.Textgrid()
123 | phoneTier = textgrid.IntervalTier('phones', duration_seq, 0, duration_seq[-1][1])
124 | tg.addTier(phoneTier)
125 | wordTier = textgrid.IntervalTier('words', word_seq, 0, word_seq[-1][1])
126 | tg.addTier(wordTier)
127 | if save_path:
128 | tg.save(save_path,format="short_textgrid", includeBlankSpaces=False)
129 | return tg
130 |
131 |
132 |
133 | def get_boundaries(phone_seq):
134 | """
135 | Get time of phone boundaries
136 |
137 | Parameters
138 | ----------
139 | phone_seq : list xxxx?
140 | A list of phone sequence.
141 |
142 | Returns
143 | -------
144 | timings: A list of time stamps
145 | symbols: A list of phone symbols
146 |
147 | """
148 |
149 | boundaries = defaultdict(set)
150 | for s,e,p in phone_seq:
151 | boundaries[s].update([p.upper()])
152 | # boundaries[e].update([p.upper()+'_e'])
153 | timings = np.array(list(boundaries.keys()))
154 | symbols = list(boundaries.values())
155 | return (timings,symbols)
156 |
157 |
158 | def check_textgrid_duration(textgrid,duration):
159 | """
160 | Check whether the duration of a textgrid file equals to 'duration'.
161 | If not, replace duration of the textgrid file.
162 |
163 | Parameters
164 | ----------
165 | textgrid : .TextGrid object
166 | A .TextGrid object.
167 | duration : float
168 | A given length of time.
169 |
170 | Returns
171 | -------
172 | textgrid : .TextGrid object
173 | A modified/unmodified textgrid.
174 |
175 | """
176 |
177 |
178 | endtime = textgrid.tierDict['phones'].entryList[-1].end
179 | if not endtime==duration:
180 | last = textgrid.tierDict['phones'].entryList.pop()
181 | textgrid.tierDict['phones'].entryList.append(last._replace(end=duration))
182 |
183 | return textgrid
184 |
185 |
186 | def textgrid_to_labels(phones,duration,resolution):
187 | """
188 |
189 |
190 | Parameters
191 | ----------
192 | phones : list
193 | A list of phone sequence
194 | resolution : float, optional
195 | The resolution of xxxxx. The default is 0.01.
196 | duration : float
197 | A given length of time.
198 |
199 |
200 | Returns
201 | -------
202 | labels : list
203 | A list of phone labels.
204 |
205 | """
206 |
207 | labels = []
208 | clock = 0.0
209 |
210 | for i, (s,e,p) in enumerate(phones):
211 |
212 | assert clock >= s
213 | while clock <= e:
214 | labels.append(p)
215 | clock += resolution
216 |
217 | # if more than half of the current frame is outside the current phone
218 | # we'll label it as the next phone
219 | if np.abs(clock-e) > resolution/2:
220 | labels[-1] = phones[min(len(phones)-1,i+1)][2]
221 |
222 | # if the final time interval is longer than the total duration
223 | # we will chop off this frame
224 | if clock-duration > resolution/2:
225 | labels.pop()
226 |
227 | return labels
228 |
229 | def remove_null_and_numbers(labels):
230 | """
231 | Remove labels which are null, noise, or numbers.
232 |
233 | Parameters
234 | ----------
235 | labels : list
236 | A list of text labels.
237 |
238 | Returns
239 | -------
240 | out : list
241 | A list of new labels.
242 |
243 | """
244 |
245 | out = []
246 | noises = set(['SPN','NSN','LAU'])
247 | for l in labels:
248 | l = re.sub(r'\d+','',l)
249 | l = l.upper()
250 | if l == '' or l == 'SIL':
251 | l = '[SIL]'
252 | if l == 'SP':
253 | l = '[SIL]'
254 | if l in noises:
255 | l = '[UNK]'
256 | out.append(l)
257 | return out
258 |
259 |
260 | def insert_sil(phones):
261 | """
262 | Insert silences.
263 |
264 | Parameters
265 | ----------
266 | phones : list
267 | A list of phone sequence
268 |
269 | Returns
270 | -------
271 | out : list
272 | A list of new labels.
273 |
274 | """
275 |
276 | out = []
277 | for i,(s,e,p) in enumerate(phones):
278 |
279 | if out:
280 | if out[-1][1]!=s:
281 | out.append((out[-1][1],s,'[SIL]'))
282 | out.append((s,e,p))
283 | return out
284 |
285 |
286 | def forced_align(cost, phone_ids):
287 |
288 | """
289 | Force align text to audio.
290 |
291 | Parameters
292 | ----------
293 | cost : float xxxxx
294 | xxxxx.
295 | phone_ids : list
296 | A list of phone IDs.
297 |
298 | Returns
299 | -------
300 | align_id : list
301 | A list of IDs for aligned phones.
302 |
303 | """
304 |
305 | D,align = dtw(C=-cost[:,phone_ids],
306 | step_sizes_sigma=np.array([[1, 1], [1, 0]]))
307 |
308 | align_seq = [-1 for i in range(max(align[:,0])+1)]
309 | for i in list(align):
310 | # print(align)
311 | if align_seq[i[0]]
10 |
11 |
12 | - [Requirement](#Requirement)
13 | - [Train and evalaute on speechocean762 dataset](#Train-and-evalaute-on-speechocean762-dataset)
14 | - [Test on your data](#Test-on-your-data)
15 | - [References](#References)
16 | - [MultiPA data](#MultiPA-data)
17 | - [Citation](#Citation)
18 |
19 |
20 | ## Requirement
21 |
22 | ```
23 | conda create -n MultiPA python=3.9
24 | conda activate MultiPA
25 | pip install fairseq
26 | pip install soundfile
27 | pip install -U openai-whisper
28 | pip install transformers
29 | pip install num2words
30 | pip install pyctcdecode
31 | pip install https://github.com/kpu/kenlm/archive/master.zip
32 | pip install spacy==2.3.0
33 | pip install levenshtein
34 | pip install nltk
35 | pip install praatio
36 | pip install g2pM
37 | pip install librosa
38 | pip install g2p-en
39 | pip install pandas
40 | ```
41 | Note: spacy needs to be 2.x version
42 |
43 |
44 | #### Download pre-trained model
45 | (1) Download [HuBERT Base (~95M params)](https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/README.md), and put the hubert_base_ls960.pt in fairseq_hubert dir.
46 | (2) Download [roberta.base model](https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.md), and put the model.pt and dict.txt in fairseq_roberta dir.
47 |
48 | ## Train and evalaute on speechocean762 dataset
49 |
50 | ### Step 1. Data Preparation
51 | (1) Download the speechocean762 dataset: [Link](https://www.openslr.org/101).
52 | (2) Put the resource, test, train, and WAVE in the speechocean762 dir.
53 | - i.e., merge the speechocean762 dir in this repo and the download speechocean762 dir.
54 |
55 | (3) Convert .WAV file to .wav file using:
56 | ```
57 | cd speechocean762
58 | python convert_wavfile.py
59 | ```
60 | (4) Generate the training and testing list using:
61 | ```
62 | cd speechocean762
63 | python create_training_and_testing_list.py
64 | ```
65 | (5) Obtain the whisper ASR result.
66 | ```
67 | python whisper_asr_all.py
68 | ```
69 | (6) Generate training features
70 | ```
71 | python get_training_features.py
72 | ```
73 |
74 | ### Step 2. Model training
75 | ```
76 | python model_assessment.py --outdir model_assessment
77 | ```
78 | Note: usually, the validation loss will stop decreasing after 2 epochs.
79 |
80 | ### Step 3. Inference
81 | Get the assessment results in a closed response scenario (using ground-truth transcript)
82 | ```
83 | python test_closed.py --ckptdir model_assessment
84 | ```
85 | The results will be saved in the "model_assessment_speechocean762_test_gtb.txt" with the format:
86 | {wavname}; A:{uttr-accuracy}; F:{uttr-fluency}; P:{uttr-prosodic}; T:{uttr-total}; Valid:{whether_output_is_valid}; ASR_s:{groud-truth-sentence}; ASR_w:{asr-results}; w_a:{word-accuracy}; w_s:{word-stress}; w_t:{w-total}; alignment:{forced-alignment-result}
87 |
88 |
89 | Get the assessment results in an open response scenario (using the result of ASRt as a ground-truth transcript)
90 | ```
91 | python test_open.py --ckptdir model_assessment
92 | ```
93 | The results will be saved in the "model_assessment_speechocean762_test_mb.txt"
94 |
95 |
96 | ### Step 4. Evaluation
97 | -----
98 | ### Closed response scenario
99 |
100 | Use "evaluation_speechocean_closed.py". Change the input path of the "get_prediction" function to the path of generated txt file in the Step 3.
101 |
102 | Note:
103 | - Since the whisper might give different recognition results for the same utterance, the performance scores will be slightly different for different runs.
104 | - For utterances that the MultiPA fails to process, the lowest scores in the training data are used. (i.e., accuracy = 1, fluency = 0, prosodic = 0, total = 0, completeness = 0, word_accuracy = 0, word_stress= 5, and word_total = 1.)
105 | - The scores higher than 10 will be clipped to 10.
106 | - The scores in the paper are the average of five models training with different random seeds.
107 |
108 | Closed response performance (PCC):
109 | | sen-accuracy | sen-fluency | sen-prosody | sen-total | word-accuracy | word-stress | word-total |
110 | |---------------|--------------|---------------|-------------|---------------|-------------|------------|
111 | | ~0.73 | ~0.79 | ~0.78 | ~0.76 |~0.52 |~0.19 | ~0.53 |
112 |
113 | ### Completeness assessment metric
114 | According to results of the previous studies, the completeness score is difficult to learn using a neural network. Therefore, we propose a forced-alignment-based method to assess completeness instead of training the completeness score with the main structure. The completeness score is defined as whether an L2 learner completes all words in the target sentence without any omissions. The idea is that if the L2 learner misses a word during pronunciation, the forced-alignment model will encounter difficulty in aligning that word to the speech signal, resulting in a shorter duration for the missed word compared to the properly pronounced words. Then, a predefined duration threshold selected based on empirical experiments can classify words into complete and incomplete words. Finally, the completeness score is calculated by identifying the ratio of complete words to the total number of words in the transcript.
115 |
116 | We analyze the word duration distribution from the forced alignment model. In this experiment, complete words are drawn from the speechocean762 dataset with full completeness scores, whereas incomplete words are simulated by randomly inserting an extra word into the original sentence. Complete words follow a nearly normal distribution, averaging around 0.38 seconds. In contrast, incomplete words exhibit a right-skewed pattern, with mean approximately at 0.075 seconds. Next, we assess various duration thresholds' effectiveness in distinguishing complete from incomplete words. Our findings reveal a threshold of roughly 0.07 seconds yielding the highest F1 score, approximately 85\%, indicating that the forced-alignment model can effectively detect missing words in a target sentence.
117 |
118 |
119 | -----
120 | ### Open response scenario
121 |
122 | Use "evaluation_speechocean_open.py".
123 | (1) Calculate and save the alignment information of the ground-truth transcript using get_gt_alignment.py. Change the path to dir in line 163.
124 | (2) Change the input path of the "get_prediction" function to the path of the generated txt file in Step 3.
125 | Note:
126 | - The evaluation of the word-level assessment result is different from the closed response scenario because there is a potential mismatch between the ground-truth label and the predicted scores. (the ground-truth labels are aligned with the ground-truth words, whereas the predicted word-level scores are aligned with the ASR-recognized words.)
127 | - Since the whisper might give different recognition results for the same utterance, the performance scores will be slightly different for different runs.
128 |
129 | | sen-accuracy | sen-fluency | sen-prosody | sen-total | word-accuracy | word-stress | word-total |
130 | |---------------|--------------|---------------|-------------|---------------|-------------|------------|
131 | | ~0.70 | ~0.77 | ~0.76 | ~0.73 |~0.42 |~0.24 | ~0.44 |
132 |
133 |
134 |
135 | ## Test on your data
136 |
137 | ```
138 | python api.py --inputdir /path/to/your/wav/dir --ckptdir model_assessment
139 | ```
140 | Note:
141 | - This api works on open response. Please replace "sen_asr_s" with the target sentence if you want to test on closed response.
142 | - One limitation of MultiPA is its ability to assess long utterances. First, MultiPA might fail to process a long utterance due to the GPU out-off-memory issue. In addition, its generalization capabilities might be limited because it was trained on utterances shorter than 20 seconds. Therefore, an additional audio segmentation step is recommended when using MultiPA on long utterances. In the api.py, we implement a simple segmentation method based on whisper's results. Specifically, if a wave file is longer than 15 seconds, the model will work on whisper segments and merge (average) the results instead of processing the entire wave file at once.
143 |
144 | Pretrained model:
145 | Download pre-trained model: [Google Drive](https://drive.google.com/file/d/1Kpm3BeEh6Rh7JZ5tatyHMUMipuo0RYds/view?usp=sharing)
146 |
147 | ## References
148 | The Charsiu.py, models.py, processors.py, utils.py in this repo are revised from [Charsiu](https://github.com/lingjzhu/charsiu/tree/main).
149 | The major change includes:
150 | (1) return the output embedding (return the probability of all possible phones)
151 | (2) prevent merging the duration of multiple identical words.
152 | (e.g., transcript: very very -> return (s1, e1, very), (s2, e2, very) instead of (s1, e2, very))
153 | -> However, in some cases, the model will still return only one alignment result, leading to the mismatch between words in the input sentence and the alignedd words.
154 |
155 | ## MultiPA data
156 |
157 | Pilot dataset for real-world open response scenario speech assessment.
158 | [Download](https://drive.google.com/drive/folders/1T1_xTcwPF94WtUU4XVId7bGYPvAvwCSZ?usp=sharing)
159 | Please submit a request to access. I will add your email to share list
160 |
161 | ### Acknowledge
162 | This data is collected by using [Label Studio](https://labelstud.io/) Academic Program. We thank Label studio to provide platform that allows researchers to collect data easily.
163 |
164 | ## Citation
165 | Please cite our paper if you find this repository useful, Thanks!
166 |
167 | @inproceedings{chen2024multipa,
168 | title = {{MultiPA}: A Multi-task Speech Pronunciation Assessment Model for Open Response Scenarios},
169 | author = {Chen, Yu-Wen and Yu, Zhou and Hirschberg, Julia},
170 | booktitle = {Proc. INTERSPEECH 2024}
171 | }
172 |
--------------------------------------------------------------------------------
/evaluation_speechocean_open.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import scipy
4 | from sklearn.metrics import mean_squared_error
5 | import math
6 | import string
7 | import numpy as np
8 | import math
9 | import torch
10 |
11 | def pad_mismatch_sequence(the_gt_w_text, the_pred_w_text, the_pred_w_acc, the_pred_w_stress, the_pred_w_total):
12 | """
13 | Sometimes the model will merge consecutive occurrences of the same word. e.g. "seven nine nine one" to "seven nine one"
14 | In this case, the number of predicted scores won't in line with the number of ground-truth words.
15 | Therefore, we duplicate the scores for the merged word.
16 | """
17 | padded_acc = []
18 | padded_stress = []
19 | padded_total = []
20 | asr_w_idx=0
21 |
22 | for gt_word in the_gt_w_text:
23 | if asr_w_idx>=len(the_pred_w_text):
24 | padded_acc.append(the_pred_w_acc[asr_w_idx-1])
25 | padded_stress.append(the_pred_w_stress[asr_w_idx-1])
26 | padded_total.append(the_pred_w_total[asr_w_idx-1])
27 | break
28 |
29 | if gt_word == the_pred_w_text[asr_w_idx]:
30 | padded_acc.append(the_pred_w_acc[asr_w_idx])
31 | padded_stress.append(the_pred_w_stress[asr_w_idx])
32 | padded_total.append(the_pred_w_total[asr_w_idx])
33 | asr_w_idx+=1
34 | else:
35 | padded_acc.append(the_pred_w_acc[asr_w_idx-1])
36 | padded_stress.append(the_pred_w_stress[asr_w_idx-1])
37 | padded_total.append(the_pred_w_total[asr_w_idx-1])
38 |
39 | return padded_acc, padded_stress, padded_total
40 |
41 | def align_two_sentences(result_gt, result_asr):
42 |
43 | asr_wordidx_list = []
44 | for _, gt_value in enumerate(result_gt):
45 | gt_start = gt_value[0]
46 | gt_end = gt_value[1]
47 | asr_wordidx = []
48 | for asr_idx, asr_value in enumerate(result_asr):
49 | asr_start = asr_value[0]
50 | asr_end = asr_value[1]
51 | if gt_end <= asr_start:
52 | break
53 | if gt_start >= asr_end:
54 | continue
55 | if max(gt_start, asr_start) <= min(gt_end, asr_end):
56 | asr_wordidx.append(asr_idx)
57 |
58 | asr_wordidx_list.append(asr_wordidx)
59 |
60 | return asr_wordidx_list
61 |
62 | def print_result(pred, gt):
63 | mse = mean_squared_error(pred, gt)
64 | corr, _ = scipy.stats.pearsonr(pred, gt)
65 | spearman, _ = scipy.stats.spearmanr(pred, gt)
66 | #print('mse:', mse)
67 | #print('corr:', round(corr,4))
68 | #print('srcc:', round(spearman,4))
69 | print(round(corr,4))
70 |
71 |
72 | def align_two_sentences(result_gt, result_asr):
73 |
74 | asr_wordidx_list = []
75 | for _, gt_value in enumerate(result_gt):
76 | gt_start = gt_value[0]
77 | gt_end = gt_value[1]
78 | asr_wordidx = []
79 | for asr_idx, asr_value in enumerate(result_asr):
80 | asr_start = asr_value[0]
81 | asr_end = asr_value[1]
82 | if gt_end <= asr_start:
83 | break
84 | if gt_start >= asr_end:
85 | continue
86 | if max(gt_start, asr_start) <= min(gt_end, asr_end):
87 | asr_wordidx.append(asr_idx)
88 |
89 | asr_wordidx_list.append(asr_wordidx)
90 |
91 | return asr_wordidx_list
92 |
93 | f = open('./speechocean762/resource/scores.json') # path to speechocean score json
94 | data = json.load(f)
95 |
96 | test_file = open('./speechocean762/test/wav.scp','r').read().splitlines() # path to speechocean test list
97 | test_data = {}
98 | for line in test_file:
99 | wavidx = line.split('\t')[0]
100 | test_data[wavidx] = data[wavidx]
101 |
102 | def get_prediction(path):
103 | invalid = 0
104 | prediction = open(path,'r').read().splitlines()
105 |
106 | print('# of prediction:', len(prediction))
107 | result_word = {}
108 | result_uttr = {}
109 | for sample in prediction:
110 |
111 | parts = sample.split(';')
112 | wavidx = parts[0].replace('.wav','')
113 | valid = parts[5].split(':')[1]
114 | if valid=='F':
115 | invalid+=1
116 | accuracy = 1.0
117 | fluency = 0.0
118 | prosodic = 0.0
119 | total = 0.0
120 | result_word[wavidx]={}
121 | result_word[wavidx]['word_accuracy'] = 0
122 | result_word[wavidx]['word_stress'] = 5
123 | result_word[wavidx]['word_total'] = 1
124 | result_word[wavidx]['text'] = ''
125 | result_word[wavidx]['alignment'] = None
126 | else:
127 | accuracy = float(parts[1].split(':')[1])
128 | fluency = float(parts[2].split(':')[1])
129 | prosodic = float(parts[3].split(':')[1])
130 | total = float(parts[4].split(':')[1])
131 |
132 | w_a = eval(parts[8].split(':')[1])
133 | w_s = eval(parts[9].split(':')[1])
134 | w_t = eval(parts[10].split(':')[1])
135 | if isinstance(w_a , float):
136 | w_a = [w_a]
137 | w_s = [w_s]
138 | w_t = [w_t]
139 | w_a = [10 if x > 10 else x for x in w_a]
140 | w_s = [10 if x > 10 else x for x in w_s]
141 | w_t = [10 if x > 10 else x for x in w_t]
142 | result_word[wavidx]={}
143 | result_word[wavidx]['word_accuracy'] = w_a
144 | result_word[wavidx]['word_stress'] = w_s
145 | result_word[wavidx]['word_total'] = w_t
146 | result_word[wavidx]['text'] = eval(parts[-1].split(':')[1])
147 | result_word[wavidx]['text'] = [word[-1] for word in result_word[wavidx]['text']]
148 | result_word[wavidx]['alignment'] = eval(parts[-1].split(':')[1])
149 |
150 | result_uttr[wavidx]={}
151 | result_uttr[wavidx]['accuracy'] = accuracy
152 | result_uttr[wavidx]['fluency'] = fluency
153 | result_uttr[wavidx]['prosodic'] = prosodic
154 | result_uttr[wavidx]['total'] = total
155 |
156 |
157 | print('# of invalid sentence:', invalid)
158 | return result_word, result_uttr
159 |
160 |
161 | def calculate_performance(result_word, result_uttr, wav_idx_word, wav_idx_uttr):
162 |
163 | gt_alignment_dir = '../gt_alignment_test'
164 | gt_A = []
165 | gt_F = []
166 | gt_P = []
167 | gt_T = []
168 |
169 | pred_A = []
170 | pred_F = []
171 | pred_P = []
172 | pred_T = []
173 |
174 | for wavidx in wav_idx_uttr:
175 | gt_A.append(test_data[wavidx]['accuracy'])
176 | pred_A.append(result_uttr[wavidx]['accuracy'])
177 | gt_F.append(test_data[wavidx]['fluency'])
178 | pred_F.append(result_uttr[wavidx]['fluency'])
179 | gt_P.append(test_data[wavidx]['prosodic'])
180 | pred_P.append(result_uttr[wavidx]['prosodic'])
181 | gt_T.append(test_data[wavidx]['total'])
182 | pred_T.append(result_uttr[wavidx]['total'])
183 |
184 | print('number of utterance', len(pred_A))
185 | #print('accuracy')
186 | print_result(pred_A, gt_A)
187 | #print('fluency')
188 | print_result(pred_F, gt_F)
189 | #print('prosodic')
190 | print_result(pred_P, gt_P)
191 | #print('total')
192 | print_result(pred_T, gt_T)
193 |
194 | gt_w_acc = []
195 | gt_w_stress = []
196 | gt_w_total = []
197 | pred_w_acc = []
198 | pred_w_stress = []
199 | pred_w_total = []
200 | count_sen = 0
201 | for wavidx in wav_idx_word:
202 | the_gt_w_acc = []
203 | the_gt_w_stress = []
204 | the_gt_w_total = []
205 | the_gt_w_text = []
206 |
207 | for word in test_data[wavidx]['words']:
208 | the_gt_w_acc.append(int(word['accuracy']))
209 | the_gt_w_stress.append(int(word['stress']))
210 | the_gt_w_total.append(int(word['total']))
211 | the_gt_w_text.append(word['text'].lower())
212 |
213 | the_pred_w_acc = result_word[wavidx]['word_accuracy']
214 | the_pred_w_stress = result_word[wavidx]['word_stress']
215 | the_pred_w_total = result_word[wavidx]['word_total']
216 |
217 | if result_word[wavidx]['alignment'] is None: #model fails
218 | gt_len = len(the_gt_w_text)
219 | the_pred_w_acc = [the_pred_w_acc for _ in range(gt_len)]
220 | the_pred_w_stress = [the_pred_w_stress for _ in range(gt_len)]
221 | the_pred_w_total = [the_pred_w_total for _ in range(gt_len)]
222 | gt_w_acc.extend(the_gt_w_acc)
223 | gt_w_stress.extend(the_gt_w_stress)
224 | gt_w_total.extend(the_gt_w_total)
225 | pred_w_acc.extend(the_pred_w_acc)
226 | pred_w_stress.extend(the_pred_w_stress)
227 | pred_w_total.extend(the_pred_w_total)
228 | continue
229 |
230 | pred_sen = ' '.join(result_word[wavidx]['text'])
231 | gt_sen = ' '.join(the_gt_w_text)
232 | if pred_sen!=gt_sen:
233 |
234 | pred_alignment = result_word[wavidx]['alignment']
235 | gt_alignment = torch.load(os.path.join(gt_alignment_dir, wavidx+'.pt'))
236 | align_result = align_two_sentences(pred_alignment, gt_alignment)
237 |
238 | the_gt_w_acc = np.asarray(the_gt_w_acc)
239 | the_gt_w_stress = np.asarray(the_gt_w_stress)
240 | the_gt_w_total = np.asarray(the_gt_w_total)
241 |
242 | align_gt_w_acc = []
243 | align_gt_w_stress = []
244 | align_gt_w_total = []
245 | for widxs in align_result:
246 | if len(widxs)!=0:
247 | align_gt_w_acc.append(np.mean(the_gt_w_acc[widxs]))
248 | align_gt_w_stress.append(np.mean(the_gt_w_stress[widxs]))
249 | align_gt_w_total.append(np.mean(the_gt_w_total[widxs]))
250 | else:
251 | align_gt_w_acc.append(0)
252 | align_gt_w_stress.append(5)
253 | align_gt_w_total.append(1)
254 |
255 | gt_w_acc.extend(align_gt_w_acc)
256 | gt_w_stress.extend(align_gt_w_stress)
257 | gt_w_total.extend(align_gt_w_total)
258 | pred_w_acc.extend(the_pred_w_acc)
259 | pred_w_stress.extend(the_pred_w_stress)
260 | pred_w_total.extend(the_pred_w_total)
261 |
262 | else: # prediction the same as the ground-truth
263 |
264 | if len(the_gt_w_text) != len(result_word[wavidx]['text']): #condition that forced alignment merge the duplicated sentence
265 | the_pred_w_acc, the_pred_w_stress, the_pred_w_total = pad_mismatch_sequence(the_gt_w_text, result_word[wavidx]['text'], the_pred_w_acc, the_pred_w_stress, the_pred_w_total)
266 |
267 | gt_w_acc.extend(the_gt_w_acc)
268 | gt_w_stress.extend(the_gt_w_stress)
269 | gt_w_total.extend(the_gt_w_total)
270 | pred_w_acc.extend(the_pred_w_acc)
271 | pred_w_stress.extend(the_pred_w_stress)
272 | pred_w_total.extend(the_pred_w_total)
273 |
274 |
275 | count_sen+=1
276 | #print(pred_sen, gt_sen)
277 |
278 | print('number of sentences for word prediction:', count_sen, "# of words:", len(pred_w_acc))
279 | #print('word acc')
280 | print_result(pred_w_acc, gt_w_acc)
281 | #print('word stress')
282 | print_result(pred_w_stress, gt_w_stress)
283 | #print('word total')
284 | print_result(pred_w_total, gt_w_total)
285 |
286 |
287 |
288 | resultA_word, resultA_uttr = get_prediction('./Results/model_assessment_val9_r2_speechocean762_test_mb.txt')
289 | calculate_performance(resultA_word, resultA_uttr, list(resultA_word.keys()),list(resultA_uttr.keys()))
--------------------------------------------------------------------------------
/model_assessment.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import fairseq
4 | import random
5 | import torch
6 | import numpy as np
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | import torch.nn.functional as F
10 | from tqdm import tqdm
11 | from torch.utils.data import DataLoader
12 | from dataloader import MyDataset
13 |
14 |
15 | random.seed(1984)
16 |
17 |
18 | class PronunciationPredictor(nn.Module):
19 |
20 | def __init__(self, ssl_model, ssl_out_dim, text_out_dim):
21 | super(PronunciationPredictor, self).__init__()
22 | self.ssl_model = ssl_model
23 | self.ssl_features = ssl_out_dim # size of HuBERT embedding
24 | self.text_out_dim = text_out_dim # size of roberta embedding
25 | self.w_features_size = 10
26 | self.p_align_size = 9
27 | self.p_pred_size = 42
28 | self.phone_vector = 71*2
29 |
30 | hidden_word = self.ssl_features+self.text_out_dim*2+self.phone_vector+self.w_features_size+self.p_align_size+self.p_pred_size*2+1
31 | hidden_sen = 768+hidden_word
32 |
33 | self.output_accuracy = nn.Linear(hidden_sen, 1)
34 | self.output_fluency = nn.Linear(hidden_sen, 1)
35 | self.output_prosodic = nn.Linear(hidden_sen, 1)
36 | self.output_total = nn.Linear(hidden_sen, 1)
37 |
38 | self.fusion_layer = nn.TransformerEncoderLayer(d_model=hidden_word, nhead=30)
39 |
40 | self.fusion_layer_word = nn.TransformerEncoderLayer(d_model=hidden_word, nhead=30)
41 | self.w_feature_layer = nn.TransformerEncoderLayer(d_model=self.w_features_size, nhead=5)
42 | self.p_feature_layer_align = nn.TransformerEncoderLayer(d_model=self.p_align_size, nhead=3)
43 | self.p_feature_layer_pred_gt = nn.TransformerEncoderLayer(d_model=self.p_pred_size, nhead=7)
44 | self.p_feature_layer_pred_asr = nn.TransformerEncoderLayer(d_model=self.p_pred_size, nhead=7)
45 | self.p_feature_layer_1 = nn.TransformerEncoderLayer(d_model=self.p_align_size+self.p_pred_size*2, nhead=3)
46 | self.p_feature_layer_2 = nn.TransformerEncoderLayer(d_model=self.p_align_size+self.p_pred_size*2, nhead=3)
47 | self.phonevector_layer = nn.TransformerEncoderLayer(d_model=self.phone_vector, nhead=2)
48 |
49 | self.word_acc = nn.Conv1d(hidden_word, 1, kernel_size=1)
50 | self.word_stress = nn.Conv1d(hidden_word, 1, kernel_size=1)
51 | self.word_total = nn.Conv1d(hidden_word, 1, kernel_size=1)
52 |
53 | def forward(self, wav, asr_word_embed, gt_word_embed, features_p, features_w, phonevector, word_phone_map, timesplit):
54 |
55 | res = self.ssl_model(wav, mask=False, features_only=True)
56 | wav_embedding_raw = res['x']
57 |
58 | ### align word-level features to the wavform
59 | batch_size = gt_word_embed.shape[0]
60 | wav_aligned = torch.zeros((gt_word_embed.shape[0], gt_word_embed.shape[1], self.ssl_features)).cuda()
61 | for b_idx in range(batch_size):
62 | for w_idx in range(len(timesplit[b_idx])):
63 | start_point = timesplit[b_idx][w_idx][0] // 320
64 | end_point = timesplit[b_idx][w_idx][1] // 320
65 | if (end_point - start_point)==0: # avoid predict nan because of no aligned wav segment
66 | the_word = wav_embedding_raw[b_idx, start_point:start_point+1, :]
67 | else:
68 | the_word = wav_embedding_raw[b_idx, start_point:end_point, :]
69 | aligned_wav_embed = the_word.mean(dim=0)
70 | wav_aligned[b_idx, w_idx, :] = aligned_wav_embed
71 |
72 |
73 | features_w = self.w_feature_layer(features_w)
74 | features_p = self.p_feature_layer_1(features_p)
75 | features_p[:,:,:self.p_align_size] = self.p_feature_layer_align(features_p[:,:,:self.p_align_size])
76 | features_p[:,:,self.p_align_size:self.p_align_size+self.p_pred_size] = self.p_feature_layer_pred_gt(features_p[:,:,self.p_align_size:self.p_align_size+self.p_pred_size])
77 | features_p[:,:,self.p_align_size+self.p_pred_size:] = self.p_feature_layer_pred_asr(features_p[:,:,self.p_align_size+self.p_pred_size:])
78 |
79 | # align phone-level features to word-level features
80 | features_p_aligned = torch.zeros((gt_word_embed.shape[0], gt_word_embed.shape[1], self.p_align_size+self.p_pred_size*2)).cuda()
81 | for b_idx in range(batch_size):
82 | for w_idx, p_list in enumerate(word_phone_map[b_idx]):
83 | features_p_aligned[b_idx, w_idx, :] = features_p[b_idx,p_list,:].mean(dim=0)
84 |
85 | features_p_aligned = self.p_feature_layer_2(features_p_aligned)
86 | phonevector = self.phonevector_layer(phonevector)
87 | fusion = torch.cat([wav_aligned, gt_word_embed, asr_word_embed, features_w, features_p_aligned, phonevector], dim=2)
88 | fusion = F.pad(fusion, (0, 1), mode='constant') # expand one dimension because original feature size is a prime number
89 |
90 | fusion_word = self.fusion_layer_word(fusion)
91 | fusion_word = fusion_word.transpose(1, 2)
92 | output_w_acc = self.word_acc(fusion_word)
93 | output_w_stress = self.word_stress(fusion_word)
94 | output_w_total = self.word_total(fusion_word)
95 | output_w_acc = output_w_acc.transpose(1,2).squeeze(2)
96 | output_w_stress = output_w_stress.transpose(1,2).squeeze(2)
97 | output_w_total = output_w_total.transpose(1,2).squeeze(2)
98 |
99 |
100 | fusion = self.fusion_layer(fusion)
101 | uttr_word = torch.mean(fusion, 1)
102 | wav_embedding = torch.mean(wav_embedding_raw, 1)
103 | uttr = torch.cat([wav_embedding, uttr_word], dim=1)
104 | output_A = self.output_accuracy(uttr)
105 | output_F = self.output_fluency(uttr)
106 | output_P = self.output_prosodic(uttr)
107 | output_T = self.output_total(uttr)
108 |
109 |
110 | return output_A.squeeze(1), output_F.squeeze(1), output_P.squeeze(1), output_T.squeeze(1), output_w_acc, output_w_stress, output_w_total
111 |
112 |
113 |
114 | def main():
115 |
116 | parser = argparse.ArgumentParser()
117 | parser.add_argument('--datadir', default='./speechocean762/wav', type=str, help='Path to root data directory')
118 | parser.add_argument('--txtfiledir', default='./speechocean762', type=str, help='Path to training txt directory')
119 | parser.add_argument('--fairseq_base_model', default='./fairseq_hubert/hubert_base_ls960.pt', type=str, help='Path to pretrained fairseq base model')
120 | parser.add_argument('--finetune_from_checkpoint', type=str, required=False, help='Path to the checkpoint to finetune from')
121 | parser.add_argument('--outdir', type=str, required=False, default='model_assessment', help='Output directory for trained checkpoints')
122 |
123 | args = parser.parse_args()
124 |
125 | cp_path = args.fairseq_base_model
126 | datadir = args.datadir
127 | ckptdir = args.outdir
128 | txtfiledir = args.txtfiledir
129 | my_checkpoint_dir = args.finetune_from_checkpoint
130 |
131 | if not os.path.exists(ckptdir):
132 | os.makedirs(os.path.join(ckptdir,'PRO'))
133 |
134 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
135 | print('DEVICE: ' + str(device))
136 |
137 | wavdir = os.path.join(datadir, '')
138 | input_list = open(os.path.join(txtfiledir, 'speechocean762_train.txt'),'r').read().splitlines()
139 | random.shuffle(input_list)
140 |
141 | trainlist = input_list[:int(len(input_list)*0.9)]
142 | validlist = input_list[int(len(input_list)*0.9):]
143 |
144 | SSL_OUT_DIM = 768
145 | TEXT_OUT_DIM = 768
146 |
147 | model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
148 | ssl_model = model[0]
149 | trainset = MyDataset(wavdir, trainlist)
150 | trainloader = DataLoader(trainset, batch_size=2, shuffle=True, num_workers=2, collate_fn=trainset.collate_fn)
151 | validset = MyDataset(wavdir, validlist)
152 | validloader = DataLoader(validset, batch_size=2, shuffle=True, num_workers=2, collate_fn=validset.collate_fn)
153 |
154 | net = PronunciationPredictor(ssl_model, SSL_OUT_DIM, TEXT_OUT_DIM)
155 | net = net.to(device)
156 |
157 | if my_checkpoint_dir != None:
158 | net.load_state_dict(torch.load(os.path.join(my_checkpoint_dir,'PRO','best')))
159 |
160 | criterion = nn.MSELoss()
161 | optimizer = optim.SGD(net.parameters(), lr=0.00005, momentum=0.7)
162 |
163 | PREV_VAL_LOSS=9999999999
164 | orig_patience=2
165 | patience=orig_patience
166 |
167 | for epoch in range(1,10):
168 | STEPS=0
169 | net.train()
170 | running_loss = 0.0
171 |
172 | for i, data in enumerate(tqdm(trainloader), 0):
173 |
174 | wav, s_A, s_F, s_P, s_T, w_s_acc, w_s_stress, w_s_total, timesplit, asr_word_embed, gt_word_embed, features_w, features_p, phonevector, word_phone_map, _, wavname = data
175 | wav = wav.to(device)
176 | s_A = s_A.to(device)
177 | s_F = s_F.to(device)
178 | s_P = s_P.to(device)
179 | s_T = s_T.to(device)
180 | w_s_acc = w_s_acc.to(device)
181 | w_s_stress = w_s_stress.to(device)
182 | w_s_total = w_s_total.to(device)
183 | asr_word_embed = asr_word_embed.to(device)
184 | gt_word_embed = gt_word_embed.to(device)
185 | features_w = features_w.to(device)
186 | features_p = features_p.to(device)
187 | phonevector = phonevector.to(device)
188 |
189 | wav_input = wav.squeeze(1)
190 | optimizer.zero_grad()
191 | output_A, output_F, output_P, output_T, output_w_acc, output_w_stress, output_w_total = net(wav_input, asr_word_embed, gt_word_embed, features_p, features_w, phonevector, word_phone_map, timesplit)
192 | if output_w_acc.shape[1]!=w_s_acc.shape[1]:
193 | continue
194 | loss_A = criterion(output_A, s_A)
195 | loss_F = criterion(output_F, s_F)
196 | loss_P = criterion(output_P, s_P)
197 | loss_T = criterion(output_T, s_T)
198 | loss_wa = criterion(output_w_acc, w_s_acc)
199 | loss_ws = criterion(output_w_stress, w_s_stress)
200 | loss_wt = criterion(output_w_total, w_s_total)
201 | loss = loss_A + loss_F + loss_P + loss_T + loss_wa + loss_ws + loss_wt
202 |
203 | loss.backward()
204 | optimizer.step()
205 | STEPS += 1
206 | running_loss += loss.item()
207 |
208 | print('EPOCH: ' + str(epoch))
209 | print('AVG EPOCH TRAIN LOSS: ' + str(running_loss / STEPS))
210 |
211 |
212 | ## validation
213 | VALSTEPS=0
214 | epoch_val_loss = 0.0
215 | net.eval()
216 | ## clear memory to avoid OOM
217 | with torch.cuda.device(device):
218 | torch.cuda.empty_cache()
219 | torch.cuda.memory_allocated()
220 | torch.cuda.synchronize()
221 |
222 | for i, data in enumerate(validloader, 0):
223 | VALSTEPS+=1
224 |
225 | wav, s_A, s_F, s_P, s_T, w_s_acc, w_s_stress, w_s_total, timesplit, asr_word_embed, gt_word_embed, features_w, features_p, phonevector, word_phone_map, _, _ = data
226 | wav = wav.to(device)
227 | s_A = s_A.to(device)
228 | s_F = s_F.to(device)
229 | s_P = s_P.to(device)
230 | s_T = s_T.to(device)
231 | w_s_acc = w_s_acc.to(device)
232 | w_s_stress = w_s_stress.to(device)
233 | w_s_total = w_s_total.to(device)
234 | asr_word_embed = asr_word_embed.to(device)
235 | gt_word_embed = gt_word_embed.to(device)
236 | features_w = features_w.to(device)
237 | features_p = features_p.to(device)
238 | phonevector = phonevector.to(device)
239 |
240 | wav_input = wav.squeeze(1)
241 |
242 | with torch.no_grad():
243 | output_A, output_F, output_P, output_T, output_w_acc, output_w_stress, output_w_total = net(wav_input, asr_word_embed, gt_word_embed, features_p, features_w, phonevector, word_phone_map, timesplit)
244 | if output_w_acc.shape[1]!=w_s_acc.shape[1]:
245 | continue
246 | loss_A = criterion(output_A, s_A)
247 | loss_F = criterion(output_F, s_F)
248 | loss_P = criterion(output_P, s_P)
249 | loss_T = criterion(output_T, s_T)
250 | loss_wa = criterion(output_w_acc, w_s_acc)
251 | loss_ws = criterion(output_w_stress, w_s_stress)
252 | loss_wt = criterion(output_w_total, w_s_total)
253 | loss = loss_A + loss_F + loss_P + loss_T + loss_wa + loss_ws + loss_wt
254 |
255 | epoch_val_loss += loss.item()
256 |
257 | avg_val_loss=epoch_val_loss/VALSTEPS
258 | print('EPOCH VAL LOSS: ' + str(avg_val_loss))
259 | if avg_val_loss < PREV_VAL_LOSS:
260 | print('Loss has decreased')
261 | PREV_VAL_LOSS=avg_val_loss
262 | torch.save(net.state_dict(), os.path.join(ckptdir,'PRO','best'))
263 | patience = orig_patience
264 | else:
265 | patience-=1
266 | if patience == 0:
267 | print('loss has not decreased for ' + str(orig_patience) + ' epochs; early stopping at epoch ' + str(epoch))
268 | break
269 |
270 | print('Finished Training of Pronunciation Assessment Model')
271 |
272 | if __name__ == '__main__':
273 | main()
274 |
--------------------------------------------------------------------------------
/processors.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | import re
6 | import numpy as np
7 | from itertools import groupby, chain, accumulate
8 | import soundfile as sf
9 | import librosa.core
10 | import unicodedata
11 | from builtins import str as unicode
12 | from nltk.tokenize import TweetTokenizer
13 | word_tokenize = TweetTokenizer().tokenize
14 |
15 | from g2p_en import G2p
16 | from g2p_en.expand import normalize_numbers
17 | from g2pM import G2pM
18 | from transformers import Wav2Vec2CTCTokenizer,Wav2Vec2FeatureExtractor, Wav2Vec2Processor
19 |
20 |
21 |
22 | class CharsiuPreprocessor:
23 |
24 | def __init__(self):
25 | pass
26 |
27 |
28 | def get_phones_and_words(self):
29 | raise NotImplementedError
30 |
31 |
32 | def get_phone_ids(self):
33 | raise NotImplementedError
34 |
35 |
36 | def mapping_phone2id(self,phone):
37 | '''
38 | Convert a phone to a numerical id
39 |
40 | Parameters
41 | ----------
42 | phone : str
43 | A phonetic symbol
44 |
45 | Returns
46 | -------
47 | int
48 | A one-hot id for the input phone
49 |
50 | '''
51 | return self.processor.tokenizer.convert_tokens_to_ids(phone)
52 |
53 | def mapping_id2phone(self,idx):
54 | '''
55 | Convert a numerical id to a phone
56 |
57 | Parameters
58 | ----------
59 | idx : int
60 | A one-hot id for a phone
61 |
62 | Returns
63 | -------
64 | str
65 | A phonetic symbol
66 |
67 | '''
68 |
69 | return self.processor.tokenizer.convert_ids_to_tokens(idx)
70 |
71 |
72 | def audio_preprocess(self,audio,sr=16000):
73 | '''
74 | Load and normalize audio
75 | If the sampling rate is incompatible with models, the input audio will be resampled.
76 |
77 | Parameters
78 | ----------
79 | path : str
80 | The path to the audio
81 | sr : int, optional
82 | Audio sampling rate, either 16000 or 32000. The default is 16000.
83 |
84 | Returns
85 | -------
86 | torch.Tensor [(n,)]
87 | A list of audio sample as an one dimensional torch tensor
88 |
89 | '''
90 | if type(audio)==str:
91 | if sr == 16000:
92 | features,fs = sf.read(audio)
93 | assert fs == 16000
94 | else:
95 | features, _ = librosa.core.load(audio,sr=sr)
96 | elif isinstance(audio, np.ndarray):
97 | features = audio
98 | else:
99 | raise Exception('The input must be a path or a numpy array!')
100 | return self.processor(features, sampling_rate=16000,return_tensors='pt').input_values.squeeze()
101 |
102 | '''
103 | English g2p processor
104 | '''
105 | class CharsiuPreprocessor_en(CharsiuPreprocessor):
106 |
107 | def __init__(self):
108 |
109 | tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('charsiu/tokenizer_en_cmu')
110 | feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
111 | self.processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
112 | self.g2p = G2p()
113 | self.sil = '[SIL]'
114 | self.sil_idx = self.mapping_phone2id(self.sil)
115 | # self.punctuation = set('.,!?')
116 | self.punctuation = set()
117 |
118 | def get_phones_and_words(self,sen):
119 | '''
120 | Convert texts to phone sequence
121 |
122 | Parameters
123 | ----------
124 | sen : str
125 | A str of input sentence
126 |
127 | Returns
128 | -------
129 | sen_clean : list
130 | A list of phone sequence without stress marks
131 | sen : list
132 | A list of phone sequence with stress marks
133 |
134 |
135 | xxxxx should sen_clean be deleted?
136 |
137 | '''
138 |
139 | phones = self.g2p(sen)
140 | words = self._get_words(sen)
141 |
142 | phones = list(tuple(g) for k,g in groupby(phones, key=lambda x: x != ' ') if k)
143 |
144 | aligned_phones = []
145 | aligned_words = []
146 | for p,w in zip(phones,words):
147 | if re.search(r'\w+\d?',p[0]):
148 | aligned_phones.append(p)
149 | aligned_words.append(w)
150 | elif p in self.punctuation:
151 | aligned_phones.append((self.sil,))
152 | aligned_words.append(self.sil)
153 |
154 | assert len(aligned_words) == len(aligned_phones)
155 |
156 | return aligned_phones, aligned_words
157 |
158 | assert len(words) == len(phones)
159 |
160 | return phones, words
161 |
162 |
163 |
164 | def get_phone_ids(self,phones,append_silence=True):
165 | '''
166 | Convert phone sequence to ids
167 |
168 | Parameters
169 | ----------
170 | phones : list
171 | A list of phone sequence
172 | append_silence : bool, optional
173 | Whether silence is appended at the beginning and the end of the sequence.
174 | The default is True.
175 |
176 | Returns
177 | -------
178 | ids: list
179 | A list of one-hot representations of phones
180 |
181 | '''
182 | phones = list(chain.from_iterable(phones))
183 | ids = [self.mapping_phone2id(re.sub(r'\d','',p)) for p in phones]
184 |
185 | # append silence at the beginning and the end
186 | if append_silence:
187 | if ids[0]!=self.sil_idx:
188 | ids = [self.sil_idx]+ids
189 | if ids[-1]!=self.sil_idx:
190 | ids.append(self.sil_idx)
191 | return ids
192 |
193 |
194 |
195 | def _get_words(self,text):
196 | '''
197 | from G2P_en
198 | https://github.com/Kyubyong/g2p/blob/master/g2p_en/g2p.py
199 |
200 | Parameters
201 | ----------
202 | sen : TYPE
203 | DESCRIPTION.
204 |
205 | Returns
206 | -------
207 | words : TYPE
208 | DESCRIPTION.
209 |
210 | '''
211 |
212 | text = unicode(text)
213 | text = normalize_numbers(text)
214 | text = ''.join(char for char in unicodedata.normalize('NFD', text)
215 | if unicodedata.category(char) != 'Mn') # Strip accents
216 | text = text.lower()
217 | text = re.sub("[^ a-z'.,?!\-]", "", text)
218 | text = text.replace("i.e.", "that is")
219 | text = text.replace("e.g.", "for example")
220 |
221 | # tokenization
222 | words = word_tokenize(text)
223 |
224 | return words
225 |
226 | def align_words(self, preds, phones, words, return_map=False):
227 |
228 | # add idx to the word so that it won't merge when the input sentence has duplicated words
229 | words_rep = [w+'_'+str(idx) for idx, (ph,w) in enumerate(zip(phones,words)) for p in ph]
230 | phones_rep = [re.sub(r'\d','',p) for ph,w in zip(phones,words) for p in ph]
231 | assert len(words_rep)==len(phones_rep)
232 | # match each phone to its word
233 | word_dur = []
234 | count = 0
235 | for dur in preds:
236 | if dur[-1] == '[SIL]':
237 | word_dur.append((dur,'[SIL]'))
238 | else:
239 | while dur[-1] != phones_rep[count]:
240 | count += 1
241 | word_dur.append((dur,words_rep[count])) #((start,end,phone),word)
242 |
243 | # merge phone-to-word alignment to derive word duration
244 | words = []
245 | word_phone_map = [] #word-to-phone index: word_phone_map[i] = num of phone of word i
246 | for key, group in groupby(word_dur, lambda x: x[-1]):
247 | group = list(group)
248 | entry = (group[0][0][0],group[-1][0][1],key.split('_')[0])
249 | words.append(entry)
250 | word_phone_map.append(len(group))
251 | if not return_map:
252 | return words
253 | else:
254 |
255 | idx=0
256 | word_phone_map_converted = []
257 | for length in(word_phone_map):
258 | the_word = []
259 | for i in range(idx,idx+length):
260 | the_word.append(i)
261 | word_phone_map_converted.append(the_word)
262 | idx+=length
263 |
264 | return words, word_phone_map_converted
265 |
266 |
267 | '''
268 | Mandarin g2p processor
269 | '''
270 |
271 |
272 | class CharsiuPreprocessor_zh(CharsiuPreprocessor_en):
273 |
274 | def __init__(self):
275 | tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('charsiu/tokenizer_zh_pinyin')
276 | feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
277 | self.processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
278 | self.g2p = G2pM()
279 | self.sil = "[SIL]"
280 | self.sil_idx = self.mapping_phone2id(self.sil)
281 | #self.punctuation = set('.,!?。,!?、')
282 | self.punctuation = set()
283 | # Pinyin tables
284 | self.consonant_list = set(['b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k',
285 | 'h', 'j', 'q', 'x', 'zh', 'ch', 'sh', 'r', 'z',
286 | 'c', 's'])
287 |
288 | self.transform_dict = {'ju':'jv', 'qu':'qv', 'xu':'xv','jue':'jve',
289 | 'que':'qve', 'xue':'xve','quan':'qvan',
290 | 'xuan':'xvan','juan':'jvan',
291 | 'qun':'qvn','xun':'xvn', 'jun':'jvn',
292 | 'yuan':'van', 'yue':'ve', 'yun':'vn',
293 | 'you':'iou', 'yan':'ian', 'yin':'in',
294 | 'wa':'ua', 'wo':'uo', 'wai':'uai',
295 | 'weng':'ueng', 'wang':'uang','wu':'u',
296 | 'yu':'v','yi':'i','yo':'io','ya':'ia', 'ye':'ie',
297 | 'yao':'iao','yang':'iang', 'ying':'ing', 'yong':'iong',
298 | 'yvan':'van', 'yve':'ve', 'yvn':'vn',
299 | 'wa':'ua', 'wo':'uo', 'wai':'uai',
300 | 'wei':'ui', 'wan':'uan', 'wen':'un',
301 | 'weng':'ueng', 'wang':'uang','yv':'v',
302 | 'wuen':'un','wuo':'uo','wuang':'uang',
303 | 'wuan':'uan','wua':'ua','wuai':'uai',
304 | 'zhi':'zhiii','chi':'chiii','shi':'shiii',
305 | 'zi':'zii','ci':'cii','si':'sii'}
306 | self.er_mapping ={'er1':('e1','rr'),'er2':('e2','rr'),'er3':('e3','rr'),'er4':('e4','rr'),
307 | 'er5':('e5','rr'),'r5':('e5','rr')}
308 | self.rhyme_mapping = {'iu1':'iou1','iu2':'iou2','iu3':'iou3','iu4':'iou4','iu5':'iou5',
309 | 'u:e1':'ve1','u:e2':'ve2','u:e3':'ve3','u:e4':'ve4','u:e5':'ve5',
310 | 'u:1':'v1','u:2':'v2','u:3':'v3','u:4':'v4','u:5':'v5',
311 | 'ueng1':('u1','eng1'),'ueng2':('u2','eng2'),'ueng3':('u3','eng3'),
312 | 'ueng4':('u4','eng4'),'ueng5':('u5','eng5'),'io5':('i5','o5'),
313 | 'io4':('i4','o4'),'io1':('i1','o1')}
314 |
315 | def get_phones_and_words(self,sen):
316 | '''
317 | Convert texts to phone sequence
318 |
319 | Parameters
320 | ----------
321 | sen : str
322 | A str of input sentence
323 |
324 | Returns
325 | -------
326 | sen_clean : list
327 | A list of phone sequence without stress marks
328 | sen : list
329 | A list of phone sequence with stress marks
330 |
331 | xxxxx should sen_clean be removed?
332 | '''
333 |
334 | phones = self.g2p(sen)
335 |
336 | aligned_phones = []
337 | aligned_words = []
338 | for p,w in zip(phones,sen):
339 | if re.search(r'\w+:?\d',p):
340 | aligned_phones.append(self._separate_syllable(self.transform_dict.get(p[:-1],p[:-1])+p[-1]))
341 | aligned_words.append(w)
342 | elif p in self.punctuation:
343 | aligned_phones.append((self.sil,))
344 | aligned_words.append(self.sil)
345 |
346 | assert len(aligned_phones)==len(aligned_words)
347 | return aligned_phones, aligned_words
348 |
349 |
350 | def get_phone_ids(self,phones,append_silence=True):
351 | '''
352 | Convert phone sequence to ids
353 |
354 | Parameters
355 | ----------
356 | phones : list
357 | A list of phone sequence
358 | append_silence : bool, optional
359 | Whether silence is appended at the beginning and the end of the sequence.
360 | The default is True.
361 |
362 | Returns
363 | -------
364 | ids: list
365 | A list of one-hot representations of phones
366 |
367 | '''
368 | phones = list(chain.from_iterable(phones))
369 | ids = [self.mapping_phone2id(p) for p in phones]
370 |
371 | # append silence at the beginning and the end
372 | if append_silence:
373 | if ids[0]!=self.sil_idx:
374 | ids = [self.sil_idx]+ids
375 | if ids[-1]!=self.sil_idx:
376 | ids.append(self.sil_idx)
377 | return ids
378 |
379 |
380 | def _separate_syllable(self,syllable):
381 | """
382 | seprate syllable to consonant + ' ' + vowel
383 |
384 | Parameters
385 | ----------
386 | syllable : xxxxx TYPE
387 | xxxxx DESCRIPTION.
388 |
389 | Returns
390 | -------
391 | syllable: xxxxx TYPE
392 | xxxxxx DESCRIPTION.
393 |
394 | """
395 |
396 | assert syllable[-1].isdigit()
397 | if syllable == 'ri4':
398 | return ('r','iii4')
399 | if syllable[:-1] == 'ueng' or syllable[:-1] == 'io':
400 | return self.rhyme_mapping.get(syllable,syllable)
401 | if syllable in self.er_mapping.keys():
402 | return self.er_mapping[syllable]
403 | if syllable[0:2] in self.consonant_list:
404 | #return syllable[0:2].encode('utf-8'),syllable[2:].encode('utf-8')
405 | return syllable[0:2], self.rhyme_mapping.get(syllable[2:],syllable[2:])
406 | elif syllable[0] in self.consonant_list:
407 | #return syllable[0].encode('utf-8'),syllable[1:].encode('utf-8')
408 | return syllable[0], self.rhyme_mapping.get(syllable[1:],syllable[1:])
409 | else:
410 | #return (syllable.encode('utf-8'),)
411 | return (syllable,)
412 |
413 |
414 | def align_words(self, preds, phones, words):
415 |
416 | words_rep = [w+str(i) for i,(ph,w) in enumerate(zip(phones,words)) for p in ph]
417 | phones_rep = [p for ph,w in zip(phones,words) for p in ph]
418 | assert len(words_rep)==len(phones_rep)
419 |
420 | # match each phone to its word
421 | word_dur = []
422 | count = 0
423 | for dur in preds:
424 | if dur[-1] == '[SIL]':
425 | word_dur.append((dur,'[SIL]'))
426 | else:
427 | while dur[-1] != phones_rep[count]:
428 | count += 1
429 | if count >= len(phones_rep):
430 | break
431 | word_dur.append((dur,words_rep[count])) #((start,end,phone),word)
432 |
433 | # merge phone-to-word alignment to derive word duration
434 | words = []
435 | for key, group in groupby(word_dur, lambda x: x[-1]):
436 | group = list(group)
437 | entry = (group[0][0][0],group[-1][0][1],re.sub(r'\d','',key))
438 | words.append(entry)
439 |
440 | return words
441 |
442 |
443 |
444 | if __name__ == '__main__':
445 | '''
446 | Testing functions
447 | '''
448 |
449 | processor = CharsiuPreprocessor_zh()
450 | phones, words = processor.get_phones_and_words("鱼香肉丝、王道椒香鸡腿和川蜀鸡翅。")
451 | print(phones)
452 | print(words)
453 | ids = processor.get_phone_ids(phones)
454 | print(ids)
455 |
456 | processor = CharsiuPreprocessor_en()
457 | phones, words = processor.get_phones_and_words("I’m playing octopath right now!")
458 | print(phones)
459 | print(words)
460 | ids = processor.get_phone_ids(phones)
461 | print(ids)
462 |
463 |
464 |
465 |
--------------------------------------------------------------------------------
/utils_assessment.py:
--------------------------------------------------------------------------------
1 | import os
2 | import nltk
3 | import string
4 | import num2words
5 | import numpy as np
6 | from dataclasses import dataclass
7 | from Levenshtein import ratio
8 | from nltk.corpus import cmudict
9 | from difflib import SequenceMatcher
10 |
11 |
12 | def fit_one_hot(inputlist):
13 | mapping = {}
14 | for i in range(len(inputlist)):
15 | mapping[inputlist[i]]=i
16 | return mapping
17 |
18 | """
19 | load the cumdict and convert it to dict for shorter processing time
20 | """
21 | nltk.download('cmudict')
22 | all_phonemes = set()
23 | entries = cmudict.entries()
24 | cmudict_dict = {entry[0].lower(): entry[1] for entry in entries}
25 | for entry in entries:
26 | phonemes = entry[1]
27 | all_phonemes.update(phonemes)
28 | all_phonemes = list(all_phonemes)
29 | all_phonemes.append('')
30 | all_phonemes = fit_one_hot(all_phonemes)
31 | phone_vector_dimension = len(all_phonemes.keys())
32 |
33 |
34 | def get_filepaths(directory, format='.wav'):
35 | """
36 | load all file in the directory
37 | Parameters
38 | ----input-----
39 | directory: str. path of directory
40 | ----Return----
41 | file_paths: list. paths of files in the directory
42 | """
43 | file_paths = []
44 | for root, _, files in os.walk(directory):
45 | for filename in files:
46 | filepath = os.path.join(root, filename)
47 | if filename.endswith(format):
48 | file_paths.append(filepath)
49 | return file_paths
50 |
51 | def remove_pun_except_apostrophe(input_string):
52 | """
53 | remove punctuations (except for ' ) of the inupt string.
54 | """
55 | translator = str.maketrans('', '', string.punctuation.replace("'", ""))
56 | output_string = input_string.translate(translator).replace(' ',' ')
57 | return output_string
58 |
59 | def remove_pun(input_string):
60 | """
61 | remove punctuations of the input_string.
62 | """
63 | input_string = "".join([char for char in input_string if char not in string.punctuation])
64 | return input_string
65 |
66 | def get_transcript(audio, whisper_model, return_seg=False):
67 | """
68 | get ASR result using whisper model
69 |
70 | Parameters
71 | ----input-----
72 | audio: Union[str, np.ndarray, torch.Tensor]
73 | whisper_model:
74 | load whisper_model using:
75 | 、、、
76 | import whisper
77 | whisper_model = whisper.load_model("base.en")
78 | 、、、
79 | return_seg: bool
80 | whether to return the segmentation result
81 | ----Return----
82 | transcript: str. ASR result of the input wavfile
83 | """
84 | result = whisper_model.transcribe(audio, fp16=False)
85 | if not return_seg:
86 | return result['text']
87 | else:
88 | return result
89 |
90 | def convert_num_to_word(sen):
91 | """
92 | convert digit in a sentence to word. e.g. "7112" to "seven one one two".
93 | """
94 | try: #for 4 digit samples of speechocean data.
95 | int(sen.replace(' ',''))
96 | sen = ' '.join([char for char in sen])
97 | sen = ' '.join([num2words.num2words(i) if i.isdigit() else i for i in sen.split()])
98 | sen = sen.replace(' ',' ')
99 | except:
100 | sen = ' '.join([num2words.num2words(i) if i.isdigit() else i for i in sen.split()])
101 | return sen
102 |
103 |
104 | def get_phone_list(word_list):
105 | """
106 | convert word to phone using cmudict
107 | Parameters
108 | ----input-----
109 | word_list: list of word. e.g. [[word1],[word2]...] or [[word1, word2],...]
110 | ----Return----
111 | phone_list: list of corresponding phone e.g [[p1-1,p1-2], [p2-1,p2-2,p2-3],...] or [[p1-1,p1-2,p2-1,p2-2,p2-3], ...]
112 | """
113 | phone_list = []
114 | for word_position in word_list:
115 | the_phone_list = []
116 | for word in word_position:
117 | phone = cmudict_dict.get(word.lower(), '')
118 | the_phone_list.extend(phone)
119 | phone_list.append(the_phone_list)
120 |
121 | return phone_list
122 |
123 | def get_phone_vector(phone_list):
124 | """
125 | convert phone to phone-vector using one-hot encoding
126 | Parameters
127 | ----input-----
128 | phone_list: list of phone. e.g. [[phone1-1, phone1-2],[phone2-1, phone2-2, phone2-3]...]
129 | ----Return----
130 | phone_vector: np.ndarray, [shape=(number of word, phone_vector_dimension=71)]
131 | """
132 | num_of_word = len(phone_list)
133 | phone_vector = np.zeros((num_of_word, phone_vector_dimension))
134 |
135 | for word_idx, the_phone_list in enumerate(phone_list):
136 | the_phone_vector = np.zeros((phone_vector_dimension, ))
137 | for phone in the_phone_list:
138 | the_phone_vector[all_phonemes[phone]]+=1
139 | phone_vector[word_idx,:] = the_phone_vector
140 |
141 | return phone_vector
142 |
143 | def get_phone_features_from_wordlist(gt_word_list, asr_word_list):
144 |
145 | """
146 | get phone features of ground-turth word list and ASR word list
147 | Parameters
148 | ----input-----
149 | gt_word_list: list of ground-turth word. e.g. [[word1-gt],[word2-gt],[word3-gt]...]
150 | asr_word_list: list of asr word. e.g. [[word1-asr, word2-asr],[word3],[word4]...]
151 | # note: gt_word_list[i] is aligned with asr_word_list[i] based on the audio-word forced alignment result
152 | ----Return----
153 | phone_distance: np.ndarray, [shape=(number of ground-truth word, 1)]
154 | the phone distance (by SequenceMatcher) between ground-truth word and the asr-word
155 | phone_vector: np.ndarray, [shape=(number of ground-truth word, (phone_vector_dimension=71)*2)]
156 | the phone vector of ground-truth word and the asr-word
157 | phone_count: np.ndarray, [shape=(number of ground-truth word, 1)]
158 | the number of phones of ground-truth word divided by the number of phones of asr word
159 | """
160 |
161 | gt_length = len(gt_word_list)
162 | gt_phone_list = get_phone_list(gt_word_list)
163 | asr_phone_list = get_phone_list(asr_word_list)
164 |
165 | gt_phone_vector = get_phone_vector(gt_phone_list)
166 | asr_phone_vector = get_phone_vector(asr_phone_list)
167 |
168 | phone_distance = np.zeros((gt_length, 1))
169 | phone_count = np.zeros((gt_length, 1))
170 | for word_idx in range(gt_length):
171 | the_distance = SequenceMatcher(None, gt_phone_list[word_idx],asr_phone_list[word_idx])
172 | phone_distance[word_idx,0]=the_distance.ratio()
173 | if len(asr_phone_list[word_idx])!=0:
174 | phone_count[word_idx,0]=len(gt_phone_list[word_idx])/len(asr_phone_list[word_idx])
175 |
176 | phone_vector = np.concatenate((gt_phone_vector, asr_phone_vector), axis=1)
177 |
178 | return phone_distance, phone_vector, phone_count
179 |
180 |
181 | def get_word_alignment_features(alignment_gt, alignment_asr):
182 | """
183 | get word-aligned features of ground-turth word list and ASR word list
184 | Parameters
185 | ----input-----
186 | alignment_gt: list, len = number of words in the ground-truth sentence
187 | word-audio alignment result of ground-truth sentence. e.g., [[(start_time1, end_time1, word1)], [(start_time2, end_time2, word2)], ...]
188 | alignment_asr: list, len = number of words in the ASR sentence
189 | word-audio alignment result of ASR sentence. e.g., [[(start_time1, end_time1, word1)], [(start_time2, end_time2, word2)], ...]
190 |
191 | ----Return----
192 | gt_word_list: list, len = number of words in the ground-truth sentence
193 | words in the ground-truth sentence, e.g. [[word1_gt],[word2_gt],[word3_gt],...]
194 | asr_word_list: list, len = number of words in the ground-truth sentence
195 | words in the ASR sentence aligned with ground-truth word. e.g. [[word1_asr, word2_asr],[word3_asr],...]
196 | alignment_features: np.ndarray, [shape=(number of words in the ground-truth sentence, 10)]
197 | phonevector: np.ndarray, [shape=(number of words in the ground-truth sentence, 71*2)]
198 | phonevector of the ground-truth words and asr words
199 | asr_wordidx_list: list, len = number of words in the ground-truth sentence
200 | mapping between ground-truth words and asr words.
201 | asr_wordidx_list[i] = [j,m] means alignment_gt[i] is overlapped with alignment_asr[j] and alignment_asr[m]
202 | """
203 | gt_length = len(alignment_gt)
204 |
205 | gt_word_list = []
206 | asr_word_list = []
207 | asr_wordidx_list = []
208 | asr_distance = np.zeros((gt_length, 1))
209 | duration_gt = np.zeros((gt_length, 1))
210 | duration_asr = np.zeros((gt_length, 1))
211 | time_diff_start = np.zeros((gt_length, 1))
212 | time_diff_end = np.zeros((gt_length, 1))
213 | interval_gt = np.zeros((gt_length, 1))
214 | interval_asr = np.zeros((gt_length, 1))
215 |
216 | pre_end_gt = 0
217 | for gt_idx, gt_value in enumerate(alignment_gt):
218 |
219 | gt_word = gt_value[2]
220 | gt_start = float(gt_value[0])
221 | gt_end = float(gt_value[1])
222 |
223 | duration_gt[gt_idx,0] = (gt_end-gt_start)
224 | interval_gt[gt_idx,0] = (gt_start - pre_end_gt)
225 | pre_end_gt = gt_end
226 |
227 | gt_word_list.append([gt_word])
228 |
229 | asr_word_all = []
230 | asr_wordidx = []
231 | asr_start_flag = True
232 | the_asr_start = 0
233 | the_asr_end = float(alignment_asr[-1][1])
234 |
235 | pre_end_asr = 0
236 | asr_interval_list = 0
237 | for asr_idx, asr_value in enumerate(alignment_asr):
238 | asr_start = float(asr_value[0])
239 | asr_end = float(asr_value[1])
240 | if gt_end <= asr_start:
241 | break
242 | if gt_start >= asr_end:
243 | continue
244 | if max(gt_start, asr_start) <= min(gt_end, asr_end):
245 | asr_word = asr_value[2]
246 | asr_word_all.append(asr_word)
247 | asr_wordidx.append(asr_idx)
248 | asr_interval_list += (asr_start - pre_end_asr)
249 | pre_end_asr = asr_end
250 | the_asr_end = asr_end
251 | if asr_start_flag:
252 | the_asr_start = asr_start
253 | asr_start_flag= False
254 |
255 | duration_asr[gt_idx,0] = (the_asr_end - the_asr_start)
256 | time_diff_start[gt_idx,0] = (the_asr_start - gt_start)
257 | time_diff_end[gt_idx,0] = (the_asr_end - gt_end)
258 | if len(asr_wordidx)!=0:
259 | interval_asr[gt_idx,0] = asr_interval_list/len(asr_wordidx)
260 |
261 | asr_wordidx_list.append(asr_wordidx)
262 | asr_word_list.append(asr_word_all)
263 | asr_distance[gt_idx,0] = ratio(gt_word, ' '.join(asr_word_all))*10
264 |
265 | align_word_count = [len(asr_word_list[word_idx]) for word_idx in range(gt_length)]
266 | phone_distance, phonevector, phone_count = get_phone_features_from_wordlist(gt_word_list, asr_word_list)
267 |
268 | align_word_count = np.asarray(align_word_count)
269 | align_word_count = np.expand_dims(align_word_count, axis=1)
270 |
271 | alignment_features = np.concatenate((asr_distance, align_word_count, duration_gt, duration_asr, time_diff_start, time_diff_end, phone_distance, phone_count, interval_gt, interval_asr), axis=1)
272 |
273 | return gt_word_list, asr_word_list, alignment_features, phonevector, asr_wordidx_list
274 |
275 |
276 | def get_phone_alignment_features(pred_phones_gt, pred_phones_asr, pred_prob_gt, phone_ids_gt, pred_prob_asr, phone_ids_asr):
277 | """
278 | get phone-aligned features of ground-turth phone list and ASR phone list
279 | Parameters
280 | ----input-----
281 | pred_phones_gt: list, len = number of phones in the ground-truth sentence
282 | phoneme-audio alignment result of ground-truth sentence. e.g., [[(start_time1, end_time1, phone1)], [(start_time2, end_time2, phone2)], ...]
283 | pred_phones_asr: list, len = number of phones in the ASR sentence
284 | phoneme-audio alignment result of ASR sentence. e.g., [[(start_time1, end_time1, phone1)], [(start_time2, end_time2, phone2)], ...]
285 | pred_prob_gt: np.ndarray, [shape=(number of phones in the ground-truth sentence, 42)]
286 | output of the charsiu model
287 | phone_ids_gt: list, len = number of phones in the ground-truth sentence
288 | index of the aligned phone.
289 | pred_prob_asr: np.ndarray, [shape=(number of phones in the ASR sentence, 42)]
290 | output of the charsiu model
291 | phone_ids_asr: list, len = number of phones in the ASR sentence
292 | index of the aligned phone.
293 |
294 | ----Return----
295 | features: np.ndarray, [shape=(number of phones in the ground-truth sentence, 93)]
296 | extracted features.
297 | features[:,:9] = alignment features
298 | features[:,9:9+42] = pred_prob_gt
299 | features[:,9+42:] = aligned_pred_prob_asr (align ASR phone to ground-truth phone using the time information)
300 | """
301 | gt_length = len(pred_phones_gt)
302 |
303 | asr_phone_list = []
304 | asr_phoneidx_list = []
305 | duration_gt = np.zeros((gt_length, 1))
306 | duration_asr = np.zeros((gt_length, 1))
307 | time_diff_start = np.zeros((gt_length, 1))
308 | time_diff_end = np.zeros((gt_length, 1))
309 | interval_gt = np.zeros((gt_length, 1))
310 | interval_asr = np.zeros((gt_length, 1))
311 |
312 | pre_end_gt = 0
313 | for gt_idx, gt_value in enumerate(pred_phones_gt):
314 |
315 | gt_start = gt_value[0]
316 | gt_end = gt_value[1]
317 |
318 | duration_gt[gt_idx,0] = (gt_end - gt_start)
319 | interval_gt[gt_idx,0] = (gt_start - pre_end_gt)
320 | pre_end_gt = gt_end
321 |
322 | asr_phone_all = []
323 | asr_phoneidx = []
324 | asr_start_flag = True
325 | the_asr_start = 0
326 |
327 | pre_end_asr = 0
328 | asr_interval_list = 0
329 | for asr_idx, asr_value in enumerate(pred_phones_asr):
330 | asr_start = asr_value[0]
331 | asr_end = asr_value[1]
332 | if gt_end <= asr_start:
333 | break
334 | if gt_start >= asr_end:
335 | continue
336 | if max(gt_start, asr_start) <= min(gt_end, asr_end):
337 | asr_phone = asr_value[2]
338 | asr_phone_all.append(asr_phone)
339 | asr_phoneidx.append(asr_idx)
340 | asr_interval_list += (asr_start - pre_end_asr)
341 | pre_end_asr = asr_end
342 | the_asr_end = asr_end
343 | if asr_start_flag:
344 | the_asr_start = asr_start
345 | asr_start_flag= False
346 |
347 | duration_asr[gt_idx,0] = (the_asr_end - the_asr_start)
348 | time_diff_start[gt_idx,0] = (the_asr_start - gt_start)
349 | time_diff_end[gt_idx,0] = (the_asr_end - gt_end)
350 | if len(asr_phoneidx)!=0:
351 | interval_asr[gt_idx,0] = asr_interval_list/len(asr_phoneidx)
352 |
353 | asr_phoneidx_list.append(asr_phoneidx)
354 | asr_phone_list.append(asr_phone_all)
355 |
356 | align_phone_count = [len(asr_phone_list[phone_idx]) for phone_idx in range(gt_length)]
357 |
358 | align_phone_count = np.asarray(align_phone_count)
359 | align_phone_count = np.expand_dims(align_phone_count, axis=1)
360 |
361 | the_gt_phone_prob = np.zeros((pred_prob_gt.shape[0], 1))
362 | the_asr_phone_prob = np.zeros((pred_prob_asr.shape[0], 1))
363 |
364 | the_gt_phone_prob[:, 0] = pred_prob_gt[np.arange(pred_prob_gt.shape[0]), phone_ids_gt]
365 | the_asr_phone_prob[:, 0] = pred_prob_asr[np.arange(pred_prob_asr.shape[0]), phone_ids_asr]
366 |
367 | aligned_asr_prob = np.zeros((pred_prob_gt.shape[0], 1))
368 | aligned_pred_prob_asr = np.zeros((pred_prob_gt.shape))
369 | for gt_p_idx, gt_p in enumerate(asr_phoneidx_list):
370 | aligned_asr_prob[gt_p_idx,:] = np.mean(the_asr_phone_prob[gt_p,:], axis=0)
371 | aligned_pred_prob_asr[gt_p_idx,:] = np.mean(pred_prob_asr[gt_p,:], axis=0)
372 |
373 | features = np.concatenate((align_phone_count, duration_gt, duration_asr, time_diff_start, time_diff_end, interval_gt, interval_asr, the_gt_phone_prob, aligned_asr_prob, pred_prob_gt, aligned_pred_prob_asr), axis=1)
374 |
375 | return features
376 |
377 |
378 | def get_roberta_word_embed(word_list, num_of_token, roberta):
379 | """
380 | get roberta word embedding of input word list
381 | Parameters
382 | ----input-----
383 | word_list: list, len = number of words
384 | num_of_token: int
385 | number of the words in the sentence.
386 | roberta: object.
387 | load roberta model using:
388 | 、、、
389 | from fairseq.models.roberta import RobertaModel
390 | roberta = RobertaModel.from_pretrained('./fairseq_roberta', checkpoint_file='model.pt')
391 | roberta.eval()
392 | 、、、
393 | ----Return----
394 | sen_vec: np.ndarray, [shape=(num_of_token, 768)]
395 | word-by-word roberta embedding
396 | """
397 | sen_vec = np.zeros((num_of_token, 768))
398 |
399 | for w_idx, the_word_list in enumerate(word_list):
400 | the_sen = ' '.join(the_word_list)
401 | if the_sen=='':
402 | continue
403 | doc = roberta.extract_features_aligned_to_words(the_sen)
404 | the_sen_vec = np.zeros((768,))
405 | for tok in doc:
406 | if str(tok)=='' or str(tok)=='':
407 | continue
408 | the_vec = tok.vector.detach().numpy()
409 | the_sen_vec[:] += the_vec
410 | the_sen_vec /= len(the_word_list)
411 | sen_vec[w_idx,:] = the_sen_vec
412 |
413 | return sen_vec
414 |
415 | def get_charsiu_alignment(audio, sen, aligner):
416 | pred_phones, pred_words, words, pred_prob, phone_ids, word_phone_map = aligner.align(audio=audio, text=sen)
417 | return pred_phones, pred_words, words, pred_prob, phone_ids, word_phone_map
418 |
419 | def get_match_index(pred_words, words):
420 | """
421 | remove [SIL] in the charsiu word alignment result.
422 | """
423 | selected_idx = []
424 | curren_idx=0
425 | for the_word in words:
426 | for i in range(curren_idx, len(pred_words)):
427 | if pred_words[i][2]==the_word:
428 | selected_idx.append(i)
429 | curren_idx = i+1
430 | break
431 | return selected_idx
432 |
433 |
434 | def feature_extraction(audio, gt_sen, asr_sen, alignment_model, word_model):
435 | """
436 | extract features for the assessment model
437 | Parameters
438 | ----input-----
439 | audio: str. or np.ndarray [shape=(n,)]
440 | path to the input wavfile or np.ndarray of wavfile
441 | gt_sen: str.
442 | ground-truth sentence
443 | asr_sen: str.
444 | ASR sentence:
445 | alignment_model: object
446 | load charsiu model using:
447 | 、、、
448 | from Charsiu import charsiu_forced_aligner
449 | alignment_model = charsiu_forced_aligner(aligner='charsiu/en_w2v2_fc_10ms')
450 | 、、、
451 | word_model: object:
452 | load roberta model using:
453 | 、、、
454 | from fairseq.models.roberta import RobertaModel
455 | word_model = RobertaModel.from_pretrained('./fairseq_roberta', checkpoint_file='model.pt')
456 | word_model.eval()
457 | 、、、
458 | ----Return----
459 | """
460 |
461 | pred_phones_gt, pred_words_gt, words_gt, pred_prob_gt, phone_ids_gt, word_phone_map_gt = get_charsiu_alignment(audio, gt_sen, alignment_model)
462 | pred_phones_asr, pred_words_asr, words_asr, pred_prob_asr, phone_ids_asr, _ = get_charsiu_alignment(audio, asr_sen, alignment_model)
463 |
464 | features_p = get_phone_alignment_features(pred_phones_gt, pred_phones_asr, pred_prob_gt, phone_ids_gt, pred_prob_asr, phone_ids_asr)
465 |
466 | selected_idx_gt = get_match_index(pred_words_gt, words_gt)
467 |
468 | word_phone_map_gt = [word_phone_map_gt[i] for i in selected_idx_gt]
469 |
470 | pred_words_gt = np.asarray(pred_words_gt)
471 | pred_words_gt = pred_words_gt[selected_idx_gt]
472 |
473 | selected_idx_asr = get_match_index(pred_words_asr, words_asr)
474 | pred_words_asr = np.asarray(pred_words_asr)
475 | pred_words_asr = pred_words_asr[selected_idx_asr]
476 | gt_word_list, asr_word_list, features_w, phonevector, _ = get_word_alignment_features(pred_words_gt, pred_words_asr)
477 |
478 | num_of_token = len(selected_idx_gt)
479 | gt_word_embed = get_roberta_word_embed(gt_word_list, num_of_token, word_model)
480 | asr_word_embed = get_roberta_word_embed(asr_word_list, num_of_token, word_model)
481 |
482 | return pred_words_gt, features_p, features_w, phonevector, gt_word_embed, asr_word_embed, word_phone_map_gt
483 |
--------------------------------------------------------------------------------
/Charsiu.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | import sys
6 | import torch
7 | from itertools import groupby
8 | sys.path.append('src/')
9 | import numpy as np
10 | #sys.path.insert(0,'src')
11 | from models import Wav2Vec2ForAttentionAlignment, Wav2Vec2ForFrameClassification, Wav2Vec2ForCTC
12 | from utils import seq2duration,forced_align,duration2textgrid,word2textgrid
13 | from processors import CharsiuPreprocessor_zh, CharsiuPreprocessor_en
14 |
15 | processors = {'zh':CharsiuPreprocessor_zh,
16 | 'en':CharsiuPreprocessor_en}
17 |
18 | class charsiu_aligner:
19 |
20 | def __init__(self,
21 | lang='en',
22 | sampling_rate=16000,
23 | device=None,
24 | recognizer=None,
25 | processor=None,
26 | resolution=0.01):
27 |
28 | self.lang = lang
29 |
30 | if processor is not None:
31 | self.processor = processor
32 | else:
33 | self.charsiu_processor = processors[self.lang]()
34 |
35 |
36 |
37 | self.resolution = resolution
38 |
39 | self.sr = sampling_rate
40 |
41 | self.recognizer = recognizer
42 |
43 | if device is None:
44 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
45 | else:
46 | self.device = device
47 |
48 |
49 | def _freeze_model(self):
50 | self.aligner.eval().to(self.device)
51 | if self.recognizer is not None:
52 | self.recognizer.eval().to(self.device)
53 |
54 |
55 |
56 | def align(self,audio,text):
57 | raise NotImplementedError()
58 |
59 |
60 |
61 | def serve(self,audio,save_to,output_format='variable',text=None):
62 | raise NotImplementedError()
63 |
64 |
65 | def _to_textgrid(self,phones,save_to):
66 | '''
67 | Convert output tuples to a textgrid file
68 |
69 | Parameters
70 | ----------
71 | phones : TYPE
72 | DESCRIPTION.
73 |
74 | Returns
75 | -------
76 | None.
77 |
78 | '''
79 | duration2textgrid(phones,save_path=save_to)
80 | print('Alignment output has been saved to %s'%(save_to))
81 |
82 |
83 |
84 | def _to_tsv(self,phones,save_to):
85 | '''
86 | Convert output tuples to a tab-separated file
87 |
88 | Parameters
89 | ----------
90 | phones : TYPE
91 | DESCRIPTION.
92 |
93 | Returns
94 | -------
95 | None.
96 |
97 | '''
98 | with open(save_to,'w') as f:
99 | for start,end,phone in phones:
100 | f.write('%s\t%s\t%s\n'%(start,end,phone))
101 | print('Alignment output has been saved to %s'%(save_to))
102 |
103 |
104 |
105 |
106 |
107 | class charsiu_forced_aligner(charsiu_aligner):
108 |
109 | def __init__(self, aligner, sil_threshold=4, **kwargs):
110 | super(charsiu_forced_aligner, self).__init__(**kwargs)
111 | self.aligner = Wav2Vec2ForFrameClassification.from_pretrained(aligner)
112 | self.sil_threshold = sil_threshold
113 |
114 | self._freeze_model()
115 |
116 |
117 | def align(self, audio, text):
118 | '''
119 | Perform forced alignment
120 |
121 | Parameters
122 | ----------
123 | audio : np.ndarray [shape=(n,)]
124 | time series of speech signal
125 | text : str
126 | The transcription
127 |
128 | Returns
129 | -------
130 | A tuple of aligned phones in the form (start_time, end_time, phone)
131 |
132 | '''
133 | audio = self.charsiu_processor.audio_preprocess(audio,sr=self.sr)
134 | audio = torch.Tensor(audio).unsqueeze(0).to(self.device)
135 | phones, words = self.charsiu_processor.get_phones_and_words(text)
136 | phone_ids = self.charsiu_processor.get_phone_ids(phones)
137 |
138 | with torch.no_grad():
139 | out = self.aligner(audio)
140 |
141 | cost = torch.softmax(out.logits,dim=-1).detach().cpu().numpy().squeeze()
142 | sil_mask = self._get_sil_mask(cost)
143 |
144 | nonsil_idx = np.argwhere(sil_mask!=self.charsiu_processor.sil_idx).squeeze()
145 | if nonsil_idx is None:
146 | raise Exception("No speech detected! Please check the audio file!")
147 |
148 | aligned_phone_ids = forced_align(cost[nonsil_idx,:],phone_ids[1:-1])
149 |
150 | aligned_phones = [self.charsiu_processor.mapping_id2phone(phone_ids[1:-1][i]) for i in aligned_phone_ids]
151 |
152 | pred_phones_ori = self._merge_silence(aligned_phones,sil_mask)
153 |
154 | pred_phones = seq2duration(pred_phones_ori,resolution=self.resolution)
155 |
156 | pred_words, word_phone_map = self.charsiu_processor.align_words(pred_phones, phones, words, return_map=True)
157 |
158 | ## get the predicted probability of output phone
159 | counter = 0
160 | pred_prob = np.empty((len(pred_phones),cost.shape[1]))
161 | group_idx = 0
162 | for _, group in groupby(pred_phones_ori):
163 | length = len(list(group))
164 | pred_prob[group_idx,:] = np.mean(cost[counter:counter+length,:], axis=0)
165 | counter += length
166 | group_idx+=1
167 |
168 | final_phones = [the_phone[2] for the_phone in pred_phones ]
169 | phone_ids = self.charsiu_processor.mapping_phone2id(final_phones)
170 |
171 | return pred_phones, pred_words, words, pred_prob, phone_ids, word_phone_map
172 |
173 |
174 | def serve(self,audio,text,save_to,output_format='textgrid'):
175 | '''
176 | A wrapper function for quick inference
177 |
178 | Parameters
179 | ----------
180 | audio : TYPE
181 | DESCRIPTION.
182 | text : TYPE, optional
183 | DESCRIPTION. The default is None.
184 | output_format : str, optional
185 | Output phone-taudio alignment as a "tsv" or "textgrid" file.
186 | The default is 'textgrid'.
187 |
188 | Returns
189 | -------
190 | None.
191 |
192 | '''
193 | phones, words = self.align(audio,text)
194 |
195 | if output_format == 'tsv':
196 | if save_to.endswith('.tsv'):
197 | save_to_phone = save_to.replace('.tsv','_phone.tsv')
198 | save_to_word = save_to.replace('.tsv','_word.tsv')
199 | else:
200 | save_to_phone = save_to + '_phone.tsv'
201 | save_to_word = save_to + '_word.tsv'
202 |
203 | self._to_tsv(phones, save_to_phone)
204 | self._to_tsv(words, save_to_word)
205 |
206 | elif output_format == 'textgrid':
207 | self._to_textgrid(phones, words, save_to)
208 | else:
209 | raise Exception('Please specify the correct output format (tsv or textgird)!')
210 |
211 | def _to_textgrid(self,phones,words,save_to):
212 | '''
213 | Convert output tuples to a textgrid file
214 |
215 | Parameters
216 | ----------
217 | phones : TYPE
218 | DESCRIPTION.
219 |
220 | Returns
221 | -------
222 | None.
223 |
224 | '''
225 | word2textgrid(phones,words,save_path=save_to)
226 | print('Alignment output has been saved to %s'%(save_to))
227 |
228 |
229 | def _merge_silence(self,aligned_phones,sil_mask):
230 | # merge silent and non-silent intervals
231 | pred_phones = []
232 | count = 0
233 | for i in sil_mask:
234 | if i==self.charsiu_processor.sil_idx:
235 | pred_phones.append('[SIL]')
236 | else:
237 | pred_phones.append(aligned_phones[count])
238 | count += 1
239 | assert len(pred_phones) == len(sil_mask)
240 | return pred_phones
241 |
242 |
243 |
244 | def _get_sil_mask(self,cost):
245 | # single out silent intervals
246 |
247 | preds = np.argmax(cost,axis=-1)
248 | sil_mask = []
249 | for key, group in groupby(preds):
250 | group = list(group)
251 | if (key==self.charsiu_processor.sil_idx and len(group) 0
97 | for kernel in kernels:
98 | self.cnns.append(nn.Conv1d(latest_size, cnn_size, kernel, padding=kernel//2))
99 | latest_size = cnn_size * len(kernels)
100 |
101 | self.out_linear = nn.Linear(latest_size, output_class_num)
102 |
103 | def forward(self, features):
104 | hidden = F.dropout(F.relu(self.in_linear(features)), p=self.drop_p)
105 |
106 | conv_feats = []
107 | hidden = hidden.transpose(1, 2).contiguous()
108 | for cnn in self.cnns:
109 | conv_feats.append(cnn(hidden))
110 | hidden = torch.cat(conv_feats, dim=1).transpose(1, 2).contiguous()
111 | hidden = F.dropout(F.relu(hidden), p=self.drop_p)
112 |
113 | predicted = self.out_linear(hidden)
114 | return predicted
115 |
116 |
117 |
118 | class RNN(nn.Module):
119 |
120 | def __init__(self,hidden_dim,out_dim):
121 | super().__init__()
122 |
123 | self.lstm = nn.LSTM(hidden_dim,hidden_dim,bidirectional=True,num_layers=1,batch_first=True)
124 | self.linear = nn.Sequential(nn.Linear(2*hidden_dim,hidden_dim),
125 | nn.ReLU(),
126 | nn.Linear(hidden_dim,out_dim))
127 |
128 |
129 | def forward(self, embeddings, lens):
130 |
131 | packed_input = pack_padded_sequence(embeddings, lens.cpu(), batch_first=True,enforce_sorted=False)
132 | packed_output, (ht, ct)= self.lstm(packed_input)
133 | out, _ = pad_packed_sequence(packed_output, batch_first=True)
134 | out = self.linear(out)
135 | return out
136 |
137 |
138 | class Wav2Vec2ForAttentionAlignment(Wav2Vec2ForPreTraining):
139 | '''
140 | Implementation adapted from: https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2ForPreTraining
141 | '''
142 | def __init__(self,config):
143 | super().__init__(config)
144 | bert_config = self.get_bert_config(config)
145 | self.bert = BertForMaskedPhoneLM(bert_config)
146 | self.cnn = ConvBank(config.hidden_size,
147 | config.bert_hidden_size,
148 | config.bert_convbank,
149 | config.bert_hidden_size,
150 | config.bert_hidden_size,
151 | config.hidden_dropout)
152 | #self.lm_head = nn.Linear(config.hidden_size,config.vocab_size)
153 | # self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
154 | # self.phone_rnn = RNN(384,config.vocab_size)
155 |
156 | self.attention = Attention(config.bert_hidden_size)
157 | self.align_loss = ForwardSumLoss()
158 |
159 | def freeze_wav2vec2(self):
160 | for param in self.wav2vec2.parameters():
161 | param.requires_grad = False
162 |
163 | def initialize_phone_model(self,path):
164 |
165 | self.bert = BertForMaskedPhoneLM.from_pretrained(path)
166 |
167 | def get_bert_config(self,config):
168 | bert_config = BertConfig(architectures=config.bert_architectures,
169 | attention_probs_dropout_prob=config.bert_attention_probs_dropout_prob,
170 | gradient_checkpointing=config.bert_gradient_checkpointing,
171 | hidden_act=config.bert_hidden_act,
172 | hidden_dropout_prob=config.bert_hidden_dropout_prob,
173 | hidden_size=config.bert_hidden_size,
174 | initializer_range=config.bert_initializer_range,
175 | intermediate_size=config.bert_intermediate_size,
176 | layer_norm_eps=config.bert_layer_norm_eps,
177 | max_position_embeddings=config.bert_max_position_embeddings,
178 | model_type=config.bert_model_type,
179 | num_attention_heads=config.bert_num_attention_heads,
180 | num_hidden_layers=config.bert_num_hidden_layers,
181 | pad_token_id=config.bert_pad_token_id,
182 | position_embedding_type=config.bert_position_embedding_type,
183 | transformers_version=config.bert_transformers_version,
184 | type_vocab_size=config.bert_type_vocab_size,
185 | use_cache=config.bert_use_cache,
186 | vocab_size=config.bert_vocab_size,
187 | convbank=config.bert_convbank)
188 |
189 | return bert_config
190 |
191 | def forward(
192 | self,
193 | input_values,
194 | attention_mask=None,
195 | output_attentions=None,
196 | output_hidden_states=None,
197 | mask_time_indices=None,
198 | return_dict=None,
199 | labels=None,
200 | labels_attention_mask=None,
201 | text_len=None,
202 | frame_len=None,
203 | weight=1
204 | ):
205 |
206 | # check the availability of attention masks
207 | # if not present, create full attention masks
208 | if attention_mask is None:
209 | attention_mask = torch.ones_like(input_values)
210 |
211 | if labels_attention_mask is None:
212 | labels_attention_mask = torch.ones_like(labels)
213 |
214 |
215 |
216 | outputs = self.wav2vec2(
217 | input_values,
218 | attention_mask=attention_mask,
219 | output_attentions=output_attentions,
220 | output_hidden_states=output_hidden_states,
221 | mask_time_indices=mask_time_indices,
222 | return_dict=return_dict,
223 | )
224 |
225 | # acoustic embeddings
226 | frame_hidden = outputs[0]
227 | # frame_hidden = self.dropout(frame_hidden)
228 | frame_hidden = self.cnn(frame_hidden)
229 |
230 | # phone embeddings
231 | phone_hidden = self.bert(input_ids=labels,attention_mask=labels_attention_mask).hidden_states[-1]
232 |
233 | # compute cross attention
234 | att_out,energy = self.attention(frame_hidden,phone_hidden,labels_attention_mask)
235 |
236 |
237 | # start masked modeling
238 | # 0. remove the blank symbol
239 | # 1. project all transformed features (including masked) to final vq dim
240 | transformer_features = self.project_hid(torch.tanh(att_out))
241 |
242 |
243 | # 2. quantize all (unmasked) extracted features and project to final vq dim
244 | extract_features = self.dropout_features(outputs[1])
245 | quantized_features, codevector_perplexity = self.quantizer(extract_features, mask_time_indices)
246 | quantized_features = self.project_q(quantized_features)
247 |
248 |
249 | # if attention_mask is passed, make sure that padded feature vectors cannot be sampled
250 | if attention_mask is not None:
251 | # compute reduced attention_mask correponding to feature vectors
252 | attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
253 |
254 | # loss_fct = nn.CrossEntropyLoss()
255 |
256 | # phone_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
257 |
258 |
259 | loss = None
260 | if self.training:
261 | # for training, we sample negatives
262 | # 3. sample K negatives (distractors) quantized states for contrastive loss
263 |
264 | negative_quantized_features = self._sample_negatives(
265 | quantized_features, self.config.num_negatives, attention_mask=attention_mask
266 | )
267 |
268 | # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
269 | # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
270 | logits = self.compute_contrastive_logits(
271 | quantized_features[None, :],
272 | negative_quantized_features,
273 | transformer_features,
274 | self.config.contrastive_logits_temperature,
275 | )
276 |
277 | # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
278 | # its cosine similarity will be masked
279 | neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
280 | if neg_is_pos.any():
281 | logits[1:][neg_is_pos] = float("-inf")
282 |
283 | # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
284 | # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
285 | preds = logits.transpose(0, 2).reshape(-1, logits.size(0))
286 | target = ((1 - attention_mask.long()) * -100).transpose(0, 1).flatten()
287 | contrastive_loss = nn.functional.cross_entropy(preds.float(), target, reduction="mean")
288 |
289 | # 7. compute diversity loss: \mathbf{L}_d
290 | # num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
291 | # diversity_loss = (num_codevectors - codevector_perplexity) / num_codevectors
292 |
293 | # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
294 | expanded_labels_attention_mask = (1-labels_attention_mask)*-10000.0
295 | expanded_labels_attention_mask = expanded_labels_attention_mask.unsqueeze(1).repeat(1,energy.size(1),1)
296 | att = torch.log_softmax(energy+expanded_labels_attention_mask,dim=-1)
297 | align_loss = self.align_loss(att.unsqueeze(1),text_len,frame_len)
298 |
299 | # expanded_attention_mask = attention_mask.unsqueeze(2).repeat(1,1,energy.size(2)) * labels_attention_mask.unsqueeze(1).repeat(1,energy.size(1),1)
300 | # expanded_attention_mask = (1-expanded_attention_mask)*-10000.0
301 | # phone_attention = torch.softmax((energy+expanded_attention_mask).transpose(2,1),dim=-1)
302 | # phone_emb = torch.bmm(phone_attention,frame_hidden)
303 | # prediction_scores = self.phone_rnn(phone_emb,text_len)
304 | # labels = labels.masked_fill(labels_attention_mask.ne(1), -100)
305 | # inter_phone = F.cosine_similarity(phone_emb[:,:-1,:],phone_emb[:,1:,:],dim=-1)*labels_attention_mask[:,1:]
306 | # interphone_loss = torch.sum(inter_phone)/torch.sum(labels_attention_mask[:,1:])
307 |
308 |
309 | loss = contrastive_loss + weight*align_loss #+ interphone_loss
310 |
311 |
312 | return CausalLMOutput(
313 | loss=loss, logits=energy, hidden_states=outputs.hidden_states, attentions=None
314 | )
315 |
316 |
317 | class BertForMaskedPhoneLM(BertForMaskedLM):
318 |
319 | def __init__(self,config):
320 | super().__init__(config)
321 | self.cnn = ConvBank(config.hidden_size,
322 | config.hidden_size,
323 | config.convbank,
324 | config.hidden_size,
325 | config.hidden_size,
326 | config.hidden_dropout_prob)
327 |
328 | def freeze_feature_extractor(self):
329 | for param in self.bert.parameters():
330 | param.requires_grad = False
331 |
332 | def forward(
333 | self,
334 | input_ids=None,
335 | attention_mask=None,
336 | token_type_ids=None,
337 | position_ids=None,
338 | head_mask=None,
339 | inputs_embeds=None,
340 | encoder_hidden_states=None,
341 | encoder_attention_mask=None,
342 | labels=None,
343 | output_attentions=None,
344 | output_hidden_states=True,
345 | ):
346 |
347 |
348 | outputs = self.bert(
349 | input_ids,
350 | attention_mask=attention_mask,
351 | token_type_ids=token_type_ids,
352 | position_ids=position_ids,
353 | head_mask=head_mask,
354 | inputs_embeds=inputs_embeds,
355 | encoder_hidden_states=encoder_hidden_states,
356 | encoder_attention_mask=encoder_attention_mask,
357 | output_attentions=output_attentions,
358 | output_hidden_states=output_hidden_states
359 | )
360 |
361 |
362 | prediction_scores = self.cnn(outputs.hidden_states[-1])
363 |
364 | masked_lm_loss = None
365 | if labels is not None:
366 | loss_fct = CrossEntropyLoss() # -100 index = padding token
367 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
368 |
369 | return MaskedLMOutput(
370 | loss=masked_lm_loss,
371 | logits=prediction_scores,
372 | hidden_states=outputs.hidden_states,
373 | attentions=outputs.attentions,
374 | )
375 |
376 |
377 |
378 | class Attention(nn.Module):
379 |
380 | def __init__(self,hidden_dim):
381 | super().__init__()
382 | self.q = nn.Linear(hidden_dim, hidden_dim)
383 | self.k = nn.Linear(hidden_dim, hidden_dim)
384 | # self.v = nn.Linear(hidden_dim*2, hidden_dim*2)
385 | self.layer_norm = nn.LayerNorm(hidden_dim)
386 |
387 | def forward(self,frame_hidden, phone_hidden,labels_attention_mask):
388 |
389 | frame_hidden = self.q(frame_hidden)
390 | phone_hidden = self.k(phone_hidden)
391 |
392 | energy = torch.bmm(frame_hidden,phone_hidden.transpose(2,1))
393 | attention_mask = (1-labels_attention_mask)*-10000.0
394 | energy = energy+attention_mask.unsqueeze(1).repeat(1,energy.size(1),1)
395 |
396 | att_matrix = torch.softmax(energy,dim=-1)
397 | att_out = torch.bmm(att_matrix,phone_hidden)
398 | att_out = torch.cat([att_out,frame_hidden],dim=-1)
399 | # att_out = self.layer_norm(att_out + frame_hidden)
400 |
401 | return att_out, energy
402 |
403 |
404 | class Wav2Vec2ForFrameClassification(Wav2Vec2ForCTC):
405 |
406 | def forward(
407 | self,
408 | input_values,
409 | attention_mask=None,
410 | output_attentions=None,
411 | output_hidden_states=None,
412 | return_dict=None,
413 | labels=None,
414 | ):
415 |
416 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
417 |
418 | outputs = self.wav2vec2(
419 | input_values,
420 | attention_mask=attention_mask,
421 | output_attentions=output_attentions,
422 | output_hidden_states=output_hidden_states,
423 | return_dict=return_dict,
424 | )
425 |
426 | hidden_states = outputs[0]
427 | hidden_states = self.dropout(hidden_states)
428 |
429 | logits = self.lm_head(hidden_states)
430 |
431 | loss = None
432 | if labels is not None:
433 |
434 | if labels.max() >= self.config.vocab_size:
435 | raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
436 |
437 | # retrieve loss input_lengths from attention_mask
438 | attention_mask = (
439 | attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
440 | )
441 | # input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
442 |
443 | loss = nn.functional.cross_entropy(logits.view(-1,logits.size(2)), labels.flatten(), reduction="mean")
444 |
445 |
446 |
447 | if not return_dict:
448 | output = (logits,) + outputs[2:]
449 | return ((loss,) + output) if loss is not None else output
450 |
451 | return CausalLMOutput(
452 | loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
453 | )
454 |
455 |
456 |
457 | class Wav2Vec2ForCTCAndPretraining(Wav2Vec2ForPreTraining):
458 | '''
459 | Implementation adapted from: https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2ForPreTraining
460 | '''
461 | def __init__(self,config):
462 | super().__init__(config)
463 |
464 | self.dropout = nn.Dropout(config.final_dropout)
465 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
466 | self.cnn = ConvBank(config.vocab_size-1,config.hidden_size,[1],config.hidden_size,config.hidden_size,0.1)
467 | # was [1,3,5,7]
468 |
469 | def freeze_wav2vec2(self):
470 | for param in self.wav2vec2.parameters():
471 | param.requires_grad = False
472 |
473 | def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
474 | output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
475 | batch_size = attention_mask.shape[0]
476 |
477 | attention_mask = torch.zeros(
478 | (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
479 | )
480 | # these two operations makes sure that all values before the output lengths idxs are attended to
481 | attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
482 | attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
483 | return attention_mask
484 |
485 | @staticmethod
486 | def _sample_negatives(
487 | features: torch.FloatTensor, num_negatives: int, attention_mask: Optional[torch.LongTensor] = None
488 | ):
489 | """
490 | Sample `num_negatives` vectors from feature vectors.
491 | """
492 | batch_size, sequence_length, hidden_size = features.shape
493 | if sequence_length <= 1:
494 | raise ValueError(
495 | f"`features should have `sequence_length` > 1, but are of shape (batch_size, sequence_length, hidden_size) = ({batch_size, sequence_length, hidden_size})."
496 | )
497 |
498 | features = features.view(-1, hidden_size) # BTC => (BxT)C
499 |
500 | with torch.no_grad():
501 | # get `num_negatives` random vector indices from the same utterance
502 | sampled_negative_indices = []
503 | for batch_idx in range(batch_size):
504 | high = attention_mask[batch_idx].sum() - 1 if attention_mask is not None else sequence_length - 1
505 | sampled_indices_slice = torch.randint(
506 | 0, high, size=(num_negatives * sequence_length,), device=features.device
507 | )
508 | sampled_negative_indices.append(sampled_indices_slice)
509 |
510 | sampled_negative_indices = torch.stack(sampled_negative_indices)
511 |
512 | # generate indices of the positive vectors themselves, repeat them `num_negatives` times
513 | feature_indices = (
514 | torch.arange(sequence_length, device=features.device)[:, None]
515 | .expand(sequence_length, num_negatives)
516 | .flatten()
517 | )
518 |
519 | # avoid sampling the same positive vector, but keep the distribution uniform
520 | sampled_negative_indices[sampled_negative_indices >= feature_indices] += 1
521 |
522 | # correct for batch size
523 | for batch_idx in range(1, batch_size):
524 | sampled_negative_indices[batch_idx] += batch_idx * sequence_length
525 |
526 | # take negative vectors from sampled indices
527 | sampled_negatives = features[sampled_negative_indices.view(-1)]
528 | sampled_negatives = sampled_negatives.view(batch_size, sequence_length, num_negatives, hidden_size).permute(
529 | 2, 0, 1, 3
530 | )
531 |
532 | return sampled_negatives
533 |
534 |
535 | def forward(
536 | self,
537 | input_values,
538 | labels=None,
539 | attention_mask=None,
540 | mask_time_indices=None,
541 | output_attentions=None,
542 | output_hidden_states=None,
543 | return_dict=None,
544 | ):
545 |
546 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
547 |
548 | if mask_time_indices is not None:
549 | mask_time_indices = mask_time_indices.to(torch.bool)
550 |
551 | outputs = self.wav2vec2(
552 | input_values,
553 | attention_mask=attention_mask,
554 | output_attentions=output_attentions,
555 | output_hidden_states=output_hidden_states,
556 | mask_time_indices=mask_time_indices,
557 | )
558 |
559 | # get CTC loss
560 | hidden_states = self.dropout(outputs[0])
561 | ctc_logits = self.lm_head(hidden_states)
562 |
563 | loss = None
564 | ctc_loss = None
565 | contrastive_loss = None
566 | codevector_perplexity = None
567 | if labels is not None:
568 |
569 | if labels.max() >= self.config.vocab_size:
570 | raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
571 |
572 | # retrieve loss input_lengths from attention_mask
573 | attention_mask = (
574 | attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
575 | )
576 | input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
577 |
578 | # assuming that padded tokens are filled with -100
579 | # when not being attended to
580 | labels_mask = labels >= 0
581 | target_lengths = labels_mask.sum(-1)
582 | flattened_targets = labels.masked_select(labels_mask)
583 |
584 | # ctc_loss doesn't support fp16
585 | log_probs = nn.functional.log_softmax(ctc_logits, dim=-1, dtype=torch.float32).transpose(0, 1)
586 |
587 | with torch.backends.cudnn.flags(enabled=False):
588 | ctc_loss = nn.functional.ctc_loss(
589 | log_probs,
590 | flattened_targets,
591 | input_lengths,
592 | target_lengths,
593 | blank=self.config.pad_token_id,
594 | reduction=self.config.ctc_loss_reduction,
595 | zero_infinity=self.config.ctc_zero_infinity,
596 | )
597 |
598 | # start masked modeling
599 | # 0. remove the blank symbol
600 | # 1. project all transformed features (including masked) to final vq dim
601 | transformer_features = self.project_hid(self.cnn(ctc_logits[:,:,:-1]))
602 |
603 |
604 | # 2. quantize all (unmasked) extracted features and project to final vq dim
605 | extract_features = self.dropout_features(outputs[1])
606 | quantized_features, codevector_perplexity = self.quantizer(extract_features, mask_time_indices)
607 | quantized_features = self.project_q(quantized_features)
608 |
609 | loss = None
610 | if self.training:
611 | # for training, we sample negatives
612 | # 3. sample K negatives (distractors) quantized states for contrastive loss
613 | # if attention_mask is passed, make sure that padded feature vectors cannot be sampled
614 | if attention_mask is not None:
615 | # compute reduced attention_mask correponding to feature vectors
616 | attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
617 |
618 | negative_quantized_features = self._sample_negatives(
619 | quantized_features, self.config.num_negatives, attention_mask=attention_mask
620 | )
621 |
622 | # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
623 | # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
624 | logits = self.compute_contrastive_logits(
625 | quantized_features[None, :],
626 | negative_quantized_features,
627 | transformer_features,
628 | self.config.contrastive_logits_temperature,
629 | )
630 |
631 | # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
632 | # its cosine similarity will be masked
633 | neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
634 | if neg_is_pos.any():
635 | logits[1:][neg_is_pos] = float("-inf")
636 |
637 | # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
638 | # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
639 | preds = logits.transpose(0, 2).reshape(-1, logits.size(0))
640 | target = ((1 - attention_mask.long()) * -100).transpose(0, 1).flatten()
641 | contrastive_loss = nn.functional.cross_entropy(preds.float(), target, reduction="mean")
642 |
643 | # 7. compute diversity loss: \mathbf{L}_d
644 | num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
645 | diversity_loss = (num_codevectors - codevector_perplexity) / num_codevectors
646 |
647 | # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
648 | contrastive_loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
649 |
650 | loss = ctc_loss + contrastive_loss
651 |
652 | return Wav2Vec2ForCTCAndPretrainingOutput(
653 | loss=loss,
654 | ctc_logits = ctc_logits,
655 | ctc_loss = ctc_loss,
656 | contrastive_loss = contrastive_loss,
657 | codevector_perplexity=codevector_perplexity,
658 | projected_states=transformer_features,
659 | projected_quantized_states=quantized_features,
660 | hidden_states=outputs.hidden_states,
661 | attentions=outputs.attentions,
662 | )
663 |
664 |
665 |
--------------------------------------------------------------------------------