├── results.png ├── models ├── mel_filters.npz ├── transformer_wrapper.py ├── transformer_config.py ├── whisper_wrapper.py └── whisper_ni_predictors.py ├── requirements.txt ├── .gitattributes ├── checkpoints ├── multi_head_model.pt └── single_head_model.pt ├── README.md └── get_score.py /results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leto19/WhiSQA/HEAD/results.png -------------------------------------------------------------------------------- /models/mel_filters.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leto19/WhiSQA/HEAD/models/mel_filters.npz -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.4 2 | torch==2.1.0 3 | torchaudio==2.1.0 4 | torchinfo==1.8.0 5 | transformers==4.35.0 6 | soundfile 7 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | checkpoints/multi_head_model.pt filter=lfs diff=lfs merge=lfs -text 2 | checkpoints/single_head_model.pt filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /checkpoints/multi_head_model.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:448e4d528ed13a1c335486a8251b3b6681d596549f9c5d30592abffbe9550d0f 3 | size 366158486 4 | -------------------------------------------------------------------------------- /checkpoints/single_head_model.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2da43d59294c358bdd3f450d0aef5ce57ad20bed3beb34657ef3a795a4766810 3 | size 364031990 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Whisper based Speech Quality Assessment (WhiSQA) 2 | 3 | Usage: 4 | `python3 get_score.py /path/to/mono_16k_wav_file.wav` 5 | 6 | Requires `git-lfs` 7 | 8 | ![Results](results.png) 9 | -------------------------------------------------------------------------------- /get_score.py: -------------------------------------------------------------------------------- 1 | from models.whisper_ni_predictors import whisperMetricPredictorEncoderLayersTransformerSmall, whisperMetricPredictorEncoderLayersTransformerSmalldim 2 | import sys 3 | import torchaudio 4 | import argparse 5 | import torch 6 | 7 | def get_score(audio_file: str, model_type: str) -> torch.Tensor: 8 | """ 9 | Get a score for a given audio file and print it. 10 | 11 | Args: 12 | audio_file (str): Path to the audio file, must be 16K sample rate and mono. 13 | model_type (str): Single MOS (more accurate) or multidimensional [MOS, Noisiness, Coloration, Discontinuity and Loudness]. 14 | 15 | Returns: 16 | score (torch.Tensor): either MOS score or MOS + speech dimensions 17 | """ 18 | if torch.cuda.is_available(): 19 | device = torch.device("cuda") 20 | elif torch.backends.mps.is_available(): 21 | device = torch.device("mps") #for M1 Macs 22 | else: 23 | device = torch.device("cpu") #May be slow ! 24 | 25 | if model_type == "single": 26 | model = whisperMetricPredictorEncoderLayersTransformerSmall() 27 | model.load_state_dict(torch.load("checkpoints/single_head_model.pt",map_location=device)) 28 | elif model_type == "multi": 29 | model = whisperMetricPredictorEncoderLayersTransformerSmalldim() 30 | model.load_state_dict(torch.load("checkpoints/multi_head_model.pt",map_location=device)) 31 | else: 32 | raise ValueError("Model type not supported") 33 | 34 | model.eval() 35 | model.to(device) 36 | waveform, sample_rate = torchaudio.load(audio_file) 37 | 38 | #check channels 39 | if waveform.shape[0] != 1: 40 | raise ValueError("Number of input channels must be 1") 41 | 42 | # Check sample rate 43 | if sample_rate != 16000: 44 | raise ValueError("Sample rate must be 16000") 45 | 46 | waveform = waveform.to(device) 47 | score = model(waveform) 48 | if model_type == "multi": 49 | score = score.squeeze(0) 50 | return score 51 | 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser(description="Get a score for a given audio file") 56 | parser.add_argument("audio_file", type=str, help="Path to the audio file") 57 | parser.add_argument("--model_type", type=str, help="Single headed MOS or multidimension [MOS,Noisiness, Coloration,Discontinuity and Loudness]", default="single") 58 | args = parser.parse_args() 59 | 60 | 61 | score = get_score(args.audio_file, args.model_type) 62 | print(score.shape) 63 | if args.model_type == "single": 64 | print("MOS", score.item() * 5) 65 | else: 66 | mos = score[0].item() * 5 67 | noisiness = score[1].item() * 5 68 | coloration = score[2].item() * 5 69 | discontinuity = score[3].item() * 5 70 | loudness = score[4].item() * 5 71 | print("MOS", mos) 72 | print("Noisiness", noisiness) 73 | print("Coloration", coloration) 74 | print("Discontinuity", discontinuity) 75 | print("Loudness", loudness) 76 | 77 | sys.exit(0) 78 | -------------------------------------------------------------------------------- /models/transformer_wrapper.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import Tensor, nn 3 | import torch 4 | 5 | try: 6 | from transformer_config import Config 7 | except: 8 | from models.transformer_config import Config 9 | 10 | class TransformerWrapper(nn.Module): 11 | 12 | def __init__(self, config: Config): 13 | super().__init__() 14 | 15 | self.config = config 16 | # Normalization. 17 | self.norm = nn.BatchNorm1d(config.dim_input) 18 | 19 | # Position encoding. 20 | # if config.xlsr_name == "hubert_encoder" or config.xlsr_name == "hubert_full" or config.xlsr_name == "whisper_full": 21 | # self.position_encoding = PositionalEncodingVariable(config) 22 | # else: 23 | # self.position_encoding = PositionalEncoding(config) 24 | self.position_encoding = PositionalEncoding(config) 25 | 26 | # Down-projection to transformer dim. 27 | self.linear_proj = nn.Linear( 28 | in_features=config.dim_input, 29 | out_features=config.dim_transformer, 30 | ) 31 | self.linear_proj_drop = nn.Dropout(config.dropout) 32 | 33 | # Transformer encoder. 34 | encoder_layer = nn.TransformerEncoderLayer( 35 | d_model=config.dim_transformer, 36 | dim_feedforward=config.dim_transformer*2, 37 | nhead=config.nhead_transformer, 38 | batch_first=True, 39 | dropout=config.dropout, 40 | ) 41 | self.transformer_encoder = nn.TransformerEncoder( 42 | encoder_layer=encoder_layer, 43 | num_layers=config.nlayers_transformer, 44 | ) 45 | self.dropout = nn.Dropout(config.dropout) 46 | 47 | def forward(self, x: Tensor, mask: Tensor = None) -> Tensor: 48 | 49 | # Normalization. 50 | # Transform from (N, L, C) to (N, C, L) and back. 51 | x = self.norm(x.transpose(-1,-2)).transpose(-1,-2) 52 | 53 | # Linear projection down to transformer dim. 54 | x = self.linear_proj(x) 55 | x = self.linear_proj_drop(x) 56 | 57 | # Position encoding + transformer. 58 | x = self.position_encoding(x) 59 | x = self.transformer_encoder(x, mask) 60 | 61 | x = self.dropout(x) 62 | 63 | return x 64 | 65 | 66 | class PositionalEncoding(nn.Module): 67 | 68 | def __init__(self, config: Config): 69 | super().__init__() 70 | 71 | d_model: int = config.dim_transformer 72 | seq_len: int = config.feat_seq_len 73 | position = torch.arange(seq_len).unsqueeze(1) 74 | div_term = torch.exp(torch.arange(0, d_model, 2) 75 | * (-math.log(2*seq_len) / d_model)) 76 | pe = torch.zeros(1, seq_len, d_model) 77 | pe[0, :, 0::2] = torch.sin(position * div_term) 78 | pe[0, :, 1::2] = torch.cos(position * div_term) 79 | self.register_buffer('pe', pe) 80 | 81 | def forward(self, x: Tensor) -> Tensor: 82 | """ 83 | Args: 84 | x: Tensor, shape [seq_len, batch_size, embedding_dim] 85 | """ 86 | x = x + self.pe.expand(x.shape) 87 | return x 88 | 89 | class PositionalEncodingVariable(nn.Module): 90 | """ 91 | Positional encoding module for variable-length sequences. 92 | 93 | Args: 94 | config (Config): Configuration object containing model parameters. 95 | 96 | Attributes: 97 | pe (Tensor): Positional encoding tensor. 98 | 99 | """ 100 | 101 | def __init__(self, config: Config): 102 | super().__init__() 103 | 104 | d_model: int = config.dim_transformer 105 | seq_len: int = config.feat_seq_len 106 | position = torch.arange(seq_len).unsqueeze(1) 107 | div_term = torch.exp(torch.arange(0, d_model, 2) 108 | * (-math.log(2*seq_len) / d_model)) 109 | pe = torch.zeros(1, seq_len, d_model) 110 | pe[0, :, 0::2] = torch.sin(position * div_term) 111 | pe[0, :, 1::2] = torch.cos(position * div_term) 112 | #self.register_buffer('pe', pe) 113 | 114 | def forward(self, x: Tensor) -> Tensor: 115 | """ 116 | Apply positional encoding to the input tensor. 117 | 118 | Args: 119 | x (Tensor): Input tensor of shape [seq_len, batch_size, embedding_dim] 120 | 121 | Returns: 122 | Tensor: Output tensor with positional encoding applied. 123 | 124 | """ 125 | d_model: int = 256 126 | seq_len: int = x.shape[1] 127 | position = torch.arange(seq_len).unsqueeze(1) 128 | div_term = torch.exp(torch.arange(0, d_model, 2) 129 | * (-math.log(2*seq_len) / d_model)) 130 | pe = torch.zeros(1, seq_len, d_model).cuda() 131 | pe[0, :, 0::2] = torch.sin(position * div_term) 132 | pe[0, :, 1::2] = torch.cos(position * div_term) 133 | self.register_buffer('pe', pe) 134 | self.pe.to(x.device) 135 | 136 | x = x + self.pe.expand(x.shape) 137 | return x 138 | 139 | 140 | 141 | -------------------------------------------------------------------------------- /models/transformer_config.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import torch 3 | from torch.nn.functional import pad 4 | 5 | class Input(Enum): 6 | MFCC = 0 7 | XLSR = 1 8 | 9 | class CenterCrop(torch.nn.Module): 10 | def __init__(self, seq_len: int) -> None: 11 | super().__init__() 12 | self.seq_len = seq_len 13 | 14 | def forward(self, x: torch.Tensor): 15 | # Center crop. 16 | unsqueezed = False 17 | if x.dim() == 2: 18 | unsqueezed = True 19 | x = x.unsqueeze(0) 20 | assert x.dim() == 3 # N, L, C 21 | 22 | if x.size(1) > self.seq_len: 23 | center_start_idx = int(x.size(1) / 2 - self.seq_len / 2) 24 | start_idx = center_start_idx 25 | end_idx = start_idx + self.seq_len 26 | x = x[:, start_idx:end_idx, :] 27 | if x.size(1) < self.seq_len: 28 | to_pad = self.seq_len - x.size(1) 29 | # Pad the end of sequence dimension. 30 | x = pad(x, (0,0,0,to_pad,0,0), mode="constant", value=0.0) 31 | 32 | if unsqueezed: 33 | x = x.squeeze(0) 34 | 35 | return x 36 | 37 | class Config: 38 | 39 | name: str = None 40 | input: Input = None 41 | feat_seq_len: int = None 42 | dim_input: int = None 43 | dim_transformer: int = None 44 | dim_head_in: int = None 45 | dim_head_out: int = None 46 | 47 | def __init__( 48 | self, 49 | name: str, 50 | input: Input, 51 | feat_seq_len: int, 52 | dim_transformer: int = None, 53 | xlsr_name: str = None, 54 | nhead_transformer: int = 4, 55 | nlayers_transformer: int = 2, 56 | ): 57 | if input == Input.MFCC: 58 | xlsr_name = None 59 | 60 | # Check valid parameters. 61 | assert feat_seq_len > 0, "feat_seq_len must be positive." 62 | 63 | # Save parameters. 64 | self.name = name 65 | self.input = input 66 | self.feat_seq_len = feat_seq_len 67 | self.dim_transformer = dim_transformer 68 | self.xlsr_name = xlsr_name 69 | self.nhead_transformer = nhead_transformer 70 | self.nlayers_transformer = nlayers_transformer 71 | if xlsr_name is not None: 72 | # From XLS-R paper Table 2: Model architectures. 73 | if xlsr_name == "wav2vec2-xls-r-300m": 74 | _b = 24 75 | _h = 1024 76 | elif xlsr_name == "wav2vec2-xls-r-1b": 77 | _b = 48 78 | _h = 1280 79 | elif xlsr_name == "wav2vec2-xls-r-2b": 80 | _b = 48 81 | _h = 1920 82 | elif xlsr_name == "hubert_encoder": 83 | _b = -1 84 | _h = 512 85 | elif xlsr_name == "hubert_encoder_t": 86 | _b = -1 87 | _h = 384 88 | elif xlsr_name == "hubert_full": 89 | _b = -1 90 | _h = 768 91 | elif xlsr_name == "hubert_full_t": 92 | _b = -1 93 | _h = 384 94 | elif xlsr_name == "whisper_encoder": 95 | _b = -1 96 | _h = 768 97 | elif xlsr_name == "whisper_encoder_ref": 98 | _b = -1 99 | _h = 768*2 100 | elif xlsr_name == "whisper_encoder_t": 101 | _b = -1 102 | _h = 1500 103 | elif xlsr_name == "whisper_full": 104 | _b = -1 105 | _h = 768 106 | elif xlsr_name == "whisper_full_t": 107 | _b = -1 108 | _h = 384 109 | self.xlsr_layers = _b + 1 # +1 for CNN activation "layer0" 110 | self.dim_input = _h 111 | else: 112 | if self.feat_seq_len == 80: #handle transposed mfcc 113 | self.xlsr_layers = None 114 | self.dim_input = 3000 115 | else: 116 | self.xlsr_layers = None 117 | self.dim_input = 80 # MFCC 118 | 119 | self.dim_head_in = self.dim_transformer # * self.feat_seq_len 120 | self.dim_head_out = 1 121 | 122 | self.dropout = 0.1 # TODO 123 | 124 | 125 | # Length of feature frame window. 126 | FEAT_SEQ_LEN = 256 127 | 128 | 129 | ####################### TRANSFORMER_32DEEP_CONFIG #################### 130 | MFCC_TRANSFORMER_32DEEP_CONFIG = Config( 131 | "MFCC_TRANSFORMER_32DEEP_CONFIG", 132 | Input.MFCC, 133 | feat_seq_len=FEAT_SEQ_LEN, 134 | dim_transformer=256, 135 | xlsr_name=None, 136 | nhead_transformer=4, 137 | nlayers_transformer=4, 138 | ) 139 | 140 | HUBERT_ENCODER_CONFIG= Config( 141 | "HUBERT_ENCODER_CONFIG", 142 | Input.XLSR, 143 | feat_seq_len=256, 144 | dim_transformer=256, 145 | xlsr_name="hubert_encoder", 146 | nhead_transformer=4, 147 | nlayers_transformer=4, 148 | ) 149 | 150 | 151 | 152 | HUBERT_ENCODER_CONFIG_T = Config( 153 | "HUBERT_ENCODER_CONFIG", 154 | Input.XLSR, 155 | feat_seq_len=512, 156 | dim_transformer=256, 157 | xlsr_name="hubert_encoder_t", 158 | nhead_transformer=4, 159 | nlayers_transformer=4, 160 | ) 161 | 162 | 163 | 164 | 165 | 166 | HUBERT_FULL_CONFIG = Config( 167 | "HUBERT_FULL_CONFIG", 168 | Input.XLSR, 169 | feat_seq_len=768, 170 | dim_transformer=256, 171 | xlsr_name="hubert_full", 172 | nhead_transformer=4, 173 | nlayers_transformer=4, 174 | ) 175 | 176 | 177 | WHISPER_ENCODER_CONFIG = Config( 178 | "WHISPER_ENCODER_CONFIG", 179 | Input.XLSR, 180 | feat_seq_len=1500, 181 | dim_transformer=768, 182 | xlsr_name="whisper_encoder", 183 | nhead_transformer=4, 184 | nlayers_transformer=4, 185 | ) 186 | 187 | WHISPER_ENCODER_CONFIG = Config( 188 | "WHISPER_ENCODER_CONFIG_REF", 189 | Input.XLSR, 190 | feat_seq_len=1500, 191 | dim_transformer=768, 192 | xlsr_name="whisper_encoder", 193 | nhead_transformer=4, 194 | nlayers_transformer=4, 195 | ) 196 | 197 | 198 | WHISPER_ENCODER_CONFIG_MEDIUM = Config( 199 | "WHISPER_ENCODER_CONFIG", 200 | Input.XLSR, 201 | feat_seq_len=1500, 202 | dim_transformer=512, 203 | xlsr_name="whisper_encoder", 204 | nhead_transformer=4, 205 | nlayers_transformer=4, 206 | ) 207 | 208 | 209 | 210 | WHISPER_ENCODER_CONFIG_SMALL = Config( 211 | "WHISPER_ENCODER_CONFIG", 212 | Input.XLSR, 213 | feat_seq_len=1500, 214 | dim_transformer=256, 215 | xlsr_name="whisper_encoder", 216 | nhead_transformer=4, 217 | nlayers_transformer=4, 218 | ) 219 | WHISPER_ENCODER_CONFIG_SMALL_T = Config( 220 | "WHISPER_ENCODER_CONFIG", 221 | Input.XLSR, 222 | feat_seq_len=768, 223 | dim_transformer=256, 224 | xlsr_name="whisper_encoder", 225 | nhead_transformer=4, 226 | nlayers_transformer=4, 227 | ) 228 | 229 | WHISPER_ENCODER_CONFIG_MEL = Config( 230 | "WHISPER_ENCODER_CONFIG", 231 | Input.MFCC, 232 | feat_seq_len=3000, 233 | dim_transformer=256, 234 | xlsr_name="whisper_encoder", 235 | nhead_transformer=4, 236 | nlayers_transformer=4, 237 | ) 238 | 239 | 240 | 241 | WHISPER_ENCODER_CONFIG_SMALLER = Config( 242 | "WHISPER_ENCODER_CONFIG", 243 | Input.XLSR, 244 | feat_seq_len=1500, 245 | dim_transformer=128, 246 | xlsr_name="whisper_encoder", 247 | nhead_transformer=4, 248 | nlayers_transformer=4, 249 | ) 250 | 251 | WHISPER_ENCODER_CONFIG_SMALLER_T = Config( 252 | "WHISPER_ENCODER_CONFIG", 253 | Input.XLSR, 254 | feat_seq_len=768, 255 | dim_transformer=128, 256 | xlsr_name="whisper_encoder_t", 257 | nhead_transformer=4, 258 | nlayers_transformer=4, 259 | ) 260 | 261 | WHISPER_FULL_CONFIG_SMALL= Config( 262 | "WHISPER_FULL_CONFIG", 263 | Input.XLSR, 264 | feat_seq_len=768, 265 | dim_transformer=256, 266 | xlsr_name="whisper_full", 267 | nhead_transformer=4, 268 | nlayers_transformer=4, 269 | ) 270 | 271 | 272 | XLSR_300M_TRANSFORMER_32DEEP_CONFIG = Config( 273 | "XLSR_300M_TRANSFORMER_32DEEP_CONFIG", 274 | Input.XLSR, 275 | feat_seq_len=FEAT_SEQ_LEN, 276 | dim_transformer=32, 277 | xlsr_name="wav2vec2-xls-r-300m", 278 | nhead_transformer=4, 279 | nlayers_transformer=4, 280 | ) 281 | 282 | XLSR_1B_TRANSFORMER_32DEEP_CONFIG = Config( 283 | "XLSR_1B_TRANSFORMER_32DEEP_CONFIG", 284 | Input.XLSR, 285 | feat_seq_len=FEAT_SEQ_LEN, 286 | dim_transformer=32, 287 | xlsr_name="wav2vec2-xls-r-1b", 288 | nhead_transformer=4, 289 | nlayers_transformer=4, 290 | ) 291 | 292 | XLSR_2B_TRANSFORMER_32DEEP_CONFIG = Config( 293 | "XLSR_2B_TRANSFORMER_32DEEP_CONFIG", 294 | Input.XLSR, 295 | feat_seq_len=FEAT_SEQ_LEN, 296 | dim_transformer=32, 297 | xlsr_name="wav2vec2-xls-r-2b", 298 | nhead_transformer=4, 299 | nlayers_transformer=4, 300 | ) -------------------------------------------------------------------------------- /models/whisper_wrapper.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, nn 2 | import torch 3 | import torch.nn.functional as F 4 | from transformers import WhisperModel, WhisperFeatureExtractor, WhisperForConditionalGeneration 5 | from functools import lru_cache 6 | from typing import Optional, Union 7 | import numpy as np 8 | 9 | 10 | SAMPLE_RATE = 16000 11 | N_FFT = 400 12 | N_MELS = 80 13 | HOP_LENGTH = 160 14 | CHUNK_LENGTH = 30 15 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk 16 | #print("N_SAMPLES: ",N_SAMPLES) 17 | N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 18 | 19 | 20 | def log_mel_spectrogram( 21 | audio: Union[str, np.ndarray, torch.Tensor], 22 | n_mels: int = N_MELS, 23 | padding: int = 0, 24 | device: Optional[Union[str, torch.device]] = None, 25 | ): 26 | """ 27 | Compute the log-Mel spectrogram of 28 | 29 | Parameters 30 | ---------- 31 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*) 32 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz 33 | 34 | n_mels: int 35 | The number of Mel-frequency filters, only 80 is supported 36 | 37 | padding: int 38 | Number of zero samples to pad to the right 39 | 40 | device: Optional[Union[str, torch.device]] 41 | If given, the audio tensor is moved to this device before STFT 42 | 43 | Returns 44 | ------- 45 | torch.Tensor, shape = (80, n_frames) 46 | A Tensor that contains the Mel spectrogram 47 | """ 48 | if device is not None: 49 | audio = audio.to(device) 50 | if padding > 0: 51 | audio = F.pad(audio, (0, padding)) 52 | window = torch.hann_window(N_FFT).to(audio.device) 53 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) 54 | magnitudes = stft[..., :-1].abs() ** 2 55 | 56 | filters = mel_filters(audio.device, n_mels) 57 | mel_spec = filters @ magnitudes 58 | 59 | log_spec = torch.clamp(mel_spec, min=1e-10).log10() 60 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) 61 | log_spec = (log_spec + 4.0) / 4.0 62 | return log_spec 63 | 64 | @lru_cache(maxsize=None) 65 | def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: 66 | """ 67 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram. 68 | Allows decoupling librosa dependency; saved using: 69 | 70 | np.savez_compressed("mel_filters.npz",mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80)) 71 | """ 72 | assert n_mels == 80, f"Unsupported n_mels: {n_mels}" 73 | with np.load("models/mel_filters.npz",allow_pickle=True) as f: 74 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) 75 | 76 | 77 | 78 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): 79 | """ 80 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder. 81 | """ 82 | if torch.is_tensor(array): 83 | if array.shape[axis] > length: 84 | array = array.index_select( 85 | dim=axis, index=torch.arange(length, device=array.device) 86 | ) 87 | 88 | if array.shape[axis] < length: 89 | pad_widths = [(0, 0)] * array.ndim 90 | pad_widths[axis] = (0, length - array.shape[axis]) 91 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) 92 | else: 93 | if array.shape[axis] > length: 94 | array = array.take(indices=range(length), axis=axis) 95 | 96 | if array.shape[axis] < length: 97 | pad_widths = [(0, 0)] * array.ndim 98 | pad_widths[axis] = (0, length - array.shape[axis]) 99 | array = np.pad(array, pad_widths) 100 | 101 | return array 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | class WhisperWrapper_full(nn.Module): 110 | def __init__(self, layer = None, use_feat_extractor = False, pretrained_model = None, num_layers = 12, *args, **kwargs): 111 | super().__init__(*args, **kwargs) 112 | 113 | # using layer = -1 returns all layers in form (1, time, feat_dim, layers) 114 | # otherwise single layer in form (1, time, feat_dim) 115 | 116 | self.num_layers = num_layers 117 | self.use_feat_extractor = use_feat_extractor 118 | if layer is None: 119 | self.layer = 12 120 | else: 121 | self.layer = layer 122 | 123 | # if use_feat_extractor: 124 | # self.feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small") 125 | if pretrained_model is None: 126 | self.model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") 127 | else: 128 | self.model = WhisperForConditionalGeneration.from_pretrained(pretrained_model) 129 | 130 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 131 | 132 | 133 | def forward(self, data): 134 | 135 | if self.use_feat_extractor: 136 | #print(data.shape) 137 | #print(type(data)) 138 | data = log_mel_spectrogram(data,padding=N_SAMPLES) 139 | data = pad_or_trim(data, length=3000).to(self.device) 140 | #print("feature shape: ",data.shape) 141 | #print("requires grad: ",data.requires_grad) 142 | 143 | outputs = self.model.generate( 144 | input_features = data, 145 | output_hidden_states = True, 146 | return_dict_in_generate = True 147 | ) 148 | #print(outputs.sequences) 149 | #print(outputs.decoder_hidden_states[0][0].shape) 150 | if self.layer == -1: 151 | decoder_hidden = [] 152 | for layer in range(self.num_layers): 153 | hidden = torch.stack([outputs.decoder_hidden_states[word][layer][:][:] for word in range(len(outputs.decoder_hidden_states))]) 154 | #hidden has dim ('word', batch, layer,feat_dim) 155 | hidden = hidden.permute(1,0,3,2) 156 | #hidden has dim (batch, 'word', feat_dim, layer) 157 | #print(layer,hidden.shape) 158 | decoder_hidden.append(hidden) 159 | decoder_hidden = torch.stack(decoder_hidden, dim = -1).squeeze(3) 160 | #print("decoder_hidden size: ",decoder_hidden.size()) 161 | elif self.layer == None: 162 | decoder_hidden = torch.stack([outputs.decoder_hidden_states[word][self.num_layers-1][0][0] for word in range(len(outputs.decoder_hidden_states))]) 163 | decoder_hidden = decoder_hidden.unsqueeze(0) 164 | else: 165 | decoder_hidden = torch.stack([outputs.decoder_hidden_states[word][self.layer][0][0] for word in range(len(outputs.decoder_hidden_states))]) 166 | decoder_hidden = decoder_hidden.unsqueeze(0) 167 | #print(f"decoder_hidden size: {decoder_hidden.size()}") 168 | # print(decoder_hidden.size()) 169 | #input(">>>") 170 | return decoder_hidden 171 | 172 | 173 | class WhisperWrapper_encoder(nn.Module): 174 | def __init__(self, layer = None, use_feat_extractor = False, pretrained_model = None, *args, **kwargs): 175 | super().__init__(*args, **kwargs) 176 | 177 | self.use_feat_extractor = use_feat_extractor 178 | self.layer = layer 179 | 180 | if not use_feat_extractor: 181 | self.feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small") 182 | if pretrained_model is None: 183 | model = WhisperModel.from_pretrained("openai/whisper-small") 184 | else: 185 | model = WhisperModel.from_pretrained(pretrained_model) 186 | self.model = model.encoder 187 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 188 | 189 | def forward(self, data): 190 | 191 | if self.use_feat_extractor: 192 | #print(data.shape) 193 | data_padded = pad_or_trim(data, length=N_SAMPLES).to(self.device) 194 | #print("data padded shape: ",data_padded.shape) 195 | data_feats = log_mel_spectrogram(data_padded) 196 | 197 | #print("feature shape after log mel: ",data_feats.shape) 198 | else: 199 | #print("data shape: ",data.shape) 200 | d_list = [] 201 | for d in data: 202 | d_list.append(d.to('cpu').tolist()) 203 | data = self.feature_extractor(d_list, sampling_rate = 16000, return_tensors = 'pt') 204 | #print(data) 205 | data_feats = data.input_features.to(self.device) 206 | #print("data shape after",data_feats.shape) 207 | if self.layer is None: 208 | data = self.model( 209 | input_features = data_feats, 210 | return_dict = True 211 | ) 212 | #print(data) 213 | data = data[0] 214 | #print(data.shape) 215 | elif self.layer == -1: 216 | data = self.model( 217 | input_features = data_feats, 218 | return_dict = True, 219 | output_hidden_states = True 220 | ) 221 | #print(data.hidden_states[0].shape) 222 | layers = [] 223 | for layer in range(len(data.hidden_states)): 224 | 225 | layers.append(data.hidden_states[layer]) 226 | data = torch.stack(layers, dim = -1) 227 | #print(data.shape) 228 | else: 229 | data = self.model( 230 | input_features = data_feats, 231 | return_dict = True, 232 | output_hidden_states = True 233 | ) 234 | data = data.hidden_states[self.layer] 235 | 236 | return data 237 | 238 | class WhisperWrapper_encoder_debug(nn.Module): 239 | def __init__(self, layer = None, use_feat_extractor = False, pretrained_model = None, *args, **kwargs): 240 | super().__init__(*args, **kwargs) 241 | 242 | self.use_feat_extractor = use_feat_extractor 243 | self.layer = layer 244 | 245 | if not use_feat_extractor: 246 | self.feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small") 247 | if pretrained_model is None: 248 | model = WhisperModel.from_pretrained("openai/whisper-small") 249 | else: 250 | model = WhisperModel.from_pretrained(pretrained_model) 251 | 252 | self.model = model.encoder 253 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 254 | 255 | def forward(self, data): 256 | 257 | if self.use_feat_extractor: 258 | print(data.shape) 259 | data_padded = pad_or_trim(data, length=N_SAMPLES).to(self.device) 260 | print("data padded shape: ",data_padded.shape) 261 | data_feats = log_mel_spectrogram(data_padded) 262 | 263 | print("feature shape after log mel: ",data_feats.shape) 264 | else: 265 | print("data shape: ",data.shape) 266 | d_list = [] 267 | for d in data: 268 | d_list.append(d.to('cpu').tolist()) 269 | data = self.feature_extractor(d_list, sampling_rate = 16000, return_tensors = 'pt') 270 | #print(data) 271 | data_feats = data.input_features.to(self.device) 272 | print("data shape after",data_feats.shape) 273 | if self.layer is None: 274 | data = self.model( 275 | input_features = data_feats, 276 | return_dict = True 277 | ) 278 | #print(data) 279 | data = data[0] 280 | print(data.shape) 281 | elif self.layer == -1: 282 | data = self.model( 283 | input_features = data_feats, 284 | return_dict = True, 285 | output_hidden_states = True 286 | ) 287 | #print(data.hidden_states[0].shape) 288 | layers = [] 289 | for layer in range(len(data.hidden_states)): 290 | 291 | layers.append(data.hidden_states[layer]) 292 | data = torch.stack(layers, dim = -1) 293 | #print(data.shape) 294 | else: 295 | data = self.model( 296 | input_features = data_feats, 297 | return_dict = True, 298 | output_hidden_states = True 299 | ) 300 | data = data.hidden_states[self.layer] 301 | 302 | return data, data_feats 303 | 304 | -------------------------------------------------------------------------------- /models/whisper_ni_predictors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import Tensor, nn 4 | try: 5 | from whisper_wrapper import WhisperWrapper_full,WhisperWrapper_encoder,pad_or_trim, log_mel_spectrogram 6 | from transformer_wrapper import TransformerWrapper 7 | from transformer_config import CenterCrop,Config,Input 8 | except: 9 | from models.whisper_wrapper import WhisperWrapper_full,WhisperWrapper_encoder, pad_or_trim, log_mel_spectrogram 10 | from models.transformer_wrapper import TransformerWrapper 11 | from models.transformer_config import CenterCrop,Config,Input 12 | 13 | class PoolAttFF(torch.nn.Module): 14 | ''' 15 | PoolAttFF: Attention-Pooling module with additonal feed-forward network. 16 | ''' 17 | def __init__(self, dim_head_in): 18 | super().__init__() 19 | 20 | self.linear1 = nn.Linear(dim_head_in, 2*dim_head_in) 21 | self.linear2 = nn.Linear(2*dim_head_in, 1) 22 | 23 | self.linear3 = nn.Linear(dim_head_in, 1) 24 | 25 | self.activation = F.relu 26 | self.dropout = nn.Dropout(0.1) 27 | 28 | def forward(self, x: Tensor): 29 | 30 | att = self.linear2(self.dropout(self.activation(self.linear1(x)))) 31 | att = att.transpose(2,1) 32 | att = F.softmax(att, dim=2) 33 | 34 | x = torch.bmm(att, x) 35 | 36 | x = x.squeeze(1) 37 | 38 | x = self.linear3(x) 39 | 40 | return x 41 | 42 | 43 | 44 | 45 | class whisperMetricPredictorEncoderTransformerSmall(nn.Module): 46 | """Transformer based varient on metric estimator 47 | 48 | based on https://github.com/lcn-kul/xls-r-analysis-sqa/ 49 | """ 50 | def __init__( 51 | self, feat_seq=1500): 52 | super().__init__() 53 | self.norm_input = nn.BatchNorm1d(768) 54 | 55 | self.feat_extract = WhisperWrapper_encoder(use_feat_extractor=True) 56 | self.feat_extract.requires_grad_(False) 57 | 58 | self.config = Config( 59 | "WHISPER_ENCODER_CONFIG", 60 | Input.XLSR, 61 | feat_seq_len=feat_seq, 62 | dim_transformer=256, 63 | xlsr_name="whisper_encoder", 64 | nhead_transformer=4, 65 | nlayers_transformer=4, 66 | ) 67 | self.transformer = TransformerWrapper(self.config) 68 | 69 | 70 | 71 | self.attenPool = PoolAttFF(self.config.dim_transformer) 72 | 73 | self.sigmoid = nn.Sigmoid() 74 | 75 | def forward(self, x): 76 | 77 | out_feats = self.feat_extract(x) #whisper encoder returns (B, 1500, 512) 78 | out_feats = self.norm_input(out_feats.permute(0,2,1)).permute(0,2,1) #normalize and permute back to (B, 1500, 512) 79 | out = self.transformer(out_feats) # transformer returns (B, 1500, 256) 80 | out = self.attenPool(out) #attenPool returns (B, 1) 81 | out = self.sigmoid(out) #sigmoid returns (B, 1) 82 | return out 83 | 84 | 85 | class whisperMetricPredictorEncoderTransformerSmallT(nn.Module): 86 | """Transformer based varient on metric estimator 87 | 88 | based on 89 | """ 90 | def __init__( 91 | self, feat_seq=1500): 92 | super().__init__() 93 | self.norm_input = nn.BatchNorm1d(feat_seq) 94 | 95 | self.feat_extract = WhisperWrapper_encoder(use_feat_extractor=True) 96 | self.feat_extract.requires_grad_(False) 97 | 98 | self.config = Config( 99 | "WHISPER_ENCODER_CONFIG", 100 | Input.XLSR, 101 | feat_seq_len=768, 102 | dim_transformer=256, 103 | xlsr_name="whisper_encoder_t", 104 | nhead_transformer=4, 105 | nlayers_transformer=4, 106 | ) 107 | self.transformer = TransformerWrapper(self.config) 108 | 109 | 110 | 111 | self.attenPool = PoolAttFF(self.config.dim_transformer) 112 | 113 | self.sigmoid = nn.Sigmoid() 114 | 115 | def forward(self, x): 116 | 117 | out_feats = self.feat_extract(x) #whisper encoder returns (B, 1500, 512) 118 | out_feats = out_feats.permute(0,2,1) #normalize and permute back to (B, 1500, 512) 119 | out_feats = self.norm_input(out_feats.permute(0,2,1)).permute(0,2,1) #normalize and permute back to (B, 1500, 512) 120 | 121 | out = self.transformer(out_feats) # transformer returns (B, 1500, 256) 122 | out = self.attenPool(out) #attenPool returns (B, 1) 123 | out = self.sigmoid(out) #sigmoid returns (B, 1) 124 | return out 125 | 126 | class whisperMetricPredictorEncoderLayersTransformerSmall(nn.Module): 127 | """Transformer based varient on metric estimator 128 | 129 | based on https://github.com/lcn-kul/xls-r-analysis-sqa/ 130 | """ 131 | def __init__( 132 | self, feat_seq=1500): 133 | super().__init__() 134 | self.norm_input = nn.BatchNorm1d(768) 135 | 136 | self.feat_extract = WhisperWrapper_encoder(use_feat_extractor=True, layer=-1) 137 | self.feat_extract.requires_grad_(False) 138 | self.layer_weights = nn.Parameter(torch.ones(13)) 139 | self.softmax = nn.Softmax(dim=0) 140 | 141 | self.config = Config( 142 | "WHISPER_ENCODER_CONFIG", 143 | Input.XLSR, 144 | feat_seq_len=feat_seq, 145 | dim_transformer=256, 146 | xlsr_name="whisper_encoder", 147 | nhead_transformer=4, 148 | nlayers_transformer=4, 149 | ) 150 | self.transformer = TransformerWrapper(self.config) 151 | 152 | 153 | 154 | self.attenPool = PoolAttFF(self.config.dim_transformer) 155 | 156 | self.sigmoid = nn.Sigmoid() 157 | 158 | def forward(self, x): 159 | 160 | out_feats = self.feat_extract(x) #whisper encoder a list of 13 tensors of shape (B, 1500, 512) 161 | out_feats = out_feats @ self.softmax(self.layer_weights) #weighted sum of the 13 tensors 162 | #print(self.layer_weights) 163 | out_feats = self.norm_input(out_feats.permute(0,2,1)).permute(0,2,1) #normalize and permute back to (B, 1500, 512) 164 | out = self.transformer(out_feats) # transformer returns (B, 1500, 256) 165 | out = self.attenPool(out) #attenPool returns (B, 1) 166 | out = self.sigmoid(out) #sigmoid returns (B, 1) 167 | return out 168 | class whisperMetricPredictorEncoderLayersTransformerSmalldim(nn.Module): 169 | """Transformer based varient on metric estimator 170 | 171 | based on https://github.com/lcn-kul/xls-r-analysis-sqa/ 172 | """ 173 | def __init__( 174 | self, feat_seq=1500): 175 | super().__init__() 176 | self.norm_input = nn.BatchNorm1d(768) 177 | 178 | self.feat_extract = WhisperWrapper_encoder(use_feat_extractor=True, layer=-1) 179 | self.feat_extract.requires_grad_(False) 180 | self.layer_weights = nn.Parameter(torch.ones(13)) 181 | self.softmax = nn.Softmax(dim=0) 182 | 183 | self.config = Config( 184 | "WHISPER_ENCODER_CONFIG", 185 | Input.XLSR, 186 | feat_seq_len=feat_seq, 187 | dim_transformer=256, 188 | xlsr_name="whisper_encoder", 189 | nhead_transformer=4, 190 | nlayers_transformer=4, 191 | ) 192 | self.transformer = TransformerWrapper(self.config) 193 | 194 | 195 | 196 | self.attenPool1 = PoolAttFF(self.config.dim_transformer) 197 | self.attenPool2 = PoolAttFF(self.config.dim_transformer) 198 | self.attenPool3 = PoolAttFF(self.config.dim_transformer) 199 | self.attenPool4 = PoolAttFF(self.config.dim_transformer) 200 | self.attenPool5 = PoolAttFF(self.config.dim_transformer) 201 | 202 | self.sigmoid = nn.Sigmoid() 203 | 204 | def forward(self, x): 205 | 206 | out_feats = self.feat_extract(x) #whisper encoder a list of 13 tensors of shape (B, 1500, 512) 207 | out_feats = out_feats @ self.softmax(self.layer_weights) #weighted sum of the 13 tensors 208 | #print(self.layer_weights) 209 | out_feats = self.norm_input(out_feats.permute(0,2,1)).permute(0,2,1) #normalize and permute back to (B, 1500, 512) 210 | out = self.transformer(out_feats) # transformer returns (B, 1500, 256) 211 | 212 | out1 = self.attenPool1(out) #attenPool returns (B, 1) 213 | out1 = self.sigmoid(out1) #sigmoid returns (B, 1) 214 | out2 = self.attenPool2(out) #attenPool returns (B, 1) 215 | out2 = self.sigmoid(out2) 216 | 217 | out3 = self.attenPool3(out) #attenPool returns (B, 1) 218 | out3 = self.sigmoid(out3) 219 | 220 | out4 = self.attenPool4(out) #attenPool returns (B, 1) 221 | out4 = self.sigmoid(out4) 222 | 223 | out5 = self.attenPool5(out) #attenPool returns (B, 1) 224 | out5 = self.sigmoid(out5) 225 | 226 | #return all 5 outputs ona new dimension 227 | out = torch.stack([out1,out2,out3,out4,out5],dim=1) 228 | 229 | return out 230 | class whisperMetricPredictorEncoderLayersTransformerSmallRef(nn.Module): 231 | """Transformer based varient on metric estimator 232 | 233 | based on https://github.com/lcn-kul/xls-r-analysis-sqa/ 234 | """ 235 | def __init__( 236 | self, feat_seq=1500): 237 | super().__init__() 238 | self.norm_input = nn.BatchNorm1d(768) 239 | self.norm_input_ref = nn.BatchNorm1d(768) 240 | self.feat_extract = WhisperWrapper_encoder(use_feat_extractor=True, layer=-1) 241 | self.feat_extract.requires_grad_(False) 242 | self.layer_weights = nn.Parameter(torch.ones(13)) 243 | self.layer_weights_ref = nn.Parameter(torch.ones(13)) 244 | self.softmax_ref = nn.Softmax(dim=0) 245 | self.softmax = nn.Softmax(dim=0) 246 | 247 | self.config = Config( 248 | "WHISPER_ENCODER_CONFIG_REF", 249 | Input.XLSR, 250 | feat_seq_len=feat_seq, 251 | dim_transformer=256, 252 | xlsr_name="whisper_encoder_ref", 253 | nhead_transformer=4, 254 | nlayers_transformer=4, 255 | ) 256 | self.transformer = TransformerWrapper(self.config) 257 | 258 | 259 | 260 | self.attenPool1 = PoolAttFF(self.config.dim_transformer) 261 | 262 | 263 | self.sigmoid = nn.Sigmoid() 264 | 265 | def forward(self, x, y): 266 | 267 | out_feats = self.feat_extract(x) #whisper encoder a list of 13 tensors of shape (B, 1500, 512) 268 | 269 | out_feats = out_feats @ self.softmax(self.layer_weights) #weighted sum of the 13 tensors 270 | print(self.layer_weights) 271 | out_feats = self.norm_input(out_feats.permute(0,2,1)).permute(0,2,1) #normalize and permute back to (B, 1500, 512) 272 | 273 | out_feats_ref = self.feat_extract(y) #whisper encoder a list of 13 tensors of shape (B, 1500, 512) 274 | out_feats_ref = out_feats_ref @ self.softmax_ref(self.layer_weights_ref) #weighted sum of the 13 tensors 275 | print(self.layer_weights_ref) 276 | out_feats_ref = self.norm_input_ref(out_feats_ref.permute(0,2,1)).permute(0,2,1) #normalize and permute back to (B, 1500, 512) 277 | 278 | 279 | #concatenate the two inputs 280 | #print("out_feats",out_feats.shape) 281 | #print("out_feats_ref",out_feats_ref.shape) 282 | out_feats = torch.cat([out_feats,out_feats_ref],dim=2) 283 | #print(out_feats.shape) 284 | 285 | 286 | out = self.transformer(out_feats) # transformer returns (B, 1500, 256) 287 | 288 | 289 | out1 = self.attenPool1(out) #attenPool returns (B, 1) 290 | out1 = self.sigmoid(out1) #sigmoid returns (B, 1) 291 | 292 | 293 | return out1#,out_feats 294 | class whisperMetricPredictorEncoderLayersTransformerSmallDimRef(nn.Module): 295 | """Transformer based varient on metric estimator 296 | 297 | based on https://github.com/lcn-kul/xls-r-analysis-sqa/ 298 | """ 299 | def __init__( 300 | self, feat_seq=1500): 301 | super().__init__() 302 | self.norm_input = nn.BatchNorm1d(768) 303 | self.norm_input_ref = nn.BatchNorm1d(768) 304 | self.feat_extract = WhisperWrapper_encoder(use_feat_extractor=True, layer=-1) 305 | self.feat_extract.requires_grad_(False) 306 | self.layer_weights = nn.Parameter(torch.ones(13)) 307 | self.layer_weights_ref = nn.Parameter(torch.ones(13)) 308 | self.softmax_ref = nn.Softmax(dim=0) 309 | self.softmax = nn.Softmax(dim=0) 310 | 311 | self.config = Config( 312 | "WHISPER_ENCODER_CONFIG_REF", 313 | Input.XLSR, 314 | feat_seq_len=feat_seq, 315 | dim_transformer=256, 316 | xlsr_name="whisper_encoder_ref", 317 | nhead_transformer=4, 318 | nlayers_transformer=4, 319 | ) 320 | self.transformer = TransformerWrapper(self.config) 321 | 322 | 323 | 324 | self.attenPool1 = PoolAttFF(self.config.dim_transformer) 325 | self.attenPool2 = PoolAttFF(self.config.dim_transformer) 326 | self.attenPool3 = PoolAttFF(self.config.dim_transformer) 327 | self.attenPool4 = PoolAttFF(self.config.dim_transformer) 328 | self.attenPool5 = PoolAttFF(self.config.dim_transformer) 329 | 330 | self.sigmoid = nn.Sigmoid() 331 | 332 | def forward(self, x, y): 333 | 334 | out_feats = self.feat_extract(x) #whisper encoder a list of 13 tensors of shape (B, 1500, 512) 335 | 336 | out_feats = out_feats @ self.softmax(self.layer_weights) #weighted sum of the 13 tensors 337 | print(self.layer_weights) 338 | out_feats = self.norm_input(out_feats.permute(0,2,1)).permute(0,2,1) #normalize and permute back to (B, 1500, 512) 339 | 340 | out_feats_ref = self.feat_extract(y) #whisper encoder a list of 13 tensors of shape (B, 1500, 512) 341 | out_feats_ref = out_feats_ref @ self.softmax_ref(self.layer_weights_ref) #weighted sum of the 13 tensors 342 | print(self.layer_weights_ref) 343 | out_feats_ref = self.norm_input_ref(out_feats_ref.permute(0,2,1)).permute(0,2,1) #normalize and permute back to (B, 1500, 512) 344 | 345 | 346 | #concatenate the two inputs 347 | #print("out_feats",out_feats.shape) 348 | #print("out_feats_ref",out_feats_ref.shape) 349 | out_feats = torch.cat([out_feats,out_feats_ref],dim=2) 350 | #print(out_feats.shape) 351 | 352 | 353 | out = self.transformer(out_feats) # transformer returns (B, 1500, 256) 354 | 355 | 356 | out1 = self.attenPool1(out) #attenPool returns (B, 1) 357 | out1 = self.sigmoid(out1) #sigmoid returns (B, 1) 358 | out2 = self.attenPool2(out) #attenPool returns (B, 1) 359 | out2 = self.sigmoid(out2) 360 | 361 | out3 = self.attenPool3(out) #attenPool returns (B, 1) 362 | out3 = self.sigmoid(out3) 363 | 364 | out4 = self.attenPool4(out) #attenPool returns (B, 1) 365 | out4 = self.sigmoid(out4) 366 | 367 | out5 = self.attenPool5(out) #attenPool returns (B, 1) 368 | out5 = self.sigmoid(out5) 369 | 370 | #return all 5 outputs ona new dimension 371 | out = torch.stack([out1,out2,out3,out4,out5],dim=1) 372 | 373 | return out#,out_feats 374 | 375 | 376 | 377 | class whisperMetricPredictorEncoderLayersTransformerSmallT(nn.Module): 378 | """Transformer based varient on metric estimator 379 | 380 | based on https://github.com/lcn-kul/xls-r-analysis-sqa/ 381 | """ 382 | def __init__( 383 | self, feat_seq=1500): 384 | super().__init__() 385 | self.norm_input = nn.BatchNorm1d(feat_seq) 386 | 387 | self.feat_extract = WhisperWrapper_encoder(use_feat_extractor=True, layer=-1) 388 | self.feat_extract.requires_grad_(False) 389 | self.layer_weights = nn.Parameter(torch.ones(13)) 390 | self.softmax = nn.Softmax(dim=0) 391 | 392 | self.config = Config( 393 | "WHISPER_ENCODER_CONFIG", 394 | Input.XLSR, 395 | feat_seq_len=768, 396 | dim_transformer=256, 397 | xlsr_name="whisper_encoder_t", 398 | nhead_transformer=4, 399 | nlayers_transformer=4, 400 | ) 401 | self.transformer = TransformerWrapper(self.config) 402 | 403 | 404 | 405 | self.attenPool = PoolAttFF(self.config.dim_transformer) 406 | 407 | self.sigmoid = nn.Sigmoid() 408 | 409 | def forward(self, x): 410 | 411 | out_feats = self.feat_extract(x) #whisper encoder a list of 13 tensors of shape (B, 1500, 512) 412 | out_feats = out_feats @ self.softmax(self.layer_weights) #weighted sum of the 13 tensors 413 | print(self.layer_weights) 414 | 415 | out_feats = out_feats.permute(0,2,1) #swap axes to (B, 512, 1500) 416 | 417 | out_feats = self.norm_input(out_feats.permute(0,2,1)).permute(0,2,1) #normalize and permute back to (B, 1500, 512) 418 | out = self.transformer(out_feats) # transformer returns (B, 1500, 256) 419 | out = self.attenPool(out) #attenPool returns (B, 1) 420 | out = self.sigmoid(out) #sigmoid returns (B, 1) 421 | return out 422 | 423 | 424 | class whisperMetricPredictorMelTransformerSmall(nn.Module): 425 | """Transformer based varient on metric estimator 426 | 427 | based on https://github.com/lcn-kul/xls-r-analysis-sqa/ 428 | """ 429 | def __init__(self, feat_seq=3000): 430 | super().__init__() 431 | 432 | 433 | self.config = Config( 434 | "MFCC_TRANSFORMER_32DEEP_CONFIG", 435 | Input.MFCC, 436 | feat_seq_len=feat_seq, 437 | dim_transformer=256, 438 | xlsr_name=None, 439 | nhead_transformer=4, 440 | nlayers_transformer=4, 441 | ) 442 | self.norm_input = nn.BatchNorm1d(80) 443 | 444 | self.transformer = TransformerWrapper(self.config) 445 | 446 | self.attenPool = PoolAttFF(self.config.dim_transformer) 447 | 448 | self.sigmoid = nn.Sigmoid() 449 | 450 | def forward(self, x): 451 | N_SAMPLES = 16000*30 452 | data_padded = pad_or_trim(x, length=N_SAMPLES) #pad or trim to 30 seconds, returns (B, 480000) 453 | data_feats = log_mel_spectrogram(data_padded).swapaxes(1,2) #returns (B, 3000, 80) 454 | 455 | data_feats = self.norm_input(data_feats.permute(0,2,1)).permute(0,2,1) #normalize and permute back to (B, 3000, 80) 456 | out_trans = self.transformer(data_feats) # transformer returns (B, 3000, 256) 457 | out = self.attenPool(out_trans) #attenPool returns (B, 1) 458 | out = self.sigmoid(out) 459 | 460 | return out 461 | 462 | 463 | class whisperMetricPredictorMelTransformerSmallT (nn.Module): 464 | """Transformer based varient on metric estimator 465 | 466 | based on https://github.com/lcn-kul/xls-r-analysis-sqa/ 467 | """ 468 | def __init__(self, feat_seq=3000): 469 | super().__init__() 470 | 471 | 472 | self.config = Config( 473 | "MFCC_TRANSFORMER_32DEEP_CONFIG", 474 | Input.MFCC, 475 | feat_seq_len=80, 476 | dim_transformer=256, 477 | xlsr_name="mel_T", 478 | nhead_transformer=4, 479 | nlayers_transformer=4, 480 | ) 481 | self.norm_input = nn.BatchNorm1d(feat_seq) 482 | 483 | self.transformer = TransformerWrapper(self.config) 484 | 485 | self.attenPool = PoolAttFF(self.config.dim_transformer) 486 | 487 | self.sigmoid = nn.Sigmoid() 488 | 489 | def forward(self, x): 490 | N_SAMPLES = 16000*30 491 | data_padded = pad_or_trim(x, length=N_SAMPLES) #pad or trim to 30 seconds, returns (B, 480000) 492 | data_feats = log_mel_spectrogram(data_padded) #returns (B, 80, 3000) 493 | 494 | data_feats = self.norm_input(data_feats.permute(0,2,1)).permute(0,2,1) #normalize and permute back to (B, 80, 3000) 495 | out_trans = self.transformer(data_feats) # transformer returns (B, 3000, 256) 496 | out = self.attenPool(out_trans) #attenPool returns (B, 1) 497 | out = self.sigmoid(out) 498 | 499 | return out 500 | 501 | 502 | 503 | class whisperMetricPredictorFullTransformerSmall(nn.Module): 504 | def __init__(self, feat_seq=768//2): 505 | super().__init__() 506 | 507 | 508 | 509 | self.feat_extract = WhisperWrapper_full(layer=-1,use_feat_extractor=True) 510 | self.feat_extract.requires_grad_(False) 511 | self.config = Config( 512 | "WHISPER_FULL_CONFIG", 513 | Input.XLSR, 514 | feat_seq_len=feat_seq, 515 | dim_transformer=256, 516 | xlsr_name="whisper_full", 517 | nhead_transformer=4, 518 | nlayers_transformer=4, 519 | ) 520 | self.cc = CenterCrop(feat_seq) 521 | self.norm_input = nn.BatchNorm1d(768) 522 | self.transformer = TransformerWrapper(self.config) 523 | self.norm_input = nn.BatchNorm1d(768) 524 | self.attenPool = PoolAttFF(self.config.dim_transformer) 525 | 526 | 527 | self.sigmoid = nn.Sigmoid() 528 | def forward(self, x): 529 | out_feats = self.feat_extract(x)[:,:,:,-1] #whisper encoder returns (B, 1500, 768) 530 | out_feats = self.cc (out_feats) #center crop to 384 531 | out_feats = self.norm_input(out_feats.permute(0,2,1)).permute(0,2,1) #normalize and permute back to (B, 384, 768) 532 | out = self.transformer(out_feats) # transformer returns (B, 384, 256) 533 | out = self.attenPool(out) #attenPool returns (B, 1) 534 | out = self.sigmoid(out) #sigmoid returns (B, 1) 535 | return out 536 | 537 | class whisperMetricPredictorFullTransformerSmallT(nn.Module): 538 | def __init__(self, feat_seq=384): 539 | super().__init__() 540 | 541 | 542 | 543 | self.feat_extract = WhisperWrapper_full(layer=-1,use_feat_extractor=True) 544 | self.feat_extract.requires_grad_(False) 545 | self.config = Config( 546 | "WHISPER_FULL_CONFIG", 547 | Input.XLSR, 548 | feat_seq_len=768, 549 | dim_transformer=256, 550 | xlsr_name="whisper_full_t", 551 | nhead_transformer=4, 552 | nlayers_transformer=4, 553 | ) 554 | self.cc = CenterCrop(feat_seq) 555 | self.norm_input = nn.BatchNorm1d(feat_seq) 556 | 557 | self.transformer = TransformerWrapper(self.config) 558 | self.attenPool = PoolAttFF(self.config.dim_transformer) 559 | 560 | 561 | self.sigmoid = nn.Sigmoid() 562 | def forward(self, x): 563 | out_feats = self.feat_extract(x)[:,:,:,-1] #whisper encoder returns (B, W, 768) 564 | out_feats = self.cc (out_feats) #center crop to 384 565 | 566 | out_feats= out_feats.permute(0,2,1) #swap axes to (B, 768, 384) 567 | 568 | out_feats = self.norm_input(out_feats.permute(0,2,1)).permute(0,2,1) #normalize and permute back to (B, 768, 384) 569 | 570 | out = self.transformer(out_feats) # transformer returns (B, 768, 256) 571 | out = self.attenPool(out) #attenPool returns (B, 1) 572 | out = self.sigmoid(out) #sigmoid returns (B, 1) 573 | return out 574 | 575 | 576 | 577 | class whisperMetricPredictorFullLayersTransformerSmall(nn.Module): 578 | def __init__(self, feat_seq=768//2): 579 | super().__init__() 580 | 581 | 582 | 583 | self.feat_extract = WhisperWrapper_full(layer=-1,use_feat_extractor=True) 584 | self.feat_extract.requires_grad_(False) 585 | self.config = Config( 586 | "WHISPER_FULL_CONFIG", 587 | Input.XLSR, 588 | feat_seq_len=feat_seq, 589 | dim_transformer=256, 590 | xlsr_name="whisper_full", 591 | nhead_transformer=4, 592 | nlayers_transformer=4, 593 | ) 594 | self.cc = CenterCrop(feat_seq) 595 | self.norm_input = nn.BatchNorm1d(768) 596 | self.transformer = TransformerWrapper(self.config) 597 | self.norm_input = nn.BatchNorm1d(768) 598 | self.attenPool = PoolAttFF(self.config.dim_transformer) 599 | self.layer_weights = nn.Parameter(torch.ones(12)) 600 | self.softmax = nn.Softmax(dim=0) 601 | 602 | self.sigmoid = nn.Sigmoid() 603 | def forward(self, x): 604 | out_feats = self.feat_extract(x) #whisper encoder returns list (B, 1500, 768,12) 605 | out_feats = out_feats @ self.softmax(self.layer_weights) #weighted sum of the 12 tensors (B, 1500, 768) 606 | print(self.layer_weights) 607 | out_feats = self.cc (out_feats) #center crop to 384 608 | out_feats = self.norm_input(out_feats.permute(0,2,1)).permute(0,2,1) #normalize and permute back to (B, 384, 768) 609 | out = self.transformer(out_feats) # transformer returns (B, 384, 256) 610 | out = self.attenPool(out) #attenPool returns (B, 1) 611 | out = self.sigmoid(out) #sigmoid returns (B, 1) 612 | return out 613 | 614 | 615 | class whisperMetricPredictorFullLayersTransformerSmallT(nn.Module): 616 | def __init__(self, feat_seq=384): 617 | super().__init__() 618 | 619 | 620 | 621 | self.feat_extract = WhisperWrapper_full(layer=-1,use_feat_extractor=True) 622 | self.feat_extract.requires_grad_(False) 623 | self.config = Config( 624 | "WHISPER_FULL_CONFIG", 625 | Input.XLSR, 626 | feat_seq_len=768, 627 | dim_transformer=256, 628 | xlsr_name="whisper_full_t", 629 | nhead_transformer=4, 630 | nlayers_transformer=4, 631 | ) 632 | self.cc = CenterCrop(feat_seq) 633 | self.norm_input = nn.BatchNorm1d(feat_seq) 634 | 635 | self.transformer = TransformerWrapper(self.config) 636 | self.attenPool = PoolAttFF(self.config.dim_transformer) 637 | self.layer_weights = nn.Parameter(torch.ones(12)) 638 | self.softmax = nn.Softmax(dim=0) 639 | 640 | self.sigmoid = nn.Sigmoid() 641 | def forward(self, x): 642 | out_feats = self.feat_extract(x) #whisper encoder returns list (B, 1500, 768,12) 643 | out_feats = out_feats @ self.softmax(self.layer_weights) #weighted sum of the 12 tensors (B, 1500, 768) 644 | print(self.layer_weights) 645 | out_feats = self.cc (out_feats) #center crop to 384 646 | 647 | out_feats= out_feats.permute(0,2,1) #swap axes to (B, 768, 384) 648 | 649 | out_feats = self.norm_input(out_feats.permute(0,2,1)).permute(0,2,1) #normalize and permute back to (B, 768, 384) 650 | 651 | out = self.transformer(out_feats) # transformer returns (B, 768, 256) 652 | out = self.attenPool(out) #attenPool returns (B, 1) 653 | out = self.sigmoid(out) #sigmoid returns (B, 1) 654 | return out 655 | 656 | --------------------------------------------------------------------------------