├── tests ├── __init__.py ├── datasets │ ├── __init__.py │ └── test_audio.py ├── modeling │ ├── __init__.py │ ├── test_pretrained.py │ ├── test_unet1d_model.py │ ├── test_modules.py │ ├── test_lightning_module.py │ └── test_waveunet_model.py ├── assets │ └── reverb │ │ └── rir_sample.wav ├── test_metrics.py └── test_transforms.py ├── src └── denoisers │ ├── scripts │ ├── __init__.py │ ├── publish.py │ └── train.py │ ├── datasets │ ├── __init__.py │ └── audio.py │ ├── modeling │ ├── unet1d │ │ ├── __init__.py │ │ ├── config.py │ │ ├── model.py │ │ └── modules.py │ ├── waveunet │ │ ├── __init__.py │ │ ├── config.py │ │ ├── modules.py │ │ └── model.py │ ├── __init__.py │ └── modules.py │ ├── __init__.py │ ├── testing │ └── __init__.py │ ├── metrics.py │ ├── datamodule.py │ ├── utils.py │ ├── losses.py │ ├── lightning_module.py │ └── transforms.py ├── setup.py ├── .github └── workflows │ ├── python-package.yml │ └── python-publish.yml ├── .pre-commit-config.yaml ├── pyproject.toml ├── .gitignore ├── requirements.txt ├── LICENSE └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for denoisers.""" 2 | -------------------------------------------------------------------------------- /src/denoisers/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | """Scripts.""" 2 | -------------------------------------------------------------------------------- /tests/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """Test datasets.""" 2 | -------------------------------------------------------------------------------- /tests/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | """Test modeling.""" 2 | -------------------------------------------------------------------------------- /src/denoisers/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """Dataset classes for the denoisers.""" 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Legacy support for setup.py based installs.""" 2 | from setuptools import setup 3 | 4 | setup() 5 | -------------------------------------------------------------------------------- /tests/assets/reverb/rir_sample.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/will-rice/denoisers/HEAD/tests/assets/reverb/rir_sample.wav -------------------------------------------------------------------------------- /src/denoisers/modeling/unet1d/__init__.py: -------------------------------------------------------------------------------- 1 | """UNet1D model.""" 2 | 3 | from denoisers.modeling.unet1d.config import UNet1DConfig 4 | from denoisers.modeling.unet1d.model import UNet1DModel 5 | 6 | __all__ = ["UNet1DConfig", "UNet1DModel"] 7 | -------------------------------------------------------------------------------- /src/denoisers/modeling/waveunet/__init__.py: -------------------------------------------------------------------------------- 1 | """WaveUnet model.""" 2 | 3 | from denoisers.modeling.waveunet.config import WaveUNetConfig 4 | from denoisers.modeling.waveunet.model import WaveUNetModel 5 | 6 | __all__ = ["WaveUNetConfig", "WaveUNetModel"] 7 | -------------------------------------------------------------------------------- /src/denoisers/__init__.py: -------------------------------------------------------------------------------- 1 | """Denoisers for the 1D and 2D cases.""" 2 | 3 | from denoisers.modeling.unet1d.config import UNet1DConfig 4 | from denoisers.modeling.unet1d.model import UNet1DModel 5 | from denoisers.modeling.waveunet.config import WaveUNetConfig 6 | from denoisers.modeling.waveunet.model import WaveUNetModel 7 | 8 | __all__ = ["WaveUNetConfig", "WaveUNetModel", "UNet1DConfig", "UNet1DModel"] 9 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | """Tests for metrics.""" 2 | 3 | from denoisers.metrics import PESQ 4 | from denoisers.testing import sine_wave 5 | 6 | 7 | def test_calculate_pesq() -> None: 8 | """Test calculate_pesq.""" 9 | pesq = PESQ(sample_rate=16000) 10 | audio = sine_wave(800, 1, 16000)[None] 11 | pred = sine_wave(800, 1, 16000)[None] 12 | score = pesq(pred, audio) 13 | assert score > 0.0 14 | -------------------------------------------------------------------------------- /src/denoisers/testing/__init__.py: -------------------------------------------------------------------------------- 1 | """Utility functions for tests.""" 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | def sine_wave(frequency: float, duration: float, sample_rate: int) -> Tensor: 8 | """Generate a sine wave. 9 | 10 | Args: 11 | ---- 12 | frequency: Frequency of the sine wave. 13 | duration: Duration of the sine wave in seconds. 14 | sample_rate: Sample rate of the sine wave. 15 | 16 | Returns: 17 | ------- 18 | A torch tensor containing the sine wave. 19 | """ 20 | return torch.sin( 21 | 2 * torch.pi * torch.arange(sample_rate * duration) * frequency / sample_rate, 22 | ).unsqueeze(0) 23 | -------------------------------------------------------------------------------- /src/denoisers/scripts/publish.py: -------------------------------------------------------------------------------- 1 | """Publish model script.""" 2 | 3 | import argparse 4 | from pathlib import Path 5 | 6 | from denoisers.modeling import CONFIGS, MODELS 7 | 8 | 9 | def main() -> None: 10 | """Run publish.""" 11 | parser = argparse.ArgumentParser("Publish model to huggingface hub.") 12 | parser.add_argument("model", type=str, choices=MODELS.keys()) 13 | parser.add_argument("name", type=str) 14 | parser.add_argument("path", type=Path) 15 | args = parser.parse_args() 16 | 17 | model = MODELS[args.model](CONFIGS[args.model]()) 18 | model.from_pretrained(args.path) 19 | model.push_to_hub(args.name) 20 | 21 | 22 | if __name__ == "__main__": 23 | main() 24 | -------------------------------------------------------------------------------- /src/denoisers/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | """Models.""" 2 | 3 | from typing import Any 4 | 5 | from denoisers.modeling.unet1d.config import UNet1DConfig 6 | from denoisers.modeling.unet1d.model import UNet1DModel 7 | from denoisers.modeling.waveunet.config import WaveUNetConfig 8 | from denoisers.modeling.waveunet.model import WaveUNetModel 9 | 10 | __all__ = [ 11 | "WaveUNetConfig", 12 | "WaveUNetModel", 13 | "UNet1DConfig", 14 | "UNet1DModel", 15 | "MODELS", 16 | "CONFIGS", 17 | ] 18 | 19 | MODELS: dict[str, Any] = { 20 | "unet1d": UNet1DModel, 21 | "waveunet": WaveUNetModel, 22 | } # Add your models here 23 | CONFIGS: dict[str, Any] = { 24 | "unet1d": UNet1DConfig, 25 | "waveunet": WaveUNetConfig, 26 | } # Add your configs here 27 | -------------------------------------------------------------------------------- /tests/datasets/test_audio.py: -------------------------------------------------------------------------------- 1 | """Test audio datasets.""" 2 | 3 | from pathlib import Path 4 | 5 | import torchaudio 6 | 7 | from denoisers.datasets.audio import AudioDataset 8 | from denoisers.testing import sine_wave 9 | 10 | 11 | def test_audio_dataset(tmpdir): 12 | """Test audio dataset.""" 13 | save_root = Path(tmpdir) / "test_dataset" 14 | save_root.mkdir(exist_ok=True, parents=True) 15 | save_path = save_root / "sample.flac" 16 | 17 | audio = sine_wave(800, 1, 16000) 18 | torchaudio.save(str(save_path), audio, 16000) 19 | 20 | dataset = AudioDataset(save_path.parent, sample_rate=16000, max_length=16384) 21 | assert len(dataset) == 1 22 | batch = dataset[0] 23 | assert batch.audio.shape == (1, 16384) 24 | assert batch.noisy.shape == (1, 16384) 25 | assert batch.lengths.shape == () 26 | -------------------------------------------------------------------------------- /tests/modeling/test_pretrained.py: -------------------------------------------------------------------------------- 1 | """Test HuggingFace Hub pretrained models.""" 2 | 3 | from denoisers import UNet1DModel 4 | from denoisers.modeling import WaveUNetModel 5 | 6 | 7 | def test_waveunet_vctk_24khz(): 8 | """Test WaveUNet 24kHz VCTK.""" 9 | model = WaveUNetModel.from_pretrained("wrice/waveunet-vctk-24khz") 10 | assert model.config.sample_rate == 24000 11 | assert model.config.norm_type == "batch" 12 | 13 | 14 | def test_waveunet_vctk_48khz(): 15 | """Test WaveUNet 48kHz VCTK.""" 16 | model = WaveUNetModel.from_pretrained("wrice/waveunet-vctk-48khz") 17 | assert model.config.sample_rate == 48000 18 | assert model.config.norm_type == "batch" 19 | 20 | 21 | def test_unet1d_vctk_48khz(): 22 | """Test UNet1D 48kHz VCTK.""" 23 | model = UNet1DModel.from_pretrained("wrice/unet1d-vctk-48khz") 24 | assert model.config.sample_rate == 48000 25 | assert model.config.norm_type == "layer" 26 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: ["main"] 9 | pull_request: 10 | branches: ["main"] 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | python-version: ["3.11", "3.12", "3.13"] 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v3 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | 27 | - name: Install uv 28 | run: | 29 | curl -LsSf https://astral.sh/uv/install.sh | sh 30 | - name: Run pre-commit hook 31 | run: | 32 | uv run pre-commit run -a 33 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v6.0.0 4 | hooks: 5 | - id: check-ast 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - id: check-merge-conflict 9 | - id: requirements-txt-fixer 10 | - repo: https://github.com/hukkin/mdformat 11 | rev: 0.7.22 # Use the desired version 12 | hooks: 13 | - id: mdformat 14 | # Optionally add plugins for extended functionality, e.g., GitHub Flavored Markdown (GFM) 15 | additional_dependencies: 16 | - mdformat-gfm 17 | - mdformat-black # For black-style formatting within code blocks 18 | - repo: local 19 | hooks: 20 | - id: ruff 21 | name: ruff 22 | language: system 23 | entry: uv run ruff check . --fix 24 | pass_filenames: false 25 | always_run: true 26 | - id: mypy 27 | name: mypy 28 | language: system 29 | entry: uv run mypy . 30 | pass_filenames: false 31 | always_run: true 32 | - id: pytest 33 | name: pytest 34 | language: system 35 | entry: uv run python -m pytest 36 | pass_filenames: false 37 | always_run: true 38 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v3 24 | - name: Set up Python 25 | uses: actions/setup-python@v3 26 | with: 27 | python-version: "3.x" 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /tests/modeling/test_unet1d_model.py: -------------------------------------------------------------------------------- 1 | """Tests for WaveUNet model.""" 2 | 3 | import torch 4 | 5 | from denoisers.modeling.unet1d.model import UNet1DConfig, UNet1DModel 6 | 7 | 8 | def test_config(): 9 | """Test config.""" 10 | config = UNet1DConfig( 11 | max_length=8192, 12 | sample_rate=16000, 13 | channels=(1, 2, 3, 4, 5, 6), 14 | kernel_size=3, 15 | dropout=0.1, 16 | activation="silu", 17 | autoencoder=False, 18 | ) 19 | assert config.max_length == 8192 20 | assert config.sample_rate == 16000 21 | assert config.channels == (1, 2, 3, 4, 5, 6) 22 | assert config.kernel_size == 3 23 | assert config.dropout == 0.1 24 | assert config.activation == "silu" 25 | assert config.autoencoder is False 26 | 27 | 28 | def test_model() -> None: 29 | """Test model.""" 30 | config = UNet1DConfig( 31 | max_length=16384, 32 | sample_rate=16000, 33 | channels=(2, 4, 6, 8), 34 | kernel_size=3, 35 | num_groups=2, 36 | ) 37 | model = UNet1DModel(config) 38 | model.eval() 39 | 40 | audio = torch.randn(1, 1, config.max_length) 41 | with torch.no_grad(): 42 | recon = model(audio).audio 43 | 44 | assert isinstance(recon, torch.Tensor) 45 | assert audio.shape == recon.shape 46 | -------------------------------------------------------------------------------- /src/denoisers/modeling/unet1d/config.py: -------------------------------------------------------------------------------- 1 | """Unet1D configuration file.""" 2 | 3 | from typing import Any, Optional 4 | 5 | from transformers import PretrainedConfig 6 | 7 | 8 | class UNet1DConfig(PretrainedConfig): 9 | """Configuration class to store the configuration of a `UNet1DModel`.""" 10 | 11 | model_type = "unet1d" 12 | 13 | def __init__( 14 | self, 15 | channels: tuple[int, ...] = ( 16 | 32, 17 | 64, 18 | 96, 19 | 128, 20 | 160, 21 | 192, 22 | 224, 23 | 256, 24 | 288, 25 | 320, 26 | 352, 27 | 384, 28 | ), 29 | kernel_size: int = 3, 30 | num_groups: Optional[int] = None, 31 | dropout: float = 0.1, 32 | activation: str = "silu", 33 | autoencoder: bool = True, 34 | max_length: int = 48000, 35 | sample_rate: int = 48000, 36 | norm_type: str = "layer", 37 | **kwargs: Any, 38 | ) -> None: 39 | self.channels = channels 40 | self.kernel_size = kernel_size 41 | self.num_groups = num_groups 42 | self.dropout = dropout 43 | self.activation = activation 44 | self.autoencoder = autoencoder 45 | self.norm_type = norm_type 46 | super().__init__(**kwargs, max_length=max_length, sample_rate=sample_rate) 47 | -------------------------------------------------------------------------------- /src/denoisers/modeling/waveunet/config.py: -------------------------------------------------------------------------------- 1 | """WaveUNet configuration file.""" 2 | 3 | from typing import Any, Optional 4 | 5 | from transformers import PretrainedConfig 6 | 7 | 8 | class WaveUNetConfig(PretrainedConfig): 9 | """Configuration class to store the configuration of a `WaveUNetModel`.""" 10 | 11 | model_type = "waveunet" 12 | 13 | def __init__( 14 | self, 15 | in_channels: tuple[int, ...] = ( 16 | 24, 17 | 48, 18 | 72, 19 | 96, 20 | 120, 21 | 144, 22 | 168, 23 | 192, 24 | 216, 25 | 240, 26 | 264, 27 | 288, 28 | ), 29 | downsample_kernel_size: int = 15, 30 | upsample_kernel_size: int = 5, 31 | dropout: float = 0.1, 32 | activation: str = "leaky_relu", 33 | autoencoder: bool = False, 34 | max_length: int = 16384 * 10, 35 | sample_rate: int = 48000, 36 | norm_type: str = "batch", 37 | num_groups: Optional[int] = None, 38 | **kwargs: Any, 39 | ) -> None: 40 | self.in_channels = in_channels 41 | self.downsample_kernel_size = downsample_kernel_size 42 | self.upsample_kernel_size = upsample_kernel_size 43 | self.dropout = dropout 44 | self.activation = activation 45 | self.autoencoder = autoencoder 46 | self.norm_type = norm_type 47 | self.num_groups = num_groups 48 | super().__init__(**kwargs, max_length=max_length, sample_rate=sample_rate) 49 | -------------------------------------------------------------------------------- /src/denoisers/metrics.py: -------------------------------------------------------------------------------- 1 | """Metrics for denoising.""" 2 | 3 | import torch 4 | import torchaudio 5 | from torch import nn 6 | from torchmetrics.functional.audio.dnsmos import ( 7 | deep_noise_suppression_mean_opinion_score, 8 | ) 9 | from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality 10 | 11 | 12 | class PESQ(nn.Module): 13 | """PESQ metric.""" 14 | 15 | def __init__(self, sample_rate: int = 48000): 16 | super().__init__() 17 | self.resample = torchaudio.transforms.Resample(sample_rate, 16000) 18 | 19 | def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 20 | """Forward pass.""" 21 | pred_resample = self.resample(preds) 22 | true_resample = self.resample(target) 23 | 24 | shortest = min(pred_resample.shape[-1], true_resample.shape[-1]) 25 | pred_resample = pred_resample[:, :, :shortest].squeeze(1) 26 | true_resample = true_resample[:, :, :shortest].squeeze(1) 27 | 28 | score = perceptual_evaluation_speech_quality( 29 | pred_resample, true_resample, fs=16000, mode="wb" 30 | ) 31 | 32 | return score.mean() 33 | 34 | 35 | class DNSMOS(nn.Module): 36 | """DNSMOS metric.""" 37 | 38 | def __init__(self, sample_rate: int = 48000): 39 | super().__init__() 40 | self.sample_rate = sample_rate 41 | 42 | def forward(self, preds: torch.Tensor) -> torch.Tensor: 43 | """Forward pass.""" 44 | pred_resample = torchaudio.functional.resample(preds, self.sample_rate, 16000) 45 | score = deep_noise_suppression_mean_opinion_score(pred_resample, 16000, False) 46 | return score.mean() 47 | -------------------------------------------------------------------------------- /src/denoisers/datamodule.py: -------------------------------------------------------------------------------- 1 | """Lightning datamodule.""" 2 | 3 | from typing import Optional 4 | 5 | import numpy as np 6 | import torch 7 | from pytorch_lightning import LightningDataModule 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | 11 | class DenoisersDataModule(LightningDataModule): 12 | """Lightning DataModule for denoisers.""" 13 | 14 | def __init__( 15 | self, dataset: Dataset, batch_size: int = 24, num_workers: int = 8 16 | ) -> None: 17 | super().__init__() 18 | 19 | self._dataset = dataset 20 | self._batch_size = batch_size 21 | self._num_workers = num_workers 22 | 23 | def setup(self, stage: Optional[str] = "fit") -> None: 24 | """Split datasets.""" 25 | train_split = int(np.floor(len(self._dataset) * 0.95)) # type: ignore 26 | val_split = int(np.ceil(len(self._dataset) * 0.05)) # type: ignore 27 | 28 | self.train_dataset, self.val_dataset = torch.utils.data.random_split( 29 | self._dataset, 30 | lengths=(train_split, val_split), 31 | ) 32 | 33 | def train_dataloader(self) -> DataLoader: 34 | """Initialize train dataloader.""" 35 | return DataLoader( 36 | self.train_dataset, 37 | batch_size=self._batch_size, 38 | num_workers=self._num_workers, 39 | shuffle=True, 40 | ) 41 | 42 | def val_dataloader(self) -> DataLoader: 43 | """Initialize validation dataloader.""" 44 | return DataLoader( 45 | self.val_dataset, 46 | batch_size=self._batch_size, 47 | num_workers=self._num_workers, 48 | shuffle=False, 49 | drop_last=True, 50 | ) 51 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=78.1.1", 4 | ] 5 | build-backend = "setuptools.build_meta" 6 | 7 | [project] 8 | name = "denoisers" 9 | version = "0.2.0" 10 | authors = [ 11 | { name="Will Rice", email="wrice20@gmail.com" }, 12 | ] 13 | description = "A package for training audio denoisers" 14 | readme = "README.md" 15 | requires-python = ">=3.9" 16 | classifiers = [ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: Apache Software License", 19 | "Operating System :: OS Independent", 20 | ] 21 | dependencies = [ 22 | "mypy>=1.13.0", 23 | "pydocstyle>=6.3.0", 24 | "pytest>=8.3.3", 25 | "ruff>=0.7.4", 26 | "wandb>=0.18.7", 27 | "matplotlib>=3.9.2", 28 | "pedalboard>=0.9.16", 29 | "pydub>=0.25.1", 30 | "pyroomacoustics>=0.8.2", 31 | "pre-commit>=4.0.1", 32 | "librosa>=0.10.2.post1", 33 | "audiomentations>=0.37.0", 34 | "onnxruntime>=1.20.1", 35 | "pesq>=0.0.4", 36 | "torch>=2.5.1", 37 | "torchaudio>=2.5.1", 38 | "transformers>=4.48.0", 39 | "torchvision>=0.20.1", 40 | "pytorch-lightning>=2.5.0.post0", 41 | ] 42 | 43 | [project.urls] 44 | "Homepage" = "https://github.com/will-rice/denoisers" 45 | "Bug Tracker" = "https://github.com/will-rice/denoisers/issues" 46 | 47 | [tool.ruff.lint.isort] 48 | known-first-party = ["denoisers"] 49 | 50 | [tool.ruff.lint] 51 | select = ["C", "E", "F", "I", "W", "D", "N", "B"] 52 | ignore = ["D107"] 53 | exclude = [".venv"] 54 | 55 | [tool.ruff.lint.pydocstyle] 56 | convention = "google" 57 | 58 | [tool.pytest.ini_options] 59 | filterwarnings = [ 60 | "ignore::DeprecationWarning", 61 | ] 62 | 63 | [tool.mypy] 64 | ignore_missing_imports = true 65 | follow_imports_for_stubs = true 66 | strict = false 67 | exclude = [".venv"] 68 | 69 | [project.scripts] 70 | train = "denoisers.scripts.train:main" 71 | publish = "denoisers.scripts.publish:main" 72 | -------------------------------------------------------------------------------- /tests/modeling/test_modules.py: -------------------------------------------------------------------------------- 1 | """Tests for modules.""" 2 | 3 | import torch 4 | 5 | from denoisers.modeling.modules import Activation, Downsample1D, Upsample1D 6 | 7 | 8 | def test_downsample_1d(): 9 | """Test downsample 1d.""" 10 | downsample = Downsample1D(1, 2, 3, 2, True, 1, True) 11 | 12 | assert isinstance(downsample, Downsample1D) 13 | assert isinstance(downsample.conv, torch.nn.Conv1d) 14 | assert downsample.conv.in_channels == 1 15 | assert downsample.conv.out_channels == 2 16 | assert downsample.conv.kernel_size == (3,) 17 | assert downsample.conv.stride == (2,) 18 | assert downsample.conv.padding == (1,) 19 | assert downsample.conv.bias is not None 20 | 21 | audio = torch.randn(1, 1, 800) 22 | out = downsample(audio) 23 | 24 | assert out.shape == (1, 2, 400) 25 | 26 | 27 | def test_upsample_1d(): 28 | """Test upsample 1d.""" 29 | upsample = Upsample1D(in_channels=1, out_channels=2, kernel_size=3, use_conv=True) 30 | 31 | assert isinstance(upsample, Upsample1D) 32 | assert isinstance(upsample.conv, torch.nn.Conv1d) 33 | assert upsample.conv.in_channels == 1 34 | assert upsample.conv.out_channels == 2 35 | assert upsample.conv.kernel_size == (3,) 36 | assert upsample.conv.stride == (1,) 37 | assert upsample.conv.padding == (1,) 38 | assert upsample.conv.bias is not None 39 | 40 | audio = torch.randn(1, 1, 800) 41 | out = upsample(audio) 42 | 43 | assert out.shape == (1, 2, 1600) 44 | 45 | 46 | def test_activation(): 47 | """Test activation.""" 48 | activation = Activation("relu") 49 | 50 | assert isinstance(activation, Activation) 51 | assert isinstance(activation.activation, torch.nn.ReLU) 52 | 53 | audio = torch.randn(1, 1, 800) 54 | out = activation(audio) 55 | 56 | assert out.shape == (1, 1, 800) 57 | 58 | # test leaky relu 59 | activation = Activation("leaky_relu") 60 | 61 | assert isinstance(activation, Activation) 62 | assert isinstance(activation.activation, torch.nn.LeakyReLU) 63 | 64 | audio = torch.randn(1, 1, 800) 65 | out = activation(audio) 66 | 67 | assert out.shape == (1, 1, 800) 68 | 69 | # test silu 70 | activation = Activation("silu") 71 | 72 | assert isinstance(activation, Activation) 73 | assert isinstance(activation.activation, torch.nn.SiLU) 74 | 75 | audio = torch.randn(1, 1, 800) 76 | out = activation(audio) 77 | 78 | assert out.shape == (1, 1, 800) 79 | -------------------------------------------------------------------------------- /tests/modeling/test_lightning_module.py: -------------------------------------------------------------------------------- 1 | """Test lightning module.""" 2 | 3 | import torch 4 | 5 | from denoisers import UNet1DConfig, UNet1DModel, WaveUNetConfig, WaveUNetModel 6 | from denoisers.datasets.audio import Batch 7 | from denoisers.lightning_module import DenoisersLightningModule 8 | 9 | 10 | def test_waveunet_lightning_module() -> None: 11 | """Test waveunet lightning module.""" 12 | config = WaveUNetConfig( 13 | max_length=16384, 14 | sample_rate=16000, 15 | in_channels=(1, 2, 3), 16 | downsample_kernel_size=3, 17 | upsample_kernel_size=3, 18 | ) 19 | model = WaveUNetModel(config) 20 | lightning_module = DenoisersLightningModule(model) 21 | 22 | audio = torch.randn(1, 1, config.max_length) 23 | batch = Batch(audio=audio, noisy=audio, lengths=torch.tensor([audio.shape[-1]])) 24 | 25 | # test forward 26 | with torch.no_grad(): 27 | recon = lightning_module(audio).audio 28 | 29 | assert isinstance(recon, torch.Tensor) 30 | assert audio.shape == recon.shape 31 | 32 | # test training step 33 | loss = lightning_module.training_step(batch, 0) 34 | assert isinstance(loss, torch.Tensor) 35 | assert loss.shape == torch.Size([]) 36 | 37 | # test validation step 38 | loss = lightning_module.validation_step(batch, 0) 39 | assert isinstance(loss, torch.Tensor) 40 | assert loss.shape == torch.Size([]) 41 | 42 | 43 | def test_unet1d_lightning_module() -> None: 44 | """Test unet1d lightning module.""" 45 | config = UNet1DConfig( 46 | max_length=16384, 47 | sample_rate=16000, 48 | in_channels=(1, 2, 3), 49 | downsample_kernel_size=3, 50 | upsample_kernel_size=3, 51 | ) 52 | model = UNet1DModel(config) 53 | lightning_module = DenoisersLightningModule(model) 54 | 55 | audio = torch.randn(1, 1, config.max_length) 56 | batch = Batch(audio=audio, noisy=audio, lengths=torch.tensor([audio.shape[-1]])) 57 | 58 | # test forward 59 | with torch.no_grad(): 60 | recon = lightning_module(audio).audio 61 | 62 | assert isinstance(recon, torch.Tensor) 63 | assert audio.shape == recon.shape 64 | 65 | # test training step 66 | loss = lightning_module.training_step(batch, 0) 67 | assert isinstance(loss, torch.Tensor) 68 | assert loss.shape == torch.Size([]) 69 | 70 | # test validation step 71 | loss = lightning_module.validation_step(batch, 0) 72 | assert isinstance(loss, torch.Tensor) 73 | assert loss.shape == torch.Size([]) 74 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Some extras 132 | wandb/ 133 | logs/ 134 | 135 | # ruff 136 | .ruff_cache/ 137 | 138 | # mac 139 | .DS_Store 140 | -------------------------------------------------------------------------------- /src/denoisers/modeling/waveunet/modules.py: -------------------------------------------------------------------------------- 1 | """WaveUnet modules.""" 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from denoisers.modeling.modules import ( 9 | Activation, 10 | Downsample1D, 11 | Normalization, 12 | Upsample1D, 13 | ) 14 | 15 | 16 | class DownsampleBlock1D(nn.Module): 17 | """1d downsample block.""" 18 | 19 | def __init__( 20 | self, 21 | in_channels: int, 22 | out_channels: int, 23 | kernel_size: int = 3, 24 | stride: int = 2, 25 | dropout: float = 0.0, 26 | activation: str = "leaky_relu", 27 | bias: bool = True, 28 | num_groups: Optional[int] = None, 29 | norm_type: str = "batch", 30 | ) -> None: 31 | super().__init__() 32 | 33 | self.downsample = Downsample1D( 34 | in_channels, 35 | out_channels=out_channels, 36 | kernel_size=kernel_size, 37 | stride=stride, 38 | use_conv=True, 39 | padding=kernel_size // 2, 40 | bias=bias, 41 | ) 42 | self.norm = Normalization(out_channels, num_groups=num_groups, name=norm_type) 43 | self.activation = Activation(activation) 44 | self.dropout = nn.Dropout(dropout) 45 | 46 | def forward(self, x: torch.Tensor) -> torch.Tensor: 47 | """Forward pass.""" 48 | x = self.downsample(x) 49 | x = self.norm(x) 50 | x = self.activation(x) 51 | x = self.dropout(x) 52 | return x 53 | 54 | 55 | class UpsampleBlock1D(nn.Module): 56 | """1d upsample block.""" 57 | 58 | def __init__( 59 | self, 60 | in_channels: int, 61 | out_channels: int, 62 | kernel_size: int = 3, 63 | dropout: float = 0.0, 64 | activation: str = "leaky_relu", 65 | bias: bool = True, 66 | norm_type: str = "batch", 67 | num_groups: Optional[int] = None, 68 | ) -> None: 69 | super().__init__() 70 | self.upsample = Upsample1D( 71 | in_channels, 72 | out_channels=out_channels, 73 | kernel_size=kernel_size, 74 | use_conv=True, 75 | padding=kernel_size // 2, 76 | bias=bias, 77 | ) 78 | self.norm = Normalization(out_channels, num_groups=num_groups, name=norm_type) 79 | self.activation = Activation(activation) 80 | self.dropout = nn.Dropout(dropout) 81 | 82 | def forward(self, x: torch.Tensor) -> torch.Tensor: 83 | """Forward pass.""" 84 | x = self.upsample(x) 85 | x = self.norm(x) 86 | x = self.activation(x) 87 | x = self.dropout(x) 88 | return x 89 | -------------------------------------------------------------------------------- /tests/modeling/test_waveunet_model.py: -------------------------------------------------------------------------------- 1 | """Tests for WaveUNet model.""" 2 | 3 | import torch 4 | 5 | from denoisers.modeling.modules import Downsample1D, Normalization, Upsample1D 6 | from denoisers.modeling.waveunet.model import WaveUNetConfig, WaveUNetModel 7 | from denoisers.modeling.waveunet.modules import DownsampleBlock1D, UpsampleBlock1D 8 | 9 | 10 | def test_config() -> None: 11 | """Test config.""" 12 | config = WaveUNetConfig( 13 | max_length=8192, 14 | sample_rate=16000, 15 | in_channels=(1, 2, 3, 4, 5, 6), 16 | downsample_kernel_size=3, 17 | upsample_kernel_size=3, 18 | dropout=0.1, 19 | activation="leaky_relu", 20 | autoencoder=False, 21 | ) 22 | assert config.max_length == 8192 23 | assert config.sample_rate == 16000 24 | assert config.in_channels == (1, 2, 3, 4, 5, 6) 25 | assert config.downsample_kernel_size == 3 26 | assert config.upsample_kernel_size == 3 27 | assert config.dropout == 0.1 28 | assert config.activation == "leaky_relu" 29 | assert config.autoencoder is False 30 | 31 | 32 | def test_model() -> None: 33 | """Test model.""" 34 | config = WaveUNetConfig( 35 | max_length=16384, 36 | sample_rate=16000, 37 | in_channels=(1, 2, 3), 38 | downsample_kernel_size=3, 39 | upsample_kernel_size=3, 40 | ) 41 | model = WaveUNetModel(config) 42 | model.eval() 43 | 44 | audio = torch.randn(1, 1, config.max_length) 45 | with torch.no_grad(): 46 | recon = model(audio).audio 47 | 48 | assert isinstance(recon, torch.Tensor) 49 | assert audio.shape == recon.shape 50 | 51 | 52 | def test_upsample_block_1d(): 53 | """Test upsample block 1d.""" 54 | block = UpsampleBlock1D(1, 2, 3, 0.1, "leaky_relu", True) 55 | 56 | assert isinstance(block, UpsampleBlock1D) 57 | assert isinstance(block.upsample, Upsample1D) 58 | assert isinstance(block.norm, Normalization) 59 | assert isinstance(block.activation.activation, torch.nn.LeakyReLU) 60 | assert isinstance(block.dropout, torch.nn.Dropout) 61 | 62 | audio = torch.randn(1, 1, 800) 63 | out = block(audio) 64 | 65 | assert out.shape == (1, 2, 1600) 66 | 67 | 68 | def test_downsample_block_1d(): 69 | """Test downsample block 1d.""" 70 | block = DownsampleBlock1D(1, 2, 3, 2, 0.1, "leaky_relu", True) 71 | 72 | assert isinstance(block, DownsampleBlock1D) 73 | assert isinstance(block.downsample, Downsample1D) 74 | assert isinstance(block.norm, Normalization) 75 | assert isinstance(block.activation.activation, torch.nn.LeakyReLU) 76 | assert isinstance(block.dropout, torch.nn.Dropout) 77 | 78 | audio = torch.randn(1, 1, 800) 79 | out = block(audio) 80 | 81 | assert out.shape == (1, 2, 400) 82 | -------------------------------------------------------------------------------- /src/denoisers/datasets/audio.py: -------------------------------------------------------------------------------- 1 | """Audio dataset.""" 2 | 3 | import random 4 | from pathlib import Path 5 | from typing import NamedTuple 6 | 7 | import torch 8 | import torchaudio 9 | from audiomentations import AddColorNoise, AddGaussianNoise, Compose, RoomSimulator 10 | from torch.utils.data import Dataset 11 | 12 | SUPPORTED_EXTENSIONS = {".wav", ".flac", ".mp3", ".ogg"} 13 | SAMPLE_RATES = [8000, 16000, 22050, 24000, 32000, 44100, 48000] 14 | 15 | 16 | class Batch(NamedTuple): 17 | """Batch of inputs.""" 18 | 19 | audio: torch.Tensor 20 | noisy: torch.Tensor 21 | lengths: torch.Tensor 22 | 23 | 24 | class AudioDataset(Dataset): 25 | """Simple audio dataset.""" 26 | 27 | def __init__( 28 | self, 29 | root: Path, 30 | max_length: int, 31 | sample_rate: int, 32 | variable_sample_rate: bool = True, 33 | ) -> None: 34 | super().__init__() 35 | self._root = root 36 | 37 | self._paths = [] 38 | for ext in SUPPORTED_EXTENSIONS: 39 | self._paths.extend(list(self._root.glob(f"**/*{ext}"))) 40 | 41 | self._max_length = max_length 42 | self._sample_rate = sample_rate 43 | self._variable_sample_rate = variable_sample_rate 44 | 45 | self._transforms = Compose( 46 | [ 47 | RoomSimulator( 48 | p=0.8, leave_length_unchanged=True, use_ray_tracing=False 49 | ), 50 | AddColorNoise(p=0.97), 51 | AddGaussianNoise(p=0.97), 52 | ] 53 | ) 54 | 55 | def __len__(self) -> int: 56 | """Return length of dataset.""" 57 | return len(self._paths) 58 | 59 | def __getitem__(self, idx: int) -> Batch: 60 | """Return item from dataset.""" 61 | path = self._paths[idx] 62 | audio, sr = torchaudio.load(str(path)) 63 | 64 | if audio.shape[0] > 1: 65 | audio = audio.mean(0, keepdim=True) 66 | 67 | new_sr = ( 68 | random.choice(SAMPLE_RATES) 69 | if self._variable_sample_rate 70 | else self._sample_rate 71 | ) 72 | if sr != self._sample_rate: 73 | audio = torchaudio.functional.resample(audio, sr, new_sr) 74 | 75 | audio_length = min(audio.shape[-1], self._max_length) 76 | 77 | if audio_length < self._max_length: 78 | pad_length = self._max_length - audio_length 79 | audio = torch.nn.functional.pad(audio, (0, pad_length)) 80 | else: 81 | start_idx = random.randint(0, audio.shape[-1] - self._max_length) 82 | audio = audio[:, start_idx : start_idx + self._max_length] 83 | 84 | noisy = self._transforms(audio.clone().numpy(), sample_rate=new_sr) 85 | noisy = torch.from_numpy(noisy) 86 | 87 | return Batch(audio=audio, noisy=noisy, lengths=torch.tensor(audio_length)) 88 | -------------------------------------------------------------------------------- /src/denoisers/scripts/train.py: -------------------------------------------------------------------------------- 1 | """Train script.""" 2 | 3 | import argparse 4 | import warnings 5 | from pathlib import Path 6 | 7 | import torch 8 | from pytorch_lightning import Trainer, callbacks, loggers, seed_everything 9 | 10 | from denoisers.datamodule import DenoisersDataModule 11 | from denoisers.datasets.audio import AudioDataset 12 | from denoisers.lightning_module import DenoisersLightningModule 13 | from denoisers.modeling import CONFIGS, MODELS 14 | 15 | if torch.cuda.is_available(): 16 | torch.backends.cudnn.benchmark = True 17 | torch.backends.cudnn.allow_tf32 = True 18 | torch.backends.cuda.matmul.allow_tf32 = True 19 | 20 | 21 | warnings.filterwarnings("ignore") 22 | 23 | 24 | def main() -> None: 25 | """Run training.""" 26 | parser = argparse.ArgumentParser("train parser") 27 | parser.add_argument("model", type=str, choices=MODELS.keys()) 28 | parser.add_argument("name", type=str) 29 | parser.add_argument("data_root", type=Path) 30 | parser.add_argument("--project", default="denoisers", type=str) 31 | parser.add_argument( 32 | "--num_devices", 33 | default=1 if torch.cuda.is_available() else None, 34 | type=int, 35 | ) 36 | parser.add_argument("--batch_size", default=64, type=int) 37 | parser.add_argument("--num_workers", default=4, type=int) 38 | parser.add_argument("--seed", default=1234, type=int) 39 | parser.add_argument("--log_path", default="logs", type=Path) 40 | parser.add_argument("--checkpoint_path", default=None, type=Path) 41 | parser.add_argument("--ema", action="store_true") 42 | parser.add_argument("--push_to_hub", action="store_true") 43 | parser.add_argument("--debug", action="store_true") 44 | 45 | args = parser.parse_args() 46 | 47 | seed_everything(args.seed) 48 | 49 | log_path = args.log_path / args.name 50 | log_path.mkdir(exist_ok=True, parents=True) 51 | 52 | config = CONFIGS[args.model]() 53 | model = MODELS[args.model](config) 54 | lightning_module = DenoisersLightningModule( 55 | model, 56 | sync_dist=args.num_devices > 1, 57 | use_ema=args.ema, 58 | push_to_hub=args.push_to_hub, 59 | ) 60 | 61 | dataset = AudioDataset( 62 | args.data_root, max_length=config.max_length, sample_rate=config.sample_rate 63 | ) 64 | datamodule = DenoisersDataModule( 65 | dataset, batch_size=args.batch_size, num_workers=args.num_workers 66 | ) 67 | logger = loggers.WandbLogger( 68 | project=args.project, 69 | save_dir=log_path, 70 | name=args.name, 71 | offline=args.debug, 72 | ) 73 | checkpoint_callback = callbacks.ModelCheckpoint( 74 | dirpath=log_path, 75 | filename="{step}", 76 | save_last=True, 77 | ) 78 | lr_monitor = callbacks.LearningRateMonitor(logging_interval="step") 79 | 80 | pretrained = args.checkpoint_path 81 | last_checkpoint = pretrained if pretrained else log_path / "last.ckpt" 82 | 83 | trainer = Trainer( 84 | default_root_dir=log_path, 85 | max_epochs=10000, 86 | accelerator="auto", 87 | val_check_interval=0.25 if len(dataset) // args.batch_size > 5000 else 1.0, 88 | devices=args.num_devices, 89 | logger=logger, 90 | precision="bf16-mixed", 91 | accumulate_grad_batches=2, 92 | callbacks=[checkpoint_callback, lr_monitor], 93 | strategy="deepspeed_stage_2" if args.num_devices > 1 else "auto", 94 | ) 95 | 96 | trainer.fit( 97 | lightning_module, 98 | datamodule=datamodule, 99 | ckpt_path=last_checkpoint if last_checkpoint.exists() else None, 100 | ) 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /src/denoisers/utils.py: -------------------------------------------------------------------------------- 1 | """General utilities for the denoisers.""" 2 | 3 | from typing import Any, Optional 4 | 5 | import librosa 6 | import matplotlib.pyplot as plt 7 | import torch 8 | import torchaudio 9 | import wandb 10 | 11 | SPEC_FN = torchaudio.transforms.Spectrogram( 12 | n_fft=2048, 13 | win_length=1024, 14 | hop_length=256, 15 | center=True, 16 | pad_mode="constant", 17 | power=2.0, 18 | ) 19 | 20 | 21 | def sequence_mask(length: Any, max_length: Optional[Any] = None) -> torch.Tensor: 22 | """Create a boolean mask from sequence lengths.""" 23 | if max_length is None: 24 | max_length = length.max() 25 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 26 | mask = x.unsqueeze(0) < length.unsqueeze(1) 27 | return mask.unsqueeze(1) 28 | 29 | 30 | def plot_image_batch( 31 | clean: torch.Tensor, 32 | noisy: torch.Tensor, 33 | preds: torch.Tensor, 34 | name: str, 35 | ) -> None: 36 | """Plot a batch of images and log them to wandb.""" 37 | np_clean = clean.squeeze(1).float().numpy(force=True)[:5] 38 | np_noisy = noisy.squeeze(1).float().numpy(force=True)[:5] 39 | np_preds = preds.squeeze(1).float().numpy(force=True)[:5] 40 | 41 | fig, ax = plt.subplots(len(np_clean), 3, figsize=(20, 5 * len(np_clean))) 42 | for i, (c, n, p) in enumerate(zip(np_clean, np_noisy, np_preds)): 43 | ax[i][0].imshow(c, origin="lower", aspect="auto") 44 | ax[i][0].axis("off") 45 | ax[i][0].title.set_text("clean") 46 | 47 | ax[i][1].imshow(n, origin="lower", aspect="auto") 48 | ax[i][1].axis("off") 49 | ax[i][1].title.set_text("noisy") 50 | 51 | ax[i][2].imshow(p, origin="lower", aspect="auto") 52 | ax[i][2].axis("off") 53 | ax[i][2].title.set_text("preds") 54 | 55 | wandb.log({f"{name}_images": wandb.Image(fig)}) 56 | 57 | plt.close() 58 | 59 | 60 | def plot_image_from_audio( 61 | clean: torch.Tensor, 62 | noisy: torch.Tensor, 63 | preds: torch.Tensor, 64 | lengths: torch.Tensor, 65 | name: str, 66 | ) -> None: 67 | """Plot a batch of images and log them to wandb.""" 68 | clean = clean.squeeze(1).float().cpu().detach()[:5] 69 | noisy = noisy.squeeze(1).float().cpu().detach()[:5] 70 | preds = preds.squeeze(1).float().cpu().detach()[:5] 71 | 72 | fig, ax = plt.subplots(len(clean), 3, figsize=(20, 5 * len(clean))) 73 | 74 | for i, (c, n, p, length) in enumerate(zip(clean, noisy, preds, lengths)): 75 | original_spec = librosa.power_to_db(SPEC_FN(c[:length])) 76 | noisy_spec = librosa.power_to_db(SPEC_FN(n[:length])) 77 | pred_spec = librosa.power_to_db(SPEC_FN(p[:length])) 78 | 79 | ax[i][0].imshow(original_spec, origin="lower", aspect="auto") 80 | ax[i][0].axis("off") 81 | ax[i][0].title.set_text("clean") 82 | 83 | ax[i][1].imshow(noisy_spec, origin="lower", aspect="auto") 84 | ax[i][1].axis("off") 85 | ax[i][1].title.set_text("noisy") 86 | 87 | ax[i][2].imshow(pred_spec, origin="lower", aspect="auto") 88 | ax[i][2].axis("off") 89 | ax[i][2].title.set_text("preds") 90 | 91 | wandb.log({f"{name}_images": wandb.Image(fig)}) 92 | plt.close() 93 | 94 | 95 | def log_audio_batch( 96 | clean: torch.Tensor, 97 | noisy: torch.Tensor, 98 | preds: torch.Tensor, 99 | lengths: torch.Tensor, 100 | name: str, 101 | sample_rate: int = 24000, 102 | ) -> None: 103 | """Log a batch of audio to wandb.""" 104 | np_clean = clean.squeeze(1).float().numpy(force=True)[0][: int(lengths[0])] 105 | np_noisy = noisy.squeeze(1).float().numpy(force=True)[0][: int(lengths[0])] 106 | np_preds = preds.squeeze(1).float().numpy(force=True)[0][: int(lengths[0])] 107 | 108 | wandb.log( 109 | { 110 | f"{name}_audio": { 111 | f"{name}_clean": wandb.Audio(np_clean, sample_rate=sample_rate), 112 | f"{name}_noisy": wandb.Audio(np_noisy, sample_rate=sample_rate), 113 | f"{name}_pred": wandb.Audio(np_preds, sample_rate=sample_rate), 114 | }, 115 | }, 116 | ) 117 | -------------------------------------------------------------------------------- /src/denoisers/modeling/unet1d/model.py: -------------------------------------------------------------------------------- 1 | """UNet1D model.""" 2 | 3 | from typing import Any, Optional 4 | 5 | import torch 6 | from torch import Tensor, nn 7 | from transformers import PreTrainedModel 8 | 9 | from denoisers.modeling.unet1d.config import UNet1DConfig 10 | from denoisers.modeling.unet1d.modules import DownBlock1D, MidBlock1D, UpBlock1D 11 | 12 | 13 | class UNet1DModelOutputs: 14 | """Class for holding model outputs.""" 15 | 16 | def __init__(self, audio: Tensor, noise: Optional[Tensor] = None) -> None: 17 | self.audio = audio 18 | self.noise = noise 19 | 20 | 21 | class UNet1DModel(PreTrainedModel): 22 | """Pretrained UNet1D Model.""" 23 | 24 | config_class: Any = UNet1DConfig 25 | 26 | def __init__(self, config: UNet1DConfig) -> None: 27 | super().__init__(config) 28 | self.config = config 29 | self.model = UNet1D( 30 | channels=config.channels, 31 | kernel_size=config.kernel_size, 32 | num_groups=config.num_groups, 33 | activation=config.activation, 34 | dropout=config.dropout, 35 | norm_type=config.norm_type, 36 | ) 37 | 38 | def forward(self, inputs: Tensor) -> UNet1DModelOutputs: 39 | """Forward Pass.""" 40 | if self.config.autoencoder: 41 | audio = self.model(inputs) 42 | return UNet1DModelOutputs(audio=audio) 43 | else: 44 | noise = self.model(inputs) 45 | denoised = inputs - noise 46 | return UNet1DModelOutputs(audio=denoised, noise=noise) 47 | 48 | 49 | class UNet1D(nn.Module): 50 | """UNet1D model.""" 51 | 52 | def __init__( 53 | self, 54 | channels: tuple[int, ...] = ( 55 | 32, 56 | 64, 57 | 96, 58 | 128, 59 | 160, 60 | 192, 61 | 224, 62 | 256, 63 | 288, 64 | 320, 65 | 352, 66 | 384, 67 | ), 68 | kernel_size: int = 3, 69 | num_groups: Optional[int] = None, 70 | activation: str = "silu", 71 | dropout: float = 0.1, 72 | norm_type: str = "layer", 73 | ) -> None: 74 | super().__init__() 75 | self.in_conv = nn.Conv1d( 76 | 1, 77 | channels[0], 78 | kernel_size=kernel_size, 79 | padding=kernel_size // 2, 80 | ) 81 | self.encoder_layers = nn.ModuleList( 82 | [ 83 | DownBlock1D( 84 | channels[i], 85 | out_channels=channels[i + 1], 86 | kernel_size=kernel_size, 87 | num_groups=num_groups, 88 | dropout=dropout, 89 | activation=activation, 90 | norm_type=norm_type, 91 | ) 92 | for i in range(len(channels) - 1) 93 | ], 94 | ) 95 | self.middle = MidBlock1D( 96 | in_channels=channels[-1], 97 | out_channels=channels[-1], 98 | kernel_size=kernel_size, 99 | num_groups=num_groups, 100 | dropout=dropout, 101 | activation=activation, 102 | norm_type=norm_type, 103 | ) 104 | self.decoder_layers = nn.ModuleList( 105 | [ 106 | UpBlock1D( 107 | channels[i + 1], 108 | out_channels=channels[i], 109 | kernel_size=kernel_size, 110 | num_groups=num_groups, 111 | dropout=dropout, 112 | activation=activation, 113 | norm_type=norm_type, 114 | ) 115 | for i in reversed(range(len(channels) - 1)) 116 | ], 117 | ) 118 | self.out_conv = nn.Sequential( 119 | nn.Conv1d(channels[0] + 1, 1, kernel_size=1, padding=0), 120 | nn.Tanh(), 121 | ) 122 | 123 | def forward(self, inputs: Tensor) -> Tensor: 124 | """Forward Pass.""" 125 | out = self.in_conv(inputs) 126 | 127 | skips = [] 128 | for layer in self.encoder_layers: 129 | out = layer(out) 130 | skips.append(out) 131 | 132 | out = self.middle(out) 133 | 134 | for skip, layer in zip(reversed(skips), self.decoder_layers): 135 | out = layer(out[..., : skip.size(-1)] + skip) 136 | 137 | out = torch.concat([out, inputs], dim=1) 138 | out = self.out_conv(out) 139 | 140 | return out.float() 141 | -------------------------------------------------------------------------------- /src/denoisers/modeling/waveunet/model.py: -------------------------------------------------------------------------------- 1 | """WaveUNet Model.""" 2 | 3 | from typing import Any, Optional 4 | 5 | import torch 6 | from torch import nn 7 | from transformers import PreTrainedModel 8 | 9 | from denoisers.modeling.modules import Activation, Normalization 10 | from denoisers.modeling.waveunet.config import WaveUNetConfig 11 | from denoisers.modeling.waveunet.modules import DownsampleBlock1D, UpsampleBlock1D 12 | 13 | 14 | class WaveUNetModelOutputs: 15 | """Class for holding model outputs.""" 16 | 17 | def __init__( 18 | self, audio: torch.Tensor, noise: Optional[torch.Tensor] = None 19 | ) -> None: 20 | self.audio = audio 21 | self.noise = noise 22 | 23 | 24 | class WaveUNetModel(PreTrainedModel): 25 | """Pretrained WaveUNet Model.""" 26 | 27 | config_class: Any = WaveUNetConfig 28 | 29 | def __init__(self, config: WaveUNetConfig) -> None: 30 | super().__init__(config) 31 | self.config = config 32 | self.model = WaveUNet( 33 | in_channels=config.in_channels, 34 | downsample_kernel_size=config.downsample_kernel_size, 35 | upsample_kernel_size=config.upsample_kernel_size, 36 | dropout=config.dropout, 37 | activation=config.activation, 38 | ) 39 | 40 | def forward(self, inputs: torch.Tensor) -> WaveUNetModelOutputs: 41 | """Forward Pass.""" 42 | if self.config.autoencoder: 43 | audio = self.model(inputs) 44 | return WaveUNetModelOutputs(audio=audio) 45 | else: 46 | noise = self.model(inputs) 47 | denoised = inputs - noise 48 | return WaveUNetModelOutputs(audio=denoised, noise=noise) 49 | 50 | 51 | class WaveUNet(nn.Module): 52 | """WaveUNet Model.""" 53 | 54 | def __init__( 55 | self, 56 | in_channels: tuple[int, ...] = ( 57 | 24, 58 | 48, 59 | 72, 60 | 96, 61 | 120, 62 | 144, 63 | 168, 64 | 192, 65 | 216, 66 | 240, 67 | 264, 68 | 288, 69 | ), 70 | downsample_kernel_size: int = 15, 71 | upsample_kernel_size: int = 5, 72 | dropout: float = 0.0, 73 | activation: str = "leaky_relu", 74 | norm_type: str = "batch", 75 | num_groups: Optional[int] = None, 76 | ) -> None: 77 | super().__init__() 78 | self.in_conv = nn.Conv1d( 79 | 1, 80 | in_channels[0], 81 | kernel_size=downsample_kernel_size, 82 | padding=downsample_kernel_size // 2, 83 | ) 84 | self.encoder_layers = nn.ModuleList( 85 | [ 86 | DownsampleBlock1D( 87 | in_channels[i], 88 | out_channels=in_channels[i + 1], 89 | kernel_size=downsample_kernel_size, 90 | dropout=dropout, 91 | activation=activation, 92 | norm_type=norm_type, 93 | num_groups=num_groups, 94 | ) 95 | for i in range(len(in_channels) - 1) 96 | ], 97 | ) 98 | self.middle = nn.Sequential( 99 | nn.Conv1d( 100 | in_channels[-1], 101 | in_channels[-1], 102 | kernel_size=downsample_kernel_size, 103 | padding=downsample_kernel_size // 2, 104 | ), 105 | Normalization(in_channels[-1], num_groups=num_groups, name=norm_type), 106 | Activation(activation), 107 | nn.Dropout(dropout), 108 | ) 109 | self.decoder_layers = nn.ModuleList( 110 | [ 111 | UpsampleBlock1D( 112 | 2 * in_channels[i + 1], 113 | out_channels=in_channels[i], 114 | kernel_size=upsample_kernel_size, 115 | dropout=dropout, 116 | activation=activation, 117 | ) 118 | for i in reversed(range(len(in_channels) - 1)) 119 | ], 120 | ) 121 | self.out_conv = nn.Sequential( 122 | nn.Conv1d(in_channels[0] + 1, 1, kernel_size=1, padding=0), 123 | nn.Tanh(), 124 | ) 125 | 126 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 127 | """Forward Pass.""" 128 | out = self.in_conv(inputs) 129 | 130 | skips = [] 131 | for layer in self.encoder_layers: 132 | out = layer(out) 133 | skips.append(out) 134 | 135 | out = self.middle(out) 136 | 137 | for skip, layer in zip(reversed(skips), self.decoder_layers): 138 | out = torch.concat([out[..., : skip.size(-1)], skip], dim=1) 139 | out = layer(out) 140 | 141 | out = torch.concat([out, inputs], dim=1) 142 | out = self.out_conv(out) 143 | 144 | return out.float() 145 | -------------------------------------------------------------------------------- /src/denoisers/losses.py: -------------------------------------------------------------------------------- 1 | """Denoisers losses.""" 2 | 3 | from typing import Any 4 | 5 | import torch 6 | from torch import nn 7 | 8 | 9 | def stft( 10 | x: torch.Tensor, fft_size: int, hop_size: int, win_length: int, window: torch.Tensor 11 | ) -> torch.Tensor: 12 | """Perform STFT and convert to magnitude spectrogram. 13 | 14 | Args: 15 | x (Tensor): Input signal tensor (B, T). 16 | fft_size (int): FFT size. 17 | hop_size (int): Hop size. 18 | win_length (int): Window length. 19 | window (str): Window function type. 20 | 21 | Returns: 22 | Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). 23 | """ 24 | x_stft = torch.stft( 25 | x.squeeze(1), fft_size, hop_size, win_length, window, return_complex=True 26 | ) 27 | x_stft = torch.view_as_real(x_stft) 28 | real = x_stft[..., 0] 29 | imag = x_stft[..., 1] 30 | 31 | return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1) 32 | 33 | 34 | class SpectralConvergenceLoss(nn.Module): 35 | """Spectral convergence loss module.""" 36 | 37 | def __init__(self) -> None: 38 | super().__init__() 39 | 40 | def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor) -> torch.Tensor: 41 | """Calculate forward propagation. 42 | 43 | Args: 44 | x_mag (Tensor): Magnitude spectrogram of 45 | predicted signal (B, #frames, #freq_bins). 46 | y_mag (Tensor): Magnitude spectrogram of 47 | ground truth signal (B, #frames, #freq_bins). 48 | 49 | Returns: 50 | Tensor: Spectral convergence loss value. 51 | """ 52 | return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") 53 | 54 | 55 | class LogSTFTMagnitudeLoss(torch.nn.Module): 56 | """Log STFT magnitude loss module.""" 57 | 58 | def __init__(self) -> None: 59 | super().__init__() 60 | 61 | def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor) -> torch.Tensor: 62 | """Calculate forward propagation. 63 | 64 | Args: 65 | x_mag (Tensor): Magnitude spectrogram of predicted signal 66 | (B, frames, freq_bins). 67 | y_mag (Tensor): Magnitude spectrogram of ground truth signal 68 | (B, frames, freq_bins). 69 | 70 | Returns: 71 | Tensor: Log STFT magnitude loss value. 72 | """ 73 | return nn.functional.l1_loss(torch.log(y_mag), torch.log(x_mag)) 74 | 75 | 76 | class STFTLoss(nn.Module): 77 | """STFT loss module.""" 78 | 79 | def __init__( 80 | self, 81 | fft_size: int = 1024, 82 | shift_size: int = 120, 83 | win_length: int = 600, 84 | window: str = "hann_window", 85 | ) -> None: 86 | super().__init__() 87 | self.fft_size = fft_size 88 | self.shift_size = shift_size 89 | self.win_length = win_length 90 | self.register_buffer("window", getattr(torch, window)(win_length)) 91 | self.spectral_convergenge_loss = SpectralConvergenceLoss() 92 | self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() 93 | 94 | def forward( 95 | self, x: torch.Tensor, y: torch.Tensor 96 | ) -> tuple[torch.Tensor, torch.Tensor]: 97 | """Calculate forward propagation. 98 | 99 | Args: 100 | x (Tensor): Predicted signal (B, T). 101 | y (Tensor): Ground truth signal (B, T). 102 | 103 | Returns: 104 | Tensor: Spectral convergence loss value. 105 | Tensor: Log STFT magnitude loss value. 106 | """ 107 | window: Any = self.window 108 | x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, window) 109 | y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, window) 110 | sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) 111 | mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) 112 | 113 | return sc_loss, mag_loss 114 | 115 | 116 | class MultiResolutionSTFTLoss(nn.Module): 117 | """Multi resolution STFT loss module.""" 118 | 119 | def __init__( 120 | self, 121 | fft_sizes: tuple[int, ...] = (1024, 2048, 512), 122 | hop_sizes: tuple[int, ...] = (120, 240, 50), 123 | win_lengths: tuple[int, ...] = (600, 1200, 240), 124 | window: str = "hann_window", 125 | factor_sc: float = 0.1, 126 | factor_mag: float = 0.1, 127 | ) -> None: 128 | super().__init__() 129 | assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) 130 | self.stft_losses = torch.nn.ModuleList() 131 | for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): 132 | self.stft_losses += [STFTLoss(fs, ss, wl, window)] 133 | self.factor_sc = factor_sc 134 | self.factor_mag = factor_mag 135 | 136 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> float: 137 | """Calculate forward propagation. 138 | 139 | Args: 140 | x (Tensor): Predicted signal (B, T). 141 | y (Tensor): Ground truth signal (B, T). 142 | 143 | Returns: 144 | Tensor: Multi resolution spectral convergence loss value. 145 | Tensor: Multi resolution log STFT magnitude loss value. 146 | """ 147 | sc_loss = 0.0 148 | mag_loss = 0.0 149 | for f in self.stft_losses: 150 | sc_l, mag_l = f(x, y) 151 | sc_loss += sc_l 152 | mag_loss += mag_l 153 | sc_loss /= len(self.stft_losses) 154 | mag_loss /= len(self.stft_losses) 155 | 156 | return self.factor_sc * sc_loss + self.factor_mag * mag_loss 157 | -------------------------------------------------------------------------------- /tests/test_transforms.py: -------------------------------------------------------------------------------- 1 | """Test transforms.""" 2 | 3 | from pathlib import Path 4 | 5 | import torch 6 | import torchaudio 7 | from torch import Tensor 8 | 9 | from denoisers.testing import sine_wave 10 | from denoisers.transforms import ( 11 | BreakTransform, 12 | ClipTransform, 13 | FilterTransform, 14 | FreqMask, 15 | GaussianNoise, 16 | NoiseFromFile, 17 | ReverbFromFile, 18 | ReverbFromSoundboard, 19 | SpecTransform, 20 | TimeMask, 21 | VolTransform, 22 | ) 23 | 24 | 25 | def test_gaussian_noise() -> None: 26 | """Test gaussian noise.""" 27 | transform = GaussianNoise(p=1.0) 28 | audio = sine_wave(800, 1, 16000) 29 | noisy_audio = transform(audio.clone()) 30 | 31 | assert isinstance(noisy_audio, Tensor) 32 | assert noisy_audio.shape == audio.shape 33 | assert not torch.isclose(audio, noisy_audio).all() 34 | 35 | noisy_audio = transform(audio.numpy()) 36 | assert isinstance(noisy_audio, Tensor) 37 | 38 | 39 | def test_filter_transform() -> None: 40 | """Test filter transform.""" 41 | transform = FilterTransform(p=1.0) 42 | audio = sine_wave(800, 1, 16000) 43 | noisy_audio = transform(audio.clone()) 44 | 45 | assert isinstance(noisy_audio, Tensor) 46 | assert audio.shape == noisy_audio.shape 47 | assert not torch.isclose(audio, noisy_audio).all() 48 | 49 | noisy_audio = transform(audio.numpy()) 50 | assert isinstance(noisy_audio, Tensor) 51 | 52 | 53 | def test_clip_transform() -> None: 54 | """Test clip transform.""" 55 | transform = ClipTransform(p=1.0, clip_ceil=0.5, clip_floor=-0.5) 56 | audio = sine_wave(800, 1, 16000) 57 | noisy_audio = transform(audio.clone()) 58 | 59 | assert isinstance(noisy_audio, Tensor) 60 | assert audio.shape == noisy_audio.shape 61 | assert not torch.isclose(audio, noisy_audio).all() 62 | 63 | noisy_audio = transform(audio.numpy()) 64 | assert isinstance(noisy_audio, Tensor) 65 | 66 | 67 | def test_break_transform() -> None: 68 | """Test break transform.""" 69 | transform = BreakTransform(p=1.0) 70 | audio = sine_wave(800, 1, 16000) 71 | noisy_audio = transform(audio.clone()) 72 | 73 | assert isinstance(noisy_audio, Tensor) 74 | assert audio.shape == noisy_audio.shape 75 | assert not torch.isclose(audio, noisy_audio).all() 76 | 77 | noisy_audio = transform(audio.numpy()) 78 | assert isinstance(noisy_audio, Tensor) 79 | 80 | 81 | def test_reverb_from_soundboard() -> None: 82 | """Test reverb from soundboard.""" 83 | transform = ReverbFromSoundboard(p=1.0) 84 | audio = sine_wave(800, 1, 16000) 85 | noisy_audio = transform(audio.clone()) 86 | 87 | assert isinstance(noisy_audio, Tensor) 88 | assert audio.shape == noisy_audio.shape 89 | assert not torch.isclose(audio, noisy_audio).all() 90 | 91 | noisy_audio = transform(audio.numpy()) 92 | assert isinstance(noisy_audio, Tensor) 93 | 94 | 95 | def test_spec_transform() -> None: 96 | """Test spec transform.""" 97 | transform = SpecTransform(p=1.0) 98 | audio = sine_wave(800, 1, 16000) 99 | noisy_audio = transform(audio.clone()) 100 | 101 | assert isinstance(noisy_audio, Tensor) 102 | assert audio.shape == noisy_audio.shape 103 | assert not torch.isclose(audio, noisy_audio).all() 104 | 105 | noisy_audio = transform(audio.numpy()) 106 | assert isinstance(noisy_audio, Tensor) 107 | 108 | 109 | def test_vol_transform() -> None: 110 | """Test vol transform.""" 111 | transform = VolTransform(p=1.0, sample_rate=16000) 112 | audio = sine_wave(800, 1, 16000) 113 | noisy_audio = transform(audio.clone()) 114 | 115 | assert isinstance(noisy_audio, Tensor) 116 | assert audio.shape == noisy_audio.shape 117 | assert not torch.isclose(audio, noisy_audio).all() 118 | 119 | noisy_audio = transform(audio.numpy()) 120 | assert isinstance(noisy_audio, Tensor) 121 | 122 | 123 | def test_noise_from_file(tmpdir) -> None: 124 | """Test noise from file.""" 125 | save_path = Path(tmpdir) / "noises" 126 | save_path.mkdir(exist_ok=True, parents=True) 127 | 128 | audio = sine_wave(800, 1, 16000) 129 | torchaudio.save(str(save_path / "noise.flac"), torch.randn_like(audio), 16000) 130 | noise, _ = torchaudio.load(str(save_path / "noise.flac")) 131 | 132 | transform = NoiseFromFile(save_path, p=1.0, sample_rate=16000, num_samples=1) 133 | noisy_audio = transform(audio.numpy()) 134 | 135 | torch.testing.assert_close(noisy_audio, audio + noise) 136 | torch.testing.assert_close(noise, noisy_audio - audio) 137 | torch.testing.assert_close(audio, noisy_audio - noise) 138 | 139 | 140 | def test_reverb_from_file() -> None: 141 | """Test reverb from file.""" 142 | audio = sine_wave(800, 1, 8000) 143 | 144 | transform = ReverbFromFile( 145 | Path("tests/assets/reverb"), 146 | p=1.0, 147 | sample_rate=8000, 148 | num_samples=1, 149 | ) 150 | 151 | noisy_audio = transform(audio.clone()) 152 | assert isinstance(noisy_audio, Tensor) 153 | assert audio.shape == noisy_audio.shape 154 | assert not torch.isclose(audio, noisy_audio).all() 155 | 156 | noisy_audio = transform(audio.numpy()) 157 | assert isinstance(noisy_audio, Tensor) 158 | 159 | 160 | def test_freq_mask() -> None: 161 | """Test freq mask.""" 162 | spec = torch.ones(1, 2048, 100) 163 | noisy_spec = FreqMask(num_masks=2, size=80, p=1.0)(spec) 164 | assert noisy_spec.shape == spec.shape 165 | assert noisy_spec.sum() < spec.sum() 166 | 167 | 168 | def test_time_mask() -> None: 169 | """Test time mask.""" 170 | spec = torch.ones(1, 2048, 100) 171 | noisy_spec = TimeMask(num_masks=2, size=20, p=1.0)(spec) 172 | assert noisy_spec.shape == spec.shape 173 | assert noisy_spec.sum() < spec.sum() 174 | -------------------------------------------------------------------------------- /src/denoisers/modeling/unet1d/modules.py: -------------------------------------------------------------------------------- 1 | """Modules for 1D U-Net.""" 2 | 3 | from typing import Optional 4 | 5 | from torch import Tensor, nn 6 | 7 | from denoisers.modeling.modules import ( 8 | Activation, 9 | Downsample1D, 10 | Normalization, 11 | Upsample1D, 12 | ) 13 | 14 | 15 | class DownBlock1D(nn.Module): 16 | """Downsampling Block for 1D data.""" 17 | 18 | def __init__( 19 | self, 20 | in_channels: int, 21 | out_channels: int, 22 | kernel_size: int = 3, 23 | num_groups: Optional[int] = None, 24 | activation: str = "silu", 25 | dropout: float = 0.0, 26 | norm_type: str = "layer", 27 | ) -> None: 28 | super().__init__() 29 | self.res_block = ResBlock1D( 30 | in_channels=in_channels, 31 | out_channels=out_channels, 32 | kernel_size=kernel_size, 33 | num_groups=num_groups, 34 | activation=activation, 35 | dropout=dropout, 36 | norm_type=norm_type, 37 | ) 38 | self.downsample = Downsample1D( 39 | in_channels=out_channels, 40 | out_channels=out_channels, 41 | kernel_size=kernel_size, 42 | use_conv=True, 43 | ) 44 | 45 | def forward(self, x: Tensor) -> Tensor: 46 | """Forward Pass.""" 47 | x = self.res_block(x) 48 | x = self.downsample(x) 49 | return x 50 | 51 | 52 | class UpBlock1D(nn.Module): 53 | """Upsampling Block for 1D data.""" 54 | 55 | def __init__( 56 | self, 57 | in_channels: int, 58 | out_channels: int, 59 | kernel_size: int = 3, 60 | num_groups: Optional[int] = None, 61 | activation: str = "silu", 62 | dropout: float = 0.0, 63 | norm_type: str = "layer", 64 | ) -> None: 65 | super().__init__() 66 | self.res_block = ResBlock1D( 67 | in_channels=in_channels, 68 | out_channels=out_channels, 69 | kernel_size=kernel_size, 70 | num_groups=num_groups, 71 | activation=activation, 72 | dropout=dropout, 73 | norm_type=norm_type, 74 | ) 75 | self.upsample = Upsample1D( 76 | in_channels=out_channels, 77 | out_channels=out_channels, 78 | kernel_size=kernel_size, 79 | use_conv=True, 80 | ) 81 | 82 | def forward(self, x: Tensor) -> Tensor: 83 | """Forward Pass.""" 84 | x = self.res_block(x) 85 | x = self.upsample(x) 86 | return x 87 | 88 | 89 | class ResBlock1D(nn.Module): 90 | """Residual Block for 1D data.""" 91 | 92 | def __init__( 93 | self, 94 | in_channels: int, 95 | out_channels: int, 96 | kernel_size: int = 3, 97 | num_groups: Optional[int] = None, 98 | activation: str = "silu", 99 | dropout: float = 0.0, 100 | norm_type: str = "layer", 101 | ) -> None: 102 | super().__init__() 103 | self.conv_1 = nn.Conv1d( 104 | in_channels, 105 | out_channels, 106 | kernel_size, 107 | padding=kernel_size // 2, 108 | bias=False, 109 | ) 110 | self.norm_1 = Normalization(out_channels, name=norm_type, num_groups=num_groups) 111 | self.activation_1 = Activation(activation, channels=out_channels) 112 | self.dropout = nn.Dropout(dropout) 113 | self.conv_2 = nn.Conv1d( 114 | out_channels, 115 | out_channels, 116 | kernel_size, 117 | padding=kernel_size // 2, 118 | bias=False, 119 | ) 120 | self.norm_2 = Normalization(out_channels, name=norm_type, num_groups=num_groups) 121 | self.activation_2 = Activation(activation, channels=out_channels) 122 | self.residual = nn.Conv1d(in_channels, out_channels, 1) 123 | 124 | def forward(self, x: Tensor) -> Tensor: 125 | """Forward Pass.""" 126 | residual = self.residual(x) 127 | x = self.conv_1(x) 128 | x = self.norm_1(x) 129 | x = self.activation_1(x) 130 | x = self.dropout(x) 131 | x = self.conv_2(x) 132 | x = self.norm_2(x) 133 | x = self.activation_2(x) 134 | x = self.dropout(x) 135 | return x + residual 136 | 137 | 138 | class MidBlock1D(nn.Module): 139 | """Middle Block for 1D data.""" 140 | 141 | def __init__( 142 | self, 143 | in_channels: int, 144 | out_channels: int, 145 | kernel_size: int = 3, 146 | num_groups: Optional[int] = None, 147 | num_heads: int = 8, 148 | activation: str = "silu", 149 | dropout: float = 0.0, 150 | norm_type: str = "layer", 151 | ) -> None: 152 | super().__init__() 153 | self.res_block_1 = ResBlock1D( 154 | in_channels=in_channels, 155 | out_channels=out_channels, 156 | kernel_size=kernel_size, 157 | num_groups=num_groups, 158 | activation=activation, 159 | dropout=dropout, 160 | norm_type=norm_type, 161 | ) 162 | self.attention = nn.MultiheadAttention(out_channels, num_heads=num_heads) 163 | self.res_block_2 = ResBlock1D( 164 | in_channels=out_channels, 165 | out_channels=out_channels, 166 | kernel_size=kernel_size, 167 | num_groups=num_groups, 168 | activation=activation, 169 | dropout=dropout, 170 | norm_type=norm_type, 171 | ) 172 | 173 | def forward(self, x: Tensor) -> Tensor: 174 | """Forward Pass.""" 175 | x = self.res_block_1(x) 176 | residual = x 177 | x = x.transpose(2, 1) 178 | x = self.attention(x, x, x)[0] 179 | x = x.transpose(2, 1) 180 | x = x + residual 181 | x = self.res_block_2(x) 182 | return x 183 | -------------------------------------------------------------------------------- /src/denoisers/lightning_module.py: -------------------------------------------------------------------------------- 1 | """Denoisers lightning module trainer.""" 2 | 3 | from typing import Any 4 | 5 | import torch 6 | import torchaudio 7 | import wandb 8 | from lightning_utilities.core.rank_zero import rank_zero_only 9 | from pytorch_lightning import LightningModule 10 | from pytorch_lightning.utilities import grad_norm 11 | from pytorch_lightning.utilities.memory import garbage_collection_cuda 12 | from torch import nn 13 | from torchmetrics import MetricCollection 14 | from torchmetrics.audio import ( 15 | ScaleInvariantSignalDistortionRatio, 16 | ScaleInvariantSignalNoiseRatio, 17 | SignalDistortionRatio, 18 | SignalNoiseRatio, 19 | ) 20 | from torchmetrics.functional.audio import deep_noise_suppression_mean_opinion_score 21 | from transformers import PreTrainedModel 22 | 23 | from denoisers.datasets.audio import Batch 24 | from denoisers.losses import MultiResolutionSTFTLoss 25 | from denoisers.metrics import DNSMOS 26 | from denoisers.utils import log_audio_batch, plot_image_from_audio 27 | 28 | 29 | class DenoisersLightningModule(LightningModule): 30 | """Denoisers lightning module.""" 31 | 32 | def __init__( 33 | self, 34 | model: PreTrainedModel, 35 | sync_dist: bool = False, 36 | use_ema: bool = False, 37 | push_to_hub: bool = False, 38 | ) -> None: 39 | super().__init__() 40 | self.model = model 41 | if use_ema: 42 | self.ema_model = torch.optim.swa_utils.AveragedModel( 43 | self.model, 44 | multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999), 45 | ) 46 | self.loss_fn = nn.L1Loss() 47 | self.stft_loss = MultiResolutionSTFTLoss() 48 | self.train_metrics = MetricCollection( 49 | { 50 | "train_snr": SignalNoiseRatio(), 51 | "train_sdr": SignalDistortionRatio(), 52 | "train_sisnr": ScaleInvariantSignalNoiseRatio(), 53 | "train_sisdr": ScaleInvariantSignalDistortionRatio(), 54 | } 55 | ) 56 | self.val_metrics = MetricCollection( 57 | { 58 | "val_snr": SignalNoiseRatio(), 59 | "val_sdr": SignalDistortionRatio(), 60 | "val_sisnr": ScaleInvariantSignalNoiseRatio(), 61 | "val_sisdr": ScaleInvariantSignalDistortionRatio(), 62 | } 63 | ) 64 | self.dns_mos = DNSMOS() 65 | self.autoencoder: bool = self.model.config.autoencoder 66 | self.sync_dist = sync_dist 67 | self.use_ema = use_ema 68 | self.push_to_hub = push_to_hub 69 | self.last_val_batch: dict[str, Any] = {} 70 | 71 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 72 | """Forward Pass.""" 73 | return self.model(inputs) 74 | 75 | def training_step(self, batch: Batch, batch_idx: Any) -> torch.Tensor: 76 | """Train step.""" 77 | outputs = self(batch.noisy) 78 | 79 | if self.autoencoder: 80 | l1_loss = self.loss_fn(outputs.audio, batch.audio) 81 | else: 82 | l1_loss = self.loss_fn(outputs.noise, batch.noisy - batch.audio) 83 | 84 | stft_loss = self.stft_loss(outputs.audio.float(), batch.audio.float()) 85 | 86 | loss = l1_loss + stft_loss 87 | 88 | metrics = self.train_metrics(outputs.audio, batch.audio) 89 | 90 | self.log("train_loss", loss, prog_bar=True, sync_dist=self.sync_dist) 91 | self.log_dict( 92 | {**metrics, "train_stft_loss": stft_loss, "train_l1_loss": l1_loss}, 93 | sync_dist=self.sync_dist, 94 | ) 95 | 96 | return loss 97 | 98 | def validation_step(self, batch: Any, batch_idx: Any) -> torch.Tensor: 99 | """Val step.""" 100 | if self.use_ema: 101 | outputs = self.ema_model(batch.noisy) 102 | else: 103 | outputs = self.model(batch.noisy) 104 | 105 | if self.autoencoder: 106 | l1_loss = self.loss_fn(outputs.audio, batch.audio) 107 | else: 108 | l1_loss = self.loss_fn(outputs.noise, batch.noisy - batch.audio) 109 | 110 | stft_loss = self.stft_loss(outputs.audio.float(), batch.audio.float()) 111 | 112 | loss = l1_loss + stft_loss 113 | 114 | metrics = self.val_metrics(outputs.audio, batch.audio) 115 | 116 | self.log("val_loss", loss, prog_bar=True, sync_dist=self.sync_dist) 117 | self.log_dict( 118 | {**metrics, "val_stft_loss": stft_loss, "val_l1_loss": l1_loss}, 119 | sync_dist=self.sync_dist, 120 | ) 121 | 122 | self.last_val_batch = { 123 | "outputs": ( 124 | batch.audio.detach(), 125 | batch.noisy.detach(), 126 | outputs.audio.detach(), 127 | batch.lengths.detach(), 128 | ), 129 | } 130 | 131 | return loss 132 | 133 | @rank_zero_only 134 | def on_validation_epoch_end(self) -> None: 135 | """Val epoch end.""" 136 | outputs = self.last_val_batch["outputs"] 137 | audio, noisy, preds, lengths = outputs 138 | 139 | score = deep_noise_suppression_mean_opinion_score( 140 | torchaudio.functional.resample( 141 | preds.cpu(), self.model.config.sample_rate, 16000 142 | ), 143 | 16000, 144 | False, 145 | ).mean() 146 | wandb.log({"dns_mos": score}) 147 | 148 | log_audio_batch( 149 | audio, 150 | noisy, 151 | preds, 152 | lengths, 153 | name="val", 154 | sample_rate=self.model.config.sample_rate, 155 | ) 156 | plot_image_from_audio(audio, noisy, preds, lengths, "val") 157 | 158 | if self.use_ema: 159 | self.model.load_state_dict(self.ema_model.module.state_dict()) 160 | 161 | model_name = self.trainer.default_root_dir.split("/")[-1] 162 | self.model.save_pretrained(self.trainer.default_root_dir + "/" + model_name) 163 | 164 | if self.push_to_hub: 165 | self.model.push_to_hub(model_name) # type: ignore[arg-type] 166 | 167 | garbage_collection_cuda() 168 | 169 | def on_before_zero_grad(self, *args: Any, **kwargs: Any) -> None: 170 | """Update EMA model.""" 171 | if self.use_ema: 172 | self.ema_model.update_parameters(self.model) 173 | 174 | def on_before_optimizer_step(self, optimizer: Any) -> None: 175 | """Before optimizer step.""" 176 | self.log_dict(grad_norm(self, norm_type=1), sync_dist=self.sync_dist) 177 | 178 | def configure_optimizers(self) -> Any: 179 | """Set optimizer.""" 180 | optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4, weight_decay=1e-2) 181 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999875) 182 | return {"optimizer": optimizer, "lr_scheduler": scheduler} 183 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by uv via the following command: 2 | # uv pip compile pyproject.toml -o requirements.txt 3 | aiohappyeyeballs==2.4.3 4 | # via aiohttp 5 | aiohttp==3.11.6 6 | # via fsspec 7 | aiosignal==1.3.1 8 | # via aiohttp 9 | attrs==24.2.0 10 | # via aiohttp 11 | audiomentations==0.37.0 12 | # via denoisers (pyproject.toml) 13 | audioread==3.0.1 14 | # via librosa 15 | certifi==2024.8.30 16 | # via 17 | # requests 18 | # sentry-sdk 19 | cffi==1.17.1 20 | # via 21 | # numpy-minmax 22 | # numpy-rms 23 | # soundfile 24 | cfgv==3.4.0 25 | # via pre-commit 26 | charset-normalizer==3.4.0 27 | # via requests 28 | click==8.1.7 29 | # via wandb 30 | coloredlogs==15.0.1 31 | # via onnxruntime 32 | contourpy==1.3.1 33 | # via matplotlib 34 | cycler==0.12.1 35 | # via matplotlib 36 | cython==3.0.11 37 | # via pyroomacoustics 38 | decorator==5.1.1 39 | # via librosa 40 | distlib==0.3.9 41 | # via virtualenv 42 | docker-pycreds==0.4.0 43 | # via wandb 44 | filelock==3.16.1 45 | # via 46 | # huggingface-hub 47 | # torch 48 | # transformers 49 | # virtualenv 50 | flatbuffers==25.2.10 51 | # via onnxruntime 52 | fonttools==4.55.0 53 | # via matplotlib 54 | frozenlist==1.5.0 55 | # via 56 | # aiohttp 57 | # aiosignal 58 | fsspec==2024.10.0 59 | # via 60 | # huggingface-hub 61 | # pytorch-lightning 62 | # torch 63 | gitdb==4.0.11 64 | # via gitpython 65 | gitpython==3.1.43 66 | # via wandb 67 | huggingface-hub==0.26.2 68 | # via 69 | # tokenizers 70 | # transformers 71 | humanfriendly==10.0 72 | # via coloredlogs 73 | identify==2.6.3 74 | # via pre-commit 75 | idna==3.10 76 | # via 77 | # requests 78 | # yarl 79 | iniconfig==2.0.0 80 | # via pytest 81 | jinja2==3.1.4 82 | # via torch 83 | joblib==1.4.2 84 | # via 85 | # librosa 86 | # scikit-learn 87 | kiwisolver==1.4.7 88 | # via matplotlib 89 | lazy-loader==0.4 90 | # via librosa 91 | librosa==0.10.2.post1 92 | # via 93 | # denoisers (pyproject.toml) 94 | # audiomentations 95 | lightning-utilities==0.11.9 96 | # via 97 | # pytorch-lightning 98 | # torchmetrics 99 | llvmlite==0.43.0 100 | # via numba 101 | markupsafe==3.0.2 102 | # via jinja2 103 | matplotlib==3.9.2 104 | # via denoisers (pyproject.toml) 105 | mpmath==1.3.0 106 | # via sympy 107 | msgpack==1.1.0 108 | # via librosa 109 | multidict==6.1.0 110 | # via 111 | # aiohttp 112 | # yarl 113 | mypy==1.13.0 114 | # via denoisers (pyproject.toml) 115 | mypy-extensions==1.0.0 116 | # via mypy 117 | networkx==3.4.2 118 | # via torch 119 | nodeenv==1.9.1 120 | # via pre-commit 121 | numba==0.60.0 122 | # via librosa 123 | numpy==1.26.4 124 | # via 125 | # audiomentations 126 | # contourpy 127 | # librosa 128 | # matplotlib 129 | # numba 130 | # numpy-minmax 131 | # numpy-rms 132 | # onnxruntime 133 | # pedalboard 134 | # pyroomacoustics 135 | # scikit-learn 136 | # scipy 137 | # soxr 138 | # torchmetrics 139 | # torchvision 140 | # transformers 141 | numpy-minmax==0.3.1 142 | # via audiomentations 143 | numpy-rms==0.4.2 144 | # via audiomentations 145 | onnxruntime==1.21.1 146 | # via denoisers (pyproject.toml) 147 | packaging==24.2 148 | # via 149 | # huggingface-hub 150 | # lazy-loader 151 | # lightning-utilities 152 | # matplotlib 153 | # onnxruntime 154 | # pooch 155 | # pytest 156 | # pytorch-lightning 157 | # torchmetrics 158 | # transformers 159 | pedalboard==0.9.16 160 | # via denoisers (pyproject.toml) 161 | pesq==0.0.4 162 | # via denoisers (pyproject.toml) 163 | pillow==11.0.0 164 | # via 165 | # matplotlib 166 | # torchvision 167 | platformdirs==4.3.6 168 | # via 169 | # pooch 170 | # virtualenv 171 | # wandb 172 | pluggy==1.5.0 173 | # via pytest 174 | pooch==1.8.2 175 | # via librosa 176 | pre-commit==4.0.1 177 | # via denoisers (pyproject.toml) 178 | propcache==0.2.0 179 | # via 180 | # aiohttp 181 | # yarl 182 | protobuf==5.29.5 183 | # via 184 | # onnxruntime 185 | # wandb 186 | psutil==6.1.0 187 | # via wandb 188 | pybind11==2.13.6 189 | # via pyroomacoustics 190 | pycparser==2.22 191 | # via cffi 192 | pydocstyle==6.3.0 193 | # via denoisers (pyproject.toml) 194 | pydub==0.25.1 195 | # via denoisers (pyproject.toml) 196 | pyparsing==3.2.0 197 | # via matplotlib 198 | pyroomacoustics==0.8.2 199 | # via denoisers (pyproject.toml) 200 | pytest==8.3.3 201 | # via denoisers (pyproject.toml) 202 | python-dateutil==2.9.0.post0 203 | # via matplotlib 204 | pytorch-lightning==2.4.0 205 | # via denoisers (pyproject.toml) 206 | pyyaml==6.0.2 207 | # via 208 | # huggingface-hub 209 | # pre-commit 210 | # pytorch-lightning 211 | # transformers 212 | # wandb 213 | regex==2024.11.6 214 | # via transformers 215 | requests==2.32.3 216 | # via 217 | # huggingface-hub 218 | # pooch 219 | # transformers 220 | # wandb 221 | ruff==0.11.6 222 | # via denoisers (pyproject.toml) 223 | safetensors==0.4.5 224 | # via transformers 225 | scikit-learn==1.5.2 226 | # via librosa 227 | scipy==1.12.0 228 | # via 229 | # audiomentations 230 | # librosa 231 | # pyroomacoustics 232 | # scikit-learn 233 | sentry-sdk==2.18.0 234 | # via wandb 235 | setproctitle==1.3.4 236 | # via wandb 237 | setuptools==78.1.1 238 | # via 239 | # lightning-utilities 240 | # torch 241 | # wandb 242 | six==1.16.0 243 | # via 244 | # docker-pycreds 245 | # python-dateutil 246 | smmap==5.0.1 247 | # via gitdb 248 | snowballstemmer==2.2.0 249 | # via pydocstyle 250 | soundfile==0.12.1 251 | # via librosa 252 | soxr==0.5.0.post1 253 | # via 254 | # audiomentations 255 | # librosa 256 | sympy==1.13.1 257 | # via 258 | # onnxruntime 259 | # torch 260 | threadpoolctl==3.5.0 261 | # via scikit-learn 262 | tokenizers==0.20.3 263 | # via transformers 264 | torch==2.8.0 265 | # via 266 | # denoisers (pyproject.toml) 267 | # pytorch-lightning 268 | # torchaudio 269 | # torchmetrics 270 | # torchvision 271 | torchaudio==2.5.1 272 | # via denoisers (pyproject.toml) 273 | torchmetrics==1.6.0 274 | # via pytorch-lightning 275 | torchvision==0.20.1 276 | # via denoisers (pyproject.toml) 277 | tqdm==4.67.0 278 | # via 279 | # huggingface-hub 280 | # pytorch-lightning 281 | # transformers 282 | transformers==4.48.0 283 | # via denoisers (pyproject.toml) 284 | typing-extensions==4.12.2 285 | # via 286 | # huggingface-hub 287 | # librosa 288 | # lightning-utilities 289 | # mypy 290 | # pytorch-lightning 291 | # torch 292 | urllib3==2.6.0 293 | # via 294 | # requests 295 | # sentry-sdk 296 | virtualenv==20.28.0 297 | # via pre-commit 298 | wandb==0.18.7 299 | # via denoisers (pyproject.toml) 300 | yarl==1.17.2 301 | # via aiohttp 302 | -------------------------------------------------------------------------------- /src/denoisers/modeling/modules.py: -------------------------------------------------------------------------------- 1 | """Modules for the denoiser models.""" 2 | 3 | from typing import Any, Optional 4 | 5 | import torch 6 | from torch import nn 7 | 8 | 9 | class Upsample1D(nn.Module): 10 | """A 1D upsampling layer with an optional convolution. 11 | 12 | https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py#L29 13 | 14 | Parameters 15 | ---------- 16 | in_channels (`int`): 17 | number of channels in the inputs and outputs. 18 | out_channels (`int`, optional): 19 | number of output channels. Defaults to `channels`. 20 | use_conv (`bool`, default `False`): 21 | option to use a convolution. 22 | use_conv_transpose (`bool`, default `False`): 23 | option to use a convolution transpose. 24 | 25 | """ 26 | 27 | def __init__( 28 | self, 29 | in_channels: int, 30 | out_channels: Optional[int] = None, 31 | kernel_size: int = 3, 32 | use_conv: bool = False, 33 | use_conv_transpose: bool = False, 34 | padding: int = 1, 35 | bias: bool = True, 36 | ): 37 | super().__init__() 38 | self.channels = in_channels 39 | self.out_channels = out_channels or in_channels 40 | self.use_conv = use_conv 41 | self.use_conv_transpose = use_conv_transpose 42 | 43 | self.conv: Any = None 44 | if use_conv_transpose: 45 | self.conv = nn.ConvTranspose1d( 46 | in_channels, 47 | self.out_channels, 48 | kernel_size=kernel_size, 49 | stride=2, 50 | padding=padding, 51 | bias=bias, 52 | ) 53 | elif use_conv: 54 | self.conv = nn.Conv1d( 55 | self.channels, 56 | self.out_channels, 57 | kernel_size=kernel_size, 58 | padding=padding, 59 | bias=bias, 60 | ) 61 | 62 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 63 | """Forward pass.""" 64 | if self.use_conv_transpose: 65 | return self.conv(inputs) 66 | 67 | outputs = nn.functional.interpolate( 68 | inputs, 69 | scale_factor=2.0, 70 | mode="linear", 71 | align_corners=True, 72 | ) 73 | 74 | if self.use_conv: 75 | outputs = self.conv(outputs) 76 | 77 | return outputs 78 | 79 | 80 | class Downsample1D(nn.Module): 81 | """A 1D downsampling layer with an optional convolution. 82 | 83 | https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py#L70 84 | 85 | Parameters 86 | ---------- 87 | in_channels (`int`): 88 | number of channels in the inputs and outputs. 89 | out_channels (`int`, optional): 90 | number of output channels. Defaults to `channels`. 91 | kernel_size (`int`, default `3`): 92 | kernel size for the convolution. 93 | stride (`int`, default `2`): 94 | stride for the convolution. 95 | use_conv (`bool`, default `False`): 96 | option to use a convolution. 97 | padding (`int`, default `1`): 98 | padding for the convolution. 99 | """ 100 | 101 | def __init__( 102 | self, 103 | in_channels: int, 104 | out_channels: Optional[int] = None, 105 | kernel_size: int = 3, 106 | stride: int = 2, 107 | use_conv: bool = False, 108 | padding: int = 1, 109 | bias: bool = True, 110 | ): 111 | super().__init__() 112 | self.channels = in_channels 113 | self.out_channels = out_channels or in_channels 114 | self.use_conv = use_conv 115 | 116 | self.conv: Any = None 117 | if use_conv: 118 | self.conv = nn.Conv1d( 119 | self.channels, 120 | self.out_channels, 121 | kernel_size=kernel_size, 122 | stride=stride, 123 | padding=padding, 124 | bias=bias, 125 | ) 126 | else: 127 | self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) 128 | 129 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 130 | """Forward pass.""" 131 | return self.conv(inputs) 132 | 133 | 134 | class Activation(nn.Module): 135 | """Activation function.""" 136 | 137 | def __init__(self, name: str, channels: Optional[int] = None): 138 | super().__init__() 139 | if name == "silu": 140 | self.activation: nn.Module = nn.SiLU(inplace=True) 141 | elif name == "relu": 142 | self.activation = nn.ReLU(inplace=True) 143 | elif name == "leaky_relu": 144 | self.activation = nn.LeakyReLU(0.2, inplace=True) 145 | elif name == "snake": 146 | if channels is None: 147 | raise ValueError( 148 | "Number of channels must be specified for Snake activation." 149 | ) 150 | self.activation = Snake1d(hidden_dim=channels) 151 | else: 152 | raise ValueError(f"{name} activation is not supported.") 153 | 154 | def forward(self, x: torch.Tensor) -> torch.Tensor: 155 | """Forward Pass.""" 156 | x = self.activation(x) 157 | return x 158 | 159 | 160 | class Normalization(nn.Module): 161 | """Normalization layer.""" 162 | 163 | def __init__(self, in_channels: int, name: str, num_groups: Optional[int] = None): 164 | super().__init__() 165 | self.name = name 166 | if name == "batch": 167 | self.norm: nn.Module = nn.BatchNorm1d(in_channels) 168 | elif name == "instance": 169 | self.norm = nn.InstanceNorm1d(in_channels) 170 | elif name == "group": 171 | if num_groups is None: 172 | raise ValueError("Number of groups must be specified for GroupNorm.") 173 | self.norm = nn.GroupNorm(num_groups, in_channels) 174 | elif name == "layer": 175 | self.norm = nn.LayerNorm(in_channels) 176 | else: 177 | raise ValueError(f"{name} normalization is not supported.") 178 | 179 | def forward(self, x: torch.Tensor) -> torch.Tensor: 180 | """Forward Pass.""" 181 | if self.name == "layer": 182 | return self.norm(x.transpose(2, 1)).transpose(2, 1) 183 | return self.norm(x) 184 | 185 | 186 | class Snake1d(nn.Module): 187 | """A 1-dimensional Snake activation function module. 188 | 189 | https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_oobleck.py#L30 190 | 191 | """ 192 | 193 | def __init__(self, hidden_dim: int, logscale: bool = False): 194 | super().__init__() 195 | self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1), requires_grad=True) 196 | self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1), requires_grad=True) 197 | self.logscale = logscale 198 | 199 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 200 | """Forward pass.""" 201 | shape = hidden_states.shape 202 | 203 | alpha = self.alpha if not self.logscale else torch.exp(self.alpha) 204 | beta = self.beta if not self.logscale else torch.exp(self.beta) 205 | 206 | hidden_states = hidden_states.reshape(shape[0], shape[1], -1) 207 | hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin( 208 | alpha * hidden_states 209 | ).pow(2) 210 | hidden_states = hidden_states.reshape(shape) 211 | return hidden_states 212 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/denoisers/transforms.py: -------------------------------------------------------------------------------- 1 | """Transforms.""" 2 | 3 | import random 4 | from pathlib import Path 5 | from typing import Union 6 | 7 | import numpy as np 8 | import torch 9 | import torchaudio 10 | from pedalboard import Reverb # type: ignore 11 | from torch import Tensor, nn 12 | 13 | 14 | class GaussianNoise(nn.Module): 15 | """Gaussian Noise Transform.""" 16 | 17 | def __init__(self, p: float = 0.5, db_min: int = 1, db_max: int = 30) -> None: 18 | super().__init__() 19 | self.p = p 20 | self.db_min = db_min 21 | self.db_max = db_max 22 | 23 | def forward(self, x: Union[Tensor, np.ndarray]) -> Union[Tensor, np.ndarray]: 24 | """Forward Pass.""" 25 | if isinstance(x, np.ndarray): 26 | x = torch.from_numpy(x) 27 | 28 | if random.random() <= self.p: 29 | db = torch.randint(self.db_min, self.db_max, (1,)) 30 | x = torchaudio.functional.add_noise(x, torch.randn_like(x), snr=db) 31 | 32 | return x 33 | 34 | 35 | class FilterTransform(nn.Module): 36 | """Filter Transform.""" 37 | 38 | def __init__( 39 | self, 40 | sample_rate: int = 24000, 41 | freq_ceil: int = 12000, 42 | freq_floor: int = 0, 43 | gain_ceil: int = 20, 44 | gain_floor: int = -20, 45 | q: float = 0.707, 46 | p: float = 0.5, 47 | ) -> None: 48 | super().__init__() 49 | self.sample_rate = sample_rate 50 | self.freq_ceil = freq_ceil 51 | self.freq_floor = freq_floor 52 | self.gain_ceil = gain_ceil 53 | self.gain_floor = gain_floor 54 | self.q = q 55 | self.p = p 56 | 57 | def get_gain(self) -> float: 58 | """Calculate gain.""" 59 | return (self.gain_floor - self.gain_ceil) * random.random() + self.gain_ceil 60 | 61 | def get_center_freq(self) -> float: 62 | """Calculate center frequency.""" 63 | return (self.freq_floor - self.freq_ceil) * random.random() + self.freq_ceil 64 | 65 | def forward(self, x: Union[Tensor, np.ndarray]) -> Union[Tensor, np.ndarray]: 66 | """Forward Pass.""" 67 | if isinstance(x, np.ndarray): 68 | x = torch.from_numpy(x) 69 | 70 | if random.random() <= self.p: 71 | gain = self.get_gain() 72 | center_freq = self.get_center_freq() 73 | 74 | x = torchaudio.functional.equalizer_biquad( 75 | x, 76 | sample_rate=self.sample_rate, 77 | center_freq=center_freq, 78 | gain=gain, 79 | Q=self.q, 80 | ) 81 | 82 | return x 83 | 84 | 85 | class ClipTransform(nn.Module): 86 | """Clip Transform.""" 87 | 88 | def __init__( 89 | self, 90 | clip_ceil: float = 1.0, 91 | clip_floor: float = 0.5, 92 | p: float = 0.5, 93 | ) -> None: 94 | super().__init__() 95 | self.clip_ceil = clip_ceil 96 | self.clip_floor = clip_floor 97 | self.p = p 98 | 99 | def get_clip(self) -> float: 100 | """Calculate clip level.""" 101 | return (self.clip_floor - self.clip_ceil) * random.random() + self.clip_ceil 102 | 103 | def forward(self, x: Union[Tensor, np.ndarray]) -> Tensor: 104 | """Forward Pass.""" 105 | if isinstance(x, np.ndarray): 106 | x = torch.from_numpy(x) 107 | 108 | if random.random() <= self.p: 109 | clip_level = self.get_clip() 110 | x[torch.abs(x) > clip_level] = clip_level 111 | 112 | return x 113 | 114 | 115 | class BreakTransform(nn.Module): 116 | """Break Transform.""" 117 | 118 | def __init__( 119 | self, 120 | sample_rate: int = 24000, 121 | break_duration: float = 0.001, 122 | break_ceil: int = 50, 123 | break_floor: int = 10, 124 | p: float = 0.5, 125 | ) -> None: 126 | super().__init__() 127 | self.sample_rate = sample_rate 128 | self.break_segment = sample_rate * break_duration 129 | self.break_ceil = break_ceil 130 | self.break_floor = break_floor 131 | self.p = p 132 | 133 | def get_mask(self, x: Tensor) -> Tensor: 134 | """Calculate mask.""" 135 | break_count = ( 136 | self.break_floor - self.break_ceil 137 | ) * random.random() + self.break_ceil 138 | break_duration = break_count * self.break_segment 139 | mask = torch.ones(x.size()) 140 | break_start = int(x.size(0) * random.random()) 141 | break_end = int(min(x.size(0), break_start + break_duration)) 142 | mask[break_start:break_end] = 0 143 | return mask 144 | 145 | def forward(self, x: Union[Tensor, np.ndarray]) -> Union[Tensor, np.ndarray]: 146 | """Forward Pass.""" 147 | if isinstance(x, np.ndarray): 148 | x = torch.from_numpy(x) 149 | 150 | if random.random() <= self.p: 151 | break_mask = self.get_mask(x) 152 | x = x * break_mask 153 | 154 | return x 155 | 156 | 157 | class ReverbFromSoundboard(nn.Module): 158 | """Reverb Transform.""" 159 | 160 | def __init__(self, sample_rate: int = 24000, p: float = 0.5) -> None: 161 | super().__init__() 162 | self.sample_rate = sample_rate 163 | self.reverb = Reverb() 164 | self.p = p 165 | 166 | @torch.no_grad() 167 | def forward(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: 168 | """Forward Pass.""" 169 | if isinstance(x, torch.Tensor): 170 | x = x.numpy() 171 | 172 | if random.random() <= self.p: 173 | self.reverb.room_size = random.random() 174 | x = self.reverb.process(x, self.sample_rate) 175 | 176 | x = torch.from_numpy(x) 177 | 178 | return x 179 | 180 | 181 | class SpecTransform(nn.Module): 182 | """Spectrogram Transform.""" 183 | 184 | def __init__(self, p: float = 0.5) -> None: 185 | super().__init__() 186 | self.a_hp = torch.tensor([-1.99599, 0.99600]) 187 | self.b_hp = torch.tensor([-2, 1]) 188 | self.p = p 189 | 190 | def _uni_rand(self) -> Tensor: 191 | return torch.rand(1) - 0.5 192 | 193 | def _rand_resp(self) -> tuple[Tensor, Tensor, Tensor, Tensor]: 194 | a1 = 0.75 * self._uni_rand() 195 | a2 = 0.75 * self._uni_rand() 196 | b1 = 0.75 * self._uni_rand() 197 | b2 = 0.75 * self._uni_rand() 198 | return a1, a2, b1, b2 199 | 200 | def forward(self, x: Union[Tensor, np.ndarray]) -> Union[Tensor, np.ndarray]: 201 | """Forward Pass.""" 202 | if isinstance(x, np.ndarray): 203 | x = torch.from_numpy(x) 204 | 205 | if random.random() <= self.p: 206 | a1, a2, b1, b2 = self._rand_resp() 207 | x = torchaudio.functional.biquad( 208 | x, 209 | 1, 210 | self.b_hp[0], 211 | self.b_hp[1], 212 | 1, 213 | self.a_hp[0], 214 | self.a_hp[1], 215 | ) 216 | x = torchaudio.functional.biquad(x, 1, b1, b2, 1, a1, a2) 217 | 218 | return x 219 | 220 | 221 | class VolTransform(nn.Module): 222 | """Volume Transform.""" 223 | 224 | def __init__( 225 | self, 226 | sample_rate: int = 24000, 227 | segment_len: float = 0.5, 228 | vol_ceil: int = 10, 229 | vol_floor: int = -10, 230 | p: float = 0.5, 231 | ) -> None: 232 | super().__init__() 233 | self.sample_rate = sample_rate 234 | self.segment_len = segment_len 235 | self.segment_samples = int(self.sample_rate * self.segment_len) 236 | self.vol_ceil = vol_ceil 237 | self.vol_floor = vol_floor 238 | self.p = p 239 | 240 | def get_vol(self, sample_length: int) -> Tensor: 241 | """Get volume.""" 242 | segments = sample_length / (self.segment_len * self.sample_rate) 243 | step_db = torch.arange( 244 | self.vol_ceil, 245 | self.vol_floor, 246 | (self.vol_floor - self.vol_ceil) / segments, 247 | ) 248 | return step_db 249 | 250 | @staticmethod 251 | def apply_gain(segments: Tensor, db: Tensor) -> Tensor: 252 | """Apply gain.""" 253 | gain = torch.pow(10.0, (0.05 * db)) 254 | segments = segments * gain 255 | return segments 256 | 257 | def forward(self, x: Union[Tensor, np.ndarray]) -> Union[Tensor, np.ndarray]: 258 | """Forward Pass.""" 259 | if isinstance(x, np.ndarray): 260 | x = torch.from_numpy(x) 261 | 262 | if random.random() <= self.p: 263 | step_db = self.get_vol(x.size(0)) 264 | for i in range(step_db.size(0)): 265 | start = i * self.segment_samples 266 | end = min((i + 1) * self.segment_samples, x.size(0)) 267 | x[start:end] = self.apply_gain(x[start:end], step_db[i]) 268 | 269 | return x 270 | 271 | 272 | class NoiseFromFile(nn.Module): 273 | """Add background noise from random file.""" 274 | 275 | def __init__( 276 | self, 277 | root: Path, 278 | p: float = 1.0, 279 | sample_rate: int = 24000, 280 | num_samples: int = 1000, 281 | ) -> None: 282 | super().__init__() 283 | self.root = root 284 | self.p = p 285 | self.sample_rate = sample_rate 286 | noise_paths = random.choices(list(root.glob("**/*.flac")), k=num_samples) 287 | self.noises = [torchaudio.load(str(noise))[0] for noise in noise_paths] 288 | print(f"Loaded {len(self.noises)} noises") 289 | 290 | def forward(self, x: Union[Tensor, np.ndarray]) -> Union[Tensor, np.ndarray]: 291 | """Forward Pass.""" 292 | if isinstance(x, np.ndarray): 293 | x = torch.from_numpy(x) 294 | 295 | if random.random() < self.p: 296 | noise = random.choice(self.noises).to(x.device) 297 | x = x + noise[:, : x.size(1)] 298 | 299 | return x 300 | 301 | 302 | class ReverbFromFile(nn.Module): 303 | """Add reverb to a sample from a rir file.""" 304 | 305 | def __init__( 306 | self, 307 | root: Path, 308 | p: float = 0.5, 309 | sample_rate: int = 24000, 310 | num_samples: int = 1000, 311 | ) -> None: 312 | super().__init__() 313 | self.root = root 314 | self.p = p 315 | self.sample_rate = sample_rate 316 | response_paths = random.choices(list(root.glob("**/*.wav")), k=num_samples) 317 | self.responses = [torchaudio.load(str(r))[0] for r in response_paths] 318 | 319 | def forward(self, x: Union[Tensor, np.ndarray]) -> Union[Tensor, np.ndarray]: 320 | """Forward Pass.""" 321 | if isinstance(x, np.ndarray): 322 | x = torch.from_numpy(x) 323 | 324 | if random.random() < self.p: 325 | rir_raw = random.choice(self.responses) 326 | rir_raw = rir_raw[random.randint(0, rir_raw.shape[0] - 1)][None] 327 | rir = rir_raw 328 | rir = rir / torch.norm(rir, p=2) 329 | rir = torch.flip(rir, [1]) 330 | x = torch.nn.functional.pad(x, (rir.shape[1] - 1, 0)) 331 | x = nn.functional.conv1d(x[None, ...], rir[None, ...])[0] 332 | 333 | return x 334 | 335 | 336 | class FreqMask(nn.Module): 337 | """Frequency Mask.""" 338 | 339 | def __init__(self, num_masks: int, size: int, p: float = 0.5) -> None: 340 | super().__init__() 341 | self.num_masks = num_masks 342 | self.size = size 343 | self.p = p 344 | self.transform = torchaudio.transforms.FrequencyMasking( 345 | freq_mask_param=self.size, 346 | ) 347 | 348 | def forward(self, x: Tensor) -> Tensor: 349 | """Forward Pass.""" 350 | if random.random() <= self.p: 351 | for _ in range(self.num_masks): 352 | x = self.transform(x) 353 | return x 354 | 355 | 356 | class TimeMask(nn.Module): 357 | """Time Mask.""" 358 | 359 | def __init__(self, num_masks: int, size: int, p: float = 0.5) -> None: 360 | super().__init__() 361 | self.num_masks = num_masks 362 | self.size = size 363 | self.p = p 364 | self.transform = torchaudio.transforms.TimeMasking( 365 | time_mask_param=self.size, 366 | p=1.0, 367 | ) 368 | 369 | def forward(self, x: Tensor) -> Tensor: 370 | """Forward Pass.""" 371 | if random.random() <= self.p: 372 | for _ in range(self.num_masks): 373 | x = self.transform(x) 374 | return x 375 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Denoisers 2 | 3 | [![PyPI version](https://badge.fury.io/py/denoisers.svg)](https://badge.fury.io/py/denoisers) 4 | [![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/) 5 | [![License](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE) 6 | [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/wrice/denoisers) 7 | 8 | Denoisers is a comprehensive deep learning library for audio denoising with a focus on simplicity, flexibility, and state-of-the-art performance. The library provides two main neural network architectures optimized for different use cases: WaveUNet for high-quality waveform processing and UNet1D for efficient real-time applications. 9 | 10 | ## 🎯 Key Features 11 | 12 | - **Two State-of-the-Art Architectures**: 13 | 14 | - **WaveUNet**: Based on the [original paper](https://arxiv.org/abs/1806.03185), optimized for high-fidelity audio restoration 15 | - **UNet1D**: Custom architecture inspired by diffusion models, designed for efficiency and real-time processing 16 | 17 | - **Pre-trained Models**: Ready-to-use models available on Hugging Face Hub 18 | 19 | - **Easy Integration**: Simple API for both inference and training 20 | 21 | - **Production Ready**: Built with PyTorch Lightning for scalable training and deployment 22 | 23 | - **Comprehensive Metrics**: Advanced audio quality metrics including PESQ, STOI, and DNS-MOS 24 | 25 | - **Flexible Training**: Support for various loss functions, data augmentation, and training strategies 26 | 27 | ## 🚀 Quick Start 28 | 29 | ### Installation 30 | 31 | ```bash 32 | pip install denoisers 33 | ``` 34 | 35 | ### Basic Usage 36 | 37 | #### Inference with Pre-trained Models 38 | 39 | ```python 40 | import torch 41 | import torchaudio 42 | from denoisers import WaveUNetModel 43 | from tqdm import tqdm 44 | 45 | # Load pre-trained model 46 | model = WaveUNetModel.from_pretrained("wrice/waveunet-vctk-24khz") 47 | 48 | # Load and preprocess audio 49 | audio, sr = torchaudio.load("noisy_audio.wav") 50 | 51 | # Resample if necessary 52 | if sr != model.config.sample_rate: 53 | audio = torchaudio.functional.resample(audio, sr, model.config.sample_rate) 54 | 55 | # Convert to mono if stereo 56 | if audio.size(0) > 1: 57 | audio = audio.mean(0, keepdim=True) 58 | 59 | # Process audio in chunks to handle long files 60 | chunk_size = model.config.max_length 61 | padding = abs(audio.size(-1) % chunk_size - chunk_size) 62 | padded = torch.nn.functional.pad(audio, (0, padding)) 63 | 64 | clean = [] 65 | for i in tqdm(range(0, padded.shape[-1], chunk_size)): 66 | audio_chunk = padded[:, i : i + chunk_size] 67 | with torch.no_grad(): 68 | clean_chunk = model(audio_chunk[None]).audio 69 | clean.append(clean_chunk.squeeze(0)) 70 | 71 | # Concatenate results and remove padding 72 | denoised = torch.concat(clean, 1)[:, : audio.shape[-1]] 73 | 74 | # Save denoised audio 75 | torchaudio.save("clean_audio.wav", denoised, model.config.sample_rate) 76 | ``` 77 | 78 | #### Available Pre-trained Models 79 | 80 | | Model | Sample Rate | Architecture | Use Case | 81 | | ------------------------------- | ----------- | ------------ | ----------------------------- | 82 | | `wrice/waveunet-vctk-24khz` | 24kHz | WaveUNet | Efficient speech denoising | 83 | | `wrice/waveunet-vctk-48khz` | 48kHz | WaveUNet | High-quality speech denoising | 84 | | `wrice/unet1d-vctk-24khz` | 24kHz | UNet1D | Efficient speech denoising | 85 | | `wrice/unet1d-vctk-48khz` | 48kHz | UNet1D | High-quality speech denoising | 86 | | `wrice/unet1d-vctk-8to48khz` | 8-48kHz | UNet1D | Robust multi-rate denoising | 87 | | `wrice/unet1d-xeno-canto-32khz` | 32kHz | UNet1D | Birdsong denoising | 88 | 89 | ## 🏗️ Architecture Overview 90 | 91 | ### WaveUNet 92 | 93 | The WaveUNet architecture implements a U-Net style network specifically designed for waveform processing: 94 | 95 | - **Encoder-Decoder Architecture**: Progressive downsampling followed by upsampling with skip connections 96 | - **Waveform Processing**: Direct operation on raw audio waveforms 97 | - **High Quality**: Optimized for maximum audio fidelity 98 | - **Configurable Depth**: Adjustable network depth for different complexity requirements 99 | 100 | **Key Parameters:** 101 | 102 | - `in_channels`: Channel progression through the network (default: 24→288) 103 | - `downsample_kernel_size`: Kernel size for downsampling layers (default: 15) 104 | - `upsample_kernel_size`: Kernel size for upsampling layers (default: 5) 105 | - `activation`: Activation function (default: "leaky_relu") 106 | - `max_length`: Maximum input length (default: 163,840 samples) 107 | 108 | ### UNet1D 109 | 110 | The UNet1D architecture is a custom implementation inspired by modern diffusion models: 111 | 112 | - **Efficient Design**: Optimized for computational efficiency and memory usage 113 | - **Real-time Capable**: Suitable for real-time audio processing applications 114 | - **Modern Architecture**: Incorporates latest advances in deep learning for audio 115 | - **Flexible Configuration**: Highly configurable for different audio types and quality requirements 116 | 117 | **Key Parameters:** 118 | 119 | - `channels`: Channel progression (default: 32→384) 120 | - `kernel_size`: Convolution kernel size (default: 3) 121 | - `activation`: Activation function (default: "silu") 122 | - `max_length`: Maximum input length (default: 48,000 samples) 123 | 124 | ## 🔧 Training Your Own Models 125 | 126 | ### Data Preparation 127 | 128 | Organize your training data in the following structure: 129 | 130 | ``` 131 | data_root/ 132 | ├── clean_audio1.wav 133 | ├── clean_audio2.flac 134 | ├── clean_audio3.mp3 135 | └── ... 136 | ``` 137 | 138 | Supported formats: `.wav`, `.flac`, `.mp3`, `.ogg` 139 | 140 | ### Training Command 141 | 142 | ```bash 143 | # Train a WaveUNet model 144 | train waveunet my-model-name /path/to/data_root/ \ 145 | --batch_size 32 \ 146 | --num_devices 2 \ 147 | --ema \ 148 | --push_to_hub 149 | 150 | # Train a UNet1D model 151 | train unet1d my-unet1d-model /path/to/data_root/ \ 152 | --batch_size 64 \ 153 | --num_workers 8 \ 154 | --seed 42 155 | ``` 156 | 157 | ### Training Parameters 158 | 159 | | Parameter | Description | Default | 160 | | --------------- | --------------------------------------------- | --------------------- | 161 | | `--batch_size` | Training batch size | 64 | 162 | | `--num_devices` | Number of GPUs to use | 1 (if CUDA available) | 163 | | `--num_workers` | Data loading workers | 4 | 164 | | `--ema` | Enable Exponential Moving Average | False | 165 | | `--push_to_hub` | Push model to Hugging Face Hub after training | False | 166 | | `--seed` | Random seed for reproducibility | 1234 | 167 | | `--project` | Weights & Biases project name | "denoisers" | 168 | 169 | ### Custom Configuration 170 | 171 | Create custom model configurations by extending the base config classes: 172 | 173 | ```python 174 | from denoisers.modeling.waveunet.config import WaveUNetConfig 175 | 176 | # Custom WaveUNet configuration 177 | config = WaveUNetConfig( 178 | in_channels=(32, 64, 128, 256, 512), 179 | sample_rate=16000, 180 | max_length=32768, 181 | activation="relu", 182 | dropout=0.2, 183 | ) 184 | ``` 185 | 186 | ## 📊 Loss Functions and Metrics 187 | 188 | ### Loss Functions 189 | 190 | The library implements several advanced loss functions optimized for audio denoising: 191 | 192 | - **L1 Loss**: Basic reconstruction loss 193 | - **Multi-Resolution STFT Loss**: Frequency-domain loss for better perceptual quality 194 | - Spectral Convergence Loss 195 | - Log STFT Magnitude Loss 196 | - Multiple resolution scales for comprehensive frequency coverage 197 | 198 | ### Evaluation Metrics 199 | 200 | Comprehensive audio quality assessment: 201 | 202 | - **SNR**: Signal-to-Noise Ratio 203 | - **SDR**: Signal-to-Distortion Ratio 204 | - **SI-SNR**: Scale-Invariant Signal-to-Noise Ratio 205 | - **SI-SDR**: Scale-Invariant Signal-to-Distortion Ratio 206 | - **DNS-MOS**: Deep Noise Suppression Mean Opinion Score 207 | - **PESQ**: Perceptual Evaluation of Speech Quality 208 | 209 | ## 🔬 Advanced Features 210 | 211 | ### Data Augmentation 212 | 213 | Built-in audio augmentation pipeline using `audiomentations`: 214 | 215 | - **Gaussian Noise Addition**: Simulates various noise conditions 216 | - **Colored Noise**: Pink, brown, and other colored noise types 217 | - **Room Simulation**: Realistic acoustic environments using `pyroomacoustics` 218 | - **Dynamic Sample Rate**: Training with multiple sample rates for robustness 219 | 220 | ### Exponential Moving Average (EMA) 221 | 222 | Optional EMA model averaging for improved training stability and performance: 223 | 224 | ```bash 225 | train waveunet my-model /data/ --ema 226 | ``` 227 | 228 | ### Mixed Precision Training 229 | 230 | Automatic mixed precision training with PyTorch Lightning for faster training and reduced memory usage. 231 | 232 | ### Weights & Biases Integration 233 | 234 | Built-in experiment tracking and visualization: 235 | 236 | - Training and validation metrics 237 | - Audio samples logging 238 | - Model checkpoints 239 | - Hyperparameter tracking 240 | 241 | ## 📈 Model Publishing 242 | 243 | Easily share your trained models on Hugging Face Hub: 244 | 245 | ```bash 246 | publish waveunet my-awesome-model /path/to/model/checkpoint 247 | ``` 248 | 249 | This automatically: 250 | 251 | - Uploads model weights and configuration 252 | - Creates model cards with training details 253 | - Enables easy model sharing and distribution 254 | 255 | ## 🛠️ Development 256 | 257 | ### Setup Development Environment 258 | 259 | ```bash 260 | # Clone the repository 261 | git clone https://github.com/will-rice/denoisers.git 262 | cd denoisers 263 | 264 | # Install development dependencies 265 | pip install -e ".[dev]" 266 | 267 | # Install pre-commit hooks 268 | pre-commit install 269 | ``` 270 | 271 | ### Code Quality 272 | 273 | The project maintains high code quality standards: 274 | 275 | - **Type Checking**: MyPy static type analysis 276 | - **Linting**: Ruff for code formatting and style 277 | - **Testing**: Pytest with comprehensive test coverage 278 | - **Documentation**: Google-style docstrings 279 | 280 | ### Running Tests 281 | 282 | ```bash 283 | # Run all tests 284 | pytest 285 | 286 | # Run with coverage 287 | pytest --cov=denoisers --cov-report=html 288 | ``` 289 | 290 | ## 📚 API Reference 291 | 292 | ### Core Models 293 | 294 | #### WaveUNetModel 295 | 296 | ```python 297 | class WaveUNetModel(PreTrainedModel): 298 | """WaveUNet model for audio denoising.""" 299 | 300 | def forward(self, inputs: torch.Tensor) -> WaveUNetModelOutputs: 301 | """ 302 | Args: 303 | inputs: Noisy audio tensor [batch_size, channels, length] 304 | 305 | Returns: 306 | WaveUNetModelOutputs with .audio attribute containing cleaned audio 307 | """ 308 | ``` 309 | 310 | #### UNet1DModel 311 | 312 | ```python 313 | class UNet1DModel(PreTrainedModel): 314 | """UNet1D model for efficient audio denoising.""" 315 | 316 | def forward(self, inputs: torch.Tensor) -> UNet1DModelOutputs: 317 | """ 318 | Args: 319 | inputs: Noisy audio tensor [batch_size, channels, length] 320 | 321 | Returns: 322 | UNet1DModelOutputs with .audio attribute containing cleaned audio 323 | """ 324 | ``` 325 | 326 | ### Datasets 327 | 328 | #### AudioDataset 329 | 330 | ```python 331 | class AudioDataset(Dataset): 332 | """Audio dataset with automatic noise synthesis.""" 333 | 334 | def __init__( 335 | self, 336 | root: Path, 337 | max_length: int, 338 | sample_rate: int, 339 | variable_sample_rate: bool = True, 340 | ): 341 | """ 342 | Args: 343 | root: Path to audio files 344 | max_length: Maximum audio length in samples 345 | sample_rate: Target sample rate 346 | variable_sample_rate: Whether to use variable sample rates during training 347 | """ 348 | ``` 349 | 350 | ## 🤝 Contributing 351 | 352 | We welcome contributions! Please see our [Contributing Guide](CONTRIBUTING.md) for details on: 353 | 354 | - Setting up the development environment 355 | - Code style and standards 356 | - Testing requirements 357 | - Pull request process 358 | 359 | ### Areas for Contribution 360 | 361 | - New model architectures 362 | - Additional loss functions 363 | - Enhanced data augmentation 364 | - Performance optimizations 365 | - Documentation improvements 366 | 367 | ## 📄 License 368 | 369 | This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details. 370 | 371 | ## 🙏 Acknowledgments 372 | 373 | - [WaveUNet paper](https://arxiv.org/abs/1806.03185) for the original architecture 374 | - PyTorch Lightning team for the excellent training framework 375 | - Hugging Face for model hosting and distribution 376 | - The open-source audio processing community 377 | 378 | ## 📞 Support 379 | 380 | - **Issues**: [GitHub Issues](https://github.com/will-rice/denoisers/issues) 381 | - **Discussions**: [GitHub Discussions](https://github.com/will-rice/denoisers/discussions) 382 | - **Email**: wrice20@gmail.com 383 | 384 | ## 🔗 Links 385 | 386 | - [Hugging Face Demo](https://huggingface.co/spaces/wrice/denoisers) 387 | - [PyPI Package](https://pypi.org/project/denoisers/) 388 | - [Documentation](https://github.com/will-rice/denoisers/wiki) 389 | - [Model Zoo](https://huggingface.co/wrice) 390 | --------------------------------------------------------------------------------