├── LICENSE ├── README.md ├── infer-batch.py ├── infer.py ├── infer_label.py ├── requirement.txt └── whisper_ph_asr ├── __init__.py ├── __pycache__ ├── __init__.cpython-310.pyc ├── attentions.cpython-310.pyc ├── commons.cpython-310.pyc └── whisper_encoder.cpython-310.pyc ├── attentions.py ├── commons.py ├── mel_filters.npz └── whisper_encoder.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Infinity-INF 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fast-phasr 2 | Phonemes and durations labeling based on whisper small 3 | ## 本项目暂时弃坑,建议使用更先进的SOFA_AI 4 | ### https://github.com/colstone/SOFA_AI 5 | -------------------------------------------------------------------------------- /infer-batch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import librosa 5 | from tqdm import tqdm 6 | import whisper_ph_asr 7 | 8 | devices = torch.cuda.is_available() 9 | if devices: 10 | print("Use CUDA") 11 | device = torch.device('cuda') 12 | asr = whisper_ph_asr.PhonemeAsr().cuda() 13 | else: 14 | print("Use CPU") 15 | device = torch.device('cpu') 16 | asr = whisper_ph_asr.PhonemeAsr().cpu() 17 | 18 | parser = argparse.ArgumentParser(description="Batch inference for audio files") 19 | parser.add_argument("--batch", action="store_true", help="Enable batch inference mode") 20 | parser.add_argument("input_dir", type=str, help="Input directory containing WAV files") 21 | args = parser.parse_args() 22 | 23 | def get_wav_file_list(input_dir): 24 | wav_file_list = [] 25 | for filename in os.listdir(input_dir): 26 | if filename.endswith(".wav"): 27 | wav_file_list.append(os.path.join(input_dir, filename)) 28 | return wav_file_list 29 | 30 | if args.batch: 31 | sounddir = args.input_dir 32 | wav_file_list = get_wav_file_list(sounddir) 33 | else: 34 | print("Input your sounds directory:") 35 | sounddir = input() 36 | wav_file_list = [sounddir] 37 | 38 | pth = "phasr.pth" 39 | ckpt = torch.load(pth) 40 | asr.load_state_dict(ckpt) 41 | 42 | for wav_file in tqdm(wav_file_list, desc="Processing"): 43 | wav16k, _ = librosa.load(wav_file, sr=16000) 44 | phonemes, durations = whisper_ph_asr.get_asr_result(asr, wav16k) 45 | 46 | htk_labels = [] 47 | current_time = 0 48 | 49 | for phoneme, duration in zip(phonemes, durations): 50 | htk_label = f"{current_time} {current_time + int(duration * 10000000)} {phoneme}" 51 | htk_labels.append(htk_label) 52 | current_time += int(duration * 10000000) 53 | 54 | output_filename = os.path.splitext(os.path.basename(wav_file))[0] + ".lab" 55 | output_path = os.path.join(os.path.dirname(wav_file), output_filename) 56 | 57 | with open(output_path, 'w') as f: 58 | for label in htk_labels: 59 | f.write(label + '\n') 60 | 61 | print("HTK-style labels saved as:", output_path) 62 | print("Inference completed for:", wav_file) 63 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import torch 3 | import whisper_ph_asr 4 | 5 | devices=torch.cuda.is_available() 6 | if devices==True: 7 | print("Use CUDA") 8 | device=torch.device('cuda') 9 | asr = whisper_ph_asr.PhonemeAsr().cuda() 10 | else: 11 | print("Use CPU") 12 | torch.device('cpu') 13 | asr = whisper_ph_asr.PhonemeAsr().cpu() 14 | 15 | print("input your sounds directory:") 16 | sounddir=input() 17 | 18 | pth = "phasr.pth" 19 | ckpt = torch.load(pth) 20 | asr.load_state_dict(ckpt) 21 | 22 | wav16k, _ = librosa.load(sounddir, sr=16000) 23 | phonemes, durations = whisper_ph_asr.get_asr_result(asr, wav16k) 24 | 25 | print(phonemes, durations) 26 | -------------------------------------------------------------------------------- /infer_label.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import torch 3 | import whisper_ph_asr 4 | import os 5 | 6 | devices = torch.cuda.is_available() 7 | if devices: 8 | print("Use CUDA") 9 | device = torch.device('cuda') 10 | asr = whisper_ph_asr.PhonemeAsr().cuda() 11 | else: 12 | print("Use CPU") 13 | device = torch.device('cpu') 14 | asr = whisper_ph_asr.PhonemeAsr().cpu() 15 | 16 | print("Input your sounds directory:") 17 | sounddir = input() 18 | 19 | pth = "phasr.pth" 20 | ckpt = torch.load(pth) 21 | asr.load_state_dict(ckpt) 22 | 23 | wav16k, _ = librosa.load(sounddir, sr=16000) 24 | phonemes, durations = whisper_ph_asr.get_asr_result(asr, wav16k) 25 | 26 | htk_labels = [] 27 | current_time = 0 28 | 29 | for phoneme, duration in zip(phonemes, durations): 30 | htk_label = f"{current_time} {current_time + int(duration * 10000000)} {phoneme}" 31 | htk_labels.append(htk_label) 32 | current_time += int(duration * 10000000) 33 | 34 | output_filename = os.path.splitext(os.path.basename(sounddir))[0] + ".lab" 35 | output_path = os.path.join(os.path.dirname(sounddir), output_filename) 36 | 37 | with open(output_path, 'w') as f: 38 | for label in htk_labels: 39 | f.write(label + '\n') 40 | 41 | print("HTK-style labels saved as:", output_path) 42 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | ffmpeg==1.4 2 | librosa==0.10.0.post2 3 | numpy==1.24.4 4 | torch==2.0.1+cu118 5 | gradio 6 | -------------------------------------------------------------------------------- /whisper_ph_asr/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import librosa 4 | import torch 5 | from torch import nn 6 | 7 | from . import commons 8 | from . import attentions 9 | from .whisper_encoder import AudioEncoder, log_mel_spectrogram, pad_or_trim 10 | ttsing_phone_set = ['_'] + [ 11 | "b", "c", "ch", "d", "f", "g", "h", "j", "k", "l", "m", "n", "p", "q", "r", 12 | "s", "sh", "t", "x", "z", "zh", "a", "ai", "an", "ang", "ao", "e", "ei", 13 | "en", "eng", "er", "iii", "ii", "i", "ia", "ian", "iang", "iao", "ie", "in", 14 | "ing", "iong", "iou", "o", "ong", "ou", "u", "ua", "uai", "uan", "uang", 15 | "uei", "uen", "ueng", "uo", "v", "van", "ve", "vn", "AH", "AA", "AO", "ER", 16 | "IH", "IY", "UH", "UW", "EH", "AE", "AY", "EY", "OY", "AW", "OW", "P", "B", 17 | "T", "D", "K", "G", "M", "N", "NG", "L", "S", "Z", "Y", "TH", "DH", "SH", 18 | "ZH", "CH", "JH", "V", "W", "F", "R", "HH", "AH0", "AA0", "AO0", "ER0", 19 | "IH0", "IY0", "UH0", "UW0", "EH0", "AE0", "AY0", "EY0", "OY0", "AW0", "OW0", 20 | "AH1", "AA1", "AO1", "ER1", "IH1", "IY1", "UH1", "UW1", "EH1", "AE1", "AY1", 21 | "EY1", "OY1", "AW1", "OW1", "AH2", "AA2", "AO2", "ER2", "IH2", "IY2", "UH2", 22 | "UW2", "EH2", "AE2", "AY2", "EY2", "OY2", "AW2", "OW2", "AH3", "AA3", "AO3", 23 | "ER3", "IH3", "IY3", "UH3", "UW3", "EH3", "AE3", "AY3", "EY3", "OY3", "AW3", 24 | "OW3", "D-1", "T-1", "P*", "B*", "T*", "D*", "K*", "G*", "M*", "N*", "NG*", 25 | "L*", "S*", "Z*", "Y*", "TH*", "DH*", "SH*", "ZH*", "CH*", "JH*", "V*", 26 | "W*", "F*", "R*", "HH*", "sp", "sil", "or", "ar", "aor", "our", "angr", 27 | "eir", "engr", "air", "ianr", "iaor", "ir", "ingr", "ur", "iiir", "uar", 28 | "uangr", "uenr", "iir", "ongr", "uor", "ueir", "iar", "iangr", "inr", 29 | "iour", "vr", "uanr", "ruai", "TR", "rest", 30 | # opencpop 31 | 'w', 'SP', 'AP', 'un', 'y', 'ui', 'iu', 32 | # opencpop-strict 33 | 'i0', 'E', 'En', 34 | # japanese-common 35 | 'ts.', 'f.', 'sh.', 'ry.', 'py.', 'h.', 'p.', 'N.', 'a.', 'm.', 'w.', 'ky.', 36 | 'n.', 'd.', 'j.', 'cl.', 'ny.', 'z.', 'o.', 'y.', 't.', 'u.', 'r.', 'pau', 37 | 'ch.', 'e.', 'b.', 'k.', 'g.', 's.', 'i.', 38 | # japanese-unique 39 | 'gy.', 'my.', 'hy.', 'br', 'by.', 'v.', 'ty.', 'xx.', 'U.', 'I.', 'dy.' 40 | ] 41 | ttsing_phone_to_int = {} 42 | int_to_ttsing_phone = {} 43 | for idx, item in enumerate(ttsing_phone_set): 44 | ttsing_phone_to_int[item] = idx 45 | int_to_ttsing_phone[idx] = item 46 | 47 | 48 | LRELU_SLOPE = 0.1 49 | 50 | 51 | hps = { 52 | "data": { 53 | "unit_dim": 768, 54 | }, 55 | "model": { 56 | "hidden_channels": 192, 57 | "spk_channels": 192, 58 | "filter_channels": 768, 59 | "n_heads": 2, 60 | "n_layers": 4, 61 | "kernel_size": 3, 62 | "p_dropout": 0.1, 63 | "prior_hidden_channels": 192, 64 | "prior_filter_channels": 768, 65 | "prior_n_heads": 2, 66 | "prior_n_layers": 4, 67 | "prior_kernel_size": 3, 68 | "prior_p_dropout": 0.1, 69 | "resblock": "1", 70 | "use_spectral_norm": False, 71 | "resblock_kernel_sizes": [3,7,11], 72 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 73 | "upsample_rates": [8,8,4,2], 74 | "upsample_initial_channel": 256, 75 | "upsample_kernel_sizes": [16,16,8,4], 76 | "n_harmonic": 64, 77 | "n_bands": 65 78 | } 79 | } 80 | 81 | 82 | class PhonemeAsr(nn.Module): 83 | """ 84 | Model 85 | """ 86 | 87 | def __init__(self): 88 | super().__init__() 89 | self.hps = hps 90 | 91 | self.pre_net = nn.Conv1d(768, hps["model"]["prior_hidden_channels"], 1) 92 | self.proj = nn.Conv1d(hps["model"]["prior_hidden_channels"], len(ttsing_phone_set), 1) 93 | self.encoder = attentions.Encoder( 94 | hps["model"]["prior_hidden_channels"], 95 | hps["model"]["prior_filter_channels"], 96 | hps["model"]["prior_n_heads"], 97 | hps["model"]["prior_n_layers"], 98 | hps["model"]["prior_kernel_size"], 99 | hps["model"]["prior_p_dropout"]) 100 | self.whisper_model = AudioEncoder(80, 1500, 768, 12, 12) 101 | 102 | def forward(self, units): 103 | phone_lengths = torch.LongTensor([units.shape[2]]).to(units.device) 104 | x = self.pre_net(units) 105 | x_mask = torch.unsqueeze(commons.sequence_mask(phone_lengths, x.size(2)), 1).to(x.dtype) 106 | x = self.encoder(x * x_mask, x_mask) 107 | x = self.proj(x) 108 | return x 109 | 110 | 111 | 112 | def get_whisper_units(model=None, wav16k_numpy=None): 113 | dev = next(model.parameters()).device 114 | mel = log_mel_spectrogram(wav16k_numpy).to(dev)[:, :3000] 115 | # if torch.cuda.is_available(): 116 | # mel = mel.to(torch.float16) 117 | feature_len = mel.shape[-1] // 2 118 | assert mel.shape[-1] < 3000, "输入音频过长,只允许输入30以内音频" 119 | with torch.no_grad(): 120 | feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[:1, :feature_len, :].cpu().transpose(1,2) 121 | return feature 122 | 123 | def load_checkpoint(checkpoint_path, model): 124 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 125 | model.load_state_dict(checkpoint_dict) 126 | 127 | def remove_consecutive_duplicates(lst): 128 | sr = 16000 129 | hop = 320 130 | new_lst = [] 131 | dur_lst = [] 132 | previous = None 133 | count = 1 134 | for item in lst: 135 | if item == previous: 136 | count += 1 137 | else: 138 | if previous: 139 | new_lst.append(f"{previous}") 140 | dur_lst.append(count*hop/sr) 141 | previous = item 142 | count = 1 143 | new_lst.append(f"{previous}") 144 | dur_lst.append(count*hop/sr) 145 | return new_lst, dur_lst 146 | 147 | def convert_x_to_phones(x): 148 | phoneme_ids = torch.argmax(x, dim=1) 149 | phones, durs = remove_consecutive_duplicates([int_to_ttsing_phone[int(i)] for i in phoneme_ids[0, :]]) 150 | return phones, durs 151 | 152 | def load_phoneme_asr_model(): 153 | # whisper_model = load_whisper_model() 154 | current_file = os.path.abspath(__file__) 155 | current_directory = os.path.dirname(current_file) 156 | checkpoint_path = f"{current_directory}/full_asr_model.pth" 157 | asr_model = PhonemeAsr(hps) 158 | _ = asr_model.eval() 159 | load_checkpoint(checkpoint_path, asr_model) 160 | if torch.cuda.is_available(): 161 | asr_model = asr_model.cuda() 162 | # asr_model = asr_model.half() 163 | return asr_model 164 | 165 | def get_asr_result(asr_model, wav16k_numpy): 166 | units = get_whisper_units(asr_model.whisper_model, wav16k_numpy) 167 | with torch.no_grad(): 168 | if torch.cuda.is_available(): 169 | units = units.cuda() 170 | x = asr_model(units) 171 | x = x.cpu() 172 | phones, durs = convert_x_to_phones(x) 173 | return phones, durs 174 | 175 | def get_silent_result(asr_model, wav16k_numpy): 176 | units = get_whisper_units(asr_model.whisper_model, wav16k_numpy) 177 | with torch.no_grad(): 178 | if torch.cuda.is_available(): 179 | units = units.cuda() 180 | x = asr_model(units) 181 | x = x.cpu() 182 | phoneme_ids = torch.argmax(x, dim=1) 183 | phonemes = [int_to_ttsing_phone[int(i)] for i in phoneme_ids[0, :]] 184 | 185 | res_list = [] 186 | previous = None 187 | for idx, item in enumerate(phonemes): 188 | if item != previous: 189 | if item in ["SP", "AP", "pau"]: 190 | res_list.append(item) 191 | else: 192 | res_list.append(None) 193 | 194 | previous = item 195 | else: 196 | res_list.append(None) 197 | 198 | 199 | # print(res_list) 200 | # print(len(phonemes)) 201 | # print(sum([1 if i==j else 0 for i, j in zip(res_list, phonemes)])) 202 | return res_list 203 | 204 | -------------------------------------------------------------------------------- /whisper_ph_asr/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Infinity-INF/fast-phasr/4a01a60ad5805613d607d604c4f8b145e8282bcd/whisper_ph_asr/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /whisper_ph_asr/__pycache__/attentions.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Infinity-INF/fast-phasr/4a01a60ad5805613d607d604c4f8b145e8282bcd/whisper_ph_asr/__pycache__/attentions.cpython-310.pyc -------------------------------------------------------------------------------- /whisper_ph_asr/__pycache__/commons.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Infinity-INF/fast-phasr/4a01a60ad5805613d607d604c4f8b145e8282bcd/whisper_ph_asr/__pycache__/commons.cpython-310.pyc -------------------------------------------------------------------------------- /whisper_ph_asr/__pycache__/whisper_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Infinity-INF/fast-phasr/4a01a60ad5805613d607d604c4f8b145e8282bcd/whisper_ph_asr/__pycache__/whisper_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /whisper_ph_asr/attentions.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | from . import commons 9 | 10 | 11 | class LayerNorm(nn.Module): 12 | def __init__(self, channels, eps=1e-5): 13 | super().__init__() 14 | self.channels = channels 15 | self.eps = eps 16 | 17 | self.gamma = nn.Parameter(torch.ones(channels)) 18 | self.beta = nn.Parameter(torch.zeros(channels)) 19 | 20 | def forward(self, x): 21 | x = x.transpose(1, -1) 22 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 23 | return x.transpose(1, -1) 24 | 25 | 26 | class Encoder(nn.Module): 27 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, 28 | **kwargs): 29 | super().__init__() 30 | self.hidden_channels = hidden_channels 31 | self.filter_channels = filter_channels 32 | self.n_heads = n_heads 33 | self.n_layers = n_layers 34 | self.kernel_size = kernel_size 35 | self.p_dropout = p_dropout 36 | self.window_size = window_size 37 | 38 | self.drop = nn.Dropout(p_dropout) 39 | self.attn_layers = nn.ModuleList() 40 | self.norm_layers_1 = nn.ModuleList() 41 | self.ffn_layers = nn.ModuleList() 42 | self.norm_layers_2 = nn.ModuleList() 43 | for i in range(self.n_layers): 44 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, 45 | window_size=window_size)) 46 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 47 | self.ffn_layers.append( 48 | FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) 49 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 50 | 51 | def forward(self, x, x_mask): 52 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 53 | x = x * x_mask 54 | for i in range(self.n_layers): 55 | y = self.attn_layers[i](x, x, attn_mask) 56 | y = self.drop(y) 57 | x = self.norm_layers_1[i](x + y) 58 | 59 | y = self.ffn_layers[i](x, x_mask) 60 | y = self.drop(y) 61 | x = self.norm_layers_2[i](x + y) 62 | x = x * x_mask 63 | return x 64 | 65 | 66 | class Decoder(nn.Module): 67 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., 68 | proximal_bias=False, proximal_init=True, **kwargs): 69 | super().__init__() 70 | self.hidden_channels = hidden_channels 71 | self.filter_channels = filter_channels 72 | self.n_heads = n_heads 73 | self.n_layers = n_layers 74 | self.kernel_size = kernel_size 75 | self.p_dropout = p_dropout 76 | self.proximal_bias = proximal_bias 77 | self.proximal_init = proximal_init 78 | 79 | self.drop = nn.Dropout(p_dropout) 80 | self.self_attn_layers = nn.ModuleList() 81 | self.norm_layers_0 = nn.ModuleList() 82 | self.encdec_attn_layers = nn.ModuleList() 83 | self.norm_layers_1 = nn.ModuleList() 84 | self.ffn_layers = nn.ModuleList() 85 | self.norm_layers_2 = nn.ModuleList() 86 | for i in range(self.n_layers): 87 | self.self_attn_layers.append( 88 | MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, 89 | proximal_bias=proximal_bias, proximal_init=proximal_init)) 90 | self.norm_layers_0.append(LayerNorm(hidden_channels)) 91 | self.encdec_attn_layers.append( 92 | MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) 93 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 94 | self.ffn_layers.append( 95 | FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) 96 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 97 | 98 | def forward(self, x, x_mask, h, h_mask): 99 | """ 100 | x: decoder input 101 | h: encoder output 102 | """ 103 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) 104 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 105 | x = x * x_mask 106 | for i in range(self.n_layers): 107 | y = self.self_attn_layers[i](x, x, self_attn_mask) 108 | y = self.drop(y) 109 | x = self.norm_layers_0[i](x + y) 110 | 111 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) 112 | y = self.drop(y) 113 | x = self.norm_layers_1[i](x + y) 114 | 115 | y = self.ffn_layers[i](x, x_mask) 116 | y = self.drop(y) 117 | x = self.norm_layers_2[i](x + y) 118 | x = x * x_mask 119 | return x 120 | 121 | 122 | class FFT(nn.Module): 123 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers=1, kernel_size=1, p_dropout=0., 124 | proximal_bias=False, proximal_init=True, **kwargs): 125 | super().__init__() 126 | self.hidden_channels = hidden_channels 127 | self.filter_channels = filter_channels 128 | self.n_heads = n_heads 129 | self.n_layers = n_layers 130 | self.kernel_size = kernel_size 131 | self.p_dropout = p_dropout 132 | self.proximal_bias = proximal_bias 133 | self.proximal_init = proximal_init 134 | 135 | self.drop = nn.Dropout(p_dropout) 136 | self.self_attn_layers = nn.ModuleList() 137 | self.norm_layers_0 = nn.ModuleList() 138 | self.ffn_layers = nn.ModuleList() 139 | self.norm_layers_1 = nn.ModuleList() 140 | for i in range(self.n_layers): 141 | self.self_attn_layers.append( 142 | MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, 143 | proximal_bias=proximal_bias, proximal_init=proximal_init)) 144 | self.norm_layers_0.append(LayerNorm(hidden_channels)) 145 | self.ffn_layers.append( 146 | FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) 147 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 148 | 149 | def forward(self, x, x_mask): 150 | """ 151 | x: decoder input 152 | h: encoder output 153 | """ 154 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) 155 | x = x * x_mask 156 | for i in range(self.n_layers): 157 | y = self.self_attn_layers[i](x, x, self_attn_mask) 158 | y = self.drop(y) 159 | x = self.norm_layers_0[i](x + y) 160 | 161 | y = self.ffn_layers[i](x, x_mask) 162 | y = self.drop(y) 163 | x = self.norm_layers_1[i](x + y) 164 | x = x * x_mask 165 | return x 166 | 167 | 168 | class FFNs(nn.Module): 169 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers=1, kernel_size=1, p_dropout=0., 170 | proximal_bias=False, proximal_init=True, **kwargs): 171 | super().__init__() 172 | self.hidden_channels = hidden_channels 173 | self.filter_channels = filter_channels 174 | self.n_heads = n_heads 175 | self.n_layers = n_layers 176 | self.kernel_size = kernel_size 177 | self.p_dropout = p_dropout 178 | self.proximal_bias = proximal_bias 179 | self.proximal_init = proximal_init 180 | 181 | self.drop = nn.Dropout(p_dropout) 182 | # self.self_attn_layers = nn.ModuleList() 183 | # self.norm_layers_0 = nn.ModuleList() 184 | self.ffn_layers = nn.ModuleList() 185 | self.norm_layers_1 = nn.ModuleList() 186 | for i in range(self.n_layers): 187 | # self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init)) 188 | # self.norm_layers_0.append(LayerNorm(hidden_channels)) 189 | self.ffn_layers.append( 190 | FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) 191 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 192 | 193 | def forward(self, x, x_mask): 194 | """ 195 | x: decoder input 196 | h: encoder output 197 | """ 198 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) 199 | x = x * x_mask 200 | for i in range(self.n_layers): 201 | # y = self.self_attn_layers[i](x, x, self_attn_mask) 202 | # y = self.drop(y) 203 | # x = self.norm_layers_0[i](x + y) 204 | 205 | y = self.ffn_layers[i](x, x_mask) 206 | y = self.drop(y) 207 | x = self.norm_layers_1[i](x + y) 208 | x = x * x_mask 209 | return x 210 | 211 | 212 | class MultiHeadAttention(nn.Module): 213 | def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, 214 | block_length=None, proximal_bias=False, proximal_init=False): 215 | super().__init__() 216 | assert channels % n_heads == 0 217 | 218 | self.channels = channels 219 | self.out_channels = out_channels 220 | self.n_heads = n_heads 221 | self.p_dropout = p_dropout 222 | self.window_size = window_size 223 | self.heads_share = heads_share 224 | self.block_length = block_length 225 | self.proximal_bias = proximal_bias 226 | self.proximal_init = proximal_init 227 | self.attn = None 228 | 229 | self.k_channels = channels // n_heads 230 | self.conv_q = nn.Conv1d(channels, channels, 1) 231 | self.conv_k = nn.Conv1d(channels, channels, 1) 232 | self.conv_v = nn.Conv1d(channels, channels, 1) 233 | self.conv_o = nn.Conv1d(channels, out_channels, 1) 234 | self.drop = nn.Dropout(p_dropout) 235 | 236 | if window_size is not None: 237 | n_heads_rel = 1 if heads_share else n_heads 238 | rel_stddev = self.k_channels ** -0.5 239 | self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 240 | self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 241 | 242 | nn.init.xavier_uniform_(self.conv_q.weight) 243 | nn.init.xavier_uniform_(self.conv_k.weight) 244 | nn.init.xavier_uniform_(self.conv_v.weight) 245 | if proximal_init: 246 | with torch.no_grad(): 247 | self.conv_k.weight.copy_(self.conv_q.weight) 248 | self.conv_k.bias.copy_(self.conv_q.bias) 249 | 250 | def forward(self, x, c, attn_mask=None): 251 | q = self.conv_q(x) 252 | k = self.conv_k(c) 253 | v = self.conv_v(c) 254 | 255 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 256 | 257 | x = self.conv_o(x) 258 | return x 259 | 260 | def attention(self, query, key, value, mask=None): 261 | # reshape [b, d, t] -> [b, n_h, t, d_k] 262 | b, d, t_s, t_t = (*key.size(), query.size(2)) 263 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 264 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 265 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 266 | 267 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) 268 | if self.window_size is not None: 269 | assert t_s == t_t, "Relative attention is only available for self-attention." 270 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 271 | rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings) 272 | scores_local = self._relative_position_to_absolute_position(rel_logits) 273 | scores = scores + scores_local 274 | if self.proximal_bias: 275 | assert t_s == t_t, "Proximal bias is only available for self-attention." 276 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) 277 | if mask is not None: 278 | scores = scores.masked_fill(mask == 0, -1e4) 279 | if self.block_length is not None: 280 | assert t_s == t_t, "Local attention is only available for self-attention." 281 | block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) 282 | scores = scores.masked_fill(block_mask == 0, -1e4) 283 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 284 | p_attn = self.drop(p_attn) 285 | output = torch.matmul(p_attn, value) 286 | if self.window_size is not None: 287 | relative_weights = self._absolute_position_to_relative_position(p_attn) 288 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) 289 | output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) 290 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] 291 | return output, p_attn 292 | 293 | def _matmul_with_relative_values(self, x, y): 294 | """ 295 | x: [b, h, l, m] 296 | y: [h or 1, m, d] 297 | ret: [b, h, l, d] 298 | """ 299 | ret = torch.matmul(x, y.unsqueeze(0)) 300 | return ret 301 | 302 | def _matmul_with_relative_keys(self, x, y): 303 | """ 304 | x: [b, h, l, d] 305 | y: [h or 1, m, d] 306 | ret: [b, h, l, m] 307 | """ 308 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 309 | return ret 310 | 311 | def _get_relative_embeddings(self, relative_embeddings, length): 312 | max_relative_position = 2 * self.window_size + 1 313 | # Pad first before slice to avoid using cond ops. 314 | pad_length = max(length - (self.window_size + 1), 0) 315 | slice_start_position = max((self.window_size + 1) - length, 0) 316 | slice_end_position = slice_start_position + 2 * length - 1 317 | if pad_length > 0: 318 | padded_relative_embeddings = F.pad( 319 | relative_embeddings, 320 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) 321 | else: 322 | padded_relative_embeddings = relative_embeddings 323 | used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position] 324 | return used_relative_embeddings 325 | 326 | def _relative_position_to_absolute_position(self, x): 327 | """ 328 | x: [b, h, l, 2*l-1] 329 | ret: [b, h, l, l] 330 | """ 331 | batch, heads, length, _ = x.size() 332 | # Concat columns of pad to shift from relative to absolute indexing. 333 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) 334 | 335 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 336 | x_flat = x.view([batch, heads, length * 2 * length]) 337 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])) 338 | 339 | # Reshape and slice out the padded elements. 340 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:] 341 | return x_final 342 | 343 | def _absolute_position_to_relative_position(self, x): 344 | """ 345 | x: [b, h, l, l] 346 | ret: [b, h, l, 2*l-1] 347 | """ 348 | batch, heads, length, _ = x.size() 349 | # padd along column 350 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) 351 | x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)]) 352 | # add 0's in the beginning that will skew the elements after reshape 353 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 354 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] 355 | return x_final 356 | 357 | def _attention_bias_proximal(self, length): 358 | """Bias for self-attention to encourage attention to close positions. 359 | Args: 360 | length: an integer scalar. 361 | Returns: 362 | a Tensor with shape [1, 1, length, length] 363 | """ 364 | r = torch.arange(length, dtype=torch.float32) 365 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 366 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 367 | 368 | 369 | class FFN(nn.Module): 370 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, 371 | causal=False): 372 | super().__init__() 373 | self.in_channels = in_channels 374 | self.out_channels = out_channels 375 | self.filter_channels = filter_channels 376 | self.kernel_size = kernel_size 377 | self.p_dropout = p_dropout 378 | self.activation = activation 379 | self.causal = causal 380 | 381 | if causal: 382 | self.padding = self._causal_padding 383 | else: 384 | self.padding = self._same_padding 385 | 386 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) 387 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) 388 | self.drop = nn.Dropout(p_dropout) 389 | 390 | def forward(self, x, x_mask): 391 | x = self.conv_1(self.padding(x * x_mask)) 392 | if self.activation == "gelu": 393 | x = x * torch.sigmoid(1.702 * x) 394 | else: 395 | x = torch.relu(x) 396 | x = self.drop(x) 397 | x = self.conv_2(self.padding(x * x_mask)) 398 | return x * x_mask 399 | 400 | def _causal_padding(self, x): 401 | if self.kernel_size == 1: 402 | return x 403 | pad_l = self.kernel_size - 1 404 | pad_r = 0 405 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 406 | x = F.pad(x, commons.convert_pad_shape(padding)) 407 | return x 408 | 409 | def _same_padding(self, x): 410 | if self.kernel_size == 1: 411 | return x 412 | pad_l = (self.kernel_size - 1) // 2 413 | pad_r = self.kernel_size // 2 414 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 415 | x = F.pad(x, commons.convert_pad_shape(padding)) 416 | return x 417 | -------------------------------------------------------------------------------- /whisper_ph_asr/commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | def init_weights(m, mean=0.0, std=0.01): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | m.weight.data.normal_(mean, std) 12 | 13 | 14 | def get_padding(kernel_size, dilation=1): 15 | return int((kernel_size * dilation - dilation) / 2) 16 | 17 | 18 | def convert_pad_shape(pad_shape): 19 | l = pad_shape[::-1] 20 | pad_shape = [item for sublist in l for item in sublist] 21 | return pad_shape 22 | 23 | 24 | def intersperse(lst, item): 25 | result = [item] * (len(lst) * 2 + 1) 26 | result[1::2] = lst 27 | return result 28 | 29 | 30 | def kl_divergence(m_p, logs_p, m_q, logs_q): 31 | """KL(P||Q)""" 32 | kl = (logs_q - logs_p) - 0.5 33 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2. * logs_q) 34 | return kl 35 | 36 | 37 | def rand_gumbel(shape): 38 | """Sample from the Gumbel distribution, protect from overflows.""" 39 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 40 | return -torch.log(-torch.log(uniform_samples)) 41 | 42 | 43 | def rand_gumbel_like(x): 44 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 45 | return g 46 | 47 | 48 | def slice_segments(x, ids_str, segment_size=4): 49 | ret = torch.zeros_like(x[:, :, :segment_size]) 50 | # print("ret shape: ",ret.shape, ids_str) 51 | for i in range(x.size(0)): 52 | idx_str = ids_str[i] 53 | idx_end = idx_str + segment_size 54 | ret[i] = x[i, :, idx_str:idx_end] 55 | return ret 56 | 57 | 58 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 59 | b, d, t = x.size() 60 | if x_lengths is None: 61 | x_lengths = t 62 | ids_str_max = x_lengths - segment_size - 1 63 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 64 | ret = slice_segments(x, ids_str, segment_size) 65 | return ret, ids_str 66 | 67 | 68 | def get_timing_signal_1d( 69 | length, channels, min_timescale=1.0, max_timescale=1.0e4): 70 | position = torch.arange(length, dtype=torch.float) 71 | num_timescales = channels // 2 72 | log_timescale_increment = ( 73 | math.log(float(max_timescale) / float(min_timescale)) / 74 | (num_timescales - 1)) 75 | inv_timescales = min_timescale * torch.exp( 76 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) 77 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 78 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 79 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 80 | signal = signal.view(1, channels, length) 81 | return signal 82 | 83 | 84 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 85 | b, channels, length = x.size() 86 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 87 | return x + signal.to(dtype=x.dtype, device=x.device) 88 | 89 | 90 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 91 | b, channels, length = x.size() 92 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 93 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 94 | 95 | 96 | def subsequent_mask(length): 97 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 98 | return mask 99 | 100 | 101 | @torch.jit.script 102 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 103 | n_channels_int = n_channels[0] 104 | in_act = input_a + input_b 105 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 106 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 107 | acts = t_act * s_act 108 | return acts 109 | 110 | 111 | def convert_pad_shape(pad_shape): 112 | l = pad_shape[::-1] 113 | pad_shape = [item for sublist in l for item in sublist] 114 | return pad_shape 115 | 116 | 117 | def shift_1d(x): 118 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 119 | return x 120 | 121 | 122 | def sequence_mask(length, max_length=None): 123 | if max_length is None: 124 | max_length = length.max() 125 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 126 | return x.unsqueeze(0) < length.unsqueeze(1) 127 | 128 | 129 | def generate_path(duration, mask): 130 | """ 131 | duration: [b, 1, t_x] 132 | mask: [b, 1, t_y, t_x] 133 | """ 134 | device = duration.device 135 | 136 | b, _, t_y, t_x = mask.shape 137 | cum_duration = torch.cumsum(duration, -1) 138 | 139 | cum_duration_flat = cum_duration.view(b * t_x) 140 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 141 | path = path.view(b, t_x, t_y) 142 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 143 | path = path.unsqueeze(1).transpose(2, 3) * mask 144 | return path 145 | 146 | 147 | def clip_grad_value_(parameters, clip_value, norm_type=2): 148 | if isinstance(parameters, torch.Tensor): 149 | parameters = [parameters] 150 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 151 | norm_type = float(norm_type) 152 | if clip_value is not None: 153 | clip_value = float(clip_value) 154 | 155 | total_norm = 0 156 | for p in parameters: 157 | param_norm = p.grad.data.norm(norm_type) 158 | total_norm += param_norm.item() ** norm_type 159 | if clip_value is not None: 160 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 161 | total_norm = total_norm ** (1. / norm_type) 162 | return total_norm 163 | -------------------------------------------------------------------------------- /whisper_ph_asr/mel_filters.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Infinity-INF/fast-phasr/4a01a60ad5805613d607d604c4f8b145e8282bcd/whisper_ph_asr/mel_filters.npz -------------------------------------------------------------------------------- /whisper_ph_asr/whisper_encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import Tensor, nn 7 | from typing import Dict, Iterable, Optional 8 | from typing import Optional, Union 9 | import ffmpeg 10 | from functools import lru_cache 11 | 12 | # hard-coded audio hyperparameters 13 | SAMPLE_RATE = 16000 14 | N_FFT = 400 15 | N_MELS = 80 16 | HOP_LENGTH = 160 17 | CHUNK_LENGTH = 30 18 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk 19 | 20 | 21 | class LayerNorm(nn.LayerNorm): 22 | def forward(self, x: Tensor) -> Tensor: 23 | return super().forward(x.float()).type(x.dtype) 24 | 25 | 26 | class Linear(nn.Linear): 27 | def forward(self, x: Tensor) -> Tensor: 28 | return F.linear( 29 | x, 30 | self.weight.to(x.dtype), 31 | None if self.bias is None else self.bias.to(x.dtype), 32 | ) 33 | 34 | 35 | class Conv1d(nn.Conv1d): 36 | def _conv_forward( 37 | self, x: Tensor, weight: Tensor, bias: Optional[Tensor] 38 | ) -> Tensor: 39 | return super()._conv_forward( 40 | x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) 41 | ) 42 | 43 | def sinusoids(length, channels, max_timescale=10000): 44 | """Returns sinusoids for positional embedding""" 45 | assert channels % 2 == 0 46 | log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) 47 | inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) 48 | scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] 49 | return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) 50 | 51 | 52 | class MultiHeadAttention(nn.Module): 53 | def __init__(self, n_state: int, n_head: int): 54 | super().__init__() 55 | self.n_head = n_head 56 | self.query = Linear(n_state, n_state) 57 | self.key = Linear(n_state, n_state, bias=False) 58 | self.value = Linear(n_state, n_state) 59 | self.out = Linear(n_state, n_state) 60 | 61 | def forward( 62 | self, 63 | x: Tensor, 64 | xa: Optional[Tensor] = None, 65 | mask: Optional[Tensor] = None, 66 | kv_cache: Optional[dict] = None, 67 | ): 68 | q = self.query(x) 69 | 70 | if kv_cache is None or xa is None or self.key not in kv_cache: 71 | # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; 72 | # otherwise, perform key/value projections for self- or cross-attention as usual. 73 | k = self.key(x if xa is None else xa) 74 | v = self.value(x if xa is None else xa) 75 | else: 76 | # for cross-attention, calculate keys and values once and reuse in subsequent calls. 77 | k = kv_cache[self.key] 78 | v = kv_cache[self.value] 79 | 80 | wv, qk = self.qkv_attention(q, k, v, mask) 81 | return self.out(wv), qk 82 | 83 | def qkv_attention( 84 | self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None 85 | ): 86 | n_batch, n_ctx, n_state = q.shape 87 | scale = (n_state // self.n_head) ** -0.25 88 | q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale 89 | k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale 90 | v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) 91 | 92 | qk = q @ k 93 | if mask is not None: 94 | qk = qk + mask[:n_ctx, :n_ctx] 95 | qk = qk.float() 96 | 97 | w = F.softmax(qk, dim=-1).to(q.dtype) 98 | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() 99 | 100 | 101 | class ResidualAttentionBlock(nn.Module): 102 | def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): 103 | super().__init__() 104 | 105 | self.attn = MultiHeadAttention(n_state, n_head) 106 | self.attn_ln = LayerNorm(n_state) 107 | 108 | self.cross_attn = ( 109 | MultiHeadAttention(n_state, n_head) if cross_attention else None 110 | ) 111 | self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None 112 | 113 | n_mlp = n_state * 4 114 | self.mlp = nn.Sequential( 115 | Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state) 116 | ) 117 | self.mlp_ln = LayerNorm(n_state) 118 | 119 | def forward( 120 | self, 121 | x: Tensor, 122 | xa: Optional[Tensor] = None, 123 | mask: Optional[Tensor] = None, 124 | kv_cache: Optional[dict] = None, 125 | ): 126 | x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] 127 | if self.cross_attn: 128 | x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] 129 | x = x + self.mlp(self.mlp_ln(x)) 130 | return x 131 | 132 | 133 | class AudioEncoder(nn.Module): 134 | def __init__( 135 | self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int 136 | ): 137 | super().__init__() 138 | self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) 139 | self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) 140 | self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) 141 | 142 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( 143 | [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] 144 | ) 145 | self.ln_post = LayerNorm(n_state) 146 | 147 | def forward(self, x: Tensor): 148 | """ 149 | x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) 150 | the mel spectrogram of the audio 151 | """ 152 | x = F.gelu(self.conv1(x)) 153 | x = F.gelu(self.conv2(x)) 154 | x = x.permute(0, 2, 1) 155 | 156 | assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" 157 | x = (x + self.positional_embedding).to(x.dtype) 158 | 159 | for block in self.blocks: 160 | x = block(x) 161 | 162 | x = self.ln_post(x) 163 | return x 164 | 165 | 166 | def load_audio(file: str, sr: int = SAMPLE_RATE): 167 | """ 168 | Open an audio file and read as mono waveform, resampling as necessary 169 | 170 | Parameters 171 | ---------- 172 | file: str 173 | The audio file to open 174 | 175 | sr: int 176 | The sample rate to resample the audio if necessary 177 | 178 | Returns 179 | ------- 180 | A NumPy array containing the audio waveform, in float32 dtype. 181 | """ 182 | try: 183 | # This launches a subprocess to decode audio while down-mixing and resampling as necessary. 184 | # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. 185 | out, _ = ( 186 | ffmpeg.input(file, threads=0) 187 | .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr) 188 | .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) 189 | ) 190 | except ffmpeg.Error as e: 191 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e 192 | 193 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 194 | 195 | 196 | 197 | @lru_cache(maxsize=None) 198 | def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: 199 | """ 200 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram. 201 | Allows decoupling librosa dependency; saved using: 202 | 203 | np.savez_compressed( 204 | "mel_filters.npz", 205 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), 206 | ) 207 | """ 208 | assert n_mels == 80, f"Unsupported n_mels: {n_mels}" 209 | with np.load( 210 | os.path.join(os.path.dirname(__file__), "mel_filters.npz") 211 | ) as f: 212 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) 213 | 214 | 215 | 216 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): 217 | """ 218 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder. 219 | """ 220 | if torch.is_tensor(array): 221 | if array.shape[axis] > length: 222 | array = array.index_select( 223 | dim=axis, index=torch.arange(length, device=array.device) 224 | ) 225 | 226 | if array.shape[axis] < length: 227 | pad_widths = [(0, 0)] * array.ndim 228 | pad_widths[axis] = (0, length - array.shape[axis]) 229 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) 230 | else: 231 | if array.shape[axis] > length: 232 | array = array.take(indices=range(length), axis=axis) 233 | 234 | if array.shape[axis] < length: 235 | pad_widths = [(0, 0)] * array.ndim 236 | pad_widths[axis] = (0, length - array.shape[axis]) 237 | array = np.pad(array, pad_widths) 238 | 239 | return array 240 | 241 | 242 | def log_mel_spectrogram( 243 | audio: Union[str, np.ndarray, torch.Tensor], 244 | n_mels: int = N_MELS, 245 | padding: int = 0, 246 | device: Optional[Union[str, torch.device]] = None, 247 | ): 248 | """ 249 | Compute the log-Mel spectrogram of 250 | 251 | Parameters 252 | ---------- 253 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*) 254 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz 255 | 256 | n_mels: int 257 | The number of Mel-frequency filters, only 80 is supported 258 | 259 | padding: int 260 | Number of zero samples to pad to the right 261 | 262 | device: Optional[Union[str, torch.device]] 263 | If given, the audio tensor is moved to this device before STFT 264 | 265 | Returns 266 | ------- 267 | torch.Tensor, shape = (80, n_frames) 268 | A Tensor that contains the Mel spectrogram 269 | """ 270 | if not torch.is_tensor(audio): 271 | if isinstance(audio, str): 272 | audio = load_audio(audio) 273 | audio = torch.from_numpy(audio) 274 | 275 | if device is not None: 276 | audio = audio.to(device) 277 | if padding > 0: 278 | audio = F.pad(audio, (0, padding)) 279 | window = torch.hann_window(N_FFT).to(audio.device) 280 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) 281 | magnitudes = stft[..., :-1].abs() ** 2 282 | 283 | filters = mel_filters(audio.device, n_mels) 284 | mel_spec = filters @ magnitudes 285 | 286 | log_spec = torch.clamp(mel_spec, min=1e-10).log10() 287 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) 288 | log_spec = (log_spec + 4.0) / 4.0 289 | return log_spec 290 | --------------------------------------------------------------------------------