├── .coveragerc ├── .github └── workflows │ └── ci.yml ├── .gitignore ├── LICENSE ├── README.md ├── setup.py ├── tests ├── __init__.py ├── test_augmentations.py ├── test_compose.py ├── test_readme_example.py ├── test_utils.py └── utils.py └── torchaudio_augmentations ├── __init__.py ├── apply.py ├── augmentations ├── __init__.py ├── delay.py ├── filter.py ├── gain.py ├── high_low_pass.py ├── noise.py ├── pitch_shift.py ├── polarity_inversion.py ├── random_resized_crop.py ├── reverb.py └── reverse.py ├── compose.py └── utils.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | source = torchaudio_augmentations 4 | 5 | [report] 6 | exclude_lines = 7 | if self.debug: 8 | pragma: no cover 9 | raise NotImplementedError 10 | if __name__ == .__main__.: 11 | ignore_errors = True 12 | omit = 13 | tests/* 14 | setup.py -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: ci 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | test: 14 | 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up Python 3.9 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: 3.9 22 | - name: Install dependencies 23 | run: | 24 | sudo apt-get install libsndfile1 25 | python -m pip install --upgrade pip 26 | pip install -e . 27 | pip install -e .[test] 28 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 29 | - name: Lint with black 30 | run: | 31 | black --check . 32 | - name: Test with pytest and generate coverage file 33 | run: | 34 | pytest --cov=./ --cov-report=xml 35 | - name: Upload coverage to Codecov 36 | uses: codecov/codecov-action@v2 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | audio_augmentations.egg-info/ 2 | build/ 3 | dist/ 4 | __pycache__/ 5 | .pytest_cache 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Janne Spijkervet 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Audio Augmentations 2 | ![CI status](https://github.com/spijkervet/torchaudio-augmentations/actions/workflows/ci.yml/badge.svg) 3 | [![codecov](https://codecov.io/gh/Spijkervet/torchaudio-augmentations/branch/master/graph/badge.svg?token=0DEFJYJH5K)](https://codecov.io/gh/Spijkervet/torchaudio-augmentations) 4 | [![Downloads](https://pepy.tech/badge/torchaudio-augmentations)](https://pepy.tech/project/torchaudio-augmentations) 5 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.4748582.svg)](https://zenodo.org/record/4748582#) 6 | 7 | Audio data augmentations library for PyTorch for audio in the time-domain. The focus of this repository is to: 8 | - Provide many audio transformations in an easy Python interface. 9 | - Have a high test coverage. 10 | - Easily control stochastic (sequential) audio transformations. 11 | - Make every audio transformation differentiable with PyTorch's `nn.Module`. 12 | - Optimise audio transformations for CPU and GPU. 13 | 14 | It supports stochastic transformations as used often in self-supervised, semi-supervised learning methods. One can apply a single stochastic augmentation or create as many stochastically transformed audio examples from a single interface. 15 | 16 | This package follows the conventions set out by `torchvision` and `torchaudio`, with audio defined as a tensor of `[channel, time]`, or a batched representation `[batch, channel, time]`. Each individual augmentation can be initialized on its own, or be wrapped around a `RandomApply` interface which will apply the augmentation with probability `p`. 17 | 18 | 19 | ## Usage 20 | We can define a single or several audio augmentations, which are applied sequentially to an audio waveform. 21 | ```python 22 | from audio_augmentations import * 23 | 24 | audio, sr = torchaudio.load("tests/classical.00002.wav") 25 | 26 | num_samples = sr * 5 27 | transforms = [ 28 | RandomResizedCrop(n_samples=num_samples), 29 | RandomApply([PolarityInversion()], p=0.8), 30 | RandomApply([Noise(min_snr=0.001, max_snr=0.005)], p=0.3), 31 | RandomApply([Gain()], p=0.2), 32 | HighLowPass(sample_rate=sr), # this augmentation will always be applied in this aumgentation chain! 33 | RandomApply([Delay(sample_rate=sr)], p=0.5), 34 | RandomApply([PitchShift( 35 | n_samples=num_samples, 36 | sample_rate=sr 37 | )], p=0.4), 38 | RandomApply([Reverb(sample_rate=sr)], p=0.3) 39 | ] 40 | ``` 41 | 42 | We can also define a stochastic augmentation on multiple transformations. The following will apply both polarity inversion and white noise with a probability of 80%, a gain of 20%, and delay and reverb with a probability of 50%: 43 | ```python 44 | transforms = [ 45 | RandomResizedCrop(n_samples=num_samples), 46 | RandomApply([PolarityInversion(), Noise(min_snr=0.001, max_snr=0.005)], p=0.8), 47 | RandomApply([Gain()], p=0.2), 48 | RandomApply([Delay(sample_rate=sr), Reverb(sample_rate=sr)], p=0.5) 49 | ] 50 | ``` 51 | 52 | We can return either one or many versions of the same audio example: 53 | ```python 54 | transform = Compose(transforms=transforms) 55 | transformed_audio = transform(audio) 56 | >> transformed_audio.shape = [num_channels, num_samples] 57 | ``` 58 | 59 | ``` 60 | audio = torchaudio.load("testing/classical.00002.wav") 61 | transform = ComposeMany(transforms=transforms, num_augmented_samples=4) 62 | transformed_audio = transform(audio) 63 | >> transformed_audio.shape = [4, num_channels, num_samples] 64 | ``` 65 | 66 | Similar to the `torchvision.datasets` interface, an instance of the `Compose` or `ComposeMany` class can be supplied to `torchaudio` dataloaders that accept `transform=`. 67 | 68 | 69 | ## Optional 70 | Install WavAugment for reverberation / pitch shifting: 71 | ``` 72 | pip install git+https://github.com/facebookresearch/WavAugment 73 | ``` 74 | 75 | # Cite 76 | You can cite this work with the following BibTeX: 77 | ``` 78 | @misc{spijkervet_torchaudio_augmentations, 79 | doi = {10.5281/ZENODO.4748582}, 80 | url = {https://zenodo.org/record/4748582}, 81 | author = {Spijkervet, Janne}, 82 | title = {Spijkervet/torchaudio-augmentations}, 83 | publisher = {Zenodo}, 84 | year = {2021}, 85 | copyright = {MIT License} 86 | } 87 | ``` 88 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Note: To use the 'upload' functionality of this file, you must: 5 | # $ pipenv install twine --dev 6 | 7 | import io 8 | import os 9 | import sys 10 | from shutil import rmtree 11 | 12 | from setuptools import find_packages, setup, Command 13 | 14 | # Package meta-data. 15 | NAME = "torchaudio-augmentations" 16 | DESCRIPTION = "Audio augmentations library for PyTorch, for audio in the time-domain." 17 | URL = "https://github.com/spijkervet/torchaudio-augmentations" 18 | EMAIL = "janne.spijkervet@gmail.com" 19 | AUTHOR = "Janne Spijkervet" 20 | REQUIRES_PYTHON = ">=3.6.0" 21 | VERSION = "0.2.4" 22 | 23 | # What packages are required for this module to be executed? 24 | REQUIRED = ["numpy", "torch", "torchaudio", "julius", "wavaugment", "torch-pitch-shift"] 25 | TEST_REQUIRED = ["pytest", "pytest-cov", "black", "librosa"] 26 | 27 | # What packages are optional? 28 | EXTRAS = { 29 | "fancy feature": [""], 30 | "test": TEST_REQUIRED, 31 | } 32 | 33 | # The rest you shouldn't have to touch too much :) 34 | # ------------------------------------------------ 35 | # Except, perhaps the License and Trove Classifiers! 36 | # If you do change the License, remember to change the Trove Classifier for that! 37 | 38 | here = os.path.abspath(os.path.dirname(__file__)) 39 | 40 | # Import the README and use it as the long-description. 41 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 42 | try: 43 | with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f: 44 | long_description = "\n" + f.read() 45 | except FileNotFoundError: 46 | long_description = DESCRIPTION 47 | 48 | # Load the package's __version__.py module as as dictionary. 49 | about = {} 50 | if not VERSION: 51 | project_slug = NAME.lower().replace("-", "_").replace(" ", "_") 52 | with open(os.path.join(here, project_slug, "__version__.py")) as f: 53 | exec(f.read(), about) 54 | else: 55 | about["__version__"] = VERSION 56 | 57 | 58 | class UploadCommand(Command): 59 | """Support setup.py upload.""" 60 | 61 | description = "Build and publish the package." 62 | user_options = [] 63 | 64 | @staticmethod 65 | def status(s): 66 | """Prints things in bold.""" 67 | print("\033[1m{0}\033[0m".format(s)) 68 | 69 | def initialize_options(self): 70 | pass 71 | 72 | def finalize_options(self): 73 | pass 74 | 75 | def run(self): 76 | try: 77 | self.status("Removing previous builds…") 78 | rmtree(os.path.join(here, "dist")) 79 | except OSError: 80 | pass 81 | 82 | self.status("Building Source and Wheel (universal) distribution…") 83 | os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable)) 84 | 85 | self.status("Uploading the package to PyPI via Twine…") 86 | os.system("twine upload dist/*") 87 | 88 | self.status("Pushing git tags…") 89 | os.system("git tag v{0}".format(about["__version__"])) 90 | os.system("git push --tags") 91 | 92 | sys.exit() 93 | 94 | 95 | # Where the magic happens: 96 | setup( 97 | name=NAME, 98 | version=about["__version__"], 99 | description=DESCRIPTION, 100 | long_description=long_description, 101 | long_description_content_type="text/markdown", 102 | author=AUTHOR, 103 | author_email=EMAIL, 104 | python_requires=REQUIRES_PYTHON, 105 | url=URL, 106 | packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]), 107 | # If your package is a single module, use this instead of 'packages': 108 | # py_modules=['mypackage'], 109 | # entry_points={ 110 | # 'console_scripts': ['mycli=mymodule:cli'], 111 | # }, 112 | install_requires=REQUIRED, 113 | extras_require=EXTRAS, 114 | include_package_data=True, 115 | license="MIT", 116 | classifiers=[ 117 | # Trove classifiers 118 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 119 | "License :: OSI Approved :: MIT License", 120 | "Programming Language :: Python", 121 | "Programming Language :: Python :: 3", 122 | "Programming Language :: Python :: 3.6", 123 | "Programming Language :: Python :: Implementation :: CPython", 124 | "Programming Language :: Python :: Implementation :: PyPy", 125 | ], 126 | # $ setup.py publish support. 127 | cmdclass={ 128 | "upload": UploadCommand, 129 | }, 130 | ) 131 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spijkervet/torchaudio-augmentations/891b3b6e19551c211e7cdab36376c7e67e9d199c/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_augmentations.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import torch 3 | import pytest 4 | import random 5 | import numpy as np 6 | from unittest import mock 7 | 8 | from torchaudio_augmentations import ( 9 | RandomApply, 10 | Compose, 11 | RandomResizedCrop, 12 | PolarityInversion, 13 | Noise, 14 | Gain, 15 | HighLowPass, 16 | Delay, 17 | PitchShift, 18 | Reverb, 19 | Reverse, 20 | LowPassFilter, 21 | HighPassFilter, 22 | ) 23 | from .utils import generate_waveform 24 | 25 | 26 | sample_rate = 22050 27 | num_samples = sample_rate * 5 28 | 29 | 30 | @pytest.mark.parametrize("seed", range(10)) 31 | @pytest.mark.parametrize("p", (0, 1)) 32 | @pytest.mark.parametrize("num_channels", (1, 2)) 33 | def test_random_apply(p, seed, num_channels): 34 | torch.manual_seed(seed) 35 | transform = RandomApply([PolarityInversion()], p=p) 36 | 37 | num_channels = 3 38 | audio = generate_waveform(sample_rate, num_samples, num_channels) 39 | 40 | t_audio = transform(audio) 41 | if p == 0: 42 | assert torch.eq(t_audio, audio).all() == True 43 | elif p == 1: 44 | assert torch.eq(t_audio, audio).all() == False 45 | 46 | # Checking if RandomApply can be printed as string 47 | transform.__repr__() 48 | 49 | 50 | @pytest.mark.parametrize("num_channels", (1, 2)) 51 | def test_compose(num_channels): 52 | 53 | audio = generate_waveform(sample_rate, num_samples, num_channels) 54 | transforms = Compose([PolarityInversion(), Gain(min_gain=-20, max_gain=-19)]) 55 | s_transforms = torch.nn.Sequential(*transforms.transforms) 56 | random.seed(42) 57 | t_audio = transforms(audio) 58 | random.seed(42) 59 | t_audio_sequential_pass = s_transforms(audio) 60 | 61 | assert torch.eq(t_audio, t_audio_sequential_pass).all() 62 | 63 | t = Compose( 64 | [ 65 | lambda x: x, 66 | ] 67 | ) 68 | with pytest.raises(RuntimeError, match="cannot call a value of type 'Tensor'"): 69 | torch.jit.script(t) 70 | 71 | # check if we can print Compose() as string 72 | t.__repr__() 73 | 74 | 75 | @pytest.mark.parametrize("num_channels", [1, 2]) 76 | def test_random_resized_crop(num_channels): 77 | num_samples = 22050 * 5 78 | audio = generate_waveform(sample_rate, num_samples, num_channels) 79 | transform = Compose([RandomResizedCrop(num_samples)]) 80 | 81 | audio = transform(audio) 82 | assert audio.shape[0] == num_channels 83 | assert audio.shape[1] == num_samples 84 | 85 | 86 | @pytest.mark.parametrize( 87 | ["batch_size", "num_channels"], 88 | [ 89 | (1, 1), 90 | (4, 1), 91 | (16, 1), 92 | (1, 2), 93 | (4, 2), 94 | (16, 2), 95 | ], 96 | ) 97 | def test_random_resized_crop_batched(batch_size, num_channels): 98 | 99 | num_samples = 22050 * 5 100 | audio = generate_waveform(sample_rate, num_samples, num_channels) 101 | audio = audio.repeat(batch_size, 1, 1) 102 | 103 | transform = Compose([RandomResizedCrop(num_samples)]) 104 | 105 | audio = transform(audio) 106 | assert audio.shape[0] == batch_size 107 | assert audio.shape[1] == num_channels 108 | assert audio.shape[2] == num_samples 109 | 110 | 111 | @pytest.mark.parametrize("num_channels", [1, 2]) 112 | def test_polarity(num_channels): 113 | audio = generate_waveform(sample_rate, num_samples, num_channels=num_channels) 114 | transform = Compose( 115 | [PolarityInversion()], 116 | ) 117 | 118 | t_audio = transform(audio) 119 | assert torch.eq(t_audio, torch.neg(audio)).all() 120 | assert t_audio.shape == audio.shape 121 | 122 | audio = torch.Tensor([5, 6, -1, 3]) 123 | expected_audio = torch.Tensor([-5, -6, 1, -3]) 124 | t_audio = transform(audio) 125 | assert torch.eq(expected_audio, t_audio).all() 126 | 127 | mono_channel_audio = torch.Tensor([[5, 6, -1, 3]]) 128 | expected_audio = torch.Tensor([[-5, -6, 1, -3]]) 129 | t_audio = transform(mono_channel_audio) 130 | assert torch.eq(expected_audio, t_audio).all() 131 | 132 | stereo_channel_audio = mono_channel_audio.repeat(2, 1) 133 | expected_audio = expected_audio.repeat(2, 1) 134 | t_audio = transform(stereo_channel_audio) 135 | assert torch.eq(expected_audio, t_audio).all() 136 | 137 | batched_stereo_channel_audio = stereo_channel_audio.repeat(16, 1, 1) 138 | expected_audio = expected_audio.repeat(16, 1, 1) 139 | t_audio = transform(batched_stereo_channel_audio) 140 | assert torch.eq(expected_audio, t_audio).all() 141 | 142 | 143 | @pytest.mark.parametrize("num_channels", [1, 2]) 144 | def test_filter(num_channels): 145 | audio = generate_waveform(sample_rate, num_samples, num_channels) 146 | transform = Compose( 147 | [HighLowPass(sample_rate=sample_rate)], 148 | ) 149 | t_audio = transform(audio) 150 | # torchaudio.save("tests/filter.wav", t_audio, sample_rate=sample_rate) 151 | assert t_audio.shape == audio.shape 152 | 153 | 154 | @pytest.mark.parametrize("num_channels", [1, 2]) 155 | def test_delay(num_channels): 156 | audio = generate_waveform(sample_rate, num_samples, num_channels) 157 | transform = Compose( 158 | [Delay(sample_rate=sample_rate)], 159 | ) 160 | 161 | t_audio = transform(audio) 162 | # torchaudio.save("tests/delay.wav", t_audio, sample_rate=sample_rate) 163 | assert t_audio.shape == audio.shape 164 | 165 | 166 | @pytest.mark.parametrize("num_channels", [1, 2]) 167 | def test_gain(num_channels): 168 | audio = generate_waveform(sample_rate, num_samples, num_channels) 169 | transform = Compose( 170 | [Gain()], 171 | ) 172 | 173 | t_audio = transform(audio) 174 | # torchaudio.save("tests/gain.wav", t_audio, sample_rate=sample_rate) 175 | assert t_audio.shape == audio.shape 176 | 177 | 178 | @pytest.mark.parametrize("num_channels", [1, 2]) 179 | def test_noise(num_channels): 180 | audio = generate_waveform(sample_rate, num_samples, num_channels) 181 | transform = Compose( 182 | [Noise(min_snr=0.5, max_snr=1)], 183 | ) 184 | 185 | t_audio = transform(audio) 186 | # torchaudio.save("tests/noise.wav", t_audio, sample_rate=sample_rate) 187 | assert t_audio.shape == audio.shape 188 | 189 | 190 | @pytest.mark.parametrize("batch_size", [1, 2, 16]) 191 | @pytest.mark.parametrize("num_channels", [1, 2]) 192 | def test_pitch(batch_size, num_channels): 193 | audio = generate_waveform(sample_rate, num_samples, num_channels) 194 | 195 | if batch_size > 1: 196 | audio = audio.unsqueeze(dim=0) 197 | audio = audio.repeat(batch_size, 1, 1) 198 | 199 | transform = Compose( 200 | [PitchShift(n_samples=num_samples, sample_rate=sample_rate)], 201 | ) 202 | 203 | t_audio = transform(audio) 204 | # torchaudio.save("tests/pitch.wav", audio, sample_rate=sample_rate) 205 | # torchaudio.save("tests/t_pitch.wav", t_audio, sample_rate=sample_rate) 206 | assert t_audio.shape == audio.shape 207 | 208 | 209 | def test_pitch_shift_transform_with_pitch_detection(): 210 | """To check semi-tone values, check: http://www.homepages.ucl.ac.uk/~sslyjjt/speech/semitone.html""" 211 | 212 | source_frequency = 440 213 | max_semitone_shift = 4 214 | expected_frequency_shift = 554 215 | 216 | num_channels = 1 217 | audio = generate_waveform( 218 | sample_rate, num_samples, num_channels, frequency=source_frequency 219 | ) 220 | pitch_shift = PitchShift( 221 | n_samples=num_samples, 222 | sample_rate=sample_rate, 223 | pitch_shift_min=max_semitone_shift, 224 | pitch_shift_max=max_semitone_shift + 1, 225 | ) 226 | 227 | t_audio = pitch_shift(audio) 228 | librosa_audio = t_audio[0].numpy() 229 | f0_hz, _, _ = librosa.pyin(librosa_audio, fmin=10, fmax=1000) 230 | 231 | # remove nan values: 232 | f0_hz = f0_hz[~np.isnan(f0_hz)] 233 | 234 | detected_f0_hz = np.max(f0_hz) 235 | 236 | detection_threshold_in_hz = 40 237 | # the detected frequency vs. expected frequency should not be smaller than 40Hz. 238 | assert abs(detected_f0_hz - expected_frequency_shift) < detection_threshold_in_hz 239 | 240 | 241 | @pytest.mark.parametrize("num_channels", [1, 2]) 242 | def test_reverb(num_channels): 243 | audio = generate_waveform(sample_rate, num_samples, num_channels) 244 | transform = Compose( 245 | [Reverb(sample_rate=sample_rate)], 246 | ) 247 | 248 | t_audio = transform(audio) 249 | # torchaudio.save("tests/reverb.wav", t_audio, sample_rate=sample_rate) 250 | assert t_audio.shape == audio.shape 251 | 252 | 253 | @pytest.mark.parametrize("num_channels", [1, 2]) 254 | def test_reverse(num_channels): 255 | stereo_audio = generate_waveform(sample_rate, num_samples, num_channels) 256 | transform = Compose( 257 | [Reverse()], 258 | ) 259 | 260 | t_audio = transform(stereo_audio) 261 | 262 | reversed_single_channel = torch.flip(t_audio, [1])[0] 263 | assert torch.equal(reversed_single_channel, stereo_audio[0]) == True 264 | 265 | reversed_stereo_channel = torch.flip(t_audio, [0])[0] 266 | assert torch.equal(reversed_stereo_channel, stereo_audio[0]) == False 267 | 268 | assert t_audio.shape == stereo_audio.shape 269 | 270 | mono_audio = stereo_audio.mean(dim=0) 271 | assert mono_audio.shape[0] == stereo_audio.shape[1] 272 | 273 | 274 | @pytest.mark.parametrize("batch_size", [1, 2, 16]) 275 | @pytest.mark.parametrize("num_channels", [1, 2]) 276 | def test_lowpass_filter(batch_size, num_channels): 277 | audio = generate_waveform(sample_rate, num_samples, num_channels) 278 | 279 | if batch_size > 1: 280 | audio = audio.unsqueeze(dim=0) 281 | audio = audio.repeat(batch_size, 1, 1) 282 | 283 | transform = LowPassFilter(sample_rate, freq_low=200, freq_high=2000) 284 | 285 | t_audio = transform(audio) 286 | assert t_audio.shape == audio.shape 287 | 288 | 289 | @pytest.mark.parametrize("batch_size", [1, 2, 16]) 290 | @pytest.mark.parametrize("num_channels", [1, 2]) 291 | def test_highpass_filter(batch_size, num_channels): 292 | audio = generate_waveform(sample_rate, num_samples, num_channels) 293 | 294 | if batch_size > 1: 295 | audio = audio.unsqueeze(dim=0) 296 | audio = audio.repeat(batch_size, 1, 1) 297 | 298 | transform = HighPassFilter(sample_rate, freq_low=200, freq_high=2000) 299 | 300 | t_audio = transform(audio) 301 | assert t_audio.shape == audio.shape 302 | 303 | 304 | @mock.patch("random.randint") 305 | @pytest.mark.parametrize("batch_size", [1, 2, 16]) 306 | @pytest.mark.parametrize("num_channels", [1, 2]) 307 | def test_high_low_pass_filter(randint_function, batch_size, num_channels): 308 | audio = generate_waveform(sample_rate, num_samples, num_channels) 309 | 310 | if batch_size > 1: 311 | audio = audio.unsqueeze(dim=0) 312 | audio = audio.repeat(batch_size, 1, 1) 313 | 314 | transform = HighLowPass(sample_rate) 315 | 316 | # let's test for the high pass filter 317 | randint_function.return_value = 0 318 | t_audio = transform(audio) 319 | assert t_audio.shape == audio.shape 320 | 321 | # let's test for the low pass filter 322 | randint_function.return_value = 1 323 | t_audio = transform(audio) 324 | assert t_audio.shape == audio.shape 325 | -------------------------------------------------------------------------------- /tests/test_compose.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torchaudio_augmentations import ( 4 | Compose, 5 | ComposeMany, 6 | RandomResizedCrop, 7 | Reverb, 8 | ) 9 | 10 | from .utils import generate_waveform 11 | 12 | sample_rate = 22050 13 | num_samples = sample_rate * 5 14 | 15 | 16 | @pytest.mark.parametrize("num_channels", [1, 2]) 17 | def test_compose(num_channels): 18 | audio = generate_waveform(sample_rate, num_samples, num_channels) 19 | transform = Compose( 20 | [ 21 | RandomResizedCrop(num_samples), 22 | ] 23 | ) 24 | 25 | t_audio = transform(audio) 26 | assert t_audio.shape[0] == num_channels 27 | assert t_audio.shape[1] == num_samples 28 | 29 | 30 | @pytest.mark.parametrize("num_channels", [1, 2]) 31 | def test_compose_many(num_channels): 32 | num_augmented_samples = 4 33 | 34 | audio = generate_waveform(sample_rate, num_samples, num_channels) 35 | transform = ComposeMany( 36 | [ 37 | RandomResizedCrop(num_samples), 38 | Reverb(sample_rate), 39 | ], 40 | num_augmented_samples=num_augmented_samples, 41 | ) 42 | 43 | t_audio = transform(audio) 44 | assert t_audio.shape[0] == num_augmented_samples 45 | assert t_audio.shape[1] == num_channels 46 | assert t_audio.shape[2] == num_samples 47 | 48 | for n in range(1, num_augmented_samples): 49 | assert torch.all(t_audio[0].eq(t_audio[n])) == False 50 | -------------------------------------------------------------------------------- /tests/test_readme_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | import numpy as np 4 | from torchaudio_augmentations import ( 5 | Compose, 6 | RandomResizedCrop, 7 | RandomApply, 8 | PolarityInversion, 9 | Noise, 10 | Gain, 11 | HighLowPass, 12 | Delay, 13 | PitchShift, 14 | Reverb, 15 | ) 16 | 17 | sr = 22050 18 | 19 | 20 | def sine(num_samples, sr): 21 | freq = 440 22 | sine = np.sin(2 * np.pi * np.arange(num_samples) * freq / sr).astype(np.float32) 23 | return torch.from_numpy(sine).reshape(1, -1) 24 | 25 | 26 | def test_readme_example(): 27 | num_samples = sr * 5 28 | audio = sine(num_samples, sr) 29 | transforms = [ 30 | RandomResizedCrop(n_samples=num_samples), 31 | RandomApply([PolarityInversion()], p=0.8), 32 | RandomApply([Noise(min_snr=0.3, max_snr=0.5)], p=0.3), 33 | RandomApply([Gain()], p=0.2), 34 | RandomApply([HighLowPass(sample_rate=sr)], p=0.8), 35 | RandomApply([Delay(sample_rate=sr)], p=0.5), 36 | RandomApply([PitchShift(n_samples=num_samples, sample_rate=sr)], p=0.4), 37 | RandomApply([Reverb(sample_rate=sr)], p=0.3), 38 | ] 39 | 40 | transform = Compose(transforms=transforms) 41 | transformed_audio = transform(audio) 42 | assert transformed_audio.shape[0] == 1 43 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torchaudio_augmentations.utils import ( 4 | add_audio_batch_dimension, 5 | remove_audio_batch_dimension, 6 | tensor_has_valid_audio_batch_dimension, 7 | ) 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "tensor,expected_value", 12 | [ 13 | (torch.zeros(1), False), 14 | (torch.zeros(1, 48000), False), 15 | (torch.zeros(16, 48000), False), 16 | (torch.zeros(1, 1, 48000), True), 17 | (torch.zeros(16, 1, 48000), True), 18 | ], 19 | ) 20 | def test_tensor_has_valid_audio_batch_dimension(tensor, expected_value): 21 | 22 | assert tensor_has_valid_audio_batch_dimension(tensor) == expected_value 23 | 24 | 25 | def test_add_audio_batch_dimension(): 26 | tensor = torch.ones(1, 48000) 27 | expected_tensor = torch.ones(1, 1, 48000) 28 | 29 | tensor = add_audio_batch_dimension(tensor) 30 | assert torch.eq(tensor, expected_tensor).all() 31 | assert tensor_has_valid_audio_batch_dimension(tensor) == True 32 | 33 | tensor = torch.ones(48000) 34 | expected_tensor = torch.ones(1, 48000) 35 | 36 | tensor = add_audio_batch_dimension(tensor) 37 | assert torch.eq(tensor, expected_tensor).all() 38 | assert tensor_has_valid_audio_batch_dimension(tensor) == False 39 | 40 | 41 | def test_remove_audio_batch_dimension(): 42 | tensor = torch.ones(1, 1, 48000) 43 | expected_tensor = torch.ones(1, 48000) 44 | 45 | tensor = remove_audio_batch_dimension(tensor) 46 | assert torch.eq(tensor, expected_tensor).all() 47 | 48 | tensor = torch.ones(1, 48000) 49 | expected_tensor = torch.ones(48000) 50 | 51 | tensor = remove_audio_batch_dimension(tensor) 52 | assert torch.eq(tensor, expected_tensor).all() 53 | assert tensor_has_valid_audio_batch_dimension(tensor) == False 54 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def generate_waveform( 6 | sample_rate: int, 7 | num_samples: int, 8 | num_channels: int, 9 | frequency: int = 440, 10 | ) -> torch.Tensor: 11 | 12 | # Dividing x legnth value into three parts:- 1/10, 1/2, 4/10. 13 | attack_length = num_samples // 10 14 | decay_length = num_samples // 2 15 | sustain_length = num_samples - (attack_length + decay_length) 16 | sustain_value = 0.1 # Release amplitude between 0 and 1 17 | 18 | # Setting array size and length. 19 | attack = np.linspace(0, 1, num=attack_length) 20 | decay = np.linspace(1, sustain_value, num=decay_length) 21 | sustain = np.ones(sustain_length) * sustain_value 22 | attack_decay_sustain = np.concatenate((attack, decay, sustain)) 23 | 24 | wavedata = np.sin(2 * np.pi * np.arange(num_samples) * frequency / sample_rate) 25 | 26 | wavedata = wavedata * attack_decay_sustain 27 | 28 | if num_channels == 2: 29 | wavedata = np.array([wavedata, wavedata * 0.9]) 30 | return torch.from_numpy(wavedata.astype(np.float32)).reshape(num_channels, -1) 31 | -------------------------------------------------------------------------------- /torchaudio_augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | from .apply import RandomApply 2 | from .compose import Compose, ComposeMany 3 | from .augmentations.delay import Delay 4 | from .augmentations.gain import Gain 5 | from .augmentations.filter import HighPassFilter, LowPassFilter 6 | from .augmentations.high_low_pass import HighLowPass 7 | from .augmentations.noise import Noise 8 | from .augmentations.pitch_shift import PitchShift 9 | from .augmentations.polarity_inversion import PolarityInversion 10 | from .augmentations.random_resized_crop import RandomResizedCrop 11 | from .augmentations.reverb import Reverb 12 | from .augmentations.reverse import Reverse 13 | -------------------------------------------------------------------------------- /torchaudio_augmentations/apply.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License 2 | 3 | # Copyright (c) Soumith Chintala 2016, 4 | # All rights reserved. 5 | 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions are met: 8 | 9 | # * Redistributions of source code must retain the above copyright notice, this 10 | # list of conditions and the following disclaimer. 11 | 12 | # * Redistributions in binary form must reproduce the above copyright notice, 13 | # this list of conditions and the following disclaimer in the documentation 14 | # and/or other materials provided with the distribution. 15 | 16 | # * Neither the name of the copyright holder nor the names of its 17 | # contributors may be used to endorse or promote products derived from 18 | # this software without specific prior written permission. 19 | 20 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | 31 | import torch 32 | 33 | 34 | class RandomApply(torch.nn.Module): 35 | """Apply randomly a list of transformations with a given probability. 36 | 37 | .. note:: 38 | In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of 39 | transforms as shown below: 40 | 41 | >>> transforms = transforms.RandomApply(torch.nn.ModuleList([ 42 | >>> transforms.ColorJitter(), 43 | >>> ]), p=0.3) 44 | >>> scripted_transforms = torch.jit.script(transforms) 45 | 46 | Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require 47 | `lambda` functions or ``PIL.Image``. 48 | 49 | Args: 50 | transforms (list or tuple or torch.nn.Module): list of transformations 51 | p (float): probability 52 | """ 53 | 54 | def __init__(self, transforms, p=0.5): 55 | super().__init__() 56 | self.transforms = transforms 57 | self.p = p 58 | 59 | def forward(self, img): 60 | if self.p < torch.rand(1): 61 | return img 62 | for t in self.transforms: 63 | img = t(img) 64 | return img 65 | 66 | def __repr__(self): 67 | format_string = self.__class__.__name__ + "(" 68 | format_string += "\n p={}".format(self.p) 69 | for t in self.transforms: 70 | format_string += "\n" 71 | format_string += " {0}".format(t) 72 | format_string += "\n)" 73 | return format_string 74 | -------------------------------------------------------------------------------- /torchaudio_augmentations/augmentations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spijkervet/torchaudio-augmentations/891b3b6e19551c211e7cdab36376c7e67e9d199c/torchaudio_augmentations/augmentations/__init__.py -------------------------------------------------------------------------------- /torchaudio_augmentations/augmentations/delay.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | 6 | class Delay(torch.nn.Module): 7 | def __init__( 8 | self, 9 | sample_rate, 10 | volume_factor=0.5, 11 | min_delay=200, 12 | max_delay=500, 13 | delay_interval=50, 14 | ): 15 | super().__init__() 16 | self.sample_rate = sample_rate 17 | self.volume_factor = volume_factor 18 | self.min_delay = min_delay 19 | self.max_delay = max_delay 20 | self.delay_interval = delay_interval 21 | 22 | def calc_offset(self, ms): 23 | return int(ms * (self.sample_rate / 1000)) 24 | 25 | def forward(self, audio): 26 | ms = random.choice( 27 | np.arange(self.min_delay, self.max_delay, self.delay_interval) 28 | ) 29 | 30 | offset = self.calc_offset(ms) 31 | beginning = torch.zeros(audio.shape[0], offset).to(audio.device) 32 | end = audio[:, :-offset] 33 | delayed_signal = torch.cat((beginning, end), dim=1) 34 | delayed_signal = delayed_signal * self.volume_factor 35 | audio = (audio + delayed_signal) / 2 36 | return audio 37 | -------------------------------------------------------------------------------- /torchaudio_augmentations/augmentations/filter.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from julius.filters import highpass_filter, lowpass_filter 4 | 5 | 6 | class FrequencyFilter(torch.nn.Module): 7 | def __init__( 8 | self, 9 | sample_rate: int, 10 | freq_low: float, 11 | freq_high: float, 12 | ): 13 | super().__init__() 14 | self.sample_rate = sample_rate 15 | self.freq_low = freq_low 16 | self.freq_high = freq_high 17 | 18 | def cutoff_frequency(self, frequency: float) -> float: 19 | return frequency / self.sample_rate 20 | 21 | def sample_uniform_frequency(self): 22 | return random.uniform(self.freq_low, self.freq_high) 23 | 24 | 25 | class HighPassFilter(FrequencyFilter): 26 | def __init__( 27 | self, 28 | sample_rate: int, 29 | freq_low: float = 200, 30 | freq_high: float = 1200, 31 | ): 32 | super().__init__(sample_rate, freq_low, freq_high) 33 | 34 | def forward(self, audio: torch.Tensor) -> torch.Tensor: 35 | frequency = self.sample_uniform_frequency() 36 | cutoff = self.cutoff_frequency(frequency) 37 | audio = highpass_filter(audio, cutoff=cutoff) 38 | return audio 39 | 40 | 41 | class LowPassFilter(FrequencyFilter): 42 | def __init__( 43 | self, 44 | sample_rate: int, 45 | freq_low: float = 2200, 46 | freq_high: float = 4000, 47 | ): 48 | super().__init__(sample_rate, freq_low, freq_high) 49 | 50 | def forward(self, audio: torch.Tensor) -> torch.Tensor: 51 | frequency = self.sample_uniform_frequency() 52 | cutoff = self.cutoff_frequency(frequency) 53 | audio = lowpass_filter(audio, cutoff=cutoff) 54 | return audio 55 | -------------------------------------------------------------------------------- /torchaudio_augmentations/augmentations/gain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from torchaudio.transforms import Vol 4 | 5 | 6 | class Gain(torch.nn.Module): 7 | def __init__(self, min_gain: float = -20.0, max_gain: float = -1): 8 | super().__init__() 9 | self.min_gain = min_gain 10 | self.max_gain = max_gain 11 | 12 | def forward(self, audio: torch.Tensor) -> torch.Tensor: 13 | gain = random.uniform(self.min_gain, self.max_gain) 14 | audio = Vol(gain, gain_type="db")(audio) 15 | return audio 16 | -------------------------------------------------------------------------------- /torchaudio_augmentations/augmentations/high_low_pass.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torchaudio_augmentations import HighPassFilter, LowPassFilter 4 | 5 | 6 | class HighLowPass(torch.nn.Module): 7 | def __init__( 8 | self, 9 | sample_rate: int, 10 | lowpass_freq_low: float = 2200, 11 | lowpass_freq_high: float = 4000, 12 | highpass_freq_low: float = 200, 13 | highpass_freq_high: float = 1200, 14 | ): 15 | super().__init__() 16 | self.sample_rate = sample_rate 17 | 18 | self.high_pass_filter = HighPassFilter( 19 | sample_rate, highpass_freq_low, highpass_freq_high 20 | ) 21 | self.low_pass_filter = LowPassFilter( 22 | sample_rate, lowpass_freq_low, lowpass_freq_high 23 | ) 24 | 25 | def forward(self, audio): 26 | highlowband = random.randint(0, 1) 27 | if highlowband == 0: 28 | audio = self.high_pass_filter(audio) 29 | else: 30 | audio = self.low_pass_filter(audio) 31 | return audio 32 | -------------------------------------------------------------------------------- /torchaudio_augmentations/augmentations/noise.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | 6 | class Noise(torch.nn.Module): 7 | def __init__(self, min_snr=0.0001, max_snr=0.01): 8 | """ 9 | :param min_snr: Minimum signal-to-noise ratio 10 | :param max_snr: Maximum signal-to-noise ratio 11 | """ 12 | super().__init__() 13 | self.min_snr = min_snr 14 | self.max_snr = max_snr 15 | 16 | def forward(self, audio): 17 | std = torch.std(audio) 18 | noise_std = random.uniform(self.min_snr * std, self.max_snr * std) 19 | 20 | noise = np.random.normal(0.0, noise_std, size=audio.shape).astype(np.float32) 21 | 22 | return audio + noise 23 | -------------------------------------------------------------------------------- /torchaudio_augmentations/augmentations/pitch_shift.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import augment 4 | 5 | 6 | class PitchShift: 7 | def __init__( 8 | self, n_samples, sample_rate, pitch_shift_min=-7.0, pitch_shift_max=7.0 9 | ): 10 | self.n_samples = n_samples 11 | self.sample_rate = sample_rate 12 | self.pitch_shift_cents_min = int(pitch_shift_min * 100) 13 | self.pitch_shift_cents_max = int(pitch_shift_max * 100) 14 | self.src_info = {"rate": self.sample_rate} 15 | 16 | def process(self, x): 17 | n_steps = random.randint(self.pitch_shift_cents_min, self.pitch_shift_cents_max) 18 | effect_chain = augment.EffectChain().pitch(n_steps).rate(self.sample_rate) 19 | num_channels = x.shape[0] 20 | target_info = { 21 | "channels": num_channels, 22 | "length": self.n_samples, 23 | "rate": self.sample_rate, 24 | } 25 | y = effect_chain.apply(x, src_info=self.src_info, target_info=target_info) 26 | 27 | # sox might misbehave sometimes by giving nan/inf if sequences are too short (or silent) 28 | # and the effect chain includes eg `pitch` 29 | if torch.isnan(y).any() or torch.isinf(y).any(): 30 | return x.clone() 31 | 32 | if y.shape[1] != x.shape[1]: 33 | if y.shape[1] > x.shape[1]: 34 | y = y[:, : x.shape[1]] 35 | else: 36 | y0 = torch.zeros(num_channels, x.shape[1]).to(y.device) 37 | y0[:, : y.shape[1]] = y 38 | y = y0 39 | return y 40 | 41 | def __call__(self, audio): 42 | if audio.ndim == 3: 43 | for b in range(audio.shape[0]): 44 | audio[b] = self.process(audio[b]) 45 | return audio 46 | else: 47 | return self.process(audio) 48 | -------------------------------------------------------------------------------- /torchaudio_augmentations/augmentations/polarity_inversion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | 5 | class PolarityInversion(torch.nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def forward(self, audio): 10 | audio = torch.neg(audio) 11 | return audio 12 | -------------------------------------------------------------------------------- /torchaudio_augmentations/augmentations/random_resized_crop.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class RandomResizedCrop(torch.nn.Module): 6 | def __init__(self, n_samples): 7 | super().__init__() 8 | self.n_samples = n_samples 9 | 10 | def forward(self, audio): 11 | max_samples = audio.shape[-1] 12 | start_idx = random.randint(0, max_samples - self.n_samples) 13 | audio = audio[..., start_idx : start_idx + self.n_samples] 14 | return audio 15 | -------------------------------------------------------------------------------- /torchaudio_augmentations/augmentations/reverb.py: -------------------------------------------------------------------------------- 1 | import augment 2 | import torch 3 | 4 | 5 | class Reverb(torch.nn.Module): 6 | def __init__( 7 | self, 8 | sample_rate, 9 | reverberance_min=0, 10 | reverberance_max=100, 11 | dumping_factor_min=0, 12 | dumping_factor_max=100, 13 | room_size_min=0, 14 | room_size_max=100, 15 | ): 16 | super().__init__() 17 | self.sample_rate = sample_rate 18 | self.reverberance_min = reverberance_min 19 | self.reverberance_max = reverberance_max 20 | self.dumping_factor_min = dumping_factor_min 21 | self.dumping_factor_max = dumping_factor_max 22 | self.room_size_min = room_size_min 23 | self.room_size_max = room_size_max 24 | self.src_info = {"rate": self.sample_rate} 25 | self.target_info = { 26 | "channels": 1, 27 | "rate": self.sample_rate, 28 | } 29 | 30 | def forward(self, audio): 31 | reverberance = torch.randint( 32 | self.reverberance_min, self.reverberance_max, size=(1,) 33 | ).item() 34 | dumping_factor = torch.randint( 35 | self.dumping_factor_min, self.dumping_factor_max, size=(1,) 36 | ).item() 37 | room_size = torch.randint( 38 | self.room_size_min, self.room_size_max, size=(1,) 39 | ).item() 40 | 41 | num_channels = audio.shape[0] 42 | effect_chain = ( 43 | augment.EffectChain() 44 | .reverb(reverberance, dumping_factor, room_size) 45 | .channels(num_channels) 46 | ) 47 | 48 | audio = effect_chain.apply( 49 | audio, src_info=self.src_info, target_info=self.target_info 50 | ) 51 | 52 | return audio 53 | -------------------------------------------------------------------------------- /torchaudio_augmentations/augmentations/reverse.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class Reverse(torch.nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def forward(self, audio): 10 | return torch.flip(audio, dims=[-1]) 11 | -------------------------------------------------------------------------------- /torchaudio_augmentations/compose.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Compose: 5 | """Data augmentation module that transforms any given data example with a chain of audio augmentations.""" 6 | 7 | def __init__(self, transforms): 8 | self.transforms = transforms 9 | 10 | def __call__(self, x): 11 | x = self.transform(x) 12 | return x 13 | 14 | def __repr__(self): 15 | format_string = self.__class__.__name__ + "(" 16 | for t in self.transforms: 17 | format_string += "\n" 18 | format_string += "\t{0}".format(t) 19 | format_string += "\n)" 20 | return format_string 21 | 22 | def transform(self, x): 23 | for t in self.transforms: 24 | x = t(x) 25 | return x 26 | 27 | 28 | class ComposeMany(Compose): 29 | """ 30 | Data augmentation module that transforms any given data example randomly 31 | resulting in N correlated views of the same example 32 | """ 33 | 34 | def __init__(self, transforms, num_augmented_samples): 35 | self.transforms = transforms 36 | self.num_augmented_samples = num_augmented_samples 37 | 38 | def __call__(self, x): 39 | samples = [] 40 | for _ in range(self.num_augmented_samples): 41 | samples.append(self.transform(x).unsqueeze(dim=0).clone()) 42 | return torch.cat(samples, dim=0) 43 | -------------------------------------------------------------------------------- /torchaudio_augmentations/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def tensor_has_valid_audio_batch_dimension(tensor: torch.Tensor) -> torch.Tensor: 5 | if tensor.ndim == 3: 6 | return True 7 | return False 8 | 9 | 10 | def add_audio_batch_dimension(tensor: torch.Tensor) -> torch.Tensor: 11 | return tensor.unsqueeze(dim=0) 12 | 13 | 14 | def remove_audio_batch_dimension(tensor: torch.Tensor) -> torch.Tensor: 15 | return tensor.squeeze(dim=0) 16 | --------------------------------------------------------------------------------