├── src ├── models │ ├── __init__.py │ ├── hipama.py │ └── gopt.py ├── run_hipama.sh ├── run_gopt.sh ├── collect_summary.py ├── get_summary.py ├── prep_data │ ├── so762_stats │ ├── gen_seq_data_word.py │ ├── gen_seq_data_utt.py │ └── gen_seq_data_phn.py ├── custom_layers │ └── attention.py └── traintest.py ├── requirements.txt ├── data └── README.md ├── LICENSE └── README.md /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .gopt import * 2 | from .hipama import * -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1 2 | kaldi-io==0.9.4 3 | kaldiio==2.17.2 4 | numpy==1.20.3 5 | pandas==1.5.0 -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | We used the same dataset as the baseline model, GOPT. 2 | 3 | Please download and configure the data the same as this repository [https://github.com/YuanGongND/gopt/tree/master/data](https://github.com/YuanGongND/gopt/tree/master/data). 4 | -------------------------------------------------------------------------------- /src/run_hipama.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ##SBATCH -p sm 3 | ##SBATCH -x sls-sm-1,sls-2080-[1,3],sls-1080-3,sls-sm-5 4 | #SBATCH -p gpu 5 | #SBATCH -x sls-titan-[0-2] 6 | #SBATCH --gres=gpu:1 7 | #SBATCH -c 4 8 | #SBATCH -n 1 9 | #SBATCH --mem=24000 10 | #SBATCH --job-name="gopt" 11 | #SBATCH --output=../exp/log_%j.txt 12 | 13 | lr=1e-3 14 | depth=3 15 | head=1 16 | batch_size=25 17 | embed_dim=24 18 | model=hipama 19 | am=librispeech 20 | 21 | exp_dir=../exp/gopt-${lr}-${depth}-${head}-${batch_size}-${embed_dim}-${model}-${am}-br 22 | 23 | # repeat times 24 | repeat_list=(0 1 2 3 4) 25 | 26 | for repeat in "${repeat_list[@]}" 27 | do 28 | mkdir -p $exp_dir-${repeat} 29 | python ./traintest.py --lr ${lr} --exp-dir ${exp_dir}-${repeat} --goptdepth ${depth} --goptheads ${head} \ 30 | --batch_size ${batch_size} --embed_dim ${embed_dim} \ 31 | --model ${model} --am ${am} 32 | done 33 | 34 | python ./collect_summary.py --exp-dir $exp_dir -------------------------------------------------------------------------------- /src/run_gopt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ##SBATCH -p sm 3 | ##SBATCH -x sls-sm-1,sls-2080-[1,3],sls-1080-3,sls-sm-5 4 | #SBATCH -p gpu 5 | #SBATCH -x sls-titan-[0-2] 6 | #SBATCH --gres=gpu:1 7 | #SBATCH -c 4 8 | #SBATCH -n 1 9 | #SBATCH --mem=24000 10 | #SBATCH --job-name="gopt" 11 | #SBATCH --output=../exp/log_%j.txt 12 | 13 | set -x 14 | # comment this line if not running on sls cluster 15 | . /data/sls/scratch/share-201907/slstoolchainrc 16 | source ../venv-gopt/bin/activate 17 | 18 | lr=1e-3 19 | depth=3 20 | head=1 21 | batch_size=25 22 | embed_dim=24 23 | model=gopt 24 | am=librispeech 25 | 26 | exp_dir=../exp/gopt-${lr}-${depth}-${head}-${batch_size}-${embed_dim}-${model}-${am}-br 27 | 28 | # repeat times 29 | repeat_list=(0 1 2 3 4) 30 | 31 | for repeat in "${repeat_list[@]}" 32 | do 33 | mkdir -p $exp_dir-${repeat} 34 | python ./traintest.py --lr ${lr} --exp-dir ${exp_dir}-${repeat} --goptdepth ${depth} --goptheads ${head} \ 35 | --batch_size ${batch_size} --embed_dim ${embed_dim} \ 36 | --model ${model} --am ${am} 37 | done 38 | 39 | python ./collect_summary.py --exp-dir $exp_dir -------------------------------------------------------------------------------- /src/collect_summary.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 9/23/21 11:33 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : collect_summary.py 7 | 8 | # collect summery of repeated experiment. 9 | 10 | import argparse 11 | import os 12 | import numpy as np 13 | 14 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 15 | parser.add_argument("--exp-dir", type=str, default="./test", help="directory to dump experiments") 16 | args = parser.parse_args() 17 | 18 | result = [] 19 | # for each repeat experiment 20 | for i in range(0, 10): 21 | cur_exp_dir = args.exp_dir + '-' + str(i) 22 | if os.path.isfile(cur_exp_dir + '/result.csv'): 23 | try: 24 | print(cur_exp_dir) 25 | cur_res = np.loadtxt(cur_exp_dir + '/result.csv', delimiter=',') 26 | for last_epoch in range(cur_res.shape[0] - 1, -1, -1): 27 | if cur_res[last_epoch, 0] > 5e-03: 28 | break 29 | result.append(cur_res[last_epoch, :]) 30 | except: 31 | pass 32 | 33 | result = np.array(result) 34 | # get mean / std of the repeat experiments. 35 | result_mean = np.mean(result, axis=0) 36 | result_std = np.std(result, axis=0) 37 | if os.path.exists(args.exp_dir) == False: 38 | os.mkdir(args.exp_dir) 39 | np.savetxt(args.exp_dir + '/result_summary.csv', [result_mean, result_std], delimiter=',') -------------------------------------------------------------------------------- /src/get_summary.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 9/21/21 9:50 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : get_summary.py 7 | 8 | import os 9 | import numpy as np 10 | from matplotlib import pyplot as plt 11 | from datetime import datetime 12 | 13 | def get_immediate_subdirectories(a_dir): 14 | return [name for name in os.listdir(a_dir) if os.path.isdir(os.path.join(a_dir, name))] 15 | 16 | requirement_list = {'gopt': 0.0, 'lstm': 0.0} 17 | 18 | # second pass 19 | for requirement in requirement_list.keys(): 20 | threshold = requirement_list[requirement] 21 | result = [] 22 | root_path = '../exp/' 23 | exp_list = get_immediate_subdirectories(root_path) 24 | exp_list.sort() 25 | for exp in exp_list: 26 | if requirement in exp and os.path.isfile(root_path + exp + '/result_summary.csv'): 27 | try: 28 | print(exp) 29 | cur_res = np.loadtxt(root_path + exp + '/result_summary.csv', delimiter=',')[0] 30 | print(cur_res) 31 | test_mse = cur_res[2] 32 | test_corr = cur_res[3] 33 | te_utt_corr = cur_res[20:25] 34 | te_word_corr = cur_res[28:31] 35 | 36 | print(te_utt_corr) 37 | result.append([exp, test_mse, test_corr, te_utt_corr[4], te_word_corr[2]]) 38 | except: 39 | pass 40 | 41 | np.savetxt('../exp/' + requirement + '_summary_brief.csv', result, delimiter=',', fmt='%s') -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, Yuan Gong 4 | Copyright (c) 2023, Heejin Do 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /src/prep_data/so762_stats: -------------------------------------------------------------------------------- 1 | Training Set 2 | 3 | 47076 total phns, 2500 utterance, 15849 words 4 | 5 | Phn Distribution 6 | {3.0: 468, 4.0: 1390, 5.0: 4393, 6.0: 747, 7.0: 338, 8.0: 1236, 9.0: 901, 10.0: 165, 11.0: 1806, 12.0: 1532, 13.0: 1267, 14.0: 632, 15.0: 805, 16.0: 859, 17.0: 575, 18.0: 1093, 19.0: 2980, 20.0: 1851, 21.0: 189, 22.0: 1380, 23.0: 1784, 24.0: 1515, 25.0: 3002, 26.0: 534, 27.0: 776, 28.0: 36, 29.0: 699, 30.0: 1245, 31.0: 2073, 32.0: 353, 33.0: 4019, 34.0: 301, 35.0: 568, 36.0: 1356, 37.0: 807, 38.0: 1586, 39.0: 577, 40.0: 1232, 41.0: 6} 7 | 8 | Feature data mean: 3.203 and std: 4.045. 9 | 10 | Phn Score Distribution (rounded by 0.1) 11 | {0.0: 1344, 0.2: 44, 0.4: 603, 0.6: 85, 0.8: 488, 1.0: 155, 1.2: 813, 1.4000000000000001: 426, 1.6: 2416, 1.8: 2995, 2.0: 37707} 12 | 13 | Word Score Distribution 14 | 15 | Accuracy {'0': 6, '1': 18, '10': 40198, '2': 124, '3': 3544, '5': 1668, '6': 221, '7': 106, '8': 1180, '9': 11} 16 | 17 | Stress: {'10': 46329, '5': 747} 18 | 19 | Total {'10': 39225, '2': 42, '3': 1103, '4': 2547, '5': 256, '6': 1645, '7': 144, '8': 1130, '9': 984} 20 | 21 | Test Set 22 | 23 | 47369 total phns, 15967 words 24 | 25 | {3.0: 416, 4.0: 1465, 5.0: 4437, 6.0: 750, 7.0: 352, 8.0: 1228, 9.0: 846, 10.0: 181, 11.0: 1866, 12.0: 1495, 13.0: 1271, 14.0: 695, 15.0: 800, 16.0: 855, 17.0: 543, 18.0: 1081, 19.0: 2927, 20.0: 1868, 21.0: 179, 22.0: 1429, 23.0: 1858, 24.0: 1464, 25.0: 3012, 26.0: 565, 27.0: 782, 28.0: 20, 29.0: 690, 30.0: 1200, 31.0: 2157, 32.0: 346, 33.0: 3958, 34.0: 348, 35.0: 563, 36.0: 1409, 37.0: 810, 38.0: 1562, 39.0: 625, 40.0: 1297, 41.0: 19} 26 | 27 | Feat Mean: 3.212 and std: 4.045. 28 | 29 | Phn Score Distribution (rounded by 0.1) 30 | {0.0: 943, 0.2: 31, 0.4: 438, 0.6: 88, 0.8: 478, 1.0: 171, 1.2: 688, 1.4: 435, 1.6: 2456, 1.8: 3146, 2.0: 38495} 31 | 32 | 31816 words in total -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HiPAMA 2 | 3 | This repository is the implementation of the paper, [**Hierarchical Pronunciation Assessment with Multi-Aspect Attention**](https://ieeexplore.ieee.org/document/10095733/) (ICASSP 2023). 4 | 5 | > Our code is based on the open source, [https://github.com/YuanGongND/gopt](https://github.com/YuanGongND/gopt) (Gong et al, 2022). 6 | 7 | ## Citation 8 | Please cite our paper if you find this repository helpful. 9 | 10 | ``` 11 | @INPROCEEDINGS{10095733, 12 | author={Do, Heejin and Kim, Yunsu and Lee, Gary Geunbae}, 13 | booktitle={ICASSP 2023 - 2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 14 | title={Hierarchical Pronunciation Assessment with Multi-Aspect Attention}, 15 | year={2023}, 16 | volume={}, 17 | number={}, 18 | pages={1-5}, 19 | doi={10.1109/ICASSP49357.2023.10095733}} 20 | ``` 21 | 22 | ## Dataset 23 | 24 | An open source dataset, SpeechOcean762 (licenced with CC BY 4.0) is used. You can download it from [https://www.openslr.org/101](https://www.openslr.org/101). 25 | 26 | ## Package Requirements 27 | 28 | Install below packages in your virtual environment before running the code. 29 | - python version 3.8.10 30 | - pytorch version '1.13.1+cu117' 31 | - numpy version 1.20.3 32 | - pandas version 1.5.0 33 | 34 | You can run below command on your virtual environment 35 | - `pip install -r requirements.txt` 36 | 37 | ## Training and Evaluation (HiPAMA) 38 | This bash script will run each model 5 times with ([0, 1, 2, 3, 4]). 39 | - `cd src` 40 | - `bash run_hipama.sh` 41 | 42 | Note that every run does not produce the same results due to the random elements. 43 | 44 | The reported results in the paper are the averages of the final epoch results for five different seeds. 45 | 46 | ## Run baseline (GOPT) 47 | This bash script will run each model 5 times with ([0, 1, 2, 3, 4]). 48 | - `cd src` 49 | - `bash run_gopt.sh` 50 | -------------------------------------------------------------------------------- /src/custom_layers/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | class Aspect_Attention_op2(nn.Module): 8 | def __init__(self, embed_dim, op='none', activation='tanh', init_stdev=0.01): 9 | super().__init__() 10 | self.supports_masking = True 11 | assert op in {'attsum', 'attmean','none'} 12 | assert activation in {None, 'tanh'} 13 | self.op = op 14 | self.activation = activation 15 | self.init_stdev = init_stdev 16 | self._reset_parameters(embed_dim) 17 | 18 | def _reset_parameters(self, embed_dim): 19 | init_val_v = (torch.randn(embed_dim) * self.init_stdev) 20 | self.att_v = nn.Parameter(init_val_v) 21 | init_val_W = (torch.randn(embed_dim, embed_dim) * self.init_stdev) 22 | self.att_W = nn.Parameter(init_val_W) 23 | self.built = True 24 | 25 | def forward(self, x, x2, mask=None): 26 | 27 | y = torch.matmul(x2, self.att_W) 28 | 29 | if not self.activation: 30 | weights = torch.tensordot(self.att_v, y, dims=([0], [2])) 31 | elif self.activation == 'tanh': 32 | weights = torch.tensordot(self.att_v, torch.tanh(y), dims=([0], [2])) 33 | 34 | weights = F.softmax(weights, dim=0) 35 | out = x2 * weights.repeat(1, x2.shape[2]).reshape(weights.shape[0], weights.shape[1], x2.shape[2]) 36 | 37 | batch_size, hidden_dim, input_size = x.size(0), x.size(2), x2.size(1) 38 | 39 | self.score = torch.bmm(x, out.transpose(1, 2)) 40 | self.attn = F.softmax(self.score.view(-1, input_size), dim=1).view(batch_size, -1, input_size) 41 | context = torch.bmm(self.attn, out) 42 | 43 | return context 44 | 45 | def get_output_shape_for(self, input_shape): 46 | return (input_shape[0], input_shape[2]) 47 | 48 | def compute_output_shape(self, input_shape): 49 | return (input_shape[0], input_shape[2]) 50 | 51 | def compute_mask(self, x, mask): 52 | return None 53 | 54 | def get_config(self): 55 | config = {'op': self.op, 'activation': self.activation, 'init_stdev': self.init_stdev} 56 | base_config = super(Attention_tmp, self).get_config() 57 | return dict(list(base_config.items()) + list(config.items())) 58 | -------------------------------------------------------------------------------- /src/prep_data/gen_seq_data_word.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 9/19/21 11:13 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : gen_seq_data_phn.py 7 | 8 | # Generate sequence phone input and label for seq2seq models from raw Kaldi GOP features. 9 | 10 | import numpy as np 11 | import json 12 | 13 | def load_feat(path): 14 | file = np.loadtxt(path, delimiter=',') 15 | return file 16 | 17 | def load_keys(path): 18 | file = np.loadtxt(path, delimiter=',', dtype=str) 19 | return file 20 | 21 | def load_label(path): 22 | file = np.loadtxt(path, delimiter=',', dtype=str) 23 | return file 24 | 25 | def process_label(label): 26 | pure_label = [] 27 | for i in range(0, label.shape[0]): 28 | pure_label.append(float(label[i, 1])) 29 | return np.array(pure_label) 30 | 31 | def process_feat_seq_word(feat, keys, labels): 32 | key_set = [] 33 | for i in range(keys.shape[0]): 34 | cur_key = keys[i].split('.')[0] 35 | key_set.append(cur_key) 36 | 37 | utt_cnt = len(list(set(key_set))) 38 | print('In total utterance number : ' + str(utt_cnt)) 39 | 40 | # -1 means n/a 41 | seq_label = np.zeros([utt_cnt, 50, 4]) - 1 42 | 43 | prev_utt_id = keys[0].split('.')[0] 44 | 45 | row = 0 46 | for i in range(feat.shape[0]): 47 | cur_utt_id, cur_tok_id = keys[i].split('.')[0], int(keys[i].split('.')[1]) 48 | if cur_utt_id != prev_utt_id: 49 | row += 1 50 | prev_utt_id = cur_utt_id 51 | 52 | seq_label[row, cur_tok_id, 0:3] = labels[i, 3:6] 53 | seq_label[row, cur_tok_id, 3] = labels[i, 1] 54 | 55 | return seq_label 56 | 57 | # utt label dict 58 | with open('scores.json') as f: 59 | utt2score = json.loads(f.read()) 60 | 61 | # sequence training data 62 | tr_feat = load_feat('../../data/raw_kaldi_gop/librispeech/tr_feats.csv') 63 | tr_keys = load_keys('../../data/raw_kaldi_gop/librispeech/tr_keys_word.csv') 64 | tr_label = load_label('../../data/raw_kaldi_gop/librispeech/tr_labels_word.csv') 65 | tr_label = process_feat_seq_word(tr_feat, tr_keys, tr_label) 66 | print(tr_label.shape) 67 | np.save('../../data/seq_data_librispeech/tr_label_word.npy', tr_label) 68 | 69 | # sequence test data 70 | te_feat = load_feat('../../data/raw_kaldi_gop/librispeech/te_feats.csv') 71 | te_keys = load_keys('../../data/raw_kaldi_gop/librispeech/te_keys_word.csv') 72 | te_label = load_label('../../data/raw_kaldi_gop/librispeech/te_labels_word.csv') 73 | te_label = process_feat_seq_word(te_feat, te_keys, te_label) 74 | print(te_label.shape) 75 | np.save('../../data/seq_data_librispeech/te_label_word.npy', te_label) 76 | -------------------------------------------------------------------------------- /src/prep_data/gen_seq_data_utt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 9/19/21 11:13 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : gen_seq_data_phn.py 7 | 8 | # Generate sequence phone input and label for seq2seq models from raw Kaldi GOP features. 9 | 10 | import numpy as np 11 | import json 12 | 13 | def load_feat(path): 14 | file = np.loadtxt(path, delimiter=',') 15 | return file 16 | 17 | def load_keys(path): 18 | file = np.loadtxt(path, delimiter=',', dtype=str) 19 | return file 20 | 21 | def load_label(path): 22 | file = np.loadtxt(path, delimiter=',', dtype=str) 23 | return file 24 | 25 | def process_label(label): 26 | pure_label = [] 27 | for i in range(0, label.shape[0]): 28 | pure_label.append(float(label[i, 1])) 29 | return np.array(pure_label) 30 | 31 | def process_feat_seq_utt(feat, keys, utt2score): 32 | key_set = [] 33 | for i in range(keys.shape[0]): 34 | cur_key = keys[i].split('.')[0] 35 | key_set.append(cur_key) 36 | 37 | utt_cnt = len(list(set(key_set))) 38 | print('In total utterance number : ' + str(utt_cnt)) 39 | 40 | seq_label = np.zeros([utt_cnt, 5]) 41 | 42 | prev_utt_id = keys[0].split('.')[0] 43 | 44 | row = 0 45 | for i in range(feat.shape[0]): 46 | cur_utt_id, cur_tok_id = keys[i].split('.')[0], int(keys[i].split('.')[1]) 47 | if cur_utt_id != prev_utt_id: 48 | row += 1 49 | prev_utt_id = cur_utt_id 50 | 51 | seq_label[row, 0] = utt2score[cur_utt_id]['accuracy'] 52 | seq_label[row, 1] = utt2score[cur_utt_id]['completeness'] 53 | seq_label[row, 2] = utt2score[cur_utt_id]['fluency'] 54 | seq_label[row, 3] = utt2score[cur_utt_id]['prosodic'] 55 | seq_label[row, 4] = utt2score[cur_utt_id]['total'] 56 | 57 | return seq_label 58 | 59 | # utt label dict 60 | with open('scores.json') as f: 61 | utt2score = json.loads(f.read()) 62 | 63 | # sequencialize training data 64 | tr_feat = load_feat('../../data/raw_kaldi_gop/librispeech/tr_feats.csv') 65 | tr_keys = load_keys('../../data/raw_kaldi_gop/librispeech/tr_keys_phn.csv') 66 | tr_label = process_feat_seq_utt(tr_feat, tr_keys, utt2score) 67 | print(tr_label.shape) 68 | np.save('../../data/seq_data_librispeech/tr_label_utt.npy', tr_label) 69 | 70 | # sequencialize test data 71 | te_feat = load_feat('../../data/raw_kaldi_gop/librispeech/te_feats.csv') 72 | te_keys = load_keys('../../data/raw_kaldi_gop/librispeech/te_keys_phn.csv') 73 | te_label = process_feat_seq_utt(te_feat, te_keys, utt2score) 74 | print(te_label.shape) 75 | np.save('../../data/seq_data_librispeech/te_label_utt.npy', te_label) -------------------------------------------------------------------------------- /src/prep_data/gen_seq_data_phn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 9/19/21 11:13 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : gen_seq_data_phn.py 7 | 8 | # Generate sequence phone input and label for seq2seq models from raw Kaldi GOP features. 9 | 10 | import numpy as np 11 | 12 | def load_feat(path): 13 | file = np.loadtxt(path, delimiter=',') 14 | return file 15 | 16 | def load_keys(path): 17 | file = np.loadtxt(path, delimiter=',', dtype=str) 18 | return file 19 | 20 | def load_label(path): 21 | file = np.loadtxt(path, delimiter=',', dtype=str) 22 | return file 23 | 24 | def process_label(label): 25 | pure_label = [] 26 | for i in range(0, label.shape[0]): 27 | pure_label.append(float(label[i, 1])) 28 | return np.array(pure_label) 29 | 30 | def process_feat_seq(feat, keys, labels, phn_dict): 31 | key_set = [] 32 | for i in range(keys.shape[0]): 33 | cur_key = keys[i].split('.')[0] 34 | key_set.append(cur_key) 35 | 36 | feat_dim = feat.shape[1] - 1 37 | 38 | utt_cnt = len(list(set(key_set))) 39 | print('In total utterance number : ' + str(utt_cnt)) 40 | 41 | # Pad all sequence to 50 because the longest sequence of the so762 dataset is shorter than 50. 42 | seq_feat = np.zeros([utt_cnt, 50, feat_dim]) 43 | # -1 means n/a, padded token 44 | # [utt, seq_len, 0] is the phone label, and the [utt, seq_len, 1] is the score label 45 | seq_label = np.zeros([utt_cnt, 50, 2]) - 1 46 | 47 | # the key format is utt_id.phn_id 48 | prev_utt_id = keys[0].split('.')[0] 49 | 50 | row = 0 51 | for i in range(feat.shape[0]): 52 | cur_utt_id, cur_tok_id = keys[i].split('.')[0], int(keys[i].split('.')[1]) 53 | # if a new sequence, start a new row of the feature vector. 54 | if cur_utt_id != prev_utt_id: 55 | row += 1 56 | prev_utt_id = cur_utt_id 57 | 58 | # The first element is the phone label. 59 | seq_feat[row, cur_tok_id, :] = feat[i, 1:] 60 | 61 | # [utt, seq_len, 0] is the phone label 62 | seq_label[row, cur_tok_id, 0] = phn_dict[labels[i, 0]] 63 | # [utt, seq_len, 1] is the score label, range from 0-2 64 | seq_label[row, cur_tok_id, 1] = labels[i, 1] 65 | 66 | return seq_feat, seq_label 67 | 68 | def gen_phn_dict(label): 69 | phn_dict = {} 70 | phn_idx = 0 71 | for i in range(label.shape[0]): 72 | if label[i, 0] not in phn_dict: 73 | phn_dict[label[i, 0]] = phn_idx 74 | phn_idx += 1 75 | return phn_dict 76 | 77 | # generate sequence training data 78 | tr_feat = load_feat('../../data/raw_kaldi_gop/librispeech/tr_feats.csv') 79 | tr_keys = load_keys('../../data/raw_kaldi_gop/librispeech/tr_keys_phn.csv') 80 | tr_label = load_label('../../data/raw_kaldi_gop/librispeech/tr_labels_phn.csv') 81 | phn_dict = gen_phn_dict(tr_label) 82 | print(phn_dict) 83 | tr_feat, tr_label = process_feat_seq(tr_feat, tr_keys, tr_label, phn_dict) 84 | print(tr_feat.shape) 85 | print(tr_label.shape) 86 | np.save('../../data/seq_data_librispeech/tr_feat.npy', tr_feat) 87 | np.save('../../data/seq_data_librispeech/tr_label_phn.npy', tr_label) 88 | 89 | # generate sequence test data 90 | te_feat = load_feat('../../data/raw_kaldi_gop/librispeech/te_feats.csv') 91 | te_keys = load_keys('../../data/raw_kaldi_gop/librispeech/te_keys_phn.csv') 92 | te_label = load_label('../../data/raw_kaldi_gop/librispeech/te_labels_phn.csv') 93 | te_feat, te_label = process_feat_seq(te_feat, te_keys, te_label, phn_dict) 94 | print(te_feat.shape) 95 | print(te_label.shape) 96 | np.save('../../data/seq_data_librispeech/te_feat.npy', te_feat) 97 | np.save('../../data/seq_data_librispeech/te_label_phn.npy', te_label) 98 | -------------------------------------------------------------------------------- /src/models/hipama.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from custom_layers.attention import Aspect_Attention_op2 5 | 6 | class Attention(nn.Module): 7 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 8 | super().__init__() 9 | self.num_heads = num_heads 10 | head_dim = dim // num_heads 11 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 12 | self.scale = qk_scale or head_dim ** -0.5 13 | 14 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 15 | self.attn_drop = nn.Dropout(attn_drop) 16 | self.proj = nn.Linear(dim, dim) 17 | self.proj_drop = nn.Dropout(proj_drop) 18 | 19 | def forward(self, x): 20 | B, N, C = x.shape 21 | #print(C) 22 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 23 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 24 | 25 | attn = (q @ k.transpose(-2, -1)) * self.scale 26 | attn = attn.softmax(dim=-1) 27 | attn = self.attn_drop(attn) 28 | 29 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 30 | x = self.proj(x) 31 | x = self.proj_drop(x) 32 | return x 33 | 34 | class HiPAMA(nn.Module): 35 | def __init__(self, embed_dim, depth, input_dim=84, num_heads=4): 36 | super().__init__() 37 | self.input_dim = input_dim 38 | self.embed_dim = embed_dim 39 | self.num_heads = num_heads 40 | self.conv_dim = 25 41 | 42 | # phone projection 43 | self.phn_proj = nn.Linear(40, embed_dim) 44 | 45 | # for phone classification 46 | self.in_proj = nn.Linear(self.input_dim, embed_dim) 47 | 48 | self.lstm = torch.nn.LSTM(input_size=embed_dim, hidden_size=embed_dim, num_layers=depth, batch_first=True) 49 | self.attn = Attention( 50 | embed_dim, num_heads=num_heads, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.) # Multi-head self-attention 51 | self.conv = torch.nn.Conv1d(in_channels=embed_dim, out_channels=embed_dim, kernel_size=5, padding=2) # seq_len, 52 | 53 | self.mlp_head_phn = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 54 | 55 | # for word classification 56 | self.rep_w1 = nn.Linear(embed_dim, embed_dim) 57 | self.rep_w2 = nn.Linear(embed_dim, embed_dim) 58 | self.rep_w3 = nn.Linear(embed_dim, embed_dim) 59 | 60 | self.attn_tmp = Aspect_Attention_op2(embed_dim) 61 | 62 | self.mlp_head_word1 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 63 | self.mlp_head_word2 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 64 | self.mlp_head_word3 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 65 | 66 | self.w_attn = Attention( 67 | embed_dim, num_heads=num_heads, qkv_bias=False, qk_scale=None, attn_drop=0.2, proj_drop=0.) 68 | 69 | # utterance level 70 | self.rep_utt1 = nn.Linear(embed_dim, embed_dim) 71 | self.rep_utt2 = nn.Linear(embed_dim, embed_dim) 72 | self.rep_utt3 = nn.Linear(embed_dim, embed_dim) 73 | self.rep_utt4 = nn.Linear(embed_dim, embed_dim) 74 | self.rep_utt5 = nn.Linear(embed_dim, embed_dim) 75 | self.utt_attn_tmp = Aspect_Attention_op2(embed_dim) 76 | 77 | self.mlp_head_utt1 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 78 | self.mlp_head_utt2 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 79 | self.mlp_head_utt3 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 80 | self.mlp_head_utt4 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 81 | self.mlp_head_utt5 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 82 | 83 | # get the output of the last valid token 84 | def get_last_valid(self, input, mask): 85 | output = [] 86 | B = input.shape[0] 87 | seq_len = input.shape[1] 88 | for i in range(B): 89 | for j in range(seq_len): 90 | if mask[i, j] == 0: 91 | output.append(input[i, j-1]) 92 | break 93 | if j == seq_len - 1: 94 | print('append') 95 | output.append(input[i, j]) 96 | output = torch.stack(output, dim=0) 97 | return output.unsqueeze(1) 98 | 99 | # x shape in [batch_size, sequence_len, feat_dim] 100 | # phn in [batch_size, seq_len] 101 | def forward(self, x, phn): 102 | 103 | # batch size 104 | B = x.shape[0] 105 | seq_len = x.shape[1] 106 | valid_tok_mask = (phn>=0) 107 | 108 | # phn_one_hot in shape [batch_size, seq_len, feat_dim] 109 | phn_one_hot = torch.nn.functional.one_hot(phn.long()+1, num_classes=40).float() 110 | # phn_embed in shape [batch_size, seq_len, embed_dim] 111 | phn_embed = self.phn_proj(phn_one_hot) 112 | 113 | if self.embed_dim != self.input_dim: 114 | x = self.in_proj(x) 115 | 116 | x = x + phn_embed 117 | self.lstm.flatten_parameters() 118 | x = self.lstm(x)[0] 119 | 120 | x = self.attn(x) 121 | 122 | x = self.conv(x.transpose(1,2)) # x.transpose in shape [batch, feat_dim, seq_len] 123 | x = x.transpose(1,2) 124 | 125 | ### Fisrt output phn score 126 | p = self.mlp_head_phn(x).reshape(B, seq_len, 1) 127 | 128 | ### Second output word score with aspects attention 129 | w1 = self.rep_w1(x) 130 | w2 = self.rep_w2(x) 131 | w3 = self.rep_w3(x) 132 | 133 | w_list = (w1, w2, w3) 134 | w_attns = [] 135 | for i in range(len(w_list)): 136 | target_w = w_list[i] 137 | non_target_w = torch.cat((w_list[:i] + w_list[i+1:]), dim=1) 138 | w_attn = self.attn_tmp(target_w, non_target_w) 139 | w = target_w + w_attn 140 | w_attns.append(w) 141 | 142 | w1 = self.mlp_head_word1(w_attns[0]).reshape(B, seq_len, 1) 143 | w2 = self.mlp_head_word2(w_attns[1]).reshape(B, seq_len, 1) 144 | w3 = self.mlp_head_word3(w_attns[2]).reshape(B, seq_len, 1) 145 | 146 | ### Third output utterance score using words representation 147 | rep = (w_attns[0] + w_attns[1] + w_attns[2]) / 3 148 | rep = self.w_attn(rep) 149 | 150 | u1 = self.rep_utt1(rep) 151 | u2 = self.rep_utt2(rep) 152 | u3 = self.rep_utt3(rep) 153 | u4 = self.rep_utt4(rep) 154 | u5 = self.rep_utt5(rep) 155 | 156 | utt_list = (u1, u2, u3, u4, u5) 157 | utt_attns = [] 158 | for i in range(len(utt_list)): 159 | target_utt = utt_list[i] 160 | non_target_utt = torch.cat((utt_list[:i] + utt_list[i+1:]), dim=1) 161 | utt_attn = self.utt_attn_tmp(target_utt, non_target_utt) 162 | utt = target_utt + utt_attn 163 | utt_attns.append(utt) 164 | 165 | 166 | u1 = self.get_last_valid(self.mlp_head_utt1(utt_attns[0]).reshape(B, seq_len), valid_tok_mask) 167 | u2 = self.get_last_valid(self.mlp_head_utt2(utt_attns[1]).reshape(B, seq_len), valid_tok_mask) 168 | u3 = self.get_last_valid(self.mlp_head_utt3(utt_attns[2]).reshape(B, seq_len), valid_tok_mask) 169 | u4 = self.get_last_valid(self.mlp_head_utt4(utt_attns[3]).reshape(B, seq_len), valid_tok_mask) 170 | u5 = self.get_last_valid(self.mlp_head_utt5(utt_attns[4]).reshape(B, seq_len), valid_tok_mask) 171 | 172 | return u1, u2, u3, u4, u5, p, w1, w2, w3 -------------------------------------------------------------------------------- /src/models/gopt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 10/22/21 1:23 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : gopt.py 7 | 8 | # attention part is borrowed from the timm package. 9 | 10 | import math 11 | import warnings 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | 16 | # code from the t2t-vit paper 17 | def get_sinusoid_encoding(n_position, d_hid): 18 | ''' Sinusoid position encoding table ''' 19 | 20 | def get_position_angle_vec(position): 21 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 22 | 23 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 24 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 25 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 26 | 27 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 28 | 29 | 30 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 31 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 32 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 33 | def norm_cdf(x): 34 | # Computes standard normal cumulative distribution function 35 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 36 | 37 | if (mean < a - 2 * std) or (mean > b + 2 * std): 38 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 39 | "The distribution of values may be incorrect.", 40 | stacklevel=2) 41 | 42 | with torch.no_grad(): 43 | # Values are generated by using a truncated uniform distribution and 44 | # then using the inverse CDF for the normal distribution. 45 | # Get upper and lower cdf values 46 | l = norm_cdf((a - mean) / std) 47 | u = norm_cdf((b - mean) / std) 48 | 49 | # Uniformly fill tensor with values from [l, u], then translate to 50 | # [2l-1, 2u-1]. 51 | tensor.uniform_(2 * l - 1, 2 * u - 1) 52 | 53 | # Use inverse cdf transform for normal distribution to get truncated 54 | # standard normal 55 | tensor.erfinv_() 56 | 57 | # Transform to proper mean, std 58 | tensor.mul_(std * math.sqrt(2.)) 59 | tensor.add_(mean) 60 | 61 | # Clamp to ensure it's in the proper range 62 | tensor.clamp_(min=a, max=b) 63 | return tensor 64 | 65 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 66 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 67 | 68 | 69 | class Attention(nn.Module): 70 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 71 | super().__init__() 72 | self.num_heads = num_heads 73 | head_dim = dim // num_heads 74 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 75 | self.scale = qk_scale or head_dim ** -0.5 76 | 77 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 78 | self.attn_drop = nn.Dropout(attn_drop) 79 | self.proj = nn.Linear(dim, dim) 80 | self.proj_drop = nn.Dropout(proj_drop) 81 | 82 | def forward(self, x): 83 | B, N, C = x.shape 84 | #print(C) 85 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 86 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 87 | 88 | attn = (q @ k.transpose(-2, -1)) * self.scale 89 | attn = attn.softmax(dim=-1) 90 | attn = self.attn_drop(attn) 91 | 92 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 93 | x = self.proj(x) 94 | x = self.proj_drop(x) 95 | return x 96 | 97 | class Mlp(nn.Module): 98 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 99 | super().__init__() 100 | out_features = out_features or in_features 101 | hidden_features = hidden_features or in_features 102 | self.fc1 = nn.Linear(in_features, hidden_features) 103 | self.act = act_layer() 104 | self.fc2 = nn.Linear(hidden_features, out_features) 105 | self.drop = nn.Dropout(drop) 106 | 107 | def forward(self, x): 108 | x = self.fc1(x) 109 | x = self.act(x) 110 | x = self.drop(x) 111 | x = self.fc2(x) 112 | x = self.drop(x) 113 | return x 114 | 115 | class Block(nn.Module): 116 | 117 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 118 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 119 | super().__init__() 120 | self.norm1 = norm_layer(dim) 121 | self.attn = Attention( 122 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 123 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 124 | self.drop_path = nn.Identity() 125 | self.norm2 = norm_layer(dim) 126 | mlp_hidden_dim = int(dim * mlp_ratio) 127 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 128 | 129 | def forward(self, x): 130 | x = x + self.drop_path(self.attn(self.norm1(x))) 131 | x = x + self.drop_path(self.mlp(self.norm2(x))) 132 | return x 133 | 134 | # standard GOPT model proposed in the paper 135 | class GOPT(nn.Module): 136 | def __init__(self, embed_dim, num_heads, depth, input_dim=84): 137 | super().__init__() 138 | self.input_dim = input_dim 139 | self.embed_dim = embed_dim 140 | # Transformer encode blocks 141 | self.blocks = nn.ModuleList([Block(dim=embed_dim, num_heads=num_heads) for i in range(depth)]) 142 | 143 | # sin pos embedding or learnable pos embedding, 55 = 50 sequence length + 5 utt-level cls tokens 144 | #self.pos_embed = nn.Parameter(get_sinusoid_encoding(55, self.embed_dim) * 0.1, requires_grad=True) 145 | self.pos_embed = nn.Parameter(torch.zeros(1, 55, self.embed_dim)) 146 | trunc_normal_(self.pos_embed, std=.02) 147 | 148 | # for phone classification 149 | self.in_proj = nn.Linear(self.input_dim, embed_dim) 150 | self.mlp_head_phn = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 151 | 152 | # for word classification, 1=accuracy, 2=stress, 3=total 153 | self.mlp_head_word1 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 154 | self.mlp_head_word2 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 155 | self.mlp_head_word3 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 156 | 157 | # canonical phone projection, assume there are 40 phns 158 | self.phn_proj = nn.Linear(40, embed_dim) 159 | 160 | # utterance level, 1=accuracy, 2=completeness, 3=fluency, 4=prosodic, 5=total score 161 | self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dim)) 162 | self.mlp_head_utt1 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 163 | self.cls_token2 = nn.Parameter(torch.zeros(1, 1, embed_dim)) 164 | self.mlp_head_utt2 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 165 | self.cls_token3 = nn.Parameter(torch.zeros(1, 1, embed_dim)) 166 | self.mlp_head_utt3 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 167 | self.cls_token4 = nn.Parameter(torch.zeros(1, 1, embed_dim)) 168 | self.mlp_head_utt4 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 169 | self.cls_token5 = nn.Parameter(torch.zeros(1, 1, embed_dim)) 170 | self.mlp_head_utt5 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 171 | 172 | # initialize the cls tokens 173 | trunc_normal_(self.cls_token1, std=.02) 174 | trunc_normal_(self.cls_token2, std=.02) 175 | trunc_normal_(self.cls_token3, std=.02) 176 | trunc_normal_(self.cls_token4, std=.02) 177 | trunc_normal_(self.cls_token5, std=.02) 178 | 179 | # x shape in [batch_size, sequence_len, feat_dim] 180 | # phn in [batch_size, seq_len] 181 | def forward(self, x, phn): 182 | 183 | # batch size 184 | B = x.shape[0] 185 | 186 | # phn_one_hot in shape [batch_size, seq_len, feat_dim] 187 | phn_one_hot = torch.nn.functional.one_hot(phn.long()+1, num_classes=40).float() 188 | # phn_embed in shape [batch_size, seq_len, embed_dim] 189 | phn_embed = self.phn_proj(phn_one_hot) 190 | 191 | # if the input dimension is different from the Transformer embedding dimension, project the input to same dim 192 | if self.embed_dim != self.input_dim: 193 | x = self.in_proj(x) 194 | 195 | x = x + phn_embed 196 | 197 | cls_token1 = self.cls_token1.expand(B, -1, -1) 198 | cls_token2 = self.cls_token2.expand(B, -1, -1) 199 | cls_token3 = self.cls_token3.expand(B, -1, -1) 200 | cls_token4 = self.cls_token4.expand(B, -1, -1) 201 | cls_token5 = self.cls_token5.expand(B, -1, -1) 202 | 203 | x = torch.cat((cls_token1, cls_token2, cls_token3, cls_token4, cls_token5, x), dim=1) 204 | 205 | x = x + self.pos_embed 206 | 207 | # forward to the Transformer encoder 208 | for blk in self.blocks: 209 | x = blk(x) 210 | 211 | # the first 5 tokens are utterance-level cls tokens, i.e., accuracy, completeness, fluency, prosodic, total scores 212 | u1 = self.mlp_head_utt1(x[:, 0]) 213 | u2 = self.mlp_head_utt2(x[:, 1]) 214 | u3 = self.mlp_head_utt3(x[:, 2]) 215 | u4 = self.mlp_head_utt4(x[:, 3]) 216 | u5 = self.mlp_head_utt5(x[:, 4]) 217 | 218 | # 6th-end tokens are phone score tokens 219 | p = self.mlp_head_phn(x[:, 5:]) 220 | 221 | # word score is propagated to phone-level, so word output is also at phone-level. 222 | # but different mlp heads are used, 1 = accuracy, 2 = stress, 3 = total 223 | w1 = self.mlp_head_word1(x[:, 5:]) 224 | w2 = self.mlp_head_word2(x[:, 5:]) 225 | w3 = self.mlp_head_word3(x[:, 5:]) 226 | return u1, u2, u3, u4, u5, p, w1, w2, w3 227 | 228 | # GOPT model without canonical phone embedding, performance worse than standard GOPT model 229 | class GOPTNoPhn(nn.Module): 230 | def __init__(self, embed_dim, num_heads, depth, input_dim=84): 231 | super().__init__() 232 | self.input_dim = input_dim 233 | self.embed_dim = embed_dim 234 | self.blocks = nn.ModuleList([Block(dim=embed_dim, num_heads=num_heads) for i in range(depth)]) 235 | 236 | # sin pos embedding 237 | #self.pos_embed = nn.Parameter(get_sinusoid_encoding(55, self.embed_dim) * 0.1, requires_grad=True) 238 | self.pos_embed = nn.Parameter(torch.zeros(1, 55, self.embed_dim)) 239 | trunc_normal_(self.pos_embed, std=.02) 240 | 241 | # for phone classification 242 | self.in_proj = nn.Linear(self.input_dim, embed_dim) 243 | self.mlp_head_phn = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 244 | 245 | # for word classification 246 | self.mlp_head_word1 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 247 | self.mlp_head_word2 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 248 | self.mlp_head_word3 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 249 | 250 | # phone projection 251 | self.phn_proj = nn.Linear(40, embed_dim) 252 | 253 | # utterance level 254 | self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dim)) 255 | self.mlp_head_utt1 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 256 | self.cls_token2 = nn.Parameter(torch.zeros(1, 1, embed_dim)) 257 | self.mlp_head_utt2 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 258 | self.cls_token3 = nn.Parameter(torch.zeros(1, 1, embed_dim)) 259 | self.mlp_head_utt3 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 260 | self.cls_token4 = nn.Parameter(torch.zeros(1, 1, embed_dim)) 261 | self.mlp_head_utt4 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 262 | self.cls_token5 = nn.Parameter(torch.zeros(1, 1, embed_dim)) 263 | self.mlp_head_utt5 = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1)) 264 | 265 | # initialize the cls tokens 266 | trunc_normal_(self.cls_token1, std=.02) 267 | trunc_normal_(self.cls_token2, std=.02) 268 | trunc_normal_(self.cls_token3, std=.02) 269 | trunc_normal_(self.cls_token4, std=.02) 270 | trunc_normal_(self.cls_token5, std=.02) 271 | 272 | # x shape in [batch_size, sequence_len, feat_dim] 273 | # phn in [batch_size, seq_len] 274 | def forward(self, x, phn): 275 | 276 | # batch size 277 | B = x.shape[0] 278 | 279 | # phn_one_hot in shape [batch_size, seq_len, feat_dim] 280 | phn_one_hot = torch.nn.functional.one_hot(phn.long()+1, num_classes=40).float() 281 | # phn_embed in shape [batch_size, seq_len, embed_dim] 282 | phn_embed = self.phn_proj(phn_one_hot) 283 | 284 | if self.embed_dim != self.input_dim: 285 | x = self.in_proj(x) 286 | 287 | #x = x + phn_embed 288 | 289 | cls_token1 = self.cls_token1.expand(B, -1, -1) 290 | cls_token2 = self.cls_token2.expand(B, -1, -1) 291 | cls_token3 = self.cls_token3.expand(B, -1, -1) 292 | cls_token4 = self.cls_token4.expand(B, -1, -1) 293 | cls_token5 = self.cls_token5.expand(B, -1, -1) 294 | 295 | x = torch.cat((cls_token1, cls_token2, cls_token3, cls_token4, cls_token5, x), dim=1) 296 | 297 | x = x + self.pos_embed 298 | 299 | for blk in self.blocks: 300 | x = blk(x) 301 | 302 | u1 = self.mlp_head_utt1(x[:, 0]) 303 | u2 = self.mlp_head_utt2(x[:, 1]) 304 | u3 = self.mlp_head_utt3(x[:, 2]) 305 | u4 = self.mlp_head_utt4(x[:, 3]) 306 | u5 = self.mlp_head_utt5(x[:, 4]) 307 | 308 | p = self.mlp_head_phn(x[:, 5:]) 309 | w1 = self.mlp_head_word1(x[:, 5:]) 310 | w2 = self.mlp_head_word2(x[:, 5:]) 311 | w3 = self.mlp_head_word3(x[:, 5:]) 312 | return u1, u2, u3, u4, u5, p, w1, w2, w3 -------------------------------------------------------------------------------- /src/traintest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 9/20/21 12:02 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : traintest.py 7 | 8 | # train and test the models 9 | import sys 10 | import os 11 | import time 12 | from torch.utils.data import Dataset, DataLoader 13 | 14 | sys.path.append(os.path.dirname(os.path.dirname(sys.path[0]))) 15 | 16 | from models import * 17 | import argparse 18 | 19 | print("I am process %s, running on %s: starting (%s)" % (os.getpid(), os.uname()[1], time.asctime())) 20 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 21 | parser.add_argument("--exp-dir", type=str, default="./exp/", help="directory to dump experiments") 22 | parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, metavar='LR', help='initial learning rate') 23 | parser.add_argument("--n-epochs", type=int, default=100, help="number of maximum training epochs") 24 | parser.add_argument("--goptdepth", type=int, default=1, help="depth of gopt models") 25 | parser.add_argument("--goptheads", type=int, default=1, help="heads of gopt models") 26 | parser.add_argument("--batch_size", type=int, default=25, help="training batch size") 27 | parser.add_argument("--embed_dim", type=int, default=12, help="gopt transformer embedding dimension") 28 | parser.add_argument("--loss_w_phn", type=float, default=1, help="weight for phoneme-level loss") 29 | parser.add_argument("--loss_w_word", type=float, default=1, help="weight for word-level loss") 30 | parser.add_argument("--loss_w_utt", type=float, default=1, help="weight for utterance-level loss") 31 | parser.add_argument("--model", type=str, default='gopt', help="name of the model") 32 | parser.add_argument("--am", type=str, default='librispeech', help="name of the acoustic models") 33 | parser.add_argument("--noise", type=float, default=0., help="the scale of random noise added on the input GoP feature") 34 | 35 | # just to generate the header for the result.csv 36 | def gen_result_header(): 37 | phn_header = ['epoch', 'phone_train_mse', 'phone_train_pcc', 'phone_test_mse', 'phone_test_pcc', 'learning rate'] 38 | utt_header_set = ['utt_train_mse', 'utt_train_pcc', 'utt_test_mse', 'utt_test_pcc'] 39 | utt_header_score = ['accuracy', 'completeness', 'fluency', 'prosodic', 'total'] 40 | word_header_set = ['word_train_pcc', 'word_test_pcc'] 41 | word_header_score = ['accuracy', 'stress', 'total'] 42 | utt_header, word_header = [], [] 43 | for dset in utt_header_set: 44 | utt_header = utt_header + [dset+'_'+x for x in utt_header_score] 45 | for dset in word_header_set: 46 | word_header = word_header + [dset+'_'+x for x in word_header_score] 47 | header = phn_header + utt_header + word_header 48 | return header 49 | 50 | def train(audio_model, train_loader, test_loader, args): 51 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 52 | print('running on ' + str(device)) 53 | 54 | # best_cum_mAP is checkpoint ensemble from the first epoch to the best epoch 55 | best_epoch, best_mse = 0, 999 56 | global_step, epoch = 0, 0 57 | exp_dir = args.exp_dir 58 | 59 | if not isinstance(audio_model, nn.DataParallel): 60 | audio_model = nn.DataParallel(audio_model) 61 | 62 | audio_model = audio_model.to(device) 63 | # Set up the optimizer 64 | trainables = [p for p in audio_model.parameters() if p.requires_grad] 65 | print('Total parameter number is : {:.3f} k'.format(sum(p.numel() for p in audio_model.parameters()) / 1e3)) 66 | print('Total trainable parameter number is : {:.3f} k'.format(sum(p.numel() for p in trainables) / 1e3)) 67 | optimizer = torch.optim.Adam(trainables, args.lr, weight_decay=5e-7, betas=(0.95, 0.999)) 68 | 69 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, list(range(20, 100, 5)), gamma=0.5, last_epoch=-1) 70 | 71 | loss_fn = nn.MSELoss() 72 | 73 | print("current #steps=%s, #epochs=%s" % (global_step, epoch)) 74 | print("start training...") 75 | result = np.zeros([args.n_epochs, 32]) 76 | 77 | while epoch < args.n_epochs: 78 | audio_model.train() 79 | for i, (audio_input, phn_label, phns, utt_label, word_label) in enumerate(train_loader): 80 | 81 | audio_input = audio_input.to(device, non_blocking=True) 82 | phn_label = phn_label.to(device, non_blocking=True) 83 | utt_label = utt_label.to(device, non_blocking=True) 84 | word_label = word_label.to(device, non_blocking=True) 85 | 86 | # warmup 87 | warm_up_step = 100 88 | if global_step <= warm_up_step and global_step % 5 == 0: 89 | warm_lr = (global_step / warm_up_step) * args.lr 90 | for param_group in optimizer.param_groups: 91 | param_group['lr'] = warm_lr 92 | print('warm-up learning rate is {:f}'.format(optimizer.param_groups[0]['lr'])) 93 | 94 | # add random noise for augmentation. 95 | noise = (torch.rand([audio_input.shape[0], audio_input.shape[1], audio_input.shape[2]]) - 1) * args.noise 96 | noise = noise.to(device, non_blocking=True) 97 | audio_input = audio_input + noise 98 | 99 | #print(phns.shape) 100 | u1, u2, u3, u4, u5, p, w1, w2, w3 = audio_model(audio_input, phns) 101 | 102 | # filter out the padded tokens, only calculate the loss based on the valid tokens 103 | # < 0 is a flag of padded tokens 104 | mask = (phn_label>=0) 105 | p = p.squeeze(2) 106 | p = p * mask 107 | phn_label = phn_label * mask 108 | 109 | loss_phn = loss_fn(p, phn_label) 110 | 111 | # avoid the 0 losses of the padded tokens impacting the performance 112 | loss_phn = loss_phn * (mask.shape[0] * mask.shape[1]) / torch.sum(mask) 113 | 114 | # utterance level loss, also mse 115 | utt_preds = torch.cat((u1, u2, u3, u4, u5), dim=1) 116 | loss_utt = loss_fn(utt_preds ,utt_label) 117 | 118 | # word level loss 119 | word_label = word_label[:, :, 0:3] 120 | mask = (word_label>=0) 121 | word_pred = torch.cat((w1,w2,w3), dim=2) 122 | word_pred = word_pred * mask 123 | word_label = word_label * mask 124 | loss_word = loss_fn(word_pred, word_label) 125 | loss_word = loss_word * (mask.shape[0] * mask.shape[1] * mask.shape[2]) / torch.sum(mask) 126 | 127 | loss = args.loss_w_phn * loss_phn + args.loss_w_utt * loss_utt + args.loss_w_word * loss_word 128 | 129 | optimizer.zero_grad() 130 | loss.backward() 131 | optimizer.step() 132 | global_step += 1 133 | 134 | print('start validation') 135 | 136 | # ensemble results 137 | # don't save prediction for the training set 138 | tr_mse, tr_corr, tr_utt_mse, tr_utt_corr, tr_word_mse, tr_word_corr = validate(audio_model, train_loader, args, -1) 139 | te_mse, te_corr, te_utt_mse, te_utt_corr, te_word_mse, te_word_corr = validate(audio_model, test_loader, args, best_mse) 140 | 141 | print('Phone: Test MSE: {:.3f}, CORR: {:.3f}'.format(te_mse.item(), te_corr)) 142 | print('Utterance:, ACC: {:.3f}, COM: {:.3f}, FLU: {:.3f}, PROC: {:.3f}, Total: {:.3f}'.format(te_utt_corr[0], te_utt_corr[1], te_utt_corr[2], te_utt_corr[3], te_utt_corr[4])) 143 | print('Word:, ACC: {:.3f}, Stress: {:.3f}, Total: {:.3f}'.format(te_word_corr[0], te_word_corr[1], te_word_corr[2])) 144 | 145 | result[epoch, :6] = [epoch, tr_mse, tr_corr, te_mse, te_corr, optimizer.param_groups[0]['lr']] 146 | 147 | result[epoch, 6:26] = np.concatenate([tr_utt_mse, tr_utt_corr, te_utt_mse, te_utt_corr]) 148 | 149 | result[epoch, 26:32] = np.concatenate([tr_word_corr, te_word_corr]) 150 | 151 | header = ','.join(gen_result_header()) 152 | np.savetxt(exp_dir + '/result.csv', result, delimiter=',', header=header, comments='') 153 | print('-------------------validation finished-------------------') 154 | 155 | if te_mse < best_mse: 156 | best_mse = te_mse 157 | best_epoch = epoch 158 | 159 | if best_epoch == epoch: 160 | if os.path.exists("%s/models/" % (exp_dir)) == False: 161 | os.mkdir("%s/models" % (exp_dir)) 162 | torch.save(audio_model.state_dict(), "%s/models/best_audio_model.pth" % (exp_dir)) 163 | 164 | if global_step > warm_up_step: 165 | scheduler.step() 166 | 167 | print('Epoch-{0} lr: {1}'.format(epoch, optimizer.param_groups[0]['lr'])) 168 | epoch += 1 169 | 170 | def validate(audio_model, val_loader, args, best_mse): 171 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 172 | if not isinstance(audio_model, nn.DataParallel): 173 | audio_model = nn.DataParallel(audio_model) 174 | audio_model = audio_model.to(device) 175 | audio_model.eval() 176 | 177 | A_phn, A_phn_target = [], [] 178 | A_u1, A_u2, A_u3, A_u4, A_u5, A_utt_target = [], [], [], [], [], [] 179 | A_w1, A_w2, A_w3, A_word_target = [], [], [], [] 180 | with torch.no_grad(): 181 | for i, (audio_input, phn_label, phns, utt_label, word_label) in enumerate(val_loader): 182 | audio_input = audio_input.to(device) 183 | 184 | # compute output 185 | u1, u2, u3, u4, u5, p, w1, w2, w3 = audio_model(audio_input, phns) 186 | p = p.to('cpu').detach() 187 | u1, u2, u3, u4, u5 = u1.to('cpu').detach(), u2.to('cpu').detach(), u3.to('cpu').detach(), u4.to('cpu').detach(), u5.to('cpu').detach() 188 | w1, w2, w3 = w1.to('cpu').detach(), w2.to('cpu').detach(), w3.to('cpu').detach() 189 | 190 | A_phn.append(p) 191 | A_phn_target.append(phn_label) 192 | 193 | A_u1.append(u1) 194 | A_u2.append(u2) 195 | A_u3.append(u3) 196 | A_u4.append(u4) 197 | A_u5.append(u5) 198 | A_utt_target.append(utt_label) 199 | 200 | A_w1.append(w1) 201 | A_w2.append(w2) 202 | A_w3.append(w3) 203 | A_word_target.append(word_label) 204 | 205 | # phone level 206 | A_phn, A_phn_target = torch.cat(A_phn), torch.cat(A_phn_target) 207 | 208 | # utterance level 209 | A_u1, A_u2, A_u3, A_u4, A_u5, A_utt_target = torch.cat(A_u1), torch.cat(A_u2), torch.cat(A_u3), torch.cat(A_u4), torch.cat(A_u5), torch.cat(A_utt_target) 210 | 211 | # word level 212 | A_w1, A_w2, A_w3, A_word_target = torch.cat(A_w1), torch.cat(A_w2), torch.cat(A_w3), torch.cat(A_word_target) 213 | 214 | # get the scores 215 | phn_mse, phn_corr = valid_phn(A_phn, A_phn_target) 216 | 217 | A_utt = torch.cat((A_u1, A_u2, A_u3, A_u4, A_u5), dim=1) 218 | utt_mse, utt_corr = valid_utt(A_utt, A_utt_target) 219 | 220 | A_word = torch.cat((A_w1, A_w2, A_w3), dim=2) 221 | word_mse, word_corr, valid_word_pred, valid_word_target = valid_word(A_word, A_word_target) 222 | 223 | if phn_mse < best_mse: 224 | print('new best phn mse {:.3f}, now saving predictions.'.format(phn_mse)) 225 | 226 | # create the directory 227 | if os.path.exists(args.exp_dir + '/preds') == False: 228 | os.mkdir(args.exp_dir + '/preds') 229 | 230 | # saving the phn target, only do once 231 | if os.path.exists(args.exp_dir + '/preds/phn_target.npy') == False: 232 | np.save(args.exp_dir + '/preds/phn_target.npy', A_phn_target) 233 | np.save(args.exp_dir + '/preds/word_target.npy', valid_word_target) 234 | np.save(args.exp_dir + '/preds/utt_target.npy', A_utt_target) 235 | 236 | np.save(args.exp_dir + '/preds/phn_pred.npy', A_phn) 237 | np.save(args.exp_dir + '/preds/word_pred.npy', valid_word_pred) 238 | np.save(args.exp_dir + '/preds/utt_pred.npy', A_utt) 239 | 240 | return phn_mse, phn_corr, utt_mse, utt_corr, word_mse, word_corr 241 | 242 | def valid_phn(audio_output, target): 243 | valid_token_pred = [] 244 | valid_token_target = [] 245 | audio_output = audio_output.squeeze(2) 246 | for i in range(audio_output.shape[0]): 247 | for j in range(audio_output.shape[1]): 248 | # only count valid tokens, not padded tokens (represented by negative values) 249 | if target[i, j] >= 0: 250 | valid_token_pred.append(audio_output[i, j]) 251 | valid_token_target.append(target[i, j]) 252 | valid_token_target = np.array(valid_token_target) 253 | valid_token_pred = np.array(valid_token_pred) 254 | 255 | valid_token_mse = np.mean((valid_token_target - valid_token_pred) ** 2) 256 | corr = np.corrcoef(valid_token_pred, valid_token_target)[0, 1] 257 | return valid_token_mse, corr 258 | 259 | def valid_utt(audio_output, target): 260 | mse = [] 261 | corr = [] 262 | for i in range(5): 263 | cur_mse = np.mean(((audio_output[:, i] - target[:, i]) ** 2).numpy()) 264 | cur_corr = np.corrcoef(audio_output[:, i], target[:, i])[0, 1] 265 | mse.append(cur_mse) 266 | corr.append(cur_corr) 267 | return mse, corr 268 | 269 | def valid_word(audio_output, target): 270 | word_id = target[:, :, -1] 271 | target = target[:, :, 0:3] 272 | 273 | valid_token_pred = [] 274 | valid_token_target = [] 275 | 276 | # unique, counts = np.unique(np.array(target), return_counts=True) 277 | # print(dict(zip(unique, counts))) 278 | 279 | # for each utterance 280 | for i in range(target.shape[0]): 281 | prev_w_id = 0 282 | start_id = 0 283 | # for each token 284 | for j in range(target.shape[1]): 285 | cur_w_id = word_id[i, j].int() 286 | # if a new word 287 | if cur_w_id != prev_w_id: 288 | # average each phone belongs to the word 289 | valid_token_pred.append(np.mean(audio_output[i, start_id: j, :].numpy(), axis=0)) 290 | valid_token_target.append(np.mean(target[i, start_id: j, :].numpy(), axis=0)) 291 | # sanity check, if the range indeed contains a single word 292 | if len(torch.unique(target[i, start_id: j, 1])) != 1: 293 | print(target[i, start_id: j, 0]) 294 | # if end of the utterance 295 | if cur_w_id == -1: 296 | break 297 | else: 298 | prev_w_id = cur_w_id 299 | start_id = j 300 | 301 | valid_token_pred = np.array(valid_token_pred) 302 | # this rounding is to solve the precision issue in the label 303 | valid_token_target = np.array(valid_token_target).round(2) 304 | 305 | mse_list, corr_list = [], [] 306 | # for each (accuracy, stress, total) word score 307 | for i in range(3): 308 | valid_token_mse = np.mean((valid_token_target[:, i] - valid_token_pred[:, i]) ** 2) 309 | corr = np.corrcoef(valid_token_pred[:, i], valid_token_target[:, i])[0, 1] 310 | mse_list.append(valid_token_mse) 311 | corr_list.append(corr) 312 | return mse_list, corr_list, valid_token_pred, valid_token_target 313 | 314 | 315 | class GoPDataset(Dataset): 316 | def __init__(self, set, am='librispeech'): 317 | # normalize the input to 0 mean and unit std. 318 | if am=='librispeech': 319 | dir='seq_data_librispeech' 320 | norm_mean, norm_std = 3.203, 4.045 321 | elif am=='paiia': 322 | dir='seq_data_paiia' 323 | norm_mean, norm_std = -0.652, 9.737 324 | elif am=='paiib': 325 | dir='seq_data_paiib' 326 | norm_mean, norm_std = -0.516, 9.247 327 | else: 328 | raise ValueError('Acoustic Model Unrecognized.') 329 | 330 | if set == 'train': 331 | self.feat = torch.tensor(np.load('../data/'+dir+'/tr_feat.npy'), dtype=torch.float) 332 | self.phn_label = torch.tensor(np.load('../data/'+dir+'/tr_label_phn.npy'), dtype=torch.float) 333 | self.utt_label = torch.tensor(np.load('../data/'+dir+'/tr_label_utt.npy'), dtype=torch.float) 334 | self.word_label = torch.tensor(np.load('../data/'+dir+'/tr_label_word.npy'), dtype=torch.float) 335 | elif set == 'test': 336 | self.feat = torch.tensor(np.load('../data/'+dir+'/te_feat.npy'), dtype=torch.float) 337 | self.phn_label = torch.tensor(np.load('../data/'+dir+'/te_label_phn.npy'), dtype=torch.float) 338 | self.utt_label = torch.tensor(np.load('../data/'+dir+'/te_label_utt.npy'), dtype=torch.float) 339 | self.word_label = torch.tensor(np.load('../data/'+dir+'/te_label_word.npy'), dtype=torch.float) 340 | 341 | # normalize the GOP feature using the training set mean and std (only count the valid token features, exclude the padded tokens). 342 | self.feat = self.norm_valid(self.feat, norm_mean, norm_std) 343 | 344 | # normalize the utt_label to 0-2 (same with phn score range) 345 | self.utt_label = self.utt_label / 5 346 | # the last dim is word_id, so not normalizing 347 | self.word_label[:, :, 0:3] = self.word_label[:, :, 0:3] / 5 348 | self.phn_label[:, :, 1] = self.phn_label[:, :, 1] 349 | 350 | # only normalize valid tokens, not padded token 351 | def norm_valid(self, feat, norm_mean, norm_std): 352 | norm_feat = torch.zeros_like(feat) 353 | for i in range(feat.shape[0]): 354 | for j in range(feat.shape[1]): 355 | if feat[i, j, 0] != 0: 356 | norm_feat[i, j, :] = (feat[i, j, :] - norm_mean) / norm_std 357 | else: 358 | break 359 | return norm_feat 360 | 361 | def __len__(self): 362 | return self.feat.shape[0] 363 | 364 | def __getitem__(self, idx): 365 | # feat, phn_label, phn_id, utt_label, word_label 366 | return self.feat[idx, :], self.phn_label[idx, :, 1], self.phn_label[idx, :, 0], self.utt_label[idx, :], self.word_label[idx, :] 367 | 368 | args = parser.parse_args() 369 | 370 | am = args.am 371 | print('now train with {:s} acoustic models'.format(am)) 372 | feat_dim = {'librispeech':84, 'paiia':86, 'paiib': 88} 373 | input_dim=feat_dim[am] 374 | 375 | # nowa is the best models used in this work 376 | if args.model == 'gopt': 377 | print('now train a GOPT models') 378 | audio_mdl = GOPT(embed_dim=args.embed_dim, num_heads=args.goptheads, depth=args.goptdepth, input_dim=input_dim) 379 | # for ablation study only 380 | elif args.model == 'gopt_nophn': 381 | print('now train a GOPT models without canonical phone embedding') 382 | audio_mdl = GOPTNoPhn(embed_dim=args.embed_dim, num_heads=args.goptheads, depth=args.goptdepth, input_dim=input_dim) 383 | elif args.model == 'lstm': 384 | print('now train a baseline LSTM model') 385 | audio_mdl = BaselineLSTM(embed_dim=args.embed_dim, depth=args.goptdepth, input_dim=input_dim) 386 | elif args.model == 'hipama': 387 | print('now train a HiPAMA model') 388 | audio_mdl = HiPAMA(embed_dim=args.embed_dim, depth=args.goptdepth, input_dim=input_dim) 389 | 390 | tr_dataset = GoPDataset('train', am=am) 391 | tr_dataloader = DataLoader(tr_dataset, batch_size=args.batch_size, shuffle=True) 392 | te_dataset = GoPDataset('test', am=am) 393 | te_dataloader = DataLoader(te_dataset, batch_size=2500, shuffle=False) 394 | 395 | train(audio_mdl, tr_dataloader, te_dataloader, args) --------------------------------------------------------------------------------