├── README.md ├── deepdream.py ├── demo.html ├── diag.py ├── speech2text.py └── vis.py /README.md: -------------------------------------------------------------------------------- 1 | # speech2text.py 2 | This is a PyTorch inference script for the NVidia openseq2seq's [wav2letter model](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/wave2letter.html) to PyTorch. 3 | 4 | The [pretrained model weights for English](https://github.com/vadimkantorov/inferspeech/releases/download/pretrained/w2l_plus_large_mp.h5) were exported from a TensorFlow [checkpoint](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/wave2letter.html#training) to HDF5 using a little [tfcheckpoint2pytorch](https://github.com/vadimkantorov/tfcheckpoint2pytorch) script that I wrote. 5 | 6 | **Limitations:** not ready for production, uses float32 weights; uses greedy decoder; does not chunk the input 7 | 8 | **Dependencies:** `pytorch` (cpu version is OK), `numpy`, `scipy`, `h5py`; optional dependencies for saving the model weights to tfjs format: `tensorflow` v1.13.1 (install as `pip3 install tensorflow==1.13.1`), tensorflowjs (install as `pip3 install tensorflowjs --no-deps`, otherwise it would upgrade your TensorFlow from v1 to v2 and break everything) 9 | 10 | The credit for the original [wav2letter++ model](https://arxiv.org/abs/1812.07625) goes to awesome Facebook AI Research scientists. 11 | 12 | # Example 13 | ```shell 14 | # download the pretrained model weights for English and Russian 15 | wget https://github.com/vadimkantorov/inferspeech/releases/download/pretrained/w2l_plus_large_mp.h5 # English, Wav2Letter 16 | wget https://github.com/vadimkantorov/inferspeech/releases/download/pretrained/checkpoint_0010_epoch_01_iter_62500.model.h5 # Russian 17 | wget https://github.com/vadimkantorov/inferspeech/releases/download/pretrained/jasper10x5_LibriSpeech_nvgrad_masks.h5.part_aa # English, Jasper, part1 18 | wget https://github.com/vadimkantorov/inferspeech/releases/download/pretrained/jasper10x5_LibriSpeech_nvgrad_masks.h5.part_ab # English, Jasper, part2 19 | cat jasper10x5_LibriSpeech_nvgrad_masks.h5.part_aa jasper10x5_LibriSpeech_nvgrad_masks.h5.part_ab > jasper10x5_LibriSpeech_nvgrad_masks.h5 20 | 21 | # download and transcribe a wav file (16 kHz) 22 | # should print: my heart doth plead that thou in him doth lie a closet never pierced with crystal eyes but the defendant doth that plea deny and says in him thy fair appearance lies 23 | wget https://github.com/vadimkantorov/inferspeech/releases/download/pretrained/121-123852-0004.wav 24 | python3 speech2text.py --model en_w2l --weights w2l_plus_large_mp.h5 -i 121-123852-0004.wav # use Wav2Letter model 25 | python3 speech2text.py --model en_w2l --weights jasper10x5_LibriSpeech_nvgrad_masks.h5 -i 121-123852-0004.wav # use Jasper model 26 | 27 | # transcribe some Russian wav file 28 | python3 speech2text.py --model ru_w2l --weights checkpoint_0010_epoch_01_iter_62500.model.h5 -i some_test.wav 29 | 30 | # save the model to ONNX format for inspection with https://lutzroeder.github.io/netron/ 31 | python3 speech2text.py --model en_w2l --weights w2l_plus_large_mp.h5 --onnx w2l_plus_large_mp.onnx 32 | 33 | # save the model to TensorFlow.js format 34 | python3 speech2text.py --model en_w2l --weights w2l_plus_large_mp.h5 --tfjs w2l_plus_large_mp.tfjs 35 | ``` 36 | 37 | # Browser demo with TensorFlow.js (work in progress) 38 | ```shell 39 | # download and extract the exported tfjs model 40 | wget https://github.com/vadimkantorov/inferspeech/releases/download/pretrained/w2l_plus_large_mp.tfjs.tar.gz 41 | tar -xf w2l_plus_large_mp.tfjs.tar.gz 42 | 43 | # serve the tfjs model and demo.html file 44 | python3 -m http.server 45 | 46 | # open the demo at http://localhost:8000/demo.html and transcribe the test file 121-123852-0004.wav 47 | ``` 48 | -------------------------------------------------------------------------------- /deepdream.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import math 4 | import json 5 | import base64 6 | import argparse 7 | import scipy.io.wavfile 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | import matplotlib; matplotlib.use('Agg') 13 | import matplotlib.pyplot as plt 14 | 15 | import speech2text 16 | 17 | def perturb(batch_first, batch_last, K = 80): 18 | diff = batch_last - batch_first 19 | positive = F.relu(diff) 20 | small = diff.clone() 21 | small[:, K:] = 0 22 | large = diff * (diff.abs() < 0.25 * diff.max()).float() 23 | return batch_first + positive, batch_first + small, batch_first + large 24 | 25 | def vis_(batch_first_grad, batch_first, batch_last, scores_first, scores_last, K = 80): 26 | postproc = lambda decoded: ''.join('.' if c == '|' else c if i == 0 or c == ' ' or c != idx2chr(decoded[i - 1]) else '_' for i, c in enumerate(''.join(map(idx2chr, decoded)))) 27 | normalize_min_max = lambda scores, dim = 0: (scores - scores.min(dim = dim).values) / (scores.max(dim = dim).values - scores.min(dim = dim).values + 1e-16) 28 | entropy = lambda x, dim, eps = 1e-15: -(x / x.sum(dim = dim, keepdim = True).add(eps) * (x / x.sum(dim = dim, keepdim = True).add(eps) + eps).log()).sum(dim = dim) 29 | 30 | plt.figure(figsize=(6, 3)) 31 | def colorbar(): cb = plt.colorbar(); cb.outline.set_visible(False); cb.ax.tick_params(labelsize = 4, length = 0.3) 32 | title = lambda s: plt.title(s, fontsize = 5) 33 | ticks = lambda labelsize = 4, length = 1: plt.gca().tick_params(axis='both', which='both', labelsize=labelsize, length=length) or [ax.set_linewidth(0) for ax in plt.gca().spines.values()] 34 | plt.subplots_adjust(top = 0.99, bottom=0.01, hspace=0.8, wspace=0.4) 35 | 36 | num_subplots = 9 37 | plt.subplot(num_subplots, 1, 1) 38 | plt.imshow(batch_first[:K].log1p(), origin = 'lower', aspect = 'auto'); ticks(); colorbar() 39 | title('LogSpectrogram, original') 40 | 41 | plt.subplot(num_subplots, 1, 2) 42 | plt.imshow(batch_last[:K].log1p(), origin = 'lower', aspect = 'auto'); ticks(); colorbar() 43 | title('LogSpectrogram, dream') 44 | 45 | plt.subplot(num_subplots, 1, 3) 46 | diff = batch_last - batch_first 47 | plt.imshow((diff * (diff.abs() < 0.25 * diff.max()).float())[:K].log1p(), origin = 'lower', aspect = 'auto'); ticks(); colorbar() 48 | title('LogDiff') 49 | 50 | plt.subplot(num_subplots, 1, 4) 51 | plt.imshow(batch_first_grad[:K].log1p(), origin = 'lower', aspect = 'auto'); ticks(); colorbar() 52 | title('LogGrad') 53 | 54 | scores_first_01 = normalize_min_max(scores_first) 55 | scores_last_01 = normalize_min_max(scores_last) 56 | scores_first_softmax = F.softmax(scores_first, dim = 0) 57 | scores_last_softmax = F.softmax(scores_last, dim = 0) 58 | 59 | plt.subplot(num_subplots, 1, 5) 60 | plt.imshow(scores_first_01, origin = 'lower', aspect = 'auto'); ticks(); colorbar() 61 | title('Scores, original') 62 | 63 | plt.subplot(num_subplots, 1, 6) 64 | plt.imshow(scores_last_01, origin = 'lower', aspect = 'auto'); ticks(); colorbar() 65 | title('Scores, dream') 66 | 67 | plt.subplot(num_subplots, 1, 7) 68 | plt.imshow(scores_last_01 - scores_first_01, origin = 'lower', aspect = 'auto'); ticks(); colorbar() 69 | title('Diff') 70 | 71 | plt.subplot(num_subplots, 1, 8) 72 | plt.plot(entropy(scores_first_softmax, dim = 0).numpy(), linewidth = 0.3, color='b') 73 | plt.plot(entropy(scores_last_softmax, dim = 0).numpy(), linewidth = 0.3, color='r') 74 | plt.hlines(1.0, 0, scores_first_01.shape[-1]); colorbar() 75 | plt.xlim([0, scores_first_01.shape[-1]]) 76 | plt.ylim([0, 2.5]) 77 | plt.xticks(torch.arange(scores_first_01.shape[-1]), postproc(scores_last.argmax(dim = 0).tolist())) 78 | ticks(labelsize=2, length = 0) 79 | ax = plt.gca().twiny() 80 | ax.tick_params(axis='x') 81 | plt.xticks(torch.arange(scores_first_01.shape[-1]), postproc(scores_first.argmax(dim = 0).tolist())) 82 | ticks(labelsize=2, length = 0) 83 | title('Entropy') 84 | 85 | plt.subplot(num_subplots, 1, 9) 86 | plt.plot(F.kl_div(scores_first_softmax, scores_last_softmax, reduction = 'none').sum(dim = 0), linewidth = 0.3, color = 'b') 87 | plt.plot((scores_last_softmax - scores_first_softmax).abs().sum(dim = 0), linewidth = 0.3, color = 'g'); ticks(); colorbar() 88 | plt.xlim([0, scores_first_01.shape[-1]]) 89 | plt.ylim([-2, 2]) 90 | title('KL') 91 | 92 | buf = io.BytesIO() 93 | plt.savefig(buf, format = 'jpg', bbox_inches = 'tight', dpi = 300) 94 | plt.close() 95 | return buf.getvalue() 96 | 97 | if __name__ == '__main__': 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument('--weights', default = 'model_checkpoint_0027_epoch_02.model.h5') #'checkpoint_0010_epoch_01_iter_62500.model.h5') 100 | parser.add_argument('-o', '--output-path', default = 'dream.html') 101 | parser.add_argument('-i', '--input-path', default = 'transcripts.json') 102 | parser.add_argument('--device', default = 'cuda') 103 | parser.add_argument('--num-iter', default = 100, type = int) 104 | parser.add_argument('--lr', default = 1e6, type = float) 105 | parser.add_argument('--max-norm', default = 100, type = float) 106 | args = parser.parse_args() 107 | 108 | ref_tra = list(sorted(json.load(open(args.input_path)), key = lambda j: j['cer'], reverse = True)) 109 | 110 | frontend, model, idx2chr, chr2idx = speech2text.load_model_ru_w2l(args.weights) 111 | model = model.to(args.device) 112 | 113 | vis = open(args.output_path , 'w') 114 | vis.write('
') 115 | 116 | for i, (reference, transcript, filename, cer) in enumerate(list(map(j.get, ['reference', 'transcript', 'filename', 'cer'])) for j in ref_tra): 117 | sample_rate, signal = scipy.io.wavfile.read(filename) 118 | if i > 5: continue 119 | 120 | signal = torch.from_numpy(signal).to(torch.float32).to(args.device).requires_grad_() 121 | labels = torch.IntTensor(list(map(chr2idx, reference))).to(args.device) 122 | 123 | batch_first, batch_first_grad, batch_last, scores_first, hyp_first, hyp_last, scores_last = None, None, None, None, None, None, None 124 | for k in range(args.num_iter): 125 | batch = frontend(signal, sample_rate).unsqueeze(0).requires_grad_(); batch.retain_grad() 126 | scores = model(batch) 127 | 128 | hyp = speech2text.decode_greedy(scores.squeeze(0), idx2chr) 129 | loss = F.ctc_loss(F.log_softmax(scores, dim = 1).permute(2, 0, 1), labels.unsqueeze(0), torch.IntTensor([scores.shape[-1]]).to(args.device), torch.IntTensor([len(labels)]).to(args.device), blank = chr2idx('|')) 130 | print(i, 'Loss:', float(loss)) 131 | if not hyp or (torch.isnan(loss) | torch.isinf(loss)).any(): 132 | continue 133 | model.zero_grad() 134 | loss.backward() 135 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) 136 | signal.data.sub_(signal.grad.data.mul_(args.lr)) 137 | signal.grad.data.zero_() 138 | if k == 0: 139 | batch_first = batch.clone() 140 | batch_first_grad = batch.grad.clone().neg_() 141 | scores_first = scores.clone() 142 | hyp_first = hyp 143 | scores_last = scores.clone() 144 | batch_last = batch.clone() 145 | hyp_last = hyp 146 | 147 | print(i, '| #', k, 'REF: ', reference) 148 | print(i, '| #', k, 'HYP: ', hyp) 149 | print() 150 | 151 | 152 | hyp_positive, hyp_small, hyp_large = [speech2text.decode_greedy(model(x).squeeze(0), idx2chr) for x in perturb(batch_first, batch_last)] 153 | encoded = base64.b64encode(open(filename, 'rb').read()).decode('utf-8').replace('\n', '') 154 | vis.write(f'transcript will appear here9 | 10 | 11 | 12 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /diag.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import Levenshtein 3 | 4 | class Alignment(object): 5 | SCORE_UNIFORM = 1 6 | SCORE_PROPORTION = 2 7 | 8 | def __init__(self): 9 | self.seq_a = None 10 | self.seq_b = None 11 | self.len_a = None 12 | self.len_b = None 13 | self.score_null = 5 14 | self.score_sub = -100 15 | self.score_del = -3 16 | self.score_ins = -3 17 | self.separator = '|' 18 | self.mode = Alignment.SCORE_UNIFORM 19 | 20 | def set_score(self, score_null=None, score_sub=None, score_del=None, score_ins=None): 21 | if score_null is not None: 22 | self.score_null = score_null 23 | if score_sub is not None: 24 | self.score_sub = score_sub 25 | if score_del is not None: 26 | self.score_del = score_del 27 | if score_ins is not None: 28 | self.score_ins = score_ins 29 | 30 | def match(self, a, b): 31 | if a == b and self.mode == Alignment.SCORE_UNIFORM: 32 | return self.score_null 33 | elif self.mode == Alignment.SCORE_UNIFORM: 34 | return self.score_sub 35 | elif a == b: 36 | return self.score_null * len(a) 37 | else: 38 | return self.score_sub * len(a) 39 | 40 | def delete(self, a): 41 | """ 42 | deleted elements are on seqa 43 | """ 44 | if self.mode == Alignment.SCORE_UNIFORM: 45 | return self.score_del 46 | return self.score_del * len(a) 47 | 48 | def insert(self, a): 49 | """ 50 | inserted elements are on seqb 51 | """ 52 | if self.mode == Alignment.SCORE_UNIFORM: 53 | return self.score_ins 54 | return self.score_ins * len(a) 55 | 56 | def score(self, aligned_seq_a, aligned_seq_b): 57 | score = 0 58 | for a, b in zip(aligned_seq_a, aligned_seq_b): 59 | if a == b: 60 | score += self.score_null 61 | else: 62 | if a == self.separator: 63 | score += self.score_ins 64 | elif b == self.separator: 65 | score += self.score_del 66 | else: 67 | score += self.score_sub 68 | return score 69 | 70 | def map_alignment(self, aligned_seq_a, aligned_seq_b): 71 | map_b2a = [] 72 | idx = 0 73 | for x, y in zip(aligned_seq_a, aligned_seq_b): 74 | if x == y: 75 | # if two positions are the same 76 | map_b2a.append(idx) 77 | idx += 1 78 | elif x == self.separator: 79 | # if a character is inserted in b, map b's 80 | # position to previous index in a 81 | # b[0]=0, b[1]=1, b[2]=1, b[3]=2 82 | # aa|bbb 83 | # aaabbb 84 | map_b2a.append(idx) 85 | elif y == self.separator: 86 | # if a character is deleted in a, increase 87 | # index in a, skip this position 88 | # b[0]=0, b[1]=1, b[2]=3 89 | # aaabbb 90 | # aa|bbb 91 | idx += 1 92 | continue 93 | return map_b2a 94 | 95 | 96 | class Needleman(Alignment): 97 | def __init__(self, *args): 98 | super(Needleman, self).__init__() 99 | self.semi_global = False 100 | self.matrix = None 101 | 102 | def init_matrix(self): 103 | rows = self.len_a + 1 104 | cols = self.len_b + 1 105 | self.matrix = [[0] * cols for i in range(rows)] 106 | 107 | def compute_matrix(self): 108 | seq_a = self.seq_a 109 | seq_b = self.seq_b 110 | len_a = self.len_a 111 | len_b = self.len_b 112 | 113 | if not self.semi_global: 114 | for i in range(1, len_a + 1): 115 | self.matrix[i][0] = self.delete(seq_a[i - 1]) + self.matrix[i - 1][0] 116 | for i in range(1, len_b + 1): 117 | self.matrix[0][i] = self.insert(seq_b[i - 1]) + self.matrix[0][i - 1] 118 | 119 | for i in range(1, len_a + 1): 120 | for j in range(1, len_b + 1): 121 | """ 122 | Note that rows = len_a+1, cols = len_b+1 123 | """ 124 | 125 | score_sub = self.matrix[i - 1][j - 1] + self.match(seq_a[i - 1], seq_b[j - 1]) 126 | score_del = self.matrix[i - 1][j] + self.delete(seq_a[i - 1]) 127 | score_ins = self.matrix[i][j - 1] + self.insert(seq_b[j - 1]) 128 | self.matrix[i][j] = max(score_sub, score_del, score_ins) 129 | 130 | def backtrack(self): 131 | aligned_seq_a, aligned_seq_b = [], [] 132 | seq_a, seq_b = self.seq_a, self.seq_b 133 | 134 | if self.semi_global: 135 | # semi-global settings, len_a = row numbers, column length, len_b = column number, row length 136 | last_col_max, val = max(enumerate([row[-1] for row in self.matrix]), key=lambda a: a[1]) 137 | last_row_max, val = max(enumerate([col for col in self.matrix[-1]]), key=lambda a: a[1]) 138 | 139 | if self.len_a < self.len_b: 140 | i, j = self.len_a, last_row_max 141 | aligned_seq_a = [self.separator] * (self.len_b - last_row_max) 142 | aligned_seq_b = seq_b[last_row_max:] 143 | else: 144 | i, j = last_col_max, self.len_b 145 | aligned_seq_a = seq_a[last_col_max:] 146 | aligned_seq_b = [self.separator] * (self.len_a - last_col_max) 147 | else: 148 | i, j = self.len_a, self.len_b 149 | 150 | mat = self.matrix 151 | 152 | while i > 0 or j > 0: 153 | # from end to start, choose insert/delete over match for a tie 154 | # why? 155 | if self.semi_global and (i == 0 or j == 0): 156 | if i == 0 and j > 0: 157 | aligned_seq_a = [self.separator] * j + aligned_seq_a 158 | aligned_seq_b = seq_b[:j] + aligned_seq_b 159 | elif i > 0 and j == 0: 160 | aligned_seq_a = seq_a[:i] + aligned_seq_a 161 | aligned_seq_b = [self.separator] * i + aligned_seq_b 162 | break 163 | 164 | if j > 0 and mat[i][j] == mat[i][j - 1] + self.insert(seq_b[j - 1]): 165 | aligned_seq_a.insert(0, self.separator * len(seq_b[j - 1])) 166 | aligned_seq_b.insert(0, seq_b[j - 1]) 167 | j -= 1 168 | 169 | elif i > 0 and mat[i][j] == mat[i - 1][j] + self.delete(seq_a[i - 1]): 170 | aligned_seq_a.insert(0, seq_a[i - 1]) 171 | aligned_seq_b.insert(0, self.separator * len(seq_a[i - 1])) 172 | i -= 1 173 | 174 | elif i > 0 and j > 0 and mat[i][j] == mat[i - 1][j - 1] + self.match(seq_a[i - 1], seq_b[j - 1]): 175 | aligned_seq_a.insert(0, seq_a[i - 1]) 176 | aligned_seq_b.insert(0, seq_b[j - 1]) 177 | i -= 1 178 | j -= 1 179 | 180 | else: 181 | print(seq_a) 182 | print(seq_b) 183 | print(aligned_seq_a) 184 | print(aligned_seq_b) 185 | # print(mat) 186 | raise Exception('backtrack error', i, j, seq_a[i - 2:i + 1], seq_b[j - 2:j + 1]) 187 | pass 188 | 189 | return aligned_seq_a, aligned_seq_b 190 | 191 | def align(self, seq_a, seq_b, semi_global=True, mode=None): 192 | self.seq_a = seq_a 193 | self.seq_b = seq_b 194 | self.len_a = len(self.seq_a) 195 | self.len_b = len(self.seq_b) 196 | 197 | self.semi_global = semi_global 198 | 199 | # 0: left-end 0-penalty, 1: right-end 0-penalty, 2: both ends 0-penalty 200 | # self.semi_end = semi_end 201 | 202 | if mode is not None: 203 | self.mode = mode 204 | self.init_matrix() 205 | self.compute_matrix() 206 | return self.backtrack() 207 | 208 | 209 | class Hirschberg(Alignment): 210 | def __init__(self): 211 | super(Hirschberg, self).__init__() 212 | self.needleman = Needleman() 213 | 214 | def last_row(self, seqa, seqb): 215 | lena = len(seqa) 216 | lenb = len(seqb) 217 | pre_row = [0] * (lenb + 1) 218 | cur_row = [0] * (lenb + 1) 219 | 220 | for j in range(1, lenb + 1): 221 | pre_row[j] = pre_row[j - 1] + self.insert(seqb[j - 1]) 222 | 223 | for i in range(1, lena + 1): 224 | cur_row[0] = self.delete(seqa[i - 1]) + pre_row[0] 225 | for j in range(1, lenb + 1): 226 | score_sub = pre_row[j - 1] + self.match(seqa[i - 1], seqb[j - 1]) 227 | score_del = pre_row[j] + self.delete(seqa[i - 1]) 228 | score_ins = cur_row[j - 1] + self.insert(seqb[j - 1]) 229 | cur_row[j] = max(score_sub, score_del, score_ins) 230 | 231 | pre_row = cur_row 232 | cur_row = [0] * (lenb + 1) 233 | 234 | return pre_row 235 | 236 | def align_rec(self, seq_a, seq_b): 237 | aligned_a, aligned_b = [], [] 238 | len_a, len_b = len(seq_a), len(seq_b) 239 | 240 | if len_a == 0: 241 | for i in range(len_b): 242 | aligned_a.append(self.separator * len(seq_b[i])) 243 | aligned_b.append(seq_b[i]) 244 | elif len_b == 0: 245 | for i in range(len_a): 246 | aligned_a.append(seq_a[i]) 247 | aligned_b.append(self.separator * len(seq_a[i])) 248 | 249 | elif len(seq_a) == 1: 250 | aligned_a, aligned_b = self.needleman.align(seq_a, seq_b) 251 | 252 | else: 253 | mid_a = int(len_a / 2) 254 | 255 | rowleft = self.last_row(seq_a[:mid_a], seq_b) 256 | rowright = self.last_row(seq_a[mid_a:][::-1], seq_b[::-1]) 257 | 258 | rowright.reverse() 259 | 260 | row = [l + r for l, r in zip(rowleft, rowright)] 261 | maxidx, maxval = max(enumerate(row), key=lambda a: a[1]) 262 | 263 | mid_b = maxidx 264 | 265 | aligned_a_left, aligned_b_left = self.align_rec(seq_a[:mid_a], seq_b[:mid_b]) 266 | aligned_a_right, aligned_b_right = self.align_rec(seq_a[mid_a:], seq_b[mid_b:]) 267 | aligned_a = aligned_a_left + aligned_a_right 268 | aligned_b = aligned_b_left + aligned_b_right 269 | 270 | return aligned_a, aligned_b 271 | 272 | def align(self, seq_a, seq_b, mode=None): 273 | self.seq_a = seq_a 274 | self.seq_b = seq_b 275 | self.len_a = len(self.seq_a) 276 | self.len_b = len(self.seq_b) 277 | if mode is not None: 278 | self.mode = mode 279 | return self.align_rec(self.seq_a, self.seq_b) 280 | 281 | def analyze(ref, hyp, phonetic_replace_groups = []): 282 | ref0, hyp0 = ref, hyp 283 | ref, hyp = Needleman().align(list(ref), list(hyp)) 284 | r, h = '', '' 285 | i = 0 286 | while i < len(ref): 287 | if i + 1 < len(hyp) and ref[i] == '|' and hyp[i + 1] == '|': 288 | r += ref[i + 1] 289 | h += hyp[i] 290 | i += 2 291 | elif i + 1 < len(ref) and ref[i + 1] == '|' and hyp[i] == '|': 292 | r += ref[i] 293 | h += hyp[i + 1] 294 | i += 2 295 | else: 296 | r += ref[i] 297 | h += hyp[i] 298 | i += 1 299 | 300 | def words(): 301 | k = None 302 | for i in range(1 + len(r)): 303 | if i == len(r) or r[i] == ' ': 304 | yield r[k : i], h[k : i] 305 | k = None 306 | elif r[i] != '|' and r[i] != ' ' and k is None: 307 | k = i 308 | 309 | assert len(r) == len(h) 310 | phonetic_group = lambda c: ([i for i, g in enumerate(phonetic_replace_groups) if c in g] + [c])[0] 311 | a = dict( 312 | chars = dict( 313 | ok = sum(1 if r[i] == h[i] else 0 for i in range(len(r))), 314 | replace = sum(1 if r[i] != '|' and r[i] != h[i] and h[i] != '|' else 0 for i in range(len(r))), 315 | replace_phonetic = sum(1 if r[i] != '|' and r[i] != h[i] and h[i] != '|' and phonetic_group(r[i]) == phonetic_group(h[i]) else 0 for i in range(len(r))), 316 | delete = sum(1 if r[i] != '|' and r[i] != h[i] and h[i] == '|' else 0 for i in range(len(r))), 317 | insert = sum(1 if r[i] == '|' and h[i] != '|' else 0 for i in range(len(r))), 318 | total = len(r) 319 | ), 320 | spaces = dict( 321 | delete = sum(1 if r[i] == ' ' and h[i] != ' ' else 0 for i in range(len(r))), 322 | insert = sum(1 if h[i] == ' ' and r[i] != ' ' else 0 for i in range(len(r))), 323 | total = sum(1 if r[i] == ' ' else 0 for i in range(len(r))) 324 | ), 325 | words = dict( 326 | missing_prefix = sum(1 if h_[0] in ' |' else 0 for r_, h_ in words()), 327 | missing_suffix = sum(1 if h_[-1] in ' |' else 0 for r_, h_ in words()), 328 | ok_prefix_suffix = sum(1 if h_[0] not in ' |' and h_[-1] not in ' |' else 0 for r_, h_ in words()), 329 | delete = sum(1 if h_.count('|') > len(r_) // 2 else 0 for r_, h_ in words()), 330 | replace = sum(1 if sum(1 if h_[i] not in ' |' and h_[i] != r_[i] else 0 for i in range(len(r_))) > len(r_) // 2 else 0 for r_, h_ in words()), 331 | total = sum(1 if c == ' ' else 0 for c in ref) + 1 332 | ), 333 | alignment = dict(ref = r, hyp = h), 334 | input = dict(ref = ref0, hyp = hyp0), 335 | cer = Levenshtein.distance(ref0.replace(' ', ''), hyp0.replace(' ', '')) / len(ref0.replace(' ', '')) 336 | 337 | ) 338 | return a 339 | 340 | RU_PHONETIC_REPLACE_GROUPS = ['АО', 'БП', 'ЗСЦ', 'ВФ', 'ГК', 'ДТ', 'ЧЖШЩ', 'ЭЕИ', 'РЛ', 'ЮУ'] 341 | 342 | if __name__ == '__main__': 343 | parser = argparse.ArgumentParser() 344 | parser.add_argument('--ref') 345 | parser.add_argument('--hyp') 346 | args = parser.parse_args() 347 | 348 | print(analyze(args.ref, args.hyp, phonetic_replace_groups = RU_PHONETIC_REPLACE_GROUPS)) 349 | -------------------------------------------------------------------------------- /speech2text.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import math 4 | import numpy as np 5 | import h5py 6 | import scipy.io.wavfile 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | def load_model_en_jasper(model_weights, batch_norm_eps = 0.001, num_classes = 29, ABC = " ABCDEFGHIJKLMNOPQRSTUVWXYZ'|"): 12 | class JasperNet(nn.ModuleList): 13 | def __init__(self, num_classes): 14 | def conv_bn_residual(kernel_size, num_channels, stride = 1, dilation = 1, padding = 0, repeat = 1, num_channels_residual = []): 15 | return nn.ModuleDict(dict( 16 | conv = nn.ModuleList([nn.Conv1d(num_channels[0] if i == 0 else num_channels[1], num_channels[1], kernel_size = kernel_size, stride = stride, dilation = dilation, padding = padding) for i in range(repeat)]), 17 | conv_residual = nn.ModuleList([nn.Conv1d(in_channels, num_channels[1], kernel_size = 1) for in_channels in num_channels_residual]) 18 | )) 19 | 20 | blocks = nn.ModuleList([ 21 | conv_bn_residual(kernel_size = 11, num_channels = (64, 256), padding = 5, stride = 2), 22 | 23 | conv_bn_residual(kernel_size = 11, num_channels = (256, 256), padding = 5, repeat = 5, num_channels_residual = [256]), 24 | conv_bn_residual(kernel_size = 11, num_channels = (256, 256), padding = 5, repeat = 5, num_channels_residual = [256, 256]), 25 | 26 | conv_bn_residual(kernel_size = 13, num_channels = (256, 384), padding = 6, repeat = 5, num_channels_residual = [256, 256, 256]), 27 | conv_bn_residual(kernel_size = 13, num_channels = (384, 384), padding = 6, repeat = 5, num_channels_residual = [256, 256, 256, 384]), 28 | 29 | conv_bn_residual(kernel_size = 17, num_channels = (384, 512), padding = 8, repeat = 5, num_channels_residual = [256, 256, 256, 384, 384]), 30 | conv_bn_residual(kernel_size = 17, num_channels = (512, 512), padding = 8, repeat = 5, num_channels_residual = [256, 256, 256, 384, 384, 512]), 31 | 32 | conv_bn_residual(kernel_size = 21, num_channels = (512, 640), padding = 10, repeat = 5, num_channels_residual = [256, 256, 256, 384, 384, 512, 512]), 33 | conv_bn_residual(kernel_size = 21, num_channels = (640, 640), padding = 10, repeat = 5, num_channels_residual = [256, 256, 256, 384, 384, 512, 512, 640]), 34 | 35 | conv_bn_residual(kernel_size = 25, num_channels = (640, 768), padding = 12, repeat = 5, num_channels_residual = [256, 256, 256, 384, 384, 512, 512, 640, 640]), 36 | conv_bn_residual(kernel_size = 25, num_channels = (768, 768), padding = 12, repeat = 5, num_channels_residual = [256, 256, 256, 384, 384, 512, 512, 640, 640, 768]), 37 | 38 | conv_bn_residual(kernel_size = 29, num_channels = (768, 896), padding = 28, dilation = 2), 39 | conv_bn_residual(kernel_size = 1, num_channels = (896, 1024)), 40 | 41 | nn.Conv1d(1024, num_classes, 1) 42 | ]) 43 | super(JasperNet, self).__init__(blocks) 44 | 45 | def forward(self, x): 46 | residual = [] 47 | for i, block in enumerate(list(self)[:-1]): 48 | for conv in block.conv[:-1]: 49 | x = F.relu(conv(x), inplace = True) 50 | x = block.conv[-1](x) 51 | for conv, r in zip(block.conv_residual, residual if i < len(self) - 3 else []): 52 | x = x + conv(r) 53 | x = F.relu(x, inplace = True) 54 | residual.append(x) 55 | return self[-1](x) 56 | 57 | model = JasperNet(num_classes = len(ABC)) 58 | h = h5py.File(model_weights) 59 | to_tensor = lambda path: torch.from_numpy(np.asarray(h[path])).to(torch.float32) 60 | state_dict = {} 61 | for param_name, param in model.state_dict().items(): 62 | ij = [int(c) for c in param_name.split('.') if c.isdigit()] 63 | weight, bias = None, None 64 | if len(ij) > 1: 65 | weight, moving_mean, moving_variance, gamma, beta = [to_tensor(f'ForwardPass/w2l_encoder/conv{1 + ij[0]}{1 + ij[1]}/{suffix}') for suffix in ['kernel', 'bn/moving_mean', 'bn/moving_variance', 'bn/gamma', 'bn/beta']] if 'residual' not in param_name else [to_tensor(f'ForwardPass/w2l_encoder/conv{1 + ij[0]}5/{suffix}') for suffix in [f'res_{ij[1]}/kernel', f'res_bn_{ij[1]}/moving_mean', f'res_bn_{ij[1]}/moving_variance', f'res_bn_{ij[1]}/gamma', f'res_bn_{ij[1]}/beta']] 66 | weight, bias = fuse_conv_bn(weight.permute(2, 1, 0), moving_mean, moving_variance, gamma, beta, batch_norm_eps = batch_norm_eps) 67 | else: 68 | weight, bias = [to_tensor(f'ForwardPass/fully_connected_ctc_decoder/fully_connected/{suffix}') for suffix in ['kernel', 'bias']] 69 | weight = weight.t().unsqueeze(-1) 70 | 71 | state_dict[param_name] = (weight if 'weight' in param_name else bias).to(param.dtype) 72 | model.load_state_dict(state_dict) 73 | 74 | def frontend(signal, sample_freq, window_size=20e-3, window_stride=10e-3, dither = 1e-5, window_fn = np.hanning, num_features = 64): 75 | def get_melscale_filterbanks(sr, n_fft, n_mels, fmin, fmax, dtype = np.float32): 76 | def hz_to_mel(frequencies): 77 | frequencies = np.asanyarray(frequencies) 78 | f_min = 0.0 79 | f_sp = 200.0 / 3 80 | mels = (frequencies - f_min) / f_sp 81 | min_log_hz = 1000.0 82 | min_log_mel = (min_log_hz - f_min) / f_sp 83 | logstep = np.log(6.4) / 27.0 84 | 85 | if frequencies.ndim: 86 | log_t = (frequencies >= min_log_hz) 87 | mels[log_t] = min_log_mel + np.log(frequencies[log_t]/min_log_hz) / logstep 88 | elif frequencies >= min_log_hz: 89 | mels = min_log_mel + np.log(frequencies / min_log_hz) / logstep 90 | 91 | return mels 92 | 93 | def mel_to_hz(mels): 94 | mels = np.asanyarray(mels) 95 | f_min = 0.0 96 | f_sp = 200.0 / 3 97 | freqs = f_min + f_sp * mels 98 | min_log_hz = 1000.0 99 | min_log_mel = (min_log_hz - f_min) / f_sp 100 | logstep = np.log(6.4) / 27.0 101 | 102 | if mels.ndim: 103 | log_t = (mels >= min_log_mel) 104 | freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel)) 105 | elif mels >= min_log_mel: 106 | freqs = min_log_hz * np.exp(logstep * (mels - min_log_mel)) 107 | 108 | return freqs 109 | 110 | n_mels = int(n_mels) 111 | weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype) 112 | 113 | fftfreqs = np.linspace(0, float(sr) / 2, int(1 + n_fft//2),endpoint=True) 114 | mel_f = mel_to_hz(np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)) 115 | 116 | fdiff = np.diff(mel_f) 117 | ramps = np.subtract.outer(mel_f, fftfreqs) 118 | 119 | for i in range(n_mels): 120 | lower = -ramps[i] / fdiff[i] 121 | upper = ramps[i+2] / fdiff[i+1] 122 | weights[i] = np.maximum(0, np.minimum(lower, upper)) 123 | 124 | enorm = 2.0 / (mel_f[2:n_mels+2] - mel_f[:n_mels]) 125 | weights *= enorm[:, np.newaxis] 126 | return torch.from_numpy(weights) 127 | 128 | signal = signal / (signal.abs().max() + 1e-5) 129 | audio_duration = len(signal) * 1.0 / sample_freq 130 | n_window_size = int(sample_freq * window_size) 131 | n_window_stride = int(sample_freq * window_stride) 132 | num_fft = 2**math.ceil(math.log2(window_size*sample_freq)) 133 | 134 | signal += dither * torch.randn_like(signal) 135 | S = torch.stft(signal, num_fft, hop_length=int(window_stride * sample_freq), win_length=int(window_size * sample_freq), window = torch.hann_window(int(window_size * sample_freq)).type_as(signal), pad_mode = 'reflect', center = True).pow(2).sum(dim = -1) 136 | mel_basis = get_melscale_filterbanks(sample_freq, num_fft, num_features, fmin=0, fmax=int(sample_freq/2)).type_as(S) 137 | 138 | features = torch.log(torch.matmul(mel_basis, S) + 1e-20) 139 | mean = features.mean(dim = 1, keepdim = True) 140 | std_dev = features.std(dim = 1, keepdim = True) 141 | return (features - mean) / std_dev 142 | 143 | return frontend, model, (lambda c: ABC[c]), ABC.index 144 | 145 | def load_model_ru_w2l(model_weights, batch_norm_eps = 1e-05, ABC = '|АБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ2* '): 146 | def conv_bn(kernel_size, num_channels, stride = 1, padding = 0): 147 | return nn.Sequential( 148 | nn.Conv1d(num_channels[0], num_channels[1], kernel_size=kernel_size, stride=stride, padding=padding), 149 | nn.ReLU(inplace = True) 150 | ) 151 | 152 | model = nn.Sequential( 153 | conv_bn(kernel_size = 13, num_channels = (161, 768), stride = 2, padding = 6), 154 | conv_bn(kernel_size = 13, num_channels = (768, 768), stride = 1, padding = 6), 155 | conv_bn(kernel_size = 13, num_channels = (768, 768), stride = 1, padding = 6), 156 | conv_bn(kernel_size = 13, num_channels = (768, 768), stride = 1, padding = 6), 157 | conv_bn(kernel_size = 13, num_channels = (768, 768), stride = 1, padding = 6), 158 | conv_bn(kernel_size = 13, num_channels = (768, 768), stride = 1, padding = 6), 159 | conv_bn(kernel_size = 13, num_channels = (768, 768), stride = 1, padding = 6), 160 | conv_bn(kernel_size = 31, num_channels = (768, 2048), stride = 1, padding = 15), 161 | conv_bn(kernel_size = 1, num_channels = (2048, 2048), stride = 1, padding = 0), 162 | nn.Conv1d(2048, len(ABC), kernel_size=1, stride=1) 163 | ) 164 | 165 | h = h5py.File(model_weights) 166 | to_tensor = lambda path: torch.from_numpy(np.asarray(h[path])).to(torch.float32) 167 | state_dict = {} 168 | for param_name, param in model.state_dict().items(): 169 | ij = [int(c) for c in param_name if c.isdigit()] 170 | if len(ij) > 1: 171 | weight, moving_mean, moving_variance, gamma, beta = [to_tensor(f'rnns.{ij[0] * 3}.weight')] + [to_tensor(f'rnns.{ij[0] * 3 + 1}.{suffix}') for suffix in ['running_mean', 'running_var', 'weight', 'bias']] 172 | weight, bias = fuse_conv_bn(weight, moving_mean, moving_variance, gamma, beta, batch_norm_eps = batch_norm_eps) 173 | else: 174 | weight, bias = [to_tensor(f'fc.0.{suffix}') for suffix in ['weight', 'bias']] 175 | state_dict[param_name] = (weight if 'weight' in param_name else bias).to(param.dtype) 176 | model.load_state_dict(state_dict) 177 | 178 | def frontend(signal, sample_rate, window_size = 0.020, window_stride = 0.010, window = 'hann'): 179 | signal = signal / signal.abs().max() 180 | if sample_rate == 8000: 181 | signal, sample_rate = F.interpolate(signal.view(1, 1, -1), scale_factor = 2).squeeze(), 16000 182 | win_length = int(sample_rate * (window_size + 1e-8)) 183 | hop_length = int(sample_rate * (window_stride + 1e-8)) 184 | nfft = win_length 185 | return torch.stft(signal, nfft, win_length = win_length, hop_length = hop_length, window = torch.hann_window(nfft).type_as(signal), pad_mode = 'reflect', center = True).pow(2).sum(dim = -1).add(1e-9).sqrt() 186 | 187 | return frontend, model, (lambda c: ABC[c]), ABC.index 188 | 189 | def load_model_en_w2l(model_weights, batch_norm_eps = 0.001, ABC = " ABCDEFGHIJKLMNOPQRSTUVWXYZ'|"): 190 | def conv_block(kernel_size, num_channels, stride = 1, dilation = 1, repeat = 1, padding = 0): 191 | modules = [] 192 | for i in range(repeat): 193 | modules.append(nn.Conv1d(num_channels[0] if i == 0 else num_channels[1], num_channels[1], kernel_size = kernel_size, stride = stride, dilation = dilation, padding = padding)) 194 | modules.append(nn.Hardtanh(0, 20, inplace = True)) 195 | return nn.Sequential(*modules) 196 | 197 | model = nn.Sequential( 198 | conv_block(kernel_size = 11, num_channels = (64, 256), stride = 2, padding = 5), 199 | conv_block(kernel_size = 11, num_channels = (256, 256), repeat = 3, padding = 5), 200 | conv_block(kernel_size = 13, num_channels = (256, 384), repeat = 3, padding = 6), 201 | conv_block(kernel_size = 17, num_channels = (384, 512), repeat = 3, padding = 8), 202 | conv_block(kernel_size = 21, num_channels = (512, 640), repeat = 3, padding = 10), 203 | conv_block(kernel_size = 25, num_channels = (640, 768), repeat = 3, padding = 12), 204 | conv_block(kernel_size = 29, num_channels = (768, 896), repeat = 1, padding = 28, dilation = 2), 205 | conv_block(kernel_size = 1, num_channels = (896, 1024), repeat = 1), 206 | nn.Conv1d(1024, len(ABC), 1) 207 | ) 208 | 209 | h = h5py.File(model_weights) 210 | to_tensor = lambda path: torch.from_numpy(np.asarray(h[path])).to(torch.float32) 211 | state_dict = {} 212 | for param_name, param in model.state_dict().items(): 213 | ij = [int(c) for c in param_name if c.isdigit()] 214 | if len(ij) > 1: 215 | weight, moving_mean, moving_variance, gamma, beta = [to_tensor(f'ForwardPass/w2l_encoder/conv{1 + ij[0]}{1 + ij[1] // 2}/{suffix}') for suffix in ['kernel', 'bn/moving_mean', 'bn/moving_variance', 'bn/gamma', 'bn/beta']] 216 | weight, bias = fuse_conv_bn(weight.permute(2, 1, 0), moving_mean, moving_variance, gamma, beta, batch_norm_eps = batch_norm_eps) 217 | else: 218 | weight, bias = [to_tensor(f'ForwardPass/fully_connected_ctc_decoder/fully_connected/{suffix}') for suffix in ['kernel', 'bias']] 219 | weight = weight.t().unsqueeze(-1) 220 | state_dict[param_name] = (weight if 'weight' in param_name else bias).to(param.dtype) 221 | model.load_state_dict(state_dict) 222 | 223 | def frontend(signal, sample_rate, nfft = 512, nfilt = 64, preemph = 0.97, window_size = 0.020, window_stride = 0.010): 224 | def get_melscale_filterbanks(nfilt, nfft, samplerate): 225 | hz2mel = lambda hz: 2595 * math.log10(1+hz/700.) 226 | mel2hz = lambda mel: torch.mul(700, torch.sub(torch.pow(10, torch.div(mel, 2595)), 1)) 227 | 228 | lowfreq = 0 229 | highfreq = samplerate // 2 230 | lowmel = hz2mel(lowfreq) 231 | highmel = hz2mel(highfreq) 232 | melpoints = torch.linspace(lowmel,highmel,nfilt+2); 233 | bin = torch.floor(torch.mul(nfft+1, torch.div(mel2hz(melpoints), samplerate))).tolist() 234 | 235 | fbank = torch.zeros([nfilt, nfft // 2 + 1]).tolist() 236 | for j in range(nfilt): 237 | for i in range(int(bin[j]), int(bin[j+1])): 238 | fbank[j][i] = (i - bin[j]) / (bin[j+1]-bin[j]) 239 | for i in range(int(bin[j+1]), int(bin[j+2])): 240 | fbank[j][i] = (bin[j+2]-i) / (bin[j+2]-bin[j+1]) 241 | return torch.tensor(fbank) 242 | 243 | preemphasis = lambda signal, coeff: torch.cat([signal[:1], torch.sub(signal[1:], torch.mul(signal[:-1], coeff))]) 244 | win_length = int(sample_rate * (window_size + 1e-8)) 245 | hop_length = int(sample_rate * (window_stride + 1e-8)) 246 | pspec = torch.stft(preemphasis(signal, preemph), nfft, win_length = win_length, hop_length = hop_length, window = torch.hann_window(win_length), pad_mode = 'constant', center = False).pow(2).sum(dim = -1) / nfft 247 | mel_basis = get_melscale_filterbanks(nfilt, nfft, sample_rate).type_as(pspec) 248 | features = torch.log(torch.add(torch.matmul(mel_basis, pspec), 1e-20)) 249 | return (features - features.mean()) / features.std() 250 | 251 | return frontend, model, (lambda c: ABC[c]), ABC.index 252 | 253 | def fuse_conv_bn(weight, moving_mean, moving_variance, gamma, beta, batch_norm_eps): 254 | factor = gamma * (moving_variance + batch_norm_eps).rsqrt() 255 | weight *= factor.view(-1, *([1] * (weight.dim() - 1))) 256 | bias = beta - moving_mean * factor 257 | return weight, bias 258 | 259 | def decode_greedy(scores, idx2chr): 260 | decoded_greedy = scores.argmax(dim = 0).tolist() 261 | decoded_text = ''.join(map(idx2chr, decoded_greedy)) 262 | return ''.join(c for i, c in enumerate(decoded_text) if (i == 0 or c != decoded_text[i - 1]) and c != '|') 263 | 264 | if __name__ == '__main__': 265 | parser = argparse.ArgumentParser() 266 | parser.add_argument('--weights', default = 'w2l_plus_large_mp.h5') 267 | parser.add_argument('--model', default = 'en_w2l', choices = ['en_w2l', 'ru_w2l', 'en_jasper']) 268 | parser.add_argument('-i', '--input_path') 269 | parser.add_argument('--onnx') 270 | parser.add_argument('--tfjs') 271 | parser.add_argument('--pt') 272 | parser.add_argument('--tfjs_quantization_dtype', default = None, choices = ['uint8', 'uint16', None]) 273 | parser.add_argument('--device', default = 'cpu') 274 | args = parser.parse_args() 275 | 276 | torch.set_grad_enabled(False) 277 | frontend, model, idx2chr, chr2idx = dict(en_w2l = load_model_en_w2l, en_jasper = load_model_en_jasper, ru_w2l = load_model_ru_w2l)[args.model](args.weights) 278 | 279 | if args.input_path: 280 | sample_rate, signal = scipy.io.wavfile.read(args.input_path) 281 | assert sample_rate in [8000, 16000] 282 | features = frontend(torch.from_numpy(signal).to(torch.float32), sample_rate) 283 | scores = model.to(args.device)(features.unsqueeze(0).to(args.device)).squeeze(0) 284 | print(decode_greedy(scores, idx2chr)) 285 | 286 | if args.tfjs: 287 | # monkey-patching a module to have tfjs converter load with tf v1 288 | convert_tf_saved_model = None 289 | sys.modules['tensorflowjs.converters.tf_saved_model_conversion_v2'] = sys.modules[__name__] 290 | import tensorflowjs 291 | import tensorflow.keras as K 292 | pytorch2keras = lambda module: K.layers.Conv1D(module.out_channels, module.kernel_size, input_shape = (None, module.in_channels), data_format = 'channels_last', strides = module.stride, dilation_rate = module.dilation, padding = 'same', weights = [module.weight.detach().permute(2, 1, 0).numpy(), module.bias.detach().flatten().numpy()]) if isinstance(module, nn.Conv1d) else K.layers.ReLU(threshold = module.min_val, max_value = module.max_val) if isinstance(module, nn.Hardtanh) else K.layers.ReLU() if isinstance(module, nn.ReLU) else K.models.Sequential(list(map(pytorch2keras, module))) 293 | model, in_channels = pytorch2keras(model), model[0][0].in_channels 294 | model.build((None, None, in_channels)) 295 | tensorflowjs.converters.save_keras_model(model, args.tfjs, quantization_dtype = getattr(np, args.tfjs_quantization_dtype or '', None)) 296 | 297 | if args.onnx: 298 | # https://github.com/onnx/onnx/issues/740 299 | batch = torch.zeros(1, 1000, model[0][0].in_channels, dtype = torch.float32) 300 | torch.onnx.export(model, batch, args.onnx, input_names = ['input'], output_names = ['output']) 301 | 302 | if args.pt: 303 | torch.save(model, args.pt) 304 | -------------------------------------------------------------------------------- /vis.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import speech2text 4 | 5 | frontend, model, idx2chr, chr2idx = speech2text.load_model_en_jasper('jasper10x5_LibriSpeech_nvgrad_masks.h5') 6 | 7 | #model = torch.load('w2l_plus_large_mp.pt') 8 | 9 | convs = [m for m in model.modules() if isinstance(m, torch.nn.Conv1d) and m.kernel_size[0] > 1] 10 | print('\n'.join(str(c) for c in convs)) 11 | plt.figure(figsize = (6, 15)) 12 | 13 | for i, conv in enumerate(convs, start = 1): 14 | weight = conv.weight 15 | plt.subplot(len(convs), 1, i) 16 | plt.imshow(weight.abs().mean(dim = 1).detach().numpy(), origin = 'lower', aspect = 'auto') 17 | plt.gca().tick_params(axis='both', which='both', labelsize=5, length = 0) 18 | # plt.title(f'out_channels = {weight.shape[0]} | kernel_size = {weight.shape[-1]}') 19 | 20 | plt.subplots_adjust(top = 0.99, bottom=0.01, hspace=0.8, wspace=0.4) 21 | plt.savefig('vis.png', dpi = 150) 22 | --------------------------------------------------------------------------------