├── tests
├── __init__.py
├── manual_test_gpu.py
├── test_auraloss.py
└── simple_train_gpu.py
├── .pre-commit-config.yaml
├── auraloss
├── __init__.py
├── plotting.py
├── utils.py
├── perceptual.py
├── time.py
└── freq.py
├── examples
├── compressor
│ ├── test.sh
│ ├── train.sh
│ ├── comp-metrics.json
│ ├── _test_comp.py
│ ├── README.md
│ ├── train_comp.py
│ ├── data.py
│ └── tcn.py
└── speech-denoise
│ ├── train_denoise.py
│ └── data.py
├── pyproject.toml
├── .github
└── workflows
│ ├── tests.yml
│ └── publish-to-test-pypi.yml
├── setup.py
├── docs
└── auraloss-logo.svg
├── .gitignore
├── README.md
└── LICENSE
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/ambv/black
3 | rev: 21.5b1
4 | hooks:
5 | - id: black
6 | language_version: python3
--------------------------------------------------------------------------------
/auraloss/__init__.py:
--------------------------------------------------------------------------------
1 | """Top-level module for auraloss"""
2 |
3 | # Import auraloss sub-modules
4 | from . import freq
5 | from . import time
6 | from . import perceptual
7 |
--------------------------------------------------------------------------------
/tests/manual_test_gpu.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import auraloss
3 |
4 | y_hat = torch.randn(2, 1, 131072)
5 | y = torch.randn(2, 1, 131072)
6 |
7 | loss_fn = auraloss.freq.MelSTFTLoss(44100)
8 | loss_fn2 = auraloss.freq.MultiResolutionSTFTLoss()
9 |
10 | # loss_fn.cuda()
11 |
12 | y_hat = y_hat.cuda()
13 | y = y.cuda()
14 |
15 | loss = loss_fn2(y_hat, y)
16 | loss = loss_fn(y_hat, y)
17 |
--------------------------------------------------------------------------------
/examples/compressor/test.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=2 python examples/compressor/test_comp.py \
2 | --root_dir /import/c4dm-datasets/SignalTrain_LA2A_Dataset_1.1 \
3 | --logdir lightning_logs/version_9 \
4 | --batch_size 128 \
5 | --sample_rate 44100 \
6 | --eval_subset "test" \
7 | --eval_length 262144 \
8 | --num_workers 8 \
9 | --gpus 1 \
10 | --precision 16 \
11 | --preload True \
12 | --save_dir "./examples/compressor/audio" \
13 | --num_examples 100 \
14 | #--auto_lr_find
15 |
--------------------------------------------------------------------------------
/examples/compressor/train.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python examples/compressor/train_comp.py \
2 | --root_dir /path/to/SignalTrain_LA2A_Dataset_1.1 \
3 | --max_epochs 20 \
4 | --batch_size 128 \
5 | --sample_rate 44100 \
6 | --train_length 32768 \
7 | --eval_length 262144 \
8 | --num_workers 8 \
9 | --kernel_size 15 \
10 | --channel_width 32 \
11 | --dilation_growth 2 \
12 | --lr 0.001 \
13 | --gpus 1 \
14 | --shuffle True \
15 | --precision 16 \
16 | --preload True \
17 | #--auto_lr_find
18 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "auraloss"
3 | version = "0.4.0"
4 | description = "Collection of audio-focused loss functions in PyTorch."
5 | authors = [
6 | { name = "Christian Steinmetz" },
7 | { email = "c.j.steinmetz@qmul.ac.uk" },
8 | ]
9 | dependencies = ["torch", "numpy"]
10 |
11 | [build-system]
12 | # Minimum requirements for the build system to execute.
13 | requires = ["setuptools", "wheel", "attrs"]
14 | build-backend = "setuptools.build_meta"
15 |
16 | [project.optional-dependencies]
17 | all = ["matplotlib", "librosa", "scipy", "scipy"]
18 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: pytest
2 |
3 | on: [push]
4 |
5 | jobs:
6 | build:
7 |
8 | runs-on: ubuntu-latest
9 | strategy:
10 | matrix:
11 | python-version: ["3.7", "3.8", "3.9", "3.10"]
12 |
13 | steps:
14 | - uses: actions/checkout@v3
15 | - name: Set up Python ${{ matrix.python-version }}
16 | uses: actions/setup-python@v4
17 | with:
18 | python-version: ${{ matrix.python-version }}
19 | - name: Install dependencies
20 | run: |
21 | python -m pip install --upgrade pip
22 | sudo apt-get install libsndfile1-dev
23 | pip install .[all]
24 | pip install pytest
25 | - name: Test with pytest
26 | run: |
27 | pytest
--------------------------------------------------------------------------------
/auraloss/plotting.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.signal
3 | import matplotlib.pyplot as plt
4 |
5 |
6 | def compare_filters(iir_b, iir_a, fir_b, fs=1):
7 |
8 | # compute response for IIR filter
9 | w_iir, h_iir = scipy.signal.freqz(iir_b, iir_a, fs=fs, worN=2048)
10 |
11 | # compute response for FIR filter
12 | w_fir, h_fir = scipy.signal.freqz(fir_b, fs=fs)
13 |
14 | h_iir_db = 20 * np.log10(np.abs(h_iir) + 1e-8)
15 | h_fir_db = 20 * np.log10(np.abs(h_fir) + 1e-8)
16 |
17 | plt.plot(w_iir, h_iir_db, label="IIR filter")
18 | plt.plot(w_fir, h_fir_db, label="FIR approx. filter")
19 | plt.xscale("log")
20 | plt.ylim([-50, 10])
21 | plt.xlim([10, 22.05e3])
22 | plt.xlabel("Freq. (Hz)")
23 | plt.ylabel("Mag. (dB)")
24 | plt.legend()
25 | plt.grid()
26 | plt.show()
27 |
--------------------------------------------------------------------------------
/auraloss/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import scipy.signal
3 |
4 |
5 | def apply_reduction(losses, reduction="none"):
6 | """Apply reduction to collection of losses."""
7 | if reduction == "mean":
8 | losses = losses.mean()
9 | elif reduction == "sum":
10 | losses = losses.sum()
11 | return losses
12 |
13 | def get_window(win_type: str, win_length: int):
14 | """Return a window function.
15 |
16 | Args:
17 | win_type (str): Window type. Can either be one of the window function provided in PyTorch
18 | ['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
19 | or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html).
20 | win_length (int): Window length
21 |
22 | Returns:
23 | win: The window as a 1D torch tensor
24 | """
25 |
26 | try:
27 | win = getattr(torch, win_type)(win_length)
28 | except:
29 | win = torch.from_numpy(scipy.signal.windows.get_window(win_type, win_length))
30 |
31 | return win
32 |
--------------------------------------------------------------------------------
/.github/workflows/publish-to-test-pypi.yml:
--------------------------------------------------------------------------------
1 | name: Publish Python distributions to PyPI and TestPyPI
2 |
3 | on: push
4 |
5 | jobs:
6 | build-n-publish:
7 | name: Publish Python distributions to PyPI and TestPyPI
8 | runs-on: ubuntu-latest
9 | steps:
10 | - uses: actions/checkout@master
11 | - name: Set up Python 3.7
12 | uses: actions/setup-python@v1
13 | with:
14 | python-version: 3.7
15 | - name: Install pypa/build
16 | run: >-
17 | python -m
18 | pip install
19 | build
20 | --user
21 | - name: Build a binary wheel and a source tarball
22 | run: >-
23 | python -m
24 | build
25 | --sdist
26 | --wheel
27 | --outdir dist/
28 | .
29 | - name: Publish distribution to Test PyPI
30 | if: startsWith(github.ref, 'refs/tags')
31 | uses: pypa/gh-action-pypi-publish@master
32 | with:
33 | password: ${{ secrets.TEST_PYPI_API_TOKEN }}
34 | repository_url: https://test.pypi.org/legacy/
35 | - name: Publish distribution to PyPI
36 | if: startsWith(github.ref, 'refs/tags')
37 | uses: pypa/gh-action-pypi-publish@master
38 | with:
39 | password: ${{ secrets.PYPI_API_TOKEN }}
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Inspired from https://github.com/kennethreitz/setup.py
3 |
4 | from pathlib import Path
5 | from setuptools import setup, find_packages
6 |
7 | NAME = "auraloss"
8 | DESCRIPTION = "Audio-focused loss functions in PyTorch"
9 | URL = "https://github.com/csteinmetz1/auraloss"
10 | EMAIL = "c.j.steinmetz@qmul.ac.uk"
11 | AUTHOR = "Christian Steinmetz"
12 | REQUIRES_PYTHON = ">=3.6.0"
13 | VERSION = "0.4.0"
14 |
15 | HERE = Path(__file__).parent
16 |
17 | try:
18 | with open(HERE / "README.md", encoding="utf-8") as f:
19 | long_description = "\n" + f.read()
20 | except FileNotFoundError:
21 | long_description = DESCRIPTION
22 |
23 | setup(
24 | name=NAME,
25 | version=VERSION,
26 | description=DESCRIPTION,
27 | long_description=long_description,
28 | long_description_content_type="text/markdown",
29 | author=AUTHOR,
30 | author_email=EMAIL,
31 | python_requires=REQUIRES_PYTHON,
32 | url=URL,
33 | packages=["auraloss"],
34 | install_requires=["torch", "numpy"],
35 | extras_require={"test": ["resampy"], "all": ["matplotlib", "librosa", "scipy"]},
36 | include_package_data=True,
37 | license="Apache License 2.0",
38 | classifiers=[
39 | "License :: OSI Approved :: Apache Software License",
40 | "Topic :: Multimedia :: Sound/Audio",
41 | "Topic :: Scientific/Engineering",
42 | ],
43 | )
44 |
--------------------------------------------------------------------------------
/docs/auraloss-logo.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/speech-denoise/train_denoise.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 | from argparse import ArgumentParser
4 |
5 | from tcn import TCNModel
6 | from data import LibriMixDataset
7 |
8 | parser = ArgumentParser()
9 |
10 | # add PROGRAM level args
11 | parser.add_argument("--root_dir", type=str, default="./data")
12 | parser.add_argument("--sample_rate", type=int, default=8000)
13 | parser.add_argument("--train_subset", type=str, default="train")
14 | parser.add_argument("--val_subset", type=str, default="val")
15 | parser.add_argument("--train_length", type=int, default=16384)
16 | parser.add_argument("--eval_length", type=int, default=32768)
17 | parser.add_argument("--batch_size", type=int, default=8)
18 | parser.add_argument("--num_workers", type=int, default=0)
19 |
20 | # add model specific args
21 | parser = TCNModel.add_model_specific_args(parser)
22 |
23 | # add all the available trainer options to argparse
24 | parser = pl.Trainer.add_argparse_args(parser)
25 |
26 | # parse them args
27 | args = parser.parse_args()
28 |
29 | # init the trainer and model
30 | trainer = pl.Trainer.from_argparse_args(args)
31 |
32 | # setup the dataloaders
33 | train_dataset = LibriMixDataset(
34 | args.root_dir, subset=args.train_subset, length=args.train_length
35 | )
36 | train_dataloader = torch.utils.data.DataLoader(
37 | train_dataset, batch_size=args.batch_size, num_workers=args.num_workers
38 | )
39 |
40 | val_dataset = LibriMixDataset(
41 | args.root_dir, subset=args.val_subset, length=args.eval_length
42 | )
43 | val_dataloader = torch.utils.data.DataLoader(
44 | val_dataset, batch_size=args.batch_size, num_workers=args.num_workers
45 | )
46 |
47 | dict_args = vars(args)
48 | model = TCNModel(**dict_args)
49 |
50 |
51 | # find proper learning rate
52 | trainer.tune(model, train_dataloader)
53 |
54 | # train!
55 | trainer.fit(model, train_dataloader, val_dataloader)
56 |
--------------------------------------------------------------------------------
/examples/compressor/comp-metrics.json:
--------------------------------------------------------------------------------
1 | {
2 | "l1": [
3 | {
4 | "val_loss": 2.192999839782715,
5 | "val_loss/L1": 0.004868578165769577,
6 | "val_loss/ESR": 0.009247570298612118,
7 | "val_loss/DC": 0.00017245385970454663,
8 | "val_loss/LogCosh": 2.789249992929399e-05,
9 | "val_loss/STFT": 0.8255783319473267,
10 | "val_loss/MRSTFT": 0.7987862825393677,
11 | "val_loss/RRSTFT": 0.5543186664581299
12 | }
13 | ],
14 | "logcosh": [
15 | {
16 | "val_loss": 2.149608612060547,
17 | "val_loss/L1": 0.005570689216256142,
18 | "val_loss/ESR": 0.010359607636928558,
19 | "val_loss/DC": 0.0003921029274351895,
20 | "val_loss/LogCosh": 3.224901593057439e-05,
21 | "val_loss/STFT": 0.8068493604660034,
22 | "val_loss/MRSTFT": 0.7804977297782898,
23 | "val_loss/RRSTFT": 0.5459066033363342
24 | }
25 | ],
26 | "esr+dc": [
27 | {
28 | "val_loss": 2.2166659832000732,
29 | "val_loss/L1": 0.005301260389387608,
30 | "val_loss/ESR": 0.009672155603766441,
31 | "val_loss/DC": 4.820536560146138e-05,
32 | "val_loss/LogCosh": 3.0196371881174855e-05,
33 | "val_loss/STFT": 0.8323166370391846,
34 | "val_loss/MRSTFT": 0.8060272932052612,
35 | "val_loss/RRSTFT": 0.5632702112197876
36 | }
37 | ],
38 | "stft": [
39 | {
40 | "val_loss": 1.2915085554122925,
41 | "val_loss/L1": 0.008976267650723457,
42 | "val_loss/ESR": 0.054853539913892746,
43 | "val_loss/DC": 1.080286983778933e-05,
44 | "val_loss/LogCosh": 0.00017625998589210212,
45 | "val_loss/STFT": 0.452709823846817,
46 | "val_loss/MRSTFT": 0.43302521109580994,
47 | "val_loss/RRSTFT": 0.34175658226013184
48 | }
49 | ],
50 | "mrstft": [
51 | {
52 | "val_loss": 1.2609974145889282,
53 | "val_loss/L1": 0.008976605720818043,
54 | "val_loss/ESR": 0.056137651205062866,
55 | "val_loss/DC": 5.441876783152111e-05,
56 | "val_loss/LogCosh": 0.0001797690347302705,
57 | "val_loss/STFT": 0.44097578525543213,
58 | "val_loss/MRSTFT": 0.42079290747642517,
59 | "val_loss/RRSTFT": 0.33388033509254456
60 | }
61 | ],
62 | "rrstft": [
63 | {
64 | "val_loss": 1.6379966735839844,
65 | "val_loss/L1": 0.015477120876312256,
66 | "val_loss/ESR": 0.21662524342536926,
67 | "val_loss/DC": 0.00022139745124150068,
68 | "val_loss/LogCosh": 0.0007051324937492609,
69 | "val_loss/STFT": 0.5189737677574158,
70 | "val_loss/MRSTFT": 0.49819114804267883,
71 | "val_loss/RRSTFT": 0.38780292868614197
72 | }
73 | ]
74 | }
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # more
132 | .vscode
133 | data/
134 | .DS_Store
135 | lightning_logs/
136 |
137 | test-env/
138 | *.wav
139 |
--------------------------------------------------------------------------------
/examples/compressor/_test_comp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import json
4 | import torch
5 | import torchsummary
6 | import pytorch_lightning as pl
7 | from argparse import ArgumentParser
8 |
9 | from tcn import TCNModel
10 | from data import SignalTrainLA2ADataset
11 |
12 | parser = ArgumentParser()
13 |
14 | # add PROGRAM level args
15 | parser.add_argument("--root_dir", type=str, default="./data")
16 | parser.add_argument("--preload", type=bool, default=False)
17 | parser.add_argument("--sample_rate", type=int, default=44100)
18 | parser.add_argument("--logdir", type=str, default="./")
19 | parser.add_argument("--eval_subset", type=str, default="val")
20 | parser.add_argument("--eval_length", type=int, default=262144)
21 | parser.add_argument("--batch_size", type=int, default=8)
22 | parser.add_argument("--num_workers", type=int, default=0)
23 |
24 | # add model specific args
25 | parser = TCNModel.add_model_specific_args(parser)
26 |
27 | # add all the available trainer options to argparse
28 | parser = pl.Trainer.add_argparse_args(parser)
29 |
30 | # parse them args
31 | args = parser.parse_args()
32 |
33 | # setup the dataloaders
34 | test_dataset = SignalTrainLA2ADataset(
35 | args.root_dir,
36 | subset=args.eval_subset,
37 | preload=args.preload,
38 | length=args.eval_length,
39 | )
40 |
41 | test_dataloader = torch.utils.data.DataLoader(
42 | test_dataset,
43 | shuffle=False,
44 | batch_size=args.batch_size,
45 | num_workers=args.num_workers,
46 | )
47 |
48 | results = {}
49 |
50 | # the losses we will test
51 | losses = ["l1", "logcosh", "esr+dc", "stft", "mrstft", "rrstft"]
52 |
53 | for loss_model in losses:
54 |
55 | root_logdir = os.path.join(args.logdir, loss_model, "lightning_logs", "version_0")
56 |
57 | checkpoint_path = glob.glob(os.path.join(root_logdir, "checkpoints", "*"))[0]
58 | print(checkpoint_path)
59 | hparams_file = os.path.join(root_logdir, "hparams.yaml")
60 |
61 | model = TCNModel.load_from_checkpoint(
62 | checkpoint_path=checkpoint_path,
63 | hparams_file=hparams_file,
64 | map_location="cuda:0",
65 | )
66 |
67 | model.hparams.save_dir = args.save_dir
68 | model.hparams.num_examples = args.num_examples
69 |
70 | # init trainer with whatever options
71 | trainer = pl.Trainer.from_argparse_args(args)
72 |
73 | # set the seed
74 | pl.seed_everything(42)
75 |
76 | # test (pass in the model)
77 | res = trainer.test(model, test_dataloaders=test_dataloader)
78 |
79 | # store in dict
80 | results[loss_model] = res
81 |
82 | # save final metrics to disk
83 | with open(os.path.join("examples", "compressor", "comp-metrics.json"), "w") as fp:
84 | json.dump(results, fp, indent=True)
85 |
--------------------------------------------------------------------------------
/examples/compressor/README.md:
--------------------------------------------------------------------------------
1 | # Analog dynamic range compressor modeling
2 |
3 | ## Dataset
4 |
5 | The [SignalTrain LA2A dataset](https://zenodo.org/record/3824876) (19GB) is available for download on Zenodo.
6 | This dataset contains monophonic audio examples, with input and output targets recorded from an [LA2A dynamic range compressor](https://en.wikipedia.org/wiki/LA-2A_Leveling_Amplifier). Different parameterizations of the two controls (threshold and compress/limit) as changed for each example.
7 | We provide a `DataLoader` in [`examples/compressor/data.py`](data.py).
8 |
9 |
10 | In our experiments we use V1.1, which makes corrections by time aligning some files.
11 | Download and extract this dataset before proceeding with the evaluation or retraining.
12 |
13 | ## Pre-trained models
14 |
15 | We provide the pre-trained model checkpoints for each of the six models for download [here](https://drive.google.com/file/d/1g1pHDVSOOtvJjIovfskX9X2295jqYD-J/view?usp=sharing) (16MB). Download this `.tgz` and extract it.
16 | You can run the evaluation (on the test set) with the [`examples/test_comp.py`](test_comp.py) script from the root direction after the dataset and the checkpoints have been downloaded.
17 |
18 | We evaluate with patches of 262,144 samples (~6 seconds at 44.1 kHz) and a batch size of 128 (same as training), which requires around 12 GB of VRAM.
19 | We evaluate with half precision, as the models were trained in half precision as well.
20 | The `preload` flag will load of the audio files into RAM before training starts to run faster than continually reading them from disk.
21 |
22 | Below is the call we used to generate the metrics in the [paper](https://www.christiansteinmetz.com/s/DMRN15__auraloss__Audio_focused_loss_functions_in_PyTorch.pdf).
23 | ```
24 | python examples/test_comp.py \
25 | --root_dir /path/to/SignalTrain_LA2A_Dataset_1.1 \
26 | --logdir path/to/checkpoints/version_9 \
27 | --batch_size 128 \
28 | --sample_rate 44100 \
29 | --eval_subset "test" \
30 | --eval_length 262144 \
31 | --num_workers 8 \
32 | --gpus 1 \
33 | --shuffle False \
34 | --precision 16 \
35 | --preload True \
36 | ```
37 |
38 | ## Retraining
39 | If you wish to retrain the models you can do so using the [`examples/train_comp.py`](train_comp.py) script.
40 | Below is the call we use to train the models.
41 | In this case we train each of the six models for 20 epochs, which takes ~6.5h per model, for a total of ~40h when training on an NVIDIA Quadro RTX 6000.
42 | ```
43 | python examples/train_comp.py \
44 | --root_dir /path/to/SignalTrain_LA2A_Dataset_1.1 \
45 | --max_epochs 20 \
46 | --batch_size 128 \
47 | --sample_rate 44100 \
48 | --train_length 32768 \
49 | --eval_length 262144 \
50 | --num_workers 8 \
51 | --kernel_size 15 \
52 | --channel_width 32 \
53 | --dilation_growth 2 \
54 | --lr 0.001 \
55 | --gpus 1 \
56 | --shuffle True \
57 | --precision 16 \
58 | --preload True \
59 | ```
--------------------------------------------------------------------------------
/examples/compressor/train_comp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import torch
4 | import pytorch_lightning as pl
5 | from argparse import ArgumentParser
6 |
7 | from tcn import TCNModel
8 | from data import SignalTrainLA2ADataset
9 |
10 | parser = ArgumentParser()
11 |
12 | # add PROGRAM level args
13 | parser.add_argument("--root_dir", type=str, default="./data")
14 | parser.add_argument("--preload", type=bool, default=False)
15 | parser.add_argument("--sample_rate", type=int, default=44100)
16 | parser.add_argument("--shuffle", type=bool, default=False)
17 | parser.add_argument("--train_subset", type=str, default="train")
18 | parser.add_argument("--val_subset", type=str, default="val")
19 | parser.add_argument("--train_length", type=int, default=32768)
20 | parser.add_argument("--eval_length", type=int, default=32768)
21 | parser.add_argument("--batch_size", type=int, default=8)
22 | parser.add_argument("--num_workers", type=int, default=0)
23 |
24 | # add model specific args
25 | parser = TCNModel.add_model_specific_args(parser)
26 |
27 | # add all the available trainer options to argparse
28 | parser = pl.Trainer.add_argparse_args(parser)
29 |
30 | # parse them args
31 | args = parser.parse_args()
32 |
33 | # setup the dataloaders
34 | train_dataset = SignalTrainLA2ADataset(
35 | args.root_dir,
36 | subset=args.train_subset,
37 | half=True if args.precision == 16 else False,
38 | preload=args.preload,
39 | length=args.train_length,
40 | )
41 |
42 | train_dataloader = torch.utils.data.DataLoader(
43 | train_dataset,
44 | shuffle=args.shuffle,
45 | batch_size=args.batch_size,
46 | num_workers=args.num_workers,
47 | )
48 |
49 | val_dataset = SignalTrainLA2ADataset(
50 | args.root_dir,
51 | preload=args.preload,
52 | half=True if args.precision == 16 else False,
53 | subset=args.val_subset,
54 | length=args.eval_length,
55 | )
56 |
57 | val_dataloader = torch.utils.data.DataLoader(
58 | val_dataset, shuffle=False, batch_size=2, num_workers=args.num_workers
59 | )
60 |
61 |
62 | past_logs = sorted(glob.glob(os.path.join("lightning_logs", "*")))
63 | if len(past_logs) > 0:
64 | version = int(os.path.basename(past_logs[-1]).split("_")[-1]) + 1
65 | else:
66 | version = 0
67 |
68 | # the losses we will test
69 | if args.train_loss is None:
70 | losses = ["l1", "logcosh", "esr+dc", "stft", "mrstft", "rrstft"]
71 | else:
72 | losses = [args.train_loss]
73 |
74 | for loss_fn in losses:
75 |
76 | print(f"training with {loss_fn}")
77 | # init logger
78 | logdir = os.path.join("lightning_logs", f"version_{version}", loss_fn)
79 | print(logdir)
80 | args.default_root_dir = logdir
81 |
82 | # init the trainer and model
83 | trainer = pl.Trainer.from_argparse_args(args)
84 | print(trainer.default_root_dir)
85 |
86 | # set the seed
87 | pl.seed_everything(42)
88 |
89 | dict_args = vars(args)
90 | dict_args["nparams"] = 2
91 | dict_args["train_loss"] = loss_fn
92 | model = TCNModel(**dict_args)
93 |
94 | # train!
95 | trainer.fit(model, train_dataloader, val_dataloader)
96 |
--------------------------------------------------------------------------------
/examples/speech-denoise/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import glob
4 | import torch
5 | import torchaudio
6 | import numpy as np
7 | import soundfile as sf
8 |
9 | torchaudio.set_audio_backend("sox_io")
10 |
11 |
12 | class LibriMixDataset(torch.utils.data.Dataset):
13 | """LibriMix dataset."""
14 |
15 | def __init__(self, root_dir, subset="train", length=16384, noisy=False):
16 | """
17 | Args:
18 | root_dir (str): Path to the preprocessed LibriMix files.
19 | subset (str, optional): Pull data either from "train", "val", or "test" subsets. (Default: "train")
20 | length (int, optional): Number of samples in the returned examples. (Default: 40)
21 | noise (bool, optional): Use mixtures with additive noise, otherwise anechoic mixes. (Default: False)
22 | """
23 | self.root_dir = root_dir
24 | self.subset = subset
25 | self.length = length
26 | self.noisy = noisy
27 |
28 | # set the mix directory if we want clean or noisy mixes as input
29 | self.mix_dir = "mix_both" if self.noisy else "mix_clean"
30 |
31 | # get all the files in the mix directory first
32 | self.files = glob.glob(
33 | os.path.join(self.root_dir, self.subset, self.mix_dir, "*.wav")
34 | )
35 | self.hours = 0 # total number of hours of data in the subset
36 |
37 | # loop over files to count total length
38 | for filename in self.files:
39 | si, ei = torchaudio.info(filename)
40 | self.hours += (si.length / si.rate) / 3600
41 |
42 | # we then want to remove the path and extract just file ids
43 | self.files = [os.path.basename(filename) for filename in self.files]
44 | print(
45 | f"Located {len(self.files)} examples totaling {self.hours:0.1f} hr in the {self.subset} subset."
46 | )
47 |
48 | def __len__(self):
49 | return 32 # len(self.files)
50 |
51 | def __getitem__(self, idx):
52 |
53 | eid = self.files[idx]
54 |
55 | # use torchaudio to load them, which should be pretty fast
56 | s1, sr = torchaudio.load(os.path.join(self.root_dir, self.subset, "s1", eid))
57 | s2, sr = torchaudio.load(os.path.join(self.root_dir, self.subset, "s2", eid))
58 | noise, sr = torchaudio.load(
59 | os.path.join(self.root_dir, self.subset, "noise", eid)
60 | )
61 | mix, sr = torchaudio.load(
62 | os.path.join(self.root_dir, self.subset, self.mix_dir, eid)
63 | )
64 |
65 | # get the length of the current file in samples
66 | si, ei = torchaudio.info(os.path.join(self.root_dir, self.subset, "s1", eid))
67 |
68 | # pad if too short
69 | if si.length < self.length:
70 | pad_length = self.length - si.length
71 | s1 = torch.nn.functional.pad(s1, (0, pad_length))
72 | s2 = torch.nn.functional.pad(s2, (0, pad_length))
73 | noise = torch.nn.functional.pad(noise, (0, pad_length))
74 | mix = torch.nn.functional.pad(mix, (0, pad_length))
75 | si.length = self.length
76 |
77 | # choose a random patch of `length` samples for training
78 | start_idx = np.random.randint(0, si.length - self.length + 1)
79 | stop_idx = start_idx + self.length
80 |
81 | # extract these patches from each sample
82 | s1 = s1[0, start_idx:stop_idx].unsqueeze(dim=0)
83 | s2 = s2[0, start_idx:stop_idx].unsqueeze(dim=0)
84 | noise = noise[0, start_idx:stop_idx].unsqueeze(dim=0)
85 | mix = mix[0, start_idx:stop_idx].unsqueeze(dim=0)
86 |
87 | return s1, s2, noise, mix
88 |
--------------------------------------------------------------------------------
/auraloss/perceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class SumAndDifference(torch.nn.Module):
6 | """Sum and difference signal extraction module."""
7 |
8 | def __init__(self):
9 | """Initialize sum and difference extraction module."""
10 | super(SumAndDifference, self).__init__()
11 |
12 | def forward(self, x):
13 | """Calculate forward propagation.
14 |
15 | Args:
16 | x (Tensor): Predicted signal (B, #channels, #samples).
17 | Returns:
18 | Tensor: Sum signal.
19 | Tensor: Difference signal.
20 | """
21 | if not (x.size(1) == 2): # inputs must be stereo
22 | raise ValueError(f"Input must be stereo: {x.size(1)} channel(s).")
23 |
24 | sum_sig = self.sum(x).unsqueeze(1)
25 | diff_sig = self.diff(x).unsqueeze(1)
26 |
27 | return sum_sig, diff_sig
28 |
29 | @staticmethod
30 | def sum(x):
31 | return x[:, 0, :] + x[:, 1, :]
32 |
33 | @staticmethod
34 | def diff(x):
35 | return x[:, 0, :] - x[:, 1, :]
36 |
37 |
38 | class FIRFilter(torch.nn.Module):
39 | """FIR pre-emphasis filtering module.
40 |
41 | Args:
42 | filter_type (str): Shape of the desired FIR filter ("hp", "fd", "aw"). Default: "hp"
43 | coef (float): Coefficient value for the filter tap (only applicable for "hp" and "fd"). Default: 0.85
44 | ntaps (int): Number of FIR filter taps for constructing A-weighting filters. Default: 101
45 | plot (bool): Plot the magnitude respond of the filter. Default: False
46 |
47 | Based upon the perceptual loss pre-empahsis filters proposed by
48 | [Wright & Välimäki, 2019](https://arxiv.org/abs/1911.08922).
49 |
50 | A-weighting filter - "aw"
51 | First-order highpass - "hp"
52 | Folded differentiator - "fd"
53 |
54 | Note that the default coefficeint value of 0.85 is optimized for
55 | a sampling rate of 44.1 kHz, considering adjusting this value at differnt sampling rates.
56 | """
57 |
58 | def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False):
59 | """Initilize FIR pre-emphasis filtering module."""
60 | super(FIRFilter, self).__init__()
61 | self.filter_type = filter_type
62 | self.coef = coef
63 | self.fs = fs
64 | self.ntaps = ntaps
65 | self.plot = plot
66 |
67 | import scipy.signal
68 |
69 | if ntaps % 2 == 0:
70 | raise ValueError(f"ntaps must be odd (ntaps={ntaps}).")
71 |
72 | if filter_type == "hp":
73 | self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1)
74 | self.fir.weight.requires_grad = False
75 | self.fir.weight.data = torch.tensor([1, -coef, 0]).view(1, 1, -1)
76 | elif filter_type == "fd":
77 | self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1)
78 | self.fir.weight.requires_grad = False
79 | self.fir.weight.data = torch.tensor([1, 0, -coef]).view(1, 1, -1)
80 | elif filter_type == "aw":
81 | # Definition of analog A-weighting filter according to IEC/CD 1672.
82 | f1 = 20.598997
83 | f2 = 107.65265
84 | f3 = 737.86223
85 | f4 = 12194.217
86 | A1000 = 1.9997
87 |
88 | NUMs = [(2 * np.pi * f4) ** 2 * (10 ** (A1000 / 20)), 0, 0, 0, 0]
89 | DENs = np.polymul(
90 | [1, 4 * np.pi * f4, (2 * np.pi * f4) ** 2],
91 | [1, 4 * np.pi * f1, (2 * np.pi * f1) ** 2],
92 | )
93 | DENs = np.polymul(
94 | np.polymul(DENs, [1, 2 * np.pi * f3]), [1, 2 * np.pi * f2]
95 | )
96 |
97 | # convert analog filter to digital filter
98 | b, a = scipy.signal.bilinear(NUMs, DENs, fs=fs)
99 |
100 | # compute the digital filter frequency response
101 | w_iir, h_iir = scipy.signal.freqz(b, a, worN=512, fs=fs)
102 |
103 | # then we fit to 101 tap FIR filter with least squares
104 | taps = scipy.signal.firls(ntaps, w_iir, abs(h_iir), fs=fs)
105 |
106 | # now implement this digital FIR filter as a Conv1d layer
107 | self.fir = torch.nn.Conv1d(
108 | 1, 1, kernel_size=ntaps, bias=False, padding=ntaps // 2
109 | )
110 | self.fir.weight.requires_grad = False
111 | self.fir.weight.data = torch.tensor(taps.astype("float32")).view(1, 1, -1)
112 |
113 | if plot:
114 | from .plotting import compare_filters
115 | compare_filters(b, a, taps, fs=fs)
116 |
117 | def forward(self, input, target):
118 | """Calculate forward propagation.
119 | Args:
120 | input (Tensor): Predicted signal (B, #channels, #samples).
121 | target (Tensor): Groundtruth signal (B, #channels, #samples).
122 | Returns:
123 | Tensor: Filtered signal.
124 | """
125 | input = torch.nn.functional.conv1d(
126 | input, self.fir.weight.data, padding=self.ntaps // 2
127 | )
128 | target = torch.nn.functional.conv1d(
129 | target, self.fir.weight.data, padding=self.ntaps // 2
130 | )
131 | return input, target
132 |
--------------------------------------------------------------------------------
/tests/test_auraloss.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import torch
4 | import auraloss
5 |
6 |
7 | def test_mrstft():
8 | target = torch.rand(8, 2, 44100)
9 | pred = torch.rand(8, 2, 44100)
10 |
11 | loss = auraloss.freq.MultiResolutionSTFTLoss()
12 | res = loss(pred, target)
13 | assert res is not None
14 |
15 |
16 | def test_stft():
17 | target = torch.rand(8, 2, 44100)
18 | pred = torch.rand(8, 2, 44100)
19 |
20 | loss = auraloss.freq.STFTLoss()
21 | res = loss(pred, target)
22 | assert res is not None
23 |
24 |
25 | def test_stft_weights_a():
26 | target = torch.rand(8, 2, 44100)
27 | pred = torch.rand(8, 2, 44100)
28 | # test difference weights
29 | loss = auraloss.freq.STFTLoss(
30 | w_log_mag=1.0,
31 | w_lin_mag=0.0,
32 | w_sc=1.0,
33 | reduction="mean",
34 | )
35 | res = loss(pred, target)
36 | assert res is not None
37 |
38 |
39 | def test_stft_reduction():
40 | target = torch.rand(8, 2, 44100)
41 | pred = torch.rand(8, 2, 44100)
42 | # test the reduction
43 | loss = auraloss.freq.STFTLoss(
44 | w_log_mag=1.0,
45 | w_lin_mag=1.0,
46 | w_sc=0.0,
47 | reduction="none",
48 | )
49 | res = loss(pred, target)
50 | print(res.shape)
51 | assert len(res.shape) > 1
52 |
53 |
54 | def test_sum_and_difference():
55 | target = torch.rand(8, 2, 44100)
56 | pred = torch.rand(8, 2, 44100)
57 | loss = auraloss.freq.SumAndDifferenceSTFTLoss(
58 | fft_sizes=[512, 2048, 8192],
59 | hop_sizes=[128, 512, 2048],
60 | win_lengths=[512, 2048, 8192],
61 | )
62 | res = loss(pred, target)
63 | assert res is not None
64 |
65 |
66 | def test_perceptual_sum_and_difference():
67 | target = torch.rand(8, 2, 44100)
68 | pred = torch.rand(8, 2, 44100)
69 | loss_fn = auraloss.freq.SumAndDifferenceSTFTLoss(
70 | fft_sizes=[512, 2048, 8192],
71 | hop_sizes=[128, 512, 2048],
72 | win_lengths=[512, 2048, 8192],
73 | perceptual_weighting=True,
74 | sample_rate=44100,
75 | )
76 |
77 | res = loss_fn(pred, target)
78 | assert res is not None
79 |
80 |
81 | def test_perceptual_mel_sum_and_difference():
82 | target = torch.rand(8, 2, 44100)
83 | pred = torch.rand(8, 2, 44100)
84 | loss_fn = auraloss.freq.SumAndDifferenceSTFTLoss(
85 | fft_sizes=[1024, 2048, 8192],
86 | hop_sizes=[256, 512, 2048],
87 | win_lengths=[1024, 2048, 8192],
88 | perceptual_weighting=True,
89 | sample_rate=44100,
90 | scale="mel",
91 | n_bins=128,
92 | )
93 |
94 | res = loss_fn(pred, target)
95 | assert res is not None
96 |
97 |
98 | def test_melstft():
99 | target = torch.rand(8, 2, 44100)
100 | pred = torch.rand(8, 2, 44100)
101 | # test MelSTFT
102 | loss = auraloss.freq.MelSTFTLoss(44100)
103 | res = loss(pred, target)
104 | assert res is not None
105 |
106 |
107 | def test_melstft_reduction():
108 | target = torch.rand(8, 2, 44100)
109 | pred = torch.rand(8, 2, 44100)
110 | # test MelSTFT with no reduction
111 | loss = auraloss.freq.MelSTFTLoss(44100, reduction="none")
112 | res = loss(pred, target)
113 | assert len(res) > 1
114 |
115 |
116 | def test_multires_mel():
117 | target = torch.rand(8, 2, 44100)
118 | pred = torch.rand(8, 2, 44100)
119 | sample_rate = 44100
120 | loss = auraloss.freq.MultiResolutionSTFTLoss(
121 | scale="mel",
122 | n_bins=64,
123 | sample_rate=sample_rate,
124 | )
125 | res = loss(pred, target)
126 | assert res is not None
127 |
128 |
129 | def test_perceptual_multires_mel():
130 | target = torch.rand(8, 2, 44100)
131 | pred = torch.rand(8, 2, 44100)
132 | sample_rate = 44100
133 | loss = auraloss.freq.MultiResolutionSTFTLoss(
134 | fft_sizes=[1024, 2048, 8192],
135 | hop_sizes=[256, 512, 2048],
136 | win_lengths=[1024, 2048, 8192],
137 | scale="mel",
138 | n_bins=128,
139 | sample_rate=sample_rate,
140 | perceptual_weighting=True,
141 | )
142 | res = loss(pred, target)
143 | assert res is not None
144 |
145 |
146 | def test_stft_l2():
147 | N = 32
148 | n = torch.arange(N)
149 |
150 | f = N / 4
151 | target = torch.cos(2 * math.pi * f * n / N)
152 | target = target[None, None, :]
153 | pred = torch.zeros_like(target)
154 |
155 | loss = auraloss.freq.STFTLoss(
156 | fft_size=N,
157 | hop_size=N + 1, # eliminate padding artefacts by enforcing only one hop
158 | win_length=N,
159 | window="ones", # eliminate windowing artefacts
160 | w_sc=0.0,
161 | w_log_mag=0.0,
162 | w_lin_mag=1.0,
163 | w_phs=0.0,
164 | mag_distance="L2",
165 | )
166 | res = loss(pred, target)
167 |
168 | # MSE of energy in concentrated in single DFT bin
169 | expected_loss = ((N // 2) ** 2) / (N // 2 + 1)
170 |
171 | torch.testing.assert_close(res, torch.tensor(expected_loss), rtol=1e-3, atol=1e-3)
172 |
173 |
174 | def test_multires_l2():
175 | N = 32
176 | n = torch.arange(N)
177 |
178 | f = N / 4
179 | target = torch.cos(2 * math.pi * f * n / N)
180 | target = target[None, None, :]
181 | pred = torch.zeros_like(target)
182 |
183 | loss = auraloss.freq.MultiResolutionSTFTLoss(
184 | fft_sizes=[N],
185 | hop_sizes=[N + 1], # eliminate padding artefacts by enforcing only one hop
186 | win_lengths=[N],
187 | window="ones", # eliminate windowing artefacts
188 | w_sc=0.0,
189 | w_log_mag=0.0,
190 | w_lin_mag=1.0,
191 | w_phs=0.0,
192 | mag_distance="L2",
193 | )
194 | res = loss(pred, target)
195 |
196 | expected_loss = ((N // 2) ** 2) / (N // 2 + 1)
197 |
198 | torch.testing.assert_close(res, torch.tensor(expected_loss), rtol=1e-3, atol=1e-3)
199 |
--------------------------------------------------------------------------------
/examples/compressor/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import glob
4 | import torch
5 | import torchaudio
6 | import numpy as np
7 | import soundfile as sf
8 |
9 | torchaudio.set_audio_backend("sox_io")
10 |
11 |
12 | class SignalTrainLA2ADataset(torch.utils.data.Dataset):
13 | """SignalTrain LA2A dataset. Source: [10.5281/zenodo.3824876](https://zenodo.org/record/3824876)."""
14 |
15 | def __init__(
16 | self,
17 | root_dir,
18 | subset="train",
19 | length=16384,
20 | preload=False,
21 | half=True,
22 | use_soundfile=False,
23 | ):
24 | """
25 | Args:
26 | root_dir (str): Path to the root directory of the SignalTrain dataset.
27 | subset (str, optional): Pull data either from "train", "val", or "test" subsets. (Default: "train")
28 | length (int, optional): Number of samples in the returned examples. (Default: 40)
29 | preload (bool, optional): Read in all data into RAM during init. (Default: False)
30 | half (bool, optional): Store the float32 audio as float16. (Default: True)
31 | use_soundfile (bool, optional): Use the soundfile library to load instead of torchaudio. (Default: False)
32 | """
33 | self.root_dir = root_dir
34 | self.subset = subset
35 | self.length = length
36 | self.preload = preload
37 | self.half = half
38 | self.use_soundfile = use_soundfile
39 |
40 | # get all the target files files in the directory first
41 | self.target_files = glob.glob(
42 | os.path.join(self.root_dir, self.subset.capitalize(), "target_*.wav")
43 | )
44 | self.input_files = glob.glob(
45 | os.path.join(self.root_dir, self.subset.capitalize(), "input_*.wav")
46 | )
47 |
48 | self.examples = []
49 | self.hours = 0 # total number of hours of data in the subset
50 |
51 | # ensure that the sets are ordered correctlty
52 | self.target_files.sort()
53 | self.input_files.sort()
54 |
55 | # get the parameters
56 | self.params = [
57 | (
58 | float(f.split("__")[1].replace(".wav", "")),
59 | float(f.split("__")[2].replace(".wav", "")),
60 | )
61 | for f in self.target_files
62 | ]
63 |
64 | # loop over files to count total length
65 | for idx, (tfile, ifile, params) in enumerate(
66 | zip(self.target_files, self.input_files, self.params)
67 | ):
68 |
69 | ifile_id = int(os.path.basename(ifile).split("_")[1])
70 | tfile_id = int(os.path.basename(tfile).split("_")[1])
71 | if ifile_id != tfile_id:
72 | raise RuntimeError(
73 | f"Found non-matching file ids: {ifile_id} != {tfile_id}! Check dataset."
74 | )
75 |
76 | md = torchaudio.info(tfile)
77 | self.hours += (md.num_frames / md.sample_rate) / 3600
78 | num_frames = md.num_frames
79 |
80 | if self.preload:
81 | sys.stdout.write(
82 | f"* Pre-loading... {idx+1:3d}/{len(self.target_files):3d} ...\r"
83 | )
84 | sys.stdout.flush()
85 | input, sr = self.load(ifile)
86 | target, sr = self.load(tfile)
87 |
88 | num_frames = int(np.min([input.shape[-1], target.shape[-1]]))
89 | if input.shape[-1] != target.shape[-1]:
90 | print(
91 | os.path.basename(ifile),
92 | input.shape[-1],
93 | os.path.basename(tfile),
94 | target.shape[-1],
95 | )
96 | raise RuntimeError("Found potentially corrupt file!")
97 | if self.half:
98 | input = input.half()
99 | target = target.half()
100 | else:
101 | input = None
102 | target = None
103 |
104 | # create one entry for each patch
105 | for n in range((num_frames // self.length) - 1):
106 | offset = int(n * self.length)
107 | end = offset + self.length
108 | self.examples.append(
109 | {
110 | "idx": idx,
111 | "target_file": tfile,
112 | "input_file": ifile,
113 | "input_audio": input[:, offset:end]
114 | if input is not None
115 | else None,
116 | "target_audio": target[:, offset:end]
117 | if input is not None
118 | else None,
119 | "params": params,
120 | "offset": offset,
121 | "frames": num_frames,
122 | }
123 | )
124 |
125 | # we then want to get the input files
126 | print(
127 | f"Located {len(self.examples)} examples totaling {self.hours:0.1f} hr in the {self.subset} subset."
128 | )
129 |
130 | def __len__(self):
131 | return len(self.examples)
132 |
133 | def __getitem__(self, idx):
134 | if self.preload:
135 | audio_idx = self.examples[idx]["idx"]
136 | offset = self.examples[idx]["offset"]
137 | input = self.examples[idx]["input_audio"]
138 | target = self.examples[idx]["target_audio"]
139 | else:
140 | offset = self.examples[idx]["offset"]
141 | input, sr = torchaudio.load(
142 | self.examples[idx]["input_file"],
143 | num_frames=self.length,
144 | frame_offset=offset,
145 | normalize=False,
146 | )
147 | target, sr = torchaudio.load(
148 | self.examples[idx]["target_file"],
149 | num_frames=self.length,
150 | frame_offset=offset,
151 | normalize=False,
152 | )
153 | if self.half:
154 | input = input.half()
155 | target = target.half()
156 |
157 | # at random with p=0.5 flip the phase
158 | if np.random.rand() > 0.5:
159 | input *= -1
160 | target *= -1
161 |
162 | # then get the tuple of parameters
163 | params = torch.tensor(self.examples[idx]["params"]).unsqueeze(0)
164 | params[:, 1] /= 100
165 |
166 | return input, target, params
167 |
168 | def load(self, filename):
169 | if self.use_soundfile:
170 | x, sr = sf.read(filename, always_2d=True)
171 | x = torch.tensor(x.T)
172 | else:
173 | x, sr = torchaudio.load(filename, normalize=False)
174 | return x, sr
175 |
--------------------------------------------------------------------------------
/tests/simple_train_gpu.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import auraloss
3 | import torchaudio
4 | from tqdm import tqdm
5 |
6 |
7 | def center_crop(x, length: int):
8 | start = (x.shape[-1] - length) // 2
9 | stop = start + length
10 | return x[..., start:stop]
11 |
12 |
13 | def causal_crop(x, length: int):
14 | stop = x.shape[-1] - 1
15 | start = stop - length
16 | return x[..., start:stop]
17 |
18 |
19 | class TCNBlock(torch.nn.Module):
20 | def __init__(
21 | self,
22 | in_ch,
23 | out_ch,
24 | kernel_size=3,
25 | padding="same",
26 | dilation=1,
27 | grouped=False,
28 | causal=False,
29 | **kwargs,
30 | ):
31 | super(TCNBlock, self).__init__()
32 |
33 | self.in_ch = in_ch
34 | self.out_ch = out_ch
35 | self.kernel_size = kernel_size
36 | self.padding = padding
37 | self.dilation = dilation
38 | self.grouped = grouped
39 | self.causal = causal
40 |
41 | groups = out_ch if grouped and (in_ch % out_ch == 0) else 1
42 |
43 | if padding == "same":
44 | pad_value = (kernel_size - 1) + ((kernel_size - 1) * (dilation - 1))
45 | elif padding in ["none", "valid"]:
46 | pad_value = 0
47 |
48 | self.conv1 = torch.nn.Conv1d(
49 | in_ch,
50 | out_ch,
51 | kernel_size=kernel_size,
52 | padding=0, # testing a change in padding was pad_value//2
53 | dilation=dilation,
54 | groups=groups,
55 | bias=False,
56 | )
57 | if grouped:
58 | self.conv1b = torch.nn.Conv1d(out_ch, out_ch, kernel_size=1)
59 |
60 | else:
61 | self.bn = torch.nn.BatchNorm1d(out_ch)
62 |
63 | self.relu = torch.nn.PReLU(out_ch)
64 | self.res = torch.nn.Conv1d(
65 | in_ch, out_ch, kernel_size=1, groups=in_ch, bias=False
66 | )
67 |
68 | def forward(self, x):
69 | x_in = x
70 |
71 | x = self.conv1(x)
72 | # x = self.bn(x)
73 | x = self.relu(x)
74 |
75 | x_res = self.res(x_in)
76 | if self.causal:
77 | x = x + causal_crop(x_res, x.shape[-1])
78 | else:
79 | x = x + center_crop(x_res, x.shape[-1])
80 |
81 | return x
82 |
83 |
84 | class TCNModel(torch.nn.Module):
85 | """Temporal convolutional network.
86 | Args:
87 | nparams (int): Number of conditioning parameters.
88 | ninputs (int): Number of input channels (mono = 1, stereo 2). Default: 1
89 | noutputs (int): Number of output channels (mono = 1, stereo 2). Default: 1
90 | nblocks (int): Number of total TCN blocks. Default: 10
91 | kernel_size (int): Width of the convolutional kernels. Default: 3
92 | dialation_growth (int): Compute the dilation factor at each block as dilation_growth ** (n % stack_size). Default: 1
93 | channel_growth (int): Compute the output channels at each black as in_ch * channel_growth. Default: 2
94 | channel_width (int): When channel_growth = 1 all blocks use convolutions with this many channels. Default: 64
95 | stack_size (int): Number of blocks that constitute a single stack of blocks. Default: 10
96 | grouped (bool): Use grouped convolutions to reduce the total number of parameters. Default: False
97 | causal (bool): Causal TCN configuration does not consider future input values. Default: False
98 | """
99 |
100 | def __init__(
101 | self,
102 | ninputs=1,
103 | noutputs=1,
104 | nblocks=10,
105 | kernel_size=3,
106 | dilation_growth=1,
107 | channel_growth=1,
108 | channel_width=32,
109 | stack_size=10,
110 | grouped=False,
111 | causal=False,
112 | ):
113 | super().__init__()
114 |
115 | self.blocks = torch.nn.ModuleList()
116 | for n in range(nblocks):
117 | in_ch = out_ch if n > 0 else ninputs
118 |
119 | if channel_growth > 1:
120 | out_ch = in_ch * self.hparams.channel_growth
121 | else:
122 | out_ch = channel_width
123 |
124 | dilation = dilation_growth ** (n % stack_size)
125 | self.blocks.append(
126 | TCNBlock(
127 | in_ch,
128 | out_ch,
129 | kernel_size=kernel_size,
130 | dilation=dilation,
131 | padding="same" if causal else "valid",
132 | causal=causal,
133 | grouped=grouped,
134 | )
135 | )
136 |
137 | self.output = torch.nn.Conv1d(out_ch, noutputs, kernel_size=1)
138 |
139 | def forward(self, x):
140 | # iterate over blocks passing conditioning
141 | for idx, block in enumerate(self.blocks):
142 | x = block(x)
143 | return torch.tanh(self.output(x))
144 |
145 | def compute_receptive_field(self):
146 | """Compute the receptive field in samples."""
147 | rf = self.hparams.kernel_size
148 | for n in range(1, self.hparams.nblocks):
149 | dilation = self.hparams.dilation_growth ** (n % self.hparams.stack_size)
150 | rf = rf + ((self.hparams.kernel_size - 1) * dilation)
151 | return
152 |
153 |
154 | if __name__ == "__main__":
155 |
156 | y, sr = torchaudio.load("../sounds/assets/drum_kit_clean.wav")
157 | y /= y.abs().max()
158 | # y = y.repeat(2, 1)
159 | print(y.shape)
160 |
161 | # created distorted copy
162 | # x = torch.tanh(y * 4.0)
163 | x = y + (0.01 * torch.randn(y.shape))
164 |
165 | # move data to gpu
166 | x = x.cuda()
167 | y = y.cuda()
168 |
169 | x = x.view(1, 2, -1)
170 | y = y.view(1, 2, -1)
171 |
172 | # create simple network
173 | model = TCNModel(ninputs=2, noutputs=2, kernel_size=13, dilation_growth=2)
174 | model.cuda()
175 |
176 | # create loss function
177 | loss_fn = auraloss.freq.SumAndDifferenceSTFTLoss(
178 | fft_sizes=[1024, 2048, 8192],
179 | hop_sizes=[256, 512, 2048],
180 | win_lengths=[1024, 2048, 8192],
181 | perceptual_weighting=True,
182 | sample_rate=44100,
183 | scale="mel",
184 | n_bins=128,
185 | )
186 | # loss_fn.cuda()
187 |
188 | # loss_fn = auraloss.freq.MultiResolutionSTFTLoss( fft_sizes=[1024, 2048, 8192],
189 | # hop_sizes=[256, 512, 2048],
190 | # win_lengths=[1024, 2048, 8192],)
191 |
192 | # loss_fn = torch.nn.MSELoss()
193 |
194 | optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
195 |
196 | # run optimization
197 | pbar = tqdm(range(1000))
198 | for iter_idx in pbar:
199 |
200 | optimizer.zero_grad()
201 |
202 | y_hat = model(x)
203 |
204 | y_crop = causal_crop(y, y_hat.shape[-1])
205 |
206 | loss = loss_fn(y_hat, y_crop)
207 | loss.backward()
208 | optimizer.step()
209 |
210 | pbar.set_description(f"loss: {loss.item():0.4f}")
211 |
212 | torchaudio.save(
213 | "tests/simple_train_gpu_output.wav", y_hat.detach().cpu().view(2, -1), sr
214 | )
215 | torchaudio.save(
216 | "tests/simple_train_gpu_input.wav", x.detach().cpu().view(2, -1), sr
217 | )
218 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
| Loss function | 88 |Interface | 89 |Reference | 90 |
|---|---|---|
| Time domain | 93 |||
| Error-to-signal ratio (ESR) | 96 |auraloss.time.ESRLoss() |
97 | Wright & Välimäki, 2019 | 98 |
| DC error (DC) | 101 |auraloss.time.DCLoss() |
102 | Wright & Välimäki, 2019 | 103 |
| Log hyperbolic cosine (Log-cosh) | 106 |auraloss.time.LogCoshLoss() |
107 | Chen et al., 2019 | 108 |
| Signal-to-noise ratio (SNR) | 111 |auraloss.time.SNRLoss() |
112 | 113 | |
| Scale-invariant signal-to-distortion ratio (SI-SDR) |
116 | auraloss.time.SISDRLoss() |
117 | Le Roux et al., 2018 | 118 |
| Scale-dependent signal-to-distortion ratio (SD-SDR) |
121 | auraloss.time.SDSDRLoss() |
122 | Le Roux et al., 2018 | 123 |
| Frequency domain | 126 |||
| Aggregate STFT | 129 |auraloss.freq.STFTLoss() |
130 | Arik et al., 2018 | 131 |
| Aggregate Mel-scaled STFT | 134 |auraloss.freq.MelSTFTLoss(sample_rate) |
135 | 136 | |
| Multi-resolution STFT | 139 |auraloss.freq.MultiResolutionSTFTLoss() |
140 | Yamamoto et al., 2019* | 141 |
| Random-resolution STFT | 144 |auraloss.freq.RandomResolutionSTFTLoss() |
145 | Steinmetz & Reiss, 2020 | 146 |
| Sum and difference STFT loss | 149 |auraloss.freq.SumAndDifferenceSTFTLoss() |
150 | Steinmetz et al., 2020 | 151 |
| Perceptual transforms | 154 |||
| Sum and difference signal transform | 157 |auraloss.perceptual.SumAndDifference() |
158 | 159 | |
| FIR pre-emphasis filters | 162 |auraloss.perceptual.FIRFilter() |
163 | Wright & Välimäki, 2019 | 164 |