├── .gitignore ├── README.md ├── SCOREQ_NeurIPS24.pdf ├── data ├── opus.wav └── ref.wav ├── figs └── results.png └── src └── scoreq ├── __init__.py ├── __main__.py └── scoreq.py /.gitignore: -------------------------------------------------------------------------------- 1 | pt-models* 2 | .vscode* 3 | __pycache__* 4 | **/__pycache__/ 5 | dist* 6 | pyproject.toml 7 | requirements.txt 8 | *.pyc 9 | *.out 10 | *.sh 11 | *.th 12 | test_package/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SCOREQ: Speech Contrastive Regression for Quality Assessment 2 | 3 | SCOREQ is a framework for speech quality assessment based on pre-training the encoder with the SCOREQ loss. 4 | 5 | This repo provides four speech quality metrics trained with the SCOREQ framework. 6 | 7 | | Domain | Train Set | Usage Mode | Prediction 8 | |---|---|---|---| 9 | | Natural speech | NISQA TRAIN SIM |No-reference | Mean Opinion Score 10 | | Natural speech | NISQA TRAIN SIM |Non-matching reference, full-reference | Euclidean distance clean speech 11 | | Synthetic speech | VoiceMOS 22 Train Set |No-reference | Mean Opinion Score 12 | | Synthetic speech | VoiceMOS 22 Train Set | Non-matching reference | Euclidean distance clean speech 13 | 14 | ## Performance - Pearson Correlation 15 | 18 | 19 | | Dataset | Domain | NISQA (NISQA TRAIN SIM) | NR-PESQ | NR-SI SDR | NORESQA-M | NR-SCOREQ Natural| NR-SCOREQ Synthetic | 20 | |-------------------|------|-------|---------|-----------|-----------|------------------| --------------------| 21 | | NISQA TEST FOR | Online Conferencing Simulated (codecs, background noise, packet loss, etc. ) | 0.91 | 0.79 | 0.74 | 0.68 | **0.97** | 0.82 22 | | NISQA TEST P501 | Online Conferencing Simulated (codecs, background noise, packet loss, etc. ) | 0.94 | 0.88 | 0.81 | 0.70 | **0.96** | 0.86 23 | | DNS Squim | Speech Enhancement, Background Noise | // | 0.96 | 0.99 | // | // | // 24 | | VoiceMOS Test 1 | Speech Synthesis | 0.54 | 0.71 | 0.67 | 0.85 | 0.86 | **0.90** 25 | | VoiceMOS Test 2 | Speech Synthesis | 0.64 | 0.49 | 0.55 | 0.91 | 0.82 | **0.98** 26 | | NOIZEUS | Speech Enhancement, Background Noise | 0.85 | 0.75 | 0.70 | 0.15 | **0.91** | 0.59 27 | | NISQA TEST LT | Online Conferencing Live | 0.84 | 0.66 | 0.56 | 0.60 | **0.86** | 0.81 28 | | P23 EXP3 | Packet Loss, Codecs | 0.82 | 0.77 | 0.17 | 0.71 | **0.94** | 0.88 29 | | TCD VOIP | VoIP Degradations | 0.76 | 0.76 | 0.76 | 0.61 | 0.85 | **0.87** 30 | | TENCENT | Online Conferencing Simulated (codecs, background noise, packet loss, etc. ) | 0.78 | 0.78 | 0.77 | 0.57 | **0.86** | 0.78 31 | | P23 EXP1 | Codecs | 0.76 | 0.70 | 0.82 | 0.40 | **0.96** | 0.92 32 | | TENCENT-Rev | Real-World Reverberation | 0.40 | 0.36 | 0.32 | 0.36 | **0.79** | 0.43 33 | 34 | 35 | ## Installation 36 | SCOREQ is hosted on PyPi. It can be installed in your Python environment with the following command 37 | ``` 38 | pip install scoreq 39 | ``` 40 | 41 | The expected sampling rate is 16 kHz. The script automatically resamples audio with different sampling rates. 42 | SCOREQ models accept variable input length. 43 | 44 | ### First run 45 | The PyTorch weights are hosted on Zenodo. The first run might be slower due to model download. 46 | 47 | ## Using SCOREQ 48 | SCOREQ can be used in 2 modes and for 2 domains by setting the arguments ```data_domain``` and ```mode```. 49 | 50 | ### Using SCOREQ from the command line 51 | 52 | | Domain | Usage Mode | CLI 53 | |---|---|---| 54 | | Natural speech | No-reference | ```python -m scoreq data_domain natural mode nr /path/to/test_audio ``` 55 | | Natural speech | Non-matching reference, full-reference | ```python -m scoreq data_domain natural mode ref /path/to/test_audio --ref_path /path/to/ref_audio``` 56 | | Synthetic speech | No-reference |```python -m scoreq data_domain synthetic mode nr /path/to/test_audio ``` 57 | | Synthetic speech | Non-matching reference | ```python -m scoreq data_domain synthetic mode ref /path/to/test_audio --ref_path /path/to/ref_audio``` 58 | 59 | 60 | ### Using SCOREQ inside Python 61 | Inside python you first need to import the package. 62 | Examples using wav files provided in the data directory. 63 | 64 | ``` 65 | import scoreq 66 | 67 | # Predict quality of natural speech in NR mode 68 | nr_scoreq = scoreq.Scoreq(data_domain='natural', mode='nr') 69 | pred_mos = nr_scoreq.predict(test_path='./data/opus.wav', ref_path=None) 70 | 71 | # Predict quality of natural speech in REF mode 72 | ref_scoreq = scoreq.Scoreq(data_domain='natural', mode='ref') 73 | pred_distance = ref_scoreq.predict(test_path='./data/opus.wav', ref_path='./data/ref.wav') 74 | 75 | # Predict quality of synthetic speech in NR mode 76 | nr_scoreq = scoreq.Scoreq(data_domain='synthetic', mode='nr') 77 | pred_mos = nr_scoreq.predict(test_path='./data/opus.wav', ref_path=None) 78 | 79 | # Predict quality of synthetic speech in REF mode 80 | ref_scoreq = scoreq.Scoreq(data_domain='synthetic', mode='ref') 81 | pred_distance = ref_scoreq.predict(test_path='./data/opus.wav', ref_path='./data/ref.wav') 82 | ``` 83 | 84 | ## Correct usage 85 | 86 | ### Reference Mode 87 | 88 | Use `mode=ref` for both non-matching reference and full-reference scenarios, depending on the clean speech input: 89 | 90 | - The metric operates in **full-reference mode** when provided with the clean counterpart of the audio. 91 | - It operates in **non-matching reference mode** if any random clean speech is used as input. 92 | 93 | Full-reference mode is expected for applications related to the natural speech domain where the clean counterpart is typically available e.g., speech codecs. 94 | SCOREQ learns a distance metric and expects clean speech as the non-matching reference. Note that the model has not been evaluated with other types of non-matching references. 95 | 96 | ### Input Length 97 | 98 | SCOREQ accepts inputs of any length. However, it was trained and evaluated on segments up to 15 seconds. We recommend trimming your audio to 10–15 seconds to avoid running out of memory with longer recordings. This duration is sufficient for accurate audio quality predictions. 99 | 100 | ### Sampling rate 101 | SCOREQ automatically resamples input files to 16 kHz, the selected sampling rate for training. The model was evaluated on four test sets in full-band mode (48 kHz), yielding promising results. While we encourage experimenting with SCOREQ at higher sampling rates, any conclusions drawn from using rates above 16 kHz should be validated through listening tests. 102 | 103 | ## Paper - NeurIPS 2024 104 | Check our paper [here](https://arxiv.org/pdf/2410.06675) 105 | ``` 106 | @article{ragano2024scoreq, 107 | title={SCOREQ: Speech Quality Assessment with Contrastive Regression}, 108 | author={Ragano, Alessandro and Skoglund, Jan and Hines, Andrew}, 109 | journal={arXiv preprint arXiv:2410.06675}, 110 | year={2024} 111 | } 112 | ``` 113 | [![DOI](https://zenodo.org/badge/868888288.svg)](https://doi.org/10.5281/zenodo.14735580) 114 | 115 | The SCOREQ code is licensed under MIT license. Dependencies of the project are available under separate license terms. 116 | 117 | Copyright © 2024 Alessandro Ragano 118 | -------------------------------------------------------------------------------- /SCOREQ_NeurIPS24.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alessandroragano/scoreq/bc1d19894092129f5dff774a0c9d942ac626d2a1/SCOREQ_NeurIPS24.pdf -------------------------------------------------------------------------------- /data/opus.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alessandroragano/scoreq/bc1d19894092129f5dff774a0c9d942ac626d2a1/data/opus.wav -------------------------------------------------------------------------------- /data/ref.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alessandroragano/scoreq/bc1d19894092129f5dff774a0c9d942ac626d2a1/data/ref.wav -------------------------------------------------------------------------------- /figs/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alessandroragano/scoreq/bc1d19894092129f5dff774a0c9d942ac626d2a1/figs/results.png -------------------------------------------------------------------------------- /src/scoreq/__init__.py: -------------------------------------------------------------------------------- 1 | from .scoreq import Scoreq 2 | scoreq = Scoreq() 3 | -------------------------------------------------------------------------------- /src/scoreq/__main__.py: -------------------------------------------------------------------------------- 1 | import click 2 | from scoreq import Scoreq 3 | 4 | @click.command() 5 | @click.argument('data_domain', type=click.Choice(['natural', 'synthetic'])) 6 | @click.argument('mode', type=click.Choice(['nr', 'ref'])) 7 | @click.argument('test_path', type=click.Path(exists=True)) 8 | @click.option('--ref_path', type=click.Path(exists=True), help='Path to the reference audio file (required in "ref" mode)') 9 | @click.option('--device', type=str, default=None, help='Specify device, cuda or cpu. Automatically set cuda if None and GPU is detected') 10 | def main(data_domain, mode, test_path, device=None, ref_path=None): 11 | """Audio quality assessment using SCOREQ""" 12 | 13 | # Check if 'ref_path' is provided in 'ref' mode 14 | if mode == 'ref' and ref_path is None: 15 | raise click.UsageError("Error: --ref_path is required in 'ref' mode") 16 | 17 | # If 'mode' is 'nr', set 'ref_path' to None explicitly 18 | if mode == 'nr': 19 | ref_path = None 20 | 21 | scoreq_model = Scoreq(device, data_domain, mode) 22 | scoreq_score = scoreq_model.predict(test_path, ref_path) 23 | 24 | if __name__ == '__main__': 25 | main() -------------------------------------------------------------------------------- /src/scoreq/scoreq.py: -------------------------------------------------------------------------------- 1 | import fairseq 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import torchaudio 6 | import os 7 | from urllib.request import urlretrieve 8 | 9 | class Scoreq(): 10 | """ 11 | Main class for handling the SCOREQ audio quality assessment model. 12 | 13 | This class loads the pre-trained SCOREQ model, processes audio files, and makes predictions in both 14 | no-reference (NR) and reference-based (FR/NMR) modes. It supports both natural and synthetic speech 15 | data domains. 16 | """ 17 | def __init__(self, device=None, data_domain='natural', mode='nr'): 18 | """ 19 | Initializes the Scoreq object. 20 | 21 | Args: 22 | device: Device to run the model on ('cuda' or 'cpu'). If None, automatically detects GPU availability. 23 | data_domain: Domain of the audio data ('natural' or 'synthetic'). 24 | mode: Mode of operation ('nr' for no-reference or 'ref' for either full-reference or non-matching-reference modes). 25 | """ 26 | 27 | # Store variables 28 | self.data_domain = data_domain 29 | self.mode = mode 30 | 31 | # *** DEVICE SETTINGS *** 32 | # Automatically set based on GPU detection 33 | if torch.cuda.is_available(): 34 | self.DEVICE = 'cuda' 35 | else: 36 | self.DEVICE = 'cpu' 37 | 38 | # Overwrite user choice 39 | if device is not None: 40 | self.DEVICE = device 41 | 42 | print(f'SCOREQ running on: {self.DEVICE}') 43 | 44 | # *** LOAD MODEL *** 45 | # *** Pytorch models directory **** 46 | if not os.path.isdir('./pt-models'): 47 | print('Creating pt-models directory') 48 | os.makedirs('./pt-models') 49 | 50 | # Download wav2vec 2.0 51 | url_w2v = "https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt" 52 | CHECKPOINT_PATH = './pt-models/wav2vec_small.pt' 53 | if not os.path.isfile(CHECKPOINT_PATH): 54 | print('Downloading wav2vec 2.0') 55 | urlretrieve(url_w2v, CHECKPOINT_PATH) 56 | print('Completed') 57 | 58 | # w2v BASE parameters 59 | W2V_OUT_DIM = 768 60 | EMB_DIM = 256 61 | 62 | # Load w2v BASE 63 | w2v_model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([CHECKPOINT_PATH]) 64 | ssl_model = w2v_model[0] 65 | ssl_model.remove_pretraining_modules() 66 | 67 | # Create SCOREQ model 68 | pt_model = TripletModel(ssl_model, W2V_OUT_DIM, EMB_DIM) 69 | 70 | # Add mos projection layer for no-reference mode 71 | if mode == 'nr': 72 | model = MosPredictor(pt_model, emb_dim=W2V_OUT_DIM) 73 | elif mode == 'ref': 74 | model = pt_model 75 | 76 | # Load weights 77 | if data_domain == 'natural': 78 | if mode == 'nr': 79 | MODEL_PATH = './pt-models/adapt_nr_telephone.pt' 80 | url_scoreq = 'https://zenodo.org/records/13860326/files/adapt_nr_telephone.pt' 81 | if not os.path.isfile(MODEL_PATH): 82 | print('Downloading PyTorch weights from Zenodo') 83 | print('SCOREQ | Mode: No-Reference | Data: Natural speech') 84 | urlretrieve(url_scoreq, MODEL_PATH) 85 | print('Download completed') 86 | elif mode == 'ref': 87 | MODEL_PATH = './pt-models/fixed_nmr_telephone.pt' 88 | url_scoreq = 'https://zenodo.org/records/13860326/files/fixed_nmr_telephone.pt' 89 | if not os.path.isfile(MODEL_PATH): 90 | print('Downloading PyTorch weights from Zenodo') 91 | print('SCOREQ | Mode: Full-Reference/NMR | Data: Natural speech') 92 | urlretrieve(url_scoreq, MODEL_PATH) 93 | print('Download completed') 94 | else: 95 | raise Exception('Mode must be either "nr" for no-reference or "ref" for full-reference and non-matching reference.') 96 | elif data_domain == 'synthetic': 97 | if mode == 'nr': 98 | MODEL_PATH = './pt-models/adapt_nr_synthetic.pt' 99 | url_scoreq = 'https://zenodo.org/records/13860326/files/adapt_nr_synthetic.pt' 100 | if not os.path.isfile(MODEL_PATH): 101 | print('Downloading PyTorch weights from Zenodo') 102 | print('SCOREQ | Mode: No-Reference | Data: Synthetic speech') 103 | urlretrieve(url_scoreq, MODEL_PATH) 104 | print('Download completed') 105 | elif mode == 'ref': 106 | MODEL_PATH = './pt-models/fixed_nmr_synthetic.pt' 107 | url_scoreq = 'https://zenodo.org/records/13860326/files/fixed_nmr_synthetic.pt' 108 | if not os.path.isfile(MODEL_PATH): 109 | print('Downloading PyTorch weights from Zenodo') 110 | print('SCOREQ | Mode: Full-reference/NMR | Data: Synthetic speech') 111 | urlretrieve(url_scoreq, MODEL_PATH) 112 | print('Download completed') 113 | else: 114 | raise Exception('Mode must be either "nr" for no-reference or "ref" for full-reference and non-matching reference.') 115 | else: 116 | raise Exception('Invalid data domain, you must select either "natural" or "synthetic".') 117 | 118 | model.load_state_dict(torch.load(MODEL_PATH, map_location=self.DEVICE)) 119 | self.model = model 120 | self.model.to(self.DEVICE) 121 | self.model.eval() 122 | 123 | 124 | def predict(self, test_path, ref_path=None): 125 | """" 126 | Makes predictions on audio files. 127 | 128 | Args: 129 | test_path: Path to the test audio file. 130 | ref_path: Path to the reference audio file (required in 'ref' mode). If ref_path is the clean counterpart the model will work in full-reference mode. If it's any clean speech, it will work in non-matching-reference mode. 131 | results_path: Optional path to save the results. 132 | 133 | Returns: 134 | The predicted quality score (MOS (1-5) in 'nr' mode, euclidean distance w.r.t to ref_path in 'ref' mode). 135 | """ 136 | 137 | # Check invalid input 138 | if test_path is None: 139 | raise Exception('test_path not specified, you need to pass a valid path to an audio file') 140 | 141 | if self.mode == 'ref': 142 | if ref_path is None: 143 | raise Exception('ref_path must be a wav file in ref mode, found None') 144 | 145 | # *** CHOOSE MODE *** 146 | # No-Reference (NR) mode 147 | if self.mode == 'nr': 148 | pred = np.round(self.nr_scoreq(test_path), 4) 149 | print(f'SCOREQ | No-Reference Mode | Domain {self.data_domain} | {test_path}: {pred}') 150 | 151 | elif self.mode == 'ref': 152 | # Full-reference (FR) mode or Non-Matching Reference (NMR) mode depending on which reference audio is used 153 | pred = self.ref_scoreq(test_path, ref_path) 154 | print(f'SCOREQ | Fr/Nmr-Reference Mode | Domain {self.data_domain} | Ref-> {ref_path}, Test-> {test_path}: {pred}') 155 | 156 | else: 157 | raise Exception('Selected mode is not valid, choose between nr and ref') 158 | 159 | return pred 160 | 161 | def nr_scoreq(self, test_path): 162 | """ 163 | Performs no-reference quality prediction. 164 | 165 | Args: 166 | test_path: Path to the test audio file. 167 | 168 | Returns: 169 | The predicted MOS. 170 | """ 171 | 172 | wave = self.load_processing(test_path).to(self.DEVICE) 173 | with torch.no_grad(): 174 | pred_mos = self.model(wave).item() 175 | 176 | return pred_mos 177 | 178 | def ref_scoreq(self, test_path, ref_path): 179 | """ 180 | Performs reference-based quality prediction. 181 | 182 | Args: 183 | test_path: Path to the test audio file. 184 | ref_path: Path to the reference audio file. It can either be the clean counterpart (Full-reference) or any clean speech (Non-matching reference). 185 | phead: Choose whether you want to use linear projection head for predictions. 186 | 187 | Returns: 188 | The euclidean distance between the embeddings of the test and reference audio files. 189 | """ 190 | test_wave = self.load_processing(test_path).to(self.DEVICE) 191 | ref_wave = self.load_processing(ref_path).to(self.DEVICE) 192 | 193 | # Get embeddings 194 | with torch.no_grad(): 195 | test_emb = self.model(test_wave) 196 | ref_emb = self.model(ref_wave) 197 | 198 | # Get euclidean distance 199 | scoreq_dist = torch.cdist(test_emb, ref_emb).item() 200 | return scoreq_dist 201 | 202 | # Load wave file 203 | def load_processing(self, filepath, target_sr=16000, trim=False): 204 | """ 205 | Loads and preprocesses an audio file. 206 | 207 | Args: 208 | filepath: Path to the audio file or a numpy array containing the audio data. 209 | target_sr: Target sample rate (default: 16000 Hz). 210 | trim: Whether to trim the audio to 10 seconds (default: False). 211 | 212 | Returns: 213 | The preprocessed audio waveform as a PyTorch tensor. 214 | """ 215 | 216 | # Load waveform 217 | if isinstance(filepath, np.ndarray): 218 | filepath = filepath[0] 219 | wave, sr = torchaudio.load(filepath) 220 | 221 | # Check number of channels (MONO) 222 | if wave.shape[0] > 1: 223 | wave = ((wave[0,:] + wave[1,:])/2).unsqueeze(0) 224 | 225 | # Check resampling (16 khz) 226 | if sr != target_sr: 227 | wave = torchaudio.transforms.Resample(sr, target_sr)(wave) 228 | sr = target_sr 229 | 230 | # Trim audio to 10 secs 231 | if trim: 232 | if wave.shape[1] > sr*10: 233 | wave = wave[:, :sr*10] 234 | 235 | return wave 236 | 237 | class TripletModel(nn.Module): 238 | """ 239 | Helper class defining the underlying neural network architecture for the SCOREQ model. 240 | """ 241 | 242 | def __init__(self, ssl_model, ssl_out_dim, emb_dim=256): 243 | """ 244 | Initializes the TripletModel. 245 | 246 | Args: 247 | ssl_model: The pre-trained self-supervised learning model (e.g., wav2vec). 248 | ssl_out_dim: Output dimension of the SSL model. 249 | emb_dim: Dimension of the final embedding (default: 256). 250 | """ 251 | 252 | super(TripletModel, self).__init__() 253 | self.ssl_model = ssl_model 254 | self.ssl_features = ssl_out_dim 255 | self.embedding_layer = nn.Sequential( 256 | nn.ReLU(), 257 | nn.Linear(self.ssl_features, emb_dim) 258 | ) 259 | 260 | def forward(self, wav, phead=False): 261 | """ 262 | Defines the forward pass of the model. 263 | 264 | Args: 265 | wav: Input audio waveform. 266 | phead: Attach embedding layer for reference mode prei 267 | 268 | Returns: 269 | The normalized embedding of the input audio. 270 | """ 271 | 272 | wav = wav.squeeze(1) 273 | res = self.ssl_model(wav, mask=False, features_only=True) 274 | x = res['x'] 275 | x = torch.mean(x, 1) 276 | 277 | # Choose if you want to keep projection head, remove for NR mode. Const model shows better performance in ODM without phead. 278 | if phead: 279 | x = self.embedding_layer(x) 280 | x = torch.nn.functional.normalize(x, dim=1) 281 | return x 282 | 283 | # ******** MOS PREDICTOR ********** 284 | class MosPredictor(nn.Module): 285 | """ 286 | Helper class that adds a layer for predicting Mean Opinion Scores (MOS) in the no-reference mode. 287 | """ 288 | 289 | def __init__(self, pt_model, emb_dim=768): 290 | """ 291 | Initializes the MosPredictor. 292 | 293 | Args: 294 | pt_model: The pre-trained triplet model. 295 | emb_dim: Dimension of the embedding (default: 768). 296 | """ 297 | super(MosPredictor, self).__init__() 298 | self.pt_model = pt_model 299 | self.mos_layer = nn.Linear(emb_dim, 1) 300 | 301 | def forward(self, wav): 302 | """ 303 | Defines the forward pass of the MOS predictor. 304 | 305 | Args: 306 | wav: Input audio waveform. 307 | 308 | Returns: 309 | The predicted MOS and the embedding. 310 | """ 311 | x = self.pt_model(wav, phead=False) 312 | if len(x.shape) == 3: 313 | x.squeeze_(2) 314 | out = self.mos_layer(x) 315 | return out --------------------------------------------------------------------------------