├── benchmark ├── __init__.py ├── benchmark_data.py ├── run_all.py ├── clip_benchmark.py ├── gain_benchmark.py ├── low_pass_benchmark.py ├── band_pass_benchmark.py ├── band_stop_benchmark.py ├── high_pass_benchmark.py ├── add_background_noise_benchmark.py └── benchmark_tool.py ├── fast_audiomentations ├── transforms │ ├── __init__.py │ ├── clip.py │ ├── _impl │ │ ├── _clip_triton.py │ │ ├── _gain_triton.py │ │ ├── _mix_triton.py │ │ └── _filter_triton.py │ ├── gain.py │ ├── low_pass_filter.py │ ├── high_pass_filter.py │ ├── band_stop_filter.py │ ├── band_pass_filter.py │ └── add_background_noise.py ├── functions │ └── audio_io.py └── __init__.py ├── examples ├── test.py ├── clip.py ├── gain.py └── low_pass_filter.py ├── tests └── data │ ├── 44k.wav │ ├── 44k_noise.wav │ └── dupl.sh ├── requirements.txt ├── LICENSE ├── .gitignore ├── README.md └── benchmark_local_result ├── clip.txt ├── gain.txt ├── low_pass_filter.txt ├── band_pass.txt ├── band_stop.txt ├── high_pass_filter.txt └── add_background_noise.txt /benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fast_audiomentations/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/test.py: -------------------------------------------------------------------------------- 1 | from fast_audiomentations import * 2 | 3 | print("SUCCESS") 4 | -------------------------------------------------------------------------------- /tests/data/44k.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lallapallooza/fast-audiomentations/HEAD/tests/data/44k.wav -------------------------------------------------------------------------------- /benchmark/benchmark_data.py: -------------------------------------------------------------------------------- 1 | 2 | PATH_TO_SINGLE_AUDIO='tests/data/44k.wav' 3 | PATH_TO_NOISES='tests/data/noise' -------------------------------------------------------------------------------- /tests/data/44k_noise.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lallapallooza/fast-audiomentations/HEAD/tests/data/44k_noise.wav -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | triton==2.1.0 2 | numpy==1.26.1 3 | torch==2.1.2+cu121 4 | torchaudio==2.1.2+cu121 5 | soundfile==0.12.1 6 | audiomentations==0.34.1 7 | torch_audiomentations==0.11.0 8 | rich 9 | nvidia-dali-cuda120==1.32.0 10 | -------------------------------------------------------------------------------- /tests/data/dupl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Navigate to the "noise" directory 4 | cd noise 5 | 6 | # Loop through all files in the directory 7 | for file in *; do 8 | # Check if it's a file and not a directory 9 | if [ -f "$file" ]; then 10 | # Generate a random prefix 11 | prefix=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 8 | head -n 1) 12 | # Duplicate the file with the new prefix 13 | cp "$file" "${prefix}_$file" 14 | fi 15 | done 16 | -------------------------------------------------------------------------------- /fast_audiomentations/functions/audio_io.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import soundfile as sf 3 | 4 | def load(path: str): 5 | samples, sr = sf.read(path, dtype="float32") 6 | return torch.tensor(samples, device='cuda', dtype=torch.float32).reshape(1, -1), sr 7 | 8 | def save(samples: torch.Tensor, sr: int, path: str): 9 | if samples.ndim > 2: 10 | raise ValueError("Audios with more than 2 dimensions are not supported.") 11 | if samples.ndim == 2: 12 | if samples.shape[0] != 1: 13 | raise ValueError("Audios with shape (C, L) where C != 1 are not supported.") 14 | samples = samples.squeeze(0) 15 | 16 | sf.write(path, samples.cpu(), sr) 17 | -------------------------------------------------------------------------------- /fast_audiomentations/__init__.py: -------------------------------------------------------------------------------- 1 | from fast_audiomentations.transforms.low_pass_filter import LowPassFilter 2 | from fast_audiomentations.transforms.band_pass_filter import BandPassFilter 3 | from fast_audiomentations.transforms.band_stop_filter import BandStopFilter 4 | from fast_audiomentations.transforms.high_pass_filter import HighPassFilter 5 | from fast_audiomentations.transforms.clip import Clip 6 | from fast_audiomentations.transforms.gain import Gain 7 | from fast_audiomentations.transforms.add_background_noise import AddBackgroundNoise 8 | 9 | 10 | __version__ = "0.1.0" 11 | __all__ = [ 12 | "LowPassFilter", 13 | "BandPassFilter", 14 | "BandStopFilter", 15 | "HighPassFilter", 16 | "Clip", 17 | "Gain", 18 | "AddBackgroundNoise" 19 | ] 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Lallapallooza 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/clip.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from fast_audiomentations.functions.audio_io import load, save 3 | from fast_audiomentations import Clip 4 | 5 | 6 | def main(input_path, output_path, batch_size_repeat): 7 | # Load the audio file 8 | samples, sr = load(input_path) 9 | 10 | # Convert samples to tensor, reshape, and repeat 11 | samples = samples.repeat(batch_size_repeat, 1).contiguous() 12 | 13 | # Initialize the Clip augmentation 14 | gain = Clip(min=-0.3, max=0.3, p=1.0) 15 | 16 | # Apply the clip augmentation 17 | augmented_samples = gain(samples=samples, sample_rate=sr) 18 | 19 | # Save all augmented samples to separate files 20 | for i in range(len(augmented_samples)): 21 | save(augmented_samples[i], sr, f"{output_path}_{i}.wav") 22 | 23 | 24 | if __name__ == "__main__": 25 | if len(sys.argv) != 4: 26 | print("Usage: python -m examples.clip ") 27 | sys.exit(1) 28 | 29 | input_path = sys.argv[1] 30 | output_path = sys.argv[2] 31 | batch_size_repeat = int(sys.argv[3]) 32 | 33 | main(input_path, output_path, batch_size_repeat) 34 | -------------------------------------------------------------------------------- /examples/gain.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from fast_audiomentations.functions.audio_io import load, save 3 | from fast_audiomentations import Gain 4 | import torch 5 | 6 | 7 | def main(input_path, output_path, batch_size_repeat): 8 | # Load the audio file 9 | samples, sr = load(input_path) 10 | 11 | # Convert samples to tensor, reshape, and repeat 12 | samples = samples.repeat(batch_size_repeat, 1).contiguous() 13 | 14 | # Initialize the Gain augmentation 15 | gain = Gain(min_gain_in_db=-12, max_gain_in_db=12, p=1.0, dtype=torch.float32) 16 | 17 | # Apply the gain augmentation 18 | augmented_samples = gain(samples=samples, sample_rate=sr) 19 | 20 | # Save all augmented samples to separate files 21 | for i in range(len(augmented_samples)): 22 | save(augmented_samples[i], sr, f"{output_path}_{i}.wav") 23 | 24 | 25 | if __name__ == "__main__": 26 | if len(sys.argv) != 4: 27 | print("Usage: python -m examples.gain ") 28 | sys.exit(1) 29 | 30 | input_path = sys.argv[1] 31 | output_path = sys.argv[2] 32 | batch_size_repeat = int(sys.argv[3]) 33 | 34 | main(input_path, output_path, batch_size_repeat) 35 | -------------------------------------------------------------------------------- /examples/low_pass_filter.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from fast_audiomentations.functions.audio_io import load, save 3 | from fast_audiomentations import LowPassFilter 4 | 5 | 6 | def main(input_path, output_path, batch_size_repeat): 7 | # Load the audio file 8 | samples, sr = load(input_path) 9 | 10 | # Repeat the samples 'batch_size_repeat' times along the first dimension 11 | samples = samples.repeat(batch_size_repeat, 1).contiguous() 12 | 13 | # Initialize the LowPassFilter with specified cutoff frequencies and probability 14 | low_pass_filter = LowPassFilter(min_cutoff_freq=100, max_cutoff_freq=10000, p=1.0, num_taps=101) 15 | 16 | # Apply the low-pass filter to the samples 17 | augmented_samples = low_pass_filter(samples=samples, sample_rate=sr) 18 | 19 | # Save all augmented samples to separate files 20 | for i in range(len(augmented_samples)): 21 | save(augmented_samples[i], sr, f"{output_path}_{i}.wav") 22 | 23 | 24 | if __name__ == "__main__": 25 | if len(sys.argv) != 4: 26 | print("Usage: python -m examples.low_pass_filter ") 27 | sys.exit(1) 28 | 29 | input_path = sys.argv[1] 30 | output_path = sys.argv[2] 31 | batch_size_repeat = int(sys.argv[3]) 32 | 33 | main(input_path, output_path, batch_size_repeat) 34 | -------------------------------------------------------------------------------- /benchmark/run_all.py: -------------------------------------------------------------------------------- 1 | from benchmark.benchmark_tool import Benchmark, _benchmark_registry 2 | import time 3 | 4 | def do_init(): 5 | import os 6 | import glob 7 | import importlib.util 8 | 9 | # List of module names to be imported when 'from benchmark import *' is used 10 | __all__ = [] 11 | 12 | # Path to the directory of this file 13 | directory_path = os.path.dirname(__file__) 14 | 15 | # Import all .py files in this directory 16 | for file_path in glob.glob(os.path.join(directory_path, '*.py')): 17 | # Skip __init__.py 18 | if os.path.basename(file_path).startswith('__'): 19 | continue 20 | 21 | module_name = os.path.splitext(os.path.basename(file_path))[0] 22 | __all__.append(module_name) 23 | 24 | spec = importlib.util.spec_from_file_location(module_name, file_path) 25 | module = importlib.util.module_from_spec(spec) 26 | spec.loader.exec_module(module) 27 | 28 | # Add the module to globals so it's accessible as if 'from import *' 29 | globals()[module_name] = module 30 | 31 | if __name__ == '__main__': 32 | do_init() 33 | print('Start benchmarking') 34 | for bench_type in _benchmark_registry.suites.keys(): 35 | benchmark = Benchmark(name=f"Benchmark of type \'{bench_type}\'", bench_type=bench_type) 36 | benchmark.run() 37 | benchmark.print_results() 38 | -------------------------------------------------------------------------------- /fast_audiomentations/transforms/clip.py: -------------------------------------------------------------------------------- 1 | from fast_audiomentations.transforms._impl._clip_triton import apply_clip as _apply_clip_triton 2 | 3 | import random 4 | import torch 5 | 6 | 7 | class Clip: 8 | """ 9 | Class for applying clipping to audio samples. 10 | 11 | Attributes: 12 | __min (float): The minimum value for clipping. 13 | __max (float): The maximum value for clipping. 14 | p (float): Probability of applying the clipping operation. 15 | """ 16 | def __init__(self, min: float = -1.0, max: float = 1.0, p=0.5): 17 | """ 18 | Initializes the Clip class with given clipping range and probability. 19 | 20 | @param min: The minimum value for clipping. 21 | @param max: The maximum value for clipping. 22 | @param p: Probability of applying the clipping operation. 23 | """ 24 | self.__min = min 25 | self.__max = max 26 | self.p = p 27 | 28 | def __call__(self, samples: torch.Tensor, sample_rate: int, inplace=False): 29 | """ 30 | Apply clipping to the audio samples. 31 | 32 | @param samples: Input audio samples tensor. 33 | @param sample_rate: Sample rate of the audio. Not used in this function but included for API consistency. 34 | @param inplace: If True, perform the operation in-place. 35 | @return: Audio samples after applying clipping. 36 | """ 37 | if random.random() < self.p: 38 | return _apply_clip_triton(samples, self.__min, self.__max, inplace=inplace) 39 | return samples -------------------------------------------------------------------------------- /fast_audiomentations/transforms/_impl/_clip_triton.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | import torch 4 | import itertools 5 | 6 | 7 | @triton.autotune( 8 | configs=[ 9 | triton.Config({'BLOCK_SIZE': block_size}, num_warps=num_warps) 10 | for (block_size, num_warps) in 11 | itertools.product([32, 64, 128, 256, 512, 1024, 2048, 4096], [1, 2, 4, 8, 16, 32]) 12 | ], 13 | key=['n_audios', 'audio_len'], 14 | ) 15 | @triton.jit 16 | def apply_clip_kernel(samples_ptr, min, max, output_ptr, n_audios, audio_len, BLOCK_SIZE: tl.constexpr): 17 | audio_idx = tl.program_id(0) 18 | 19 | if audio_idx >= n_audios: 20 | return 21 | 22 | for i in range(0, audio_len, BLOCK_SIZE): 23 | sample_idx = i + tl.arange(0, BLOCK_SIZE) 24 | mask = sample_idx < audio_len 25 | 26 | samples = tl.load(samples_ptr + audio_idx * audio_len + sample_idx, mask=mask) 27 | result = tl.where(samples > max, max, samples) 28 | result = tl.where(result < min, min, result) 29 | tl.store(output_ptr + audio_idx * audio_len + sample_idx, result, mask=mask) 30 | 31 | 32 | def apply_clip(samples: torch.Tensor, min: float, max: float, inplace: bool = False): 33 | assert min < max 34 | assert samples.ndim == 2 35 | 36 | n_audios, audio_len = samples.shape 37 | 38 | grid = lambda _: (n_audios,) 39 | 40 | if inplace: 41 | apply_clip_kernel[grid](samples, min, max, samples, n_audios, audio_len) 42 | return samples 43 | else: 44 | copy = torch.empty_like(samples, dtype=samples.dtype) 45 | apply_clip_kernel[grid](samples, min, max, copy, n_audios, audio_len) 46 | return copy 47 | -------------------------------------------------------------------------------- /fast_audiomentations/transforms/_impl/_gain_triton.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | import torch 4 | import itertools 5 | 6 | 7 | @triton.autotune( 8 | configs=[ 9 | triton.Config({'BLOCK_SIZE': block_size}, num_warps=num_warps) 10 | for (block_size, num_warps) in 11 | itertools.product([32, 64, 128, 256, 512, 1024, 2048, 4096], [1, 2, 4, 8, 16, 32]) 12 | ], 13 | key=['n_audios', 'audio_len'], 14 | ) 15 | @triton.jit 16 | def apply_gain_kernel(samples_ptr, amplitude_ratios_ptr, output_ptr, n_audios, audio_len, BLOCK_SIZE: tl.constexpr): 17 | audio_idx = tl.program_id(0) 18 | 19 | if audio_idx >= n_audios: 20 | return 21 | 22 | gain = tl.load(amplitude_ratios_ptr + audio_idx) 23 | 24 | for i in range(0, audio_len, BLOCK_SIZE): 25 | sample_idx = i + tl.arange(0, BLOCK_SIZE) 26 | mask = sample_idx < audio_len 27 | samples = tl.load(samples_ptr + audio_idx * audio_len + sample_idx, mask=mask) 28 | result = samples * gain 29 | tl.store(output_ptr + audio_idx * audio_len + sample_idx, result, mask=mask) 30 | 31 | 32 | def apply_gain(samples: torch.Tensor, amplitude_ratios: torch.Tensor, inplace: bool = False): 33 | assert samples.ndim == 2 and amplitude_ratios.ndim == 1 34 | n_audios, audio_len = samples.shape 35 | 36 | grid = lambda _: (n_audios,) 37 | 38 | if inplace: 39 | apply_gain_kernel[grid](samples, amplitude_ratios, samples, n_audios, audio_len) 40 | return samples 41 | else: 42 | copy = torch.empty_like(samples, device='cuda', dtype=samples.dtype) 43 | apply_gain_kernel[grid](samples, amplitude_ratios, copy, n_audios, audio_len) 44 | return copy 45 | -------------------------------------------------------------------------------- /fast_audiomentations/transforms/gain.py: -------------------------------------------------------------------------------- 1 | from fast_audiomentations.transforms._impl._gain_triton import apply_gain as _apply_gain_triton 2 | 3 | import random 4 | import torch 5 | 6 | 7 | class Gain: 8 | """ 9 | Class for applying gain (volume adjustment) to audio samples. 10 | 11 | Attributes: 12 | min_gain_in_db (float): Minimum gain value in decibels. 13 | max_gain_in_db (float): Maximum gain value in decibels. 14 | p (float): Probability of applying the gain operation. 15 | dtype (torch.dtype): Data type for computation. 16 | """ 17 | 18 | def __init__(self, 19 | min_gain_in_db=-12, 20 | max_gain_in_db=12, 21 | p=0.5, 22 | buffer_size=129, 23 | dtype: torch.dtype = torch.float32): 24 | """ 25 | Initializes the Gain class with given gain range, probability, and buffer size. 26 | 27 | @param min_gain_in_db: Minimum gain value in decibels. 28 | @param max_gain_in_db: Maximum gain value in decibels. 29 | @param p: Probability of applying the gain operation. 30 | @param buffer_size: Size of the buffer for random gain generation. 31 | @param dtype: Data type for computation. 32 | """ 33 | self.min_gain_in_db = min_gain_in_db 34 | self.max_gain_in_db = max_gain_in_db 35 | self.p = p 36 | self.dtype = dtype 37 | 38 | self.random_buffer = torch.empty(buffer_size, device='cuda', dtype=dtype) 39 | 40 | def __generate_random_amplitude_ratios(self, num_audios): 41 | """ 42 | Generate random amplitude ratios for the gain operation. 43 | 44 | @param num_audios: Number of audio samples to process. 45 | @return: A tensor of random amplitude ratios. 46 | """ 47 | assert num_audios <= self.random_buffer.size(0) 48 | 49 | slice = self.random_buffer[:num_audios] 50 | slice.uniform_(self.min_gain_in_db, self.max_gain_in_db) 51 | 52 | return slice 53 | 54 | def __call__(self, samples: torch.Tensor, sample_rate: int, inplace=False): 55 | """ 56 | Apply gain (volume adjustment) to the audio samples. 57 | 58 | @param samples: Input audio samples tensor. 59 | @param sample_rate: Sample rate of the audio. Not used in this function but included for API consistency. 60 | @param inplace: If True, perform the operation in-place. 61 | @return: Audio samples after applying gain. 62 | """ 63 | if random.random() < self.p: 64 | gain_factors = self.__generate_random_amplitude_ratios(samples.shape[0]) 65 | return _apply_gain_triton(samples, gain_factors, inplace=inplace) 66 | return samples 67 | -------------------------------------------------------------------------------- /fast_audiomentations/transforms/low_pass_filter.py: -------------------------------------------------------------------------------- 1 | from fast_audiomentations.transforms._impl._filter_triton import create_filters as _create_low_pass_filters 2 | from fast_audiomentations.transforms._impl._filter_triton import fft_conv1d as _fft_conv1d 3 | 4 | import random 5 | import torch 6 | 7 | 8 | class LowPassFilter: 9 | """ 10 | Class for applying a low-pass filter to audio samples. 11 | 12 | Attributes: 13 | min_cutoff_freq (int): Minimum cutoff frequency for the low-pass filter. 14 | max_cutoff_freq (int): Maximum cutoff frequency for the low-pass filter. 15 | num_taps (int): Number of filter taps. 16 | buffer_size (int): Size of the buffer for processing. 17 | p (float): Probability of applying the augmentation. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | min_cutoff_freq: int = 500, 23 | max_cutoff_freq: int = 2000, 24 | num_taps: int = 101, 25 | buffer_size: int = 129, 26 | p: float = 0.5 27 | ): 28 | self.__min_cutoff_freq = min_cutoff_freq 29 | self.__max_cutoff_freq = max_cutoff_freq 30 | self.p = p 31 | self.num_taps = num_taps 32 | self.window = torch.hamming_window(num_taps, device='cuda', dtype=torch.float32, periodic=False) 33 | self.random_buffer = torch.empty(buffer_size, device='cuda') 34 | half = (num_taps - 1) // 2 35 | self.time = torch.arange(-half, half + 1, dtype=torch.float32, device='cuda') 36 | self.filter_output = torch.empty((buffer_size, self.num_taps), device='cuda', dtype=torch.float32) 37 | 38 | def __generate_random_cutoffs(self, num_audios): 39 | """ 40 | Generate random cutoff frequencies for the low-pass filter. 41 | 42 | @param num_audios: Number of audio samples to process. 43 | @return: A tensor of random cutoff frequencies. 44 | """ 45 | assert num_audios <= self.random_buffer.size(0) 46 | 47 | buff_slice = self.random_buffer[:num_audios] 48 | buff_slice.uniform_(self.__min_cutoff_freq, self.__max_cutoff_freq) 49 | 50 | return buff_slice 51 | 52 | def __call__(self, samples: torch.Tensor, sample_rate: int, inplace=False): 53 | """ 54 | Apply the low-pass filter to the audio samples. 55 | 56 | @param samples: Input audio samples tensor. 57 | @param sample_rate: Sample rate of the audio. 58 | @param inplace: If True, perform the operation in-place. 59 | @return: Audio samples after applying the low-pass filter. 60 | """ 61 | if random.random() < self.p: 62 | freqs = self.__generate_random_cutoffs(samples.shape[0]) 63 | 64 | buff_slice = self.filter_output[:len(freqs)] 65 | _create_low_pass_filters( 66 | buff_slice, 67 | freqs, 68 | self.time, 69 | self.window, 70 | sample_rate, 71 | self.num_taps, 72 | "low" 73 | ) 74 | return _fft_conv1d(samples, buff_slice) 75 | 76 | return samples -------------------------------------------------------------------------------- /fast_audiomentations/transforms/high_pass_filter.py: -------------------------------------------------------------------------------- 1 | from fast_audiomentations.transforms._impl._filter_triton import create_filters as _create_high_pass_filters 2 | from fast_audiomentations.transforms._impl._filter_triton import fft_conv1d as _fft_conv1d 3 | 4 | import random 5 | import torch 6 | 7 | 8 | class HighPassFilter: 9 | """ 10 | Class for applying a high-pass filter to audio samples. 11 | 12 | Attributes: 13 | min_cutoff_freq (int): Minimum cutoff frequency for the high-pass filter. 14 | max_cutoff_freq (int): Maximum cutoff frequency for the high-pass filter. 15 | num_taps (int): Number of filter taps. 16 | buffer_size (int): Size of the buffer for processing. 17 | p (float): Probability of applying the augmentation. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | min_cutoff_freq: int = 500, 23 | max_cutoff_freq: int = 2000, 24 | num_taps: int = 101, 25 | buffer_size: int = 129, 26 | p: float = 0.5 27 | ): 28 | self.__min_cutoff_freq = min_cutoff_freq 29 | self.__max_cutoff_freq = max_cutoff_freq 30 | self.p = p 31 | self.num_taps = num_taps 32 | self.window = torch.hamming_window(num_taps, device='cuda', dtype=torch.float32, periodic=False) 33 | self.random_buffer = torch.empty(buffer_size, device='cuda') 34 | half = (num_taps - 1) // 2 35 | self.time = torch.arange(-half, half + 1, dtype=torch.float32, device='cuda') 36 | self.filter_output = torch.empty((buffer_size, self.num_taps), device='cuda', dtype=torch.float32) 37 | 38 | def __generate_random_cutoffs(self, num_audios): 39 | """ 40 | Generate random cutoff frequencies for the high-pass filter. 41 | 42 | @param num_audios: Number of audio samples to process. 43 | @return: A tensor of random cutoff frequencies. 44 | """ 45 | assert num_audios <= self.random_buffer.size(0) 46 | 47 | buff_slice = self.random_buffer[:num_audios] 48 | buff_slice.uniform_(self.__min_cutoff_freq, self.__max_cutoff_freq) 49 | 50 | return buff_slice 51 | 52 | def __call__(self, samples: torch.Tensor, sample_rate: int, inplace=False): 53 | """ 54 | Apply the high-pass filter to the audio samples. 55 | 56 | @param samples: Input audio samples tensor. 57 | @param sample_rate: Sample rate of the audio. 58 | @param inplace: If True, perform the operation in-place. 59 | @return: Audio samples after applying the high-pass filter. 60 | """ 61 | if random.random() < self.p: 62 | freqs = self.__generate_random_cutoffs(samples.shape[0]) 63 | 64 | buff_slice = self.filter_output[:len(freqs)] 65 | _create_high_pass_filters( 66 | buff_slice, 67 | freqs, 68 | self.time, 69 | self.window, 70 | sample_rate, 71 | self.num_taps, 72 | "high" 73 | ) 74 | return _fft_conv1d(samples, buff_slice) 75 | 76 | return samples 77 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fast-audiomentations 2 | `fast-audiomentations` is a Python library that leverages GPU acceleration for efficient audio augmentation, 3 | suitable for high-throughput audio analysis and machine learning applications. 4 | 5 | ## Key Features 6 | 7 | - Performance Showcase: Highlights the advantages of using GPU for audio processing tasks. 8 | - Diverse Audio Augmentations: Offers a variety of transformations including noise addition, filtering, and gain adjustments, optimized for GPU performance. 9 | - User-Friendly API: Designed for ease of use, enabling quick integration into audio processing pipelines. 10 | - Triton-based GPU code: Triton enables maximized GPU utilization. Although real-world examples of Triton for straightforward kernels are limited, this library serves as an invaluable practical study resource. 11 | 12 | ## Installation 13 | Currently, there is no ready-to-use package on PyPI, so you need to: 14 | ```bash 15 | git clone https://github.com/Lallapallooza/fast-audiomentations 16 | cd fast-audiomentations 17 | python -m pip install -r requirements.txt 18 | python -m examples.test 19 | ``` 20 | Assuming "SUCCESS" is printed, the library is now ready to use. 21 | Examples can be found in the `examples/` and `benchmark/` directories. 22 | 23 | ## Usage 24 | Here's an example of how to apply background noise addition to an audio sample 25 | using `fast-audiomentations`: 26 | 27 | ```python 28 | from fast_audiomentations import AddBackgroundNoise 29 | import torch 30 | 31 | batch_size = 128 32 | 33 | dataloader = AddBackgroundNoise.get_dali_dataloader( 34 | 'some/path/to/noises/in/dali/format', 35 | buffer_size=batch_size, 36 | n_workers=4 37 | ) 38 | add_background_noise = AddBackgroundNoise( 39 | noises_dataloader=dataloader, 40 | min_snr=-10, 41 | max_snr=10, 42 | buffer_size=batch_size, 43 | p=1.0 44 | ) 45 | 46 | sample_rate = 48000 47 | audio = load_my_audio_batch() # load as torch.Tensor 48 | audio_lens = torch.tensor([audio.shape[0] for i in range(batch_size)], device='cuda') 49 | 50 | # torch.Tensor 51 | audios_with_noise = add_background_noise( 52 | samples=audio, 53 | samples_lens=audio_lens, 54 | sample_rate=sample_rate 55 | ) 56 | ``` 57 | 58 | ## Benchmarking 59 | To validate the hypothesis that GPU can accelerate audio augmentation processing, 60 | benchmarks were conducted in comparison with two libraries: 61 | [audiomentations](https://github.com/iver56/audiomentations) 62 | and 63 | [torch-audiomentations](https://github.com/asteroid-team/torch-audiomentations). 64 | 65 | Details about the specific batch sizes, augmentation types, and parameters can be found in the code. 66 | See the `benchmark/` directory. 67 | 68 | For my configuration: 69 | - NVIDIA GeForce RTX 3090 Ti 70 | - 64GB RAM 71 | - 12th Gen Intel(R) Core(TM) i9-12900KF 72 | - Samsung 980 PRO 73 | 74 | The results are available in the `benchmark_results_local/` directory. 75 | To run benchmarks, you need to: 76 | 1. Change the parameters in `benchmark/benchmark_data.py`. 77 | 2. Run the command: ```python -m benchmark.run_all```. 78 | 79 | Observations show that for stateless operations (no IO required), 80 | the performance increase is significant (dozens of times). 81 | However, for augmentations like adding background noise, 82 | the performance benefit is less significant due to IO/PCIe overhead. 83 | Even utilizing DALI does not change the situation much, as the GPU is heavily underutilized. 84 | Nonetheless, it's observed that the NVMe device is not fully utilized, 85 | suggesting there might be some code inefficiency or a limitation in Python's IO performance. 86 | 87 | ## Future work 88 | 89 | - Add more augmentations (RIR convolve, SpecAug, etc.). We welcome requests for additional augmentations and will endeavor to implement them promptly. 90 | - Rewrite the IO process for stateful cases (RIR, add_background_noise) to fully utilize IO capabilities. 91 | - Conduct benchmarks with network devices for stateful scenarios to demonstrate that asynchronous data reading may reduce latency. 92 | - Support XLA backend 93 | - Add wrappers like `Compose`, `OneOf`, etc... 94 | 95 | -------------------------------------------------------------------------------- /fast_audiomentations/transforms/band_stop_filter.py: -------------------------------------------------------------------------------- 1 | from fast_audiomentations.transforms._impl._filter_triton import create_filters as _create_pass_filters 2 | from fast_audiomentations.transforms._impl._filter_triton import fft_conv1d as _fft_conv1d 3 | 4 | import random 5 | import torch 6 | 7 | 8 | class BandStopFilter: 9 | """ 10 | Class for applying a band-stop filter to audio samples. 11 | 12 | Attributes: 13 | min_center_freq (int): Minimum center frequency for the band-stop filter. 14 | max_center_freq (int): Maximum center frequency for the band-stop filter. 15 | num_taps (int): Number of filter taps. 16 | buffer_size (int): Size of the buffer for processing. 17 | p (float): Probability of applying the augmentation. 18 | """ 19 | def __init__( 20 | self, 21 | min_center_freq: int = 500, 22 | max_center_freq: int = 2000, 23 | num_taps: int = 101, 24 | buffer_size: int = 129, 25 | p: float = 0.5 26 | ): 27 | self.__min_center_freq = min_center_freq 28 | self.__max_center_freq = max_center_freq 29 | self.p = p 30 | self.num_taps = num_taps 31 | 32 | # Generating a Hamming window for filter design 33 | self.window = torch.hamming_window(num_taps, device='cuda', dtype=torch.float32, periodic=False) 34 | 35 | # Buffers for random frequency generation 36 | self.random_buffer_low = torch.empty(buffer_size, device='cuda') 37 | self.random_buffer_high = torch.empty(buffer_size, device='cuda') 38 | 39 | # Time buffer for filter design 40 | half = (num_taps - 1) // 2 41 | self.time = torch.arange(-half, half + 1, dtype=torch.float32, device='cuda') 42 | 43 | # Buffers for filter outputs 44 | self.filter_output_low = torch.empty((buffer_size, self.num_taps), device='cuda', dtype=torch.float32) 45 | self.filter_output_high = torch.empty((buffer_size, self.num_taps), device='cuda', dtype=torch.float32) 46 | 47 | def __generate_random_cutoffs(self, num_audios): 48 | """ 49 | Generate random cutoff frequencies for the band-stop filter. 50 | 51 | @param num_audios: Number of audio samples to process. 52 | @return: Tuple of tensors containing low and high cutoff frequencies. 53 | """ 54 | assert num_audios <= self.random_buffer_low.size(0) 55 | 56 | # Randomly generate low and high cutoff frequencies 57 | buff_slice_low = self.random_buffer_low[:num_audios] 58 | buff_slice_low.uniform_(self.__min_center_freq, self.__max_center_freq) 59 | 60 | buff_slice_high = self.random_buffer_high[:num_audios] 61 | buff_slice_high.uniform_(self.__min_center_freq, self.__max_center_freq) 62 | 63 | # Ensuring the high cutoff is always higher than the low cutoff 64 | # TODO: rewrite on triton as well 65 | buff_slice_high = torch.clip(buff_slice_high + buff_slice_low, 0, self.__max_center_freq) 66 | 67 | return buff_slice_low, buff_slice_high 68 | 69 | def __call__(self, samples: torch.Tensor, sample_rate: int, inplace=False): 70 | """ 71 | Apply the band-stop filter to the given audio samples. 72 | 73 | @param samples: Input audio samples tensor. 74 | @param sample_rate: Sample rate of the audio. 75 | @param inplace: If True, perform operation in-place. 76 | @return: Audio samples after applying the band-stop filter. 77 | """ 78 | if random.random() < self.p: 79 | # Generate random frequencies for the filter 80 | freqs_low, freqs_high = self.__generate_random_cutoffs(samples.shape[0]) 81 | 82 | # Create filters based on the generated frequencies 83 | buff_slice_low = self.filter_output_low[:len(freqs_low)] 84 | _create_pass_filters( 85 | buff_slice_low, 86 | freqs_low, 87 | self.time, 88 | self.window, 89 | sample_rate, 90 | self.num_taps, 91 | mode="low" 92 | ) 93 | 94 | buff_slice_high = self.filter_output_high[:len(freqs_low)] 95 | _create_pass_filters( 96 | buff_slice_high, 97 | freqs_high, 98 | self.time, 99 | self.window, 100 | sample_rate, 101 | self.num_taps, 102 | mode="high" 103 | ) 104 | 105 | # Apply the convolution with the created filters 106 | return _fft_conv1d(samples, buff_slice_low - buff_slice_high) 107 | 108 | return samples 109 | -------------------------------------------------------------------------------- /fast_audiomentations/transforms/band_pass_filter.py: -------------------------------------------------------------------------------- 1 | from fast_audiomentations.transforms._impl._filter_triton import create_filters as _create_pass_filters 2 | from fast_audiomentations.transforms._impl._filter_triton import fft_conv1d as _fft_conv1d 3 | 4 | import random 5 | import torch 6 | 7 | 8 | class BandPassFilter: 9 | """ 10 | Class for applying a band-pass filter to audio samples. 11 | 12 | Attributes: 13 | min_center_freq (int): Minimum center frequency for the band-pass filter. 14 | max_center_freq (int): Maximum center frequency for the band-pass filter. 15 | num_taps (int): Number of filter taps. 16 | buffer_size (int): Size of the buffer for processing. 17 | p (float): Probability of applying the augmentation. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | min_center_freq: int = 500, 23 | max_center_freq: int = 2000, 24 | num_taps: int = 101, 25 | buffer_size: int = 129, 26 | p: float = 0.5 27 | ): 28 | self.__min_center_freq = min_center_freq 29 | self.__max_center_freq = max_center_freq 30 | self.p = p 31 | self.num_taps = num_taps 32 | 33 | # Generating a Hamming window for filter design 34 | self.window = torch.hamming_window(num_taps, device='cuda', dtype=torch.float32, periodic=False) 35 | 36 | # Buffers for random frequency generation 37 | self.random_buffer_low = torch.empty(buffer_size, device='cuda') 38 | self.random_buffer_high = torch.empty(buffer_size, device='cuda') 39 | 40 | # Time buffer for filter design 41 | half = (num_taps - 1) // 2 42 | self.time = torch.arange(-half, half + 1, dtype=torch.float32, device='cuda') 43 | 44 | # Buffers for filter outputs 45 | self.filter_output_low = torch.empty((buffer_size, self.num_taps), device='cuda', dtype=torch.float32) 46 | self.filter_output_high = torch.empty((buffer_size, self.num_taps), device='cuda', dtype=torch.float32) 47 | 48 | def __generate_random_cutoffs(self, num_audios): 49 | """ 50 | Generate random cutoff frequencies for the band-pass filter. 51 | 52 | @param num_audios: Number of audio samples to process. 53 | @return: Tuple of tensors containing low and high cutoff frequencies. 54 | """ 55 | assert num_audios <= self.random_buffer_low.size(0) 56 | 57 | # Randomly generate low and high cutoff frequencies 58 | buff_slice_low = self.random_buffer_low[:num_audios] 59 | buff_slice_low.uniform_(self.__min_center_freq, self.__max_center_freq) 60 | 61 | buff_slice_high = self.random_buffer_high[:num_audios] 62 | buff_slice_high.uniform_(self.__min_center_freq, self.__max_center_freq) 63 | 64 | # TODO: rewrite on triton as well 65 | # Ensuring the high cutoff is always higher than the low cutoff 66 | buff_slice_high = torch.clip(buff_slice_high + buff_slice_low, 0, self.__max_center_freq) 67 | 68 | return buff_slice_low, buff_slice_high 69 | 70 | def __call__(self, samples: torch.Tensor, sample_rate: int, inplace=False): 71 | """ 72 | Apply the band-pass filter to the given audio samples. 73 | 74 | @param samples: Input audio samples tensor. 75 | @param sample_rate: Sample rate of the audio. 76 | @param inplace: If True, perform operation in-place. 77 | @return: Audio samples after applying the band-pass filter. 78 | """ 79 | if random.random() < self.p: 80 | # Generate random frequencies for the filter 81 | freqs_low, freqs_high = self.__generate_random_cutoffs(samples.shape[0]) 82 | 83 | # Create filters based on the generated frequencies 84 | buff_slice_low = self.filter_output_low[:len(freqs_low)] 85 | _create_pass_filters( 86 | buff_slice_low, 87 | freqs_low, 88 | self.time, 89 | self.window, 90 | sample_rate, 91 | self.num_taps, 92 | mode="low" 93 | ) 94 | 95 | buff_slice_high = self.filter_output_high[:len(freqs_low)] 96 | _create_pass_filters( 97 | buff_slice_high, 98 | freqs_high, 99 | self.time, 100 | self.window, 101 | sample_rate, 102 | self.num_taps, 103 | mode="low" 104 | ) 105 | 106 | # Apply the convolution with the created filters 107 | return _fft_conv1d(samples, buff_slice_low - buff_slice_high) 108 | 109 | return samples 110 | -------------------------------------------------------------------------------- /fast_audiomentations/transforms/_impl/_mix_triton.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | import torch 4 | import itertools 5 | 6 | 7 | @triton.jit 8 | def rms_kernel(audios, audios_real_lens, audios_max_len, batch_idx, BLOCK_SIZE_RMS: tl.constexpr): 9 | audios_real_lens_vals = tl.load(audios_real_lens + batch_idx) 10 | 11 | _mean = tl.zeros([BLOCK_SIZE_RMS], dtype=tl.float32) 12 | for offset in range(0, audios_max_len, BLOCK_SIZE_RMS): 13 | audios_block_ptr = offset + tl.arange(0, BLOCK_SIZE_RMS) 14 | audios_mask = audios_block_ptr < audios_real_lens_vals 15 | 16 | audios_vals = tl.load(audios + batch_idx * audios_max_len + audios_block_ptr, mask=audios_mask) 17 | audios_partial_sum_sq = tl.where(audios_mask, tl.math.pow(audios_vals, 2.0), 0) 18 | _mean += audios_partial_sum_sq 19 | 20 | audios_global_sum_sq = tl.sum(_mean, axis=0) 21 | return tl.sqrt(audios_global_sum_sq / audios_real_lens_vals) 22 | 23 | @triton.autotune( 24 | configs=[ 25 | triton.Config({'BLOCK_SIZE_SUM': block_size_sum}, num_warps=num_warps) 26 | for (block_size_sum, num_warps) in 27 | itertools.product( 28 | [512, 1024], 29 | [2, 4, 8, 16] 30 | ) 31 | ], 32 | key=['clean_audio_max_len', 'noisy_audio_max_len'] 33 | ) 34 | @triton.jit 35 | def sum_with_snr_kernel( 36 | clean_audio, clean_audio_real_lens, clean_audio_max_len, desired_rms, 37 | noisy_audio_ptr, noisy_audio_real_lens, noisy_audio_max_len, 38 | output_ptr, BLOCK_SIZE_SUM: tl.constexpr, BLOCK_SIZE_RMS: tl.constexpr): 39 | batch_idx = tl.program_id(0) 40 | 41 | # RMS clean 42 | clean_audio_real_lens_val = tl.load(clean_audio_real_lens + batch_idx) 43 | clean_audio_rms = rms_kernel(clean_audio, clean_audio_real_lens, clean_audio_max_len, batch_idx, BLOCK_SIZE_RMS) 44 | 45 | # RMS noisy 46 | noisy_audio_real_lens_val = tl.load(noisy_audio_real_lens + batch_idx) 47 | 48 | noisy_audio_rms = rms_kernel(noisy_audio_ptr, noisy_audio_real_lens, noisy_audio_max_len, batch_idx, BLOCK_SIZE_RMS) 49 | 50 | # Desired RMS for noisy scale 51 | desired_rms_val = tl.load(desired_rms + batch_idx) 52 | relative_rms = clean_audio_rms / tl.math.pow(10.0, desired_rms_val / 20.0) 53 | 54 | for offset in range(0, clean_audio_max_len, BLOCK_SIZE_SUM): 55 | clean_audio_block_ptr = offset + tl.arange(0, BLOCK_SIZE_SUM) 56 | clean_audio_mask = clean_audio_block_ptr < clean_audio_real_lens_val 57 | clean_audio_vals = tl.load( 58 | clean_audio + batch_idx * clean_audio_max_len + clean_audio_block_ptr, 59 | mask=clean_audio_mask 60 | ) 61 | 62 | """ 63 | Adjusts the block's start position if it extends beyond the noisy audio array, shifting it leftward as needed. 64 | This adjustment keeps the data block within the noisy audio array limits, accounting for its circular nature 65 | 66 | Scenario without adjustment: 67 | noisy_audio_array: |----|----|----|----|----|----|----|----| 68 | block: |~~~~~~~~~~~~~~~~| 69 | (Block fits within the array, no adjustment needed) 70 | 71 | Scenario with adjustment: 72 | noisy_audio_array: |----|----|----|----|----|----|----|----| 73 | block: |~~~~~~~~~~~~~~~~| 74 | (Block exceeds array bounds, needs to be shifted left) 75 | noisy_audio_array: |----|----|----|----|----|----|----|----| 76 | block: |~~~~~~~~~~~~~~~~| <--- Shifted left 77 | """ 78 | offset_over_max = offset % noisy_audio_real_lens_val 79 | 80 | offset_adjusted = offset_over_max - tl.math.min( 81 | offset_over_max, 82 | tl.math.max(0, (offset_over_max + BLOCK_SIZE_SUM) - noisy_audio_real_lens_val) 83 | ) 84 | 85 | noisy_audio_block_ptr = offset_adjusted + tl.arange(0, BLOCK_SIZE_SUM) 86 | 87 | noisy_audio_val = tl.load( 88 | noisy_audio_ptr + batch_idx * noisy_audio_max_len + noisy_audio_block_ptr, 89 | mask=noisy_audio_block_ptr < noisy_audio_real_lens_val 90 | ) 91 | 92 | tl.store( 93 | output_ptr + batch_idx * clean_audio_max_len + clean_audio_block_ptr, 94 | clean_audio_vals + noisy_audio_val * (relative_rms / noisy_audio_rms), 95 | mask=clean_audio_mask 96 | ) 97 | 98 | def sum_with_snr_triton(samples: torch.Tensor, samples_lens: torch.Tensor, samples_noise, samples_noise_lens: torch.Tensor, snrs): 99 | assert samples.is_contiguous() and samples_noise.is_contiguous(), "Samples must be contiguous" 100 | 101 | B, T = samples.shape 102 | output = torch.empty_like(samples, device=samples.device, dtype=samples.dtype) 103 | 104 | grid = lambda opt: (B,) 105 | 106 | sum_with_snr_kernel[grid]( 107 | samples, samples_lens, T, snrs, 108 | samples_noise, samples_noise_lens, samples_noise.shape[1], 109 | output, BLOCK_SIZE_RMS=max(1024, triton.next_power_of_2(max(T, samples_noise.shape[1]) // 1024))) 110 | 111 | return output -------------------------------------------------------------------------------- /benchmark/clip_benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from benchmark.benchmark_tool import ( 3 | Benchmark, 4 | BenchmarkSuite, 5 | SingleAudioProvider, 6 | SingleAudioProviderClasses, 7 | benchmark 8 | ) 9 | 10 | 11 | @benchmark( 12 | bench_type='clip', 13 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=1, dtype='float32'), 14 | n_iters=1000 15 | ) 16 | @benchmark( 17 | bench_type='clip', 18 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=16, dtype='float32'), 19 | n_iters=1000 20 | ) 21 | @benchmark( 22 | bench_type='clip', 23 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=32, dtype='float32'), 24 | n_iters=1000 25 | ) 26 | @benchmark( 27 | bench_type='clip', 28 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=64, dtype='float32'), 29 | n_iters=1000 30 | ) 31 | @benchmark( 32 | bench_type='clip', 33 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=128, dtype='float32'), 34 | n_iters=1000 35 | ) 36 | class AudiomentationsBenchmarkSuite(BenchmarkSuite): 37 | def __init__(self, data_provider: SingleAudioProvider, n_iters): 38 | super().__init__( 39 | name=f"Audiomentations Clip (dtype={data_provider.dtype}) (batch_size={data_provider.batch_size})", 40 | warmup_iterations=10, 41 | iterations=n_iters, 42 | samples_per_iter=data_provider.batch_size 43 | ) 44 | 45 | self.data_provider = data_provider 46 | 47 | def on_start(self): 48 | from audiomentations import Clip 49 | self.clip = Clip(a_min=-0.6, a_max=0.6, p=1.0) 50 | self.audio = self.data_provider.get()[0] 51 | 52 | def run_iteration(self): 53 | for i in range(self.data_provider.batch_size): 54 | self.clip(samples=self.audio, sample_rate=44100) 55 | 56 | def run_warmup_iteration(self): 57 | for i in range(self.data_provider.batch_size): 58 | self.clip(samples=self.audio, sample_rate=44100) 59 | 60 | def on_end(self): 61 | pass 62 | 63 | 64 | @benchmark( 65 | bench_type='clip', 66 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=1, dtype=torch.float32), 67 | n_iters=1000 68 | ) 69 | @benchmark( 70 | bench_type='clip', 71 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=1, dtype=torch.float16), 72 | n_iters=1000 73 | ) 74 | @benchmark( 75 | bench_type='clip', 76 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=16, dtype=torch.float32), 77 | n_iters=1000 78 | ) 79 | @benchmark( 80 | bench_type='clip', 81 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=16, dtype=torch.float16), 82 | n_iters=1000 83 | ) 84 | @benchmark( 85 | bench_type='clip', 86 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=32, dtype=torch.float32), 87 | n_iters=1000 88 | ) 89 | @benchmark( 90 | bench_type='clip', 91 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=32, dtype=torch.float16), 92 | n_iters=1000 93 | ) 94 | @benchmark( 95 | bench_type='clip', 96 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=64, dtype=torch.float32), 97 | n_iters=1000 98 | ) 99 | @benchmark( 100 | bench_type='clip', 101 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=64, dtype=torch.float16), 102 | n_iters=1000 103 | ) 104 | @benchmark( 105 | bench_type='clip', 106 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=128, dtype=torch.float32), 107 | n_iters=1000 108 | ) 109 | @benchmark( 110 | bench_type='clip', 111 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=128, dtype=torch.float16), 112 | n_iters=1000 113 | ) 114 | class FastAudiomentationsBenchmarkSuite(BenchmarkSuite): 115 | def __init__(self, data_provider: SingleAudioProvider, n_iters): 116 | super().__init__( 117 | name=f"Fast Audiomentations Clip ({data_provider.dtype}) (batch_size={data_provider.batch_size})", 118 | warmup_iterations=10, 119 | iterations=n_iters, 120 | samples_per_iter=data_provider.batch_size, 121 | is_gpu_timer_required=True 122 | ) 123 | 124 | self.data_provider = data_provider 125 | 126 | def on_start(self): 127 | from fast_audiomentations import Clip 128 | self.clip = Clip(min=-0.6, max=0.6, p=1.0) 129 | self.audio = self.data_provider.get() 130 | 131 | def run_iteration(self): 132 | self.clip(samples=self.audio, sample_rate=44100) 133 | 134 | def run_warmup_iteration(self): 135 | self.clip(samples=self.audio, sample_rate=44100) 136 | 137 | def on_end(self): 138 | pass 139 | 140 | 141 | if __name__ == '__main__': 142 | benchmark = Benchmark(name="Clip Benchmark", bench_type="clip") 143 | benchmark.run() 144 | benchmark.print_results() 145 | -------------------------------------------------------------------------------- /benchmark/gain_benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from benchmark.benchmark_tool import ( 3 | Benchmark, 4 | BenchmarkSuite, 5 | SingleAudioProvider, 6 | SingleAudioProviderClasses, 7 | benchmark 8 | ) 9 | 10 | 11 | @benchmark( 12 | bench_type='gain', 13 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=1, dtype='float32'), 14 | n_iters=1000 15 | ) 16 | @benchmark( 17 | bench_type='gain', 18 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=16, dtype='float32'), 19 | n_iters=1000 20 | ) 21 | @benchmark( 22 | bench_type='gain', 23 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=32, dtype='float32'), 24 | n_iters=1000 25 | ) 26 | @benchmark( 27 | bench_type='gain', 28 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=64, dtype='float32'), 29 | n_iters=1000 30 | ) 31 | @benchmark( 32 | bench_type='gain', 33 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=128, dtype='float32'), 34 | n_iters=1000 35 | ) 36 | class AudiomentationsBenchmarkSuite(BenchmarkSuite): 37 | def __init__(self, data_provider: SingleAudioProvider, n_iters): 38 | super().__init__( 39 | name=f"Audiomentations Gain (dtype={data_provider.dtype}) (batch_size={data_provider.batch_size})", 40 | warmup_iterations=10, 41 | iterations=n_iters, 42 | samples_per_iter=data_provider.batch_size 43 | ) 44 | 45 | self.data_provider = data_provider 46 | 47 | def on_start(self): 48 | from audiomentations import Gain 49 | self.gain = Gain(min_gain_in_db=-10, max_gain_in_db=10, p=1.0) 50 | self.audio = self.data_provider.get()[0] 51 | 52 | def run_iteration(self): 53 | for i in range(self.data_provider.batch_size): 54 | self.gain(samples=self.audio, sample_rate=44100) 55 | 56 | def run_warmup_iteration(self): 57 | for i in range(self.data_provider.batch_size): 58 | self.gain(samples=self.audio, sample_rate=44100) 59 | 60 | def on_end(self): 61 | pass 62 | 63 | 64 | @benchmark( 65 | bench_type='gain', 66 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=1, dtype=torch.float32), 67 | n_iters=1000 68 | ) 69 | @benchmark( 70 | bench_type='gain', 71 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=1, dtype=torch.float16), 72 | n_iters=1000 73 | ) 74 | @benchmark( 75 | bench_type='gain', 76 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=16, dtype=torch.float32), 77 | n_iters=1000 78 | ) 79 | @benchmark( 80 | bench_type='gain', 81 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=16, dtype=torch.float16), 82 | n_iters=1000 83 | ) 84 | @benchmark( 85 | bench_type='gain', 86 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=32, dtype=torch.float32), 87 | n_iters=1000 88 | ) 89 | @benchmark( 90 | bench_type='gain', 91 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=32, dtype=torch.float16), 92 | n_iters=1000 93 | ) 94 | @benchmark( 95 | bench_type='gain', 96 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=64, dtype=torch.float32), 97 | n_iters=1000 98 | ) 99 | @benchmark( 100 | bench_type='gain', 101 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=64, dtype=torch.float16), 102 | n_iters=1000 103 | ) 104 | @benchmark( 105 | bench_type='gain', 106 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=128, dtype=torch.float32), 107 | n_iters=1000 108 | ) 109 | @benchmark( 110 | bench_type='gain', 111 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=128, dtype=torch.float16), 112 | n_iters=1000 113 | ) 114 | class FastAudiomentationsBenchmarkSuite(BenchmarkSuite): 115 | def __init__(self, data_provider: SingleAudioProvider, n_iters): 116 | super().__init__( 117 | name=f"Fast Audiomentations Gain ({data_provider.dtype}) (batch_size={data_provider.batch_size})", 118 | warmup_iterations=10, 119 | iterations=n_iters, 120 | samples_per_iter=data_provider.batch_size, 121 | is_gpu_timer_required=True 122 | ) 123 | 124 | self.data_provider = data_provider 125 | 126 | def on_start(self): 127 | from fast_audiomentations.transforms.gain import Gain 128 | self.gain = Gain(min_gain_in_db=-10, max_gain_in_db=10, p=1.0) 129 | self.audio = self.data_provider.get() 130 | 131 | def run_iteration(self): 132 | self.gain(samples=self.audio, sample_rate=44100) 133 | 134 | def run_warmup_iteration(self): 135 | self.gain(samples=self.audio, sample_rate=44100) 136 | 137 | def on_end(self): 138 | pass 139 | 140 | 141 | @benchmark( 142 | 'gain', 143 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=1, dtype=torch.float32), 144 | n_iters=1000 145 | ) 146 | @benchmark( 147 | 'gain', 148 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=16, dtype=torch.float32), 149 | n_iters=1000 150 | ) 151 | @benchmark( 152 | 'gain', 153 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=32, dtype=torch.float32), 154 | n_iters=1000 155 | ) 156 | @benchmark( 157 | 'gain', 158 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=64, dtype=torch.float32), 159 | n_iters=1000 160 | ) 161 | @benchmark( 162 | 'gain', 163 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=128, dtype=torch.float32), 164 | n_iters=1000 165 | ) 166 | class TorchAudiomentationsBenchmarkSuite(BenchmarkSuite): 167 | def __init__(self, data_provider: SingleAudioProvider, n_iters): 168 | super().__init__( 169 | name=f"Torch Audiomentations Gain ({data_provider.dtype}) (batch_size={data_provider.batch_size})", 170 | warmup_iterations=10, 171 | iterations=n_iters, 172 | samples_per_iter=data_provider.batch_size, 173 | is_gpu_timer_required=True 174 | ) 175 | self.data_provider = data_provider 176 | 177 | def on_start(self): 178 | from torch_audiomentations import Gain 179 | self.gain = Gain(min_gain_in_db=-10, max_gain_in_db=10, p=1.0) 180 | self.audio = self.data_provider.get().unsqueeze(1) 181 | 182 | def run_iteration(self): 183 | self.gain(samples=self.audio, sample_rate=44100) 184 | 185 | def run_warmup_iteration(self): 186 | self.gain(samples=self.audio, sample_rate=44100) 187 | 188 | def on_end(self): 189 | pass 190 | 191 | 192 | if __name__ == '__main__': 193 | benchmark = Benchmark(name="Gain Benchmark", bench_type='gain') 194 | benchmark.run() 195 | benchmark.print_results() 196 | -------------------------------------------------------------------------------- /benchmark/low_pass_benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from benchmark.benchmark_tool import ( 3 | Benchmark, 4 | BenchmarkSuite, 5 | SingleAudioProvider, 6 | SingleAudioProviderClasses, 7 | benchmark 8 | ) 9 | 10 | 11 | @benchmark( 12 | bench_type='low_pass_filter', 13 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=1, dtype='float32'), 14 | n_iters=1000 15 | ) 16 | @benchmark( 17 | bench_type='low_pass_filter', 18 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=16, dtype='float32'), 19 | n_iters=1000 20 | ) 21 | @benchmark( 22 | bench_type='low_pass_filter', 23 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=32, dtype='float32'), 24 | n_iters=1000 25 | ) 26 | @benchmark( 27 | bench_type='low_pass_filter', 28 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=64, dtype='float32'), 29 | n_iters=1000 30 | ) 31 | @benchmark( 32 | bench_type='low_pass_filter', 33 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=128, dtype='float32'), 34 | n_iters=1000 35 | ) 36 | class AudiomentationsBenchmarkSuite(BenchmarkSuite): 37 | def __init__(self, data_provider: SingleAudioProvider, n_iters): 38 | super().__init__( 39 | name=f"Audiomentations Low Pass Filter (dtype={data_provider.dtype}) (batch_size={data_provider.batch_size})", 40 | warmup_iterations=10, 41 | iterations=n_iters, 42 | samples_per_iter=data_provider.batch_size 43 | ) 44 | 45 | self.data_provider = data_provider 46 | 47 | def on_start(self): 48 | from audiomentations import LowPassFilter 49 | self.low_pass_filter = LowPassFilter(min_cutoff_freq=100, max_cutoff_freq=10000, p=1.0) 50 | self.audio = self.data_provider.get()[0] 51 | 52 | def run_iteration(self): 53 | for i in range(self.data_provider.batch_size): 54 | self.low_pass_filter(samples=self.audio, sample_rate=44100) 55 | 56 | def run_warmup_iteration(self): 57 | for i in range(self.data_provider.batch_size): 58 | self.low_pass_filter(samples=self.audio, sample_rate=44100) 59 | 60 | def on_end(self): 61 | pass 62 | 63 | 64 | @benchmark( 65 | bench_type='low_pass_filter', 66 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=1, dtype=torch.float32), 67 | n_iters=1000 68 | ) 69 | @benchmark( 70 | bench_type='low_pass_filter', 71 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=1, dtype=torch.float16), 72 | n_iters=1000 73 | ) 74 | @benchmark( 75 | bench_type='low_pass_filter', 76 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=16, dtype=torch.float32), 77 | n_iters=1000 78 | ) 79 | @benchmark( 80 | bench_type='low_pass_filter', 81 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=16, dtype=torch.float16), 82 | n_iters=1000 83 | ) 84 | @benchmark( 85 | bench_type='low_pass_filter', 86 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=32, dtype=torch.float32), 87 | n_iters=1000 88 | ) 89 | @benchmark( 90 | bench_type='low_pass_filter', 91 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=32, dtype=torch.float16), 92 | n_iters=1000 93 | ) 94 | @benchmark( 95 | bench_type='low_pass_filter', 96 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=64, dtype=torch.float32), 97 | n_iters=1000 98 | ) 99 | @benchmark( 100 | bench_type='low_pass_filter', 101 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=64, dtype=torch.float16), 102 | n_iters=1000 103 | ) 104 | @benchmark( 105 | bench_type='low_pass_filter', 106 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=128, dtype=torch.float32), 107 | n_iters=1000 108 | ) 109 | @benchmark( 110 | bench_type='low_pass_filter', 111 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=128, dtype=torch.float16), 112 | n_iters=1000 113 | ) 114 | class FastAudiomentationsBenchmarkSuite(BenchmarkSuite): 115 | def __init__(self, data_provider: SingleAudioProvider, n_iters): 116 | super().__init__( 117 | name=f"Fast Audiomentations Low Pass Filter ({data_provider.dtype}) (batch_size={data_provider.batch_size})", 118 | warmup_iterations=10, 119 | iterations=n_iters, 120 | samples_per_iter=data_provider.batch_size, 121 | is_gpu_timer_required=True 122 | ) 123 | 124 | self.data_provider = data_provider 125 | 126 | def on_start(self): 127 | from fast_audiomentations import LowPassFilter 128 | self.low_pass_filter = LowPassFilter(min_cutoff_freq=100, max_cutoff_freq=10000, p=1.0, num_taps=101) 129 | self.audio = self.data_provider.get() 130 | 131 | def run_iteration(self): 132 | self.low_pass_filter(samples=self.audio, sample_rate=44100) 133 | 134 | def run_warmup_iteration(self): 135 | self.low_pass_filter(samples=self.audio, sample_rate=44100) 136 | 137 | def on_end(self): 138 | pass 139 | 140 | 141 | @benchmark( 142 | bench_type='low_pass_filter', 143 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=1, dtype=torch.float32), 144 | n_iters=1000 145 | ) 146 | @benchmark( 147 | bench_type='low_pass_filter', 148 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=16, dtype=torch.float32), 149 | n_iters=1000 150 | ) 151 | @benchmark( 152 | bench_type='low_pass_filter', 153 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=32, dtype=torch.float32), 154 | n_iters=1000 155 | ) 156 | @benchmark( 157 | bench_type='low_pass_filter', 158 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=64, dtype=torch.float32), 159 | n_iters=1000 160 | ) 161 | @benchmark( 162 | bench_type='low_pass_filter', 163 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=128, dtype=torch.float32), 164 | n_iters=1000 165 | ) 166 | class TorchAudiomentationsBenchmarkSuite(BenchmarkSuite): 167 | def __init__(self, data_provider: SingleAudioProvider, n_iters): 168 | super().__init__( 169 | name=f"Torch Audiomentations Low Pass Filter ({data_provider.dtype}) (batch_size={data_provider.batch_size})", 170 | warmup_iterations=10, 171 | iterations=n_iters, 172 | samples_per_iter=data_provider.batch_size, 173 | is_gpu_timer_required=True 174 | ) 175 | self.data_provider = data_provider 176 | 177 | def on_start(self): 178 | from torch_audiomentations import LowPassFilter 179 | self.low_pass_filter = LowPassFilter(min_cutoff_freq=100, max_cutoff_freq=10000, p=1.0) 180 | self.audio = self.data_provider.get().unsqueeze(1) 181 | 182 | def run_iteration(self): 183 | self.low_pass_filter(samples=self.audio, sample_rate=44100) 184 | 185 | def run_warmup_iteration(self): 186 | self.low_pass_filter(samples=self.audio, sample_rate=44100) 187 | 188 | def on_end(self): 189 | pass 190 | 191 | if __name__ == '__main__': 192 | benchmark = Benchmark(name="Low Pass Filter Benchmark", bench_type="low_pass_filter") 193 | benchmark.run() 194 | benchmark.print_results() 195 | -------------------------------------------------------------------------------- /benchmark/band_pass_benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from benchmark.benchmark_tool import ( 3 | Benchmark, 4 | BenchmarkSuite, 5 | SingleAudioProvider, 6 | SingleAudioProviderClasses, 7 | benchmark 8 | ) 9 | 10 | @benchmark( 11 | bench_type='band_pass_filter', 12 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=1, dtype='float32'), 13 | n_iters=1000 14 | ) 15 | @benchmark( 16 | bench_type='band_pass_filter', 17 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=16, dtype='float32'), 18 | n_iters=1000 19 | ) 20 | @benchmark( 21 | bench_type='band_pass_filter', 22 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=32, dtype='float32'), 23 | n_iters=1000 24 | ) 25 | @benchmark( 26 | bench_type='band_pass_filter', 27 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=64, dtype='float32'), 28 | n_iters=1000 29 | ) 30 | @benchmark( 31 | bench_type='band_pass_filter', 32 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=128, dtype='float32'), 33 | n_iters=1000 34 | ) 35 | class AudiomentationsBenchmarkSuite(BenchmarkSuite): 36 | def __init__(self, data_provider: SingleAudioProvider, n_iters): 37 | super().__init__( 38 | name=f"Audiomentations Band Pass Filter (dtype={data_provider.dtype}) (batch_size={data_provider.batch_size})", 39 | warmup_iterations=10, 40 | iterations=n_iters, 41 | samples_per_iter=data_provider.batch_size 42 | ) 43 | 44 | self.data_provider = data_provider 45 | 46 | def on_start(self): 47 | from audiomentations import BandPassFilter 48 | self.band_pass_filter = BandPassFilter(min_center_freq=100, max_center_freq=10000, p=1.0) 49 | self.audio = self.data_provider.get()[0] 50 | 51 | def run_iteration(self): 52 | for i in range(self.data_provider.batch_size): 53 | self.band_pass_filter(samples=self.audio, sample_rate=44100) 54 | 55 | def run_warmup_iteration(self): 56 | for i in range(self.data_provider.batch_size): 57 | self.band_pass_filter(samples=self.audio, sample_rate=44100) 58 | 59 | def on_end(self): 60 | pass 61 | 62 | 63 | @benchmark( 64 | bench_type='band_pass_filter', 65 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=1, dtype=torch.float32), 66 | n_iters=1000 67 | ) 68 | @benchmark( 69 | bench_type='band_pass_filter', 70 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=1, dtype=torch.float16), 71 | n_iters=1000 72 | ) 73 | @benchmark( 74 | bench_type='band_pass_filter', 75 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=16, dtype=torch.float32), 76 | n_iters=1000 77 | ) 78 | @benchmark( 79 | bench_type='band_pass_filter', 80 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=16, dtype=torch.float16), 81 | n_iters=1000 82 | ) 83 | @benchmark( 84 | bench_type='band_pass_filter', 85 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=32, dtype=torch.float32), 86 | n_iters=1000 87 | ) 88 | @benchmark( 89 | bench_type='band_pass_filter', 90 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=32, dtype=torch.float16), 91 | n_iters=1000 92 | ) 93 | @benchmark( 94 | bench_type='band_pass_filter', 95 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=64, dtype=torch.float32), 96 | n_iters=1000 97 | ) 98 | @benchmark( 99 | bench_type='band_pass_filter', 100 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=64, dtype=torch.float16), 101 | n_iters=1000 102 | ) 103 | @benchmark( 104 | bench_type='band_pass_filter', 105 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=128, dtype=torch.float32), 106 | n_iters=1000 107 | ) 108 | @benchmark( 109 | bench_type='band_pass_filter', 110 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=128, dtype=torch.float16), 111 | n_iters=1000 112 | ) 113 | class FastAudiomentationsBenchmarkSuite(BenchmarkSuite): 114 | def __init__(self, data_provider: SingleAudioProvider, n_iters): 115 | super().__init__( 116 | name=f"Fast Audiomentations Band Pass Filter ({data_provider.dtype}) (batch_size={data_provider.batch_size})", 117 | warmup_iterations=10, 118 | iterations=n_iters, 119 | samples_per_iter=data_provider.batch_size, 120 | is_gpu_timer_required=True 121 | ) 122 | 123 | self.data_provider = data_provider 124 | 125 | def on_start(self): 126 | from fast_audiomentations import BandPassFilter 127 | self.band_pass_filter = BandPassFilter(min_center_freq=100, max_center_freq=10000, p=1.0) 128 | self.audio = self.data_provider.get() 129 | 130 | def run_iteration(self): 131 | self.band_pass_filter(samples=self.audio, sample_rate=44100) 132 | 133 | def run_warmup_iteration(self): 134 | self.band_pass_filter(samples=self.audio, sample_rate=44100) 135 | 136 | def on_end(self): 137 | pass 138 | 139 | 140 | @benchmark( 141 | bench_type='band_pass_filter', 142 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=1, dtype=torch.float32), 143 | n_iters=1000 144 | ) 145 | @benchmark( 146 | bench_type='band_pass_filter', 147 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=16, dtype=torch.float32), 148 | n_iters=1000 149 | ) 150 | @benchmark( 151 | bench_type='band_pass_filter', 152 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=32, dtype=torch.float32), 153 | n_iters=1000 154 | ) 155 | @benchmark( 156 | bench_type='band_pass_filter', 157 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=64, dtype=torch.float32), 158 | n_iters=1000 159 | ) 160 | @benchmark( 161 | bench_type='band_pass_filter', 162 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=128, dtype=torch.float32), 163 | n_iters=1000 164 | ) 165 | class TorchAudiomentationsBenchmarkSuite(BenchmarkSuite): 166 | def __init__(self, data_provider: SingleAudioProvider, n_iters): 167 | super().__init__( 168 | name=f"Torch Audiomentations Band Pass Filter ({data_provider.dtype}) (batch_size={data_provider.batch_size})", 169 | warmup_iterations=10, 170 | iterations=n_iters, 171 | samples_per_iter=data_provider.batch_size, 172 | is_gpu_timer_required=True 173 | ) 174 | self.data_provider = data_provider 175 | 176 | def on_start(self): 177 | from torch_audiomentations import BandPassFilter 178 | self.band_pass_filter = BandPassFilter(min_center_frequency=100, max_center_frequency=10000, p=1.0) 179 | self.audio = self.data_provider.get().unsqueeze(1) 180 | 181 | def run_iteration(self): 182 | self.band_pass_filter(samples=self.audio, sample_rate=44100) 183 | 184 | def run_warmup_iteration(self): 185 | self.band_pass_filter(samples=self.audio, sample_rate=44100) 186 | 187 | def on_end(self): 188 | pass 189 | 190 | if __name__ == '__main__': 191 | benchmark = Benchmark(name="Band Pass Filter Benchmark", bench_type='band_pass_filter') 192 | benchmark.run() 193 | benchmark.print_results() 194 | -------------------------------------------------------------------------------- /benchmark/band_stop_benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from benchmark.benchmark_tool import ( 3 | Benchmark, 4 | BenchmarkSuite, 5 | SingleAudioProvider, 6 | SingleAudioProviderClasses, 7 | benchmark 8 | ) 9 | 10 | @benchmark( 11 | bench_type='band_stop_filter', 12 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=1, dtype='float32'), 13 | n_iters=1000 14 | ) 15 | @benchmark( 16 | bench_type='band_stop_filter', 17 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=16, dtype='float32'), 18 | n_iters=1000 19 | ) 20 | @benchmark( 21 | bench_type='band_stop_filter', 22 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=32, dtype='float32'), 23 | n_iters=1000 24 | ) 25 | @benchmark( 26 | bench_type='band_stop_filter', 27 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=64, dtype='float32'), 28 | n_iters=1000 29 | ) 30 | @benchmark( 31 | bench_type='band_stop_filter', 32 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=128, dtype='float32'), 33 | n_iters=1000 34 | ) 35 | class AudiomentationsBenchmarkSuite(BenchmarkSuite): 36 | def __init__(self, data_provider: SingleAudioProvider, n_iters): 37 | super().__init__( 38 | name=f"Audiomentations Band Stop Filter (dtype={data_provider.dtype}) (batch_size={data_provider.batch_size})", 39 | warmup_iterations=10, 40 | iterations=n_iters, 41 | samples_per_iter=data_provider.batch_size 42 | ) 43 | 44 | self.data_provider = data_provider 45 | 46 | def on_start(self): 47 | from audiomentations import BandStopFilter 48 | self.band_stop_filter = BandStopFilter(min_center_freq=100, max_center_freq=10000, p=1.0) 49 | self.audio = self.data_provider.get()[0] 50 | 51 | def run_iteration(self): 52 | for i in range(self.data_provider.batch_size): 53 | self.band_stop_filter(samples=self.audio, sample_rate=44100) 54 | 55 | def run_warmup_iteration(self): 56 | for i in range(self.data_provider.batch_size): 57 | self.band_stop_filter(samples=self.audio, sample_rate=44100) 58 | 59 | def on_end(self): 60 | pass 61 | 62 | 63 | @benchmark( 64 | bench_type='band_stop_filter', 65 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=1, dtype=torch.float32), 66 | n_iters=1000 67 | ) 68 | @benchmark( 69 | bench_type='band_stop_filter', 70 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=1, dtype=torch.float16), 71 | n_iters=1000 72 | ) 73 | @benchmark( 74 | bench_type='band_stop_filter', 75 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=16, dtype=torch.float32), 76 | n_iters=1000 77 | ) 78 | @benchmark( 79 | bench_type='band_stop_filter', 80 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=16, dtype=torch.float16), 81 | n_iters=1000 82 | ) 83 | @benchmark( 84 | bench_type='band_stop_filter', 85 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=32, dtype=torch.float32), 86 | n_iters=1000 87 | ) 88 | @benchmark( 89 | bench_type='band_stop_filter', 90 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=32, dtype=torch.float16), 91 | n_iters=1000 92 | ) 93 | @benchmark( 94 | bench_type='band_stop_filter', 95 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=64, dtype=torch.float32), 96 | n_iters=1000 97 | ) 98 | @benchmark( 99 | bench_type='band_stop_filter', 100 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=64, dtype=torch.float16), 101 | n_iters=1000 102 | ) 103 | @benchmark( 104 | bench_type='band_stop_filter', 105 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=128, dtype=torch.float32), 106 | n_iters=1000 107 | ) 108 | @benchmark( 109 | bench_type='band_stop_filter', 110 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=128, dtype=torch.float16), 111 | n_iters=1000 112 | ) 113 | class FastAudiomentationsBenchmarkSuite(BenchmarkSuite): 114 | def __init__(self, data_provider: SingleAudioProvider, n_iters): 115 | super().__init__( 116 | name=f"Fast Audiomentations Band Stop Filter ({data_provider.dtype}) (batch_size={data_provider.batch_size})", 117 | warmup_iterations=10, 118 | iterations=n_iters, 119 | samples_per_iter=data_provider.batch_size, 120 | is_gpu_timer_required=True 121 | ) 122 | 123 | self.data_provider = data_provider 124 | 125 | def on_start(self): 126 | from fast_audiomentations import BandStopFilter 127 | self.band_stop_filter = BandStopFilter(min_center_freq=100, max_center_freq=10000, p=1.0) 128 | self.audio = self.data_provider.get() 129 | 130 | def run_iteration(self): 131 | self.band_stop_filter(samples=self.audio, sample_rate=44100) 132 | 133 | def run_warmup_iteration(self): 134 | self.band_stop_filter(samples=self.audio, sample_rate=44100) 135 | 136 | def on_end(self): 137 | pass 138 | 139 | 140 | @benchmark( 141 | bench_type='band_stop_filter', 142 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=1, dtype=torch.float32), 143 | n_iters=1000 144 | ) 145 | @benchmark( 146 | bench_type='band_stop_filter', 147 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=16, dtype=torch.float32), 148 | n_iters=1000 149 | ) 150 | @benchmark( 151 | bench_type='band_stop_filter', 152 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=32, dtype=torch.float32), 153 | n_iters=1000 154 | ) 155 | @benchmark( 156 | bench_type='band_stop_filter', 157 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=64, dtype=torch.float32), 158 | n_iters=1000 159 | ) 160 | @benchmark( 161 | bench_type='band_stop_filter', 162 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=128, dtype=torch.float32), 163 | n_iters=1000 164 | ) 165 | class TorchAudiomentationsBenchmarkSuite(BenchmarkSuite): 166 | def __init__(self, data_provider: SingleAudioProvider, n_iters): 167 | super().__init__( 168 | name=f"Torch Audiomentations Band Stop Filter ({data_provider.dtype}) (batch_size={data_provider.batch_size})", 169 | warmup_iterations=10, 170 | iterations=n_iters, 171 | samples_per_iter=data_provider.batch_size, 172 | is_gpu_timer_required=True 173 | ) 174 | self.data_provider = data_provider 175 | 176 | def on_start(self): 177 | from torch_audiomentations import BandStopFilter 178 | self.band_stop_filter = BandStopFilter(min_center_frequency=100, max_center_frequency=10000, p=1.0) 179 | self.audio = self.data_provider.get().unsqueeze(1) 180 | 181 | def run_iteration(self): 182 | self.band_stop_filter(samples=self.audio, sample_rate=44100) 183 | 184 | def run_warmup_iteration(self): 185 | self.band_stop_filter(samples=self.audio, sample_rate=44100) 186 | 187 | def on_end(self): 188 | pass 189 | 190 | if __name__ == '__main__': 191 | benchmark = Benchmark(name="Band Stop Filter Benchmark", bench_type='band_stop_filter') 192 | benchmark.run() 193 | benchmark.print_results() 194 | -------------------------------------------------------------------------------- /benchmark/high_pass_benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from benchmark.benchmark_tool import ( 3 | Benchmark, 4 | BenchmarkSuite, 5 | SingleAudioProvider, 6 | SingleAudioProviderClasses, 7 | benchmark 8 | ) 9 | 10 | 11 | @benchmark( 12 | bench_type='high_pass_filter', 13 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=1, dtype='float32'), 14 | n_iters=1000 15 | ) 16 | @benchmark( 17 | bench_type='high_pass_filter', 18 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=16, dtype='float32'), 19 | n_iters=1000 20 | ) 21 | @benchmark( 22 | bench_type='high_pass_filter', 23 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=32, dtype='float32'), 24 | n_iters=1000 25 | ) 26 | @benchmark( 27 | bench_type='high_pass_filter', 28 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=64, dtype='float32'), 29 | n_iters=1000 30 | ) 31 | @benchmark( 32 | bench_type='high_pass_filter', 33 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=128, dtype='float32'), 34 | n_iters=1000 35 | ) 36 | class AudiomentationsBenchmarkSuite(BenchmarkSuite): 37 | def __init__(self, data_provider: SingleAudioProvider, n_iters): 38 | super().__init__( 39 | name=f"Audiomentations High Pass Filter (dtype={data_provider.dtype}) (batch_size={data_provider.batch_size})", 40 | warmup_iterations=10, 41 | iterations=n_iters, 42 | samples_per_iter=data_provider.batch_size 43 | ) 44 | 45 | self.data_provider = data_provider 46 | 47 | def on_start(self): 48 | from audiomentations import HighPassFilter 49 | self.high_pass_filter = HighPassFilter(min_cutoff_freq=100, max_cutoff_freq=10000, p=1.0) 50 | self.audio = self.data_provider.get()[0] 51 | 52 | def run_iteration(self): 53 | for i in range(self.data_provider.batch_size): 54 | self.high_pass_filter(samples=self.audio, sample_rate=44100) 55 | 56 | def run_warmup_iteration(self): 57 | for i in range(self.data_provider.batch_size): 58 | self.high_pass_filter(samples=self.audio, sample_rate=44100) 59 | 60 | def on_end(self): 61 | pass 62 | 63 | 64 | @benchmark( 65 | bench_type='high_pass_filter', 66 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=1, dtype=torch.float32), 67 | n_iters=1000 68 | ) 69 | @benchmark( 70 | bench_type='high_pass_filter', 71 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=1, dtype=torch.float16), 72 | n_iters=1000 73 | ) 74 | @benchmark( 75 | bench_type='high_pass_filter', 76 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=16, dtype=torch.float32), 77 | n_iters=1000 78 | ) 79 | @benchmark( 80 | bench_type='high_pass_filter', 81 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=16, dtype=torch.float16), 82 | n_iters=1000 83 | ) 84 | @benchmark( 85 | bench_type='high_pass_filter', 86 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=32, dtype=torch.float32), 87 | n_iters=1000 88 | ) 89 | @benchmark( 90 | bench_type='high_pass_filter', 91 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=32, dtype=torch.float16), 92 | n_iters=1000 93 | ) 94 | @benchmark( 95 | bench_type='high_pass_filter', 96 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=64, dtype=torch.float32), 97 | n_iters=1000 98 | ) 99 | @benchmark( 100 | bench_type='high_pass_filter', 101 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=64, dtype=torch.float16), 102 | n_iters=1000 103 | ) 104 | @benchmark( 105 | bench_type='high_pass_filter', 106 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=128, dtype=torch.float32), 107 | n_iters=1000 108 | ) 109 | @benchmark( 110 | bench_type='high_pass_filter', 111 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=128, dtype=torch.float16), 112 | n_iters=1000 113 | ) 114 | class FastAudiomentationsBenchmarkSuite(BenchmarkSuite): 115 | def __init__(self, data_provider: SingleAudioProvider, n_iters): 116 | super().__init__( 117 | name=f"Fast Audiomentations High Pass Filter ({data_provider.dtype}) (batch_size={data_provider.batch_size})", 118 | warmup_iterations=10, 119 | iterations=n_iters, 120 | samples_per_iter=data_provider.batch_size, 121 | is_gpu_timer_required=True 122 | ) 123 | 124 | self.data_provider = data_provider 125 | 126 | def on_start(self): 127 | from fast_audiomentations import HighPassFilter 128 | self.high_pass_filter = HighPassFilter(min_cutoff_freq=100, max_cutoff_freq=10000, p=1.0) 129 | self.audio = self.data_provider.get() 130 | 131 | def run_iteration(self): 132 | r = self.high_pass_filter(samples=self.audio, sample_rate=44100) 133 | 134 | def run_warmup_iteration(self): 135 | self.high_pass_filter(samples=self.audio, sample_rate=44100) 136 | 137 | def on_end(self): 138 | pass 139 | 140 | 141 | @benchmark( 142 | bench_type='high_pass_filter', 143 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=1, dtype=torch.float32), 144 | n_iters=1000 145 | ) 146 | @benchmark( 147 | bench_type='high_pass_filter', 148 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=16, dtype=torch.float32), 149 | n_iters=1000 150 | ) 151 | @benchmark( 152 | bench_type='high_pass_filter', 153 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=32, dtype=torch.float32), 154 | n_iters=1000 155 | ) 156 | @benchmark( 157 | bench_type='high_pass_filter', 158 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=64, dtype=torch.float32), 159 | n_iters=1000 160 | ) 161 | @benchmark( 162 | bench_type='high_pass_filter', 163 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=128, dtype=torch.float32), 164 | n_iters=1000 165 | ) 166 | class TorchAudiomentationsBenchmarkSuite(BenchmarkSuite): 167 | def __init__(self, data_provider: SingleAudioProvider, n_iters): 168 | super().__init__( 169 | name=f"Torch Audiomentations High Pass Filter ({data_provider.dtype}) (batch_size={data_provider.batch_size})", 170 | warmup_iterations=10, 171 | iterations=n_iters, 172 | samples_per_iter=data_provider.batch_size, 173 | is_gpu_timer_required=True 174 | ) 175 | self.data_provider = data_provider 176 | 177 | def on_start(self): 178 | from torch_audiomentations import HighPassFilter 179 | self.high_pass_filter = HighPassFilter(min_cutoff_freq=100, max_cutoff_freq=10000, p=1.0) 180 | self.audio = self.data_provider.get().unsqueeze(1) 181 | 182 | def run_iteration(self): 183 | self.high_pass_filter(samples=self.audio, sample_rate=44100) 184 | 185 | def run_warmup_iteration(self): 186 | self.high_pass_filter(samples=self.audio, sample_rate=44100) 187 | 188 | def on_end(self): 189 | pass 190 | 191 | if __name__ == '__main__': 192 | benchmark = Benchmark(name="High Pass Filter Benchmark", bench_type="high_pass_filter") 193 | benchmark.run() 194 | benchmark.print_results() 195 | -------------------------------------------------------------------------------- /fast_audiomentations/transforms/add_background_noise.py: -------------------------------------------------------------------------------- 1 | from fast_audiomentations.transforms._impl._mix_triton import sum_with_snr_triton as _sum_with_snr_triton 2 | 3 | import random 4 | import numpy as np 5 | 6 | import torch 7 | import torchaudio 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | from nvidia.dali import pipeline_def 11 | import nvidia.dali.fn as fn 12 | import nvidia.dali.types as types 13 | 14 | 15 | @pipeline_def 16 | def audio_pipeline(file_paths): 17 | """ 18 | Defines an NVIDIA DALI pipeline for audio file processing. 19 | 20 | @param file_paths: The root directory path where audio files are stored. 21 | @return: Processed audio and audio lengths as output from the pipeline. 22 | """ 23 | encoded, _ = fn.readers.file(file_root=file_paths) 24 | audio, sr = fn.decoders.audio(encoded, dtype=types.FLOAT, downmix=True) 25 | audio_length = fn.shapes(audio) 26 | 27 | return audio, audio_length 28 | 29 | 30 | class AddBackgroundNoise: 31 | """ 32 | Class for adding background noise to audio samples. 33 | 34 | Attributes: 35 | noises_dataloader (DataLoader | DALIDataLoader): Dataloader for background noise samples. 36 | min_snr (float): Minimum Signal-to-Noise Ratio. 37 | max_snr (float): Maximum Signal-to-Noise Ratio. 38 | p (float): Probability of applying the augmentation. 39 | dtype (torch.dtype): Data type for computation. 40 | """ 41 | class AudioDataset(Dataset): 42 | """ 43 | Dataset class to handle loading of audio files. 44 | 45 | @param file_paths: List of paths to audio files. 46 | """ 47 | def __init__(self, file_paths): 48 | self.file_paths = file_paths 49 | 50 | def __len__(self): 51 | return len(self.file_paths) 52 | 53 | def __getitem__(self, idx): 54 | """ 55 | Get a single item from the dataset. 56 | 57 | @param idx: Index of the item. 58 | @return: Waveform and its length of the audio at the given index. 59 | """ 60 | try: 61 | waveform, sample_rate = torchaudio.load(self.file_paths[idx % len(self.file_paths)]) 62 | return waveform[0], waveform.shape[1] 63 | except Exception as e: 64 | print(e) 65 | 66 | def __init__(self, 67 | noises_dataloader, 68 | min_snr: float = 15.0, 69 | max_snr: float = 20.0, 70 | p=0.5, 71 | buffer_size=129, 72 | dtype: torch.dtype = torch.float32): 73 | self.min_snr = min_snr 74 | self.max_snr = max_snr 75 | self.p = p 76 | self.dtype = dtype 77 | self.noises_dataloader = noises_dataloader 78 | 79 | self.random_buffer = torch.empty(buffer_size, device='cuda', dtype=dtype) 80 | self.copy_stream = torch.cuda.Stream() 81 | self.padded_batch = None 82 | 83 | @staticmethod 84 | def get_default_dataloader(noises_paths, buffer_size=128, n_workers=8, prefetch_factor=2): 85 | """ 86 | Create a default DataLoader for loading noise samples. 87 | 88 | @param noises_paths: List of paths to noise files. 89 | @param buffer_size: Size of the buffer for batch loading. 90 | @param n_workers: Number of worker threads for data loading. 91 | @param prefetch_factor: Number of batches to prefetch. 92 | @return: A DataLoader instance configured for loading noise samples. 93 | """ 94 | def audio_collate_fn(batch): 95 | max_length = max(waveform.size(0) for waveform, _ in batch) 96 | batch_waveforms = [] 97 | audio_lens = [] 98 | for waveform, waveform_len in batch: 99 | padded_waveform = torch.nn.functional.pad( 100 | waveform, (0, max_length - waveform.size(0)), mode='constant', value=0) 101 | batch_waveforms.append(padded_waveform) 102 | audio_lens.append(waveform_len) 103 | return torch.stack(batch_waveforms), torch.tensor(audio_lens) 104 | 105 | return DataLoader( 106 | AddBackgroundNoise.AudioDataset(noises_paths), 107 | batch_size=buffer_size, 108 | num_workers=n_workers, 109 | shuffle=True, 110 | prefetch_factor=prefetch_factor, 111 | pin_memory=False, 112 | collate_fn=audio_collate_fn, 113 | drop_last=True, 114 | persistent_workers=True 115 | ) 116 | 117 | @staticmethod 118 | def get_dali_dataloader(noises_paths, buffer_size=128, n_workers=8, device_id=0): 119 | """ 120 | Create a DataLoader using NVIDIA DALI for loading noise samples. 121 | 122 | @param noises_paths: List of paths to noise files. 123 | @param buffer_size: Size of the buffer for batch loading. 124 | @param n_workers: Number of worker threads for data loading. 125 | @param device_id: ID of the GPU device. 126 | @return: A DALI pipeline instance configured for loading noise samples. 127 | """ 128 | pipe = audio_pipeline(batch_size=buffer_size, num_threads=n_workers, device_id=device_id, 129 | file_paths=noises_paths, prefetch_queue_depth=2) 130 | pipe.build() 131 | return pipe 132 | 133 | 134 | def __generate_random_snrs(self, num_audios): 135 | assert num_audios <= self.random_buffer.size(0) 136 | 137 | buff_slice = self.random_buffer[:num_audios] 138 | buff_slice.uniform_(self.min_snr, self.max_snr) 139 | 140 | return buff_slice 141 | 142 | def create_padded_batch(self, audios): 143 | # Extract and convert audio tensors and lengths to NumPy arrays 144 | audio_tensors = [np.array(tensor) for tensor in audios[0]] 145 | lengths = np.squeeze(np.array(audios[1]), 1) # Extracting and reshaping length array 146 | 147 | # Find the maximum length in the batch 148 | max_length = lengths.max() 149 | 150 | # Batch size is the number of audio tensors 151 | batch_size = len(audio_tensors) 152 | 153 | # Allocate a GPU matrix with the shape [batch_size, max_length] 154 | # TODO: some smart preallocation 155 | self.padded_batch = torch.zeros(batch_size, max_length, device='cuda') 156 | 157 | # Copy each audio tensor (converted from NumPy) to the padded batch 158 | for i, audio_tensor_np in enumerate(audio_tensors): 159 | audio_length = lengths[i] 160 | 161 | audio_tensor = torch.from_numpy(audio_tensor_np) 162 | 163 | self.padded_batch[i, :audio_length].copy_(audio_tensor, non_blocking=True) 164 | 165 | return self.padded_batch, torch.tensor(lengths, device='cuda') 166 | 167 | def __call__(self, samples: torch.Tensor, samples_lens: torch.Tensor, sample_rate: int, inplace=False): 168 | """ 169 | Apply the add background noise transformation. 170 | 171 | @param samples: Input audio samples. 172 | @param samples_lens: Lengths of the audio samples. 173 | @param sample_rate: Sample rate of the audio. 174 | @param inplace: If True, perform operation in-place. 175 | @return: Augmented audio samples. 176 | """ 177 | if random.random() < self.p: 178 | audios = self.noises_dataloader.run() 179 | noises, noises_lens = self.create_padded_batch(audios) 180 | snrs = self.__generate_random_snrs(samples.shape[0]) 181 | 182 | return _sum_with_snr_triton( 183 | samples, samples_lens, 184 | noises, noises_lens, 185 | snrs 186 | ) 187 | return samples -------------------------------------------------------------------------------- /fast_audiomentations/transforms/_impl/_filter_triton.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import triton 4 | import triton.language as tl 5 | 6 | import itertools 7 | from scipy.signal import firwin 8 | import numpy as np 9 | import math 10 | import soundfile as sf 11 | 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | 17 | @triton.jit 18 | def sinc_kernel( 19 | output_ptr, 20 | cutoffs_ptr, 21 | indices_ptr, 22 | num_taps, 23 | window_ptr, 24 | half_sample_rate, 25 | mode: tl.constexpr, 26 | BLOCK_SIZE: tl.constexpr): 27 | batch_idx = tl.program_id(1) 28 | pos = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 29 | mask = pos < num_taps 30 | 31 | cutoff_val = tl.load(cutoffs_ptr + batch_idx) / half_sample_rate 32 | index_val = tl.load(indices_ptr + pos, mask=mask) 33 | window_val = tl.load(window_ptr + pos, mask=mask) 34 | 35 | x = index_val * math.pi * cutoff_val 36 | sinc_val = tl.where(index_val == 0, 1., tl.sin(x) / x) 37 | windowed_sinc = sinc_val * window_val 38 | 39 | # Normalize each filter by the sum of its windowed sinc values 40 | normalized_sinc = windowed_sinc / tl.sum(windowed_sinc, axis=0) 41 | if mode == "high": 42 | center_idx = num_taps // 2 43 | adjusted_val = tl.where(pos == center_idx, 1.0 - normalized_sinc, -normalized_sinc) 44 | 45 | tl.store(output_ptr + batch_idx * num_taps + pos, adjusted_val, mask=mask) 46 | elif mode == "low": 47 | tl.store(output_ptr + batch_idx * num_taps + pos, normalized_sinc, mask=mask) 48 | else: 49 | raise ValueError(f"Unknown mode: {mode}") 50 | 51 | 52 | def create_filters(filter_output, cutoff_freqs, time, window, sample_rate, num_taps, mode): 53 | grid_size = (1, len(cutoff_freqs)) 54 | 55 | sinc_kernel[grid_size]( 56 | filter_output, 57 | cutoff_freqs, 58 | time, 59 | num_taps, 60 | window, 61 | 0.5 * sample_rate, 62 | mode, 63 | triton.next_power_of_2(num_taps) 64 | ) 65 | 66 | @triton.autotune( 67 | configs=[ 68 | triton.Config({'BLOCK_SIZE': block_size}, num_warps=num_warps) 69 | for (block_size, num_warps) in 70 | itertools.product([32, 64, 128, 256, 512, 1024, 2048, 4096], [1, 2, 4, 8, 16, 32]) 71 | ], 72 | key=['length', 'kernel_size', 'stride', 'n_frames'] 73 | ) 74 | @triton.jit 75 | def unfold_kernel(input_ptr, output_ptr, length, kernel_size, stride, n_frames, BLOCK_SIZE: tl.constexpr): 76 | # Compute indices 77 | batch_idx = tl.program_id(0) 78 | 79 | # Global frame index 80 | frame_idx = tl.program_id(1) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 81 | 82 | # Bounds check for the frame index 83 | mask = frame_idx < n_frames 84 | 85 | # Calculate position in input for each thread 86 | input_pos = frame_idx * stride 87 | 88 | # Each thread processes one frame if within bounds 89 | for i in range(kernel_size): 90 | in_bounds = mask & ((input_pos + i) < length) 91 | 92 | # Use tl.where to handle in-bounds and out-of-bounds cases 93 | val = tl.where(in_bounds, tl.load(input_ptr + batch_idx * length + input_pos + i, mask=in_bounds), 0) 94 | 95 | out_idx = batch_idx * n_frames * kernel_size + frame_idx * kernel_size + i 96 | tl.store(output_ptr + out_idx, val, mask=in_bounds) 97 | 98 | 99 | def unfold_triton(input, kernel_size, stride): 100 | assert input.ndim >= 2, "Input tensor must be at least 2D" 101 | length = input.shape[-1] 102 | n_frames = math.ceil((max(length, kernel_size) - kernel_size) / stride) + 1 103 | 104 | # Prepare output tensor 105 | output_shape = list(input.shape)[:-1] + [n_frames, kernel_size] 106 | output = torch.empty(output_shape, device=input.device, dtype=input.dtype) 107 | 108 | # Grid dimensions 109 | grid = lambda META: ( 110 | input.shape[0], 111 | triton.cdiv(n_frames, META['BLOCK_SIZE']) + (n_frames % META['BLOCK_SIZE'] != 0) 112 | ) 113 | 114 | # Launch kernel 115 | unfold_kernel[grid](input, output, length, kernel_size, stride, n_frames) 116 | 117 | return output 118 | 119 | def complex_mul_conjugate(a: torch.Tensor, b: torch.Tensor): 120 | a_real = a[..., 0].contiguous() 121 | a_imag = a[..., 1].contiguous() 122 | b_real = b[..., 0].contiguous() 123 | b_imag = b[..., 1].contiguous() 124 | 125 | return torch.stack(complex_mul_conjugate_triton(a_real, b_real, a_imag, b_imag), dim=-1) 126 | 127 | @triton.autotune( 128 | configs=[ 129 | triton.Config({}, num_warps=num_warps) 130 | for (num_warps) in [1, 2, 4, 8, 16, 32] 131 | ], 132 | key=['num_batches', 'num_frames', 'fft_size'] 133 | ) 134 | @triton.jit 135 | def complex_mul_conjugate_kernel( 136 | a_real_ptr, 137 | b_real_ptr, 138 | a_imag_ptr, 139 | b_imag_ptr, 140 | output1_ptr, 141 | output2_ptr, 142 | num_batches, 143 | num_frames, 144 | fft_size, 145 | BLOCK_SIZE: tl.constexpr): 146 | # Compute indices for batch and fft 147 | batch_idx = tl.program_id(0) 148 | 149 | # Ensure we don't go out of bounds for batch index 150 | if batch_idx >= num_batches: 151 | return 152 | 153 | fft_idx = tl.arange(0, BLOCK_SIZE) 154 | fft_mask = fft_idx < fft_size 155 | 156 | batch_by_fft = batch_idx * fft_size 157 | 158 | b_real_val = tl.load(b_real_ptr + batch_by_fft + fft_idx, mask=fft_mask) 159 | b_imag_val = tl.load(b_imag_ptr + batch_by_fft + fft_idx, mask=fft_mask) 160 | 161 | for frame_idx in range(num_frames): 162 | global_idx = num_frames * batch_by_fft + frame_idx * fft_size + fft_idx 163 | 164 | a_real_val = tl.load(a_real_ptr + global_idx, mask=fft_mask) 165 | a_imag_val = tl.load(a_imag_ptr + global_idx, mask=fft_mask) 166 | 167 | result1 = a_real_val * b_real_val + a_imag_val * b_imag_val 168 | result2 = a_imag_val * b_real_val - a_real_val * b_imag_val 169 | 170 | tl.store(output1_ptr + global_idx, result1, mask=fft_mask) 171 | tl.store(output2_ptr + global_idx, result2, mask=fft_mask) 172 | 173 | 174 | def complex_mul_conjugate_triton(a_real, b_real, a_imag, b_imag): 175 | assert a_real.shape[-1] == b_real.shape[-1] # Ensure last dimensions match for multiplication 176 | 177 | num_batches, num_frames, fft_size = a_real.shape 178 | 179 | # Output tensor 180 | output1 = torch.empty_like(a_real) 181 | output2 = torch.empty_like(a_real) 182 | 183 | # Define grid size for the kernel launch 184 | grid_size = (num_batches,) 185 | 186 | # Launch the kernel 187 | 188 | complex_mul_conjugate_kernel[grid_size]( 189 | a_real, 190 | b_real, 191 | a_imag, 192 | b_imag, 193 | output1, 194 | output2, 195 | num_batches, 196 | num_frames, 197 | fft_size, 198 | triton.next_power_of_2(fft_size) 199 | ) 200 | 201 | return output1, output2 202 | 203 | 204 | def fft_conv1d( 205 | x: torch.Tensor, weight: torch.Tensor, 206 | stride: int = 1, padding: int = 0, 207 | block_ratio: float = 8): 208 | x = F.pad(x, (padding, padding)) 209 | B, L = x.shape 210 | _, kernel_size = weight.shape 211 | 212 | block_size: int = min(int(kernel_size * block_ratio), L) 213 | 214 | weight = F.pad(weight, (0, block_size - weight.shape[-1]), mode='constant', value=0) 215 | if weight.dtype != torch.float16 and weight.shape[1].bit_count() != 1: 216 | weight_z = torch.view_as_real(torch.fft.rfft(weight.to(torch.float32), dim=-1)).to(torch.float16) 217 | else: 218 | weight_z = torch.view_as_real(torch.fft.rfft(weight, dim=-1)) 219 | 220 | frames = unfold_triton(x, block_size, block_size - kernel_size + 1) 221 | 222 | if frames.dtype == torch.float16 and frames.shape[1].bit_count() != 1: 223 | frames_z = torch.view_as_real(torch.fft.rfft(frames.to(torch.float32), dim=-1)).to(torch.float16) 224 | else: 225 | frames_z = torch.view_as_real(torch.fft.rfft(frames, dim=-1)) 226 | 227 | out_z = complex_mul_conjugate(frames_z, weight_z) 228 | 229 | if out_z.dtype == torch.float16 and out_z.shape[1].bit_count() != 1: 230 | out = torch.fft.irfft(torch.view_as_complex(out_z.to(torch.float32)), block_size, dim=-1).to(torch.float16) 231 | else: 232 | out = torch.fft.irfft(torch.view_as_complex(out_z), block_size, dim=-1) 233 | 234 | out = out[..., :-kernel_size + 1] 235 | out = out.reshape(B, 1, -1) 236 | out = out[..., ::stride] 237 | target_length = (L - kernel_size) // stride + 1 238 | out = out[..., :target_length] 239 | 240 | return out -------------------------------------------------------------------------------- /benchmark/add_background_noise_benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from benchmark.benchmark_tool import ( 3 | Benchmark, 4 | BenchmarkSuite, 5 | SingleAudioProvider, 6 | SingleAudioProviderClasses, 7 | benchmark 8 | ) 9 | from benchmark.benchmark_data import PATH_TO_NOISES 10 | 11 | @benchmark( 12 | bench_type='add_background_noise', 13 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=1, dtype='float32'), 14 | path_to_sounds=PATH_TO_NOISES, 15 | n_iters=300 16 | ) 17 | @benchmark( 18 | bench_type='add_background_noise', 19 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=16, dtype='float32'), 20 | path_to_sounds=PATH_TO_NOISES, 21 | n_iters=300 22 | ) 23 | @benchmark( 24 | bench_type='add_background_noise', 25 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=32, dtype='float32'), 26 | path_to_sounds=PATH_TO_NOISES, 27 | n_iters=300 28 | ) 29 | @benchmark( 30 | bench_type='add_background_noise', 31 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=64, dtype='float32'), 32 | path_to_sounds=PATH_TO_NOISES, 33 | n_iters=300 34 | ) 35 | @benchmark( 36 | bench_type='add_background_noise', 37 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.NUMPY, batch_size=128, dtype='float32'), 38 | path_to_sounds=PATH_TO_NOISES, 39 | n_iters=300 40 | ) 41 | class AudiomentationsBenchmarkSuite(BenchmarkSuite): 42 | def __init__(self, data_provider: SingleAudioProvider, path_to_sounds: str, n_iters): 43 | super().__init__( 44 | name=f"Audiomentations Add Background Noise (dtype={data_provider.dtype}) (batch_size={data_provider.batch_size})", 45 | warmup_iterations=10, 46 | iterations=n_iters, 47 | samples_per_iter=data_provider.batch_size 48 | ) 49 | 50 | self.data_provider = data_provider 51 | self.path_to_sounds = path_to_sounds 52 | 53 | def on_start(self): 54 | from audiomentations import AddBackgroundNoise 55 | self.add_background_noise = AddBackgroundNoise( 56 | sounds_path=self.path_to_sounds, 57 | min_snr_db=-10, 58 | max_snr_in_db=10, 59 | p=1.0 60 | ) 61 | self.audio = self.data_provider.get()[0] 62 | 63 | def run_iteration(self): 64 | for i in range(self.data_provider.batch_size): 65 | self.add_background_noise(samples=self.audio, sample_rate=44100) 66 | 67 | def run_warmup_iteration(self): 68 | for i in range(self.data_provider.batch_size): 69 | self.add_background_noise(samples=self.audio, sample_rate=44100) 70 | 71 | def on_end(self): 72 | pass 73 | 74 | 75 | @benchmark( 76 | bench_type='add_background_noise', 77 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=1, dtype=torch.float32), 78 | path_to_sounds=PATH_TO_NOISES, 79 | n_iters=300 80 | ) 81 | @benchmark( 82 | bench_type='add_background_noise', 83 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=1, dtype=torch.float16), 84 | path_to_sounds=PATH_TO_NOISES, 85 | n_iters=300 86 | ) 87 | @benchmark( 88 | bench_type='add_background_noise', 89 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=16, dtype=torch.float32), 90 | path_to_sounds=PATH_TO_NOISES, 91 | n_iters=300 92 | ) 93 | @benchmark( 94 | bench_type='add_background_noise', 95 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=16, dtype=torch.float16), 96 | path_to_sounds=PATH_TO_NOISES, 97 | n_iters=300 98 | ) 99 | @benchmark( 100 | bench_type='add_background_noise', 101 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=32, dtype=torch.float32), 102 | path_to_sounds=PATH_TO_NOISES, 103 | n_iters=300 104 | ) 105 | @benchmark( 106 | bench_type='add_background_noise', 107 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=32, dtype=torch.float16), 108 | path_to_sounds=PATH_TO_NOISES, 109 | n_iters=300 110 | ) 111 | @benchmark( 112 | bench_type='add_background_noise', 113 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=64, dtype=torch.float32), 114 | path_to_sounds=PATH_TO_NOISES, 115 | n_iters=300 116 | ) 117 | @benchmark( 118 | bench_type='add_background_noise', 119 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=64, dtype=torch.float16), 120 | path_to_sounds=PATH_TO_NOISES, 121 | n_iters=300 122 | ) 123 | @benchmark( 124 | bench_type='add_background_noise', 125 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=128, dtype=torch.float32), 126 | path_to_sounds=PATH_TO_NOISES, 127 | n_iters=300 128 | ) 129 | @benchmark( 130 | bench_type='add_background_noise', 131 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=128, dtype=torch.float16), 132 | path_to_sounds=PATH_TO_NOISES, 133 | n_iters=300 134 | ) 135 | class FastAudiomentationsBenchmarkSuite(BenchmarkSuite): 136 | def __init__(self, data_provider: SingleAudioProvider, path_to_sounds: str, n_iters): 137 | super().__init__( 138 | name=f"Fast Audiomentations Add Background Noise ({data_provider.dtype}) (batch_size={data_provider.batch_size})", 139 | warmup_iterations=10, 140 | iterations=n_iters, 141 | samples_per_iter=data_provider.batch_size, 142 | is_gpu_timer_required=True 143 | ) 144 | 145 | self.data_provider = data_provider 146 | self.path_to_sounds = path_to_sounds 147 | 148 | def on_start(self): 149 | from fast_audiomentations import AddBackgroundNoise 150 | 151 | dataloader = AddBackgroundNoise.get_dali_dataloader( 152 | self.path_to_sounds, 153 | buffer_size=self.data_provider.batch_size, 154 | n_workers=4 155 | ) 156 | self.add_background_noise = AddBackgroundNoise( 157 | noises_dataloader=dataloader, 158 | min_snr=-10, 159 | max_snr=10, 160 | buffer_size=self.data_provider.batch_size, 161 | p=1.0 162 | ) 163 | self.audio = self.data_provider.get() 164 | self.audio_lens = torch.tensor([self.audio.shape[0] for i in range(self.data_provider.batch_size)], device='cuda') 165 | 166 | 167 | def run_iteration(self): 168 | self.add_background_noise(samples=self.audio, samples_lens=self.audio_lens, sample_rate=44100) 169 | 170 | def run_warmup_iteration(self): 171 | self.add_background_noise(samples=self.audio, samples_lens=self.audio_lens, sample_rate=44100) 172 | 173 | def on_end(self): 174 | pass 175 | 176 | 177 | @benchmark( 178 | bench_type='add_background_noise', 179 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=1, dtype=torch.float32), 180 | path_to_sounds=PATH_TO_NOISES, 181 | n_iters=300 182 | ) 183 | @benchmark( 184 | bench_type='add_background_noise', 185 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=16, dtype=torch.float32), 186 | path_to_sounds=PATH_TO_NOISES, 187 | n_iters=300 188 | ) 189 | @benchmark( 190 | bench_type='add_background_noise', 191 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=32, dtype=torch.float32), 192 | path_to_sounds=PATH_TO_NOISES, 193 | n_iters=300 194 | ) 195 | @benchmark( 196 | bench_type='add_background_noise', 197 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=64, dtype=torch.float32), 198 | path_to_sounds=PATH_TO_NOISES, 199 | n_iters=300 200 | ) 201 | @benchmark( 202 | bench_type='add_background_noise', 203 | data_provider=SingleAudioProvider(SingleAudioProviderClasses.TORCH_GPU, batch_size=128, dtype=torch.float32), 204 | path_to_sounds=PATH_TO_NOISES, 205 | n_iters=300 206 | ) 207 | class TorchAudiomentationsBenchmarkSuite(BenchmarkSuite): 208 | def __init__(self, data_provider: SingleAudioProvider, path_to_sounds: str, n_iters): 209 | super().__init__( 210 | name=f"Torch Audiomentations Add Background Noise ({data_provider.dtype}) (batch_size={data_provider.batch_size})", 211 | warmup_iterations=10, 212 | iterations=n_iters, 213 | samples_per_iter=data_provider.batch_size, 214 | is_gpu_timer_required=True 215 | ) 216 | self.data_provider = data_provider 217 | self.path_to_sounds = path_to_sounds 218 | 219 | def on_start(self): 220 | from torch_audiomentations import AddBackgroundNoise 221 | self.add_background_noise = AddBackgroundNoise( 222 | background_paths=self.path_to_sounds, 223 | min_snr_in_db=-10, 224 | max_snr_in_db=10, 225 | p=1.0 226 | ) 227 | self.audio = self.data_provider.get().unsqueeze(1) 228 | 229 | def run_iteration(self): 230 | self.add_background_noise(samples=self.audio, sample_rate=44100) 231 | 232 | def run_warmup_iteration(self): 233 | self.add_background_noise(samples=self.audio, sample_rate=44100) 234 | 235 | def on_end(self): 236 | pass 237 | 238 | if __name__ == '__main__': 239 | benchmark = Benchmark(name="Add Background Noise Benchmark", bench_type='add_background_noise') 240 | benchmark.run() 241 | benchmark.print_results() 242 | -------------------------------------------------------------------------------- /benchmark_local_result/clip.txt: -------------------------------------------------------------------------------- 1 |    ~/code/fast-audiomentations  python3 -m benchmark.clip_benchmark  ✔  15m 56s   mc  09:07:11  2 | Running Audiomentations Clip (dtype=float32) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 3 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 4 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 5 | Running Audiomentations Clip (dtype=float32) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 6 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 7 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 8 | Running Audiomentations Clip (dtype=float32) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 9 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 10 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 11 | Running Audiomentations Clip (dtype=float32) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 12 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 13 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 14 | Running Audiomentations Clip (dtype=float32) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 15 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 16 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 17 | Running Fast Audiomentations Clip (torch.float16) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 18 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 19 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 20 | Running Fast Audiomentations Clip (torch.float32) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 21 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 22 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 23 | Running Fast Audiomentations Clip (torch.float16) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 24 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 25 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 26 | Running Fast Audiomentations Clip (torch.float32) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 27 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 28 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 29 | Running Fast Audiomentations Clip (torch.float16) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 30 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 31 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 32 | Running Fast Audiomentations Clip (torch.float32) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 33 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 34 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 35 | Running Fast Audiomentations Clip (torch.float16) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 36 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 37 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 38 | Running Fast Audiomentations Clip (torch.float32) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 39 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 40 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 41 | Running Fast Audiomentations Clip (torch.float16) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 42 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 43 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 44 | Running Fast Audiomentations Clip (torch.float32) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 45 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 46 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 47 | ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓ 48 | ┃ Suite ┃ Warmup Time Total (mcs) ┃ Warmup Time Per Sample (mcs) ┃ Run Time Total (mcs) ┃ Run Time Per Sample (mcs) ┃ Relative Slowdown ┃ Percentage Slower (%) ┃ 49 | ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩ 50 | │ 🚀 Fast Audiomentations Clip (torch.float16) (batch_size=128) │ 7100086.9274mcs │ 5546.9429mcs │ 178021.0890mcs │ 1.3908mcs │ 1.0000x │ 0.0000% │ 51 | │ ⚡ Fast Audiomentations Clip (torch.float16) (batch_size=64) │ 6151487.1120mcs │ 9611.6986mcs │ 105172.8956mcs │ 1.6433mcs │ 1.1816x │ 18.1578% │ 52 | │ ⚡ Fast Audiomentations Clip (torch.float16) (batch_size=32) │ 6335584.6405mcs │ 19798.7020mcs │ 69793.6321mcs │ 2.1811mcs │ 1.5682x │ 56.8210% │ 53 | │ ⚡ Fast Audiomentations Clip (torch.float32) (batch_size=128) │ 16100.8835mcs │ 12.5788mcs │ 354655.3285mcs │ 2.7707mcs │ 1.9922x │ 99.2210% │ 54 | │ 🐢 Fast Audiomentations Clip (torch.float32) (batch_size=64) │ 16719.1029mcs │ 26.1236mcs │ 211422.2077mcs │ 3.3035mcs │ 2.3752x │ 137.5249% │ 55 | │ 🐢 Fast Audiomentations Clip (torch.float32) (batch_size=32) │ 13491.8690mcs │ 42.1621mcs │ 106112.9909mcs │ 3.3160mcs │ 2.3843x │ 138.4279% │ 56 | │ 🐢 Fast Audiomentations Clip (torch.float16) (batch_size=16) │ 6431032.4192mcs │ 40193.9526mcs │ 58091.6802mcs │ 3.6307mcs │ 2.6106x │ 161.0553% │ 57 | │ 🐢 Fast Audiomentations Clip (torch.float32) (batch_size=16) │ 8401.8707mcs │ 52.5117mcs │ 74660.5767mcs │ 4.6663mcs │ 3.3551x │ 235.5134% │ 58 | │ 🐌 Fast Audiomentations Clip (torch.float32) (batch_size=1) │ 21923.5420mcs │ 2192.3542mcs │ 44268.3514mcs │ 44.2684mcs │ 31.8297x │ 3082.9650% │ 59 | │ 🐌 Fast Audiomentations Clip (torch.float16) (batch_size=1) │ 6823062.4199mcs │ 682306.2420mcs │ 45879.0396mcs │ 45.8790mcs │ 32.9878x │ 3198.7761% │ 60 | │ 🐌 Audiomentations Clip (dtype=float32) (batch_size=128) │ 330404.9969mcs │ 258.1289mcs │ 12166078.0907mcs │ 95.0475mcs │ 68.3407x │ 6734.0656% │ 61 | │ 🐌 Audiomentations Clip (dtype=float32) (batch_size=64) │ 71667.4328mcs │ 111.9804mcs │ 6111328.3634mcs │ 95.4895mcs │ 68.6585x │ 6765.8476% │ 62 | │ 🐌 Audiomentations Clip (dtype=float32) (batch_size=16) │ 22910.8334mcs │ 143.1927mcs │ 1535941.3624mcs │ 95.9963mcs │ 69.0229x │ 6802.2895% │ 63 | │ 🐌 Audiomentations Clip (dtype=float32) (batch_size=32) │ 36659.9560mcs │ 114.5624mcs │ 3081208.4675mcs │ 96.2878mcs │ 69.2324x │ 6823.2437% │ 64 | │ 🐌 Audiomentations Clip (dtype=float32) (batch_size=1) │ 4275.0835mcs │ 427.5084mcs │ 98846.9124mcs │ 98.8469mcs │ 71.0725x │ 7007.2505% │ 65 | └─────────────────────────────────────────────────────────────────────────────┴─────────────────────────┴──────────────────────────────┴──────────────────────┴───────────────────────────┴───────────────────┴───────────────────────┘ 66 | -------------------------------------------------------------------------------- /benchmark/benchmark_tool.py: -------------------------------------------------------------------------------- 1 | from time import time, sleep 2 | from rich.console import Console 3 | from rich.table import Table 4 | from rich.progress import Progress 5 | from collections import defaultdict 6 | import torch 7 | from benchmark.benchmark_data import PATH_TO_SINGLE_AUDIO 8 | import soundfile as sf 9 | import enum 10 | import numpy as np 11 | 12 | class BenchmarkRegistry: 13 | def __init__(self): 14 | self.suites = defaultdict(list) 15 | 16 | def register(self, suite_class, bench_type, *args, **kwargs): 17 | for (this_suite_class, this_args, this_kwargs) in self.suites[bench_type]: 18 | if (this_suite_class.__name__ == suite_class.__name__ 19 | and set(this_args) == set(args) 20 | and this_kwargs == kwargs): 21 | return 22 | self.suites[bench_type].append((suite_class, args, kwargs)) 23 | 24 | def get_suites(self, bench_type): 25 | return [suite_class(*args, **kwargs) for suite_class, args, kwargs in self.suites[bench_type]] 26 | 27 | _benchmark_registry = BenchmarkRegistry() 28 | 29 | def benchmark(bench_type, *args, **kwargs): 30 | def decorator(suite_class): 31 | _benchmark_registry.register(suite_class, bench_type, *args, **kwargs) 32 | return suite_class 33 | return decorator 34 | 35 | class SingleAudioProviderClasses(enum.Enum): 36 | TORCH_GPU = 1 37 | TORCH_CPU = 2 38 | NUMPY = 3 39 | 40 | class SingleAudioProvider: 41 | def __init__(self, cls: SingleAudioProviderClasses, dtype: str | torch.dtype='float32', batch_size=1): 42 | samples, sr = sf.read(PATH_TO_SINGLE_AUDIO, dtype='float32') 43 | self.samples = samples 44 | self.sr = sr 45 | self.cls = cls 46 | self.dtype = dtype 47 | self.batch_size = batch_size 48 | 49 | def get(self): 50 | if self.cls == SingleAudioProviderClasses.TORCH_GPU: 51 | return torch.tensor(self.samples, device='cuda', dtype=self.dtype).repeat(self.batch_size, 1) 52 | elif self.cls == SingleAudioProviderClasses.TORCH_CPU: 53 | return torch.tensor(self.samples, device='cpu', dtype=self.dtype).repeat(self.batch_size, 1) 54 | elif self.cls == SingleAudioProviderClasses.NUMPY: 55 | return np.tile(self.samples, self.batch_size).reshape(self.batch_size, -1) 56 | else: 57 | raise ValueError(f"Unknown class {self.cls}") 58 | 59 | class BenchmarkSuite: 60 | def __init__(self, 61 | name: str, 62 | warmup_iterations: int = 10, 63 | iterations: int = 1000, 64 | samples_per_iter: int = 1, 65 | is_gpu_timer_required: bool = False): 66 | self.__name = name 67 | self.__warmup_iterations = warmup_iterations 68 | self.__iterations = iterations 69 | self.__samples_per_iter = samples_per_iter 70 | self.__is_gpu_timer_required = is_gpu_timer_required 71 | 72 | def iterations(self): 73 | return self.__iterations 74 | 75 | def warmup_iterations(self): 76 | return self.__warmup_iterations 77 | 78 | def samples_per_iter(self): 79 | return self.__samples_per_iter 80 | 81 | def name(self): 82 | return self.__name 83 | 84 | def run(self, progress, task): 85 | elapsed = [] 86 | for i in range(self.__iterations): 87 | 88 | torch.cuda.synchronize() 89 | if self.__is_gpu_timer_required: 90 | start_event = torch.cuda.Event(enable_timing=True) 91 | end_event = torch.cuda.Event(enable_timing=True) 92 | start_event.record() 93 | else: 94 | start = time() 95 | 96 | self.run_iteration() 97 | if self.__is_gpu_timer_required: 98 | end_event.record() 99 | end_event.synchronize() 100 | elapsed.append(start_event.elapsed_time(end_event) / 1000) 101 | else: 102 | end = time() 103 | elapsed.append(end - start) 104 | 105 | progress.update(task, advance=100. / self.__iterations) 106 | return elapsed 107 | 108 | def run_iteration(self): 109 | raise NotImplementedError 110 | 111 | def on_start(self): 112 | raise NotImplementedError 113 | 114 | def on_end(self): 115 | raise NotImplementedError 116 | 117 | def warmup(self, progress, task): 118 | for i in range(self.__warmup_iterations): 119 | self.run_warmup_iteration() 120 | progress.update(task, advance=100. / self.__warmup_iterations) 121 | 122 | def run_warmup_iteration(self): 123 | raise NotImplementedError 124 | 125 | 126 | 127 | class Benchmark: 128 | def __init__(self, name: str, bench_type: str, auto_collect_suites=True): 129 | self.suites = [] 130 | self.name = name 131 | self.bench_type = bench_type 132 | self.timings_on_start = {} 133 | self.timings_on_end = {} 134 | self.timings_on_warmup = {} 135 | self.timings_on_run = {} 136 | 137 | if auto_collect_suites: 138 | self.collect_suites() 139 | 140 | def collect_suites(self): 141 | global _benchmark_registry 142 | for suite in _benchmark_registry.get_suites(self.bench_type): 143 | self.add_suite(suite) 144 | 145 | def add_suite(self, suite: BenchmarkSuite): 146 | self.suites.append(suite) 147 | 148 | def run(self): 149 | with Progress() as progress: 150 | for suite in self.suites: 151 | # Create a task for each suite 152 | task = progress.add_task(f"[green]Running {suite.name()}...", total=100) 153 | 154 | # Run the suite with progress updates 155 | self.run_suite_with_progress(suite, progress, task) 156 | 157 | @staticmethod 158 | def seconds_to_microseconds(seconds): 159 | return seconds * 1000000 160 | 161 | def run_suite_with_progress(self, suite, progress, task): 162 | # Start the suite 163 | start = time() 164 | suite.on_start() 165 | end = time() 166 | self.timings_on_start[suite.name()] = end - start 167 | progress.update(task, advance=25) 168 | 169 | # Warmup 170 | warmup_task = progress.add_task(f"Warmup...", total=100) 171 | suite.warmup(progress, warmup_task) 172 | end = time() 173 | self.timings_on_warmup[suite.name()] = end - start 174 | progress.update(task, advance=25) 175 | 176 | # Run 177 | run_task = progress.add_task(f"Running...", total=100) 178 | elapsed = suite.run(progress, run_task) 179 | self.timings_on_run[suite.name()] = sum(elapsed) 180 | progress.update(task, advance=25) 181 | 182 | # print(suite.name()) 183 | # for i in range(len(elapsed)): 184 | # print(f"{i}: {self.seconds_to_microseconds(elapsed[i])}") 185 | 186 | # End 187 | start = time() 188 | suite.on_end() 189 | end = time() 190 | self.timings_on_end[suite.name()] = end - start 191 | progress.update(task, advance=25) 192 | 193 | def print_results(self): 194 | console = Console() 195 | 196 | # Calculate warmup and run times 197 | warmup_times = {suite.name(): self.timings_on_warmup[suite.name()] for suite in self.suites} 198 | warmup_times_average = {suite.name(): self.timings_on_warmup[suite.name()] / 199 | suite.warmup_iterations() / 200 | suite.samples_per_iter() for suite in self.suites} 201 | 202 | run_times = {suite.name(): self.timings_on_run[suite.name()] for suite in self.suites} 203 | run_times_average = {suite.name(): self.timings_on_run[suite.name()] / 204 | suite.iterations() / 205 | suite.samples_per_iter() for suite in self.suites} 206 | 207 | # Calculate relative slowdown based on run times 208 | fastest_run_time = min(run_times_average.values()) 209 | relative_slowdown = {suite: time / fastest_run_time for suite, time in run_times_average.items()} 210 | 211 | # Gradient colors and emojis for performance 212 | def get_color_and_emoji(slowdown): 213 | if slowdown <= 1.1: # Close to fastest 214 | return "green", "🚀" 215 | elif slowdown <= 2: # Moderately slower 216 | return "yellow", "⚡" 217 | elif slowdown <= 5: # Slower 218 | return "orange", "🐢" 219 | else: # Much slower 220 | return "red", "🐌" 221 | 222 | # Create a table 223 | table = Table(show_header=True, header_style="bold magenta") 224 | table.add_column("Suite", style="dim", width=75) 225 | table.add_column("Warmup Time Total (mcs)", justify="right") 226 | table.add_column("Warmup Time Per Sample (mcs)", justify="right") 227 | table.add_column("Run Time Total (mcs)", justify="right") 228 | table.add_column("Run Time Per Sample (mcs)", justify="right") 229 | table.add_column("Relative Slowdown", justify="right") 230 | table.add_column("Percentage Slower (%)", justify="right") 231 | 232 | # Sort by run times 233 | sorted_suites = sorted(run_times_average.items(), key=lambda x: x[1]) 234 | 235 | # Add rows to the table 236 | for suite, _ in sorted_suites: 237 | run_time = self.seconds_to_microseconds(run_times[suite]) 238 | warmup_time = self.seconds_to_microseconds(warmup_times[suite]) 239 | warmup_time_per_sample = self.seconds_to_microseconds(warmup_times_average[suite]) 240 | run_time_per_sample = self.seconds_to_microseconds(run_times_average[suite]) 241 | slowdown = relative_slowdown[suite] 242 | percent_slower = (slowdown - 1) * 100 243 | color, emoji = get_color_and_emoji(slowdown) 244 | 245 | table.add_row( 246 | f"[{color}]{emoji} {suite}[/]", 247 | f"{warmup_time:.4f}mcs", f"{warmup_time_per_sample:.4f}mcs", 248 | f"{run_time:.4f}mcs", f"{run_time_per_sample:.4f}mcs", 249 | f"{slowdown:.4f}x", f"{percent_slower:.4f}%" 250 | ) 251 | 252 | console.print(table) 253 | 254 | 255 | 256 | 257 | -------------------------------------------------------------------------------- /benchmark_local_result/gain.txt: -------------------------------------------------------------------------------- 1 |    ~/code/fast-audiomentations  python3 -m benchmark.gain_benchmark  ✔  59s   mc  09:10:55  2 | /home/vitalii/.local/lib/python3.10/site-packages/torch_audiomentations/utils/io.py:27: UserWarning: torchaudio._backend.set_audio_backend has been deprecated. With dispatcher enabled, this function is no-op. You can remove the function call. 3 | torchaudio.set_audio_backend("soundfile") 4 | Running Audiomentations Gain (dtype=float32) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 5 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 6 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 7 | Running Audiomentations Gain (dtype=float32) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 8 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 9 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 10 | Running Audiomentations Gain (dtype=float32) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 11 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 12 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 13 | Running Audiomentations Gain (dtype=float32) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 14 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 15 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 16 | Running Audiomentations Gain (dtype=float32) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 17 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 18 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 19 | Running Fast Audiomentations Gain (torch.float16) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 20 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 21 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 22 | Running Fast Audiomentations Gain (torch.float32) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 23 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 24 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 25 | Running Fast Audiomentations Gain (torch.float16) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 26 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 27 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 28 | Running Fast Audiomentations Gain (torch.float32) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 29 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 30 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 31 | Running Fast Audiomentations Gain (torch.float16) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 32 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 33 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 34 | Running Fast Audiomentations Gain (torch.float32) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 35 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 36 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 37 | Running Fast Audiomentations Gain (torch.float16) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 38 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 39 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 40 | Running Fast Audiomentations Gain (torch.float32) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 41 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 42 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 43 | Running Fast Audiomentations Gain (torch.float16) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 44 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 45 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 46 | Running Fast Audiomentations Gain (torch.float32) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 47 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 48 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 49 | Running Torch Audiomentations Gain (torch.float32) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 50 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 51 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 52 | Running Torch Audiomentations Gain (torch.float32) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 53 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 54 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 55 | Running Torch Audiomentations Gain (torch.float32) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 56 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 57 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 58 | Running Torch Audiomentations Gain (torch.float32) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 59 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 60 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 61 | Running Torch Audiomentations Gain (torch.float32) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 62 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 63 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 64 | ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓ 65 | ┃ Suite ┃ Warmup Time Total (mcs) ┃ Warmup Time Per Sample (mcs) ┃ Run Time Total (mcs) ┃ Run Time Per Sample (mcs) ┃ Relative Slowdown ┃ Percentage Slower (%) ┃ 66 | ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩ 67 | │ 🚀 Fast Audiomentations Gain (torch.float16) (batch_size=128) │ 8759505.9872mcs │ 6843.3641mcs │ 184982.3672mcs │ 1.4452mcs │ 1.0000x │ 0.0000% │ 68 | │ ⚡ Fast Audiomentations Gain (torch.float16) (batch_size=64) │ 6351126.4324mcs │ 9923.6351mcs │ 111195.8071mcs │ 1.7374mcs │ 1.2022x │ 20.2231% │ 69 | │ ⚡ Fast Audiomentations Gain (torch.float16) (batch_size=32) │ 6330022.0966mcs │ 19781.3191mcs │ 77595.4579mcs │ 2.4249mcs │ 1.6779x │ 67.7900% │ 70 | │ ⚡ Fast Audiomentations Gain (torch.float32) (batch_size=128) │ 73546.1712mcs │ 57.4579mcs │ 368522.9445mcs │ 2.8791mcs │ 1.9922x │ 99.2206% │ 71 | │ 🐢 Fast Audiomentations Gain (torch.float32) (batch_size=64) │ 43500.9003mcs │ 67.9702mcs │ 231702.6881mcs │ 3.6204mcs │ 2.5051x │ 150.5133% │ 72 | │ 🐢 Fast Audiomentations Gain (torch.float32) (batch_size=32) │ 42330.0266mcs │ 132.2813mcs │ 126643.6804mcs │ 3.9576mcs │ 2.7385x │ 173.8503% │ 73 | │ 🐢 Fast Audiomentations Gain (torch.float16) (batch_size=16) │ 6213977.0985mcs │ 38837.3569mcs │ 65651.9672mcs │ 4.1032mcs │ 2.8393x │ 183.9275% │ 74 | │ 🐢 Fast Audiomentations Gain (torch.float32) (batch_size=16) │ 16131.1626mcs │ 100.8198mcs │ 89645.4729mcs │ 5.6028mcs │ 3.8769x │ 287.6931% │ 75 | │ 🐌 Torch Audiomentations Gain (torch.float32) (batch_size=128) │ 86367.8455mcs │ 67.4749mcs │ 1581925.4411mcs │ 12.3588mcs │ 8.5518x │ 755.1763% │ 76 | │ 🐌 Torch Audiomentations Gain (torch.float32) (batch_size=64) │ 18994.5698mcs │ 29.6790mcs │ 889250.2408mcs │ 13.8945mcs │ 9.6144x │ 861.4432% │ 77 | │ 🐌 Torch Audiomentations Gain (torch.float32) (batch_size=32) │ 26453.2566mcs │ 82.6664mcs │ 531161.1511mcs │ 16.5988mcs │ 11.4857x │ 1048.5660% │ 78 | │ 🐌 Torch Audiomentations Gain (torch.float32) (batch_size=16) │ 15734.9110mcs │ 98.3432mcs │ 367529.2482mcs │ 22.9706mcs │ 15.8947x │ 1489.4672% │ 79 | │ 🐌 Audiomentations Gain (dtype=float32) (batch_size=1) │ 3818.5120mcs │ 381.8512mcs │ 34164.9055mcs │ 34.1649mcs │ 23.6407x │ 2264.0674% │ 80 | │ 🐌 Audiomentations Gain (dtype=float32) (batch_size=16) │ 10463.7146mcs │ 65.3982mcs │ 547697.7825mcs │ 34.2311mcs │ 23.6865x │ 2268.6486% │ 81 | │ 🐌 Audiomentations Gain (dtype=float32) (batch_size=32) │ 17522.3351mcs │ 54.7573mcs │ 1107546.0911mcs │ 34.6108mcs │ 23.9492x │ 2294.9225% │ 82 | │ 🐌 Audiomentations Gain (dtype=float32) (batch_size=128) │ 252207.0408mcs │ 197.0368mcs │ 4504890.9187mcs │ 35.1945mcs │ 24.3531x │ 2335.3083% │ 83 | │ 🐌 Audiomentations Gain (dtype=float32) (batch_size=64) │ 33100.3666mcs │ 51.7193mcs │ 2290756.4640mcs │ 35.7931mcs │ 24.7673x │ 2376.7295% │ 84 | │ 🐌 Fast Audiomentations Gain (torch.float16) (batch_size=1) │ 7984286.3083mcs │ 798428.6308mcs │ 50432.3202mcs │ 50.4323mcs │ 34.8970x │ 3389.7040% │ 85 | │ 🐌 Fast Audiomentations Gain (torch.float32) (batch_size=1) │ 54518.4612mcs │ 5451.8461mcs │ 54452.0325mcs │ 54.4520mcs │ 37.6785x │ 3667.8511% │ 86 | │ 🐌 Torch Audiomentations Gain (torch.float32) (batch_size=1) │ 14967.4416mcs │ 1496.7442mcs │ 201640.8323mcs │ 201.6408mcs │ 139.5270x │ 13852.6956% │ 87 | └─────────────────────────────────────────────────────────────────────────────┴─────────────────────────┴──────────────────────────────┴──────────────────────┴───────────────────────────┴───────────────────┴───────────────────────┘ 88 | -------------------------------------------------------------------------------- /benchmark_local_result/low_pass_filter.txt: -------------------------------------------------------------------------------- 1 |    ~/code/fast-audiomentations  python3 -m benchmark.low_pass_benchmark  ✔  8m 46s   mc  09:25:19  2 | /home/vitalii/.local/lib/python3.10/site-packages/torch_audiomentations/utils/io.py:27: UserWarning: torchaudio._backend.set_audio_backend has been deprecated. With dispatcher enabled, this function is no-op. You can remove the function call. 3 | torchaudio.set_audio_backend("soundfile") 4 | Running Audiomentations Low Pass Filter (dtype=float32) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 5 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 6 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 7 | Running Audiomentations Low Pass Filter (dtype=float32) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 8 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 9 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 10 | Running Audiomentations Low Pass Filter (dtype=float32) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 11 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 12 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 13 | Running Audiomentations Low Pass Filter (dtype=float32) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 14 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 15 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 16 | Running Audiomentations Low Pass Filter (dtype=float32) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 17 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 18 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 19 | Running Fast Audiomentations Low Pass Filter (torch.float16) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 20 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 21 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 22 | Running Fast Audiomentations Low Pass Filter (torch.float32) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 23 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 24 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 25 | Running Fast Audiomentations Low Pass Filter (torch.float16) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 26 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 27 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 28 | Running Fast Audiomentations Low Pass Filter (torch.float32) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 29 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 30 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 31 | Running Fast Audiomentations Low Pass Filter (torch.float16) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 32 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 33 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 34 | Running Fast Audiomentations Low Pass Filter (torch.float32) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 35 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 36 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 37 | Running Fast Audiomentations Low Pass Filter (torch.float16) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 38 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 39 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 40 | Running Fast Audiomentations Low Pass Filter (torch.float32) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 41 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 42 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 43 | Running Fast Audiomentations Low Pass Filter (torch.float16) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 44 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 45 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 46 | Running Fast Audiomentations Low Pass Filter (torch.float32) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 47 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 48 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 49 | Running Torch Audiomentations Low Pass Filter (torch.float32) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 50 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 51 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 52 | Running Torch Audiomentations Low Pass Filter (torch.float32) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 53 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 54 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 55 | Running Torch Audiomentations Low Pass Filter (torch.float32) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 56 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 57 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 58 | Running Torch Audiomentations Low Pass Filter (torch.float32) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 59 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 60 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 61 | Running Torch Audiomentations Low Pass Filter (torch.float32) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 62 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 63 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 64 | ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓ 65 | ┃ Suite ┃ Warmup Time Total (mcs) ┃ Warmup Time Per Sample (mcs) ┃ Run Time Total (mcs) ┃ Run Time Per Sample (mcs) ┃ Relative Slowdown ┃ Percentage Slower (%) ┃ 66 | ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩ 67 | │ 🚀 Fast Audiomentations Low Pass Filter (torch.float16) (batch_size=128) │ 9479858.3984mcs │ 7406.1394mcs │ 6539680.1028mcs │ 51.0913mcs │ 1.0000x │ 0.0000% │ 68 | │ 🚀 Fast Audiomentations Low Pass Filter (torch.float32) (batch_size=128) │ 12580.1563mcs │ 9.8282mcs │ 6545469.5730mcs │ 51.1365mcs │ 1.0009x │ 0.0885% │ 69 | │ 🚀 Fast Audiomentations Low Pass Filter (torch.float32) (batch_size=32) │ 9328.3653mcs │ 29.1511mcs │ 1659497.1881mcs │ 51.8593mcs │ 1.0150x │ 1.5033% │ 70 | │ 🚀 Fast Audiomentations Low Pass Filter (torch.float16) (batch_size=64) │ 789064.1689mcs │ 1232.9128mcs │ 3355513.0885mcs │ 52.4299mcs │ 1.0262x │ 2.6201% │ 71 | │ 🚀 Fast Audiomentations Low Pass Filter (torch.float16) (batch_size=32) │ 690002.6798mcs │ 2156.2584mcs │ 1685868.7358mcs │ 52.6834mcs │ 1.0312x │ 3.1163% │ 72 | │ 🚀 Fast Audiomentations Low Pass Filter (torch.float32) (batch_size=64) │ 11173.2483mcs │ 17.4582mcs │ 3396280.4503mcs │ 53.0669mcs │ 1.0387x │ 3.8669% │ 73 | │ ⚡ Fast Audiomentations Low Pass Filter (torch.float32) (batch_size=16) │ 10459.8999mcs │ 65.3744mcs │ 1115266.5288mcs │ 69.7042mcs │ 1.3643x │ 36.4307% │ 74 | │ ⚡ Fast Audiomentations Low Pass Filter (torch.float16) (batch_size=16) │ 763370.2755mcs │ 4771.0642mcs │ 1150158.3067mcs │ 71.8849mcs │ 1.4070x │ 40.6990% │ 75 | │ 🐌 Fast Audiomentations Low Pass Filter (torch.float32) (batch_size=1) │ 14244.7948mcs │ 1424.4795mcs │ 564696.2252mcs │ 564.6962mcs │ 11.0527x │ 1005.2699% │ 76 | │ 🐌 Fast Audiomentations Low Pass Filter (torch.float16) (batch_size=1) │ 743583.4408mcs │ 74358.3441mcs │ 613910.3366mcs │ 613.9103mcs │ 12.0160x │ 1101.5958% │ 77 | │ 🐌 Torch Audiomentations Low Pass Filter (torch.float32) (batch_size=128) │ 1779130.6973mcs │ 1389.9459mcs │ 80979866.2987mcs │ 632.6552mcs │ 12.3828x │ 1138.2848% │ 78 | │ 🐌 Torch Audiomentations Low Pass Filter (torch.float32) (batch_size=64) │ 457803.0109mcs │ 715.3172mcs │ 42682761.6444mcs │ 666.9182mcs │ 13.0535x │ 1205.3471% │ 79 | │ 🐌 Torch Audiomentations Low Pass Filter (torch.float32) (batch_size=16) │ 109733.1047mcs │ 685.8319mcs │ 10746296.6695mcs │ 671.6435mcs │ 13.1460x │ 1214.5960% │ 80 | │ 🐌 Torch Audiomentations Low Pass Filter (torch.float32) (batch_size=32) │ 251117.9447mcs │ 784.7436mcs │ 22745037.4985mcs │ 710.7824mcs │ 13.9120x │ 1291.2018% │ 81 | │ 🐌 Torch Audiomentations Low Pass Filter (torch.float32) (batch_size=1) │ 21835.3271mcs │ 2183.5327mcs │ 1041417.3760mcs │ 1041.4174mcs │ 20.3835x │ 1938.3478% │ 82 | │ 🐌 Audiomentations Low Pass Filter (dtype=float32) (batch_size=128) │ 1871715.7841mcs │ 1462.2780mcs │ 165035952.0912mcs │ 1289.3434mcs │ 25.2361x │ 2423.6089% │ 83 | │ 🐌 Audiomentations Low Pass Filter (dtype=float32) (batch_size=64) │ 837191.3433mcs │ 1308.1115mcs │ 82728461.7424mcs │ 1292.6322mcs │ 25.3005x │ 2430.0461% │ 84 | │ 🐌 Audiomentations Low Pass Filter (dtype=float32) (batch_size=32) │ 420398.9506mcs │ 1313.7467mcs │ 41675310.3733mcs │ 1302.3534mcs │ 25.4907x │ 2449.0733% │ 85 | │ 🐌 Audiomentations Low Pass Filter (dtype=float32) (batch_size=1) │ 16512.8708mcs │ 1651.2871mcs │ 1309809.9232mcs │ 1309.8099mcs │ 25.6367x │ 2463.6678% │ 86 | │ 🐌 Audiomentations Low Pass Filter (dtype=float32) (batch_size=16) │ 213551.9981mcs │ 1334.7000mcs │ 20961445.8084mcs │ 1310.0904mcs │ 25.6422x │ 2464.2167% │ 87 | └─────────────────────────────────────────────────────────────────────────────┴─────────────────────────┴──────────────────────────────┴──────────────────────┴───────────────────────────┴───────────────────┴───────────────────────┘ 88 | -------------------------------------------------------------------------------- /benchmark_local_result/band_pass.txt: -------------------------------------------------------------------------------- 1 |    ~/code/fast-audiomentations  python3 -m benchmark.band_pass_benchmark  INT ✘  mc  08:34:00  2 | /home/vitalii/.local/lib/python3.10/site-packages/torch_audiomentations/utils/io.py:27: UserWarning: torchaudio._backend.set_audio_backend has been deprecated. With dispatcher enabled, this function is no-op. You can remove the function call. 3 | torchaudio.set_audio_backend("soundfile") 4 | Running Audiomentations Band Pass Filter (dtype=float32) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 5 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 6 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 7 | Running Audiomentations Band Pass Filter (dtype=float32) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 8 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 9 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 10 | Running Audiomentations Band Pass Filter (dtype=float32) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 11 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 12 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 13 | Running Audiomentations Band Pass Filter (dtype=float32) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 14 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 15 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 16 | Running Audiomentations Band Pass Filter (dtype=float32) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 17 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 18 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 19 | Running Fast Audiomentations Band Pass Filter (torch.float16) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 20 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 21 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 22 | Running Fast Audiomentations Band Pass Filter (torch.float32) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 23 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 24 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 25 | Running Fast Audiomentations Band Pass Filter (torch.float16) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 26 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 27 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 28 | Running Fast Audiomentations Band Pass Filter (torch.float32) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 29 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 30 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 31 | Running Fast Audiomentations Band Pass Filter (torch.float16) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 32 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 33 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 34 | Running Fast Audiomentations Band Pass Filter (torch.float32) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 35 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 36 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 37 | Running Fast Audiomentations Band Pass Filter (torch.float16) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 38 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 39 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 40 | Running Fast Audiomentations Band Pass Filter (torch.float32) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 41 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 42 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 43 | Running Fast Audiomentations Band Pass Filter (torch.float16) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 44 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 45 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 46 | Running Fast Audiomentations Band Pass Filter (torch.float32) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 47 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 48 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 49 | Running Torch Audiomentations Band Pass Filter (torch.float32) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 50 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 51 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 52 | Running Torch Audiomentations Band Pass Filter (torch.float32) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 53 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 54 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 55 | Running Torch Audiomentations Band Pass Filter (torch.float32) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 56 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 57 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 58 | Running Torch Audiomentations Band Pass Filter (torch.float32) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 59 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 60 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 61 | Running Torch Audiomentations Band Pass Filter (torch.float32) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 62 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 63 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 64 | ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓ 65 | ┃ Suite ┃ Warmup Time Total (mcs) ┃ Warmup Time Per Sample (mcs) ┃ Run Time Total (mcs) ┃ Run Time Per Sample (mcs) ┃ Relative Slowdown ┃ Percentage Slower (%) ┃ 66 | ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩ 67 | │ 🚀 Fast Audiomentations Band Pass Filter (torch.float32) (batch_size=128) │ 12373.4474mcs │ 9.6668mcs │ 6399552.8026mcs │ 49.9965mcs │ 1.0000x │ 0.0000% │ 68 | │ 🚀 Fast Audiomentations Band Pass Filter (torch.float16) (batch_size=128) │ 9565271.6160mcs │ 7472.8684mcs │ 6414372.0684mcs │ 50.1123mcs │ 1.0023x │ 0.2316% │ 69 | │ 🚀 Fast Audiomentations Band Pass Filter (torch.float16) (batch_size=32) │ 764896.1544mcs │ 2390.3005mcs │ 1655801.2835mcs │ 51.7438mcs │ 1.0349x │ 3.4948% │ 70 | │ 🚀 Fast Audiomentations Band Pass Filter (torch.float32) (batch_size=32) │ 9869.3371mcs │ 30.8417mcs │ 1662280.8615mcs │ 51.9463mcs │ 1.0390x │ 3.8998% │ 71 | │ 🚀 Fast Audiomentations Band Pass Filter (torch.float16) (batch_size=64) │ 759296.6557mcs │ 1186.4010mcs │ 3447901.8970mcs │ 53.8735mcs │ 1.0775x │ 7.7545% │ 72 | │ ⚡ Fast Audiomentations Band Pass Filter (torch.float32) (batch_size=64) │ 11142.9691mcs │ 17.4109mcs │ 3542940.0640mcs │ 55.3584mcs │ 1.1072x │ 10.7246% │ 73 | │ ⚡ Fast Audiomentations Band Pass Filter (torch.float32) (batch_size=16) │ 10815.6204mcs │ 67.5976mcs │ 999265.4073mcs │ 62.4541mcs │ 1.2492x │ 24.9169% │ 74 | │ ⚡ Fast Audiomentations Band Pass Filter (torch.float16) (batch_size=16) │ 761017.3225mcs │ 4756.3583mcs │ 1089747.6792mcs │ 68.1092mcs │ 1.3623x │ 36.2280% │ 75 | │ 🐌 Fast Audiomentations Band Pass Filter (torch.float32) (batch_size=1) │ 14457.9411mcs │ 1445.7941mcs │ 400722.6866mcs │ 400.7227mcs │ 8.0150x │ 701.5014% │ 76 | │ 🐌 Fast Audiomentations Band Pass Filter (torch.float16) (batch_size=1) │ 770780.8018mcs │ 77078.0802mcs │ 401693.2478mcs │ 401.6932mcs │ 8.0344x │ 703.4426% │ 77 | │ 🐌 Audiomentations Band Pass Filter (dtype=float32) (batch_size=128) │ 2041798.8300mcs │ 1595.1553mcs │ 183139631.5098mcs │ 1430.7784mcs │ 28.6176x │ 2761.7567% │ 78 | │ 🐌 Audiomentations Band Pass Filter (dtype=float32) (batch_size=64) │ 934302.0916mcs │ 1459.8470mcs │ 91806728.6015mcs │ 1434.4801mcs │ 28.6916x │ 2769.1608% │ 79 | │ 🐌 Audiomentations Band Pass Filter (dtype=float32) (batch_size=32) │ 470837.1162mcs │ 1471.3660mcs │ 46064014.6732mcs │ 1439.5005mcs │ 28.7920x │ 2779.2021% │ 80 | │ 🐌 Audiomentations Band Pass Filter (dtype=float32) (batch_size=1) │ 21025.8961mcs │ 2102.5896mcs │ 1446767.5686mcs │ 1446.7676mcs │ 28.9374x │ 2793.7373% │ 81 | │ 🐌 Audiomentations Band Pass Filter (dtype=float32) (batch_size=16) │ 238082.1705mcs │ 1488.0136mcs │ 23215111.7325mcs │ 1450.9445mcs │ 29.0209x │ 2802.0917% │ 82 | │ 🐌 Torch Audiomentations Band Pass Filter (torch.float32) (batch_size=128) │ 4476696.4912mcs │ 3497.4191mcs │ 280803409.7137mcs │ 2193.7766mcs │ 43.8786x │ 4287.8599% │ 83 | │ 🐌 Torch Audiomentations Band Pass Filter (torch.float32) (batch_size=16) │ 333304.6436mcs │ 2083.1540mcs │ 38027220.1443mcs │ 2376.7013mcs │ 47.5373x │ 4653.7347% │ 84 | │ 🐌 Torch Audiomentations Band Pass Filter (torch.float32) (batch_size=64) │ 1598209.8579mcs │ 2497.2029mcs │ 153943295.5666mcs │ 2405.3640mcs │ 48.1106x │ 4711.0642% │ 85 | │ 🐌 Torch Audiomentations Band Pass Filter (torch.float32) (batch_size=32) │ 735768.7950mcs │ 2299.2775mcs │ 79369528.5187mcs │ 2480.2978mcs │ 49.6094x │ 4860.9422% │ 86 | │ 🐌 Torch Audiomentations Band Pass Filter (torch.float32) (batch_size=1) │ 47212.6007mcs │ 4721.2601mcs │ 2951544.3507mcs │ 2951.5444mcs │ 59.0350x │ 5803.5012% │ 87 | └─────────────────────────────────────────────────────────────────────────────┴─────────────────────────┴──────────────────────────────┴──────────────────────┴───────────────────────────┴───────────────────┴───────────────────────┘ 88 | -------------------------------------------------------------------------------- /benchmark_local_result/band_stop.txt: -------------------------------------------------------------------------------- 1 |    ~/code/fast-audiomentations  python3 -m benchmark.band_stop_benchmark  ✔  15m 55s   mc  08:50:34  2 | /home/vitalii/.local/lib/python3.10/site-packages/torch_audiomentations/utils/io.py:27: UserWarning: torchaudio._backend.set_audio_backend has been deprecated. With dispatcher enabled, this function is no-op. You can remove the function call. 3 | torchaudio.set_audio_backend("soundfile") 4 | Running Audiomentations Band Stop Filter (dtype=float32) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 5 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 6 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 7 | Running Audiomentations Band Stop Filter (dtype=float32) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 8 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 9 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 10 | Running Audiomentations Band Stop Filter (dtype=float32) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 11 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 12 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 13 | Running Audiomentations Band Stop Filter (dtype=float32) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 14 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 15 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 16 | Running Audiomentations Band Stop Filter (dtype=float32) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 17 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 18 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 19 | Running Fast Audiomentations Band Stop Filter (torch.float16) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 20 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 21 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 22 | Running Fast Audiomentations Band Stop Filter (torch.float32) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 23 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 24 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 25 | Running Fast Audiomentations Band Stop Filter (torch.float16) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 26 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 27 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 28 | Running Fast Audiomentations Band Stop Filter (torch.float32) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 29 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 30 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 31 | Running Fast Audiomentations Band Stop Filter (torch.float16) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 32 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 33 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 34 | Running Fast Audiomentations Band Stop Filter (torch.float32) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 35 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 36 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 37 | Running Fast Audiomentations Band Stop Filter (torch.float16) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 38 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 39 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 40 | Running Fast Audiomentations Band Stop Filter (torch.float32) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 41 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 42 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 43 | Running Fast Audiomentations Band Stop Filter (torch.float16) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 44 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 45 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 46 | Running Fast Audiomentations Band Stop Filter (torch.float32) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 47 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 48 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 49 | Running Torch Audiomentations Band Stop Filter (torch.float32) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 50 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 51 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 52 | Running Torch Audiomentations Band Stop Filter (torch.float32) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 53 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 54 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 55 | Running Torch Audiomentations Band Stop Filter (torch.float32) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 56 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 57 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 58 | Running Torch Audiomentations Band Stop Filter (torch.float32) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 59 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 60 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 61 | Running Torch Audiomentations Band Stop Filter (torch.float32) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 62 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 63 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 64 | ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓ 65 | ┃ Suite ┃ Warmup Time Total (mcs) ┃ Warmup Time Per Sample (mcs) ┃ Run Time Total (mcs) ┃ Run Time Per Sample (mcs) ┃ Relative Slowdown ┃ Percentage Slower (%) ┃ 66 | ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩ 67 | │ 🚀 Fast Audiomentations Band Stop Filter (torch.float16) (batch_size=128) │ 9513787.5080mcs │ 7432.6465mcs │ 6533075.1925mcs │ 51.0396mcs │ 1.0000x │ 0.0000% │ 68 | │ 🚀 Fast Audiomentations Band Stop Filter (torch.float32) (batch_size=128) │ 12942.7910mcs │ 10.1116mcs │ 6745739.0108mcs │ 52.7011mcs │ 1.0326x │ 3.2552% │ 69 | │ 🚀 Fast Audiomentations Band Stop Filter (torch.float32) (batch_size=64) │ 10890.0070mcs │ 17.0156mcs │ 3393847.8351mcs │ 53.0289mcs │ 1.0390x │ 3.8974% │ 70 | │ 🚀 Fast Audiomentations Band Stop Filter (torch.float32) (batch_size=32) │ 10614.8720mcs │ 33.1715mcs │ 1712774.0173mcs │ 53.5242mcs │ 1.0487x │ 4.8679% │ 71 | │ 🚀 Fast Audiomentations Band Stop Filter (torch.float16) (batch_size=64) │ 785462.6179mcs │ 1227.2853mcs │ 3435902.3967mcs │ 53.6860mcs │ 1.0518x │ 5.1848% │ 72 | │ 🚀 Fast Audiomentations Band Stop Filter (torch.float16) (batch_size=32) │ 719348.4306mcs │ 2247.9638mcs │ 1729183.3923mcs │ 54.0370mcs │ 1.0587x │ 5.8726% │ 73 | │ ⚡ Fast Audiomentations Band Stop Filter (torch.float32) (batch_size=16) │ 20036.9358mcs │ 125.2308mcs │ 1121768.7036mcs │ 70.1105mcs │ 1.3736x │ 37.3649% │ 74 | │ ⚡ Fast Audiomentations Band Stop Filter (torch.float16) (batch_size=16) │ 725724.2203mcs │ 4535.7764mcs │ 1228764.7386mcs │ 76.7978mcs │ 1.5047x │ 50.4669% │ 75 | │ 🐌 Fast Audiomentations Band Stop Filter (torch.float32) (batch_size=1) │ 14875.4120mcs │ 1487.5412mcs │ 619267.0087mcs │ 619.2670mcs │ 12.1331x │ 1113.3058% │ 76 | │ 🐌 Fast Audiomentations Band Stop Filter (torch.float16) (batch_size=1) │ 884478.5690mcs │ 88447.8569mcs │ 634598.5917mcs │ 634.5986mcs │ 12.4334x │ 1143.3443% │ 77 | │ 🐌 Audiomentations Band Stop Filter (dtype=float32) (batch_size=128) │ 2079025.5070mcs │ 1624.2387mcs │ 186835763.9313mcs │ 1459.6544mcs │ 28.5984x │ 2759.8441% │ 78 | │ 🐌 Audiomentations Band Stop Filter (dtype=float32) (batch_size=64) │ 938365.6979mcs │ 1466.1964mcs │ 93524050.4742mcs │ 1461.3133mcs │ 28.6309x │ 2763.0943% │ 79 | │ 🐌 Audiomentations Band Stop Filter (dtype=float32) (batch_size=32) │ 472897.2912mcs │ 1477.8040mcs │ 47126485.8246mcs │ 1472.7027mcs │ 28.8541x │ 2785.4091% │ 80 | │ 🐌 Audiomentations Band Stop Filter (dtype=float32) (batch_size=16) │ 244487.5240mcs │ 1528.0470mcs │ 23845888.1378mcs │ 1490.3680mcs │ 29.2002x │ 2820.0200% │ 81 | │ 🐌 Audiomentations Band Stop Filter (dtype=float32) (batch_size=1) │ 17534.9712mcs │ 1753.4971mcs │ 1507139.4444mcs │ 1507.1394mcs │ 29.5288x │ 2852.8797% │ 82 | │ 🐌 Torch Audiomentations Band Stop Filter (torch.float32) (batch_size=128) │ 3754051.2085mcs │ 2932.8525mcs │ 285680166.3971mcs │ 2231.8763mcs │ 43.7283x │ 4272.8284% │ 83 | │ 🐌 Torch Audiomentations Band Stop Filter (torch.float32) (batch_size=32) │ 787430.5248mcs │ 2460.7204mcs │ 71895619.1273mcs │ 2246.7381mcs │ 44.0195x │ 4301.9465% │ 84 | │ 🐌 Torch Audiomentations Band Stop Filter (torch.float32) (batch_size=64) │ 1466037.9887mcs │ 2290.6844mcs │ 146637201.9501mcs │ 2291.2063mcs │ 44.8907x │ 4389.0713% │ 85 | │ 🐌 Torch Audiomentations Band Stop Filter (torch.float32) (batch_size=16) │ 419323.4444mcs │ 2620.7715mcs │ 40940346.0703mcs │ 2558.7716mcs │ 50.1330x │ 4913.3017% │ 86 | │ 🐌 Torch Audiomentations Band Stop Filter (torch.float32) (batch_size=1) │ 79014.5397mcs │ 7901.4540mcs │ 3099280.6987mcs │ 3099.2807mcs │ 60.7230x │ 5972.3001% │ 87 | └─────────────────────────────────────────────────────────────────────────────┴─────────────────────────┴──────────────────────────────┴──────────────────────┴───────────────────────────┴───────────────────┴───────────────────────┘ 88 | -------------------------------------------------------------------------------- /benchmark_local_result/high_pass_filter.txt: -------------------------------------------------------------------------------- 1 |    ~/code/fast-audiomentations  python3 -m benchmark.high_pass_benchmark  ✔  52s   mc  09:15:56  2 | /home/vitalii/.local/lib/python3.10/site-packages/torch_audiomentations/utils/io.py:27: UserWarning: torchaudio._backend.set_audio_backend has been deprecated. With dispatcher enabled, this function is no-op. You can remove the function call. 3 | torchaudio.set_audio_backend("soundfile") 4 | Running Audiomentations High Pass Filter (dtype=float32) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 5 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 6 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 7 | Running Audiomentations High Pass Filter (dtype=float32) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 8 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 9 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 10 | Running Audiomentations High Pass Filter (dtype=float32) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 11 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 12 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 13 | Running Audiomentations High Pass Filter (dtype=float32) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 14 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 15 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 16 | Running Audiomentations High Pass Filter (dtype=float32) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 17 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 18 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 19 | Running Fast Audiomentations High Pass Filter (torch.float16) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 20 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 21 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 22 | Running Fast Audiomentations High Pass Filter (torch.float32) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 23 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 24 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 25 | Running Fast Audiomentations High Pass Filter (torch.float16) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 26 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 27 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 28 | Running Fast Audiomentations High Pass Filter (torch.float32) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 29 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 30 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 31 | Running Fast Audiomentations High Pass Filter (torch.float16) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 32 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 33 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 34 | Running Fast Audiomentations High Pass Filter (torch.float32) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 35 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 36 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 37 | Running Fast Audiomentations High Pass Filter (torch.float16) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 38 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 39 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 40 | Running Fast Audiomentations High Pass Filter (torch.float32) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 41 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 42 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 43 | Running Fast Audiomentations High Pass Filter (torch.float16) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 44 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 45 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 46 | Running Fast Audiomentations High Pass Filter (torch.float32) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 47 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 48 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 49 | Running Torch Audiomentations High Pass Filter (torch.float32) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 50 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 51 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 52 | Running Torch Audiomentations High Pass Filter (torch.float32) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 53 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 54 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 55 | Running Torch Audiomentations High Pass Filter (torch.float32) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 56 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 57 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 58 | Running Torch Audiomentations High Pass Filter (torch.float32) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 59 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 60 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 61 | Running Torch Audiomentations High Pass Filter (torch.float32) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 62 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 63 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 64 | ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓ 65 | ┃ Suite ┃ Warmup Time Total (mcs) ┃ Warmup Time Per Sample (mcs) ┃ Run Time Total (mcs) ┃ Run Time Per Sample (mcs) ┃ Relative Slowdown ┃ Percentage Slower (%) ┃ 66 | ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩ 67 | │ 🚀 Fast Audiomentations High Pass Filter (torch.float32) (batch_size=32) │ 16178.3695mcs │ 50.5574mcs │ 1633060.0635mcs │ 51.0331mcs │ 1.0000x │ 0.0000% │ 68 | │ 🚀 Fast Audiomentations High Pass Filter (torch.float16) (batch_size=32) │ 767281.0555mcs │ 2397.7533mcs │ 1666557.5351mcs │ 52.0799mcs │ 1.0205x │ 2.0512% │ 69 | │ 🚀 Fast Audiomentations High Pass Filter (torch.float32) (batch_size=128) │ 13041.0194mcs │ 10.1883mcs │ 6849122.6883mcs │ 53.5088mcs │ 1.0485x │ 4.8511% │ 70 | │ 🚀 Fast Audiomentations High Pass Filter (torch.float16) (batch_size=64) │ 738150.3582mcs │ 1153.3599mcs │ 3446843.9639mcs │ 53.8569mcs │ 1.0553x │ 5.5333% │ 71 | │ 🚀 Fast Audiomentations High Pass Filter (torch.float32) (batch_size=64) │ 10715.2462mcs │ 16.7426mcs │ 3459167.3298mcs │ 54.0495mcs │ 1.0591x │ 5.9106% │ 72 | │ 🚀 Fast Audiomentations High Pass Filter (torch.float16) (batch_size=128) │ 9929131.5079mcs │ 7757.1340mcs │ 6973906.1165mcs │ 54.4836mcs │ 1.0676x │ 6.7613% │ 73 | │ ⚡ Fast Audiomentations High Pass Filter (torch.float32) (batch_size=16) │ 19281.1489mcs │ 120.5072mcs │ 982135.6156mcs │ 61.3835mcs │ 1.2028x │ 20.2816% │ 74 | │ ⚡ Fast Audiomentations High Pass Filter (torch.float16) (batch_size=16) │ 846808.6720mcs │ 5292.5542mcs │ 1031496.9629mcs │ 64.4686mcs │ 1.2633x │ 26.3269% │ 75 | │ 🐌 Fast Audiomentations High Pass Filter (torch.float32) (batch_size=1) │ 14634.3708mcs │ 1463.4371mcs │ 362972.6720mcs │ 362.9727mcs │ 7.1125x │ 611.2491% │ 76 | │ 🐌 Fast Audiomentations High Pass Filter (torch.float16) (batch_size=1) │ 707847.1184mcs │ 70784.7118mcs │ 387427.1669mcs │ 387.4272mcs │ 7.5917x │ 659.1680% │ 77 | │ 🐌 Torch Audiomentations High Pass Filter (torch.float32) (batch_size=128) │ 2006944.1795mcs │ 1567.9251mcs │ 83923237.0834mcs │ 655.6503mcs │ 12.8475x │ 1184.7543% │ 78 | │ 🐌 Torch Audiomentations High Pass Filter (torch.float32) (batch_size=64) │ 431875.4673mcs │ 674.8054mcs │ 41963267.5610mcs │ 655.6761mcs │ 12.8480x │ 1184.8048% │ 79 | │ 🐌 Torch Audiomentations High Pass Filter (torch.float32) (batch_size=32) │ 218618.3929mcs │ 683.1825mcs │ 21637137.0211mcs │ 676.1605mcs │ 13.2494x │ 1224.9443% │ 80 | │ 🐌 Torch Audiomentations High Pass Filter (torch.float32) (batch_size=16) │ 104858.1600mcs │ 655.3635mcs │ 11223090.5957mcs │ 701.4432mcs │ 13.7449x │ 1274.4860% │ 81 | │ 🐌 Torch Audiomentations High Pass Filter (torch.float32) (batch_size=1) │ 23466.1102mcs │ 2346.6110mcs │ 1062837.5374mcs │ 1062.8375mcs │ 20.8264x │ 1982.6424% │ 82 | │ 🐌 Audiomentations High Pass Filter (dtype=float32) (batch_size=128) │ 1825371.2654mcs │ 1426.0713mcs │ 165114495.7542mcs │ 1289.9570mcs │ 25.2769x │ 2427.6856% │ 83 | │ 🐌 Audiomentations High Pass Filter (dtype=float32) (batch_size=64) │ 838546.5145mcs │ 1310.2289mcs │ 84661211.2522mcs │ 1322.8314mcs │ 25.9210x │ 2492.1034% │ 84 | │ 🐌 Audiomentations High Pass Filter (dtype=float32) (batch_size=32) │ 426318.1686mcs │ 1332.2443mcs │ 42397303.3428mcs │ 1324.9157mcs │ 25.9619x │ 2496.1876% │ 85 | │ 🐌 Audiomentations High Pass Filter (dtype=float32) (batch_size=1) │ 16426.0864mcs │ 1642.6086mcs │ 1326984.8824mcs │ 1326.9849mcs │ 26.0024x │ 2500.2422% │ 86 | │ 🐌 Audiomentations High Pass Filter (dtype=float32) (batch_size=16) │ 220810.8902mcs │ 1380.0681mcs │ 21543944.3588mcs │ 1346.4965mcs │ 26.3848x │ 2538.4754% │ 87 | └─────────────────────────────────────────────────────────────────────────────┴─────────────────────────┴──────────────────────────────┴──────────────────────┴───────────────────────────┴───────────────────┴───────────────────────┘ 88 | -------------------------------------------------------------------------------- /benchmark_local_result/add_background_noise.txt: -------------------------------------------------------------------------------- 1 |    ~/code/fast-audiomentations  python3 -m benchmark.add_background_noise_benchmark  INT ✘  34s   mc  07:44:11  2 | /home/vitalii/.local/lib/python3.10/site-packages/torch_audiomentations/utils/io.py:27: UserWarning: torchaudio._backend.set_audio_backend has been deprecated. With dispatcher enabled, this function is no-op. You can remove the function call. 3 | torchaudio.set_audio_backend("soundfile") 4 | Running Audiomentations Add Background Noise (dtype=float32) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 5 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 6 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 7 | Running Audiomentations Add Background Noise (dtype=float32) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 8 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 9 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 10 | Running Audiomentations Add Background Noise (dtype=float32) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 11 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 12 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 13 | Running Audiomentations Add Background Noise (dtype=float32) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 14 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 15 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 16 | Running Audiomentations Add Background Noise (dtype=float32) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 17 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 18 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 19 | Running Fast Audiomentations Add Background Noise (torch.float16) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 20 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 21 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 22 | Running Fast Audiomentations Add Background Noise (torch.float32) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 23 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 24 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 25 | Running Fast Audiomentations Add Background Noise (torch.float16) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 26 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 27 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 28 | Running Fast Audiomentations Add Background Noise (torch.float32) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 29 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 30 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 31 | Running Fast Audiomentations Add Background Noise (torch.float16) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 32 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 33 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 34 | Running Fast Audiomentations Add Background Noise (torch.float32) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 35 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 36 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 37 | Running Fast Audiomentations Add Background Noise (torch.float16) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 38 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 39 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 40 | Running Fast Audiomentations Add Background Noise (torch.float32) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 41 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 42 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 43 | Running Fast Audiomentations Add Background Noise (torch.float16) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 44 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 45 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 46 | Running Fast Audiomentations Add Background Noise (torch.float32) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 47 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 48 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 49 | Running Torch Audiomentations Add Background Noise (torch.float32) (batch_size=128)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 50 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 51 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 52 | Running Torch Audiomentations Add Background Noise (torch.float32) (batch_size=64)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 53 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 54 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 55 | Running Torch Audiomentations Add Background Noise (torch.float32) (batch_size=32)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 56 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 57 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 58 | Running Torch Audiomentations Add Background Noise (torch.float32) (batch_size=16)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 59 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 60 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 61 | Running Torch Audiomentations Add Background Noise (torch.float32) (batch_size=1)... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 62 | Warmup... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 63 | Running... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 100% 0:00:01 64 | ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓ 65 | ┃ Suite ┃ Warmup Time Total (mcs) ┃ Warmup Time Per Sample (mcs) ┃ Run Time Total (mcs) ┃ Run Time Per Sample (mcs) ┃ Relative Slowdown ┃ Percentage Slower (%) ┃ 66 | ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩ 67 | │ 🚀 Fast Audiomentations Add Background Noise (torch.float16) │ 236064.6725mcs │ 737.7021mcs │ 6740835.8175mcs │ 702.1704mcs │ 1.0000x │ 0.0000% │ 68 | │ (batch_size=32) │ │ │ │ │ │ │ 69 | │ 🚀 Fast Audiomentations Add Background Noise (torch.float32) │ 232386.1122mcs │ 726.2066mcs │ 6749651.6422mcs │ 703.0887mcs │ 1.0013x │ 0.1308% │ 70 | │ (batch_size=32) │ │ │ │ │ │ │ 71 | │ 🚀 Fast Audiomentations Add Background Noise (torch.float16) │ 265414.2380mcs │ 414.7097mcs │ 13878056.1001mcs │ 722.8154mcs │ 1.0294x │ 2.9402% │ 72 | │ (batch_size=64) │ │ │ │ │ │ │ 73 | │ 🚀 Fast Audiomentations Add Background Noise (torch.float16) │ 161287.0693mcs │ 1008.0442mcs │ 3504929.4066mcs │ 730.1936mcs │ 1.0399x │ 3.9909% │ 74 | │ (batch_size=16) │ │ │ │ │ │ │ 75 | │ 🚀 Fast Audiomentations Add Background Noise (torch.float32) │ 261575.9373mcs │ 408.7124mcs │ 14034798.0528mcs │ 730.9791mcs │ 1.0410x │ 4.1028% │ 76 | │ (batch_size=64) │ │ │ │ │ │ │ 77 | │ 🚀 Fast Audiomentations Add Background Noise (torch.float32) │ 161105.3944mcs │ 1006.9087mcs │ 3592804.6629mcs │ 748.5010mcs │ 1.0660x │ 6.5982% │ 78 | │ (batch_size=16) │ │ │ │ │ │ │ 79 | │ ⚡ Fast Audiomentations Add Background Noise (torch.float32) │ 738246.4409mcs │ 576.7550mcs │ 30054204.7977mcs │ 782.6616mcs │ 1.1146x │ 11.4632% │ 80 | │ (batch_size=128) │ │ │ │ │ │ │ 81 | │ ⚡ Fast Audiomentations Add Background Noise (torch.float16) │ 3475562.0956mcs │ 2715.2829mcs │ 31850917.7384mcs │ 829.4510mcs │ 1.1813x │ 18.1267% │ 82 | │ (batch_size=128) │ │ │ │ │ │ │ 83 | │ ⚡ Fast Audiomentations Add Background Noise (torch.float32) (batch_size=1) │ 62872.1714mcs │ 6287.2171mcs │ 417419.9675mcs │ 1391.3999mcs │ 1.9816x │ 98.1570% │ 84 | │ ⚡ Fast Audiomentations Add Background Noise (torch.float16) (batch_size=1) │ 62833.0708mcs │ 6283.3071mcs │ 418665.9207mcs │ 1395.5531mcs │ 1.9875x │ 98.7485% │ 85 | │ 🐢 Audiomentations Add Background Noise (dtype=float32) (batch_size=1) │ 79497.0989mcs │ 7949.7099mcs │ 643973.5889mcs │ 2146.5786mcs │ 3.0571x │ 205.7062% │ 86 | │ 🐢 Audiomentations Add Background Noise (dtype=float32) (batch_size=16) │ 453288.3167mcs │ 2833.0520mcs │ 10833084.5833mcs │ 2256.8926mcs │ 3.2142x │ 221.4167% │ 87 | │ 🐢 Audiomentations Add Background Noise (dtype=float32) (batch_size=128) │ 3710981.8459mcs │ 2899.2046mcs │ 87375024.7955mcs │ 2275.3913mcs │ 3.2405x │ 224.0512% │ 88 | │ 🐢 Audiomentations Add Background Noise (dtype=float32) (batch_size=64) │ 1544400.4536mcs │ 2413.1257mcs │ 45507883.7872mcs │ 2370.2023mcs │ 3.3755x │ 237.5537% │ 89 | │ 🐢 Audiomentations Add Background Noise (dtype=float32) (batch_size=32) │ 825649.7383mcs │ 2580.1554mcs │ 22923871.2788mcs │ 2387.9033mcs │ 3.4007x │ 240.0746% │ 90 | │ 🐌 Torch Audiomentations Add Background Noise (torch.float32) │ 17530441.2842mcs │ 27391.3145mcs │ 526497913.2080mcs │ 27421.7663mcs │ 39.0529x │ 3805.2866% │ 91 | │ (batch_size=64) │ │ │ │ │ │ │ 92 | │ 🐌 Torch Audiomentations Add Background Noise (torch.float32) │ 35179584.2648mcs │ 27484.0502mcs │ 1058446146.9727mcs │ 27563.7017mcs │ 39.2550x │ 3825.5004% │ 93 | │ (batch_size=128) │ │ │ │ │ │ │ 94 | │ 🐌 Torch Audiomentations Add Background Noise (torch.float32) │ 8570765.2569mcs │ 26783.6414mcs │ 265460135.6812mcs │ 27652.0975mcs │ 39.3809x │ 3838.0893% │ 95 | │ (batch_size=32) │ │ │ │ │ │ │ 96 | │ 🐌 Torch Audiomentations Add Background Noise (torch.float32) │ 4452070.4746mcs │ 27825.4405mcs │ 132781658.1421mcs │ 27662.8454mcs │ 39.3962x │ 3839.6200% │ 97 | │ (batch_size=16) │ │ │ │ │ │ │ 98 | │ 🐌 Torch Audiomentations Add Background Noise (torch.float32) │ 433716.0587mcs │ 43371.6059mcs │ 8526919.4756mcs │ 28423.0649mcs │ 40.4789x │ 3947.8871% │ 99 | │ (batch_size=1) │ │ │ │ │ │ │ 100 | └─────────────────────────────────────────────────────────────────────────────┴─────────────────────────┴──────────────────────────────┴──────────────────────┴───────────────────────────┴───────────────────┴───────────────────────┘ 101 | --------------------------------------------------------------------------------