├── .gitignore ├── LICENSE ├── README.md ├── audioldm_eval ├── __init__.py ├── audio │ ├── __init__.py │ ├── audio_processing.py │ ├── stft.py │ └── tools.py ├── datasets │ ├── __init__.py │ ├── load_mel.py │ └── transforms.py ├── eval.py ├── eval_parallel.py ├── feature_extractors │ ├── __init__.py │ ├── inception3.py │ ├── melception.py │ ├── melception_audioset.py │ └── panns │ │ ├── __init__.py │ │ ├── config.py │ │ ├── evaluate.py │ │ ├── finetune_template.py │ │ ├── losses.py │ │ ├── main.py │ │ ├── models.py │ │ ├── pytorch_utils.py │ │ └── utilities.py └── metrics │ ├── __init__.py │ ├── fad.py │ ├── fid.py │ ├── gs │ ├── __init__.py │ ├── geom_score.py │ ├── top_utils.py │ └── utils.py │ ├── isc.py │ ├── kid.py │ ├── kl.py │ ├── ndb.py │ └── validate.py ├── gen_test_file.py ├── setup.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | ckpt/ 2 | *.pth 3 | *.wav 4 | *.npy 5 | *.egg-info 6 | __pycache__ 7 | vctk_test 8 | .DS_* 9 | script/* 10 | datasets/* 11 | test_fad/* 12 | *.ckpt 13 | *.json 14 | audio 15 | build 16 | dist 17 | *.pkl 18 | pickle_check.py 19 | test.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012-2022 Scott Chacon and others 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Audio Generation Evaluation 2 | 3 | This toolbox aims to unify audio generation model evaluation for easier future comparison. 4 | 5 | ## Quick Start 6 | 7 | First, prepare the environment 8 | ```shell 9 | pip install git+https://github.com/haoheliu/audioldm_eval 10 | ``` 11 | 12 | Second, generate test dataset by 13 | ```shell 14 | python3 gen_test_file.py 15 | ``` 16 | 17 | Finally, perform a test run. A result for reference is attached [here](https://github.com/haoheliu/audioldm_eval/blob/main/example/paired_ref.json). 18 | ```shell 19 | python3 test.py # Evaluate and save the json file to disk (example/paired.json) 20 | ``` 21 | 22 | ## Evaluation metrics 23 | We have the following metrics in this toolbox: 24 | 25 | - Recommanded: 26 | - FAD: Frechet audio distance 27 | - ISc: Inception score 28 | - Other for references: 29 | - FD: Frechet distance, realized by either PANNs, a state-of-the-art audio classification model, or MERT, a music understanding model. 30 | - KID: Kernel inception score 31 | - KL: KL divergence (softmax over logits) 32 | - KL_Sigmoid: KL divergence (sigmoid over logits) 33 | - PSNR: Peak signal noise ratio 34 | - SSIM: Structural similarity index measure 35 | - LSD: Log-spectral distance 36 | 37 | The evaluation function will accept the paths of two folders as main parameters. 38 | 1. If two folder have **files with same name and same numbers of files**, the evaluation will run in **paired mode**. 39 | 2. If two folder have **different numbers of files or files with different name**, the evaluation will run in **unpaired mode**. 40 | 41 | **These metrics will only be calculated in paried mode**: KL, KL_Sigmoid, PSNR, SSIM, LSD. 42 | In the unpaired mode, these metrics will return minus one. 43 | 44 | ## Evaluation on AudioCaps and AudioSet 45 | 46 | The AudioCaps test set consists of audio files with multiple text annotations. To evaluate the performance of AudioLDM, we randomly selected one annotation per audio file, which can be found in the [accompanying json file](https://github.com/haoheliu/audioldm_eval/tree/c9e936ea538c4db7e971d9528a2d2eb4edac975d/example/AudioCaps). 47 | 48 | Given the size of the AudioSet evaluation set with approximately 20,000 audio files, it may be impractical for audio generative models to perform evaluation on the entire set. As a result, we randomly selected 2,000 audio files for evaluation, with the corresponding annotations available in a [json file](https://github.com/haoheliu/audioldm_eval/tree/c9e936ea538c4db7e971d9528a2d2eb4edac975d/example/AudioSet). 49 | 50 | For more information on our evaluation process, please refer to [our paper](https://arxiv.org/abs/2301.12503). 51 | 52 | ## Example 53 | 54 | Single-GPU mode: 55 | 56 | ```python 57 | import torch 58 | from audioldm_eval import EvaluationHelper 59 | 60 | # GPU acceleration is preferred 61 | device = torch.device(f"cuda:{0}") 62 | 63 | generation_result_path = "example/paired" 64 | target_audio_path = "example/reference" 65 | 66 | # Initialize a helper instance 67 | evaluator = EvaluationHelper(16000, device) 68 | 69 | # Perform evaluation, result will be print out and saved as json 70 | metrics = evaluator.main( 71 | generation_result_path, 72 | target_audio_path, 73 | backbone="cnn14", # `cnn14` refers to PANNs model, `mert` refers to MERT model 74 | limit_num=None # If you only intend to evaluate X (int) pairs of data, set limit_num=X 75 | ) 76 | ``` 77 | 78 | Multi-GPU mode: 79 | 80 | ```python 81 | import torch 82 | from audioldm_eval import EvaluationHelperParallel 83 | import torch.multiprocessing as mp 84 | 85 | generation_result_path = "example/paired" 86 | target_audio_path = "example/reference" 87 | 88 | if __name__ == '__main__': 89 | evaluator = EvaluationHelperParallel(16000, 2) # 2 denotes number of GPUs 90 | metrics = evaluator.main( 91 | generation_result_path, 92 | target_audio_path, 93 | backbone="cnn14", # `cnn14` refers to PANNs model, `mert` refers to MERT model 94 | limit_num=None # If you only intend to evaluate X (int) pairs of data, set limit_num=X 95 | ) 96 | ``` 97 | 98 | You can use `CUDA_VISIBLE_DEVICES` to specify the GPU/GPUs to use. 99 | 100 | ```shell 101 | CUDA_VISIBLE_DEVICES=0,1 python3 test.py 102 | ``` 103 | 104 | ## Note 105 | - Update on 29 Sept 2024: 106 | - **MERT inference:** Note that the MERT model is trained on 24 kHz, but the repository inference in either 16 kHz or 32 kHz mode. In both modes, we resample the audio to 24 kHz. 107 | - **FAD calculation:** The FAD calculation currently even in the parallel mode will only be done on the first GPU, due to the implementation we currently use. 108 | - Update on 24 June 2023: 109 | - **Issues on model evaluation:** I found the PANNs based Frechet Distance and KL score is not as robust as FAD sometimes. For example, when the generation are all silent audio, the FAD and KL still indicate model perform very well, while FAD and Inception Score (IS) can still reflect the model true bad performance. Sometimes the resample method on audio can significantly affect the FD (+-30) and KL (+-0.4) performance as well. 110 | - To address this issue, in another branch of this repo ([passt_replace_panns](https://github.com/haoheliu/audioldm_eval/tree/passt_replace_panns)), I change the PANNs model to Passt, which I found to be more robust to resample method and other trival mismatches. 111 | 112 | - **Update on code:** The calculation of FAD is slow. Now, after each calculation of a folder, the code will save the FAD feature into an .npy file for later reference. 113 | 114 | ## TODO 115 | 116 | - [ ] Add pretrained AudioLDM model. 117 | - [ ] Add CLAP score 118 | 119 | ## Cite this repo 120 | 121 | If you found this tool useful, please consider citing 122 | ```bibtex 123 | @article{audioldm2-2024taslp, 124 | author={Liu, Haohe and Yuan, Yi and Liu, Xubo and Mei, Xinhao and Kong, Qiuqiang and Tian, Qiao and Wang, Yuping and Wang, Wenwu and Wang, Yuxuan and Plumbley, Mark D.}, 125 | journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing}, 126 | title={AudioLDM 2: Learning Holistic Audio Generation With Self-Supervised Pretraining}, 127 | year={2024}, 128 | volume={32}, 129 | pages={2871-2883}, 130 | doi={10.1109/TASLP.2024.3399607} 131 | } 132 | 133 | @article{liu2023audioldm, 134 | title={{AudioLDM}: Text-to-Audio Generation with Latent Diffusion Models}, 135 | author={Liu, Haohe and Chen, Zehua and Yuan, Yi and Mei, Xinhao and Liu, Xubo and Mandic, Danilo and Wang, Wenwu and Plumbley, Mark D}, 136 | journal={Proceedings of the International Conference on Machine Learning}, 137 | year={2023} 138 | pages={21450-21474} 139 | } 140 | ``` 141 | 142 | ## Reference 143 | 144 | > https://github.com/toshas/torch-fidelity 145 | 146 | > https://github.com/v-iashin/SpecVQGAN 147 | -------------------------------------------------------------------------------- /audioldm_eval/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics.fid import calculate_fid 2 | from .metrics.isc import calculate_isc 3 | from .metrics.kid import calculate_kid 4 | from .metrics.kl import calculate_kl 5 | from .eval import EvaluationHelper 6 | from .eval_parallel import EvaluationHelperParallel -------------------------------------------------------------------------------- /audioldm_eval/audio/__init__.py: -------------------------------------------------------------------------------- 1 | # import audio.tools 2 | # import audio.stft 3 | # import audio.audio_processing 4 | from .stft import * 5 | from .audio_processing import * 6 | from .tools import * 7 | -------------------------------------------------------------------------------- /audioldm_eval/audio/audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import librosa.util as librosa_util 4 | from scipy.signal import get_window 5 | 6 | 7 | def window_sumsquare( 8 | window, 9 | n_frames, 10 | hop_length, 11 | win_length, 12 | n_fft, 13 | dtype=np.float32, 14 | norm=None, 15 | ): 16 | """ 17 | # from librosa 0.6 18 | Compute the sum-square envelope of a window function at a given hop length. 19 | 20 | This is used to estimate modulation effects induced by windowing 21 | observations in short-time fourier transforms. 22 | 23 | Parameters 24 | ---------- 25 | window : string, tuple, number, callable, or list-like 26 | Window specification, as in `get_window` 27 | 28 | n_frames : int > 0 29 | The number of analysis frames 30 | 31 | hop_length : int > 0 32 | The number of samples to advance between frames 33 | 34 | win_length : [optional] 35 | The length of the window function. By default, this matches `n_fft`. 36 | 37 | n_fft : int > 0 38 | The length of each analysis frame. 39 | 40 | dtype : np.dtype 41 | The data type of the output 42 | 43 | Returns 44 | ------- 45 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 46 | The sum-squared envelope of the window function 47 | """ 48 | if win_length is None: 49 | win_length = n_fft 50 | 51 | n = n_fft + hop_length * (n_frames - 1) 52 | x = np.zeros(n, dtype=dtype) 53 | 54 | # Compute the squared window at the desired length 55 | win_sq = get_window(window, win_length, fftbins=True) 56 | win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 57 | win_sq = librosa_util.pad_center(win_sq, n_fft) 58 | 59 | # Fill the envelope 60 | for i in range(n_frames): 61 | sample = i * hop_length 62 | x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] 63 | return x 64 | 65 | 66 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 67 | """ 68 | PARAMS 69 | ------ 70 | magnitudes: spectrogram magnitudes 71 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 72 | """ 73 | 74 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 75 | angles = angles.astype(np.float32) 76 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 77 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 78 | 79 | for i in range(n_iters): 80 | _, angles = stft_fn.transform(signal) 81 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 82 | return signal 83 | 84 | 85 | def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): 86 | """ 87 | PARAMS 88 | ------ 89 | C: compression factor 90 | """ 91 | return normalize_fun(torch.clamp(x, min=clip_val) * C) 92 | 93 | 94 | def dynamic_range_decompression(x, C=1): 95 | """ 96 | PARAMS 97 | ------ 98 | C: compression factor used to compress 99 | """ 100 | return torch.exp(x) / C 101 | -------------------------------------------------------------------------------- /audioldm_eval/audio/stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy.signal import get_window 5 | from librosa.util import pad_center, tiny 6 | from librosa.filters import mel as librosa_mel_fn 7 | 8 | from audioldm_eval.audio.audio_processing import ( 9 | dynamic_range_compression, 10 | dynamic_range_decompression, 11 | window_sumsquare, 12 | ) 13 | 14 | 15 | class STFT(torch.nn.Module): 16 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 17 | 18 | def __init__(self, filter_length, hop_length, win_length, window="hann"): 19 | super(STFT, self).__init__() 20 | self.filter_length = filter_length 21 | self.hop_length = hop_length 22 | self.win_length = win_length 23 | self.window = window 24 | self.forward_transform = None 25 | scale = self.filter_length / self.hop_length 26 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 27 | 28 | cutoff = int((self.filter_length / 2 + 1)) 29 | fourier_basis = np.vstack( 30 | [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] 31 | ) 32 | 33 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 34 | inverse_basis = torch.FloatTensor( 35 | np.linalg.pinv(scale * fourier_basis).T[:, None, :] 36 | ) 37 | 38 | if window is not None: 39 | assert filter_length >= win_length 40 | # get window and zero center pad it to filter_length 41 | fft_window = get_window(window, win_length, fftbins=True) 42 | fft_window = pad_center(fft_window, size=filter_length) 43 | fft_window = torch.from_numpy(fft_window).float() 44 | 45 | # window the bases 46 | forward_basis *= fft_window 47 | inverse_basis *= fft_window 48 | 49 | self.register_buffer("forward_basis", forward_basis.float()) 50 | self.register_buffer("inverse_basis", inverse_basis.float()) 51 | 52 | def transform(self, input_data): 53 | num_batches = input_data.size(0) 54 | num_samples = input_data.size(1) 55 | 56 | self.num_samples = num_samples 57 | 58 | # similar to librosa, reflect-pad the input 59 | input_data = input_data.view(num_batches, 1, num_samples) 60 | input_data = F.pad( 61 | input_data.unsqueeze(1), 62 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 63 | mode="reflect", 64 | ) 65 | input_data = input_data.squeeze(1) 66 | 67 | forward_transform = F.conv1d( 68 | input_data, 69 | torch.autograd.Variable(self.forward_basis, requires_grad=False), 70 | stride=self.hop_length, 71 | padding=0, 72 | ).cpu() 73 | 74 | cutoff = int((self.filter_length / 2) + 1) 75 | real_part = forward_transform[:, :cutoff, :] 76 | imag_part = forward_transform[:, cutoff:, :] 77 | 78 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 79 | phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) 80 | 81 | return magnitude, phase 82 | 83 | def inverse(self, magnitude, phase): 84 | recombine_magnitude_phase = torch.cat( 85 | [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 86 | ) 87 | 88 | inverse_transform = F.conv_transpose1d( 89 | recombine_magnitude_phase, 90 | torch.autograd.Variable(self.inverse_basis, requires_grad=False), 91 | stride=self.hop_length, 92 | padding=0, 93 | ) 94 | 95 | if self.window is not None: 96 | window_sum = window_sumsquare( 97 | self.window, 98 | magnitude.size(-1), 99 | hop_length=self.hop_length, 100 | win_length=self.win_length, 101 | n_fft=self.filter_length, 102 | dtype=np.float32, 103 | ) 104 | # remove modulation effects 105 | approx_nonzero_indices = torch.from_numpy( 106 | np.where(window_sum > tiny(window_sum))[0] 107 | ) 108 | window_sum = torch.autograd.Variable( 109 | torch.from_numpy(window_sum), requires_grad=False 110 | ) 111 | window_sum = window_sum 112 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ 113 | approx_nonzero_indices 114 | ] 115 | 116 | # scale by hop ratio 117 | inverse_transform *= float(self.filter_length) / self.hop_length 118 | 119 | inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] 120 | inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] 121 | 122 | return inverse_transform 123 | 124 | def forward(self, input_data): 125 | self.magnitude, self.phase = self.transform(input_data) 126 | reconstruction = self.inverse(self.magnitude, self.phase) 127 | return reconstruction 128 | 129 | 130 | class TacotronSTFT(torch.nn.Module): 131 | def __init__( 132 | self, 133 | filter_length, 134 | hop_length, 135 | win_length, 136 | n_mel_channels, 137 | sampling_rate, 138 | mel_fmin, 139 | mel_fmax, 140 | ): 141 | super(TacotronSTFT, self).__init__() 142 | self.n_mel_channels = n_mel_channels 143 | self.sampling_rate = sampling_rate 144 | self.stft_fn = STFT(filter_length, hop_length, win_length) 145 | mel_basis = librosa_mel_fn(sr=sampling_rate, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax) 146 | mel_basis = torch.from_numpy(mel_basis).float() 147 | self.register_buffer("mel_basis", mel_basis) 148 | 149 | def spectral_normalize(self, magnitudes, normalize_fun): 150 | output = dynamic_range_compression(magnitudes, normalize_fun) 151 | return output 152 | 153 | def spectral_de_normalize(self, magnitudes): 154 | output = dynamic_range_decompression(magnitudes) 155 | return output 156 | 157 | def mel_spectrogram(self, y, normalize_fun=torch.log): 158 | """Computes mel-spectrograms from a batch of waves 159 | PARAMS 160 | ------ 161 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 162 | 163 | RETURNS 164 | ------- 165 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 166 | """ 167 | assert torch.min(y.data) >= -1 168 | assert torch.max(y.data) <= 1 169 | 170 | magnitudes, phases = self.stft_fn.transform(y) 171 | magnitudes = magnitudes.data 172 | mel_output = torch.matmul(self.mel_basis, magnitudes) 173 | mel_output = self.spectral_normalize(mel_output, normalize_fun) 174 | energy = torch.norm(magnitudes, dim=1) 175 | 176 | return mel_output, energy 177 | -------------------------------------------------------------------------------- /audioldm_eval/audio/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.io.wavfile import write 4 | import pickle 5 | import json 6 | from audioldm_eval.audio.audio_processing import griffin_lim 7 | 8 | 9 | def save_pickle(obj, fname): 10 | print("Save pickle at " + fname) 11 | with open(fname, "wb") as f: 12 | pickle.dump(obj, f) 13 | 14 | 15 | def load_pickle(fname): 16 | print("Load pickle at " + fname) 17 | with open(fname, "rb") as f: 18 | res = pickle.load(f) 19 | return res 20 | 21 | 22 | def write_json(my_dict, fname): 23 | print("Save json file at " + fname) 24 | json_str = json.dumps(my_dict) 25 | with open(fname, "w") as json_file: 26 | json_file.write(json_str) 27 | 28 | 29 | def load_json(fname): 30 | with open(fname, "r") as f: 31 | data = json.load(f) 32 | return data 33 | 34 | 35 | def get_mel_from_wav(audio, _stft): 36 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) 37 | audio = torch.autograd.Variable(audio, requires_grad=False) 38 | melspec, energy = _stft.mel_spectrogram(audio) 39 | melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) 40 | energy = torch.squeeze(energy, 0).numpy().astype(np.float32) 41 | return melspec, energy 42 | 43 | 44 | def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60): 45 | mel = torch.stack([mel]) 46 | mel_decompress = _stft.spectral_de_normalize(mel) 47 | mel_decompress = mel_decompress.transpose(1, 2).data.cpu() 48 | spec_from_mel_scaling = 1000 49 | spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis) 50 | spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) 51 | spec_from_mel = spec_from_mel * spec_from_mel_scaling 52 | 53 | audio = griffin_lim( 54 | torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters 55 | ) 56 | 57 | audio = audio.squeeze() 58 | audio = audio.cpu().numpy() 59 | audio_path = out_filename 60 | write(audio_path, _stft.sampling_rate, audio) 61 | -------------------------------------------------------------------------------- /audioldm_eval/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haoheliu/audioldm_eval/8dc07ee7c42f9dc6e295460a1034175a0d49b436/audioldm_eval/datasets/__init__.py -------------------------------------------------------------------------------- /audioldm_eval/datasets/load_mel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import torchaudio 5 | from tqdm import tqdm 6 | # import librosa 7 | 8 | def pad_short_audio(audio, min_samples=32000): 9 | if(audio.size(-1) < min_samples): 10 | audio = torch.nn.functional.pad(audio, (0, min_samples - audio.size(-1)), mode='constant', value=0.0) 11 | return audio 12 | 13 | class MelPairedDataset(torch.utils.data.Dataset): 14 | def __init__( 15 | self, 16 | datadir1, 17 | datadir2, 18 | _stft, 19 | sr=16000, 20 | fbin_mean=None, 21 | fbin_std=None, 22 | augment=False, 23 | limit_num=None, 24 | ): 25 | self.datalist1 = [os.path.join(datadir1, x) for x in os.listdir(datadir1)] 26 | self.datalist1 = sorted(self.datalist1) 27 | 28 | self.datalist2 = [os.path.join(datadir2, x) for x in os.listdir(datadir2)] 29 | self.datalist2 = sorted(self.datalist2) 30 | 31 | if limit_num is not None: 32 | self.datalist1 = self.datalist1[:limit_num] 33 | self.datalist2 = self.datalist2[:limit_num] 34 | 35 | self.align_two_file_list() 36 | 37 | self._stft = _stft 38 | self.sr = sr 39 | self.augment = augment 40 | 41 | # if fbin_mean is not None: 42 | # self.fbin_mean = fbin_mean[..., None] 43 | # self.fbin_std = fbin_std[..., None] 44 | # else: 45 | # self.fbin_mean = None 46 | # self.fbin_std = None 47 | 48 | def align_two_file_list(self): 49 | data_dict1 = {os.path.basename(x): x for x in self.datalist1} 50 | data_dict2 = {os.path.basename(x): x for x in self.datalist2} 51 | 52 | keyset1 = set(data_dict1.keys()) 53 | keyset2 = set(data_dict2.keys()) 54 | 55 | intersect_keys = keyset1.intersection(keyset2) 56 | 57 | self.datalist1 = [data_dict1[k] for k in intersect_keys] 58 | self.datalist2 = [data_dict2[k] for k in intersect_keys] 59 | 60 | print("Two path have %s intersection files" % len(intersect_keys)) 61 | 62 | def __getitem__(self, index): 63 | while True: 64 | try: 65 | filename1 = self.datalist1[index] 66 | filename2 = self.datalist2[index] 67 | mel1, _, audio1 = self.get_mel_from_file(filename1) 68 | mel2, _, audio2 = self.get_mel_from_file(filename2) 69 | break 70 | except Exception as e: 71 | print(index, e) 72 | index = (index + 1) % len(self.datalist) 73 | 74 | # if(self.fbin_mean is not None): 75 | # mel = (mel - self.fbin_mean) / self.fbin_std 76 | min_len = min(mel1.shape[-1], mel2.shape[-1]) 77 | return ( 78 | mel1[..., :min_len], 79 | mel2[..., :min_len], 80 | os.path.basename(filename1), 81 | (audio1, audio2), 82 | ) 83 | 84 | def __len__(self): 85 | return len(self.datalist1) 86 | 87 | def get_mel_from_file(self, audio_file): 88 | audio, file_sr = torchaudio.load(audio_file) 89 | # Only use the first channel 90 | audio = audio[0:1,...] 91 | audio = audio - audio.mean() 92 | if file_sr != self.sr: 93 | audio = torchaudio.functional.resample( 94 | audio, orig_freq=file_sr, new_freq=self.sr 95 | ) 96 | 97 | if self._stft is not None: 98 | melspec, energy = self.get_mel_from_wav(audio[0, ...]) 99 | else: 100 | melspec, energy = None, None 101 | 102 | return melspec, energy, audio 103 | 104 | def get_mel_from_wav(self, audio): 105 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) 106 | audio = torch.autograd.Variable(audio, requires_grad=False) 107 | 108 | # ========================================================================= 109 | # Following the processing in https://github.com/v-iashin/SpecVQGAN/blob/5bc54f30eb89f82d129aa36ae3f1e90b60e73952/vocoder/mel2wav/extract_mel_spectrogram.py#L141 110 | melspec, energy = self._stft.mel_spectrogram(audio, normalize_fun=torch.log10) 111 | melspec = (melspec * 20) - 20 112 | melspec = (melspec + 100) / 100 113 | melspec = torch.clip(melspec, min=0, max=1.0) 114 | # ========================================================================= 115 | # Augment 116 | # if(self.augment): 117 | # for i in range(1): 118 | # random_start = int(torch.rand(1) * 950) 119 | # melspec[0,:,random_start:random_start+50] = 0.0 120 | # ========================================================================= 121 | melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) 122 | energy = torch.squeeze(energy, 0).numpy().astype(np.float32) 123 | return melspec, energy 124 | 125 | 126 | class WaveDataset(torch.utils.data.Dataset): 127 | def __init__( 128 | self, 129 | datadir, 130 | sr=16000, 131 | limit_num=None, 132 | ): 133 | self.datalist = [os.path.join(datadir, x) for x in os.listdir(datadir)] 134 | self.datalist = sorted(self.datalist) 135 | if limit_num is not None: 136 | self.datalist = self.datalist[:limit_num] 137 | self.sr = sr 138 | 139 | def __getitem__(self, index): 140 | while True: 141 | try: 142 | filename = self.datalist[index] 143 | waveform = self.read_from_file(filename) 144 | if waveform.size(-1) < 1: 145 | raise ValueError("empty file %s" % filename) 146 | break 147 | except Exception as e: 148 | print(index, e) 149 | index = (index + 1) % len(self.datalist) 150 | 151 | return waveform, os.path.basename(filename) 152 | 153 | def __len__(self): 154 | return len(self.datalist) 155 | 156 | def read_from_file(self, audio_file): 157 | audio, file_sr = torchaudio.load(audio_file) 158 | # Only use the first channel 159 | audio = audio[0:1,...] 160 | audio = audio - audio.mean() 161 | 162 | # if file_sr != self.sr and file_sr == 32000 and self.sr == 16000: 163 | # audio = audio[..., ::2] 164 | # if file_sr != self.sr and file_sr == 48000 and self.sr == 16000: 165 | # audio = audio[..., ::3] 166 | # el 167 | 168 | if file_sr != self.sr: 169 | audio = torchaudio.functional.resample( 170 | audio, orig_freq=file_sr, new_freq=self.sr, # rolloff=0.95, lowpass_filter_width=16 171 | ) 172 | # audio = torch.FloatTensor(librosa.resample(audio.numpy(), file_sr, self.sr)) 173 | 174 | audio = pad_short_audio(audio, min_samples=32000) 175 | return audio 176 | 177 | def load_npy_data(loader): 178 | new_train = [] 179 | for mel, waveform, filename in tqdm(loader): 180 | batch = batch.float().numpy() 181 | new_train.append( 182 | batch.reshape( 183 | -1, 184 | ) 185 | ) 186 | new_train = np.array(new_train) 187 | return new_train 188 | 189 | 190 | if __name__ == "__main__": 191 | path = "/scratch/combined/result/ground/00294 harvest festival rumour 1_mel.npy" 192 | temp = np.load(path) 193 | print("temp", temp.shape) 194 | -------------------------------------------------------------------------------- /audioldm_eval/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from specvqgan.modules.losses.vggishish.transforms import Crop 3 | 4 | 5 | class FromMinusOneOneToZeroOne(object): 6 | """Actually, it doesnot do [-1, 1] --> [0, 1] as promised. It would, if inputs would be in [-1, 1] 7 | but reconstructed specs are not.""" 8 | 9 | def __call__(self, item): 10 | item["image"] = (item["image"] + 1) / 2 11 | return item 12 | 13 | 14 | class CropNoDict(Crop): 15 | def __init__(self, cropped_shape, random_crop=None): 16 | super().__init__(cropped_shape=cropped_shape, random_crop=random_crop) 17 | 18 | def __call__(self, x): 19 | # albumentations expect an ndarray of size (H, W, ...) but we have tensor of size (B, H, W). 20 | # we will assume that the batch-dim (B) is out "channel" dim and permute it to the end. 21 | # Finally, we change the type back to Torch.Tensor. 22 | x = self.preprocessor(image=x.permute(1, 2, 0).numpy())["image"].transpose( 23 | 2, 0, 1 24 | ) 25 | return torch.from_numpy(x) 26 | 27 | 28 | class GetInputFromBatchByKey(object): # get image from item dict 29 | def __init__(self, input_key): 30 | self.input_key = input_key 31 | 32 | def __call__(self, item): 33 | return item[self.input_key] 34 | 35 | 36 | class ToFloat32(object): 37 | def __call__(self, item): 38 | return item.float() 39 | -------------------------------------------------------------------------------- /audioldm_eval/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from audioldm_eval.datasets.load_mel import load_npy_data, MelPairedDataset, WaveDataset 3 | import numpy as np 4 | import argparse 5 | import datetime 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | import torchaudio.transforms as T 10 | from transformers import Wav2Vec2Processor, AutoModel, Wav2Vec2FeatureExtractor 11 | from audioldm_eval.metrics.fad import FrechetAudioDistance 12 | from audioldm_eval import calculate_fid, calculate_isc, calculate_kid, calculate_kl 13 | from skimage.metrics import peak_signal_noise_ratio as psnr 14 | from skimage.metrics import structural_similarity as ssim 15 | from audioldm_eval.feature_extractors.panns import Cnn14 16 | from audioldm_eval.audio.tools import save_pickle, load_pickle, write_json, load_json 17 | from ssr_eval.metrics import AudioMetrics 18 | import audioldm_eval.audio as Audio 19 | 20 | class EvaluationHelper: 21 | def __init__(self, sampling_rate, device, backbone="mert") -> None: 22 | 23 | self.device = device 24 | self.backbone = backbone 25 | self.sampling_rate = sampling_rate 26 | self.frechet = FrechetAudioDistance( 27 | use_pca=False, 28 | use_activation=False, 29 | verbose=True, 30 | ) 31 | 32 | # self.passt_model = get_basic_model(mode="logits") 33 | # self.passt_model.eval() 34 | # self.passt_model.to(self.device) 35 | 36 | # self.lsd_metric = AudioMetrics(self.sampling_rate) 37 | self.frechet.model = self.frechet.model.to(device) 38 | 39 | features_list = ["2048", "logits"] 40 | 41 | if self.backbone == "mert": 42 | self.mel_model = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True) 43 | self.processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-95M",trust_remote_code=True) 44 | self.target_sample_rate = self.processor.sampling_rate 45 | self.resampler = T.Resample(orig_freq=self.sampling_rate, new_freq=self.target_sample_rate).to(self.device) 46 | elif self.backbone == "cnn14": 47 | if self.sampling_rate == 16000: 48 | self.mel_model = Cnn14( 49 | features_list=features_list, 50 | sample_rate=16000, 51 | window_size=512, 52 | hop_size=160, 53 | mel_bins=64, 54 | fmin=50, 55 | fmax=8000, 56 | classes_num=527, 57 | ) 58 | elif self.sampling_rate == 32000: 59 | self.mel_model = Cnn14( 60 | features_list=features_list, 61 | sample_rate=32000, 62 | window_size=1024, 63 | hop_size=320, 64 | mel_bins=64, 65 | fmin=50, 66 | fmax=14000, 67 | classes_num=527, 68 | ) 69 | else: 70 | raise ValueError( 71 | "We only support the evaluation on 16kHz and 32kHz sampling rate for CNN14." 72 | ) 73 | else: 74 | raise ValueError("Backbone not supported") 75 | 76 | if self.sampling_rate == 16000: 77 | self._stft = Audio.TacotronSTFT(512, 160, 512, 64, 16000, 50, 8000) 78 | elif self.sampling_rate == 32000: 79 | self._stft = Audio.TacotronSTFT(1024, 320, 1024, 64, 32000, 50, 14000) 80 | else: 81 | raise ValueError( 82 | "We only support the evaluation on 16kHz and 32kHz sampling rate." 83 | ) 84 | 85 | self.mel_model.eval() 86 | self.mel_model.to(self.device) 87 | self.fbin_mean, self.fbin_std = None, None 88 | 89 | def main( 90 | self, 91 | generate_files_path, 92 | groundtruth_path, 93 | limit_num=None, 94 | ): 95 | print("Generted files", generate_files_path) 96 | print("Target files", groundtruth_path) 97 | 98 | self.file_init_check(generate_files_path) 99 | self.file_init_check(groundtruth_path) 100 | 101 | same_name = self.get_filename_intersection_ratio( 102 | generate_files_path, groundtruth_path, limit_num=limit_num 103 | ) 104 | 105 | metrics = self.calculate_metrics(generate_files_path, groundtruth_path, same_name, limit_num) # recalculate = True 106 | 107 | return metrics 108 | 109 | def file_init_check(self, dir): 110 | assert os.path.exists(dir), "The path does not exist %s" % dir 111 | assert len(os.listdir(dir)) > 1, "There is no files in %s" % dir 112 | 113 | def get_filename_intersection_ratio( 114 | self, dir1, dir2, threshold=0.99, limit_num=None 115 | ): 116 | self.datalist1 = [os.path.join(dir1, x) for x in os.listdir(dir1)] 117 | self.datalist1 = sorted(self.datalist1) 118 | 119 | self.datalist2 = [os.path.join(dir2, x) for x in os.listdir(dir2)] 120 | self.datalist2 = sorted(self.datalist2) 121 | 122 | data_dict1 = {os.path.basename(x): x for x in self.datalist1} 123 | data_dict2 = {os.path.basename(x): x for x in self.datalist2} 124 | 125 | keyset1 = set(data_dict1.keys()) 126 | keyset2 = set(data_dict2.keys()) 127 | 128 | intersect_keys = keyset1.intersection(keyset2) 129 | if ( 130 | len(intersect_keys) / len(keyset1) > threshold 131 | and len(intersect_keys) / len(keyset2) > threshold 132 | ): 133 | print( 134 | "+Two path have %s intersection files out of total %s & %s files. Processing two folder with same_name=True" 135 | % (len(intersect_keys), len(keyset1), len(keyset2)) 136 | ) 137 | return True 138 | else: 139 | print( 140 | "-Two path have %s intersection files out of total %s & %s files. Processing two folder with same_name=False" 141 | % (len(intersect_keys), len(keyset1), len(keyset2)) 142 | ) 143 | return False 144 | 145 | def calculate_lsd(self, pairedloader, same_name=True, time_offset=160 * 7): 146 | if same_name == False: 147 | return { 148 | "lsd": -1, 149 | "ssim_stft": -1, 150 | } 151 | print("Calculating LSD using a time offset of %s ..." % time_offset) 152 | lsd_avg = [] 153 | ssim_stft_avg = [] 154 | for _, _, filename, (audio1, audio2) in tqdm(pairedloader): 155 | audio1 = audio1.cpu().numpy()[0, 0] 156 | audio2 = audio2.cpu().numpy()[0, 0] 157 | 158 | # If you use HIFIGAN (verified on 2023-01-12), you need seven frames' offset 159 | audio1 = audio1[time_offset:] 160 | 161 | audio1 = audio1 - np.mean(audio1) 162 | audio2 = audio2 - np.mean(audio2) 163 | 164 | audio1 = audio1 / np.max(np.abs(audio1)) 165 | audio2 = audio2 / np.max(np.abs(audio2)) 166 | 167 | min_len = min(audio1.shape[0], audio2.shape[0]) 168 | 169 | audio1, audio2 = audio1[:min_len], audio2[:min_len] 170 | 171 | result = self.lsd(audio1, audio2) 172 | 173 | lsd_avg.append(result["lsd"]) 174 | ssim_stft_avg.append(result["ssim"]) 175 | 176 | return {"lsd": np.mean(lsd_avg), "ssim_stft": np.mean(ssim_stft_avg)} 177 | 178 | def lsd(self, audio1, audio2): 179 | result = self.lsd_metric.evaluation(audio1, audio2, None) 180 | return result 181 | 182 | def calculate_psnr_ssim(self, pairedloader, same_name=True): 183 | if same_name == False: 184 | return {"psnr": -1, "ssim": -1} 185 | psnr_avg = [] 186 | ssim_avg = [] 187 | for mel_gen, mel_target, filename, _ in tqdm(pairedloader): 188 | mel_gen = mel_gen.cpu().numpy()[0] 189 | mel_target = mel_target.cpu().numpy()[0] 190 | psnrval = psnr(mel_gen, mel_target) 191 | if np.isinf(psnrval): 192 | print("Infinite value encountered in psnr %s " % filename) 193 | continue 194 | psnr_avg.append(psnrval) 195 | data_range = max(np.max(mel_gen), np.max(mel_target)) - min(np.min(mel_gen), np.min(mel_target)) 196 | ssim_avg.append(ssim(mel_gen, mel_target, data_range=data_range)) 197 | return {"psnr": np.mean(psnr_avg), "ssim": np.mean(ssim_avg)} 198 | 199 | def calculate_metrics(self, generate_files_path, groundtruth_path, same_name, limit_num=None, calculate_psnr_ssim=False, calculate_lsd=False, recalculate=False): 200 | # Generation, target 201 | torch.manual_seed(0) 202 | 203 | num_workers = 6 204 | 205 | outputloader = DataLoader( 206 | WaveDataset( 207 | generate_files_path, 208 | self.sampling_rate, # TODO 209 | # 32000, 210 | limit_num=limit_num, 211 | ), 212 | batch_size=1, 213 | sampler=None, 214 | num_workers=num_workers, 215 | ) 216 | 217 | resultloader = DataLoader( 218 | WaveDataset( 219 | groundtruth_path, 220 | self.sampling_rate, # TODO 221 | # 32000, 222 | limit_num=limit_num, 223 | ), 224 | batch_size=1, 225 | sampler=None, 226 | num_workers=num_workers, 227 | ) 228 | 229 | out = {} 230 | 231 | # FAD 232 | ###################################################################################################################### 233 | if(recalculate): 234 | print("Calculate FAD score from scratch") 235 | fad_score = self.frechet.score(generate_files_path, groundtruth_path, limit_num=limit_num, recalculate=recalculate) 236 | out.update(fad_score) 237 | print("FAD: %s" % fad_score) 238 | ###################################################################################################################### 239 | 240 | # PANNs or PassT 241 | ###################################################################################################################### 242 | cache_path = groundtruth_path + "classifier_logits_feature_cache.pkl" 243 | if(os.path.exists(cache_path) and not recalculate): 244 | print("reload", cache_path) 245 | featuresdict_2 = load_pickle(cache_path) 246 | else: 247 | print("Extracting features from %s." % groundtruth_path) 248 | featuresdict_2 = self.get_featuresdict(resultloader) 249 | save_pickle(featuresdict_2, cache_path) 250 | 251 | cache_path = generate_files_path + "classifier_logits_feature_cache.pkl" 252 | if(os.path.exists(cache_path) and not recalculate): 253 | print("reload", cache_path) 254 | featuresdict_1 = load_pickle(cache_path) 255 | else: 256 | print("Extracting features from %s." % generate_files_path) 257 | featuresdict_1 = self.get_featuresdict(outputloader) 258 | save_pickle(featuresdict_1, cache_path) 259 | 260 | metric_kl, kl_ref, paths_1 = calculate_kl( 261 | featuresdict_1, featuresdict_2, "logits", same_name 262 | ) 263 | 264 | out.update(metric_kl) 265 | 266 | metric_isc = calculate_isc( 267 | featuresdict_1, 268 | feat_layer_name="logits", 269 | splits=10, 270 | samples_shuffle=True, 271 | rng_seed=2020, 272 | ) 273 | out.update(metric_isc) 274 | 275 | if("2048" in featuresdict_1.keys() and "2048" in featuresdict_2.keys()): 276 | metric_fid = calculate_fid( 277 | featuresdict_1, featuresdict_2, feat_layer_name="2048" 278 | ) 279 | out.update(metric_fid) 280 | 281 | # Metrics for Autoencoder 282 | ###################################################################################################################### 283 | if(calculate_psnr_ssim or calculate_lsd): 284 | pairedloader = DataLoader( 285 | MelPairedDataset( 286 | generate_files_path, 287 | groundtruth_path, 288 | self._stft, 289 | self.sampling_rate, 290 | self.fbin_mean, 291 | self.fbin_std, 292 | limit_num=limit_num, 293 | ), 294 | batch_size=1, 295 | sampler=None, 296 | num_workers=16, 297 | ) 298 | 299 | if(calculate_lsd): 300 | metric_lsd = self.calculate_lsd(pairedloader, same_name=same_name) 301 | out.update(metric_lsd) 302 | 303 | if(calculate_psnr_ssim): 304 | metric_psnr_ssim = self.calculate_psnr_ssim(pairedloader, same_name=same_name) 305 | out.update(metric_psnr_ssim) 306 | 307 | # metric_kid = calculate_kid( 308 | # featuresdict_1, 309 | # featuresdict_2, 310 | # feat_layer_name="2048", 311 | # subsets=100, 312 | # subset_size=1000, 313 | # degree=3, 314 | # gamma=None, 315 | # coef0=1, 316 | # rng_seed=2020, 317 | # ) 318 | # out.update(metric_kid) 319 | 320 | print("\n".join((f"{k}: {v:.7f}" for k, v in out.items()))) 321 | print("\n") 322 | print(limit_num) 323 | print( 324 | f'KL_Sigmoid: {out.get("kullback_leibler_divergence_sigmoid", float("nan")):8.5f};', 325 | f'KL: {out.get("kullback_leibler_divergence_softmax", float("nan")):8.5f};', 326 | f'PSNR: {out.get("psnr", float("nan")):.5f}', 327 | f'SSIM: {out.get("ssim", float("nan")):.5f}', 328 | f'ISc: {out.get("inception_score_mean", float("nan")):8.5f} ({out.get("inception_score_std", float("nan")):5f});', 329 | f'KID: {out.get("kernel_inception_distance_mean", float("nan")):.5f}', 330 | f'({out.get("kernel_inception_distance_std", float("nan")):.5f})', 331 | f'FD: {out.get("frechet_distance", float("nan")):8.5f};', 332 | f'FAD: {out.get("frechet_audio_distance", float("nan")):.5f}', 333 | f'LSD: {out.get("lsd", float("nan")):.5f}', 334 | # f'SSIM_STFT: {out.get("ssim_stft", float("nan")):.5f}', 335 | ) 336 | result = { 337 | "frechet_distance": out.get("frechet_distance", float("nan")), 338 | "frechet_audio_distance": out.get("frechet_audio_distance", float("nan")), 339 | "kullback_leibler_divergence_sigmoid": out.get( 340 | "kullback_leibler_divergence_sigmoid", float("nan") 341 | ), 342 | "kullback_leibler_divergence_softmax": out.get( 343 | "kullback_leibler_divergence_softmax", float("nan") 344 | ), 345 | "lsd": out.get("lsd", float("nan")), 346 | "psnr": out.get("psnr", float("nan")), 347 | "ssim": out.get("ssim", float("nan")), 348 | # "ssim_stft": out.get("ssim_stft", float("nan")), 349 | "inception_score_mean": out.get("inception_score_mean", float("nan")), 350 | "inception_score_std": out.get("inception_score_std", float("nan")), 351 | "kernel_inception_distance_mean": out.get( 352 | "kernel_inception_distance_mean", float("nan") 353 | ), 354 | "kernel_inception_distance_std": out.get( 355 | "kernel_inception_distance_std", float("nan") 356 | ), 357 | } 358 | 359 | json_path = os.path.join(os.path.dirname(generate_files_path), self.get_current_time()+"_"+os.path.basename(generate_files_path) + ".json") 360 | write_json(result, json_path) 361 | return result 362 | 363 | def get_current_time(self): 364 | now = datetime.datetime.now() 365 | return now.strftime("%Y-%m-%d-%H:%M:%S") 366 | 367 | def get_featuresdict(self, dataloader): 368 | out = None 369 | out_meta = None 370 | 371 | # transforms=StandardNormalizeAudio() 372 | for waveform, filename in tqdm(dataloader): 373 | try: 374 | metadict = { 375 | "file_path_": filename, 376 | } 377 | waveform = waveform.squeeze(1) 378 | 379 | # batch = transforms(batch) 380 | waveform = waveform.float().to(self.device) 381 | 382 | # featuresdict = {} 383 | # with torch.no_grad(): 384 | # if(waveform.size(-1) >= 320000): 385 | # waveform = waveform[...,:320000] 386 | # else: 387 | # waveform = torch.nn.functional.pad(waveform, (0,320000-waveform.size(-1))) 388 | # featuresdict["logits"] = self.passt_model(waveform) 389 | 390 | with torch.no_grad(): 391 | if self.backbone == "mert": 392 | waveform = self.resampler(waveform[0]) 393 | mert_input = self.processor(waveform, sampling_rate=self.target_sample_rate, return_tensors="pt").to(self.device) 394 | mert_output = self.mel_model(**mert_input, output_hidden_states=True) 395 | time_reduced_hidden_states = torch.stack(mert_output.hidden_states).squeeze().mean(dim=1) 396 | featuresdict = {"2048": time_reduced_hidden_states.cpu(), "logits": time_reduced_hidden_states.cpu()} 397 | elif self.backbone == "cnn14": 398 | featuresdict = self.mel_model(waveform) 399 | 400 | featuresdict = {k: [v.cpu()] for k, v in featuresdict.items()} 401 | 402 | if out is None: 403 | out = featuresdict 404 | else: 405 | out = {k: out[k] + featuresdict[k] for k in out.keys()} 406 | 407 | if out_meta is None: 408 | out_meta = metadict 409 | else: 410 | out_meta = {k: out_meta[k] + metadict[k] for k in out_meta.keys()} 411 | except Exception as e: 412 | import ipdb 413 | 414 | ipdb.set_trace() 415 | print("Classifier Inference error: ", e) 416 | continue 417 | 418 | out = {k: torch.cat(v, dim=0) for k, v in out.items()} 419 | return {**out, **out_meta} 420 | 421 | def sample_from(self, samples, number_to_use): 422 | assert samples.shape[0] >= number_to_use 423 | rand_order = np.random.permutation(samples.shape[0]) 424 | return samples[rand_order[: samples.shape[0]], :] 425 | 426 | 427 | if __name__ == "__main__": 428 | import yaml 429 | import argparse 430 | from audioldm_eval import EvaluationHelper 431 | import torch 432 | 433 | parser = argparse.ArgumentParser() 434 | 435 | parser.add_argument( 436 | "-g", 437 | "--generation_result_path", 438 | type=str, 439 | required=False, 440 | help="Audio sampling rate during evaluation", 441 | default="/mnt/fast/datasets/audio/audioset/2million_audioset_wav/balanced_train_segments", 442 | ) 443 | 444 | parser.add_argument( 445 | "-t", 446 | "--target_audio_path", 447 | type=str, 448 | required=False, 449 | help="Audio sampling rate during evaluation", 450 | default="/mnt/fast/datasets/audio/audioset/2million_audioset_wav/eval_segments", 451 | ) 452 | 453 | parser.add_argument( 454 | "-sr", 455 | "--sampling_rate", 456 | type=int, 457 | required=False, 458 | help="Audio sampling rate during evaluation", 459 | default=16000, 460 | ) 461 | 462 | parser.add_argument( 463 | "-l", 464 | "--limit_num", 465 | type=int, 466 | required=False, 467 | help="Audio clip numbers limit for evaluation", 468 | default=None, 469 | ) 470 | 471 | args = parser.parse_args() 472 | 473 | device = torch.device(f"cuda:{0}") 474 | 475 | evaluator = EvaluationHelper(args.sampling_rate, device) 476 | 477 | metrics = evaluator.main( 478 | args.generation_result_path, 479 | args.target_audio_path, 480 | limit_num=args.limit_num, 481 | same_name=args.same_name, 482 | ) 483 | 484 | print(metrics) 485 | -------------------------------------------------------------------------------- /audioldm_eval/eval_parallel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import datetime 4 | import numpy as np 5 | import torch.multiprocessing as mp 6 | import torchaudio.transforms as T 7 | from transformers import Wav2Vec2Processor, AutoModel, Wav2Vec2FeatureExtractor 8 | from torch.utils.data import DataLoader, DistributedSampler 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | import torch.distributed as dist 11 | from audioldm_eval.datasets.load_mel import WaveDataset 12 | from audioldm_eval.metrics.fad import FrechetAudioDistance 13 | from audioldm_eval import calculate_fid, calculate_isc, calculate_kid, calculate_kl 14 | from audioldm_eval.feature_extractors.panns import Cnn14 15 | from audioldm_eval.audio.tools import save_pickle, load_pickle, write_json, load_json 16 | from tqdm import tqdm 17 | 18 | class EvaluationHelperParallel: 19 | def __init__(self, sampling_rate, num_gpus, batch_size=1, backbone="mert") -> None: 20 | self.sampling_rate = sampling_rate 21 | self.num_gpus = num_gpus 22 | self.backbone = backbone 23 | self.batch_size = batch_size 24 | 25 | def setup(self, rank, world_size): 26 | os.environ['MASTER_ADDR'] = 'localhost' 27 | os.environ['MASTER_PORT'] = '12355' 28 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 29 | torch.cuda.set_device(rank) 30 | 31 | def cleanup(self): 32 | dist.destroy_process_group() 33 | 34 | def init_models(self, rank): 35 | self.device = torch.device(f"cuda:{rank}") 36 | self.frechet = FrechetAudioDistance(use_pca=False, use_activation=False, verbose=True) 37 | self.frechet.model = self.frechet.model.to(self.device) 38 | 39 | features_list = ["2048", "logits"] 40 | 41 | if self.backbone == "mert": 42 | self.mel_model = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True) 43 | self.processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-95M",trust_remote_code=True) 44 | self.target_sample_rate = self.processor.sampling_rate 45 | self.resampler = T.Resample(orig_freq=self.sampling_rate, new_freq=self.target_sample_rate).to(self.device) 46 | 47 | elif self.backbone == "cnn14": 48 | if self.sampling_rate == 16000: 49 | self.mel_model = Cnn14( 50 | features_list=features_list, 51 | sample_rate=16000, 52 | window_size=512, 53 | hop_size=160, 54 | mel_bins=64, 55 | fmin=50, 56 | fmax=8000, 57 | classes_num=527, 58 | ) 59 | elif self.sampling_rate == 32000: 60 | self.mel_model = Cnn14( 61 | features_list=features_list, 62 | sample_rate=32000, 63 | window_size=1024, 64 | hop_size=320, 65 | mel_bins=64, 66 | fmin=50, 67 | fmax=14000, 68 | classes_num=527, 69 | ) 70 | else: 71 | raise ValueError("We only support the evaluation on 16kHz and 32kHz sampling rate.") 72 | 73 | else: 74 | raise ValueError("Backbone not supported") 75 | 76 | self.mel_model = DDP(self.mel_model.to(self.device), device_ids=[rank]) 77 | self.mel_model.eval() 78 | 79 | def main(self, generate_files_path, groundtruth_path, limit_num=None): 80 | mp.spawn(self.run, args=(generate_files_path, groundtruth_path, limit_num), nprocs=self.num_gpus, join=True) 81 | 82 | def get_featuresdict(self, rank, dataloader): 83 | out = None 84 | out_meta = {"file_path_": []} 85 | 86 | for waveform, filename in dataloader: 87 | try: 88 | waveform = waveform.squeeze(1).float().to(self.device) 89 | 90 | with torch.no_grad(): 91 | if self.backbone == "mert": 92 | waveform = self.resampler(waveform[0]) 93 | mert_input = self.processor(waveform, sampling_rate=self.target_sample_rate, return_tensors="pt").to(self.device) 94 | mert_output = self.mel_model(**mert_input, output_hidden_states=True) 95 | time_reduced_hidden_states = torch.stack(mert_output.hidden_states).squeeze().mean(dim=1) 96 | featuresdict = {"2048": time_reduced_hidden_states, "logits": time_reduced_hidden_states} 97 | elif self.backbone == "cnn14": 98 | featuresdict = self.mel_model(waveform) 99 | 100 | featuresdict = {k: v for k, v in featuresdict.items()} 101 | 102 | if out is None: 103 | out = featuresdict 104 | else: 105 | out = {k: torch.cat([out[k], featuresdict[k]], dim=0) for k in out.keys()} 106 | 107 | out_meta["file_path_"].extend(filename) 108 | except Exception as e: 109 | print(f"Classifier Inference error on rank {rank}: ", e) 110 | continue 111 | 112 | return out, out_meta 113 | 114 | def gather_features(self, featuresdict, metadict): 115 | all_features = {} 116 | for k, v in featuresdict.items(): 117 | gathered = [torch.zeros_like(v) for _ in range(self.num_gpus)] 118 | dist.all_gather(gathered, v) 119 | all_features[k] = torch.cat(gathered, dim=0) 120 | 121 | all_meta = {} 122 | for k, v in metadict.items(): 123 | gathered = [None for _ in range(self.num_gpus)] 124 | dist.all_gather_object(gathered, v) 125 | all_meta[k] = sum(gathered, []) 126 | 127 | return {**all_features, **all_meta} 128 | 129 | def run(self, rank, generate_files_path, groundtruth_path, limit_num): 130 | self.setup(rank, self.num_gpus) 131 | self.init_models(rank) 132 | 133 | same_name = self.get_filename_intersection_ratio(generate_files_path, groundtruth_path, limit_num=limit_num) 134 | 135 | metrics = self.calculate_metrics(rank, generate_files_path, groundtruth_path, same_name, limit_num) # recalculate = True 136 | 137 | if rank == 0: 138 | print("\n".join((f"{k}: {v:.7f}" for k, v in metrics.items()))) 139 | json_path = os.path.join(os.path.dirname(generate_files_path), f"{self.get_current_time()}_{os.path.basename(generate_files_path)}.json") 140 | write_json(metrics, json_path) 141 | 142 | self.cleanup() 143 | 144 | def calculate_metrics(self, rank, generate_files_path, groundtruth_path, same_name, limit_num=None, calculate_psnr_ssim=False, calculate_lsd=False, recalculate=False): 145 | torch.manual_seed(0) 146 | num_workers = 6 147 | 148 | output_dataset = WaveDataset(generate_files_path, self.sampling_rate, limit_num=limit_num) 149 | result_dataset = WaveDataset(groundtruth_path, self.sampling_rate, limit_num=limit_num) 150 | 151 | output_sampler = DistributedSampler(output_dataset, num_replicas=self.num_gpus, rank=rank) 152 | result_sampler = DistributedSampler(result_dataset, num_replicas=self.num_gpus, rank=rank) 153 | 154 | outputloader = DataLoader( 155 | output_dataset, 156 | batch_size=self.batch_size, 157 | sampler=output_sampler, 158 | num_workers=num_workers, 159 | ) 160 | 161 | resultloader = DataLoader( 162 | result_dataset, 163 | batch_size=self.batch_size, 164 | sampler=result_sampler, 165 | num_workers=num_workers, 166 | ) 167 | 168 | out = {} 169 | if rank == 0: 170 | if(recalculate): 171 | print("Calculate FAD score from scratch") 172 | fad_score = self.frechet.score(generate_files_path, groundtruth_path, limit_num=limit_num, recalculate=recalculate) 173 | out.update(fad_score) 174 | print("FAD: %s" % fad_score) 175 | 176 | cache_path = generate_files_path + "classifier_logits_feature_cache.pkl" 177 | if os.path.exists(cache_path) and not recalculate: 178 | print("reload", cache_path) 179 | all_featuresdict_1 = load_pickle(cache_path) 180 | else: 181 | print(f"Extracting features from {generate_files_path}.") 182 | featuresdict_1, metadict_1 = self.get_featuresdict(rank, outputloader) 183 | all_featuresdict_1 = self.gather_features(featuresdict_1, metadict_1) 184 | if rank == 0: 185 | save_pickle(all_featuresdict_1, cache_path) 186 | 187 | cache_path = groundtruth_path + "classifier_logits_feature_cache.pkl" 188 | if os.path.exists(cache_path) and not recalculate: 189 | print("reload", cache_path) 190 | all_featuresdict_2 = load_pickle(cache_path) 191 | else: 192 | print(f"Extracting features from {groundtruth_path}.") 193 | featuresdict_2, metadict_2 = self.get_featuresdict(rank, resultloader) 194 | all_featuresdict_2 = self.gather_features(featuresdict_2, metadict_2) 195 | if rank == 0: 196 | save_pickle(all_featuresdict_2, cache_path) 197 | 198 | if rank == 0: 199 | for k, v in all_featuresdict_1.items(): 200 | if isinstance(v, torch.Tensor): 201 | all_featuresdict_1[k] = v.cpu() 202 | for k, v in all_featuresdict_2.items(): 203 | if isinstance(v, torch.Tensor): 204 | all_featuresdict_2[k] = v.cpu() 205 | 206 | metric_kl, _, _ = calculate_kl(all_featuresdict_1, all_featuresdict_2, "logits", same_name) 207 | out.update(metric_kl) 208 | 209 | metric_isc = calculate_isc(all_featuresdict_1, feat_layer_name="logits", splits=10, samples_shuffle=True, rng_seed=2020) 210 | out.update(metric_isc) 211 | 212 | if "2048" in all_featuresdict_1.keys() and "2048" in all_featuresdict_2.keys(): 213 | metric_fid = calculate_fid(all_featuresdict_1, all_featuresdict_2, feat_layer_name="2048") 214 | out.update(metric_fid) 215 | 216 | if(calculate_psnr_ssim or calculate_lsd): 217 | pairedloader = DataLoader( 218 | MelPairedDataset( 219 | generate_files_path, 220 | groundtruth_path, 221 | self._stft, 222 | self.sampling_rate, 223 | self.fbin_mean, 224 | self.fbin_std, 225 | limit_num=limit_num, 226 | ), 227 | batch_size=self.batch_size, 228 | sampler=None, 229 | num_workers=16, 230 | ) 231 | 232 | if(calculate_lsd): 233 | metric_lsd = self.calculate_lsd(pairedloader, same_name=same_name) 234 | out.update(metric_lsd) 235 | 236 | if(calculate_psnr_ssim): 237 | metric_psnr_ssim = self.calculate_psnr_ssim(pairedloader, same_name=same_name) 238 | out.update(metric_psnr_ssim) 239 | 240 | dist.barrier() 241 | return out 242 | 243 | def file_init_check(self, dir): 244 | assert os.path.exists(dir), "The path does not exist %s" % dir 245 | assert len(os.listdir(dir)) > 1, "There is no files in %s" % dir 246 | 247 | def get_filename_intersection_ratio( 248 | self, dir1, dir2, threshold=0.99, limit_num=None 249 | ): 250 | self.datalist1 = [os.path.join(dir1, x) for x in os.listdir(dir1)] 251 | self.datalist1 = sorted(self.datalist1) 252 | 253 | self.datalist2 = [os.path.join(dir2, x) for x in os.listdir(dir2)] 254 | self.datalist2 = sorted(self.datalist2) 255 | 256 | data_dict1 = {os.path.basename(x): x for x in self.datalist1} 257 | data_dict2 = {os.path.basename(x): x for x in self.datalist2} 258 | 259 | keyset1 = set(data_dict1.keys()) 260 | keyset2 = set(data_dict2.keys()) 261 | 262 | intersect_keys = keyset1.intersection(keyset2) 263 | if ( 264 | len(intersect_keys) / len(keyset1) > threshold 265 | and len(intersect_keys) / len(keyset2) > threshold 266 | ): 267 | print( 268 | "+Two path have %s intersection files out of total %s & %s files. Processing two folder with same_name=True" 269 | % (len(intersect_keys), len(keyset1), len(keyset2)) 270 | ) 271 | return True 272 | else: 273 | print( 274 | "-Two path have %s intersection files out of total %s & %s files. Processing two folder with same_name=False" 275 | % (len(intersect_keys), len(keyset1), len(keyset2)) 276 | ) 277 | return False 278 | 279 | def calculate_lsd(self, pairedloader, same_name=True, time_offset=160 * 7): 280 | if same_name == False: 281 | return { 282 | "lsd": -1, 283 | "ssim_stft": -1, 284 | } 285 | print("Calculating LSD using a time offset of %s ..." % time_offset) 286 | lsd_avg = [] 287 | ssim_stft_avg = [] 288 | for _, _, filename, (audio1, audio2) in tqdm(pairedloader): 289 | audio1 = audio1.cpu().numpy()[0, 0] 290 | audio2 = audio2.cpu().numpy()[0, 0] 291 | 292 | # If you use HIFIGAN (verified on 2023-01-12), you need seven frames' offset 293 | audio1 = audio1[time_offset:] 294 | 295 | audio1 = audio1 - np.mean(audio1) 296 | audio2 = audio2 - np.mean(audio2) 297 | 298 | audio1 = audio1 / np.max(np.abs(audio1)) 299 | audio2 = audio2 / np.max(np.abs(audio2)) 300 | 301 | min_len = min(audio1.shape[0], audio2.shape[0]) 302 | 303 | audio1, audio2 = audio1[:min_len], audio2[:min_len] 304 | 305 | result = self.lsd(audio1, audio2) 306 | 307 | lsd_avg.append(result["lsd"]) 308 | ssim_stft_avg.append(result["ssim"]) 309 | 310 | return {"lsd": np.mean(lsd_avg), "ssim_stft": np.mean(ssim_stft_avg)} 311 | 312 | def lsd(self, audio1, audio2): 313 | result = self.lsd_metric.evaluation(audio1, audio2, None) 314 | return result 315 | 316 | def calculate_psnr_ssim(self, pairedloader, same_name=True): 317 | if same_name == False: 318 | return {"psnr": -1, "ssim": -1} 319 | psnr_avg = [] 320 | ssim_avg = [] 321 | for mel_gen, mel_target, filename, _ in tqdm(pairedloader): 322 | mel_gen = mel_gen.cpu().numpy()[0] 323 | mel_target = mel_target.cpu().numpy()[0] 324 | psnrval = psnr(mel_gen, mel_target) 325 | if np.isinf(psnrval): 326 | print("Infinite value encountered in psnr %s " % filename) 327 | continue 328 | psnr_avg.append(psnrval) 329 | data_range = max(np.max(mel_gen), np.max(mel_target)) - min(np.min(mel_gen), np.min(mel_target)) 330 | ssim_avg.append(ssim(mel_gen, mel_target, data_range=data_range)) 331 | return {"psnr": np.mean(psnr_avg), "ssim": np.mean(ssim_avg)} 332 | 333 | def get_current_time(self): 334 | now = datetime.datetime.now() 335 | return now.strftime("%Y-%m-%d-%H:%M:%S") 336 | 337 | def sample_from(self, samples, number_to_use): 338 | assert samples.shape[0] >= number_to_use 339 | rand_order = np.random.permutation(samples.shape[0]) 340 | return samples[rand_order[: samples.shape[0]], :] -------------------------------------------------------------------------------- /audioldm_eval/feature_extractors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haoheliu/audioldm_eval/8dc07ee7c42f9dc6e295460a1034175a0d49b436/audioldm_eval/feature_extractors/__init__.py -------------------------------------------------------------------------------- /audioldm_eval/feature_extractors/inception3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from `https://github.com/pytorch/vision`. 3 | Modified by Vladimir Iashin, 2021. 4 | """ 5 | import math 6 | import sys 7 | from contextlib import redirect_stdout 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from omegaconf.listconfig import ListConfig 13 | from torch.hub import load_state_dict_from_url 14 | from torch.nn.modules.utils import _ntuple 15 | from torchvision.models.inception import BasicConv2d 16 | from torchvision.models.inception import Inception3 as TorchVisionInception3 17 | 18 | 19 | class FeatureExtractorInceptionV3(TorchVisionInception3): 20 | def __init__( 21 | self, name, features_list, feature_extractor_weights_path=None, **kwargs 22 | ): 23 | """Build pretrained InceptionV3 24 | 25 | Parameters 26 | ---------- 27 | features_list: list 28 | A list of feature names from the list of provided by this extractor, 29 | which will be produced for each input 30 | feature_extractor_weights_path: str 31 | Path to the pretrained Inception model weights in PyTorch format. 32 | Refer to inception_features.py:__main__ for making your own. 33 | By default downloads the checkpoint from internet. 34 | """ 35 | super().__init__(num_classes=1008, **kwargs) 36 | self.input_image_size = 299 37 | self.provided_feats = ("64", "192", "768", "2048", "logits_unbiased", "logits") 38 | self.features_list = list(features_list) 39 | 40 | assert type(name) is str, "Feature extractor name must be a string" 41 | assert type(features_list) in ( 42 | list, 43 | tuple, 44 | ListConfig, 45 | ), "Wrong features list type" 46 | assert all( 47 | (a in self.provided_feats for a in features_list) 48 | ), "Requested features arent on the list" 49 | assert len(features_list) == len( 50 | set(features_list) 51 | ), "Duplicate features requested" 52 | 53 | PT_INCEPTION_URL = "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" 54 | self.Mixed_5b = InceptionA(192, pool_features=32) 55 | self.Mixed_5c = InceptionA(256, pool_features=64) 56 | self.Mixed_5d = InceptionA(288, pool_features=64) 57 | self.Mixed_6b = InceptionC(768, channels_7x7=128) 58 | self.Mixed_6c = InceptionC(768, channels_7x7=160) 59 | self.Mixed_6d = InceptionC(768, channels_7x7=160) 60 | self.Mixed_6e = InceptionC(768, channels_7x7=192) 61 | 62 | self.Mixed_7b = InceptionE_1(1280) 63 | self.Mixed_7c = InceptionE_2(2048) 64 | self.AuxLogits = nn.Identity() 65 | 66 | if feature_extractor_weights_path is None: 67 | with redirect_stdout(sys.stderr): 68 | state_dict = load_state_dict_from_url(PT_INCEPTION_URL, progress=True) 69 | else: 70 | state_dict = torch.load(feature_extractor_weights_path) 71 | self.load_state_dict(state_dict) 72 | 73 | for p in self.parameters(): 74 | p.requires_grad_(False) 75 | 76 | def forward(self, x): 77 | assert ( 78 | torch.is_tensor(x) and x.dtype == torch.uint8 79 | ), "Expecting x as torch.Tensor, dtype=torch.uint8" 80 | features = {} 81 | remaining_features = self.features_list.copy() 82 | 83 | x = x.float() 84 | # N x 3 x ? x ? 85 | 86 | x = interpolate_bilinear_2d_like_tensorflow1x( 87 | x, 88 | size=(self.input_image_size, self.input_image_size), 89 | align_corners=False, 90 | ) 91 | # N x 3 x 299 x 299 92 | 93 | # x = (x - 128) * torch.tensor(0.0078125, dtype=torch.float32, device=x.device) # happening in graph 94 | x = (x - 128) / 128 # but this gives bit-exact output _of this step_ too 95 | # N x 3 x 299 x 299 96 | 97 | x = self.Conv2d_1a_3x3(x) 98 | # N x 32 x 149 x 149 99 | x = self.Conv2d_2a_3x3(x) 100 | # N x 32 x 147 x 147 101 | x = self.Conv2d_2b_3x3(x) 102 | # N x 64 x 147 x 147 103 | x = self.maxpool1(x) 104 | # N x 64 x 73 x 73 105 | 106 | if "64" in remaining_features: 107 | features["64"] = F.adaptive_avg_pool2d(x, output_size=(1, 1)) 108 | remaining_features.remove("64") 109 | if len(remaining_features) == 0: 110 | return tuple(features[a] for a in self.features_list) 111 | 112 | x = self.Conv2d_3b_1x1(x) 113 | # N x 80 x 73 x 73 114 | x = self.Conv2d_4a_3x3(x) 115 | # N x 192 x 71 x 71 116 | x = self.maxpool2(x) 117 | # N x 192 x 35 x 35 118 | 119 | if "192" in remaining_features: 120 | features["192"] = F.adaptive_avg_pool2d(x, output_size=(1, 1)) 121 | remaining_features.remove("192") 122 | if len(remaining_features) == 0: 123 | return tuple(features[a] for a in self.features_list) 124 | 125 | x = self.Mixed_5b(x) 126 | # N x 256 x 35 x 35 127 | x = self.Mixed_5c(x) 128 | # N x 288 x 35 x 35 129 | x = self.Mixed_5d(x) 130 | # N x 288 x 35 x 35 131 | x = self.Mixed_6a(x) 132 | # N x 768 x 17 x 17 133 | x = self.Mixed_6b(x) 134 | # N x 768 x 17 x 17 135 | x = self.Mixed_6c(x) 136 | # N x 768 x 17 x 17 137 | x = self.Mixed_6d(x) 138 | # N x 768 x 17 x 17 139 | x = self.Mixed_6e(x) 140 | # N x 768 x 17 x 17 141 | 142 | if "768" in remaining_features: 143 | features["768"] = F.adaptive_avg_pool2d(x, output_size=(1, 1)) 144 | remaining_features.remove("768") 145 | if len(remaining_features) == 0: 146 | return tuple(features[a] for a in self.features_list) 147 | 148 | x = self.Mixed_7a(x) 149 | # N x 1280 x 8 x 8 150 | x = self.Mixed_7b(x) 151 | # N x 2048 x 8 x 8 152 | x = self.Mixed_7c(x) 153 | # N x 2048 x 8 x 8 154 | x = self.avgpool(x) 155 | # N x 2048 x 1 x 1 156 | 157 | x = torch.flatten(x, 1) 158 | # N x 2048 159 | 160 | if "2048" in remaining_features: 161 | features["2048"] = x 162 | remaining_features.remove("2048") 163 | if len(remaining_features) == 0: 164 | return tuple(features[a] for a in self.features_list) 165 | 166 | if "logits_unbiased" in remaining_features: 167 | x = x.mm(self.fc.weight.T) 168 | # N x 1008 (num_classes) 169 | features["logits_unbiased"] = x 170 | remaining_features.remove("logits_unbiased") 171 | if len(remaining_features) == 0: 172 | return tuple(features[a] for a in self.features_list) 173 | 174 | x = x + self.fc.bias.unsqueeze(0) 175 | else: 176 | x = self.fc(x) 177 | # N x 1008 (num_classes) 178 | 179 | features["logits"] = x 180 | return tuple(features[a] for a in self.features_list) 181 | 182 | @staticmethod 183 | def get_provided_features_list(): 184 | return 185 | 186 | def get_requested_features_list(self): 187 | return self.features_list 188 | 189 | def get_name(self): 190 | return self.name 191 | 192 | def convert_features_tuple_to_dict(self, features): 193 | """ 194 | The only compound return type of the forward function amenable to JIT tracing is tuple. 195 | This function simply helps to recover the mapping. 196 | """ 197 | message = "Features must be the output of forward function" 198 | assert type(features) is tuple and len(features) == len( 199 | self.features_list 200 | ), message 201 | return dict( 202 | ((name, feature) for name, feature in zip(self.features_list, features)) 203 | ) 204 | 205 | 206 | class InceptionA(nn.Module): 207 | """Block from torchvision patched to be compatible with TensorFlow implementation""" 208 | 209 | def __init__(self, in_channels, pool_features): 210 | super(InceptionA, self).__init__() 211 | self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1) 212 | 213 | self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1) 214 | self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2) 215 | 216 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 217 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 218 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1) 219 | 220 | self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1) 221 | 222 | def forward(self, x): 223 | branch1x1 = self.branch1x1(x) 224 | 225 | branch5x5 = self.branch5x5_1(x) 226 | branch5x5 = self.branch5x5_2(branch5x5) 227 | 228 | branch3x3dbl = self.branch3x3dbl_1(x) 229 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 230 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 231 | 232 | # Patch: Tensorflow's average pool does not use the padded zero's in its average calculation 233 | branch_pool = F.avg_pool2d( 234 | x, kernel_size=3, stride=1, padding=1, count_include_pad=False 235 | ) 236 | branch_pool = self.branch_pool(branch_pool) 237 | 238 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 239 | return torch.cat(outputs, 1) 240 | 241 | 242 | class InceptionC(nn.Module): 243 | """Block from torchvision patched to be compatible with TensorFlow implementation""" 244 | 245 | def __init__(self, in_channels, channels_7x7): 246 | super(InceptionC, self).__init__() 247 | self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1) 248 | 249 | c7 = channels_7x7 250 | self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1) 251 | self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 252 | self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0)) 253 | 254 | self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1) 255 | self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 256 | self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 257 | self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 258 | self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 259 | 260 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 261 | 262 | def forward(self, x): 263 | branch1x1 = self.branch1x1(x) 264 | 265 | branch7x7 = self.branch7x7_1(x) 266 | branch7x7 = self.branch7x7_2(branch7x7) 267 | branch7x7 = self.branch7x7_3(branch7x7) 268 | 269 | branch7x7dbl = self.branch7x7dbl_1(x) 270 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 271 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 272 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 273 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 274 | 275 | # Patch: Tensorflow's average pool does not use the padded zero's in its average calculation 276 | branch_pool = F.avg_pool2d( 277 | x, kernel_size=3, stride=1, padding=1, count_include_pad=False 278 | ) 279 | branch_pool = self.branch_pool(branch_pool) 280 | 281 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 282 | return torch.cat(outputs, 1) 283 | 284 | 285 | class InceptionE_1(nn.Module): 286 | """First InceptionE block from torchvision patched to be compatible with TensorFlow implementation""" 287 | 288 | def __init__(self, in_channels): 289 | super(InceptionE_1, self).__init__() 290 | self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1) 291 | 292 | self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1) 293 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 294 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 295 | 296 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1) 297 | self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) 298 | self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 299 | self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 300 | 301 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 302 | 303 | def forward(self, x): 304 | branch1x1 = self.branch1x1(x) 305 | 306 | branch3x3 = self.branch3x3_1(x) 307 | branch3x3 = [ 308 | self.branch3x3_2a(branch3x3), 309 | self.branch3x3_2b(branch3x3), 310 | ] 311 | branch3x3 = torch.cat(branch3x3, 1) 312 | 313 | branch3x3dbl = self.branch3x3dbl_1(x) 314 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 315 | branch3x3dbl = [ 316 | self.branch3x3dbl_3a(branch3x3dbl), 317 | self.branch3x3dbl_3b(branch3x3dbl), 318 | ] 319 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 320 | 321 | # Patch: Tensorflow's average pool does not use the padded zero's in its average calculation 322 | branch_pool = F.avg_pool2d( 323 | x, kernel_size=3, stride=1, padding=1, count_include_pad=False 324 | ) 325 | branch_pool = self.branch_pool(branch_pool) 326 | 327 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 328 | return torch.cat(outputs, 1) 329 | 330 | 331 | class InceptionE_2(nn.Module): 332 | """Second InceptionE block from torchvision patched to be compatible with TensorFlow implementation""" 333 | 334 | def __init__(self, in_channels): 335 | super(InceptionE_2, self).__init__() 336 | self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1) 337 | 338 | self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1) 339 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 340 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 341 | 342 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1) 343 | self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) 344 | self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 345 | self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 346 | 347 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 348 | 349 | def forward(self, x): 350 | branch1x1 = self.branch1x1(x) 351 | 352 | branch3x3 = self.branch3x3_1(x) 353 | branch3x3 = [ 354 | self.branch3x3_2a(branch3x3), 355 | self.branch3x3_2b(branch3x3), 356 | ] 357 | branch3x3 = torch.cat(branch3x3, 1) 358 | 359 | branch3x3dbl = self.branch3x3dbl_1(x) 360 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 361 | branch3x3dbl = [ 362 | self.branch3x3dbl_3a(branch3x3dbl), 363 | self.branch3x3dbl_3b(branch3x3dbl), 364 | ] 365 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 366 | 367 | # Patch: TensorFlow Inception model uses max pooling instead of average 368 | # pooling. This is likely an error in this specific Inception 369 | # implementation, as other Inception models use average pooling here 370 | # (which matches the description in the paper). 371 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 372 | branch_pool = self.branch_pool(branch_pool) 373 | 374 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 375 | return torch.cat(outputs, 1) 376 | 377 | 378 | def interpolate_bilinear_2d_like_tensorflow1x( 379 | input, size=None, scale_factor=None, align_corners=None, method="slow" 380 | ): 381 | r"""Down/up samples the input to either the given :attr:`size` or the given :attr:`scale_factor` 382 | 383 | Epsilon-exact bilinear interpolation as it is implemented in TensorFlow 1.x: 384 | https://github.com/tensorflow/tensorflow/blob/f66daa493e7383052b2b44def2933f61faf196e0/tensorflow/core/kernels/image_resizer_state.h#L41 385 | https://github.com/tensorflow/tensorflow/blob/6795a8c3a3678fb805b6a8ba806af77ddfe61628/tensorflow/core/kernels/resize_bilinear_op.cc#L85 386 | as per proposal: 387 | https://github.com/pytorch/pytorch/issues/10604#issuecomment-465783319 388 | 389 | Related materials: 390 | https://hackernoon.com/how-tensorflows-tf-image-resize-stole-60-days-of-my-life-aba5eb093f35 391 | https://jricheimer.github.io/tensorflow/2019/02/11/resize-confusion/ 392 | https://machinethink.net/blog/coreml-upsampling/ 393 | 394 | Currently only 2D spatial sampling is supported, i.e. expected inputs are 4-D in shape. 395 | 396 | The input dimensions are interpreted in the form: 397 | `mini-batch x channels x height x width`. 398 | 399 | Args: 400 | input (Tensor): the input tensor 401 | size (Tuple[int, int]): output spatial size. 402 | scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple. 403 | align_corners (bool, optional): Same meaning as in TensorFlow 1.x. 404 | method (str, optional): 405 | 'slow' (1e-4 L_inf error on GPU, bit-exact on CPU, with checkerboard 32x32->299x299), or 406 | 'fast' (1e-3 L_inf error on GPU and CPU, with checkerboard 32x32->299x299) 407 | """ 408 | if method not in ("slow", "fast"): 409 | raise ValueError('how_exact can only be one of "slow", "fast"') 410 | 411 | if input.dim() != 4: 412 | raise ValueError("input must be a 4-D tensor") 413 | 414 | if align_corners is None: 415 | raise ValueError( 416 | "align_corners is not specified (use this function for a complete determinism)" 417 | ) 418 | 419 | def _check_size_scale_factor(dim): 420 | if size is None and scale_factor is None: 421 | raise ValueError("either size or scale_factor should be defined") 422 | if size is not None and scale_factor is not None: 423 | raise ValueError("only one of size or scale_factor should be defined") 424 | if ( 425 | scale_factor is not None 426 | and isinstance(scale_factor, tuple) 427 | and len(scale_factor) != dim 428 | ): 429 | raise ValueError( 430 | "scale_factor shape must match input shape. " 431 | "Input is {}D, scale_factor size is {}".format(dim, len(scale_factor)) 432 | ) 433 | 434 | is_tracing = torch._C._get_tracing_state() 435 | 436 | def _output_size(dim): 437 | _check_size_scale_factor(dim) 438 | if size is not None: 439 | if is_tracing: 440 | return [torch.tensor(i) for i in size] 441 | else: 442 | return size 443 | scale_factors = _ntuple(dim)(scale_factor) 444 | # math.floor might return float in py2.7 445 | 446 | # make scale_factor a tensor in tracing so constant doesn't get baked in 447 | if is_tracing: 448 | return [ 449 | ( 450 | torch.floor( 451 | ( 452 | input.size(i + 2).float() 453 | * torch.tensor(scale_factors[i], dtype=torch.float32) 454 | ).float() 455 | ) 456 | ) 457 | for i in range(dim) 458 | ] 459 | else: 460 | return [ 461 | int(math.floor(float(input.size(i + 2)) * scale_factors[i])) 462 | for i in range(dim) 463 | ] 464 | 465 | def tf_calculate_resize_scale(in_size, out_size): 466 | if align_corners: 467 | if is_tracing: 468 | return (in_size - 1) / (out_size.float() - 1).clamp(min=1) 469 | else: 470 | return (in_size - 1) / max(1, out_size - 1) 471 | else: 472 | if is_tracing: 473 | return in_size / out_size.float() 474 | else: 475 | return in_size / out_size 476 | 477 | out_size = _output_size(2) 478 | scale_x = tf_calculate_resize_scale(input.shape[3], out_size[1]) 479 | scale_y = tf_calculate_resize_scale(input.shape[2], out_size[0]) 480 | 481 | def resample_using_grid_sample(): 482 | grid_x = torch.arange(0, out_size[1], 1, dtype=input.dtype, device=input.device) 483 | grid_x = grid_x * (2 * scale_x / (input.shape[3] - 1)) - 1 484 | 485 | grid_y = torch.arange(0, out_size[0], 1, dtype=input.dtype, device=input.device) 486 | grid_y = grid_y * (2 * scale_y / (input.shape[2] - 1)) - 1 487 | 488 | grid_x = grid_x.view(1, out_size[1]).repeat(out_size[0], 1) 489 | grid_y = grid_y.view(out_size[0], 1).repeat(1, out_size[1]) 490 | 491 | grid_xy = torch.cat( 492 | (grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)), dim=2 493 | ).unsqueeze(0) 494 | grid_xy = grid_xy.repeat(input.shape[0], 1, 1, 1) 495 | 496 | out = F.grid_sample( 497 | input, grid_xy, mode="bilinear", padding_mode="border", align_corners=True 498 | ) 499 | return out 500 | 501 | def resample_manually(): 502 | grid_x = torch.arange(0, out_size[1], 1, dtype=input.dtype, device=input.device) 503 | grid_x = grid_x * torch.tensor(scale_x, dtype=torch.float32) 504 | grid_x_lo = grid_x.long() 505 | grid_x_hi = (grid_x_lo + 1).clamp_max(input.shape[3] - 1) 506 | grid_dx = grid_x - grid_x_lo.float() 507 | 508 | grid_y = torch.arange(0, out_size[0], 1, dtype=input.dtype, device=input.device) 509 | grid_y = grid_y * torch.tensor(scale_y, dtype=torch.float32) 510 | grid_y_lo = grid_y.long() 511 | grid_y_hi = (grid_y_lo + 1).clamp_max(input.shape[2] - 1) 512 | grid_dy = grid_y - grid_y_lo.float() 513 | 514 | # could be improved with index_select 515 | in_00 = input[:, :, grid_y_lo, :][:, :, :, grid_x_lo] 516 | in_01 = input[:, :, grid_y_lo, :][:, :, :, grid_x_hi] 517 | in_10 = input[:, :, grid_y_hi, :][:, :, :, grid_x_lo] 518 | in_11 = input[:, :, grid_y_hi, :][:, :, :, grid_x_hi] 519 | 520 | in_0 = in_00 + (in_01 - in_00) * grid_dx.view(1, 1, 1, out_size[1]) 521 | in_1 = in_10 + (in_11 - in_10) * grid_dx.view(1, 1, 1, out_size[1]) 522 | out = in_0 + (in_1 - in_0) * grid_dy.view(1, 1, out_size[0], 1) 523 | 524 | return out 525 | 526 | if method == "slow": 527 | out = resample_manually() 528 | else: 529 | out = resample_using_grid_sample() 530 | 531 | return out 532 | -------------------------------------------------------------------------------- /audioldm_eval/feature_extractors/melception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torchvision.models.inception import BasicConv2d, Inception3 4 | 5 | 6 | class Melception(Inception3): 7 | def __init__( 8 | self, num_classes, features_list, feature_extractor_weights_path, **kwargs 9 | ): 10 | # inception = Melception(num_classes=309) 11 | super().__init__(num_classes=num_classes, init_weights=True, **kwargs) 12 | self.features_list = list(features_list) 13 | # the same as https://github.com/pytorch/vision/blob/5339e63148/torchvision/models/inception.py#L95 14 | # but for 1-channel input instead of RGB. 15 | self.Conv2d_1a_3x3 = BasicConv2d(1, 32, kernel_size=3, stride=2) 16 | # also the 'hight' of the mel spec is 80 (vs 299 in RGB) we remove all max pool from Inception 17 | self.maxpool1 = torch.nn.Identity() 18 | self.maxpool2 = torch.nn.Identity() 19 | 20 | state_dict = torch.load(feature_extractor_weights_path, map_location="cpu") 21 | self.load_state_dict(state_dict["model"]) 22 | for p in self.parameters(): 23 | p.requires_grad_(False) 24 | 25 | def forward(self, x): 26 | features = {} 27 | remaining_features = self.features_list.copy() 28 | 29 | # B x 1 x 80 x 848 <- N x M x T 30 | x = x.unsqueeze(1) 31 | # (B, 32, 39, 423) <- 32 | x = self.Conv2d_1a_3x3(x) 33 | # (B, 32, 37, 421) <- 34 | x = self.Conv2d_2a_3x3(x) 35 | # (B, 64, 37, 421) <- 36 | x = self.Conv2d_2b_3x3(x) 37 | # (B, 64, 37, 421) <- 38 | x = self.maxpool1(x) 39 | 40 | if "64" in remaining_features: 41 | features["64"] = F.adaptive_avg_pool2d(x, output_size=(1, 1)) 42 | remaining_features.remove("64") 43 | if len(remaining_features) == 0: 44 | return tuple(features[a] for a in self.features_list) 45 | 46 | # (B, 80, 37, 421) <- 47 | x = self.Conv2d_3b_1x1(x) 48 | # (B, 192, 35, 419) <- 49 | x = self.Conv2d_4a_3x3(x) 50 | # (B, 192, 35, 419) <- 51 | x = self.maxpool2(x) 52 | 53 | if "192" in remaining_features: 54 | features["192"] = F.adaptive_avg_pool2d(x, output_size=(1, 1)) 55 | remaining_features.remove("192") 56 | if len(remaining_features) == 0: 57 | return tuple(features[a] for a in self.features_list) 58 | 59 | # (B, 256, 35, 419) <- 60 | x = self.Mixed_5b(x) 61 | # (B, 288, 35, 419) <- 62 | x = self.Mixed_5c(x) 63 | # (B, 288, 35, 419) <- 64 | x = self.Mixed_5d(x) 65 | # (B, 288, 35, 419) <- 66 | x = self.Mixed_6a(x) 67 | # (B, 768, 17, 209) <- 68 | x = self.Mixed_6b(x) 69 | # (B, 768, 17, 209) <- 70 | x = self.Mixed_6c(x) 71 | # (B, 768, 17, 209) <- 72 | x = self.Mixed_6d(x) 73 | # (B, 768, 17, 209) <- 74 | x = self.Mixed_6e(x) 75 | 76 | if "768" in remaining_features: 77 | features["768"] = F.adaptive_avg_pool2d(x, output_size=(1, 1)) 78 | remaining_features.remove("768") 79 | if len(remaining_features) == 0: 80 | return tuple(features[a] for a in self.features_list) 81 | 82 | # (B, 1280, 8, 104) <- 83 | x = self.Mixed_7a(x) 84 | # (B, 2048, 8, 104) <- 85 | x = self.Mixed_7b(x) 86 | # (B, 2048, 8, 104) <- 87 | x = self.Mixed_7c(x) 88 | # (B, 2048, 1, 1) <- 89 | x = self.avgpool(x) 90 | # (B, 2048, 1, 1) <- 91 | x = self.dropout(x) 92 | 93 | # (B, 2048) <- 94 | x = torch.flatten(x, 1) 95 | 96 | if "2048" in remaining_features: 97 | features["2048"] = x 98 | remaining_features.remove("2048") 99 | if len(remaining_features) == 0: 100 | return tuple(features[a] for a in self.features_list) 101 | 102 | if "logits_unbiased" in remaining_features: 103 | # (B, num_classes) <- 104 | x = x.mm(self.fc.weight.T) 105 | features["logits_unbiased"] = x 106 | remaining_features.remove("logits_unbiased") 107 | if len(remaining_features) == 0: 108 | return tuple(features[a] for a in self.features_list) 109 | 110 | x = x + self.fc.bias.unsqueeze(0) 111 | else: 112 | x = self.fc(x) 113 | 114 | features["logits"] = x 115 | return tuple(features[a] for a in self.features_list) 116 | 117 | def convert_features_tuple_to_dict(self, features): 118 | """ 119 | The only compound return type of the forward function amenable to JIT tracing is tuple. 120 | This function simply helps to recover the mapping. 121 | """ 122 | message = "Features must be the output of forward function" 123 | assert type(features) is tuple and len(features) == len( 124 | self.features_list 125 | ), message 126 | return dict( 127 | ((name, feature) for name, feature in zip(self.features_list, features)) 128 | ) 129 | -------------------------------------------------------------------------------- /audioldm_eval/feature_extractors/melception_audioset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torchvision.models.inception import BasicConv2d, Inception3 4 | from collections import OrderedDict 5 | 6 | 7 | def load_module2model(state_dict): 8 | new_state_dict = OrderedDict() 9 | for k, v in state_dict.items(): # k为module.xxx.weight, v为权重 10 | if k[:7] == "module.": 11 | name = k[7:] # 截取`module.`后面的xxx.weight 12 | new_state_dict[name] = v 13 | return new_state_dict 14 | 15 | 16 | class Melception(Inception3): 17 | def __init__( 18 | self, num_classes, features_list, feature_extractor_weights_path, **kwargs 19 | ): 20 | # inception = Melception(num_classes=309) 21 | super().__init__(num_classes=num_classes, init_weights=True, **kwargs) 22 | self.features_list = list(features_list) 23 | # the same as https://github.com/pytorch/vision/blob/5339e63148/torchvision/models/inception.py#L95 24 | # but for 1-channel input instead of RGB. 25 | self.Conv2d_1a_3x3 = BasicConv2d(1, 32, kernel_size=3, stride=2) 26 | # also the 'hight' of the mel spec is 80 (vs 299 in RGB) we remove all max pool from Inception 27 | self.maxpool1 = torch.nn.Identity() 28 | self.maxpool2 = torch.nn.Identity() 29 | 30 | state_dict = torch.load(feature_extractor_weights_path, map_location="cpu") 31 | new_state_dict = load_module2model(state_dict["model"]) 32 | # print('before....') 33 | # print(self.state_dict()['Conv2d_1a_3x3.conv.weight']) 34 | # print('after') 35 | self.load_state_dict(new_state_dict) 36 | # print(self.state_dict()['Conv2d_1a_3x3.conv.weight']) 37 | # assert 1==2 38 | for p in self.parameters(): 39 | p.requires_grad_(False) 40 | 41 | def forward(self, x): 42 | features = {} 43 | remaining_features = self.features_list.copy() 44 | 45 | # B x 1 x 80 x 848 <- N x M x T 46 | x = x.unsqueeze(1) 47 | # (B, 32, 39, 423) <- 48 | x = self.Conv2d_1a_3x3(x) 49 | # (B, 32, 37, 421) <- 50 | x = self.Conv2d_2a_3x3(x) 51 | # (B, 64, 37, 421) <- 52 | x = self.Conv2d_2b_3x3(x) 53 | # (B, 64, 37, 421) <- 54 | x = self.maxpool1(x) 55 | 56 | if "64" in remaining_features: 57 | features["64"] = F.adaptive_avg_pool2d(x, output_size=(1, 1)) 58 | remaining_features.remove("64") 59 | if len(remaining_features) == 0: 60 | return tuple(features[a] for a in self.features_list) 61 | 62 | # (B, 80, 37, 421) <- 63 | x = self.Conv2d_3b_1x1(x) 64 | # (B, 192, 35, 419) <- 65 | x = self.Conv2d_4a_3x3(x) 66 | # (B, 192, 35, 419) <- 67 | x = self.maxpool2(x) 68 | 69 | if "192" in remaining_features: 70 | features["192"] = F.adaptive_avg_pool2d(x, output_size=(1, 1)) 71 | remaining_features.remove("192") 72 | if len(remaining_features) == 0: 73 | return tuple(features[a] for a in self.features_list) 74 | 75 | # (B, 256, 35, 419) <- 76 | x = self.Mixed_5b(x) 77 | # (B, 288, 35, 419) <- 78 | x = self.Mixed_5c(x) 79 | # (B, 288, 35, 419) <- 80 | x = self.Mixed_5d(x) 81 | # (B, 288, 35, 419) <- 82 | x = self.Mixed_6a(x) 83 | # (B, 768, 17, 209) <- 84 | x = self.Mixed_6b(x) 85 | # (B, 768, 17, 209) <- 86 | x = self.Mixed_6c(x) 87 | # (B, 768, 17, 209) <- 88 | x = self.Mixed_6d(x) 89 | # (B, 768, 17, 209) <- 90 | x = self.Mixed_6e(x) 91 | 92 | if "768" in remaining_features: 93 | features["768"] = F.adaptive_avg_pool2d(x, output_size=(1, 1)) 94 | remaining_features.remove("768") 95 | if len(remaining_features) == 0: 96 | return tuple(features[a] for a in self.features_list) 97 | 98 | # (B, 1280, 8, 104) <- 99 | x = self.Mixed_7a(x) 100 | # (B, 2048, 8, 104) <- 101 | x = self.Mixed_7b(x) 102 | # (B, 2048, 8, 104) <- 103 | x = self.Mixed_7c(x) 104 | # (B, 2048, 1, 1) <- 105 | x = self.avgpool(x) 106 | # (B, 2048, 1, 1) <- 107 | x = self.dropout(x) 108 | 109 | # (B, 2048) <- 110 | x = torch.flatten(x, 1) 111 | # print('x ',x.shape) 112 | if "2048" in remaining_features: 113 | features["2048"] = x 114 | remaining_features.remove("2048") 115 | if len(remaining_features) == 0: 116 | return tuple(features[a] for a in self.features_list) 117 | 118 | if "logits_unbiased" in remaining_features: 119 | # (B, num_classes) <- 120 | x = x.mm(self.fc.weight.T) 121 | features["logits_unbiased"] = x 122 | remaining_features.remove("logits_unbiased") 123 | if len(remaining_features) == 0: 124 | return tuple(features[a] for a in self.features_list) 125 | 126 | x = x + self.fc.bias.unsqueeze(0) 127 | else: 128 | x = self.fc(x) 129 | 130 | features["logits"] = x 131 | # print('x ',x.shape) 132 | # assert 1==2 133 | return tuple(features[a] for a in self.features_list) 134 | 135 | def convert_features_tuple_to_dict(self, features): 136 | """ 137 | The only compound return type of the forward function amenable to JIT tracing is tuple. 138 | This function simply helps to recover the mapping. 139 | """ 140 | message = "Features must be the output of forward function" 141 | assert type(features) is tuple and len(features) == len( 142 | self.features_list 143 | ), message 144 | return dict( 145 | ((name, feature) for name, feature in zip(self.features_list, features)) 146 | ) 147 | -------------------------------------------------------------------------------- /audioldm_eval/feature_extractors/panns/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import Cnn14, Cnn14_16k 2 | -------------------------------------------------------------------------------- /audioldm_eval/feature_extractors/panns/config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import csv 3 | 4 | sample_rate = 32000 5 | clip_samples = sample_rate * 10 # Audio clips are 10-second 6 | 7 | # Load label 8 | with open("metadata/class_labels_indices.csv", "r") as f: 9 | reader = csv.reader(f, delimiter=",") 10 | lines = list(reader) 11 | 12 | labels = [] 13 | ids = [] # Each label has a unique id such as "/m/068hy" 14 | for i1 in range(1, len(lines)): 15 | id = lines[i1][1] 16 | label = lines[i1][2] 17 | ids.append(id) 18 | labels.append(label) 19 | 20 | classes_num = len(labels) 21 | 22 | lb_to_ix = {label: i for i, label in enumerate(labels)} 23 | ix_to_lb = {i: label for i, label in enumerate(labels)} 24 | 25 | id_to_ix = {id: i for i, id in enumerate(ids)} 26 | ix_to_id = {i: id for i, id in enumerate(ids)} 27 | 28 | full_samples_per_class = np.array( 29 | [ 30 | 937432, 31 | 16344, 32 | 7822, 33 | 10271, 34 | 2043, 35 | 14420, 36 | 733, 37 | 1511, 38 | 1258, 39 | 424, 40 | 1751, 41 | 704, 42 | 369, 43 | 590, 44 | 1063, 45 | 1375, 46 | 5026, 47 | 743, 48 | 853, 49 | 1648, 50 | 714, 51 | 1497, 52 | 1251, 53 | 2139, 54 | 1093, 55 | 133, 56 | 224, 57 | 39469, 58 | 6423, 59 | 407, 60 | 1559, 61 | 4546, 62 | 6826, 63 | 7464, 64 | 2468, 65 | 549, 66 | 4063, 67 | 334, 68 | 587, 69 | 238, 70 | 1766, 71 | 691, 72 | 114, 73 | 2153, 74 | 236, 75 | 209, 76 | 421, 77 | 740, 78 | 269, 79 | 959, 80 | 137, 81 | 4192, 82 | 485, 83 | 1515, 84 | 655, 85 | 274, 86 | 69, 87 | 157, 88 | 1128, 89 | 807, 90 | 1022, 91 | 346, 92 | 98, 93 | 680, 94 | 890, 95 | 352, 96 | 4169, 97 | 2061, 98 | 1753, 99 | 9883, 100 | 1339, 101 | 708, 102 | 37857, 103 | 18504, 104 | 12864, 105 | 2475, 106 | 2182, 107 | 757, 108 | 3624, 109 | 677, 110 | 1683, 111 | 3583, 112 | 444, 113 | 1780, 114 | 2364, 115 | 409, 116 | 4060, 117 | 3097, 118 | 3143, 119 | 502, 120 | 723, 121 | 600, 122 | 230, 123 | 852, 124 | 1498, 125 | 1865, 126 | 1879, 127 | 2429, 128 | 5498, 129 | 5430, 130 | 2139, 131 | 1761, 132 | 1051, 133 | 831, 134 | 2401, 135 | 2258, 136 | 1672, 137 | 1711, 138 | 987, 139 | 646, 140 | 794, 141 | 25061, 142 | 5792, 143 | 4256, 144 | 96, 145 | 8126, 146 | 2740, 147 | 752, 148 | 513, 149 | 554, 150 | 106, 151 | 254, 152 | 1592, 153 | 556, 154 | 331, 155 | 615, 156 | 2841, 157 | 737, 158 | 265, 159 | 1349, 160 | 358, 161 | 1731, 162 | 1115, 163 | 295, 164 | 1070, 165 | 972, 166 | 174, 167 | 937780, 168 | 112337, 169 | 42509, 170 | 49200, 171 | 11415, 172 | 6092, 173 | 13851, 174 | 2665, 175 | 1678, 176 | 13344, 177 | 2329, 178 | 1415, 179 | 2244, 180 | 1099, 181 | 5024, 182 | 9872, 183 | 10948, 184 | 4409, 185 | 2732, 186 | 1211, 187 | 1289, 188 | 4807, 189 | 5136, 190 | 1867, 191 | 16134, 192 | 14519, 193 | 3086, 194 | 19261, 195 | 6499, 196 | 4273, 197 | 2790, 198 | 8820, 199 | 1228, 200 | 1575, 201 | 4420, 202 | 3685, 203 | 2019, 204 | 664, 205 | 324, 206 | 513, 207 | 411, 208 | 436, 209 | 2997, 210 | 5162, 211 | 3806, 212 | 1389, 213 | 899, 214 | 8088, 215 | 7004, 216 | 1105, 217 | 3633, 218 | 2621, 219 | 9753, 220 | 1082, 221 | 26854, 222 | 3415, 223 | 4991, 224 | 2129, 225 | 5546, 226 | 4489, 227 | 2850, 228 | 1977, 229 | 1908, 230 | 1719, 231 | 1106, 232 | 1049, 233 | 152, 234 | 136, 235 | 802, 236 | 488, 237 | 592, 238 | 2081, 239 | 2712, 240 | 1665, 241 | 1128, 242 | 250, 243 | 544, 244 | 789, 245 | 2715, 246 | 8063, 247 | 7056, 248 | 2267, 249 | 8034, 250 | 6092, 251 | 3815, 252 | 1833, 253 | 3277, 254 | 8813, 255 | 2111, 256 | 4662, 257 | 2678, 258 | 2954, 259 | 5227, 260 | 1472, 261 | 2591, 262 | 3714, 263 | 1974, 264 | 1795, 265 | 4680, 266 | 3751, 267 | 6585, 268 | 2109, 269 | 36617, 270 | 6083, 271 | 16264, 272 | 17351, 273 | 3449, 274 | 5034, 275 | 3931, 276 | 2599, 277 | 4134, 278 | 3892, 279 | 2334, 280 | 2211, 281 | 4516, 282 | 2766, 283 | 2862, 284 | 3422, 285 | 1788, 286 | 2544, 287 | 2403, 288 | 2892, 289 | 4042, 290 | 3460, 291 | 1516, 292 | 1972, 293 | 1563, 294 | 1579, 295 | 2776, 296 | 1647, 297 | 4535, 298 | 3921, 299 | 1261, 300 | 6074, 301 | 2922, 302 | 3068, 303 | 1948, 304 | 4407, 305 | 712, 306 | 1294, 307 | 1019, 308 | 1572, 309 | 3764, 310 | 5218, 311 | 975, 312 | 1539, 313 | 6376, 314 | 1606, 315 | 6091, 316 | 1138, 317 | 1169, 318 | 7925, 319 | 3136, 320 | 1108, 321 | 2677, 322 | 2680, 323 | 1383, 324 | 3144, 325 | 2653, 326 | 1986, 327 | 1800, 328 | 1308, 329 | 1344, 330 | 122231, 331 | 12977, 332 | 2552, 333 | 2678, 334 | 7824, 335 | 768, 336 | 8587, 337 | 39503, 338 | 3474, 339 | 661, 340 | 430, 341 | 193, 342 | 1405, 343 | 1442, 344 | 3588, 345 | 6280, 346 | 10515, 347 | 785, 348 | 710, 349 | 305, 350 | 206, 351 | 4990, 352 | 5329, 353 | 3398, 354 | 1771, 355 | 3022, 356 | 6907, 357 | 1523, 358 | 8588, 359 | 12203, 360 | 666, 361 | 2113, 362 | 7916, 363 | 434, 364 | 1636, 365 | 5185, 366 | 1062, 367 | 664, 368 | 952, 369 | 3490, 370 | 2811, 371 | 2749, 372 | 2848, 373 | 15555, 374 | 363, 375 | 117, 376 | 1494, 377 | 1647, 378 | 5886, 379 | 4021, 380 | 633, 381 | 1013, 382 | 5951, 383 | 11343, 384 | 2324, 385 | 243, 386 | 372, 387 | 943, 388 | 734, 389 | 242, 390 | 3161, 391 | 122, 392 | 127, 393 | 201, 394 | 1654, 395 | 768, 396 | 134, 397 | 1467, 398 | 642, 399 | 1148, 400 | 2156, 401 | 1368, 402 | 1176, 403 | 302, 404 | 1909, 405 | 61, 406 | 223, 407 | 1812, 408 | 287, 409 | 422, 410 | 311, 411 | 228, 412 | 748, 413 | 230, 414 | 1876, 415 | 539, 416 | 1814, 417 | 737, 418 | 689, 419 | 1140, 420 | 591, 421 | 943, 422 | 353, 423 | 289, 424 | 198, 425 | 490, 426 | 7938, 427 | 1841, 428 | 850, 429 | 457, 430 | 814, 431 | 146, 432 | 551, 433 | 728, 434 | 1627, 435 | 620, 436 | 648, 437 | 1621, 438 | 2731, 439 | 535, 440 | 88, 441 | 1736, 442 | 736, 443 | 328, 444 | 293, 445 | 3170, 446 | 344, 447 | 384, 448 | 7640, 449 | 433, 450 | 215, 451 | 715, 452 | 626, 453 | 128, 454 | 3059, 455 | 1833, 456 | 2069, 457 | 3732, 458 | 1640, 459 | 1508, 460 | 836, 461 | 567, 462 | 2837, 463 | 1151, 464 | 2068, 465 | 695, 466 | 1494, 467 | 3173, 468 | 364, 469 | 88, 470 | 188, 471 | 740, 472 | 677, 473 | 273, 474 | 1533, 475 | 821, 476 | 1091, 477 | 293, 478 | 647, 479 | 318, 480 | 1202, 481 | 328, 482 | 532, 483 | 2847, 484 | 526, 485 | 721, 486 | 370, 487 | 258, 488 | 956, 489 | 1269, 490 | 1641, 491 | 339, 492 | 1322, 493 | 4485, 494 | 286, 495 | 1874, 496 | 277, 497 | 757, 498 | 1393, 499 | 1330, 500 | 380, 501 | 146, 502 | 377, 503 | 394, 504 | 318, 505 | 339, 506 | 1477, 507 | 1886, 508 | 101, 509 | 1435, 510 | 284, 511 | 1425, 512 | 686, 513 | 621, 514 | 221, 515 | 117, 516 | 87, 517 | 1340, 518 | 201, 519 | 1243, 520 | 1222, 521 | 651, 522 | 1899, 523 | 421, 524 | 712, 525 | 1016, 526 | 1279, 527 | 124, 528 | 351, 529 | 258, 530 | 7043, 531 | 368, 532 | 666, 533 | 162, 534 | 7664, 535 | 137, 536 | 70159, 537 | 26179, 538 | 6321, 539 | 32236, 540 | 33320, 541 | 771, 542 | 1169, 543 | 269, 544 | 1103, 545 | 444, 546 | 364, 547 | 2710, 548 | 121, 549 | 751, 550 | 1609, 551 | 855, 552 | 1141, 553 | 2287, 554 | 1940, 555 | 3943, 556 | 289, 557 | ] 558 | ) 559 | -------------------------------------------------------------------------------- /audioldm_eval/feature_extractors/panns/evaluate.py: -------------------------------------------------------------------------------- 1 | from sklearn import metrics 2 | 3 | from pytorch_utils import forward 4 | 5 | 6 | class Evaluator(object): 7 | def __init__(self, model): 8 | """Evaluator. 9 | 10 | Args: 11 | model: object 12 | """ 13 | self.model = model 14 | 15 | def evaluate(self, data_loader): 16 | """Forward evaluation data and calculate statistics. 17 | 18 | Args: 19 | data_loader: object 20 | 21 | Returns: 22 | statistics: dict, 23 | {'average_precision': (classes_num,), 'auc': (classes_num,)} 24 | """ 25 | 26 | # Forward 27 | output_dict = forward( 28 | model=self.model, generator=data_loader, return_target=True 29 | ) 30 | 31 | clipwise_output = output_dict["clipwise_output"] # (audios_num, classes_num) 32 | target = output_dict["target"] # (audios_num, classes_num) 33 | 34 | average_precision = metrics.average_precision_score( 35 | target, clipwise_output, average=None 36 | ) 37 | 38 | auc = metrics.roc_auc_score(target, clipwise_output, average=None) 39 | 40 | statistics = {"average_precision": average_precision, "auc": auc} 41 | 42 | return statistics 43 | -------------------------------------------------------------------------------- /audioldm_eval/feature_extractors/panns/finetune_template.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(1, os.path.join(sys.path[0], "../utils")) 5 | import numpy as np 6 | import argparse 7 | import h5py 8 | import math 9 | import time 10 | import logging 11 | import matplotlib.pyplot as plt 12 | 13 | import torch 14 | 15 | torch.backends.cudnn.benchmark = True 16 | torch.manual_seed(0) 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | import torch.optim as optim 20 | import torch.utils.data 21 | 22 | from utilities import get_filename 23 | from models import * 24 | import config 25 | 26 | 27 | class Transfer_Cnn14(nn.Module): 28 | def __init__( 29 | self, 30 | sample_rate, 31 | window_size, 32 | hop_size, 33 | mel_bins, 34 | fmin, 35 | fmax, 36 | classes_num, 37 | freeze_base, 38 | ): 39 | """Classifier for a new task using pretrained Cnn14 as a sub module.""" 40 | super(Transfer_Cnn14, self).__init__() 41 | audioset_classes_num = 527 42 | 43 | self.base = Cnn14( 44 | sample_rate, 45 | window_size, 46 | hop_size, 47 | mel_bins, 48 | fmin, 49 | fmax, 50 | audioset_classes_num, 51 | ) 52 | 53 | # Transfer to another task layer 54 | self.fc_transfer = nn.Linear(2048, classes_num, bias=True) 55 | 56 | if freeze_base: 57 | # Freeze AudioSet pretrained layers 58 | for param in self.base.parameters(): 59 | param.requires_grad = False 60 | 61 | self.init_weights() 62 | 63 | def init_weights(self): 64 | init_layer(self.fc_transfer) 65 | 66 | def load_from_pretrain(self, pretrained_checkpoint_path): 67 | checkpoint = torch.load(pretrained_checkpoint_path) 68 | self.base.load_state_dict(checkpoint["model"]) 69 | 70 | def forward(self, input, mixup_lambda=None): 71 | """Input: (batch_size, data_length)""" 72 | output_dict = self.base(input, mixup_lambda) 73 | embedding = output_dict["embedding"] 74 | 75 | clipwise_output = torch.log_softmax(self.fc_transfer(embedding), dim=-1) 76 | output_dict["clipwise_output"] = clipwise_output 77 | 78 | return output_dict 79 | 80 | 81 | def train(args): 82 | 83 | # Arugments & parameters 84 | sample_rate = args.sample_rate 85 | window_size = args.window_size 86 | hop_size = args.hop_size 87 | mel_bins = args.mel_bins 88 | fmin = args.fmin 89 | fmax = args.fmax 90 | model_type = args.model_type 91 | pretrained_checkpoint_path = args.pretrained_checkpoint_path 92 | freeze_base = args.freeze_base 93 | device = "cuda" if (args.cuda and torch.cuda.is_available()) else "cpu" 94 | 95 | classes_num = config.classes_num 96 | pretrain = True if pretrained_checkpoint_path else False 97 | 98 | # Model 99 | Model = eval(model_type) 100 | model = Model( 101 | sample_rate, 102 | window_size, 103 | hop_size, 104 | mel_bins, 105 | fmin, 106 | fmax, 107 | classes_num, 108 | freeze_base, 109 | ) 110 | 111 | # Load pretrained model 112 | if pretrain: 113 | logging.info("Load pretrained model from {}".format(pretrained_checkpoint_path)) 114 | model.load_from_pretrain(pretrained_checkpoint_path) 115 | 116 | # Parallel 117 | print("GPU number: {}".format(torch.cuda.device_count())) 118 | model = torch.nn.DataParallel(model) 119 | 120 | if "cuda" in device: 121 | model.to(device) 122 | 123 | print("Load pretrained model successfully!") 124 | 125 | 126 | if __name__ == "__main__": 127 | parser = argparse.ArgumentParser(description="Example of parser. ") 128 | subparsers = parser.add_subparsers(dest="mode") 129 | 130 | # Train 131 | parser_train = subparsers.add_parser("train") 132 | parser_train.add_argument("--sample_rate", type=int, required=True) 133 | parser_train.add_argument("--window_size", type=int, required=True) 134 | parser_train.add_argument("--hop_size", type=int, required=True) 135 | parser_train.add_argument("--mel_bins", type=int, required=True) 136 | parser_train.add_argument("--fmin", type=int, required=True) 137 | parser_train.add_argument("--fmax", type=int, required=True) 138 | parser_train.add_argument("--model_type", type=str, required=True) 139 | parser_train.add_argument("--pretrained_checkpoint_path", type=str) 140 | parser_train.add_argument("--freeze_base", action="store_true", default=False) 141 | parser_train.add_argument("--cuda", action="store_true", default=False) 142 | 143 | # Parse arguments 144 | args = parser.parse_args() 145 | args.filename = get_filename(__file__) 146 | 147 | if args.mode == "train": 148 | train(args) 149 | 150 | else: 151 | raise Exception("Error argument!") 152 | -------------------------------------------------------------------------------- /audioldm_eval/feature_extractors/panns/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def clip_bce(output_dict, target_dict): 6 | """Binary crossentropy loss.""" 7 | return F.binary_cross_entropy(output_dict["clipwise_output"], target_dict["target"]) 8 | 9 | 10 | def get_loss_func(loss_type): 11 | if loss_type == "clip_bce": 12 | return clip_bce 13 | -------------------------------------------------------------------------------- /audioldm_eval/feature_extractors/panns/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(1, os.path.join(sys.path[0], "../utils")) 5 | import numpy as np 6 | import argparse 7 | import time 8 | import logging 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | import torch.utils.data 15 | 16 | from utilities import ( 17 | create_folder, 18 | get_filename, 19 | create_logging, 20 | Mixup, 21 | StatisticsContainer, 22 | ) 23 | from models import ( 24 | Cnn14, 25 | Cnn14_no_specaug, 26 | Cnn14_no_dropout, 27 | Cnn6, 28 | Cnn10, 29 | ResNet22, 30 | ResNet38, 31 | ResNet54, 32 | Cnn14_emb512, 33 | Cnn14_emb128, 34 | Cnn14_emb32, 35 | MobileNetV1, 36 | MobileNetV2, 37 | LeeNet11, 38 | LeeNet24, 39 | DaiNet19, 40 | Res1dNet31, 41 | Res1dNet51, 42 | Wavegram_Cnn14, 43 | Wavegram_Logmel_Cnn14, 44 | Wavegram_Logmel128_Cnn14, 45 | Cnn14_16k, 46 | Cnn14_8k, 47 | Cnn14_mel32, 48 | Cnn14_mel128, 49 | Cnn14_mixup_time_domain, 50 | Cnn14_DecisionLevelMax, 51 | Cnn14_DecisionLevelAtt, 52 | ) 53 | from pytorch_utils import move_data_to_device, count_parameters, count_flops, do_mixup 54 | from data_generator import ( 55 | AudioSetDataset, 56 | TrainSampler, 57 | BalancedTrainSampler, 58 | AlternateTrainSampler, 59 | EvaluateSampler, 60 | collate_fn, 61 | ) 62 | from evaluate import Evaluator 63 | import config 64 | from losses import get_loss_func 65 | 66 | 67 | def train(args): 68 | """Train AudioSet tagging model. 69 | 70 | Args: 71 | dataset_dir: str 72 | workspace: str 73 | data_type: 'balanced_train' | 'full_train' 74 | window_size: int 75 | hop_size: int 76 | mel_bins: int 77 | model_type: str 78 | loss_type: 'clip_bce' 79 | balanced: 'none' | 'balanced' | 'alternate' 80 | augmentation: 'none' | 'mixup' 81 | batch_size: int 82 | learning_rate: float 83 | resume_iteration: int 84 | early_stop: int 85 | accumulation_steps: int 86 | cuda: bool 87 | """ 88 | 89 | # Arguments & parameters 90 | workspace = args.workspace 91 | data_type = args.data_type 92 | sample_rate = args.sample_rate 93 | window_size = args.window_size 94 | hop_size = args.hop_size 95 | mel_bins = args.mel_bins 96 | fmin = args.fmin 97 | fmax = args.fmax 98 | model_type = args.model_type 99 | loss_type = args.loss_type 100 | balanced = args.balanced 101 | augmentation = args.augmentation 102 | batch_size = args.batch_size 103 | learning_rate = args.learning_rate 104 | resume_iteration = args.resume_iteration 105 | early_stop = args.early_stop 106 | device = ( 107 | torch.device("cuda") 108 | if args.cuda and torch.cuda.is_available() 109 | else torch.device("cpu") 110 | ) 111 | filename = args.filename 112 | 113 | num_workers = 8 114 | clip_samples = config.clip_samples 115 | classes_num = config.classes_num 116 | loss_func = get_loss_func(loss_type) 117 | 118 | # Paths 119 | black_list_csv = None 120 | 121 | train_indexes_hdf5_path = os.path.join( 122 | workspace, "hdf5s", "indexes", "{}.h5".format(data_type) 123 | ) 124 | 125 | eval_bal_indexes_hdf5_path = os.path.join( 126 | workspace, "hdf5s", "indexes", "balanced_train.h5" 127 | ) 128 | 129 | eval_test_indexes_hdf5_path = os.path.join(workspace, "hdf5s", "indexes", "eval.h5") 130 | 131 | checkpoints_dir = os.path.join( 132 | workspace, 133 | "checkpoints", 134 | filename, 135 | "sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}".format( 136 | sample_rate, window_size, hop_size, mel_bins, fmin, fmax 137 | ), 138 | "data_type={}".format(data_type), 139 | model_type, 140 | "loss_type={}".format(loss_type), 141 | "balanced={}".format(balanced), 142 | "augmentation={}".format(augmentation), 143 | "batch_size={}".format(batch_size), 144 | ) 145 | create_folder(checkpoints_dir) 146 | 147 | statistics_path = os.path.join( 148 | workspace, 149 | "statistics", 150 | filename, 151 | "sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}".format( 152 | sample_rate, window_size, hop_size, mel_bins, fmin, fmax 153 | ), 154 | "data_type={}".format(data_type), 155 | model_type, 156 | "loss_type={}".format(loss_type), 157 | "balanced={}".format(balanced), 158 | "augmentation={}".format(augmentation), 159 | "batch_size={}".format(batch_size), 160 | "statistics.pkl", 161 | ) 162 | create_folder(os.path.dirname(statistics_path)) 163 | 164 | logs_dir = os.path.join( 165 | workspace, 166 | "logs", 167 | filename, 168 | "sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}".format( 169 | sample_rate, window_size, hop_size, mel_bins, fmin, fmax 170 | ), 171 | "data_type={}".format(data_type), 172 | model_type, 173 | "loss_type={}".format(loss_type), 174 | "balanced={}".format(balanced), 175 | "augmentation={}".format(augmentation), 176 | "batch_size={}".format(batch_size), 177 | ) 178 | 179 | create_logging(logs_dir, filemode="w") 180 | logging.info(args) 181 | 182 | if "cuda" in str(device): 183 | logging.info("Using GPU.") 184 | device = "cuda" 185 | else: 186 | logging.info("Using CPU. Set --cuda flag to use GPU.") 187 | device = "cpu" 188 | 189 | # Model 190 | Model = eval(model_type) 191 | model = Model( 192 | sample_rate=sample_rate, 193 | window_size=window_size, 194 | hop_size=hop_size, 195 | mel_bins=mel_bins, 196 | fmin=fmin, 197 | fmax=fmax, 198 | classes_num=classes_num, 199 | ) 200 | 201 | params_num = count_parameters(model) 202 | # flops_num = count_flops(model, clip_samples) 203 | logging.info("Parameters num: {}".format(params_num)) 204 | # logging.info('Flops num: {:.3f} G'.format(flops_num / 1e9)) 205 | 206 | # Dataset will be used by DataLoader later. Dataset takes a meta as input 207 | # and return a waveform and a target. 208 | dataset = AudioSetDataset(sample_rate=sample_rate) 209 | 210 | # Train sampler 211 | if balanced == "none": 212 | Sampler = TrainSampler 213 | elif balanced == "balanced": 214 | Sampler = BalancedTrainSampler 215 | elif balanced == "alternate": 216 | Sampler = AlternateTrainSampler 217 | 218 | train_sampler = Sampler( 219 | indexes_hdf5_path=train_indexes_hdf5_path, 220 | batch_size=batch_size * 2 if "mixup" in augmentation else batch_size, 221 | black_list_csv=black_list_csv, 222 | ) 223 | 224 | # Evaluate sampler 225 | eval_bal_sampler = EvaluateSampler( 226 | indexes_hdf5_path=eval_bal_indexes_hdf5_path, batch_size=batch_size 227 | ) 228 | 229 | eval_test_sampler = EvaluateSampler( 230 | indexes_hdf5_path=eval_test_indexes_hdf5_path, batch_size=batch_size 231 | ) 232 | 233 | # Data loader 234 | train_loader = torch.utils.data.DataLoader( 235 | dataset=dataset, 236 | batch_sampler=train_sampler, 237 | collate_fn=collate_fn, 238 | num_workers=num_workers, 239 | pin_memory=True, 240 | ) 241 | 242 | eval_bal_loader = torch.utils.data.DataLoader( 243 | dataset=dataset, 244 | batch_sampler=eval_bal_sampler, 245 | collate_fn=collate_fn, 246 | num_workers=num_workers, 247 | pin_memory=True, 248 | ) 249 | 250 | eval_test_loader = torch.utils.data.DataLoader( 251 | dataset=dataset, 252 | batch_sampler=eval_test_sampler, 253 | collate_fn=collate_fn, 254 | num_workers=num_workers, 255 | pin_memory=True, 256 | ) 257 | 258 | if "mixup" in augmentation: 259 | mixup_augmenter = Mixup(mixup_alpha=1.0) 260 | 261 | # Evaluator 262 | evaluator = Evaluator(model=model) 263 | 264 | # Statistics 265 | statistics_container = StatisticsContainer(statistics_path) 266 | 267 | # Optimizer 268 | optimizer = optim.Adam( 269 | model.parameters(), 270 | lr=learning_rate, 271 | betas=(0.9, 0.999), 272 | eps=1e-08, 273 | weight_decay=0.0, 274 | amsgrad=True, 275 | ) 276 | 277 | train_bgn_time = time.time() 278 | 279 | # Resume training 280 | if resume_iteration > 0: 281 | resume_checkpoint_path = os.path.join( 282 | workspace, 283 | "checkpoints", 284 | filename, 285 | "sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}".format( 286 | sample_rate, window_size, hop_size, mel_bins, fmin, fmax 287 | ), 288 | "data_type={}".format(data_type), 289 | model_type, 290 | "loss_type={}".format(loss_type), 291 | "balanced={}".format(balanced), 292 | "augmentation={}".format(augmentation), 293 | "batch_size={}".format(batch_size), 294 | "{}_iterations.pth".format(resume_iteration), 295 | ) 296 | 297 | logging.info("Loading checkpoint {}".format(resume_checkpoint_path)) 298 | checkpoint = torch.load(resume_checkpoint_path) 299 | model.load_state_dict(checkpoint["model"]) 300 | train_sampler.load_state_dict(checkpoint["sampler"]) 301 | statistics_container.load_state_dict(resume_iteration) 302 | iteration = checkpoint["iteration"] 303 | 304 | else: 305 | iteration = 0 306 | 307 | # Parallel 308 | print("GPU number: {}".format(torch.cuda.device_count())) 309 | model = torch.nn.DataParallel(model) 310 | 311 | if "cuda" in str(device): 312 | model.to(device) 313 | 314 | time1 = time.time() 315 | 316 | for batch_data_dict in train_loader: 317 | """batch_data_dict: { 318 | 'audio_name': (batch_size [*2 if mixup],), 319 | 'waveform': (batch_size [*2 if mixup], clip_samples), 320 | 'target': (batch_size [*2 if mixup], classes_num), 321 | (ifexist) 'mixup_lambda': (batch_size * 2,)} 322 | """ 323 | 324 | # Evaluate 325 | if (iteration % 2000 == 0 and iteration > resume_iteration) or (iteration == 0): 326 | train_fin_time = time.time() 327 | 328 | bal_statistics = evaluator.evaluate(eval_bal_loader) 329 | test_statistics = evaluator.evaluate(eval_test_loader) 330 | 331 | logging.info( 332 | "Validate bal mAP: {:.3f}".format( 333 | np.mean(bal_statistics["average_precision"]) 334 | ) 335 | ) 336 | 337 | logging.info( 338 | "Validate test mAP: {:.3f}".format( 339 | np.mean(test_statistics["average_precision"]) 340 | ) 341 | ) 342 | 343 | statistics_container.append(iteration, bal_statistics, data_type="bal") 344 | statistics_container.append(iteration, test_statistics, data_type="test") 345 | statistics_container.dump() 346 | 347 | train_time = train_fin_time - train_bgn_time 348 | validate_time = time.time() - train_fin_time 349 | 350 | logging.info( 351 | "iteration: {}, train time: {:.3f} s, validate time: {:.3f} s" 352 | "".format(iteration, train_time, validate_time) 353 | ) 354 | 355 | logging.info("------------------------------------") 356 | 357 | train_bgn_time = time.time() 358 | 359 | # Save model 360 | if iteration % 100000 == 0: 361 | checkpoint = { 362 | "iteration": iteration, 363 | "model": model.module.state_dict(), 364 | "sampler": train_sampler.state_dict(), 365 | } 366 | 367 | checkpoint_path = os.path.join( 368 | checkpoints_dir, "{}_iterations.pth".format(iteration) 369 | ) 370 | 371 | torch.save(checkpoint, checkpoint_path) 372 | logging.info("Model saved to {}".format(checkpoint_path)) 373 | 374 | # Mixup lambda 375 | if "mixup" in augmentation: 376 | batch_data_dict["mixup_lambda"] = mixup_augmenter.get_lambda( 377 | batch_size=len(batch_data_dict["waveform"]) 378 | ) 379 | 380 | # Move data to device 381 | for key in batch_data_dict.keys(): 382 | batch_data_dict[key] = move_data_to_device(batch_data_dict[key], device) 383 | 384 | # Forward 385 | model.train() 386 | 387 | if "mixup" in augmentation: 388 | batch_output_dict = model( 389 | batch_data_dict["waveform"], batch_data_dict["mixup_lambda"] 390 | ) 391 | """{'clipwise_output': (batch_size, classes_num), ...}""" 392 | 393 | batch_target_dict = { 394 | "target": do_mixup( 395 | batch_data_dict["target"], batch_data_dict["mixup_lambda"] 396 | ) 397 | } 398 | """{'target': (batch_size, classes_num)}""" 399 | else: 400 | batch_output_dict = model(batch_data_dict["waveform"], None) 401 | """{'clipwise_output': (batch_size, classes_num), ...}""" 402 | 403 | batch_target_dict = {"target": batch_data_dict["target"]} 404 | """{'target': (batch_size, classes_num)}""" 405 | 406 | # Loss 407 | loss = loss_func(batch_output_dict, batch_target_dict) 408 | 409 | # Backward 410 | loss.backward() 411 | print(loss) 412 | 413 | optimizer.step() 414 | optimizer.zero_grad() 415 | 416 | if iteration % 10 == 0: 417 | print( 418 | "--- Iteration: {}, train time: {:.3f} s / 10 iterations ---".format( 419 | iteration, time.time() - time1 420 | ) 421 | ) 422 | time1 = time.time() 423 | 424 | # Stop learning 425 | if iteration == early_stop: 426 | break 427 | 428 | iteration += 1 429 | 430 | 431 | if __name__ == "__main__": 432 | 433 | parser = argparse.ArgumentParser(description="Example of parser. ") 434 | subparsers = parser.add_subparsers(dest="mode") 435 | 436 | parser_train = subparsers.add_parser("train") 437 | parser_train.add_argument("--workspace", type=str, required=True) 438 | parser_train.add_argument( 439 | "--data_type", 440 | type=str, 441 | default="full_train", 442 | choices=["balanced_train", "full_train"], 443 | ) 444 | parser_train.add_argument("--sample_rate", type=int, default=32000) 445 | parser_train.add_argument("--window_size", type=int, default=1024) 446 | parser_train.add_argument("--hop_size", type=int, default=320) 447 | parser_train.add_argument("--mel_bins", type=int, default=64) 448 | parser_train.add_argument("--fmin", type=int, default=50) 449 | parser_train.add_argument("--fmax", type=int, default=14000) 450 | parser_train.add_argument("--model_type", type=str, required=True) 451 | parser_train.add_argument( 452 | "--loss_type", type=str, default="clip_bce", choices=["clip_bce"] 453 | ) 454 | parser_train.add_argument( 455 | "--balanced", 456 | type=str, 457 | default="balanced", 458 | choices=["none", "balanced", "alternate"], 459 | ) 460 | parser_train.add_argument( 461 | "--augmentation", type=str, default="mixup", choices=["none", "mixup"] 462 | ) 463 | parser_train.add_argument("--batch_size", type=int, default=32) 464 | parser_train.add_argument("--learning_rate", type=float, default=1e-3) 465 | parser_train.add_argument("--resume_iteration", type=int, default=0) 466 | parser_train.add_argument("--early_stop", type=int, default=1000000) 467 | parser_train.add_argument("--cuda", action="store_true", default=False) 468 | 469 | args = parser.parse_args() 470 | args.filename = get_filename(__file__) 471 | 472 | if args.mode == "train": 473 | train(args) 474 | 475 | else: 476 | raise Exception("Error argument!") 477 | -------------------------------------------------------------------------------- /audioldm_eval/feature_extractors/panns/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def move_data_to_device(x, device): 8 | if "float" in str(x.dtype): 9 | x = torch.Tensor(x) 10 | elif "int" in str(x.dtype): 11 | x = torch.LongTensor(x) 12 | else: 13 | return x 14 | 15 | return x.to(device) 16 | 17 | 18 | def do_mixup(x, mixup_lambda): 19 | """Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes 20 | (1, 3, 5, ...). 21 | 22 | Args: 23 | x: (batch_size * 2, ...) 24 | mixup_lambda: (batch_size * 2,) 25 | 26 | Returns: 27 | out: (batch_size, ...) 28 | """ 29 | out = ( 30 | x[0::2].transpose(0, -1) * mixup_lambda[0::2] 31 | + x[1::2].transpose(0, -1) * mixup_lambda[1::2] 32 | ).transpose(0, -1) 33 | return out 34 | 35 | 36 | def append_to_dict(dict, key, value): 37 | if key in dict.keys(): 38 | dict[key].append(value) 39 | else: 40 | dict[key] = [value] 41 | 42 | 43 | def forward(model, generator, return_input=False, return_target=False): 44 | """Forward data to a model. 45 | 46 | Args: 47 | model: object 48 | generator: object 49 | return_input: bool 50 | return_target: bool 51 | 52 | Returns: 53 | audio_name: (audios_num,) 54 | clipwise_output: (audios_num, classes_num) 55 | (ifexist) segmentwise_output: (audios_num, segments_num, classes_num) 56 | (ifexist) framewise_output: (audios_num, frames_num, classes_num) 57 | (optional) return_input: (audios_num, segment_samples) 58 | (optional) return_target: (audios_num, classes_num) 59 | """ 60 | output_dict = {} 61 | device = next(model.parameters()).device 62 | time1 = time.time() 63 | 64 | # Forward data to a model in mini-batches 65 | for n, batch_data_dict in enumerate(generator): 66 | print(n) 67 | batch_waveform = move_data_to_device(batch_data_dict["waveform"], device) 68 | 69 | with torch.no_grad(): 70 | model.eval() 71 | batch_output = model(batch_waveform) 72 | 73 | append_to_dict(output_dict, "audio_name", batch_data_dict["audio_name"]) 74 | 75 | append_to_dict( 76 | output_dict, 77 | "clipwise_output", 78 | batch_output["clipwise_output"].data.cpu().numpy(), 79 | ) 80 | 81 | if "segmentwise_output" in batch_output.keys(): 82 | append_to_dict( 83 | output_dict, 84 | "segmentwise_output", 85 | batch_output["segmentwise_output"].data.cpu().numpy(), 86 | ) 87 | 88 | if "framewise_output" in batch_output.keys(): 89 | append_to_dict( 90 | output_dict, 91 | "framewise_output", 92 | batch_output["framewise_output"].data.cpu().numpy(), 93 | ) 94 | 95 | if return_input: 96 | append_to_dict(output_dict, "waveform", batch_data_dict["waveform"]) 97 | 98 | if return_target: 99 | if "target" in batch_data_dict.keys(): 100 | append_to_dict(output_dict, "target", batch_data_dict["target"]) 101 | 102 | if n % 10 == 0: 103 | print( 104 | " --- Inference time: {:.3f} s / 10 iterations ---".format( 105 | time.time() - time1 106 | ) 107 | ) 108 | time1 = time.time() 109 | 110 | for key in output_dict.keys(): 111 | output_dict[key] = np.concatenate(output_dict[key], axis=0) 112 | 113 | return output_dict 114 | 115 | 116 | def interpolate(x, ratio): 117 | """Interpolate data in time domain. This is used to compensate the 118 | resolution reduction in downsampling of a CNN. 119 | 120 | Args: 121 | x: (batch_size, time_steps, classes_num) 122 | ratio: int, ratio to interpolate 123 | 124 | Returns: 125 | upsampled: (batch_size, time_steps * ratio, classes_num) 126 | """ 127 | (batch_size, time_steps, classes_num) = x.shape 128 | upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) 129 | upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) 130 | return upsampled 131 | 132 | 133 | def pad_framewise_output(framewise_output, frames_num): 134 | """Pad framewise_output to the same length as input frames. The pad value 135 | is the same as the value of the last frame. 136 | 137 | Args: 138 | framewise_output: (batch_size, frames_num, classes_num) 139 | frames_num: int, number of frames to pad 140 | 141 | Outputs: 142 | output: (batch_size, frames_num, classes_num) 143 | """ 144 | pad = framewise_output[:, -1:, :].repeat( 145 | 1, frames_num - framewise_output.shape[1], 1 146 | ) 147 | """tensor for padding""" 148 | 149 | output = torch.cat((framewise_output, pad), dim=1) 150 | """(batch_size, frames_num, classes_num)""" 151 | 152 | return output 153 | 154 | 155 | def count_parameters(model): 156 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 157 | 158 | 159 | def count_flops(model, audio_length): 160 | """Count flops. Code modified from others' implementation.""" 161 | multiply_adds = True 162 | list_conv2d = [] 163 | 164 | def conv2d_hook(self, input, output): 165 | batch_size, input_channels, input_height, input_width = input[0].size() 166 | output_channels, output_height, output_width = output[0].size() 167 | 168 | kernel_ops = ( 169 | self.kernel_size[0] 170 | * self.kernel_size[1] 171 | * (self.in_channels / self.groups) 172 | * (2 if multiply_adds else 1) 173 | ) 174 | bias_ops = 1 if self.bias is not None else 0 175 | 176 | params = output_channels * (kernel_ops + bias_ops) 177 | flops = batch_size * params * output_height * output_width 178 | 179 | list_conv2d.append(flops) 180 | 181 | list_conv1d = [] 182 | 183 | def conv1d_hook(self, input, output): 184 | batch_size, input_channels, input_length = input[0].size() 185 | output_channels, output_length = output[0].size() 186 | 187 | kernel_ops = ( 188 | self.kernel_size[0] 189 | * (self.in_channels / self.groups) 190 | * (2 if multiply_adds else 1) 191 | ) 192 | bias_ops = 1 if self.bias is not None else 0 193 | 194 | params = output_channels * (kernel_ops + bias_ops) 195 | flops = batch_size * params * output_length 196 | 197 | list_conv1d.append(flops) 198 | 199 | list_linear = [] 200 | 201 | def linear_hook(self, input, output): 202 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 203 | 204 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 205 | bias_ops = self.bias.nelement() 206 | 207 | flops = batch_size * (weight_ops + bias_ops) 208 | list_linear.append(flops) 209 | 210 | list_bn = [] 211 | 212 | def bn_hook(self, input, output): 213 | list_bn.append(input[0].nelement() * 2) 214 | 215 | list_relu = [] 216 | 217 | def relu_hook(self, input, output): 218 | list_relu.append(input[0].nelement() * 2) 219 | 220 | list_pooling2d = [] 221 | 222 | def pooling2d_hook(self, input, output): 223 | batch_size, input_channels, input_height, input_width = input[0].size() 224 | output_channels, output_height, output_width = output[0].size() 225 | 226 | kernel_ops = self.kernel_size * self.kernel_size 227 | bias_ops = 0 228 | params = output_channels * (kernel_ops + bias_ops) 229 | flops = batch_size * params * output_height * output_width 230 | 231 | list_pooling2d.append(flops) 232 | 233 | list_pooling1d = [] 234 | 235 | def pooling1d_hook(self, input, output): 236 | batch_size, input_channels, input_length = input[0].size() 237 | output_channels, output_length = output[0].size() 238 | 239 | kernel_ops = self.kernel_size[0] 240 | bias_ops = 0 241 | 242 | params = output_channels * (kernel_ops + bias_ops) 243 | flops = batch_size * params * output_length 244 | 245 | list_pooling2d.append(flops) 246 | 247 | def foo(net): 248 | childrens = list(net.children()) 249 | if not childrens: 250 | if isinstance(net, nn.Conv2d): 251 | net.register_forward_hook(conv2d_hook) 252 | elif isinstance(net, nn.Conv1d): 253 | net.register_forward_hook(conv1d_hook) 254 | elif isinstance(net, nn.Linear): 255 | net.register_forward_hook(linear_hook) 256 | elif isinstance(net, nn.BatchNorm2d) or isinstance(net, nn.BatchNorm1d): 257 | net.register_forward_hook(bn_hook) 258 | elif isinstance(net, nn.ReLU): 259 | net.register_forward_hook(relu_hook) 260 | elif isinstance(net, nn.AvgPool2d) or isinstance(net, nn.MaxPool2d): 261 | net.register_forward_hook(pooling2d_hook) 262 | elif isinstance(net, nn.AvgPool1d) or isinstance(net, nn.MaxPool1d): 263 | net.register_forward_hook(pooling1d_hook) 264 | else: 265 | print("Warning: flop of module {} is not counted!".format(net)) 266 | return 267 | for c in childrens: 268 | foo(c) 269 | 270 | # Register hook 271 | foo(model) 272 | 273 | device = device = next(model.parameters()).device 274 | input = torch.rand(1, audio_length).to(device) 275 | 276 | out = model(input) 277 | 278 | total_flops = ( 279 | sum(list_conv2d) 280 | + sum(list_conv1d) 281 | + sum(list_linear) 282 | + sum(list_bn) 283 | + sum(list_relu) 284 | + sum(list_pooling2d) 285 | + sum(list_pooling1d) 286 | ) 287 | 288 | return total_flops 289 | -------------------------------------------------------------------------------- /audioldm_eval/feature_extractors/panns/utilities.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | from scipy import stats 5 | import datetime 6 | import pickle 7 | 8 | 9 | def create_folder(fd): 10 | if not os.path.exists(fd): 11 | os.makedirs(fd) 12 | 13 | 14 | def get_filename(path): 15 | path = os.path.realpath(path) 16 | na_ext = path.split("/")[-1] 17 | na = os.path.splitext(na_ext)[0] 18 | return na 19 | 20 | 21 | def get_sub_filepaths(folder): 22 | paths = [] 23 | for root, dirs, files in os.walk(folder): 24 | for name in files: 25 | path = os.path.join(root, name) 26 | paths.append(path) 27 | return paths 28 | 29 | 30 | def create_logging(log_dir, filemode): 31 | create_folder(log_dir) 32 | i1 = 0 33 | 34 | while os.path.isfile(os.path.join(log_dir, "{:04d}.log".format(i1))): 35 | i1 += 1 36 | 37 | log_path = os.path.join(log_dir, "{:04d}.log".format(i1)) 38 | logging.basicConfig( 39 | level=logging.DEBUG, 40 | format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s", 41 | datefmt="%a, %d %b %Y %H:%M:%S", 42 | filename=log_path, 43 | filemode=filemode, 44 | ) 45 | 46 | # Print to console 47 | console = logging.StreamHandler() 48 | console.setLevel(logging.INFO) 49 | formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s") 50 | console.setFormatter(formatter) 51 | logging.getLogger("").addHandler(console) 52 | 53 | return logging 54 | 55 | 56 | def read_metadata(csv_path, classes_num, id_to_ix): 57 | """Read metadata of AudioSet from a csv file. 58 | 59 | Args: 60 | csv_path: str 61 | 62 | Returns: 63 | meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)} 64 | """ 65 | 66 | with open(csv_path, "r") as fr: 67 | lines = fr.readlines() 68 | lines = lines[3:] # Remove heads 69 | 70 | audios_num = len(lines) 71 | targets = np.zeros((audios_num, classes_num), dtype=np.bool) 72 | audio_names = [] 73 | 74 | for n, line in enumerate(lines): 75 | items = line.split(", ") 76 | """items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']""" 77 | 78 | audio_name = "Y{}.wav".format( 79 | items[0] 80 | ) # Audios are started with an extra 'Y' when downloading 81 | label_ids = items[3].split('"')[1].split(",") 82 | 83 | audio_names.append(audio_name) 84 | 85 | # Target 86 | for id in label_ids: 87 | ix = id_to_ix[id] 88 | targets[n, ix] = 1 89 | 90 | meta_dict = {"audio_name": np.array(audio_names), "target": targets} 91 | return meta_dict 92 | 93 | 94 | def float32_to_int16(x): 95 | assert np.max(np.abs(x)) <= 1.2 96 | x = np.clip(x, -1, 1) 97 | return (x * 32767.0).astype(np.int16) 98 | 99 | 100 | def int16_to_float32(x): 101 | return (x / 32767.0).astype(np.float32) 102 | 103 | 104 | def pad_or_truncate(x, audio_length): 105 | """Pad all audio to specific length.""" 106 | if len(x) <= audio_length: 107 | return np.concatenate((x, np.zeros(audio_length - len(x))), axis=0) 108 | else: 109 | return x[0:audio_length] 110 | 111 | 112 | def d_prime(auc): 113 | d_prime = stats.norm().ppf(auc) * np.sqrt(2.0) 114 | return d_prime 115 | 116 | 117 | class Mixup(object): 118 | def __init__(self, mixup_alpha, random_seed=1234): 119 | """Mixup coefficient generator.""" 120 | self.mixup_alpha = mixup_alpha 121 | self.random_state = np.random.RandomState(random_seed) 122 | 123 | def get_lambda(self, batch_size): 124 | """Get mixup random coefficients. 125 | Args: 126 | batch_size: int 127 | Returns: 128 | mixup_lambdas: (batch_size,) 129 | """ 130 | mixup_lambdas = [] 131 | for n in range(0, batch_size, 2): 132 | lam = self.random_state.beta(self.mixup_alpha, self.mixup_alpha, 1)[0] 133 | mixup_lambdas.append(lam) 134 | mixup_lambdas.append(1.0 - lam) 135 | 136 | return np.array(mixup_lambdas) 137 | 138 | 139 | class StatisticsContainer(object): 140 | def __init__(self, statistics_path): 141 | """Contain statistics of different training iterations.""" 142 | self.statistics_path = statistics_path 143 | 144 | self.backup_statistics_path = "{}_{}.pkl".format( 145 | os.path.splitext(self.statistics_path)[0], 146 | datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), 147 | ) 148 | 149 | self.statistics_dict = {"bal": [], "test": []} 150 | 151 | def append(self, iteration, statistics, data_type): 152 | statistics["iteration"] = iteration 153 | self.statistics_dict[data_type].append(statistics) 154 | 155 | def dump(self): 156 | pickle.dump(self.statistics_dict, open(self.statistics_path, "wb")) 157 | pickle.dump(self.statistics_dict, open(self.backup_statistics_path, "wb")) 158 | logging.info(" Dump statistics to {}".format(self.statistics_path)) 159 | logging.info(" Dump statistics to {}".format(self.backup_statistics_path)) 160 | 161 | def load_state_dict(self, resume_iteration): 162 | self.statistics_dict = pickle.load(open(self.statistics_path, "rb")) 163 | 164 | resume_statistics_dict = {"bal": [], "test": []} 165 | 166 | for key in self.statistics_dict.keys(): 167 | for statistics in self.statistics_dict[key]: 168 | if statistics["iteration"] <= resume_iteration: 169 | resume_statistics_dict[key].append(statistics) 170 | 171 | self.statistics_dict = resume_statistics_dict 172 | -------------------------------------------------------------------------------- /audioldm_eval/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haoheliu/audioldm_eval/8dc07ee7c42f9dc6e295460a1034175a0d49b436/audioldm_eval/metrics/__init__.py -------------------------------------------------------------------------------- /audioldm_eval/metrics/fad.py: -------------------------------------------------------------------------------- 1 | """ 2 | Calculate Frechet Audio Distance betweeen two audio directories. 3 | 4 | Frechet distance implementation adapted from: https://github.com/mseitzer/pytorch-fid 5 | 6 | VGGish adapted from: https://github.com/harritaylor/torchvggish 7 | """ 8 | import os 9 | import numpy as np 10 | import torch 11 | 12 | from torch import nn 13 | from scipy import linalg 14 | from tqdm import tqdm 15 | from multiprocessing.dummy import Pool as ThreadPool 16 | from audioldm_eval.datasets.load_mel import WaveDataset 17 | from torch.utils.data import DataLoader 18 | 19 | class FrechetAudioDistance: 20 | def __init__( 21 | self, use_pca=False, use_activation=False, verbose=False, audio_load_worker=8 22 | ): 23 | self.__get_model(use_pca=use_pca, use_activation=use_activation) 24 | self.verbose = verbose 25 | self.audio_load_worker = audio_load_worker 26 | 27 | def __get_model(self, use_pca=False, use_activation=False): 28 | """ 29 | Params: 30 | -- x : Either 31 | (i) a string which is the directory of a set of audio files, or 32 | (ii) a np.ndarray of shape (num_samples, sample_length) 33 | """ 34 | self.model = torch.hub.load("harritaylor/torchvggish", "vggish") 35 | if not use_pca: 36 | self.model.postprocess = False 37 | if not use_activation: 38 | self.model.embeddings = nn.Sequential( 39 | *list(self.model.embeddings.children())[:-1] 40 | ) 41 | self.model.eval() 42 | 43 | def load_audio_data(self, x): 44 | outputloader = DataLoader( 45 | WaveDataset( 46 | x, 47 | 16000, 48 | limit_num=None, 49 | ), 50 | batch_size=1, 51 | sampler=None, 52 | num_workers=8, 53 | ) 54 | data_list = [] 55 | print("Loading data to RAM") 56 | for batch in tqdm(outputloader): 57 | data_list.append((batch[0][0,0], 16000)) 58 | return data_list 59 | 60 | def get_embeddings(self, x, sr=16000, limit_num=None): 61 | """ 62 | Get embeddings using VGGish model. 63 | Params: 64 | -- x : Either 65 | (i) a string which is the directory of a set of audio files, or 66 | (ii) a list of np.ndarray audio samples 67 | -- sr : Sampling rate, if x is a list of audio samples. Default value is 16000. 68 | """ 69 | embd_lst = [] 70 | x = self.load_audio_data(x) 71 | if isinstance(x, list): 72 | try: 73 | for audio, sr in tqdm(x, disable=(not self.verbose)): 74 | embd = self.model.forward(audio.numpy(), sr) 75 | if self.model.device == torch.device("cuda"): 76 | embd = embd.cpu() 77 | embd = embd.detach().numpy() 78 | embd_lst.append(embd) 79 | except Exception as e: 80 | print( 81 | "[Frechet Audio Distance] get_embeddings throw an exception: {}".format( 82 | str(e) 83 | ) 84 | ) 85 | else: 86 | raise AttributeError 87 | 88 | return np.concatenate(embd_lst, axis=0) 89 | 90 | def calculate_embd_statistics(self, embd_lst): 91 | if isinstance(embd_lst, list): 92 | embd_lst = np.array(embd_lst) 93 | mu = np.mean(embd_lst, axis=0) 94 | sigma = np.cov(embd_lst, rowvar=False) 95 | return mu, sigma 96 | 97 | def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): 98 | """ 99 | Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py 100 | 101 | Numpy implementation of the Frechet Distance. 102 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 103 | and X_2 ~ N(mu_2, C_2) is 104 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 105 | Stable version by Dougal J. Sutherland. 106 | Params: 107 | -- mu1 : Numpy array containing the activations of a layer of the 108 | inception net (like returned by the function 'get_predictions') 109 | for generated samples. 110 | -- mu2 : The sample mean over activations, precalculated on an 111 | representative data set. 112 | -- sigma1: The covariance matrix over activations for generated samples. 113 | -- sigma2: The covariance matrix over activations, precalculated on an 114 | representative data set. 115 | Returns: 116 | -- : The Frechet Distance. 117 | """ 118 | 119 | mu1 = np.atleast_1d(mu1) 120 | mu2 = np.atleast_1d(mu2) 121 | 122 | sigma1 = np.atleast_2d(sigma1) 123 | sigma2 = np.atleast_2d(sigma2) 124 | 125 | assert ( 126 | mu1.shape == mu2.shape 127 | ), "Training and test mean vectors have different lengths" 128 | assert ( 129 | sigma1.shape == sigma2.shape 130 | ), "Training and test covariances have different dimensions" 131 | 132 | diff = mu1 - mu2 133 | 134 | # Product might be almost singular 135 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 136 | if not np.isfinite(covmean).all(): 137 | msg = ( 138 | "fid calculation produces singular product; " 139 | "adding %s to diagonal of cov estimates" 140 | ) % eps 141 | print(msg) 142 | offset = np.eye(sigma1.shape[0]) * eps 143 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 144 | 145 | # Numerical error might give slight imaginary component 146 | if np.iscomplexobj(covmean): 147 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 148 | m = np.max(np.abs(covmean.imag)) 149 | raise ValueError("Imaginary component {}".format(m)) 150 | covmean = covmean.real 151 | 152 | tr_covmean = np.trace(covmean) 153 | 154 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 155 | 156 | def score(self, background_dir, eval_dir, store_embds=False, limit_num=None, recalculate = False): 157 | # background_dir: generated samples 158 | # eval_dir: groundtruth samples 159 | try: 160 | fad_target_folder_cache = eval_dir + "_fad_feature_cache.npy" 161 | fad_generated_folder_cache = background_dir + "_fad_feature_cache.npy" 162 | 163 | if(not os.path.exists(fad_generated_folder_cache) or recalculate): 164 | embds_background = self.get_embeddings(background_dir, limit_num=limit_num) 165 | np.save(fad_generated_folder_cache, embds_background) 166 | else: 167 | print("Reload fad_generated_folder_cache", fad_generated_folder_cache) 168 | embds_background = np.load(fad_generated_folder_cache) 169 | 170 | if(not os.path.exists(fad_target_folder_cache) or recalculate): 171 | embds_eval = self.get_embeddings(eval_dir, limit_num=limit_num) 172 | np.save(fad_target_folder_cache, embds_eval) 173 | else: 174 | print("Reload fad_target_folder_cache", fad_target_folder_cache) 175 | embds_eval = np.load(fad_target_folder_cache) 176 | 177 | if store_embds: 178 | np.save("embds_background.npy", embds_background) 179 | np.save("embds_eval.npy", embds_eval) 180 | 181 | if len(embds_background) == 0: 182 | print( 183 | "[Frechet Audio Distance] background set dir is empty, exitting..." 184 | ) 185 | return -1 186 | 187 | if len(embds_eval) == 0: 188 | print("[Frechet Audio Distance] eval set dir is empty, exitting...") 189 | return -1 190 | 191 | mu_background, sigma_background = self.calculate_embd_statistics( 192 | embds_background 193 | ) 194 | mu_eval, sigma_eval = self.calculate_embd_statistics(embds_eval) 195 | 196 | fad_score = self.calculate_frechet_distance( 197 | mu_background, sigma_background, mu_eval, sigma_eval 198 | ) 199 | 200 | return {"frechet_audio_distance": fad_score} 201 | 202 | except Exception as e: 203 | print("[Frechet Audio Distance] exception thrown, {}".format(str(e))) 204 | return -1 205 | -------------------------------------------------------------------------------- /audioldm_eval/metrics/fid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import scipy.linalg 4 | 5 | # FID评价保真度,越小越好 6 | def calculate_fid( 7 | featuresdict_1, featuresdict_2, feat_layer_name 8 | ): # using 2048 layer to calculate 9 | eps = 1e-6 10 | features_1 = featuresdict_1[feat_layer_name] 11 | features_2 = featuresdict_2[feat_layer_name] 12 | 13 | assert torch.is_tensor(features_1) and features_1.dim() == 2 14 | assert torch.is_tensor(features_2) and features_2.dim() == 2 15 | 16 | stat_1 = { 17 | "mu": np.mean(features_1.numpy(), axis=0), 18 | "sigma": np.cov(features_1.numpy(), rowvar=False), 19 | } 20 | stat_2 = { 21 | "mu": np.mean(features_2.numpy(), axis=0), 22 | "sigma": np.cov(features_2.numpy(), rowvar=False), 23 | } 24 | 25 | print("Computing Frechet Distance") 26 | 27 | mu1, sigma1 = stat_1["mu"], stat_1["sigma"] 28 | mu2, sigma2 = stat_2["mu"], stat_2["sigma"] 29 | assert mu1.shape == mu2.shape and mu1.dtype == mu2.dtype 30 | assert sigma1.shape == sigma2.shape and sigma1.dtype == sigma2.dtype 31 | 32 | mu1 = np.atleast_1d(mu1) 33 | mu2 = np.atleast_1d(mu2) 34 | 35 | sigma1 = np.atleast_2d(sigma1) 36 | sigma2 = np.atleast_2d(sigma2) 37 | 38 | assert ( 39 | mu1.shape == mu2.shape 40 | ), "Training and test mean vectors have different lengths" 41 | assert ( 42 | sigma1.shape == sigma2.shape 43 | ), "Training and test covariances have different dimensions" 44 | 45 | diff = mu1 - mu2 46 | 47 | # Product might be almost singular 48 | covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False) 49 | if not np.isfinite(covmean).all(): 50 | print( 51 | f"WARNING: fid calculation produces singular product; adding {eps} to diagonal of cov" 52 | ) 53 | offset = np.eye(sigma1.shape[0]) * eps 54 | covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 55 | 56 | # Numerical error might give slight imaginary component 57 | if np.iscomplexobj(covmean): 58 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 59 | m = np.max(np.abs(covmean.imag)) 60 | assert False, "Imaginary component {}".format(m) 61 | covmean = covmean.real 62 | 63 | tr_covmean = np.trace(covmean) 64 | 65 | fid = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 66 | 67 | return { 68 | "frechet_distance": float(fid), 69 | } 70 | -------------------------------------------------------------------------------- /audioldm_eval/metrics/gs/__init__.py: -------------------------------------------------------------------------------- 1 | from .geom_score import * 2 | from .top_utils import * 3 | from .utils import * 4 | -------------------------------------------------------------------------------- /audioldm_eval/metrics/gs/geom_score.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from .utils import relative 4 | from .utils import witness 5 | import numpy as np 6 | 7 | 8 | def rlt(X, L_0=64, gamma=None, i_max=100): 9 | """ 10 | This function implements Algorithm 1 for one sample of landmarks. 11 | 12 | Args: 13 | X: np.array representing the dataset. 14 | L_0: number of landmarks to sample. 15 | gamma: float, parameter determining the maximum persistence value. 16 | i_max: int, upper bound on the value of beta_1 to compute. 17 | 18 | Returns 19 | An array of size (i_max, ) containing RLT(i, 1, X, L) 20 | for randomly sampled landmarks. 21 | """ 22 | if not isinstance(X, np.ndarray): 23 | raise ValueError("X should be a numpy array") 24 | if len(X.shape) != 2: 25 | raise ValueError("X should be 2d array, got shape {}".format(X.shape)) 26 | N = X.shape[0] 27 | if gamma is None: 28 | gamma = 1.0 / 128 * N / 5000 29 | I_1, alpha_max = witness(X, L_0=L_0, gamma=gamma) 30 | res = relative(I_1, alpha_max, i_max=i_max) 31 | return res 32 | 33 | 34 | def rlts(X, L_0=64, gamma=None, i_max=100, n=1000): 35 | """ 36 | This function implements Algorithm 1. 37 | 38 | Args: 39 | X: np.array representing the dataset. 40 | L_0: number of landmarks to sample. 41 | gamma: float, parameter determining the maximum persistence value. 42 | i_max: int, upper bound on the value of beta_1 to compute. 43 | n: int, number of samples 44 | Returns 45 | An array of size (n, i_max) containing RLT(i, 1, X, L) 46 | for n collections of randomly sampled landmarks. 47 | """ 48 | rlts = np.zeros((n, i_max)) 49 | for i in range(n): 50 | rlts[i, :] = rlt(X, L_0, gamma, i_max) 51 | if i % 10 == 0: 52 | print("Done {}/{}".format(i, n)) 53 | return rlts 54 | 55 | 56 | def geom_score(rlts1, rlts2): 57 | """ 58 | This function implements Algorithm 2. 59 | 60 | Args: 61 | rlts1 and rlts2: arrays as returned by the function "rlts". 62 | Returns 63 | Float, a number representing topological similarity of two datasets. 64 | 65 | """ 66 | mrlt1 = np.mean(rlts1, axis=0) 67 | mrlt2 = np.mean(rlts2, axis=0) 68 | return np.sum((mrlt1 - mrlt2) ** 2) 69 | -------------------------------------------------------------------------------- /audioldm_eval/metrics/gs/top_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def circle(N=5000): 5 | phi = 2 * np.pi * np.random.rand(N) 6 | x = [[np.sin(phi0), np.cos(phi0)] for phi0 in phi] 7 | x = np.array(x) 8 | x = x + 0.05 * np.random.randn(N, 2) 9 | return x 10 | 11 | 12 | def filled_circle(N=5000): 13 | ans = [] 14 | while len(ans) < N: 15 | x = np.random.rand(2) * 2.0 - 1.0 16 | if np.linalg.norm(x) < 1: 17 | ans.append(x) 18 | return np.array(ans) + 0.05 * np.random.randn(N, 2) 19 | 20 | 21 | def circle_quorter(N=5000): 22 | phi = np.pi * np.random.rand(N) + np.pi / 2 23 | x = [[np.sin(phi0), np.cos(phi0)] for phi0 in phi] 24 | x = np.array(x) 25 | x = x + 0.05 * np.random.randn(N, 2) 26 | return x 27 | 28 | 29 | def circle_thin(N=5000): 30 | phi = np.random.randn(N) 31 | x = [[np.sin(phi0), np.cos(phi0)] for phi0 in phi] 32 | x = np.array(x) 33 | x = x + 0.05 * np.random.randn(N, 2) 34 | return x 35 | 36 | 37 | def planar(N=5000, zdim=32, dim=784): 38 | A = np.random.rand(N, zdim) 39 | z = np.random.rand(zdim, dim) 40 | return np.dot(A, z) 41 | -------------------------------------------------------------------------------- /audioldm_eval/metrics/gs/utils.py: -------------------------------------------------------------------------------- 1 | try: 2 | import gudhi 3 | except ImportError as e: 4 | import six 5 | 6 | error = e.__class__( 7 | "You are likely missing your GUDHI installation, " 8 | "you should visit http://gudhi.gforge.inria.fr/python/latest/installation.html " 9 | "for further instructions.\nIf you use conda, you can use\nconda install -c conda-forge gudhi" 10 | ) 11 | six.raise_from(error, e) 12 | 13 | import numpy as np 14 | from scipy.spatial.distance import cdist # , pdist, squareform 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | def relative(I_1, alpha_max, i_max=100): 19 | """ 20 | For a collection of intervals I_1 this functions computes 21 | RLT by formulas (2) and (3). This function will be typically called 22 | on the output of the gudhi persistence_intervals_in_dimension function. 23 | 24 | Args: 25 | I_1: list of intervals e.g. [[0, 1], [0, 2], [0, np.inf]]. 26 | alpha_max: float, the maximal persistence value 27 | i_max: int, upper bound on the value of beta_1 to compute. 28 | 29 | Returns 30 | An array of size (i_max, ) containing desired RLT. 31 | """ 32 | 33 | persistence_intervals = [] 34 | # If for some interval we have that it persisted up to np.inf 35 | # we replace this point with alpha_max. 36 | for interval in I_1: 37 | if not np.isinf(interval[1]): 38 | persistence_intervals.append(list(interval)) 39 | elif np.isinf(interval[1]): 40 | persistence_intervals.append([interval[0], alpha_max]) 41 | 42 | # If there are no intervals in H1 then we always observed 0 holes. 43 | if len(persistence_intervals) == 0: 44 | rlt = np.zeros(i_max) 45 | rlt[0] = 1.0 46 | return rlt 47 | 48 | persistence_intervals_ext = persistence_intervals + [[0, alpha_max]] 49 | persistence_intervals_ext = np.array(persistence_intervals_ext) 50 | persistence_intervals = np.array(persistence_intervals) 51 | 52 | # Change in the value of beta_1 may happen only at the boundary points 53 | # of the intervals 54 | switch_points = np.sort(np.unique(persistence_intervals_ext.flatten())) 55 | rlt = np.zeros(i_max) 56 | for i in range(switch_points.shape[0] - 1): 57 | midpoint = (switch_points[i] + switch_points[i + 1]) / 2 58 | s = 0 59 | for interval in persistence_intervals: 60 | # Count how many intervals contain midpoint 61 | if midpoint >= interval[0] and midpoint < interval[1]: 62 | s = s + 1 63 | if s < i_max: 64 | rlt[s] += switch_points[i + 1] - switch_points[i] 65 | 66 | return rlt / alpha_max 67 | 68 | 69 | def lmrk_table(W, L): 70 | """ 71 | Helper function to construct an input for the gudhi.WitnessComplex 72 | function. 73 | 74 | Args: 75 | W: 2d array of size w x d, containing witnesses 76 | L: 2d array of size l x d containing landmarks 77 | 78 | Returns 79 | Return a 3d array D of size w x l x 2 and the maximal distance 80 | between W and L. 81 | 82 | D satisfies the property that D[i, :, :] is [idx_i, dists_i], 83 | where dists_i are the sorted distances from the i-th witness to each 84 | point in L and idx_i are the indices of the corresponding points 85 | in L, e.g., 86 | D[i, :, :] = [[0, 0.1], [1, 0.2], [3, 0.3], [2, 0.4]] 87 | """ 88 | 89 | a = cdist(W, L) 90 | max_val = np.max(a) 91 | idx = np.argsort(a) 92 | b = a[np.arange(np.shape(a)[0])[:, np.newaxis], idx] 93 | return np.dstack([idx, b]), max_val 94 | 95 | 96 | def random_landmarks(X, L_0=32): 97 | """ 98 | Randomly sample L_0 points from X. 99 | """ 100 | sz = X.shape[0] 101 | idx = np.random.choice(sz, L_0) 102 | L = X[idx] 103 | return L 104 | 105 | 106 | def witness(X, gamma=1.0 / 128, L_0=64): 107 | """ 108 | This function computes the persistence intervals for the dataset 109 | X using the witness complex. 110 | 111 | Args: 112 | X: 2d array representing the dataset. 113 | gamma: parameter determining the maximal persistence value. 114 | L_0: int, number of landmarks to use. 115 | 116 | Returns 117 | A list of persistence intervals and the maximal persistence value. 118 | """ 119 | L = random_landmarks(X, L_0) 120 | W = X 121 | lmrk_tab, max_dist = lmrk_table(W, L) 122 | wc = gudhi.WitnessComplex(lmrk_tab) 123 | alpha_max = max_dist * gamma 124 | st = wc.create_simplex_tree(max_alpha_square=alpha_max, limit_dimension=2) 125 | # this seems to modify the st object 126 | st.persistence(homology_coeff_field=2) 127 | diag = st.persistence_intervals_in_dimension(1) 128 | return diag, alpha_max 129 | 130 | 131 | def fancy_plot(y, color="C0", label="", alpha=0.3): 132 | """ 133 | A function for a nice visualization of MRLT. 134 | """ 135 | n = y.shape[0] 136 | x = np.arange(n) 137 | xleft = x - 0.5 138 | xright = x + 0.5 139 | X = np.array([xleft, xright]).T.flatten() 140 | Xn = np.zeros(X.shape[0] + 2) 141 | Xn[1:-1] = X 142 | Xn[0] = -0.5 143 | Xn[-1] = n - 0.5 144 | Y = np.array([y, y]).T.flatten() 145 | Yn = np.zeros(Y.shape[0] + 2) 146 | Yn[1:-1] = Y 147 | plt.bar(x, y, width=1, alpha=alpha, color=color, edgecolor=color) 148 | plt.plot(Xn, Yn, c=color, label=label, lw=3) 149 | -------------------------------------------------------------------------------- /audioldm_eval/metrics/isc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def calculate_isc(featuresdict, feat_layer_name, rng_seed, samples_shuffle, splits): 6 | print("Computing Inception Score") 7 | 8 | features = featuresdict[feat_layer_name] 9 | 10 | assert torch.is_tensor(features) and features.dim() == 2 11 | N, C = features.shape 12 | if samples_shuffle: 13 | rng = np.random.RandomState(rng_seed) 14 | features = features[rng.permutation(N), :] 15 | features = features.double() 16 | 17 | p = features.softmax(dim=1) 18 | log_p = features.log_softmax(dim=1) 19 | 20 | scores = [] 21 | for i in range(splits): 22 | p_chunk = p[(i * N // splits) : ((i + 1) * N // splits), :] # 一部分的预测概率 23 | log_p_chunk = log_p[(i * N // splits) : ((i + 1) * N // splits), :] # log 24 | q_chunk = p_chunk.mean(dim=0, keepdim=True) # 概率的均值 25 | kl = p_chunk * (log_p_chunk - q_chunk.log()) # 26 | kl = kl.sum(dim=1).mean().exp().item() 27 | scores.append(kl) 28 | # print("scores",scores) 29 | return { 30 | "inception_score_mean": float(np.mean(scores)), 31 | "inception_score_std": float(np.std(scores)), 32 | } 33 | -------------------------------------------------------------------------------- /audioldm_eval/metrics/kid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | # 分多组,每组一定的数量,然后每组分别计算MMD 6 | 7 | 8 | def calculate_kid( 9 | featuresdict_1, 10 | featuresdict_2, 11 | subsets, 12 | subset_size, 13 | degree, 14 | gamma, 15 | coef0, 16 | rng_seed, 17 | feat_layer_name, 18 | ): 19 | features_1 = featuresdict_1[feat_layer_name] 20 | features_2 = featuresdict_2[feat_layer_name] 21 | 22 | assert torch.is_tensor(features_1) and features_1.dim() == 2 23 | assert torch.is_tensor(features_2) and features_2.dim() == 2 24 | assert features_1.shape[1] == features_2.shape[1] 25 | if subset_size > len(features_2): 26 | print( 27 | f"WARNING: subset size ({subset_size}) is larger than feature length ({len(features_2)}). ", 28 | "Using", 29 | len(features_2), 30 | "for both datasets", 31 | ) 32 | subset_size = len(features_2) 33 | if subset_size > len(features_1): 34 | print( 35 | f"WARNING: subset size ({subset_size}) is larger than feature length ({len(features_1)}). ", 36 | "Using", 37 | len(features_1), 38 | "for both datasets", 39 | ) 40 | subset_size = len(features_1) 41 | 42 | features_1 = features_1.cpu().numpy() 43 | features_2 = features_2.cpu().numpy() 44 | 45 | mmds = np.zeros(subsets) 46 | rng = np.random.RandomState(rng_seed) 47 | 48 | for i in tqdm( 49 | range(subsets), 50 | leave=False, 51 | unit="subsets", 52 | desc="Computing Kernel Inception Distance", 53 | ): 54 | f1 = features_1[rng.choice(len(features_1), subset_size, replace=False)] 55 | f2 = features_2[rng.choice(len(features_2), subset_size, replace=False)] 56 | o = polynomial_mmd(f1, f2, degree, gamma, coef0) 57 | mmds[i] = o 58 | 59 | return { 60 | "kernel_inception_distance_mean": float(np.mean(mmds)), 61 | "kernel_inception_distance_std": float(np.std(mmds)), 62 | } 63 | 64 | 65 | def polynomial_kernel(X, Y, degree=3, gamma=None, coef0=1): 66 | if gamma in [None, "none", "null", "None"]: 67 | gamma = 1.0 / X.shape[1] 68 | K = (np.matmul(X, Y.T) * gamma + coef0) ** degree 69 | return K 70 | 71 | 72 | def polynomial_mmd(features_1, features_2, degree, gamma, coef0): 73 | K_XX = polynomial_kernel( 74 | features_1, features_1, degree=degree, gamma=gamma, coef0=coef0 75 | ) 76 | K_YY = polynomial_kernel( 77 | features_2, features_2, degree=degree, gamma=gamma, coef0=coef0 78 | ) 79 | K_XY = polynomial_kernel( 80 | features_1, features_2, degree=degree, gamma=gamma, coef0=coef0 81 | ) 82 | 83 | # based on https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd.py 84 | # changed to not compute the full kernel matrix at once 85 | m = K_XX.shape[0] 86 | assert K_XX.shape == (m, m) 87 | assert K_XY.shape == (m, m) 88 | assert K_YY.shape == (m, m) 89 | 90 | diag_X = np.diagonal(K_XX) 91 | diag_Y = np.diagonal(K_YY) 92 | 93 | Kt_XX_sums = K_XX.sum(axis=1) - diag_X 94 | Kt_YY_sums = K_YY.sum(axis=1) - diag_Y 95 | K_XY_sums_0 = K_XY.sum(axis=0) 96 | 97 | Kt_XX_sum = Kt_XX_sums.sum() 98 | Kt_YY_sum = Kt_YY_sums.sum() 99 | K_XY_sum = K_XY_sums_0.sum() 100 | 101 | mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m - 1)) 102 | mmd2 -= 2 * K_XY_sum / (m * m) 103 | 104 | return mmd2 105 | -------------------------------------------------------------------------------- /audioldm_eval/metrics/kl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pathlib import Path 3 | import os 4 | 5 | 6 | def path_to_sharedkey(path, dataset_name, classes=None): 7 | if dataset_name.lower() == "vggsound": 8 | # a generic oneliner which extracts the unique filename for the dataset. 9 | # Works on both FakeFolder and VGGSound* datasets 10 | sharedkey = Path(path).stem.replace("_mel", "").split("_sample_")[0] 11 | elif dataset_name.lower() == "vas": 12 | # in the case of vas the procedure is a bit more tricky and involves relying on the premise that 13 | # the folder names (.../VAS_validation/cls_0, .../cls_1 etc) are made after enumerating sorted list 14 | # of classes. 15 | classes = sorted(classes) 16 | target_to_label = {f"cls_{i}": c for i, c in enumerate(classes)} 17 | # replacing class folder with the name of the class to match the original dataset (cls_2 -> dog) 18 | for folder_cls_name, label in target_to_label.items(): 19 | path = path.replace(folder_cls_name, label).replace( 20 | "melspec_10s_22050hz/", "" 21 | ) 22 | # merging video name with class name to make a unique shared key 23 | sharedkey = ( 24 | Path(path).parent.stem 25 | + "_" 26 | + Path(path).stem.replace("_mel", "").split("_sample_")[0] 27 | ) 28 | elif dataset_name.lower() == "caps": # stem : 获取/.之间的部分 29 | sharedkey = Path(path).stem.replace("_mel", "").split("_sample_")[0] # 获得原文件名称 30 | else: 31 | raise NotImplementedError 32 | return sharedkey 33 | 34 | 35 | def calculate_kl(featuresdict_1, featuresdict_2, feat_layer_name, same_name=True): 36 | # test_input(featuresdict_1, featuresdict_2, feat_layer_name, dataset_name, classes) 37 | if not same_name: 38 | return ( 39 | { 40 | "kullback_leibler_divergence_sigmoid": float(-1), 41 | "kullback_leibler_divergence_softmax": float(-1), 42 | }, 43 | None, 44 | None, 45 | ) 46 | 47 | print( 48 | 'KL: Assuming that `input2` is "pseudo" target and `input1` is prediction. KL(input2_i||input1_i)' 49 | ) 50 | EPS = 1e-6 51 | features_1 = featuresdict_1[feat_layer_name] 52 | features_2 = featuresdict_2[feat_layer_name] 53 | # # print('features_1 ',features_1.shape) # the predicted (num*10, class_num) 54 | # # print('features_2 ',features_2.shape) # the true 55 | paths_1 = [os.path.basename(x) for x in featuresdict_1["file_path_"]] 56 | paths_2 = [os.path.basename(x) for x in featuresdict_2["file_path_"]] 57 | # # print('paths_1 ',len(paths_1)) its path 58 | # # print('paths_2 ',len(paths_2)) 59 | path_to_feats_1 = {p: f for p, f in zip(paths_1, features_1)} 60 | # #print(path_to_feats_1) 61 | path_to_feats_2 = {p: f for p, f in zip(paths_2, features_2)} 62 | # # dataset_name: caps 63 | # # in input1 (fakes) can have multiple samples per video, while input2 has only one real 64 | # sharedkey_to_feats_1 = {path_to_sharedkey(p, dataset_name, classes): [] for p in paths_1} 65 | sharedkey_to_feats_1 = {p: path_to_feats_1[p] for p in paths_1} 66 | sharedkey_to_feats_2 = {p: path_to_feats_2[p] for p in paths_2} 67 | # sharedkey_to_feats_2 = {path_to_sharedkey(p, dataset_name, classes):path_to_feats_2[p] for p in paths_1} 68 | 69 | features_1 = [] 70 | features_2 = [] 71 | 72 | for sharedkey, feat_2 in sharedkey_to_feats_2.items(): 73 | # print("feat_2",feat_2) 74 | if sharedkey not in sharedkey_to_feats_1.keys(): 75 | print("%s is not in the generation result" % sharedkey) 76 | continue 77 | features_1.extend([sharedkey_to_feats_1[sharedkey]]) 78 | # print("feature_step",len(features_1)) 79 | # print("share",sharedkey_to_feats_1[sharedkey]) 80 | # just replicating the ground truth logits to compare with multiple samples in prediction 81 | # samples_num = len(sharedkey_to_feats_1[sharedkey]) 82 | features_2.extend([feat_2]) 83 | 84 | features_1 = torch.stack(features_1, dim=0) 85 | features_2 = torch.stack(features_2, dim=0) 86 | 87 | kl_ref = torch.nn.functional.kl_div( 88 | (features_1.softmax(dim=1) + EPS).log(), 89 | features_2.softmax(dim=1), 90 | reduction="none", 91 | ) / len(features_1) 92 | kl_ref = torch.mean(kl_ref, dim=-1) 93 | 94 | # AudioGen use this formulation 95 | kl_softmax = torch.nn.functional.kl_div( 96 | (features_1.softmax(dim=1) + EPS).log(), 97 | features_2.softmax(dim=1), 98 | reduction="sum", 99 | ) / len(features_1) 100 | 101 | # For multi-class audio clips, this formulation could be better 102 | kl_sigmoid = torch.nn.functional.kl_div( 103 | (features_1.sigmoid() + EPS).log(), features_2.sigmoid(), reduction="sum" 104 | ) / len(features_1) 105 | 106 | return ( 107 | { 108 | "kullback_leibler_divergence_sigmoid": float(kl_sigmoid), 109 | "kullback_leibler_divergence_softmax": float(kl_softmax), 110 | }, 111 | kl_ref, 112 | paths_1, 113 | ) 114 | 115 | 116 | def test_input(featuresdict_1, featuresdict_2, feat_layer_name, dataset_name, classes): 117 | assert feat_layer_name == "logits", "This KL div metric is implemented on logits." 118 | assert ( 119 | "file_path_" in featuresdict_1 and "file_path_" in featuresdict_2 120 | ), "File paths are missing" 121 | assert len(featuresdict_1) >= len( 122 | featuresdict_2 123 | ), "There are more samples in input1, than in input2" 124 | assert ( 125 | len(featuresdict_1) % len(featuresdict_2) == 0 126 | ), "Size of input1 is not a multiple of input1 size." 127 | if dataset_name == "vas": 128 | assert ( 129 | classes is not None 130 | ), f"Specify classes if you are using vas dataset. Now `classes` – {classes}" 131 | print( 132 | "KL: when FakesFolder on VAS is used as a dataset, we assume the original labels were sorted", 133 | "to produce the target_ids. E.g. `baby` -> `cls_0`; `cough` -> `cls_1`; `dog` -> `cls_2`.", 134 | ) 135 | 136 | 137 | if __name__ == "__main__": 138 | # p = torch.tensor([0.25, 0.25, 0.25, 0.25]).view(1, 4) 139 | # q = torch.tensor([0.25, 0.25, 0.25, 0.25]).view(1, 4) 140 | # 0. 141 | 142 | p = torch.tensor([0.5, 0.6, 0.7]).view(3, 1) 143 | p_ = 1 - p 144 | p = torch.cat([p, p_], dim=1).view(-1, 2) 145 | print(p) 146 | q = torch.tensor([0.5, 0.6, 0.7]).view(3, 1) 147 | q_ = 1 - q 148 | q = torch.cat([q, q_], dim=1).view(-1, 2) 149 | print(q.shape) 150 | kl = torch.nn.functional.kl_div(torch.log(q), p, reduction="sum") 151 | # 0.0853 152 | 153 | print(kl) 154 | -------------------------------------------------------------------------------- /audioldm_eval/metrics/ndb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from sklearn.cluster import KMeans 4 | from scipy.stats import norm 5 | from matplotlib import pyplot as plt 6 | import pickle as pkl 7 | 8 | 9 | class NDB: 10 | def __init__( 11 | self, 12 | training_data=None, 13 | number_of_bins=100, 14 | significance_level=0.05, 15 | z_threshold=None, 16 | whitening=False, 17 | max_dims=None, 18 | cache_folder=None, 19 | ): 20 | """ 21 | NDB Evaluation Class 22 | :param training_data: Optional - the training samples - array of m x d floats (m samples of dimension d) 23 | :param number_of_bins: Number of bins (clusters) default=100 24 | :param significance_level: The statistical significance level for the two-sample test 25 | :param z_threshold: Allow defining a threshold in terms of difference/SE for defining a bin as statistically different 26 | :param whitening: Perform data whitening - subtract mean and divide by per-dimension std 27 | :param max_dims: Max dimensions to use in K-means. By default derived automatically from d 28 | :param bins_file: Optional - file to write / read-from the clusters (to avoid re-calculation) 29 | """ 30 | self.number_of_bins = number_of_bins 31 | self.significance_level = significance_level 32 | self.z_threshold = z_threshold 33 | self.whitening = whitening 34 | self.ndb_eps = 1e-6 35 | self.training_mean = 0.0 36 | self.training_std = 1.0 37 | self.max_dims = max_dims 38 | self.cache_folder = cache_folder 39 | self.bin_centers = None 40 | self.bin_proportions = None 41 | self.ref_sample_size = None 42 | self.used_d_indices = None 43 | self.results_file = None 44 | self.test_name = "ndb_{}_bins_{}".format( 45 | self.number_of_bins, "whiten" if self.whitening else "orig" 46 | ) 47 | self.cached_results = {} 48 | if self.cache_folder: 49 | self.results_file = os.path.join( 50 | cache_folder, self.test_name + "_results.pkl" 51 | ) 52 | if os.path.isfile(self.results_file): 53 | # print('Loading previous results from', self.results_file, ':') 54 | self.cached_results = pkl.load(open(self.results_file, "rb")) 55 | # print(self.cached_results.keys()) 56 | if training_data is not None or cache_folder is not None: 57 | bins_file = None 58 | if cache_folder: 59 | os.makedirs(cache_folder, exist_ok=True) 60 | bins_file = os.path.join(cache_folder, self.test_name + ".pkl") 61 | self.construct_bins(training_data, bins_file) 62 | 63 | def construct_bins(self, training_samples, bins_file): 64 | """ 65 | Performs K-means clustering of the training samples 66 | :param training_samples: An array of m x d floats (m samples of dimension d) 67 | """ 68 | 69 | # if self.__read_from_bins_file(bins_file): 70 | # return 71 | n, d = training_samples.shape 72 | k = self.number_of_bins 73 | # print("k is",k) 74 | if self.whitening: 75 | self.training_mean = np.mean(training_samples, axis=0) 76 | self.training_std = np.std(training_samples, axis=0) + self.ndb_eps 77 | 78 | if self.max_dims is None and d > 1000: 79 | # To ran faster, perform binning on sampled data dimension (i.e. don't use all channels of all pixels) 80 | self.max_dims = d // 6 81 | 82 | whitened_samples = (training_samples - self.training_mean) / self.training_std 83 | d_used = d if self.max_dims is None else min(d, self.max_dims) 84 | self.used_d_indices = np.random.choice(d, d_used, replace=False) 85 | 86 | # print('Performing K-Means clustering of {} samples in dimension {} / {} to {} clusters ...'.format(n, d_used, d, k)) 87 | # print('Can take a couple of minutes...') 88 | if n // k > 1000: 89 | print( 90 | "Training data size should be ~500 times the number of bins (for reasonable speed and accuracy)" 91 | ) 92 | 93 | clusters = KMeans(n_clusters=k, max_iter=100).fit( 94 | whitened_samples[:, self.used_d_indices] 95 | ) 96 | 97 | bin_centers = np.zeros([k, d]) 98 | for i in range(k): 99 | bin_centers[i, :] = np.mean( 100 | whitened_samples[clusters.labels_ == i, :], axis=0 101 | ) 102 | 103 | # Organize bins by size 104 | label_vals, label_counts = np.unique(clusters.labels_, return_counts=True) 105 | bin_order = np.argsort(-label_counts) 106 | self.bin_proportions = label_counts[bin_order] / np.sum(label_counts) 107 | self.bin_centers = bin_centers[bin_order, :] 108 | self.ref_sample_size = n 109 | self.__write_to_bins_file(bins_file) 110 | # print('Done.') 111 | 112 | def evaluate(self, query_samples, model_label=None): 113 | """ 114 | Assign each sample to the nearest bin center (in L2). Pre-whiten if required. and calculate the NDB 115 | (Number of statistically Different Bins) and JS divergence scores. 116 | :param query_samples: An array of m x d floats (m samples of dimension d) 117 | :param model_label: optional label string for the evaluated model, allows plotting results of multiple models 118 | :return: results dictionary containing NDB and JS scores and array of labels (assigned bin for each query sample) 119 | """ 120 | n = query_samples.shape[0] 121 | query_bin_proportions, query_bin_assignments = self.__calculate_bin_proportions( 122 | query_samples 123 | ) 124 | # print("query",query_bin_proportions) 125 | # print(query_bin_proportions) 126 | # print("self",self.bin_proportions) 127 | different_bins = NDB.two_proportions_z_test( 128 | self.bin_proportions, 129 | self.ref_sample_size, 130 | query_bin_proportions, 131 | n, 132 | significance_level=self.significance_level, 133 | z_threshold=self.z_threshold, 134 | ) 135 | # print("different",different_bins) 136 | ndb = np.count_nonzero(different_bins) 137 | print("ndb", ndb) 138 | js = NDB.jensen_shannon_divergence(self.bin_proportions, query_bin_proportions) 139 | results = { 140 | "NDB": ndb, 141 | "JS": js, 142 | "Proportions": query_bin_proportions, 143 | "N": n, 144 | "Bin-Assignment": query_bin_assignments, 145 | "Different-Bins": different_bins, 146 | } 147 | 148 | if model_label: 149 | print("Results for {} samples from {}: ".format(n, model_label), end="") 150 | self.cached_results[model_label] = results 151 | if self.results_file: 152 | # print('Storing result to', self.results_file) 153 | pkl.dump(self.cached_results, open(self.results_file, "wb")) 154 | 155 | print("NDB =", ndb, "NDB/K =", ndb / self.number_of_bins, ", JS =", js) 156 | return results 157 | 158 | def print_results(self): 159 | print( 160 | "NSB results (K={}{}):".format( 161 | self.number_of_bins, ", data whitening" if self.whitening else "" 162 | ) 163 | ) 164 | for model in sorted(list(self.cached_results.keys())): 165 | res = self.cached_results[model] 166 | print( 167 | "%s: NDB = %d, NDB/K = %.3f, JS = %.4f" 168 | % (model, res["NDB"], res["NDB"] / self.number_of_bins, res["JS"]) 169 | ) 170 | 171 | def plot_results(self, models_to_plot=None): 172 | """ 173 | Plot the binning proportions of different methods 174 | :param models_to_plot: optional list of model labels to plot 175 | """ 176 | K = self.number_of_bins 177 | w = 1.0 / (len(self.cached_results) + 1) 178 | assert K == self.bin_proportions.size 179 | assert self.cached_results 180 | 181 | # Used for plotting only 182 | def calc_se(p1, n1, p2, n2): 183 | p = (p1 * n1 + p2 * n2) / (n1 + n2) 184 | return np.sqrt(p * (1 - p) * (1 / n1 + 1 / n2)) 185 | 186 | if not models_to_plot: 187 | models_to_plot = sorted(list(self.cached_results.keys())) 188 | 189 | # Visualize the standard errors using the train proportions and size and query sample size 190 | train_se = calc_se( 191 | self.bin_proportions, 192 | self.ref_sample_size, 193 | self.bin_proportions, 194 | self.cached_results[models_to_plot[0]]["N"], 195 | ) 196 | plt.bar( 197 | np.arange(0, K) + 0.5, 198 | height=train_se * 2.0, 199 | bottom=self.bin_proportions - train_se, 200 | width=1.0, 201 | label="Train$\pm$SE", 202 | color="gray", 203 | ) 204 | 205 | ymax = 0.0 206 | for i, model in enumerate(models_to_plot): 207 | results = self.cached_results[model] 208 | label = "%s (%i : %.4f)" % (model, results["NDB"], results["JS"]) 209 | ymax = max(ymax, np.max(results["Proportions"])) 210 | if K <= 70: 211 | plt.bar( 212 | np.arange(0, K) + (i + 1.0) * w, 213 | results["Proportions"], 214 | width=w, 215 | label=label, 216 | ) 217 | else: 218 | plt.plot( 219 | np.arange(0, K) + 0.5, results["Proportions"], "--*", label=label 220 | ) 221 | plt.legend(loc="best") 222 | plt.ylim((0.0, min(ymax, np.max(self.bin_proportions) * 4.0))) 223 | plt.grid(True) 224 | plt.title( 225 | "Binning Proportions Evaluation Results for {} bins (NDB : JS)".format(K) 226 | ) 227 | plt.show() 228 | 229 | def __calculate_bin_proportions(self, samples): 230 | if self.bin_centers is None: 231 | print( 232 | "First run construct_bins on samples from the reference training data" 233 | ) 234 | # print("as1",samples.shape[1]) 235 | # print("as2",self.bin_centers.shape[1]) 236 | assert samples.shape[1] == self.bin_centers.shape[1] 237 | n, d = samples.shape 238 | k = self.bin_centers.shape[0] 239 | D = np.zeros([n, k], dtype=samples.dtype) 240 | 241 | # print('Calculating bin assignments for {} samples...'.format(n)) 242 | whitened_samples = (samples - self.training_mean) / self.training_std 243 | for i in range(k): 244 | print(".", end="", flush=True) 245 | D[:, i] = np.linalg.norm( 246 | whitened_samples[:, self.used_d_indices] 247 | - self.bin_centers[i, self.used_d_indices], 248 | ord=2, 249 | axis=1, 250 | ) 251 | print() 252 | labels = np.argmin(D, axis=1) 253 | probs = np.zeros([k]) 254 | label_vals, label_counts = np.unique(labels, return_counts=True) 255 | probs[label_vals] = label_counts / n 256 | return probs, labels 257 | 258 | def __read_from_bins_file(self, bins_file): 259 | if bins_file and os.path.isfile(bins_file): 260 | print("Loading binning results from", bins_file) 261 | bins_data = pkl.load(open(bins_file, "rb")) 262 | self.bin_proportions = bins_data["proportions"] 263 | self.bin_centers = bins_data["centers"] 264 | self.ref_sample_size = bins_data["n"] 265 | self.training_mean = bins_data["mean"] 266 | self.training_std = bins_data["std"] 267 | self.used_d_indices = bins_data["d_indices"] 268 | return True 269 | return False 270 | 271 | def __write_to_bins_file(self, bins_file): 272 | if bins_file: 273 | print("Caching binning results to", bins_file) 274 | bins_data = { 275 | "proportions": self.bin_proportions, 276 | "centers": self.bin_centers, 277 | "n": self.ref_sample_size, 278 | "mean": self.training_mean, 279 | "std": self.training_std, 280 | "d_indices": self.used_d_indices, 281 | } 282 | pkl.dump(bins_data, open(bins_file, "wb")) 283 | 284 | @staticmethod 285 | def two_proportions_z_test(p1, n1, p2, n2, significance_level, z_threshold=None): 286 | # Per http://stattrek.com/hypothesis-test/difference-in-proportions.aspx 287 | # See also http://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/binotest.htm 288 | p = (p1 * n1 + p2 * n2) / (n1 + n2) 289 | se = np.sqrt(p * (1 - p) * (1 / n1 + 1 / n2)) 290 | z = (p1 - p2) / se 291 | # print("z",abs(z)) 292 | # Allow defining a threshold in terms as Z (difference relative to the SE) rather than in p-values. 293 | if z_threshold is not None: 294 | return abs(z) > z_threshold 295 | p_values = 2.0 * norm.cdf(-1.0 * np.abs(z)) # Two-tailed test 296 | return p_values < significance_level 297 | 298 | @staticmethod 299 | def jensen_shannon_divergence(p, q): 300 | """ 301 | Calculates the symmetric Jensen–Shannon divergence between the two PDFs 302 | """ 303 | m = (p + q) * 0.5 304 | return 0.5 * (NDB.kl_divergence(p, m) + NDB.kl_divergence(q, m)) 305 | 306 | @staticmethod 307 | def kl_divergence(p, q): 308 | """ 309 | The Kullback–Leibler divergence. 310 | Defined only if q != 0 whenever p != 0. 311 | """ 312 | assert np.all(np.isfinite(p)) 313 | assert np.all(np.isfinite(q)) 314 | assert not np.any(np.logical_and(p != 0, q == 0)) 315 | 316 | p_pos = p > 0 317 | return np.sum(p[p_pos] * np.log(p[p_pos] / q[p_pos])) 318 | 319 | 320 | if __name__ == "__main__": 321 | dim = 100 322 | k = 100 323 | n_train = k * 100 324 | n_test = k * 10 325 | 326 | train_samples = np.random.uniform(size=[n_train, dim]) 327 | ndb = NDB(training_data=train_samples, number_of_bins=k, whitening=True) 328 | 329 | test_samples = np.random.uniform(high=1.0, size=[n_test, dim]) 330 | ndb.evaluate(test_samples, model_label="Test") 331 | 332 | test_samples = np.random.uniform(high=0.9, size=[n_test, dim]) 333 | ndb.evaluate(test_samples, model_label="Good") 334 | 335 | test_samples = np.random.uniform(high=0.75, size=[n_test, dim]) 336 | ndb.evaluate(test_samples, model_label="Bad") 337 | 338 | ndb.plot_results(models_to_plot=["Test", "Good", "Bad"]) 339 | -------------------------------------------------------------------------------- /audioldm_eval/metrics/validate.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from numpy import cov 3 | from numpy import trace 4 | from numpy import iscomplexobj 5 | from numpy.random import random 6 | from scipy.linalg import sqrtm 7 | 8 | 9 | def calculate_fid(act1, act2): 10 | # calculate mean and covariance statistics 11 | mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False) 12 | mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False) 13 | print("mu1 ", mu1.shape) 14 | print("mu2 ", mu2.shape) 15 | print("sigma1 ", sigma1.shape) 16 | print("sigma2 ", sigma2.shape) 17 | # calculate sum squared difference between means 18 | ssdiff = numpy.sum((mu1 - mu2) * 2.0) 19 | 20 | # calculate sqrt of product between cov 21 | covmean = sqrtm(sigma1.dot(sigma2)) 22 | 23 | # check and correct imaginary numbers from sqrt 24 | if iscomplexobj(covmean): 25 | covmean = covmean.real 26 | # calculate score 27 | fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean) 28 | return fid 29 | 30 | 31 | act1 = random(2048 * 2) 32 | act1 = act1.reshape((2, 2048)) 33 | act2 = random(2048 * 2) 34 | act2 = act2.reshape((2, 2048)) 35 | fid = calculate_fid(act1, act1) 36 | print("FID (same): %.3f" % fid) 37 | fid = calculate_fid(act1, act2) 38 | print("FID (different): %.3f" % fid) 39 | -------------------------------------------------------------------------------- /gen_test_file.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Creates a set of audio files to test FAD calculation.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import errno 23 | import os 24 | 25 | from absl import app 26 | from absl import flags 27 | 28 | import numpy as np 29 | import scipy.io.wavfile 30 | 31 | _SAMPLE_RATE = 16000 32 | 33 | FLAGS = flags.FLAGS 34 | flags.DEFINE_string( 35 | "test_files", "", "Directory where the test files should be located" 36 | ) 37 | 38 | 39 | def create_dir(output_dir): 40 | """Ignore directory exists error.""" 41 | try: 42 | os.makedirs(output_dir) 43 | except OSError as exception: 44 | if exception.errno == errno.EEXIST and os.path.isdir(output_dir): 45 | pass 46 | else: 47 | raise 48 | 49 | 50 | def add_noise(data, stddev): 51 | """Adds Gaussian noise to the samples. 52 | 53 | Args: 54 | data: 1d Numpy array containing floating point samples. Not necessarily 55 | normalized. 56 | stddev: The standard deviation of the added noise. 57 | 58 | Returns: 59 | 1d Numpy array containing the provided floating point samples with added 60 | Gaussian noise. 61 | 62 | Raises: 63 | ValueError: When data is not a 1d numpy array. 64 | """ 65 | if len(data.shape) != 1: 66 | raise ValueError("expected 1d numpy array.") 67 | max_value = np.amax(np.abs(data)) 68 | num_samples = data.shape[0] 69 | gauss = np.random.normal(0, stddev, (num_samples)) * max_value 70 | return data + gauss 71 | 72 | 73 | def gen_sine_wave(freq=600, length_seconds=6, sample_rate=_SAMPLE_RATE, param=None): 74 | """Creates sine wave of the specified frequency, sample_rate and length.""" 75 | t = np.linspace(0, length_seconds, int(length_seconds * sample_rate)) 76 | samples = np.sin(2 * np.pi * t * freq) 77 | if param: 78 | samples = add_noise(samples, param) 79 | return np.asarray(2**15 * samples, dtype=np.int16) 80 | 81 | 82 | def main(argv): 83 | del argv # Unused. 84 | for traget, count, param in [ 85 | ("reference", 50, 0.0), 86 | ("paired", 50, 0.001), 87 | ("unpaired", 25, 0.001), 88 | ]: 89 | output_dir = os.path.join(FLAGS.test_files, "example", traget) 90 | create_dir(output_dir) 91 | print("output_dir:", output_dir) 92 | frequencies = np.linspace(100, 1000, count).tolist() 93 | for freq in frequencies: 94 | samples = gen_sine_wave(freq, param=param) 95 | filename = os.path.join(output_dir, "sin_%.0f.wav" % freq) 96 | print("Creating: %s with %i samples." % (filename, samples.shape[0])) 97 | scipy.io.wavfile.write(filename, _SAMPLE_RATE, samples) 98 | 99 | 100 | if __name__ == "__main__": 101 | os.makedirs("example", exist_ok=True) 102 | app.run(main) 103 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | # python3 setup.py sdist bdist_wheel 4 | """ 5 | @File : setup.py.py 6 | @Contact : haoheliu@gmail.com 7 | @License : (C)Copyright 2020-2100 8 | 9 | @Modify Time @Author @Version @Desciption 10 | ------------ ------- -------- ----------- 11 | 9/6/21 5:16 PM Haohe Liu 1.0 None 12 | """ 13 | 14 | # !/usr/bin/env python 15 | # -*- coding: utf-8 -*- 16 | 17 | # Note: To use the 'upload' functionality of this file, you must: 18 | # $ pipenv install twine --dev 19 | 20 | import io 21 | import os 22 | import sys 23 | from shutil import rmtree 24 | 25 | from setuptools import find_packages, setup, Command 26 | 27 | # Package meta-data. 28 | NAME = "audioldm_eval" 29 | DESCRIPTION = "This package is written for the evaluation of audio generation model." 30 | URL = "https://github.com/haoheliu/audioldm_eval" 31 | EMAIL = "haoheliu@gmail.com" 32 | AUTHOR = "Haohe Liu" 33 | REQUIRES_PYTHON = ">=3.6.0" 34 | VERSION = "0.0.5" 35 | 36 | # What packages are required for this module to be executed? 37 | REQUIRED = [ 38 | "torch>=1.11.0", 39 | "torchaudio", 40 | "transformers", 41 | "scikit-image", 42 | "torchlibrosa", 43 | "absl-py", 44 | "scipy", 45 | "tqdm", 46 | "ssr_eval", 47 | "librosa", 48 | ] 49 | 50 | # What packages are optional? 51 | EXTRAS = {} 52 | 53 | # The rest you shouldn't have to touch too much :) 54 | # ------------------------------------------------ 55 | # Except, perhaps the License and Trove Classifiers! 56 | # If you do change the License, remember to change the Trove Classifier for that! 57 | 58 | here = os.path.abspath(os.path.dirname(__file__)) 59 | 60 | # Import the README and use it as the long-description. 61 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 62 | try: 63 | with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f: 64 | long_description = "\n" + f.read() 65 | except FileNotFoundError: 66 | long_description = DESCRIPTION 67 | 68 | # Load the package's __version__.py module as a dictionary. 69 | about = {} 70 | if not VERSION: 71 | project_slug = NAME.lower().replace("-", "_").replace(" ", "_") 72 | with open(os.path.join(here, project_slug, "__version__.py")) as f: 73 | exec(f.read(), about) 74 | else: 75 | about["__version__"] = VERSION 76 | 77 | 78 | class UploadCommand(Command): 79 | """Support setup.py upload.""" 80 | 81 | description = "Build and publish the package." 82 | user_options = [] 83 | 84 | @staticmethod 85 | def status(s): 86 | """Prints things in bold.""" 87 | print("\033[1m{0}\033[0m".format(s)) 88 | 89 | def initialize_options(self): 90 | pass 91 | 92 | def finalize_options(self): 93 | pass 94 | 95 | def run(self): 96 | try: 97 | self.status("Removing previous builds…") 98 | rmtree(os.path.join(here, "dist")) 99 | except OSError: 100 | pass 101 | 102 | self.status("Building Source and Wheel (universal) distribution…") 103 | os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable)) 104 | 105 | self.status("Uploading the package to PyPI via Twine…") 106 | os.system("twine upload dist/*") 107 | 108 | self.status("Pushing git tags…") 109 | os.system("git tag v{0}".format(about["__version__"])) 110 | os.system("git push --tags") 111 | 112 | sys.exit() 113 | 114 | 115 | # Where the magic happens: 116 | setup( 117 | name=NAME, 118 | version=about["__version__"], 119 | description=DESCRIPTION, 120 | long_description=long_description, 121 | long_description_content_type="text/markdown", 122 | author=AUTHOR, 123 | author_email=EMAIL, 124 | python_requires=REQUIRES_PYTHON, 125 | url=URL, 126 | # packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]), 127 | # If your package is a single module, use this instead of 'packages': 128 | py_modules=["torchsubband"], 129 | # entry_points={ 130 | # 'console_scripts': ['mycli=mymodule:cli'], 131 | # }, 132 | install_requires=REQUIRED, 133 | extras_require=EXTRAS, 134 | packages=find_packages(), 135 | include_package_data=True, 136 | license="MIT", 137 | classifiers=[ 138 | # Trove classifiers 139 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 140 | "License :: OSI Approved :: MIT License", 141 | "Programming Language :: Python", 142 | "Programming Language :: Python :: 3", 143 | "Programming Language :: Python :: 3.7", 144 | "Programming Language :: Python :: Implementation :: CPython", 145 | "Programming Language :: Python :: Implementation :: PyPy", 146 | ], 147 | # $ setup.py publish support. 148 | cmdclass={ 149 | "upload": UploadCommand, 150 | }, 151 | ) 152 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from audioldm_eval import EvaluationHelper, EvaluationHelperParallel 3 | import torch.multiprocessing as mp 4 | 5 | device = torch.device(f"cuda:{0}") 6 | 7 | generation_result_path = "example/paired" 8 | # generation_result_path = "example/unpaired" 9 | target_audio_path = "example/reference" 10 | 11 | ## Single GPU 12 | 13 | evaluator = EvaluationHelper(16000, device) 14 | 15 | # Perform evaluation, result will be print out and saved as json 16 | metrics = evaluator.main( 17 | generation_result_path, 18 | target_audio_path, 19 | ) 20 | 21 | ## Multiple GPUs 22 | 23 | if __name__ == '__main__': 24 | evaluator = EvaluationHelperParallel(16000, 2) 25 | metrics = evaluator.main( 26 | generation_result_path, 27 | target_audio_path, 28 | ) --------------------------------------------------------------------------------