├── tips └── gammatone.png ├── setup.py ├── nnAudio2 ├── librosa_LICENSE.md ├── augmentation.py ├── librosa_filters.py ├── __init__.py └── Spectrogram.py ├── LICENSE └── README.md /tips/gammatone.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangHelin1997/nnAudio2/HEAD/tips/gammatone.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="nnAudio2", 8 | version="0.0.1", 9 | author="Helin Wang", 10 | author_email="wanghl15@pku.edu.cn", 11 | description="PyTorch implemention of audio processing functions based on nnAudio.", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/WangHelin1997/nnAudio2", 15 | packages=setuptools.find_packages(), 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: OS Independent", 20 | ], 21 | python_requires='>=3.6', 22 | ) 23 | -------------------------------------------------------------------------------- /nnAudio2/librosa_LICENSE.md: -------------------------------------------------------------------------------- 1 | ISC License 2 | 3 | Copyright (c) 2013--2017, librosa development team. 4 | 5 | Permission to use, copy, modify, and/or distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 The Python Packaging Authority 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nnAudio2 2 | Audio processing by using pytorch 1D convolution network (based on nnAudio). 3 | 4 | Gammatone Spectrogram and SpecAugmentation are now available on GPU. 5 | 6 | # Notice 7 | This code has been merged pull request to https://github.com/KinWaiCheuk/nnAudio. 8 | 9 | # Install 10 | ``` 11 | $ pip install -i https://test.pypi.org/simple/ nnAudio2 12 | ``` 13 | 14 | # Dependencies 15 | Numpy 1.14.5 16 | 17 | Scipy 1.2.0 18 | 19 | PyTorch 1.1.0 20 | 21 | Python >= 3.6 22 | 23 | librosa = 0.7.0 (Theortically nnAudio depends on librosa. But we only need to use a single function `mel` from `librosa.filters`. To save users troubles from installing librosa for this single function, I just copy the chunks of functions corresponding to `mel` in my code so that nnAudio runs without the need to install librosa) 24 | 25 | # Quick Start 26 | Gammatone Spectrogram is now available in nnAudio2, and other details can be found in https://github.com/KinWaiCheuk/nnAudio. 27 | 28 | In addition, more audio augmentation methods are available in nnAudio2, according to https://github.com/qiuqiangkong/torchlibrosa. 29 | 30 | ### 1. STFT 31 | ```python 32 | from nnAudio2 import Spectrogram 33 | Spectrogram.STFT(n_fft=2048, freq_bins=None, hop_length=512, window='hann', freq_scale='no', center=True, pad_mode='reflect', fmin=50,fmax=6000, sr=22050, trainable=False) 34 | ``` 35 | 36 | ### 2. Mel Spectrogram 37 | ```python 38 | Spectrogram.MelSpectrogram(sr=22050, n_fft=2048, n_mels=128, hop_length=512, window='hann', center=True, pad_mode='reflect', htk=False, fmin=0.0, fmax=None, norm=1, trainable_mel=False, trainable_STFT=False) 39 | ``` 40 | 41 | ### 3. CQT 42 | ```python 43 | Spectrogram.CQT(sr=22050, hop_length=512, fmin=220, fmax=None, n_bins=84, bins_per_octave=12, norm=1, window='hann', center=True, pad_mode='reflect') 44 | ``` 45 | 46 | ### 4. Gammatone Spectrogram 47 | ```python 48 | Spectrogram.Gammatonegram(sr=44100, n_fft=2048, n_bins=64, hop_length=512, window='hann', center=True, pad_mode='reflect', htk=False, fmin=50.0, fmax=None, norm=1, trainable_bins=False, trainable_STFT=False) 49 | ``` 50 | 51 | #### The Gammatone filters by nnAudio2 52 | ![alt text](https://github.com/WangHelin1997/nnAudio2/blob/master/tips/gammatone.png) 53 | 54 | # References 55 | 1. https://github.com/KinWaiCheuk/nnAudio 56 | 2. https://github.com/qiuqiangkong/torchlibrosa 57 | 3. https://github.com/mcusi/gammatonegram 58 | -------------------------------------------------------------------------------- /nnAudio2/augmentation.py: -------------------------------------------------------------------------------- 1 | # Reference : https://github.com/qiuqiangkong/torchlibrosa 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class DropStripes(nn.Module): 12 | def __init__(self, dim, drop_width, stripes_num): 13 | """Drop stripes. 14 | Args: 15 | dim: int, dimension along which to drop 16 | drop_width: int, maximum width of stripes to drop 17 | stripes_num: int, how many stripes to drop 18 | """ 19 | super(DropStripes, self).__init__() 20 | 21 | assert dim in [2, 3] # dim 2: time; dim 3: frequency 22 | 23 | self.dim = dim 24 | self.drop_width = drop_width 25 | self.stripes_num = stripes_num 26 | 27 | def forward(self, input): 28 | """input: (batch_size, channels, time_steps, freq_bins)""" 29 | 30 | assert input.ndimension() == 4 31 | 32 | if self.training is False: 33 | return input 34 | 35 | else: 36 | batch_size = input.shape[0] 37 | total_width = input.shape[self.dim] 38 | 39 | for n in range(batch_size): 40 | self.transform_slice(input[n], total_width) 41 | 42 | return input 43 | 44 | 45 | def transform_slice(self, e, total_width): 46 | """e: (channels, time_steps, freq_bins)""" 47 | 48 | for _ in range(self.stripes_num): 49 | distance = torch.randint(low=0, high=self.drop_width, size=(1,))[0] 50 | bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0] 51 | 52 | if self.dim == 2: 53 | e[:, bgn : bgn + distance, :] = 0 54 | elif self.dim == 3: 55 | e[:, :, bgn : bgn + distance] = 0 56 | 57 | 58 | class SpecAugmentation(nn.Module): 59 | def __init__(self, time_drop_width, time_stripes_num, freq_drop_width, 60 | freq_stripes_num): 61 | """Spec augmetation. 62 | [ref] Park, D.S., Chan, W., Zhang, Y., Chiu, C.C., Zoph, B., Cubuk, E.D. 63 | and Le, Q.V., 2019. Specaugment: A simple data augmentation method 64 | for automatic speech recognition. arXiv preprint arXiv:1904.08779. 65 | Args: 66 | time_drop_width: int 67 | time_stripes_num: int 68 | freq_drop_width: int 69 | freq_stripes_num: int 70 | """ 71 | 72 | super(SpecAugmentation, self).__init__() 73 | 74 | self.time_dropper = DropStripes(dim=2, drop_width=time_drop_width, 75 | stripes_num=time_stripes_num) 76 | 77 | self.freq_dropper = DropStripes(dim=3, drop_width=freq_drop_width, 78 | stripes_num=freq_stripes_num) 79 | 80 | def forward(self, input): 81 | x = self.time_dropper(input) 82 | x = self.freq_dropper(x) 83 | return x 84 | 85 | 86 | if __name__ == '__main__': 87 | 88 | torch.manual_seed(0) 89 | random_state = np.random.RandomState(0) 90 | np_data = random_state.normal(size=(10, 4, 640, 64)) 91 | pt_data = torch.Tensor(np_data) 92 | 93 | spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 94 | freq_drop_width=16, freq_stripes_num=2) 95 | 96 | # Training stage 97 | spec_augmenter.train() # set to spec_augmenter.eval() for evaluation 98 | result = spec_augmenter(pt_data) 99 | 100 | print(result.shape) 101 | -------------------------------------------------------------------------------- /nnAudio2/librosa_filters.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import warnings 3 | ### ----------------Functions for generating kenral for Mel Spectrogram------------ ### 4 | # This code is equalvant to from librosa.filters import mel 5 | # By doing so, we can run nnAudio without installing librosa 6 | def fft2gammatonemx(sr=20000, n_fft=2048, n_bins=64, width=1.0, fmin=0.0, 7 | fmax=11025, maxlen=1024): 8 | """ 9 | # Ellis' description in MATLAB: 10 | # [wts,cfreqa] = fft2gammatonemx(nfft, sr, nfilts, width, minfreq, maxfreq, maxlen) 11 | # Generate a matrix of weights to combine FFT bins into 12 | # Gammatone bins. nfft defines the source FFT size at 13 | # sampling rate sr. Optional nfilts specifies the number of 14 | # output bands required (default 64), and width is the 15 | # constant width of each band in Bark (default 1). 16 | # minfreq, maxfreq specify range covered in Hz (100, sr/2). 17 | # While wts has nfft columns, the second half are all zero. 18 | # Hence, aud spectrum is 19 | # fft2gammatonemx(nfft,sr)*abs(fft(xincols,nfft)); 20 | # maxlen truncates the rows to this many bins. 21 | # cfreqs returns the actual center frequencies of each 22 | # gammatone band in Hz. 23 | # 24 | # 2009/02/22 02:29:25 Dan Ellis dpwe@ee.columbia.edu based on rastamat/audspec.m 25 | # Sat May 27 15:37:50 2017 Maddie Cusimano, mcusi@mit.edu 27 May 2017: convert to python 26 | """ 27 | 28 | wts = np.zeros([n_bins,n_fft],dtype=np.float32) 29 | 30 | #after Slaney's MakeERBFilters 31 | EarQ = 9.26449; minBW = 24.7; order = 1; 32 | 33 | nFr = np.array(range(n_bins)) + 1 34 | em = EarQ*minBW 35 | cfreqs = (fmax+em)*np.exp(nFr*(-np.log(fmax + em)+np.log(fmin + em))/n_bins)-em 36 | cfreqs = cfreqs[::-1] 37 | 38 | GTord = 4 39 | ucircArray = np.array(range(int(n_fft/2 + 1))) 40 | ucirc = np.exp(1j*2*np.pi*ucircArray/n_fft); 41 | #justpoles = 0 :taking out the 'if' corresponding to this. 42 | 43 | ERB = width*np.power(np.power(cfreqs/EarQ,order) + np.power(minBW,order),1/order); 44 | B = 1.019 * 2 * np.pi * ERB; 45 | r = np.exp(-B/sr) 46 | theta = 2*np.pi*cfreqs/sr 47 | pole = r*np.exp(1j*theta) 48 | T = 1/sr 49 | ebt = np.exp(B*T); cpt = 2*cfreqs*np.pi*T; 50 | ccpt = 2*T*np.cos(cpt); scpt = 2*T*np.sin(cpt); 51 | A11 = -np.divide(np.divide(ccpt,ebt) + np.divide(np.sqrt(3+2**1.5)*scpt,ebt),2); 52 | A12 = -np.divide(np.divide(ccpt,ebt) - np.divide(np.sqrt(3+2**1.5)*scpt,ebt),2); 53 | A13 = -np.divide(np.divide(ccpt,ebt) + np.divide(np.sqrt(3-2**1.5)*scpt,ebt),2); 54 | A14 = -np.divide(np.divide(ccpt,ebt) - np.divide(np.sqrt(3-2**1.5)*scpt,ebt),2); 55 | zros = -np.array([A11, A12, A13, A14])/T; 56 | wIdx = range(int(n_fft/2 + 1)) 57 | gain = np.abs((-2*np.exp(4*1j*cfreqs*np.pi*T)*T + 2*np.exp(-(B*T) + 2*1j*cfreqs*np.pi*T)*T* (np.cos(2*cfreqs*np.pi*T) - np.sqrt(3 - 2**(3/2))* np.sin(2*cfreqs*np.pi*T))) *(-2*np.exp(4*1j*cfreqs*np.pi*T)*T + 2*np.exp(-(B*T) + 2*1j*cfreqs*np.pi*T)*T* (np.cos(2*cfreqs*np.pi*T) + np.sqrt(3 - 2**(3/2)) * np.sin(2*cfreqs*np.pi*T)))*(-2*np.exp(4*1j*cfreqs*np.pi*T)*T + 2*np.exp(-(B*T) + 2*1j*cfreqs*np.pi*T)*T* (np.cos(2*cfreqs*np.pi*T) - np.sqrt(3 + 2**(3/2))*np.sin(2*cfreqs*np.pi*T))) *(-2*np.exp(4*1j*cfreqs*np.pi*T)*T + 2*np.exp(-(B*T) + 2*1j*cfreqs*np.pi*T)*T* (np.cos(2*cfreqs*np.pi*T) + np.sqrt(3 + 2**(3/2))*np.sin(2*cfreqs*np.pi*T))) /(-2 / np.exp(2*B*T) - 2*np.exp(4*1j*cfreqs*np.pi*T) + 2*(1 + np.exp(4*1j*cfreqs*np.pi*T))/np.exp(B*T))**4); 58 | #in MATLAB, there used to be 64 where here it says n_bins: 59 | wts[:, wIdx] = ((T**4)/np.reshape(gain,(n_bins,1))) * np.abs(ucirc-np.reshape(zros[0],(n_bins,1)))*np.abs(ucirc-np.reshape(zros[1],(n_bins,1)))*np.abs(ucirc-np.reshape(zros[2],(n_bins,1)))*np.abs(ucirc-np.reshape(zros[3],(n_bins,1)))*(np.abs(np.power(np.multiply(np.reshape(pole,(n_bins,1))-ucirc,np.conj(np.reshape(pole,(n_bins,1)))-ucirc),-GTord))); 60 | wts = wts[:,range(maxlen)]; 61 | 62 | return wts, cfreqs 63 | 64 | def mel_to_hz(mels, htk=False): 65 | """Convert mel bin numbers to frequencies 66 | Examples 67 | -------- 68 | >>> librosa.mel_to_hz(3) 69 | 200. 70 | >>> librosa.mel_to_hz([1,2,3,4,5]) 71 | array([ 66.667, 133.333, 200. , 266.667, 333.333]) 72 | Parameters 73 | ---------- 74 | mels : np.ndarray [shape=(n,)], float 75 | mel bins to convert 76 | htk : bool 77 | use HTK formula instead of Slaney 78 | Returns 79 | ------- 80 | frequencies : np.ndarray [shape=(n,)] 81 | input mels in Hz 82 | See Also 83 | -------- 84 | hz_to_mel 85 | """ 86 | 87 | mels = np.asanyarray(mels) 88 | 89 | if htk: 90 | return 700.0 * (10.0**(mels / 2595.0) - 1.0) 91 | 92 | # Fill in the linear scale 93 | f_min = 0.0 94 | f_sp = 200.0 / 3 95 | freqs = f_min + f_sp * mels 96 | 97 | # And now the nonlinear scale 98 | min_log_hz = 1000.0 # beginning of log region (Hz) 99 | min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) 100 | logstep = np.log(6.4) / 27.0 # step size for log region 101 | 102 | if mels.ndim: 103 | # If we have vector data, vectorize 104 | log_t = (mels >= min_log_mel) 105 | freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel)) 106 | elif mels >= min_log_mel: 107 | # If we have scalar data, check directly 108 | freqs = min_log_hz * np.exp(logstep * (mels - min_log_mel)) 109 | 110 | return freqs 111 | 112 | def hz_to_mel(frequencies, htk=False): 113 | """Convert Hz to Mels 114 | Examples 115 | -------- 116 | >>> librosa.hz_to_mel(60) 117 | 0.9 118 | >>> librosa.hz_to_mel([110, 220, 440]) 119 | array([ 1.65, 3.3 , 6.6 ]) 120 | Parameters 121 | ---------- 122 | frequencies : number or np.ndarray [shape=(n,)] , float 123 | scalar or array of frequencies 124 | htk : bool 125 | use HTK formula instead of Slaney 126 | Returns 127 | ------- 128 | mels : number or np.ndarray [shape=(n,)] 129 | input frequencies in Mels 130 | See Also 131 | -------- 132 | mel_to_hz 133 | """ 134 | 135 | frequencies = np.asanyarray(frequencies) 136 | 137 | if htk: 138 | return 2595.0 * np.log10(1.0 + frequencies / 700.0) 139 | 140 | # Fill in the linear part 141 | f_min = 0.0 142 | f_sp = 200.0 / 3 143 | 144 | mels = (frequencies - f_min) / f_sp 145 | 146 | # Fill in the log-scale part 147 | 148 | min_log_hz = 1000.0 # beginning of log region (Hz) 149 | min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) 150 | logstep = np.log(6.4) / 27.0 # step size for log region 151 | 152 | if frequencies.ndim: 153 | # If we have array data, vectorize 154 | log_t = (frequencies >= min_log_hz) 155 | mels[log_t] = min_log_mel + np.log(frequencies[log_t]/min_log_hz) / logstep 156 | elif frequencies >= min_log_hz: 157 | # If we have scalar data, heck directly 158 | mels = min_log_mel + np.log(frequencies / min_log_hz) / logstep 159 | 160 | return mels 161 | 162 | def fft_frequencies(sr=22050, n_fft=2048): 163 | '''Alternative implementation of `np.fft.fftfreq` 164 | Parameters 165 | ---------- 166 | sr : number > 0 [scalar] 167 | Audio sampling rate 168 | n_fft : int > 0 [scalar] 169 | FFT window size 170 | Returns 171 | ------- 172 | freqs : np.ndarray [shape=(1 + n_fft/2,)] 173 | Frequencies `(0, sr/n_fft, 2*sr/n_fft, ..., sr/2)` 174 | Examples 175 | -------- 176 | >>> librosa.fft_frequencies(sr=22050, n_fft=16) 177 | array([ 0. , 1378.125, 2756.25 , 4134.375, 178 | 5512.5 , 6890.625, 8268.75 , 9646.875, 11025. ]) 179 | ''' 180 | 181 | return np.linspace(0, 182 | float(sr) / 2, 183 | int(1 + n_fft//2), 184 | endpoint=True) 185 | 186 | def mel_frequencies(n_mels=128, fmin=0.0, fmax=11025.0, htk=False): 187 | """Compute an array of acoustic frequencies tuned to the mel scale. 188 | The mel scale is a quasi-logarithmic function of acoustic frequency 189 | designed such that perceptually similar pitch intervals (e.g. octaves) 190 | appear equal in width over the full hearing range. 191 | Because the definition of the mel scale is conditioned by a finite number 192 | of subjective psychoaoustical experiments, several implementations coexist 193 | in the audio signal processing literature [1]_. By default, librosa replicates 194 | the behavior of the well-established MATLAB Auditory Toolbox of Slaney [2]_. 195 | According to this default implementation, the conversion from Hertz to mel is 196 | linear below 1 kHz and logarithmic above 1 kHz. Another available implementation 197 | replicates the Hidden Markov Toolkit [3]_ (HTK) according to the following formula: 198 | `mel = 2595.0 * np.log10(1.0 + f / 700.0).` 199 | The choice of implementation is determined by the `htk` keyword argument: setting 200 | `htk=False` leads to the Auditory toolbox implementation, whereas setting it `htk=True` 201 | leads to the HTK implementation. 202 | .. [1] Umesh, S., Cohen, L., & Nelson, D. Fitting the mel scale. 203 | In Proc. International Conference on Acoustics, Speech, and Signal Processing 204 | (ICASSP), vol. 1, pp. 217-220, 1998. 205 | .. [2] Slaney, M. Auditory Toolbox: A MATLAB Toolbox for Auditory 206 | Modeling Work. Technical Report, version 2, Interval Research Corporation, 1998. 207 | .. [3] Young, S., Evermann, G., Gales, M., Hain, T., Kershaw, D., Liu, X., 208 | Moore, G., Odell, J., Ollason, D., Povey, D., Valtchev, V., & Woodland, P. 209 | The HTK book, version 3.4. Cambridge University, March 2009. 210 | See Also 211 | -------- 212 | hz_to_mel 213 | mel_to_hz 214 | librosa.feature.melspectrogram 215 | librosa.feature.mfcc 216 | Parameters 217 | ---------- 218 | n_mels : int > 0 [scalar] 219 | Number of mel bins. 220 | fmin : float >= 0 [scalar] 221 | Minimum frequency (Hz). 222 | fmax : float >= 0 [scalar] 223 | Maximum frequency (Hz). 224 | htk : bool 225 | If True, use HTK formula to convert Hz to mel. 226 | Otherwise (False), use Slaney's Auditory Toolbox. 227 | Returns 228 | ------- 229 | bin_frequencies : ndarray [shape=(n_mels,)] 230 | Vector of n_mels frequencies in Hz which are uniformly spaced on the Mel 231 | axis. 232 | Examples 233 | -------- 234 | >>> librosa.mel_frequencies(n_mels=40) 235 | array([ 0. , 85.317, 170.635, 255.952, 236 | 341.269, 426.586, 511.904, 597.221, 237 | 682.538, 767.855, 853.173, 938.49 , 238 | 1024.856, 1119.114, 1222.042, 1334.436, 239 | 1457.167, 1591.187, 1737.532, 1897.337, 240 | 2071.84 , 2262.393, 2470.47 , 2697.686, 241 | 2945.799, 3216.731, 3512.582, 3835.643, 242 | 4188.417, 4573.636, 4994.285, 5453.621, 243 | 5955.205, 6502.92 , 7101.009, 7754.107, 244 | 8467.272, 9246.028, 10096.408, 11025. ]) 245 | """ 246 | 247 | # 'Center freqs' of mel bands - uniformly spaced between limits 248 | min_mel = hz_to_mel(fmin, htk=htk) 249 | max_mel = hz_to_mel(fmax, htk=htk) 250 | 251 | mels = np.linspace(min_mel, max_mel, n_mels) 252 | 253 | return mel_to_hz(mels, htk=htk) 254 | 255 | def mel(sr, n_fft, n_mels=128, fmin=0.0, fmax=None, htk=False, 256 | norm=1, dtype=np.float32): 257 | """Create a Filterbank matrix to combine FFT bins into Mel-frequency bins 258 | Parameters 259 | ---------- 260 | sr : number > 0 [scalar] 261 | sampling rate of the incoming signal 262 | n_fft : int > 0 [scalar] 263 | number of FFT components 264 | n_mels : int > 0 [scalar] 265 | number of Mel bands to generate 266 | fmin : float >= 0 [scalar] 267 | lowest frequency (in Hz) 268 | fmax : float >= 0 [scalar] 269 | highest frequency (in Hz). 270 | If `None`, use `fmax = sr / 2.0` 271 | htk : bool [scalar] 272 | use HTK formula instead of Slaney 273 | norm : {None, 1, np.inf} [scalar] 274 | if 1, divide the triangular mel weights by the width of the mel band 275 | (area normalization). Otherwise, leave all the triangles aiming for 276 | a peak value of 1.0 277 | dtype : np.dtype 278 | The data type of the output basis. 279 | By default, uses 32-bit (single-precision) floating point. 280 | Returns 281 | ------- 282 | M : np.ndarray [shape=(n_mels, 1 + n_fft/2)] 283 | Mel transform matrix 284 | Notes 285 | ----- 286 | This function caches at level 10. 287 | Examples 288 | -------- 289 | >>> melfb = librosa.filters.mel(22050, 2048) 290 | >>> melfb 291 | array([[ 0. , 0.016, ..., 0. , 0. ], 292 | [ 0. , 0. , ..., 0. , 0. ], 293 | ..., 294 | [ 0. , 0. , ..., 0. , 0. ], 295 | [ 0. , 0. , ..., 0. , 0. ]]) 296 | Clip the maximum frequency to 8KHz 297 | >>> librosa.filters.mel(22050, 2048, fmax=8000) 298 | array([[ 0. , 0.02, ..., 0. , 0. ], 299 | [ 0. , 0. , ..., 0. , 0. ], 300 | ..., 301 | [ 0. , 0. , ..., 0. , 0. ], 302 | [ 0. , 0. , ..., 0. , 0. ]]) 303 | >>> import matplotlib.pyplot as plt 304 | >>> plt.figure() 305 | >>> librosa.display.specshow(melfb, x_axis='linear') 306 | >>> plt.ylabel('Mel filter') 307 | >>> plt.title('Mel filter bank') 308 | >>> plt.colorbar() 309 | >>> plt.tight_layout() 310 | >>> plt.show() 311 | """ 312 | 313 | if fmax is None: 314 | fmax = float(sr) / 2 315 | 316 | if norm is not None and norm != 1 and norm != np.inf: 317 | raise ParameterError('Unsupported norm: {}'.format(repr(norm))) 318 | 319 | # Initialize the weights 320 | n_mels = int(n_mels) 321 | weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype) 322 | 323 | # Center freqs of each FFT bin 324 | fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft) 325 | 326 | # 'Center freqs' of mel bands - uniformly spaced between limits 327 | mel_f = mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax, htk=htk) 328 | 329 | fdiff = np.diff(mel_f) 330 | ramps = np.subtract.outer(mel_f, fftfreqs) 331 | 332 | for i in range(n_mels): 333 | # lower and upper slopes for all bins 334 | lower = -ramps[i] / fdiff[i] 335 | upper = ramps[i+2] / fdiff[i+1] 336 | 337 | # .. then intersect them with each other and zero 338 | weights[i] = np.maximum(0, np.minimum(lower, upper)) 339 | 340 | if norm == 1: 341 | # Slaney-style mel is scaled to be approx constant energy per channel 342 | enorm = 2.0 / (mel_f[2:n_mels+2] - mel_f[:n_mels]) 343 | weights *= enorm[:, np.newaxis] 344 | 345 | # Only check weights if f_mel[0] is positive 346 | if not np.all((mel_f[:-2] == 0) | (weights.max(axis=1) > 0)): 347 | # This means we have an empty channel somewhere 348 | warnings.warn('Empty filters detected in mel frequency basis. ' 349 | 'Some channels will produce empty responses. ' 350 | 'Try increasing your sampling rate (and fmax) or ' 351 | 'reducing n_mels.') 352 | 353 | return weights 354 | 355 | def gammatone(sr, n_fft, n_bins=64, fmin=20.0, fmax=None, htk=False, 356 | norm=1, dtype=np.float32): 357 | """Create a Filterbank matrix to combine FFT bins into Gammatone bins 358 | Parameters 359 | ---------- 360 | sr : number > 0 [scalar] 361 | sampling rate of the incoming signal 362 | n_fft : int > 0 [scalar] 363 | number of FFT components 364 | n_bins : int > 0 [scalar] 365 | number of Mel bands to generate 366 | fmin : float >= 0 [scalar] 367 | lowest frequency (in Hz) 368 | fmax : float >= 0 [scalar] 369 | highest frequency (in Hz). 370 | If `None`, use `fmax = sr / 2.0` 371 | htk : bool [scalar] 372 | use HTK formula instead of Slaney 373 | norm : {None, 1, np.inf} [scalar] 374 | if 1, divide the triangular mel weights by the width of the mel band 375 | (area normalization). Otherwise, leave all the triangles aiming for 376 | a peak value of 1.0 377 | dtype : np.dtype 378 | The data type of the output basis. 379 | By default, uses 32-bit (single-precision) floating point. 380 | Returns 381 | ------- 382 | G : np.ndarray [shape=(n_bins, 1 + n_fft/2)] 383 | Gammatone transform matrix 384 | """ 385 | 386 | if fmax is None: 387 | fmax = float(sr) / 2 388 | n_bins = int(n_bins) 389 | 390 | weights,_ = fft2gammatonemx(sr=sr, n_fft=n_fft, n_bins=n_bins, fmin=fmin, fmax=fmax, maxlen=int(n_fft//2+1)) 391 | 392 | return (1/n_fft)*weights 393 | ### ------------------End of Functions for generating kenral for Mel Spectrogram ----------------### 394 | -------------------------------------------------------------------------------- /nnAudio2/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.functional import conv1d, conv2d 4 | 5 | import numpy as np 6 | import torch 7 | from time import time 8 | import math 9 | from scipy.signal import get_window 10 | from scipy import signal 11 | from scipy import fft 12 | import warnings 13 | 14 | sz_float = 4 # size of a float 15 | epsilon = 10e-8 # fudge factor for normalization 16 | 17 | # ---------------------------Filter design ----------------------------------- 18 | def create_lowpass_filter(band_center=0.5, kernelLength=256, transitionBandwidth=0.03): 19 | # calculate the highest frequency we need to preserve and the 20 | # lowest frequency we allow to pass through. Note that frequency 21 | # is on a scale from 0 to 1 where 0 is 0 and 1 is Nyquist 22 | # frequency of the signal BEFORE downsampling 23 | 24 | # transitionBandwidth = 0.03 25 | passbandMax = band_center / (1 + transitionBandwidth) 26 | stopbandMin = band_center * (1 + transitionBandwidth) 27 | 28 | # Unlike the filter tool we used online yesterday, this tool does 29 | # not allow us to specify how closely the filter matches our 30 | # specifications. Instead, we specify the length of the kernel. 31 | # The longer the kernel is, the more precisely it will match. 32 | # kernelLength = 256 33 | 34 | # We specify a list of key frequencies for which we will require 35 | # that the filter match a specific output gain. 36 | # From [0.0 to passbandMax] is the frequency range we want to keep 37 | # untouched and [stopbandMin, 1.0] is the range we want to remove 38 | keyFrequencies = [0.0, passbandMax, stopbandMin, 1.0] 39 | 40 | # We specify a list of output gains to correspond to the key 41 | # frequencies listed above. 42 | # The first two gains are 1.0 because they correspond to the first 43 | # two key frequencies. the second two are 0.0 because they 44 | # correspond to the stopband frequencies 45 | gainAtKeyFrequencies = [1.0, 1.0, 0.0, 0.0] 46 | 47 | # This command produces the filter kernel coefficients 48 | filterKernel = signal.firwin2(kernelLength, keyFrequencies, gainAtKeyFrequencies) 49 | 50 | return filterKernel.astype(np.float32) 51 | 52 | def downsampling_by_n(x, filterKernel, n): 53 | """downsampling by n""" 54 | x = conv1d(x,filterKernel,stride=n, padding=(filterKernel.shape[-1]-1)//2) 55 | return x 56 | 57 | def downsampling_by_2(x, filterKernel): 58 | x = conv1d(x,filterKernel,stride=2, padding=(filterKernel.shape[-1]-1)//2) 59 | return x 60 | 61 | 62 | ## Basic tools for computation ## 63 | def nextpow2(A): 64 | return int(np.ceil(np.log2(A))) 65 | 66 | def complex_mul(cqt_filter, stft): 67 | """Since PyTorch does not support complex numbers and its operation. We need to write our own complex multiplication function. This one is specially designed for CQT usage""" 68 | 69 | cqt_filter_real = cqt_filter[0] 70 | cqt_filter_imag = cqt_filter[1] 71 | fourier_real = stft[0] 72 | fourier_imag = stft[1] 73 | 74 | CQT_real = torch.matmul(cqt_filter_real, fourier_real) - torch.matmul(cqt_filter_imag, fourier_imag) 75 | CQT_imag = torch.matmul(cqt_filter_real, fourier_imag) + torch.matmul(cqt_filter_imag, fourier_real) 76 | 77 | return CQT_real, CQT_imag 78 | 79 | def broadcast_dim(x): 80 | """ 81 | Auto broadcast input so that it can fits into a Conv1d 82 | """ 83 | if x.dim() == 2: 84 | x = x[:, None, :] 85 | elif x.dim() == 1: 86 | x = x[None, None, :] 87 | elif x.dim() == 3: 88 | pass 89 | else: 90 | raise ValueError("Only support input with shape = (batch, len) or shape = (len)") 91 | return x 92 | 93 | def broadcast_dim_conv2d(x): 94 | """ 95 | To auto broadcast input so that it can fits into a Conv1d 96 | """ 97 | if x.dim() == 3: 98 | x = x[:, None, :,:] 99 | 100 | else: 101 | raise ValueError("Only support input with shape = (batch, len) or shape = (len)") 102 | return x 103 | 104 | 105 | ## Kernal generation functions ## 106 | def create_fourier_kernels(n_fft, freq_bins=None, fmin=50,fmax=6000, sr=44100, freq_scale='linear', window='hann'): 107 | """ 108 | If freq_scale is 'no', then low and high arguments will be ignored 109 | """ 110 | 111 | if freq_bins==None: 112 | freq_bins = n_fft//2+1 113 | 114 | s = np.arange(0, n_fft, 1.) 115 | wsin = np.empty((freq_bins,1,n_fft)) 116 | wcos = np.empty((freq_bins,1,n_fft)) 117 | start_freq = fmin 118 | end_freq = fmax 119 | bins2freq = [] 120 | binslist = [] 121 | 122 | # num_cycles = start_freq*d/44000. 123 | # scaling_ind = np.log(end_freq/start_freq)/k 124 | 125 | # Choosing window shape 126 | 127 | window_mask = get_window(window,int(n_fft), fftbins=True) 128 | 129 | 130 | if freq_scale == 'linear': 131 | print("sampling rate = {}. Please make sure the sampling rate is correct in order to get a valid freq range".format(sr)) 132 | start_bin = start_freq*n_fft/sr 133 | scaling_ind = (end_freq-start_freq)*(n_fft/sr)/freq_bins 134 | for k in range(freq_bins): # Only half of the bins contain useful info 135 | # print("linear freq = {}".format((k*scaling_ind+start_bin)*sr/n_fft)) 136 | bins2freq.append((k*scaling_ind+start_bin)*sr/n_fft) 137 | binslist.append((k*scaling_ind+start_bin)) 138 | wsin[k,0,:] = window_mask*np.sin(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft) 139 | wcos[k,0,:] = window_mask*np.cos(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft) 140 | 141 | elif freq_scale == 'log': 142 | print("sampling rate = {}. Please make sure the sampling rate is correct in order to get a valid freq range".format(sr)) 143 | start_bin = start_freq*n_fft/sr 144 | scaling_ind = np.log(end_freq/start_freq)/freq_bins 145 | for k in range(freq_bins): # Only half of the bins contain useful info 146 | # print("log freq = {}".format(np.exp(k*scaling_ind)*start_bin*sr/n_fft)) 147 | bins2freq.append(np.exp(k*scaling_ind)*start_bin*sr/n_fft) 148 | binslist.append((np.exp(k*scaling_ind)*start_bin)) 149 | wsin[k,0,:] = window_mask*np.sin(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft) 150 | wcos[k,0,:] = window_mask*np.cos(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft) 151 | 152 | elif freq_scale == 'no': 153 | for k in range(freq_bins): # Only half of the bins contain useful info 154 | bins2freq.append(k*sr/n_fft) 155 | binslist.append(k) 156 | wsin[k,0,:] = window_mask*np.sin(2*np.pi*k*s/n_fft) 157 | wcos[k,0,:] = window_mask*np.cos(2*np.pi*k*s/n_fft) 158 | else: 159 | print("Please select the correct frequency scale, 'linear' or 'log'") 160 | return wsin.astype(np.float32),wcos.astype(np.float32), bins2freq, binslist 161 | 162 | def create_cqt_kernels(Q, fs, fmin, n_bins=84, bins_per_octave=12, norm=1, window='hann', fmax=None, topbin_check=True): 163 | """ 164 | Automatically create CQT kernels and convert it to frequency domain 165 | """ 166 | # norm arg is not functioning 167 | 168 | fftLen = 2**nextpow2(np.ceil(Q * fs / fmin)) 169 | # minWin = 2**nextpow2(np.ceil(Q * fs / fmax)) 170 | if (fmax != None) and (n_bins == None): 171 | n_bins = np.ceil(bins_per_octave * np.log2(fmax / fmin)) # Calculate the number of bins 172 | freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave)) 173 | elif (fmax == None) and (n_bins != None): 174 | freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave)) 175 | else: 176 | warnings.warn('If fmax is given, n_bins will be ignored',SyntaxWarning) 177 | n_bins = np.ceil(bins_per_octave * np.log2(fmax / fmin)) # Calculate the number of bins 178 | freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave)) 179 | if np.max(freqs) > fs/2 and topbin_check==True: 180 | raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, please reduce the n_bins'.format(np.max(freqs))) 181 | tempKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64) 182 | specKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64) 183 | for k in range(0, int(n_bins)): 184 | freq = freqs[k] 185 | l = np.ceil(Q * fs / freq) 186 | lenghts = np.ceil(Q * fs / freqs) 187 | # Centering the kernels 188 | if l%2==1: # pad more zeros on RHS 189 | start = int(np.ceil(fftLen / 2.0 - l / 2.0))-1 190 | else: 191 | start = int(np.ceil(fftLen / 2.0 - l / 2.0)) 192 | sig = get_window(window,int(l), fftbins=True)*np.exp(np.r_[-l//2:l//2]*1j*2*np.pi*freq/fs)/l 193 | if norm: # Normalizing the filter # Trying to normalize like librosa 194 | tempKernel[k, start:start + int(l)] = sig/np.linalg.norm(sig, norm) 195 | else: 196 | tempKernel[k, start:start + int(l)] = sig 197 | # specKernel[k, :] = fft(tempKernel[k]) 198 | 199 | # return specKernel[:,:fftLen//2+1], fftLen, torch.tensor(lenghts).float() 200 | return tempKernel, fftLen, torch.tensor(lenghts).float() 201 | 202 | def create_cqt_kernels_t(Q, fs, fmin, n_bins=84, bins_per_octave=12, norm=1, window='hann', fmax=None): 203 | """ 204 | Create cqt kernels in time-domain 205 | """ 206 | # norm arg is not functioning 207 | 208 | fftLen = 2**nextpow2(np.ceil(Q * fs / fmin)) 209 | # minWin = 2**nextpow2(np.ceil(Q * fs / fmax)) 210 | if (fmax != None) and (n_bins == None): 211 | n_bins = np.ceil(bins_per_octave * np.log2(fmax / fmin)) # Calculate the number of bins 212 | freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave)) 213 | elif (fmax == None) and (n_bins != None): 214 | freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave)) 215 | else: 216 | warnings.warn('If fmax is given, n_bins will be ignored',SyntaxWarning) 217 | n_bins = np.ceil(bins_per_octave * np.log2(fmax / fmin)) # Calculate the number of bins 218 | freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave)) 219 | if np.max(freqs) > fs/2: 220 | raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, please reduce the n_bins'.format(np.max(freqs))) 221 | tempKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64) 222 | specKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64) 223 | for k in range(0, int(n_bins)): 224 | freq = freqs[k] 225 | l = np.ceil(Q * fs / freq) 226 | lenghts = np.ceil(Q * fs / freqs) 227 | # Centering the kernels 228 | if l%2==1: # pad more zeros on RHS 229 | start = int(np.ceil(fftLen / 2.0 - l / 2.0))-1 230 | else: 231 | start = int(np.ceil(fftLen / 2.0 - l / 2.0)) 232 | sig = get_window(window,int(l), fftbins=True)*np.exp(np.r_[-l//2:l//2]*1j*2*np.pi*freq/fs)/l 233 | if norm: # Normalizing the filter # Trying to normalize like librosa 234 | tempKernel[k, start:start + int(l)] = sig/np.linalg.norm(sig, norm) 235 | else: 236 | tempKernel[k, start:start + int(l)] = sig 237 | # specKernel[k, :]=fft(conj(tempKernel[k, :])) 238 | 239 | return tempKernel, fftLen, torch.tensor(lenghts).float() 240 | 241 | ### ----------------Functions for generating kenral for Mel Spectrogram------------ ### 242 | def mel_to_hz(mels, htk=False): 243 | """Convert mel bin numbers to frequencies 244 | 245 | Examples 246 | -------- 247 | >>> librosa.mel_to_hz(3) 248 | 200. 249 | 250 | >>> librosa.mel_to_hz([1,2,3,4,5]) 251 | array([ 66.667, 133.333, 200. , 266.667, 333.333]) 252 | 253 | Parameters 254 | ---------- 255 | mels : np.ndarray [shape=(n,)], float 256 | mel bins to convert 257 | htk : bool 258 | use HTK formula instead of Slaney 259 | 260 | Returns 261 | ------- 262 | frequencies : np.ndarray [shape=(n,)] 263 | input mels in Hz 264 | 265 | See Also 266 | -------- 267 | hz_to_mel 268 | """ 269 | 270 | mels = np.asanyarray(mels) 271 | 272 | if htk: 273 | return 700.0 * (10.0**(mels / 2595.0) - 1.0) 274 | 275 | # Fill in the linear scale 276 | f_min = 0.0 277 | f_sp = 200.0 / 3 278 | freqs = f_min + f_sp * mels 279 | 280 | # And now the nonlinear scale 281 | min_log_hz = 1000.0 # beginning of log region (Hz) 282 | min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) 283 | logstep = np.log(6.4) / 27.0 # step size for log region 284 | 285 | if mels.ndim: 286 | # If we have vector data, vectorize 287 | log_t = (mels >= min_log_mel) 288 | freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel)) 289 | elif mels >= min_log_mel: 290 | # If we have scalar data, check directly 291 | freqs = min_log_hz * np.exp(logstep * (mels - min_log_mel)) 292 | 293 | return freqs 294 | 295 | def hz_to_mel(frequencies, htk=False): 296 | """Convert Hz to Mels 297 | 298 | Examples 299 | -------- 300 | >>> librosa.hz_to_mel(60) 301 | 0.9 302 | >>> librosa.hz_to_mel([110, 220, 440]) 303 | array([ 1.65, 3.3 , 6.6 ]) 304 | 305 | Parameters 306 | ---------- 307 | frequencies : number or np.ndarray [shape=(n,)] , float 308 | scalar or array of frequencies 309 | htk : bool 310 | use HTK formula instead of Slaney 311 | 312 | Returns 313 | ------- 314 | mels : number or np.ndarray [shape=(n,)] 315 | input frequencies in Mels 316 | 317 | See Also 318 | -------- 319 | mel_to_hz 320 | """ 321 | 322 | frequencies = np.asanyarray(frequencies) 323 | 324 | if htk: 325 | return 2595.0 * np.log10(1.0 + frequencies / 700.0) 326 | 327 | # Fill in the linear part 328 | f_min = 0.0 329 | f_sp = 200.0 / 3 330 | 331 | mels = (frequencies - f_min) / f_sp 332 | 333 | # Fill in the log-scale part 334 | 335 | min_log_hz = 1000.0 # beginning of log region (Hz) 336 | min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) 337 | logstep = np.log(6.4) / 27.0 # step size for log region 338 | 339 | if frequencies.ndim: 340 | # If we have array data, vectorize 341 | log_t = (frequencies >= min_log_hz) 342 | mels[log_t] = min_log_mel + np.log(frequencies[log_t]/min_log_hz) / logstep 343 | elif frequencies >= min_log_hz: 344 | # If we have scalar data, heck directly 345 | mels = min_log_mel + np.log(frequencies / min_log_hz) / logstep 346 | 347 | return mels 348 | 349 | def fft_frequencies(sr=22050, n_fft=2048): 350 | '''Alternative implementation of `np.fft.fftfreq` 351 | 352 | Parameters 353 | ---------- 354 | sr : number > 0 [scalar] 355 | Audio sampling rate 356 | 357 | n_fft : int > 0 [scalar] 358 | FFT window size 359 | 360 | 361 | Returns 362 | ------- 363 | freqs : np.ndarray [shape=(1 + n_fft/2,)] 364 | Frequencies `(0, sr/n_fft, 2*sr/n_fft, ..., sr/2)` 365 | 366 | 367 | Examples 368 | -------- 369 | >>> librosa.fft_frequencies(sr=22050, n_fft=16) 370 | array([ 0. , 1378.125, 2756.25 , 4134.375, 371 | 5512.5 , 6890.625, 8268.75 , 9646.875, 11025. ]) 372 | 373 | ''' 374 | 375 | return np.linspace(0, 376 | float(sr) / 2, 377 | int(1 + n_fft//2), 378 | endpoint=True) 379 | 380 | def mel_frequencies(n_mels=128, fmin=0.0, fmax=11025.0, htk=False): 381 | """Compute an array of acoustic frequencies tuned to the mel scale. 382 | 383 | The mel scale is a quasi-logarithmic function of acoustic frequency 384 | designed such that perceptually similar pitch intervals (e.g. octaves) 385 | appear equal in width over the full hearing range. 386 | 387 | Because the definition of the mel scale is conditioned by a finite number 388 | of subjective psychoaoustical experiments, several implementations coexist 389 | in the audio signal processing literature [1]_. By default, librosa replicates 390 | the behavior of the well-established MATLAB Auditory Toolbox of Slaney [2]_. 391 | According to this default implementation, the conversion from Hertz to mel is 392 | linear below 1 kHz and logarithmic above 1 kHz. Another available implementation 393 | replicates the Hidden Markov Toolkit [3]_ (HTK) according to the following formula: 394 | 395 | `mel = 2595.0 * np.log10(1.0 + f / 700.0).` 396 | 397 | The choice of implementation is determined by the `htk` keyword argument: setting 398 | `htk=False` leads to the Auditory toolbox implementation, whereas setting it `htk=True` 399 | leads to the HTK implementation. 400 | 401 | .. [1] Umesh, S., Cohen, L., & Nelson, D. Fitting the mel scale. 402 | In Proc. International Conference on Acoustics, Speech, and Signal Processing 403 | (ICASSP), vol. 1, pp. 217-220, 1998. 404 | 405 | .. [2] Slaney, M. Auditory Toolbox: A MATLAB Toolbox for Auditory 406 | Modeling Work. Technical Report, version 2, Interval Research Corporation, 1998. 407 | 408 | .. [3] Young, S., Evermann, G., Gales, M., Hain, T., Kershaw, D., Liu, X., 409 | Moore, G., Odell, J., Ollason, D., Povey, D., Valtchev, V., & Woodland, P. 410 | The HTK book, version 3.4. Cambridge University, March 2009. 411 | 412 | 413 | See Also 414 | -------- 415 | hz_to_mel 416 | mel_to_hz 417 | librosa.feature.melspectrogram 418 | librosa.feature.mfcc 419 | 420 | 421 | Parameters 422 | ---------- 423 | n_mels : int > 0 [scalar] 424 | Number of mel bins. 425 | 426 | fmin : float >= 0 [scalar] 427 | Minimum frequency (Hz). 428 | 429 | fmax : float >= 0 [scalar] 430 | Maximum frequency (Hz). 431 | 432 | htk : bool 433 | If True, use HTK formula to convert Hz to mel. 434 | Otherwise (False), use Slaney's Auditory Toolbox. 435 | 436 | Returns 437 | ------- 438 | bin_frequencies : ndarray [shape=(n_mels,)] 439 | Vector of n_mels frequencies in Hz which are uniformly spaced on the Mel 440 | axis. 441 | 442 | Examples 443 | -------- 444 | >>> librosa.mel_frequencies(n_mels=40) 445 | array([ 0. , 85.317, 170.635, 255.952, 446 | 341.269, 426.586, 511.904, 597.221, 447 | 682.538, 767.855, 853.173, 938.49 , 448 | 1024.856, 1119.114, 1222.042, 1334.436, 449 | 1457.167, 1591.187, 1737.532, 1897.337, 450 | 2071.84 , 2262.393, 2470.47 , 2697.686, 451 | 2945.799, 3216.731, 3512.582, 3835.643, 452 | 4188.417, 4573.636, 4994.285, 5453.621, 453 | 5955.205, 6502.92 , 7101.009, 7754.107, 454 | 8467.272, 9246.028, 10096.408, 11025. ]) 455 | 456 | """ 457 | 458 | # 'Center freqs' of mel bands - uniformly spaced between limits 459 | min_mel = hz_to_mel(fmin, htk=htk) 460 | max_mel = hz_to_mel(fmax, htk=htk) 461 | 462 | mels = np.linspace(min_mel, max_mel, n_mels) 463 | 464 | return mel_to_hz(mels, htk=htk) 465 | 466 | def mel(sr, n_fft, n_mels=128, fmin=0.0, fmax=None, htk=False, 467 | norm=1, dtype=np.float32): 468 | """Create a Filterbank matrix to combine FFT bins into Mel-frequency bins 469 | 470 | Parameters 471 | ---------- 472 | sr : number > 0 [scalar] 473 | sampling rate of the incoming signal 474 | 475 | n_fft : int > 0 [scalar] 476 | number of FFT components 477 | 478 | n_mels : int > 0 [scalar] 479 | number of Mel bands to generate 480 | 481 | fmin : float >= 0 [scalar] 482 | lowest frequency (in Hz) 483 | 484 | fmax : float >= 0 [scalar] 485 | highest frequency (in Hz). 486 | If `None`, use `fmax = sr / 2.0` 487 | 488 | htk : bool [scalar] 489 | use HTK formula instead of Slaney 490 | 491 | norm : {None, 1, np.inf} [scalar] 492 | if 1, divide the triangular mel weights by the width of the mel band 493 | (area normalization). Otherwise, leave all the triangles aiming for 494 | a peak value of 1.0 495 | 496 | dtype : np.dtype 497 | The data type of the output basis. 498 | By default, uses 32-bit (single-precision) floating point. 499 | 500 | Returns 501 | ------- 502 | M : np.ndarray [shape=(n_mels, 1 + n_fft/2)] 503 | Mel transform matrix 504 | 505 | Notes 506 | ----- 507 | This function caches at level 10. 508 | 509 | Examples 510 | -------- 511 | >>> melfb = librosa.filters.mel(22050, 2048) 512 | >>> melfb 513 | array([[ 0. , 0.016, ..., 0. , 0. ], 514 | [ 0. , 0. , ..., 0. , 0. ], 515 | ..., 516 | [ 0. , 0. , ..., 0. , 0. ], 517 | [ 0. , 0. , ..., 0. , 0. ]]) 518 | 519 | 520 | Clip the maximum frequency to 8KHz 521 | 522 | >>> librosa.filters.mel(22050, 2048, fmax=8000) 523 | array([[ 0. , 0.02, ..., 0. , 0. ], 524 | [ 0. , 0. , ..., 0. , 0. ], 525 | ..., 526 | [ 0. , 0. , ..., 0. , 0. ], 527 | [ 0. , 0. , ..., 0. , 0. ]]) 528 | 529 | 530 | >>> import matplotlib.pyplot as plt 531 | >>> plt.figure() 532 | >>> librosa.display.specshow(melfb, x_axis='linear') 533 | >>> plt.ylabel('Mel filter') 534 | >>> plt.title('Mel filter bank') 535 | >>> plt.colorbar() 536 | >>> plt.tight_layout() 537 | >>> plt.show() 538 | """ 539 | 540 | if fmax is None: 541 | fmax = float(sr) / 2 542 | 543 | if norm is not None and norm != 1 and norm != np.inf: 544 | raise ParameterError('Unsupported norm: {}'.format(repr(norm))) 545 | 546 | # Initialize the weights 547 | n_mels = int(n_mels) 548 | weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype) 549 | 550 | # Center freqs of each FFT bin 551 | fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft) 552 | 553 | # 'Center freqs' of mel bands - uniformly spaced between limits 554 | mel_f = mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax, htk=htk) 555 | 556 | fdiff = np.diff(mel_f) 557 | ramps = np.subtract.outer(mel_f, fftfreqs) 558 | 559 | for i in range(n_mels): 560 | # lower and upper slopes for all bins 561 | lower = -ramps[i] / fdiff[i] 562 | upper = ramps[i+2] / fdiff[i+1] 563 | 564 | # .. then intersect them with each other and zero 565 | weights[i] = np.maximum(0, np.minimum(lower, upper)) 566 | 567 | if norm == 1: 568 | # Slaney-style mel is scaled to be approx constant energy per channel 569 | enorm = 2.0 / (mel_f[2:n_mels+2] - mel_f[:n_mels]) 570 | weights *= enorm[:, np.newaxis] 571 | 572 | # Only check weights if f_mel[0] is positive 573 | if not np.all((mel_f[:-2] == 0) | (weights.max(axis=1) > 0)): 574 | # This means we have an empty channel somewhere 575 | warnings.warn('Empty filters detected in mel frequency basis. ' 576 | 'Some channels will produce empty responses. ' 577 | 'Try increasing your sampling rate (and fmax) or ' 578 | 'reducing n_mels.') 579 | 580 | return weights 581 | ### ------------------End of Functions for generating kenral for Mel Spectrogram ----------------### 582 | 583 | 584 | ### ------------------Spectrogram Classes---------------------------### 585 | class CQT1992(torch.nn.Module): 586 | def __init__(self, sr=22050, hop_length=512, fmin=220, fmax=None, n_bins=84, bins_per_octave=12, norm=1, window='hann', center=True, pad_mode='reflect'): 587 | super(CQT1992, self).__init__() 588 | # norm arg is not functioning 589 | 590 | self.hop_length = hop_length 591 | self.center = center 592 | self.pad_mode = pad_mode 593 | self.norm = norm 594 | 595 | # creating kernels for CQT 596 | Q = 1/(2**(1/bins_per_octave)-1) 597 | 598 | print("Creating CQT kernels ...", end='\r') 599 | start = time() 600 | self.cqt_kernels, self.kernal_width, lenghts = create_cqt_kernels(Q, sr, fmin, n_bins, bins_per_octave, norm, window, fmax) 601 | self.cqt_kernels = fft(self.cqt_kernels)[:,:self.kernal_width//2+1] 602 | self.cqt_kernels_real = torch.tensor(self.cqt_kernels.real.astype(np.float32)) 603 | self.cqt_kernels_imag = torch.tensor(self.cqt_kernels.imag.astype(np.float32)) 604 | print("CQT kernels created, time used = {:.4f} seconds".format(time()-start)) 605 | 606 | # creating kernels for stft 607 | # self.cqt_kernels_real*=lenghts.unsqueeze(1)/self.kernal_width # Trying to normalize as librosa 608 | # self.cqt_kernels_imag*=lenghts.unsqueeze(1)/self.kernal_width 609 | print("Creating STFT kernels ...", end='\r') 610 | start = time() 611 | wsin, wcos, self.bins2freq, _ = create_fourier_kernels(self.kernal_width, window='ones', freq_scale='no') 612 | self.wsin = torch.tensor(wsin) 613 | self.wcos = torch.tensor(wcos) 614 | print("STFT kernels created, time used = {:.4f} seconds".format(time()-start)) 615 | 616 | def forward(self,x): 617 | x = broadcast_dim(x) 618 | if self.center: 619 | if self.pad_mode == 'constant': 620 | padding = nn.ConstantPad1d(self.kernal_width//2, 0) 621 | elif self.pad_mode == 'reflect': 622 | padding = nn.ReflectionPad1d(self.kernal_width//2) 623 | 624 | x = padding(x) 625 | 626 | # STFT 627 | fourier_real = conv1d(x, self.wcos, stride=self.hop_length) 628 | fourier_imag = conv1d(x, self.wsin, stride=self.hop_length) 629 | 630 | # CQT 631 | CQT_real, CQT_imag = complex_mul((self.cqt_kernels_real, self.cqt_kernels_imag), 632 | (fourier_real, fourier_imag)) 633 | 634 | # Getting CQT Amplitude 635 | CQT = torch.sqrt(CQT_real.pow(2)+CQT_imag.pow(2)) 636 | 637 | if self.norm: 638 | return CQT/self.kernal_width 639 | else: 640 | return CQT 641 | 642 | class CQT1992v2(torch.nn.Module): 643 | def __init__(self, sr=22050, hop_length=512, fmin=220, fmax=None, n_bins=84, bins_per_octave=12, norm=1, window='hann', center=True, pad_mode='reflect'): 644 | super(CQT1992v2, self).__init__() 645 | # norm arg is not functioning 646 | 647 | self.hop_length = hop_length 648 | self.center = center 649 | self.pad_mode = pad_mode 650 | 651 | # creating kernels for CQT 652 | Q = 1/(2**(1/bins_per_octave)-1) 653 | 654 | print("Creating CQT kernels ...", end='\r') 655 | start = time() 656 | self.cqt_kernels, self.kernal_width, lenghts = create_cqt_kernels(Q, sr, fmin, n_bins, bins_per_octave, norm, window, fmax) 657 | self.cqt_kernels_real = torch.tensor(self.cqt_kernels.real).unsqueeze(1) 658 | self.cqt_kernels_imag = torch.tensor(self.cqt_kernels.imag).unsqueeze(1) 659 | print("CQT kernels created, time used = {:.4f} seconds".format(time()-start)) 660 | 661 | # creating kernels for stft 662 | # self.cqt_kernels_real*=lenghts.unsqueeze(1)/self.kernal_width # Trying to normalize as librosa 663 | # self.cqt_kernels_imag*=lenghts.unsqueeze(1)/self.kernal_width 664 | 665 | def forward(self,x): 666 | x = broadcast_dim(x) 667 | if self.center: 668 | if self.pad_mode == 'constant': 669 | padding = nn.ConstantPad1d(self.kernal_width//2, 0) 670 | elif self.pad_mode == 'reflect': 671 | padding = nn.ReflectionPad1d(self.kernal_width//2) 672 | 673 | x = padding(x) 674 | 675 | # CQT 676 | CQT_real = conv1d(x, self.cqt_kernels_real, stride=self.hop_length) 677 | CQT_imag = conv1d(x, self.cqt_kernels_imag, stride=self.hop_length) 678 | 679 | 680 | # Getting CQT Amplitude 681 | CQT = torch.sqrt(CQT_real.pow(2)+CQT_imag.pow(2)) 682 | return CQT 683 | 684 | class STFT(torch.nn.Module): 685 | """When using freq_scale, please set the correct sampling rate. The sampling rate is used to calucate the correct frequency""" 686 | 687 | def __init__(self, n_fft=2048, freq_bins=None, hop_length=512, window='hann', freq_scale='no', center=True, pad_mode='reflect', fmin=50,fmax=6000, sr=22050): 688 | super(STFT, self).__init__() 689 | self.stride = hop_length 690 | self.center = center 691 | self.pad_mode = pad_mode 692 | self.n_fft = n_fft 693 | 694 | # Create filter windows for stft 695 | wsin, wcos, self.bins2freq, self.bin_list = create_fourier_kernels(n_fft, freq_bins=freq_bins, window=window, freq_scale=freq_scale, fmin=fmin,fmax=fmax, sr=sr) 696 | self.wsin = torch.tensor(wsin, dtype=torch.float) 697 | self.wcos = torch.tensor(wcos, dtype=torch.float) 698 | 699 | def forward(self,x): 700 | x = broadcast_dim(x) 701 | if self.center: 702 | if self.pad_mode == 'constant': 703 | padding = nn.ConstantPad1d(self.n_fft//2, 0) 704 | elif self.pad_mode == 'reflect': 705 | padding = nn.ReflectionPad1d(self.n_fft//2) 706 | 707 | x = padding(x) 708 | 709 | spec = conv1d(x, self.wsin, stride=self.stride).pow(2) \ 710 | + conv1d(x, self.wcos, stride=self.stride).pow(2) # Doing STFT by using conv1d 711 | return torch.sqrt(spec) 712 | 713 | def manual_forward(self,x): 714 | x = broadcast_dim(x) 715 | if self.center: 716 | if self.pad_mode == 'constant': 717 | padding = nn.ConstantPad1d(self.n_fft//2, 0) 718 | elif self.pad_mode == 'reflect': 719 | padding = nn.ReflectionPad1d(self.n_fft//2) 720 | 721 | x = padding(x) 722 | 723 | imag = conv1d(x, self.wsin, stride=self.stride).pow(2) 724 | real = conv1d(x, self.wcos, stride=self.stride).pow(2) # Doing STFT by using conv1d 725 | return real, imag 726 | 727 | class DFT(torch.nn.Module): 728 | """ 729 | The inverse function only works for 1 single frame. i.e. input shape = (batch, n_fft, 1) 730 | """ 731 | def __init__(self, n_fft=2048, freq_bins=None, hop_length=512, window='hann', freq_scale='no', center=True, pad_mode='reflect', fmin=50,fmax=6000, sr=22050): 732 | super(DFT, self).__init__() 733 | self.stride = hop_length 734 | self.center = center 735 | self.pad_mode = pad_mode 736 | self.n_fft = n_fft 737 | 738 | # Create filter windows for stft 739 | wsin, wcos, self.bins2freq = create_fourier_kernels(n_fft, freq_bins=n_fft, window=window, freq_scale=freq_scale, fmin=fmin,fmax=fmax, sr=sr) 740 | self.wsin = torch.tensor(wsin, dtype=torch.float) 741 | self.wcos = torch.tensor(wcos, dtype=torch.float) 742 | 743 | def forward(self,x): 744 | x = broadcast_dim(x) 745 | if self.center: 746 | if self.pad_mode == 'constant': 747 | padding = nn.ConstantPad1d(self.n_fft//2, 0) 748 | elif self.pad_mode == 'reflect': 749 | padding = nn.ReflectionPad1d(self.n_fft//2) 750 | 751 | x = padding(x) 752 | 753 | imag = conv1d(x, self.wsin, stride=self.stride) 754 | real = conv1d(x, self.wcos, stride=self.stride) 755 | return (real, -imag) 756 | 757 | def inverse(self,x_real,x_imag): 758 | x_real = broadcast_dim(x_real) 759 | x_imag = broadcast_dim(x_imag) 760 | 761 | x_real.transpose_(1,2) # Prepare the right shape to do inverse 762 | x_imag.transpose_(1,2) # Prepare the right shape to do inverse 763 | 764 | # if self.center: 765 | # if self.pad_mode == 'constant': 766 | # padding = nn.ConstantPad1d(self.n_fft//2, 0) 767 | # elif self.pad_mode == 'reflect': 768 | # padding = nn.ReflectionPad1d(self.n_fft//2) 769 | 770 | # x_real = padding(x_real) 771 | # x_imag = padding(x_imag) 772 | 773 | # Watch out for the positive and negative signs 774 | #ifft = e^(+2\pi*j)*X 775 | 776 | #ifft(X_real) = (a1, a2) 777 | 778 | #ifft(X_imag)*1j = (b1, b2)*1j 779 | # = (-b2, b1) 780 | 781 | a1 = conv1d(x_real, self.wcos, stride=self.stride) 782 | a2 = conv1d(x_real, self.wsin, stride=self.stride) 783 | b1 = conv1d(x_imag, self.wcos, stride=self.stride) 784 | b2 = conv1d(x_imag, self.wsin, stride=self.stride) 785 | 786 | imag = a2+b1 787 | real = a1-b2 788 | return (real/self.n_fft, imag/self.n_fft) 789 | 790 | class iSTFT_complex_2d(torch.nn.Module): 791 | def __init__(self, n_fft=2048, freq_bins=None, hop_length=512, window='hann', freq_scale='no', center=True, pad_mode='reflect', fmin=50,fmax=6000, sr=22050): 792 | super(iSTFT_complex_2d, self).__init__() 793 | self.stride = hop_length 794 | self.center = center 795 | self.pad_mode = pad_mode 796 | self.n_fft = n_fft 797 | 798 | # Create filter windows for stft 799 | wsin, wcos, self.bins2freq = create_fourier_kernels(n_fft, freq_bins=n_fft, window=window, freq_scale=freq_scale, fmin=fmin,fmax=fmax, sr=sr) 800 | self.wsin = torch.tensor(wsin, dtype=torch.float) 801 | self.wcos = torch.tensor(wcos, dtype=torch.float) 802 | 803 | self.wsin = self.wsin[:,:,:,None] #adjust the filter shape to fit into 2d Conv 804 | self.wcos = self.wcos[:,:,:,None] 805 | 806 | def forward(self,x_real,x_imag): 807 | x_real = broadcast_dim_conv2d(x_real) 808 | x_imag = broadcast_dim_conv2d(x_imag) # taking conjuate 809 | 810 | 811 | # if self.center: 812 | # if self.pad_mode == 'constant': 813 | # padding = nn.ConstantPad1d(self.n_fft//2, 0) 814 | # elif self.pad_mode == 'reflect': 815 | # padding = nn.ReflectionPad1d(self.n_fft//2) 816 | 817 | # x_real = padding(x_real) 818 | # x_imag = padding(x_imag) 819 | 820 | # Watch out for the positive and negative signs 821 | #ifft = e^(+2\pi*j)*X 822 | 823 | #ifft(X_real) = (a1, a2) 824 | 825 | #ifft(X_imag)*1j = (b1, b2)*1j 826 | # = (-b2, b1) 827 | 828 | a1 = conv2d(x_real, self.wcos, stride=(1,1)) 829 | a2 = conv2d(x_real, self.wsin, stride=(1,1)) 830 | b1 = conv2d(x_imag, self.wcos, stride=(1,1)) 831 | b2 = conv2d(x_imag, self.wsin, stride=(1,1)) 832 | 833 | imag = a2+b1 834 | real = a1-b2 835 | return (real/self.n_fft, imag/self.n_fft) 836 | 837 | class MelSpectrogram(torch.nn.Module): 838 | def __init__(self, sr=22050, n_fft=2048, n_mels=128, hop_length=512, window='hann', center=True, pad_mode='reflect', htk=False, fmin=0.0, fmax=None, norm=1): 839 | super(MelSpectrogram, self).__init__() 840 | self.stride = hop_length 841 | self.center = center 842 | self.pad_mode = pad_mode 843 | self.n_fft = n_fft 844 | 845 | # Create filter windows for stft 846 | wsin, wcos, self.bins2freq, _ = create_fourier_kernels(n_fft, freq_bins=None, window=window, freq_scale='no', sr=sr) 847 | self.wsin = torch.tensor(wsin, dtype=torch.float) 848 | self.wcos = torch.tensor(wcos, dtype=torch.float) 849 | 850 | # Creating kenral for mel spectrogram 851 | mel_basis = mel(sr, n_fft, n_mels, fmin, fmax, htk=htk, norm=norm) 852 | self.mel_basis = torch.tensor(mel_basis) 853 | def forward(self,x): 854 | x = broadcast_dim(x) 855 | if self.center: 856 | if self.pad_mode == 'constant': 857 | padding = nn.ConstantPad1d(self.n_fft//2, 0) 858 | elif self.pad_mode == 'reflect': 859 | padding = nn.ReflectionPad1d(self.n_fft//2) 860 | 861 | x = padding(x) 862 | 863 | spec = conv1d(x, self.wsin, stride=self.stride).pow(2) \ 864 | + conv1d(x, self.wcos, stride=self.stride).pow(2) # Doing STFT by using conv1d 865 | 866 | melspec = torch.matmul(self.mel_basis, spec) 867 | return melspec 868 | 869 | 870 | ### ----------------CQT 2010------------------------------------------------------- ### 871 | 872 | def cqt_filter_fft(sr, fmin, n_bins, bins_per_octave, tuning, 873 | filter_scale, norm, sparsity, hop_length=None, 874 | window='hann'): 875 | '''Generate the frequency domain constant-Q filter basis.''' 876 | 877 | basis, lengths = filters.constant_q(sr, 878 | fmin=fmin, 879 | n_bins=n_bins, 880 | bins_per_octave=bins_per_octave, 881 | tuning=tuning, 882 | filter_scale=filter_scale, 883 | norm=norm, 884 | pad_fft=True, 885 | window=window) 886 | 887 | # Filters are padded up to the nearest integral power of 2 888 | n_fft = basis.shape[1] 889 | 890 | if (hop_length is not None and 891 | n_fft < 2.0**(1 + np.ceil(np.log2(hop_length)))): 892 | 893 | n_fft = int(2.0 ** (1 + np.ceil(np.log2(hop_length)))) 894 | 895 | # re-normalize bases with respect to the FFT window length 896 | basis *= lengths[:, np.newaxis] / float(n_fft) 897 | 898 | # FFT and retain only the non-negative frequencies 899 | fft = get_fftlib() 900 | fft_basis = fft.fft(basis, n=n_fft, axis=1)[:, :(n_fft // 2)+1] 901 | 902 | # sparsify the basis 903 | fft_basis = util.sparsify_rows(fft_basis, quantile=sparsity) 904 | 905 | return fft_basis, n_fft, lengths 906 | 907 | # from librosa import filters, get_fftlib, util 908 | 909 | class CQT2010(torch.nn.Module): 910 | """ 911 | This alogrithm is using the resampling method proposed in [1]. Instead of convoluting the STFT results with a gigantic CQT kernel covering the full frequency spectrum, we make a small CQT kernel covering only the top octave. Then we keep downsampling the input audio by a factor of 2 to convoluting it with the small CQT kernel. Everytime the input audio is downsampled, the CQT relative to the downsampled input is equavalent to the next lower octave. 912 | The kernel creation process is still same as the 1992 algorithm. Therefore, we can reuse the code from the 1992 alogrithm [2] 913 | [1] Schörkhuber, Christian. “CONSTANT-Q TRANSFORM TOOLBOX FOR MUSIC PROCESSING.” (2010). 914 | [2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of a constant Q transform.” (1992). 915 | 916 | early downsampling factor is to downsample the input audio to reduce the CQT kernel size. The result with and without early downsampling are more or less the same except in the very low frequency region where freq < 40Hz 917 | 918 | """ 919 | def __init__(self, sr=22050, hop_length=512, fmin=220, fmax=None, n_bins=84, bins_per_octave=12, norm=True, basis_norm=1, window='hann', pad_mode='reflect', earlydownsample=True): 920 | super(CQT2010, self).__init__() 921 | 922 | self.norm = norm # Now norm is used to normalize the final CQT result by dividing n_fft 923 | #basis_norm is for normlaizing basis 924 | self.hop_length = hop_length 925 | self.pad_mode = pad_mode 926 | self.n_bins = n_bins 927 | self.earlydownsample = earlydownsample # We will activate eraly downsampling later if possible 928 | 929 | Q = 1/(2**(1/bins_per_octave)-1) # It will be used to calculate filter_cutoff and creating CQT kernels 930 | 931 | # Creating lowpass filter and make it a torch tensor 932 | print("Creating low pass filter ...", end='\r') 933 | start = time() 934 | self.lowpass_filter = torch.tensor( 935 | create_lowpass_filter( 936 | band_center = 0.5, 937 | kernelLength=256, 938 | transitionBandwidth=0.001)) 939 | self.lowpass_filter = self.lowpass_filter[None,None,:] # Broadcast the tensor to the shape that fits conv1d 940 | print("Low pass filter created, time used = {:.4f} seconds".format(time()-start)) 941 | 942 | # Caluate num of filter requires for the kernel 943 | # n_octaves determines how many resampling requires for the CQT 944 | n_filters = min(bins_per_octave, n_bins) 945 | self.n_octaves = int(np.ceil(float(n_bins) / bins_per_octave)) 946 | # print("n_octaves = ", self.n_octaves) 947 | # Calculate the lowest frequency bin for the top octave kernel 948 | self.fmin_t = fmin*2**(self.n_octaves-1) 949 | remainder = n_bins % bins_per_octave 950 | # print("remainder = ", remainder) 951 | if remainder==0: 952 | fmax_t = self.fmin_t*2**((bins_per_octave-1)/bins_per_octave) # Calculate the top bin frequency 953 | else: 954 | fmax_t = self.fmin_t*2**((remainder-1)/bins_per_octave) # Calculate the top bin frequency 955 | self.fmin_t = fmax_t/2**(1-1/bins_per_octave) # Adjusting the top minium bins 956 | if fmax_t > sr/2: 957 | raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, please reduce the n_bins'.format(fmax_t)) 958 | 959 | 960 | if self.earlydownsample == True: # Do early downsampling if this argument is True 961 | print("Creating early downsampling filter ...", end='\r') 962 | start = time() 963 | sr, self.hop_length, self.downsample_factor, self.early_downsample_filter, self.earlydownsample = self.get_early_downsample_params(sr, hop_length, fmax_t, Q, self.n_octaves) 964 | print("Early downsampling filter created, time used = {:.4f} seconds".format(time()-start)) 965 | else: 966 | self.downsample_factor=1. 967 | 968 | # Preparing CQT kernels 969 | print("Creating CQT kernels ...", end='\r') 970 | start = time() 971 | # print("Q = {}, fmin_t = {}, n_filters = {}".format(Q, self.fmin_t, n_filters)) 972 | basis, self.n_fft, _ = create_cqt_kernels(Q, sr, self.fmin_t, n_filters, bins_per_octave, norm=basis_norm, topbin_check=False) 973 | self.basis=basis 974 | fft_basis = fft(basis)[:,:self.n_fft//2+1] # Convert CQT kenral from time domain to freq domain 975 | 976 | self.cqt_kernels_real = torch.tensor(fft_basis.real.astype(np.float32)) # These cqt_kernal is already in the frequency domain 977 | self.cqt_kernels_imag = torch.tensor(fft_basis.imag.astype(np.float32)) 978 | print("CQT kernels created, time used = {:.4f} seconds".format(time()-start)) 979 | # print("Getting cqt kernel done, n_fft = ",self.n_fft) 980 | # Preparing kernels for Short-Time Fourier Transform (STFT) 981 | # We set the frequency range in the CQT filter instead of here. 982 | print("Creating STFT kernels ...", end='\r') 983 | start = time() 984 | wsin, wcos, self.bins2freq, _ = create_fourier_kernels(self.n_fft, window='ones', freq_scale='no') 985 | self.wsin = torch.tensor(wsin) 986 | self.wcos = torch.tensor(wcos) 987 | print("STFT kernels created, time used = {:.4f} seconds".format(time()-start)) 988 | 989 | 990 | 991 | # If center==True, the STFT window will be put in the middle, and paddings at the beginning and ending are required. 992 | if self.pad_mode == 'constant': 993 | self.padding = nn.ConstantPad1d(self.n_fft//2, 0) 994 | elif self.pad_mode == 'reflect': 995 | self.padding = nn.ReflectionPad1d(self.n_fft//2) 996 | 997 | 998 | def get_cqt(self,x,hop_length, padding): 999 | """Multiplying the STFT result with the cqt_kernal, check out the 1992 CQT paper [1] for how to multiple the STFT result with the CQT kernel 1000 | [2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of a constant Q transform.” (1992).""" 1001 | 1002 | # STFT, converting the audio input from time domain to frequency domain 1003 | try: 1004 | x = padding(x) # When center == True, we need padding at the beginning and ending 1005 | except: 1006 | print("padding with reflection mode might not be the best choice, try using constant padding") 1007 | fourier_real = conv1d(x, self.wcos, stride=hop_length) 1008 | fourier_imag = conv1d(x, self.wsin, stride=hop_length) 1009 | 1010 | # Multiplying input with the CQT kernel in freq domain 1011 | CQT_real, CQT_imag = complex_mul((self.cqt_kernels_real, self.cqt_kernels_imag), 1012 | (fourier_real, fourier_imag)) 1013 | 1014 | # Getting CQT Amplitude 1015 | CQT = torch.sqrt(CQT_real.pow(2)+CQT_imag.pow(2)) 1016 | 1017 | return CQT 1018 | 1019 | 1020 | def get_early_downsample_params(self, sr, hop_length, fmax_t, Q, n_octaves): 1021 | window_bandwidth = 1.5 # for hann window 1022 | filter_cutoff = fmax_t * (1 + 0.5 * window_bandwidth / Q) 1023 | sr, hop_length, downsample_factor=self.early_downsample(sr, hop_length, n_octaves, sr//2, filter_cutoff) 1024 | if downsample_factor != 1: 1025 | print("Can do early downsample, factor = ", downsample_factor) 1026 | earlydownsample=True 1027 | # print("new sr = ", sr) 1028 | # print("new hop_length = ", hop_length) 1029 | early_downsample_filter = create_lowpass_filter(band_center=1/downsample_factor, kernelLength=256, transitionBandwidth=0.03) 1030 | early_downsample_filter = torch.tensor(early_downsample_filter)[None, None, :] 1031 | else: 1032 | print("No early downsampling is required, downsample_factor = ", downsample_factor) 1033 | early_downsample_filter = None 1034 | earlydownsample=False 1035 | return sr, hop_length, downsample_factor, early_downsample_filter, earlydownsample 1036 | 1037 | # The following two downsampling count functions are obtained from librosa CQT 1038 | # They are used to determine the number of pre resamplings if the starting and ending frequency are both in low frequency regions. 1039 | def early_downsample_count(self, nyquist, filter_cutoff, hop_length, n_octaves): 1040 | '''Compute the number of early downsampling operations''' 1041 | 1042 | downsample_count1 = max(0, int(np.ceil(np.log2(0.85 * nyquist / 1043 | filter_cutoff)) - 1) - 1) 1044 | # print("downsample_count1 = ", downsample_count1) 1045 | num_twos = nextpow2(hop_length) 1046 | downsample_count2 = max(0, num_twos - n_octaves + 1) 1047 | # print("downsample_count2 = ",downsample_count2) 1048 | 1049 | return min(downsample_count1, downsample_count2) 1050 | 1051 | def early_downsample(self, sr, hop_length, n_octaves, 1052 | nyquist, filter_cutoff): 1053 | '''Return new sampling rate and hop length after early dowansampling''' 1054 | downsample_count = self.early_downsample_count(nyquist, filter_cutoff, hop_length, n_octaves) 1055 | # print("downsample_count = ", downsample_count) 1056 | downsample_factor = 2**(downsample_count) 1057 | 1058 | hop_length //= downsample_factor # Getting new hop_length 1059 | new_sr = sr / float(downsample_factor) # Getting new sampling rate 1060 | 1061 | sr = new_sr 1062 | 1063 | return sr, hop_length, downsample_factor 1064 | 1065 | 1066 | def forward(self,x): 1067 | x = broadcast_dim(x) 1068 | if self.earlydownsample==True: 1069 | x = downsampling_by_n(x, self.early_downsample_filter, self.downsample_factor) 1070 | hop = self.hop_length 1071 | CQT = self.get_cqt(x, hop, self.padding) #Getting the top octave CQT 1072 | 1073 | x_down = x # Preparing a new variable for downsampling 1074 | for i in range(self.n_octaves-1): 1075 | hop = hop//2 1076 | x_down = downsampling_by_2(x_down, self.lowpass_filter) 1077 | CQT1 = self.get_cqt(x_down, hop, self.padding) 1078 | CQT = torch.cat((CQT1, CQT),1) # 1079 | CQT = CQT[:,-self.n_bins:,:] #Removing unwanted top bins 1080 | CQT = CQT*2**(self.n_octaves-1) #Normalizing signals with respect to n_fft 1081 | 1082 | CQT = CQT*self.downsample_factor/2**(self.n_octaves-1) # Normalizing the output with the downsampling factor, 2**(self.n_octaves-1) is make it same mag as 1992 1083 | 1084 | if self.norm: 1085 | return CQT/self.n_fft 1086 | else: 1087 | return CQT 1088 | 1089 | class CQT2019(torch.nn.Module): 1090 | """ 1091 | This alogrithm is using the resampling method proposed in [1]. Instead of convoluting the STFT results with a gigantic CQT kernel covering the full frequency spectrum, we make a small CQT kernel covering only the top octave. Then we keep downsampling the input audio by a factor of 2 to convoluting it with the small CQT kernel. Everytime the input audio is downsampled, the CQT relative to the downsampled input is equavalent to the next lower octave. 1092 | The kernel creation process is still same as the 1992 algorithm. Therefore, we can reuse the code from the 1992 alogrithm [2] 1093 | [1] Schörkhuber, Christian. “CONSTANT-Q TRANSFORM TOOLBOX FOR MUSIC PROCESSING.” (2010). 1094 | [2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of a constant Q transform.” (1992). 1095 | 1096 | early downsampling factor is to downsample the input audio to reduce the CQT kernel size. The result with and without early downsampling are more or less the same except in the very low frequency region where freq < 40Hz 1097 | 1098 | """ 1099 | def __init__(self, sr=22050, hop_length=512, fmin=220, fmax=None, n_bins=84, bins_per_octave=12, norm=True, basis_norm=1, window='hann', pad_mode='reflect', earlydownsample=True): 1100 | super(CQT2019, self).__init__() 1101 | 1102 | self.norm = norm # Now norm is used to normalize the final CQT result by dividing n_fft 1103 | #basis_norm is for normlaizing basis 1104 | self.hop_length = hop_length 1105 | self.pad_mode = pad_mode 1106 | self.n_bins = n_bins 1107 | self.earlydownsample = earlydownsample # We will activate eraly downsampling later if possible 1108 | 1109 | Q = 1/(2**(1/bins_per_octave)-1) # It will be used to calculate filter_cutoff and creating CQT kernels 1110 | 1111 | # Creating lowpass filter and make it a torch tensor 1112 | print("Creating low pass filter ...", end='\r') 1113 | start = time() 1114 | self.lowpass_filter = torch.tensor( 1115 | create_lowpass_filter( 1116 | band_center = 0.50, 1117 | kernelLength=256, 1118 | transitionBandwidth=0.001)) 1119 | self.lowpass_filter = self.lowpass_filter[None,None,:] # Broadcast the tensor to the shape that fits conv1d 1120 | print("Low pass filter created, time used = {:.4f} seconds".format(time()-start)) 1121 | 1122 | # Caluate num of filter requires for the kernel 1123 | # n_octaves determines how many resampling requires for the CQT 1124 | n_filters = min(bins_per_octave, n_bins) 1125 | self.n_octaves = int(np.ceil(float(n_bins) / bins_per_octave)) 1126 | print("num_octave = ", self.n_octaves) 1127 | 1128 | # Calculate the lowest frequency bin for the top octave kernel 1129 | self.fmin_t = fmin*2**(self.n_octaves-1) 1130 | remainder = n_bins % bins_per_octave 1131 | # print("remainder = ", remainder) 1132 | if remainder==0: 1133 | fmax_t = self.fmin_t*2**((bins_per_octave-1)/bins_per_octave) # Calculate the top bin frequency 1134 | else: 1135 | fmax_t = self.fmin_t*2**((remainder-1)/bins_per_octave) # Calculate the top bin frequency 1136 | self.fmin_t = fmax_t/2**(1-1/bins_per_octave) # Adjusting the top minium bins 1137 | if fmax_t > sr/2: 1138 | raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, please reduce the n_bins'.format(fmax_t)) 1139 | 1140 | if self.earlydownsample == True: # Do early downsampling if this argument is True 1141 | print("Creating early downsampling filter ...", end='\r') 1142 | start = time() 1143 | sr, self.hop_length, self.downsample_factor, self.early_downsample_filter, self.earlydownsample = self.get_early_downsample_params(sr, hop_length, fmax_t, Q, self.n_octaves) 1144 | print("Early downsampling filter created, time used = {:.4f} seconds".format(time()-start)) 1145 | else: 1146 | self.downsample_factor=1. 1147 | 1148 | # Preparing CQT kernels 1149 | print("Creating CQT kernels ...", end='\r') 1150 | start = time() 1151 | basis, self.n_fft, _ = create_cqt_kernels(Q, sr, self.fmin_t, n_filters, bins_per_octave, norm=basis_norm, topbin_check=False) 1152 | self.basis = basis 1153 | self.cqt_kernels_real = torch.tensor(basis.real.astype(np.float32)).unsqueeze(1) # These cqt_kernal is already in the frequency domain 1154 | self.cqt_kernels_imag = torch.tensor(basis.imag.astype(np.float32)).unsqueeze(1) 1155 | print("CQT kernels created, time used = {:.4f} seconds".format(time()-start)) 1156 | # print("Getting cqt kernel done, n_fft = ",self.n_fft) 1157 | 1158 | # If center==True, the STFT window will be put in the middle, and paddings at the beginning and ending are required. 1159 | if self.pad_mode == 'constant': 1160 | self.padding = nn.ConstantPad1d(self.n_fft//2, 0) 1161 | elif self.pad_mode == 'reflect': 1162 | self.padding = nn.ReflectionPad1d(self.n_fft//2) 1163 | 1164 | 1165 | def get_cqt(self,x,hop_length, padding): 1166 | """Multiplying the STFT result with the cqt_kernal, check out the 1992 CQT paper [1] for how to multiple the STFT result with the CQT kernel 1167 | [2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of a constant Q transform.” (1992).""" 1168 | 1169 | # STFT, converting the audio input from time domain to frequency domain 1170 | try: 1171 | x = padding(x) # When center == True, we need padding at the beginning and ending 1172 | except: 1173 | print("padding with reflection mode might not be the best choice, try using constant padding") 1174 | CQT_real = conv1d(x, self.cqt_kernels_real, stride=hop_length) 1175 | CQT_imag = conv1d(x, self.cqt_kernels_imag, stride=hop_length) 1176 | 1177 | # Getting CQT Amplitude 1178 | CQT = torch.sqrt(CQT_real.pow(2)+CQT_imag.pow(2)) 1179 | 1180 | return CQT 1181 | 1182 | 1183 | def get_early_downsample_params(self, sr, hop_length, fmax_t, Q, n_octaves): 1184 | window_bandwidth = 1.5 # for hann window 1185 | filter_cutoff = fmax_t * (1 + 0.5 * window_bandwidth / Q) 1186 | sr, hop_length, downsample_factor=self.early_downsample(sr, hop_length, n_octaves, sr//2, filter_cutoff) 1187 | if downsample_factor != 1: 1188 | print("Can do early downsample, factor = ", downsample_factor) 1189 | earlydownsample=True 1190 | # print("new sr = ", sr) 1191 | # print("new hop_length = ", hop_length) 1192 | early_downsample_filter = create_lowpass_filter(band_center=1/downsample_factor, kernelLength=256, transitionBandwidth=0.03) 1193 | early_downsample_filter = torch.tensor(early_downsample_filter)[None, None, :] 1194 | else: 1195 | print("No early downsampling is required, downsample_factor = ", downsample_factor) 1196 | early_downsample_filter = None 1197 | earlydownsample=False 1198 | return sr, hop_length, downsample_factor, early_downsample_filter, earlydownsample 1199 | 1200 | # The following two downsampling count functions are obtained from librosa CQT 1201 | # They are used to determine the number of pre resamplings if the starting and ending frequency are both in low frequency regions. 1202 | def early_downsample_count(self, nyquist, filter_cutoff, hop_length, n_octaves): 1203 | '''Compute the number of early downsampling operations''' 1204 | 1205 | downsample_count1 = max(0, int(np.ceil(np.log2(0.85 * nyquist / 1206 | filter_cutoff)) - 1) - 1) 1207 | # print("downsample_count1 = ", downsample_count1) 1208 | num_twos = nextpow2(hop_length) 1209 | downsample_count2 = max(0, num_twos - n_octaves + 1) 1210 | # print("downsample_count2 = ",downsample_count2) 1211 | 1212 | return min(downsample_count1, downsample_count2) 1213 | 1214 | def early_downsample(self, sr, hop_length, n_octaves, 1215 | nyquist, filter_cutoff): 1216 | '''Return new sampling rate and hop length after early dowansampling''' 1217 | downsample_count = self.early_downsample_count(nyquist, filter_cutoff, hop_length, n_octaves) 1218 | # print("downsample_count = ", downsample_count) 1219 | downsample_factor = 2**(downsample_count) 1220 | 1221 | hop_length //= downsample_factor # Getting new hop_length 1222 | new_sr = sr / float(downsample_factor) # Getting new sampling rate 1223 | 1224 | sr = new_sr 1225 | 1226 | return sr, hop_length, downsample_factor 1227 | 1228 | 1229 | def forward(self,x): 1230 | x = broadcast_dim(x) 1231 | if self.earlydownsample==True: 1232 | x = downsampling_by_n(x, self.early_downsample_filter, self.downsample_factor) 1233 | hop = self.hop_length 1234 | CQT = self.get_cqt(x, hop, self.padding) #Getting the top octave CQT 1235 | 1236 | x_down = x # Preparing a new variable for downsampling 1237 | for i in range(self.n_octaves-1): 1238 | hop = hop//2 1239 | x_down = downsampling_by_2(x_down, self.lowpass_filter) 1240 | CQT1 = self.get_cqt(x_down, hop, self.padding) 1241 | CQT = torch.cat((CQT1, CQT),1) # 1242 | CQT = CQT[:,-self.n_bins:,:] #Removing unwanted bottom bins 1243 | CQT = CQT*2**(self.n_octaves-1) #Normalizing signals with respect to n_fft 1244 | CQT = CQT*self.downsample_factor/2**(self.n_octaves-1) # Normalizing the output with the downsampling factor, 2**(self.n_octaves-1) is make it same mag as 1992 1245 | 1246 | return CQT 1247 | -------------------------------------------------------------------------------- /nnAudio2/Spectrogram.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.functional import conv1d, conv2d 4 | 5 | import numpy as np 6 | import torch 7 | from time import time 8 | import math 9 | from scipy.signal import get_window 10 | from scipy import signal 11 | from scipy import fft 12 | import warnings 13 | 14 | from .librosa_filters import mel,gammatone # Use it for PyPip 15 | # from librosa_filters import mel # Use it for debug 16 | 17 | sz_float = 4 # size of a float 18 | epsilon = 10e-8 # fudge factor for normalization 19 | 20 | # ---------------------------Filter design ----------------------------------- 21 | def create_lowpass_filter(band_center=0.5, kernelLength=256, transitionBandwidth=0.03): 22 | """ 23 | calculate the highest frequency we need to preserve and the lowest frequency we allow to pass through. Note that frequency is on a scale from 0 to 1 where 0 is 0 and 1 is Nyquist frequency of the signal BEFORE downsampling 24 | """ 25 | 26 | # transitionBandwidth = 0.03 27 | passbandMax = band_center / (1 + transitionBandwidth) 28 | stopbandMin = band_center * (1 + transitionBandwidth) 29 | 30 | # Unlike the filter tool we used online yesterday, this tool does 31 | # not allow us to specify how closely the filter matches our 32 | # specifications. Instead, we specify the length of the kernel. 33 | # The longer the kernel is, the more precisely it will match. 34 | # kernelLength = 256 35 | 36 | # We specify a list of key frequencies for which we will require 37 | # that the filter match a specific output gain. 38 | # From [0.0 to passbandMax] is the frequency range we want to keep 39 | # untouched and [stopbandMin, 1.0] is the range we want to remove 40 | keyFrequencies = [0.0, passbandMax, stopbandMin, 1.0] 41 | 42 | # We specify a list of output gains to correspond to the key 43 | # frequencies listed above. 44 | # The first two gains are 1.0 because they correspond to the first 45 | # two key frequencies. the second two are 0.0 because they 46 | # correspond to the stopband frequencies 47 | gainAtKeyFrequencies = [1.0, 1.0, 0.0, 0.0] 48 | 49 | # This command produces the filter kernel coefficients 50 | filterKernel = signal.firwin2(kernelLength, keyFrequencies, gainAtKeyFrequencies) 51 | 52 | return filterKernel.astype(np.float32) 53 | 54 | def downsampling_by_n(x, filterKernel, n): 55 | """A helper function that downsamples the audio by a arbitary factor n. It is used in CQT2010 and CQT2010v2 56 | 57 | Parameters 58 | ---------- 59 | x : torch.Tensor 60 | The input waveform in ``torch.Tensor`` type with shape ``(batch, 1, len_audio)`` 61 | 62 | filterKernel : str 63 | Filter kernel in ``torch.Tensor`` type with shape ``(1, 1, len_kernel)`` 64 | 65 | n : int 66 | The downsampling factor 67 | 68 | Returns 69 | ------- 70 | torch.Tensor 71 | The downsampled waveform 72 | 73 | Examples 74 | -------- 75 | >>> x_down = downsampling_by_n(x, filterKernel) 76 | """ 77 | x = conv1d(x,filterKernel,stride=n, padding=(filterKernel.shape[-1]-1)//2) 78 | return x 79 | 80 | def downsampling_by_2(x, filterKernel): 81 | """A helper function that downsamples the audio by half. It is used in CQT2010 and CQT2010v2 82 | 83 | Parameters 84 | ---------- 85 | x : torch.Tensor 86 | The input waveform in ``torch.Tensor`` type with shape ``(batch, 1, len_audio)`` 87 | 88 | filterKernel : str 89 | Filter kernel in ``torch.Tensor`` type with shape ``(1, 1, len_kernel)`` 90 | 91 | Returns 92 | ------- 93 | torch.Tensor 94 | The downsampled waveform 95 | 96 | Examples 97 | -------- 98 | >>> x_down = downsampling_by_2(x, filterKernel) 99 | """ 100 | x = conv1d(x,filterKernel,stride=2, padding=(filterKernel.shape[-1]-1)//2) 101 | return x 102 | 103 | 104 | ## Basic tools for computation ## 105 | def nextpow2(A): 106 | """A helper function to calculate the next nearest number to the power of 2. 107 | 108 | Parameters 109 | ---------- 110 | A : float 111 | A float number that is going to be rounded up to the nearest power of 2 112 | 113 | Returns 114 | ------- 115 | int 116 | The nearest power of 2 to the input number ``A`` 117 | 118 | Examples 119 | -------- 120 | 121 | >>> nextpow2(6) 122 | 8 123 | """ 124 | return int(np.ceil(np.log2(A))) 125 | 126 | def complex_mul(cqt_filter, stft): 127 | """Since PyTorch does not support complex numbers and its operation. We need to write our own complex multiplication function. This one is specially designed for CQT usage 128 | 129 | Parameters 130 | ---------- 131 | cqt_filter : tuple of torch.Tensor 132 | The tuple is in the format of ``(real_torch_tensor, imag_torch_tensor)`` 133 | 134 | Returns 135 | ------- 136 | tuple of torch.Tensor 137 | The output is in the format of ``(real_torch_tensor, imag_torch_tensor)`` 138 | """ 139 | 140 | cqt_filter_real = cqt_filter[0] 141 | cqt_filter_imag = cqt_filter[1] 142 | fourier_real = stft[0] 143 | fourier_imag = stft[1] 144 | 145 | CQT_real = torch.matmul(cqt_filter_real, fourier_real) - torch.matmul(cqt_filter_imag, fourier_imag) 146 | CQT_imag = torch.matmul(cqt_filter_real, fourier_imag) + torch.matmul(cqt_filter_imag, fourier_real) 147 | 148 | return CQT_real, CQT_imag 149 | 150 | def broadcast_dim(x): 151 | """ 152 | Auto broadcast input so that it can fits into a Conv1d 153 | """ 154 | 155 | if x.dim() == 2: 156 | x = x[:, None, :] 157 | elif x.dim() == 1: 158 | # If nn.DataParallel is used, this broadcast doesn't work 159 | x = x[None, None, :] 160 | elif x.dim() == 3: 161 | pass 162 | else: 163 | raise ValueError("Only support input with shape = (batch, len) or shape = (len)") 164 | return x 165 | 166 | def broadcast_dim_conv2d(x): 167 | """ 168 | Auto broadcast input so that it can fits into a Conv2d 169 | """ 170 | if x.dim() == 3: 171 | x = x[:, None, :,:] 172 | 173 | else: 174 | raise ValueError("Only support input with shape = (batch, len) or shape = (len)") 175 | return x 176 | 177 | 178 | ## Kernal generation functions ## 179 | def create_fourier_kernels(n_fft, freq_bins=None, fmin=50,fmax=6000, sr=44100, freq_scale='linear', window='hann'): 180 | """ This function creates the Fourier Kernel for STFT, Melspectrogram and CQT. Most of the parameters follow librosa conventions. Part of the code comes from pytorch_musicnet. https://github.com/jthickstun/pytorch_musicnet 181 | 182 | Parameters 183 | ---------- 184 | n_fft : int 185 | The window size 186 | 187 | freq_bins : int 188 | Number of frequency bins. Default is ``None``, which means ``n_fft//2+1`` bins 189 | 190 | fmin : int 191 | The starting frequency for the lowest frequency bin. If freq_scale is ``no``, this argument does nothing. 192 | 193 | fmax : int 194 | The ending frequency for the highest frequency bin. If freq_scale is ``no``, this argument does nothing. 195 | 196 | sr : int 197 | The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``. Setting the correct sampling rate is very important for calculating the correct frequency. 198 | 199 | freq_scale: 'linear', 'log', or 'no' 200 | Determine the spacing between each frequency bin. When 'linear' or 'log' is used, the bin spacing can be controlled by ``fmin`` and ``fmax``. If 'no' is used, the bin will start at 0Hz and end at Nyquist frequency with linear spacing. 201 | 202 | Returns 203 | ------- 204 | wsin : numpy.array 205 | Imaginary Fourier Kernel with the shape ``(freq_bins, 1, n_fft)`` 206 | 207 | wcos : numpy.array 208 | Real Fourier Kernel with the shape ``(freq_bins, 1, n_fft)`` 209 | 210 | bins2freq : list 211 | Mapping each frequency bin to frequency in Hz. 212 | 213 | binslist : list 214 | The normalized frequency ``k`` in digital domain. This ``k`` is in the Discrete Fourier Transform equation $$ 215 | 216 | """ 217 | if freq_bins==None: 218 | freq_bins = n_fft//2+1 219 | 220 | s = np.arange(0, n_fft, 1.) 221 | wsin = np.empty((freq_bins,1,n_fft)) 222 | wcos = np.empty((freq_bins,1,n_fft)) 223 | start_freq = fmin 224 | end_freq = fmax 225 | bins2freq = [] 226 | binslist = [] 227 | 228 | # num_cycles = start_freq*d/44000. 229 | # scaling_ind = np.log(end_freq/start_freq)/k 230 | 231 | # Choosing window shape 232 | 233 | window_mask = get_window(window,int(n_fft), fftbins=True) 234 | 235 | 236 | if freq_scale == 'linear': 237 | print("sampling rate = {}. Please make sure the sampling rate is correct in order to get a valid freq range".format(sr)) 238 | start_bin = start_freq*n_fft/sr 239 | scaling_ind = (end_freq-start_freq)*(n_fft/sr)/freq_bins 240 | for k in range(freq_bins): # Only half of the bins contain useful info 241 | # print("linear freq = {}".format((k*scaling_ind+start_bin)*sr/n_fft)) 242 | bins2freq.append((k*scaling_ind+start_bin)*sr/n_fft) 243 | binslist.append((k*scaling_ind+start_bin)) 244 | wsin[k,0,:] = window_mask*np.sin(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft) 245 | wcos[k,0,:] = window_mask*np.cos(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft) 246 | 247 | elif freq_scale == 'log': 248 | print("sampling rate = {}. Please make sure the sampling rate is correct in order to get a valid freq range".format(sr)) 249 | start_bin = start_freq*n_fft/sr 250 | scaling_ind = np.log(end_freq/start_freq)/freq_bins 251 | for k in range(freq_bins): # Only half of the bins contain useful info 252 | # print("log freq = {}".format(np.exp(k*scaling_ind)*start_bin*sr/n_fft)) 253 | bins2freq.append(np.exp(k*scaling_ind)*start_bin*sr/n_fft) 254 | binslist.append((np.exp(k*scaling_ind)*start_bin)) 255 | wsin[k,0,:] = window_mask*np.sin(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft) 256 | wcos[k,0,:] = window_mask*np.cos(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft) 257 | 258 | elif freq_scale == 'no': 259 | for k in range(freq_bins): # Only half of the bins contain useful info 260 | bins2freq.append(k*sr/n_fft) 261 | binslist.append(k) 262 | wsin[k,0,:] = window_mask*np.sin(2*np.pi*k*s/n_fft) 263 | wcos[k,0,:] = window_mask*np.cos(2*np.pi*k*s/n_fft) 264 | else: 265 | print("Please select the correct frequency scale, 'linear' or 'log'") 266 | return wsin.astype(np.float32),wcos.astype(np.float32), bins2freq, binslist 267 | 268 | def create_cqt_kernels(Q, fs, fmin, n_bins=84, bins_per_octave=12, norm=1, window='hann', fmax=None, topbin_check=True): 269 | """ 270 | Automatically create CQT kernels and convert it to frequency domain 271 | """ 272 | # norm arg is not functioning 273 | 274 | fftLen = 2**nextpow2(np.ceil(Q * fs / fmin)) 275 | # minWin = 2**nextpow2(np.ceil(Q * fs / fmax)) 276 | if (fmax != None) and (n_bins == None): 277 | n_bins = np.ceil(bins_per_octave * np.log2(fmax / fmin)) # Calculate the number of bins 278 | freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave)) 279 | elif (fmax == None) and (n_bins != None): 280 | freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave)) 281 | else: 282 | warnings.warn('If fmax is given, n_bins will be ignored',SyntaxWarning) 283 | n_bins = np.ceil(bins_per_octave * np.log2(fmax / fmin)) # Calculate the number of bins 284 | freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave)) 285 | if np.max(freqs) > fs/2 and topbin_check==True: 286 | raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, please reduce the n_bins'.format(np.max(freqs))) 287 | tempKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64) 288 | specKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64) 289 | for k in range(0, int(n_bins)): 290 | freq = freqs[k] 291 | l = np.ceil(Q * fs / freq) 292 | lenghts = np.ceil(Q * fs / freqs) 293 | # Centering the kernels 294 | if l%2==1: # pad more zeros on RHS 295 | start = int(np.ceil(fftLen / 2.0 - l / 2.0))-1 296 | else: 297 | start = int(np.ceil(fftLen / 2.0 - l / 2.0)) 298 | sig = get_window(window,int(l), fftbins=True)*np.exp(np.r_[-l//2:l//2]*1j*2*np.pi*freq/fs)/l 299 | if norm: # Normalizing the filter # Trying to normalize like librosa 300 | tempKernel[k, start:start + int(l)] = sig/np.linalg.norm(sig, norm) 301 | else: 302 | tempKernel[k, start:start + int(l)] = sig 303 | # specKernel[k, :] = fft(tempKernel[k]) 304 | 305 | # return specKernel[:,:fftLen//2+1], fftLen, torch.tensor(lenghts).float() 306 | return tempKernel, fftLen, torch.tensor(lenghts).float() 307 | 308 | def create_cqt_kernels_t(Q, fs, fmin, n_bins=84, bins_per_octave=12, norm=1, window='hann', fmax=None): 309 | """ 310 | Create cqt kernels in time-domain 311 | """ 312 | # norm arg is not functioning 313 | 314 | fftLen = 2**nextpow2(np.ceil(Q * fs / fmin)) 315 | # minWin = 2**nextpow2(np.ceil(Q * fs / fmax)) 316 | if (fmax != None) and (n_bins == None): 317 | n_bins = np.ceil(bins_per_octave * np.log2(fmax / fmin)) # Calculate the number of bins 318 | freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave)) 319 | elif (fmax == None) and (n_bins != None): 320 | freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave)) 321 | else: 322 | warnings.warn('If fmax is given, n_bins will be ignored',SyntaxWarning) 323 | n_bins = np.ceil(bins_per_octave * np.log2(fmax / fmin)) # Calculate the number of bins 324 | freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave)) 325 | if np.max(freqs) > fs/2: 326 | raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, please reduce the n_bins'.format(np.max(freqs))) 327 | tempKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64) 328 | specKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64) 329 | for k in range(0, int(n_bins)): 330 | freq = freqs[k] 331 | l = np.ceil(Q * fs / freq) 332 | lenghts = np.ceil(Q * fs / freqs) 333 | # Centering the kernels 334 | if l%2==1: # pad more zeros on RHS 335 | start = int(np.ceil(fftLen / 2.0 - l / 2.0))-1 336 | else: 337 | start = int(np.ceil(fftLen / 2.0 - l / 2.0)) 338 | sig = get_window(window,int(l), fftbins=True)*np.exp(np.r_[-l//2:l//2]*1j*2*np.pi*freq/fs)/l 339 | if norm: # Normalizing the filter # Trying to normalize like librosa 340 | tempKernel[k, start:start + int(l)] = sig/np.linalg.norm(sig, norm) 341 | else: 342 | tempKernel[k, start:start + int(l)] = sig 343 | # specKernel[k, :]=fft(conj(tempKernel[k, :])) 344 | 345 | return tempKernel, fftLen, torch.tensor(lenghts).float() 346 | 347 | 348 | ### ------------------Spectrogram Classes---------------------------### 349 | 350 | class STFT(torch.nn.Module): 351 | """This function is to calculate the short-time Fourier transform (STFT) of the input signal. Input signal should be in either of the following shapes. 1. ``(len_audio)``, 2. ``(num_audio, len_audio)``, 3. ``(num_audio, 1, len_audio)``. The correct shape will be inferred autommatically if the input follows these 3 shapes. Most of the arguments follow the convention from librosa. This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``. 352 | 353 | Parameters 354 | ---------- 355 | n_fft : int 356 | The window size. Default value is 2048. 357 | 358 | freq_bins : int 359 | Number of frequency bins. Default is ``None``, which means ``n_fft//2+1`` bins 360 | 361 | hop_length : int 362 | The hop (or stride) size. Default value is 512. 363 | 364 | window : str 365 | The windowing function for STFT. It uses ``scipy.signal.get_window``, please refer to scipy documentation for possible windowing functions. The default value is 'hann' 366 | 367 | freq_scale : 'linear', 'log', or 'no' 368 | Determine the spacing between each frequency bin. When `linear` or `log` is used, the bin spacing can be controlled by ``fmin`` and ``fmax``. If 'no' is used, the bin will start at 0Hz and end at Nyquist frequency with linear spacing. 369 | 370 | center : bool 371 | Putting the STFT keneral at the center of the time-step or not. If ``False``, the time index is the beginning of the STFT kernel, if ``True``, the time index is the center of the STFT kernel. Default value if ``True``. 372 | 373 | pad_mode : str 374 | The padding method. Default value is 'reflect'. 375 | 376 | fmin : int 377 | The starting frequency for the lowest frequency bin. If freq_scale is ``no``, this argument does nothing. 378 | 379 | fmax : int 380 | The ending frequency for the highest frequency bin. If freq_scale is ``no``, this argument does nothing. 381 | 382 | sr : int 383 | The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``. Setting the correct sampling rate is very important for calculating the correct frequency. 384 | 385 | trainable : bool 386 | Determine if the STFT kenrels are trainable or not. If ``True``, the gradients for STFT kernels will also be caluclated and the STFT kernels will be updated during model training. Default value is ``False`` 387 | 388 | output_format : str 389 | Determine the return type. ``Magnitude`` will return the magnitude of the STFT result, shape = ``(num_samples, freq_bins,time_steps)``; ``Complex`` will return the STFT result in complex number, shape = ``(num_samples, freq_bins,time_steps, 2)``; ``Phase`` will return the phase of the STFT reuslt, shape = ``(num_samples, freq_bins,time_steps, 2)``. The complex number is stored as ``(real, imag)`` in the last axis. Default value is 'Magnitude'. 390 | 391 | verbose : bool 392 | If ``True``, it shows layer information. If ``False``, it suppresses all prints 393 | 394 | device : str 395 | Choose which device to initialize this layer. Default value is 'cuda:0' 396 | 397 | Returns 398 | ------- 399 | spectrogram : torch.tensor 400 | It returns a tensor of spectrograms. shape = ``(num_samples, freq_bins,time_steps)`` if 'Magnitude' is used as the ``output_format``; Shape = ``(num_samples, freq_bins,time_steps, 2)`` if 'Complex' or 'Phase' are used as the ``output_format`` 401 | 402 | Examples 403 | -------- 404 | >>> spec_layer = Spectrogram.STFT() 405 | >>> specs = spec_layer(x) 406 | """ 407 | def __init__(self, n_fft=2048, freq_bins=None, hop_length=512, window='hann', freq_scale='no', center=True, pad_mode='reflect', fmin=50,fmax=6000, sr=22050, trainable=False, output_format='Magnitude', verbose=True, device='cuda:0'): 408 | self.trainable = trainable 409 | super(STFT, self).__init__() 410 | self.stride = hop_length 411 | self.center = center 412 | self.pad_mode = pad_mode 413 | self.n_fft = n_fft 414 | self.trainable = trainable 415 | self.output_format=output_format 416 | self.device = device 417 | start = time() 418 | # Create filter windows for stft 419 | wsin, wcos, self.bins2freq, self.bin_list = create_fourier_kernels(n_fft, freq_bins=freq_bins, window=window, freq_scale=freq_scale, fmin=fmin,fmax=fmax, sr=sr) 420 | self.wsin = torch.tensor(wsin, dtype=torch.float, device=self.device) 421 | self.wcos = torch.tensor(wcos, dtype=torch.float, device=self.device) 422 | 423 | # Making all these variables nn.Parameter, so that the model can be used with nn.Parallel 424 | self.wsin = torch.nn.Parameter(self.wsin, requires_grad=self.trainable) 425 | self.wcos = torch.nn.Parameter(self.wcos, requires_grad=self.trainable) 426 | 427 | # if self.trainable==True: 428 | # self.wsin = torch.nn.Parameter(self.wsin) 429 | # self.wcos = torch.nn.Parameter(self.wcos) 430 | 431 | if verbose==True: 432 | print("STFT kernels created, time used = {:.4f} seconds".format(time()-start)) 433 | else: 434 | pass 435 | 436 | def forward(self,x): 437 | x = broadcast_dim(x) 438 | if self.center: 439 | if self.pad_mode == 'constant': 440 | padding = nn.ConstantPad1d(self.n_fft//2, 0) 441 | elif self.pad_mode == 'reflect': 442 | padding = nn.ReflectionPad1d(self.n_fft//2) 443 | 444 | x = padding(x) 445 | 446 | spec_imag = conv1d(x, self.wsin, stride=self.stride) 447 | spec_real = conv1d(x, self.wcos, stride=self.stride) # Doing STFT by using conv1d 448 | 449 | if self.output_format=='Magnitude': 450 | spec = spec_real.pow(2) + spec_imag.pow(2) 451 | if self.trainable==True: 452 | return torch.sqrt(spec+1e-8) # prevent Nan gradient when sqrt(0) due to output=0 453 | else: 454 | return torch.sqrt(spec) 455 | elif self.output_format=='Complex': 456 | return torch.stack((spec_real,-spec_imag), -1) # Remember the minus sign for imaginary part 457 | 458 | elif self.output_format=='Phase': 459 | return torch.atan2(-spec_imag+0.0,spec_real) # +0.0 helps remove -0.0 elements, which leads to error in calcuating pahse 460 | 461 | # This part is for implementing the librosa.core.magphase 462 | # But it seems it is not useful 463 | # phase_real = torch.cos(torch.atan2(spec_imag,spec_real)) 464 | # phase_imag = torch.sin(torch.atan2(spec_imag,spec_real)) 465 | # return torch.stack((phase_real,phase_imag), -1) 466 | 467 | def manual_forward(self,x): 468 | x = broadcast_dim(x) 469 | if self.center: 470 | if self.pad_mode == 'constant': 471 | padding = nn.ConstantPad1d(self.n_fft//2, 0) 472 | elif self.pad_mode == 'reflect': 473 | padding = nn.ReflectionPad1d(self.n_fft//2) 474 | 475 | x = padding(x) 476 | 477 | imag = conv1d(x, self.wsin, stride=self.stride).pow(2) 478 | real = conv1d(x, self.wcos, stride=self.stride).pow(2) # Doing STFT by using conv1d 479 | return real, imag 480 | 481 | class DFT(torch.nn.Module): 482 | """ 483 | The inverse function only works for 1 single frame. i.e. input shape = (batch, n_fft, 1) 484 | """ 485 | def __init__(self, n_fft=2048, freq_bins=None, hop_length=512, window='hann', freq_scale='no', center=True, pad_mode='reflect', fmin=50,fmax=6000, sr=22050): 486 | super(DFT, self).__init__() 487 | self.stride = hop_length 488 | self.center = center 489 | self.pad_mode = pad_mode 490 | self.n_fft = n_fft 491 | 492 | # Create filter windows for stft 493 | wsin, wcos, self.bins2freq = create_fourier_kernels(n_fft, freq_bins=n_fft, window=window, freq_scale=freq_scale, fmin=fmin,fmax=fmax, sr=sr) 494 | self.wsin = torch.tensor(wsin, dtype=torch.float) 495 | self.wcos = torch.tensor(wcos, dtype=torch.float) 496 | 497 | def forward(self,x): 498 | x = broadcast_dim(x) 499 | if self.center: 500 | if self.pad_mode == 'constant': 501 | padding = nn.ConstantPad1d(self.n_fft//2, 0) 502 | elif self.pad_mode == 'reflect': 503 | padding = nn.ReflectionPad1d(self.n_fft//2) 504 | 505 | x = padding(x) 506 | 507 | imag = conv1d(x, self.wsin, stride=self.stride) 508 | real = conv1d(x, self.wcos, stride=self.stride) 509 | return (real, -imag) 510 | 511 | def inverse(self,x_real,x_imag): 512 | x_real = broadcast_dim(x_real) 513 | x_imag = broadcast_dim(x_imag) 514 | 515 | x_real.transpose_(1,2) # Prepare the right shape to do inverse 516 | x_imag.transpose_(1,2) # Prepare the right shape to do inverse 517 | 518 | # if self.center: 519 | # if self.pad_mode == 'constant': 520 | # padding = nn.ConstantPad1d(self.n_fft//2, 0) 521 | # elif self.pad_mode == 'reflect': 522 | # padding = nn.ReflectionPad1d(self.n_fft//2) 523 | 524 | # x_real = padding(x_real) 525 | # x_imag = padding(x_imag) 526 | 527 | # Watch out for the positive and negative signs 528 | #ifft = e^(+2\pi*j)*X 529 | 530 | #ifft(X_real) = (a1, a2) 531 | 532 | #ifft(X_imag)*1j = (b1, b2)*1j 533 | # = (-b2, b1) 534 | 535 | a1 = conv1d(x_real, self.wcos, stride=self.stride) 536 | a2 = conv1d(x_real, self.wsin, stride=self.stride) 537 | b1 = conv1d(x_imag, self.wcos, stride=self.stride) 538 | b2 = conv1d(x_imag, self.wsin, stride=self.stride) 539 | 540 | imag = a2+b1 541 | real = a1-b2 542 | return (real/self.n_fft, imag/self.n_fft) 543 | 544 | class iSTFT_complex_2d(torch.nn.Module): 545 | def __init__(self, n_fft=2048, freq_bins=None, hop_length=512, window='hann', freq_scale='no', center=True, pad_mode='reflect', fmin=50,fmax=6000, sr=22050): 546 | super(iSTFT_complex_2d, self).__init__() 547 | self.stride = hop_length 548 | self.center = center 549 | self.pad_mode = pad_mode 550 | self.n_fft = n_fft 551 | 552 | # Create filter windows for stft 553 | wsin, wcos, self.bins2freq = create_fourier_kernels(n_fft, freq_bins=n_fft, window=window, freq_scale=freq_scale, fmin=fmin,fmax=fmax, sr=sr) 554 | self.wsin = torch.tensor(wsin, dtype=torch.float) 555 | self.wcos = torch.tensor(wcos, dtype=torch.float) 556 | 557 | self.wsin = self.wsin[:,:,:,None] #adjust the filter shape to fit into 2d Conv 558 | self.wcos = self.wcos[:,:,:,None] 559 | 560 | def forward(self,x_real,x_imag): 561 | x_real = broadcast_dim_conv2d(x_real) 562 | x_imag = broadcast_dim_conv2d(x_imag) # taking conjuate 563 | 564 | 565 | # if self.center: 566 | # if self.pad_mode == 'constant': 567 | # padding = nn.ConstantPad1d(self.n_fft//2, 0) 568 | # elif self.pad_mode == 'reflect': 569 | # padding = nn.ReflectionPad1d(self.n_fft//2) 570 | 571 | # x_real = padding(x_real) 572 | # x_imag = padding(x_imag) 573 | 574 | # Watch out for the positive and negative signs 575 | #ifft = e^(+2\pi*j)*X 576 | 577 | #ifft(X_real) = (a1, a2) 578 | 579 | #ifft(X_imag)*1j = (b1, b2)*1j 580 | # = (-b2, b1) 581 | 582 | a1 = conv2d(x_real, self.wcos, stride=(1,1)) 583 | a2 = conv2d(x_real, self.wsin, stride=(1,1)) 584 | b1 = conv2d(x_imag, self.wcos, stride=(1,1)) 585 | b2 = conv2d(x_imag, self.wsin, stride=(1,1)) 586 | 587 | imag = a2+b1 588 | real = a1-b2 589 | return (real/self.n_fft, imag/self.n_fft) 590 | class MelSpectrogram(torch.nn.Module): 591 | """This function is to calculate the Melspectrogram of the input signal. Input signal should be in either of the following shapes. 1. ``(len_audio)``, 2. ``(num_audio, len_audio)``, 3. ``(num_audio, 1, len_audio)``. The correct shape will be inferred autommatically if the input follows these 3 shapes. Most of the arguments follow the convention from librosa. This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``. 592 | 593 | Parameters 594 | ---------- 595 | sr : int 596 | The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``. Setting the correct sampling rate is very important for calculating the correct frequency. 597 | 598 | n_fft : int 599 | The window size for the STFT. Default value is 2048 600 | 601 | n_mels : int 602 | The number of Mel filter banks. The filter banks maps the n_fft to mel bins. Default value is 128 603 | 604 | hop_length : int 605 | The hop (or stride) size. Default value is 512. 606 | 607 | window : str 608 | The windowing function for STFT. It uses ``scipy.signal.get_window``, please refer to scipy documentation for possible windowing functions. The default value is 'hann' 609 | 610 | center : bool 611 | Putting the STFT keneral at the center of the time-step or not. If ``False``, the time index is the beginning of the STFT kernel, if ``True``, the time index is the center of the STFT kernel. Default value if ``True``. 612 | 613 | pad_mode : str 614 | The padding method. Default value is 'reflect'. 615 | 616 | htk : bool 617 | When ``False`` is used, the Mel scale is quasi-logarithmic. When ``True`` is used, the Mel scale is logarithmic. The default value is ``False`` 618 | 619 | fmin : int 620 | The starting frequency for the lowest Mel filter bank 621 | 622 | fmax : int 623 | The ending frequency for the highest Mel filter bank 624 | 625 | trainable_mel : bool 626 | Determine if the Mel filter banks are trainable or not. If ``True``, the gradients for Mel filter banks will also be caluclated and the Mel filter banks will be updated during model training. Default value is ``False`` 627 | 628 | trainable_STFT : bool 629 | Determine if the STFT kenrels are trainable or not. If ``True``, the gradients for STFT kernels will also be caluclated and the STFT kernels will be updated during model training. Default value is ``False`` 630 | 631 | verbose : bool 632 | If ``True``, it shows layer information. If ``False``, it suppresses all prints 633 | 634 | device : str 635 | Choose which device to initialize this layer. Default value is 'cuda:0' 636 | 637 | Returns 638 | ------- 639 | spectrogram : torch.tensor 640 | It returns a tensor of spectrograms. shape = ``(num_samples, freq_bins,time_steps)``. 641 | 642 | Examples 643 | -------- 644 | >>> spec_layer = Spectrogram.MelSpectrogram() 645 | >>> specs = spec_layer(x) 646 | """ 647 | 648 | def __init__(self, sr=22050, n_fft=2048, n_mels=128, hop_length=512, window='hann', center=True, pad_mode='reflect', power=2.0, htk=False, fmin=0.0, fmax=None, norm=1, trainable_mel=False, trainable_STFT=False, verbose=True, device='cuda:0'): 649 | super(MelSpectrogram, self).__init__() 650 | self.stride = hop_length 651 | self.center = center 652 | self.pad_mode = pad_mode 653 | self.n_fft = n_fft 654 | self.device = device 655 | self.power = power 656 | 657 | # Create filter windows for stft 658 | start = time() 659 | wsin, wcos, self.bins2freq, _ = create_fourier_kernels(n_fft, freq_bins=None, window=window, freq_scale='no', sr=sr) 660 | self.wsin = torch.tensor(wsin, dtype=torch.float, device=self.device) 661 | self.wcos = torch.tensor(wcos, dtype=torch.float, device=self.device) 662 | 663 | 664 | # Creating kenral for mel spectrogram 665 | start = time() 666 | mel_basis = mel(sr, n_fft, n_mels, fmin, fmax, htk=htk, norm=norm) 667 | self.mel_basis = torch.tensor(mel_basis, device=self.device) 668 | 669 | if verbose==True: 670 | print("STFT filter created, time used = {:.4f} seconds".format(time()-start)) 671 | print("Mel filter created, time used = {:.4f} seconds".format(time()-start)) 672 | else: 673 | pass 674 | # Making everything nn.Prarmeter, so that this model can support nn.DataParallel 675 | self.mel_basis = torch.nn.Parameter(self.mel_basis, requires_grad=trainable_mel) 676 | self.wsin = torch.nn.Parameter(self.wsin, requires_grad=trainable_STFT) 677 | self.wcos = torch.nn.Parameter(self.wcos, requires_grad=trainable_STFT) 678 | 679 | # if trainable_mel==True: 680 | # self.mel_basis = torch.nn.Parameter(self.mel_basis) 681 | # if trainable_STFT==True: 682 | # self.wsin = torch.nn.Parameter(self.wsin) 683 | # self.wcos = torch.nn.Parameter(self.wcos) 684 | 685 | def forward(self,x): 686 | x = broadcast_dim(x) 687 | if self.center: 688 | if self.pad_mode == 'constant': 689 | padding = nn.ConstantPad1d(self.n_fft//2, 0) 690 | elif self.pad_mode == 'reflect': 691 | padding = nn.ReflectionPad1d(self.n_fft//2) 692 | 693 | x = padding(x) 694 | 695 | spec = torch.sqrt(conv1d(x, self.wsin, stride=self.stride).pow(2) \ 696 | + conv1d(x, self.wcos, stride=self.stride).pow(2))**self.power # Doing STFT by using conv1d 697 | 698 | melspec = torch.matmul(self.mel_basis, spec) 699 | return melspec 700 | 701 | class Gammatonegram(torch.nn.Module): 702 | """This function is to calculate the Gammatonegram of the input signal. Input signal should be in either of the following shapes. 1. ``(len_audio)``, 2. ``(num_audio, len_audio)``, 3. ``(num_audio, 1, len_audio)``. The correct shape will be inferred autommatically if the input follows these 3 shapes. This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``. 703 | 704 | Parameters 705 | ---------- 706 | sr : int 707 | The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``. Setting the correct sampling rate is very important for calculating the correct frequency. 708 | 709 | n_fft : int 710 | The window size for the STFT. Default value is 2048 711 | 712 | n_mels : int 713 | The number of Gammatonegram filter banks. The filter banks maps the n_fft to Gammatone bins. Default value is 64 714 | 715 | hop_length : int 716 | The hop (or stride) size. Default value is 512. 717 | 718 | window : str 719 | The windowing function for STFT. It uses ``scipy.signal.get_window``, please refer to scipy documentation for possible windowing functions. The default value is 'hann' 720 | 721 | center : bool 722 | Putting the STFT keneral at the center of the time-step or not. If ``False``, the time index is the beginning of the STFT kernel, if ``True``, the time index is the center of the STFT kernel. Default value if ``True``. 723 | 724 | pad_mode : str 725 | The padding method. Default value is 'reflect'. 726 | 727 | htk : bool 728 | When ``False`` is used, the Mel scale is quasi-logarithmic. When ``True`` is used, the Mel scale is logarithmic. The default value is ``False`` 729 | 730 | fmin : int 731 | The starting frequency for the lowest Gammatone filter bank 732 | 733 | fmax : int 734 | The ending frequency for the highest Gammatone filter bank 735 | 736 | trainable_mel : bool 737 | Determine if the Gammatone filter banks are trainable or not. If ``True``, the gradients for Mel filter banks will also be caluclated and the Mel filter banks will be updated during model training. Default value is ``False`` 738 | 739 | trainable_STFT : bool 740 | Determine if the STFT kenrels are trainable or not. If ``True``, the gradients for STFT kernels will also be caluclated and the STFT kernels will be updated during model training. Default value is ``False`` 741 | 742 | verbose : bool 743 | If ``True``, it shows layer information. If ``False``, it suppresses all prints 744 | 745 | device : str 746 | Choose which device to initialize this layer. Default value is 'cuda:0' 747 | 748 | Returns 749 | ------- 750 | spectrogram : torch.tensor 751 | It returns a tensor of spectrograms. shape = ``(num_samples, freq_bins,time_steps)``. 752 | 753 | Examples 754 | -------- 755 | >>> spec_layer = Spectrogram.Gammatonegram() 756 | >>> specs = spec_layer(x) 757 | """ 758 | 759 | def __init__(self, sr=44100, n_fft=2048, n_bins=64, hop_length=512, window='hann', center=True, pad_mode='reflect', power=2.0, htk=False, fmin=20.0, fmax=None, norm=1, trainable_bins=False, trainable_STFT=False, verbose=True, device='cuda:0'): 760 | super(Gammatonegram, self).__init__() 761 | self.stride = hop_length 762 | self.center = center 763 | self.pad_mode = pad_mode 764 | self.n_fft = n_fft 765 | self.device = device 766 | self.power = power 767 | 768 | # Create filter windows for stft 769 | start = time() 770 | wsin, wcos, self.bins2freq, _ = create_fourier_kernels(n_fft, freq_bins=None, window=window, freq_scale='no', sr=sr) 771 | self.wsin = torch.tensor(wsin, dtype=torch.float, device=self.device) 772 | self.wcos = torch.tensor(wcos, dtype=torch.float, device=self.device) 773 | 774 | 775 | # Creating kenral for Gammatone spectrogram 776 | start = time() 777 | gammatone_basis = gammatone(sr, n_fft, n_bins, fmin, fmax) 778 | self.gammatone_basis = torch.tensor(gammatone_basis, device=self.device) 779 | 780 | if verbose==True: 781 | print("STFT filter created, time used = {:.4f} seconds".format(time()-start)) 782 | print("Gammatone filter created, time used = {:.4f} seconds".format(time()-start)) 783 | else: 784 | pass 785 | # Making everything nn.Prarmeter, so that this model can support nn.DataParallel 786 | self.gammatone_basis = torch.nn.Parameter(self.gammatone_basis, requires_grad=trainable_bins) 787 | self.wsin = torch.nn.Parameter(self.wsin, requires_grad=trainable_STFT) 788 | self.wcos = torch.nn.Parameter(self.wcos, requires_grad=trainable_STFT) 789 | 790 | # if trainable_mel==True: 791 | # self.mel_basis = torch.nn.Parameter(self.mel_basis) 792 | # if trainable_STFT==True: 793 | # self.wsin = torch.nn.Parameter(self.wsin) 794 | # self.wcos = torch.nn.Parameter(self.wcos) 795 | 796 | def forward(self,x): 797 | x = broadcast_dim(x) 798 | if self.center: 799 | if self.pad_mode == 'constant': 800 | padding = nn.ConstantPad1d(self.n_fft//2, 0) 801 | elif self.pad_mode == 'reflect': 802 | padding = nn.ReflectionPad1d(self.n_fft//2) 803 | 804 | x = padding(x) 805 | 806 | spec = torch.sqrt(conv1d(x, self.wsin, stride=self.stride).pow(2) \ 807 | + conv1d(x, self.wcos, stride=self.stride).pow(2))**self.power # Doing STFT by using conv1d 808 | 809 | gammatonespec = torch.matmul(self.gammatone_basis, spec) 810 | return gammatonespec 811 | 812 | class MelSpectrogramv2(torch.nn.Module): 813 | """This is an experimental feature using torch.stft when trainable is not needed. Somehow it is slower?""" 814 | def __init__(self, sr=22050, n_fft=2048, n_mels=128, hop_length=512, window='hann', center=True, pad_mode='reflect', htk=False, fmin=0.0, fmax=None, norm=1, trainable_mel=False, trainable_STFT=False, device='cuda:0'): 815 | super(MelSpectrogramv2, self).__init__() 816 | self.stride = hop_length 817 | self.center = center 818 | self.pad_mode = pad_mode 819 | self.n_fft = n_fft 820 | self.trainable_STFT=trainable_STFT 821 | self.device = device 822 | 823 | # Create filter windows for stft 824 | if self.trainable_STFT==True: 825 | start = time() 826 | wsin, wcos, self.bins2freq, _ = create_fourier_kernels(n_fft, freq_bins=None, window=window, freq_scale='no', sr=sr) 827 | self.wsin = torch.tensor(wsin, dtype=torch.float, device=self.device) 828 | self.wcos = torch.tensor(wcos, dtype=torch.float, device=self.device) 829 | self.wsin = torch.nn.Parameter(self.wsin) 830 | self.wcos = torch.nn.Parameter(self.wcos) 831 | print("STFT filter created, time used = {:.4f} seconds".format(time()-start)) 832 | else: 833 | window = get_window(window,int(n_fft), fftbins=True).astype(np.float32) 834 | self.window = torch.tensor(window, device=self.device) 835 | # Creating kenral for mel spectrogram 836 | start = time() 837 | mel_basis = mel(sr, n_fft, n_mels, fmin, fmax, htk=htk, norm=norm) 838 | self.mel_basis = torch.tensor(mel_basis, device=self.device) 839 | print("Mel filter created, time used = {:.4f} seconds".format(time()-start)) 840 | 841 | if trainable_mel==True: 842 | self.mel_basis = torch.nn.Parameter(self.mel_basis) 843 | 844 | 845 | 846 | def forward(self,x): 847 | if self.center: 848 | if self.pad_mode == 'constant': 849 | padding = nn.ConstantPad1d(self.n_fft//2, 0) 850 | elif self.pad_mode == 'reflect': 851 | padding = nn.ReflectionPad1d(self.n_fft//2) 852 | 853 | 854 | if self.trainable_STFT==False: 855 | x = padding(x) 856 | spec_complex = torch.stft(x, self.n_fft, self.stride, window=self.window) 857 | spec = spec_complex[:,:,:,0].pow(2) + spec_complex[:,:,:,1].pow(2) 858 | else: 859 | x = broadcast_dim(x) 860 | x = padding(x) 861 | spec = conv1d(x, self.wsin, stride=self.stride).pow(2) \ 862 | + conv1d(x, self.wcos, stride=self.stride).pow(2) # Doing STFT by using conv1d 863 | 864 | melspec = torch.matmul(self.mel_basis, spec) 865 | return melspec 866 | 867 | class CQT1992(torch.nn.Module): 868 | def __init__(self, sr=22050, hop_length=512, fmin=220, fmax=None, n_bins=84, bins_per_octave=12, norm=1, window='hann', center=True, pad_mode='reflect', device="cuda:0"): 869 | super(CQT1992, self).__init__() 870 | # norm arg is not functioning 871 | 872 | self.hop_length = hop_length 873 | self.center = center 874 | self.pad_mode = pad_mode 875 | self.norm = norm 876 | self.device = device 877 | 878 | # creating kernels for CQT 879 | Q = 1/(2**(1/bins_per_octave)-1) 880 | 881 | print("Creating CQT kernels ...", end='\r') 882 | start = time() 883 | self.cqt_kernels, self.kernal_width, self.lenghts = create_cqt_kernels(Q, sr, fmin, n_bins, bins_per_octave, norm, window, fmax) 884 | self.lenghts = self.lenghts.to(device) 885 | self.cqt_kernels = fft(self.cqt_kernels)[:,:self.kernal_width//2+1] 886 | self.cqt_kernels_real = torch.tensor(self.cqt_kernels.real.astype(np.float32), device=device) 887 | self.cqt_kernels_imag = torch.tensor(self.cqt_kernels.imag.astype(np.float32), device=device) 888 | print("CQT kernels created, time used = {:.4f} seconds".format(time()-start)) 889 | 890 | # creating kernels for stft 891 | # self.cqt_kernels_real*=lenghts.unsqueeze(1)/self.kernal_width # Trying to normalize as librosa 892 | # self.cqt_kernels_imag*=lenghts.unsqueeze(1)/self.kernal_width 893 | print("Creating STFT kernels ...", end='\r') 894 | start = time() 895 | wsin, wcos, self.bins2freq, _ = create_fourier_kernels(self.kernal_width, window='ones', freq_scale='no') 896 | self.wsin = torch.tensor(wsin, device=device) 897 | self.wcos = torch.tensor(wcos, device=device) 898 | print("STFT kernels created, time used = {:.4f} seconds".format(time()-start)) 899 | 900 | def forward(self,x): 901 | x = broadcast_dim(x) 902 | if self.center: 903 | if self.pad_mode == 'constant': 904 | padding = nn.ConstantPad1d(self.kernal_width//2, 0) 905 | elif self.pad_mode == 'reflect': 906 | padding = nn.ReflectionPad1d(self.kernal_width//2) 907 | 908 | x = padding(x) 909 | 910 | # STFT 911 | fourier_real = conv1d(x, self.wcos, stride=self.hop_length) 912 | fourier_imag = conv1d(x, self.wsin, stride=self.hop_length) 913 | 914 | # CQT 915 | CQT_real, CQT_imag = complex_mul((self.cqt_kernels_real, self.cqt_kernels_imag), 916 | (fourier_real, fourier_imag)) 917 | 918 | # Getting CQT Amplitude 919 | CQT = torch.sqrt(CQT_real.pow(2)+CQT_imag.pow(2)) 920 | 921 | if self.norm: 922 | return CQT/self.kernal_width*torch.sqrt(self.lenghts.view(-1,1)) 923 | else: 924 | return CQT*torch.sqrt(self.lenghts.view(-1,1)) 925 | 926 | 927 | class CQT2010(torch.nn.Module): 928 | """ 929 | This alogrithm is using the resampling method proposed in [1]. Instead of convoluting the STFT results with a gigantic CQT kernel covering the full frequency spectrum, we make a small CQT kernel covering only the top octave. Then we keep downsampling the input audio by a factor of 2 to convoluting it with the small CQT kernel. Everytime the input audio is downsampled, the CQT relative to the downsampled input is equavalent to the next lower octave. 930 | The kernel creation process is still same as the 1992 algorithm. Therefore, we can reuse the code from the 1992 alogrithm [2] 931 | [1] Schörkhuber, Christian. “CONSTANT-Q TRANSFORM TOOLBOX FOR MUSIC PROCESSING.” (2010). 932 | [2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of a constant Q transform.” (1992). 933 | 934 | early downsampling factor is to downsample the input audio to reduce the CQT kernel size. The result with and without early downsampling are more or less the same except in the very low frequency region where freq < 40Hz 935 | 936 | """ 937 | def __init__(self, sr=22050, hop_length=512, fmin=32.70, fmax=None, n_bins=84, bins_per_octave=12, norm=True, basis_norm=1, window='hann', pad_mode='reflect', earlydownsample=True, device='cuda:0'): 938 | super(CQT2010, self).__init__() 939 | 940 | self.norm = norm # Now norm is used to normalize the final CQT result by dividing n_fft 941 | #basis_norm is for normlaizing basis 942 | self.hop_length = hop_length 943 | self.pad_mode = pad_mode 944 | self.n_bins = n_bins 945 | self.earlydownsample = earlydownsample # We will activate eraly downsampling later if possible 946 | self.device = device 947 | 948 | Q = 1/(2**(1/bins_per_octave)-1) # It will be used to calculate filter_cutoff and creating CQT kernels 949 | 950 | # Creating lowpass filter and make it a torch tensor 951 | print("Creating low pass filter ...", end='\r') 952 | start = time() 953 | self.lowpass_filter = torch.tensor( 954 | create_lowpass_filter( 955 | band_center = 0.5, 956 | kernelLength=256, 957 | transitionBandwidth=0.001), device=self.device) 958 | self.lowpass_filter = self.lowpass_filter[None,None,:] # Broadcast the tensor to the shape that fits conv1d 959 | print("Low pass filter created, time used = {:.4f} seconds".format(time()-start)) 960 | 961 | # Caluate num of filter requires for the kernel 962 | # n_octaves determines how many resampling requires for the CQT 963 | n_filters = min(bins_per_octave, n_bins) 964 | self.n_octaves = int(np.ceil(float(n_bins) / bins_per_octave)) 965 | # print("n_octaves = ", self.n_octaves) 966 | # Calculate the lowest frequency bin for the top octave kernel 967 | self.fmin_t = fmin*2**(self.n_octaves-1) 968 | remainder = n_bins % bins_per_octave 969 | # print("remainder = ", remainder) 970 | if remainder==0: 971 | fmax_t = self.fmin_t*2**((bins_per_octave-1)/bins_per_octave) # Calculate the top bin frequency 972 | else: 973 | fmax_t = self.fmin_t*2**((remainder-1)/bins_per_octave) # Calculate the top bin frequency 974 | self.fmin_t = fmax_t/2**(1-1/bins_per_octave) # Adjusting the top minium bins 975 | if fmax_t > sr/2: 976 | raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, please reduce the n_bins'.format(fmax_t)) 977 | 978 | 979 | if self.earlydownsample == True: # Do early downsampling if this argument is True 980 | print("Creating early downsampling filter ...", end='\r') 981 | start = time() 982 | sr, self.hop_length, self.downsample_factor, self.early_downsample_filter, self.earlydownsample = self.get_early_downsample_params(sr, hop_length, fmax_t, Q, self.n_octaves, verbose) 983 | print("Early downsampling filter created, time used = {:.4f} seconds".format(time()-start)) 984 | else: 985 | self.downsample_factor=1. 986 | 987 | # Preparing CQT kernels 988 | print("Creating CQT kernels ...", end='\r') 989 | start = time() 990 | # print("Q = {}, fmin_t = {}, n_filters = {}".format(Q, self.fmin_t, n_filters)) 991 | basis, self.n_fft, _ = create_cqt_kernels(Q, sr, self.fmin_t, n_filters, bins_per_octave, norm=basis_norm, topbin_check=False) 992 | 993 | # This is for the normalization in the end 994 | freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave)) 995 | lenghts = np.ceil(Q * sr / freqs) 996 | self.lenghts = torch.tensor(lenghts, device=self.device).float() 997 | 998 | 999 | self.basis=basis 1000 | fft_basis = fft(basis)[:,:self.n_fft//2+1] # Convert CQT kenral from time domain to freq domain 1001 | 1002 | self.cqt_kernels_real = torch.tensor(fft_basis.real.astype(np.float32), device=self.device) # These cqt_kernal is already in the frequency domain 1003 | self.cqt_kernels_imag = torch.tensor(fft_basis.imag.astype(np.float32), device=self.device) 1004 | print("CQT kernels created, time used = {:.4f} seconds".format(time()-start)) 1005 | # print("Getting cqt kernel done, n_fft = ",self.n_fft) 1006 | # Preparing kernels for Short-Time Fourier Transform (STFT) 1007 | # We set the frequency range in the CQT filter instead of here. 1008 | print("Creating STFT kernels ...", end='\r') 1009 | start = time() 1010 | wsin, wcos, self.bins2freq, _ = create_fourier_kernels(self.n_fft, window='ones', freq_scale='no') 1011 | self.wsin = torch.tensor(wsin, device=self.device) 1012 | self.wcos = torch.tensor(wcos, device=self.device) 1013 | print("STFT kernels created, time used = {:.4f} seconds".format(time()-start)) 1014 | 1015 | 1016 | 1017 | # If center==True, the STFT window will be put in the middle, and paddings at the beginning and ending are required. 1018 | if self.pad_mode == 'constant': 1019 | self.padding = nn.ConstantPad1d(self.n_fft//2, 0) 1020 | elif self.pad_mode == 'reflect': 1021 | self.padding = nn.ReflectionPad1d(self.n_fft//2) 1022 | 1023 | 1024 | def get_cqt(self,x,hop_length, padding): 1025 | """Multiplying the STFT result with the cqt_kernal, check out the 1992 CQT paper [1] for how to multiple the STFT result with the CQT kernel 1026 | [2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of a constant Q transform.” (1992).""" 1027 | 1028 | # STFT, converting the audio input from time domain to frequency domain 1029 | try: 1030 | x = padding(x) # When center == True, we need padding at the beginning and ending 1031 | except: 1032 | print("padding with reflection mode might not be the best choice, try using constant padding") 1033 | fourier_real = conv1d(x, self.wcos, stride=hop_length) 1034 | fourier_imag = conv1d(x, self.wsin, stride=hop_length) 1035 | 1036 | # Multiplying input with the CQT kernel in freq domain 1037 | CQT_real, CQT_imag = complex_mul((self.cqt_kernels_real, self.cqt_kernels_imag), 1038 | (fourier_real, fourier_imag)) 1039 | 1040 | # Getting CQT Amplitude 1041 | CQT = torch.sqrt(CQT_real.pow(2)+CQT_imag.pow(2)) 1042 | 1043 | return CQT 1044 | 1045 | 1046 | def get_early_downsample_params(self, sr, hop_length, fmax_t, Q, n_octaves, verbose): 1047 | window_bandwidth = 1.5 # for hann window 1048 | filter_cutoff = fmax_t * (1 + 0.5 * window_bandwidth / Q) 1049 | sr, hop_length, downsample_factor=self.early_downsample(sr, hop_length, n_octaves, sr//2, filter_cutoff) 1050 | if downsample_factor != 1: 1051 | print("Can do early downsample, factor = ", downsample_factor) 1052 | earlydownsample=True 1053 | # print("new sr = ", sr) 1054 | # print("new hop_length = ", hop_length) 1055 | early_downsample_filter = create_lowpass_filter(band_center=1/downsample_factor, kernelLength=256, transitionBandwidth=0.03) 1056 | early_downsample_filter = torch.tensor(early_downsample_filter, device=self.device)[None, None, :] 1057 | else: 1058 | print("No early downsampling is required, downsample_factor = ", downsample_factor) 1059 | early_downsample_filter = None 1060 | earlydownsample=False 1061 | return sr, hop_length, downsample_factor, early_downsample_filter, earlydownsample 1062 | 1063 | # The following two downsampling count functions are obtained from librosa CQT 1064 | # They are used to determine the number of pre resamplings if the starting and ending frequency are both in low frequency regions. 1065 | def early_downsample_count(self, nyquist, filter_cutoff, hop_length, n_octaves): 1066 | '''Compute the number of early downsampling operations''' 1067 | 1068 | downsample_count1 = max(0, int(np.ceil(np.log2(0.85 * nyquist / 1069 | filter_cutoff)) - 1) - 1) 1070 | # print("downsample_count1 = ", downsample_count1) 1071 | num_twos = nextpow2(hop_length) 1072 | downsample_count2 = max(0, num_twos - n_octaves + 1) 1073 | # print("downsample_count2 = ",downsample_count2) 1074 | 1075 | return min(downsample_count1, downsample_count2) 1076 | 1077 | def early_downsample(self, sr, hop_length, n_octaves, 1078 | nyquist, filter_cutoff): 1079 | '''Return new sampling rate and hop length after early dowansampling''' 1080 | downsample_count = self.early_downsample_count(nyquist, filter_cutoff, hop_length, n_octaves) 1081 | # print("downsample_count = ", downsample_count) 1082 | downsample_factor = 2**(downsample_count) 1083 | 1084 | hop_length //= downsample_factor # Getting new hop_length 1085 | new_sr = sr / float(downsample_factor) # Getting new sampling rate 1086 | 1087 | sr = new_sr 1088 | 1089 | return sr, hop_length, downsample_factor 1090 | 1091 | 1092 | def forward(self,x): 1093 | x = broadcast_dim(x) 1094 | if self.earlydownsample==True: 1095 | x = downsampling_by_n(x, self.early_downsample_filter, self.downsample_factor) 1096 | hop = self.hop_length 1097 | CQT = self.get_cqt(x, hop, self.padding) #Getting the top octave CQT 1098 | 1099 | x_down = x # Preparing a new variable for downsampling 1100 | for i in range(self.n_octaves-1): 1101 | hop = hop//2 1102 | x_down = downsampling_by_2(x_down, self.lowpass_filter) 1103 | CQT1 = self.get_cqt(x_down, hop, self.padding) 1104 | CQT = torch.cat((CQT1, CQT),1) # 1105 | CQT = CQT[:,-self.n_bins:,:] #Removing unwanted top bins 1106 | CQT = CQT*2**(self.n_octaves-1) #Normalizing signals with respect to n_fft 1107 | 1108 | CQT = CQT*self.downsample_factor/2**(self.n_octaves-1) # Normalizing the output with the downsampling factor, 2**(self.n_octaves-1) is make it same mag as 1992 1109 | 1110 | if self.norm: 1111 | return CQT/self.n_fft*torch.sqrt(self.lenghts.view(-1,1)) 1112 | else: 1113 | return CQT*torch.sqrt(self.lenghts.view(-1,1)) 1114 | 1115 | class CQT1992v2(torch.nn.Module): 1116 | """This function is to calculate the CQT of the input signal. Input signal should be in either of the following shapes. 1. ``(len_audio)``, 2. ``(num_audio, len_audio)``, 3. ``(num_audio, 1, len_audio)``. The correct shape will be inferred autommatically if the input follows these 3 shapes. Most of the arguments follow the convention from librosa. This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``. 1117 | 1118 | This alogrithm uses the method proposed in [1]. I slightly modify it so that it runs faster than the original 1992 algorithm, that is why I call it version 2. 1119 | [1] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of a constant Q transform.” (1992). 1120 | 1121 | Parameters 1122 | ---------- 1123 | sr : int 1124 | The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``. Setting the correct sampling rate is very important for calculating the correct frequency. 1125 | 1126 | hop_length : int 1127 | The hop (or stride) size. Default value is 512. 1128 | 1129 | fmin : float 1130 | The frequency for the lowest CQT bin. Default is 32.70Hz, which coresponds to the note C0. 1131 | 1132 | fmax : float 1133 | The frequency for the highest CQT bin. Default is ``None``, therefore the higest CQT bin is inferred from the ``n_bins`` and ``bins_per_octave``. If ``fmax`` is not ``None``, then the argument ``n_bins`` will be ignored and ``n_bins`` will be calculated automatically. Default is ``None`` 1134 | 1135 | n_bins : int 1136 | The total numbers of CQT bins. Default is 84. Will be ignored if ``fmax`` is not ``None``. 1137 | 1138 | bins_per_octave : int 1139 | Number of bins per octave. Default is 12. 1140 | 1141 | norm : int 1142 | Normalization for the CQT kernels. ``1`` means L1 normalization, and ``2`` means L2 normalization. Default is ``1``, which is same as the normalization used in librosa. 1143 | 1144 | window : str 1145 | The windowing function for CQT. It uses ``scipy.signal.get_window``, please refer to scipy documentation for possible windowing functions. The default value is 'hann' 1146 | 1147 | center : bool 1148 | Putting the CQT keneral at the center of the time-step or not. If ``False``, the time index is the beginning of the CQT kernel, if ``True``, the time index is the center of the CQT kernel. Default value if ``True``. 1149 | 1150 | pad_mode : str 1151 | The padding method. Default value is 'reflect'. 1152 | 1153 | trainable : bool 1154 | Determine if the CQT kernels are trainable or not. If ``True``, the gradients for CQT kernels will also be caluclated and the CQT kernels will be updated during model training. Default value is ``False`` 1155 | 1156 | output_format : str 1157 | Determine the return type. ``Magnitude`` will return the magnitude of the STFT result, shape = ``(num_samples, freq_bins,time_steps)``; ``Complex`` will return the STFT result in complex number, shape = ``(num_samples, freq_bins,time_steps, 2)``; ``Phase`` will return the phase of the STFT reuslt, shape = ``(num_samples, freq_bins,time_steps, 2)``. The complex number is stored as ``(real, imag)`` in the last axis. Default value is 'Magnitude'. 1158 | 1159 | verbose : bool 1160 | If ``True``, it shows layer information. If ``False``, it suppresses all prints 1161 | 1162 | device : str 1163 | Choose which device to initialize this layer. Default value is 'cuda:0' 1164 | 1165 | Returns 1166 | ------- 1167 | spectrogram : torch.tensor 1168 | It returns a tensor of spectrograms. shape = ``(num_samples, freq_bins,time_steps)`` if 'Magnitude' is used as the ``output_format``; Shape = ``(num_samples, freq_bins,time_steps, 2)`` if 'Complex' or 'Phase' are used as the ``output_format`` 1169 | 1170 | Examples 1171 | -------- 1172 | >>> spec_layer = Spectrogram.CQT1992v2() 1173 | >>> specs = spec_layer(x) 1174 | """ 1175 | 1176 | def __init__(self, sr=22050, hop_length=512, fmin=32.70, fmax=None, n_bins=84, bins_per_octave=12, norm=1, window='hann', center=True, pad_mode='reflect', trainable=False, output_format='Magnitude', verbose=True, device='cuda:0'): 1177 | super(CQT1992v2, self).__init__() 1178 | # norm arg is not functioning 1179 | self.trainable = trainable 1180 | self.hop_length = hop_length 1181 | self.center = center 1182 | self.pad_mode = pad_mode 1183 | self.output_format = output_format 1184 | self.device = device 1185 | 1186 | 1187 | # creating kernels for CQT 1188 | Q = 1/(2**(1/bins_per_octave)-1) 1189 | 1190 | if verbose==True: 1191 | print("Creating CQT kernels ...", end='\r') 1192 | 1193 | start = time() 1194 | self.cqt_kernels, self.kernal_width, self.lenghts = create_cqt_kernels(Q, sr, fmin, n_bins, bins_per_octave, norm, window, fmax) 1195 | self.lenghts = self.lenghts.to(device) 1196 | self.cqt_kernels_real = torch.tensor(self.cqt_kernels.real, device=self.device).unsqueeze(1) 1197 | self.cqt_kernels_imag = torch.tensor(self.cqt_kernels.imag, device=self.device).unsqueeze(1) 1198 | 1199 | # Making everything a Parameter to support nn.DataParallel 1200 | self.cqt_kernels_real = torch.nn.Parameter(self.cqt_kernels_real, requires_grad=trainable) 1201 | self.cqt_kernels_imag = torch.nn.Parameter(self.cqt_kernels_imag, requires_grad=trainable) 1202 | self.lenghts = torch.nn.Parameter(self.lenghts, requires_grad=False) 1203 | # if trainable==True: 1204 | # self.cqt_kernels_real = torch.nn.Parameter(self.cqt_kernels_real) 1205 | # self.cqt_kernels_imag = torch.nn.Parameter(self.cqt_kernels_imag) 1206 | 1207 | if verbose==True: 1208 | print("CQT kernels created, time used = {:.4f} seconds".format(time()-start)) 1209 | 1210 | # creating kernels for stft 1211 | # self.cqt_kernels_real*=lenghts.unsqueeze(1)/self.kernal_width # Trying to normalize as librosa 1212 | # self.cqt_kernels_imag*=lenghts.unsqueeze(1)/self.kernal_width 1213 | 1214 | def forward(self,x): 1215 | x = broadcast_dim(x) 1216 | if self.center: 1217 | if self.pad_mode == 'constant': 1218 | padding = nn.ConstantPad1d(self.kernal_width//2, 0) 1219 | elif self.pad_mode == 'reflect': 1220 | padding = nn.ReflectionPad1d(self.kernal_width//2) 1221 | 1222 | x = padding(x) 1223 | 1224 | # CQT 1225 | CQT_real = conv1d(x, self.cqt_kernels_real, stride=self.hop_length)*torch.sqrt(self.lenghts.view(-1,1)) 1226 | CQT_imag = -conv1d(x, self.cqt_kernels_imag, stride=self.hop_length)*torch.sqrt(self.lenghts.view(-1,1)) 1227 | 1228 | if self.output_format=='Magnitude': 1229 | if self.trainable==False: 1230 | # Getting CQT Amplitude 1231 | CQT = torch.sqrt(CQT_real.pow(2)+CQT_imag.pow(2)) 1232 | else: 1233 | CQT = torch.sqrt(CQT_real.pow(2)+CQT_imag.pow(2)+1e-8) 1234 | return CQT 1235 | 1236 | elif self.output_format=='Complex': 1237 | return torch.stack((CQT_real,CQT_imag),-1) 1238 | 1239 | elif self.output_format=='Phase': 1240 | phase_real = torch.cos(torch.atan2(CQT_imag,CQT_real)) 1241 | phase_imag = torch.sin(torch.atan2(CQT_imag,CQT_real)) 1242 | return torch.stack((phase_real,phase_imag), -1) 1243 | 1244 | def forward_manual(self,x): 1245 | x = broadcast_dim(x) 1246 | if self.center: 1247 | if self.pad_mode == 'constant': 1248 | padding = nn.ConstantPad1d(self.kernal_width//2, 0) 1249 | elif self.pad_mode == 'reflect': 1250 | padding = nn.ReflectionPad1d(self.kernal_width//2) 1251 | 1252 | x = padding(x) 1253 | 1254 | # CQT 1255 | CQT_real = conv1d(x, self.cqt_kernels_real, stride=self.hop_length) 1256 | CQT_imag = conv1d(x, self.cqt_kernels_imag, stride=self.hop_length) 1257 | 1258 | 1259 | # Getting CQT Amplitude 1260 | CQT = torch.sqrt(CQT_real.pow(2)+CQT_imag.pow(2)) 1261 | return CQT*torch.sqrt(self.lenghts.view(-1,1)) 1262 | 1263 | 1264 | class CQT2010v2(torch.nn.Module): 1265 | """This function is to calculate the CQT of the input signal. Input signal should be in either of the following shapes. 1. ``(len_audio)``, 2. ``(num_audio, len_audio)``, 3. ``(num_audio, 1, len_audio)``. The correct shape will be inferred autommatically if the input follows these 3 shapes. Most of the arguments follow the convention from librosa. This class inherits from ``torch.nn.Module``, therefore, the usage is same as ``torch.nn.Module``. 1266 | 1267 | This alogrithm uses the resampling method proposed in [1]. Instead of convoluting the STFT results with a gigantic CQT kernel covering the full frequency spectrum, we make a small CQT kernel covering only the top octave. Then we keep downsampling the input audio by a factor of 2 to convoluting it with the small CQT kernel. Everytime the input audio is downsampled, the CQT relative to the downsampled input is equavalent to the next lower octave. 1268 | The kernel creation process is still same as the 1992 algorithm. Therefore, we can reuse the code from the 1992 alogrithm [2] 1269 | [1] Schörkhuber, Christian. “CONSTANT-Q TRANSFORM TOOLBOX FOR MUSIC PROCESSING.” (2010). 1270 | [2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of a constant Q transform.” (1992). 1271 | 1272 | early downsampling factor is to downsample the input audio to reduce the CQT kernel size. The result with and without early downsampling are more or less the same except in the very low frequency region where freq < 40Hz 1273 | 1274 | Parameters 1275 | ---------- 1276 | sr : int 1277 | The sampling rate for the input audio. It is used to calucate the correct ``fmin`` and ``fmax``. Setting the correct sampling rate is very important for calculating the correct frequency. 1278 | 1279 | hop_length : int 1280 | The hop (or stride) size. Default value is 512. 1281 | 1282 | fmin : float 1283 | The frequency for the lowest CQT bin. Default is 32.70Hz, which coresponds to the note C0. 1284 | 1285 | fmax : float 1286 | The frequency for the highest CQT bin. Default is ``None``, therefore the higest CQT bin is inferred from the ``n_bins`` and ``bins_per_octave``. If ``fmax`` is not ``None``, then the argument ``n_bins`` will be ignored and ``n_bins`` will be calculated automatically. Default is ``None`` 1287 | 1288 | n_bins : int 1289 | The total numbers of CQT bins. Default is 84. Will be ignored if ``fmax`` is not ``None``. 1290 | 1291 | bins_per_octave : int 1292 | Number of bins per octave. Default is 12. 1293 | 1294 | norm : bool 1295 | Normalization for the CQT result. 1296 | 1297 | basis_norm : int 1298 | Normalization for the CQT kernels. ``1`` means L1 normalization, and ``2`` means L2 normalization. Default is ``1``, which is same as the normalization used in librosa. 1299 | 1300 | window : str 1301 | The windowing function for CQT. It uses ``scipy.signal.get_window``, please refer to scipy documentation for possible windowing functions. The default value is 'hann' 1302 | 1303 | pad_mode : str 1304 | The padding method. Default value is 'reflect'. 1305 | 1306 | trainable : bool 1307 | Determine if the CQT kernels are trainable or not. If ``True``, the gradients for CQT kernels will also be caluclated and the CQT kernels will be updated during model training. Default value is ``False`` 1308 | 1309 | output_format : str 1310 | Determine the return type. 'Magnitude' will return the magnitude of the STFT result, shape = ``(num_samples, freq_bins,time_steps)``; 'Complex' will return the STFT result in complex number, shape = ``(num_samples, freq_bins,time_steps, 2)``; 'Phase' will return the phase of the STFT reuslt, shape = ``(num_samples, freq_bins,time_steps, 2)``. The complex number is stored as ``(real, imag)`` in the last axis. Default value is 'Magnitude'. 1311 | 1312 | verbose : bool 1313 | If ``True``, it shows layer information. If ``False``, it suppresses all prints 1314 | 1315 | device : str 1316 | Choose which device to initialize this layer. Default value is 'cuda:0' 1317 | 1318 | Returns 1319 | ------- 1320 | spectrogram : torch.tensor 1321 | It returns a tensor of spectrograms. shape = ``(num_samples, freq_bins,time_steps)`` if 'Magnitude' is used as the ``output_format``; Shape = ``(num_samples, freq_bins,time_steps, 2)`` if 'Complex' or 'Phase' are used as the ``output_format`` 1322 | 1323 | Examples 1324 | -------- 1325 | >>> spec_layer = Spectrogram.CQT2010v2() 1326 | >>> specs = spec_layer(x) 1327 | """ 1328 | 1329 | 1330 | def __init__(self, sr=22050, hop_length=512, fmin=32.70, fmax=None, n_bins=84, bins_per_octave=12, norm=True, basis_norm=1, window='hann', pad_mode='reflect', earlydownsample=True, trainable=False, output_format='Magnitude', verbose=True, device='cuda:0'): 1331 | super(CQT2010v2, self).__init__() 1332 | 1333 | self.norm = norm # Now norm is used to normalize the final CQT result by dividing n_fft 1334 | #basis_norm is for normlaizing basis 1335 | self.hop_length = hop_length 1336 | self.pad_mode = pad_mode 1337 | self.n_bins = n_bins 1338 | self.earlydownsample = earlydownsample # We will activate eraly downsampling later if possible 1339 | self.trainable = trainable 1340 | self.output_format = output_format 1341 | self.device = device 1342 | 1343 | Q = 1/(2**(1/bins_per_octave)-1) # It will be used to calculate filter_cutoff and creating CQT kernels 1344 | 1345 | # Creating lowpass filter and make it a torch tensor 1346 | if verbose==True: 1347 | print("Creating low pass filter ...", end='\r') 1348 | start = time() 1349 | # self.lowpass_filter = torch.tensor( 1350 | # create_lowpass_filter( 1351 | # band_center = 0.50, 1352 | # kernelLength=256, 1353 | # transitionBandwidth=0.001)) 1354 | self.lowpass_filter = torch.tensor( 1355 | create_lowpass_filter( 1356 | band_center = 0.50, 1357 | kernelLength=256, 1358 | transitionBandwidth=0.001), device=self.device) 1359 | self.lowpass_filter = self.lowpass_filter[None,None,:] # Broadcast the tensor to the shape that fits conv1d 1360 | if verbose==True: 1361 | print("Low pass filter created, time used = {:.4f} seconds".format(time()-start)) 1362 | 1363 | # Caluate num of filter requires for the kernel 1364 | # n_octaves determines how many resampling requires for the CQT 1365 | n_filters = min(bins_per_octave, n_bins) 1366 | self.n_octaves = int(np.ceil(float(n_bins) / bins_per_octave)) 1367 | if verbose==True: 1368 | print("num_octave = ", self.n_octaves) 1369 | 1370 | # Calculate the lowest frequency bin for the top octave kernel 1371 | self.fmin_t = fmin*2**(self.n_octaves-1) 1372 | remainder = n_bins % bins_per_octave 1373 | # print("remainder = ", remainder) 1374 | if remainder==0: 1375 | fmax_t = self.fmin_t*2**((bins_per_octave-1)/bins_per_octave) # Calculate the top bin frequency 1376 | else: 1377 | fmax_t = self.fmin_t*2**((remainder-1)/bins_per_octave) # Calculate the top bin frequency 1378 | self.fmin_t = fmax_t/2**(1-1/bins_per_octave) # Adjusting the top minium bins 1379 | if fmax_t > sr/2: 1380 | raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, please reduce the n_bins'.format(fmax_t)) 1381 | 1382 | if self.earlydownsample == True: # Do early downsampling if this argument is True 1383 | if verbose==True: 1384 | print("Creating early downsampling filter ...", end='\r') 1385 | start = time() 1386 | sr, self.hop_length, self.downsample_factor, self.early_downsample_filter, self.earlydownsample = self.get_early_downsample_params(sr, hop_length, fmax_t, Q, self.n_octaves, verbose) 1387 | if verbose==True: 1388 | print("Early downsampling filter created, time used = {:.4f} seconds".format(time()-start)) 1389 | else: 1390 | self.downsample_factor=1. 1391 | 1392 | # Preparing CQT kernels 1393 | if verbose==True: 1394 | print("Creating CQT kernels ...", end='\r') 1395 | start = time() 1396 | basis, self.n_fft, self.lenghts = create_cqt_kernels(Q, sr, self.fmin_t, n_filters, bins_per_octave, norm=basis_norm, topbin_check=False) 1397 | 1398 | # For normalization in the end 1399 | freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave)) 1400 | lenghts = np.ceil(Q * sr / freqs) 1401 | self.lenghts = torch.tensor(lenghts,device=self.device).float() 1402 | 1403 | self.basis = basis 1404 | self.cqt_kernels_real = torch.tensor(basis.real.astype(np.float32),device=self.device).unsqueeze(1) # These cqt_kernal is already in the frequency domain 1405 | self.cqt_kernels_imag = torch.tensor(basis.imag.astype(np.float32),device=self.device).unsqueeze(1) 1406 | 1407 | # Making them nn.Parameter so that the model can support nn.DataParallel 1408 | self.cqt_kernels_real = torch.nn.Parameter(self.cqt_kernels_real, requires_grad=self.trainable) 1409 | self.cqt_kernels_imag = torch.nn.Parameter(self.cqt_kernels_imag, requires_grad=self.trainable) 1410 | self.lenghts = torch.nn.Parameter(self.lenghts, requires_grad=False) 1411 | self.lowpass_filter = torch.nn.Parameter(self.lowpass_filter, requires_grad=False) 1412 | # if trainable==True: 1413 | # self.cqt_kernels_real = torch.nn.Parameter(self.cqt_kernels_real) 1414 | # self.cqt_kernels_imag = torch.nn.Parameter(self.cqt_kernels_imag) 1415 | if verbose==True: 1416 | print("CQT kernels created, time used = {:.4f} seconds".format(time()-start)) 1417 | # print("Getting cqt kernel done, n_fft = ",self.n_fft) 1418 | 1419 | # If center==True, the STFT window will be put in the middle, and paddings at the beginning and ending are required. 1420 | if self.pad_mode == 'constant': 1421 | self.padding = nn.ConstantPad1d(self.n_fft//2, 0) 1422 | elif self.pad_mode == 'reflect': 1423 | self.padding = nn.ReflectionPad1d(self.n_fft//2) 1424 | 1425 | 1426 | def get_cqt(self,x,hop_length, padding): 1427 | """Multiplying the STFT result with the cqt_kernal, check out the 1992 CQT paper [1] for how to multiple the STFT result with the CQT kernel 1428 | [2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of a constant Q transform.” (1992).""" 1429 | 1430 | # STFT, converting the audio input from time domain to frequency domain 1431 | try: 1432 | x = padding(x) # When center == True, we need padding at the beginning and ending 1433 | except: 1434 | print("padding with reflection mode might not be the best choice, try using constant padding") 1435 | CQT_real = conv1d(x, self.cqt_kernels_real, stride=hop_length) 1436 | CQT_imag = conv1d(x, self.cqt_kernels_imag, stride=hop_length) 1437 | 1438 | # Getting CQT Amplitude 1439 | CQT = torch.sqrt(CQT_real.pow(2)+CQT_imag.pow(2)) 1440 | 1441 | return CQT 1442 | 1443 | def get_cqt_complex(self,x,hop_length, padding): 1444 | """Multiplying the STFT result with the cqt_kernal, check out the 1992 CQT paper [1] for how to multiple the STFT result with the CQT kernel 1445 | [2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of a constant Q transform.” (1992).""" 1446 | 1447 | # STFT, converting the audio input from time domain to frequency domain 1448 | try: 1449 | x = padding(x) # When center == True, we need padding at the beginning and ending 1450 | except: 1451 | print("padding with reflection mode might not be the best choice, try using constant padding") 1452 | CQT_real = conv1d(x, self.cqt_kernels_real, stride=hop_length) 1453 | CQT_imag = -conv1d(x, self.cqt_kernels_imag, stride=hop_length) 1454 | 1455 | return torch.stack((CQT_real, CQT_imag),-1) 1456 | 1457 | def get_early_downsample_params(self, sr, hop_length, fmax_t, Q, n_octaves, verbose): 1458 | window_bandwidth = 1.5 # for hann window 1459 | filter_cutoff = fmax_t * (1 + 0.5 * window_bandwidth / Q) 1460 | sr, hop_length, downsample_factor=self.early_downsample(sr, hop_length, n_octaves, sr//2, filter_cutoff) 1461 | if downsample_factor != 1: 1462 | if verbose==True: 1463 | print("Can do early downsample, factor = ", downsample_factor) 1464 | earlydownsample=True 1465 | # print("new sr = ", sr) 1466 | # print("new hop_length = ", hop_length) 1467 | early_downsample_filter = create_lowpass_filter(band_center=1/downsample_factor, kernelLength=256, transitionBandwidth=0.03) 1468 | early_downsample_filter = torch.tensor(early_downsample_filter, device=self.device)[None, None, :] 1469 | else: 1470 | if verbose==True: 1471 | print("No early downsampling is required, downsample_factor = ", downsample_factor) 1472 | early_downsample_filter = None 1473 | earlydownsample=False 1474 | return sr, hop_length, downsample_factor, early_downsample_filter, earlydownsample 1475 | 1476 | # The following two downsampling count functions are obtained from librosa CQT 1477 | # They are used to determine the number of pre resamplings if the starting and ending frequency are both in low frequency regions. 1478 | def early_downsample_count(self, nyquist, filter_cutoff, hop_length, n_octaves): 1479 | '''Compute the number of early downsampling operations''' 1480 | 1481 | downsample_count1 = max(0, int(np.ceil(np.log2(0.85 * nyquist / 1482 | filter_cutoff)) - 1) - 1) 1483 | # print("downsample_count1 = ", downsample_count1) 1484 | num_twos = nextpow2(hop_length) 1485 | downsample_count2 = max(0, num_twos - n_octaves + 1) 1486 | # print("downsample_count2 = ",downsample_count2) 1487 | 1488 | return min(downsample_count1, downsample_count2) 1489 | 1490 | def early_downsample(self, sr, hop_length, n_octaves, 1491 | nyquist, filter_cutoff): 1492 | '''Return new sampling rate and hop length after early dowansampling''' 1493 | downsample_count = self.early_downsample_count(nyquist, filter_cutoff, hop_length, n_octaves) 1494 | # print("downsample_count = ", downsample_count) 1495 | downsample_factor = 2**(downsample_count) 1496 | 1497 | hop_length //= downsample_factor # Getting new hop_length 1498 | new_sr = sr / float(downsample_factor) # Getting new sampling rate 1499 | 1500 | sr = new_sr 1501 | 1502 | return sr, hop_length, downsample_factor 1503 | 1504 | 1505 | def forward(self,x): 1506 | x = broadcast_dim(x) 1507 | if self.earlydownsample==True: 1508 | x = downsampling_by_n(x, self.early_downsample_filter, self.downsample_factor) 1509 | hop = self.hop_length 1510 | CQT = self.get_cqt_complex(x, hop, self.padding) #Getting the top octave CQT 1511 | 1512 | x_down = x # Preparing a new variable for downsampling 1513 | for i in range(self.n_octaves-1): 1514 | hop = hop//2 1515 | x_down = downsampling_by_2(x_down, self.lowpass_filter) 1516 | CQT1 = self.get_cqt_complex(x_down, hop, self.padding) 1517 | CQT = torch.cat((CQT1, CQT),1) # 1518 | CQT = CQT[:,-self.n_bins:,:] #Removing unwanted bottom bins 1519 | CQT = CQT*2**(self.n_octaves-1) #Normalizing signals with respect to n_fft 1520 | # print("downsample_factor = ",self.downsample_factor) 1521 | # print(CQT.shape) 1522 | # print(self.lenghts.view(-1,1).shape) 1523 | CQT = CQT*self.downsample_factor/2**(self.n_octaves-1) # Normalizing the output with the downsampling factor, 2**(self.n_octaves-1) is make it same mag as 1992 1524 | CQT = CQT*torch.sqrt(self.lenghts.view(-1,1,1)) # Normalize again to get same result as librosa 1525 | 1526 | if self.output_format=='Magnitude': 1527 | if self.trainable==False: 1528 | # Getting CQT Amplitude 1529 | return torch.sqrt(CQT.pow(2).sum(-1)) 1530 | else: 1531 | return torch.sqrt(CQT.pow(2).sum(-1)+1e-8) 1532 | 1533 | elif self.output_format=='Complex': 1534 | return CQT 1535 | 1536 | elif self.output_format=='Phase': 1537 | phase_real = torch.cos(torch.atan2(CQT[:,:,:,1],CQT[:,:,:,0])) 1538 | phase_imag = torch.sin(torch.atan2(CQT[:,:,:,1],CQT[:,:,:,0])) 1539 | return torch.stack((phase_real,phase_imag), -1) 1540 | 1541 | def forward_manual(self,x): 1542 | x = broadcast_dim(x) 1543 | if self.earlydownsample==True: 1544 | x = downsampling_by_n(x, self.early_downsample_filter, self.downsample_factor) 1545 | hop = self.hop_length 1546 | CQT = self.get_cqt(x, hop, self.padding) #Getting the top octave CQT 1547 | 1548 | x_down = x # Preparing a new variable for downsampling 1549 | for i in range(self.n_octaves-1): 1550 | hop = hop//2 1551 | x_down = downsampling_by_2(x_down, self.lowpass_filter) 1552 | CQT1 = self.get_cqt(x_down, hop, self.padding) 1553 | CQT = torch.cat((CQT1, CQT),1) # 1554 | CQT = CQT[:,-self.n_bins:,:] #Removing unwanted bottom bins 1555 | CQT = CQT*2**(self.n_octaves-1) #Normalizing signals with respect to n_fft 1556 | # print("downsample_factor = ",self.downsample_factor) 1557 | CQT = CQT*self.downsample_factor/2**(self.n_octaves-1) # Normalizing the output with the downsampling factor, 2**(self.n_octaves-1) is make it same mag as 1992 1558 | 1559 | return CQT*torch.sqrt(self.lenghts.view(-1,1)) 1560 | 1561 | class CQT(CQT1992v2): 1562 | """An abbreviation for CQT1992v2. Please refer to the CQT1992v2 documentation""" 1563 | pass 1564 | --------------------------------------------------------------------------------