├── README.md ├── LICENSE ├── .gitignore ├── conda.yml ├── utils.py ├── vis-strfs.py └── models.py /README.md: -------------------------------------------------------------------------------- 1 | # strfnet-IS2020 2 | Learnable Spectro-Temporal Receptive Fields for Robust Voice Type Discrimination 3 | 4 | Official repo for the PyTorch implementation of the STRFNet system appeared in INTERSPEECH2020 5 | 6 | 7 | [[Paper](https://www.isca-speech.org/archive/Interspeech_2020/pdfs/1878.pdf)]\ 8 | [[Arxiv](https://arxiv.org/abs/2010.09151)] 9 | 10 | ### Citation 11 | ``` 12 | @inproceedings{vuong2020learnable, 13 | author={Tyler Vuong and Yangyang Xia and Richard M. Stern}, 14 | title={Learnable Spectro-Temporal Receptive Fields for Robust Voice Type Discrimination}, 15 | year=2020, 16 | month = oct, 17 | booktitle={Interspeech 2020}, 18 | pages={1957--1961}, 19 | publisher = {{ISCA}}, 20 | doi={10.21437/Interspeech.2020-1878}, 21 | url={http://dx.doi.org/10.21437/Interspeech.2020-1878} 22 | } 23 | 24 | ``` 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Tyler Vuong, Yangyang Raymond Xia, Richard Stern 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /conda.yml: -------------------------------------------------------------------------------- 1 | name: strfnet 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - blas=1.0=mkl 10 | - ca-certificates=2020.12.5=ha878542_0 11 | - certifi=2020.12.5=py38h578d9bd_1 12 | - cudatoolkit=11.0.221=h6bb024c_0 13 | - cycler=0.10.0=py_2 14 | - dbus=1.13.18=hb2f20db_0 15 | - expat=2.2.10=he6710b0_2 16 | - fontconfig=2.13.1=he4413a7_1000 17 | - freetype=2.10.4=h5ab3b9f_0 18 | - glib=2.66.1=h92f7085_0 19 | - gst-plugins-base=1.14.0=hbbd80ab_1 20 | - gstreamer=1.14.0=h28cd5cc_2 21 | - icu=58.2=hf484d3e_1000 22 | - intel-openmp=2020.2=254 23 | - jpeg=9b=h024ee3a_2 24 | - kiwisolver=1.3.1=py38h82cb98a_0 25 | - lcms2=2.11=h396b838_0 26 | - ld_impl_linux-64=2.33.1=h53a641e_7 27 | - libedit=3.1.20191231=h14c3975_1 28 | - libffi=3.3=he6710b0_2 29 | - libgcc-ng=9.1.0=hdf63c60_0 30 | - libgfortran-ng=7.3.0=hdf63c60_0 31 | - libpng=1.6.37=hbc83047_0 32 | - libstdcxx-ng=9.1.0=hdf63c60_0 33 | - libtiff=4.1.0=h2733197_1 34 | - libuuid=2.32.1=h14c3975_1000 35 | - libuv=1.40.0=h7b6447c_0 36 | - libxcb=1.13=h14c3975_1002 37 | - libxml2=2.9.10=hb55368b_3 38 | - lz4-c=1.9.3=h2531618_0 39 | - matplotlib=3.3.3=py38h578d9bd_0 40 | - matplotlib-base=3.3.3=py38h5c7f4ab_0 41 | - mkl=2020.2=256 42 | - mkl-service=2.3.0=py38he904b0f_0 43 | - mkl_fft=1.2.0=py38h23d657b_0 44 | - mkl_random=1.1.1=py38h0573a6f_0 45 | - ncurses=6.2=he6710b0_1 46 | - ninja=1.10.2=py38hff7bd54_0 47 | - numpy=1.19.2=py38h54aff64_0 48 | - numpy-base=1.19.2=py38hfa32c7d_0 49 | - olefile=0.46=py_0 50 | - openssl=1.1.1i=h27cfd23_0 51 | - pcre=8.44=he1b5a44_0 52 | - pillow=8.1.0=py38he98fc37_0 53 | - pip=20.3.3=py38h06a4308_0 54 | - pthread-stubs=0.4=h36c2ea0_1001 55 | - pyparsing=2.4.7=pyh9f0ad1d_0 56 | - pyqt=5.9.2=py38h05f1152_4 57 | - python=3.8.5=h7579374_1 58 | - python-dateutil=2.8.1=py_0 59 | - python_abi=3.8=1_cp38 60 | - pytorch=1.7.1=py3.8_cuda11.0.221_cudnn8.0.5_0 61 | - qt=5.9.7=h5867ecd_1 62 | - readline=8.0=h7b6447c_0 63 | - scipy=1.5.2=py38h0b6359f_0 64 | - setuptools=51.3.3=py38h06a4308_4 65 | - sip=4.19.13=py38he6710b0_0 66 | - six=1.15.0=py38h06a4308_0 67 | - sqlite=3.33.0=h62c20be_0 68 | - tk=8.6.10=hbc83047_0 69 | - torchaudio=0.7.2=py38 70 | - torchvision=0.8.2=py38_cu110 71 | - tornado=6.1=py38h25fe258_0 72 | - typing_extensions=3.7.4.3=py_0 73 | - wheel=0.36.2=pyhd3eb1b0_0 74 | - xorg-libxau=1.0.9=h14c3975_0 75 | - xorg-libxdmcp=1.1.3=h516909a_0 76 | - xz=5.2.5=h7b6447c_0 77 | - zlib=1.2.11=h7b6447c_3 78 | - zstd=1.4.5=h9ceee32_0 79 | prefix: /home/xyy/anaconda3/envs/strfnet 80 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for STRFNet.""" 2 | import torch 3 | import torch.nn as nn 4 | 5 | def nextpow2(n): 6 | """Give next power of 2 bigger than n.""" 7 | return 1 << (n-1).bit_length() 8 | 9 | def hilbert(x, ndft=None): 10 | r"""Analytic signal of x. 11 | 12 | Return the analytic signal of a real signal x, x + j\hat{x}, where \hat{x} 13 | is the Hilbert transform of x. 14 | 15 | Parameters 16 | ---------- 17 | x: torch.Tensor 18 | Audio signal to be analyzed. 19 | Always assumes x is real, and x.shape[-1] is the signal length. 20 | 21 | Returns 22 | ------- 23 | out: torch.Tensor 24 | out.shape == (*x.shape, 2) 25 | 26 | """ 27 | if ndft is None: 28 | sig = x 29 | else: 30 | assert ndft > x.size(-1) 31 | sig = F.pad(x, (0, ndft-x.size(-1))) 32 | xspec = torch.rfft(sig, 1, onesided=False) 33 | siglen = sig.size(-1) 34 | h = torch.zeros(siglen, 2, dtype=sig.dtype, device=sig.device) 35 | if siglen % 2 == 0: 36 | h[0] = h[siglen//2] = 1 37 | h[1:siglen//2] = 2 38 | else: 39 | h[0] = 1 40 | h[1:(siglen+1)//2] = 2 41 | 42 | return torch.ifft(xspec * h, 1) 43 | 44 | 45 | class MLP(nn.Module): 46 | """Multi-Layer Perceptron.""" 47 | 48 | def __init__(self, indim, outdim, hiddims=[], bias=True, 49 | activate_hid=nn.ReLU(), activate_out=nn.ReLU(), 50 | batchnorm=[]): 51 | """Initialize a MLP. 52 | 53 | Parameters 54 | ---------- 55 | indim: int 56 | Input dimension to the MLP. 57 | outdim: int 58 | Output dimension to the MLP. 59 | hiddims: list of int 60 | A list of hidden dimensions. Default ([]) means no hidden layers. 61 | bias: bool [True] 62 | Apply bias for this network? 63 | activate_hid: callable, optional 64 | Activation function for hidden layers. Default to ReLU. 65 | activate_out: callable, optional 66 | Activation function for output layer. Default to ReLU. 67 | 68 | """ 69 | super(MLP, self).__init__() 70 | self.indim = indim 71 | self.outdim = outdim 72 | self.hiddims = hiddims 73 | self.nhidden = len(hiddims) 74 | if self.nhidden == 0: 75 | print("No hidden layers.") 76 | indims = [indim] + hiddims 77 | outdims = hiddims + [outdim] 78 | self.layers = nn.ModuleList([]) 79 | for ii in range(self.nhidden): 80 | self.layers.append(nn.Linear(indims[ii], outdims[ii], bias=bias)) 81 | if len(batchnorm) > 0 and batchnorm[ii]: 82 | self.layers.append(nn.BatchNorm1d(outdims[ii], momentum=0.05)) 83 | self.layers.append(activate_hid) 84 | self.layers.append(nn.Linear(indims[-1], outdims[-1], bias=bias)) 85 | if activate_out is not None: 86 | self.layers.append(activate_out) 87 | 88 | def forward(self, x): 89 | """One forward pass.""" 90 | for layer in self.layers: 91 | x = layer(x) 92 | return x 93 | -------------------------------------------------------------------------------- /vis-strfs.py: -------------------------------------------------------------------------------- 1 | """Visualize spectral-temporal receptive fields at different scales.""" 2 | import numpy as np 3 | import scipy.signal as signal 4 | 5 | 6 | def strf(time, freq, sr, bins_per_octave, rate=1, scale=1, phi=0, theta=0, 7 | ndft=None): 8 | """Spectral-temporal response fields for both up and down direction. 9 | 10 | Implement the STRF described in Chi, Ru, and Shamma: 11 | Chi, T., Ru, P., & Shamma, S. A. (2005). Multiresolution spectrotemporal 12 | analysis of complex sounds. The Journal of the Acoustical Society of 13 | America, 118(2), 887–906. https://doi.org/10.1121/1.1945807. 14 | 15 | Parameters 16 | ---------- 17 | time: int or float 18 | Time support in seconds. The returned STRF will cover the range 19 | [0, time). 20 | freq: int or float 21 | Frequency support in number of octaves. The returned STRF will 22 | cover the range [-freq, freq). 23 | sr: int 24 | Sampling rate in Hz. 25 | bins_per_octave: int 26 | Number of frequency bins per octave on the log-frequency scale. 27 | rate: int or float 28 | Stretch factor in time. 29 | scale: int or float 30 | Stretch factor in frequency. 31 | phi: float 32 | Orientation of spectral evolution in radians. 33 | theta: float 34 | Orientation of time evolution in radians. 35 | 36 | """ 37 | def _hs(x, scale): 38 | """Construct a 1-D spectral impulse response with a 2-diff Gaussian. 39 | 40 | This is the prototype filter suggested by Chi et al. 41 | """ 42 | sx = scale * x 43 | return scale * (1-(2*np.pi*sx)**2) * np.exp(-(2*np.pi*sx)**2/2) 44 | 45 | def _ht(t, rate): 46 | """Construct a 1-D temporal impulse response with a Gamma function. 47 | 48 | This is the prototype filter suggested by Chi et al. 49 | """ 50 | rt = rate * t 51 | return rate * rt**2 * np.exp(-3.5*rt) * np.sin(2*np.pi*rt) 52 | 53 | hs = _hs(np.linspace(-freq, freq, endpoint=False, 54 | num=int(2*freq*bins_per_octave)), scale) 55 | ht = _ht(np.linspace(0, time, endpoint=False, num=int(sr*time)), rate) 56 | if ndft is None: 57 | ndft = max(512, nextpow2(max(len(hs), len(ht)))) 58 | ndft = max(len(hs), len(ht)) 59 | assert ndft >= max(len(ht), len(hs)) 60 | hsa = signal.hilbert(hs, ndft)[:len(hs)] 61 | hta = signal.hilbert(ht, ndft)[:len(ht)] 62 | hirs = hs * np.cos(phi) + hsa.imag * np.sin(phi) 63 | hirt = ht * np.cos(theta) + hta.imag * np.sin(theta) 64 | hirs_ = signal.hilbert(hirs, ndft)[:len(hs)] 65 | hirt_ = signal.hilbert(hirt, ndft)[:len(ht)] 66 | return np.outer(hirt_, hirs_).real,\ 67 | np.outer(np.conj(hirt_), hirs_).real 68 | 69 | 70 | if __name__ == '__main__': 71 | import matplotlib.pyplot as plt 72 | SR = 16000 73 | CQT_FRATE = 100 74 | BPO = 6 75 | TSUPP = float(input("Enter the time support in seconds [0.5]: ") or "0.5") 76 | FSUPP = float(input("Enter the frequency support in octaves [2.0]: ") or "2") 77 | print(f"""Visualization of a STRF pair as a function of rate and scale. 78 | Configs: {SR} Hz sampling rate 79 | {CQT_FRATE} Hz frame rate 80 | {BPO} bins per octave 81 | {TSUPP} seconds time support 82 | {FSUPP} octaves frequency support""") 83 | 84 | 85 | fig, ax = plt.subplots(1, 2, figsize=(20, 10)) 86 | for ss in range(1, 26): 87 | for rr in range(1, 26): 88 | phi = 0 89 | theta = 0 90 | scale = ss / 5 91 | rate = rr / 2 92 | kdn, kup = strf(TSUPP, FSUPP, CQT_FRATE, BPO, rate=rate, scale=scale, 93 | phi=phi*np.pi, theta=theta*np.pi, ndft=512) 94 | 95 | for xx, kk in zip(ax, (kdn, kup)): 96 | xx.clear() 97 | dbspec = 10*np.log10(np.clip(kk**2, 1e-8, None)) 98 | xx.pcolormesh( 99 | *np.mgrid[slice(0, TSUPP, TSUPP/len(kk)), 100 | slice(-FSUPP, FSUPP, 2*FSUPP/int(2*FSUPP*BPO))], 101 | dbspec, 102 | cmap='jet', 103 | vmin=-60, 104 | vmax=0, 105 | shading='auto' 106 | ) 107 | fig.canvas.draw() 108 | fig.suptitle( 109 | f'''Support: {TSUPP*1e3:.1f} ms, {FSUPP:.1f} octaves 110 | Rate: {rate:.1f} Hz, Scale: {scale:.1f} cycles/octave''', 111 | fontsize=16) 112 | plt.pause(.5) 113 | 114 | plt.close(fig) 115 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | """DNN architectures based on STRF kernels.""" 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from utils import MLP, hilbert, nextpow2 9 | 10 | 11 | def is_strf_param(nm): 12 | """Check if a parameter name string is one of STRF parameters.""" 13 | return any(n in nm for n in ("rates_", "scales_", "phis_", "thetas_")) 14 | 15 | 16 | class GaborSTRFConv(nn.Module): 17 | """Gabor-STRF-based cross-correlation kernel.""" 18 | def __init__(self, supn, supk, nkern, rates=None, scales=None): 19 | """Instantiate a Gabor-based STRF convolution layer. 20 | 21 | Parameters 22 | ---------- 23 | supn: int 24 | Time support in number of frames. Also the window length. 25 | supk: int 26 | Frequency support in number of channels. Also the window length. 27 | nkern: int 28 | Number of kernels, each with a learnable rate and scale. 29 | rates: list of float, None 30 | Initial values for temporal modulation. 31 | scales: list of float, None 32 | Initial values for spectral modulation. 33 | 34 | """ 35 | super(GaborSTRFConv, self).__init__() 36 | if supk % 2 == 0: # force odd number 37 | supk += 1 38 | self.supk = torch.arange(supk, dtype=torch.float32) 39 | if supn % 2 == 0: # force odd number 40 | supn += 1 41 | self.supn = torch.arange(supn, dtype=self.supk.dtype) 42 | self.padding = (supn//2, supk//2) 43 | 44 | # Set up learnable parameters 45 | for param in (rates, scales): 46 | assert (not param) or len(param) == nkern 47 | if not rates: 48 | rates = torch.rand(nkern) * math.pi 49 | if not scales: 50 | scales = torch.rand(nkern) * math.pi 51 | self.rates_ = nn.Parameter(torch.Tensor(rates)) 52 | self.scales_ = nn.Parameter(torch.Tensor(scales)) 53 | 54 | def strfs(self): 55 | """Make STRFs using the current parameters.""" 56 | if self.supn.device != self.rates_.device: # for first run 57 | self.supn = self.supn.to(self.rates_.device) 58 | self.supk = self.supk.to(self.rates_.device) 59 | n0, k0 = self.padding 60 | nsin = torch.sin(torch.ger(self.rates_, self.supn-n0)) 61 | ncos = torch.cos(torch.ger(self.rates_, self.supn-n0)) 62 | ksin = torch.sin(torch.ger(self.scales_, self.supk-k0)) 63 | kcos = torch.cos(torch.ger(self.scales_, self.supk-k0)) 64 | nwind = .5 - .5 * torch.cos(2*math.pi*self.supn/(len(self.supn)+1)) 65 | kwind = .5 - .5 * torch.cos(2*math.pi*self.supk/(len(self.supk)+1)) 66 | strfr = torch.bmm((ncos*nwind).unsqueeze(-1), 67 | (kcos*kwind).unsqueeze(1)) 68 | strfi = torch.bmm((nsin*nwind).unsqueeze(-1), 69 | (ksin*kwind).unsqueeze(1)) 70 | 71 | return torch.cat((strfr, strfi), 0) 72 | 73 | def forward(self, sigspec): 74 | """Forward pass real spectra [Batch x Time x Frequency].""" 75 | if len(sigspec.shape) == 2: # expand batch dimension if single eg 76 | sigspec = sigspec.unsqueeze(0) 77 | strfs = self.strfs().unsqueeze(1).type_as(sigspec) 78 | return F.conv2d(sigspec.unsqueeze(1), strfs, padding=self.padding) 79 | 80 | 81 | class STRFConv(nn.Module): 82 | """Spectrotemporal receptive field (STRF)-based convolution.""" 83 | def __init__(self, fr, bins_per_octave, suptime, supoct, nkern, 84 | rates=None, scales=None, phis=None, thetas=None): 85 | """Instantiate a STRF convolution layer. 86 | 87 | Parameters 88 | ---------- 89 | fr: int 90 | Frame rate of the incoming spectrogram in Hz. 91 | e.g. spectrogram with 10ms hop size has frame rate 100Hz. 92 | bins_per_octave: int 93 | Number of frequency dimensions per octave in the spectrogram. 94 | suptime: float 95 | Maximum time support in seconds. 96 | All kernels will span [0, suptime) seconds. 97 | supoct: float 98 | Maximum frequency support in number of octaves. 99 | All kernels will span [-supoct, supoct] octaves. 100 | nkern: int 101 | Number of learnable STRF kernels. 102 | rates: array_like, (None) 103 | Init. for learnable stretch factor in time. 104 | Dimension must match `nkern` if specified. 105 | scales: int or float, (None) 106 | Init. for learnable stretch factor in frequency. 107 | Dimension must match `nkern` if specified. 108 | phis: float, (None) 109 | Init. for learnable phase shift of spectral evolution in radians. 110 | Dimension must match `nkern` if specified. 111 | thetas: float, (None) 112 | Init. for learnable phase shift of time evolution in radians. 113 | Dimension must match `nkern` if specified. 114 | 115 | """ 116 | super(STRFConv, self).__init__() 117 | 118 | # For printing 119 | self.__rep = f"""STRF(fr={fr}, bins_per_octave={bins_per_octave}, 120 | suptime={suptime}, supoct={supoct}, nkern={nkern}, 121 | rates={rates}, scales={scales}, phis={phis}, 122 | thetas={thetas})""" 123 | 124 | # Determine time & frequency support 125 | _fsteps = int(supoct * bins_per_octave) # spectral step on one side 126 | self.supf = torch.linspace(-supoct, supoct, steps=2*_fsteps+1) 127 | _tsteps = int(fr*suptime) 128 | if _tsteps % 2 == 0: # force odd number 129 | _tsteps += 1 130 | self.supt = torch.arange(_tsteps).type_as(self.supf)/fr 131 | self.padding = (_tsteps//2, _fsteps) 132 | self.ndft = max(nextpow2(max(len(self.supf), len(self.supt))), 128) 133 | 134 | # Set up learnable parameters 135 | for param in (rates, scales, phis, thetas): 136 | assert (not param) or len(param) == nkern 137 | if not rates: 138 | rates = torch.rand(nkern) * 10 139 | if not scales: 140 | scales = torch.rand(nkern) / 5 141 | if not phis: 142 | phis = 2*math.pi * torch.rand(nkern) 143 | if not thetas: 144 | thetas = 2*math.pi * torch.rand(nkern) 145 | self.rates_ = nn.Parameter(torch.Tensor(rates)) 146 | self.scales_ = nn.Parameter(torch.Tensor(scales)) 147 | self.phis_ = nn.Parameter(torch.Tensor(phis)) 148 | self.thetas_ = nn.Parameter(torch.Tensor(thetas)) 149 | 150 | @staticmethod 151 | def _hs(x, scale): 152 | """Spectral evolution.""" 153 | sx = scale * x 154 | return scale * (1-(2*math.pi*sx)**2) * torch.exp(-(2*math.pi*sx)**2/2) 155 | 156 | @staticmethod 157 | def _ht(t, rate): 158 | """Temporal evolution.""" 159 | rt = rate * t 160 | return rate * rt**2 * torch.exp(-3.5*rt) * torch.sin(2*math.pi*rt) 161 | 162 | def strfs(self): 163 | """Make STRFs using current parameters.""" 164 | if self.supt.device != self.rates_.device: # for first run 165 | self.supt = self.supt.to(self.rates_.device) 166 | self.supf = self.supf.to(self.rates_.device) 167 | K, S, T = len(self.rates_), len(self.supf), len(self.supt) 168 | # Construct STRFs 169 | hs = self._hs(self.supf, self.scales_.view(K, 1)) 170 | ht = self._ht(self.supt, self.rates_.view(K, 1)) 171 | hsa = hilbert(hs, self.ndft)[:, :hs.size(-1), :] 172 | hta = hilbert(ht, self.ndft)[:, :ht.size(-1), :] 173 | hirs = hs * torch.cos(self.phis_.view(K, 1)) \ 174 | + hsa[..., 1] * torch.sin(self.phis_.view(K, 1)) 175 | hirt = ht * torch.cos(self.thetas_.view(K, 1)) \ 176 | + hta[..., 1] * torch.sin(self.thetas_.view(K, 1)) 177 | hirs_ = hilbert(hirs, self.ndft)[:, :hs.size(-1), :] # K x S x 2 178 | hirt_ = hilbert(hirt, self.ndft)[:, :ht.size(-1), :] # K x T x 2 179 | 180 | # for a single strf: 181 | # strfdn = hirt_[:, 0] * hirs_[:, 0] - hirt_[:, 1] * hirs_[:, 1] 182 | # strfup = hirt_[:, 0] * hirs_[:, 0] + hirt_[:, 1] * hirs_[:, 1] 183 | rreal = hirt_[..., 0].view(K, T, 1) * hirs_[..., 0].view(K, 1, S) 184 | rimag = hirt_[..., 1].view(K, T, 1) * hirs_[..., 1].view(K, 1, S) 185 | strfs = torch.cat((rreal-rimag, rreal+rimag), 0) # 2K x T x S 186 | 187 | return strfs 188 | 189 | def forward(self, sigspec): 190 | """Convolve a spectrographic representation with all STRF kernels. 191 | 192 | Parameters 193 | ---------- 194 | sigspec: `torch.Tensor` (batch_size, time_dim, freq_dim) 195 | Batch of spectrograms. 196 | The frequency dimension should be logarithmically spaced. 197 | 198 | Returns 199 | ------- 200 | features: `torch.Tensor` (batch_size, nkern, time_dim, freq_dim) 201 | Batch of STRF activatations. 202 | 203 | """ 204 | if len(sigspec.shape) == 2: # expand batch dimension if single eg 205 | sigspec = sigspec.unsqueeze(0) 206 | strfs = self.strfs().unsqueeze(1).type_as(sigspec) 207 | return F.conv2d(sigspec.unsqueeze(1), strfs, padding=self.padding) 208 | 209 | def __repr__(self): 210 | return self.__rep 211 | 212 | 213 | def init_STRFNet(sample_batch, 214 | num_classes, 215 | num_kernels=32, 216 | residual_channels=[32, 32], 217 | embedding_dimension=1024, 218 | num_rnn_layers=2, 219 | frame_rate=None, bins_per_octave=None, 220 | time_support=None, frequency_support=None, 221 | conv2d_sizes=(3, 3), 222 | mlp_hiddims=[], 223 | activate_out=nn.LogSoftmax(dim=1) 224 | ): 225 | """Initialize a STRFNet for multi-class classification. 226 | 227 | This is a one-stop solution to create STRFNet and its variants. 228 | 229 | Parameters 230 | ---------- 231 | sample_batch: [Batch,Time,Frequency] torch.FloatTensor 232 | A batch of training examples that is used for training. 233 | Some dimension parameter of the network is inferred cannot be changed. 234 | num_classes: int 235 | Number of classes for the classification task. 236 | 237 | Keyword Parameters 238 | ------------------ 239 | num_kernels: int, 32 240 | 2*num_kernels is the number of STRF/2D kernels. 241 | Doubling is due to the two orientations of the STRFs. 242 | residual_channels: list(int), [32, 32] 243 | Specify the number of conv2d channels for each residual block. 244 | embedding_dimension: int, 1024 245 | Dimension of the learned embedding (RNN output). 246 | frame_rate: float, None 247 | Sampling rate [samples/second] / hop size [samples]. 248 | No STRF kernels by default. 249 | bins_per_octave: int, None 250 | Frequency bins per octave in CQT sense. (TODO: extend for non-CQT rep.) 251 | No STRF kernels by default. 252 | time_support: float, None 253 | Number of seconds spanned by each STRF kernel. 254 | No STRF kernels by default. 255 | frequency_support: int/float, None 256 | If frame_rate or bins_per_octave is None, interpret as GaborSTRFConv. 257 | - Number of frequency bins (int) spanned by each STRF kernel. 258 | Otherwise, interpret as STRFConv. 259 | - Number of octaves spanned by each STRF kernel. 260 | No STRF kernels by default. 261 | conv2d_sizes: (int, int), (3, 3) 262 | nn.Conv2d kernel dimensions. 263 | mlp_hiddims: list(int), [] 264 | Final MLP hidden layer dimensions. 265 | Default has no hidden layers. 266 | activate_out: callable, nn.LogSoftmax(dim=1) 267 | Activation function at the final layer. 268 | Default uses LogSoftmax for multi-class classification. 269 | """ 270 | if all(p is not None for p in (time_support, frequency_support)): 271 | is_strfnet = True 272 | if all(p is not None for p in (frame_rate, bins_per_octave)): 273 | kernel_type = 'wavelet' 274 | else: 275 | assert all( 276 | type(p) is int for p in (time_support, frequency_support) 277 | ) 278 | kernel_type = 'gabor' 279 | else: 280 | is_strfnet = False 281 | is_cnn = conv2d_sizes is not None 282 | is_hybrid = is_strfnet and is_cnn 283 | if is_hybrid: 284 | print(f"Preparing for Hybrid STRFNet; kernel type is {kernel_type}.") 285 | elif is_strfnet: 286 | print(f"Preparing for STRFNet; kernel type is {kernel_type}.") 287 | elif is_cnn: 288 | print("Preparing for CNN.") 289 | else: 290 | raise ValueError("Insufficient parameters. Check example_STRFNet.") 291 | 292 | if not is_strfnet: 293 | strf_layer = None 294 | elif kernel_type == 'wavelet': 295 | strf_layer = STRFConv( 296 | frame_rate, bins_per_octave, 297 | time_support, frequency_support, num_kernels 298 | ) 299 | else: 300 | strf_layer = GaborSTRFConv( 301 | time_support, frequency_support, num_kernels 302 | ) 303 | 304 | if is_cnn: 305 | d1, d2 = conv2d_sizes 306 | if d1 % 2 == 0: 307 | d1 += 1 308 | print("Enforcing odd conv2d dimension.") 309 | if d2 % 2 == 0: 310 | d2 += 1 311 | print("Enforcing odd conv2d dimension.") 312 | conv2d_layer = nn.Conv2d( 313 | 1, 2*num_kernels, # Double to match the total number of STRFs 314 | (d1, d2), padding=(d1//2, d2//2) 315 | ) 316 | else: 317 | conv2d_layer = None 318 | 319 | residual_layer = ModResnet( 320 | (4 if is_hybrid else 2)*num_kernels, residual_channels, False 321 | ) 322 | with torch.no_grad(): 323 | flattened_dimension = STRFNet.cnn_forward( 324 | sample_batch, strf_layer, conv2d_layer, residual_layer 325 | ).shape[-1] 326 | 327 | linear_layer = nn.Linear(flattened_dimension, embedding_dimension) 328 | rnn = nn.GRU( 329 | embedding_dimension, embedding_dimension, batch_first=True, 330 | num_layers=num_rnn_layers, bidirectional=True 331 | ) 332 | 333 | mlp = MLP( 334 | 2*embedding_dimension, num_classes, hiddims=mlp_hiddims, 335 | activate_hid=nn.LeakyReLU(), 336 | activate_out=activate_out, 337 | batchnorm=[True]*len(mlp_hiddims) 338 | ) 339 | 340 | return STRFNet(strf_layer, conv2d_layer, residual_layer, 341 | linear_layer, rnn, mlp) 342 | 343 | 344 | class SelfAttention(nn.Module): 345 | """A self-attentive layer.""" 346 | def __init__(self, indim, hiddim=256): 347 | super(SelfAttention, self).__init__() 348 | self.layers = nn.Sequential( 349 | nn.Linear(indim, hiddim), 350 | nn.Tanh(), 351 | nn.Linear(hiddim, 1, bias=False) 352 | ) 353 | 354 | def forward(self, x): 355 | """Transform a BxTxF input tensor.""" 356 | y_attn = self.layers(x) 357 | attn = F.softmax(y_attn, dim=1) 358 | attn_applied = torch.matmul(x.transpose(2, 1), attn).squeeze(-1) 359 | return attn_applied, attn 360 | 361 | 362 | class STRFNet(nn.Module): 363 | """A generic STRFNet with generic or STRF kernels in the first layer. 364 | 365 | Processing workflow: 366 | Feat. -> STRF/conv2d -> Residual CNN -> Attention -> MLP -> Class prob. 367 | BxTxF ----> BxTxF -------> BxTxF ---------> BxF ----> BxK -> BxC 368 | 369 | """ 370 | def __init__(self, strf_layer, conv2d_layer, residual_layer, 371 | linear_layer, rnn, mlp 372 | ): 373 | """See init_STRFNet for initializing each component.""" 374 | super(STRFNet, self).__init__() 375 | self.strf_layer = strf_layer 376 | self.conv2d_layer = conv2d_layer 377 | self.residual_layer = residual_layer 378 | self.linear_layer = linear_layer 379 | self.rnn = rnn 380 | self.attention_layer = SelfAttention(2*rnn.hidden_size) 381 | self.mlp = mlp 382 | 383 | def forward(self, x, return_embedding=False): 384 | """Forward pass a batch-by-time-by-frequency tensor.""" 385 | x = self.cnn_forward( 386 | x, self.strf_layer, self.conv2d_layer, self.residual_layer 387 | ) 388 | x = self.linear_layer(x) 389 | x, _ = self.rnn(x) 390 | x, attn = self.attention_layer(x) 391 | out = self.mlp(x) 392 | 393 | if return_embedding: 394 | return out, x 395 | 396 | return out 397 | 398 | @staticmethod 399 | def cnn_forward(x, strf_layer, conv2d_layer, residual_layer): 400 | """Forward until the beginning of linear layer. 401 | 402 | Deals with CNN, STRFNet, or Hybrid. 403 | """ 404 | def flatten(x): 405 | return x.transpose_(1, 2).reshape(x.size(0), x.size(1), -1) 406 | 407 | if strf_layer and conv2d_layer: # Hybrid 408 | strf_out = strf_layer(x) 409 | cnn_out = conv2d_layer(x.unsqueeze(1)) 410 | return flatten( 411 | residual_layer(torch.cat((strf_out, cnn_out), dim=1)) 412 | ) 413 | elif strf_layer: # STRFNet 414 | return flatten(residual_layer(strf_layer(x))) 415 | else: 416 | return flatten(residual_layer(conv2d_layer(x.unsqueeze(1)))) 417 | 418 | 419 | def conv3x3(in_channels, out_channels, stride=1): 420 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, 421 | stride=stride, padding=1, bias=False) 422 | 423 | 424 | class ResidualBlock(nn.Module): 425 | def __init__(self, in_channels, out_channels, stride=1, downsample=None): 426 | 427 | super(ResidualBlock, self).__init__() 428 | self.convlayers = nn.Sequential( 429 | conv3x3(in_channels, out_channels, stride), 430 | nn.BatchNorm2d(out_channels), 431 | nn.LeakyReLU(inplace=True), 432 | conv3x3(out_channels, out_channels), 433 | nn.BatchNorm2d(out_channels) 434 | ) 435 | self.downsample = downsample 436 | 437 | def forward(self, x): 438 | residual = self.downsample(x) if self.downsample else x 439 | return torch.relu(residual + self.convlayers(x)) 440 | 441 | 442 | class ModResnet(nn.Module): 443 | """Modified ResNet from the PyTorch tutorial.""" 444 | def __init__(self, in_chan, res_chans, pool=True): 445 | super(ModResnet, self).__init__() 446 | """Instantiate a series of residual blocks. 447 | 448 | Parameters 449 | ---------- 450 | in_chan: int 451 | Input channel number 452 | res_chans: list(int) 453 | Channel number for each residual block. 454 | 455 | """ 456 | self.in_channels = in_chan 457 | assert len(res_chans) > 0, "Requires at least one residual block!" 458 | res_layers = [self.make_layer(ResidualBlock, res_chans[0], 2)] 459 | for cc in res_chans: 460 | res_layers.append(self.make_layer(ResidualBlock, cc, 2, 2)) 461 | 462 | self.res_layers = nn.Sequential(*res_layers) 463 | if pool: 464 | self.avg_pool = nn.AvgPool2d((8, 5)) 465 | self.pool = pool 466 | 467 | def make_layer(self, block, out_channels, blocks, stride=1): 468 | downsample = None 469 | if (stride != 1) or (self.in_channels != out_channels): 470 | downsample = nn.Sequential( 471 | conv3x3(self.in_channels, out_channels, stride=stride), 472 | nn.BatchNorm2d(out_channels)) 473 | layers = [] 474 | layers.append( 475 | block(self.in_channels, out_channels, stride, downsample) 476 | ) 477 | self.in_channels = out_channels 478 | for _ in range(1, blocks): 479 | layers.append(block(out_channels, out_channels)) 480 | return nn.Sequential(*layers) 481 | 482 | def forward(self, x): 483 | out = self.res_layers(x) 484 | if self.pool: # average pool and then flatten out to single vector 485 | out = self.avg_pool(out) 486 | out = out.view(out.size(0), -1) 487 | 488 | return out 489 | 490 | 491 | if __name__ == "__main__": 492 | # Test STRFNet 493 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 494 | print(device) 495 | net = init_STRFNet( 496 | torch.rand(32, 64, 257), 2, 497 | time_support=10, frequency_support=2, 498 | #frame_rate=100, bins_per_octave=12, 499 | conv2d_sizes=None 500 | ).to(device) 501 | print(net) 502 | res = net(torch.rand(24, 64, 257).to(device)) # simulation 503 | loss = res.sum() 504 | loss.backward() 505 | print("Okay.") 506 | --------------------------------------------------------------------------------