├── requirements.txt ├── README.md ├── model.py ├── almost_inference.py ├── my_datasets.py ├── main.py ├── train_test.py └── utils.py /requirements.txt: -------------------------------------------------------------------------------- 1 | string 2 | re 3 | math 4 | random 5 | numpy 6 | pandas 7 | torch 8 | torchaudio 9 | librosa 10 | asrtoolkit 11 | torch_optimizer 12 | collections 13 | tqdm 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # QuartzNet (ASR, 1D separable convolutions) 2 | 3 | Model described in Kriman et al., 2019 (QuartzNet: Deep Automatic Speech Recognition with 1D Time-Channel Separable Convolutions). 4 | 5 | 6 | Data can be downloaded here: 7 | https://commonvoice.mozilla.org/en/datasets 8 | 9 | More files for running here (sorted indexes, preprecessed tsv, model weights): https://yadi.sk/d/tT-N6DRHkB5XTw?w=1 10 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | 6 | # blocks 7 | def conv_bn_act(in_size, out_size, kernel_size, stride=1, dilation=1): 8 | return nn.Sequential( 9 | nn.Conv1d(in_size, out_size, kernel_size, stride, dilation=dilation), 10 | nn.BatchNorm1d(out_size), 11 | nn.ReLU() 12 | ) 13 | 14 | 15 | def sepconv_bn(in_size, out_size, kernel_size, stride=1, dilation=1, padding=None): 16 | if padding is None: 17 | padding = (kernel_size-1)//2 18 | return nn.Sequential( 19 | torch.nn.Conv1d(in_size, in_size, kernel_size, 20 | stride=stride, dilation=dilation, groups=in_size, 21 | padding=padding), 22 | torch.nn.Conv1d(in_size, out_size, kernel_size=1), 23 | nn.BatchNorm1d(out_size) 24 | ) 25 | 26 | # Main block B_i 27 | class QnetBlock(nn.Module): 28 | def __init__(self, in_size, out_size, kernel_size, stride=1, 29 | R=5): 30 | super().__init__() 31 | 32 | self.layers = nn.ModuleList(sepconv_bn(in_size, out_size, kernel_size, stride)) 33 | for i in range(R - 1): 34 | self.layers.append(nn.ReLU()) 35 | self.layers.append(sepconv_bn(out_size, out_size, kernel_size, stride)) 36 | self.layers = nn.Sequential(*self.layers) 37 | 38 | self.residual = nn.ModuleList() 39 | self.residual.append(torch.nn.Conv1d(in_size, out_size, kernel_size=1)) 40 | self.residual.append(torch.nn.BatchNorm1d(out_size)) 41 | self.residual = nn.Sequential(*self.residual) 42 | 43 | def forward(self, x): 44 | return F.relu(self.residual(x) + self.layers(x)) 45 | 46 | class QuartzNet(nn.Module): 47 | def __init__(self, n_mels, num_classes): 48 | super().__init__() 49 | self.c1 = sepconv_bn(n_mels, 256, kernel_size=33, stride=2) 50 | self.blocks = nn.Sequential( 51 | # in out k s R 52 | QnetBlock(256, 256, 33, 1, R=5), 53 | QnetBlock(256, 256, 39, 1, R=5), 54 | QnetBlock(256, 512, 51, 1, R=5), 55 | QnetBlock(512, 512, 63, 1, R=5), 56 | QnetBlock(512, 512, 75, 1, R=5) 57 | ) 58 | self.c2 = sepconv_bn(512, 512, kernel_size=87, dilation=2, padding=86) 59 | self.c3 = conv_bn_act(512, 1024, kernel_size=1) 60 | self.c4 = conv_bn_act(1024, num_classes, kernel_size=1) 61 | 62 | self.init_weights() 63 | 64 | 65 | def init_weights(self): 66 | pass 67 | 68 | 69 | def forward(self, x): 70 | c1 = F.relu(self.c1(x)) 71 | blocks = self.blocks(c1) 72 | c2 = F.relu(self.c2(blocks)) 73 | c3 = self.c3(c2) 74 | return self.c4(c3) 75 | 76 | -------------------------------------------------------------------------------- /almost_inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchaudio 4 | import random 5 | import numpy as np 6 | import pandas as pd 7 | import wandb 8 | import torch_optimizer 9 | 10 | from torch.optim.lr_scheduler import StepLR 11 | from torch.optim.lr_scheduler import CosineAnnealingLR 12 | from torch.utils.data import DataLoader 13 | 14 | # project imports 15 | from my_datasets import TrainDataset, TestDataset, mel_len, preprocess_data, transform_tr 16 | from model import QuartzNet 17 | from train_test import test 18 | 19 | 20 | def set_seed(seed): 21 | torch.backends.cudnn.deterministic = True 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | 27 | def count_parameters(model): 28 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 29 | return sum([np.prod(p.size()) for p in model_parameters]) 30 | 31 | 32 | if __name__ == '__main__': 33 | BATCH_SIZE = 10 34 | N_MELS = 64 35 | 36 | set_seed(21) 37 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 38 | 39 | ### Loading data and loaders 40 | my_dataset = TrainDataset(csv_file='train_preprocessed.tsv', transform=transform_tr) 41 | # sorted indexes 42 | with open('sorted.npy', 'rb') as f: 43 | s = np.load(f) 44 | to_save = s[200:300][:, 0] 45 | val_set = torch.utils.data.Subset(my_dataset, to_save) 46 | val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, 47 | shuffle=True, collate_fn=preprocess_data, drop_last=True, 48 | num_workers=0, pin_memory=True) 49 | 50 | ### wandb logins 51 | wandb.login() 52 | wandb.init() 53 | train_table = wandb.Table(columns=["Predicted Text", "True Text"]) 54 | 55 | ### Creating melspecs on GPU 56 | melspec = torchaudio.transforms.MelSpectrogram( 57 | sample_rate=16000, ### 22050, 48000 58 | n_fft=1024, 59 | hop_length=256, 60 | n_mels=N_MELS ### 64, 80 61 | ).to(device) 62 | 63 | ### Creating model from scratch 64 | model = QuartzNet(n_mels=64, num_classes=28) 65 | print('num of params', count_parameters(model)) 66 | model.to(device) 67 | wandb.watch(model) 68 | opt = torch_optimizer.NovoGrad( 69 | model.parameters(), 70 | lr=0.01, 71 | betas=(0.8, 0.5), 72 | weight_decay=0.001, 73 | ) 74 | scheduler = CosineAnnealingLR(opt, T_max=50, eta_min=0, last_epoch=-1) 75 | 76 | # loading checkpoint 77 | checkpoint = torch.load('epoch_5', map_location=torch.device('cpu')) 78 | model.load_state_dict(checkpoint['model_state_dict']) 79 | 80 | CTCLoss = nn.CTCLoss(blank=0).to(device) 81 | test(model, opt, val_loader, CTCLoss, device, bs_width=8, melspec=melspec) 82 | 83 | -------------------------------------------------------------------------------- /my_datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | import string 4 | import pandas as pd 5 | import math 6 | 7 | from torch import distributions 8 | from torch.nn.utils.rnn import pad_sequence 9 | 10 | 11 | from utils import TextTransform 12 | 13 | class TrainDataset(torch.utils.data.Dataset): 14 | """Custom competition dataset.""" 15 | 16 | def __init__(self, csv_file, transform=None): 17 | """ 18 | Args: 19 | csv_file (string): Path to the csv file with annotations. 20 | transform (callable, optional): Optional transform to be applied on a sample. 21 | """ 22 | self.answers = pd.read_csv(csv_file, '\t') 23 | self.transform = transform 24 | 25 | 26 | def __len__(self): 27 | return len(self.answers) 28 | 29 | 30 | def __getitem__(self, idx): 31 | if torch.is_tensor(idx): 32 | idx = idx.tolist() 33 | 34 | utt_name = 'cv-corpus-5.1-2020-06-22/en/clips/' + self.answers.loc[idx, 'path'] 35 | utt = torchaudio.load(utt_name)[0].squeeze() 36 | if len(utt.shape) != 1: 37 | utt = utt[1] 38 | 39 | answer = self.answers.loc[idx, 'sentence'] 40 | 41 | if self.transform: 42 | utt = self.transform(utt) 43 | 44 | sample = {'utt': utt, 'answer': answer} 45 | return sample 46 | 47 | 48 | class TestDataset(torch.utils.data.Dataset): 49 | """Custom test dataset.""" 50 | 51 | def __init__(self, csv_file, transform=None): 52 | """ 53 | Args: 54 | transform (callable, optional): Optional transform to be applied on a sample. 55 | """ 56 | self.names = pd.read_csv(csv_file, '\t') 57 | self.transform = transform 58 | 59 | 60 | def __len__(self): 61 | return len(self.names) 62 | 63 | 64 | def __getitem__(self, idx): 65 | if torch.is_tensor(idx): 66 | idx = idx.tolist() 67 | 68 | utt_name = 'cv-corpus-5.1-2020-06-22/en/clips/' + self.names.loc[idx, 'path'] 69 | utt = torchaudio.load(utt_name)[0].squeeze() 70 | 71 | if self.transform: 72 | utt = self.transform(utt) 73 | 74 | sample = {'utt': utt} 75 | return sample 76 | 77 | 78 | #win_len=1024, hop_len=256 79 | # counting len of MelSpec before doing it (cause of padding) 80 | def mel_len(x): 81 | return int(x // 256) + 1 82 | 83 | 84 | def transform_tr(wav): 85 | aug_num = torch.randint(low=0, high=3, size=(1,)).item() 86 | augs = [ 87 | lambda x: x, 88 | lambda x: (x + distributions.Normal(0, 0.01).sample(x.size())).clamp_(-1, 1), 89 | lambda x: torchaudio.transforms.Vol(.1)(x) 90 | ] 91 | return augs[aug_num](wav) 92 | 93 | 94 | # collate_fn 95 | def preprocess_data(data): 96 | text_transform = TextTransform() 97 | wavs = [] 98 | input_lens = [] 99 | labels = [] 100 | label_lens = [] 101 | 102 | for el in data: 103 | wavs.append(el['utt']) 104 | input_lens.append(math.ceil(mel_len(el['utt'].shape[0]) / 2)) # cause of stride 2 105 | label = torch.Tensor(text_transform.text_to_int(el['answer'])) 106 | labels.append(label) 107 | label_lens.append(len(label)) 108 | 109 | wavs = pad_sequence(wavs, batch_first=True) 110 | labels = pad_sequence(labels, batch_first=True) 111 | 112 | return wavs, input_lens, labels, label_lens 113 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchaudio 4 | import random 5 | import numpy as np 6 | import pandas as pd 7 | import wandb 8 | import torch_optimizer 9 | 10 | from torch.optim.lr_scheduler import StepLR 11 | from torch.optim.lr_scheduler import CosineAnnealingLR 12 | from torch.utils.data import DataLoader 13 | 14 | # project imports 15 | from my_datasets import TrainDataset, TestDataset, mel_len, preprocess_data, transform_tr 16 | from model import QuartzNet 17 | from train_test import train 18 | 19 | 20 | def set_seed(seed): 21 | torch.backends.cudnn.deterministic = True 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | 27 | def count_parameters(model): 28 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 29 | return sum([np.prod(p.size()) for p in model_parameters]) 30 | 31 | 32 | if __name__ == '__main__': 33 | BATCH_SIZE = 80 34 | NUM_EPOCHS = 5 35 | N_MELS = 64 36 | 37 | set_seed(21) 38 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 39 | 40 | ### Loading data and loaders 41 | my_dataset = TrainDataset(csv_file='train_preprocessed.tsv', transform=transform_tr) 42 | test_dataset = TestDataset(csv_file='cv-corpus-5.1-2020-06-22/en/test.tsv', transform=None) 43 | # sorted indexes 44 | with open('sorted.npy', 'rb') as f: 45 | s = np.load(f) 46 | to_save = s[:100000][:, 0] 47 | my_dataset = torch.utils.data.Subset(my_dataset, to_save) 48 | train_set, val_set = torch.utils.data.random_split(my_dataset, [85000, 15000]) 49 | 50 | train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, 51 | shuffle=True, collate_fn=preprocess_data, drop_last=True, 52 | num_workers=0, pin_memory=True) 53 | val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, 54 | shuffle=True, collate_fn=preprocess_data, drop_last=True, 55 | num_workers=0, pin_memory=True) 56 | test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True) 57 | 58 | ### wandb logins 59 | wandb.login() 60 | wandb.init() 61 | train_table = wandb.Table(columns=["Predicted Text", "True Text"]) 62 | 63 | ### Creating melspecs on GPU 64 | melspec = torchaudio.transforms.MelSpectrogram( 65 | sample_rate=16000, ### 22050, 48000 66 | n_fft=1024, 67 | hop_length=256, 68 | n_mels=N_MELS ### 64, 80 69 | ).to(device) 70 | # with augmentations 71 | melspec_transforms = nn.Sequential( 72 | torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=1024, hop_length=256, n_mels=N_MELS), 73 | torchaudio.transforms.FrequencyMasking(freq_mask_param=15), 74 | torchaudio.transforms.TimeMasking(time_mask_param=35), 75 | ).to(device) 76 | 77 | ### Creating model from scratch 78 | model = QuartzNet(n_mels=64, num_classes=28) 79 | print('num of params', count_parameters(model)) 80 | model.to(device) 81 | wandb.watch(model) 82 | 83 | opt = torch_optimizer.NovoGrad( 84 | model.parameters(), 85 | lr=0.01, 86 | betas=(0.8, 0.5), 87 | weight_decay=0.001, 88 | ) 89 | scheduler = CosineAnnealingLR(opt, T_max=50, eta_min=0, last_epoch=-1) 90 | CTCLoss = nn.CTCLoss(blank=0).to(device) 91 | train(model, opt, train_loader, scheduler, CTCLoss, device, 92 | n_epochs=NUM_EPOCHS, val_dl=val_loader, 93 | melspec=melspec, melspec_transforms=melspec_transforms) 94 | 95 | -------------------------------------------------------------------------------- /train_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import wandb 4 | import numpy as np 5 | 6 | from tqdm import tqdm 7 | 8 | from utils import cer, wer, decoder_func, beam_search_decoding 9 | 10 | 11 | def train_epoch(model, optimizer, dataloader, CTCLoss, device, melspec_transforms): 12 | model.train() 13 | losses = [] 14 | 15 | for i, (wavs, wavs_len, answ, answ_len) in tqdm(enumerate(dataloader)): 16 | wavs, answ = wavs.to(device), answ.to(device) 17 | 18 | trans_wavs = torch.log(melspec_transforms(wavs) + 1e-9) 19 | 20 | optimizer.zero_grad() 21 | 22 | output = model(trans_wavs) 23 | output = F.log_softmax(output, dim=1) 24 | output = output.transpose(0, 1).transpose(0, 2) 25 | 26 | loss = CTCLoss(output, answ, wavs_len, answ_len) 27 | loss.backward() 28 | 29 | torch.nn.utils.clip_grad_norm_(model.parameters(), 15) 30 | optimizer.step() 31 | losses.append(loss.item()) 32 | if i % 100 == 0: 33 | wandb.log({'mean_train_loss':loss}) 34 | preds, targets = decoder_func(output, answ, answ_len, del_repeated=False) 35 | wandb.log({"CER_train": cer(targets[0], preds[0])}) 36 | wandb.log({"WER_train": wer(targets[0], preds[0])}) 37 | 38 | return np.mean(losses) 39 | 40 | 41 | def train(model, opt, train_dl, scheduler, CTCLoss, device, n_epochs, val_dl=None, 42 | melspec=None, melspec_transforms=None): 43 | for epoch in range(n_epochs): 44 | print("Epoch {} of {}".format(epoch, n_epochs), 'LR', scheduler.get_last_lr()) 45 | 46 | mean_loss = train_epoch(model, opt, train_dl, CTCLoss, device, melspec_transforms) 47 | print('MEAN EPOCH LOSS IS', mean_loss) 48 | 49 | scheduler.step() 50 | 51 | if (val_dl != None): 52 | test(model, opt, val_dl, CTCLoss, device, melspec=melspec) 53 | 54 | torch.save({ 55 | 'epoch': epoch, 56 | 'model_state_dict': model.state_dict(), 57 | 'optimizer_state_dict': opt.state_dict(), 58 | 'scheduler_state_dict': scheduler.state_dict() 59 | }, 'epoch_0_and_'+str(epoch)) 60 | 61 | 62 | def test(model, optimizer, dataloader, CTCLoss, device, melspec, bs_width=None): 63 | model.eval() 64 | 65 | cers, wers, cers_bs, wers_bs = [], [], [], [] 66 | losses = [] 67 | 68 | with torch.no_grad(): 69 | for i, (wavs, wavs_len, answ, answ_len) in enumerate(dataloader): 70 | wavs, answ = wavs.to(device), answ.to(device) 71 | 72 | trans_wavs = torch.log(melspec(wavs) + 1e-9) 73 | 74 | output = model(trans_wavs) 75 | if bs_width != None: 76 | output_bs = F.softmax(output, dim=1).transpose(0, 1).transpose(0, 2) 77 | preds_bs, targets_bs = beam_search_decoding(output_bs, answ, answ_len, width=bs_width) 78 | 79 | output = F.log_softmax(output, dim=1) 80 | output = output.transpose(0, 1).transpose(0, 2) 81 | loss = CTCLoss(output, answ, wavs_len, answ_len) 82 | losses.append(loss.item()) 83 | 84 | # argmax 85 | preds, targets = decoder_func(output, answ, answ_len, del_repeated=True) 86 | 87 | for i in range(len(preds)): 88 | if i == 0: 89 | print('target: ', ''.join(targets[i])) 90 | print('prediction: ', ''.join(preds[i])) 91 | 92 | cers.append(cer(targets[i], preds[i])) 93 | wers.append(wer(targets[i], preds[i])) 94 | if bs_width != None and i == 0: 95 | print('beamS pred:', ''.join(preds_bs[i])) 96 | cers_bs.append(cer(targets_bs[i], preds_bs[i])) 97 | wers_bs.append(wer(targets_bs[i], preds_bs[i])) 98 | 99 | avg_cer = np.mean(cers) 100 | avg_wer = np.mean(wers) 101 | if bs_width != None: 102 | avg_cer_bs = np.mean(cers_bs) 103 | avg_wer_bs = np.mean(wers_bs) 104 | 105 | wandb.log({"CER_val": avg_cer}) 106 | wandb.log({"WER_val": avg_wer}) 107 | avg_loss= np.mean(losses) 108 | print('average test loss is', avg_loss) 109 | wandb.log({'mean_VAL_loss':avg_loss}) 110 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import string 3 | import asrtoolkit 4 | 5 | class TextTransform: 6 | def __init__(self): 7 | self.char_dict = {} 8 | self.index_dict = {} 9 | 10 | self.char_dict['\''] = 0 11 | self.index_dict[0] = '\'' 12 | self.char_dict[' '] = 1 13 | self.index_dict[1] = ' ' 14 | for i, let in enumerate(string.ascii_lowercase): 15 | self.index_dict[i + 2] = let 16 | self.char_dict[let] = i + 2 17 | 18 | def text_to_int(self, text): 19 | labels = [] 20 | for let in text: 21 | labels.append(self.char_dict[let]) 22 | return labels 23 | 24 | def int_to_text(self, labels): 25 | text = [] 26 | for num in labels: 27 | text.append(self.index_dict[num]) 28 | return text 29 | 30 | 31 | # argmax decoding 32 | def decoder_func(output, answ, answ_lens, blank_label=0, del_repeated=True): 33 | decoded_preds, decoded_targs = [], [] 34 | 35 | text_transform = TextTransform() 36 | 37 | batch_freqs = torch.argmax(output, dim=2).transpose(0, 1) 38 | 39 | for i, freqs in enumerate(batch_freqs): 40 | preds = [] 41 | 42 | decoded_targs.append( 43 | text_transform.int_to_text(answ[i][:answ_lens[i]].tolist()) 44 | ) 45 | 46 | for j, num in enumerate(freqs): 47 | if num != blank_label: 48 | if del_repeated and j != 0 and num == freqs[j-1]: 49 | continue 50 | preds.append(num.item()) 51 | decoded_preds.append(text_transform.int_to_text(preds)) 52 | 53 | return decoded_preds, decoded_targs 54 | 55 | 56 | # beam search decoding 57 | def beam_search_decoding(output, answ, answ_lens, blank_label=0, width=8): 58 | decoded_preds, decoded_targs = [], [] 59 | 60 | text_transform = TextTransform() 61 | 62 | for i, mat in enumerate(output.transpose(0, 1)): 63 | last = {} 64 | P_b, P_t = 1, 1 65 | P_nb = 0 66 | # dict [0:prob_blank, 1:prob_not_blank, 2:prob_total] 67 | last[''] = [P_b, P_nb, P_t] 68 | 69 | for t in range(mat.shape[0]): 70 | curr = {} 71 | 72 | # sorting 73 | cand = [(key, el) for (key, el) in last.items()] 74 | sorted_cand = sorted(cand, reverse=True, key=lambda x: x[1][2]) # P_Total 75 | best_beams = [key for (key, el) in sorted_cand][0:width] 76 | 77 | for beam in best_beams: 78 | P_nb = 0 79 | if t == 0: 80 | beam = '' 81 | else: 82 | if len(beam) > 0: 83 | last_num = text_transform.text_to_int(beam[-1]) 84 | P_nb = last[beam][1] * mat[t, last_num] 85 | 86 | P_b = last[beam][2] * mat[t, blank_label] 87 | 88 | if beam not in curr: 89 | curr[beam] = [P_b, P_nb, P_b+P_nb] 90 | else: 91 | curr[beam][0] += P_b 92 | curr[beam][1] += P_nb 93 | curr[beam][2] += P_b + P_nb 94 | 95 | # 0 is blank 96 | for c in range(1, mat.shape[1]): 97 | new_beam = beam + ''.join(text_transform.int_to_text([c])) 98 | 99 | if len(beam) > 0 and last_num == c: 100 | P_nb = mat[t, c] * last[beam][0] 101 | else: 102 | P_nb = mat[t, c] * last[beam][2] 103 | 104 | if new_beam not in curr: 105 | curr[new_beam] = [0, P_nb, P_nb] 106 | else: 107 | curr[new_beam][1] += P_nb 108 | curr[new_beam][2] += P_nb 109 | last = curr 110 | 111 | cand = [(key, el) for (key, el) in last.items()] 112 | sorted_cand = sorted(cand, reverse=True, key=lambda x: x[1][2]) 113 | best_beam = [x[0] for x in sorted_cand][0] 114 | 115 | decoded_preds.append(best_beam) 116 | 117 | # i - номер бача 118 | decoded_targs.append( 119 | text_transform.int_to_text(answ[i][:answ_lens[i]].tolist()) 120 | ) 121 | 122 | return decoded_preds, decoded_targs 123 | 124 | 125 | def cer(target, pred): 126 | cer_res = asrtoolkit.cer(''.join(target), ''.join(pred)) 127 | return cer_res 128 | 129 | 130 | def wer(target, pred): 131 | wer_res = asrtoolkit.wer(''.join(target), ''.join(pred)) 132 | return wer_res 133 | --------------------------------------------------------------------------------