├── tests ├── __init__.py ├── utils.py ├── test_audio.py ├── test_band_stop_filter.py ├── test_convolution.py ├── test_base_class.py ├── test_one_of.py ├── test_mel_utils.py ├── test_time_inversion.py ├── test_random_crop.py ├── test_config.py ├── test_low_pass_filter.py ├── test_high_pass_filter.py ├── test_band_pass_filter.py ├── test_pitch_shift.py ├── test_shuffle_channels.py ├── test_file_utils.py ├── test_differentiable.py ├── test_mix.py ├── test_some_of.py ├── test_spliceout.py ├── test_impulse_response.py ├── test_polarity_inversion.py ├── test_padding.py ├── test_colored_noise.py ├── test_shift.py ├── test_compose.py └── test_background_noise.py ├── torch_audiomentations ├── core │ ├── __init__.py │ └── composition.py ├── utils │ ├── __init__.py │ ├── multichannel.py │ ├── fft.py │ ├── mel_scale.py │ ├── object_dict.py │ ├── dsp.py │ ├── file.py │ ├── convolution.py │ └── config.py ├── augmentations │ ├── __init__.py │ ├── identity.py │ ├── high_pass_filter.py │ ├── polarity_inversion.py │ ├── band_stop_filter.py │ ├── time_inversion.py │ ├── padding.py │ ├── shuffle_channels.py │ ├── gain.py │ ├── random_crop.py │ ├── peak_normalization.py │ ├── splice_out.py │ ├── mix.py │ ├── low_pass_filter.py │ ├── impulse_response.py │ ├── pitch_shift.py │ ├── colored_noise.py │ ├── band_pass_filter.py │ ├── background_noise.py │ └── shift.py └── __init__.py ├── test_fixtures ├── ir │ ├── misc_file.txt │ └── impulse_response_0.wav ├── config.yml ├── bg │ ├── bg.wav │ └── stereo_noise.wav ├── perfect-alley1.ogg ├── perfect-alley2.ogg ├── acoustic_guitar_0.wav ├── bg_short │ └── bg_short.WAV └── config_compose.yml ├── pytest.ini ├── images ├── convolve_exec_time_plot.png ├── torch_audiomentations_logo.png └── visual_explanation_mode_etc.png ├── .editorconfig ├── .coveragerc ├── .github └── workflows │ ├── test_requirements.txt │ ├── test_formatting.yml │ └── ci.yml ├── pyproject.toml ├── codecov.yml ├── .flake8 ├── environment.yml ├── packaging.md ├── LICENSE ├── scripts ├── plot.py ├── measure_convolve_execution_time.py ├── perf_benchmark.py └── demo.py ├── setup.py └── .gitignore /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torch_audiomentations/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torch_audiomentations/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torch_audiomentations/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test_fixtures/ir/misc_file.txt: -------------------------------------------------------------------------------- 1 | This file serves as an example of a file that is not an audio file 2 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | python_files=test*.py 3 | addopts=--cov torch_audiomentations --cov-report=xml 4 | -------------------------------------------------------------------------------- /test_fixtures/config.yml: -------------------------------------------------------------------------------- 1 | transform: Gain 2 | params: 3 | min_gain_in_db: -12.0 4 | mode: per_channel 5 | -------------------------------------------------------------------------------- /test_fixtures/bg/bg.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iver56/torch-audiomentations/HEAD/test_fixtures/bg/bg.wav -------------------------------------------------------------------------------- /torch_audiomentations/utils/multichannel.py: -------------------------------------------------------------------------------- 1 | def is_multichannel(samples) -> bool: 2 | return samples.shape[1] > 1 3 | -------------------------------------------------------------------------------- /test_fixtures/perfect-alley1.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iver56/torch-audiomentations/HEAD/test_fixtures/perfect-alley1.ogg -------------------------------------------------------------------------------- /test_fixtures/perfect-alley2.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iver56/torch-audiomentations/HEAD/test_fixtures/perfect-alley2.ogg -------------------------------------------------------------------------------- /images/convolve_exec_time_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iver56/torch-audiomentations/HEAD/images/convolve_exec_time_plot.png -------------------------------------------------------------------------------- /test_fixtures/bg/stereo_noise.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iver56/torch-audiomentations/HEAD/test_fixtures/bg/stereo_noise.wav -------------------------------------------------------------------------------- /images/torch_audiomentations_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iver56/torch-audiomentations/HEAD/images/torch_audiomentations_logo.png -------------------------------------------------------------------------------- /test_fixtures/acoustic_guitar_0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iver56/torch-audiomentations/HEAD/test_fixtures/acoustic_guitar_0.wav -------------------------------------------------------------------------------- /test_fixtures/bg_short/bg_short.WAV: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iver56/torch-audiomentations/HEAD/test_fixtures/bg_short/bg_short.WAV -------------------------------------------------------------------------------- /images/visual_explanation_mode_etc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iver56/torch-audiomentations/HEAD/images/visual_explanation_mode_etc.png -------------------------------------------------------------------------------- /test_fixtures/ir/impulse_response_0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iver56/torch-audiomentations/HEAD/test_fixtures/ir/impulse_response_0.wav -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | end_of_line = lf 5 | insert_final_newline = true 6 | 7 | [*.py] 8 | charset = utf-8 9 | indent_style = space 10 | indent_size = 4 11 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | BASE_DIR = Path(os.path.abspath(os.path.dirname(os.path.dirname(__file__)))) 5 | TEST_FIXTURES_DIR = BASE_DIR / "test_fixtures" 6 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | fail_under=80 3 | # Regexes for lines to exclude from consideration 4 | exclude_lines = 5 | pragma: no cover 6 | raise NotImplementedError 7 | 8 | [run] 9 | source=torch_audiomentations 10 | -------------------------------------------------------------------------------- /test_fixtures/config_compose.yml: -------------------------------------------------------------------------------- 1 | transform: Compose 2 | params: 3 | transforms: 4 | - transform: Gain 5 | params: 6 | min_gain_in_db: -12.0 7 | mode: per_channel 8 | - transform: PolarityInversion 9 | shuffle: True 10 | -------------------------------------------------------------------------------- /.github/workflows/test_requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.22,<2 2 | scipy>=1.5.2 3 | torch==1.11.0 4 | torchaudio==0.11.0 5 | audioread>=2.1.8 6 | julius>=0.2.3,<0.3 7 | py-cpuinfo>=7.0.0 8 | pytest==5.3.4 9 | pytest-cov==2.8.1 10 | coverage==4.5.2 11 | PyYAML>=5.3.1 12 | torch-pitch-shift>=1.2.2 13 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools", 4 | "wheel", 5 | ] 6 | 7 | [tool.black] 8 | # https://github.com/psf/black 9 | line-length = 90 10 | target-version = ["py36"] 11 | exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|.venv|.svn|_build|buck-out|build|dist)" 12 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | comment: false 2 | coverage: 3 | status: 4 | patch: 5 | default: 6 | target: 60% 7 | project: 8 | default: 9 | target: auto # target is the base commit coverage 10 | threshold: 10% # allow this little decrease on project 11 | base: auto 12 | -------------------------------------------------------------------------------- /torch_audiomentations/utils/fft.py: -------------------------------------------------------------------------------- 1 | try: 2 | # This works in PyTorch>=1.7 3 | from torch.fft import irfft, rfft 4 | except ModuleNotFoundError: 5 | # PyTorch<=1.6 6 | raise Exception( 7 | "torch-audiomentations does not support pytorch<=1.6. Please upgrade to pytorch 1.7 or newer.", 8 | ) 9 | -------------------------------------------------------------------------------- /torch_audiomentations/utils/mel_scale.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def convert_frequencies_to_mels(f: torch.Tensor) -> torch.Tensor: 5 | """ 6 | Convert f hertz to m mels 7 | 8 | https://en.wikipedia.org/wiki/Mel_scale#Formula 9 | """ 10 | return 2595.0 * torch.log10(1.0 + f / 700.0) 11 | 12 | 13 | def convert_mels_to_frequencies(m: torch.Tensor) -> torch.Tensor: 14 | """ 15 | Convert m mels to f hertz 16 | 17 | https://en.wikipedia.org/wiki/Mel_scale#History_and_other_formulas 18 | """ 19 | return 700.0 * (10 ** (m / 2595.0) - 1.0) 20 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 90 3 | exclude = scripts,*.egg,build 4 | select = E,W,F 5 | verbose = 2 6 | # https://pep8.readthedocs.io/en/latest/intro.html#error-codes 7 | format = pylint 8 | ignore = 9 | # E731 - Do not assign a lambda expression, use a def 10 | E731 11 | # W605 - invalid escape sequence '\_'. Needed for docs 12 | W605 13 | # W504 - line break after binary operator 14 | W504 15 | # W503 - line break before binary operator, need for black 16 | W503 17 | # E203 - whitespace before ':'. Opposite convention enforced by black 18 | E203 19 | 20 | per-file-ignores = 21 | */__init__.py: F401 22 | -------------------------------------------------------------------------------- /tests/test_audio.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from torch_audiomentations.utils.io import Audio 6 | 7 | 8 | class TestAudio: 9 | @pytest.mark.parametrize( 10 | "shape", 11 | [(1, 512), (1, 2, 512), (2, 1, 512)], 12 | ) 13 | def test_rms_normalize(self, shape: tuple): 14 | samples = torch.rand(size=shape, dtype=torch.float32) 15 | normalized_samples = Audio.rms_normalize(samples) 16 | 17 | assert samples.shape == normalized_samples.shape 18 | 19 | normalized_rms = np.sqrt(np.mean(np.square(normalized_samples.numpy()))) 20 | assert normalized_rms == pytest.approx(1.0) 21 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: torch-audiomentations-gpu 2 | channels: 3 | - defaults 4 | - conda-forge 5 | - pytorch 6 | dependencies: 7 | - cudatoolkit=11.3 8 | - ffmpeg 9 | - matplotlib>=3,<4 10 | - numpy>=1.22,<2 11 | - pip>=20.1 12 | - python=3.9 13 | - pytorch::pytorch=1.11.0 14 | - pytorch::torchaudio=0.11.0 15 | - scipy=1.7.3 16 | - seaborn>=0.9,<1.0 17 | - pip: 18 | - audioread==2.1.8 19 | - black==23.12.1 20 | - coverage==5.3 21 | - julius>=0.2.3,<0.3 22 | - pandas==1.1.4 23 | - py-cpuinfo==7.0.0 24 | - pytest==7.4.4 25 | - pytest-cov==5.0.0 26 | - PyYAML==5.3.1 27 | - setuptools>=41.0.0 28 | - tqdm==4.49.0 29 | - twine 30 | - torch-pitch-shift>=1.2.2 31 | -------------------------------------------------------------------------------- /packaging.md: -------------------------------------------------------------------------------- 1 | * Check that all unit tests are OK 2 | * Run the demo script and listen to the sounds to empirically check the results 3 | * Bump the version number in `torch_audiomentations/__init__.py` in accordance with the [semantic versioning specification](https://semver.org/) 4 | * Write a summary of the changes in the changelog section in README.md 5 | * Commit and push the change with a commit message like this: "Release vx.y.z" (replace x.y.z with the package version) 6 | * Add and push a git tag to the release commit 7 | * Add a release here: https://github.com/asteroid-team/torch-audiomentations/releases/new 8 | * Update the Zenodo badge in README.md. Commit and push. 9 | * `python setup.py sdist bdist_wheel` 10 | * `python -m twine upload dist/*` 11 | -------------------------------------------------------------------------------- /.github/workflows/test_formatting.yml: -------------------------------------------------------------------------------- 1 | name: Linter 2 | on: [push] 3 | 4 | jobs: 5 | code-black: 6 | name: CI 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: Checkout 10 | uses: actions/checkout@v2 11 | - name: Set up Python 3.9 12 | uses: actions/setup-python@v2 13 | with: 14 | python-version: 3.9 15 | 16 | - name: Install Black and flake8 17 | run: pip install black==23.12.1 flake8 18 | - name: Run Black 19 | run: python -m black --config=pyproject.toml --check torch_audiomentations tests test_fixtures 20 | 21 | - name: Lint with flake8 22 | # Exit on important linting errors and warn about others. 23 | run: | 24 | python -m flake8 torch_audiomentations tests --show-source --statistics --select=F6,F7,F82,F52 25 | python -m flake8 --config .flake8 --exit-zero torch_audiomentations tests --statistics 26 | -------------------------------------------------------------------------------- /tests/test_band_stop_filter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from torch_audiomentations import BandStopFilter 5 | 6 | 7 | class TestBandStopFilter: 8 | def test_band_reject_filter(self): 9 | samples = np.array( 10 | [ 11 | [[0.75, 0.5, -0.25, -0.125, 0.0], [0.65, 0.5, -0.25, -0.125, 0.0]], 12 | [[0.3, 0.5, -0.25, -0.125, 0.0], [0.9, 0.5, -0.25, -0.125, 0.0]], 13 | [[0.9, 0.5, -0.25, -1.06, 0.0], [0.9, 0.5, -0.25, -1.12, 0.0]], 14 | ], 15 | dtype=np.float32, 16 | ) 17 | sample_rate = 16000 18 | 19 | augment = BandStopFilter(p=1.0, output_type="dict") 20 | processed_samples = augment( 21 | samples=torch.from_numpy(samples), sample_rate=sample_rate 22 | ).samples.numpy() 23 | assert processed_samples.shape == samples.shape 24 | assert processed_samples.dtype == np.float32 25 | -------------------------------------------------------------------------------- /tests/test_convolution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from numpy.testing import assert_almost_equal 3 | from scipy.signal import convolve as scipy_convolve 4 | 5 | from tests.utils import TEST_FIXTURES_DIR 6 | from torch_audiomentations.utils.convolution import convolve as torch_convolve 7 | from torch_audiomentations.utils.io import Audio 8 | 9 | 10 | class TestConvolution: 11 | def test_convolve(self): 12 | sample_rate = 16000 13 | 14 | file_path = TEST_FIXTURES_DIR / "acoustic_guitar_0.wav" 15 | audio = Audio(sample_rate, mono=True) 16 | samples = audio(file_path).numpy() 17 | ir_samples = audio(TEST_FIXTURES_DIR / "ir" / "impulse_response_0.wav").numpy() 18 | 19 | expected_output = scipy_convolve(samples, ir_samples) 20 | actual_output = torch_convolve( 21 | torch.from_numpy(samples), torch.from_numpy(ir_samples) 22 | ).numpy() 23 | 24 | assert_almost_equal(actual_output, expected_output, decimal=6) 25 | -------------------------------------------------------------------------------- /torch_audiomentations/utils/object_dict.py: -------------------------------------------------------------------------------- 1 | # Inspired by tornado 2 | # https://www.tornadoweb.org/en/stable/_modules/tornado/util.html#ObjectDict 3 | 4 | import typing 5 | 6 | _ObjectDictBase = typing.Dict[str, typing.Any] 7 | 8 | 9 | class ObjectDict(_ObjectDictBase): 10 | """ 11 | Make a dictionary behave like an object, with attribute-style access. 12 | 13 | Here are some examples of how it can be used: 14 | 15 | o = ObjectDict(my_dict) 16 | # or like this: 17 | o = ObjectDict(samples=samples, sample_rate=sample_rate) 18 | 19 | # Attribute-style access 20 | samples = o.samples 21 | 22 | # Dict-style access 23 | samples = o["samples"] 24 | """ 25 | 26 | def __getattr__(self, name): 27 | # type: (str) -> typing.Any 28 | try: 29 | return self[name] 30 | except KeyError: 31 | raise AttributeError(name) 32 | 33 | def __setattr__(self, name, value): 34 | # type: (str, typing.Any) -> None 35 | self[name] = value 36 | -------------------------------------------------------------------------------- /tests/test_base_class.py: -------------------------------------------------------------------------------- 1 | import types 2 | import unittest 3 | 4 | import pytest 5 | import torch 6 | 7 | from torch_audiomentations import PolarityInversion 8 | 9 | 10 | class TestBaseClass(unittest.TestCase): 11 | def test_parameters(self): 12 | # Test that we can access the parameters function of nn.Module 13 | augment = PolarityInversion(p=1.0, output_type="dict") 14 | params = augment.parameters() 15 | assert isinstance(params, types.GeneratorType) 16 | 17 | def test_ndim_check(self): 18 | augment = PolarityInversion(p=1.0, output_type="dict") 19 | # 1D tensor not allowed 20 | with pytest.raises(RuntimeError): 21 | augment(torch.tensor([1.0, 0.5, 0.25, 0.125], dtype=torch.float32)) 22 | # 2D tensor not allowed 23 | with pytest.raises(RuntimeError): 24 | augment(torch.tensor([[1.0, 0.5, 0.25, 0.125]], dtype=torch.float32)) 25 | # 4D tensor not allowed 26 | with pytest.raises(RuntimeError): 27 | augment(torch.tensor([[[[1.0, 0.5, 0.25, 0.125]]]], dtype=torch.float32)) 28 | -------------------------------------------------------------------------------- /torch_audiomentations/augmentations/identity.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from torch import Tensor 3 | 4 | from ..core.transforms_interface import BaseWaveformTransform 5 | from ..utils.object_dict import ObjectDict 6 | 7 | 8 | class Identity(BaseWaveformTransform): 9 | """ 10 | This transform returns the input unchanged. It can be used for simplifying the code 11 | in cases where data augmentation should be disabled. 12 | """ 13 | 14 | supported_modes = {"per_batch", "per_example", "per_channel"} 15 | supports_multichannel = True 16 | requires_sample_rate = False 17 | supports_target = True 18 | requires_target = False 19 | 20 | def apply_transform( 21 | self, 22 | samples: Tensor = None, 23 | sample_rate: Optional[int] = None, 24 | targets: Optional[Tensor] = None, 25 | target_rate: Optional[int] = None, 26 | ) -> ObjectDict: 27 | return ObjectDict( 28 | samples=samples, 29 | sample_rate=sample_rate, 30 | targets=targets, 31 | target_rate=target_rate, 32 | ) 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 asteroid-team 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /torch_audiomentations/__init__.py: -------------------------------------------------------------------------------- 1 | from .augmentations.background_noise import AddBackgroundNoise 2 | from .augmentations.band_pass_filter import BandPassFilter 3 | from .augmentations.band_stop_filter import BandStopFilter 4 | from .augmentations.colored_noise import AddColoredNoise 5 | from .augmentations.gain import Gain 6 | from .augmentations.high_pass_filter import HighPassFilter 7 | from .augmentations.identity import Identity 8 | from .augmentations.impulse_response import ApplyImpulseResponse 9 | from .augmentations.low_pass_filter import LowPassFilter 10 | from .augmentations.peak_normalization import PeakNormalization 11 | from .augmentations.pitch_shift import PitchShift 12 | from .augmentations.polarity_inversion import PolarityInversion 13 | from .augmentations.shift import Shift 14 | from .augmentations.shuffle_channels import ShuffleChannels 15 | from .augmentations.time_inversion import TimeInversion 16 | from .augmentations.splice_out import SpliceOut 17 | from .core.composition import Compose, SomeOf, OneOf 18 | from .utils.config import from_dict, from_yaml 19 | from .utils.convolution import convolve 20 | 21 | __version__ = "0.12.0" 22 | -------------------------------------------------------------------------------- /torch_audiomentations/utils/dsp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def calculate_rms(samples): 5 | """ 6 | Calculates the root mean square. 7 | 8 | Based on https://github.com/iver56/audiomentations/blob/master/audiomentations/core/utils.py 9 | """ 10 | return torch.sqrt(torch.mean(torch.square(samples), dim=-1, keepdim=False)) 11 | 12 | 13 | def calculate_desired_noise_rms(clean_rms, snr): 14 | """ 15 | Given the Root Mean Square (RMS) of a clean sound and a desired signal-to-noise ratio (SNR), 16 | calculate the desired RMS of a noise sound to be mixed in. 17 | Based on https://github.com/Sato-Kunihiko/audio-SNR/blob/8d2c933b6c0afe6f1203251f4877e7a1068a6130/create_mixed_audio_file.py#L20 18 | 19 | :param clean_rms: Root Mean Square (RMS) - a value between 0.0 and 1.0 20 | :param snr: Signal-to-Noise (SNR) Ratio in dB - typically somewhere between -20 and 60 21 | :return: 22 | """ 23 | noise_rms = clean_rms / (10 ** (snr / 20)) 24 | return noise_rms 25 | 26 | 27 | def convert_decibels_to_amplitude_ratio(decibels): 28 | return 10 ** (decibels / 20) 29 | 30 | 31 | def convert_amplitude_ratio_to_decibels(amplitude_ratio): 32 | return 20 * torch.log10(amplitude_ratio) 33 | -------------------------------------------------------------------------------- /tests/test_one_of.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | from torch_audiomentations import PolarityInversion, PeakNormalization, Gain, OneOf 6 | from torch_audiomentations.utils.object_dict import ObjectDict 7 | 8 | 9 | class TestOneOf(unittest.TestCase): 10 | def setUp(self): 11 | self.sample_rate = 16000 12 | self.audio = torch.randn(1, 1, 16000) 13 | 14 | self.transforms = [ 15 | Gain(min_gain_in_db=-6.000001, max_gain_in_db=-2, p=1.0), 16 | PolarityInversion(p=1.0), 17 | PeakNormalization(p=1.0), 18 | ] 19 | 20 | def test_one_of_without_specifying_output_type(self): 21 | augment = OneOf(self.transforms) 22 | 23 | self.assertEqual(len(augment.transform_indexes), 0) # no transforms applied yet 24 | output = augment(samples=self.audio, sample_rate=self.sample_rate) 25 | # This dtype should be torch.Tensor until we switch to ObjectDict by default 26 | assert type(output) == torch.Tensor 27 | 28 | def test_one_of_dict(self): 29 | augment = OneOf(self.transforms, output_type="dict") 30 | 31 | self.assertEqual(len(augment.transform_indexes), 0) # no transforms applied yet 32 | output = augment(samples=self.audio, sample_rate=self.sample_rate) 33 | assert type(output) == ObjectDict 34 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | # see: https://help.github.com/en/actions/reference/events-that-trigger-workflows 4 | # Trigger the workflow on push or pull request 5 | on: [push, pull_request] 6 | 7 | jobs: 8 | src-test: 9 | name: unit-tests 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: [3.9] 14 | 15 | # Timeout: https://stackoverflow.com/a/59076067/4521646 16 | timeout-minutes: 10 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | 24 | - name: Install libnsdfile 25 | run: | 26 | sudo apt update 27 | sudo apt install libsndfile1-dev libsndfile1 28 | 29 | # FIX requirement install 30 | - name: Install python dependencies 31 | run: | 32 | python -m pip install --upgrade --user pip --quiet 33 | python -m pip install -r .github/workflows/test_requirements.txt 34 | python --version 35 | pip --version 36 | python -m pip list 37 | shell: bash 38 | 39 | - name: Pytest and coverage 40 | run: coverage run -a -m py.test tests 41 | 42 | - name: Coverage report 43 | run: | 44 | coverage xml 45 | bash <(curl -s https://codecov.io/bash) 46 | -------------------------------------------------------------------------------- /tests/test_mel_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_audiomentations.utils.mel_scale import ( 4 | convert_frequencies_to_mels, 5 | convert_mels_to_frequencies, 6 | ) 7 | 8 | 9 | class TestMelUtils: 10 | def test_mel_utils_with_tensor_input(self): 11 | frequencies = torch.tensor([0.0, 1.0, 40.0, 400.0, 2000.0, 20000.0, 50000.0]) 12 | mels = convert_frequencies_to_mels(frequencies) 13 | 14 | assert torch.allclose( 15 | mels, 16 | torch.tensor( 17 | [ 18 | 0.0000e0, 19 | 1.6089e0, 20 | 6.2627e1, 21 | 5.0938e2, 22 | 1.5214e3, 23 | 3.8169e3, 24 | 4.8265e3, 25 | ] 26 | ), 27 | rtol=1e-4, 28 | atol=1e-3, 29 | ) 30 | frequencies_again = convert_mels_to_frequencies(mels) 31 | assert torch.allclose(frequencies_again, frequencies, rtol=1e-5, atol=1e-3) 32 | 33 | def test_mel_utils_with_scalar_input(self): 34 | m = convert_frequencies_to_mels(torch.tensor(400.0)) 35 | 36 | assert torch.allclose( 37 | m, 38 | torch.tensor(5.0938e2), 39 | rtol=1e-4, 40 | atol=1e-3, 41 | ) 42 | f = convert_mels_to_frequencies(m) 43 | assert torch.allclose(f, torch.tensor(400.0), rtol=1e-5, atol=1e-3) 44 | -------------------------------------------------------------------------------- /tests/test_time_inversion.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | 4 | from torch_audiomentations import TimeInversion 5 | 6 | 7 | class TestTimeInversion(unittest.TestCase): 8 | def setUp(self): 9 | self.augment = TimeInversion(p=1.0, output_type="dict") 10 | self.samples = torch.arange(1, 100, 1).type(torch.FloatTensor) 11 | self.expected_samples = torch.arange(99, 0, -1).type(torch.FloatTensor) 12 | 13 | def test_single_channel(self): 14 | samples = self.samples.unsqueeze(0).unsqueeze(0) # (B, C, T): (1, 1, 100) 15 | processed_samples = self.augment(samples=samples, sample_rate=16000).samples 16 | 17 | self.assertEqual(processed_samples.shape, samples.shape) 18 | self.assertTrue( 19 | torch.equal( 20 | processed_samples, self.expected_samples.unsqueeze(0).unsqueeze(0) 21 | ) 22 | ) 23 | 24 | def test_multi_channel(self): 25 | samples = torch.stack([self.samples, self.samples], dim=0).unsqueeze( 26 | 0 27 | ) # (B, C, T): (1, 2, 100) 28 | processed_samples = self.augment(samples=samples, sample_rate=16000).samples 29 | 30 | self.assertEqual(processed_samples.shape, samples.shape) 31 | self.assertTrue( 32 | torch.equal(processed_samples[:, 0], self.expected_samples.unsqueeze(0)) 33 | ) 34 | self.assertTrue( 35 | torch.equal(processed_samples[:, 1], self.expected_samples.unsqueeze(0)) 36 | ) 37 | -------------------------------------------------------------------------------- /scripts/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | import numpy as np 4 | 5 | 6 | def show_horizontal_bar_chart(exec_time_dict, title): 7 | """ 8 | Plot execution times as a horizontal bar chart. 9 | 10 | :param exec_time_dict: a dict where keys are descriptions and values are execution times 11 | :param title: The title of the plot 12 | :return: 13 | """ 14 | fig = plt.figure(figsize=(8, 8)) 15 | ax = fig.add_subplot(111) 16 | ax.set_xlabel("Execution time (s)\nNote: log-scaled axis!") 17 | ax.invert_yaxis() # labels read top-to-bottom 18 | 19 | ax.set_title(title, loc="left") 20 | 21 | labels = exec_time_dict.keys() 22 | exec_times = [exec_time_dict[label] for label in labels] 23 | labels = [label if label is not None else "unknown" for label in labels] 24 | 25 | zipped_data = zip(exec_times, labels) 26 | zipped_data = sorted(zipped_data, reverse=True) 27 | exec_times, labels = zip(*zipped_data) 28 | 29 | ypos = np.arange(len(exec_times)) 30 | plt.barh(ypos, exec_times, align="center") 31 | plt.yticks(ypos, labels, fontsize=12) 32 | plt.xscale("log") 33 | max_exec_time = max(exec_times) 34 | 35 | for i, value in enumerate(exec_times): 36 | if value < 0.5 * max_exec_time: 37 | plt.text(value, i, " {:.2f} s".format(value), va="center") 38 | else: 39 | plt.text( 40 | value, 41 | i, 42 | "{:.2f} s ".format(value), 43 | va="center", 44 | ha="right", 45 | color="white", 46 | ) 47 | 48 | # Tweak spacing to prevent clipping of tick-labels 49 | plt.subplots_adjust(left=0.4) 50 | plt.show() 51 | plt.close(fig) 52 | -------------------------------------------------------------------------------- /tests/test_random_crop.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import pytest 4 | import numpy as np 5 | from torch_audiomentations.augmentations.random_crop import RandomCrop 6 | 7 | 8 | class TestRandomCrop(unittest.TestCase): 9 | def test_crop(self): 10 | samples = torch.rand(size=(8, 2, 32000), dtype=torch.float32) 11 | sampling_rate = 16000 12 | crop_to = 1.5 13 | desired_samples_len = sampling_rate * crop_to 14 | Crop = RandomCrop(max_length=crop_to, sampling_rate=sampling_rate) 15 | cropped_samples = Crop(samples) 16 | 17 | self.assertEqual(desired_samples_len, cropped_samples.size(-1)) 18 | 19 | def test_crop_larger_cropto(self): 20 | samples = torch.rand(size=(8, 2, 32000), dtype=torch.float32) 21 | sampling_rate = 16000 22 | crop_to = 3 23 | Crop = RandomCrop(max_length=crop_to, sampling_rate=sampling_rate) 24 | cropped_samples = Crop(samples) 25 | 26 | np.testing.assert_array_equal(samples, cropped_samples) 27 | self.assertEqual(samples.size(-1), cropped_samples.size(-1)) 28 | 29 | @pytest.mark.skip(reason="output_type is not implemented yet") 30 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") 31 | def test_crop_on_device_cuda(self): 32 | samples = torch.rand( 33 | size=(8, 2, 32000), dtype=torch.float32, device=torch.device("cuda") 34 | ) 35 | sampling_rate = 16000 36 | crop_to = 1.5 37 | desired_samples_len = sampling_rate * crop_to 38 | Crop = RandomCrop( 39 | max_length=crop_to, sampling_rate=sampling_rate, output_type="dict" 40 | ) 41 | cropped_samples = Crop(samples) 42 | 43 | self.assertEqual(desired_samples_len, cropped_samples.size(-1)) 44 | -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from tests.utils import TEST_FIXTURES_DIR 4 | from torch_audiomentations import from_dict, from_yaml 5 | from torch_audiomentations import Gain, Compose 6 | 7 | 8 | class TestFromConfig(unittest.TestCase): 9 | def test_from_dict(self): 10 | config = { 11 | "transform": "Gain", 12 | "params": {"min_gain_in_db": -12.0, "mode": "per_channel"}, 13 | } 14 | transform = from_dict(config) 15 | 16 | assert isinstance(transform, Gain) 17 | assert transform.min_gain_in_db == -12.0 18 | assert transform.max_gain_in_db == 6.0 19 | assert transform.mode == "per_channel" 20 | 21 | def test_from_yaml(self): 22 | file_yml = TEST_FIXTURES_DIR / "config.yml" 23 | transform = from_yaml(file_yml) 24 | 25 | assert isinstance(transform, Gain) 26 | assert transform.min_gain_in_db == -12.0 27 | assert transform.max_gain_in_db == 6.0 28 | assert transform.mode == "per_channel" 29 | 30 | def test_from_dict_compose(self): 31 | config = { 32 | "transform": "Compose", 33 | "params": { 34 | "shuffle": True, 35 | "transforms": [ 36 | { 37 | "transform": "Gain", 38 | "params": {"min_gain_in_db": -12.0, "mode": "per_channel"}, 39 | }, 40 | {"transform": "PolarityInversion"}, 41 | ], 42 | }, 43 | } 44 | transform = from_dict(config) 45 | assert isinstance(transform, Compose) 46 | 47 | def test_from_yaml_compose(self): 48 | file_yml = TEST_FIXTURES_DIR / "config_compose.yml" 49 | transform = from_yaml(file_yml) 50 | assert isinstance(transform, Compose) 51 | -------------------------------------------------------------------------------- /tests/test_low_pass_filter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from torch_audiomentations import LowPassFilter 5 | 6 | 7 | class TestLowPassFilter: 8 | def test_low_pass_filter(self): 9 | samples = np.array( 10 | [ 11 | [[0.75, 0.5, -0.25, -0.125, 0.0], [0.65, 0.5, -0.25, -0.125, 0.0]], 12 | [[0.3, 0.5, -0.25, -0.125, 0.0], [0.9, 0.5, -0.25, -0.125, 0.0]], 13 | [[0.9, 0.5, -0.25, -1.06, 0.0], [0.9, 0.5, -0.25, -1.12, 0.0]], 14 | ], 15 | dtype=np.float32, 16 | ) 17 | sample_rate = 16000 18 | 19 | augment = LowPassFilter( 20 | min_cutoff_freq=200, max_cutoff_freq=7000, p=1.0, output_type="dict" 21 | ) 22 | processed_samples = augment( 23 | samples=torch.from_numpy(samples), sample_rate=sample_rate 24 | ).samples.numpy() 25 | assert processed_samples.shape == samples.shape 26 | assert processed_samples.dtype == np.float32 27 | 28 | def test_equal_cutoff_min_max(self): 29 | samples = np.array( 30 | [ 31 | [[0.75, 0.5, -0.25, -0.125, 0.0], [0.65, 0.5, -0.25, -0.125, 0.0]], 32 | [[0.3, 0.5, -0.25, -0.125, 0.0], [0.9, 0.5, -0.25, -0.125, 0.0]], 33 | [[0.9, 0.5, -0.25, -1.06, 0.0], [0.9, 0.5, -0.25, -1.12, 0.0]], 34 | ], 35 | dtype=np.float32, 36 | ) 37 | sample_rate = 16000 38 | 39 | augment = LowPassFilter( 40 | min_cutoff_freq=2000, max_cutoff_freq=2000, p=1.0, output_type="dict" 41 | ) 42 | processed_samples = augment( 43 | samples=torch.from_numpy(samples), sample_rate=sample_rate 44 | ).samples.numpy() 45 | assert processed_samples.shape == samples.shape 46 | assert processed_samples.dtype == np.float32 47 | -------------------------------------------------------------------------------- /tests/test_high_pass_filter.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | 7 | from torch_audiomentations import HighPassFilter 8 | 9 | 10 | class TestHighPassFilter(unittest.TestCase): 11 | def test_high_pass_filter(self): 12 | samples = np.array( 13 | [ 14 | [[0.75, 0.5, -0.25, -0.125, 0.0], [0.65, 0.5, -0.25, -0.125, 0.0]], 15 | [[0.3, 0.5, -0.25, -0.125, 0.0], [0.9, 0.5, -0.25, -0.125, 0.0]], 16 | [[0.9, 0.5, -0.25, -1.06, 0.0], [0.9, 0.5, -0.25, -1.12, 0.0]], 17 | ], 18 | dtype=np.float32, 19 | ) 20 | sample_rate = 16000 21 | 22 | augment = HighPassFilter(p=1.0, output_type="dict") 23 | processed_samples = augment( 24 | samples=torch.from_numpy(samples), sample_rate=sample_rate 25 | ).samples.numpy() 26 | self.assertEqual(processed_samples.shape, samples.shape) 27 | self.assertEqual(processed_samples.dtype, np.float32) 28 | 29 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") 30 | def test_high_pass_filter_cuda(self): 31 | samples = np.array( 32 | [ 33 | [[0.75, 0.5, -0.25, -0.125, 0.0], [0.65, 0.5, -0.25, -0.125, 0.0]], 34 | [[0.3, 0.5, -0.25, -0.125, 0.0], [0.9, 0.5, -0.25, -0.125, 0.0]], 35 | [[0.9, 0.5, -0.25, -1.06, 0.0], [0.9, 0.5, -0.25, -1.12, 0.0]], 36 | ], 37 | dtype=np.float32, 38 | ) 39 | sample_rate = 16000 40 | 41 | augment = HighPassFilter(p=1.0, output_type="dict") 42 | processed_samples = ( 43 | augment(samples=torch.from_numpy(samples).cuda(), sample_rate=sample_rate) 44 | .samples.cpu() 45 | .numpy() 46 | ) 47 | self.assertEqual(processed_samples.shape, samples.shape) 48 | self.assertEqual(processed_samples.dtype, np.float32) 49 | -------------------------------------------------------------------------------- /torch_audiomentations/augmentations/high_pass_filter.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | from typing import Optional 3 | 4 | from ..augmentations.low_pass_filter import LowPassFilter 5 | from ..utils.object_dict import ObjectDict 6 | 7 | 8 | class HighPassFilter(LowPassFilter): 9 | """ 10 | Apply high-pass filtering to the input audio. 11 | """ 12 | 13 | def __init__( 14 | self, 15 | min_cutoff_freq: float = 20.0, 16 | max_cutoff_freq: float = 2400.0, 17 | mode: str = "per_example", 18 | p: float = 0.5, 19 | p_mode: str = None, 20 | sample_rate: int = None, 21 | target_rate: int = None, 22 | output_type: Optional[str] = None, 23 | ): 24 | """ 25 | :param min_cutoff_freq: Minimum cutoff frequency in hertz 26 | :param max_cutoff_freq: Maximum cutoff frequency in hertz 27 | :param mode: 28 | :param p: 29 | :param p_mode: 30 | :param sample_rate: 31 | :param target_rate: 32 | """ 33 | 34 | super().__init__( 35 | min_cutoff_freq, 36 | max_cutoff_freq, 37 | mode=mode, 38 | p=p, 39 | p_mode=p_mode, 40 | sample_rate=sample_rate, 41 | target_rate=target_rate, 42 | output_type=output_type, 43 | ) 44 | 45 | def apply_transform( 46 | self, 47 | samples: Tensor = None, 48 | sample_rate: Optional[int] = None, 49 | targets: Optional[Tensor] = None, 50 | target_rate: Optional[int] = None, 51 | ) -> ObjectDict: 52 | perturbed = super().apply_transform( 53 | samples=samples.clone(), 54 | sample_rate=sample_rate, 55 | targets=targets.clone() if targets is not None else None, 56 | target_rate=target_rate, 57 | ) 58 | 59 | perturbed.samples = samples - perturbed.samples 60 | return perturbed 61 | -------------------------------------------------------------------------------- /tests/test_band_pass_filter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from torch_audiomentations import BandPassFilter 6 | 7 | 8 | class TestBandPassFilter: 9 | def test_band_pass_filter(self): 10 | samples = np.array( 11 | [ 12 | [[0.75, 0.5, -0.25, -0.125, 0.0], [0.65, 0.5, -0.25, -0.125, 0.0]], 13 | [[0.3, 0.5, -0.25, -0.125, 0.0], [0.9, 0.5, -0.25, -0.125, 0.0]], 14 | [[0.9, 0.5, -0.25, -1.06, 0.0], [0.9, 0.5, -0.25, -1.12, 0.0]], 15 | ], 16 | dtype=np.float32, 17 | ) 18 | sample_rate = 16000 19 | 20 | augment = BandPassFilter(p=1.0, output_type="dict") 21 | for _ in range(20): 22 | processed_samples = augment( 23 | samples=torch.from_numpy(samples), sample_rate=sample_rate 24 | ).samples.numpy() 25 | assert processed_samples.shape == samples.shape 26 | assert processed_samples.dtype == np.float32 27 | 28 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") 29 | def test_band_pass_filter_cuda(self): 30 | samples = np.array( 31 | [ 32 | [[0.75, 0.5, -0.25, -0.125, 0.0], [0.65, 0.5, -0.25, -0.125, 0.0]], 33 | [[0.3, 0.5, -0.25, -0.125, 0.0], [0.9, 0.5, -0.25, -0.125, 0.0]], 34 | [[0.9, 0.5, -0.25, -1.06, 0.0], [0.9, 0.5, -0.25, -1.12, 0.0]], 35 | ], 36 | dtype=np.float32, 37 | ) 38 | sample_rate = 16000 39 | 40 | augment = BandPassFilter(p=1.0, output_type="dict") 41 | for _ in range(20): 42 | processed_samples = ( 43 | augment(samples=torch.from_numpy(samples).cuda(), sample_rate=sample_rate) 44 | .samples.cpu() 45 | .numpy() 46 | ) 47 | assert processed_samples.shape == samples.shape 48 | assert processed_samples.dtype == np.float32 49 | -------------------------------------------------------------------------------- /tests/test_pitch_shift.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | from numpy.testing import assert_almost_equal 7 | from torch_audiomentations import PitchShift 8 | 9 | 10 | def get_example(): 11 | return ( 12 | torch.rand( 13 | size=(8, 2, 32000), 14 | dtype=torch.float32, 15 | device="cuda" if torch.cuda.is_available() else "cpu", 16 | ) 17 | - 0.5 18 | ) 19 | 20 | 21 | class TestPitchShift(unittest.TestCase): 22 | def test_per_example_shift(self): 23 | samples = get_example() 24 | aug = PitchShift(sample_rate=16000, p=1, mode="per_example", output_type="dict") 25 | aug.randomize_parameters(samples) 26 | results = aug.apply_transform(samples).samples 27 | self.assertEqual(results.shape, samples.shape) 28 | 29 | def test_per_channel_shift(self): 30 | samples = get_example() 31 | aug = PitchShift(sample_rate=16000, p=1, mode="per_channel", output_type="dict") 32 | aug.randomize_parameters(samples) 33 | results = aug.apply_transform(samples).samples 34 | self.assertEqual(results.shape, samples.shape) 35 | 36 | def test_per_batch_shift(self): 37 | samples = get_example() 38 | aug = PitchShift(sample_rate=16000, p=1, mode="per_batch", output_type="dict") 39 | aug.randomize_parameters(samples) 40 | results = aug.apply_transform(samples).samples 41 | self.assertEqual(results.shape, samples.shape) 42 | 43 | def error_raised(self): 44 | error = False 45 | try: 46 | PitchShift( 47 | sample_rate=16000, 48 | p=1, 49 | mode="per_example", 50 | min_transpose_semitones=0.0, 51 | max_transpose_semitones=0.0, 52 | output_type="dict", 53 | ) 54 | except ValueError: 55 | error = True 56 | if not error: 57 | raise ValueError("Invalid transpositions were not detected") 58 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import os 3 | import re 4 | 5 | from setuptools import setup, find_packages 6 | 7 | with open("README.md", "r") as readme_file: 8 | long_description = readme_file.read() 9 | 10 | here = os.path.abspath(os.path.dirname(__file__)) 11 | 12 | 13 | def read(*parts): 14 | with codecs.open(os.path.join(here, *parts), "r") as fp: 15 | return fp.read() 16 | 17 | 18 | def find_version(*file_paths): 19 | version_file = read(*file_paths) 20 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) 21 | if version_match: 22 | return version_match.group(1) 23 | raise RuntimeError("Unable to find version string.") 24 | 25 | 26 | setup( 27 | name="torch-audiomentations", 28 | version=find_version("torch_audiomentations", "__init__.py"), 29 | author="Iver Jordal", 30 | description="A Pytorch library for audio data augmentation. Inspired by audiomentations." 31 | " Useful for deep learning.", 32 | license="MIT", 33 | long_description=long_description, 34 | long_description_content_type="text/markdown", 35 | url="https://github.com/asteroid-team/torch-audiomentations", 36 | packages=find_packages( 37 | exclude=["build", "scripts", "dist", "images", "test_fixtures", "tests"] 38 | ), 39 | install_requires=[ 40 | "julius>=0.2.3,<0.3", 41 | "torch>=1.7.0", 42 | "torchaudio>=0.9.0,<2.9", 43 | "torch-pitch-shift>=1.2.2", 44 | ], 45 | extras_require={"extras": ["PyYAML"]}, 46 | python_requires=">=3.6", 47 | classifiers=[ 48 | "Programming Language :: Python :: 3", 49 | "License :: OSI Approved :: MIT License", 50 | "Operating System :: OS Independent", 51 | "Development Status :: 3 - Alpha", 52 | "Intended Audience :: Developers", 53 | "Intended Audience :: Science/Research", 54 | "Topic :: Multimedia", 55 | "Topic :: Multimedia :: Sound/Audio", 56 | "Topic :: Scientific/Engineering", 57 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 58 | ], 59 | ) 60 | -------------------------------------------------------------------------------- /torch_audiomentations/utils/file.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import List, Union 4 | 5 | 6 | SUPPORTED_EXTENSIONS = (".wav",) 7 | 8 | 9 | def find_audio_files_in_paths( 10 | paths: Union[List[Path], List[str], Path, str], 11 | filename_endings=SUPPORTED_EXTENSIONS, 12 | traverse_subdirectories=True, 13 | follow_symlinks=True, 14 | ): 15 | """Return a list of paths to all audio files with the given extension(s) contained in the list or in its directories. 16 | Also traverses subdirectories by default. 17 | """ 18 | 19 | file_paths = [] 20 | 21 | if isinstance(paths, (list, tuple, set)): 22 | paths = list(paths) 23 | else: 24 | paths = [paths] 25 | 26 | for p in paths: 27 | if str(p).lower().endswith(SUPPORTED_EXTENSIONS): 28 | file_path = Path(os.path.abspath(p)) 29 | file_paths.append(file_path) 30 | elif os.path.isdir(p): 31 | file_paths += find_audio_files( 32 | p, 33 | filename_endings=filename_endings, 34 | traverse_subdirectories=traverse_subdirectories, 35 | follow_symlinks=follow_symlinks, 36 | ) 37 | return file_paths 38 | 39 | 40 | def find_audio_files( 41 | root_path, 42 | filename_endings=SUPPORTED_EXTENSIONS, 43 | traverse_subdirectories=True, 44 | follow_symlinks=True, 45 | ): 46 | """Return a list of paths to all audio files with the given extension(s) in a directory. 47 | Also traverses subdirectories by default. 48 | """ 49 | file_paths = [] 50 | 51 | for root, dirs, filenames in os.walk(root_path, followlinks=follow_symlinks): 52 | filenames = sorted(filenames) 53 | for filename in filenames: 54 | input_path = os.path.abspath(root) 55 | file_path = os.path.join(input_path, filename) 56 | 57 | if filename.lower().endswith(filename_endings): 58 | file_paths.append(Path(file_path)) 59 | if not traverse_subdirectories: 60 | # prevent descending into subfolders 61 | break 62 | 63 | return file_paths 64 | -------------------------------------------------------------------------------- /tests/test_shuffle_channels.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from numpy.testing import assert_array_equal 5 | 6 | from torch_audiomentations import ShuffleChannels 7 | from torch_audiomentations.core.transforms_interface import ModeNotSupportedException 8 | 9 | 10 | class TestShuffleChannels: 11 | def test_shuffle_mono(self): 12 | samples = torch.from_numpy( 13 | np.array([[[1.0, -1.0, 1.0, -1.0, 1.0]]], dtype=np.float32) 14 | ) 15 | augment = ShuffleChannels(p=1.0, output_type="dict") 16 | 17 | with pytest.warns(UserWarning): 18 | processed_samples = augment(samples).samples 19 | 20 | assert_array_equal(samples.numpy(), processed_samples.numpy()) 21 | 22 | @pytest.mark.parametrize( 23 | "device_name", 24 | [ 25 | pytest.param("cpu"), 26 | pytest.param( 27 | "cuda", 28 | marks=pytest.mark.skip("Requires CUDA") 29 | if not torch.cuda.is_available() 30 | else [], 31 | ), 32 | ], 33 | ) 34 | def test_variability_within_batch(self, device_name): 35 | device = torch.device(device_name) 36 | torch.manual_seed(42) 37 | 38 | samples = np.array( 39 | [[1.0, -1.0, 1.0, -1.0, 1.0], [0.1, -0.1, 0.1, -0.1, 0.1]], dtype=np.float32 40 | ) 41 | samples = np.stack([samples] * 1000, axis=0) 42 | samples = torch.from_numpy(samples).to(device) 43 | 44 | augment = ShuffleChannels(p=1.0, output_type="dict") 45 | processed_samples = augment(samples).samples 46 | 47 | orders = {"original": 0, "swapped": 0} 48 | for i in range(processed_samples.shape[0]): 49 | if processed_samples[i, 0, 0] > 0.5: 50 | orders["original"] += 1 51 | else: 52 | orders["swapped"] += 1 53 | 54 | for order in orders: 55 | assert orders[order] > 50 56 | 57 | def test_unsupported_mode(self): 58 | with pytest.raises(ModeNotSupportedException): 59 | ShuffleChannels(mode="per_batch", p=1.0, output_type="dict") 60 | -------------------------------------------------------------------------------- /torch_audiomentations/augmentations/polarity_inversion.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | from typing import Optional 3 | 4 | from ..core.transforms_interface import BaseWaveformTransform 5 | from ..utils.object_dict import ObjectDict 6 | 7 | 8 | class PolarityInversion(BaseWaveformTransform): 9 | """ 10 | Flip the audio samples upside-down, reversing their polarity. In other words, multiply the 11 | waveform by -1, so negative values become positive, and vice versa. The result will sound 12 | the same compared to the original when played back in isolation. However, when mixed with 13 | other audio sources, the result may be different. This waveform inversion technique 14 | is sometimes used for audio cancellation or obtaining the difference between two waveforms. 15 | However, in the context of audio data augmentation, this transform can be useful when 16 | training phase-aware machine learning models. 17 | """ 18 | 19 | supported_modes = {"per_batch", "per_example", "per_channel"} 20 | 21 | supports_multichannel = True 22 | requires_sample_rate = False 23 | 24 | supports_target = True 25 | requires_target = False 26 | 27 | def __init__( 28 | self, 29 | mode: str = "per_example", 30 | p: float = 0.5, 31 | p_mode: Optional[str] = None, 32 | sample_rate: Optional[int] = None, 33 | target_rate: Optional[int] = None, 34 | output_type: Optional[str] = None, 35 | ): 36 | super().__init__( 37 | mode=mode, 38 | p=p, 39 | p_mode=p_mode, 40 | sample_rate=sample_rate, 41 | target_rate=target_rate, 42 | output_type=output_type, 43 | ) 44 | 45 | def apply_transform( 46 | self, 47 | samples: Tensor = None, 48 | sample_rate: Optional[int] = None, 49 | targets: Optional[Tensor] = None, 50 | target_rate: Optional[int] = None, 51 | ) -> ObjectDict: 52 | return ObjectDict( 53 | samples=-samples, 54 | sample_rate=sample_rate, 55 | targets=targets, 56 | target_rate=target_rate, 57 | ) 58 | -------------------------------------------------------------------------------- /torch_audiomentations/augmentations/band_stop_filter.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | from typing import Optional 3 | 4 | from ..augmentations.band_pass_filter import BandPassFilter 5 | from ..utils.object_dict import ObjectDict 6 | 7 | 8 | class BandStopFilter(BandPassFilter): 9 | """ 10 | Apply band-stop filtering to the input audio. Also known as notch filter, 11 | band reject filter and frequency mask. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | min_center_frequency=200, 17 | max_center_frequency=4000, 18 | min_bandwidth_fraction=0.5, 19 | max_bandwidth_fraction=1.99, 20 | mode: str = "per_example", 21 | p: float = 0.5, 22 | p_mode: str = None, 23 | sample_rate: int = None, 24 | target_rate: int = None, 25 | output_type: Optional[str] = None, 26 | ): 27 | """ 28 | :param min_center_frequency: Minimum center frequency in hertz 29 | :param max_center_frequency: Maximum center frequency in hertz 30 | :param min_bandwidth_fraction: Minimum bandwidth fraction relative to center 31 | frequency (number between 0.0 and 2.0) 32 | :param max_bandwidth_fraction: Maximum bandwidth fraction relative to center 33 | frequency (number between 0.0 and 2.0) 34 | :param mode: 35 | :param p: 36 | :param p_mode: 37 | :param sample_rate: 38 | :param target_rate: 39 | """ 40 | 41 | super().__init__( 42 | min_center_frequency, 43 | max_center_frequency, 44 | min_bandwidth_fraction, 45 | max_bandwidth_fraction, 46 | mode=mode, 47 | p=p, 48 | p_mode=p_mode, 49 | sample_rate=sample_rate, 50 | target_rate=target_rate, 51 | output_type=output_type, 52 | ) 53 | 54 | def apply_transform( 55 | self, 56 | samples: Tensor = None, 57 | sample_rate: Optional[int] = None, 58 | targets: Optional[Tensor] = None, 59 | target_rate: Optional[int] = None, 60 | ) -> ObjectDict: 61 | perturbed = super().apply_transform( 62 | samples.clone(), 63 | sample_rate, 64 | targets=targets.clone() if targets is not None else None, 65 | target_rate=target_rate, 66 | ) 67 | 68 | perturbed.samples = samples - perturbed.samples 69 | return perturbed 70 | -------------------------------------------------------------------------------- /tests/test_file_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tempfile 4 | import uuid 5 | from pathlib import Path 6 | 7 | import pytest 8 | 9 | from tests.utils import TEST_FIXTURES_DIR 10 | from torch_audiomentations.utils.file import find_audio_files, find_audio_files_in_paths 11 | 12 | 13 | class TestFileUtils: 14 | def test_find_audio_files(self): 15 | file_paths = find_audio_files(TEST_FIXTURES_DIR) 16 | file_paths = [Path(fp).name for fp in file_paths] 17 | assert set(file_paths) == { 18 | "acoustic_guitar_0.wav", 19 | "bg.wav", 20 | "bg_short.WAV", 21 | "impulse_response_0.wav", 22 | "stereo_noise.wav", 23 | } 24 | 25 | def test_find_audio_files_in_paths(self): 26 | paths = [ 27 | os.path.join(TEST_FIXTURES_DIR, "bg"), 28 | os.path.join(TEST_FIXTURES_DIR, "bg_short"), 29 | os.path.join(TEST_FIXTURES_DIR, "ir", "impulse_response_0.wav"), 30 | ] 31 | file_paths = find_audio_files_in_paths(paths) 32 | file_paths = [Path(fp).name for fp in file_paths] 33 | assert set(file_paths) == { 34 | "bg.wav", 35 | "bg_short.WAV", 36 | "impulse_response_0.wav", 37 | "stereo_noise.wav", 38 | } 39 | 40 | @pytest.mark.skipif( 41 | os.name == "nt", reason="Symlink testing is not relevant on Windows" 42 | ) 43 | def test_follow_directory_symlink(self): 44 | tmp_dir_1 = os.path.join(tempfile.gettempdir(), str(uuid.uuid4())[0:12]) 45 | tmp_dir_2 = os.path.join(tempfile.gettempdir(), str(uuid.uuid4())[0:12]) 46 | os.makedirs(tmp_dir_1, exist_ok=True) 47 | os.makedirs(tmp_dir_2, exist_ok=True) 48 | 49 | assert tmp_dir_1 != tmp_dir_2 50 | 51 | tmp_file_path = os.path.join(tmp_dir_1, "{}.wav".format(str(uuid.uuid4())[0:12])) 52 | shutil.copyfile(TEST_FIXTURES_DIR / "acoustic_guitar_0.wav", tmp_file_path) 53 | 54 | file_paths = find_audio_files(tmp_dir_2) 55 | assert len(file_paths) == 0 56 | 57 | symlink_path = os.path.join(tmp_dir_2, "subdir") 58 | os.symlink(tmp_dir_1, symlink_path, target_is_directory=True) 59 | 60 | file_paths = find_audio_files(tmp_dir_2) 61 | assert len(file_paths) == 1 62 | assert Path(file_paths[0]).name == Path(tmp_file_path).name 63 | 64 | os.unlink(symlink_path) 65 | os.unlink(tmp_file_path) 66 | os.rmdir(tmp_dir_1) 67 | os.rmdir(tmp_dir_2) 68 | -------------------------------------------------------------------------------- /torch_audiomentations/augmentations/time_inversion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from typing import Optional 4 | 5 | from ..core.transforms_interface import BaseWaveformTransform 6 | from ..utils.object_dict import ObjectDict 7 | 8 | 9 | class TimeInversion(BaseWaveformTransform): 10 | """ 11 | Reverse (invert) the audio along the time axis similar to random flip of 12 | an image in the visual domain. This can be relevant in the context of audio 13 | classification. It was successfully applied in the paper 14 | AudioCLIP: Extending CLIP to Image, Text and Audio 15 | https://arxiv.org/pdf/2106.13043.pdf 16 | """ 17 | 18 | supported_modes = {"per_batch", "per_example", "per_channel"} 19 | 20 | supports_multichannel = True 21 | requires_sample_rate = False 22 | 23 | supports_target = True 24 | requires_target = False 25 | 26 | def __init__( 27 | self, 28 | mode: str = "per_example", 29 | p: float = 0.5, 30 | p_mode: str = None, 31 | sample_rate: int = None, 32 | target_rate: int = None, 33 | output_type: Optional[str] = None, 34 | ): 35 | """ 36 | :param mode: 37 | :param p: 38 | :param p_mode: 39 | :param sample_rate: 40 | """ 41 | super().__init__( 42 | mode=mode, 43 | p=p, 44 | p_mode=p_mode, 45 | sample_rate=sample_rate, 46 | target_rate=target_rate, 47 | output_type=output_type, 48 | ) 49 | 50 | def apply_transform( 51 | self, 52 | samples: Tensor = None, 53 | sample_rate: Optional[int] = None, 54 | targets: Optional[Tensor] = None, 55 | target_rate: Optional[int] = None, 56 | ) -> ObjectDict: 57 | # torch.flip() is supposed to be slower than np.flip() 58 | # An alternative is to use advanced indexing: https://github.com/pytorch/pytorch/issues/16424 59 | # reverse_index = torch.arange(selected_samples.size(-1) - 1, -1, -1).to(selected_samples.device) 60 | # transformed_samples = selected_samples[..., reverse_index] 61 | # return transformed_samples 62 | 63 | flipped_samples = torch.flip(samples, dims=(-1,)) 64 | if targets is None: 65 | flipped_targets = targets 66 | else: 67 | flipped_targets = torch.flip(targets, dims=(-2,)) 68 | 69 | return ObjectDict( 70 | samples=flipped_samples, 71 | sample_rate=sample_rate, 72 | targets=flipped_targets, 73 | target_rate=target_rate, 74 | ) 75 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Jetbrains IDE 132 | .idea 133 | 134 | # Scripts/demo output 135 | scripts/**/*.wav 136 | scripts/**/*.png 137 | 138 | # Anaconda temporary requirements file 139 | conda*.requirements.txt 140 | 141 | # vscode 142 | .vscode 143 | -------------------------------------------------------------------------------- /torch_audiomentations/augmentations/padding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | from torch import Tensor 4 | 5 | from ..core.transforms_interface import BaseWaveformTransform 6 | from ..utils.object_dict import ObjectDict 7 | 8 | 9 | class Padding(BaseWaveformTransform): 10 | supported_modes = {"per_batch", "per_example", "per_channel"} 11 | supports_multichannel = True 12 | requires_sample_rate = False 13 | 14 | supports_target = True 15 | requires_target = False 16 | 17 | def __init__( 18 | self, 19 | min_fraction=0.1, 20 | max_fraction=0.5, 21 | pad_section="end", 22 | mode="per_batch", 23 | p=0.5, 24 | p_mode: Optional[str] = None, 25 | sample_rate: Optional[int] = None, 26 | target_rate: Optional[int] = None, 27 | output_type: Optional[str] = None, 28 | ): 29 | super().__init__( 30 | mode=mode, 31 | p=p, 32 | p_mode=p_mode, 33 | sample_rate=sample_rate, 34 | target_rate=target_rate, 35 | output_type=output_type, 36 | ) 37 | self.min_fraction = min_fraction 38 | self.max_fraction = max_fraction 39 | self.pad_section = pad_section 40 | if not self.min_fraction >= 0.0: 41 | raise ValueError("minimum fraction should be greater than zero.") 42 | if self.min_fraction > self.max_fraction: 43 | raise ValueError( 44 | "minimum fraction should be less than or equal to maximum fraction." 45 | ) 46 | assert self.pad_section in ( 47 | "start", 48 | "end", 49 | ), 'pad_section must be "start" or "end"' 50 | 51 | def randomize_parameters( 52 | self, 53 | samples: Tensor = None, 54 | sample_rate: Optional[int] = None, 55 | targets: Optional[Tensor] = None, 56 | target_rate: Optional[int] = None, 57 | ): 58 | input_length = samples.shape[-1] 59 | self.transform_parameters["pad_length"] = torch.randint( 60 | int(input_length * self.min_fraction), 61 | int(input_length * self.max_fraction), 62 | (samples.shape[0],), 63 | ) 64 | 65 | def apply_transform( 66 | self, 67 | samples: Tensor, 68 | sample_rate: Optional[int] = None, 69 | targets: Optional[int] = None, 70 | target_rate: Optional[int] = None, 71 | ) -> ObjectDict: 72 | for i, index in enumerate(self.transform_parameters["pad_length"]): 73 | if self.pad_section == "start": 74 | samples[i, :, :index] = 0.0 75 | else: 76 | samples[i, :, -index:] = 0.0 77 | 78 | return ObjectDict( 79 | samples=samples, 80 | sample_rate=sample_rate, 81 | targets=targets, 82 | target_rate=target_rate, 83 | ) 84 | -------------------------------------------------------------------------------- /torch_audiomentations/utils/convolution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_audiomentations.utils.fft import rfft, irfft 4 | 5 | _NEXT_FAST_LEN = {} 6 | 7 | 8 | def next_fast_len(size): 9 | """ 10 | Returns the next largest number ``n >= size`` whose prime factors are all 11 | 2, 3, or 5. These sizes are efficient for fast fourier transforms. 12 | Equivalent to :func:`scipy.fftpack.next_fast_len`. 13 | 14 | Note: This function was originally copied from the https://github.com/pyro-ppl/pyro 15 | repository, where the license was Apache 2.0. Any modifications to the original code can be 16 | found at https://github.com/asteroid-team/torch-audiomentations/commits 17 | 18 | :param int size: A positive number. 19 | :returns: A possibly larger number. 20 | :rtype int: 21 | """ 22 | try: 23 | return _NEXT_FAST_LEN[size] 24 | except KeyError: 25 | pass 26 | 27 | assert isinstance(size, int) and size > 0 28 | next_size = size 29 | while True: 30 | remaining = next_size 31 | for n in (2, 3, 5): 32 | while remaining % n == 0: 33 | remaining //= n 34 | if remaining == 1: 35 | _NEXT_FAST_LEN[size] = next_size 36 | return next_size 37 | next_size += 1 38 | 39 | 40 | def convolve(signal, kernel, mode="full"): 41 | """ 42 | Computes the 1-d convolution of signal by kernel using FFTs. 43 | The two arguments should have the same rightmost dim, but may otherwise be 44 | arbitrarily broadcastable. 45 | 46 | Note: This function was originally copied from the https://github.com/pyro-ppl/pyro 47 | repository, where the license was Apache 2.0. Any modifications to the original code can be 48 | found at https://github.com/asteroid-team/torch-audiomentations/commits 49 | 50 | :param torch.Tensor signal: A signal to convolve. 51 | :param torch.Tensor kernel: A convolution kernel. 52 | :param str mode: One of: 'full', 'valid', 'same'. 53 | :return: A tensor with broadcasted shape. Letting ``m = signal.size(-1)`` 54 | and ``n = kernel.size(-1)``, the rightmost size of the result will be: 55 | ``m + n - 1`` if mode is 'full'; 56 | ``max(m, n) - min(m, n) + 1`` if mode is 'valid'; or 57 | ``max(m, n)`` if mode is 'same'. 58 | :rtype torch.Tensor: 59 | """ 60 | m = signal.size(-1) 61 | n = kernel.size(-1) 62 | if mode == "full": 63 | truncate = m + n - 1 64 | elif mode == "valid": 65 | truncate = max(m, n) - min(m, n) + 1 66 | elif mode == "same": 67 | truncate = max(m, n) 68 | else: 69 | raise ValueError("Unknown mode: {}".format(mode)) 70 | 71 | # Compute convolution using fft. 72 | padded_size = m + n - 1 73 | # Round up for cheaper fft. 74 | fast_ftt_size = next_fast_len(padded_size) 75 | f_signal = rfft(signal, n=fast_ftt_size) 76 | f_kernel = rfft(kernel, n=fast_ftt_size) 77 | f_result = f_signal * f_kernel 78 | result = irfft(f_result, n=fast_ftt_size) 79 | 80 | start_idx = (padded_size - truncate) // 2 81 | return result[..., start_idx : start_idx + truncate] 82 | -------------------------------------------------------------------------------- /torch_audiomentations/augmentations/shuffle_channels.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import warnings 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | 8 | from ..core.transforms_interface import BaseWaveformTransform 9 | from ..utils.object_dict import ObjectDict 10 | 11 | 12 | class ShuffleChannels(BaseWaveformTransform): 13 | """ 14 | Given multichannel audio input (e.g. stereo), shuffle the channels, e.g. so left can become right and vice versa. 15 | This transform can help combat positional bias in machine learning models that input multichannel waveforms. 16 | 17 | If the input audio is mono, this transform does nothing except emit a warning. 18 | """ 19 | 20 | supports_multichannel = True 21 | requires_sample_rate = False 22 | supported_modes = {"per_example"} 23 | 24 | def __init__( 25 | self, 26 | mode: str = "per_example", 27 | p: float = 0.5, 28 | p_mode: Optional[str] = None, 29 | sample_rate: Optional[int] = None, 30 | target_rate: Optional[int] = None, 31 | output_type: Optional[str] = None, 32 | ): 33 | super().__init__( 34 | mode=mode, 35 | p=p, 36 | p_mode=p_mode, 37 | sample_rate=sample_rate, 38 | target_rate=target_rate, 39 | output_type=output_type, 40 | ) 41 | 42 | def randomize_parameters( 43 | self, 44 | samples: Tensor = None, 45 | sample_rate: Optional[int] = None, 46 | targets: Optional[Tensor] = None, 47 | target_rate: Optional[int] = None, 48 | ): 49 | batch_size = samples.shape[0] 50 | num_channels = samples.shape[1] 51 | assert num_channels <= 255 52 | permutations = torch.zeros( 53 | (batch_size, num_channels), dtype=torch.int64, device=samples.device 54 | ) 55 | for i in range(batch_size): 56 | permutations[i] = torch.randperm(num_channels, device=samples.device) 57 | self.transform_parameters["permutations"] = permutations 58 | 59 | def apply_transform( 60 | self, 61 | samples: Tensor = None, 62 | sample_rate: Optional[int] = None, 63 | targets: Optional[Tensor] = None, 64 | target_rate: Optional[int] = None, 65 | ) -> ObjectDict: 66 | if samples.shape[1] == 1: 67 | warnings.warn( 68 | "Mono audio was passed to ShuffleChannels - there are no channels to shuffle." 69 | " The input will be returned unchanged." 70 | ) 71 | return ObjectDict( 72 | samples=samples, 73 | sample_rate=sample_rate, 74 | targets=targets, 75 | target_rate=target_rate, 76 | ) 77 | 78 | for i in range(samples.size(0)): 79 | samples[i] = samples[i, self.transform_parameters["permutations"][i]] 80 | if targets is not None: 81 | targets[i] = targets[i, self.transform_parameters["permutations"][i]] 82 | 83 | return ObjectDict( 84 | samples=samples, 85 | sample_rate=sample_rate, 86 | targets=targets, 87 | target_rate=target_rate, 88 | ) 89 | -------------------------------------------------------------------------------- /tests/test_differentiable.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import pytest 4 | import torch 5 | from torch.optim import SGD 6 | 7 | from tests.utils import TEST_FIXTURES_DIR 8 | from torch_audiomentations import ( 9 | AddBackgroundNoise, 10 | ApplyImpulseResponse, 11 | Gain, 12 | PeakNormalization, 13 | PolarityInversion, 14 | Compose, 15 | Shift, 16 | LowPassFilter, 17 | HighPassFilter, 18 | ) 19 | 20 | BG_NOISE_PATH = TEST_FIXTURES_DIR / "bg" 21 | IR_PATH = TEST_FIXTURES_DIR / "ir" 22 | 23 | 24 | @pytest.mark.parametrize( 25 | "augment", 26 | [ 27 | # Differentiable transforms: 28 | AddBackgroundNoise(BG_NOISE_PATH, 20, p=1.0, output_type="dict"), 29 | ApplyImpulseResponse(IR_PATH, p=1.0, output_type="dict"), 30 | Compose( 31 | transforms=[ 32 | Gain(min_gain_in_db=-15.0, max_gain_in_db=5.0, p=1.0), 33 | PolarityInversion(p=1.0), 34 | ], 35 | output_type="dict", 36 | ), 37 | Gain(min_gain_in_db=-6.000001, max_gain_in_db=-6, p=1.0, output_type="dict"), 38 | PolarityInversion(p=1.0, output_type="dict"), 39 | Shift(p=1.0, output_type="dict"), 40 | # Non-differentiable transforms: 41 | # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: 42 | # [torch.DoubleTensor [1, 1, 5]], which is output 0 of IndexBackward, is at version 1; expected version 0 instead. 43 | # Hint: enable anomaly detection to find the operation that failed to compute its gradient, 44 | # with torch.autograd.set_detect_anomaly(True). 45 | pytest.param( 46 | HighPassFilter(p=1.0, output_type="dict"), 47 | marks=pytest.mark.skip("Not differentiable"), 48 | ), 49 | pytest.param( 50 | LowPassFilter(p=1.0, output_type="dict"), 51 | marks=pytest.mark.skip("Not differentiable"), 52 | ), 53 | pytest.param( 54 | PeakNormalization(p=1.0, output_type="dict"), 55 | marks=pytest.mark.skip("Not differentiable"), 56 | ), 57 | ], 58 | ) 59 | def test_transform_is_differentiable(augment): 60 | sample_rate = 16000 61 | # Note: using float64 dtype to be compatible with AddBackgroundNoise fixtures 62 | samples = torch.tensor( 63 | [[1.0, 0.5, -0.25, -0.125, 0.0]], dtype=torch.float64 64 | ).unsqueeze(1) 65 | samples_cpy = deepcopy(samples) 66 | 67 | # We are going to convert the input tensor to a nn.Parameter so that we can 68 | # track the gradients with respect to it. We'll "optimize" the input signal 69 | # to be closer to that after the augmentation to test differentiability 70 | # of the transform. If the signal got changed in any way, and the test 71 | # didn't crash, it means it works. 72 | samples = torch.nn.Parameter(samples) 73 | optim = SGD([samples], lr=1.0) 74 | for i in range(10): 75 | optim.zero_grad() 76 | transformed = augment(samples=samples, sample_rate=sample_rate).samples 77 | # Compute mean absolute error 78 | loss = torch.mean(torch.abs(samples - transformed)) 79 | loss.backward() 80 | optim.step() 81 | 82 | assert (samples != samples_cpy).any() 83 | -------------------------------------------------------------------------------- /torch_audiomentations/augmentations/gain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from typing import Optional 4 | 5 | from ..core.transforms_interface import BaseWaveformTransform 6 | from ..utils.dsp import convert_decibels_to_amplitude_ratio 7 | from ..utils.object_dict import ObjectDict 8 | 9 | 10 | class Gain(BaseWaveformTransform): 11 | """ 12 | Multiply the audio by a random amplitude factor to reduce or increase the volume. This 13 | technique can help a model become somewhat invariant to the overall gain of the input audio. 14 | 15 | Warning: This transform can return samples outside the [-1, 1] range, which may lead to 16 | clipping or wrap distortion, depending on what you do with the audio in a later stage. 17 | See also https://en.wikipedia.org/wiki/Clipping_(audio)#Digital_clipping 18 | """ 19 | 20 | supported_modes = {"per_batch", "per_example", "per_channel"} 21 | 22 | supports_multichannel = True 23 | requires_sample_rate = False 24 | 25 | supports_target = True 26 | requires_target = False 27 | 28 | def __init__( 29 | self, 30 | min_gain_in_db: float = -18.0, 31 | max_gain_in_db: float = 6.0, 32 | mode: str = "per_example", 33 | p: float = 0.5, 34 | p_mode: Optional[str] = None, 35 | sample_rate: Optional[int] = None, 36 | target_rate: Optional[int] = None, 37 | output_type: Optional[str] = None, 38 | ): 39 | super().__init__( 40 | mode=mode, 41 | p=p, 42 | p_mode=p_mode, 43 | sample_rate=sample_rate, 44 | target_rate=target_rate, 45 | output_type=output_type, 46 | ) 47 | self.min_gain_in_db = min_gain_in_db 48 | self.max_gain_in_db = max_gain_in_db 49 | if self.min_gain_in_db >= self.max_gain_in_db: 50 | raise ValueError("max_gain_in_db must be higher than min_gain_in_db") 51 | 52 | def randomize_parameters( 53 | self, 54 | samples: Tensor = None, 55 | sample_rate: Optional[int] = None, 56 | targets: Optional[Tensor] = None, 57 | target_rate: Optional[int] = None, 58 | ): 59 | distribution = torch.distributions.Uniform( 60 | low=torch.tensor( 61 | self.min_gain_in_db, dtype=torch.float32, device=samples.device 62 | ), 63 | high=torch.tensor( 64 | self.max_gain_in_db, dtype=torch.float32, device=samples.device 65 | ), 66 | validate_args=True, 67 | ) 68 | selected_batch_size = samples.size(0) 69 | self.transform_parameters["gain_factors"] = ( 70 | convert_decibels_to_amplitude_ratio( 71 | distribution.sample(sample_shape=(selected_batch_size,)) 72 | ) 73 | .unsqueeze(1) 74 | .unsqueeze(1) 75 | ) 76 | 77 | def apply_transform( 78 | self, 79 | samples: Tensor = None, 80 | sample_rate: Optional[int] = None, 81 | targets: Optional[Tensor] = None, 82 | target_rate: Optional[int] = None, 83 | ) -> ObjectDict: 84 | return ObjectDict( 85 | samples=samples * self.transform_parameters["gain_factors"], 86 | sample_rate=sample_rate, 87 | targets=targets, 88 | target_rate=target_rate, 89 | ) 90 | -------------------------------------------------------------------------------- /torch_audiomentations/augmentations/random_crop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import typing 3 | import warnings 4 | from torch_audiomentations.utils.multichannel import is_multichannel 5 | from ..core.transforms_interface import MultichannelAudioNotSupportedException 6 | 7 | 8 | class RandomCrop(torch.nn.Module): 9 | 10 | """Crop the audio to predefined length in max_length.""" 11 | 12 | supports_multichannel = True 13 | 14 | def __init__( 15 | self, max_length: float, sampling_rate: int, max_length_unit: str = "seconds" 16 | ): 17 | """ 18 | :param max_length: length to which samples are to be cropped. 19 | :sampling_rate: sampling rate of input samples. 20 | :max_length_unit: defines the unit of max_length. 21 | "seconds": Number of seconds 22 | "samples": Number of audio samples 23 | """ 24 | super(RandomCrop, self).__init__() 25 | self.sampling_rate = sampling_rate 26 | if max_length_unit == "seconds": 27 | self.num_samples = int(self.sampling_rate * max_length) 28 | elif max_length_unit == "samples": 29 | self.num_samples = int(max_length) 30 | else: 31 | raise ValueError('max_length_unit must be "samples" or "seconds"') 32 | 33 | def forward(self, samples, sampling_rate: typing.Optional[int] = None): 34 | sample_rate = sampling_rate or self.sampling_rate 35 | if sample_rate is None: 36 | raise RuntimeError("sample_rate is required") 37 | 38 | if len(samples) == 0: 39 | warnings.warn( 40 | "An empty samples tensor was passed to {}".format(self.__class__.__name__) 41 | ) 42 | return samples 43 | 44 | if len(samples.shape) != 3: 45 | raise RuntimeError( 46 | "torch-audiomentations expects input tensors to be three-dimensional, with" 47 | " dimension ordering like [batch_size, num_channels, num_samples]. If your" 48 | " audio is mono, you can use a shape like [batch_size, 1, num_samples]." 49 | ) 50 | 51 | if is_multichannel(samples): 52 | if samples.shape[1] > samples.shape[2]: 53 | warnings.warn( 54 | "Multichannel audio must have channels first, not channels last. In" 55 | " other words, the shape must be (batch size, channels, samples), not" 56 | " (batch_size, samples, channels)" 57 | ) 58 | if not self.supports_multichannel: 59 | raise MultichannelAudioNotSupportedException( 60 | "{} only supports mono audio, not multichannel audio".format( 61 | self.__class__.__name__ 62 | ) 63 | ) 64 | 65 | if samples.shape[2] < self.num_samples: 66 | warnings.warn("audio length less than cropping length") 67 | return samples 68 | 69 | start_indices = torch.randint( 70 | 0, samples.shape[2] - self.num_samples, (samples.shape[2],) 71 | ) 72 | samples_cropped = torch.empty( 73 | (samples.shape[0], samples.shape[1], self.num_samples), device=samples.device 74 | ) 75 | for i, sample in enumerate(samples): 76 | samples_cropped[i] = sample.unsqueeze(0)[ 77 | :, :, start_indices[i] : start_indices[i] + self.num_samples 78 | ] 79 | 80 | return samples_cropped 81 | -------------------------------------------------------------------------------- /tests/test_mix.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | from torch_audiomentations.augmentations.mix import Mix 6 | from torch_audiomentations.utils.dsp import calculate_rms 7 | from torch_audiomentations.utils.io import Audio 8 | from .utils import TEST_FIXTURES_DIR 9 | 10 | 11 | class TestMix(unittest.TestCase): 12 | def setUp(self): 13 | self.sample_rate = 16000 14 | audio = Audio(self.sample_rate, mono=True) 15 | self.guitar = audio(TEST_FIXTURES_DIR / "acoustic_guitar_0.wav")[None] 16 | self.noise = audio(TEST_FIXTURES_DIR / "bg" / "bg.wav")[None] 17 | 18 | common_num_samples = min(self.guitar.shape[-1], self.noise.shape[-1]) 19 | self.guitar = self.guitar[:, :, :common_num_samples] 20 | 21 | self.guitar_target = torch.zeros( 22 | (1, 1, common_num_samples // 7, 2), dtype=torch.int64 23 | ) 24 | self.guitar_target[:, :, :, 0] = 1 25 | 26 | self.noise = self.noise[:, :, :common_num_samples] 27 | self.noise_target = torch.zeros( 28 | (1, 1, common_num_samples // 7, 2), dtype=torch.int64 29 | ) 30 | self.noise_target[:, :, :, 1] = 1 31 | 32 | self.input_audios = torch.cat([self.guitar, self.noise], dim=0) 33 | self.input_targets = torch.cat([self.guitar_target, self.noise_target], dim=0) 34 | 35 | def test_varying_snr_within_batch(self): 36 | min_snr_in_db = 3 37 | max_snr_in_db = 30 38 | augment = Mix( 39 | min_snr_in_db=min_snr_in_db, 40 | max_snr_in_db=max_snr_in_db, 41 | p=1.0, 42 | output_type="dict", 43 | ) 44 | mixed_audios = augment(self.input_audios, self.sample_rate).samples 45 | 46 | self.assertEqual(tuple(mixed_audios.shape), tuple(self.input_audios.shape)) 47 | self.assertFalse(torch.equal(mixed_audios, self.input_audios)) 48 | 49 | added_audios = mixed_audios - self.input_audios 50 | 51 | for i in range(len(self.input_audios)): 52 | signal_rms = calculate_rms(self.input_audios[i]) 53 | added_rms = calculate_rms(added_audios[i]) 54 | snr_in_db = 20 * torch.log10(signal_rms / added_rms).item() 55 | self.assertGreaterEqual(snr_in_db, min_snr_in_db) 56 | self.assertLessEqual(snr_in_db, max_snr_in_db) 57 | 58 | def test_targets_union(self): 59 | augment = Mix(p=1.0, mix_target="union", output_type="dict") 60 | mixtures = augment( 61 | samples=self.input_audios, 62 | sample_rate=self.sample_rate, 63 | targets=self.input_targets, 64 | ) 65 | mixed_targets = mixtures.targets 66 | 67 | # check guitar target is still active in first (guitar) sample 68 | self.assertTrue( 69 | torch.equal(mixed_targets[0, :, :, 0], self.input_targets[0, :, :, 0]) 70 | ) 71 | # check noise target is still active in second (noise) sample 72 | self.assertTrue( 73 | torch.equal(mixed_targets[1, :, :, 1], self.input_targets[1, :, :, 1]) 74 | ) 75 | 76 | def test_targets_original(self): 77 | augment = Mix(p=1.0, mix_target="original", output_type="dict") 78 | mixtures = augment( 79 | samples=self.input_audios, 80 | sample_rate=self.sample_rate, 81 | targets=self.input_targets, 82 | ) 83 | mixed_targets = mixtures.targets 84 | 85 | self.assertTrue(torch.equal(mixed_targets, self.input_targets)) 86 | -------------------------------------------------------------------------------- /tests/test_some_of.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import torch 5 | from numpy.testing import assert_array_equal 6 | 7 | from torch_audiomentations import PolarityInversion, PeakNormalization, Gain 8 | from torch_audiomentations import SomeOf 9 | from torch_audiomentations.utils.object_dict import ObjectDict 10 | 11 | 12 | class TestSomeOf(unittest.TestCase): 13 | def setUp(self): 14 | self.sample_rate = 16000 15 | self.audio = torch.randn(1, 1, 16000) 16 | 17 | self.transforms = [ 18 | Gain(min_gain_in_db=-6.000001, max_gain_in_db=-2, p=1.0), 19 | PolarityInversion(p=1.0), 20 | PeakNormalization(p=1.0), 21 | ] 22 | 23 | def test_some_of_without_specifying_output_type(self): 24 | augment = SomeOf(2, self.transforms) 25 | 26 | self.assertEqual(len(augment.transform_indexes), 0) # no transforms applied yet 27 | output = augment(samples=self.audio, sample_rate=self.sample_rate) 28 | # This dtype should be torch.Tensor until we switch to ObjectDict by default 29 | assert type(output) == torch.Tensor 30 | self.assertEqual(len(augment.transform_indexes), 2) # 2 transforms applied 31 | 32 | def test_some_of_dict(self): 33 | augment = SomeOf(2, self.transforms, output_type="dict") 34 | 35 | self.assertEqual(len(augment.transform_indexes), 0) # no transforms applied yet 36 | output = augment(samples=self.audio, sample_rate=self.sample_rate) 37 | assert type(output) == ObjectDict 38 | self.assertEqual(len(augment.transform_indexes), 2) # 2 transforms applied 39 | 40 | def test_some_of_with_p_zero(self): 41 | augment = SomeOf(2, self.transforms, p=0.0, output_type="dict") 42 | 43 | self.assertEqual(len(augment.transform_indexes), 0) # no transforms applied yet 44 | processed_samples = augment( 45 | samples=self.audio, sample_rate=self.sample_rate 46 | ).samples 47 | self.assertEqual(len(augment.transform_indexes), 0) # 0 transforms applied 48 | 49 | def test_some_of_tuple(self): 50 | augment = SomeOf((1, None), self.transforms, output_type="dict") 51 | 52 | self.assertEqual(len(augment.transform_indexes), 0) # no transforms applied yet 53 | processed_samples = augment( 54 | samples=self.audio, sample_rate=self.sample_rate 55 | ).samples 56 | self.assertTrue( 57 | len(augment.transform_indexes) > 0 58 | ) # at least one transform applied 59 | 60 | def test_some_of_freeze_and_unfreeze_parameters(self): 61 | augment = SomeOf(2, self.transforms, output_type="dict") 62 | 63 | samples = np.array([[[1.0, 0.5, -0.25, -0.125, 0.0]]], dtype=np.float32) 64 | samples = torch.from_numpy(samples) 65 | 66 | self.assertEqual(len(augment.transform_indexes), 0) # no transforms applied yet 67 | processed_samples1 = augment( 68 | samples=samples, sample_rate=self.sample_rate 69 | ).samples.numpy() 70 | transform_indexes1 = augment.transform_indexes 71 | self.assertEqual(len(augment.transform_indexes), 2) 72 | 73 | augment.freeze_parameters() 74 | 75 | processed_samples2 = augment( 76 | samples=samples, sample_rate=self.sample_rate 77 | ).samples.numpy() 78 | transform_indexes2 = augment.transform_indexes 79 | assert_array_equal(processed_samples1, processed_samples2) 80 | assert_array_equal(transform_indexes1, transform_indexes2) 81 | -------------------------------------------------------------------------------- /tests/test_spliceout.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import torch 4 | import pytest 5 | 6 | from torch_audiomentations.augmentations.splice_out import SpliceOut 7 | from torch_audiomentations import Compose 8 | 9 | 10 | class TestSpliceout(unittest.TestCase): 11 | def test_splice_out(self): 12 | audio_samples = torch.rand(size=(8, 1, 32000), dtype=torch.float32) 13 | augment = Compose( 14 | [ 15 | SpliceOut(num_time_intervals=10, max_width=400, output_type="dict"), 16 | ], 17 | output_type="dict", 18 | ) 19 | splice_out_samples = augment( 20 | samples=audio_samples, sample_rate=16000 21 | ).samples.numpy() 22 | 23 | assert splice_out_samples.dtype == np.float32 24 | 25 | def test_splice_out_odd_hann(self): 26 | audio_samples = torch.rand(size=(8, 1, 32000), dtype=torch.float32) 27 | augment = Compose( 28 | [ 29 | SpliceOut(num_time_intervals=10, max_width=400, output_type="dict"), 30 | ], 31 | output_type="dict", 32 | ) 33 | splice_out_samples = augment( 34 | samples=audio_samples, sample_rate=16100 35 | ).samples.numpy() 36 | 37 | assert splice_out_samples.dtype == np.float32 38 | 39 | def test_splice_out_per_batch(self): 40 | audio_samples = torch.rand(size=(8, 1, 32000), dtype=torch.float32) 41 | augment = Compose( 42 | [ 43 | SpliceOut( 44 | num_time_intervals=10, 45 | max_width=400, 46 | mode="per_batch", 47 | p=1.0, 48 | output_type="dict", 49 | ), 50 | ], 51 | output_type="dict", 52 | ) 53 | splice_out_samples = augment( 54 | samples=audio_samples, sample_rate=16000 55 | ).samples.numpy() 56 | 57 | assert splice_out_samples.dtype == np.float32 58 | self.assertLess(splice_out_samples.sum(), audio_samples.numpy().sum()) 59 | self.assertEqual(splice_out_samples.shape, audio_samples.shape) 60 | 61 | def test_splice_out_multichannel(self): 62 | audio_samples = torch.rand(size=(8, 2, 32000), dtype=torch.float32) 63 | augment = Compose( 64 | [ 65 | SpliceOut(num_time_intervals=10, max_width=400, output_type="dict"), 66 | ], 67 | output_type="dict", 68 | ) 69 | splice_out_samples = augment( 70 | samples=audio_samples, sample_rate=16000 71 | ).samples.numpy() 72 | 73 | assert splice_out_samples.dtype == np.float32 74 | self.assertLess(splice_out_samples.sum(), audio_samples.numpy().sum()) 75 | self.assertEqual(splice_out_samples.shape, audio_samples.shape) 76 | 77 | @pytest.mark.skip(reason="This test fails and SpliceOut is not released yet") 78 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") 79 | def test_splice_out_cuda(self): 80 | audio_samples = ( 81 | torch.rand( 82 | size=(8, 1, 32000), dtype=torch.float32, device=torch.device("cuda") 83 | ) 84 | - 0.5 85 | ) 86 | augment = Compose( 87 | [ 88 | SpliceOut(num_time_intervals=10, max_width=400, output_type="dict"), 89 | ], 90 | output_type="dict", 91 | ) 92 | splice_out_samples = ( 93 | augment(samples=audio_samples, sample_rate=16000).samples.cpu().numpy() 94 | ) 95 | 96 | assert splice_out_samples.dtype == np.float32 97 | self.assertLess(splice_out_samples.sum(), audio_samples.cpu().numpy().sum()) 98 | -------------------------------------------------------------------------------- /torch_audiomentations/augmentations/peak_normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import typing 3 | from typing import Optional 4 | from torch import Tensor 5 | 6 | from ..core.transforms_interface import BaseWaveformTransform 7 | from ..utils.object_dict import ObjectDict 8 | 9 | 10 | class PeakNormalization(BaseWaveformTransform): 11 | """ 12 | Apply a constant amount of gain, so that highest signal level present in each audio snippet 13 | in the batch becomes 0 dBFS, i.e. the loudest level allowed if all samples must be between 14 | -1 and 1. 15 | 16 | This transform has an alternative mode (apply_to="only_too_loud_sounds") where it only 17 | applies to audio snippets that have extreme values outside the [-1, 1] range. This is useful 18 | for avoiding digital clipping in audio that is too loud, while leaving other audio 19 | untouched. 20 | """ 21 | 22 | supported_modes = {"per_batch", "per_example", "per_channel"} 23 | 24 | supports_multichannel = True 25 | requires_sample_rate = False 26 | 27 | supports_target = True 28 | requires_target = False 29 | 30 | def __init__( 31 | self, 32 | apply_to="all", 33 | mode: str = "per_example", 34 | p: float = 0.5, 35 | p_mode: typing.Optional[str] = None, 36 | sample_rate: typing.Optional[int] = None, 37 | target_rate: typing.Optional[int] = None, 38 | output_type: Optional[str] = None, 39 | ): 40 | super().__init__( 41 | mode=mode, 42 | p=p, 43 | p_mode=p_mode, 44 | sample_rate=sample_rate, 45 | target_rate=target_rate, 46 | output_type=output_type, 47 | ) 48 | assert apply_to in ("all", "only_too_loud_sounds") 49 | self.apply_to = apply_to 50 | 51 | def randomize_parameters( 52 | self, 53 | samples: Tensor = None, 54 | sample_rate: Optional[int] = None, 55 | targets: Optional[Tensor] = None, 56 | target_rate: Optional[int] = None, 57 | ): 58 | # Compute the most extreme value of each multichannel audio snippet in the batch 59 | most_extreme_values, _ = torch.max(torch.abs(samples), dim=-1) 60 | most_extreme_values, _ = torch.max(most_extreme_values, dim=-1) 61 | 62 | if self.apply_to == "all": 63 | # Avoid division by zero 64 | self.transform_parameters["selector"] = ( 65 | most_extreme_values > 0.0 66 | ) # type: torch.BoolTensor 67 | elif self.apply_to == "only_too_loud_sounds": 68 | # Apply peak normalization only to audio examples with 69 | # values outside the [-1, 1] range 70 | self.transform_parameters["selector"] = ( 71 | most_extreme_values > 1.0 72 | ) # type: torch.BoolTensor 73 | else: 74 | raise Exception("Unknown value of apply_to in PeakNormalization instance!") 75 | if self.transform_parameters["selector"].any(): 76 | self.transform_parameters["divisors"] = torch.reshape( 77 | most_extreme_values[self.transform_parameters["selector"]], (-1, 1, 1) 78 | ) 79 | 80 | def apply_transform( 81 | self, 82 | samples: Tensor = None, 83 | sample_rate: Optional[int] = None, 84 | targets: Optional[Tensor] = None, 85 | target_rate: Optional[int] = None, 86 | ) -> ObjectDict: 87 | if "divisors" in self.transform_parameters: 88 | samples[self.transform_parameters["selector"]] /= self.transform_parameters[ 89 | "divisors" 90 | ] 91 | 92 | return ObjectDict( 93 | samples=samples, 94 | sample_rate=sample_rate, 95 | targets=targets, 96 | target_rate=target_rate, 97 | ) 98 | -------------------------------------------------------------------------------- /tests/test_impulse_response.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import torch 5 | 6 | from torch_audiomentations import ApplyImpulseResponse 7 | from torch_audiomentations.utils.io import Audio 8 | from .utils import TEST_FIXTURES_DIR 9 | 10 | 11 | @pytest.fixture 12 | def sample_rate(): 13 | yield 16000 14 | 15 | 16 | @pytest.fixture 17 | def input_audio(sample_rate): 18 | audio = Audio(sample_rate, mono=True) 19 | return audio(os.path.join(TEST_FIXTURES_DIR, "acoustic_guitar_0.wav"))[None] 20 | 21 | 22 | @pytest.fixture 23 | def input_audios(input_audio): 24 | batch_size = 32 25 | return torch.cat([input_audio] * batch_size, dim=0) 26 | 27 | 28 | @pytest.fixture 29 | def ir_path(): 30 | yield os.path.join(TEST_FIXTURES_DIR, "ir") 31 | 32 | 33 | @pytest.fixture() 34 | def ir_transform(ir_path, sample_rate): 35 | return ApplyImpulseResponse( 36 | ir_path, p=1.0, sample_rate=sample_rate, output_type="dict" 37 | ) 38 | 39 | 40 | @pytest.fixture() 41 | def ir_transform_no_guarantee(ir_path, sample_rate): 42 | return ApplyImpulseResponse( 43 | ir_path, p=0.0, sample_rate=sample_rate, output_type="dict" 44 | ) 45 | 46 | 47 | def test_impulse_response_guaranteed_with_single_tensor_input(ir_transform, input_audio): 48 | mixed_input = ir_transform(input_audio).samples 49 | assert mixed_input.shape == input_audio.shape 50 | assert not torch.equal(mixed_input, input_audio) 51 | 52 | 53 | @pytest.mark.parametrize("compensate_for_propagation_delay", [False, True]) 54 | def test_impulse_response_guaranteed_with_batched_tensor_input( 55 | ir_path, sample_rate, input_audios, compensate_for_propagation_delay 56 | ): 57 | mixed_inputs = ApplyImpulseResponse( 58 | ir_path, 59 | compensate_for_propagation_delay=compensate_for_propagation_delay, 60 | p=1.0, 61 | sample_rate=sample_rate, 62 | output_type="dict", 63 | )(input_audios).samples 64 | assert mixed_inputs.shape == input_audios.shape 65 | assert not torch.equal(mixed_inputs, input_audios) 66 | 67 | 68 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") 69 | def test_impulse_response_guaranteed_with_batched_cuda_tensor_input( 70 | input_audios, ir_transform 71 | ): 72 | input_audio_cuda = input_audios.cuda() 73 | mixed_inputs = ir_transform(input_audio_cuda).samples 74 | assert not torch.equal(mixed_inputs, input_audio_cuda) 75 | assert mixed_inputs.shape == input_audio_cuda.shape 76 | assert mixed_inputs.dtype == input_audio_cuda.dtype 77 | assert mixed_inputs.device == input_audio_cuda.device 78 | 79 | 80 | def test_impulse_response_no_guarantee_with_single_tensor_input( 81 | input_audio, ir_transform_no_guarantee 82 | ): 83 | mixed_input = ir_transform_no_guarantee(input_audio).samples 84 | assert mixed_input.shape == input_audio.shape 85 | 86 | 87 | def test_impulse_response_no_guarantee_with_batched_tensor_input( 88 | input_audios, ir_transform_no_guarantee 89 | ): 90 | mixed_inputs = ir_transform_no_guarantee(input_audios).samples 91 | assert mixed_inputs.shape == input_audios.shape 92 | 93 | 94 | def test_impulse_response_guaranteed_with_zero_length_samples(ir_transform): 95 | empty_audio = torch.empty(0, 1, 16000) 96 | with pytest.warns(UserWarning, match="An empty samples tensor was passed"): 97 | mixed_inputs = ir_transform(empty_audio).samples 98 | 99 | assert torch.equal(mixed_inputs, empty_audio) 100 | 101 | 102 | def test_impulse_response_access_file_paths(ir_path, sample_rate, input_audios): 103 | augment = ApplyImpulseResponse( 104 | ir_path, p=1.0, sample_rate=sample_rate, output_type="dict" 105 | ) 106 | mixed_inputs = augment(samples=input_audios, sample_rate=sample_rate).samples 107 | 108 | assert mixed_inputs.shape == input_audios.shape 109 | 110 | ir_paths = augment.transform_parameters["ir_paths"] 111 | assert len(ir_paths) == input_audios.size(0) 112 | assert str(ir_paths[0]) == os.path.join(ir_path, "impulse_response_0.wav") 113 | -------------------------------------------------------------------------------- /tests/test_polarity_inversion.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | from numpy.testing import assert_almost_equal 7 | 8 | from torch_audiomentations import PolarityInversion 9 | 10 | 11 | class TestPolarityInversion(unittest.TestCase): 12 | def test_polarity_inversion(self): 13 | samples = np.array([[[1.0, 0.5, -0.25, -0.125, 0.0]]], dtype=np.float32) 14 | sample_rate = 16000 15 | 16 | augment = PolarityInversion(p=1.0, output_type="dict") 17 | inverted_samples = augment( 18 | samples=torch.from_numpy(samples), sample_rate=sample_rate 19 | ).samples.numpy() 20 | assert_almost_equal( 21 | inverted_samples, 22 | np.array([[[-1.0, -0.5, 0.25, 0.125, 0.0]]], dtype=np.float32), 23 | ) 24 | self.assertEqual(inverted_samples.dtype, np.float32) 25 | 26 | def test_polarity_inversion_zero_probability(self): 27 | samples = np.array([[[1.0, 0.5, -0.25, -0.125, 0.0]]], dtype=np.float32) 28 | sample_rate = 16000 29 | 30 | augment = PolarityInversion(p=0.0, output_type="dict") 31 | processed_samples = augment( 32 | samples=torch.from_numpy(samples), sample_rate=sample_rate 33 | ).samples.numpy() 34 | assert_almost_equal( 35 | processed_samples, 36 | np.array([[[1.0, 0.5, -0.25, -0.125, 0.0]]], dtype=np.float32), 37 | ) 38 | self.assertEqual(processed_samples.dtype, np.float32) 39 | 40 | def test_polarity_inversion_variability_within_batch(self): 41 | samples = np.array([[1.0, 0.5, 0.25, 0.125, 0.0]], dtype=np.float32) 42 | samples_batch = np.stack([samples] * 10000, axis=0) 43 | sample_rate = 16000 44 | 45 | augment = PolarityInversion(p=0.5, output_type="dict") 46 | processed_samples = augment( 47 | samples=torch.from_numpy(samples_batch), sample_rate=sample_rate 48 | ).samples.numpy() 49 | 50 | num_unprocessed_examples = 0 51 | num_processed_examples = 0 52 | for i in range(processed_samples.shape[0]): 53 | if np.sum(processed_samples[i]) > 0: 54 | num_unprocessed_examples += 1 55 | else: 56 | num_processed_examples += 1 57 | 58 | self.assertEqual(num_unprocessed_examples + num_processed_examples, 10000) 59 | 60 | print(num_processed_examples) 61 | self.assertGreater(num_processed_examples, 2000) 62 | self.assertLess(num_processed_examples, 8000) 63 | 64 | def test_polarity_inversion_multichannel(self): 65 | samples = np.array( 66 | [[[1.0, 0.5, -0.25, -0.125, 0.0]], [[1.0, 0.5, -0.25, -0.125, 0.0]]], 67 | dtype=np.float32, 68 | ) 69 | sample_rate = 16000 70 | 71 | augment = PolarityInversion(p=1.0, output_type="dict") 72 | inverted_samples = augment( 73 | samples=torch.from_numpy(samples), sample_rate=sample_rate 74 | ).samples.numpy() 75 | assert_almost_equal( 76 | inverted_samples, 77 | np.array( 78 | [[[-1.0, -0.5, 0.25, 0.125, 0.0]], [[-1.0, -0.5, 0.25, 0.125, 0.0]]], 79 | dtype=np.float32, 80 | ), 81 | ) 82 | self.assertEqual(inverted_samples.dtype, np.float32) 83 | 84 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") 85 | def test_polarity_inversion_cuda(self): 86 | samples = np.array([[[1.0, 0.5, -0.25, -0.125, 0.0]]], dtype=np.float32) 87 | sample_rate = 16000 88 | 89 | augment = PolarityInversion(p=1.0, output_type="dict").cuda() 90 | inverted_samples = ( 91 | augment(samples=torch.from_numpy(samples).cuda(), sample_rate=sample_rate) 92 | .samples.cpu() 93 | .numpy() 94 | ) 95 | assert_almost_equal( 96 | inverted_samples, 97 | np.array([[[-1.0, -0.5, 0.25, 0.125, 0.0]]], dtype=np.float32), 98 | ) 99 | self.assertEqual(inverted_samples.dtype, np.float32) 100 | -------------------------------------------------------------------------------- /torch_audiomentations/augmentations/splice_out.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from typing import Optional 4 | from torch import Tensor 5 | from torch.nn.functional import pad 6 | 7 | from ..core.transforms_interface import BaseWaveformTransform 8 | from ..utils.dsp import convert_decibels_to_amplitude_ratio 9 | from ..utils.object_dict import ObjectDict 10 | 11 | 12 | class SpliceOut(BaseWaveformTransform): 13 | 14 | """ 15 | spliceout augmentation proposed in https://arxiv.org/pdf/2110.00046.pdf 16 | silence padding is added at the end to retain the audio length. 17 | """ 18 | 19 | supported_modes = {"per_batch", "per_example"} 20 | requires_sample_rate = True 21 | 22 | def __init__( 23 | self, 24 | num_time_intervals=8, 25 | max_width=400, 26 | mode: str = "per_example", 27 | p: float = 0.5, 28 | p_mode: Optional[str] = None, 29 | sample_rate: Optional[int] = None, 30 | target_rate: Optional[int] = None, 31 | output_type: Optional[str] = None, 32 | ): 33 | """ 34 | param num_time_intervals: number of time intervals to spliceout 35 | param max_width: maximum width of each spliceout in milliseconds 36 | param n_fft: size of FFT 37 | """ 38 | 39 | super().__init__( 40 | mode=mode, 41 | p=p, 42 | p_mode=p_mode, 43 | sample_rate=sample_rate, 44 | target_rate=target_rate, 45 | output_type=output_type, 46 | ) 47 | self.num_time_intervals = num_time_intervals 48 | self.max_width = max_width 49 | 50 | def randomize_parameters( 51 | self, 52 | samples: Tensor = None, 53 | sample_rate: Optional[int] = None, 54 | targets: Optional[Tensor] = None, 55 | target_rate: Optional[int] = None, 56 | ): 57 | self.transform_parameters["splice_lengths"] = torch.randint( 58 | low=0, 59 | high=int(sample_rate * self.max_width * 1e-3), 60 | size=(samples.shape[0], self.num_time_intervals), 61 | ) 62 | 63 | def apply_transform( 64 | self, 65 | samples: Tensor = None, 66 | sample_rate: Optional[int] = None, 67 | targets: Optional[Tensor] = None, 68 | target_rate: Optional[int] = None, 69 | ) -> ObjectDict: 70 | spliceout_samples = [] 71 | 72 | for i in range(samples.shape[0]): 73 | random_lengths = self.transform_parameters["splice_lengths"][i] 74 | sample = samples[i][:, :] 75 | for j in range(self.num_time_intervals): 76 | start = torch.randint( 77 | 0, 78 | sample.shape[-1] - random_lengths[j], 79 | size=(1,), 80 | ) 81 | 82 | if random_lengths[j] % 2 != 0: 83 | random_lengths[j] += 1 84 | 85 | hann_window_len = random_lengths[j] 86 | hann_window = torch.hann_window(hann_window_len, device=samples.device) 87 | hann_window_left, hann_window_right = ( 88 | hann_window[: hann_window_len // 2], 89 | hann_window[hann_window_len // 2 :], 90 | ) 91 | 92 | fading_out, fading_in = ( 93 | sample[:, start : start + random_lengths[j] // 2], 94 | sample[:, start + random_lengths[j] // 2 : start + random_lengths[j]], 95 | ) 96 | crossfade = hann_window_right * fading_out + hann_window_left * fading_in 97 | sample = torch.cat( 98 | ( 99 | sample[:, :start], 100 | crossfade[:, :], 101 | sample[:, start + random_lengths[j] :], 102 | ), 103 | dim=-1, 104 | ) 105 | 106 | padding = torch.zeros( 107 | (samples[i].shape[0], samples[i].shape[-1] - sample.shape[-1]), 108 | dtype=torch.float32, 109 | device=sample.device, 110 | ) 111 | sample = torch.cat((sample, padding), dim=-1) 112 | spliceout_samples.append(sample.unsqueeze(0)) 113 | 114 | return ObjectDict( 115 | samples=torch.cat(spliceout_samples, dim=0), 116 | sample_rate=sample_rate, 117 | targets=targets, 118 | target_rate=target_rate, 119 | ) 120 | -------------------------------------------------------------------------------- /torch_audiomentations/augmentations/mix.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | from torch import Tensor 4 | 5 | from ..core.transforms_interface import BaseWaveformTransform 6 | from ..utils.dsp import calculate_rms 7 | from ..utils.io import Audio 8 | from ..utils.object_dict import ObjectDict 9 | 10 | 11 | class Mix(BaseWaveformTransform): 12 | """ 13 | Create a new sample by mixing it with another random sample from the same batch 14 | 15 | Signal-to-noise ratio (where "noise" is the second random sample) is selected 16 | randomly between `min_snr_in_db` and `max_snr_in_db`. 17 | 18 | `mix_target` controls how resulting targets are generated. It can be one of 19 | "original" (targets are those of the original sample) or "union" (targets are the 20 | union of original and overlapping targets) 21 | 22 | """ 23 | 24 | supported_modes = {"per_example", "per_channel"} 25 | 26 | supports_multichannel = True 27 | requires_sample_rate = False 28 | 29 | supports_target = True 30 | requires_target = False 31 | 32 | def __init__( 33 | self, 34 | min_snr_in_db: float = 0.0, 35 | max_snr_in_db: float = 5.0, 36 | mix_target: str = "union", 37 | mode: str = "per_example", 38 | p: float = 0.5, 39 | p_mode: str = None, 40 | sample_rate: int = None, 41 | target_rate: int = None, 42 | output_type: Optional[str] = None, 43 | ): 44 | super().__init__( 45 | mode=mode, 46 | p=p, 47 | p_mode=p_mode, 48 | sample_rate=sample_rate, 49 | target_rate=target_rate, 50 | output_type=output_type, 51 | ) 52 | self.min_snr_in_db = min_snr_in_db 53 | self.max_snr_in_db = max_snr_in_db 54 | if self.min_snr_in_db > self.max_snr_in_db: 55 | raise ValueError("min_snr_in_db must not be greater than max_snr_in_db") 56 | 57 | self.mix_target = mix_target 58 | if mix_target == "original": 59 | self._mix_target = lambda target, background_target, snr: target 60 | 61 | elif mix_target == "union": 62 | self._mix_target = lambda target, background_target, snr: torch.maximum( 63 | target, background_target 64 | ) 65 | 66 | else: 67 | raise ValueError("mix_target must be one of 'original' or 'union'.") 68 | 69 | def randomize_parameters( 70 | self, 71 | samples: Tensor = None, 72 | sample_rate: Optional[int] = None, 73 | targets: Optional[Tensor] = None, 74 | target_rate: Optional[int] = None, 75 | ): 76 | batch_size, num_channels, num_samples = samples.shape 77 | snr_distribution = torch.distributions.Uniform( 78 | low=torch.tensor( 79 | self.min_snr_in_db, 80 | dtype=torch.float32, 81 | device=samples.device, 82 | ), 83 | high=torch.tensor( 84 | self.max_snr_in_db, 85 | dtype=torch.float32, 86 | device=samples.device, 87 | ), 88 | validate_args=True, 89 | ) 90 | 91 | # randomize SNRs 92 | self.transform_parameters["snr_in_db"] = snr_distribution.sample( 93 | sample_shape=(batch_size,) 94 | ) 95 | 96 | # randomize index of second sample 97 | self.transform_parameters["sample_idx"] = torch.randint( 98 | 0, 99 | batch_size, 100 | (batch_size,), 101 | device=samples.device, 102 | ) 103 | 104 | def apply_transform( 105 | self, 106 | samples: Tensor = None, 107 | sample_rate: Optional[int] = None, 108 | targets: Optional[Tensor] = None, 109 | target_rate: Optional[int] = None, 110 | ) -> ObjectDict: 111 | snr = self.transform_parameters["snr_in_db"] 112 | idx = self.transform_parameters["sample_idx"] 113 | 114 | background_samples = Audio.rms_normalize(samples[idx]) 115 | background_rms = calculate_rms(samples) / (10 ** (snr.unsqueeze(dim=-1) / 20)) 116 | 117 | mixed_samples = samples + background_rms.unsqueeze(-1) * background_samples 118 | 119 | if targets is None: 120 | mixed_targets = None 121 | 122 | else: 123 | background_targets = targets[idx] 124 | mixed_targets = self._mix_target(targets, background_targets, snr) 125 | 126 | return ObjectDict( 127 | samples=mixed_samples, 128 | sample_rate=sample_rate, 129 | targets=mixed_targets, 130 | target_rate=target_rate, 131 | ) 132 | -------------------------------------------------------------------------------- /torch_audiomentations/augmentations/low_pass_filter.py: -------------------------------------------------------------------------------- 1 | import julius 2 | import torch 3 | from torch import Tensor 4 | from typing import Optional 5 | 6 | 7 | from ..core.transforms_interface import BaseWaveformTransform 8 | from ..utils.mel_scale import convert_frequencies_to_mels, convert_mels_to_frequencies 9 | from ..utils.object_dict import ObjectDict 10 | 11 | 12 | class LowPassFilter(BaseWaveformTransform): 13 | """ 14 | Apply low-pass filtering to the input audio. 15 | """ 16 | 17 | supported_modes = {"per_batch", "per_example", "per_channel"} 18 | 19 | supports_multichannel = True 20 | requires_sample_rate = True 21 | 22 | supports_target = True 23 | requires_target = False 24 | 25 | def __init__( 26 | self, 27 | min_cutoff_freq: float = 150.0, 28 | max_cutoff_freq: float = 7500.0, 29 | mode: str = "per_example", 30 | p: float = 0.5, 31 | p_mode: str = None, 32 | sample_rate: int = None, 33 | target_rate: int = None, 34 | output_type: Optional[str] = None, 35 | ): 36 | """ 37 | :param min_cutoff_freq: Minimum cutoff frequency in hertz 38 | :param max_cutoff_freq: Maximum cutoff frequency in hertz 39 | :param mode: 40 | :param p: 41 | :param p_mode: 42 | :param sample_rate: 43 | """ 44 | super().__init__( 45 | mode=mode, 46 | p=p, 47 | p_mode=p_mode, 48 | sample_rate=sample_rate, 49 | target_rate=target_rate, 50 | output_type=output_type, 51 | ) 52 | 53 | self.min_cutoff_freq = min_cutoff_freq 54 | self.max_cutoff_freq = max_cutoff_freq 55 | if self.min_cutoff_freq > self.max_cutoff_freq: 56 | raise ValueError("min_cutoff_freq must not be greater than max_cutoff_freq") 57 | 58 | self.cached_lpf = None 59 | 60 | def randomize_parameters( 61 | self, 62 | samples: Tensor = None, 63 | sample_rate: Optional[int] = None, 64 | targets: Optional[Tensor] = None, 65 | target_rate: Optional[int] = None, 66 | ): 67 | """ 68 | :params samples: (batch_size, num_channels, num_samples) 69 | """ 70 | batch_size, _, num_samples = samples.shape 71 | 72 | if self.min_cutoff_freq == self.max_cutoff_freq: 73 | # Speed up computation by caching the LPF instance if the cutoff is constant 74 | cutoff_as_fraction_of_sr = self.min_cutoff_freq / sample_rate 75 | lpf_needs_init = ( 76 | self.cached_lpf is None 77 | or self.cached_lpf.cutoff != cutoff_as_fraction_of_sr 78 | ) 79 | if lpf_needs_init: 80 | self.cached_lpf = julius.LowPassFilter(cutoff=cutoff_as_fraction_of_sr) 81 | self.transform_parameters["cutoff_freq"] = torch.full( 82 | size=(batch_size,), 83 | fill_value=self.min_cutoff_freq, 84 | dtype=torch.float32, 85 | device=samples.device, 86 | ) 87 | else: 88 | # Sample frequencies uniformly in mel space, then convert back to frequency 89 | dist = torch.distributions.Uniform( 90 | low=convert_frequencies_to_mels( 91 | torch.tensor( 92 | self.min_cutoff_freq, dtype=torch.float32, device=samples.device 93 | ) 94 | ), 95 | high=convert_frequencies_to_mels( 96 | torch.tensor( 97 | self.max_cutoff_freq, dtype=torch.float32, device=samples.device 98 | ) 99 | ), 100 | validate_args=True, 101 | ) 102 | self.transform_parameters["cutoff_freq"] = convert_mels_to_frequencies( 103 | dist.sample(sample_shape=(batch_size,)) 104 | ) 105 | self.cached_lpf = None 106 | 107 | def apply_transform( 108 | self, 109 | samples: Tensor = None, 110 | sample_rate: Optional[int] = None, 111 | targets: Optional[Tensor] = None, 112 | target_rate: Optional[int] = None, 113 | ) -> ObjectDict: 114 | batch_size, num_channels, num_samples = samples.shape 115 | 116 | if self.cached_lpf is None: 117 | cutoffs_as_fraction_of_sample_rate = ( 118 | self.transform_parameters["cutoff_freq"] / sample_rate 119 | ) 120 | # TODO: Instead of using a for loop, perform batched compute to speed things up 121 | for i in range(batch_size): 122 | samples[i] = julius.lowpass_filter( 123 | samples[i], cutoffs_as_fraction_of_sample_rate[i].item() 124 | ) 125 | else: 126 | for i in range(batch_size): 127 | samples[i] = self.cached_lpf(samples[i]) 128 | 129 | return ObjectDict( 130 | samples=samples, 131 | sample_rate=sample_rate, 132 | targets=targets, 133 | target_rate=target_rate, 134 | ) 135 | -------------------------------------------------------------------------------- /tests/test_padding.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import torch 4 | from numpy.testing import assert_almost_equal 5 | import pytest 6 | 7 | from torch_audiomentations.augmentations.padding import Padding 8 | 9 | 10 | class TestPadding(unittest.TestCase): 11 | def test_padding_end(self): 12 | audio_samples = torch.rand(size=(2, 2, 32000), dtype=torch.float32) 13 | augment = Padding( 14 | min_fraction=0.2, 15 | max_fraction=0.5, 16 | pad_section="end", 17 | p=1.0, 18 | output_type="dict", 19 | ) 20 | padded_samples = augment(audio_samples).samples 21 | 22 | self.assertEqual(audio_samples.shape, padded_samples.shape) 23 | assert_almost_equal(padded_samples[..., -6400:].numpy(), np.zeros((2, 2, 6400))) 24 | 25 | def test_padding_start(self): 26 | audio_samples = torch.rand(size=(2, 2, 32000), dtype=torch.float32) 27 | augment = Padding( 28 | min_fraction=0.2, 29 | max_fraction=0.5, 30 | pad_section="start", 31 | p=1.0, 32 | output_type="dict", 33 | ) 34 | padded_samples = augment(audio_samples).samples 35 | 36 | self.assertEqual(audio_samples.shape, padded_samples.shape) 37 | assert_almost_equal(padded_samples[..., :6400].numpy(), np.zeros((2, 2, 6400))) 38 | 39 | def test_padding_zero(self): 40 | audio_samples = torch.rand(size=(2, 2, 32000), dtype=torch.float32) 41 | augment = Padding(min_fraction=0.2, max_fraction=0.5, p=0.0, output_type="dict") 42 | padded_samples = augment(audio_samples).samples 43 | 44 | self.assertEqual(audio_samples.shape, padded_samples.shape) 45 | assert_almost_equal(audio_samples.numpy(), padded_samples.numpy()) 46 | 47 | def test_padding_perexample(self): 48 | audio_samples = torch.rand(size=(10, 2, 32000), dtype=torch.float32) 49 | augment = Padding( 50 | min_fraction=0.2, 51 | max_fraction=0.5, 52 | pad_section="start", 53 | p=0.5, 54 | mode="per_example", 55 | p_mode="per_example", 56 | output_type="dict", 57 | ) 58 | 59 | padded_samples = augment(audio_samples).samples.numpy() 60 | num_unprocessed_examples = 0.0 61 | num_processed_examples = 0.0 62 | for i, sample in enumerate(padded_samples): 63 | if np.allclose(audio_samples[i], sample): 64 | num_unprocessed_examples += 1 65 | else: 66 | num_processed_examples += 1 67 | 68 | self.assertLess(padded_samples.sum(), audio_samples.numpy().sum()) 69 | 70 | def test_padding_perchannel(self): 71 | audio_samples = torch.rand(size=(10, 2, 32000), dtype=torch.float32) 72 | augment = Padding( 73 | min_fraction=0.2, 74 | max_fraction=0.5, 75 | pad_section="start", 76 | p=0.5, 77 | mode="per_channel", 78 | p_mode="per_channel", 79 | output_type="dict", 80 | ) 81 | 82 | padded_samples = augment(audio_samples).samples.numpy() 83 | num_unprocessed_examples = 0.0 84 | num_processed_examples = 0.0 85 | for i, sample in enumerate(padded_samples): 86 | if np.allclose(audio_samples[i], sample): 87 | num_unprocessed_examples += 1 88 | else: 89 | num_processed_examples += 1 90 | 91 | self.assertLess(padded_samples.sum(), audio_samples.numpy().sum()) 92 | 93 | def test_padding_variability_perexample(self): 94 | audio_samples = torch.rand(size=(10, 2, 32000), dtype=torch.float32) 95 | augment = Padding( 96 | min_fraction=0.2, 97 | max_fraction=0.5, 98 | pad_section="start", 99 | p=0.5, 100 | mode="per_example", 101 | p_mode="per_example", 102 | output_type="dict", 103 | ) 104 | 105 | padded_samples = augment(audio_samples).samples.numpy() 106 | num_unprocessed_examples = 0.0 107 | num_processed_examples = 0.0 108 | for i, sample in enumerate(padded_samples): 109 | if np.allclose(audio_samples[i], sample): 110 | num_unprocessed_examples += 1 111 | else: 112 | num_processed_examples += 1 113 | 114 | self.assertEqual(num_processed_examples + num_unprocessed_examples, 10) 115 | self.assertGreater(num_processed_examples, 2) 116 | self.assertLess(num_unprocessed_examples, 8) 117 | 118 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") 119 | def test_padding_cuda(self): 120 | audio_samples = torch.rand( 121 | size=(2, 2, 32000), dtype=torch.float32, device=torch.device("cuda") 122 | ) 123 | augment = Padding(min_fraction=0.2, max_fraction=0.5, p=1.0, output_type="dict") 124 | padded_samples = augment(audio_samples).samples 125 | 126 | self.assertEqual(audio_samples.shape, padded_samples.shape) 127 | assert_almost_equal( 128 | padded_samples[..., -6400:].cpu().numpy(), np.zeros((2, 2, 6400)) 129 | ) 130 | -------------------------------------------------------------------------------- /scripts/measure_convolve_execution_time.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from cpuinfo import get_cpu_info 4 | from scipy.signal import convolve as scipy_convolve 5 | from tqdm import tqdm 6 | 7 | from scripts.demo import TEST_FIXTURES_DIR, timer 8 | from scripts.plot import show_horizontal_bar_chart 9 | from torch_audiomentations.utils.convolution import convolve as torch_convolve 10 | from torch_audiomentations.utils.io import Audio 11 | 12 | if __name__ == "__main__": 13 | file_path = TEST_FIXTURES_DIR / "acoustic_guitar_0.wav" 14 | sample_rate = 48000 15 | audio = Audio(sample_rate, mono=True) 16 | samples = audio(file_path).numpy() 17 | ir_samples = audio(TEST_FIXTURES_DIR / "ir" / "impulse_response_0.wav").numpy() 18 | 19 | is_cuda_available = torch.cuda.is_available() 20 | print("Is torch CUDA available:", is_cuda_available) 21 | 22 | num_examples = 32 23 | 24 | execution_times = { 25 | # tuples of (description, batch size) 26 | ("scipy direct", 1): [], 27 | ("scipy FFT", 1): [], 28 | ("torch FFT CPU", 1): [], 29 | ("torch FFT CPU", num_examples): [], 30 | ("torch FFT CUDA", 1): [], 31 | ("torch FFT CUDA", num_examples): [], 32 | } 33 | 34 | for i in tqdm(range(num_examples), desc="scipy, method='direct', batch size=1"): 35 | with timer("scipy direct") as t: 36 | expected_output = scipy_convolve(samples, ir_samples, method="direct") 37 | execution_times[(t.description, 1)].append(t.execution_time) 38 | 39 | for i in tqdm(range(num_examples), desc="scipy, method='fft', batch size=1"): 40 | with timer("scipy FFT") as t: 41 | expected_output = scipy_convolve(samples, ir_samples, method="fft") 42 | execution_times[(t.description, 1)].append(t.execution_time) 43 | 44 | pytorch_samples_cpu = torch.from_numpy(samples) 45 | pytorch_ir_samples_cpu = torch.from_numpy(ir_samples) 46 | 47 | for i in tqdm(range(num_examples), desc="torch FFT CPU, batch size=1"): 48 | with timer("torch FFT CPU") as t: 49 | _ = torch_convolve(pytorch_samples_cpu, pytorch_ir_samples_cpu).numpy() 50 | execution_times[(t.description, 1)].append(t.execution_time) 51 | 52 | pytorch_samples_batch_cpu = torch.stack([pytorch_samples_cpu] * num_examples) 53 | 54 | for i in tqdm(range(5), desc="torch FFT CPU, batch size={}".format(num_examples)): 55 | with timer("torch FFT CPU") as t: 56 | _ = torch_convolve(pytorch_samples_batch_cpu, pytorch_ir_samples_cpu).numpy() 57 | execution_times[(t.description, num_examples)].append(t.execution_time) 58 | 59 | if is_cuda_available: 60 | pytorch_samples_cuda = torch.from_numpy(samples).cuda() 61 | pytorch_ir_samples_cuda = torch.from_numpy(ir_samples).cuda() 62 | 63 | for i in tqdm(range(num_examples), desc="torch FFT CUDA, batch size=1"): 64 | with timer("torch FFT CUDA") as t: 65 | _ = ( 66 | torch_convolve(pytorch_samples_cuda, pytorch_ir_samples_cuda) 67 | .cpu() 68 | .numpy() 69 | ) 70 | execution_times[(t.description, 1)].append(t.execution_time) 71 | 72 | pytorch_samples_batch_cuda = pytorch_samples_batch_cpu.cuda() 73 | 74 | for i in tqdm( 75 | range(5), desc="torch FFT CUDA, batch size={}".format(num_examples) 76 | ): 77 | with timer("torch FFT CUDA") as t: 78 | _ = ( 79 | torch_convolve(pytorch_samples_batch_cuda, pytorch_ir_samples_cuda) 80 | .cpu() 81 | .numpy() 82 | ) 83 | execution_times[(t.description, num_examples)].append(t.execution_time) 84 | 85 | normalized_execution_times = {} 86 | for description, batch_size in execution_times: 87 | times = execution_times[(description, batch_size)] 88 | if len(times) == 0: 89 | continue 90 | times[0] = float(np.median(times)) 91 | # We consider the first execution to be a warm-up, as it may take many magnitudes longer 92 | # Therefore, we simply replace its value with the median 93 | batch_execution_time = num_examples * sum(times) / (batch_size * len(times)) 94 | print( 95 | "{:<20} batch size = {:<4} {:.3f} s".format( 96 | description, batch_size, batch_execution_time 97 | ) 98 | ) 99 | normalized_execution_times[ 100 | "{}, batch size={}".format(description, batch_size) 101 | ] = batch_execution_time 102 | 103 | cpu_info = get_cpu_info() 104 | cpu_info_string = "CPU: {}".format(cpu_info["brand_raw"]) 105 | print(cpu_info_string) 106 | 107 | plot_title = "Convolving an IR with {} sounds\n{}".format( 108 | num_examples, cpu_info_string 109 | ) 110 | 111 | if is_cuda_available: 112 | cuda_device_name = torch.cuda.get_device_name() 113 | cuda_device_string = "CUDA device: {}".format(cuda_device_name) 114 | print(cuda_device_string) 115 | plot_title += "\n{}".format(cuda_device_string) 116 | 117 | show_horizontal_bar_chart(normalized_execution_times, plot_title) 118 | -------------------------------------------------------------------------------- /tests/test_colored_noise.py: -------------------------------------------------------------------------------- 1 | import random 2 | import pytest 3 | import torch 4 | 5 | from torch_audiomentations import AddColoredNoise 6 | from torch_audiomentations.utils.io import Audio 7 | from .utils import TEST_FIXTURES_DIR 8 | 9 | 10 | @pytest.fixture 11 | def setup_audio(): 12 | sample_rate = 16000 13 | audio = Audio(sample_rate=sample_rate) 14 | batch_size = 16 15 | empty_input_audio = torch.empty(0, 1, 16000) 16 | 17 | input_audio = audio(TEST_FIXTURES_DIR / "acoustic_guitar_0.wav").unsqueeze(0) 18 | input_audios = torch.cat([input_audio] * batch_size, dim=0) 19 | 20 | cl_noise_transform_guaranteed = AddColoredNoise(20, p=1.0, output_type="dict") 21 | cl_noise_transform_no_guarantee = AddColoredNoise(20, p=0.0, output_type="dict") 22 | 23 | return { 24 | "sample_rate": sample_rate, 25 | "empty_input_audio": empty_input_audio, 26 | "input_audio": input_audio, 27 | "input_audios": input_audios, 28 | "cl_noise_transform_guaranteed": cl_noise_transform_guaranteed, 29 | "cl_noise_transform_no_guarantee": cl_noise_transform_no_guarantee, 30 | } 31 | 32 | 33 | def test_colored_noise_no_guarantee_with_single_tensor(setup_audio): 34 | input_audio = setup_audio["input_audio"] 35 | transform = setup_audio["cl_noise_transform_no_guarantee"] 36 | sample_rate = setup_audio["sample_rate"] 37 | 38 | mixed_input = transform(input_audio, sample_rate).samples 39 | assert torch.equal(mixed_input, input_audio) 40 | assert mixed_input.size(0) == input_audio.size(0) 41 | 42 | 43 | def test_background_noise_no_guarantee_with_empty_tensor(setup_audio): 44 | empty_input_audio = setup_audio["empty_input_audio"] 45 | transform = setup_audio["cl_noise_transform_no_guarantee"] 46 | sample_rate = setup_audio["sample_rate"] 47 | 48 | with pytest.warns(UserWarning, match="An empty samples tensor was passed"): 49 | mixed_input = transform(empty_input_audio, sample_rate).samples 50 | 51 | assert torch.equal(mixed_input, empty_input_audio) 52 | assert mixed_input.size(0) == empty_input_audio.size(0) 53 | 54 | 55 | def test_colored_noise_guaranteed_with_zero_length_samples(setup_audio): 56 | empty_input_audio = setup_audio["empty_input_audio"] 57 | transform = setup_audio["cl_noise_transform_guaranteed"] 58 | sample_rate = setup_audio["sample_rate"] 59 | 60 | with pytest.warns(UserWarning, match="An empty samples tensor was passed"): 61 | mixed_input = transform(empty_input_audio, sample_rate).samples 62 | 63 | assert torch.equal(mixed_input, empty_input_audio) 64 | assert mixed_input.size(0) == empty_input_audio.size(0) 65 | 66 | 67 | def test_colored_noise_guaranteed_with_single_tensor(setup_audio): 68 | input_audio = setup_audio["input_audio"] 69 | transform = setup_audio["cl_noise_transform_guaranteed"] 70 | sample_rate = setup_audio["sample_rate"] 71 | 72 | mixed_input = transform(input_audio, sample_rate).samples 73 | assert not torch.equal(mixed_input, input_audio) 74 | assert mixed_input.size(0) == input_audio.size(0) 75 | assert mixed_input.size(1) == input_audio.size(1) 76 | 77 | 78 | def test_colored_noise_guaranteed_with_batched_tensor(setup_audio): 79 | random.seed(42) 80 | input_audios = setup_audio["input_audios"] 81 | transform = setup_audio["cl_noise_transform_guaranteed"] 82 | sample_rate = setup_audio["sample_rate"] 83 | 84 | mixed_inputs = transform(input_audios, sample_rate).samples 85 | assert not torch.equal(mixed_inputs, input_audios) 86 | assert mixed_inputs.size(0) == input_audios.size(0) 87 | assert mixed_inputs.size(1) == input_audios.size(1) 88 | 89 | 90 | def test_same_min_max_f_decay(setup_audio): 91 | random.seed(42) 92 | input_audios = setup_audio["input_audios"] 93 | sample_rate = setup_audio["sample_rate"] 94 | 95 | transform = AddColoredNoise( 96 | 20, min_f_decay=1.0, max_f_decay=1.0, p=1.0, output_type="dict" 97 | ) 98 | outputs = transform(input_audios, sample_rate).samples 99 | assert outputs.size(0) == input_audios.size(0) 100 | assert outputs.size(1) == input_audios.size(1) 101 | 102 | 103 | def test_invalid_params(): 104 | with pytest.raises(ValueError): 105 | AddColoredNoise(min_snr_in_db=30, max_snr_in_db=3, p=1.0, output_type="dict") 106 | with pytest.raises(ValueError): 107 | AddColoredNoise(min_f_decay=2, max_f_decay=1, p=1.0, output_type="dict") 108 | 109 | 110 | def test_various_lengths_and_sample_rates(): 111 | random.seed(42) 112 | transform = AddColoredNoise( 113 | min_snr_in_db=10, max_snr_in_db=12, p=1.0, output_type="dict" 114 | ) 115 | 116 | for _ in range(100): 117 | length = random.randint(1000, 100_000) 118 | sample_rate = random.randint(1000, 100_000) 119 | input_audio = torch.randn(1, 1, length, dtype=torch.float32) 120 | output_audio = transform(input_audio, sample_rate=sample_rate).samples 121 | 122 | assert output_audio.shape == input_audio.shape 123 | assert output_audio.dtype == input_audio.dtype 124 | 125 | input_audio = torch.randn(1, 1, 16001, dtype=torch.float32) 126 | output_audio = transform(input_audio, sample_rate=16001).samples 127 | assert output_audio.shape == input_audio.shape 128 | assert not torch.equal(output_audio, input_audio) 129 | -------------------------------------------------------------------------------- /torch_audiomentations/augmentations/impulse_response.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | from typing import Union, List, Optional 4 | from torch import Tensor 5 | 6 | import torch 7 | from torch.nn.utils.rnn import pad_sequence 8 | 9 | from ..core.transforms_interface import BaseWaveformTransform, EmptyPathException 10 | from ..utils.convolution import convolve 11 | from ..utils.file import find_audio_files_in_paths 12 | from ..utils.io import Audio 13 | from ..utils.object_dict import ObjectDict 14 | 15 | 16 | class ApplyImpulseResponse(BaseWaveformTransform): 17 | """ 18 | Convolve the given audio with impulse responses. 19 | """ 20 | 21 | supported_modes = {"per_batch", "per_example", "per_channel"} 22 | 23 | # Note: This transform has only partial support for multichannel audio. IRs that are not 24 | # mono get mixed down to mono before they are convolved with all channels in the input. 25 | supports_multichannel = True 26 | requires_sample_rate = True 27 | 28 | supports_target = False # FIXME: some work is needed to support targets (see FIXMEs in apply_transform) 29 | requires_target = False 30 | 31 | def __init__( 32 | self, 33 | ir_paths: Union[List[Path], List[str], Path, str], 34 | convolve_mode: str = "full", 35 | compensate_for_propagation_delay: bool = False, 36 | mode: str = "per_example", 37 | p: float = 0.5, 38 | p_mode: str = None, 39 | sample_rate: int = None, 40 | target_rate: int = None, 41 | output_type: Optional[str] = None, 42 | ): 43 | """ 44 | :param ir_paths: Either a path to a folder with audio files or a list of paths to audio files. 45 | :param convolve_mode: 46 | :param compensate_for_propagation_delay: Convolving audio with a RIR normally 47 | introduces a bit of delay, especially when the peak absolute amplitude in the 48 | RIR is not in the very beginning. When compensate_for_propagation_delay is 49 | set to True, the returned slices of audio will be offset to compensate for 50 | this delay. 51 | :param mode: 52 | :param p: 53 | :param p_mode: 54 | :param sample_rate: 55 | :param target_rate: 56 | """ 57 | 58 | super().__init__( 59 | mode=mode, 60 | p=p, 61 | p_mode=p_mode, 62 | sample_rate=sample_rate, 63 | target_rate=target_rate, 64 | output_type=output_type, 65 | ) 66 | 67 | # TODO: check that one can read audio files 68 | self.ir_paths = find_audio_files_in_paths(ir_paths) 69 | 70 | if sample_rate is not None: 71 | self.audio = Audio(sample_rate=sample_rate, mono=True) 72 | 73 | if len(self.ir_paths) == 0: 74 | raise EmptyPathException("There are no supported audio files found.") 75 | 76 | self.convolve_mode = convolve_mode 77 | self.compensate_for_propagation_delay = compensate_for_propagation_delay 78 | 79 | def randomize_parameters( 80 | self, 81 | samples: Tensor = None, 82 | sample_rate: Optional[int] = None, 83 | targets: Optional[Tensor] = None, 84 | target_rate: Optional[int] = None, 85 | ): 86 | batch_size, _, _ = samples.shape 87 | 88 | audio = self.audio if hasattr(self, "audio") else Audio(sample_rate, mono=True) 89 | 90 | random_ir_paths = random.choices(self.ir_paths, k=batch_size) 91 | 92 | self.transform_parameters["ir"] = pad_sequence( 93 | [audio(ir_path).transpose(0, 1) for ir_path in random_ir_paths], 94 | batch_first=True, 95 | padding_value=0.0, 96 | ).transpose(1, 2) 97 | 98 | self.transform_parameters["ir_paths"] = random_ir_paths 99 | 100 | def apply_transform( 101 | self, 102 | samples: Tensor = None, 103 | sample_rate: Optional[int] = None, 104 | targets: Optional[Tensor] = None, 105 | target_rate: Optional[int] = None, 106 | ) -> ObjectDict: 107 | batch_size, num_channels, num_samples = samples.shape 108 | 109 | # (batch_size, 1, max_ir_length) 110 | ir = self.transform_parameters["ir"].to(samples.device) 111 | 112 | convolved_samples = convolve( 113 | samples, ir.expand(-1, num_channels, -1), mode=self.convolve_mode 114 | ) 115 | 116 | if self.compensate_for_propagation_delay: 117 | propagation_delays = ir.abs().argmax(dim=2, keepdim=False)[:, 0] 118 | convolved_samples = torch.stack( 119 | [ 120 | convolved_sample[ 121 | :, propagation_delay : propagation_delay + num_samples 122 | ] 123 | for convolved_sample, propagation_delay in zip( 124 | convolved_samples, propagation_delays 125 | ) 126 | ], 127 | dim=0, 128 | ) 129 | 130 | return ObjectDict( 131 | samples=convolved_samples, 132 | sample_rate=sample_rate, 133 | targets=targets, # FIXME compensate targets as well? 134 | target_rate=target_rate, 135 | ) 136 | 137 | else: 138 | return ObjectDict( 139 | samples=convolved_samples[..., :num_samples], 140 | sample_rate=sample_rate, 141 | targets=targets, # FIXME crop targets as well? 142 | target_rate=target_rate, 143 | ) 144 | -------------------------------------------------------------------------------- /torch_audiomentations/augmentations/pitch_shift.py: -------------------------------------------------------------------------------- 1 | from random import choices 2 | 3 | from torch import Tensor 4 | from typing import Optional 5 | from torch_pitch_shift import pitch_shift, get_fast_shifts, semitones_to_ratio 6 | 7 | from ..core.transforms_interface import BaseWaveformTransform 8 | from ..utils.object_dict import ObjectDict 9 | 10 | 11 | class PitchShift(BaseWaveformTransform): 12 | """ 13 | Pitch-shift sounds up or down without changing the tempo. 14 | """ 15 | 16 | supported_modes = {"per_batch", "per_example", "per_channel"} 17 | 18 | supports_multichannel = True 19 | requires_sample_rate = True 20 | 21 | supports_target = True 22 | requires_target = False 23 | 24 | def __init__( 25 | self, 26 | min_transpose_semitones: float = -4.0, 27 | max_transpose_semitones: float = 4.0, 28 | mode: str = "per_example", 29 | p: float = 0.5, 30 | p_mode: str = None, 31 | sample_rate: int = None, 32 | target_rate: int = None, 33 | output_type: Optional[str] = None, 34 | ): 35 | """ 36 | :param sample_rate: 37 | :param min_transpose_semitones: Minimum pitch shift transposition in semitones (default -4.0) 38 | :param max_transpose_semitones: Maximum pitch shift transposition in semitones (default +4.0) 39 | :param mode: ``per_example``, ``per_channel``, or ``per_batch``. Default ``per_example``. 40 | :param p: 41 | :param p_mode: 42 | :param target_rate: 43 | """ 44 | super().__init__( 45 | mode=mode, 46 | p=p, 47 | p_mode=p_mode, 48 | sample_rate=sample_rate, 49 | target_rate=target_rate, 50 | output_type=output_type, 51 | ) 52 | 53 | if min_transpose_semitones > max_transpose_semitones: 54 | raise ValueError("max_transpose_semitones must be > min_transpose_semitones") 55 | if not sample_rate: 56 | raise ValueError("sample_rate is invalid.") 57 | self._sample_rate = sample_rate 58 | self._fast_shifts = get_fast_shifts( 59 | sample_rate, 60 | lambda x: x >= semitones_to_ratio(min_transpose_semitones) 61 | and x <= semitones_to_ratio(max_transpose_semitones) 62 | and x != 1, 63 | ) 64 | if not len(self._fast_shifts): 65 | raise ValueError( 66 | "No fast pitch-shift ratios could be computed for the given sample rate and transpose range." 67 | ) 68 | self._mode = mode 69 | 70 | def randomize_parameters( 71 | self, 72 | samples: Tensor = None, 73 | sample_rate: Optional[int] = None, 74 | targets: Optional[Tensor] = None, 75 | target_rate: Optional[int] = None, 76 | ): 77 | """ 78 | :param samples: (batch_size, num_channels, num_samples) 79 | :param sample_rate: 80 | """ 81 | batch_size, num_channels, num_samples = samples.shape 82 | 83 | if self._mode == "per_example": 84 | self.transform_parameters["transpositions"] = choices( 85 | self._fast_shifts, k=batch_size 86 | ) 87 | elif self._mode == "per_channel": 88 | self.transform_parameters["transpositions"] = list( 89 | zip( 90 | *[ 91 | choices(self._fast_shifts, k=batch_size) 92 | for i in range(num_channels) 93 | ] 94 | ) 95 | ) 96 | elif self._mode == "per_batch": 97 | self.transform_parameters["transpositions"] = choices(self._fast_shifts, k=1) 98 | 99 | def apply_transform( 100 | self, 101 | samples: Tensor = None, 102 | sample_rate: Optional[int] = None, 103 | targets: Optional[Tensor] = None, 104 | target_rate: Optional[int] = None, 105 | ) -> ObjectDict: 106 | """ 107 | :param samples: (batch_size, num_channels, num_samples) 108 | :param sample_rate: 109 | """ 110 | batch_size, num_channels, num_samples = samples.shape 111 | 112 | if sample_rate is not None and sample_rate != self._sample_rate: 113 | raise ValueError( 114 | "sample_rate must match the value of sample_rate " 115 | + "passed into the PitchShift constructor" 116 | ) 117 | sample_rate = self.sample_rate 118 | 119 | if self._mode == "per_example": 120 | for i in range(batch_size): 121 | samples[i, ...] = pitch_shift( 122 | samples[i][None], 123 | self.transform_parameters["transpositions"][i], 124 | sample_rate, 125 | )[0] 126 | 127 | elif self._mode == "per_channel": 128 | for i in range(batch_size): 129 | for j in range(num_channels): 130 | samples[i, j, ...] = pitch_shift( 131 | samples[i][j][None][None], 132 | self.transform_parameters["transpositions"][i][j], 133 | sample_rate, 134 | )[0][0] 135 | 136 | elif self._mode == "per_batch": 137 | samples = pitch_shift( 138 | samples, self.transform_parameters["transpositions"][0], sample_rate 139 | ) 140 | 141 | return ObjectDict( 142 | samples=samples, 143 | sample_rate=sample_rate, 144 | targets=targets, 145 | target_rate=target_rate, 146 | ) 147 | -------------------------------------------------------------------------------- /torch_audiomentations/utils/config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Dict, Text, Optional, Union 3 | from importlib import import_module 4 | 5 | import torch_audiomentations 6 | from torch_audiomentations import Compose 7 | from torch_audiomentations.core.transforms_interface import BaseWaveformTransform 8 | 9 | # TODO: define this elsewhere? 10 | # TODO: update when a new type of transform is added (e.g. BaseSpectrogramTransform? OneOf? SomeOf?) 11 | # https://github.com/asteroid-team/torch-audiomentations/issues/26 12 | Transform = Union[BaseWaveformTransform, Compose] 13 | 14 | 15 | def get_class_by_name( 16 | class_name: str, default_module_name: str = "torch_audiomentations" 17 | ) -> type: 18 | """Load class by its name 19 | 20 | Parameters 21 | ---------- 22 | class_name : `str` 23 | default_module_name : `str`, optional 24 | When provided and `class_name` does not contain the absolute path. 25 | Defaults to "torch_audiomentations". 26 | 27 | Returns 28 | ------- 29 | Klass : `type` 30 | Class. 31 | 32 | Example 33 | ------- 34 | >>> YourAugmentation = get_class_by_name('your_package.your_module.YourAugmentation') 35 | >>> YourAugmentation = get_class_by_name('YourAugmentation', default_module_name='your_package.your_module') 36 | 37 | >>> from torch_audiomentations import Gain 38 | >>> assert Gain == get_class_by_name('Gain') 39 | """ 40 | tokens = class_name.split(".") 41 | 42 | if len(tokens) == 1: 43 | if default_module_name is None: 44 | msg = ( 45 | f'Could not infer module name from class name "{class_name}".' 46 | f"Please provide default module name." 47 | ) 48 | raise ValueError(msg) 49 | module_name = default_module_name 50 | else: 51 | module_name = ".".join(tokens[:-1]) 52 | class_name = tokens[-1] 53 | 54 | return getattr(import_module(module_name), class_name) 55 | 56 | 57 | def from_dict(config: Dict[Text, Union[Text, Dict[Text, Any]]]) -> Transform: 58 | """Instantiate a transform from a configuration dictionary. 59 | 60 | `from_dict` can be used to instantiate a transform from its class name. 61 | For instance, these two pieces of code are equivalent: 62 | 63 | >>> from torch_audiomentations import Gain 64 | >>> transform = Gain(min_gain_in_db=-12.0) 65 | 66 | >>> transform = from_dict({'transform': 'Gain', 67 | ... 'params': {'min_gain_in_db': -12.0}}) 68 | 69 | Transforms composition is also supported: 70 | 71 | >>> compose = from_dict( 72 | ... {'transform': 'Compose', 73 | ... 'params': {'transforms': [{'transform': 'Gain', 74 | ... 'params': {'min_gain_in_db': -12.0, 75 | ... 'mode': 'per_channel'}}, 76 | ... {'transform': 'PolarityInversion'}], 77 | ... 'shuffle': True}}) 78 | 79 | :param config: configuration - a configuration dictionary 80 | :returns: A transform. 81 | :rtype Transform: 82 | """ 83 | 84 | try: 85 | TransformClassName: Text = config["transform"] 86 | except KeyError: 87 | raise ValueError( 88 | "A (currently missing) 'transform' key should be used to define the transform type." 89 | ) 90 | 91 | try: 92 | TransformClass = get_class_by_name(TransformClassName) 93 | except AttributeError: 94 | raise ValueError( 95 | f"torch_audiomentations does not implement {TransformClassName} transform." 96 | ) 97 | 98 | transform_params: Dict = config.get("params", dict()) 99 | if not isinstance(transform_params, dict): 100 | raise ValueError( 101 | "Transform parameters must be provided as {'param_name': param_value} dictionary." 102 | ) 103 | 104 | if TransformClassName in ["Compose", "OneOf", "SomeOf"]: 105 | transform_params["transforms"] = [ 106 | from_dict(sub_transform_config) 107 | for sub_transform_config in transform_params["transforms"] 108 | ] 109 | 110 | return TransformClass(**transform_params) 111 | 112 | 113 | def from_yaml(file_yml: Union[Path, Text]) -> Transform: 114 | """Instantiate a transform from a YAML configuration file. 115 | 116 | `from_yaml` can be used to instantiate a transform from a YAML file. 117 | For instance, these two pieces of code are equivalent: 118 | 119 | >>> from torch_audiomentations import Gain 120 | >>> transform = Gain(min_gain_in_db=-12.0, mode="per_channel") 121 | 122 | >>> transform = from_yaml("config.yml") 123 | 124 | where the content of `config.yml` is something like: 125 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 126 | # config.yml 127 | transform: Gain 128 | params: 129 | min_gain_in_db: -12.0 130 | mode: per_channel 131 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 132 | 133 | Transforms composition is also supported: 134 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 135 | # config.yml 136 | transform: Compose 137 | params: 138 | shuffle: True 139 | transforms: 140 | - transform: Gain 141 | params: 142 | min_gain_in_db: -12.0 143 | mode: per_channel 144 | - transform: PolarityInversion 145 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 146 | 147 | :param file_yml: configuration file - a path to a YAML file with the following structure: 148 | :returns: A transform. 149 | :rtype Transform: 150 | """ 151 | 152 | try: 153 | import yaml 154 | except ImportError as e: 155 | raise ImportError( 156 | "PyYAML package is needed by `from_yaml`: please install it first." 157 | ) 158 | 159 | with open(file_yml, "r") as f: 160 | config = yaml.load(f, Loader=yaml.SafeLoader) 161 | 162 | return from_dict(config) 163 | -------------------------------------------------------------------------------- /torch_audiomentations/augmentations/colored_noise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from typing import Optional 4 | from math import ceil 5 | 6 | from torch_audiomentations.utils.fft import rfft, irfft 7 | from ..core.transforms_interface import BaseWaveformTransform 8 | from ..utils.dsp import calculate_rms 9 | from ..utils.io import Audio 10 | from ..utils.object_dict import ObjectDict 11 | 12 | 13 | def _gen_noise(f_decay, num_samples, sample_rate, device): 14 | """ 15 | Generate colored noise with f_decay decay using torch.fft 16 | """ 17 | noise = torch.normal( 18 | torch.tensor(0.0, device=device), 19 | torch.tensor(1.0, device=device), 20 | (sample_rate,), 21 | device=device, 22 | ) 23 | spec = rfft(noise) 24 | mask = 1 / ( 25 | torch.linspace(1, (sample_rate / 2) ** 0.5, spec.shape[0], device=device) 26 | ** f_decay 27 | ) 28 | spec *= mask 29 | noise = Audio.rms_normalize(irfft(spec).unsqueeze(0)).squeeze() 30 | noise = torch.cat([noise] * int(ceil(num_samples / noise.shape[0]))) 31 | return noise[:num_samples] 32 | 33 | 34 | class AddColoredNoise(BaseWaveformTransform): 35 | """ 36 | Add colored noises to the input audio. 37 | """ 38 | 39 | supported_modes = {"per_batch", "per_example", "per_channel"} 40 | 41 | supports_multichannel = True 42 | requires_sample_rate = True 43 | 44 | supports_target = True 45 | requires_target = False 46 | 47 | def __init__( 48 | self, 49 | min_snr_in_db: float = 3.0, 50 | max_snr_in_db: float = 30.0, 51 | min_f_decay: float = -2.0, 52 | max_f_decay: float = 2.0, 53 | mode: str = "per_example", 54 | p: float = 0.5, 55 | p_mode: str = None, 56 | sample_rate: int = None, 57 | target_rate: int = None, 58 | output_type: Optional[str] = None, 59 | ): 60 | """ 61 | :param min_snr_in_db: minimum SNR in dB. 62 | :param max_snr_in_db: maximum SNR in dB. 63 | :param min_f_decay: 64 | defines the minimum frequency power decay (1/f**f_decay). 65 | Typical values are "white noise" (f_decay=0), "pink noise" (f_decay=1), 66 | "brown noise" (f_decay=2), "blue noise (f_decay=-1)" and "violet noise" 67 | (f_decay=-2) 68 | :param max_f_decay: 69 | defines the maximum power decay (1/f**f_decay) for non-white noises. 70 | :param mode: 71 | :param p: 72 | :param p_mode: 73 | :param sample_rate: 74 | :param target_rate: 75 | """ 76 | 77 | super().__init__( 78 | mode=mode, 79 | p=p, 80 | p_mode=p_mode, 81 | sample_rate=sample_rate, 82 | target_rate=target_rate, 83 | output_type=output_type, 84 | ) 85 | 86 | self.min_snr_in_db = min_snr_in_db 87 | self.max_snr_in_db = max_snr_in_db 88 | if self.min_snr_in_db > self.max_snr_in_db: 89 | raise ValueError("min_snr_in_db must not be greater than max_snr_in_db") 90 | 91 | self.min_f_decay = min_f_decay 92 | self.max_f_decay = max_f_decay 93 | if self.min_f_decay > self.max_f_decay: 94 | raise ValueError("min_f_decay must not be greater than max_f_decay") 95 | 96 | def randomize_parameters( 97 | self, 98 | samples: Tensor = None, 99 | sample_rate: Optional[int] = None, 100 | targets: Optional[Tensor] = None, 101 | target_rate: Optional[int] = None, 102 | ): 103 | """ 104 | :params selected_samples: (batch_size, num_channels, num_samples) 105 | """ 106 | batch_size, _, num_samples = samples.shape 107 | 108 | # (batch_size, ) SNRs 109 | for param, mini, maxi in [ 110 | ("snr_in_db", self.min_snr_in_db, self.max_snr_in_db), 111 | ("f_decay", self.min_f_decay, self.max_f_decay), 112 | ]: 113 | if mini == maxi: 114 | self.transform_parameters[param] = torch.full( 115 | size=(batch_size,), 116 | fill_value=mini, 117 | dtype=torch.float32, 118 | device=samples.device, 119 | ) 120 | else: 121 | dist = torch.distributions.Uniform( 122 | low=torch.tensor(mini, dtype=torch.float32, device=samples.device), 123 | high=torch.tensor(maxi, dtype=torch.float32, device=samples.device), 124 | validate_args=True, 125 | ) 126 | self.transform_parameters[param] = dist.sample(sample_shape=(batch_size,)) 127 | 128 | def apply_transform( 129 | self, 130 | samples: Tensor = None, 131 | sample_rate: Optional[int] = None, 132 | targets: Optional[Tensor] = None, 133 | target_rate: Optional[int] = None, 134 | ) -> ObjectDict: 135 | batch_size, num_channels, num_samples = samples.shape 136 | 137 | # (batch_size, num_samples) 138 | noise = torch.stack( 139 | [ 140 | _gen_noise( 141 | self.transform_parameters["f_decay"][i], 142 | num_samples, 143 | sample_rate, 144 | samples.device, 145 | ) 146 | for i in range(batch_size) 147 | ] 148 | ) 149 | 150 | # (batch_size, num_channels) 151 | noise_rms = calculate_rms(samples) / ( 152 | 10 ** (self.transform_parameters["snr_in_db"].unsqueeze(dim=-1) / 20) 153 | ) 154 | 155 | return ObjectDict( 156 | samples=samples 157 | + noise_rms.unsqueeze(-1) 158 | * noise.view(batch_size, 1, num_samples).expand(-1, num_channels, -1), 159 | sample_rate=sample_rate, 160 | targets=targets, 161 | target_rate=target_rate, 162 | ) 163 | -------------------------------------------------------------------------------- /scripts/perf_benchmark.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | import os 3 | import random 4 | from pathlib import Path 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import pandas as pd 9 | import time 10 | import torch 11 | from tqdm import tqdm 12 | 13 | from torch_audiomentations import ( 14 | PolarityInversion, 15 | Gain, 16 | PeakNormalization, 17 | Shift, 18 | ShuffleChannels, 19 | LowPassFilter, 20 | HighPassFilter, 21 | ) 22 | 23 | BASE_DIR = Path(os.path.abspath(os.path.dirname(os.path.dirname(__file__)))) 24 | SCRIPTS_DIR = BASE_DIR / "scripts" 25 | 26 | 27 | class timer(object): 28 | """ 29 | timer: A class used to measure the execution time of a block of code that is 30 | inside a "with" statement. 31 | 32 | Example: 33 | 34 | ``` 35 | with timer("Count to 500000"): 36 | x = 0 37 | for i in range(500000): 38 | x += 1 39 | print(x) 40 | ``` 41 | 42 | Will output: 43 | 500000 44 | Count to 500000: 0.04 s 45 | 46 | Warning: The time resolution used here may be limited to 1 ms 47 | """ 48 | 49 | def __init__(self, description="Execution time", verbose=False): 50 | self.description = description 51 | self.verbose = verbose 52 | self.execution_time = None 53 | 54 | def __enter__(self): 55 | self.t = time.time() 56 | return self 57 | 58 | def __exit__(self, type, value, traceback): 59 | self.execution_time = time.time() - self.t 60 | if self.verbose: 61 | print("{}: {:.3f} s".format(self.description, self.execution_time)) 62 | 63 | 64 | def measure_execution_time( 65 | transform, batch_size, num_channels, duration, sample_rate, device_name, device 66 | ): 67 | transform_name = transform.__class__.__name__ 68 | num_samples = int(duration * sample_rate) 69 | samples = torch.rand( 70 | (batch_size, num_channels, num_samples), dtype=torch.float32, device=device 71 | ) 72 | perf_objects = [] 73 | for i in range(3): 74 | perf_obj = { 75 | "metrics": {}, 76 | "params": { 77 | "transform": transform_name, 78 | "batch_size": batch_size, 79 | "num_channels": num_channels, 80 | "duration": duration, 81 | "sample_rate": sample_rate, 82 | "num_samples": num_samples, 83 | "device_name": device_name, 84 | }, 85 | } 86 | with timer() as t: 87 | transform(samples=samples, sample_rate=sample_rate).cpu() 88 | perf_obj["metrics"]["execution_time"] = t.execution_time 89 | perf_objects.append(perf_obj) 90 | return perf_objects 91 | 92 | 93 | if __name__ == "__main__": 94 | """ 95 | For each transformation, apply it to an example sound and write the transformed sounds to 96 | an output folder. Also crudely measure and print execution time. 97 | """ 98 | np.random.seed(42) 99 | random.seed(42) 100 | 101 | output_dir = SCRIPTS_DIR / "perf_benchmark_output" 102 | os.makedirs(output_dir, exist_ok=True) 103 | 104 | params = dict( 105 | batch_sizes=[1, 2, 4, 8, 16], 106 | channels=[1, 2, 4, 8], 107 | durations=[1, 2, 4, 8, 16], 108 | sample_rates=[16000], 109 | devices=["cpu", "cuda"], 110 | ) 111 | 112 | if not torch.cuda.is_available(): 113 | params["devices"].remove("cuda") 114 | 115 | devices = { 116 | device_name: torch.device(device_name) for device_name in params["devices"] 117 | } 118 | 119 | transforms = [ 120 | Gain(p=1.0), 121 | HighPassFilter(p=1.0), 122 | LowPassFilter(p=1.0), 123 | PolarityInversion(p=1.0), 124 | PeakNormalization(p=1.0), 125 | Shift(p=1.0), 126 | ShuffleChannels(p=1.0), 127 | ] 128 | 129 | perf_objects = [] 130 | 131 | for device_name in params["devices"]: 132 | device = devices[device_name] 133 | for batch_size in tqdm(params["batch_sizes"]): 134 | for num_channels in params["channels"]: 135 | for duration in params["durations"]: 136 | for sample_rate in params["sample_rates"]: 137 | for transform in transforms: 138 | perf_objects += measure_execution_time( 139 | transform, 140 | batch_size, 141 | num_channels, 142 | duration, 143 | sample_rate, 144 | device_name, 145 | device, 146 | ) 147 | 148 | params_to_group_by = ["batch_size", "num_channels", "num_samples", "device_name"] 149 | for group_by_param in tqdm(params_to_group_by, desc="Making plots"): 150 | param_values = [] 151 | metric_values = [] 152 | transform_names = [] 153 | for perf_obj in perf_objects: 154 | param_values.append(perf_obj["params"][group_by_param]) 155 | metric_values.append(perf_obj["metrics"]["execution_time"]) 156 | transform_names.append(perf_obj["params"]["transform"]) 157 | 158 | df = pd.DataFrame( 159 | { 160 | group_by_param: param_values, 161 | "exec_time": metric_values, 162 | "transform": transform_names, 163 | } 164 | ) 165 | 166 | violin_plot_file_path = str(output_dir / "{}_plot.png".format(group_by_param)) 167 | 168 | fig = plt.figure() 169 | ax = fig.add_subplot(111) 170 | ax.set_title("execution time grouped by {}".format(group_by_param)) 171 | g = sns.boxplot(x=group_by_param, y="exec_time", data=df, ax=ax, hue="transform") 172 | g.set_yscale("log") 173 | fig.tight_layout() 174 | plt.savefig(violin_plot_file_path) 175 | plt.close(fig) 176 | -------------------------------------------------------------------------------- /tests/test_shift.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | from numpy.testing import assert_almost_equal 7 | 8 | from torch_audiomentations import Shift 9 | 10 | 11 | class TestShift: 12 | @pytest.mark.parametrize( 13 | "device_name", 14 | [ 15 | pytest.param("cpu"), 16 | pytest.param( 17 | "cuda", 18 | marks=pytest.mark.skip("Requires CUDA") 19 | if not torch.cuda.is_available() 20 | else [], 21 | ), 22 | ], 23 | ) 24 | def test_shift_by_1_sample_3dim(self, device_name): 25 | device = torch.device(device_name) 26 | samples = torch.arange(4)[None, None].repeat(2, 2, 1).to(device=device) 27 | samples[1] += 1 28 | 29 | augment = Shift( 30 | min_shift=1, max_shift=1, shift_unit="samples", p=1.0, output_type="dict" 31 | ) 32 | processed_samples = augment(samples).samples 33 | 34 | assert_almost_equal( 35 | processed_samples.cpu(), 36 | [[[3, 0, 1, 2], [3, 0, 1, 2]], [[4, 1, 2, 3], [4, 1, 2, 3]]], 37 | ) 38 | 39 | def test_shift_by_1_sample_without_rollover(self): 40 | samples = torch.arange(4)[None, None].repeat(2, 2, 1) 41 | samples[1] += 1 42 | 43 | augment = Shift( 44 | min_shift=1, 45 | max_shift=1, 46 | shift_unit="samples", 47 | rollover=False, 48 | p=1.0, 49 | output_type="dict", 50 | ) 51 | 52 | processed_samples = augment(samples=samples).samples 53 | assert_almost_equal( 54 | processed_samples, 55 | [[[0, 0, 1, 2], [0, 0, 1, 2]], [[0, 1, 2, 3], [0, 1, 2, 3]]], 56 | ) 57 | 58 | def test_negative_shift_by_2_samples(self): 59 | samples = torch.arange(4)[None, None].repeat(2, 2, 1) 60 | samples[1] += 1 61 | 62 | augment = Shift( 63 | min_shift=-2, 64 | max_shift=-2, 65 | shift_unit="samples", 66 | rollover=True, 67 | p=1.0, 68 | output_type="dict", 69 | ) 70 | 71 | processed_samples = augment(samples=samples).samples 72 | assert_almost_equal( 73 | processed_samples, 74 | [[[2, 3, 0, 1], [2, 3, 0, 1]], [[3, 4, 1, 2], [3, 4, 1, 2]]], 75 | ) 76 | 77 | def test_shift_by_fraction(self): 78 | samples = torch.arange(4)[None, None].repeat(2, 2, 1) 79 | samples[1] += 1 80 | 81 | augment = Shift( 82 | min_shift=0.5, 83 | max_shift=0.5, 84 | shift_unit="fraction", 85 | rollover=True, 86 | p=1.0, 87 | output_type="dict", 88 | ) 89 | 90 | processed_samples = augment(samples=samples).samples 91 | assert_almost_equal( 92 | processed_samples, 93 | [[[2, 3, 0, 1], [2, 3, 0, 1]], [[3, 4, 1, 2], [3, 4, 1, 2]]], 94 | ) 95 | 96 | def test_shift_by_seconds(self): 97 | samples = torch.arange(4)[None, None].repeat(2, 2, 1) 98 | samples[1] += 1 99 | 100 | augment = Shift( 101 | min_shift=-2, 102 | max_shift=-2, 103 | shift_unit="seconds", 104 | p=1.0, 105 | sample_rate=1, 106 | output_type="dict", 107 | ) 108 | processed_samples = augment(samples).samples 109 | 110 | assert_almost_equal( 111 | processed_samples, 112 | [[[2, 3, 0, 1], [2, 3, 0, 1]], [[3, 4, 1, 2], [3, 4, 1, 2]]], 113 | ) 114 | 115 | def test_shift_by_seconds_specify_sample_rate_in_both_init_and_forward(self): 116 | samples = torch.arange(4)[None, None].repeat(2, 2, 1) 117 | samples[1] += 1 118 | init_sample_rate = 1 119 | forward_sample_rate = 2 120 | 121 | augment = Shift( 122 | min_shift=1, 123 | max_shift=1, 124 | shift_unit="seconds", 125 | p=1.0, 126 | sample_rate=init_sample_rate, 127 | rollover=False, 128 | output_type="dict", 129 | ) 130 | # If sample_rate is specified in both __init__ and forward, then the latter will be used 131 | processed_samples = augment(samples, sample_rate=forward_sample_rate).samples 132 | assert_almost_equal( 133 | processed_samples, 134 | [[[0, 0, 0, 1], [0, 0, 0, 1]], [[0, 0, 1, 2], [0, 0, 1, 2]]], 135 | ) 136 | 137 | def test_shift_by_seconds_without_specifying_sample_rate(self): 138 | samples = torch.arange(4)[None, None].repeat(2, 2, 1) 139 | samples[1] += 1 140 | 141 | augment = Shift( 142 | min_shift=-3, max_shift=-3, shift_unit="seconds", p=1.0, output_type="dict" 143 | ) 144 | with pytest.raises(RuntimeError): 145 | augment(samples) 146 | 147 | with pytest.raises(RuntimeError): 148 | augment(samples, sample_rate=None) 149 | 150 | def test_variability_within_batch(self): 151 | torch.manual_seed(42) 152 | 153 | samples = torch.arange(4)[None, None].repeat(1000, 2, 1) 154 | augment = Shift( 155 | min_shift=-1, max_shift=1, shift_unit="samples", p=1.0, output_type="dict" 156 | ) 157 | processed_samples = augment(samples).samples 158 | 159 | applied_shift_counts = {-1: 0, 0: 0, 1: 0} 160 | for i in range(samples.shape[0]): 161 | applied_shift = None 162 | for shift in range(augment.min_shift, augment.max_shift + 1): 163 | rolled = np.roll(samples[i], shift, axis=1) 164 | if np.array_equal(rolled, processed_samples[i]): 165 | applied_shift = shift 166 | break 167 | assert applied_shift is not None 168 | applied_shift_counts[applied_shift] += 1 169 | 170 | for shift in range(0, augment.max_shift + 1): 171 | assert applied_shift_counts[shift] > 50 172 | -------------------------------------------------------------------------------- /torch_audiomentations/augmentations/band_pass_filter.py: -------------------------------------------------------------------------------- 1 | import julius 2 | import torch 3 | from torch import Tensor 4 | from typing import Optional 5 | from ..core.transforms_interface import BaseWaveformTransform 6 | from ..utils.mel_scale import convert_frequencies_to_mels, convert_mels_to_frequencies 7 | from ..utils.object_dict import ObjectDict 8 | 9 | 10 | class BandPassFilter(BaseWaveformTransform): 11 | """ 12 | Apply band-pass filtering to the input audio. 13 | """ 14 | 15 | supported_modes = {"per_batch", "per_example", "per_channel"} 16 | 17 | supports_multichannel = True 18 | requires_sample_rate = True 19 | 20 | supports_target = True 21 | requires_target = False 22 | 23 | def __init__( 24 | self, 25 | min_center_frequency=200, 26 | max_center_frequency=4000, 27 | min_bandwidth_fraction=0.5, 28 | max_bandwidth_fraction=1.99, 29 | mode: str = "per_example", 30 | p: float = 0.5, 31 | p_mode: str = None, 32 | sample_rate: int = None, 33 | target_rate: int = None, 34 | output_type: Optional[str] = None, 35 | ): 36 | """ 37 | :param min_center_frequency: Minimum center frequency in hertz 38 | :param max_center_frequency: Maximum center frequency in hertz 39 | :param min_bandwidth_fraction: Minimum bandwidth fraction relative to center 40 | frequency (number between 0.0 and 2.0) 41 | :param max_bandwidth_fraction: Maximum bandwidth fraction relative to center 42 | frequency (number between 0.0 and 2.0) 43 | :param mode: 44 | :param p: 45 | :param p_mode: 46 | :param sample_rate: 47 | """ 48 | super().__init__( 49 | mode=mode, 50 | p=p, 51 | p_mode=p_mode, 52 | sample_rate=sample_rate, 53 | target_rate=target_rate, 54 | output_type=output_type, 55 | ) 56 | 57 | self.min_center_frequency = min_center_frequency 58 | self.max_center_frequency = max_center_frequency 59 | self.min_bandwidth_fraction = min_bandwidth_fraction 60 | self.max_bandwidth_fraction = max_bandwidth_fraction 61 | 62 | if max_center_frequency < min_center_frequency: 63 | raise ValueError( 64 | f"max_center_frequency ({max_center_frequency}) should be larger than " 65 | f"min_center_frequency ({min_center_frequency})." 66 | ) 67 | 68 | if min_bandwidth_fraction <= 0.0: 69 | raise ValueError("min_bandwidth_fraction must be a positive number") 70 | 71 | if max_bandwidth_fraction < min_bandwidth_fraction: 72 | raise ValueError( 73 | f"max_bandwidth_fraction ({max_bandwidth_fraction}) should be larger than " 74 | f"min_bandwidth_fraction ({min_bandwidth_fraction})." 75 | ) 76 | 77 | if max_bandwidth_fraction >= 2.0: 78 | raise ValueError( 79 | f"max_bandwidth_fraction ({max_bandwidth_fraction}) should be smaller than 2.0," 80 | f"since otherwise low_cut_frequency of the band can be smaller than 0 Hz." 81 | ) 82 | 83 | def randomize_parameters( 84 | self, 85 | samples: Tensor = None, 86 | sample_rate: Optional[int] = None, 87 | targets: Optional[Tensor] = None, 88 | target_rate: Optional[int] = None, 89 | ): 90 | """ 91 | :params samples: (batch_size, num_channels, num_samples) 92 | """ 93 | 94 | batch_size, _, num_samples = samples.shape 95 | 96 | # Sample frequencies uniformly in mel space, then convert back to frequency 97 | def get_dist(min_freq, max_freq): 98 | dist = torch.distributions.Uniform( 99 | low=convert_frequencies_to_mels( 100 | torch.tensor(min_freq, dtype=torch.float32, device=samples.device) 101 | ), 102 | high=convert_frequencies_to_mels( 103 | torch.tensor(max_freq, dtype=torch.float32, device=samples.device) 104 | ), 105 | validate_args=True, 106 | ) 107 | return dist 108 | 109 | center_dist = get_dist(self.min_center_frequency, self.max_center_frequency) 110 | self.transform_parameters["center_freq"] = convert_mels_to_frequencies( 111 | center_dist.sample(sample_shape=(batch_size,)) 112 | ) 113 | 114 | bandwidth_dist = torch.distributions.Uniform( 115 | low=torch.tensor( 116 | self.min_bandwidth_fraction, dtype=torch.float32, device=samples.device 117 | ), 118 | high=torch.tensor( 119 | self.max_bandwidth_fraction, dtype=torch.float32, device=samples.device 120 | ), 121 | ) 122 | self.transform_parameters["bandwidth"] = bandwidth_dist.sample( 123 | sample_shape=(batch_size,) 124 | ) 125 | 126 | def apply_transform( 127 | self, 128 | samples: Tensor = None, 129 | sample_rate: Optional[int] = None, 130 | targets: Optional[Tensor] = None, 131 | target_rate: Optional[int] = None, 132 | ) -> ObjectDict: 133 | batch_size, num_channels, num_samples = samples.shape 134 | 135 | low_cutoffs_as_fraction_of_sample_rate = ( 136 | self.transform_parameters["center_freq"] 137 | * (1 - 0.5 * self.transform_parameters["bandwidth"]) 138 | / sample_rate 139 | ) 140 | high_cutoffs_as_fraction_of_sample_rate = ( 141 | self.transform_parameters["center_freq"] 142 | * (1 + 0.5 * self.transform_parameters["bandwidth"]) 143 | / sample_rate 144 | ) 145 | # TODO: Instead of using a for loop, perform batched compute to speed things up 146 | for i in range(batch_size): 147 | samples[i] = julius.bandpass_filter( 148 | samples[i], 149 | cutoff_low=low_cutoffs_as_fraction_of_sample_rate[i].item(), 150 | cutoff_high=high_cutoffs_as_fraction_of_sample_rate[i].item(), 151 | ) 152 | 153 | return ObjectDict( 154 | samples=samples, 155 | sample_rate=sample_rate, 156 | targets=targets, 157 | target_rate=target_rate, 158 | ) 159 | -------------------------------------------------------------------------------- /torch_audiomentations/augmentations/background_noise.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | from typing import Union, List, Optional 4 | 5 | import torch 6 | from torch import Tensor 7 | 8 | from ..core.transforms_interface import BaseWaveformTransform, EmptyPathException 9 | from ..utils.dsp import calculate_rms 10 | from ..utils.file import find_audio_files_in_paths 11 | from ..utils.io import Audio 12 | from ..utils.object_dict import ObjectDict 13 | 14 | 15 | class AddBackgroundNoise(BaseWaveformTransform): 16 | """ 17 | Add background noise to the input audio. 18 | """ 19 | 20 | supported_modes = {"per_batch", "per_example", "per_channel"} 21 | 22 | # Note: This transform has only partial support for multichannel audio. Noises that are not 23 | # mono get mixed down to mono before they are added to all channels in the input. 24 | supports_multichannel = True 25 | requires_sample_rate = True 26 | 27 | supports_target = True 28 | requires_target = False 29 | 30 | def __init__( 31 | self, 32 | background_paths: Union[List[Path], List[str], Path, str], 33 | min_snr_in_db: float = 3.0, 34 | max_snr_in_db: float = 30.0, 35 | mode: str = "per_example", 36 | p: float = 0.5, 37 | p_mode: str = None, 38 | sample_rate: int = None, 39 | target_rate: int = None, 40 | output_type: Optional[str] = None, 41 | ): 42 | """ 43 | 44 | :param background_paths: Either a path to a folder with audio files or a list of paths 45 | to audio files. 46 | :param min_snr_in_db: minimum SNR in dB. 47 | :param max_snr_in_db: maximum SNR in dB. 48 | :param mode: 49 | :param p: 50 | :param p_mode: 51 | :param sample_rate: 52 | """ 53 | 54 | super().__init__( 55 | mode=mode, 56 | p=p, 57 | p_mode=p_mode, 58 | sample_rate=sample_rate, 59 | target_rate=target_rate, 60 | output_type=output_type, 61 | ) 62 | 63 | # TODO: check that one can read audio files 64 | self.background_paths = find_audio_files_in_paths(background_paths) 65 | 66 | if sample_rate is not None: 67 | self.audio = Audio(sample_rate=sample_rate, mono=True) 68 | 69 | if len(self.background_paths) == 0: 70 | raise EmptyPathException("There are no supported audio files found.") 71 | 72 | self.min_snr_in_db = min_snr_in_db 73 | self.max_snr_in_db = max_snr_in_db 74 | if self.min_snr_in_db > self.max_snr_in_db: 75 | raise ValueError("min_snr_in_db must not be greater than max_snr_in_db") 76 | 77 | def random_background(self, audio: Audio, target_num_samples: int) -> torch.Tensor: 78 | pieces = [] 79 | 80 | # TODO: support repeat short samples instead of concatenating from different files 81 | 82 | missing_num_samples = target_num_samples 83 | while missing_num_samples > 0: 84 | background_path = random.choice(self.background_paths) 85 | background_num_samples = audio.get_num_samples(background_path) 86 | 87 | if background_num_samples > missing_num_samples: 88 | sample_offset = random.randint( 89 | 0, background_num_samples - missing_num_samples 90 | ) 91 | num_samples = missing_num_samples 92 | background_samples = audio( 93 | background_path, sample_offset=sample_offset, num_samples=num_samples 94 | ) 95 | missing_num_samples = 0 96 | else: 97 | background_samples = audio(background_path) 98 | missing_num_samples -= background_num_samples 99 | 100 | pieces.append(background_samples) 101 | 102 | # the inner call to rms_normalize ensures concatenated pieces share the same RMS (1) 103 | # the outer call to rms_normalize ensures that the resulting background has an RMS of 1 104 | # (this simplifies "apply_transform" logic) 105 | return audio.rms_normalize( 106 | torch.cat([audio.rms_normalize(piece) for piece in pieces], dim=1) 107 | ) 108 | 109 | def randomize_parameters( 110 | self, 111 | samples: Tensor = None, 112 | sample_rate: Optional[int] = None, 113 | targets: Optional[Tensor] = None, 114 | target_rate: Optional[int] = None, 115 | ): 116 | """ 117 | 118 | :params samples: (batch_size, num_channels, num_samples) 119 | """ 120 | 121 | batch_size, _, num_samples = samples.shape 122 | 123 | # (batch_size, num_samples) RMS-normalized background noise 124 | audio = self.audio if hasattr(self, "audio") else Audio(sample_rate, mono=True) 125 | self.transform_parameters["background"] = torch.stack( 126 | [self.random_background(audio, num_samples) for _ in range(batch_size)] 127 | ) 128 | 129 | # (batch_size, ) SNRs 130 | if self.min_snr_in_db == self.max_snr_in_db: 131 | self.transform_parameters["snr_in_db"] = torch.full( 132 | size=(batch_size,), 133 | fill_value=self.min_snr_in_db, 134 | dtype=torch.float32, 135 | device=samples.device, 136 | ) 137 | else: 138 | snr_distribution = torch.distributions.Uniform( 139 | low=torch.tensor( 140 | self.min_snr_in_db, dtype=torch.float32, device=samples.device 141 | ), 142 | high=torch.tensor( 143 | self.max_snr_in_db, dtype=torch.float32, device=samples.device 144 | ), 145 | validate_args=True, 146 | ) 147 | self.transform_parameters["snr_in_db"] = snr_distribution.sample( 148 | sample_shape=(batch_size,) 149 | ) 150 | 151 | def apply_transform( 152 | self, 153 | samples: Tensor = None, 154 | sample_rate: Optional[int] = None, 155 | targets: Optional[Tensor] = None, 156 | target_rate: Optional[int] = None, 157 | ) -> ObjectDict: 158 | batch_size, num_channels, num_samples = samples.shape 159 | 160 | # (batch_size, num_samples) 161 | background = self.transform_parameters["background"].to(samples.device) 162 | 163 | # (batch_size, num_channels) 164 | background_rms = calculate_rms(samples) / ( 165 | 10 ** (self.transform_parameters["snr_in_db"].unsqueeze(dim=-1) / 20) 166 | ) 167 | 168 | return ObjectDict( 169 | samples=samples 170 | + background_rms.unsqueeze(-1) 171 | * background.view(batch_size, 1, num_samples).expand(-1, num_channels, -1), 172 | sample_rate=sample_rate, 173 | targets=targets, 174 | target_rate=target_rate, 175 | ) 176 | -------------------------------------------------------------------------------- /tests/test_compose.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | import numpy as np 5 | import torch 6 | from numpy.testing import assert_almost_equal, assert_array_equal 7 | from torchaudio.transforms import Vol 8 | 9 | from torch_audiomentations import PolarityInversion, Compose, PeakNormalization, Gain 10 | from torch_audiomentations.augmentations.shuffle_channels import ShuffleChannels 11 | from torch_audiomentations.utils.dsp import convert_decibels_to_amplitude_ratio 12 | 13 | 14 | class TestCompose(unittest.TestCase): 15 | def test_compose_without_specifying_output_type(self): 16 | samples = np.array([[[1.0, 0.5, -0.25, -0.125, 0.0]]], dtype=np.float32) 17 | sample_rate = 16000 18 | 19 | augment = Compose( 20 | [ 21 | Gain(min_gain_in_db=-6.000001, max_gain_in_db=-6, p=1.0), 22 | PolarityInversion(p=1.0), 23 | ] 24 | ) 25 | processed_samples = augment( 26 | samples=torch.from_numpy(samples), sample_rate=sample_rate 27 | ) 28 | # This dtype should be torch.Tensor until we switch to ObjectDict as default 29 | assert type(processed_samples) == torch.Tensor 30 | processed_samples = processed_samples.numpy() 31 | expected_factor = -convert_decibels_to_amplitude_ratio(-6) 32 | assert_almost_equal( 33 | processed_samples, 34 | expected_factor 35 | * np.array([[[1.0, 0.5, -0.25, -0.125, 0.0]]], dtype=np.float32), 36 | decimal=6, 37 | ) 38 | self.assertEqual(processed_samples.dtype, np.float32) 39 | 40 | def test_compose_dict(self): 41 | samples = np.array([[[1.0, 0.5, -0.25, -0.125, 0.0]]], dtype=np.float32) 42 | sample_rate = 16000 43 | 44 | augment = Compose( 45 | [ 46 | Gain(min_gain_in_db=-6.000001, max_gain_in_db=-6, p=1.0), 47 | PolarityInversion(p=1.0), 48 | ], 49 | output_type="dict", 50 | ) 51 | processed_samples = augment( 52 | samples=torch.from_numpy(samples), sample_rate=sample_rate 53 | ).samples.numpy() 54 | expected_factor = -convert_decibels_to_amplitude_ratio(-6) 55 | assert_almost_equal( 56 | processed_samples, 57 | expected_factor 58 | * np.array([[[1.0, 0.5, -0.25, -0.125, 0.0]]], dtype=np.float32), 59 | decimal=6, 60 | ) 61 | self.assertEqual(processed_samples.dtype, np.float32) 62 | 63 | def test_compose_with_torchaudio_transform(self): 64 | samples = np.array([[[1.0, 0.5, -0.25, -0.125, 0.0]]], dtype=np.float32) 65 | sample_rate = 16000 66 | 67 | augment = Compose( 68 | [ 69 | Vol(gain=-6, gain_type="db"), 70 | PolarityInversion(p=1.0), 71 | ], 72 | output_type="dict", 73 | ) 74 | processed_samples = augment( 75 | samples=torch.from_numpy(samples), sample_rate=sample_rate 76 | ).samples.numpy() 77 | expected_factor = -convert_decibels_to_amplitude_ratio(-6) 78 | assert_almost_equal( 79 | processed_samples, 80 | expected_factor 81 | * np.array([[[1.0, 0.5, -0.25, -0.125, 0.0]]], dtype=np.float32), 82 | decimal=6, 83 | ) 84 | self.assertEqual(processed_samples.dtype, np.float32) 85 | 86 | def test_compose_with_p_zero(self): 87 | samples = np.array([[[1.0, 0.5, -0.25, -0.125, 0.0]]], dtype=np.float32) 88 | sample_rate = 16000 89 | 90 | augment = Compose( 91 | transforms=[ 92 | Gain(min_gain_in_db=-6.000001, max_gain_in_db=-6, p=1.0), 93 | PolarityInversion(p=1.0), 94 | ], 95 | p=0.0, 96 | output_type="dict", 97 | ) 98 | processed_samples = augment( 99 | samples=torch.from_numpy(samples), sample_rate=sample_rate 100 | ).samples.numpy() 101 | assert_array_equal(samples, processed_samples) 102 | 103 | def test_freeze_and_unfreeze_parameters(self): 104 | torch.manual_seed(42) 105 | 106 | samples = np.array([[[1.0, 0.5, -0.25, -0.125, 0.0]]], dtype=np.float32) 107 | sample_rate = 16000 108 | 109 | augment = Compose( 110 | transforms=[ 111 | Gain( 112 | min_gain_in_db=-16.000001, 113 | max_gain_in_db=-2, 114 | p=1.0, 115 | ), 116 | PolarityInversion(p=1.0), 117 | ], 118 | output_type="dict", 119 | ) 120 | 121 | processed_samples1 = augment( 122 | samples=torch.from_numpy(samples), sample_rate=sample_rate 123 | ).samples.numpy() 124 | augment.freeze_parameters() 125 | processed_samples2 = augment( 126 | samples=torch.from_numpy(samples), sample_rate=sample_rate 127 | ).samples.numpy() 128 | assert_array_equal(processed_samples1, processed_samples2) 129 | 130 | augment.unfreeze_parameters() 131 | processed_samples3 = augment( 132 | samples=torch.from_numpy(samples), sample_rate=sample_rate 133 | ).samples.numpy() 134 | self.assertNotEqual(processed_samples1[0, 0, 0], processed_samples3[0, 0, 0]) 135 | 136 | def test_shuffle(self): 137 | random.seed(42) 138 | samples = np.array([[[1.0, 0.5, -0.25, -0.125, 0.0]]], dtype=np.float32) 139 | sample_rate = 16000 140 | 141 | augment = Compose( 142 | transforms=[ 143 | Gain(min_gain_in_db=-18.0, max_gain_in_db=-16.0, p=1.0), 144 | PeakNormalization(p=1.0), 145 | ], 146 | shuffle=True, 147 | output_type="dict", 148 | ) 149 | num_peak_normalization_last = 0 150 | num_gain_last = 0 151 | for i in range(100): 152 | processed_samples = augment( 153 | samples=torch.from_numpy(samples), sample_rate=sample_rate 154 | ).samples.numpy() 155 | 156 | # Either PeakNormalization or Gain was applied last 157 | if processed_samples[0, 0, 0] < 0.2: 158 | num_gain_last += 1 159 | elif processed_samples[0, 0, 0] == 1.0: 160 | num_peak_normalization_last += 1 161 | else: 162 | raise AssertionError("Unexpected value!") 163 | 164 | self.assertGreater(num_peak_normalization_last, 10) 165 | self.assertGreater(num_gain_last, 10) 166 | 167 | def test_supported_modes_property(self): 168 | augment = Compose( 169 | transforms=[ 170 | PeakNormalization(p=1.0), 171 | ], 172 | output_type="dict", 173 | ) 174 | assert augment.supported_modes == {"per_batch", "per_example", "per_channel"} 175 | 176 | augment = Compose( 177 | transforms=[ 178 | PeakNormalization( 179 | p=1.0, 180 | ), 181 | ShuffleChannels( 182 | p=1.0, 183 | ), 184 | ], 185 | output_type="dict", 186 | ) 187 | assert augment.supported_modes == {"per_example"} 188 | -------------------------------------------------------------------------------- /torch_audiomentations/augmentations/shift.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, Union 3 | from torch import Tensor 4 | 5 | from ..core.transforms_interface import BaseWaveformTransform 6 | from ..utils.object_dict import ObjectDict 7 | 8 | 9 | def shift_gpu(tensor: torch.Tensor, r: torch.Tensor, rollover: bool = False): 10 | """Shift or roll a batch of tensors""" 11 | b, c, t = tensor.shape 12 | 13 | # Arange indexes 14 | x = torch.arange(t, device=tensor.device) 15 | 16 | # Apply Roll 17 | r = r[:, None, None] 18 | idxs = (x - r).expand([b, c, t]) 19 | ret = torch.gather(tensor, 2, idxs % t) 20 | if rollover: 21 | return ret 22 | 23 | # Cut where we've rolled over 24 | cut_points = (idxs + 1).clamp(0) 25 | cut_points[cut_points > t] = 0 26 | ret[cut_points == 0] = 0 27 | return ret 28 | 29 | 30 | def shift_cpu( 31 | selected_samples: torch.Tensor, shift_samples: torch.Tensor, rollover: bool = False 32 | ): 33 | """Shift or roll a batch of tensors with the help of a for loop and torch.roll()""" 34 | selected_batch_size = selected_samples.size(0) 35 | 36 | for i in range(selected_batch_size): 37 | num_samples_to_shift = shift_samples[i].item() 38 | selected_samples[i] = torch.roll( 39 | selected_samples[i], shifts=num_samples_to_shift, dims=-1 40 | ) 41 | 42 | if not rollover: 43 | if num_samples_to_shift > 0: 44 | selected_samples[i, ..., :num_samples_to_shift] = 0.0 45 | elif num_samples_to_shift < 0: 46 | selected_samples[i, ..., num_samples_to_shift:] = 0.0 47 | 48 | return selected_samples 49 | 50 | 51 | class Shift(BaseWaveformTransform): 52 | """ 53 | Shift the audio forwards or backwards, with or without rollover 54 | """ 55 | 56 | supported_modes = {"per_batch", "per_example", "per_channel"} 57 | 58 | supports_multichannel = True 59 | requires_sample_rate = True 60 | 61 | supports_target = False # FIXME: some work is needed to support targets (see FIXMEs in apply_transform) 62 | requires_target = False 63 | 64 | def __init__( 65 | self, 66 | min_shift: Union[float, int] = -0.5, 67 | max_shift: Union[float, int] = 0.5, 68 | shift_unit: str = "fraction", 69 | rollover: bool = True, 70 | mode: str = "per_example", 71 | p: float = 0.5, 72 | p_mode: Optional[str] = None, 73 | sample_rate: Optional[int] = None, 74 | target_rate: Optional[int] = None, 75 | output_type: Optional[str] = None, 76 | ): 77 | """ 78 | 79 | :param min_shift: minimum amount of shifting in time. See also shift_unit. 80 | :param max_shift: maximum amount of shifting in time. See also shift_unit. 81 | :param shift_unit: Defines the unit of the value of min_shift and max_shift. 82 | "fraction": Fraction of the total sound length 83 | "samples": Number of audio samples 84 | "seconds": Number of seconds 85 | :param rollover: When set to True, samples that roll beyond the first or last position 86 | are re-introduced at the last or first. When set to False, samples that roll beyond 87 | the first or last position are discarded. In other words, rollover=False results in 88 | an empty space (with zeroes). 89 | :param mode: 90 | :param p: 91 | :param p_mode: 92 | :param sample_rate: 93 | :param target_rate: 94 | """ 95 | super().__init__( 96 | mode=mode, 97 | p=p, 98 | p_mode=p_mode, 99 | sample_rate=sample_rate, 100 | target_rate=target_rate, 101 | output_type=output_type, 102 | ) 103 | self.min_shift = min_shift 104 | self.max_shift = max_shift 105 | self.shift_unit = shift_unit 106 | self.rollover = rollover 107 | if self.min_shift > self.max_shift: 108 | raise ValueError("min_shift must not be greater than max_shift") 109 | if self.shift_unit not in ("fraction", "samples", "seconds"): 110 | raise ValueError('shift_unit must be "samples", "fraction" or "seconds"') 111 | 112 | def randomize_parameters( 113 | self, 114 | samples: Tensor = None, 115 | sample_rate: Optional[int] = None, 116 | targets: Optional[Tensor] = None, 117 | target_rate: Optional[int] = None, 118 | ): 119 | if self.shift_unit == "samples": 120 | min_shift_in_samples = self.min_shift 121 | max_shift_in_samples = self.max_shift 122 | 123 | elif self.shift_unit == "fraction": 124 | min_shift_in_samples = int(round(self.min_shift * samples.shape[-1])) 125 | max_shift_in_samples = int(round(self.max_shift * samples.shape[-1])) 126 | 127 | elif self.shift_unit == "seconds": 128 | min_shift_in_samples = int(round(self.min_shift * sample_rate)) 129 | max_shift_in_samples = int(round(self.max_shift * sample_rate)) 130 | 131 | else: 132 | raise ValueError("Invalid shift_unit") 133 | 134 | assert ( 135 | torch.iinfo(torch.int32).min 136 | <= min_shift_in_samples 137 | <= torch.iinfo(torch.int32).max 138 | ) 139 | assert ( 140 | torch.iinfo(torch.int32).min 141 | <= max_shift_in_samples 142 | <= torch.iinfo(torch.int32).max 143 | ) 144 | selected_batch_size = samples.size(0) 145 | if min_shift_in_samples == max_shift_in_samples: 146 | self.transform_parameters["num_samples_to_shift"] = torch.full( 147 | size=(selected_batch_size,), 148 | fill_value=min_shift_in_samples, 149 | dtype=torch.int32, 150 | device=samples.device, 151 | ) 152 | 153 | else: 154 | self.transform_parameters["num_samples_to_shift"] = torch.randint( 155 | low=min_shift_in_samples, 156 | high=max_shift_in_samples + 1, 157 | size=(selected_batch_size,), 158 | dtype=torch.int32, 159 | device=samples.device, 160 | ) 161 | 162 | def apply_transform( 163 | self, 164 | samples: Tensor = None, 165 | sample_rate: Optional[int] = None, 166 | targets: Optional[Tensor] = None, 167 | target_rate: Optional[int] = None, 168 | ) -> ObjectDict: 169 | num_samples_to_shift = self.transform_parameters["num_samples_to_shift"] 170 | 171 | # Select fastest implementation based on device 172 | shift = shift_gpu if samples.device.type == "cuda" else shift_cpu 173 | shifted_samples = shift(samples, num_samples_to_shift, self.rollover) 174 | 175 | if targets is None or target_rate == 0: 176 | shifted_targets = targets 177 | 178 | else: 179 | num_frames_to_shift = int( 180 | round(target_rate * num_samples_to_shift / sample_rate) 181 | ) 182 | shifted_targets = shift( 183 | targets.transpose(-2, -1), num_frames_to_shift, self.rollover 184 | ).transpose(-2, -1) 185 | 186 | return ObjectDict( 187 | samples=shifted_samples, 188 | sample_rate=sample_rate, 189 | targets=shifted_targets, 190 | target_rate=target_rate, 191 | ) 192 | 193 | def is_sample_rate_required(self) -> bool: 194 | # Sample rate is required only if shift_unit is "seconds" 195 | return self.shift_unit == "seconds" 196 | -------------------------------------------------------------------------------- /tests/test_background_noise.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | import tempfile 5 | import unittest 6 | import uuid 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | import pytest 11 | import torch 12 | from scipy.io.wavfile import write 13 | 14 | from torch_audiomentations import AddBackgroundNoise 15 | from torch_audiomentations.utils.dsp import calculate_rms 16 | from .utils import TEST_FIXTURES_DIR 17 | from torch_audiomentations.utils.io import Audio 18 | 19 | 20 | class TestAddBackgroundNoise(unittest.TestCase): 21 | def setUp(self): 22 | self.sample_rate = 16000 23 | self.batch_size = 16 24 | self.empty_input_audio = torch.empty(0, 1, 16000) 25 | # TODO: use utils.io.Audio 26 | 27 | audio = Audio(self.sample_rate, mono=True) 28 | self.input_audio = audio(TEST_FIXTURES_DIR / "acoustic_guitar_0.wav")[None] 29 | self.input_audios = torch.cat([self.input_audio] * self.batch_size, dim=0) 30 | 31 | self.bg_path = TEST_FIXTURES_DIR / "bg" 32 | self.bg_short_path = TEST_FIXTURES_DIR / "bg_short" 33 | self.bg_noise_transform_guaranteed = AddBackgroundNoise( 34 | self.bg_path, 20, p=1.0, output_type="dict" 35 | ) 36 | self.bg_short_noise_transform_guaranteed = AddBackgroundNoise( 37 | self.bg_short_path, 20, p=1.0, output_type="dict" 38 | ) 39 | self.bg_noise_transform_no_guarantee = AddBackgroundNoise( 40 | self.bg_path, 20, p=0.0, output_type="dict" 41 | ) 42 | 43 | def test_background_noise_no_guarantee_with_single_tensor(self): 44 | mixed_input = self.bg_noise_transform_no_guarantee( 45 | self.input_audio, self.sample_rate 46 | ).samples 47 | self.assertTrue(torch.equal(mixed_input, self.input_audio)) 48 | self.assertEqual(mixed_input.size(0), self.input_audio.size(0)) 49 | 50 | def test_background_noise_no_guarantee_with_empty_tensor(self): 51 | with self.assertWarns(UserWarning) as warning_context_manager: 52 | mixed_input = self.bg_noise_transform_no_guarantee( 53 | self.empty_input_audio, self.sample_rate 54 | ).samples 55 | 56 | self.assertIn( 57 | "An empty samples tensor was passed", str(warning_context_manager.warning) 58 | ) 59 | 60 | self.assertTrue(torch.equal(mixed_input, self.empty_input_audio)) 61 | self.assertEqual(mixed_input.size(0), self.empty_input_audio.size(0)) 62 | 63 | def test_background_noise_guaranteed_with_zero_length_samples(self): 64 | with self.assertWarns(UserWarning) as warning_context_manager: 65 | mixed_input = self.bg_noise_transform_guaranteed( 66 | self.empty_input_audio, self.sample_rate 67 | ).samples 68 | 69 | self.assertIn( 70 | "An empty samples tensor was passed", str(warning_context_manager.warning) 71 | ) 72 | 73 | self.assertTrue(torch.equal(mixed_input, self.empty_input_audio)) 74 | self.assertEqual(mixed_input.size(0), self.empty_input_audio.size(0)) 75 | 76 | def test_background_noise_guaranteed_with_single_tensor(self): 77 | mixed_input = self.bg_noise_transform_guaranteed( 78 | self.input_audio, self.sample_rate 79 | ).samples 80 | self.assertFalse(torch.equal(mixed_input, self.input_audio)) 81 | self.assertEqual(mixed_input.size(0), self.input_audio.size(0)) 82 | self.assertEqual(mixed_input.size(1), self.input_audio.size(1)) 83 | 84 | def test_background_noise_guaranteed_with_batched_tensor(self): 85 | random.seed(42) 86 | mixed_inputs = self.bg_noise_transform_guaranteed( 87 | self.input_audios, self.sample_rate 88 | ).samples 89 | self.assertFalse(torch.equal(mixed_inputs, self.input_audios)) 90 | self.assertEqual(mixed_inputs.size(0), self.input_audios.size(0)) 91 | self.assertEqual(mixed_inputs.size(1), self.input_audios.size(1)) 92 | 93 | def test_background_short_noise_guaranteed_with_batched_tensor(self): 94 | mixed_input = self.bg_short_noise_transform_guaranteed( 95 | self.input_audio, self.sample_rate 96 | ).samples 97 | self.assertFalse(torch.equal(mixed_input, self.input_audio)) 98 | self.assertEqual(mixed_input.size(0), self.input_audio.size(0)) 99 | self.assertEqual(mixed_input.size(1), self.input_audio.size(1)) 100 | 101 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") 102 | def test_background_short_noise_guaranteed_with_batched_cuda_tensor(self): 103 | input_audio_cuda = self.input_audio.cuda() 104 | mixed_input = self.bg_short_noise_transform_guaranteed( 105 | input_audio_cuda, self.sample_rate 106 | ).samples 107 | assert not torch.equal(mixed_input, input_audio_cuda) 108 | assert mixed_input.shape == input_audio_cuda.shape 109 | assert mixed_input.dtype == input_audio_cuda.dtype 110 | assert mixed_input.device == input_audio_cuda.device 111 | 112 | def test_varying_snr_within_batch(self): 113 | min_snr_in_db = 3 114 | max_snr_in_db = 30 115 | augment = AddBackgroundNoise( 116 | self.bg_path, 117 | min_snr_in_db=min_snr_in_db, 118 | max_snr_in_db=max_snr_in_db, 119 | p=1.0, 120 | output_type="dict", 121 | ) 122 | augmented_audios = augment(self.input_audios, self.sample_rate).samples 123 | 124 | self.assertEqual(tuple(augmented_audios.shape), tuple(self.input_audios.shape)) 125 | self.assertFalse(torch.equal(augmented_audios, self.input_audios)) 126 | 127 | added_noises = augmented_audios - self.input_audios 128 | 129 | actual_snr_values = [] 130 | for i in range(len(self.input_audios)): 131 | signal_rms = calculate_rms(self.input_audios[i]) 132 | noise_rms = calculate_rms(added_noises[i]) 133 | snr_in_db = 20 * torch.log10(signal_rms / noise_rms).item() 134 | self.assertGreaterEqual(snr_in_db, min_snr_in_db) 135 | self.assertLessEqual(snr_in_db, max_snr_in_db) 136 | 137 | actual_snr_values.append(snr_in_db) 138 | 139 | self.assertGreater(max(actual_snr_values) - min(actual_snr_values), 13.37) 140 | 141 | def test_invalid_params(self): 142 | with self.assertRaises(ValueError): 143 | augment = AddBackgroundNoise( 144 | self.bg_path, min_snr_in_db=30, max_snr_in_db=3, p=1.0, output_type="dict" 145 | ) 146 | 147 | def test_min_equals_max(self): 148 | desired_snr = 3.0 149 | augment = AddBackgroundNoise( 150 | self.bg_path, 151 | min_snr_in_db=desired_snr, 152 | max_snr_in_db=desired_snr, 153 | p=1.0, 154 | output_type="dict", 155 | ) 156 | augmented_audios = augment(self.input_audios, self.sample_rate).samples 157 | 158 | self.assertEqual(tuple(augmented_audios.shape), tuple(self.input_audios.shape)) 159 | self.assertFalse(torch.equal(augmented_audios, self.input_audios)) 160 | 161 | added_noises = augmented_audios - self.input_audios 162 | for i in range(len(self.input_audios)): 163 | signal_rms = calculate_rms(self.input_audios[i]) 164 | noise_rms = calculate_rms(added_noises[i]) 165 | snr_in_db = 20 * torch.log10(signal_rms / noise_rms).item() 166 | self.assertAlmostEqual(snr_in_db, desired_snr, places=5) 167 | 168 | def test_compatibility_of_resampled_length(self): 169 | random.seed(42) 170 | 171 | for _ in range(30): 172 | input_length = random.randint(1333, 1399) 173 | bg_length = random.randint(1333, 1399) 174 | input_sample_rate = random.randint(1000, 5000) 175 | bg_sample_rate = random.randint(1000, 5000) 176 | 177 | noise = np.random.uniform( 178 | low=-0.2, 179 | high=0.2, 180 | size=(bg_length,), 181 | ).astype(np.float32) 182 | tmp_dir = os.path.join(tempfile.gettempdir(), str(uuid.uuid4())) 183 | try: 184 | os.makedirs(tmp_dir) 185 | write(os.path.join(tmp_dir, "noise.wav"), rate=bg_sample_rate, data=noise) 186 | 187 | print( 188 | f"input_length={input_length}, input_sample_rate={input_sample_rate}," 189 | f" bg_length={bg_length}, bg_sample_rate={bg_sample_rate}" 190 | ) 191 | input_audio = torch.randn(1, 1, input_length, dtype=torch.float32) 192 | transform = AddBackgroundNoise( 193 | tmp_dir, 194 | min_snr_in_db=4, 195 | max_snr_in_db=6, 196 | p=1.0, 197 | sample_rate=input_sample_rate, 198 | output_type="dict", 199 | ) 200 | transform(input_audio) 201 | except Exception: 202 | raise 203 | finally: 204 | shutil.rmtree(tmp_dir) 205 | -------------------------------------------------------------------------------- /torch_audiomentations/core/composition.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List, Union, Optional, Tuple 3 | 4 | from torch import Tensor 5 | import torch.nn 6 | import warnings 7 | 8 | from torch_audiomentations.core.transforms_interface import BaseWaveformTransform 9 | from torch_audiomentations.utils.object_dict import ObjectDict 10 | 11 | 12 | class BaseCompose(torch.nn.Module): 13 | """This class can apply a sequence of transforms to waveforms.""" 14 | 15 | def __init__( 16 | self, 17 | transforms: List[ 18 | torch.nn.Module 19 | ], # FIXME: do we really want to support regular nn.Module? 20 | shuffle: bool = False, 21 | p: float = 1.0, 22 | p_mode="per_batch", 23 | output_type: Optional[str] = None, 24 | ): 25 | """ 26 | :param transforms: List of waveform transform instances 27 | :param shuffle: Should the order of transforms be shuffled? 28 | :param p: The probability of applying the Compose to the given batch. 29 | :param p_mode: Only "per_batch" is supported at the moment. 30 | :param output_type: This optional argument can be set to "tensor" or "dict". 31 | """ 32 | super().__init__() 33 | self.p = p 34 | if p_mode != "per_batch": 35 | # TODO: Support per_example as well? And per_channel? 36 | raise ValueError(f'p_mode = "{p_mode}" is not supported') 37 | self.p_mode = p_mode 38 | self.shuffle = shuffle 39 | self.are_parameters_frozen = False 40 | 41 | if output_type is None: 42 | warnings.warn( 43 | f"Transforms now expect an `output_type` argument that currently defaults to 'tensor', " 44 | f"will default to 'dict' in v0.12, and will be removed in v0.13. Make sure to update " 45 | f"your code to something like:\n" 46 | f" >>> augment = {self.__class__.__name__}(..., output_type='dict')\n" 47 | f" >>> augmented_samples = augment(samples).samples", 48 | FutureWarning, 49 | ) 50 | output_type = "tensor" 51 | 52 | elif output_type == "tensor": 53 | warnings.warn( 54 | f"`output_type` argument will default to 'dict' in v0.12, and will be removed in v0.13. " 55 | f"Make sure to update your code to something like:\n" 56 | f"your code to something like:\n" 57 | f" >>> augment = {self.__class__.__name__}(..., output_type='dict')\n" 58 | f" >>> augmented_samples = augment(samples).samples", 59 | DeprecationWarning, 60 | ) 61 | 62 | self.output_type = output_type 63 | 64 | self.transforms = torch.nn.ModuleList(transforms) 65 | for tfm in self.transforms: 66 | tfm.output_type = "dict" 67 | 68 | def freeze_parameters(self): 69 | """ 70 | Mark all parameters as frozen, i.e. do not randomize them for each call. This can be 71 | useful if you want to apply an effect chain with the exact same parameters to multiple 72 | sounds. 73 | """ 74 | self.are_parameters_frozen = True 75 | for transform in self.transforms: 76 | transform.freeze_parameters() 77 | 78 | def unfreeze_parameters(self): 79 | """ 80 | Unmark all parameters as frozen, i.e. let them be randomized for each call. 81 | """ 82 | self.are_parameters_frozen = False 83 | for transform in self.transforms: 84 | transform.unfreeze_parameters() 85 | 86 | @property 87 | def supported_modes(self) -> set: 88 | """Return the intersection of supported modes of the transforms in the composition.""" 89 | currently_supported_modes = {"per_batch", "per_example", "per_channel"} 90 | for transform in self.transforms: 91 | currently_supported_modes = currently_supported_modes.intersection( 92 | transform.supported_modes 93 | ) 94 | return currently_supported_modes 95 | 96 | 97 | class Compose(BaseCompose): 98 | def forward( 99 | self, 100 | samples: Tensor = None, 101 | sample_rate: Optional[int] = None, 102 | targets: Optional[Tensor] = None, 103 | target_rate: Optional[int] = None, 104 | ) -> ObjectDict: 105 | inputs = ObjectDict( 106 | samples=samples, 107 | sample_rate=sample_rate, 108 | targets=targets, 109 | target_rate=target_rate, 110 | ) 111 | 112 | if random.random() < self.p: 113 | transform_indexes = list(range(len(self.transforms))) 114 | if self.shuffle: 115 | random.shuffle(transform_indexes) 116 | for i in transform_indexes: 117 | tfm = self.transforms[i] 118 | if isinstance(tfm, (BaseWaveformTransform, BaseCompose)): 119 | inputs = self.transforms[i](**inputs) 120 | 121 | else: 122 | assert isinstance(tfm, torch.nn.Module) 123 | # FIXME: do we really want to support regular nn.Module? 124 | inputs.samples = self.transforms[i](inputs.samples) 125 | 126 | return inputs.samples if self.output_type == "tensor" else inputs 127 | 128 | 129 | class SomeOf(BaseCompose): 130 | """ 131 | SomeOf randomly picks several of the given transforms and applies them. 132 | The number of transforms to be applied can be chosen as follows: 133 | 134 | - Pick exactly n transforms 135 | Example: pick exactly 2 of the transforms 136 | `SomeOf(2, [transform1, transform2, transform3])` 137 | 138 | - Pick between a minimum and maximum number of transforms 139 | Example: pick 1 to 3 of the transforms 140 | `SomeOf((1, 3), [transform1, transform2, transform3])` 141 | 142 | Example: Pick 2 to all of the transforms 143 | `SomeOf((2, None), [transform1, transform2, transform3])` 144 | """ 145 | 146 | def __init__( 147 | self, 148 | num_transforms: Union[int, Tuple[int, int]], 149 | transforms: List[torch.nn.Module], 150 | p: float = 1.0, 151 | p_mode="per_batch", 152 | output_type: Optional[str] = None, 153 | ): 154 | super().__init__( 155 | transforms=transforms, p=p, p_mode=p_mode, output_type=output_type 156 | ) 157 | 158 | self.transform_indexes = [] 159 | self.num_transforms = num_transforms 160 | self.all_transforms_indexes = list(range(len(self.transforms))) 161 | 162 | if isinstance(num_transforms, tuple): 163 | self.min_num_transforms = num_transforms[0] 164 | self.max_num_transforms = ( 165 | num_transforms[1] if num_transforms[1] else len(transforms) 166 | ) 167 | else: 168 | self.min_num_transforms = self.max_num_transforms = num_transforms 169 | 170 | assert self.min_num_transforms >= 1, "min_num_transforms must be >= 1" 171 | assert self.min_num_transforms <= len( 172 | transforms 173 | ), "num_transforms must be <= len(transforms)" 174 | assert self.max_num_transforms <= len( 175 | transforms 176 | ), "max_num_transforms must be <= len(transforms)" 177 | 178 | def randomize_parameters(self): 179 | num_transforms_to_apply = random.randint( 180 | self.min_num_transforms, self.max_num_transforms 181 | ) 182 | self.transform_indexes = sorted( 183 | random.sample(self.all_transforms_indexes, num_transforms_to_apply) 184 | ) 185 | 186 | def forward( 187 | self, 188 | samples: Tensor = None, 189 | sample_rate: Optional[int] = None, 190 | targets: Optional[Tensor] = None, 191 | target_rate: Optional[int] = None, 192 | ) -> ObjectDict: 193 | inputs = ObjectDict( 194 | samples=samples, 195 | sample_rate=sample_rate, 196 | targets=targets, 197 | target_rate=target_rate, 198 | ) 199 | 200 | if random.random() < self.p: 201 | if not self.are_parameters_frozen: 202 | self.randomize_parameters() 203 | 204 | for i in self.transform_indexes: 205 | tfm = self.transforms[i] 206 | if isinstance(tfm, (BaseWaveformTransform, BaseCompose)): 207 | inputs = self.transforms[i](**inputs) 208 | 209 | else: 210 | assert isinstance(tfm, torch.nn.Module) 211 | # FIXME: do we really want to support regular nn.Module? 212 | inputs.samples = self.transforms[i](inputs.samples) 213 | 214 | return inputs.samples if self.output_type == "tensor" else inputs 215 | 216 | 217 | class OneOf(SomeOf): 218 | """ 219 | OneOf randomly picks one of the given transforms and applies it. 220 | """ 221 | 222 | def __init__( 223 | self, 224 | transforms: List[torch.nn.Module], 225 | p: float = 1.0, 226 | p_mode="per_batch", 227 | output_type: Optional[str] = None, 228 | ): 229 | super().__init__( 230 | num_transforms=1, 231 | transforms=transforms, 232 | p=p, 233 | p_mode=p_mode, 234 | output_type=output_type, 235 | ) 236 | -------------------------------------------------------------------------------- /scripts/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import torch 8 | from scipy.io import wavfile 9 | 10 | from torch_audiomentations import ( 11 | PolarityInversion, 12 | Gain, 13 | PeakNormalization, 14 | Compose, 15 | Shift, 16 | AddBackgroundNoise, 17 | ApplyImpulseResponse, 18 | AddColoredNoise, 19 | HighPassFilter, 20 | LowPassFilter, 21 | BandPassFilter, 22 | PitchShift, 23 | BandStopFilter, 24 | TimeInversion, 25 | Identity, 26 | ) 27 | from torch_audiomentations.augmentations.padding import Padding 28 | from torch_audiomentations.augmentations.shuffle_channels import ShuffleChannels 29 | from torch_audiomentations.augmentations.splice_out import SpliceOut 30 | from torch_audiomentations.core.transforms_interface import ModeNotSupportedException 31 | from torch_audiomentations.utils.object_dict import ObjectDict 32 | from torch_audiomentations.utils.io import Audio 33 | 34 | SAMPLE_RATE = 44100 35 | 36 | BASE_DIR = Path(os.path.abspath(os.path.dirname(os.path.dirname(__file__)))) 37 | SCRIPTS_DIR = BASE_DIR / "scripts" 38 | TEST_FIXTURES_DIR = BASE_DIR / "test_fixtures" 39 | 40 | 41 | class timer(object): 42 | """ 43 | timer: A class used to measure the execution time of a block of code that is 44 | inside a "with" statement. 45 | 46 | Example: 47 | 48 | ``` 49 | with timer("Count to 500000"): 50 | x = 0 51 | for i in range(500000): 52 | x += 1 53 | print(x) 54 | ``` 55 | 56 | Will output: 57 | 500000 58 | Count to 500000: 0.04 s 59 | 60 | Warning: The time resolution used here may be limited to 1 ms 61 | """ 62 | 63 | def __init__(self, description="Execution time", verbose=False): 64 | self.description = description 65 | self.verbose = verbose 66 | self.execution_time = None 67 | 68 | def __enter__(self): 69 | self.t = time.time() 70 | return self 71 | 72 | def __exit__(self, type, value, traceback): 73 | self.execution_time = time.time() - self.t 74 | if self.verbose: 75 | print("{}: {:.3f} s".format(self.description, self.execution_time)) 76 | 77 | 78 | if __name__ == "__main__": 79 | """ 80 | For each transformation, apply it to an example sound and write the transformed sounds to 81 | an output folder. Also crudely measure and print execution time. 82 | """ 83 | output_dir = os.path.join(SCRIPTS_DIR, "demo_output") 84 | os.makedirs(output_dir, exist_ok=True) 85 | 86 | torch.manual_seed(43) 87 | np.random.seed(43) 88 | random.seed(43) 89 | 90 | filenames = ["perfect-alley1.ogg", "perfect-alley2.ogg"] 91 | audio = Audio(SAMPLE_RATE, mono=True) 92 | samples1 = audio(os.path.join(TEST_FIXTURES_DIR, filenames[0])) 93 | _, num_samples1 = samples1.shape 94 | samples2 = audio(os.path.join(TEST_FIXTURES_DIR, filenames[1])) 95 | _, num_samples2 = samples2.shape 96 | num_samples = min(num_samples1, num_samples2) 97 | samples = torch.stack([samples1[:, :num_samples], samples2[:, :num_samples]], dim=0) 98 | 99 | modes = ["per_batch", "per_example", "per_channel"] 100 | for mode in modes: 101 | transforms = [ 102 | { 103 | "get_instance": lambda: AddBackgroundNoise( 104 | background_paths=TEST_FIXTURES_DIR / "bg", mode=mode, p=1.0 105 | ), 106 | "num_runs": 5, 107 | }, 108 | {"get_instance": lambda: AddColoredNoise(mode=mode, p=1.0), "num_runs": 5}, 109 | { 110 | "get_instance": lambda: ApplyImpulseResponse( 111 | ir_paths=TEST_FIXTURES_DIR / "ir", mode=mode, p=1.0 112 | ), 113 | "num_runs": 1, 114 | }, 115 | { 116 | "get_instance": lambda: ApplyImpulseResponse( 117 | ir_paths=TEST_FIXTURES_DIR / "ir", 118 | compensate_for_propagation_delay=True, 119 | mode=mode, 120 | p=1.0, 121 | ), 122 | "name": "ApplyImpulseResponse with compensate_for_propagation_delay set to True", 123 | "num_runs": 1, 124 | }, 125 | {"get_instance": lambda: BandPassFilter(mode=mode, p=1.0), "num_runs": 5}, 126 | {"get_instance": lambda: BandStopFilter(mode=mode, p=1.0), "num_runs": 5}, 127 | { 128 | "get_instance": lambda: Compose( 129 | transforms=[ 130 | Gain( 131 | min_gain_in_db=-18.0, max_gain_in_db=-16.0, mode=mode, p=1.0 132 | ), 133 | PeakNormalization(mode=mode, p=1.0), 134 | ], 135 | shuffle=True, 136 | ), 137 | "name": "Shuffled Compose with Gain and PeakNormalization", 138 | "num_runs": 5, 139 | }, 140 | { 141 | "get_instance": lambda: Compose( 142 | transforms=[ 143 | Gain( 144 | min_gain_in_db=-18.0, max_gain_in_db=-16.0, mode=mode, p=0.5 145 | ), 146 | PolarityInversion(mode=mode, p=0.5), 147 | ], 148 | shuffle=True, 149 | ), 150 | "name": "Compose with Gain and PolarityInversion", 151 | "num_runs": 5, 152 | }, 153 | {"get_instance": lambda: Gain(mode=mode, p=1.0), "num_runs": 5}, 154 | {"get_instance": lambda: HighPassFilter(mode=mode, p=1.0), "num_runs": 5}, 155 | {"get_instance": lambda: Identity(mode=mode, p=1.0), "num_runs": 1}, 156 | {"get_instance": lambda: LowPassFilter(mode=mode, p=1.0), "num_runs": 5}, 157 | {"get_instance": lambda: Padding(mode=mode, p=1.0), "num_runs": 5}, 158 | {"get_instance": lambda: PeakNormalization(mode=mode, p=1.0), "num_runs": 1}, 159 | { 160 | "get_instance": lambda: PitchShift( 161 | sample_rate=SAMPLE_RATE, mode=mode, p=1.0 162 | ), 163 | "num_runs": 5, 164 | }, 165 | {"get_instance": lambda: PolarityInversion(mode=mode, p=1.0), "num_runs": 1}, 166 | {"get_instance": lambda: Shift(mode=mode, p=1.0), "num_runs": 5}, 167 | {"get_instance": lambda: ShuffleChannels(mode=mode, p=1.0), "num_runs": 5}, 168 | { 169 | "get_instance": lambda: SpliceOut(mode=mode, p=1.0), 170 | "num_runs": 5, 171 | }, 172 | {"get_instance": lambda: TimeInversion(mode=mode, p=1.0), "num_runs": 1}, 173 | ] 174 | 175 | execution_times = {} 176 | 177 | for transform in transforms: 178 | try: 179 | augmenter = transform["get_instance"]() 180 | except ModeNotSupportedException: 181 | continue 182 | transform_name = ( 183 | transform.get("name") 184 | if transform.get("name") 185 | else augmenter.__class__.__name__ 186 | ) 187 | execution_times[transform_name] = [] 188 | for i in range(transform["num_runs"]): 189 | with timer() as t: 190 | augmented_samples = augmenter( 191 | samples=samples, sample_rate=SAMPLE_RATE 192 | ) 193 | print( 194 | augmenter.__class__.__name__, 195 | "is output ObjectDict:", 196 | type(augmented_samples) is ObjectDict, 197 | ) 198 | augmented_samples = ( 199 | augmented_samples.samples.numpy() 200 | if type(augmented_samples) is ObjectDict 201 | else augmented_samples.numpy() 202 | ) 203 | execution_times[transform_name].append(t.execution_time) 204 | for example_idx, original_filename in enumerate(filenames): 205 | output_file_path = os.path.join( 206 | output_dir, 207 | "{}_{}_{:03d}_{}.wav".format( 208 | transform_name, mode, i, Path(original_filename).stem 209 | ), 210 | ) 211 | wavfile.write( 212 | output_file_path, 213 | rate=SAMPLE_RATE, 214 | data=augmented_samples[example_idx].transpose(), 215 | ) 216 | 217 | for transform_name in execution_times: 218 | if len(execution_times[transform_name]) > 1: 219 | print( 220 | "{:<52} {:.3f} s (std: {:.3f} s)".format( 221 | transform_name, 222 | np.mean(execution_times[transform_name]), 223 | np.std(execution_times[transform_name]), 224 | ) 225 | ) 226 | else: 227 | print( 228 | "{:<52} {:.3f} s".format( 229 | transform_name, np.mean(execution_times[transform_name]) 230 | ) 231 | ) 232 | --------------------------------------------------------------------------------