├── README.md ├── cqt_nsgt_pytorch ├── CQT_nsgt.py ├── __init__.py ├── __pycache__ │ ├── CQT_nsgt.cpython-310.pyc │ ├── CQT_nsgt.cpython-38.pyc │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-38.pyc │ ├── fscale.cpython-310.pyc │ ├── fscale.cpython-38.pyc │ ├── nsdual.cpython-310.pyc │ ├── nsdual.cpython-38.pyc │ ├── nsgfwin.cpython-310.pyc │ ├── nsgfwin.cpython-38.pyc │ ├── util.cpython-310.pyc │ └── util.cpython-38.pyc ├── fscale.py ├── nsdual.py ├── nsgfwin.py └── util.py ├── setup.py └── tests ├── test.m └── test_notebook.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # CQT_pytorch 2 | 3 | Pytorch implementation of the invertible CQT based on Non-stationary Gabor filters. 4 | 5 | The transform has near-perfect reconstruction, is differentiable and GPU-efficient. 6 | 7 | ## Install 8 | 9 | ```bash 10 | pip install cqt-nsgt-pytorch 11 | ``` 12 | ## Usage 13 | ```py 14 | from cqt_nsgt_pytorch import CQT_nsgt 15 | 16 | #parameter examples 17 | numocts=9 18 | binsoct=64 19 | fs=44100 20 | Ls=131072 21 | 22 | cqt=CQT_nsgt(numocts, binsoct, mode="matrix_complete",fs=fs, audio_len=Ls, device="cuda", dtype=torch.float32) 23 | 24 | audio=#load some audio file shape=[Batch, channels, time] 25 | 26 | X=cqt.fwd(audio)# forward transform 27 | #X.shape=[batch, channels, frequency, time] 28 | audio_reconstructed=cqt.bwd(X) #backward transform 29 | 30 | ``` 31 | ## Modes of operation 32 | 33 | Different versions of the transform are implemented. They can be selected by choosing the 'mode' parameter. Except "matrix" and "oct, that discard DC and Nyquist bands, the rest have perfect reconstruction. 34 | 35 | mode | Description | Output shape 36 | ------------- | ------------- | ------------- 37 | "critical" | (default) critical sampling (no redundancy) (slow implementation) | list of tensors, each with different time resolution 38 | "matrix" | Equal time resolution per frequency band. maximum redundancy (discards DC and Nyquist) | 2d-Tensor \[binsoct \times numocts, T\] 39 | "matrix_complete | Same as above, but DC and Nyquist are included. | 2d-Tensor \[binsoct \times numocts + 2, T\] 40 | "matrix_slow" | Slower version of "matrix_complete". Might show similar efficiency in CPU and consumes way less memory | 2d-Tensor \[binsoct \times numocts + 2, T\] 41 | "oct" | Tradeoff between structure and redundancy. THe frequency bins are grouped by octave bands, each octave with a different time resolution. The time lengths are restricted to be powers of 2. (Discards DC and Nyquist) | list of tensors, one per octave band, each with different time resolution 42 | "oct_complete" | Same as above, but DC and Nyquist are included | list of tensors, one per octave band,DC and Nyquist, each with a different time resolution 43 | 44 | 45 | 46 | ## TODO 47 | - [x] On "matrix" mode, give the option to output also the DC and Nyq. Same in "oct" mode. Document how this disacrding thing is implemented. 48 | - [ ] Do some proper documentation 49 | - [ ] Test it for mixed precision. problems with powers of 2, etc. Maybe this will require zero padding... 50 | - [ ] Make the apply_hpf_DC() and apply_lpf_DC() more handy and clear. Document the usage of those. 51 | - [ ] Accelerate the "critical" mode, similar method as in "oct" could also apply. (update: seems a bit tricky memory-wise) 52 | - [ ] Clean the whole __init__() method as now it is a mess. 53 | - [ ] Report the efficiency of the implementation in GPU. (time and frequency). Briefly: It is fast as everything is vectorized but maybe consumes too much memory, specially on the backward pass. 54 | - [x] Check if there is more redundancy to get rid of. Apparently, there is not 55 | -------------------------------------------------------------------------------- /cqt_nsgt_pytorch/CQT_nsgt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | #from src.nsgt.cq import NSGT 3 | 4 | from .fscale import LogScale , FlexLogOctScale 5 | 6 | from .nsgfwin import nsgfwin 7 | from .nsdual import nsdual 8 | from .util import calcwinrange 9 | 10 | import math 11 | from math import ceil 12 | 13 | def next_power_of_2(x): 14 | return 1 if x == 0 else 2**math.ceil(math.log2(x)) 15 | 16 | 17 | 18 | class CQT_nsgt(): 19 | def __init__(self,numocts, binsoct, mode="critical", window="hann", flex_Q=None, fs=44100, audio_len=44100, device="cpu", dtype=torch.float32): 20 | """ 21 | args: 22 | numocts (int) number of octaves 23 | binsoct (int) number of bins per octave. Can be a list if mode="flex_oct" 24 | mode (string) defines the mode of operation: 25 | "critical": (default) critical sampling (no redundancy) returns a list of tensors, each with different time resolution (slow implementation) 26 | "critical_fast": notimplemented 27 | "matrix": returns a 2d-matrix maximum redundancy (discards DC and Nyquist) 28 | "matrix_pow2": returns a 2d-matrix maximum redundancy (discards DC and Nyquist) (time-resolution is rounded up to a power of 2) 29 | "matrix_complete": returns a 2d-matrix maximum redundancy (with DC and Nyquist) 30 | "matrix_slow": returns a 2d-matrix maximum redundancy (slow implementation) 31 | "oct": octave-wise rasterization ( modearate redundancy) returns a list of tensors, each from a different octave with different time resolution (discards DC and Nyquist) 32 | "oct_complete": octave-wise rasterization ( modearate redundancy) returns a list of tensors, each from a different octave with different time resolution (with DC and Nyquist) 33 | fs (float) sampling frequency 34 | audio_len (int) sample length 35 | device 36 | """ 37 | 38 | fmax=fs/2 -10**-6 #the maximum frequency is Nyquist 39 | self.Ls=audio_len #the length is given 40 | 41 | fmin=fmax/(2**numocts) 42 | fbins=int(binsoct*numocts) 43 | self.numocts=numocts 44 | self.binsoct=binsoct 45 | 46 | if mode=="flex_oct": 47 | self.scale = FlexLogOctScale(fs, self.numocts, self.binsoct, time_reductions) 48 | else: 49 | self.scale = LogScale(fmin, fmax, fbins) 50 | 51 | self.fs=fs 52 | 53 | self.device=torch.device(device) 54 | self.mode=mode 55 | self.dtype=dtype 56 | 57 | self.frqs,self.q = self.scale() 58 | 59 | self.g,rfbas,self.M = nsgfwin(self.frqs, self.q, self.fs, self.Ls, dtype=self.dtype, device=self.device, min_win=4, window=window) 60 | 61 | sl = slice(0,len(self.g)//2+1) 62 | 63 | # coefficients per slice 64 | self.ncoefs = max(int(math.ceil(float(len(gii))/mii))*mii for mii,gii in zip(self.M[sl],self.g[sl])) 65 | if mode=="matrix" or mode=="matrix_complete" or mode=="matrix_slow": 66 | #just use the maximum resolution everywhere 67 | self.M[:] = self.M.max() 68 | elif mode=="matrix_pow2": 69 | self.size_per_oct=[] 70 | self.M[:]=next_power_of_2(self.M.max()) 71 | 72 | elif mode=="oct" or mode=="oct_complete": 73 | #round uo all the lengths of an octave to the next power of 2 74 | self.size_per_oct=[] 75 | idx=1 76 | for i in range(numocts): 77 | value=next_power_of_2(self.M[idx:idx+binsoct].max()) 78 | 79 | #value=M[idx:idx+binsoct].max() 80 | self.size_per_oct.append(value) 81 | self.M[idx:idx+binsoct]=value 82 | self.M[-idx-binsoct:-idx]=value 83 | idx+=binsoct 84 | 85 | 86 | # calculate shifts 87 | self.wins,self.nn = calcwinrange(self.g, rfbas, self.Ls, device=self.device) 88 | # calculate dual windows 89 | self.gd = nsdual(self.g, self.wins, self.nn, self.M, dtype=self.dtype, device=self.device) 90 | 91 | #filter DC 92 | self.Hlpf=torch.zeros(self.Ls, dtype=self.dtype, device=self.device) 93 | self.Hlpf[0:len(self.g[0])//2]=self.g[0][:len(self.g[0])//2]*self.gd[0][:len(self.g[0])//2]*self.M[0] 94 | self.Hlpf[-len(self.g[0])//2:]=self.g[0][len(self.g[0])//2:]*self.gd[0][len(self.g[0])//2:]*self.M[0] 95 | #filter nyquist 96 | nyquist_idx=len(self.g)//2 97 | Lg=len(self.g[nyquist_idx]) 98 | self.Hlpf[self.wins[nyquist_idx][0:(Lg+1)//2]]+=self.g[nyquist_idx][(Lg)//2:]*self.gd[nyquist_idx][(Lg)//2:]*self.M[nyquist_idx] 99 | self.Hlpf[self.wins[nyquist_idx][-(Lg-1)//2:]]+=self.g[nyquist_idx][:(Lg)//2]*self.gd[nyquist_idx][:(Lg)//2]*self.M[nyquist_idx] 100 | 101 | self.Hhpf=1-self.Hlpf 102 | 103 | #FORWARD!! this is from nsgtf 104 | #self.forward = lambda s: nsgtf(s, self.g, self.wins, self.nn, self.M, mode=self.mode , device=self.device) 105 | #sl = slice(0,len(self.g)//2+1) 106 | if mode=="matrix" or mode=="oct" or mode=="matrix_pow2": 107 | sl = slice(1,len(self.g)//2) #getting rid of the DC component and the Nyquist 108 | else: 109 | sl = slice(0,len(self.g)//2+1) 110 | 111 | self.maxLg_enc = max(int(ceil(float(len(gii))/mii))*mii for mii,gii in zip(self.M[sl], self.g[sl])) 112 | 113 | self.loopparams_enc = [] 114 | for mii,gii,win_range in zip(self.M[sl],self.g[sl],self.wins[sl]): 115 | Lg = len(gii) 116 | col = int(ceil(float(Lg)/mii)) 117 | assert col*mii >= Lg 118 | assert col == 1 119 | p = (mii,win_range,Lg,col) 120 | self.loopparams_enc.append(p) 121 | 122 | 123 | def get_ragged_giis(g, wins, ms, mode): 124 | #ragged_giis = [torch.nn.functional.pad(torch.unsqueeze(gii, dim=0), (0, self.maxLg_enc-gii.shape[0])) for gii in gd[sl]] 125 | #ragged_giis=[] 126 | c=torch.zeros((len(g),self.Ls//2+1),dtype=self.dtype,device=self.device) 127 | ix=[] 128 | if mode=="oct": 129 | for i in range(self.numocts): 130 | ix.append(torch.zeros((self.binsoct,self.size_per_oct[i]),dtype=torch.int64,device=self.device)) 131 | elif mode=="matrix" or mode=="matrix_pow2": 132 | ix.append(torch.zeros((len(g),self.maxLg_enc),dtype=torch.int64,device=self.device)) 133 | 134 | elif mode=="oct_complete" or mode=="matrix_complete": 135 | ix.append(torch.zeros((1,ms[0]),dtype=torch.int64,device=self.device)) 136 | count=0 137 | for i in range(1,len(g)-1): 138 | if count==0 or ms[i] == ms[i-1]: 139 | count+=1 140 | else: 141 | ix.append(torch.zeros((count,ms[i-1]),dtype=torch.int64,device=self.device)) 142 | count=1 143 | 144 | ix.append(torch.zeros((count,ms[i-1]),dtype=torch.int64,device=self.device)) 145 | 146 | ix.append(torch.zeros((1,ms[-1]),dtype=torch.int64,device=self.device)) 147 | 148 | j=0 149 | k=0 150 | for i,(gii, win_range) in enumerate(zip(g,wins)): 151 | if i>0: 152 | if ms[i]!=ms[i-1] or ((mode=="oct_complete" or mode=="matrix_complete") and (j==0 or i==len(g)-1)): 153 | j+=1 154 | k=0 155 | 156 | gii=torch.fft.fftshift(gii).unsqueeze(0) 157 | Lg=gii.shape[1] 158 | 159 | if (i==0 or i==len(g)-1) and (mode=="oct_complete" or mode=="matrix_complete"): 160 | #special case for the DC and Nyquist, as we don't want to use the mirrored frequencies, take this into account during forward! we would just need to conjugate or sth! 161 | if i==0: 162 | c[i,win_range[Lg//2:]]=gii[...,Lg//2:] 163 | 164 | ix[j][0,:(Lg+1)//2]=win_range[Lg//2:].unsqueeze(0) 165 | ix[j][0,-(Lg//2):]=torch.flip(win_range[Lg//2:].unsqueeze(0),(-1,)) 166 | if i==len(g)-1: 167 | c[i,win_range[:(Lg+1)//2]]=gii[...,:(Lg+1)//2] 168 | 169 | ix[j][0,:(Lg+1)//2]=torch.flip(win_range[:(Lg+1)//2].unsqueeze(0),(-1,)) #rethink this 170 | ix[j][0,-(Lg//2):]=win_range[:(Lg)//2].unsqueeze(0) 171 | else: 172 | c[i,win_range]=gii 173 | 174 | ix[j][k,:(Lg+1)//2]=win_range[Lg//2:].unsqueeze(0) 175 | ix[j][k,-(Lg//2):]=win_range[:Lg//2].unsqueeze(0) 176 | 177 | k+=1 178 | #a=torch.unsqueeze(gii, dim=0) 179 | #b=torch.nn.functional.pad(a, (0, self.maxLg_enc-gii.shape[0])) 180 | #ragged_giis.append(b) 181 | #dirty unsqueeze 182 | return torch.conj(c), ix 183 | 184 | 185 | if self.mode=="matrix" or self.mode=="matrix_complete" or self.mode=="matrix_pow2": 186 | self.giis, self.idx_enc=get_ragged_giis(self.g[sl], self.wins[sl], self.M[sl],self.mode) 187 | #self.idx_enc=self.idx_enc[0] 188 | #self.idx_enc=self.idx_enc.unsqueeze(0).unsqueeze(0) 189 | elif self.mode=="oct" or self.mode=="oct_complete": 190 | self.giis, self.idx_enc=get_ragged_giis(self.g[sl], self.wins[sl], self.M[sl], self.mode) 191 | #self.idx_enc=self.idx_enc.unsqueeze(0).unsqueeze(0) 192 | elif self.mode=="critical" or self.mode=="matrix_slow": 193 | #self.giis, self.idx_enc=get_ragged_giis(self.g[sl], self.wins[sl], self.M[sl], self.mode) 194 | 195 | ragged_giis = [torch.nn.functional.pad(torch.unsqueeze(gii, dim=0), (0, self.maxLg_enc-gii.shape[0])) for gii in self.g[sl]] 196 | self.giis = torch.conj(torch.cat(ragged_giis)) 197 | #ragged_giis = [torch.nn.functional.pad(torch.unsqueeze(gii, dim=0), (0, self.maxLg_enc-gii.shape[0])) for gii in self.g[sl]] 198 | 199 | #self.giis = torch.conj(torch.cat(ragged_giis)) 200 | 201 | #FORWARD!! this is from nsigtf 202 | #self.backward = lambda c: nsigtf(c, self.gd, self.wins, self.nn, self.Ls, mode=self.mode, device=self.device) 203 | 204 | self.maxLg_dec = max(len(gdii) for gdii in self.gd) 205 | if self.mode=="matrix_pow2": 206 | self.maxLg_dec=self.maxLg_enc 207 | #self.maxLg_dec=self.maxLg_enc 208 | #print(self.maxLg_enc, self.maxLg_dec) 209 | 210 | #ragged_gdiis = [torch.nn.functional.pad(torch.unsqueeze(gdii, dim=0), (0, self.maxLg_dec-gdii.shape[0])) for gdii in self.gd] 211 | #self.gdiis = torch.conj(torch.cat(ragged_gdiis)) 212 | 213 | def get_ragged_gdiis(gd, wins, mode, ms=None): 214 | ragged_gdiis=[] 215 | ix=torch.zeros((len(gd),self.Ls//2+1),dtype=torch.int64,device=self.device)+self.maxLg_dec//2#I initialize the index with the center to make sure that it points to a 0 216 | for i,(g, win_range) in enumerate(zip(gd, wins)): 217 | Lg=g.shape[0] 218 | gl=g[:(Lg+1)//2] 219 | gr=g[(Lg+1)//2:] 220 | zeros = torch.zeros(self.maxLg_dec-Lg ,dtype=g.dtype, device=g.device) # pre-allocation 221 | paddedg=torch.cat((gl, zeros, gr),0).unsqueeze(0) 222 | ragged_gdiis.append(paddedg) 223 | 224 | wr1 = win_range[:(Lg)//2] 225 | wr2 = win_range[-((Lg+1)//2):] 226 | if mode=="matrix_complete" and i==0: 227 | #ix[i,wr1]=torch.Tensor([self.maxLg_dec-(Lg//2)+i for i in range(len(wr1))]).to(torch.int64) #the end part 228 | ix[i,wr2]=torch.Tensor([i for i in range(len(wr2))]).to(torch.int64).to(self.device) #the start part 229 | elif mode=="matrix_complete" and i==len(gd)-1: 230 | ix[i,wr1]=torch.Tensor([self.maxLg_dec-(Lg//2)+i for i in range(len(wr1))]).to(torch.int64).to(self.device) #the end part 231 | #ix[i,wr2]=torch.Tensor([i for i in range(len(wr2))]).to(torch.int64) #the start part 232 | else: 233 | ix[i,wr1]=torch.Tensor([self.maxLg_dec-(Lg//2)+i for i in range(len(wr1))]).to(torch.int64).to(self.device) #the end part 234 | ix[i,wr2]=torch.Tensor([i for i in range(len(wr2))]).to(torch.int64).to(self.device) #the start part 235 | 236 | 237 | return torch.conj(torch.cat(ragged_gdiis)).to(self.dtype)*self.maxLg_dec, ix 238 | 239 | def get_ragged_gdiis_critical(gd, ms): 240 | seq_gdiis=[] 241 | ragged_gdiis=[] 242 | mprev=-1 243 | for i,(g,m) in enumerate(zip(gd, ms)): 244 | if i>0 and m!=mprev: 245 | gdii=torch.conj(torch.cat(ragged_gdiis)) 246 | if len(gdii.shape)==1: 247 | gdii=gdii.unsqueeze(0) 248 | #seq_gdiis.append(gdii[0:gdii.shape[0]//2 +1]) 249 | seq_gdiis.append(gdii) 250 | ragged_gdiis=[] 251 | 252 | Lg=g.shape[0] 253 | gl=g[:(Lg+1)//2] 254 | gr=g[(Lg+1)//2:] 255 | zeros = torch.zeros(m-Lg ,dtype=g.dtype, device=g.device) # pre-allocation 256 | paddedg=torch.cat((gl, zeros, gr),0).unsqueeze(0)*m 257 | ragged_gdiis.append(paddedg) 258 | mprev=m 259 | 260 | gdii=torch.conj(torch.cat(ragged_gdiis)) 261 | seq_gdiis.append(gdii) 262 | #seq_gdiis.append(gdii[0:gdii.shape[0]//2 +1]) 263 | return seq_gdiis 264 | 265 | def get_ragged_gdiis_oct(gd, ms, wins, mode): 266 | seq_gdiis=[] 267 | ragged_gdiis=[] 268 | mprev=-1 269 | ix=[] 270 | if mode=="oct_complete": 271 | ix+=[torch.zeros((1,self.Ls//2+1),dtype=torch.int64,device=self.device)+ms[0]//2] 272 | 273 | ix+=[torch.zeros((self.binsoct,self.Ls//2+1),dtype=torch.int64,device=self.device)+self.size_per_oct[j]//2 for j in range(len(self.size_per_oct))] 274 | if mode=="oct_complete": 275 | ix+=[torch.zeros((1,self.Ls//2+1),dtype=torch.int64,device=self.device)+ms[-1]//2] 276 | 277 | #I nitialize the index with the center to make sure that it points to a 0 278 | j=0 279 | k=0 280 | for i,(g,m, win_range) in enumerate(zip(gd, ms, wins)): 281 | if i>0 and m!=mprev or (mode=="oct_complete" and i==len(gd)-1): 282 | #take care when size of DC is the same as the next octave, or last octave has the same size as nyquist! 283 | gdii=torch.conj(torch.cat(ragged_gdiis)) 284 | if len(gdii.shape)==1: 285 | gdii=gdii.unsqueeze(0) 286 | #seq_gdiis.append(gdii[0:gdii.shape[0]//2 +1]) 287 | seq_gdiis.append(gdii.to(self.dtype)) 288 | ragged_gdiis=[] 289 | j+=1 290 | k=0 291 | 292 | Lg=g.shape[0] 293 | gl=g[:(Lg+1)//2] 294 | gr=g[(Lg+1)//2:] 295 | zeros = torch.zeros(m-Lg ,dtype=g.dtype, device=g.device) # pre-allocation 296 | paddedg=torch.cat((gl, zeros, gr),0).unsqueeze(0)*m 297 | ragged_gdiis.append(paddedg) 298 | mprev=m 299 | 300 | wr1 = win_range[:(Lg)//2] 301 | wr2 = win_range[-((Lg+1)//2):] 302 | if mode=="oct_complete" and i==0: 303 | #ix[i,wr1]=torch.Tensor([self.maxLg_dec-(Lg//2)+i for i in range(len(wr1))]).to(torch.int64) #the end part 304 | #ix[i,wr2]=torch.Tensor([i for i in range(len(wr2))]).to(torch.int64) #the start part 305 | ix[0][k,wr2]=torch.Tensor([i for i in range(len(wr2))]).to(self.device).to(torch.int64) #the start part 306 | elif mode=="oct_complete" and i==len(gd)-1: 307 | #ix[i,wr1]=torch.Tensor([self.maxLg_dec-(Lg//2)+i for i in range(len(wr1))]).to(torch.int64) #the end part 308 | ix[-1][k,wr1]=torch.Tensor([m-(Lg//2)+i for i in range(len(wr1))]).to(self.device).to(torch.int64) #the end part 309 | #ix[i,wr2]=torch.Tensor([i for i in range(len(wr2))]).to(torch.int64) #the start part 310 | else: 311 | #ix[i,wr1]=torch.Tensor([self.maxLg_dec-(Lg//2)+i for i in range(len(wr1))]).to(torch.int64) #the end part 312 | #ix[i,wr2]=torch.Tensor([i for i in range(len(wr2))]).to(torch.int64) #the start part 313 | 314 | ix[j][k,wr1]=torch.Tensor([m-(Lg//2)+i for i in range(len(wr1))]).to(self.device).to(torch.int64) #the end part 315 | ix[j][k,wr2]=torch.Tensor([i for i in range(len(wr2))]).to(self.device).to(torch.int64) #the start part 316 | k+=1 317 | 318 | gdii=torch.conj(torch.cat(ragged_gdiis)) 319 | seq_gdiis.append(gdii.to(self.dtype)) 320 | #seq_gdiis.append(gdii[0:gdii.shape[0]//2 +1]) 321 | 322 | return seq_gdiis, ix 323 | 324 | if self.mode=="matrix" or self.mode=="matrix_complete": 325 | self.gdiis, self.idx_dec= get_ragged_gdiis(self.gd[sl], self.wins[sl], self.mode) 326 | #self.gdiis = self.gdiis[sl] 327 | #self.gdiis = self.gdiis[0:(self.gdiis.shape[0]//2 +1)] 328 | elif self.mode=="matrix_pow2": 329 | self.gdiis, self.idx_dec= get_ragged_gdiis(self.gd[sl], self.wins[sl], self.mode, ms=self.M[sl]) 330 | elif self.mode=="oct" or self.mode=="oct_complete": 331 | self.gdiis, self.idx_dec=get_ragged_gdiis_oct(self.gd[sl], self.M[sl], self.wins[sl], self.mode) 332 | for gdiis in self.gdiis: 333 | gdiis.to(self.dtype) 334 | 335 | elif self.mode=="critical": 336 | self.gdiis =get_ragged_gdiis_critical(self.gd[sl], self.M[sl]) 337 | elif self.mode=="matrix_slow": 338 | ragged_gdiis = [torch.nn.functional.pad(torch.unsqueeze(gdii, dim=0), (0, self.maxLg_dec-gdii.shape[0])) for gdii in self.gd] 339 | self.gdiis = torch.conj(torch.cat(ragged_gdiis)) 340 | 341 | self.loopparams_dec = [] 342 | for gdii,win_range in zip(self.gd[sl], self.wins[sl]): 343 | Lg = len(gdii) 344 | wr1 = win_range[:(Lg)//2] 345 | wr2 = win_range[-((Lg+1)//2):] 346 | p = (wr1,wr2,Lg) 347 | self.loopparams_dec.append(p) 348 | 349 | def apply_hpf_DC(self, x): 350 | Lin=x.shape[-1] 351 | if Lin self.Ls: 355 | raise ValueError("Input signal is longer than the maximum length. I could have patched it, but I didn't. sorry :(") 356 | 357 | X=torch.fft.fft(x) 358 | X=X*torch.conj(self.Hhpf) 359 | out= torch.fft.ifft(X).real 360 | if Lin self.Ls: 371 | raise ValueError("Input signal is longer than the maximum length. I could have patched it, but I didn't. sorry :(") 372 | X=torch.fft.fft(x) 373 | X=X*torch.conj(self.Hlpf) 374 | out= torch.fft.ifft(X).real 375 | if Lin 0) 49 | if lim != 0: 50 | # f partly <= 0 51 | f = f[lim:] 52 | q = q[lim:] 53 | 54 | lim = np.argmax(f >= nf) 55 | if lim != 0: 56 | # f partly >= nf 57 | f = f[:lim] 58 | q = q[:lim] 59 | 60 | assert len(f) == len(q) 61 | assert np.all((f[1:]-f[:-1]) > 0) # frequencies must be increasing 62 | assert np.all(q > 0) # all q must be > 0 63 | 64 | qneeded = f*(Ls/(8.*sr)) 65 | #if np.any(q >= qneeded) and dowarn: 66 | # warn("Q-factor too high for frequencies %s"%",".join("%.2f"%fi for fi in f[q >= qneeded])) 67 | 68 | fbas = f 69 | lbas = len(fbas) 70 | 71 | frqs = np.concatenate(((0.,),fbas,(nf,))) 72 | 73 | fbas = np.concatenate((frqs,sr-frqs[-2:0:-1])) 74 | 75 | # at this point: fbas.... frequencies in Hz 76 | 77 | fbas *= float(Ls)/sr 78 | 79 | # Omega[k] in the paper 80 | M = np.zeros(fbas.shape, dtype=int) 81 | M[0] = np.round(2*fbas[1]) 82 | #M[1]= 83 | M[1] = np.round(fbas[1]/q[0]) 84 | for k in range(2,lbas+1): 85 | #M[k] = np.round(fbas[k]/q[k-1]) 86 | M[k]= np.round(fbas[k+1]-fbas[k-1]) #this is nyq! 87 | #M[k] = 88 | #M[lbas]=np.round(fbas[lbas]/q[-1]) 89 | M[lbas+1]= np.round(fbas[k+1]-fbas[k-1]) #this is nyq! 90 | M[lbas+2:]=M[lbas:0:-1] #symmetry! 91 | 92 | #M[-1] = np.round(Ls-fbas[-2]) 93 | 94 | np.clip(M, min_win, np.inf, out=M) 95 | 96 | 97 | if window=="hann": 98 | print("using a hann window") 99 | g = [hannwin(m, device=device).to(dtype) for m in M] 100 | elif window=="blackharr": 101 | print("using a blackharr window") 102 | g = [blackharr(m, device=device).to(dtype) for m in M] 103 | elif window[0]=="kaiser": 104 | print("using a kaiser window with beta=",window[1]) 105 | str, beta= window 106 | g = [kaiserwin(m,beta, device=device).to(dtype) for m in M] 107 | 108 | #g[0]=tukeywin(M[0], 0.2, device=device).to(dtype) 109 | 110 | fbas[lbas] = (fbas[lbas-1]+fbas[lbas+1])/2 111 | fbas[lbas+2] = Ls-fbas[lbas] 112 | rfbas = np.round(fbas).astype(int) 113 | 114 | 115 | return g,rfbas,M 116 | -------------------------------------------------------------------------------- /cqt_nsgt_pytorch/util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 2 | 3 | """ 4 | Python implementation of Non-Stationary Gabor Transform (NSGT) 5 | derived from MATLAB code by NUHAG, University of Vienna, Austria 6 | 7 | Thomas Grill, 2011-2015 8 | http://grrrr.org/nsgt 9 | 10 | Austrian Research Institute for Artificial Intelligence (OFAI) 11 | AudioMiner project, supported by Vienna Science and Technology Fund (WWTF) 12 | """ 13 | 14 | import numpy as np 15 | import torch 16 | from math import exp, floor, ceil, pi 17 | #import scipy.signal 18 | 19 | 20 | def hannwin(l, device="cpu"): 21 | r = torch.arange(l,dtype=float, device=torch.device(device)) 22 | r *= np.pi*2./l 23 | r = torch.cos(r) 24 | r += 1. 25 | r *= 0.5 26 | return r 27 | 28 | #design a kaiser window 29 | def kaiserwin(l, beta, device="cpu"): 30 | beta=torch.tensor(beta, dtype=float, device=torch.device(device)) 31 | r = torch.arange(l,dtype=float, device=torch.device(device)) 32 | r *= np.pi*2./l 33 | r = torch.cos(r) 34 | r += 1. 35 | r *= 0.5 36 | r = torch.sqrt(r) 37 | r = torch.i0(beta*torch.sqrt(1.-r**2))/(2.*torch.i0(beta)) 38 | r=torch.roll(r, l//2) 39 | return r 40 | 41 | 42 | 43 | #alternative windows!! maybe could be interesting to switch to get better time or freq resolution, who knows... 44 | def blackharr(n, l=None, mod=True, device="cpu"): 45 | if l is None: 46 | l = n 47 | nn = (n//2)*2 48 | k = torch.arange(n, device=torch.device(device)) 49 | if not mod: 50 | bh = 0.35875 - 0.48829*torch.cos(k*(2*pi/nn)) + 0.14128*torch.cos(k*(4*pi/nn)) -0.01168*torch.cos(k*(6*pi/nn)) 51 | else: 52 | bh = 0.35872 - 0.48832*torch.cos(k*(2*pi/nn)) + 0.14128*torch.cos(k*(4*pi/nn)) -0.01168*torch.cos(k*(6*pi/nn)) 53 | bh = torch.hstack((bh,torch.zeros(l-n,dtype=bh.dtype,device=torch.device(device)))) 54 | bh = torch.hstack((bh[-n//2:],bh[:-n//2])) 55 | return bh 56 | 57 | def blackharrcw(bandwidth,corr_shift): 58 | flip = -1 if corr_shift < 0 else 1 59 | corr_shift *= flip 60 | 61 | M = np.ceil(bandwidth/2+corr_shift-1)*2 62 | win = np.concatenate((np.arange(M//2,M), np.arange(0,M//2)))-corr_shift 63 | win = (0.35872 - 0.48832*np.cos(win*(2*np.pi/bandwidth))+ 0.14128*np.cos(win*(4*np.pi/bandwidth)) -0.01168*np.cos(win*(6*np.pi/bandwidth)))*(win <= bandwidth)*(win >= 0) 64 | 65 | return win[::flip],M 66 | 67 | 68 | 69 | def _isseq(x): 70 | try: 71 | len(x) 72 | except TypeError: 73 | return False 74 | return True 75 | 76 | 77 | def calcwinrange(g, rfbas, Ls, device="cpu"): 78 | shift = np.concatenate(((np.mod(-rfbas[-1],Ls),), rfbas[1:]-rfbas[:-1])) 79 | 80 | timepos = np.cumsum(shift) 81 | nn = timepos[-1] 82 | timepos -= shift[0] # Calculate positions from shift vector 83 | 84 | wins = [] 85 | for gii,tpii in zip(g, timepos): 86 | Lg = len(gii) 87 | win_range = torch.arange(-(Lg//2)+tpii, Lg-(Lg//2)+tpii, dtype=int, device=torch.device(device)) 88 | win_range %= nn 89 | 90 | wins.append(win_range) 91 | 92 | return wins,nn 93 | 94 | 95 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="cqt-nsgt-pytorch", 5 | packages=find_packages(exclude=[]), 6 | version="0.0.9", 7 | license="MIT", 8 | description="Pytorch implementation of an invertible and differentiable Constant-Q Transform based Non-stationary Gabor Transform (NSGT) for audio processing.", 9 | long_description_content_type="text/markdown", 10 | author="Eloi Moliner", 11 | author_email="eloi.moliner@aalto.fi", 12 | url="https://github.com/eloimoliner/CQT_pytorch", 13 | keywords=["audio processing", "constant-q transform", "deep learning", "pytorch", "nsgt"], 14 | install_requires=[ 15 | "torch>=1.13.0", 16 | "numpy>=1.19.5", 17 | ], 18 | classifiers=[ 19 | "Development Status :: 4 - Beta", 20 | "Intended Audience :: Developers", 21 | "License :: OSI Approved :: MIT License", 22 | "Programming Language :: Python :: 3.6", 23 | ], 24 | ) 25 | -------------------------------------------------------------------------------- /tests/test.m: -------------------------------------------------------------------------------- 1 | clear all; 2 | close all; 3 | [a,fs]=audioread("test_dir/0.wav"); 4 | 5 | a=a(1:131072); 6 | 7 | A=fft(a); 8 | N=length(A) 9 | 10 | L=512 11 | k=1:L 12 | 13 | g=(0.5+0.5*cos(k.*pi*2/L)) 14 | H=zeros(length(A),1) 15 | H(1:L/2)=g(1:L/2) 16 | H(N-L/2+1:end)=g(L/2+1:end) 17 | 18 | Hhpf=1-H 19 | 20 | B=A.*Hhpf 21 | b=real(ifft(B)) 22 | 23 | C=fft(b) 24 | 25 | --------------------------------------------------------------------------------