├── LICENSE ├── README.md ├── evaluate.py └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Sony Research Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vocoder Evaluation 2 | 3 | This repository contains the evaluation tool used in **"BigVSAN: Enhancing GAN-based Neural Vocoders with Slicing Adversarial Network"** (*[arXiv 2309.02836](https://arxiv.org/abs/2309.02836)*). 4 | Please cite [[1](#citation)] in your work when using this code in your experiments. 5 | 6 | ## Quick Start 7 | 8 | First, prepare an environment 9 | ```shell 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | Then, perform an evaluation 14 | ```shell 15 | python evaluate.py ... 16 | ``` 17 | ```gt_dir n``` means a directory that contains ground-truth audio files, and ```synth_dir n``` means a directory that contains synthesized audio files. Each file in ```synth_dir n``` needs to have the corresponding file that has the same name in ```gt_dir n```. Also, a corresponding pair needs to be time-aligned in advance. 18 | 19 | ```evaluate.py``` will output calculated metrics for each ```gt_dir n```-```synth_dir n``` pair and the macro averages of them across all pairs. It will take some time to complete an evaluation. 20 | 21 | ## Supported evaluation metrics 22 | This toolbox supports the following metrics: 23 | 24 | - M-STFT: Multi-resolution short-term Fourier transform 25 | - PESQ: Perceptual evaluation of speech quality 26 | - MCD: Mel-cepstral distortion 27 | - Periodicity: Periodicity error 28 | - V/UV F1: F1 score of voiced/unvoiced classification 29 | 30 | ## Citation 31 | 32 | If you find this tool useful, please consider citing 33 | 34 | [1] Shibuya, T., Takida, Y., Mitsufuji, Y., 35 | "BigVSAN: Enhancing GAN-based Neural Vocoders with Slicing Adversarial Network," 36 | ICASSP 2024. 37 | ```bibtex 38 | @inproceedings{shibuya2024bigvsan, 39 | title={{BigVSAN}: Enhancing GAN-based Neural Vocoders with Slicing Adversarial Network}, 40 | author={Shibuya, Takashi and Takida, Yuhta and Mitsufuji, Yuki}, 41 | booktitle={ICASSP 2024 - 2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 42 | year={2024} 43 | } 44 | ``` 45 | 46 | ## References 47 | 48 | > https://github.com/NVIDIA/BigVGAN 49 | 50 | > https://github.com/csteinmetz1/auraloss 51 | 52 | > https://github.com/ludlows/PESQ 53 | 54 | > https://github.com/ttslr/python-MCD 55 | 56 | > https://github.com/descriptinc/cargan 57 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import auraloss 3 | import functools 4 | import json 5 | import librosa 6 | import numpy as np 7 | import os 8 | import pysptk 9 | import torch 10 | import torchaudio as ta 11 | 12 | from cargan.evaluate.objective.metrics import Pitch 13 | from cargan.preprocess.pitch import from_audio 14 | from fastdtw import fastdtw 15 | from pesq import pesq 16 | from scipy.io.wavfile import read 17 | from scipy.spatial.distance import euclidean 18 | from tqdm import tqdm 19 | 20 | SR_TARGET = 24000 21 | MAX_WAV_VALUE = 32768.0 22 | 23 | 24 | def load_wav(full_path): 25 | sampling_rate, audio = read(full_path) 26 | if sampling_rate != SR_TARGET: 27 | raise IOError( 28 | f'Sampling rate of the file {full_path} is {sampling_rate} Hz, but the model requires {SR_TARGET} Hz' 29 | ) 30 | 31 | audio = audio / MAX_WAV_VALUE 32 | 33 | audio = torch.FloatTensor(audio) 34 | audio = audio.unsqueeze(0) 35 | 36 | return audio 37 | 38 | 39 | def readmgc(x): 40 | frame_length = 1024 41 | hop_length = 256 42 | # Windowing 43 | frames = librosa.util.frame(x, frame_length=frame_length, hop_length=hop_length).astype(np.float64).T 44 | frames *= pysptk.blackman(frame_length) 45 | assert frames.shape[1] == frame_length 46 | # Order of mel-cepstrum 47 | order = 25 48 | alpha = 0.41 49 | stage = 5 50 | gamma = -1.0 / stage 51 | 52 | mgc = pysptk.mgcep(frames, order, alpha, gamma) 53 | mgc = mgc.reshape(-1, order + 1) 54 | return mgc 55 | 56 | 57 | def evaluate(gt_dir, synth_dir): 58 | """Perform objective evaluation""" 59 | files = [file for file in os.listdir(synth_dir) if file.endswith('.wav')] 60 | gpu = 0 if torch.cuda.is_available() else None 61 | device = torch.device('cpu' if gpu is None else f'cuda:{gpu}') 62 | torch.cuda.empty_cache() 63 | 64 | mrstft_tot = 0.0 65 | pesq_tot = 0.0 66 | s = 0.0 67 | frames_tot = 0 68 | 69 | resampler_16k = ta.transforms.Resample(SR_TARGET, 16000).to(device) 70 | resampler_22k = ta.transforms.Resample(SR_TARGET, 22050).to(device) 71 | 72 | # Modules for evaluation metrics 73 | loss_mrstft = auraloss.freq.MultiResolutionSTFTLoss(device=device) 74 | batch_metrics_periodicity = Pitch() 75 | periodicity_fn = functools.partial(from_audio, gpu=gpu) 76 | 77 | with torch.no_grad(): 78 | 79 | iterator = tqdm(files, dynamic_ncols=True, desc=f'Evaluating {synth_dir}') 80 | for wavID in iterator: 81 | 82 | y = load_wav(os.path.join(gt_dir, wavID)) 83 | y_g_hat = load_wav(os.path.join(synth_dir, wavID)) 84 | y = y.to(device) 85 | y_g_hat = y_g_hat.to(device) 86 | 87 | y_16k = resampler_16k(y) 88 | y_g_hat_16k = resampler_16k(y_g_hat) 89 | 90 | y_22k = resampler_22k(y) 91 | y_g_hat_22k = resampler_22k(y_g_hat) 92 | 93 | # MRSTFT calculation 94 | mrstft_tot += loss_mrstft(y_g_hat, y).item() 95 | 96 | # PESQ calculation 97 | y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() 98 | y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() 99 | pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, 'wb') 100 | 101 | # MCD calculation 102 | y_double_22k = (y_22k[0] * MAX_WAV_VALUE).double().cpu().numpy() 103 | y_g_hat_double_22k = (y_g_hat_22k[0] * MAX_WAV_VALUE).double().cpu().numpy() 104 | 105 | y_mgc = readmgc(y_double_22k) 106 | y_g_hat_mgc = readmgc(y_g_hat_double_22k) 107 | 108 | _, path = fastdtw(y_mgc, y_g_hat_mgc, dist=euclidean) 109 | 110 | y_path = list(map(lambda l: l[0], path)) 111 | y_g_hat_path = list(map(lambda l: l[1], path)) 112 | y_mgc = y_mgc[y_path] 113 | y_g_hat_mgc = y_g_hat_mgc[y_g_hat_path] 114 | 115 | frames_tot += y_mgc.shape[0] 116 | 117 | z = y_mgc - y_g_hat_mgc 118 | s += np.sqrt((z * z).sum(-1)).sum() 119 | 120 | # Periodicity calculation 121 | true_pitch, true_periodicity = periodicity_fn(y_22k) 122 | pred_pitch, pred_periodicity = periodicity_fn(y_g_hat_22k) 123 | batch_metrics_periodicity.update(true_pitch, true_periodicity, pred_pitch, pred_periodicity) 124 | 125 | results = batch_metrics_periodicity() 126 | 127 | return { 128 | 'M-STFT': mrstft_tot / len(files), 129 | 'PESQ': pesq_tot / len(files), 130 | 'MCD': 10.0 / np.log(10.0) * np.sqrt(2.0) * float(s) / float(frames_tot), 131 | 'Periodicity': results['periodicity'], 132 | 'V/UV F1': results['f1'], 133 | } 134 | 135 | 136 | def main(): 137 | parser = argparse.ArgumentParser() 138 | parser.add_argument('list_wavs_dir', nargs='+') 139 | parser.add_argument('--output_file', default=None) 140 | a = parser.parse_args() 141 | 142 | if len(a.list_wavs_dir) & 1: 143 | raise ValueError('The number of directories should be even.') 144 | 145 | # Check directories 146 | list_gt_dir = [] 147 | list_synth_dir = [] 148 | for i in range(0, len(a.list_wavs_dir), 2): 149 | gt_dir = a.list_wavs_dir[i] 150 | synth_dir = a.list_wavs_dir[i + 1] 151 | 152 | gt_files = set(os.listdir(gt_dir)) 153 | synth_files = set([file for file in os.listdir(synth_dir) if file.endswith('.wav')]) 154 | if gt_files < synth_files: 155 | raise IOError( 156 | f'Each file in "{synth_dir}" needs to have the corresponding file that has the same name in "{gt_dir}"' 157 | ) 158 | 159 | list_gt_dir.append(gt_dir) 160 | list_synth_dir.append(synth_dir) 161 | 162 | # Evaluate waveforms 163 | results_tot = { 164 | 'M-STFT': 0.0, 165 | 'PESQ': 0.0, 166 | 'MCD': 0.0, 167 | 'Periodicity': 0.0, 168 | 'V/UV F1': 0.0, 169 | 'dir_results': {}, 170 | } 171 | for gt_dir, synth_dir in zip(list_gt_dir, list_synth_dir): 172 | results = evaluate(gt_dir, synth_dir) 173 | results_tot['M-STFT'] += results['M-STFT'] 174 | results_tot['PESQ'] += results['PESQ'] 175 | results_tot['MCD'] += results['MCD'] 176 | results_tot['Periodicity'] += results['Periodicity'] 177 | results_tot['V/UV F1'] += results['V/UV F1'] 178 | results_tot['dir_results'][synth_dir] = results 179 | results_tot['M-STFT'] /= len(results_tot['dir_results']) 180 | results_tot['PESQ'] /= len(results_tot['dir_results']) 181 | results_tot['MCD'] /= len(results_tot['dir_results']) 182 | results_tot['Periodicity'] /= len(results_tot['dir_results']) 183 | results_tot['V/UV F1'] /= len(results_tot['dir_results']) 184 | 185 | # Print to stdout 186 | print(results_tot) 187 | 188 | if a.output_file: 189 | # Write results 190 | with open(a.output_file, 'w') as file: 191 | json.dump(results_tot, file, ensure_ascii=False, indent=4) 192 | 193 | 194 | if __name__ == '__main__': 195 | main() 196 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | auraloss 2 | cargan 3 | fastdtw 4 | librosa 5 | pesq 6 | pysptk 7 | torchaudio 8 | tqdm 9 | --------------------------------------------------------------------------------