├── .gitignore ├── README.md ├── my_863_corpus ├── README.md ├── conf │ ├── cnn_lstm_ctc_setting.conf │ └── lstm_ctc_setting.conf ├── run.sh └── steps │ ├── BeamSearch.py │ ├── cnn_lstm_ctc.py │ ├── ctcDecoder.py │ ├── data_loader.py │ ├── lstm_ctc.py │ ├── model.py │ ├── test.py │ └── utils.py ├── requirements.txt └── timit ├── conf ├── backup.conf ├── ctc_config.yaml ├── dev_spk.list ├── fbank.conf ├── mfcc.conf ├── phones.60-48-39.map ├── test_spk.list └── train_spk.list ├── local ├── make_spectrum.py ├── normalize_phone.py └── timit_data_prep.sh ├── models └── model_ctc.py ├── path.sh ├── run.sh ├── steps ├── get_model_units.py ├── make_feat.sh ├── test_ctc.py ├── train_ctc.py ├── train_lm.sh ├── visualize.py └── visualize.py.old └── utils ├── BeamSearch.py ├── NgramLM.py ├── ctcDecoder.py ├── data_loader.py └── tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | 4 | my_863_corpus/* 5 | log/ 6 | checkpoint/ 7 | data/ 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Update: 2 | Update to pytorch1.2 and python3. 3 | 4 | # CTC-based Automatic Speech Recogniton 5 | This is a CTC-based speech recognition system with pytorch. 6 | 7 | At present, the system only supports phoneme recognition. 8 | 9 | You can also do it at word-level and may get a high error rate. 10 | 11 | Another way is to decode with a lexcion and word-level language model using WFST which is not included in this system. 12 | 13 | ## Data 14 | English Corpus: Timit 15 | - Training set: 3696 sentences(exclude SA utterance) 16 | - Dev set: 400 sentences 17 | - Test set: 192 sentences 18 | 19 | Chinese Corpus: 863 Corpus 20 | - Training set: 21 | 22 | | Speaker | UtterId | Utterances | 23 | | :-: | :-: | :-: | 24 | | M50, F50 | A1-A521, AW1-AW129 | 650 sentences | 25 | | M54, F54 | B522-B1040,BW130-BW259 | 649 sentences | 26 | | M60, F60 | C1041-C1560 CW260-CW388 | 649 sentences | 27 | | M64, F64 | D1-D625 | 625 sentences | 28 | | All | |5146 sentences | 29 | 30 | - Test set: 31 | 32 | | Speaker | UtterId | Utterances | 33 | | :-: | :-: | :-: | 34 | | M51, F51 | A1-A100 | 100 sentences | 35 | | M55, F55 | B522-B521 | 100 sentences | 36 | | M61, F61 | C1041-C1140 | 100 sentences | 37 | | M63, F63 | D1-D100 | 100 sentences | 38 | | All | | 800 sentences | 39 | 40 | ## Install 41 | - Install [Pytorch](http://pytorch.org/) 42 | - ~~Install [warp-ctc](https://github.com/SeanNaren/warp-ctc) and bind it to pytorch.~~ 43 | ~~Notice: If use python2, reinstall the pytorch with source code instead of pip.~~ 44 | Use pytorch1.2 built-in CTC function(nn.CTCLoss) Now. 45 | - Install [Kaldi](https://github.com/kaldi-asr/kaldi). We use kaldi to extract mfcc and fbank. 46 | - Install pytorch [torchaudio](https://github.com/pytorch/audio.git)(This is needed when using waveform as input). 47 | - ~~Install [KenLM](https://github.com/kpu/kenlm). Training n-gram Languange Model if needed~~. 48 | Use Irstlm in kaldi tools instead. 49 | - Install and start visdom 50 | ``` 51 | pip3 install visdom 52 | python -m visdom.server 53 | ``` 54 | - Install other python packages 55 | ``` 56 | pip install -r requirements.txt 57 | ``` 58 | 59 | ## Usage 60 | 1. Install all the packages according to the Install part. 61 | 2. Revise the top script run.sh. 62 | 4. Open the config file to revise the super-parameters about everything. 63 | 5. Run the top script with four conditions 64 | ```bash 65 | bash run.sh data_prepare + AM training + LM training + testing 66 | bash run.sh 1 AM training + LM training + testing 67 | bash run.sh 2 LM training + testing 68 | bash run.sh 3 testing 69 | ``` 70 | RNN LM training is not implemented yet. They are added to the todo-list. 71 | 72 | ## Data Prepare 73 | 1. Extract 39dim mfcc and 40dim fbank feature from kaldi. 74 | 2. Use compute-cmvn-stats and apply-cmvn with training data to get the global mean and variance and normalize the feature. 75 | 3. Rewrite Dataset and dataLoader in torch.nn.dataset to prepare data for training. You can find them in the steps/dataloader.py. 76 | 77 | ## Model 78 | - RNN + DNN + CTC 79 | RNN here can be replaced by nn.LSTM and nn.GRU 80 | - CNN + RNN + DNN + CTC 81 | CNN is use to reduce the variety of spectrum which can be caused by the speaker and environment difference. 82 | - How to choose 83 | Use add_cnn to choose one of two models. If add_cnn is True, then CNN+RNN+DNN+CTC will be chosen. 84 | 85 | ## Training: 86 | - initial-lr = 0.001 87 | - decay = 0.5 88 | - wight-decay = 0.005 89 | 90 | Adjust the learning rate if the dev loss is around a specific loss for ten times. 91 | Times of adjusting learning rate is 8 which can be alter in steps/train_ctc.py(line367). 92 | Optimizer is nn.optimizer.Adam with weigth decay 0.005 93 | 94 | ## Decoder 95 | ### Greedy decoder: 96 | Take the max prob of outputs as the result and get the path. 97 | Calculate the WER and CER by used the function of the class. 98 | ### Beam decoder: 99 | Implemented with python. [Original Code](https://github.com/githubharald/CTCDecoder) 100 | I fix it to support phoneme for batch decode. 101 | Beamsearch can improve about 0.2% of phonome accuracy. 102 | Phoneme-level language model is inserted to beam search decoder now. 103 | 104 | ## ToDo 105 | - Combine with RNN-LM 106 | - Beam search with RNN-LM 107 | - The code in 863_corpus is a mess. Need arranged. 108 | 109 | -------------------------------------------------------------------------------- /my_863_corpus/README.md: -------------------------------------------------------------------------------- 1 | # DataSet 2 | Chinese Corpus:863 Corpus 3 | - Training set: 4 | M50 F50 A1-A521 5 | AW1-AW129 650 sentences 6 | M54 F54 B522-B1040 7 | BW130-BW259 649 sentences 8 | M60 F60 C1041-C1560 9 | CW260-CW388 649 sentences 10 | M64 F64 D1-D625 625 sentences 11 | - Total:5146 sentences 12 | 13 | Test set: 14 | M51 F51 A1-A100 100 sentences 15 | M55 F55 B522-B521 100 sentences 16 | M61 F61 C1041-C1140 100 sentences 17 | M63 F63 D1-D100 100 sentences 18 | Total:800 sentences 19 | 20 | # Label 66 phonemes 21 | a 1 ai 2 an 3 ang 4 ao 5 as 6 b 7 c 8 ch 9 d 10 22 | e 11 ei 12 en 13 eng 14 er 15 es 16 f 17 g 18 h 19 i 20 23 | ia 21 ian 22 iang 23 iao 24 ie 25 ih 26 in 27 ing 28 iong 29 is 30 24 | iu 31 iz 32 j 33 k 34 l 35 m 36 n 37 o 38 ong 39 os 40 25 | ou 41 p 42 q 43 r 44 s 45 sh 46 sil 47 t 48 u 49 ua 50 26 | uai 51 uan 52 uang 53 ueng 54 ui 55 un 56 uo 57 us 58 v 59 van 60 27 | ve 61 vn 62 vs 63 x 64 z 65 zh 66 28 | 29 | # Features: 30 | mfcc39 fbank40 spectrum201 31 | -------------------------------------------------------------------------------- /my_863_corpus/conf/cnn_lstm_ctc_setting.conf: -------------------------------------------------------------------------------- 1 | [Data] 2 | dataset = TIMIT 3 | data_dir = /home/fan/pytorch/CTC_pytorch/my_863_corpus/data_prepare/data 4 | feature_type = spectrum 5 | n_feats = 201 6 | out_type = phone 7 | 8 | [Model] 9 | rnn_input_size = 201 10 | rnn_hidden_size = 256 11 | rnn_layers = 4 12 | rnn_type = nn.LSTM 13 | bidirectional = True 14 | batch_norm = True 15 | num_class = 48 16 | drop_out = 0 17 | 18 | [Training] 19 | init_lr = 0.001 20 | num_epoches = 30 21 | least_train_epoch = 5 22 | end_adjust_acc = 0.5 23 | lr_decay = 0.5 24 | batch_size = 8 25 | weight_decay = 0.005 26 | 27 | -------------------------------------------------------------------------------- /my_863_corpus/conf/lstm_ctc_setting.conf: -------------------------------------------------------------------------------- 1 | [Data] 2 | dataset = 863_courpus 3 | data_dir = /home/fan/pytorch/CTC_pytorch/my_863_corpus/data_prepare/data 4 | feature_type = fbank 5 | n_feats = 40 6 | out_type = phone 7 | 8 | [Model] 9 | rnn_input_size = 40 10 | rnn_hidden_size = 256 11 | rnn_layers = 4 12 | rnn_type = nn.LSTM 13 | bidirectional = True 14 | batch_norm = True 15 | num_class = 66 16 | drop_out = 0 17 | model_file = ./log/best_model_cv87.7975047985.pkl 18 | 19 | [Training] 20 | init_lr = 0.001 21 | num_epoches = 500 22 | end_adjust_acc = 1.5 23 | lr_decay = 0.5 24 | batch_size = 16 25 | weight_decay = 0.005 26 | 27 | -------------------------------------------------------------------------------- /my_863_corpus/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #Top script for My experiment 4 | #Author: Richardfan 5 | #2017.11.1 6 | 7 | lstm_ctc_CONF_FILE='./conf/lstm_ctc_setting.conf' 8 | cnn_lstm_ctc_CONF_FILE='./conf/cnn_lstm_ctc_setting.conf' 9 | LOG_DIR='./log/' 10 | 11 | echo ======================================================== 12 | echo " Training " 13 | echo ======================================================== 14 | 15 | #python steps/lstm_ctc.py --conf $lstm_ctc_CONF_FILE --log-dir $LOG_DIR 16 | python steps/cnn_lstm_ctc.py --conf $cnn_lstm_ctc_CONF_FILE --log-dir $LOG_DIR 17 | 18 | 19 | echo ======================================================== 20 | echo " Greedy Decoding " 21 | echo ======================================================== 22 | 23 | #python steps/test.py --conf $lstm_ctc_CONF_FILE --decode-type 'Greedy' 24 | python steps/test.py --conf $cnn_lstm_ctc_CONF_FILE --decode-type 'Greedy' 25 | 26 | -------------------------------------------------------------------------------- /my_863_corpus/steps/BeamSearch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #encoding=utf-8 3 | 4 | import numpy as np 5 | import torch 6 | 7 | class BeamEntry: 8 | "information about one single beam at specific time-step" 9 | def __init__(self): 10 | self.prTotal=0 # blank and non-blank 11 | self.prNonBlank=0 # non-blank 12 | self.prBlank=0 # blank 13 | self.y=() # labelling at current time-step 14 | 15 | 16 | class BeamState: 17 | "information about beams at specific time-step" 18 | def __init__(self): 19 | self.entries={} 20 | 21 | def norm(self): 22 | "length-normalise probabilities to avoid penalising long labellings" 23 | for (k,v) in self.entries.items(): 24 | labellingLen=len(self.entries[k].y) 25 | self.entries[k].prTotal=self.entries[k].prTotal*(1.0/(labellingLen if labellingLen else 1)) 26 | 27 | def sort(self): 28 | "return beams sorted by probability" 29 | u=[v for (k,v) in self.entries.items()] 30 | s=sorted(u, reverse=True, key=lambda x:x.prTotal) 31 | return [x.y for x in s] 32 | 33 | class ctcBeamSearch(object): 34 | def __init__(self, classes, beam_width, lm, blank_index=0): 35 | self.classes = classes 36 | self.beamWidth = beam_width 37 | self.lm = lm 38 | self.blank_index = blank_index 39 | 40 | def calcExtPr(self, k, y, t, mat, beamState): 41 | "probability for extending labelling y to y+k" 42 | 43 | # language model (char bigrams) 44 | bigramProb=1 45 | if self.lm: 46 | c1=self.classes[y[-1] if len(y) else self.classes.index(' ')] 47 | c2=self.classes[k] 48 | lmFactor=0.01 # controls influence of language model 49 | bigramProb=self.lm.getCharBigram(c1,c2)**lmFactor 50 | 51 | # optical model (RNN) 52 | if len(y) and y[-1]==k: 53 | return mat[t, k]*bigramProb*beamState.entries[y].prBlank 54 | else: 55 | return mat[t, k]*bigramProb*beamState.entries[y].prTotal 56 | 57 | def addLabelling(self, beamState, y): 58 | "adds labelling if it does not exist yet" 59 | if y not in beamState.entries: 60 | beamState.entries[y]=BeamEntry() 61 | 62 | def decode(self, inputs, inputs_list): 63 | ''' 64 | mat : FloatTesnor batch * timesteps * class 65 | ''' 66 | batches, maxT, maxC = inputs.size() 67 | res = [] 68 | 69 | for batch in range(batches): 70 | mat = inputs[batch].numpy() 71 | # Initialise beam state 72 | last=BeamState() 73 | y=() 74 | last.entries[y]=BeamEntry() 75 | last.entries[y].prBlank=1 76 | last.entries[y].prTotal=1 77 | 78 | # go over all time-steps 79 | for t in range(inputs_list[batch]): 80 | curr=BeamState() 81 | if (1 - mat[t, self.blank_index]) < 0.1: #跳过概率很接近1的blank帧,增加解码速度 82 | continue 83 | # get best labellings 84 | BHat=last.sort()[0:self.beamWidth] #取前beam个最好的结果 85 | #print(BHat) 86 | # go over best labellings 87 | for y in BHat: 88 | prNonBlank=0 89 | # if nonempty labelling 90 | if len(y)>0: 91 | # seq prob so far and prob of seeing last label again 92 | prNonBlank=last.entries[y].prNonBlank*mat[t, y[-1]] #相同的y两种可能,加入重复或者加入空白,如果之前没有字符,在NonBlank概率为0 93 | 94 | # calc probabilities 95 | prBlank=(last.entries[y].prTotal)*mat[t, self.blank_index] 96 | # save result 97 | self.addLabelling(curr, y) 98 | curr.entries[y].y=y 99 | curr.entries[y].prNonBlank+=prNonBlank 100 | curr.entries[y].prBlank+=prBlank 101 | curr.entries[y].prTotal+=prBlank+prNonBlank 102 | 103 | # extend current labelling 104 | for k in range(maxC): #t时刻加入其它的label,此时Blank的概率为0,如果加入的label与最后一个相同,因为不能重复,所以上一个字符一定是blank 105 | if k != self.blank_index: 106 | newY=y+(k,) 107 | prNonBlank=self.calcExtPr(k, y, t, mat, last) 108 | 109 | # save result 110 | self.addLabelling(curr, newY) 111 | curr.entries[newY].y=newY 112 | curr.entries[newY].prNonBlank+=prNonBlank 113 | curr.entries[newY].prTotal+=prNonBlank 114 | 115 | # set new beam state 116 | last=curr 117 | 118 | # normalise probabilities according to labelling length 119 | last.norm() 120 | 121 | # sort by probability 122 | bestLabelling=last.sort()[0] # get most probable labelling 123 | 124 | # map labels to chars 125 | res_b =' '.join([self.classes[l] for l in bestLabelling]) 126 | res.append(res_b) 127 | return res 128 | 129 | 130 | if __name__=='__main__': 131 | classes=["a","b"] 132 | mat=np.array([[[0.4, 0, 0.6], [0.4, 0, 0.6], [0, 1, 0], [0, 0, 0]], [[0.4, 0, 0.6],[0.4, 0, 0.6], [0.4, 0.1, 0.5], [0.2, 0.5, 0.3]]]) 133 | mat = torch.FloatTensor(mat) 134 | input_list = [2, 2] 135 | print('Test beam search') 136 | expected='a' 137 | decoder = ctcBeamSearch(classes, 10, None, blank_index=len(classes)) 138 | actual=decoder.decode(mat, input_list) 139 | print('Expected: "'+expected+'"') 140 | print(actual) 141 | -------------------------------------------------------------------------------- /my_863_corpus/steps/cnn_lstm_ctc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #encoding=utf-8 3 | 4 | #train process for the model 5 | 6 | from data_loader import myDataset, myCNNDataLoader 7 | from model import * 8 | from ctcDecoder import GreedyDecoder 9 | from warpctc_pytorch import CTCLoss 10 | import torch 11 | import torch.nn as nn 12 | from torch.autograd import Variable 13 | import time 14 | import numpy as np 15 | import argparse 16 | import ConfigParser 17 | import os 18 | import copy 19 | 20 | def train(model, train_loader, loss_fn, optimizer, logger, print_every=20): 21 | model.train() 22 | 23 | total_loss = 0 24 | print_loss = 0 25 | i = 0 26 | for data in train_loader: 27 | inputs, targets, input_sizes, input_sizes_list, target_sizes = data 28 | batch_size = inputs.size(0) 29 | 30 | inputs = Variable(inputs, requires_grad=False) 31 | targets = Variable(targets, requires_grad=False) 32 | input_sizes = Variable(input_sizes, requires_grad=False) 33 | target_sizes = Variable(target_sizes, requires_grad=False) 34 | 35 | if USE_CUDA: 36 | inputs = inputs.cuda() 37 | 38 | out = model(inputs) 39 | 40 | loss = loss_fn(out, targets, input_sizes, target_sizes) 41 | loss /= batch_size 42 | print_loss += loss.data[0] 43 | 44 | if (i + 1) % print_every == 0: 45 | print('batch = %d, loss = %.4f' % (i+1, print_loss / print_every)) 46 | logger.debug('batch = %d, loss = %.4f' % (i+1, print_loss / print_every)) 47 | print_loss = 0 48 | 49 | total_loss += loss.data[0] 50 | optimizer.zero_grad() 51 | loss.backward() 52 | nn.utils.clip_grad_norm(model.parameters(), 400) #防止梯度爆炸或者梯度消失,限制参数范围 53 | optimizer.step() 54 | i += 1 55 | average_loss = total_loss / i 56 | print("Epoch done, average loss: %.4f" % average_loss) 57 | logger.info("Epoch done, average loss: %.4f" % average_loss) 58 | return average_loss 59 | 60 | def dev(model, dev_loader, decoder, logger): 61 | model.eval() 62 | total_cer = 0 63 | total_tokens = 0 64 | 65 | for data in dev_loader: 66 | inputs, targets, input_sizes, input_sizes_list, target_sizes =data 67 | batch_size = inputs.size(1) 68 | 69 | inputs = Variable(inputs, volatile=True, requires_grad=False) 70 | 71 | if USE_CUDA: 72 | inputs = inputs.cuda() 73 | 74 | probs = model(inputs) 75 | probs = probs.data.cpu() 76 | if decoder.space_idx == -1: 77 | total_cer += decoder.phone_word_error(probs, input_sizes_list, targets, target_sizes)[1] 78 | else: 79 | total_cer += decoder.phone_word_error(probs, input_sizes_list, targets, target_sizes)[0] 80 | total_tokens += sum(target_sizes) 81 | acc = 1 - float(total_cer) / total_tokens 82 | return acc*100 83 | 84 | def init_logger(log_file): 85 | import logging 86 | from logging.handlers import RotatingFileHandler 87 | 88 | logger = logging.getLogger() 89 | hdl = RotatingFileHandler(log_file, maxBytes=10*1024*1024, backupCount=10) 90 | formatter=logging.Formatter('%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s') 91 | hdl.setFormatter(formatter) 92 | logger.addHandler(hdl) 93 | logger.setLevel(logging.DEBUG) 94 | return logger 95 | 96 | RNN = {'nn.LSTM':nn.LSTM, 'nn.GRU': nn.GRU, 'nn.RNN':nn.RNN} 97 | parser = argparse.ArgumentParser(description='cnn_lstm_ctc') 98 | parser.add_argument('--conf', default='../conf/cnn_lstm_ctc_setting.conf' , help='conf file with Argument of LSTM and training') 99 | parser.add_argument('--log-dir', dest='log_dir', default='../log', help='log file for training') 100 | 101 | def main(): 102 | args = parser.parse_args() 103 | cf = ConfigParser.ConfigParser() 104 | try: 105 | cf.read(args.conf) 106 | except: 107 | print("conf file not exists") 108 | 109 | logger = init_logger(os.path.join(args.log_dir, 'train_cnn_lstm_ctc.log')) 110 | dataset = cf.get('Data', 'dataset') 111 | data_dir = cf.get('Data', 'data_dir') 112 | feature_type = cf.get('Data', 'feature_type') 113 | out_type = cf.get('Data', 'out_type') 114 | n_feats = cf.getint('Data', 'n_feats') 115 | batch_size = cf.getint("Training", 'batch_size') 116 | 117 | #Data Loader 118 | train_dataset = myDataset(data_dir, data_set='train', feature_type=feature_type, out_type=out_type, n_feats=n_feats) 119 | train_loader = myCNNDataLoader(train_dataset, batch_size=batch_size, shuffle=True, 120 | num_workers=4, pin_memory=False) 121 | dev_dataset = myDataset(data_dir, data_set="test", feature_type=feature_type, out_type=out_type, n_feats=n_feats) 122 | dev_loader = myCNNDataLoader(dev_dataset, batch_size=batch_size, shuffle=False, 123 | num_workers=4, pin_memory=False) 124 | 125 | #decoder for dev set 126 | decoder = GreedyDecoder(dev_dataset.int2phone, space_idx=-1, blank_index=0) 127 | 128 | #Define Model 129 | rnn_input_size = cf.getint('Model', 'rnn_input_size') 130 | rnn_hidden_size = cf.getint('Model', 'rnn_hidden_size') 131 | rnn_layers = cf.getint('Model', 'rnn_layers') 132 | rnn_type = RNN[cf.get('Model', 'rnn_type')] 133 | bidirectional = cf.getboolean('Model', 'bidirectional') 134 | batch_norm = cf.getboolean('Model', 'batch_norm') 135 | num_class = cf.getint('Model', 'num_class') 136 | drop_out = cf.getfloat('Model', 'num_class') 137 | model = CNN_LSTM_CTC(rnn_input_size=rnn_input_size, rnn_hidden_size=rnn_hidden_size, rnn_layers=rnn_layers, 138 | rnn_type=rnn_type, bidirectional=bidirectional, batch_norm=batch_norm, 139 | num_class=num_class, drop_out=drop_out) 140 | #model.apply(xavier_uniform_init) 141 | print(model.name) 142 | 143 | #Training 144 | init_lr = cf.getfloat('Training', 'init_lr') 145 | num_epoches = cf.getint('Training', 'num_epoches') 146 | end_adjust_acc = cf.getfloat('Training', 'end_adjust_acc') 147 | decay = cf.getfloat("Training", 'lr_decay') 148 | weight_decay = cf.getfloat("Training", 'weight_decay') 149 | try: 150 | seed = cf.getint('Training', 'seed') 151 | except: 152 | seed = torch.cuda.initial_seed() 153 | params = { 'num_epoches':num_epoches, 'end_adjust_acc':end_adjust_acc, 'seed':seed, 154 | 'decay':decay, 'learning_rate':init_lr, 'weight_decay':weight_decay, 'batch_size':batch_size, 155 | 'feature_type':feature_type, 'n_feats': n_feats, 'out_type': out_type } 156 | 157 | if USE_CUDA: 158 | torch.cuda.manual_seed(seed) 159 | model = model.cuda() 160 | 161 | print(params) 162 | 163 | loss_fn = CTCLoss() 164 | optimizer = torch.optim.Adam(model.parameters(), lr=init_lr, weight_decay=weight_decay) 165 | 166 | #visualization for training 167 | from visdom import Visdom 168 | viz = Visdom(env='863_corpus') 169 | title = dataset+' '+feature_type+str(n_feats)+' CNN_LSTM_CTC' 170 | opts = [dict(title=title+" Loss", ylabel = 'Loss', xlabel = 'Epoch'), 171 | dict(title=title+" CER on Train", ylabel = 'CER', xlabel = 'Epoch'), 172 | dict(title=title+' CER on DEV', ylabel = 'DEV CER', xlabel = 'Epoch')] 173 | viz_window = [None, None, None] 174 | 175 | count = 0 176 | learning_rate = init_lr 177 | acc_best = -100 178 | acc_best_true = -100 179 | adjust_rate_flag = False 180 | stop_train = False 181 | adjust_time = 0 182 | start_time = time.time() 183 | loss_results = [] 184 | training_cer_results = [] 185 | dev_cer_results = [] 186 | 187 | while not stop_train: 188 | if count >= num_epoches: 189 | break 190 | count += 1 191 | 192 | if adjust_rate_flag: 193 | learning_rate *= decay 194 | adjust_rate_flag = False 195 | for param in optimizer.param_groups: 196 | param['lr'] *= decay 197 | 198 | print("Start training epoch: %d, learning_rate: %.5f" % (count, learning_rate)) 199 | logger.info("Start training epoch: %d, learning_rate: %.5f" % (count, learning_rate)) 200 | 201 | loss = train(model, train_loader, loss_fn, optimizer, logger, print_every=20) 202 | loss_results.append(loss) 203 | cer = dev(model, train_loader, decoder, logger) 204 | print("cer on training set is %.4f" % cer) 205 | logger.info("cer on training set is %.4f" % cer) 206 | training_cer_results.append(cer) 207 | acc = dev(model, dev_loader, decoder, logger) 208 | dev_cer_results.append(acc) 209 | 210 | #model_path_accept = './log/epoch'+str(count)+'_lr'+str(learning_rate)+'_cv'+str(acc)+'.pkl' 211 | #model_path_reject = './log/epoch'+str(count)+'_lr'+str(learning_rate)+'_cv'+str(acc)+'_rejected.pkl' 212 | 213 | if acc > (acc_best + end_adjust_acc): 214 | acc_best = acc 215 | adjust_rate_count = 0 216 | model_state = copy.deepcopy(model.state_dict()) 217 | op_state = copy.deepcopy(optimizer.state_dict()) 218 | elif (acc > acc_best - end_adjust_acc): 219 | adjust_rate_count += 1 220 | if acc > acc_best and acc > acc_best_true: 221 | acc_best_true = acc 222 | model_state = copy.deepcopy(model.state_dict()) 223 | op_state = copy.deepcopy(optimizer.state_dict()) 224 | else: 225 | adjust_rate_count = 0 226 | #torch.save(model.state_dict(), model_path_reject) 227 | print("adjust_rate_count:"+str(adjust_rate_count)) 228 | print('adjust_time:'+str(adjust_time)) 229 | logger.info("adjust_rate_count:"+str(adjust_rate_count)) 230 | logger.info('adjust_time:'+str(adjust_time)) 231 | 232 | if adjust_rate_count == 10: 233 | adjust_rate_flag = True 234 | adjust_time += 1 235 | adjust_rate_count = 0 236 | acc_best = acc_best_true 237 | model.load_state_dict(model_state) 238 | optimizer.load_state_dict(op_state) 239 | 240 | if adjust_time == 8: 241 | stop_train = True 242 | 243 | time_used = (time.time() - start_time) / 60 244 | print("epoch %d done, cv acc is: %.4f, time_used: %.4f minutes" % (count, acc, time_used)) 245 | logger.info("epoch %d done, cv acc is: %.4f, time_used: %.4f minutes" % (count, acc, time_used)) 246 | x_axis = range(count) 247 | y_axis = [loss_results[0:count], training_cer_results[0:count], dev_cer_results[0:count]] 248 | for x in range(len(viz_window)): 249 | if viz_window[x] is None: 250 | viz_window[x] = viz.line(X = np.array(x_axis), Y = np.array(y_axis[x]), opts = opts[x],) 251 | else: 252 | viz.line(X = np.array(x_axis), Y = np.array(y_axis[x]), win = viz_window[x], update = 'replace',) 253 | 254 | print("End training, best cv acc is: %.4f" % acc_best) 255 | logger.info("End training, best cv acc is: %.4f" % acc_best) 256 | best_path = os.path.join(args.log_dir, 'best_model'+'_cv'+str(acc_best)+'.pkl') 257 | cf.set('Model', 'model_file', best_path) 258 | cf.write(open(args.conf, 'w')) 259 | params['epoch']=count 260 | torch.save(CNN_LSTM_CTC.save_package(model, optimizer=optimizer, epoch=params, loss_results=loss_results, training_cer_results=training_cer_results, dev_cer_results=dev_cer_results), best_path) 261 | 262 | if __name__ == '__main__': 263 | main() 264 | -------------------------------------------------------------------------------- /my_863_corpus/steps/ctcDecoder.py: -------------------------------------------------------------------------------- 1 | #/usr/bin/python 2 | #encoding=utf-8 3 | 4 | #greedy decoder and beamsearch decoder for ctc 5 | 6 | import torch 7 | import numpy as np 8 | 9 | class Decoder(object): 10 | def __init__(self, int2char, space_idx = 1, blank_index = 0): 11 | self.int_to_char = int2char 12 | self.space_idx = space_idx 13 | self.blank_index = blank_index 14 | self.num_word = 0 15 | self.num_char = 0 16 | 17 | def greedy_decoder(self, prob_tensor, frame_seq_len): 18 | prob_tensor = prob_tensor.transpose(0,1) #batch_size*seq_len*output_size 19 | _, decoded = torch.max(prob_tensor, 2) 20 | decoded = decoded.view(decoded.size(0), decoded.size(1)) 21 | strings = self._convert_to_strings(decoded, frame_seq_len) 22 | return self._process_strings(strings, remove_rep=True) 23 | 24 | def decode(self): 25 | raise NotImplementedError; 26 | 27 | def cer_wer(self, prob_tensor, frame_seq_len, targets, target_sizes): 28 | strings = self.decode(prob_tensor, frame_seq_len) 29 | targets = self._unflatten_targets(targets, target_sizes) 30 | target_strings = self._process_strings(self._convert_to_strings(targets)) 31 | cer = 0 32 | wer = 0 33 | for x in range(len(target_strings)): 34 | cer += self.cer(strings[x], target_strings[x]) / float(len(target_strings[x])) 35 | wer += self.wer(strings[x], target_strings[x]) / float(len(target_strings[x].split())) 36 | return cer, wer 37 | 38 | def phone_word_error(self, prob_tensor, frame_seq_len, targets, target_sizes): 39 | strings = self.decode(prob_tensor, frame_seq_len) 40 | targets = self._unflatten_targets(targets, target_sizes) 41 | target_strings = self._process_strings(self._convert_to_strings(targets)) 42 | cer = 0 43 | wer = 0 44 | for x in range(len(target_strings)): 45 | cer += self.cer(strings[x], target_strings[x]) 46 | wer += self.wer(strings[x], target_strings[x]) 47 | self.num_word += len(target_strings[x].split()) 48 | self.num_char += len(target_strings[x]) 49 | return cer, wer 50 | 51 | def _unflatten_targets(self, targets, target_sizes): 52 | split_targets = [] 53 | offset = 0 54 | for size in target_sizes: 55 | split_targets.append(targets[offset : offset + size]) 56 | offset += size 57 | return split_targets 58 | 59 | def _process_strings(self, seqs, remove_rep = False): 60 | processed_strings = [] 61 | for seq in seqs: 62 | string = self._process_string(seq, remove_rep) 63 | processed_strings.append(string) 64 | return processed_strings 65 | 66 | def _process_string(self, seq, remove_rep = False): 67 | string = '' 68 | for i, char in enumerate(seq): 69 | if char != self.int_to_char[self.blank_index]: 70 | if remove_rep and i != 0 and char == seq[i - 1]: #remove dumplicates 71 | pass 72 | elif self.space_idx == -1: 73 | string = string + ' '+ char 74 | elif char == self.int_to_char[self.space_idx]: 75 | string += ' ' 76 | else: 77 | string = string + char 78 | return string 79 | 80 | def _convert_to_strings(self, seq, sizes=None): 81 | strings = [] 82 | for x in range(len(seq)): 83 | seq_len = sizes[x] if sizes is not None else len(seq[x]) 84 | string = self._convert_to_string(seq[x], seq_len) 85 | strings.append(string) 86 | return strings 87 | 88 | def _convert_to_string(self, seq, sizes): 89 | result = [] 90 | for i in range(sizes): 91 | result.append(self.int_to_char[seq[i]]) 92 | if self.space_idx == -1: 93 | return result 94 | else: 95 | return ''.join(result) 96 | 97 | def wer(self, s1, s2): 98 | b = set(s1.split() + s2.split()) 99 | word2int = dict(zip(b, range(len(b)))) 100 | 101 | w1 = [word2int[w] for w in s1.split()] 102 | w2 = [word2int[w] for w in s2.split()] 103 | return self._edit_distance(w1, w2) 104 | 105 | def cer(self, s1, s2): 106 | return self._edit_distance(s1, s2) 107 | 108 | def _edit_distance(self, src_seq, tgt_seq): # compute edit distance between two iterable objects 109 | L1, L2 = len(src_seq), len(tgt_seq) 110 | if L1 == 0: return L2 111 | if L2 == 0: return L1 112 | # construct matrix of size (L1 + 1, L2 + 1) 113 | dist = [[0] * (L2 + 1) for i in range(L1 + 1)] 114 | for i in range(1, L2 + 1): 115 | dist[0][i] = dist[0][i-1] + 1 116 | for i in range(1, L1 + 1): 117 | dist[i][0] = dist[i-1][0] + 1 118 | for i in range(1, L1 + 1): 119 | for j in range(1, L2 + 1): 120 | if src_seq[i - 1] == tgt_seq[j - 1]: 121 | cost = 0 122 | else: 123 | cost = 1 124 | dist[i][j] = min(dist[i][j-1] + 1, dist[i-1][j] + 1, dist[i-1][j-1] + cost) 125 | return dist[L1][L2] 126 | 127 | 128 | class GreedyDecoder(Decoder): 129 | def decode(self, prob_tensor, frame_seq_len): 130 | prob_tensor = prob_tensor.transpose(0,1) # (n, t, c) 131 | _, decoded = torch.max(prob_tensor, 2) 132 | decoded = decoded.view(decoded.size(0), decoded.size(1)) 133 | decoded = self._convert_to_strings(decoded, frame_seq_len) # convert digit idx to chars 134 | return self._process_strings(decoded, remove_rep=True) 135 | 136 | 137 | class BeamDecoder(Decoder): 138 | def __init__(self, int2char, beam_width = 100, blank_index = 0, space_idx = -1, lm_path=None): 139 | self.beam_width = beam_width 140 | self.top_n = top_paths 141 | self.labels = ['#'] 142 | for digit in int2char: 143 | if digit != 0: 144 | self.labels.append(int2char[digit]) 145 | super(BeamDecoder, self).__init__(int2char, space_idx=space_idx, blank_index=blank_index) 146 | 147 | import BeamSearch 148 | self._decoder = BeamSearch.ctcBeamSearch(self.labels, beam_width, None, blank_index) 149 | #import TokenPassing 150 | #self._decoder = TokenPassing.ctcTokenPassing(self.labels, dic, blank_index=blank_index) 151 | 152 | def convert_to_strings(self, out, seq_len): 153 | results = [] 154 | for b, batch in enumerate(out): 155 | utterances = [] 156 | for p, utt in enumerate(batch): 157 | size = seq_len[b][p] 158 | utterances.append(' '.join(map(lambda x:self.int_to_char[x], utt[0:size]))) 159 | results.append(utterances) 160 | return results 161 | 162 | def decode(self, prob_tensor, frame_seq_len=None): 163 | probs = prob_tensor.transpose(0, 1) 164 | res = self._decoder.decode(probs, frame_seq_len) 165 | return res 166 | 167 | if __name__ == '__main__': 168 | decoder = Decoder('abcde', 1, 2) 169 | print(decoder._convert_to_strings([[1,2,1,0,3],[1,2,1,1,1]])) 170 | 171 | -------------------------------------------------------------------------------- /my_863_corpus/steps/data_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #encoding=utf-8 3 | 4 | import os 5 | import h5py 6 | import numpy as np 7 | import torch 8 | import sys 9 | from torch.utils.data import DataLoader 10 | from torch.utils.data import Dataset 11 | import scipy.signal 12 | import math 13 | from utils import parse_audio, process_kaldi_feat, process_label_file, process_map_file, F_Mel 14 | import utils 15 | 16 | windows = {'hamming':scipy.signal.hamming, 'hann':scipy.signal.hann, 'blackman':scipy.signal.blackman, 17 | 'bartlett':scipy.signal.bartlett} 18 | audio_conf = {"sample_rate":16000, 'window_size':0.025, 'window_stride':0.01, 'window': 'hamming'} 19 | 20 | #Override the class of Dataset 21 | #Define my own dataset over timit used the feature extracted by kaldi 22 | class myDataset(Dataset): 23 | def __init__(self, data_dir, data_set='train', feature_type='spectrum', out_type='phone', n_feats=39, normalize=True, mel=False): 24 | self.data_set = data_set 25 | self.out_type = out_type 26 | self.feature_type = feature_type 27 | self.normalize = normalize 28 | self.mel = mel 29 | h5_file = os.path.join(data_dir, feature_type+'_'+out_type+'_tmp', data_set+'.h5py') 30 | wav_path = os.path.join(data_dir, 'wav_path', data_set+'.wav.scp') 31 | mfcc_file = os.path.join(data_dir, "feature_"+feature_type, data_set+'.txt') 32 | label_file = os.path.join(data_dir,"label_"+out_type, data_set+'.text') 33 | char_file = os.path.join(data_dir, out_type+'_list.txt') 34 | if not os.path.exists(h5_file): 35 | if feature_type != 'spectrum': 36 | self.n_feats = n_feats 37 | print("Process %s data in kaldi format..." % data_set) 38 | self.process_txt(mfcc_file, label_file, char_file, h5_file) 39 | else: 40 | print("Extract spectrum with librosa...") 41 | self.n_feats = int(audio_conf['sample_rate']*audio_conf['window_size']/2+1) 42 | self.process_audio(wav_path, label_file, char_file, h5_file) 43 | else: 44 | if feature_type != "spectrum": 45 | self.n_feats = n_feats 46 | else: 47 | self.n_feats = int(audio_conf["sample_rate"]*audio_conf["window_size"]/2+1) 48 | #self.n_feats = n_feats 49 | print("Loading %s data from h5py file..." % data_set) 50 | self.load_h5py(h5_file) 51 | 52 | def process_txt(self, mfcc_file, label_file, char_file, h5_file): 53 | #read map file 54 | self.char_map, self.int2phone = process_map_file(char_file) 55 | 56 | #read the label file 57 | label_dict = process_label_file(label_file, self.out_type, self.char_map) 58 | 59 | #read the mfcc file 60 | mfcc_dict = process_kaldi_feat(mfcc_file, self.n_feats) 61 | 62 | if len(mfcc_dict) != len(label_dict): 63 | print("%s data: The num of wav and text are not the same!" % self.data_set) 64 | sys.exit(1) 65 | 66 | self.features_label = [] 67 | #save the data as h5 file 68 | f = h5py.File(h5_file, 'w') 69 | f.create_dataset("phone_map_key", data=self.char_map.keys()) 70 | f.create_dataset("phone_map_value", data = self.char_map.values()) 71 | for utt in mfcc_dict: 72 | grp = f.create_group(utt) 73 | self.features_label.append((torch.FloatTensor(np.array(mfcc_dict[utt])), label_dict[utt].tolist())) 74 | grp.create_dataset('data', data=np.array(mfcc_dict[utt])) 75 | grp.create_dataset('label', data=label_dict[utt]) 76 | print("Saved the %s data to h5py file" % self.data_set) 77 | #print(self.__getitem__(1)) 78 | 79 | def process_audio(self, wav_path, label_file, char_file, h5_file): 80 | #read map file 81 | self.char_map, self.int2phone = process_map_file(char_file) 82 | 83 | #read the label file 84 | label_dict = process_label_file(label_file, self.out_type, self.char_map) 85 | 86 | #extract spectrum 87 | spec_dict = dict() 88 | f = open(wav_path, 'r') 89 | for line in f.readlines(): 90 | utt, path = line.strip().split() 91 | spect = self.parse_audio(path) 92 | #print(spect) 93 | spec_dict[utt] = spect.numpy() 94 | f.close() 95 | 96 | if self.normalize: 97 | i = 0 98 | for utt in spec_dict: 99 | if i == 0: 100 | spec_all = torch.FloatTensor(spec_dict[utt]) 101 | else: 102 | spec_all = torch.cat((spec_all, torch.FloatTensor(spec_dict[utt])), 0) 103 | i += 1 104 | mean = torch.mean(spec_all, 0, True) 105 | std = torch.std(spec_all, 0, True) 106 | for utt in spec_dict: 107 | tmp = torch.add(torch.FloatTensor(spec_dict[utt]), -1, mean) 108 | spec_dict[utt] = torch.div(tmp, std).numpy() 109 | 110 | if len(spec_dict) != len(label_dict): 111 | print("%s data: The num of wav and text are not the same!" % self.data_set) 112 | sys.exit(1) 113 | 114 | self.features_label = [] 115 | #save the data as h5 file 116 | f = h5py.File(h5_file, 'w') 117 | f.create_dataset("phone_map_key", data=self.char_map.keys()) 118 | f.create_dataset("phone_map_value", data = self.char_map.values()) 119 | for utt in spec_dict: 120 | grp = f.create_group(utt) 121 | self.features_label.append((torch.FloatTensor(spec_dict[utt]), label_dict[utt].tolist())) 122 | grp.create_dataset('data', data=spec_dict[utt]) 123 | grp.create_dataset('label', data=label_dict[utt]) 124 | print("Saved the %s data to h5py file" % self.data_set) 125 | 126 | 127 | def parse_audio(self, path): 128 | y = load_audio(path) 129 | n_fft = int(audio_conf['sample_rate']*audio_conf["window_size"]) 130 | win_length = n_fft 131 | hop_length = int(audio_conf['sample_rate']*audio_conf['window_stride']) 132 | window = windows[audio_conf['window']] 133 | D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, 134 | win_length=win_length, window=window) 135 | spect, phase = librosa.magphase(D) 136 | spect = np.log1p(spect) 137 | spect = torch.FloatTensor(spect) 138 | 139 | return spect.transpose(0,1) 140 | 141 | def load_h5py(self, h5_file): 142 | self.features_label = [] 143 | f = h5py.File(h5_file, 'r') 144 | for grp in f: 145 | if grp != 'phone_map_key' and grp != 'phone_map_value': 146 | self.features_label.append((torch.FloatTensor(np.asarray(f[grp]['data'])), np.asarray(f[grp]['label']).tolist())) 147 | self.char_map = dict() 148 | self.int2phone = dict() 149 | keys = f['phone_map_key'] 150 | values = f['phone_map_value'] 151 | for i in range(len(keys)): 152 | self.char_map[str(keys[i])] = values[i] 153 | self.int2phone[values[i]] = keys[i] 154 | self.int2phone[0]='#' 155 | print("Load %d sentences from %s dataset" % (self.__len__(), self.data_set)) 156 | 157 | def __getitem__(self, idx): 158 | if self.mel: 159 | spect, label = self.features_label[idx] 160 | spect = F_Mel(spect, audio_conf) 161 | return (spect, label) 162 | else: 163 | return self.features_label[idx] 164 | 165 | def __len__(self): 166 | return len(self.features_label) 167 | 168 | def create_RNN_input(batch): 169 | def func(p): 170 | return p[0].size(0) 171 | 172 | #sort batch according to the frame nums 173 | batch = sorted(batch, reverse=True, key=func) 174 | longest_sample = batch[0][0] 175 | feat_size = longest_sample.size(1) 176 | #feat_size = 101 177 | max_length = longest_sample.size(0) 178 | batch_size = len(batch) 179 | inputs = torch.zeros(batch_size, max_length, feat_size) 180 | input_sizes = torch.IntTensor(batch_size) 181 | target_sizes = torch.IntTensor(batch_size) 182 | targets = [] 183 | input_size_list = [] 184 | for x in range(batch_size): 185 | sample = batch[x] 186 | feature = sample[0] 187 | #feature = sample[0].transpose(0,1)[:101].transpose(0,1) 188 | label = sample[1] 189 | seq_length = feature.size(0) 190 | inputs[x].narrow(0, 0, seq_length).copy_(feature) 191 | input_sizes[x] = seq_length 192 | input_size_list.append(seq_length) 193 | target_sizes[x] = len(label) 194 | targets.extend(label) 195 | targets = torch.IntTensor(targets) 196 | #src_pos = [[(pos+1) if (w!=[0]*feat_size).any() else 0 for pos, w in enumerate(instance)] for instance in inputs.numpy()] 197 | #src_pos = torch.LongTensor(np.array(src_pos)) 198 | return inputs, targets, input_sizes, input_size_list, target_sizes 199 | 200 | #class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, 201 | # sampler=None, batch_sampler=None, num_workers=0, 202 | # collate_fn=, 203 | # pin_memory=False, drop_last=False) 204 | #subclass of DataLoader and rewrite the collate_fn to form batch 205 | 206 | class myDataLoader(DataLoader): 207 | def __init__(self, *args, **kwargs): 208 | super(myDataLoader, self).__init__(*args, **kwargs) 209 | self.collate_fn = create_RNN_input 210 | 211 | class myCNNDataLoader(DataLoader): 212 | def __init__(self, *args, **kwargs): 213 | super(myCNNDataLoader, self).__init__(*args, **kwargs) 214 | self.collate_fn = create_CNN_input 215 | 216 | def create_CNN_input(batch): 217 | def func(p): 218 | return p[0].size(0) 219 | 220 | def change_size(size): 221 | size = int(math.floor((size-11)/2)+1) 222 | #size = int(math.floor((size-11)/1)+1) 223 | return size 224 | 225 | #sort batch according to the frame nums 226 | batch = sorted(batch, reverse=True, key=func) 227 | longest_sample = batch[0][0] 228 | feat_size = longest_sample.size(1) 229 | max_length = longest_sample.size(0) 230 | batch_size = len(batch) 231 | inputs = torch.zeros(batch_size, 1, max_length, feat_size) 232 | input_sizes = torch.IntTensor(batch_size) 233 | target_sizes = torch.IntTensor(batch_size) 234 | targets = [] 235 | input_size_list = [] 236 | for x in range(batch_size): 237 | sample = batch[x] 238 | feature = sample[0] 239 | label = sample[1] 240 | seq_length = feature.size(0) 241 | inputs[x][0].narrow(0, 0, seq_length).copy_(feature) 242 | input_sizes[x] = change_size(seq_length) 243 | input_size_list.append(change_size(seq_length)) 244 | target_sizes[x] = len(label) 245 | targets.extend(label) 246 | targets = torch.IntTensor(targets) 247 | return inputs, targets, input_sizes, input_size_list, target_sizes 248 | 249 | if __name__ == '__main__': 250 | dev_dataset = myDataset('../data_prepare/data', data_set='test', feature_type="spectrum", out_type='phone', n_feats=201) 251 | dev_loader = myDataLoader(dev_dataset, batch_size=8, shuffle=True, 252 | num_workers=4, pin_memory=False) 253 | #print(dev_dataset.int2phone) 254 | -------------------------------------------------------------------------------- /my_863_corpus/steps/lstm_ctc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #encoding=utf-8 3 | 4 | #train process for the model 5 | 6 | from data_loader import myDataset, myDataLoader 7 | from model import * 8 | from ctcDecoder import GreedyDecoder 9 | from warpctc_pytorch import CTCLoss 10 | import torch 11 | import torch.nn as nn 12 | from torch.autograd import Variable 13 | import time 14 | import numpy as np 15 | import argparse 16 | import ConfigParser 17 | import os 18 | import copy 19 | 20 | 21 | def train(model, train_loader, loss_fn, optimizer, logger, print_every=20): 22 | model.train() 23 | 24 | total_loss = 0 25 | print_loss = 0 26 | i = 0 27 | for data in train_loader: 28 | inputs, targets, input_sizes, input_sizes_list, target_sizes = data 29 | batch_size = inputs.size(0) 30 | inputs = inputs.transpose(0,1) 31 | 32 | inputs = Variable(inputs, requires_grad=False) 33 | targets = Variable(targets, requires_grad=False) 34 | input_sizes = Variable(input_sizes, requires_grad=False) 35 | target_sizes = Variable(target_sizes, requires_grad=False) 36 | 37 | if USE_CUDA: 38 | inputs = inputs.cuda() 39 | 40 | #pack padded input sequence 41 | inputs = nn.utils.rnn.pack_padded_sequence(inputs, input_sizes_list) 42 | out = model(inputs) 43 | 44 | loss = loss_fn(out, targets, input_sizes, target_sizes) 45 | loss /= batch_size 46 | print_loss += loss.data[0] 47 | 48 | if (i + 1) % print_every == 0: 49 | print('batch = %d, loss = %.4f' % (i+1, print_loss / print_every)) 50 | logger.debug('batch = %d, loss = %.4f' % (i+1, print_loss / print_every)) 51 | print_loss = 0 52 | 53 | total_loss += loss.data[0] 54 | optimizer.zero_grad() 55 | loss.backward() 56 | nn.utils.clip_grad_norm(model.parameters(), 400) #防止梯度爆炸或者梯度消失,限制参数范围 57 | optimizer.step() 58 | i += 1 59 | average_loss = total_loss / i 60 | print("Epoch done, average loss: %.4f" % average_loss) 61 | logger.info("Epoch done, average loss: %.4f" % average_loss) 62 | return average_loss 63 | 64 | def dev(model, dev_loader, decoder, logger): 65 | model.eval() 66 | total_cer = 0 67 | total_tokens = 0 68 | 69 | for data in dev_loader: 70 | inputs, targets, input_sizes, input_sizes_list, target_sizes =data 71 | batch_size = inputs.size(1) 72 | inputs = inputs.transpose(0, 1) 73 | 74 | inputs = Variable(inputs, volatile=True, requires_grad=False) 75 | 76 | if USE_CUDA: 77 | inputs = inputs.cuda() 78 | 79 | inputs = nn.utils.rnn.pack_padded_sequence(inputs, input_sizes_list) 80 | probs = model(inputs) 81 | 82 | probs = probs.data.cpu() 83 | if decoder.space_idx == -1: 84 | total_cer += decoder.phone_word_error(probs, input_sizes_list, targets, target_sizes)[1] 85 | else: 86 | total_cer += decoder.phone_word_error(probs, input_sizes_list, targets, target_sizes)[0] 87 | total_tokens += sum(target_sizes) 88 | acc = 1 - float(total_cer) / total_tokens 89 | return acc*100 90 | 91 | def init_logger(log_file): 92 | import logging 93 | from logging.handlers import RotatingFileHandler 94 | 95 | logger = logging.getLogger() 96 | hdl = RotatingFileHandler(log_file, maxBytes=10*1024*1024, backupCount=10) 97 | formatter=logging.Formatter('%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s') 98 | hdl.setFormatter(formatter) 99 | logger.addHandler(hdl) 100 | logger.setLevel(logging.DEBUG) 101 | return logger 102 | 103 | RNN = {'nn.LSTM':nn.LSTM, 'nn.GRU': nn.GRU, 'nn.RNN':nn.RNN} 104 | parser = argparse.ArgumentParser(description='lstm_ctc') 105 | parser.add_argument('--conf', default='../conf/lstm_ctc_setting.conf' , help='conf file with Argument of LSTM and training') 106 | parser.add_argument('--log-dir', dest='log_dir', default='../log/', help='log dir for training') 107 | 108 | def main(): 109 | args = parser.parse_args() 110 | cf = ConfigParser.ConfigParser() 111 | try: 112 | cf.read(args.conf) 113 | except: 114 | print("conf file not exists") 115 | 116 | logger = init_logger(os.path.join(args.log_dir, 'train_lstm_ctc.log')) 117 | dataset = cf.get('Data', 'dataset') 118 | data_dir = cf.get('Data', 'data_dir') 119 | feature_type = cf.get('Data', 'feature_type') 120 | out_type = cf.get('Data', 'out_type') 121 | n_feats = cf.getint('Data', 'n_feats') 122 | batch_size = cf.getint("Training", 'batch_size') 123 | 124 | #Data Loader 125 | train_dataset = myDataset(data_dir, data_set='train', feature_type=feature_type, out_type=out_type, n_feats=n_feats) 126 | train_loader = myDataLoader(train_dataset, batch_size=batch_size, shuffle=True, 127 | num_workers=4, pin_memory=False) 128 | dev_dataset = myDataset(data_dir, data_set="test", feature_type=feature_type, out_type=out_type, n_feats=n_feats) 129 | dev_loader = myDataLoader(dev_dataset, batch_size=batch_size, shuffle=False, 130 | num_workers=4, pin_memory=False) 131 | 132 | #decoder for dev set 133 | decoder = GreedyDecoder(dev_dataset.int2phone, space_idx=-1, blank_index=0) 134 | 135 | #Define Model 136 | rnn_input_size = cf.getint('Model', 'rnn_input_size') 137 | rnn_hidden_size = cf.getint('Model', 'rnn_hidden_size') 138 | rnn_layers = cf.getint('Model', 'rnn_layers') 139 | rnn_type = RNN[cf.get('Model', 'rnn_type')] 140 | bidirectional = cf.getboolean('Model', 'bidirectional') 141 | batch_norm = cf.getboolean('Model', 'batch_norm') 142 | num_class = cf.getint('Model', 'num_class') 143 | drop_out = cf.getfloat('Model', 'num_class') 144 | model = CTC_RNN(rnn_input_size=rnn_input_size, rnn_hidden_size=rnn_hidden_size, rnn_layers=rnn_layers, 145 | rnn_type=rnn_type, bidirectional=bidirectional, batch_norm=batch_norm, 146 | num_class=num_class, drop_out=drop_out) 147 | #model.apply(xavier_uniform_init) 148 | print(model.name) 149 | 150 | #Training 151 | init_lr = cf.getfloat('Training', 'init_lr') 152 | num_epoches = cf.getint('Training', 'num_epoches') 153 | end_adjust_acc = cf.getfloat('Training', 'end_adjust_acc') 154 | decay = cf.getfloat("Training", 'lr_decay') 155 | weight_decay = cf.getfloat("Training", 'weight_decay') 156 | try: 157 | seed = cf.getint('Training', 'seed') 158 | except: 159 | seed = torch.cuda.initial_seed() 160 | 161 | params = { 'num_epoches':num_epoches, 'end_adjust_acc':end_adjust_acc, 'seed':seed, 162 | 'decay':decay, 'learning_rate':init_lr, 'weight_decay':weight_decay, 'batch_size':batch_size, 163 | 'feature_type':feature_type, 'n_feats': n_feats, 'out_type': out_type } 164 | 165 | if USE_CUDA: 166 | torch.cuda.manual_seed(seed) 167 | model = model.cuda() 168 | 169 | print(params) 170 | 171 | loss_fn = CTCLoss() 172 | optimizer = torch.optim.Adam(model.parameters(), lr=init_lr, weight_decay=weight_decay) 173 | 174 | #visualization for training 175 | from visdom import Visdom 176 | viz = Visdom(env='863_corpus') 177 | title = dataset+' '+feature_type+str(n_feats)+' LSTM_CTC' 178 | opts = [dict(title=title+" Loss", ylabel = 'Loss', xlabel = 'Epoch'), 179 | dict(title=title+" CER on Train", ylabel = 'CER', xlabel = 'Epoch'), 180 | dict(title=title+' CER on DEV', ylabel = 'DEV CER', xlabel = 'Epoch')] 181 | viz_window = [None, None, None] 182 | 183 | count = 0 184 | learning_rate = init_lr 185 | acc_best = -100 186 | acc_best_true = -100 187 | adjust_rate_count = 0 188 | adjust_rate_flag = False 189 | stop_train = False 190 | adjust_time = 0 191 | start_time = time.time() 192 | loss_results = [] 193 | training_cer_results = [] 194 | dev_cer_results = [] 195 | 196 | while not stop_train: 197 | if count >= num_epoches: 198 | break 199 | count += 1 200 | 201 | if adjust_rate_flag: 202 | learning_rate *= decay 203 | adjust_rate_flag = False 204 | for param in optimizer.param_groups: 205 | param['lr'] *= decay 206 | 207 | print("Start training epoch: %d, learning_rate: %.5f" % (count, learning_rate)) 208 | logger.info("Start training epoch: %d, learning_rate: %.5f" % (count, learning_rate)) 209 | 210 | loss = train(model, train_loader, loss_fn, optimizer, logger) 211 | loss_results.append(loss) 212 | cer = dev(model, train_loader, decoder, logger) 213 | print("cer on training set is %.4f" % cer) 214 | logger.info("cer on training set is %.4f" % cer) 215 | training_cer_results.append(cer) 216 | acc = dev(model, dev_loader, decoder, logger) 217 | dev_cer_results.append(acc) 218 | 219 | #model_path_accept = './log/epoch'+str(count)+'_lr'+str(learning_rate)+'_cv'+str(acc)+'.pkl' 220 | #model_path_reject = './log/epoch'+str(count)+'_lr'+str(learning_rate)+'_cv'+str(acc)+'_rejected.pkl' 221 | 222 | if acc > (acc_best + end_adjust_acc): 223 | acc_best = acc 224 | adjust_rate_count = 0 225 | model_state = copy.deepcopy(model.state_dict()) 226 | op_state = copy.deepcopy(optimizer.state_dict()) 227 | elif (acc > acc_best - end_adjust_acc): 228 | adjust_rate_count += 1 229 | if acc > acc_best and acc > acc_best_true: 230 | acc_best_true = acc 231 | model_state = copy.deepcopy(model.state_dict()) 232 | op_state = copy.deepcopy(optimizer.state_dict()) 233 | else: 234 | adjust_rate_count = 0 235 | #torch.save(model.state_dict(), model_path_reject) 236 | print("adjust_rate_count:"+str(adjust_rate_count)) 237 | print('adjust_time:'+str(adjust_time)) 238 | logger.info("adjust_rate_count:"+str(adjust_rate_count)) 239 | logger.info('adjust_time:'+str(adjust_time)) 240 | 241 | if adjust_rate_count == 10: 242 | adjust_rate_flag = True 243 | adjust_time += 1 244 | adjust_rate_count = 0 245 | acc_best = acc_best_true 246 | model.load_state_dict(model_state) 247 | optimizer.load_state_dict(op_state) 248 | 249 | if adjust_time == 8: 250 | stop_train = True 251 | 252 | time_used = (time.time() - start_time) / 60 253 | print("epoch %d done, cv acc is: %.4f, time_used: %.4f minutes" % (count, acc, time_used)) 254 | logger.info("epoch %d done, cv acc is: %.4f, time_used: %.4f minutes" % (count, acc, time_used)) 255 | 256 | x_axis = range(count) 257 | y_axis = [loss_results[0:count], training_cer_results[0:count], dev_cer_results[0:count]] 258 | for x in range(len(viz_window)): 259 | if viz_window[x] is None: 260 | viz_window[x] = viz.line(X = np.array(x_axis), Y = np.array(y_axis[x]), opts = opts[x],) 261 | else: 262 | viz.line(X = np.array(x_axis), Y = np.array(y_axis[x]), win = viz_window[x], update = 'replace',) 263 | 264 | print("End training, best cv acc is: %.4f" % acc_best) 265 | logger.info("End training, best cv acc is: %.4f" % acc_best) 266 | best_path = os.path.join(args.log_dir, 'best_model'+'_cv'+str(acc_best)+'.pkl') 267 | cf.set('Model', 'model_file', best_path) 268 | cf.write(open(args.conf, 'w')) 269 | params['epoch'] = count 270 | torch.save(CTC_RNN.save_package(model, optimizer=optimizer, epoch=params, loss_results=loss_results, training_cer_results=training_cer_results, dev_cer_results=dev_cer_results), best_path) 271 | 272 | 273 | if __name__ == '__main__': 274 | main() 275 | -------------------------------------------------------------------------------- /my_863_corpus/steps/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #encoding=utf-8 3 | 4 | import torch 5 | import math 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import numpy as np 9 | from collections import OrderedDict 10 | 11 | __author__ = "Richardfan" 12 | 13 | support_rnn = {'lstm': nn.LSTM, 'rnn': nn.RNN, 'gru': nn.GRU} 14 | USE_CUDA = True 15 | 16 | def position_encoding_init(n_position, d_pos_vec): 17 | position_enc = np.array([ 18 | [pos / np.power(10000, 2*i / d_pos_vec) for i in range(d_pos_vec)] 19 | if pos !=0 else np.zeros(d_pos_vec) for pos in range(n_position)] 20 | ) 21 | 22 | position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) #dim 2i 23 | position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) #dim 2i+1 24 | return torch.from_numpy(position_enc).type(torch.FloatTensor) 25 | 26 | class Encoder(nn.Module): 27 | def __init__(self, n_position, d_word_vec=512): 28 | super(Encoder, self).__init__() 29 | self.position_enc = nn.Embedding(n_position, d_word_vec, padding_idx=0) 30 | self.position_enc.weight.data = position_encoding_init(n_position, d_word_vec) 31 | 32 | def forward(self, src_pos): 33 | enc_input = self.position_enc(src_pos) 34 | 35 | return enc_input 36 | 37 | class SequenceWise(nn.Module): 38 | def __init__(self, module): 39 | super(SequenceWise, self).__init__() 40 | self.module = module 41 | 42 | def forward(self, x): 43 | try: 44 | x, batch_size_len = x.data, x.batch_sizes 45 | #print(x) 46 | #x.data: sum(x_len) * num_features 47 | x = self.module(x) 48 | x = nn.utils.rnn.PackedSequence(x, batch_size_len) 49 | except: 50 | t, n = x.size(0), x.size(1) 51 | x = x.view(t*n, -1) 52 | #print(x) 53 | #x : sum(x_len) * num_features 54 | x = self.module(x) 55 | x = x.view(t, n, -1) 56 | return x 57 | 58 | def __repr__(self): 59 | tmpstr = self.__class__.__name__ + ' (\n' 60 | tmpstr += self.module.__repr__() 61 | tmpstr += ')' 62 | return tmpstr 63 | 64 | class InferenceBatchLogSoftmax(nn.Module): 65 | def forward(self, x): 66 | #x: seq_len * batch_size * num 67 | 68 | if not self.training: 69 | seq_len = x.size()[0] 70 | return torch.stack([F.log_softmax(x[i]) for i in range(seq_len)], 0) 71 | else: 72 | return x 73 | 74 | class BatchRNN(nn.Module): 75 | def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, 76 | bidirectional=False, batch_norm=True, dropout = 0.1): 77 | super(BatchRNN, self).__init__() 78 | self.input_size = input_size 79 | self.hidden_size = hidden_size 80 | self.bidirectional = bidirectional 81 | self.batch_norm = SequenceWise(nn.BatchNorm1d(input_size)) if batch_norm else None 82 | self.rnn = rnn_type(input_size=input_size, hidden_size=hidden_size, 83 | bidirectional=bidirectional, dropout = dropout, bias=False) 84 | 85 | def forward(self, x): 86 | if self.batch_norm is not None: 87 | x = self.batch_norm(x) 88 | x, _ = self.rnn(x) 89 | self.rnn.flatten_parameters() 90 | return x 91 | 92 | 93 | class CTC_RNN(nn.Module): 94 | def __init__(self, rnn_input_size=40, rnn_hidden_size=768, rnn_layers=5, 95 | rnn_type=nn.LSTM, bidirectional=True, 96 | batch_norm=True, num_class=28, drop_out = 0.1): 97 | super(CTC_RNN, self).__init__() 98 | self.rnn_input_size = rnn_input_size 99 | self.rnn_hidden_size = rnn_hidden_size 100 | self.rnn_layers = rnn_layers 101 | self.rnn_type = rnn_type 102 | self.num_class = num_class 103 | self.num_directions = 2 if bidirectional else 1 104 | self.name = 'CTC_RNN' 105 | self._drop_out = drop_out 106 | 107 | rnns = [] 108 | rnn = BatchRNN(input_size=rnn_input_size, hidden_size=rnn_hidden_size, 109 | rnn_type=rnn_type, bidirectional=bidirectional, 110 | batch_norm=False) 111 | rnns.append(('0', rnn)) 112 | for i in range(rnn_layers-1): 113 | rnn = BatchRNN(input_size=self.num_directions*rnn_hidden_size, 114 | hidden_size=rnn_hidden_size, rnn_type=rnn_type, 115 | bidirectional=bidirectional, dropout=drop_out, batch_norm=batch_norm) 116 | rnns.append(('%d' % (i+1), rnn)) 117 | 118 | self.rnns = nn.Sequential(OrderedDict(rnns)) 119 | 120 | if batch_norm : 121 | fc = nn.Sequential(nn.BatchNorm1d(self.num_directions*rnn_hidden_size), 122 | nn.Linear(self.num_directions*rnn_hidden_size, num_class+1, bias=False)) 123 | else: 124 | fc = nn.Linear(self.num_directions*rnn_hidden_size, num_class+1, bias=False) 125 | 126 | self.fc = SequenceWise(fc) 127 | self.inference_log_softmax = InferenceBatchLogSoftmax() 128 | 129 | def forward(self, x): 130 | #x: packed padded sequence 131 | #x.data: means the origin data 132 | #x.batch_sizes: the batch_size of each frames 133 | #x_len: type:list not torch.IntTensor 134 | x = self.rnns(x) 135 | #print(x) 136 | x = self.fc(x) 137 | 138 | x, batch_seq = nn.utils.rnn.pad_packed_sequence(x,batch_first=False) 139 | x = self.inference_log_softmax(x) 140 | 141 | return x 142 | 143 | @staticmethod 144 | def save_package(model, optimizer=None, decoder=None, epoch=None, loss_results=None, training_cer_results=None, dev_cer_results=None): 145 | package = { 146 | 'input_size': model.rnn_input_size, 147 | 'hidden_size': model.rnn_hidden_size, 148 | 'rnn_layers': model.rnn_layers, 149 | 'rnn_type': model.rnn_type, 150 | 'num_class': model.num_class, 151 | 'bidirectional': model.num_directions, 152 | '_drop_out' : model._drop_out, 153 | 'name': model.name, 154 | 'state_dict': model.state_dict() 155 | } 156 | if optimizer is not None: 157 | package['optim_dict'] = optimizer.state_dict() 158 | if decoder is not None: 159 | package['decoder'] = decoder 160 | if epoch is not None: 161 | package['epoch'] = epoch 162 | if loss_results is not None: 163 | package['loss_results'] = loss_results 164 | package['training_cer_results'] = training_cer_results 165 | package['dev_cer_results'] = dev_cer_results 166 | return package 167 | 168 | class CNN_LSTM_CTC(nn.Module): 169 | def __init__(self, rnn_input_size=201, rnn_hidden_size=256, rnn_layers=4, 170 | rnn_type=nn.LSTM, bidirectional=True, 171 | batch_norm=True, num_class=48, drop_out=0.1): 172 | super(CNN_LSTM_CTC, self).__init__() 173 | self.rnn_input_size = rnn_input_size 174 | self.rnn_hidden_size = rnn_hidden_size 175 | self.rnn_layers = rnn_layers 176 | self.rnn_type = rnn_type 177 | self.num_class = num_class 178 | self.num_directions = 2 if bidirectional else 1 179 | self._drop_out = drop_out 180 | self.name = 'CNN_LSTM_CTC' 181 | 182 | self.conv = nn.Sequential( 183 | nn.Conv2d(1, 16, kernel_size=(11, 5), stride=(2, 2)), 184 | nn.BatchNorm2d(16), 185 | nn.Hardtanh(0, 20, inplace=True), 186 | #nn.Conv2d(32, 32, kernel_size=(11, 21), stride=(1, 2)), 187 | #nn.BatchNorm2d(32), 188 | #nn.Hardtanh(0, 20, inplace=True) 189 | ) 190 | 191 | rnn_input_size = int(math.floor(rnn_input_size-5)/2+1) 192 | #rnn_input_size = int(math.floor(rnn_input_size-21)/2+1) 193 | rnn_input_size *= 16 194 | 195 | rnns = [] 196 | rnn = BatchRNN(input_size=rnn_input_size, hidden_size=rnn_hidden_size, 197 | rnn_type=rnn_type, bidirectional=bidirectional, 198 | batch_norm=False) 199 | 200 | rnns.append(('0', rnn)) 201 | for i in range(rnn_layers-1): 202 | rnn = BatchRNN(input_size=self.num_directions*rnn_hidden_size, 203 | hidden_size=rnn_hidden_size, rnn_type=rnn_type, 204 | bidirectional=bidirectional, dropout = drop_out, batch_norm = batch_norm) 205 | rnns.append(('%d' % (i+1), rnn)) 206 | 207 | self.rnns = nn.Sequential(OrderedDict(rnns)) 208 | 209 | if batch_norm : 210 | fc = nn.Sequential(nn.BatchNorm1d(self.num_directions*rnn_hidden_size), 211 | nn.Linear(self.num_directions*rnn_hidden_size, num_class+1, bias=False)) 212 | else: 213 | fc = nn.Linear(self.num_directions*rnn_hidden_size, num_class+1, bias=False) 214 | 215 | self.fc = SequenceWise(fc) 216 | self.inference_log_softmax = InferenceBatchLogSoftmax() 217 | 218 | def forward(self, x): 219 | #x: batch_size * 1 * max_seq_length * feat_size 220 | x = self.conv(x) 221 | x = x.transpose(2, 3).contiguous() 222 | sizes = x.size() 223 | 224 | x = x.view(sizes[0], sizes[1]*sizes[2], sizes[3]) 225 | x = x.transpose(1,2).transpose(0,1).contiguous() 226 | 227 | x = self.rnns(x) 228 | #print(x) 229 | 230 | x = self.fc(x) 231 | 232 | x = self.inference_log_softmax(x) 233 | 234 | return x 235 | 236 | @staticmethod 237 | def save_package(model, optimizer=None, decoder=None, epoch=None, loss_results=None, training_cer_results=None, dev_cer_results=None): 238 | package = { 239 | 'input_size': model.rnn_input_size, 240 | 'hidden_size': model.rnn_hidden_size, 241 | 'rnn_layers': model.rnn_layers, 242 | 'rnn_type': model.rnn_type, 243 | 'num_class': model.num_class, 244 | 'bidirectional': model.num_directions, 245 | '_drop_out': model._drop_out, 246 | 'name': model.name, 247 | 'state_dict': model.state_dict() 248 | } 249 | if optimizer is not None: 250 | package['optim_dict'] = optimizer.state_dict() 251 | if decoder is not None: 252 | package['decoder'] = decoder 253 | if epoch is not None: 254 | package['epoch'] = epoch 255 | if loss_results is not None: 256 | package['loss_results'] = loss_results 257 | package['training_cer_results'] = training_cer_results 258 | package['dev_cer_results'] = dev_cer_results 259 | return package 260 | 261 | def xavier_uniform_init(m): 262 | for param in m.parameters(): 263 | if param.data.ndimension() > 1: 264 | nn.init.xavier_uniform(param.data) 265 | 266 | if __name__ == '__main__': 267 | #model = CNN_LSTM_CTC(rnn_input_size=201, rnn_hidden_size=256, rnn_layers=4, 268 | # rnn_type=nn.LSTM, bidirectional=True, batch_norm=True, 269 | # num_class=48, drop_out=0) 270 | #model.apply(xavier_uniform_init) 271 | encoder = Encoder(11, 20) 272 | src_pos = [[1,2,3,4,5,6,7,8,9,10],[1,2,3,4,5,0,0,0,0,0]] 273 | src_pos = torch.LongTensor(src_pos) 274 | src_pos = torch.autograd.Variable(src_pos) 275 | print(src_pos) 276 | out = encoder(src_pos) 277 | print(out) 278 | 279 | -------------------------------------------------------------------------------- /my_863_corpus/steps/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #encoding=utf-8 3 | 4 | from data_loader import myDataset 5 | from data_loader import myDataLoader, myCNNDataLoader 6 | from model import * 7 | from ctcDecoder import GreedyDecoder, BeamDecoder 8 | import torch 9 | import torch.nn as nn 10 | from torch.autograd import Variable 11 | import time 12 | import argparse 13 | import sys 14 | if sys.version[0] == '2': 15 | import ConfigParser 16 | else: 17 | import configparser 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--conf', help='conf file for training') 21 | parser.add_argument('--model-path', dest='model_path', help='Model file to decode for test') 22 | parser.add_argument('--decode-type', dest='decode_type', default='Greedy', help='Decoder for test. GreadyDecoder or Beam search Decoder') 23 | 24 | def test(): 25 | args = parser.parse_args() 26 | if args.model_path is not None: 27 | package = torch.load(args.model_path) 28 | data_dir = '../data_prepare/data' 29 | else: 30 | cf = ConfigParser.ConfigParser() 31 | cf.read(args.conf) 32 | model_path = cf.get('Model', 'model_file') 33 | data_dir = cf.get('Data', 'data_dir') 34 | package = torch.load(model_path) 35 | 36 | input_size = package['input_size'] 37 | layers = package['rnn_layers'] 38 | hidden_size = package['hidden_size'] 39 | rnn_type = package['rnn_type'] 40 | num_class = package["num_class"] 41 | feature_type = package['epoch']['feature_type'] 42 | n_feats = package['epoch']['n_feats'] 43 | out_type = package['epoch']['out_type'] 44 | model_type = package['name'] 45 | drop_out = package['_drop_out'] 46 | #weight_decay = package['epoch']['weight_decay'] 47 | #print(weight_decay) 48 | 49 | decoder_type = args.decode_type 50 | 51 | test_dataset = myDataset(data_dir, data_set='test', feature_type=feature_type, out_type=out_type, n_feats=n_feats) 52 | 53 | if model_type == 'CNN_LSTM_CTC': 54 | model = CNN_LSTM_CTC(rnn_input_size=input_size, rnn_hidden_size=hidden_size, rnn_layers=layers, 55 | rnn_type=rnn_type, bidirectional=True, batch_norm=True, num_class=num_class, drop_out=drop_out) 56 | test_loader = myCNNDataLoader(test_dataset, batch_size=8, shuffle=False, 57 | num_workers=4, pin_memory=False) 58 | else: 59 | model = CTC_RNN(rnn_input_size=input_size, rnn_hidden_size=hidden_size, rnn_layers=layers, 60 | rnn_type=rnn_type, bidirectional=True, batch_norm=True, num_class=num_class, drop_out=drop_out) 61 | test_loader = myDataLoader(test_dataset, batch_size=8, shuffle=False, 62 | num_workers=4, pin_memory=False) 63 | 64 | model.load_state_dict(package['state_dict']) 65 | model.eval() 66 | 67 | if USE_CUDA: 68 | model = model.cuda() 69 | 70 | if decoder_type == 'Greedy': 71 | decoder = GreedyDecoder(test_dataset.int2phone, space_idx=-1, blank_index=0) 72 | else: 73 | decoder = BeamDecoder(test_dataset.int2phone, top_paths=40, beam_width=20, blank_index=0, space_idx=-1, 74 | lm_path=None, lm_alpha=0.8, lm_beta=1, cutoff_prob=1.0, dic=test_dataset.phone_word) 75 | 76 | total_wer = 0 77 | total_cer = 0 78 | start = time.time() 79 | for data in test_loader: 80 | inputs, target, input_sizes, input_size_list, target_sizes = data 81 | if model.name == 'CTC_RNN': 82 | inputs = inputs.transpose(0,1) 83 | inputs = Variable(inputs, volatile=True, requires_grad=False) 84 | if USE_CUDA: 85 | inputs = inputs.cuda() 86 | 87 | if model.name == 'CTC_RNN': 88 | inputs = nn.utils.rnn.pack_padded_sequence(inputs, input_size_list) 89 | probs = model(inputs) 90 | probs = probs.data.cpu() 91 | #print(probs) 92 | 93 | decoded = decoder.decode(probs, input_size_list) 94 | 95 | targets = decoder._unflatten_targets(target, target_sizes) 96 | labels = decoder._process_strings(decoder._convert_to_strings(targets)) 97 | for x in range(len(labels)): 98 | print("origin: "+ labels[x]) 99 | print("decoded: "+ decoded[x]) 100 | cer = 0 101 | wer = 0 102 | for x in range(len(labels)): 103 | cer += decoder.cer(decoded[x], labels[x]) 104 | wer += decoder.wer(decoded[x], labels[x]) 105 | decoder.num_word += len(labels[x].split()) 106 | decoder.num_char += len(labels[x]) 107 | total_cer += cer 108 | total_wer += wer 109 | CER = (1 - float(total_cer) / decoder.num_char)*100 110 | WER = (1 - float(total_wer) / decoder.num_word)*100 111 | print("Character error rate on test set: %.4f" % CER) 112 | print("Word error rate on test set: %.4f" % WER) 113 | end = time.time() 114 | time_used = (end - start) / 60.0 115 | print("Time used for decoding %d sentences: %.4f minutes" % (len(test_dataset), time_used)) 116 | 117 | if __name__ == "__main__": 118 | test() 119 | -------------------------------------------------------------------------------- /my_863_corpus/steps/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #encoding=utf-8 3 | 4 | __author__ = 'Richardfan' 5 | 6 | import torchaudio 7 | import librosa 8 | import numpy as np 9 | import torch 10 | import math 11 | 12 | def load_audio(path): 13 | ''' 14 | Input: 15 | path : string 载入音频的路径 16 | Output: 17 | sound : numpy.ndarray 单声道音频数据,如果是多声道进行平均 18 | ''' 19 | sound, _ = torchaudio.load(path) 20 | sound = sound.numpy() 21 | if len(sound.shape) > 1: 22 | if sound.shape[1] == 1: 23 | sound = sound.squeeze() 24 | else: 25 | sound - sound.mean(axis=1) 26 | return sound 27 | 28 | def parse_audio(path, audio_conf, windows): 29 | ''' 30 | Input: 31 | path : string 导入音频的路径 32 | audio_conf : dcit 求频谱的音频参数 33 | windows : dict 加窗类型 34 | Output: 35 | spect : FloatTensor 每帧的频谱 36 | ''' 37 | y = load_audio(path) 38 | n_fft = int(audio_conf['sample_rate']*audio_conf["window_size"]) 39 | win_length = n_fft 40 | hop_length = int(audio_conf['sample_rate']*audio_conf['window_stride']) 41 | window = windows[audio_conf['window']] 42 | #D = librosa.cqt(y, sr=audio_conf['sample_rate']) 43 | D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, 44 | win_length=win_length, window=window) 45 | spect, phase = librosa.magphase(D) 46 | spect = np.log1p(spect) 47 | spect = torch.FloatTensor(spect) 48 | 49 | return spect.transpose(0,1) 50 | 51 | def F_Mel(fre_f, audio_conf): 52 | ''' 53 | Input: 54 | fre_f : FloatTensor log spectrum 55 | audio_conf : 主要需要用到采样率 56 | Output: 57 | mel_f : FloatTensor 换成mel频谱 58 | ''' 59 | n_mels = fre_f.size(1) 60 | mel_bin = librosa.mel_frequencies(n_mels=n_mels, fmin=0, fmax=audio_conf["sample_rate"]/2) / 40 61 | count = 0 62 | fre_f = fre_f.numpy().tolist() 63 | mel_f = [] 64 | for frame in fre_f: 65 | mel_f_frame = [] 66 | for i in range(n_mels): 67 | left = int(math.floor(mel_bin[i])) 68 | right = left + 1 69 | tmp = (frame[right] - frame[left]) * (mel_bin[i] - left) + frame[left] #线性插值 70 | mel_f_frame.append(tmp) 71 | mel_f.append(mel_f_frame) 72 | return torch.FloatTensor(mel_f) 73 | 74 | 75 | def process_kaldi_feat(feat_file, feat_size): 76 | ''' 77 | Input: 78 | feat_file : string 特征文件路径 79 | feat_size : int 特征大小 即特征维数 80 | Output: 81 | feat_dict : dict 特征文件中的特征,每个utt的特征是list类型 82 | ''' 83 | feat_dict = dict() 84 | f = open(feat_file, 'r') 85 | for line in f.readlines(): 86 | feat_frame = list() 87 | line = line.strip().split() 88 | if len(line) == 2: 89 | utt = line[0] 90 | feat_dict[utt] = list() 91 | continue 92 | if len(line) > 2: 93 | for i in range(feat_size): 94 | feat_frame.append(float(line[i])) 95 | feat_dict[utt].append(feat_frame) 96 | f.close() 97 | return feat_dict 98 | 99 | def process_label_file(label_file, label_type, char_map): 100 | ''' 101 | Input: 102 | label_file : string 标签文件路径 103 | label_type : string 标签类型(目前只支持字符和音素) 104 | char_map : dict 标签和数字的对应关系 105 | Output: 106 | label_dict : dict 所有句子的标签,每个句子是numpy类型 107 | ''' 108 | label_dict = dict() 109 | f = open(label_file, 'r') 110 | for label in f.readlines(): 111 | label = label.strip() 112 | label_list = [] 113 | if label_type == 'char': 114 | utt = label.split('\t', 1)[0] 115 | label = label.split('\t', 1)[1] 116 | for i in range(len(label)): 117 | if label[i].lower() in char_map: 118 | label_list.append(char_map[label[i].lower()]) 119 | if label[i] == ' ': 120 | label_list.append(28) 121 | else: 122 | label = label.split() 123 | utt = label[0] 124 | for i in range(1,len(label)): 125 | label_list.append(char_map[label[i]]) 126 | label_dict[utt] = np.array(label_list) 127 | f.close() 128 | return label_dict 129 | 130 | def process_map_file(map_file): 131 | ''' 132 | Input: 133 | map_file : string label和数字的对应关系文件 134 | Output: 135 | char_map : dict 对应关系字典 136 | int2phone : dict 数字到label的对应关系 137 | ''' 138 | char_map = dict() 139 | int2phone = dict() 140 | f = open(map_file, 'r') 141 | for line in f.readlines(): 142 | char, num = line.strip().split(' ') 143 | char_map[char] = int(num) 144 | int2phone[int(num)] = char 145 | f.close() 146 | int2phone[0] = '#' 147 | return char_map, int2phone 148 | 149 | if __name__ == '__main__': 150 | import scipy.signal 151 | windows = {'hamming':scipy.signal.hamming, 'hann':scipy.signal.hann, 'blackman':scipy.signal.blackman, 152 | 'bartlett':scipy.signal.bartlett} 153 | audio_conf = {"sample_rate":16000, 'window_size':0.025, 'window_stride':0.01, 'window': 'hamming'} 154 | path = '/home/fan/Audio_data/TIMIT/train/dr1/fcjf0/sa1.wav' 155 | spect = parse_audio(path, audio_conf, windows) 156 | mel_f = F_Mel(spect, audio_conf) 157 | print(spect) 158 | print(mel_f) 159 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | visdom 4 | -------------------------------------------------------------------------------- /timit/conf/backup.conf: -------------------------------------------------------------------------------- 1 | [Data] 2 | dataset = TIMIT 3 | data_dir = /home/fan/pytorch/CTC_pytorch/timit/data_prepare 4 | feature_type = spectrum 5 | n_feats = 201 6 | out_type = phone 7 | mel = True 8 | 9 | [Model] 10 | rnn_input_size = 201 11 | rnn_hidden_size = 256 12 | rnn_layers = 4 13 | rnn_type = nn.LSTM 14 | bidirectional = True 15 | batch_norm = True 16 | num_class = 48 17 | drop_out = 0 18 | add_cnn = True 19 | model_file = 20 | 21 | [CNN] 22 | layers = 2 23 | channel = [(1, 32), (32, 32)] 24 | kernel_size = [(3, 41), (3, 21)] 25 | stride = [(1, 2), (2, 2)] 26 | padding = [(0, 0), (0, 0)] 27 | pooling = None 28 | batch_norm = True 29 | activation_function = relu 30 | 31 | [Training] 32 | use_cuda = True 33 | init_lr = 0.001 34 | num_epoches = 500 35 | end_adjust_acc = 2 36 | lr_decay = 0.5 37 | batch_size = 16 38 | weight_decay = 0.05 39 | seed = 1234 40 | 41 | [Decode] 42 | decode_type = Greedy 43 | beam_width = 20 44 | lm_alpha = 0.1 45 | eval_dataset = test 46 | 47 | -------------------------------------------------------------------------------- /timit/conf/ctc_config.yaml: -------------------------------------------------------------------------------- 1 | #exp name and save dir 2 | exp_name: 'ctc_fbank_cnn' 3 | checkpoint_dir: 'checkpoint/' 4 | 5 | #Data 6 | vocab_file: 'data/units' 7 | train_scp_path: 'data/train/fbank.scp' 8 | train_lab_path: 'data/train/phn_text' 9 | valid_scp_path: 'data/dev/fbank.scp' 10 | valid_lab_path: 'data/dev/phn_text' 11 | left_ctx: 0 12 | right_ctx: 2 13 | n_skip_frame: 2 14 | n_downsample: 2 15 | num_workers: 1 16 | shuffle_train: True 17 | feature_dim: 81 18 | output_class_dim: 39 19 | mel: False 20 | feature_type: "fbank" 21 | 22 | #Model 23 | rnn_input_size: 243 24 | rnn_hidden_size: 384 25 | rnn_layers: 4 26 | rnn_type: "nn.LSTM" 27 | bidirectional: True 28 | batch_norm: True 29 | drop_out: 0.2 30 | 31 | #CNN 32 | add_cnn: True 33 | layers: 2 34 | channel: "[(1, 32), (32, 32)]" 35 | kernel_size: "[(3, 3), (3, 3)]" 36 | stride: "[(1, 2), (2, 2)]" 37 | padding: "[(1, 1), (1, 1)]" 38 | pooling: "None" 39 | batch_norm: True 40 | activation_function: "relu" 41 | 42 | #[Training] 43 | use_gpu: True 44 | init_lr: 0.001 45 | num_epoches: 500 46 | end_adjust_acc: 2 47 | lr_decay: 0.5 48 | batch_size: 8 49 | weight_decay: 0.0005 50 | seed: 1 51 | verbose_step: 50 52 | 53 | #[test] 54 | test_scp_path: 'data/test/fbank.scp' 55 | test_lab_path: 'data/test/phn_text' 56 | decode_type: "Greedy" 57 | beam_width: 10 58 | lm_alpha: 0.1 59 | lm_path: 'data/lm_phone_bg.arpa' 60 | 61 | -------------------------------------------------------------------------------- /timit/conf/dev_spk.list: -------------------------------------------------------------------------------- 1 | faks0 2 | fdac1 3 | fjem0 4 | mgwt0 5 | mjar0 6 | mmdb1 7 | mmdm2 8 | mpdf0 9 | fcmh0 10 | fkms0 11 | mbdg0 12 | mbwm0 13 | mcsh0 14 | fadg0 15 | fdms0 16 | fedw0 17 | mgjf0 18 | mglb0 19 | mrtk0 20 | mtaa0 21 | mtdt0 22 | mthc0 23 | mwjg0 24 | fnmr0 25 | frew0 26 | fsem0 27 | mbns0 28 | mmjr0 29 | mdls0 30 | mdlf0 31 | mdvc0 32 | mers0 33 | fmah0 34 | fdrw0 35 | mrcs0 36 | mrjm4 37 | fcal1 38 | mmwh0 39 | fjsj0 40 | majc0 41 | mjsw0 42 | mreb0 43 | fgjd0 44 | fjmg0 45 | mroa0 46 | mteb0 47 | mjfc0 48 | mrjr0 49 | fmml0 50 | mrws1 51 | -------------------------------------------------------------------------------- /timit/conf/fbank.conf: -------------------------------------------------------------------------------- 1 | --window-type=hamming 2 | --num-mel-bins=80 3 | --use-energy 4 | 5 | -------------------------------------------------------------------------------- /timit/conf/mfcc.conf: -------------------------------------------------------------------------------- 1 | --use-energy=false # only non-default option. 2 | -------------------------------------------------------------------------------- /timit/conf/phones.60-48-39.map: -------------------------------------------------------------------------------- 1 | aa aa aa 2 | ae ae ae 3 | ah ah ah 4 | ao ao aa 5 | aw aw aw 6 | ax ax ah 7 | ax-h ax ah 8 | axr er er 9 | ay ay ay 10 | b b b 11 | bcl vcl sil 12 | ch ch ch 13 | d d d 14 | dcl vcl sil 15 | dh dh dh 16 | dx dx dx 17 | eh eh eh 18 | el el l 19 | em m m 20 | en en n 21 | eng ng ng 22 | epi epi sil 23 | er er er 24 | ey ey ey 25 | f f f 26 | g g g 27 | gcl vcl sil 28 | h# sil sil 29 | hh hh hh 30 | hv hh hh 31 | ih ih ih 32 | ix ix ih 33 | iy iy iy 34 | jh jh jh 35 | k k k 36 | kcl cl sil 37 | l l l 38 | m m m 39 | n n n 40 | ng ng ng 41 | nx n n 42 | ow ow ow 43 | oy oy oy 44 | p p p 45 | pau sil sil 46 | pcl cl sil 47 | q 48 | r r r 49 | s s s 50 | sh sh sh 51 | t t t 52 | tcl cl sil 53 | th th th 54 | uh uh uh 55 | uw uw uw 56 | ux uw uw 57 | v v v 58 | w w w 59 | y y y 60 | z z z 61 | zh zh sh 62 | -------------------------------------------------------------------------------- /timit/conf/test_spk.list: -------------------------------------------------------------------------------- 1 | mdab0 2 | mwbt0 3 | felc0 4 | mtas1 5 | mwew0 6 | fpas0 7 | mjmp0 8 | mlnt0 9 | fpkt0 10 | mlll0 11 | mtls0 12 | fjlm0 13 | mbpm0 14 | mklt0 15 | fnlp0 16 | mcmj0 17 | mjdh0 18 | fmgd0 19 | mgrt0 20 | mnjm0 21 | fdhc0 22 | mjln0 23 | mpam0 24 | fmld0 25 | -------------------------------------------------------------------------------- /timit/conf/train_spk.list: -------------------------------------------------------------------------------- 1 | fcjf0 2 | fdaw0 3 | fdml0 4 | fecd0 5 | fetb0 6 | fjsp0 7 | fkfb0 8 | fmem0 9 | fsah0 10 | fsjk1 11 | fsma0 12 | ftbr0 13 | fvfb0 14 | fvmh0 15 | mcpm0 16 | mdac0 17 | mdpk0 18 | medr0 19 | mgrl0 20 | mjeb1 21 | mjwt0 22 | mkls0 23 | mklw0 24 | mmgg0 25 | mmrp0 26 | mpgh0 27 | mpgr0 28 | mpsw0 29 | mrai0 30 | mrcg0 31 | mrdd0 32 | mrso0 33 | mrws0 34 | mtjs0 35 | mtpf0 36 | mtrr0 37 | mwad0 38 | mwar0 39 | faem0 40 | fajw0 41 | fcaj0 42 | fcmm0 43 | fcyl0 44 | fdas1 45 | fdnc0 46 | fdxw0 47 | feac0 48 | fhlm0 49 | fjkl0 50 | fkaa0 51 | flma0 52 | flmc0 53 | fmjb0 54 | fmkf0 55 | fmmh0 56 | fpjf0 57 | frll0 58 | fscn0 59 | fskl0 60 | fsrh0 61 | ftmg0 62 | marc0 63 | mbjv0 64 | mcew0 65 | mctm0 66 | mdbp0 67 | mdem0 68 | mdlb0 69 | mdlc2 70 | mdmt0 71 | mdps0 72 | mdss0 73 | mdwd0 74 | mefg0 75 | mhrm0 76 | mjae0 77 | mjbg0 78 | mjde0 79 | mjeb0 80 | mjhi0 81 | mjma0 82 | mjmd0 83 | mjpm0 84 | mjrp0 85 | mkah0 86 | mkaj0 87 | mkdt0 88 | mkjo0 89 | mmaa0 90 | mmag0 91 | mmds0 92 | mmgk0 93 | mmxs0 94 | mppc0 95 | mprb0 96 | mrab0 97 | mrcw0 98 | mrfk0 99 | mrgs0 100 | mrhl0 101 | mrjh0 102 | mrjm0 103 | mrjm1 104 | mrjt0 105 | mrlj0 106 | mrlr0 107 | mrms0 108 | msat0 109 | mtat1 110 | mtbc0 111 | mtdb0 112 | mtjg0 113 | mwsb0 114 | mzmb0 115 | falk0 116 | fcke0 117 | fcmg0 118 | fdfb0 119 | fdjh0 120 | feme0 121 | fgcs0 122 | fgrw0 123 | fjlg0 124 | fjlr0 125 | flac0 126 | fljd0 127 | fltm0 128 | fmjf0 129 | fntb0 130 | fpaz0 131 | fsjs0 132 | fsjw0 133 | fskc0 134 | fsls0 135 | madc0 136 | makb0 137 | makr0 138 | mapv0 139 | mbef0 140 | mcal0 141 | mcdc0 142 | mcdd0 143 | mcef0 144 | mdbb1 145 | mddc0 146 | mdef0 147 | mdhs0 148 | mdjm0 149 | mdlc0 150 | mdlh0 151 | mdns0 152 | mdss1 153 | mdtb0 154 | mdwm0 155 | mfmc0 156 | mgaf0 157 | mhjb0 158 | mhmr0 159 | milb0 160 | mjda0 161 | mjjb0 162 | mjkr0 163 | mjlg1 164 | mjrh1 165 | mkls1 166 | mkxl0 167 | mlns0 168 | mmam0 169 | mmar0 170 | mmeb0 171 | mmjb1 172 | mmsm0 173 | mprd0 174 | mrbc0 175 | mrds0 176 | mree0 177 | mreh1 178 | mrjb1 179 | mrtc0 180 | mrtj0 181 | mrwa0 182 | msfv0 183 | mtjm0 184 | mtkp0 185 | mtlb0 186 | mtpg0 187 | mtpp0 188 | mvjh0 189 | mwdk0 190 | mwgr0 191 | falr0 192 | fbas0 193 | fbmj0 194 | fcag0 195 | fdkn0 196 | feeh0 197 | fjwb1 198 | fjxp0 199 | fkdw0 200 | fklc0 201 | flhd0 202 | flkm0 203 | fpaf0 204 | fsak0 205 | fssb0 206 | maeb0 207 | marw0 208 | mbma0 209 | mbwp0 210 | mcdr0 211 | mcss0 212 | mdcd0 213 | mdma0 214 | mesg0 215 | mfrm0 216 | mfwk0 217 | mgag0 218 | mgjc0 219 | mgrp0 220 | mgxp0 221 | mjac0 222 | mjdc0 223 | mjee0 224 | mjjj0 225 | mjlb0 226 | mjls0 227 | mjmm0 228 | mjpm1 229 | mjrh0 230 | mjsr0 231 | mjws0 232 | mjxl0 233 | mkam0 234 | mlbc0 235 | mlel0 236 | mljc0 237 | mljh0 238 | mlsh0 239 | mmbs0 240 | mmdm0 241 | mmgc0 242 | mnet0 243 | mpeb0 244 | mprk0 245 | mprt0 246 | mrab1 247 | mrfl0 248 | mrgm0 249 | mrsp0 250 | msfh0 251 | msmc0 252 | msms0 253 | msrg0 254 | mstf0 255 | mtas0 256 | mtqc0 257 | mtrc0 258 | mtrt0 259 | fbjl0 260 | fbmh0 261 | fcdr1 262 | fdmy0 263 | fdtd0 264 | fear0 265 | fexm0 266 | fgdp0 267 | fgmb0 268 | fjxm0 269 | fkkh0 270 | flja0 271 | fljg0 272 | flmk0 273 | flod0 274 | fmpg0 275 | fpmy0 276 | fsag0 277 | fsdc0 278 | fsjg0 279 | fskp0 280 | fsmm0 281 | fsms1 282 | ftbw0 283 | ftlg0 284 | mbgt0 285 | mchl0 286 | mclm0 287 | mdas0 288 | mdhl0 289 | mdsj0 290 | mdwh0 291 | megj0 292 | mewm0 293 | mfer0 294 | mges0 295 | mgsh0 296 | mhit0 297 | mhmg0 298 | mjdm0 299 | mjfh0 300 | mjpg0 301 | mjrg0 302 | mjwg0 303 | mjxa0 304 | mmab1 305 | mmcc0 306 | mmdm1 307 | mmvp0 308 | mmwb0 309 | mpmb0 310 | mram0 311 | mrav0 312 | mrew1 313 | mrkm0 314 | mrld0 315 | mrml0 316 | mrvg0 317 | msas0 318 | msdh0 319 | msem1 320 | msrr0 321 | mtat0 322 | mtdp0 323 | mtmt0 324 | mvlo0 325 | mwac0 326 | mwch0 327 | mwem0 328 | mwsh0 329 | fapb0 330 | fbch0 331 | fhxs0 332 | fjdm2 333 | fklc1 334 | flag0 335 | fmju0 336 | fpad0 337 | frjb0 338 | fsbk0 339 | fsdj0 340 | fsgf0 341 | ftaj0 342 | mabc0 343 | majp0 344 | mbma1 345 | mcae0 346 | mdrd0 347 | meal0 348 | mejl0 349 | mesj0 350 | mjrk0 351 | mkes0 352 | mkln0 353 | mmdb0 354 | mpgr1 355 | mrmb0 356 | mrxb0 357 | msat1 358 | msds0 359 | msjk0 360 | msmr0 361 | msvs0 362 | mtju0 363 | mtxs0 364 | fblv0 365 | fcjs0 366 | fcrz0 367 | fjen0 368 | fjhk0 369 | fjrp1 370 | fjsk0 371 | fkde0 372 | fksr0 373 | fleh0 374 | flet0 375 | fmah1 376 | fmkc0 377 | fpab1 378 | fpac0 379 | freh0 380 | fspm0 381 | fvkb0 382 | madd0 383 | maeo0 384 | mafm0 385 | mbar0 386 | mbbr0 387 | mbml0 388 | mbom0 389 | mbth0 390 | mclk0 391 | mcre0 392 | mcth0 393 | mdcm0 394 | mded0 395 | mdks0 396 | mdlc1 397 | mdlm0 398 | mdlr0 399 | mdlr1 400 | mdpb0 401 | mfxs0 402 | mfxv0 403 | mgak0 404 | mgar0 405 | mgaw0 406 | mgsl0 407 | mhbs0 408 | mhxl0 409 | mjai0 410 | mjdg0 411 | mjfr0 412 | mjjm0 413 | mjra0 414 | mkag0 415 | mkdb0 416 | mklr0 417 | mmdg0 418 | mmws1 419 | mntw0 420 | mpar0 421 | mpfu0 422 | mrem0 423 | mrlj1 424 | mrmg0 425 | mrmh0 426 | mrpc1 427 | msah1 428 | msdb0 429 | mses0 430 | mtab0 431 | mter0 432 | mtkd0 433 | mtlc0 434 | mtml0 435 | mtmn0 436 | mtpr0 437 | mtwh1 438 | mvrw0 439 | mwre0 440 | mwrp0 441 | fbcg1 442 | fceg0 443 | fclt0 444 | fjrb0 445 | fklh0 446 | fmbg0 447 | fnkl0 448 | fpls0 449 | mbcg0 450 | mbsb0 451 | mcxm0 452 | mejs0 453 | mkdd0 454 | mkrg0 455 | mmea0 456 | mmlm0 457 | mmpm0 458 | mmws0 459 | mrdm0 460 | mrlk0 461 | mrre0 462 | mtcs0 463 | -------------------------------------------------------------------------------- /timit/local/make_spectrum.py: -------------------------------------------------------------------------------- 1 | #encoding=utf-8 2 | 3 | #The code make the Full-mell Spectrogram feature and save it as ark and scp 4 | #Author: Richardfan 5 | #Date: 2018.4.24 6 | 7 | import torchaudio 8 | import librosa 9 | import numpy as np 10 | import struct 11 | import sys 12 | import scipy.signal 13 | 14 | class KaldiWriteOut(object): 15 | def __init__(self, ark_path, scp_path): 16 | self.ark_path = ark_path 17 | self.scp_path = scp_path 18 | self.ark_file_write = open(ark_path, 'wb') 19 | self.scp_file_write = open(scp_path, 'w') 20 | self.pos = 0 21 | 22 | def write_kaldi_mat(self, utt_id, utt_mat): 23 | utt_mat = np.asarray(utt_mat, dtype=np.float32) 24 | rows, cols = utt_mat.shape 25 | self.ark_file_write.write(struct.pack('<%ds'%(len(utt_id)), utt_id)) 26 | self.ark_file_write.write(struct.pack(' 1: 48 | if sound.shape[1] == 1: 49 | sound = sound.squeeze() 50 | else: 51 | sound - sound.mean(axis=1) 52 | return sound 53 | 54 | def parse_audio(path, audio_conf, windows, normalize=True): 55 | ''' 56 | Input: 57 | path : string 导入音频的路径 58 | audio_conf : dcit 求频谱的音频参数 59 | windows : dict 加窗类型 60 | Output: 61 | spect : ndarray 每帧的频谱 62 | ''' 63 | y = load_audio(path) 64 | n_fft = int(audio_conf['sample_rate']*audio_conf["window_size"]) 65 | win_length = n_fft 66 | hop_length = int(audio_conf['sample_rate']*audio_conf['window_stride']) 67 | window = windows[audio_conf['window']] 68 | D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, 69 | win_length=win_length, window=window) 70 | spect, phase = librosa.magphase(D) 71 | 72 | spect = np.log1p(spect) 73 | 74 | if normalize: 75 | mean = spect.mean() 76 | std = spect.std() 77 | spect = np.add(spect, -mean) 78 | spect = np.divide(spect, std) 79 | 80 | return spect.transpose() 81 | 82 | def make_spectrum(wave_path, ark_file, scp_file): 83 | windows = {'hamming':scipy.signal.hamming, 'hann':scipy.signal.hann, 'blackman':scipy.signal.blackman, 84 | 'bartlett':scipy.signal.bartlett} 85 | audio_conf = {"sample_rate":16000, 'window_size':0.025, 'window_stride':0.01, 'window': 'hamming'} 86 | arkwriter = KaldiWriteOut(ark_file, scp_file) 87 | with open(wave_path, 'r') as rf: 88 | i = 0 89 | for lines in rf.readlines(): 90 | utt_id, path = lines.strip().split() 91 | utt_mat = parse_audio(path, audio_conf, windows, normalize=True) 92 | arkwriter.write_kaldi_mat(utt_id, utt_mat) 93 | i += 1 94 | if i %10 == 0: 95 | print("Processed %d sentences..." % i) 96 | arkwriter.close() 97 | print("Done. Processed %d sentences..." % i) 98 | 99 | if __name__ == '__main__': 100 | if len(sys.argv) != 4: 101 | print("Usage: python "+sys.argv[0] + ' [wav_path] [ark file to write] [scp file to write]') 102 | sys.exit(1) 103 | wave_path = sys.argv[1] 104 | ark_file = sys.argv[2] 105 | scp_file = sys.argv[3] 106 | make_spectrum(wave_path, ark_file, scp_file) 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /timit/local/normalize_phone.py: -------------------------------------------------------------------------------- 1 | #encoding=utf-8 2 | 3 | import os 4 | import sys 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser(description="Normalize the phoneme on TIMIT") 8 | parser.add_argument("--map", default="./decode_map_48-39/phones.60-48-39.map", help="The map file") 9 | parser.add_argument("--to", default=48, help="Determine how many phonemes to map") 10 | parser.add_argument("--src", default='./data_prepare/train/phn_text', help="The source file to mapping") 11 | parser.add_argument("--tgt", default='./data_prepare/train/48_text' ,help="The target file after mapping") 12 | 13 | def main(): 14 | args = parser.parse_args() 15 | if not os.path.exists(args.map) or not os.path.exists(args.src): 16 | print("Map file or source file not exist !") 17 | sys.exit(1) 18 | 19 | map_dict = {} 20 | with open(args.map) as f: 21 | for line in f.readlines(): 22 | line = line.strip().split('\t') 23 | if args.to == "60-48": 24 | if len(line) == 1: 25 | map_dict[line[0]] = "" 26 | else: 27 | map_dict[line[0]] = line[1] 28 | elif args.to == "60-39": 29 | if len(line) == 1: 30 | map_dict[line[0]] = "" 31 | else: 32 | map_dict[line[0]] = line[2] 33 | elif args.to == "48-39": 34 | if len(line) == 3: 35 | map_dict[line[1]] = line[2] 36 | else: 37 | print("%s phonemes are not supported" % args.to) 38 | sys.exit(1) 39 | 40 | with open(args.src, 'r') as rf, open(args.tgt, 'w') as wf: 41 | for line in rf.readlines(): 42 | line = line.strip().split(' ') 43 | uttid, utt = line[0], line[1:] 44 | map_utt = [ map_dict[phone] for phone in utt if map_dict[phone] != "" ] 45 | wf.writelines(uttid + ' ' + ' '.join(map_utt) + '\n') 46 | 47 | if __name__ == "__main__": 48 | main() 49 | -------------------------------------------------------------------------------- /timit/local/timit_data_prep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #此文件用来得到训练集,验证集和测试集的音频路径文件和转录文本即标签文件以便后续处理 4 | #输入的参数时TIMIT数据库的路径。 5 | #更换数据集之后,因为数据的目录结构不一致,需要对此脚本进行简单的修改。 6 | 7 | if [ $# -ne 2 ]; then 8 | echo "Need directory of TIMIT dataset !" 9 | exit 1; 10 | fi 11 | 12 | conf_dir=`pwd`/conf 13 | prepare_dir=`pwd`/data 14 | map_file=$conf_dir/phones.60-48-39.map 15 | phoneme_map=$2 16 | 17 | . path.sh 18 | sph2pipe=$KALDI_ROOT/tools/sph2pipe_v2.5/sph2pipe 19 | if [ ! -x $sph2pipe ]; then 20 | echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; 21 | exit 1; 22 | fi 23 | 24 | [ -f $conf_dir/test_spk.list ] || error_exit "$PROG: Eval-set speaker list not found."; 25 | [ -f $conf_dir/dev_spk.list ] || error_exit "$PROG: dev-set speaker list not found."; 26 | 27 | #根据数据库train,test的名称修改,有时候下载下来train可能是大写或者是其他形式 28 | train_dir=train 29 | test_dir=test 30 | 31 | ls -d "$1"/$train_dir/dr*/* | sed -e "s:^.*/::" > $conf_dir/train_spk.list 32 | 33 | tmpdir=`pwd`/tmp 34 | mkdir -p $tmpdir $prepare_dir 35 | for x in train dev test; do 36 | if [ ! -d $prepare_dir/$x ]; then 37 | mkdir -p $prepare_dir/$x 38 | fi 39 | 40 | # 只使用 si & sx 的语音. 41 | find $1/{$train_dir,$test_dir} -not \( -iname 'SA*' \) -iname '*.WAV' \ 42 | | grep -f $conf_dir/${x}_spk.list > $tmpdir/${x}_sph.flist 43 | 44 | #获得每句话的id标识 45 | sed -e 's:.*/\(.*\)/\(.*\).WAV$:\1_\2:i' $tmpdir/${x}_sph.flist \ 46 | > $tmpdir/${x}_sph.uttids 47 | 48 | #生成wav.scp,即每句话的音频路径 49 | paste -d" " $tmpdir/${x}_sph.uttids $tmpdir/${x}_sph.flist \ 50 | | sort -k1,1 > $prepare_dir/$x/wav.scp 51 | 52 | awk '{printf("%s '$sph2pipe' -f wav %s |\n", $1, $2);}' < $prepare_dir/$x/wav.scp > $prepare_dir/$x/wav_sph.scp 53 | 54 | for y in wrd phn; do 55 | find $1/{$train_dir,$test_dir} -not \( -iname 'SA*' \) -iname '*.'$y'' \ 56 | | grep -f $conf_dir/${x}_spk.list > $tmpdir/${x}_txt.flist 57 | sed -e 's:.*/\(.*\)/\(.*\).'$y'$:\1_\2:i' $tmpdir/${x}_txt.flist \ 58 | > $tmpdir/${x}_txt.uttids 59 | while read line; do 60 | [ -f $line ] || error_exit "Cannot find transcription file '$line'"; 61 | cut -f3 -d' ' "$line" | tr '\n' ' ' | sed -e 's: *$:\n:' 62 | done < $tmpdir/${x}_txt.flist > $tmpdir/${x}_txt.trans 63 | 64 | #将句子标识(uttid)和文本标签放在一行并按照uttid进行排序使其与音频路径顺序一致 65 | paste -d" " $tmpdir/${x}_txt.uttids $tmpdir/${x}_txt.trans \ 66 | | sort -k1,1 > $tmpdir/${x}.trans 67 | 68 | #生成文本标签 69 | cat $tmpdir/${x}.trans | sort > $prepare_dir/$x/${y}_text || exit 1; 70 | if [ $y == phn ]; then 71 | cp $prepare_dir/$x/${y}_text $prepare_dir/$x/${y}_text.tmp 72 | python local/normalize_phone.py --map $map_file --to $phoneme_map --src $prepare_dir/$x/${y}_text.tmp --tgt $prepare_dir/$x/${y}_text 73 | rm -f $prepare_dir/$x/${y}_text.tmp 74 | fi 75 | done 76 | done 77 | 78 | rm -rf $tmpdir 79 | 80 | echo "Data preparation succeeded" 81 | -------------------------------------------------------------------------------- /timit/models/model_ctc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #encoding=utf-8 3 | 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import editdistance as ed 8 | import torch.nn.functional as F 9 | from collections import OrderedDict 10 | 11 | __author__ = "Ruchao Fan" 12 | 13 | class BatchRNN(nn.Module): 14 | """ 15 | Add BatchNorm before rnn to generate a batchrnn layer 16 | """ 17 | def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, 18 | bidirectional=False, batch_norm=True, dropout=0.1): 19 | super(BatchRNN, self).__init__() 20 | self.input_size = input_size 21 | self.hidden_size = hidden_size 22 | self.bidirectional = bidirectional 23 | self.batch_norm = nn.BatchNorm1d(input_size) if batch_norm else None 24 | self.rnn = rnn_type(input_size=input_size, hidden_size=hidden_size, 25 | bidirectional=bidirectional, bias=False) 26 | self.dropout = nn.Dropout(p=dropout) 27 | 28 | def forward(self, x): 29 | if self.batch_norm is not None: 30 | x = x.transpose(-1, -2) 31 | x = self.batch_norm(x) 32 | x = x.transpose(-1, -2) 33 | x, _ = self.rnn(x) 34 | x = self.dropout(x) 35 | #self.rnn.flatten_parameters() 36 | return x 37 | 38 | class LayerCNN(nn.Module): 39 | """ 40 | One CNN layer include conv2d, batchnorm, activation and maxpooling 41 | """ 42 | def __init__(self, in_channel, out_channel, kernel_size, stride, padding, pooling_size=None, 43 | activation_function=nn.ReLU, batch_norm=True, dropout=0.1): 44 | super(LayerCNN, self).__init__() 45 | if len(kernel_size) == 2: 46 | self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding) 47 | self.batch_norm = nn.BatchNorm2d(out_channel) if batch_norm else None 48 | else: 49 | self.conv = nn.Conv1d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding) 50 | self.batch_norm = nn.BatchNorm1d(out_channel) if batch_norm else None 51 | self.activation = activation_function(inplace=True) 52 | if pooling_size is not None and len(kernel_size) == 2: 53 | self.pooling = nn.MaxPool2d(pooling_size) 54 | elif len(kernel_size) == 1: 55 | self.pooling = nn.MaxPool1d(pooling_size) 56 | else: 57 | self.pooling = None 58 | self.dropout = nn.Dropout(p=dropout) 59 | 60 | def forward(self, x): 61 | x = self.conv(x) 62 | if self.batch_norm is not None: 63 | x = self.batch_norm(x) 64 | x = self.activation(x) 65 | if self.pooling is not None: 66 | x = self.pooling(x) 67 | x = self.dropout(x) 68 | return x 69 | 70 | class CTC_Model(nn.Module): 71 | def __init__(self, add_cnn=False, cnn_param=None, rnn_param=None, num_class=39, drop_out=0.1): 72 | """ 73 | add_cnn [bool]: whether add cnn in the model 74 | cnn_param [dict]: cnn parameters, only support Conv2d i.e. 75 | cnn_param = {"layer":[[(in_channel, out_channel), (kernel_size), (stride), (padding), (pooling_size)],...], 76 | "batch_norm":True, "activate_function":nn.ReLU} 77 | rnn_param [dict]: rnn parameters i.e. 78 | rnn_param = {"rnn_input_size":201, "rnn_hidden_size":256, ....} 79 | num_class [int]: the number of modelling units, add blank to be the number of classes 80 | drop_out [float]: drop_out rate for all 81 | """ 82 | super(CTC_Model, self).__init__() 83 | self.add_cnn = add_cnn 84 | self.cnn_param = cnn_param 85 | if rnn_param is None or type(rnn_param) != dict: 86 | raise ValueError("rnn_param need to be a dict to contain all params of rnn!") 87 | self.rnn_param = rnn_param 88 | self.num_class = num_class 89 | self.num_directions = 2 if rnn_param["bidirectional"] else 1 90 | self.drop_out = drop_out 91 | 92 | if add_cnn: 93 | cnns = [] 94 | activation = cnn_param["activate_function"] 95 | batch_norm = cnn_param["batch_norm"] 96 | rnn_input_size = rnn_param["rnn_input_size"] 97 | cnn_layers = cnn_param["layer"] 98 | for n in range(len(cnn_layers)): 99 | in_channel = cnn_layers[n][0][0] 100 | out_channel = cnn_layers[n][0][1] 101 | kernel_size = cnn_layers[n][1] 102 | stride = cnn_layers[n][2] 103 | padding = cnn_layers[n][3] 104 | pooling_size = cnn_layers[n][4] 105 | 106 | cnn = LayerCNN(in_channel, out_channel, kernel_size, stride, padding, pooling_size, 107 | activation_function=activation, batch_norm=batch_norm, dropout=drop_out) 108 | cnns.append(('%d' % n, cnn)) 109 | 110 | try: 111 | rnn_input_size = int(math.floor((rnn_input_size+2*padding[1]-kernel_size[1])/stride[1])+1) 112 | except: 113 | #if using 1-d Conv 114 | rnn_input_size = rnn_input_size 115 | self.conv = nn.Sequential(OrderedDict(cnns)) 116 | rnn_input_size *= out_channel 117 | else: 118 | rnn_input_size = rnn_param["rnn_input_size"] 119 | 120 | rnns = [] 121 | rnn_hidden_size = rnn_param["rnn_hidden_size"] 122 | rnn_type = rnn_param["rnn_type"] 123 | rnn_layers = rnn_param["rnn_layers"] 124 | bidirectional = rnn_param["bidirectional"] 125 | batch_norm = rnn_param["batch_norm"] 126 | rnn = BatchRNN(input_size=rnn_input_size, hidden_size=rnn_hidden_size, rnn_type=rnn_type, 127 | bidirectional=bidirectional, dropout=drop_out, batch_norm=False) 128 | rnns.append(('0', rnn)) 129 | for i in range(rnn_layers-1): 130 | rnn = BatchRNN(input_size=self.num_directions*rnn_hidden_size, hidden_size=rnn_hidden_size, rnn_type=rnn_type, 131 | bidirectional=bidirectional, dropout=drop_out, batch_norm=batch_norm) 132 | rnns.append(('%d' % (i+1), rnn)) 133 | self.rnns = nn.Sequential(OrderedDict(rnns)) 134 | 135 | if batch_norm: 136 | self.fc = nn.Sequential(nn.BatchNorm1d(self.num_directions*rnn_hidden_size), 137 | nn.Linear(self.num_directions*rnn_hidden_size, num_class, bias=False),) 138 | else: 139 | self.fc = nn.Linear(self.num_directions*rnn_hidden_size, num_class, bias=False) 140 | self.log_softmax = nn.LogSoftmax(dim=-1) 141 | 142 | def forward(self, x, visualize=False): 143 | #x: batch_size * 1 * max_seq_length * feat_size 144 | if visualize: 145 | visual = [x] 146 | 147 | if self.add_cnn: 148 | x = self.conv(x.unsqueeze(1)) 149 | 150 | if visualize: 151 | visual.append(x) 152 | 153 | x = x.transpose(1, 2).contiguous() 154 | sizes = x.size() 155 | if len(sizes) > 3: 156 | x = x.view(sizes[0], sizes[1], sizes[2]*sizes[3]) 157 | 158 | x = x.transpose(0,1).contiguous() 159 | 160 | if visualize: 161 | visual.append(x) 162 | 163 | x = self.rnns(x) 164 | seq_len, batch, _ = x.size() 165 | x = x.view(seq_len*batch, -1) 166 | x = self.fc(x) 167 | x = x.view(seq_len, batch, -1) 168 | out = self.log_softmax(x) 169 | 170 | if visualize: 171 | visual.append(out) 172 | return out, visual 173 | return out 174 | else: 175 | x = x.transpose(0, 1) 176 | x = self.rnns(x) 177 | seq_len, batch, _ = x.size() 178 | x = x.view(seq_len*batch, -1) 179 | x = self.fc(x) 180 | x = x.view(seq_len, batch, -1) 181 | out = self.log_softmax(x) 182 | if visualize: 183 | visual.append(out) 184 | return out, visual 185 | return out 186 | 187 | def compute_wer(self, index, input_sizes, targets, target_sizes): 188 | batch_errs = 0 189 | batch_tokens = 0 190 | for i in range(len(index)): 191 | label = targets[i][:target_sizes[i]] 192 | pred = [] 193 | for j in range(len(index[i][:input_sizes[i]])): 194 | if index[i][j] == 0: 195 | continue 196 | if j == 0: 197 | pred.append(index[i][j]) 198 | if j > 0 and index[i][j] != index[i][j-1]: 199 | pred.append(index[i][j]) 200 | batch_errs += ed.eval(label, pred) 201 | batch_tokens += len(label) 202 | return batch_errs, batch_tokens 203 | 204 | def add_weights_noise(self): 205 | for param in self.parameters(): 206 | weight_noise = param.data.new(param.size()).normal_(0, 0.075).type_as(param.type()) 207 | param = torch.nn.parameter.Parameter(param.data + weight_noise) 208 | 209 | @staticmethod 210 | def save_package(model, optimizer=None, decoder=None, epoch=None, loss_results=None, dev_loss_results=None, dev_cer_results=None): 211 | package = { 212 | 'rnn_param': model.rnn_param, 213 | 'add_cnn': model.add_cnn, 214 | 'cnn_param': model.cnn_param, 215 | 'num_class': model.num_class, 216 | '_drop_out': model.drop_out, 217 | 'state_dict': model.state_dict() 218 | } 219 | if optimizer is not None: 220 | package['optim_dict'] = optimizer.state_dict() 221 | if decoder is not None: 222 | package['decoder'] = decoder 223 | if epoch is not None: 224 | package['epoch'] = epoch 225 | if loss_results is not None: 226 | package['loss_results'] = loss_results 227 | package['dev_loss_results'] = dev_loss_results 228 | package['dev_cer_results'] = dev_cer_results 229 | return package 230 | 231 | if __name__ == '__main__': 232 | model = CTC_Model(add_cnn=True, cnn_param={"batch_norm":True, "activativate_function":nn.ReLU, "layer":[[(1,32), (3,41), (1,2), (0,0), None], 233 | [(32,32), (3,21), (2,2), (0,0), None]]}, num_class=48, drop_out=0) 234 | for idx, m in CTC_Model.modules(): 235 | print(idx, m) 236 | 237 | -------------------------------------------------------------------------------- /timit/path.sh: -------------------------------------------------------------------------------- 1 | KALDI_ROOT=../../kaldi 2 | 3 | . $KALDI_ROOT/tools/config/common_path.sh 4 | export LC_ALL=C 5 | -------------------------------------------------------------------------------- /timit/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #Author: Ruchao Fan 4 | #2017.11.1 Training acoustic model and decode with phoneme-level bigram 5 | #2018.4.30 Replace the h5py with ark and simplify the data_loader.py 6 | #2019.12.20 Update to pytorch1.2 and python3.7 7 | 8 | . path.sh 9 | 10 | stage=0 11 | 12 | timit_dir='' 13 | phoneme_map='60-39' 14 | feat_dir='data' #dir to save feature 15 | feat_type='fbank' #fbank, mfcc, spectrogram 16 | config_file='conf/ctc_config.yaml' 17 | 18 | if [ ! -z $1 ]; then 19 | stage=$1 20 | fi 21 | 22 | if [ $stage -le 0 ]; then 23 | echo "Step 0: Data Preparation ..." 24 | local/timit_data_prep.sh $timit_dir $phoneme_map || exit 1; 25 | python3 steps/get_model_units.py $feat_dir/train/phn_text 26 | fi 27 | 28 | if [ $stage -le 1 ]; then 29 | echo "Step 1: Feature Extraction..." 30 | steps/make_feat.sh $feat_type $feat_dir || exit 1; 31 | fi 32 | 33 | if [ $stage -le 2 ]; then 34 | echo "Step 2: Acoustic Model(CTC) Training..." 35 | CUDA_VISIBLE_DEVICE='0' python3 steps/train_ctc.py --conf $config_file || exit 1; 36 | fi 37 | 38 | if [ $stage -le 3 ]; then 39 | echo "Step 3: LM Model Training..." 40 | steps/train_lm.sh $feat_dir || exit 1; 41 | fi 42 | 43 | if [ $stage -le 4 ]; then 44 | echo "Step 4: Decoding..." 45 | CUDA_VISIBLE_DEVICE='0' python3 steps/test_ctc.py --conf $config_file || exit 1; 46 | fi 47 | 48 | -------------------------------------------------------------------------------- /timit/steps/get_model_units.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | 4 | if len(sys.argv) != 2: 5 | print("We need training text to generate the modelling units.") 6 | sys.exit(1) 7 | 8 | train_text = sys.argv[1] 9 | units_file = 'data/units' 10 | 11 | units = {} 12 | with open(train_text, 'r') as fin: 13 | line = fin.readline() 14 | while line: 15 | line = line.strip().split(' ') 16 | for char in line[1:]: 17 | try: 18 | if units[char] == True: 19 | continue 20 | except: 21 | units[char] = True 22 | line = fin.readline() 23 | 24 | fwriter = open(units_file, 'w') 25 | for char in units: 26 | print(char, file=fwriter) 27 | 28 | 29 | -------------------------------------------------------------------------------- /timit/steps/make_feat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #The script is to make fbank, mfcc and spectrogram from kaldi 4 | 5 | feat_type=$1 6 | data_dir=$2 7 | conf_dir=conf 8 | compress=false 9 | 10 | if [ "$feat_type" != "fbank" || "$feat_type" != "mfcc" || "$feat_type" != "spectrogram" ]; then 11 | echo "Feature type $feat_type does not support!" 12 | exit 1; 13 | else 14 | echo ============================================================================ 15 | echo " $feat_type Feature Extration and CMVN " 16 | echo ============================================================================ 17 | 18 | feat_config=$conf_dir/$feat_type.conf 19 | if [ ! -f $feat_config ]; then 20 | echo "missing file $feat_config!" 21 | exit 1; 22 | fi 23 | 24 | x=train 25 | compute-$feat_type-feats --config=$feat_config scp,p:$data_dir/$x/wav_sph.scp \ 26 | ark,scp:$data_dir/$x/raw_$feat_type.ark,$data_dir/$x/raw_$feat_type.scp 27 | #compute mean and variance with all training samples 28 | compute-cmvn-stats --binary=false scp:$data_dir/$x/raw_$feat_type.scp $data_dir/global_${feat_type}_cmvn.txt 29 | #apply cmvn for training set 30 | apply-cmvn --norm-vars=true $data_dir/global_${feat_type}_cmvn.txt scp:$data_dir/$x/raw_$feat_type.scp ark:- |\ 31 | copy-feats --compress=$compress ark:- ark,scp:$data_dir/$x/$feat_type.ark,$data_dir/$x/$feat_type.scp 32 | rm -f $data_dir/$x/raw_$feat_type.ark $data_dir/$x/raw_$feat_type.scp 33 | 34 | for x in dev test; do 35 | compute-$feat_type-feats --config=$feat_config scp,p:$data_dir/$x/wav_sph.scp ark:- | \ 36 | apply-cmvn --norm-vars=true $data_dir/global_${feat_type}_cmvn.txt ark:- ark:- |\ 37 | copy-feats --compress=$compress ark:- ark,scp:$data_dir/$x/$feat_type.ark,$data_dir/$x/$feat_type.scp 38 | done 39 | fi 40 | 41 | echo "Finished successfully on" `date` 42 | exit 0 43 | -------------------------------------------------------------------------------- /timit/steps/test_ctc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #encoding=utf-8 3 | 4 | import os 5 | import time 6 | import sys 7 | import torch 8 | import yaml 9 | import argparse 10 | import torch.nn as nn 11 | 12 | sys.path.append('./') 13 | from models.model_ctc import * 14 | from utils.ctcDecoder import GreedyDecoder, BeamDecoder 15 | from utils.data_loader import Vocab, SpeechDataset, SpeechDataLoader 16 | from steps.train_ctc import Config 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--conf', help='conf file for training') 20 | 21 | def test(): 22 | args = parser.parse_args() 23 | try: 24 | conf = yaml.safe_load(open(args.conf,'r')) 25 | except: 26 | print("Config file not exist!") 27 | sys.exit(1) 28 | 29 | opts = Config() 30 | for k,v in conf.items(): 31 | setattr(opts, k, v) 32 | print('{:50}:{}'.format(k, v)) 33 | 34 | use_cuda = opts.use_gpu 35 | device = torch.device('cuda') if use_cuda else torch.device('cpu') 36 | 37 | model_path = os.path.join(opts.checkpoint_dir, opts.exp_name, 'ctc_best_model.pkl') 38 | package = torch.load(model_path) 39 | 40 | rnn_param = package["rnn_param"] 41 | add_cnn = package["add_cnn"] 42 | cnn_param = package["cnn_param"] 43 | num_class = package["num_class"] 44 | feature_type = package['epoch']['feature_type'] 45 | n_feats = package['epoch']['n_feats'] 46 | drop_out = package['_drop_out'] 47 | mel = opts.mel 48 | 49 | beam_width = opts.beam_width 50 | lm_alpha = opts.lm_alpha 51 | decoder_type = opts.decode_type 52 | vocab_file = opts.vocab_file 53 | 54 | vocab = Vocab(vocab_file) 55 | test_dataset = SpeechDataset(vocab, opts.test_scp_path, opts.test_lab_path, opts) 56 | test_loader = SpeechDataLoader(test_dataset, batch_size=opts.batch_size, shuffle=False, num_workers=opts.num_workers, pin_memory=False) 57 | 58 | model = CTC_Model(rnn_param=rnn_param, add_cnn=add_cnn, cnn_param=cnn_param, num_class=num_class, drop_out=drop_out) 59 | model.to(device) 60 | model.load_state_dict(package['state_dict']) 61 | model.eval() 62 | 63 | 64 | if decoder_type == 'Greedy': 65 | decoder = GreedyDecoder(vocab.index2word, space_idx=-1, blank_index=0) 66 | else: 67 | decoder = BeamDecoder(vocab.index2word, beam_width=beam_width, blank_index=0, space_idx=-1, lm_path=opts.lm_path, lm_alpha=opts.lm_alpha) 68 | 69 | total_wer = 0 70 | total_cer = 0 71 | start = time.time() 72 | with torch.no_grad(): 73 | for data in test_loader: 74 | inputs, input_sizes, targets, target_sizes, utt_list = data 75 | inputs = inputs.to(device) 76 | #rnput_sizes = input_sizes.to(device) 77 | #target = target.to(device) 78 | #target_sizes = target_sizes.to(device) 79 | 80 | probs = model(inputs) 81 | 82 | max_length = probs.size(0) 83 | input_sizes = (input_sizes * max_length).long() 84 | 85 | probs = probs.cpu() 86 | decoded = decoder.decode(probs, input_sizes.numpy().tolist()) 87 | 88 | targets, target_sizes = targets.numpy(), target_sizes.numpy() 89 | labels = [] 90 | for i in range(len(targets)): 91 | label = [ vocab.index2word[num] for num in targets[i][:target_sizes[i]]] 92 | labels.append(' '.join(label)) 93 | 94 | for x in range(len(targets)): 95 | print("origin : " + labels[x]) 96 | print("decoded: " + decoded[x]) 97 | cer = 0 98 | wer = 0 99 | for x in range(len(labels)): 100 | cer += decoder.cer(decoded[x], labels[x]) 101 | wer += decoder.wer(decoded[x], labels[x]) 102 | decoder.num_word += len(labels[x].split()) 103 | decoder.num_char += len(labels[x]) 104 | total_cer += cer 105 | total_wer += wer 106 | CER = (float(total_cer) / decoder.num_char)*100 107 | WER = (float(total_wer) / decoder.num_word)*100 108 | print("Character error rate on test set: %.4f" % CER) 109 | print("Word error rate on test set: %.4f" % WER) 110 | end = time.time() 111 | time_used = (end - start) / 60.0 112 | print("time used for decode %d sentences: %.4f minutes." % (len(test_dataset), time_used)) 113 | 114 | if __name__ == "__main__": 115 | test() 116 | -------------------------------------------------------------------------------- /timit/steps/train_ctc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #encoding=utf-8 3 | 4 | import os 5 | import sys 6 | import copy 7 | import time 8 | import yaml 9 | import argparse 10 | import numpy as np 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | sys.path.append('./') 16 | from models.model_ctc import * 17 | #from warpctc_pytorch import CTCLoss # use built-in nn.CTCLoss 18 | from utils.data_loader import Vocab, SpeechDataset, SpeechDataLoader 19 | 20 | supported_rnn = {'nn.LSTM':nn.LSTM, 'nn.GRU': nn.GRU, 'nn.RNN':nn.RNN} 21 | supported_activate = {'relu':nn.ReLU, 'tanh':nn.Tanh, 'sigmoid':nn.Sigmoid} 22 | 23 | parser = argparse.ArgumentParser(description='cnn_lstm_ctc') 24 | parser.add_argument('--conf', default='conf/ctc_config.yaml' , help='conf file with argument of LSTM and training') 25 | 26 | def run_epoch(epoch_id, model, data_iter, loss_fn, device, optimizer=None, print_every=20, is_training=True): 27 | if is_training: 28 | model.train() 29 | else: 30 | model.eval() 31 | 32 | total_loss = 0 33 | total_tokens = 0 34 | total_errs = 0 35 | cur_loss = 0 36 | 37 | for i, data in enumerate(data_iter): 38 | inputs, input_sizes, targets, target_sizes, utt_list = data 39 | inputs = inputs.to(device) 40 | input_sizes = input_sizes.to(device) 41 | targets = targets.to(device) 42 | target_sizes = target_sizes.to(device) 43 | 44 | out = model(inputs) 45 | out_len, batch_size, _ = out.size() 46 | input_sizes = (input_sizes * out_len).long() 47 | loss = loss_fn(out, targets, input_sizes, target_sizes) 48 | loss /= batch_size 49 | cur_loss += loss.item() 50 | total_loss += loss.item() 51 | prob, index = torch.max(out, dim=-1) 52 | batch_errs, batch_tokens = model.compute_wer(index.transpose(0,1).cpu().numpy(), input_sizes.cpu().numpy(), targets.cpu().numpy(), target_sizes.cpu().numpy()) 53 | total_errs += batch_errs 54 | total_tokens += batch_tokens 55 | 56 | if (i + 1) % print_every == 0 and is_training: 57 | print('Epoch = %d, step = %d, cur_loss = %.4f, total_loss = %.4f, total_wer = %.4f' % (epoch_id, 58 | i+1, cur_loss / print_every, total_loss / (i+1), total_errs / total_tokens )) 59 | cur_loss = 0 60 | 61 | if is_training: 62 | optimizer.zero_grad() 63 | loss.backward() 64 | #nn.utils.clip_grad_norm_(model.parameters(), 400) 65 | optimizer.step() 66 | average_loss = total_loss / (i+1) 67 | training = "Train" if is_training else "Valid" 68 | print("Epoch %d %s done, total_loss: %.4f, total_wer: %.4f" % (epoch_id, training, average_loss, total_errs / total_tokens)) 69 | return 1-total_errs / total_tokens, average_loss 70 | 71 | class Config(object): 72 | batch_size = 4 73 | dropout = 0.1 74 | 75 | def main(conf): 76 | opts = Config() 77 | for k, v in conf.items(): 78 | setattr(opts, k, v) 79 | print('{:50}:{}'.format(k, v)) 80 | 81 | device = torch.device('cuda') if opts.use_gpu else torch.device('cpu') 82 | torch.manual_seed(opts.seed) 83 | np.random.seed(opts.seed) 84 | if opts.use_gpu: 85 | torch.cuda.manual_seed(opts.seed) 86 | 87 | #Data Loader 88 | vocab = Vocab(opts.vocab_file) 89 | train_dataset = SpeechDataset(vocab, opts.train_scp_path, opts.train_lab_path, opts) 90 | dev_dataset = SpeechDataset(vocab, opts.valid_scp_path, opts.valid_lab_path, opts) 91 | train_loader = SpeechDataLoader(train_dataset, batch_size=opts.batch_size, shuffle=opts.shuffle_train, num_workers=opts.num_workers) 92 | dev_loader = SpeechDataLoader(dev_dataset, batch_size=opts.batch_size, shuffle=False, num_workers=opts.num_workers) 93 | 94 | #Define Model 95 | rnn_type = supported_rnn[opts.rnn_type] 96 | rnn_param = {"rnn_input_size":opts.rnn_input_size, "rnn_hidden_size":opts.rnn_hidden_size, "rnn_layers":opts.rnn_layers, 97 | "rnn_type":rnn_type, "bidirectional":opts.bidirectional, "batch_norm":opts.batch_norm} 98 | 99 | num_class = vocab.n_words 100 | opts.output_class_dim = vocab.n_words 101 | drop_out = opts.drop_out 102 | add_cnn = opts.add_cnn 103 | 104 | cnn_param = {} 105 | channel = eval(opts.channel) 106 | kernel_size = eval(opts.kernel_size) 107 | stride = eval(opts.stride) 108 | padding = eval(opts.padding) 109 | pooling = eval(opts.pooling) 110 | activation_function = supported_activate[opts.activation_function] 111 | cnn_param['batch_norm'] = opts.batch_norm 112 | cnn_param['activate_function'] = activation_function 113 | cnn_param["layer"] = [] 114 | for layer in range(opts.layers): 115 | layer_param = [channel[layer], kernel_size[layer], stride[layer], padding[layer]] 116 | if pooling is not None: 117 | layer_param.append(pooling[layer]) 118 | else: 119 | layer_param.append(None) 120 | cnn_param["layer"].append(layer_param) 121 | 122 | model = CTC_Model(add_cnn=add_cnn, cnn_param=cnn_param, rnn_param=rnn_param, num_class=num_class, drop_out=drop_out) 123 | model = model.to(device) 124 | num_params = 0 125 | for name, param in model.named_parameters(): 126 | num_params += param.numel() 127 | print("Number of parameters %d" % num_params) 128 | for idx, m in enumerate(model.children()): 129 | print(idx, m) 130 | 131 | #Training 132 | init_lr = opts.init_lr 133 | num_epoches = opts.num_epoches 134 | end_adjust_acc = opts.end_adjust_acc 135 | decay = opts.lr_decay 136 | weight_decay = opts.weight_decay 137 | batch_size = opts.batch_size 138 | 139 | params = { 'num_epoches':num_epoches, 'end_adjust_acc':end_adjust_acc, 'mel': opts.mel, 'seed':opts.seed, 140 | 'decay':decay, 'learning_rate':init_lr, 'weight_decay':weight_decay, 'batch_size':batch_size, 141 | 'feature_type':opts.feature_type, 'n_feats': opts.feature_dim } 142 | print(params) 143 | 144 | loss_fn = nn.CTCLoss(reduction='sum') 145 | optimizer = torch.optim.Adam(model.parameters(), lr=init_lr, weight_decay=weight_decay) 146 | 147 | #visualization for training 148 | from visdom import Visdom 149 | viz = Visdom() 150 | if add_cnn: 151 | title = opts.feature_type + str(opts.feature_dim) + ' CNN_LSTM_CTC' 152 | else: 153 | title = opts.feature_type + str(opts.feature_dim) + ' LSTM_CTC' 154 | 155 | viz_opts = [dict(title=title+" Loss", ylabel = 'Loss', xlabel = 'Epoch'), 156 | dict(title=title+" Loss on Dev", ylabel = 'DEV Loss', xlabel = 'Epoch'), 157 | dict(title=title+' CER on DEV', ylabel = 'DEV CER', xlabel = 'Epoch')] 158 | viz_window = [None, None, None] 159 | 160 | count = 0 161 | learning_rate = init_lr 162 | loss_best = 1000 163 | loss_best_true = 1000 164 | adjust_rate_flag = False 165 | stop_train = False 166 | adjust_time = 0 167 | acc_best = 0 168 | start_time = time.time() 169 | loss_results = [] 170 | dev_loss_results = [] 171 | dev_cer_results = [] 172 | 173 | while not stop_train: 174 | if count >= num_epoches: 175 | break 176 | count += 1 177 | 178 | if adjust_rate_flag: 179 | learning_rate *= decay 180 | adjust_rate_flag = False 181 | for param in optimizer.param_groups: 182 | param['lr'] *= decay 183 | 184 | print("Start training epoch: %d, learning_rate: %.5f" % (count, learning_rate)) 185 | 186 | train_acc, loss = run_epoch(count, model, train_loader, loss_fn, device, optimizer=optimizer, print_every=opts.verbose_step, is_training=True) 187 | loss_results.append(loss) 188 | acc, dev_loss = run_epoch(count, model, dev_loader, loss_fn, device, optimizer=None, print_every=opts.verbose_step, is_training=False) 189 | print("loss on dev set is %.4f" % dev_loss) 190 | dev_loss_results.append(dev_loss) 191 | dev_cer_results.append(acc) 192 | 193 | #adjust learning rate by dev_loss 194 | if dev_loss < (loss_best - end_adjust_acc): 195 | loss_best = dev_loss 196 | loss_best_true = dev_loss 197 | adjust_rate_count = 0 198 | model_state = copy.deepcopy(model.state_dict()) 199 | op_state = copy.deepcopy(optimizer.state_dict()) 200 | elif (dev_loss < loss_best + end_adjust_acc): 201 | adjust_rate_count += 1 202 | if dev_loss < loss_best and dev_loss < loss_best_true: 203 | loss_best_true = dev_loss 204 | model_state = copy.deepcopy(model.state_dict()) 205 | op_state = copy.deepcopy(optimizer.state_dict()) 206 | else: 207 | adjust_rate_count = 10 208 | 209 | if acc > acc_best: 210 | acc_best = acc 211 | best_model_state = copy.deepcopy(model.state_dict()) 212 | best_op_state = copy.deepcopy(optimizer.state_dict()) 213 | 214 | print("adjust_rate_count:"+str(adjust_rate_count)) 215 | print('adjust_time:'+str(adjust_time)) 216 | 217 | if adjust_rate_count == 10: 218 | adjust_rate_flag = True 219 | adjust_time += 1 220 | adjust_rate_count = 0 221 | if loss_best > loss_best_true: 222 | loss_best = loss_best_true 223 | model.load_state_dict(model_state) 224 | optimizer.load_state_dict(op_state) 225 | 226 | if adjust_time == 8: 227 | stop_train = True 228 | 229 | time_used = (time.time() - start_time) / 60 230 | print("epoch %d done, cv acc is: %.4f, time_used: %.4f minutes" % (count, acc, time_used)) 231 | 232 | x_axis = range(count) 233 | y_axis = [loss_results[0:count], dev_loss_results[0:count], dev_cer_results[0:count]] 234 | for x in range(len(viz_window)): 235 | if viz_window[x] is None: 236 | viz_window[x] = viz.line(X = np.array(x_axis), Y = np.array(y_axis[x]), opts = viz_opts[x],) 237 | else: 238 | viz.line(X = np.array(x_axis), Y = np.array(y_axis[x]), win = viz_window[x], update = 'replace',) 239 | 240 | print("End training, best dev loss is: %.4f, acc is: %.4f" % (loss_best, acc_best)) 241 | model.load_state_dict(best_model_state) 242 | optimizer.load_state_dict(best_op_state) 243 | save_dir = os.path.join(opts.checkpoint_dir, opts.exp_name) 244 | if not os.path.exists(save_dir): 245 | os.makedirs(save_dir) 246 | best_path = os.path.join(save_dir, 'ctc_best_model.pkl') 247 | params['epoch']=count 248 | 249 | torch.save(CTC_Model.save_package(model, optimizer=optimizer, epoch=params, loss_results=loss_results, dev_loss_results=dev_loss_results, dev_cer_results=dev_cer_results), best_path) 250 | 251 | if __name__ == '__main__': 252 | args = parser.parse_args() 253 | try: 254 | config_path = args.conf 255 | conf = yaml.safe_load(open(config_path, 'r')) 256 | except: 257 | print("No input config or config file missing, please check.") 258 | sys.exit(1) 259 | main(conf) 260 | -------------------------------------------------------------------------------- /timit/steps/train_lm.sh: -------------------------------------------------------------------------------- 1 | #training LM with irlstm in kaldi/tools 2 | 3 | . path.sh 4 | export IRSTLM=$KALDI_ROOT/tools/irstlm/ 5 | export PATH=${PATH}:$IRSTLM/bin 6 | 7 | srcdir=$1 8 | 9 | if ! command -v prune-lm >/dev/null 2>&1 ; then 10 | echo "$0: Error: the IRSTLM is not available or compiled" >&2 11 | echo "$0: Error: We used to install it by default, but." >&2 12 | echo "$0: Error: this is no longer the case." >&2 13 | echo "$0: Error: To install it, go to $KALDI_ROOT/tools" >&2 14 | echo "$0: Error: and run extras/install_irstlm.sh" >&2 15 | exit 1 16 | fi 17 | 18 | cut -d' ' -f2- $srcdir/train/phn_text | sed -e 's:^: :' -e 's:$: :' \ 19 | > $srcdir/lm_train.text 20 | 21 | build-lm.sh -i $srcdir/lm_train.text -n 2 -o lm_phone_bg.ilm.gz 22 | 23 | compile-lm lm_phone_bg.ilm.gz -t=yes /dev/stdout > $srcdir/lm_phone_bg.arpa 24 | 25 | rm -f lm_phone_bg.ilm.gz 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /timit/steps/visualize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #encoding=utf-8 3 | 4 | #old version, need updated 5 | from data_loader import myDataset 6 | from data_loader import myDataLoader, myCNNDataLoader 7 | from model import * 8 | from ctcDecoder import GreedyDecoder, BeamDecoder 9 | import torch 10 | import torch.nn as nn 11 | from torch.autograd import Variable 12 | import time 13 | import visdom 14 | 15 | 16 | def test(): 17 | model_path = '../log/exp_cnn_lstm_ctc_spectrum201/exp_cnn3*41_3*21_4lstm_ctc_Melspectrum_stride_1_2/exp2_82.1483/best_model_cv80.8660423723.pkl' 18 | package = torch.load(model_path) 19 | data_dir = '../data_prepare/data' 20 | rnn_param = package["rnn_param"] 21 | add_cnn = package["add_cnn"] 22 | cnn_param = package["cnn_param"] 23 | num_class = package["num_class"] 24 | feature_type = package['epoch']['feature_type'] 25 | n_feats = package['epoch']['n_feats'] 26 | out_type = package['epoch']['out_type'] 27 | drop_out = package['_drop_out'] 28 | try: 29 | mel = package['epoch']['mel'] 30 | except: 31 | mel = False 32 | #weight_decay = package['epoch']['weight_decay'] 33 | #print(weight_decay) 34 | 35 | decoder_type = 'Greedy' 36 | 37 | test_dataset = myDataset(data_dir, data_set='train', feature_type=feature_type, out_type=out_type, n_feats=n_feats, mel=mel) 38 | 39 | model = CTC_Model(rnn_param=rnn_param, add_cnn=add_cnn, cnn_param=cnn_param, num_class=num_class, drop_out=drop_out) 40 | 41 | if add_cnn: 42 | test_loader = myCNNDataLoader(test_dataset, batch_size=1, shuffle=False, 43 | num_workers=4, pin_memory=False) 44 | else: 45 | test_loader = myDataLoader(test_dataset, batch_size=1, shuffle=False, 46 | num_workers=4, pin_memory=False) 47 | 48 | model.load_state_dict(package['state_dict']) 49 | model.eval() 50 | 51 | if USE_CUDA: 52 | model = model.cuda() 53 | 54 | if decoder_type == 'Greedy': 55 | decoder = GreedyDecoder(test_dataset.int2phone, space_idx=-1, blank_index=0) 56 | else: 57 | decoder = BeamDecoder(test_dataset.int2phone) 58 | 59 | import pickle 60 | f = open('../decode_map_48-39/map_dict.pkl', 'rb') 61 | map_dict = pickle.load(f) 62 | f.close() 63 | print(map_dict) 64 | 65 | vis = visdom.Visdom(env='fan') 66 | legend = [] 67 | for i in range(49): 68 | legend.append(test_dataset.int2phone[i]) 69 | 70 | for data in test_loader: 71 | inputs, target, input_sizes, input_size_list, target_sizes = data 72 | if not add_cnn: 73 | inputs = inputs.transpose(0,1) 74 | 75 | inputs = Variable(inputs, volatile=True, requires_grad=False) 76 | if USE_CUDA: 77 | inputs = inputs.cuda() 78 | 79 | if not add_cnn: 80 | inputs = nn.utils.rnn.pack_padded_sequence(inputs, input_size_list) 81 | 82 | probs, visual = model(inputs, visualize=True) 83 | probs = probs.data.cpu() 84 | 85 | if add_cnn: 86 | max_length = probs.size(0) 87 | input_size_list = [int(x*max_length) for x in input_size_list] 88 | 89 | decoded = decoder.decode(probs, input_size_list) 90 | targets = decoder._unflatten_targets(target, target_sizes) 91 | labels = decoder._process_strings(decoder._convert_to_strings(targets)) 92 | 93 | for x in range(len(labels)): 94 | label = labels[x].strip().split(' ') 95 | for i in range(len(label)): 96 | label[i] = map_dict[label[i]] 97 | labels[x] = ' '.join(label) 98 | decode = decoded[x].strip().split(' ') 99 | for i in range(len(decode)): 100 | decode[i] = map_dict[decode[i]] 101 | decoded[x] = ' '.join(decode) 102 | 103 | for x in range(len(labels)): 104 | print("origin: "+ labels[x]) 105 | print("decoded: "+ decoded[x]) 106 | 107 | if add_cnn: 108 | spectrum_inputs = visual[0][0][0].transpose(0, 1).data.cpu() 109 | opts = dict(title=labels[0], xlabel="frame", ylabel='spectrum') 110 | vis.heatmap(spectrum_inputs, opts = opts) 111 | 112 | opts = dict(title=labels[0], xlabel="frame", ylabel='feature_after_cnn') 113 | after_cnn = visual[1][0][0].transpose(0, 1).data.cpu() 114 | vis.heatmap(after_cnn, opts = opts) 115 | 116 | opts = dict(title=labels[0], xlabel="frame", ylabel='feature_before_rnn') 117 | before_rnn = visual[2].transpose(0, 1)[0].transpose(0, 1).data.cpu() 118 | vis.heatmap(before_rnn, opts=opts) 119 | 120 | show_prob = visual[3].transpose(0, 1)[0].data.cpu() 121 | line_opts = dict(title=decoded[0], xlabel="frame", ylabel="probability", legend=legend) 122 | x = show_prob.size()[0] 123 | vis.line(show_prob.numpy(), X=np.array(range(x)), opts=line_opts) 124 | else: 125 | spectrum_inputs = visual[0][0][0].transpose(0, 1).data.cpu() 126 | opts = dict(title=labels[0], xlabel="frame", ylabel='spectrum') 127 | vis.heatmap(spectrum_inputs, opts = opts) 128 | 129 | show_prob = visual[1].transpose(0, 1)[0].data.cpu() 130 | line_opts = dict(title=decoded[0], xlabel="frame", ylabel="probability", legend=legend) 131 | x = show_prob.size()[0] 132 | vis.line(show_prob.numpy(), X=np.array(range(x)), opts=line_opts) 133 | break 134 | 135 | if __name__ == "__main__": 136 | test() 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /timit/steps/visualize.py.old: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #encoding=utf-8 3 | 4 | from data_loader import myDataset 5 | from data_loader import myDataLoader, myCNNDataLoader 6 | from model import * 7 | from ctcDecoder import GreedyDecoder, BeamDecoder 8 | import torch 9 | import torch.nn as nn 10 | from torch.autograd import Variable 11 | import time 12 | import visdom 13 | 14 | 15 | def test(): 16 | model_path = '../log/exp_cnn_lstm_ctc_spectrum201/exp_cnn3*41_3*21_4lstm_ctc_Melspectrum_stride_1_2/exp2_82.1483/best_model_cv80.8660423723.pkl' 17 | package = torch.load(model_path) 18 | data_dir = '../data_prepare/data' 19 | input_size = package['input_size'] 20 | layers = package['rnn_layers'] 21 | hidden_size = package['hidden_size'] 22 | rnn_type = package['rnn_type'] 23 | num_class = package["num_class"] 24 | feature_type = package['epoch']['feature_type'] 25 | n_feats = package['epoch']['n_feats'] 26 | out_type = package['epoch']['out_type'] 27 | model_type = package['name'] 28 | drop_out = package['_drop_out'] 29 | try: 30 | mel = package['epoch']['mel'] 31 | except: 32 | mel = False 33 | #weight_decay = package['epoch']['weight_decay'] 34 | #print(weight_decay) 35 | 36 | decoder_type = 'Greedy' 37 | 38 | test_dataset = myDataset(data_dir, data_set='train', feature_type=feature_type, out_type=out_type, n_feats=n_feats, mel=mel) 39 | 40 | if model_type == 'CNN_LSTM_CTC': 41 | model = CNN_LSTM_CTC(rnn_input_size=input_size, rnn_hidden_size=hidden_size, rnn_layers=layers, 42 | rnn_type=rnn_type, bidirectional=True, batch_norm=True, num_class=num_class, drop_out=drop_out) 43 | test_loader = myCNNDataLoader(test_dataset, batch_size=1, shuffle=False, 44 | num_workers=4, pin_memory=False) 45 | else: 46 | model = CTC_RNN(rnn_input_size=input_size, rnn_hidden_size=hidden_size, rnn_layers=layers, 47 | rnn_type=rnn_type, bidirectional=True, batch_norm=True, num_class=num_class, drop_out=drop_out) 48 | test_loader = myDataLoader(test_dataset, batch_size=8, shuffle=False, 49 | num_workers=4, pin_memory=False) 50 | 51 | model.load_state_dict(package['state_dict']) 52 | model.eval() 53 | 54 | if USE_CUDA: 55 | model = model.cuda() 56 | 57 | if decoder_type == 'Greedy': 58 | decoder = GreedyDecoder(test_dataset.int2phone, space_idx=-1, blank_index=0) 59 | else: 60 | decoder = BeamDecoder(test_dataset.int2phone, top_paths=3, beam_width=20, blank_index=0, space_idx=-1, 61 | lm_path=None, dict_path=None, 62 | trie_path=None, lm_alpha=10, lm_beta1=1, lm_beta2=1) 63 | import pickle 64 | f = open('../decode_map_48-39/map_dict.pkl', 'rb') 65 | map_dict = pickle.load(f) 66 | f.close() 67 | print(map_dict) 68 | 69 | vis = visdom.Visdom(env='fan') 70 | legend = [] 71 | for i in range(49): 72 | legend.append(test_dataset.int2phone[i]) 73 | 74 | for data in test_loader: 75 | inputs, target, input_sizes, input_size_list, target_sizes = data 76 | if model.name == 'CTC_RNN': 77 | inputs = inputs.transpose(0,1) 78 | 79 | inputs = Variable(inputs, volatile=True, requires_grad=False) 80 | if USE_CUDA: 81 | inputs = inputs.cuda() 82 | 83 | if model.name == 'CTC_RNN': 84 | inputs = nn.utils.rnn.pack_padded_sequence(inputs, input_size_list) 85 | probs, visual = model(inputs, visualize=True) 86 | probs = probs.data.cpu() 87 | 88 | decoded = decoder.decode(probs, input_size_list) 89 | targets = decoder._unflatten_targets(target, target_sizes) 90 | labels = decoder._process_strings(decoder._convert_to_strings(targets)) 91 | 92 | for x in range(len(labels)): 93 | label = labels[x].strip().split(' ') 94 | for i in range(len(label)): 95 | label[i] = map_dict[label[i]] 96 | labels[x] = ' '.join(label) 97 | decode = decoded[x].strip().split(' ') 98 | for i in range(len(decode)): 99 | decode[i] = map_dict[decode[i]] 100 | decoded[x] = ' '.join(decode) 101 | 102 | for x in range(len(labels)): 103 | print("origin: "+ labels[x]) 104 | print("decoded: "+ decoded[x]) 105 | 106 | spectrum_inputs = visual[0][0][0].transpose(0, 1).data.cpu() 107 | opts = dict(title=labels[0], xlabel="frame", ylabel='spectrum') 108 | vis.heatmap(spectrum_inputs, opts = opts) 109 | 110 | opts = dict(title=labels[0], xlabel="frame", ylabel='feature_after_cnn1') 111 | after_cnn = visual[1][0][0].transpose(0, 1).data.cpu() 112 | vis.heatmap(after_cnn, opts = opts) 113 | 114 | opts = dict(title=labels[0], xlabel="frame", ylabel='feature_after_cnn2') 115 | after_cnn2 = visual[2][0][0].transpose(0, 1).data.cpu() 116 | vis.heatmap(after_cnn2, opts = opts) 117 | 118 | opts = dict(title=labels[0], xlabel="frame", ylabel='feature_before_rnn') 119 | before_rnn = visual[3].transpose(0, 1)[0].transpose(0, 1).data.cpu() 120 | vis.heatmap(before_rnn, opts=opts) 121 | 122 | show_prob = visual[4].transpose(0, 1)[0].data.cpu() 123 | line_opts = dict(title=decoded[0], xlabel="frame", ylabel="probability", legend=legend) 124 | x = show_prob.size()[0] 125 | vis.line(show_prob.numpy(), X=np.array(range(x)), opts=line_opts) 126 | break 127 | 128 | if __name__ == "__main__": 129 | test() 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /timit/utils/BeamSearch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #encoding=utf-8 3 | 4 | import math 5 | 6 | LOG_ZERO = -99999999.0 7 | LOG_ONE = 0.0 8 | 9 | class BeamEntry: 10 | "information about one single beam at specific time-step" 11 | def __init__(self): 12 | self.prTotal=LOG_ZERO # blank and non-blank 13 | self.prNonBlank=LOG_ZERO # non-blank 14 | self.prBlank=LOG_ZERO # blank 15 | self.y=() # labelling at current time-step 16 | 17 | 18 | class BeamState: 19 | "information about beams at specific time-step" 20 | def __init__(self): 21 | self.entries={} 22 | 23 | def norm(self): 24 | "length-normalise probabilities to avoid penalising long labellings" 25 | for (k,v) in self.entries.items(): 26 | labellingLen=len(self.entries[k].y) 27 | self.entries[k].prTotal=self.entries[k].prTotal*(1.0/(labellingLen if labellingLen else 1)) 28 | 29 | def sort(self): 30 | "return beams sorted by probability" 31 | u=[v for (k,v) in self.entries.items()] 32 | s=sorted(u, reverse=True, key=lambda x:x.prTotal) 33 | return [x.y for x in s] 34 | 35 | class ctcBeamSearch(object): 36 | def __init__(self, classes, beam_width, lm, lm_alpha=0.01, blank_index=0): 37 | self.classes = classes 38 | self.beamWidth = beam_width 39 | self.lm_alpha = lm_alpha 40 | self.lm = lm 41 | self.blank_index = blank_index 42 | 43 | def log_add_prob(self, log_x, log_y): 44 | if log_x <= LOG_ZERO: 45 | return log_y 46 | if log_y <= LOG_ZERO: 47 | return log_x 48 | if (log_y - log_x) > 0.0: 49 | log_y, log_x = log_x, log_y 50 | return log_x + math.log(1 + math.exp(log_y - log_x)) 51 | 52 | def calcExtPr(self, k, y, t, mat, beamState): 53 | "probability for extending labelling y to y+k" 54 | 55 | # language model (char bigrams) 56 | bigramProb=LOG_ONE 57 | if self.lm: 58 | c1=self.classes[y[-1]] if len(y) else '' 59 | c2=self.classes[k] 60 | bigramProb = self.lm.get_bi_prob(c1,c2) * self.lm_alpha 61 | 62 | # optical model (RNN) 63 | if len(y) and y[-1]==k and mat[t-1, self.blank_index] < 0.9: 64 | return math.log(mat[t, k]) + bigramProb + beamState.entries[y].prBlank 65 | else: 66 | return math.log(mat[t, k]) + bigramProb + beamState.entries[y].prTotal 67 | 68 | def addLabelling(self, beamState, y): 69 | "adds labelling if it does not exist yet" 70 | if y not in beamState.entries: 71 | beamState.entries[y]=BeamEntry() 72 | 73 | def decode(self, inputs, inputs_list): 74 | ''' 75 | mat : FloatTesnor batch * timesteps * class 76 | ''' 77 | batches, maxT, maxC = inputs.size() 78 | res = [] 79 | 80 | for batch in range(batches): 81 | mat = inputs[batch].numpy() 82 | # Initialise beam state 83 | last=BeamState() 84 | y=() 85 | last.entries[y]=BeamEntry() 86 | last.entries[y].prBlank=LOG_ONE 87 | last.entries[y].prTotal=LOG_ONE 88 | 89 | # go over all time-steps 90 | for t in range(inputs_list[batch]): 91 | curr=BeamState() 92 | #跳过概率很接近1的blank帧,增加解码速度 93 | if (1 - mat[t, self.blank_index]) < 0.1: 94 | continue 95 | #取前beam个最好的结果 96 | BHat=last.sort()[0:self.beamWidth] 97 | # go over best labellings 98 | for y in BHat: 99 | prNonBlank=LOG_ZERO 100 | # if nonempty labelling 101 | if len(y)>0: 102 | #相同的y两种可能,加入重复或者加入空白,如果之前没有字符,在NonBlank概率为0 103 | prNonBlank=last.entries[y].prNonBlank + math.log(mat[t, y[-1]]) 104 | 105 | # calc probabilities 106 | prBlank = (last.entries[y].prTotal) + math.log(mat[t, self.blank_index]) 107 | # save result 108 | self.addLabelling(curr, y) 109 | curr.entries[y].y=y 110 | curr.entries[y].prNonBlank = self.log_add_prob(curr.entries[y].prNonBlank, prNonBlank) 111 | curr.entries[y].prBlank = self.log_add_prob(curr.entries[y].prBlank, prBlank) 112 | prTotal = self.log_add_prob(prBlank, prNonBlank) 113 | curr.entries[y].prTotal = self.log_add_prob(curr.entries[y].prTotal, prTotal) 114 | 115 | #t时刻加入其它的label,此时Blank的概率为0,如果加入的label与最后一个相同,因为不能重复,所以上一个字符一定是blank 116 | for k in range(maxC): 117 | if k != self.blank_index: 118 | newY=y+(k,) 119 | prNonBlank=self.calcExtPr(k, y, t, mat, last) 120 | 121 | # save result 122 | self.addLabelling(curr, newY) 123 | curr.entries[newY].y=newY 124 | curr.entries[newY].prNonBlank = self.log_add_prob(curr.entries[newY].prNonBlank, prNonBlank) 125 | curr.entries[newY].prTotal = self.log_add_prob(curr.entries[newY].prTotal, prNonBlank) 126 | 127 | # set new beam state 128 | last=curr 129 | 130 | BHat=last.sort()[0:self.beamWidth] 131 | # go over best labellings 132 | curr = BeamState() 133 | for y in BHat: 134 | newY = y 135 | c1 = self.classes[y[-1]] 136 | c2 = "" 137 | prNonBlank = last.entries[newY].prTotal + self.lm.get_bi_prob(c1, c2) * self.lm_alpha 138 | self.addLabelling(curr, newY) 139 | curr.entries[newY].y=newY 140 | curr.entries[newY].prNonBlank = self.log_add_prob(curr.entries[newY].prNonBlank, prNonBlank) 141 | curr.entries[newY].prTotal = self.log_add_prob(curr.entries[newY].prTotal, prNonBlank) 142 | 143 | last = curr 144 | # normalise probabilities according to labelling length 145 | last.norm() 146 | 147 | # sort by probability 148 | bestLabelling=last.sort()[0] # get most probable labelling 149 | 150 | # map labels to chars 151 | res_b =' '.join([self.classes[l] for l in bestLabelling]) 152 | res.append(res_b) 153 | return res 154 | 155 | -------------------------------------------------------------------------------- /timit/utils/NgramLM.py: -------------------------------------------------------------------------------- 1 | #!/usrbin/python 2 | #encoding=utf-8 3 | 4 | # Get n-gram propability from arpa file; 5 | 6 | import re 7 | import math 8 | 9 | n_grams = ["unigram", 'bigram', 'trigram', '4gram', '5gram'] 10 | 11 | class LanguageModel: 12 | """ 13 | New version of LanguageModel which can read the text arpa file ,which 14 | is generate from kennlm 15 | """ 16 | def __init__(self, arpa_file=None, n_gram=2, start='', end='', unk=''): 17 | "Load arpa file to get words and prob" 18 | self.n_gram = n_gram 19 | self.start = start 20 | self.end = end 21 | self.unk = unk 22 | self.scale = math.log(10) #arpa格式是以10为底的对数概率,转化为以e为底 23 | self.initngrams(arpa_file) 24 | 25 | def initngrams(self, fn): 26 | "internal init of word bigrams" 27 | self.unigram = {} 28 | self.bigram = {} 29 | if self.n_gram == 3: 30 | self.trigrame = {} 31 | 32 | # go through text and create each bigrams 33 | f = open(fn, 'r') 34 | recording = 0 35 | for lines in f.readlines(): 36 | line = lines.strip('\n') 37 | #a = re.match('gram', line) 38 | if line == "\\1-grams:": 39 | recording = 1 40 | continue 41 | if line == "\\2-grams:": 42 | recording = 2 43 | continue 44 | if recording == 1: 45 | line = line.split('\t') 46 | if len(line) == 3: 47 | self.unigram[line[1]] = [self.scale * float(line[0]), self.scale * float(line[2])] #save the prob and backoff prob 48 | elif len(line) == 2: 49 | self.unigram[line[1]] = [self.scale * float(line[0]), 0.0] 50 | if recording == 2: 51 | line = line.split('\t') 52 | if len(line) == 3: 53 | #print(line[1]) 54 | self.bigram[line[1]] = [self.scale * float(line[0]), self.scale * float(line[2])] 55 | elif len(line) == 2: 56 | self.bigram[line[1]] = [self.scale * float(line[0]), 0.0] 57 | f.close() 58 | self.unigram['UNK'] = self.unigram[self.unk] 59 | 60 | 61 | def get_uni_prob(self, wid): 62 | "Returns unigram probabiliy of word" 63 | return self.unigram[wid][0] 64 | 65 | def get_bi_prob(self, w1, w2): 66 | ''' 67 | Return bigrams probability p(w2 | w1) 68 | if bigrame does not exist, use backoff prob 69 | ''' 70 | if w1 == '': 71 | w1 = self.start 72 | if w2 == '': 73 | w2 = self.end 74 | key = w1 + ' ' + w2 75 | if key not in self.bigram: 76 | return self.unigram[w1][1] + self.unigram[w2][0] 77 | else: 78 | return self.bigram[key][0] 79 | 80 | def score_bg(self, sentence): 81 | ''' 82 | Score a sentence using bigram, return P(sentence) 83 | ''' 84 | val = 0.0 85 | words = sentence.strip().split() 86 | val += self.get_bi_prob(self.start, words[0]) 87 | for i in range(len(words)-1): 88 | val += self.get_bi_prob(words[i], words[i+1]) 89 | val += self.get_bi_prob(words[-1], self.end) 90 | return val 91 | 92 | if __name__ == "__main__": 93 | lm = LanguageModel('./data_prepare/bigram.arpa') 94 | #print(lm.bigram['你 好']) 95 | print(lm.get_bi_prob('', 'sil')) 96 | #print(lm.score_bg("中国 呼吸")) 97 | 98 | -------------------------------------------------------------------------------- /timit/utils/ctcDecoder.py: -------------------------------------------------------------------------------- 1 | #/usr/bin/python 2 | #encoding=utf-8 3 | 4 | #greedy decoder and beamsearch decoder for ctc 5 | 6 | import torch 7 | import numpy as np 8 | 9 | class Decoder(object): 10 | "解码器基类定义,作用是将模型的输出转化为文本使其能够与标签计算正确率" 11 | def __init__(self, int2char, space_idx = 1, blank_index = 0): 12 | ''' 13 | int2char : 将类别转化为字符标签 14 | space_idx : 空格符号的索引,如果为为-1,表示空格不是一个类别 15 | blank_index : 空白类的索引,默认设置为0 16 | ''' 17 | self.int_to_char = int2char 18 | self.space_idx = space_idx 19 | self.blank_index = blank_index 20 | self.num_word = 0 21 | self.num_char = 0 22 | 23 | def decode(self): 24 | "解码函数,在GreedyDecoder和BeamDecoder继承类中实现" 25 | raise NotImplementedError; 26 | 27 | def phone_word_error(self, prob_tensor, frame_seq_len, targets, target_sizes): 28 | '''计算词错率和字符错误率 29 | Args: 30 | prob_tensor : 模型的输出 31 | frame_seq_len : 每个样本的帧长 32 | targets : 样本标签 33 | target_sizes : 每个样本标签的长度 34 | Returns: 35 | wer : 词错率,以space为间隔分开作为词 36 | cer : 字符错误率 37 | ''' 38 | strings = self.decode(prob_tensor, frame_seq_len) 39 | targets = self._unflatten_targets(targets, target_sizes) 40 | target_strings = self._process_strings(self._convert_to_strings(targets)) 41 | 42 | cer = 0 43 | wer = 0 44 | for x in range(len(target_strings)): 45 | cer += self.cer(strings[x], target_strings[x]) 46 | wer += self.wer(strings[x], target_strings[x]) 47 | self.num_word += len(target_strings[x].split()) 48 | self.num_char += len(target_strings[x]) 49 | return cer, wer 50 | 51 | def _unflatten_targets(self, targets, target_sizes): 52 | '''将标签按照每个样本的标签长度进行分割 53 | Args: 54 | targets : 数字表示的标签 55 | target_sizes : 每个样本标签的长度 56 | Returns: 57 | split_targets : 得到的分割后的标签 58 | ''' 59 | split_targets = [] 60 | offset = 0 61 | for size in target_sizes: 62 | split_targets.append(targets[offset : offset + size]) 63 | offset += size 64 | return split_targets 65 | 66 | def _process_strings(self, seqs, remove_rep = False): 67 | '''处理转化后的字符序列,包括去重复等,将list转化为string 68 | Args: 69 | seqs : 待处理序列 70 | remove_rep : 是否去重复 71 | Returns: 72 | processed_strings : 处理后的字符序列 73 | ''' 74 | processed_strings = [] 75 | for seq in seqs: 76 | string = self._process_string(seq, remove_rep) 77 | processed_strings.append(string) 78 | return processed_strings 79 | 80 | def _process_string(self, seq, remove_rep = False): 81 | string = '' 82 | for i, char in enumerate(seq): 83 | if char != self.int_to_char[self.blank_index]: 84 | if remove_rep and i != 0 and char == seq[i - 1]: #remove dumplicates 85 | pass 86 | elif self.space_idx == -1: 87 | string = string + ' '+ char 88 | elif char == self.int_to_char[self.space_idx]: 89 | string += ' ' 90 | else: 91 | string = string + char 92 | return string 93 | 94 | def _convert_to_strings(self, seq, sizes=None): 95 | '''将数字序列的输出转化为字符序列 96 | Args: 97 | seqs : 待转化序列 98 | sizes : 每个样本序列的长度 99 | Returns: 100 | strings : 转化后的字符序列 101 | ''' 102 | strings = [] 103 | for x in range(len(seq)): 104 | seq_len = sizes[x] if sizes is not None else len(seq[x]) 105 | string = self._convert_to_string(seq[x], seq_len) 106 | strings.append(string) 107 | return strings 108 | 109 | def _convert_to_string(self, seq, sizes): 110 | result = [] 111 | for i in range(sizes): 112 | result.append(self.int_to_char[seq[i]]) 113 | if self.space_idx == -1: 114 | return result 115 | else: 116 | return ''.join(result) 117 | 118 | def wer(self, s1, s2): 119 | "将空格作为分割计算词错误率" 120 | b = set(s1.split() + s2.split()) 121 | word2int = dict(zip(b, range(len(b)))) 122 | 123 | w1 = [word2int[w] for w in s1.split()] 124 | w2 = [word2int[w] for w in s2.split()] 125 | return self._edit_distance(w1, w2) 126 | 127 | def cer(self, s1, s2): 128 | "计算字符错误率" 129 | return self._edit_distance(s1, s2) 130 | 131 | def _edit_distance(self, src_seq, tgt_seq): 132 | "计算两个序列的编辑距离,用来计算字符错误率" 133 | L1, L2 = len(src_seq), len(tgt_seq) 134 | if L1 == 0: return L2 135 | if L2 == 0: return L1 136 | # construct matrix of size (L1 + 1, L2 + 1) 137 | dist = [[0] * (L2 + 1) for i in range(L1 + 1)] 138 | for i in range(1, L2 + 1): 139 | dist[0][i] = dist[0][i-1] + 1 140 | for i in range(1, L1 + 1): 141 | dist[i][0] = dist[i-1][0] + 1 142 | for i in range(1, L1 + 1): 143 | for j in range(1, L2 + 1): 144 | if src_seq[i - 1] == tgt_seq[j - 1]: 145 | cost = 0 146 | else: 147 | cost = 1 148 | dist[i][j] = min(dist[i][j-1] + 1, dist[i-1][j] + 1, dist[i-1][j-1] + cost) 149 | return dist[L1][L2] 150 | 151 | 152 | class GreedyDecoder(Decoder): 153 | "直接解码,把每一帧的输出概率最大的值作为输出值,而不是整个序列概率最大的值" 154 | def decode(self, prob_tensor, frame_seq_len): 155 | '''解码函数 156 | Args: 157 | prob_tensor : 网络模型输出 158 | frame_seq_len : 每一样本的帧数 159 | Returns: 160 | 解码得到的string,即识别结果 161 | ''' 162 | prob_tensor = prob_tensor.transpose(0,1) 163 | _, decoded = torch.max(prob_tensor, 2) 164 | decoded = decoded.view(decoded.size(0), decoded.size(1)) 165 | decoded = self._convert_to_strings(decoded.numpy(), frame_seq_len) 166 | return self._process_strings(decoded, remove_rep=True) 167 | 168 | class BeamDecoder(Decoder): 169 | "Beam search 解码。解码结果为整个序列概率的最大值" 170 | def __init__(self, int2char, beam_width = 200, blank_index = 0, space_idx = -1, lm_path=None, lm_alpha=0.01): 171 | self.beam_width = beam_width 172 | super(BeamDecoder, self).__init__(int2char, space_idx=space_idx, blank_index=blank_index) 173 | 174 | import sys 175 | sys.path.append('../') 176 | import utils.BeamSearch as uBeam 177 | import utils.NgramLM as uNgram 178 | lm = uNgram.LanguageModel(arpa_file=lm_path) 179 | self._decoder = uBeam.ctcBeamSearch(int2char, beam_width, lm, lm_alpha=lm_alpha, blank_index = blank_index) 180 | 181 | def decode(self, prob_tensor, frame_seq_len=None): 182 | '''解码函数 183 | Args: 184 | prob_tensor : 网络模型输出 185 | frame_seq_len : 每一样本的帧数 186 | Returns: 187 | res : 解码得到的string,即识别结果 188 | ''' 189 | probs = prob_tensor.transpose(0, 1) 190 | probs = torch.exp(probs) 191 | res = self._decoder.decode(probs, frame_seq_len) 192 | return res 193 | 194 | 195 | if __name__ == '__main__': 196 | decoder = Decoder('abcde', 1, 2) 197 | print(decoder._convert_to_strings([[1,2,1,0,3],[1,2,1,1,1]])) 198 | 199 | -------------------------------------------------------------------------------- /timit/utils/data_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #encoding=utf-8 3 | 4 | import torch 5 | import kaldiio 6 | import numpy as np 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | from utils.tools import load_wave, F_Mel, make_context, skip_feat 10 | 11 | audio_conf = {"sample_rate":16000, 'window_size':0.025, 'window_stride':0.01, 'window': 'hamming'} 12 | 13 | class Vocab(object): 14 | def __init__(self, vocab_file): 15 | self.vocab_file = vocab_file 16 | self.word2index = {"blank": 0, "UNK": 1} 17 | self.index2word = {0: "blank", 1: "UNK"} 18 | self.word2count = {} 19 | self.n_words = 2 20 | self.read_lang() 21 | 22 | def add_sentence(self, sentence): 23 | for word in sentence.split(' '): 24 | self.add_word(word) 25 | 26 | def add_word(self, word): 27 | if word not in self.word2index: 28 | self.word2index[word] = self.n_words 29 | self.word2count[word] = 1 30 | self.index2word[self.n_words] = word 31 | self.n_words += 1 32 | else: 33 | self.word2count[word] += 1 34 | 35 | def read_lang(self): 36 | print("Reading vocabulary from {}".format(self.vocab_file)) 37 | with open(self.vocab_file, 'r') as rf: 38 | line = rf.readline() 39 | while line: 40 | line = line.strip().split(' ') 41 | if len(line) > 1: 42 | sen = ' '.join(line[1:]) 43 | else: 44 | sen = line[0] 45 | self.add_sentence(sen) 46 | line = rf.readline() 47 | print("Vocabulary size is {}".format(self.n_words)) 48 | 49 | 50 | class SpeechDataset(Dataset): 51 | def __init__(self, vocab, scp_path, lab_path, opts): 52 | self.vocab = vocab 53 | self.scp_path = scp_path 54 | self.lab_path = lab_path 55 | self.left_ctx = opts.left_ctx 56 | self.right_ctx = opts.right_ctx 57 | self.n_skip_frame = opts.n_skip_frame 58 | self.n_downsample = opts.n_downsample 59 | self.feature_type = opts.feature_type 60 | self.mel = opts.mel 61 | 62 | if opts.feature_type == "waveform": 63 | self.label_dict = process_label_file(label_file, self.out_type, self.class2int) 64 | self.item = [] 65 | with open(wav_path, 'r') as f: 66 | for line in f.readlines(): 67 | utt, path = line.strip().split('\t') 68 | self.item.append((path, self.label_dict[utt])) 69 | else: 70 | self.process_feature_label() 71 | 72 | def process_feature_label(self): 73 | path_dict = [] 74 | #read the ark path 75 | with open(self.scp_path, 'r') as rf: 76 | line = rf.readline() 77 | while line: 78 | utt, path = line.strip().split(' ') 79 | path_dict.append((utt, path)) 80 | line = rf.readline() 81 | 82 | #read the label 83 | label_dict = dict() 84 | with open(self.lab_path, 'r') as rf: 85 | line = rf.readline() 86 | while line: 87 | utt, label = line.strip().split(' ', 1) 88 | label_dict[utt] = [self.vocab.word2index[c] if c in self.vocab.word2index else self.vocab.word2index['UNK'] for c in label.split()] 89 | line = rf.readline() 90 | 91 | assert len(path_dict) == len(label_dict) 92 | print("Reading %d lines from %s" % (len(label_dict), self.lab_path)) 93 | 94 | self.item = [] 95 | for i in range(len(path_dict)): 96 | utt, path = path_dict[i] 97 | self.item.append((path, label_dict[utt], utt)) 98 | 99 | def __getitem__(self, idx): 100 | if self.feature_type == "waveform": 101 | path, label = self.item[idx] 102 | return (load_wave(path), label) 103 | else: 104 | path, label, utt = self.item[idx] 105 | feat = kaldiio.load_mat(path) 106 | feat= skip_feat(make_context(feat, self.left_ctx, self.right_ctx), self.n_skip_frame) 107 | seq_len, dim = feat.shape 108 | if seq_len % self.n_downsample != 0: 109 | pad_len = self.n_downsample - seq_len % self.n_downsample 110 | feat = np.vstack([feat, np.zeros((pad_len, dim))]) 111 | if self.mel: 112 | return (F_Mel(torch.from_numpy(feat), audio_conf), label) 113 | else: 114 | return (torch.from_numpy(feat), torch.LongTensor(label), utt) 115 | 116 | def __len__(self): 117 | return len(self.item) 118 | 119 | def create_input(batch): 120 | inputs_max_length = max(x[0].size(0) for x in batch) 121 | feat_size = batch[0][0].size(1) 122 | targets_max_length = max(x[1].size(0) for x in batch) 123 | batch_size = len(batch) 124 | batch_data = torch.zeros(batch_size, inputs_max_length, feat_size) 125 | batch_label = torch.zeros(batch_size, targets_max_length) 126 | input_sizes = torch.zeros(batch_size) 127 | target_sizes = torch.zeros(batch_size) 128 | utt_list = [] 129 | 130 | for x in range(batch_size): 131 | feature, label, utt = batch[x] 132 | feature_length = feature.size(0) 133 | label_length = label.size(0) 134 | 135 | batch_data[x].narrow(0, 0, feature_length).copy_(feature) 136 | batch_label[x].narrow(0, 0, label_length).copy_(label) 137 | input_sizes[x] = feature_length / inputs_max_length 138 | target_sizes[x] = label_length 139 | utt_list.append(utt) 140 | return batch_data.float(), input_sizes.float(), batch_label.long(), target_sizes.long(), utt_list 141 | 142 | ''' 143 | class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, 144 | collate_fn=, pin_memory=False, drop_last=False) 145 | subclass of DataLoader and rewrite the collate_fn to form batch 146 | ''' 147 | 148 | class SpeechDataLoader(DataLoader): 149 | def __init__(self, *args, **kwargs): 150 | super(SpeechDataLoader, self).__init__(*args, **kwargs) 151 | self.collate_fn = create_input 152 | 153 | if __name__ == '__main__': 154 | dev_dataset = SpeechDataset() 155 | dev_dataloader = SpeechDataLoader(dev_dataset, batch_size=2, shuffle=True) 156 | 157 | import visdom 158 | viz = visdom.Visdom(env='fan') 159 | for i in range(1): 160 | show = dev_dataset[i][0].transpose(0, 1) 161 | text = dev_dataset[i][1] 162 | for num in range(len(text)): 163 | text[num] = dev_dataset.int2class[text[num]] 164 | text = ' '.join(text) 165 | opts = dict(title=text, xlabel='frame', ylabel='spectrum') 166 | viz.heatmap(show, opts = opts) 167 | -------------------------------------------------------------------------------- /timit/utils/tools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #encoding=utf-8 3 | 4 | __author__ = 'Ruchao Fan' 5 | 6 | import math 7 | import torch 8 | import numpy as np 9 | #import librosa 10 | #import torchaudio 11 | 12 | def load_audio(path): 13 | """ 14 | Args: 15 | path : string 载入音频的路径 16 | Returns: 17 | sound : numpy.ndarray 单声道音频数据,如果是多声道进行平均 18 | """ 19 | sound, _ = torchaudio.load(path) 20 | sound = sound.numpy() 21 | if len(sound.shape) > 1: 22 | if sound.shape[1] == 1: 23 | sound = sound.squeeze() 24 | else: 25 | sound = sound.mean(axis=1) 26 | return sound 27 | 28 | def load_wave(path, normalize=True): 29 | """ 30 | Args: 31 | path : string 载入音频的路径 32 | Returns: 33 | """ 34 | sound = load_audio(path) 35 | wave = torch.FloatTensor(sound) 36 | if normalize: 37 | mean = wave.mean() 38 | std = wave.std() 39 | wave.add_(-mean) 40 | wave.div_(std) 41 | return wave 42 | 43 | def F_Mel(fre_f, audio_conf): 44 | ''' 45 | Input: 46 | fre_f : FloatTensor log spectrum 47 | audio_conf : 主要需要用到采样率 48 | Output: 49 | mel_f : FloatTensor 换成mel频谱 50 | ''' 51 | n_mels = fre_f.size(1) 52 | mel_bin = librosa.mel_frequencies(n_mels=n_mels, fmin=0, fmax=audio_conf["sample_rate"]/2) * audio_conf["window_size"] 53 | count = 0 54 | fre_f = fre_f.numpy().tolist() 55 | mel_f = [] 56 | for frame in fre_f: 57 | mel_f_frame = [] 58 | for i in range(n_mels): 59 | left = int(math.floor(mel_bin[i])) 60 | right = left + 1 61 | tmp = (frame[right] - frame[left]) * (mel_bin[i] - left) + frame[left] #线性插值 62 | mel_f_frame.append(tmp) 63 | mel_f.append(mel_f_frame) 64 | return torch.FloatTensor(mel_f) 65 | 66 | def make_context(feature, left, right): 67 | if left==0 and right == 0: 68 | return feature 69 | feature = [feature] 70 | for i in range(left): 71 | feature.append(np.vstack((feature[-1][0], feature[-1][:-1]))) 72 | feature.reverse() 73 | for i in range(right): 74 | feature.append(np.vstack((feature[-1][1:], feature[-1][-1]))) 75 | return np.hstack(feature) 76 | 77 | def skip_feat(feature, skip): 78 | ''' 79 | ''' 80 | if skip == 1 or skip == 0: 81 | return feature 82 | skip_feature=[] 83 | for i in range(feature.shape[0]): 84 | if i % skip == 0: 85 | skip_feature.append(feature[i]) 86 | return np.vstack(skip_feature) 87 | 88 | def process_label_file(label_file, label_type, class2int): 89 | ''' 90 | Input: 91 | label_file : string 标签文件路径 92 | label_type : string 标签类型(目前只支持字符和音素) 93 | class2int : dict 标签和数字的对应关系 94 | Output: 95 | label_dict : dict 所有句子的标签,每个句子是numpy类型 96 | ''' 97 | label_dict = dict() 98 | f = open(label_file, 'r') 99 | for label in f.readlines(): 100 | label = label.strip() 101 | label_list = [] 102 | if label_type == 'char': 103 | utt = label.split('\t', 1)[0] 104 | label = label.split('\t', 1)[1] 105 | for i in range(len(label)): 106 | if label[i].lower() in class2int: 107 | label_list.append(class2int[label[i].lower()]) 108 | if label[i] == ' ': 109 | label_list.append(class2int['SPACE']) 110 | else: 111 | label = label.split() 112 | utt = label[0] 113 | for i in range(1,len(label)): 114 | label_list.append(class2int[label[i]]) 115 | label_dict[utt] = label_list 116 | f.close() 117 | return label_dict 118 | 119 | ''' 120 | if __name__ == '__main__': 121 | import scipy.signal 122 | windows = {'hamming':scipy.signal.hamming, 'hann':scipy.signal.hann, 'blackman':scipy.signal.blackman, 123 | 'bartlett':scipy.signal.bartlett} 124 | audio_conf = {"sample_rate":16000, 'window_size':0.025, 'window_stride':0.01, 'window': 'hamming'} 125 | path = '/home/fan/Audio_data/TIMIT/test/dr7/fdhc0/si1559.wav' 126 | spect = parse_audio(path, audio_conf, windows, normalize=True) 127 | mel_f = F_Mel(spect, audio_conf) 128 | wave = load_wav(path) 129 | print(wave) 130 | 131 | import visdom 132 | viz = visdom.Visdom(env='fan') 133 | viz.heatmap(spect.transpose(0, 1), opts=dict(title="Log Spectrum", xlabel="She had your dark suit in greasy wash water all year.", ylabel="Frequency")) 134 | viz.heatmap(mel_f.transpose(0, 1), opts=dict(title="Log Mel Spectrum", xlabel="She had your dark suit in greasy wash water all year.", ylabel="Frequency")) 135 | viz.line(wave.numpy()) 136 | ''' 137 | 138 | --------------------------------------------------------------------------------