├── .gitignore ├── LICENSE ├── README.md ├── pcen ├── __init__.py ├── f2m.py └── pcen.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Ralph Tang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-PCEN 2 | Efficient PyTorch reimplementation of [per-channel energy normalization](https://arxiv.org/pdf/1607.05666.pdf) with Mel 3 | spectrogram features. 4 | 5 | ## Overview 6 | 7 | Robustness to loudness differences in near- and far-field conditions is critical in high-quality speech recognition applications. 8 | Obviously, spectrogram energies differ significantly between, say, shouting at arms-length and whispering from a distance. 9 | This can worsen model quality, since the model itself would need to be robust across a wide range of input. The 10 | log-compression step in the popular log-Mel transform partially addresses this issue by reducing the dynamic range of audio; 11 | however, it ignores per-channel energy differences and is static by definition. 12 | 13 | [Per-channel energy normalization](https://arxiv.org/pdf/1607.05666.pdf) is one such solution to the aforementioned problems. 14 | It provides a per-channel, trainable front-end in place of the log compression, greatly improving model robustness in keyword spotting systems -- all the while being resource-efficient and easy to implement. 15 | 16 | ## Installation and Usage 17 | 1. PyTorch and NumPy are required. LibROSA and matplotlib are required only for the example. 18 | 2. To install via pip, run `pip install git+https://github.com/daemon/pytorch-pcen`. Otherwise, clone this repository and run `python setup.py install`. 19 | 3. To run the example in the module, place a 16kHz WAV file named `yes.wav` in the current directory. Then, do `python -m pcen.pcen`. 20 | 21 | The following is a self-contained example for using a streaming PCEN layer: 22 | ```python 23 | import pcen 24 | import torch 25 | 26 | # 40-dimensional features, 30-millisecond window, 10-millisecond shift; trainable is false by default 27 | transform = pcen.StreamingPCENTransform(n_mels=40, n_fft=480, hop_length=160, trainable=True) 28 | audio = torch.empty(1, 16000).normal_(0, 0.1) # Gaussian noise 29 | 30 | # 1600 is an arbitrary chunk size; This step is unnecessary but demonstrates the streaming nature 31 | streaming_chunks = audio.split(1600, 1) 32 | pcen_chunks = [transform(chunk) for chunk in streaming_chunks] # Transform each chunk 33 | transform.reset() # Reset the persistent streaming state 34 | pcen_ = torch.cat(pcen_chunks, 1) 35 | ``` 36 | 37 | ## Citation 38 | Wang, Yuxuan, Pascal Getreuer, Thad Hughes, Richard F. Lyon, and Rif A. Saurous. [Trainable frontend for robust and far-field keyword spotting](https://arxiv.org/pdf/1607.05666.pdf). In _Acoustics, Speech and Signal Processing (ICASSP), 2017 IEEE International Conference on_, pp. 5670-5674. IEEE, 2017. 39 | ```tex 40 | @inproceedings{wang2017trainable, 41 | title={Trainable frontend for robust and far-field keyword spotting}, 42 | author={Wang, Yuxuan and Getreuer, Pascal and Hughes, Thad and Lyon, Richard F and Saurous, Rif A}, 43 | booktitle={Acoustics, Speech and Signal Processing (ICASSP), 2017 IEEE International Conference on}, 44 | pages={5670--5674}, 45 | year={2017}, 46 | organization={IEEE} 47 | } 48 | ``` 49 | -------------------------------------------------------------------------------- /pcen/__init__.py: -------------------------------------------------------------------------------- 1 | from .pcen import * -------------------------------------------------------------------------------- /pcen/f2m.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 2-Clause License 3 | 4 | Copyright (c) 2017 Facebook Inc. (Soumith Chintala), 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 21 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 22 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 23 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 24 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 25 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 26 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | """ 28 | import numpy as np 29 | import torch 30 | import torch.nn as nn 31 | 32 | 33 | class F2M(nn.Module): 34 | """This turns a normal STFT into a MEL Frequency STFT, using a conversion 35 | matrix. This uses triangular filter banks. 36 | Args: 37 | n_mels (int): number of MEL bins 38 | sr (int): sample rate of audio signal 39 | f_max (float, optional): maximum frequency. default: sr // 2 40 | f_min (float): minimum frequency. default: 0 41 | """ 42 | def __init__(self, n_mels=40, sr=16000, f_max=None, f_min=0., n_fft=40, onesided=True): 43 | super().__init__() 44 | self.n_mels = n_mels 45 | self.sr = sr 46 | self.f_max = f_max if f_max is not None else sr // 2 47 | self.f_min = f_min 48 | self.n_fft = n_fft 49 | if onesided: 50 | self.n_fft = self.n_fft // 2 + 1 51 | self._init_buffers() 52 | 53 | def _init_buffers(self): 54 | m_min = 0. if self.f_min == 0 else 2595 * np.log10(1. + (self.f_min / 700)) 55 | m_max = 2595 * np.log10(1. + (self.f_max / 700)) 56 | 57 | m_pts = torch.linspace(m_min, m_max, self.n_mels + 2) 58 | f_pts = (700 * (10**(m_pts / 2595) - 1)) 59 | 60 | bins = torch.floor(((self.n_fft - 1) * 2) * f_pts / self.sr).long() 61 | 62 | fb = torch.zeros(self.n_fft, self.n_mels) 63 | for m in range(1, self.n_mels + 1): 64 | f_m_minus = bins[m - 1].item() 65 | f_m = bins[m].item() 66 | f_m_plus = bins[m + 1].item() 67 | 68 | if f_m_minus != f_m: 69 | fb[f_m_minus:f_m, m - 1] = (torch.arange(f_m_minus, f_m) - f_m_minus) / (f_m - f_m_minus) 70 | if f_m != f_m_plus: 71 | fb[f_m:f_m_plus, m - 1] = (f_m_plus - torch.arange(f_m, f_m_plus)) / (f_m_plus - f_m) 72 | self.register_buffer("fb", fb) 73 | 74 | def forward(self, spec_f): 75 | spec_m = torch.matmul(spec_f, self.fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels) 76 | return spec_m -------------------------------------------------------------------------------- /pcen/pcen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .f2m import F2M 7 | 8 | 9 | def pcen(x, eps=1E-6, s=0.025, alpha=0.98, delta=2, r=0.5, training=False, last_state=None, empty=True): 10 | frames = x.split(1, -2) 11 | m_frames = [] 12 | if empty: 13 | last_state = None 14 | for frame in frames: 15 | if last_state is None: 16 | last_state = frame 17 | m_frames.append(frame) 18 | continue 19 | if training: 20 | m_frame = ((1 - s) * last_state).add_(s * frame) 21 | else: 22 | m_frame = (1 - s) * last_state + s * frame 23 | last_state = m_frame 24 | m_frames.append(m_frame) 25 | M = torch.cat(m_frames, 1) 26 | if training: 27 | pcen_ = (x / (M + eps).pow(alpha) + delta).pow(r) - delta ** r 28 | else: 29 | pcen_ = x.div_(M.add_(eps).pow_(alpha)).add_(delta).pow_(r).sub_(delta ** r) 30 | return pcen_, last_state 31 | 32 | 33 | class StreamingPCENTransform(nn.Module): 34 | 35 | def __init__(self, eps=1E-6, s=0.025, alpha=0.98, delta=2, r=0.5, trainable=False, 36 | use_cuda_kernel=False, **stft_kwargs): 37 | super().__init__() 38 | self.use_cuda_kernel = use_cuda_kernel 39 | if trainable: 40 | self.s = nn.Parameter(torch.Tensor([s])) 41 | self.alpha = nn.Parameter(torch.Tensor([alpha])) 42 | self.delta = nn.Parameter(torch.Tensor([delta])) 43 | self.r = nn.Parameter(torch.Tensor([r])) 44 | else: 45 | self.s = s 46 | self.alpha = alpha 47 | self.delta = delta 48 | self.r = r 49 | self.eps = eps 50 | self.trainable = trainable 51 | self.stft_kwargs = stft_kwargs 52 | self.register_buffer("last_state", torch.zeros(stft_kwargs["n_mels"])) 53 | mel_keys = {"n_mels", "sr", "f_max", "f_min", "n_fft"} 54 | mel_keys = set(stft_kwargs.keys()).intersection(mel_keys) 55 | mel_kwargs = {k: stft_kwargs[k] for k in mel_keys} 56 | stft_keys = set(stft_kwargs.keys()) - mel_keys 57 | self.n_fft = stft_kwargs["n_fft"] 58 | self.stft_kwargs = {k: stft_kwargs[k] for k in stft_keys} 59 | self.f2m = F2M(**mel_kwargs) 60 | self.reset() 61 | 62 | def reset(self): 63 | self.empty = True 64 | 65 | def forward(self, x): 66 | x = torch.stft(x, self.n_fft, **self.stft_kwargs).norm(dim=-1, p=2) 67 | x = self.f2m(x.permute(0, 2, 1)) 68 | if self.use_cuda_kernel: 69 | x, ls = pcen_cuda_kernel(x, self.eps, self.s, self.alpha, self.delta, self.r, self.trainable, self.last_state, self.empty) 70 | else: 71 | x, ls = pcen(x, self.eps, self.s, self.alpha, self.delta, self.r, self.training and self.trainable, self.last_state, self.empty) 72 | self.last_state = ls.detach() 73 | self.empty = False 74 | return x 75 | 76 | 77 | if __name__ == "__main__": 78 | import time 79 | import librosa 80 | import librosa.display 81 | import matplotlib.pyplot as plt 82 | transform = StreamingPCENTransform(n_mels=40, n_fft=480, hop_length=160).cuda() 83 | x = torch.tensor(librosa.core.load("yes.wav", sr=16000)[0]).unsqueeze(0).cuda() 84 | n = 200 85 | 86 | # Non-streaming 87 | a = time.perf_counter() 88 | for _ in range(n): 89 | y = transform(x) 90 | transform.reset() 91 | b = time.perf_counter() 92 | print("{:.2} ms per second of audio.".format((b - a) / n * 1000)) 93 | 94 | # Streaming in chunks of 1600 95 | x_chunks = x.split(1600, 1) 96 | a = time.perf_counter() 97 | for _ in range(n): 98 | y_chunks = list(map(transform, x_chunks)) 99 | transform.reset() 100 | b = time.perf_counter() 101 | print("{:.2} ms per second of audio.".format((b - a) / n * 1000)) 102 | 103 | librosa.display.specshow(y[0].cpu().numpy().T) 104 | plt.title("Non-streaming") 105 | plt.show() 106 | 107 | librosa.display.specshow(torch.cat(y_chunks, 1)[0].cpu().numpy().T) 108 | plt.title("Streaming") 109 | plt.show() 110 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | 4 | setuptools.setup( 5 | name="pytorch-pcen", 6 | version="0.0.1", 7 | author="Ralph Tang", 8 | author_email="r33tang@uwaterloo.ca", 9 | description="Efficient implementation of per-channel energy normalization.", 10 | install_requires=["numpy"], 11 | url="https://github.com/daemon/pytorch-pcen", 12 | packages=setuptools.find_packages(), 13 | classifiers=( 14 | "Programming Language :: Python :: 3", 15 | "License :: OSI Approved :: MIT License", 16 | "Operating System :: OS Independent", 17 | ), 18 | ) 19 | --------------------------------------------------------------------------------