├── tests ├── __init__.py ├── manual_test_gpu.py ├── test_auraloss.py └── simple_train_gpu.py ├── .pre-commit-config.yaml ├── auraloss ├── __init__.py ├── plotting.py ├── utils.py ├── perceptual.py ├── time.py └── freq.py ├── examples ├── compressor │ ├── test.sh │ ├── train.sh │ ├── comp-metrics.json │ ├── _test_comp.py │ ├── README.md │ ├── train_comp.py │ ├── data.py │ └── tcn.py └── speech-denoise │ ├── train_denoise.py │ └── data.py ├── pyproject.toml ├── .github └── workflows │ ├── tests.yml │ └── publish-to-test-pypi.yml ├── setup.py ├── docs └── auraloss-logo.svg ├── .gitignore ├── README.md └── LICENSE /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: 21.5b1 4 | hooks: 5 | - id: black 6 | language_version: python3 -------------------------------------------------------------------------------- /auraloss/__init__.py: -------------------------------------------------------------------------------- 1 | """Top-level module for auraloss""" 2 | 3 | # Import auraloss sub-modules 4 | from . import freq 5 | from . import time 6 | from . import perceptual 7 | -------------------------------------------------------------------------------- /tests/manual_test_gpu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import auraloss 3 | 4 | y_hat = torch.randn(2, 1, 131072) 5 | y = torch.randn(2, 1, 131072) 6 | 7 | loss_fn = auraloss.freq.MelSTFTLoss(44100) 8 | loss_fn2 = auraloss.freq.MultiResolutionSTFTLoss() 9 | 10 | # loss_fn.cuda() 11 | 12 | y_hat = y_hat.cuda() 13 | y = y.cuda() 14 | 15 | loss = loss_fn2(y_hat, y) 16 | loss = loss_fn(y_hat, y) 17 | -------------------------------------------------------------------------------- /examples/compressor/test.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2 python examples/compressor/test_comp.py \ 2 | --root_dir /import/c4dm-datasets/SignalTrain_LA2A_Dataset_1.1 \ 3 | --logdir lightning_logs/version_9 \ 4 | --batch_size 128 \ 5 | --sample_rate 44100 \ 6 | --eval_subset "test" \ 7 | --eval_length 262144 \ 8 | --num_workers 8 \ 9 | --gpus 1 \ 10 | --precision 16 \ 11 | --preload True \ 12 | --save_dir "./examples/compressor/audio" \ 13 | --num_examples 100 \ 14 | #--auto_lr_find 15 | -------------------------------------------------------------------------------- /examples/compressor/train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python examples/compressor/train_comp.py \ 2 | --root_dir /path/to/SignalTrain_LA2A_Dataset_1.1 \ 3 | --max_epochs 20 \ 4 | --batch_size 128 \ 5 | --sample_rate 44100 \ 6 | --train_length 32768 \ 7 | --eval_length 262144 \ 8 | --num_workers 8 \ 9 | --kernel_size 15 \ 10 | --channel_width 32 \ 11 | --dilation_growth 2 \ 12 | --lr 0.001 \ 13 | --gpus 1 \ 14 | --shuffle True \ 15 | --precision 16 \ 16 | --preload True \ 17 | #--auto_lr_find 18 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "auraloss" 3 | version = "0.4.0" 4 | description = "Collection of audio-focused loss functions in PyTorch." 5 | authors = [ 6 | { name = "Christian Steinmetz" }, 7 | { email = "c.j.steinmetz@qmul.ac.uk" }, 8 | ] 9 | dependencies = ["torch", "numpy"] 10 | 11 | [build-system] 12 | # Minimum requirements for the build system to execute. 13 | requires = ["setuptools", "wheel", "attrs"] 14 | build-backend = "setuptools.build_meta" 15 | 16 | [project.optional-dependencies] 17 | all = ["matplotlib", "librosa", "scipy", "scipy"] 18 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: pytest 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: ["3.7", "3.8", "3.9", "3.10"] 12 | 13 | steps: 14 | - uses: actions/checkout@v3 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v4 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | sudo apt-get install libsndfile1-dev 23 | pip install .[all] 24 | pip install pytest 25 | - name: Test with pytest 26 | run: | 27 | pytest -------------------------------------------------------------------------------- /auraloss/plotting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.signal 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def compare_filters(iir_b, iir_a, fir_b, fs=1): 7 | 8 | # compute response for IIR filter 9 | w_iir, h_iir = scipy.signal.freqz(iir_b, iir_a, fs=fs, worN=2048) 10 | 11 | # compute response for FIR filter 12 | w_fir, h_fir = scipy.signal.freqz(fir_b, fs=fs) 13 | 14 | h_iir_db = 20 * np.log10(np.abs(h_iir) + 1e-8) 15 | h_fir_db = 20 * np.log10(np.abs(h_fir) + 1e-8) 16 | 17 | plt.plot(w_iir, h_iir_db, label="IIR filter") 18 | plt.plot(w_fir, h_fir_db, label="FIR approx. filter") 19 | plt.xscale("log") 20 | plt.ylim([-50, 10]) 21 | plt.xlim([10, 22.05e3]) 22 | plt.xlabel("Freq. (Hz)") 23 | plt.ylabel("Mag. (dB)") 24 | plt.legend() 25 | plt.grid() 26 | plt.show() 27 | -------------------------------------------------------------------------------- /auraloss/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import scipy.signal 3 | 4 | 5 | def apply_reduction(losses, reduction="none"): 6 | """Apply reduction to collection of losses.""" 7 | if reduction == "mean": 8 | losses = losses.mean() 9 | elif reduction == "sum": 10 | losses = losses.sum() 11 | return losses 12 | 13 | def get_window(win_type: str, win_length: int): 14 | """Return a window function. 15 | 16 | Args: 17 | win_type (str): Window type. Can either be one of the window function provided in PyTorch 18 | ['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window'] 19 | or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html). 20 | win_length (int): Window length 21 | 22 | Returns: 23 | win: The window as a 1D torch tensor 24 | """ 25 | 26 | try: 27 | win = getattr(torch, win_type)(win_length) 28 | except: 29 | win = torch.from_numpy(scipy.signal.windows.get_window(win_type, win_length)) 30 | 31 | return win 32 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-test-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python distributions to PyPI and TestPyPI 2 | 3 | on: push 4 | 5 | jobs: 6 | build-n-publish: 7 | name: Publish Python distributions to PyPI and TestPyPI 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@master 11 | - name: Set up Python 3.7 12 | uses: actions/setup-python@v1 13 | with: 14 | python-version: 3.7 15 | - name: Install pypa/build 16 | run: >- 17 | python -m 18 | pip install 19 | build 20 | --user 21 | - name: Build a binary wheel and a source tarball 22 | run: >- 23 | python -m 24 | build 25 | --sdist 26 | --wheel 27 | --outdir dist/ 28 | . 29 | - name: Publish distribution to Test PyPI 30 | if: startsWith(github.ref, 'refs/tags') 31 | uses: pypa/gh-action-pypi-publish@master 32 | with: 33 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 34 | repository_url: https://test.pypi.org/legacy/ 35 | - name: Publish distribution to PyPI 36 | if: startsWith(github.ref, 'refs/tags') 37 | uses: pypa/gh-action-pypi-publish@master 38 | with: 39 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Inspired from https://github.com/kennethreitz/setup.py 3 | 4 | from pathlib import Path 5 | from setuptools import setup, find_packages 6 | 7 | NAME = "auraloss" 8 | DESCRIPTION = "Audio-focused loss functions in PyTorch" 9 | URL = "https://github.com/csteinmetz1/auraloss" 10 | EMAIL = "c.j.steinmetz@qmul.ac.uk" 11 | AUTHOR = "Christian Steinmetz" 12 | REQUIRES_PYTHON = ">=3.6.0" 13 | VERSION = "0.4.0" 14 | 15 | HERE = Path(__file__).parent 16 | 17 | try: 18 | with open(HERE / "README.md", encoding="utf-8") as f: 19 | long_description = "\n" + f.read() 20 | except FileNotFoundError: 21 | long_description = DESCRIPTION 22 | 23 | setup( 24 | name=NAME, 25 | version=VERSION, 26 | description=DESCRIPTION, 27 | long_description=long_description, 28 | long_description_content_type="text/markdown", 29 | author=AUTHOR, 30 | author_email=EMAIL, 31 | python_requires=REQUIRES_PYTHON, 32 | url=URL, 33 | packages=["auraloss"], 34 | install_requires=["torch", "numpy"], 35 | extras_require={"test": ["resampy"], "all": ["matplotlib", "librosa", "scipy"]}, 36 | include_package_data=True, 37 | license="Apache License 2.0", 38 | classifiers=[ 39 | "License :: OSI Approved :: Apache Software License", 40 | "Topic :: Multimedia :: Sound/Audio", 41 | "Topic :: Scientific/Engineering", 42 | ], 43 | ) 44 | -------------------------------------------------------------------------------- /docs/auraloss-logo.svg: -------------------------------------------------------------------------------- 1 | auraloss-logoal -------------------------------------------------------------------------------- /examples/speech-denoise/train_denoise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | from argparse import ArgumentParser 4 | 5 | from tcn import TCNModel 6 | from data import LibriMixDataset 7 | 8 | parser = ArgumentParser() 9 | 10 | # add PROGRAM level args 11 | parser.add_argument("--root_dir", type=str, default="./data") 12 | parser.add_argument("--sample_rate", type=int, default=8000) 13 | parser.add_argument("--train_subset", type=str, default="train") 14 | parser.add_argument("--val_subset", type=str, default="val") 15 | parser.add_argument("--train_length", type=int, default=16384) 16 | parser.add_argument("--eval_length", type=int, default=32768) 17 | parser.add_argument("--batch_size", type=int, default=8) 18 | parser.add_argument("--num_workers", type=int, default=0) 19 | 20 | # add model specific args 21 | parser = TCNModel.add_model_specific_args(parser) 22 | 23 | # add all the available trainer options to argparse 24 | parser = pl.Trainer.add_argparse_args(parser) 25 | 26 | # parse them args 27 | args = parser.parse_args() 28 | 29 | # init the trainer and model 30 | trainer = pl.Trainer.from_argparse_args(args) 31 | 32 | # setup the dataloaders 33 | train_dataset = LibriMixDataset( 34 | args.root_dir, subset=args.train_subset, length=args.train_length 35 | ) 36 | train_dataloader = torch.utils.data.DataLoader( 37 | train_dataset, batch_size=args.batch_size, num_workers=args.num_workers 38 | ) 39 | 40 | val_dataset = LibriMixDataset( 41 | args.root_dir, subset=args.val_subset, length=args.eval_length 42 | ) 43 | val_dataloader = torch.utils.data.DataLoader( 44 | val_dataset, batch_size=args.batch_size, num_workers=args.num_workers 45 | ) 46 | 47 | dict_args = vars(args) 48 | model = TCNModel(**dict_args) 49 | 50 | 51 | # find proper learning rate 52 | trainer.tune(model, train_dataloader) 53 | 54 | # train! 55 | trainer.fit(model, train_dataloader, val_dataloader) 56 | -------------------------------------------------------------------------------- /examples/compressor/comp-metrics.json: -------------------------------------------------------------------------------- 1 | { 2 | "l1": [ 3 | { 4 | "val_loss": 2.192999839782715, 5 | "val_loss/L1": 0.004868578165769577, 6 | "val_loss/ESR": 0.009247570298612118, 7 | "val_loss/DC": 0.00017245385970454663, 8 | "val_loss/LogCosh": 2.789249992929399e-05, 9 | "val_loss/STFT": 0.8255783319473267, 10 | "val_loss/MRSTFT": 0.7987862825393677, 11 | "val_loss/RRSTFT": 0.5543186664581299 12 | } 13 | ], 14 | "logcosh": [ 15 | { 16 | "val_loss": 2.149608612060547, 17 | "val_loss/L1": 0.005570689216256142, 18 | "val_loss/ESR": 0.010359607636928558, 19 | "val_loss/DC": 0.0003921029274351895, 20 | "val_loss/LogCosh": 3.224901593057439e-05, 21 | "val_loss/STFT": 0.8068493604660034, 22 | "val_loss/MRSTFT": 0.7804977297782898, 23 | "val_loss/RRSTFT": 0.5459066033363342 24 | } 25 | ], 26 | "esr+dc": [ 27 | { 28 | "val_loss": 2.2166659832000732, 29 | "val_loss/L1": 0.005301260389387608, 30 | "val_loss/ESR": 0.009672155603766441, 31 | "val_loss/DC": 4.820536560146138e-05, 32 | "val_loss/LogCosh": 3.0196371881174855e-05, 33 | "val_loss/STFT": 0.8323166370391846, 34 | "val_loss/MRSTFT": 0.8060272932052612, 35 | "val_loss/RRSTFT": 0.5632702112197876 36 | } 37 | ], 38 | "stft": [ 39 | { 40 | "val_loss": 1.2915085554122925, 41 | "val_loss/L1": 0.008976267650723457, 42 | "val_loss/ESR": 0.054853539913892746, 43 | "val_loss/DC": 1.080286983778933e-05, 44 | "val_loss/LogCosh": 0.00017625998589210212, 45 | "val_loss/STFT": 0.452709823846817, 46 | "val_loss/MRSTFT": 0.43302521109580994, 47 | "val_loss/RRSTFT": 0.34175658226013184 48 | } 49 | ], 50 | "mrstft": [ 51 | { 52 | "val_loss": 1.2609974145889282, 53 | "val_loss/L1": 0.008976605720818043, 54 | "val_loss/ESR": 0.056137651205062866, 55 | "val_loss/DC": 5.441876783152111e-05, 56 | "val_loss/LogCosh": 0.0001797690347302705, 57 | "val_loss/STFT": 0.44097578525543213, 58 | "val_loss/MRSTFT": 0.42079290747642517, 59 | "val_loss/RRSTFT": 0.33388033509254456 60 | } 61 | ], 62 | "rrstft": [ 63 | { 64 | "val_loss": 1.6379966735839844, 65 | "val_loss/L1": 0.015477120876312256, 66 | "val_loss/ESR": 0.21662524342536926, 67 | "val_loss/DC": 0.00022139745124150068, 68 | "val_loss/LogCosh": 0.0007051324937492609, 69 | "val_loss/STFT": 0.5189737677574158, 70 | "val_loss/MRSTFT": 0.49819114804267883, 71 | "val_loss/RRSTFT": 0.38780292868614197 72 | } 73 | ] 74 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # more 132 | .vscode 133 | data/ 134 | .DS_Store 135 | lightning_logs/ 136 | 137 | test-env/ 138 | *.wav 139 | -------------------------------------------------------------------------------- /examples/compressor/_test_comp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import torch 5 | import torchsummary 6 | import pytorch_lightning as pl 7 | from argparse import ArgumentParser 8 | 9 | from tcn import TCNModel 10 | from data import SignalTrainLA2ADataset 11 | 12 | parser = ArgumentParser() 13 | 14 | # add PROGRAM level args 15 | parser.add_argument("--root_dir", type=str, default="./data") 16 | parser.add_argument("--preload", type=bool, default=False) 17 | parser.add_argument("--sample_rate", type=int, default=44100) 18 | parser.add_argument("--logdir", type=str, default="./") 19 | parser.add_argument("--eval_subset", type=str, default="val") 20 | parser.add_argument("--eval_length", type=int, default=262144) 21 | parser.add_argument("--batch_size", type=int, default=8) 22 | parser.add_argument("--num_workers", type=int, default=0) 23 | 24 | # add model specific args 25 | parser = TCNModel.add_model_specific_args(parser) 26 | 27 | # add all the available trainer options to argparse 28 | parser = pl.Trainer.add_argparse_args(parser) 29 | 30 | # parse them args 31 | args = parser.parse_args() 32 | 33 | # setup the dataloaders 34 | test_dataset = SignalTrainLA2ADataset( 35 | args.root_dir, 36 | subset=args.eval_subset, 37 | preload=args.preload, 38 | length=args.eval_length, 39 | ) 40 | 41 | test_dataloader = torch.utils.data.DataLoader( 42 | test_dataset, 43 | shuffle=False, 44 | batch_size=args.batch_size, 45 | num_workers=args.num_workers, 46 | ) 47 | 48 | results = {} 49 | 50 | # the losses we will test 51 | losses = ["l1", "logcosh", "esr+dc", "stft", "mrstft", "rrstft"] 52 | 53 | for loss_model in losses: 54 | 55 | root_logdir = os.path.join(args.logdir, loss_model, "lightning_logs", "version_0") 56 | 57 | checkpoint_path = glob.glob(os.path.join(root_logdir, "checkpoints", "*"))[0] 58 | print(checkpoint_path) 59 | hparams_file = os.path.join(root_logdir, "hparams.yaml") 60 | 61 | model = TCNModel.load_from_checkpoint( 62 | checkpoint_path=checkpoint_path, 63 | hparams_file=hparams_file, 64 | map_location="cuda:0", 65 | ) 66 | 67 | model.hparams.save_dir = args.save_dir 68 | model.hparams.num_examples = args.num_examples 69 | 70 | # init trainer with whatever options 71 | trainer = pl.Trainer.from_argparse_args(args) 72 | 73 | # set the seed 74 | pl.seed_everything(42) 75 | 76 | # test (pass in the model) 77 | res = trainer.test(model, test_dataloaders=test_dataloader) 78 | 79 | # store in dict 80 | results[loss_model] = res 81 | 82 | # save final metrics to disk 83 | with open(os.path.join("examples", "compressor", "comp-metrics.json"), "w") as fp: 84 | json.dump(results, fp, indent=True) 85 | -------------------------------------------------------------------------------- /examples/compressor/README.md: -------------------------------------------------------------------------------- 1 | # Analog dynamic range compressor modeling 2 | 3 | ## Dataset 4 | 5 | The [SignalTrain LA2A dataset](https://zenodo.org/record/3824876) (19GB) is available for download on Zenodo. 6 | This dataset contains monophonic audio examples, with input and output targets recorded from an [LA2A dynamic range compressor](https://en.wikipedia.org/wiki/LA-2A_Leveling_Amplifier). Different parameterizations of the two controls (threshold and compress/limit) as changed for each example. 7 | We provide a `DataLoader` in [`examples/compressor/data.py`](data.py). 8 | 9 | 10 | In our experiments we use V1.1, which makes corrections by time aligning some files. 11 | Download and extract this dataset before proceeding with the evaluation or retraining. 12 | 13 | ## Pre-trained models 14 | 15 | We provide the pre-trained model checkpoints for each of the six models for download [here](https://drive.google.com/file/d/1g1pHDVSOOtvJjIovfskX9X2295jqYD-J/view?usp=sharing) (16MB). Download this `.tgz` and extract it. 16 | You can run the evaluation (on the test set) with the [`examples/test_comp.py`](test_comp.py) script from the root direction after the dataset and the checkpoints have been downloaded. 17 | 18 | We evaluate with patches of 262,144 samples (~6 seconds at 44.1 kHz) and a batch size of 128 (same as training), which requires around 12 GB of VRAM. 19 | We evaluate with half precision, as the models were trained in half precision as well. 20 | The `preload` flag will load of the audio files into RAM before training starts to run faster than continually reading them from disk. 21 | 22 | Below is the call we used to generate the metrics in the [paper](https://www.christiansteinmetz.com/s/DMRN15__auraloss__Audio_focused_loss_functions_in_PyTorch.pdf). 23 | ``` 24 | python examples/test_comp.py \ 25 | --root_dir /path/to/SignalTrain_LA2A_Dataset_1.1 \ 26 | --logdir path/to/checkpoints/version_9 \ 27 | --batch_size 128 \ 28 | --sample_rate 44100 \ 29 | --eval_subset "test" \ 30 | --eval_length 262144 \ 31 | --num_workers 8 \ 32 | --gpus 1 \ 33 | --shuffle False \ 34 | --precision 16 \ 35 | --preload True \ 36 | ``` 37 | 38 | ## Retraining 39 | If you wish to retrain the models you can do so using the [`examples/train_comp.py`](train_comp.py) script. 40 | Below is the call we use to train the models. 41 | In this case we train each of the six models for 20 epochs, which takes ~6.5h per model, for a total of ~40h when training on an NVIDIA Quadro RTX 6000. 42 | ``` 43 | python examples/train_comp.py \ 44 | --root_dir /path/to/SignalTrain_LA2A_Dataset_1.1 \ 45 | --max_epochs 20 \ 46 | --batch_size 128 \ 47 | --sample_rate 44100 \ 48 | --train_length 32768 \ 49 | --eval_length 262144 \ 50 | --num_workers 8 \ 51 | --kernel_size 15 \ 52 | --channel_width 32 \ 53 | --dilation_growth 2 \ 54 | --lr 0.001 \ 55 | --gpus 1 \ 56 | --shuffle True \ 57 | --precision 16 \ 58 | --preload True \ 59 | ``` -------------------------------------------------------------------------------- /examples/compressor/train_comp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import pytorch_lightning as pl 5 | from argparse import ArgumentParser 6 | 7 | from tcn import TCNModel 8 | from data import SignalTrainLA2ADataset 9 | 10 | parser = ArgumentParser() 11 | 12 | # add PROGRAM level args 13 | parser.add_argument("--root_dir", type=str, default="./data") 14 | parser.add_argument("--preload", type=bool, default=False) 15 | parser.add_argument("--sample_rate", type=int, default=44100) 16 | parser.add_argument("--shuffle", type=bool, default=False) 17 | parser.add_argument("--train_subset", type=str, default="train") 18 | parser.add_argument("--val_subset", type=str, default="val") 19 | parser.add_argument("--train_length", type=int, default=32768) 20 | parser.add_argument("--eval_length", type=int, default=32768) 21 | parser.add_argument("--batch_size", type=int, default=8) 22 | parser.add_argument("--num_workers", type=int, default=0) 23 | 24 | # add model specific args 25 | parser = TCNModel.add_model_specific_args(parser) 26 | 27 | # add all the available trainer options to argparse 28 | parser = pl.Trainer.add_argparse_args(parser) 29 | 30 | # parse them args 31 | args = parser.parse_args() 32 | 33 | # setup the dataloaders 34 | train_dataset = SignalTrainLA2ADataset( 35 | args.root_dir, 36 | subset=args.train_subset, 37 | half=True if args.precision == 16 else False, 38 | preload=args.preload, 39 | length=args.train_length, 40 | ) 41 | 42 | train_dataloader = torch.utils.data.DataLoader( 43 | train_dataset, 44 | shuffle=args.shuffle, 45 | batch_size=args.batch_size, 46 | num_workers=args.num_workers, 47 | ) 48 | 49 | val_dataset = SignalTrainLA2ADataset( 50 | args.root_dir, 51 | preload=args.preload, 52 | half=True if args.precision == 16 else False, 53 | subset=args.val_subset, 54 | length=args.eval_length, 55 | ) 56 | 57 | val_dataloader = torch.utils.data.DataLoader( 58 | val_dataset, shuffle=False, batch_size=2, num_workers=args.num_workers 59 | ) 60 | 61 | 62 | past_logs = sorted(glob.glob(os.path.join("lightning_logs", "*"))) 63 | if len(past_logs) > 0: 64 | version = int(os.path.basename(past_logs[-1]).split("_")[-1]) + 1 65 | else: 66 | version = 0 67 | 68 | # the losses we will test 69 | if args.train_loss is None: 70 | losses = ["l1", "logcosh", "esr+dc", "stft", "mrstft", "rrstft"] 71 | else: 72 | losses = [args.train_loss] 73 | 74 | for loss_fn in losses: 75 | 76 | print(f"training with {loss_fn}") 77 | # init logger 78 | logdir = os.path.join("lightning_logs", f"version_{version}", loss_fn) 79 | print(logdir) 80 | args.default_root_dir = logdir 81 | 82 | # init the trainer and model 83 | trainer = pl.Trainer.from_argparse_args(args) 84 | print(trainer.default_root_dir) 85 | 86 | # set the seed 87 | pl.seed_everything(42) 88 | 89 | dict_args = vars(args) 90 | dict_args["nparams"] = 2 91 | dict_args["train_loss"] = loss_fn 92 | model = TCNModel(**dict_args) 93 | 94 | # train! 95 | trainer.fit(model, train_dataloader, val_dataloader) 96 | -------------------------------------------------------------------------------- /examples/speech-denoise/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | import torch 5 | import torchaudio 6 | import numpy as np 7 | import soundfile as sf 8 | 9 | torchaudio.set_audio_backend("sox_io") 10 | 11 | 12 | class LibriMixDataset(torch.utils.data.Dataset): 13 | """LibriMix dataset.""" 14 | 15 | def __init__(self, root_dir, subset="train", length=16384, noisy=False): 16 | """ 17 | Args: 18 | root_dir (str): Path to the preprocessed LibriMix files. 19 | subset (str, optional): Pull data either from "train", "val", or "test" subsets. (Default: "train") 20 | length (int, optional): Number of samples in the returned examples. (Default: 40) 21 | noise (bool, optional): Use mixtures with additive noise, otherwise anechoic mixes. (Default: False) 22 | """ 23 | self.root_dir = root_dir 24 | self.subset = subset 25 | self.length = length 26 | self.noisy = noisy 27 | 28 | # set the mix directory if we want clean or noisy mixes as input 29 | self.mix_dir = "mix_both" if self.noisy else "mix_clean" 30 | 31 | # get all the files in the mix directory first 32 | self.files = glob.glob( 33 | os.path.join(self.root_dir, self.subset, self.mix_dir, "*.wav") 34 | ) 35 | self.hours = 0 # total number of hours of data in the subset 36 | 37 | # loop over files to count total length 38 | for filename in self.files: 39 | si, ei = torchaudio.info(filename) 40 | self.hours += (si.length / si.rate) / 3600 41 | 42 | # we then want to remove the path and extract just file ids 43 | self.files = [os.path.basename(filename) for filename in self.files] 44 | print( 45 | f"Located {len(self.files)} examples totaling {self.hours:0.1f} hr in the {self.subset} subset." 46 | ) 47 | 48 | def __len__(self): 49 | return 32 # len(self.files) 50 | 51 | def __getitem__(self, idx): 52 | 53 | eid = self.files[idx] 54 | 55 | # use torchaudio to load them, which should be pretty fast 56 | s1, sr = torchaudio.load(os.path.join(self.root_dir, self.subset, "s1", eid)) 57 | s2, sr = torchaudio.load(os.path.join(self.root_dir, self.subset, "s2", eid)) 58 | noise, sr = torchaudio.load( 59 | os.path.join(self.root_dir, self.subset, "noise", eid) 60 | ) 61 | mix, sr = torchaudio.load( 62 | os.path.join(self.root_dir, self.subset, self.mix_dir, eid) 63 | ) 64 | 65 | # get the length of the current file in samples 66 | si, ei = torchaudio.info(os.path.join(self.root_dir, self.subset, "s1", eid)) 67 | 68 | # pad if too short 69 | if si.length < self.length: 70 | pad_length = self.length - si.length 71 | s1 = torch.nn.functional.pad(s1, (0, pad_length)) 72 | s2 = torch.nn.functional.pad(s2, (0, pad_length)) 73 | noise = torch.nn.functional.pad(noise, (0, pad_length)) 74 | mix = torch.nn.functional.pad(mix, (0, pad_length)) 75 | si.length = self.length 76 | 77 | # choose a random patch of `length` samples for training 78 | start_idx = np.random.randint(0, si.length - self.length + 1) 79 | stop_idx = start_idx + self.length 80 | 81 | # extract these patches from each sample 82 | s1 = s1[0, start_idx:stop_idx].unsqueeze(dim=0) 83 | s2 = s2[0, start_idx:stop_idx].unsqueeze(dim=0) 84 | noise = noise[0, start_idx:stop_idx].unsqueeze(dim=0) 85 | mix = mix[0, start_idx:stop_idx].unsqueeze(dim=0) 86 | 87 | return s1, s2, noise, mix 88 | -------------------------------------------------------------------------------- /auraloss/perceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class SumAndDifference(torch.nn.Module): 6 | """Sum and difference signal extraction module.""" 7 | 8 | def __init__(self): 9 | """Initialize sum and difference extraction module.""" 10 | super(SumAndDifference, self).__init__() 11 | 12 | def forward(self, x): 13 | """Calculate forward propagation. 14 | 15 | Args: 16 | x (Tensor): Predicted signal (B, #channels, #samples). 17 | Returns: 18 | Tensor: Sum signal. 19 | Tensor: Difference signal. 20 | """ 21 | if not (x.size(1) == 2): # inputs must be stereo 22 | raise ValueError(f"Input must be stereo: {x.size(1)} channel(s).") 23 | 24 | sum_sig = self.sum(x).unsqueeze(1) 25 | diff_sig = self.diff(x).unsqueeze(1) 26 | 27 | return sum_sig, diff_sig 28 | 29 | @staticmethod 30 | def sum(x): 31 | return x[:, 0, :] + x[:, 1, :] 32 | 33 | @staticmethod 34 | def diff(x): 35 | return x[:, 0, :] - x[:, 1, :] 36 | 37 | 38 | class FIRFilter(torch.nn.Module): 39 | """FIR pre-emphasis filtering module. 40 | 41 | Args: 42 | filter_type (str): Shape of the desired FIR filter ("hp", "fd", "aw"). Default: "hp" 43 | coef (float): Coefficient value for the filter tap (only applicable for "hp" and "fd"). Default: 0.85 44 | ntaps (int): Number of FIR filter taps for constructing A-weighting filters. Default: 101 45 | plot (bool): Plot the magnitude respond of the filter. Default: False 46 | 47 | Based upon the perceptual loss pre-empahsis filters proposed by 48 | [Wright & Välimäki, 2019](https://arxiv.org/abs/1911.08922). 49 | 50 | A-weighting filter - "aw" 51 | First-order highpass - "hp" 52 | Folded differentiator - "fd" 53 | 54 | Note that the default coefficeint value of 0.85 is optimized for 55 | a sampling rate of 44.1 kHz, considering adjusting this value at differnt sampling rates. 56 | """ 57 | 58 | def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False): 59 | """Initilize FIR pre-emphasis filtering module.""" 60 | super(FIRFilter, self).__init__() 61 | self.filter_type = filter_type 62 | self.coef = coef 63 | self.fs = fs 64 | self.ntaps = ntaps 65 | self.plot = plot 66 | 67 | import scipy.signal 68 | 69 | if ntaps % 2 == 0: 70 | raise ValueError(f"ntaps must be odd (ntaps={ntaps}).") 71 | 72 | if filter_type == "hp": 73 | self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1) 74 | self.fir.weight.requires_grad = False 75 | self.fir.weight.data = torch.tensor([1, -coef, 0]).view(1, 1, -1) 76 | elif filter_type == "fd": 77 | self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1) 78 | self.fir.weight.requires_grad = False 79 | self.fir.weight.data = torch.tensor([1, 0, -coef]).view(1, 1, -1) 80 | elif filter_type == "aw": 81 | # Definition of analog A-weighting filter according to IEC/CD 1672. 82 | f1 = 20.598997 83 | f2 = 107.65265 84 | f3 = 737.86223 85 | f4 = 12194.217 86 | A1000 = 1.9997 87 | 88 | NUMs = [(2 * np.pi * f4) ** 2 * (10 ** (A1000 / 20)), 0, 0, 0, 0] 89 | DENs = np.polymul( 90 | [1, 4 * np.pi * f4, (2 * np.pi * f4) ** 2], 91 | [1, 4 * np.pi * f1, (2 * np.pi * f1) ** 2], 92 | ) 93 | DENs = np.polymul( 94 | np.polymul(DENs, [1, 2 * np.pi * f3]), [1, 2 * np.pi * f2] 95 | ) 96 | 97 | # convert analog filter to digital filter 98 | b, a = scipy.signal.bilinear(NUMs, DENs, fs=fs) 99 | 100 | # compute the digital filter frequency response 101 | w_iir, h_iir = scipy.signal.freqz(b, a, worN=512, fs=fs) 102 | 103 | # then we fit to 101 tap FIR filter with least squares 104 | taps = scipy.signal.firls(ntaps, w_iir, abs(h_iir), fs=fs) 105 | 106 | # now implement this digital FIR filter as a Conv1d layer 107 | self.fir = torch.nn.Conv1d( 108 | 1, 1, kernel_size=ntaps, bias=False, padding=ntaps // 2 109 | ) 110 | self.fir.weight.requires_grad = False 111 | self.fir.weight.data = torch.tensor(taps.astype("float32")).view(1, 1, -1) 112 | 113 | if plot: 114 | from .plotting import compare_filters 115 | compare_filters(b, a, taps, fs=fs) 116 | 117 | def forward(self, input, target): 118 | """Calculate forward propagation. 119 | Args: 120 | input (Tensor): Predicted signal (B, #channels, #samples). 121 | target (Tensor): Groundtruth signal (B, #channels, #samples). 122 | Returns: 123 | Tensor: Filtered signal. 124 | """ 125 | input = torch.nn.functional.conv1d( 126 | input, self.fir.weight.data, padding=self.ntaps // 2 127 | ) 128 | target = torch.nn.functional.conv1d( 129 | target, self.fir.weight.data, padding=self.ntaps // 2 130 | ) 131 | return input, target 132 | -------------------------------------------------------------------------------- /tests/test_auraloss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import torch 4 | import auraloss 5 | 6 | 7 | def test_mrstft(): 8 | target = torch.rand(8, 2, 44100) 9 | pred = torch.rand(8, 2, 44100) 10 | 11 | loss = auraloss.freq.MultiResolutionSTFTLoss() 12 | res = loss(pred, target) 13 | assert res is not None 14 | 15 | 16 | def test_stft(): 17 | target = torch.rand(8, 2, 44100) 18 | pred = torch.rand(8, 2, 44100) 19 | 20 | loss = auraloss.freq.STFTLoss() 21 | res = loss(pred, target) 22 | assert res is not None 23 | 24 | 25 | def test_stft_weights_a(): 26 | target = torch.rand(8, 2, 44100) 27 | pred = torch.rand(8, 2, 44100) 28 | # test difference weights 29 | loss = auraloss.freq.STFTLoss( 30 | w_log_mag=1.0, 31 | w_lin_mag=0.0, 32 | w_sc=1.0, 33 | reduction="mean", 34 | ) 35 | res = loss(pred, target) 36 | assert res is not None 37 | 38 | 39 | def test_stft_reduction(): 40 | target = torch.rand(8, 2, 44100) 41 | pred = torch.rand(8, 2, 44100) 42 | # test the reduction 43 | loss = auraloss.freq.STFTLoss( 44 | w_log_mag=1.0, 45 | w_lin_mag=1.0, 46 | w_sc=0.0, 47 | reduction="none", 48 | ) 49 | res = loss(pred, target) 50 | print(res.shape) 51 | assert len(res.shape) > 1 52 | 53 | 54 | def test_sum_and_difference(): 55 | target = torch.rand(8, 2, 44100) 56 | pred = torch.rand(8, 2, 44100) 57 | loss = auraloss.freq.SumAndDifferenceSTFTLoss( 58 | fft_sizes=[512, 2048, 8192], 59 | hop_sizes=[128, 512, 2048], 60 | win_lengths=[512, 2048, 8192], 61 | ) 62 | res = loss(pred, target) 63 | assert res is not None 64 | 65 | 66 | def test_perceptual_sum_and_difference(): 67 | target = torch.rand(8, 2, 44100) 68 | pred = torch.rand(8, 2, 44100) 69 | loss_fn = auraloss.freq.SumAndDifferenceSTFTLoss( 70 | fft_sizes=[512, 2048, 8192], 71 | hop_sizes=[128, 512, 2048], 72 | win_lengths=[512, 2048, 8192], 73 | perceptual_weighting=True, 74 | sample_rate=44100, 75 | ) 76 | 77 | res = loss_fn(pred, target) 78 | assert res is not None 79 | 80 | 81 | def test_perceptual_mel_sum_and_difference(): 82 | target = torch.rand(8, 2, 44100) 83 | pred = torch.rand(8, 2, 44100) 84 | loss_fn = auraloss.freq.SumAndDifferenceSTFTLoss( 85 | fft_sizes=[1024, 2048, 8192], 86 | hop_sizes=[256, 512, 2048], 87 | win_lengths=[1024, 2048, 8192], 88 | perceptual_weighting=True, 89 | sample_rate=44100, 90 | scale="mel", 91 | n_bins=128, 92 | ) 93 | 94 | res = loss_fn(pred, target) 95 | assert res is not None 96 | 97 | 98 | def test_melstft(): 99 | target = torch.rand(8, 2, 44100) 100 | pred = torch.rand(8, 2, 44100) 101 | # test MelSTFT 102 | loss = auraloss.freq.MelSTFTLoss(44100) 103 | res = loss(pred, target) 104 | assert res is not None 105 | 106 | 107 | def test_melstft_reduction(): 108 | target = torch.rand(8, 2, 44100) 109 | pred = torch.rand(8, 2, 44100) 110 | # test MelSTFT with no reduction 111 | loss = auraloss.freq.MelSTFTLoss(44100, reduction="none") 112 | res = loss(pred, target) 113 | assert len(res) > 1 114 | 115 | 116 | def test_multires_mel(): 117 | target = torch.rand(8, 2, 44100) 118 | pred = torch.rand(8, 2, 44100) 119 | sample_rate = 44100 120 | loss = auraloss.freq.MultiResolutionSTFTLoss( 121 | scale="mel", 122 | n_bins=64, 123 | sample_rate=sample_rate, 124 | ) 125 | res = loss(pred, target) 126 | assert res is not None 127 | 128 | 129 | def test_perceptual_multires_mel(): 130 | target = torch.rand(8, 2, 44100) 131 | pred = torch.rand(8, 2, 44100) 132 | sample_rate = 44100 133 | loss = auraloss.freq.MultiResolutionSTFTLoss( 134 | fft_sizes=[1024, 2048, 8192], 135 | hop_sizes=[256, 512, 2048], 136 | win_lengths=[1024, 2048, 8192], 137 | scale="mel", 138 | n_bins=128, 139 | sample_rate=sample_rate, 140 | perceptual_weighting=True, 141 | ) 142 | res = loss(pred, target) 143 | assert res is not None 144 | 145 | 146 | def test_stft_l2(): 147 | N = 32 148 | n = torch.arange(N) 149 | 150 | f = N / 4 151 | target = torch.cos(2 * math.pi * f * n / N) 152 | target = target[None, None, :] 153 | pred = torch.zeros_like(target) 154 | 155 | loss = auraloss.freq.STFTLoss( 156 | fft_size=N, 157 | hop_size=N + 1, # eliminate padding artefacts by enforcing only one hop 158 | win_length=N, 159 | window="ones", # eliminate windowing artefacts 160 | w_sc=0.0, 161 | w_log_mag=0.0, 162 | w_lin_mag=1.0, 163 | w_phs=0.0, 164 | mag_distance="L2", 165 | ) 166 | res = loss(pred, target) 167 | 168 | # MSE of energy in concentrated in single DFT bin 169 | expected_loss = ((N // 2) ** 2) / (N // 2 + 1) 170 | 171 | torch.testing.assert_close(res, torch.tensor(expected_loss), rtol=1e-3, atol=1e-3) 172 | 173 | 174 | def test_multires_l2(): 175 | N = 32 176 | n = torch.arange(N) 177 | 178 | f = N / 4 179 | target = torch.cos(2 * math.pi * f * n / N) 180 | target = target[None, None, :] 181 | pred = torch.zeros_like(target) 182 | 183 | loss = auraloss.freq.MultiResolutionSTFTLoss( 184 | fft_sizes=[N], 185 | hop_sizes=[N + 1], # eliminate padding artefacts by enforcing only one hop 186 | win_lengths=[N], 187 | window="ones", # eliminate windowing artefacts 188 | w_sc=0.0, 189 | w_log_mag=0.0, 190 | w_lin_mag=1.0, 191 | w_phs=0.0, 192 | mag_distance="L2", 193 | ) 194 | res = loss(pred, target) 195 | 196 | expected_loss = ((N // 2) ** 2) / (N // 2 + 1) 197 | 198 | torch.testing.assert_close(res, torch.tensor(expected_loss), rtol=1e-3, atol=1e-3) 199 | -------------------------------------------------------------------------------- /examples/compressor/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | import torch 5 | import torchaudio 6 | import numpy as np 7 | import soundfile as sf 8 | 9 | torchaudio.set_audio_backend("sox_io") 10 | 11 | 12 | class SignalTrainLA2ADataset(torch.utils.data.Dataset): 13 | """SignalTrain LA2A dataset. Source: [10.5281/zenodo.3824876](https://zenodo.org/record/3824876).""" 14 | 15 | def __init__( 16 | self, 17 | root_dir, 18 | subset="train", 19 | length=16384, 20 | preload=False, 21 | half=True, 22 | use_soundfile=False, 23 | ): 24 | """ 25 | Args: 26 | root_dir (str): Path to the root directory of the SignalTrain dataset. 27 | subset (str, optional): Pull data either from "train", "val", or "test" subsets. (Default: "train") 28 | length (int, optional): Number of samples in the returned examples. (Default: 40) 29 | preload (bool, optional): Read in all data into RAM during init. (Default: False) 30 | half (bool, optional): Store the float32 audio as float16. (Default: True) 31 | use_soundfile (bool, optional): Use the soundfile library to load instead of torchaudio. (Default: False) 32 | """ 33 | self.root_dir = root_dir 34 | self.subset = subset 35 | self.length = length 36 | self.preload = preload 37 | self.half = half 38 | self.use_soundfile = use_soundfile 39 | 40 | # get all the target files files in the directory first 41 | self.target_files = glob.glob( 42 | os.path.join(self.root_dir, self.subset.capitalize(), "target_*.wav") 43 | ) 44 | self.input_files = glob.glob( 45 | os.path.join(self.root_dir, self.subset.capitalize(), "input_*.wav") 46 | ) 47 | 48 | self.examples = [] 49 | self.hours = 0 # total number of hours of data in the subset 50 | 51 | # ensure that the sets are ordered correctlty 52 | self.target_files.sort() 53 | self.input_files.sort() 54 | 55 | # get the parameters 56 | self.params = [ 57 | ( 58 | float(f.split("__")[1].replace(".wav", "")), 59 | float(f.split("__")[2].replace(".wav", "")), 60 | ) 61 | for f in self.target_files 62 | ] 63 | 64 | # loop over files to count total length 65 | for idx, (tfile, ifile, params) in enumerate( 66 | zip(self.target_files, self.input_files, self.params) 67 | ): 68 | 69 | ifile_id = int(os.path.basename(ifile).split("_")[1]) 70 | tfile_id = int(os.path.basename(tfile).split("_")[1]) 71 | if ifile_id != tfile_id: 72 | raise RuntimeError( 73 | f"Found non-matching file ids: {ifile_id} != {tfile_id}! Check dataset." 74 | ) 75 | 76 | md = torchaudio.info(tfile) 77 | self.hours += (md.num_frames / md.sample_rate) / 3600 78 | num_frames = md.num_frames 79 | 80 | if self.preload: 81 | sys.stdout.write( 82 | f"* Pre-loading... {idx+1:3d}/{len(self.target_files):3d} ...\r" 83 | ) 84 | sys.stdout.flush() 85 | input, sr = self.load(ifile) 86 | target, sr = self.load(tfile) 87 | 88 | num_frames = int(np.min([input.shape[-1], target.shape[-1]])) 89 | if input.shape[-1] != target.shape[-1]: 90 | print( 91 | os.path.basename(ifile), 92 | input.shape[-1], 93 | os.path.basename(tfile), 94 | target.shape[-1], 95 | ) 96 | raise RuntimeError("Found potentially corrupt file!") 97 | if self.half: 98 | input = input.half() 99 | target = target.half() 100 | else: 101 | input = None 102 | target = None 103 | 104 | # create one entry for each patch 105 | for n in range((num_frames // self.length) - 1): 106 | offset = int(n * self.length) 107 | end = offset + self.length 108 | self.examples.append( 109 | { 110 | "idx": idx, 111 | "target_file": tfile, 112 | "input_file": ifile, 113 | "input_audio": input[:, offset:end] 114 | if input is not None 115 | else None, 116 | "target_audio": target[:, offset:end] 117 | if input is not None 118 | else None, 119 | "params": params, 120 | "offset": offset, 121 | "frames": num_frames, 122 | } 123 | ) 124 | 125 | # we then want to get the input files 126 | print( 127 | f"Located {len(self.examples)} examples totaling {self.hours:0.1f} hr in the {self.subset} subset." 128 | ) 129 | 130 | def __len__(self): 131 | return len(self.examples) 132 | 133 | def __getitem__(self, idx): 134 | if self.preload: 135 | audio_idx = self.examples[idx]["idx"] 136 | offset = self.examples[idx]["offset"] 137 | input = self.examples[idx]["input_audio"] 138 | target = self.examples[idx]["target_audio"] 139 | else: 140 | offset = self.examples[idx]["offset"] 141 | input, sr = torchaudio.load( 142 | self.examples[idx]["input_file"], 143 | num_frames=self.length, 144 | frame_offset=offset, 145 | normalize=False, 146 | ) 147 | target, sr = torchaudio.load( 148 | self.examples[idx]["target_file"], 149 | num_frames=self.length, 150 | frame_offset=offset, 151 | normalize=False, 152 | ) 153 | if self.half: 154 | input = input.half() 155 | target = target.half() 156 | 157 | # at random with p=0.5 flip the phase 158 | if np.random.rand() > 0.5: 159 | input *= -1 160 | target *= -1 161 | 162 | # then get the tuple of parameters 163 | params = torch.tensor(self.examples[idx]["params"]).unsqueeze(0) 164 | params[:, 1] /= 100 165 | 166 | return input, target, params 167 | 168 | def load(self, filename): 169 | if self.use_soundfile: 170 | x, sr = sf.read(filename, always_2d=True) 171 | x = torch.tensor(x.T) 172 | else: 173 | x, sr = torchaudio.load(filename, normalize=False) 174 | return x, sr 175 | -------------------------------------------------------------------------------- /tests/simple_train_gpu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import auraloss 3 | import torchaudio 4 | from tqdm import tqdm 5 | 6 | 7 | def center_crop(x, length: int): 8 | start = (x.shape[-1] - length) // 2 9 | stop = start + length 10 | return x[..., start:stop] 11 | 12 | 13 | def causal_crop(x, length: int): 14 | stop = x.shape[-1] - 1 15 | start = stop - length 16 | return x[..., start:stop] 17 | 18 | 19 | class TCNBlock(torch.nn.Module): 20 | def __init__( 21 | self, 22 | in_ch, 23 | out_ch, 24 | kernel_size=3, 25 | padding="same", 26 | dilation=1, 27 | grouped=False, 28 | causal=False, 29 | **kwargs, 30 | ): 31 | super(TCNBlock, self).__init__() 32 | 33 | self.in_ch = in_ch 34 | self.out_ch = out_ch 35 | self.kernel_size = kernel_size 36 | self.padding = padding 37 | self.dilation = dilation 38 | self.grouped = grouped 39 | self.causal = causal 40 | 41 | groups = out_ch if grouped and (in_ch % out_ch == 0) else 1 42 | 43 | if padding == "same": 44 | pad_value = (kernel_size - 1) + ((kernel_size - 1) * (dilation - 1)) 45 | elif padding in ["none", "valid"]: 46 | pad_value = 0 47 | 48 | self.conv1 = torch.nn.Conv1d( 49 | in_ch, 50 | out_ch, 51 | kernel_size=kernel_size, 52 | padding=0, # testing a change in padding was pad_value//2 53 | dilation=dilation, 54 | groups=groups, 55 | bias=False, 56 | ) 57 | if grouped: 58 | self.conv1b = torch.nn.Conv1d(out_ch, out_ch, kernel_size=1) 59 | 60 | else: 61 | self.bn = torch.nn.BatchNorm1d(out_ch) 62 | 63 | self.relu = torch.nn.PReLU(out_ch) 64 | self.res = torch.nn.Conv1d( 65 | in_ch, out_ch, kernel_size=1, groups=in_ch, bias=False 66 | ) 67 | 68 | def forward(self, x): 69 | x_in = x 70 | 71 | x = self.conv1(x) 72 | # x = self.bn(x) 73 | x = self.relu(x) 74 | 75 | x_res = self.res(x_in) 76 | if self.causal: 77 | x = x + causal_crop(x_res, x.shape[-1]) 78 | else: 79 | x = x + center_crop(x_res, x.shape[-1]) 80 | 81 | return x 82 | 83 | 84 | class TCNModel(torch.nn.Module): 85 | """Temporal convolutional network. 86 | Args: 87 | nparams (int): Number of conditioning parameters. 88 | ninputs (int): Number of input channels (mono = 1, stereo 2). Default: 1 89 | noutputs (int): Number of output channels (mono = 1, stereo 2). Default: 1 90 | nblocks (int): Number of total TCN blocks. Default: 10 91 | kernel_size (int): Width of the convolutional kernels. Default: 3 92 | dialation_growth (int): Compute the dilation factor at each block as dilation_growth ** (n % stack_size). Default: 1 93 | channel_growth (int): Compute the output channels at each black as in_ch * channel_growth. Default: 2 94 | channel_width (int): When channel_growth = 1 all blocks use convolutions with this many channels. Default: 64 95 | stack_size (int): Number of blocks that constitute a single stack of blocks. Default: 10 96 | grouped (bool): Use grouped convolutions to reduce the total number of parameters. Default: False 97 | causal (bool): Causal TCN configuration does not consider future input values. Default: False 98 | """ 99 | 100 | def __init__( 101 | self, 102 | ninputs=1, 103 | noutputs=1, 104 | nblocks=10, 105 | kernel_size=3, 106 | dilation_growth=1, 107 | channel_growth=1, 108 | channel_width=32, 109 | stack_size=10, 110 | grouped=False, 111 | causal=False, 112 | ): 113 | super().__init__() 114 | 115 | self.blocks = torch.nn.ModuleList() 116 | for n in range(nblocks): 117 | in_ch = out_ch if n > 0 else ninputs 118 | 119 | if channel_growth > 1: 120 | out_ch = in_ch * self.hparams.channel_growth 121 | else: 122 | out_ch = channel_width 123 | 124 | dilation = dilation_growth ** (n % stack_size) 125 | self.blocks.append( 126 | TCNBlock( 127 | in_ch, 128 | out_ch, 129 | kernel_size=kernel_size, 130 | dilation=dilation, 131 | padding="same" if causal else "valid", 132 | causal=causal, 133 | grouped=grouped, 134 | ) 135 | ) 136 | 137 | self.output = torch.nn.Conv1d(out_ch, noutputs, kernel_size=1) 138 | 139 | def forward(self, x): 140 | # iterate over blocks passing conditioning 141 | for idx, block in enumerate(self.blocks): 142 | x = block(x) 143 | return torch.tanh(self.output(x)) 144 | 145 | def compute_receptive_field(self): 146 | """Compute the receptive field in samples.""" 147 | rf = self.hparams.kernel_size 148 | for n in range(1, self.hparams.nblocks): 149 | dilation = self.hparams.dilation_growth ** (n % self.hparams.stack_size) 150 | rf = rf + ((self.hparams.kernel_size - 1) * dilation) 151 | return 152 | 153 | 154 | if __name__ == "__main__": 155 | 156 | y, sr = torchaudio.load("../sounds/assets/drum_kit_clean.wav") 157 | y /= y.abs().max() 158 | # y = y.repeat(2, 1) 159 | print(y.shape) 160 | 161 | # created distorted copy 162 | # x = torch.tanh(y * 4.0) 163 | x = y + (0.01 * torch.randn(y.shape)) 164 | 165 | # move data to gpu 166 | x = x.cuda() 167 | y = y.cuda() 168 | 169 | x = x.view(1, 2, -1) 170 | y = y.view(1, 2, -1) 171 | 172 | # create simple network 173 | model = TCNModel(ninputs=2, noutputs=2, kernel_size=13, dilation_growth=2) 174 | model.cuda() 175 | 176 | # create loss function 177 | loss_fn = auraloss.freq.SumAndDifferenceSTFTLoss( 178 | fft_sizes=[1024, 2048, 8192], 179 | hop_sizes=[256, 512, 2048], 180 | win_lengths=[1024, 2048, 8192], 181 | perceptual_weighting=True, 182 | sample_rate=44100, 183 | scale="mel", 184 | n_bins=128, 185 | ) 186 | # loss_fn.cuda() 187 | 188 | # loss_fn = auraloss.freq.MultiResolutionSTFTLoss( fft_sizes=[1024, 2048, 8192], 189 | # hop_sizes=[256, 512, 2048], 190 | # win_lengths=[1024, 2048, 8192],) 191 | 192 | # loss_fn = torch.nn.MSELoss() 193 | 194 | optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) 195 | 196 | # run optimization 197 | pbar = tqdm(range(1000)) 198 | for iter_idx in pbar: 199 | 200 | optimizer.zero_grad() 201 | 202 | y_hat = model(x) 203 | 204 | y_crop = causal_crop(y, y_hat.shape[-1]) 205 | 206 | loss = loss_fn(y_hat, y_crop) 207 | loss.backward() 208 | optimizer.step() 209 | 210 | pbar.set_description(f"loss: {loss.item():0.4f}") 211 | 212 | torchaudio.save( 213 | "tests/simple_train_gpu_output.wav", y_hat.detach().cpu().view(2, -1), sr 214 | ) 215 | torchaudio.save( 216 | "tests/simple_train_gpu_input.wav", x.detach().cpu().view(2, -1), sr 217 | ) 218 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # auraloss 4 | 5 | 6 | 7 | A collection of audio-focused loss functions in PyTorch. 8 | 9 | [[PDF](https://www.christiansteinmetz.com/s/DMRN15__auraloss__Audio_focused_loss_functions_in_PyTorch.pdf)] 10 | 11 |
12 | 13 | ## Setup 14 | 15 | ``` 16 | pip install auraloss 17 | ``` 18 | 19 | If you want to use `MelSTFTLoss()` or `FIRFilter()` you will need to specify the extra install (librosa and scipy). 20 | 21 | ``` 22 | pip install auraloss[all] 23 | ``` 24 | 25 | ## Usage 26 | 27 | ```python 28 | import torch 29 | import auraloss 30 | 31 | mrstft = auraloss.freq.MultiResolutionSTFTLoss() 32 | 33 | input = torch.rand(8,1,44100) 34 | target = torch.rand(8,1,44100) 35 | 36 | loss = mrstft(input, target) 37 | ``` 38 | 39 | **NEW**: Perceptual weighting with mel scaled spectrograms. 40 | 41 | ```python 42 | 43 | bs = 8 44 | chs = 1 45 | seq_len = 131072 46 | sample_rate = 44100 47 | 48 | # some audio you want to compare 49 | target = torch.rand(bs, chs, seq_len) 50 | pred = torch.rand(bs, chs, seq_len) 51 | 52 | # define the loss function 53 | loss_fn = auraloss.freq.MultiResolutionSTFTLoss( 54 | fft_sizes=[1024, 2048, 8192], 55 | hop_sizes=[256, 512, 2048], 56 | win_lengths=[1024, 2048, 8192], 57 | scale="mel", 58 | n_bins=128, 59 | sample_rate=sample_rate, 60 | perceptual_weighting=True, 61 | ) 62 | 63 | # compute 64 | loss = loss_fn(pred, target) 65 | 66 | ``` 67 | 68 | ## Citation 69 | If you use this code in your work please consider citing us. 70 | ```bibtex 71 | @inproceedings{steinmetz2020auraloss, 72 | title={auraloss: {A}udio focused loss functions in {PyTorch}}, 73 | author={Steinmetz, Christian J. and Reiss, Joshua D.}, 74 | booktitle={Digital Music Research Network One-day Workshop (DMRN+15)}, 75 | year={2020} 76 | } 77 | ``` 78 | 79 | 80 | # Loss functions 81 | 82 | We categorize the loss functions as either time-domain or frequency-domain approaches. 83 | Additionally, we include perceptual transforms. 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 |
Loss functionInterfaceReference
Time domain
Error-to-signal ratio (ESR)auraloss.time.ESRLoss()Wright & Välimäki, 2019
DC error (DC)auraloss.time.DCLoss()Wright & Välimäki, 2019
Log hyperbolic cosine (Log-cosh)auraloss.time.LogCoshLoss()Chen et al., 2019
Signal-to-noise ratio (SNR)auraloss.time.SNRLoss()
Scale-invariant signal-to-distortion
ratio (SI-SDR)
auraloss.time.SISDRLoss()Le Roux et al., 2018
Scale-dependent signal-to-distortion
ratio (SD-SDR)
auraloss.time.SDSDRLoss()Le Roux et al., 2018
Frequency domain
Aggregate STFTauraloss.freq.STFTLoss()Arik et al., 2018
Aggregate Mel-scaled STFTauraloss.freq.MelSTFTLoss(sample_rate)
Multi-resolution STFTauraloss.freq.MultiResolutionSTFTLoss()Yamamoto et al., 2019*
Random-resolution STFTauraloss.freq.RandomResolutionSTFTLoss()Steinmetz & Reiss, 2020
Sum and difference STFT lossauraloss.freq.SumAndDifferenceSTFTLoss()Steinmetz et al., 2020
Perceptual transforms
Sum and difference signal transformauraloss.perceptual.SumAndDifference()
FIR pre-emphasis filtersauraloss.perceptual.FIRFilter()Wright & Välimäki, 2019
166 | 167 | \* [Wang et al., 2019](https://arxiv.org/abs/1904.12088) also propose a multi-resolution spectral loss (that [Engel et al., 2020](https://arxiv.org/abs/2001.04643) follow), 168 | but they do not include both the log magnitude (L1 distance) and spectral convergence terms, introduced in [Arik et al., 2018](https://arxiv.org/abs/1808.0671), and then extended for the multi-resolution case in [Yamamoto et al., 2019](https://arxiv.org/abs/1910.11480). 169 | 170 | ## Examples 171 | 172 | Currently we include an example using a set of the loss functions to train a TCN for modeling an analog dynamic range compressor. 173 | For details please refer to the details in [`examples/compressor`](examples/compressor). 174 | We provide pre-trained models, evaluation scripts to compute the metrics in the [paper](https://www.christiansteinmetz.com/s/DMRN15__auraloss__Audio_focused_loss_functions_in_PyTorch.pdf), as well as scripts to retrain models. 175 | 176 | There are some more advanced things you can do based upon the `STFTLoss` class. 177 | For example, you can compute both linear and log scaled STFT errors as in [Engel et al., 2020](https://arxiv.org/abs/2001.04643). 178 | In this case we do not include the spectral convergence term. 179 | ```python 180 | stft_loss = auraloss.freq.STFTLoss( 181 | w_log_mag=1.0, 182 | w_lin_mag=1.0, 183 | w_sc=0.0, 184 | ) 185 | ``` 186 | 187 | There is also a Mel-scaled STFT loss, which has some special requirements. 188 | This loss requires you set the sample rate as well as specify the correct device. 189 | ```python 190 | sample_rate = 44100 191 | melstft_loss = auraloss.freq.MelSTFTLoss(sample_rate, device="cuda") 192 | ``` 193 | 194 | You can also build a multi-resolution Mel-scaled STFT loss with 64 bins easily. 195 | Make sure you pass the correct device where the tensors you are comparing will be. 196 | ```python 197 | loss_fn = auraloss.freq.MultiResolutionSTFTLoss( 198 | scale="mel", 199 | n_bins=64, 200 | sample_rate=sample_rate, 201 | device="cuda" 202 | ) 203 | ``` 204 | 205 | If you are computing a loss on stereo audio you may want to consider the sum and difference (mid/side) loss. 206 | Below we have shown an example of using this loss function with the perceptual weighting and mel scaling for 207 | further perceptual relevance. 208 | 209 | ```python 210 | 211 | target = torch.rand(8, 2, 44100) 212 | pred = torch.rand(8, 2, 44100) 213 | 214 | loss_fn = auraloss.freq.SumAndDifferenceSTFTLoss( 215 | fft_sizes=[1024, 2048, 8192], 216 | hop_sizes=[256, 512, 2048], 217 | win_lengths=[1024, 2048, 8192], 218 | perceptual_weighting=True, 219 | sample_rate=44100, 220 | scale="mel", 221 | n_bins=128, 222 | ) 223 | 224 | loss = loss_fn(pred, target) 225 | ``` 226 | 227 | # Development 228 | 229 | Run tests locally with pytest. 230 | 231 | ```python -m pytest``` 232 | -------------------------------------------------------------------------------- /auraloss/time.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor as T 3 | 4 | from .utils import apply_reduction 5 | 6 | 7 | class ESRLoss(torch.nn.Module): 8 | """Error-to-signal ratio loss function module. 9 | 10 | See [Wright & Välimäki, 2019](https://arxiv.org/abs/1911.08922). 11 | 12 | Args: 13 | reduction (string, optional): Specifies the reduction to apply to the output: 14 | 'none': no reduction will be applied, 15 | 'mean': the sum of the output will be divided by the number of elements in the output, 16 | 'sum': the output will be summed. Default: 'mean' 17 | Shape: 18 | - input : :math:`(batch, nchs, ...)`. 19 | - target: :math:`(batch, nchs, ...)`. 20 | """ 21 | 22 | def __init__(self, eps: float = 1e-8, reduction: str = "mean") -> None: 23 | super().__init__() 24 | self.eps = eps 25 | self.reduction = reduction 26 | 27 | def forward(self, input: T, target: T) -> T: 28 | num = ((target - input) ** 2).sum(dim=-1) 29 | denom = (target ** 2).sum(dim=-1) + self.eps 30 | losses = num / denom 31 | losses = apply_reduction(losses, reduction=self.reduction) 32 | return losses 33 | 34 | 35 | class DCLoss(torch.nn.Module): 36 | """DC loss function module. 37 | 38 | See [Wright & Välimäki, 2019](https://arxiv.org/abs/1911.08922). 39 | 40 | Args: 41 | reduction (string, optional): Specifies the reduction to apply to the output: 42 | 'none': no reduction will be applied, 43 | 'mean': the sum of the output will be divided by the number of elements in the output, 44 | 'sum': the output will be summed. Default: 'mean' 45 | Shape: 46 | - input : :math:`(batch, nchs, ...)`. 47 | - target: :math:`(batch, nchs, ...)`. 48 | """ 49 | 50 | def __init__(self, eps: float = 1e-8, reduction: str = "mean") -> None: 51 | super().__init__() 52 | self.eps = eps 53 | self.reduction = reduction 54 | 55 | def forward(self, input: T, target: T) -> T: 56 | num = (target - input).mean(dim=-1) ** 2 57 | denom = (target ** 2).mean(dim=-1) + self.eps 58 | losses = num / denom 59 | losses = apply_reduction(losses, self.reduction) 60 | return losses 61 | 62 | 63 | class LogCoshLoss(torch.nn.Module): 64 | """Log-cosh loss function module. 65 | 66 | See [Chen et al., 2019](https://openreview.net/forum?id=rkglvsC9Ym). 67 | 68 | Args: 69 | a (float, optional): Smoothness hyperparameter. Smaller is smoother. Default: 1.0 70 | eps (float, optional): Small epsilon value for stablity. Default: 1e-8 71 | reduction (string, optional): Specifies the reduction to apply to the output: 72 | 'none': no reduction will be applied, 73 | 'mean': the sum of the output will be divided by the number of elements in the output, 74 | 'sum': the output will be summed. Default: 'mean' 75 | Shape: 76 | - input : :math:`(batch, nchs, ...)`. 77 | - target: :math:`(batch, nchs, ...)`. 78 | """ 79 | 80 | def __init__(self, a=1.0, eps=1e-8, reduction="mean"): 81 | super(LogCoshLoss, self).__init__() 82 | self.a = a 83 | self.eps = eps 84 | self.reduction = reduction 85 | 86 | def forward(self, input, target): 87 | losses = ( 88 | (1 / self.a) * torch.log(torch.cosh(self.a * (input - target)) + self.eps) 89 | ).mean(-1) 90 | losses = apply_reduction(losses, self.reduction) 91 | return losses 92 | 93 | 94 | class SNRLoss(torch.nn.Module): 95 | """Signal-to-noise ratio loss module. 96 | 97 | Note that this does NOT implement the SDR from 98 | [Vincent et al., 2006](https://ieeexplore.ieee.org/document/1643671), 99 | which includes the application of a 512-tap FIR filter. 100 | """ 101 | 102 | def __init__(self, zero_mean=True, eps=1e-8, reduction="mean"): 103 | super(SNRLoss, self).__init__() 104 | self.zero_mean = zero_mean 105 | self.eps = eps 106 | self.reduction = reduction 107 | 108 | def forward(self, input, target): 109 | if self.zero_mean: 110 | input_mean = torch.mean(input, dim=-1, keepdim=True) 111 | target_mean = torch.mean(target, dim=-1, keepdim=True) 112 | input = input - input_mean 113 | target = target - target_mean 114 | 115 | res = input - target 116 | losses = 10 * torch.log10( 117 | (target ** 2).sum(-1) / ((res ** 2).sum(-1) + self.eps) + self.eps 118 | ) 119 | losses = apply_reduction(losses, self.reduction) 120 | return -losses 121 | 122 | 123 | class SISDRLoss(torch.nn.Module): 124 | """Scale-invariant signal-to-distortion ratio loss module. 125 | 126 | Note that this returns the negative of the SI-SDR loss. 127 | 128 | See [Le Roux et al., 2018](https://arxiv.org/abs/1811.02508) 129 | 130 | Args: 131 | zero_mean (bool, optional) Remove any DC offset in the inputs. Default: ``True`` 132 | eps (float, optional): Small epsilon value for stablity. Default: 1e-8 133 | reduction (string, optional): Specifies the reduction to apply to the output: 134 | 'none': no reduction will be applied, 135 | 'mean': the sum of the output will be divided by the number of elements in the output, 136 | 'sum': the output will be summed. Default: 'mean' 137 | Shape: 138 | - input : :math:`(batch, nchs, ...)`. 139 | - target: :math:`(batch, nchs, ...)`. 140 | """ 141 | 142 | def __init__(self, zero_mean=True, eps=1e-8, reduction="mean"): 143 | super(SISDRLoss, self).__init__() 144 | self.zero_mean = zero_mean 145 | self.eps = eps 146 | self.reduction = reduction 147 | 148 | def forward(self, input, target): 149 | if self.zero_mean: 150 | input_mean = torch.mean(input, dim=-1, keepdim=True) 151 | target_mean = torch.mean(target, dim=-1, keepdim=True) 152 | input = input - input_mean 153 | target = target - target_mean 154 | 155 | alpha = (input * target).sum(-1) / (((target ** 2).sum(-1)) + self.eps) 156 | target = target * alpha.unsqueeze(-1) 157 | res = input - target 158 | 159 | losses = 10 * torch.log10( 160 | (target ** 2).sum(-1) / ((res ** 2).sum(-1) + self.eps) + self.eps 161 | ) 162 | losses = apply_reduction(losses, self.reduction) 163 | return -losses 164 | 165 | 166 | class SDSDRLoss(torch.nn.Module): 167 | """Scale-dependent signal-to-distortion ratio loss module. 168 | 169 | Note that this returns the negative of the SD-SDR loss. 170 | 171 | See [Le Roux et al., 2018](https://arxiv.org/abs/1811.02508) 172 | 173 | Args: 174 | zero_mean (bool, optional) Remove any DC offset in the inputs. Default: ``True`` 175 | eps (float, optional): Small epsilon value for stablity. Default: 1e-8 176 | reduction (string, optional): Specifies the reduction to apply to the output: 177 | 'none': no reduction will be applied, 178 | 'mean': the sum of the output will be divided by the number of elements in the output, 179 | 'sum': the output will be summed. Default: 'mean' 180 | Shape: 181 | - input : :math:`(batch, nchs, ...)`. 182 | - target: :math:`(batch, nchs, ...)`. 183 | """ 184 | 185 | def __init__(self, zero_mean=True, eps=1e-8, reduction="mean"): 186 | super(SDSDRLoss, self).__init__() 187 | self.zero_mean = zero_mean 188 | self.eps = eps 189 | self.reduction = reduction 190 | 191 | def forward(self, input, target): 192 | if self.zero_mean: 193 | input_mean = torch.mean(input, dim=-1, keepdim=True) 194 | target_mean = torch.mean(target, dim=-1, keepdim=True) 195 | input = input - input_mean 196 | target = target - target_mean 197 | 198 | alpha = (input * target).sum(-1) / (((target ** 2).sum(-1)) + self.eps) 199 | scaled_target = target * alpha.unsqueeze(-1) 200 | res = input - target 201 | 202 | losses = 10 * torch.log10( 203 | (scaled_target ** 2).sum(-1) / ((res ** 2).sum(-1) + self.eps) + self.eps 204 | ) 205 | losses = apply_reduction(losses, self.reduction) 206 | return -losses 207 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /examples/compressor/tcn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchaudio 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | from argparse import ArgumentParser 7 | 8 | import auraloss 9 | 10 | 11 | def center_crop(x, shape): 12 | start = (x.shape[-1] - shape[-1]) // 2 13 | stop = start + shape[-1] 14 | return x[..., start:stop] 15 | 16 | 17 | class FiLM(torch.nn.Module): 18 | def __init__(self, num_features, cond_dim): 19 | super(FiLM, self).__init__() 20 | self.num_features = num_features 21 | self.bn = torch.nn.BatchNorm1d(num_features, affine=False) 22 | self.adaptor = torch.nn.Linear(cond_dim, num_features * 2) 23 | 24 | def forward(self, x, cond): 25 | 26 | cond = self.adaptor(cond) 27 | g, b = torch.chunk(cond, 2, dim=-1) 28 | g = g.permute(0, 2, 1) 29 | b = b.permute(0, 2, 1) 30 | 31 | x = self.bn(x) # apply BatchNorm without affine 32 | x = (x * g) + b # then apply conditional affine 33 | 34 | return x 35 | 36 | 37 | class TCNBlock(torch.nn.Module): 38 | def __init__( 39 | self, 40 | in_ch, 41 | out_ch, 42 | kernel_size=3, 43 | padding=0, 44 | dilation=1, 45 | depthwise=False, 46 | conditional=False, 47 | **kwargs, 48 | ): 49 | super(TCNBlock, self).__init__() 50 | 51 | self.in_ch = in_ch 52 | self.out_ch = out_ch 53 | self.kernel_size = kernel_size 54 | self.padding = padding 55 | self.dilation = dilation 56 | self.depthwise = depthwise 57 | self.conditional = conditional 58 | 59 | groups = out_ch if depthwise and (in_ch % out_ch == 0) else 1 60 | 61 | self.conv1 = torch.nn.Conv1d( 62 | in_ch, 63 | out_ch, 64 | kernel_size=kernel_size, 65 | padding=padding, 66 | dilation=dilation, 67 | groups=groups, 68 | bias=False, 69 | ) 70 | if depthwise: 71 | self.conv1b = torch.nn.Conv1d(out_ch, out_ch, kernel_size=1) 72 | 73 | if conditional: 74 | self.film = FiLM(out_ch, 128) 75 | else: 76 | self.bn = torch.nn.BatchNorm1d(out_ch) 77 | 78 | self.relu = torch.nn.PReLU(out_ch) 79 | self.res = torch.nn.Conv1d( 80 | in_ch, out_ch, kernel_size=1, groups=in_ch, bias=False 81 | ) 82 | 83 | def forward(self, x, p=None): 84 | x_in = x 85 | 86 | x = self.conv1(x) 87 | if self.depthwise: # apply pointwise conv 88 | x = self.conv1b(x) 89 | if p is not None: # apply FiLM conditioning 90 | x = self.film(x, p) 91 | else: 92 | x = self.bn(x) 93 | x = self.relu(x) 94 | 95 | x_res = self.res(x_in) 96 | x = x + center_crop(x_res, x.shape) 97 | 98 | return x 99 | 100 | 101 | class TCNModel(pl.LightningModule): 102 | """Temporal convolutional network with conditioning module. 103 | 104 | Params: 105 | nparams (int): Number of conditioning parameters. 106 | ninputs (int): Number of input channels (mono = 1, stereo 2). Default: 1 107 | noutputs (int): Number of output channels (mono = 1, stereo 2). Default: 1 108 | nblocks (int): Number of total TCN blocks. Default: 10 109 | kernel_size (int): Width of the convolutional kernels. Default: 3 110 | dialation_growth (int): Compute the dilation factor at each block as dilation_growth ** (n % stack_size). Default: 1 111 | channel_growth (int): Compute the output channels at each black as in_ch * channel_growth. Default: 2 112 | channel_width (int): When channel_growth = 1 all blocks use convolutions with this many channels. Default: 64 113 | stack_size (int): Number of blocks that constitute a single stack of blocks. Default: 10 114 | depthwise (bool): Use depthwise-separable convolutions to reduce the total number of parameters. Default: False 115 | num_examples (int): Number of evaluation audio examples to log after each epochs. Default: 4 116 | """ 117 | 118 | def __init__( 119 | self, 120 | nparams, 121 | ninputs=1, 122 | noutputs=1, 123 | nblocks=10, 124 | kernel_size=3, 125 | dilation_growth=1, 126 | channel_growth=1, 127 | channel_width=64, 128 | stack_size=10, 129 | depthwise=False, 130 | num_examples=4, 131 | save_dir=None, 132 | **kwargs, 133 | ): 134 | super(TCNModel, self).__init__() 135 | 136 | self.save_hyperparameters() 137 | 138 | # setup loss functions 139 | self.l1 = torch.nn.L1Loss() 140 | self.esr = auraloss.time.ESRLoss() 141 | self.dc = auraloss.time.DCLoss() 142 | self.logcosh = auraloss.time.LogCoshLoss() 143 | self.sisdr = auraloss.time.SISDRLoss() 144 | self.stft = auraloss.freq.STFTLoss() 145 | self.mrstft = auraloss.freq.MultiResolutionSTFTLoss() 146 | self.rrstft = auraloss.freq.RandomResolutionSTFTLoss() 147 | 148 | if nparams > 0: 149 | self.gen = torch.nn.Sequential( 150 | torch.nn.Linear(nparams, 32), 151 | torch.nn.PReLU(), 152 | torch.nn.Linear(32, 64), 153 | torch.nn.PReLU(), 154 | torch.nn.Linear(64, 128), 155 | torch.nn.PReLU(), 156 | ) 157 | 158 | self.blocks = torch.nn.ModuleList() 159 | for n in range(nblocks): 160 | in_ch = out_ch if n > 0 else ninputs 161 | out_ch = in_ch * channel_growth if channel_growth > 1 else channel_width 162 | 163 | dilation = dilation_growth ** (n % stack_size) 164 | self.blocks.append( 165 | TCNBlock( 166 | in_ch, 167 | out_ch, 168 | kernel_size=kernel_size, 169 | dilation=dilation, 170 | depthwise=self.hparams.depthwise, 171 | conditional=True if nparams > 0 else False, 172 | ) 173 | ) 174 | 175 | self.output = torch.nn.Conv1d(out_ch, noutputs, kernel_size=1) 176 | 177 | def forward(self, x, p=None): 178 | # if parameters present, 179 | # compute global conditioning 180 | if p is not None: 181 | cond = self.gen(p) 182 | else: 183 | cond = None 184 | 185 | # iterate over blocks passing conditioning 186 | for idx, block in enumerate(self.blocks): 187 | x = block(x, cond) 188 | if idx == 0: 189 | skips = x 190 | else: 191 | skips = center_crop(skips, x.shape) + x 192 | 193 | return torch.tanh(self.output(x + skips)) 194 | 195 | def compute_receptive_field(self): 196 | """Compute the receptive field in samples.""" 197 | rf = self.hparams.kernel_size 198 | for n in range(1, self.hparams.nblocks): 199 | dilation = self.hparams.dilation_growth ** (n % self.hparams.stack_size) 200 | rf = rf + ((self.hparams.kernel_size - 1) * dilation) 201 | rf = rf + ((self.hparams.kernel_size - 1) * 1) 202 | return rf 203 | 204 | def training_step(self, batch, batch_idx): 205 | input, target, params = batch 206 | 207 | # pass the input thrgouh the mode 208 | pred = self(input, params) 209 | 210 | # crop the target signal 211 | target = center_crop(target, pred.shape) 212 | 213 | # compute the error using appropriate loss 214 | if self.hparams.train_loss == "l1": 215 | loss = self.l1(pred, target) 216 | elif self.hparams.train_loss == "esr+dc": 217 | loss = self.esr(pred, target) + self.dc(pred, target) 218 | elif self.hparams.train_loss == "logcosh": 219 | loss = self.logcosh(pred, target) 220 | elif self.hparams.train_loss == "sisdr": 221 | loss = self.sisdr(pred, target) 222 | elif self.hparams.train_loss == "stft": 223 | loss = self.stft(pred, target) 224 | elif self.hparams.train_loss == "mrstft": 225 | loss = self.mrstft(pred, target) 226 | elif self.hparams.train_loss == "rrstft": 227 | loss = self.rrstft(pred, target) 228 | else: 229 | raise NotImplementedError(f"Invalid loss fn: {self.hparams.train_loss}") 230 | 231 | self.log( 232 | "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True 233 | ) 234 | 235 | return loss 236 | 237 | def validation_step(self, batch, batch_idx): 238 | input, target, params = batch 239 | 240 | # pass the input thrgouh the mode 241 | pred = self(input, params) 242 | 243 | # crop the input and target signals 244 | input_crop = center_crop(input, pred.shape) 245 | target_crop = center_crop(target, pred.shape) 246 | 247 | # compute the validation error using all losses 248 | l1_loss = self.l1(pred, target_crop) 249 | esr_loss = self.esr(pred, target_crop) 250 | dc_loss = self.dc(pred, target_crop) 251 | logcosh_loss = self.logcosh(pred, target_crop) 252 | sisdr_loss = self.sisdr(pred, target_crop) 253 | stft_loss = self.stft(pred, target_crop) 254 | mrstft_loss = self.mrstft(pred, target_crop) 255 | rrstft_loss = self.rrstft(pred, target_crop) 256 | 257 | aggregate_loss = ( 258 | l1_loss 259 | + esr_loss 260 | + dc_loss 261 | + logcosh_loss 262 | + sisdr_loss 263 | + mrstft_loss 264 | + stft_loss 265 | + rrstft_loss 266 | ) 267 | 268 | self.log("val_loss", aggregate_loss) 269 | self.log("val_loss/L1", l1_loss) 270 | self.log("val_loss/ESR", esr_loss) 271 | self.log("val_loss/DC", dc_loss) 272 | self.log("val_loss/LogCosh", logcosh_loss) 273 | self.log("val_loss/SI-SDR", sisdr_loss) 274 | self.log("val_loss/STFT", stft_loss) 275 | self.log("val_loss/MRSTFT", mrstft_loss) 276 | self.log("val_loss/RRSTFT", rrstft_loss) 277 | 278 | # move tensors to cpu for logging 279 | outputs = { 280 | "input": input_crop.cpu().numpy(), 281 | "target": target_crop.cpu().numpy(), 282 | "pred": pred.cpu().numpy(), 283 | "params": params.cpu().numpy(), 284 | } 285 | 286 | return outputs 287 | 288 | def validation_epoch_end(self, validation_step_outputs): 289 | # flatten the output validation step dicts to a single dict 290 | outputs = {"input": [], "target": [], "pred": [], "params": []} 291 | 292 | for out in validation_step_outputs: 293 | for key, val in out.items(): 294 | bs = val.shape[0] 295 | for bidx in np.arange(bs): 296 | outputs[key].append(val[bidx, ...]) 297 | 298 | example_indices = np.arange(len(outputs["input"])) 299 | rand_indices = np.random.choice( 300 | example_indices, 301 | replace=False, 302 | size=np.min([len(outputs["input"]), self.hparams.num_examples]), 303 | ) 304 | 305 | for idx, rand_idx in enumerate(list(rand_indices)): 306 | i = outputs["input"][rand_idx].squeeze() 307 | t = outputs["target"][rand_idx].squeeze() 308 | p = outputs["pred"][rand_idx].squeeze() 309 | prm = outputs["params"][rand_idx].squeeze() 310 | 311 | # log audio examples 312 | self.logger.experiment.add_audio( 313 | f"input/{idx}", 314 | i, 315 | self.global_step, 316 | sample_rate=self.hparams.sample_rate, 317 | ) 318 | self.logger.experiment.add_audio( 319 | f"target/{idx}", 320 | t, 321 | self.global_step, 322 | sample_rate=self.hparams.sample_rate, 323 | ) 324 | self.logger.experiment.add_audio( 325 | f"pred/{idx}", p, self.global_step, sample_rate=self.hparams.sample_rate 326 | ) 327 | 328 | if self.hparams.save_dir is not None: 329 | if not os.path.isdir(self.hparams.save_dir): 330 | os.makedirs(self.hparams.save_dir) 331 | 332 | input_filename = os.path.join( 333 | self.hparams.save_dir, 334 | f"{idx}-input-{int(prm[0]):1d}-{prm[1]:0.2f}.wav", 335 | ) 336 | target_filename = os.path.join( 337 | self.hparams.save_dir, 338 | f"{idx}-target-{int(prm[0]):1d}-{prm[1]:0.2f}.wav", 339 | ) 340 | 341 | if not os.path.isfile(input_filename): 342 | torchaudio.save( 343 | input_filename, 344 | torch.tensor(i).view(1, -1).float(), 345 | sample_rate=self.hparams.sample_rate, 346 | ) 347 | 348 | if not os.path.isfile(target_filename): 349 | torchaudio.save( 350 | target_filename, 351 | torch.tensor(t).view(1, -1).float(), 352 | sample_rate=self.hparams.sample_rate, 353 | ) 354 | 355 | torchaudio.save( 356 | os.path.join( 357 | self.hparams.save_dir, 358 | f"{idx}-pred-{self.hparams.train_loss}-{int(prm[0]):1d}-{prm[1]:0.2f}.wav", 359 | ), 360 | torch.tensor(p).view(1, -1).float(), 361 | sample_rate=self.hparams.sample_rate, 362 | ) 363 | 364 | def test_step(self, batch, batch_idx): 365 | return self.validation_step(batch, batch_idx) 366 | 367 | def test_epoch_end(self, test_step_outputs): 368 | return self.validation_epoch_end(test_step_outputs) 369 | 370 | def configure_optimizers(self): 371 | optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr) 372 | lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 373 | optimizer, patience=4, verbose=True 374 | ) 375 | return { 376 | "optimizer": optimizer, 377 | "lr_scheduler": lr_scheduler, 378 | "monitor": "val_loss", 379 | } 380 | 381 | # add any model hyperparameters here 382 | @staticmethod 383 | def add_model_specific_args(parent_parser): 384 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 385 | # --- model related --- 386 | parser.add_argument("--ninputs", type=int, default=1) 387 | parser.add_argument("--noutputs", type=int, default=1) 388 | parser.add_argument("--nblocks", type=int, default=10) 389 | parser.add_argument("--kernel_size", type=int, default=3) 390 | parser.add_argument("--dilation_growth", type=int, default=1) 391 | parser.add_argument("--channel_growth", type=int, default=1) 392 | parser.add_argument("--channel_width", type=int, default=64) 393 | parser.add_argument("--stack_size", type=int, default=10) 394 | parser.add_argument("--depthwise", default=False, action="store_true") 395 | # --- training related --- 396 | parser.add_argument("--lr", type=float, default=1e-3) 397 | parser.add_argument("--train_loss", type=str, default="l1") 398 | # --- vadliation related --- 399 | parser.add_argument("--save_dir", type=str, default=None) 400 | parser.add_argument("--num_examples", type=int, default=4) 401 | 402 | return parser 403 | -------------------------------------------------------------------------------- /auraloss/freq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import List, Any 4 | 5 | from .utils import apply_reduction, get_window 6 | from .perceptual import SumAndDifference, FIRFilter 7 | 8 | 9 | class SpectralConvergenceLoss(torch.nn.Module): 10 | """Spectral convergence loss module. 11 | 12 | See [Arik et al., 2018](https://arxiv.org/abs/1808.06719). 13 | """ 14 | 15 | def __init__(self): 16 | super(SpectralConvergenceLoss, self).__init__() 17 | 18 | def forward(self, x_mag, y_mag): 19 | return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") 20 | 21 | 22 | class STFTMagnitudeLoss(torch.nn.Module): 23 | """STFT magnitude loss module. 24 | 25 | See [Arik et al., 2018](https://arxiv.org/abs/1808.06719) 26 | and [Engel et al., 2020](https://arxiv.org/abs/2001.04643v1) 27 | 28 | Log-magnitudes are calculated with `log(log_fac*x + log_eps)`, where `log_fac` controls the 29 | compression strength (larger value results in more compression), and `log_eps` can be used 30 | to control the range of the compressed output values (e.g., `log_eps>=1` ensures positive 31 | output values). The default values `log_fac=1` and `log_eps=0` correspond to plain log-compression. 32 | 33 | Args: 34 | log (bool, optional): Log-scale the STFT magnitudes, 35 | or use linear scale. Default: True 36 | log_eps (float, optional): Constant value added to the magnitudes before evaluating the logarithm. 37 | Default: 0.0 38 | log_fac (float, optional): Constant multiplication factor for the magnitudes before evaluating the logarithm. 39 | Default: 1.0 40 | distance (str, optional): Distance function ["L1", "L2"]. Default: "L1" 41 | reduction (str, optional): Reduction of the loss elements. Default: "mean" 42 | """ 43 | 44 | def __init__(self, log=True, log_eps=0.0, log_fac=1.0, distance="L1", reduction="mean"): 45 | super(STFTMagnitudeLoss, self).__init__() 46 | 47 | self.log = log 48 | self.log_eps = log_eps 49 | self.log_fac = log_fac 50 | 51 | if distance == "L1": 52 | self.distance = torch.nn.L1Loss(reduction=reduction) 53 | elif distance == "L2": 54 | self.distance = torch.nn.MSELoss(reduction=reduction) 55 | else: 56 | raise ValueError(f"Invalid distance: '{distance}'.") 57 | 58 | def forward(self, x_mag, y_mag): 59 | if self.log: 60 | x_mag = torch.log(self.log_fac * x_mag + self.log_eps) 61 | y_mag = torch.log(self.log_fac * y_mag + self.log_eps) 62 | return self.distance(x_mag, y_mag) 63 | 64 | 65 | class STFTLoss(torch.nn.Module): 66 | """STFT loss module. 67 | 68 | See [Yamamoto et al. 2019](https://arxiv.org/abs/1904.04472). 69 | 70 | Args: 71 | fft_size (int, optional): FFT size in samples. Default: 1024 72 | hop_size (int, optional): Hop size of the FFT in samples. Default: 256 73 | win_length (int, optional): Length of the FFT analysis window. Default: 1024 74 | window (str, optional): Window to apply before FFT, can either be one of the window function provided in PyTorch 75 | ['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window'] 76 | or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html). 77 | Default: 'hann_window' 78 | w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0 79 | w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0 80 | w_lin_mag_mag (float, optional): Weight of the linear magnitude loss term. Default: 0.0 81 | w_phs (float, optional): Weight of the spectral phase loss term. Default: 0.0 82 | sample_rate (int, optional): Sample rate. Required when scale = 'mel'. Default: None 83 | scale (str, optional): Optional frequency scaling method, options include: 84 | ['mel', 'chroma'] 85 | Default: None 86 | n_bins (int, optional): Number of scaling frequency bins. Default: None. 87 | perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False 88 | scale_invariance (bool, optional): Perform an optimal scaling of the target. Default: False 89 | eps (float, optional): Small epsilon value for stablity. Default: 1e-8 90 | output (str, optional): Format of the loss returned. 91 | 'loss' : Return only the raw, aggregate loss term. 92 | 'full' : Return the raw loss, plus intermediate loss terms. 93 | Default: 'loss' 94 | reduction (str, optional): Specifies the reduction to apply to the output: 95 | 'none': no reduction will be applied, 96 | 'mean': the sum of the output will be divided by the number of elements in the output, 97 | 'sum': the output will be summed. 98 | Default: 'mean' 99 | mag_distance (str, optional): Distance function ["L1", "L2"] for the magnitude loss terms. 100 | device (str, optional): Place the filterbanks on specified device. Default: None 101 | 102 | Returns: 103 | loss: 104 | Aggreate loss term. Only returned if output='loss'. By default. 105 | loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss: 106 | Aggregate and intermediate loss terms. Only returned if output='full'. 107 | """ 108 | 109 | def __init__( 110 | self, 111 | fft_size: int = 1024, 112 | hop_size: int = 256, 113 | win_length: int = 1024, 114 | window: str = "hann_window", 115 | w_sc: float = 1.0, 116 | w_log_mag: float = 1.0, 117 | w_lin_mag: float = 0.0, 118 | w_phs: float = 0.0, 119 | sample_rate: float = None, 120 | scale: str = None, 121 | n_bins: int = None, 122 | perceptual_weighting: bool = False, 123 | scale_invariance: bool = False, 124 | eps: float = 1e-8, 125 | output: str = "loss", 126 | reduction: str = "mean", 127 | mag_distance: str = "L1", 128 | device: Any = None, 129 | **kwargs 130 | ): 131 | super().__init__() 132 | self.fft_size = fft_size 133 | self.hop_size = hop_size 134 | self.win_length = win_length 135 | self.window = get_window(window, win_length) 136 | self.w_sc = w_sc 137 | self.w_log_mag = w_log_mag 138 | self.w_lin_mag = w_lin_mag 139 | self.w_phs = w_phs 140 | self.sample_rate = sample_rate 141 | self.scale = scale 142 | self.n_bins = n_bins 143 | self.perceptual_weighting = perceptual_weighting 144 | self.scale_invariance = scale_invariance 145 | self.eps = eps 146 | self.output = output 147 | self.reduction = reduction 148 | self.mag_distance = mag_distance 149 | self.device = device 150 | 151 | self.phs_used = bool(self.w_phs) 152 | 153 | self.spectralconv = SpectralConvergenceLoss() 154 | self.logstft = STFTMagnitudeLoss( 155 | log=True, 156 | reduction=reduction, 157 | distance=mag_distance, 158 | **kwargs 159 | ) 160 | self.linstft = STFTMagnitudeLoss( 161 | log=False, 162 | reduction=reduction, 163 | distance=mag_distance, 164 | **kwargs 165 | ) 166 | 167 | # setup mel filterbank 168 | if scale is not None: 169 | try: 170 | import librosa.filters 171 | except Exception as e: 172 | print(e) 173 | print("Try `pip install auraloss[all]`.") 174 | 175 | if self.scale == "mel": 176 | assert sample_rate != None # Must set sample rate to use mel scale 177 | assert n_bins <= fft_size # Must be more FFT bins than Mel bins 178 | fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins) 179 | fb = torch.tensor(fb).unsqueeze(0) 180 | 181 | elif self.scale == "chroma": 182 | assert sample_rate != None # Must set sample rate to use chroma scale 183 | assert n_bins <= fft_size # Must be more FFT bins than chroma bins 184 | fb = librosa.filters.chroma( 185 | sr=sample_rate, n_fft=fft_size, n_chroma=n_bins 186 | ) 187 | 188 | else: 189 | raise ValueError( 190 | f"Invalid scale: {self.scale}. Must be 'mel' or 'chroma'." 191 | ) 192 | 193 | self.register_buffer("fb", fb) 194 | 195 | if scale is not None and device is not None: 196 | self.fb = self.fb.to(self.device) # move filterbank to device 197 | 198 | if self.perceptual_weighting: 199 | if sample_rate is None: 200 | raise ValueError( 201 | f"`sample_rate` must be supplied when `perceptual_weighting = True`." 202 | ) 203 | self.prefilter = FIRFilter(filter_type="aw", fs=sample_rate) 204 | 205 | def stft(self, x): 206 | """Perform STFT. 207 | Args: 208 | x (Tensor): Input signal tensor (B, T). 209 | 210 | Returns: 211 | Tensor: x_mag, x_phs 212 | Magnitude and phase spectra (B, fft_size // 2 + 1, frames). 213 | """ 214 | x_stft = torch.stft( 215 | x, 216 | self.fft_size, 217 | self.hop_size, 218 | self.win_length, 219 | self.window, 220 | return_complex=True, 221 | ) 222 | x_mag = torch.sqrt( 223 | torch.clamp((x_stft.real**2) + (x_stft.imag**2), min=self.eps) 224 | ) 225 | 226 | # torch.angle is expensive, so it is only evaluated if the values are used in the loss 227 | if self.phs_used: 228 | x_phs = torch.angle(x_stft) 229 | else: 230 | x_phs = None 231 | 232 | return x_mag, x_phs 233 | 234 | def forward(self, input: torch.Tensor, target: torch.Tensor): 235 | bs, chs, seq_len = input.size() 236 | 237 | if self.perceptual_weighting: # apply optional A-weighting via FIR filter 238 | # since FIRFilter only support mono audio we will move channels to batch dim 239 | input = input.view(bs * chs, 1, -1) 240 | target = target.view(bs * chs, 1, -1) 241 | 242 | # now apply the filter to both 243 | self.prefilter.to(input.device) 244 | input, target = self.prefilter(input, target) 245 | 246 | # now move the channels back 247 | input = input.view(bs, chs, -1) 248 | target = target.view(bs, chs, -1) 249 | 250 | # compute the magnitude and phase spectra of input and target 251 | self.window = self.window.to(input.device) 252 | 253 | x_mag, x_phs = self.stft(input.view(-1, input.size(-1))) 254 | y_mag, y_phs = self.stft(target.view(-1, target.size(-1))) 255 | 256 | # apply relevant transforms 257 | if self.scale is not None: 258 | self.fb = self.fb.to(input.device) 259 | x_mag = torch.matmul(self.fb, x_mag) 260 | y_mag = torch.matmul(self.fb, y_mag) 261 | 262 | # normalize scales 263 | if self.scale_invariance: 264 | alpha = (x_mag * y_mag).sum([-2, -1]) / ((y_mag**2).sum([-2, -1])) 265 | y_mag = y_mag * alpha.unsqueeze(-1) 266 | 267 | # compute loss terms 268 | sc_mag_loss = self.spectralconv(x_mag, y_mag) if self.w_sc else 0.0 269 | log_mag_loss = self.logstft(x_mag, y_mag) if self.w_log_mag else 0.0 270 | lin_mag_loss = self.linstft(x_mag, y_mag) if self.w_lin_mag else 0.0 271 | phs_loss = torch.nn.functional.mse_loss(x_phs, y_phs) if self.phs_used else 0.0 272 | 273 | # combine loss terms 274 | loss = ( 275 | (self.w_sc * sc_mag_loss) 276 | + (self.w_log_mag * log_mag_loss) 277 | + (self.w_lin_mag * lin_mag_loss) 278 | + (self.w_phs * phs_loss) 279 | ) 280 | 281 | loss = apply_reduction(loss, reduction=self.reduction) 282 | 283 | if self.output == "loss": 284 | return loss 285 | elif self.output == "full": 286 | return loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss 287 | 288 | 289 | class MelSTFTLoss(STFTLoss): 290 | """Mel-scale STFT loss module.""" 291 | 292 | def __init__( 293 | self, 294 | sample_rate, 295 | fft_size=1024, 296 | hop_size=256, 297 | win_length=1024, 298 | window="hann_window", 299 | w_sc=1.0, 300 | w_log_mag=1.0, 301 | w_lin_mag=0.0, 302 | w_phs=0.0, 303 | n_mels=128, 304 | **kwargs, 305 | ): 306 | super(MelSTFTLoss, self).__init__( 307 | fft_size, 308 | hop_size, 309 | win_length, 310 | window, 311 | w_sc, 312 | w_log_mag, 313 | w_lin_mag, 314 | w_phs, 315 | sample_rate, 316 | "mel", 317 | n_mels, 318 | **kwargs, 319 | ) 320 | 321 | 322 | class ChromaSTFTLoss(STFTLoss): 323 | """Chroma-scale STFT loss module.""" 324 | 325 | def __init__( 326 | self, 327 | sample_rate, 328 | fft_size=1024, 329 | hop_size=256, 330 | win_length=1024, 331 | window="hann_window", 332 | w_sc=1.0, 333 | w_log_mag=1.0, 334 | w_lin_mag=0.0, 335 | w_phs=0.0, 336 | n_chroma=12, 337 | **kwargs, 338 | ): 339 | super(ChromaSTFTLoss, self).__init__( 340 | fft_size, 341 | hop_size, 342 | win_length, 343 | window, 344 | w_sc, 345 | w_log_mag, 346 | w_lin_mag, 347 | w_phs, 348 | sample_rate, 349 | "chroma", 350 | n_chroma, 351 | **kwargs, 352 | ) 353 | 354 | 355 | class MultiResolutionSTFTLoss(torch.nn.Module): 356 | """Multi resolution STFT loss module. 357 | 358 | See [Yamamoto et al., 2019](https://arxiv.org/abs/1910.11480) 359 | 360 | Args: 361 | fft_sizes (list): List of FFT sizes. 362 | hop_sizes (list): List of hop sizes. 363 | win_lengths (list): List of window lengths. 364 | window (str, optional): Window to apply before FFT, options include: 365 | 'hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window'] 366 | Default: 'hann_window' 367 | w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0 368 | w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0 369 | w_lin_mag (float, optional): Weight of the linear magnitude loss term. Default: 0.0 370 | w_phs (float, optional): Weight of the spectral phase loss term. Default: 0.0 371 | sample_rate (int, optional): Sample rate. Required when scale = 'mel'. Default: None 372 | scale (str, optional): Optional frequency scaling method, options include: 373 | ['mel', 'chroma'] 374 | Default: None 375 | n_bins (int, optional): Number of mel frequency bins. Required when scale = 'mel'. Default: None. 376 | scale_invariance (bool, optional): Perform an optimal scaling of the target. Default: False 377 | """ 378 | 379 | def __init__( 380 | self, 381 | fft_sizes: List[int] = [1024, 2048, 512], 382 | hop_sizes: List[int] = [120, 240, 50], 383 | win_lengths: List[int] = [600, 1200, 240], 384 | window: str = "hann_window", 385 | w_sc: float = 1.0, 386 | w_log_mag: float = 1.0, 387 | w_lin_mag: float = 0.0, 388 | w_phs: float = 0.0, 389 | sample_rate: float = None, 390 | scale: str = None, 391 | n_bins: int = None, 392 | perceptual_weighting: bool = False, 393 | scale_invariance: bool = False, 394 | **kwargs, 395 | ): 396 | super().__init__() 397 | assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) # must define all 398 | self.fft_sizes = fft_sizes 399 | self.hop_sizes = hop_sizes 400 | self.win_lengths = win_lengths 401 | 402 | self.stft_losses = torch.nn.ModuleList() 403 | for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): 404 | self.stft_losses += [ 405 | STFTLoss( 406 | fs, 407 | ss, 408 | wl, 409 | window, 410 | w_sc, 411 | w_log_mag, 412 | w_lin_mag, 413 | w_phs, 414 | sample_rate, 415 | scale, 416 | n_bins, 417 | perceptual_weighting, 418 | scale_invariance, 419 | **kwargs, 420 | ) 421 | ] 422 | 423 | def forward(self, x, y): 424 | mrstft_loss = 0.0 425 | sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss = [], [], [], [] 426 | 427 | for f in self.stft_losses: 428 | if f.output == "full": # extract just first term 429 | tmp_loss = f(x, y) 430 | mrstft_loss += tmp_loss[0] 431 | sc_mag_loss.append(tmp_loss[1]) 432 | log_mag_loss.append(tmp_loss[2]) 433 | lin_mag_loss.append(tmp_loss[3]) 434 | phs_loss.append(tmp_loss[4]) 435 | else: 436 | mrstft_loss += f(x, y) 437 | 438 | mrstft_loss /= len(self.stft_losses) 439 | 440 | if f.output == "loss": 441 | return mrstft_loss 442 | else: 443 | return mrstft_loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss 444 | 445 | 446 | class RandomResolutionSTFTLoss(torch.nn.Module): 447 | """Random resolution STFT loss module. 448 | 449 | See [Steinmetz & Reiss, 2020](https://www.christiansteinmetz.com/s/DMRN15__auraloss__Audio_focused_loss_functions_in_PyTorch.pdf) 450 | 451 | Args: 452 | resolutions (int): Total number of STFT resolutions. 453 | min_fft_size (int): Smallest FFT size. 454 | max_fft_size (int): Largest FFT size. 455 | min_hop_size (int): Smallest hop size as porportion of window size. 456 | min_hop_size (int): Largest hop size as porportion of window size. 457 | window (str): Window function type. 458 | randomize_rate (int): Number of forwards before STFTs are randomized. 459 | """ 460 | 461 | def __init__( 462 | self, 463 | resolutions=3, 464 | min_fft_size=16, 465 | max_fft_size=32768, 466 | min_hop_size=0.1, 467 | max_hop_size=1.0, 468 | windows=[ 469 | "hann_window", 470 | "bartlett_window", 471 | "blackman_window", 472 | "hamming_window", 473 | "kaiser_window", 474 | ], 475 | w_sc=1.0, 476 | w_log_mag=1.0, 477 | w_lin_mag=0.0, 478 | w_phs=0.0, 479 | sample_rate=None, 480 | scale=None, 481 | n_mels=None, 482 | randomize_rate=1, 483 | **kwargs, 484 | ): 485 | super().__init__() 486 | self.resolutions = resolutions 487 | self.min_fft_size = min_fft_size 488 | self.max_fft_size = max_fft_size 489 | self.min_hop_size = min_hop_size 490 | self.max_hop_size = max_hop_size 491 | self.windows = windows 492 | self.randomize_rate = randomize_rate 493 | self.w_sc = w_sc 494 | self.w_log_mag = w_log_mag 495 | self.w_lin_mag = w_lin_mag 496 | self.w_phs = w_phs 497 | self.sample_rate = sample_rate 498 | self.scale = scale 499 | self.n_mels = n_mels 500 | 501 | self.nforwards = 0 502 | self.randomize_losses() # init the losses 503 | 504 | def randomize_losses(self): 505 | # clear the existing STFT losses 506 | self.stft_losses = torch.nn.ModuleList() 507 | for n in range(self.resolutions): 508 | frame_size = 2 ** np.random.randint( 509 | np.log2(self.min_fft_size), np.log2(self.max_fft_size) 510 | ) 511 | hop_size = int( 512 | frame_size 513 | * ( 514 | self.min_hop_size 515 | + (np.random.rand() * (self.max_hop_size - self.min_hop_size)) 516 | ) 517 | ) 518 | window_length = int(frame_size * np.random.choice([1.0, 0.5, 0.25])) 519 | window = np.random.choice(self.windows) 520 | self.stft_losses += [ 521 | STFTLoss( 522 | frame_size, 523 | hop_size, 524 | window_length, 525 | window, 526 | self.w_sc, 527 | self.w_log_mag, 528 | self.w_lin_mag, 529 | self.w_phs, 530 | self.sample_rate, 531 | self.scale, 532 | self.n_mels, 533 | ) 534 | ] 535 | 536 | def forward(self, input, target): 537 | if input.size(-1) <= self.max_fft_size: 538 | raise ValueError( 539 | f"Input length ({input.size(-1)}) must be larger than largest FFT size ({self.max_fft_size})." 540 | ) 541 | elif target.size(-1) <= self.max_fft_size: 542 | raise ValueError( 543 | f"Target length ({target.size(-1)}) must be larger than largest FFT size ({self.max_fft_size})." 544 | ) 545 | 546 | if self.nforwards % self.randomize_rate == 0: 547 | self.randomize_losses() 548 | 549 | loss = 0.0 550 | for f in self.stft_losses: 551 | loss += f(input, target) 552 | loss /= len(self.stft_losses) 553 | 554 | self.nforwards += 1 555 | 556 | return loss 557 | 558 | 559 | class SumAndDifferenceSTFTLoss(torch.nn.Module): 560 | """Sum and difference sttereo STFT loss module. 561 | 562 | See [Steinmetz et al., 2020](https://arxiv.org/abs/2010.10291) 563 | 564 | Args: 565 | fft_sizes (List[int]): List of FFT sizes. 566 | hop_sizes (List[int]): List of hop sizes. 567 | win_lengths (List[int]): List of window lengths. 568 | window (str, optional): Window function type. 569 | w_sum (float, optional): Weight of the sum loss component. Default: 1.0 570 | w_diff (float, optional): Weight of the difference loss component. Default: 1.0 571 | perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False 572 | mel_stft (bool, optional): Use Multi-resoltuion mel spectrograms. Default: False 573 | n_mel_bins (int, optional): Number of mel bins to use when mel_stft = True. Default: 128 574 | sample_rate (float, optional): Audio sample rate. Default: None 575 | output (str, optional): Format of the loss returned. 576 | 'loss' : Return only the raw, aggregate loss term. 577 | 'full' : Return the raw loss, plus intermediate loss terms. 578 | Default: 'loss' 579 | """ 580 | 581 | def __init__( 582 | self, 583 | fft_sizes: List[int], 584 | hop_sizes: List[int], 585 | win_lengths: List[int], 586 | window: str = "hann_window", 587 | w_sum: float = 1.0, 588 | w_diff: float = 1.0, 589 | output: str = "loss", 590 | **kwargs, 591 | ): 592 | super().__init__() 593 | self.sd = SumAndDifference() 594 | self.w_sum = w_sum 595 | self.w_diff = w_diff 596 | self.output = output 597 | self.mrstft = MultiResolutionSTFTLoss( 598 | fft_sizes, 599 | hop_sizes, 600 | win_lengths, 601 | window, 602 | **kwargs, 603 | ) 604 | 605 | def forward(self, input: torch.Tensor, target: torch.Tensor): 606 | """This loss function assumes batched input of stereo audio in the time domain. 607 | 608 | Args: 609 | input (torch.Tensor): Input tensor with shape (batch size, 2, seq_len). 610 | target (torch.Tensor): Target tensor with shape (batch size, 2, seq_len). 611 | 612 | Returns: 613 | loss (torch.Tensor): Aggreate loss term. Only returned if output='loss'. 614 | loss (torch.Tensor), sum_loss (torch.Tensor), diff_loss (torch.Tensor): 615 | Aggregate and intermediate loss terms. Only returned if output='full'. 616 | """ 617 | assert input.shape == target.shape # must have same shape 618 | bs, chs, seq_len = input.size() 619 | 620 | # compute sum and difference signals for both 621 | input_sum, input_diff = self.sd(input) 622 | target_sum, target_diff = self.sd(target) 623 | 624 | # compute error in STFT domain 625 | sum_loss = self.mrstft(input_sum, target_sum) 626 | diff_loss = self.mrstft(input_diff, target_diff) 627 | loss = ((self.w_sum * sum_loss) + (self.w_diff * diff_loss)) / 2 628 | 629 | if self.output == "loss": 630 | return loss 631 | elif self.output == "full": 632 | return loss, sum_loss, diff_loss 633 | --------------------------------------------------------------------------------